__author__ = "Simon Nilsson; sronilsson@gmail.com"
import functools
import multiprocessing
import os
from typing import Dict, List, Union
import cv2
import numpy as np
import pandas as pd
from simba.mixins.config_reader import ConfigReader
from simba.mixins.statistics_mixin import Statistics
from simba.utils.checks import (
check_all_file_names_are_represented_in_video_log, check_if_dir_exists,
check_if_valid_img, check_int, check_valid_boolean, check_valid_lst)
from simba.utils.data import (detect_bouts, slice_roi_dict_from_attribute,
terminate_cpu_pool)
from simba.utils.enums import Defaults, Keys
from simba.utils.errors import NoROIDataError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (find_core_cnt,
find_files_of_filetypes_in_directory,
find_video_of_file, get_video_meta_data,
read_df, read_frm_of_video, write_df)
def slice_rectangle_from_img(img: np.ndarray,
top_left_x: int,
top_left_y: int,
bottom_right_x: int,
bottom_right_y: int):
check_if_valid_img(data=img, source=f'{slice_rectangle_from_img.__name__} img', raise_error=True)
height, width = img.shape[:2]
tl_x = max(0, min(top_left_x, width))
br_x = max(0, min(bottom_right_x, width))
tl_y = max(0, min(top_left_y, height))
br_y = max(0, min(bottom_right_y, height))
if tl_x >= br_x or tl_y >= br_y:
raise NoROIDataError(msg='The ROI has no area', source=slice_rectangle_from_img.__name__)
return img[tl_y:br_y, tl_x:br_x]
def _get_intensity_scores_in_rois(frm_list: List[int],
video_rois: dict,
video_path: str,
verbose: bool):
results = {}
for frm_idx in range(frm_list[0], frm_list[-1]+1):
if verbose:
print(f'Analyzing frame {frm_idx}...')
img = read_frm_of_video(video_path=video_path, frame_index=frm_idx)
for _, rectangle in video_rois[Keys.ROI_RECTANGLES.value].iterrows():
if rectangle["Name"] not in results.keys(): results[rectangle["Name"]] = {}
tl_x, tl_y = rectangle["topLeftX"], rectangle["topLeftY"]
br_x, br_y = rectangle["Bottom_right_X"], rectangle["Bottom_right_Y"]
roi_image = slice_rectangle_from_img(img=img, top_left_x=tl_x, top_left_y=tl_y, bottom_right_x=br_x, bottom_right_y=br_y)
if roi_image.ndim == 3:
results[rectangle["Name"]][frm_idx] = np.average(np.linalg.norm(roi_image, axis=2)) / np.sqrt(3)
else:
results[rectangle["Name"]][frm_idx] = np.average(roi_image)
for _, polygon in video_rois[Keys.ROI_POLYGONS.value].iterrows():
if polygon["Name"] not in results.keys(): results[polygon["Name"]] = {}
x, y, w, h = cv2.boundingRect(polygon["vertices"])
roi_img = img[y : y + h, x : x + w].copy()
pts = polygon["vertices"] - polygon["vertices"].min(axis=0)
mask = np.zeros(roi_img.shape[:2], np.uint8)
cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
dst = cv2.bitwise_and(roi_img, roi_img, mask=mask)
bg = np.ones_like(roi_img, np.uint8)
cv2.bitwise_not(bg, bg, mask=mask)
roi_image = bg + dst
if roi_image.ndim == 3:
results[polygon["Name"]][frm_idx] = np.average(np.linalg.norm(roi_image, axis=2)) / np.sqrt(3)
else:
results[polygon["Name"]][frm_idx] = np.average(roi_image)
for _, circle in video_rois[Keys.ROI_CIRCLES.value].iterrows():
if circle["Name"] not in results.keys(): results[circle["Name"]] = {}
c_x, c_y, r = int(circle["centerX"]), int(circle["centerY"]), circle["radius"]
roi_image = img[c_y - r: c_y + r, c_x - r: c_x + r].copy()
mask = np.zeros(roi_image.shape[:2], dtype=np.uint8)
cv2.circle(mask, (r, r), r, 255, thickness=-1)
if len(roi_image.shape) == 2:
roi_image = cv2.bitwise_and(roi_image, roi_image, mask=mask)
else:
mask_3ch = cv2.merge([mask] * roi_image.shape[2])
roi_image = cv2.bitwise_and(roi_image, mask_3ch)
if roi_image.ndim == 3:
results[circle["Name"]][frm_idx] = np.average(np.linalg.norm(roi_image, axis=2)) / np.sqrt(3)
else:
results[circle["Name"]][frm_idx] = np.average(roi_image)
return results
[docs]class CueLightAnalyzer(ConfigReader):
"""
Analyze when cue lights are in ON and OFF states. Results are stored in the
``project_folder/csv/cue_lights`` cue lights directory.
:param Union[str, os.PathLike], config_path: path to SimBA project config file in Configparser format
:param Union[str, os.PathLike], data_dir: directory holding pose-estimation data. E.g., ``project_folder/csv/outlier_corrected_movement_location``
:param List[str] cue_light_names: Names of cue light ROIs, as defined in the SimBA ROI interface.
.. note::
`Cue light tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/cue_light_tutorial.md>`__.
References
----------
.. [1] Lรณpez-Moraga, A., Luyten, L., & Beckers, T. (2025). Generalization and extinction of platform-mediated avoidance in male and female rats.
`Scientific Reports, 15, 9730 <https://doi.org/10.1038/s41598-025-94265-x>`_.
:example:
>>> cue_light_analyzer = CueLightAnalyzer(config_path='MyProjectConfig', in_dir='project_folder/csv/outlier_corrected_movement_location', cue_light_names=['Cue_light'])
>>> cue_light_analyzer.run()
"""
def __init__(self,
config_path: Union[str, os.PathLike],
data_dir: Union[str, os.PathLike],
cue_light_names: List[str],
save_dir: Union[str, os.PathLike] = None,
core_cnt: int = -1,
detailed_data: bool = False,
verbose: bool = True):
ConfigReader.__init__(self, config_path=config_path, read_video_info=True)
check_if_dir_exists(in_dir=data_dir, source=self.__class__.__name__, raise_error=True)
check_valid_lst(data=cue_light_names, source=self.__class__.__name__, valid_dtypes=(str,), min_len=1, raise_error=True)
check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0])
check_valid_boolean(value=detailed_data, source=f'{self.__class__.__name__} detailed_data', raise_error=True)
check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose', raise_error=True)
self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=[f'.{self.file_type}'], raise_error=True, as_dict=True)
self.read_roi_data()
self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
self.cue_light_names, self.detailed_data, self.verbose = cue_light_names, detailed_data, verbose
self.detailed_df_lst = []
if save_dir is None:
self.save_dir = self.cue_lights_data_dir
else:
self.save_dir = save_dir
if not os.path.exists(self.save_dir): os.makedirs(self.save_dir)
self.video_cnt = len(list(self.data_paths.keys()))
def _get_kmeans(self,
intensities: Dict[str, Dict[int, int]]):
kmeans_timer = SimbaTimer(start=True)
if self.verbose:
print(f'Performing kmeans for {len(self.cue_light_names)} cue-lights for video {self.video_name}...')
results = {}
for cue_light_name, cue_light_data in intensities.items():
cue_light_data = dict(sorted(cue_light_data.items()))
cue_light_data = np.array(list(cue_light_data.values())).astype(np.float64)
centroids, labels, _ = Statistics().kmeans_1d(data=cue_light_data, k=2, max_iters=300, calc_medians=True)
centroids = centroids.flatten()
if centroids[0] > centroids[1]:
labels = 1 - labels
centroids = centroids[::-1]
results[cue_light_name] = {'labels': labels, 'intensities': cue_light_data, 'centroids': centroids}
kmeans_timer.stop_timer()
if self.verbose:
print(f'Kmeans for {len(self.cue_light_names)} cue-lights for video {self.video_name} complete (elapsed time: {kmeans_timer.elapsed_time}s)')
return results
def _append_light_data(self, data_df: pd.DataFrame, kmeans_data: dict):
for shape_name in self.cue_light_names:
data_df[f'{shape_name}'] = kmeans_data[shape_name]['labels']
data_df[f'{shape_name}_INTENSITY'] = kmeans_data[shape_name]['intensities']
return data_df.fillna(0)
def _remove_outlier_events(self, data_df: pd.DataFrame, time_threshold: float = 0.03):
for shape_name in self.cue_light_names:
que_light_bouts = detect_bouts(data_df=data_df, target_lst=[f'{shape_name}'], fps=self.fps)
que_light_negative_outliers = que_light_bouts[que_light_bouts["Bout_time"] <= time_threshold]
for idx, r in que_light_negative_outliers.iterrows():
data_df.loc[r["Start_frame"] - 1 : r["End_frame"] + 1, f'{shape_name}'] = 0
detailed_df = detect_bouts(data_df=data_df, target_lst=[f'{shape_name}'], fps=self.fps)
detailed_df = detailed_df.rename(columns={'Event': 'CUE LIGHT', 'Start_time': 'ONSET TIME','End Time': 'OFFSET TIME','Start_frame': 'ONSET FRAME','End_frame': 'OFFSET FRAME','Bout_time': 'ONSET DURATION'})
detailed_df['VIDEO'] = self.video_name
detailed_df = detailed_df[['VIDEO', 'CUE LIGHT', 'ONSET TIME', 'OFFSET TIME', 'ONSET FRAME', 'OFFSET FRAME', 'ONSET DURATION']]
self.detailed_df_lst.append(detailed_df)
return data_df
def run(self):
print(f"Processing {len(self.cue_light_names)} cue light(s) in {len(list(self.data_paths.keys()))} data file(s)...")
check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=list(self.data_paths.values()))
video_timer = SimbaTimer(start=True)
for file_cnt, (file_name, file_path) in enumerate(self.data_paths.items()):
self.data_df = read_df(file_path, self.file_type)
self.video_name = file_name
self.save_path = os.path.join(self.save_dir, f"{file_name}.{self.file_type}")
_, _, self.fps = self.read_video_info(video_name=file_name)
video_roi_dict, roi_names, video_roi_cnt = slice_roi_dict_from_attribute(data=self.roi_dict, shape_names=self.cue_light_names, video_names=[file_name])
missing_rois = [x for x in self.cue_light_names if x not in roi_names]
if len(missing_rois) > 0:
raise NoROIDataError(msg=f'The video {file_name} does not have cue light ROI(s) named {missing_rois}.', source=self.__class__.__name__)
self.video_path = find_video_of_file(video_dir=self.video_dir, filename=file_name, raise_error=True)
self.video_meta_data = get_video_meta_data(self.video_path)
self.frm_lst = list(range(0, self.video_meta_data["frame_count"], 1))
self.frame_chunks = np.array_split(self.frm_lst, self.core_cnt)
self.intensities = {}
print(f'Getting light intensities for {video_roi_cnt} cue light in video {file_name}... (frame count: {self.video_meta_data["frame_count"]}, video: {file_cnt+1}/{self.video_cnt})')
with multiprocessing.Pool(self.core_cnt, maxtasksperchild=Defaults.MAXIMUM_MAX_TASK_PER_CHILD.value) as pool:
constants = functools.partial(_get_intensity_scores_in_rois,
video_rois=video_roi_dict,
video_path=self.video_path,
verbose=self.verbose)
for cnt, result in enumerate(pool.imap(constants, self.frame_chunks, chunksize=self.multiprocess_chunksize)):
for key, subdict in result.items():
if key in self.intensities:self.intensities[key].update(subdict)
else: self.intensities[key] = subdict
if self.verbose:
print(f'Batch {int(np.ceil(cnt + 1 / self.core_cnt))} complete...')
terminate_cpu_pool(pool=pool, force=False)
kmeans = self._get_kmeans(intensities=self.intensities)
self.data_df = self._append_light_data(data_df=self.data_df, kmeans_data=kmeans)
self.data_df = self._remove_outlier_events(data_df=self.data_df)
write_df(self.data_df, self.file_type, self.save_path)
video_timer.stop_timer()
print(f'Cue-light data video {file_name} complete. Saved at {self.save_path} (elapsed time: {video_timer.elapsed_time_str}s).')
if self.detailed_data:
details_save_path = os.path.join(self.logs_path, f'cue_light_details_{self.datetime}.csv')
detailed_df = pd.concat(self.detailed_df_lst, axis=0).reset_index(drop=True)
detailed_df = detailed_df.sort_values(by=['VIDEO', 'CUE LIGHT', 'ONSET TIME'], ascending=True)
detailed_df.to_csv(details_save_path)
print(f'Detailed cue light data saved at {details_save_path}...')
self.timer.stop_timer()
stdout_success(msg=f"Analysed {self.video_cnt} files. Data stored in {self.save_dir}", elapsed_time=self.timer.elapsed_time)
# if __name__ == "__main__":
# test = CueLightAnalyzer(config_path=r"C:\troubleshooting\cue_light\t1\project_folder\project_config.ini",
# data_dir=r'C:\troubleshooting\cue_light\t1\project_folder\csv\outlier_corrected_movement_location',
# cue_light_names=['cl', 'cl2'],
# save_dir=r'C:\troubleshooting\cue_light\t1\project_folder\csv\cue_lights',
# core_cnt=18,
# detailed_data=True)
# test.run()
# test = CueLightAnalyzer(config_path='/Users/simon/Desktop/troubleshooting/light_analyzer/project_folder/project_config.ini',
# in_dir='/Users/simon/Desktop/troubleshooting/light_analyzer/project_folder/csv/outlier_corrected_movement_location',
# cue_light_names=['Cue_light'])
# test.run()