Source code for simba.pose_importers.madlc_importer

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import os
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

try:
    from typing import Literal
except:
    from typing_extensions import Literal

import numpy as np
import pandas as pd

from simba.data_processors.interpolate import Interpolate
from simba.data_processors.smoothing import Smoothing
from simba.mixins.config_reader import ConfigReader
from simba.mixins.pose_importer_mixin import PoseImporterMixin
from simba.utils.checks import (check_file_exist_and_readable,
                                check_if_dir_exists,
                                check_if_keys_exist_in_dict, check_instance,
                                check_int, check_str, check_valid_lst)
from simba.utils.enums import Formats, Methods, Options
from simba.utils.errors import BodypartColumnNotFoundError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (find_all_videos_in_project,
                                    get_video_meta_data, write_df)


[docs]class MADLCImporterH5(ConfigReader, PoseImporterMixin): """ Importing multi-animal deeplabcut (maDLC) pose-estimation data (in H5 format) into a SimBA project in parquet or CSV format. :param str config_path: path to SimBA project config file in Configparser format :param str data_folder: Path to folder containing maDLC data in ``.h5`` format. :param str file_type: Method used to perform pose-estimation in maDLC. OPTIONS: `skeleton`, `box`, `ellipse`. :param List[str] id_lst: Names of animals. :param Optional[Dict[str, str]] interpolation_setting: Dict defining the type and method to use to perform interpolation {'type': 'animals', 'method': 'linear'}. :param Optional[Dict[str, Union[str, int]]] smoothing_settings: Dictionary defining the pose estimation smoothing method {'time_window': 500, 'method': 'gaussian'}. .. note:: `Multi-animal import tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/Multi_animal_pose.md>`__. :examples: >>> _ = MADLCImporterH5(config_path=r'MyConfigPath', data_folder=r'maDLCDataFolder', file_type='ellipse', id_lst=['Animal_1', 'Animal_2'], interpolation_settings={'type': 'animals', 'method': 'linear'}, smoothing_settings={'time_window': 500, 'method': 'gaussian'}).run() References ---------- .. [1] Lauer, J., et al. (2022). Multi-animal pose estimation, identification and tracking with DeepLabCut. `Nature Methods, 19, 496–504 <https://doi.org/10.1038/s41592-022-01443-0>`_. """ def __init__(self, config_path: Union[str, os.PathLike], data_folder: Union[str, os.PathLike], file_type: Literal['skeleton', 'box', 'ellipse'], id_lst: List[str], interpolation_settings: Optional[Dict[str, str]] = None, smoothing_settings: Optional[Dict[str, Any]] = None): check_file_exist_and_readable(file_path=config_path) check_if_dir_exists(in_dir=data_folder) check_str(name=f'{self.__class__.__name__} file_type', value=file_type, options=Options.MULTI_DLC_TYPE_IMPORT_OPTION.value) check_valid_lst(data=id_lst, source=f'{self.__class__.__name__} id_lst', valid_dtypes=(str,)) if interpolation_settings is not None: check_if_keys_exist_in_dict(data=interpolation_settings, key=['method', 'type'], name=f'{self.__class__.__name__} interpolation_settings') check_str(name=f'{self.__class__.__name__} interpolation_settings type', value=interpolation_settings['type'], options=('body-parts', 'animals')) check_str(name=f'{self.__class__.__name__} interpolation_settings method', value=interpolation_settings['method'], options=('linear', 'quadratic', 'nearest')) if smoothing_settings is not None: check_if_keys_exist_in_dict(data=smoothing_settings, key=['method', 'time_window'], name=f'{self.__class__.__name__} smoothing_settings') check_str(name=f'{self.__class__.__name__} smoothing_settings method', value=smoothing_settings['method'], options=('savitzky-golay', 'gaussian')) check_int(name=f'{self.__class__.__name__} smoothing_settings time_window', value=smoothing_settings['time_window'], min_value=1) ConfigReader.__init__(self, config_path=config_path, read_video_info=False) PoseImporterMixin.__init__(self) self.interpolation_settings, self.smoothing_settings = interpolation_settings, smoothing_settings self.data_folder, self.id_lst = data_folder, id_lst self.import_log_path = os.path.join(self.logs_path, f"data_import_log_{self.datetime}.csv") self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir, raise_error=True if len(id_lst) > 1 else False) self.input_data_paths = self.find_data_files(dir=self.data_folder, extensions=Formats.DLC_FILETYPES.value[file_type]) self.data_and_videos_lk = self.link_video_paths_to_data_paths(data_paths=self.input_data_paths, video_paths=self.video_paths, str_splits=Formats.DLC_NETWORK_FILE_NAMES.value, raise_error=True if len(id_lst) > 1 else False) if self.pose_setting is Methods.USER_DEFINED.value: self.__update_config_animal_cnt() if self.animal_cnt > 1: self.check_multi_animal_status() self.animal_bp_dict = self.create_body_part_dictionary(self.multi_animal_status, self.id_lst, self.animal_cnt, self.x_cols, self.y_cols, self.p_cols, self.clr_lst) if self.pose_setting is Methods.USER_DEFINED.value: self.update_bp_headers_file(update_bp_headers=True) print(f"Importing {len(list(self.data_and_videos_lk.keys()))} multi-animal DLC file(s)...") def run(self): import_log = pd.DataFrame(columns=["VIDEO", "IMPORT_TIME", "IMPORT_SOURCE", "INTERPOLATION_SETTING", "SMOOTHING_SETTING"]) for cnt, (video_name, video_data) in enumerate(self.data_and_videos_lk.items()): video_timer = SimbaTimer(start=True) self.add_spacer, self.frame_no, self.video_data, self.video_name = (2, 1, video_data, video_name) print(f"Processing {video_name} ({cnt+1}/{len(self.input_data_paths)})...") self.data_df = pd.read_hdf(video_data["DATA"]).replace([np.inf, -np.inf], np.nan).fillna(0) if len(self.data_df.columns) != len(self.bp_headers): raise BodypartColumnNotFoundError( msg=f'The number of body-parts in data file {video_data["DATA"]} do not match the number of body-parts in your SimBA project. ' f"The number of of body-parts expected by your SimBA project is {int(len(self.bp_headers) / 3)}. " f'The number of of body-parts contained in data file {video_data["DATA"]} is {int(len(self.data_df.columns) / 3)}. ' f"Make sure you have specified the correct number of animals and body-parts in your project. NOTE: The project body-parts is stored at {self.body_parts_path}." ) self.data_df.columns = self.bp_headers if self.animal_cnt > 1: self.initialize_multi_animal_ui(animal_bp_dict=self.animal_bp_dict, video_info=get_video_meta_data(video_data["VIDEO"]), data_df=self.data_df, video_path=video_data["VIDEO"]) self.multianimal_identification() else: self.out_df = self.insert_multi_idx_columns(df=self.data_df.fillna(0)) self.save_path = os.path.join(os.path.join(self.input_csv_dir, f"{self.video_name}.{self.file_type}")) write_df(df=self.out_df, file_type=self.file_type, save_path=self.save_path, multi_idx_header=True) if self.interpolation_settings is not None: interpolator = Interpolate(config_path=self.config_path, data_path=self.save_path, type=self.interpolation_settings['type'], method=self.interpolation_settings['method'], multi_index_df_headers=True, copy_originals=False) interpolator.run() if self.smoothing_settings is not None: smoother = Smoothing(config_path=self.config_path, data_path=self.save_path, time_window=self.smoothing_settings['time_window'], method=self.smoothing_settings['method'], multi_index_df_headers=True, copy_originals=False) smoother.run() video_timer.stop_timer() stdout_success(msg=f"Video {video_name} data imported...", elapsed_time=video_timer.elapsed_time_str) self.timer.stop_timer() stdout_success(msg=f"All maDLC H5 data files imported to {self.input_csv_dir} directory", elapsed_time=self.timer.elapsed_time_str)
# test = MADLCImporterH5(config_path=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', # data_folder=r'/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/h5', # file_type='ellipse', # id_lst=['Simon', 'JJ'], # interpolation_settings= {'type': 'animals', 'method': 'linear'}, # smoothing_settings = {'time_window': 500, 'method': 'gaussian'}) # test.run()