Skip to content

Commit e457736

Browse files
Refactor CrossChannelExtractor into driver with pluggable synthesizers
1 parent 71cb0b8 commit e457736

File tree

7 files changed

+105
-46
lines changed

7 files changed

+105
-46
lines changed

captionizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def run(
5454
from mhc.transformer import MHCTransformer
5555
from mhc.constants import MHC_CHANNEL_CONFIG
5656
from extractors.cross_channel import CrossChannelExtractor
57+
from synthesizers.sleep import SleepSynthesizer
5758
from extractors.statistical import StatisticalExtractor
5859
from extractors.structural import StructuralExtractor
5960
from models.local import LocalConfig, LocalModel
@@ -67,7 +68,7 @@ def run(
6768
StatisticalExtractor(MHC_CHANNEL_CONFIG),
6869
StructuralExtractor(MHC_CHANNEL_CONFIG),
6970
SemanticExtractor(MHC_CHANNEL_CONFIG),
70-
CrossChannelExtractor(MHC_CHANNEL_CONFIG),
71+
CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]),
7172
])
7273

7374
captionizer = Captionizer(dataset, MHCTransformer(), annotator)

explorer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from annotator import Annotator
1919
from extractors import ChannelConfig
2020
from extractors.cross_channel import CrossChannelExtractor
21+
from synthesizers.sleep import SleepSynthesizer
2122
from extractors.statistical import StatisticalExtractor
2223
from extractors.structural import StructuralExtractor
2324
from mhc.constants import MHC_CHANNEL_CONFIG
@@ -91,7 +92,7 @@ def __init__(
9192
self.annotator = Annotator([
9293
StatisticalExtractor(channel_config),
9394
StructuralExtractor(channel_config),
94-
CrossChannelExtractor(channel_config),
95+
CrossChannelExtractor(channel_config, synthesizers=[SleepSynthesizer()]),
9596
])
9697

9798
self.row_index = min(max(0, row_index), len(self.dataset) - 1)

extractors/cross_channel.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,57 +6,20 @@
66
#
77
from __future__ import annotations
88

9-
import numpy as np
10-
119
from extractors import CaptionExtractor, ChannelConfig
10+
from synthesizers import CrossChannelSynthesizer
1211
from timef.schema import Annotation, Recording
1312

1413

1514
class CrossChannelExtractor(CaptionExtractor):
1615
caption_type = "cross_channel"
1716

18-
def __init__(self, config: ChannelConfig, min_duration: int = 5):
17+
def __init__(self, config: ChannelConfig, synthesizers: list[CrossChannelSynthesizer]):
1918
super().__init__(config)
20-
self.min_duration = min_duration
19+
self.synthesizers = synthesizers
2120

2221
def extract(self, row: Recording) -> list[Annotation]:
23-
try:
24-
in_bed_idx = row.channel_names.index("sleep:inbed")
25-
asleep_idx = row.channel_names.index("sleep:asleep")
26-
except ValueError:
27-
return []
28-
29-
in_bed = np.asarray(row.values[in_bed_idx], dtype=float)
30-
asleep = np.asarray(row.values[asleep_idx], dtype=float)
31-
if not np.any((~np.isnan(asleep)) & (asleep > 0)):
32-
return []
33-
mask = (~np.isnan(in_bed)) & (in_bed > 0) & ~((~np.isnan(asleep)) & (asleep > 0))
34-
3522
results: list[Annotation] = []
36-
for start, end in _contiguous_windows(mask, self.min_duration):
37-
end_inclusive = max(start, end - 1)
38-
results.append(
39-
Annotation(
40-
caption_type=self.caption_type,
41-
text=f"In bed but not sleeping from {self._time_label(start)} {start} to {self._time_label(end_inclusive)} {end_inclusive}.",
42-
channel_idxs=(asleep_idx, in_bed_idx),
43-
window=(start, end),
44-
label="in_bed_not_sleeping",
45-
)
46-
)
23+
for synth in self.synthesizers:
24+
results.extend(synth.synthesize(row, self.config))
4725
return results
48-
49-
def _time_label(self, value: int) -> str:
50-
return "hour" if self.config.time_unit == "hours" else "minute"
51-
52-
53-
def _contiguous_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]:
54-
if not mask.any():
55-
return []
56-
57-
padded = np.concatenate(([False], mask, [False]))
58-
diffs = np.diff(padded.astype(np.int8))
59-
starts = np.where(diffs == 1)[0]
60-
ends = np.where(diffs == -1)[0]
61-
keep = (ends - starts) >= min_duration
62-
return list(zip(starts[keep].tolist(), ends[keep].tolist()))

