Skip to content

Add curriculum learning training pipeline#50

Open
max-rosenblattl wants to merge 9 commits intomainfrom
max-rosenblattl/curriculum-learning
Open

Add curriculum learning training pipeline#50
max-rosenblattl wants to merge 9 commits intomainfrom
max-rosenblattl/curriculum-learning

Conversation

@max-rosenblattl
Copy link
Copy Markdown
Collaborator

@max-rosenblattl max-rosenblattl commented Apr 11, 2026

Add curriculum learning training pipeline

♻️ Current situation & Problem

SensorTSLM needs a multi-stage training pipeline to fine-tune OpenTSLM models on our captioned MHC datasets. The original OpenTSLM curriculum_learning.py (2,300 lines) includes DDP, SDFT, and 12 hardcoded stages we don't need.

⚙️ Release Notes

  • Add CurriculumTrainer with configurable StageConfig for data-driven stage definitions
  • Support both OpenTSLMSP and OpenTSLMFlamingo model types
  • Checkpoint management with automatic knowledge transfer between stages
  • Optional LoRA fine-tuning for SP model in later stages
  • Support for loading external pretrained checkpoints via --initial_checkpoint

📚 Documentation

  • opentslm hardcodes attn_implementation="eager" and torch_dtype=bfloat16 in
    OpenTSLMSP.py / OpenTSLMFlamingo.py.
  • On V100, patching to sdpa + fp16 gave a 5.3× speedup (11.47 → 2.15 s/step,
    Llama-3.2-1B, batch=4). On A100+ keep sdpa, revert dtype to bfloat16

✅ Testing

Tested on Sherlock using V100 Llama-3-1B 200 steps, loss going down, captions replicating statistics, but still underfitted:

