import functools
import multiprocessing
import os
import random
from copy import deepcopy
from typing import List, Optional, Tuple, Union
import cv2
import numpy as np
import pandas as pd
from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.checks import (check_file_exist_and_readable, check_float,
check_if_dir_exists, check_int,
check_valid_boolean, check_valid_cpu_pool,
check_valid_dataframe, check_valid_lst,
check_valid_tuple)
from simba.utils.data import (create_color_palette, get_cpu_pool,
terminate_cpu_pool)
from simba.utils.enums import Defaults, Options
from simba.utils.errors import (CountError, DataHeaderError, FrameRangeError,
InvalidInputError, NoDataError)
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (concatenate_videos_in_folder,
create_directory, find_core_cnt,
find_files_of_filetypes_in_directory,
get_current_time, get_fn_ext,
get_video_meta_data, read_frm_of_video,
recursive_file_search, remove_a_folder)
FRAME = 'FRAME'
CLASS_ID = 'CLASS_ID'
CONFIDENCE = 'CONFIDENCE'
CLASS_NAME = 'CLASS_NAME'
TRACK = 'TRACK'
BOX_CORD_FIELDS = ['X1', 'Y1', 'X2', 'Y2', 'X3', 'Y3', 'X4', 'Y4']
EXPECTED_COLS = [FRAME, CLASS_ID, CLASS_NAME, CONFIDENCE, 'X1', 'Y1', 'X2', 'Y2', 'X3', 'Y3', 'X4', 'Y4']
def _yolo_keypoint_visualizer(frm_ids: np.ndarray,
data: pd.DataFrame,
threshold: float,
video_path: str,
save_dir: str,
circle_size: int,
verbose: bool,
thickness: int,
palettes: dict,
bbox: bool,
skeleton: list):
batch_id, frame_rng = frm_ids[0], frm_ids[1]
start_frm, end_frm, current_frm = frame_rng[0], frame_rng[-1], frame_rng[0]
video_meta_data = get_video_meta_data(video_path=video_path, fps_as_int=False)
cap = cv2.VideoCapture(video_path)
fourcc, font = cv2.VideoWriter_fourcc(*"mp4v"), cv2.FONT_HERSHEY_DUPLEX
video_save_path = os.path.join(save_dir, f'{batch_id}.mp4')
video_writer = cv2.VideoWriter(video_save_path, fourcc, video_meta_data["fps"], (video_meta_data["width"], video_meta_data["height"]))
if TRACK in data.columns:
data = data.drop([TRACK], axis=1)
while current_frm <= end_frm:
if verbose: stdout_information(f'Processing frame {current_frm}/{video_meta_data["frame_count"]} (batch: {batch_id}, video: {video_meta_data["video_name"]})...')
img = read_frm_of_video(video_path=video_path, frame_index=current_frm)
frm_data = data.loc[data[FRAME] == current_frm]
frm_data = frm_data[frm_data[CONFIDENCE] > threshold]
for cnt, (row, row_data) in enumerate(frm_data.iterrows()):
clrs = np.array(palettes[int(row_data[CLASS_ID])]).astype(np.int32)
bbox_cords = row_data[BOX_CORD_FIELDS].values.astype(np.int32).reshape(-1, 2)
kp_coords = row_data.drop(EXPECTED_COLS).values.astype(np.int32).reshape(-1, 3)[:, :-1]
p_cords = row_data.drop(EXPECTED_COLS).values.astype(np.float64).reshape(-1, 3)[:, -1:].flatten()
clr = tuple(int(c) for c in clrs[0])
if bbox:
img = cv2.polylines(img, [bbox_cords], True, clr, thickness=thickness, lineType=cv2.LINE_AA)
for kp_cnt, kp in enumerate(kp_coords):
if p_cords[kp_cnt] > threshold:
clr = tuple(int(c) for c in clrs[kp_cnt+1])
img = cv2.circle(img, (tuple(kp)), circle_size, clr, -1)
if skeleton is not None:
for (kp1, kp2) in skeleton:
if (row_data[f'{kp1}_P'] > threshold) and (row_data[f'{kp2}_P'] > threshold):
pos_1 = np.array([row_data[[f'{kp1}_X', f'{kp1}_Y']].values.astype(np.int32)])
pos_2 = np.array([row_data[[f'{kp2}_X', f'{kp2}_Y']].values.astype(np.int32)])
img = PlottingMixin().draw_lines_on_img(img=img, start_positions=pos_1, end_positions=pos_2, color=(105,105,105), highlight_endpoint=False, thickness=int(max(1, int(circle_size/2))))
img = img.astype(np.uint8)
video_writer.write(img)
current_frm += 1
cap.release()
video_writer.release()
return batch_id
[docs]class YOLOPoseVisualizer():
"""
Visualizes YOLO-based keypoint pose estimation data on video frames and creates an annotated output video.
This class takes keypoint data (CSV) and overlays it onto the corresponding video using color-coded keypoints
and optional filtering. The result is saved as a new annotated video, and supports multicore parallel rendering
for efficient processing of long videos.
.. seealso::
To create YOLO pose data, see :func:`~simba.bounding_box_tools.yolo.yolo_pose_inference.YOLOPoseInference`
To fit YOLO model, see :func:`~simba.bounding_box_tools.yolo.yolo_fit.FitYolo`
For instructions, see `YOLO Pose Estimation Visualization Documentation <https://github.com/sgoldenlab/simba/blob/master/docs/yolo_pose_plot.md>`_.
.. video:: _static/img/YOLOPoseVisualizer.webm
:width: 900
:loop:
:muted:
:autoplay:
:align: center
.. video:: _static/img/T1.webm
:width: 1000
:autoplay:
:loop:
:muted:
:align: center
:param Union[str, os.PathLike] data_path: Path to the CSV file containing keypoint data, or folder containing keypoint data (output from YOLO pose inference).
:param Union[str, os.PathLike] video_path: Path to the original input video, or folder containing original videos, to overlay keypoints on.
:param Union[str, os.PathLike] save_dir: Directory to save the resulting annotated video.
:param Optional[Union[str, Tuple[str, ...]]] palettes: Name(s) of categorical color palettes used to draw keypoints per detected class. A single string applies to all classes; a tuple assigns one palette per class. Defaults to ('Set1',).
:param Optional[int] core_cnt: Number of CPU cores to use for parallel rendering. Defaults to -1 (all available cores).
:param float threshold: Confidence threshold for rendering bounding boxes, keypoints, and skeleton edges. Only entries with confidence >= threshold are drawn.
:param Optional[int] thickness: Thickness of bounding boxes and skeleton edges. If None, computed from frame dimensions.
:param Optional[int] circle_size: Radius of keypoint circles. If None, computed from frame dimensions.
:param Optional[bool] verbose: Set True to enable progress logging.
:param Optional[bool] bbox: Set False to disable rendering of bounding boxes around detections.
:param Optional[List[Tuple[str, str]]] skeleton: Iterable of keypoint name pairs defining skeleton edges to render when both keypoints exceed ``threshold``.
:param Optional[bool] recursive: If True, search data and video directories recursively; otherwise only the top level is scanned.
:param Optional[int] sample_n: Randomly sample ``sample_n`` data files to visualize. If None, visualize all detected files.
:example:
>>> video_path = r"/mnt/c/troubleshooting/mitra/project_folder/videos/501_MA142_Gi_CNO_0521.mp4"
>>> data_path = "/mnt/c/troubleshooting/mitra/yolo_pose/501_MA142_Gi_CNO_0521.csv"
>>> kp_vis = YOLOPoseVisualizer(data_path=data_path,
>>> video_path=video_path,
>>> save_dir='/mnt/c/troubleshooting/mitra/yolo_pose/',
>>> core_cnt=18)
>>> kp_vis.run()
"""
def __init__(self,
data_path: Union[str, os.PathLike],
video_path: Union[str, os.PathLike],
save_dir: Union[str, os.PathLike],
palettes: Optional[Union[str, Tuple[str, ...]]] = 'Set1',
core_cnt: int = -1,
threshold: float = 0.0,
thickness: Optional[int] = None,
circle_size: Optional[int] = None,
verbose: bool = True,
bbox: Optional[bool] = True,
skeleton: List[Tuple[str, str]] = None,
recursive: Optional[bool] = False,
pool: Optional[multiprocessing.Pool] = None,
sample_n: Optional[int] = None):
check_valid_boolean(value=[recursive], source=f'{self.__class__.__name__} recursive', raise_error=True)
if os.path.isdir(data_path):
if not recursive:
data_paths = find_files_of_filetypes_in_directory(directory=data_path, extensions=['.csv'], as_dict=True, raise_error=True)
else:
data_paths = recursive_file_search(directory=data_path, extensions='csv', as_dict=True, raise_error=True)
if sample_n is not None:
check_int(name=f'{self.__class__.__name__} sample', min_value=1, raise_error=True, value=sample_n)
sample_n = min(sample_n, len(list(data_paths)))
data_paths = dict(random.sample(list(data_paths.items()), sample_n))
elif os.path.isfile(data_path):
check_file_exist_and_readable(file_path=data_path)
data_paths = {get_fn_ext(filepath=data_path)[1]: data_path}
else:
raise InvalidInputError(msg=f'{data_path} is not a valid directory path or file path', source=self.__class__.__name__)
if os.path.isdir(video_path):
if not recursive:
video_paths = find_files_of_filetypes_in_directory(directory=video_path, extensions=Options.ALL_VIDEO_FORMAT_OPTIONS.value, as_dict=True, raise_error=True)
else:
video_paths = recursive_file_search(directory=video_path, extensions=Options.ALL_VIDEO_FORMAT_OPTIONS_2.value, as_dict=True, raise_error=True)
elif os.path.isfile(video_path):
video_paths = {get_fn_ext(filepath=video_path)[1]: video_path}
else:
raise InvalidInputError(msg=f'{video_path} is not a valid directory path or file path', source=self.__class__.__name__)
missing_video_paths = [x for x in data_paths.keys() if x not in video_paths.keys()]
if len(missing_video_paths) > 0:
raise NoDataError(msg=f'The data file(s) {missing_video_paths} does not have a representative video.', source=self.__class__.__name__)
if pool is not None: check_valid_cpu_pool(value=pool, raise_error=True)
self.data_paths, self.video_paths = data_paths, video_paths
check_int(name=f'{self.__class__.__name__} core_cnt', value=core_cnt, min_value=-1, unaccepted_vals=[0])
check_float(name=f'{self.__class__.__name__} threshold', value=threshold, min_value=0.0, max_value=1.0)
if circle_size is not None:
check_int(name=f'{self.__class__.__name__} circle_size', value=circle_size, min_value=1)
if thickness is not None:
check_int(name=f'{self.__class__.__name__} thickness', value=thickness, min_value=1)
check_float(name=f'{self.__class__.__name__} threshold', value=threshold, min_value=0.0, max_value=1.0)
self.core_cnt = core_cnt
if core_cnt == -1 or core_cnt > find_core_cnt()[0]: self.core_cnt = find_core_cnt()[0]
check_if_dir_exists(in_dir=save_dir)
check_valid_boolean(value=[verbose], source=f'{self.__class__.__name__} verbose', raise_error=True)
check_valid_boolean(value=[bbox], source=f'{self.__class__.__name__} bbox', raise_error=True)
if isinstance(palettes, str):
palettes = (palettes,)
else:
check_valid_tuple(x=palettes, source=f'{self.__class__.__name__} palettes', minimum_length=1, valid_dtypes=(str,), accepted_values=Options.PALETTE_OPTIONS_CATEGORICAL.value)
self.save_dir, self.verbose, self.palette, self.thickness = save_dir, verbose, palettes, thickness
self.threshold, self.circle_size, self.thickness, self.bbox = threshold, circle_size, thickness, bbox
self.palettes, self.skeleton, self.pool = palettes, skeleton, pool
self.timer = SimbaTimer(start=True)
def run(self):
pool_terminate_flag = True if self.pool is None else False
pool = get_cpu_pool(core_cnt=self.core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value, source=self.__class__.__name__, verbose=self.verbose) if self.pool is None else self.pool
for video_cnt, (video_name, data_path) in enumerate(self.data_paths.items()):
video_timer = SimbaTimer(start=True)
self.video_temp_dir = os.path.join(self.save_dir, video_name, "temp")
if os.path.isdir(self.video_temp_dir):
remove_a_folder(folder_dir=self.video_temp_dir)
create_directory(paths=self.video_temp_dir)
self.data_df = pd.read_csv(data_path, index_col=0)
self.video_meta_data = get_video_meta_data(video_path=self.video_paths[video_name])
check_valid_dataframe(df=self.data_df, source=f'{self.__class__.__name__} {data_path}', required_fields=EXPECTED_COLS)
self.df_frm_cnt = np.unique(self.data_df[FRAME].values).shape[0]
self.classes = np.unique(self.data_df[CLASS_NAME].values)
self.save_path, self.skeleton = os.path.join(self.save_dir, f'{video_name}.mp4'), self.skeleton
if len(self.classes) != len(self.palettes):
#InvalidValueWarning(msg=f'{len(self.classes)} classes detected in {data_path}, but {len(self.palettes)} color palette names passed: {self.palettes}', source=self.__class__.__name__)
if len(self.classes) > len(Options.PALETTE_OPTIONS_CATEGORICAL.value):
raise CountError(msg=f'There are more classes ({len(self.classes)}) than available color palettes ({len(Options.PALETTE_OPTIONS_CATEGORICAL.value)})', source=self.__class__.__name__)
else:
self.palettes = Options.PALETTE_OPTIONS_CATEGORICAL.value[:len(self.classes)+1]
self.clrs = {}
for cnt, palette in enumerate(self.palettes):
self.clrs[cnt] = create_color_palette(pallete_name=palette, increments=len(self.data_df.columns) - len(EXPECTED_COLS))
if self.skeleton is not None:
check_valid_lst(data=self.skeleton, source=f'{self.__class__.__name__} skeleton', valid_dtypes=(list, tuple,), min_len=1, raise_error=True)
for i in self.skeleton:
check_valid_tuple(x=i, source=f'{self.__class__.__name__} {i}', accepted_lengths=(2,), valid_dtypes=(str,))
required_s_cols = [f'{x}_{suffix}' for xs in self.skeleton for x in xs for suffix in ('X', 'Y', 'P')]
missing = [x for x in required_s_cols if x not in self.data_df.columns]
if len(missing) > 0:
raise DataHeaderError(msg=f'Columns {missing} missing in file {data_path} as passe by skeleton.', source=self.__class__.__name__)
if self.video_meta_data['frame_count'] != self.df_frm_cnt:
raise FrameRangeError(msg=f'The bounding boxes contain data for {self.df_frm_cnt} frames, while the video is {self.video_meta_data["frame_count"]} frames ({self.video_meta_data["video_name"]})', source=self.__class__.__name__)
if self.circle_size is None:
circle_size = PlottingMixin().get_optimal_circle_size(frame_size=(self.video_meta_data['width'], self.video_meta_data['height']), circle_frame_ratio=80)
else:
circle_size = deepcopy(self.circle_size)
if self.thickness is None:
thickness = deepcopy(circle_size)
else:
thickness = deepcopy(self.thickness)
frm_batches = np.array_split(np.array(list(range(0, self.df_frm_cnt))), self.core_cnt)
frm_batches = [(i, j) for i, j in enumerate(frm_batches)]
if self.verbose: stdout_information(msg=f'Visualizing video {self.video_meta_data["video_name"]} (frame count: {self.video_meta_data["frame_count"]}, video: {video_cnt+1}/{len(list(self.data_paths.keys()))})...')
constants = functools.partial(_yolo_keypoint_visualizer,
data=self.data_df,
threshold=self.threshold,
video_path=self.video_paths[video_name],
save_dir=self.video_temp_dir,
circle_size=circle_size,
thickness=thickness,
palettes=self.clrs,
bbox=self.bbox,
skeleton=self.skeleton,
verbose=self.verbose,)
for cnt, result in enumerate(pool.imap(constants, frm_batches, chunksize=1)):
if self.verbose: stdout_information(msg=f'Video batch {result+1}/{self.core_cnt} complete...')
video_timer.stop_timer()
concatenate_videos_in_folder(in_folder=self.video_temp_dir, save_path=self.save_path, gpu=False)
if self.verbose: stdout_success(msg=f'YOLO pose video saved at {self.save_path} (Video {video_cnt+1}/{len(list(self.data_paths.keys()))})', source=self.__class__.__name__, elapsed_time=video_timer.elapsed_time_str)
if pool_terminate_flag: terminate_cpu_pool(pool=pool, force=False, source=self.__class__.__name__)
self.timer.stop_timer()
if self.verbose: stdout_success(msg=f'{len(list(self.data_paths.keys()))} YOLO pose video saved in directory {self.save_dir}', source=self.__class__.__name__, elapsed_time=self.timer.elapsed_time_str)
#
#
# if __name__ == "__main__":
# #video_path = r"D:\cvat_annotations\videos\mp4_20250624155703\s34-drinking.mp4"
# #data_path = r"D:\cvat_annotations\yolo_07032025\out_data\s34-drinking.csv"
# #save_dir = r"D:\cvat_annotations\yolo_07032025\out_video"
#
# video_path = r"D:\cvat_annotations\videos\mp4_20250624155703\s34-drinking.mp4"
#
# data_path = r"D:\cvat_annotations\yolo_mdl_07122025\out_data\s34-drinking.csv"
# save_dir = r'D:\cvat_annotations\yolo_mdl_07122025\out_video'
# video_dir = r"D:\cvat_annotations\videos\mp4_20250624155703"
# data_dir = r"D:\cvat_annotations\yolo_mdl_07122025\out_data"
# data_dir = r"D:\platea\platea_videos\videos\yolo_results"
# video_dir = r"D:\platea\platea_videos\videos\videos"
# save_dir = r'D:\platea\platea_videos\videos\yolo_kp_video_out'
#
# video_paths = find_files_of_filetypes_in_directory(directory=video_dir, extensions=['.mp4'], as_dict=True)
# data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'], as_dict=True)
#
#
# skeleton = [('NOSE', 'LEFT_EAR'),
# ('NOSE', 'RIGHT_EAR'),
# ('RIGHT_EAR', 'LEFT_EAR'),
# ('LEFT_EAR', 'LEFT_SIDE'),
# ('RIGHT_EAR', 'RIGHT_SIDE'),
# ('LEFT_SIDE', 'CENTER'),
# ('LEFT_EAR', 'CENTER'),
# ('RIGHT_EAR', 'CENTER'),
# ('CENTER', 'RIGHT_SIDE'),
# ('CENTER', 'TAIL_BASE'),
# ('LEFT_SIDE', 'TAIL_BASE'),
# ('RIGHT_SIDE', 'TAIL_BASE'),
# ('TAIL_BASE', 'TAIL_CENTER'),
# ('TAIL_CENTER', 'TAIL_TIP')]
#
# skeleton = [('NOSE', 'LEFT_EAR'),
# ('NOSE', 'RIGHT_EAR'),
# ('RIGHT_EAR', 'LEFT_EAR'),
# ('LEFT_EAR', 'LEFT_SIDE'),
# ('RIGHT_EAR', 'RIGHT_SIDE'),
# ('LEFT_SIDE', 'CENTER'),
# ('LEFT_EAR', 'CENTER'),
# ('RIGHT_EAR', 'CENTER'),
# ('CENTER', 'RIGHT_SIDE'),
# ('CENTER', 'TAIL_BASE'),
# ('LEFT_SIDE', 'TAIL_BASE'),
# ('RIGHT_SIDE', 'TAIL_BASE')]
#
# for video_name, video_path in video_paths.items():
# data_path = data_paths[video_name]
# #kp_vis = YOLOPoseVisualizer(data_path=data_path, video_path=video_path)
#
# kp_vis = YOLOPoseVisualizer(data_path= data_path,
# video_path=video_path,
# save_dir=save_dir,
# core_cnt=31,
# bbox=True,
# verbose=True,
# skeleton=skeleton,
# threshold=0.5,
# thickness=3,
# circle_size=10)
#
# kp_vis.run()
#
# # video_dir = r'D:\cvat_annotations\videos\mp4_20250624155703'
# # data_dir = r'D:\cvat_annotations\yolo_mdl_07102025\out_csv'
# # save_dir = r'D:\cvat_annotations\yolo_mdl_07102025\out_video'
# #
# # video_paths = find_files_of_filetypes_in_directory(directory=video_dir, extensions=['.mp4'], as_dict=True)
# # data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=['.csv'], as_dict=True)
# #
# # for video_name, data_path in data_paths.items():
# # kp_vis = YOLOPoseVisualizer(data_path= data_path,
# # video_path=video_paths[video_name],
# # save_dir=save_dir,
# # core_cnt=28,
# # bbox=True,
# # verbose=True)
# #
# # kp_vis.run()
# if __name__ == "__main__":
# from simba.utils.read_write import find_files_of_filetypes_in_directory
#
# video_paths = find_files_of_filetypes_in_directory(directory=r'E:\netholabs_videos\mosaics\subset', extensions=['.mp4', '.avi'], as_dict=True)
# data_paths = find_files_of_filetypes_in_directory(directory=r'E:\netholabs_videos\mosaics_inference', extensions=['.csv'], as_dict=True)
# save_dir = r"E:\netholabs_videos\mosaics_inference\out_videos"
#
# # video_paths = r"D:\netholabs\minutes_examples\minute_27.avi"
# # data_paths = r'D:\netholabs\mdls_08202025\10x\results\minute_27.csv'
#
# for video_name, video_path in video_paths.items():
# data_path = data_paths[video_name]
# kp_vis = YOLOPoseVisualizer(data_path=data_path,
# video_path=video_path,
# save_dir=save_dir,
# core_cnt=8)
# kp_vis.run()
#
# # kp_vis = YOLOPoseVisualizer(data_path=data_paths,
# # video_path=video_paths,
# # save_dir=save_dir,
# # core_cnt=18)
# # kp_vis.run()
# if __name__ == "__main__":
# video_path = r"E:\maplight_videos"
# data_path = r"E:\maplight_videos\yolo_mdl\mdl\results"
# save_dir = r'E:\maplight_videos\yolo_mdl\mdl\videos'
# kp_vis = YOLOPoseVisualizer(data_path=data_path,
# video_path=video_path,
# save_dir=save_dir,
# core_cnt=14,
# palettes=('Set1', 'Pastel1'),
# recursive=True,
# sample_n=1)
#
#
# kp_vis.run()
# if __name__ == "__main__":
# video_path = r"E:\netholabs_videos\two_tracks\videos"
# data_path = r"E:\netholabs_videos\two_tracks\csv_track_025"
# save_dir = r'E:\netholabs_videos\two_tracks\track_videos'
# kp_vis = YOLOPoseVisualizer(data_path=data_path,
# video_path=video_path,
# save_dir=save_dir,
# core_cnt=14,
# palettes=('Set1',),
# recursive=True,
# sample_n=75)
#
#
# kp_vis.run()
# if __name__ == "__main__":
# video_path = r"E:\todd_tail\sample\video1_CLAHE_CLIPLIMIT_14_TILESIZE_19_20260216171121.mp4"
# data_path = r"E:\todd_tail\yolo_model_predict\video1_CLAHE_CLIPLIMIT_14_TILESIZE_19_20260216171121.csv"
# save_dir = r'E:\todd_tail\yolo_mdl_video'
# kp_vis = YOLOPoseVisualizer(data_path=data_path,
# video_path=video_path,
# save_dir=save_dir,
# core_cnt=14,
# threshold=0.0,)
#
#
# kp_vis.run()