Source code for simba.mixins.pose_importer_mixin

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import itertools
import os
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from typing import Dict, List, Optional, Union

import cv2
import h5py
import numpy as np
import pandas as pd
import scipy.io as sio
from numba import jit, prange

from simba.mixins.plotting_mixin import PlottingMixin
from simba.utils.enums import ConfigKey
from simba.utils.errors import (CountError, IntegerError, InvalidInputError,
                                NoDataError, NoFilesFoundError)
from simba.utils.read_write import get_fn_ext
from simba.utils.warnings import FrameRangeWarning


[docs]class PoseImporterMixin(object): """ Methods for importing pose-estimation data. """ def __init__(self): self.datetime = datetime.now().strftime("%Y%m%d%H%M%S") pass
[docs] def initialize_multi_animal_ui(self, animal_bp_dict: dict, video_info: dict, data_df: pd.DataFrame, video_path: Union[str, os.PathLike], initial_frame_no: Optional[int] = None): """ :param Optional[int] initial_frame_no: Frame index to start the assignment UI on. Useful for jumping directly to a frame where all animals are simultaneously detected, so the user doesn't have to press "x" repeatedly to find one. If ``None`` (default) the UI starts at frame 0, preserving previous behaviour. Clamped to the valid range ``[0, len(data_df) - 1]``; out-of-range values emit a :class:`FrameRangeWarning` and fall back to 0. """ start_frame = 0 if initial_frame_no is not None: if not isinstance(initial_frame_no, int) or initial_frame_no < 0 or initial_frame_no >= len(data_df): FrameRangeWarning( msg=(f"initial_frame_no={initial_frame_no} is outside the valid range " f"[0, {len(data_df) - 1}] for this video; starting at frame 0 instead."), source=self.__class__.__name__, ) else: start_frame = initial_frame_no self.video_info, self.data_df, self.frame_no, self.add_spacer = (video_info, data_df, start_frame, 2) self.animal_bp_dict, self.cap = animal_bp_dict, cv2.VideoCapture(video_path) _, self.video_name, _ = get_fn_ext(video_path) self.get_video_scalers(video_info=video_info)
def get_video_scalers(self, video_info: dict): self.scalers = {} w, h = int(video_info["width"]), int(video_info["height"]) self.scalers["circle"] = PlottingMixin().get_optimal_circle_size(frame_size=(w, h), circle_frame_ratio=100) self.scalers["font"], _, self.scalers["space"] = PlottingMixin().get_optimal_font_scales(text="Press 'c' to continue to start assigning identities using this frame", accepted_px_width=w, accepted_px_height=int(h/6))
[docs] def find_data_files(self, dir: Union[str, os.PathLike], extensions: List[str]) -> List[str]: """ Search for files with specific extensions in a given directory and return their paths. :param Union[str, os.PathLike] dir: The directory to search for files. It can be a string or a path-like object. :param List[str] extensions: A list of file extensions to look for. :return: A list of paths to the files with the specified extensions found in the directory. :raises NoDataError: If no files with the specified extensions are found in the directory. """ data_paths = [] paths = [f for f in next(os.walk(dir))[2] if not f[0] == "."] paths = [os.path.join(dir, f) for f in paths] for extension in extensions: for path in paths: if path.endswith(extension): data_paths.append(path) if len(data_paths) == 0: raise NoDataError(msg=f"No files with {extensions} extensions found in {dir}.", source=self.__class__.__name__) return data_paths
def get_x_y_loc_of_mouse_click(self, event, x, y, flags, param): if event == cv2.EVENT_LBUTTONDOWN: self.click_loc = (x, y) self.id_cords[self.cnt] = {} self.id_cords[self.cnt]["cord"] = self.click_loc self.id_cords[self.cnt]["name"] = self.animal_name def insert_all_bodyparts_into_img(self, img: np.ndarray, body_parts: dict): for animal, bp_data in body_parts.items(): for bp_cnt, bp_tuple in enumerate(bp_data): try: cv2.circle(img, bp_tuple, self.scalers["circle"], self.animal_bp_dict[animal]["colors"][bp_cnt], -1, lineType=cv2.LINE_AA) except Exception as err: if type(err) == OverflowError: raise IntegerError( f"SimBA encountered a pose-estimated body-part located at pixel position {str(bp_tuple)}. " "This value is too large to be converted to an integer. " "Please check your pose-estimation data to make sure that it is accurate.", source=self.__class__.__name__, ) def insert_animal_names(self): for animal_cnt, animal_data in self.id_cords.items(): cv2.putText(self.new_frame, animal_data["name"], animal_data["cord"], cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 255), 2) def multianimal_identification(self): cv2.destroyAllWindows() self.cap.set(1, self.frame_no) self.all_frame_data = self.data_df.loc[self.frame_no, :] cv2.namedWindow("Define animal IDs", cv2.WINDOW_NORMAL) _, self.img = self.cap.read() self.img_overlay, self.img_bp_cords = deepcopy(self.img), defaultdict(list) for animal_cnt, (animal_name, animal_bps) in enumerate(self.animal_bp_dict.items()): self.img_bp_cords[animal_name] = [] for x_name, y_name in zip(animal_bps["X_bps"], animal_bps["Y_bps"]): self.img_bp_cords[animal_name].append(tuple(self.data_df.loc[self.frame_no, [x_name, y_name]].values.astype(np.int32))) self.insert_all_bodyparts_into_img(img=self.img_overlay, body_parts=self.img_bp_cords) side_img = np.ones((int(self.video_info["height"] / 2), self.video_info["width"], 3)) cv2.putText(side_img, f"Current video: {self.video_name}", (10, self.scalers["space"]), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 255), 2) cv2.putText(side_img, "Can you assign identities based on the displayed frame ?", (10, int(self.scalers["space"] * (self.add_spacer * 2))), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 255), 2) cv2.putText(side_img, 'Press "x" to display new, random, frame', (10, int(self.scalers["space"] * (self.add_spacer * 3))), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 0), 2) cv2.putText(side_img, 'Press "c" to continue to start assigning identities using this frame', (10, int(self.scalers["space"] * (self.add_spacer * 4))), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (0, 255, 0), 2) self.img_concat = np.uint8(np.concatenate((self.img_overlay, side_img), axis=0)) cv2.imshow("Define animal IDs", self.img_concat) cv2.resizeWindow("Define animal IDs", self.video_info["height"], self.video_info["width"]) keyboard_choice = False while not keyboard_choice: k = cv2.waitKey(20) if (k == ord("x")) or (k == ord("X")): if self.frame_no + 50 > len(self.data_df): FrameRangeWarning(msg=f"Cannot proceed to new frame: the frame {self.frame_no+50} does not exist in video {self.video_name}. The video {self.video_name} has {len(self.data_df)} frames.") else: cv2.destroyWindow("Define animal IDs") cv2.waitKey(0) self.frame_no += 50 self.multianimal_identification() break elif (k == ord("c")) or (k == ord("C")): cv2.destroyWindow("Define animal IDs") cv2.waitKey(0) self.choose_animal_ui() break def choose_animal_ui(self): self.id_cords = {} for cnt, animal in enumerate(self.animal_bp_dict.keys()): self.animal_name, self.cnt = animal, cnt self.new_overlay = deepcopy(self.img_overlay) cv2.namedWindow("Define animal IDs", cv2.WINDOW_NORMAL) self.side_img = np.ones((int(self.video_info["height"] / 2), self.video_info["width"], 3)) cv2.putText(self.side_img, "Left mouse click on:", (10, self.scalers["space"]), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 255), 3) cv2.putText(self.side_img, animal, (10, int(self.scalers["space"] * (self.add_spacer * 2))), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 0), 3) for id in self.id_cords.keys(): cv2.putText(self.new_overlay, self.id_cords[id]["name"], self.id_cords[id]["cord"], cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 255), 3) self.new_overlay = np.uint8(np.concatenate((self.new_overlay, self.side_img), axis=0)) cv2.imshow("Define animal IDs", self.new_overlay) cv2.resizeWindow("Define animal IDs", self.video_info["height"], self.video_info["width"]) while cnt not in self.id_cords.keys(): cv2.setMouseCallback("Define animal IDs", self.get_x_y_loc_of_mouse_click) cv2.waitKey(200) self.confirm_ui() def confirm_ui(self): cv2.destroyAllWindows() cv2.namedWindow("Define animal IDs", cv2.WINDOW_NORMAL) cv2.resizeWindow("Define animal IDs", self.video_info["height"], self.video_info["width"]) self.new_frame = deepcopy(self.img) self.side_img = np.ones((int(self.video_info["height"] / 2), self.video_info["width"], 3)) cv2.putText(self.side_img, f"Current video: {self.video_name}", (10, self.scalers["space"]), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 255), 3) cv2.putText(self.side_img, "Are you happy with your assigned identities ?", (10, int(self.scalers["space"] * (self.add_spacer * 2))), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 255), 2) cv2.putText(self.side_img, 'Press "c" to continue (to finish, or proceed to the next video)', (10, int(self.scalers["space"] * (self.add_spacer * 3))), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (255, 255, 0), 2) cv2.putText(self.side_img, 'Press "x" to re-start assigning identities', (10, int(self.scalers["space"] * (self.add_spacer * 4))), cv2.FONT_HERSHEY_SIMPLEX, self.scalers["font"], (0, 255, 255), 2,) self.insert_all_bodyparts_into_img(img=self.new_frame, body_parts=self.img_bp_cords) self.insert_animal_names() self.img_concat = np.uint8(np.concatenate((self.new_frame, self.side_img), axis=0)) cv2.imshow("Define animal IDs", self.img_concat) cv2.resizeWindow("Define animal IDs", self.video_info["height"], self.video_info["width"]) keyboard_choice = False while not keyboard_choice: k = cv2.waitKey(20) if (k == ord("x")) or (k == ord("X")): if self.frame_no + 50 > len(self.data_df): FrameRangeWarning(msg=f"Cannot proceed to new frame: the frame {self.frame_no+50} does not exist in video {self.video_name}. The video {self.video_name} has {len(self.data_df)} frames.") else: cv2.destroyWindow("Define animal IDs") cv2.waitKey(1) self.frame_no += 50 self.multianimal_identification() break elif (k == ord("c")) or (k == ord("C")): cv2.destroyAllWindows() cv2.waitKey(1) self.cap.release() self.find_closest_animals() break def find_closest_animals(self): self.animal_order = {} for animal_number, animal_click_data in self.id_cords.items(): animal_name, animal_cord = (animal_click_data["name"], animal_click_data["cord"]) closest_animal = {"animal_name": None, "body_part_name": None, "distance": np.inf} for other_animal_name, animal_bps in self.animal_bp_dict.items(): animal_bp_names_x = self.animal_bp_dict[other_animal_name]["X_bps"] animal_bp_names_y = self.animal_bp_dict[other_animal_name]["Y_bps"] for x_col, y_col in zip(animal_bp_names_x, animal_bp_names_y): bp_x, bp_y = self.all_frame_data[x_col], self.all_frame_data[y_col] distance = np.sqrt((animal_cord[0] - bp_x) ** 2 + (animal_cord[1] - bp_y) ** 2) if distance < closest_animal["distance"]: closest_animal["animal_name"] = other_animal_name closest_animal["body_part_name"] = (x_col, y_col) closest_animal["distance"] = distance if closest_animal["animal_name"] is None: raise InvalidInputError( msg=(f'No tracked body-parts found on frame {self.frame_no} for click number ' f'{animal_number} ({animal_name}): every body-part of every animal is NaN at ' f'this frame. Press "x" to step to a frame where at least one body-part per ' f'animal is tracked, then assign identities.'), source=self.__class__.__name__, ) self.animal_order[animal_number] = closest_animal self.check_intergity_of_chosen_animal_order() self.organize_results() self.reinsert_multi_idx_columns() def organize_results(self): self.out_df = pd.DataFrame() for animal_cnt, animal_data in self.animal_order.items(): closest_animal_dict = self.animal_bp_dict[animal_data["animal_name"]] x_cols, y_cols, p_cols = ( closest_animal_dict["X_bps"], closest_animal_dict["Y_bps"], closest_animal_dict["P_bps"], ) for x_col, y_col, p_cols in zip(x_cols, y_cols, p_cols): df = self.data_df[[x_col, y_col, p_cols]] self.out_df = pd.concat([self.out_df, df], axis=1) def reinsert_multi_idx_columns(self): multi_idx_cols = [] for col_idx in range(len(self.out_df.columns)): multi_idx_cols.append( tuple(("IMPORTED_POSE", "IMPORTED_POSE", self.out_df.columns[col_idx])) ) self.out_df.columns = pd.MultiIndex.from_tuples( multi_idx_cols, names=("scorer", "bodypart", "coords") ) def insert_multi_idx_columns(self, df: pd.DataFrame) -> pd.DataFrame: multi_idx_cols = [] for col_idx in range(len(df.columns)): multi_idx_cols.append(tuple(("IMPORTED_POSE", "IMPORTED_POSE", df.columns[col_idx]))) df.columns = pd.MultiIndex.from_tuples(multi_idx_cols, names=("scorer", "bodypart", "coords")) return df def check_intergity_of_chosen_animal_order(self): for click_key_combination in itertools.combinations(list(self.animal_order.keys()), 2): click_n, click_n1 = click_key_combination[0], click_key_combination[1] animal_1, animal_2 = (self.animal_order[click_n]["animal_name"], self.animal_order[click_n1]["animal_name"]) if animal_1 == animal_2: raise InvalidInputError(msg=f"The animal most proximal to click number {str(click_n)} is animal named {animal_1}. The animal most proximal to click number {str(click_n1)} is also animal {animal_2}. Please indicate which animal is which using a video frame where the animals are clearly separated", source=self.__class__.__name__,) def intertwine_probability_cols(self, data: pd.DataFrame) -> pd.DataFrame: results = pd.DataFrame() for animal_name, animal_bps in self.animal_bp_dict.items(): x_cols, y_cols, p_cols = ( animal_bps["X_bps"], animal_bps["Y_bps"], animal_bps["P_bps"], ) for x_col, y_col, p_col in zip(x_cols, y_cols, p_cols): df = data[[x_col, y_col, p_col]] results = pd.concat([results, df], axis=1) return results def __update_config_animal_cnt(self): self.config.set( ConfigKey.GENERAL_SETTINGS.value, ConfigKey.ANIMAL_CNT.value, str(self.animal_cnt), ) with open(self.project_path, "w+") as f: self.config.write(f) f.close() def update_bp_headers_file(self, update_bp_headers: Optional[bool] = False): new_headers = [] for animal_name in self.animal_bp_dict.keys(): for bp in self.animal_bp_dict[animal_name]["X_bps"]: if animal_name not in bp: new_headers.append("{}_{}".format(animal_name, bp[:-2])) else: new_headers.append(bp[:-2]) new_bp_df = pd.DataFrame(new_headers) new_bp_df.to_csv(self.body_parts_path, index=False, header=False) if update_bp_headers: self.bp_headers = new_headers def update_config_animal_names(self): self.config.set( ConfigKey.MULTI_ANIMAL_ID_SETTING.value, ConfigKey.MULTI_ANIMAL_IDS.value, ",".join(str(x) for x in self.id_lst), ) with open(self.project_path, "w+") as f: self.config.write(f) f.close() @staticmethod @jit(nopython=True) def transpose_multi_animal_table(data: np.array, idx: np.array, animal_cnt: int) -> np.array: unique_tracks = np.unique(idx[:, 0]).flatten() unique_tracks = np.sort(unique_tracks) unique_tracks = unique_tracks[0:animal_cnt] results = np.full((np.max(idx[:, 1]+1), data.shape[1] * animal_cnt), 0.0) for i in prange(np.max(idx[:, 1])+1): for j in prange(animal_cnt): data_idx = np.argwhere((idx[:, 0] == unique_tracks[j]) & (idx[:, 1] == i)).flatten() if len(data_idx) == 1: animal_frm_data = data[data_idx[0]] else: animal_frm_data = np.full((data.shape[1]), 0.0) results[i][j * animal_frm_data.shape[0] : j * animal_frm_data.shape[0] + animal_frm_data.shape[0]] = animal_frm_data return results def read_apt_trk_file(self, file_path: str): print("Reading data using scipy.io...") try: trk_coordinates = sio.loadmat(file_path)["pTrk"] track_cnt = trk_coordinates.shape[3] data = [trk_coordinates[..., i] for i in range(track_cnt)] except NotImplementedError: print("Failed to read data using scipy.io. Reading data using h5py...") with h5py.File(file_path, "r") as trk_dict: trk_list = list(trk_dict["pTrk"]) t_second = np.array(trk_list) if len(t_second.shape) > 3: t_third = np.swapaxes(t_second, 0, 3) trk_coordinates = np.swapaxes(t_third, 1, 2) track_cnt = trk_coordinates.shape[3] data = [trk_coordinates[..., i] for i in range(track_cnt)] else: data = np.swapaxes(t_second, 0, 2) track_cnt = 1 if track_cnt != self.animal_cnt: raise CountError( msg=f"There are {str(track_cnt)} tracks in the .trk file {file_path}. But your SimBA project expects {str(self.animal_cnt)} tracks.", source=self.__class__.__name__, ) if self.animal_cnt != 1: animal_df_lst = [] for animal in data: m, n, r = animal.shape out_arr = np.column_stack((np.repeat(np.arange(m), n), animal.reshape(m * n, -1))) animal_df_lst.append(pd.DataFrame(out_arr).T.iloc[1:].reset_index(drop=True)) self.data_df = pd.concat(animal_df_lst, axis=1).fillna(0).astype(np.int) else: m, n, r = data.shape out_arr = np.column_stack((np.repeat(np.arange(m), n), data.reshape(m * n, -1))) self.data_df = (pd.DataFrame(out_arr).T.iloc[1:].reset_index(drop=True).astype(np.int)) p_cols = pd.DataFrame(1, index=self.data_df.index, columns=self.data_df.columns[1::2] + 0.5) return pd.concat([self.data_df, p_cols], axis=1) # .sort_index(axis=1)