import pandas as pd
pd.options.mode.chained_assignment = None
import os
from copy import deepcopy
from typing import List, Optional, Union
import numpy as np
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
from simba.mixins.config_reader import ConfigReader
from simba.utils.checks import (check_file_exist_and_readable, check_str,
check_valid_lst)
from simba.utils.data import animal_interpolator, body_part_interpolator
from simba.utils.enums import TagNames
from simba.utils.errors import DataHeaderError, InvalidInputError
from simba.utils.printing import (SimbaTimer, log_event, stdout_information,
stdout_success)
from simba.utils.read_write import (copy_files_to_directory,
find_files_of_filetypes_in_directory,
get_fn_ext, read_df, write_df)
[docs]class Interpolate(ConfigReader):
"""
Interpolate missing body-parts in pose-estimation data. "Missing" is defined as either (i) when a single body-parts is None, or
when all body-parts belonging to an animal are identical (i.e., the same 2D coordinate or all None).
.. image:: _static/img/interpolation_comparison.webp
:alt: Interpolation comparison
:width: 500
:align: center
.. note::
`Interpolation tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/Scenario1.md#to-import-multiple-dlc-csv-files>`__.
.. important::
The interpolated data overwrites the original data on disk. If the original data is required, pass ``copy_originals = True`` to save a copy of the original data.
:param Union[str, os.PathLike] config_path: path to SimBA project config file in Configparser format.
:param Union[str, os.PathLike] data_path: Path to a directory, path to a file, or a list of file paths to files with pose-estimation data in CSV or parquet format.
:param Optional[Literal['body-parts', 'animals']] type: If 'animals', then interpolation is performed when all body-parts belonging to an animal are identical (i.e., the same 2D coordinate or all None). If 'body-parts` then all body-parts that are None will be interpolated. Default: body-parts.
:param Optional[Literal['nearest', 'linear', 'quadratic']] method: If 'animals', then interpolation is performed when all body-parts belonging to an animal are identical (i.e., the same 2D coordinate or all None). If 'body-parts` then all body-parts that are None will be interpolated. Default: body-parts.
:param Optional[bool] multi_index_df_headers: If truth-like, then the input data is anticipated to have multiple header columns, and output columns will have multiple header columns. Default: False.
:param Optional[bool] copy_originals: If truth-like, then the pre-interpolated, original data, will be bo stored in a subdirectory of the original data. The subdirectory is named according to the type of interpolation and datetime of the operation.
:example:
>>> interpolator = Interpolate(config_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/project_config.ini', data_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/input_csv/test', type='body-parts', multi_index_df_headers=True, copy_originals=True)
>>> interpolator.run()
"""
def __init__(self,
config_path: Union[str, os.PathLike],
data_path: Union[str, os.PathLike, List[Union[str, os.PathLike]]],
type: Optional[Literal['body-parts', 'animals']] = 'body-parts',
method: Optional[Literal['nearest', 'linear', 'quadratic']] = 'nearest',
multi_index_df_headers: Optional[bool] = False,
copy_originals: Optional[bool] = False) -> None:
log_event(logger_name=str(self.__class__.__name__), log_type=TagNames.CLASS_INIT.value, msg=self.create_log_msg_from_init_args(locals=locals()))
ConfigReader.__init__(self, config_path=config_path, read_video_info=False)
check_str(name=f'{self.__class__.__name__} type', value=type.lower(), options=('body-parts', 'animals'))
check_str(name=f'{self.__class__.__name__} method', value=method.lower(), options=('nearest', 'linear', 'quadratic'))
if isinstance(data_path, list):
check_valid_lst(data=data_path, source=self.__class__.__name__, valid_dtypes=(str,))
for i in data_path: check_file_exist_and_readable(file_path=i)
self.file_paths = deepcopy(data_path)
elif os.path.isdir(data_path):
self.file_paths = find_files_of_filetypes_in_directory(directory=data_path, extensions=[f'.{self.file_type}'], raise_error=True)
elif os.path.isfile(data_path):
check_file_exist_and_readable(file_path=data_path)
self.file_paths = [data_path]
else:
raise InvalidInputError(msg=f'{data_path} is not a valid data directory, or a valid file path, or a valid list of file paths', source=self.__class__.__name__)
if copy_originals:
self.originals_dir = os.path.join(os.path.dirname(self.file_paths[0]), f"Pre_{method}_{type}_interpolation_{self.datetime}")
os.makedirs(self.originals_dir)
self.type, self.method, self.multi_index_df_headers, self.copy_originals = type.lower(), method.lower(), multi_index_df_headers, copy_originals
def __insert_multiindex_header(self, df: pd.DataFrame):
multi_idx_header = []
for i in range(len(df.columns)):
multi_idx_header.append(("IMPORTED_POSE", "IMPORTED_POSE", list(df.columns)[i]))
df.columns = pd.MultiIndex.from_tuples(multi_idx_header)
return df
[docs] def run(self):
stdout_information(msg=f'Running interpolation on {len(self.file_paths)} data files...')
for file_cnt, file_path in enumerate(self.file_paths):
video_timer = SimbaTimer(start=True)
_, self.video_name, _ = get_fn_ext(filepath=file_path)
df = read_df(file_path=file_path, file_type=self.file_type, check_multiindex=self.multi_index_df_headers)
if self.multi_index_df_headers:
if len(df.columns) != len(self.bp_headers):
raise DataHeaderError( msg=f"The file {file_path} contains {len(df.columns)} columns, but your SimBA project expects {len(self.bp_headers)} columns representing {int(len(self.bp_headers) / 3)} body-parts (x, y, p). Check that the {self.body_parts_path} lists the correct body-parts associated with the project", source=self.__class__.__name__)
df.columns = self.bp_headers
df = df.apply(pd.to_numeric, errors="coerce").fillna(0.0)
df[df < 0] = 0.0
if self.type == 'animals':
df = animal_interpolator(df=df, animal_bp_dict=self.animal_bp_dict, source=file_path, method=self.method)
else:
df = body_part_interpolator(df=df, animal_bp_dict=self.animal_bp_dict, source=file_path, method=self.method)
if self.multi_index_df_headers:
df = self.__insert_multiindex_header(df=df)
if self.copy_originals:
copy_files_to_directory(file_paths=[file_path], dir=self.originals_dir)
write_df(df=df.astype(np.float32), file_type=self.file_type, save_path=file_path, multi_idx_header=self.multi_index_df_headers)
video_timer.stop_timer()
stdout_information(msg=f"Video {self.video_name} interpolated (elapsed time {video_timer.elapsed_time_str}) ...")
self.timer.stop_timer()
if self.copy_originals:
msg = f"{len(self.file_paths)} data file(s) interpolated using {self.type} {self.method} methods. Originals saved in {self.originals_dir} directory."
else:
msg = f"{len(self.file_paths)} data file(s) interpolated using {self.type} {self.method} methods."
stdout_success(msg=msg, elapsed_time=self.timer.elapsed_time_str, source=self.__class__.__name__)