__author__ = "Simon Nilsson; sronilsson@gmail.com"
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
import functools
import gc
import multiprocessing
import os
import platform
import sys
from copy import deepcopy
from typing import List, Optional, Union
import cv2
is_pycharm_ipython = True
try:
module_names = list(sys.modules.keys())
if any('pydev' in str(mod).lower() for mod in module_names):
is_pycharm_ipython = True
elif 'IPython' in sys.modules or 'ipython' in sys.modules:
is_pycharm_ipython = True
else:
is_pycharm_ipython = False
except Exception:
is_pycharm_ipython = True
if not is_pycharm_ipython:
if 'MPLBACKEND' not in os.environ: os.environ['MPLBACKEND'] = 'Agg'
try:
import matplotlib
matplotlib.use('Agg', force=False)
except (RecursionError, RuntimeError, ValueError, SystemError):
import matplotlib
else:
import matplotlib
import numpy as np
import pandas as pd
from simba.mixins.config_reader import ConfigReader
from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.checks import (
check_all_file_names_are_represented_in_video_log,
check_file_exist_and_readable, check_float, check_int, check_str,
check_that_column_exist, check_valid_boolean, check_valid_lst)
from simba.utils.data import (create_color_palette, detect_bouts, get_cpu_pool,
terminate_cpu_pool)
from simba.utils.enums import Formats, Options
from simba.utils.errors import NoSpecifiedOutputError
from simba.utils.lookups import get_fonts
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
create_directory, find_core_cnt,
get_fn_ext, read_df, seconds_to_timestamp)
HEIGHT = "height"
WIDTH = "width"
FONT_ROTATION = "font rotation"
FONT_SIZE = "font size"
STYLE_KEYS = [HEIGHT, WIDTH, FONT_ROTATION, FONT_SIZE]
[docs]def gantt_creator_mp(data: np.array,
frame_setting: bool,
video_setting: bool,
video_save_dir: str,
frame_folder_dir: str,
bouts_df: pd.DataFrame,
clf_names: list,
fps: int,
bar_opacity: float,
video_name: str,
width: int,
height: int,
font_size: int,
font: str,
font_rotation: int,
palette: np.ndarray,
hhmmss: bool):
batch_id, frame_rng = data[0], data[1]
start_frm, end_frm, current_frm = frame_rng[0], frame_rng[-1], frame_rng[0]
video_writer = None
if video_setting:
fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
video_save_path = os.path.join(video_save_dir, f"{batch_id}.mp4")
video_writer = cv2.VideoWriter(video_save_path, fourcc, fps, (width, height))
while current_frm <= end_frm:
bout_rows = bouts_df.loc[bouts_df["End_frame"] <= current_frm]
plot = PlottingMixin().make_gantt_plot(x_length=current_frm + 1,
bouts_df=bout_rows,
clf_names=clf_names,
fps=fps,
width=width,
height=height,
font_size=font_size,
font=font,
bar_opacity=bar_opacity,
font_rotation=font_rotation,
video_name=video_name,
save_path=None,
palette=palette,
hhmmss=hhmmss)
current_frm += 1
if frame_setting:
frame_save_path = os.path.join(frame_folder_dir, f"{current_frm}.png")
cv2.imwrite(frame_save_path, plot)
if video_setting:
video_writer.write(plot)
del plot
if current_frm % 100 == 0:
gc.collect()
timestamp = seconds_to_timestamp(seconds=(current_frm / fps))
stdout_information(msg=f"Gantt frame created: {current_frm + 1}, (video: {video_name}, processing core: {batch_id + 1}, timestamp: {timestamp}")
if video_setting:
video_writer.release()
del video_writer
gc.collect()
return batch_id
[docs]class GanttCreatorMultiprocess(ConfigReader, PlottingMixin):
"""
Create classifier Gantt charts using multiprocessing for faster generation.
Generates one or more of: (i) frame-by-frame Gantt images, (ii) dynamic Gantt videos, (iii) a final static Gantt image (PNG or SVG).
.. note::
`GitHub gantt tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/tutorial.md#gantt-plot>`__.
.. seealso::
For single-process alternative, see :class:`simba.plotting.gantt_creator.GanttCreatorSingleProcess`.
.. image:: _static/img/gantt_plot.png
:alt: Gantt plot
:width: 300
:align: center
:param Union[str, os.PathLike] config_path: Path to SimBA project config file.
:param Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] data_paths: File path, list of file paths, or ``None`` (all machine result files in project).
:param bool frame_setting: If ``True``, creates individual frame images. Default: ``False``.
:param bool video_setting: If ``True``, creates dynamic Gantt videos. Default: ``False``.
:param bool last_frm_setting: If ``True``, creates a final static Gantt image. Default: ``True``.
:param bool last_frame_as_svg: If ``True``, saves final static frame as SVG; else PNG. Default: ``False``.
:param int width: Width of output images/videos in pixels. Default: 640.
:param int height: Height of output images/videos in pixels. Default: 480.
:param int font_size: Font size for behavior labels. Default: 8.
:param int font_rotation: Rotation angle for y-axis labels in degrees (0-180). Default: 45.
:param Optional[str] font: Matplotlib font name. If ``None``, default font is used.
:param float bar_opacity: Opacity of Gantt bars in range (0, 1]. Default: ``0.85``.
:param str palette: Color palette name for behaviors. Default: 'Set1'.
:param Optional[int] core_cnt: Number of CPU cores to use. If -1, uses all available cores. Default: -1.
:param bool hhmmss: If ``True``, x-axis labels are formatted as ``HH:MM:SS``. If ``False``, seconds are used. Default: ``False``.
:param Optional[List[str]] clf_names: Optional subset of classifiers to include. If ``None``, uses all project classifiers.
:example:
>>> gantt_creator = GanttCreatorMultiprocess(config_path='project_config.ini', video_setting=True, data_paths=['csv/machine_results/video1.csv'], core_cnt=5, hhmmss=True, last_frm_setting=True)
>>> gantt_creator.run()
"""
def __init__(self,
config_path: Union[str, os.PathLike],
data_paths: Optional[Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]]] = None,
frame_setting: Optional[bool] = False,
video_setting: Optional[bool] = False,
last_frm_setting: Optional[bool] = True,
last_frame_as_svg: bool = False,
width: int = 640,
height: int = 480,
font_size: int = 8,
font_rotation: int = 45,
font: Optional[str] = None,
bar_opacity: float = 0.85,
palette: str = 'Set1',
core_cnt: int = -1,
hhmmss: bool = False,
clf_names: Optional[List[str]] = None):
check_file_exist_and_readable(file_path=config_path)
if (not frame_setting) and (not video_setting) and (not last_frm_setting):
raise NoSpecifiedOutputError(msg="SIMBA ERROR: Please select gantt videos, frames, and/or last frame.", source=self.__class__.__name__)
check_file_exist_and_readable(file_path=config_path)
check_int(value=width, min_value=1, name=f'{self.__class__.__name__} width')
check_int(value=height, min_value=1, name=f'{self.__class__.__name__} height')
check_int(value=font_size, min_value=1, name=f'{self.__class__.__name__} font_size')
check_int(value=font_rotation, min_value=0, max_value=180, name=f'{self.__class__.__name__} font_rotation')
check_valid_boolean(value=hhmmss, source=f'{self.__class__.__name__} hhmmss', raise_error=False)
check_valid_boolean(value=last_frame_as_svg, source=f'{self.__class__.__name__} last_frame_as_svg', raise_error=False)
palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value + Options.PALETTE_OPTIONS.value
check_str(name=f'{self.__class__.__name__} palette', value=palette, options=palettes)
check_float(name=f'{self.__class__.__name__} bar_opacity', value=bar_opacity, allow_zero=False, allow_negative=False, max_value=1.0, raise_error=True)
check_int(name=f"{self.__class__.__name__} core_cnt",value=core_cnt, min_value=-1, unaccepted_vals=[0], max_value=find_core_cnt()[0])
self.core_cnt = find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt
self.width, self.height, self.font_size, self.font_rotation, self.hhmmss = width, height, font_size, font_rotation, hhmmss
if font is not None:
check_str(name=f'{self.__class__.__name__} font', value=font, options=list(get_fonts().keys()), raise_error=True)
ConfigReader.__init__(self, config_path=config_path, create_logger=False)
if isinstance(data_paths, list):
check_valid_lst(data=data_paths, source=f'{self.__class__.__name__} data_paths', valid_dtypes=(str,), min_len=1)
elif isinstance(data_paths, str):
check_file_exist_and_readable(file_path=data_paths)
data_paths = [data_paths]
else:
data_paths = deepcopy(self.machine_results_paths)
for file_path in data_paths: check_file_exist_and_readable(file_path=file_path)
if clf_names is not None:
check_valid_lst(data=clf_names, source=f'{self.__class__.__name__} clf_names', valid_dtypes=(str,), valid_values=self.clf_names, min_len=1, raise_error=True)
self.clf_names = clf_names
PlottingMixin.__init__(self)
self.clr_lst = create_color_palette(pallete_name=palette, increments=len(self.body_parts_lst) + 1, as_int=True, as_rgb_ratio=True)
self.frame_setting, self.video_setting, self.data_paths, self.last_frm_setting, self.font, self.bar_opacity = frame_setting, video_setting,data_paths, last_frm_setting, font, bar_opacity
self.last_frm_ext, self.last_frame_as_svg = 'svg' if last_frame_as_svg else 'png', last_frame_as_svg
if not os.path.exists(self.gantt_plot_dir): os.makedirs(self.gantt_plot_dir)
if platform.system() == "Darwin": multiprocessing.set_start_method("spawn", force=True)
stdout_information(msg=f"Processing {len(self.data_paths)} video(s)...")
[docs] def run(self):
check_all_file_names_are_represented_in_video_log(video_info_df=self.video_info_df, data_paths=self.data_paths)
if self.video_setting or self.frame_setting:
self.pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=self.maxtasksperchild, verbose=True, source=self.__class__.__name__)
else:
self.pool = None
for file_cnt, file_path in enumerate(self.data_paths):
video_timer = SimbaTimer(start=True)
_, self.video_name, _ = get_fn_ext(file_path)
self.data_df = read_df(file_path, self.file_type)
check_that_column_exist(df=self.data_df, column_name=self.clf_names, file_name=file_path)
stdout_information(msg=f"Processing video {self.video_name}, Frame count: {len(self.data_df)} (Video {(file_cnt + 1)}/{len(self.data_paths)})...")
self.video_info_settings, _, self.fps = self.read_video_info(video_name=self.video_name)
self.bouts_df = detect_bouts(data_df=self.data_df, target_lst=list(self.clf_names), fps=int(self.fps))
self.temp_folder = os.path.join(self.gantt_plot_dir, self.video_name, "temp")
self.save_frame_folder_dir = os.path.join(self.gantt_plot_dir, self.video_name)
if self.frame_setting:
create_directory(paths=self.save_frame_folder_dir, overwrite=True)
if self.video_setting:
create_directory(paths=self.temp_folder)
self.save_video_path = os.path.join(self.gantt_plot_dir, f"{self.video_name}.mp4")
if self.last_frm_setting:
self.make_gantt_plot(x_length=len(self.data_df),
bouts_df=self.bouts_df,
clf_names=self.clf_names,
fps=self.fps,
width=self.width,
height=self.height,
font_size=self.font_size,
font_rotation=self.font_rotation,
video_name=self.video_name,
bar_opacity=self.bar_opacity,
font=self.font,
as_svg=self.last_frame_as_svg,
save_path=os.path.join(self.gantt_plot_dir, f"{self.video_name}_final_image.{self.last_frm_ext}"),
palette=self.clr_lst,
hhmmss=self.hhmmss)
if self.video_setting or self.frame_setting:
frame_data = np.array_split(list(range(0, len(self.data_df))), self.core_cnt)
frame_data = [(i, x) for i, x in enumerate(frame_data)]
stdout_information(msg=f"Creating gantt, multiprocessing (chunksize: {(self.multiprocess_chunksize)}, cores: {self.core_cnt})...")
constants = functools.partial(gantt_creator_mp,
video_setting=self.video_setting,
frame_setting=self.frame_setting,
video_save_dir=self.temp_folder,
frame_folder_dir=self.save_frame_folder_dir,
bouts_df=self.bouts_df,
clf_names=self.clf_names,
fps=self.fps,
width=self.width,
height=self.height,
font_size=self.font_size,
bar_opacity=self.bar_opacity,
font=self.font,
font_rotation=self.font_rotation,
video_name=self.video_name,
palette=self.clr_lst,
hhmmss=self.hhmmss)
for cnt, result in enumerate(self.pool.imap(constants, frame_data, chunksize=self.multiprocess_chunksize)):
stdout_information(msg=f'Batch {result+1}/{self.core_cnt} complete...')
if self.video_setting:
stdout_information(msg=f"Joining {self.video_name} multiprocessed video...")
concatenate_videos_in_folder(in_folder=self.temp_folder, save_path=self.save_video_path)
video_timer.stop_timer()
stdout_information(msg=f"Gantt video {self.video_name} complete (elapsed time: {video_timer.elapsed_time_str}s) ...")
terminate_cpu_pool(pool=self.pool, force=False, source=self.__class__.__name__)
self.timer.stop_timer()
stdout_success(msg=f"Gantt visualizations for {len(self.data_paths)} videos created in {self.gantt_plot_dir} directory", elapsed_time=self.timer.elapsed_time_str)
# if __name__ == "__main__":
# test = GanttCreatorMultiprocess(config_path=r"E:\troubleshooting\mitra_pbn\mitra_pbn\project_folder\project_config.ini",
# frame_setting=False,
# video_setting=False,
# last_frm_setting=True,
# last_frame_as_svg=True,
# width=640,
# height= 480,
# font_size=10,
# font_rotation= 45,
# hhmmss=True)
# test.run()
# if __name__ == "__main__":
# test = GanttCreatorMultiprocess(config_path=r"D:\troubleshooting\maplight_ri\project_folder\project_config.ini",
# frame_setting=False,
# video_setting=True,
# data_paths=r"D:\troubleshooting\maplight_ri\project_folder\csv\machine_results\Trial_1_C24_D1_1.csv",
# last_frm_setting=False,
# width=640,
# height= 480,
# font_size=10,
# font_rotation= 45,
# core_cnt=16)
# test.run()
# test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',
# frame_setting=False,
# video_setting=True,
# data_paths=['/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/csv/machine_results/2022-06-20_NOB_DOT_4.csv'],
# cores=5,
# last_frm_setting=False,
# style_attr={'width': 640, 'height': 480, 'font size': 10, 'font rotation': 65})
# test.run()
# test = GanttCreatorMultiprocess(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
# frame_setting=False,
# video_setting=True,
# data_paths=['/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/machine_results/Together_1.csv'],
# last_frm_setting=True,
# style_attr={'width': 640, 'height': 480, 'font size': 10, 'font rotation': 65},
# cores=2)
# test.run()