Source code for simba.plotting.roi_plotter_mp

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import functools
import itertools
import multiprocessing
import os
import platform
import shutil
from typing import Dict, List, Optional, Tuple, Union

try:
    from typing import Literal
except:
    from typing_extensions import Literal

import cv2
import numpy as np
import pandas as pd

from simba.mixins.config_reader import ConfigReader
from simba.mixins.geometry_mixin import GeometryMixin
from simba.mixins.plotting_mixin import PlottingMixin
from simba.roi_tools.roi_aggregate_statistics_analyzer import \
    ROIAggregateStatisticsAnalyzer
from simba.roi_tools.roi_utils import get_roi_dict_from_dfs
from simba.utils.checks import (check_file_exist_and_readable, check_float,
                                check_if_dir_exists, check_if_valid_rgb_tuple,
                                check_int, check_nvidea_gpu_available,
                                check_str, check_valid_boolean,
                                check_valid_lst,
                                check_video_and_data_frm_count_align)
from simba.utils.data import (create_color_palettes, detect_bouts,
                              get_cpu_pool, slice_roi_dict_for_video,
                              terminate_cpu_pool)
from simba.utils.enums import (ROI_SETTINGS, Formats, Keys, Options, Paths,
                               TextOptions)
from simba.utils.errors import (BodypartColumnNotFoundError, DuplicationError,
                                NoFilesFoundError, NoROIDataError,
                                ROICoordinatesNotFoundError)
from simba.utils.lookups import get_simba_font_name_and_path
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
                                    find_core_cnt, get_video_meta_data,
                                    read_df, seconds_to_timestamp)
from simba.utils.warnings import (DuplicateNamesWarning, FrameRangeWarning,
                                  GPUToolsWarning)

pd.options.mode.chained_assignment = None
SECONDS, HHMMSSSSSS = ['seconds', 'hh:mm:ss.ssss']