[gold]
The average iPhone step count value is 138 steps/hr, with extremes at 2415 (max) and 0 (min), and a std of 427. The iPhone distance data exhibits a mean of 115.9 m/hr, a standard deviation of 355.2, and its extreme values are 0.0 and 2050.3. flights climbed (iPhone) average 0.4 count/hr, reaching a maximum of 10.0 and a minimum of 0.0, with a standard deviation of 1.7. Apple Watch step count exhibits a mean of 306 steps/hr, with peak and minimal values reaching 6945 and 0, and a standard deviation of 839. For the Apple Watch distance measurements, the mean is 268.9 m/hr, the standard deviation is 796.2, and the data lies between 0.0 and 7067.8. The average heart rate value is 71.2 bpm, with extremes at 158.7 (max) and 57.1 (min), and a std of 15.0. The active energy data exhibits a mean of 428 cal/hr, a standard deviation of 796, and its extreme values are 0 and 5807. A spike in iPhone step count was detected at hour 117. Spike event recorded for iPhone step count at hour 112. An outlier event in iPhone step count occurred at hour 88. Spike event recorded for iPhone step count at hour 32. A significant peak in iPhone step count was observed at hour 104. A spike in iPhone step count was detected at hour 40. iPhone step count shows a notable spike at hour 80. iPhone step count shows a notable spike at hour 156. Spike event recorded for iPhone distance at hour 117. A significant peak in iPhone distance was observed at hour 112. iPhone distance shows a notable spike at hour 88. A significant peak in iPhone distance was observed at hour 32. Spike event recorded for iPhone distance at hour 40. An outlier event in iPhone distance occurred at hour 104. A spike in iPhone distance was detected at hour 80. A spike in iPhone distance was detected at hour 156. flights climbed (iPhone) shows a notable spike at hour 104. An outlier event in flights climbed (iPhone) occurred at hour 32. Spike event recorded for flights climbed (iPhone) at hour 80. A spike in flights climbed (iPhone) was detected at hour 117. flights climbed (iPhone) shows a notable spike at hour 157. A spike in flights climbed (iPhone) was detected at hour 88. A significant peak in flights climbed (iPhone) was observed at hour 40. Spike event recorded for flights climbed (iPhone) at hour 112. An outlier event in flights climbed (iPhone) occurred at hour 107. A significant peak in flights climbed (iPhone) was observed at hour 109. A spike in Apple Watch step count was detected at hour 154. Spike event recorded for Apple Watch step count at hour 139. A spike in Apple Watch step count was detected at hour 90. A significant peak in Apple Watch step count was observed at hour 117. A significant peak in Apple Watch step count was observed at hour 32. An outlier event in Apple Watch step count occurred at hour 112. Apple Watch step count shows a notable spike at hour 104. Spike event recorded for Apple Watch step count at hour 80. Apple Watch step count shows a notable spike at hour 157. A spike in Apple Watch distance was detected at hour 154. Spike event recorded for Apple Watch distance at hour 139. A spike in Apple Watch distance was detected at hour 90. A significant peak in Apple Watch distance was observed at hour 117. An outlier event in Apple Watch distance occurred at hour 112. A significant peak in Apple Watch distance was observed at hour 32. Apple Watch distance shows a notable spike at hour 104. Spike event recorded for Apple Watch distance at hour 80. Apple Watch distance shows a notable spike at hour 157. A spike in heart rate was detected at hour 154. A significant peak in heart rate was observed at hour 90. Spike event recorded for heart rate at hour 139. active energy shows a notable spike at hour 154. A spike in active energy was detected at hour 139. active energy shows a notable spike at hour 90. Spike event recorded for active energy at hour 117. An outlier event in active energy occurred at hour 104. A significant peak in active energy was observed at hour 32. Spike event recorded for active energy at hour 80. A significant peak in active energy was observed at hour 112. A spike in active energy was detected at hour 88. An outlier event in active energy occurred at hour 157. Between hour 120 and 134, active energy exhibited a increasing trend. From hour 24 to 29, the user had a period of asleep. Asleep recorded within the 72-78 hour range. Asleep episode occurred between hour 93 and 102. Asleep was recorded between hour 119 and 127. Identified asleep across the timeframe of hour 142 to 151. From hour 24 to 29, the user had a period of inbed. Inbed recorded within the 72-78 hour range. Inbed episode occurred between hour 93 and 102. Inbed was recorded between hour 119 and 127. Identified inbed across the timeframe of hour 142 to 150.<|end_of_text|>
[generated]
 The iPhone step count data for this time period has a mean of 138.0 steps per hour, with a standard deviation of 427. The iPhone distance data has a mean of 115.0 m/hr, with a standard deviation of 355.197. The flights climbed (iPhone) data has a mean of 0.43 count/hr, with a standard deviation of 1.73. The Apple Watch step count data has a mean of 305.65 steps/hr, with a standard deviation of 839.4528. The Apple Watch distance data has a mean of 268.88 m/hr, with a standard deviation of 796.218. The heart rate data has a mean of 71.18 bpm, with a standard deviation of 14.99. The active energy data has a mean of 427.69 cal/hr, with a standard deviation of 1.73. The following is a detailed binary mask for the following step count data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following distance data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following heart rate data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following active energy data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following flight distance data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following flights climbed (iPhone) data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following Apple Watch step count data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following Apple Watch distance data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following heart rate data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following active energy data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following step count data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following distance data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following flight distance data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following flights climbed (iPhone) data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following Apple Watch step count data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following Apple Watch distance data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following heart rate data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following active energy data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following following step count data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following following distance data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following following flights climbed (iPhone) data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following following flights distance data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following flights climbed (iPhone) data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following active energy data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following following heart rate data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following active energy data, describing the person's activity, sleep, and physiological patterns as accurately as possible. The following is a detailed binary mask for the following following following following step count data, describing the person's activity,

Code of Conduct & Contributing Guidelines

By creating and submitting this pull request, you agree to follow our Code of Conduct and Contributing Guidelines:

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 12, 2026

📝 Walkthrough

Walkthrough

Adds broader .gitignore rules, a new multi-stage curriculum training/evaluation orchestration module curriculum_learning.py with checkpointing/LoRA/CLI, and a small dataset prompt sanitization change in time_series_datasets/mhc_base_qa_dataset.py.

Changes

