Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,14 @@ def register_empty_parameter(module, name, param):
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
# When we have a case of tensor2 = tensor1, it would call the set_attr
# of param, which in turn would call the register_parameter API.
# In this case, the new param is already on meta-device, since it was moved
# previously when it was initialized. Hence, when resetting, you can
# directly assign that tensor instead of re-init. If you re-init you would
# lose the relationship.
module._parameters[name] = param if param.device == device else \
param_cls(module._parameters[name].to(device), **kwargs)

def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ def test_init_empty_weights(self):
assert module.weight.device == torch.device("cpu")
assert module.running_mean.device == torch.device("cpu")

def test_init_empty_weights_with_tie_embedding(self):
with init_empty_weights():
module = torch.nn.ModuleList([torch.nn.Embedding(12, 12), torch.nn.Linear(12, 12)])
# tie embedding
module[0].weight = module[1].weight
assert module[0].weight is module[1].weight

Comment on lines +191 to +202
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also check if this works with transformers ? We used to require users to call tie_weights and it will really nice if this is not needed anymore

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tie weights testing based on qwen2.

def test_init_empty_weights_very_large_model(self):
# This is a 100 billion parameters model.
with init_empty_weights():
Expand Down