-
Notifications
You must be signed in to change notification settings - Fork 1.3k
handle weight sharing with init_on_device #3752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
a695579
892ba83
a8575f9
ac9f722
33a07f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -188,6 +188,18 @@ def test_init_empty_weights(self): | |
| assert module.weight.device == torch.device("cpu") | ||
| assert module.running_mean.device == torch.device("cpu") | ||
|
|
||
| def test_init_empty_weights_with_tie_embedding(self): | ||
| with init_empty_weights(): | ||
| module = torch.nn.ModuleList([torch.nn.Embedding(12, 12), torch.nn.Linear(12, 12)]) | ||
| # tie embedding | ||
| module[0].weight = module[1].weight | ||
|
|
||
| from transformers.models import Qwen2Config, Qwen2ForCausalLM | ||
|
|
||
| qwen2 = Qwen2ForCausalLM(Qwen2Config(tie_word_embeddings=True)) | ||
| assert module[0].weight is module[1].weight | ||
| assert qwen2.lm_head.weight is qwen2.model.embed_tokens.weight | ||
|
|
||
|
Comment on lines
+191
to
+202
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you also check if this works with transformers ? We used to require users to call
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added tie weights testing based on qwen2. |
||
| def test_init_empty_weights_very_large_model(self): | ||
| # This is a 100 billion parameters model. | ||
| with init_empty_weights(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean try with an example to show that we don't need tie_weights anymore like in
test_infer_auto_device_map_on_t0pp. Remove that for now and we can update the tests laterThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed
tie_weightsintest_infer_auto_device_map_on_t0ppand get passed, is this pr need remove all thetie_weightscall in tests ? or keep sometie_weightsto ensure that usingtie_weightsstill works as expected??