Source code for simba.plotting.pose_plotter_mp

import functools
import multiprocessing
import os
import platform
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

try:
    from typing import Literal
except:
    from typing_extensions import Literal

import cv2
import numpy as np
import pandas as pd

from simba.feature_extractors.perimeter_jit import jitted_centroid
from simba.mixins.config_reader import ConfigReader
from simba.mixins.geometry_mixin import GeometryMixin
from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.checks import (check_if_valid_rgb_tuple, check_instance,
                                check_int, check_nvidea_gpu_available,
                                check_str, check_that_column_exist,
                                check_valid_boolean)
from simba.utils.data import (create_color_palette, get_cpu_pool,
                              terminate_cpu_pool)
from simba.utils.enums import OS, Formats, Options
from simba.utils.errors import CountError, InvalidFilepathError
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
                                    find_core_cnt,
                                    find_files_of_filetypes_in_directory,
                                    get_current_time, get_fn_ext,
                                    get_video_meta_data, read_df)
from simba.utils.warnings import FrameRangeWarning


[docs]def pose_plotter_mp(data: pd.DataFrame, video_meta_data: dict, video_path: str, bp_dict: dict, colors_dict: dict, circle_size: int, center_of_mass: Optional[dict], center_of_mass_clr: tuple, bbox: bool, video_save_dir: Union[str, os.PathLike],): fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) group_cnt = int(data.iloc[0]["group"]) data = data.drop(["group"], axis=1) start_frm, current_frm, end_frm = data.index[0], data.index[0], data.index[-1] save_path = os.path.join(video_save_dir, f"{group_cnt}.mp4") writer = cv2.VideoWriter(save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"], video_meta_data["height"])) cap = cv2.VideoCapture(video_path) cap.set(1, start_frm) while current_frm < end_frm: ret, img = cap.read() if ret: for animal_cnt, (animal_name, animal_data) in enumerate(bp_dict.items()): animal_bbox = [] for cnt, (x_name, y_name) in enumerate(zip(animal_data["X_bps"], animal_data["Y_bps"])): check_that_column_exist(df=data, column_name=[x_name, y_name], file_name=video_path) bp_tuple = (int(data.at[current_frm, x_name]), int(data.at[current_frm, y_name])) clr = colors_dict[animal_cnt][cnt] img = cv2.circle(img, bp_tuple, circle_size, clr, -1) animal_bbox.append(list(bp_tuple)) if bbox is not None and len(animal_bbox) > 4: if bbox == 'axis-aligned': animal_bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=np.array(animal_bbox).reshape(-1, len(animal_bbox), 2).astype(np.int32)) elif bbox == 'animal-aligned': animal_bbox = GeometryMixin().minimum_rotated_rectangle(shape=np.array(animal_bbox).reshape(len(animal_bbox), 2), buffer=None) animal_bbox = np.round(np.array(animal_bbox.exterior.coords)).astype(np.int32) img = cv2.polylines(img, [animal_bbox], True, colors_dict[animal_cnt][0], thickness=max(1, int(circle_size/1.5)), lineType=-1) if center_of_mass is not None: center_point = center_of_mass[animal_name][current_frm] center_point_tuple = (int(center_point[0]), int(center_point[1])) img = cv2.circle(img, center_point_tuple, circle_size+2, center_of_mass_clr, -1) writer.write(img) current_frm += 1 stdout_information(msg=f"[{get_current_time()}] Multi-processing video frame {current_frm} on core {group_cnt}...", source=pose_plotter_mp.__name__) else: FrameRangeWarning(msg=f'Frame {current_frm} not found in video {video_path}, terminating video creation...', source=pose_plotter_mp.__name__) break cap.release() writer.release()
[docs]class PosePlotterMultiProcess(): """ Create pose-estimation visualizations from data within a SimBA project folder using multiprocessing. :param Union[str, os.PathLike] data_path: Path to a SimBA project directory containing pose-estimation data (parquet or CSV), or path to a single pose file. Must be under ``project_folder/csv/`` so that ``project_config.ini`` can be located. :param Optional[Union[str, os.PathLike]] out_dir: Directory where pose-estimation videos are saved. If None, saves to a new folder under the input data directory. Default None. :param Optional[Dict[str, str]] palettes: Dict mapping animal names to color palette names (e.g. ``{'Animal_1': 'Set1', 'Animal_2': 'Pastel1'}``). Must have one entry per animal. If None, uses project default body-part colors. Default None. :param Optional[int] circle_size: Radius of circles drawn at each body-part location. If None, auto-computed from video resolution. Default None. :param Optional[int] core_cnt: Number of CPU cores for multiprocessing. -1 uses all available cores. Default -1. :param Optional[bool] gpu: If True, use GPU for video concatenation when available. Default False. :param Optional[Literal['axis-aligned', 'animal-aligned']] bbox: If not None, draw bounding boxes around each animal. ``'axis-aligned'`` = rectangle aligned with video axes; ``'animal-aligned'`` = minimum rotated rectangle aligned with the animal's orientation. Default None (no bounding boxes). :param Optional[Tuple[int, int, int]] center_of_mass: If not None, RGB tuple (e.g. (255, 0, 0)) for drawing a center-of-mass dot per animal. Default None (no center of mass). :param Optional[int] sample_time: If not None, render only the first N seconds of each video (N = this value). Useful for quick previews. Default None (full video). :param bool verbose: If True, print progress messages during video creation. Default True. .. image:: _static/img/pose_plotter.png :alt: Pose plotter :width: 600 :align: center :example: >>> test = PosePlotterMultiProcess(data_path='project_folder/csv/outlier_corrected_movement_location', out_dir='/project_folder/test_viz', circle_size=10, core_cnt=1, palettes={'Animal_1': 'Set1', 'Animal_2': 'Pastel1'}) >>> test.run() """ def __init__(self, data_path: Union[str, os.PathLike], out_dir: Optional[Union[str, os.PathLike]] = None, palettes: Optional[Dict[str, str]] = None, circle_size: Optional[int] = None, core_cnt: Optional[int] = -1, gpu: Optional[bool] = False, bbox: Optional[Literal['axis-aligned', 'animal-aligned']] = None, center_of_mass: Optional[Tuple[int, int, int]] = None, sample_time: Optional[int] = None, verbose: bool = True) -> None: if os.path.isdir(data_path): config_path = os.path.join(Path(data_path).parents[1], 'project_config.ini') elif os.path.isfile(data_path): config_path = os.path.join(Path(data_path).parents[2], 'project_config.ini') else: raise InvalidFilepathError(msg=f'{data_path} not not a valid file or directory path.', source=self.__class__.__name__) if not os.path.isfile(config_path): raise InvalidFilepathError(msg=f'When visualizing pose-estimation, select an input sub-directory of the project_folder/csv folder OR file in the project_folder/csv/ANY_FOLDER directory. {data_path} does not meet these requirements and therefore SimBA cant locate the project_config.ini (expected at {config_path}', source=self.__class__.__name__) self.config = ConfigReader(config_path=config_path, read_video_info=False, create_logger=False) if os.path.isdir(data_path): files_found = find_files_of_filetypes_in_directory(directory=data_path, extensions=[f'.{self.config.file_type}'], raise_error=True) else: files_found = [data_path] self.animal_bp_dict = self.config.body_parts_lst if circle_size is not None: check_int(name='circle_size', value=circle_size, min_value=1) check_int(name='core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0]) if core_cnt == -1: core_cnt = find_core_cnt()[0] self.color_dict = {} if palettes is not None: check_instance(source=self.__class__.__name__, instance=palettes, accepted_types=(dict,)) if len(list(palettes.keys())) != self.config.animal_cnt: raise CountError(msg=f'The number of color palettes ({(len(list(palettes.keys())))}) spedificed is not the same as the number of animals ({(self.config.animal_cnt)}) in the SimBA project at {self.config.project_path}') for cnt, (k, v) in enumerate(palettes.items()): check_str(name='palette', value=v, options=Options.PALETTE_OPTIONS_CATEGORICAL.value) self.color_dict[cnt] = create_color_palette(pallete_name=v, increments=len(self.config.body_parts_lst)) else: for cnt, (k, v) in enumerate(self.config.animal_bp_dict.items()): self.color_dict[cnt] = self.config.animal_bp_dict[k]["colors"] if bbox is not None: check_str(name=f'{self.__class__.__name__} bbox', value=bbox, options=['axis-aligned', 'animal-aligned'], allow_blank=False, raise_error=True) check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose', raise_error=True) if sample_time is not None: check_int(name='sample_time', value=sample_time, min_value=1) if out_dir is None: out_dir = os.path.join(os.path.dirname(files_found[0]), f'pose_videos_{self.config.datetime}') self.circle_size, self.core_cnt, self.out_dir, self.sample_time, self.bbox, self.verbose = (circle_size, core_cnt, out_dir, sample_time, bbox, verbose) if not os.path.exists(self.out_dir): os.makedirs(self.out_dir) check_valid_boolean(value=gpu, source=f'{self.__class__.__name__} gpu', raise_error=True) if center_of_mass is not None: check_if_valid_rgb_tuple(data=center_of_mass, raise_error=True, source=f'{self.__class__.__name__} center_of_mass') self.data, self.center_of_mass = {}, center_of_mass self.gpu = True if gpu and check_nvidea_gpu_available() else False for file in files_found: self.data[file] = self.config.find_video_of_file(video_dir=self.config.video_dir, filename=get_fn_ext(file)[1]) if platform.system() == OS.MAC.value: multiprocessing.set_start_method("spawn", force=True) def _get_center_of_mass(self): if self.verbose: stdout_information(msg='Computing animal centroids...') center_of_mass_data = {} for animal_cnt, (animal_name, animal_data) in enumerate(self.config.animal_bp_dict.items()): animal_data_cols = [x for pair in zip(animal_data["X_bps"], animal_data["Y_bps"]) for x in pair] animal_df = self.pose_df[animal_data_cols] center_of_mass_data[animal_name] = jitted_centroid(points=np.reshape(animal_df.values, (len(animal_df / 2), -1, 2)).astype(np.float32)) return center_of_mass_data
[docs] def run(self): self.pool = get_cpu_pool(core_cnt=self.core_cnt, source=self.__class__.__name__) for file_cnt, (pose_path, video_path) in enumerate(self.data.items()): video_timer = SimbaTimer(start=True) video_name = get_fn_ext(pose_path)[1] self.temp_folder = os.path.join(self.out_dir, video_name, "temp") if os.path.exists(self.temp_folder): self.config.remove_a_folder(self.temp_folder) os.makedirs(self.temp_folder, exist_ok=True) save_video_path = os.path.join(self.out_dir, f"{video_name}.mp4") pose_df = read_df(file_path=pose_path, file_type=self.config.file_type, check_multiindex=True) video_meta_data = get_video_meta_data(video_path=video_path) if self.circle_size is None: video_circle_size = PlottingMixin().get_optimal_circle_size(frame_size=(int(video_meta_data['width']), int(video_meta_data['height'])), circle_frame_ratio=70) else: video_circle_size = self.circle_size if (self.sample_time is None) and (video_meta_data["frame_count"] != len(pose_df)): FrameRangeWarning(msg=f'The video {video_name} has pose-estimation data for {len(pose_df)} frames, but the video has {video_meta_data["frame_count"]} frames. Ensure the data and video has an equal number of frames.', source=self.__class__.__name__) elif isinstance(self.sample_time, int): sample_frm_cnt = int(video_meta_data["fps"] * self.sample_time) if sample_frm_cnt > len(pose_df): sample_frm_cnt = len(pose_df) pose_df = pose_df.iloc[0:sample_frm_cnt] if 'input_csv' in os.path.dirname(pose_path): pose_df = self.config.insert_column_headers_for_outlier_correction(data_df=pose_df, new_headers=self.config.bp_headers, filepath=pose_path) self.pose_df = (pose_df.apply(pd.to_numeric, errors="coerce").fillna(0).reset_index(drop=True)) self.centroid_data = self._get_center_of_mass() if self.center_of_mass is not None else None pose_lst, obs_per_split = PlottingMixin().split_and_group_df(df=pose_df, splits=self.core_cnt) if self.verbose: stdout_information(msg=f"Creating pose videos, multiprocessing (chunksize: {self.config.multiprocess_chunksize}, cores: {self.core_cnt})...") constants = functools.partial(pose_plotter_mp, video_meta_data=video_meta_data, video_path=video_path, bp_dict=self.config.animal_bp_dict, colors_dict=self.color_dict, circle_size=video_circle_size, bbox=self.bbox, center_of_mass=self.centroid_data, center_of_mass_clr=self.center_of_mass, video_save_dir=self.temp_folder) for cnt, result in enumerate(self.pool.imap(constants, pose_lst, chunksize=self.config.multiprocess_chunksize)): if self.verbose: stdout_information(msg=f"Image {min(len(pose_df), obs_per_split*(cnt+1))}/{len(pose_df)}, Video {file_cnt+1}/{len(list(self.data.keys()))}...") if self.verbose: stdout_information(msg=f"Joining {video_name} multi-processed video...") concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=save_video_path, remove_splits=True, gpu=self.gpu) video_timer.stop_timer() stdout_success(msg=f"Pose video {video_name} complete and saved at {save_video_path}", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__) terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__) self.config.timer.stop_timer() stdout_success(f"Pose visualizations for {len(list(self.data.keys()))} video(s) created in {self.out_dir} directory", elapsed_time=self.config.timer.elapsed_time_str, source=self.__class__.__name__)
# if __name__ == "__main__": # test = PosePlotterMultiProcess(data_path=r"E:\troubleshooting\mitra_emergence\project_folder\csv\outlier_corrected_movement_location\Box1_180mISOcontrol_Females.csv", # out_dir=None, # circle_size=8, # core_cnt=12, # palettes=None, # bbox=True, # center_of_mass=True) # test.run() # test = PosePlotterMultiProcess(data_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_1.csv', # out_dir=None, # circle_size=None, # core_cnt=1, # palettes=None) # test.run() # test = PosePlotter(in_dir='/Users/simon/Desktop/envs/troubleshooting/dam_nest-c-only_ryan/project_folder/csv/outlier_corrected_movement_location', # out_dir='/Users/simon/Desktop/video_tests_', # sample_time=2, # circle_size=10, # core_cnt=1, # color_settings=None) # # test.run() # test = PosePlotter(in_dir='/Users/simon/Desktop/envs/troubleshooting/piotr/project_folder/csv/outlier_corrected_movement_location', # out_dir='/Users/simon/Desktop/envs/troubleshooting/piotr/project_folder/frames/output/test', # circle_size=10, # core_cnt=6, # color_settings={'Animal_1': 'Green', 'Animal_2': 'Red'}) # test.run() # if __name__ == "__main__": # test = PosePlotterMultiProcess(data_path=r"F:\troubleshooting\sam\sam\project_folder\csv\outlier_corrected_movement_location", # out_dir=r'F:\troubleshooting\sam\sam\project_folder\frames\output\pose_validation_2', # circle_size=8, # core_cnt=5, # palettes={'Animal_1': 'Set1'}, # sample_time=60, # bbox=None, #'animal-aligned', # center_of_mass=None) # test.run()