Source code for simba.data_processors.cue_light_clf_statistics

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import glob
import os
from typing import List, Optional, Union

import pandas as pd

from simba.cue_light_tools.cue_light_tools import find_frames_when_cue_light_on
from simba.mixins.config_reader import ConfigReader
from simba.utils.checks import (
    check_all_file_names_are_represented_in_video_log, check_if_dir_exists,
    check_int, check_that_column_exist, check_valid_dataframe, check_valid_lst)
from simba.utils.data import detect_bouts
from simba.utils.enums import Formats
from simba.utils.errors import NoDataError, NoFilesFoundError
from simba.utils.printing import stdout_success
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
                                    get_fn_ext, read_config_entry, read_df)


[docs]class CueLightClfAnalyzer(ConfigReader): """ Compute aggregate statistics when classified behaviors are occurring in relation to the cue light ON and OFF states. :param Union[str, os.PathLike] config_path: path to SimBA project config file in Configparser format :param int pre_window: Time period (in millisecond) before the onset of each cue light to compute aggregate classification statistics within. :param int post_window: Time period (in millisecond) after the offset of each cue light to compute aggregate classification statistics within. :param List[str] cue_light_names: Names of cue lights, as defined in the SimBA ROI interface. :param List[str] list: Names of the classifiers we want to compute aggregate statistics for. .. note:: `Cue light tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/cue_light_tutorial.md>`__. :example: >>> test = CueLightClfAnalyzer(config_path=r"C:\troubleshooting\cue_light\t1\project_folder\project_config.ini", >>> pre_window=1, >>> post_window=1, >>> cue_light_names=['cl'], >>> clf_names=['freeze']) >>> test.run() >>> test.save() """ def __init__(self, config_path: Union[str, os.PathLike], cue_light_names: List[str], clf_names: List[str], data_dir: Optional[Union[str, os.PathLike]] = None, pre_window: int = 0, post_window: int = 0): ConfigReader.__init__(self, config_path=config_path) check_valid_lst(data=cue_light_names, source=f'{self.__class__.__name__} cue_light_names', valid_dtypes=(str,), min_len=1, raise_error=True) check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), min_len=1, raise_error=True, valid_values=self.clf_names) check_int(name=f'{self.__class__.__name__} pre_window', value=pre_window, min_value=0) check_int(name=f'{self.__class__.__name__} post_window', value=post_window, min_value=0) if data_dir is None: self.data_dir = self.cue_lights_data_dir else: check_if_dir_exists(in_dir=data_dir) self.data_dir = data_dir self.cue_light_paths = find_files_of_filetypes_in_directory(directory=self.data_dir, extensions=[f'.{self.file_type}'], raise_error=True, as_dict=True) self.machine_results_paths = find_files_of_filetypes_in_directory(directory=self.machine_results_dir, extensions=[f'.{self.file_type}'], raise_error=True, as_dict=True) missing_ml = [x for x in self.cue_light_paths.keys() if x not in self.machine_results_paths.keys()] if len(missing_ml) > 0: raise NoDataError(msg=f'{len(missing_ml)} cue-light file(s) are missing classification files in the {self.machine_results_dir} directory: {missing_ml}', source=self.__class__.__name__) self.cue_light_names, self.pre_window, self.post_window, self.clf_names = cue_light_names, pre_window, post_window, clf_names self.save_path = os.path.join(self.logs_path, f"Cue_lights_clf_statistics_{self.datetime}.csv") def run(self): check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=list(self.cue_light_paths.values())) self.results = pd.DataFrame(columns=['VIDEO', 'CUE LIGHT', 'CLASSIFIER', 'CUE LIGHT BOUT START TIME', 'CUE LIGHT BOUT END TIME', 'CUE LIGHT BOUT START FRAME', 'CUE LIGHT BOUT END FRAME', ' CUE LIGHT BOUT BEHAVIOR PRESENT (S)', 'CUE LIGHT BOUT BEHAVIOR ABSENT (S)', f'PRE CUE LIGHT BOUT ({self.pre_window}s) PRESENT (S)', f'PRE CUE LIGHT BOUT ({self.pre_window}s) ABSENT (S)', f'POST CUE LIGHT BOUT ({self.pre_window}s) PRESENT (S)', f'POST CUE LIGHT BOUT ({self.pre_window}s) ABSENT (S)']) print('Running cue-light classifier statistics...') for file_cnt, (video_name, cue_light_data_path) in enumerate(self.cue_light_paths.items()): print(f'Analyzing {video_name}...') machine_results_path = self.machine_results_paths[video_name] ml_df = read_df(machine_results_path, self.file_type) cue_light_df = read_df(cue_light_data_path, self.file_type) check_valid_dataframe(df=ml_df, source=machine_results_path, valid_dtypes=Formats.NUMERIC_DTYPES.value, required_fields=self.clf_names) check_valid_dataframe(df=cue_light_df, source=cue_light_data_path, valid_dtypes=Formats.NUMERIC_DTYPES.value, required_fields=self.cue_light_names) data_df = pd.concat([ml_df, cue_light_df[self.cue_light_names]], axis=1) del cue_light_df, ml_df _, _, fps = self.read_video_info(video_name=video_name) self.prior_window_frames_cnt = int(self.pre_window * fps) self.post_window_frames_cnt = int(self.post_window * fps) cue_light_bouts = detect_bouts(data_df=data_df, target_lst=self.cue_light_names, fps=fps).reset_index(drop=True) for bout_cnt, bout in cue_light_bouts.iterrows(): cue_frm_range = list(range(bout['Start_frame'], bout['End_frame']+1)) pre_window_frms = list(range(max(0, bout['Start_frame']-self.prior_window_frames_cnt), bout['Start_frame'])) post_window_frms = list(range(bout['End_frame']+1, min((bout['End_frame'] + self.post_window_frames_cnt), len(data_df)))) cue_frm_range_df = data_df.loc[cue_frm_range][self.clf_names] pre_window_frms_df = data_df.loc[pre_window_frms][self.clf_names] post_window_frms_df = data_df.loc[post_window_frms][self.clf_names] for clf in self.clf_names: cue_clf_present = round(cue_frm_range_df[clf].sum() / fps, 4) cue_clf_absent = round(bout['Bout_time'] - cue_clf_present, 4) pre_clf_present = round(pre_window_frms_df[clf].sum() / fps, 4) pre_clf_absent = round(self.pre_window - pre_clf_present, 4) post_clf_present = round(post_window_frms_df[clf].sum() / fps, 4) post_clf_absent = round(self.post_window - post_clf_present, 4) self.results.loc[len(self.results)] = [video_name, bout['Event'], clf, bout['Start_time'], bout['End Time'], bout['Start_frame'], bout['End_frame'], cue_clf_present, cue_clf_absent, pre_clf_present, pre_clf_absent, post_clf_present, post_clf_absent] def save(self): self.results = self.results.sort_values(by=['VIDEO', 'CUE LIGHT', 'CUE LIGHT BOUT START TIME'], ascending=True) if self.post_window == 0: self.results = self.results.drop([f'POST CUE LIGHT BOUT ({self.pre_window}s) PRESENT (S)', f'POST CUE LIGHT BOUT ({self.pre_window}s) ABSENT (S)'], axis=1) if self.pre_window == 0: self.results = self.results.drop([f'PRE CUE LIGHT BOUT ({self.pre_window}s) PRESENT (S)', f'PRE CUE LIGHT BOUT ({self.pre_window}s) ABSENT (S)'], axis=1) self.results.to_csv(self.save_path) self.timer.stop_timer() stdout_success(msg=f'Cue light classifier statistics saved at {self.save_path}', elapsed_time=self.timer.elapsed_time_str)
# test = CueLightClfAnalyzer(config_path=r"C:\troubleshooting\cue_light\t1\project_folder\project_config.ini", # pre_window=1, # post_window=1, # cue_light_names=['MY_CUE_LIGHT'], # clf_names=['freeze']) # test.run() # test.save()