Cohort / File(s) Summary
Git ignore updates
.gitignore
Broadened PNG ignore from sample_plot_*.png to all *.png; added exports/ and results/ to ignored paths.
Curriculum learning module
curriculum_learning.py
New module introducing StageConfig and CurriculumTrainer for multi-stage curriculum runs: model selection (OpenTSLMSP / OpenTSLMFlamingo), device and optional gradient checkpointing, LoRA handling, per-stage datasets/dataloaders, training loop (AdamW, linear warmup, grad clip), periodic/latest and best checkpointing, eval-only mode, test prediction export and metrics.json, and a CLI main().
Dataset prompt sanitization
time_series_datasets/mhc_base_qa_dataset.py
Changed _binary_prompt to convert NaNs to 0.0 via torch.nan_to_num before conversion and updated prompt text from “NaN = nonwear” to “0 = inactive or nonwear.”

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CLI as main()
    participant Trainer as CurriculumTrainer
    participant Model as Model<br/>(OpenTSLMSP / Flamingo)
    participant Dataset as Dataset<br/>& DataLoader
    participant Optimizer as AdamW<br/>+ Scheduler
    participant Storage as Checkpoint / Results

    User->>CLI: invoke curriculum run (stages, options)
    CLI->>Trainer: init with StageConfig list

    loop per stage
        Trainer->>Storage: load initial or previous best checkpoint
        Storage-->>Trainer: checkpoint data
        Trainer->>Model: construct/load weights, enable LoRA/grad-ckpt, set device
        Trainer->>Dataset: build train/val/test loaders

        alt eval_only
            Trainer->>Model: generate on test set
            Model-->>Trainer: predictions
        else training
            loop epochs
                loop batches
                    Dataset-->>Model: batch
                    Model->>Model: forward -> loss
                    Model->>Optimizer: backward
                    Optimizer->>Model: clip grads and step
                    Optimizer->>Optimizer: scheduler step
                end
                Trainer->>Dataset: validate -> val loss
                alt best val improved
                    Trainer->>Storage: save best checkpoint
                end
                Trainer->>Storage: optionally save latest checkpoint
            end
            Trainer->>Storage: load best checkpoint
            Trainer->>Model: apply best state
            Trainer->>Model: generate on test set
            Model-->>Trainer: predictions
        end

        Trainer->>Storage: write test_predictions.jsonl and metrics.json
    end

    Trainer->>Storage: aggregate curriculum_results.json
    Trainer-->>User: return aggregated results
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 3.70% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main changeset: adding a curriculum learning training pipeline to the codebase.
Description check ✅ Passed The PR description clearly relates to the changeset, detailing a curriculum learning pipeline addition with specific features, documentation, and testing results.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch max-rosenblattl/curriculum-learning

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@curriculum_learning.py`:
- Around line 138-149: The _is_stage_completed helper currently returns True
based only on metrics and predictions, causing stages to be considered complete
even if the checkpoint (e.g., checkpoints/best_model.pt) is missing; update
_is_stage_completed to also verify the expected checkpoint file exists before
returning True. Locate the _is_stage_completed method and augment its final
return condition to include a check that the checkpoint path (the same path used
by _load_previous_stage_model, e.g., self._checkpoint_path(stage_name) or the
specific "checkpoints/best_model.pt" for that stage) exists; ensure you call
self._predictions_path(stage_name).exists() and the checkpoint exists in the
same boolean expression so the stage is only marked complete when metrics,
predictions, and the checkpoint are all present.
- Around line 361-363: The current sequence creates the model, then loads
previous-stage checkpoint LoRA weights before LoRA adapters exist, preventing
proper LoRA weight restoration for OpenTSLMSP stages; change the call order so
_enable_lora_if_needed(stage_name) runs immediately after self.model =
self._create_model() and before _load_previous_stage_model(stage_name), or
alternatively ensure _load_previous_stage_model defers calling
load_lora_state_from_checkpoint() until after _enable_lora_if_needed has created
adapters; update references to _create_model, _enable_lora_if_needed,
_load_previous_stage_model and load_lora_state_from_checkpoint accordingly.
- Line 184: Replace unsafe torch.load calls that pass weights_only=False with
weights_only=True to avoid unpickling arbitrary Python objects; specifically
update the checkpoint loads that assign checkpoint = torch.load(path,
map_location="cpu", weights_only=False) (and the analogous calls at the other
two occurrences around lines 207 and 218) so they use weights_only=True,
ensuring only tensors/state-dicts/primitives are deserialized when loading
checkpoints (including external --initial_checkpoint).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: bcd7158c-be47-4004-b589-36a423cac31b

📥 Commits

Reviewing files that changed from the base of the PR and between 0b910c7 and 07e3dc3.

📒 Files selected for processing (2)
  • .gitignore
  • curriculum_learning.py

Comment thread curriculum_learning.py
Comment thread curriculum_learning.py
Comment thread curriculum_learning.py
Comment thread curriculum_learning.py
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
curriculum_learning.py (1)

475-480: ⚠️ Potential issue | 🟠 Major

Don't skip completed stages when --eval_only is requested.

Right now run_curriculum() returns cached metrics before _train_stage(..., eval_only=True) can run, so --eval_only never re-evaluates an already-completed stage or regenerates predictions.

Suggested fix
-            if self._is_stage_completed(stage_name):
+            if not eval_only and self._is_stage_completed(stage_name):
                 print(f"Skipping completed stage: {stage_name}")
                 with open(self._metrics_path(stage_name)) as f:
                     results[stage_name] = json.load(f)
                 continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@curriculum_learning.py` around lines 475 - 480, The loop in run_curriculum
