Source code for simba.plotting.yolo_annotation_visualizer

import os
import random
from typing import Dict, List, Optional, Tuple, Union

try:
    from typing import Literal
except:
    from typing_extensions import Literal

import cv2
import numpy as np
import yaml

from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.checks import (check_file_exist_and_readable,
                                check_if_dir_exists, check_int, check_str,
                                check_valid_boolean)
from simba.utils.data import create_color_palette
from simba.utils.enums import Options
from simba.utils.errors import InvalidInputError, NoDataError
from simba.utils.printing import SimbaTimer, stdout_information, stdout_success
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
                                    get_fn_ext)
from simba.utils.yolo import detect_yolo_project_type

BBOX_VALUE_CNT = 4
KPT_DIM = 3

[docs]class YOLOAnnotationVisualizer(object): """ Visualize YOLO annotation label files overlaid on their source images. .. seealso:: For visualizing YOLO bounding-box inference results on video, see :func:`simba.plotting.yolo_visualize.YOLOVisualizer`. For visualizing YOLO keypoint pose-estimation results on video, see :func:`simba.plotting.yolo_pose_visualizer.YOLOPoseVisualizer`. For visualizing YOLO segmentation polygon results on video, see :func:`simba.plotting.yolo_seg_visualizer.YOLOSegmentationVisualizer`. For auto-detecting the YOLO project type from a label file, see :func:`simba.utils.yolo.detect_yolo_project_type`. :param Union[str, os.PathLike] map_yaml_path: Path to the YOLO project ``map.yaml`` file. :param Union[str, os.PathLike] save_dir: Directory where annotated images are saved. :param Optional[str] split: Which split to visualize: ``'train'``, ``'val'``, or ``'all'``. Default ``'all'``. :param Optional[int] n: Number of images to visualize. If ``None``, visualize every image. Default ``None``. :param Optional[int] circle_size: Radius of keypoint circles. If ``None``, computed from image dimensions. :param Optional[int] thickness: Line thickness for bounding boxes / polygon edges. If ``None``, computed from image dimensions. :param str palette: Color palette name (e.g. ``'Set1'``). Default ``'Set1'``. :param str img_format: Output image format extension. Default ``'.png'``. :param float seg_opacity: Opacity of filled segmentation polygons (0.0–1.0). Default ``0.5``. :param bool show_names: If True, draw class name labels on each annotation. Default False. :param bool show_outline: If True, draw polygon outline for segmentation annotations. Default False. :param bool verbose: Print progress messages. Default ``True``. :example: >>> viz = YOLOAnnotationVisualizer(map_yaml_path=r'F:\netholabs\moira_lp_sam\map.yaml', save_dir=r'F:\netholabs\annotation_visualizations', n=400) >>> viz.run() >>> viz = YOLOAnnotationVisualizer(map_yaml_path=r'/path/to/map.yaml', save_dir=r'/path/to/output') >>> viz.run() >>> viz = YOLOAnnotationVisualizer(map_yaml_path=r'/path/to/map.yaml', save_dir=r'/path/to/output', n=50, circle_size=5, thickness=2, img_format='.jpeg') >>> viz.run() """ def __init__(self, map_yaml_path: Union[str, os.PathLike], save_dir: Union[str, os.PathLike], split: Optional[Literal['train', 'val', 'all']] = 'all', n: Optional[int] = None, circle_size: Optional[int] = None, thickness: Optional[int] = None, palette: str = 'Set1', img_format: str = '.png', seg_opacity: float = 0.5, show_names: bool = False, show_outline: bool = False, verbose: bool = True): check_file_exist_and_readable(file_path=map_yaml_path) check_if_dir_exists(in_dir=save_dir, source=f'{self.__class__.__name__} save_dir') check_str(name=f'{self.__class__.__name__} split', value=split, options=('train', 'val', 'all')) check_valid_boolean(value=[verbose], source=f'{self.__class__.__name__} verbose', raise_error=True) if n is not None: check_int(name=f'{self.__class__.__name__} n', value=n, min_value=1) if circle_size is not None: check_int(name=f'{self.__class__.__name__} circle_size', value=circle_size, min_value=1) if thickness is not None: check_int(name=f'{self.__class__.__name__} thickness', value=thickness, min_value=1) if not img_format.startswith('.'): img_format = f'.{img_format}' check_str(name=f'{self.__class__.__name__} img_format', value=img_format.lower(), options=('.png', '.jpeg', '.jpg', '.bmp', '.webp')) with open(map_yaml_path, 'r') as f: self.yolo_map = yaml.safe_load(f) required_keys = ['path', 'names'] missing = [k for k in required_keys if k not in self.yolo_map] if len(missing) > 0: raise InvalidInputError(msg=f'map.yaml missing required keys: {missing}', source=self.__class__.__name__) self.project_path = self.yolo_map['path'] self.names = self.yolo_map['names'] self.kpt_shape = self.yolo_map.get('kpt_shape', None) self.save_dir = save_dir self.split = split self.n = n self.circle_size = circle_size self.thickness = thickness self.palette = palette self.img_format = img_format.lower() self.seg_opacity = seg_opacity self.show_names = show_names self.show_outline = show_outline self.verbose = verbose def _find_image_label_pairs(self) -> List[Tuple[str, str]]: splits = [] if self.split == 'all': for key in ('train', 'val', 'test'): if key in self.yolo_map: splits.append(key) else: if self.split not in self.yolo_map: raise InvalidInputError(msg=f'Split "{self.split}" not found in map.yaml', source=self.__class__.__name__) splits.append(self.split) pairs = [] for s in splits: img_dir = os.path.join(self.project_path, self.yolo_map[s]) if not os.path.isabs(img_dir): img_dir = os.path.normpath(os.path.join(self.project_path, self.yolo_map[s])) lbl_dir_candidate = img_dir.replace(os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep) if lbl_dir_candidate == img_dir: lbl_dir_candidate = os.path.join(self.project_path, 'labels', s) if not os.path.isdir(img_dir): raise InvalidInputError(msg=f'Image directory not found: {img_dir}', source=self.__class__.__name__) if not os.path.isdir(lbl_dir_candidate): raise InvalidInputError(msg=f'Label directory not found: {lbl_dir_candidate}', source=self.__class__.__name__) img_files = find_files_of_filetypes_in_directory(directory=img_dir, extensions=Options.ALL_IMAGE_FORMAT_OPTIONS.value, as_dict=True, raise_error=True) lbl_files = find_files_of_filetypes_in_directory(directory=lbl_dir_candidate, extensions=['.txt'], as_dict=True, raise_error=False) if lbl_files is None: lbl_files = {} for img_name, img_path in img_files.items(): if img_name in lbl_files: pairs.append((img_path, lbl_files[img_name])) if len(pairs) == 0: raise NoDataError(msg='No matched image/label pairs found.', source=self.__class__.__name__) return pairs @staticmethod def _parse_bbox_line(parts: List[str], img_w: int, img_h: int) -> Tuple[int, np.ndarray]: class_id = int(parts[0]) xc, yc, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4]) xc_px, yc_px = xc * img_w, yc * img_h w_px, h_px = w * img_w, h * img_h x1 = int(xc_px - w_px / 2) y1 = int(yc_px - h_px / 2) x2 = int(xc_px + w_px / 2) y2 = int(yc_px + h_px / 2) return class_id, np.array([x1, y1, x2, y2], dtype=np.int32) @staticmethod def _parse_keypoint_line(parts: List[str], img_w: int, img_h: int) -> Tuple[int, np.ndarray, np.ndarray]: class_id = int(parts[0]) xc, yc, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4]) xc_px, yc_px = xc * img_w, yc * img_h w_px, h_px = w * img_w, h * img_h x1 = int(xc_px - w_px / 2) y1 = int(yc_px - h_px / 2) x2 = int(xc_px + w_px / 2) y2 = int(yc_px + h_px / 2) bbox = np.array([x1, y1, x2, y2], dtype=np.int32) kp_values = [float(v) for v in parts[5:]] kps = [] for i in range(0, len(kp_values), KPT_DIM): kx = int(kp_values[i] * img_w) ky = int(kp_values[i + 1] * img_h) vis = int(kp_values[i + 2]) kps.append((kx, ky, vis)) return class_id, bbox, np.array(kps, dtype=np.int32) @staticmethod def _parse_seg_line(parts: List[str], img_w: int, img_h: int) -> Tuple[int, np.ndarray]: class_id = int(parts[0]) coords = [float(v) for v in parts[1:]] points = [] for i in range(0, len(coords), 2): px = int(coords[i] * img_w) py = int(coords[i + 1] * img_h) points.append([px, py]) return class_id, np.array(points, dtype=np.int32) def _draw_bbox(self, img: np.ndarray, class_id: int, bbox: np.ndarray, color: tuple, thickness: int) -> np.ndarray: cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, thickness, lineType=cv2.LINE_AA) if self.show_names: label = self.names.get(class_id, str(class_id)) cv2.putText(img, label, (bbox[0], max(bbox[1] - 5, 0)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, max(1, thickness // 2), cv2.LINE_AA) return img def _draw_keypoints(self, img: np.ndarray, class_id: int, bbox: np.ndarray, kps: np.ndarray, colors: list, circle_size: int, thickness: int) -> np.ndarray: color = tuple(int(c) for c in colors[0]) cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, thickness, lineType=cv2.LINE_AA) if self.show_names: label = self.names.get(class_id, str(class_id)) cv2.putText(img, label, (bbox[0], max(bbox[1] - 5, 0)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, max(1, thickness // 2), cv2.LINE_AA) for kp_idx, kp in enumerate(kps): if kp[2] > 0: clr_idx = min(kp_idx + 1, len(colors) - 1) clr = tuple(int(c) for c in colors[clr_idx]) cv2.circle(img, (int(kp[0]), int(kp[1])), circle_size, clr, -1) return img def _draw_segmentation(self, img: np.ndarray, class_id: int, polygon: np.ndarray, color: tuple, thickness: int) -> np.ndarray: overlay = img.copy() pts = polygon.reshape((-1, 1, 2)) if self.show_outline: cv2.polylines(img, [pts], isClosed=True, color=color, thickness=thickness, lineType=cv2.LINE_AA) cv2.fillPoly(overlay, [pts], color=color) cv2.addWeighted(overlay, self.seg_opacity, img, 1 - self.seg_opacity, 0, img) if self.show_names: label = self.names.get(class_id, str(class_id)) cx, cy = int(polygon[:, 0].mean()), int(polygon[:, 1].mean()) cv2.putText(img, label, (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, max(1, thickness // 2), cv2.LINE_AA) return img def run(self): timer = SimbaTimer(start=True) pairs = self._find_image_label_pairs() first_lbl_path = pairs[0][1] project_type = detect_yolo_project_type(label_path=first_lbl_path) print(project_type) if self.verbose: stdout_information(msg=f'Detected YOLO project type: {project_type} ({len(pairs)} image/label pairs found)', source=self.__class__.__name__) if self.n is not None: sample_n = min(self.n, len(pairs)) pairs = random.sample(pairs, sample_n) n_classes = len(self.names) kp_count = 0 if project_type == 'keypoint' and self.kpt_shape is not None: kp_count = self.kpt_shape[0] palette_size = max(n_classes, kp_count + 1, 10) class_colors = create_color_palette(pallete_name=self.palette, increments=palette_size) for img_cnt, (img_path, lbl_path) in enumerate(pairs): img = cv2.imread(img_path) if img is None: if self.verbose: stdout_information(msg=f'Could not read image: {img_path}, skipping...', source=self.__class__.__name__) continue img_h, img_w = img.shape[:2] if self.circle_size is None: circle_size = PlottingMixin().get_optimal_circle_size(frame_size=(img_w, img_h), circle_frame_ratio=80) else: circle_size = self.circle_size if self.thickness is None: thickness = max(1, circle_size) else: thickness = self.thickness with open(lbl_path, 'r') as f: lines = f.readlines() for line in lines: parts = line.strip().split() if len(parts) < 2: continue print(project_type) if project_type == 'bbox': class_id, bbox = self._parse_bbox_line(parts, img_w, img_h) color = tuple(int(c) for c in class_colors[class_id % len(class_colors)]) img = self._draw_bbox(img, class_id, bbox, color, thickness) elif project_type == 'keypoint': class_id, bbox, kps = self._parse_keypoint_line(parts, img_w, img_h) colors = class_colors[:max(kps.shape[0] + 1, 1)] img = self._draw_keypoints(img, class_id, bbox, kps, colors, circle_size, thickness) elif project_type == 'segmentation': class_id, polygon = self._parse_seg_line(parts, img_w, img_h) color = tuple(int(c) for c in class_colors[class_id % len(class_colors)]) img = self._draw_segmentation(img, class_id, polygon, color, thickness) _, img_name, _ = get_fn_ext(filepath=img_path) save_path = os.path.join(self.save_dir, f'{img_name}{self.img_format}') cv2.imwrite(save_path, img) if self.verbose: stdout_information(msg=f'Annotated image {img_cnt + 1}/{len(pairs)} saved ({img_name})', source=self.__class__.__name__) timer.stop_timer() stdout_success(msg=f'{len(pairs)} annotated images saved in {self.save_dir}', source=self.__class__.__name__, elapsed_time=timer.elapsed_time_str)
#if __name__ == '__main__': # viz = YOLOAnnotationVisualizer(map_yaml_path=r"E:\open_video\open_field_2\yolo_seg_project\map.yaml", # save_dir=r"E:\open_video\open_field_2\yolo_seg_project\annotation_imgs", # n=150, # show_names=False, # show_outline=False) # viz.run()