Source code for simba.mixins.plotting_mixin

__author__ = "Simon Nilsson; sronilsson@gmail.com"
import io
import os
import shutil
from copy import copy
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple, Union

import cv2
import imutils
import matplotlib
#matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import plotly.graph_objs as go
import seaborn as sns
from matplotlib import cm
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.collections import LineCollection
from matplotlib.figure import Figure
from matplotlib.ticker import FuncFormatter, MaxNLocator
from numba import bool_, njit, uint8
from PIL import Image, ImageDraw, ImageFont

try:
    from typing import Literal
except:
    from typing_extensions import Literal

import ast

from simba.mixins.config_reader import ConfigReader
from simba.utils.checks import (
    check_all_file_names_are_represented_in_video_log,
    check_file_exist_and_readable, check_float, check_if_dir_exists,
    check_if_keys_exist_in_dict, check_if_valid_img, check_if_valid_rgb_tuple,
    check_instance, check_int, check_str, check_that_column_exist,
    check_valid_array, check_valid_boolean, check_valid_dataframe,
    check_valid_lst, check_valid_tuple)
from simba.utils.data import (create_color_palette, detect_bouts,
                              savgol_smoother)
from simba.utils.enums import Formats, Keys, Options
from simba.utils.errors import InvalidFileTypeError, InvalidInputError
from simba.utils.lookups import (get_categorical_palettes, get_color_dict,
                                 get_fonts, get_named_colors,
                                 get_named_simba_fonts)
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (find_files_of_filetypes_in_directory,
                                    get_fn_ext, get_video_meta_data, read_df,
                                    read_frm_of_video, read_video_info,
                                    seconds_to_timestamp)


[docs]class PlottingMixin(object): """ Methods for visualizations """ def __init__(self): pass
[docs] def create_gantt_img(self, bouts_df: pd.DataFrame, clf_name: str, image_index: int, fps: int, gantt_img_title: str, header_font_size: int = 24, label_font_size: int = 12): """ Helper to create a single gantt plot based on the data preceeding the input image :param pd.DataFrame bouts_df: ataframe holding information on individual bouts created by :meth:`simba.misc_tools.get_bouts_for_gantt`. :param str clf_name: Name of the classifier. :param int image_index: The count of the image. E.g., ``1000`` will create a gantt image representing frame 1-1000. :param int fps: The fps of the input video. :param str gantt_img_title: Title of the image. :return np.ndarray """ fig, ax = plt.subplots() fig.suptitle(gantt_img_title, fontsize=header_font_size) relRows = bouts_df.loc[bouts_df["End_frame"] <= image_index] for i, event in enumerate(relRows.groupby("Event")): data_event = event[1][["Start_time", "Bout_time"]] ax.broken_barh(data_event.values, (4, 4), facecolors="red") xLength = (round(image_index / fps)) + 1 if xLength < 10: xLength = 10 ax.set_xlim(0, xLength) ax.set_ylim([0, 12]) plt.ylabel(clf_name, fontsize=label_font_size) plt.yticks([]) plt.xlabel("time(s)", fontsize=label_font_size) ax.yaxis.set_ticklabels([]) ax.grid(True) buffer_ = io.BytesIO() plt.savefig(buffer_, format="png") buffer_.seek(0) image = PIL.Image.open(buffer_) ar = np.asarray(image) open_cv_image = ar[:, :, ::-1] open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_RGB2BGR) open_cv_image = cv2.resize(open_cv_image, (640, 480)) open_cv_image = np.uint8(open_cv_image) buffer_.close() plt.close(fig) return open_cv_image
[docs] def create_single_color_lst( self, pallete_name: Literal[Options.PALETTE_OPTIONS], increments: int, as_rgb_ratio: bool = False, as_hex: bool = False, ) -> List[Union[str, int, float]]: """ Helper to create a color palette of bgr colors in a list. :param str pallete_name: Palette name (e.g., 'jet') :param int increments: Numbers of colors in the color palette to create. :param bool as_rgb_ratio: If True returns the colors as RGB ratios (0-1). :param bool as_hex: If True, returns the colors as HEX. :return list .. note:: If as_rgb_ratio **AND** as_hex, then returns HEX. """ if as_hex: as_rgb_ratio = True cmap = cm.get_cmap(pallete_name, increments + 1) color_lst = [] for i in range(cmap.N): rgb = list((cmap(i)[:3])) if not as_rgb_ratio: rgb = [i * 255 for i in rgb] rgb.reverse() if as_hex: rgb = matplotlib.colors.to_hex(rgb) color_lst.append(rgb) return color_lst
[docs] def remove_a_folder(self, folder_dir: str) -> None: """Helper to remove a directory, use for cleaning up smaller multiprocessed videos following concat""" shutil.rmtree(folder_dir, ignore_errors=True)
@staticmethod @lru_cache(maxsize=64) def _load_ttf(font_path: str, size: int): check_file_exist_and_readable(file_path=font_path) check_int(name=f'{PlottingMixin._load_ttf.__name__} size', value=size, min_value=1) try: return ImageFont.truetype(font_path, size=size) except OSError as e: raise InvalidFileTypeError(msg=f'The file {font_path} is not a valid font file (.ttf/.otf) and could not be loaded: {e}', source=PlottingMixin._load_ttf.__name__)
[docs] def split_and_group_df(self, df: pd.DataFrame, splits: int, include_row_index: bool = False, include_split_order: bool = True) -> (List[pd.DataFrame], int): """ Helper to split a dataframe for multiprocessing. If include_split_order, then include the group number in split data as a column. If include_row_index, includes a column representing the row index in the array, which can be helpful for knowing the frame indexes while multiprocessing videos. Returns split data and approximations of number of observations per split. """ if include_row_index: row_indices = np.arange(len(df)).reshape(-1, 1) df = np.concatenate((df, row_indices), axis=1) data_arr = np.array_split(df, splits) if include_split_order: for df_cnt in range(len(data_arr)): data_arr[df_cnt]["group"] = df_cnt obs_per_split = len(data_arr[0]) return data_arr, obs_per_split
[docs] def make_distance_plot( self, data: np.array, line_attr: Dict[int, str], style_attr: Dict[str, Any], fps: int, save_img: bool = False, save_path: Optional[str] = None, ) -> np.ndarray: """ Helper to make a single line plot .png image with N lines. :param np.array data: Two-dimensional array where rows represent frames and columns represent intertwined x and y coordinates. :param dict line_attr: Line color attributes. :param dict style_attr: Plot attributes (size, font size, line width etc). :param int fps: Video frame rate. :param Optionan[str] save_path: Location to store output .png image. If None, then return image. .. note:: `GitHub tutorial/documentation <https://github.com/sgoldenlab/simba/blob/master/docs/Scenario2.md#visualizing-distance-plots>`__. :example: >>> fps = 10 >>> data = np.random.random((100,2)) >>> line_attr = {0: ['Blue'], 1: ['Red']} >>> save_path = '/_tests/final_frm.png' >>> style_attr = {'width': 640, 'height': 480, 'line width': 6, 'font size': 8, 'y_max': 'auto'} >>> self.make_distance_plot(fps=fps, data=data, line_attr=line_attr, style_attr=style_attr, save_path=save_path) """ colors = get_color_dict() for j in range(data.shape[1]): color = colors[line_attr[j][-1]][::-1] color = tuple(x / 255 for x in color) plt.plot( data[:, j], color=color, linewidth=style_attr["line width"], alpha=style_attr["opacity"], ) timer = SimbaTimer(start=True) max_x = len(data) if style_attr["y_max"] == "auto": max_y = np.amax(data) else: max_y = float(style_attr["y_max"]) y_ticks_locs = y_lbls = np.round(np.linspace(0, max_y, 10), 2) x_ticks_locs = x_lbls = np.linspace(0, max_x, 5) x_lbls = np.round((x_lbls / fps), 1) plt.xlabel("time (s)") plt.ylabel("distance (cm)") plt.xticks( x_ticks_locs, x_lbls, rotation="horizontal", fontsize=style_attr["font size"], ) plt.yticks(y_ticks_locs, y_lbls, fontsize=style_attr["font size"]) plt.ylim(0, max_y) plt.suptitle( "Animal distances", x=0.5, y=0.92, fontsize=style_attr["font size"] + 4 ) buffer_ = io.BytesIO() plt.savefig(buffer_, format="png") buffer_.seek(0) img = PIL.Image.open(buffer_) img = np.uint8(cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)) buffer_.close() plt.close() img = cv2.resize(img, (style_attr["width"], style_attr["height"])) timer.stop_timer() if save_img: cv2.imwrite(save_path, img) stdout_success( f"Final distance plot saved at {save_path}", elapsed_time=timer.elapsed_time_str, source=self.__class__.__name__, ) else: return img
[docs] def make_probability_plot( self, data: pd.Series, style_attr: dict, clf_name: str, fps: int, save_path: str ) -> np.ndarray: """ Make a single classifier probability plot png image. :param pd.Series data: row representing frames and field representing classification probabilities. :param dict line_attr: Line color attributes. :param dict style_attr: Image attributes (size, font size, line width etc). :param int fps: Video frame rate. :param str ot :param str save_path: Location to store output .png image. .. notes: `Tutorial <https://github.com/sgoldenlab/simba/blob/master/docs/Scenario2.md#visualizing-classification-probabilities>`__. :example: >>> data = pd.Series(np.random.random((100, 1)).flatten()) >>> style_attr = {'width': 640, 'height': 480, 'font size': 10, 'line width': 6, 'color': 'blue', 'circle size': 20} >>> clf_name='Attack' >>> fps=10 >>> save_path = '/_test/frames/output/probability_plots/Together_1_final_frame.png' >>> _ = self.make_probability_plot(data=data, style_attr=style_attr, clf_name=clf_name, fps=fps, save_path=save_path) """ timer = SimbaTimer() timer.start_timer() if style_attr["y_max"] == "auto": max_y = float(data.max().round(2)) else: max_y = float(style_attr["y_max"]) max_x = len(data) plt.plot( list(data), color=style_attr["color"], linewidth=style_attr["line width"] ) plt.plot( len(data), list(data)[-1], "o", markersize=style_attr["circle size"], color=style_attr["color"], ) plt.ylim([0, max_y]) plt.ylabel(clf_name, fontsize=style_attr["font size"]) y_ticks_locs = y_lbls = np.round(np.linspace(0, max_y, 10), 2) x_ticks_locs = x_lbls = np.linspace(0, max_x, 5) x_lbls = np.round((x_lbls / fps), 1) plt.xlabel("Time (s)", fontsize=style_attr["font size"] + 4) plt.grid() plt.xticks( x_ticks_locs, x_lbls, rotation="horizontal", fontsize=style_attr["font size"], ) plt.yticks(y_ticks_locs, y_lbls, fontsize=style_attr["font size"]) plt.suptitle( "{} {}".format(clf_name, "probability"), x=0.5, y=0.92, fontsize=style_attr["font size"] + 4, ) buffer_ = io.BytesIO() plt.savefig(buffer_, format="png") buffer_.seek(0) image = PIL.Image.open(buffer_) ar = np.asarray(image) img = cv2.cvtColor(ar, cv2.COLOR_RGB2BGR) img = np.uint8(cv2.resize(img, (style_attr["width"], style_attr["height"]))) buffer_.close() plt.close() timer.stop_timer() cv2.imwrite(save_path, img) stdout_success( msg=f"Final probability plot saved at {save_path}", elapsed_time=timer.elapsed_time_str, source=self.__class__.__name__, )
[docs] def make_gantt_plot(self, bouts_df: pd.DataFrame, clf_names: List[str], palette: List[Tuple[int, int, int]], fps: int, x_length: int, video_name: str, width: int = 640, height: int = 480, font_size: int = 8, bar_opacity: float = 0.85, font_rotation: int = 45, x_tick_lbl_rotation: int = 45, font: Optional[str] = None, title: Optional[str] = None, title_font_size: int = 24, save_path: Optional[str] = None, edge_clr: Optional[str] = 'black', hhmmss: bool = False, as_svg: bool = False) -> Union[None, np.ndarray, str]: """ Create a Gantt chart visualization of behavioral bouts over time. Generates a horizontal bar chart where each row represents a behavior class, and bars indicate when behaviors occurred. Supports SVG output for scalable figures or PNG/NumPy array for video overlays. .. image:: _static/img/gantt_mosaic.webp :alt: Gantt mosaic :width: 1000 :align: center .. video:: _static/img/make_gantt_plot.webm :width: 1000 :autoplay: :loop: :muted: :align: center :param pd.DataFrame bouts_df: DataFrame containing bout data with columns 'Event', 'Start_time', and 'Bout_time'. :param List[str] clf_names: List of behavior/classifier names to display. Must match 'Event' values in ``bouts_df``. :param List[Tuple[int, int, int]] palette: List of RGB color tuples (0-255) for each behavior. Length should match ``clf_names``. :param int fps: Frames per second of the source video. Used to convert frame counts to time. :param int x_length: Total length of the session in frames. Determines x-axis range. :param str video_name: Title displayed at the top of the chart. :param int width: Output image width in pixels (when not SVG). Default: 640. :param int height: Output image height in pixels (when not SVG). Default: 480. :param int font_size: Base font size for labels and ticks. Default: 8. :param float bar_opacity: Opacity of behavior bars (0.0-1.0). Default: 0.85. :param int font_rotation: Rotation angle in degrees for y-axis labels. Default: 45. :param int x_tick_lbl_rotation: Rotation angle in degrees for x-axis tick labels. Default: 0. :param Optional[str] font: Font to render the chart text in. Accepts a bundled SimBA font name (the .ttf filename stem returned by :func:`~simba.utils.lookups.get_named_simba_fonts`, e.g. 'Poppins Regular'), an OS-installed font name (from :func:`~simba.utils.lookups.get_fonts`), or any matplotlib family name. Bundled fonts are auto-registered with matplotlib and resolved to their internal family name. If None, uses the matplotlib default. :param Optional[str] save_path: Path to save the image. If None and ``as_svg=False``, returns NumPy array. :param Optional[str] edge_clr: Color of bar edges. Default: 'black'. :param bool hhmmss: If True, displays x-axis time as HH:MM:SS. If False, displays seconds. Default: False. :param bool as_svg: If True, returns or saves SVG format. If False, uses PNG format. Default: False. :return Union[None, np.ndarray, str]: If ``as_svg=True`` and ``save_path=None``, returns SVG string. If ``save_path`` provided, returns None (saves file). Otherwise returns NumPy array (BGR format). """ video_timer = SimbaTimer(start=True) colour_tuple_x = list(np.arange(3.5, 203.5, 5)) original_font_family = copy(plt.rcParams['font.family']) if isinstance(plt.rcParams['font.family'], list) else plt.rcParams['font.family'] if font is not None: simba_fonts = get_named_simba_fonts() available_fonts = get_fonts() matplotlib.font_manager._get_font.cache_clear() if font in simba_fonts: # A SimBA-bundled font (e.g. 'Poppins Regular'): register the .ttf with matplotlib and resolve its internal family name - the filename-stem name used elsewhere in SimBA is not the matplotlib family name. matplotlib.font_manager.fontManager.addfont(simba_fonts[font]) family_name = matplotlib.font_manager.FontProperties(fname=simba_fonts[font]).get_name() plt.rcParams['font.family'] = [family_name, 'sans-serif'] elif font in available_fonts: plt.rcParams['font.family'] = font else: plt.rcParams['font.family'] = [font, 'sans-serif'] fig, ax = plt.subplots() fig.patch.set_facecolor('white') plt.title(video_name, fontsize=font_size + 2, pad=25, fontweight='bold') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_color('#666666') ax.spines['bottom'].set_color('#666666') if not hhmmss: x_label: str = "SESSION TIME (S)" else: x_label: str = "SESSION TIME (HH:MM:SS)" for i, event in enumerate(bouts_df.groupby("Event")): for x in clf_names: if event[0] == x: ix = clf_names.index(x) data_event = event[1][["Start_time", "Bout_time"]] ax.broken_barh(data_event.values, (colour_tuple_x[ix], 3), facecolors=palette[ix], edgecolor=edge_clr, linewidth=0.25, alpha=bar_opacity) x_ticks_seconds = np.round(np.linspace(0, x_length / fps, 6)) x_ticks_locs = x_ticks_seconds if hhmmss: x_lbls = [seconds_to_timestamp(sec) for sec in x_ticks_seconds] else: x_lbls = [int(x) for x in x_ticks_seconds] ax.set_xticks(x_ticks_locs) ax.set_xticklabels(x_lbls, rotation=x_tick_lbl_rotation) ax.set_ylim(0, colour_tuple_x[len(clf_names)]) ax.set_yticks(np.arange(5, 5 * len(clf_names) + 1, 5)) ax.set_yticklabels(clf_names, rotation=font_rotation, ha='right', va='center') ax.tick_params(axis="both", labelsize=font_size) if title is not None: ax.set_title(title, fontsize=title_font_size) plt.xlabel(x_label, fontsize=font_size + 3) ax.grid(True, axis='both', linewidth=1.0, color='gray', alpha=0.2, linestyle='--', which='major') plt.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.15) plt.tight_layout() if as_svg and save_path is None: svg_buffer = io.BytesIO() fig.savefig(svg_buffer, format="svg", bbox_inches="tight") svg_buffer.seek(0) svg_data = svg_buffer.getvalue().decode("utf-8") svg_buffer.close() plt.close(fig) if font is not None: plt.rcParams['font.family'] = original_font_family matplotlib.font_manager._get_font.cache_clear() return svg_data if as_svg and save_path is not None: fig.savefig(save_path, format="svg", bbox_inches="tight") plt.close(fig) if font is not None: plt.rcParams['font.family'] = original_font_family matplotlib.font_manager._get_font.cache_clear() return None else: buffer_ = io.BytesIO() plt.savefig(buffer_, format="png") buffer_.seek(0) image = PIL.Image.open(buffer_) ar = np.asarray(image) open_cv_image = cv2.cvtColor(ar, cv2.COLOR_RGB2BGR) open_cv_image = cv2.resize(open_cv_image, (width, height)) frame = np.uint8(open_cv_image) buffer_.close() plt.close('all') if font is not None: plt.rcParams['font.family'] = original_font_family matplotlib.font_manager._get_font.cache_clear() if save_path is not None: cv2.imwrite(save_path, frame) video_timer.stop_timer() stdout_success(msg=f"Final gantt frame for video {video_name} saved at {save_path}",elapsed_time=video_timer.elapsed_time_str, source=self.__class__.__name__) else: return frame
@staticmethod def make_clf_heatmap_plot( frm_data: np.array, max_scale: float, palette: Literal[Options.PALETTE_OPTIONS], aspect_ratio: float, shading: Literal["gouraud", "flat"], clf_name: str, img_size: Tuple[int, int], file_name: Optional[str] = None, final_img: bool = False, ): cum_df = pd.DataFrame(frm_data).reset_index() cum_df = cum_df.melt( id_vars="index", value_vars=None, var_name=None, value_name="seconds", col_level=None, ).rename(columns={"index": "vertical_idx", "variable": "horizontal_idx"}) cum_df["color"] = ( (cum_df["seconds"].astype(float) / float(max_scale)) .round(2) .clip(upper=100) ) color_array = np.zeros( ( len(cum_df["vertical_idx"].unique()), len(cum_df["horizontal_idx"].unique()), ) ) for i in range(color_array.shape[0]): for j in range(color_array.shape[1]): value = cum_df["color"][ (cum_df["horizontal_idx"] == j) & (cum_df["vertical_idx"] == i) ].values[0] color_array[i, j] = value color_array = color_array * max_scale matplotlib.font_manager._get_font.cache_clear() plt.close("all") fig = plt.figure() im_ratio = color_array.shape[0] / color_array.shape[1] plt.pcolormesh( color_array, shading=shading, cmap=palette, rasterized=True, alpha=1, vmin=0.0, vmax=float(max_scale), ) plt.gca().invert_yaxis() plt.xticks([]) plt.yticks([]) plt.axis("off") plt.tick_params(axis="both", which="both", length=0) cb = plt.colorbar(pad=0.0, fraction=0.023 * im_ratio) cb.ax.tick_params(size=0) cb.outline.set_visible(False) cb.set_label("{} (seconds)".format(clf_name), rotation=270, labelpad=10) plt.tight_layout() # plt.gca().set_aspect(aspect_ratio) buffer_ = io.BytesIO() plt.savefig(buffer_, format="png") buffer_.seek(0) image = PIL.Image.open(buffer_) image = np.uint8(cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)) image = cv2.resize(image, img_size) buffer_.close() plt.close() if final_img: cv2.imwrite(file_name, image) stdout_success( msg=f"Final classifier heatmap image saved at at {file_name}", source="make_clf_heatmap_plot", ) else: return image @staticmethod def make_location_heatmap_plot(frm_data: np.array, max_scale: float, palette: Literal[Options.PALETTE_OPTIONS], aspect_ratio: float, shading: str, img_size: Tuple[int, int], file_name: Optional[Union[str, os.PathLike]] = None, line_clr: Optional[str] = None, min_seconds: Optional[float] = None, bg_img: Optional[np.ndarray] = None, legend_lbl: str = "location (seconds)", heatmap_opacity: Optional[float] = 0.99, color_legend: bool = True, leg_width: Optional[int] = None, hh_mm_ss: bool = False) -> Union[np.ndarray, None]: check_valid_boolean(value=hh_mm_ss, source=f'{PlottingMixin.make_location_heatmap_plot.__name__} hh_mm_ss', raise_error=True) cum_df = pd.DataFrame(frm_data).reset_index() cum_df = cum_df.melt(id_vars="index", value_vars=None, var_name=None, value_name="seconds", col_level=None).rename(columns={"index": "vertical_idx", "variable": "horizontal_idx"}) below_min = cum_df["seconds"] < min_seconds if min_seconds is not None else pd.Series(False, index=cum_df.index) if min_seconds is not None: cum_df.loc[below_min, "seconds"] = 0 cum_df["color"] = ((cum_df["seconds"].astype(float) / float(max_scale)).round(2).clip(upper=100)) color_array = np.zeros((len(cum_df["vertical_idx"].unique()), len(cum_df["horizontal_idx"].unique()))) below_min_arr = np.zeros(color_array.shape, dtype=bool) for i in range(color_array.shape[0]): for j in range(color_array.shape[1]): mask = (cum_df["horizontal_idx"] == j) & (cum_df["vertical_idx"] == i) color_array[i, j] = cum_df.loc[mask, "color"].values[0] if min_seconds is not None: below_min_arr[i, j] = below_min.loc[mask].values[0] color_array = color_array * max_scale if min_seconds is not None and bg_img is None: color_array = color_array.astype(np.float64) color_array[below_min_arr] = np.nan vmin_plot = min_seconds if min_seconds is not None else 0.0 matplotlib.font_manager._get_font.cache_clear() plt.close("all") fig = plt.figure(facecolor="white") ax = plt.gca() ax.set_facecolor("white") if line_clr is not None: linewidths = PlottingMixin().get_optimal_circle_size(frame_size=img_size, circle_frame_ratio=175) else: linewidths = None plt.pcolormesh(color_array, shading=shading, cmap=palette, rasterized=True, alpha=1, vmin=vmin_plot, vmax=max_scale, linewidths=linewidths, edgecolors=line_clr) plt.gca().set_aspect("equal") plt.gca().invert_yaxis() plt.xticks([]) plt.yticks([]) plt.axis("off") plt.tick_params(axis="both", which="both", length=0) plt.tight_layout() canvas = FigureCanvas(fig) canvas.draw() mat = np.array(canvas.renderer._renderer) image = cv2.cvtColor(mat, cv2.COLOR_RGB2BGR) image = np.uint8(image) plt.close("all") if bg_img is not None: check_if_valid_img(data=bg_img, source=f'{PlottingMixin.make_location_heatmap_plot.__name__} bg_img', raise_error=False) bg_h, bg_w = bg_img.shape[0], bg_img.shape[1] if bg_img.ndim == 2: bg_img = cv2.cvtColor(bg_img, cv2.COLOR_GRAY2BGR) image = cv2.resize(image, (bg_w, bg_h)) heatmap_w = max(0.0, min(1.0, float(heatmap_opacity) if heatmap_opacity is not None else 0.8)) threshold = max(1e-9, float(max_scale) * 0.01) alpha_mask = (color_array > threshold).astype(np.float64) alpha_resized = cv2.resize(alpha_mask, (bg_w, bg_h), interpolation=cv2.INTER_LINEAR) blend = np.clip(alpha_resized * heatmap_w, 0, 1) blend_3ch = np.repeat(blend[:, :, np.newaxis], 3, axis=2) image = (image.astype(np.float64) * blend_3ch + bg_img.astype(np.float64) * (1.0 - blend_3ch)).round().astype(np.uint8) if color_legend: fig_cb = Figure(figsize=(0.6, 4)) ax_cb = fig_cb.add_subplot(111) sm = cm.ScalarMappable(cmap=palette, norm=matplotlib.colors.Normalize(vmin=vmin_plot, vmax=max_scale)) sm.set_array([]) cb = fig_cb.colorbar(sm, cax=ax_cb) cb.set_label(legend_lbl, rotation=270, labelpad=14, fontsize=10) cb.ax.yaxis.set_major_locator(MaxNLocator(integer=True)) def _legend_formatter(x, pos): if hh_mm_ss: from simba.utils.read_write import seconds_to_timestamp if x >= max_scale - 1e-9: return f">{seconds_to_timestamp(max_scale)}" if min_seconds is not None and x <= vmin_plot + 1e-9: return f"<{seconds_to_timestamp(min_seconds)}" return seconds_to_timestamp(float(x)) if x >= max_scale - 1e-9: return f">{int(round(max_scale))}" if min_seconds is not None and x <= vmin_plot + 1e-9: return f"<{int(round(min_seconds))}" return str(int(round(x))) cb.ax.yaxis.set_major_formatter(FuncFormatter(_legend_formatter)) cb.ax.tick_params(size=0) cb.outline.set_visible(False) buf = io.BytesIO() fig_cb.savefig(buf, format="png", dpi=100, bbox_inches="tight", pad_inches=0.15) buf.seek(0) mat_cb = np.array(Image.open(buf).convert("RGB")) plt.close(fig_cb) cb_bgr = cv2.cvtColor(mat_cb, cv2.COLOR_RGB2BGR) leg_h = image.shape[0] leg_w = leg_width if leg_width is not None else max(1, int(cb_bgr.shape[1] * leg_h / cb_bgr.shape[0])) cb_resized = cv2.resize(cb_bgr, (leg_w, leg_h), interpolation=cv2.INTER_LINEAR) image = np.hstack((image, cb_resized)) if not color_legend: h, w = image.shape[:2] target_w, target_h = img_size[0], img_size[1] scale = min(target_w / w, target_h / h) new_w, new_h = int(round(w * scale)), int(round(h * scale)) image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) if new_w < target_w or new_h < target_h: padded = np.zeros((target_h, target_w, image.shape[2]), dtype=image.dtype) padded[:] = 0 y0 = (target_h - new_h) // 2 x0 = (target_w - new_w) // 2 padded[y0 : y0 + new_h, x0 : x0 + new_w] = image image = padded if file_name is not None: cv2.imwrite(file_name, image) stdout_success(msg=f"Location heatmap image saved at at {file_name}", source=PlottingMixin.make_location_heatmap_plot.__class__.__name__) else: return image
[docs] def get_bouts_for_gantt( self, data_df: pd.DataFrame, clf_name: str, fps: int ) -> pd.DataFrame: """ Helper to detect all behavior bouts for a specific classifier. :param pd.DataFrame data_df: Pandas Dataframe with classifier prediction data. :param str clf_name: Name of the classifier :param int fps: The fps of the input video. :return pd.DataFrame: Holding the start time, end time, end frame, bout time etc of each classified bout. """ boutsList, nameList, startTimeList, endTimeList, endFrameList = ( [], [], [], [], [], ) groupDf = pd.DataFrame() v = (data_df[clf_name] != data_df[clf_name].shift()).cumsum() u = data_df.groupby(v)[clf_name].agg(["all", "count"]) m = u["all"] & u["count"].ge(1) groupDf["groups"] = data_df.groupby(v).apply( lambda x: (x.index[0], x.index[-1]) )[m] for indexes, rows in groupDf.iterrows(): currBout = list(rows["groups"]) boutTime = ((currBout[-1] - currBout[0]) + 1) / fps startTime = (currBout[0] + 1) / fps endTime = (currBout[1]) / fps endFrame = currBout[1] endTimeList.append(endTime) startTimeList.append(startTime) boutsList.append(boutTime) nameList.append(clf_name) endFrameList.append(endFrame) return pd.DataFrame( list(zip(nameList, startTimeList, endTimeList, endFrameList, boutsList)), columns=["Event", "Start_time", "End Time", "End_frame", "Bout_time"], )
[docs] def resize_gantt(self, gantt_img: np.array, img_height: int) -> np.ndarray: """ Helper to resize image while retaining aspect ratio. """ return imutils.resize(gantt_img, height=img_height)
@staticmethod def bbox_mp( frm_range: list, polygon_data: dict, animal_bp_dict: dict, data_df: Optional[pd.DataFrame], intersection_data_df: Optional[pd.DataFrame], roi_attributes: dict, video_path: str, key_points: bool, greyscale: bool, ): cap, current_frame = cv2.VideoCapture(video_path), frm_range[0] cap.set(1, frm_range[0]) img_lst = [] while current_frame < frm_range[-1]: ret, frame = cap.read() if ret: if key_points: frm_data = data_df.iloc[current_frame] if greyscale: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) for animal_cnt, (animal, animal_data) in enumerate( animal_bp_dict.items() ): if key_points: for bp_cnt, (x_col, y_col) in enumerate( zip(animal_data["X_bps"], animal_data["Y_bps"]) ): cv2.circle( frame, (frm_data[x_col], frm_data[y_col]), 0, roi_attributes[animal]["bbox_clr"], roi_attributes[animal]["keypoint_size"], ) animal_polygon = np.array( list( polygon_data[animal][ current_frame ].convex_hull.exterior.coords ) ).astype(int) if intersection_data_df is not None: intersect = intersection_data_df.loc[ current_frame, intersection_data_df.columns.str.startswith(animal), ].sum() if intersect > 0: cv2.polylines( frame, [animal_polygon], 1, roi_attributes[animal]["highlight_clr"], roi_attributes[animal]["highlight_clr_thickness"], ) cv2.polylines( frame, [animal_polygon], 1, roi_attributes[animal]["bbox_clr"], roi_attributes[animal]["bbox_thickness"], ) img_lst.append(frame) current_frame += 1 else: print( "SIMBA WARNING: SimBA tried to grab frame number {} from video {}, but could not find it. The video has {} frames.".format( str(current_frame), video_path, str(cap.get(cv2.CAP_PROP_FRAME_COUNT)), ) ) return img_lst @staticmethod def path_plot_mp( data: np.array, video_setting: bool, frame_setting: bool, video_save_dir: str, video_name: str, frame_folder_dir: str, style_attr: dict, print_animal_names: bool, animal_attr: dict, fps: int, video_info: pd.DataFrame, clf_attr: dict, input_style_attr: dict, video_path: Optional[Union[str, os.PathLike]] = None, ): group = int(data[0][0]) color_dict = get_color_dict() if video_setting: fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value) video_save_path = os.path.join(video_save_dir, "{}.mp4".format(str(group))) video_writer = cv2.VideoWriter( video_save_path, fourcc, fps, (style_attr["width"], style_attr["height"]), ) if input_style_attr is not None: if (type(input_style_attr["bg color"]) == dict) and ( input_style_attr["bg color"]["type"] ) == "moving": check_file_exist_and_readable(file_path=video_path) video_cap = cv2.VideoCapture(video_path) for i in range(data.shape[0]): if input_style_attr is not None: if (type(input_style_attr["bg color"]) == dict) and ( input_style_attr["bg color"]["type"] ) == "moving": style_attr["bg color"] = input_style_attr["bg color"] frame_id = int(data[i][1]) frame_data = data[i][2:].astype(int) frame_data = np.split(frame_data, len(list(animal_attr.keys())), axis=0) img = np.zeros( ( int(video_info["Resolution_height"].values[0]), int(video_info["Resolution_width"].values[0]), 3, ) ) if (type(style_attr["bg color"]) == dict) and ( style_attr["bg color"]["type"] ) == "moving": style_attr["bg color"] = read_frm_of_video( video_path=video_cap, opacity=style_attr["bg color"]["opacity"], frame_index=frame_id, ) img[:] = style_attr["bg color"] for animal_cnt, animal_data in enumerate(frame_data): animal_clr = style_attr["animal clrs"][animal_cnt] cv2.line( img, tuple(animal_data), animal_clr, int(style_attr["line width"]) ) cv2.circle( img, tuple(animal_data[-1]), 0, animal_clr, style_attr["circle size"], ) if print_animal_names: cv2.putText( img, style_attr["animal names"][animal_cnt], tuple(animal_data[-1]), cv2.FONT_HERSHEY_COMPLEX, style_attr["font size"], animal_clr, style_attr["font thickness"], ) if clf_attr: for clf_cnt, clf_name in enumerate(clf_attr["data"].columns): clf_size = int(clf_attr["attr"][clf_cnt][-1].split(": ")[-1]) clf_clr = color_dict[clf_attr["attr"][clf_cnt][1]] clf_sliced = clf_attr["data"][clf_name].loc[0:frame_id] clf_sliced_idx = list(clf_sliced[clf_sliced == 1].index) locations = clf_attr["positions"][clf_sliced_idx, :] for i in range(locations.shape[0]): cv2.circle( img, (locations[i][0], locations[i][1]), 0, clf_clr, clf_size, ) img = cv2.resize(img, (style_attr["width"], style_attr["height"])) if video_setting: video_writer.write(np.uint8(img)) if frame_setting: frm_name = os.path.join(frame_folder_dir, str(frame_id) + ".png") cv2.imwrite(frm_name, np.uint8(img)) print( "Path frame created: {}, Video: {}, Processing core: {}".format( str(frame_id + 1), video_name, str(group + 1) ) ) if video_setting: video_writer.release() return group def violin_plot( self, data: pd.DataFrame, x: str, y: str, save_path: Union[str, os.PathLike], font_rotation: Optional[int] = 45, font_size: Optional[int] = 10, img_size: Optional[tuple] = (13.7, 8.27), cut: Optional[int] = 0, scale: Optional[Literal["area", "count", "width"]] = "area", ): named_colors = get_named_colors() palette = {} for cnt, violin in enumerate(sorted(list(data[x].unique()))): palette[violin] = named_colors[cnt] plt.figure() order = data.groupby(by=[x])[y].median().sort_values().iloc[::-1].index figure_FSCTT = sns.violinplot( x=x, y=y, data=data, cut=cut, scale=scale, order=order, palette=palette ) figure_FSCTT.set_xticklabels( figure_FSCTT.get_xticklabels(), rotation=font_rotation, size=font_size ) figure_FSCTT.figure.set_size_inches(img_size) figure_FSCTT.figure.savefig(save_path, bbox_inches="tight") stdout_success( msg=f"Violin plot saved at {save_path}", source=self.__class__.__name__ )
[docs] @staticmethod @njit([(uint8[:, :, :], bool_)]) def rotate_img(img: np.ndarray, right: bool) -> np.ndarray: """ Flip a color image 90 degrees to the left or right .. image:: _static/img/rotate_img.png :alt: Rotate img :width: 600 :align: center :param np.ndarray img: Input image as numpy array in uint8 format. :param bool right: If True, flips to the right. If False, flips to the left. :return: The rotated image as a numpy array of uint8 format. :example: >>> img = cv2.imread('/Users/simon/Desktop/test.png') >>> rotated_img = PlottingMixin.rotate_img(img=img, right=False) """ if right: img = np.transpose(img[:, ::-1, :], axes=(1, 0, 2)) else: img = np.transpose(img[::-1, :, :], axes=(1, 0, 2)) return np.ascontiguousarray(img).astype(np.uint8)
[docs] @staticmethod def continuous_scatter( data: Union[np.ndarray, pd.DataFrame], columns: Optional[List[str]] = ("X", "Y", "Cluster"), palette: Optional[str] = "magma", show_box: Optional[bool] = False, size: Optional[int] = 10, title: Optional[str] = None, bg_clr: Optional[str] = None, save_path: Optional[Union[str, os.PathLike]] = None, ): """Create a 2D scatterplot with a continuous legend""" check_instance( source=f"{PlottingMixin.continuous_scatter.__name__} data", instance=data, accepted_types=(np.ndarray, pd.DataFrame), ) if isinstance(data, pd.DataFrame): check_that_column_exist( df=data, column_name=columns, file_name=PlottingMixin.continuous_scatter.__name__, ) data = data[list(columns)] else: check_valid_array( data=data, source=PlottingMixin.continuous_scatter.__name__, accepted_ndims=(2,), max_axis_1=len(columns), min_axis_1=len(columns), ) data = pd.DataFrame(data, columns=list(columns)) fig, ax = plt.subplots() if bg_clr is not None: if bg_clr not in get_named_colors(): raise InvalidInputError( msg=f"bg_clr {bg_clr} is not a valid named color. Options: {get_named_colors()}", source=PlottingMixin.continuous_scatter.__name__, ) fig.set_facecolor(bg_clr) if not show_box: plt.axis("off") plt.xlabel(columns[0]) plt.ylabel(columns[1]) plot = ax.scatter( data[columns[0]], data[columns[1]], c=data[columns[2]], s=size, cmap=palette ) cbar = fig.colorbar(plot) cbar.set_label(columns[2], loc="center") if title is not None: plt.title( title, ha="center", fontsize=15, bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0}, ) if save_path is not None: check_if_dir_exists(in_dir=os.path.dirname(save_path)) fig.savefig(save_path) plt.close("all") else: return plot
[docs] @staticmethod def categorical_scatter( data: Union[np.ndarray, pd.DataFrame], columns: Optional[List[str]] = ("X", "Y", "Cluster"), palette: Optional[str] = "Set1", show_box: Optional[bool] = False, size: Optional[int] = 10, title: Optional[str] = None, save_path: Optional[Union[str, os.PathLike]] = None, ): """ Create a 2D scatterplot with a categorical legend. .. image:: _static/img/categorical_scatter.png :alt: Categorical scatter :width: 400 :align: center :param Union[np.ndarray, pd.DataFrame] data: Input data, either a NumPy array or a pandas DataFrame. :param Optional[List[str]] columns: A list of column names for the x-axis, y-axis, and the categorical variable respectively. Default is ["X", "Y", "Cluster"]. :param Optional[str] palette: The color palette to be used for the categorical variable. Default is "Set1". :param Optional[bool] show_box: Whether to display the plot axis. Default is False. :param Optional[int] size: Size of markers in the scatterplot. Default is 10. :param Optional[str] title: Title for the plot. Default is None. :param Optional[Union[str, os.PathLike]] save_path: The path where the plot will be saved. Default is None which returns the image. :return matplotlib.axes._subplots.AxesSubplot or None: The scatterplot if 'save_path' is not provided, otherwise None. """ cmaps = get_categorical_palettes() if palette not in cmaps: raise InvalidInputError( msg=f"{palette} is not a valid palette. Accepted options: {cmaps}.", source=PlottingMixin.categorical_scatter.__name__, ) check_instance( source=f"{PlottingMixin.categorical_scatter.__name__} data", instance=data, accepted_types=(np.ndarray, pd.DataFrame), ) if isinstance(data, pd.DataFrame): check_that_column_exist( df=data, column_name=columns, file_name=PlottingMixin.categorical_scatter.__name__, ) data = data[list(columns)] else: check_valid_array( data=data, source=PlottingMixin.categorical_scatter.__name__, accepted_ndims=(2,), max_axis_1=len(columns), min_axis_1=len(columns), ) data = pd.DataFrame(data, columns=list(columns)) if not show_box: plt.axis("off") # pct_x = np.percentile(data[columns[0]].values, 75) # pct_y = np.percentile(data[columns[1]].values, 75) # plt.xlim(data[columns[0]].min() - pct_x, data[columns[0]].max() + pct_x) # plt.ylim(data[columns[1]].min() - pct_y, data[columns[1]].max() + pct_y) plot = sns.scatterplot( data=data, x=columns[0], y=columns[1], hue=columns[2], palette=sns.color_palette(palette, len(data[columns[2]].unique())), s=size, ) if title is not None: plt.title( title, ha="center", fontsize=15, bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0}, ) if save_path is not None: check_if_dir_exists(in_dir=os.path.dirname(save_path)) plt.savefig(save_path) plt.close("all") else: return plot
[docs] @staticmethod def joint_plot( data: Union[np.ndarray, pd.DataFrame], columns: Optional[List[str]] = ("X", "Y", "Cluster"), palette: Optional[str] = "Set1", kind: Optional[str] = "scatter", size: Optional[int] = 10, title: Optional[str] = None, save_path: Optional[Union[str, os.PathLike]] = None, ): """ Generate a joint plot. Useful when visualizing embedded behavior data latent spaces with dense and overlapping scatters. .. image:: _static/img/joint_plot.png :alt: Joint plot :width: 700 :align: center :param Union[np.ndarray, pd.DataFrame] data: Input data, either a NumPy array or a pandas DataFrame. :param Optional[List[str]] columns: Names of columns if input is dataframe, default is ["X", "Y", "Cluster"]. :param Optional[str] palette: Palette for the plot, default is "Set1". :param Optional[str] kind: Type of plot ("scatter", "kde", "hist", or "reg"), default is "scatter". :param Optional[int] size: Size of markers for scatter plot, default is 10. :param Optional[str] title: Title of the plot, default is None. :param Optional[Union[str, os.PathLike]] save_path: Path to save the plot image, default is None. :return sns.JointGrid or None: JointGrid object if save_path is None, else None. :example: >>> data, lbls = make_blobs(n_samples=100000, n_features=2, centers=10, random_state=42) >>> data = np.hstack((data, lbls.reshape(-1, 1))) >>> PlottingMixin.joint_plot(data=data, columns=['X', 'Y', 'Cluster'], title='The plot') """ cmaps = get_categorical_palettes() if palette not in cmaps: raise InvalidInputError( msg=f"{palette} is not a valid palette. Accepted options: {cmaps}", source=PlottingMixin.joint_plot.__name__, ) check_instance( source=f"{PlottingMixin.joint_plot.__name__} data", instance=data, accepted_types=(np.ndarray, pd.DataFrame), ) check_str( name=f"{PlottingMixin.joint_plot.__name__} kind", value=kind, options=("kde", "reg", "hist", "scatter"), ) if isinstance(data, pd.DataFrame): check_that_column_exist( df=data, column_name=columns, file_name=PlottingMixin.joint_plot.__name__, ) data = data[list(columns)] else: check_valid_array( data=data, source=PlottingMixin.joint_plot.__name__, accepted_ndims=(2,), max_axis_1=len(columns), min_axis_1=len(columns), ) data = pd.DataFrame(data, columns=list(columns)) pct_x = np.percentile(data[columns[0]].values, 75) pct_y = np.percentile(data[columns[1]].values, 75) plot = sns.jointplot( data=data, x=columns[0], y=columns[1], hue=columns[2], xlim=(data[columns[0]].min() - pct_x, data[columns[0]].max() + pct_x), ylim=(data[columns[1]].min() - pct_y, data[columns[1]].max() + pct_y), palette=sns.color_palette(palette, len(data[columns[2]].unique())), kind=kind, marginal_ticks=False, s=size, ) if title is not None: plot.fig.suptitle( title, va="baseline", ha="center", fontsize=15, bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0}, ) if save_path is not None: check_if_dir_exists(in_dir=os.path.dirname(save_path)) plot.savefig(save_path) plt.close("all") else: return plot
[docs] @staticmethod def line_plot(df: pd.DataFrame, x: str, y: Union[str, List[str]], error: Optional[Union[str, List[str]]] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, fig_size: Tuple[int] = (10, 6), error_opacity: float = 0.2, palette: str = 'Set1', grid: bool = True, bg_clr: str = 'white', line_width: float = 1.5, line_opacity: float = 1.0, save_path: Optional[Union[str, os.PathLike]] = None, dpi: Optional[int] = None, tight_layout: bool = True, show_legend: bool = True, legend_loc: Optional[str] = 'best', font_size: Optional[int] = None, title_font_size: Optional[int] = None, x_lim: Optional[Tuple[float, float]] = None, y_lim: Optional[Tuple[float, float]] = None, marker: Optional[str] = None, markersize: Optional[float] = None, linestyle: Optional[Union[str, List[str]]] = None, save_kwargs: Optional[Dict[str, Any]] = None, svg: bool = False, x_tick_lbl_rotation: Optional[int] = None, x_tick_interval: Optional[int] = None): """ Line plot from DataFrame with optional error bands. Optional styling arguments (useful for publication or clearer plots): - dpi: Resolution for saved figure (e.g. 150, 300). Ignored when svg=True. - tight_layout: If True, call fig.tight_layout() before save (default True). - show_legend: Whether to show the legend (default True). - legend_loc: Legend position, e.g. 'best', 'upper right', 'lower left'. - font_size: Font size for axis labels and tick labels. - title_font_size: Font size for the title (default 15 if title set). - x_lim, y_lim: (min, max) tuples to fix axis limits. - marker: Matplotlib marker for data points (e.g. 'o', 's', '^'). - markersize: Size of markers when marker is set. - linestyle: Single style or list of styles per series ('-', '--', '-.', ':'). - save_kwargs: Dict passed to plt.savefig (e.g. bbox_inches='tight', pad_inches=0.1). - svg: If True and save_path is set, save as SVG (path extension becomes .svg). """ check_instance(source=f"{PlottingMixin.line_plot.__name__} df", instance=df, accepted_types=(pd.DataFrame)) check_str(name=f"{PlottingMixin.line_plot.__name__} x", value=x, options=tuple(df.columns)) check_instance(source=f"{PlottingMixin.line_plot.__name__} y", instance=y, accepted_types=(str, list)) check_float(name=f"{PlottingMixin.line_plot.__name__} line_width", value=line_width, allow_zero=False, allow_negative=False, raise_error=True) check_float(name=f"{PlottingMixin.line_plot.__name__} line_opacity", value=line_opacity, min_value=0.0, max_value=1.0) check_valid_boolean(value=grid, source=f"{PlottingMixin.line_plot.__name__} grid", raise_error=True) check_str(name=f"{PlottingMixin.line_plot.__name__} bg_clr", value=bg_clr.lower(), options=tuple(get_named_colors())) bg_clr = bg_clr.lower() if dpi is not None: check_int(name=f"{PlottingMixin.line_plot.__name__} dpi", value=dpi, min_value=1, raise_error=True) if font_size is not None: check_int(name=f"{PlottingMixin.line_plot.__name__} font_size", value=font_size, min_value=1, raise_error=True) if title_font_size is not None: check_int(name=f"{PlottingMixin.line_plot.__name__} title_font_size", value=title_font_size, min_value=1, raise_error=True) if legend_loc is not None: check_str(name=f"{PlottingMixin.line_plot.__name__} legend_loc", value=legend_loc, options=('best', 'upper right', 'upper left', 'lower left', 'lower right', 'right', 'center left', 'center right', 'lower center', 'upper center', 'center'), raise_error=True) if x_lim is not None: check_valid_tuple(x=x_lim, source=f"{PlottingMixin.line_plot.__name__} x_lim", accepted_lengths=(2,), valid_dtypes=(int, float)) if y_lim is not None: check_valid_tuple(x=y_lim, source=f"{PlottingMixin.line_plot.__name__} y_lim", accepted_lengths=(2,), valid_dtypes=(int, float)) if markersize is not None: check_float(name=f"{PlottingMixin.line_plot.__name__} markersize", value=markersize, allow_zero=False, allow_negative=False, raise_error=True) if linestyle is not None and isinstance(linestyle, list): check_instance(source=f"{PlottingMixin.line_plot.__name__} linestyle", instance=linestyle, accepted_types=(list,)) if len(linestyle) != len(y): raise ValueError(f"{PlottingMixin.line_plot.__name__} linestyle list length must match number of y series ({len(y)}).") check_valid_boolean(value=svg, source=f"{PlottingMixin.line_plot.__name__} svg", raise_error=True) if x_tick_lbl_rotation is not None: check_int(name=f"{PlottingMixin.line_plot.__name__} x_tick_lbl_rotation", value=x_tick_lbl_rotation, min_value=0, max_value=360) if x_tick_interval is not None: check_int(name=f"{PlottingMixin.line_plot.__name__} x_tick_interval", value=x_tick_interval, min_value=1) if grid: sns.set_style(style="whitegrid", rc={"grid.linestyle": "--"}) else: sns.set_style(style="white") if isinstance(y, str): check_str(name=f"{PlottingMixin.line_plot.__name__} y", value=y, options=tuple(df.columns)) check_valid_lst(data=list(df[y]), source=f"{PlottingMixin.line_plot.__name__} y", valid_dtypes=Formats.NUMERIC_DTYPES.value) y = [y] if error is not None: check_instance(source=f"{PlottingMixin.line_plot.__name__} error", instance=error, accepted_types=(str,)) check_str(name=f"{PlottingMixin.line_plot.__name__} error", value=error, options=tuple(df.columns)) check_valid_lst(data=list(df[error]), source=f"{PlottingMixin.line_plot.__name__} error", valid_dtypes=Formats.NUMERIC_DTYPES.value) error = [error] else: for i in y: check_str(name=f"{PlottingMixin.line_plot.__name__} y", value=i, options=tuple(df.columns)) check_valid_lst(data=list(df[i]), source=f"{PlottingMixin.line_plot.__name__} error", valid_dtypes=Formats.NUMERIC_DTYPES.value) if error is not None: check_instance(source=f"{PlottingMixin.line_plot.__name__} error", instance=error, accepted_types=(list,)) for i in error: check_str(name=f"{PlottingMixin.line_plot.__name__} error", value=i, options=tuple(df.columns)) check_valid_lst(data=list(df[i]), source=f"{PlottingMixin.line_plot.__name__} error", valid_dtypes=Formats.NUMERIC_DTYPES.value) fig, ax = plt.subplots(figsize=fig_size) fig.set_facecolor(bg_clr) ax.set_facecolor(bg_clr) colors = sns.color_palette(palette, n_colors=len(y)) for i in range(len(y)): ls = linestyle[i] if isinstance(linestyle, list) else linestyle plot_kw = dict(data=df, x=x, y=y[i], label=y[i], color=colors[i], linewidth=line_width, alpha=line_opacity) if marker is not None: plot_kw['marker'] = marker if ls is not None: plot_kw['linestyle'] = ls sns.lineplot(**plot_kw) if error is not None: ax.fill_between(df[x], df[y[i]] - df[error[i]], df[y[i]] + df[error[i]], alpha=error_opacity, color=colors[i]) if marker is not None and markersize is not None: for line in ax.get_lines(): line.set_markersize(markersize) if x_label is not None: check_str(name=f"{PlottingMixin.line_plot.__name__} x_label", value=x_label) ax.set_xlabel(x_label, fontsize=font_size) if y_label is not None: check_str(name=f"{PlottingMixin.line_plot.__name__} y_label", value=y_label) ax.set_ylabel(y_label, fontsize=font_size) if title is not None: check_str(name=f"{PlottingMixin.line_plot.__name__} title", value=title) ax.set_title(title, ha="center", fontsize=title_font_size if title_font_size is not None else 15) if font_size is not None: ax.tick_params(axis='both', labelsize=font_size) if x_lim is not None: ax.set_xlim(x_lim[0], x_lim[1]) if y_lim is not None: ax.set_ylim(y_lim[0], y_lim[1]) if not show_legend: ax.legend_.remove() if ax.legend_ is not None else None elif legend_loc is not None and ax.legend_ is not None: ax.legend(loc=legend_loc, fontsize=font_size) ax.grid(grid) if x_tick_interval is not None: x_vals = df[x].values ax.set_xticks(range(0, len(x_vals), x_tick_interval)) ax.set_xticklabels(x_vals[::x_tick_interval]) if x_tick_lbl_rotation is not None: plt.setp(ax.get_xticklabels(), rotation=x_tick_lbl_rotation, ha='right') if tight_layout: fig.tight_layout() if save_path is not None: check_str(name=f"{PlottingMixin.line_plot.__name__} save_path", value=save_path) check_if_dir_exists(in_dir=os.path.dirname(save_path)) out_path = save_path save_opts = {} if svg: save_dir, save_name, _ = get_fn_ext(filepath=save_path) out_path = os.path.join(save_dir, f'{save_name}.svg') save_opts['format'] = 'svg' elif dpi is not None: save_opts['dpi'] = dpi if isinstance(save_kwargs, dict): save_opts.update(save_kwargs) plt.savefig(out_path, **save_opts) plt.close("all") else: return fig
[docs] @staticmethod def make_line_plot(data: List[np.ndarray], colors: List[str], show_box: Optional[bool] = True, width: Optional[int] = 640, height: Optional[int] = 480, line_width: Optional[int] = 6, font_size: Optional[int] = 8, bg_clr: Optional[str] = None, x_lbl_divisor: Optional[float] = None, title: Optional[str] = None, y_lbl: Optional[str] = None, x_lbl: Optional[str] = None, y_tick_lbls_as_int: bool = False, x_tick_lbls_as_int: bool = False, y_tick_cnt: int = 10, x_tick_cnt: int = 5, y_max: Optional[Union[int, float]] = -1, line_opacity: Optional[float] = 1.0, as_svg: bool = False, save_path: Optional[Union[str, os.PathLike]] = None, show_thresholds: bool = False): """ Create a multi-line plot from NumPy arrays. Generates a line plot with one or more data series, each with customizable colors and styling. Supports SVG output for scalable figures or PNG/NumPy array for video overlays. .. image:: _static/img/line_plot_mosaic.webp :alt: Line plot mosaic :width: 1000 :align: center :param List[np.ndarray] data: List of 1D or 2D NumPy arrays to plot. Each array becomes one line. :param List[str] colors: List of color names (must match length of ``data``). Uses SimBA color dictionary. :param Optional[bool] show_box: If False, hides plot axes and borders. Default: True. :param Optional[int] width: Output image width in pixels (when not SVG). Default: 640. :param Optional[int] height: Output image height in pixels (when not SVG). Default: 480. :param Optional[int] line_width: Width of plotted lines. Default: 6. :param Optional[int] font_size: Font size for labels and ticks. Default: 8. :param Optional[str] bg_clr: Background color name. If None, uses matplotlib default. :param Optional[float] x_lbl_divisor: Divide x-axis tick labels by this value (e.g., convert frames to seconds). Default: None. :param Optional[str] title: Plot title displayed at top. :param Optional[str] y_lbl: Y-axis label. :param Optional[str] x_lbl: X-axis label. :param bool y_tick_lbls_as_int: If True, formats y-axis ticks as integers. Default: False. :param bool x_tick_lbls_as_int: If True, formats x-axis ticks as integers. Default: False. :param int y_tick_cnt: Number of y-axis tick marks. Default: 10. :param int x_tick_cnt: Number of x-axis tick marks. Default: 5. :param Optional[Union[int, float]] y_max: Maximum y-axis value. If -1, auto-scales to data maximum. Default: -1. :param Optional[float] line_opacity: Opacity of lines (0.0-1.0). Default: 1.0. :param bool as_svg: If True, returns or saves SVG format. If False, uses PNG format. Default: False. :param Optional[Union[str, os.PathLike]] save_path: Path to save the image. If None and ``as_svg=False``, returns NumPy array. :param bool show_thresholds: If True, displays horizontal threshold lines at 25%, 50%, and 75%. Default: False. :return Union[None, np.ndarray, str]: If ``as_svg=True`` and ``save_path=None``, returns SVG string. If ``save_path`` provided, returns None (saves file). Otherwise returns NumPy array (BGR format). """ check_valid_lst(data=data, source=PlottingMixin.make_line_plot.__name__, valid_dtypes=(np.ndarray, list)) check_valid_lst(data=colors, source=PlottingMixin.make_line_plot.__name__, valid_dtypes=(str,), exact_len=len(data)) clr_dict = get_color_dict() matplotlib.font_manager._get_font.cache_clear() plt.close("all") fig, ax = plt.subplots() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_linewidth(1.2) ax.spines['bottom'].set_linewidth(1.2) if show_thresholds: ax.axhline(y=0.75, color='#ec4899', linestyle=(0, (3, 1, 1, 1)), linewidth=1.5, alpha=0.9, label='Threshold: 75%') ax.axhline(y=0.5, color='#3b82f6', linestyle=(0, (3, 1, 1, 1)), linewidth=1.5, alpha=0.9, label='Threshold: 50%') ax.axhline(y=0.25, color='#8b5cf6', linestyle=(0, (3, 1, 1, 1)), linewidth=1.5, alpha=0.9, label='Threshold: 25%') if bg_clr is not None: fig.set_facecolor(bg_clr) if not show_box: plt.axis("off") for i in range(len(data)): line_clr = clr_dict[colors[i]][::-1] line_clr = tuple(x / 255 for x in line_clr) flat_data = data[i].flatten() plt.plot( flat_data, color=line_clr, linewidth=line_width, alpha=line_opacity) max_x = max([len(x) for x in data]) if y_max == -1: y_max = max([np.max(x) for x in data]) if not y_tick_lbls_as_int: y_ticks_locs = y_lbls = np.round(np.linspace(0, y_max, y_tick_cnt), 2) else: y_ticks_locs = y_lbls = np.round(np.linspace(0, y_max, y_tick_cnt)).astype(np.int32) if not x_tick_lbls_as_int: x_ticks_locs = x_lbls = np.linspace(0, max_x, x_tick_cnt) else: x_ticks_locs = x_lbls = np.linspace(0, max_x, x_tick_cnt).astype(np.int32) if x_lbl_divisor is not None: x_lbls = np.round((x_lbls / x_lbl_divisor), 1) if y_lbl is not None: plt.ylabel(y_lbl) if x_lbl is not None: plt.xlabel(x_lbl) plt.xticks(x_ticks_locs, x_lbls, rotation="horizontal", fontsize=font_size) plt.yticks(y_ticks_locs, y_lbls, fontsize=font_size) plt.ylim(0, y_max) if title is not None: plt.suptitle(title, x=0.5, y=0.92, fontsize=font_size + 4) if as_svg and save_path is None: svg_buffer = io.BytesIO() plt.savefig(svg_buffer, format="svg", bbox_inches="tight") svg_buffer.seek(0) svg_data = svg_buffer.getvalue().decode("utf-8") svg_buffer.close() plt.close() return svg_data elif as_svg and save_path is not None: plt.savefig(save_path, format="svg", bbox_inches="tight") plt.close() stdout_success(msg=f"Line plot saved at {save_path}") return None else: buffer_ = io.BytesIO() plt.savefig(buffer_, format="png") buffer_.seek(0) img = PIL.Image.open(buffer_) img = np.uint8(cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)) buffer_.close() plt.close() img = cv2.resize(img, (width, height)) if save_path is not None: cv2.imwrite(save_path, img) stdout_success(msg=f"Line plot saved at {save_path}") else: return img
[docs] @staticmethod def make_line_plot_plotly( data: List[np.ndarray], colors: List[str], show_box: Optional[bool] = True, show_grid: Optional[bool] = False, width: Optional[int] = 640, height: Optional[int] = 480, line_width: Optional[int] = 6, font_size: Optional[int] = 8, bg_clr: Optional[str] = "white", x_lbl_divisor: Optional[float] = None, title: Optional[str] = None, y_lbl: Optional[str] = None, x_lbl: Optional[str] = None, y_max: Optional[int] = -1, line_opacity: Optional[int] = 0.5, save_path: Optional[Union[str, os.PathLike]] = None, ): """ Create a line plot using Plotly. .. note:: Plotly can be more reliable than matplotlib on some systems when accessed through multprocessing calls. If **not** called though multiprocessing, consider using ``simba.mixins.plotting_mixin.PlottingMixin.make_line_plot()`` Uses ``kaleido`` for transform image to numpy array or save to disk. .. image:: _static/img/make_line_plot_plotly.png :alt: Make line plot plotly :width: 500 :align: center :param List[np.ndarray] data: List of 1D numpy arrays representing lines. :param List[str] colors: List of named colors of size len(data). :param bool show_box: Whether to show the plot box (axes, title, etc.). :param bool show_grid: Whether to show gridlines on the plot. :param int width: Width of the plot in pixels. :param int height: Height of the plot in pixels. :param int line_width: Width of the lines in the plot. :param int font_size: Font size for axis labels and tick labels. :param str bg_clr: Background color of the plot. :param float x_lbl_divisor: Divisor for adjusting the tick spacing on the x-axis. :param str title: Title of the plot. :param str y_lbl: Label for the y-axis. :param str x_lbl: Label for the x-axis. :param int y_max: Maximum value for the y-axis. :param float line_opacity: Opacity of the lines in the plot. :param Union[str, os.PathLike] save_path: Path to save the plot image. If None, returns a numpy array of the plot. :return: If save_path is None, returns a numpy array representing the plot image. :example: >>> p = np.random.randint(0, 50, (100,)) >>> y = np.random.randint(0, 50, (200,)) >>> img = PlottingMixin.make_line_plot_plotly(data=[p, y], show_box=False, font_size=20, bg_clr='white', show_grid=False, x_lbl_divisor=30, colors=['Red', 'Green'], save_path='/Users/simon/Desktop/envs/simba/troubleshooting/beepboop174/project_folder/frames/output/line_plot/Trial 3_final_img.png') """ def tick_formatter(x): if x_lbl_divisor is not None: return str(round(x / x_lbl_divisor, 2)) else: return str(x) fig = go.Figure() clr_dict = get_color_dict() if y_max == -1: y_max = max([np.max(i) for i in data]) for i in range(len(data)): line_clr = clr_dict[colors[i]] line_clr = ( f"rgba({line_clr[0]}, {line_clr[1]}, {line_clr[2]}, {line_opacity})" ) fig.add_trace( go.Scatter( y=data[i].flatten(), mode="lines", line=dict(color=line_clr, width=line_width), ) ) if not show_box: fig.update_layout( width=width, height=height, title=title, xaxis_visible=False, yaxis_visible=False, showlegend=False, ) else: if fig["layout"]["xaxis"]["tickvals"] is None: tickvals = [i for i in range(data[0].shape[0])] else: tickvals = fig["layout"]["xaxis"]["tickvals"] if x_lbl_divisor is not None: ticktext = [tick_formatter(x) for x in tickvals] else: ticktext = tickvals fig.update_layout( width=width, height=height, title=title, xaxis=dict( title=x_lbl, tickvals=tickvals, ticktext=ticktext, tickmode="auto", tick0=0, dtick=10, tickfont=dict(size=font_size), showgrid=show_grid, ), yaxis=dict( title=y_lbl, tickfont=dict(size=font_size), range=[0, y_max], showgrid=show_grid, ), showlegend=False, ) if bg_clr is not None: fig.update_layout(plot_bgcolor=bg_clr) img_bytes = fig.to_image(format="png") img = PIL.Image.open(io.BytesIO(img_bytes)) img = np.array(img) if save_path is not None: cv2.imwrite(save_path, img) stdout_success(msg=f"Line plot saved at {save_path}") else: return img
[docs] @staticmethod def make_path_plot( data: List[np.ndarray], colors: List[Union[Tuple[int, int, int], str]], width: Optional[int] = 640, height: Optional[int] = 480, max_lines: Optional[int] = None, bg_clr: Optional[Union[Tuple[int, int, int], np.ndarray]] = (255, 255, 255), circle_size: Optional[Union[int, None]] = 3, font_size: Optional[float] = 2.0, font_thickness: Optional[int] = 2, line_width: Optional[int] = 2, animal_names: Optional[List[str]] = None, clf_attr: Optional[Dict[str, Any]] = None, save_path: Optional[Union[str, os.PathLike]] = None, ) -> Union[None, np.ndarray]: """ Creates a path plot visualization from the given data. .. image:: _static/img/make_path_plot.png :alt: Make path plot :width: 500 :align: center :param List[np.ndarray] data: List of numpy arrays containing path data. :param List[Tuple[int, int, int]] colors: List of RGB tuples, strings (names of palettes), or lists of list of tuples, representing colors for each path. :param width: Width of the output image (default is 640 pixels). :param height: Height of the output image (default is 480 pixels). :param max_lines: Maximum number of lines to plot from each path data. :param bg_clr: Background color of the plot (default is white). :param circle_size: Size of the circle marker at the end of each path (default is 3). :param font_size: Font size for displaying animal names (default is 2.0). :param font_thickness: Thickness of the font for displaying animal names (default is 2). :param line_width: Width of the lines representing paths (default is 2). :param animal_names: List of names for the animals corresponding to each path. :param clf_attr: Dictionary containing attributes for classification markers. :param save_path: Path to save the generated plot image. :return: If save_path is None, returns the generated image as a numpy array, otherwise, returns None. :example: >>> x = np.random.randint(0, 500, (100, 2)) >>> y = np.random.randint(0, 500, (100, 2)) >>> position_data = np.random.randint(0, 500, (100, 2)) >>> clf_data_1 = np.random.randint(0, 2, (100,)) >>> clf_data_2 = np.random.randint(0, 2, (100,)) >>> clf_data = {'Attack': {'color': (155, 1, 10), 'size': 30, 'positions': position_data, 'clfs': clf_data_1}, 'Sniffing': {'color': (155, 90, 10), 'size': 30, 'positions': position_data, 'clfs': clf_data_2}} >>> PlottingMixin.make_path_plot(data=[x, y], colors=[(0, 255, 0), (255, 0, 0)], clf_attr=clf_data) """ check_valid_lst(data=data, source=PlottingMixin.make_path_plot.__name__, valid_dtypes=(np.ndarray,), min_len=1) for i in data: check_valid_array(data=i, source=PlottingMixin.make_path_plot.__name__, accepted_ndims=(2,), accepted_axis_1_shape=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) check_instance(source="bg_clr", instance=bg_clr, accepted_types=(np.ndarray, tuple)) if isinstance(bg_clr, tuple): check_if_valid_rgb_tuple(data=bg_clr) check_int(name=f"{PlottingMixin.make_path_plot.__name__} height", value=height, min_value=1) check_int(name=f"{PlottingMixin.make_path_plot.__name__} height", value=width, min_value=1) check_float(name=f"{PlottingMixin.make_path_plot.__name__} font_size", value=font_size) check_int(name=f"{PlottingMixin.make_path_plot.__name__} font_thickness", value=font_thickness) check_int(name=f"{PlottingMixin.make_path_plot.__name__} line_width", value=line_width) timer = SimbaTimer(start=True) plot_clrs = [] check_valid_lst(data=colors, source=PlottingMixin.make_path_plot.__name__, valid_dtypes=(tuple, str, list,), exact_len=len(data)) for cnt, t in enumerate(colors): if check_if_valid_rgb_tuple(data=t, raise_error=False): plot_clrs.append([t] * data[cnt].shape[0]) elif t in Options.PALETTE_OPTIONS.value: plot_clrs.append(create_color_palette(pallete_name=t, increments=data[cnt].shape[0], as_int=True)) elif isinstance(t, list): check_valid_lst(data=t, source=PlottingMixin.make_path_plot.__name__, valid_dtypes=(tuple,), exact_len=data[cnt].shape[0]) for lst_clr in t: check_if_valid_rgb_tuple(data=lst_clr, raise_error=True, source=PlottingMixin.make_path_plot.__class__.__name__) plot_clrs.append(t) else: raise InvalidInputError(msg=f'The color {t} for is not a valid color palette or valid rgb color tuple.', source=PlottingMixin.make_path_plot.__class__.__name__) if (isinstance(bg_clr, np.ndarray)) and bg_clr.ndim > 1: img = np.zeros((bg_clr.shape[0], bg_clr.shape[1], 3)) else: img = np.zeros((height, width, 3)) img[:] = bg_clr for line_cnt in range(len(data)): last_clr = plot_clrs[line_cnt][-1] line_data = data[line_cnt] if max_lines is not None: check_int(name=f"{PlottingMixin.make_path_plot.__name__} max_lines", value=max_lines, min_value=1) line_data = line_data[-max_lines:] for i in range(1, line_data.shape[0]): clr = plot_clrs[line_cnt][i] cv2.line(img, tuple(int(x) for x in line_data[i]), tuple(int(x) for x in line_data[i - 1]), clr, line_width) if circle_size is not None: cv2.circle(img, tuple(int(x) for x in line_data[-1]),0, last_clr, circle_size) if animal_names is not None: cv2.putText( img, animal_names[line_cnt], tuple(int(x) for x in line_data[-1]), cv2.FONT_HERSHEY_COMPLEX, font_size, last_clr, font_thickness, ) if clf_attr is not None: check_instance( source=PlottingMixin.make_path_plot.__name__, instance=clf_attr, accepted_types=(dict,), ) for k, v in clf_attr.items(): check_if_keys_exist_in_dict( data=v, key=["color", "size", "positions", "clfs"], name="clf_attr" ) for clf_name, clf_data in clf_attr.items(): clf_positions = clf_data["positions"][ np.argwhere(clf_data["clfs"] == 1).flatten() ] for i in clf_positions: cv2.circle(img, tuple(i), 0, clf_data["color"], clf_data["size"]) img = cv2.resize(img, (width, height)).astype(np.uint8) if save_path is not None: check_if_dir_exists(in_dir=os.path.dirname(save_path)) timer.stop_timer() cv2.imwrite(save_path, img) stdout_success( msg=f"Path plot saved at {save_path}", elapsed_time=timer.elapsed_time_str, source=PlottingMixin.make_path_plot.__name__, ) else: return img
@staticmethod def rectangles_onto_image(img: np.ndarray, rectangles: pd.DataFrame, show_center: Optional[bool] = False, show_tags: Optional[bool] = False, circle_size: Optional[int] = 2, line_type: int = -1, print_metrics: bool = False, omitted_rois: Optional[List[str]] = None, omitted_centers: Optional[List[str]] = None, txt_size: Optional[Union[float, int]] = None) -> np.ndarray: check_valid_array(data=img, source=PlottingMixin.rectangles_onto_image.__name__) check_valid_dataframe(df=rectangles, source=PlottingMixin.rectangles_onto_image.__name__, required_fields=["topLeftX", "topLeftY", "Bottom_right_X", "Bottom_right_Y", "Color BGR", "Thickness", "Center_X", "Center_Y", "Tags", "Ear_tag_size", 'width', 'height', 'Name']) check_int(name='line_type', value=line_type, accepted_vals=[4, 8, 16, -1], raise_error=True) if circle_size is not None: check_int(name=PlottingMixin.rectangles_onto_image.__name__, value=circle_size, min_value=1) for _, row in rectangles.iterrows(): if omitted_rois is not None and row['Name'] in omitted_rois: continue rectangle_line_type = [4 if row['Thickness'] == 1 else line_type][0] tag_size = [row['Ear_tag_size'] if circle_size is None else circle_size][0] if isinstance(row["Color BGR"], str): row["Color BGR"] = ast.literal_eval(row["Color BGR"]) img = cv2.rectangle(img, (int(row["topLeftX"]), int(row["topLeftY"])), (int(row["Bottom_right_X"]), int(row["Bottom_right_Y"])), row["Color BGR"], int(row["Thickness"]), lineType=rectangle_line_type) if show_center: if omitted_centers is not None and row['Name'] in omitted_centers: continue img = cv2.circle(img, (int(row["Center_X"]), int(row["Center_Y"])), tag_size, row["Color BGR"], -1) if show_tags: for tag_name, tag_data in row["Tags"].items(): img = cv2.circle(img, tuple(tag_data), tag_size, row["Color BGR"], -1) if (print_metrics) and ('area_cm' in rectangles.columns): if txt_size is None: font_size, _, _ = PlottingMixin().get_optimal_font_scales(text='10 DIGIT TEXT', accepted_px_height=row['height'], accepted_px_width=row['width']) else: font_size = copy(txt_size) img = PlottingMixin().put_text(img=img, text=str(row['area_cm']), pos=row['Tags']['Left tag'], font_size=font_size, text_bg_alpha=0.6) return img @staticmethod def circles_onto_image(img: np.ndarray, circles: pd.DataFrame, show_center: Optional[bool] = False, show_tags: Optional[bool] = False, circle_size: Optional[int] = 2, line_type: Optional[int] = -1, print_metrics: bool = False, omitted_rois: Optional[List[str]] = None, omitted_centers: Optional[List[str]] = None, txt_size: Optional[Union[float, int]] = None) -> np.ndarray: check_valid_array(data=img, source=PlottingMixin.circles_onto_image.__name__) check_valid_dataframe(df=circles, source=PlottingMixin.circles_onto_image.__name__, required_fields=["centerX", "centerY", "radius", "Color BGR", "Thickness", "Tags", "Ear_tag_size", "Name"]) if circle_size is not None: check_int(name=PlottingMixin.circles_onto_image.__name__, value=circle_size, min_value=1) check_int(name='line_type', value=line_type, accepted_vals=[4, 8, 16, -1], raise_error=True) for _, row in circles.iterrows(): if omitted_rois is not None and row['Name'] in omitted_rois: continue circle_line_type = [4 if row['Thickness'] == 1 else line_type][0] tag_size = [row['Ear_tag_size'] if circle_size is None else circle_size][0] if isinstance(row["Color BGR"], str): row["Color BGR"] = ast.literal_eval(row["Color BGR"]) img = cv2.circle(img, (int(row["centerX"]), int(row["centerY"])), row["radius"], row["Color BGR"], int(row["Thickness"]), lineType=circle_line_type) if show_center: if omitted_centers is not None and row['Name'] in omitted_centers: continue try: img = cv2.circle(img, (int(row["Center_X"]), int(row["Center_Y"])), tag_size, row["Color BGR"], -1, lineType=circle_line_type) except KeyError: img = cv2.circle(img, (int(row["centerX"]), int(row["centerY"])), tag_size, row["Color BGR"], -1, lineType=circle_line_type) if show_tags: for tag_data in row["Tags"].values(): img = cv2.circle(img, tuple(tag_data), tag_size, row["Color BGR"], -1, lineType=circle_line_type) if (print_metrics) and ('radius_cm' in circles.columns): if txt_size is None: font_size, _, _ = PlottingMixin().get_optimal_font_scales(text='10 DIGIT TEXT', accepted_px_height=row['radius'] * 2, accepted_px_width=row['radius'] * 2) else: font_size = copy(txt_size) img = PlottingMixin().put_text(img=img, text=str(row['radius_cm']), pos=row['Tags']['Border tag'], font_size=font_size, text_bg_alpha=0.6) return img @staticmethod def polygons_onto_image(img: np.ndarray, polygons: pd.DataFrame, show_center: Optional[bool] = False, show_tags: Optional[bool] = False, circle_size: Optional[int] = 2, line_type: Optional[int] = -1, print_metrics: bool = False, omitted_rois: Optional[List[str]] = None, omitted_centers: Optional[List[str]] = None, txt_size: Optional[Union[float, int]] = None) -> np.ndarray: check_valid_array(data=img, source=f"{PlottingMixin.polygons_onto_image.__name__} img") check_valid_dataframe(df=polygons, source=f"{PlottingMixin.polygons_onto_image.__name__} polygons", required_fields=["vertices", "Color BGR", "Thickness", "Tags", "Name"]) check_int(name='line_type', value=line_type, accepted_vals=[4, 8, 16, -1], raise_error=True) if circle_size is not None: check_int(name=PlottingMixin.polygons_onto_image.__name__, value=circle_size, min_value=1) for _, row in polygons.iterrows(): if omitted_rois is not None and row['Name'] in omitted_rois: continue polygon_line_type = [4 if row['Thickness'] == 1 else line_type][0] tag_size = [row['Ear_tag_size'] if circle_size is None else circle_size][0] if isinstance(row["Color BGR"], str): row["Color BGR"] = ast.literal_eval(row["Color BGR"]) img = cv2.polylines( img, [row["vertices"].astype(np.int32)], True, row["Color BGR"], thickness=int(row["Thickness"]), lineType=polygon_line_type) if show_center: if omitted_centers is not None and row['Name'] in omitted_centers: continue img = cv2.circle(img, (int(row["Center_X"]), int(row["Center_Y"])), tag_size, row["Color BGR"], polygon_line_type) if show_tags: for tag_name, tag_data in row["Tags"].items(): img = cv2.circle(img, tuple(tag_data), tag_size, row["Color BGR"], polygon_line_type) if (print_metrics) and ('area_cm' in polygons.columns) and ('max_vertice_distance' in polygons.columns) and ('center' in polygons.columns): if txt_size is None: font_size, _, _ = PlottingMixin().get_optimal_font_scales(text='10 DIGIT TEXT', accepted_px_height=int(row['max_vertice_distance'] / 2), accepted_px_width=int(row['max_vertice_distance'] / 2)) else: font_size = copy(txt_size) img = PlottingMixin().put_text(img=img, text=str(row['area_cm']), pos=(int(row['center'][0]), int(row['center'][1])), font_size=font_size, text_bg_alpha=0.6) return img @staticmethod def roi_dict_onto_img(img: np.ndarray, roi_dict: Dict[str, pd.DataFrame], circle_size: Optional[int] = None, show_center: Optional[bool] = False, omitted_centers: Optional[List[str]] = None, show_tags: Optional[bool] = False) -> np.ndarray: check_valid_array(data=img, source=f"{PlottingMixin.roi_dict_onto_img.__name__} img") check_if_keys_exist_in_dict(data=roi_dict, key=[Keys.ROI_POLYGONS.value, Keys.ROI_CIRCLES.value, Keys.ROI_RECTANGLES.value], name=PlottingMixin.roi_dict_onto_img.__name__) img = PlottingMixin.rectangles_onto_image(img=img, rectangles=roi_dict[Keys.ROI_RECTANGLES.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags, omitted_centers=omitted_centers) img = PlottingMixin.circles_onto_image(img=img, circles=roi_dict[Keys.ROI_CIRCLES.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags, omitted_centers=omitted_centers) img = PlottingMixin.polygons_onto_image(img=img, polygons=roi_dict[Keys.ROI_POLYGONS.value], circle_size=circle_size, show_center=show_center, show_tags=show_tags, omitted_centers=omitted_centers) return img
[docs] @staticmethod def insert_directing_line( directing_df: pd.DataFrame, img: np.ndarray, shape_name: str, animal_name: str, frame_id: int, color: Optional[Tuple[int]] = (0, 0, 255), thickness: Optional[int] = 2, style: Optional[str] = "lines", ) -> np.ndarray: """ Helper to insert lines between the actor 'eye' and the ROI centers. :param directing_df: Dataframe containing eye and ROI locations. Stored as ``results`` in instance of ``simba.roi_tools.ROI_directing_analyzer.DirectingROIAnalyzer``. :param np.ndarray img: The image to draw the line on. :param str shape_name: The name of the shape to draw the line to. :param str animal_name: The name of the animal :param int frame_id: The frame number in the video :param Optional[Tuple[int]] color: The color of the line :param Optional[int] thickness: The thickness of the line. :param Optional[str] style: The style of the line. "lines" or "funnel". :return np.ndarray: The input image with the line. """ check_valid_array(data=img, source=PlottingMixin.insert_directing_line.__name__) check_valid_dataframe( df=directing_df, source=PlottingMixin.rectangles_onto_image.__name__, required_fields=[ "ROI", "Animal", "Frame", "ROI_edge_1_x", "ROI_edge_1_y", "ROI_edge_2_x", "ROI_edge_2_y", ], ) r = directing_df.loc[ (directing_df["ROI"] == shape_name) & (directing_df["Animal"] == animal_name) & (directing_df["Frame"] == frame_id) ].reset_index(drop=True) if style == "funnel": convex_hull_arr = ( np.array( [ [r["ROI_edge_1_x"], r["ROI_edge_1_y"]], [r["ROI_edge_2_x"], r["ROI_edge_2_y"]], [r["Eye_x"], r["Eye_y"]], ] ) .reshape(-1, 2) .astype(int) ) img = cv2.fillPoly(img, [convex_hull_arr], color) else: img = cv2.line( img, (int(r["Eye_x"]), int(r["Eye_y"])), (int(r["ROI_x"]), int(r["ROI_y"])), color, thickness, ) return img
[docs] @staticmethod def draw_lines_on_img(img: np.ndarray, start_positions: np.ndarray, end_positions: np.ndarray, color: Tuple[int, int, int], opacity: Optional[float] = None, highlight_endpoint: Optional[bool] = False, thickness: Optional[int] = 2, circle_size: Optional[int] = 2) -> np.ndarray: """ Helper to draw a set of lines onto an image. :param np.ndarray img: The image to draw the lines on. :param np.ndarray start_positions: 2D numpy array representing the start positions of the lines in x, y format. :param np.ndarray end_positions: 2D numpy array representing the end positions of the lines in x, y format. :param Tuple[int, int, int] color: The color of the lines in BGR format. :param Optional[bool] highlight_endpoint: If True, highlights the ends of the lines with circles. :param Optional[int] thickness: The thickness of the lines. :param Optional[int] circle_size: If ``highlight_endpoint`` is True, the size of the highlighted points. :return np.ndarray: The image with the lines overlayed. """ check_valid_array(data=start_positions, source=f"{PlottingMixin.draw_lines_on_img.__name__} start_positions") check_valid_array(data=start_positions, source=f"{PlottingMixin.draw_lines_on_img.__name__} start_positions", accepted_ndims=(2,), accepted_dtypes=(Formats.INTEGER_DTYPES.value), min_axis_0=1 ) check_valid_array(data=end_positions, source=f"{PlottingMixin.draw_lines_on_img.__name__} end_positions", accepted_shapes=[(start_positions.shape[0], 2),], accepted_dtypes=Formats.INTEGER_DTYPES.value) check_if_valid_img(data=img, source=f"{PlottingMixin.draw_lines_on_img.__name__} img", raise_error=True) if opacity is not None: check_float(name=f"{PlottingMixin.draw_lines_on_img.__name__} opacity", value=opacity, min_value=0.0, max_value=1.0, raise_error=True, allow_negative=False) check_if_valid_rgb_tuple(data=color) if opacity is not None and opacity < 1.0: line_layer = img.copy() for i in range(start_positions.shape[0]): cv2.line(img, pt1=(start_positions[i][0], start_positions[i][1]), pt2=(end_positions[i][0], end_positions[i][1]), color=color, thickness=thickness) if highlight_endpoint: cv2.circle(img,(end_positions[i][0], end_positions[i][1]), circle_size, color, -1) img = cv2.addWeighted(line_layer, 1 - opacity, img, opacity, 0) else: for i in range(start_positions.shape[0]): cv2.line(img, pt1=(start_positions[i][0], start_positions[i][1]), pt2=(end_positions[i][0], end_positions[i][1]), color=color, thickness=thickness) if highlight_endpoint: cv2.circle(img,(end_positions[i][0], end_positions[i][1]), circle_size, color, -1) return img
[docs] def get_optimal_font_scales(self, text: Union[str, List[str]], accepted_px_width: int, accepted_px_height: int, text_thickness: Optional[int] = 2, font: Optional[int] = cv2.FONT_HERSHEY_TRIPLEX) -> Tuple[float, int, int]: """ Get the optimal font size, column-wise and row-wise text distance of printed text for printing on images. :param str text: The text to be printed. Either a string or a list of strings. If a list, then the longest string will be used to evaluate spacings/font. :param int accepted_px_width: The widest allowed string in pixels. E.g., 1/4th of the image width. :param int accepted_px_height: The highest allowed string in pixels. E.g., 1/10th of the image size. :param Optional[int] text_thickness: The thickness of the font. Default: 2. :param Optional[int] font: The font integer representation 0-7. See ``simba.utils.enums.Options.CV2_FONTS.values :return Tuple[int, int, int]: The font size, the shift on x between successive columns, the shift in y between successive rows. :example: >>> img = cv2.imread('/Users/simon/Desktop/Screenshot 2024-07-08 at 4.46.03 PM.png') >>> accepted_px_width = int(img.shape[1] / 4) >>> accepted_px_height = int(img.shape[0] / 10) >>>>text = 'HELLO MY FELLOW' >>> PlottingMixin().get_optimal_font_scales(text=text, accepted_px_width=accepted_px_width, accepted_px_height=accepted_px_height, text_thickness=2) """ check_int(name='accepted_px_width', value=accepted_px_width, min_value=1) check_int(name='accepted_px_height', value=accepted_px_height, min_value=1) check_int(name='text_thickness', value=text_thickness, min_value=1) check_int(name='font', value=font, min_value=0, max_value=7) if isinstance(text, list): check_valid_lst(data=text, valid_dtypes=(str,), min_len=1) text = max(text, key=len) else: check_str(name='text', value=text) for scale in reversed(range(0, 100, 1)): text_size = cv2.getTextSize(text, fontFace=font, fontScale=scale / 10, thickness=text_thickness) new_width, new_height = text_size[0][0], text_size[0][1] if (new_width <= accepted_px_width) and (new_height <= accepted_px_height): font_scale = scale / 10 x_shift = new_width + text_size[1] y_shift = new_height + text_size[1] return (font_scale, x_shift, y_shift) return (1, 1, 1)
[docs] @staticmethod def get_optimal_font_size_ttf(text: Union[str, List[str]], font_path: str, accepted_px_width: int, accepted_px_height: int, max_px: int = 400) -> Tuple[int, int, int]: """ Get the optimal font PIXEL size, column-wise and row-wise text distance for printing text on images using a TrueType/OpenType (.ttf/.otf) font rendered with PIL. This is the TTF counterpart of :meth:`get_optimal_font_scales`. Note that the returned value is a font PIXEL size (to be passed as ``font_size`` to :meth:`put_text` together with ``font_path``), NOT the cv2 scale factor returned by :meth:`get_optimal_font_scales`. The two are not interchangeable. :param Union[str, List[str]] text: The text to be printed. Either a string or a list of strings. If a list, then the longest string will be used to evaluate spacings/font. :param str font_path: Path to the .ttf/.otf font file to measure with. :param int accepted_px_width: The widest allowed string in pixels. E.g., 1/4th of the image width. :param int accepted_px_height: The highest allowed string in pixels. E.g., 1/10th of the image size. :param int max_px: The largest font pixel size to consider. The search counts down from this value. Default: 400. :return Tuple[int, int, int]: The font pixel size, the shift on x between successive columns, the shift in y between successive rows. :example: >>> size_px, x_shift, y_shift = PlottingMixin.get_optimal_font_size_ttf(text='HELLO MY FELLOW', font_path='Poppins-Regular.ttf', accepted_px_width=480, accepted_px_height=108) """ check_file_exist_and_readable(file_path=font_path) check_int(name='accepted_px_width', value=accepted_px_width, min_value=1) check_int(name='accepted_px_height', value=accepted_px_height, min_value=1) check_int(name='max_px', value=max_px, min_value=1) if isinstance(text, list): check_valid_lst(data=text, valid_dtypes=(str,), min_len=1) text = max(text, key=len) else: check_str(name='text', value=text) for size in reversed(range(1, max_px + 1)): font = PlottingMixin._load_ttf(font_path, size) l, t, r, b = font.getbbox(text) new_width, new_height = r - l, b - t if (new_width <= accepted_px_width) and (new_height <= accepted_px_height): x_shift = new_width + (b - new_height) y_shift = new_height + (b - new_height) return (size, x_shift, y_shift) return (1, 1, 1)
[docs] @staticmethod def get_optimal_font_spacing_ttf(font_path: str, size_px: int, text: Union[str, List[str]], gap: int = 1) -> int: """ Return the optimal vertical pixel pitch (distance between consecutive baselines) for stacking lines of TTF text rendered by :meth:`put_text`, so the per-line background boxes sit snugly without overlapping. The pitch equals the tallest rendered box height of the actual ``text`` being drawn (i.e. the tight ascender-to-descender bbox of the supplied strings at ``size_px``, plus the same padding that :meth:`put_text` adds to its background box), plus a small ``gap``. Measuring the real text - rather than a worst-case probe - keeps caps-only stacks tight (no wasted descender room) while still accommodating descenders when present. This is the row-spacing counterpart to the font PIXEL size returned by :meth:`get_optimal_font_size_ttf`; do not use the cv2 spacing from :meth:`get_optimal_font_scales` for TTF text, as the metrics differ. :param str font_path: Path to the .ttf/.otf font file. :param int size_px: Font pixel height (as passed to :meth:`put_text` when ``font_path`` is set). :param Union[str, List[str]] text: The string(s) that will be stacked. The tallest rendered box across them sets the pitch. :param int gap: Extra pixels inserted between consecutive line boxes. Default 1. :return int: The recommended vertical pitch in pixels. :example: >>> PlottingMixin.get_optimal_font_spacing_ttf(font_path='Poppins Regular.ttf', size_px=13, text=['TIMERS:', 'grooming']) """ check_int(name='size_px', value=size_px, min_value=1) check_int(name='gap', value=gap, min_value=0) if isinstance(text, str): text = [text] check_valid_lst(data=text, valid_dtypes=(str,), min_len=1) pil_font = PlottingMixin._load_ttf(font_path, size_px) pad = max(1, int(size_px * 0.1)) # mirror the padding used in put_text's TTF branch measure = ImageDraw.Draw(Image.new("RGB", (1, 1))) max_box_h = max((measure.textbbox((0, 0), s, font=pil_font, anchor="ls")[3] - measure.textbbox((0, 0), s, font=pil_font, anchor="ls")[1]) for s in text) return max_box_h + (2 * pad) + gap
[docs] def get_optimal_circle_size(self, frame_size: Tuple[int, int], circle_frame_ratio: Optional[int] = 100) -> int: """ Calculate the optimal circle size for fitting within a rectangular frame based on a given ratio. This method computes the diameter of a circle that fits within the smallest dimension of a rectangular frame, scaled by a specified ratio. The resulting circle size ensures that it fits within the bounds of the frame while maintaining the specified size ratio. :param Tuple[int, int] frame_size: A tuple representing the dimensions of the rectangular frame (width, height). :param Optional[int] circle_frame_ratio: An integer representing the ratio between the frame's smallest dimension and the circle's diameter. A lower ratio results in a larger circle, and a higher ratio results in a smaller circle. :return int: The computed diameter of the circle that fits within the smallest dimension of the frame, scaled by the `circle_frame_ratio`. """ check_int(name='accepted_circle_size', value=circle_frame_ratio, min_value=1) check_valid_tuple(x=frame_size, source='frame_size', accepted_lengths=(2,), valid_dtypes=(int,)) for i in frame_size: check_int(name='frame_size', value=i, min_value=1) return int(min(frame_size[0], frame_size[1]) / circle_frame_ratio)
[docs] def put_text(self, img: np.ndarray, text: str, pos: Tuple[int, int], font_size: Union[int, float], font_thickness: Optional[int] = 2, font: Optional[int] = cv2.FONT_HERSHEY_DUPLEX, font_path: Optional[str] = None, text_color: Optional[Tuple[int, int, int]] = (255, 255, 255), text_color_bg: Optional[Tuple[int, int, int]] = (0, 0, 0), text_bg_alpha: float = 0.8) -> np.ndarray: """ Draws text on an image with a background color and transparency. This method overlays text on an image at the specified position, with options for adjusting font size, thickness, background color, and background transparency. The text is drawn with an optional background rectangle that can have a specified transparency level to ensure readability over various image backgrounds. :param img: The image on which the text is to be drawn. This is a NumPy array representing the image data. :param text: The text string to be drawn on the image. :param pos: The position (x, y) where the text will be placed on the image. The coordinates correspond to the bottom-left corner of the text. :param font_size: The size of the font. When ``font_path`` is None, this is the cv2 scale factor multiplied by the font-specific base size. When ``font_path`` is passed, this is interpreted as the font PIXEL height. :param font_thickness: The thickness of the text strokes. It is an integer specifying the number of pixels for the thickness. Used only by the cv2 path (ignored when ``font_path`` is passed). :param font: The font type used to render the text. It corresponds to one of the predefined OpenCV Hershey font types (0-7). Ignored when ``font_path`` is passed. :param font_path: Optional path to a TrueType/OpenType (.ttf/.otf) font file. If passed, it takes precedence over ``font`` and the text is rendered with PIL using this font (e.g. for custom fonts such as Poppins). If None, the cv2 Hershey ``font`` is used. :param text_color: The color of the text in RGB format. By default, the text color is white. :param text_color_bg: The background color for the text in RGB format. By default, the background color is black. :param text_bg_alpha: The transparency level of the background rectangle. A value between 0 and 1, where 0 is fully transparent and 1 is fully opaque. :return: The image with the overlaid text and background rectangle. """ check_valid_tuple(x=pos, accepted_lengths=(2,), valid_dtypes=(int,)) check_if_valid_rgb_tuple(data=text_color) check_if_valid_rgb_tuple(data=text_color_bg) check_float(name='text_bg_alpha', value=text_bg_alpha, min_value=0, max_value=1.0) if font_path is not None: pil_font = PlottingMixin._load_ttf(font_path, int(font_size)) x, y = pos measure = ImageDraw.Draw(Image.new("RGB", (1, 1))) x0, y0, x1, y1 = measure.textbbox(pos, text, font=pil_font, anchor="ls") pad = max(1, int(font_size * 0.1)) X0, Y0 = max(0, x0 - pad), max(0, y0 - pad) X1, Y1 = min(img.shape[1], x1 + pad), min(img.shape[0], y1 + pad) if X1 <= X0 or Y1 <= Y0: return img roi = img[Y0:Y1, X0:X1] if text_bg_alpha > 0: bg = roi.copy(); bg[:] = text_color_bg cv2.addWeighted(bg, text_bg_alpha, roi, 1 - text_bg_alpha, 0, roi) pil = Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)) ImageDraw.Draw(pil).text((x - X0, y - Y0), text, font=pil_font, fill=text_color[::-1], anchor="ls") img[Y0:Y1, X0:X1] = cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR) return img check_int(name='font_thickness', value=font_thickness, min_value=1) check_int(name='font', value=font, min_value=0, max_value=7) x, y = pos if text_bg_alpha <= 0: cv2.putText(img, text, (x, y), font, font_size, text_color, font_thickness) return img text_size, px_buffer = cv2.getTextSize(text, font, font_size, font_thickness) w, h = text_size overlay, output = img.copy(), img.copy() cv2.rectangle(overlay, (x, y-h), (x + w, y + px_buffer), text_color_bg, -1) cv2.addWeighted(overlay, text_bg_alpha, output, 1 - text_bg_alpha, 0, output) cv2.putText(output, text, (x, y), font, font_size, text_color, font_thickness) return output
[docs] @staticmethod def plot_bar_chart(df: pd.DataFrame, x: str, y: str, error: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, fig_size: Tuple[int, int] = (10, 8), palette: str = 'magma', error_clr: str = 'grey', bar_alpha: float = 1.0, dpi: int = 600, orientation: Literal['vertical', 'horizontal'] = 'vertical', y_min: float = 0.0, y_max: Optional[float] = None, save_path: Optional[Union[str, os.PathLike]] = None, as_svg: bool = False): """ Create a bar chart from DataFrame columns. Generates a bar chart with optional error bars, supporting both vertical and horizontal orientations. Uses seaborn for styling with customizable colors, transparency, and axis limits. .. image:: _static/img/bar_chart_mosaic.webp :alt: Bar chart mosaic :width: 1000 :align: center :param pd.DataFrame df: DataFrame containing the data to plot. :param str x: Column name for x-axis categories. :param str y: Column name for y-axis values (must be numeric). :param Optional[str] error: Column name for error bar values. If None, no error bars are shown. :param Optional[str] x_label: X-axis label. If None, no label is displayed. :param Optional[str] y_label: Y-axis label. If None, no label is displayed. :param Optional[str] title: Chart title. If None, no title is displayed. :param Tuple[int, int] fig_size: Figure size (width, height) in inches. Default: (10, 8). :param str palette: Seaborn color palette name. Default: 'magma'. :param str error_clr: Color name for error bars. Default: 'grey'. :param float bar_alpha: Bar transparency (0.0-1.0). Default: 1.0. :param int dpi: Resolution for saved images. Default: 600. :param Literal['vertical', 'horizontal'] orientation: Bar orientation. Default: 'vertical'. :param float y_min: Minimum value for y-axis (or x-axis if horizontal). Default: 0.0. :param Optional[float] y_max: Maximum value for y-axis (or x-axis if horizontal). If None, auto-scales. Default: None. :param Optional[Union[str, os.PathLike]] save_path: Path to save the image. If None, returns matplotlib figure. :param bool as_svg: If True, saves as SVG format. If False, saves as PNG. Default: False. :return Optional[matplotlib.figure.Figure]: Returns matplotlib figure if ``save_path`` is None, otherwise returns None. """ check_instance(source=f"{PlottingMixin.plot_bar_chart.__name__} df", instance=df, accepted_types=(pd.DataFrame)) check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} x", value=x, options=tuple(df.columns)) check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} y", value=y, options=tuple(df.columns)) check_valid_lst(data=list(df[y]), source=f"{PlottingMixin.plot_bar_chart.__name__} y", valid_dtypes=Formats.NUMERIC_DTYPES.value) check_valid_boolean(value=as_svg, source=f"{PlottingMixin.plot_bar_chart.__name__} as_svg", raise_error=True) check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} error_clr", value=error_clr) check_float(name=f"{PlottingMixin.plot_bar_chart.__name__} bar_alpha", value=bar_alpha, min_value=0.0, max_value=1.0) check_int(name=f"{PlottingMixin.plot_bar_chart.__name__} dpi", value=dpi, min_value=1) check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} orientation", value=orientation, options=('vertical', 'horizontal')) check_float(name=f"{PlottingMixin.plot_bar_chart.__name__} y_min", value=y_min) if y_max is not None: check_float(name=f"{PlottingMixin.plot_bar_chart.__name__} y_max", value=y_max) if y_max <= y_min: raise InvalidInputError(msg=f"y_max ({y_max}) must be greater than y_min ({y_min})", source=f"{PlottingMixin.plot_bar_chart.__name__}") fig, ax = plt.subplots(figsize=fig_size) if error is not None: check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} error", value=error, options=tuple(df.columns)) check_valid_lst(data=list(df[error]), source=f"{PlottingMixin.plot_bar_chart.__name__} error",valid_dtypes=Formats.NUMERIC_DTYPES.value) if orientation == 'horizontal': sns.barplot(x=y, y=x, data=df, palette=palette, ax=ax, orient='h') ax.set_yticklabels(df[x].unique(), rotation=0, fontsize=8) if error is not None: for i, (value, error_val) in enumerate(zip(df[y], df[error])): ax.errorbar(value, i, xerr=[[0], [error_val]], fmt='o', color=error_clr, capsize=2) else: sns.barplot(x=x, y=y, data=df, palette=palette, ax=ax) ax.set_xticklabels(df[x].unique(), rotation=90, fontsize=8) if error is not None: for i, (value, error_val) in enumerate(zip(df[y], df[error])): ax.errorbar(i, value, yerr=[[0], [error_val]], fmt='o', color=error_clr, capsize=2) for patch in ax.patches: patch.set_alpha(bar_alpha) if orientation == 'horizontal': if y_max is not None: ax.set_xlim(left=y_min, right=y_max) else: ax.set_xlim(left=y_min) else: if y_max is not None: ax.set_ylim(bottom=y_min, top=y_max) else: ax.set_ylim(bottom=y_min) if x_label is not None: check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} x_label", value=x_label) plt.xlabel(x_label) if y_label is not None: check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} y_label", value=y_label) plt.ylabel(y_label) if title is not None: check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} title", value=title) plt.title(title, ha="center", fontsize=15) if save_path is not None: check_str(name=f"{PlottingMixin.plot_bar_chart.__name__} save_path", value=save_path) check_if_dir_exists(in_dir=os.path.dirname(save_path)) if as_svg: fig.savefig(save_path, dpi=dpi, format="svg", bbox_inches="tight") else: fig.savefig(save_path, dpi=dpi, bbox_inches='tight') else: return fig
@staticmethod def _plot_blobs(data: np.ndarray, verbose: bool, circle_size: float, circle_clr: tuple, video_path: os.PathLike, temp_dir: os.PathLike): group = int(data[0, 0]) fourcc = cv2.VideoWriter_fourcc(*"DIVX") temp_video_save_path = os.path.join(temp_dir, f"{group}.avi") video_meta_data = get_video_meta_data(video_path=video_path) video_writer = cv2.VideoWriter(temp_video_save_path, fourcc, video_meta_data['fps'], (video_meta_data["width"], video_meta_data["height"])) for frm_idx in range(data.shape[0]): frame_id = data[frm_idx][1] bps = data[frm_idx][2:4] frm = read_frm_of_video(video_path=video_path, frame_index=frame_id) cv2.circle(frm, center=(int(bps[0]), int(bps[1])), radius=circle_size, color=circle_clr, thickness=-1).astype(np.uint8) video_writer.write(frm) if verbose: print(f'Writing frame {frame_id} (frames count: {video_meta_data["frame_count"]}, Video: {video_meta_data["video_name"]})...') video_writer.release() return group
[docs] @staticmethod def plot_clf_cumcount(config_path: Union[str, os.PathLike], clf: str, data_dir: Optional[Union[str, os.PathLike]] = None, save_path: Optional[Union[str, os.PathLike]] = None, bouts: Optional[bool] = False, seconds: Optional[bool] = False) -> None: """ Generates and saves a cumulative count plot of a specified classifier's occurrences over video frames or time. .. image:: _static/img/plot_clf_cumcount.webp :alt: Plot clf cumcount :width: 500 :align: center :param Union[str, os.PathLike] config_path: Path to the configuration file, which includes settings and paths for data processing and storage. :param str clf: The classifier name (e.g., 'CIRCLING') for which to calculate cumulative counts. :param Optional[Union[str, os.PathLike]] data_dir: Directory containing the log files to analyze. If not provided, the default path in the configuration is used. :param Optional[Union[str, os.PathLike]] save_path: Destination path to save the plot image. If None, saves to the logs path in the configuration. :param Optional[bool] bouts: If True, calculates the cumulative count in terms of detected bouts instead of time or frames. :param Optional[bool] seconds: If True, calculates time in seconds rather than frames. :return: None. :example: >>> plot_clf_cumcount(config_path=r"D:\troubleshooting\mitra\project_folder\project_config.ini", clf='CIRCLING', data_dir=r'D:\troubleshooting\mitra\project_folder\logs\test', seconds=True, bouts=True) """ config = ConfigReader(config_path=config_path, read_video_info=True, create_logger=False) if data_dir is not None: check_if_dir_exists(in_dir=data_dir, source=f'{PlottingMixin.plot_clf_cumcount.__name__} data_dir') else: data_dir = config.machine_results_dir if save_path is None: save_path = os.path.join(config.logs_path, f'cumcount_{config.datetime}.png') data_paths = find_files_of_filetypes_in_directory(directory=data_dir, extensions=[f'.{config.file_type}'], raise_error=True) check_valid_boolean(value=[bouts, seconds], source=PlottingMixin.plot_clf_cumcount.__name__, raise_error=True) check_str(name=f'{PlottingMixin.plot_clf_cumcount.__name__} clf', value=clf) x_name = 'VIDEO TIME (FRAMES)' y_name = f'{clf} TIME (FRAMES)' if seconds: check_all_file_names_are_represented_in_video_log(video_info_df=config.video_info_df, data_paths=data_paths) x_name = f'VIDEO TIME (S)' if bouts: y_name = f'{clf} (BOUT COUNT)' clrs = create_color_palette(pallete_name='Set2', increments=len(data_paths), as_rgb_ratio=True) for file_cnt, file_path in enumerate(data_paths): _, video_name, _ = get_fn_ext(filepath=file_path) print(f'Analysing video {video_name} ({file_cnt + 1}/{len(data_paths)})...') df = read_df(file_path=file_path, file_type=config.file_type) check_valid_dataframe(df=df, source=f'{PlottingMixin.plot_clf_cumcount.__name__} {file_path}', required_fields=[clf]) if not bouts and not seconds: clf_sum = list(df[clf].cumsum().ffill()) time = list(df.index) elif not bouts and seconds: _, _, fps = read_video_info(vid_info_df=config.video_info_df, video_name=video_name) clf_sum = np.round(np.array(df[clf].cumsum().ffill() / fps), 2) time = list(df.index / fps) else: bout_starts = detect_bouts(data_df=df, target_lst=[clf], fps=1)['Start_frame'].values bouts_arr = np.full(len(df), fill_value=np.nan, dtype=np.float32) bouts_arr[0] = 0 for bout_cnt in range(bout_starts.shape[0]): bouts_arr[bout_starts[bout_cnt]] = bout_cnt + 1 clf_sum = pd.DataFrame(bouts_arr, columns=[clf]).ffill().values.reshape(-1) if seconds: _, _, fps = read_video_info(vid_info_df=config.video_info_df, video_name=video_name) time = list(df.index / fps) else: time = list(df.index) video_results = pd.DataFrame(data=clf_sum, columns=[y_name]) video_results['VIDEO'] = video_name video_results[x_name] = time sns.lineplot(data=video_results, x=x_name, y=y_name, hue="VIDEO", palette=[clrs[file_cnt]]) plt.savefig(save_path) config.timer.stop_timer() stdout_success(msg=f"Graph saved at {save_path}", elapsed_time=config.timer.elapsed_time_str)
@staticmethod def save_svg_markup(svg_markup: str, save_path: Union[str, os.PathLike]) -> None: check_str(name=f"{PlottingMixin.save_svg_markup.__name__} svg_markup", value=svg_markup, raise_error=True) check_str(name=f"{PlottingMixin.save_svg_markup.__name__} save_path", value=str(save_path), raise_error=True) check_if_dir_exists(in_dir=os.path.dirname(str(save_path))) print(svg_markup) with open(save_path, "w", encoding="utf-8") as f: f.write(svg_markup)
[docs] @staticmethod def get_path_img(data: np.ndarray, size: Optional[Tuple[int, int]] = None, # HxW line_thickness: float = 2, line_color: Union[Tuple[int, int, int], Literal['time', 'velocity']] = (147, 20, 255), bg_clr: Union[Tuple[int, int, int], np.ndarray] = (255, 255, 255), opacity: int = 1.0, smoothing_time: Optional[int] = None, save_path: Optional[Union[str, os.PathLike]] = None, svg: bool = False, dpi: int = 500) -> Optional[matplotlib.figure.Figure]: """ Create a path plot from NumPy array data. Generates a path visualization with optional time or velocity-based coloring, background images, smoothing, and SVG/PNG output. .. image:: _static/img/get_path_img.webp :alt: Get path img :width: 1000 :align: center :param np.ndarray data: 2D array with shape (N, 2) containing x, y coordinates. :param Optional[Tuple[int, int]] size: Image size as (height, width) in pixels. If None, auto-calculated from data or bg_img. :param float line_thickness: Thickness of the path line. Default: 2. :param Tuple[int, int, int] line_color: RGB color tuple (0-255). Default: (147, 20, 255). :param Tuple[int, int, int] bg_clr: Background color RGB tuple (0-255). Default: (255, 255, 255). :param Optional[np.ndarray] bg_img: Background image array. If provided, overrides bg_clr. Default: None. :param float opacity: Line opacity (0.0-1.0). Default: 1.0. :param Optional[int] smoothing_time: Smoothing time window in milliseconds. Applies Savitzky-Golay filter. If None, no smoothing. Default: None. :param Optional[Literal['time', 'velocity']] color_by: Color path by time progression or velocity. If None, uses line_color. Default: None. :param Optional[Union[str, os.PathLike]] save_path: Path to save the image. If None, returns figure. :param bool svg: If True, saves as SVG format. If False, saves as PNG. Default: False. :param int dpi: Resolution for saved images. Default: 500. :return Optional[matplotlib.figure.Figure]: Returns matplotlib figure if save_path is None, otherwise None. .. seealso:: For more complex path plots with multiprocessing and advanced features, see :class:`simba.plotting.path_plotter.PathPlotterSingleCore` and :class:`simba.plotting.path_plotter_mp.PathPlotterMulticore`. :example: >>> df = pd.read_csv('/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/csv/outlier_corrected_movement_location/2022-06-20_NOB_DOT_4.csv') >>> data = df[['Nose_x', 'Nose_y']].values >>> img = read_frm_of_video(video_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/videos/2022-06-20_NOB_DOT_4.mp4', frame_index=400) >>> PlottingMixin().get_path_img(data=data, >>> size=(1080, 1080), >>> line_thickness=0.5, >>> line_color=(0, 255, 0), >>> bg_clr=(255, 255, 255), >>> bg_img=img, >>> save_path='/Users/simon/Desktop/envs/simba/troubleshooting/RAT_NOR/project_folder/csv/outlier_corrected_movement_location/2022-06-20_NOB_DOT_4_3.png', >>> dpi=600, >>> opacity=1.0, >>> color_by=None, >>> svg=False, >>> smoothing_time=5000) """ check_valid_array(data=data, source=f"{PlottingMixin.get_path_img.__name__} end_positions", accepted_ndims=(2,), min_axis_1=2, accepted_dtypes=Formats.NUMERIC_DTYPES.value) if size is not None: check_instance(source=f"{PlottingMixin.get_path_img.__name__} size", instance=size, accepted_types=(tuple,)) check_valid_lst(data=list(size), source=f"{PlottingMixin.get_path_img.__name__} size", exact_len=2, valid_dtypes=(int,)) check_int(name=f"{PlottingMixin.get_path_img.__name__} size[0]", value=size[0], min_value=1) check_int(name=f"{PlottingMixin.get_path_img.__name__} size[1]", value=size[1], min_value=1) check_float(name=f"{PlottingMixin.get_path_img.__name__} line_thickness", value=line_thickness, allow_zero=False, allow_negative=False) check_float(name=f"{PlottingMixin.get_path_img.__name__} opacity", value=opacity, allow_zero=False, allow_negative=False) if isinstance(line_color, tuple): check_if_valid_rgb_tuple(data=line_color, raise_error=True, source=PlottingMixin.get_path_img.__name__) else: check_str(name=f'{PlottingMixin.get_path_img.__name__}', value=line_color, options=['time', 'velocity']) if isinstance(bg_clr, tuple): check_if_valid_rgb_tuple(data=bg_clr, raise_error=True, source=PlottingMixin.get_path_img.__name__) else: check_if_valid_img(data=bg_clr, source=f'{PlottingMixin.get_path_img.__name__} bg_clr', raise_error=True) check_int(name=f"{PlottingMixin.get_path_img.__name__} dpi", value=dpi, min_value=1) if save_path is not None: check_str(name=f"{PlottingMixin.get_path_img.__name__} save_path", value=save_path) check_if_dir_exists(in_dir=os.path.dirname(save_path)) if smoothing_time is not None: check_int(name=f"{PlottingMixin.get_path_img.__name__} smoothing_time", value=smoothing_time, min_value=1) data = savgol_smoother(data=data, fps=1, time_window=smoothing_time, source=PlottingMixin.get_path_img.__name__) x, y = data[:, 0], data[:, 1] if isinstance(bg_clr, np.ndarray): if size is None: size = (bg_clr.shape[0], bg_clr.shape[1]) # HxW else: if bg_clr.shape[0] != size[0] or bg_clr.shape[1] != size[1]: bg_clr = cv2.resize(bg_clr, (size[1], size[0])) elif size is None: size = (int(np.max(y.flatten())), int(np.max(x.flatten()))) fig, ax = plt.subplots(figsize=(size[1] / dpi, size[0] / dpi)) if isinstance(bg_clr, np.ndarray): if bg_clr.shape[2] == 3: bg_clr_rgb = cv2.cvtColor(bg_clr, cv2.COLOR_BGR2RGB) else: bg_clr_rgb = bg_clr ax.imshow(bg_clr_rgb, extent=[0, size[1], size[0], 0], origin='upper') else: bg_clr_normalized = (bg_clr[0] / 255.0, bg_clr[1] / 255.0, bg_clr[2] / 255.0) fig.patch.set_facecolor(bg_clr_normalized) ax.set_facecolor(bg_clr_normalized) if isinstance(line_color, tuple): clr = (line_color[0] / 255.0, line_color[1] / 255.0, line_color[2] / 255.0) ax.plot(x, y, color=clr, linewidth=line_thickness, alpha=opacity) elif line_color == 'time': points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) colors = create_color_palette(pallete_name='jet', increments=len(segments), as_rgb_ratio=True) colors = [tuple(c) for c in colors] lc = LineCollection(segments, colors=colors, linewidths=line_thickness, alpha=opacity, capstyle='round', joinstyle='round', antialiaseds=True) ax.add_collection(lc) elif line_color == 'velocity': dx, dy = np.diff(x), np.diff(y) velocity = np.sqrt(dx ** 2 + dy ** 2) points = np.array([x, y]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) if velocity.max() > velocity.min(): velocity_norm = (velocity - velocity.min()) / (velocity.max() - velocity.min()) else: velocity_norm = np.zeros_like(velocity) palette_colors = create_color_palette(pallete_name='jet', increments=len(segments), as_rgb_ratio=True) colors = [tuple(palette_colors[int(v * (len(palette_colors) - 1))]) for v in velocity_norm] lc = LineCollection(segments, colors=colors, linewidths=line_thickness, alpha=opacity, capstyle='round', joinstyle='round', antialiaseds=True) ax.add_collection(lc) ax.axis('off') if isinstance(bg_clr, np.ndarray): ax.set_xlim(0, size[1]) ax.set_ylim(size[0], 0) else: ax.set_xlim(x.min(), x.max()) ax.set_ylim(y.max(), y.min()) if save_path is not None: if isinstance(bg_clr, np.ndarray): facecolor = (1.0, 1.0, 1.0) else: facecolor = (bg_clr[0] / 255.0, bg_clr[1] / 255.0, bg_clr[2] / 255.0) fig.savefig(save_path, dpi=dpi, bbox_inches='tight', pad_inches=0, format='svg' if svg else 'png', facecolor=facecolor) plt.close(fig) return None else: return fig