Source code for simba.pose_importers.sleap_slp_importer

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union

import h5py
import numpy as np
import pandas as pd

from simba.data_processors.interpolate import Interpolate
from simba.data_processors.smoothing import Smoothing
from simba.mixins.config_reader import ConfigReader
from simba.mixins.pose_importer_mixin import PoseImporterMixin
from simba.utils.checks import (check_file_exist_and_readable,
                                check_if_dir_exists,
                                check_if_keys_exist_in_dict, check_int,
                                check_str, check_valid_lst)
from simba.utils.enums import Methods, TagNames
from simba.utils.errors import CountError
from simba.utils.printing import (SimbaTimer, log_event, stdout_information,
                                  stdout_success)
from simba.utils.read_write import (clean_sleap_file_name,
                                    find_all_videos_in_project, get_fn_ext,
                                    get_video_meta_data, write_df)

FRAME_IDX_COL, INST_START_COL, INST_END_COL = 2, 3, 4
INSTANCE_TRACK_COL = 4
POINT_X_COL, POINT_Y_COL, POINT_SCORE_COL = 0, 1, 4


def _field(arr: np.ndarray, name: str, pos: int) -> np.ndarray:
    """Return a column of a (possibly structured) numpy array by field name, falling back to positional index."""
    names = arr.dtype.names
    if names is not None and name in names:
        return np.asarray(arr[name])
    if names is not None:
        return np.asarray(arr[names[pos]])
    return np.asarray(arr[:, pos])


[docs]class SLEAPImporterSLP(ConfigReader, PoseImporterMixin): """ Class for importing SLEAP pose-estimation data (``.slp`` files) into a SimBA project. :param str project_path: path to SimBA project config file in Configparser format. :param str data_folder: Path to folder containing SLEAP data in ``.slp`` format. :param List[str] id_lst: Animal names. This will be ignored in one animal projects and default to ``Animal_1``. :param Optional[Dict[str, str]] interpolation_settings: Dict defining the type and method to use to perform interpolation e.g. {'type': 'animals', 'method': 'linear'}. :param Optional[Dict[str, Union[str, int]]] smoothing_settings: Dictionary defining the pose estimation smoothing method e.g. {'time_window': 500, 'method': 'gaussian'}. :example: >>> slp_importer = SLEAPImporterSLP(project_path="MyConfigPath", data_folder=r'MySLPDataFolder', id_lst=['Mouse_1', 'Mouse_2'], interpolation_settings={'type': 'animals', 'method': 'linear'}, smoothing_settings={'time_window': 200, 'method': 'savitzky-golay'}) >>> slp_importer.run() References ---------- .. [1] Pereira, T. D., et al. (2022). SLEAP: A deep learning system for multi-animal pose tracking. `Nature Methods, 19, 486–495 <https://doi.org/10.1038/s41592-022-01426-1>`_. """ def __init__(self, project_path: Union[str, os.PathLike], data_folder: Union[str, os.PathLike], id_lst: List[str], interpolation_settings: Optional[Dict[str, str]] = None, smoothing_settings: Optional[Dict[str, Any]] = None): ConfigReader.__init__(self, config_path=project_path, read_video_info=False) PoseImporterMixin.__init__(self) log_event(logger_name=str(__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals())) check_file_exist_and_readable(file_path=project_path) check_if_dir_exists(in_dir=data_folder, source=self.__class__.__name__) check_valid_lst(data=id_lst, source=f'{self.__class__.__name__} id_lst', valid_dtypes=(str,), min_len=1) if interpolation_settings is not None: check_if_keys_exist_in_dict(data=interpolation_settings, key=['method', 'type'], name=f'{self.__class__.__name__} interpolation_settings') check_str(name=f'{self.__class__.__name__} interpolation_settings type', value=interpolation_settings['type'], options=('body-parts', 'animals')) check_str(name=f'{self.__class__.__name__} interpolation_settings method', value=interpolation_settings['method'], options=('linear', 'quadratic', 'nearest')) if smoothing_settings is not None: check_if_keys_exist_in_dict(data=smoothing_settings, key=['method', 'time_window'], name=f'{self.__class__.__name__} smoothing_settings') check_str(name=f'{self.__class__.__name__} smoothing_settings method', value=smoothing_settings['method'], options=('savitzky-golay', 'gaussian')) check_int(name=f'{self.__class__.__name__} smoothing_settings time_window', value=smoothing_settings['time_window'], min_value=1) self.interpolation_settings, self.smoothing_settings = interpolation_settings, smoothing_settings self.data_folder, self.id_lst = data_folder, id_lst self.import_log_path = os.path.join(self.logs_path, f"data_import_log_{self.datetime}.csv") self.video_paths = find_all_videos_in_project(videos_dir=self.video_dir, raise_error=True if len(id_lst) > 1 else False) self.input_data_paths = self.find_data_files(dir=self.data_folder, extensions=[".slp"]) if self.pose_setting is Methods.USER_DEFINED.value: self.__update_config_animal_cnt() if self.animal_cnt > 1: self.data_and_videos_lk = self.link_video_paths_to_data_paths(data_paths=self.input_data_paths, video_paths=self.video_paths, filename_cleaning_func=clean_sleap_file_name) self.check_multi_animal_status() self.animal_bp_dict = self.create_body_part_dictionary(self.multi_animal_status, self.id_lst, self.animal_cnt, self.x_cols, self.y_cols, self.p_cols, self.clr_lst) if self.pose_setting is Methods.USER_DEFINED.value: self.update_bp_headers_file(update_bp_headers=True) else: self.data_and_videos_lk = dict([(get_fn_ext(file_path)[1], {"DATA": file_path, "VIDEO": None}) for file_path in self.input_data_paths]) stdout_information(msg=f"Importing {len(list(self.data_and_videos_lk.keys()))} file(s)...", source=self.__class__.__name__) @staticmethod def _read_metadata(file_path: Union[str, os.PathLike]) -> dict: """Pull the JSON ``metadata`` blob (skeleton / node order) out of an .slp file.""" result = {} def _visitor(name, obj): if name == "metadata": attrs = list(obj.attrs.items()) result["meta"] = dict(json.loads(attrs[1][1].decode("utf-8"))) with h5py.File(file_path, "r") as f: f.visititems(_visitor) return result["meta"] def _slp_to_arrays(self, file_path: Union[str, os.PathLike]) -> Tuple[np.ndarray, np.ndarray, List[str]]: """ Read an .slp file and return: * ``data`` : (n_instances, n_nodes * 3) float64 array of interleaved [x, y, p]. * ``idx`` : (n_instances, 2) array of [track_id, frame_idx]. * ``ordered_bps`` : body-part names in skeleton node order. """ meta = self._read_metadata(file_path) bp_names = [bp["name"] for bp in meta["nodes"]] ordered_ids = [n["id"] for n in meta["skeletons"][0]["nodes"]] ordered_bps = [bp_names[i] for i in ordered_ids] n_nodes = len(ordered_bps) with h5py.File(file_path, "r") as f: frames = f["frames"][:] instances = f["instances"][:] pred_points = f["pred_points"][:] n_instances = instances.shape[0] expected_points = n_instances * n_nodes if pred_points.shape[0] != expected_points: raise CountError(msg=f"SLP file {file_path}: expected {expected_points} predicted points " f"({n_instances} instances x {n_nodes} nodes) but found {pred_points.shape[0]}.", source=self.__class__.__name__) frame_idx_vals = _field(frames, "frame_idx", FRAME_IDX_COL).astype(np.int64) inst_start = _field(frames, "instance_id_start", INST_START_COL).astype(np.int64) inst_end = _field(frames, "instance_id_end", INST_END_COL).astype(np.int64) instance_frame_idx = np.repeat(frame_idx_vals, inst_end - inst_start) track_ids = _field(instances, "track", INSTANCE_TRACK_COL).astype(np.int64) px = _field(pred_points, "x", POINT_X_COL).astype(np.float64).reshape(n_instances, n_nodes) py = _field(pred_points, "y", POINT_Y_COL).astype(np.float64).reshape(n_instances, n_nodes) pp = _field(pred_points, "score", POINT_SCORE_COL).astype(np.float64).reshape(n_instances, n_nodes) data = np.empty((n_instances, n_nodes * 3), dtype=np.float64) data[:, 0::3] = px data[:, 1::3] = py data[:, 2::3] = pp idx = np.column_stack([track_ids, instance_frame_idx]) return data, idx, ordered_bps def run(self): for file_cnt, (video_name, video_data) in enumerate(self.data_and_videos_lk.items()): output_filename = clean_sleap_file_name(filename=video_name) stdout_information(msg=f"Importing {output_filename}...", source=self.__class__.__name__) video_timer = SimbaTimer(start=True) self.video_name = video_name self.save_path = os.path.join(self.input_csv_dir, f"{output_filename}.{self.file_type}") data, idx, _ = self._slp_to_arrays(file_path=video_data["DATA"]) if self.animal_cnt > 1: self.data_df = pd.DataFrame(self.transpose_multi_animal_table(data=data, idx=idx, animal_cnt=self.animal_cnt)) else: self.data_df = pd.DataFrame(data).set_index(idx[:, 1]).sort_index() self.data_df = self.data_df.reindex(range(0, int(self.data_df.index[-1]) + 1), fill_value=0) if len(self.bp_headers) != len(self.data_df.columns): raise CountError(msg=f"The SimBA project expects {len(self.bp_headers)} data columns, but the reshaped " f"SLEAP data at {video_data['DATA']} produced {len(self.data_df.columns)} columns. " f"Make sure you have specified the correct number of animals and body-parts in your project.", source=self.__class__.__name__) self.data_df.columns = self.bp_headers if self.animal_cnt > 1: self.initialize_multi_animal_ui(animal_bp_dict=self.animal_bp_dict, video_info=get_video_meta_data(video_data["VIDEO"]), data_df=self.data_df, video_path=video_data["VIDEO"]) self.multianimal_identification() else: self.out_df = self.insert_multi_idx_columns(df=self.data_df.fillna(0)) write_df(df=self.out_df, file_type=self.file_type, save_path=self.save_path, multi_idx_header=True) if self.interpolation_settings is not None: Interpolate(config_path=self.config_path, data_path=self.save_path, type=self.interpolation_settings['type'], method=self.interpolation_settings['method'], multi_index_df_headers=True, copy_originals=False).run() if self.smoothing_settings is not None: Smoothing(config_path=self.config_path, data_path=self.save_path, time_window=self.smoothing_settings['time_window'], method=self.smoothing_settings['method'], multi_index_df_headers=True, copy_originals=False).run() video_timer.stop_timer() stdout_success(msg=f"Video {output_filename} data imported...", elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__) self.timer.stop_timer() stdout_success(msg=f"All SLEAP SLP data files imported to {self.input_csv_dir} directory", elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)