Fix torch.compile recompilation issue with HF modeling + TP#7
Open
Fix torch.compile recompilation issue with HF modeling + TP#7
torch.compile recompilation issue with HF modeling + TP#7Conversation
…e issue when combined with TP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixing the bug #6
TODO: need to apply change in
transformers V5. That requires to wait for V5 to be a bit stable before switch torchtitan transformers modeling backend to v5 (as for now, it relies on 4.57.1)Issue
Fix
transformersatmodeling_llama.py, change./tooling_dev/debug_local.sh debugperf_large --compileExplanation
torch.compiletraces your model, it creates a compiled graph along with guards. Guards are conditions that must be true for that graph to be reused. If guard fails,torch.compilewill recompiles.modeling_llama.py, theself.attn(hidden_states=hidden_states)is called withkwargsregister_forward_pre_hook. However, depending on if you usekwargsor not, it will call different function (cf https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L576).module.register_forward_pre_hook(lambda _, inputs, kwargs: some_fn(inputs, kwargs), with_kwargs=Trueif hook_id in self._forward_pre_hooks_with_kwargs:(cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1808)kwargswill results in differenthook_id, hence the error___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)kwargs,self._forward_pre_hooks_with_kwargswill always be empty (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1679C13-L1679C48) so the if check is not triggered, so each attention layer has samehook_id, thus no recompile