diff --git a/.gitignore b/.gitignore index 234e05d..3540a84 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ +data/ .env .claude/ docs/ diff --git a/detectors/__init__.py b/detectors/__init__.py index 3ae65ce..2f1103c 100644 --- a/detectors/__init__.py +++ b/detectors/__init__.py @@ -15,7 +15,7 @@ @dataclass class DetectionResult: - event_type: Literal["trend", "spike", "drop"] + event_type: Literal["trend", "spike"] start_minute: Optional[int] = None end_minute: Optional[int] = None direction: Optional[Literal["increasing", "decreasing"]] = None diff --git a/detectors/spike.py b/detectors/spike.py index 7ca6086..72c905e 100644 --- a/detectors/spike.py +++ b/detectors/spike.py @@ -7,163 +7,162 @@ from __future__ import annotations import numpy as np -from scipy.signal import find_peaks, peak_widths +from scipy.signal import find_peaks from detectors import DetectionResult, StructuralDetector class SpikeDetector(StructuralDetector): - """Detects spikes and drops. - """ + """Detects upward point anomalies using a local-baseline heuristic.""" def __init__( self, filter_zeros: bool = False, - prominence_scale: float = 3.0, min_prominence: float = 0.0, min_distance: int = 1, min_width: int = 1, top_k: int | None = None, - smooth_window: int = 1, - drop_localization: str = "center", + max_cluster_peaks: int = 2, + soft_top_k_margin: float = 0.90, ) -> None: super().__init__(filter_zeros=filter_zeros) - self.prominence_scale = prominence_scale self.min_prominence = min_prominence self.min_distance = max(1, min_distance) self.min_width = max(1, min_width) self.top_k = top_k - self.smooth_window = max(1, smooth_window) - self.drop_localization = drop_localization + self.max_cluster_peaks = max(1, max_cluster_peaks) + self.soft_top_k_margin = min(max(float(soft_top_k_margin), 0.0), 1.0) + self.duplicate_radius = max(1, min(3, self.min_distance // 5 if self.min_distance > 1 else 1)) def _detect(self, filtered_signal: np.ndarray, indices: np.ndarray) -> list[DetectionResult]: - smoothed_signal = self._smooth(filtered_signal) - prominence = self._prominence_threshold(filtered_signal) # threshold on raw/filtered, not smoothed - if prominence <= 0: - return [] - - results: list[DetectionResult] = [] - seen_minutes: set[int] = set() - - for peak_idx, peak_score, properties in self._select_peaks(smoothed_signal, prominence): - minute = self._localize_spike(peak_idx, properties, filtered_signal, indices) - if minute in seen_minutes: - continue - results.append(DetectionResult(event_type="spike", spike_minute=minute, score=peak_score)) - seen_minutes.add(minute) - - for peak_idx, peak_score, properties in self._select_drop_peaks(smoothed_signal, prominence): - minute = self._localize_drop(peak_idx, properties, smoothed_signal, indices) - if minute in seen_minutes: - continue - results.append(DetectionResult(event_type="drop", spike_minute=minute, score=peak_score)) - seen_minutes.add(minute) - + candidates = self._collect_candidates(filtered_signal, indices) + results = self._select_results(candidates) results.sort(key=lambda result: int(result.spike_minute)) return results - def _smooth(self, filtered_signal: np.ndarray) -> np.ndarray: - if self.smooth_window <= 1: - return filtered_signal - - kernel = np.ones(self.smooth_window, dtype=float) / float(self.smooth_window) - return np.convolve(filtered_signal, kernel, mode="same") - - def _select_peaks( + def _collect_candidates( self, - smoothed_signal: np.ndarray, - prominence: float, - ) -> list[tuple[int, float, dict[str, int]]]: + filtered_signal: np.ndarray, + indices: np.ndarray, + ) -> list[dict[str, float | int]]: peaks, properties = find_peaks( - smoothed_signal, - prominence=prominence, - distance=self.min_distance, + filtered_signal, + prominence=self.min_prominence, width=self.min_width, ) if len(peaks) == 0: return [] prominences = properties["prominences"].astype(float) - order = np.argsort(prominences)[::-1] - if self.top_k is not None: - order = order[: self.top_k] - left_ips = properties.get("left_ips", peaks).astype(float) right_ips = properties.get("right_ips", peaks).astype(float) - - selected: list[tuple[int, float, dict[str, int]]] = [] - for idx in order: - peak_idx = int(peaks[idx]) - selected.append( - ( - peak_idx, - float(prominences[idx]), - { - "left_idx": int(np.floor(left_ips[idx])), - "right_idx": int(np.ceil(right_ips[idx])), - }, - ) + left_bases = properties.get("left_bases", peaks).astype(int) + right_bases = properties.get("right_bases", peaks).astype(int) + widths = properties.get("widths", np.ones_like(peaks, dtype=float)).astype(float) + + candidates: list[dict[str, float | int]] = [] + for idx, peak in enumerate(peaks): + peak_idx = int(peak) + minute = self._localize_spike( + peak_idx, + { + "left_idx": int(np.floor(left_ips[idx])), + "right_idx": int(np.ceil(right_ips[idx])), + }, + filtered_signal, + indices, ) - return selected - - def _select_drop_peaks( - self, - smoothed_signal: np.ndarray, - prominence: float, - ) -> list[tuple[int, float, dict[str, int]]]: - peaks, properties = find_peaks( - -smoothed_signal, - prominence=prominence, - distance=self.min_distance, - width=self.min_width, - plateau_size=(1, None), - ) - if len(peaks) == 0: - return [] - - prominences = properties["prominences"].astype(float) - left_edges = properties.get("left_edges", peaks).astype(int) - order = np.argsort(prominences)[::-1] - if self.top_k is not None: - order = order[: self.top_k] - - selected: list[tuple[int, float, dict[str, int]]] = [] - for idx in order: - peak_idx = int(peaks[idx]) - left_edge_idx = int(left_edges[idx]) - selected.append( - ( - peak_idx, - float(prominences[idx]), - { - "left_edge_idx": left_edge_idx, - }, - ) + left_base_value = float(filtered_signal[int(np.clip(left_bases[idx], 0, len(filtered_signal) - 1))]) + right_base_value = float(filtered_signal[int(np.clip(right_bases[idx], 0, len(filtered_signal) - 1))]) + baseline_value = max(left_base_value, right_base_value) + local_jump = max(0.0, float(filtered_signal[peak_idx]) - baseline_value) + relative_jump = local_jump / max(abs(baseline_value), 1.0) + candidates.append( + { + "minute": minute, + "prominence": float(prominences[idx]), + "width": float(widths[idx]), + "local_jump": local_jump, + "relative_jump": relative_jump, + } ) - return selected + return candidates - def _localize_drop( - self, - peak_idx: int, - properties: dict[str, int], - smoothed_signal: np.ndarray, - indices: np.ndarray, - ) -> int: - if self.drop_localization == "left_edge": - return int(indices[properties["left_edge_idx"]]) + def _select_results(self, candidates: list[dict[str, float | int]]) -> list[DetectionResult]: + if not candidates: + return [] - if self.drop_localization == "left_ips": - try: - left_ips_arr = peak_widths(-smoothed_signal, [peak_idx], rel_height=1.0)[2] - left_idx = int(np.clip(round(float(left_ips_arr[0])), 0, len(indices) - 1)) - return int(indices[left_idx]) - except Exception: - return int(indices[peak_idx]) + minutes = np.asarray([float(candidate["minute"]) for candidate in candidates], dtype=float) + prominences = np.asarray([float(candidate["prominence"]) for candidate in candidates], dtype=float) + widths = np.asarray([float(candidate["width"]) for candidate in candidates], dtype=float) + local_jumps = np.asarray([float(candidate["local_jump"]) for candidate in candidates], dtype=float) + relative_jumps = np.asarray([float(candidate["relative_jump"]) for candidate in candidates], dtype=float) + + jump_cutoff = self._local_jump_threshold(local_jumps) + scores = local_jumps + 0.25 * widths + 0.10 * prominences + + duplicate_clusters = self._cluster_ids(minutes, radius=self.duplicate_radius) + accepted: set[int] = set() + cluster_members: dict[int, list[int]] = {} + for idx, cluster_id in enumerate(duplicate_clusters): + cluster_members.setdefault(int(cluster_id), []).append(idx) + + for cluster_id in sorted(cluster_members): + members = cluster_members[cluster_id] + qualifying = [ + idx for idx in members + if local_jumps[idx] >= jump_cutoff - 1e-12 or relative_jumps[idx] >= 1.0 - 1e-12 + ] + if not qualifying: + continue - return int(indices[peak_idx]) + qualifying.sort(key=lambda idx: (-scores[idx], -local_jumps[idx], -widths[idx], minutes[idx])) + keep_n = min(self.max_cluster_peaks, len(qualifying)) + boundary_score = float(scores[qualifying[keep_n - 1]]) + accepted.update(idx for idx in qualifying if scores[idx] >= boundary_score - 1e-12) + + if self.top_k is not None and len(accepted) > self.top_k: + accepted_list = sorted(accepted, key=lambda idx: (-scores[idx], -local_jumps[idx], -widths[idx], minutes[idx])) + boundary_score = float(scores[accepted_list[self.top_k - 1]]) + override_cutoff = boundary_score * self.soft_top_k_margin + accepted = {idx for idx in accepted_list if scores[idx] >= override_cutoff - 1e-12} + + return [ + DetectionResult( + event_type="spike", + spike_minute=int(candidates[idx]["minute"]), + score=float(scores[idx]), + ) + for idx in accepted + ] + + @staticmethod + def _cluster_ids(minutes: np.ndarray, radius: int) -> np.ndarray: + if len(minutes) == 0: + return np.zeros(0, dtype=int) + order = np.argsort(minutes) + cluster_ids = np.zeros(len(minutes), dtype=int) + cluster = 0 + prev_minute = float(minutes[order[0]]) + cluster_ids[order[0]] = cluster + for idx in order[1:]: + minute = float(minutes[idx]) + if minute - prev_minute > radius: + cluster += 1 + cluster_ids[idx] = cluster + prev_minute = minute + return cluster_ids + + @staticmethod + def _local_jump_threshold(local_jumps: np.ndarray) -> float: + if len(local_jumps) == 0: + return 0.0 + q25, q50, q75 = np.percentile(local_jumps, [25, 50, 75]) + iqr = float(q75 - q25) + return max(0.0, max(float(q25), float(q50 - 0.25 * iqr))) + @staticmethod def _localize_spike( - self, peak_idx: int, properties: dict[str, int], filtered_signal: np.ndarray, @@ -183,15 +182,3 @@ def _localize_spike( candidate_indices = left_idx + candidate_offsets best_idx = min(candidate_indices, key=lambda idx: abs(idx - peak_idx)) return int(indices[int(best_idx)]) - - def _prominence_threshold(self, filtered_signal: np.ndarray) -> float: - centered = filtered_signal - np.median(filtered_signal) - mad = float(np.median(np.abs(centered))) - if mad > 0: - scale = 1.4826 * mad - return max(self.min_prominence, self.prominence_scale * scale) - - spread = float(np.percentile(filtered_signal, 95) - np.percentile(filtered_signal, 5)) - if spread <= 1e-12: - return 0.0 - return max(self.min_prominence, 0.5 * spread) diff --git a/explorer.py b/explorer.py index bf7b4f8..2aedec8 100644 --- a/explorer.py +++ b/explorer.py @@ -58,17 +58,23 @@ def _format_detector_event(detector_name: str, result: object) -> str: return ( f"{detector_name}: {result.direction} {result.start_minute}-{result.end_minute}{score_suffix}" ) - if event_type in {"spike", "drop"}: + if event_type == "spike": return f"{detector_name}: {event_type} @{result.spike_minute}{score_suffix}" return f"{detector_name}: {event_type}{score_suffix}" - def _truncate(text: str, max_len: int = 34) -> str: if len(text) <= max_len: return text return text[: max_len - 1] + "..." +def _hit_target_label(hit_target: str) -> str: + if ":" in hit_target: + _, event_type = hit_target.split(":", 1) + return event_type.replace("_", " ") + return hit_target.replace("Detector", "").lower() + + class SensorExplorer: def __init__( self, @@ -88,14 +94,12 @@ def __init__( self.show_trends = True self.show_spikes = True - self.show_drops = True - self.show_nonwear = True self.detail_mode = "events" self.details_scroll = 0 self.details_page_lines = 12 self.hit_target_names = self._available_detector_names() self.hit_target = self.hit_target_names[0] if self.hit_target_names else None - self.search_status = "Use hit< / hit> to jump to the selected detector." + self.search_status = "Use hit< / hit> to jump to the selected event type." self._ignore_widget_events = False @@ -114,10 +118,10 @@ def __init__( self.prev_row_ax = self.fig.add_axes([0.70, 0.051, 0.035, 0.036]) self.next_row_ax = self.fig.add_axes([0.74, 0.051, 0.035, 0.036]) detail_tab_specs = [ - ("stats", 0.76), - ("events", 0.815), - ("captions", 0.870), - ("help", 0.925), + ("stats", 0.755), + ("events", 0.805), + ("captions", 0.855), + ("help", 0.905), ] self.detail_tab_buttons: dict[str, Button] = {} for label, x0 in detail_tab_specs: @@ -136,8 +140,8 @@ def __init__( for i, detector_name in enumerate(self.hit_target_names): x0 = 0.805 + i * (width + 0.004) ax = self.fig.add_axes([x0, 0.698, width, 0.038]) - self.hit_target_buttons[detector_name] = Button(ax, detector_name.replace("Detector", "").lower()) - overlay_labels = ["trend", "spike", "drop", "nonwear"] + self.hit_target_buttons[detector_name] = Button(ax, _hit_target_label(detector_name)) + overlay_labels = ["trend", "spike"] self.overlay_buttons: dict[str, Button] = {} start_x = 0.83 button_width = 0.035 @@ -194,13 +198,20 @@ def _load_row_bundle(self, row_index: int) -> Recording: @staticmethod def _available_detector_names() -> list[str]: - names = { - detector.__class__.__name__ - for detectors in MHC_CHANNEL_CONFIG.detectors.values() - for detector in detectors - } + names: set[str] = set() + for detectors in MHC_CHANNEL_CONFIG.detectors.values(): + for detector in detectors: + detector_name = detector.__class__.__name__ + names.add(detector_name) return sorted(names) + @staticmethod + def _matches_hit_target(hit_target: str, detector_name: str, result: object) -> bool: + if ":" not in hit_target: + return detector_name == hit_target + target_detector_name, target_event_type = hit_target.split(":", 1) + return detector_name == target_detector_name and getattr(result, "event_type", None) == target_event_type + def _set_row(self, row_index: int) -> None: row_index = min(max(0, int(row_index)), len(self.dataset) - 1) if row_index == self.row_index: @@ -234,17 +245,13 @@ def _on_toggle(self, label: str) -> None: self.show_trends = not self.show_trends elif label == "spike": self.show_spikes = not self.show_spikes - elif label == "drop": - self.show_drops = not self.show_drops - elif label == "nonwear": - self.show_nonwear = not self.show_nonwear self._update_overlay_button_styles() self.render(reset_zoom=False) def _set_hit_target(self, detector_name: str) -> None: if detector_name != self.hit_target: self.hit_target = detector_name - self.search_status = f"Jump target set to {detector_name}." + self.search_status = f"Jump target set to {_hit_target_label(detector_name)}." self._update_hit_target_button_styles() self.render(reset_zoom=False) @@ -299,12 +306,22 @@ def _detector_events(self, signal: SignalView) -> list[tuple[str, object]]: events.append((float(getattr(result, "score", 0.0)), detector_name, result)) events.sort(key=lambda item: item[0], reverse=True) - limit = StructuralExtractor.MAX_EVENTS_PER_SIGNAL - if limit is not None: - events = events[:limit] - return [(detector_name, result) for _, detector_name, result in events] + @staticmethod + def _spike_labels(detector_events: list[tuple[str, object]]) -> dict[tuple[str, int], int]: + labels: dict[tuple[str, int], int] = {} + rank = 1 + for detector_name, result in detector_events: + if getattr(result, "event_type", None) != "spike": + continue + minute = getattr(result, "spike_minute", None) + if minute is None: + continue + labels[(detector_name, int(minute))] = rank + rank += 1 + return labels + @lru_cache(maxsize=64) def _row_detector_events(self, row_index: int) -> tuple[tuple[tuple[str, object], ...], ...]: recording = self._load_row_bundle(row_index) @@ -324,16 +341,16 @@ def _jump_to_hit(self, step: int) -> None: candidate = (flat_index + step * offset) % (n_rows * n_signals) row_index, signal_index = divmod(candidate, n_signals) row_signal_events = self._row_detector_events(row_index)[signal_index] - if any(detector_name == self.hit_target for detector_name, _ in row_signal_events): + if any(self._matches_hit_target(self.hit_target, detector_name, result) for detector_name, result in row_signal_events): self.row_index = row_index self.signal_index = signal_index self.details_scroll = 0 - self.search_status = f"Jumped to row {row_index}, signal {signal_index} with {self.hit_target}." + self.search_status = f"Jumped to row {row_index}, signal {signal_index} with {_hit_target_label(self.hit_target)}." self._sync_widgets() self.render(reset_zoom=True) return - self.search_status = f"No hits found for {self.hit_target} in the scanned dataset." + self.search_status = f"No hits found for {_hit_target_label(self.hit_target)} in the scanned dataset." self.render(reset_zoom=False) @staticmethod @@ -393,8 +410,6 @@ def _overlay_state(self, label: str) -> bool: return { "trend": self.show_trends, "spike": self.show_spikes, - "drop": self.show_drops, - "nonwear": self.show_nonwear, }[label] def _update_overlay_button_styles(self) -> None: @@ -463,6 +478,7 @@ def render(self, reset_zoom: bool = False) -> None: n_signals = recording.values.shape[0] signal = recording.signal(self.signal_index) detector_events = self._detector_events(signal) + spike_labels = self._spike_labels(detector_events) captions = self._captions_for_signal(recording, self.signal_index) display_name = signal.display_name unit = signal.unit or "" @@ -485,9 +501,6 @@ def render(self, reset_zoom: bool = False) -> None: ax.axis("off") self.ax_main.plot(x[valid], y[valid], color="steelblue", linewidth=1.0, label="signal") - if self.show_nonwear: - for start, end in _nan_regions(y): - self.ax_main.axvspan(start, end, color="#d62728", alpha=0.08, label="nonwear") for detector_name, result in detector_events: if result.event_type == "trend" and self.show_trends: @@ -498,13 +511,8 @@ def render(self, reset_zoom: bool = False) -> None: minute = int(result.spike_minute) if minute < len(y) and not np.isnan(y[minute]): self.ax_main.scatter(minute, y[minute], color="#2ca02c", marker="^", s=38, zorder=4, label=detector_name) - self.ax_main.annotate(str(minute), (minute, y[minute]), xytext=(0, 8), textcoords="offset points", ha="center", fontsize=7) - elif result.event_type == "drop" and self.show_drops: - minute = int(result.spike_minute) - if minute < len(y) and not np.isnan(y[minute]): - self.ax_main.scatter(minute, y[minute], color="#d62728", marker="v", s=38, zorder=4, label=detector_name) - self.ax_main.annotate(str(minute), (minute, y[minute]), xytext=(0, -12), textcoords="offset points", ha="center", fontsize=7) - + label = spike_labels.get((detector_name, minute), minute) + self.ax_main.annotate(f"#{label}", (minute, y[minute]), xytext=(0, 8), textcoords="offset points", ha="center", fontsize=7) self.ax_main.set_title(f"Row {self.row_index} | {display_name}") self.ax_main.set_ylabel(f"{display_name}\n({unit or 'value'})") self.ax_main.set_xlabel("Minute of day") @@ -548,10 +556,7 @@ def render(self, reset_zoom: bool = False) -> None: valid_minutes = int(np.sum(valid)) active_channels = recording.active_channel_count() - total_nonwear = recording.total_nonwear_minutes wear_pct = recording.wear_pct - if wear_pct is None and total_nonwear is not None: - wear_pct = (1440.0 - float(total_nonwear)) / 1440.0 * 100.0 stats_text = "n/a" if valid.any(): values = y[valid] @@ -562,7 +567,16 @@ def render(self, reset_zoom: bool = False) -> None: f"max={np.max(values):.{decimals}f}" ) - detector_lines = [_format_detector_event(name, result) for name, result in detector_events] + detector_lines = [] + for detector_name, result in detector_events: + line = _format_detector_event(detector_name, result) + if getattr(result, "event_type", None) == "spike": + minute = getattr(result, "spike_minute", None) + if minute is not None: + rank = spike_labels.get((detector_name, int(minute))) + if rank is not None: + line = f"#{rank} {line}" + detector_lines.append(line) if not detector_lines: detector_lines = ["No detector events on this signal."] @@ -583,7 +597,7 @@ def render(self, reset_zoom: bool = False) -> None: f"row {self.row_index} signal {self.signal_index}/{n_signals - 1}", f"user {_truncate(str(recording.user_id), 24)}", f"date {recording.date}", - f"wear {wear_pct:.1f}% nonwear {float(total_nonwear):.0f}m" if wear_pct is not None and total_nonwear is not None else "wear n/a", + f"wear {wear_pct:.1f}%" if wear_pct is not None else "wear n/a", f"active {active_channels}/{n_signals} valid {valid_minutes}/1440", f"has_data {signal.has_any_data} nonzero_or_nan {signal.minutes_nonzero_or_nan:.0f}", ] @@ -602,7 +616,7 @@ def render(self, reset_zoom: bool = False) -> None: self.ax_hit_target.text( 0.0, 1.02, - f"Find {self.hit_target.replace('Detector', '').lower()} hits" if self.hit_target else "Find hits", + f"Find {_hit_target_label(self.hit_target)} hits" if self.hit_target else "Find hits", transform=self.ax_hit_target.transAxes, ha="left", va="bottom", diff --git a/extractors/structural.py b/extractors/structural.py index 3866596..4232c1b 100644 --- a/extractors/structural.py +++ b/extractors/structural.py @@ -18,7 +18,6 @@ class StructuralExtractor(CaptionExtractor): caption_type = "structural" - MAX_EVENTS_PER_SIGNAL: int | None = 6 def extract(self, row: Recording) -> list[Annotation]: seed = self._seed(row.row_id) @@ -44,8 +43,6 @@ def extract(self, row: Recording) -> list[Annotation]: j += 1 signal_events.sort(key=lambda item: item[0], reverse=True) - if self.MAX_EVENTS_PER_SIGNAL is not None: - signal_events = signal_events[: self.MAX_EVENTS_PER_SIGNAL] results.extend(annotation for _, annotation in signal_events) diff --git a/mhc/constants.py b/mhc/constants.py index 3683bb9..781735d 100644 --- a/mhc/constants.py +++ b/mhc/constants.py @@ -62,40 +62,28 @@ def _activity_spike(**overrides) -> SpikeDetector: params = { - "prominence_scale": 5.0, - "min_distance": 30, - "min_width": 3, + "min_distance": 15, + "min_width": 1, "top_k": 4, - "smooth_window": 5, - "drop_localization": "left_edge", } params.update(overrides) return SpikeDetector(**params) - def _heart_rate_spike() -> SpikeDetector: return SpikeDetector( filter_zeros=True, - prominence_scale=4.5, min_distance=20, - min_width=3, - top_k=3, - smooth_window=3, - drop_localization="center", + min_width=1, + top_k=4, ) def _flights_spike() -> SpikeDetector: return SpikeDetector( - prominence_scale=4.5, min_distance=30, min_width=1, - top_k=3, - smooth_window=1, - drop_localization="left_edge", + top_k=4, ) - - MHC_CHANNEL_CONFIG = ChannelConfig( names=CHANNEL_NAMES, meta=CHANNEL_META, diff --git a/templates/templates.json b/templates/templates.json index dbf7bf2..50345a3 100644 --- a/templates/templates.json +++ b/templates/templates.json @@ -20,13 +20,6 @@ "{name} shows a notable spike at minute {minute}.", "An outlier event in {name} occurred at minute {minute}.", "A significant peak in {name} was observed at minute {minute}." - ], - "drop": [ - "A drop in {name} was detected at minute {minute}.", - "A sudden decrease in {name} occurred at minute {minute}.", - "{name} shows a notable drop at minute {minute}.", - "An outlier low event in {name} occurred at minute {minute}.", - "A significant dip in {name} was observed at minute {minute}." ] }, "semantic": [