Skip to content

[graph_trainer] Replace trace_module/run_traced_module with minimal_fx_trace API#2753

Closed
tugsbayasgalan wants to merge 18 commits intogh/tugsbayasgalan/8/basefrom
gh/tugsbayasgalan/8/head
Closed

[graph_trainer] Replace trace_module/run_traced_module with minimal_fx_trace API#2753
tugsbayasgalan wants to merge 18 commits intogh/tugsbayasgalan/8/basefrom
gh/tugsbayasgalan/8/head

Conversation

@tugsbayasgalan
Copy link
Copy Markdown
Contributor

@tugsbayasgalan tugsbayasgalan commented Mar 30, 2026

Stack from ghstack (oldest at bottom):

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:

  • Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now.
  • run_traced_module now runs slightly more ergonomically.
  • Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
    time and validated at execute time to catch module structure mismatches.
  • Install TracingContext before make_fx so invoke_subgraph deduplication works.
  • Validate that all pytree leaves in args are tensors or primitives
    (int/float/bool/str). Non-primitive values (callables, custom objects) must
    be captured in fn's closure or registered via pytree.register_constant /
    register_pytree_node.

trainer.py:

  • Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
    old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
    function factory. The model is now passed as an arg, loss_fn is captured in
    the closure.

test_trace_module.py:

  • Replace TrainStepModule with _make_train_step plain function factory.
  • Remove _get_params_and_buffers helper (no longer needed).
  • Update all callsites: trace_module -> aot_function, run_traced_module ->
    direct TracedResult.call.
  • Register BlockMask as pytree node at module level so flex_attention tests
    pass the leaf validation.
  • Add test_module_not_first_arg: module at position 1 in args.
  • Add test_multiple_modules: two nn.Modules interleaved with a tensor.
  • Add test_mismatched_module_raises: FQN validation catches wrong module.
  • Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

…ion API

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 30, 2026
tugsbayasgalan added a commit that referenced this pull request Mar 30, 2026
…ion API

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

ghstack-source-id: ed9abe3
Pull Request resolved: #2753

class FwdBwdStepModule(nn.Module):
"""Wraps model + loss_fn + autograd.grad into a single traceable forward.
def _make_fwd_bwd_step(loss_fn):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the _ from the beginning. Claude does that. Please make a pass in other files for similar nit as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DOne

…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 30, 2026
…ion API

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

ghstack-source-id: 3a07e8e
Pull Request resolved: #2753
params_and_buffers = {**named_parameters, **named_buffers}
params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
params_len = len(params_and_buffers_flat)
def __init__(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need docstring, what are the params ?


def __call__(self, *args: Any) -> list[torch.Tensor]:
"""Execute the traced graph, reading fresh params from modules in ``args``."""
module_indices_set = set(self.module_indices)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is module_indices_set?

it's computed for every call, should it be cached?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted this option.


unwrapped_args = []
input_layouts: list[SubclassLayout] = []
def aot_function(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we already have an aot_function in pytorch/pytorch?

compare and contrast with pytorch's version, mention what it DOESN'T do. e.g. No functionalization.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a better name here tbh... Do you have any ideas? What about minimal_fx_tracer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to minimal_make_fx

Comment thread torchtitan/experiments/graph_trainer/make_fx_tracer.py Outdated
Comment on lines +287 to +288
input_subclass_layouts: list[SubclassLayout],
output_subclass_layouts: list[SubclassLayout],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the case without DTensor, input/output are just plain tensor.
In such case, what are the values for these 2 args?

Copy link
Copy Markdown
Contributor Author

@tugsbayasgalan tugsbayasgalan Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SubclassLayout(num_tensors=1, meta=None), but yeah i think we can simplify for plain tensors.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

Comment on lines +781 to +783
annotate_flex_attention_for_regional_inductor()
if use_regional_inductor
else contextlib.nullcontext()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe_region_inductor = ....

with maybe_regional_inductor:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

rebuilt[i] = user_list[user_idx]
user_idx += 1

with contextlib.ExitStack() as stack:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the stack for? can we live without it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was originally for supporting multiple models as input, can be deleted if we only want to support one model.

self.output_subclass_layouts = output_subclass_layouts
self.output_spec = output_spec

def __call__(self, *args: Any) -> list[torch.Tensor]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the rationale for making it a callable?

Not saying it's a bad idea, just looking for design trade-offs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this way, we can automatically fetch the param/buffer FQN etc from the tracing time and have the same calling convention as user callable. With the old design, it is bit awkward.

Comment on lines +22 to +41
def make_fwd_bwd_step(loss_fn):
"""Return a plain function that traces the entire fwd+loss+bwd step.

This allows make_fx to trace through the entire fwd+loss+bwd as one graph.
"""
``loss_fn`` is captured in the closure so it is not a graph input.

def __init__(self, model, loss_fn):
super().__init__()
self.model = model
self.loss_fn = loss_fn
TODO: investigate how loss_fn interacts with non-strict trace. Currently
it is captured as a closure variable, but non-strict tracing may need it
registered via pytree.register_constant or passed differently.
"""

def forward(self, inputs, labels, global_valid_tokens, extra_inputs, extra_kwargs):
pred = self.model(inputs, **extra_inputs, **extra_kwargs)
loss = self.loss_fn(pred, labels) / global_valid_tokens
params = [p for p in self.model.parameters() if p.requires_grad]
def fwd_bwd_step(
model, inputs, labels, global_valid_tokens, extra_inputs, extra_kwargs
):
pred = model(inputs, **extra_inputs, **extra_kwargs)
loss = loss_fn(pred, labels) / global_valid_tokens
params = [p for p in model.parameters() if p.requires_grad]
grads = torch.autograd.grad(loss, params)
return [loss] + list(grads)

return fwd_bwd_step
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This smells off.
a more natrual design is to have loss_fn as a part of model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it is ok to have loss_fn as input to the function we are tracing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. We will go with this pattern

@SherlockNoMad
Copy link
Copy Markdown
Contributor

@claude help me review.

This is an API design PR, look for opportunities to simplify design, keep everything minimal.

if this become too complicated for frontier user to hack on, it's bad.

@claude
Copy link
Copy Markdown

claude Bot commented Mar 30, 2026

Claude finished @SherlockNoMad's task in 2m 34s —— View job


Review: API Design for aot_function / TracedResult

  • Read all changed files and understand the API design
  • Analyze API complexity and simplification opportunities
  • Review existing review comments for context
  • Write comprehensive review feedback

Overall direction is good — replacing trace_module + run_traced_module with a single aot_function → callable TracedResult is cleaner. But there are several places where complexity has been added that don't seem justified by a concrete use case yet. Below are the findings, ordered by importance.


1. Multi-module support adds significant complexity without a use case

