-
Notifications
You must be signed in to change notification settings - Fork 341
Expand file tree
/
Copy pathtest_tracker_integration.py
More file actions
120 lines (98 loc) · 3.88 KB
/
test_tracker_integration.py
File metadata and controls
120 lines (98 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# ------------------------------------------------------------------------
# Trackers
# Copyright (c) 2026 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import pytest
import supervision as sv
from trackers.core.base import BaseTracker
from trackers.eval import evaluate_mot_sequences
from trackers.io.mot import _load_mot_file, _mot_frame_to_detections, _MOTOutput
_TRACKER_IDS = ["sort", "bytetrack", "ocsort"]
_METRICS = ["CLEAR", "HOTA", "Identity"]
_TEST_DATA_DIR = Path(__file__).resolve().parent.parent / "data"
def _load_expected(dataset: str) -> dict[str, Any]:
"""Load TrackEval-derived reference metrics for a dataset."""
path = _TEST_DATA_DIR / f"tracker_expected_{dataset}.json"
with open(path) as f:
return json.load(f)
def _run_tracker_on_flat_dataset(
tracker_id: str,
data_path: Path,
output_dir: Path,
seqmap_path: Path,
) -> None:
"""Run a tracker on GT-derived detections and save flat MOT output files."""
import trackers as _trackers # noqa: F401 - triggers registration
info = BaseTracker._lookup_tracker(tracker_id)
assert info is not None, f"Unknown tracker: {tracker_id}"
output_dir.mkdir(parents=True, exist_ok=True)
gt_dir = data_path / "gt"
with open(seqmap_path) as f:
sequences = [
line.strip() for line in f if line.strip() and line.strip() != "name"
]
for seq_name in sequences:
gt_file = gt_dir / f"{seq_name}.txt"
if not gt_file.exists():
continue
gt_data = _load_mot_file(gt_file)
max_frame = max(gt_data.keys()) if gt_data else 0
tracker = info.tracker_class()
mot_path = output_dir / f"{seq_name}.txt"
with _MOTOutput(mot_path) as mot:
for frame_idx in range(1, max_frame + 1):
if frame_idx in gt_data:
detections = _mot_frame_to_detections(gt_data[frame_idx])
else:
detections = sv.Detections.empty()
tracked = tracker.update(detections)
if tracked.tracker_id is not None:
mature = tracked[tracked.tracker_id != -1]
assert isinstance(mature, sv.Detections)
mot.write(frame_idx, mature)
else:
mot.write(frame_idx, tracked)
@pytest.mark.integration
@pytest.mark.parametrize("tracker_id", _TRACKER_IDS)
@pytest.mark.parametrize(
"dataset, fixture_name",
[
("sportsmot", "sportsmot_flat_data"),
("dancetrack", "dancetrack_flat_data"),
],
)
def test_tracker_regression(
tracker_id: str,
dataset: str,
fixture_name: str,
request: pytest.FixtureRequest,
tmp_path: Path,
) -> None:
data_path, _ = request.getfixturevalue(fixture_name)
expected = _load_expected(dataset)[tracker_id]
tracker_output_dir = tmp_path / "tracker_output"
_run_tracker_on_flat_dataset(
tracker_id=tracker_id,
data_path=data_path,
output_dir=tracker_output_dir,
seqmap_path=data_path / "seqmap.txt",
)
result = evaluate_mot_sequences(
gt_dir=data_path / "gt",
tracker_dir=tracker_output_dir,
seqmap=data_path / "seqmap.txt",
metrics=_METRICS,
)
aggregate = result.aggregate
assert aggregate.HOTA is not None
assert aggregate.CLEAR is not None
assert aggregate.Identity is not None
assert aggregate.HOTA.HOTA * 100 == pytest.approx(expected["HOTA"], abs=0.001)
assert aggregate.CLEAR.MOTA * 100 == pytest.approx(expected["MOTA"], abs=0.001)
assert aggregate.Identity.IDF1 * 100 == pytest.approx(expected["IDF1"], abs=0.001)
assert aggregate.CLEAR.IDSW == expected["IDSW"]