From 251f935699643685eba8bd140610f3b23fa71d97 Mon Sep 17 00:00:00 2001 From: Karl Date: Fri, 27 Mar 2026 12:09:34 +0100 Subject: [PATCH 1/7] Add configurable sliding-window trend detector Implement TrendDetector using relative multi-scale windows, linear-fit scoring, coverage checks, and segment merging. Keep existing detector call sites unchanged so current channel configs continue to work with the new defaults. Future tuning can override detector parameters per dataset or per channel in mhc/constants.py. --- detectors/trend.py | 164 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 160 insertions(+), 4 deletions(-) diff --git a/detectors/trend.py b/detectors/trend.py index 370f71d..ea075e2 100644 --- a/detectors/trend.py +++ b/detectors/trend.py @@ -5,18 +5,174 @@ # SPDX-License-Identifier: MIT # from __future__ import annotations +from dataclasses import dataclass import numpy as np from detectors import DetectionResult, StructuralDetector +@dataclass +class _TrendSegment: + direction: str + start_minute: int + end_minute: int + score: float + + class TrendDetector(StructuralDetector): - """Detects statistically significant trends. - """ + """Detects increasing/decreasing trends across multiple relative window sizes.""" - def __init__(self, filter_zeros: bool = False) -> None: + def __init__( + self, + filter_zeros: bool = False, + window_sizes: tuple[int, ...] | None = None, + window_fracs: tuple[float, ...] = (0.1, 0.25, 1.0), + min_window: int = 12, + max_window: int | None = None, + stride_frac: float = 0.25, + min_coverage: float = 0.75, + min_effect: float = 1.25, + min_r2: float = 0.5, + min_span: int = 10, + merge_gap: int = 5, + ) -> None: super().__init__(filter_zeros=filter_zeros) + self.window_sizes = window_sizes + self.window_fracs = tuple(window_fracs) + self.min_window = max(3, min_window) + self.max_window = max_window + self.stride_frac = stride_frac + self.min_coverage = min_coverage + self.min_effect = min_effect + self.min_r2 = min_r2 + self.min_span = min_span + self.merge_gap = merge_gap def _detect(self, series: np.ndarray, indices: np.ndarray) -> list[DetectionResult]: - return [] + window_sizes = self._resolve_window_sizes(len(series)) + segments: list[_TrendSegment] = [] + for window_size in window_sizes: + stride = max(1, int(round(window_size * self.stride_frac))) + for start_idx in self._window_starts(len(series), window_size, stride): + end_idx = start_idx + window_size + segment = self._classify_window(series[start_idx:end_idx], indices[start_idx:end_idx]) + if segment is not None: + segments.append(segment) + + merged = self._merge_segments(segments) + return [ + DetectionResult( + event_type="trend", + start_minute=segment.start_minute, + end_minute=segment.end_minute, + direction=segment.direction, + ) + for segment in merged + if segment.end_minute - segment.start_minute >= self.min_span + ] + + def _resolve_window_sizes(self, n_samples: int) -> list[int]: + sizes = set(self.window_sizes or ()) + for frac in self.window_fracs: + if frac <= 0: + continue + sizes.add(int(round(n_samples * frac))) + + resolved = [] + max_window = self.max_window or n_samples + for size in sorted(sizes): + size = max(self.min_window, min(size, max_window, n_samples)) + if size <= n_samples: + resolved.append(size) + + return sorted(set(resolved)) + + @staticmethod + def _window_starts(n_samples: int, window_size: int, stride: int) -> list[int]: + if window_size >= n_samples: + return [0] + + starts = list(range(0, n_samples - window_size + 1, stride)) + last_start = n_samples - window_size + if starts[-1] != last_start: + starts.append(last_start) + return starts + + def _classify_window( + self, + series: np.ndarray, + indices: np.ndarray, + ) -> _TrendSegment | None: + if len(series) < self.min_window: + return None + + start_minute = int(indices[0]) + end_minute = int(indices[-1]) + span = end_minute - start_minute + if span < self.min_span: + return None + + coverage = len(indices) / (span + 1) + if coverage < self.min_coverage: + return None + + x = indices.astype(float) + y = series.astype(float) + x_centered = x - np.mean(x) + y_centered = y - np.mean(y) + + denom = float(np.dot(x_centered, x_centered)) + if denom <= 0: + return None + + slope = float(np.dot(x_centered, y_centered) / denom) + fitted = np.mean(y) + slope * x_centered + + ss_tot = float(np.dot(y_centered, y_centered)) + if ss_tot <= 1e-12: + return None + + ss_res = float(np.dot(y - fitted, y - fitted)) + r2 = max(0.0, 1.0 - ss_res / ss_tot) + delta = slope * span + scale = float(np.std(y)) + if scale <= 1e-12: + return None + + effect = abs(delta) / scale + if effect < self.min_effect or r2 < self.min_r2: + return None + + direction = "increasing" if slope > 0 else "decreasing" + return _TrendSegment( + direction=direction, + start_minute=start_minute, + end_minute=end_minute, + score=effect * r2, + ) + + def _merge_segments(self, segments: list[_TrendSegment]) -> list[_TrendSegment]: + if not segments: + return [] + + segments.sort(key=lambda s: (s.start_minute, s.end_minute, -s.score)) + merged: list[_TrendSegment] = [] + + for segment in segments: + if not merged: + merged.append(segment) + continue + + prev = merged[-1] + if ( + segment.direction == prev.direction + and segment.start_minute <= prev.end_minute + self.merge_gap + ): + prev.end_minute = max(prev.end_minute, segment.end_minute) + prev.score = max(prev.score, segment.score) + continue + + merged.append(segment) + + return merged From 705be76244fb266e6e9e498bde82d9ceb6f927c4 Mon Sep 17 00:00:00 2001 From: Karl Date: Fri, 27 Mar 2026 15:57:06 +0100 Subject: [PATCH 2/7] requirements + README update --- README.md | 8 ++++++-- requirements.txt | 6 ++++++ 2 files changed, 12 insertions(+), 2 deletions(-) create mode 100644 requirements.txt diff --git a/README.md b/README.md index e978985..30bb1b5 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,11 @@ Dataset-agnostic captioning pipeline for sensor time-series data. ## Setup -Set the dataset path before running: +Install dependencies and set the dataset path before running: + +```bash +python3 -m pip install -r requirements.txt +``` ```bash export MHC_DATASET_DIR="../hf-daily_max-nonwear=50" @@ -13,5 +17,5 @@ export MHC_DATASET_DIR="../hf-daily_max-nonwear=50" ## Usage ```bash -python captionizer.py +python3 captionizer.py ``` diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ad7e82c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +datasets +matplotlib +numpy +pyarrow +scipy +torch From 31867d0ced6d4da345fe887ae72108c8019f40c0 Mon Sep 17 00:00:00 2001 From: Karl Date: Fri, 27 Mar 2026 15:57:17 +0100 Subject: [PATCH 3/7] Implement spike and drop detection with SciPy find_peaks --- detectors/spike.py | 48 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/detectors/spike.py b/detectors/spike.py index d26f129..abe65cc 100644 --- a/detectors/spike.py +++ b/detectors/spike.py @@ -7,6 +7,7 @@ from __future__ import annotations import numpy as np +from scipy.signal import find_peaks from detectors import DetectionResult, StructuralDetector @@ -15,8 +16,51 @@ class SpikeDetector(StructuralDetector): """Detects spikes and drops. """ - def __init__(self, filter_zeros: bool = False) -> None: + def __init__( + self, + filter_zeros: bool = False, + prominence_scale: float = 3.0, + min_prominence: float = 0.0, + min_distance: int = 1, + ) -> None: super().__init__(filter_zeros=filter_zeros) + self.prominence_scale = prominence_scale + self.min_prominence = min_prominence + self.min_distance = max(1, min_distance) def _detect(self, series: np.ndarray, indices: np.ndarray) -> list[DetectionResult]: - return [] + prominence = self._prominence_threshold(series) + if prominence <= 0: + return [] + + results: list[DetectionResult] = [] + seen_minutes: set[int] = set() + + for peak_idx in find_peaks(series, prominence=prominence, distance=self.min_distance)[0]: + minute = int(indices[peak_idx]) + if minute in seen_minutes: + continue + results.append(DetectionResult(event_type="spike", spike_minute=minute)) + seen_minutes.add(minute) + + for peak_idx in find_peaks(-series, prominence=prominence, distance=self.min_distance)[0]: + minute = int(indices[peak_idx]) + if minute in seen_minutes: + continue + results.append(DetectionResult(event_type="drop", spike_minute=minute)) + seen_minutes.add(minute) + + results.sort(key=lambda result: int(result.spike_minute)) + return results + + def _prominence_threshold(self, series: np.ndarray) -> float: + centered = series - np.median(series) + mad = float(np.median(np.abs(centered))) + if mad > 0: + scale = 1.4826 * mad + return max(self.min_prominence, self.prominence_scale * scale) + + spread = float(np.percentile(series, 95) - np.percentile(series, 5)) + if spread <= 1e-12: + return 0.0 + return max(self.min_prominence, 0.5 * spread) From 2cc266d14139fe37ea011bdf3f0a5ee23089595e Mon Sep 17 00:00:00 2001 From: Karl Date: Tue, 31 Mar 2026 19:49:08 +0200 Subject: [PATCH 4/7] Implemented a detector-debug visualization mode in visualizer.py (line 128) so you can inspect which detector fired, on which channel, and at what time in the series. How to use: python visualizer.py --mode detectors --row-index 0 --save-path detector_debug_row_0.png --min-wear-pct 0 --- visualizer.py | 157 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 150 insertions(+), 7 deletions(-) diff --git a/visualizer.py b/visualizer.py index 0f6f64d..48bfa40 100644 --- a/visualizer.py +++ b/visualizer.py @@ -6,12 +6,16 @@ # from __future__ import annotations +import argparse import textwrap from math import ceil import matplotlib.pyplot as plt import numpy as np +from matplotlib.lines import Line2D +from matplotlib.patches import Patch +from extractors import ChannelConfig from timef.schema import Annotation, Sample, Signal @@ -121,6 +125,133 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] return fig +def plot_detector_view( + signals: list[Signal], + config: ChannelConfig, + save_path: str | None = None, + n_cols: int = 3, +) -> plt.Figure | None: + channel_names = [s.name for s in signals] + n_channels = len(channel_names) + if n_channels == 0: + raise ValueError("signals must not be empty") + + data = np.stack([s.data for s in signals]) + time_minutes = np.arange(data.shape[1]) + + n_rows = ceil(n_channels / n_cols) + fig, axes_2d = plt.subplots(n_rows, n_cols, figsize=(20, 4.6 * n_rows), sharex="col") + axes = np.atleast_1d(axes_2d).flatten() + + legend_handles: dict[str, object] = { + "signal": Line2D([0], [0], color="steelblue", linewidth=1.0, label="signal"), + "trend:increasing": Patch(facecolor="#4daf4a", edgecolor="none", alpha=0.18, label="trend (increasing)"), + "trend:decreasing": Patch(facecolor="#ff7f00", edgecolor="none", alpha=0.18, label="trend (decreasing)"), + "spike": Line2D([0], [0], marker="^", linestyle="None", color="#2ca02c", markersize=6, label="spike"), + "drop": Line2D([0], [0], marker="v", linestyle="None", color="#d62728", markersize=6, label="drop"), + "gap": Patch(facecolor="#d62728", edgecolor="none", alpha=0.10, label="gap"), + } + + for i, (ax, signal) in enumerate(zip(axes[:n_channels], signals)): + series = data[i] + valid = ~np.isnan(series) + ax.plot(time_minutes[valid], series[valid], linewidth=0.8, color="steelblue") + + detector_summaries: list[str] = [] + for detector in config.detectors.get(signal.name, []): + detector_name = detector.__class__.__name__ + results = detector.detect(signal.data) + if not results: + detector_summaries.append(f"{detector_name}: none") + continue + + rendered = [] + for result in results: + if result.event_type == "trend": + direction = result.direction or "trend" + color = "#4daf4a" if direction == "increasing" else "#ff7f00" + ax.axvspan(result.start_minute, result.end_minute, color=color, alpha=0.18) + rendered.append(f"{direction} {result.start_minute}-{result.end_minute}") + elif result.event_type == "spike": + minute = int(result.spike_minute) + value = signal.data[minute] + if not np.isnan(value): + ax.scatter(minute, value, color="#2ca02c", marker="^", s=28, zorder=3) + rendered.append(f"spike @{minute}") + elif result.event_type == "drop": + minute = int(result.spike_minute) + value = signal.data[minute] + if not np.isnan(value): + ax.scatter(minute, value, color="#d62728", marker="v", s=28, zorder=3) + rendered.append(f"drop @{minute}") + elif result.event_type == "gap": + ax.axvspan(result.start_minute, result.end_minute, color="#d62728", alpha=0.10) + rendered.append(f"gap {result.start_minute}-{result.end_minute}") + else: + rendered.append(result.event_type) + + detector_summaries.append(f"{detector_name}: {', '.join(rendered)}") + + display_name = config.meta.get(signal.name, (signal.name, "", 0))[0] + ax.set_title(display_name, fontsize=8) + ax.set_ylabel(signal.name, fontsize=6, rotation=90, ha="right", va="center") + ax.tick_params(axis="y", labelsize=6) + ax.tick_params(axis="x", labelsize=6, labelbottom=True) + ax.set_xlim(0, len(series)) + + summary_text = "\n".join(textwrap.fill(line, width=56) for line in detector_summaries) + if summary_text: + ax.text( + 0.01, + 0.99, + summary_text, + transform=ax.transAxes, + va="top", + ha="left", + fontsize=5.5, + bbox={"boxstyle": "round,pad=0.25", "facecolor": "white", "alpha": 0.85, "edgecolor": "#cccccc"}, + ) + + for col in range(n_cols): + axes[col].set_xticks(range(0, data.shape[1] + 1, 200)) + + for ax in axes[n_channels:]: + ax.set_visible(False) + + for ax in axes[:n_channels]: + ax.set_xlabel("Time (minutes)") + + meta = signals[0].metadata if signals else {} + title = f"{meta.get('user_id', '')} | {meta.get('date', '')} | detector debug view" + fig.suptitle(title, fontsize=10) + fig.legend( + handles=list(legend_handles.values()), + loc="upper center", + bbox_to_anchor=(0.5, 0.99), + ncol=6, + fontsize=7, + frameon=False, + ) + fig.tight_layout(rect=(0, 0, 1, 0.97)) + + if save_path: + fig.savefig(save_path, dpi=150) + plt.close(fig) + return None + + plt.show() + return fig + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Visualize caption or detector-debug plots for one dataset row.") + parser.add_argument("--mode", choices=("captions", "detectors"), default="captions") + parser.add_argument("--row-index", type=int, default=0, help="Dataset row index to visualize.") + parser.add_argument("--save-path", type=str, default=None, help="Optional output image path.") + parser.add_argument("--min-wear-pct", type=float, default=90.0, help="Minimum wear percentage filter.") + return parser.parse_args() + + if __name__ == "__main__": from mhc.dataset import MHCDataset from mhc.transformer import MHCTransformer @@ -130,10 +261,22 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] from annotator import Annotator from captionizer import Captionizer - dataset = MHCDataset(min_wear_pct=90.0) - annotator = Annotator([StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG)]) - captionizer = Captionizer(dataset, MHCTransformer(), annotator) - result = captionizer.run(max_rows=1) - for signals, samples, annotations in result.iter_rows(): - plot_row(signals, samples, annotations, save_path="sample_plot.png") - print("Saved sample_plot.png") + args = _parse_args() + + dataset = MHCDataset(min_wear_pct=args.min_wear_pct) + if args.row_index < 0 or args.row_index >= len(dataset): + raise IndexError(f"row-index must be between 0 and {len(dataset) - 1}, got {args.row_index}") + + row = dataset[args.row_index] + signals = MHCTransformer().transform_row(row) + + if args.mode == "detectors": + save_path = args.save_path or f"detector_debug_row_{args.row_index}.png" + plot_detector_view(signals, MHC_CHANNEL_CONFIG, save_path=save_path) + print(f"Saved {save_path}") + else: + annotator = Annotator([StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG)]) + samples, annotations = annotator.annotate(signals) + save_path = args.save_path or "sample_plot.png" + plot_row(signals, samples, annotations, save_path=save_path) + print(f"Saved {save_path}") From d703c93ddb55463c6cc765373cb44e1b462b08ae Mon Sep 17 00:00:00 2001 From: Karl Date: Tue, 31 Mar 2026 20:13:42 +0200 Subject: [PATCH 5/7] Add interactive detector explorer and document usage Introduce a Matplotlib-based explorer for browsing dataset rows, switching signals, inspecting detector overlays, and reviewing stats/captions in a side panel. Document how to launch and use the explorer in the README. --- README.md | 20 ++ explorer.py | 645 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 665 insertions(+) create mode 100644 explorer.py diff --git a/README.md b/README.md index 30bb1b5..0ac5359 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,23 @@ export MHC_DATASET_DIR="../hf-daily_max-nonwear=50" ```bash python3 captionizer.py ``` + +## Explorer + +Use the interactive explorer to inspect one row at a time, switch signals, and see which detector events fired where on the time series. + +Start it with: + +```bash +python3 explorer.py +``` + +Useful controls: + +- Use the bottom row slider or `<` / `>` buttons to move between dataset rows. +- Click a signal in the right-hand signal list or in the channel overview heatmap to switch channels. +- Use the Matplotlib zoom and pan tools on the main plot to inspect parts of the signal in detail. +- Click `reset` or press `home` to reset the zoom. +- Use the overlay buttons to toggle `trend`, `spike`, `drop`, `gap`, and `nonwear` overlays. +- Use the `stats`, `events`, `captions`, and `help` tabs in the details panel to switch what metadata is shown. +- Scroll inside the details panel with the mouse wheel or the `^` / `v` buttons. diff --git a/explorer.py b/explorer.py new file mode 100644 index 0000000..6728eef --- /dev/null +++ b/explorer.py @@ -0,0 +1,645 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import argparse +import textwrap +from functools import lru_cache + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.patches import Rectangle +from matplotlib.widgets import Button, Slider + +from annotator import Annotator +from extractors.statistical import StatisticalExtractor +from extractors.structural import StructuralExtractor +from mhc.constants import MHC_CHANNEL_CONFIG +from mhc.dataset import MHCDataset +from mhc.transformer import MHCTransformer +from timef.schema import Annotation, Sample, Signal + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Interactive explorer for MHC rows, signals, and detector outputs.") + parser.add_argument("--row-index", type=int, default=0, help="Initial dataset row index.") + parser.add_argument("--signal-index", type=int, default=0, help="Initial signal index.") + parser.add_argument("--min-wear-pct", type=float, default=0.0, help="Minimum wear percentage filter.") + parser.add_argument("--save-path", type=str, default=None, help="Optional snapshot path. Saves current view and exits.") + return parser.parse_args() + + +def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]]: + regions = [] + in_region = False + for i, val in enumerate(np.isnan(arr)): + if val and not in_region: + start = i + in_region = True + elif not val and in_region: + if i - start >= min_length: + regions.append((start, i - 1)) + in_region = False + if in_region and len(arr) - start >= min_length: + regions.append((start, len(arr) - 1)) + return regions + + +def _format_detector_event(detector_name: str, result) -> str: + if result.event_type == "trend": + return f"{detector_name}: {result.direction} {result.start_minute}-{result.end_minute}" + if result.event_type in {"spike", "drop"}: + return f"{detector_name}: {result.event_type} @{result.spike_minute}" + if result.event_type == "gap": + return f"{detector_name}: gap {result.start_minute}-{result.end_minute}" + return f"{detector_name}: {result.event_type}" + + +def _display_name(signal_name: str) -> str: + return MHC_CHANNEL_CONFIG.meta.get(signal_name, (signal_name, "", 0))[0] + + +def _truncate(text: str, max_len: int = 34) -> str: + if len(text) <= max_len: + return text + return text[: max_len - 1] + "..." + + +class SensorExplorer: + def __init__( + self, + dataset: MHCDataset, + row_index: int = 0, + signal_index: int = 0, + ) -> None: + self.dataset = dataset + self.transformer = MHCTransformer() + self.annotator = Annotator([ + StatisticalExtractor(MHC_CHANNEL_CONFIG), + StructuralExtractor(MHC_CHANNEL_CONFIG), + ]) + + self.row_index = min(max(0, row_index), len(self.dataset) - 1) + self.signal_index = min(max(0, signal_index), len(MHC_CHANNEL_CONFIG.names) - 1) + + self.show_trends = True + self.show_spikes = True + self.show_drops = True + self.show_gaps = True + self.show_nonwear = True + self.detail_mode = "events" + self.details_scroll = 0 + self.details_page_lines = 12 + + self._ignore_widget_events = False + + self.fig = plt.figure(figsize=(17, 10)) + self.ax_main = self.fig.add_axes([0.06, 0.34, 0.66, 0.58]) + self.ax_overview = self.fig.add_axes([0.06, 0.16, 0.66, 0.12], sharex=self.ax_main) + self.ax_summary = self.fig.add_axes([0.76, 0.76, 0.22, 0.16]) + self.ax_signal_list = self.fig.add_axes([0.76, 0.44, 0.22, 0.28]) + self.ax_details = self.fig.add_axes([0.76, 0.16, 0.22, 0.22]) + for ax in (self.ax_summary, self.ax_signal_list, self.ax_details): + ax.axis("off") + + self.reset_zoom_ax = self.fig.add_axes([0.06, 0.049, 0.055, 0.036]) + self.row_slider_ax = self.fig.add_axes([0.13, 0.060, 0.54, 0.024]) + self.prev_row_ax = self.fig.add_axes([0.69, 0.051, 0.035, 0.036]) + self.next_row_ax = self.fig.add_axes([0.73, 0.051, 0.035, 0.036]) + detail_tab_specs = [ + ("stats", 0.76), + ("events", 0.815), + ("captions", 0.870), + ("help", 0.925), + ] + self.detail_tab_buttons: dict[str, Button] = {} + for label, x0 in detail_tab_specs: + ax = self.fig.add_axes([x0, 0.390, 0.05, 0.028]) + self.detail_tab_buttons[label] = Button(ax, label) + self.detail_up_ax = self.fig.add_axes([0.935, 0.125, 0.022, 0.028]) + self.detail_down_ax = self.fig.add_axes([0.958, 0.125, 0.022, 0.028]) + overlay_labels = ["trend", "spike", "drop", "gap", "nonwear"] + self.overlay_buttons: dict[str, Button] = {} + start_x = 0.79 + button_width = 0.037 + gap = 0.004 + for i, label in enumerate(overlay_labels): + ax = self.fig.add_axes([start_x + i * (button_width + gap), 0.051, button_width, 0.036]) + self.overlay_buttons[label] = Button(ax, label) + + self.row_slider = Slider( + self.row_slider_ax, + "Row", + 0, + len(self.dataset) - 1, + valinit=self.row_index, + valstep=1, + valfmt="%d", + ) + self.reset_zoom_button = Button(self.reset_zoom_ax, "reset") + self.prev_row_button = Button(self.prev_row_ax, "<") + self.next_row_button = Button(self.next_row_ax, ">") + self.detail_up_button = Button(self.detail_up_ax, "^") + self.detail_down_button = Button(self.detail_down_ax, "v") + self._style_widgets() + self._sync_widgets() + + self.row_slider.on_changed(self._on_row_slider) + self.reset_zoom_button.on_clicked(lambda _: self.render(reset_zoom=True)) + self.prev_row_button.on_clicked(lambda _: self._set_row(self.row_index - 1)) + self.next_row_button.on_clicked(lambda _: self._set_row(self.row_index + 1)) + self.detail_up_button.on_clicked(lambda _: self._scroll_details(-1)) + self.detail_down_button.on_clicked(lambda _: self._scroll_details(1)) + for label, button in self.detail_tab_buttons.items(): + button.on_clicked(lambda _, name=label: self._set_detail_mode(name)) + for label, button in self.overlay_buttons.items(): + button.on_clicked(lambda _, name=label: self._on_toggle(name)) + self.fig.canvas.mpl_connect("button_press_event", self._on_click) + self.fig.canvas.mpl_connect("key_press_event", self._on_key_press) + self.fig.canvas.mpl_connect("scroll_event", self._on_scroll) + + self.render(reset_zoom=True) + + @lru_cache(maxsize=12) + def _load_row_bundle(self, row_index: int) -> tuple[list[Signal], list[Sample], list[Annotation]]: + row = self.dataset[row_index] + signals = self.transformer.transform_row(row) + samples, annotations = self.annotator.annotate(signals) + return signals, samples, annotations + + def _set_row(self, row_index: int) -> None: + row_index = min(max(0, int(row_index)), len(self.dataset) - 1) + if row_index == self.row_index: + return + self.row_index = row_index + self._sync_widgets() + self.render(reset_zoom=True) + + def _set_signal(self, signal_index: int) -> None: + signal_index = min(max(0, int(signal_index)), len(MHC_CHANNEL_CONFIG.names) - 1) + if signal_index == self.signal_index: + return + self.signal_index = signal_index + self._sync_widgets() + self.render(reset_zoom=True) + + def _sync_widgets(self) -> None: + self._ignore_widget_events = True + self.row_slider.set_val(self.row_index) + self.row_slider_ax.set_title(f"Row {self.row_index} / {len(self.dataset) - 1}", loc="left", fontsize=10, pad=2) + self._ignore_widget_events = False + + def _on_row_slider(self, value: float) -> None: + if not self._ignore_widget_events: + self._set_row(int(value)) + + def _on_toggle(self, label: str) -> None: + if label == "trend": + self.show_trends = not self.show_trends + elif label == "spike": + self.show_spikes = not self.show_spikes + elif label == "drop": + self.show_drops = not self.show_drops + elif label == "gap": + self.show_gaps = not self.show_gaps + elif label == "nonwear": + self.show_nonwear = not self.show_nonwear + self._update_overlay_button_styles() + self.render(reset_zoom=False) + + def _set_detail_mode(self, mode: str) -> None: + if mode != self.detail_mode: + self.detail_mode = mode + self.details_scroll = 0 + self._update_detail_tab_styles() + self.render(reset_zoom=False) + + def _on_click(self, event) -> None: + if event.inaxes is self.ax_overview and event.ydata is not None: + self._set_signal(int(round(event.ydata))) + elif event.inaxes is self.ax_signal_list and event.ydata is not None: + self._set_signal(int(round(event.ydata))) + + def _on_key_press(self, event) -> None: + if event.key == "up": + self._set_row(self.row_index - 1) + elif event.key == "down": + self._set_row(self.row_index + 1) + elif event.key == "left": + self._set_signal(self.signal_index - 1) + elif event.key == "right": + self._set_signal(self.signal_index + 1) + elif event.key == "home": + self.render(reset_zoom=True) + elif event.key == "pageup": + self._scroll_details(-1) + elif event.key == "pagedown": + self._scroll_details(1) + + def _on_scroll(self, event) -> None: + if event.inaxes is not self.ax_details: + return + direction = -1 if event.button == "up" else 1 + self._scroll_details(direction) + + def _scroll_details(self, direction: int) -> None: + self.details_scroll += direction + self.render(reset_zoom=False) + + def _detector_events(self, signal: Signal) -> list[tuple[str, object]]: + events = [] + for detector in MHC_CHANNEL_CONFIG.detectors.get(signal.name, []): + detector_name = detector.__class__.__name__ + for result in detector.detect(signal.data): + events.append((detector_name, result)) + return events + + @staticmethod + def _captions_for_signal(signal_id: str, samples: list[Sample], annotations: list[Annotation]) -> dict[str, list[str]]: + sample_ids = { + sample.id + for sample in samples + if any(ref.signal_id == signal_id for ref in sample.signals) + } + grouped: dict[str, list[str]] = {} + for annotation in annotations: + if annotation.answer is None: + continue + if any(ref.sample_id in sample_ids for ref in annotation.samples): + grouped.setdefault(annotation.spec_id, []).append(annotation.answer) + return grouped + + @staticmethod + def _overview_matrix(signals: list[Signal]) -> np.ma.MaskedArray: + rows = [] + for signal in signals: + arr = np.asarray(signal.data, dtype=float) + normalized = np.full_like(arr, np.nan, dtype=float) + valid = ~np.isnan(arr) + if valid.any(): + values = arr[valid] + lo = float(np.nanpercentile(values, 5)) + hi = float(np.nanpercentile(values, 95)) + if hi - lo <= 1e-12: + normalized[valid] = 0.5 + else: + normalized[valid] = np.clip((values - lo) / (hi - lo), 0.0, 1.0) + rows.append(normalized) + return np.ma.masked_invalid(np.vstack(rows)) + + def _style_widgets(self) -> None: + self.row_slider.label.set_visible(False) + self.row_slider.valtext.set_visible(False) + + for button in ( + self.reset_zoom_button, + self.prev_row_button, + self.next_row_button, + ): + button.label.set_fontsize(9) + + for button in self.detail_tab_buttons.values(): + button.label.set_fontsize(7.5) + self.detail_up_button.label.set_fontsize(8) + self.detail_down_button.label.set_fontsize(8) + for button in self.overlay_buttons.values(): + button.label.set_fontsize(7.5) + self._sync_widgets() + self._update_detail_tab_styles() + self._update_overlay_button_styles() + + def _overlay_state(self, label: str) -> bool: + return { + "trend": self.show_trends, + "spike": self.show_spikes, + "drop": self.show_drops, + "gap": self.show_gaps, + "nonwear": self.show_nonwear, + }[label] + + def _update_overlay_button_styles(self) -> None: + for label, button in self.overlay_buttons.items(): + enabled = self._overlay_state(label) + face = "#1f4f95" if enabled else "#f7f7f7" + edge = "#f4d35e" if enabled else "#b8c0cc" + text = "white" if enabled else "#6b7280" + button.ax.set_facecolor(face) + button.ax.patch.set_edgecolor(edge) + button.ax.patch.set_linewidth(2.4 if enabled else 1.2) + for spine in button.ax.spines.values(): + spine.set_edgecolor(edge) + spine.set_linewidth(2.4 if enabled else 1.2) + button.hovercolor = "#3465a4" if enabled else "#ebeff4" + button.label.set_color(text) + button.label.set_fontweight("bold" if enabled else "normal") + + @staticmethod + def _build_detail_lines(title: str, lines: list[str], width: int) -> list[str]: + rendered = [title] + if not lines: + rendered.append(" none") + rendered.append("") + return rendered + for line in lines: + wrapped = textwrap.wrap(line, width=width) or [""] + rendered.extend(f" {part}" for part in wrapped) + rendered.append("") + return rendered + + def _update_detail_tab_styles(self) -> None: + for label, button in self.detail_tab_buttons.items(): + active = label == self.detail_mode + face = "#204a87" if active else "#f4f5f7" + edge = "#f4d35e" if active else "#c7cdd6" + text = "white" if active else "#5f6b7a" + button.ax.set_facecolor(face) + button.ax.patch.set_edgecolor(edge) + button.ax.patch.set_linewidth(2.2 if active else 1.1) + for spine in button.ax.spines.values(): + spine.set_edgecolor(edge) + spine.set_linewidth(2.2 if active else 1.1) + button.hovercolor = "#3465a4" if active else "#eaedf2" + button.label.set_color(text) + button.label.set_fontweight("bold" if active else "normal") + + def render(self, reset_zoom: bool = False) -> None: + signals, samples, annotations = self._load_row_bundle(self.row_index) + signal = signals[self.signal_index] + detector_events = self._detector_events(signal) + captions = self._captions_for_signal(signal.id, samples, annotations) + display_name, unit, decimals = MHC_CHANNEL_CONFIG.meta.get(signal.name, (signal.name, "", 2)) + + x = np.arange(len(signal.data)) + y = np.asarray(signal.data, dtype=float) + valid = ~np.isnan(y) + + old_xlim = self.ax_main.get_xlim() + old_ylim = self.ax_main.get_ylim() + + self.ax_main.clear() + self.ax_overview.clear() + self.ax_summary.clear() + self.ax_signal_list.clear() + self.ax_details.clear() + for ax in (self.ax_summary, self.ax_signal_list, self.ax_details): + ax.axis("off") + + self.ax_main.plot(x[valid], y[valid], color="steelblue", linewidth=1.0, label="signal") + if self.show_nonwear: + for start, end in _nan_regions(y): + self.ax_main.axvspan(start, end, color="#d62728", alpha=0.08) + + for detector_name, result in detector_events: + if result.event_type == "trend" and self.show_trends: + color = "#4daf4a" if result.direction == "increasing" else "#ff7f00" + label = f"{detector_name} ({result.direction})" + self.ax_main.axvspan(result.start_minute, result.end_minute, color=color, alpha=0.18, label=label) + elif result.event_type == "spike" and self.show_spikes: + minute = int(result.spike_minute) + if minute < len(y) and not np.isnan(y[minute]): + self.ax_main.scatter(minute, y[minute], color="#2ca02c", marker="^", s=38, zorder=4, label=detector_name) + self.ax_main.annotate(str(minute), (minute, y[minute]), xytext=(0, 8), textcoords="offset points", ha="center", fontsize=7) + elif result.event_type == "drop" and self.show_drops: + minute = int(result.spike_minute) + if minute < len(y) and not np.isnan(y[minute]): + self.ax_main.scatter(minute, y[minute], color="#d62728", marker="v", s=38, zorder=4, label=detector_name) + self.ax_main.annotate(str(minute), (minute, y[minute]), xytext=(0, -12), textcoords="offset points", ha="center", fontsize=7) + elif result.event_type == "gap" and self.show_gaps: + self.ax_main.axvspan(result.start_minute, result.end_minute, color="#d62728", alpha=0.12, label=detector_name) + + self.ax_main.set_title(f"Row {self.row_index} | {display_name}") + self.ax_main.set_ylabel(f"{display_name}\n({unit or 'value'})") + self.ax_main.set_xlabel("Minute of day") + self.ax_main.grid(alpha=0.2) + self.ax_main.margins(x=0) + + handles, labels = self.ax_main.get_legend_handles_labels() + deduped: dict[str, object] = {} + for handle, label in zip(handles, labels): + deduped.setdefault(label, handle) + if deduped: + self.ax_main.legend(deduped.values(), deduped.keys(), loc="upper right", fontsize=8) + + matrix = self._overview_matrix(signals) + cmap = plt.get_cmap("viridis").copy() + cmap.set_bad(color="#f1f1f1") + self.ax_overview.imshow(matrix, aspect="auto", interpolation="nearest", cmap=cmap, origin="upper") + self.ax_overview.axhspan( + self.signal_index - 0.5, + self.signal_index + 0.5, + facecolor="#f4d35e", + alpha=0.20, + edgecolor="#f4d35e", + linewidth=0, + ) + self.ax_overview.axhline(self.signal_index, color="white", linewidth=2) + self.ax_overview.set_title("Channel Overview", fontsize=10, loc="left", pad=4) + self.ax_overview.set_xlabel("Minute of day") + self.ax_overview.set_yticks([]) + self.ax_overview.tick_params(axis="x", labelsize=8) + self.ax_overview.text( + 1.0, + 1.02, + "click heatmap or signal list to change channel", + transform=self.ax_overview.transAxes, + ha="right", + va="bottom", + fontsize=8, + color="#555555", + ) + + valid_minutes = int(np.sum(valid)) + active_channels = sum(bool(s.metadata.get("has_any_data", False)) for s in signals) + total_nonwear = signal.metadata.get("total_nonwear_minutes") + wear_pct = None + if total_nonwear is not None: + wear_pct = (1440.0 - float(total_nonwear)) / 1440.0 * 100.0 + stats_text = "n/a" + if valid.any(): + values = y[valid] + stats_text = ( + f"mean={np.mean(values):.{decimals}f}\n" + f"std={np.std(values):.{decimals}f}\n" + f"min={np.min(values):.{decimals}f}\n" + f"max={np.max(values):.{decimals}f}" + ) + + detector_lines = [_format_detector_event(name, result) for name, result in detector_events] + if not detector_lines: + detector_lines = ["No detector events on this signal."] + + caption_lines = [] + for spec_id, values in captions.items(): + label = spec_id.split(":")[-1] + for value in values[:3]: + caption_lines.append(f"{label}: {value}") + if not caption_lines: + caption_lines = ["No captions for this signal."] + + self.ax_summary.set_xlim(0, 1) + self.ax_summary.set_ylim(0, 1) + self.ax_summary.add_patch(Rectangle((0.0, 0.72), 1.0, 0.28, facecolor="#204a87", edgecolor="none")) + self.ax_summary.text(0.03, 0.94, "Selected Signal", color="white", fontsize=9, va="top", weight="bold") + self.ax_summary.text(0.03, 0.80, _truncate(display_name, 26), color="white", fontsize=13, va="center", weight="bold") + self.ax_summary.text(0.03, 0.66, _truncate(signal.name, 34), color="#35506b", fontsize=8) + summary_lines = [ + f"row {self.row_index} signal {self.signal_index}/{len(signals) - 1}", + f"user {_truncate(str(signal.metadata.get('user_id', 'n/a')), 24)}", + f"date {signal.metadata.get('date', 'n/a')}", + f"wear {wear_pct:.1f}% nonwear {float(total_nonwear):.0f}m" if wear_pct is not None and total_nonwear is not None else "wear n/a", + f"active {active_channels}/{len(signals)} valid {valid_minutes}/1440", + f"has_data {signal.metadata.get('has_any_data', 'n/a')} nonzero_or_nan {signal.metadata.get('minutes_nonzero_or_nan', 'n/a')}", + ] + self.ax_summary.text( + 0.03, + 0.58, + "\n".join(summary_lines), + va="top", + ha="left", + fontsize=8.1, + family="monospace", + color="#222222", + ) + + self.ax_signal_list.set_xlim(0, 1) + self.ax_signal_list.set_ylim(len(signals), 0) + self.ax_signal_list.text( + 0.0, + 1.02, + "Signals", + transform=self.ax_signal_list.transAxes, + fontsize=10, + weight="bold", + color="#333333", + va="bottom", + ) + for idx, listed_signal in enumerate(signals): + y0 = idx + is_selected = idx == self.signal_index + is_active = bool(listed_signal.metadata.get("has_any_data", False)) + face = "#204a87" if is_selected else ("#f7f7f7" if idx % 2 == 0 else "#eeeeee") + edge = "#10253f" if is_selected else "#d0d0d0" + text_color = "white" if is_selected else ("#222222" if is_active else "#888888") + self.ax_signal_list.add_patch(Rectangle((0.0, y0), 1.0, 0.92, facecolor=face, edgecolor=edge, linewidth=0.8)) + self.ax_signal_list.text( + 0.03, + y0 + 0.46, + f"{idx:02d}", + va="center", + ha="left", + fontsize=8, + family="monospace", + color="#f4d35e" if is_selected else "#666666", + weight="bold", + ) + self.ax_signal_list.text( + 0.14, + y0 + 0.46, + _truncate(_display_name(listed_signal.name), 25), + va="center", + ha="left", + fontsize=8.8, + color=text_color, + weight="bold" if is_selected else "normal", + ) + self.ax_signal_list.text( + 0.99, + 1.01, + "click to select", + transform=self.ax_signal_list.transAxes, + ha="right", + va="bottom", + fontsize=7.5, + color="#666666", + ) + + stats_lines = stats_text.splitlines() if stats_text != "n/a" else ["n/a"] + if self.detail_mode == "stats": + detail_title = "Stats" + detail_lines = self._build_detail_lines("Stats", stats_lines, width=30) + elif self.detail_mode == "captions": + detail_title = "Captions" + detail_lines = self._build_detail_lines("Captions", caption_lines, width=30) + elif self.detail_mode == "help": + detail_title = "Help" + detail_lines = self._build_detail_lines( + "Help", + [ + "click a signal on the right or the overview heatmap to change channel", + "up/down changes row", + "left/right changes signal", + "mouse wheel over details scrolls", + "PageUp/PageDown also scroll details", + "overlay buttons toggle detector layers", + ], + width=30, + ) + else: + detail_title = "Detector Events" + detail_lines = self._build_detail_lines("Detector Events", detector_lines, width=30) + detail_lines = detail_lines[:-1] if detail_lines and detail_lines[-1] == "" else detail_lines + max_scroll = max(0, len(detail_lines) - self.details_page_lines) + self.details_scroll = min(max(self.details_scroll, 0), max_scroll) + visible_lines = detail_lines[self.details_scroll:self.details_scroll + self.details_page_lines] + start_line = self.details_scroll + 1 if detail_lines else 0 + end_line = self.details_scroll + len(visible_lines) + self.ax_details.set_xlim(0, 1) + self.ax_details.set_ylim(0, 1) + self.ax_details.add_patch(Rectangle((0, 0), 1, 1, facecolor="#fbfbfb", edgecolor="#dddddd")) + track_x = 0.975 + track_y = 0.06 + track_h = 0.80 + self.ax_details.add_patch(Rectangle((track_x, track_y), 0.012, track_h, facecolor="#ececec", edgecolor="#d0d0d0")) + if detail_lines: + visible_frac = min(1.0, self.details_page_lines / len(detail_lines)) + thumb_h = max(0.10, track_h * visible_frac) + available_h = track_h - thumb_h + scroll_frac = 0.0 if max_scroll == 0 else self.details_scroll / max_scroll + thumb_y = track_y + available_h * (1.0 - scroll_frac) + self.ax_details.add_patch(Rectangle((track_x, thumb_y), 0.012, thumb_h, facecolor="#204a87", edgecolor="none")) + self.ax_details.text( + 0.03, + 0.90, + "\n".join(visible_lines), + va="top", + ha="left", + fontsize=7.8, + family="monospace", + color="#222222", + ) + + if reset_zoom: + self.ax_main.set_xlim(0, len(y) - 1) + else: + self.ax_main.set_xlim(old_xlim) + self.ax_main.set_ylim(old_ylim) + + self.fig.suptitle("SensorTSLM Interactive Explorer", fontsize=14) + self.fig.canvas.draw_idle() + + def save(self, path: str) -> None: + self.fig.savefig(path, dpi=150) + + +def main() -> None: + args = _parse_args() + dataset = MHCDataset(min_wear_pct=args.min_wear_pct) + explorer = SensorExplorer( + dataset=dataset, + row_index=args.row_index, + signal_index=args.signal_index, + ) + if args.save_path: + explorer.save(args.save_path) + print(f"Saved {args.save_path}") + plt.close(explorer.fig) + return + plt.show() + + +if __name__ == "__main__": + main() From 35e9394547bdfb7b81dd6a96f082cae1e53f53eb Mon Sep 17 00:00:00 2001 From: Karl Date: Tue, 31 Mar 2026 21:41:30 +0200 Subject: [PATCH 6/7] added feature to find hits for certain detectors --- visualizer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/visualizer.py b/visualizer.py index 48bfa40..bb2fea1 100644 --- a/visualizer.py +++ b/visualizer.py @@ -147,6 +147,7 @@ def plot_detector_view( "signal": Line2D([0], [0], color="steelblue", linewidth=1.0, label="signal"), "trend:increasing": Patch(facecolor="#4daf4a", edgecolor="none", alpha=0.18, label="trend (increasing)"), "trend:decreasing": Patch(facecolor="#ff7f00", edgecolor="none", alpha=0.18, label="trend (decreasing)"), + "plateau": Patch(facecolor="#377eb8", edgecolor="none", alpha=0.16, label="plateau"), "spike": Line2D([0], [0], marker="^", linestyle="None", color="#2ca02c", markersize=6, label="spike"), "drop": Line2D([0], [0], marker="v", linestyle="None", color="#d62728", markersize=6, label="drop"), "gap": Patch(facecolor="#d62728", edgecolor="none", alpha=0.10, label="gap"), @@ -172,6 +173,9 @@ def plot_detector_view( color = "#4daf4a" if direction == "increasing" else "#ff7f00" ax.axvspan(result.start_minute, result.end_minute, color=color, alpha=0.18) rendered.append(f"{direction} {result.start_minute}-{result.end_minute}") + elif result.event_type == "plateau": + ax.axvspan(result.start_minute, result.end_minute, color="#377eb8", alpha=0.16) + rendered.append(f"plateau {result.start_minute}-{result.end_minute}") elif result.event_type == "spike": minute = int(result.spike_minute) value = signal.data[minute] From ce2818d185f27ff8da5d4e6e3075eb2ede9a7b34 Mon Sep 17 00:00:00 2001 From: Karl Date: Tue, 31 Mar 2026 21:59:43 +0200 Subject: [PATCH 7/7] hit miss update --- explorer.py | 169 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 150 insertions(+), 19 deletions(-) diff --git a/explorer.py b/explorer.py index 6728eef..5072448 100644 --- a/explorer.py +++ b/explorer.py @@ -52,6 +52,8 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] def _format_detector_event(detector_name: str, result) -> str: if result.event_type == "trend": return f"{detector_name}: {result.direction} {result.start_minute}-{result.end_minute}" + if result.event_type == "plateau": + return f"{detector_name}: plateau {result.start_minute}-{result.end_minute}" if result.event_type in {"spike", "drop"}: return f"{detector_name}: {result.event_type} @{result.spike_minute}" if result.event_type == "gap": @@ -87,6 +89,7 @@ def __init__( self.signal_index = min(max(0, signal_index), len(MHC_CHANNEL_CONFIG.names) - 1) self.show_trends = True + self.show_plateaus = True self.show_spikes = True self.show_drops = True self.show_gaps = True @@ -94,22 +97,26 @@ def __init__( self.detail_mode = "events" self.details_scroll = 0 self.details_page_lines = 12 + self.hit_target_names = self._available_detector_names() + self.hit_target = self.hit_target_names[0] if self.hit_target_names else None + self.search_status = "Use hit< / hit> to jump to the selected detector." self._ignore_widget_events = False self.fig = plt.figure(figsize=(17, 10)) - self.ax_main = self.fig.add_axes([0.06, 0.34, 0.66, 0.58]) - self.ax_overview = self.fig.add_axes([0.06, 0.16, 0.66, 0.12], sharex=self.ax_main) - self.ax_summary = self.fig.add_axes([0.76, 0.76, 0.22, 0.16]) - self.ax_signal_list = self.fig.add_axes([0.76, 0.44, 0.22, 0.28]) - self.ax_details = self.fig.add_axes([0.76, 0.16, 0.22, 0.22]) - for ax in (self.ax_summary, self.ax_signal_list, self.ax_details): + self.ax_main = self.fig.add_axes([0.04, 0.34, 0.70, 0.62]) + self.ax_overview = self.fig.add_axes([0.04, 0.16, 0.70, 0.12], sharex=self.ax_main) + self.ax_summary = self.fig.add_axes([0.77, 0.78, 0.21, 0.18]) + self.ax_hit_target = self.fig.add_axes([0.77, 0.715, 0.21, 0.045]) + self.ax_signal_list = self.fig.add_axes([0.77, 0.43, 0.21, 0.24]) + self.ax_details = self.fig.add_axes([0.76, 0.12, 0.22, 0.22]) + for ax in (self.ax_summary, self.ax_hit_target, self.ax_signal_list, self.ax_details): ax.axis("off") - self.reset_zoom_ax = self.fig.add_axes([0.06, 0.049, 0.055, 0.036]) - self.row_slider_ax = self.fig.add_axes([0.13, 0.060, 0.54, 0.024]) - self.prev_row_ax = self.fig.add_axes([0.69, 0.051, 0.035, 0.036]) - self.next_row_ax = self.fig.add_axes([0.73, 0.051, 0.035, 0.036]) + self.reset_zoom_ax = self.fig.add_axes([0.04, 0.049, 0.055, 0.036]) + self.row_slider_ax = self.fig.add_axes([0.11, 0.060, 0.57, 0.024]) + self.prev_row_ax = self.fig.add_axes([0.70, 0.051, 0.035, 0.036]) + self.next_row_ax = self.fig.add_axes([0.74, 0.051, 0.035, 0.036]) detail_tab_specs = [ ("stats", 0.76), ("events", 0.815), @@ -118,14 +125,26 @@ def __init__( ] self.detail_tab_buttons: dict[str, Button] = {} for label, x0 in detail_tab_specs: - ax = self.fig.add_axes([x0, 0.390, 0.05, 0.028]) + ax = self.fig.add_axes([x0 + 0.01, 0.375, 0.047, 0.028]) self.detail_tab_buttons[label] = Button(ax, label) - self.detail_up_ax = self.fig.add_axes([0.935, 0.125, 0.022, 0.028]) - self.detail_down_ax = self.fig.add_axes([0.958, 0.125, 0.022, 0.028]) - overlay_labels = ["trend", "spike", "drop", "gap", "nonwear"] + self.detail_up_ax = self.fig.add_axes([0.935, 0.085, 0.022, 0.028]) + self.detail_down_ax = self.fig.add_axes([0.958, 0.085, 0.022, 0.028]) + self.prev_hit_ax = self.fig.add_axes([0.76, 0.698, 0.035, 0.038]) + self.next_hit_ax = self.fig.add_axes([0.945, 0.698, 0.035, 0.038]) + self.hit_target_buttons: dict[str, Button] = {} + if self.hit_target_names: + n_targets = len(self.hit_target_names) + total_gap = 0.004 * max(0, n_targets - 1) + available_width = 0.145 + width = (available_width - total_gap) / max(1, n_targets) + for i, detector_name in enumerate(self.hit_target_names): + x0 = 0.805 + i * (width + 0.004) + ax = self.fig.add_axes([x0, 0.698, width, 0.038]) + self.hit_target_buttons[detector_name] = Button(ax, detector_name.replace("Detector", "").lower()) + overlay_labels = ["trend", "plateau", "spike", "drop", "gap", "nonwear"] self.overlay_buttons: dict[str, Button] = {} start_x = 0.79 - button_width = 0.037 + button_width = 0.027 gap = 0.004 for i, label in enumerate(overlay_labels): ax = self.fig.add_axes([start_x + i * (button_width + gap), 0.051, button_width, 0.036]) @@ -143,6 +162,8 @@ def __init__( self.reset_zoom_button = Button(self.reset_zoom_ax, "reset") self.prev_row_button = Button(self.prev_row_ax, "<") self.next_row_button = Button(self.next_row_ax, ">") + self.prev_hit_button = Button(self.prev_hit_ax, "hit<") + self.next_hit_button = Button(self.next_hit_ax, "hit>") self.detail_up_button = Button(self.detail_up_ax, "^") self.detail_down_button = Button(self.detail_down_ax, "v") self._style_widgets() @@ -152,8 +173,12 @@ def __init__( self.reset_zoom_button.on_clicked(lambda _: self.render(reset_zoom=True)) self.prev_row_button.on_clicked(lambda _: self._set_row(self.row_index - 1)) self.next_row_button.on_clicked(lambda _: self._set_row(self.row_index + 1)) + self.prev_hit_button.on_clicked(lambda _: self._jump_to_hit(-1)) + self.next_hit_button.on_clicked(lambda _: self._jump_to_hit(1)) self.detail_up_button.on_clicked(lambda _: self._scroll_details(-1)) self.detail_down_button.on_clicked(lambda _: self._scroll_details(1)) + for detector_name, button in self.hit_target_buttons.items(): + button.on_clicked(lambda _, name=detector_name: self._set_hit_target(name)) for label, button in self.detail_tab_buttons.items(): button.on_clicked(lambda _, name=label: self._set_detail_mode(name)) for label, button in self.overlay_buttons.items(): @@ -171,11 +196,21 @@ def _load_row_bundle(self, row_index: int) -> tuple[list[Signal], list[Sample], samples, annotations = self.annotator.annotate(signals) return signals, samples, annotations + @staticmethod + def _available_detector_names() -> list[str]: + names = { + detector.__class__.__name__ + for detectors in MHC_CHANNEL_CONFIG.detectors.values() + for detector in detectors + } + return sorted(names) + def _set_row(self, row_index: int) -> None: row_index = min(max(0, int(row_index)), len(self.dataset) - 1) if row_index == self.row_index: return self.row_index = row_index + self.details_scroll = 0 self._sync_widgets() self.render(reset_zoom=True) @@ -184,6 +219,7 @@ def _set_signal(self, signal_index: int) -> None: if signal_index == self.signal_index: return self.signal_index = signal_index + self.details_scroll = 0 self._sync_widgets() self.render(reset_zoom=True) @@ -200,6 +236,8 @@ def _on_row_slider(self, value: float) -> None: def _on_toggle(self, label: str) -> None: if label == "trend": self.show_trends = not self.show_trends + elif label == "plateau": + self.show_plateaus = not self.show_plateaus elif label == "spike": self.show_spikes = not self.show_spikes elif label == "drop": @@ -211,6 +249,13 @@ def _on_toggle(self, label: str) -> None: self._update_overlay_button_styles() self.render(reset_zoom=False) + def _set_hit_target(self, detector_name: str) -> None: + if detector_name != self.hit_target: + self.hit_target = detector_name + self.search_status = f"Jump target set to {detector_name}." + self._update_hit_target_button_styles() + self.render(reset_zoom=False) + def _set_detail_mode(self, mode: str) -> None: if mode != self.detail_mode: self.detail_mode = mode @@ -239,6 +284,10 @@ def _on_key_press(self, event) -> None: self._scroll_details(-1) elif event.key == "pagedown": self._scroll_details(1) + elif event.key == "n": + self._jump_to_hit(1) + elif event.key == "p": + self._jump_to_hit(-1) def _on_scroll(self, event) -> None: if event.inaxes is not self.ax_details: @@ -258,6 +307,32 @@ def _detector_events(self, signal: Signal) -> list[tuple[str, object]]: events.append((detector_name, result)) return events + @lru_cache(maxsize=64) + def _row_detector_events(self, row_index: int) -> tuple[tuple[tuple[str, object], ...], ...]: + signals, _, _ = self._load_row_bundle(row_index) + return tuple(tuple(self._detector_events(signal)) for signal in signals) + + def _jump_to_hit(self, step: int) -> None: + n_rows = len(self.dataset) + n_signals = len(MHC_CHANNEL_CONFIG.names) + flat_index = self.row_index * n_signals + self.signal_index + + for offset in range(1, n_rows * n_signals): + candidate = (flat_index + step * offset) % (n_rows * n_signals) + row_index, signal_index = divmod(candidate, n_signals) + row_signal_events = self._row_detector_events(row_index)[signal_index] + if any(detector_name == self.hit_target for detector_name, _ in row_signal_events): + self.row_index = row_index + self.signal_index = signal_index + self.details_scroll = 0 + self.search_status = f"Jumped to row {row_index}, signal {signal_index} with {self.hit_target}." + self._sync_widgets() + self.render(reset_zoom=True) + return + + self.search_status = f"No hits found for {self.hit_target} in the scanned dataset." + self.render(reset_zoom=False) + @staticmethod def _captions_for_signal(signal_id: str, samples: list[Sample], annotations: list[Annotation]) -> dict[str, list[str]]: sample_ids = { @@ -299,22 +374,28 @@ def _style_widgets(self) -> None: self.reset_zoom_button, self.prev_row_button, self.next_row_button, + self.prev_hit_button, + self.next_hit_button, ): - button.label.set_fontsize(9) + button.label.set_fontsize(8.5) + for button in self.hit_target_buttons.values(): + button.label.set_fontsize(8) for button in self.detail_tab_buttons.values(): button.label.set_fontsize(7.5) self.detail_up_button.label.set_fontsize(8) self.detail_down_button.label.set_fontsize(8) for button in self.overlay_buttons.values(): - button.label.set_fontsize(7.5) + button.label.set_fontsize(6.8) self._sync_widgets() + self._update_hit_target_button_styles() self._update_detail_tab_styles() self._update_overlay_button_styles() def _overlay_state(self, label: str) -> bool: return { "trend": self.show_trends, + "plateau": self.show_plateaus, "spike": self.show_spikes, "drop": self.show_drops, "gap": self.show_gaps, @@ -337,6 +418,22 @@ def _update_overlay_button_styles(self) -> None: button.label.set_color(text) button.label.set_fontweight("bold" if enabled else "normal") + def _update_hit_target_button_styles(self) -> None: + for detector_name, button in self.hit_target_buttons.items(): + active = detector_name == self.hit_target + face = "#204a87" if active else "#f4f5f7" + edge = "#f4d35e" if active else "#c7cdd6" + text = "white" if active else "#5f6b7a" + button.ax.set_facecolor(face) + button.ax.patch.set_edgecolor(edge) + button.ax.patch.set_linewidth(2.0 if active else 1.1) + for spine in button.ax.spines.values(): + spine.set_edgecolor(edge) + spine.set_linewidth(2.0 if active else 1.1) + button.hovercolor = "#3465a4" if active else "#eaedf2" + button.label.set_color(text) + button.label.set_fontweight("bold" if active else "normal") + @staticmethod def _build_detail_lines(title: str, lines: list[str], width: int) -> list[str]: rendered = [title] @@ -383,9 +480,10 @@ def render(self, reset_zoom: bool = False) -> None: self.ax_main.clear() self.ax_overview.clear() self.ax_summary.clear() + self.ax_hit_target.clear() self.ax_signal_list.clear() self.ax_details.clear() - for ax in (self.ax_summary, self.ax_signal_list, self.ax_details): + for ax in (self.ax_summary, self.ax_hit_target, self.ax_signal_list, self.ax_details): ax.axis("off") self.ax_main.plot(x[valid], y[valid], color="steelblue", linewidth=1.0, label="signal") @@ -398,6 +496,14 @@ def render(self, reset_zoom: bool = False) -> None: color = "#4daf4a" if result.direction == "increasing" else "#ff7f00" label = f"{detector_name} ({result.direction})" self.ax_main.axvspan(result.start_minute, result.end_minute, color=color, alpha=0.18, label=label) + elif result.event_type == "plateau" and self.show_plateaus: + self.ax_main.axvspan( + result.start_minute, + result.end_minute, + color="#377eb8", + alpha=0.16, + label=detector_name, + ) elif result.event_type == "spike" and self.show_spikes: minute = int(result.spike_minute) if minute < len(y) and not np.isnan(y[minute]): @@ -504,6 +610,29 @@ def render(self, reset_zoom: bool = False) -> None: family="monospace", color="#222222", ) + self.ax_hit_target.set_xlim(0, 1) + self.ax_hit_target.set_ylim(0, 1) + self.ax_hit_target.text( + 0.0, + 1.02, + f"Find {self.hit_target.replace('Detector', '').lower()} hits" if self.hit_target else "Find hits", + transform=self.ax_hit_target.transAxes, + ha="left", + va="bottom", + fontsize=8.5, + color="#444444", + weight="bold", + ) + self.ax_hit_target.text( + 1.0, + 1.02, + _truncate(self.search_status, 40), + transform=self.ax_hit_target.transAxes, + ha="right", + va="bottom", + fontsize=7.5, + color="#666666", + ) self.ax_signal_list.set_xlim(0, 1) self.ax_signal_list.set_ylim(len(signals), 0) @@ -572,6 +701,9 @@ def render(self, reset_zoom: bool = False) -> None: "click a signal on the right or the overview heatmap to change channel", "up/down changes row", "left/right changes signal", + "choose a detector target above the signal list", + "n / p jump to next or previous hit for that detector", + "hit< / hit> buttons do the same", "mouse wheel over details scrolls", "PageUp/PageDown also scroll details", "overlay buttons toggle detector layers", @@ -618,7 +750,6 @@ def render(self, reset_zoom: bool = False) -> None: self.ax_main.set_xlim(old_xlim) self.ax_main.set_ylim(old_ylim) - self.fig.suptitle("SensorTSLM Interactive Explorer", fontsize=14) self.fig.canvas.draw_idle() def save(self, path: str) -> None: