Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
71cb0b8
In Bed but not Asleep annotation
KarlDeck Apr 9, 2026
e457736
Refactor CrossChannelExtractor into driver with pluggable synthesizers
max-rosenblattl Apr 12, 2026
7596a1b
made stationary activity synthesizer
KarlDeck Apr 12, 2026
e4320e5
added totals
KarlDeck Apr 12, 2026
08af225
added cardio synthesizers with totals
KarlDeck Apr 12, 2026
48927f1
Merge remote-tracking branch 'origin/main' into KarlDeck/Sleep-Bundles
KarlDeck Apr 12, 2026
7596829
Merge remote-tracking branch 'origin/main' into KarlDeck/Sleep-Bundles
KarlDeck Apr 12, 2026
5e3842d
put static methods into parent
KarlDeck Apr 13, 2026
43e3711
made min duration for synthesizer visible to users. Adresses Coderabb…
KarlDeck Apr 13, 2026
9e525e0
adressed coderabbit comment #2 comment
KarlDeck Apr 13, 2026
5158020
put variables into mhc/constants.py
KarlDeck Apr 13, 2026
3601edd
added HR delta
KarlDeck Apr 13, 2026
e7e5d0a
rephrased HR delta
KarlDeck Apr 13, 2026
fc3590a
solved duplication issue in templates/templates.json
KarlDeck Apr 13, 2026
52b8063
fixed --weekly issue
KarlDeck Apr 13, 2026
396e5b1
added 100 bpm threshold
KarlDeck Apr 13, 2026
05fb0ac
put _seed into util.py
KarlDeck Apr 13, 2026
0e09456
Merge remote-tracking branch 'origin/main' into KarlDeck/Sleep-Bundles
KarlDeck Apr 13, 2026
9f72309
reprased the synthesizer outputs
KarlDeck Apr 14, 2026
8a0956e
rephrase 2
KarlDeck Apr 14, 2026
df174d1
split up _metrics_suffix to make it easier to read
KarlDeck Apr 14, 2026
3ee312b
split up _metrics_suffix to make it easier to read
KarlDeck Apr 14, 2026
e6ecd2a
small cleanup
KarlDeck Apr 14, 2026
5cb38ef
Refactored _metrics_suffix into sub functions and transfered into parent
KarlDeck Apr 14, 2026
56d2fad
minor fix
KarlDeck Apr 14, 2026
9e39ee3
added docstrings
KarlDeck Apr 14, 2026
5de955d
comment added
KarlDeck Apr 14, 2026
c2be771
transfer functions from init to _helper
KarlDeck Apr 16, 2026
2fe7c84
created _workout base for cardio, stationary and furutre workouts
KarlDeck Apr 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ __pycache__/
.claude/
docs/
sample_plot_*.png
data/
3 changes: 3 additions & 0 deletions captionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def run(
from mhc.dataset import MHCDataset
from mhc.transformer import MHCTransformer
from mhc.constants import MHC_CHANNEL_CONFIG
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.sleep import SleepSynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from models.local import LocalConfig, LocalModel
Expand All @@ -66,6 +68,7 @@ def run(
StatisticalExtractor(MHC_CHANNEL_CONFIG),
StructuralExtractor(MHC_CHANNEL_CONFIG),
SemanticExtractor(MHC_CHANNEL_CONFIG),
CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]),
])

captionizer = Captionizer(dataset, MHCTransformer(), annotator)
Expand Down
128 changes: 81 additions & 47 deletions explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from annotator import Annotator
from extractors import ChannelConfig
from extractors.cross_channel import CrossChannelExtractor
from synthesizers.sleep import SleepSynthesizer
from extractors.statistical import StatisticalExtractor
from extractors.structural import StructuralExtractor
from mhc.constants import MHC_CHANNEL_CONFIG
Expand All @@ -25,6 +27,9 @@
from timef.schema import Recording, SignalView
from transformer import Transformer

