Source code for simba.pose_importers.sleap_csv_importer

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import os
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd

from simba.data_processors.interpolate import Interpolate
from simba.data_processors.smoothing import Smoothing
from simba.mixins.config_reader import ConfigReader
from simba.mixins.pose_importer_mixin import PoseImporterMixin
from simba.utils.checks import (check_if_dir_exists,
                                check_if_keys_exist_in_dict, check_int,
                                check_str, check_that_column_exist,
                                check_valid_lst)
from simba.utils.enums import Methods, TagNames
from simba.utils.errors import AnimalNumberError, CountError
from simba.utils.printing import (SimbaTimer, log_event, stdout_information,
                                  stdout_success)
from simba.utils.read_write import (clean_sleap_file_name,
                                    find_all_videos_in_project,
                                    find_files_of_filetypes_in_directory,
                                    get_fn_ext, get_video_meta_data,
                                    read_sleap_csv, write_df)

TRACK = "track"
INSTANCE_SCORE = "instance.score"

[docs]class SLEAPImporterCSV(ConfigReader, PoseImporterMixin): """ Importing SLEAP pose-estimation data into SimBA project in ``CSV`` format. .. note:: `Google Colab notebook for converting SLEAP .slp to CSV written by @Toshea111 <https://colab.research.google.com/drive/1EpyTKFHVMCqcb9Lj9vjMrriyaG9SvrPO?usp=sharing>`__. `Example expected SLEAP csv data file for 5 animals / 4 pose-estimated body-parts <https://github.com/sgoldenlab/simba/blob/master/misc/sleap_csv_example.csv>`__. :param str config_path: path to SimBA project config file in Configparser format :param str data_folder: Path to folder containing SLEAP data in `csv` format. :param List[str] id_lst: Animal names. This will be ignored in one animal projects and default to ``Animal_1``. :param Optional[Dict[str, str]] interpolation_setting: Dict defining the type and method to use to perform interpolation {'type': 'animals', 'method': 'linear'}. :param Optional[Dict[str, Union[str, int]]] smoothing_settings: Dictionary defining the pose estimation smoothing method {'time_window': 500, 'method': 'gaussian'}. References ---------- .. [1] Pereira, T. D., et al. (2022). SLEAP: A deep learning system for multi-animal pose tracking. `Nature Methods, 19, 486–495 <https://doi.org/10.1038/s41592-022-01426-1>`_. >>> sleap_csv_importer = SLEAPImporterCSV(config_path=r'project_folder/project_config.ini', data_folder=r'data_folder', id_lst=['Termite_1', 'Termite_2', 'Termite_3', 'Termite_4', 'Termite_5'], interpolation_settings={'type': 'animals', 'method': 'linear'}, smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) >>> sleap_csv_importer.run() """ def __init__(self, config_path: Union[str, os.PathLike], data_folder: Union[str, os.PathLike], id_lst: List[str], interpolation_settings: Optional[Dict[str, str]] = None, smoothing_settings: Optional[Dict[str, Union[int, str]]] = None): ConfigReader.__init__(self, config_path=config_path, read_video_info=False) PoseImporterMixin.__init__(self) check_if_dir_exists(in_dir=data_folder) check_valid_lst(data=id_lst, source=f'{self.__class__.__name__} id_lst', valid_dtypes=(str,), min_len=1) if interpolation_settings is not None: check_if_keys_exist_in_dict(data=interpolation_settings, key=['method', 'type'], name=f'{self.__class__.__name__} interpolation_settings') check_str(name=f'{self.__class__.__name__} interpolation_settings type', value=interpolation_settings['type'], options=('body-parts', 'animals')) check_str(name=f'{self.__class__.__name__} interpolation_settings method', value=interpolation_settings['method'], options=('linear', 'quadratic', 'nearest')) if smoothing_settings is not None: check_if_keys_exist_in_dict(data=smoothing_settings, key=['method', 'time_window'], name=f'{self.__class__.__name__} smoothing_settings') check_str(name=f'{self.__class__.__name__} smoothing_settings method', value=smoothing_settings['method'], options=('savitzky-golay', 'gaussian')) check_int(name=f'{self.__class__.__name__} smoothing_settings time_window', value=smoothing_settings['time_window'], min_value=1) log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) self.interpolation_settings, self.smoothing_settings = (interpolation_settings, smoothing_settings) self.data_folder, self.id_lst = data_folder, id_lst self.import_log_path = os.path.join(self.logs_path, f"data_import_log_{self.datetime}.csv") self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir) self.input_data_paths = find_files_of_filetypes_in_directory(directory=self.data_folder, extensions=['.csv'], raise_error=True) if self.pose_setting is Methods.USER_DEFINED.value: self.__update_config_animal_cnt() if self.animal_cnt > 1: self.data_and_videos_lk = self.link_video_paths_to_data_paths(data_paths=self.input_data_paths, video_paths=self.video_paths, filename_cleaning_func=clean_sleap_file_name) self.check_multi_animal_status() self.animal_bp_dict = self.create_body_part_dictionary(self.multi_animal_status, self.id_lst, self.animal_cnt, self.x_cols, self.y_cols, self.p_cols, self.clr_lst) if self.pose_setting is Methods.USER_DEFINED.value: self.update_bp_headers_file(update_bp_headers=True) else: self.data_and_videos_lk = dict([(get_fn_ext(file_path)[1], {"DATA": file_path, "VIDEO": None}) for file_path in self.input_data_paths]) stdout_information(msg=f"Importing {len(list(self.data_and_videos_lk.keys()))} SLEAP CSV file(s)...") def run(self): for file_cnt, (video_name, video_data) in enumerate(self.data_and_videos_lk.items()): output_filename = clean_sleap_file_name(filename=video_name) print(f"Importing {output_filename}...") video_timer = SimbaTimer(start=True) self.video_name = video_name self.save_path = os.path.join(os.path.join(self.input_csv_dir, f"{output_filename}.{self.file_type}")) data_df, bp_names, headers = read_sleap_csv(file_path=video_data["DATA"]) if INSTANCE_SCORE in data_df.columns: data_df = data_df.drop([INSTANCE_SCORE], axis=1) idx = data_df.iloc[:, :2] check_that_column_exist(df=idx, column_name=TRACK, file_name=video_name) data_unique_tracks = list(idx[TRACK].unique()) if len(data_unique_tracks) != self.animal_cnt: raise AnimalNumberError(msg=f'The SLEAP CSV file {video_data["DATA"]} contains data for {len(data_unique_tracks)} tracks (found tracks: {data_unique_tracks}). The SimBA project config says the SimBA project expects data for {self.animal_cnt} animals.') if len(data_unique_tracks) == 1: idx[TRACK] = 'track_0'; data_unique_tracks = ['track_0'] idx[TRACK] = idx[TRACK].fillna(data_unique_tracks[0]) idx[TRACK] = idx[TRACK].astype(str).str.replace(r"[^\d.]+", "", regex=True).astype(int) data_df = data_df.iloc[:, 2:].fillna(0) if self.animal_cnt > 1: self.data_df = pd.DataFrame(self.transpose_multi_animal_table(data=data_df.values, idx=idx.values, animal_cnt=self.animal_cnt)) else: idx = list(idx.drop(TRACK, axis=1)["frame_idx"]) self.data_df = data_df.set_index([idx]).sort_index() self.data_df.columns = np.arange(len(self.data_df.columns)) self.data_df = self.data_df.reindex(range(0, self.data_df.index[-1] + 1), fill_value=0) if len(self.bp_headers) != len(self.data_df.columns): raise CountError(msg=f"The SimBA project expects {len(self.bp_headers)} data columns, but your SLEAP data file at {video_data['DATA']} contains {len(self.data_df.columns)} columns.", source=self.__class__.__name__) self.data_df.columns = self.bp_headers self.out_df = deepcopy(self.data_df) if self.animal_cnt > 1: self.initialize_multi_animal_ui(animal_bp_dict=self.animal_bp_dict, video_info=get_video_meta_data(video_data["VIDEO"]), data_df=self.data_df, video_path=video_data["VIDEO"]) self.multianimal_identification() write_df(df=self.out_df, file_type=self.file_type, save_path=self.save_path, multi_idx_header=True) if self.interpolation_settings is not None: interpolator = Interpolate(config_path=self.config_path, data_path=self.save_path, type=self.interpolation_settings['type'], method=self.interpolation_settings['method'], multi_index_df_headers=True, copy_originals=False) interpolator.run() if self.smoothing_settings is not None: smoother = Smoothing(config_path=self.config_path, data_path=self.save_path, time_window=self.smoothing_settings['time_window'], method=self.smoothing_settings['method'], multi_index_df_headers=True, copy_originals=False) smoother.run() video_timer.stop_timer() stdout_success(msg=f"Video {video_name} data imported...", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__) self.timer.stop_timer() stdout_success(msg=f"{len(list(self.data_and_videos_lk.keys()))} file(s) imported to the SimBA project {self.input_csv_dir}", source=self.__class__.__name__)
# test = SLEAPImporterCSV(config_path=r"D:\troubleshooting\two_animals_sleap\project_folder\project_config.ini", # data_folder=r'D:\troubleshooting\two_animals_sleap\import_data', # id_lst=['Animal_1', 'Animal_2'], # interpolation_settings=None, # smoothing_settings = None) # test.run() # test = SLEAPImporterCSV(config_path=r"F:\troubleshooting\sam\sam\project_folder\project_config.ini", # data_folder=r'F:\troubleshooting\sam\raw_pose', # id_lst=['Animal_1',], # interpolation_settings={'type': 'body-parts', 'method': 'nearest'}, # smoothing_settings = {'time_window': 200, 'method': 'savitzky-golay'}) # test.run() # # # test = SLEAPImporterCSV(config_path=r"C:\troubleshooting\sleap_import_two_tracks\project_folder\project_config.ini", # data_folder=r'C:\troubleshooting\sleap_import_two_tracks\data_csv', # id_lst=['Track_0', 'Track_1'], # interpolation_settings={'type': 'animals', 'method': 'linear'}, # smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) # test.run() # # # test = SLEAPImporterCSV(config_path=r"C:\troubleshooting\sleap_import_two_tracks\project_folder\project_config.ini", # data_folder=r'C:\troubleshooting\sleap_import_two_tracks\data_csv', # id_lst=['Track_0', 'Track_1'], # interpolation_settings={'type': 'animals', 'method': 'linear'}, # smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) # test.run() # test = SLEAPImporterCSV(config_path=r'/Users/simon/Desktop/envs/troubleshooting/Hornet/project_folder/project_config.ini', # data_folder=r'/Users/simon/Desktop/envs/troubleshooting/Hornet_single_slp/import', # id_lst=['Hornet'], # interpolation_settings="Body-parts: Nearest", # smoothing_settings = {'Method': 'None', 'Parameters': {'Time_window': '200'}}) # test.run() # test = SLEAPImporterCSV(config_path=r'/Users/simon/Desktop/envs/troubleshooting/Hornet/project_folder/project_config.ini', # data_folder=r'/Users/simon/Desktop/envs/troubleshooting/Hornet_single_slp/import', # id_lst=['Hornet'], # interpolation_settings="Body-parts: Nearest", # smoothing_settings = {'Method': 'None', 'Parameters': {'Time_window': '200'}}) # test.run() # test = SLEAPImporterCSV(config_path=r'/Users/simon/Desktop/envs/troubleshooting/slp_1_animal_1_bp/project_folder', # data_folder='/Users/simon/Desktop/envs/troubleshooting/slp_1_animal_1_bp/import', # id_lst=['Termite_1'], # interpolation_settings="Body-parts: Nearest", # smoothing_settings = {'Method': 'Savitzky Golay', 'Parameters': {'Time_window': '200'}}) # test.run()