diff --git a/test/core/test_tracker_integration.py b/test/core/test_tracker_integration.py index 31c4688b..f1427e33 100644 --- a/test/core/test_tracker_integration.py +++ b/test/core/test_tracker_integration.py @@ -70,6 +70,7 @@ def _run_tracker_on_flat_dataset( tracked = tracker.update(detections) if tracked.tracker_id is not None: mature = tracked[tracked.tracker_id != -1] + assert isinstance(mature, sv.Detections) mot.write(frame_idx, mature) else: mot.write(frame_idx, tracked) @@ -110,6 +111,9 @@ def test_tracker_regression( ) aggregate = result.aggregate + assert aggregate.HOTA is not None + assert aggregate.CLEAR is not None + assert aggregate.Identity is not None assert aggregate.HOTA.HOTA * 100 == pytest.approx(expected["HOTA"], abs=0.001) assert aggregate.CLEAR.MOTA * 100 == pytest.approx(expected["MOTA"], abs=0.001) assert aggregate.Identity.IDF1 * 100 == pytest.approx(expected["IDF1"], abs=0.001) diff --git a/trackers/core/bytetrack/kalman.py b/trackers/core/bytetrack/kalman.py index c091e067..ea9cfc40 100644 --- a/trackers/core/bytetrack/kalman.py +++ b/trackers/core/bytetrack/kalman.py @@ -51,7 +51,7 @@ def get_next_tracker_id(cls) -> int: cls.count_id += 1 return next_id - def __init__(self, bbox: np.ndarray): + def __init__(self, bbox: np.ndarray) -> None: # Initialize with a temporary ID of -1 # Will be assigned a real ID when the track is considered mature self.tracker_id = -1 diff --git a/trackers/eval/__init__.py b/trackers/eval/__init__.py index 401aec62..254f9286 100644 --- a/trackers/eval/__init__.py +++ b/trackers/eval/__init__.py @@ -8,6 +8,7 @@ from trackers.eval.box import box_ioa, box_iou from trackers.eval.clear import aggregate_clear_metrics, compute_clear_metrics +from trackers.eval.evaluate import evaluate_mot_sequence, evaluate_mot_sequences from trackers.eval.hota import aggregate_hota_metrics, compute_hota_metrics from trackers.eval.identity import aggregate_identity_metrics, compute_identity_metrics from trackers.eval.results import ( @@ -19,7 +20,7 @@ ) -def __getattr__(name: str): +def __getattr__(name: str) -> object: """Lazy imports for evaluate functions to avoid circular imports.""" if name in ("evaluate_mot_sequence", "evaluate_mot_sequences"): from trackers.eval import evaluate as _evaluate diff --git a/trackers/io/mot.py b/trackers/io/mot.py index 9f849c0b..bc6043ab 100644 --- a/trackers/io/mot.py +++ b/trackers/io/mot.py @@ -9,6 +9,7 @@ import csv from dataclasses import dataclass from pathlib import Path +from typing import TextIO import numpy as np import supervision as sv @@ -418,9 +419,9 @@ def _prepare_mot_sequence( class _MOTOutput: """Context manager for MOT format file writing.""" - def __init__(self, path: Path | None): + def __init__(self, path: Path | None) -> None: self.path = path - self._file = None + self._file: TextIO | None = None def write(self, frame_idx: int, detections: sv.Detections) -> None: """Write detections for a frame in MOT format.""" @@ -447,12 +448,12 @@ def write(self, frame_idx: int, detections: sv.Detections) -> None: f"{conf:.4f},-1,-1,-1\n" ) - def __enter__(self): + def __enter__(self) -> _MOTOutput: if self.path is not None: self.path.parent.mkdir(parents=True, exist_ok=True) self._file = open(self.path, "w") return self - def __exit__(self, *_): + def __exit__(self, *_) -> None: if self._file is not None: self._file.close() diff --git a/trackers/io/video.py b/trackers/io/video.py index 51a15000..2700f545 100644 --- a/trackers/io/video.py +++ b/trackers/io/video.py @@ -90,7 +90,7 @@ def _iter_image_folder_frames( class _VideoOutput: """Context manager for lazy video file writing.""" - def __init__(self, path: Path | None, *, fps: float = _DEFAULT_OUTPUT_FPS): + def __init__(self, path: Path | None, *, fps: float = _DEFAULT_OUTPUT_FPS) -> None: self.path = path self.fps = fps self._writer: cv2.VideoWriter | None = None @@ -137,7 +137,7 @@ def __exit__(self, *_: object) -> None: class _DisplayWindow: """Context manager for OpenCV display window with resizable output.""" - def __init__(self, window_name: str = "Tracking"): + def __init__(self, window_name: str = "Tracking") -> None: self.window_name = window_name self._quit_requested = False cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) @@ -159,8 +159,8 @@ def quit_requested(self) -> bool: """Return True if user pressed quit key.""" return self._quit_requested - def __enter__(self): + def __enter__(self) -> _DisplayWindow: return self - def __exit__(self, *_): + def __exit__(self, *_) -> None: cv2.destroyWindow(self.window_name) diff --git a/trackers/scripts/track.py b/trackers/scripts/track.py index 31265cf5..972bdeaf 100644 --- a/trackers/scripts/track.py +++ b/trackers/scripts/track.py @@ -11,6 +11,7 @@ import sys from contextlib import nullcontext from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import supervision as sv @@ -23,6 +24,9 @@ from trackers.scripts.progress import _classify_source, _SourceInfo, _TrackingProgress from trackers.utils.device import _best_device +if TYPE_CHECKING: + from inference_models import AnyModel + # Defaults DEFAULT_MODEL = "rfdetr-nano" DEFAULT_TRACKER = "bytetrack" @@ -555,7 +559,7 @@ def _init_model( *, device: str = DEFAULT_DEVICE, api_key: str | None = None, -): +) -> AnyModel: """Load detection model via inference-models. Args: @@ -585,7 +589,7 @@ def _init_model( ) -def _run_model(model, frame: np.ndarray, confidence: float) -> sv.Detections: +def _run_model(model: AnyModel, frame: np.ndarray, confidence: float) -> sv.Detections: """Run model inference and return sv.Detections.""" predictions = model(frame) if not predictions: @@ -627,7 +631,7 @@ def _extract_tracker_params( return params -def _init_tracker(tracker_id: str, **kwargs) -> BaseTracker: +def _init_tracker(tracker_id: str, **kwargs: object) -> BaseTracker: """Create tracker instance from registry. Args: