diff --git a/captionizer.py b/captionizer.py index 9e2326c..bcc71fe 100644 --- a/captionizer.py +++ b/captionizer.py @@ -53,6 +53,8 @@ def run( from mhc.dataset import MHCDataset from mhc.transformer import MHCTransformer from mhc.constants import MHC_CHANNEL_CONFIG + from extractors.cross_channel import CrossChannelExtractor + from synthesizers.sleep import SleepSynthesizer from extractors.statistical import StatisticalExtractor from extractors.structural import StructuralExtractor from models.local import LocalConfig, LocalModel @@ -66,6 +68,7 @@ def run( StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG), SemanticExtractor(MHC_CHANNEL_CONFIG), + CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]), ]) captionizer = Captionizer(dataset, MHCTransformer(), annotator) diff --git a/explorer.py b/explorer.py index e1f14ef..a2f44f0 100644 --- a/explorer.py +++ b/explorer.py @@ -17,6 +17,8 @@ from annotator import Annotator from extractors import ChannelConfig +from extractors.cross_channel import CrossChannelExtractor +from synthesizers.sleep import SleepSynthesizer from extractors.statistical import StatisticalExtractor from extractors.structural import StructuralExtractor from mhc.constants import MHC_CHANNEL_CONFIG @@ -25,6 +27,9 @@ from timef.schema import Recording, SignalView from transformer import Transformer +IN_BED_AWAKE_LABEL = "in_bed_not_sleeping" +IN_BED_AWAKE_TARGET = "InBedAwake" + def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Interactive explorer for MHC rows, signals, and detector outputs.") @@ -93,6 +98,7 @@ def __init__( self.annotator = Annotator([ StatisticalExtractor(channel_config), StructuralExtractor(channel_config), + CrossChannelExtractor(channel_config, synthesizers=[SleepSynthesizer()]), ]) self.row_index = min(max(0, row_index), len(self.dataset) - 1) @@ -103,7 +109,7 @@ def __init__( self.detail_mode = "events" self.details_scroll = 0 self.details_page_lines = 12 - self.hit_target_names = self._available_detector_names() + self.hit_target_names = self._available_hit_target_names() self.hit_target = self.hit_target_names[0] if self.hit_target_names else None self.search_status = "Use hit< / hit> to jump to the selected event type." @@ -202,13 +208,13 @@ def _load_row_bundle(self, row_index: int) -> Recording: recording.annotations.extend(self.annotator.annotate(recording)) return recording - def _available_detector_names(self) -> list[str]: + def _available_hit_target_names(self) -> list[str]: names = { detector.__class__.__name__ for detectors in self.channel_config.detectors.values() for detector in detectors } - return sorted(names) + return sorted(names) + [IN_BED_AWAKE_TARGET] @staticmethod def _matches_hit_target(hit_target: str, detector_name: str, result: object) -> bool: @@ -260,6 +266,15 @@ def _set_hit_target(self, detector_name: str) -> None: self._update_hit_target_button_styles() self.render(reset_zoom=False) + def _cycle_hit_target(self, step: int) -> None: + if not self.hit_target_names: + return + if self.hit_target not in self.hit_target_names: + self._set_hit_target(self.hit_target_names[0]) + return + current = self.hit_target_names.index(self.hit_target) + self._set_hit_target(self.hit_target_names[(current + step) % len(self.hit_target_names)]) + def _set_detail_mode(self, mode: str) -> None: if mode != self.detail_mode: self.detail_mode = mode @@ -292,6 +307,10 @@ def _on_key_press(self, event) -> None: self._jump_to_hit(1) elif event.key == "p": self._jump_to_hit(-1) + elif event.key == "[": + self._cycle_hit_target(-1) + elif event.key == "]": + self._cycle_hit_target(1) def _on_scroll(self, event) -> None: if event.inaxes is not self.ax_details: @@ -332,9 +351,16 @@ def _row_detector_events(self, row_index: int) -> tuple[tuple[tuple[str, object] recording = self._load_row_bundle(row_index) return tuple(tuple(self._detector_events(signal)) for signal in recording.iter_channels()) + def _signal_has_hit_target(self, row_index: int, signal_index: int) -> bool: + if self.hit_target == IN_BED_AWAKE_TARGET: + recording = self._load_row_bundle(row_index) + return bool(self._cross_channel_windows(recording, signal_index, IN_BED_AWAKE_LABEL)) + row_signal_events = self._row_detector_events(row_index)[signal_index] + return any(detector_name == self.hit_target for detector_name, _ in row_signal_events) + def _jump_to_hit(self, step: int) -> None: if self.hit_target is None: - self.search_status = "No detector target is available." + self.search_status = "No search target is available." self.render(reset_zoom=False) return @@ -345,8 +371,7 @@ def _jump_to_hit(self, step: int) -> None: for offset in range(1, n_rows * n_signals): candidate = (flat_index + step * offset) % (n_rows * n_signals) row_index, signal_index = divmod(candidate, n_signals) - row_signal_events = self._row_detector_events(row_index)[signal_index] - if any(self._matches_hit_target(self.hit_target, detector_name, result) for detector_name, result in row_signal_events): + if self._signal_has_hit_target(row_index, signal_index): self.row_index = row_index self.signal_index = signal_index self.details_scroll = 0 @@ -367,6 +392,14 @@ def _captions_for_signal(recording: Recording, signal_idx: int) -> dict[str, lis grouped.setdefault(annotation.caption_type, []).append(annotation.text) return grouped + @staticmethod + def _cross_channel_windows(recording: Recording, signal_idx: int, label: str) -> list[tuple[int, int]]: + return [ + annotation.window + for annotation in recording.annotations_for_signal(signal_idx) + if annotation.label == label and annotation.window is not None + ] + @staticmethod def _overview_matrix(recording: Recording) -> np.ma.MaskedArray: rows = [] @@ -398,12 +431,12 @@ def _style_widgets(self) -> None: ): button.label.set_fontsize(8.5) - for button in self.hit_target_buttons.values(): - button.label.set_fontsize(8) for button in self.detail_tab_buttons.values(): button.label.set_fontsize(7.5) self.detail_up_button.label.set_fontsize(8) self.detail_down_button.label.set_fontsize(8) + for button in self.hit_target_buttons.values(): + button.label.set_fontsize(8) for button in self.overlay_buttons.values(): button.label.set_fontsize(7) self._sync_widgets() @@ -485,6 +518,7 @@ def render(self, reset_zoom: bool = False) -> None: detector_events = self._detector_events(signal) spike_labels = self._spike_labels(detector_events) captions = self._captions_for_signal(recording, self.signal_index) + cross_channel_windows = self._cross_channel_windows(recording, self.signal_index, IN_BED_AWAKE_LABEL) display_name = signal.display_name unit = signal.unit or "" _, _, decimals = self.channel_config.meta.get(signal.name, (signal.name, "", 2)) @@ -506,6 +540,8 @@ def render(self, reset_zoom: bool = False) -> None: ax.axis("off") self.ax_main.plot(x[valid], y[valid], color="steelblue", linewidth=1.0, label="signal") + for start, end in cross_channel_windows: + self.ax_main.axvspan(start, end, color="#f4d35e", alpha=0.22, label="in bed, awake") for detector_name, result in detector_events: if result.event_type == "trend" and self.show_trends: @@ -631,6 +667,18 @@ def render(self, reset_zoom: bool = False) -> None: color="#444444", weight="bold", ) + self.ax_hit_target.text( + 0.0, + 0.55, + _hit_target_label(self.hit_target), + transform=self.ax_hit_target.transAxes, + ha="left", + va="center", + fontsize=8.5, + color="#333333", + family="monospace", + weight="bold", + ) self.ax_hit_target.text( 1.0, 1.02, @@ -706,9 +754,10 @@ def render(self, reset_zoom: bool = False) -> None: "click a signal on the right or the overview heatmap to change channel", "up/down changes row", "left/right changes signal", - "choose a detector target above the signal list", - "n / p jump to next or previous hit for that detector", + "target< / target> changes the jump target", + "n / p jump to next or previous hit for that target", "hit< / hit> buttons do the same", + "[ / ] also changes the jump target", "mouse wheel over details scrolls", "PageUp/PageDown also scroll details", "overlay buttons toggle detector layers", diff --git a/extractors/__init__.py b/extractors/__init__.py index b97e004..9a94a6d 100644 --- a/extractors/__init__.py +++ b/extractors/__init__.py @@ -18,7 +18,7 @@ DEFAULT_TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json" -VALID_CAPTION_TYPES = ("statistical", "structural", "semantic") +VALID_CAPTION_TYPES = ("statistical", "structural", "semantic", "cross_channel") _ACTIVITY_RE = re.compile(r"HKWorkoutActivityType(.+)$") diff --git a/extractors/cross_channel.py b/extractors/cross_channel.py new file mode 100644 index 0000000..df08f21 --- /dev/null +++ b/extractors/cross_channel.py @@ -0,0 +1,25 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +from extractors import CaptionExtractor, ChannelConfig +from synthesizers import CrossChannelSynthesizer +from timef.schema import Annotation, Recording + + +class CrossChannelExtractor(CaptionExtractor): + caption_type = "cross_channel" + + def __init__(self, config: ChannelConfig, synthesizers: list[CrossChannelSynthesizer]): + super().__init__(config) + self.synthesizers = synthesizers + + def extract(self, row: Recording) -> list[Annotation]: + results: list[Annotation] = [] + for synth in self.synthesizers: + results.extend(synth.synthesize(row, self.config)) + return results diff --git a/synthesizers/__init__.py b/synthesizers/__init__.py new file mode 100644 index 0000000..5987202 --- /dev/null +++ b/synthesizers/__init__.py @@ -0,0 +1,31 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import abc + +import numpy as np + +from extractors import ChannelConfig +from timef.schema import Annotation, Recording + + +class CrossChannelSynthesizer(abc.ABC): + @abc.abstractmethod + def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: ... + + +def contiguous_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]: + if not mask.any(): + return [] + + padded = np.concatenate(([False], mask, [False])) + diffs = np.diff(padded.astype(np.int8)) + starts = np.where(diffs == 1)[0] + ends = np.where(diffs == -1)[0] + keep = (ends - starts) >= min_duration + return list(zip(starts[keep].tolist(), ends[keep].tolist())) diff --git a/synthesizers/sleep.py b/synthesizers/sleep.py new file mode 100644 index 0000000..e1b0b47 --- /dev/null +++ b/synthesizers/sleep.py @@ -0,0 +1,53 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import json + +import numpy as np + +from extractors import CaptionExtractor, ChannelConfig +from synthesizers import CrossChannelSynthesizer, contiguous_windows +from timef.schema import Annotation, Recording + + +class SleepSynthesizer(CrossChannelSynthesizer): + def __init__(self, min_duration: int = 5): + self.min_duration = min_duration + + def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: + try: + in_bed_idx = row.channel_names.index("sleep:inbed") + asleep_idx = row.channel_names.index("sleep:asleep") + except ValueError: + return [] + + in_bed = np.asarray(row.values[in_bed_idx], dtype=float) + asleep = np.asarray(row.values[asleep_idx], dtype=float) + if not np.any((~np.isnan(asleep)) & (asleep > 0)): + return [] + mask = (~np.isnan(in_bed)) & (in_bed > 0) & ~((~np.isnan(asleep)) & (asleep > 0)) + + time_unit = "hour" if config.time_unit == "hours" else "minute" + templates = json.loads(config.templates_path.read_text())["cross_channel"]["sleep"] + seed = CaptionExtractor._seed(row.row_id) + + results: list[Annotation] = [] + for i, (start, end) in enumerate(contiguous_windows(mask, self.min_duration)): + end_inclusive = max(start, end - 1) + template = templates[(seed + i) % len(templates)] + text = template.format(time_unit=time_unit, start=start, end=end_inclusive) + results.append( + Annotation( + caption_type="cross_channel", + text=text, + channel_idxs=(asleep_idx, in_bed_idx), + window=(start, end), + label="in_bed_not_sleeping", + ) + ) + return results diff --git a/templates/templates.json b/templates/templates.json index 60b79ff..e124c43 100644 --- a/templates/templates.json +++ b/templates/templates.json @@ -35,5 +35,14 @@ "{name} occurred at minute {start}.", "{name} was observed at minute {start}.", "Identified {name} at minute {start}." - ] + ], + "cross_channel": { + "sleep": [ + "In bed but not sleeping from {time_unit} {start} to {end}.", + "The user was in bed but awake from {time_unit} {start} to {end}.", + "Awake while in bed between {time_unit} {start} and {end}.", + "In-bed wakefulness detected from {time_unit} {start} to {end}.", + "The user remained awake in bed from {time_unit} {start} to {end}." + ] + } } diff --git a/visualizer.py b/visualizer.py index a1583d4..4e0ef35 100644 --- a/visualizer.py +++ b/visualizer.py @@ -114,13 +114,19 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] from mhc.dataset import MHCDataset from mhc.transformer import MHCTransformer from mhc.constants import MHC_CHANNEL_CONFIG + from extractors.cross_channel import CrossChannelExtractor + from synthesizers.sleep import SleepSynthesizer from extractors.statistical import StatisticalExtractor from extractors.structural import StructuralExtractor from annotator import Annotator from captionizer import Captionizer dataset = MHCDataset(min_wear_pct=90.0) - annotator = Annotator([StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG)]) + annotator = Annotator([ + StatisticalExtractor(MHC_CHANNEL_CONFIG), + StructuralExtractor(MHC_CHANNEL_CONFIG), + CrossChannelExtractor(MHC_CHANNEL_CONFIG, synthesizers=[SleepSynthesizer()]), + ]) captionizer = Captionizer(dataset, MHCTransformer(), annotator) result, _ = captionizer.run(max_rows=1) for row in result.iter_rows():