__author__ = "Simon Nilsson; sronilsson@gmail.com"
import ast
import os
from typing import Optional, Union
import pandas as pd
from sklearn.model_selection import train_test_split
from simba.mixins.config_reader import ConfigReader
from simba.mixins.train_model_mixin import TrainModelMixin
from simba.utils.checks import check_if_filepath_list_is_empty, check_int
from simba.utils.enums import (ConfigKey, Dtypes, Formats, Links, Methods,
MLParamKeys, Options)
from simba.utils.errors import ParametersFileError
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import read_config_entry, str_2_bool, write_df
[docs]class TrainRandomForestClassifier(ConfigReader, TrainModelMixin):
"""
Train a single random forest model using hyperparameter setting and evaluation methods
stored within the SimBA project config .ini file (``global environment``).
:param Union[str, os.PathLike] config_path: path to SimBA project config file in Configparser format
.. note::
`Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/Scenario1.md#step-7-train-machine-model>`_
:example:
>>> model_trainer = TrainRandomForestClassifier(config_path='MyConfigPath')
>>> model_trainer.run()
>>> model_trainer.save()
"""
def __init__(self,
config_path: Union[str, os.PathLike]):
ConfigReader.__init__(self, config_path=config_path, create_logger=False)
TrainModelMixin.__init__(self)
self.read_model_settings_from_config(config=self.config)
check_if_filepath_list_is_empty(filepaths=self.target_file_paths, error_msg=f"Zero annotation files found in directory {self.targets_folder}, cannot create model.")
self.bp_config = read_config_entry(config=self.config, section=ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, option=ConfigKey.POSE_SETTING.value, default_value='user_defined', data_type=Dtypes.STR.value)
stdout_information(msg=f"Reading in {len(self.target_file_paths)} annotated files...")
if (isinstance(self.clf_name, str)) and self.clf_name.lower() == Dtypes.NONE.value.lower():
raise ParametersFileError(msg=f'The single classifier name is names "None". Have you set the model settings for a SINGLE model in SimBA? "None" is the name of a behavior if the behavior has not been set. Please set the hyperparameters and click "SAVE SETTINGS (SPECIFIC MODEL)" before training the model as detailed HERE: {Links.TRAIN_ML_MODEL.value}')
self.data_df, self.frm_idx = self.read_all_files_in_folder_mp_futures(self.target_file_paths, self.file_type, [self.clf_name])
self.frm_idx = pd.DataFrame({"VIDEO": list(self.data_df.index), "FRAME_IDX": self.frm_idx})
self.data_df = self.check_raw_dataset_integrity(df=self.data_df, logs_path=self.logs_path)
self.data_df_wo_cords = self.drop_bp_cords(df=self.data_df)
annotation_cols_to_remove = self.read_in_all_model_names_to_remove(self.config, self.clf_cnt, self.clf_name)
self.x_y_df = self.delete_other_annotation_columns(self.data_df_wo_cords, list(annotation_cols_to_remove))
self.class_names = ["Not_" + self.clf_name, self.clf_name]
self.x_df, self.y_df = self.split_df_to_x_y(self.x_y_df, self.clf_name)
self.feature_names = self.x_df.columns
self.check_sampled_dataset_integrity(x_df=self.x_df, y_df=self.y_df)
stdout_information(msg=f"Number of features in dataset: {len(self.x_df.columns)}")
stdout_information(msg=f"Number of {self.clf_name} frames in dataset: {int(self.y_df.sum())} ({str(round(self.y_df.sum() / len(self.y_df), 4) * 100)}%)")
[docs] def run(self):
"""
Method for training single random forest model.
"""
stdout_information(msg="Training and evaluating model...")
self.timer = SimbaTimer(start=True)
self.perform_sampling()
if self.algo == "RF":
n_estimators = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.RF_ESTIMATORS.value,data_type=Dtypes.INT.value)
max_features = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.RF_MAX_FEATURES.value,data_type=Dtypes.STR.value)
if max_features == "None": max_features = None
criterion = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.RF_CRITERION.value, data_type=Dtypes.STR.value, options=Options.CLF_CRITERION.value)
min_sample_leaf = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.MIN_LEAF.value, data_type=Dtypes.INT.value)
compute_permutation_importance = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.PERMUTATION_IMPORTANCE.value, data_type=Dtypes.STR.value, default_value=False)
generate_learning_curve = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.LEARNING_CURVE.value, data_type=Dtypes.STR.value, default_value=False)
generate_precision_recall_curve = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.PRECISION_RECALL.value, data_type=Dtypes.STR.value, default_value=False)
generate_example_decision_tree = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.EX_DECISION_TREE.value, data_type=Dtypes.STR.value, default_value=False)
generate_classification_report = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.CLF_REPORT.value, data_type=Dtypes.STR.value, default_value=False)
generate_features_importance_log = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.IMPORTANCE_LOG.value, data_type=Dtypes.STR.value, default_value=False)
generate_features_importance_bar_graph = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.IMPORTANCE_LOG.value,data_type=Dtypes.STR.value,default_value=False)
generate_example_decision_tree_fancy = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.EX_DECISION_TREE_FANCY.value,data_type=Dtypes.STR.value,default_value=False)
generate_shap_scores = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.SHAP_SCORES.value,data_type=Dtypes.STR.value,default_value=False)
save_meta_data = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.RF_METADATA.value,data_type=Dtypes.STR.value,default_value=False)
compute_partial_dependency = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.PARTIAL_DEPENDENCY.value,data_type=Dtypes.STR.value, default_value=False)
cuda = str_2_bool(read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.CUDA.value, data_type=Dtypes.STR.value, default_value=False))
if self.config.has_option(ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.CLASS_WEIGHTS.value):
class_weights = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.CLASS_WEIGHTS.value, data_type=Dtypes.STR.value, default_value=Dtypes.NONE.value)
if class_weights == "custom":
class_weights = ast.literal_eval(read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.CLASS_CUSTOM_WEIGHTS.value,data_type=Dtypes.STR.value))
for k, v in class_weights.items():
class_weights[k] = int(v)
if class_weights == Dtypes.NONE.value:
class_weights = None
else:
class_weights = None
if generate_learning_curve in Options.PERFORM_FLAGS.value:
shuffle_splits = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.LEARNING_CURVE_K_SPLITS.value,data_type=Dtypes.INT.value,default_value=Dtypes.NAN.value)
dataset_splits = read_config_entry(self.config,ConfigKey.CREATE_ENSEMBLE_SETTINGS.value,MLParamKeys.LEARNING_CURVE_DATA_SPLITS.value,data_type=Dtypes.INT.value,default_value=Dtypes.NAN.value)
check_int(name=MLParamKeys.LEARNING_CURVE_K_SPLITS.value, value=shuffle_splits)
check_int(name=MLParamKeys.LEARNING_CURVE_DATA_SPLITS.value, value=dataset_splits)
else:
shuffle_splits, dataset_splits = Dtypes.NAN.value, Dtypes.NAN.value
if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value:
feature_importance_bars = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.IMPORTANCE_BARS_N.value, Dtypes.INT.value, Dtypes.NAN.value)
check_int(name=MLParamKeys.IMPORTANCE_BARS_N.value, value=feature_importance_bars, min_value=1)
else:
feature_importance_bars = Dtypes.NAN.value
(shap_target_present_cnt, shap_target_absent_cnt, shap_save_n, shap_multiprocess) = (None, None, None, None)
if generate_shap_scores in Options.PERFORM_FLAGS.value:
shap_target_present_cnt = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.SHAP_PRESENT.value, data_type=Dtypes.INT.value, default_value=0)
shap_target_absent_cnt = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.SHAP_ABSENT.value, data_type=Dtypes.INT.value, default_value=0)
shap_save_n = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.SHAP_SAVE_ITERATION.value, data_type=Dtypes.STR.value, default_value=Dtypes.NONE.value)
shap_multiprocess = read_config_entry(self.config, ConfigKey.CREATE_ENSEMBLE_SETTINGS.value, MLParamKeys.SHAP_MULTIPROCESS.value, data_type=Dtypes.STR.value, default_value="False")
try:
shap_save_n = int(shap_save_n)
except ValueError:
shap_save_n = shap_target_present_cnt + shap_target_absent_cnt
check_int(name=MLParamKeys.SHAP_PRESENT.value, value=shap_target_present_cnt)
check_int(name=MLParamKeys.SHAP_ABSENT.value, value=shap_target_absent_cnt)
self.rf_clf = self.clf_define(n_estimators=n_estimators,
max_depth=self.rf_max_depth,
max_features=max_features,
n_jobs=-1,
criterion=criterion,
min_samples_leaf=min_sample_leaf,
verbose=1,
class_weight=class_weights,
cuda=cuda)
stdout_information(msg=f"Fitting {self.clf_name} model...")
self.rf_clf = self.clf_fit(clf=self.rf_clf, x_df=self.x_train, y_df=self.y_train)
if compute_permutation_importance in Options.PERFORM_FLAGS.value:
self.calc_permutation_importance(x_test=self.x_test, y_test=self.y_test, clf=self.rf_clf, feature_names=self.feature_names, clf_name=self.clf_name, save_dir=self.eval_out_path, save_file_no=None, plot=True)
if generate_learning_curve in Options.PERFORM_FLAGS.value:
self.calc_learning_curve(x_y_df=self.x_y_df, clf_name=self.clf_name, shuffle_splits=shuffle_splits, dataset_splits=dataset_splits, tt_size=self.tt_size, rf_clf=self.rf_clf, save_dir=self.eval_out_path)
if generate_precision_recall_curve in Options.PERFORM_FLAGS.value:
self.calc_pr_curve(rf_clf=self.rf_clf, x_df=self.x_test, y_df=self.y_test, clf_name=self.clf_name, save_dir=self.eval_out_path)
if generate_example_decision_tree in Options.PERFORM_FLAGS.value:
self.create_example_dt(self.rf_clf, self.clf_name, self.feature_names, self.class_names, self.eval_out_path)
if generate_classification_report in Options.PERFORM_FLAGS.value:
self.create_clf_report(self.rf_clf, self.x_test, self.y_test, self.class_names, self.eval_out_path)
if generate_features_importance_log in Options.PERFORM_FLAGS.value:
self.create_x_importance_log(rf_clf=self.rf_clf, x_names=self.feature_names, clf_name=self.clf_name, save_dir=self.eval_out_path)
if generate_features_importance_bar_graph in Options.PERFORM_FLAGS.value:
self.create_x_importance_bar_chart(rf_clf=self.rf_clf, x_names=self.feature_names, clf_name=self.clf_name, save_dir=self.eval_out_path, n_bars=feature_importance_bars)
if generate_example_decision_tree_fancy in Options.PERFORM_FLAGS.value:
self.dviz_classification_visualization(self.x_train, self.y_train, self.clf_name, self.class_names, self.eval_out_path)
if generate_shap_scores in Options.PERFORM_FLAGS.value:
shap_plot = self.bp_config in {'14', '16'}
if not shap_multiprocess in Options.PERFORM_FLAGS.value:
self.create_shap_log(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=shap_target_present_cnt,
cnt_absent=shap_target_absent_cnt,
verbose=True,
plot=shap_plot,
save_it=shap_save_n,
save_dir=self.eval_out_path)
else:
self.create_shap_log_concurrent_mp(rf_clf=self.rf_clf,
x=self.x_train,
y=self.y_train,
x_names=list(self.feature_names),
clf_name=self.clf_name,
cnt_present=shap_target_present_cnt,
cnt_absent=shap_target_absent_cnt,
save_dir=self.eval_out_path,
plot=shap_plot)
if compute_partial_dependency in Options.PERFORM_FLAGS.value:
self.partial_dependence_calculator(clf=self.rf_clf, x_df=self.x_train, clf_name=self.clf_name, save_dir=self.eval_out_path)
if save_meta_data in Options.PERFORM_FLAGS.value:
print("Saving model meta data file...")
save_path = os.path.join(self.eval_out_path, f"{self.clf_name}_meta.csv")
table = {MLParamKeys.CLASSIFIER_NAME.value: self.clf_name,
MLParamKeys.RF_CRITERION.value: criterion,
MLParamKeys.RF_MAX_FEATURES.value: max_features,
MLParamKeys.MIN_LEAF.value: min_sample_leaf,
MLParamKeys.RF_ESTIMATORS.value: n_estimators,
MLParamKeys.PERMUTATION_IMPORTANCE.value: compute_permutation_importance,
MLParamKeys.CLF_REPORT.value: generate_classification_report,
MLParamKeys.EX_DECISION_TREE.value: generate_example_decision_tree,
MLParamKeys.IMPORTANCE_BAR_CHART.value: generate_features_importance_bar_graph,
MLParamKeys.IMPORTANCE_LOG.value: generate_features_importance_log,
MLParamKeys.PRECISION_RECALL.value: generate_precision_recall_curve,
MLParamKeys.RF_METADATA.value: save_meta_data,
MLParamKeys.LEARNING_CURVE.value: generate_learning_curve,
MLParamKeys.LEARNING_CURVE_DATA_SPLITS.value: dataset_splits,
MLParamKeys.LEARNING_CURVE_K_SPLITS.value: shuffle_splits,
MLParamKeys.N_FEATURE_IMPORTANCE_BARS.value: feature_importance_bars,
MLParamKeys.OVERSAMPLE_RATIO.value: self.over_sample_ratio,
MLParamKeys.OVERSAMPLE_SETTING.value: self.over_sample_setting,
MLParamKeys.TT_SIZE.value: self.tt_size,
MLParamKeys.TRAIN_TEST_SPLIT_TYPE.value: self.split_type,
MLParamKeys.UNDERSAMPLE_RATIO.value: self.under_sample_ratio,
MLParamKeys.UNDERSAMPLE_SETTING.value: self.under_sample_setting,
MLParamKeys.CLASS_WEIGHTS.value: self.rf_max_depth,
MLParamKeys.CUDA.value: cuda}
out_df = pd.DataFrame([list(table.values())], columns=list(table.keys()))
out_df.to_csv(save_path)
[docs] def save(self) -> None:
"""
Method for saving pickled RF model. The model is saved in the `models/generated_models` directory
of the SimBA project tree.
"""
self.timer.stop_timer()
if not os.listdir(self.model_dir_out):
os.makedirs(self.model_dir_out)
self.save_rf_model(self.rf_clf, self.clf_name, self.model_dir_out)
stdout_success(msg=f"Classifier {self.clf_name} saved in {self.model_dir_out} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
stdout_success(msg=f"Evaluation files are in {self.eval_out_path} folders", source=self.__class__.__name__)
# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\open_field_rearing\project_folder\project_config.ini")
# test.run()
# test.save()
# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
# test.save()
# test = TrainRandomForestClassifier(config_path=r"/mnt/c/troubleshooting/mitra/project_folder/project_config.ini")
# test.run()
# test.save()
# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
# test.save()
# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
# test.save()
# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
# test.save()
#
# test = TrainRandomForestClassifier(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini")
# test.run()
# test.save()
# test = TrainRandomForestClassifier(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini')
# test.run()
# test.save()
# test = TrainRandomForestClassifier(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini')
# test.perform_sampling()
# test.train_model()
# test.save_model()
# test = TrainSingleModel(config_path='/Users/simon/Desktop/envs/troubleshooting/locomotion/project_folder/project_config.ini')
# test.perform_sampling()
# test.train_model()
# test.save_model()
# test = TrainSingleModel(config_path='/Users/simon/Desktop/envs/troubleshooting/prueba/project_folder/project_config.ini')
# test.perform_sampling()
# test.train_model()
# test.save_model()
# test = TrainSingleModel(config_path='/Users/simon/Desktop/envs/troubleshooting/naresh/project_folder/project_config.ini')
# test.perform_sampling()
# test.train_model()
# test.save_model()
# test = TrainSingleModel(config_path='/Users/simon/Desktop/envs/troubleshooting/Lucas/project_folder/project_config.ini')
# test.perform_sampling()
# test.train_model()
# test.save`_model()