synthesizers/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#
2+
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
3+
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
4+
#
5+
# SPDX-License-Identifier: MIT
6+
#
7+
from __future__ import annotations
8+
9+
import abc
10+
11+
import numpy as np
12+
13+
from extractors import ChannelConfig
14+
from timef.schema import Annotation, Recording
15+
16+
17+
class CrossChannelSynthesizer(abc.ABC):
18+
@abc.abstractmethod
19+
def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: ...
20+
21+
22+
def contiguous_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]:
23+
if not mask.any():
24+
return []
25+
26+
padded = np.concatenate(([False], mask, [False]))
27+
diffs = np.diff(padded.astype(np.int8))
28+
starts = np.where(diffs == 1)[0]
29+
ends = np.where(diffs == -1)[0]
30+
keep = (ends - starts) >= min_duration
31+
return list(zip(starts[keep].tolist(), ends[keep].tolist()))

synthesizers/sleep.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#
2+
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
3+
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
4+
#
5+
# SPDX-License-Identifier: MIT
6+
#
7+
from __future__ import annotations
8+
9+
import json
10+
11+
import numpy as np
12+
13+
from extractors import CaptionExtractor, ChannelConfig
14+
from synthesizers import CrossChannelSynthesizer, contiguous_windows
15+
from timef.schema import Annotation, Recording
16+
17+
18+
class SleepSynthesizer(CrossChannelSynthesizer):
19+
def __init__(self, min_duration: int = 5):
20+
self.min_duration = min_duration
21+
22+
def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]:
23+
try:
24+
in_bed_idx = row.channel_names.index("sleep:inbed")
25+
asleep_idx = row.channel_names.index("sleep:asleep")
26+
except ValueError:
27+
return []
28+
29+
in_bed = np.asarray(row.values[in_bed_idx], dtype=float)
30+
asleep = np.asarray(row.values[asleep_idx], dtype=float)
31+
if not np.any((~np.isnan(asleep)) & (asleep > 0)):
32+
return []
33+
mask = (~np.isnan(in_bed)) & (in_bed > 0) & ~((~np.isnan(asleep)) & (asleep > 0))
34+
35+
time_unit = "hour" if config.time_unit == "hours" else "minute"
36+
templates = json.loads(config.templates_path.read_text())["cross_channel"]["sleep"]
37+
seed = CaptionExtractor._seed(row.row_id)
38+
39+
results: list[Annotation] = []
40+
for i, (start, end) in enumerate(contiguous_windows(mask, self.min_duration)):
41+
end_inclusive = max(start, end - 1)
42+
template = templates[(seed + i) % len(templates)]
43+
text = template.format(time_unit=time_unit, start=start, end=end_inclusive)
44+
results.append(
45+
Annotation(
46+
caption_type="cross_channel",
47+
text=text,
48+
channel_idxs=(asleep_idx, in_bed_idx),
49+
window=(start, end),
50+
label="in_bed_not_sleeping",
51+
)
52+
)
53+
return results

templates/templates.json

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,14 @@
4242
"{name} occurred at minute {start}.",
4343
"{name} was observed at minute {start}.",
4444
"Identified {name} at minute {start}."
45-
]
45+
],
46+
"cross_channel": {
47+
"sleep": [
48+
"In bed but not sleeping from {time_unit} {start} to {end}.",
49+
"The user was in bed but awake from {time_unit} {start} to {end}.",
50+
"Awake while in bed between {time_unit} {start} and {end}.",
51+
"In-bed wakefulness detected from {time_unit} {start} to {end}.",
52+
"The user remained awake in bed from {time_unit} {start} to {end}."
53+
]
54+
}
4655
}

visualizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]]
115115
from mhc.transformer import MHCTransformer
116116
from mhc.constants import MHC_CHANNEL_CONFIG
117117
from extractors.cross_channel import CrossChannelExtractor
118+
from synthesizers.sleep import SleepSynthesizer
118119
from extractors.statistical import StatisticalExtractor
119120
from extractors.structural import StructuralExtractor
120121
from annotator import Annotator
@@ -124,7 +125,7 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]]
124125
annotator = Annotator([
125126
StatisticalExtractor(MHC_CHANNEL_CONFIG),
126127
StructuralExtractor(MHC_CHANNEL_CONFIG),
127-
CrossChannelExtractor(MHC_CHANNEL_CONFIG),
128+
CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]),
128129
])
129130
captionizer = Captionizer(dataset, MHCTransformer(), annotator)
130131
result, _ = captionizer.run(max_rows=1)

0 commit comments

Comments
 (0)