Add curriculum learning training pipeline#50
Conversation
📝 WalkthroughWalkthroughAdds broader Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CLI as main()
participant Trainer as CurriculumTrainer
participant Model as Model<br/>(OpenTSLMSP / Flamingo)
participant Dataset as Dataset<br/>& DataLoader
participant Optimizer as AdamW<br/>+ Scheduler
participant Storage as Checkpoint / Results
User->>CLI: invoke curriculum run (stages, options)
CLI->>Trainer: init with StageConfig list
loop per stage
Trainer->>Storage: load initial or previous best checkpoint
Storage-->>Trainer: checkpoint data
Trainer->>Model: construct/load weights, enable LoRA/grad-ckpt, set device
Trainer->>Dataset: build train/val/test loaders
alt eval_only
Trainer->>Model: generate on test set
Model-->>Trainer: predictions
else training
loop epochs
loop batches
Dataset-->>Model: batch
Model->>Model: forward -> loss
Model->>Optimizer: backward
Optimizer->>Model: clip grads and step
Optimizer->>Optimizer: scheduler step
end
Trainer->>Dataset: validate -> val loss
alt best val improved
Trainer->>Storage: save best checkpoint
end
Trainer->>Storage: optionally save latest checkpoint
end
Trainer->>Storage: load best checkpoint
Trainer->>Model: apply best state
Trainer->>Model: generate on test set
Model-->>Trainer: predictions
end
Trainer->>Storage: write test_predictions.jsonl and metrics.json
end
Trainer->>Storage: aggregate curriculum_results.json
Trainer-->>User: return aggregated results
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@curriculum_learning.py`:
- Around line 138-149: The _is_stage_completed helper currently returns True
based only on metrics and predictions, causing stages to be considered complete
even if the checkpoint (e.g., checkpoints/best_model.pt) is missing; update
_is_stage_completed to also verify the expected checkpoint file exists before
returning True. Locate the _is_stage_completed method and augment its final
return condition to include a check that the checkpoint path (the same path used
by _load_previous_stage_model, e.g., self._checkpoint_path(stage_name) or the
specific "checkpoints/best_model.pt" for that stage) exists; ensure you call
self._predictions_path(stage_name).exists() and the checkpoint exists in the
same boolean expression so the stage is only marked complete when metrics,
predictions, and the checkpoint are all present.
- Around line 361-363: The current sequence creates the model, then loads
previous-stage checkpoint LoRA weights before LoRA adapters exist, preventing
proper LoRA weight restoration for OpenTSLMSP stages; change the call order so
_enable_lora_if_needed(stage_name) runs immediately after self.model =
self._create_model() and before _load_previous_stage_model(stage_name), or
alternatively ensure _load_previous_stage_model defers calling
load_lora_state_from_checkpoint() until after _enable_lora_if_needed has created
adapters; update references to _create_model, _enable_lora_if_needed,
_load_previous_stage_model and load_lora_state_from_checkpoint accordingly.
- Line 184: Replace unsafe torch.load calls that pass weights_only=False with
weights_only=True to avoid unpickling arbitrary Python objects; specifically
update the checkpoint loads that assign checkpoint = torch.load(path,
map_location="cpu", weights_only=False) (and the analogous calls at the other
two occurrences around lines 207 and 218) so they use weights_only=True,
ensuring only tensors/state-dicts/primitives are deserialized when loading
checkpoints (including external --initial_checkpoint).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: bcd7158c-be47-4004-b589-36a423cac31b
📒 Files selected for processing (2)
.gitignorecurriculum_learning.py
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
curriculum_learning.py (1)
475-480:⚠️ Potential issue | 🟠 MajorDon't skip completed stages when
--eval_onlyis requested.Right now
run_curriculum()returns cached metrics before_train_stage(..., eval_only=True)can run, so--eval_onlynever re-evaluates an already-completed stage or regenerates predictions.Suggested fix
- if self._is_stage_completed(stage_name): + if not eval_only and self._is_stage_completed(stage_name): print(f"Skipping completed stage: {stage_name}") with open(self._metrics_path(stage_name)) as f: results[stage_name] = json.load(f) continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@curriculum_learning.py` around lines 475 - 480, The loop in run_curriculum currently skips stages that return True from _is_stage_completed(stage_name), which prevents --eval_only from re-running evaluation; change the logic so that when eval_only (or the equivalent flag on self) is true you do not continue past completed stages—instead call _train_stage(stage_name, eval_only=True) to regenerate predictions/metrics and then load or store the results from _metrics_path(stage_name); keep the existing behavior for non-eval runs (still skip completed stages), and reference run_curriculum, _is_stage_completed, _train_stage, and _metrics_path when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@curriculum_learning.py`:
- Around line 369-381: After constructing train_loader, val_loader, and
test_loader (the DataLoader instances created from config.dataset_class),
validate that each has at least one batch (e.g., check len(list(train_loader)) >
0 or the underlying dataset length) and raise a clear error or fallback behavior
if any split is empty to avoid zero-step schedules and divide-by-zero when
computing averages in functions that use train_loader/val_loader/test_loader
(references: train_loader, val_loader, test_loader, config.dataset_class);
ensure the check runs immediately after those DataLoader constructions and
prevents proceeding to schedule creation or metric averaging when a split is
empty.
- Around line 154-157: _mark_stage_completed currently overwrites the metrics
file and drops existing fields; change it to first read and parse the existing
metrics JSON (using self._metrics_path(stage_name)), merge/update the
"completed": True key into that dict (handling FileNotFoundError or JSON decode
errors by starting with an empty dict), then write the merged dict back to the
same path with json.dump so fields like "best_val_loss" are preserved.
---
Duplicate comments:
In `@curriculum_learning.py`:
- Around line 475-480: The loop in run_curriculum currently skips stages that
return True from _is_stage_completed(stage_name), which prevents --eval_only
from re-running evaluation; change the logic so that when eval_only (or the
equivalent flag on self) is true you do not continue past completed
stages—instead call _train_stage(stage_name, eval_only=True) to regenerate
predictions/metrics and then load or store the results from
_metrics_path(stage_name); keep the existing behavior for non-eval runs (still
skip completed stages), and reference run_curriculum, _is_stage_completed,
_train_stage, and _metrics_path when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 3b1c2436-064f-4a21-8f1e-46522dba3f8f
📒 Files selected for processing (1)
curriculum_learning.py
| def _mark_stage_completed(self, stage_name: str, metrics: dict): | ||
| metrics["completed"] = True | ||
| with open(self._metrics_path(stage_name), "w") as f: | ||
| json.dump(metrics, f, indent=2) |
There was a problem hiding this comment.
Preserve existing stage metrics when marking completion.
This helper always rewrites metrics.json from scratch. In the eval-only path that drops previously recorded fields like best_val_loss, so a reevaluation silently destroys training metadata.
Suggested fix
def _mark_stage_completed(self, stage_name: str, metrics: dict):
- metrics["completed"] = True
- with open(self._metrics_path(stage_name), "w") as f:
- json.dump(metrics, f, indent=2)
+ path = self._metrics_path(stage_name)
+ merged_metrics = {}
+ if path.exists():
+ with open(path) as f:
+ merged_metrics = json.load(f)
+ merged_metrics.update(metrics)
+ merged_metrics["completed"] = True
+ with open(path, "w") as f:
+ json.dump(merged_metrics, f, indent=2)
print(f"Stage {stage_name} marked as completed")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@curriculum_learning.py` around lines 154 - 157, _mark_stage_completed
currently overwrites the metrics file and drops existing fields; change it to
first read and parse the existing metrics JSON (using
self._metrics_path(stage_name)), merge/update the "completed": True key into
that dict (handling FileNotFoundError or JSON decode errors by starting with an
empty dict), then write the merged dict back to the same path with json.dump so
fields like "best_val_loss" are preserved.
| train_dataset = config.dataset_class(split="train", **dataset_kwargs) | ||
| val_dataset = config.dataset_class(split="validation", **dataset_kwargs) | ||
| test_dataset = config.dataset_class(split="test", **dataset_kwargs) | ||
|
|
||
| train_loader = DataLoader( | ||
| train_dataset, batch_size=batch_size, shuffle=True, collate_fn=list | ||
| ) | ||
| val_loader = DataLoader( | ||
| val_dataset, batch_size=1, shuffle=False, collate_fn=list | ||
| ) | ||
| test_loader = DataLoader( | ||
| test_dataset, batch_size=1, shuffle=False, collate_fn=list | ||
| ) |
There was a problem hiding this comment.
Guard against empty splits before training/evaluation.
Please validate the loaders right after construction. Empty smoke/export splits make Line 395 create a zero-step schedule, and the later averages at Line 423, Line 435, and Line 330 will divide by zero.
Suggested fix
test_loader = DataLoader(
test_dataset, batch_size=1, shuffle=False, collate_fn=list
)
+ if len(test_loader) == 0:
+ raise ValueError(f"Stage {stage_name} has no test samples")
+ if not eval_only:
+ if len(train_loader) == 0:
+ raise ValueError(f"Stage {stage_name} has no training samples")
+ if len(val_loader) == 0:
+ raise ValueError(f"Stage {stage_name} has no validation samples")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@curriculum_learning.py` around lines 369 - 381, After constructing
train_loader, val_loader, and test_loader (the DataLoader instances created from
config.dataset_class), validate that each has at least one batch (e.g., check
len(list(train_loader)) > 0 or the underlying dataset length) and raise a clear
error or fallback behavior if any split is empty to avoid zero-step schedules
and divide-by-zero when computing averages in functions that use
train_loader/val_loader/test_loader (references: train_loader, val_loader,
test_loader, config.dataset_class); ensure the check runs immediately after
those DataLoader constructions and prevents proceeding to schedule creation or
metric averaging when a split is empty.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
curriculum_learning.py (2)
155-158:⚠️ Potential issue | 🟡 MinorPreserve prior metrics when marking a stage complete.
This rewrites
metrics.jsonfrom scratch, so an eval-only rerun drops fields already recorded for the stage, includingbest_val_loss. Merge into the existing JSON before setting"completed": true.♻️ Suggested change
def _mark_stage_completed(self, stage_name: str, metrics: dict): - metrics["completed"] = True - with open(self._metrics_path(stage_name), "w") as f: - json.dump(metrics, f, indent=2) + path = self._metrics_path(stage_name) + merged_metrics = {} + if path.exists(): + with open(path) as f: + merged_metrics = json.load(f) + merged_metrics.update(metrics) + merged_metrics["completed"] = True + with open(path, "w") as f: + json.dump(merged_metrics, f, indent=2) print(f"Stage {stage_name} marked as completed")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@curriculum_learning.py` around lines 155 - 158, In _mark_stage_completed, avoid overwriting existing metrics.json: read the current JSON from self._metrics_path(stage_name) if it exists, merge its contents with the incoming metrics dict (preserving prior fields such as best_val_loss and any other recorded keys), then set "completed": True on the merged dict and write that merged object back; ensure you still use the same self._metrics_path(stage_name) and handle missing/invalid JSON by falling back to the provided metrics.
380-388:⚠️ Potential issue | 🟠 MajorFail fast on empty splits before building the schedule.
An empty training split makes Line 402 create a zero-step scheduler, and empty train/validation/test splits later divide by zero at Line 440, Line 461, and Line 337. Validate the loaders immediately after construction and raise a clear error.
🛡️ Suggested change
test_loader = DataLoader( test_dataset, batch_size=1, shuffle=False, collate_fn=list ) + if len(test_loader) == 0: + raise ValueError(f"Stage {stage_name} has no test samples") + if not eval_only: + if len(train_loader) == 0: + raise ValueError(f"Stage {stage_name} has no training samples") + if len(val_loader) == 0: + raise ValueError(f"Stage {stage_name} has no validation samples")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@curriculum_learning.py` around lines 380 - 388, After constructing train_loader, val_loader, and test_loader, immediately validate that each contains at least one sample and raise a clear ValueError if any split is empty; specifically check len(train_loader.dataset) (or len(train_loader)) and similarly for val_loader and test_loader before proceeding to build the scheduler or use these loaders so you avoid creating a zero-step scheduler and subsequent divide-by-zero errors when creating the scheduler and computing per-epoch/step metrics. Ensure the error message names the offending loader (train/val/test) and stops execution before scheduler creation or any downstream computations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@curriculum_learning.py`:
- Around line 501-506: The short-circuit that skips completed stages (the for
loop over stage_names using self._is_stage_completed(stage_name) and reading
self._metrics_path(stage_name)) must honor the eval_only flag: change the logic
so you only skip returning cached metrics when eval_only is False; if eval_only
is True, do not short-circuit completed stages and instead call
self._train_stage(stage_name, ..., eval_only=True) so predictions/metrics are
refreshed from the checkpoint before loading/saving metrics. Ensure references
to _is_stage_completed, _train_stage, and _metrics_path are used to implement
this conditional behavior.
---
Duplicate comments:
In `@curriculum_learning.py`:
- Around line 155-158: In _mark_stage_completed, avoid overwriting existing
metrics.json: read the current JSON from self._metrics_path(stage_name) if it
exists, merge its contents with the incoming metrics dict (preserving prior
fields such as best_val_loss and any other recorded keys), then set "completed":
True on the merged dict and write that merged object back; ensure you still use
the same self._metrics_path(stage_name) and handle missing/invalid JSON by
falling back to the provided metrics.
- Around line 380-388: After constructing train_loader, val_loader, and
test_loader, immediately validate that each contains at least one sample and
raise a clear ValueError if any split is empty; specifically check
len(train_loader.dataset) (or len(train_loader)) and similarly for val_loader
and test_loader before proceeding to build the scheduler or use these loaders so
you avoid creating a zero-step scheduler and subsequent divide-by-zero errors
when creating the scheduler and computing per-epoch/step metrics. Ensure the
error message names the offending loader (train/val/test) and stops execution
before scheduler creation or any downstream computations.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 5f1d5845-4860-42ca-a0ee-cb2d86c935f7
📒 Files selected for processing (1)
curriculum_learning.py
| for stage_name in stage_names: | ||
| if self._is_stage_completed(stage_name): | ||
| print(f"Skipping completed stage: {stage_name}") | ||
| with open(self._metrics_path(stage_name)) as f: | ||
| results[stage_name] = json.load(f) | ||
| continue |
There was a problem hiding this comment.
Honor --eval_only before the completion short-circuit.
Completed stages are skipped before _train_stage(..., eval_only=True) can run, so --eval_only just returns cached metrics and never refreshes predictions or metrics from the saved checkpoint. Only skip when eval_only is false.
🔁 Suggested change
- if self._is_stage_completed(stage_name):
+ if not eval_only and self._is_stage_completed(stage_name):
print(f"Skipping completed stage: {stage_name}")
with open(self._metrics_path(stage_name)) as f:
results[stage_name] = json.load(f)
continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@curriculum_learning.py` around lines 501 - 506, The short-circuit that skips
completed stages (the for loop over stage_names using
self._is_stage_completed(stage_name) and reading self._metrics_path(stage_name))
must honor the eval_only flag: change the logic so you only skip returning
cached metrics when eval_only is False; if eval_only is True, do not
short-circuit completed stages and instead call self._train_stage(stage_name,
..., eval_only=True) so predictions/metrics are refreshed from the checkpoint
before loading/saving metrics. Ensure references to _is_stage_completed,
_train_stage, and _metrics_path are used to implement this conditional behavior.
There was a problem hiding this comment.
🧹 Nitpick comments (2)
curriculum_learning.py (2)
355-355: Addstrict=Truetozip()for defensive debugging.A length mismatch between
ts_textsandts_valueswould indicate a data preparation bug. Addingstrict=Truesurfaces this immediately rather than silently truncating the iteration.Suggested fix
- for i, (text, ts) in enumerate(zip(ts_texts, ts_values)): + for i, (text, ts) in enumerate(zip(ts_texts, ts_values, strict=True)):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@curriculum_learning.py` at line 355, The loop over ts_texts and ts_values uses zip without strict checking, which silently truncates if lengths differ; update the zip call in the loop that iterates "for i, (text, ts) in enumerate(zip(ts_texts, ts_values))" to use zip(ts_texts, ts_values, strict=True) so a length mismatch raises immediately (adjust any surrounding code/comments to reflect the stricter behavior).
509-511: Consider usingitertools.islicefor more efficient mid-epoch resume.The current approach iterates through all batches and skips via
continue, which still incurs DataLoader overhead for skipped samples. For large datasets, this could add noticeable startup time on resume.Example using islice
+from itertools import islice + ... - for step, batch in enumerate(train_loader): - if step < step_skip: - continue + for step, batch in enumerate(islice(train_loader, step_skip, None), start=step_skip):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@curriculum_learning.py` around lines 509 - 511, The loop currently does "for step, batch in enumerate(train_loader): if step < step_skip: continue" which still iterates DataLoader items; replace this with itertools.islice to skip upstream: import and use islice(train_loader, step_skip, None) (and adjust enumerate start to step_skip) so the DataLoader does not yield skipped batches; update references to `train_loader`, `step`, `batch`, and `step_skip` in the enclosing function in curriculum_learning.py accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@curriculum_learning.py`:
- Line 355: The loop over ts_texts and ts_values uses zip without strict
checking, which silently truncates if lengths differ; update the zip call in the
loop that iterates "for i, (text, ts) in enumerate(zip(ts_texts, ts_values))" to
use zip(ts_texts, ts_values, strict=True) so a length mismatch raises
immediately (adjust any surrounding code/comments to reflect the stricter
behavior).
- Around line 509-511: The loop currently does "for step, batch in
enumerate(train_loader): if step < step_skip: continue" which still iterates
DataLoader items; replace this with itertools.islice to skip upstream: import
and use islice(train_loader, step_skip, None) (and adjust enumerate start to
step_skip) so the DataLoader does not yield skipped batches; update references
to `train_loader`, `step`, `batch`, and `step_skip` in the enclosing function in
curriculum_learning.py accordingly.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 8c543b6e-ccc5-478f-8e61-8e8997b2a62c
📒 Files selected for processing (2)
curriculum_learning.pytime_series_datasets/mhc_base_qa_dataset.py
Add curriculum learning training pipeline
♻️ Current situation & Problem
SensorTSLM needs a multi-stage training pipeline to fine-tune OpenTSLM models on our captioned MHC datasets. The original OpenTSLM
curriculum_learning.py(2,300 lines) includes DDP, SDFT, and 12 hardcoded stages we don't need.⚙️ Release Notes
CurriculumTrainerwith configurableStageConfigfor data-driven stage definitions--initial_checkpoint📚 Documentation
opentslmhardcodesattn_implementation="eager"andtorch_dtype=bfloat16inOpenTSLMSP.py/OpenTSLMFlamingo.py.sdpa+fp16gave a 5.3× speedup (11.47 → 2.15 s/step,Llama-3.2-1B, batch=4). On A100+ keep
sdpa, revert dtype tobfloat16✅ Testing
Tested on Sherlock using V100 Llama-3-1B 200 steps, loss going down, captions replicating statistics, but still underfitted:
Code of Conduct & Contributing Guidelines
By creating and submitting this pull request, you agree to follow our Code of Conduct and Contributing Guidelines: