__author__ = "Simon Nilsson; sronilsson@gmail.com"
import itertools
import os
from typing import Any, Dict, List, Union
import cv2
import numpy as np
from simba.mixins.config_reader import ConfigReader
from simba.mixins.plotting_mixin import PlottingMixin
from simba.roi_tools.ROI_feature_analyzer import ROIFeatureCreator
from simba.utils.checks import (check_file_exist_and_readable,
check_if_keys_exist_in_dict, check_int,
check_valid_array, check_valid_dataframe,
check_valid_lst,
check_video_and_data_frm_count_align)
from simba.utils.data import slice_roi_dict_for_video
from simba.utils.enums import Formats, Keys, TextOptions
from simba.utils.errors import (BodypartColumnNotFoundError, NoFilesFoundError,
ROICoordinatesNotFoundError)
from simba.utils.printing import stdout_success
from simba.utils.read_write import get_fn_ext, get_video_meta_data, read_df
from simba.utils.warnings import DuplicateNamesWarning
ROI_CENTERS = "roi_centers"
ROI_EAR_TAGS = "roi_ear_tags"
DIRECTIONALITY = "directionality"
DIRECTIONALITY_STYLE = "directionality_style"
BORDER_COLOR = "border_color"
POSE = "pose_estimation"
ANIMAL_NAMES = "animal_names"
STYLE_KEYS = [ROI_CENTERS,
ROI_EAR_TAGS,
DIRECTIONALITY,
BORDER_COLOR,
POSE,
DIRECTIONALITY_STYLE,
ANIMAL_NAMES]
[docs]class ROIfeatureVisualizer(ConfigReader):
"""
Visualizing features that depend on the relationships between the location of the animals and user-defined
ROIs. E.g., distances to centroids of ROIs, if animals are directing towards ROIs, and if animals are within ROIs.
.. note::
For improved run-time, see :meth:`simba.ROI_feature_visualizer_mp.ROIfeatureVisualizerMultiprocess` for multiprocess class.
`Tutorials <https://github.com/sgoldenlab/simba/blob/master/docs/ROI_tutorial.md#part-5-visualizing-roi-features>`__.
.. image:: _static/img/roi_visualize.png
:alt: Roi visualize
:width: 400
:align: center
.. image:: _static/img/ROIfeatureVisualizer_1.png
:alt: ROIfeature Visualizer 1
:width: 700
:align: center
.. image:: _static/img/ROIfeatureVisualizer_2.png
:alt: ROIfeature Visualizer 2
:width: 700
:align: center
:param Union[str, os.PathLike] config_path: Path to SimBA project config file in Configparser format
:param Union[str, os.PathLike] video_path: Path to video file to overlay ROI features on.
:param List[str] body_parts: List of body-parts to use as proxy for animal location(s).
:param Dict[str, Any] style_attr: User-defined styles (sizes, colors etc.)
:example:
>>> style_attr = {'roi_centers': True, 'roi_ear_tags': True, 'directionality': True, 'directionality_style': 'funnel', 'border_color': (0, 0, 0), 'pose_estimation': True, 'animal_names': True}
>>> test = ROIfeatureVisualizer(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini', video_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/videos/2022-06-20_NOB_DOT_4.mp4', style_attr=style_attr, body_parts=['Nose'])
>>> test.run()
"""
def __init__(self,
config_path: Union[str, os.PathLike],
video_path: Union[str, os.PathLike],
body_parts: List[str],
style_attr: Dict[str, Any]):
check_file_exist_and_readable(file_path=config_path)
check_file_exist_and_readable(file_path=video_path)
check_if_keys_exist_in_dict(data=style_attr, key=STYLE_KEYS, name=f"{self.__class__.__name__} style_attr")
_, self.video_name, _ = get_fn_ext(video_path)
ConfigReader.__init__(self, config_path=config_path)
if not os.path.isfile(self.roi_coordinates_path):
raise ROICoordinatesNotFoundError(expected_file_path=self.roi_coordinates_path)
self.read_roi_data()
self.roi_dict, shape_names = slice_roi_dict_for_video(data=self.roi_dict, video_name=self.video_name)
self.data_path = os.path.join(self.outlier_corrected_dir, f"{self.video_name}.{self.file_type}")
if not os.path.isfile(self.data_path):
raise NoFilesFoundError(msg=f"SIMBA ERROR: Could not find the file at path {self.data_path}. Make sure the data file exist to create ROI visualizations", source=self.__class__.__name__)
if not os.path.exists(self.roi_features_save_dir):
os.makedirs(self.roi_features_save_dir)
self.save_path = os.path.join(self.roi_features_save_dir, f"{self.video_name}.mp4")
check_valid_lst(data=body_parts, source=f"{self.__class__.__name__} body-parts", valid_dtypes=(str,), min_len=1)
for bp in body_parts:
if bp not in self.body_parts_lst:
raise BodypartColumnNotFoundError(msg=f"The body-part {bp} is not a valid body-part in the SimBA project. Options: {self.body_parts_lst}", source=self.__class__.__name__)
self.roi_feature_creator = ROIFeatureCreator(config_path=config_path, body_parts=body_parts, append_data=False, data_path=self.data_path)
self.roi_feature_creator.run()
self.bp_lk = self.roi_feature_creator.bp_lk
self.animal_bp_names = [f"{v[0]} {v[1]}" for v in self.bp_lk.values()]
self.animal_names = [v[0] for v in self.bp_lk.values()]
self.video_meta_data = get_video_meta_data(video_path, fps_as_int=False)
self.fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
self.cap = cv2.VideoCapture(video_path)
check_video_and_data_frm_count_align(video=video_path, data=self.data_path, name=video_path, raise_error=False)
self.style_attr = style_attr
self.direct_viable = self.roi_feature_creator.roi_directing_viable
self.data_df = read_df(file_path=self.data_path, file_type=self.file_type).reset_index(drop=True)
self.shape_dicts = self.__create_shape_dicts()
self.directing_df = self.roi_feature_creator.dr
def __calc_text_locs(self):
add_spacer = TextOptions.FIRST_LINE_SPACING.value
self.loc_dict = {}
txt_strs = []
for animal_cnt, animal_bp_name in enumerate(self.animal_bp_names):
for shape in self.shape_names:
txt_strs.append(animal_bp_name + ' ' + shape + ' center distance')
longest_text_str = str(max(txt_strs, key=len))
self.font_size, self.x_scaler, self.y_scaler = PlottingMixin().get_optimal_font_scales(text=longest_text_str, accepted_px_width=int(self.video_meta_data['width'] / 2), accepted_px_height=int(self.video_meta_data['height'] / 15), text_thickness=3)
print(self.x_scaler, self.y_scaler)
self.circle_size = PlottingMixin().get_optimal_circle_size(frame_size=(int(self.video_meta_data['width']), int(self.video_meta_data['height'])), circle_frame_ratio=100)
for animal_cnt, animal_data in self.bp_lk.items():
animal, animal_bp, _ = animal_data
animal_name = f"{animal} {animal_bp}"
self.loc_dict[animal_name] = {}
self.loc_dict[animal] = {}
for shape in self.shape_names:
self.loc_dict[animal_name][shape] = {}
self.loc_dict[animal_name][shape]["in_zone_text"] = f"{shape} {animal_name} in zone"
self.loc_dict[animal_name][shape]["distance_text"] = f"{shape} {animal_name} distance"
self.loc_dict[animal_name][shape]["in_zone_text_loc"] = (int(self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.y_scaler * add_spacer))
self.loc_dict[animal_name][shape]["in_zone_data_loc"] = (int(self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value + self.x_scaler), (self.video_meta_data["height"] - (self.video_meta_data["height"] + TextOptions.BORDER_BUFFER_Y.value) + self.y_scaler * add_spacer))
add_spacer += 1
self.loc_dict[animal_name][shape]["distance_text_loc"] = (int(self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.y_scaler * add_spacer))
self.loc_dict[animal_name][shape]["distance_data_loc"] = (int(self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value + self.x_scaler), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.y_scaler * add_spacer))
add_spacer += 1
if self.direct_viable and self.style_attr[DIRECTIONALITY]:
self.loc_dict[animal][shape] = {}
self.loc_dict[animal][shape]["directing_text"] = f"{shape} {animal} facing"
self.loc_dict[animal][shape]["directing_text_loc"] = (int(self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10) + self.y_scaler * add_spacer))
self.loc_dict[animal][shape]["directing_data_loc"] = (int(self.video_meta_data["width"] + TextOptions.BORDER_BUFFER_X.value + self.x_scaler), (self.video_meta_data["height"] - (self.video_meta_data["height"] + 10)+ self.y_scaler * add_spacer))
add_spacer += 1
def __create_shape_dicts(self):
shape_dicts = {}
for shape, df in self.roi_dict.items():
if not df["Name"].is_unique:
df = df.drop_duplicates(subset=["Name"], keep="first")
DuplicateNamesWarning(msg=f'Some of your ROIs with the same shape ({shape}) has the same names for video {self.video_name}. E.g., you have two rectangles named "My rectangle". SimBA prefers ROI shapes with unique names. SimBA will keep one of the unique shape names and drop the rest.', source=self.__class__.__name__)
d = df.set_index("Name").to_dict(orient="index")
shape_dicts = {**shape_dicts, **d}
return shape_dicts
def __insert_texts(self, shape_df):
for cnt, animal_data in self.bp_lk.items():
animal, animal_bp, _ = animal_data
animal_name = f"{animal} {animal_bp}"
for _, shape in shape_df.iterrows():
shape_name, shape_color = shape["Name"], shape["Color BGR"]
cv2.putText(self.img_w_border, self.loc_dict[animal_name][shape_name]["in_zone_text"], self.loc_dict[animal_name][shape_name]["in_zone_text_loc"], self.font, self.font_size, shape_color, TextOptions.TEXT_THICKNESS.value)
cv2.putText(self.img_w_border, self.loc_dict[animal_name][shape_name]["distance_text"], self.loc_dict[animal_name][shape_name]["distance_text_loc"], self.font, self.font_size, shape_color, TextOptions.TEXT_THICKNESS.value)
if self.direct_viable:
cv2.putText(self.img_w_border, self.loc_dict[animal][shape_name]["directing_text"], self.loc_dict[animal][shape_name]["directing_text_loc"], self.font, self.font_size, shape_color, TextOptions.TEXT_THICKNESS.value)
def run(self):
self.frame_cnt = 0
while self.cap.isOpened():
ret, self.img = self.cap.read()
if ret:
self.img_w_border = cv2.copyMakeBorder(self.img, 0, 0, 0, self.video_meta_data["width"], borderType=cv2.BORDER_CONSTANT, value=self.style_attr[BORDER_COLOR])
if self.frame_cnt == 0:
self.img_w_border_h, self.img_w_border_w = (self.img_w_border.shape[0], self.img_w_border.shape[1])
self.__calc_text_locs()
self.writer = cv2.VideoWriter(self.save_path, self.fourcc, self.video_meta_data["fps"], (self.img_w_border_w, self.img_w_border_h))
self.__insert_texts(self.roi_dict[Keys.ROI_RECTANGLES.value])
self.__insert_texts(self.roi_dict[Keys.ROI_CIRCLES.value])
self.__insert_texts(self.roi_dict[Keys.ROI_POLYGONS.value])
if self.style_attr[POSE]:
for animal_name, bp_data in self.animal_bp_dict.items():
for bp_cnt, bp in enumerate(zip(bp_data["X_bps"], bp_data["Y_bps"])):
bp_cords = self.data_df.loc[self.frame_cnt, list(bp)].values.astype(np.int64)
cv2.circle(self.img_w_border, (bp_cords[0], bp_cords[1]), self.circle_size, self.animal_bp_dict[animal_name]["colors"][bp_cnt], -1)
if self.style_attr[ANIMAL_NAMES]:
for animal_name, bp_data in self.animal_bp_dict.items():
headers = [bp_data["X_bps"][-1], bp_data["Y_bps"][-1]]
bp_cords = self.data_df.loc[self.frame_cnt, headers].values.astype(np.int64)
cv2.putText(self.img_w_border, animal_name, (bp_cords[0], bp_cords[1]), self.font, self.font_size, self.animal_bp_dict[animal_name]["colors"][0], TextOptions.TEXT_THICKNESS.value)
self.img_w_border = PlottingMixin.roi_dict_onto_img(img=self.img_w_border, roi_dict=self.roi_dict, circle_size=self.circle_size, show_tags=self.style_attr[ROI_EAR_TAGS], show_center=self.style_attr[ROI_CENTERS])
for animal_name, shape_name in itertools.product(self.animal_bp_names, self.shape_names):
in_zone_col_name = f"{shape_name} {animal_name} in zone"
distance_col_name = f"{shape_name} {animal_name} distance"
in_zone_value = str(bool(self.roi_feature_creator.out_df.loc[self.frame_cnt, in_zone_col_name]))
distance_value = round(self.roi_feature_creator.out_df.loc[self.frame_cnt, distance_col_name], 2)
cv2.putText(self.img_w_border, in_zone_value, self.loc_dict[animal_name][shape_name]["in_zone_data_loc"], self.font, self.font_size, self.shape_dicts[shape_name]["Color BGR"], TextOptions.TEXT_THICKNESS.value)
cv2.putText(self.img_w_border, str(distance_value), self.loc_dict[animal_name][shape_name]["distance_data_loc"], self.font, self.font_size, self.shape_dicts[shape_name]["Color BGR"], TextOptions.TEXT_THICKNESS.value)
if self.direct_viable and self.style_attr[DIRECTIONALITY]:
for animal_name, shape_name in itertools.product(self.animal_names, self.shape_names):
facing_col_name = f"{shape_name} {animal_name} facing"
facing_value = self.roi_feature_creator.out_df.loc[self.frame_cnt, facing_col_name]
cv2.putText(self.img_w_border, str(bool(facing_value)), self.loc_dict[animal_name][shape_name]["directing_data_loc"], self.font, self.font_size, self.shape_dicts[shape_name]["Color BGR"], TextOptions.TEXT_THICKNESS.value)
if facing_value:
self.img_w_border = PlottingMixin.insert_directing_line(directing_df=self.directing_df,
img=self.img_w_border,
shape_name=shape_name,
animal_name=animal_name,
frame_id=self.frame_cnt,
color=self.shape_dicts[shape_name]["Color BGR"],
thickness=self.shape_dicts[shape_name]["Thickness"],
style=self.style_attr[DIRECTIONALITY_STYLE])
self.frame_cnt += 1
self.writer.write(np.uint8(self.img_w_border))
print(f"Frame: {self.frame_cnt} / {self.video_meta_data['frame_count']}. Video: {self.video_name} ...")
else:
break
self.timer.stop_timer()
self.cap.release()
self.writer.release()
stdout_success(
msg=f"Feature video {self.video_name} saved in {self.save_path} directory ...",
elapsed_time=self.timer.elapsed_time_str,
)
# style_attr = {'roi_centers': True, 'roi_ear_tags': True, 'directionality': True, 'directionality_style': 'lines', 'border_color': (0, 0, 0), 'pose_estimation': True, 'animal_names': True}
# test = ROIfeatureVisualizer(config_path=r"C:\troubleshooting\roi_duplicates\project_folder\project_config.ini",
# video_path=r"C:\troubleshooting\roi_duplicates\project_folder\videos\2021-12-21_15-03-57_CO_Trimmed.mp4",
# style_attr=style_attr,
# body_parts=['Butt/Proximal Tail']) #'Butt/Proximal Tail'
# test.run()
# style_attr = {'roi_centers': True, 'roi_ear_tags': True, 'directionality': True, 'directionality_style': 'lines', 'border_color': (0, 0, 0), 'pose_estimation': True, 'animal_names': True}
# test = ROIfeatureVisualizer(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/project_config.ini',
# video_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/videos/2022-06-20_NOB_DOT_4.mp4',
# style_attr=style_attr,
# body_parts=['Nose'])
# test.run()
# style_attr = {'roi_centers': True, 'roi_ear_tags': True, 'directionality': True, 'directionality_style': 'funnel', 'border_color': (0, 128, 0), 'pose_estimation': True, 'animal_names': True}
# test = ROIfeatureVisualizer(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini',
# video_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/videos/Together_1.avi',
# style_attr=style_attr,
# body_parts=['Nose_1', 'Nose_2'])
# test.run()
# style_attr = {'ROI_centers': True, 'ROI_ear_tags': True, 'Directionality': True, 'Directionality_style': 'Line', 'Border_color': (0, 128, 0), 'Pose_estimation': True}
# test = ROIfeatureVisualizer(config_path='/Users/simon/Desktop/envs/simba_dev/tests/test_data/mouse_open_field/project_folder/project_config.ini', video_name='Video1.mp4', style_attr=style_attr)
# test.run()
# test = ROIfeatureVisualizer(config_path='/Users/simon/Desktop/train_model_project/project_folder/project_config.ini', video_name='Together_1.avi')
# test.run()
# test.save_new_features_files()