diff --git a/.gitignore b/.gitignore index ba0430d..c18dd8d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -__pycache__/ \ No newline at end of file +__pycache__/ diff --git a/README.md b/README.md index e978985..f517c8f 100644 --- a/README.md +++ b/README.md @@ -15,3 +15,27 @@ export MHC_DATASET_DIR="../hf-daily_max-nonwear=50" ```bash python captionizer.py ``` + +## Export To TimeF + +```python +from pathlib import Path + +from captionizer import Captionizer +from exporters.timef_export import TimeFExportConfig, export_caption_result + +result, _ = captionizer.run(max_rows=5) +root = export_caption_result( + result, + TimeFExportConfig( + output_root=Path("exports"), + dataset_id="mhc_caption_runs", + sampling_period=1, + timestamp_unit=1, + unit_sampling_rate="1 / minute", + unit_timestamp="minute", + time_column_name="time_minute", + ), +) +print(root) +``` diff --git a/annotator.py b/annotator.py index 95f5657..d5dfdf5 100644 --- a/annotator.py +++ b/annotator.py @@ -7,7 +7,7 @@ from __future__ import annotations from extractors import CaptionExtractor -from timef.schema import Annotation, AnnotationSpec, Sample, Signal +from runtime import Annotation, RuntimeRow class Annotator: @@ -19,14 +19,8 @@ def __init__(self, extractors: list[CaptionExtractor]): raise ValueError(f"Duplicate extractor for caption_type={extractor.caption_type!r}.") seen.add(extractor.caption_type) - def get_annotation_specs(self) -> list[AnnotationSpec]: - return [extractor.get_annotation_spec() for extractor in self.extractors] - - def annotate(self, signals: list[Signal]) -> tuple[list[Sample], list[Annotation]]: - samples: list[Sample] = [] + def annotate(self, row: RuntimeRow) -> list[Annotation]: annotations: list[Annotation] = [] for extractor in self.extractors: - for s, a in extractor.extract(signals): - samples.append(s) - annotations.append(a) - return samples, annotations + annotations.extend(extractor.extract(row)) + return annotations diff --git a/captionizer.py b/captionizer.py index 637539a..1286049 100644 --- a/captionizer.py +++ b/captionizer.py @@ -6,7 +6,7 @@ # from __future__ import annotations -from timef.schema import DatasetManifest, CaptionResult +from runtime import CaptionResult from transformer import Transformer from annotator import Annotator from reviewer import EvaluationResult, Reviewer @@ -28,13 +28,7 @@ def __init__( def run( self, max_rows: int | None = None, ) -> tuple[CaptionResult, EvaluationResult | None]: - manifest = DatasetManifest() - for spec in self.transformer.get_signal_specs(): - manifest.signal_specs[spec.id] = spec - for spec in self.annotator.get_annotation_specs(): - manifest.annotation_specs[spec.id] = spec - - result = CaptionResult(manifest=manifest) + result = CaptionResult() failed_rows: list[int] = [] for i in range(len(self.dataset)): @@ -42,14 +36,9 @@ def run( break try: - row = self.dataset[i] - signals = self.transformer.transform_row(row) - samples, annotations = self.annotator.annotate(signals) - - for signal in signals: - result.signals[signal.id] = signal - result.samples.extend(samples) - result.annotations.extend(annotations) + row = self.transformer.transform_row(self.dataset[i]) + row.annotations.extend(self.annotator.annotate(row)) + result.rows.append(row) except Exception: failed_rows.append(i) @@ -86,19 +75,15 @@ def run( print(f"Dataset size: {len(dataset)}") result, evaluation = captionizer.run(max_rows=5) - print(f"Signals: {len(result.signals)}") - print(f"Samples: {len(result.samples)}") - print(f"Annotations: {len(result.annotations)}") + print(f"Rows: {len(result.rows)}") + print(f"Annotations: {sum(len(row.annotations) for row in result.rows)}") if evaluation: print(f"Evaluation: {len(evaluation.scores)} scores, mean={evaluation.mean_score}") shown = 0 - for signals, samples, annotations in result.iter_rows(): - active = sum( - s.metadata.get("has_any_data", True) for s in signals - ) - if active >= 9: - plot_row(signals, samples, annotations) + for row in result.iter_rows(): + if row.active_channel_count() >= 9: + plot_row(row) shown += 1 if shown >= 4: break diff --git a/exporters/__init__.py b/exporters/__init__.py new file mode 100644 index 0000000..188e9e9 --- /dev/null +++ b/exporters/__init__.py @@ -0,0 +1,12 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from exporters.timef_export import TimeFExportConfig, export_caption_result + +__all__ = [ + "TimeFExportConfig", + "export_caption_result", +] diff --git a/exporters/timef_export.py b/exporters/timef_export.py new file mode 100644 index 0000000..39157c1 --- /dev/null +++ b/exporters/timef_export.py @@ -0,0 +1,215 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq + +from runtime import CaptionResult +from timenet_timef import ( + Annotation as PersistedAnnotation, + AnnotationSampleRef, + AnnotationSpec, + DatasetManifest, + Sample, + SampleSignalRef, + Signal, + SignalSpec, + TimeFWriter, + mark_validated, + validate_dataset, +) + +_SPEC_IDS = { + "statistical": 0, + "structural": 1, + "semantic": 2, +} + +_QUESTIONS = { + "statistical": "Describe summary statistics for the selected signal.", + "structural": "Describe the structural event in the selected signal.", + "semantic": "Describe the semantic event in the selected signal.", +} + + +@dataclass(frozen=True) +class TimeFExportConfig: + output_root: Path + dataset_id: str + version: str | None = None + source: str = "sensortslm" + description: str = "" + domains: tuple[str, ...] = ("health", "activity", "sleep") + tags: tuple[str, ...] = ("sensortslm", "generated") + sampling_period: float = 60.0 + timestamp_unit: float = 1.0 + unit_sampling_rate: str = "Hz" + unit_timestamp: str = "s" + time_column_name: str = "time_s" + value_unit_map: dict[str, str] | None = None + + +def export_caption_result(result: CaptionResult, config: TimeFExportConfig) -> Path: + version = config.version or datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + root = Path(config.output_root) / config.dataset_id / version + if root.exists(): + raise FileExistsError(f"Export target already exists: {root}") + if not result.rows: + raise ValueError("CaptionResult is empty") + if config.sampling_period <= 0: + raise ValueError("sampling_period must be > 0") + if config.timestamp_unit <= 0: + raise ValueError("timestamp_unit must be > 0") + if not config.time_column_name: + raise ValueError("time_column_name must be non-empty") + + first_row = result.rows[0] + channel_names = tuple(first_row.channel_names) + signal_specs = _build_signal_specs(channel_names, config) + annotation_specs = [ + AnnotationSpec(id=_SPEC_IDS["statistical"], task="statistical", domains=list(config.domains)), + AnnotationSpec(id=_SPEC_IDS["structural"], task="structural", domains=list(config.domains)), + AnnotationSpec(id=_SPEC_IDS["semantic"], task="semantic", domains=list(config.domains)), + ] + + writer = TimeFWriter(root) + samples: list[Sample] = [] + signals: list[Signal] = [] + annotations: list[PersistedAnnotation] = [] + sampling_rate = config.timestamp_unit / config.sampling_period + annotation_id = 0 + + for sample_id, row in enumerate(result.rows): + _validate_row_shape(row, channel_names) + signal_file = f"sample-{sample_id}.parquet" + _write_signal_frame(root / "signals" / signal_file, config.sampling_period, config.time_column_name, row) + total_duration = float(row.values.shape[1] * config.sampling_period) + + signal_refs: list[SampleSignalRef] = [] + for channel_idx, channel_name in enumerate(channel_names): + signal_id = sample_id * len(channel_names) + channel_idx + signals.append( + Signal( + id=signal_id, + spec_id=channel_idx, + name=f"{row.row_id}:{channel_name}", + sampling_rate=sampling_rate, + total_duration=total_duration, + shard_file=signal_file, + row_group_id=0, + row_start=0, + row_count=row.values.shape[1], + ) + ) + signal_refs.append( + SampleSignalRef( + signal_id=signal_id, + sampling_rate=sampling_rate, + channels=[channel_name], + ) + ) + + samples.append( + Sample( + id=sample_id, + dataset_id=config.dataset_id, + unit_timestamp=config.unit_timestamp, + windows=None, + signals=signal_refs, + ) + ) + + for annotation in row.annotations: + spec_id = _SPEC_IDS.get(annotation.kind) + if spec_id is None: + raise ValueError(f"Unsupported annotation kind: {annotation.kind}") + reference = json.dumps( + { + "channel_names": [channel_names[idx] for idx in annotation.channel_idxs], + "window": list(annotation.window) if annotation.window is not None else None, + "kind": annotation.kind, + }, + sort_keys=True, + ) + + label = None if annotation.label is None else str(annotation.label) + answer = annotation.text if label is None else None + rationale = annotation.text if label is not None and annotation.text else None + + annotations.append( + PersistedAnnotation( + id=annotation_id, + dataset_id=config.dataset_id, + spec_id=spec_id, + samples=[AnnotationSampleRef(sample=sample_id, reference=reference)], + question=_QUESTIONS[annotation.kind], + answer=answer, + label=label, + rationale=rationale, + ) + ) + annotation_id += 1 + + manifest = DatasetManifest( + dataset_id=config.dataset_id, + version=version, + source=config.source, + description=config.description, + domains=list(config.domains), + tags=list(config.tags), + signals=signals, + signal_spec=signal_specs, + sensor_spec=[], + annotation_spec=annotation_specs, + sample_count=len(samples), + annotation_count=len(annotations), + validated=False, + ) + + writer.write_manifest(manifest) + writer.write_samples(samples) + writer.write_annotations(annotations, manifest) + validate_dataset(root) + mark_validated(root, validated=True) + return root + + +def _build_signal_specs(channel_names: tuple[str, ...], config: TimeFExportConfig) -> list[SignalSpec]: + value_unit_map = config.value_unit_map or {} + return [ + SignalSpec( + id=idx, + name=channel_name, + channels=[channel_name], + unit_sampling_rate=config.unit_sampling_rate, + unit_timestamp=config.unit_timestamp, + unit_value=value_unit_map.get(channel_name), + ) + for idx, channel_name in enumerate(channel_names) + ] + + +def _validate_row_shape(row, channel_names: tuple[str, ...]) -> None: + if tuple(row.channel_names) != channel_names: + raise ValueError("All RuntimeRow objects must share the same channel_names order for export") + if row.values.shape[0] != len(channel_names): + raise ValueError("RuntimeRow values must have one row per channel") + + +def _write_signal_frame(path: Path, sampling_period: float, time_column_name: str, row) -> None: + time_axis = [idx * sampling_period for idx in range(row.values.shape[1])] + payload: dict[str, list[float]] = {time_column_name: time_axis} + for idx, channel_name in enumerate(row.channel_names): + payload[channel_name] = row.values[idx].astype(float).tolist() + path.parent.mkdir(parents=True, exist_ok=True) + pq.write_table(pa.table(payload), path) diff --git a/extractors/__init__.py b/extractors/__init__.py index 22f1419..6404395 100644 --- a/extractors/__init__.py +++ b/extractors/__init__.py @@ -13,7 +13,7 @@ from aggregators import MetricAggregator from detectors import StructuralDetector -from timef.schema import Annotation, AnnotationSpec, Sample, SampleRef, SampleSignalRef, Signal +from runtime import Annotation, RuntimeRow VALID_CAPTION_TYPES = ("statistical", "structural", "semantic") @@ -51,16 +51,13 @@ def __init_subclass__(cls, **kwargs): def __init__(self, config: ChannelConfig): self.config = config - def get_annotation_spec(self) -> AnnotationSpec: - return AnnotationSpec(id=f"captioning:{self.caption_type}", task="captioning") - @staticmethod def _seed(key: str) -> int: return zlib.crc32(key.encode("utf-8")) & 0xFFFFFFFF @abc.abstractmethod - def extract(self, signals: list[Signal]) -> list[tuple[Sample, Annotation]]: - """Extract captions and return (Sample, Annotation) pairs.""" + def extract(self, row: RuntimeRow) -> list[Annotation]: + """Extract annotations for a row.""" ... diff --git a/extractors/generative.py b/extractors/generative.py index 01c8c60..c1bfa06 100644 --- a/extractors/generative.py +++ b/extractors/generative.py @@ -8,7 +8,7 @@ from extractors import CaptionExtractor, ChannelConfig from models.base import BaseModel -from timef.schema import Annotation, Sample, Signal +from runtime import Annotation, RuntimeRow class GenerativeExtractor(CaptionExtractor): @@ -20,5 +20,5 @@ def __init__(self, config: ChannelConfig, model: BaseModel): super().__init__(config) self.model = model - def extract(self, signals: list[Signal]) -> list[tuple[Sample, Annotation]]: + def extract(self, row: RuntimeRow) -> list[Annotation]: raise NotImplementedError diff --git a/extractors/semantic.py b/extractors/semantic.py index d7417ec..5ec6294 100644 --- a/extractors/semantic.py +++ b/extractors/semantic.py @@ -12,7 +12,7 @@ import numpy as np from extractors import CaptionExtractor, ChannelConfig -from timef.schema import Annotation, Sample, SampleRef, SampleSignalRef, Signal +from runtime import Annotation, RuntimeRow TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json" SEM_TEMPLATES: list[str] = json.loads(TEMPLATES_PATH.read_text())["semantic"] @@ -34,46 +34,37 @@ def __init__( self.semantic_names = set(activity_channels) | set(sleep_channels) self.min_minutes = min_minutes - def extract(self, signals: list[Signal]) -> list[tuple[Sample, Annotation]]: - seed = self._seed(signals[0].id) - semantic_signals = {s.name: s for s in signals if s.name in self.semantic_names} - if not semantic_signals: + def extract(self, row: RuntimeRow) -> list[Annotation]: + seed = self._seed(row.row_id) + semantic_indices = [i for i, name in enumerate(row.channel_names) if name in self.semantic_names] + if not semantic_indices: return [] - row_key = signals[0].id.rsplit(":", 1)[0] + continuous_indices = tuple( + i for i, name in enumerate(row.channel_names) if name in self.config.continuous + ) - continuous_refs = [ - SampleSignalRef(signal_id=s.id) - for s in signals if s.name in self.config.continuous - ] - - results = [] + results: list[Annotation] = [] j = 0 - for name, signal in semantic_signals.items(): - windows = _contiguous_windows(signal.data, self.min_minutes) + for idx in semantic_indices: + name = row.channel_names[idx] + windows = _contiguous_windows(row.values[idx], self.min_minutes) for start, end in windows: template = SEM_TEMPLATES[(seed + j) % len(SEM_TEMPLATES)] display_name = self.config.display_name(name) caption = template.format(name=display_name, start=start, end=max(end - 1, start)) caption = caption[0].upper() + caption[1:] - label = name if name in self.activity_channels else None - - sample_id = f"{row_key}:{self.caption_type}:{name}:{j}" - sample = Sample( - id=sample_id, - windows=[(start, end)], - signals=continuous_refs, - ) - annotation = Annotation( - id=sample_id, - spec_id=f"captioning:{self.caption_type}", - samples=[SampleRef(sample_id=sample_id)], - answer=caption, - label=label, + results.append( + Annotation( + kind=self.caption_type, + text=caption, + channel_idxs=continuous_indices, + window=(start, end), + label=label, + ) ) - results.append((sample, annotation)) j += 1 return results diff --git a/extractors/statistical.py b/extractors/statistical.py index a1ded63..4ea1592 100644 --- a/extractors/statistical.py +++ b/extractors/statistical.py @@ -13,7 +13,7 @@ from aggregators import MetricAggregator from extractors import CaptionExtractor -from timef.schema import Annotation, Sample, SampleRef, SampleSignalRef, Signal +from runtime import Annotation, RuntimeRow TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json" STAT_TEMPLATES: list[str] = json.loads(TEMPLATES_PATH.read_text())["statistical"] @@ -24,17 +24,14 @@ class StatisticalExtractor(CaptionExtractor): caption_type = "statistical" - def extract(self, signals: list[Signal]) -> list[tuple[Sample, Annotation]]: - if not signals: - raise ValueError("signals must not be empty") - seed = self._seed(signals[0].id) - results = [] + def extract(self, row: RuntimeRow) -> list[Annotation]: + seed = self._seed(row.row_id) + results: list[Annotation] = [] - for i, signal in enumerate(signals): + for i, signal in enumerate(row.iter_channels()): if signal.name not in self.config.continuous: continue - # remove data gaps series = signal.data[~np.isnan(signal.data)].astype(float) aggregator = self.config.aggregators.get(signal.name, DEFAULT_AGGREGATOR) @@ -52,15 +49,6 @@ def extract(self, signals: list[Signal]) -> list[tuple[Sample, Annotation]]: mean=format(mean_v), max=format(max_v), min=format(min_v), std=format(std_v), ) - - sample_id = f"{signal.id}:{self.caption_type}" - sample = Sample(id=sample_id, signals=[SampleSignalRef(signal_id=signal.id)]) - annotation = Annotation( - id=f"{signal.id}:{self.caption_type}", - spec_id=f"captioning:{self.caption_type}", - samples=[SampleRef(sample_id=sample_id)], - answer=caption, - ) - results.append((sample, annotation)) + results.append(Annotation(kind=self.caption_type, text=caption, channel_idxs=(i,))) return results diff --git a/extractors/structural.py b/extractors/structural.py index 337b010..be03223 100644 --- a/extractors/structural.py +++ b/extractors/structural.py @@ -10,7 +10,7 @@ import pathlib from extractors import CaptionExtractor -from timef.schema import Annotation, Sample, SampleRef, SampleSignalRef, Signal +from runtime import Annotation, RuntimeRow TEMPLATES_PATH = pathlib.Path(__file__).resolve().parent.parent / "templates" / "templates.json" STRUCT_TEMPLATES: dict[str, list[str]] = json.loads(TEMPLATES_PATH.read_text())["structural"] @@ -19,34 +19,27 @@ class StructuralExtractor(CaptionExtractor): caption_type = "structural" - def extract(self, signals: list[Signal]) -> list[tuple[Sample, Annotation]]: - seed = self._seed(signals[0].id) - results = [] - + def extract(self, row: RuntimeRow) -> list[Annotation]: + seed = self._seed(row.row_id) + results: list[Annotation] = [] j = 0 - for i, signal in enumerate(signals): + for i, signal in enumerate(row.iter_channels()): detectors = self.config.detectors.get(signal.name, []) - display_name = self.config.meta.get(signal.name, (signal.name, "", 0))[0] + display_name = signal.display_name for detector in detectors: for result in detector.detect(signal.data): templates = STRUCT_TEMPLATES[result.event_type] template = templates[(seed + i + j) % len(templates)] caption = template.format(name=display_name, **result.template_vars()) - - sample_id = f"{signal.id}:{self.caption_type}:{j}" - sample = Sample( - id=sample_id, - windows=[result.window], - signals=[SampleSignalRef(signal_id=signal.id)], - ) - annotation = Annotation( - id=f"{signal.id}:{self.caption_type}:{j}", - spec_id=f"captioning:{self.caption_type}", - samples=[SampleRef(sample_id=sample_id)], - answer=caption, + results.append( + Annotation( + kind=self.caption_type, + text=caption, + channel_idxs=(i,), + window=result.window, + ) ) - results.append((sample, annotation)) j += 1 return results diff --git a/mhc/transformer.py b/mhc/transformer.py index 22c2ede..09add59 100644 --- a/mhc/transformer.py +++ b/mhc/transformer.py @@ -10,7 +10,7 @@ from extractors import ChannelConfig from mhc.constants import MHC_CHANNEL_CONFIG -from timef.schema import SignalSpec, Signal +from runtime import RuntimeRow from transformer import Transformer @@ -18,30 +18,11 @@ class MHCTransformer(Transformer): def __init__(self, config: ChannelConfig = MHC_CHANNEL_CONFIG): self.config = config - def get_signal_specs(self) -> list[SignalSpec]: - specs = [] - for ch in self.config.names: - if ch in self.config.continuous: - display_name, unit, _ = self.config.meta[ch] - else: - display_name = ch - unit = "binary" - specs.append(SignalSpec( - id=f"mhc:{ch}", - name=ch, - display_name=display_name, - channels=["value"], - unit_sampling_rate="1/min", - unit_timestamp="minutes", - unit_value=unit, - )) - return specs - - def transform_row(self, row: dict) -> list[Signal]: + def transform_row(self, row: dict) -> RuntimeRow: user_id = row["user_id"] date = row["date"] row_id = f"mhc:{user_id}:{date}" - data = np.array(row["data"], dtype=np.float32).copy() # (19, 1440) + data = np.asarray(row["data"], dtype=np.float32).copy() # (19, 1440) # ZeroToNaNTransform (according to MHC-Benchmark) data[5][data[5] == 0] = np.nan @@ -52,32 +33,49 @@ def transform_row(self, row: dict) -> list[Signal]: # Heart rate (channel 5): convert from beats/sec to BPM data[5] *= 60 - total_nonwear_minutes = row.get("total_nonwear_minutes") - has_any_data = row.get("has_any_data", []) - minutes_nonzero_or_nan = row.get("minutes_nonzero_or_nan", []) + has_any_data = np.asarray(row.get("has_any_data", []), dtype=bool) + minutes_nonzero_or_nan = np.asarray(row.get("minutes_nonzero_or_nan", []), dtype=np.float32) channel_names = row.get("channel_names", []) channel_units = row.get("channel_units", []) - channel_variance = row.get("channel_variance", []) + channel_variance = np.asarray(row.get("channel_variance", []), dtype=np.float32) + + n_channels = len(self.config.names) + if has_any_data.shape != (n_channels,): + has_any_data = np.array( + [has_any_data[i] if i < len(has_any_data) else False for i in range(n_channels)], + dtype=bool, + ) + if minutes_nonzero_or_nan.shape != (n_channels,): + minutes_nonzero_or_nan = np.array( + [ + float(minutes_nonzero_or_nan[i]) if i < len(minutes_nonzero_or_nan) else 0.0 + for i in range(n_channels) + ], + dtype=np.float32, + ) + if channel_variance.shape != (n_channels,): + channel_variance = np.array( + [float(channel_variance[i]) if i < len(channel_variance) else 0.0 for i in range(n_channels)], + dtype=np.float32, + ) - signals = [] + display_names = [] + units = [] for i, ch in enumerate(self.config.names): - signals.append(Signal( - id=f"mhc:{user_id}:{date}:{ch}", - spec_id=f"mhc:{ch}", - name=ch, - data=data[i], - metadata={ - "row_id": row_id, - "user_id": user_id, - "date": date, - "channel_name": channel_names[i] if i < len(channel_names) else ch, - "channel_unit": channel_units[i] if i < len(channel_units) else None, - "has_any_data": has_any_data[i] if i < len(has_any_data) else None, - "minutes_nonzero_or_nan": ( - minutes_nonzero_or_nan[i] if i < len(minutes_nonzero_or_nan) else None - ), - "channel_variance": channel_variance[i] if i < len(channel_variance) else None, - "total_nonwear_minutes": total_nonwear_minutes, - }, - )) - return signals + display_names.append(channel_names[i] if i < len(channel_names) else self.config.display_name(ch)) + units.append(channel_units[i] if i < len(channel_units) else self.config.meta.get(ch, ("", None, 0))[1]) + + return RuntimeRow( + row_id=row_id, + user_id=user_id, + date=date, + values=np.ascontiguousarray(data), + channel_names=tuple(self.config.names), + display_names=tuple(display_names), + units=tuple(units), + has_any_data=has_any_data, + minutes_nonzero_or_nan=minutes_nonzero_or_nan, + channel_variance=channel_variance, + total_nonwear_minutes=row.get("total_nonwear_minutes"), + wear_pct=row.get("wear_pct"), + ) diff --git a/models/base.py b/models/base.py index 754682e..27109b8 100644 --- a/models/base.py +++ b/models/base.py @@ -13,7 +13,7 @@ import numpy as np from extractors import ChannelConfig -from timef.schema import Annotation, Signal +from runtime import Annotation, SignalView @dataclass @@ -32,31 +32,32 @@ def __init__(self, channel_config: ChannelConfig): def process( self, prompt: str, - signals: list[Signal], + signals: list[SignalView], annotations: list[Annotation] | None = None, multi_channel: bool = False, ) -> ModelResponse: """Render signals as plots, assemble with optional annotations, call model.""" caption = self._format_annotations(annotations) if annotations else None return self._call( - prompt=prompt, signals=signals, caption=caption, multi_channel=multi_channel, + prompt=prompt, signals=signals, caption=caption, multi_channel=multi_channel ) @abc.abstractmethod def _call( self, prompt: str, - signals: list[Signal], + signals: list[SignalView], caption: str | None, multi_channel: bool, ) -> ModelResponse: """Subclasses render signals in their native format and call the model.""" ... - def _plot_signal(self, signal: Signal) -> plt.Figure: + def _plot_signal(self, signal: SignalView) -> plt.Figure: """Plot a single signal. Subclasses convert the figure to their native image format.""" - meta = self.channel_config.meta.get(signal.name) - title = f"{meta[0]} ({meta[1]})" if meta else signal.name + title = signal.display_name + if signal.unit: + title = f"{title} ({signal.unit})" data = signal.data time = np.arange(len(data)) @@ -72,4 +73,4 @@ def _plot_signal(self, signal: Signal) -> plt.Figure: def _format_annotations(self, annotations: list[Annotation]) -> str: """Concatenate all annotation answers.""" - return "\n".join(ann.answer for ann in annotations if ann.answer) + return "\n".join(annotation.text for annotation in annotations if annotation.text) diff --git a/models/client.py b/models/client.py index f921759..5b2a5c6 100644 --- a/models/client.py +++ b/models/client.py @@ -10,7 +10,7 @@ from extractors import ChannelConfig from models.base import BaseModel, ModelResponse -from timef.schema import Signal +from runtime import SignalView @dataclass(frozen=True) @@ -32,7 +32,7 @@ def __init__(self, config: ClientConfig, channel_config: ChannelConfig): def _call( self, prompt: str, - signals: list[Signal], + signals: list[SignalView], caption: str | None, multi_channel: bool, ) -> ModelResponse: diff --git a/models/local.py b/models/local.py index 8349640..d333034 100644 --- a/models/local.py +++ b/models/local.py @@ -16,7 +16,7 @@ from extractors import ChannelConfig from models.base import BaseModel, ModelResponse -from timef.schema import Signal +from runtime import SignalView @dataclass(frozen=True) @@ -43,7 +43,7 @@ def __init__(self, config: LocalConfig, channel_config: ChannelConfig): def _call( self, prompt: str, - signals: list[Signal], + signals: list[SignalView], caption: str | None, multi_channel: bool, ) -> ModelResponse: @@ -63,7 +63,7 @@ def _call( return ModelResponse(text=generated, input_tokens=0, output_tokens=0) - def _render_signal(self, signal: Signal) -> Image.Image: + def _render_signal(self, signal: SignalView) -> Image.Image: """Plot signal and convert to PIL Image.""" fig = self._plot_signal(signal) buf = io.BytesIO() diff --git a/reviewer.py b/reviewer.py index 8672668..cb3a8f3 100644 --- a/reviewer.py +++ b/reviewer.py @@ -11,7 +11,7 @@ from dataclasses import dataclass, field from models.base import BaseModel -from timef.schema import Annotation, CaptionResult +from runtime import Annotation, CaptionResult EVALUATE_PROMPT = """\ @@ -64,7 +64,9 @@ def evaluate(self, result: CaptionResult, per_channel: bool = False) -> Evaluati ) scores.append(self._parse_score(response.text)) else: - for signals, _, annotations in result.iter_rows(): + for row in result.iter_rows(): + signals = list(row.iter_channels()) + annotations = row.annotations if not annotations: continue response = self.model.process( diff --git a/runtime/__init__.py b/runtime/__init__.py new file mode 100644 index 0000000..492c4a1 --- /dev/null +++ b/runtime/__init__.py @@ -0,0 +1,14 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from runtime.types import Annotation, CaptionResult, RuntimeRow, SignalView + +__all__ = [ + "Annotation", + "CaptionResult", + "RuntimeRow", + "SignalView", +] diff --git a/runtime/types.py b/runtime/types.py new file mode 100644 index 0000000..5029f70 --- /dev/null +++ b/runtime/types.py @@ -0,0 +1,122 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Generator + +import numpy as np + + +@dataclass(slots=True) +class Annotation: + kind: str + text: str + channel_idxs: tuple[int, ...] = () + window: tuple[int, int] | None = None + label: Any = None + + +@dataclass(slots=True) +class SignalView: + row: "RuntimeRow" + idx: int + + @property + def id(self) -> str: + return f"{self.row.row_id}:{self.name}" + + @property + def name(self) -> str: + return self.row.channel_names[self.idx] + + @property + def display_name(self) -> str: + return self.row.display_names[self.idx] + + @property + def unit(self) -> str | None: + return self.row.units[self.idx] + + @property + def data(self) -> np.ndarray: + return self.row.values[self.idx] + + @property + def has_any_data(self) -> bool: + return bool(self.row.has_any_data[self.idx]) + + @property + def minutes_nonzero_or_nan(self) -> float: + return float(self.row.minutes_nonzero_or_nan[self.idx]) + + @property + def variance(self) -> float: + return float(self.row.channel_variance[self.idx]) + + +@dataclass(slots=True) +class RuntimeRow: + row_id: str + user_id: str + date: str + values: np.ndarray + channel_names: tuple[str, ...] + display_names: tuple[str, ...] + units: tuple[str | None, ...] + has_any_data: np.ndarray + minutes_nonzero_or_nan: np.ndarray + channel_variance: np.ndarray + total_nonwear_minutes: float | None = None + wear_pct: float | None = None + annotations: list[Annotation] = field(default_factory=list) + + def __post_init__(self) -> None: + if self.values.ndim != 2: + raise ValueError("values must have shape (channels, time)") + n_channels = self.values.shape[0] + for field_name in ("channel_names", "display_names", "units"): + if len(getattr(self, field_name)) != n_channels: + raise ValueError(f"{field_name} length must match number of channels") + for field_name in ("has_any_data", "minutes_nonzero_or_nan", "channel_variance"): + if getattr(self, field_name).shape != (n_channels,): + raise ValueError(f"{field_name} must have shape ({n_channels},)") + + def signal(self, idx: int) -> SignalView: + return SignalView(row=self, idx=idx) + + def iter_channels(self) -> Generator[SignalView, None, None]: + for idx in range(self.values.shape[0]): + yield SignalView(row=self, idx=idx) + + def channel(self, idx: int) -> SignalView: + return self.signal(idx) + + def iter_signals(self) -> Generator[SignalView, None, None]: + yield from self.iter_channels() + + def annotations_for_signal(self, idx: int) -> list[Annotation]: + return [annotation for annotation in self.annotations if idx in annotation.channel_idxs] + + def active_channel_count(self) -> int: + return int(np.count_nonzero(self.has_any_data)) + + +@dataclass(slots=True) +class CaptionResult: + rows: list[RuntimeRow] = field(default_factory=list) + + def iter_rows(self) -> Generator[RuntimeRow, None, None]: + yield from self.rows + + def iter_channels(self) -> Generator[tuple[SignalView, list[Annotation]], None, None]: + for row in self.rows: + for signal in row.iter_channels(): + yield signal, row.annotations_for_signal(signal.idx) + + def iter_signals(self) -> Generator[tuple[SignalView, list[Annotation]], None, None]: + yield from self.iter_channels() diff --git a/timef/__init__.py b/timef/__init__.py index ee30aac..cdb938a 100644 --- a/timef/__init__.py +++ b/timef/__init__.py @@ -4,26 +4,20 @@ # # SPDX-License-Identifier: MIT # -from timef.schema import ( - Annotation, - AnnotationSpec, - CaptionResult, - DatasetManifest, - Sample, - SampleRef, - SampleSignalRef, - Signal, - SignalSpec, -) +"""Compatibility shim for legacy timef imports. + +Prefer importing runtime types from `runtime` and export helpers from +`exporters.timef_export`. +""" + +from exporters.timef_export import TimeFExportConfig, export_caption_result +from runtime import Annotation, CaptionResult, RuntimeRow, SignalView __all__ = [ "Annotation", - "AnnotationSpec", "CaptionResult", - "DatasetManifest", - "Sample", - "SampleRef", - "SampleSignalRef", - "Signal", - "SignalSpec", + "RuntimeRow", + "SignalView", + "TimeFExportConfig", + "export_caption_result", ] diff --git a/timef/folder_only_for_backwards_compatability.txt b/timef/folder_only_for_backwards_compatability.txt new file mode 100644 index 0000000..7f56159 --- /dev/null +++ b/timef/folder_only_for_backwards_compatability.txt @@ -0,0 +1,7 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +This folder is only here for backwards compatibility it serves no purpose in the current implementation, its supposed to be deleted soon \ No newline at end of file diff --git a/timef/schema.py b/timef/schema.py index 6b4dc5d..2705eb6 100644 --- a/timef/schema.py +++ b/timef/schema.py @@ -4,137 +4,16 @@ # # SPDX-License-Identifier: MIT # -from __future__ import annotations +"""Compatibility shim for legacy timef.schema imports. -from dataclasses import dataclass, field -from typing import Any, Generator +Prefer importing runtime types from `runtime.types`. +""" +from runtime.types import Annotation, CaptionResult, RuntimeRow, SignalView -@dataclass(frozen=True) -class SignalSpec: - id: str # e.g. "mhc:heart_rate" - name: str # e.g. "heart_rate" - display_name: str # e.g. "Heart rate" - channels: list[str] # e.g. ["value"] (single-channel for MHC) - unit_sampling_rate: str # e.g. "1/min" - unit_timestamp: str # e.g. "minutes" - unit_value: str # e.g. "bpm" - sensor_id: str | None = None - - -@dataclass(frozen=True) -class Signal: - id: str # e.g. "mhc:user42:2024-01-15:heart_rate" - spec_id: str # -> SignalSpec.id - name: str - data: Any = field(repr=False) # np.ndarray, in-memory for now - metadata: dict = field(default_factory=dict) - - -@dataclass(frozen=True) -class SampleSignalRef: - signal_id: str - sampling_rate: float | None = None - channels: list[str] | None = None # None = all - - -@dataclass(frozen=True) -class Sample: - id: str - windows: list[tuple[int, int]] | None = None # None = full signal - signals: list[SampleSignalRef] = field(default_factory=list) - - -@dataclass(frozen=True) -class AnnotationSpec: - id: str # e.g. "captioning:statistical" - task: str # e.g. "captioning" - schema: dict | None = None - - -@dataclass(frozen=True) -class SampleRef: - sample_id: str - reference: str | None = None # alias for [ref:X] in text - forecast: bool = False - - -@dataclass(frozen=True) -class Annotation: - id: str - spec_id: str # -> AnnotationSpec.id - samples: list[SampleRef] = field(default_factory=list) - question: str | None = None - answer: str | None = None - label: Any = None - - -@dataclass -class DatasetManifest: - signal_specs: dict[str, SignalSpec] = field(default_factory=dict) - annotation_specs: dict[str, AnnotationSpec] = field(default_factory=dict) - metadata: dict = field(default_factory=dict) - - -@dataclass -class CaptionResult: - manifest: DatasetManifest = field(default_factory=DatasetManifest) - signals: dict[str, Signal] = field(default_factory=dict) - samples: list[Sample] = field(default_factory=list) - annotations: list[Annotation] = field(default_factory=list) - - def iter_rows(self) -> Generator[tuple[list[Signal], list[Sample], list[Annotation]], None, None]: - rows: dict[object, list[Signal]] = {} - for signal in self.signals.values(): - row_id = signal.metadata.get("row_id") - key = row_id if row_id is not None else tuple(sorted(signal.metadata.items())) - rows.setdefault(key, []).append(signal) - - sample_to_annotations: dict[str, list[Annotation]] = {} - for ann in self.annotations: - for sr in ann.samples: - sample_to_annotations.setdefault(sr.sample_id, []).append(ann) - - for signals in rows.values(): - signal_ids = {s.id for s in signals} - - # Find samples referencing this row's signals - row_samples = [ - s for s in self.samples - if any(ref.signal_id in signal_ids for ref in s.signals) - ] - - # Find annotations referencing this row's samples - row_annotations = [] - seen = set() - for sample in row_samples: - for ann in sample_to_annotations.get(sample.id, []): - if ann.id not in seen: - seen.add(ann.id) - row_annotations.append(ann) - - yield signals, row_samples, row_annotations - - def iter_channels(self) -> Generator[tuple[Signal, list[Annotation]], None, None]: - """Yield (signal, annotations) per individual signal/channel.""" - # Map signal_id → samples that reference it - signal_to_samples: dict[str, list[Sample]] = {} - for sample in self.samples: - for ref in sample.signals: - signal_to_samples.setdefault(ref.signal_id, []).append(sample) - - # Map sample_id → annotations - sample_to_annotations: dict[str, list[Annotation]] = {} - for ann in self.annotations: - for sr in ann.samples: - sample_to_annotations.setdefault(sr.sample_id, []).append(ann) - - for signal in self.signals.values(): - anns: list[Annotation] = [] - seen: set[str] = set() - for sample in signal_to_samples.get(signal.id, []): - for ann in sample_to_annotations.get(sample.id, []): - if ann.id not in seen: - seen.add(ann.id) - anns.append(ann) - yield signal, anns +__all__ = [ + "Annotation", + "CaptionResult", + "RuntimeRow", + "SignalView", +] diff --git a/timenet_timef/__init__.py b/timenet_timef/__init__.py new file mode 100644 index 0000000..0111fcf --- /dev/null +++ b/timenet_timef/__init__.py @@ -0,0 +1,40 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from timenet_timef.io import TimeFReader, TimeFWriter +from timenet_timef.schema import ( + Annotation, + AnnotationSampleRef, + AnnotationSpec, + DatasetManifest, + Sample, + SampleSignalRef, + SensorSpec, + Signal, + SignalSpec, + VALID_DOMAINS, + validate_annotation_against_spec, +) +from timenet_timef.validate import TimeFValidationError, mark_validated, validate_dataset + +__all__ = [ + "Annotation", + "AnnotationSampleRef", + "AnnotationSpec", + "DatasetManifest", + "Sample", + "SampleSignalRef", + "SensorSpec", + "Signal", + "SignalSpec", + "TimeFReader", + "TimeFWriter", + "TimeFValidationError", + "VALID_DOMAINS", + "mark_validated", + "validate_annotation_against_spec", + "validate_dataset", +] diff --git a/timenet_timef/io.py b/timenet_timef/io.py new file mode 100644 index 0000000..3fb7432 --- /dev/null +++ b/timenet_timef/io.py @@ -0,0 +1,227 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterator + +import pyarrow as pa +import pyarrow.parquet as pq + +from timenet_timef.schema import Annotation, DatasetManifest, Sample +from timenet_timef.utils import ensure_dir + + +def _table_from_rows(rows: list[dict[str, Any]]) -> pa.Table: + if not rows: + return pa.table({}) + return pa.Table.from_pylist(rows) + + +def _is_time_column(column_name: str) -> bool: + return column_name == "time" or column_name.startswith("time_") + + +@dataclass(slots=True) +class SampleSignalIndexRow: + sample_id: int + signal_id: int + sampling_rate: float | None + channels: str | None + + def to_dict(self) -> dict[str, Any]: + return { + "sample_id": self.sample_id, + "signal_id": self.signal_id, + "sampling_rate": self.sampling_rate, + "channels": self.channels, + } + + +@dataclass(slots=True) +class AnnotationTaskIndexRow: + sample_id: int + annotation_id: int + task_type: str + + def to_dict(self) -> dict[str, Any]: + return { + "sample_id": self.sample_id, + "annotation_id": self.annotation_id, + "task_type": self.task_type, + } + + +@dataclass(slots=True) +class AnnotationDomainIndexRow: + sample_id: int + annotation_id: int + domain: str + + def to_dict(self) -> dict[str, Any]: + return { + "sample_id": self.sample_id, + "annotation_id": self.annotation_id, + "domain": self.domain, + } + + +class TimeFWriter: + def __init__(self, root: Path): + self.root = root + self.signals_dir = ensure_dir(root / "signals") + + def write_manifest(self, manifest: DatasetManifest) -> None: + ensure_dir(self.root) + (self.root / "manifest.json").write_text( + json.dumps(manifest.to_dict(), indent=2), + encoding="utf-8", + ) + + def write_samples(self, rows: list[Sample]) -> None: + pq.write_table(_table_from_rows([row.to_parquet_dict() for row in rows]), self.root / "samples.parquet") + + index_rows: list[dict[str, Any]] = [] + for sample in rows: + for signal_ref in sample.signals: + channels = None if signal_ref.channels is None else json.dumps(signal_ref.channels) + index_rows.append( + SampleSignalIndexRow( + sample_id=sample.id, + signal_id=signal_ref.signal_id, + sampling_rate=signal_ref.sampling_rate, + channels=channels, + ).to_dict() + ) + pq.write_table(_table_from_rows(index_rows), self.root / "sample_signal_index.parquet") + + def write_annotations(self, rows: list[Annotation], manifest: DatasetManifest | None = None) -> None: + annotation_rows = [row.to_parquet_dict() for row in rows] + if annotation_rows: + annotation_table = _table_from_rows(annotation_rows) + else: + annotation_table = pa.table( + { + "id": pa.array([], type=pa.int64()), + "dataset_id": pa.array([], type=pa.string()), + "spec_id": pa.array([], type=pa.int64()), + "samples": pa.array([], type=pa.string()), + "question": pa.array([], type=pa.string()), + "answer": pa.array([], type=pa.string()), + "label": pa.array([], type=pa.string()), + "rationale": pa.array([], type=pa.string()), + } + ) + pq.write_table(annotation_table, self.root / "annotations.parquet") + + task_lookup: dict[int, str] = {} + domains: list[str] = [] + if manifest is not None: + task_lookup = {spec.id: spec.task for spec in manifest.annotation_spec} + domains = list(manifest.domains) + + task_rows: list[dict[str, Any]] = [] + domain_rows: list[dict[str, Any]] = [] + for annotation in rows: + task_type = task_lookup.get(annotation.spec_id) + for sample_ref in annotation.samples: + if task_type is not None: + task_rows.append( + AnnotationTaskIndexRow( + sample_id=sample_ref.sample, + annotation_id=annotation.id, + task_type=task_type, + ).to_dict() + ) + for domain in domains: + domain_rows.append( + AnnotationDomainIndexRow( + sample_id=sample_ref.sample, + annotation_id=annotation.id, + domain=domain, + ).to_dict() + ) + + if task_rows: + task_table = _table_from_rows(task_rows) + else: + task_table = pa.table( + { + "sample_id": pa.array([], type=pa.int64()), + "annotation_id": pa.array([], type=pa.int64()), + "task_type": pa.array([], type=pa.string()), + } + ) + if domain_rows: + domain_table = _table_from_rows(domain_rows) + else: + domain_table = pa.table( + { + "sample_id": pa.array([], type=pa.int64()), + "annotation_id": pa.array([], type=pa.int64()), + "domain": pa.array([], type=pa.string()), + } + ) + + pq.write_table(task_table, self.root / "annotation_tasks_index.parquet") + pq.write_table(domain_table, self.root / "annotation_domains_index.parquet") + + def write_signal_frame(self, signal_file: str, frame: dict[str, list[Any]]) -> None: + pq.write_table(pa.table(frame), self.signals_dir / signal_file) + + +class TimeFReader: + def __init__(self, root: Path): + self.root = root + + def read_manifest(self) -> DatasetManifest: + payload = json.loads((self.root / "manifest.json").read_text(encoding="utf-8")) + return DatasetManifest.from_dict(payload) + + def read_samples(self) -> Iterator[Sample]: + table = pq.read_table(self.root / "samples.parquet") + for row in table.to_pylist(): + yield Sample.from_parquet_dict(row) + + def read_sample(self, sample_id: int) -> Sample | None: + for sample in self.read_samples(): + if sample.id == sample_id: + return sample + return None + + def read_annotations(self) -> Iterator[Annotation]: + table = pq.read_table(self.root / "annotations.parquet") + for row in table.to_pylist(): + yield Annotation.from_parquet_dict(row) + + def read_signal_frame(self, signal_file: str) -> list[dict[str, Any]]: + return pq.read_table(self.root / "signals" / signal_file).to_pylist() + + def read_signal_frame_for_sample(self, sample: Sample, manifest: DatasetManifest) -> dict[int, list[dict[str, Any]]]: + signals_by_id = {signal.id: signal for signal in manifest.signals} + result: dict[int, list[dict[str, Any]]] = {} + for signal_ref in sample.signals: + signal = signals_by_id.get(signal_ref.signal_id) + if signal is None: + continue + rows = self.read_signal_frame(signal.shard_file) + sliced = rows[signal.row_start : signal.row_start + signal.row_count] + if signal_ref.channels is not None: + filtered: list[dict[str, Any]] = [] + for row in sliced: + filtered.append( + { + key: value + for key, value in row.items() + if _is_time_column(key) or key in signal_ref.channels + } + ) + sliced = filtered + result[signal_ref.signal_id] = sliced + return result diff --git a/timenet_timef/labels.py b/timenet_timef/labels.py new file mode 100644 index 0000000..b27aa0f --- /dev/null +++ b/timenet_timef/labels.py @@ -0,0 +1,28 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +LABEL_SCHEMAS: dict[str, dict[str, list[str] | str]] = { + "binary": { + "description": "Binary yes/no classification", + "labels": ["yes", "no"], + }, +} + + +def get_schema(name: str) -> dict[str, list[str] | str]: + if name not in LABEL_SCHEMAS: + raise ValueError(f"Unknown label schema: {name}") + return LABEL_SCHEMAS[name] + + +def get_valid_labels(name: str) -> list[str]: + schema = get_schema(name) + labels = schema["labels"] + if not isinstance(labels, list): + raise ValueError(f"Invalid labels payload for schema {name}") + return labels diff --git a/timenet_timef/schema.py b/timenet_timef/schema.py new file mode 100644 index 0000000..24b8e6d --- /dev/null +++ b/timenet_timef/schema.py @@ -0,0 +1,295 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from typing import Any + +from timenet_timef.labels import get_valid_labels +from timenet_timef.units import validate_ucum + + +VALID_DOMAINS: frozenset[str] = frozenset( + { + "health", + "activity", + "general", + "forecasting", + "economics", + "sleep", + "cardiology", + "finance", + "energy", + "environment", + "industrial", + "transport", + } +) + + +@dataclass(slots=True) +class SensorSpec: + id: int + name: str + + +@dataclass(slots=True) +class SignalSpec: + id: int + name: str + channels: list[str] = field(default_factory=list) + unit_sampling_rate: str = "Hz" + unit_timestamp: str = "s" + unit_value: str | None = None + sensor_id: int | None = None + + def __post_init__(self) -> None: + validate_ucum(self.unit_sampling_rate) + validate_ucum(self.unit_timestamp) + if self.unit_value is not None: + validate_ucum(self.unit_value) + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "SignalSpec": + return cls(**payload) + + +@dataclass(slots=True) +class Signal: + id: int + spec_id: int + name: str + sampling_rate: float | None + total_duration: float + shard_file: str + row_group_id: int + row_start: int = 0 + row_count: int = 0 + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "Signal": + return cls(**payload) + + +@dataclass(slots=True) +class SampleSignalRef: + signal_id: int + sampling_rate: float | None = None + channels: list[str] | None = None + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "SampleSignalRef": + return cls(**payload) + + +@dataclass(slots=True) +class Sample: + id: int + dataset_id: str + unit_timestamp: str + windows: list[list[float]] | None = None + signals: list[SampleSignalRef] = field(default_factory=list) + + def __post_init__(self) -> None: + validate_ucum(self.unit_timestamp) + if not self.signals: + raise ValueError("Sample must reference at least one signal") + if self.windows is not None: + for window in self.windows: + if len(window) != 2 or window[0] >= window[1]: + raise ValueError("Sample windows must be [start, end] with start < end") + + def to_parquet_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "dataset_id": self.dataset_id, + "unit_timestamp": self.unit_timestamp, + "windows": json.dumps(self.windows), + "signals": json.dumps([asdict(sig) for sig in self.signals]), + } + + @classmethod + def from_parquet_dict(cls, payload: dict[str, Any]) -> "Sample": + windows = payload.get("windows") + signals = payload.get("signals") + if isinstance(windows, str): + windows = json.loads(windows) + if isinstance(signals, str): + signals = json.loads(signals) + return cls( + id=int(payload["id"]), + dataset_id=str(payload["dataset_id"]), + unit_timestamp=str(payload["unit_timestamp"]), + windows=windows, + signals=[SampleSignalRef.from_dict(sig) for sig in signals], + ) + + +@dataclass(slots=True) +class AnnotationSpec: + id: int + task: str + schema_: str | None = None + domains: list[str] = field(default_factory=list) + + def __post_init__(self) -> None: + invalid = set(self.domains) - VALID_DOMAINS + if invalid: + raise ValueError(f"Invalid annotation domains: {sorted(invalid)}") + if self.task in {"labeling", "classification"} and self.schema_ is None: + raise ValueError("Classification annotation specs require a schema") + if self.schema_ is not None: + get_valid_labels(self.schema_) + + def to_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "task": self.task, + "schema": self.schema_, + "domains": self.domains, + } + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "AnnotationSpec": + data = dict(payload) + if "schema" in data and "schema_" not in data: + data["schema_"] = data.pop("schema") + return cls(**data) + + def get_valid_labels(self) -> list[str] | None: + if self.schema_ is None: + return None + return get_valid_labels(self.schema_) + + +@dataclass(slots=True) +class AnnotationSampleRef: + sample: int + reference: str | None = None + forecast: bool = False + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "AnnotationSampleRef": + return cls(**payload) + + +@dataclass(slots=True) +class Annotation: + id: int + dataset_id: str + spec_id: int + samples: list[AnnotationSampleRef] + question: str | None = None + answer: str | None = None + label: str | None = None + rationale: str | None = None + + def __post_init__(self) -> None: + if not self.samples: + raise ValueError("Annotation must reference at least one sample") + has_answer = self.answer is not None + has_label = self.label is not None + if has_answer == has_label: + raise ValueError("Annotation must have exactly one of answer or label") + + def to_parquet_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "dataset_id": self.dataset_id, + "spec_id": self.spec_id, + "samples": json.dumps([asdict(sample) for sample in self.samples]), + "question": self.question, + "answer": self.answer, + "label": self.label, + "rationale": self.rationale, + } + + @classmethod + def from_parquet_dict(cls, payload: dict[str, Any]) -> "Annotation": + samples = payload.get("samples") + if isinstance(samples, str): + samples = json.loads(samples) + return cls( + id=int(payload["id"]), + dataset_id=str(payload["dataset_id"]), + spec_id=int(payload["spec_id"]), + samples=[AnnotationSampleRef.from_dict(sample) for sample in samples], + question=payload.get("question"), + answer=payload.get("answer"), + label=payload.get("label"), + rationale=payload.get("rationale"), + ) + + +@dataclass(slots=True) +class DatasetManifest: + dataset_id: str + version: str + source: str + description: str = "" + domains: list[str] = field(default_factory=list) + tags: list[str] = field(default_factory=list) + signals: list[Signal] = field(default_factory=list) + signal_spec: list[SignalSpec] = field(default_factory=list) + sensor_spec: list[SensorSpec] = field(default_factory=list) + annotation_spec: list[AnnotationSpec] = field(default_factory=list) + sample_count: int = 0 + annotation_count: int = 0 + validated: bool = False + + def __post_init__(self) -> None: + invalid = set(self.domains) - VALID_DOMAINS + if invalid: + raise ValueError(f"Invalid manifest domains: {sorted(invalid)}") + + def to_dict(self) -> dict[str, Any]: + return { + "dataset_id": self.dataset_id, + "version": self.version, + "source": self.source, + "description": self.description, + "domains": self.domains, + "tags": self.tags, + "signals": [asdict(signal) for signal in self.signals], + "signal_spec": [asdict(spec) for spec in self.signal_spec], + "sensor_spec": [asdict(spec) for spec in self.sensor_spec], + "annotation_spec": [spec.to_dict() for spec in self.annotation_spec], + "sample_count": self.sample_count, + "annotation_count": self.annotation_count, + "validated": self.validated, + } + + @classmethod + def from_dict(cls, payload: dict[str, Any]) -> "DatasetManifest": + return cls( + dataset_id=payload["dataset_id"], + version=payload["version"], + source=payload["source"], + description=payload.get("description", ""), + domains=list(payload.get("domains", [])), + tags=list(payload.get("tags", [])), + signals=[Signal.from_dict(item) for item in payload.get("signals", [])], + signal_spec=[SignalSpec.from_dict(item) for item in payload.get("signal_spec", [])], + sensor_spec=[SensorSpec(**item) for item in payload.get("sensor_spec", [])], + annotation_spec=[AnnotationSpec.from_dict(item) for item in payload.get("annotation_spec", [])], + sample_count=int(payload.get("sample_count", 0)), + annotation_count=int(payload.get("annotation_count", 0)), + validated=bool(payload.get("validated", False)), + ) + + +def validate_annotation_against_spec(annotation: Annotation, spec: AnnotationSpec) -> None: + if spec.task not in {"labeling", "classification"}: + return + if annotation.label is None: + raise ValueError(f"Classification annotation {annotation.id} requires a label") + valid_labels = spec.get_valid_labels() + if valid_labels is not None and annotation.label not in valid_labels: + raise ValueError( + f"Invalid label '{annotation.label}' for annotation spec {spec.id}. Valid labels: {valid_labels}" + ) diff --git a/timenet_timef/units.py b/timenet_timef/units.py new file mode 100644 index 0000000..03f88e2 --- /dev/null +++ b/timenet_timef/units.py @@ -0,0 +1,13 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + + +def validate_ucum(unit_code: str) -> str: + if not isinstance(unit_code, str) or not unit_code.strip(): + raise ValueError("Unit code must be a non-empty string") + return unit_code diff --git a/timenet_timef/utils.py b/timenet_timef/utils.py new file mode 100644 index 0000000..8b407a8 --- /dev/null +++ b/timenet_timef/utils.py @@ -0,0 +1,14 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +from pathlib import Path + + +def ensure_dir(path: Path) -> Path: + path.mkdir(parents=True, exist_ok=True) + return path diff --git a/timenet_timef/validate.py b/timenet_timef/validate.py new file mode 100644 index 0000000..85db686 --- /dev/null +++ b/timenet_timef/validate.py @@ -0,0 +1,106 @@ +# +# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) +# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. +# +# SPDX-License-Identifier: MIT +# +from __future__ import annotations + +import json +from pathlib import Path + +import pyarrow.parquet as pq + +from timenet_timef.schema import VALID_DOMAINS, validate_annotation_against_spec +from timenet_timef.io import TimeFReader + + +class TimeFValidationError(ValueError): + pass + + +def validate_dataset(root: Path) -> None: + reader = TimeFReader(root) + manifest = reader.read_manifest() + samples = list(reader.read_samples()) + annotations = list(reader.read_annotations()) + + if manifest.sample_count != len(samples): + raise TimeFValidationError("manifest.sample_count does not match samples.parquet") + if manifest.annotation_count != len(annotations): + raise TimeFValidationError("manifest.annotation_count does not match annotations.parquet") + + invalid_domains = set(manifest.domains) - VALID_DOMAINS + if invalid_domains: + raise TimeFValidationError(f"invalid manifest domains: {sorted(invalid_domains)}") + + signal_spec_by_id = {spec.id: spec for spec in manifest.signal_spec} + signal_by_id = {signal.id: signal for signal in manifest.signals} + annotation_spec_by_id = {spec.id: spec for spec in manifest.annotation_spec} + sample_ids = {sample.id for sample in samples} + + for required in ( + "manifest.json", + "samples.parquet", + "annotations.parquet", + "sample_signal_index.parquet", + "annotation_tasks_index.parquet", + "annotation_domains_index.parquet", + ): + if not (root / required).exists(): + raise TimeFValidationError(f"missing required file {required}") + + for signal in manifest.signals: + if signal.spec_id not in signal_spec_by_id: + raise TimeFValidationError(f"signal {signal.id} references unknown spec_id {signal.spec_id}") + shard_path = root / "signals" / signal.shard_file + if not shard_path.exists(): + raise TimeFValidationError(f"missing shard file {signal.shard_file}") + metadata = pq.ParquetFile(shard_path).metadata + if signal.row_group_id >= metadata.num_row_groups: + raise TimeFValidationError(f"row_group_id out of bounds for signal {signal.id}") + row_group_rows = metadata.row_group(signal.row_group_id).num_rows + if signal.row_start + signal.row_count > row_group_rows: + raise TimeFValidationError(f"row span out of bounds for signal {signal.id}") + + for sample in samples: + for signal_ref in sample.signals: + if signal_ref.signal_id not in signal_by_id: + raise TimeFValidationError( + f"sample {sample.id} references unknown signal_id {signal_ref.signal_id}" + ) + if signal_ref.channels is not None: + spec = signal_spec_by_id[signal_by_id[signal_ref.signal_id].spec_id] + invalid_channels = set(signal_ref.channels) - set(spec.channels) + if invalid_channels: + raise TimeFValidationError( + f"sample {sample.id} selects invalid channels {sorted(invalid_channels)}" + ) + + for annotation in annotations: + spec = annotation_spec_by_id.get(annotation.spec_id) + if spec is None: + raise TimeFValidationError(f"annotation {annotation.id} references unknown spec_id {annotation.spec_id}") + for ref in annotation.samples: + if ref.sample not in sample_ids: + raise TimeFValidationError( + f"annotation {annotation.id} references unknown sample {ref.sample}" + ) + if ref.reference is not None: + try: + json.loads(ref.reference) + except json.JSONDecodeError as exc: + raise TimeFValidationError( + f"annotation {annotation.id} has invalid reference JSON" + ) from exc + try: + validate_annotation_against_spec(annotation, spec) + except ValueError as exc: + raise TimeFValidationError(str(exc)) from exc + + +def mark_validated(root: Path, validated: bool = True) -> None: + manifest_path = root / "manifest.json" + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + payload["validated"] = validated + manifest_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") diff --git a/transformer.py b/transformer.py index b726394..601d743 100644 --- a/transformer.py +++ b/transformer.py @@ -8,12 +8,9 @@ import abc -from timef.schema import SignalSpec, Signal +from runtime import RuntimeRow class Transformer(abc.ABC): @abc.abstractmethod - def get_signal_specs(self) -> list[SignalSpec]: ... - - @abc.abstractmethod - def transform_row(self, row: dict) -> list[Signal]: ... + def transform_row(self, row: dict) -> RuntimeRow: ... diff --git a/visualizer.py b/visualizer.py index 72b45fe..6f56e2a 100644 --- a/visualizer.py +++ b/visualizer.py @@ -12,52 +12,31 @@ import matplotlib.pyplot as plt import numpy as np -from timef.schema import Annotation, Sample, Signal +from runtime import RuntimeRow def plot_row( - signals: list[Signal], - samples: list[Sample], - annotations: list[Annotation], + row: RuntimeRow, save_path: str | None = None, n_cols: int = 3, ) -> plt.Figure | None: - channel_names = [s.name for s in signals] + channel_names = list(row.channel_names) n_channels = len(channel_names) - - data = np.stack([s.data for s in signals]) + data = row.values time_minutes = np.arange(data.shape[1]) - # Build signal_id -> channel_name - sig_id_to_name = {s.id: s.name for s in signals} - - # Build sample_id -> list of channel names - sample_channels: dict[str, list[str]] = {} - for sample in samples: - names = [ - sig_id_to_name[ref.signal_id] - for ref in sample.signals - if ref.signal_id in sig_id_to_name - ] - if names: - sample_channels[sample.id] = names - - # Split annotations: semantic captions are global, others are per-channel - semantic_captions: list[str] = [] - channel_captions: dict[str, dict[str, list[str]]] = {ch: {} for ch in channel_names} - for ann in annotations: - if not ann.answer: + semantic_captions = [ + annotation.text for annotation in row.annotations if annotation.kind == "semantic" and annotation.text + ] + channel_captions: dict[str, dict[str, list[str]]] = {name: {} for name in channel_names} + + for caption in row.annotations: + if not caption.text or caption.kind == "semantic": continue - ann_type = ann.spec_id.split(":")[-1] if ":" in ann.spec_id else ann.spec_id - if ann_type == "semantic": - semantic_captions.append(ann.answer) - else: - for sample_ref in ann.samples: - for ch in sample_channels.get(sample_ref.sample_id, []): - if ch in channel_captions: - channel_captions[ch].setdefault(ann_type, []).append(ann.answer) - - # Derive per-channel nonwear regions from NaN + for idx in caption.channel_idxs: + name = channel_names[idx] + channel_captions[name].setdefault(caption.kind, []).append(caption.text) + def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]]: regions = [] in_region = False @@ -111,9 +90,7 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] elif i in bottom_row_indices: ax.set_xlabel("Time (minutes)") - meta = signals[0].metadata if signals else {} - title = f"{meta.get('user_id', '')} | {meta.get('date', '')}" - fig.suptitle(title, fontsize=10, y=1.0) + fig.suptitle(f"{row.user_id} | {row.date}", fontsize=10, y=1.0) if semantic_captions: sem_text = textwrap.fill(" ".join(semantic_captions), width=140) @@ -145,7 +122,7 @@ def _nan_regions(arr: np.ndarray, min_length: int = 30) -> list[tuple[int, int]] dataset = MHCDataset(min_wear_pct=90.0) annotator = Annotator([StatisticalExtractor(MHC_CHANNEL_CONFIG), StructuralExtractor(MHC_CHANNEL_CONFIG)]) captionizer = Captionizer(dataset, MHCTransformer(), annotator) - result = captionizer.run(max_rows=1) - for signals, samples, annotations in result.iter_rows(): - plot_row(signals, samples, annotations, save_path="sample_plot.png") + result, _ = captionizer.run(max_rows=1) + for row in result.iter_rows(): + plot_row(row, save_path="sample_plot.png") print("Saved sample_plot.png")