Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions curriculum_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,13 +683,22 @@ def main():

from time_series_datasets.mhc_caption_qa_dataset import MHCCaptionQADataset

split_file = os.environ.get("MHC_SPLIT_FILE")
if not split_file:
raise RuntimeError(
"MHC_SPLIT_FILE env var is required: point at MHC-benchmark's "
"sharable_users_seed42_2026.json so split assignment is consistent."
)

caption_kwargs = {
"export_dir": os.environ.get("CAPTION_EXPORT_DIR", "exports/lean_smoke"),
"split_file": split_file,
}
metabolic_kwargs = {
"labels": os.environ.get("METABOLIC_LABELS", ",".join(DEFAULT_METABOLIC_LABELS)).split(","),
"export_dir": os.environ.get("CAPTION_EXPORT_DIR", "exports/lean_smoke"),
"label_lookup_path": os.environ.get("METABOLIC_LABEL_LOOKUP"),
"split_file": split_file,
}

curriculum_stages = OrderedDict(
Expand Down
77 changes: 51 additions & 26 deletions evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
from dataclasses import dataclass

from tqdm import tqdm
from transformers import logging as hf_logging

from evaluation.tasks.base import EvalTask
from models.base import BaseModel, ModelResponse

# Silence the per-generate "Setting pad_token_id to eos_token_id" warning.
hf_logging.set_verbosity_error()


@dataclass(slots=True)
class SampleResult:
Expand Down Expand Up @@ -54,39 +58,60 @@ def run(
dataset,
model_name: str,
max_samples: int | None = None,
save_every: int | None = None,
results_writer=None,
out_dir=None,
) -> EvalResult:
"""Evaluate *model* on *dataset* using *task*.

The dataset must be an OpenTSLM-style ``QADataset`` and expose
``get_eval_context(idx)`` returning ``(Recording, ground_truth)``.
``dataset[idx]`` provides the model-facing sample dict.

With ``save_every`` + ``results_writer`` + ``out_dir`` set, partial
results are flushed every N samples and on KeyboardInterrupt.
"""
n = len(dataset) if max_samples is None else min(max_samples, len(dataset))
per_sample: list[SampleResult] = []

for i in tqdm(range(n), desc=f"{model_name} / {task.name}"):
recording, ground_truth = dataset.get_eval_context(i)
sample = dataset[i]

response: ModelResponse = model.generate(sample=sample, recording=recording)
parsed = task.parse_prediction(response.text)

per_sample.append(SampleResult(
sample_id=recording.row_id,
ground_truth=ground_truth,
raw_prediction=response.text,
parsed_prediction=parsed,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
metadata={"user_id": recording.user_id, "date": recording.date},
))

gts = [s.ground_truth for s in per_sample]
preds = [s.parsed_prediction for s in per_sample]

return EvalResult(
model_name=model_name,
task_name=task.name,
per_sample=per_sample,
metrics=task.aggregate_metrics(gts, preds),
)
def _build_partial() -> "EvalResult":
gts = [s.ground_truth for s in per_sample]
preds = [s.parsed_prediction for s in per_sample]
metrics = task.aggregate_metrics(gts, preds) if per_sample else {}
return EvalResult(
model_name=model_name,
task_name=task.name,
per_sample=per_sample,
metrics=metrics,
)

def _flush() -> None:
if results_writer is not None and out_dir is not None and per_sample:
results_writer.write(_build_partial(), out_dir=out_dir)

try:
for i in tqdm(range(n), desc=f"{model_name} / {task.name}"):
recording, ground_truth = dataset.get_eval_context(i)
sample = dataset[i]

response: ModelResponse = model.generate(sample=sample, recording=recording)
parsed = task.parse_prediction(response.text)

per_sample.append(SampleResult(
sample_id=recording.row_id,
ground_truth=ground_truth,
raw_prediction=response.text,
parsed_prediction=parsed,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
metadata={"user_id": recording.user_id, "date": recording.date},
))

if save_every and (i + 1) % save_every == 0:
_flush()
except KeyboardInterrupt:
print(f"\n[interrupted at sample {len(per_sample)}/{n}] saving partial results")
_flush()
raise

return _build_partial()
19 changes: 13 additions & 6 deletions evaluation/io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,25 @@ class ResultsWriter:
def __init__(self, results_dir: Path | str = Path("results")) -> None:
self.results_dir = Path(results_dir)

def write(self, result: EvalResult) -> Path:
def write(self, result: EvalResult, out_dir: Path | None = None) -> Path:
"""Persist *result* to disk and return the output directory path.

If ``out_dir`` is given, write into it (overwriting existing files);
otherwise create a fresh ``<timestamp>`` subdir.

Creates:
metrics.json — aggregated metrics dict (directly JSON-serialisable).
predictions.csv — one row per evaluated sample.
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
safe_task = re.sub(r"[^\w\-.]", "_", result.task_name)
safe_model = re.sub(r"[^\w\-.]", "_", result.model_name)
out_dir = self.results_dir / safe_task / safe_model / timestamp
out_dir.mkdir(parents=True, exist_ok=False)
if out_dir is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
safe_task = re.sub(r"[^\w\-.]", "_", result.task_name)
safe_model = re.sub(r"[^\w\-.]", "_", result.model_name)
out_dir = self.results_dir / safe_task / safe_model / timestamp
out_dir.mkdir(parents=True, exist_ok=False)
else:
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

# metrics.json
with open(out_dir / "metrics.json", "w") as f:
Expand Down
28 changes: 27 additions & 1 deletion evaluation/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def _parse_args() -> argparse.Namespace:
"OpenTSLM.load_pretrained."
),
)
parser.add_argument(
"--save-every", type=int, default=None,
help=(
"Persist partial results every N samples (and on KeyboardInterrupt) "
"to the same output dir. Useful for long evals on slow filesystems."
),
)
return parser.parse_args()


Expand All @@ -100,6 +107,12 @@ def _build_model(args: argparse.Namespace):


def _build_dataset(args: argparse.Namespace, task: MetabolicTask):
split_file = os.environ.get("MHC_SPLIT_FILE")
if not split_file:
raise RuntimeError(
"MHC_SPLIT_FILE env var is required: point at MHC-benchmark's "
"sharable_users_seed42_2026.json so split assignment is consistent."
)
return task.dataset_cls(
split="test",
EOS_TOKEN="",
Expand All @@ -108,6 +121,7 @@ def _build_dataset(args: argparse.Namespace, task: MetabolicTask):
label_lookup_path=os.environ.get(
"MHC_LABELS_PATH", "data/labels_dev_subset.parquet"
),
split_file=split_file,
)


Expand All @@ -131,16 +145,28 @@ def main() -> None:
print("No samples found — check MHC_EXPORT_DIR.")
sys.exit(1)

# Pre-create the output dir so periodic saves rewrite the same path.
from datetime import datetime
import re
writer = ResultsWriter()
safe_task = re.sub(r"[^\w\-.]", "_", task.name)
safe_model = re.sub(r"[^\w\-.]", "_", model_name)
out_dir = writer.results_dir / safe_task / safe_model / datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir.mkdir(parents=True, exist_ok=True)

evaluator = Evaluator()
result = evaluator.run(
task=task,
model=model,
dataset=dataset,
model_name=model_name,
max_samples=max_samples,
save_every=args.save_every,
results_writer=writer,
out_dir=out_dir,
)

ResultsWriter().write(result)
writer.write(result, out_dir=out_dir)

print("\n=== RESULTS ===")
for key, val in result.metrics.items():
Expand Down
27 changes: 26 additions & 1 deletion mhc/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,42 @@
#
# SPDX-License-Identifier: MIT
#
import json

from datasets import load_from_disk
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
from torch.utils.data import Dataset

from mhc.constants import DATASET_DIR


class MHCDataset(Dataset):
def __init__(self, min_wear_pct: float = 0.0, dataset_dir: str = DATASET_DIR):
def __init__(
self,
min_wear_pct: float = 0.0,
dataset_dir: str = DATASET_DIR,
split_file: str | None = None,
):
"""Load MHC daily dataset, optionally filtered by ``split_file``
(MHC-benchmark's ``sharable_users_seed42_2026.json`` — rows with
``user_id`` outside train ∪ validation ∪ test are dropped)."""
self.ds = load_from_disk(dataset_dir)

if split_file is not None:
with open(split_file) as f:
splits = json.load(f)
allowed_users = (
set(splits["train"]) | set(splits["validation"]) | set(splits["test"])
)
mask = pc.is_in(
self.ds.data.column("user_id"),
value_set=pa.array(sorted(allowed_users)),
)
indices = [i for i, m in enumerate(mask.to_pylist()) if m]
self.ds = self.ds.select(indices)

if min_wear_pct > 0:
max_nonwear = 1440 * (1 - min_wear_pct / 100)
mask = pc.less_equal(self.ds.data.column("total_nonwear_minutes"), max_nonwear)
Expand Down
21 changes: 20 additions & 1 deletion mhc_weekly/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#
# SPDX-License-Identifier: MIT
#
import json

from datasets import load_from_disk
import numpy as np
from torch.utils.data import Dataset
Expand All @@ -12,8 +14,25 @@


class WeeklyMHCDataset(Dataset):
def __init__(self, min_valid_hours: int = 0, dataset_dir: str = WEEKLY_DATASET_DIR):
def __init__(
self,
min_valid_hours: int = 0,
dataset_dir: str = WEEKLY_DATASET_DIR,
split_file: str | None = None,
):
"""Load weekly MHC dataset, optionally filtered by ``split_file``
(MHC-benchmark's ``sharable_users_seed42_2026.json`` — rows with
``user_id`` outside train ∪ validation ∪ test are dropped)."""
self.ds = load_from_disk(dataset_dir)

if split_file is not None:
with open(split_file) as f:
splits = json.load(f)
allowed_users = (
set(splits["train"]) | set(splits["validation"]) | set(splits["test"])
)
self.ds = self.ds.filter(lambda row: row["user_id"] in allowed_users)

if min_valid_hours > 0:
self.ds = self.ds.filter(lambda row: row["n_valid_hours"] >= min_valid_hours)

Expand Down
23 changes: 17 additions & 6 deletions scripts/export_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,23 @@
from timef.schema import CaptionResult


def _build_daily(min_wear_pct: float):
def _build_daily(min_wear_pct: float, split_file: str | None):
from mhc.constants import MHC_CHANNEL_CONFIG
from mhc.dataset import MHCDataset
from mhc.transformer import MHCTransformer
return MHCDataset(min_wear_pct=min_wear_pct), MHCTransformer(), MHC_CHANNEL_CONFIG
return (
MHCDataset(min_wear_pct=min_wear_pct, split_file=split_file),
MHCTransformer(),
MHC_CHANNEL_CONFIG,
)


def _build_weekly(min_valid_hours: int):
def _build_weekly(min_valid_hours: int, split_file: str | None):
from mhc_weekly.constants import WEEKLY_CHANNEL_CONFIG
from mhc_weekly.dataset import WeeklyMHCDataset
from mhc_weekly.transformer import WeeklyMHCTransformer
return (
WeeklyMHCDataset(min_valid_hours=min_valid_hours),
WeeklyMHCDataset(min_valid_hours=min_valid_hours, split_file=split_file),
WeeklyMHCTransformer(),
WEEKLY_CHANNEL_CONFIG,
)
Expand All @@ -65,12 +69,19 @@ def main():
help="Start index (for sharded runs).")
parser.add_argument("--end", type=int, default=None,
help="End index exclusive (for sharded runs).")
parser.add_argument("--split_file", default=None,
help="MHC-benchmark canonical split JSON. When set, "
"non-sharable users are dropped at load time.")
args = parser.parse_args()

if args.variant == "daily":
dataset, transformer, channel_config = _build_daily(args.min_wear_pct)
dataset, transformer, channel_config = _build_daily(
args.min_wear_pct, args.split_file
)
else:
dataset, transformer, channel_config = _build_weekly(args.min_valid_hours)
dataset, transformer, channel_config = _build_weekly(
args.min_valid_hours, args.split_file
)

annotator = Annotator([
StructuralExtractor(channel_config),
Expand Down
10 changes: 9 additions & 1 deletion scripts/export_captions_sharded.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
# REPO_DIR default $HOME/SensorTSLM
# VENV_DIR default $SCRATCH/sensor-venv (used via sherlock_setup.sh)
# TOTAL_ROWS auto-detected for daily; defaults to 615490 for weekly
# SPLIT_FILE path to MHC-benchmark sharable_users_seed42_2026.json; when
# set, non-sharable users are dropped at load time

set -euo pipefail

Expand Down Expand Up @@ -57,6 +59,12 @@ fi
mkdir -p "$REPO_DIR/logs"
STRIDE=$(( (TOTAL_ROWS + NUM_SHARDS - 1) / NUM_SHARDS ))

SPLIT_ARG=""
if [ -n "${SPLIT_FILE:-}" ]; then
SPLIT_ARG="--split_file $SPLIT_FILE"
echo " split_file: $SPLIT_FILE"
fi

echo "Submitting $NUM_SHARDS shards of ~$STRIDE rows each ($VARIANT)"
echo " out: $OUT_DIR/shard_<i>"
echo " source: $DATA_DIR"
Expand All @@ -71,7 +79,7 @@ for i in $(seq 0 $((NUM_SHARDS - 1))); do
source \$SCRATCH/sensor-venv/bin/activate; \
export ${ENV_VAR}=$DATA_DIR; \
cd $REPO_DIR; \
python scripts/export_captions.py --variant $VARIANT --out $OUT_DIR/shard_$i --start $START --end $END"
python scripts/export_captions.py --variant $VARIANT --out $OUT_DIR/shard_$i --start $START --end $END $SPLIT_ARG"
done

echo "Check: squeue -u \$USER"
Loading
Loading