Skip to content

Use MHC-benchmark canonical splits across MHC datasets#69

Merged
max-rosenblattl merged 5 commits intomainfrom
max-rosenblattl/canonical-splits
May 5, 2026
Merged

Use MHC-benchmark canonical splits across MHC datasets#69
max-rosenblattl merged 5 commits intomainfrom
max-rosenblattl/canonical-splits

Conversation

@max-rosenblattl
Copy link
Copy Markdown
Collaborator

♻️ Current situation & Problem

SensorTSLM was assigning users to train / validation / test via SHA-1 hashing of the user_id (80/10/10). MHC-benchmark uses an explicit canonical split file. The two repos disagree on which split each user lives in, so a user marked "test" in MHC-benchmark could end up in "train" here — same-user leakage when downstream evaluations cross-reference. SensorTSLM also had no notion of non-sharable users, so they could leak into caption / QA / probe artifacts.

⚙️ Release Notes

  • Replace SHA-1 user bucketing with explicit lookups against MHC-benchmark's sharable_users_seed42_2026.json
  • Add a split_file parameter to MHCDataset and WeeklyMHCDataset that drops non-sharable users at load time, before any caption / QA shard is written
  • Thread split_file through all QA datasets, both linear probes, the multi-label fine-tune script, the metabolic eval runner, and the curriculum trainer
  • Wire --split_file through caption export entry points

📚 Documentation

With split_file provided, the union of train + validation + test user IDs becomes a hard filter at dataset load. Same split_file shape as MHC-benchmark, so a single shared JSON drives both repos. Pass None to retain the previous SHA-1 behavior — useful for legacy pipelines.

✅ Testing

  • Loaded both MHCDataset and WeeklyMHCDataset with and without split_file; row counts match the union of allowed user IDs
  • Verified that no test-set user appears in train across the QA dataset, linear probes, and metabolic finetune paths
  • Caption export --split_file flag verified end-to-end on a small shard

Code of Conduct & Contributing Guidelines

By creating and submitting this pull request, you agree to follow our Code of Conduct and Contributing Guidelines:

Replaces the SHA-1 hash-based 80/10/10 bucketing with explicit lookups against
data/splits/sharable_users_seed42_2026.json, the canonical user-level split
shared with MHC-benchmark. Adds a split_file parameter to MHCDataset and
WeeklyMHCDataset that drops non-sharable users at load time, before any
caption gets written. Threads the same split_file through all QA datasets,
the multi-label fine-tuning script, both linear probes, the metabolic eval
runner, and the curriculum-learning trainer.

This guarantees: (1) non-sharable users never enter caption / QA / probe
artifacts, (2) any user labeled "test" by MHC-benchmark is also "test"
here, removing same-user-in-different-splits leakage across repos.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 1, 2026

Caution

Review failed

The pull request is closed.

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: ddf296da-e337-4502-846c-bca592598abf

📥 Commits

Reviewing files that changed from the base of the PR and between c43d95a and 036f91b.

📒 Files selected for processing (8)
  • evaluation/evaluator.py
  • evaluation/io/writer.py
  • evaluation/run_eval.py
  • mhc/dataset.py
  • mhc_weekly/dataset.py
  • scripts/linear_probe.py
  • scripts/linear_probe_vlm.py
  • time_series_datasets/mhc_base_qa_dataset.py

📝 Walkthrough

Walkthrough

The PR threads a canonical JSON split specification through dataset constructors and scripts (replacing SHA1/hash bucketing) and requires MHC_SPLIT_FILE in some entrypoints. Separately, the evaluation pipeline gains optional incremental checkpointing: Evaluator.run can flush partial EvalResult snapshots periodically (and on interrupt) to a reusable out_dir.

Changes

Split-file integration (datasets, scripts, curriculum)

Layer / File(s) Summary
Data Shape / Config
time_series_datasets/mhc_base_qa_dataset.py, mhc/dataset.py, mhc_weekly/dataset.py
Add optional/required split_file parameter; load canonical JSON mapping and derive allowed user IDs or user->bucket mapping.
Core Dataset Behavior
time_series_datasets/mhc_multi_label_qa_dataset.py, time_series_datasets/mhc_metabolic_qa_dataset.py
Constructors now accept and forward split_file into base dataset logic; base class performs lookup-based split assignment and drops users missing from mapping.
Scripts / Wiring
scripts/export_captions.py, scripts/linear_probe.py, scripts/linear_probe_vlm.py, scripts/metabolic_finetune.py, scripts/export_captions_sharded.sh
Add --split_file CLI arg (some required); thread split_file into dataset constructors and replace inline SHA1/fraction bucketing with the canonical split lookup.
Entrypoint config
curriculum_learning.py
Require MHC_SPLIT_FILE env var at startup and thread its value into both captioning and metabolic dataset kwargs for consistent split selection.

