Use MHC-benchmark canonical splits across MHC datasets#69
Use MHC-benchmark canonical splits across MHC datasets#69max-rosenblattl merged 5 commits intomainfrom
Conversation
Replaces the SHA-1 hash-based 80/10/10 bucketing with explicit lookups against data/splits/sharable_users_seed42_2026.json, the canonical user-level split shared with MHC-benchmark. Adds a split_file parameter to MHCDataset and WeeklyMHCDataset that drops non-sharable users at load time, before any caption gets written. Threads the same split_file through all QA datasets, the multi-label fine-tuning script, both linear probes, the metabolic eval runner, and the curriculum-learning trainer. This guarantees: (1) non-sharable users never enter caption / QA / probe artifacts, (2) any user labeled "test" by MHC-benchmark is also "test" here, removing same-user-in-different-splits leakage across repos.
|
Caution Review failedThe pull request is closed. ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (8)
📝 WalkthroughWalkthroughThe PR threads a canonical JSON split specification through dataset constructors and scripts (replacing SHA1/hash bucketing) and requires MHC_SPLIT_FILE in some entrypoints. Separately, the evaluation pipeline gains optional incremental checkpointing: Evaluator.run can flush partial EvalResult snapshots periodically (and on interrupt) to a reusable out_dir. ChangesSplit-file integration (datasets, scripts, curriculum)
Evaluation incremental checkpointing
Sequence Diagram(s)sequenceDiagram
participant User
participant Evaluator
participant ResultsWriter
participant Filesystem
User->>Evaluator: run(task, model, dataset, save_every, results_writer, out_dir)
Note over Evaluator: pre-create out_dir
loop per sample
Evaluator->>Evaluator: evaluate sample -> append per_sample
alt sample_count % save_every == 0
Evaluator->>Evaluator: build partial EvalResult
Evaluator->>ResultsWriter: write(partial, out_dir)
ResultsWriter->>Filesystem: write metrics.json & predictions.csv
Filesystem-->>ResultsWriter: files persisted
end
end
alt KeyboardInterrupt
User->>Evaluator: Ctrl+C
Evaluator->>Evaluator: build partial EvalResult
Evaluator->>ResultsWriter: write(partial, out_dir)
else Normal completion
Evaluator->>Evaluator: build final EvalResult
Evaluator->>ResultsWriter: write(final, out_dir)
end
ResultsWriter-->>Evaluator: path returned
Evaluator-->>User: return EvalResult
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
scripts/linear_probe.py (1)
263-265:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winPass
args.split_fileon the normal finalize path.This call is now missing
_finalize(...)'s newsplit_fileargument, so a successful extraction run will crash withTypeErrorbefore writingembeddings.npz. The interrupt and partial-resume paths already pass it.🐛 Proposed fix
_finalize( - out_path, embeddings, user_ids, dates, completed, args.embedding_source, args.label_lookup + out_path, embeddings, user_ids, dates, completed, + args.embedding_source, args.label_lookup, args.split_file, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@scripts/linear_probe.py` around lines 263 - 265, The normal successful path calls _finalize(...) without the new split_file parameter causing a TypeError; update the call in scripts/linear_probe.py to pass args.split_file (i.e., call _finalize(out_path, embeddings, user_ids, dates, completed, args.embedding_source, args.label_lookup, args.split_file)) so it matches the interrupt/partial-resume paths and the _finalize function signature.
🧹 Nitpick comments (3)
time_series_datasets/mhc_base_qa_dataset.py (1)
30-45: 💤 Low valueConsider validating JSON structure for clearer error messages.
If the split JSON is malformed (e.g., missing
"validation"key), this will raise a genericKeyError. A brief validation could improve debuggability.💡 Optional: Add key validation
def _load_split_assignment( split_file: str | Path, ) -> dict[str, Literal["train", "validation", "test"]]: ... with open(split_file) as f: splits = json.load(f) + required = {"train", "validation", "test"} + missing = required - splits.keys() + if missing: + raise ValueError(f"split_file missing required keys: {missing}") user_to_bucket: dict[str, Literal["train", "validation", "test"]] = {}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@time_series_datasets/mhc_base_qa_dataset.py` around lines 30 - 45, The _load_split_assignment function currently assumes the loaded JSON has "train", "validation", and "test" keys and will raise a generic KeyError if any are missing; update it to validate the structure after json.load(splits) by checking that splits is a dict and contains the three required keys (and that each value is a sequence), and if not raise a clear ValueError or TypeError with a message like "split file missing keys: ..." mentioning which keys are absent; keep the existing user_to_bucket filling logic unchanged but only run it after the validation so callers get a helpful error instead of a raw KeyError.evaluation/evaluator.py (1)
26-28: 💤 Low valueGlobal logging level change may have unintended side effects.
Setting
hf_logging.set_verbosity_error()at module load time affects all HuggingFace transformers logging across the entire application. If other parts of the codebase rely on warning-level logs for debugging, they'll be silenced.Consider scoping this more narrowly or making it configurable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@evaluation/evaluator.py` around lines 26 - 28, hf_logging.set_verbosity_error() is being called at module import which globally silences all HuggingFace transformer logs; instead limit scope by either making the change configurable or targeting only the specific logger(s). Replace the global call with a scoped approach: read a config/env flag or add an optional parameter, and set the level on the specific logger name(s) (e.g., the "transformers" or "transformers.generation" logger) or apply a temporary context around the code that produces the pad_token_id warning; update references in evaluator.py where hf_logging.set_verbosity_error is used so only the intended logs are suppressed.mhc_weekly/dataset.py (1)
31-37: 💤 Low valueConsider using PyArrow filtering for consistency with
mhc/dataset.py.The lambda-based
filter()works correctly but differs from the vectorizedpc.is_in()approach used inmhc/dataset.py(lines 41-46). For large datasets, the PyArrow approach is more efficient.💡 Optional: Use PyArrow filtering for consistency
+import pyarrow as pa +import pyarrow.compute as pc ... if split_file is not None: with open(split_file) as f: splits = json.load(f) allowed_users = ( set(splits["train"]) | set(splits["validation"]) | set(splits["test"]) ) - self.ds = self.ds.filter(lambda row: row["user_id"] in allowed_users) + mask = pc.is_in( + self.ds.data.column("user_id"), + value_set=pa.array(sorted(allowed_users)), + ) + indices = [i for i, m in enumerate(mask.to_pylist()) if m] + self.ds = self.ds.select(indices)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@mhc_weekly/dataset.py` around lines 31 - 37, Replace the Python-level lambda filter over self.ds with a PyArrow vectorized filter like in mhc/dataset.py: import pyarrow.compute as pc and build a PyArrow array of allowed_users (pa.array(list(allowed_users))) then call self.ds = self.ds.filter(pc.is_in(self.ds["user_id"], allowed_users_pa)) so filtering uses pc.is_in instead of lambda; reference symbols: self.ds, split_file, allowed_users, pc.is_in.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@scripts/linear_probe_vlm.py`:
- Around line 217-225: The code currently maps missing users to "unknown" via
user_to_bucket.get(...) and leaves them in the output; change _finalize so that
after computing splits (using _load_split_assignment) you filter out any indices
where user_to_bucket has no entry (i.e., where splits == "unknown") and drop
those rows from embeddings, user_ids, week_starts, and the joined labels_df
before saving; reference the symbols _finalize, _load_split_assignment,
embeddings, user_ids, week_starts, labels_df and splits to locate the change.
Ideally also apply the same filtering earlier when loading ds, but minimally
ensure this block removes non-sharable users rather than writing them as
"unknown".
In `@time_series_datasets/mhc_multi_label_qa_dataset.py`:
- Around line 91-99: The __init__ of MHCMultiLabelQADataset currently passes
split_file directly to _load_split_assignment which will raise an unhelpful
TypeError if split_file is None; add explicit validation at the start of
MHCMultiLabelQADataset.__init__ (similar to MHCBaseQADataset) that raises a
clear ValueError when split_file is None or empty, and only call
_load_split_assignment(split_file) after this check so _load_split_assignment
always receives a valid path.
---
Outside diff comments:
In `@scripts/linear_probe.py`:
- Around line 263-265: The normal successful path calls _finalize(...) without
the new split_file parameter causing a TypeError; update the call in
scripts/linear_probe.py to pass args.split_file (i.e., call _finalize(out_path,
embeddings, user_ids, dates, completed, args.embedding_source,
args.label_lookup, args.split_file)) so it matches the interrupt/partial-resume
paths and the _finalize function signature.
---
Nitpick comments:
In `@evaluation/evaluator.py`:
- Around line 26-28: hf_logging.set_verbosity_error() is being called at module
import which globally silences all HuggingFace transformer logs; instead limit
scope by either making the change configurable or targeting only the specific
logger(s). Replace the global call with a scoped approach: read a config/env
flag or add an optional parameter, and set the level on the specific logger
name(s) (e.g., the "transformers" or "transformers.generation" logger) or apply
a temporary context around the code that produces the pad_token_id warning;
update references in evaluator.py where hf_logging.set_verbosity_error is used
so only the intended logs are suppressed.
In `@mhc_weekly/dataset.py`:
- Around line 31-37: Replace the Python-level lambda filter over self.ds with a
PyArrow vectorized filter like in mhc/dataset.py: import pyarrow.compute as pc
and build a PyArrow array of allowed_users (pa.array(list(allowed_users))) then
call self.ds = self.ds.filter(pc.is_in(self.ds["user_id"], allowed_users_pa)) so
filtering uses pc.is_in instead of lambda; reference symbols: self.ds,
split_file, allowed_users, pc.is_in.
In `@time_series_datasets/mhc_base_qa_dataset.py`:
- Around line 30-45: The _load_split_assignment function currently assumes the
loaded JSON has "train", "validation", and "test" keys and will raise a generic
KeyError if any are missing; update it to validate the structure after
json.load(splits) by checking that splits is a dict and contains the three
required keys (and that each value is a sequence), and if not raise a clear
ValueError or TypeError with a message like "split file missing keys: ..."
mentioning which keys are absent; keep the existing user_to_bucket filling logic
unchanged but only run it after the validation so callers get a helpful error
instead of a raw KeyError.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: b5de683f-a0a9-4c7c-9744-22d7450414d2
📒 Files selected for processing (14)
curriculum_learning.pyevaluation/evaluator.pyevaluation/io/writer.pyevaluation/run_eval.pymhc/dataset.pymhc_weekly/dataset.pyscripts/export_captions.pyscripts/export_captions_sharded.shscripts/linear_probe.pyscripts/linear_probe_vlm.pyscripts/metabolic_finetune.pytime_series_datasets/mhc_base_qa_dataset.pytime_series_datasets/mhc_metabolic_qa_dataset.pytime_series_datasets/mhc_multi_label_qa_dataset.py
| def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file): | ||
| """Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's | ||
| ``sharable_users_seed42_2026.json``. Required for canonical splits. | ||
| """ | ||
| for label in labels: | ||
| if label not in LABEL_TYPES: | ||
| raise ValueError(f"Unknown label: {label!r}") | ||
|
|
||
| user_to_bucket = _load_split_assignment(split_file) |
There was a problem hiding this comment.
Add explicit validation for split_file parameter.
Unlike MHCBaseQADataset which validates split_file is None with a clear error message, this constructor passes split_file directly to _load_split_assignment(). If None is passed, this will fail with an unhelpful TypeError when open(None) is called.
🛡️ Proposed fix: Add validation
class MHCMultiLabelQADataset(Dataset):
def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file):
"""Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's
``sharable_users_seed42_2026.json``. Required for canonical splits.
"""
+ if split_file is None:
+ raise ValueError(
+ "split_file is required: pass MHC-benchmark's "
+ "sharable_users_seed42_2026.json so split assignment is consistent."
+ )
for label in labels:
if label not in LABEL_TYPES:
raise ValueError(f"Unknown label: {label!r}")
user_to_bucket = _load_split_assignment(split_file)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file): | |
| """Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's | |
| ``sharable_users_seed42_2026.json``. Required for canonical splits. | |
| """ | |
| for label in labels: | |
| if label not in LABEL_TYPES: | |
| raise ValueError(f"Unknown label: {label!r}") | |
| user_to_bucket = _load_split_assignment(split_file) | |
| def __init__(self, split, EOS_TOKEN, labels, export_dir, label_lookup_path, split_file): | |
| """Args mirror MHCBaseQADataset; ``split_file`` points at MHC-benchmark's | |
| ``sharable_users_seed42_2026.json``. Required for canonical splits. | |
| """ | |
| if split_file is None: | |
| raise ValueError( | |
| "split_file is required: pass MHC-benchmark's " | |
| "sharable_users_seed42_2026.json so split assignment is consistent." | |
| ) | |
| for label in labels: | |
| if label not in LABEL_TYPES: | |
| raise ValueError(f"Unknown label: {label!r}") | |
| user_to_bucket = _load_split_assignment(split_file) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@time_series_datasets/mhc_multi_label_qa_dataset.py` around lines 91 - 99, The
__init__ of MHCMultiLabelQADataset currently passes split_file directly to
_load_split_assignment which will raise an unhelpful TypeError if split_file is
None; add explicit validation at the start of MHCMultiLabelQADataset.__init__
(similar to MHCBaseQADataset) that raises a clear ValueError when split_file is
None or empty, and only call _load_split_assignment(split_file) after this check
so _load_split_assignment always receives a valid path.
♻️ Current situation & Problem
SensorTSLM was assigning users to train / validation / test via SHA-1 hashing of the user_id (80/10/10). MHC-benchmark uses an explicit canonical split file. The two repos disagree on which split each user lives in, so a user marked "test" in MHC-benchmark could end up in "train" here — same-user leakage when downstream evaluations cross-reference. SensorTSLM also had no notion of non-sharable users, so they could leak into caption / QA / probe artifacts.
⚙️ Release Notes
sharable_users_seed42_2026.jsonsplit_fileparameter toMHCDatasetandWeeklyMHCDatasetthat drops non-sharable users at load time, before any caption / QA shard is writtensplit_filethrough all QA datasets, both linear probes, the multi-label fine-tune script, the metabolic eval runner, and the curriculum trainer--split_filethrough caption export entry points📚 Documentation
With
split_fileprovided, the union oftrain+validation+testuser IDs becomes a hard filter at dataset load. Samesplit_fileshape as MHC-benchmark, so a single shared JSON drives both repos. PassNoneto retain the previous SHA-1 behavior — useful for legacy pipelines.✅ Testing
MHCDatasetandWeeklyMHCDatasetwith and withoutsplit_file; row counts match the union of allowed user IDs--split_fileflag verified end-to-end on a small shardCode of Conduct & Contributing Guidelines
By creating and submitting this pull request, you agree to follow our Code of Conduct and Contributing Guidelines: