Skip to content

Fix torch.compile recompilation issue with HF modeling + TP#7

Open
3outeille wants to merge 7 commits intomainfrom
fix-compile-tp
Open

Fix torch.compile recompilation issue with HF modeling + TP#7
3outeille wants to merge 7 commits intomainfrom
fix-compile-tp

Conversation

@3outeille
Copy link
Copy Markdown
Member

@3outeille 3outeille commented Dec 9, 2025

Fixing the bug #6

TODO: need to apply change in transformers V5. That requires to wait for V5 to be a bit stable before switch torchtitan transformers modeling backend to v5 (as for now, it relies on 4.57.1)

Issue

[rank3]:/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:321: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
[rank3]:  warnings.warn(
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] torch._dynamo hit config.recompile_limit (8)
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8]    function: 'forward' (/fsx/ferdinandmom/ferdinand-hf/huggingface/torchtitan/env_torchtitan_official/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:145)
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8]    last reason: 0/7: ___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)  # if hook_id in self._forward_pre_hooks_with_kwargs:  # nn/modules/module.py:1815 in inner
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[rank3]:[rank3]:W1209 10:08:01.582000 2539247 torch/_dynamo/convert_frame.py:1564] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/compile/programming_model.recompilation.html
[rank3]:[rank3]: Traceback (most recent call last):

Fix

  • Apply + current PR changes + transformers at modeling_llama.py, change
       hidden_states, _ = self.self_attn(
-		   hidden_states=hidden_states,
+          hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
  • ./tooling_dev/debug_local.sh debugperf_large --compile
image

Explanation

  • When torch.compile traces your model, it creates a compiled graph along with guards. Guards are conditions that must be true for that graph to be reused. If guard fails, torch.compile will recompiles.
  • in modeling_llama.py, the self.attn(hidden_states=hidden_states) is called with kwargs
  • In torchtitan, if you apply TP, it will apply register_forward_pre_hook . However, depending on if you use kwargs or not, it will call different function (cf https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L576).
    • In our case, it will call module.register_forward_pre_hook(lambda _, inputs, kwargs: some_fn(inputs, kwargs), with_kwargs=True
  • but calling this function is problematic as it will trigger if hook_id in self._forward_pre_hooks_with_kwargs: (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1808)
    • This means that using kwargs will results in different hook_id , hence the error ___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)
  • When we don't usekwargs, self._forward_pre_hooks_with_kwargs will always be empty (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1679C13-L1679C48) so the if check is not triggered, so each attention layer has same hook_id, thus no recompile

@3outeille 3outeille changed the base branch from main to improve_hf_throughput December 9, 2025 14:14
@3outeille 3outeille changed the base branch from improve_hf_throughput to main December 9, 2025 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant