Source code for simba.model.train_multiclass_rf

import ast
import json
import os
from itertools import product
from typing import Dict, Union

import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

from simba.mixins.config_reader import ConfigReader
from simba.mixins.train_model_mixin import TrainModelMixin
from simba.utils.checks import (check_if_filepath_list_is_empty,
                                check_if_list_contains_values,
                                check_if_valid_input)
from simba.utils.enums import (ConfigKey, Dtypes, Formats, Methods,
                               MLParamKeys, Options, TagNames)
from simba.utils.errors import SamplingError
from simba.utils.printing import SimbaTimer, log_event, stdout_success
from simba.utils.read_write import read_config_entry, write_df


[docs]class TrainMultiClassRandomForestClassifier(ConfigReader, TrainModelMixin): def __init__(self, config_path: Union[str, os.PathLike]): ConfigReader.__init__(self, config_path=config_path) TrainModelMixin.__init__(self) log_event(logger_name=str(self.__class__.__name__),log_type=TagNames.CLASS_INIT.value,msg=self.create_log_msg_from_init_args(locals=locals())) self.read_model_settings_from_config(config=self.config) self.bp_config = read_config_entry(config=self.config, section=ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, option=ConfigKey.POSE_SETTING.value, default_value='user_defined', data_type=Dtypes.STR.value) check_if_filepath_list_is_empty( filepaths=self.target_file_paths, error_msg="Zero annotation files found in project_folder/csv/targets_inserted, cannot create model.", ) print(f"Reading in {len(self.target_file_paths)} annotated files...") self.data_df, self.frm_idx = self.read_all_files_in_folder_mp_futures( annotations_file_paths=self.target_file_paths, file_type=self.file_type, classifier_names=self.clf_names, raise_bool_clf_error=False, ) self.frm_idx = pd.DataFrame( {"VIDEO": list(self.data_df.index), "FRAME_IDX": self.frm_idx} ) self.check_raw_dataset_integrity(df=self.data_df, logs_path=self.logs_path) self.data_df_wo_cords = self.drop_bp_cords(df=self.data_df) annotation_cols_to_remove = self.read_in_all_model_names_to_remove( self.config, self.clf_cnt, self.clf_name ) self.x_y_df = self.delete_other_annotation_columns( df=self.data_df_wo_cords, annotations_lst=list(annotation_cols_to_remove), raise_error=False, ) self.classifier_map = ast.literal_eval( read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, "classifier_map", data_type="str", ) ) self.under_sample_setting = ( read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.UNDERSAMPLE_SETTING.value, data_type=Dtypes.STR.value, default_value=Dtypes.NONE.value, ) .lower() .strip() ) self.under_sample_ratio = ast.literal_eval( read_config_entry( self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.UNDERSAMPLE_RATIO.value, data_type=Dtypes.STR.value, ) ) self.x_df, self.y_df = self.split_df_to_x_y(self.x_y_df, self.clf_name) check_if_list_contains_values( data=list(self.y_df), values=list(self.classifier_map.keys()), name=self.clf_name, ) self.feature_names = self.x_df.columns self.check_sampled_dataset_integrity(x_df=self.x_df, y_df=self.y_df) print(f"Number of features in dataset: {len(self.feature_names)}") for k, v in self.classifier_map.items(): print( f"Number of {v} frames in dataset: {len(self.y_df[self.y_df == k])} ({(len(self.y_df[self.y_df == k]) / len(self.x_df)) * 100}%)" ) def _check_presence_of_classes_post_sampling(self): for set, clf_code in product( [self.y_train, self.y_test], self.classifier_map.keys() ): obs_cnt = len(self.y_df[self.y_df == clf_code]) if obs_cnt == 0: raise SamplingError( msg=f"Zero observations of {self.classifier_map[clf_code]} found in the training and/or test set." ) def _perform_sampling(self): if self.split_type == Methods.SPLIT_TYPE_FRAMES.value: self.x_train, self.x_test, self.y_train, self.y_test = train_test_split( self.x_df, self.y_df, test_size=self.tt_size ) if self.under_sample_setting.lower() != Dtypes.NONE.value.lower(): if self.under_sample_setting == "random undersample multiclass frames": self.x_train, self.y_train = self.random_multiclass_frm_sampler( x_df=self.x_train, y_df=self.y_train, target_field=self.clf_name, target_var=self.under_sample_ratio["target_var"], sampling_ratio=self.under_sample_ratio["sampling_ratio"], ) elif self.under_sample_setting == "random undersample multiclass bouts": self.x_train, self.y_train = self.random_multiclass_bout_sampler( x_df=self.x_train, y_df=self.y_train, target_field=self.clf_name, target_var=self.under_sample_ratio["target_var"], sampling_ratio=self.under_sample_ratio["sampling_ratio"], ) else: raise SamplingError( msg=f"Under sample setting {self.under_sample_setting} not recognized. Options: [None, random undersample multiclass frames, random undersample multiclass bouts]", source=self.__class__.__name__, ) if self.save_train_test_frm_info: train_data = self.frm_idx[ self.frm_idx.index.isin(self.x_train.index) ].set_index("VIDEO") test_data = self.frm_idx[ self.frm_idx.index.isin(self.x_test.index) ].set_index("VIDEO") write_df( df=train_data, file_type=Formats.CSV.value, save_path=os.path.join(self.eval_out_path, "train_idx.csv"), ) write_df( df=test_data, file_type=Formats.CSV.value, save_path=os.path.join(self.eval_out_path, "test_idx.csv"), ) print( f"Frame index of train and test set saved in {self.eval_out_path} directory..." )
[docs] def run(self): print("Training and evaluating model...") self._perform_sampling() self.timer = SimbaTimer(start=True) self.rf_clf = RandomForestClassifier( n_estimators=self.n_estimators, max_depth=self.rf_max_depth, max_features=self.max_features, n_jobs=-1, criterion=self.criterion, min_samples_leaf=self.min_sample_leaf, bootstrap=True, verbose=1, class_weight=self.class_weights, ) self.rf_clf = self.clf_fit(clf=self.rf_clf, x_df=self.x_df, y_df=self.y_df) if self.compute_permutation_importance in Options.PERFORM_FLAGS.value: self.calc_permutation_importance(x_test=self.x_test, y_test=self.y_test, clf=self.rf_clf, feature_names=self.feature_names, clf_name=self.clf_name, save_dir=self.eval_out_path, save_file_no=None, plot=True) if self.generate_learning_curve in Options.PERFORM_FLAGS.value: self.calc_learning_curve( x_y_df=self.x_y_df, clf_name=self.clf_name, shuffle_splits=self.shuffle_splits, dataset_splits=self.dataset_splits, tt_size=self.tt_size, rf_clf=self.rf_clf, save_dir=self.eval_out_path, multiclass=True, ) if self.generate_precision_recall_curve in Options.PERFORM_FLAGS.value: self.calc_pr_curve( rf_clf=self.rf_clf, x_df=self.x_test, y_df=self.y_test, clf_name=self.clf_name, save_dir=self.eval_out_path, multiclass=True, classifier_map=self.classifier_map, ) if self.generate_example_decision_tree in Options.PERFORM_FLAGS.value: self.create_example_dt( self.rf_clf, self.clf_name, self.feature_names, list(self.classifier_map.values()), self.eval_out_path, ) if self.generate_classification_report in Options.PERFORM_FLAGS.value: self.create_clf_report( self.rf_clf, self.x_test, self.y_test, list(self.classifier_map.values()), self.eval_out_path, clf_name=self.clf_name, ) if self.generate_features_importance_log in Options.PERFORM_FLAGS.value: self.create_x_importance_log( self.rf_clf, self.feature_names, self.clf_name, self.eval_out_path ) if self.generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value: self.create_x_importance_bar_chart( self.rf_clf, self.feature_names, self.clf_name, self.eval_out_path, self.feature_importance_bars, ) if self.generate_example_decision_tree_fancy in Options.PERFORM_FLAGS.value: self.dviz_classification_visualization( self.x_train, self.y_train, self.clf_name, list(self.classifier_map.values()), self.eval_out_path, ) if self.generate_shap_scores in Options.PERFORM_FLAGS.value: shap_plot = self.bp_config in {'14', '16'} if not self.shap_multiprocess in Options.PERFORM_FLAGS.value: self.create_shap_log(rf_clf=self.rf_clf, x=self.x_train, y=self.y_train, x_names=list(self.feature_names), clf_name=self.clf_name, cnt_present=self.shap_target_present_cnt, cnt_absent=self.shap_target_absent_cnt, verbose=True, plot=shap_plot, save_it=self.shap_save_n, save_dir=self.eval_out_path) else: self.create_shap_log_concurrent_mp(rf_clf=self.rf_clf, x=self.x_train, y=self.y_train, x_names=list(self.feature_names), clf_name=self.clf_name, cnt_present=self.shap_target_present_cnt, cnt_absent=self.shap_target_absent_cnt, save_dir=self.eval_out_path, plot=shap_plot) if self.compute_partial_dependency in Options.PERFORM_FLAGS.value: self.partial_dependence_calculator( clf=self.rf_clf, x_df=self.x_train, clf_name=self.clf_name, save_dir=self.eval_out_path, ) if self.save_meta_data in Options.PERFORM_FLAGS.value: meta_data_lst = [ self.clf_name, self.criterion, self.max_features, self.min_sample_leaf, self.n_estimators, self.compute_permutation_importance, self.generate_classification_report, self.generate_example_decision_tree, self.generate_features_importance_bar_graph, self.generate_features_importance_log, self.generate_precision_recall_curve, self.save_meta_data, self.generate_learning_curve, self.dataset_splits, self.shuffle_splits, self.feature_importance_bars, self.over_sample_ratio, self.over_sample_setting, self.tt_size, self.split_type, self.under_sample_ratio, self.under_sample_setting, str(self.class_weights), self.rf_max_depth, ] self.create_meta_data_csv_training_one_model( meta_data_lst, self.clf_name, self.eval_out_path )
[docs] def save_model(self) -> None: """ Method for saving pickled RF model. The model is saved in the `models/generated_models` directory of the SimBA project tree. """ self.timer.stop_timer() self.save_rf_model(self.rf_clf, self.clf_name, self.model_dir_out) stdout_success( msg=f"Classifier {self.clf_name} saved in models/generated_models directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__, ) stdout_success( msg=f"Evaluation files are in models/generated_models/model_evaluations folders", source=self.__class__.__name__, )
# model_trainer = TrainMultiClassRandomForestClassifier(config_path='/Users/simon/Desktop/envs/troubleshooting/multilabel/project_folder/project_config.ini') # model_trainer.run() # model_trainer.save_model()