diff --git a/docs/trackers/core/ksp/tracker.md b/docs/trackers/core/ksp/tracker.md new file mode 100644 index 00000000..32c11268 --- /dev/null +++ b/docs/trackers/core/ksp/tracker.md @@ -0,0 +1,186 @@ +--- +comments: true +--- + +# KSP + +[![IEEE](https://img.shields.io/badge/IEEE-10.1109/TPAMI.2011.21-blue.svg)](https://doi.org/10.1109/TPAMI.2011.21) +[![arXiv](https://img.shields.io/badge/arXiv-1808.01562-b31b1b.svg)](https://arxiv.org/abs/1808.01562) +[![PDF (Unofficial)](https://img.shields.io/badge/PDF-Stanford--Preprint-red.svg)](http://vision.stanford.edu/teaching/cs231b_spring1415/papers/Berclaz-tracking.pdf) +[![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-track-objects-with-sort-tracker.ipynb) + +## Overview + +**KSP Tracker** (K-Shortest Paths Tracker) is an offline, tracking-by-detection method that formulates multi-object tracking as a global optimization problem over a directed graph. Each object detection is represented as a node, and feasible transitions between detections are modeled as edges weighted by spatial and temporal consistency. By solving a K-shortest paths problem, the tracker extracts globally optimal trajectories that span the entire sequence. + +Unlike online trackers, which make frame-by-frame decisions, KSP Tracker leverages the full temporal context of a video to achieve greater robustness against occlusions, missed detections, and fragmented tracks. This makes it especially suitable for applications where high tracking accuracy is required, such as surveillance review, sports analytics, or autonomous system evaluation. However, the reliance on global optimization introduces higher computational cost and requires access to the full sequence before tracking can be performed. + +## Examples + +=== "inference" + + ```python hl_lines="2 6 11 15" + import supervision as sv + from trackers import KSPTracker + from inference import get_model + import numpy as np + + tracker = KSPTracker() + model = get_model(model_id="yolo11x") + box_annotator = sv.BoxAnnotator() + label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT) + + def get_model_detections(frame: np.ndarray): + result = model.infer(frame)[0] + return sv.Detections.from_inference(result) + + tracked_dets = tracker.track( + source_path="", + get_model_detections=get_model_detections + ) + + frame_idx_to_dets = {i: tracked_dets[i] for i in range(len(tracked_dets))} + + def annotate_frame(frame: np.ndarray, i: int) -> np.ndarray: + detections = frame_idx_to_dets.get(i, sv.Detections.empty()) + detections.tracker_id = detections.tracker_id or np.zeros(len(detections), dtype=int) + labels = [f"{tid}" for tid in detections.tracker_id] + ann = box_annotator.annotate(frame.copy(), detections) + return label_annotator.annotate(ann, detections, labels=labels) + + sv.process_video( + source_path="", + target_path="", + callback=annotate_frame, + ) + ``` + +=== "rf-detr" + + ```python hl_lines="2 6 11 14" + import supervision as sv + from trackers import KSPTracker + from rfdetr import RFDETRBase + import numpy as np + + tracker = KSPTracker() + model = RFDETRBase() + box_annotator = sv.BoxAnnotator() + label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT) + + def get_model_detections(frame: np.ndarray): + return model.predict(frame) + + tracked_dets = tracker.track( + source_path="", + get_model_detections=get_model_detections + ) + + frame_idx_to_dets = {i: tracked_dets[i] for i in range(len(tracked_dets))} + + def annotate_frame(frame: np.ndarray, i: int) -> np.ndarray: + detections = frame_idx_to_dets.get(i, sv.Detections.empty()) + detections.tracker_id = detections.tracker_id or np.zeros(len(detections), dtype=int) + labels = [f"{tid}" for tid in detections.tracker_id] + ann = box_annotator.annotate(frame.copy(), detections) + return label_annotator.annotate(ann, detections, labels=labels) + + sv.process_video( + source_path="", + target_path="", + callback=annotate_frame, + ) + ``` + +=== "ultralytics" + + ```python hl_lines="2 6 11 16" + import supervision as sv + from trackers import KSPTracker + from ultralytics import YOLO + import numpy as np + + tracker = KSPTracker() + model = YOLO("yolo11m.pt") + box_annotator = sv.BoxAnnotator() + label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT) + + def get_model_detections(frame: np.ndarray): + result = model(frame, imgsz=1280, verbose=False)[0] + detections = sv.Detections.from_ultralytics(result) + return detections[detections.class_id == 0] if not detections.is_empty() else detections + + tracked_dets = tracker.track( + source_path="", + get_model_detections=get_model_detections + ) + + frame_idx_to_dets = {i: tracked_dets[i] for i in range(len(tracked_dets))} + + def annotate_frame(frame: np.ndarray, i: int) -> np.ndarray: + detections = frame_idx_to_dets.get(i, sv.Detections.empty()) + detections.tracker_id = detections.tracker_id or np.zeros(len(detections), dtype=int) + labels = [f"{tid}" for tid in detections.tracker_id] + ann = box_annotator.annotate(frame.copy(), detections) + return label_annotator.annotate(ann, detections, labels=labels) + + sv.process_video( + source_path="", + target_path="", + callback=annotate_frame, + ) + ``` + +=== "transformers" + + ```python hl_lines="3 7 13 27" + import torch + import supervision as sv + from trackers import KSPTracker + from transformers import RTDetrV2ForObjectDetection, RTDetrImageProcessor + import numpy as np + + tracker = KSPTracker() + processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_v2_r18vd") + model = RTDetrV2ForObjectDetection.from_pretrained("PekingU/rtdetr_v2_r18vd") + box_annotator = sv.BoxAnnotator() + label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT) + + def get_model_detections(frame: np.ndarray): + inputs = processor(images=frame, return_tensors="pt") + with torch.no_grad(): + outputs = model(**inputs) + + h, w, _ = frame.shape + results = processor.post_process_object_detection( + outputs, + target_sizes=torch.tensor([(h, w)]), + threshold=0.5 + )[0] + + return sv.Detections.from_transformers(results, id2label=model.config.id2label) + + tracked_dets = tracker.track( + "", + get_model_detections=get_model_detections + ) + + frame_idx_to_dets = {i: tracked_dets[i] for i in range(len(tracked_dets))} + + def annotate_frame(frame: np.ndarray, i: int) -> np.ndarray: + detections = frame_idx_to_dets.get(i, sv.Detections.empty()) + detections.tracker_id = detections.tracker_id or np.zeros(len(detections), dtype=int) + labels = [f"{tid}" for tid in detections.tracker_id] + ann = box_annotator.annotate(frame.copy(), detections) + return label_annotator.annotate(ann, detections, labels=labels) + + sv.process_video( + source_path="", + target_path="", + callback=callback, + ) + ``` + +## API + +::: trackers.core.ksp.tracker.KSPTracker diff --git a/mkdocs.yml b/mkdocs.yml index ad896311..e97c69b7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -95,4 +95,5 @@ nav: - Trackers: - SORT: trackers/core/sort/tracker.md - DeepSORT: trackers/core/deepsort/tracker.md + - KSP: trackers/core/ksp/tracker.md - ReID: trackers/core/reid/reid.md diff --git a/pyproject.toml b/pyproject.toml index 3e0e12b0..597e580b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,9 @@ metrics = [ "tensorboard>=2.19.0", "wandb>=0.19.11", ] +ksptracker = [ + "networkx>=3.2.1", +] [dependency-groups] dev = [ diff --git a/trackers/core/base.py b/trackers/core/base.py index 6db95485..aeb67d04 100644 --- a/trackers/core/base.py +++ b/trackers/core/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Callable, List, Optional import numpy as np import supervision as sv @@ -22,3 +23,18 @@ def update(self, detections: sv.Detections, frame: np.ndarray) -> sv.Detections: @abstractmethod def reset(self) -> None: pass + + +class BaseOfflineTracker(ABC): + @abstractmethod + def reset(self) -> None: + pass + + @abstractmethod + def track( + self, + source_path: str, + get_model_detections: Callable[[np.ndarray], sv.Detections], + num_of_tracks: Optional[int] = None, + ) -> List[sv.Detections]: + pass diff --git a/trackers/core/ksp/__init__.py b/trackers/core/ksp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trackers/core/ksp/solver.py b/trackers/core/ksp/solver.py new file mode 100644 index 00000000..a51e8850 --- /dev/null +++ b/trackers/core/ksp/solver.py @@ -0,0 +1,303 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, List, Optional, Set, Tuple + +import networkx as nx +import numpy as np +import supervision as sv +from tqdm.auto import tqdm + + +@dataclass(frozen=True) +class TrackNode: + frame_id: int + det_idx: int + class_id: int + position: tuple + bbox: np.ndarray + confidence: float + + def __hash__(self): + return hash((self.frame_id, self.det_idx)) + + def __eq__(self, other: Any): + return isinstance(other, TrackNode) and (self.frame_id, self.det_idx) == ( + other.frame_id, + other.det_idx, + ) + + def __str__(self): + return f"{self.frame_id}:{self.det_idx}@{self.position}" + + +class KSPSolver: + def __init__( + self, + path_overlap_penalty: float = 40, + iou_weight: float = 0.9, + dist_weight: float = 0.1, + size_weight: float = 0.1, + conf_weight: float = 0.1, + entry_weight: float = 2.0, + exit_weight: float = 2.0, + ): + self.path_overlap_penalty = ( + path_overlap_penalty if path_overlap_penalty is not None else 40 + ) + self.weight_key = "weight" + self.source = "SOURCE" + self.sink = "SINK" + self.detection_per_frame: List[sv.Detections] = [] + self.weights = {"iou": 0.9, "dist": 0.1, "size": 0.1, "conf": 0.1} + self.entry_weight = entry_weight + self.exit_weight = exit_weight + + if path_overlap_penalty is not None: + self.path_overlap_penalty = path_overlap_penalty + if iou_weight is not None: + self.weights["iou"] = iou_weight + if dist_weight is not None: + self.weights["dist"] = dist_weight + if size_weight is not None: + self.weights["size"] = size_weight + if conf_weight is not None: + self.weights["conf"] = conf_weight + + # Entry/exit region settings + self.entry_exit_regions: List[ + Tuple[int, int, int, int] + ] = [] # (x1, y1, x2, y2) + + # Border region settings + self.use_border_regions = True + self.active_borders: Set[str] = {"left", "right", "top", "bottom"} + self.border_margin = 40 + self.frame_size = (1920, 1080) + + self.reset() + + def reset(self) -> None: + """ + Reset the solver state, clearing all buffered detections and the graph. + """ + self.detection_per_frame = [] + self.graph = nx.DiGraph() + + def append_frame(self, detections: sv.Detections) -> None: + """ + Add detections for the current frame to the solver's buffer. + + Args: + detections (sv.Detections): Detections for the current frame. + """ + self.detection_per_frame.append(detections) + + def _get_center(self, bbox: np.ndarray) -> np.ndarray: + """ + Compute the center (x, y) of a bounding box. + + Args: + bbox (np.ndarray): Bounding box as [x1, y1, x2, y2]. + + Returns: + np.ndarray: Center coordinates as (x, y). + """ + x1, y1, x2, y2 = bbox + return np.array([(x1 + x2) / 2, (y1 + y2) / 2]) + + def set_entry_exit_regions(self, regions: List[Tuple[int, int, int, int]]) -> None: + """ + Set rectangular entry/exit zones (x1, y1, x2, y2). + + Args: + regions (List[Tuple[int, int, int, int]]): List of rectangular regions. + """ + self.entry_exit_regions = regions + + def set_border_entry_exit( + self, + use_border: bool = True, + borders: Optional[Set[str]] = None, + margin: int = 40, + frame_size: Tuple[int, int] = (1920, 1080), + ) -> None: + """ + Configure border-based entry/exit zones. + + Args: + use_border (bool): Enable/disable border-based entry/exit. + borders (Optional[Set[str]]): Set of borders to use. + margin (int): Border thickness in pixels. + frame_size (Tuple[int, int]): Size of the image (width, height). + """ + self.use_border_regions = use_border + self.active_borders = ( + borders if borders is not None else {"left", "right", "top", "bottom"} + ) + self.border_margin = margin + self.frame_size = frame_size + + def _in_door(self, node: TrackNode) -> bool: + """ + Check if a node is inside any entry/exit region (rectangular or border). + + Args: + node (TrackNode): The node to check. + + Returns: + bool: True if in any entry/exit region, else False. + """ + x, y = node.position + + # Check custom rectangular regions + for x1, y1, x2, y2 in self.entry_exit_regions: + if x1 <= x <= x2 and y1 <= y <= y2: + return True + + # Check image border zones + if self.use_border_regions: + width, height = self.frame_size + m = self.border_margin + + if "left" in self.active_borders and x <= m: + return True + if "right" in self.active_borders and x >= width - m: + return True + if "top" in self.active_borders and y <= m: + return True + if "bottom" in self.active_borders and y >= height - m: + return True + + return False + + def _edge_cost(self, nodeU: TrackNode, nodeV: TrackNode) -> float: + """ + Compute the cost of linking two detections (nodes) in the graph. + + Args: + nodeU (TrackNode): Source node. + nodeV (TrackNode): Destination node. + + Returns: + float: Edge cost based on IoU, distance, size, and confidence weights. + """ + bboxU, bboxV = nodeU.bbox, nodeV.bbox + conf_u, conf_v = nodeU.confidence, nodeV.confidence + + center_dist = np.linalg.norm(self._get_center(bboxU) - self._get_center(bboxV)) + iou_penalty = 1 - sv.box_iou_batch(np.array([bboxU]), np.array([bboxV])) + + area_a = (bboxU[2] - bboxU[0]) * (bboxU[3] - bboxU[1]) + area_b = (bboxV[2] - bboxV[0]) * (bboxV[3] - bboxV[1]) + size_penalty = np.log( + (max(area_a, area_b) / (min(area_a, area_b) + 1e-6)) + 1e-6 + ) + + conf_penalty = 1 - min(conf_u, conf_v) + + return ( + self.weights["iou"] * iou_penalty + + self.weights["dist"] * center_dist + + self.weights["size"] * size_penalty + + self.weights["conf"] * conf_penalty + ) + + def _build_graph(self): + """ + Build the directed graph of detections for KSP computation. + Nodes represent detections; edges represent possible associations. + """ + G = nx.DiGraph() + G.add_node(self.source) + G.add_node(self.sink) + + node_frames = [] + + for frame_id, detections in enumerate(self.detection_per_frame): + frame_nodes = [] + for det_idx, bbox in enumerate(detections.xyxy): + node = TrackNode( + frame_id=frame_id, + det_idx=det_idx, + class_id=int(detections.class_id[det_idx]), + position=tuple(self._get_center(bbox)), + bbox=bbox, + confidence=float(detections.confidence[det_idx]), + ) + G.add_node(node) + frame_nodes.append(node) + node_frames.append(frame_nodes) + + for t in range(len(node_frames) - 1): + for node_a in node_frames[t]: + if self._in_door(node_a): + G.add_edge(self.source, node_a, weight=t * self.entry_weight) + G.add_edge( + node_a, + self.sink, + weight=(len(node_frames) - 1 - t) * self.exit_weight, + ) + + for node_b in node_frames[t + 1]: + cost = self._edge_cost(node_a, node_b) + G.add_edge(node_a, node_b, weight=cost) + + for node in node_frames[0]: + G.add_edge(self.source, node, weight=0.0) + for node in node_frames[-1]: + G.add_edge(node, self.sink, weight=0.0) + + self.graph = G + + def solve(self, k: Optional[int] = None) -> List[List[TrackNode]]: + """ + Solve the K-Shortest Paths problem on the constructed detection graph. + + This method extracts up to k node-disjoint paths from the source to the sink in + the detection graph, assigning each path as a unique object track. Edge reuse is + penalized to encourage distinct tracks. The cost of each edge is determined by + the configured weights for IoU, distance, size, and confidence. + + Args: + k (Optional[int]): The number of tracks (paths) to extract. If None, uses + the maximum number of detections in any frame as the default. + + Returns: + List[List[TrackNode]]: A list of tracks, each track is a list of TrackNode + objects representing the detections assigned to that track. + """ + self._build_graph() + + G_base = self.graph.copy() + edge_reuse: defaultdict[Tuple[Any, Any], int] = defaultdict(int) + paths: List[List[TrackNode]] = [] + + if k is None: + k = max(len(f.xyxy) for f in self.detection_per_frame) + + for _i in tqdm(range(k), desc="Extracting k-shortest paths", leave=True): + G_mod = G_base.copy() + + for u, v, data in G_mod.edges(data=True): + base = data[self.weight_key] + penalty = self.path_overlap_penalty * 1000 * edge_reuse[(u, v)] * base + data[self.weight_key] = base + penalty + + try: + _, path = nx.single_source_dijkstra( + G_mod, self.source, self.sink, weight=self.weight_key + ) + except nx.NetworkXNoPath: + print(f"No path found from source to sink at {_i}th iteration") + break + + if path[1:-1] in paths: + print("Duplicate path found!") + continue + + paths.append(path[1:-1]) + + for u, v in zip(path[:-1], path[1:]): + edge_reuse[(u, v)] += 1 + + return paths diff --git a/trackers/core/ksp/tracker.py b/trackers/core/ksp/tracker.py new file mode 100644 index 00000000..0a2cafae --- /dev/null +++ b/trackers/core/ksp/tracker.py @@ -0,0 +1,308 @@ +import os +from collections import defaultdict +from typing import Callable, List, Optional, Set, Tuple, Union + +import cv2 +import numpy as np +import PIL +import supervision as sv +from tqdm.auto import tqdm + +from trackers.core.base import BaseOfflineTracker +from trackers.core.ksp.solver import KSPSolver, TrackNode + + +class KSPTracker(BaseOfflineTracker): + """ + Offline tracker using K-Shortest Paths (KSP) algorithm. + """ + + def __init__( + self, + path_overlap_penalty: float = 40, + iou_weight: float = 0.9, + dist_weight: float = 0.1, + size_weight: float = 0.1, + conf_weight: float = 0.1, + entry_weight: float = 2.0, + exit_weight: float = 2.0, + entry_exit_regions: Optional[List[Tuple[int, int, int, int]]] = None, + use_border: bool = True, + borders: Optional[Set[str]] = None, + border_margin: int = 40, + frame_size: Tuple[int, int] = (1920, 1080), + ) -> None: + """ + Initialize the KSPTracker and its underlying solver with region and cost + configuration. + + Args: + path_overlap_penalty (float): Penalty for reusing the same edge + (detection pairing) in multiple tracks. Higher values encourage the + tracker to produce more distinct, non-overlapping tracks by + discouraging shared detections between tracks. Default is 40. + iou_weight (float): Weight for the IoU penalty in the edge cost. + Higher values make the tracker favor linking detections with greater + spatial overlap, which helps maintain track continuity for objects + that move smoothly. Default is 0.9. + dist_weight (float): Weight for the Euclidean distance between + detection centers in the edge cost. Increasing this value penalizes + large jumps between detections in consecutive frames, promoting + smoother, more physically plausible tracks. Default is 0.1. + size_weight (float): Weight for the size difference penalty in + the edge cost. Higher values penalize linking detections with + significantly different bounding box areas, which helps prevent + identity switches when object size changes abruptly. Default is 0.1. + conf_weight (float): Weight for the confidence penalty in the + edge cost. Higher values penalize edges between detections with lower + confidence scores, making the tracker prefer more reliable detections + and reducing the impact of false positives. Default is 0.1. + entry_weight (float): Weight for entry node connections in the graph. + Higher values make the tracker more conservative about creating new tracks + when objects appear. Default is 2.0. + exit_weight (float): Weight for exit node connections in the graph. + Higher values make the tracker more conservative about ending tracks + when objects disappear. Default is 2.0. + entry_exit_regions (Optional[List[Tuple[int, int, int, int]]]): List of + rectangular entry/exit regions, each as (x1, y1, x2, y2) in pixels. + Used to determine when objects enter or exit the scene. Default is + an empty list. + use_border (bool): Whether to enable border-based entry/exit + logic. If True, objects entering or exiting through the image borders + (as defined by `borders` and `border_margin`) are considered for + entry/exit events. Default is True. + borders (Optional[Set[str]]): Set of border sides to use for entry/exit + logic. Valid values are any subset of {"left", "right", "top", + "bottom"}. Default is all four borders. + border_margin (int): Thickness of the border region (in pixels) + used for entry/exit detection. Default is 40. + frame_size (Tuple[int, int]): Size of the image frames as + (width, height). Used to determine border region extents. Default is + (1920, 1080). + """ + self.entry_exit_regions: List[Tuple[int, int, int, int]] = ( + entry_exit_regions if entry_exit_regions is not None else [] + ) + self.use_border: bool = use_border if use_border is not None else True + self.borders: Set[str] = ( + borders if borders is not None else {"left", "right", "top", "bottom"} + ) + self.border_margin: int = border_margin if border_margin is not None else 40 + self.frame_size: Tuple[int, int] = ( + frame_size if frame_size is not None else (1920, 1080) + ) + self._solver = KSPSolver( + path_overlap_penalty=path_overlap_penalty, + iou_weight=iou_weight, + dist_weight=dist_weight, + size_weight=size_weight, + conf_weight=conf_weight, + entry_weight=entry_weight, + exit_weight=exit_weight, + ) + self._solver.set_entry_exit_regions(self.entry_exit_regions) + self._solver.set_border_entry_exit( + use_border=self.use_border, + borders=self.borders, + margin=self.border_margin, + frame_size=self.frame_size, + ) + self.reset() + + def reset(self) -> None: + """ + Reset the KSPTracker and its solver state. + + This clears all buffered detections and resets the underlying solver. + """ + self._solver.reset() + + def _update(self, detections: sv.Detections) -> sv.Detections: + """ + Add detections for the current frame to the solver. + + Args: + detections (sv.Detections): Detections for the current frame. + + Returns: + sv.Detections: The same detections passed in. + """ + self._solver.append_frame(detections) + return detections + + def set_entry_exit_regions(self, regions: List[Tuple[int, int, int, int]]) -> None: + """ + Set rectangular entry/exit zones (x1, y1, x2, y2) and update both the + tracker and solver. + + Args: + regions (List[Tuple[int, int, int, int]]): List of rectangular + regions for entry/exit logic. + """ + self.entry_exit_regions = regions + self._solver.set_entry_exit_regions(regions) + + def set_border_entry_exit( + self, + use_border: bool = True, + borders: Optional[Set[str]] = None, + margin: int = 40, + frame_size: Tuple[int, int] = (1920, 1080), + ) -> None: + """ + Configure border-based entry/exit zones and update both the tracker and + solver. + + Args: + use_border (bool): Enable/disable border-based entry/exit. + borders (Optional[Set[str]]): Set of borders to use. Each value should + be one of "left", "right", "top", "bottom". + margin (int): Border thickness in pixels. + frame_size (Tuple[int, int]): Size of the image (width, height). + """ + self.use_border = use_border + self.borders = ( + borders if borders is not None else {"left", "right", "top", "bottom"} + ) + self.border_margin = margin + self.frame_size = frame_size + self._solver.set_border_entry_exit( + use_border=self.use_border, + borders=self.borders, + margin=self.border_margin, + frame_size=self.frame_size, + ) + + def _assign_tracker_ids_from_paths( + self, paths: List[List[TrackNode]] + ) -> List[sv.Detections]: + """ + Assigns each detection a unique tracker ID directly from node-disjoint paths. + + Args: + paths (List[List[TrackNode]]): List of tracks, each a list of TrackNode. + + Returns: + List[sv.Detections]: List of sv.Detections with tracker IDs assigned + for each frame. + """ + # Map from frame to list of dicts with detection info + tracker_id + frame_to_dets = defaultdict(list) + + # Assign each node a unique tracker ID (path index + 1) + for tracker_id, path in enumerate(paths, start=1): + for node in path: + frame_to_dets[node.frame_id].append( + { + "xyxy": node.bbox, + "confidence": node.confidence, + "class_id": node.class_id, + "tracker_id": tracker_id, + } + ) + + # Convert detections per frame into sv.Detections objects + frame_to_detections = [] + for frame in sorted(frame_to_dets.keys()): + dets_list = frame_to_dets[frame] + xyxy = np.array([d["xyxy"] for d in dets_list], dtype=np.float32) + confidence = np.array( + [d["confidence"] for d in dets_list], dtype=np.float32 + ) + class_id = np.array([d["class_id"] for d in dets_list], dtype=int) + tracker_id = np.array([d["tracker_id"] for d in dets_list], dtype=int) + + detections = sv.Detections( + xyxy=xyxy, + confidence=confidence, + class_id=class_id, + tracker_id=tracker_id, + ) + frame_to_detections.append(detections) + + return frame_to_detections + + def track( + self, + source: Union[str, List[PIL.Image.Image]], + get_model_detections: Callable[[np.ndarray], sv.Detections], + num_of_tracks: Optional[int] = None, + ) -> List[sv.Detections]: + """ + Run the KSP solver and assign tracker IDs to detections. + + Args: + source_path (str): Path to video file or directory of frames. + get_model_detections (Callable[[np.ndarray], sv.Detections]): + Function that takes an image (np.ndarray) and returns sv.Detections. + num_of_tracks (Optional[int]): Number of tracks to extract (K). + + Returns: + List[sv.Detections]: List of sv.Detections with tracker IDs assigned. + """ + if not source: + raise ValueError( + "`source_path` must be a string path to a directory or an .mp4 file." + ) + if get_model_detections is None: + raise TypeError( + "`get_model_detections` must be a callable that returns an " + "instance of `sv.Detections`." + ) + if isinstance(source, str) and source.lower().endswith(".mp4"): + frames_generator = sv.get_video_frames_generator(source_path=source) + video_info = sv.VideoInfo.from_video_path(video_path=source) + + self._solver.set_border_entry_exit( + self.use_border, + self.borders, + self.border_margin, + (video_info.width, video_info.height), + ) + + for frame in tqdm( + frames_generator, + total=video_info.total_frames, + desc="Extracting detections and buffering from video", + dynamic_ncols=True, + ): + detections = get_model_detections(frame) + self._update(detections) + elif isinstance(source, str) and os.path.isdir(source): + frame_paths = sorted( + [ + os.path.join(source, f) + for f in os.listdir(source) + if f.lower().endswith(".jpg") + ] + ) + + has_set_frame_size = False + + for frame_path in tqdm( + frame_paths, + desc="Extracting detections and buffering directory", + dynamic_ncols=True, + ): + image = cv2.imread(frame_path) + height, width = image.shape[:2] + + if not has_set_frame_size: + self._solver.set_border_entry_exit( + self.use_border, + self.borders, + self.border_margin, + (width, height), + ) + + detections = get_model_detections(image) + self._update(detections) + else: + raise ValueError(f"{source} not a valid path or list of PIL.Image.Image.") + paths = self._solver.solve(num_of_tracks) + + if not paths: + return [] + return self._assign_tracker_ids_from_paths( + paths + ), self._solver.detection_per_frame diff --git a/uv.lock b/uv.lock index 3afc9958..bcea71da 100644 --- a/uv.lock +++ b/uv.lock @@ -6635,6 +6635,10 @@ cu126 = [ { name = "torchvision", version = "0.22.1", source = { registry = "https://pypi.org/simple" }, marker = "(sys_platform == 'darwin' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu118') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-rocm61' and extra == 'extra-8-trackers-rocm624')" }, { name = "torchvision", version = "0.22.1+cu126", source = { registry = "https://download.pytorch.org/whl/cu126" }, marker = "(sys_platform != 'darwin' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu118') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-rocm61' and extra == 'extra-8-trackers-rocm624')" }, ] +ksptracker = [ + { name = "networkx", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10' or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu118') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-rocm61' and extra == 'extra-8-trackers-rocm624')" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10' or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu118') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-rocm61' and extra == 'extra-8-trackers-rocm624')" }, +] metrics = [ { name = "matplotlib", version = "3.9.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10' or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu118') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-rocm61' and extra == 'extra-8-trackers-rocm624')" }, { name = "matplotlib", version = "3.10.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10' or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu118') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cpu' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu124') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu118' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-cu126') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu124' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm61') or (extra == 'extra-8-trackers-cu126' and extra == 'extra-8-trackers-rocm624') or (extra == 'extra-8-trackers-rocm61' and extra == 'extra-8-trackers-rocm624')" }, @@ -6700,6 +6704,7 @@ requires-dist = [ { name = "aiofiles", specifier = ">=24.1.0" }, { name = "aiohttp", specifier = ">=3.11.16" }, { name = "matplotlib", marker = "extra == 'metrics'", specifier = ">=3.9.4" }, + { name = "networkx", marker = "extra == 'ksptracker'", specifier = ">=3.2.1" }, { name = "numpy", specifier = ">=2.0.2" }, { name = "pytorch-triton-rocm", marker = "sys_platform == 'darwin' and extra == 'rocm61'", specifier = ">=2.0.0" }, { name = "pytorch-triton-rocm", marker = "sys_platform == 'darwin' and extra == 'rocm624'", specifier = ">=2.0.0" }, @@ -6736,7 +6741,7 @@ requires-dist = [ { name = "validators", specifier = ">=0.34.0" }, { name = "wandb", marker = "extra == 'metrics'", specifier = ">=0.19.11" }, ] -provides-extras = ["cpu", "cu126", "cu124", "cu118", "rocm61", "rocm624", "reid", "metrics"] +provides-extras = ["cpu", "cu126", "cu124", "cu118", "rocm61", "rocm624", "reid", "metrics", "ksptracker"] [package.metadata.requires-dev] build = [