Source code for simba.plotting.gantt_creator_mp

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
import functools
import gc
import multiprocessing
import os
import platform
import sys
from copy import deepcopy
from typing import List, Optional, Union

import cv2

is_pycharm_ipython = True
try:
    module_names = list(sys.modules.keys())
    if any('pydev' in str(mod).lower() for mod in module_names):
        is_pycharm_ipython = True
    elif 'IPython' in sys.modules or 'ipython' in sys.modules:
        is_pycharm_ipython = True
    else:
        is_pycharm_ipython = False
except Exception:
    is_pycharm_ipython = True

if not is_pycharm_ipython:
    if 'MPLBACKEND' not in os.environ:  os.environ['MPLBACKEND'] = 'Agg'
    try:
        import matplotlib
        matplotlib.use('Agg', force=False)
    except (RecursionError, RuntimeError, ValueError, SystemError):
        import matplotlib
else:
    import matplotlib

import numpy as np
import pandas as pd

from simba.mixins.config_reader import ConfigReader
from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.checks import (
    check_all_file_names_are_represented_in_video_log,
    check_file_exist_and_readable, check_float, check_int, check_str,
    check_that_column_exist, check_valid_boolean, check_valid_lst)
from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
                              terminate_cpu_pool)
from simba.utils.enums import Formats, Options
from simba.utils.errors import NoSpecifiedOutputError
from simba.utils.lookups import get_fonts
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
                                    create_directory, find_core_cnt,
                                    get_fn_ext, read_df, seconds_to_timestamp)

HEIGHT = "height"
WIDTH = "width"
FONT_ROTATION = "font rotation"
FONT_SIZE = "font size"
STYLE_KEYS = [HEIGHT, WIDTH, FONT_ROTATION, FONT_SIZE]

