Multi-class classifiers

[ ]:
from simba.model.train_multiclass_rf import TrainMultiClassRandomForestClassifier
from simba.model.grid_search_multiclass_rf import GridSearchMulticlassRandomForestClassifier
from simba.model.inference_multiclass_batch import InferenceMulticlassBatch
from simba.plotting.plot_multiclass_clf_results import PlotMulticlassSklearnResultsSingleCore

CREATE MULTI-CLASS CLASSIFIERS REQUIREMENTS

Some requirements for this to run succesfully (there is currently (10/23) no way of doing this in the GUI):

(1) All files inside the project_folder/csv/targets_inserted directory contains an annotation column which name represent a class of mutually exclusive behaviors and values that map to different types of that behaviour class.

For example, the column may be named “Running”, and each row contains a value between 0 and 3, where 0 represents “no running”, 1 represent “slow speed running”, 2 represents “medium speed running”, and 3 represents “fast speed running”.

(2) If creating a single classifier, the SimBA project_config.ini needs to contain a map pairing the values in the annotation column to behavior names. In the [SML Setting], insert a classifier_map_1 option. See THIS config file as an example. For example, the [SML Settings] may read:

[SML settings]
model_dir = /Users/simon/Desktop/envs/troubleshooting/multilabel/models
model_path_1 = /Users/simon/Desktop/envs/troubleshooting/multilabel/models/generated_models/syllable_class.sav
no_targets = 1
target_name_1 = running
classifier_map_1 = {0: 'no running', 1: 'slow speed running', 2: 'medium speed running', 3: 'fast speed running'}

(3) If creating multiple multi-class classifiers, the SimBA meta files inside the project_folder/configs directory needs to contain a map pairing the values in the annotation column to behavior names. This involves inserting a column named classifier_map and the map into the first row. See last column in THIS config meta file as an example.

(4) We can use sampling to balance the class distributions of annoatated frames when creating classifiers.

  • When creating a single classifier: Open the project_config.ini in a text editor. Under the [create ensemble settings] section find the [under_sample_setting] section and set this to either random undersample multiclass frames or random undersample multiclass bouts. Next, under the [create ensemble settings] section and [under_sample_ratio] section, set this to a dictionary that defines a baseline target class, and ratios to sample of the non-target classes relative to the target class. For example, this is a valid entry: {‘target_var’: 0, ‘sampling_ratio’: {1: 1.0, 2: 1.0, 3: 1.0}}.

    Example I: You have 4 classes of behaviors 0 represents “no running”, 1 represent “slow speed running”, 2 represents “medium speed running”, and 3 represents “fast speed running”. You have 100 annotated frames of no running, 500 annotated frames each of “slow speed running”, “medium speed running”, and “fast speed running”. Now you want to sample as many annotations in each class as in there is in the “no running class” for training.

    Set the target_var to zero (no running) and the sampling ratios to 1.0 for each of the other classes: {‘target_var’: 0, ‘sampling_ratio’: {1: 1.0, 2: 1.0, 3: 1.0}}. For each of the non-target classes (1, 2, 3), 100% of the count of annotations in the target class (0) will be sampled.

    Example II: You have 100 annotated frames of no running, 500 annotated frames each of “slow speed running”, “medium speed running”, and “fast speed running”. Now you want half the numer of annotations present in the non-running class for all the non-target classes. Set the target_var to zero (no running) and the sampling ratios to 0.5 for each of the non-target classes: {‘target_var’: 0, ‘sampling_ratio’: {1: 0.5, 2: 0.5, 3: 0.5}}. For each of the non-target classes (1, 2, 3), 50% of the count of annotations in the target class (0) will be sampled.

    Here is example config file undersample setting for a multiclass task:

    [create ensemble settings]
    under_sample_setting = random undersample multiclass frames
    under_sample_ratio = {'target_var': 0, 'sampling_ratio': {1: 1, 2: 1, 3: 1}}
    
  • When creating multiple classifiers: Open the files in the project_folder/configs directory in a text editor.

    Under the under_sample_setting header insert eitherrandom undersample multiclass frames or random undersample multiclass bouts. Under the under_sample_ratio header, insert your sampling ratios e.g., {'target_var': 0, 'sampling_ratio': {1: 1, 2: 1, 3: 1}}.

CREATE A SINGLE MULTI-CLASS CLASSIFIER