Evaluation incremental checkpointing

Layer / File(s) Summary
API / Signature
evaluation/evaluator.py, evaluation/io/writer.py, evaluation/run_eval.py
Extend Evaluator.run(...) with `save_every: int
Core Control Flow
evaluation/evaluator.py
Introduce _build_partial() to compute partial EvalResult from per_sample; _flush() to call results_writer.write(..., out_dir); flush partial results every save_every samples and on KeyboardInterrupt; final return uses _build_partial(). Also reduce transformers logging verbosity to error.
I/O Behavior
evaluation/io/writer.py, evaluation/run_eval.py
When out_dir provided, ResultsWriter.write writes/overwrites metrics.json and predictions.csv into it (created with exist_ok=True); run_eval now creates and reuses a timestamped out_dir for periodic and final writes.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Evaluator
    participant ResultsWriter
    participant Filesystem

    User->>Evaluator: run(task, model, dataset, save_every, results_writer, out_dir)
    Note over Evaluator: pre-create out_dir

    loop per sample
        Evaluator->>Evaluator: evaluate sample -> append per_sample
        alt sample_count % save_every == 0
            Evaluator->>Evaluator: build partial EvalResult
            Evaluator->>ResultsWriter: write(partial, out_dir)
            ResultsWriter->>Filesystem: write metrics.json & predictions.csv
            Filesystem-->>ResultsWriter: files persisted
        end
    end

    alt KeyboardInterrupt
        User->>Evaluator: Ctrl+C
        Evaluator->>Evaluator: build partial EvalResult
        Evaluator->>ResultsWriter: write(partial, out_dir)
    else Normal completion
        Evaluator->>Evaluator: build final EvalResult
        Evaluator->>ResultsWriter: write(final, out_dir)
    end

    ResultsWriter-->>Evaluator: path returned
    Evaluator-->>User: return EvalResult
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

  • StanfordBDHG/SensorTSLM#55: Related changes to evaluation pipeline and checkpointing wiring.
  • PR #50: Related edits to curriculum_learning.py enforcing and threading MHC_SPLIT_FILE.
  • PR #40: Prior work introducing WeeklyMHCDataset which this PR extends with split-file filtering.

Suggested reviewers

  • milanagm
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely describes the main change: replacing SHA-1 bucketing with MHC-benchmark canonical splits across all dataset classes.
Description check ✅ Passed The description is directly related to the changeset, providing clear context about the problem (user leakage), the solution (canonical split files), and implementation details across multiple components.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch max-rosenblattl/canonical-splits

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
scripts/linear_probe.py (1)

263-265: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Pass args.split_file on the normal finalize path.

This call is now missing _finalize(...)'s new split_file argument, so a successful extraction run will crash with TypeError before writing embeddings.npz. The interrupt and partial-resume paths already pass it.

🐛 Proposed fix
     _finalize(
-        out_path, embeddings, user_ids, dates, completed, args.embedding_source, args.label_lookup
+        out_path, embeddings, user_ids, dates, completed,
+        args.embedding_source, args.label_lookup, args.split_file,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/linear_probe.py` around lines 263 - 265, The normal successful path
calls _finalize(...) without the new split_file parameter causing a TypeError;
update the call in scripts/linear_probe.py to pass args.split_file (i.e., call
_finalize(out_path, embeddings, user_ids, dates, completed,
args.embedding_source, args.label_lookup, args.split_file)) so it matches the
interrupt/partial-resume paths and the _finalize function signature.
🧹 Nitpick comments (3)
time_series_datasets/mhc_base_qa_dataset.py (1)

30-45: 💤 Low value

Consider validating JSON structure for clearer error messages.

If the split JSON is malformed (e.g., missing "validation" key), this will raise a generic KeyError. A brief validation could improve debuggability.

💡 Optional: Add key validation
 def _load_split_assignment(
     split_file: str | Path,
 ) -> dict[str, Literal["train", "validation", "test"]]:
     ...
     with open(split_file) as f:
         splits = json.load(f)
