Source code for simba.plotting.distance_plotter_mp

__author__ = "Simon Nilsson; sronilsson@gmail.com"
import functools
import multiprocessing
import os
import platform
from typing import Dict, List, Optional, Union

import cv2
import numpy as np
from numba import jit

from simba.mixins.config_reader import ConfigReader
from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.checks import (
    check_all_file_names_are_represented_in_video_log,
    check_file_exist_and_readable, check_instance, check_int,
    check_valid_boolean, check_valid_lst)
from simba.utils.data import terminate_cpu_pool
from simba.utils.errors import (CountError, InvalidInputError,
                                NoSpecifiedOutputError)
from simba.utils.lookups import get_color_dict
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
                                    find_core_cnt, get_fn_ext, read_df)


[docs]def distance_plotter_mp( frm_cnts: np.array, distances: np.ndarray, colors: List[str], video_setting: bool, frame_setting: bool, video_name: str, video_save_dir: str, frame_folder_dir: str, style_attr: dict, fps: int, ): group = int(distances[frm_cnts[0], 0]) video_writer = None if video_setting: fourcc = cv2.VideoWriter_fourcc(*"DIVX") temp_video_save_path = os.path.join(video_save_dir, f"{group}.avi") video_writer = cv2.VideoWriter( temp_video_save_path, fourcc, fps, (style_attr["width"], style_attr["height"]), ) for frm_cnt in frm_cnts: line_data = distances[:frm_cnt, 1:] line_data = np.hsplit(line_data, line_data.shape[1]) img = PlottingMixin.make_line_plot_plotly( data=line_data, colors=colors, width=style_attr["width"], height=style_attr["height"], line_width=style_attr["line width"], font_size=style_attr["font size"], title="Animal distances", y_lbl="distance (cm)", x_lbl="frame count", x_lbl_divisor=fps, y_max=style_attr["y_max"], line_opacity=style_attr["opacity"], save_path=None, ).astype(np.uint8) if video_setting: video_writer.write(img[:, :, :3]) if frame_setting: frm_name = os.path.join(frame_folder_dir, f"{frm_cnt}.png") cv2.imwrite(frm_name, np.uint8(img)) stdout_information(msg= f"Distance frame created: {frm_cnt} (of {distances.shape[0]}), Video: {video_name}, Processing core: {group}" ) if video_setting: video_writer.release() return group
[docs]class DistancePlotterMultiCore(ConfigReader, PlottingMixin): """ Visualize frame-wise body-part distances as line plots using multiprocessing. Produces one or more of: (i) frame-by-frame plot images, (ii) a dynamic distance-plot video, (iii) a final static distance plot (PNG or SVG). :param Union[str, os.PathLike] config_path: Path to SimBA project config file. :param List[Union[str, os.PathLike]] data_paths: One or more pose data files to process. :param bool frame_setting: If ``True``, save one plot image per frame. :param bool video_setting: If ``True``, save a video of the plot building over time. :param bool final_img: If ``True``, save a final static distance plot for each video. :param Dict[str, int] style_attr: Plot style dictionary. Expected keys include ``width``, ``height``, ``line width``, ``font size``, ``y_max``, and ``opacity``. :param List[List[str]] line_attr: Distance definitions. Each entry is ``[body_part_1, body_part_2, color_name]``. :param Optional[int] core_cnt: Number of CPU cores. ``-1`` uses all available cores. Default: ``-1``. :param bool last_frame_as_svg: If ``True``, final static distance image is saved as SVG; else PNG. Default: ``False``. .. note:: `GitHub tutorial/documentation <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#step-11-visualizations>`__. .. image:: _static/img/DistancePlotterMultiCore.png :alt: Distance Plotter Multi Core :width: 600 :align: center .. video:: _static/img/DistancePlotterMultiCore_1.webm :width: 600 :autoplay: :loop: :muted: :align: center :example: >>> style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'opacity': 0.5} >>> line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']} >>> distance_plotter = DistancePlotterMultiCore(config_path=r'/tests_/project_folder/project_config.ini', frame_setting=False, video_setting=True, final_img=True, style_attr=style_attr, line_attr=line_attr, files_found=['/test_/project_folder/csv/machine_results/Together_1.csv'], core_cnt=5) >>> distance_plotter.run() """ def __init__(self, config_path: Union[str, os.PathLike], data_paths: List[Union[str, os.PathLike]], frame_setting: bool, video_setting: bool, final_img: bool, style_attr: Dict[str, int], line_attr: List[List[str]], core_cnt: Optional[int] = -1, last_frame_as_svg: bool = False): if (not frame_setting) and (not video_setting) and (not final_img): raise NoSpecifiedOutputError(msg="Please choice to create frames and/or video distance plots", source=self.__class__.__name__) check_int( name=f"{self.__class__.__name__} core_cnt", value=core_cnt, min_value=-1, max_value=find_core_cnt()[0], ) if core_cnt == -1: core_cnt = find_core_cnt()[0] ConfigReader.__init__(self, config_path=config_path) PlottingMixin.__init__(self) check_instance( source=f"{self.__class__.__name__} line_attr", instance=line_attr, accepted_types=(list,), ) for cnt, i in enumerate(line_attr): check_valid_lst( source=f"{self.__class__.__name__} line_attr {cnt}", data=i, valid_dtypes=(str,), exact_len=3, ) check_valid_lst(data=data_paths, valid_dtypes=(str,), min_len=1) _ = [check_file_exist_and_readable(i) for i in data_paths] ( self.video_setting, self.frame_setting, self.data_paths, self.style_attr, self.line_attr, self.final_img, self.core_cnt, ) = ( video_setting, frame_setting, data_paths, style_attr, line_attr, final_img, core_cnt, ) if not os.path.exists(self.line_plot_dir): os.makedirs(self.line_plot_dir) check_valid_boolean(value=last_frame_as_svg, source=f'{self.__class__.__name__} last_frame_as_svg', raise_error=False) self.last_frm_ext, self.last_frame_as_svg = 'svg' if last_frame_as_svg else 'png', last_frame_as_svg self.color_names = get_color_dict() if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True) @staticmethod @jit(nopython=True) def __insert_group_idx_column(data: np.array, group: int): group_col = np.full((data.shape[0], 1), group) return np.hstack((group_col, data))
[docs] def run(self): stdout_information(msg=f"Processing {len(self.data_paths)} video(s)...") check_all_file_names_are_represented_in_video_log( video_info_df=self.video_info_df, data_paths=self.data_paths ) for file_cnt, file_path in enumerate(self.data_paths): video_timer = SimbaTimer(start=True) _, video_name, _ = get_fn_ext(file_path) data_df = read_df(file_path, self.file_type) try: data_df.columns = self.bp_headers except ValueError: raise CountError( msg=f"SimBA expects {self.bp_headers} columns but found {len(data_df)} columns in {file_path}", source=self.__class__.__name__, ) self.video_info, px_per_mm, fps = self.read_video_info( video_name=video_name ) self.save_video_folder = os.path.join(self.line_plot_dir, video_name) self.temp_folder = os.path.join(self.line_plot_dir, video_name, "temp") self.save_frame_folder_dir = os.path.join(self.line_plot_dir, video_name) distances = [] colors = [] for cnt, i in enumerate(self.line_attr): if i[2] not in list(self.color_names.keys()): raise InvalidInputError( msg=f"{i[2]} is not a valid color. Options: {list(self.color_names.keys())}.", source=self.__class__.__name__, ) colors.append(i[2]) bp_1, bp_2 = [f"{i[0]}_x", f"{i[0]}_y"], [f"{i[1]}_x", f"{i[1]}_y"] if len(list(set(bp_1) - set(data_df.columns))) > 0: raise InvalidInputError( msg=f"Could not find fields {bp_1} in {file_path}", source=self.__class__.__name__, ) if len(list(set(bp_2) - set(data_df.columns))) > 0: raise InvalidInputError( msg=f"Could not find fields {bp_2} in {file_path}", source=self.__class__.__name__, ) distances.append( FeatureExtractionMixin.framewise_euclidean_distance( location_1=data_df[bp_1].values.astype(np.float64), location_2=data_df[bp_2].values.astype(np.float64), px_per_mm=np.float64(px_per_mm), centimeter=True, ) ) if self.frame_setting: if os.path.exists(self.save_frame_folder_dir): self.remove_a_folder(self.save_frame_folder_dir) os.makedirs(self.save_frame_folder_dir) if self.video_setting: self.video_folder = os.path.join(self.line_plot_dir, video_name) if os.path.exists(self.temp_folder): self.remove_a_folder(self.temp_folder) os.makedirs(self.temp_folder) self.save_video_path = os.path.join( self.line_plot_dir, f"{video_name}.mp4" ) if self.final_img: _ = PlottingMixin.make_line_plot( data=distances, colors=colors, width=self.style_attr["width"], height=self.style_attr["height"], line_width=self.style_attr["line width"], font_size=self.style_attr["font size"], title="Animal distances", y_lbl="distance (cm)", x_lbl="time (s)", x_lbl_divisor=fps, as_svg=self.last_frame_as_svg, y_max=self.style_attr["y_max"], line_opacity=self.style_attr["opacity"], save_path=os.path.join( self.line_plot_dir, f"{video_name}_final_distances.{self.last_frm_ext}" ), ) if self.video_setting or self.frame_setting: if self.style_attr["y_max"] == -1: self.style_attr["y_max"] = max([np.max(x) for x in distances]) distances = np.stack(distances, axis=1) frm_range = np.arange(0, distances.shape[0]) frm_range = np.array_split(frm_range, self.core_cnt) distances = np.array_split(distances, self.core_cnt) distances = [ self.__insert_group_idx_column(data=i, group=cnt) for cnt, i in enumerate(distances) ] distances = np.concatenate(distances, axis=0) stdout_information(msg= f"Creating distance plots, multiprocessing, follow progress in terminal (chunksize: {self.multiprocess_chunksize}, cores: {self.core_cnt})" ) with multiprocessing.Pool( self.core_cnt, maxtasksperchild=self.maxtasksperchild ) as pool: constants = functools.partial( distance_plotter_mp, distances=distances, video_setting=self.video_setting, frame_setting=self.frame_setting, video_name=video_name, video_save_dir=self.temp_folder, frame_folder_dir=self.save_frame_folder_dir, style_attr=self.style_attr, colors=colors, fps=fps, ) for cnt, result in enumerate( pool.map( constants, frm_range, chunksize=self.multiprocess_chunksize ) ): stdout_information(msg=f"Frame batch core {result} complete...") pass terminate_cpu_pool(pool=pool, force=False) if self.video_setting: concatenate_videos_in_folder( in_folder=self.temp_folder, save_path=self.save_video_path, video_format="avi", ) video_timer.stop_timer() stdout_success( msg=f"Distance visualizations created for {video_name} saved at {self.line_plot_dir}", elapsed_time=video_timer.elapsed_time_str, ) self.timer.stop_timer() stdout_success( msg=f"Distance visualizations complete for {len(self.data_paths)} video(s)", elapsed_time=self.timer.elapsed_time_str, )
# style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 12, 'y_max': -1, 'opacity': 0.5} # line_attr = [['nose', 'center', 'Green'], ['center', 'center', 'Red']] # test = DistancePlotterMultiCore(config_path=r"E:\troubleshooting\mitra_pbn\mitra_pbn\project_folder\project_config.ini", # frame_setting=False, # video_setting=False, # last_frame_as_svg=True, # style_attr=style_attr, # final_img=True, # data_paths=[r"E:\troubleshooting\mitra_pbn\mitra_pbn\project_folder\csv\outlier_corrected_movement_location\2026-01-05 14-17-54 box2_1143_L_Gq_5cno.csv"], # line_attr=line_attr, # core_cnt=-1) # test.run() # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 12, 'y_max': -1, 'opacity': 0.5} # line_attr = [['Center_1', 'Center_2', 'Green'], ['Ear_left_2', 'Ear_right_2', 'Red']] # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # frame_setting=True, # video_setting=True, # style_attr=style_attr, # final_img=True, # data_paths=['/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/outlier_corrected_movement_location/Together_1.csv'], # line_attr=line_attr, # core_cnt=-1) # test.run() # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'y_max': 'auto', 'opacity': 0.9} # line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']} # # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # frame_setting=False, # video_setting=True, # style_attr=style_attr, # final_img=True, # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'], # line_attr=line_attr, # core_cnt=3) # test.create_distance_plot() # # # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8} # line_attr = {0: ['Termite_1_Head_1', 'Termite_1_Thorax_1', 'Dark-red']} # # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'opacity': 0.5, 'y_max': 'auto'} # line_attr = {0: ['Center_1', 'Center_2', 'Green'], 1: ['Ear_left_2', 'Ear_left_1', 'Red']} # # test = DistancePlotterMultiCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # frame_setting=False, # video_setting=True, # style_attr=style_attr, # final_img=False, # files_found=['/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'], # line_attr=line_attr, # core_cnt=5) # test.create_distance_plot() # style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8} # line_attr = {0: ['Termite_1_Head_1', 'Termite_1_Thorax_1', 'Dark-red']} # test = DistancePlotterSingleCore(config_path=r'/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini', # frame_setting=False, # video_setting=True, # style_attr=style_attr, # files_found=['/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/csv/outlier_corrected_movement_location/termites_1.csv'], # line_attr=line_attr) # test.create_distance_plot()