Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions captionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def run(
from mhc.dataset import MHCDataset
from mhc.transformer import MHCTransformer
from mhc.constants import MHC_CHANNEL_CONFIG
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.sleep import SleepSynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from models.local import LocalConfig, LocalModel
Expand All @@ -66,6 +68,7 @@ def run(
StatisticalExtractor(MHC_CHANNEL_CONFIG),
StructuralExtractor(MHC_CHANNEL_CONFIG),
SemanticExtractor(MHC_CHANNEL_CONFIG),
CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]),
])

captionizer = Captionizer(dataset, MHCTransformer(), annotator)
Expand Down
69 changes: 59 additions & 10 deletions explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from annotator import Annotator
from extractors import ChannelConfig
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.sleep import SleepSynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from mhc.constants import MHC_CHANNEL_CONFIG
Expand All @@ -25,6 +27,9 @@
from timef.schema import Recording, SignalView
from transformer import Transformer

IN_BED_AWAKE_LABEL = "in_bed_not_sleeping"
IN_BED_AWAKE_TARGET = "InBedAwake"


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Interactive explorer for MHC rows, signals, and detector outputs.")
Expand Down Expand Up @@ -93,6 +98,7 @@ def __init__(
self.annotator = Annotator([
StatisticalExtractor(channel_config),
StructuralExtractor(channel_config),
CrossChannelExtractor(channel_config, synthesizers=[SleepSynthesizer()]),
])

self.row_index = min(max(0, row_index), len(self.dataset) - 1)
Expand All @@ -103,7 +109,7 @@ def __init__(
self.detail_mode = "events"
self.details_scroll = 0
self.details_page_lines = 12
self.hit_target_names = self._available_detector_names()
self.hit_target_names = self._available_hit_target_names()
self.hit_target = self.hit_target_names[0] if self.hit_target_names else None
self.search_status = "Use hit< / hit> to jump to the selected event type."

Expand Down Expand Up @@ -202,13 +208,13 @@ def _load_row_bundle(self, row_index: int) -> Recording:
recording.annotations.extend(self.annotator.annotate(recording))
return recording

def _available_detector_names(self) -> list[str]:
def _available_hit_target_names(self) -> list[str]:
names = {
detector.__class__.__name__
for detectors in self.channel_config.detectors.values()
for detector in detectors
}
return sorted(names)
return sorted(names) + [IN_BED_AWAKE_TARGET]

@staticmethod
def _matches_hit_target(hit_target: str, detector_name: str, result: object) -> bool:
Expand Down Expand Up @@ -260,6 +266,15 @@ def _set_hit_target(self, detector_name: str) -> None:
self._update_hit_target_button_styles()
self.render(reset_zoom=False)

def _cycle_hit_target(self, step: int) -> None:
if not self.hit_target_names:
return
if self.hit_target not in self.hit_target_names:
self._set_hit_target(self.hit_target_names[0])
return
current = self.hit_target_names.index(self.hit_target)
self._set_hit_target(self.hit_target_names[(current + step) % len(self.hit_target_names)])

def _set_detail_mode(self, mode: str) -> None:
if mode != self.detail_mode:
self.detail_mode = mode
Expand Down Expand Up @@ -292,6 +307,10 @@ def _on_key_press(self, event) -> None:
self._jump_to_hit(1)
elif event.key == "p":
self._jump_to_hit(-1)
elif event.key == "[":
self._cycle_hit_target(-1)
elif event.key == "]":
self._cycle_hit_target(1)

def _on_scroll(self, event) -> None:
if event.inaxes is not self.ax_details:
Expand Down Expand Up @@ -332,9 +351,16 @@ def _row_detector_events(self, row_index: int) -> tuple[tuple[tuple[str, object]
recording = self._load_row_bundle(row_index)
return tuple(tuple(self._detector_events(signal)) for signal in recording.iter_channels())

def _signal_has_hit_target(self, row_index: int, signal_index: int) -> bool:
if self.hit_target == IN_BED_AWAKE_TARGET:
recording = self._load_row_bundle(row_index)
return bool(self._cross_channel_windows(recording, signal_index, IN_BED_AWAKE_LABEL))
row_signal_events = self._row_detector_events(row_index)[signal_index]
return any(detector_name == self.hit_target for detector_name, _ in row_signal_events)

def _jump_to_hit(self, step: int) -> None:
if self.hit_target is None:
self.search_status = "No detector target is available."
self.search_status = "No search target is available."
self.render(reset_zoom=False)
return

Expand All @@ -345,8 +371,7 @@ def _jump_to_hit(self, step: int) -> None:
for offset in range(1, n_rows * n_signals):
candidate = (flat_index + step * offset) % (n_rows * n_signals)
row_index, signal_index = divmod(candidate, n_signals)
row_signal_events = self._row_detector_events(row_index)[signal_index]
if any(self._matches_hit_target(self.hit_target, detector_name, result) for detector_name, result in row_signal_events):
if self._signal_has_hit_target(row_index, signal_index):
self.row_index = row_index
self.signal_index = signal_index
self.details_scroll = 0
Expand All @@ -367,6 +392,14 @@ def _captions_for_signal(recording: Recording, signal_idx: int) -> dict[str, lis
grouped.setdefault(annotation.caption_type, []).append(annotation.text)
return grouped

@staticmethod
def _cross_channel_windows(recording: Recording, signal_idx: int, label: str) -> list[tuple[int, int]]:
return [
annotation.window
for annotation in recording.annotations_for_signal(signal_idx)
if annotation.label == label and annotation.window is not None
]

@staticmethod
def _overview_matrix(recording: Recording) -> np.ma.MaskedArray:
rows = []
Expand Down Expand Up @@ -398,12 +431,12 @@ def _style_widgets(self) -> None:
):
button.label.set_fontsize(8.5)

for button in self.hit_target_buttons.values():
button.label.set_fontsize(8)
for button in self.detail_tab_buttons.values():
button.label.set_fontsize(7.5)
self.detail_up_button.label.set_fontsize(8)
self.detail_down_button.label.set_fontsize(8)
for button in self.hit_target_buttons.values():
button.label.set_fontsize(8)
for button in self.overlay_buttons.values():
button.label.set_fontsize(7)
self._sync_widgets()
Expand Down Expand Up @@ -485,6 +518,7 @@ def render(self, reset_zoom: bool = False) -> None:
detector_events = self._detector_events(signal)
spike_labels = self._spike_labels(detector_events)
captions = self._captions_for_signal(recording, self.signal_index)
cross_channel_windows = self._cross_channel_windows(recording, self.signal_index, IN_BED_AWAKE_LABEL)
display_name = signal.display_name
unit = signal.unit or ""
_, _, decimals = self.channel_config.meta.get(signal.name, (signal.name, "", 2))
Expand All @@ -506,6 +540,8 @@ def render(self, reset_zoom: bool = False) -> None:
ax.axis("off")

self.ax_main.plot(x[valid], y[valid], color="steelblue", linewidth=1.0, label="signal")
for start, end in cross_channel_windows:
self.ax_main.axvspan(start, end, color="#f4d35e", alpha=0.22, label="in bed, awake")

for detector_name, result in detector_events:
if result.event_type == "trend" and self.show_trends:
Expand Down Expand Up @@ -631,6 +667,18 @@ def render(self, reset_zoom: bool = False) -> None:
color="#444444",
weight="bold",
)
self.ax_hit_target.text(
0.0,
0.55,
_hit_target_label(self.hit_target),
transform=self.ax_hit_target.transAxes,
ha="left",
va="center",
fontsize=8.5,
color="#333333",
family="monospace",
weight="bold",
)
self.ax_hit_target.text(
1.0,
1.02,
Expand Down Expand Up @@ -706,9 +754,10 @@ def render(self, reset_zoom: bool = False) -> None:
"click a signal on the right or the overview heatmap to change channel",
"up/down changes row",
"left/right changes signal",
"choose a detector target above the signal list",
"n / p jump to next or previous hit for that detector",
"target< / target> changes the jump target",
"n / p jump to next or previous hit for that target",
"hit< / hit> buttons do the same",
"[ / ] also changes the jump target",
Comment on lines +757 to +760
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Help text references non-existent controls.

Line 757 mentions target< / target> but these controls don't exist in the UI. The actual ways to change the jump target are:

  • Click the detector name buttons (e.g., "InBedAwake")
  • Press [ or ] keys

Consider removing or correcting this line to avoid user confusion.

📝 Suggested fix
                     "up/down changes row",
                     "left/right changes signal",
-                    "target< / target> changes the jump target",
+                    "click detector buttons to change the jump target",
                     "n / p jump to next or previous hit for that target",
                     "hit< / hit> buttons do the same",
                     "[ / ] also changes the jump target",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@explorer.py` around lines 757 - 760, The help text still references a
non-existent control "target< / target>"; update the help string (the block
containing "target< / target> changes the jump target", "n / p jump to next or
previous hit for that target", etc.) to remove the bogus control and instead
instruct users to change the jump target by clicking the detector name buttons
(e.g., "InBedAwake") or by pressing "[" and "]" keys, keeping the related "n /
p" and "hit buttons" guidance intact; locate and edit the help text literal in
explorer.py to replace the incorrect line with the corrected wording.

"mouse wheel over details scrolls",
"PageUp/PageDown also scroll details",
"overlay buttons toggle detector layers",
Expand Down
2 changes: 1 addition & 1 deletion extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

DEFAULT_TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json"

VALID_CAPTION_TYPES = ("statistical", "structural", "semantic")
VALID_CAPTION_TYPES = ("statistical", "structural", "semantic", "cross_channel")


_ACTIVITY_RE = re.compile(r"HKWorkoutActivityType(.+)$")
Expand Down
25 changes: 25 additions & 0 deletions extractors/cross_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

from extractors import CaptionExtractor, ChannelConfig
from synthesizers import CrossChannelSynthesizer
from timef.schema import Annotation, Recording


class CrossChannelExtractor(CaptionExtractor):
caption_type = "cross_channel"

def __init__(self, config: ChannelConfig, synthesizers: list[CrossChannelSynthesizer]):
super().__init__(config)
self.synthesizers = synthesizers

def extract(self, row: Recording) -> list[Annotation]:
results: list[Annotation] = []
for synth in self.synthesizers:
results.extend(synth.synthesize(row, self.config))
return results
31 changes: 31 additions & 0 deletions synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

import abc

import numpy as np

from extractors import ChannelConfig
from timef.schema import Annotation, Recording


class CrossChannelSynthesizer(abc.ABC):
@abc.abstractmethod
def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: ...


def contiguous_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]:
if not mask.any():
return []

padded = np.concatenate(([False], mask, [False]))
diffs = np.diff(padded.astype(np.int8))
starts = np.where(diffs == 1)[0]
ends = np.where(diffs == -1)[0]
keep = (ends - starts) >= min_duration
return list(zip(starts[keep].tolist(), ends[keep].tolist()))
53 changes: 53 additions & 0 deletions synthesizers/sleep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

import json

import numpy as np

from extractors import CaptionExtractor, ChannelConfig
from synthesizers import CrossChannelSynthesizer, contiguous_windows
from timef.schema import Annotation, Recording


class SleepSynthesizer(CrossChannelSynthesizer):
def __init__(self, min_duration: int = 5):
self.min_duration = min_duration

def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]:
try:
in_bed_idx = row.channel_names.index("sleep:inbed")
asleep_idx = row.channel_names.index("sleep:asleep")
except ValueError:
return []

in_bed = np.asarray(row.values[in_bed_idx], dtype=float)
asleep = np.asarray(row.values[asleep_idx], dtype=float)
if not np.any((~np.isnan(asleep)) & (asleep > 0)):
return []
mask = (~np.isnan(in_bed)) & (in_bed > 0) & ~((~np.isnan(asleep)) & (asleep > 0))

time_unit = "hour" if config.time_unit == "hours" else "minute"
templates = json.loads(config.templates_path.read_text())["cross_channel"]["sleep"]
seed = CaptionExtractor._seed(row.row_id)

results: list[Annotation] = []
for i, (start, end) in enumerate(contiguous_windows(mask, self.min_duration)):
end_inclusive = max(start, end - 1)
template = templates[(seed + i) % len(templates)]
text = template.format(time_unit=time_unit, start=start, end=end_inclusive)
results.append(
Annotation(
caption_type="cross_channel",
text=text,
channel_idxs=(asleep_idx, in_bed_idx),
window=(start, end),
label="in_bed_not_sleeping",
)
)
return results
11 changes: 10 additions & 1 deletion templates/templates.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,14 @@
"{name} occurred at minute {start}.",
"{name} was observed at minute {start}.",
"Identified {name} at minute {start}."
]
],
"cross_channel": {
"sleep": [
"In bed but not sleeping from {time_unit} {start} to {end}.",
"The user was in bed but awake from {time_unit} {start} to {end}.",
"Awake while in bed between {time_unit} {start} and {end}.",
"In-bed wakefulness detected from {time_unit} {start} to {end}.",
"The user remained awake in bed from {time_unit} {start} to {end}."
]
}
}
8 changes: 7 additions & 1 deletion visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,19 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]]
from mhc.dataset import MHCDataset
from mhc.transformer import MHCTransformer
from mhc.constants import MHC_CHANNEL_CONFIG
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.sleep import SleepSynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from annotator import Annotator
from captionizer import Captionizer

dataset = MHCDataset(min_wear_pct=90.0)
annotator = Annotator([StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG)])
annotator = Annotator([
StatisticalExtractor(MHC_CHANNEL_CONFIG),
StructuralExtractor(MHC_CHANNEL_CONFIG),
CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]),
])
Comment on lines +125 to +129
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

This wiring doesn't surface the new sleep annotations yet.

CrossChannelExtractor now runs here, but plot_row() still only renders statistical, structural, and semantic caption types into the subplot labels. The new cross_channel captions are generated and then dropped, so this example path never shows the feature it just enabled. Either render cross_channel captions as well or leave the extractor out of this demo entrypoint for now.

Possible follow-up in plot_row()
-        for ann_type in ("statistical", "structural", "semantic"):
+        for ann_type in ("statistical", "structural", "semantic", "cross_channel"):
             if ann_type in caps:
                 parts.append(textwrap.fill(" ".join(caps[ann_type]), width=80))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@visualizer.py` around lines 125 - 129, The demo wires CrossChannelExtractor
(with SleepSynthesizer) into Annotator but plot_row() currently only renders
'statistical', 'structural', and 'semantic' caption types so the new
'cross_channel' captions are dropped; either update plot_row() to handle and
render 'cross_channel' captions (add 'cross_channel' to the caption types it
iterates/labels and format the label text accordingly) or remove
CrossChannelExtractor from the Annotator initialization to keep the demo
consistent (modify the Annotator call where
CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]) is
added).

captionizer = Captionizer(dataset, MHCTransformer(), annotator)
result, _ = captionizer.run(max_rows=1)
for row in result.iter_rows():
Expand Down
Loading