Skip to content

Commit 9ee2a53

Browse files
committed
Save state model only
1 parent 0d8ca1a commit 9ee2a53

3 files changed

Lines changed: 37 additions & 14 deletions

File tree

src/accelerate/accelerator.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3632,6 +3632,8 @@ def _inner(folder):
36323632
os.makedirs(output_dir, exist_ok=True)
36333633
logger.info(f"Saving current state to {output_dir}")
36343634

3635+
save_model_only = save_model_func_kwargs.pop("save_model_only", False)
3636+
36353637
if self.distributed_type == DistributedType.XLA:
36363638
# Finish running the previous step before checkpointing
36373639
xm.mark_step()
@@ -3657,23 +3659,25 @@ def _inner(folder):
36573659

36583660
# Save the optimizers taking care of FSDP and DeepSpeed nuances
36593661
optimizers = []
3660-
if self.distributed_type == DistributedType.FSDP:
3661-
for i, opt in enumerate(self._optimizers):
3662-
logger.info("Saving FSDP Optimizer")
3663-
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
3664-
logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
3665-
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
3666-
optimizers = self._optimizers
3662+
if not save_model_only:
3663+
if self.distributed_type == DistributedType.FSDP:
3664+
for i, opt in enumerate(self._optimizers):
3665+
logger.info("Saving FSDP Optimizer")
3666+
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
3667+
logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
3668+
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
3669+
optimizers = self._optimizers
36673670

36683671
# Save the lr schedulers taking care of DeepSpeed nuances
36693672
schedulers = []
3670-
if self.distributed_type == DistributedType.DEEPSPEED:
3671-
for i, scheduler in enumerate(self._schedulers):
3672-
if isinstance(scheduler, DeepSpeedSchedulerWrapper):
3673-
continue
3674-
schedulers.append(scheduler)
3675-
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
3676-
schedulers = self._schedulers
3673+
if not save_model_only:
3674+
if self.distributed_type == DistributedType.DEEPSPEED:
3675+
for i, scheduler in enumerate(self._schedulers):
3676+
if isinstance(scheduler, DeepSpeedSchedulerWrapper):
3677+
continue
3678+
schedulers.append(scheduler)
3679+
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
3680+
schedulers = self._schedulers
36773681

36783682
# Save the samplers of the dataloaders
36793683
dataloaders = self._dataloaders
@@ -3694,6 +3698,7 @@ def _inner(folder):
36943698
self.scaler,
36953699
save_on_each_node=self.project_configuration.save_on_each_node,
36963700
safe_serialization=safe_serialization,
3701+
save_model_only=save_model_only
36973702
)
36983703
for i, obj in enumerate(self._custom_objects):
36993704
save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)

src/accelerate/checkpointing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def save_accelerator_state(
7171
scaler: Optional[GradScaler] = None,
7272
save_on_each_node: bool = False,
7373
safe_serialization: bool = True,
74+
save_model_only: bool = False,
7475
):
7576
"""
7677
Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory.
@@ -113,6 +114,10 @@ def save_accelerator_state(
113114
output_model_file = output_dir.joinpath(weights_name)
114115
save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization)
115116
logger.info(f"Model weights saved in {output_model_file}")
117+
118+
if save_model_only:
119+
return output_dir
120+
116121
# Optimizer states
117122
for i, opt in enumerate(optimizers):
118123
state = opt.state_dict()

tests/test_accelerator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,19 @@ def test_save_model(self, use_safetensors):
296296
load_checkpoint_in_model(model, tmpdirname)
297297
assert abs(model_signature - get_signature(model)) < 1e-3
298298

299+
@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
300+
def test_save_state_model_only(self, use_safetensors):
301+
accelerator = Accelerator()
302+
model = torch.nn.Linear(10, 10)
303+
model = accelerator.prepare(model)
304+
305+
model_signature = get_signature(model)
306+
with tempfile.TemporaryDirectory() as tmpdirname:
307+
accelerator.save_state(tmpdirname, safe_serialization=use_safetensors, save_model_only=True)
308+
# make sure loaded weights match
309+
load_checkpoint_in_model(model, tmpdirname)
310+
assert abs(model_signature - get_signature(model)) < 1e-3
311+
299312
@parameterized.expand([True, False], name_func=parameterized_custom_name_func)
300313
def test_save_sharded_model(self, use_safetensors):
301314
accelerator = Accelerator()

0 commit comments

Comments
 (0)