Source code for simba.third_party_label_appenders.transform.simba_to_yolo

import os
import random
from typing import Dict, List, Optional, Tuple, Union

import pandas as pd

try:
    from typing import Literal
except:
    from typing_extensions import Literal

import cv2
import numpy as np

from simba.mixins.config_reader import ConfigReader
from simba.third_party_label_appenders.transform.utils import (
    create_yolo_keypoint_yaml, get_yolo_keypoint_flip_idx)
from simba.utils.checks import (check_file_exist_and_readable, check_float,
                                check_if_dir_exists, check_int, check_str,
                                check_valid_boolean, check_valid_dataframe,
                                check_valid_tuple,
                                check_video_and_data_frm_count_align)
from simba.utils.enums import Formats, Options
from simba.utils.errors import NoFilesFoundError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (create_directory,
                                    find_files_of_filetypes_in_directory,
                                    get_video_meta_data, read_df,
                                    read_frm_of_video)
from simba.utils.warnings import NoDataFoundWarning
from simba.utils.yolo import keypoint_array_to_yolo_annotation_str


[docs]class SimBA2Yolo: """ Convert pose estimation data from a SimBA project into the YOLO keypoint format, including frame sampling, image-label pair creation, bounding box computation, and train/validation splitting. .. note:: For creating the ``flip_idx``, see :func:`simba.third_party_label_appenders.converters.get_yolo_keypoint_flip_idx`. For creating the ``bp_id_idx``, see :func:`simba.third_party_label_appenders.converters.get_yolo_keypoint_bp_id_idx` :param Union[str, os.PathLike] config_path: Path to the SimBA project `.ini` configuration file. :param Union[str, os.PathLike] save_dir: Directory where YOLO-formatted data will be saved. Subdirectories for images/labels (train/val) are created. :param Optional[Union[str, os.PathLike] data_dir: Optional directory containing outlier-corrected SimBA pose estimation data. If None, uses path from config. :param float train_size: Proportion of samples to allocate to the training set (range 0.1–0.99). Remaining samples go to validation. :param bool verbose: If True, prints progress updates to the console. :param bool greyscale: If True, saves extracted video frames in greyscale. Otherwise, saves in color. :param float padding: Padding added around the bounding box (as a proportion of image dimensions, range 0.0–1.0). Useful if animal body-parts are in a "line". :param Tuple[int, ...] flip_idx: Tuple defining symmetric keypoint indices for horizontal flipping. Used to write the `map.yaml` file. If None, then attempt to infer. :param Dict[int, str] names: Dictionary mapping instance IDs to class names. Used in annotation labels and `map.yaml`. :param Optional[int] sample_size: If specified, limits the number of randomly sampled frames per video. If None, all frames are used. :param Optional[Dict[int, Union[Tuple[int], List[int]]]] bp_id_idx: Optional mapping of instance IDs to keypoint index groups, allowing support for multiple animals per frame. Must match keys in `map_dict`. :param Optional[str] single_id: If the data contains pose-estimation for multiple indivisuals, but you want to treat it as examples of a single individual, pass the name of the single individual. Defaults to None, and the YOLO data will be formatted to the number of objects which the H5 data contains. :return: None. Saves YOLO-formatted images and annotations to disk in the `save_dir` location. :example: >>> SAVE_DIR = r'D:\troubleshooting\mitra\mitra_yolo' >>> CONFIG_PATH = r"C:\troubleshooting\mitra\project_folder\project_config.ini" >>> runner = SimBA2Yolo(config_path=CONFIG_PATH, save_dir=SAVE_DIR, sample_size=10, verbose=True) >>> runner.run() """ def __init__(self, config_path: Union[str, os.PathLike], save_dir: Union[str, os.PathLike], data_dir: Optional[Union[str, os.PathLike]] = None, train_size: float = 0.7, verbose: bool = False, greyscale: bool = False, clahe: bool = False, padding: float = 0.00, threshold: float = 0.00, flip_idx: Optional[Tuple[int, ...]] = None, names: Tuple[str, ...] = ('animal_1',), sample_size: Optional[int] = None, bp_id_idx: Optional[Dict[int, Union[Tuple[int], List[int]]]] = None, single_id: Optional[str] = None) -> None: check_valid_boolean(value=verbose, source=f'{self.__class__.__name__} verbose') check_valid_boolean(value=greyscale, source=f'{self.__class__.__name__} greyscale') check_valid_boolean(value=clahe, source=f'{self.__class__.__name__} clahe') check_file_exist_and_readable(file_path=config_path) check_float(name=f'{self.__class__.__name__} padding', value=padding, max_value=1.0, min_value=0.0, raise_error=True) check_float(name=f'{self.__class__.__name__} train_size', value=train_size, max_value=0.99, min_value=0.1) check_float(name=f'{self.__class__.__name__} threshold', value=threshold, max_value=1.0, min_value=0.0) check_valid_tuple(x=names, source=f'{self.__class__.__name__} names', valid_dtypes=(str,), minimum_length=1) check_if_dir_exists(in_dir=save_dir) if flip_idx is not None: check_valid_tuple(x=flip_idx, source=self.__class__.__name__, valid_dtypes=(int,), minimum_length=1) self.img_dir, self.lbl_dir = os.path.join(save_dir, 'images'), os.path.join(save_dir, 'labels') self.img_train_dir, self.img_val_dir = os.path.join(self.img_dir, 'train'), os.path.join(self.img_dir, 'val') self.lbl_train_dir, self.lb_val_dir = os.path.join(self.lbl_dir, 'train'), os.path.join(self.lbl_dir, 'val') create_directory(paths=[self.img_train_dir, self.img_val_dir, self.lbl_train_dir, self.lb_val_dir], overwrite=False) self.map_path = os.path.join(save_dir, 'map.yaml') if single_id is not None: check_str(name=f'{self.__class__.__name__} single_id', value=single_id, raise_error=True) if sample_size is not None: check_int(name=f'{self.__class__.__name__} sample', value=sample_size, min_value=1) self.config = ConfigReader(config_path=config_path) if data_dir is not None: check_if_dir_exists(in_dir=data_dir, source=f'{self.__class__.__name__} data_dir') else: data_dir = self.config.outlier_corrected_dir self.data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=[f'.{self.config.file_type}'], raise_error=True, as_dict=True) self.video_paths = find_files_of_filetypes_in_directory(directory=self.config.video_dir, extensions=Options.ALL_VIDEO_FORMAT_OPTIONS.value, raise_error=True, as_dict=True) missing_videos = [x for x in self.data_paths.keys() if x not in self.video_paths.keys()] if len(missing_videos) > 0: NoDataFoundWarning(msg=f'Data files {missing_videos} do not have corresponding videos in the {self.config.video_dir} directory', source=self.__class__.__name__) self.data_w_video = [x for x in self.data_paths.keys() if x in self.video_paths.keys()] if len(self.data_w_video) == 0: raise NoFilesFoundError(msg=f'None of the data files in {data_dir} have matching videos in the {self.config.video_dir} directory', source=self.__class__.__name__) self.sample_size, self.train_size, self.verbose, self.save_dir = sample_size, train_size, verbose, save_dir self.greyscale, self.bp_id_idx, self.padding, self.flip_idx = greyscale, bp_id_idx, padding, flip_idx self.clahe, self.threshold, self.single_id = clahe, threshold, single_id self.names = {0: self.single_id} if self.single_id is not None else {k:v for k, v in enumerate(names)} def run(self): annotations, timer, body_part_headers = [], SimbaTimer(start=True), [] for file_cnt, video_name in enumerate(self.data_w_video): data = read_df(file_path=self.data_paths[video_name], file_type=self.config.file_type) data.columns = [x.lower() for x in list(data.columns)] bp_header_names = [x.lower() for x in self.config.bp_headers] check_valid_dataframe(df=data, source=f'{self.__class__.__name__} {self.data_paths[video_name]}', valid_dtypes=Formats.NUMERIC_DTYPES.value) video_path = self.video_paths[video_name] check_video_and_data_frm_count_align(video=video_path, data=data, name=self.data_paths[video_name], raise_error=True) p_data = data[data.columns[list(data.columns.str.endswith('_p'))]] data = data.loc[:, ~data.columns.str.endswith('_p')].reset_index(drop=True) data = data.iloc[(p_data[(p_data > self.threshold).all(axis=1)].index)] data = data[[x for x in bp_header_names if not x.endswith('_p')]] body_part_headers = data.columns data['video'], frm_cnt = video_name, len(data) if self.sample_size is None: video_sample_idx = list(range(0, frm_cnt)) else: video_sample_idx = list(range(0, frm_cnt)) if self.sample_size > frm_cnt else random.sample(list(range(0, frm_cnt)), self.sample_size) annotations.append(data.iloc[video_sample_idx].reset_index(drop=False)) if self.flip_idx is None: self.flip_idx = get_yolo_keypoint_flip_idx(x=list(dict.fromkeys([x[:-2] for x in body_part_headers]))) annotations = pd.concat(annotations, axis=0).reset_index(drop=True) video_names = annotations.pop('video').reset_index(drop=True).values train_idx = random.sample(list(annotations['index']), int(len(annotations) * self.train_size)) bp_id_idx = np.array_split(np.array(range(0, int(len(body_part_headers) / 2))), len(self.names.keys())) bp_id_idx = [list(x) for x in bp_id_idx] for cnt, (idx, idx_data) in enumerate(annotations.iterrows()): vid_path = self.video_paths[video_names[cnt]] video_meta = get_video_meta_data(video_path=vid_path) frm_idx, keypoints = idx_data[0], idx_data.values[1:].reshape(-1, 2) mask = (keypoints[:, 0] == 0.0) & (keypoints[:, 1] == 0.0) keypoints[mask] = np.nan if np.all(np.isnan(keypoints)) or np.all(keypoints == 0.0) or np.all(np.isnan(keypoints) | (keypoints == 0.0)): continue img_lbl = '' if self.verbose: print(f'Processing image {cnt + 1}/{len(annotations)}...') file_name = f'{video_meta["video_name"]}.{frm_idx}' if frm_idx in train_idx: img_save_path, lbl_save_path = os.path.join(self.img_train_dir, f'{file_name}.png'), os.path.join(self.lbl_train_dir, f'{file_name}.txt') else: img_save_path, lbl_save_path = os.path.join(self.img_val_dir, f'{file_name}.png'), os.path.join(self.lb_val_dir, f'{file_name}.txt') img = read_frm_of_video(video_path=vid_path, frame_index=frm_idx, greyscale=self.greyscale, clahe=self.clahe) img_h, img_w = img.shape[0], img.shape[1] keypoints_with_id = {} for k, idx in enumerate(bp_id_idx): keypoints_with_id[k] = keypoints[idx, :] for id, keypoints in keypoints_with_id.items(): if np.all(np.isnan(keypoints)) or np.all(keypoints == 0.0) or np.all(np.isnan(keypoints) | (keypoints == 0.0)): continue visability_col = np.full((keypoints.shape[0], 1), fill_value=2).flatten() keypoints = np.insert(keypoints, 2, visability_col, axis=1) both_zero = (keypoints[:, 0] == 0) & (keypoints[:, 1] == 0) has_nan_or_inf = ~np.isfinite(keypoints[:, 0]) | ~np.isfinite(keypoints[:, 1]) mask = both_zero | has_nan_or_inf keypoints[mask, 2] = 0 instance_str = f'{id} ' if self.single_id is None else '0 ' instance_str += keypoint_array_to_yolo_annotation_str(x=keypoints, img_w=img_w, img_h=img_h, padding=self.padding) img_lbl += instance_str.strip() + '\n' with open(lbl_save_path, mode='wt', encoding='utf-8') as f: f.write(img_lbl) cv2.imwrite(img_save_path, img) create_yolo_keypoint_yaml(path=self.save_dir, train_path=self.img_train_dir, val_path=self.img_val_dir, names=self.names, save_path=self.map_path, kpt_shape=(len(self.flip_idx), 3), flip_idx=self.flip_idx) timer.stop_timer() stdout_success(msg=f'YOLO formated data saved in {self.save_dir} directory', source=self.__class__.__name__, elapsed_time=timer.elapsed_time_str)
# # SAVE_DIR = r'E:\troubleshooting\mitra\yolo_0126\yolo_train_0126' # CONFIG_PATH = r"E:\troubleshooting\mitra\project_folder\project_config.ini" # runner = SimBA2Yolo(config_path=CONFIG_PATH, save_dir=SAVE_DIR, sample_size=50, verbose=True, names=('animal_1',), threshold=0.5) # runner.run()