Source code for simba.data_processors.severity_frame_based_calculator

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import os
from datetime import datetime
from typing import Dict, Union

import cv2
import numpy as np
import pandas as pd

from simba.mixins.config_reader import ConfigReader
from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
from simba.utils.checks import (check_if_filepath_list_is_empty,
                                check_that_column_exist)
from simba.utils.data import (create_color_palettes, detect_bouts, find_bins,
                              get_video_meta_data)
from simba.utils.errors import InvalidVideoFileError, NoDataError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import get_fn_ext, read_df
from simba.utils.warnings import NoDataFoundWarning


[docs]class SeverityFrameCalculator(ConfigReader, FeatureExtractionMixin): """ Computes the "severity" of classification frame events based on how much the animals are moving. Frames are scored as less or more severe at lower and higher movements, respectively. :param str config_path: path to SimBA project config file in Configparser format. :param dict settings: how to calculate the severity. E.g., {'brackets': 10, 'clf': 'Attack', 'animals': ['Simon', 'JJ'], 'time': True, 'frames': False}. .. note:: `Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/Scenario2.md>`__. Examples ---------- >>> settings = {'brackets': 10, 'clf': 'Attack', 'animals': ['Simon', 'JJ'], 'time': True, 'frames': False, 'normalization': 'ALL VIDEOS', 'save_bin_definitions': True} >>> processor = SeverityFrameCalculator(config_path='project_folder/project_config.ini', settings=settings) >>> processor.run() >>> processor.save() """ def __init__(self, config_path: Union[str, os.PathLike], settings: Dict): ConfigReader.__init__(self, config_path=config_path) self.settings = settings check_if_filepath_list_is_empty( filepaths=self.machine_results_paths, error_msg=f"SIMBA ERROR: Cannot process severity. {self.machine_results_dir} directory is empty", ) save_name = os.path.join( f'severity_{datetime.now().strftime("%Y%m%d%H%M%S")}.csv' ) self.movement_animal_bp_dict = { k: self.animal_bp_dict[k] for k in self.settings["animals"] } self.color_lst_lst = create_color_palettes( no_animals=len(list(self.animal_bp_dict.keys())), map_size=len(self.bp_headers), ) definitions_save_name = os.path.join( f'severity_bin_definitions_{datetime.now().strftime("%Y%m%d%H%M%S")}.csv' ) self.save_path = os.path.join(self.logs_path, save_name) self.definitions_path = os.path.join(self.logs_path, definitions_save_name) self.results = {} def __calculate_movements(self): self.movements = {} for file_cnt, file_path in enumerate(self.machine_results_paths): _, video_name, _ = get_fn_ext(file_path) print( f"Analyzing movements in {video_name} ({file_cnt+1}/{len(self.machine_results_paths)})..." ) _, px_per_mm, fps = self.read_video_info(video_name=video_name) df = read_df(file_path=file_path, file_type=self.file_type) if self.settings["clf"] not in df.columns: NoDataFoundWarning( msg=f'Skipping file {video_name} - {self.settings["clf"]} data not present in file' ) continue video_movement = np.full((len(df)), 0) for animal_name, animal_bodyparts in self.movement_animal_bp_dict.items(): animal_df = df[animal_bodyparts["X_bps"] + animal_bodyparts["Y_bps"]] animal_df = self.create_shifted_df(df=animal_df) for bp_x, bp_y in zip( animal_bodyparts["X_bps"], animal_bodyparts["Y_bps"] ): video_movement = np.add( video_movement, self.euclidean_distance( animal_df[bp_x].values, animal_df[f"{bp_x}_shifted"].values, animal_df[bp_y].values, animal_df[f"{bp_y}_shifted"].values, px_per_mm, ), ) self.movements[video_name] = video_movement.astype(np.int)
[docs] def run(self): self.__calculate_movements() self.visualization_data = {} self.video_bins_info = find_bins( data=self.movements, bracket_type=self.settings["bracket_type"], bracket_cnt=self.settings["brackets"], normalization_method=self.settings["normalization"], ) for file_cnt, file_path in enumerate(self.machine_results_paths): _, video_name, _ = get_fn_ext(file_path) print( f"Matching brackets to movements in video {video_name} ({file_cnt+1}/{len(self.machine_results_paths)})..." ) self.results[video_name] = {} _, px_per_mm, fps = self.read_video_info(video_name=video_name) df = read_df(file_path=file_path, file_type=self.file_type).astype(int) check_that_column_exist( df=df, column_name=self.settings["clf"], file_name=file_path ) move_df = pd.DataFrame(self.movements[video_name], columns=["MOVEMENT"]) video_bins = np.array((0)) video_bins = np.hstack( (video_bins, self.video_bins_info[video_name][:, -1]) ) hist = np.histogram([self.movements[video_name]], bins=video_bins) move_df["BIN"] = np.fmin( np.digitize(self.movements[video_name], hist[1]), self.settings["brackets"], ) df = pd.concat([df, move_df], axis=1) if self.settings["visualize"]: self.visualization_data[video_name] = df clf_df = ( df["BIN"][df[self.settings["clf"]] == 1] .astype(int) .reset_index(drop=True) ) for i in range(0, self.settings["brackets"]): if self.settings["frames"]: self.results[video_name][f"Grade {str(i + 1)} (frames)"] = len( clf_df[clf_df == i] ) if self.settings["time"]: self.results[video_name][f"Grade {str(i + 1)} (s)"] = round( (len(clf_df[clf_df == i]) / fps), 4 ) self.__save() if self.settings["visualize"]: self.__visualize()
def __save(self): print("Saving data..") out_df = pd.DataFrame(columns=["VIDEO", "MEASUREMENT", "VALUE"]) for video_name, video_data in self.results.items(): for grade, grade_data in video_data.items(): out_df.loc[len(out_df)] = [video_name, grade, grade_data] out_df.to_csv(self.save_path) self.timer.stop_timer() stdout_success( msg=f"Severity data saved at {self.save_path}", elapsed_time=self.timer.elapsed_time_str, ) if self.settings["save_bin_definitions"]: results = pd.DataFrame(columns=["VIDEO", "SEVERITY_BIN", ">=", "<"]) for video_name, video_bins in self.video_bins_info.items(): for bin_cnt, video_bin in enumerate(video_bins): results.loc[len(results)] = [ video_name, bin_cnt + 1, video_bin[0], video_bin[1], ] save_path = os.path.join( self.logs_path, f"severity_bin_definitions_{self.datetime}.csv" ) results.to_csv(save_path) stdout_success(msg=f"Severity bracket definitions saved at {save_path}") def __visualize(self): bouts_df_lst = [] video_timer, save_dir = SimbaTimer(start=True), os.path.join( self.project_path, "frames", "output", "severity_bouts" ) if not os.path.exists(save_dir): os.makedirs(save_dir) for video_name, video_data in self.visualization_data.items(): bouts_df = detect_bouts( data_df=video_data, target_lst=[self.settings["clf"]], fps=-1 ) bouts_df.insert(0, "VIDEO", video_name) bouts_df_lst.append(bouts_df) bouts_df = pd.concat(bouts_df_lst, axis=0) if self.settings["visualize_event_cnt"] == "ALL CLIPS": self.settings["visualize_event_cnt"] = len(bouts_df) elif self.settings["visualize_event_cnt"] > len(bouts_df): self.settings["visualize_event_cnt"] = len(bouts_df) print( f'User specified {self.settings["visualize_event_cnt"]} visualization but only {len(self.results)} bouts where found. Creating {len(self.results)} videos...' ) bouts_df = bouts_df.sample(n=self.settings["visualize_event_cnt"]).reset_index( drop=True ) for idx, r in bouts_df.iterrows(): video_name, start_frm, end_frame = ( r["VIDEO"], r["Start_frame"], r["End_frame"], ) print(start_frm) bout_df = self.visualization_data[video_name].iloc[start_frm:end_frame] clip_path = os.path.join( save_dir, f'{video_name}_{self.settings["clf"]}_{start_frm}_{end_frame}.mp4', ) video_path = self.find_video_of_file( video_dir=self.video_dir, filename=video_name ) video_meta_data = get_video_meta_data(video_path) self.space_scale, self.radius_scale, self.res_scale, self.font_scale = ( 60, 12, 1500, 1.1, ) self.max_dim = max(video_meta_data["width"], video_meta_data["height"]) self.circle_scale = int(self.radius_scale / (self.res_scale / self.max_dim)) self.font_size = float(self.font_scale / (self.res_scale / self.max_dim)) video_fps = int(video_meta_data["fps"] * self.settings["video_speed"]) if video_fps < 1: video_fps = 1 cap = cv2.VideoCapture(video_path) writer = cv2.VideoWriter( clip_path, cv2.VideoWriter_fourcc(*"mp4v"), video_fps, (int(video_meta_data["width"]), int(video_meta_data["height"])), ) event_frm_count, frm_cnt, current_frm = ( (end_frame - start_frm), 0, start_frm, ) cap.set(1, current_frm) while current_frm < end_frame: ret, img = cap.read() if self.settings["show_pose"]: frm_pose = bout_df.loc[current_frm] cv2.putText( img, f'Severity bin: {frm_pose["BIN"]}', (10, 50), cv2.FONT_HERSHEY_COMPLEX, self.font_size, (255, 0, 0), 2, ) for animal_cnt, (animal_name, animal_body_parts) in enumerate( self.animal_bp_dict.items() ): for bp_cnt, (x_name, y_name) in enumerate( zip(animal_body_parts["X_bps"], animal_body_parts["Y_bps"]) ): x = frm_pose[[x_name, y_name]].values cv2.circle( img, (x[0], x[1]), 0, self.color_lst_lst[animal_cnt][bp_cnt], self.circle_scale, ) if not ret: raise InvalidVideoFileError( msg=f'Could not find frame {current_frm} in video {video_path}. Video {video_path} contains {video_meta_data["frame_count"]} frames.' ) writer.write(img) frm_cnt += 1 print( f"Frame {str(frm_cnt)} / {str(event_frm_count)}, Event {idx + 1}/{str(len(bouts_df))}, Video {video_name}" ) current_frm += 1 writer.release() video_timer.stop_timer() stdout_success( msg=f"Videos complete: saved in {save_dir}", elapsed_time=video_timer.elapsed_time_str, )
# settings = {'brackets': 10, # 'clf': 'Attack', # 'animals': ['Simon', 'JJ'], # 'normalization': 'ALL VIDEOS', #BY VIDEO # 'bracket_type': "QUANTILE", # 'time': True, # 'frames': True, # 'save_bin_definitions': True, # 'visualize': True, # 'visualize_event_cnt': 'ALL CLIPS', # 'video_speed': 0.1, # 'show_pose': True} # processor = SeverityFrameCalculator(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', settings=settings) # processor.run()