currently skips stages that return True from _is_stage_completed(stage_name),
which prevents --eval_only from re-running evaluation; change the logic so that
when eval_only (or the equivalent flag on self) is true you do not continue past
completed stages—instead call _train_stage(stage_name, eval_only=True) to
regenerate predictions/metrics and then load or store the results from
_metrics_path(stage_name); keep the existing behavior for non-eval runs (still
skip completed stages), and reference run_curriculum, _is_stage_completed,
_train_stage, and _metrics_path when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@curriculum_learning.py`:
- Around line 369-381: After constructing train_loader, val_loader, and
test_loader (the DataLoader instances created from config.dataset_class),
validate that each has at least one batch (e.g., check len(list(train_loader)) >
0 or the underlying dataset length) and raise a clear error or fallback behavior
if any split is empty to avoid zero-step schedules and divide-by-zero when
computing averages in functions that use train_loader/val_loader/test_loader
(references: train_loader, val_loader, test_loader, config.dataset_class);
ensure the check runs immediately after those DataLoader constructions and
prevents proceeding to schedule creation or metric averaging when a split is
empty.
- Around line 154-157: _mark_stage_completed currently overwrites the metrics
file and drops existing fields; change it to first read and parse the existing
metrics JSON (using self._metrics_path(stage_name)), merge/update the
"completed": True key into that dict (handling FileNotFoundError or JSON decode
errors by starting with an empty dict), then write the merged dict back to the
same path with json.dump so fields like "best_val_loss" are preserved.

---

Duplicate comments:
In `@curriculum_learning.py`:
- Around line 475-480: The loop in run_curriculum currently skips stages that
return True from _is_stage_completed(stage_name), which prevents --eval_only
from re-running evaluation; change the logic so that when eval_only (or the
equivalent flag on self) is true you do not continue past completed
stages—instead call _train_stage(stage_name, eval_only=True) to regenerate
predictions/metrics and then load or store the results from
_metrics_path(stage_name); keep the existing behavior for non-eval runs (still
skip completed stages), and reference run_curriculum, _is_stage_completed,
_train_stage, and _metrics_path when making the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 3b1c2436-064f-4a21-8f1e-46522dba3f8f

📥 Commits

Reviewing files that changed from the base of the PR and between 07e3dc3 and ecb835a.

📒 Files selected for processing (1)
  • curriculum_learning.py

Comment thread curriculum_learning.py
Comment on lines +154 to +157
def _mark_stage_completed(self, stage_name: str, metrics: dict):
metrics["completed"] = True
with open(self._metrics_path(stage_name), "w") as f:
json.dump(metrics, f, indent=2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Preserve existing stage metrics when marking completion.

This helper always rewrites metrics.json from scratch. In the eval-only path that drops previously recorded fields like best_val_loss, so a reevaluation silently destroys training metadata.

Suggested fix
     def _mark_stage_completed(self, stage_name: str, metrics: dict):
-        metrics["completed"] = True
-        with open(self._metrics_path(stage_name), "w") as f:
-            json.dump(metrics, f, indent=2)
+        path = self._metrics_path(stage_name)
+        merged_metrics = {}
+        if path.exists():
+            with open(path) as f:
+                merged_metrics = json.load(f)
+        merged_metrics.update(metrics)
+        merged_metrics["completed"] = True
+        with open(path, "w") as f:
+            json.dump(merged_metrics, f, indent=2)
         print(f"Stage {stage_name} marked as completed")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@curriculum_learning.py` around lines 154 - 157, _mark_stage_completed
