Source code for simba.model.yolo_fit

import os
import sys
import urllib.request
from contextlib import redirect_stderr, redirect_stdout

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import argparse
from typing import Optional, Union

try:
    from typing import Literal
except:
    from typing_extensions import Literal

try:
    from ultralytics import YOLO
except ModuleNotFoundError:
    YOLO = None

from simba.data_processors.cuda.utils import _is_cuda_available
from simba.utils.checks import (check_file_exist_and_readable,
                                check_if_dir_exists, check_int, check_str,
                                check_valid_boolean, check_valid_device,
                                check_valid_url)
from simba.utils.enums import Options
from simba.utils.errors import SimBAGPUError, SimBAPAckageVersionError
from simba.utils.printing import stdout_information
from simba.utils.read_write import find_core_cnt, get_current_time
from simba.utils.yolo import load_yolo_model

#YOLO_X_PATH = "https://huggingface.co/Ultralytics/YOLO11/resolve/main/yolo11x-pose.pt"

YOLO_M_PATH = "https://huggingface.co/Ultralytics/YOLO11/resolve/main/yolo11m-pose.pt"


[docs]class FitYolo(): """ Fit an Ultralytics YOLO model (detection, pose, or segmentation) from SimBA projects with parameter validation. .. note:: - Works with any Ultralytics model flavour (bbox, pose, segmentation). - Download starter weights from `HuggingFace <https://huggingface.co/Ultralytics>`__. - Example dataset YAMLs: `bbox <https://github.com/sgoldenlab/simba/blob/master/misc/ex_yolo_model.yaml>`__, `pose <https://github.com/sgoldenlab/simba/blob/master/misc/ex_yolo_model_keypoints.yaml>`__. .. seealso:: :func:`simba.bounding_box_tools.yolo.utils.fit_yolo` for the functional API. :func:`simba.bounding_box_tools.yolo.utils.load_yolo_model` to load trained weights. For instructions, see `YOLO Pose Estimation Training Documentation <https://github.com/sgoldenlab/simba/blob/master/docs/yolo_train.md>`_. :param Union[str, os.PathLike] weights_path: Path to base weights (e.g., ``yolo11n.pt`` or ``.onnx`` export). :param Union[str, os.PathLike] model_yaml: Dataset configuration YAML describing dataset folders and class labels. :param Union[str, os.PathLike] save_path: Directory where training outputs (weights, metrics, plots) are written. :param int epochs: Training epochs to run. Must be ≥ 1. Default ``200``. :param Union[int, float] batch: Batch size per step. Default ``16``. :param bool plots: If ``True``, Ultralytics saves training curves. Default ``True``. :param int imgsz: Square image resolution used during training. Default ``640``. :param Optional[str] format: Optional weights format override. Must belong to :class:`simba.utils.enums.Options.VALID_YOLO_FORMATS`. Default ``None``. :param Union[Literal['cpu'], int] device: Compute device string or CUDA index. Default ``0``. :param bool verbose: Emit detailed progress information. Default ``True``. :param int workers: Data-loader worker processes. Use ``-1`` for all cores. Default ``8``. :param int patience: Early-stopping patience (epochs without improvement). Default ``100``. :param Union[bool, Literal['disk']] cache: Image caching strategy. ``True`` caches all dataset images in RAM on the first epoch so subsequent epochs read from memory instead of disk (fastest, requires the dataset to fit in RAM). ``"disk"`` caches decoded images as ``.npy`` files on disk (avoids re-decoding each epoch without needing the dataset to fit in RAM, but uses more disk space). ``False`` disables caching. Default ``False``. :raises SimBAGPUError: If no CUDA-capable GPU is detected. :raises SimBAPAckageVersionError: If ``ultralytics`` is unavailable in the environment. :raises FileNotFoundError: If ``weights_path`` or ``model_yaml`` do not exist. :raises ValueError: If provided arguments fail SimBA validation checks. :example: >>> fitter = FitYolo( ... weights_path=r"D:\\yolo_weights\\yolo11n-pose.pt", ... model_yaml=r"D:\\datasets\\pose_project\\map.yaml", ... save_path=r"D:\\datasets\\pose_project\\mdl", ... epochs=300, ... batch=24, ... device=0, ... imgsz=640, ... ) >>> fitter.run() """ def __init__(self, model_yaml: Union[str, os.PathLike], save_path: Union[str, os.PathLike], weights_path: Optional[Union[str, os.PathLike]] = None, epochs: int = 200, batch: Union[int, float] = 16, plots: bool = True, imgsz: int = 640, format: Optional[str] = None, device: Union[Literal['cpu'], int] = 0, verbose: bool = True, workers: int = 8, patience: int = 500, cache: Union[bool, Literal['disk']] = False, device_id: Optional[int] = None): os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" gpu_available, gpus = _is_cuda_available() if not gpu_available: raise SimBAGPUError(msg='No GPU detected.', source=self.__class__.__name__) if device_id is not None: check_int(name=f'{self.__class__.__name__} device_id', value=device_id, min_value=0) gpu_ids = list(gpus.keys()) if device_id not in gpu_ids: raise SimBAGPUError(msg=f'GPU device_id {device_id} not found. Available GPU id(s): {gpu_ids}', source=self.__class__.__name__) os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) device = 0 if YOLO is None: raise SimBAPAckageVersionError(msg='Ultralytics package not detected.', source=self.__class__.__name__) if weights_path is not None: check_file_exist_and_readable(file_path=weights_path) self.weights_path = weights_path else: self._download_start_weights() check_file_exist_and_readable(file_path=model_yaml) check_valid_boolean(value=verbose, source=f'{__class__.__name__} verbose', raise_error=True) check_valid_boolean(value=plots, source=f'{__class__.__name__} plots', raise_error=True) if not isinstance(cache, bool): check_str(name=f'{__class__.__name__} cache', value=cache, options=('disk',), raise_error=True) else: check_valid_boolean(value=cache, source=f'{__class__.__name__} cache', raise_error=True) check_if_dir_exists(in_dir=save_path) if format is not None: check_str(name=f'{__class__.__name__} format', value=format.lower(), options=Options.VALID_YOLO_FORMATS.value, raise_error=True) check_int(name=f'{__class__.__name__} epochs', value=epochs, min_value=1) check_int(name=f'{__class__.__name__} imgsz', value=imgsz, min_value=1) check_int(name=f'{__class__.__name__} workers', value=workers, min_value=-1, unaccepted_vals=[0], max_value=find_core_cnt()[0]) check_int(name=f'{__class__.__name__} patience', value=patience, min_value=1) if workers == -1: workers = find_core_cnt()[0] check_valid_device(device=device) self.model_yaml, self.epochs, self.batch = model_yaml, epochs, batch self.imgsz, self.device, self.workers, self.format = imgsz, device, workers, format self.plots, self.save_path, self.verbose, self.patience = plots, save_path, verbose, patience self.cache = cache def _download_start_weights(self, url: str = YOLO_M_PATH, save_path: Union[str, os.PathLike] = "yolo11m-pose.pt"): print(f'No start weights provided, downloading {save_path} from {url}...') check_valid_url(url=url, raise_error=True, source=self.__class__.__name__) if not os.path.isfile(save_path): urllib.request.urlretrieve(url, save_path) stdout_information(msg=f'Downloaded initial weights from {url}', source=self.__class__.__name__) self.weights_path = save_path print(self.weights_path)
[docs] def run(self): stdout_information(msg=f'[{get_current_time()}] Please follow the YOLO pose model training in the terminal from where SimBA was launched ...', source=self.__class__.__name__) stdout_information(msg=f'[{get_current_time()}] Results will be stored in the {self.save_path} directory ..', source=self.__class__.__name__) with redirect_stdout(sys.__stdout__), redirect_stderr(sys.__stderr__): model = load_yolo_model(weights_path=self.weights_path, verbose=self.verbose, format=self.format, device=self.device) model.train(data=self.model_yaml, epochs=self.epochs, project=self.save_path, batch=self.batch, plots=self.plots, imgsz=self.imgsz, workers=self.workers, device=self.device, patience=self.patience, cache=self.cache)
# if __name__ == "__main__" and not hasattr(sys, 'ps1'): # parser = argparse.ArgumentParser(description="Fit YOLO model using ultralytics package.") # parser.add_argument('--weights_path', type=str, default=None, help='Path to the trained YOLO model weights (e.g., yolo11n-pose.pt). Omit to download default starter weights.') # parser.add_argument('--model_yaml', type=str, required=True, help='Path to map.yaml (model structure and label definitions)') # parser.add_argument('--save_path', type=str, required=True, help='Directory where trained model and logs will be saved') # parser.add_argument('--epochs', type=int, default=25, help='Number of epochs to train the model. Default is 25') # parser.add_argument('--batch', type=int, default=16, help='Batch size for training. Default is 16') # parser.add_argument('--plots', type=lambda x: str(x).lower() == 'true', default=True, help='Whether to plot training results. Use "True" or "False". Default is True') # parser.add_argument('--imgsz', type=int, default=640, help='Image size for training. Default is 640') # parser.add_argument('--format', type=str, default=None, help=f'Format of the YOLO model. Must be one of: {", ".join(Options.VALID_YOLO_FORMATS.value)}') # parser.add_argument('--device', type=str, default='0', help='Device to train on. Use "cpu" or GPU index (e.g., "0"). Default is "0"') # parser.add_argument('--verbose', type=lambda x: str(x).lower() == 'true', default=True, help='Print verbose messages. Use "True" or "False". Default is True') # parser.add_argument('--workers', type=int, default=8, help='Number of data loader workers. Default is 8. Use -1 for max cores') # parser.add_argument('--patience', type=int, default=100, help='Number of epochs to wait without improvement in validation metrics before early stopping the training. Default is 100') # # args = parser.parse_args() # # yolo_fitter = FitYolo(weights_path=args.weights_path, # model_yaml=args.model_yaml, # save_path=args.save_path, # epochs=args.epochs, # batch=args.batch, # plots=args.plots, # imgsz=args.imgsz, # format=args.format, # device=int(args.device) if args.device != 'cpu' else 'cpu', # verbose=args.verbose, # workers=args.workers, # patience=args.patience) # yolo_fitter.run() # fitter = FitYolo(weights_path=r"/home/cat/simon/yolo_0413/yolo26s.pt", # model_yaml=r"/home/cat/simon/yolo_0413/map.yaml", # save_path=r'/home/cat/simon/yolo_0413/mdl', # epochs=5000, # batch=2000, # format=None, # device_id=0, # imgsz=256) # fitter.run() # if __name__ == "__main__": # parser = argparse.ArgumentParser(description="Fit YOLO model") # parser.add_argument('--weights_path', type=str, default=r"/home/cat/simon/yolo_0413/yolo26s.pt") # parser.add_argument('--model_yaml', type=str, default=r"/home/cat/simon/yolo_0413/map.yaml") # parser.add_argument('--save_path', type=str, default=r'/home/cat/simon/yolo_0413/mdl') # parser.add_argument('--epochs', type=int, default=3000) # parser.add_argument('--batch', type=int, default=2000) # parser.add_argument('--imgsz', type=int, default=256) # parser.add_argument('--device_id', type=int, default=0) # parser.add_argument('--patience', type=int, default=500) # args = parser.parse_args() # # fitter = FitYolo(weights_path=args.weights_path, # model_yaml=args.model_yaml, # save_path=args.save_path, # epochs=args.epochs, # batch=args.batch, # format=None, # imgsz=args.imgsz, # patience=args.patience, # device_id=args.device_id) # fitter.run() # fitter = FitYolo(weights_path=r"D:\maplight_tg2576_yolo\yolo_mdl\original_weight_oct\best.pt", # model_yaml=r"D:\maplight_tg2576_yolo\yolo_mdl\map.yaml", # save_path=r"D:\maplight_tg2576_yolo\yolo_mdl\mdl", # epochs=1500, # batch=22, # format=None, # device=0, # imgsz=640) # fitter.run() # fitter = FitYolo(weights_path=r"E:\yolo_resident_intruder\mdl\train3\weights\best.pt", # model_yaml=r"E:\maplight_videos\yolo_mdl\map.yaml", # save_path=r"E:\maplight_videos\yolo_mdl\mdl", # epochs=1500, # batch=22, # format=None, # device=0, # imgsz=640) # fitter.run() # # fitter = FitYolo(weights_path=r"E:\netholabs_videos\3d\yolo_mdl\mdl\train5\weights\best.pt", # model_yaml=r"E:\netholabs_videos\3d\yolo_mdl\map.yaml", # save_path=r"E:\netholabs_videos\3d\yolo_mdl\mdl", # epochs=1500, # batch=8, # format=None, # device=0, # imgsz=820) # fitter.run() # fitter = FitYolo(weights_path=r"D:\yolo_weights\yolo11m-pose.pt", # model_yaml=r"D:\cvat_annotations\frames\yolo_072125\map.yaml", # save_path=r"D:\cvat_annotations\frames\yolo_072125\mdl", # epochs=1000, # batch=24, # format=None, # device=0, # imgsz=640) # fitter.run() # # # # # fitter = FitYolo(weights_path=r"D:\yolo_weights\yolo11m-seg.pt", # model_yaml=r"D:\troubleshooting\mitra\mitra_yolo_seg\map.yaml", # save_path=r"D:\troubleshooting\mitra\mitra_yolo_seg\mdl", # epochs=1500, # batch=16, # format=None, # device=0, # imgsz=640) # fitter.run() # # # fitter = FitYolo(weights_path=r"D:\maplight_tg2576_yolo\yolo_mdl\original_weight_oct\best.pt", # model_yaml=r"F:\todd_sleap\yolo_dataset\map.yaml", # save_path=r"F:\todd_sleap\yolo_dataset\mdl", # epochs=1500, # batch=22, # format=None, # device=0, # imgsz=640) # fitter.run()