Source code for simba.data_processors.severity_bout_based_calculator

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import os
from datetime import datetime
from typing import Dict, List, Union

import cv2
import numpy as np
import pandas as pd

from simba.mixins.config_reader import ConfigReader
from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
from simba.utils.checks import (check_if_filepath_list_is_empty,
                                check_that_column_exist)
from simba.utils.data import create_color_palettes, detect_bouts, find_bins
from simba.utils.errors import InvalidVideoFileError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import get_fn_ext, get_video_meta_data, read_df
from simba.utils.warnings import NoDataFoundWarning


[docs]class SeverityBoutCalculator(ConfigReader, FeatureExtractionMixin): """ Computes the "severity" of classification bout events based on how much the animals are moving within the bout. Bouts are scored as less or more severe at lower and higher movements, respectively. :param str config_path: path to SimBA project config file in Configparser format. :param dict settings: how to calculate the severity. E.g., {'brackets': 10, 'clf': 'Attack', 'animals': ['Simon', 'JJ'], 'normalization': 'ALL VIDEOS', 'save_bin_definitions': True, 'visualize': True, 'visualize_event_cnt': 'ALL', 'video_speed': 1.0, 'show_pose': True} .. note:: `Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/Scenario2.md>`__. Examples ---------- >>> settings = {'brackets': 10, 'clf': 'Attack', 'animals': ['Simon', 'JJ'], 'normalization': 'ALL VIDEOS', 'save_bin_definitions': True, 'visualize': True, 'visualize_event_cnt': 'ALL', 'video_speed': 1.0, 'show_pose': True} >>> processor = SeverityBoutCalculator(config_path='project_folder/project_config.ini', settings=settings) >>> processor.run() >>> processor.save() """ def __init__(self, config_path: Union[str, os.PathLike], settings: Dict): ConfigReader.__init__(self, config_path=config_path) self.settings = settings self.movement_animal_bp_dict = { k: self.animal_bp_dict[k] for k in self.settings["animals"] } self.color_lst_lst = create_color_palettes( no_animals=len(list(self.animal_bp_dict.keys())), map_size=len(self.bp_headers), ) check_if_filepath_list_is_empty( filepaths=self.machine_results_paths, error_msg=f"SIMBA ERROR: Cannot process severity. {self.machine_results_dir} directory is empty", ) save_name = os.path.join( f'severity_{datetime.now().strftime("%Y%m%d%H%M%S")}.csv' ) definitions_save_name = os.path.join( f'severity_bin_definitions_{datetime.now().strftime("%Y%m%d%H%M%S")}.csv' ) self.save_path = os.path.join(self.logs_path, save_name) self.definitions_path, self.fourcc = os.path.join( self.logs_path, definitions_save_name ), cv2.VideoWriter_fourcc(*"mp4v") def __calculate_movements(self, data_paths: List[str]): movements = {} for file_cnt, file_path in enumerate(data_paths): _, video_name, _ = get_fn_ext(file_path) print( f"Analyzing movements in {video_name} ({file_cnt+1}/{len(data_paths)})..." ) _, px_per_mm, fps = self.read_video_info(video_name=video_name) df = read_df(file_path=file_path, file_type=self.file_type) if self.settings["clf"] not in df.columns: NoDataFoundWarning( msg=f'Skipping file {video_name} - {self.settings["clf"]} data not present in file' ) continue video_movement = np.full((len(df)), 0) for animal_name, animal_bodyparts in self.movement_animal_bp_dict.items(): animal_df = df[animal_bodyparts["X_bps"] + animal_bodyparts["Y_bps"]] animal_df = self.create_shifted_df(df=animal_df) for bp_x, bp_y in zip( animal_bodyparts["X_bps"], animal_bodyparts["Y_bps"] ): video_movement = np.add( video_movement, self.euclidean_distance( animal_df[bp_x].values, animal_df[f"{bp_x}_shifted"].values, animal_df[bp_y].values, animal_df[f"{bp_y}_shifted"].values, px_per_mm, ), ) movements[video_name] = video_movement return movements
[docs] def run(self): movements = self.__calculate_movements(data_paths=self.machine_results_paths) video_bins_info = find_bins( data=movements, bracket_type=self.settings["bracket_type"], bracket_cnt=self.settings["brackets"], normalization_method=self.settings["normalization"], ) self.results = pd.DataFrame( columns=[ "VIDEO", "EVENT", "START TIME", "END TIME", "START FRAME", "END FRAME", "BOUT TIME", "MOVEMENT", "SEVERITY", ] ) for file_cnt, file_path in enumerate(self.machine_results_paths): _, video_name, _ = get_fn_ext(file_path) _, px_per_mm, fps = self.read_video_info(video_name=video_name) df = read_df(file_path=file_path, file_type=self.file_type) check_that_column_exist( df=df, column_name=self.settings["clf"], file_name=file_path ) bout_df = detect_bouts( data_df=df, target_lst=[self.settings["clf"]], fps=fps ) bout_df.columns = [ "EVENT", "START TIME", "END TIME", "START FRAME", "END FRAME", "BOUT TIME", ] severity_lst, movement_lst = [], [] if len(bout_df) > 0: for i in bout_df[["START FRAME", "END FRAME"]].values: bout_move_val = np.mean(movements[video_name][i[0] : i[1]]) if bout_move_val >= video_bins_info[video_name][-1][1]: severity_lst.append(self.settings["brackets"]) movement_lst.append(bout_move_val) else: for j_cnt, j in enumerate(video_bins_info[video_name]): if j[0] <= bout_move_val <= j[1]: severity_lst.append(j_cnt + 1) movement_lst.append(round(bout_move_val, 4)) bout_df["MOVEMENT"] = movement_lst bout_df["SEVERITY"] = severity_lst bout_df.insert(0, "VIDEO", video_name) self.results = pd.concat([self.results, bout_df], axis=0).reset_index( drop=True ) self.save() if self.settings["save_bin_definitions"]: self.save_bin_definitions(data=video_bins_info) if self.settings["visualize"]: self.visualize()
[docs] def save(self): self.results.to_csv(self.save_path) self.timer.stop_timer() stdout_success( msg=f"Bout severity data saved at {self.save_path}", elapsed_time=self.timer.elapsed_time_str, )
[docs] def save_bin_definitions(self, data=dict): results = pd.DataFrame(columns=["VIDEO", "SEVERITY_BIN", ">=", "<"]) for video_name, video_bins in data.items(): for bin_cnt, video_bin in enumerate(video_bins): results.loc[len(results)] = [ video_name, bin_cnt + 1, video_bin[0], video_bin[1], ] save_path = os.path.join( self.logs_path, f"severity_bin_definitions_{self.datetime}.csv" ) results.to_csv(save_path) stdout_success(msg=f"Severity bracket definitions saved at {save_path}")
[docs] def visualize(self): if self.settings["visualize_event_cnt"] == "ALL CLIPS": self.settings["visualize_event_cnt"] = len(self.results) elif self.settings["visualize_event_cnt"] > len(self.results): self.settings["visualize_event_cnt"] = len(self.results) print( f'User specified {self.settings["visualize_event_cnt"]} visualization but only {len(self.results)} bouts where found. Creating {len(self.results)} videos...' ) video_timer, save_dir = SimbaTimer(start=True), os.path.join( self.project_path, "frames", "output", "severity_bouts" ) if not os.path.exists(save_dir): os.makedirs(save_dir) viz_df = self.results.sample( n=self.settings["visualize_event_cnt"] ).reset_index(drop=True) for idx, r in viz_df.iterrows(): video_name, start_frm, end_frame, severity = ( r["VIDEO"], r["START FRAME"], r["END FRAME"], r["SEVERITY"], ) pose_df = None if self.settings["show_pose"]: pose_df = read_df( os.path.join( self.machine_results_dir, video_name + "." + self.file_type ), self.file_type, usecols=self.bp_headers, ).astype(int) clip_path = os.path.join( save_dir, f'{video_name}_{self.settings["clf"]}_{start_frm}_{end_frame}_{severity}.mp4', ) video_path = self.find_video_of_file( video_dir=self.video_dir, filename=video_name ) video_meta_data = get_video_meta_data(video_path) self.space_scale, self.radius_scale, self.res_scale, self.font_scale = ( 60, 12, 1500, 1.1, ) self.max_dim = max(video_meta_data["width"], video_meta_data["height"]) self.circle_scale = int(self.radius_scale / (self.res_scale / self.max_dim)) video_fps = int(video_meta_data["fps"] * self.settings["video_speed"]) if video_fps < 1: video_fps = 1 cap = cv2.VideoCapture(video_path) writer = cv2.VideoWriter( clip_path, cv2.VideoWriter_fourcc(*"mp4v"), video_fps, (int(video_meta_data["width"]), int(video_meta_data["height"])), ) event_frm_count, frm_cnt, current_frm = end_frame - start_frm, 0, start_frm cap.set(1, current_frm) while current_frm < end_frame: ret, img = cap.read() if self.settings["show_pose"]: frm_pose = pose_df.iloc[current_frm] for animal_cnt, (animal_name, animal_body_parts) in enumerate( self.animal_bp_dict.items() ): for bp_cnt, (x_name, y_name) in enumerate( zip(animal_body_parts["X_bps"], animal_body_parts["Y_bps"]) ): x = frm_pose[[x_name, y_name]].values cv2.circle( img, (x[0], x[1]), 0, self.color_lst_lst[animal_cnt][bp_cnt], self.circle_scale, ) if not ret: raise InvalidVideoFileError( msg=f'Could not find frame {current_frm} in video {video_path}. Video {video_path} contains {video_meta_data["frame_count"]} frames.' ) writer.write(img) frm_cnt += 1 print( f"Frame {str(frm_cnt)} / {str(event_frm_count)}, Event {idx + 1}/{str(len(viz_df))}, Video {video_name}" ) current_frm += 1 writer.release() video_timer.stop_timer() stdout_success( msg=f"Videos complete: saved in {save_dir}", elapsed_time=video_timer.elapsed_time_str, )
# settings = {'brackets': 10, # 'clf': 'Attack', # 'animals': ['Simon', 'JJ'], # 'normalization': 'ALL VIDEOS', #BY VIDEO # 'bracket_type': "QUANTIZE", # 'save_bin_definitions': True, # 'visualize': True, # 'visualize_event_cnt': 'ALL CLIPS', # 'video_speed': 0.1, # 'show_pose': True} # processor = SeverityBoutCalculator(config_path='/Users/simon/Desktop/envs/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', settings=settings) # processor.run()