Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions arctic_training/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def validate_single_checkpoint_resume(self) -> Self:
assert sum(resume_checkpoint_values) <= 1, "Only one checkpoint can auto resume."
return self

@model_validator(mode="after")
def set_optimizer_for_cpu_offload(self) -> Self:
if self.deepspeed.get("cpu_offload", {}).get("device", None) == "cpu" and self.optimizer.type == "fusedadam":
self.optimizer.type = "cpu_adam"
Comment on lines +339 to +340
Copy link
Copy Markdown
Collaborator

@sfc-gh-sbekman sfc-gh-sbekman Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it assert if a user has an explicit self.optimizer.type set to the wrong value? like:

zero_optimization.offload_optimizer.device.cpu is configured therefore optimizer.type=cpu_adam is required, but the config file has optimizer.type=... . Please either set it to cpu_adam or remove the explicit config and it'll be set automatically

return self


def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig:
if isinstance(config_file_or_dict, dict):
Expand Down
3 changes: 2 additions & 1 deletion arctic_training/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from arctic_training.data.dpo_factory import DPODataFactory
from arctic_training.model.hf_factory import HFModelFactory
from arctic_training.model.liger_factory import LigerModelFactory
from arctic_training.optimizer.adam_factory import CPUAdamOptimizerFactory
from arctic_training.optimizer.adam_factory import FusedAdamOptimizerFactory
from arctic_training.registry import get_registered_model_factory
from arctic_training.scheduler.hf_factory import HFSchedulerFactory
Expand Down Expand Up @@ -176,7 +177,7 @@ class DPOTrainer(Trainer):
model_factory: Union[HFModelFactory, LigerModelFactory]
ref_model_factory: Union[HFModelFactory, LigerModelFactory]
checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine]
optimizer_factory: FusedAdamOptimizerFactory
optimizer_factory: Union[FusedAdamOptimizerFactory, CPUAdamOptimizerFactory]
scheduler_factory: HFSchedulerFactory
tokenizer_factory: HFTokenizerFactory
ref_model: torch.nn.Module
Expand Down