[2]:
#WE DEFINE THE PATH TO OUR SIMBA PROJECT CONFIG FILE
CONFIG_PATH = '/Users/simon/Desktop/envs/troubleshooting/multilabel/project_folder/project_config.ini'
[3]:
#WE CREATE A MULTI-CLASS TRAINER INSTANCE
model_trainer = TrainMultiClassRandomForestClassifier(config_path=CONFIG_PATH)
Reading in 1 annotated files...
Reading complete 01.YC015YC016phase45-sample (elapsed time: 51.3581s)...
Number of features in dataset: 14251
Number of None frames in dataset: 10 (25.0%)
Number of sharp frames in dataset: 10 (25.0%)
Number of track frames in dataset: 10 (25.0%)
Number of sync frames in dataset: 10 (25.0%)
[4]:
#WE RUN THIS MODEL TRAINER INSTANCE
model_trainer.run()
Training and evaluating model...
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
Calculating learning curves...
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.2s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.2s finished
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.1s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.2s finished
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=-1)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-1)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
Learning curve calculation complete (elapsed time: 1.5253s) ...
Calculating PR curves...
Precision-recall curve calculation complete (elapsed time: 0.1334s) ...
Creating classification report visualization...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
Creating feature importance log...
Creating feature importance bar chart...
Creating feature importance log...
[5]:
# FINALLY WE SAVE THE MODEL
model_trainer.save_model()
SIMBA COMPLETE: Classifier syllable_class saved in models/generated_models directory (elapsed time: 8.938s)     complete
SIMBA COMPLETE: Evaluation files are in models/generated_models/model_evaluations folders       complete

CREATE MULTIPLE MULTI-CLASS CLASSIFIER, ONE FOR EACH FILE INSIDE THE PROJECT_FOLDER/CONFIGS DIRECTORY

[6]:
#WE CREATE A GRID-SEARCH MULTI-CLASS TRAINER INSTANCE
multi_model_trainer = GridSearchMulticlassRandomForestClassifier(config_path=CONFIG_PATH)
Reading in 1 annotated files...
SIMBA WARNING: MultiProcessingFailedWarning: Multi-processing file read failed, reverting to single core (increased run-time on large datasets).        warning
Reading in file 1/1...
Dataset size: 2.28352MB / 0.002284GB
1 file(s) read (elapsed time: 52.359s) ...
[7]:
#WE RUN THIS MODEL TRAINER INSTANCE
multi_model_trainer.run()
Training model 1/2 (syllable_class)...
Number of features in dataset: 14251
Number of None frames in dataset: 10 (25.0%)
Number of sharp frames in dataset: 10 (25.0%)
Number of track frames in dataset: 10 (25.0%)
Number of sync frames in dataset: 10 (25.0%)
MODEL 1 settings
+------------------------+----------------+
| Setting                | value          |
+========================+================+
| Model name             | syllable_class |
+------------------------+----------------+
| Ensemble method        | RF             |
+------------------------+----------------+
| Estimators (trees)     | 2000           |
+------------------------+----------------+
| Max features           | sqrt           |
+------------------------+----------------+
| Under sampling setting | None           |
+------------------------+----------------+
| Under sampling ratio   | nan            |
+------------------------+----------------+
| Over sampling setting  | None           |
+------------------------+----------------+
| Over sampling ratio    | nan            |
+------------------------+----------------+
| criterion              | gini           |
+------------------------+----------------+
| Min sample leaf        | 1              |
+------------------------+----------------+     TABLE
Fitting syllable_class model...
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    0.1s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:    0.3s
[Parallel(n_jobs=-1)]: Done 434 tasks      | elapsed:    0.6s
[Parallel(n_jobs=-1)]: Done 784 tasks      | elapsed:    1.1s
[Parallel(n_jobs=-1)]: Done 1234 tasks      | elapsed:    1.7s
[Parallel(n_jobs=-1)]: Done 1784 tasks      | elapsed:    2.5s
[Parallel(n_jobs=-1)]: Done 2000 out of 2000 | elapsed:    2.7s finished
Saving model meta data file...
Classifier syllable_class_0 saved in models/validations/model_files directory ...
SIMBA COMPLETE: All models and evaluations complete. The models and evaluation files are in models/validations folders  complete
Training model 2/2 (syllable_class)...
Number of features in dataset: 14251
Number of None frames in dataset: 10 (25.0%)
Number of sharp frames in dataset: 10 (25.0%)
Number of track frames in dataset: 10 (25.0%)
Number of sync frames in dataset: 10 (25.0%)
MODEL 2 settings
+------------------------+----------------+
| Setting                | value          |
+========================+================+
| Model name             | syllable_class |
+------------------------+----------------+
| Ensemble method        | RF             |
+------------------------+----------------+
| Estimators (trees)     | 2000           |
+------------------------+----------------+
| Max features           | sqrt           |
+------------------------+----------------+
| Under sampling setting | None           |
+------------------------+----------------+
| Under sampling ratio   | nan            |
+------------------------+----------------+
| Over sampling setting  | None           |
+------------------------+----------------+
| Over sampling ratio    | nan            |
+------------------------+----------------+
| criterion              | gini           |
+------------------------+----------------+
| Min sample leaf        | 1              |
+------------------------+----------------+     TABLE
Fitting syllable_class model...
[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed:    0.1s
[Parallel(n_jobs=-1)]: Done 184 tasks      | elapsed:    0.3s
[Parallel(n_jobs=-1)]: Done 434 tasks      | elapsed:    0.6s
[Parallel(n_jobs=-1)]: Done 784 tasks      | elapsed:    1.1s
[Parallel(n_jobs=-1)]: Done 1234 tasks      | elapsed:    1.7s
[Parallel(n_jobs=-1)]: Done 1784 tasks      | elapsed:    2.4s
Saving model meta data file...
Classifier syllable_class_1 saved in models/validations/model_files directory ...
SIMBA COMPLETE: All models and evaluations complete. The models and evaluation files are in models/validations folders  complete
[Parallel(n_jobs=-1)]: Done 2000 out of 2000 | elapsed:    2.7s finished

RUN INFERENCE

Now when we have created our multi-class model(s), we want to use one model to create predictions for new videos. Set the path to the model you want to use in the [SML settings] section under the model_path option.

[8]:
#WE CREATE A MULTI-CLASS BATCH INFERENCE INSTANCE
batch_inferencer = InferenceMulticlassBatch(config_path=CONFIG_PATH)

Analyzing 1 file(s) with 1 classifier(s)...
[9]:
#WE RUN THE BATCH INFERENCE INSTANCE
batch_inferencer.run()
Analyzing video 01.YC015YC016phase45-sample...
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done   6 out of  10 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=8)]: Done  10 out of  10 | elapsed:    0.0s finished
Predictions created for 01.YC015YC016phase45-sample (elapsed time: 5046.467) ...
SIMBA COMPLETE: Multi-class machine predictions complete. 1 file(s) saved in project_folder/csv/machine_results directory (elapsed time: 5046.4743s)    complete