[docs]def gantt_creator_mp(data: np.array, frame_setting: bool, video_setting: bool, video_save_dir: str, frame_folder_dir: str, bouts_df: pd.DataFrame, clf_names: list, fps: int, bar_opacity: float, video_name: str, width: int, height: int, font_size: int, font: str, font_rotation: int, palette: np.ndarray, hhmmss: bool): batch_id, frame_rng = data[0], data[1] start_frm, end_frm, current_frm = frame_rng[0], frame_rng[-1], frame_rng[0] video_writer = None if video_setting: fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) video_save_path = os.path.join(video_save_dir, f"{batch_id}.mp4") video_writer = cv2.VideoWriter(video_save_path, fourcc, fps, (width, height)) while current_frm <= end_frm: bout_rows = bouts_df.loc[bouts_df["End_frame"] <= current_frm] plot = PlottingMixin().make_gantt_plot(x_length=current_frm + 1, bouts_df=bout_rows, clf_names=clf_names, fps=fps, width=width, height=height, font_size=font_size, font=font, bar_opacity=bar_opacity, font_rotation=font_rotation, video_name=video_name, save_path=None, palette=palette, hhmmss=hhmmss) current_frm += 1 if frame_setting: frame_save_path = os.path.join(frame_folder_dir, f"{current_frm}.png") cv2.imwrite(frame_save_path, plot) if video_setting: video_writer.write(plot) del plot if current_frm % 100 == 0: gc.collect() timestamp = seconds_to_timestamp(seconds=(current_frm / fps)) stdout_information(msg=f"Gantt frame created: {current_frm + 1}, (video: {video_name}, processing core: {batch_id + 1}, timestamp: {timestamp}") if video_setting: video_writer.release() del video_writer gc.collect() return batch_id
[docs]class GanttCreatorMultiprocess(ConfigReader, PlottingMixin): """ Create classifier Gantt charts using multiprocessing for faster generation. Generates one or more of: (i) frame-by-frame Gantt images, (ii) dynamic Gantt videos, (iii) a final static Gantt image (PNG or SVG). .. note:: `GitHub gantt tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#gantt-plot>`__. .. seealso:: For single-process alternative, see :class:`simba.plotting.gantt_creator.GanttCreatorSingleProcess`. .. image:: _static/img/gantt_plot.png :alt: Gantt plot :width: 300 :align: center :param Union[str, os.PathLike] config_path: Path to SimBA project config file. :param Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] data_paths: File path, list of file paths, or ``None`` (all machine result files in project). :param bool frame_setting: If ``True``, creates individual frame images. Default: ``False``. :param bool video_setting: If ``True``, creates dynamic Gantt videos. Default: ``False``. :param bool last_frm_setting: If ``True``, creates a final static Gantt image. Default: ``True``. :param bool last_frame_as_svg: If ``True``, saves final static frame as SVG; else PNG. Default: ``False``. :param int width: Width of output images/videos in pixels. Default: 640. :param int height: Height of output images/videos in pixels. Default: 480. :param int font_size: Font size for behavior labels. Default: 8. :param int font_rotation: Rotation angle for y-axis labels in degrees (0-180). Default: 45. :param Optional[str] font: Matplotlib font name. If ``None``, default font is used. :param float bar_opacity: Opacity of Gantt bars in range (0, 1]. Default: ``0.85``. :param str palette: Color palette name for behaviors. Default: 'Set1'. :param Optional[int] core_cnt: Number of CPU cores to use. If -1, uses all available cores. Default: -1. :param bool hhmmss: If ``True``, x-axis labels are formatted as ``HH:MM:SS``. If ``False``, seconds are used. Default: ``False``. :param Optional[List[str]] clf_names: Optional subset of classifiers to include. If ``None``, uses all project classifiers. :example: >>> gantt_creator = GanttCreatorMultiprocess(config_path='project_config.ini', video_setting=True, data_paths=['csv/machine_results/video1.csv'], core_cnt=5, hhmmss=True, last_frm_setting=True) >>> gantt_creator.run() """ def __init__(self, config_path: Union[str, os.PathLike], data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None, frame_setting: Optional[bool] = False, video_setting: Optional[bool] = False, last_frm_setting: Optional[bool] = True, last_frame_as_svg: bool = False, width: int = 640, height: int = 480, font_size: int = 8, font_rotation: int = 45, font: Optional[str] = None, bar_opacity: float = 0.85, palette: str = 'Set1', core_cnt: int = -1, hhmmss: bool = False, clf_names: Optional[List[str]] = None): check_file_exist_and_readable(file_path=config_path) if (not frame_setting) and (not video_setting) and (not last_frm_setting): raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.", source=self.__class__.__name__) check_file_exist_and_readable(file_path=config_path) check_int(value=width, min_value=1, name=f'{self.__class__.__name__} width') check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height') check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size') check_int(value=font_rotation, min_value=0, max_value=180, name=f'{self.__class__.__name__} font_rotation') check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False) check_valid_boolean(value=last_frame_as_svg, source=f'{self.__class__.__name__} last_frame_as_svg', raise_error=False) palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes) check_float(name=f'{self.__class__.__name__} bar_opacity', value=bar_opacity, allow_zero=False, allow_negative=False, max_value=1.0, raise_error=True) check_int(name=f"{self.__class__.__name__} core_cnt",value=core_cnt, min_value=-1, unaccepted_vals=[0], max_value=find_core_cnt()[0]) self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt self.width, self.height, self.font_size, self.font_rotation, self.hhmmss = width, height, font_size, font_rotation, hhmmss if font is not None: check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True) ConfigReader.__init__(self, config_path=config_path, create_logger=False) if isinstance(data_paths, list): check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1) elif isinstance(data_paths, str): check_file_exist_and_readable(file_path=data_paths) data_paths = [data_paths] else: data_paths = deepcopy(self.machine_results_paths) for file_path in data_paths: check_file_exist_and_readable(file_path=file_path) if clf_names is not None: check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True) self.clf_names = clf_names PlottingMixin.__init__(self) self.clr_lst = create_color_palette(pallete_name=palette, increments=len(self.body_parts_lst) + 1, as_int=True, as_rgb_ratio=True) self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting, self.font, self.bar_opacity = frame_setting, video_setting,data_paths, last_frm_setting, font, bar_opacity self.last_frm_ext, self.last_frame_as_svg = 'svg' if last_frame_as_svg else 'png', last_frame_as_svg if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir) if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True) stdout_information(msg=f"Processing {len(self.data_paths)} video(s)...")
[docs] def run(self): check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths) if self.video_setting or self.frame_setting: self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=self.maxtasksperchild, verbose=True, source=self.__class__.__name__) else: self.pool = None for file_cnt, file_path in enumerate(self.data_paths): video_timer = SimbaTimer(start=True) _, self.video_name, _ = get_fn_ext(file_path) self.data_df = read_df(file_path, self.file_type) check_that_column_exist(df=self.data_df, column_name=self.clf_names, file_name=file_path) stdout_information(msg=f"Processing video {self.video_name}, Frame count: {len(self.data_df)} (Video {(file_cnt + 1)}/{len(self.data_paths)})...") self.video_info_settings, _, self.fps = self.read_video_info(video_name=self.video_name) self.bouts_df = detect_bouts(data_df=self.data_df, target_lst=list(self.clf_names), fps=int(self.fps)) self.temp_folder = os.path.join(self.gantt_plot_dir, self.video_name, "temp") self.save_frame_folder_dir = os.path.join(self.gantt_plot_dir, self.video_name) if self.frame_setting: create_directory(paths=self.save_frame_folder_dir, overwrite=True) if self.video_setting: create_directory(paths=self.temp_folder) self.save_video_path = os.path.join(self.gantt_plot_dir, f"{self.video_name}.mp4") if self.last_frm_setting: self.make_gantt_plot(x_length=len(self.data_df), bouts_df=self.bouts_df, clf_names=self.clf_names, fps=self.fps, width=self.width, height=self.height, font_size=self.font_size, font_rotation=self.font_rotation, video_name=self.video_name, bar_opacity=self.bar_opacity, font=self.font, as_svg=self.last_frame_as_svg, save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name}_final_image.{self.last_frm_ext}"), palette=self.clr_lst, hhmmss=self.hhmmss) if self.video_setting or self.frame_setting: frame_data = np.array_split(list(range(0, len(self.data_df))), self.core_cnt) frame_data = [(i, x) for i, x in enumerate(frame_data)] stdout_information(msg=f"Creating gantt, multiprocessing (chunksize: {(self.multiprocess_chunksize)}, cores: {self.core_cnt})...") constants = functools.partial(gantt_creator_mp, video_setting=self.video_setting, frame_setting=self.frame_setting, video_save_dir=self.temp_folder, frame_folder_dir=self.save_frame_folder_dir, bouts_df=self.bouts_df, clf_names=self.clf_names, fps=self.fps, width=self.width, height=self.height, font_size=self.font_size, bar_opacity=self.bar_opacity, font=self.font, font_rotation=self.font_rotation, video_name=self.video_name, palette=self.clr_lst, hhmmss=self.hhmmss) for cnt, result in enumerate(self.pool.imap(constants, frame_data, chunksize=self.multiprocess_chunksize)): stdout_information(msg=f'Batch {result+1}/{self.core_cnt} complete...') if self.video_setting: stdout_information(msg=f"Joining {self.video_name} multiprocessed video...") concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path) video_timer.stop_timer() stdout_information(msg=f"Gantt video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ...") terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__) self.timer.stop_timer() stdout_success(msg=f"Gantt visualizations for {len(self.data_paths)} videos created in {self.gantt_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str)
# if __name__ == "__main__": # test = GanttCreatorMultiprocess(config_path=r"E:\troubleshooting\mitra_pbn\mitra_pbn\project_folder\project_config.ini", # frame_setting=False, # video_setting=False, # last_frm_setting=True, # last_frame_as_svg=True, # width=640, # height= 480, # font_size=10, # font_rotation= 45, # hhmmss=True) # test.run() # if __name__ == "__main__": # test = GanttCreatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini", # frame_setting=False, # video_setting=True, # data_paths=r"D:\troubleshooting\maplight_ri\project_folder\csv\machine_results\Trial_1_C24_D1_1.csv", # last_frm_setting=False, # width=640, # height= 480, # font_size=10, # font_rotation= 45, # core_cnt=16) # test.run() # test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini', # frame_setting=False, # video_setting=True, # data_paths=['/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/csv/machine_results/2022-06-20_NOB_DOT_4.csv'], # cores=5, # last_frm_setting=False, # style_attr={'width': 640, 'height': 480, 'font size': 10, 'font rotation': 65}) # test.run() # test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # frame_setting=False, # video_setting=True, # data_paths=['/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'], # last_frm_setting=True, # style_attr={'width': 640, 'height': 480, 'font size': 10, 'font rotation': 65}, # cores=2) # test.run()