Source code for simba.model.inference_batch

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import argparse
import os
import sys
from copy import deepcopy
from typing import Dict, List, Optional, Union

import numpy as np

from simba.data_processors.agg_clf_calculator import AggregateClfCalculator
from simba.mixins.config_reader import ConfigReader
from simba.mixins.train_model_mixin import TrainModelMixin
from simba.utils.checks import (
    check_all_file_names_are_represented_in_video_log,
    check_file_exist_and_readable, check_float, check_if_dir_exists,
    check_if_keys_exist_in_dict, check_int, check_that_column_exist,
    check_valid_dict, check_valid_lst)
from simba.utils.data import plug_holes_shortest_bout
from simba.utils.enums import ConfigKey, TagNames
from simba.utils.errors import InvalidInputError, NoFilesFoundError
from simba.utils.printing import (SimbaTimer, log_event, stdout_information,
                                  stdout_success)
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
                                    get_fn_ext, read_df, write_df)
from simba.utils.warnings import NoFileFoundWarning

MINIMUM_BOUT_LENGTH = 'minimum_bout_length'
THRESHOLD = 'threshold'
MODEL_NAME = 'model_name'
MODEL_PATH = 'model_path'

[docs]class InferenceBatch(TrainModelMixin, ConfigReader): """ Run classifier inference on all files with the ``project_folder/csv/features_extracted`` directory. Results are stored in the ``project_folder/csv/machine_results`` directory of the SimBA project. .. note:: To compute aggregate statistics from the output of this class, see :func:`simba.data_processors.agg_clf_calculator.AggregateClfCalculator` :param Union[str, os.PathLike] config_path: path to SimBA project config file in Configparser format. :param Optional[Union[str, os.PathLike]] features_dir: Optional directory containing featurized files in CSV or parquet format. If None, then the `project_folder/csv/features_extracted` directory of the project will be used. :param Optional[Union[str, os.PathLike]] save_dir: Optional directory to save the data for the analyzed videos. If None, then the `project_folder/csv/machine_results` directory of the project will be used. :param Optional[int] minimum_bout_length: Optional minimum bout length (milliseconds) override. If None, classifier-specific minimum bout settings from project configuration are used. :param Optional[Dict[str, Dict[str, List[str]]]] feature_subsets_by_clf: Optional per-classifier feature subsets to use during inference. Format: ``{classifier_name: {subset_name: [feature_col_1, feature_col_2, ...]}}``. If provided, each classifier is applied once per subset and outputs are suffixed with the subset name. :param Optional[Dict[str, Dict[str, Union[str, int, float]]]] model_dict: Optional override of the classifiers to run. Format: ``{model_name: {'model_path': '/path/to/clf.sav', 'minimum_bout_length': 100, 'threshold': 0.5}}``. If None, classifier definitions are read from the project config (current behavior). When provided, these models replace the project-config classifiers for this run. :param Optional[Union[str, os.PathLike]] save_agg_stats: Optional directory in which to save aggregate classifier statistics. If None, no aggregate statistics are computed. If a directory is provided, :class:`simba.data_processors.agg_clf_calculator.AggregateClfCalculator` is run after inference completes, reading from this class's ``save_dir`` and writing its CSV outputs to ``save_agg_stats``. :param bool verbose: If True, print progress and status messages during inference. Default: True. :example I: >>> inferencer = InferenceBatch(config_path='MyConfigPath') >>> inferencer.run() :example II: >>> inferencer = InferenceBatch(config_path=r"D:/troubleshooting/mitra/project_folder/project_config.ini", features_dir=r"D:/troubleshooting/mitra/project_folder/videos/bg_removed/rotated/tail_features/APPENDED") >>> inferencer.run() """ def __init__(self, config_path: Union[str, os.PathLike], features_dir: Optional[Union[str, os.PathLike]] = None, save_dir: Optional[Union[str, os.PathLike]] = None, minimum_bout_length: Optional[int] = None, feature_subsets_by_clf: Optional[Dict[str, Dict[str, List[str]]]] = None, model_dict: Optional[Dict[str, Dict[str, Union[str, int, float]]]] = None, save_agg_stats: Optional[Union[str, os.PathLike]] = None, verbose: bool = True): ConfigReader.__init__(self, config_path=config_path) if features_dir is not None: check_if_dir_exists(in_dir=features_dir, source=self.__class__.__name__) self.features_dir = deepcopy(features_dir) self.feature_file_paths = find_files_of_filetypes_in_directory(directory=self.features_dir, extensions=[f'.{self.file_type}'], raise_warning=False, raise_error=False) if save_dir is not None: check_if_dir_exists(in_dir=save_dir, source=self.__class__.__name__) self.save_dir = deepcopy(save_dir) else: self.save_dir = self.machine_results_dir TrainModelMixin.__init__(self) log_event(logger_name=str(self.__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) if len(self.feature_file_paths) == 0: raise NoFilesFoundError(msg=f"Zero files found in the {self.features_dir}. Create features before running classifier.", source=self.__class__.__name__,) if model_dict is not None: check_valid_dict(x=model_dict, valid_key_dtypes=(str,), valid_values_dtypes=(dict,), min_len_keys=1, source=f'{self.__class__.__name__} model_dict') translated_model_dict = {} for cnt, (model_name, m_hyp) in enumerate(model_dict.items()): check_if_keys_exist_in_dict(data=m_hyp, key=[MODEL_PATH, THRESHOLD, MINIMUM_BOUT_LENGTH], name=f'{self.__class__.__name__} model_dict[{model_name}]', raise_error=True) check_file_exist_and_readable(file_path=m_hyp[MODEL_PATH]) check_float(name=f'{self.__class__.__name__} model_dict[{model_name}] {THRESHOLD}', value=m_hyp[THRESHOLD], min_value=0.0, max_value=1.0) check_int(name=f'{self.__class__.__name__} model_dict[{model_name}] {MINIMUM_BOUT_LENGTH}', value=m_hyp[MINIMUM_BOUT_LENGTH], min_value=0) translated_model_dict[cnt] = {MODEL_PATH: m_hyp[MODEL_PATH], MODEL_NAME: model_name, THRESHOLD: float(m_hyp[THRESHOLD]), MINIMUM_BOUT_LENGTH: int(m_hyp[MINIMUM_BOUT_LENGTH])} self.clf_names = list(model_dict.keys()) self.clf_cnt = len(self.clf_names) self._override_model_dict = translated_model_dict else: self._override_model_dict = None if feature_subsets_by_clf is not None: check_valid_dict(x=feature_subsets_by_clf, valid_key_dtypes=(str,), source=f'{self.__class__.__name__} feature_subsets_by_clf') for cnt, (k, v) in enumerate(feature_subsets_by_clf.items()): if k not in self.clf_names: raise InvalidInputError(msg=f'Unknown classifier "{k}" in feature_subsets_by_clf. Valid classifiers: {self.clf_names}', source=self.__class__.__name__) check_valid_dict(x=v, valid_key_dtypes=(str,), valid_values_dtypes=(list,), source=f'{self.__class__.__name__} feature_subsets_by_clf {cnt}') for subset_name, feature_names in v.items(): check_valid_lst(data=feature_names, source=f'{self.__class__.__name__} feature_subsets_by_clf {k} {subset_name}', valid_dtypes=(str,), min_len=1, raise_error=True) if minimum_bout_length is not None: check_int(name=f'{self.__class__.__name__} minimum_bout_length', value=minimum_bout_length, allow_zero=False, allow_negative=False, raise_error=True) if save_agg_stats is not None: check_if_dir_exists(in_dir=save_agg_stats, source=f'{self.__class__.__name__} save_agg_stats') self.save_agg_stats = save_agg_stats self.verbose, self.feature_subsets_by_clf, self.minimum_bout_length = verbose, feature_subsets_by_clf, minimum_bout_length if verbose: stdout_information(msg=f"Analyzing {len(self.feature_file_paths)} file(s) with {self.clf_cnt} classifier(s)...") self.timer = SimbaTimer(start=True) self.model_dict = self._override_model_dict if self._override_model_dict is not None else self.get_model_info(config=self.config, model_cnt=self.clf_cnt)
[docs] def run(self): check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.feature_file_paths) for file_cnt, file_path in enumerate(self.feature_file_paths): video_timer = SimbaTimer(start=True) _, file_name, _ = get_fn_ext(file_path) if self.verbose: stdout_information(msg=f"Analyzing video {file_name}... (Video {file_cnt+1}/{len(self.feature_file_paths)})") file_save_path = os.path.join(self.save_dir, f"{file_name}.{self.file_type}") in_df = read_df(file_path, self.file_type) x_df = self.drop_bp_cords(df=in_df).astype(np.float32) self.check_df_dataset_integrity(df=x_df, logs_path=self.logs_path, file_name=file_name) _, px_per_mm, fps = self.read_video_info(video_name=file_name, raise_error=False) out_df = deepcopy(in_df) for m, m_hyp in self.model_dict.items(): check_if_keys_exist_in_dict(data=m_hyp, key=[MODEL_PATH, MODEL_NAME, THRESHOLD, MINIMUM_BOUT_LENGTH], name=f'classifier {m}', raise_error=False) if not os.path.isfile(m_hyp[MODEL_PATH]): NoFileFoundWarning(msg=f'SKIPPING CLASSIFIER {m} for video {file_name}. The classifier model file {m_hyp[MODEL_PATH]} could not be found.', source=self.__class__.__name__) continue clf = self.read_pickle(file_path=m_hyp[MODEL_PATH]) if self.feature_subsets_by_clf is None or m_hyp[MODEL_NAME] not in self.feature_subsets_by_clf: probability_column = f"Probability_{m_hyp[MODEL_NAME]}" out_df[probability_column] = self.clf_predict_proba(clf=clf, x_df=x_df, data_path=file_path, model_name=m_hyp[MODEL_NAME]) out_df[m_hyp[MODEL_NAME]] = np.where(out_df[probability_column] > m_hyp[THRESHOLD], 1, 0) clf_min_bout = self.minimum_bout_length if self.minimum_bout_length is not None else m_hyp[MINIMUM_BOUT_LENGTH] if int(clf_min_bout) > 0: if self.verbose: stdout_information(msg=f'Correcting minimum bouts in video {file_name} and classifier {m_hyp[MODEL_NAME]} ({clf_min_bout}ms)...') out_df = plug_holes_shortest_bout(data_df=out_df, clf_name=m_hyp[MODEL_NAME], fps=fps, shortest_bout=clf_min_bout) else: subset_cnts = len(self.feature_subsets_by_clf[m_hyp[MODEL_NAME]].keys()) self.config.set(section=ConfigKey.SML_SETTINGS.value, option=ConfigKey.TARGET_CNT.value, value=str(subset_cnts)) for mdl_cnt, (model_subset_name, model_x) in enumerate(self.feature_subsets_by_clf[m_hyp[MODEL_NAME]].items()): probability_column = f"Probability_{m_hyp[MODEL_NAME]}_{model_subset_name}" self.config.set(section=ConfigKey.SML_SETTINGS.value, option=f'target_name_{mdl_cnt+1}', value=f'{m_hyp[MODEL_NAME]}_{model_subset_name}') self.config.set(section=ConfigKey.THRESHOLD_SETTINGS.value, option=f'threshold_{mdl_cnt+1}', value=f'{m_hyp[THRESHOLD]}') self.config.set(section=ConfigKey.MIN_BOUT_LENGTH.value, option=f'min_bout_{mdl_cnt+1}', value=f'{m_hyp[MINIMUM_BOUT_LENGTH]}') check_that_column_exist(df=x_df, column_name=model_x, file_name=file_path, raise_error=True) out_df[probability_column] = self.clf_predict_proba(clf=clf, x_df=x_df[model_x], data_path=file_path, model_name=model_subset_name) out_df[f'{m_hyp[MODEL_NAME]}_{model_subset_name}'] = np.where(out_df[probability_column] > m_hyp[THRESHOLD], 1, 0) clf_min_bout = self.minimum_bout_length if self.minimum_bout_length is not None else m_hyp[MINIMUM_BOUT_LENGTH] if int(clf_min_bout) > 0: if self.verbose: stdout_information(msg=f'Correcting minimum bouts in video {file_name} and classifier {m_hyp[MODEL_NAME]} ({clf_min_bout}ms)...') out_df = plug_holes_shortest_bout(data_df=out_df, clf_name=f'{m_hyp[MODEL_NAME]}_{model_subset_name}', fps=fps, shortest_bout=clf_min_bout) with open(self.config_path, "w") as f: self.config.write(f) write_df(df=out_df, file_type=self.file_type, save_path=file_save_path) video_timer.stop_timer() if self.verbose: stdout_information(msg=f"Predictions created for {file_name} (frame count: {len(in_df)}, elapsed time: {video_timer.elapsed_time_str}) ...") if self.save_agg_stats is not None: if self.verbose: stdout_information(msg=f"Computing aggregate classifier statistics into {self.save_agg_stats}...") agg = AggregateClfCalculator(config_path=self.config_path, classifiers=list(self.clf_names), data_dir=self.save_dir, save_dir=self.save_agg_stats) agg.run() agg.save() self.timer.stop_timer() stdout_success(msg=f"Machine predictions complete for {len(self.feature_file_paths)} file(s). Files saved in {self.save_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)
# if __name__ == "__main__" and not hasattr(sys, 'ps1'): # parser = argparse.ArgumentParser(description="Perform classifications according to rules defined in SimBA project_config.ini.") # parser.add_argument('--config_path', type=str, required=True, help='Path to SimBA Project config.') # args = parser.parse_args() # runner = InferenceBatch(config_path=args.config_path) # runner.run() # # test = InferenceBatch(config_path=r"H:\projects\jason_zhang\jason_project\project_folder\project_config.ini", # save_dir=r'H:\projects\jason_zhang\jason_project\project_folder\csv\GROOMING\500_0.275_smoothing_500ms\csvs', # model_dict={'GROOMING': {'model_path': r"H:\projects\jason_zhang\jason_project\models\GROOMING.sav", 'minimum_bout_length': 500, 'threshold': 0.275}}, # save_agg_stats=r'H:\projects\jason_zhang\jason_project\project_folder\csv\GROOMING\500_0.275_smoothing_500ms\csvs') # test.run() # # test = InferenceBatch(config_path=r"E:\troubleshooting\two_black_animals_14bp\project_folder\project_config.ini") # test.run() # test = InferenceBatch(config_path=r"C:\troubleshooting\Top_down_old\project_folder\project_config.ini") # test.run() # # test = InferenceBatch(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini", # features_dir=r'C:\troubleshooting\mitra\project_folder\videos\additional\bg_removed\rotated\tail_features_additional\APPENDED') # test.run() # test = InferenceBatch(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini") # test.run() # test = InferenceBatch(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini", # features_dir=r'D:\troubleshooting\mitra\project_folder\videos\bg_removed\rotated\laying_down_features\APPENDED') # test.run() # if __name__ == "__main__": # test = InferenceBatch(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini", # features_dir=r'D:\troubleshooting\mitra\project_folder\videos\bg_removed\rotated\tail_features\APPENDED') # test.run() # test = InferenceBatch(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/levi/project_folder/project_config.ini') # test.run() # test = InferenceBatch(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') # test.run() # test = RunModel(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') # test.run_models()