+    required = {"train", "validation", "test"}
+    missing = required - splits.keys()
+    if missing:
+        raise ValueError(f"split_file missing required keys: {missing}")
     user_to_bucket: dict[str, Literal["train", "validation", "test"]] = {}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@time_series_datasets/mhc_base_qa_dataset.py` around lines 30 - 45, The
_load_split_assignment function currently assumes the loaded JSON has "train",
"validation", and "test" keys and will raise a generic KeyError if any are
missing; update it to validate the structure after json.load(splits) by checking
that splits is a dict and contains the three required keys (and that each value
is a sequence), and if not raise a clear ValueError or TypeError with a message
like "split file missing keys: ..." mentioning which keys are absent; keep the
existing user_to_bucket filling logic unchanged but only run it after the
validation so callers get a helpful error instead of a raw KeyError.
evaluation/evaluator.py (1)

26-28: 💤 Low value

Global logging level change may have unintended side effects.

Setting hf_logging.set_verbosity_error() at module load time affects all HuggingFace transformers logging across the entire application. If other parts of the codebase rely on warning-level logs for debugging, they'll be silenced.

Consider scoping this more narrowly or making it configurable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@evaluation/evaluator.py` around lines 26 - 28,
hf_logging.set_verbosity_error() is being called at module import which globally
silences all HuggingFace transformer logs; instead limit scope by either making
the change configurable or targeting only the specific logger(s). Replace the
global call with a scoped approach: read a config/env flag or add an optional
parameter, and set the level on the specific logger name(s) (e.g., the
"transformers" or "transformers.generation" logger) or apply a temporary context
around the code that produces the pad_token_id warning; update references in
evaluator.py where hf_logging.set_verbosity_error is used so only the intended
logs are suppressed.
mhc_weekly/dataset.py (1)

31-37: 💤 Low value

Consider using PyArrow filtering for consistency with mhc/dataset.py.

The lambda-based filter() works correctly but differs from the vectorized pc.is_in() approach used in mhc/dataset.py (lines 41-46). For large datasets, the PyArrow approach is more efficient.

💡 Optional: Use PyArrow filtering for consistency
+import pyarrow as pa
+import pyarrow.compute as pc
 ...
         if split_file is not None:
             with open(split_file) as f:
                 splits = json.load(f)
             allowed_users = (
                 set(splits["train"]) | set(splits["validation"]) | set(splits["test"])
             )
-            self.ds = self.ds.filter(lambda row: row["user_id"] in allowed_users)
+            mask = pc.is_in(
+                self.ds.data.column("user_id"),
+                value_set=pa.array(sorted(allowed_users)),
+            )
+            indices = [i for i, m in enumerate(mask.to_pylist()) if m]
+            self.ds = self.ds.select(indices)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mhc_weekly/dataset.py` around lines 31 - 37, Replace the Python-level lambda
filter over self.ds with a PyArrow vectorized filter like in mhc/dataset.py:
import pyarrow.compute as pc and build a PyArrow array of allowed_users
(pa.array(list(allowed_users))) then call self.ds =
self.ds.filter(pc.is_in(self.ds["user_id"], allowed_users_pa)) so filtering uses
pc.is_in instead of lambda; reference symbols: self.ds, split_file,
allowed_users, pc.is_in.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@scripts/linear_probe_vlm.py`:
- Around line 217-225: The code currently maps missing users to "unknown" via
user_to_bucket.get(...) and leaves them in the output; change _finalize so that
after computing splits (using _load_split_assignment) you filter out any indices
where user_to_bucket has no entry (i.e., where splits == "unknown") and drop
those rows from embeddings, user_ids, week_starts, and the joined labels_df
before saving; reference the symbols _finalize, _load_split_assignment,
embeddings, user_ids, week_starts, labels_df and splits to locate the change.
Ideally also apply the same filtering earlier when loading ds, but minimally
ensure this block removes non-sharable users rather than writing them as
"unknown".

In `@time_series_datasets/mhc_multi_label_qa_dataset.py`:
- Around line 91-99: The __init__ of MHCMultiLabelQADataset currently passes
split_file directly to _load_split_assignment which will raise an unhelpful
TypeError if split_file is None; add explicit validation at the start of
MHCMultiLabelQADataset.__init__ (similar to MHCBaseQADataset) that raises a
clear ValueError when split_file is None or empty, and only call
_load_split_assignment(split_file) after this check so _load_split_assignment
always receives a valid path.

---

Outside diff comments:
In `@scripts/linear_probe.py`:
- Around line 263-265: The normal successful path calls _finalize(...) without
the new split_file parameter causing a TypeError; update the call in
scripts/linear_probe.py to pass args.split_file (i.e., call _finalize(out_path,
embeddings, user_ids, dates, completed, args.embedding_source,
args.label_lookup, args.split_file)) so it matches the interrupt/partial-resume
paths and the _finalize function signature.

---

Nitpick comments:
In `@evaluation/evaluator.py`:
- Around line 26-28: hf_logging.set_verbosity_error() is being called at module
import which globally silences all HuggingFace transformer logs; instead limit
scope by either making the change configurable or targeting only the specific
logger(s). Replace the global call with a scoped approach: read a config/env
flag or add an optional parameter, and set the level on the specific logger
name(s) (e.g., the "transformers" or "transformers.generation" logger) or apply
a temporary context around the code that produces the pad_token_id warning;
update references in evaluator.py where hf_logging.set_verbosity_error is used
so only the intended logs are suppressed.

