|
6 | 6 | # |
7 | 7 | from __future__ import annotations |
8 | 8 |
|
9 | | -import numpy as np |
10 | | - |
11 | 9 | from extractors import CaptionExtractor, ChannelConfig |
| 10 | +from synthesizers import CrossChannelSynthesizer |
12 | 11 | from timef.schema import Annotation, Recording |
13 | 12 |
|
14 | 13 |
|
15 | 14 | class CrossChannelExtractor(CaptionExtractor): |
16 | 15 | caption_type = "cross_channel" |
17 | 16 |
|
18 | | - def __init__(self, config: ChannelConfig, min_duration: int = 5): |
| 17 | + def __init__(self, config: ChannelConfig, synthesizers: list[CrossChannelSynthesizer]): |
19 | 18 | super().__init__(config) |
20 | | - self.min_duration = min_duration |
| 19 | + self.synthesizers = synthesizers |
21 | 20 |
|
22 | 21 | def extract(self, row: Recording) -> list[Annotation]: |
23 | | - try: |
24 | | - in_bed_idx = row.channel_names.index("sleep:inbed") |
25 | | - asleep_idx = row.channel_names.index("sleep:asleep") |
26 | | - except ValueError: |
27 | | - return [] |
28 | | - |
29 | | - in_bed = np.asarray(row.values[in_bed_idx], dtype=float) |
30 | | - asleep = np.asarray(row.values[asleep_idx], dtype=float) |
31 | | - if not np.any((~np.isnan(asleep)) & (asleep > 0)): |
32 | | - return [] |
33 | | - mask = (~np.isnan(in_bed)) & (in_bed > 0) & ~((~np.isnan(asleep)) & (asleep > 0)) |
34 | | - |
35 | 22 | results: list[Annotation] = [] |
36 | | - for start, end in _contiguous_windows(mask, self.min_duration): |
37 | | - end_inclusive = max(start, end - 1) |
38 | | - results.append( |
39 | | - Annotation( |
40 | | - caption_type=self.caption_type, |
41 | | - text=f"In bed but not sleeping from {self._time_label(start)} {start} to {self._time_label(end_inclusive)} {end_inclusive}.", |
42 | | - channel_idxs=(asleep_idx, in_bed_idx), |
43 | | - window=(start, end), |
44 | | - label="in_bed_not_sleeping", |
45 | | - ) |
46 | | - ) |
| 23 | + for synth in self.synthesizers: |
| 24 | + results.extend(synth.synthesize(row, self.config)) |
47 | 25 | return results |
48 | | - |
49 | | - def _time_label(self, value: int) -> str: |
50 | | - return "hour" if self.config.time_unit == "hours" else "minute" |
51 | | - |
52 | | - |
53 | | -def _contiguous_windows(mask: np.ndarray, min_duration: int) -> list[tuple[int, int]]: |
54 | | - if not mask.any(): |
55 | | - return [] |
56 | | - |
57 | | - padded = np.concatenate(([False], mask, [False])) |
58 | | - diffs = np.diff(padded.astype(np.int8)) |
59 | | - starts = np.where(diffs == 1)[0] |
60 | | - ends = np.where(diffs == -1)[0] |
61 | | - keep = (ends - starts) >= min_duration |
62 | | - return list(zip(starts[keep].tolist(), ends[keep].tolist())) |
0 commit comments