import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
try:
from typing import Literal
except:
from typing_extensions import Literal
from typing import List, Optional, Union
import numpy as np
import pandas as pd
import torch
from simba.utils.checks import (check_file_exist_and_readable, check_float,
check_if_dir_exists, check_int,
check_valid_boolean, check_valid_lst,
get_fn_ext)
from simba.utils.data import resample_geometry_vertices
from simba.utils.enums import Options
from simba.utils.errors import InvalidFileTypeError
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
get_video_meta_data)
from simba.utils.yolo import load_yolo_model, yolo_predict
TASK = 'segment'
[docs]class YOLOSegmentationInference():
"""
Run inference on video(s) using a trained YOLO segmentation model.
:param Union[str, os.Pathlike] weights_path: Path to the trained YOLO `.pt` weights file.
:param Union[str, os.Pathlike] video_path: Path to a single video or a list of video paths to run inference on.
:param bool verbose: Whether to print progress information. Default is True.
:param Union[str, os.Pathlike] save_dir: Directory where output videos and data will be saved.
:param Union[str, int] device: Device to run inference on; use 'cpu' or an integer GPU index (e.g., 0).
:param str format: Optional export format for the model. Supported values: "onnx", "engine", "torchscript", "onnxsimplify", "coreml", "openvino", "pb", "tf", "tflite". Defaults to None.
:param Optional[int] batch_size: Number of frames to process at once. Increase for faster performance with sufficient memory.
:param int torch_threads: Number of CPU threads to use (when on CPU).
:param bool half_precision: Whether to use half-precision (FP16) for inference on GPU. Default is True.
:param bool stream: Whether to stream video processing (less memory, suitable for long videos).
:param float threshold: Confidence threshold for object/segmentation detection.
:param int max_tracks: Optional maximum number of objects to track. If None, tracking is disabled.
:param bool interpolate: Whether to interpolate results (useful for smoothing or low-FPS videos).
:param int imgsz: Inference image size (width/height in pixels); must be multiple of 32.
:param float iou: IoU threshold for non-max suppression (NMS).
:param bool retina_msk: Whether to use high-resolution Retina-style masks.
:param int vertice_cnt: Number of vertices used to approximate the segmentation mask polygon.
.. important::
The ``imgsz`` parameter is critical for mask quality. Segmentation requires pixel-level precision along
object boundaries, so spatial detail lost to downscaling hurts segmentation far more than detection or pose
tasks. Set ``imgsz`` as large as your GPU memory allows. The default ``640`` may be too coarse for
high-quality segmentation masks.
.. note::
To **create** YOLO segmentation dataset for fitting, use :func:`simba.third_party_label_appenders.transform.labelme_to_yolo_seg.LabelmeKeypoints2YoloSeg`.
To fit YOLO model, see `:func:`simba.model.yolo_fit.FitYolo`.
To visualize the segmentation results, see :func:`simba.plotting.yolo_seg_visualizer.YOLOSegmentationVisualizer`
:example:
>>> weights_path = r"D:/platea/yolo_071525/mdl/train3/weights/best.pt"
>>> video_path = r"D:/platea/platea_videos/videos/clipped/10B_Mouse_5-choice_MustTouchTrainingNEWFINAL_a7.mp4"
>>> save_dir = r"D:/platea/platea_videos/videos/yolo_results"
>>> runner = YOLOSegmentationInference(weights_path=weights_path, video_path=video_path, save_dir=save_dir, verbose=True, device=0, format=None, stream=True, batch_size=10, imgsz=320, interpolate=True, threshold=0.8, retina_msk=True)
>>> runner.run()
"""
def __init__(self,
weights_path: Union[str, os.PathLike],
video_path: Union[Union[str, os.PathLike], List[Union[str, os.PathLike]]],
verbose: Optional[bool] = True,
save_dir: Optional[Union[str, os.PathLike]] = None,
device: Union[Literal['cpu'], int] = 0,
format: Optional[str] = None,
batch_size: Optional[int] = 4,
torch_threads: int = 8,
half_precision: bool = True,
stream: bool = False,
threshold: float = 0.5,
max_tracks: int = 300,
interpolate: bool = False,
imgsz: int = 640,
iou: float = 0.5,
retina_msk: Optional[bool] = False,
vertice_cnt: int = 30):
if isinstance(video_path, list):
check_valid_lst(data=video_path, source=f'{self.__class__.__name__} video_path', valid_dtypes=(str, np.str_,), min_len=1)
elif os.path.isfile(video_path):
video_path = [video_path]
elif os.path.isdir(video_path):
video_path = find_files_of_filetypes_in_directory(directory=video_path, extensions=Options.ALL_VIDEO_FORMAT_OPTIONS.value, as_dict=False)
for i in video_path:
_ = get_video_meta_data(video_path=i)
check_file_exist_and_readable(file_path=weights_path)
check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose')
check_valid_boolean(value=interpolate, source=f'{self.__class__.__name__} interpolate')
check_valid_boolean(value=retina_msk, source=f'{self.__class__.__name__} retina_msk')
check_int(name=f'{self.__class__.__name__} batch_size', value=batch_size, min_value=1)
check_int(name=f'{self.__class__.__name__} imgsz', value=imgsz, min_value=1)
check_int(name=f'{self.__class__.__name__} vertice_cnt', value=vertice_cnt, min_value=3)
check_float(name=f'{self.__class__.__name__} threshold', value=threshold, min_value=10e-6, max_value=1.0)
check_float(name=f'{self.__class__.__name__} iou', value=iou, min_value=10e-6, max_value=1.0)
check_int(name=f'{self.__class__.__name__} max_tracks', value=max_tracks, min_value=1)
if save_dir is not None:
check_if_dir_exists(in_dir=save_dir, source=f'{self.__class__.__name__} save_dir')
torch.set_num_threads(torch_threads)
if verbose: stdout_information(msg=f'Loading YOLO model from {weights_path}...')
self.model = load_yolo_model(weights_path=weights_path, device=device, format=format)
self.half_precision, self.stream, self.video_path, self.retina_msk = half_precision, stream, video_path, retina_msk
self.device, self.batch_size, self.threshold, self.max_tracks, self.iou = device, batch_size, threshold, max_tracks, iou
self.verbose, self.save_dir, self.imgsz, self.interpolate = verbose, save_dir, imgsz, interpolate
self.vertice_cnt = vertice_cnt
if self.model.model.task != TASK:
raise InvalidFileTypeError(msg=f'The model {weights_path} is not a segmentation model. It is a {self.model.model.task} model', source=self.__class__.__name__)
self.vertice_col_names = ['FRAME', 'ID']
for i in range(self.vertice_cnt):
self.vertice_col_names.append(f"VERTICE_{i}_X"); self.vertice_col_names.append(f"VERTICE_{i}_Y")
def run(self):
results = {}
timer = SimbaTimer(start=True)
for path in self.video_path:
_, video_name, _ = get_fn_ext(filepath=path)
video_meta_data = get_video_meta_data(video_path=path)
video_results = []
video_predictions = yolo_predict(model=self.model,
source=path,
half=self.half_precision,
batch_size=self.batch_size,
stream=self.stream,
imgsz=self.imgsz,
device=self.device,
threshold=self.threshold,
max_detections=self.max_tracks,
verbose=self.verbose,
iou=self.iou,
retina_msk=self.retina_msk)
for frm_cnt, video_prediction in enumerate(video_predictions):
if video_prediction.masks is not None:
boxes = video_prediction.boxes.data
boxes = boxes.cpu().numpy().astype(np.float32)
detected_classes = boxes[:, -1].astype(int) if boxes.size > 0 else []
detected_masks = video_prediction.masks.xy
for detection_cnt, detected_class in enumerate(detected_classes):
mask = detected_masks[detection_cnt].reshape(-1, detected_masks[detection_cnt].shape[0], 2)
vertices = resample_geometry_vertices(vertices=mask, vertice_cnt=self.vertice_cnt).flatten()
vertices = np.insert(vertices, 0, detected_class)
vertices = np.insert(vertices, 0, frm_cnt)
video_results.append(vertices)
vertices = pd.DataFrame(video_results, columns=self.vertice_col_names)
if self.interpolate:
vertices = self._interpolate_missing_frames(vertices=vertices, total_frames=int(video_meta_data['frame_count']))
if self.save_dir:
save_path = os.path.join(self.save_dir, f'{video_name}.csv')
vertices.to_csv(save_path)
else:
results[video_name] = vertices
timer.stop_timer()
if not self.save_dir:
if self.verbose: stdout_success(f'YOLO results created', timer.elapsed_time_str)
return results
else:
if self.verbose:
stdout_success(f'YOLO results saved in {self.save_dir} directory', timer.elapsed_time_str)
return None
def _interpolate_missing_frames(self, vertices: pd.DataFrame, total_frames: int) -> pd.DataFrame:
if vertices.empty:
return vertices
coord_cols = [c for c in self.vertice_col_names if c not in ('FRAME', 'ID')]
unique_ids = vertices['ID'].unique()
all_results = []
for track_id in unique_ids:
id_data = vertices[vertices['ID'] == track_id].drop_duplicates(subset='FRAME', keep='first').copy()
id_data = id_data.set_index('FRAME').reindex(range(total_frames))
id_data['ID'] = track_id
id_data[coord_cols] = id_data[coord_cols].interpolate(method='linear', limit_direction='both')
id_data = id_data.reset_index().rename(columns={'index': 'FRAME'})
all_results.append(id_data)
return pd.concat(all_results, axis=0).reset_index(drop=True)[self.vertice_col_names]
# weights_path = r"E:\litpose_yolo\yolo_from_sam3\mdl\train\weights\best.pt"
# video_path = r'E:\litpose_yolo\pi\videos'
# save_dir=r"E:\litpose_yolo\yolo_from_sam3\csv_results"
# i = YOLOSegmentationInference(weights_path=weights_path,
# video_path=video_path,
# save_dir=save_dir,
# verbose=True,
# device=0,
# format=None,
# stream=True,
# max_tracks=1,
# batch_size=10,
# imgsz=320,
# interpolate=True,
# threshold=0.8,
# retina_msk=False)
# i.run()