-
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcaptionizer.py
More file actions
101 lines (85 loc) · 3.48 KB
/
captionizer.py
File metadata and controls
101 lines (85 loc) · 3.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations
from timef.schema import CaptionResult
from transformer import Transformer
from annotator import Annotator
from reviewer import EvaluationResult, Reviewer
class Captionizer:
def __init__(
self,
dataset,
transformer: Transformer,
annotator: Annotator,
reviewer: Reviewer | None = None,
) -> None:
self.dataset = dataset
self.transformer = transformer
self.annotator = annotator
self.reviewer = reviewer
def run(
self, max_rows: int | None = None,
) -> tuple[CaptionResult, EvaluationResult | None]:
result = CaptionResult()
failed_rows: list[int] = []
for i in range(len(self.dataset)):
if max_rows is not None and i >= max_rows:
break
try:
row = self.transformer.transform_row(self.dataset[i])
row.annotations.extend(self.annotator.annotate(row))
result.rows.append(row)
except Exception:
failed_rows.append(i)
if failed_rows:
print(f"{len(failed_rows)} / {i + 1} rows failed during captioning: {failed_rows}")
evaluation = self.reviewer.evaluate(result) if self.reviewer else None
return result, evaluation
if __name__ == "__main__":
from mhc.dataset import MHCDataset
from mhc.transformer import MHCTransformer
from mhc.constants import MHC_CHANNEL_CONFIG
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.sleep import SleepSynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from models.local import LocalConfig, LocalModel
from models.client import ClientConfig, ClientModel
from extractors.semantic import SemanticExtractor
from visualizer import plot_row
import numpy as np
dataset = MHCDataset(min_wear_pct=0.0)
annotator = Annotator([
StatisticalExtractor(MHC_CHANNEL_CONFIG),
StructuralExtractor(MHC_CHANNEL_CONFIG),
SemanticExtractor(MHC_CHANNEL_CONFIG),
CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]),
])
captionizer = Captionizer(dataset, MHCTransformer(), annotator)
print(f"Dataset size: {len(dataset)}")
result, _ = captionizer.run(max_rows=1)
print(f"Rows: {len(result.rows)}")
print(f"Annotations: {sum(len(row.annotations) for row in result.rows)}")
filtered = CaptionResult(rows=[row for row in result.iter_rows() if row.active_channel_count() >= 9])
print(f"Filtered rows: {len(filtered.rows)}")
model = ClientModel(
ClientConfig(
base_url="https://api.openai.com/v1",
model="gpt-5.4",
),
MHC_CHANNEL_CONFIG,
)
reviewer = Reviewer(model)
evaluation = reviewer.evaluate(filtered)
print(f"Evaluation: {len(evaluation.scores)} scores, mean={evaluation.mean_score}")
for j, s in enumerate(evaluation.scores):
print(f" [{j}] score={s.score} feedback={s.feedback}")
for i, row in enumerate(filtered.iter_rows()):
plot_row(row, save_path=f"sample_plot_{i}.png")
print(f"Saved sample_plot_{i}.png")
if i >= 3:
break