Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,15 @@ def register_empty_parameter(module, name, param):
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
# When we have a case of tensor2 = tensor1, it would call the set_attr
# of param, which in turn would call the register_parameter API.
# In this case, the new param is already on meta-device, since it was moved
# previously when it was initialized. Hence, when resetting, you can
# directly assign that tensor instead of re-init. If you re-init you would
# lose the relationship.
module._parameters[name] = (
param if param.device == device else param_cls(module._parameters[name].to(device), **kwargs)
)

def register_empty_buffer(module, name, buffer, persistent=True):
old_register_buffer(module, name, buffer, persistent=persistent)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ def test_init_empty_weights(self):
assert module.weight.device == torch.device("cpu")
assert module.running_mean.device == torch.device("cpu")

def test_init_empty_weights_with_tie_embedding(self):
with init_empty_weights():
module = torch.nn.ModuleList([torch.nn.Embedding(12, 12), torch.nn.Linear(12, 12)])
# tie embedding
module[0].weight = module[1].weight

from transformers.models import Qwen2Config, Qwen2ForCausalLM

qwen2 = Qwen2ForCausalLM(Qwen2Config(tie_word_embeddings=True))
Comment on lines +197 to +199
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean try with an example to show that we don't need tie_weights anymore like in test_infer_auto_device_map_on_t0pp. Remove that for now and we can update the tests later

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed tie_weights in test_infer_auto_device_map_on_t0pp and get passed, is this pr need remove all the tie_weights call in tests ? or keep some tie_weights to ensure that using tie_weights still works as expected??

assert module[0].weight is module[1].weight
assert qwen2.lm_head.weight is qwen2.model.embed_tokens.weight

def test_init_empty_weights_very_large_model(self):
# This is a 100 billion parameters model.
with init_empty_weights():
Expand Down
1 change: 0 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,6 @@ def test_infer_auto_device_map_on_t0pp(self):
config = AutoConfig.from_pretrained("bigscience/T0pp")
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)
model.tie_weights()

special_dtypes = {n: torch.float32 for n, _ in model.named_parameters() if "wo" in n}
max_memory = {0: 10**10, 1: 10**10, "cpu": 10**10}
Expand Down
Loading