IN_BED_AWAKE_LABEL = "in_bed_not_sleeping"
IN_BED_AWAKE_TARGET = "InBedAwake"


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Interactive explorer for MHC rows, signals, and detector outputs.")
Expand Down Expand Up @@ -87,6 +92,7 @@ def __init__(
self.annotator = Annotator([
StatisticalExtractor(channel_config),
StructuralExtractor(channel_config),
CrossChannelExtractor(channel_config, synthesizers=[SleepSynthesizer()]),
])

self.row_index = min(max(0, row_index), len(self.dataset) - 1)
Expand All @@ -99,9 +105,9 @@ def __init__(
self.detail_mode = "events"
self.details_scroll = 0
self.details_page_lines = 12
self.hit_target_names = self._available_detector_names()
self.hit_target_names = self._available_hit_target_names()
self.hit_target = self.hit_target_names[0] if self.hit_target_names else None
self.search_status = "Use hit< / hit> to jump to the selected detector."
self.search_status = "Use hit< / hit> to jump to the selected event."

self._ignore_widget_events = False

Expand Down Expand Up @@ -131,18 +137,10 @@ def __init__(
self.detail_tab_buttons[label] = Button(ax, label)
self.detail_up_ax = self.fig.add_axes([0.935, 0.085, 0.022, 0.028])
self.detail_down_ax = self.fig.add_axes([0.958, 0.085, 0.022, 0.028])
self.prev_hit_ax = self.fig.add_axes([0.76, 0.698, 0.035, 0.038])
self.next_hit_ax = self.fig.add_axes([0.945, 0.698, 0.035, 0.038])
self.hit_target_buttons: dict[str, Button] = {}
if self.hit_target_names:
n_targets = len(self.hit_target_names)
total_gap = 0.004 * max(0, n_targets - 1)
available_width = 0.145
width = (available_width - total_gap) / max(1, n_targets)
for i, detector_name in enumerate(self.hit_target_names):
x0 = 0.805 + i * (width + 0.004)
ax = self.fig.add_axes([x0, 0.698, width, 0.038])
self.hit_target_buttons[detector_name] = Button(ax, detector_name.replace("Detector", "").lower())
self.prev_target_ax = self.fig.add_axes([0.76, 0.698, 0.045, 0.038])
self.next_target_ax = self.fig.add_axes([0.81, 0.698, 0.045, 0.038])
self.prev_hit_ax = self.fig.add_axes([0.895, 0.698, 0.040, 0.038])
self.next_hit_ax = self.fig.add_axes([0.94, 0.698, 0.040, 0.038])
overlay_labels = ["trend", "spike", "drop", "nonwear"]
self.overlay_buttons: dict[str, Button] = {}
start_x = 0.83
Expand All @@ -164,6 +162,8 @@ def __init__(
self.reset_zoom_button = Button(self.reset_zoom_ax, "reset")
self.prev_row_button = Button(self.prev_row_ax, "<")
self.next_row_button = Button(self.next_row_ax, ">")
self.prev_target_button = Button(self.prev_target_ax, "target<")
self.next_target_button = Button(self.next_target_ax, "target>")
self.prev_hit_button = Button(self.prev_hit_ax, "hit<")
self.next_hit_button = Button(self.next_hit_ax, "hit>")
self.detail_up_button = Button(self.detail_up_ax, "^")
Expand All @@ -175,12 +175,12 @@ def __init__(
self.reset_zoom_button.on_clicked(lambda _: self.render(reset_zoom=True))
self.prev_row_button.on_clicked(lambda _: self._set_row(self.row_index - 1))
self.next_row_button.on_clicked(lambda _: self._set_row(self.row_index + 1))
self.prev_target_button.on_clicked(lambda _: self._cycle_hit_target(-1))
self.next_target_button.on_clicked(lambda _: self._cycle_hit_target(1))
self.prev_hit_button.on_clicked(lambda _: self._jump_to_hit(-1))
self.next_hit_button.on_clicked(lambda _: self._jump_to_hit(1))
self.detail_up_button.on_clicked(lambda _: self._scroll_details(-1))
self.detail_down_button.on_clicked(lambda _: self._scroll_details(1))
for detector_name, button in self.hit_target_buttons.items():
button.on_clicked(lambda _, name=detector_name: self._set_hit_target(name))
for label, button in self.detail_tab_buttons.items():
button.on_clicked(lambda _, name=label: self._set_detail_mode(name))
for label, button in self.overlay_buttons.items():
Expand All @@ -198,13 +198,19 @@ def _load_row_bundle(self, row_index: int) -> Recording:
recording.annotations.extend(self.annotator.annotate(recording))
return recording

def _available_detector_names(self) -> list[str]:
def _available_hit_target_names(self) -> list[str]:
names = {
detector.__class__.__name__
for detectors in self.channel_config.detectors.values()
for detector in detectors
}
return sorted(names)
return sorted(names) + [IN_BED_AWAKE_TARGET]

@staticmethod
def _hit_target_button_label(hit_target: str) -> str:
if hit_target == IN_BED_AWAKE_TARGET:
return "inbed+awake"
return hit_target.replace("Detector", "").lower()

def _set_row(self, row_index: int) -> None:
row_index = min(max(0, int(row_index)), len(self.dataset) - 1)
Expand Down Expand Up @@ -249,10 +255,18 @@ def _on_toggle(self, label: str) -> None:
def _set_hit_target(self, detector_name: str) -> None:
if detector_name != self.hit_target:
self.hit_target = detector_name
self.search_status = f"Jump target set to {detector_name}."
self._update_hit_target_button_styles()
self.search_status = f"Jump target set to {self._hit_target_button_label(detector_name)}."
self.render(reset_zoom=False)

def _cycle_hit_target(self, step: int) -> None:
if not self.hit_target_names:
return
if self.hit_target not in self.hit_target_names:
self._set_hit_target(self.hit_target_names[0])
return
current = self.hit_target_names.index(self.hit_target)
self._set_hit_target(self.hit_target_names[(current + step) % len(self.hit_target_names)])

def _set_detail_mode(self, mode: str) -> None:
if mode != self.detail_mode:
self.detail_mode = mode
Expand Down Expand Up @@ -285,6 +299,10 @@ def _on_key_press(self, event) -> None:
self._jump_to_hit(1)
elif event.key == "p":
self._jump_to_hit(-1)
elif event.key == "[":
self._cycle_hit_target(-1)
elif event.key == "]":
self._cycle_hit_target(1)

def _on_scroll(self, event) -> None:
if event.inaxes is not self.ax_details:
Expand Down Expand Up @@ -315,9 +333,16 @@ def _row_detector_events(self, row_index: int) -> tuple[tuple[tuple[str, object]
recording = self._load_row_bundle(row_index)
return tuple(tuple(self._detector_events(signal)) for signal in recording.iter_channels())

def _signal_has_hit_target(self, row_index: int, signal_index: int) -> bool:
if self.hit_target == IN_BED_AWAKE_TARGET:
recording = self._load_row_bundle(row_index)
return bool(self._cross_channel_windows(recording, signal_index, IN_BED_AWAKE_LABEL))
row_signal_events = self._row_detector_events(row_index)[signal_index]
return any(detector_name == self.hit_target for detector_name, _ in row_signal_events)

def _jump_to_hit(self, step: int) -> None:
if self.hit_target is None:
self.search_status = "No detector target is available."
self.search_status = "No search target is available."
self.render(reset_zoom=False)
return

Expand All @@ -328,17 +353,19 @@ def _jump_to_hit(self, step: int) -> None:
for offset in range(1, n_rows * n_signals):
candidate = (flat_index + step * offset) % (n_rows * n_signals)
row_index, signal_index = divmod(candidate, n_signals)
row_signal_events = self._row_detector_events(row_index)[signal_index]
if any(detector_name == self.hit_target for detector_name, _ in row_signal_events):
if self._signal_has_hit_target(row_index, signal_index):
self.row_index = row_index
self.signal_index = signal_index
self.details_scroll = 0
self.search_status = f"Jumped to row {row_index}, signal {signal_index} with {self.hit_target}."
self.search_status = (
f"Jumped to row {row_index}, signal {signal_index} "
f"with {self._hit_target_button_label(self.hit_target)}."
)
self._sync_widgets()
self.render(reset_zoom=True)
return

self.search_status = f"No hits found for {self.hit_target} in the scanned dataset."
self.search_status = f"No hits found for {self._hit_target_button_label(self.hit_target)} in the scanned dataset."
Comment thread
KarlDeck marked this conversation as resolved.
self.render(reset_zoom=False)

@staticmethod
Expand All @@ -350,6 +377,14 @@ def _captions_for_signal(recording: Recording, signal_idx: int) -> dict[str, lis
grouped.setdefault(annotation.caption_type, []).append(annotation.text)
return grouped

@staticmethod
def _cross_channel_windows(recording: Recording, signal_idx: int, label: str) -> list[tuple[int, int]]:
return [
annotation.window
for annotation in recording.annotations_for_signal(signal_idx)
if annotation.label == label and annotation.window is not None
]

@staticmethod
def _overview_matrix(recording: Recording) -> np.ma.MaskedArray:
rows = []
Expand Down Expand Up @@ -381,16 +416,15 @@ def _style_widgets(self) -> None:
):
button.label.set_fontsize(8.5)

for button in self.hit_target_buttons.values():
button.label.set_fontsize(8)
for button in self.detail_tab_buttons.values():
button.label.set_fontsize(7.5)
self.detail_up_button.label.set_fontsize(8)
self.detail_down_button.label.set_fontsize(8)
self.prev_target_button.label.set_fontsize(7.2)
self.next_target_button.label.set_fontsize(7.2)
for button in self.overlay_buttons.values():
button.label.set_fontsize(7)
self._sync_widgets()
self._update_hit_target_button_styles()
self._update_detail_tab_styles()
self._update_overlay_button_styles()

Expand Down Expand Up @@ -418,22 +452,6 @@ def _update_overlay_button_styles(self) -> None:
button.label.set_color(text)
button.label.set_fontweight("bold" if enabled else "normal")

def _update_hit_target_button_styles(self) -> None:
for detector_name, button in self.hit_target_buttons.items():
active = detector_name == self.hit_target
face = "#204a87" if active else "#f4f5f7"
edge = "#f4d35e" if active else "#c7cdd6"
text = "white" if active else "#5f6b7a"
button.ax.set_facecolor(face)
button.ax.patch.set_edgecolor(edge)
button.ax.patch.set_linewidth(2.0 if active else 1.1)
for spine in button.ax.spines.values():
spine.set_edgecolor(edge)
spine.set_linewidth(2.0 if active else 1.1)
button.hovercolor = "#3465a4" if active else "#eaedf2"
button.label.set_color(text)
button.label.set_fontweight("bold" if active else "normal")

@staticmethod
def _build_detail_lines(title: str, lines: list[str], width: int) -> list[str]:
rendered = [title]
Expand Down Expand Up @@ -469,6 +487,7 @@ def render(self, reset_zoom: bool = False) -> None:
signal = recording.signal(self.signal_index)
detector_events = self._detector_events(signal)
captions = self._captions_for_signal(recording, self.signal_index)
cross_channel_windows = self._cross_channel_windows(recording, self.signal_index, IN_BED_AWAKE_LABEL)
display_name = signal.display_name
unit = signal.unit or ""
_, _, decimals = self.channel_config.meta.get(signal.name, (signal.name, "", 2))
Expand All @@ -493,6 +512,8 @@ def render(self, reset_zoom: bool = False) -> None:
if self.show_nonwear:
for start, end in _nan_regions(y):
self.ax_main.axvspan(start, end, color="#d62728", alpha=0.08, label="nonwear")
for start, end in cross_channel_windows:
self.ax_main.axvspan(start, end, color="#f4d35e", alpha=0.22, label="in bed, awake")

for detector_name, result in detector_events:
if result.event_type == "trend" and self.show_trends:
Expand Down Expand Up @@ -609,14 +630,26 @@ def render(self, reset_zoom: bool = False) -> None:
self.ax_hit_target.text(
0.0,
1.02,
f"Find {self.hit_target.replace('Detector', '').lower()} hits" if self.hit_target else "Find hits",
"Find Hits",
transform=self.ax_hit_target.transAxes,
ha="left",
va="bottom",
fontsize=8.5,
color="#444444",
weight="bold",
)
self.ax_hit_target.text(
0.0,
0.55,
self._hit_target_button_label(self.hit_target),
transform=self.ax_hit_target.transAxes,
ha="left",
va="center",
fontsize=8.5,
color="#333333",
family="monospace",
weight="bold",
)
self.ax_hit_target.text(
1.0,
1.02,
Expand Down Expand Up @@ -692,9 +725,10 @@ def render(self, reset_zoom: bool = False) -> None:
"click a signal on the right or the overview heatmap to change channel",
"up/down changes row",
"left/right changes signal",
"choose a detector target above the signal list",
"n / p jump to next or previous hit for that detector",
"target< / target> changes the jump target",
"n / p jump to next or previous hit for that target",
"hit< / hit> buttons do the same",
"[ / ] also changes the jump target",
"mouse wheel over details scrolls",
"PageUp/PageDown also scroll details",
"overlay buttons toggle detector layers",
Expand Down
2 changes: 1 addition & 1 deletion extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

DEFAULT_TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json"

VALID_CAPTION_TYPES = ("statistical", "structural", "semantic")
VALID_CAPTION_TYPES = ("statistical", "structural", "semantic", "cross_channel")


_ACTIVITY_RE = re.compile(r"HKWorkoutActivityType(.+)$")
Expand Down
25 changes: 25 additions & 0 deletions extractors/cross_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

from extractors import CaptionExtractor, ChannelConfig
from synthesizers import CrossChannelSynthesizer
from timef.schema import Annotation, Recording


class CrossChannelExtractor(CaptionExtractor):
caption_type = "cross_channel"

def __init__(self, config: ChannelConfig, synthesizers: list[CrossChannelSynthesizer]):
super().__init__(config)
self.synthesizers = synthesizers

def extract(self, row: Recording) -> list[Annotation]:
results: list[Annotation] = []
for synth in self.synthesizers:
results.extend(synth.synthesize(row, self.config))
return results
31 changes: 31 additions & 0 deletions synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md)
# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project.
#
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

import abc

import numpy as np

from extractors import ChannelConfig
from timef.schema import Annotation, Recording


class CrossChannelSynthesizer(abc.ABC):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the all the helper methods in __init__.py are better suited in a _helper.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in c2be771

@abc.abstractmethod
def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: ...


def contiguous_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]:
if not mask.any():
return []

padded = np.concatenate(([False], mask, [False]))
diffs = np.diff(padded.astype(np.int8))
starts = np.where(diffs == 1)[0]
ends = np.where(diffs == -1)[0]
keep = (ends - starts) >= min_duration
return list(zip(starts[keep].tolist(), ends[keep].tolist()))
Loading