Source code for simba.roi_tools.ROI_directing_analyzer

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import os
from typing import Optional, Union

import numpy as np
import pandas as pd
from numba import jit, prange

from simba.mixins.config_reader import ConfigReader
from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
from simba.utils.checks import (
    check_all_file_names_are_represented_in_video_log,
    check_file_exist_and_readable, check_instance, check_str,
    check_valid_boolean, check_valid_dataframe)
from simba.utils.data import slice_roi_dict_for_video
from simba.utils.enums import ROI_SETTINGS, Formats
from simba.utils.errors import (InvalidInputError, NoDataError,
                                ROICoordinatesNotFoundError)
from simba.utils.lookups import create_directionality_cords
from simba.utils.printing import (SimbaTimer, stdout_information,
                                  stdout_success, stdout_warning)
from simba.utils.read_write import (get_fn_ext, read_data_paths, read_df,
                                    read_roi_data, read_video_info,
                                    seconds_to_timestamp)

EAR_LEFT, EAR_RIGHT, NOSE = "Ear_left", "Ear_right", "Nose"
X_BPS, Y_BPS = "X_bps", "Y_bps"
CENTER_X, CENTER_Y, NAME = 'Center_X', 'Center_Y', 'Name'
FIRST_DIRECTING_TIME = 'FIRST DIRECTING TIME'
LAST_DIRECTING_TIME = 'LAST DIRECTING TIME'
DIRECTING_BOUTS = 'DIRECTING BOUTS'
DIRECTING_TIME = 'TOTAL DIRECTING TIME'

[docs]class DirectingROIAnalyzer(ConfigReader, FeatureExtractionMixin): """ Compute aggregate statistics for animals directing towards ROIs. :param str config_path: Path to SimBA project config file in Configparser format :param Optional[Union[str, os.PathLike]] data_path: Path to folder or file holding the data used to calculate ROI aggregate statistics. If None, then defaults to the `project_folder/csv/outlier_corrected_movement_location` directory of the SimBA project. Default: None. .. note:: `ROI tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/ROI_tutorial_new.md>`__. `Example expected output file <https://github.com/sgoldenlab/simba/blob/master/docs/ROI_tutorial_new.md>`__. :example: >>> test = DirectingROIAnalyzer(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini') >>> test.run() >>> test.save() """ def __init__(self, config_path: Union[str, os.PathLike], data_path: Optional[Union[str, os.PathLike]] = None, roi_coordinates_path: Optional[Union[str, os.PathLike]] = None, detailed_table: bool = True, agg_stats: bool = True, transpose_agg_stats: bool = False, left_ear_name: Optional[str] = None, right_ear_name: Optional[str] = None, nose_name: Optional[str] = None): check_file_exist_and_readable(file_path=config_path) if data_path is not None: check_instance(source=f'{self.__class__.__name__} data_path', instance=data_path, accepted_types=(str,)) if roi_coordinates_path is not None: check_file_exist_and_readable(file_path=roi_coordinates_path) check_valid_boolean(value=[detailed_table, agg_stats, transpose_agg_stats], source=f'{self.__class__.__name__} detailed_table/agg_stats/transpose_agg_stats') for bp_name, bp_value in zip(['left_ear_name', 'right_ear_name', 'nose_name'], [left_ear_name, right_ear_name, nose_name]): if bp_value is not None: check_str(name=bp_name, value=bp_value) ConfigReader.__init__(self, config_path=config_path) FeatureExtractionMixin.__init__(self, config_path=config_path) if roi_coordinates_path is not None: self.roi_coordinates_path = read_roi_data(self.roi_coordinates_path) else: if not os.path.isfile(self.roi_coordinates_path): raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) self.read_roi_data() self.data_paths = read_data_paths(path=data_path, default=self.outlier_corrected_paths, default_name=self.outlier_corrected_dir, file_type=self.file_type) self.transpose_agg_stats, self.detailed_table, self.compute_agg_stats = transpose_agg_stats, detailed_table, agg_stats passed_bps = [left_ear_name, right_ear_name, nose_name] if sum(p is None for p in passed_bps) not in (0, len(passed_bps)): raise InvalidInputError(msg="left_ear_name, right_ear_name, and nose_name must either all be None or all be provided as strings", source=self.__class__.__name__) if isinstance(left_ear_name, str) and isinstance(right_ear_name, str) and isinstance(nose_name, str): self.direct_bp_dict = create_directionality_cords(bp_dict=self.animal_bp_dict, left_ear_name=left_ear_name, nose_name=nose_name, right_ear_name=right_ear_name) else: if not self.check_directionality_viable()[0]: raise InvalidInputError(msg="You are not tracking the necessary body-parts to calculate direction. Either (i) pass the body-parts names, or (ii) name the body-parts properly so SimBA can automatically detect the left ear, right ear, and nose body-part names.", source=self.__class__.__name__) self.direct_bp_dict = self.check_directionality_cords() def __format_direction_data( self, direction_data: np.ndarray, nose_arr: np.ndarray, roi_center: np.ndarray, animal_name: str, shape_name: str, ) -> pd.DataFrame: x_min = np.minimum(direction_data[:, 1], nose_arr[:, 0]) y_min = np.minimum(direction_data[:, 2], nose_arr[:, 1]) delta_x = abs((direction_data[:, 1] - nose_arr[:, 0]) / 2) delta_y = abs((direction_data[:, 2] - nose_arr[:, 1]) / 2) x_middle, y_middle = np.add(x_min, delta_x), np.add(y_min, delta_y) direction_data = np.concatenate( (y_middle.reshape(-1, 1), direction_data), axis=1 ) direction_data = np.concatenate( (x_middle.reshape(-1, 1), direction_data), axis=1 ) direction_data = np.delete(direction_data, [2, 3, 4], 1) bp_data = pd.DataFrame( direction_data, columns=["Eye_x", "Eye_y", "Directing_BOOL"] ) bp_data["ROI_x"] = roi_center[0] bp_data["ROI_y"] = roi_center[1] bp_data = bp_data[["Eye_x", "Eye_y", "ROI_x", "ROI_y", "Directing_BOOL"]] bp_data.insert(loc=0, column="ROI", value=shape_name) bp_data.insert(loc=0, column="Animal", value=animal_name) bp_data.insert(loc=0, column="Video", value=self.video_name) bp_data = bp_data.reset_index().rename(columns={"index": "Frame"}) bp_data = bp_data[bp_data["Directing_BOOL"] == 1].reset_index(drop=True) return bp_data @staticmethod @jit(nopython=True, fastmath=True) def ccw(roi_lines: np.array, eye_lines: np.array, shape_type: str): def calc(A, B, C): return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0]) results = np.full((eye_lines.shape[0], 4), -1) for i in prange(eye_lines.shape[0]): eye, roi = eye_lines[i][0:2], eye_lines[i][2:4] min_distance = np.inf if shape_type == "Circle": reversed_roi_lines = roi_lines[::-1] for j in prange(roi_lines.shape[0]): dist_1 = np.sqrt((eye[0] - roi_lines[j][0]) ** 2 + (eye[1] - roi_lines[j][1]) ** 2) dist_2 = np.sqrt((eye[0] - roi_lines[j][2]) ** 2 + (eye[1] - roi_lines[j][3]) ** 2) if (dist_1 < min_distance) or (dist_2 < min_distance): min_distance = min(dist_1, dist_2) results[i] = reversed_roi_lines[j] else: for j in prange(roi_lines.shape[0]): line_a, line_b = roi_lines[j][0:2], roi_lines[j][2:4] center_x, center_y = (line_a[0] + line_b[0] // 2, line_a[1] + line_b[1] // 2) if calc(eye, line_a, line_b) != calc(roi, line_a, line_b) or calc(eye, roi, line_a) != calc(eye, roi, line_b): distance = np.sqrt((eye[0] - center_x) ** 2 + (eye[1] - center_y) ** 2) if distance < min_distance: results[i] = roi_lines[j] min_distance = distance return results def __find_roi_intersections(self, bp_data: pd.DataFrame, shape_info: dict): eye_lines = bp_data[["Eye_x", "Eye_y", "ROI_x", "ROI_y"]].values.astype(int) roi_lines = None if shape_info["Shape_type"].lower() == ROI_SETTINGS.RECTANGLE.value: top_left_x, top_left_y = (shape_info["topLeftX"], shape_info["topLeftY"]) bottom_right_x, bottom_right_y = (shape_info["Bottom_right_X"], shape_info["Bottom_right_Y"]) top_right_x, top_right_y = top_left_x + shape_info["width"], top_left_y bottom_left_x, bottom_left_y = (bottom_right_x - shape_info["width"], bottom_right_y) roi_lines = np.array([[top_left_x, top_left_y, bottom_left_x, bottom_left_y], [bottom_left_x, bottom_left_y, bottom_right_x, bottom_right_y], [bottom_right_x, bottom_right_y, top_right_x, top_right_y], [top_right_x, top_right_y, top_left_x, top_left_y]]) elif shape_info["Shape_type"].lower() == ROI_SETTINGS.POLYGON.value: roi_lines = np.full((shape_info["vertices"].shape[0], 4), np.nan) roi_lines[-1] = np.hstack((shape_info["vertices"][0], shape_info["vertices"][-1])) for i in range(shape_info["vertices"].shape[0] - 1): roi_lines[i] = np.hstack((shape_info["vertices"][i], shape_info["vertices"][i + 1])) elif shape_info["Shape_type"].lower() == ROI_SETTINGS.CIRCLE.value: center = shape_info[["centerX", "centerY"]].values.astype(int) roi_lines = np.full((2, 4), np.nan) roi_lines[0] = np.array([center[0], center[1] - shape_info["radius"], center[0], center[1] + shape_info["radius"]]) roi_lines[1] = np.array([center[0] - shape_info["radius"], center[1], center[0] + shape_info["radius"], center[1]]) return self.ccw(roi_lines=roi_lines, eye_lines=eye_lines, shape_type=shape_info["Shape_type"]) def run(self): self.results, self.agg_stats_df = [], pd.DataFrame(columns=['VIDEO', 'ANIMAL', 'ROI', 'ROI_TYPE', 'MEASURE', 'VALUE']) check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths) for file_cnt, file_path in enumerate(self.data_paths): _, self.video_name, _ = get_fn_ext(file_path) _, _, fps = read_video_info(video_name=self.video_name, video_info_df=self.video_info_df, raise_error=True) video_timer = SimbaTimer(start=True) stdout_information(msg=f"Analyzing ROI directionality in video {self.video_name}...") data_df = read_df(file_path=file_path, file_type=self.file_type) video_roi_dict, shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) if len(shape_names) == 0: stdout_warning(msg=f'Skipping video {self.video_name}: No ROIs drawn for video {self.video_name} detected.') continue for animal_name, bps in self.direct_bp_dict.items(): required_bps = list(set(v for bp in bps.values() for v in bp.values())) check_valid_dataframe(df=data_df, source=file_path, valid_dtypes=Formats.NUMERIC_DTYPES.value, required_fields=required_bps) ear_left_arr = data_df[[bps[EAR_LEFT][X_BPS], bps[EAR_LEFT][Y_BPS]]].values ear_right_arr = data_df[[bps[EAR_RIGHT][X_BPS], bps[EAR_RIGHT][Y_BPS]]].values nose_arr = data_df[[bps[NOSE][X_BPS], bps[NOSE][Y_BPS]]].values.astype(np.int64) for roi_type, roi_type_data in video_roi_dict.items(): for _, row in roi_type_data.iterrows(): roi_center, roi_name = np.array([row[CENTER_X], row[CENTER_Y]]), row[NAME] direction_data = FeatureExtractionMixin.jitted_line_crosses_to_static_targets(left_ear_array=ear_left_arr, right_ear_array=ear_right_arr, nose_array=nose_arr, target_array=roi_center) bp_data = self.__format_direction_data(direction_data=direction_data, nose_arr=nose_arr, roi_center=roi_center, animal_name=animal_name, shape_name=roi_name) eye_roi_intersections = pd.DataFrame(self.__find_roi_intersections(bp_data=bp_data, shape_info=row), columns=[ "ROI_edge_1_x", "ROI_edge_1_y", "ROI_edge_2_x", "ROI_edge_2_y"]) self.results.append(pd.concat([bp_data, eye_roi_intersections], axis=1)) if len(bp_data) > 0 and self.compute_agg_stats: first_direct_time = seconds_to_timestamp(seconds=bp_data['Frame'].min() / fps) last_direct_time = seconds_to_timestamp(seconds=bp_data['Frame'].max() / fps) directing_time = seconds_to_timestamp(seconds=len(bp_data) / fps) direction_bouts = int((bp_data['Frame'].diff() > 1).sum() + 1) for measure, value in [(FIRST_DIRECTING_TIME, first_direct_time), (LAST_DIRECTING_TIME, last_direct_time), (DIRECTING_TIME, directing_time), (DIRECTING_BOUTS, direction_bouts)]: self.agg_stats_df.loc[len(self.agg_stats_df)] = [self.video_name, animal_name, roi_name, roi_type, measure, value] video_timer.stop_timer() stdout_information(msg=f"ROI directionality analyzed in video {self.video_name}... (elapsed time: {video_timer.elapsed_time_str}s)") if len(self.results) == 0: raise NoDataError(msg=f'No ROI DATA exists for data files {self.data_paths}', source=self.__class__.__name__) self.results_df = pd.concat(self.results, axis=0) if self.transpose_agg_stats and len(self.agg_stats_df) > 0: self.agg_stats_df = self.agg_stats_df.pivot_table(index=['VIDEO', 'ANIMAL', 'ROI', 'ROI_TYPE'], columns='MEASURE', values='VALUE', aggfunc='first').reset_index() self.agg_stats_df.columns.name = None def save(self): if not hasattr(self, "results_df"): raise InvalidInputError(msg="Run the ROI direction analyzer before saving") if self.detailed_table: path = os.path.join(self.logs_path, f"ROI_directionality_summary_{self.datetime}.csv") self.results_df.to_csv(path) stdout_success(msg=f"Detailed ROI directionality data saved in {path}", source=self.__class__.__name__) if self.compute_agg_stats: path = os.path.join(self.logs_path, f"ROI_directionality_aggregate_stats_{self.datetime}.csv") self.agg_stats_df.to_csv(path) stdout_success(msg=f"Detailed ROI aggregate statistics saved in {path}", source=self.__class__.__name__)
# test = DirectingROIAnalyzer(config_path=r"E:\troubleshooting\mitra_emergence_hour\project_folder\project_config.ini", # nose_name='left_ear', # left_ear_name='nose', # right_ear_name='tail_base') # test.run() # test.save()