diff --git a/README.md b/README.md index e978985..0ac5359 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,11 @@ Dataset-agnostic captioning pipeline for sensor time-series data. ## Setup -Set the dataset path before running: +Install dependencies and set the dataset path before running: + +```bash +python3 -m pip install -r requirements.txt +``` ```bash export MHC_DATASET_DIR="../hf-daily_max-nonwear=50" @@ -13,5 +17,25 @@ export MHC_DATASET_DIR="../hf-daily_max-nonwear=50" ## Usage ```bash -python captionizer.py +python3 captionizer.py ``` + +## Explorer + +Use the interactive explorer to inspect one row at a time, switch signals, and see which detector events fired where on the time series. + +Start it with: + +```bash +python3 explorer.py +``` + +Useful controls: + +- Use the bottom row slider or `<` / `>` buttons to move between dataset rows. +- Click a signal in the right-hand signal list or in the channel overview heatmap to switch channels. +- Use the Matplotlib zoom and pan tools on the main plot to inspect parts of the signal in detail. +- Click `reset` or press `home` to reset the zoom. +- Use the overlay buttons to toggle `trend`, `spike`, `drop`, `gap`, and `nonwear` overlays. +- Use the `stats`, `events`, `captions`, and `help` tabs in the details panel to switch what metadata is shown. +- Scroll inside the details panel with the mouse wheel or the `^` / `v` buttons. diff --git a/detectors/spike.py b/detectors/spike.py index d26f129..abe65cc 100644 --- a/detectors/spike.py +++ b/detectors/spike.py @@ -7,6 +7,7 @@ from __future__ import annotations import numpy as np +from scipy.signal import find_peaks from detectors import DetectionResult, StructuralDetector @@ -15,8 +16,51 @@ class SpikeDetector(StructuralDetector): """Detects spikes and drops. """ - def __init__(self, filter_zeros: bool = False) -> None: + def __init__( + self, + filter_zeros: bool = False, + prominence_scale: float = 3.0, + min_prominence: float = 0.0, + min_distance: int = 1, + ) -> None: super().__init__(filter_zeros=filter_zeros) + self.prominence_scale = prominence_scale + self.min_prominence = min_prominence + self.min_distance = max(1, min_distance) def _detect(self, series: np.ndarray, indices: np.ndarray) -> list[DetectionResult]: - return [] + prominence = self._prominence_threshold(series) + if prominence <= 0: + return [] + + results: list[DetectionResult] = [] + seen_minutes: set[int] = set() + + for peak_idx in find_peaks(series, prominence=prominence, distance=self.min_distance)[0]: + minute = int(indices[peak_idx]) + if minute in seen_minutes: + continue + results.append(DetectionResult(event_type="spike", spike_minute=minute)) + seen_minutes.add(minute) + + for peak_idx in find_peaks(-series, prominence=prominence, distance=self.min_distance)[0]: + minute = int(indices[peak_idx]) + if minute in seen_minutes: + continue + results.append(DetectionResult(event_type="drop", spike_minute=minute)) + seen_minutes.add(minute) + + results.sort(key=lambda result: int(result.spike_minute)) + return results + + def _prominence_threshold(self, series: np.ndarray) -> float: + centered = series - np.median(series) + mad = float(np.median(np.abs(centered))) + if mad > 0: + scale = 1.4826 * mad + return max(self.min_prominence, self.prominence_scale * scale) + + spread = float(np.percentile(series, 95) - np.percentile(series, 5)) + if spread <= 1e-12: + return 0.0 + return max(self.min_prominence, 0.5 * spread) diff --git a/detectors/trend.py b/detectors/trend.py index 370f71d..ea075e2 100644 --- a/detectors/trend.py +++ b/detectors/trend.py @@ -5,18 +5,174 @@ # SPDX-License-Identifier: MIT # from __future__ import annotations +from dataclasses import dataclass import numpy as np from detectors import DetectionResult, StructuralDetector +@dataclass +class _TrendSegment: + direction: str + start_minute: int + end_minute: int + score: float + + class TrendDetector(StructuralDetector): - """Detects statistically significant trends. - """ + """Detects increasing/decreasing trends across multiple relative window sizes.""" - def __init__(self, filter_zeros: bool = False) -> None: + def __init__( + self, + filter_zeros: bool = False, + window_sizes: tuple[int, ...] | None = None, + window_fracs: tuple[float, ...] = (0.1, 0.25, 1.0), + min_window: int = 12, + max_window: int | None = None, + stride_frac: float = 0.25, + min_coverage: float = 0.75, + min_effect: float = 1.25, + min_r2: float = 0.5, + min_span: int = 10, + merge_gap: int = 5, + ) -> None: super().__init__(filter_zeros=filter_zeros) + self.window_sizes = window_sizes + self.window_fracs = tuple(window_fracs) + self.min_window = max(3, min_window) + self.max_window = max_window + self.stride_frac = stride_frac + self.min_coverage = min_coverage + self.min_effect = min_effect + self.min_r2 = min_r2 + self.min_span = min_span + self.merge_gap = merge_gap def _detect(self, series: np.ndarray, indices: np.ndarray) -> list[DetectionResult]: - return [] + window_sizes = self._resolve_window_sizes(len(series)) + segments: list[_TrendSegment] = [] + for window_size in window_sizes: + stride = max(1, int(round(window_size * self.stride_frac))) + for start_idx in self._window_starts(len(series), window_size, stride): + end_idx = start_idx + window_size + segment = self._classify_window(series[start_idx:end_idx], indices[start_idx:end_idx]) + if segment is not None: + segments.append(segment) + + merged = self._merge_segments(segments) + return [ + DetectionResult( + event_type="trend", + start_minute=segment.start_minute, + end_minute=segment.end_minute, + direction=segment.direction, + ) + for segment in merged + if segment.end_minute - segment.start_minute >= self.min_span + ] + + def _resolve_window_sizes(self, n_samples: int) -> list[int]: + sizes = set(self.window_sizes or ()) + for frac in self.window_fracs: + if frac <= 0: + continue + sizes.add(int(round(n_samples * frac))) + + resolved = [] + max_window = self.max_window or n_samples + for size in sorted(sizes): + size = max(self.min_window, min(size, max_window, n_samples)) + if size <= n_samples: + resolved.append(size) + + return sorted(set(resolved)) + + @staticmethod + def _window_starts(n_samples: int, window_size: int, stride: int) -> list[int]: + if window_size >= n_samples: + return [0] + + starts = list(range(0, n_samples - window_size + 1, stride)) + last_start = n_samples - window_size + if starts[-1] != last_start: + starts.append(last_start) + return starts + + def _classify_window( + self, + series: np.ndarray, + indices: np.ndarray, + ) -> _TrendSegment | None: + if len(series) < self.min_window: + return None + + start_minute = int(indices[0]) + end_minute = int(indices[-1]) + span = end_minute - start_minute + if span < self.min_span: + return None + + coverage = len(indices) / (span + 1) + if coverage < self.min_coverage: + return None + + x = indices.astype(float) + y = series.astype(float) + x_centered = x - np.mean(x) + y_centered = y - np.mean(y) + + denom = float(np.dot(x_centered, x_centered)) + if denom <= 0: + return None + + slope = float(np.dot(x_centered, y_centered) / denom) + fitted = np.mean(y) + slope * x_centered + + ss_tot = float(np.dot(y_centered, y_centered)) + if ss_tot <= 1e-12: + return None + + ss_res = float(np.dot(y - fitted, y - fitted)) + r2 = max(0.0, 1.0 - ss_res / ss_tot) + delta = slope * span + scale = float(np.std(y)) + if scale <= 1e-12: + return None + + effect = abs(delta) / scale + if effect < self.min_effect or r2 < self.min_r2: + return None + + direction = "increasing" if slope > 0 else "decreasing" + return _TrendSegment( + direction=direction, + start_minute=start_minute, + end_minute=end_minute, + score=effect * r2, + ) + + def _merge_segments(self, segments: list[_TrendSegment]) -> list[_TrendSegment]: + if not segments: + return [] + + segments.sort(key=lambda s: (s.start_minute, s.end_minute, -s.score)) + merged: list[_TrendSegment] = [] + + for segment in segments: + if not merged: + merged.append(segment) + continue + + prev = merged[-1] + if ( + segment.direction == prev.direction + and segment.start_minute <= prev.end_minute + self.merge_gap + ): + prev.end_minute = max(prev.end_minute, segment.end_minute) + prev.score = max(prev.score, segment.score) + continue + + merged.append(segment) + + return merged diff --git a/explorer.py b/explorer.py new file mode 100644 index 0000000..5072448 --- /dev/null +++ b/explorer.py @@ -0,0 +1,776 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import argparse +import textwrap +from functools import lru_cache + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.patches import Rectangle +from matplotlib.widgets import Button, Slider + +from annotator import Annotator +from extractors.statistical import StatisticalExtractor +from extractors.structural import StructuralExtractor +from mhc.constants import MHC_CHANNEL_CONFIG +from mhc.dataset import MHCDataset +from mhc.transformer import MHCTransformer +from timef.schema import Annotation, Sample, Signal + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Interactive explorer for MHC rows, signals, and detector outputs.") + parser.add_argument("--row-index", type=int, default=0, help="Initial dataset row index.") + parser.add_argument("--signal-index", type=int, default=0, help="Initial signal index.") + parser.add_argument("--min-wear-pct", type=float, default=0.0, help="Minimum wear percentage filter.") + parser.add_argument("--save-path", type=str, default=None, help="Optional snapshot path. Saves current view and exits.") + return parser.parse_args() + + +def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]]: + regions = [] + in_region = False + for i, val in enumerate(np.isnan(arr)): + if val and not in_region: + start = i + in_region = True + elif not val and in_region: + if i - start >= min_length: + regions.append((start, i - 1)) + in_region = False + if in_region and len(arr) - start >= min_length: + regions.append((start, len(arr) - 1)) + return regions + + +def _format_detector_event(detector_name: str, result) -> str: + if result.event_type == "trend": + return f"{detector_name}: {result.direction} {result.start_minute}-{result.end_minute}" + if result.event_type == "plateau": + return f"{detector_name}: plateau {result.start_minute}-{result.end_minute}" + if result.event_type in {"spike", "drop"}: + return f"{detector_name}: {result.event_type} @{result.spike_minute}" + if result.event_type == "gap": + return f"{detector_name}: gap {result.start_minute}-{result.end_minute}" + return f"{detector_name}: {result.event_type}" + + +def _display_name(signal_name: str) -> str: + return MHC_CHANNEL_CONFIG.meta.get(signal_name, (signal_name, "", 0))[0] + + +def _truncate(text: str, max_len: int = 34) -> str: + if len(text) <= max_len: + return text + return text[: max_len - 1] + "..." + + +class SensorExplorer: + def __init__( + self, + dataset: MHCDataset, + row_index: int = 0, + signal_index: int = 0, + ) -> None: + self.dataset = dataset + self.transformer = MHCTransformer() + self.annotator = Annotator([ + StatisticalExtractor(MHC_CHANNEL_CONFIG), + StructuralExtractor(MHC_CHANNEL_CONFIG), + ]) + + self.row_index = min(max(0, row_index), len(self.dataset) - 1) + self.signal_index = min(max(0, signal_index), len(MHC_CHANNEL_CONFIG.names) - 1) + + self.show_trends = True + self.show_plateaus = True + self.show_spikes = True + self.show_drops = True + self.show_gaps = True + self.show_nonwear = True + self.detail_mode = "events" + self.details_scroll = 0 + self.details_page_lines = 12 + self.hit_target_names = self._available_detector_names() + self.hit_target = self.hit_target_names[0] if self.hit_target_names else None + self.search_status = "Use hit< / hit> to jump to the selected detector." + + self._ignore_widget_events = False + + self.fig = plt.figure(figsize=(17, 10)) + self.ax_main = self.fig.add_axes([0.04, 0.34, 0.70, 0.62]) + self.ax_overview = self.fig.add_axes([0.04, 0.16, 0.70, 0.12], sharex=self.ax_main) + self.ax_summary = self.fig.add_axes([0.77, 0.78, 0.21, 0.18]) + self.ax_hit_target = self.fig.add_axes([0.77, 0.715, 0.21, 0.045]) + self.ax_signal_list = self.fig.add_axes([0.77, 0.43, 0.21, 0.24]) + self.ax_details = self.fig.add_axes([0.76, 0.12, 0.22, 0.22]) + for ax in (self.ax_summary, self.ax_hit_target, self.ax_signal_list, self.ax_details): + ax.axis("off") + + self.reset_zoom_ax = self.fig.add_axes([0.04, 0.049, 0.055, 0.036]) + self.row_slider_ax = self.fig.add_axes([0.11, 0.060, 0.57, 0.024]) + self.prev_row_ax = self.fig.add_axes([0.70, 0.051, 0.035, 0.036]) + self.next_row_ax = self.fig.add_axes([0.74, 0.051, 0.035, 0.036]) + detail_tab_specs = [ + ("stats", 0.76), + ("events", 0.815), + ("captions", 0.870), + ("help", 0.925), + ] + self.detail_tab_buttons: dict[str, Button] = {} + for label, x0 in detail_tab_specs: + ax = self.fig.add_axes([x0 + 0.01, 0.375, 0.047, 0.028]) + self.detail_tab_buttons[label] = Button(ax, label) + self.detail_up_ax = self.fig.add_axes([0.935, 0.085, 0.022, 0.028]) + self.detail_down_ax = self.fig.add_axes([0.958, 0.085, 0.022, 0.028]) + self.prev_hit_ax = self.fig.add_axes([0.76, 0.698, 0.035, 0.038]) + self.next_hit_ax = self.fig.add_axes([0.945, 0.698, 0.035, 0.038]) + self.hit_target_buttons: dict[str, Button] = {} + if self.hit_target_names: + n_targets = len(self.hit_target_names) + total_gap = 0.004 * max(0, n_targets - 1) + available_width = 0.145 + width = (available_width - total_gap) / max(1, n_targets) + for i, detector_name in enumerate(self.hit_target_names): + x0 = 0.805 + i * (width + 0.004) + ax = self.fig.add_axes([x0, 0.698, width, 0.038]) + self.hit_target_buttons[detector_name] = Button(ax, detector_name.replace("Detector", "").lower()) + overlay_labels = ["trend", "plateau", "spike", "drop", "gap", "nonwear"] + self.overlay_buttons: dict[str, Button] = {} + start_x = 0.79 + button_width = 0.027 + gap = 0.004 + for i, label in enumerate(overlay_labels): + ax = self.fig.add_axes([start_x + i * (button_width + gap), 0.051, button_width, 0.036]) + self.overlay_buttons[label] = Button(ax, label) + + self.row_slider = Slider( + self.row_slider_ax, + "Row", + 0, + len(self.dataset) - 1, + valinit=self.row_index, + valstep=1, + valfmt="%d", + ) + self.reset_zoom_button = Button(self.reset_zoom_ax, "reset") + self.prev_row_button = Button(self.prev_row_ax, "<") + self.next_row_button = Button(self.next_row_ax, ">") + self.prev_hit_button = Button(self.prev_hit_ax, "hit<") + self.next_hit_button = Button(self.next_hit_ax, "hit>") + self.detail_up_button = Button(self.detail_up_ax, "^") + self.detail_down_button = Button(self.detail_down_ax, "v") + self._style_widgets() + self._sync_widgets() + + self.row_slider.on_changed(self._on_row_slider) + self.reset_zoom_button.on_clicked(lambda _: self.render(reset_zoom=True)) + self.prev_row_button.on_clicked(lambda _: self._set_row(self.row_index - 1)) + self.next_row_button.on_clicked(lambda _: self._set_row(self.row_index + 1)) + self.prev_hit_button.on_clicked(lambda _: self._jump_to_hit(-1)) + self.next_hit_button.on_clicked(lambda _: self._jump_to_hit(1)) + self.detail_up_button.on_clicked(lambda _: self._scroll_details(-1)) + self.detail_down_button.on_clicked(lambda _: self._scroll_details(1)) + for detector_name, button in self.hit_target_buttons.items(): + button.on_clicked(lambda _, name=detector_name: self._set_hit_target(name)) + for label, button in self.detail_tab_buttons.items(): + button.on_clicked(lambda _, name=label: self._set_detail_mode(name)) + for label, button in self.overlay_buttons.items(): + button.on_clicked(lambda _, name=label: self._on_toggle(name)) + self.fig.canvas.mpl_connect("button_press_event", self._on_click) + self.fig.canvas.mpl_connect("key_press_event", self._on_key_press) + self.fig.canvas.mpl_connect("scroll_event", self._on_scroll) + + self.render(reset_zoom=True) + + @lru_cache(maxsize=12) + def _load_row_bundle(self, row_index: int) -> tuple[list[Signal], list[Sample], list[Annotation]]: + row = self.dataset[row_index] + signals = self.transformer.transform_row(row) + samples, annotations = self.annotator.annotate(signals) + return signals, samples, annotations + + @staticmethod + def _available_detector_names() -> list[str]: + names = { + detector.__class__.__name__ + for detectors in MHC_CHANNEL_CONFIG.detectors.values() + for detector in detectors + } + return sorted(names) + + def _set_row(self, row_index: int) -> None: + row_index = min(max(0, int(row_index)), len(self.dataset) - 1) + if row_index == self.row_index: + return + self.row_index = row_index + self.details_scroll = 0 + self._sync_widgets() + self.render(reset_zoom=True) + + def _set_signal(self, signal_index: int) -> None: + signal_index = min(max(0, int(signal_index)), len(MHC_CHANNEL_CONFIG.names) - 1) + if signal_index == self.signal_index: + return + self.signal_index = signal_index + self.details_scroll = 0 + self._sync_widgets() + self.render(reset_zoom=True) + + def _sync_widgets(self) -> None: + self._ignore_widget_events = True + self.row_slider.set_val(self.row_index) + self.row_slider_ax.set_title(f"Row {self.row_index} / {len(self.dataset) - 1}", loc="left", fontsize=10, pad=2) + self._ignore_widget_events = False + + def _on_row_slider(self, value: float) -> None: + if not self._ignore_widget_events: + self._set_row(int(value)) + + def _on_toggle(self, label: str) -> None: + if label == "trend": + self.show_trends = not self.show_trends + elif label == "plateau": + self.show_plateaus = not self.show_plateaus + elif label == "spike": + self.show_spikes = not self.show_spikes + elif label == "drop": + self.show_drops = not self.show_drops + elif label == "gap": + self.show_gaps = not self.show_gaps + elif label == "nonwear": + self.show_nonwear = not self.show_nonwear + self._update_overlay_button_styles() + self.render(reset_zoom=False) + + def _set_hit_target(self, detector_name: str) -> None: + if detector_name != self.hit_target: + self.hit_target = detector_name + self.search_status = f"Jump target set to {detector_name}." + self._update_hit_target_button_styles() + self.render(reset_zoom=False) + + def _set_detail_mode(self, mode: str) -> None: + if mode != self.detail_mode: + self.detail_mode = mode + self.details_scroll = 0 + self._update_detail_tab_styles() + self.render(reset_zoom=False) + + def _on_click(self, event) -> None: + if event.inaxes is self.ax_overview and event.ydata is not None: + self._set_signal(int(round(event.ydata))) + elif event.inaxes is self.ax_signal_list and event.ydata is not None: + self._set_signal(int(round(event.ydata))) + + def _on_key_press(self, event) -> None: + if event.key == "up": + self._set_row(self.row_index - 1) + elif event.key == "down": + self._set_row(self.row_index + 1) + elif event.key == "left": + self._set_signal(self.signal_index - 1) + elif event.key == "right": + self._set_signal(self.signal_index + 1) + elif event.key == "home": + self.render(reset_zoom=True) + elif event.key == "pageup": + self._scroll_details(-1) + elif event.key == "pagedown": + self._scroll_details(1) + elif event.key == "n": + self._jump_to_hit(1) + elif event.key == "p": + self._jump_to_hit(-1) + + def _on_scroll(self, event) -> None: + if event.inaxes is not self.ax_details: + return + direction = -1 if event.button == "up" else 1 + self._scroll_details(direction) + + def _scroll_details(self, direction: int) -> None: + self.details_scroll += direction + self.render(reset_zoom=False) + + def _detector_events(self, signal: Signal) -> list[tuple[str, object]]: + events = [] + for detector in MHC_CHANNEL_CONFIG.detectors.get(signal.name, []): + detector_name = detector.__class__.__name__ + for result in detector.detect(signal.data): + events.append((detector_name, result)) + return events + + @lru_cache(maxsize=64) + def _row_detector_events(self, row_index: int) -> tuple[tuple[tuple[str, object], ...], ...]: + signals, _, _ = self._load_row_bundle(row_index) + return tuple(tuple(self._detector_events(signal)) for signal in signals) + + def _jump_to_hit(self, step: int) -> None: + n_rows = len(self.dataset) + n_signals = len(MHC_CHANNEL_CONFIG.names) + flat_index = self.row_index * n_signals + self.signal_index + + for offset in range(1, n_rows * n_signals): + candidate = (flat_index + step * offset) % (n_rows * n_signals) + row_index, signal_index = divmod(candidate, n_signals) + row_signal_events = self._row_detector_events(row_index)[signal_index] + if any(detector_name == self.hit_target for detector_name, _ in row_signal_events): + self.row_index = row_index + self.signal_index = signal_index + self.details_scroll = 0 + self.search_status = f"Jumped to row {row_index}, signal {signal_index} with {self.hit_target}." + self._sync_widgets() + self.render(reset_zoom=True) + return + + self.search_status = f"No hits found for {self.hit_target} in the scanned dataset." + self.render(reset_zoom=False) + + @staticmethod + def _captions_for_signal(signal_id: str, samples: list[Sample], annotations: list[Annotation]) -> dict[str, list[str]]: + sample_ids = { + sample.id + for sample in samples + if any(ref.signal_id == signal_id for ref in sample.signals) + } + grouped: dict[str, list[str]] = {} + for annotation in annotations: + if annotation.answer is None: + continue + if any(ref.sample_id in sample_ids for ref in annotation.samples): + grouped.setdefault(annotation.spec_id, []).append(annotation.answer) + return grouped + + @staticmethod + def _overview_matrix(signals: list[Signal]) -> np.ma.MaskedArray: + rows = [] + for signal in signals: + arr = np.asarray(signal.data, dtype=float) + normalized = np.full_like(arr, np.nan, dtype=float) + valid = ~np.isnan(arr) + if valid.any(): + values = arr[valid] + lo = float(np.nanpercentile(values, 5)) + hi = float(np.nanpercentile(values, 95)) + if hi - lo <= 1e-12: + normalized[valid] = 0.5 + else: + normalized[valid] = np.clip((values - lo) / (hi - lo), 0.0, 1.0) + rows.append(normalized) + return np.ma.masked_invalid(np.vstack(rows)) + + def _style_widgets(self) -> None: + self.row_slider.label.set_visible(False) + self.row_slider.valtext.set_visible(False) + + for button in ( + self.reset_zoom_button, + self.prev_row_button, + self.next_row_button, + self.prev_hit_button, + self.next_hit_button, + ): + button.label.set_fontsize(8.5) + + for button in self.hit_target_buttons.values(): + button.label.set_fontsize(8) + for button in self.detail_tab_buttons.values(): + button.label.set_fontsize(7.5) + self.detail_up_button.label.set_fontsize(8) + self.detail_down_button.label.set_fontsize(8) + for button in self.overlay_buttons.values(): + button.label.set_fontsize(6.8) + self._sync_widgets() + self._update_hit_target_button_styles() + self._update_detail_tab_styles() + self._update_overlay_button_styles() + + def _overlay_state(self, label: str) -> bool: + return { + "trend": self.show_trends, + "plateau": self.show_plateaus, + "spike": self.show_spikes, + "drop": self.show_drops, + "gap": self.show_gaps, + "nonwear": self.show_nonwear, + }[label] + + def _update_overlay_button_styles(self) -> None: + for label, button in self.overlay_buttons.items(): + enabled = self._overlay_state(label) + face = "#1f4f95" if enabled else "#f7f7f7" + edge = "#f4d35e" if enabled else "#b8c0cc" + text = "white" if enabled else "#6b7280" + button.ax.set_facecolor(face) + button.ax.patch.set_edgecolor(edge) + button.ax.patch.set_linewidth(2.4 if enabled else 1.2) + for spine in button.ax.spines.values(): + spine.set_edgecolor(edge) + spine.set_linewidth(2.4 if enabled else 1.2) + button.hovercolor = "#3465a4" if enabled else "#ebeff4" + button.label.set_color(text) + button.label.set_fontweight("bold" if enabled else "normal") + + def _update_hit_target_button_styles(self) -> None: + for detector_name, button in self.hit_target_buttons.items(): + active = detector_name == self.hit_target + face = "#204a87" if active else "#f4f5f7" + edge = "#f4d35e" if active else "#c7cdd6" + text = "white" if active else "#5f6b7a" + button.ax.set_facecolor(face) + button.ax.patch.set_edgecolor(edge) + button.ax.patch.set_linewidth(2.0 if active else 1.1) + for spine in button.ax.spines.values(): + spine.set_edgecolor(edge) + spine.set_linewidth(2.0 if active else 1.1) + button.hovercolor = "#3465a4" if active else "#eaedf2" + button.label.set_color(text) + button.label.set_fontweight("bold" if active else "normal") + + @staticmethod + def _build_detail_lines(title: str, lines: list[str], width: int) -> list[str]: + rendered = [title] + if not lines: + rendered.append(" none") + rendered.append("") + return rendered + for line in lines: + wrapped = textwrap.wrap(line, width=width) or [""] + rendered.extend(f" {part}" for part in wrapped) + rendered.append("") + return rendered + + def _update_detail_tab_styles(self) -> None: + for label, button in self.detail_tab_buttons.items(): + active = label == self.detail_mode + face = "#204a87" if active else "#f4f5f7" + edge = "#f4d35e" if active else "#c7cdd6" + text = "white" if active else "#5f6b7a" + button.ax.set_facecolor(face) + button.ax.patch.set_edgecolor(edge) + button.ax.patch.set_linewidth(2.2 if active else 1.1) + for spine in button.ax.spines.values(): + spine.set_edgecolor(edge) + spine.set_linewidth(2.2 if active else 1.1) + button.hovercolor = "#3465a4" if active else "#eaedf2" + button.label.set_color(text) + button.label.set_fontweight("bold" if active else "normal") + + def render(self, reset_zoom: bool = False) -> None: + signals, samples, annotations = self._load_row_bundle(self.row_index) + signal = signals[self.signal_index] + detector_events = self._detector_events(signal) + captions = self._captions_for_signal(signal.id, samples, annotations) + display_name, unit, decimals = MHC_CHANNEL_CONFIG.meta.get(signal.name, (signal.name, "", 2)) + + x = np.arange(len(signal.data)) + y = np.asarray(signal.data, dtype=float) + valid = ~np.isnan(y) + + old_xlim = self.ax_main.get_xlim() + old_ylim = self.ax_main.get_ylim() + + self.ax_main.clear() + self.ax_overview.clear() + self.ax_summary.clear() + self.ax_hit_target.clear() + self.ax_signal_list.clear() + self.ax_details.clear() + for ax in (self.ax_summary, self.ax_hit_target, self.ax_signal_list, self.ax_details): + ax.axis("off") + + self.ax_main.plot(x[valid], y[valid], color="steelblue", linewidth=1.0, label="signal") + if self.show_nonwear: + for start, end in _nan_regions(y): + self.ax_main.axvspan(start, end, color="#d62728", alpha=0.08) + + for detector_name, result in detector_events: + if result.event_type == "trend" and self.show_trends: + color = "#4daf4a" if result.direction == "increasing" else "#ff7f00" + label = f"{detector_name} ({result.direction})" + self.ax_main.axvspan(result.start_minute, result.end_minute, color=color, alpha=0.18, label=label) + elif result.event_type == "plateau" and self.show_plateaus: + self.ax_main.axvspan( + result.start_minute, + result.end_minute, + color="#377eb8", + alpha=0.16, + label=detector_name, + ) + elif result.event_type == "spike" and self.show_spikes: + minute = int(result.spike_minute) + if minute < len(y) and not np.isnan(y[minute]): + self.ax_main.scatter(minute, y[minute], color="#2ca02c", marker="^", s=38, zorder=4, label=detector_name) + self.ax_main.annotate(str(minute), (minute, y[minute]), xytext=(0, 8), textcoords="offset points", ha="center", fontsize=7) + elif result.event_type == "drop" and self.show_drops: + minute = int(result.spike_minute) + if minute < len(y) and not np.isnan(y[minute]): + self.ax_main.scatter(minute, y[minute], color="#d62728", marker="v", s=38, zorder=4, label=detector_name) + self.ax_main.annotate(str(minute), (minute, y[minute]), xytext=(0, -12), textcoords="offset points", ha="center", fontsize=7) + elif result.event_type == "gap" and self.show_gaps: + self.ax_main.axvspan(result.start_minute, result.end_minute, color="#d62728", alpha=0.12, label=detector_name) + + self.ax_main.set_title(f"Row {self.row_index} | {display_name}") + self.ax_main.set_ylabel(f"{display_name}\n({unit or 'value'})") + self.ax_main.set_xlabel("Minute of day") + self.ax_main.grid(alpha=0.2) + self.ax_main.margins(x=0) + + handles, labels = self.ax_main.get_legend_handles_labels() + deduped: dict[str, object] = {} + for handle, label in zip(handles, labels): + deduped.setdefault(label, handle) + if deduped: + self.ax_main.legend(deduped.values(), deduped.keys(), loc="upper right", fontsize=8) + + matrix = self._overview_matrix(signals) + cmap = plt.get_cmap("viridis").copy() + cmap.set_bad(color="#f1f1f1") + self.ax_overview.imshow(matrix, aspect="auto", interpolation="nearest", cmap=cmap, origin="upper") + self.ax_overview.axhspan( + self.signal_index - 0.5, + self.signal_index + 0.5, + facecolor="#f4d35e", + alpha=0.20, + edgecolor="#f4d35e", + linewidth=0, + ) + self.ax_overview.axhline(self.signal_index, color="white", linewidth=2) + self.ax_overview.set_title("Channel Overview", fontsize=10, loc="left", pad=4) + self.ax_overview.set_xlabel("Minute of day") + self.ax_overview.set_yticks([]) + self.ax_overview.tick_params(axis="x", labelsize=8) + self.ax_overview.text( + 1.0, + 1.02, + "click heatmap or signal list to change channel", + transform=self.ax_overview.transAxes, + ha="right", + va="bottom", + fontsize=8, + color="#555555", + ) + + valid_minutes = int(np.sum(valid)) + active_channels = sum(bool(s.metadata.get("has_any_data", False)) for s in signals) + total_nonwear = signal.metadata.get("total_nonwear_minutes") + wear_pct = None + if total_nonwear is not None: + wear_pct = (1440.0 - float(total_nonwear)) / 1440.0 * 100.0 + stats_text = "n/a" + if valid.any(): + values = y[valid] + stats_text = ( + f"mean={np.mean(values):.{decimals}f}\n" + f"std={np.std(values):.{decimals}f}\n" + f"min={np.min(values):.{decimals}f}\n" + f"max={np.max(values):.{decimals}f}" + ) + + detector_lines = [_format_detector_event(name, result) for name, result in detector_events] + if not detector_lines: + detector_lines = ["No detector events on this signal."] + + caption_lines = [] + for spec_id, values in captions.items(): + label = spec_id.split(":")[-1] + for value in values[:3]: + caption_lines.append(f"{label}: {value}") + if not caption_lines: + caption_lines = ["No captions for this signal."] + + self.ax_summary.set_xlim(0, 1) + self.ax_summary.set_ylim(0, 1) + self.ax_summary.add_patch(Rectangle((0.0, 0.72), 1.0, 0.28, facecolor="#204a87", edgecolor="none")) + self.ax_summary.text(0.03, 0.94, "Selected Signal", color="white", fontsize=9, va="top", weight="bold") + self.ax_summary.text(0.03, 0.80, _truncate(display_name, 26), color="white", fontsize=13, va="center", weight="bold") + self.ax_summary.text(0.03, 0.66, _truncate(signal.name, 34), color="#35506b", fontsize=8) + summary_lines = [ + f"row {self.row_index} signal {self.signal_index}/{len(signals) - 1}", + f"user {_truncate(str(signal.metadata.get('user_id', 'n/a')), 24)}", + f"date {signal.metadata.get('date', 'n/a')}", + f"wear {wear_pct:.1f}% nonwear {float(total_nonwear):.0f}m" if wear_pct is not None and total_nonwear is not None else "wear n/a", + f"active {active_channels}/{len(signals)} valid {valid_minutes}/1440", + f"has_data {signal.metadata.get('has_any_data', 'n/a')} nonzero_or_nan {signal.metadata.get('minutes_nonzero_or_nan', 'n/a')}", + ] + self.ax_summary.text( + 0.03, + 0.58, + "\n".join(summary_lines), + va="top", + ha="left", + fontsize=8.1, + family="monospace", + color="#222222", + ) + self.ax_hit_target.set_xlim(0, 1) + self.ax_hit_target.set_ylim(0, 1) + self.ax_hit_target.text( + 0.0, + 1.02, + f"Find {self.hit_target.replace('Detector', '').lower()} hits" if self.hit_target else "Find hits", + transform=self.ax_hit_target.transAxes, + ha="left", + va="bottom", + fontsize=8.5, + color="#444444", + weight="bold", + ) + self.ax_hit_target.text( + 1.0, + 1.02, + _truncate(self.search_status, 40), + transform=self.ax_hit_target.transAxes, + ha="right", + va="bottom", + fontsize=7.5, + color="#666666", + ) + + self.ax_signal_list.set_xlim(0, 1) + self.ax_signal_list.set_ylim(len(signals), 0) + self.ax_signal_list.text( + 0.0, + 1.02, + "Signals", + transform=self.ax_signal_list.transAxes, + fontsize=10, + weight="bold", + color="#333333", + va="bottom", + ) + for idx, listed_signal in enumerate(signals): + y0 = idx + is_selected = idx == self.signal_index + is_active = bool(listed_signal.metadata.get("has_any_data", False)) + face = "#204a87" if is_selected else ("#f7f7f7" if idx % 2 == 0 else "#eeeeee") + edge = "#10253f" if is_selected else "#d0d0d0" + text_color = "white" if is_selected else ("#222222" if is_active else "#888888") + self.ax_signal_list.add_patch(Rectangle((0.0, y0), 1.0, 0.92, facecolor=face, edgecolor=edge, linewidth=0.8)) + self.ax_signal_list.text( + 0.03, + y0 + 0.46, + f"{idx:02d}", + va="center", + ha="left", + fontsize=8, + family="monospace", + color="#f4d35e" if is_selected else "#666666", + weight="bold", + ) + self.ax_signal_list.text( + 0.14, + y0 + 0.46, + _truncate(_display_name(listed_signal.name), 25), + va="center", + ha="left", + fontsize=8.8, + color=text_color, + weight="bold" if is_selected else "normal", + ) + self.ax_signal_list.text( + 0.99, + 1.01, + "click to select", + transform=self.ax_signal_list.transAxes, + ha="right", + va="bottom", + fontsize=7.5, + color="#666666", + ) + + stats_lines = stats_text.splitlines() if stats_text != "n/a" else ["n/a"] + if self.detail_mode == "stats": + detail_title = "Stats" + detail_lines = self._build_detail_lines("Stats", stats_lines, width=30) + elif self.detail_mode == "captions": + detail_title = "Captions" + detail_lines = self._build_detail_lines("Captions", caption_lines, width=30) + elif self.detail_mode == "help": + detail_title = "Help" + detail_lines = self._build_detail_lines( + "Help", + [ + "click a signal on the right or the overview heatmap to change channel", + "up/down changes row", + "left/right changes signal", + "choose a detector target above the signal list", + "n / p jump to next or previous hit for that detector", + "hit< / hit> buttons do the same", + "mouse wheel over details scrolls", + "PageUp/PageDown also scroll details", + "overlay buttons toggle detector layers", + ], + width=30, + ) + else: + detail_title = "Detector Events" + detail_lines = self._build_detail_lines("Detector Events", detector_lines, width=30) + detail_lines = detail_lines[:-1] if detail_lines and detail_lines[-1] == "" else detail_lines + max_scroll = max(0, len(detail_lines) - self.details_page_lines) + self.details_scroll = min(max(self.details_scroll, 0), max_scroll) + visible_lines = detail_lines[self.details_scroll:self.details_scroll + self.details_page_lines] + start_line = self.details_scroll + 1 if detail_lines else 0 + end_line = self.details_scroll + len(visible_lines) + self.ax_details.set_xlim(0, 1) + self.ax_details.set_ylim(0, 1) + self.ax_details.add_patch(Rectangle((0, 0), 1, 1, facecolor="#fbfbfb", edgecolor="#dddddd")) + track_x = 0.975 + track_y = 0.06 + track_h = 0.80 + self.ax_details.add_patch(Rectangle((track_x, track_y), 0.012, track_h, facecolor="#ececec", edgecolor="#d0d0d0")) + if detail_lines: + visible_frac = min(1.0, self.details_page_lines / len(detail_lines)) + thumb_h = max(0.10, track_h * visible_frac) + available_h = track_h - thumb_h + scroll_frac = 0.0 if max_scroll == 0 else self.details_scroll / max_scroll + thumb_y = track_y + available_h * (1.0 - scroll_frac) + self.ax_details.add_patch(Rectangle((track_x, thumb_y), 0.012, thumb_h, facecolor="#204a87", edgecolor="none")) + self.ax_details.text( + 0.03, + 0.90, + "\n".join(visible_lines), + va="top", + ha="left", + fontsize=7.8, + family="monospace", + color="#222222", + ) + + if reset_zoom: + self.ax_main.set_xlim(0, len(y) - 1) + else: + self.ax_main.set_xlim(old_xlim) + self.ax_main.set_ylim(old_ylim) + + self.fig.canvas.draw_idle() + + def save(self, path: str) -> None: + self.fig.savefig(path, dpi=150) + + +def main() -> None: + args = _parse_args() + dataset = MHCDataset(min_wear_pct=args.min_wear_pct) + explorer = SensorExplorer( + dataset=dataset, + row_index=args.row_index, + signal_index=args.signal_index, + ) + if args.save_path: + explorer.save(args.save_path) + print(f"Saved {args.save_path}") + plt.close(explorer.fig) + return + plt.show() + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 156cf55..43b5665 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -numpy +datasets matplotlib +numpy +pyarrow +scipy torch Pillow transformers -datasets -pyarrow diff --git a/visualizer.py b/visualizer.py index 72b45fe..eae1661 100644 --- a/visualizer.py +++ b/visualizer.py @@ -6,12 +6,16 @@ # from __future__ import annotations +import argparse import textwrap from math import ceil import matplotlib.pyplot as plt import numpy as np +from matplotlib.lines import Line2D +from matplotlib.patches import Patch +from extractors import ChannelConfig from timef.schema import Annotation, Sample, Signal @@ -133,6 +137,137 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] return fig +def plot_detector_view( + signals: list[Signal], + config: ChannelConfig, + save_path: str | None = None, + n_cols: int = 3, +) -> plt.Figure | None: + channel_names = [s.name for s in signals] + n_channels = len(channel_names) + if n_channels == 0: + raise ValueError("signals must not be empty") + + data = np.stack([s.data for s in signals]) + time_minutes = np.arange(data.shape[1]) + + n_rows = ceil(n_channels / n_cols) + fig, axes_2d = plt.subplots(n_rows, n_cols, figsize=(20, 4.6 * n_rows), sharex="col") + axes = np.atleast_1d(axes_2d).flatten() + + legend_handles: dict[str, object] = { + "signal": Line2D([0], [0], color="steelblue", linewidth=1.0, label="signal"), + "trend:increasing": Patch(facecolor="#4daf4a", edgecolor="none", alpha=0.18, label="trend (increasing)"), + "trend:decreasing": Patch(facecolor="#ff7f00", edgecolor="none", alpha=0.18, label="trend (decreasing)"), + "plateau": Patch(facecolor="#377eb8", edgecolor="none", alpha=0.16, label="plateau"), + "spike": Line2D([0], [0], marker="^", linestyle="None", color="#2ca02c", markersize=6, label="spike"), + "drop": Line2D([0], [0], marker="v", linestyle="None", color="#d62728", markersize=6, label="drop"), + "gap": Patch(facecolor="#d62728", edgecolor="none", alpha=0.10, label="gap"), + } + + for i, (ax, signal) in enumerate(zip(axes[:n_channels], signals)): + series = data[i] + valid = ~np.isnan(series) + ax.plot(time_minutes[valid], series[valid], linewidth=0.8, color="steelblue") + + detector_summaries: list[str] = [] + for detector in config.detectors.get(signal.name, []): + detector_name = detector.__class__.__name__ + results = detector.detect(signal.data) + if not results: + detector_summaries.append(f"{detector_name}: none") + continue + + rendered = [] + for result in results: + if result.event_type == "trend": + direction = result.direction or "trend" + color = "#4daf4a" if direction == "increasing" else "#ff7f00" + ax.axvspan(result.start_minute, result.end_minute, color=color, alpha=0.18) + rendered.append(f"{direction} {result.start_minute}-{result.end_minute}") + elif result.event_type == "plateau": + ax.axvspan(result.start_minute, result.end_minute, color="#377eb8", alpha=0.16) + rendered.append(f"plateau {result.start_minute}-{result.end_minute}") + elif result.event_type == "spike": + minute = int(result.spike_minute) + value = signal.data[minute] + if not np.isnan(value): + ax.scatter(minute, value, color="#2ca02c", marker="^", s=28, zorder=3) + rendered.append(f"spike @{minute}") + elif result.event_type == "drop": + minute = int(result.spike_minute) + value = signal.data[minute] + if not np.isnan(value): + ax.scatter(minute, value, color="#d62728", marker="v", s=28, zorder=3) + rendered.append(f"drop @{minute}") + elif result.event_type == "gap": + ax.axvspan(result.start_minute, result.end_minute, color="#d62728", alpha=0.10) + rendered.append(f"gap {result.start_minute}-{result.end_minute}") + else: + rendered.append(result.event_type) + + detector_summaries.append(f"{detector_name}: {', '.join(rendered)}") + + display_name = config.meta.get(signal.name, (signal.name, "", 0))[0] + ax.set_title(display_name, fontsize=8) + ax.set_ylabel(signal.name, fontsize=6, rotation=90, ha="right", va="center") + ax.tick_params(axis="y", labelsize=6) + ax.tick_params(axis="x", labelsize=6, labelbottom=True) + ax.set_xlim(0, len(series)) + + summary_text = "\n".join(textwrap.fill(line, width=56) for line in detector_summaries) + if summary_text: + ax.text( + 0.01, + 0.99, + summary_text, + transform=ax.transAxes, + va="top", + ha="left", + fontsize=5.5, + bbox={"boxstyle": "round,pad=0.25", "facecolor": "white", "alpha": 0.85, "edgecolor": "#cccccc"}, + ) + + for col in range(n_cols): + axes[col].set_xticks(range(0, data.shape[1] + 1, 200)) + + for ax in axes[n_channels:]: + ax.set_visible(False) + + for ax in axes[:n_channels]: + ax.set_xlabel("Time (minutes)") + + meta = signals[0].metadata if signals else {} + title = f"{meta.get('user_id', '')} | {meta.get('date', '')} | detector debug view" + fig.suptitle(title, fontsize=10) + fig.legend( + handles=list(legend_handles.values()), + loc="upper center", + bbox_to_anchor=(0.5, 0.99), + ncol=6, + fontsize=7, + frameon=False, + ) + fig.tight_layout(rect=(0, 0, 1, 0.97)) + + if save_path: + fig.savefig(save_path, dpi=150) + plt.close(fig) + return None + + plt.show() + return fig + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Visualize caption or detector-debug plots for one dataset row.") + parser.add_argument("--mode", choices=("captions", "detectors"), default="captions") + parser.add_argument("--row-index", type=int, default=0, help="Dataset row index to visualize.") + parser.add_argument("--save-path", type=str, default=None, help="Optional output image path.") + parser.add_argument("--min-wear-pct", type=float, default=90.0, help="Minimum wear percentage filter.") + return parser.parse_args() + + if __name__ == "__main__": from mhc.dataset import MHCDataset from mhc.transformer import MHCTransformer @@ -142,10 +277,22 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] from annotator import Annotator from captionizer import Captionizer - dataset = MHCDataset(min_wear_pct=90.0) - annotator = Annotator([StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG)]) - captionizer = Captionizer(dataset, MHCTransformer(), annotator) - result = captionizer.run(max_rows=1) - for signals, samples, annotations in result.iter_rows(): - plot_row(signals, samples, annotations, save_path="sample_plot.png") - print("Saved sample_plot.png") + args = _parse_args() + + dataset = MHCDataset(min_wear_pct=args.min_wear_pct) + if args.row_index < 0 or args.row_index >= len(dataset): + raise IndexError(f"row-index must be between 0 and {len(dataset) - 1}, got {args.row_index}") + + row = dataset[args.row_index] + signals = MHCTransformer().transform_row(row) + + if args.mode == "detectors": + save_path = args.save_path or f"detector_debug_row_{args.row_index}.png" + plot_detector_view(signals, MHC_CHANNEL_CONFIG, save_path=save_path) + print(f"Saved {save_path}") + else: + annotator = Annotator([StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG)]) + samples, annotations = annotator.annotate(signals) + save_path = args.save_path or "sample_plot.png" + plot_row(signals, samples, annotations, save_path=save_path) + print(f"Saved {save_path}")