__author__ = "Simon Nilsson; sronilsson@gmail.com"
import functools
import itertools
import multiprocessing
import os
import platform
from copy import deepcopy
from typing import Dict, Optional, Tuple, Union
try:
from typing import Literal
except:
from typing_extensions import Literal
import cv2
import numpy as np
import pandas as pd
from simba.mixins.config_reader import ConfigReader
from simba.mixins.plotting_mixin import PlottingMixin
from simba.roi_tools.ROI_directing_analyzer import DirectingROIAnalyzer
from simba.utils.checks import (check_file_exist_and_readable,
check_if_string_value_is_valid_video_timestamp,
check_if_valid_rgb_tuple, check_int, check_str,
check_that_hhmmss_start_is_before_end,
check_valid_boolean, check_valid_dict,
check_video_and_data_frm_count_align)
from simba.utils.data import (find_frame_numbers_from_time_stamp, get_cpu_pool,
slice_roi_dict_for_video, terminate_cpu_pool)
from simba.utils.enums import Formats, TextOptions
from simba.utils.errors import (NoFilesFoundError, NoROIDataError,
ROICoordinatesNotFoundError)
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
find_core_cnt, get_fn_ext,
get_video_meta_data, read_df,
remove_a_folder, seconds_to_timestamp)
START_TIME, END_TIME = 'start_time', 'end_time'
def _roi_directing_visualizer_mp(frm_range: Tuple[int, np.ndarray],
data_df: pd.DataFrame,
text_locations: dict,
font_size: float,
circle_size: Union[float, int],
save_temp_dir: str,
video_meta_data: dict,
shape_info: dict,
shape_names: list,
video_path: str,
animal_names: list,
roi_dict: dict,
animal_bp_dict: dict,
directing_data: pd.DataFrame,
border_bg_color: tuple,
show_pose: bool,
show_roi_centers: bool,
show_animal_names: bool,
direction_color: Tuple[int, int, int],
direction_thickness: int,
direction_style: str,
verbose: bool,
cumulative_directing: dict,
fps: float):
fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
font = cv2.FONT_HERSHEY_SIMPLEX
group_cnt, frm_range = frm_range[0], frm_range[1]
current_frm, end_frm = frm_range[0], frm_range[-1]
save_path = os.path.join(save_temp_dir, f"{group_cnt}.mp4")
writer = cv2.VideoWriter(save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"] * 2, video_meta_data["height"]))
cap = cv2.VideoCapture(video_path)
cap.set(1, current_frm)
directing_lk = set(zip(directing_data["Animal"], directing_data["ROI"], directing_data["Frame"]))
while current_frm <= end_frm:
ret, img = cap.read()
if ret:
img = cv2.copyMakeBorder(img, 0, 0, 0, int(video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=border_bg_color)
img = PlottingMixin.roi_dict_onto_img(img=img, roi_dict=roi_dict, circle_size=circle_size, show_center=show_roi_centers)
if show_pose:
for animal_name, bp_data in animal_bp_dict.items():
for bp_cnt, bp in enumerate(zip(bp_data["X_bps"], bp_data["Y_bps"])):
bp_cords = data_df.loc[current_frm, list(bp)].values.astype(np.int64)
cv2.circle(img, (bp_cords[0], bp_cords[1]), 0, animal_bp_dict[animal_name]["colors"][bp_cnt], circle_size)
if show_animal_names:
for animal_name, bp_data in animal_bp_dict.items():
headers = [bp_data["X_bps"][-1], bp_data["Y_bps"][-1]]
bp_cords = data_df.loc[current_frm, headers].values.astype(np.int64)
cv2.putText(img, animal_name, (bp_cords[0], bp_cords[1]), font, font_size, animal_bp_dict[animal_name]["colors"][0], 1)
for animal_name, shape_name in itertools.product(animal_names, shape_names):
is_directing = (animal_name, shape_name, current_frm) in directing_lk
shape_clr = shape_info[shape_name]["Color BGR"]
cv2.putText(img, text_locations[animal_name][shape_name]["directing_text"], text_locations[animal_name][shape_name]["directing_text_loc"], font, font_size, shape_clr, 1)
cv2.putText(img, str(is_directing), text_locations[animal_name][shape_name]["directing_data_loc"], font, font_size, shape_clr, 1)
cum_key = (animal_name, shape_name)
cum_frms = cumulative_directing[cum_key][current_frm] if current_frm < len(cumulative_directing[cum_key]) else cumulative_directing[cum_key][-1]
cum_time = seconds_to_timestamp(seconds=cum_frms / fps, hh_mm_ss_sss=True)
cv2.putText(img, text_locations[animal_name][shape_name]["total_time_text"], text_locations[animal_name][shape_name]["total_time_text_loc"], font, font_size, shape_clr, 1)
cv2.putText(img, cum_time, text_locations[animal_name][shape_name]["total_time_data_loc"], font, font_size, shape_clr, 1)
if is_directing:
img = PlottingMixin.insert_directing_line(directing_df=directing_data,
img=img,
shape_name=shape_name,
animal_name=animal_name,
frame_id=current_frm,
color=direction_color,
thickness=direction_thickness,
style=direction_style)
writer.write(np.uint8(img))
if verbose:
seconds = seconds_to_timestamp(seconds=current_frm / video_meta_data['fps'], hh_mm_ss_sss=True)
stdout_information(msg=f"Multiprocessing frame: {current_frm}, time-stamp: {seconds} on core {group_cnt}...")
current_frm += 1
else:
break
cap.release()
writer.release()
return group_cnt
[docs]class DirectingROIVisualizer(ConfigReader, PlottingMixin):
"""
Visualize when animals are directing towards ROIs. Draws the ROIs on the video frames, overlays
pose-estimation body-parts, and draws directing lines (funnel or line style) from the animal eye
midpoint to the ROI when the animal is directing towards the ROI. A text panel shows the directing
boolean for each animal-ROI combination per frame. Uses multiprocessing.
:param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
:param Union[str, os.PathLike] video_path: Path to video file to overlay directing visualization on.
:param Literal['funnel', 'lines'] direction_style: Style of direction line. Default 'funnel'.
:param Tuple[int, int, int] direction_color: BGR color of the directing line. Default (0, 0, 255) (red).
:param Optional[int] direction_thickness: Thickness of the directing line (used for 'lines' style). If None, computed automatically based on video resolution. Default None.
:param Optional[int] circle_size: Size of the pose-estimation keypoint circles. If None, computed automatically based on video resolution. Default None.
:param bool show_pose: If True, draw pose-estimation keypoints on the video. Default True.
:param bool show_roi_centers: If True, draw the center of each ROI. Default True.
:param bool show_animal_names: If True, display animal names on the video. Default False.
:param Tuple[int, int, int] border_bg_clr: BGR color for the text panel background. Default (0, 0, 0).
:param Optional[Dict[str, str]] time_slice: Optional dict with 'start_time' and 'end_time' keys (HH:MM:SS format) to visualize a sub-clip. Default None.
:param Optional[Union[str, os.PathLike]] roi_coordinates_path: Optional path to ROI definitions file. If None, uses the project default. Default None.
:param Optional[str] left_ear_name: Optional custom left ear body-part name. Default None.
:param Optional[str] right_ear_name: Optional custom right ear body-part name. Default None.
:param Optional[str] nose_name: Optional custom nose body-part name. Default None.
:param int core_cnt: Number of CPU cores for multiprocessing. -1 uses all available. Default -1.
:param bool gpu: If True, use GPU for video concatenation when available. Default False.
:param bool verbose: If True, print progress messages during visualization. Default True.
.. video:: _static/img/DirectingROIVisualizer.webm
:width: 1000
:autoplay:
:loop:
:muted:
:align: center
:example:
>>> viz = DirectingROIVisualizer(config_path='/path/to/project_config.ini',
... video_path='/path/to/video.mp4',
... direction_style='funnel',
... show_pose=True,
... core_cnt=4)
>>> viz.run()
"""
def __init__(self,
config_path: Union[str, os.PathLike],
video_path: Union[str, os.PathLike],
direction_style: Literal['funnel', 'lines'] = 'lines',
direction_color: Tuple[int, int, int] = (0, 0, 255),
direction_thickness: Optional[int] = None,
circle_size: Optional[int] = None,
show_pose: bool = True,
show_roi_centers: bool = True,
show_animal_names: bool = False,
border_bg_clr: Tuple[int, int, int] = (0, 0, 0),
time_slice: Optional[Dict[str, str]] = None,
roi_coordinates_path: Optional[Union[str, os.PathLike]] = None,
left_ear_name: Optional[str] = None,
right_ear_name: Optional[str] = None,
nose_name: Optional[str] = None,
core_cnt: int = -1,
gpu: bool = False,
verbose: bool = True):
check_file_exist_and_readable(file_path=config_path)
check_file_exist_and_readable(file_path=video_path)
check_int(name=f"{self.__class__.__name__} core_cnt", value=core_cnt, min_value=-1, unaccepted_vals=[0])
check_str(name=f"{self.__class__.__name__} direction_style", value=direction_style, options=['funnel', 'lines'])
check_if_valid_rgb_tuple(data=direction_color, source=f"{self.__class__.__name__} direction_color")
if direction_thickness is not None:
check_int(name=f"{self.__class__.__name__} direction_thickness", value=direction_thickness, min_value=1)
if circle_size is not None:
check_int(name=f"{self.__class__.__name__} circle_size", value=circle_size, min_value=1)
check_valid_boolean(value=[show_pose, show_roi_centers, show_animal_names, gpu, verbose], source=self.__class__.__name__)
check_if_valid_rgb_tuple(data=border_bg_clr, source=f"{self.__class__.__name__} border_bg_clr")
if time_slice is not None:
check_valid_dict(x=time_slice, valid_key_dtypes=(str,), valid_values_dtypes=(str,), valid_keys=(START_TIME, END_TIME), required_keys=(START_TIME, END_TIME))
check_if_string_value_is_valid_video_timestamp(value=time_slice[START_TIME], name='START TIME', raise_error=True)
check_if_string_value_is_valid_video_timestamp(value=time_slice[END_TIME], name='END TIME', raise_error=True)
check_that_hhmmss_start_is_before_end(start_time=time_slice[START_TIME], end_time=time_slice[END_TIME], name='TIME SLICE', raise_error=True)
ConfigReader.__init__(self, config_path=config_path)
PlottingMixin.__init__(self)
self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
self.gpu = gpu
self.show_pose, self.show_roi_centers, self.show_animal_names = show_pose, show_roi_centers, show_animal_names
self.border_bg_clr, self.time_slice = border_bg_clr, time_slice
self.direction_style, self.direction_color, self.direction_thickness, self.circle_size = direction_style, direction_color, direction_thickness, circle_size
self.verbose = verbose
self.video_path = video_path
if roi_coordinates_path is not None:
check_file_exist_and_readable(file_path=roi_coordinates_path)
self.roi_coordinates_path = deepcopy(roi_coordinates_path)
if not os.path.isfile(self.roi_coordinates_path):
raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path)
self.read_roi_data()
_, self.video_name, _ = get_fn_ext(video_path)
self.roi_dict, self.shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name)
if len(self.shape_names) == 0:
raise NoROIDataError(msg=f"No ROIs found for video {self.video_name}. Draw ROIs for this video before creating directing visualizations.", source=self.__class__.__name__)
self.data_path = os.path.join(self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}")
if not os.path.isfile(self.data_path):
raise NoFilesFoundError(msg=f"Could not find the file at path {self.data_path}. Make sure the data file exists to create directing ROI visualizations.", source=self.__class__.__name__)
self.directing_analyzer = DirectingROIAnalyzer(config_path=config_path,
data_path=self.data_path,
left_ear_name=left_ear_name,
right_ear_name=right_ear_name,
nose_name=nose_name)
self.directing_analyzer.run()
self.directing_df = self.directing_analyzer.results_df
self.animal_names = list(self.directing_analyzer.direct_bp_dict.keys())
self.video_meta_data = get_video_meta_data(video_path, fps_as_int=False)
if direction_thickness is None:
self.direction_thickness = max(1, self.get_optimal_circle_size(frame_size=(int(self.video_meta_data["height"]), int(self.video_meta_data["width"])), circle_frame_ratio=200))
self.data_df = read_df(file_path=self.data_path, file_type=self.file_type)
check_video_and_data_frm_count_align(video=video_path, data=self.data_path, name=video_path, raise_error=False)
self.save_dir = os.path.join(self.frames_output_dir, "ROI_directing_visualizations")
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.save_path = os.path.join(self.save_dir, f"{self.video_name}.mp4")
self.save_temp_dir = os.path.join(self.save_dir, "temp")
if os.path.exists(self.save_temp_dir):
remove_a_folder(folder_dir=self.save_temp_dir)
os.makedirs(self.save_temp_dir)
self.shape_dicts = self.__create_shape_dicts()
if platform.system() == "Darwin":
multiprocessing.set_start_method("spawn", force=True)
def __create_shape_dicts(self):
shape_dicts = {}
for shape, df in self.roi_dict.items():
if not df["Name"].is_unique:
df = df.drop_duplicates(subset=["Name"], keep="first")
d = df.set_index("Name").to_dict(orient="index")
shape_dicts = {**shape_dicts, **d}
return shape_dicts
def __calc_text_locs(self):
add_spacer = TextOptions.FIRST_LINE_SPACING.value
self.loc_dict = {}
txt_strs = []
for animal_name in self.animal_names:
for shape in self.shape_names:
txt_strs.append(f"{shape} {animal_name} directing")
txt_strs.append(f"{shape} {animal_name} total time (s)")
longest_text_str = str(max(txt_strs, key=len)) if len(txt_strs) > 0 else "N/A"
self.font_size, x_spacer, y_spacer = self.get_optimal_font_scales(text=longest_text_str, accepted_px_width=int(self.video_meta_data["width"] / 2), accepted_px_height=int(self.video_meta_data["height"] / 15), text_thickness=3)
if self.circle_size is None:
self.circle_size = self.get_optimal_circle_size(frame_size=(int(self.video_meta_data["height"]), int(self.video_meta_data["height"])), circle_frame_ratio=100)
for animal_name in self.animal_names:
self.loc_dict[animal_name] = {}
for shape in self.shape_names:
self.loc_dict[animal_name][shape] = {}
self.loc_dict[animal_name][shape]["directing_text"] = f"{shape} {animal_name} directing"
self.loc_dict[animal_name][shape]["directing_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + y_spacer * add_spacer))
self.loc_dict[animal_name][shape]["directing_data_loc"] = (int(self.video_meta_data["width"] + x_spacer + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + y_spacer * add_spacer))
add_spacer += 1
self.loc_dict[animal_name][shape]["total_time_text"] = f"{shape} {animal_name} total time (s)"
self.loc_dict[animal_name][shape]["total_time_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + y_spacer * add_spacer))
self.loc_dict[animal_name][shape]["total_time_data_loc"] = (int(self.video_meta_data["width"] + x_spacer + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + y_spacer * add_spacer))
add_spacer += 1
def run(self):
self.timer = SimbaTimer(start=True)
self.__calc_text_locs()
data_df = self.data_df.copy()
if self.time_slice is not None:
frm_ids = find_frame_numbers_from_time_stamp(start_time=self.time_slice[START_TIME], end_time=self.time_slice[END_TIME], fps=int(self.video_meta_data['fps']))
data_df = data_df.loc[frm_ids].reset_index(drop=True)
n_frms = len(data_df)
cumulative_directing = {}
directing_lk = set(zip(self.directing_df["Animal"], self.directing_df["ROI"], self.directing_df["Frame"]))
for animal_name, shape_name in itertools.product(self.animal_names, self.shape_names):
arr = np.zeros(n_frms, dtype=np.int64)
for i in range(n_frms):
arr[i] = 1 if (animal_name, shape_name, i) in directing_lk else 0
cumulative_directing[(animal_name, shape_name)] = np.cumsum(arr)
frm_lst = np.arange(0, n_frms)
frm_lst = np.array_split(frm_lst, self.core_cnt)
frame_range = [(i, frm_lst[i]) for i in range(len(frm_lst))]
if self.verbose:
stdout_information(msg=f"Creating ROI directing visualization for video {self.video_name}, multiprocessing (cores: {self.core_cnt})...")
pool = get_cpu_pool(core_cnt=self.core_cnt, verbose=self.verbose, source=self.__class__.__name__)
constants = functools.partial(_roi_directing_visualizer_mp,
data_df=data_df.reset_index(drop=True),
text_locations=self.loc_dict,
font_size=self.font_size,
circle_size=self.circle_size,
video_meta_data=self.video_meta_data,
shape_info=self.shape_dicts,
roi_dict=self.roi_dict,
save_temp_dir=self.save_temp_dir,
directing_data=self.directing_df,
shape_names=self.shape_names,
video_path=self.video_path,
animal_names=self.animal_names,
animal_bp_dict=self.animal_bp_dict,
border_bg_color=self.border_bg_clr,
show_pose=self.show_pose,
show_roi_centers=self.show_roi_centers,
show_animal_names=self.show_animal_names,
direction_color=self.direction_color,
direction_thickness=self.direction_thickness,
direction_style=self.direction_style,
verbose=self.verbose,
cumulative_directing=cumulative_directing,
fps=self.video_meta_data['fps'])
for cnt, result in enumerate(pool.imap(constants, frame_range, chunksize=self.multiprocess_chunksize)):
if self.verbose:
stdout_information(msg=f"Batch core {result + 1}/{self.core_cnt} complete...")
if self.verbose:
stdout_information(f"Joining {self.video_name} multi-processed video...")
concatenate_videos_in_folder(in_folder=self.save_temp_dir, save_path=self.save_path, video_format="mp4", remove_splits=True, gpu=self.gpu)
self.timer.stop_timer()
terminate_cpu_pool(pool=pool, force=False, verbose=self.verbose, source=self.__class__.__name__)
stdout_success(msg=f"ROI directing visualization for video {self.video_name} complete. Video saved at {self.save_path}.", elapsed_time=self.timer.elapsed_time_str)
# if __name__ == '__main__':
# test = DirectingROIVisualizer(config_path=r"E:\troubleshooting\mitra_emergence_hour\project_folder\project_config.ini",
# video_path=r"E:\troubleshooting\mitra_emergence_hour\project_folder\videos\Box1_180mISOcontrol_Females.mp4",
# direction_style='funnel',
# show_pose=True,
# time_slice={'start_time': '00:00:00', 'end_time': '00:00:10'},
# core_cnt=4)
# test.run()