Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions arctic_training/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def validate_single_checkpoint_resume(self) -> Self:
assert sum(resume_checkpoint_values) <= 1, "Only one checkpoint can auto resume."
return self

@model_validator(mode="after")
def set_optimizer_for_cpu_offload(self) -> Self:
if self.deepspeed.get("cpu_offload", {}).get("device", None) == "cpu" and self.optimizer.type == "fusedadam":
self.optimizer.type = "cpu_adam"
Comment on lines +339 to +340
Copy link
Copy Markdown
Collaborator

@sfc-gh-sbekman sfc-gh-sbekman Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it assert if a user has an explicit self.optimizer.type set to the wrong value? like:

zero_optimization.offload_optimizer.device.cpu is configured therefore optimizer.type=cpu_adam is required, but the config file has optimizer.type=... . Please either set it to cpu_adam or remove the explicit config and it'll be set automatically

return self


def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig:
if isinstance(config_file_or_dict, dict):
Expand Down
3 changes: 2 additions & 1 deletion arctic_training/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from arctic_training.data.dpo_factory import DPODataFactory
from arctic_training.model.hf_factory import HFModelFactory
from arctic_training.model.liger_factory import LigerModelFactory
from arctic_training.optimizer.adam_factory import CPUAdamOptimizerFactory
from arctic_training.optimizer.adam_factory import FusedAdamOptimizerFactory
from arctic_training.registry import get_registered_model_factory
from arctic_training.scheduler.hf_factory import HFSchedulerFactory
Expand Down Expand Up @@ -176,7 +177,7 @@ class DPOTrainer(Trainer):
model_factory: Union[HFModelFactory, LigerModelFactory]
ref_model_factory: Union[HFModelFactory, LigerModelFactory]
checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine]
optimizer_factory: FusedAdamOptimizerFactory
optimizer_factory: Union[FusedAdamOptimizerFactory, CPUAdamOptimizerFactory]
scheduler_factory: HFSchedulerFactory
tokenizer_factory: HFTokenizerFactory
ref_model: torch.nn.Module
Expand Down