@@ -3632,6 +3632,8 @@ def _inner(folder):
36323632 os .makedirs (output_dir , exist_ok = True )
36333633 logger .info (f"Saving current state to { output_dir } " )
36343634
3635+ save_model_only = save_model_func_kwargs .pop ("save_model_only" , False )
3636+
36353637 if self .distributed_type == DistributedType .XLA :
36363638 # Finish running the previous step before checkpointing
36373639 xm .mark_step ()
@@ -3657,23 +3659,25 @@ def _inner(folder):
36573659
36583660 # Save the optimizers taking care of FSDP and DeepSpeed nuances
36593661 optimizers = []
3660- if self .distributed_type == DistributedType .FSDP :
3661- for i , opt in enumerate (self ._optimizers ):
3662- logger .info ("Saving FSDP Optimizer" )
3663- save_fsdp_optimizer (self .state .fsdp_plugin , self , opt , self ._models [i ], output_dir , i )
3664- logger .info (f"FSDP Optimizer saved to output dir { output_dir } " )
3665- elif self .distributed_type not in [DistributedType .DEEPSPEED , DistributedType .MEGATRON_LM ]:
3666- optimizers = self ._optimizers
3662+ if not save_model_only :
3663+ if self .distributed_type == DistributedType .FSDP :
3664+ for i , opt in enumerate (self ._optimizers ):
3665+ logger .info ("Saving FSDP Optimizer" )
3666+ save_fsdp_optimizer (self .state .fsdp_plugin , self , opt , self ._models [i ], output_dir , i )
3667+ logger .info (f"FSDP Optimizer saved to output dir { output_dir } " )
3668+ elif self .distributed_type not in [DistributedType .DEEPSPEED , DistributedType .MEGATRON_LM ]:
3669+ optimizers = self ._optimizers
36673670
36683671 # Save the lr schedulers taking care of DeepSpeed nuances
36693672 schedulers = []
3670- if self .distributed_type == DistributedType .DEEPSPEED :
3671- for i , scheduler in enumerate (self ._schedulers ):
3672- if isinstance (scheduler , DeepSpeedSchedulerWrapper ):
3673- continue
3674- schedulers .append (scheduler )
3675- elif self .distributed_type not in [DistributedType .MEGATRON_LM ]:
3676- schedulers = self ._schedulers
3673+ if not save_model_only :
3674+ if self .distributed_type == DistributedType .DEEPSPEED :
3675+ for i , scheduler in enumerate (self ._schedulers ):
3676+ if isinstance (scheduler , DeepSpeedSchedulerWrapper ):
3677+ continue
3678+ schedulers .append (scheduler )
3679+ elif self .distributed_type not in [DistributedType .MEGATRON_LM ]:
3680+ schedulers = self ._schedulers
36773681
36783682 # Save the samplers of the dataloaders
36793683 dataloaders = self ._dataloaders
@@ -3694,6 +3698,7 @@ def _inner(folder):
36943698 self .scaler ,
36953699 save_on_each_node = self .project_configuration .save_on_each_node ,
36963700 safe_serialization = safe_serialization ,
3701+ save_model_only = save_model_only
36973702 )
36983703 for i , obj in enumerate (self ._custom_objects ):
36993704 save_custom_state (obj , output_dir , i , save_on_each_node = self .project_configuration .save_on_each_node )
0 commit comments