Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ __pycache__/
.claude/
docs/
sample_plot_*.png
data/
3 changes: 3 additions & 0 deletions captionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def run(
from mhc.dataset import MHCDataset
from mhc.transformer import MHCTransformer
from mhc.constants import MHC_CHANNEL_CONFIG
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.sleep import SleepSynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from models.local import LocalConfig, LocalModel
Expand All @@ -66,6 +68,7 @@ def run(
StatisticalExtractor(MHC_CHANNEL_CONFIG),
StructuralExtractor(MHC_CHANNEL_CONFIG),
SemanticExtractor(MHC_CHANNEL_CONFIG),
CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]),
])

captionizer = Captionizer(dataset, MHCTransformer(), annotator)
Expand Down
128 changes: 81 additions & 47 deletions explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from annotator import Annotator
from extractors import ChannelConfig
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.sleep import SleepSynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from mhc.constants import MHC_CHANNEL_CONFIG
Expand All @@ -25,6 +27,9 @@
from timef.schema import Recording, SignalView
from transformer import Transformer

IN_BED_AWAKE_LABEL = "in_bed_not_sleeping"
IN_BED_AWAKE_TARGET = "InBedAwake"


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Interactive explorer for MHC rows, signals, and detector outputs.")
Expand Down Expand Up @@ -87,6 +92,7 @@ def __init__(
self.annotator = Annotator([
StatisticalExtractor(channel_config),
StructuralExtractor(channel_config),
CrossChannelExtractor(channel_config, synthesizers=[SleepSynthesizer()]),
])

self.row_index = min(max(0, row_index), len(self.dataset) - 1)
Expand All @@ -99,9 +105,9 @@ def __init__(
self.detail_mode = "events"
self.details_scroll = 0
self.details_page_lines = 12
self.hit_target_names = self._available_detector_names()
self.hit_target_names = self._available_hit_target_names()
self.hit_target = self.hit_target_names[0] if self.hit_target_names else None
self.search_status = "Use hit< / hit> to jump to the selected detector."
self.search_status = "Use hit< / hit> to jump to the selected event."

self._ignore_widget_events = False

Expand Down Expand Up @@ -131,18 +137,10 @@ def __init__(
self.detail_tab_buttons[label] = Button(ax, label)
self.detail_up_ax = self.fig.add_axes([0.935, 0.085, 0.022, 0.028])
self.detail_down_ax = self.fig.add_axes([0.958, 0.085, 0.022, 0.028])
self.prev_hit_ax = self.fig.add_axes([0.76, 0.698, 0.035, 0.038])
self.next_hit_ax = self.fig.add_axes([0.945, 0.698, 0.035, 0.038])
self.hit_target_buttons: dict[str, Button] = {}
if self.hit_target_names:
n_targets = len(self.hit_target_names)
total_gap = 0.004 * max(0, n_targets - 1)
available_width = 0.145
width = (available_width - total_gap) / max(1, n_targets)
for i, detector_name in enumerate(self.hit_target_names):
x0 = 0.805 + i * (width + 0.004)
ax = self.fig.add_axes([x0, 0.698, width, 0.038])
self.hit_target_buttons[detector_name] = Button(ax, detector_name.replace("Detector", "").lower())
self.prev_target_ax = self.fig.add_axes([0.76, 0.698, 0.045, 0.038])
self.next_target_ax = self.fig.add_axes([0.81, 0.698, 0.045, 0.038])
self.prev_hit_ax = self.fig.add_axes([0.895, 0.698, 0.040, 0.038])
self.next_hit_ax = self.fig.add_axes([0.94, 0.698, 0.040, 0.038])
overlay_labels = ["trend", "spike", "drop", "nonwear"]
self.overlay_buttons: dict[str, Button] = {}
start_x = 0.83
Expand All @@ -164,6 +162,8 @@ def __init__(
self.reset_zoom_button = Button(self.reset_zoom_ax, "reset")
self.prev_row_button = Button(self.prev_row_ax, "<")
self.next_row_button = Button(self.next_row_ax, ">")
self.prev_target_button = Button(self.prev_target_ax, "target<")
self.next_target_button = Button(self.next_target_ax, "target>")
self.prev_hit_button = Button(self.prev_hit_ax, "hit<")
self.next_hit_button = Button(self.next_hit_ax, "hit>")
self.detail_up_button = Button(self.detail_up_ax, "^")
Expand All @@ -175,12 +175,12 @@ def __init__(
self.reset_zoom_button.on_clicked(lambda _: self.render(reset_zoom=True))
self.prev_row_button.on_clicked(lambda _: self._set_row(self.row_index - 1))
self.next_row_button.on_clicked(lambda _: self._set_row(self.row_index + 1))
self.prev_target_button.on_clicked(lambda _: self._cycle_hit_target(-1))
self.next_target_button.on_clicked(lambda _: self._cycle_hit_target(1))
self.prev_hit_button.on_clicked(lambda _: self._jump_to_hit(-1))
self.next_hit_button.on_clicked(lambda _: self._jump_to_hit(1))
self.detail_up_button.on_clicked(lambda _: self._scroll_details(-1))
self.detail_down_button.on_clicked(lambda _: self._scroll_details(1))
for detector_name, button in self.hit_target_buttons.items():
button.on_clicked(lambda _, name=detector_name: self._set_hit_target(name))
for label, button in self.detail_tab_buttons.items():
button.on_clicked(lambda _, name=label: self._set_detail_mode(name))
for label, button in self.overlay_buttons.items():
Expand All @@ -198,13 +198,19 @@ def _load_row_bundle(self, row_index: int) -> Recording:
recording.annotations.extend(self.annotator.annotate(recording))
return recording

def _available_detector_names(self) -> list[str]:
def _available_hit_target_names(self) -> list[str]:
names = {
detector.__class__.__name__
for detectors in self.channel_config.detectors.values()
for detector in detectors
}
return sorted(names)
return sorted(names) + [IN_BED_AWAKE_TARGET]

@staticmethod
def _hit_target_button_label(hit_target: str) -> str:
if hit_target == IN_BED_AWAKE_TARGET:
return "inbed+awake"
return hit_target.replace("Detector", "").lower()

def _set_row(self, row_index: int) -> None:
row_index = min(max(0, int(row_index)), len(self.dataset) - 1)
Expand Down Expand Up @@ -249,10 +255,18 @@ def _on_toggle(self, label: str) -> None:
def _set_hit_target(self, detector_name: str) -> None:
if detector_name != self.hit_target:
self.hit_target = detector_name
self.search_status = f"Jump target set to {detector_name}."
self._update_hit_target_button_styles()
self.search_status = f"Jump target set to {self._hit_target_button_label(detector_name)}."
self.render(reset_zoom=False)

def _cycle_hit_target(self, step: int) -> None:
if not self.hit_target_names:
return
if self.hit_target not in self.hit_target_names:
self._set_hit_target(self.hit_target_names[0])
return
current = self.hit_target_names.index(self.hit_target)
self._set_hit_target(self.hit_target_names[(current + step) % len(self.hit_target_names)])

def _set_detail_mode(self, mode: str) -> None:
if mode != self.detail_mode:
self.detail_mode = mode
Expand Down Expand Up @@ -285,6 +299,10 @@ def _on_key_press(self, event) -> None:
self._jump_to_hit(1)
elif event.key == "p":
self._jump_to_hit(-1)
elif event.key == "[":
self._cycle_hit_target(-1)
elif event.key == "]":
self._cycle_hit_target(1)

def _on_scroll(self, event) -> None:
if event.inaxes is not self.ax_details:
Expand Down Expand Up @@ -315,9 +333,16 @@ def _row_detector_events(self, row_index: int) -> tuple[tuple[tuple[str, object]
recording = self._load_row_bundle(row_index)
return tuple(tuple(self._detector_events(signal)) for signal in recording.iter_channels())

def _signal_has_hit_target(self, row_index: int, signal_index: int) -> bool:
if self.hit_target == IN_BED_AWAKE_TARGET:
recording = self._load_row_bundle(row_index)
return bool(self._cross_channel_windows(recording, signal_index, IN_BED_AWAKE_LABEL))
row_signal_events = self._row_detector_events(row_index)[signal_index]
return any(detector_name == self.hit_target for detector_name, _ in row_signal_events)

def _jump_to_hit(self, step: int) -> None:
if self.hit_target is None:
self.search_status = "No detector target is available."
self.search_status = "No search target is available."
self.render(reset_zoom=False)
return

Expand All @@ -328,17 +353,19 @@ def _jump_to_hit(self, step: int) -> None:
for offset in range(1, n_rows * n_signals):
candidate = (flat_index + step * offset) % (n_rows * n_signals)
row_index, signal_index = divmod(candidate, n_signals)
row_signal_events = self._row_detector_events(row_index)[signal_index]
if any(detector_name == self.hit_target for detector_name, _ in row_signal_events):
if self._signal_has_hit_target(row_index, signal_index):
self.row_index = row_index
self.signal_index = signal_index
self.details_scroll = 0
self.search_status = f"Jumped to row {row_index}, signal {signal_index} with {self.hit_target}."
self.search_status = (
f"Jumped to row {row_index}, signal {signal_index} "
f"with {self._hit_target_button_label(self.hit_target)}."
)
self._sync_widgets()
self.render(reset_zoom=True)
return

self.search_status = f"No hits found for {self.hit_target} in the scanned dataset."
self.search_status = f"No hits found for {self._hit_target_button_label(self.hit_target)} in the scanned dataset."
self.render(reset_zoom=False)

@staticmethod
Expand All @@ -350,6 +377,14 @@ def _captions_for_signal(recording: Recording, signal_idx: int) -> dict[str, lis
grouped.setdefault(annotation.caption_type, []).append(annotation.text)
return grouped

@staticmethod
def _cross_channel_windows(recording: Recording, signal_idx: int, label: str) -> list[tuple[int, int]]:
return [
annotation.window
for annotation in recording.annotations_for_signal(signal_idx)
if annotation.label == label and annotation.window is not None
]

@staticmethod
def _overview_matrix(recording: Recording) -> np.ma.MaskedArray:
rows = []
Expand Down Expand Up @@ -381,16 +416,15 @@ def _style_widgets(self) -> None:
):
button.label.set_fontsize(8.5)

for button in self.hit_target_buttons.values():
button.label.set_fontsize(8)
for button in self.detail_tab_buttons.values():
button.label.set_fontsize(7.5)
self.detail_up_button.label.set_fontsize(8)
self.detail_down_button.label.set_fontsize(8)
self.prev_target_button.label.set_fontsize(7.2)
self.next_target_button.label.set_fontsize(7.2)
for button in self.overlay_buttons.values():
button.label.set_fontsize(7)
self._sync_widgets()
self._update_hit_target_button_styles()
self._update_detail_tab_styles()
self._update_overlay_button_styles()

Expand Down Expand Up @@ -418,22 +452,6 @@ def _update_overlay_button_styles(self) -> None:
button.label.set_color(text)
button.label.set_fontweight("bold" if enabled else "normal")

def _update_hit_target_button_styles(self) -> None:
for detector_name, button in self.hit_target_buttons.items():
active = detector_name == self.hit_target
face = "#204a87" if active else "#f4f5f7"
edge = "#f4d35e" if active else "#c7cdd6"
text = "white" if active else "#5f6b7a"
button.ax.set_facecolor(face)
button.ax.patch.set_edgecolor(edge)
button.ax.patch.set_linewidth(2.0 if active else 1.1)
for spine in button.ax.spines.values():
spine.set_edgecolor(edge)
spine.set_linewidth(2.0 if active else 1.1)
button.hovercolor = "#3465a4" if active else "#eaedf2"
button.label.set_color(text)
button.label.set_fontweight("bold" if active else "normal")

@staticmethod
def _build_detail_lines(title: str, lines: list[str], width: int) -> list[str]:
rendered = [title]
Expand Down Expand Up @@ -469,6 +487,7 @@ def render(self, reset_zoom: bool = False) -> None:
signal = recording.signal(self.signal_index)
detector_events = self._detector_events(signal)
captions = self._captions_for_signal(recording, self.signal_index)
cross_channel_windows = self._cross_channel_windows(recording, self.signal_index, IN_BED_AWAKE_LABEL)
display_name = signal.display_name
unit = signal.unit or ""
_, _, decimals = self.channel_config.meta.get(signal.name, (signal.name, "", 2))
Expand All @@ -493,6 +512,8 @@ def render(self, reset_zoom: bool = False) -> None:
if self.show_nonwear:
for start, end in _nan_regions(y):
self.ax_main.axvspan(start, end, color="#d62728", alpha=0.08, label="nonwear")
for start, end in cross_channel_windows:
self.ax_main.axvspan(start, end, color="#f4d35e", alpha=0.22, label="in bed, awake")

for detector_name, result in detector_events:
if result.event_type == "trend" and self.show_trends:
Expand Down Expand Up @@ -609,14 +630,26 @@ def render(self, reset_zoom: bool = False) -> None:
self.ax_hit_target.text(
0.0,
1.02,
f"Find {self.hit_target.replace('Detector', '').lower()} hits" if self.hit_target else "Find hits",
"Find Hits",
transform=self.ax_hit_target.transAxes,
ha="left",
va="bottom",
fontsize=8.5,
color="#444444",
weight="bold",
)
self.ax_hit_target.text(
0.0,
0.55,
self._hit_target_button_label(self.hit_target),
transform=self.ax_hit_target.transAxes,
ha="left",
va="center",
fontsize=8.5,
color="#333333",
family="monospace",
weight="bold",
)
self.ax_hit_target.text(
1.0,
1.02,
Expand Down Expand Up @@ -692,9 +725,10 @@ def render(self, reset_zoom: bool = False) -> None:
"click a signal on the right or the overview heatmap to change channel",
"up/down changes row",
"left/right changes signal",
"choose a detector target above the signal list",
"n / p jump to next or previous hit for that detector",
"target< / target> changes the jump target",
"n / p jump to next or previous hit for that target",
"hit< / hit> buttons do the same",
"[ / ] also changes the jump target",
Comment on lines +757 to +760
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Help text references non-existent controls.

Line 757 mentions target< / target> but these controls don't exist in the UI. The actual ways to change the jump target are:

  • Click the detector name buttons (e.g., "InBedAwake")
  • Press [ or ] keys

Consider removing or correcting this line to avoid user confusion.

📝 Suggested fix
                     "up/down changes row",
                     "left/right changes signal",
-                    "target< / target> changes the jump target",
+                    "click detector buttons to change the jump target",
                     "n / p jump to next or previous hit for that target",
                     "hit< / hit> buttons do the same",
                     "[ / ] also changes the jump target",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@explorer.py` around lines 757 - 760, The help text still references a
non-existent control "target< / target>"; update the help string (the block
containing "target< / target> changes the jump target", "n / p jump to next or
previous hit for that target", etc.) to remove the bogus control and instead
instruct users to change the jump target by clicking the detector name buttons
(e.g., "InBedAwake") or by pressing "[" and "]" keys, keeping the related "n /
p" and "hit buttons" guidance intact; locate and edit the help text literal in
explorer.py to replace the incorrect line with the corrected wording.

"mouse wheel over details scrolls",
"PageUp/PageDown also scroll details",
"overlay buttons toggle detector layers",
Expand Down
2 changes: 1 addition & 1 deletion extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

DEFAULT_TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json"

VALID_CAPTION_TYPES = ("statistical", "structural", "semantic")
VALID_CAPTION_TYPES = ("statistical", "structural", "semantic", "cross_channel")


_ACTIVITY_RE = re.compile(r"HKWorkoutActivityType(.+)$")
Expand Down
25 changes: 25 additions & 0 deletions extractors/cross_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

from extractors import CaptionExtractor, ChannelConfig
from synthesizers import CrossChannelSynthesizer
from timef.schema import Annotation, Recording


class CrossChannelExtractor(CaptionExtractor):
caption_type = "cross_channel"

def __init__(self, config: ChannelConfig, synthesizers: list[CrossChannelSynthesizer]):
super().__init__(config)
self.synthesizers = synthesizers

def extract(self, row: Recording) -> list[Annotation]:
results: list[Annotation] = []
for synth in self.synthesizers:
results.extend(synth.synthesize(row, self.config))
return results
31 changes: 31 additions & 0 deletions synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

import abc

import numpy as np

from extractors import ChannelConfig
from timef.schema import Annotation, Recording


class CrossChannelSynthesizer(abc.ABC):
@abc.abstractmethod
def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: ...


def contiguous_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]:
if not mask.any():
return []

padded = np.concatenate(([False], mask, [False]))
diffs = np.diff(padded.astype(np.int8))
starts = np.where(diffs == 1)[0]
ends = np.where(diffs == -1)[0]
keep = (ends - starts) >= min_duration
return list(zip(starts[keep].tolist(), ends[keep].tolist()))
Loading