Shapley calculations: Example I

In this example, we have previously created a classifier. We have the data used to create this classifier, and now we want to compute SHAP explainability scores for this classifier.

[1]:
from simba.mixins.train_model_mixin import TrainModelMixin
from simba.mixins.config_reader import ConfigReader
from simba.utils.read_write import read_df, read_config_file
import glob
/Users/simon/opt/anaconda3/envs/simba_dev/lib/python3.6/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.metrics.classification module is  deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.metrics. Anything that cannot be imported from sklearn.metrics is now part of the private API.
  warnings.warn(message, FutureWarning)
[2]:
# DEFINITIONS
CONFIG_PATH = '/Users/simon/Desktop/envs/troubleshooting/Nastacia_unsupervised/project_folder/project_config.ini'
CLASSIFIER_PATH = '/Users/simon/Desktop/envs/troubleshooting/Nastacia_unsupervised/models/generated_models/Attack.sav'
CLASSIFIER_NAME = 'Attack'
COUNT_PRESENT = 10
COUNT_ABSENT = 10
[3]:
# READ IN THE CONFIG AND THE CLASSIFIER
config = read_config_file(config_path=CONFIG_PATH)
config_object = ConfigReader(config_path=CONFIG_PATH)
clf = read_df(file_path=CLASSIFIER_PATH, file_type='pickle')
[4]:
# READ IN THE DATA

#Read in the path to all files inside the project_folder/csv/targets_inserted directory
file_paths = glob.glob(config_object.targets_folder + '/*' + config_object.file_type)

#Reads in the data held in all files in ``file_paths`` defined above
data = TrainModelMixin().read_all_files_in_folder_mp(file_paths=file_paths, file_type=config.get('General settings', 'workflow_file_type').strip()).reset_index(drop=True)

#We find all behavior annotations that are NOT the targets. I.e., if SHAP values for Attack is going to be calculated, bit we need to find which other annotations exist in the data e.g., Escape and Defensive.
non_target_annotations = TrainModelMixin().read_in_all_model_names_to_remove(config=config, model_cnt=config_object.clf_cnt, clf_name=CLASSIFIER_NAME)

# We remove the body-part coordinate columns and the annotations which are not the target from the data
data = data.drop(non_target_annotations + config_object.bp_headers, axis=1)

# We place the target data in its own variable
target_df = data.pop(CLASSIFIER_NAME)

Dataset size: 27.391392MB / 0.027391GB
[5]:
TrainModelMixin().create_shap_log(ini_file_path=CONFIG_PATH,
                               rf_clf=clf,
                               x_df=data,
                               y_df=target_df,
                               x_names=data.columns,
                               clf_name=CLASSIFIER_NAME,
                               cnt_present=COUNT_PRESENT,
                               cnt_absent=COUNT_ABSENT,
                               save_path=config_object.logs_path)
Calculating SHAP values...
Setting feature_perturbation = "tree_path_dependent" because no background data was given.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
Saving SHAP data after 0 iterations...
SHAP frame: 1 / 20, elapsed time: 0.1383...
SHAP frame: 2 / 20, elapsed time: 0.1237...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 3 / 20, elapsed time: 0.1211...
SHAP frame: 4 / 20, elapsed time: 0.1303...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 5 / 20, elapsed time: 0.1275...
SHAP frame: 6 / 20, elapsed time: 0.1265...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 7 / 20, elapsed time: 0.1271...
SHAP frame: 8 / 20, elapsed time: 0.1203...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 9 / 20, elapsed time: 0.1283...
SHAP frame: 10 / 20, elapsed time: 0.1273...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 11 / 20, elapsed time: 0.1259...
SHAP frame: 12 / 20, elapsed time: 0.1287...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 13 / 20, elapsed time: 0.1272...
SHAP frame: 14 / 20, elapsed time: 0.1276...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 15 / 20, elapsed time: 0.1303...
SHAP frame: 16 / 20, elapsed time: 0.1297...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 17 / 20, elapsed time: 0.1289...
SHAP frame: 18 / 20, elapsed time: 0.1267...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SHAP frame: 19 / 20, elapsed time: 0.1332...
Saving SHAP data after 19 iterations...
SHAP frame: 20 / 20, elapsed time: 0.1683...
SIMBA COMPLETE: SHAP calculations complete (elapsed time: 2.7615s)      complete
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 200 out of 200 | elapsed:    0.0s finished
SIMBA COMPLETE: Aggregate SHAP statistics saved in /Users/simon/Desktop/envs/troubleshooting/Nastacia_unsupervised/project_folder/logs/shap directory (elapsed time: 0.1415s)   complete
SIMBA COMPLETE: SHAP summary graph saved at /Users/simon/Desktop/envs/troubleshooting/Nastacia_unsupervised/project_folder/logs/shap/SHAP_summary_line_graph_Attack_20230525112836.png (elapsed time: 0.0443s)  complete
[ ]: