Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions torchtitan/experiments/graph_trainer/make_fx_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -247,27 +248,51 @@ def _patched(t_outputs, *args, **kwargs): # type: ignore[no-untyped-def]


def _copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
"""Copy forward node metadata (custom) to later nodes sharing the same seq_nr.
"""Copy forward metadata to backward nodes across all nested FX subgraphs.

Walks the graph in a single pass. The first node seen for each seq_nr is
treated as the forward node.
Subsequent nodes with the same seq_nr (typically backward nodes) receive
the forward node's custom metadata.
Uses a two-pass approach over all submodule graphs (including HOP subgraphs
like score_mod/mask_mod). Pass 1 collects forward nodes by seq_nr; pass 2
copies custom/nn_module_stack/stack_trace from the matching forward node to
each backward node. Backward nodes are identified by the autograd engine's
``autograd_backward`` tag on ``node.meta``.
"""

def _is_backward(node: torch.fx.Node) -> bool:
return node.meta.get("autograd_backward", False)

seq_nr_to_fwd_node: dict[int, torch.fx.Node] = {}

for node in fx_g.graph.nodes:
if node.op not in ("call_function", "get_attr") or "seq_nr" not in node.meta:
for submod in fx_g.modules():
if not isinstance(submod, torch.fx.GraphModule):
continue
seq_nr = node.meta["seq_nr"]
if seq_nr not in seq_nr_to_fwd_node:
seq_nr_to_fwd_node[seq_nr] = node
else:
fwd_node = seq_nr_to_fwd_node[seq_nr]
for node in submod.graph.nodes:
if (
node.op not in ("call_function", "get_attr")
or "seq_nr" not in node.meta
or _is_backward(node)
):
continue
seq_nr = node.meta["seq_nr"]
if seq_nr not in seq_nr_to_fwd_node:
seq_nr_to_fwd_node[seq_nr] = node

for submod in fx_g.modules():
if not isinstance(submod, torch.fx.GraphModule):
continue
for node in submod.graph.nodes:
if (
node.op not in ("call_function", "get_attr")
or "seq_nr" not in node.meta
or not _is_backward(node)
):
continue
fwd_node = seq_nr_to_fwd_node.get(node.meta["seq_nr"])
if fwd_node is None or fwd_node is node:
continue

custom = fwd_node.meta.get("custom")
if custom:
node.meta.setdefault("custom", {}).update(custom)
node.meta.setdefault("custom", {}).update(copy.deepcopy(custom))
nn_module_stack = fwd_node.meta.get("nn_module_stack")
if nn_module_stack is not None:
node.meta["nn_module_stack"] = nn_module_stack.copy()
Expand Down
191 changes: 189 additions & 2 deletions torchtitan/experiments/graph_trainer/tests/test_trace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def _apply_regional_inductor(traced_result):
break

context = torch._guards.TracingContext(fake_mode)
with torch._guards.tracing(context):
with (
torch._guards.tracing(context),
torch._functorch.config.patch("remat_using_tags_for_fwd_loss_bwd_graph", False),
):
traced_result.gm = regional_inductor(traced_result.gm)

traced_result.gm.graph.set_codegen(CodeGen())
Expand Down Expand Up @@ -450,12 +453,15 @@ def test_copy_fwd_metadata_propagates_custom(self):
# Run the copy pass again
_copy_fwd_metadata_to_bw_nodes(gm)

def is_backward(node: torch.fx.Node) -> bool:
return node.meta.get("autograd_backward", False)

# Check that bwd nodes with shared seq_nr got the custom metadata
for node in gm.graph.nodes:
if node.op != "call_function" or "seq_nr" not in node.meta:
continue
seq_nr = node.meta["seq_nr"]
if node is not seq_nr_first.get(seq_nr):
if node is not seq_nr_first.get(seq_nr) and is_backward(node):
# This is a backward node
custom = node.meta.get("custom")
self.assertIsNotNone(
Expand All @@ -464,6 +470,21 @@ def test_copy_fwd_metadata_propagates_custom(self):
)
self.assertEqual(custom.get("test_key"), "test_value")

def test_copy_fwd_metadata_uses_backward_tagging(self):
graph = torch.fx.Graph()
fwd = graph.call_function(torch.ops.aten.add.Tensor, args=(1, 2))
fwd.meta["seq_nr"] = 7
fwd.meta["custom"] = {"test_key": "test_value"}
bwd = graph.call_function(torch.ops.aten.mul.Tensor, args=(fwd, 3))
bwd.meta["seq_nr"] = 7
bwd.meta["autograd_backward"] = True
graph.output(bwd)
gm = torch.fx.GraphModule(torch.nn.Module(), graph)

_copy_fwd_metadata_to_bw_nodes(gm)

self.assertEqual(bwd.meta["custom"].get("test_key"), "test_value")

def test_backward_nodes_have_stack_trace(self):
"""Verify that backward nodes get stack_trace from their forward counterpart."""
model = SimpleMLP().to(device=self.DEVICE, dtype=self.DTYPE)
Expand Down Expand Up @@ -661,6 +682,172 @@ def test_deepseek_v3(self):
config = deepseekv3_configs["debugmodel"]()
self._run_model_test(DeepSeekV3Model, config)

def test_deepseek_v3_flex_attention(self):
"""Multi-step bitwise test for DeepSeek MLA + flex attention + regional_inductor.

Uses a tiny model with small head dims to stay within triton shared
memory limits. Annotates FlexAttention.forward via annotate_fn before
tracing so compile_with_inductor flows into the graph naturally.
"""
from torch.fx.traceback import annotate_fn
from torch.nn.attention.flex_attention import and_masks

from torchtitan.models.common.attention import (
create_attention_mask,
FlexAttention,
get_causal_mask_mod,
get_document_mask_mod,
)
from torchtitan.models.common.linear import Linear
from torchtitan.models.common.rmsnorm import RMSNorm
from torchtitan.models.common.rope import RoPE
from torchtitan.models.deepseek_v3.model import Attention as DSAttention

dim = 64
n_heads = 4
rope_dim = 16
seq_len = 64
vocab_size = 128

# Build a tiny model: embedding -> MLA flex attention -> projection
class TinyFlexMLA(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(vocab_size, dim)
kv_lora_rank = 32
qk_nope_head_dim = 16
v_head_dim = 16
qk_head_dim = qk_nope_head_dim + rope_dim
self.attn = DSAttention(
DSAttention.Config(
n_heads=n_heads,
dim=dim,
q_lora_rank=0,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=rope_dim,
v_head_dim=v_head_dim,
q_norm=RMSNorm.Config(normalized_shape=1),
kv_norm=RMSNorm.Config(normalized_shape=kv_lora_rank),
inner_attention=FlexAttention.Config(),
mask_type="block_causal",
wq=Linear.Config(
in_features=dim,
out_features=n_heads * qk_head_dim,
),
wkv_a=Linear.Config(
in_features=dim,
out_features=kv_lora_rank + rope_dim,
),
wkv_b=Linear.Config(
in_features=kv_lora_rank,
out_features=n_heads * (qk_nope_head_dim + v_head_dim),
),
wo=Linear.Config(
in_features=n_heads * v_head_dim,
out_features=dim,
),
),
)
self.rope = RoPE(
RoPE.Config(
dim=rope_dim,
max_seq_len=seq_len,
backend="complex",
scaling="none",
)
)
self.proj = nn.Linear(dim, vocab_size)

def init_states(self, buffer_device=None):
self.rope._init_self_buffers(
buffer_device=buffer_device or torch.device("cuda")
)

def forward(self, tokens, block_mask):
x = self.embed(tokens)
x = self.attn(x, self.rope.cache, block_mask)
return self.proj(x)

model_ref = TinyFlexMLA().to(device=self.DEVICE, dtype=self.DTYPE)
with torch.no_grad():
model_ref.init_states(buffer_device=torch.device(self.DEVICE))
model_test = TinyFlexMLA().to(device=self.DEVICE, dtype=self.DTYPE)
model_test.load_state_dict(model_ref.state_dict())
with torch.no_grad():
model_test.init_states(buffer_device=torch.device(self.DEVICE))

tokens = torch.randint(0, vocab_size, (1, seq_len), device=self.DEVICE)
labels = torch.randint(0, vocab_size, (1, seq_len), device=self.DEVICE)
# Insert EOS tokens to create document boundaries for block_causal mask
tokens[:, 15::16] = 1
block_mask = create_attention_mask(
and_masks(get_causal_mask_mod(), get_document_mask_mod(tokens, eos_id=1)),
B=1,
H=None,
Q_LEN=seq_len,
KV_LEN=seq_len,
)

# Annotate FlexAttention.forward so compile_with_inductor flows into
# the traced graph. Restore the original after tracing.
orig_forward = FlexAttention.forward
FlexAttention.forward = annotate_fn(
{
"compile_with_inductor": {
"inductor_configs": FlexAttention.inductor_configs
}
}
)(FlexAttention.forward)
try:
train_step = make_train_step(get_loss)
maybe_register_blockmask_pytree_node()
traced = trace_train_step(train_step)(model_ref, tokens, block_mask, labels)
finally:
FlexAttention.forward = orig_forward

# Verify flex attention HOPs got the annotation
for node in traced.gm.graph.nodes:
if node.target in {
torch.ops.higher_order.flex_attention,
torch.ops.higher_order.flex_attention_backward,
}:
custom = node.meta.get("custom", {})
self.assertIn(
"compile_with_inductor",
custom,
f"{node.name} missing compile_with_inductor annotation",
)

_apply_regional_inductor(traced)

opt_ref = torch.optim.Adam(model_ref.parameters(), lr=self.LR)
opt_copy = torch.optim.Adam(model_test.parameters(), lr=self.LR)

for step in range(1, self.NUM_STEPS + 1):
logits_ref = model_ref(tokens, block_mask)
loss_ref = get_loss(logits_ref, labels)
loss_ref.backward()
grads_ref = [p.grad.clone() for p in model_ref.parameters()]
opt_ref.step()
opt_ref.zero_grad()

wrapped = run_traced_train_step(
traced, model_test, tokens, block_mask, labels
)
loss_tr = wrapped[0]
grads_tr = wrapped[1:]
for p, g in zip(model_test.parameters(), grads_tr, strict=True):
p.grad = g
opt_copy.step()
opt_copy.zero_grad()

self.assertTrue(
torch.equal(loss_ref, loss_tr), f"Step {step}: loss mismatch"
)
for gr, gt in zip(grads_ref, grads_tr, strict=True):
self.assertTrue(torch.equal(gr, gt), f"Step {step}: grad mismatch")

def test_llama4(self):
from torchtitan.models.llama4 import llama4_configs
from torchtitan.models.llama4.model import Llama4Model
Expand Down
Loading