currently overwrites the metrics file and drops existing fields; change it to
first read and parse the existing metrics JSON (using
self._metrics_path(stage_name)), merge/update the "completed": True key into
that dict (handling FileNotFoundError or JSON decode errors by starting with an
empty dict), then write the merged dict back to the same path with json.dump so
fields like "best_val_loss" are preserved.

Comment thread curriculum_learning.py
Comment on lines +369 to +381
train_dataset = config.dataset_class(split="train", **dataset_kwargs)
val_dataset = config.dataset_class(split="validation", **dataset_kwargs)
test_dataset = config.dataset_class(split="test", **dataset_kwargs)

train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=list
)
val_loader = DataLoader(
val_dataset, batch_size=1, shuffle=False, collate_fn=list
)
test_loader = DataLoader(
test_dataset, batch_size=1, shuffle=False, collate_fn=list
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against empty splits before training/evaluation.

Please validate the loaders right after construction. Empty smoke/export splits make Line 395 create a zero-step schedule, and the later averages at Line 423, Line 435, and Line 330 will divide by zero.

Suggested fix
         test_loader = DataLoader(
             test_dataset, batch_size=1, shuffle=False, collate_fn=list
         )
+        if len(test_loader) == 0:
+            raise ValueError(f"Stage {stage_name} has no test samples")
+        if not eval_only:
+            if len(train_loader) == 0:
+                raise ValueError(f"Stage {stage_name} has no training samples")
+            if len(val_loader) == 0:
+                raise ValueError(f"Stage {stage_name} has no validation samples")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@curriculum_learning.py` around lines 369 - 381, After constructing
train_loader, val_loader, and test_loader (the DataLoader instances created from
config.dataset_class), validate that each has at least one batch (e.g., check
len(list(train_loader)) > 0 or the underlying dataset length) and raise a clear
error or fallback behavior if any split is empty to avoid zero-step schedules
and divide-by-zero when computing averages in functions that use
train_loader/val_loader/test_loader (references: train_loader, val_loader,
test_loader, config.dataset_class); ensure the check runs immediately after
those DataLoader constructions and prevents proceeding to schedule creation or
metric averaging when a split is empty.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
curriculum_learning.py (2)

155-158: ⚠️ Potential issue | 🟡 Minor

Preserve prior metrics when marking a stage complete.

This rewrites metrics.json from scratch, so an eval-only rerun drops fields already recorded for the stage, including best_val_loss. Merge into the existing JSON before setting "completed": true.

♻️ Suggested change
     def _mark_stage_completed(self, stage_name: str, metrics: dict):
-        metrics["completed"] = True
-        with open(self._metrics_path(stage_name), "w") as f:
-            json.dump(metrics, f, indent=2)
+        path = self._metrics_path(stage_name)
+        merged_metrics = {}
+        if path.exists():
+            with open(path) as f:
+                merged_metrics = json.load(f)
+        merged_metrics.update(metrics)
+        merged_metrics["completed"] = True
+        with open(path, "w") as f:
+            json.dump(merged_metrics, f, indent=2)
         print(f"Stage {stage_name} marked as completed")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@curriculum_learning.py` around lines 155 - 158, In _mark_stage_completed,
avoid overwriting existing metrics.json: read the current JSON from
self._metrics_path(stage_name) if it exists, merge its contents with the
incoming metrics dict (preserving prior fields such as best_val_loss and any
other recorded keys), then set "completed": True on the merged dict and write
that merged object back; ensure you still use the same
self._metrics_path(stage_name) and handle missing/invalid JSON by falling back
to the provided metrics.

380-388: ⚠️ Potential issue | 🟠 Major

Fail fast on empty splits before building the schedule.

An empty training split makes Line 402 create a zero-step scheduler, and empty train/validation/test splits later divide by zero at Line 440, Line 461, and Line 337. Validate the loaders immediately after construction and raise a clear error.

🛡️ Suggested change
         test_loader = DataLoader(
             test_dataset, batch_size=1, shuffle=False, collate_fn=list
         )
+        if len(test_loader) == 0:
+            raise ValueError(f"Stage {stage_name} has no test samples")
+        if not eval_only:
+            if len(train_loader) == 0:
+                raise ValueError(f"Stage {stage_name} has no training samples")
+            if len(val_loader) == 0:
+                raise ValueError(f"Stage {stage_name} has no validation samples")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@curriculum_learning.py` around lines 380 - 388, After constructing
train_loader, val_loader, and test_loader, immediately validate that each
contains at least one sample and raise a clear ValueError if any split is empty;
specifically check len(train_loader.dataset) (or len(train_loader)) and
similarly for val_loader and test_loader before proceeding to build the
scheduler or use these loaders so you avoid creating a zero-step scheduler and
subsequent divide-by-zero errors when creating the scheduler and computing
per-epoch/step metrics. Ensure the error message names the offending loader
(train/val/test) and stops execution before scheduler creation or any downstream
computations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@curriculum_learning.py`:
- Around line 501-506: The short-circuit that skips completed stages (the for
loop over stage_names using self._is_stage_completed(stage_name) and reading
self._metrics_path(stage_name)) must honor the eval_only flag: change the logic
so you only skip returning cached metrics when eval_only is False; if eval_only
is True, do not short-circuit completed stages and instead call
self._train_stage(stage_name, ..., eval_only=True) so predictions/metrics are
refreshed from the checkpoint before loading/saving metrics. Ensure references
to _is_stage_completed, _train_stage, and _metrics_path are used to implement
this conditional behavior.

---

Duplicate comments:
In `@curriculum_learning.py`:
- Around line 155-158: In _mark_stage_completed, avoid overwriting existing
metrics.json: read the current JSON from self._metrics_path(stage_name) if it
exists, merge its contents with the incoming metrics dict (preserving prior
fields such as best_val_loss and any other recorded keys), then set "completed":
True on the merged dict and write that merged object back; ensure you still use
the same self._metrics_path(stage_name) and handle missing/invalid JSON by
falling back to the provided metrics.
- Around line 380-388: After constructing train_loader, val_loader, and
test_loader, immediately validate that each contains at least one sample and
raise a clear ValueError if any split is empty; specifically check
len(train_loader.dataset) (or len(train_loader)) and similarly for val_loader
and test_loader before proceeding to build the scheduler or use these loaders so
you avoid creating a zero-step scheduler and subsequent divide-by-zero errors
when creating the scheduler and computing per-epoch/step metrics. Ensure the
error message names the offending loader (train/val/test) and stops execution
before scheduler creation or any downstream computations.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 5f1d5845-4860-42ca-a0ee-cb2d86c935f7

📥 Commits

Reviewing files that changed from the base of the PR and between ecb835a and d43ea04.

📒 Files selected for processing (1)
  • curriculum_learning.py

Comment thread curriculum_learning.py
Comment on lines +501 to +506
for stage_name in stage_names:
if self._is_stage_completed(stage_name):
print(f"Skipping completed stage: {stage_name}")
with open(self._metrics_path(stage_name)) as f:
results[stage_name] = json.load(f)
continue
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Honor --eval_only before the completion short-circuit.

Completed stages are skipped before _train_stage(..., eval_only=True) can run, so --eval_only just returns cached metrics and never refreshes predictions or metrics from the saved checkpoint. Only skip when eval_only is false.

🔁 Suggested change
-            if self._is_stage_completed(stage_name):
+            if not eval_only and self._is_stage_completed(stage_name):
                 print(f"Skipping completed stage: {stage_name}")
                 with open(self._metrics_path(stage_name)) as f:
                     results[stage_name] = json.load(f)
                 continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@curriculum_learning.py` around lines 501 - 506, The short-circuit that skips
completed stages (the for loop over stage_names using
self._is_stage_completed(stage_name) and reading self._metrics_path(stage_name))
must honor the eval_only flag: change the logic so you only skip returning
cached metrics when eval_only is False; if eval_only is True, do not
short-circuit completed stages and instead call self._train_stage(stage_name,
..., eval_only=True) so predictions/metrics are refreshed from the checkpoint
before loading/saving metrics. Ensure references to _is_stage_completed,
_train_stage, and _metrics_path are used to implement this conditional behavior.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
curriculum_learning.py (2)

355-355: Add strict=True to zip() for defensive debugging.

A length mismatch between ts_texts and ts_values would indicate a data preparation bug. Adding strict=True surfaces this immediately rather than silently truncating the iteration.

Suggested fix
-        for i, (text, ts) in enumerate(zip(ts_texts, ts_values)):
+        for i, (text, ts) in enumerate(zip(ts_texts, ts_values, strict=True)):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@curriculum_learning.py` at line 355, The loop over ts_texts and ts_values
uses zip without strict checking, which silently truncates if lengths differ;
update the zip call in the loop that iterates "for i, (text, ts) in
enumerate(zip(ts_texts, ts_values))" to use zip(ts_texts, ts_values,
strict=True) so a length mismatch raises immediately (adjust any surrounding
code/comments to reflect the stricter behavior).

