Source code for simba.unsupervised.data_extractor

__author__ = "Simon Nilsson; sronilsson@gmail.com"

import json
import os
from typing import List, Optional, Union

try:
    from typing import Literal
except:
    from typing_extensions import Literal

import numpy as np
import pandas as pd

from simba.mixins.config_reader import ConfigReader
from simba.mixins.unsupervised_mixin import UMLMixin
from simba.unsupervised.enums import Clustering, UMLOptions, Unsupervised
from simba.utils.checks import (check_file_exist_and_readable,
                                check_if_dir_exists,
                                check_if_keys_exist_in_dict,
                                check_valid_extension)
from simba.utils.enums import Formats
from simba.utils.errors import InvalidInputError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
                                    get_fn_ext, read_pickle, write_pickle)

CLUSTERER_PARAMETERS = "CLUSTERER HYPER-PARAMETERS"
DIMENSIONALITY_REDUCTION_PARAMETERS = "DIMENSIONALITY REDUCTION HYPER-PARAMETERS"
SCALER = "SCALER"
SCALED_DATA = "SCALED DATA"
LOW_VARIANCE_FIELDS = "LOW VARIANCE FIELDS"
FEATURE_NAMES = "FEATURE NAMES"
FRAME_FEATURES = "FRAME FEATURES"
FRAME_POSE = "FRAME POSE"
FRAME_TARGETS = "FRAME TARGETS"
BOUTS_FEATURES = "BOUTS FEATURES"
BOUTS_TARGETS = "BOUTS TARGETS"
BOUTS_DIM_CORDS = "BOUTS DIMENSIONALITY REDUCTION DATA"
BOUTS_CLUSTER_LABELS = "BOUTS CLUSTER LABELS"


