diff --git a/arctic_training/model/hf_factory.py b/arctic_training/model/hf_factory.py index a56ee94a..b785e1ea 100644 --- a/arctic_training/model/hf_factory.py +++ b/arctic_training/model/hf_factory.py @@ -83,7 +83,12 @@ def post_create_model_callback(self, model): ) if not self.config.disable_activation_checkpoint: - model.gradient_checkpointing_enable() + gc_kwargs = None + if self.trainer.config.activation_checkpoint_cpu_offload: + # CPU-offload monkey patch (see arctic_training.monkey_patches) + # only intercepts torch.utils.checkpoint's reentrant path. + gc_kwargs = {"use_reentrant": True} + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) model = self.make_model_gradient_checkpointing_compatible(model) return model