__author__ = "Simon Nilsson; sronilsson@gmail.com"
import glob
import os
from typing import Any, Dict, List, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from simba.mixins.plotting_mixin import PlottingMixin
from simba.mixins.unsupervised_mixin import UMLMixin
from simba.unsupervised.enums import Clustering, UMLOptions, Unsupervised
from simba.utils.checks import (check_if_dir_exists,
check_if_filepath_list_is_empty,
check_if_keys_exist_in_dict, check_instance,
check_int, check_str, check_that_column_exist,
check_valid_lst)
from simba.utils.enums import Formats, Options
from simba.utils.printing import stdout_success
from simba.utils.read_write import read_pickle
[docs]class GridSearchVisualizer(UMLMixin):
"""
Visualize grid-searched latent spaces in .png format.
.. image:: _static/img/GridSearchVisualizer.png
:alt: Grid Search Visualizer
:width: 800
:align: center
:param model_dir: path to pickle holding unsupervised results in ``data_map.yaml`` format.
:param save_dir: directory holding one or more unsupervised results in pickle ``data_map.yaml`` format.
:param settings: User-defined image attributes (e.g., continous and catehorical palettes)
:example:
>>> settings = {'CATEGORICAL_PALETTE': 'Pastel1', 'CONTINUOUS_PALETTE': 'magma', 'SCATTER_SIZE': 10}
>>> visualizer = GridSearchVisualizer(model_dir='/Users/simon/Desktop/envs/troubleshooting/unsupervised/cluster_models_042023', save_dir='/Users/simon/Desktop/envs/troubleshooting/unsupervised/images', settings=settings)
>>> visualizer.continuous_visualizer(continuous_vars=['START_FRAME'])
>>> visualizer.categorical_visualizer(categoricals=['CLUSTER'])
"""
def __init__(
self,
model_dir: Union[str, os.PathLike],
save_dir: Union[str, os.PathLike],
settings: Dict[str, Any],
):
super().__init__()
check_if_dir_exists(in_dir=save_dir)
check_if_dir_exists(in_dir=model_dir)
self.save_dir, self.settings, self.model_dir = save_dir, settings, model_dir
self.data_paths = glob.glob(model_dir + f"/*.{Formats.PICKLE.value}")
check_if_keys_exist_in_dict(
data=self.settings,
key=["CATEGORICAL_PALETTE", "CONTINUOUS_PALETTE", "SCATTER_SIZE"],
name=f"{self.__class__.__name__} settings",
)
check_if_filepath_list_is_empty(
filepaths=self.data_paths,
error_msg=f"SIMBA ERROR: No pickle files found in {model_dir}",
)
check_int(
name=f"{self.__class__.__name__} scatter size",
value=self.settings["SCATTER_SIZE"],
min_value=1,
)
check_str(
name=f"{self.__class__.__name__} categorical palette",
value=self.settings["CATEGORICAL_PALETTE"],
options=Options.PALETTE_OPTIONS_CATEGORICAL.value,
)
check_str(
name=f"{self.__class__.__name__} continuous palette",
value=self.settings["CONTINUOUS_PALETTE"],
options=Options.PALETTE_OPTIONS.value,
)
def __extract_plot_data(self, data: dict):
embedding_data = pd.DataFrame(
data[Unsupervised.DR_MODEL.value][Unsupervised.MODEL.value].embedding_,
columns=["X", "Y"],
)
bouts_data = data[Unsupervised.DATA.value][
Unsupervised.BOUTS_FEATURES.value
].reset_index()
target_data = data[Unsupervised.DATA.value][Unsupervised.BOUTS_TARGETS.value]
cluster_data = pd.DataFrame(
data[Clustering.CLUSTER_MODEL.value][Unsupervised.MODEL.value]
.labels_.reshape(-1, 1)
.astype(np.int8),
columns=["CLUSTER"],
)
data = pd.concat(
[embedding_data, bouts_data, target_data, cluster_data], axis=1
)
return data.loc[:, ~data.columns.duplicated()].copy()
def categorical_visualizer(self, categorical_vars: List[str]):
check_valid_lst(
data=categorical_vars,
source=self.__class__.__name__,
valid_dtypes=(str,),
min_len=1,
)
print(
f"Creating {len(categorical_vars)} categorical plot(s) from {len(self.data_paths)} data files..."
)
for file_cnt, file_path in enumerate(self.data_paths):
v = read_pickle(data_path=file_path)
check_if_keys_exist_in_dict(
data=v,
key=[Unsupervised.DR_MODEL.value, Unsupervised.DATA.value],
name=file_path,
)
data = self.__extract_plot_data(data=v)
for variable in categorical_vars:
check_that_column_exist(
df=data,
column_name=variable,
file_name=v[Unsupervised.DR_MODEL.value][
Unsupervised.HASHED_NAME.value
],
)
save_path = os.path.join(
self.save_dir,
f"{v[Clustering.CLUSTER_MODEL.value][Unsupervised.HASHED_NAME.value]}_{variable}.png",
)
_ = PlottingMixin.categorical_scatter(
data=data,
columns=("X", "Y", variable),
palette=self.settings["CATEGORICAL_PALETTE"],
size=self.settings["SCATTER_SIZE"],
save_path=save_path,
show_box=False,
)
stdout_success(msg=f"Saved {save_path}...")
plt.close("all")
self.timer.stop_timer()
stdout_success(
msg=f"{int(len(categorical_vars) * len(self.data_paths))} categorical plot(s) created.",
elapsed_time=self.timer.elapsed_time_str,
)
def continuous_visualizer(self, continuous_vars: List[str]):
check_valid_lst(
data=continuous_vars,
source=self.__class__.__name__,
valid_dtypes=(str,),
min_len=1,
)
print(
f"Creating {len(continuous_vars)} categorical plot(s) from {len(self.data_paths)} data files..."
)
for file_cnt, file_path in enumerate(self.data_paths):
v = read_pickle(data_path=file_path)
check_if_keys_exist_in_dict(
data=v,
key=[Unsupervised.DR_MODEL.value, Unsupervised.DATA.value],
name=file_path,
)
data = self.__extract_plot_data(data=v)
for variable in continuous_vars:
check_that_column_exist(
df=data,
column_name=variable,
file_name=v[Unsupervised.DR_MODEL.value][
Unsupervised.HASHED_NAME.value
],
)
save_path = os.path.join(
self.save_dir,
f"{v[Clustering.CLUSTER_MODEL.value][Unsupervised.HASHED_NAME.value]}_{variable}.png",
)
_ = PlottingMixin.continuous_scatter(
data=data,
columns=("X", "Y", variable),
palette=self.settings["CONTINUOUS_PALETTE"],
size=self.settings["SCATTER_SIZE"],
save_path=save_path,
show_box=False,
)
plt.close("all")
stdout_success(msg=f"Saved {save_path}...")
self.timer.stop_timer()
stdout_success(
msg=f"{int(len(continuous_vars) * len(self.data_paths))} continuous plot(s) created.",
elapsed_time=self.timer.elapsed_time_str,
)
# #
# settings = {'CATEGORICAL_PALETTE': 'tab20', 'CONTINUOUS_PALETTE': 'magma', 'SCATTER_SIZE': 10}
# test = GridSearchVisualizer(model_dir='/Users/simon/Desktop/envs/NG_Unsupervised/project_folder/clusters',
# save_dir='/Users/simon/Desktop/envs/NG_Unsupervised/project_folder/cluster_vis',
# settings=settings)
# test.categorical_visualizer(categorical_vars=['CLUSTER'])
# test.continuous_visualizer(continuous_vars=['PROBABILITY'])
# settings = {'CATEGORICAL_PALETTE': 'Pastel1', 'CONTINUOUS_PALETTE': 'magma', 'SCATTER_SIZE': 10}
# test = GridSearchVisualizer(model_dir='/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/clustering_test_2',
# save_dir='/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/cluster_viz_2',
# settings=settings)
# test.categorical_visualizer(categorical_vars=['CLASSIFIER'])
# settings = {'CATEGORICAL_PALETTE': 'Pastel1', 'CONTINUOUS_PALETTE': 'magma', 'SCATTER_SIZE': 10}
# test = GridSearchVisualizer(model_dir='/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/cluster_mdls',
# save_dir='/Users/simon/Desktop/envs/simba/troubleshooting/NG_Unsupervised/project_folder/cluster_vis',
# settings=settings)
# #test.categorical_visualizer(categorical_vars=['VIDEO', 'CLUSTER'])
# test.continuous_visualizer(continuous_vars=['START_FRAME'])
# settings = {'PALETTE': 'Pastel1'}
# test = GridSearchVisualizer(model_dir='/Users/simon/Desktop/envs/troubleshooting/unsupervised/cluster_models',
# save_dir='/Users/simon/Desktop/envs/troubleshooting/unsupervised/images',
# settings=settings)
# test.cluster_visualizer()
# settings = {'CATEGORICAL_PALETTE': 'Pastel1', 'SCATTER_SIZE': 10}
# test = GridSearchVisualizer(model_dir='/Users/simon/Desktop/envs/troubleshooting/unsupervised/cluster_models_042023',
# save_dir='/Users/simon/Desktop/envs/troubleshooting/unsupervised/images',
# settings=settings)
# #test.continuous_visualizer(continuous_vars=['START_FRAME'])
# test.categorical_visualizer(categoricals=['CLUSTER'])