def _roi_plotter_mp(data: Tuple[int, pd.DataFrame],
                    loc_dict: dict,
                    font_size: float,
                    circle_sizes: list,
                    save_temp_directory: str,
                    video_shape_names: list,
                    input_video_path: str,
                    body_part_dict: dict,
                    roi_dfs_dict: Dict[str, pd.DataFrame],
                    roi_dict: dict,
                    bp_colors: list,
                    show_animal_name: bool,
                    font_path: Optional[str],
                    font_size_px: Optional[int],
                    show_pose: bool,
                    animal_ids: list,
                    threshold: float,
                    outside_roi: bool,
                    verbose: bool,
                    print_timer: str,
                    border_bg_clr: tuple,
                    animal_bp_dict: dict,
                    bbox: Optional[str]):

    def __insert_texts(roi_dict, img):
        for animal_name in animal_ids:
            for shape_name, shape_data in roi_dict.items():
                img = PlottingMixin().put_text(img=img, text=loc_dict[animal_name][shape_name]["timer_text"], pos=loc_dict[animal_name][shape_name]["timer_text_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in shape_data['Color BGR']), text_bg_alpha=0.0)
                img = PlottingMixin().put_text(img=img, text=loc_dict[animal_name][shape_name]["entries_text"], pos=loc_dict[animal_name][shape_name]["entries_text_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in shape_data['Color BGR']), text_bg_alpha=0.0)
            if outside_roi:
                img = PlottingMixin().put_text(img=img, text=loc_dict[animal_name][ROI_SETTINGS.OUTSIDE_ROI.value]["timer_text"], pos=loc_dict[animal_name][ROI_SETTINGS.OUTSIDE_ROI.value]["timer_text_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=TextOptions.WHITE.value, text_bg_alpha=0.0)
                img = PlottingMixin().put_text(img=img, text=loc_dict[animal_name][ROI_SETTINGS.OUTSIDE_ROI.value]["entries_text"], pos=loc_dict[animal_name][ROI_SETTINGS.OUTSIDE_ROI.value]["entries_text_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=TextOptions.WHITE.value, text_bg_alpha=0.0)
            return img


        return img

    fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
    group_cnt, data_df = data[0], data[1]
    df_frm_range = data_df.index.tolist()
    start_frm, current_frm, end_frm = df_frm_range[0], df_frm_range[0], df_frm_range[-1]
    save_path = os.path.join(save_temp_directory, f"{group_cnt}.mp4")
    font_size = font_size_px if font_path is not None else font_size
    video_meta_data = get_video_meta_data(video_path=input_video_path)
    writer = cv2.VideoWriter(save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"] * 2, video_meta_data["height"]))
    cap = cv2.VideoCapture(input_video_path)
    cap.set(1, start_frm)

    while current_frm <= end_frm:
        ret, img = cap.read()
        if ret:
            img = cv2.copyMakeBorder(img, 0, 0, 0, int(video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=border_bg_clr)
            img = __insert_texts(roi_dict, img)
            img = PlottingMixin.roi_dict_onto_img(img=img, roi_dict=roi_dfs_dict)
            for animal_cnt, animal_name in enumerate(animal_ids):
                if show_animal_name or show_pose or bbox is not None:
                    x, y, p = (data_df.loc[current_frm, body_part_dict[animal_name]].fillna(0.0).values.astype(np.int32))
                    if threshold <= p:
                        if show_pose:
                            img = cv2.circle(img, (x, y), circle_sizes[animal_cnt], bp_colors[animal_cnt], -1)
                        if show_animal_name:
                            img = PlottingMixin().put_text(img=img, text=animal_name, pos=(x, y), font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in bp_colors[animal_cnt]), text_bg_alpha=0.0)
                        if bbox is not None:
                            x_cols, y_cols = animal_bp_dict[animal_name]['X_bps'], animal_bp_dict[animal_name]['Y_bps']
                            animal_cols = [x for pair in zip(x_cols, y_cols) for x in pair]
                            animal_cords = data_df.loc[current_frm, animal_cols].fillna(0.0).values.astype(np.int32).reshape(-1, 2)
                            try:
                                if bbox == Options.AXIS_ALIGNED.value:
                                    animal_bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
                                else:
                                    animal_bbox = GeometryMixin().minimum_rotated_rectangle(shape=animal_cords, buffer=None, return_type='array')
                                img = cv2.polylines(img, [animal_bbox], True, bp_colors[animal_cnt], thickness=circle_sizes[animal_cnt], lineType=cv2.LINE_AA)
                            except Exception as e:
                                pass
                for shape_name in video_shape_names:
                    shape_color = TextOptions.WHITE.value if shape_name == ROI_SETTINGS.OUTSIDE_ROI.value else roi_dict[shape_name]["Color BGR"]
                    timer = round(data_df.loc[current_frm, f"{animal_name}_{shape_name}_cum_sum_time"], 2)
                    if print_timer == HHMMSSSSSS: timer = seconds_to_timestamp(seconds=timer, hh_mm_ss_sss=True)
                    entries = data_df.loc[current_frm, f"{animal_name}_{shape_name}_cum_sum_entries"]
                    img = PlottingMixin().put_text(img=img, text=str(timer), pos=loc_dict[animal_name][shape_name]["timer_data_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in shape_color), text_bg_alpha=0.0)
                    img = PlottingMixin().put_text(img=img, text=str(entries), pos=loc_dict[animal_name][shape_name]["entries_data_loc"], font_size=font_size, font=TextOptions.FONT.value, font_thickness=TextOptions.TEXT_THICKNESS.value, font_path=font_path, text_color=tuple(int(v) for v in shape_color), text_bg_alpha=0.0)
                    #img = cv2.putText(img, str(timer), loc_dict[animal_name][shape_name]["timer_data_loc"], TextOptions.FONT.value, font_size, shape_color, TextOptions.TEXT_THICKNESS.value)
                    #img = cv2.putText(img, str(entries), loc_dict[animal_name][shape_name]["entries_data_loc"], TextOptions.FONT.value, font_size, shape_color, TextOptions.TEXT_THICKNESS.value)

            writer.write(img)
            current_frm += 1
            if verbose:
                stdout_information(msg=f"Multi-processing video frame {current_frm}/{video_meta_data['frame_count']} (core batch: {group_cnt}, video: {video_meta_data['video_name']})...")
        else:
            FrameRangeWarning(msg=f'Could not read frame {current_frm} in video {video_meta_data["video_name"]}', source=_roi_plotter_mp.__name__)
            break

    cap.release()
    writer.release()
    return group_cnt

[docs]class ROIPlotMultiprocess(ConfigReader, PlottingMixin): """ Visualize the ROI data (number of entries/exits, time-spent in ROIs) using multiprocessing for improved performance. .. note:: `ROI tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/ROI_tutorial_new.md>`__. .. image:: _static/img/roi_visualize.png :alt: Roi visualize :width: 400 :align: center .. image:: _static/img/ROIPlot_1.png :alt: ROIPlot 1 :width: 1000 :align: center .. video:: _static/img/ROIPlot_2.webm :width: 1000 :autoplay: :loop: :muted: :align: center .. video:: _static/img/outside_roi_example.mp4 :width: 800 :autoplay: :loop: :muted: :align: center :param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format. :param Union[str, os.PathLike] video_path: Path to video file to create ROI visualizations for. :param List[str] body_parts: List of the body-parts to use as proxy for animal locations. :param float threshold: Float between 0 and 1. Body-part locations detected below this confidence threshold are filtered. Default: 0.0. :param int core_cnt: Number of cores to use for multiprocessing. Default: -1 (uses all available cores). :param bool verbose: If True, print progress messages during video creation. Default: True. :param bool outside_roi: If True, SimBA will treat all areas NOT covered by a ROI drawing as a single additional ROI and visualize the stats for this single ROI. Default: False. :param bool show_body_part: If True, display body-part locations as circles on the video frames. Default: True. :param bool show_animal_name: If True, display animal names on the video frames. Default: False. :param Optional[Literal['axis-aligned', 'animal-aligned']] bbox: If not None, draw bounding boxes around each animal. ``'axis-aligned'`` = rectangle aligned with video axes; ``'animal-aligned'`` = minimum rotated rectangle aligned with the animal's orientation. Default: None (no bounding boxes). :param Literal['seconds', 'hh:mm:ss.ssss'] print_timer: Timer format for behavior/ROI counters shown in the border panel. ``'seconds'`` = numeric seconds, ``'hh:mm:ss.ssss'`` = clock-style timestamp with fractional seconds. Default: ``'seconds'``. :param Tuple[int, int, int] border_bg_clr: RGB tuple representing the background color of the border area where statistics are displayed. Default: (0, 0, 0). :param Optional[Union[str, os.PathLike]] data_path: Optional path to the pose-estimation data. If None, then locates file in ``outlier_corrected_movement_location`` directory. Default: None. :param Optional[Union[str, os.PathLike]] save_path: Optional path to where to save video. If None, then saves it in ``frames/output/roi_analysis`` directory of SimBA project. Default: None. :param Optional[List[Tuple[int, int, int]]] bp_colors: Optional list of tuples of same length as body_parts representing the colors of the body-parts in RGB format. Defaults to None and colors are automatically chosen. Default: None. :param Optional[List[Union[int]]] bp_sizes: Optional list of integers representing the sizes of the pose estimated body-part location circles. Defaults to None and size is automatically inferred. Default: None. :param bool gpu: If True, use GPU acceleration for video concatenation. Default: False. :example: >>> test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini', >>> video_path="/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/videos/SI_DAY3_308_CD1_PRESENT.mp4", >>> core_cnt=7, >>> body_parts=['Nose'], >>> show_body_part=True, >>> show_animal_name=False) >>> test.run() """ def __init__(self, config_path: Union[str, os.PathLike], video_path: Union[str, os.PathLike], body_parts: List[str], threshold: Optional[float] = 0.0, core_cnt: int = -1, verbose: bool = True, outside_roi: bool = False, show_body_part: bool = True, font: Optional[str] = None, show_animal_name: bool = False, bbox: Optional[Literal['axis-aligned', 'animal-aligned']] = None, print_timer: Literal['seconds', 'hh:mm:ss.ssss'] = 'seconds', border_bg_clr: Tuple[int, int, int] = (0, 0, 0), data_path: Optional[Union[str, os.PathLike]] = None, save_path: Optional[Union[str, os.PathLike]] = None, bp_colors: Optional[List[Tuple[int, int, int]]] = None, bp_sizes: Optional[List[Union[int]]] = None, gpu: bool = False): check_file_exist_and_readable(file_path=config_path) ConfigReader.__init__(self, config_path=config_path) self.video_meta_data = get_video_meta_data(video_path=video_path) self.video_name = self.video_meta_data['video_name'] check_float(name=f'{self.__class__.__name__} threshold', value=threshold, min_value=0.0, max_value=1.0) check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0,]) check_if_valid_rgb_tuple(source=f'{self.__class__.__name__} border_bg_clr', data=border_bg_clr, raise_error=True) check_valid_boolean(value=[gpu], source=f'{self.__class__.__name__} gpu', raise_error=True) check_valid_boolean(value=[outside_roi], source=f'{self.__class__.__name__} outside_roi', raise_error=True) check_valid_boolean(value=[verbose], source=f'{self.__class__.__name__} verbose', raise_error=True) check_valid_boolean(value=show_body_part, source=f'{self.__class__.__name__} show_body_part', raise_error=True) check_valid_boolean(value=show_animal_name, source=f'{self.__class__.__name__} show_animal_name', raise_error=True) check_str(name=f'{self.__class__.__name__} timer', value=print_timer, options=(SECONDS, HHMMSSSSSS,)) self.font_path, self.font = None, None if font is not None: self.font, self.font_path = get_simba_font_name_and_path(font=font) if bbox is not None: check_str(name=f'{self.__class__.__name__} bbox', value=bbox, options=Options.BBOX_OPTIONS.value, allow_blank=False, raise_error=True) if gpu and not check_nvidea_gpu_available(): GPUToolsWarning(msg='GPU not detected but GPU set to True - skipping GPU use.') gpu = False self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt if not os.path.isfile(self.roi_coordinates_path): raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path) self.read_roi_data() self.sliced_roi_dict, self.shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name) if len(self.shape_names) == 0: raise NoROIDataError(msg=f"Cannot plot ROI data for video {self.video_name}. No ROIs defined for this video.") if data_path is None: data_path = os.path.join(self.outlier_corrected_dir, f'{self.video_name}.{self.file_type}') else: if not os.path.isfile(data_path): raise NoFilesFoundError(msg=f"SIMBA ERROR: Could not find the file at path {data_path}. Make sure the data file exist to create ROI visualizations", source=self.__class__.__name__) check_file_exist_and_readable(file_path=data_path) if save_path is None: save_path = os.path.join(self.project_path, Paths.ROI_ANALYSIS.value, f'{self.video_name}.mp4') if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) else: check_if_dir_exists(os.path.dirname(save_path)) self.save_path, self.data_path = save_path, data_path check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,), min_len=1) if outside_roi: self.shape_names.append(ROI_SETTINGS.OUTSIDE_ROI.value) if len(set(body_parts)) != len(body_parts): raise DuplicationError(msg=f'All body-part entries have to be unique. Got {body_parts}', source=self.__class__.__name__) for bp in body_parts: if bp not in self.body_parts_lst: raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__) self.roi_analyzer = ROIAggregateStatisticsAnalyzer(config_path=self.config_path, data_path=self.data_path, detailed_bout_data=True, threshold=threshold, body_parts=body_parts, outside_rois=outside_roi, verbose=verbose) self.roi_analyzer.run() if bp_colors is not None: check_valid_lst(data=bp_colors, source=f'{self.__class__.__name__} bp_colors', valid_dtypes=(tuple,), exact_len=len(body_parts), raise_error=True) _ = [check_if_valid_rgb_tuple(x) for x in bp_colors] self.color_lst = bp_colors else: self.color_lst = create_color_palettes(self.roi_analyzer.animal_cnt, len(body_parts))[0] self.bp_sizes = bp_sizes try: self.detailed_roi_data = pd.concat(self.roi_analyzer.detailed_dfs, axis=0).reset_index(drop=True) except ValueError: self.detailed_roi_data = None self.bp_dict = self.roi_analyzer.bp_dict self.animal_names = [self.find_animal_name_from_body_part_name(bp_name=x, bp_dict=self.animal_bp_dict) for x in body_parts] self.data_df = read_df(file_path=self.data_path, file_type=self.file_type).fillna(0.0).reset_index(drop=True) self.shape_columns = [] for x in itertools.product(self.animal_names, self.shape_names): self.data_df[f"{x[0]}_{x[1]}"] = 0; self.shape_columns.append(f"{x[0]}_{x[1]}") self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) self.video_path = video_path check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.video_name, raise_error=False) self.cap = cv2.VideoCapture(self.video_path) self.threshold, self.body_parts, self.show_animal_name, self.gpu, self.outside_roi, self.verbose, self.border_bg_clr = threshold, body_parts, show_animal_name, gpu, outside_roi, verbose, border_bg_clr self.show_pose, self.bbox, self.print_timer = show_body_part, bbox, print_timer self.roi_dict_ = get_roi_dict_from_dfs(rectangle_df=self.sliced_roi_dict[Keys.ROI_RECTANGLES.value], circle_df=self.sliced_roi_dict[Keys.ROI_CIRCLES.value], polygon_df=self.sliced_roi_dict[Keys.ROI_POLYGONS.value]) self.temp_folder = os.path.join(os.path.dirname(self.save_path), self.video_name, "temp") if os.path.exists(self.temp_folder): shutil.rmtree(self.temp_folder) os.makedirs(self.temp_folder) self.roi_dict_ = get_roi_dict_from_dfs(rectangle_df=self.sliced_roi_dict[Keys.ROI_RECTANGLES.value], circle_df=self.sliced_roi_dict[Keys.ROI_CIRCLES.value], polygon_df=self.sliced_roi_dict[Keys.ROI_POLYGONS.value]) if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True) def __get_circle_sizes(self): optimal_circle_size = PlottingMixin().get_optimal_circle_size(frame_size=(int(self.video_meta_data["height"]), int(self.video_meta_data["height"])), circle_frame_ratio=100) if self.bp_sizes is None: self.circle_sizes = [optimal_circle_size] * len(self.animal_names) else: self.circle_sizes = [] for circle_size in self.bp_sizes: if not check_int(name='circle_size', value=circle_size, min_value=1, raise_error=False)[0]: self.circle_sizes.append(optimal_circle_size) else: self.circle_sizes.append(int(circle_size)) def __get_roi_columns(self): if self.detailed_roi_data is not None: roi_entries_dict = self.detailed_roi_data[["ANIMAL", "Event", "Start_frame", "End_frame"]].to_dict(orient="records") for entry_dict in roi_entries_dict: entry, exit = int(entry_dict["Start_frame"]), int(entry_dict["End_frame"]) entry_dict["frame_range"] = list(range(entry, exit + 1)) col_name = f'{entry_dict["ANIMAL"]}_{entry_dict["Event"]}' self.data_df[col_name][self.data_df.index.isin(entry_dict["frame_range"])] = 1 def __get_text_locs(self) -> dict: loc_dict = {} label_strs = [f'{shape} {animal_name} timer:' for animal_name in self.animal_names for shape in self.shape_names] + [f'{shape} {animal_name} entries:' for animal_name in self.animal_names for shape in self.shape_names] longest_text_str = max(label_strs, key=len) self.font_size, x_spacer, y_spacer = PlottingMixin().get_optimal_font_scales(text=longest_text_str, accepted_px_width=int(self.video_meta_data["width"] / 1.5), accepted_px_height=int(self.video_meta_data["height"] / 10), text_thickness=TextOptions.TEXT_THICKNESS.value) if self.font_path is not None: self.video_font_size_px, x_spacer, _ = PlottingMixin().get_optimal_font_size_ttf(text=label_strs, font_path=self.font_path, accepted_px_width=int(self.video_meta_data["width"] / 1.5), accepted_px_height=int(self.video_meta_data["height"] / 10)) y_spacer = self.get_optimal_font_spacing_ttf(font_path=self.font_path, size_px=self.video_font_size_px, text=label_strs, gap=0) else: self.video_font_size_px = None row_counter = TextOptions.FIRST_LINE_SPACING.value for animal_cnt, animal_name in enumerate(self.animal_names): loc_dict[animal_name] = {} for shape in self.shape_names: loc_dict[animal_name][shape] = {} loc_dict[animal_name][shape]["timer_text"] = f"{shape} {animal_name} timer:" loc_dict[animal_name][shape]["entries_text"] = f"{shape} {animal_name} entries:" loc_dict[animal_name][shape]["timer_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter)) loc_dict[animal_name][shape]["timer_data_loc"] = (int(self.video_meta_data["width"] + x_spacer + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter)) row_counter += 1 loc_dict[animal_name][shape]["entries_text_loc"] = ((self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter)) loc_dict[animal_name][shape]["entries_data_loc"] = (int(self.video_meta_data["width"] + x_spacer + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"]- (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter)) row_counter += 1 return loc_dict def __get_counters(self) -> dict: cnt_dict = {} for animal_cnt, animal_name in enumerate(self.animal_names): cnt_dict[animal_name] = {} for shape in self.shape_names: cnt_dict[animal_name][shape] = {} cnt_dict[animal_name][shape]["timer"] = 0 cnt_dict[animal_name][shape]["entries"] = 0 cnt_dict[animal_name][shape]["entry_status"] = False return cnt_dict def __get_cumulative_data(self): for animal_name in self.animal_names: for shape in self.shape_names: self.data_df[f"{animal_name}_{shape}_cum_sum_time"] = (self.data_df[f"{animal_name}_{shape}"].cumsum() / self.video_meta_data['fps']) roi_bouts = list(detect_bouts(data_df=self.data_df, target_lst=[f"{animal_name}_{shape}"], fps=self.video_meta_data['fps'])["Start_frame"]) self.data_df[f"{animal_name}_{shape}_entry"] = 0 self.data_df.loc[roi_bouts, f"{animal_name}_{shape}_entry"] = 1 self.data_df[f"{animal_name}_{shape}_cum_sum_entries"] = (self.data_df[f"{animal_name}_{shape}_entry"].cumsum()) def __create_shape_dicts(self): shape_dicts = {} for shape, df in self.roi_dict.items(): if not df["Name"].is_unique: df = df.drop_duplicates(subset=["Name"], keep="first") DuplicateNamesWarning(msg=f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', source=self.__class__.__name__) d = df.set_index("Name").to_dict(orient="index") shape_dicts = {**shape_dicts, **d} return shape_dicts def __get_bordered_img_size(self) -> Tuple[int, int]: cap = cv2.VideoCapture(self.video_path) cap.set(1, 1) _, img = self.cap.read() bordered_img = cv2.copyMakeBorder(img, 0, 0, 0, int(self.video_meta_data["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]) cap.release() return bordered_img.shape[0], bordered_img.shape[1] def run(self): check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.video_name, raise_error=False) video_timer = SimbaTimer(start=True) self.__get_circle_sizes() self.__get_roi_columns() self.border_img_h, self.border_img_w = self.__get_bordered_img_size() self.loc_dict = self.__get_text_locs() self.cnt_dict = self.__get_counters() self.__get_cumulative_data() data = np.array_split(self.data_df, self.core_cnt) data = [(i, j) for i, j in enumerate(data)] del self.data_df del self.roi_analyzer.logger if self.verbose: stdout_information(msg=f"Creating ROI images, multiprocessing (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})...") self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=self.maxtasksperchild, source=self.__class__.__name__, verbose=True) constants = functools.partial(_roi_plotter_mp, loc_dict=self.loc_dict, font_size=self.font_size, circle_sizes=self.circle_sizes, save_temp_directory=self.temp_folder, body_part_dict=self.bp_dict, input_video_path=self.video_path, roi_dfs_dict=self.sliced_roi_dict, roi_dict = self.roi_dict_, font_path=self.font_path, font_size_px=self.video_font_size_px, video_shape_names=self.shape_names, bp_colors=self.color_lst, print_timer=self.print_timer, show_animal_name=self.show_animal_name, show_pose=self.show_pose, animal_ids=self.animal_names, threshold=self.threshold, outside_roi=self.outside_roi, verbose=self.verbose, border_bg_clr=self.border_bg_clr, animal_bp_dict=self.animal_bp_dict, bbox=self.bbox) for cnt, batch_cnt in enumerate(self.pool.imap(constants, data, chunksize=self.multiprocess_chunksize)): if self.verbose: stdout_information(msg=f'Image batch {batch_cnt+1} / {self.core_cnt} complete...') if self.verbose: stdout_information(msg=f"Joining {self.video_name} multi-processed ROI video...") concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_path, video_format="mp4", remove_splits=True, gpu=self.gpu, verbose=self.verbose) terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__) video_timer.stop_timer() if self.verbose: stdout_success(msg=f"Video {self.video_name} created. ROI video saved at {self.save_path}", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__, )
# if __name__ == "__main__": # test = ROIPlotMultiprocess(config_path=r"G:\projects\sleap_bp_order\by_order\project_folder\project_config.ini", # video_path=r"G:\projects\sleap_bp_order\by_order\project_folder\videos\rat_pilot.mp4", # body_parts=['nose'], # outside_roi=True, # gpu=True, # font='poppins regular', # core_cnt=8, # border_bg_clr=(0, 0, 0), # bbox=None) # test.run() # if __name__ == "__main__": # test = ROIPlotMultiprocess(config_path=r"C:\troubleshooting\roi_duplicates\project_folder\project_config.ini", # video_path=r"C:\troubleshooting\roi_duplicates\project_folder\videos\2021-12-21_15-03-57_CO_Trimmed.mp4", # body_parts=['Snout'], # style_attr={'show_body_part': True, 'show_animal_name': False}, # bp_sizes=[20], # bp_colors=[(155, 255, 243)]) # test.run() # # if __name__ == '__main__': # test = ROIPlotMultiprocess(config_path=r"C:\troubleshooting\platea\project_folder\project_config.ini", # video_path=r"C:\troubleshooting\platea\project_folder\videos\Video_1.mp4", # body_parts=['NOSE'], # style_attr={'show_body_part': True, 'show_animal_name': False}) # test.run() # test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/project_config.ini', # video_path="/Users/simon/Desktop/envs/simba/troubleshooting/open_field_below/project_folder/videos/raw_clip1.mp4", # body_parts=['Snout'], # style_attr={'show_body_part': True, 'show_animal_name': False}) # test.run() # test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini', # video_path="/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/videos/2022-06-20_NOB_DOT_4.mp4", # body_parts=['Nose'], # style_attr={'show_body_part': True, 'show_animal_name': False}) # test.run() # test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini', # video_path="/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/videos/SI_DAY3_308_CD1_PRESENT.mp4", # core_cnt=-1, # style_attr={'show_body_part': True, 'show_animal_name': False}, # body_parts=['Nose']) # test.run() # test = ROIPlotMultiprocess(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # video_path="/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.avi", # core_cnt=7, # style_attr={'show_body_part': True, 'show_animal_name': True}, # body_parts=['Nose_1', 'Nose_2']) # test.run() # # test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini', # video_path="2022-06-20_NOB_DOT_4.mp4", # core_cnt=7, # style_attr={'Show_body_part': True, 'Show_animal_name': True}, body_parts={'Animal_1': 'Nose'}) # test.run() # # # test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/spontenous_alternation/project_folder/project_config.ini', # video_path="F1 HAB.mp4", # core_cnt=5, # style_attr={'Show_body_part': True, 'Show_animal_name': True}) # test.run() # # get_video_meta_data(video_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/frames/output/ROI_analysis/2022-06-20_NOB_DOT_4.mp4') # get_video_meta_data(video_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.avi') # test = ROIPlot(ini_path=r'/Users/simon/Desktop/troubleshooting/train_model_project/project_folder/project_config.ini', video_path=r"Together_1.avi") # test.insert_data() # test.visualize_ROI_data() # test = ROIPlot(ini_path=r"Z:\DeepLabCut\DLC_extract\Troubleshooting\ROI_2_animals\project_folder\project_config.ini", video_path=r"Z:\DeepLabCut\DLC_extract\Troubleshooting\ROI_2_animals\project_folder\videos\Video7.mp4") # test.insert_data() # test.visualize_ROI_data() # # test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # video_path="Together_1.avi", # style_attr={'Show_body_part': True, 'Show_animal_name': False}, # core_cnt=5) # test.run() # test = ROIPlotMultiprocess(ini_path=r'/Users/simon/Desktop/envs/troubleshooting/DLC_2_Black_animals/project_folder/project_config.ini', # video_path="Together_1.avi", # style_attr={'Show_body_part': True, 'Show_animal_name': False}, # core_cnt=5) # test.run()