diff --git a/captionizer.py b/captionizer.py index 9e2326c..c20c319 100644 --- a/captionizer.py +++ b/captionizer.py @@ -53,6 +53,10 @@ def run( from mhc.dataset import MHCDataset from mhc.transformer import MHCTransformer from mhc.constants import MHC_CHANNEL_CONFIG + from extractors.cross_channel import CrossChannelExtractor + from synthesizers.cardio import CardioSynthesizer + from synthesizers.sleep import SleepSynthesizer + from synthesizers.stationary_activity import StationaryActivitySynthesizer from extractors.statistical import StatisticalExtractor from extractors.structural import StructuralExtractor from models.local import LocalConfig, LocalModel @@ -66,6 +70,14 @@ def run( StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG), SemanticExtractor(MHC_CHANNEL_CONFIG), + CrossChannelExtractor( + MHC_CHANNEL_CONFIG, + synthesizers=[ + SleepSynthesizer(min_duration=5), + StationaryActivitySynthesizer(min_duration=5), + CardioSynthesizer(min_duration=5), + ], + ), ]) captionizer = Captionizer(dataset, MHCTransformer(), annotator) diff --git a/explorer.py b/explorer.py index 21ce81b..8b3b2fe 100644 --- a/explorer.py +++ b/explorer.py @@ -17,14 +17,64 @@ from annotator import Annotator from extractors import ChannelConfig +from extractors.cross_channel import CrossChannelExtractor from extractors.semantic import SemanticExtractor +from extractors.statistical import StatisticalExtractor from extractors.structural import StructuralExtractor +from synthesizers.cardio import CardioSynthesizer +from synthesizers.sleep import SleepSynthesizer +from synthesizers.stationary_activity import StationaryActivitySynthesizer from mhc.constants import MHC_CHANNEL_CONFIG from mhc.dataset import MHCDataset from mhc.transformer import MHCTransformer from timef.schema import Recording, SignalView from transformer import Transformer +CROSS_CHANNEL_TARGETS = { + "InBedAwake": { + "label": "in_bed_not_sleeping", + "display": "inbed+awake", + "overlay_label": "in bed, awake", + "color": "#f4d35e", + "alpha": 0.22, + }, + "StationaryHIIT": { + "label": "stationary_hiit", + "display": "stationary hiit", + "overlay_label": "stationary HIIT", + "color": "#f28482", + "alpha": 0.20, + }, + "StationaryStrength": { + "label": "stationary_strength", + "display": "stationary strength", + "overlay_label": "stationary strength", + "color": "#84a59d", + "alpha": 0.20, + }, + "StationaryFunctional": { + "label": "stationary_functional", + "display": "stationary functional", + "overlay_label": "stationary functional", + "color": "#577590", + "alpha": 0.20, + }, + "CardioRunning": { + "label": "cardio_running", + "display": "cardio running", + "overlay_label": "cardio running", + "color": "#43aa8b", + "alpha": 0.20, + }, + "CardioCycling": { + "label": "cardio_cycling", + "display": "cardio cycling", + "overlay_label": "cardio cycling", + "color": "#277da1", + "alpha": 0.20, + }, +} + def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Interactive explorer for MHC rows, signals, and detector outputs.") @@ -87,26 +137,41 @@ def __init__( channel_config: ChannelConfig, row_index: int = 0, signal_index: int = 0, + include_cross_channel: bool = True, ) -> None: self.dataset = dataset self.transformer = transformer self.channel_config = channel_config - self.annotator = Annotator([ + extractors = [ + StatisticalExtractor(channel_config), StructuralExtractor(channel_config), SemanticExtractor(channel_config), - ]) + ] + if include_cross_channel: + extractors.append( + CrossChannelExtractor( + channel_config, + synthesizers=[ + SleepSynthesizer(min_duration=5), + StationaryActivitySynthesizer(min_duration=5), + CardioSynthesizer(min_duration=5), + ], + ) + ) + self.annotator = Annotator(extractors) self.row_index = min(max(0, row_index), len(self.dataset) - 1) self.signal_index = min(max(0, signal_index), len(channel_config.names) - 1) self.show_trends = True self.show_spikes = True + self.show_nonwear = True self.detail_mode = "events" self.details_scroll = 0 self.details_page_lines = 12 - self.hit_target_names = self._available_detector_names() + self.hit_target_names = self._available_hit_target_names() self.hit_target = self.hit_target_names[0] if self.hit_target_names else None - self.search_status = "Use hit< / hit> to jump to the selected event type." + self.search_status = "Use hit< / hit> to jump to the selected event." self._ignore_widget_events = False @@ -136,19 +201,11 @@ def __init__( self.detail_tab_buttons[label] = Button(ax, label) self.detail_up_ax = self.fig.add_axes([0.935, 0.085, 0.022, 0.028]) self.detail_down_ax = self.fig.add_axes([0.958, 0.085, 0.022, 0.028]) - self.prev_hit_ax = self.fig.add_axes([0.76, 0.698, 0.035, 0.038]) - self.next_hit_ax = self.fig.add_axes([0.945, 0.698, 0.035, 0.038]) - self.hit_target_buttons: dict[str, Button] = {} - if self.hit_target_names: - n_targets = len(self.hit_target_names) - total_gap = 0.004 * max(0, n_targets - 1) - available_width = 0.145 - width = (available_width - total_gap) / max(1, n_targets) - for i, detector_name in enumerate(self.hit_target_names): - x0 = 0.805 + i * (width + 0.004) - ax = self.fig.add_axes([x0, 0.698, width, 0.038]) - self.hit_target_buttons[detector_name] = Button(ax, _hit_target_label(detector_name)) - overlay_labels = ["trend", "spike"] + self.prev_target_ax = self.fig.add_axes([0.76, 0.698, 0.045, 0.038]) + self.next_target_ax = self.fig.add_axes([0.81, 0.698, 0.045, 0.038]) + self.prev_hit_ax = self.fig.add_axes([0.895, 0.698, 0.040, 0.038]) + self.next_hit_ax = self.fig.add_axes([0.94, 0.698, 0.040, 0.038]) + overlay_labels = ["trend", "spike", "nonwear"] self.overlay_buttons: dict[str, Button] = {} start_x = 0.83 button_width = 0.035 @@ -169,6 +226,8 @@ def __init__( self.reset_zoom_button = Button(self.reset_zoom_ax, "reset") self.prev_row_button = Button(self.prev_row_ax, "<") self.next_row_button = Button(self.next_row_ax, ">") + self.prev_target_button = Button(self.prev_target_ax, "target<") + self.next_target_button = Button(self.next_target_ax, "target>") self.prev_hit_button = Button(self.prev_hit_ax, "hit<") self.next_hit_button = Button(self.next_hit_ax, "hit>") self.detail_up_button = Button(self.detail_up_ax, "^") @@ -180,12 +239,12 @@ def __init__( self.reset_zoom_button.on_clicked(lambda _: self.render(reset_zoom=True)) self.prev_row_button.on_clicked(lambda _: self._set_row(self.row_index - 1)) self.next_row_button.on_clicked(lambda _: self._set_row(self.row_index + 1)) + self.prev_target_button.on_clicked(lambda _: self._cycle_hit_target(-1)) + self.next_target_button.on_clicked(lambda _: self._cycle_hit_target(1)) self.prev_hit_button.on_clicked(lambda _: self._jump_to_hit(-1)) self.next_hit_button.on_clicked(lambda _: self._jump_to_hit(1)) self.detail_up_button.on_clicked(lambda _: self._scroll_details(-1)) self.detail_down_button.on_clicked(lambda _: self._scroll_details(1)) - for detector_name, button in self.hit_target_buttons.items(): - button.on_clicked(lambda _, name=detector_name: self._set_hit_target(name)) for label, button in self.detail_tab_buttons.items(): button.on_clicked(lambda _, name=label: self._set_detail_mode(name)) for label, button in self.overlay_buttons.items(): @@ -203,13 +262,26 @@ def _load_row_bundle(self, row_index: int) -> Recording: recording.annotations.extend(self.annotator.annotate(recording)) return recording - def _available_detector_names(self) -> list[str]: + def _available_hit_target_names(self) -> list[str]: names = { detector.__class__.__name__ for detectors in self.channel_config.detectors.values() for detector in detectors } - return sorted(names) + return sorted(names) + list(CROSS_CHANNEL_TARGETS) + + @staticmethod + def _hit_target_button_label(hit_target: str) -> str: + if hit_target in CROSS_CHANNEL_TARGETS: + return CROSS_CHANNEL_TARGETS[hit_target]["display"] + return hit_target.replace("Detector", "").lower() + + @staticmethod + def _matches_hit_target(hit_target: str, detector_name: str, result: object) -> bool: + if ":" not in hit_target: + return detector_name == hit_target + target_detector_name, target_event_type = hit_target.split(":", 1) + return detector_name == target_detector_name and getattr(result, "event_type", None) == target_event_type @staticmethod def _matches_hit_target(hit_target: str, detector_name: str, result: object) -> bool: @@ -251,16 +323,26 @@ def _on_toggle(self, label: str) -> None: self.show_trends = not self.show_trends elif label == "spike": self.show_spikes = not self.show_spikes + elif label == "nonwear": + self.show_nonwear = not self.show_nonwear self._update_overlay_button_styles() self.render(reset_zoom=False) def _set_hit_target(self, detector_name: str) -> None: if detector_name != self.hit_target: self.hit_target = detector_name - self.search_status = f"Jump target set to {_hit_target_label(detector_name)}." - self._update_hit_target_button_styles() + self.search_status = f"Jump target set to {self._hit_target_button_label(detector_name)}." self.render(reset_zoom=False) + def _cycle_hit_target(self, step: int) -> None: + if not self.hit_target_names: + return + if self.hit_target not in self.hit_target_names: + self._set_hit_target(self.hit_target_names[0]) + return + current = self.hit_target_names.index(self.hit_target) + self._set_hit_target(self.hit_target_names[(current + step) % len(self.hit_target_names)]) + def _set_detail_mode(self, mode: str) -> None: if mode != self.detail_mode: self.detail_mode = mode @@ -293,6 +375,10 @@ def _on_key_press(self, event) -> None: self._jump_to_hit(1) elif event.key == "p": self._jump_to_hit(-1) + elif event.key == "[": + self._cycle_hit_target(-1) + elif event.key == "]": + self._cycle_hit_target(1) def _on_scroll(self, event) -> None: if event.inaxes is not self.ax_details: @@ -331,30 +417,76 @@ def _row_detector_events(self, row_index: int) -> tuple[tuple[tuple[str, object] recording = self._load_row_bundle(row_index) return tuple(tuple(self._detector_events(signal)) for signal in recording.iter_channels()) + @lru_cache(maxsize=256) + def _row_hit_signal_indices(self, row_index: int, target_name: str) -> tuple[int, ...]: + if target_name in CROSS_CHANNEL_TARGETS: + recording = self._load_row_bundle(row_index) + label = CROSS_CHANNEL_TARGETS[target_name]["label"] + return tuple( + signal_idx + for signal_idx in range(recording.values.shape[0]) + if self._cross_channel_windows(recording, signal_idx, label) + ) + + row_signal_events = self._row_detector_events(row_index) + return tuple( + signal_idx + for signal_idx, events in enumerate(row_signal_events) + if any(detector_name == target_name for detector_name, _ in events) + ) + + def _signal_has_hit_target(self, row_index: int, signal_index: int) -> bool: + if self.hit_target is None: + return False + return signal_index in self._row_hit_signal_indices(row_index, self.hit_target) + def _jump_to_hit(self, step: int) -> None: if self.hit_target is None: - self.search_status = "No detector target is available." + self.search_status = "No search target is available." self.render(reset_zoom=False) return n_rows = len(self.dataset) n_signals = len(self.channel_config.names) - flat_index = self.row_index * n_signals + self.signal_index - - for offset in range(1, n_rows * n_signals): - candidate = (flat_index + step * offset) % (n_rows * n_signals) - row_index, signal_index = divmod(candidate, n_signals) - row_signal_events = self._row_detector_events(row_index)[signal_index] - if any(self._matches_hit_target(self.hit_target, detector_name, result) for detector_name, result in row_signal_events): - self.row_index = row_index - self.signal_index = signal_index - self.details_scroll = 0 - self.search_status = f"Jumped to row {row_index}, signal {signal_index} with {_hit_target_label(self.hit_target)}." - self._sync_widgets() - self.render(reset_zoom=True) - return - - self.search_status = f"No hits found for {_hit_target_label(self.hit_target)} in the scanned dataset." + start_row = self.row_index + for row_offset in range(n_rows): + row_index = (start_row + step * row_offset) % n_rows + hit_signals = self._row_hit_signal_indices(row_index, self.hit_target) + if not hit_signals: + continue + + if row_index == self.row_index: + if step > 0: + candidates = [idx for idx in hit_signals if idx > self.signal_index] + if candidates: + signal_index = candidates[0] + elif row_offset == 0: + continue + else: + signal_index = hit_signals[0] + else: + candidates = [idx for idx in hit_signals if idx < self.signal_index] + if candidates: + signal_index = candidates[-1] + elif row_offset == 0: + continue + else: + signal_index = hit_signals[-1] + else: + signal_index = hit_signals[0] if step > 0 else hit_signals[-1] + + self.row_index = row_index + self.signal_index = signal_index + self.details_scroll = 0 + self.search_status = ( + f"Jumped to row {row_index}, signal {signal_index} " + f"with {self._hit_target_button_label(self.hit_target)}." + ) + self._sync_widgets() + self.render(reset_zoom=True) + return + + self.search_status = f"No hits found for {self._hit_target_button_label(self.hit_target)} in the scanned dataset." self.render(reset_zoom=False) @staticmethod @@ -366,6 +498,14 @@ def _captions_for_signal(recording: Recording, signal_idx: int) -> dict[str, lis grouped.setdefault(annotation.caption_type, []).append(annotation.text) return grouped + @staticmethod + def _cross_channel_windows(recording: Recording, signal_idx: int, label: str) -> list[tuple[int, int]]: + return [ + annotation.window + for annotation in recording.annotations_for_signal(signal_idx) + if annotation.label == label and annotation.window is not None + ] + @staticmethod def _overview_matrix(recording: Recording) -> np.ma.MaskedArray: rows = [] @@ -397,16 +537,15 @@ def _style_widgets(self) -> None: ): button.label.set_fontsize(8.5) - for button in self.hit_target_buttons.values(): - button.label.set_fontsize(8) for button in self.detail_tab_buttons.values(): button.label.set_fontsize(7.5) self.detail_up_button.label.set_fontsize(8) self.detail_down_button.label.set_fontsize(8) + self.prev_target_button.label.set_fontsize(7.2) + self.next_target_button.label.set_fontsize(7.2) for button in self.overlay_buttons.values(): button.label.set_fontsize(7) self._sync_widgets() - self._update_hit_target_button_styles() self._update_detail_tab_styles() self._update_overlay_button_styles() @@ -414,6 +553,7 @@ def _overlay_state(self, label: str) -> bool: return { "trend": self.show_trends, "spike": self.show_spikes, + "nonwear": self.show_nonwear, }[label] def _update_overlay_button_styles(self) -> None: @@ -432,22 +572,6 @@ def _update_overlay_button_styles(self) -> None: button.label.set_color(text) button.label.set_fontweight("bold" if enabled else "normal") - def _update_hit_target_button_styles(self) -> None: - for detector_name, button in self.hit_target_buttons.items(): - active = detector_name == self.hit_target - face = "#204a87" if active else "#f4f5f7" - edge = "#f4d35e" if active else "#c7cdd6" - text = "white" if active else "#5f6b7a" - button.ax.set_facecolor(face) - button.ax.patch.set_edgecolor(edge) - button.ax.patch.set_linewidth(2.0 if active else 1.1) - for spine in button.ax.spines.values(): - spine.set_edgecolor(edge) - spine.set_linewidth(2.0 if active else 1.1) - button.hovercolor = "#3465a4" if active else "#eaedf2" - button.label.set_color(text) - button.label.set_fontweight("bold" if active else "normal") - @staticmethod def _build_detail_lines(title: str, lines: list[str], width: int) -> list[str]: rendered = [title] @@ -484,6 +608,10 @@ def render(self, reset_zoom: bool = False) -> None: detector_events = self._detector_events(signal) spike_labels = self._spike_labels(detector_events) captions = self._captions_for_signal(recording, self.signal_index) + cross_channel_windows = { + target_name: self._cross_channel_windows(recording, self.signal_index, target["label"]) + for target_name, target in CROSS_CHANNEL_TARGETS.items() + } display_name = signal.display_name unit = signal.unit or "" _, _, decimals = self.channel_config.meta.get(signal.name, (signal.name, "", 2)) @@ -505,6 +633,19 @@ def render(self, reset_zoom: bool = False) -> None: ax.axis("off") self.ax_main.plot(x[valid], y[valid], color="steelblue", linewidth=1.0, label="signal") + if self.show_nonwear: + for start, end in _nan_regions(y): + self.ax_main.axvspan(start, end, color="#d62728", alpha=0.08, label="nonwear") + for target_name, windows in cross_channel_windows.items(): + target = CROSS_CHANNEL_TARGETS[target_name] + for start, end in windows: + self.ax_main.axvspan( + start, + end, + color=target["color"], + alpha=target["alpha"], + label=target["overlay_label"], + ) for detector_name, result in detector_events: if result.event_type == "trend" and self.show_trends: @@ -623,7 +764,7 @@ def render(self, reset_zoom: bool = False) -> None: self.ax_hit_target.text( 0.0, 1.02, - f"Find {_hit_target_label(self.hit_target)} hits" if self.hit_target else "Find hits", + "Find Hits", transform=self.ax_hit_target.transAxes, ha="left", va="bottom", @@ -631,6 +772,18 @@ def render(self, reset_zoom: bool = False) -> None: color="#444444", weight="bold", ) + self.ax_hit_target.text( + 0.0, + 0.55, + self._hit_target_button_label(self.hit_target), + transform=self.ax_hit_target.transAxes, + ha="left", + va="center", + fontsize=8.5, + color="#333333", + family="monospace", + weight="bold", + ) self.ax_hit_target.text( 1.0, 1.02, @@ -706,9 +859,10 @@ def render(self, reset_zoom: bool = False) -> None: "click a signal on the right or the overview heatmap to change channel", "up/down changes row", "left/right changes signal", - "choose a detector target above the signal list", - "n / p jump to next or previous hit for that detector", + "target< / target> changes the jump target", + "n / p jump to next or previous hit for that target", "hit< / hit> buttons do the same", + "[ / ] also changes the jump target", "mouse wheel over details scrolls", "PageUp/PageDown also scroll details", "overlay buttons toggle detector layers", @@ -778,6 +932,7 @@ def main() -> None: channel_config=channel_config, row_index=args.row_index, signal_index=args.signal_index, + include_cross_channel=not args.weekly, ) if args.save_path: explorer.save(args.save_path) diff --git a/extractors/__init__.py b/extractors/__init__.py index b97e004..107f0e6 100644 --- a/extractors/__init__.py +++ b/extractors/__init__.py @@ -9,16 +9,16 @@ import abc import pathlib import re -import zlib from dataclasses import dataclass, field from aggregators import MetricAggregator from detectors import StructuralDetector from timef.schema import Annotation, Recording +from util import seed_from_key DEFAULT_TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json" -VALID_CAPTION_TYPES = ("statistical", "structural", "semantic") +VALID_CAPTION_TYPES = ("statistical", "structural", "semantic", "cross_channel") _ACTIVITY_RE = re.compile(r"HKWorkoutActivityType(.+)$") @@ -59,7 +59,7 @@ def __init__(self, config: ChannelConfig): @staticmethod def _seed(key: str) -> int: - return zlib.crc32(key.encode("utf-8")) & 0xFFFFFFFF + return seed_from_key(key) @abc.abstractmethod def extract(self, row: Recording) -> list[Annotation]: diff --git a/extractors/cross_channel.py b/extractors/cross_channel.py new file mode 100644 index 0000000..df08f21 --- /dev/null +++ b/extractors/cross_channel.py @@ -0,0 +1,25 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +from extractors import CaptionExtractor, ChannelConfig +from synthesizers import CrossChannelSynthesizer +from timef.schema import Annotation, Recording + + +class CrossChannelExtractor(CaptionExtractor): + caption_type = "cross_channel" + + def __init__(self, config: ChannelConfig, synthesizers: list[CrossChannelSynthesizer]): + super().__init__(config) + self.synthesizers = synthesizers + + def extract(self, row: Recording) -> list[Annotation]: + results: list[Annotation] = [] + for synth in self.synthesizers: + results.extend(synth.synthesize(row, self.config)) + return results diff --git a/mhc/constants.py b/mhc/constants.py index b5e6389..662af21 100644 --- a/mhc/constants.py +++ b/mhc/constants.py @@ -58,6 +58,40 @@ ] SLEEP_CHANNELS = ["sleep:asleep", "sleep:inbed"] +WATCH_HR_CHANNEL = "hk_watch:HKQuantityTypeIdentifierHeartRate" +WATCH_DISTANCE_CHANNEL = "hk_watch:HKQuantityTypeIdentifierDistanceWalkingRunning" +WATCH_STEP_CHANNEL = "hk_watch:HKQuantityTypeIdentifierStepCount" + +CARDIO_WORKOUT_CHANNELS = ( + ( + "workout:HKWorkoutActivityTypeRunning", + "cardio_running", + "running", + ), + ( + "workout:HKWorkoutActivityTypeCycling", + "cardio_cycling", + "cycling", + ), +) + +STATIONARY_WORKOUT_CHANNELS = ( + ( + "workout:HKWorkoutActivityTypeHighIntensityIntervalTraining", + "stationary_hiit", + "stationary HIIT", + ), + ( + "workout:HKWorkoutActivityTypeTraditionalStrengthTraining", + "stationary_strength", + "traditional strength", + ), + ( + "workout:HKWorkoutActivityTypeFunctionalStrengthTraining", + "stationary_functional", + "functional", + ), +) MHC_CHANNEL_CONFIG = ChannelConfig( diff --git a/mhc_weekly/constants.py b/mhc_weekly/constants.py index f40923e..a8eee51 100644 --- a/mhc_weekly/constants.py +++ b/mhc_weekly/constants.py @@ -9,8 +9,8 @@ from extractors import ChannelConfig from aggregators import NonZeroAggregator -from detectors.spike import SpikeDetector from detectors.trend import TrendDetector +from detectors.spike import SpikeDetector from mhc.constants import ACTIVITY_CHANNELS, CHANNEL_NAMES, CONTINUOUS_CHANNELS, SLEEP_CHANNELS HOURLY_TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates_hourly.json" diff --git a/synthesizers/__init__.py b/synthesizers/__init__.py new file mode 100644 index 0000000..52ce8b3 --- /dev/null +++ b/synthesizers/__init__.py @@ -0,0 +1,10 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +from synthesizers._base import CrossChannelSynthesizer +from synthesizers._workout import WorkoutSynthesizer diff --git a/synthesizers/_base.py b/synthesizers/_base.py new file mode 100644 index 0000000..bfab02d --- /dev/null +++ b/synthesizers/_base.py @@ -0,0 +1,17 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import abc + +from extractors import ChannelConfig +from timef.schema import Annotation, Recording + + +class CrossChannelSynthesizer(abc.ABC): + @abc.abstractmethod + def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: ... diff --git a/synthesizers/_helper.py b/synthesizers/_helper.py new file mode 100644 index 0000000..977fcc9 --- /dev/null +++ b/synthesizers/_helper.py @@ -0,0 +1,329 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import json +from functools import lru_cache + +import numpy as np + +from extractors import ChannelConfig +from mhc.constants import WATCH_HR_CHANNEL +from timef.schema import Recording +from util import seed_from_key + + +def index_or_none(row: Recording, channel_name: str) -> int | None: + try: + return row.channel_names.index(channel_name) + except ValueError: + return None + + +def positive_metric_values(row: Recording, idx: int | None, start: int, end: int) -> np.ndarray | None: + if idx is None: + return None + values = np.asarray(row.values[idx][start:end], dtype=float) + valid = np.isfinite(values) & (values > 0) + if not valid.any(): + return None + return values[valid] + + +def metric_mean(row: Recording, idx: int | None, start: int, end: int) -> float | None: + values = positive_metric_values(row, idx, start, end) + if values is None: + return None + return float(np.mean(values)) + + +def metric_peak(row: Recording, idx: int | None, start: int, end: int) -> float | None: + values = positive_metric_values(row, idx, start, end) + if values is None: + return None + return float(np.max(values)) + + +def metric_total(row: Recording, idx: int | None, start: int, end: int) -> float | None: + values = positive_metric_values(row, idx, start, end) + if values is None: + return None + return float(np.sum(values)) + + +def metric_day_mean_delta(row: Recording, idx: int | None, start: int, end: int) -> float | None: + window_mean = metric_mean(row, idx, start, end) + if window_mean is None or idx is None: + return None + + day_values = positive_metric_values(row, idx, 0, row.values.shape[1]) + if day_values is None: + return None + return float(window_mean - np.mean(day_values)) + + +def metric_day_mean(row: Recording, idx: int | None) -> float | None: + if idx is None: + return None + day_values = positive_metric_values(row, idx, 0, row.values.shape[1]) + if day_values is None: + return None + return float(np.mean(day_values)) + + +def channel_meta(config: ChannelConfig, channel_name: str) -> tuple[str, str]: + display_name, unit, _ = config.meta.get(channel_name, (config.display_name(channel_name), "", 0)) + return display_name, unit + + +@lru_cache(maxsize=None) +def _load_template_payload(path: str) -> dict: + with open(path, encoding="utf-8") as f: + return json.load(f) + + +def _cross_channel_metric_templates(config: ChannelConfig) -> dict[str, list[str]]: + payload = _load_template_payload(str(config.templates_path)) + cross_channel = payload.get("cross_channel", {}) + metrics = cross_channel.get("metrics", {}) + return metrics if isinstance(metrics, dict) else {} + + +def _select_metric_template( + config: ChannelConfig, + row: Recording, + start: int, + end: int, + template_key: str, +) -> str | None: + templates = _cross_channel_metric_templates(config).get(template_key) + if not templates: + return None + idx = seed_from_key(f"{row.row_id}:{start}:{end}:{template_key}") % len(templates) + return templates[idx] + + +def render_metric_text( + config: ChannelConfig, + row: Recording, + start: int, + end: int, + template_key: str, + fallback: str, + **kwargs, +) -> str: + template = _select_metric_template(config, row, start, end, template_key) + return (template or fallback).format(**kwargs) + + +def sentence(text: str) -> str: + text = text.strip() + if not text: + return "" + return text if text.endswith(".") else f"{text}." + + +def finalize_caption_text(text: str, metrics_suffix: str) -> str: + if metrics_suffix: + return text.rstrip(".") + "." + return text + + +def _join_metric_fragments(fragments: list[str]) -> str: + return ", ".join(fragment.strip() for fragment in fragments if fragment.strip()) + + +def append_hr_metrics( + parts: list[str], + channel_idxs: list[int], + config: ChannelConfig, + row: Recording, + start: int, + end: int, + hr_idx: int | None, + elevated_threshold: float, + include_space_before_day_unit: bool = True, +) -> None: + """Append heart-rate summary sentences and include the HR channel when present.""" + hr_mean = metric_mean(row, hr_idx, start, end) + if hr_mean is None or hr_idx is None: + return + + hr_name, hr_unit = channel_meta(config, WATCH_HR_CHANNEL) + hr_peak = metric_peak(row, hr_idx, start, end) + summary_parts = [ + render_metric_text( + config, + row, + start, + end, + "hr_summary", + "averaging a {name} of {mean:.0f} {unit}", + name=hr_name, + mean=hr_mean, + unit=hr_unit, + ) + ] + if hr_peak is not None: + summary_parts.append( + render_metric_text( + config, + row, + start, + end, + "hr_peak", + "peaking at {peak:.0f} {unit}", + name=hr_name, + peak=hr_peak, + unit=hr_unit, + ) + ) + if elevated_threshold is not None and hr_mean > elevated_threshold: + summary_parts.append( + render_metric_text( + config, + row, + start, + end, + "hr_elevated", + "with an elevated {name} during this phase", + name=hr_name, + ) + ) + parts.append(sentence(_join_metric_fragments(summary_parts))) + + hr_day_delta = metric_day_mean_delta(row, hr_idx, start, end) + hr_day_mean = metric_day_mean(row, hr_idx) + if hr_day_delta is not None and hr_day_mean is not None: + direction = "higher" if hr_day_delta >= 0 else "lower" + day_mean_unit = f" {hr_unit}" if include_space_before_day_unit and hr_unit else hr_unit + template_key = "hr_day_delta_higher" if hr_day_delta >= 0 else "hr_day_delta_lower" + fallback = "The {name} was {delta:.0f} {unit} {direction} than the day's mean of {day_mean:.0f}{day_mean_unit}" + parts.append( + sentence( + render_metric_text( + config, + row, + start, + end, + template_key, + fallback, + name=hr_name, + delta=abs(hr_day_delta), + unit=hr_unit, + direction=direction, + day_mean=hr_day_mean, + day_mean_unit=day_mean_unit, + ) + ) + ) + + channel_idxs.append(hr_idx) + + +def append_distance_metrics( + parts: list[str], + channel_idxs: list[int], + config: ChannelConfig, + row: Recording, + start: int, + end: int, + distance_idx: int | None, +) -> None: + """Append distance summary sentences and include the distance channel when present.""" + distance_mean = metric_mean(row, distance_idx, start, end) + distance_total = metric_total(row, distance_idx, start, end) + if distance_mean is None or distance_idx is None: + return + + parts.append( + sentence( + render_metric_text( + config, + row, + start, + end, + "distance_mean", + "The watch recorded an average distance of {mean:.1f} m/min during this period", + mean=distance_mean, + ) + ) + ) + if distance_total is not None: + parts.append( + sentence( + render_metric_text( + config, + row, + start, + end, + "distance_total", + "The total distance recorded by the watch in that interval was {total:.1f} m", + total=distance_total, + ) + ) + ) + + channel_idxs.append(distance_idx) + + +def append_step_metrics( + parts: list[str], + channel_idxs: list[int], + config: ChannelConfig, + row: Recording, + start: int, + end: int, + step_idx: int | None, +) -> None: + """Append step-count summary sentences and include the step channel when present.""" + step_mean = metric_mean(row, step_idx, start, end) + step_total = metric_total(row, step_idx, start, end) + if step_mean is None or step_idx is None: + return + + parts.append( + sentence( + render_metric_text( + config, + row, + start, + end, + "step_mean", + "The watch recorded an average step count of {mean:.1f} steps/min during this period", + mean=step_mean, + ) + ) + ) + if step_total is not None: + parts.append( + sentence( + render_metric_text( + config, + row, + start, + end, + "step_total", + "The total step count recorded by the watch during that time was {total:.0f}", + total=step_total, + ) + ) + ) + + channel_idxs.append(step_idx) + + +def contiguous_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]: + if not mask.any(): + return [] + + padded = np.concatenate(([False], mask, [False])) + diffs = np.diff(padded.astype(np.int8)) + starts = np.where(diffs == 1)[0] + ends = np.where(diffs == -1)[0] + keep = (ends - starts) >= min_duration + return list(zip(starts[keep].tolist(), ends[keep].tolist())) diff --git a/synthesizers/_workout.py b/synthesizers/_workout.py new file mode 100644 index 0000000..2c97165 --- /dev/null +++ b/synthesizers/_workout.py @@ -0,0 +1,94 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import abc +import json + +import numpy as np + +from extractors import ChannelConfig +from synthesizers._base import CrossChannelSynthesizer +from synthesizers._helper import contiguous_windows, finalize_caption_text +from timef.schema import Annotation, Recording +from util import seed_from_key + + +class WorkoutSynthesizer(CrossChannelSynthesizer, abc.ABC): + def __init__(self, min_duration: int = 0): + self.min_duration = min_duration + + @property + @abc.abstractmethod + def template_key(self) -> str: + """Template key for cross-channel workout captions.""" + + @property + @abc.abstractmethod + def workout_channels(self) -> list[tuple[str, str, str]]: + """Triples of workout channel, annotation label, and activity name.""" + + @abc.abstractmethod + def _metrics_suffix( + self, + config: ChannelConfig, + label: str, + row: Recording, + start: int, + end: int, + channel_idxs: list[int], + ) -> str: + """Build the optional metrics suffix and append referenced channel indices.""" + + def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: + templates = json.loads(config.templates_path.read_text())["cross_channel"][self.template_key] + time_unit = "hour" if config.time_unit == "hours" else "minute" + seed = seed_from_key(row.row_id) + + results: list[Annotation] = [] + for workout_channel, label, activity_name in self.workout_channels: + try: + workout_idx = row.channel_names.index(workout_channel) + except ValueError: + continue + + workout = np.asarray(row.values[workout_idx], dtype=float) + workout_active = np.isfinite(workout) & (workout > 0) + if not workout_active.any(): + continue + + windows = contiguous_windows(workout_active, self.min_duration) + for i, (start, end) in enumerate(windows): + end_inclusive = max(start, end - 1) + template = templates[(seed + i) % len(templates)] + channel_idxs = [workout_idx] + metrics_suffix = self._metrics_suffix( + config=config, + label=label, + row=row, + start=start, + end=end, + channel_idxs=channel_idxs, + ) + text = template.format( + activity_name=activity_name, + time_unit=time_unit, + start=start, + end=end_inclusive, + metrics_suffix=metrics_suffix, + ) + text = finalize_caption_text(text, metrics_suffix) + results.append( + Annotation( + caption_type="cross_channel", + text=text, + channel_idxs=tuple(channel_idxs), + window=(start, end), + label=label, + ) + ) + return results diff --git a/synthesizers/cardio.py b/synthesizers/cardio.py new file mode 100644 index 0000000..7e89bda --- /dev/null +++ b/synthesizers/cardio.py @@ -0,0 +1,60 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +from extractors import ChannelConfig +from mhc.constants import CARDIO_WORKOUT_CHANNELS, WATCH_DISTANCE_CHANNEL, WATCH_HR_CHANNEL, WATCH_STEP_CHANNEL +from synthesizers import WorkoutSynthesizer +from synthesizers._helper import ( + append_distance_metrics, + append_hr_metrics, + append_step_metrics, + index_or_none, +) +from timef.schema import Recording + + +class CardioSynthesizer(WorkoutSynthesizer): + def __init__( + self, + min_duration: int = 0, + hr_elevated_threshold_bpm: float = 100.0, + ): + super().__init__(min_duration=min_duration) + self.hr_elevated_threshold_bpm = hr_elevated_threshold_bpm + + @property + def template_key(self) -> str: + return "cardio" + + @property + def workout_channels(self) -> list[tuple[str, str, str]]: + return CARDIO_WORKOUT_CHANNELS + + def _metrics_suffix( + self, + config: ChannelConfig, + label: str, + row: Recording, + start: int, + end: int, + channel_idxs: list[int], + ) -> str: + hr_idx = index_or_none(row, WATCH_HR_CHANNEL) + distance_idx = index_or_none(row, WATCH_DISTANCE_CHANNEL) + step_idx = index_or_none(row, WATCH_STEP_CHANNEL) + parts: list[str] = [] + + append_hr_metrics(parts, channel_idxs, config, row, start, end, hr_idx, elevated_threshold=self.hr_elevated_threshold_bpm) + append_distance_metrics(parts, channel_idxs, config, row, start, end, distance_idx) + + if label == "cardio_running": + append_step_metrics(parts, channel_idxs, config, row, start, end, step_idx) + + if not parts: + return "" + return " " + " ".join(parts) diff --git a/synthesizers/sleep.py b/synthesizers/sleep.py new file mode 100644 index 0000000..ff852e7 --- /dev/null +++ b/synthesizers/sleep.py @@ -0,0 +1,56 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import json + +import numpy as np + +from extractors import ChannelConfig +from synthesizers import CrossChannelSynthesizer +from synthesizers._helper import contiguous_windows +from timef.schema import Annotation, Recording +from util import seed_from_key + + +class SleepSynthesizer(CrossChannelSynthesizer): + def __init__(self, min_duration: int = 0): + self.min_duration = min_duration + + def synthesize(self, row: Recording, config: ChannelConfig) -> list[Annotation]: + try: + in_bed_idx = row.channel_names.index("sleep:inbed") + asleep_idx = row.channel_names.index("sleep:asleep") + except ValueError: + return [] + + in_bed = np.asarray(row.values[in_bed_idx], dtype=float) + asleep = np.asarray(row.values[asleep_idx], dtype=float) + if not np.any(~np.isnan(asleep)): + return [] + # Mark intervals where the user is in bed but not recorded as asleep. + mask = (~np.isnan(in_bed)) & (in_bed > 0) & ~((~np.isnan(asleep)) & (asleep > 0)) + + time_unit = "hour" if config.time_unit == "hours" else "minute" + templates = json.loads(config.templates_path.read_text())["cross_channel"]["sleep"] + seed = seed_from_key(row.row_id) + + results: list[Annotation] = [] + for i, (start, end) in enumerate(contiguous_windows(mask, self.min_duration)): + end_inclusive = max(start, end - 1) + template = templates[(seed + i) % len(templates)] + text = template.format(time_unit=time_unit, start=start, end=end_inclusive) + results.append( + Annotation( + caption_type="cross_channel", + text=text, + channel_idxs=(asleep_idx, in_bed_idx), + window=(start, end), + label="in_bed_not_sleeping", + ) + ) + return results diff --git a/synthesizers/stationary_activity.py b/synthesizers/stationary_activity.py new file mode 100644 index 0000000..005e07e --- /dev/null +++ b/synthesizers/stationary_activity.py @@ -0,0 +1,72 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +from extractors import ChannelConfig +from mhc.constants import STATIONARY_WORKOUT_CHANNELS, WATCH_DISTANCE_CHANNEL, WATCH_HR_CHANNEL, WATCH_STEP_CHANNEL +from synthesizers import WorkoutSynthesizer +from synthesizers._helper import ( + append_distance_metrics, + append_hr_metrics, + append_step_metrics, + index_or_none, +) +from timef.schema import Recording + + +class StationaryActivitySynthesizer(WorkoutSynthesizer): + def __init__( + self, + min_duration: int = 0, + hr_floor_bpm: float = 100.0, + hr_margin_bpm: float = 15.0, + max_distance_per_min: float = 3.0, + ): + super().__init__(min_duration=min_duration) + self.hr_floor_bpm = hr_floor_bpm + self.hr_margin_bpm = hr_margin_bpm + self.max_distance_per_min = max_distance_per_min + + @property + def template_key(self) -> str: + return "stationary_activity" + + @property + def workout_channels(self) -> list[tuple[str, str, str]]: + return STATIONARY_WORKOUT_CHANNELS + + def _metrics_suffix( + self, + config: ChannelConfig, + _label: str, + row: Recording, + start: int, + end: int, + channel_idxs: list[int], + ) -> str: + hr_idx = index_or_none(row, WATCH_HR_CHANNEL) + distance_idx = index_or_none(row, WATCH_DISTANCE_CHANNEL) + step_idx = index_or_none(row, WATCH_STEP_CHANNEL) + parts: list[str] = [] + + append_hr_metrics( + parts, + channel_idxs, + config, + row, + start, + end, + hr_idx, + elevated_threshold=self.hr_floor_bpm, + include_space_before_day_unit=False, + ) + append_distance_metrics(parts, channel_idxs, config, row, start, end, distance_idx) + append_step_metrics(parts, channel_idxs, config, row, start, end, step_idx) + + if not parts: + return "" + return " " + " ".join(parts) diff --git a/templates/templates.json b/templates/templates.json index c9719f3..8d8dec9 100644 --- a/templates/templates.json +++ b/templates/templates.json @@ -72,5 +72,93 @@ "{name} occurred at minute {start}.", "{name} was observed at minute {start}.", "Identified {name} at minute {start}." - ] + ], + "cross_channel": { + "metrics": { + "hr_summary": [ + "averaging a {name} of {mean:.0f} {unit}", + "with a mean {name} of {mean:.0f} {unit}", + "at roughly {mean:.0f} {unit} {name} on average", + "with {name} averaging around {mean:.0f} {unit}", + "where {name} averaged at {mean:.0f} {unit}" + ], + "hr_peak": [ + "peaking at {peak:.0f} {unit}", + "reaching a peak {name} of {peak:.0f} {unit}", + "with a maximum of {peak:.0f} {unit}", + "with {name} climbing to {peak:.0f} {unit} at its highest", + "where {name} topped out at {peak:.0f} {unit}" + ], + "hr_elevated": [ + "with an elevated {name} during this phase", + "indicating an elevated {name} in this interval", + "showing a heightened {name} throughout this phase", + "suggesting the {name} stayed elevated throughout this stretch", + "with the {name} remaining noticeably elevated across the interval" + ], + "hr_day_delta_higher": [ + "The {name} was {delta:.0f} {unit} higher than the day's mean of {day_mean:.0f}{day_mean_unit}", + "Compared with the day's mean of {day_mean:.0f}{day_mean_unit}, the {name} was {delta:.0f} {unit} higher", + "This interval's {name} was {delta:.0f} {unit} above the day's mean of {day_mean:.0f}{day_mean_unit}", + "The {name} sat about {delta:.0f} {unit} above the day's mean of {day_mean:.0f}{day_mean_unit}", + "Relative to the day's mean of {day_mean:.0f}{day_mean_unit}, the {name} ran about {delta:.0f} {unit} higher" + ], + "hr_day_delta_lower": [ + "The {name} was {delta:.0f} {unit} lower than the day's mean of {day_mean:.0f}{day_mean_unit}", + "Compared with the day's mean of {day_mean:.0f}{day_mean_unit}, the {name} was {delta:.0f} {unit} lower", + "This interval's {name} was {delta:.0f} {unit} below the day's mean of {day_mean:.0f}{day_mean_unit}", + "The {name} sat about {delta:.0f} {unit} below the day's mean of {day_mean:.0f}{day_mean_unit}", + "Relative to the day's mean of {day_mean:.0f}{day_mean_unit}, the {name} ran about {delta:.0f} {unit} lower" + ], + "distance_mean": [ + "The watch recorded an average distance of {mean:.1f} m/min during this period", + "Average watch distance during this interval was {mean:.1f} m/min", + "The watch logged roughly {mean:.1f} m/min on average in this period", + "The watch logged about {mean:.1f} m/min on average over this stretch", + "Over this interval, the watch averaged {mean:.1f} m/min" + ], + "distance_total": [ + "The total distance recorded by the watch in that interval was {total:.1f} m", + "Across that interval, the watch recorded {total:.1f} m in total distance", + "The watch logged a total distance of {total:.1f} m during that time", + "Over the full interval, the watch recorded {total:.1f} m of distance", + "In total, the watch logged {total:.1f} m across the interval" + ], + "step_mean": [ + "The watch recorded an average step count of {mean:.1f} steps/min during this period", + "Average watch step count during this interval was {mean:.1f} steps/min", + "The watch logged roughly {mean:.1f} steps/min on average in this period", + "The watch logged about {mean:.1f} steps/min on average over this stretch", + "Over this interval, the watch averaged {mean:.1f} steps/min" + ], + "step_total": [ + "The total step count recorded by the watch during that time was {total:.0f}", + "Across that interval, the watch recorded {total:.0f} total steps", + "The watch logged a total of {total:.0f} steps during that time", + "Over the full interval, the watch recorded {total:.0f} steps", + "In total, the watch logged {total:.0f} steps across the interval" + ] + }, + "sleep": [ + "In bed but not sleeping from {time_unit} {start} to {end}.", + "The user was in bed but awake from {time_unit} {start} to {end}.", + "Awake while in bed between {time_unit} {start} and {end}.", + "In-bed wakefulness detected from {time_unit} {start} to {end}.", + "The user remained awake in bed from {time_unit} {start} to {end}." + ], + "stationary_activity": [ + "{activity_name} training was detected from {time_unit} {start} to {end}{metrics_suffix}.", + "The user did {activity_name} training between {time_unit} {start} and {end}{metrics_suffix}.", + "{activity_name} training took place from {time_unit} {start} to {end}{metrics_suffix}.", + "From {time_unit} {start} to {end}, the user was doing {activity_name} training{metrics_suffix}.", + "A stretch of {activity_name} training lasted from {time_unit} {start} to {end}{metrics_suffix}." + ], + "cardio": [ + "{activity_name} lasted from {time_unit} {start} to {end}{metrics_suffix}.", + "The user was doing {activity_name} between {time_unit} {start} and {end}{metrics_suffix}.", + "{activity_name} took place from {time_unit} {start} to {end}{metrics_suffix}.", + "From {time_unit} {start} to {end}, the user was in a {activity_name} interval{metrics_suffix}.", + "A period of {activity_name} ran from {time_unit} {start} to {end}{metrics_suffix}." + ] + } } diff --git a/util.py b/util.py new file mode 100644 index 0000000..ad367ad --- /dev/null +++ b/util.py @@ -0,0 +1,13 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import zlib + + +def seed_from_key(key: str) -> int: + return zlib.crc32(key.encode("utf-8")) & 0xFFFFFFFF diff --git a/visualizer.py b/visualizer.py index a1583d4..50c8a32 100644 --- a/visualizer.py +++ b/visualizer.py @@ -114,13 +114,28 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] from mhc.dataset import MHCDataset from mhc.transformer import MHCTransformer from mhc.constants import MHC_CHANNEL_CONFIG + from extractors.cross_channel import CrossChannelExtractor + from synthesizers.cardio import CardioSynthesizer + from synthesizers.sleep import SleepSynthesizer + from synthesizers.stationary_activity import StationaryActivitySynthesizer from extractors.statistical import StatisticalExtractor from extractors.structural import StructuralExtractor from annotator import Annotator from captionizer import Captionizer dataset = MHCDataset(min_wear_pct=90.0) - annotator = Annotator([StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG)]) + annotator = Annotator([ + StatisticalExtractor(MHC_CHANNEL_CONFIG), + StructuralExtractor(MHC_CHANNEL_CONFIG), + CrossChannelExtractor( + MHC_CHANNEL_CONFIG, + synthesizers=[ + SleepSynthesizer(min_duration=5), + StationaryActivitySynthesizer(min_duration=5), + CardioSynthesizer(min_duration=5), + ], + ), + ]) captionizer = Captionizer(dataset, MHCTransformer(), annotator) result, _ = captionizer.run(max_rows=1) for row in result.iter_rows():