Source code for simba.model.sam_inference

import os

import pandas as pd

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from typing import List, Optional, Tuple, Union

import numpy as np

from simba.data_processors.cuda.utils import _is_cuda_available

try:
    from ultralytics.models.sam import SAM2VideoPredictor
except:
    SAM2VideoPredictor = None

from simba.utils.checks import (check_file_exist_and_readable, check_float,
                                check_if_dir_exists, check_instance, check_int,
                                check_valid_array, check_valid_tuple)
from simba.utils.data import resample_geometry_vertices
from simba.utils.enums import Formats, Options
from simba.utils.errors import (InvalidInputError, SimBAGPUError,
                                SimBAPAckageVersionError)
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
                                    get_fn_ext, get_video_meta_data, write_df)


[docs]class SamInference(): """ :example: >>> i = SamInference(video_path=r"MyVideo", >>> labels=[[1]], >>> prompts=[[166, 428]], >>> weights_path=r"D:\yolo_weights\sam2.1_b.pt", >>> save_dir=r'C:\troubleshooting\sam_results', >>> names=('Animal1',)) >>> i.run() .. video:: _static/img/sam_example.webm :loop: :muted: :align: center """ def __init__(self, video_path: Union[str, os.PathLike], weights_path: Union[str, os.PathLike], save_dir: Union[str, os.PathLike], prompts: Union[np.ndarray, List[List[int]]], labels: Union[np.ndarray, List[List[int]]], names: Tuple[str, ...], imgsz: Optional[int] = 1024, confidence: Optional[float] = 0.25, vertice_cnt: Optional[int] = 100): if not _is_cuda_available()[0]: raise SimBAGPUError(msg='No GPU detected.', source=self.__class__.__name__) if SAM2VideoPredictor is None: raise SimBAPAckageVersionError(msg='ultralytics.models.sam.SAM2VideoPredictor package not detected.', source=self.__class__.__name__) if os.path.isdir(video_path): self.video_paths = find_files_of_filetypes_in_directory(directory=video_path, extensions=Options.ALL_VIDEO_FORMAT_OPTIONS.value, raise_error=True) elif os.path.isfile(video_path): self.video_paths = [video_path] else: raise InvalidInputError(msg=f'{video_path} is not a valid file or directory path.', source=self.__class__.__name__) _ = [get_video_meta_data(video_path=x) for x in self.video_paths] check_instance(source=f'{self.__class__.__name__} prompts', instance=prompts, accepted_types=(list, np.ndarray,), raise_error=True, warning=False) check_instance(source=f'{self.__class__.__name__} labels', instance=labels, accepted_types=(list, np.ndarray,), raise_error=True, warning=False) prompts = np.array(prompts) if isinstance(prompts, (list,)) else prompts labels = np.array(labels) if isinstance(labels, (list,)) else labels check_valid_array(data=prompts, source=f'{self.__class__.__name__} prompts', accepted_ndims=(1, 2, 3,), accepted_dtypes=Formats.INTEGER_DTYPES.value) check_valid_array(data=labels, source=f'{self.__class__.__name__} labels', accepted_ndims=(1, 2,), accepted_dtypes=Formats.INTEGER_DTYPES.value) check_file_exist_and_readable(file_path=weights_path) check_if_dir_exists(in_dir=save_dir, raise_error=True) check_int(name=f'{self.__class__.__name__} imgsz', value=imgsz, min_value=1) check_int(name=f'{self.__class__.__name__} vertice_cnt', value=vertice_cnt, min_value=3) check_float(name=f'{self.__class__.__name__} confidence', value=confidence, min_value=10e-6, max_value=1.0) check_valid_tuple(x=names, source=f'{self.__class__.__name__} names', accepted_lengths=(len(labels[0]),)) self.animal_name_dict = {v: x for v, x in enumerate(names)} self.prompts, self.lbls, self.vertice_cnt = prompts, labels, vertice_cnt self.save_dir, self.names = save_dir, names self.overrides = dict(conf=confidence, task="segment", mode="predict", imgsz=imgsz, model="sam2_b.pt") self.predictor = SAM2VideoPredictor(overrides=self.overrides) self.vertice_col_names = ['FRAME', 'NAME'] for i in range(self.vertice_cnt): self.vertice_col_names.append(f"VERTICE_{i}_x"); self.vertice_col_names.append(f"VERTICE_{i}_y")
[docs] def run(self): for video_cnt, video_path in enumerate(self.video_paths): video_labels, video_prompts = self.lbls[video_cnt], self.prompts[video_cnt] video_meta_data = get_video_meta_data(video_path=video_path) video_results = [] _, video_name, _ = get_fn_ext(filepath=video_path) save_path = os.path.join(self.save_dir, f'{video_name}.csv') results = self.predictor(source=video_path, points=video_prompts, labels=video_labels, stream=True) for frm_cnt, video_predictions in enumerate(results): for animal_name_cnt, animal_name in enumerate(self.names): if video_predictions.names is None or animal_name_cnt not in video_predictions.names.keys(): mask = np.full(shape=(int(self.vertice_cnt*2)), fill_value=-1, dtype=np.int32) mask = np.insert(mask, 0, animal_name_cnt) mask = np.insert(mask, 0, int(frm_cnt)) video_results.append(mask) else: mask = video_predictions.masks[animal_name_cnt].xy[0].astype(np.int64) mask[:, 0] = np.clip(mask[:, 0], 0, video_meta_data['width']) mask[:, 1] = np.clip(mask[:, 1], 0, video_meta_data['height']) mask = resample_geometry_vertices(vertices=mask.reshape(1, len(mask), 2), vertice_cnt=self.vertice_cnt)[0].flatten().astype(np.int64) mask = np.insert(mask, 0, animal_name_cnt) mask = np.insert(mask, 0, int(frm_cnt)) video_results.append(mask) video_results = pd.DataFrame(video_results, columns=self.vertice_col_names) #video_results['NAME'] = video_results['NAME'].map(self.animal_name_dict) video_results.to_csv(path_or_buf=save_path)
#write_df(df=video_results, file_type='csv', save_path=save_path, multi_idx_header=False) # i = SamInference(video_path=r"D:\platea\platea_videos\videos\clipped\10B_Mouse_5-choice_MustTouchTrainingNEWFINAL_a7_clipped_3.mp4", # labels=[[1]], # prompts=[[166, 428]], # weights_path=r"D:\yolo_weights\sam2.1_b.pt", # save_dir=r'C:\troubleshooting\sam_results', # names=('Animal1',)) # i.run()