Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,6 +1792,12 @@ def prepare_model(
if device_placement is None:
device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP

# Ensure we can't double wrap a model
if getattr(model, "_is_accelerate_prepared", False):
if model not in self._models:
self._models.append(model)
return model

self._models.append(model)

# TODO: Look at enabling native TP training directly with a proper config
Expand Down
13 changes: 13 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,19 @@ def test_is_accelerator_prepared(self):
"Valid Dataloader is missing `_is_accelerator_prepared` or is set to `False`"
)

def test_prepare_model_twice_does_not_double_wrap(self):
accelerator = Accelerator()
model = torch.nn.Linear(10, 2)
prepared_model = accelerator.prepare_model(model)
num_models_before = len(accelerator._models)
reprepared_model = accelerator.prepare_model(prepared_model)
assert len(accelerator._models) == num_models_before, (
"prepare_model should not add duplicate entries to _models"
)
assert reprepared_model is prepared_model, (
"prepare_model should return the same object when called twice"
)

@require_cuda_or_xpu
@slow
@require_bnb
Expand Down
Loading