import glob
import os
import random
import shutil
from typing import Optional, Tuple, Union
import cv2
import numpy as np
import pandas as pd
from simba.utils.checks import check_if_dir_exists, check_int
from simba.utils.errors import InvalidInputError
from simba.utils.printing import SimbaTimer, stdout_success, stdout_warning
from simba.utils.read_write import get_fn_ext
[docs]class CropLPAnnotations:
"""
Creates a new, self-contained Lightning Pose project from an existing one where
every labeled image is cropped to a fixed size around the annotated animal.
When running inference on cropped frames (e.g. from an object detector), the
training data should also be cropped so the model sees the same distribution
at train and test time. This class produces a new LP project where every
labeled frame is cropped around the animal's keypoints, ready for training
a model that will run inference on crops — without any re-labeling.
The output project is ready for training/inference: configs, calibrations, models,
scripts, and ``project.yaml`` are copied, config paths are updated to the new
location, and all ``CollectedData_*.csv`` keypoint coordinates are shifted to
match the cropped frames. A row is only dropped when it is all-NaN in every
camera view; in any other case each view is processed independently. For a
view that has all-NaN keypoints but valid keypoints in some other view, the
image is center-cropped and the keypoint coords stay NaN.
:param str lp_project_dir: Root of the source LP project (e.g. ``Z:/home/simon/lp_300126``).
:param str save_dir: Root of the new cropped LP project.
:param Tuple[int, int] crop_size: Output crop ``(width, height)`` in pixels (e.g. ``(512, 512)``).
Each crop is centered on the keypoint centroid per frame.
:param Optional[Union[bool, int]] visualize: If ``True``, save annotated overlay images for every
cropped frame to ``save_dir/visualizations/``. If ``int``, save that many
randomly sampled overlays. ``None`` / ``False`` disables visualization.
:param Optional[int] padding: Minimum number of pixels between any keypoint and the crop edge.
The crop window is shifted (within image bounds) so that all keypoints
are at least ``padding`` pixels from each border. If the keypoint span
plus ``2 * padding`` exceeds ``crop_size`` in either dimension, a warning
is printed and the padding is best-effort. ``None`` is treated as ``0``.
.. seealso::
:class:`~simba.third_party_label_appenders.transform.litpose_crop_annotations_bbox_square.CropLPAnnotationsBboxSquare`
Bounding-box-based square crop that pads and resizes — matches inference-time crop behavior.
"""
def __init__(self,
lp_project_dir: str,
save_dir: str,
crop_size: Tuple[int, int] = (512, 512),
visualize: Optional[Union[bool, int]] = None,
padding: Optional[int] = None):
check_if_dir_exists(in_dir=lp_project_dir)
check_int(name="CropLPAnnotations crop_size width", value=crop_size[0], min_value=1)
check_int(name="CropLPAnnotations crop_size height", value=crop_size[1], min_value=1)
if isinstance(visualize, int) and not isinstance(visualize, bool):
check_int(name="CropLPAnnotations visualize", value=visualize, min_value=1)
if padding is not None:
check_int(name="CropLPAnnotations padding", value=padding, min_value=0)
self.lp_project_dir = lp_project_dir
self.save_dir = save_dir
self.crop_size = crop_size
self.visualize = visualize
self.padding = padding if padding is not None else 0
self.csv_paths = sorted([os.path.join(lp_project_dir, f) for f in os.listdir(lp_project_dir)
if f.startswith("CollectedData_") and f.endswith(".csv")])
if len(self.csv_paths) == 0:
raise InvalidInputError(msg=f"No CollectedData_*.csv files found in {lp_project_dir}.")
check_if_dir_exists(in_dir=os.path.join(lp_project_dir, "labeled-data"))
def run(self):
os.makedirs(self.save_dir, exist_ok=True)
timer = SimbaTimer(start=True)
self._copy_project_files(lp_project_dir=self.lp_project_dir, save_dir=self.save_dir)
self._update_config_paths(lp_project_dir=self.lp_project_dir, save_dir=self.save_dir)
drop_positions = self._find_all_nan_positions(csv_paths=self.csv_paths)
if drop_positions:
stdout_warning(msg=f"Dropping {len(drop_positions)} row position(s) all-NaN in every view: {sorted(drop_positions)}")
viz_candidates = []
for csv_path in self.csv_paths:
self._process_csv(csv_path=csv_path,
lp_project_dir=self.lp_project_dir,
save_dir=self.save_dir,
crop_size=self.crop_size,
viz_candidates=viz_candidates,
drop_positions=drop_positions,
padding=self.padding)
if self.visualize and len(viz_candidates) > 0:
viz_dir = os.path.join(self.save_dir, "visualizations")
os.makedirs(viz_dir, exist_ok=True)
if isinstance(self.visualize, int) and not isinstance(self.visualize, bool):
sample = random.sample(viz_candidates, min(self.visualize, len(viz_candidates)))
else:
sample = viz_candidates
for cropped_path, new_xs, new_ys, bp_names, viz_fn in sample:
viz_img = cv2.imread(cropped_path)
for bp_idx in range(len(bp_names)):
bx, by = new_xs[bp_idx], new_ys[bp_idx]
if np.isnan(bx) or np.isnan(by):
continue
pt = (int(round(bx)), int(round(by)))
cv2.circle(viz_img, pt, 4, (0, 0, 255), -1)
cv2.putText(viz_img, bp_names[bp_idx], (pt[0] + 6, pt[1] - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 255, 0), 1, cv2.LINE_AA)
cv2.imwrite(os.path.join(viz_dir, viz_fn), viz_img)
stdout_success(msg=f"Saved {len(sample)} visualizations in {viz_dir}")
timer.stop_timer()
stdout_success(msg=f"Cropped LP annotations saved in {self.save_dir}", elapsed_time=timer.elapsed_time_str)
@staticmethod
def _copy_project_files(lp_project_dir: str, save_dir: str):
COPY_DIRS = ("configs", "calibrations", "models")
COPY_EXTS = (".yaml", ".yml", ".sh", ".json", ".jsonl", ".zip", ".txt")
for d in COPY_DIRS:
src = os.path.join(lp_project_dir, d)
dst = os.path.join(save_dir, d)
if os.path.isdir(src):
if os.path.exists(dst):
shutil.rmtree(dst)
shutil.copytree(src, dst)
stdout_success(msg=f"Copied directory {d}/", source="CropLPAnnotations")
for entry in os.listdir(lp_project_dir):
src_path = os.path.join(lp_project_dir, entry)
if not os.path.isfile(src_path):
continue
_, ext = os.path.splitext(entry)
if ext.lower() in COPY_EXTS:
shutil.copy2(src_path, os.path.join(save_dir, entry))
stdout_success(msg=f"Copied {entry}", source="CropLPAnnotations")
@staticmethod
def _to_posix_path(path: str) -> str:
"""Convert a path to POSIX format, stripping Windows drive letters (e.g. ``Z:/home/...`` -> ``/home/...``)."""
p = path.replace("\\", "/")
if len(p) >= 2 and p[1] == ":":
p = p[2:]
return p
def _update_config_paths(self, lp_project_dir: str, save_dir: str):
import re
yaml_files = glob.glob(os.path.join(save_dir, "**", "*.yaml"), recursive=True)
yaml_files += glob.glob(os.path.join(save_dir, "**", "*.yml"), recursive=True)
if len(yaml_files) == 0:
return
old_posix = self._to_posix_path(lp_project_dir)
new_posix = self._to_posix_path(save_dir)
VIDEO_KEYS = ("video_dir", "test_videos_directory")
n_updated = 0
for yaml_path in yaml_files:
with open(yaml_path, "r") as f:
lines = f.readlines()
changed = False
for i, line in enumerate(lines):
if old_posix not in line:
continue
key = line.split(":")[0].strip() if ":" in line else ""
if key in VIDEO_KEYS:
continue
lines[i] = line.replace(old_posix, new_posix)
changed = True
if changed:
with open(yaml_path, "w") as f:
f.writelines(lines)
n_updated += 1
if n_updated > 0:
stdout_success(msg=f"Updated paths in {n_updated} config file(s) to {new_posix}", source="CropLPAnnotations")
@staticmethod
def _row_all_nan(coords: np.ndarray) -> bool:
xs, ys = coords[0::2], coords[1::2]
return not np.any(~np.isnan(xs) & ~np.isnan(ys))
@staticmethod
def _find_all_nan_positions(csv_paths):
"""Return the set of row positions that are all-NaN in EVERY CSV."""
per_csv_nan = []
for csv_path in csv_paths:
df = pd.read_csv(csv_path, header=[0, 1, 2], index_col=0)
nan_set = set()
for row_pos in range(len(df)):
coords = df.iloc[row_pos].values.astype(float)
if CropLPAnnotations._row_all_nan(coords):
nan_set.add(row_pos)
per_csv_nan.append(nan_set)
if not per_csv_nan:
return set()
common = per_csv_nan[0]
for s in per_csv_nan[1:]:
common = common & s
return common
def _process_csv(self,
csv_path: str,
lp_project_dir: str,
save_dir: str,
crop_size: Tuple[int, int],
viz_candidates: list,
drop_positions: set,
padding: int = 0):
_, csv_fn, csv_ext = get_fn_ext(filepath=csv_path)
df = pd.read_csv(csv_path, header=[0, 1, 2], index_col=0)
bp_names = [df.columns[i][1] for i in range(0, len(df.columns), 2)]
out_rows = []
for row_pos in range(len(df)):
if row_pos in drop_positions:
continue
idx = df.index[row_pos]
coords = df.loc[idx].values.astype(float)
xs = coords[0::2]
ys = coords[1::2]
valid = ~np.isnan(xs) & ~np.isnan(ys)
img_rel = str(idx)
img_path = os.path.join(lp_project_dir, img_rel.replace("/", os.sep))
if not os.path.isfile(img_path):
stdout_warning(msg=f"{csv_fn}: skipped row_pos={row_pos} idx={idx} (image file not found)")
continue
img = cv2.imread(img_path)
if img is None:
stdout_warning(msg=f"{csv_fn}: skipped row_pos={row_pos} idx={idx} (cv2.imread returned None)")
continue
h, w = img.shape[:2]
if w < crop_size[0] or h < crop_size[1]:
scale = max(crop_size[0] / w, crop_size[1] / h)
new_w, new_h = int(np.ceil(w * scale)), int(np.ceil(h * scale))
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
xs = np.where(np.isnan(xs), np.nan, xs * scale)
ys = np.where(np.isnan(ys), np.nan, ys * scale)
h, w = img.shape[:2]
if np.any(valid):
x_valid, y_valid = xs[valid], ys[valid]
x_min, x_max = float(np.min(x_valid)), float(np.max(x_valid))
y_min, y_max = float(np.min(y_valid)), float(np.max(y_valid))
cx = (x_min + x_max) / 2.0
cy = (y_min + y_max) / 2.0
else:
cx, cy = w / 2.0, h / 2.0
x_min, x_max = cx, cx
y_min, y_max = cy, cy
half_w = crop_size[0] / 2.0
half_h = crop_size[1] / 2.0
crop_x1 = int(np.floor(cx - half_w))
crop_y1 = int(np.floor(cy - half_h))
crop_x2 = crop_x1 + crop_size[0]
crop_y2 = crop_y1 + crop_size[1]
if crop_x1 < 0:
crop_x1, crop_x2 = 0, crop_size[0]
if crop_y1 < 0:
crop_y1, crop_y2 = 0, crop_size[1]
if crop_x2 > w:
crop_x2, crop_x1 = w, w - crop_size[0]
if crop_y2 > h:
crop_y2, crop_y1 = h, h - crop_size[1]
if padding > 0 and np.any(valid):
kp_span_w = (x_max - x_min) + 2 * padding
kp_span_h = (y_max - y_min) + 2 * padding
if kp_span_w > crop_size[0] or kp_span_h > crop_size[1]:
stdout_warning(msg=f"{csv_fn}: row_pos={row_pos} keypoint span + 2*padding ({kp_span_w:.0f}x{kp_span_h:.0f}) exceeds crop_size {crop_size}; padding is best-effort.")
need_x1 = x_min - padding
need_x2 = x_max + padding
if need_x1 < crop_x1:
crop_x1 = max(int(np.floor(need_x1)), 0)
crop_x2 = crop_x1 + crop_size[0]
if crop_x2 > w:
crop_x2, crop_x1 = w, w - crop_size[0]
if need_x2 > crop_x2:
crop_x2 = min(int(np.ceil(need_x2)), w)
crop_x1 = crop_x2 - crop_size[0]
if crop_x1 < 0:
crop_x1, crop_x2 = 0, crop_size[0]
need_y1 = y_min - padding
need_y2 = y_max + padding
if need_y1 < crop_y1:
crop_y1 = max(int(np.floor(need_y1)), 0)
crop_y2 = crop_y1 + crop_size[1]
if crop_y2 > h:
crop_y2, crop_y1 = h, h - crop_size[1]
if need_y2 > crop_y2:
crop_y2 = min(int(np.ceil(need_y2)), h)
crop_y1 = crop_y2 - crop_size[1]
if crop_y1 < 0:
crop_y1, crop_y2 = 0, crop_size[1]
cropped = img[crop_y1:crop_y2, crop_x1:crop_x2]
out_img_path = os.path.join(save_dir, img_rel.replace("/", os.sep))
os.makedirs(os.path.dirname(out_img_path), exist_ok=True)
cv2.imwrite(out_img_path, cropped)
new_coords = coords.copy()
new_xs = np.where(np.isnan(xs), np.nan, xs - crop_x1)
new_ys = np.where(np.isnan(ys), np.nan, ys - crop_y1)
new_coords[0::2] = new_xs
new_coords[1::2] = new_ys
out_rows.append((idx, new_coords))
if self.visualize:
parts = img_rel.replace("/", os.sep).split(os.sep)
viz_fn = f"{parts[-2]}_{parts[-1]}" if len(parts) >= 2 else parts[-1]
viz_candidates.append((out_img_path, new_xs.copy(), new_ys.copy(), list(bp_names), viz_fn))
if len(out_rows) == 0:
stdout_warning(msg=f"No valid rows in {csv_fn}{csv_ext}.")
return
indices, data = zip(*out_rows)
out_df = pd.DataFrame(np.array(data), index=list(indices), columns=df.columns)
out_df.index.name = df.index.name
out_csv_path = os.path.join(save_dir, f"{csv_fn}{csv_ext}")
out_df.to_csv(out_csv_path)
stdout_success(msg=f"Saved {out_csv_path} ({len(out_df)} rows).")
#
if __name__ == "__main__":
cropper = CropLPAnnotations(lp_project_dir=r"Z:\home\simon\LPProjects\mini_project_0504",
save_dir=r"Z:\home\simon\LPProjects\mini_project_0504_cropped_2",
crop_size=(512, 512),
visualize=True,
padding=20)
cropper.run()