__author__ = "Simon Nilsson; sronilsson@gmail.com"
import glob
import os
import platform
import random
import re
import struct
import subprocess
import sys
import tkinter as tk
from copy import copy
from datetime import datetime
from itertools import groupby
from multiprocessing import Lock, Value
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
try:
from typing import Literal
except:
from typing_extensions import Literal
import matplotlib.font_manager
import numpy as np
import pandas as pd
import psutil
import pyglet
from matplotlib import cm
from matplotlib.colors import hsv_to_rgb, rgb2hex
from tabulate import tabulate
import simba
from simba.utils.checks import (check_ffmpeg_available,
check_file_exist_and_readable, check_float,
check_if_dir_exists, check_if_valid_rgb_tuple,
check_instance, check_int, check_str,
check_valid_dict, check_valid_tuple)
from simba.utils.enums import (OS, UML, Defaults, FontPaths, Formats, Keys,
Methods, Options, Paths)
from simba.utils.errors import (FFMPEGNotFoundError, InvalidInputError,
NoFilesFoundError, SimBAPAckageVersionError)
from simba.utils.printing import stdout_information
from simba.utils.read_write import (fetch_pip_data,
find_files_of_filetypes_in_directory,
get_fn_ext, get_video_meta_data, read_json)
from simba.utils.warnings import NoDataFoundWarning
if platform.system() == OS.WINDOWS.value:
from pyglet.libs.win32 import constants
constants.COINIT_MULTITHREADED = 0x2 # 0x2 = COINIT_APARTMENTTHREADED
RGBFloat = Tuple[float, float, float]
[docs]class SharedCounter(object):
"""Counter that can be shared across processes on different cores"""
def __init__(self, initval=0):
self.val = Value("i", initval)
self.lock = Lock()
[docs] def increment(self):
with self.lock:
self.val.value += 1
[docs] def value(self):
with self.lock:
return self.val.value
[docs]def get_body_part_configurations() -> Dict[str, Union[str, os.PathLike]]:
"""
Return dict with named body-part schematics of pose-estimation schemas in SimBA installation as keys,
and paths to the images representing those body-part schematics as values.
"""
lookup = {}
simba_dir = os.path.dirname(simba.__file__)
img_dir = os.path.join(simba_dir, Paths.SCHEMATICS.value)
names_path = os.path.join(simba_dir, Paths.PROJECT_POSE_CONFIG_NAMES.value)
check_file_exist_and_readable(file_path=names_path)
check_if_dir_exists(in_dir=img_dir)
names_lst = list(pd.read_csv(names_path, header=None)[0])
img_paths = glob.glob(img_dir + "/*.png")
img_paths.sort(
key=lambda v: [
int(x) if x.isdigit() else x for x in re.findall(r"[^0-9]|[0-9]+", v)
]
)
for name, img_path in zip(names_lst, img_paths):
lookup[name] = {}
lookup[name]["img_path"] = img_path
return lookup
[docs]def get_bp_config_codes() -> Dict[str, str]:
"""
Helper to match SimBA project_config.ini [create ensemble settings][pose_estimation_body_parts] to string names.
"""
return {
"1 animal; 4 body-parts": "4",
"1 animal; 7 body-parts": "7",
"1 animal; 8 body-parts": "8",
"1 animal; 9 body-parts": "9",
"2 animals; 8 body-parts": "8",
"2 animals; 14 body-parts": "14",
"2 animals; 16 body-parts": "16",
"MARS": Methods.USER_DEFINED.value,
"Multi-animals; 4 body-parts": "8",
"Multi-animals; 7 body-parts": "14",
"Multi-animals; 8 body-parts": "16",
"3D tracking": "3D_user_defined",
"AMBER": "AMBER",
"SimBA BLOB Tracking": Methods.SIMBA_BLOB.value,
"FaceMap": Methods.FACEMAP.value,
"SuperAnimal-TopView": Methods.SUPER_ANIMAL_TOPVIEW.value
}
[docs]def get_bp_config_code_class_pairs() -> Dict[str, object]:
"""
Helper to match SimBA project_config.ini [create ensemble settings][pose_estimation_body_parts] setting to feature extraction module class.
"""
from simba.feature_extractors.amber_feature_extractor import \
AmberFeatureExtractor
from simba.feature_extractors.feature_extractor_4bp import \
ExtractFeaturesFrom4bps
from simba.feature_extractors.feature_extractor_7bp import \
ExtractFeaturesFrom7bps
from simba.feature_extractors.feature_extractor_8bp import \
ExtractFeaturesFrom8bps
from simba.feature_extractors.feature_extractor_8bps_2_animals import \
ExtractFeaturesFrom8bps2Animals
from simba.feature_extractors.feature_extractor_9bp import \
ExtractFeaturesFrom9bps
from simba.feature_extractors.feature_extractor_14bp import \
ExtractFeaturesFrom14bps
from simba.feature_extractors.feature_extractor_16bp import \
ExtractFeaturesFrom16bps
from simba.feature_extractors.feature_extractor_user_defined import \
UserDefinedFeatureExtractor
return {
"16": ExtractFeaturesFrom16bps,
"14": ExtractFeaturesFrom14bps,
"9": ExtractFeaturesFrom9bps,
"8": {1: ExtractFeaturesFrom8bps, 2: ExtractFeaturesFrom8bps2Animals},
"7": ExtractFeaturesFrom7bps,
"4": ExtractFeaturesFrom4bps,
"user_defined": UserDefinedFeatureExtractor,
"AMBER": AmberFeatureExtractor,
}
[docs]def rgb_to_hex(color: Tuple[int, int, int]) -> str:
check_if_valid_rgb_tuple(data=color, raise_error=True, source=rgb_to_hex.__name__)
r, g, b = color
return rgb2hex((r/255, g/255, b/255), keep_alpha=False)
[docs]def get_icons_paths() -> Dict[str, Union[str, os.PathLike]]:
"""
Helper to get dictionary with icons with the icon names as keys (grabbed from file-name) and their
file paths as values.
"""
simba_dir = os.path.dirname(simba.__file__)
icons_dir = os.path.join(simba_dir, Paths.ICON_ASSETS.value)
icon_paths = glob.glob(icons_dir + "/*.png")
icons = {}
for icon_path in icon_paths:
_, icon_name, _ = get_fn_ext(icon_path)
icons[icon_name] = {}
icons[icon_name]["icon_path"] = icon_path
return icons
[docs]def load_simba_fonts():
""" Load fonts defined in simba.utils.enums.FontPaths into memory"""
simba_dir = os.path.dirname(simba.__file__)
font_enum = {i.name: i.value for i in FontPaths}
for k, v in font_enum.items():
pyglet.font.add_file(os.path.join(simba_dir, v))
[docs]def get_named_simba_fonts() -> Dict[str, str]:
"""
Return a dictionary mapping the name (filename without extension) of each font bundled with SimBA
(in ``simba/assets/fonts``) to its absolute file path.
.. seealso::
For all fonts installed on the host OS (rather than only those bundled with SimBA), see :func:`~simba.utils.lookups.get_fonts`.
:example:
>>> get_named_simba_fonts()
>>> {'Poppins Regular': '.../simba/assets/fonts/Poppins Regular.ttf', ...}
"""
simba_dir = os.path.dirname(simba.__file__)
check_if_dir_exists(in_dir=simba_dir, source=get_named_simba_fonts.__name__)
fonts_dir = os.path.join(simba_dir, 'assets', 'fonts')
check_if_dir_exists(in_dir=fonts_dir, source=get_named_simba_fonts.__name__)
font_paths = glob.glob(os.path.join(fonts_dir, '*.ttf')) + glob.glob(os.path.join(fonts_dir, '*.otf'))
return {os.path.splitext(os.path.basename(p))[0]: p for p in font_paths}
[docs]def get_simba_font_name_and_path(font: str) -> Tuple[str, str]:
"""
Resolve a (case-insensitive) bundled SimBA font name to its canonical name and absolute font-file path.
The returned canonical name (correct casing) is suitable for matplotlib / :meth:`~simba.mixins.plotting_mixin.PlottingMixin.make_gantt_plot`,
and the returned path is the ``.ttf`` to pass to :meth:`~simba.mixins.plotting_mixin.PlottingMixin.put_text`.
.. seealso::
For the full name-to-path dictionary, see :func:`~simba.utils.lookups.get_named_simba_fonts`.
:param str font: The font name to resolve (case-insensitive), as listed by :func:`get_named_simba_fonts`.
:return: Tuple of (canonical font name, absolute font-file path).
:rtype: Tuple[str, str]
:raises StringError: If ``font`` is not a valid SimBA font name.
:example:
>>> get_simba_font_name_and_path(font='poppins regular')
>>> ('Poppins Regular', '.../simba/assets/fonts/Poppins Regular.ttf')
"""
check_str(name=f'{get_simba_font_name_and_path.__name__} font', value=font)
simba_fonts = get_named_simba_fonts()
fonts_lower = {k.lower(): k for k in simba_fonts.keys()} # lower-case -> canonical name
if font.lower() not in fonts_lower:
check_str(name=f'{get_simba_font_name_and_path.__name__} font', value=font, options=tuple(simba_fonts.keys()))
canonical_name = fonts_lower[font.lower()]
return canonical_name, simba_fonts[canonical_name]
[docs]def get_emojis() -> Dict[str, str]:
"""
Helper to get dictionary of emojis with names as keys and emojis as values. Note, the same emojis are
represented differently in different python versions.
"""
python_version = str(f"{sys.version_info.major}.{sys.version_info.minor}")
if python_version == "3.6":
return {"thank_you": "".join(chr(x) for x in struct.unpack(">2H", "\U0001f64f".encode("utf-16be"))),
"relaxed": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F600".encode("utf-16be"))),
"error": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F6A8".encode("utf-16be"))),
"complete": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F680".encode("utf-16be"))),
"warning": "".join(chr(x) for x in struct.unpack(">2H", "\u2757\uFE0F".encode("utf-16be"))),
"trash": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F5D1".encode("utf-16be"))),
"information": "".join(chr(x) for x in struct.unpack(">2H", "\U0001F4DD".encode("utf-16be")))} # 📝 memo
elif python_version == "3.10" or python_version == "3.9":
return {
"thank_you": "\U0001f64f".encode("utf-8", "replace").decode(),
"relaxed": "\U0001F600".encode("utf-8", "replace").decode(),
"warning": "\u2757\uFE0F".encode("utf-8", "replace").decode(),
"error": "\U0001F6A8".encode("utf-8", "replace").decode(),
"complete": "\U0001F680".encode("utf-8", "replace").decode(),
"trash": "\U0001F5D1".encode("utf-8", "replace").decode(),
"information": "\U0001F4DD".encode("utf-8", "replace").decode(), # 📝 memo
}
elif python_version == "3.7":
return {
"thank_you": "\U0001f64f".encode("utf16", errors="surrogatepass").decode(
"utf16"
),
"relaxed": "\U0001F600".encode("utf16", errors="surrogatepass").decode(
"utf16"
),
"error": "\U0001F6A8".encode("utf16", errors="surrogatepass").decode(
"utf16"
),
"complete": "\U0001F680".encode("utf16", errors="surrogatepass").decode(
"utf16"
),
"warning": "\u2757\uFE0F".encode("utf16", errors="surrogatepass").decode(
"utf16"
),
"trash": "\U0001F5D1".encode("utf16", errors="surrogatepass").decode(
"utf16"
),
"information": "\U0001F4DD".encode("utf16", errors="surrogatepass").decode(
"utf16"
), # 📝 memo
}
else:
return {
"thank_you": "\U0001f64f",
"relaxed": "\U0001F600",
"error": "\U0001F6A8",
"complete": "\U0001F680",
"warning": "\u2757\uFE0F",
"trash": "\U0001F5D1",
"information": "\U0001F4DD",
# 📝 memo
}
[docs]def get_cmaps() -> List[str]:
"""
Get list of named matplotlib color palettes.
"""
return [
"spring",
"summer",
"autumn",
"cool",
"Wistia",
"Pastel1",
"Set1",
"winter",
"afmhot",
"gist_heat",
"copper",
]
[docs]def get_categorical_palettes():
return [
"Pastel1",
"Pastel2",
"Paired",
"Accent",
"Dark2",
"Set1",
"Set2",
"Set3",
"tab10",
"tab20",
"tab20b",
"tab20c",
]
[docs]def get_color_dict() -> Dict[str, Tuple[int, int, int]]:
"""
Get dict of color names as keys and RGB tuples as values
"""
return {
"Grey": (220, 200, 200),
"Red": (0, 0, 255),
"Dark-red": (0, 0, 139),
"Maroon": (0, 0, 128),
"Orange": (0, 165, 255),
"Dark-orange": (0, 140, 255),
"Coral": (80, 127, 255),
"Chocolate": (30, 105, 210),
"Yellow": (0, 255, 255),
"Green": (0, 128, 0),
"Dark-grey": (105, 105, 105),
"Light-grey": (192, 192, 192),
"Pink": (178, 102, 255),
"Lime": (204, 255, 229),
"Purple": (255, 51, 153),
"Cyan": (255, 255, 102),
"White": (255, 255, 255),
"Black": (0, 0, 0),
"Darkgoldenrod": (184, 134, 11),
"Olive": (109, 113, 46),
"Seagreen": (46, 139, 87),
"Dodgerblue": (30, 144, 255),
"Springgreen": (0, 255, 127),
"Firebrick": (178, 34, 34),
"Indigo": (63, 15, 183),
}
[docs]def get_named_colors() -> List[str]:
"""
Get list of named matplotlib colors.
"""
return [
"red",
"pink",
"lime",
"gold",
"coral",
"lavender",
"sienna",
"tomato",
"grey",
"azure",
"crimson",
"lightgrey",
"aqua",
"plum",
"blue",
"teal",
"maroon",
"green",
"black",
"deeppink",
"darkgoldenrod",
"purple",
"olive",
"seagreen",
"dodgerblue",
"springgreen",
"firebrick",
"indigo",
"white",
]
[docs]def create_color_palettes(no_animals: int, map_size: int) -> List[List[int]]:
"""
Create list of lists of bgr colors, one for each animal. Each list is pulled from a different palette
matplotlib color map.
:param int no_animals: Number of different palette lists
:param int map_size: Number of colors in each created palette.
:return List[List[int]]: BGR colors
:example:
>>> create_color_palettes(no_animals=2, map_size=2)
>>> [[[255.0, 0.0, 255.0], [0.0, 255.0, 255.0]], [[102.0, 127.5, 0.0], [102.0, 255.0, 255.0]]]
"""
colorListofList = []
cmaps = [
"spring",
"summer",
"autumn",
"cool",
"Wistia",
"Pastel1",
"Set1",
"winter",
"afmhot",
"gist_heat",
"copper",
"viridis",
"Set3",
"Set2",
"Paired",
"seismic",
"prism",
"ocean",
]
for colormap in range(no_animals):
if hasattr(cm, "cmap_d") and colormap in cm.cmap_d:
currColorMap = cm.get_cmap(cmaps[colormap], map_size)
else:
currColorMap = cm.get_cmap("spring", map_size)
currColorList = []
for i in range(currColorMap.N):
rgb = list((currColorMap(i)[:3]))
rgb = [i * 255 for i in rgb]
rgb.reverse()
currColorList.append(rgb)
colorListofList.append(currColorList)
return colorListofList
[docs]def get_random_color_palette(n_colors: int):
""" Get a random color palette with N random colors."""
check_int(name=f'{get_random_color_palette.__name__} n_colors', value=n_colors, min_value=1, raise_error=True)
return [tuple(random.randint(0, 255) for _ in range(3)) for _ in range(n_colors)]
[docs]def cardinality_to_integer_lookup() -> Dict[str, int]:
"""
Create dictionary that maps cardinal compass directions to integers.
:example:
>>> data = ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]
>>> [cardinality_to_integer_lookup()[d] for d in data]
>>> [0, 1, 2, 3, 4, 5, 6, 7]
"""
return {"N": 0, "NE": 1, "E": 2, "SE": 3, "S": 4, "SW": 5, "W": 6, "NW": 7}
[docs]def integer_to_cardinality_lookup():
"""
Create dictionary that maps integers to cardinal compass directions.
"""
return {0: "N", 1: "NE", 2: "E", 3: "SE", 4: "S", 5: "SW", 6: "W", 7: "NW"}
[docs]def percent_to_crf_lookup() -> Dict[str, int]:
"""
Create dictionary that matches human-readable percent values to FFmpeg Constant Rate Factor (CRF)
values that regulates video quality in CPU codecs. Higher CRF values translates to lower video quality and reduced
file sizes.
"""
return {
"10": 37,
"20": 34,
"30": 31,
"40": 28,
"50": 25,
"60": 22,
"70": 19,
"80": 16,
"90": 13,
"100": 10,
}
[docs]def gpu_quality_to_cpu_quality_lk():
return {'fast': 34,
'medium': 23,
'slow': 13}
[docs]def percent_to_qv_lk():
"""
Create dictionary that matches human-readable percent values to FFmpeg regulates video quality in CPU codecs.
Higher FFmpeg quality scores maps to smaller, lower quality videos. Used in some AVI codecs such as 'divx' and 'mjpeg'.
"""
return {100: 3,
90: 5,
80: 7,
70: 9,
60: 11,
50: 13,
40: 15,
30: 17,
20: 19,
10: 21}
[docs]def get_ffmpeg_crossfade_methods():
return ['fade',
'fadeblack',
'fadewhite',
'distance',
'wipeleft',
'wiperight',
'wipeup',
'wipedown',
'sideleft',
'sideright',
'sideup',
'sidedown',
'smoothleft',
'smoothright',
'smoothup',
'smoothdown',
'circlecrop',
'rectcrop',
'circleclose',
'circleopen',
'horzclose',
'horzopen',
'vertclose',
'vertopen',
'diagbl',
'diagbr',
'diagtl',
'diagtr',
'hlslice',
'hrslice',
'vuslice',
'vdslice',
'dissolve',
'pixelize',
'radial',
'hblur',
'wipetl',
'wipetr',
'wipebl',
'wipebr',
'fadegrays',
'squeezev',
'squeezeh',
'zoomin',
'hlwind',
'hrwind',
'vuwind',
'vdwind',
'coverleft',
'coverright',
'cobverup',
'coverdown',
'revealleft',
'revealright',
'revealup',
'revealdown']
[docs]def video_quality_to_preset_lookup() -> Dict[str, str]:
"""
Create dictionary that matches human-readable video quality settings to FFmpeg presets for GPU codecs.
"""
return {"Low": "fast", "Medium": "medium", "High": "slow"}
[docs]def get_labelling_img_kbd_bindings() -> dict:
"""
Returns dictionary of tkinter keyboard bindings.
.. note::
Change ``kbd`` values to change keyboard shortcuts. For example:
Some possible examples:
<Key>, <KeyPress>, <KeyRelease>: Binds to any key press or release.
<KeyPress-A>, <Key-a>: Binds to the 'a' key press (case sensitive).
<Up>, <Down>, <Left>, <Right>: Binds to the arrow keys.
<Control-KeyPress-A>, <Control-a>: Binds to Ctrl + A or Ctrl + a
"""
return \
{'frame+1': # MOVE FORWARD 1 FRAME
{'label': 'Right Arrow = +1 frame',
'kbd': "<Right>"},
'frame-1': # MOVE BACK 1 FRAME
{'label': 'Left Arrow = -1 frame',
'kbd': "<Left>"},
'save': # SAVE CURRENT ANNOTATIONS STATS TO DISK
{'label': 'Ctrl + s = Save annotations file',
'kbd': "<Control-s>"},
'frame+1_keep_choices': # MOVE FORWARD 1 FRAME AND KEEP ANNOTATION SELECTIONS OF THE CURRENT FRAME
{'label': 'Ctrl + a = +1 frame and keep choices',
'kbd': "<Control-a>"},
'frame-1_keep_choices': # MOVE BACKWARDS 1 FRAME AND KEEP ANNOTATION SELECTIONS OF THE CURRENT FRAME
{'label': 'Ctrl + q = -1 frame and keep choices',
'kbd': "<Control-q>"},
'print_annotation_statistic': # PRINT ANNOTATION STATISTICS
{'label': 'Ctrl + p = Show annotation statistics',
'kbd': "<Control-p>"},
'last_frame': # SHOW LAST FRAME
{'label': 'Ctrl + l = Last frame',
'kbd': "<Control-l>"},
'first_frame': # SHOW FIRT FRAME
{'label': 'Ctrl + o = First frame',
'kbd': "<Control-o>"}
}
[docs]def get_labelling_video_kbd_bindings() -> dict:
"""
Returns a dictionary of OpenCV-compatible keyboard bindings for video labeling.
Notes:
- Change the `kbd` values to customize keyboard shortcuts.
- OpenCV key codes differ from Tkinter bindings (see `get_labelling_img_kbd_bindings`).
- Use either single-character strings (e.g. 'p') or integer ASCII codes (e.g. 32 for space bar).
Examples:
Remap space bar to Pause/Play:
{'Pause/Play': {'label': 'Space = Pause/Play', 'kbd': 32}}
"""
bindings = {
'Pause/Play': {
'label': 'p = Pause/Play',
'kbd': 'p'
},
'forward_two_frames': {
'label': 'o = +2 frames',
'kbd': 'o'
},
'forward_ten_frames': {
'label': 'e = +10 frames',
'kbd': 'e'
},
'forward_one_second': {
'label': 'w = +1 second',
'kbd': 'w'
},
'backwards_two_frames': {
'label': 't = -2 frames',
'kbd': 't'
},
'backwards_ten_frames': {
'label': 's = -10 frames',
'kbd': 's'
},
'backwards_one_second': {
'label': 'x = -1 second',
'kbd': 'x'
},
'close_window': {
'label': 'q = Close video window',
'kbd': 'q'
},
}
#PERFORM CHECKS THAT BINDINGS ARE DEFINED CORRECTLY.
check_valid_dict( x=bindings, valid_key_dtypes=(str,), valid_values_dtypes=(dict,), source=f'{get_labelling_video_kbd_bindings.__name__} bindings')
cleaned_bindings = {}
for action, config in bindings.items():
check_valid_dict(x=config, valid_key_dtypes=(str,), valid_values_dtypes=(str, int), required_keys=('label', 'kbd'))
kbd_val = config['kbd']
check_str(value=config['label'], allow_blank=False, raise_error=True, name=f'{get_labelling_video_kbd_bindings.__name__} action')
if check_int(name=f'{action} kbd', value=kbd_val, raise_error=False)[0]:
new_config = copy(config)
new_config['kbd'] = int(kbd_val)
cleaned_bindings[action] = new_config
else:
cleaned_bindings[action] = config
return cleaned_bindings
[docs]def get_fonts(sort_alphabetically: bool = False):
"""
Returns a dictionary with all fonts available in OS, with the font name as key and font path as value.
.. seealso::
For only the fonts bundled with SimBA (in ``simba/assets/fonts``), see :func:`~simba.utils.lookups.get_named_simba_fonts`.
:param bool sort_alphabetically: If True, the returned dictionary is sorted by font name using natural (numeric-aware) ordering. If False, fonts are returned in the order reported by matplotlib's font manager. Default False.
:return: Dictionary mapping font name (key) to font file path (value). On Windows the ``C:`` drive prefix is stripped and paths are returned as POSIX-style strings.
:rtype: Dict[str, str]
:example:
>>> get_fonts(sort_alphabetically=True)
>>> {'Arial': '/Windows/Fonts/arial.ttf', ...}
"""
font_dict = {f.name: f.fname for f in matplotlib.font_manager.fontManager.ttflist if not f.name.startswith('.')}
if len(font_dict) == 0:
NoDataFoundWarning(msg='No fonts found on disk using matplotlib.font_manager', source=get_fonts.__name__)
if platform.system() == OS.WINDOWS.value:
font_dict = {key: str(Path(value.replace('C:', '')).as_posix()) for key, value in font_dict.items()}
if sort_alphabetically:
font_dict = dict(sorted(font_dict.items(), key=lambda x: [int(t) if t.isdigit() else t.lower() for t in re.split(r'(\d+)', x[0])]))
return font_dict
[docs]def get_log_config():
return {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s|%(name)s||%(message)s",
"datefmt": "%Y-%m-%dT%H:%M:%SZ",
# "class": "pythonjsonlogger.jsonlogger.JsonFormatter",
}
},
"handlers": {
"file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"formatter": "default",
"mode": "a",
"backupCount": 5,
"maxBytes": 5000000,
}
},
"loggers": {"": {"level": "INFO", "handlers": ["file_handler"]}},
}
[docs]def get_model_names():
model_names_dir = os.path.join(os.path.dirname(simba.__file__), Paths.UNSUPERVISED_MODEL_NAMES.value)
return list(pd.read_parquet(model_names_dir)[UML.NAMES.value])
[docs]def win_to_wsl_path(win_path: Union[str, os.PathLike]) -> str:
"""Helper to convert a windows path name, to a WSL path name"""
result = subprocess.run(["wsl.exe", "wslpath", win_path], capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"WSL path conversion failed: {result.stderr}")
return result.stdout.strip()
[docs]def get_available_ram():
total = psutil.virtual_memory().total
available = psutil.virtual_memory().available
total_mb = total / (1024 ** 2)
available_mb = available / (1024 ** 2)
results = {}
results["bytes"] = total
results["available_bytes"] = available
results["megabytes"] = total_mb
results["available_mb"] = available_mb
results["gigabytes"] = total_mb / 1024
results["available_gb"] = available_mb / 1024
return results
[docs]def get_current_time():
return datetime.now().strftime("%H:%M:%S")
[docs]def get_display_resolution() -> Tuple[int, int]:
"""
Helper to get main monitor / display resolution.
.. note::
May return the virtual geometry in multi-display setups. To return the resolution of each available monitor in mosaic, see :func:`simba.utils.lookups.get_monitor_info`.
"""
root = tk.Tk()
root.withdraw()
width = root.winfo_screenwidth()
height = root.winfo_screenheight()
root.destroy()
return (width, height)
[docs]def get_img_resize_info(img_size: Tuple[int ,int],
display_resolution: Optional[Tuple[int, int]] = None,
max_height_ratio: float = 0.5,
max_width_ratio: float = 0.5,
min_height_ratio: float = 0.0,
min_width_ratio: float = 0.0) -> Tuple[int, int, float, float]:
"""
Calculates the new dimensions and scaling factors needed to resize an image while preserving its
aspect ratio so that it fits within a given portion of the display resolution.
:param Tuple[int, int] img_size : The original size of the image as (width, height).
:param Optional[Tuple[int, int]] display_resolution: Optional resolution of the display as (width, height). If none, then grabs the resolution of the main monitor.
:param float max_height_ratio: The maximum allowed height of the image as a fraction of the display height (default is 0.5).
:param float max_width_ratio: The maximum allowed width of the image as a fraction of the display width (default is 0.5).
:return: Length 4 tuple with resized width, resized height, downscale factor, and upscale factor
:rtype: Tuple[int, int, float, float]
"""
if display_resolution is None:
_, display_resolution = get_monitor_info()
max_width = round(display_resolution[0] * max_width_ratio)
max_height = round(display_resolution[1] * max_height_ratio)
min_width = round(display_resolution[0] * min_width_ratio)
min_height = round(display_resolution[1] * min_height_ratio)
if img_size[0] > max_width or img_size[1] > max_height:
width_ratio = max_width / img_size[0]
height_ratio = max_height / img_size[1]
downscale_factor = min(width_ratio, height_ratio)
upscale_factor = 1 / downscale_factor
new_width = round(img_size[0] * downscale_factor)
new_height = round(img_size[1] * downscale_factor)
return new_width, new_height, downscale_factor, upscale_factor
elif img_size[0] < min_width or img_size[1] < min_height:
width_ratio = min_width / img_size[0]
height_ratio = min_height / img_size[1]
scale = max(width_ratio, height_ratio) # ensures both dimensions meet or exceed min
new_width = round(round(img_size[0] * scale))
new_height = round(round(img_size[1] * scale))
return new_width, new_height, scale, 1 / scale
else:
return img_size[0], img_size[1], 1, 1
[docs]def is_running_in_ide():
return hasattr(sys, 'ps1') or sys.flags.interactive
[docs]def get_monitor_info() -> Tuple[Dict[int, Dict[str, Union[int, bool]]], Tuple[int, int]]:
"""
Helper to get main monitor / display resolution.
.. note::
Returns dict containing the resolution of each available monitor. To get the virtual geometry, see :func:`simba.utils.lookups.get_display_resolution`, and tuple of main monitor width and height.
"""
monitors = pyglet.canvas.get_display().get_screens()
results = {}
for monitor_cnt, monitor_info in enumerate(monitors):
primary = True if monitor_info.x == 0 and monitor_info.y == 0 else False
results[monitor_cnt] = {'width': monitor_info.width,
'height': monitor_info.height,
'primary': primary}
main_monitor = next(({'width': v['width'], 'height': v['height']} for v in results.values() if v.get('primary')), {'width': next(iter(results.values()))['width'], 'height': next(iter(results.values()))['height']})
return results, (int(main_monitor['width']), int(main_monitor['height']))
[docs]def get_table(data: Dict[str, Any],
headers: Optional[Tuple[str, str]] = ("SETTING", "VALUE"),
tablefmt: str = "grid") -> str:
"""
Create a formatted table string from dictionary data using the tabulate library.
Converts a dictionary into a formatted table string suitable for display
or printing. Each key-value pair in the dictionary becomes a row in the table.
:param Dict[str, Any] data: Dictionary containing the data to be formatted as a table. Keys become the first column, values become the second column.
:param Optional[Tuple[str, str]] headers: Tuple of two strings representing the column headers. Default is ("SETTING", "VALUE").
:param Literal["grid"] tablefmt: Table format style. For options, see simba.utils.enums.Formats.VALID_TABLEFMT
:return str: Formatted table string ready for display or printing.
:example:
>>> data = {"fps": 30, "width": 1920, "height": 1080, "frame_count": 3000}
>>> table = get_table(data=data, headers=("PARAMETER", "VALUE"))
"""
check_valid_dict(x=data, valid_key_dtypes=(str,), min_len_keys=1, source=f'{get_table.__name__} data')
check_valid_tuple(x=headers, source=f'{get_table.__name__} data', accepted_lengths=(2,), valid_dtypes=(str,))
check_str(name=f'{get_table.__name__} tablefmt', value=tablefmt, options=Formats.VALID_TABLEFMT.value, raise_error=True)
table_view = [[key, data[key]] for key in data]
return tabulate(table_view, headers=headers, tablefmt=tablefmt)
[docs]def get_ffmpeg_encoders(raise_error: bool = True, alphabetically_sorted: bool = False) -> List[str]:
"""
Get a list of all available FFmpeg encoders.
:param bool raise_error: If True, raises an exception when FFmpeg is not available or the command fails. If False, returns an empty list on error. Default: True.
:return: List of encoder names (e.g., ['libx264', 'aac', 'libvpx', ...]). Returns empty list if FFmpeg is unavailable and raise_error=False.
:rtype: List[str]
:example:
>>> codecs = get_ffmpeg_encoders()
>>> print(Formats.BATCH_CODEC.value in codecs)
"""
check_ffmpeg_available(raise_error=True)
try:
proc = subprocess.Popen(['ffmpeg', '-encoders'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = proc.communicate()
if isinstance(stdout, bytes):
stdout = stdout.decode('utf-8')
except Exception as e:
if raise_error:
raise FFMPEGNotFoundError(msg=str(e.args))
else:
return []
encoders = []
lines = stdout.split('\n')
for line in lines:
if re.match(r'^\s*[VAS]', line):
parts = line.split()
if len(parts) >= 2:
encoder_name = parts[1]
encoders.append(encoder_name)
return sorted(encoders) if alphabetically_sorted else encoders
[docs]def find_closest_string(target: str,
string_list: List[str],
case_sensitive: bool = False,
token_based: bool = True) -> Optional[Tuple[str, Union[int, float]]]:
"""
Find the closest string in a list to a target string using hybrid similarity matching.
This function uses a combination of token-based matching and Levenshtein distance to find
the best match. Token-based matching is particularly useful for strings like body part names
where word order may vary (e.g., "Left_ear" vs "Ear_left").
:param str target: The target string to match against.
:param List[str] string_list: List of strings to search through.
:param bool case_sensitive: If True, comparison is case-sensitive. If False (default), comparison is case-insensitive.
:param bool token_based: If True (default), uses hybrid token-based and Levenshtein matching which handles word reordering better. If False, uses pure Levenshtein distance only.
:return: Tuple of (closest_string, distance) or None if string_list is empty. When token_based=True, distance is a float score (lower is better). When token_based=False, distance is integer edit distance.
:rtype: Optional[Tuple[str, Union[int, float]]]
:example:
>>> find_closest_string("cat", ["dog", "car", "bat"])
>>> ('car', 0.33)
>>> find_closest_string("Left_ear", ["Ear_left", "Right_ear", "Nose"])
>>> ('Ear_left', 0.0)
>>> find_closest_string("CAT", ["dog", "car", "bat"], case_sensitive=False)
>>> ('car', 0.33)
>>> find_closest_string("CAT", ["dog", "car", "bat"], case_sensitive=True, token_based=False)
>>> ('car', 3)
"""
check_str(name=f'{find_closest_string.__name__} target', value=target, allow_blank=False, raise_error=True)
check_instance(source=f'{find_closest_string.__name__} string_list', instance=string_list, accepted_types=(list,), raise_error=True)
if len(string_list) == 0:
return None
for i in string_list:
check_str(name=f'{find_closest_string.__name__} string_list entry', value=i, allow_blank=False, raise_error=True)
def levenshtein(s1: str, s2: str) -> int:
if s1 == s2: return 0
if not s1: return len(s2)
if not s2: return len(s1)
if len(s1) > len(s2): s1, s2 = s2, s1
prev_row = list(range(len(s1) + 1))
for i, c2 in enumerate(s2):
curr_row = [i + 1]
for j, c1 in enumerate(s1):
cost = 0 if c1 == c2 else 1
curr_row.append(min(prev_row[j + 1] + 1, curr_row[j] + 1, prev_row[j] + cost))
prev_row = curr_row
return prev_row[-1]
def tokenize(s: str) -> List[str]:
"""Split string by common delimiters and return sorted tokens"""
tokens = re.split(r'[_\-\s]+', s)
return sorted([t for t in tokens if t])
def token_sort_similarity(s1: str, s2: str) -> float:
"""
Hybrid similarity score combining token matching with character-level Levenshtein.
Returns a score where 0.0 = perfect match, higher = worse match.
"""
tokens1 = tokenize(s1)
tokens2 = tokenize(s2)
# Token set matching
set1, set2 = set(tokens1), set(tokens2)
intersection = len(set1 & set2)
union = len(set1 | set2)
if union == 0:
token_score = 1.0
else:
token_score = 1.0 - (intersection / union) # Jaccard distance
sorted_s1 = '_'.join(tokens1)
sorted_s2 = '_'.join(tokens2)
max_len = max(len(sorted_s1), len(sorted_s2))
if max_len == 0:
lev_score = 0.0
else:
lev_score = levenshtein(sorted_s1, sorted_s2) / max_len
# Weighted combination: token matching (70%) + order similarity (30%)
return token_score * 0.7 + lev_score * 0.3
# Prepare strings for comparison
if not case_sensitive:
target_cmp = target.lower()
string_list_cmp = [s.lower() for s in string_list]
else:
target_cmp = target
string_list_cmp = string_list
# Find closest match
if token_based:
scores = [token_sort_similarity(target_cmp, s) for s in string_list_cmp]
closest_idx = min(range(len(scores)), key=lambda i: scores[i])
closest = string_list[closest_idx]
distance = scores[closest_idx]
else:
distances = [levenshtein(target_cmp, s) for s in string_list_cmp]
closest_idx = min(range(len(distances)), key=lambda i: distances[i])
closest = string_list[closest_idx]
distance = distances[closest_idx]
return closest, distance
[docs]def create_directionality_cords(bp_dict: dict,
left_ear_name: str,
nose_name: str,
right_ear_name: str) -> dict:
"""
Helper to create a dictionary mapping animal body-parts (nose, left ear, right ear) to their X and Y coordinate
column names for directionality analysis.
:param dict bp_dict: Dictionary with animal names as keys and body-part coordinate information as values. Expected to contain 'X_bps' and 'Y_bps' keys with lists of column names.
:param str left_ear_name: Name of the left ear body-part to search for in coordinate column names.
:param str nose_name: Name of the nose body-part to search for in coordinate column names.
:param str right_ear_name: Name of the right ear body-part to search for in coordinate column names.
:return: Nested dictionary with animal names as keys, body-part types (nose, ear_left, ear_right) as second-level keys, and coordinate types (X_bps, Y_bps) as third-level keys with corresponding column names as values.
:rtype: dict
:raises InvalidInputError: If any required body-part or coordinate cannot be found in the input dictionary.
:example:
>>> bp_dict = {'Animal_1': {'X_bps': ['Animal_1_Nose_x', 'Animal_1_Ear_left_x', 'Animal_1_Ear_right_x'], 'Y_bps': ['Animal_1_Nose_y', 'Animal_1_Ear_left_y', 'Animal_1_Ear_right_y']}}
>>> create_directionality_cords(bp_dict=bp_dict, left_ear_name='Ear_left', nose_name='Nose', right_ear_name='Ear_right')
>>> {'Animal_1': {'nose': {'X_bps': 'Animal_1_Nose_x', 'Y_bps': 'Animal_1_Nose_y'}, 'ear_left': {'X_bps': 'Animal_1_Ear_left_x', 'Y_bps': 'Animal_1_Ear_left_y'}, 'ear_right': {'X_bps': 'Animal_1_Ear_right_x', 'Y_bps': 'Animal_1_Ear_right_y'}}}
"""
NOSE, EAR_LEFT, EAR_RIGHT = Keys.NOSE.value, Keys.EAR_LEFT.value, Keys.EAR_RIGHT.value
results = {}
for animal in bp_dict.keys():
results[animal] = {NOSE: {}, EAR_LEFT: {}, EAR_RIGHT: {}}
for dimension in ["X_bps", "Y_bps"]:
for cord in bp_dict[animal][dimension]:
if (nose_name.lower() in cord.lower()) and ("x" in cord.lower()):
results[animal][NOSE]["X_bps"] = cord
elif (nose_name.lower() in cord.lower()) and ("y" in cord.lower()):
results[animal][NOSE]["Y_bps"] = cord
elif (left_ear_name.lower() in cord.lower()) and ("x" in cord.lower()):
results[animal][EAR_LEFT]["X_bps"] = cord
elif (left_ear_name.lower() in cord.lower()) and ("y" in cord.lower()):
results[animal][EAR_LEFT]["Y_bps"] = cord
elif (right_ear_name.lower() in cord.lower()) and ("x" in cord.lower()):
results[animal][EAR_RIGHT]["X_bps"] = cord
elif (right_ear_name.lower() in cord.lower()) and ("y" in cord.lower()):
results[animal][EAR_RIGHT]["Y_bps"] = cord
for animal_name, animal_bps in results.items():
for bp_name, bp_values in animal_bps.items():
if len(bp_values.keys()) == 0:
raise InvalidInputError(msg=f'Could not detect a body-part for animal {animal_name}, body-part {bp_name} in SimBA project. Make sure the body-part configuration file at {Paths.BP_NAMES.value} lists the appropriate body-parts', source=create_directionality_cords.__name__)
for cord_key, cord_value in bp_values.items():
if cord_value == '':
raise InvalidInputError(msg=f'Could not detect a body-part for animal {animal_name}, body-part {bp_name} and coordinate {cord_key} in SimBA project. Make sure the body-part configuration file at {Paths.BP_NAMES.value} lists the appropriate body-parts. Passed values: {left_ear_name, nose_name, right_ear_name}', source=create_directionality_cords.__name__)
return results
[docs]def intermittent_palette(n: int = 10,
base_light: float = 0.55,
contrast_delta: float = 0.18,
seed_hue: Optional[float] = None,
output: Literal["rgb", "rgb255", "hex"] = "rgb",
rng: Optional[random.Random] = None) -> Union[List[RGBFloat], List[Tuple[int, int, int]], List[str]]:
"""
Generate a categorical colour palette with evenly spaced hues and alternating lightness.
.. note::
Use to get color palette where immediate colors are distinct.
:param int n: Number of colours to generate. Must be greater than or equal to 1.
:param float base_light: Midpoint HSV value (0-1) used as the baseline lightness. Default ``0.55``.
:param float contrast_delta: Lightness offset added/subtracted per colour to improve visual separation. Default ``0.18``.
:param Optional[float] seed_hue: Initial hue (0-1). If ``None``, a random hue is sampled. Default ``None``.
:param str output: Output colour format. One of ``{"rgb", "rgb255", "hex"}``. Default ``"rgb"``.
:param Optional[random.Random] rng: Optional pre-seeded RNG for reproducible random starts.
:return: Colour palette in the requested format (RGB floats, RGB 0-255 integers, or hexadecimal strings).
:rtype: Union[List[Tuple[float, float, float]], List[Tuple[int, int, int]], List[str]]
:example:
>>> palette = intermittent_palette(n=6, output="hex")
>>> palette
>>> ['#a33f46', '#51a5df', '#b36824', '#4dbd9f', '#c749b4', '#7a9a3e']
"""
fn_name = intermittent_palette.__name__
check_int(name=f"{fn_name} n", value=n, min_value=1)
check_float(name=f"{fn_name} base_light", value=base_light, min_value=0.0, max_value=1.0)
check_float(name=f"{fn_name} contrast_delta", value=contrast_delta, min_value=0.0, max_value=1.0)
if seed_hue is not None:
check_float(name=f"{fn_name} seed_hue", value=seed_hue, min_value=0.0, max_value=1.0)
check_str(name=f"{fn_name} output", value=output, options={"rgb", "rgb255", "hex"}, raise_error=True)
if rng is not None and not isinstance(rng, random.Random):
raise InvalidInputError(msg="rng must be an instance of random.Random.", source=fn_name)
golden_ratio = 0.618033988749895
rnd = rng or random.Random()
hue = seed_hue % 1.0 if seed_hue is not None else rnd.random()
colours: List[RGBFloat] = []
for idx in range(n):
hue = (hue + golden_ratio) % 1.0
sat = 0.72 if idx % 3 else 0.85
light = base_light + (contrast_delta if idx % 2 else -contrast_delta)
light = min(max(light, 0.25), 0.85)
colours.append(tuple(hsv_to_rgb((hue, sat, light))))
fmt = output.lower()
if fmt == "rgb":
return colours
elif fmt == "rgb255":
return [tuple(int(round(c * 255)) for c in colour) for colour in colours]
else:
return [rgb2hex(colour) for colour in colours]
[docs]def quality_pct_to_crf(pct: int) -> int:
check_int(name=f'{quality_pct_to_crf.__name__} pct', min_value=1, max_value=100, raise_error=True, value=pct)
quality_lk = {int(k):v for k, v in percent_to_crf_lookup().items()}
closest_key = min(quality_lk, key=lambda k: abs(k - pct))
return quality_lk[closest_key]
[docs]def check_for_updates(time_out: int = 2):
"""
Check for SimBA package updates by querying PyPI and comparing with the installed version.
Fetches the latest SimBA version from PyPI and compares it with the currently installed
version. Prints an informational message indicating whether an update is available or if
the installation is up-to-date. Requires an active internet connection to query PyPI.
:param int time_out: Timeout in seconds for the PyPI API request. Default is 2 seconds.
Must be at least 1 second.
:return: None. Prints update information to stdout via stdout_information.
:raises SimBAPAckageVersionError: If the latest version cannot be fetched from PyPI, or if
the local SimBA version cannot be determined.
:example:
>>> check_for_updates()
>>> # Prints: "UP-TO-DATE. You have the latest SimBA version (1.0.0)."
>>> # or: "NEW SimBA VERSION AVAILABLE. You have SimBA version 1.0.0. The latest version is 1.1.0..."
"""
check_int(name=f'{fetch_pip_data.__name__} time_out', value=time_out, min_value=1)
_, latest_simba_version = fetch_pip_data(pip_url=r'https://pypi.org/pypi/simba-uw-tf-dev/json', time_out=time_out)
env_simba_version = OS.SIMBA_VERSION.value
if latest_simba_version is None:
raise SimBAPAckageVersionError(msg='Could not fetch latest SimBA version.', source=check_for_updates.__name__)
elif env_simba_version is None:
raise SimBAPAckageVersionError(msg='Could not get local SimBA version.', source=check_for_updates.__name__)
if latest_simba_version == env_simba_version:
msg = f'UP-TO-DATE. \nYou have the latest SimBA version ({env_simba_version}).'
else:
msg = (f'NEW SimBA VERSION AVAILABLE. \nYou have SimBA version {env_simba_version}. \nThe latest version is {latest_simba_version}. '
f'\nYou can update using "pip install simba-uw-tf-dev --upgrade"')
stdout_information(msg=msg, source=check_for_updates.__name__)
[docs]def get_ext_codec_map() -> Dict[str, str]:
"""
Get a dictionary mapping video file extensions to their recommended FFmpeg codecs.
Automatically falls back to alternative codecs if the preferred codec is not available.
:return: Dictionary mapping file extensions (without leading dot) to codec names.
:rtype: Dict[str, str]
:example:
>>> codec_map = get_ext_codec_map()
>>> codec = codec_map.get('webm', 'libx264') # Returns 'libvpx-vp9' or fallback
"""
codecs_available = get_ffmpeg_encoders(raise_error=False)
if not codecs_available: codecs_available = []
common_codecs = ['libx264', 'mpeg4', 'h264', 'mjpeg', 'libx265']
fallback_codec = None
for codec in common_codecs:
if codec in codecs_available:
fallback_codec = codec
break
# If no common codec found, use first available or default to mpeg4 (most universal)
if fallback_codec is None:
fallback_codec = codecs_available[0] if codecs_available else 'mpeg4'
def get_codec(preferred: str, alternative: str = None) -> str:
if preferred in codecs_available:
return preferred
alt = alternative if alternative else fallback_codec
return alt if alt in codecs_available else preferred
return {
'webm': get_codec(preferred='libvpx-vp9', alternative='libvpx'),
'avi': get_codec(preferred='mpeg4', alternative='libx264'),
'mp4': get_codec(preferred='libx264', alternative='mpeg4'),
'mov': get_codec(preferred='libx264', alternative='mpeg4'),
'mkv': get_codec(preferred='libx264', alternative='mpeg4'),
'flv': get_codec(preferred='libx264', alternative='mpeg4'),
'm4v': get_codec(preferred='libx264', alternative='mpeg4'),
'h264': get_codec(preferred='libx264', alternative='h264'),
}
[docs]def get_ffmpeg_codec(file_name: Union[str, os.PathLike],
fallback: str = 'mpeg4') -> str:
"""
Get the recommended FFmpeg codec for a video file based on its extension.
:param Union[str, os.PathLike] file_name: Path to video file or file extension.
:param str fallback: Codec to return if file extension is not recognized. Default: 'mpeg4'.
:return: Recommended FFmpeg codec name for the video file.
:rtype: str
:example:
>>> codec = get_ffmpeg_codec(file_name='video.mp4')
>>> codec = get_ffmpeg_codec(file_name='video.webm', fallback='libx264')
>>> codec = get_ffmpeg_codec(file_name=r'C:/videos/my_video.avi')
"""
codec_map = get_ext_codec_map()
_, file_name, ext = get_fn_ext(filepath=file_name)
if ext[1:] in codec_map.keys():
return codec_map[ext[1:]]
else:
return fallback
[docs]def get_nvdec_count(gpu_name: Optional[str] = None) -> int:
"""
Return the number of concurrent NVDEC (hardware video decode) sessions typical for the GPU model.
.. csv-table::
:header: EXPECTED RUNTIMES SINGLE NVDEC
:file: ../../docs/tables/NVDECYoloInference.csv
:widths: 10, 10, 40, 40
:align: center
:header-rows: 1
.. note::
When ``gpu_name`` is None, the first GPU name reported by ``nvidia-smi`` is used. Matching is done by
substring: the longest dictionary key contained in ``gpu_name`` wins, so shorter names do not shadow
longer ones (e.g. ``RTX 4070 Ti Super`` before ``RTX 4070 Ti``). Unknown or unmatched GPUs return ``1``.
:param str | None gpu_name: Full GPU product string, or None to query the local GPU.
:return: NVDEC engine count used for capacity hints (defaults to 1 if unknown).
:rtype: int
"""
NVDEC = {
"A10": 1,
"L4": 1,
"RTX 3060": 1,
"RTX 3060 Ti": 1,
"RTX 3070": 1,
"RTX 3070 Ti": 1,
"RTX 4060": 1,
"RTX 4060 Ti": 1,
"RTX 4070": 1,
"RTX 4070 Super": 1,
"RTX 4070 Ti": 1,
"RTX 4070 Ti Super": 1,
"RTX 5000 Ada": 1,
"RTX A4000": 1,
"RTX 5070": 1,
"RTX 5070 Ti": 1,
"A40": 1,
"RTX 3080": 1,
"RTX 3080 Ti": 1,
"RTX 3090": 1,
"RTX 3090 Ti": 1,
"RTX A5000": 1,
"RTX A6000": 1,
"RTX 4080": 2,
"RTX 4080 Super": 2,
"RTX 5080": 2,
"RTX 5090": 2,
"L40": 3,
"L40S": 3,
"RTX 4090": 3,
"RTX 5880 Ada": 3,
"RTX 6000 Ada": 3,
"RTX 6000 Pro": 4,
"RTX PRO 6000": 4,
"A30": 4,
"A100": 5,
"B100": 7,
"B200": 7,
"GB200": 7,
"H100": 7,
"H200": 7,
}
if gpu_name is None:
result = subprocess.run(["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], capture_output=True, text=True)
gpu_name = result.stdout.strip().split("\n")[0]
for key in sorted(NVDEC, key=len, reverse=True):
if key in gpu_name:
return NVDEC[key]
return 1
[docs]def find_best_multi_animal_assignment_frame(h5_path: Union[str, os.PathLike],
expected_animals: int,
strategy: Literal['longest_run_middle', 'first'] = 'longest_run_middle',
min_bodyparts_per_animal: int = 1) -> Optional[int]:
"""
Find a frame index suitable for the SimBA multi-animal identity-assignment UI.
Scans a DeepLabCut multi-animal H5 (e.g. ``_el.h5`` / ``_full.h5``) and returns a
frame index where all ``expected_animals`` individuals have at least
``min_bodyparts_per_animal`` non-NaN body-part detections. Useful for jumping the
multi-animal assignment UI straight to a frame where every animal is clearly
tracked, skipping the manual "x"-stepping loop in
:meth:`simba.mixins.pose_importer_mixin.PoseImporterMixin.multianimal_identification`.
The recommendation can be used as the ``initial_frame_no`` argument to
:class:`simba.pose_importers.superanimal_import.SuperAnimalTopViewImporter` (or any
other multi-animal importer that exposes the same parameter).
.. note::
The function expects a modern DLC PyTorch / multi-animal pandas H5 layout with
at least an ``individuals`` column level. Single-animal files and legacy DLC
TF files without ``individuals`` cannot be analysed this way and return
``None`` with a warning.
:param Union[str, os.PathLike] h5_path: Path to a DLC multi-animal H5 file with an
``individuals`` column level (typically modern DLC PyTorch backend output).
:param int expected_animals: Number of animals the SimBA project is configured for,
i.e. the number of distinct individuals that must all be simultaneously detected
on the returned frame. Must be >= 1.
:param Literal['longest_run_middle', 'first'] strategy: How to pick among candidate
frames. ``'longest_run_middle'`` (default) returns the midpoint of the longest
consecutive run of frames where all animals meet the body-part threshold (most
robust for the assignment UI). ``'first'`` returns the first qualifying frame.
:param int min_bodyparts_per_animal: Minimum number of non-NaN body-parts that each
animal must have on a candidate frame. Default ``1`` reproduces the original
"at least one body-part visible per animal" behaviour. Higher values yield
frames where animals are more completely tracked, which makes click-based
identity assignment more reliable (e.g. for SuperAnimal-TopView with 27 body
parts per animal, ``min_bodyparts_per_animal=14`` requires that more than half
of every animal's body-parts are tracked on the returned frame).
:return: Frame index recommended for the assignment UI, or ``None`` if no frame in
the file satisfies the constraint, or if the file does not contain a
multi-animal layout.
:rtype: Optional[int]
:example:
>>> frame = find_best_multi_animal_assignment_frame(
... h5_path=r'G:\\projects\\edmayelle\\raw_data\\HCS17_..._el.h5',
... expected_animals=5,
... )
>>> # frame == 3313 (middle of the longest 5-mice run)
:example require >= 10 body-parts per animal for higher-quality assignment frames:
>>> frame = find_best_multi_animal_assignment_frame(
... h5_path=..., expected_animals=5, min_bodyparts_per_animal=10)
"""
check_file_exist_and_readable(file_path=h5_path)
check_int(name=f'{find_best_multi_animal_assignment_frame.__name__} expected_animals',
value=expected_animals, min_value=1)
check_str(name=f'{find_best_multi_animal_assignment_frame.__name__} strategy',
value=strategy, options=('longest_run_middle', 'first'))
check_int(name=f'{find_best_multi_animal_assignment_frame.__name__} min_bodyparts_per_animal',
value=min_bodyparts_per_animal, min_value=1)
try:
df = pd.read_hdf(h5_path)
except Exception as e:
raise InvalidInputError(
msg=f'The H5 file {h5_path} could not be read as a pandas DataFrame: {type(e).__name__}: {e}',
source=find_best_multi_animal_assignment_frame.__name__,
)
level_names = list(df.columns.names) if isinstance(df.columns, pd.MultiIndex) else []
if 'individuals' not in level_names:
NoDataFoundWarning(
msg=(f'H5 file {h5_path} does not contain a multi-animal "individuals" column level; '
f'cannot search for a multi-animal assignment frame.'),
source=find_best_multi_animal_assignment_frame.__name__,
)
return None
if 'coords' in level_names and 'likelihood' in df.columns.get_level_values('coords'):
detection_view = df.xs('likelihood', level='coords', axis=1)
else:
detection_view = df
individuals = list(dict.fromkeys(df.columns.get_level_values('individuals')))
bp_counts_per_animal = pd.DataFrame(
{ind: detection_view.xs(ind, level='individuals', axis=1).notna().sum(axis=1) for ind in individuals},
index=df.index,
)
animals_meeting_threshold = (bp_counts_per_animal >= min_bodyparts_per_animal).sum(axis=1).to_numpy()
hits = np.flatnonzero(animals_meeting_threshold == expected_animals).tolist()
if not hits:
max_meeting = int(animals_meeting_threshold.max()) if len(animals_meeting_threshold) else 0
NoDataFoundWarning(
msg=(f'H5 file {h5_path} contains {len(individuals)} individuals but no frame has all '
f'{expected_animals} of them with at least {min_bodyparts_per_animal} body-part(s) '
f'tracked simultaneously (max animals-meeting-threshold = {max_meeting}).'),
source=find_best_multi_animal_assignment_frame.__name__,
)
return None
if strategy == 'first':
return int(hits[0])
longest_start, longest_end, longest_len = hits[0], hits[0], 1
for _, g in groupby(enumerate(hits), lambda ix: ix[0] - ix[1]):
run = list(g)
run_start, run_end, run_len = run[0][1], run[-1][1], len(run)
if run_len > longest_len:
longest_start, longest_end, longest_len = run_start, run_end, run_len
return int((longest_start + longest_end) // 2)