__author__ = "Simon Nilsson; sronilsson@gmail.com"
import os
from typing import Any, Dict, Union
import pandas as pd
from simba.mixins.config_reader import ConfigReader
from simba.mixins.plotting_mixin import PlottingMixin
from simba.mixins.unsupervised_mixin import UMLMixin
from simba.unsupervised.enums import Clustering, Unsupervised
from simba.utils.checks import (check_file_exist_and_readable,
check_if_keys_exist_in_dict, check_instance)
from simba.utils.printing import stdout_success
from simba.utils.read_write import read_pickle
METHOD = "method"
PEARSON = "pearson"
KENDALL = "kendall"
PALETTE = "palette"
SHAP = "shap"
PLOTS = "plots"
CREATE = "create"
SPEARMAN = "spearman"
CORRELATIONS = "correlations"
[docs]class EmbeddingCorrelationCalculator(UMLMixin, ConfigReader):
"""
Class for correlating dimensionality reduction features with original features for explainability purposes.
.. image:: _static/img/EmbeddingCorrelationCalculator.png
:alt: Embedding Correlation Calculator
:width: 800
:align: center
:param str config_path: path to SimBA configparser.ConfigParser project_config.ini
:param str data_path: path to pickle holding unsupervised results in ``data_map.yaml`` format.
:param dict settings: dict holding which statistical tests to use and how to create plots.
:Example:
>>> settings = {'correlation_methods': ['pearson', 'kendall', 'spearman'], 'plots': {'create': True, 'correlations': 'pearson', 'palette': 'jet'}}
>>> calculator = EmbeddingCorrelationCalculator(config_path='unsupervised/project_folder/project_config.ini', data_path='unsupervised/cluster_models/quizzical_rhodes.pickle', settings=settings)
>>> calculator.run()
"""
def __init__(
self,
data_path: Union[str, os.PathLike],
config_path: Union[str, os.PathLike],
settings: Dict[str, Any],
):
check_file_exist_and_readable(file_path=config_path)
check_instance(
source=f"{self.__class__.__name__} settings",
instance=settings,
accepted_types=(dict,),
)
ConfigReader.__init__(self, config_path=config_path)
UnsupervisedMixin.__init__(self)
check_file_exist_and_readable(file_path=data_path)
self.settings, self.data_path = settings, data_path
check_if_keys_exist_in_dict(
data=settings,
key=[CORRELATIONS, PLOTS],
name=f"{self.__class__.__name__} settings",
)
self.data = read_pickle(data_path=self.data_path)
check_if_keys_exist_in_dict(
data=self.data,
key=[Unsupervised.METHODS.value, Unsupervised.DR_MODEL.value],
name=self.data_path,
)
self.save_path = os.path.join(
self.logs_path,
f"embedding_correlations_{self.data[Unsupervised.DR_MODEL.value][Unsupervised.HASHED_NAME.value]}_{self.datetime}.csv",
)
def run(self):
print("Calculating embedding correlations...")
self.x_df = self.data[Unsupervised.METHODS.value][
Unsupervised.SCALED_DATA.value
]
self.y_df = pd.DataFrame(
self.data[Unsupervised.DR_MODEL.value][Unsupervised.MODEL.value].embedding_,
columns=["X", "Y"],
index=self.x_df.index,
)
results = pd.DataFrame()
for correlation_method in self.settings[CORRELATIONS]:
results[f"{correlation_method}_Y"] = self.x_df.corrwith(
self.y_df["Y"], method=correlation_method
)
results[f"{correlation_method}_X"] = self.x_df.corrwith(
self.y_df["X"], method=correlation_method
)
results.to_csv(self.save_path)
self.timer.stop_timer()
stdout_success(
msg=f"Embedding correlations saved in {self.save_path}",
elapsed_time=self.timer.elapsed_time_str,
)
if self.settings[PLOTS][CREATE]:
print("Creating embedding correlation plots...")
df = pd.concat([self.x_df, self.y_df], axis=1)
save_dir = os.path.join(
self.logs_path,
f"embedding_correlation_plots_{self.data[Unsupervised.DR_MODEL.value][Unsupervised.HASHED_NAME.value]}_{self.datetime}",
)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for feature_cnt, feature_name in enumerate(
self.data[Unsupervised.METHODS.value][Unsupervised.FEATURE_NAMES.value]
):
save_path = os.path.join(save_dir, f"{feature_name}.png")
_ = PlottingMixin.continuous_scatter(
data=df,
columns=["X", "Y", feature_name],
palette=self.settings[PLOTS][PALETTE],
title=feature_name,
save_path=save_path,
show_box=False,
)
print(
f"Saving image {str(feature_cnt+1)}/{str(len(df.columns))} ({feature_name})"
)
self.timer.stop_timer()
stdout_success(
msg=f"Embedding correlation calculations complete",
elapsed_time=self.timer.elapsed_time_str,
)
# settings = {'correlations': ['pearson', 'kendall', 'spearman'], 'plots': {'create': True, 'correlations': 'pearson', 'palette': 'jet'}}
# calculator = EmbeddingCorrelationCalculator(config_path='/Users/simon/Desktop/envs/NG_Unsupervised/project_folder/project_config.ini',
# data_path='/Users/simon/Desktop/envs/NG_Unsupervised/project_folder/clusters/beautiful_beaver.pickle',
# settings=settings)
# calculator.run()