__author__ = "Simon Nilsson; sronilsson@gmail.com"
import itertools
import os
from typing import Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import pandas as pd
from simba.mixins.config_reader import ConfigReader
from simba.mixins.geometry_mixin import GeometryMixin
from simba.mixins.plotting_mixin import PlottingMixin
from simba.roi_tools.roi_aggregate_statistics_analyzer import \
ROIAggregateStatisticsAnalyzer
from simba.roi_tools.roi_utils import get_roi_dict_from_dfs
from simba.utils.checks import (check_file_exist_and_readable, check_float,
check_if_dir_exists, check_if_valid_rgb_tuple,
check_int, check_valid_boolean,
check_valid_lst,
check_video_and_data_frm_count_align)
from simba.utils.data import (create_color_palettes, detect_bouts,
slice_roi_dict_for_video)
from simba.utils.enums import Formats, Keys, Paths, TagNames, TextOptions
from simba.utils.errors import (BodypartColumnNotFoundError, DuplicationError,
NoFilesFoundError, NoROIDataError,
ROICoordinatesNotFoundError)
from simba.utils.printing import SimbaTimer, log_event, stdout_success
from simba.utils.read_write import (get_video_meta_data, read_df,
read_frm_of_video)
from simba.utils.warnings import FrameRangeWarning
OUTSIDE_ROI = 'OUTSIDE REGIONS OF INTEREST'
[docs]class ROIPlotter(ConfigReader):
"""
Visualize the ROI data (number of entries/exits, time-spent in ROIs etc).
.. note::
`ROI tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/ROI_tutorial_new.md>`__.
.. seealso::
Use :func:`simba.plotting.ROI_plotter_mp.ROIPlotMultiprocess` for improved run-time.
.. image:: _static/img/ROIPlot_1.png
:alt: ROIPlot 1
:width: 800
:align: center
.. video:: _static/img/ROIPlot.webm
:width: 800
:autoplay:
:loop:
:muted:
:align: center
.. video:: _static/img/outside_roi_example.mp4
:width: 800
:autoplay:
:loop:
:muted:
:align: center
.. youtube:: Q2ByLfwJIaw
:width: 640
:height: 480
:align: center
:param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format.
:param Union[str, os.PathLike] video_path: Path to video file to create ROI visualizations for.
:param List[str] body_parts: List of the body-parts to use as proxy for animal locations.
:param bool outside_roi: If True, SimBA will treat all areas NOT covered by a ROI drawing as a single additional ROI and visualize the stats for this single ROI. Default: False.
:param float threshold: Float between 0 and 1. Body-part locations detected below this confidence threshold are filtered. Default: 0.0.
:param Optional[bool] verbose: If True, print progress messages during video creation. Default: True.
:param bool show_animal_name: If True, display animal names on the video frames. Default: False.
:param bool show_body_part: If True, display body-part locations as circles on the video frames. Default: True.
:param Optional[Union[str, os.PathLike]] data_path: Optional path to the pose-estimation data. If None, then locates file in ``outlier_corrected_movement_location`` directory. Default: None.
:param Optional[Union[str, os.PathLike]] save_path: Optional path to where to save video. If None, then saves it in ``frames/output/roi_analysis`` directory of SimBA project. Default: None.
:param Optional[List[Tuple[int, int, int]]] bp_colors: Optional list of tuples of same length as body_parts representing the colors of the body-parts in RGB format. Defaults to None and colors are automatically chosen. Default: None.
:param Optional[List[Union[int]]] bp_sizes: Optional list of integers representing the sizes of the pose estimated body-part location circles. Defaults to None and size is automatically inferred. Default: None.
:param Tuple[int, int, int] border_bg_clr: RGB tuple representing the background color of the border area where statistics are displayed. Default: (0, 0, 0).
:example:
>>> test = ROIPlotter(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini',
>>> video_path="/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/videos/SI_DAY3_308_CD1_PRESENT.mp4",
>>> body_parts=['Nose'],
>>> show_body_part=True,
>>> show_animal_name=True)
>>> test.run()
:example II:
>>> test = ROIPlotter(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
>>> video_path=r"C:\troubleshooting\mitra\project_folder\videos\501_MA142_Gi_Saline_0513.mp4",
>>> body_parts=['Nose'],
>>> show_body_part=True,
>>> show_animal_name=False)
>>> test.run()
"""
def __init__(self,
config_path: Union[str, os.PathLike],
video_path: Union[str, os.PathLike],
body_parts: List[str],
outside_roi: bool = False,
threshold: float = 0.0,
verbose: Optional[bool] = True,
show_animal_name: bool = False,
show_body_part: bool = True,
show_bbox: bool = False,
data_path: Optional[Union[str, os.PathLike]] = None,
save_path: Optional[Union[str, os.PathLike]] = None,
bp_colors: Optional[List[Tuple[int, int, int]]] = None,
bp_sizes: Optional[List[Union[int]]] = None,
border_bg_clr: Tuple[int, int, int] = (0, 0, 0)):
log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals()))
check_float(name=f'{self.__class__.__name__} threshold', value=threshold, min_value=0.0, max_value=1.0)
check_valid_boolean(value=outside_roi, source=f'{self.__class__.__name__} outside_roi', raise_error=True)
check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose', raise_error=True)
check_valid_boolean(value=show_bbox, source=f'{self.__class__.__name__} show_bbox', raise_error=True)
self.video_meta = get_video_meta_data(video_path=video_path)
self.video_path = video_path
self.video_name = self.video_meta['video_name']
ConfigReader.__init__(self, config_path=config_path)
if not os.path.isfile(self.roi_coordinates_path):
raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path)
self.read_roi_data()
self.sliced_roi_dict, self.shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name)
if len(self.shape_names) == 0:
raise NoROIDataError(msg=f"Cannot plot ROI data for video {self.video_name}. No ROIs defined for this video.")
if data_path is None:
data_path = os.path.join(self.outlier_corrected_dir, f'{self.video_name}.{self.file_type}')
else:
if not os.path.isfile(data_path):
raise NoFilesFoundError(msg=f"SIMBA ERROR: Could not find the file at path {data_path}. Make sure the data file exist to create ROI visualizations", source=self.__class__.__name__)
check_file_exist_and_readable(file_path=data_path)
if save_path is None:
save_path = os.path.join(self.project_path, Paths.ROI_ANALYSIS.value, f'{self.video_name}.mp4')
if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path))
else:
check_if_dir_exists(os.path.dirname(save_path))
self.save_path, self.data_path = save_path, data_path
check_valid_lst(data=body_parts, source=f'{self.__class__.__name__} body-parts', valid_dtypes=(str,), min_len=1)
check_if_valid_rgb_tuple(source=f'{self.__class__.__name__} border_bg_clr', data=border_bg_clr, raise_error=True)
check_valid_boolean(value=show_body_part, source=f'{self.__class__.__name__} show_body_part', raise_error=True)
check_valid_boolean(value=show_animal_name, source=f'{self.__class__.__name__} show_animal_name', raise_error=True)
if outside_roi: self.shape_names.append(OUTSIDE_ROI)
if len(set(body_parts)) != len(body_parts):
raise DuplicationError(msg=f'All body-part entries have to be unique. Got {body_parts}', source=self.__class__.__name__)
for bp in body_parts:
if bp not in self.body_parts_lst: raise BodypartColumnNotFoundError(msg=f'The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}', source=self.__class__.__name__)
self.roi_analyzer = ROIAggregateStatisticsAnalyzer(config_path=self.config_path, data_path=self.data_path, detailed_bout_data=True, threshold=threshold, body_parts=body_parts, outside_rois=outside_roi, verbose=verbose)
self.roi_analyzer.run()
if bp_colors is not None:
check_valid_lst(data=bp_colors, source=f'{self.__class__.__name__} bp_colors', valid_dtypes=(tuple,), exact_len=len(body_parts), raise_error=True)
_ = [check_if_valid_rgb_tuple(x) for x in bp_colors]
self.color_lst = bp_colors
else:
self.color_lst = create_color_palettes(self.roi_analyzer.animal_cnt, len(body_parts))[0]
self.bp_sizes = bp_sizes
try:
self.detailed_roi_data = pd.concat(self.roi_analyzer.detailed_dfs, axis=0).reset_index(drop=True)
except ValueError:
self.detailed_roi_data = None
self.detailed_roi_data = pd.concat(self.roi_analyzer.detailed_dfs, axis=0).reset_index(drop=True)
self.bp_dict = self.roi_analyzer.bp_dict
self.animal_names = [self.find_animal_name_from_body_part_name(bp_name=x, bp_dict=self.animal_bp_dict) for x in body_parts]
self.data_df = read_df(file_path=self.data_path, file_type=self.file_type).fillna(0.0).reset_index(drop=True)
self.shape_columns = []
for x in itertools.product(self.animal_names, self.shape_names):
self.data_df[f"{x[0]}_{x[1]}"] = 0; self.shape_columns.append(f"{x[0]}_{x[1]}")
self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
check_video_and_data_frm_count_align(video=self.video_path, data=self.data_df, name=self.video_name, raise_error=False)
self.cap = cv2.VideoCapture(self.video_path)
self.threshold, self.body_parts, self.outside_roi, self.verbose, self.border_bg_clr = threshold, body_parts, outside_roi, verbose, border_bg_clr
self.show_pose, self.show_animal_name, self.show_bbox = show_body_part, show_animal_name, show_bbox
self.roi_dict_ = get_roi_dict_from_dfs(rectangle_df=self.sliced_roi_dict[Keys.ROI_RECTANGLES.value], circle_df=self.sliced_roi_dict[Keys.ROI_CIRCLES.value], polygon_df=self.sliced_roi_dict[Keys.ROI_POLYGONS.value])
def __get_circle_sizes(self):
optimal_circle_size = PlottingMixin().get_optimal_circle_size(frame_size=(int(self.video_meta["height"]), int(self.video_meta["height"])), circle_frame_ratio=70)
if self.bp_sizes is None:
self.circle_sizes = [optimal_circle_size] * len(self.animal_names)
else:
self.circle_sizes = []
for circle_size in self.bp_sizes:
if not check_int(name='circle_size', value=circle_size, min_value=1, raise_error=False)[0]:
self.circle_sizes.append(optimal_circle_size)
else:
self.circle_sizes.append(int(circle_size))
def __get_roi_columns(self):
if self.detailed_roi_data is not None:
roi_entries_dict = self.detailed_roi_data[["ANIMAL", "Event", "Start_frame", "End_frame"]].to_dict(orient="records")
for entry_dict in roi_entries_dict:
entry, exit = int(entry_dict["Start_frame"]), int(entry_dict["End_frame"])
entry_dict["frame_range"] = list(range(entry, exit + 1))
col_name = f'{entry_dict["ANIMAL"]}_{entry_dict["Event"]}'
self.data_df[col_name][self.data_df.index.isin(entry_dict["frame_range"])] = 1
def __get_bordered_img_size(self) -> Tuple[int, int]:
img = read_frm_of_video(video_path=self.video_path, frame_index=0)
self.base_img = cv2.copyMakeBorder(img, 0, 0, 0, int(self.video_meta["width"]), borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0])
return (self.base_img.shape[0], self.base_img.shape[1])
def __get_text_locs(self) -> dict:
loc_dict = {}
txt_strs = []
for animal_cnt, animal_name in enumerate(self.animal_names):
for shape in self.shape_names:
txt_strs.append(f'{animal_name} {shape} entries')
longest_text_str = max(txt_strs, key=len)
self.font_size, x_spacer, y_spacer = PlottingMixin().get_optimal_font_scales(text=longest_text_str, accepted_px_width=int(self.video_meta["width"] / 1.5), accepted_px_height=int(self.video_meta["height"] / 10), text_thickness=TextOptions.TEXT_THICKNESS.value)
row_counter = TextOptions.FIRST_LINE_SPACING.value
for animal_cnt, animal_name in enumerate(self.animal_names):
loc_dict[animal_name] = {}
for shape in self.shape_names:
loc_dict[animal_name][shape] = {}
loc_dict[animal_name][shape]["timer_text"] = f"{shape} {animal_name} timer:"
loc_dict[animal_name][shape]["entries_text"] = f"{shape} {animal_name} entries:"
loc_dict[animal_name][shape]["timer_text_loc"] = ((self.video_meta["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta["height"] - (self.video_meta["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter))
loc_dict[animal_name][shape]["timer_data_loc"] = (int(self.video_meta["width"] + x_spacer + TextOptions.BORDER_BUFFER_X.value), (self.video_meta["height"] - (self.video_meta["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter))
row_counter += 1
loc_dict[animal_name][shape]["entries_text_loc"] = ((self.video_meta["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta["height"] - (self.video_meta["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter))
loc_dict[animal_name][shape]["entries_data_loc"] = (int(self.video_meta["width"] + x_spacer + TextOptions.BORDER_BUFFER_X.value), (self.video_meta["height"]- (self.video_meta["height"] + TextOptions.BORDER_BUFFER_Y.value) + y_spacer * row_counter))
row_counter += 1
return loc_dict
def __get_counters(self) -> dict:
cnt_dict = {}
for animal_cnt, animal_name in enumerate(self.animal_names):
cnt_dict[animal_name] = {}
for shape in self.shape_names:
cnt_dict[animal_name][shape] = {}
cnt_dict[animal_name][shape]["timer"] = 0
cnt_dict[animal_name][shape]["entries"] = 0
cnt_dict[animal_name][shape]["entry_status"] = False
return cnt_dict
def __insert_texts(self, roi_dict, img):
for animal_name in self.animal_names:
for shape_name, shape_data in roi_dict.items():
img = cv2.putText(img, self.loc_dict[animal_name][shape_name]["timer_text"], self.loc_dict[animal_name][shape_name]["timer_text_loc"], TextOptions.FONT.value, self.font_size, shape_data['Color BGR'], TextOptions.TEXT_THICKNESS.value)
img = cv2.putText(img, self.loc_dict[animal_name][shape_name]["entries_text"], self.loc_dict[animal_name][shape_name]["entries_text_loc"], TextOptions.FONT.value, self.font_size, shape_data['Color BGR'], TextOptions.TEXT_THICKNESS.value)
if self.outside_roi:
img = cv2.putText(img, self.loc_dict[animal_name][OUTSIDE_ROI]["timer_text"], self.loc_dict[animal_name][OUTSIDE_ROI]["timer_text_loc"], TextOptions.FONT.value, self.font_size, TextOptions.WHITE.value, TextOptions.TEXT_THICKNESS.value)
img = cv2.putText(img, self.loc_dict[animal_name][OUTSIDE_ROI]["entries_text"], self.loc_dict[animal_name][OUTSIDE_ROI]["entries_text_loc"], TextOptions.FONT.value, self.font_size, TextOptions.WHITE.value, TextOptions.TEXT_THICKNESS.value)
return img
def __get_cumulative_data(self):
for animal_name in self.animal_names:
for shape in self.shape_names:
self.data_df[f"{animal_name}_{shape}_cum_sum_time"] = (self.data_df[f"{animal_name}_{shape}"].cumsum() / self.video_meta['fps'])
roi_bouts = list(detect_bouts(data_df=self.data_df, target_lst=[f"{animal_name}_{shape}"], fps=self.video_meta['fps'])["Start_frame"])
self.data_df[f"{animal_name}_{shape}_entry"] = 0
self.data_df.loc[roi_bouts, f"{animal_name}_{shape}_entry"] = 1
self.data_df[f"{animal_name}_{shape}_cum_sum_entries"] = (self.data_df[f"{animal_name}_{shape}_entry"].cumsum())
def run(self):
video_timer = SimbaTimer(start=True)
self.__get_circle_sizes()
self.__get_roi_columns()
self.border_img_h, self.border_img_w = self.__get_bordered_img_size()
writer = cv2.VideoWriter(self.save_path, self.fourcc, self.video_meta["fps"], (self.border_img_w, self.border_img_h))
self.loc_dict = self.__get_text_locs()
self.cnt_dict = self.__get_counters()
self.__get_cumulative_data()
frame_cnt = 0
while self.cap.isOpened():
ret, img = self.cap.read()
if ret:
img = cv2.copyMakeBorder(img, 0, 0, 0, int(self.video_meta["width"]), borderType=cv2.BORDER_CONSTANT, value=self.border_bg_clr)
img = self.__insert_texts(self.roi_dict_, img)
img = PlottingMixin.roi_dict_onto_img(img=img, roi_dict=self.sliced_roi_dict)
for animal_cnt, animal_name in enumerate(self.animal_names):
x, y, p = (self.data_df.loc[frame_cnt, self.bp_dict[animal_name]].fillna(0.0).values.astype(np.int32))
if (self.threshold <= p) and self.show_pose:
img = cv2.circle(img, (x, y), self.circle_sizes[animal_cnt], self.color_lst[animal_cnt], -1)
if (self.threshold <= p) and self.show_animal_name:
img = cv2.putText(img, animal_name, (x, y), self.font, self.font_size, self.color_lst[animal_cnt], TextOptions.TEXT_THICKNESS.value)
if (self.threshold <= p) and self.show_bbox:
x_cols, y_cols = self.animal_bp_dict[animal_name]['X_bps'], self.animal_bp_dict[animal_name]['Y_bps']
animal_cols = [x for pair in zip(x_cols, y_cols) for x in pair]
animal_cords = self.data_df.loc[frame_cnt, animal_cols].fillna(0.0).values.astype(np.int32).reshape(-1, 2)
try:
bbox = GeometryMixin().keypoints_to_axis_aligned_bounding_box(keypoints=animal_cords.reshape(-1, len(animal_cords), 2).astype(np.int32))
img = cv2.polylines(img, [bbox], True, self.color_lst[animal_cnt], thickness=self.circle_sizes[animal_cnt], lineType=cv2.LINE_AA)
except Exception as e:
# print(e.args)
pass
for animal_cnt, animal_name in enumerate(self.animal_names):
for shape_name, shape_data in self.roi_dict_.items():
time = str(round(self.data_df.loc[frame_cnt, f"{animal_name}_{shape_name}_cum_sum_time"], 2))
entries = str(int(self.data_df.loc[frame_cnt, f"{animal_name}_{shape_name}_cum_sum_entries"]))
img = cv2.putText(img, time, self.loc_dict[animal_name][shape_name]["timer_data_loc"], self.font, self.font_size, shape_data["Color BGR"], TextOptions.TEXT_THICKNESS.value)
img = cv2.putText(img, entries, self.loc_dict[animal_name][shape_name]["entries_data_loc"], self.font, self.font_size, shape_data["Color BGR"], TextOptions.TEXT_THICKNESS.value)
if self.outside_roi:
time = str(round(self.data_df.loc[frame_cnt, f"{animal_name}_{OUTSIDE_ROI}_cum_sum_time"], 2))
entries = str(int(self.data_df.loc[frame_cnt, f"{animal_name}_{OUTSIDE_ROI}_cum_sum_entries"]))
img = cv2.putText(img, time, self.loc_dict[animal_name][OUTSIDE_ROI]["timer_data_loc"], self.font, self.font_size, TextOptions.WHITE.value, TextOptions.TEXT_THICKNESS.value)
img = cv2.putText(img, entries, self.loc_dict[animal_name][OUTSIDE_ROI]["entries_data_loc"], self.font, self.font_size, TextOptions.WHITE.value, TextOptions.TEXT_THICKNESS.value)
writer.write(img)
if self.verbose: print(f"Frame: {frame_cnt+1} / {self.video_meta['frame_count']}, Video: {self.video_name}.")
frame_cnt += 1
else:
FrameRangeWarning(msg=f'Could not read frame {frame_cnt} in video {self.video_name}', source=self.__class__.__name__)
break
writer.release()
video_timer.stop_timer()
if self.verbose: stdout_success(msg=f"Video {self.video_name} created. Video saved at {self.save_path}", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__)
# if __name__ == "__main__":
# test = ROIPlotter(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
# video_path=r"C:\troubleshooting\mitra\project_folder\videos\502_MA141_Gi_Saline_0517.mp4",
# body_parts=['Nose'],
# style_attr={'show_body_part': True, 'show_animal_name': False})
# test.run()
# test = ROIPlotter(config_path=r"C:\troubleshooting\roi_duplicates\project_folder\project_config.ini",
# video_path=r"C:\troubleshooting\roi_duplicates\project_folder\videos\2021-12-21_15-03-57_CO_Trimmed.mp4",
# body_parts=['Snout'],
# style_attr={'show_body_part': True, 'show_animal_name': False})
# test.run()
# if __name__ == "__main__":
# test = ROIPlotMultiprocess(config_path=r"C:\troubleshooting\roi_duplicates\project_folder\project_config.ini",
# video_path=r"C:\troubleshooting\roi_duplicates\project_folder\videos\2021-12-21_15-03-57_CO_Trimmed.mp4",
# body_parts=['Snout'],
# style_attr={'show_body_part': True, 'show_animal_name': False},
# bp_sizes=[20],
# bp_colors=[(155, 255, 243)])
# test.run()
#
# test = ROIPlotter(config_path=r"C:\troubleshooting\mitra\project_folder\project_config.ini",
# video_path=r"C:\troubleshooting\mitra\project_folder\videos\501_MA142_Gi_Saline_0513.mp4",
# body_parts=['Nose'],
# style_attr={'show_body_part': True, 'show_animal_name': False},
# outside_roi=True)
# test.run()
# test = ROIPlot(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/project_config.ini',
# video_path="/Users/simon/Desktop/envs/simba/troubleshooting/mouse_open_field/project_folder/videos/SI_DAY3_308_CD1_PRESENT.mp4",
# body_parts=['Nose'],
# style_attr={'show_body_part': True, 'show_animal_name': True})
# test.run()
# test = ROIPlot(ini_path=r'/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini',
# video_path="termite_test.mp4",
# style_attr={'Show_body_part': True, 'Show_animal_name': True})
# test.insert_data()
# test.visualize_ROI_data()
# test = ROIPlot(ini_path=r'/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini', video_path="termite_test.mp4")
# test.insert_data()
# test.visualize_ROI_data()
# test = ROIPlot(ini_path=r'/Users/simon/Desktop/troubleshooting/train_model_project/project_folder/project_config.ini', video_path=r"Together_1.avi")
# test.insert_data()
# test.visualize_ROI_data()
# test = ROIPlot(ini_path=r'/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
# video_path="Together_1.avi",
# style_attr={'Show_body_part': True, 'Show_animal_name': False},
# body_parts={f'Simon': 'Ear_left_1'})
# test.insert_data()
# test.run()
# test = ROIPlot(ini_path=r'/Users/simon/Desktop/envs/troubleshooting/Termites_5/project_folder/project_config.ini',
# video_path="termite_test.mp4",
# style_attr={'Show_body_part': True, 'Show_animal_name': True},
# body_parts={f'Simon': 'Termite_1_Head_1'})
# test.insert_data()
# test.run()