diff --git a/curriculum_learning.py b/curriculum_learning.py index 4989d0f..e225c8e 100644 --- a/curriculum_learning.py +++ b/curriculum_learning.py @@ -683,13 +683,22 @@ def main(): from time_series_datasets.mhc_caption_qa_dataset import MHCCaptionQADataset + split_file = os.environ.get("MHC_SPLIT_FILE") + if not split_file: + raise RuntimeError( + "MHC_SPLIT_FILE env var is required: point at MHC-benchmark's " + "sharable_users_seed42_2026.json so split assignment is consistent." + ) + caption_kwargs = { "export_dir": os.environ.get("CAPTION_EXPORT_DIR", "exports/lean_smoke"), + "split_file": split_file, } metabolic_kwargs = { "labels": os.environ.get("METABOLIC_LABELS", ",".join(DEFAULT_METABOLIC_LABELS)).split(","), "export_dir": os.environ.get("CAPTION_EXPORT_DIR", "exports/lean_smoke"), "label_lookup_path": os.environ.get("METABOLIC_LABEL_LOOKUP"), + "split_file": split_file, } curriculum_stages = OrderedDict( diff --git a/evaluation/evaluator.py b/evaluation/evaluator.py index ca332c7..b7191bd 100644 --- a/evaluation/evaluator.py +++ b/evaluation/evaluator.py @@ -18,10 +18,14 @@ from dataclasses import dataclass from tqdm import tqdm +from transformers import logging as hf_logging from evaluation.tasks.base import EvalTask from models.base import BaseModel, ModelResponse +# Silence the per-generate "Setting pad_token_id to eos_token_id" warning. +hf_logging.set_verbosity_error() + @dataclass(slots=True) class SampleResult: @@ -54,39 +58,60 @@ def run( dataset, model_name: str, max_samples: int | None = None, + save_every: int | None = None, + results_writer=None, + out_dir=None, ) -> EvalResult: """Evaluate *model* on *dataset* using *task*. The dataset must be an OpenTSLM-style ``QADataset`` and expose ``get_eval_context(idx)`` returning ``(Recording, ground_truth)``. ``dataset[idx]`` provides the model-facing sample dict. + + With ``save_every`` + ``results_writer`` + ``out_dir`` set, partial + results are flushed every N samples and on KeyboardInterrupt. """ n = len(dataset) if max_samples is None else min(max_samples, len(dataset)) per_sample: list[SampleResult] = [] - for i in tqdm(range(n), desc=f"{model_name} / {task.name}"): - recording, ground_truth = dataset.get_eval_context(i) - sample = dataset[i] - - response: ModelResponse = model.generate(sample=sample, recording=recording) - parsed = task.parse_prediction(response.text) - - per_sample.append(SampleResult( - sample_id=recording.row_id, - ground_truth=ground_truth, - raw_prediction=response.text, - parsed_prediction=parsed, - input_tokens=response.input_tokens, - output_tokens=response.output_tokens, - metadata={"user_id": recording.user_id, "date": recording.date}, - )) - - gts = [s.ground_truth for s in per_sample] - preds = [s.parsed_prediction for s in per_sample] - - return EvalResult( - model_name=model_name, - task_name=task.name, - per_sample=per_sample, - metrics=task.aggregate_metrics(gts, preds), - ) + def _build_partial() -> "EvalResult": + gts = [s.ground_truth for s in per_sample] + preds = [s.parsed_prediction for s in per_sample] + metrics = task.aggregate_metrics(gts, preds) if per_sample else {} + return EvalResult( + model_name=model_name, + task_name=task.name, + per_sample=per_sample, + metrics=metrics, + ) + + def _flush() -> None: + if results_writer is not None and out_dir is not None and per_sample: + results_writer.write(_build_partial(), out_dir=out_dir) + + try: + for i in tqdm(range(n), desc=f"{model_name} / {task.name}"): + recording, ground_truth = dataset.get_eval_context(i) + sample = dataset[i] + + response: ModelResponse = model.generate(sample=sample, recording=recording) + parsed = task.parse_prediction(response.text) + + per_sample.append(SampleResult( + sample_id=recording.row_id, + ground_truth=ground_truth, + raw_prediction=response.text, + parsed_prediction=parsed, + input_tokens=response.input_tokens, + output_tokens=response.output_tokens, + metadata={"user_id": recording.user_id, "date": recording.date}, + )) + + if save_every and (i + 1) % save_every == 0: + _flush() + except KeyboardInterrupt: + print(f"\n[interrupted at sample {len(per_sample)}/{n}] saving partial results") + _flush() + raise + + return _build_partial() diff --git a/evaluation/io/writer.py b/evaluation/io/writer.py index d7aa186..e13b886 100644 --- a/evaluation/io/writer.py +++ b/evaluation/io/writer.py @@ -36,18 +36,25 @@ class ResultsWriter: def __init__(self, results_dir: Path | str = Path("results")) -> None: self.results_dir = Path(results_dir) - def write(self, result: EvalResult) -> Path: + def write(self, result: EvalResult, out_dir: Path | None = None) -> Path: """Persist *result* to disk and return the output directory path. + If ``out_dir`` is given, write into it (overwriting existing files); + otherwise create a fresh ```` subdir. + Creates: metrics.json — aggregated metrics dict (directly JSON-serialisable). predictions.csv — one row per evaluated sample. """ - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") - safe_task = re.sub(r"[^\w\-.]", "_", result.task_name) - safe_model = re.sub(r"[^\w\-.]", "_", result.model_name) - out_dir = self.results_dir / safe_task / safe_model / timestamp - out_dir.mkdir(parents=True, exist_ok=False) + if out_dir is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + safe_task = re.sub(r"[^\w\-.]", "_", result.task_name) + safe_model = re.sub(r"[^\w\-.]", "_", result.model_name) + out_dir = self.results_dir / safe_task / safe_model / timestamp + out_dir.mkdir(parents=True, exist_ok=False) + else: + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) # metrics.json with open(out_dir / "metrics.json", "w") as f: diff --git a/evaluation/run_eval.py b/evaluation/run_eval.py index e4dd8ec..db24c83 100644 --- a/evaluation/run_eval.py +++ b/evaluation/run_eval.py @@ -76,6 +76,13 @@ def _parse_args() -> argparse.Namespace: "OpenTSLM.load_pretrained." ), ) + parser.add_argument( + "--save-every", type=int, default=None, + help=( + "Persist partial results every N samples (and on KeyboardInterrupt) " + "to the same output dir. Useful for long evals on slow filesystems." + ), + ) return parser.parse_args() @@ -100,6 +107,12 @@ def _build_model(args: argparse.Namespace): def _build_dataset(args: argparse.Namespace, task: MetabolicTask): + split_file = os.environ.get("MHC_SPLIT_FILE") + if not split_file: + raise RuntimeError( + "MHC_SPLIT_FILE env var is required: point at MHC-benchmark's " + "sharable_users_seed42_2026.json so split assignment is consistent." + ) return task.dataset_cls( split="test", EOS_TOKEN="", @@ -108,6 +121,7 @@ def _build_dataset(args: argparse.Namespace, task: MetabolicTask): label_lookup_path=os.environ.get( "MHC_LABELS_PATH", "data/labels_dev_subset.parquet" ), + split_file=split_file, ) @@ -131,6 +145,15 @@ def main() -> None: print("No samples found — check MHC_EXPORT_DIR.") sys.exit(1) + # Pre-create the output dir so periodic saves rewrite the same path. + from datetime import datetime + import re + writer = ResultsWriter() + safe_task = re.sub(r"[^\w\-.]", "_", task.name) + safe_model = re.sub(r"[^\w\-.]", "_", model_name) + out_dir = writer.results_dir / safe_task / safe_model / datetime.now().strftime("%Y%m%d_%H%M%S") + out_dir.mkdir(parents=True, exist_ok=True) + evaluator = Evaluator() result = evaluator.run( task=task, @@ -138,9 +161,12 @@ def main() -> None: dataset=dataset, model_name=model_name, max_samples=max_samples, + save_every=args.save_every, + results_writer=writer, + out_dir=out_dir, ) - ResultsWriter().write(result) + writer.write(result, out_dir=out_dir) print("\n=== RESULTS ===") for key, val in result.metrics.items(): diff --git a/mhc/dataset.py b/mhc/dataset.py index f1aed03..8f7ba83 100644 --- a/mhc/dataset.py +++ b/mhc/dataset.py @@ -4,8 +4,11 @@ # # SPDX-License-Identifier: MIT # +import json + from datasets import load_from_disk import numpy as np +import pyarrow as pa import pyarrow.compute as pc from torch.utils.data import Dataset @@ -13,8 +16,30 @@ class MHCDataset(Dataset): - def __init__(self, min_wear_pct: float = 0.0, dataset_dir: str = DATASET_DIR): + def __init__( + self, + min_wear_pct: float = 0.0, + dataset_dir: str = DATASET_DIR, + split_file: str | None = None, + ): + """Load MHC daily dataset, optionally filtered by ``split_file`` + (MHC-benchmark's ``sharable_users_seed42_2026.json`` — rows with + ``user_id`` outside train ∪ validation ∪ test are dropped).""" self.ds = load_from_disk(dataset_dir) + + if split_file is not None: + with open(split_file) as f: + splits = json.load(f) + allowed_users = ( + set(splits["train"]) | set(splits["validation"]) | set(splits["test"]) + ) + mask = pc.is_in( + self.ds.data.column("user_id"), + value_set=pa.array(sorted(allowed_users)), + ) + indices = [i for i, m in enumerate(mask.to_pylist()) if m] + self.ds = self.ds.select(indices) + if min_wear_pct > 0: max_nonwear = 1440 * (1 - min_wear_pct / 100) mask = pc.less_equal(self.ds.data.column("total_nonwear_minutes"), max_nonwear) diff --git a/mhc_weekly/dataset.py b/mhc_weekly/dataset.py index 5ddc92c..a880ca4 100644 --- a/mhc_weekly/dataset.py +++ b/mhc_weekly/dataset.py @@ -4,6 +4,8 @@ # # SPDX-License-Identifier: MIT # +import json + from datasets import load_from_disk import numpy as np from torch.utils.data import Dataset @@ -12,8 +14,25 @@ class WeeklyMHCDataset(Dataset): - def __init__(self, min_valid_hours: int = 0, dataset_dir: str = WEEKLY_DATASET_DIR): + def __init__( + self, + min_valid_hours: int = 0, + dataset_dir: str = WEEKLY_DATASET_DIR, + split_file: str | None = None, + ): + """Load weekly MHC dataset, optionally filtered by ``split_file`` + (MHC-benchmark's ``sharable_users_seed42_2026.json`` — rows with + ``user_id`` outside train ∪ validation ∪ test are dropped).""" self.ds = load_from_disk(dataset_dir) + + if split_file is not None: + with open(split_file) as f: + splits = json.load(f) + allowed_users = ( + set(splits["train"]) | set(splits["validation"]) | set(splits["test"]) + ) + self.ds = self.ds.filter(lambda row: row["user_id"] in allowed_users) + if min_valid_hours > 0: self.ds = self.ds.filter(lambda row: row["n_valid_hours"] >= min_valid_hours) diff --git a/scripts/export_captions.py b/scripts/export_captions.py index 35b9df9..0ea5a9a 100644 --- a/scripts/export_captions.py +++ b/scripts/export_captions.py @@ -29,19 +29,23 @@ from timef.schema import CaptionResult -def _build_daily(min_wear_pct: float): +def _build_daily(min_wear_pct: float, split_file: str | None): from mhc.constants import MHC_CHANNEL_CONFIG from mhc.dataset import MHCDataset from mhc.transformer import MHCTransformer - return MHCDataset(min_wear_pct=min_wear_pct), MHCTransformer(), MHC_CHANNEL_CONFIG + return ( + MHCDataset(min_wear_pct=min_wear_pct, split_file=split_file), + MHCTransformer(), + MHC_CHANNEL_CONFIG, + ) -def _build_weekly(min_valid_hours: int): +def _build_weekly(min_valid_hours: int, split_file: str | None): from mhc_weekly.constants import WEEKLY_CHANNEL_CONFIG from mhc_weekly.dataset import WeeklyMHCDataset from mhc_weekly.transformer import WeeklyMHCTransformer return ( - WeeklyMHCDataset(min_valid_hours=min_valid_hours), + WeeklyMHCDataset(min_valid_hours=min_valid_hours, split_file=split_file), WeeklyMHCTransformer(), WEEKLY_CHANNEL_CONFIG, ) @@ -65,12 +69,19 @@ def main(): help="Start index (for sharded runs).") parser.add_argument("--end", type=int, default=None, help="End index exclusive (for sharded runs).") + parser.add_argument("--split_file", default=None, + help="MHC-benchmark canonical split JSON. When set, " + "non-sharable users are dropped at load time.") args = parser.parse_args() if args.variant == "daily": - dataset, transformer, channel_config = _build_daily(args.min_wear_pct) + dataset, transformer, channel_config = _build_daily( + args.min_wear_pct, args.split_file + ) else: - dataset, transformer, channel_config = _build_weekly(args.min_valid_hours) + dataset, transformer, channel_config = _build_weekly( + args.min_valid_hours, args.split_file + ) annotator = Annotator([ StructuralExtractor(channel_config), diff --git a/scripts/export_captions_sharded.sh b/scripts/export_captions_sharded.sh index 6a7250a..6286f4d 100755 --- a/scripts/export_captions_sharded.sh +++ b/scripts/export_captions_sharded.sh @@ -21,6 +21,8 @@ # REPO_DIR default $HOME/SensorTSLM # VENV_DIR default $SCRATCH/sensor-venv (used via sherlock_setup.sh) # TOTAL_ROWS auto-detected for daily; defaults to 615490 for weekly +# SPLIT_FILE path to MHC-benchmark sharable_users_seed42_2026.json; when +# set, non-sharable users are dropped at load time set -euo pipefail @@ -57,6 +59,12 @@ fi mkdir -p "$REPO_DIR/logs" STRIDE=$(( (TOTAL_ROWS + NUM_SHARDS - 1) / NUM_SHARDS )) +SPLIT_ARG="" +if [ -n "${SPLIT_FILE:-}" ]; then + SPLIT_ARG="--split_file $SPLIT_FILE" + echo " split_file: $SPLIT_FILE" +fi + echo "Submitting $NUM_SHARDS shards of ~$STRIDE rows each ($VARIANT)" echo " out: $OUT_DIR/shard_" echo " source: $DATA_DIR" @@ -71,7 +79,7 @@ for i in $(seq 0 $((NUM_SHARDS - 1))); do source \$SCRATCH/sensor-venv/bin/activate; \ export ${ENV_VAR}=$DATA_DIR; \ cd $REPO_DIR; \ -python scripts/export_captions.py --variant $VARIANT --out $OUT_DIR/shard_$i --start $START --end $END" +python scripts/export_captions.py --variant $VARIANT --out $OUT_DIR/shard_$i --start $START --end $END $SPLIT_ARG" done echo "Check: squeue -u \$USER" diff --git a/scripts/linear_probe.py b/scripts/linear_probe.py deleted file mode 100644 index 79fd8e0..0000000 --- a/scripts/linear_probe.py +++ /dev/null @@ -1,432 +0,0 @@ -# -# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) -# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. -# -# SPDX-License-Identifier: MIT -# -"""Linear-probe a trained OpenTSLMSP checkpoint against MHC downstream labels. - -Two phases (first run does both; re-runs with --skip_extract skip phase 1): - 1. Extract a fixed-length embedding per sample (no gradients). - --embedding_source projector : mean-pool encoder+projector output (fast, skips LLM, ignores LoRA) - --embedding_source llm_last : LLM forward, mean-pool last hidden state (captures LoRA + prompt context) - 2. Fit sklearn probes per label on the user-hash split, report metrics. - -Example: - python scripts/linear_probe.py \ - --checkpoint results/.../latest.pt \ - --label_lookup /scratch/.../weekly_labels_lookup.parquet \ - --export_dir exports/lean_smoke \ - --output_dir probe_results/stage1 \ - --embedding_source llm_last -""" -from __future__ import annotations - -import argparse -import hashlib -import json -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) - -import numpy as np -import pandas as pd -import pyarrow.feather as feather -import torch -from torch.nn.utils.rnn import pad_sequence - -from opentslm.model.llm.OpenTSLMSP import OpenTSLMSP - -from mhc.constants import CONTINUOUS_CHANNELS -from time_series_datasets.mhc_label_lookup import LABEL_TYPES, LabelLookup - - -TRAIN_FRAC = 0.8 -VAL_FRAC = 0.1 -PRE_PROMPT = "You are an expert in wearable sensor time series analysis." -POST_PROMPT = "" - - -def _assign_split(user_id: str) -> str: - digest = int(hashlib.sha1(user_id.encode()).hexdigest(), 16) - bucket = (digest % 1000) / 1000.0 - if bucket < TRAIN_FRAC: - return "train" - if bucket < TRAIN_FRAC + VAL_FRAC: - return "validation" - return "test" - - -def _normalize_channel(channel: np.ndarray, is_continuous: bool) -> tuple[np.ndarray, float, float]: - if not is_continuous: - return np.nan_to_num(channel, nan=0.0).astype(np.float32), 0.0, 1.0 - finite = channel[np.isfinite(channel)] - if finite.size == 0: - return np.zeros_like(channel, dtype=np.float32), 0.0, 1.0 - mean = float(finite.mean()) - std = max(float(finite.std()), 1e-6) - out = (channel - mean) / std - return np.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32), mean, std - - -def _build_sample(row: dict) -> dict: - """Match MHCBaseQADataset's prompt + time-series formatting.""" - values = np.asarray(row["values"], dtype=np.float32) - channel_names = row["channel_names"] - display_names = row["display_names"] - units = row["units"] - - texts: list[str] = [] - tensors: list[torch.Tensor] = [] - for i, name in enumerate(channel_names): - is_cont = name in CONTINUOUS_CHANNELS - normed, mean, std = _normalize_channel(values[i], is_cont) - tensors.append(torch.from_numpy(normed)) - display = display_names[i] - unit = units[i] - unit_str = f" in {unit}" if unit else "" - if is_cont: - texts.append( - f"The following is the {display}{unit_str}, " - f"it has mean {mean:.4f} and std {std:.4f}:" - ) - else: - texts.append( - f"The following is a per-minute binary mask for {display} " - f"(1 = active, 0 = inactive, NaN = nonwear):" - ) - return { - "pre_prompt": PRE_PROMPT, - "time_series_text": texts, - "time_series": tensors, - "post_prompt": POST_PROMPT, - } - - -def _projector_embedding(model, batch_rows: list[dict]) -> torch.Tensor: - n_ch = len(batch_rows[0]["channel_names"]) - ts_list: list[torch.Tensor] = [] - for row in batch_rows: - values = np.asarray(row["values"], dtype=np.float32) - channel_names = row["channel_names"] - for i, name in enumerate(channel_names): - normed, _, _ = _normalize_channel(values[i], name in CONTINUOUS_CHANNELS) - ts_list.append(torch.from_numpy(normed)) - - ts_padded = pad_sequence(ts_list, batch_first=True) - T = ts_padded.size(1) - rem = T % model.patch_size - if rem: - pad = ts_padded.new_zeros(ts_padded.size(0), model.patch_size - rem) - ts_padded = torch.cat([ts_padded, pad], dim=1) - ts_padded = ts_padded.to(model.device, non_blocking=True) - - enc = model.encoder(ts_padded) - proj = model.projector(enc).float() - B = len(batch_rows) - return proj.view(B, n_ch, proj.size(1), proj.size(2)).mean(dim=(1, 2)) - - -def _llm_last_embedding(model, batch_rows: list[dict]) -> torch.Tensor: - samples = [_build_sample(row) for row in batch_rows] - inputs_embeds, attention_mask = model.pad_and_apply_batch(samples) - outputs = model.llm( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - output_hidden_states=True, - return_dict=True, - ) - hidden = outputs.hidden_states[-1] # [B, L, H] - mask = attention_mask.unsqueeze(-1).to(hidden.dtype) # [B, L, 1] - pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) - return pooled.float() - - -def _snapshot(path: Path, embeddings, user_ids, dates, completed, source) -> None: - tmp = path.with_name(path.stem + ".tmp.npz") - np.savez( - tmp, - embeddings=embeddings[:completed], - user_ids=np.array(user_ids[:completed]), - dates=np.array(dates[:completed]), - completed=np.array(completed), - embedding_source=np.array(source), - ) - tmp.replace(path) - - -def _finalize(out_path: Path, embeddings, user_ids, dates, completed, source, label_lookup_path) -> None: - embeddings = embeddings[:completed] - user_ids = user_ids[:completed] - dates = dates[:completed] - print(f"Joining labels for {completed} samples...") - lookup = LabelLookup(label_lookup_path) - labels_df = lookup.join(user_ids, dates) - splits = np.array([_assign_split(uid) for uid in user_ids]) - np.savez_compressed( - out_path / "embeddings.npz", - embeddings=embeddings, - user_ids=np.array(user_ids), - dates=np.array(dates), - splits=splits, - embedding_source=np.array(source), - ) - labels_df.to_parquet(out_path / "labels.parquet") - # Clean up partial snapshot — final file supersedes it - partial = out_path / "embeddings_partial.npz" - if partial.exists(): - partial.unlink() - print(f"Saved embeddings + labels to {out_path} (source={source})") - - -def extract_embeddings(args, model) -> None: - export_dir = Path(args.export_dir) - shards = sorted(export_dir.glob("recordings_*.arrow")) - if not shards: - raise FileNotFoundError(f"No shards in {export_dir}") - - rows = [] - for shard in shards: - rows.extend(feather.read_table(shard).to_pylist()) - print(f"Loaded {len(rows)} rows from {len(shards)} shards") - - if args.max_samples and args.max_samples < len(rows): - rng = np.random.default_rng(args.seed) - idx = rng.choice(len(rows), size=args.max_samples, replace=False) - rows = [rows[i] for i in sorted(idx)] - print(f"Subsampled to {len(rows)} rows (seed={args.seed})") - - hidden = model.llm.config.hidden_size - out_path = Path(args.output_dir) - out_path.mkdir(parents=True, exist_ok=True) - partial_path = out_path / "embeddings_partial.npz" - - embeddings = np.zeros((len(rows), hidden), dtype=np.float32) - user_ids: list[str] = [""] * len(rows) - dates: list[str] = [""] * len(rows) - completed = 0 - - resume_from_path = None - if partial_path.exists() and not args.no_resume: - resume_from_path = partial_path - elif (out_path / "embeddings.npz").exists() and not args.no_resume: - resume_from_path = out_path / "embeddings.npz" - print(f"Found finalized {resume_from_path.name}, treating as resume point.") - - if resume_from_path is not None: - cached = np.load(resume_from_path, allow_pickle=True) - prev_completed = int(cached["completed"]) if "completed" in cached.files else len(cached["embeddings"]) - prev_source = str(cached["embedding_source"]) - if prev_source != args.embedding_source: - print( - f"Partial file has source={prev_source} but args.embedding_source={args.embedding_source}." - " Ignoring and restarting. Use --no_resume to silence, or delete partial manually." - ) - elif prev_completed > len(rows): - print( - f"Partial file has {prev_completed} samples but current run has only {len(rows)} rows." - " Ignoring (likely different --max_samples/seed). Use --no_resume to silence." - ) - else: - embeddings[:prev_completed] = cached["embeddings"] - prev_users = list(cached["user_ids"]) - prev_dates = list(cached["dates"]) - user_ids[:prev_completed] = prev_users - dates[:prev_completed] = prev_dates - completed = prev_completed - print(f"Resuming from sample {completed}/{len(rows)}") - - model.eval() - get_embed = _projector_embedding if args.embedding_source == "projector" else _llm_last_embedding - - try: - with torch.no_grad(): - for start in range(completed, len(rows), args.batch_size): - batch = rows[start : start + args.batch_size] - pooled = get_embed(model, batch) - embeddings[start : start + len(batch)] = pooled.cpu().numpy() - for i, row in enumerate(batch): - user_ids[start + i] = str(row["user_id"]) - dates[start + i] = str(row["date"]) - completed = start + len(batch) - - if (start // args.batch_size) % 50 == 0: - print(f" {completed}/{len(rows)}") - - if args.save_every > 0 and completed % args.save_every < args.batch_size: - _snapshot( - partial_path, embeddings, user_ids, dates, completed, args.embedding_source - ) - except KeyboardInterrupt: - print(f"\nInterrupted at {completed}/{len(rows)}. Saving partial snapshot...") - _snapshot( - partial_path, embeddings, user_ids, dates, completed, args.embedding_source - ) - print(f"Partial saved to {partial_path}. Re-run to resume or pass --skip_extract to probe partial.") - if completed == 0: - raise - _finalize( - out_path, embeddings, user_ids, dates, completed, args.embedding_source, args.label_lookup - ) - return - - _finalize( - out_path, embeddings, user_ids, dates, completed, args.embedding_source, args.label_lookup - ) - - -def run_probes(args) -> None: - from sklearn.linear_model import LogisticRegression, Ridge - from sklearn.metrics import ( - balanced_accuracy_score, - mean_absolute_error, - r2_score, - roc_auc_score, - ) - - out_path = Path(args.output_dir) - final_path = out_path / "embeddings.npz" - partial_path = out_path / "embeddings_partial.npz" - if not final_path.exists() and partial_path.exists(): - print(f"No final embeddings.npz — finalizing from partial snapshot.") - cached = np.load(partial_path, allow_pickle=True) - completed = int(cached["completed"]) - embeddings = cached["embeddings"] - user_ids = list(cached["user_ids"]) - dates = list(cached["dates"]) - source = str(cached["embedding_source"]) - _finalize(out_path, embeddings, user_ids, dates, completed, source, args.label_lookup) - cached = np.load(final_path, allow_pickle=True) - X = cached["embeddings"] - splits = cached["splits"] - labels_df = pd.read_parquet(out_path / "labels.parquet") - - train_mask = splits == "train" - test_mask = splits == "test" - print(f"Splits: train={train_mask.sum()}, test={test_mask.sum()}") - - results: dict[str, dict] = {} - for label in labels_df.columns: - ltype = LABEL_TYPES.get(label) - if ltype is None: - continue - y_raw = labels_df[label].to_numpy() - valid = ~pd.isna(y_raw) - tr = train_mask & valid - te = test_mask & valid - if tr.sum() < 50 or te.sum() < 50: - continue - - Xtr, Xte = X[tr], X[te] - if ltype == "binary": - ytr = y_raw[tr].astype(int) - yte = y_raw[te].astype(int) - clf = LogisticRegression(max_iter=2000, n_jobs=-1).fit(Xtr, ytr) - prob = clf.predict_proba(Xte)[:, 1] - pred = (prob > 0.5).astype(int) - results[label] = { - "type": ltype, - "n_train": int(tr.sum()), - "n_test": int(te.sum()), - "auroc": float(roc_auc_score(yte, prob)), - "bal_acc": float(balanced_accuracy_score(yte, pred)), - } - elif ltype == "ordinal": - ytr = y_raw[tr].astype(int) - yte = y_raw[te].astype(int) - clf = LogisticRegression(max_iter=2000, n_jobs=-1).fit(Xtr, ytr) - pred = clf.predict(Xte) - results[label] = { - "type": ltype, - "n_train": int(tr.sum()), - "n_test": int(te.sum()), - "bal_acc": float(balanced_accuracy_score(yte, pred)), - "mae": float(mean_absolute_error(yte, pred)), - } - else: # continuous - ytr = y_raw[tr].astype(float) - yte = y_raw[te].astype(float) - reg = Ridge(alpha=1.0).fit(Xtr, ytr) - pred = reg.predict(Xte) - results[label] = { - "type": ltype, - "n_train": int(tr.sum()), - "n_test": int(te.sum()), - "mae": float(mean_absolute_error(yte, pred)), - "r2": float(r2_score(yte, pred)), - } - - r = results[label] - summary = ", ".join( - f"{k}={v:.3f}" if isinstance(v, float) else f"{k}={v}" - for k, v in r.items() - if k != "type" - ) - print(f" {label} ({r['type']}): {summary}") - - with open(out_path / "probe_results.json", "w") as f: - json.dump(results, f, indent=2) - print(f"Wrote {out_path / 'probe_results.json'}") - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--checkpoint", required=True) - parser.add_argument("--label_lookup", required=True) - parser.add_argument("--export_dir", default="exports/lean_smoke") - parser.add_argument("--output_dir", default="probe_results") - parser.add_argument("--llm_id", default="meta-llama/Llama-3.2-1B") - parser.add_argument("--batch_size", type=int, default=16) - parser.add_argument("--device", default=None) - parser.add_argument( - "--embedding_source", - choices=["projector", "llm_last"], - default="llm_last", - help="projector: encoder+projector only (fast, no LoRA). " - "llm_last: full LLM forward, mean-pool last hidden state.", - ) - parser.add_argument( - "--lora_r", type=int, default=16, help="Used only if checkpoint has lora_state" - ) - parser.add_argument("--lora_alpha", type=int, default=32) - parser.add_argument( - "--max_samples", type=int, default=0, - help="If >0, randomly subsample this many rows before extraction.", - ) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument( - "--save_every", type=int, default=500, - help="Snapshot partial embeddings every N samples. 0 disables.", - ) - parser.add_argument( - "--no_resume", action="store_true", - help="Ignore existing embeddings_partial.npz and start from scratch.", - ) - parser.add_argument( - "--skip_extract", - action="store_true", - help="Reuse cached embeddings.npz/labels.parquet from --output_dir", - ) - args = parser.parse_args() - - device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") - - if not args.skip_extract: - model = OpenTSLMSP(llm_id=args.llm_id, device=device) - ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) - model.encoder.load_state_dict(ckpt["encoder_state"]) - model.projector.load_state_dict(ckpt["projector_state"]) - if "lora_state" in ckpt: - model.enable_lora( - lora_r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=0.0 - ) - model.load_lora_state_from_checkpoint(ckpt, allow_missing=True) - print(f"Loaded checkpoint from epoch {ckpt.get('epoch', '?')}") - extract_embeddings(args, model) - - run_probes(args) - - -if __name__ == "__main__": - main() diff --git a/scripts/linear_probe_vlm.py b/scripts/linear_probe_vlm.py deleted file mode 100644 index 9be38b5..0000000 --- a/scripts/linear_probe_vlm.py +++ /dev/null @@ -1,510 +0,0 @@ -# -*- coding: utf-8 -*- -# -# SPDX-FileCopyrightText: 2026 Stanford University, ETH Zurich, and the project authors (see CONTRIBUTORS.md) -# SPDX-FileCopyrightText: 2026 This source file is part of the SensorTSLM open-source project. -# -# SPDX-License-Identifier: MIT -# -"""Linear-probe a frozen VLM against MHC downstream labels. - -Renders weekly sensor tensors as multi-channel plots, feeds them to a frozen -vision-language model (e.g. Gemma 4 2B), extracts the last hidden state at the -final non-padding token, and trains sklearn probes on the resulting embeddings. - -Two phases (first run does both; re-runs with --skip_extract skip phase 1): - 1. Extract a fixed-length embedding per sample (no gradients). - 2. Fit sklearn probes per label on the user-hash split, report metrics. - -Example: - python scripts/linear_probe_vlm.py \ - --hf_dataset /scratch/.../weekly_hf \ - --label_lookup /scratch/.../weekly_labels_lookup.parquet \ - --output_dir probe_results/gemma4_2b \ - --model_name google/gemma-4-E2B-it -""" -from __future__ import annotations - -import argparse -import hashlib -import io -import json -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -from PIL import Image - -from time_series_datasets.mhc_label_lookup import LABEL_TYPES, LabelLookup - - -TRAIN_FRAC = 0.8 -VAL_FRAC = 0.1 - -# Channel layout -- matches the MHC weekly HF dataset (168 hours, 19 channels) -_CONTINUOUS = [ - (0, "#e41a1c", "iPhone step count", "steps/hr"), - (1, "#377eb8", "iPhone distance", "m/hr"), - (2, "#4daf4a", "flights climbed (iPhone)", "count/hr"), - (3, "#984ea3", "Apple Watch step count", "steps/hr"), - (4, "#ff7f00", "Apple Watch distance", "m/hr"), - (5, "#a65628", "heart rate", "bpm"), - (6, "#f781bf", "active energy", "cal/hr"), -] - -_BINARY = [ - (7, "#4c72b0", "Asleep"), - (8, "#dd8452", "InBed"), - (9, "#e41a1c", "Walking"), - (10, "#377eb8", "Cycling"), - (11, "#4daf4a", "Running"), - (12, "#984ea3", "Other"), - (13, "#ff7f00", "Mixed Cardio"), - (14, "#a65628", "Strength"), - (15, "#f781bf", "Elliptical"), - (16, "#999999", "HIIT"), - (17, "#66c2a5", "Functional"), - (18, "#fc8d62", "Yoga"), -] - -_CHANNEL_LEGEND = """\ -This image shows one week of wearable sensor data. The x-axis shows hours (0h-24h repeating over 7 days). -- Panel 1 (Activity): iPhone/Watch steps, distance, flights climbed -- Panel 2 (Physiology): Heart rate (bpm), active energy (cal/hr) -- Panel 3 (Sleep): Binary asleep/inbed indicators (filled area = asleep) -- Panel 4 (Workouts): Binary workout type indicators""" # noqa: E501 - - -def _assign_split(user_id: str) -> str: - digest = int(hashlib.sha1(user_id.encode()).hexdigest(), 16) - bucket = (digest % 1000) / 1000.0 - if bucket < TRAIN_FRAC: - return "train" - if bucket < TRAIN_FRAC + VAL_FRAC: - return "validation" - return "test" - - -def _render_weekly_tensor(values: np.ndarray, dpi: int = 100) -> Image.Image: - """Render a (168, 19) weekly sensor tensor as a single-column plot. - - Top: 7 continuous channels as line plots. - Bottom: active binary channels as horizontal strips. - """ - values = values.copy() - # Heart rate (ch 5) is stored as beats/sec -- convert to bpm - values[:, 5] = values[:, 5] * 60 - - hours = np.arange(168) - n_cont = len(_CONTINUOUS) - - active_binary = [] - for ch_idx, color, name in _BINARY: - data = values[:, ch_idx] - valid = ~np.isnan(data) - if (valid & (data > 0.5)).any(): - active_binary.append((ch_idx, color, name)) - n_bin = len(active_binary) - - heights = [3] * n_cont + [1] * max(n_bin, 1) - total_rows = n_cont + max(n_bin, 1) - - fig, axes = plt.subplots( - total_rows, 1, - figsize=(14, 1.2 * n_cont + 0.4 * max(n_bin, 1)), - gridspec_kw={"height_ratios": heights, "hspace": 0.35}, - ) - if total_rows == 1: - axes = [axes] - - for row, (ch_idx, color, name, unit) in enumerate(_CONTINUOUS): - ax = axes[row] - data = values[:, ch_idx] - ax.plot(hours, data, linewidth=0.8, color=color, alpha=0.9) - if np.all(np.isnan(data)): - vmin, vmax = 0.0, 1.0 - else: - vmin, vmax = float(np.nanmin(data)), float(np.nanmax(data)) - pad = max((vmax - vmin) * 0.15, 0.5) - ax.set_ylim(vmin - pad, vmax + pad) - ax.set_xlim(0, 167) - for d in range(1, 7): - ax.axvline(d * 24, color="gray", linewidth=0.3, alpha=0.5) - ax.set_xticks([]) - ax.set_yticks([]) - ax.spines[:].set_visible(False) - ax.set_title( - f"{name} ({unit}) [{vmin:.0f}\u2013{vmax:.0f}]", - fontsize=6, color=color, fontweight="bold", loc="left", pad=2, - ) - - if n_bin > 0: - for i, (ch_idx, color, name) in enumerate(active_binary): - ax = axes[n_cont + i] - data = values[:, ch_idx] - valid = ~np.isnan(data) - active = valid & (data > 0.5) - ax.fill_between(hours, 0, 1, where=active, color=color, alpha=0.8) - ax.set_xlim(0, 167) - ax.set_ylim(0, 1) - for d in range(1, 7): - ax.axvline(d * 24, color="gray", linewidth=0.3, alpha=0.5) - ax.set_xticks([]) - ax.set_yticks([]) - ax.spines[:].set_visible(False) - ax.set_title(name, fontsize=5, color=color, fontweight="bold", loc="left", pad=1) - else: - ax = axes[n_cont] - ax.set_xlim(0, 167) - ax.set_ylim(0, 1) - ax.set_xticks([]) - ax.set_yticks([]) - ax.spines[:].set_visible(False) - ax.set_title("(no activity/sleep data)", fontsize=5, color="gray", loc="left", pad=1) - - buf = io.BytesIO() - fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight", pad_inches=0.05) - plt.close(fig) - buf.seek(0) - return Image.open(buf).convert("RGB") - - -def _auto_model_class(model_name: str): - """Pick the right AutoModel class based on model name.""" - if "gemma-4" in model_name.lower(): - from transformers import AutoModelForMultimodalLM - return AutoModelForMultimodalLM - from transformers import AutoModelForImageTextToText - return AutoModelForImageTextToText - - -def _load_model(args): - """Load the frozen VLM and processor.""" - from transformers import AutoProcessor - - AutoModelCls = _auto_model_class(args.model_name) - dtype = getattr(torch, args.torch_dtype) - - model_kwargs = { - "torch_dtype": dtype, - "device_map": {"": args.device} if args.device != "auto" else "auto", - "attn_implementation": "eager", - } - if args.quantization == "4bit": - from transformers import BitsAndBytesConfig - model_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=dtype, - ) - - print(f"Loading frozen model {args.model_name}...") - model = AutoModelCls.from_pretrained(args.model_name, **model_kwargs) - model.eval() - - processor_name = args.model_name - if "-pt" in processor_name: - processor_name = processor_name.replace("-pt", "-it") - processor = AutoProcessor.from_pretrained(processor_name) - - return model, processor - - -def _snapshot(path: Path, embeddings, user_ids, week_starts, completed, model_name) -> None: - tmp = path.with_name(path.stem + ".tmp.npz") - np.savez( - tmp, - embeddings=embeddings[:completed], - user_ids=np.array(user_ids[:completed]), - week_starts=np.array(week_starts[:completed]), - completed=np.array(completed), - model_name=np.array(model_name), - ) - tmp.replace(path) - - -def _finalize(out_path, embeddings, user_ids, week_starts, completed, model_name, label_lookup_path): - embeddings = embeddings[:completed] - user_ids = user_ids[:completed] - week_starts = week_starts[:completed] - print(f"Joining labels for {completed} samples...") - lookup = LabelLookup(label_lookup_path) - labels_df = lookup.join(user_ids, week_starts) - splits = np.array([_assign_split(uid) for uid in user_ids]) - np.savez_compressed( - out_path / "embeddings.npz", - embeddings=embeddings, - user_ids=np.array(user_ids), - week_starts=np.array(week_starts), - splits=splits, - model_name=np.array(model_name), - ) - labels_df.to_parquet(out_path / "labels.parquet") - partial = out_path / "embeddings_partial.npz" - if partial.exists(): - partial.unlink() - print(f"Saved embeddings + labels to {out_path} (model={model_name})") - - -def extract_embeddings(args, model, processor) -> None: - import datasets as hf_ds - - print(f"Loading HF dataset from {args.hf_dataset}...") - ds = hf_ds.load_from_disk(args.hf_dataset) - n_total = len(ds) - print(f"Loaded {n_total} weekly samples") - - if args.max_samples and args.max_samples < n_total: - rng = np.random.default_rng(args.seed) - idx = rng.choice(n_total, size=args.max_samples, replace=False) - ds = ds.select(sorted(idx)) - n_total = len(ds) - print(f"Subsampled to {n_total} samples (seed={args.seed})") - - # Determine hidden size from model config - config = model.config - if hasattr(config, "text_config"): - hidden = config.text_config.hidden_size - elif hasattr(config, "hidden_size"): - hidden = config.hidden_size - else: - raise ValueError("Cannot determine hidden_size from model config") - print(f"Hidden size: {hidden}") - - out_path = Path(args.output_dir) - out_path.mkdir(parents=True, exist_ok=True) - partial_path = out_path / "embeddings_partial.npz" - - embeddings = np.zeros((n_total, hidden), dtype=np.float32) - user_ids: list[str] = [""] * n_total - week_starts: list[str] = [""] * n_total - completed = 0 - - # Resume logic - resume_from_path = None - if partial_path.exists() and not args.no_resume: - resume_from_path = partial_path - elif (out_path / "embeddings.npz").exists() and not args.no_resume: - resume_from_path = out_path / "embeddings.npz" - print(f"Found finalized {resume_from_path.name}, treating as resume point.") - - if resume_from_path is not None: - cached = np.load(resume_from_path, allow_pickle=True) - prev_completed = int(cached["completed"]) if "completed" in cached.files else len(cached["embeddings"]) - prev_model = str(cached["model_name"]) if "model_name" in cached.files else "" - if prev_model and prev_model != args.model_name: - print( - f"Partial file has model={prev_model} but args.model_name={args.model_name}. " - "Ignoring and restarting." - ) - elif prev_completed > n_total: - print( - f"Partial file has {prev_completed} samples but current run has only {n_total}. " - "Ignoring (likely different --max_samples/seed)." - ) - else: - embeddings[:prev_completed] = cached["embeddings"] - user_ids[:prev_completed] = list(cached["user_ids"]) - week_starts[:prev_completed] = list(cached["week_starts"]) - completed = prev_completed - print(f"Resuming from sample {completed}/{n_total}") - - prompt_text = _CHANNEL_LEGEND + "\n\nDescribe this person's health patterns." - - try: - with torch.no_grad(): - for start in range(completed, n_total, args.batch_size): - end = min(start + args.batch_size, n_total) - batch_texts = [] - batch_images = [] - - for i in range(start, end): - row = ds[i] - values = np.array(row["values"], dtype=np.float32) - image = _render_weekly_tensor(values, dpi=args.rendering_dpi) - messages = [{"role": "user", "content": [ - {"type": "image", "image": image}, - {"type": "text", "text": prompt_text}, - ]}] - text = processor.apply_chat_template( - messages, add_generation_prompt=True, tokenize=False, - ) - batch_texts.append(text.strip()) - batch_images.append([image]) - - user_ids[i] = str(row["user_id"]) - week_starts[i] = str(row["week_start"]) - - inputs = processor( - text=batch_texts, images=batch_images, - return_tensors="pt", padding=True, - ).to(model.device) - - outputs = model(**inputs, output_hidden_states=True) - hidden_states = outputs.hidden_states[-1] # (B, L, H) - - pad_id = processor.tokenizer.pad_token_id - for b in range(hidden_states.shape[0]): - ids = inputs["input_ids"][b] - non_pad = (ids != pad_id).nonzero(as_tuple=True)[0] - last_pos = non_pad[-1].item() if len(non_pad) > 0 else hidden_states.shape[1] - 1 - feat = hidden_states[b, last_pos].float().cpu().numpy() - embeddings[start + b] = feat - - completed = end - - if (start // args.batch_size) % 50 == 0: - print(f" {completed}/{n_total}") - - if args.save_every > 0 and completed % args.save_every < args.batch_size: - _snapshot(partial_path, embeddings, user_ids, week_starts, completed, args.model_name) - - except KeyboardInterrupt: - print(f"\nInterrupted at {completed}/{n_total}. Saving partial snapshot...") - _snapshot(partial_path, embeddings, user_ids, week_starts, completed, args.model_name) - print(f"Partial saved to {partial_path}. Re-run to resume or pass --skip_extract to probe partial.") - if completed == 0: - raise - _finalize(out_path, embeddings, user_ids, week_starts, completed, args.model_name, args.label_lookup) - return - - _finalize(out_path, embeddings, user_ids, week_starts, completed, args.model_name, args.label_lookup) - - -def run_probes(args) -> None: - from sklearn.linear_model import LogisticRegression, Ridge - from sklearn.metrics import ( - balanced_accuracy_score, - mean_absolute_error, - r2_score, - roc_auc_score, - ) - from sklearn.preprocessing import StandardScaler - - out_path = Path(args.output_dir) - final_path = out_path / "embeddings.npz" - partial_path = out_path / "embeddings_partial.npz" - if not final_path.exists() and partial_path.exists(): - print("No final embeddings.npz -- finalizing from partial snapshot.") - cached = np.load(partial_path, allow_pickle=True) - completed = int(cached["completed"]) - _finalize( - out_path, cached["embeddings"], list(cached["user_ids"]), - list(cached["week_starts"]), completed, - str(cached["model_name"]), args.label_lookup, - ) - cached = np.load(final_path, allow_pickle=True) - X = cached["embeddings"] - splits = cached["splits"] - labels_df = pd.read_parquet(out_path / "labels.parquet") - - train_mask = splits == "train" - test_mask = splits == "test" - print(f"Splits: train={train_mask.sum()}, test={test_mask.sum()}") - - results: dict[str, dict] = {} - for label in labels_df.columns: - ltype = LABEL_TYPES.get(label) - if ltype is None: - continue - y_raw = labels_df[label].to_numpy() - valid = ~pd.isna(y_raw) - tr = train_mask & valid - te = test_mask & valid - if tr.sum() < 50 or te.sum() < 50: - continue - - Xtr, Xte = X[tr], X[te] - scaler = StandardScaler().fit(Xtr) - Xtr, Xte = scaler.transform(Xtr), scaler.transform(Xte) - if ltype in ("binary", "ordinal"): - ytr = y_raw[tr].astype(int) - yte = y_raw[te].astype(int) - if len(np.unique(ytr)) < 2: - print(f" {label} ({ltype}): skip -- only one class in train") - continue - clf = LogisticRegression(max_iter=2000, n_jobs=-1).fit(Xtr, ytr) - if ltype == "binary": - prob = clf.predict_proba(Xte)[:, 1] - pred = (prob > 0.5).astype(int) - auroc = ( - float(roc_auc_score(yte, prob)) - if len(np.unique(yte)) >= 2 - else float("nan") - ) - results[label] = { - "type": ltype, - "n_train": int(tr.sum()), - "n_test": int(te.sum()), - "auroc": auroc, - "bal_acc": float(balanced_accuracy_score(yte, pred)), - } - else: # ordinal - pred = clf.predict(Xte) - results[label] = { - "type": ltype, - "n_train": int(tr.sum()), - "n_test": int(te.sum()), - "bal_acc": float(balanced_accuracy_score(yte, pred)), - "mae": float(mean_absolute_error(yte, pred)), - } - else: # continuous - ytr = y_raw[tr].astype(float) - yte = y_raw[te].astype(float) - reg = Ridge(alpha=1.0).fit(Xtr, ytr) - pred = reg.predict(Xte) - results[label] = { - "type": ltype, - "n_train": int(tr.sum()), - "n_test": int(te.sum()), - "mae": float(mean_absolute_error(yte, pred)), - "r2": float(r2_score(yte, pred)), - } - - r = results[label] - summary = ", ".join( - f"{k}={v:.3f}" if isinstance(v, float) else f"{k}={v}" - for k, v in r.items() - if k != "type" - ) - print(f" {label} ({r['type']}): {summary}") - - with open(out_path / "probe_results.json", "w") as f: - json.dump(results, f, indent=2) - print(f"Wrote {out_path / 'probe_results.json'}") - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument("--hf_dataset", required=True, help="Path to weekly HF dataset on disk") - parser.add_argument("--label_lookup", required=True, help="Path to weekly_labels_lookup.parquet") - parser.add_argument("--output_dir", default="probe_results/vlm") - parser.add_argument("--model_name", default="google/gemma-4-E2B-it") - parser.add_argument("--device", default="auto", help="'auto', 'cuda', 'cuda:0', 'cpu'") - parser.add_argument("--torch_dtype", default="bfloat16") - parser.add_argument("--quantization", default=None, choices=[None, "4bit"]) - parser.add_argument("--batch_size", type=int, default=4) - parser.add_argument("--rendering_dpi", type=int, default=100) - parser.add_argument("--max_samples", type=int, default=0, help="If >0, subsample this many rows") - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--save_every", type=int, default=500) - parser.add_argument("--no_resume", action="store_true") - parser.add_argument("--skip_extract", action="store_true", help="Reuse cached embeddings, only run probes") - args = parser.parse_args() - - if not args.skip_extract: - model, processor = _load_model(args) - extract_embeddings(args, model, processor) - del model, processor - torch.cuda.empty_cache() - - run_probes(args) - - -if __name__ == "__main__": - main() diff --git a/scripts/metabolic_finetune.py b/scripts/metabolic_finetune.py index 3580345..9fef9d4 100644 --- a/scripts/metabolic_finetune.py +++ b/scripts/metabolic_finetune.py @@ -49,8 +49,8 @@ from mhc.constants import CONTINUOUS_CHANNELS from time_series_datasets.mhc_base_qa_dataset import ( MHCBaseQADataset, - _assign_split, _iter_shard_rows, + _load_split_assignment, ) from time_series_datasets.mhc_label_lookup import ( LABEL_TYPES, @@ -123,13 +123,16 @@ class MHCMultiLabelQADataset(Dataset): samples — a silent correctness bug for multi-task training). """ - def __init__(self, split, eos, labels, export_dir, label_lookup_path): + def __init__(self, split, eos, labels, export_dir, label_lookup_path, split_file): for label in labels: if label not in LABEL_TYPES: raise ValueError(f"Unknown label: {label!r}") + user_to_bucket = _load_split_assignment(split_file) all_samples = _load_base_samples(export_dir, label_lookup_path, labels) - split_samples = [s for s in all_samples if _assign_split(s["user_id"]) == split] + split_samples = [ + s for s in all_samples if user_to_bucket.get(s["user_id"]) == split + ] items: list[dict] = [] pre_text = "You are an expert in wearable sensor time series analysis." @@ -179,8 +182,8 @@ def __getitem__(self, idx): return self.items[idx] -def _build_mixed(split, labels, eos, export_dir, label_lookup_path): - ds = MHCMultiLabelQADataset(split, eos, labels, export_dir, label_lookup_path) +def _build_mixed(split, labels, eos, export_dir, label_lookup_path, split_file): + ds = MHCMultiLabelQADataset(split, eos, labels, export_dir, label_lookup_path, split_file) if len(ds) == 0: raise RuntimeError(f"No usable samples in split={split} for any label") counts = ", ".join(f"{k}={v}" for k, v in ds.per_label_counts.items()) @@ -240,6 +243,10 @@ def main() -> None: "--checkpoint", required=True, help="Stage-1 checkpoint (latest.pt or best_model.pt)" ) parser.add_argument("--label_lookup", required=True) + parser.add_argument( + "--split_file", required=True, + help="MHC-benchmark canonical split JSON (sharable_users_seed42_2026.json).", + ) parser.add_argument("--export_dir", default="exports/lean_smoke") parser.add_argument("--output_dir", required=True) parser.add_argument("--llm_id", default="meta-llama/Llama-3.2-1B") @@ -293,8 +300,12 @@ def main() -> None: # Data eos = model.get_eos_token() - train_ds = _build_mixed("train", args.labels, eos, args.export_dir, args.label_lookup) - val_ds = _build_mixed("validation", args.labels, eos, args.export_dir, args.label_lookup) + train_ds = _build_mixed( + "train", args.labels, eos, args.export_dir, args.label_lookup, args.split_file, + ) + val_ds = _build_mixed( + "validation", args.labels, eos, args.export_dir, args.label_lookup, args.split_file, + ) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=list ) diff --git a/time_series_datasets/mhc_base_qa_dataset.py b/time_series_datasets/mhc_base_qa_dataset.py index 09b987d..af426c8 100644 --- a/time_series_datasets/mhc_base_qa_dataset.py +++ b/time_series_datasets/mhc_base_qa_dataset.py @@ -8,7 +8,7 @@ from __future__ import annotations -import hashlib +import json from pathlib import Path from typing import Iterator, List, Literal, Tuple @@ -25,19 +25,19 @@ DEFAULT_EXPORT_DIR = "exports" -TRAIN_FRAC = 0.8 -VAL_FRAC = 0.1 -def _assign_split(user_id: str) -> Literal["train", "validation", "test"]: - """Deterministic user-level split via SHA-1 bucketing (80 / 10 / 10).""" - digest = int(hashlib.sha1(user_id.encode()).hexdigest(), 16) - bucket = (digest % 1000) / 1000.0 - if bucket < TRAIN_FRAC: - return "train" - if bucket < TRAIN_FRAC + VAL_FRAC: - return "validation" - return "test" +def _load_split_assignment( + split_file: str | Path, +) -> dict[str, Literal["train", "validation", "test"]]: + """Invert MHC-benchmark's split JSON into a user_id → bucket map.""" + with open(split_file) as f: + splits = json.load(f) + user_to_bucket: dict[str, Literal["train", "validation", "test"]] = {} + for bucket in ("train", "validation", "test"): + for uid in splits[bucket]: + user_to_bucket[uid] = bucket + return user_to_bucket def _iter_shard_rows(export_dir: Path) -> Iterator[dict]: @@ -65,19 +65,25 @@ def __init__( export_dir: str | Path = DEFAULT_EXPORT_DIR, label_lookup_path: str | Path | None = None, required_labels: List[str] | None = None, + split_file: str | Path | None = None, ) -> None: """Load and format QA samples from captioned MHC arrow shards. required_labels: label columns to fetch from the labels parquet. - Subclasses pass their task's target(s) — e.g. - ``MHCMetabolicQADataset`` forwards ``[target_label]`` here. - Routed straight into ``LabelLookup.join(labels=...)``. ``None`` pulls every column registered in ``LABEL_TYPES``. + split_file: MHC-benchmark's ``sharable_users_seed42_2026.json``. + Required; users absent from the JSON are dropped. """ + if split_file is None: + raise ValueError( + "split_file is required: pass MHC-benchmark's " + "sharable_users_seed42_2026.json" + ) self._split_name = split self._export_dir = Path(export_dir) self._label_lookup = LabelLookup(label_lookup_path) if label_lookup_path else None self._required_labels = required_labels + self._user_to_bucket = _load_split_assignment(split_file) super().__init__(split, EOS_TOKEN) # self._raw_samples is populated inside _load_splits. @@ -113,7 +119,9 @@ def _load_splits(self) -> Tuple[List[dict], List[dict], List[dict]]: sample["labels"] = row.to_dict() for sample in all_samples: - bucket = _assign_split(sample["user_id"]) + bucket = self._user_to_bucket.get(sample["user_id"]) + if bucket is None: + continue {"train": train, "validation": validation, "test": test}[bucket].append(sample) # Keep this instance's split for get_eval_context() — QADataset discards diff --git a/time_series_datasets/mhc_metabolic_qa_dataset.py b/time_series_datasets/mhc_metabolic_qa_dataset.py index 8c80da9..67e87ec 100644 --- a/time_series_datasets/mhc_metabolic_qa_dataset.py +++ b/time_series_datasets/mhc_metabolic_qa_dataset.py @@ -28,6 +28,7 @@ def __init__( target_label: str, export_dir: str | Path = "exports", label_lookup_path: str | Path | None = None, + split_file: str | Path | None = None, ) -> None: if target_label not in LABEL_TYPES: raise ValueError(f"Unknown label: {target_label!r}") @@ -37,6 +38,7 @@ def __init__( super().__init__( split, EOS_TOKEN, export_dir, label_lookup_path, required_labels=[target_label], + split_file=split_file, ) if self._target_label not in self._label_lookup.label_names: raise ValueError( diff --git a/time_series_datasets/mhc_multi_label_qa_dataset.py b/time_series_datasets/mhc_multi_label_qa_dataset.py index b1a0c67..82fe813 100644 --- a/time_series_datasets/mhc_multi_label_qa_dataset.py +++ b/time_series_datasets/mhc_multi_label_qa_dataset.py @@ -26,8 +26,8 @@ from mhc.constants import CONTINUOUS_CHANNELS from time_series_datasets.mhc_base_qa_dataset import ( MHCBaseQADataset, - _assign_split, _iter_shard_rows, + _load_split_assignment, ) from time_series_datasets.mhc_label_lookup import ( LABEL_TYPES, @@ -88,13 +88,19 @@ def _load_base_samples(export_dir, label_lookup_path, labels): class MHCMultiLabelQADataset(Dataset): - def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path): + def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file): + """Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's + ``sharable_users_seed42_2026.json``. Required for canonical splits. + """ for label in labels: if label not in LABEL_TYPES: raise ValueError(f"Unknown label: {label!r}") + user_to_bucket = _load_split_assignment(split_file) all_samples = _load_base_samples(export_dir, label_lookup_path, labels) - split_samples = [s for s in all_samples if _assign_split(s["user_id"]) == split] + split_samples = [ + s for s in all_samples if user_to_bucket.get(s["user_id"]) == split + ] items: list[dict] = [] pre_text = "You are an expert in wearable sensor time series analysis."