__author__ = "Simon Nilsson; sronilsson@gmail.com"
import functools
import multiprocessing
import os
from typing import List, Union
import cv2
import numpy as np
import pandas as pd
from simba.mixins.config_reader import ConfigReader
from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.checks import (check_file_exist_and_readable, check_int,
check_valid_boolean, check_valid_dataframe,
check_valid_lst)
from simba.utils.data import (create_color_palettes, detect_bouts,
slice_roi_dict_from_attribute,
terminate_cpu_pool)
from simba.utils.enums import Defaults, Formats, TextOptions
from simba.utils.errors import NoROIDataError, NoSpecifiedOutputError
from simba.utils.printing import stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
create_directory, find_core_cnt,
get_fn_ext, get_video_meta_data, read_df,
read_frm_of_video)
def _plot_cue_light_data(frm_idxs: list,
video_setting: bool,
frame_setting: bool,
show_pose: bool,
data_df: pd.DataFrame,
bp_names: list,
font_size: int,
x_shift: int,
y_shift: int,
frames_save_dir: str,
video_save_dir: str,
circle_size: int,
roi_dict: dict,
video_path: str,
verbose: bool):
batch_id, frame_rng = frm_idxs[0], frm_idxs[1]
start_frm, end_frm, current_frm = frame_rng[0], frame_rng[-1], frame_rng[0]
video_writer = None
video_meta_data = get_video_meta_data(video_path=video_path)
clrs = create_color_palettes(no_animals=1, map_size=len(bp_names)+1, cmaps=['Set3'])[0]
if video_setting:
fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
video_save_path = os.path.join(video_save_dir, f"{batch_id}.mp4")
video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data['fps'], (int(video_meta_data['width']*2), video_meta_data['height']))
while current_frm <= end_frm:
img = read_frm_of_video(video_path, frame_index=current_frm)
img = cv2.copyMakeBorder(img, 0, 0, 0, int(video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
if show_pose:
for bp_cnt, bp_name in enumerate(bp_names):
col_names = [f'{bp_name}_x', f'{bp_name}_y']
bp_data = data_df.loc[current_frm, col_names].values.astype(np.int32)
img = cv2.circle(img, tuple(bp_data), circle_size, clrs[bp_cnt], -1)
img = PlottingMixin().roi_dict_onto_img(img=img, roi_dict=roi_dict)
y_shift_counts = 1
for cue_light_type, cue_light_type_data in roi_dict.items():
for _, cue_light_data in cue_light_type_data.iterrows():
color, name = cue_light_data['Color BGR'], cue_light_data['Name']
light_status = 'ON' if data_df.loc[current_frm, name] == 1 else 'OFF'
light_color = (0, 255, 255) if light_status == 'ON' else color
que_light_bouts = detect_bouts(data_df=data_df.loc[0:current_frm], target_lst=[name], fps=video_meta_data['fps'])
que_light_bouts_cnt = len(que_light_bouts)
que_light_bouts_duration = round(que_light_bouts['Bout_time'].sum(), 2)
off_duration = round((((current_frm + 1)/ video_meta_data['fps'])) - que_light_bouts_duration, 2)
img = PlottingMixin().put_text(img=img, text=f"{name} STATUS:", pos=(int(video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), int(y_shift*y_shift_counts)), font_size=font_size, font_thickness=2, text_color_bg=(0, 0, 0), text_color=color)
img = PlottingMixin().put_text(img=img, text=light_status, pos=(int(video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value + x_shift), int(y_shift*y_shift_counts)), font_size=font_size, font_thickness=2, text_color_bg=(0, 0, 0), text_color=light_color)
y_shift_counts += 1
img = PlottingMixin().put_text(img=img, text=f"{name} ONSET COUNTS:", pos=(int(video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), int(y_shift*y_shift_counts)), font_size=font_size, font_thickness=2, text_color_bg=(0, 0, 0), text_color=color)
img = PlottingMixin().put_text(img=img, text=str(que_light_bouts_cnt), pos=(int(video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value + x_shift), int(y_shift*y_shift_counts)), font_size=font_size, font_thickness=2, text_color_bg=(0, 0, 0), text_color=color)
y_shift_counts += 1
img = PlottingMixin().put_text(img=img, text=f"{name} TIME ON (S):", pos=(int(video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), int(y_shift*y_shift_counts)), font_size=font_size, font_thickness=2, text_color_bg=(0, 0, 0), text_color=color)
img = PlottingMixin().put_text(img=img, text=str(que_light_bouts_duration), pos=(int(video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value + x_shift), int(y_shift*y_shift_counts)), font_size=font_size, font_thickness=2, text_color_bg=(0, 0, 0), text_color=color)
y_shift_counts += 1
img = PlottingMixin().put_text(img=img, text=f"{name} TIME OFF (S):", pos=(int(video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), int(y_shift*y_shift_counts)), font_size=font_size, font_thickness=2, text_color_bg=(0, 0, 0), text_color=color)
img = PlottingMixin().put_text(img=img, text=str(off_duration), pos=(int(video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value + x_shift), int(y_shift*y_shift_counts)), font_size=font_size, font_thickness=2, text_color_bg=(0, 0, 0), text_color=color)
y_shift_counts += 1
if video_setting:
video_writer.write(np.uint8(img))
if frame_setting:
frame_save_path = os.path.join(frames_save_dir,f"{current_frm}.png")
cv2.imwrite(frame_save_path, current_frm)
if verbose:
print(f"Cue light frame complete: {current_frm} / {video_meta_data['frame_count']}. Video: {video_meta_data['video_name']} ")
current_frm += 1
if video_setting:
video_writer.release()
return batch_id
[docs]class CueLightVisualizer(ConfigReader):
"""
Visualize SimBA computed cue-light ON and OFF states and the aggregate statistics of ON and OFF
states.
:param str config_path: path to SimBA project config file in Configparser format.
:param List[str] cue_light_names: Names of cue lights, as defined in the SimBA ROI interface.
:param str video_path: Path to video which user wants to create visualizations of cue light states and aggregate statistics for.
:param bool frame_setting: If True, creates individual frames in png format. Defaults to False.
:param bool video_setting: If True, creates compressed videos in mp4 format. Defaults to True.
.. notes:
`Cue light tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/cue_light_tutorial.md>`__.
.. video:: _static/img/CueLightVisualizer.webm
:width: 800
:autoplay:
:loop:
:muted:
:align: center
:examples:
>>> cue_light_visualizer = CueLightVisualizer(config_path='SimBAConfig', cue_light_names=['Cue_light'], video_path='VideoPath', video_setting=True, frame_setting=False)
>>> cue_light_visualizer.run()
"""
def __init__(self,
config_path: Union[str, os.PathLike],
cue_light_names: List[str],
video_path: Union[str, os.PathLike],
data_path: Union[str, os.PathLike],
frame_setting: bool = False,
video_setting: bool = True,
core_cnt: int = -1,
show_pose: bool = True,
verbose: bool = True):
ConfigReader.__init__(self, config_path=config_path)
check_valid_boolean(value=[frame_setting], source=f'{self.__class__.__name__} frame_setting', raise_error=True)
check_valid_boolean(value=[video_setting], source=f'{self.__class__.__name__} video_setting', raise_error=True)
check_valid_boolean(value=[show_pose], source=f'{self.__class__.__name__} show_pose', raise_error=True)
check_valid_boolean(value=[verbose], source=f'{self.__class__.__name__} verbose', raise_error=True)
check_valid_lst(data=cue_light_names, source=self.__class__.__name__, valid_dtypes=(str,), min_len=1, raise_error=True)
check_file_exist_and_readable(file_path=video_path)
check_file_exist_and_readable(file_path=data_path)
check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0])
self.video_meta_data = get_video_meta_data(video_path)
_, self.video_name, _ = get_fn_ext(filepath=data_path)
if (not frame_setting) and (not video_setting):
raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please choose to select either videos, frames, or both frames and videos.")
self.cue_light_names, self.video_path, self.data_path = cue_light_names, video_path, data_path
self.data_df = read_df(self.data_path, self.file_type)
self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
self.font_size, self.x_shift, self.y_shift = PlottingMixin().get_optimal_font_scales(text='ONE LONG ARSE STRING FOR YA', accepted_px_height=int(self.video_meta_data['height']/2), accepted_px_width=int(self.video_meta_data['width']/2))
self.circle_size = PlottingMixin().get_optimal_circle_size(frame_size=(self.video_meta_data['width'], self.video_meta_data['height']), circle_frame_ratio=60)
self.read_roi_data()
self.video_setting, self.frame_setting, self.data_path, self.show_pose, self.verbose = video_setting, frame_setting, data_path, show_pose, verbose
self.video_save_dir = os.path.join(self.frames_output_dir, 'cue_lights')
self.frames_save_dir = os.path.join(self.frames_output_dir, 'cue_lights')
self.video_roi_dict, roi_names, video_roi_cnt = slice_roi_dict_from_attribute(data=self.roi_dict, shape_names=self.cue_light_names, video_names=[self.video_name])
missing_rois = [x for x in self.cue_light_names if x not in roi_names]
if len(missing_rois) > 0:
raise NoROIDataError(msg=f'The video {self.video_name} does not have cue light ROI(s) named {missing_rois}.', source=self.__class__.__name__)
if show_pose:
check_valid_dataframe(df=self.data_df, source=f'{self.__class__.__name__} {data_path}', valid_dtypes=Formats.NUMERIC_DTYPES.value, required_fields=self.bp_col_names)
def run(self):
print(f"Creating video for {len(self.cue_light_names)} cue light(s) in in {self.video_name}...")
frames_dir, video_temp_dir = None, None
if self.frame_setting:
frames_dir = os.path.join(self.frames_save_dir, self.video_name)
create_directory(paths=frames_dir, overwrite=True)
if self.video_setting:
self.save_video_path = os.path.join(self.video_save_dir, f"{self.video_name}.mp4")
video_temp_dir = os.path.join(self.video_save_dir, 'temp')
create_directory(paths=video_temp_dir, overwrite=True)
self.frm_lst = list(range(0, self.video_meta_data["frame_count"], 1))
self.frame_chunks = np.array_split(self.frm_lst, self.core_cnt)
self.frame_chunks = [(x, j) for x, j in enumerate(self.frame_chunks)]
with multiprocessing.Pool(self.core_cnt, maxtasksperchild=Defaults.MAXIMUM_MAX_TASK_PER_CHILD.value) as pool:
constants = functools.partial(_plot_cue_light_data,
frame_setting=self.frame_setting,
video_setting=self.video_setting,
show_pose=self.show_pose,
data_df=self.data_df,
frames_save_dir=frames_dir,
video_save_dir=video_temp_dir,
circle_size=self.circle_size,
font_size=self.font_size,
x_shift=self.x_shift,
y_shift=self.y_shift,
roi_dict=self.video_roi_dict,
video_path=self.video_path,
bp_names=self.body_parts_lst,
verbose=self.verbose)
for cnt, result in enumerate(pool.imap(constants, self.frame_chunks, chunksize=self.multiprocess_chunksize)):
if self.verbose:
print(f'Batch {int(result+1/self.core_cnt)} complete...')
terminate_cpu_pool(pool=pool, force=False)
self.timer.stop_timer()
if self.video_setting:
print(f"Joining {self.video_name} multiprocessed video...")
concatenate_videos_in_folder(in_folder=video_temp_dir, save_path=self.save_video_path)
stdout_success(msg=f"Cue light video visualization for video {self.video_name} saved at {self.save_video_path}", elapsed_time=self.timer.elapsed_time_str)
if self.frame_setting:
stdout_success(msg=f"Cue light frame visualization for video {self.video_name} saved at {frames_dir}", elapsed_time=self.timer.elapsed_time_str)
# if __name__ == "__main__":
# test = CueLightVisualizer(config_path=r"C:\troubleshooting\cue_light\t1\project_folder\project_config.ini",
# cue_light_names=['cl'],
# video_path=r"C:\troubleshooting\cue_light\t1\project_folder\videos\2025-05-21 16-10-06_cropped.mp4",
# data_path=r"C:\troubleshooting\cue_light\t1\project_folder\csv\cue_lights\2025-05-21 16-10-06_cropped.csv",
# video_setting=True,
# frame_setting=False,
# core_cnt=23)
#
# test.run()