After running batch_inferencer.run() in directly above cell, the results are saved in the project_folder/csv/machine_results directory of the SimBA project. Rather than the one-column-per-classifier created by Boolean classifiers, you will have one column per type of behavior in you classifier map. These fields contain probabilities that each of the behaviors occur in each frame (you may potentially want to weigh these scores to get the most accurate results).

VISUALIZE RESULTS

[13]:
#DEFINE HOW YOU WANT TO PLOT THE RESULTS
CREATE_FRAMES = False
CREATE_VIDEO = True
ROTATE_VIDEO = False
VIDEO_NAMES = ['01.YC015YC016phase45-sample.csv']
[14]:
#WE CREATE A MULTI-CLASS PLOTTING INSTANCE
multiclass_plotter = PlotMulticlassSklearnResultsSingleCore(config_path=CONFIG_PATH,
                                                            frame_setting=CREATE_FRAMES,
                                                            video_setting=CREATE_VIDEO,
                                                            video_names=VIDEO_NAMES,
                                                            rotate=ROTATE_VIDEO)
[12]:
#WE RUN THE PLOTTER (ONLY WORKS ON SINGLE CORE AT MOMENT SO MIGHT TAKE SOME TIME)
multiclass_plotter.run()
SIMBA WARNING: FrameRangeWarning: The video /Users/simon/Desktop/envs/troubleshooting/multilabel/project_folder/videos/01.YC015YC016phase45-sample.mp4 contains 300 frames, while the data /Users/simon/Desktop/envs/troubleshooting/multilabel/project_folder/csv/machine_results/01.YC015YC016phase45-sample.csv contains 40 frames.  warning
Frame: 1 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 2 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 3 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 4 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 5 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 6 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 7 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 8 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 9 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 10 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 11 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 12 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 13 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 14 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 15 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 16 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 17 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 18 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 19 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 20 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 21 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 22 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 23 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 24 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 25 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 26 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 27 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 28 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 29 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 30 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 31 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 32 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 33 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 34 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 35 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 36 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 37 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 38 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 39 / 300. Video: 01.YC015YC016phase45-sample (1/1)
Frame: 40 / 300. Video: 01.YC015YC016phase45-sample (1/1)
SIMBA WARNING: FrameRangeWarning: Video terminated early: no data for frame 40 found in file /Users/simon/Desktop/envs/troubleshooting/multilabel/project_folder/csv/machine_results/01.YC015YC016phase45-sample.csv    warning