(Aligning with @SherlockNoMad's comment)

module_indices, per_module_params, _ModuleParamsMeta, _collect_module_params, the rebuilt arg reconstruction loop, and the ExitStack for multiple _reparametrize_module calls — all of this exists to support multiple nn.Module instances interleaved in args. There's no concrete use case for this today. The only caller (trainer.py:107) passes a single module.

Recommendation: Restrict to a single module for now. This would let you drop:

  • module_indices (list → single int or just "always position 0")
  • _collect_module_params helper
  • _ModuleParamsMeta dataclass (just store fqns and num_params directly on TracedResult)
  • The rebuilt list reconstruction in fn_with_subclass_handling
  • The ExitStack loop (single _reparametrize_module context)
  • test_multiple_modules and test_module_not_first_arg

If multi-module is needed later, it can be added then. Right now it's speculative generality that makes the code harder to follow.


2. Naming: aot_function collides with PyTorch's torch._functorch.aot_autograd.aot_function

(Per @SherlockNoMad's comment)

This will confuse anyone who's worked with torch.compile internals. The PyTorch aot_function does functionalization + partitioning; this one does neither. The name implies capabilities that don't exist here.

Suggested alternatives: trace_training_step, make_fx_trace, trace_callable, or even just keeping trace_function (since you've generalized beyond modules).


3. TracedResult.__init__ has too many parameters (6)

make_fx_tracer.py:282-296 — The constructor takes 6 positional args. This is a sign the object is carrying too much state. If you simplify to single-module (point 1), this drops to ~4 which is more manageable.

Also, TracedResult.__init__ is missing a docstring (per @SherlockNoMad's comment).


4. module_indices_set is recomputed every __call__

make_fx_tracer.py:300set(self.module_indices) is computed on every forward call. As @SherlockNoMad noted, this should be cached in __init__. It's a small thing but signals that __call__ wasn't designed for hot-path efficiency.


5. FQN validation in __call__ is heavy for a hot path

make_fx_tracer.py:306-317 — On every call, you:

  1. Call mod.named_parameters() + mod.named_buffers() (these walk the module tree)
  2. Build a dict
  3. Extract FQN list
  4. Compare with saved FQNs

For a training loop that runs thousands of iterations, this is unnecessary overhead after the first validation. Consider either:

  • Validating once (on first call) and caching a "validated" flag, or
  • Dropping FQN validation entirely and trusting the caller (the module reference doesn't change in training loops)

6. Duplicated param extraction logic between _collect_module_params and TracedResult.__call__

The pattern of {**dict(mod.named_parameters(...)), **dict(mod.named_buffers(...))} appears at both make_fx_tracer.py:261-264 and make_fx_tracer.py:306-309. If you keep multi-module support, at least factor out the "get params dict for a module" step into a shared helper.


7. record_stack_traces=True — is this needed?

make_fx_tracer.py:477 — This adds stack trace recording to every FX node. @SherlockNoMad asked about the stack usage. Stack traces are useful for debugging but add overhead. If it's only needed for _copy_fwd_metadata_to_bw_nodes (which uses seq_nr), check whether seq_nr is still populated without record_stack_traces. If so, disable it for production use.


8. The fn is nn.Module special case adds implicit behavior

make_fx_tracer.py:367-369 — When fn is an nn.Module, it silently gets prepended to args and type(fn).__call__ becomes the callable. This means the calling convention changes depending on the type of fn:

# Module: trace with (tokens,), call with (model, tokens)
traced = aot_function(model, (tokens,))
result = traced(model, tokens)  # caller must know model was prepended

# Function: trace with (model, tokens), call with (model, tokens)  
traced = aot_function(train_step, (model, tokens))
result = traced(model, tokens)

The asymmetry between trace args and call args in the module case is confusing. Since trainer.py already uses the function form (via make_fwd_bwd_step), consider dropping the nn.Module special case entirely. Users who want to trace a module can write aot_function(lambda m, *a: m(*a), (model,) + args) or more naturally just aot_function(model.forward, ...) if you handle bound methods.


9. loss_fn in closure — the TODO is valid and should be addressed before landing

trainer.py:27-29:

# TODO: investigate how loss_fn interacts with non-strict trace.

If this is a known open question, it should be investigated now rather than landing as tech debt. If loss_fn gets captured as a constant during tracing, changing it later won't take effect — that's a subtle correctness issue.


10. Test file: _run_fsdp_model_test duplicates _run_bitwise_test

test_trace_module.py:734-827 largely duplicates test_trace_module.py:472-534. The FSDP version adds FSDP setup and collective op checks, but the core training loop is copy-pasted. This isn't introduced by this PR (it existed before), but it's worth noting since the PR touches both.


Summary

The core idea (callable → trace → callable result) is sound and simpler than the old trace_module + run_traced_module pair. The main bloat comes from:

  1. Multi-module support that has no concrete use case
  2. The nn.Module special case that creates an asymmetric API
  3. Per-call FQN validation that adds hot-path overhead

Addressing #1 and #2 alone would remove ~80 lines and make the code significantly easier for a frontier user to understand and hack on.


…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
@tugsbayasgalan
Copy link
Copy Markdown
Contributor Author

@claude Give a review on this PR. Did I address all @SherlockNoMad 's feedback. Is there further opportunity to simplify the design?

@claude
Copy link
Copy Markdown

claude Bot commented Mar 31, 2026

Claude finished @tugsbayasgalan's task in 2m 28s —— View job


Review: Feedback resolution and remaining simplification opportunities

  • Read all changed files in their current state
  • Map @SherlockNoMad's feedback items to resolutions
  • Identify remaining simplification opportunities
  • Write comprehensive review

Feedback Resolution Status

Tracking each piece of @SherlockNoMad's feedback:

# Feedback Status
1 Rename aot_function — name collides with PyTorch's Addressed — renamed to minimal_fx_tracer
2 TracedResult.__init__ needs docstring Addressed — docstring added at make_fx_tracer.py:275-293
3 module_indices_set recomputed every __call__ — cache it Addressedmodule_indices_set no longer exists (single-module restriction)
4 Multi-module support — do you have a concrete case? Addressed — restricted to single module at position 0 (make_fx_tracer.py:383-393)
5 What are SubclassLayout args for plain tensors? Partially addressed@tugsbayasgalan acknowledged it can be simplified but the code still stores SubclassLayout for plain tensors via _unwrap_subclasses returning them as passthrough
6 record_stack_traces — what's the stack for, can we live without it? Not addressed — still True at make_fx_tracer.py:471. @tugsbayasgalan said "Yeah we can" but it's still there. Caution: removing it may break _copy_fwd_metadata_to_bw_nodes which depends on seq_nr metadata — needs verification before removal
7 What's the rationale for making TracedResult callable? Addressed in discussion — cleaner calling convention
8 loss_fn as closure smells off Addressed in discussion — pattern agreed upon offline
9 maybe_region_inductor = ... with maybe_regional_inductor: Addressed — refactored in both _run_bitwise_test (test_trace_module.py:433-438) and _run_fsdp_model_test (test_trace_module.py:733-738)
10 _ prefix naming (from @anijain2305) Addressedmake_fwd_bwd_step and make_train_step no longer have _ prefix

Bug: input_subclass_layouts iteration in DTensor test

test_trace_module.py:268-270:

has_subclass = any(
    layout.meta is not None for layout in traced.input_subclass_layouts
)

input_subclass_layouts is now a dict[int, SubclassLayout]. Iterating over a dict iterates over keys (ints), not values. This should be:

for layout in traced.input_subclass_layouts.values()

The test likely passes by accident because DTensor params exist and the dict is non-empty (the any() over int keys evaluates truthily for non-zero ints, though int has no .meta attribute... actually this should raise AttributeError). Either this test isn't being run or there's something else going on — worth verifying.

Fix this →


Remaining Simplification Opportunities

1. _unwrap_subclasses / _wrap_subclasses still carry dict-based layout tracking when only subclasses need entries

The old list-based approach stored a SubclassLayout for every arg (including plain tensors with SubclassLayout(1, None)). The new dict-based approach only stores entries for subclasses, which is better. However, _wrap_subclasses still needs num_args to iterate through positions. An alternative is to just return (flat_tensors, layouts_dict) and let callers iterate through flat_tensors directly, using layouts_dict only when rewrapping is needed. This would simplify the reconstruction logic.

2. TracedResult.__call__ re-flattens user args with pytree.tree_flatten on every call

make_fx_tracer.py:341: user_args_flat, _ = pytree.tree_flatten(user_args) runs on every invocation. In the training loop, the structure of args[1:] doesn't change between iterations. Since user_args_spec is already captured at trace time, you could validate the spec matches and avoid the flatten cost. That said, this is a minor optimization — tree_flatten is cheap for shallow structures.

3. record_stack_traces=True should be resolved

As noted above, @SherlockNoMad asked about this and @tugsbayasgalan agreed it could be removed, but it's still present. The concern: _copy_fwd_metadata_to_bw_nodes at make_fx_tracer.py:245-269 relies on seq_nr in node metadata. If record_stack_traces=False causes seq_nr to not be recorded, removing it would silently break forward-to-backward annotation propagation. This needs a quick test: trace with record_stack_traces=False and check if seq_nr is still populated. If yes, remove it. If no, add a comment explaining why it's required.

4. _run_fsdp_model_test still duplicates _run_bitwise_test

test_trace_module.py:687-780 largely duplicates test_trace_module.py:420-483. The FSDP version adds FSDP setup and collective op checks but the core training loop is copy-pasted. This predates this PR, but since you're touching both methods anyway, extracting the shared training loop would reduce ~40 lines.

5. module_indices scan is unnecessary given the position-0 constraint

make_fx_tracer.py:383:

module_indices = [i for i, a in enumerate(args) if isinstance(a, nn.Module)]

Since you require exactly one module at position 0, you could simplify to:

if not isinstance(args[0], nn.Module):
    raise ValueError("args[0] must be an nn.Module")
if any(isinstance(a, nn.Module) for a in args[1:]):
    raise ValueError("Only one nn.Module allowed, at position 0")

This is clearer about the constraint and avoids building a list. Minor, but aligns with the "keep it minimal" goal.

6. Duplicated param extraction pattern

The pattern {**dict(mod.named_parameters(remove_duplicate=False)), **dict(mod.named_buffers(remove_duplicate=False))} appears at both make_fx_tracer.py:397-400 (in minimal_fx_tracer) and make_fx_tracer.py:324-327 (in TracedResult.__call__). A one-liner helper would reduce duplication and make intent clearer.


Summary

Most of @SherlockNoMad's feedback has been addressed. The main open items are:

  1. Bug: input_subclass_layouts dict iteration in DTensor test (test_trace_module.py:269)
  2. Unresolved: record_stack_traces=True@tugsbayasgalan agreed to remove it but it's still there, and removing it may have consequences that need verification
  3. Nice-to-have: The simplification opportunities above (items 4-6) would further reduce code but aren't blockers

The API is significantly cleaner than before — single function entry point, callable result, no more run_traced_module + manual param dict construction. The design is heading in the right direction.


View job

…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
@tugsbayasgalan
Copy link
Copy Markdown
Contributor Author

@claude Give me another round of review.

@claude
Copy link
Copy Markdown

claude Bot commented Mar 31, 2026

Claude finished @tugsbayasgalan's task in 2m 51s —— View job


Review Round 3

  • Read current state of all changed files
  • Get the diff against base branch
  • Check resolution of previous review items
  • Identify any new or remaining issues
  • Write comprehensive review

The code is significantly cleaner than the previous rounds. Most of @SherlockNoMad's feedback has been addressed. Here's where things stand.


Previous feedback: resolution status

Feedback Status
Rename aot_function (name collision) Done — now minimal_fx_tracer
TracedResult.__init__ docstring Donemake_fx_tracer.py:282-302
module_indices_set recomputed every call Done — eliminated with single-module restriction
Multi-module support without use case Done — restricted to single module at args[0]
_ prefix naming Donemake_fwd_bwd_step, make_train_step
maybe_regional_inductor formatting Donetest_trace_module.py:434-438, test_trace_module.py:734-738
Duplicated param extraction Done — factored into _get_params_and_buffers helper
FQN validation every call Done_validated flag caches after first call
input_subclass_layouts dict iteration bug Donetest_trace_module.py:270 uses .values()
record_stack_traces=True Not addressed — still present at make_fx_tracer.py:471

Remaining issues

1. record_stack_traces=True is still unresolved

make_fx_tracer.py:471@SherlockNoMad asked about this, @tugsbayasgalan agreed it could be removed, but it's still set to True.

The question is whether seq_nr metadata (used by _copy_fwd_metadata_to_bw_nodes at line 245-269) is populated independently of record_stack_traces. If seq_nr comes from the autograd engine's _patch_engine_run_backward hook rather than from record_stack_traces, then the flag can be safely removed. If not, a comment explaining the dependency is needed. Either way, this should be resolved — either remove it or document why it must stay.

2. _validated flag can mask bugs in non-standard usage

make_fx_tracer.py:333-342 — The _validated flag means FQN validation runs exactly once. If a TracedResult is accidentally reused with a different module after the first successful call, the mismatch will be silently ignored. This is fine for training loops (always the same module), but it could cause confusing bugs if someone stores a TracedResult and reuses it across different contexts.

Not a blocker — just worth adding a one-line comment like # Only validate once; assumes the same module is used for all subsequent calls.

3. test_mlp_forward output structure assertion could be stronger

test_trace_module.py:129torch.equal(out_eager, wrapped) implicitly tests that wrapped is a tensor (not a list), which validates the output spec reconstruction. This is good, but a comment noting this intentional check would make the test's purpose clearer — it's testing that a single-tensor return remains a single tensor, not [tensor].

4. Minor: record_module_stack=False comment is imprecise

make_fx_tracer.py:472:

record_module_stack=False,  # don't need nn_module_stack for now

But _copy_fwd_metadata_to_bw_nodes at line 267-269 does copy nn_module_stack from forward to backward nodes. If record_module_stack=False prevents nn_module_stack from being recorded in the first place, this copy is a no-op. If the intent is that nn_module_stack comes from a different source (like _reparametrize_module), the comment should say that. Otherwise this looks contradictory.


Overall assessment

The API is clean and well-designed. The minimal_fx_tracer → callable TracedResult pattern is a clear improvement over the old trace_module + run_traced_module + manual param dict. The single-module-at-position-0 constraint is well-validated with clear error messages.

The only actionable item before landing is resolving record_stack_traces — either remove it (if seq_nr is independent) or add a comment explaining why it's required. Everything else is minor.


View job

…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
@tugsbayasgalan
Copy link
Copy Markdown
Contributor Author

@claude Give review.

@claude
Copy link
Copy Markdown

claude Bot commented Apr 6, 2026

Claude finished @tugsbayasgalan's task in 2m 33s —— View job


Review Round 4

  • Read all changed files in current state
  • Get diff against base branch
  • Check resolution of previous review items
  • Identify any new or remaining issues
  • Write comprehensive review

Previous feedback resolution

# Feedback Status
1 Rename aot_function (collision with PyTorch) Doneminimal_fx_tracer
2 TracedResult.__init__ docstring Donemake_fx_tracer.py:289-301
3 module_indices_set caching Done — eliminated with single-module restriction
4 Multi-module support without use case Done — single module at args[0] only
5 _ prefix naming Donemake_fwd_bwd_step, make_train_step
6 maybe_regional_inductor formatting Done
7 Duplicated param extraction Done_get_params_and_buffers helper at make_fx_tracer.py:279
8 Per-call FQN validation overhead Done — moved to run_traced as a one-time check per call (no longer cached _validated flag, but the overhead is reasonable since run_traced is a separate function)
9 input_subclass_layouts dict iteration bug Donetest_trace_module.py:291 uses .values()
10 record_stack_traces=True Not addressed — still at make_fx_tracer.py:424

Remaining issues

1. record_stack_traces=True — still unresolved (minor)

make_fx_tracer.py:424 — This was flagged in rounds 1, 2, and 3. @tugsbayasgalan agreed it could be removed but it's still True.

After investigation: seq_nr (which _copy_fwd_metadata_to_bw_nodes depends on) comes from _patch_engine_run_backwardsetup_stacktrace_preservation_hooks, which is independent of record_stack_traces. The record_stack_traces flag controls whether stack_trace strings are recorded on FX nodes — these are Python stacktraces that show where each op originated.

_copy_fwd_metadata_to_bw_nodes does copy stack_trace from forward to backward nodes (make_fx_tracer.py:274-276), but this is secondary — the critical metadata is custom (for regional inductor annotations) and seq_nr (for the fwd↔bwd correlation itself). Stack traces are purely for debugging.

There's a design choice here:

  • Remove it: Saves tracing overhead. seq_nr and custom metadata still work. Backward nodes won't have stack_trace metadata (the copy in _copy_fwd_metadata_to_bw_nodes will be a no-op for that field). test_backward_nodes_have_stack_trace would need to be removed or adjusted.
  • Keep it + add comment: If stack traces on backward nodes are useful for debugging graph passes.

Not a blocker, but it should be intentional. If you keep it, add a brief comment explaining why (e.g., # needed so _copy_fwd_metadata_to_bw_nodes can propagate stack_trace to backward nodes for debugging).

2. record_module_stack=False contradicts _copy_fwd_metadata_to_bw_nodes (minor)

make_fx_tracer.py:425:

record_module_stack=False,  # don't need nn_module_stack for now

But _copy_fwd_metadata_to_bw_nodes at lines 271-273 copies nn_module_stack:

nn_module_stack = fwd_node.meta.get("nn_module_stack")
if nn_module_stack is not None:
    node.meta["nn_module_stack"] = nn_module_stack.copy()

With record_module_stack=False, nn_module_stack is never recorded on forward nodes, so this copy is always a no-op. The code is correct (the if guard handles None), but the dead code path is misleading. Either:

  • Remove the nn_module_stack copy from _copy_fwd_metadata_to_bw_nodes since it can never fire, or
  • Add a comment noting it's there for future use when record_module_stack is enabled

3. FQN validation in run_traced runs every call

make_fx_tracer.py:462-470 — Every call to run_traced walks the module tree via _get_params_and_buffers and then compares FQN lists. This is a design trade-off: run_traced is a standalone function (not a method on TracedResult), so there's no place to cache a "validated" flag.

For training loops with thousands of iterations, the repeated named_parameters() + named_buffers() tree walk adds overhead. This is acceptable for now since the module structure doesn't change during training, but worth noting as a potential optimization target — either:

  • Add a skip_validation=False kwarg for hot loops, or
  • Cache validation state on TracedResult inside run_traced

Not a blocker.

4. test_backward_nodes_have_stack_trace hardcodes num_checked == 24

test_trace_module.py:437:

self.assertEqual(num_checked, 24)

This is fragile — any change to SimpleMLP architecture, PyTorch's autograd decomposition, or the loss function would silently break this assertion. Consider using self.assertGreater(num_checked, 0) instead, unless the exact count 24 is part of the contract being tested.


Design assessment

The API is clean and well-structured. The overall flow is easy to follow:

# Trace
traced = minimal_fx_tracer(train_step)(model, tokens, labels)
# Execute
result = run_traced(traced, model, tokens, labels)

This is a significant improvement over the old pattern:

# Old: awkward module wrapping + manual param dict
train_step = TrainStepModule(model, loss_fn)
traced = trace_module(train_step, (tokens, labels))
params = _get_params_and_buffers(train_step)
result = run_traced_module(traced, params, (tokens, labels))

Key design wins:

  • No more nn.Module wrapper boilerplatemake_fwd_bwd_step is a plain function factory
  • Symmetric API — trace args and execute args have the same shape
  • TracedResult as a data class — clean separation between tracing metadata and execution
  • run_traced as a standalone function — execution logic is explicit and hackable
  • Single module at args[0] — simple, clear constraint with good error messages
  • Leaf validation — catches common mistakes (passing callables as args instead of closures)
  • New TracingContext installationmake_fx_tracer.py:419-421 installs TracingContext before make_fx so that invoke_subgraph deduplication works

The _unwrap_subclasses / _wrap_subclasses refactor from list-based to dict-based layouts is also an improvement — only subclass positions get entries, making the common case (plain tensors) cleaner.


Summary

The PR is in good shape. All major feedback from @SherlockNoMad has been addressed. The remaining items are:

  1. record_stack_traces=True: Either remove it (if stack traces on backward nodes aren't needed) or add a comment explaining why it's kept. This has been open for 3 review rounds.
  2. Dead nn_module_stack copy: Minor inconsistency with record_module_stack=False. Add a comment or remove the dead path.
  3. Hardcoded num_checked == 24: Fragile test assertion.
  4. Per-call FQN validation: Acceptable overhead for now, optimization target for later.

None of these are blockers.


View job

…h aot_function API"


Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. 
- run_traced_module now runs slightly more ergonomically. 
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"


Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. 
- run_traced_module now runs slightly more ergonomically. 
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
@tugsbayasgalan tugsbayasgalan changed the title [graph_trainer] Replace trace_module/run_traced_module with aot_function API [graph_trainer] Replace trace_module/run_traced_module with minimal_fx_trace API Apr 6, 2026
@SherlockNoMad
Copy link
Copy Markdown
Contributor

Nice, this is moving in the right direction.
What about not relying on positional argument to identify stateful input, e.g. nn.module?

…h minimal_fx_trace API"


Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. 
- run_traced_module now runs slightly more ergonomically. 
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h minimal_fx_trace API"


Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. 
- run_traced_module now runs slightly more ergonomically. 
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h minimal_fx_trace API"


Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. 
- run_traced_module now runs slightly more ergonomically. 
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
flat_outputs = traced_result.gm(*all_args)
return _wrap_to_subclasses(flat_outputs, traced_result.output_subclass_layouts)
# TODO: Consider stronger state validation once the long-term state API settles.
state = extract_module_state(module)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is in the perf critical path, consider caching extracted state, say in TracedResult?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants