Source code for simba.pose_processors.remove_keypoints

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import glob
import os
import warnings
from datetime import datetime

import pandas as pd
from tables import NaturalNameWarning

from simba.utils.checks import check_if_filepath_list_is_empty
from simba.utils.errors import (BodypartColumnNotFoundError,
                                InvalidFileTypeError, NotDirectoryError)
from simba.utils.printing import SimbaTimer, stdout_success

warnings.filterwarnings("ignore", category=NaturalNameWarning)


[docs]class KeypointRemover(object): """ Remove pose-estimated keypoints from data in CSV or H5 format. :param str data_folder: Path to directory containing pose-estiation CSV or H5 data :param str file_format: File type of pose-estimation data. :param str pose_tool: Tool used to perform pose-estimation. E.g., `DLC` or `maDLC` .. note:: `GitHub tutorial/documentation <https://github.com/sgoldenlab/simba/blob/master/docs/Tutorial_tools.md#remove-body-parts-from-tracking-data>`__. Examples ---------- >>> keypoint_remover = KeypointRemover(data_folder="MyDataFolder", pose_tool='maDLC', file_format='h5') >>> keypoint_remover.run(bp_to_remove_list=['Nose_1, Nose_2']) """ def __init__(self, data_folder: str, pose_tool: str, file_format: str): if not os.path.isdir(data_folder): raise NotDirectoryError( msg="{} is not a valid directory.".format(str(data_folder)) ) self.files_found = glob.glob(data_folder + "/*." + file_format) check_if_filepath_list_is_empty( filepaths=self.files_found, error_msg="SIMBA ERROR: Zero files found of type {} in the {} directory".format( file_format, data_folder ), ) self.datetime = str(datetime.now().strftime("%Y%m%d%H%M%S")) self.pose_tool, self.data_folder = pose_tool, data_folder self.file_format = file_format if file_format == "h5": first_df = pd.read_hdf(self.files_found[0]) else: first_df = pd.read_csv(self.files_found[0], header=[0, 1, 2]) header_list = list(first_df.columns)[1:] self.body_part_names, self.animal_names = [], [] if pose_tool == "DLC": for header_entry in header_list: if header_entry[1] not in self.body_part_names: self.body_part_names.append(header_entry[1]) else: for header_entry in header_list: if header_entry[1] not in self.body_part_names: self.animal_names.append(header_entry[1]) self.body_part_names.append(header_entry[2]) self.body_part_names, self.animal_names = list(set(self.body_part_names)), list( set(self.animal_names) ) def run(self, animal_names: list, bp_to_remove_list: list): self.timer = SimbaTimer() self.timer.start_timer() save_directory = os.path.join( self.data_folder, "Reorganized_bp_{}".format(self.datetime) ) if not os.path.exists(save_directory): os.makedirs(save_directory) print( "Saving {} new pose-estimation files in {} directory...".format( str(len(self.files_found)), save_directory ) ) if (self.pose_tool == "DLC") or (self.pose_tool == "maDLC"): for file_cnt, file_path in enumerate(self.files_found): save_path = os.path.join(save_directory, os.path.basename(file_path)) if self.file_format == "csv": self.df = pd.read_csv(file_path, header=[0, 1, 2], index_col=0) for body_part in bp_to_remove_list: if body_part not in self.df.columns._levels[1]: raise BodypartColumnNotFoundError( msg=f"{body_part} key point is not present in file {file_path}" ) self.df = self.df.drop(body_part, axis=1, level=1) self.df.to_csv(save_path) if self.file_format == "h5": self.df = pd.read_hdf(file_path) try: first_header_value = self.df.columns._levels[0].values[0] except: raise InvalidFileTypeError( msg=f"{file_path} is not a valid maDLC pose-estimation file" ) for body_part, animal_name in zip(bp_to_remove_list, animal_names): for cord in ["x", "y", "likelihood"]: try: self.df = self.df.drop( (first_header_value, animal_name, body_part, cord), axis=1, ) except: raise BodypartColumnNotFoundError( msg=f"Could not find body part {body_part} in {file_path}" ) self.df.to_hdf( save_path, key="re-organized", format="table", mode="w" ) print( "Saved {}, Video {}/{}.".format( os.path.basename(file_path), str(file_cnt + 1), str(len(self.files_found)), ) ) self.timer.stop_timer() stdout_success( msg=f"{str(len(self.files_found))} new data with {str(len(bp_to_remove_list))} body-parts removed saved in {save_directory} directory", elapsed_time=self.timer.elapsed_time_str, )