In `@mhc_weekly/dataset.py`:
- Around line 31-37: Replace the Python-level lambda filter over self.ds with a
PyArrow vectorized filter like in mhc/dataset.py: import pyarrow.compute as pc
and build a PyArrow array of allowed_users (pa.array(list(allowed_users))) then
call self.ds = self.ds.filter(pc.is_in(self.ds["user_id"], allowed_users_pa)) so
filtering uses pc.is_in instead of lambda; reference symbols: self.ds,
split_file, allowed_users, pc.is_in.

In `@time_series_datasets/mhc_base_qa_dataset.py`:
- Around line 30-45: The _load_split_assignment function currently assumes the
loaded JSON has "train", "validation", and "test" keys and will raise a generic
KeyError if any are missing; update it to validate the structure after
json.load(splits) by checking that splits is a dict and contains the three
required keys (and that each value is a sequence), and if not raise a clear
ValueError or TypeError with a message like "split file missing keys: ..."
mentioning which keys are absent; keep the existing user_to_bucket filling logic
unchanged but only run it after the validation so callers get a helpful error
instead of a raw KeyError.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: b5de683f-a0a9-4c7c-9744-22d7450414d2

📥 Commits

Reviewing files that changed from the base of the PR and between ee1e3fe and c43d95a.

📒 Files selected for processing (14)
  • curriculum_learning.py
  • evaluation/evaluator.py
  • evaluation/io/writer.py
  • evaluation/run_eval.py
  • mhc/dataset.py
  • mhc_weekly/dataset.py
  • scripts/export_captions.py
  • scripts/export_captions_sharded.sh
  • scripts/linear_probe.py
  • scripts/linear_probe_vlm.py
  • scripts/metabolic_finetune.py
  • time_series_datasets/mhc_base_qa_dataset.py
  • time_series_datasets/mhc_metabolic_qa_dataset.py
  • time_series_datasets/mhc_multi_label_qa_dataset.py

Comment thread scripts/linear_probe_vlm.py Outdated
Comment on lines +91 to +99
def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file):
"""Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's
``sharable_users_seed42_2026.json``. Required for canonical splits.
"""
for label in labels:
if label not in LABEL_TYPES:
raise ValueError(f"Unknown label: {label!r}")

user_to_bucket = _load_split_assignment(split_file)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add explicit validation for split_file parameter.

Unlike MHCBaseQADataset which validates split_file is None with a clear error message, this constructor passes split_file directly to _load_split_assignment(). If None is passed, this will fail with an unhelpful TypeError when open(None) is called.

🛡️ Proposed fix: Add validation
 class MHCMultiLabelQADataset(Dataset):
     def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file):
         """Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's
         ``sharable_users_seed42_2026.json``. Required for canonical splits.
         """
+        if split_file is None:
+            raise ValueError(
+                "split_file is required: pass MHC-benchmark's "
+                "sharable_users_seed42_2026.json so split assignment is consistent."
+            )
         for label in labels:
             if label not in LABEL_TYPES:
                 raise ValueError(f"Unknown label: {label!r}")

         user_to_bucket = _load_split_assignment(split_file)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file):
"""Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's
``sharable_users_seed42_2026.json``. Required for canonical splits.
"""
for label in labels:
if label not in LABEL_TYPES:
raise ValueError(f"Unknown label: {label!r}")
user_to_bucket = _load_split_assignment(split_file)
def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file):
"""Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's
``sharable_users_seed42_2026.json``. Required for canonical splits.
"""
if split_file is None:
raise ValueError(
"split_file is required: pass MHC-benchmark's "
"sharable_users_seed42_2026.json so split assignment is consistent."
)
for label in labels:
if label not in LABEL_TYPES:
raise ValueError(f"Unknown label: {label!r}")
user_to_bucket = _load_split_assignment(split_file)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@time_series_datasets/mhc_multi_label_qa_dataset.py` around lines 91 - 99, The
__init__ of MHCMultiLabelQADataset currently passes split_file directly to
_load_split_assignment which will raise an unhelpful TypeError if split_file is
None; add explicit validation at the start of MHCMultiLabelQADataset.__init__
(similar to MHCBaseQADataset) that raises a clear ValueError when split_file is
None or empty, and only call _load_split_assignment(split_file) after this check
so _load_split_assignment always receives a valid path.

@max-rosenblattl max-rosenblattl merged commit 345da28 into main May 5, 2026
2 of 3 checks passed
@max-rosenblattl max-rosenblattl deleted the max-rosenblattl/canonical-splits branch May 5, 2026 02:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant