diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index bddcaa8a0cc..f025fefa92a 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -132,7 +132,15 @@ def register_empty_parameter(module, name, param): param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad - module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + # When we have a case of tensor2 = tensor1, it would call the set_attr + # of param, which in turn would call the register_parameter API. + # In this case, the new param is already on meta-device, since it was moved + # previously when it was initialized. Hence, when resetting, you can + # directly assign that tensor instead of re-init. If you re-init you would + # lose the relationship. + module._parameters[name] = ( + param if param.device == device else param_cls(module._parameters[name].to(device), **kwargs) + ) def register_empty_buffer(module, name, buffer, persistent=True): old_register_buffer(module, name, buffer, persistent=persistent) diff --git a/tests/test_big_modeling.py b/tests/test_big_modeling.py index 7c960745565..3764ba2b98d 100644 --- a/tests/test_big_modeling.py +++ b/tests/test_big_modeling.py @@ -188,6 +188,18 @@ def test_init_empty_weights(self): assert module.weight.device == torch.device("cpu") assert module.running_mean.device == torch.device("cpu") + def test_init_empty_weights_with_tie_embedding(self): + with init_empty_weights(): + module = torch.nn.ModuleList([torch.nn.Embedding(12, 12), torch.nn.Linear(12, 12)]) + # tie embedding + module[0].weight = module[1].weight + + from transformers.models import Qwen2Config, Qwen2ForCausalLM + + qwen2 = Qwen2ForCausalLM(Qwen2Config(tie_word_embeddings=True)) + assert module[0].weight is module[1].weight + assert qwen2.lm_head.weight is qwen2.model.embed_tokens.weight + def test_init_empty_weights_very_large_model(self): # This is a 100 billion parameters model. with init_empty_weights(): diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 4857b3b5df2..68f3eae039c 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -677,7 +677,6 @@ def test_infer_auto_device_map_on_t0pp(self): config = AutoConfig.from_pretrained("bigscience/T0pp") with init_empty_weights(): model = AutoModelForSeq2SeqLM.from_config(config) - model.tie_weights() special_dtypes = {n: torch.float32 for n, _ in model.named_parameters() if "wo" in n} max_memory = {0: 10**10, 1: 10**10, "cpu": 10**10}