[docs]class DataExtractor(UMLMixin, ConfigReader): """ Extracts human-readable data from directory of pickles or single pickle file that holds unsupervised analyses. :param config_path: path to SimBA configparser.ConfigParser project_config.ini :param data_path: path to pickle holding unsupervised results in ``data_map.yaml`` format. :param data_type: The type of data to extract. E.g., CLUSTERER_PARAMETERS, DIMENSIONALITY_REDUCTION_PARAMETERS, SCALER, SCALED_DATA, LOW_VARIANCE_FIELDS, FEATURE_NAMES, FRAME_FEATURES, FRAME_POSE, FRAME_TARGET, BOUTS_FEATURES, BOUTS_TARGETS, BOUTS_DIM_CORDS :param settings: User-defined parameters for data extraction. :example: >>> extractor = DataExtractor(data_path='unsupervised/cluster_models/awesome_curran.pickle', data_type=['BOUTS_TARGETS'], settings=None, config_path='unsupervised/project_folder/project_config.ini') >>> extractor.run() """ def __init__( self, config_path: Union[str, os.PathLike], data_path: Union[str, os.PathLike], data_types: List[str], settings: Optional[dict] = None, ): check_file_exist_and_readable(file_path=config_path) ConfigReader.__init__(self, config_path=config_path) UMLMixin.__init__(self) if os.path.isdir(data_path): check_if_dir_exists(in_dir=data_path) self.data_paths = find_files_of_filetypes_in_directory( directory=data_path, extensions=[f".{Formats.PICKLE.value}"], raise_error=True, ) else: check_valid_extension( path=data_path, accepted_extensions=Formats.PICKLE.value ) self.data_paths = [data_path] invalid_dtypes = list(set(data_types) - set(UMLOptions.DATA_TYPES.value)) if len(data_types) == 0: raise InvalidInputError( msg=f"data_types is an empty list. Accepted options: {UMLOptions.DATA_TYPES.value}.", source=self.__class__.__name__, ) if len(invalid_dtypes) > 0: raise InvalidInputError( msg=f"Found invalid data types: {invalid_dtypes}. Accepted: {UMLOptions.DATA_TYPES.value}.", source=self.__class__.__name__, ) self.save_dir = os.path.join( self.logs_path, f"extracted_unsupervised_model_data_{self.datetime}" ) if not os.path.isdir(self.save_dir): os.makedirs(self.save_dir) self.settings, self.data_types = settings, data_types def run(self): for file_cnt, file_path in enumerate(self.data_paths): print(f"Processing file {file_cnt+1}/{len(self.data_paths)}...") file_timer = SimbaTimer(start=True) v = read_pickle(data_path=file_path) mdl_name = get_fn_ext(filepath=file_path)[1] save_subdir = os.path.join(self.save_dir, get_fn_ext(filepath=file_path)[1]) if not os.path.isdir(save_subdir): os.makedirs(save_subdir) for data_type in self.data_types: if data_type == CLUSTERER_PARAMETERS: check_if_keys_exist_in_dict( data=v, key=[Clustering.CLUSTER_MODEL.value], name=file_path ) save_path = os.path.join( save_subdir, f"cluster_parameters_{mdl_name}.json" ) json.dump( v[Clustering.CLUSTER_MODEL.value][ Unsupervised.PARAMETERS.value ], open(save_path, "w"), indent=4, sort_keys=True, ) print( f"Saved cluster parameters for model {mdl_name} at {save_path}..." ) elif data_type == DIMENSIONALITY_REDUCTION_PARAMETERS: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value], name=file_path ) save_path = os.path.join( save_subdir, f"dimensionality_reduction_parameters_{mdl_name}.json", ) json.dump( v[Unsupervised.DR_MODEL.value][Unsupervised.PARAMETERS.value], open(save_path, "w"), indent=4, sort_keys=True, ) print( f"Saved dimension reduction parameters for model {mdl_name} at {save_path} ..." ) # elif data_type == SCALER: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value, Unsupervised.METHODS.value], name=file_path, ) mdl_name = v[Unsupervised.DR_MODEL.value][ Unsupervised.HASHED_NAME.value ] save_path = os.path.join(save_subdir, f"scaler_{mdl_name}.pickle") write_pickle( data=v[Unsupervised.METHODS.value][Unsupervised.SCALER.value], save_path=save_path, ) print(f"Saved scaler for model {mdl_name} at {save_path}...") elif data_type == SCALED_DATA: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value, Unsupervised.METHODS.value], name=file_path, ) save_path = os.path.join(save_subdir, f"scaled_data_{mdl_name}.csv") v[Unsupervised.METHODS.value][ Unsupervised.SCALED_DATA.value ].to_csv(save_path) print(f"Saved scaled data for {mdl_name} model at {save_path}...") elif data_type == LOW_VARIANCE_FIELDS: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value, Unsupervised.METHODS.value], name=file_path, ) save_path = os.path.join( save_subdir, f"low_variance_fields_{mdl_name}.csv" ) out_df = pd.DataFrame( data=v[Unsupervised.METHODS.value][ Unsupervised.LOW_VARIANCE_FIELDS.value ], columns=["FIELD_NAMES"], ) out_df.to_csv(save_path) print( f"Saved low variance fields for model {mdl_name} at {save_path}..." ) elif data_type == FEATURE_NAMES: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value, Unsupervised.METHODS.value], name=file_path, ) save_path = os.path.join( save_subdir, f"feature_names_{mdl_name}.csv" ) out_df = pd.DataFrame( data=v[Unsupervised.METHODS.value][ Unsupervised.FEATURE_NAMES.value ], columns=["FIELD_NAMES"], ) out_df.to_csv(save_path) print(f"Feature names saved for model {mdl_name} at {save_path}...") elif data_type == FRAME_FEATURES: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value, Unsupervised.DATA.value], name=file_path, ) save_path = os.path.join( save_subdir, f"frame_wise_features_{mdl_name}.csv" ) v[Unsupervised.DATA.value][ Unsupervised.FRAME_FEATURES.value ].to_csv(save_path) print( f"Saved frame-wise features for model {mdl_name} at {save_path}..." ) elif data_type == FRAME_POSE: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value, Unsupervised.DATA.value], name=file_path, ) save_path = os.path.join( save_subdir, f"frame_wise_pose_estimation_{mdl_name}.csv" ) v[Unsupervised.DATA.value][Unsupervised.FRAME_POSE.value].to_csv( save_path ) print( f"Frame-wise pose saved for model {mdl_name} at {save_path}..." ) elif data_type == FRAME_TARGETS: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value, Unsupervised.DATA.value], name=file_path, ) save_path = os.path.join( save_subdir, f"frame_wise_target_data_{mdl_name}.csv" ) v[Unsupervised.DATA.value][Unsupervised.FRAME_TARGETS.value].to_csv( save_path ) print( f"Saved frame-wise target data for model {mdl_name} at {save_path}..." ) elif data_type == BOUTS_FEATURES: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value], name=file_path, ) save_path = os.path.join( save_subdir, f"bout_features_data_{mdl_name}.csv" ) v[Unsupervised.DATA.value][ Unsupervised.BOUTS_FEATURES.value ].to_csv(save_path) print( f"Saved bout features data for model {mdl_name} at {save_path}..." ) elif data_type == BOUTS_TARGETS: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value, Unsupervised.DATA.value], name=file_path, ) save_path = os.path.join( save_subdir, f"bout_targets_{mdl_name}.csv" ) v[Unsupervised.DATA.value][Unsupervised.BOUTS_TARGETS.value].to_csv( save_path ) print( f"Saved bout target data for model {mdl_name} at {save_path}..." ) elif data_type == BOUTS_DIM_CORDS: check_if_keys_exist_in_dict( data=v, key=[Unsupervised.DR_MODEL.value], name=file_path, ) save_path = os.path.join( save_subdir, f"bout_dim_cords_{mdl_name}.csv" ) mdl_data = v[Unsupervised.DR_MODEL.value][ Unsupervised.MODEL.value ].embedding_.astype(np.float32) idx = v[Unsupervised.METHODS.value][ Unsupervised.SCALED_DATA.value ].index mdl_data = pd.DataFrame( data=mdl_data, index=idx, columns=["X", "Y"] ) mdl_data.to_csv(save_path) print( f"Saved bout dimensionality reduction data for model {mdl_name} at {save_path}..." ) elif data_type == BOUTS_CLUSTER_LABELS: check_if_keys_exist_in_dict( data=v, key=[ Unsupervised.DR_MODEL.value, Clustering.CLUSTER_MODEL.value, ], name=file_path, ) save_path = os.path.join( save_subdir, f"bout_cluster_labels_{mdl_name}.csv" ) mdl_data = ( v[Clustering.CLUSTER_MODEL.value][Unsupervised.MODEL.value] .labels_.astype(np.int64) .reshape(-1, 1) ) idx = v[Unsupervised.METHODS.value][ Unsupervised.SCALED_DATA.value ].index mdl_data = pd.DataFrame(data=mdl_data, index=idx, columns=["LABEL"]) mdl_data.to_csv(save_path) print( f"Saved bout cluster labels for model {mdl_name} at {save_path}..." ) else: raise InvalidInputError( msg=f"Invalid datatype {data_type}.", source=self.__class__.__name__, ) file_timer.stop_timer() stdout_success( msg=f"{mdl_name} model data extraction complete", elapsed_time=file_timer.elapsed_time, ) self.timer.stop_timer() stdout_success( msg=f"Data for {len(self.data_paths)} model(s) extracted", elapsed_time=self.timer.elapsed_time, )
# test = DataExtractor(data_path='/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/cluster_mdls/hopeful_khorana.pickle', # data_types=['CLUSTERER HYPER-PARAMETERS'], # settings=None, # config_path='/Users/simon/Desktop/envs/NG_Unsupervised/project_folder/project_config.ini') # # test.run()