Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,14 @@ def register_empty_parameter(module, name, param):
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
# When we have a case of tensor2 = tensor1, it would call the set_attr
# of param, which in turn would call the register_parameter API.
# In this case, the new param is already on meta-device, since it was moved
# previously when it was initialized. Hence, when resetting, you can
# directly assign that tensor instead of re-init. If you re-init you would
# lose the relationship.
module._parameters[name] = param if param.device == device else \
param_cls(module._parameters[name].to(device), **kwargs)

def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ def test_init_empty_weights(self):
assert module.weight.device == torch.device("cpu")
assert module.running_mean.device == torch.device("cpu")

def test_init_empty_weights_with_tie_embedding(self):
with init_empty_weights():
module = torch.nn.ModuleList([torch.nn.Embedding(12, 12), torch.nn.Linear(12, 12)])
# tie embedding
module[0].weight = module[1].weight

from transformers.models import Qwen2Config, Qwen2ForCausalLM

qwen2 = Qwen2ForCausalLM(Qwen2Config(tie_word_embeddings=True))
Comment on lines +197 to +199
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean try with an example to show that we don't need tie_weights anymore like in test_infer_auto_device_map_on_t0pp. Remove that for now and we can update the tests later

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed tie_weights in test_infer_auto_device_map_on_t0pp and get passed, is this pr need remove all the tie_weights call in tests ? or keep some tie_weights to ensure that using tie_weights still works as expected??

assert module[0].weight is module[1].weight
assert qwen2.lm_head.weight is qwen2.model.embed_tokens.weight

Comment on lines +191 to +202
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also check if this works with transformers ? We used to require users to call tie_weights and it will really nice if this is not needed anymore

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tie weights testing based on qwen2.

def test_init_empty_weights_very_large_model(self):
# This is a 100 billion parameters model.
with init_empty_weights():
Expand Down