__author__ = "Simon Nilsson; sronilsson@gmail.com"
import functools
import itertools
import multiprocessing
import os
import platform
import shutil
from typing import Dict, List, Optional, Tuple, Union
try:
from typing import Literal
except:
from typing_extensions import Literal
import cv2
import numpy as np
import pandas as pd
from simba.mixins.config_reader import ConfigReader
from simba.mixins.geometry_mixin import GeometryMixin
from simba.mixins.plotting_mixin import PlottingMixin
from simba.roi_tools.roi_aggregate_statistics_analyzer import \
ROIAggregateStatisticsAnalyzer
from simba.roi_tools.roi_utils import get_roi_dict_from_dfs
from simba.utils.checks import (check_file_exist_and_readable, check_float,
check_if_dir_exists, check_if_valid_rgb_tuple,
check_int, check_nvidea_gpu_available,
check_str, check_valid_boolean,
check_valid_lst,
check_video_and_data_frm_count_align)
from simba.utils.data import (create_color_palettes, detect_bouts,
get_cpu_pool, slice_roi_dict_for_video,
terminate_cpu_pool)
from simba.utils.enums import (ROI_SETTINGS, Formats, Keys, Options, Paths,
TextOptions)
from simba.utils.errors import (BodypartColumnNotFoundError, DuplicationError,
NoFilesFoundError, NoROIDataError,
ROICoordinatesNotFoundError)
from simba.utils.lookups import get_simba_font_name_and_path
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
find_core_cnt, get_video_meta_data,
read_df, seconds_to_timestamp)
from simba.utils.warnings import (DuplicateNamesWarning, FrameRangeWarning,
GPUToolsWarning)
pd.options.mode.chained_assignment = None
SECONDS, HHMMSSSSSS = ['seconds', 'hh:mm:ss.ssss']
def _roi_plotter_mp(data: Tuple[int, pd.DataFrame],
loc_dict: dict,
font_size: float,
circle_sizes: list,
save_temp_directory: str,
video_shape_names: list,
input_video_path: str,
body_part_dict: dict,
roi_dfs_dict: Dict[str, pd.DataFrame],
roi_dict: dict,
bp_colors: list,
show_animal_name: bool,
font_path: Optional[str],
font_size_px: Optional[int],
show_pose: bool,
animal_ids: list,
threshold: float,
outside_roi: bool,
verbose: bool,
print_timer: str,
border_bg_clr: tuple,
animal_bp_dict: dict,
bbox: Optional[str]):
def __insert_texts(roi_dict, img):
for animal_name in animal_ids:
for shape_name, shape_data in roi_dict.items():
img = PlottingMixin().put_text(img=img, text=loc_dict[animal_name][shape_name]["timer_text"], pos=loc_dict[animal_name][shape_name]["timer_text_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in shape_data['Color BGR']), text_bg_alpha=0.0)
img = PlottingMixin().put_text(img=img, text=loc_dict[animal_name][shape_name]["entries_text"], pos=loc_dict[animal_name][shape_name]["entries_text_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in shape_data['Color BGR']), text_bg_alpha=0.0)
if outside_roi:
img = PlottingMixin().put_text(img=img, text=loc_dict[animal_name][ROI_SETTINGS.OUTSIDE_ROI.value]["timer_text"], pos=loc_dict[animal_name][ROI_SETTINGS.OUTSIDE_ROI.value]["timer_text_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=TextOptions.WHITE.value, text_bg_alpha=0.0)
img = PlottingMixin().put_text(img=img, text=loc_dict[animal_name][ROI_SETTINGS.OUTSIDE_ROI.value]["entries_text"], pos=loc_dict[animal_name][ROI_SETTINGS.OUTSIDE_ROI.value]["entries_text_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=TextOptions.WHITE.value, text_bg_alpha=0.0)
return img
return img
fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
group_cnt, data_df = data[0], data[1]
df_frm_range = data_df.index.tolist()
start_frm, current_frm, end_frm = df_frm_range[0], df_frm_range[0], df_frm_range[-1]
save_path = os.path.join(save_temp_directory, f"{group_cnt}.mp4")
font_size = font_size_px if font_path is not None else font_size
video_meta_data = get_video_meta_data(video_path=input_video_path)
writer = cv2.VideoWriter(save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"] * 2, video_meta_data["height"]))
cap = cv2.VideoCapture(input_video_path)
cap.set(1, start_frm)
while current_frm <= end_frm:
ret, img = cap.read()
if ret:
img = cv2.copyMakeBorder(img, 0, 0, 0, int(video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=border_bg_clr)
img = __insert_texts(roi_dict, img)
img = PlottingMixin.roi_dict_onto_img(img=img, roi_dict=roi_dfs_dict)
for animal_cnt, animal_name in enumerate(animal_ids):
if show_animal_name or show_pose or bbox is not None:
x, y, p = (data_df.loc[current_frm, body_part_dict[animal_name]].fillna(0.0).values.astype(np.int32))
if threshold <= p:
if show_pose:
img = cv2.circle(img, (x, y), circle_sizes[animal_cnt], bp_colors[animal_cnt], -1)
if show_animal_name:
img = PlottingMixin().put_text(img=img, text=animal_name, pos=(x, y), font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in bp_colors[animal_cnt]), text_bg_alpha=0.0)
if bbox is not None:
x_cols, y_cols = animal_bp_dict[animal_name]['X_bps'], animal_bp_dict[animal_name]['Y_bps']
animal_cols = [x for pair in zip(x_cols, y_cols) for x in pair]
animal_cords = data_df.loc[current_frm, animal_cols].fillna(0.0).values.astype(np.int32).reshape(-1, 2)
try:
if bbox == Options.AXIS_ALIGNED.value:
animal_bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
else:
animal_bbox = GeometryMixin().minimum_rotated_rectangle(shape=animal_cords, buffer=None, return_type='array')
img = cv2.polylines(img, [animal_bbox], True, bp_colors[animal_cnt], thickness=circle_sizes[animal_cnt], lineType=cv2.LINE_AA)
except Exception as e:
pass
for shape_name in video_shape_names:
shape_color = TextOptions.WHITE.value if shape_name == ROI_SETTINGS.OUTSIDE_ROI.value else roi_dict[shape_name]["Color BGR"]
timer = round(data_df.loc[current_frm, f"{animal_name}_{shape_name}_cum_sum_time"], 2)
if print_timer == HHMMSSSSSS: timer = seconds_to_timestamp(seconds=timer, hh_mm_ss_sss=True)
entries = data_df.loc[current_frm, f"{animal_name}_{shape_name}_cum_sum_entries"]
img = PlottingMixin().put_text(img=img, text=str(timer), pos=loc_dict[animal_name][shape_name]["timer_data_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in shape_color), text_bg_alpha=0.0)
img = PlottingMixin().put_text(img=img, text=str(entries), pos=loc_dict[animal_name][shape_name]["entries_data_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in shape_color), text_bg_alpha=0.0)
#img = cv2.putText(img, str(timer), loc_dict[animal_name][shape_name]["timer_data_loc"], TextOptions.FONT.value, font_size, shape_color, TextOptions.TEXT_THICKNESS.value)
#img = cv2.putText(img, str(entries), loc_dict[animal_name][shape_name]["entries_data_loc"], TextOptions.FONT.value, font_size, shape_color, TextOptions.TEXT_THICKNESS.value)
writer.write(img)
current_frm += 1
if verbose:
stdout_information(msg=f"Multi-processing video frame {current_frm}/{video_meta_data['frame_count']} (core batch: {group_cnt}, video: {video_meta_data['video_name']})...")
else:
FrameRangeWarning(msg=f'Could not read frame {current_frm} in video {video_meta_data["video_name"]}', source=_roi_plotter_mp.__name__)
break
cap.release()
writer.release()
return group_cnt
[docs]class ROIPlotMultiprocess(ConfigReader, PlottingMixin):
"""
Visualize the ROI data (number of entries/exits, time-spent in ROIs) using multiprocessing for improved performance.
.. note::
`ROI tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/ROI_tutorial_new.md>`__.
.. image:: _static/img/roi_visualize.png
:alt: Roi visualize
:width: 400
:align: center
.. image:: _static/img/ROIPlot_1.png
:alt: ROIPlot 1
:width: 1000
:align: center
.. video:: _static/img/ROIPlot_2.webm
:width: 1000
:autoplay:
:loop:
:muted:
:align: center
.. video:: _static/img/outside_roi_example.mp4
:width: 800
:autoplay:
:loop:
:muted:
:align: center
:param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
:param Union[str, os.PathLike] video_path: Path to video file to create ROI visualizations for.
:param List[str] body_parts: List of the body-parts to use as proxy for animal locations.
:param float threshold: Float between 0 and 1. Body-part locations detected below this confidence threshold are filtered. Default: 0.0.
:param int core_cnt: Number of cores to use for multiprocessing. Default: -1 (uses all available cores).
:param bool verbose: If True, print progress messages during video creation. Default: True.
:param bool outside_roi: If True, SimBA will treat all areas NOT covered by a ROI drawing as a single additional ROI and visualize the stats for this single ROI. Default: False.
:param bool show_body_part: If True, display body-part locations as circles on the video frames. Default: True.
:param bool show_animal_name: If True, display animal names on the video frames. Default: False.
:param Optional[Literal['axis-aligned', 'animal-aligned']] bbox: If not None, draw bounding boxes around each animal. ``'axis-aligned'`` = rectangle aligned with video axes; ``'animal-aligned'`` = minimum rotated rectangle aligned with the animal's orientation. Default: None (no bounding boxes).
:param Literal['seconds', 'hh:mm:ss.ssss'] print_timer: Timer format for behavior/ROI counters shown in the border panel. ``'seconds'`` = numeric seconds, ``'hh:mm:ss.ssss'`` = clock-style timestamp with fractional seconds. Default: ``'seconds'``.
:param Tuple[int, int, int] border_bg_clr: RGB tuple representing the background color of the border area where statistics are displayed. Default: (0, 0, 0).
:param Optional[Union[str, os.PathLike]] data_path: Optional path to the pose-estimation data. If None, then locates file in ``outlier_corrected_movement_location`` directory. Default: None.
:param Optional[Union[str, os.PathLike]] save_path: Optional path to where to save video. If None, then saves it in ``frames/output/roi_analysis`` directory of SimBA project. Default: None.
:param Optional[List[Tuple[int, int, int]]] bp_colors: Optional list of tuples of same length as body_parts representing the colors of the body-parts in RGB format. Defaults to None and colors are automatically chosen. Default: None.
:param Optional[List[Union[int]]] bp_sizes: Optional list of integers representing the sizes of the pose estimated body-part location circles. Defaults to None and size is automatically inferred. Default: None.
:param bool gpu: If True, use GPU acceleration for video concatenation. Default: False.
:example:
>>> test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini',
>>> video_path="/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/videos/SI_DAY3_308_CD1_PRESENT.mp4",
>>> core_cnt=7,
>>> body_parts=['Nose'],
>>> show_body_part=True,
>>> show_animal_name=False)
>>> test.run()
"""
def __init__(self,
config_path: Union[str, os.PathLike],
video_path: Union[str, os.PathLike],
body_parts: List[str],
threshold: Optional[float] = 0.0,
core_cnt: int = -1,
verbose: bool = True,
outside_roi: bool = False,
show_body_part: bool = True,
font: Optional[str] = None,
show_animal_name: bool = False,
bbox: Optional[Literal['axis-aligned', 'animal-aligned']] = None,
print_timer: Literal['seconds', 'hh:mm:ss.ssss'] = 'seconds',
border_bg_clr: Tuple[int, int, int] = (0, 0, 0),
data_path: Optional[Union[str, os.PathLike]] = None,
save_path: Optional[Union[str, os.PathLike]] = None,
bp_colors: Optional[List[Tuple[int, int, int]]] = None,
bp_sizes: Optional[List[Union[int]]] = None,
gpu: bool = False):
check_file_exist_and_readable(file_path=config_path)
ConfigReader.__init__(self, config_path=config_path)
self.video_meta_data = get_video_meta_data(video_path=video_path)
self.video_name = self.video_meta_data['video_name']
check_float(name=f'{self.__class__.__name__} threshold', value=threshold, min_value=0.0, max_value=1.0)
check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0,])
check_if_valid_rgb_tuple(source=f'{self.__class__.__name__} border_bg_clr', data=border_bg_clr, raise_error=True)
check_valid_boolean(value=[gpu], source=f'{self.__class__.__name__} gpu', raise_error=True)
check_valid_boolean(value=[outside_roi], source=f'{self.__class__.__name__} outside_roi', raise_error=True)
check_valid_boolean(value=[verbose], source=f'{self.__class__.__name__} verbose', raise_error=True)
check_valid_boolean(value=show_body_part, source=f'{self.__class__.__name__} show_body_part', raise_error=True)
check_valid_boolean(value=show_animal_name, source=f'{self.__class__.__name__} show_animal_name', raise_error=True)
check_str(name=f'{self.__class__.__name__} timer', value=print_timer, options=(SECONDS, HHMMSSSSSS,))
self.font_path, self.font = None, None
if font is not None: self.font, self.font_path = get_simba_font_name_and_path(font=font)
if bbox is not None:
check_str(name=f'{self.__class__.__name__} bbox', value=bbox, options=Options.BBOX_OPTIONS.value, allow_blank=False, raise_error=True)
if gpu and not check_nvidea_gpu_available():
GPUToolsWarning(msg='GPU not detected but GPU set to True - skipping GPU use.')
gpu = False
self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
if not os.path.isfile(self.roi_coordinates_path):
raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path)
self.read_roi_data()
self.sliced_roi_dict, self.shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name)
if len(self.shape_names) == 0:
raise NoROIDataError(msg=f"Cannot plot ROI data for video {self.video_name}. No ROIs defined for this video.")
if data_path is None:
data_path = os.path.join(self.outlier_corrected_dir, f'{self.video_name}.{self.file_type}')
else:
if not os.path.isfile(data_path):
raise NoFilesFoundError(msg=f"SIMBA ERROR: Could not find the file at path {data_path}. Make sure the data file exist to create ROI visualizations", source=self.__class__.__name__)
check_file_exist_and_readable(file_path=data_path)
if save_path is None:
save_path = os.path.join(self.project_path, Paths.ROI_ANALYSIS.value, f'{self.video_name}.mp4')
if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path))
else:
check_if_dir_exists(os.path.dirname(save_path))
self.save_path, self.data_path = save_path, data_path
check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,), min_len=1)
if outside_roi: self.shape_names.append(ROI_SETTINGS.OUTSIDE_ROI.value)
if len(set(body_parts)) != len(body_parts):
raise DuplicationError(msg=f'All body-part entries have to be unique. Got {body_parts}', source=self.__class__.__name__)
for bp in body_parts:
if bp not in self.body_parts_lst: raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__)
self.roi_analyzer = ROIAggregateStatisticsAnalyzer(config_path=self.config_path, data_path=self.data_path, detailed_bout_data=True, threshold=threshold, body_parts=body_parts, outside_rois=outside_roi, verbose=verbose)
self.roi_analyzer.run()
if bp_colors is not None:
check_valid_lst(data=bp_colors, source=f'{self.__class__.__name__} bp_colors', valid_dtypes=(tuple,), exact_len=len(body_parts), raise_error=True)
_ = [check_if_valid_rgb_tuple(x) for x in bp_colors]
self.color_lst = bp_colors
else:
self.color_lst = create_color_palettes(self.roi_analyzer.animal_cnt, len(body_parts))[0]
self.bp_sizes = bp_sizes
try:
self.detailed_roi_data = pd.concat(self.roi_analyzer.detailed_dfs, axis=0).reset_index(drop=True)
except ValueError:
self.detailed_roi_data = None
self.bp_dict = self.roi_analyzer.bp_dict
self.animal_names = [self.find_animal_name_from_body_part_name(bp_name=x, bp_dict=self.animal_bp_dict) for x in body_parts]
self.data_df = read_df(file_path=self.data_path, file_type=self.file_type).fillna(0.0).reset_index(drop=True)
self.shape_columns = []
for x in itertools.product(self.animal_names, self.shape_names):
self.data_df[f"{x[0]}_{x[1]}"] = 0; self.shape_columns.append(f"{x[0]}_{x[1]}")
self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
self.video_path = video_path
check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.video_name, raise_error=False)
self.cap = cv2.VideoCapture(self.video_path)
self.threshold, self.body_parts, self.show_animal_name, self.gpu, self.outside_roi, self.verbose, self.border_bg_clr = threshold, body_parts, show_animal_name, gpu, outside_roi, verbose, border_bg_clr
self.show_pose, self.bbox, self.print_timer = show_body_part, bbox, print_timer
self.roi_dict_ = get_roi_dict_from_dfs(rectangle_df=self.sliced_roi_dict[Keys.ROI_RECTANGLES.value], circle_df=self.sliced_roi_dict[Keys.ROI_CIRCLES.value], polygon_df=self.sliced_roi_dict[Keys.ROI_POLYGONS.value])
self.temp_folder = os.path.join(os.path.dirname(self.save_path), self.video_name, "temp")
if os.path.exists(self.temp_folder): shutil.rmtree(self.temp_folder)
os.makedirs(self.temp_folder)
self.roi_dict_ = get_roi_dict_from_dfs(rectangle_df=self.sliced_roi_dict[Keys.ROI_RECTANGLES.value], circle_df=self.sliced_roi_dict[Keys.ROI_CIRCLES.value], polygon_df=self.sliced_roi_dict[Keys.ROI_POLYGONS.value])
if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True)
def __get_circle_sizes(self):
optimal_circle_size = PlottingMixin().get_optimal_circle_size(frame_size=(int(self.video_meta_data["height"]), int(self.video_meta_data["height"])), circle_frame_ratio=100)
if self.bp_sizes is None:
self.circle_sizes = [optimal_circle_size] * len(self.animal_names)
else:
self.circle_sizes = []
for circle_size in self.bp_sizes:
if not check_int(name='circle_size', value=circle_size, min_value=1, raise_error=False)[0]:
self.circle_sizes.append(optimal_circle_size)
else:
self.circle_sizes.append(int(circle_size))
def __get_roi_columns(self):
if self.detailed_roi_data is not None:
roi_entries_dict = self.detailed_roi_data[["ANIMAL", "Event", "Start_frame", "End_frame"]].to_dict(orient="records")
for entry_dict in roi_entries_dict:
entry, exit = int(entry_dict["Start_frame"]), int(entry_dict["End_frame"])
entry_dict["frame_range"] = list(range(entry, exit + 1))
col_name = f'{entry_dict["ANIMAL"]}_{entry_dict["Event"]}'
self.data_df[col_name][self.data_df.index.isin(entry_dict["frame_range"])] = 1
def __get_text_locs(self) -> dict:
loc_dict = {}
label_strs = [f'{shape} {animal_name} timer:' for animal_name in self.animal_names for shape in self.shape_names] + [f'{shape} {animal_name} entries:' for animal_name in self.animal_names for shape in self.shape_names]
longest_text_str = max(label_strs, key=len)
self.font_size, x_spacer, y_spacer = PlottingMixin().get_optimal_font_scales(text=longest_text_str, accepted_px_width=int(self.video_meta_data["width"] / 1.5), accepted_px_height=int(self.video_meta_data["height"] / 10), text_thickness=TextOptions.TEXT_THICKNESS.value)
if self.font_path is not None:
self.video_font_size_px, x_spacer, _ = PlottingMixin().get_optimal_font_size_ttf(text=label_strs, font_path=self.font_path, accepted_px_width=int(self.video_meta_data["width"] / 1.5), accepted_px_height=int(self.video_meta_data["height"] / 10))
y_spacer = self.get_optimal_font_spacing_ttf(font_path=self.font_path, size_px=self.video_font_size_px, text=label_strs, gap=0)
else:
self.video_font_size_px = None
row_counter = TextOptions.FIRST_LINE_SPACING.value
for animal_cnt, animal_name in enumerate(self.animal_names):
loc_dict[animal_name] = {}
for shape in self.shape_names:
loc_dict[animal_name][shape] = {}
loc_dict[animal_name][shape]["timer_text"] = f"{shape} {animal_name} timer:"
loc_dict[animal_name][shape]["entries_text"] = f"{shape} {animal_name} entries:"
loc_dict[animal_name][shape]["timer_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter))
loc_dict[animal_name][shape]["timer_data_loc"] = (int(self.video_meta_data["width"] + x_spacer + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter))
row_counter += 1
loc_dict[animal_name][shape]["entries_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter))
loc_dict[animal_name][shape]["entries_data_loc"] = (int(self.video_meta_data["width"] + x_spacer + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"]- (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter))
row_counter += 1
return loc_dict
def __get_counters(self) -> dict:
cnt_dict = {}
for animal_cnt, animal_name in enumerate(self.animal_names):
cnt_dict[animal_name] = {}
for shape in self.shape_names:
cnt_dict[animal_name][shape] = {}
cnt_dict[animal_name][shape]["timer"] = 0
cnt_dict[animal_name][shape]["entries"] = 0
cnt_dict[animal_name][shape]["entry_status"] = False
return cnt_dict
def __get_cumulative_data(self):
for animal_name in self.animal_names:
for shape in self.shape_names:
self.data_df[f"{animal_name}_{shape}_cum_sum_time"] = (self.data_df[f"{animal_name}_{shape}"].cumsum() / self.video_meta_data['fps'])
roi_bouts = list(detect_bouts(data_df=self.data_df, target_lst=[f"{animal_name}_{shape}"], fps=self.video_meta_data['fps'])["Start_frame"])
self.data_df[f"{animal_name}_{shape}_entry"] = 0
self.data_df.loc[roi_bouts, f"{animal_name}_{shape}_entry"] = 1
self.data_df[f"{animal_name}_{shape}_cum_sum_entries"] = (self.data_df[f"{animal_name}_{shape}_entry"].cumsum())
def __create_shape_dicts(self):
shape_dicts = {}
for shape, df in self.roi_dict.items():
if not df["Name"].is_unique:
df = df.drop_duplicates(subset=["Name"], keep="first")
DuplicateNamesWarning(msg=f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', source=self.__class__.__name__)
d = df.set_index("Name").to_dict(orient="index")
shape_dicts = {**shape_dicts, **d}
return shape_dicts
def __get_bordered_img_size(self) -> Tuple[int, int]:
cap = cv2.VideoCapture(self.video_path)
cap.set(1, 1)
_, img = self.cap.read()
bordered_img = cv2.copyMakeBorder(img, 0, 0, 0, int(self.video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
cap.release()
return bordered_img.shape[0], bordered_img.shape[1]
def run(self):
check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.video_name, raise_error=False)
video_timer = SimbaTimer(start=True)
self.__get_circle_sizes()
self.__get_roi_columns()
self.border_img_h, self.border_img_w = self.__get_bordered_img_size()
self.loc_dict = self.__get_text_locs()
self.cnt_dict = self.__get_counters()
self.__get_cumulative_data()
data = np.array_split(self.data_df, self.core_cnt)
data = [(i, j) for i, j in enumerate(data)]
del self.data_df
del self.roi_analyzer.logger
if self.verbose: stdout_information(msg=f"Creating ROI images, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})...")
self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=self.maxtasksperchild, source=self.__class__.__name__, verbose=True)
constants = functools.partial(_roi_plotter_mp,
loc_dict=self.loc_dict,
font_size=self.font_size,
circle_sizes=self.circle_sizes,
save_temp_directory=self.temp_folder,
body_part_dict=self.bp_dict,
input_video_path=self.video_path,
roi_dfs_dict=self.sliced_roi_dict,
roi_dict = self.roi_dict_,
font_path=self.font_path,
font_size_px=self.video_font_size_px,
video_shape_names=self.shape_names,
bp_colors=self.color_lst,
print_timer=self.print_timer,
show_animal_name=self.show_animal_name,
show_pose=self.show_pose,
animal_ids=self.animal_names,
threshold=self.threshold,
outside_roi=self.outside_roi,
verbose=self.verbose,
border_bg_clr=self.border_bg_clr,
animal_bp_dict=self.animal_bp_dict,
bbox=self.bbox)
for cnt, batch_cnt in enumerate(self.pool.imap(constants, data, chunksize=self.multiprocess_chunksize)):
if self.verbose: stdout_information(msg=f'Image batch {batch_cnt+1} / {self.core_cnt} complete...')
if self.verbose: stdout_information(msg=f"Joining {self.video_name} multi-processed ROI video...")
concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_path, video_format="mp4", remove_splits=True, gpu=self.gpu, verbose=self.verbose)
terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__)
video_timer.stop_timer()
if self.verbose: stdout_success(msg=f"Video {self.video_name} created. ROI video saved at {self.save_path}", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__, )
# if __name__ == "__main__":
# test = ROIPlotMultiprocess(config_path=r"G:\projects\sleap_bp_order\by_order\project_folder\project_config.ini",
# video_path=r"G:\projects\sleap_bp_order\by_order\project_folder\videos\rat_pilot.mp4",
# body_parts=['nose'],
# outside_roi=True,
# gpu=True,
# font='poppins regular',
# core_cnt=8,
# border_bg_clr=(0, 0, 0),
# bbox=None)
# test.run()
# if __name__ == "__main__":
# test = ROIPlotMultiprocess(config_path=r"C:\troubleshooting\roi_duplicates\project_folder\project_config.ini",
# video_path=r"C:\troubleshooting\roi_duplicates\project_folder\videos\2021-12-21_15-03-57_CO_Trimmed.mp4",
# body_parts=['Snout'],
# style_attr={'show_body_part': True, 'show_animal_name': False},
# bp_sizes=[20],
# bp_colors=[(155, 255, 243)])
# test.run()
#
# if __name__ == '__main__':
# test = ROIPlotMultiprocess(config_path=r"C:\troubleshooting\platea\project_folder\project_config.ini",
# video_path=r"C:\troubleshooting\platea\project_folder\videos\Video_1.mp4",
# body_parts=['NOSE'],
# style_attr={'show_body_part': True, 'show_animal_name': False})
# test.run()
# test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/project_config.ini',
# video_path="/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/videos/raw_clip1.mp4",
# body_parts=['Snout'],
# style_attr={'show_body_part': True, 'show_animal_name': False})
# test.run()
# test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',
# video_path="/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/videos/2022-06-20_NOB_DOT_4.mp4",
# body_parts=['Nose'],
# style_attr={'show_body_part': True, 'show_animal_name': False})
# test.run()
# test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini',
# video_path="/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/videos/SI_DAY3_308_CD1_PRESENT.mp4",
# core_cnt=-1,
# style_attr={'show_body_part': True, 'show_animal_name': False},
# body_parts=['Nose'])
# test.run()
# test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
# video_path="/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.avi",
# core_cnt=7,
# style_attr={'show_body_part': True, 'show_animal_name': True},
# body_parts=['Nose_1', 'Nose_2'])
# test.run()
#
# test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',
# video_path="2022-06-20_NOB_DOT_4.mp4",
# core_cnt=7,
# style_attr={'Show_body_part': True, 'Show_animal_name': True}, body_parts={'Animal_1': 'Nose'})
# test.run()
#
#
# test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/spontenous_alternation/project_folder/project_config.ini',
# video_path="F1 HAB.mp4",
# core_cnt=5,
# style_attr={'Show_body_part': True, 'Show_animal_name': True})
# test.run()
#
# get_video_meta_data(video_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/frames/output/ROI_analysis/2022-06-20_NOB_DOT_4.mp4')
# get_video_meta_data(video_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.avi')
# test = ROIPlot(ini_path=r'/Users/simon/Desktop/troubleshooting/train_model_project/project_folder/project_config.ini', video_path=r"Together_1.avi")
# test.insert_data()
# test.visualize_ROI_data()
# test = ROIPlot(ini_path=r"Z:\DeepLabCut\DLC_extract\Troubleshooting\ROI_2_animals\project_folder\project_config.ini", video_path=r"Z:\DeepLabCut\DLC_extract\Troubleshooting\ROI_2_animals\project_folder\videos\Video7.mp4")
# test.insert_data()
# test.visualize_ROI_data()
#
# test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
# video_path="Together_1.avi",
# style_attr={'Show_body_part': True, 'Show_animal_name': False},
# core_cnt=5)
# test.run()
# test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/troubleshooting/DLC_2_Black_animals/project_folder/project_config.ini',
# video_path="Together_1.avi",
# style_attr={'Show_body_part': True, 'Show_animal_name': False},
# core_cnt=5)
# test.run()