509-511: Consider using itertools.islice for more efficient mid-epoch resume.

The current approach iterates through all batches and skips via continue, which still incurs DataLoader overhead for skipped samples. For large datasets, this could add noticeable startup time on resume.

Example using islice
+from itertools import islice
+
 ...
-            for step, batch in enumerate(train_loader):
-                if step < step_skip:
-                    continue
+            for step, batch in enumerate(islice(train_loader, step_skip, None), start=step_skip):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@curriculum_learning.py` around lines 509 - 511, The loop currently does "for
step, batch in enumerate(train_loader): if step < step_skip: continue" which
still iterates DataLoader items; replace this with itertools.islice to skip
upstream: import and use islice(train_loader, step_skip, None) (and adjust
enumerate start to step_skip) so the DataLoader does not yield skipped batches;
update references to `train_loader`, `step`, `batch`, and `step_skip` in the
enclosing function in curriculum_learning.py accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@curriculum_learning.py`:
- Line 355: The loop over ts_texts and ts_values uses zip without strict
checking, which silently truncates if lengths differ; update the zip call in the
loop that iterates "for i, (text, ts) in enumerate(zip(ts_texts, ts_values))" to
use zip(ts_texts, ts_values, strict=True) so a length mismatch raises
immediately (adjust any surrounding code/comments to reflect the stricter
behavior).
- Around line 509-511: The loop currently does "for step, batch in
enumerate(train_loader): if step < step_skip: continue" which still iterates
DataLoader items; replace this with itertools.islice to skip upstream: import
and use islice(train_loader, step_skip, None) (and adjust enumerate start to
step_skip) so the DataLoader does not yield skipped batches; update references
to `train_loader`, `step`, `batch`, and `step_skip` in the enclosing function in
curriculum_learning.py accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 8c543b6e-ccc5-478f-8e61-8e8997b2a62c

📥 Commits

Reviewing files that changed from the base of the PR and between d43ea04 and 5e4ba8e.

📒 Files selected for processing (2)
  • curriculum_learning.py
  • time_series_datasets/mhc_base_qa_dataset.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant