__author__ = "Simon Nilsson; sronilsson@gmail.com"
import math
from copy import deepcopy
from tkinter import Event, Toplevel
from typing import List, Optional, Tuple
import cv2
import numpy as np
import pandas as pd
from PIL import Image, ImageTk
from shapely.geometry import MultiPolygon, Point, Polygon
from shapely.ops import unary_union
from simba.mixins.geometry_mixin import GeometryMixin
from simba.mixins.plotting_mixin import PlottingMixin
from simba.roi_tools.roi_utils import (create_circle_entry,
create_rectangle_entry,
get_circle_df_headers,
get_image_from_label,
get_polygon_df_headers,
get_rectangle_df_headers)
from simba.utils.checks import check_instance
from simba.utils.enums import ROI_SETTINGS, Keys, TkBinds
TAGS, SHAPE_TYPE = 'Tags', 'Shape_type'
def _plot_roi(roi_dict: dict,
img: np.ndarray):
rectangles_df, circles_df, polygon_df = pd.DataFrame(columns=get_rectangle_df_headers()), pd.DataFrame(columns=get_circle_df_headers()), pd.DataFrame(columns=get_polygon_df_headers())
for roi_name, roi_data in roi_dict.items():
if (roi_data['Shape_type'].lower() == ROI_SETTINGS.RECTANGLE.value):
rectangles_df = pd.concat([rectangles_df, pd.DataFrame([roi_data])], ignore_index=True)
elif roi_data['Shape_type'].lower() == ROI_SETTINGS.CIRCLE.value:
circles_df = pd.concat([circles_df, pd.DataFrame([roi_data])], ignore_index=True)
elif roi_data['Shape_type'].lower() == ROI_SETTINGS.POLYGON.value:
polygon_df = pd.concat([polygon_df, pd.DataFrame([roi_data])], ignore_index=True)
roi_dict = {Keys.ROI_RECTANGLES.value: rectangles_df, Keys.ROI_CIRCLES.value: circles_df, Keys.ROI_POLYGONS.value: polygon_df}
img = PlottingMixin.roi_dict_onto_img(img=img, roi_dict=roi_dict, circle_size=None, show_tags=False, show_center=True, omitted_centers=list(polygon_df['Name'].unique()))
return img
[docs]class InteractiveROIBufferer():
"""
Interactive Tkinter-based tool for buffering (expanding or shrinking) ROI shapes by specified metric millimeter by clicking on their tags.
:param Toplevel img_window: Tkinter Toplevel window containing an image label named 'img_lbl' displaying the ROI image.
:param np.ndarray original_img: Original image as a numpy array in BGR format. Used as the base for redrawing ROIs.
:param dict roi_dict: Dictionary containing ROI definitions. Keys are ROI names (str), values are dictionaries with ROI properties including 'Shape_type', 'Tags', 'Color BGR', 'Thickness', 'Ear_tag_size', etc. Expected shape types: 'Rectangle', 'Circle', or 'Polygon'.
:param int buffer_mm: Buffer distance in millimeters. Positive values expand the shape, negative values shrink it.
:param float px_per_mm: Pixels per millimeter conversion factor. Used to convert buffer_mm to pixels. Must be > 0.
:param Optional[dict] settings: Optional dictionary of ROI settings. If None, uses default ROI_SETTINGS values. Default None.
:param Optional[List[Polygon]] hex_grid: Optional list of Shapely Polygon objects representing a hexagon grid overlay. Default None.
:param Optional[List[Polygon]] rectangle_grid: Optional list of Shapely Polygon objects representing a rectangle grid overlay. Default None.
"""
def __init__(self,
img_window: Toplevel,
original_img: np.ndarray,
roi_dict: dict,
buffer_mm: int,
px_per_mm: float,
settings: Optional[dict] = None,
hex_grid: Optional[List[Polygon]] = None,
rectangle_grid: Optional[List[Polygon]] = None):
check_instance(source=self.__class__.__name__, instance=img_window, accepted_types=(Toplevel,))
if settings is None: settings = {item.name: item.value for item in ROI_SETTINGS}
self.hex_grid, self.rectangle_grid = hex_grid, rectangle_grid
self.img_lbl = img_window.nametowidget("img_lbl")
self.img = get_image_from_label(self.img_lbl)
self.original_img, self.roi_dict = deepcopy(original_img), deepcopy(roi_dict)
_plot_roi(roi_dict=self.roi_dict, img=self.original_img.copy())
self.img_w, self.img_h = self.img.shape[1], self.img.shape[0]
self.img_window, self.settings, self.buffer_mm, self.px_per_mm = img_window, settings, buffer_mm, px_per_mm
self.bind_mouse()
def _find_closest_tag(self, roi_dict: dict, click_coordinate: Tuple[int, int]):
clicked_roi, clicked_tag = None, None
for roi_name, roi_data in roi_dict.items():
ear_tag_size = roi_data['Ear_tag_size']
for roi_tag_name, roi_tag_coordinate in roi_data[TAGS].items():
distance = math.sqrt((roi_tag_coordinate[0] - click_coordinate[0]) ** 2 + (roi_tag_coordinate[1] - click_coordinate[1]) ** 2)
if distance <= ear_tag_size:
clicked_roi, clicked_tag = roi_data, roi_tag_name
return clicked_roi, clicked_tag
def bind_mouse(self):
self.img_window.bind(TkBinds.B1_PRESS.value, self.left_mouse_down)
def unbind_mouse(self):
self.img_window.unbind(TkBinds.B1_PRESS.value)
def __update_image(self, img):
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(img_rgb)
tk_image = ImageTk.PhotoImage(pil_image)
self.img_lbl.configure(image=tk_image)
self.img_lbl.image = tk_image
def left_mouse_down(self, event: Event):
self.click_loc = (event.x, event.y)
self.clicked_roi, self.clicked_tag = self._find_closest_tag(roi_dict=self.roi_dict, click_coordinate=self.click_loc)
if self.clicked_roi is not None and self.clicked_roi[SHAPE_TYPE] != ROI_SETTINGS.POLYGON.value:
original_tags = self.clicked_roi[TAGS]
buffer_px = int(self.buffer_mm / self.px_per_mm)
if self.clicked_roi[SHAPE_TYPE] == ROI_SETTINGS.CIRCLE.value:
center, radius = original_tags['Center tag'], self.clicked_roi['radius']
roi_geometry = Point(center).buffer(distance=radius)
new_geometry = GeometryMixin().buffer_shape(shape=roi_geometry, size_mm=self.buffer_mm, pixels_per_mm=self.px_per_mm, resolution=1)
if isinstance(new_geometry, MultiPolygon):
new_geometry = unary_union(new_geometry)
if isinstance(new_geometry, MultiPolygon):
new_geometry = max(new_geometry.geoms, key=lambda p: p.area)
minx, miny, maxx, maxy = new_geometry.bounds
center_x, center_y = (minx + maxx) / 2, (miny + maxy) / 2
new_center = (int(center_x), int(center_y))
new_radius = int((maxx - minx) / 2)
self.circle_center, self.circle_radius = new_center, new_radius
self.width, self.height, self.center = new_radius * 2, new_radius * 2, new_center
self.top_left = (new_center[0] - new_radius, new_center[1] - new_radius)
self.bottom_right = (new_center[0] + new_radius, new_center[1] + new_radius)
self.left_border_tag = (new_center[0] - new_radius, new_center[1])
new_roi = create_circle_entry(circle_selector=self, video_name=self.clicked_roi['Video'], shape_name=self.clicked_roi['Name'], clr_name=self.clicked_roi['Color name'], clr_bgr=self.clicked_roi['Color BGR'], thickness=self.clicked_roi['Thickness'], ear_tag_size=int(self.clicked_roi['Ear_tag_size']), px_conversion_factor=self.px_per_mm)
else:
center_tag_names = ['Center tag', 'Center_tag']
corner_tags = {k: v for k, v in original_tags.items() if k not in center_tag_names}
if len(corner_tags) > 0:
tag_coords = np.array(list(corner_tags.values()))
centroid = np.mean(tag_coords, axis=0)
buffered_tags = {}
for tag_name, tag_coord in corner_tags.items():
tag_coord_np = np.array(tag_coord)
vec_to_tag = tag_coord_np - centroid
vec_length = np.linalg.norm(vec_to_tag)
if vec_length > 1e-10:
vec_normalized = vec_to_tag / vec_length
new_coord = tag_coord_np + vec_normalized * buffer_px
new_coord[0] = max(0, min(new_coord[0], self.img_w - 1))
new_coord[1] = max(0, min(new_coord[1], self.img_h - 1))
buffered_tags[tag_name] = tuple(new_coord.astype(int))
else:
buffered_tags[tag_name] = tag_coord
buffered_coords = np.array(list(buffered_tags.values()))
new_center = tuple(np.mean(buffered_coords, axis=0).astype(int))
if 'Center tag' in original_tags: buffered_tags['Center tag'] = new_center
elif 'Center_tag' in original_tags: buffered_tags['Center_tag'] = new_center
else:
buffered_tags = original_tags.copy()
self.top_left = buffered_tags.get('Top left tag', buffered_tags.get('Tag_0', (0, 0)))
self.bottom_right = buffered_tags.get('Bottom right tag', buffered_tags.get('Tag_2', (0, 0)))
self.top_right_tag = buffered_tags.get('Top right tag', buffered_tags.get('Tag_1', (0, 0)))
self.bottom_left_tag = buffered_tags.get('Bottom left tag', buffered_tags.get('Tag_3', (0, 0)))
self.center = buffered_tags.get('Center tag', buffered_tags.get('Center_tag', (0, 0)))
self.left_tag = buffered_tags.get('Left tag', (self.top_left[0], (self.top_left[1] + self.bottom_right[1]) // 2))
self.right_tag = buffered_tags.get('Right tag', (self.bottom_right[0], (self.top_left[1] + self.bottom_right[1]) // 2))
self.top_tag = buffered_tags.get('Top tag', ((self.top_left[0] + self.bottom_right[0]) // 2, self.top_left[1]))
self.bottom_tag = buffered_tags.get('Bottom tag', ((self.top_left[0] + self.bottom_right[0]) // 2, self.bottom_right[1]))
self.width = abs(self.bottom_right[0] - self.top_left[0])
self.height = abs(self.bottom_right[1] - self.top_left[1])
self.circle_radius = int(self.width / 2)
self.left_border_tag = self.left_tag
new_roi = create_rectangle_entry(rectangle_selector=self, video_name=self.clicked_roi['Video'], shape_name=self.clicked_roi['Name'], clr_name=self.clicked_roi['Color name'], clr_bgr=self.clicked_roi['Color BGR'], thickness=self.clicked_roi['Thickness'], ear_tag_size=int(self.clicked_roi['Ear_tag_size']), px_conversion_factor=self.px_per_mm)
self.roi_dict[self.clicked_roi['Name']] = new_roi
self.temp_img = _plot_roi(roi_dict=self.roi_dict, img=self.original_img.copy())
self.__update_image(img=self.temp_img)