-
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcaptionizer.py
More file actions
110 lines (94 loc) · 3.8 KB
/
captionizer.py
File metadata and controls
110 lines (94 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations
from timef.schema import CaptionResult
from transformer import Transformer
from annotator import Annotator
from reviewer import EvaluationResult, Reviewer
class Captionizer:
def __init__(
self,
dataset,
transformer: Transformer,
annotator: Annotator,
reviewer: Reviewer | None = None,
) -> None:
self.dataset = dataset
self.transformer = transformer
self.annotator = annotator
self.reviewer = reviewer
def run(
self, max_rows: int | None = None,
) -> tuple[CaptionResult, EvaluationResult | None]:
result = CaptionResult()
failed_rows: list[int] = []
for i in range(len(self.dataset)):
if max_rows is not None and i >= max_rows:
break
try:
row = self.transformer.transform_row(self.dataset[i])
row.annotations.extend(self.annotator.annotate(row))
result.rows.append(row)
except Exception:
failed_rows.append(i)
if failed_rows:
print(f"{len(failed_rows)} / {i + 1} rows failed during captioning: {failed_rows}")
evaluation = self.reviewer.evaluate(result) if self.reviewer else None
return result, evaluation
if __name__ == "__main__":
from mhc.dataset import MHCDataset
from mhc.transformer import MHCTransformer
from mhc.constants import MHC_CHANNEL_CONFIG
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.cardio import CardioSynthesizer
from synthesizers.sleep import SleepSynthesizer
from synthesizers.stationary_activity import StationaryActivitySynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from models.local import LocalConfig, LocalModel
from models.client import ClientConfig, ClientModel
from extractors.semantic import SemanticExtractor
from visualizer import plot_row
import numpy as np
dataset = MHCDataset(min_wear_pct=0.0)
annotator = Annotator([
StatisticalExtractor(MHC_CHANNEL_CONFIG),
StructuralExtractor(MHC_CHANNEL_CONFIG),
SemanticExtractor(MHC_CHANNEL_CONFIG),
CrossChannelExtractor(
MHC_CHANNEL_CONFIG,
synthesizers=[
SleepSynthesizer(min_duration=5),
StationaryActivitySynthesizer(min_duration=5),
CardioSynthesizer(min_duration=5),
],
),
])
captionizer = Captionizer(dataset, MHCTransformer(), annotator)
print(f"Dataset size: {len(dataset)}")
result, _ = captionizer.run(max_rows=1)
print(f"Rows: {len(result.rows)}")
print(f"Annotations: {sum(len(row.annotations) for row in result.rows)}")
filtered = CaptionResult(rows=[row for row in result.iter_rows() if row.active_channel_count() >= 9])
print(f"Filtered rows: {len(filtered.rows)}")
model = ClientModel(
ClientConfig(
base_url="https://api.openai.com/v1",
model="gpt-5.4",
),
MHC_CHANNEL_CONFIG,
)
reviewer = Reviewer(model)
evaluation = reviewer.evaluate(filtered)
print(f"Evaluation: {len(evaluation.scores)} scores, mean={evaluation.mean_score}")
for j, s in enumerate(evaluation.scores):
print(f" [{j}] score={s.score} feedback={s.feedback}")
for i, row in enumerate(filtered.iter_rows()):
plot_row(row, save_path=f"sample_plot_{i}.png")
print(f"Saved sample_plot_{i}.png")
if i >= 3:
break