From dba37bf0eb6a4265784badd252ca34e8e2127977 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 10 Apr 2025 18:41:47 +0000 Subject: [PATCH] make CPU Adam automatic when offloading optimizer in DS config --- arctic_training/config/trainer.py | 6 ++++++ arctic_training/trainer/dpo_trainer.py | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/arctic_training/config/trainer.py b/arctic_training/config/trainer.py index af9388b1..9de06675 100644 --- a/arctic_training/config/trainer.py +++ b/arctic_training/config/trainer.py @@ -334,6 +334,12 @@ def validate_single_checkpoint_resume(self) -> Self: assert sum(resume_checkpoint_values) <= 1, "Only one checkpoint can auto resume." return self + @model_validator(mode="after") + def set_optimizer_for_cpu_offload(self) -> Self: + if self.deepspeed.get("cpu_offload", {}).get("device", None) == "cpu" and self.optimizer.type == "fusedadam": + self.optimizer.type = "cpu_adam" + return self + def get_config(config_file_or_dict: Union[Path, Dict]) -> BaseConfig: if isinstance(config_file_or_dict, dict): diff --git a/arctic_training/trainer/dpo_trainer.py b/arctic_training/trainer/dpo_trainer.py index 7dbd2b99..7d574391 100644 --- a/arctic_training/trainer/dpo_trainer.py +++ b/arctic_training/trainer/dpo_trainer.py @@ -39,6 +39,7 @@ from arctic_training.data.dpo_factory import DPODataFactory from arctic_training.model.hf_factory import HFModelFactory from arctic_training.model.liger_factory import LigerModelFactory +from arctic_training.optimizer.adam_factory import CPUAdamOptimizerFactory from arctic_training.optimizer.adam_factory import FusedAdamOptimizerFactory from arctic_training.registry import get_registered_model_factory from arctic_training.scheduler.hf_factory import HFSchedulerFactory @@ -176,7 +177,7 @@ class DPOTrainer(Trainer): model_factory: Union[HFModelFactory, LigerModelFactory] ref_model_factory: Union[HFModelFactory, LigerModelFactory] checkpoint_engine: Union[DSCheckpointEngine, HFCheckpointEngine] - optimizer_factory: FusedAdamOptimizerFactory + optimizer_factory: Union[FusedAdamOptimizerFactory, CPUAdamOptimizerFactory] scheduler_factory: HFSchedulerFactory tokenizer_factory: HFTokenizerFactory ref_model: torch.nn.Module