[graph_trainer] Replace trace_module/run_traced_module with minimal_fx_trace API#2753
[graph_trainer] Replace trace_module/run_traced_module with minimal_fx_trace API#2753tugsbayasgalan wants to merge 18 commits intogh/tugsbayasgalan/8/basefrom
Conversation
…ion API Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…ion API Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). ghstack-source-id: ed9abe3 Pull Request resolved: #2753
|
|
||
| class FwdBwdStepModule(nn.Module): | ||
| """Wraps model + loss_fn + autograd.grad into a single traceable forward. | ||
| def _make_fwd_bwd_step(loss_fn): |
There was a problem hiding this comment.
Remove the _ from the beginning. Claude does that. Please make a pass in other files for similar nit as well.
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…ion API Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). ghstack-source-id: 3a07e8e Pull Request resolved: #2753
| params_and_buffers = {**named_parameters, **named_buffers} | ||
| params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) | ||
| params_len = len(params_and_buffers_flat) | ||
| def __init__( |
There was a problem hiding this comment.
need docstring, what are the params ?
|
|
||
| def __call__(self, *args: Any) -> list[torch.Tensor]: | ||
| """Execute the traced graph, reading fresh params from modules in ``args``.""" | ||
| module_indices_set = set(self.module_indices) |
There was a problem hiding this comment.
what is module_indices_set?
it's computed for every call, should it be cached?
There was a problem hiding this comment.
Deleted this option.
|
|
||
| unwrapped_args = [] | ||
| input_layouts: list[SubclassLayout] = [] | ||
| def aot_function( |
There was a problem hiding this comment.
do we already have an aot_function in pytorch/pytorch?
compare and contrast with pytorch's version, mention what it DOESN'T do. e.g. No functionalization.
There was a problem hiding this comment.
Need a better name here tbh... Do you have any ideas? What about minimal_fx_tracer?
There was a problem hiding this comment.
Renamed to minimal_make_fx
| input_subclass_layouts: list[SubclassLayout], | ||
| output_subclass_layouts: list[SubclassLayout], |
There was a problem hiding this comment.
in the case without DTensor, input/output are just plain tensor.
In such case, what are the values for these 2 args?
There was a problem hiding this comment.
SubclassLayout(num_tensors=1, meta=None), but yeah i think we can simplify for plain tensors.
| annotate_flex_attention_for_regional_inductor() | ||
| if use_regional_inductor | ||
| else contextlib.nullcontext() |
There was a problem hiding this comment.
maybe_region_inductor = ....
with maybe_regional_inductor:
| rebuilt[i] = user_list[user_idx] | ||
| user_idx += 1 | ||
|
|
||
| with contextlib.ExitStack() as stack: |
There was a problem hiding this comment.
what's the stack for? can we live without it?
There was a problem hiding this comment.
This was originally for supporting multiple models as input, can be deleted if we only want to support one model.
| self.output_subclass_layouts = output_subclass_layouts | ||
| self.output_spec = output_spec | ||
|
|
||
| def __call__(self, *args: Any) -> list[torch.Tensor]: |
There was a problem hiding this comment.
what's the rationale for making it a callable?
Not saying it's a bad idea, just looking for design trade-offs.
There was a problem hiding this comment.
Oh this way, we can automatically fetch the param/buffer FQN etc from the tracing time and have the same calling convention as user callable. With the old design, it is bit awkward.
| def make_fwd_bwd_step(loss_fn): | ||
| """Return a plain function that traces the entire fwd+loss+bwd step. | ||
|
|
||
| This allows make_fx to trace through the entire fwd+loss+bwd as one graph. | ||
| """ | ||
| ``loss_fn`` is captured in the closure so it is not a graph input. | ||
|
|
||
| def __init__(self, model, loss_fn): | ||
| super().__init__() | ||
| self.model = model | ||
| self.loss_fn = loss_fn | ||
| TODO: investigate how loss_fn interacts with non-strict trace. Currently | ||
| it is captured as a closure variable, but non-strict tracing may need it | ||
| registered via pytree.register_constant or passed differently. | ||
| """ | ||
|
|
||
| def forward(self, inputs, labels, global_valid_tokens, extra_inputs, extra_kwargs): | ||
| pred = self.model(inputs, **extra_inputs, **extra_kwargs) | ||
| loss = self.loss_fn(pred, labels) / global_valid_tokens | ||
| params = [p for p in self.model.parameters() if p.requires_grad] | ||
| def fwd_bwd_step( | ||
| model, inputs, labels, global_valid_tokens, extra_inputs, extra_kwargs | ||
| ): | ||
| pred = model(inputs, **extra_inputs, **extra_kwargs) | ||
| loss = loss_fn(pred, labels) / global_valid_tokens | ||
| params = [p for p in model.parameters() if p.requires_grad] | ||
| grads = torch.autograd.grad(loss, params) | ||
| return [loss] + list(grads) | ||
|
|
||
| return fwd_bwd_step |
There was a problem hiding this comment.
This smells off.
a more natrual design is to have loss_fn as a part of model.
There was a problem hiding this comment.
Do you think it is ok to have loss_fn as input to the function we are tracing?
There was a problem hiding this comment.
Discussed offline. We will go with this pattern
|
@claude help me review. This is an API design PR, look for opportunities to simplify design, keep everything minimal. if this become too complicated for frontier user to hack on, it's bad. |
|
Claude finished @SherlockNoMad's task in 2m 34s —— View job Review: API Design for
|
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
|
@claude Give a review on this PR. Did I address all @SherlockNoMad 's feedback. Is there further opportunity to simplify the design? |
|
Claude finished @tugsbayasgalan's task in 2m 28s —— View job Review: Feedback resolution and remaining simplification opportunities
Feedback Resolution StatusTracking each piece of @SherlockNoMad's feedback:
Bug:
|
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
|
@claude Give me another round of review. |
|
Claude finished @tugsbayasgalan's task in 2m 51s —— View job Review Round 3
The code is significantly cleaner than the previous rounds. Most of @SherlockNoMad's feedback has been addressed. Here's where things stand. Previous feedback: resolution status
Remaining issues1.
The question is whether 2.
Not a blocker — just worth adding a one-line comment like 3.
4. Minor:
record_module_stack=False, # don't need nn_module_stack for nowBut Overall assessmentThe API is clean and well-designed. The The only actionable item before landing is resolving |
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
|
@claude Give review. |
|
Claude finished @tugsbayasgalan's task in 2m 33s —— View job Review Round 4
Previous feedback resolution
Remaining issues1.
After investigation:
There's a design choice here:
Not a blocker, but it should be intentional. If you keep it, add a brief comment explaining why (e.g., 2.
record_module_stack=False, # don't need nn_module_stack for nowBut nn_module_stack = fwd_node.meta.get("nn_module_stack")
if nn_module_stack is not None:
node.meta["nn_module_stack"] = nn_module_stack.copy()With
3. FQN validation in
For training loops with thousands of iterations, the repeated
Not a blocker. 4.
self.assertEqual(num_checked, 24)This is fragile — any change to Design assessmentThe API is clean and well-structured. The overall flow is easy to follow: # Trace
traced = minimal_fx_tracer(train_step)(model, tokens, labels)
# Execute
result = run_traced(traced, model, tokens, labels)This is a significant improvement over the old pattern: # Old: awkward module wrapping + manual param dict
train_step = TrainStepModule(model, loss_fn)
traced = trace_module(train_step, (tokens, labels))
params = _get_params_and_buffers(train_step)
result = run_traced_module(traced, params, (tokens, labels))Key design wins:
The SummaryThe PR is in good shape. All major feedback from @SherlockNoMad has been addressed. The remaining items are:
None of these are blockers. |
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. - run_traced_module now runs slightly more ergonomically. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. - run_traced_module now runs slightly more ergonomically. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
|
Nice, this is moving in the right direction. |
…h minimal_fx_trace API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. - run_traced_module now runs slightly more ergonomically. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h minimal_fx_trace API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. - run_traced_module now runs slightly more ergonomically. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h minimal_fx_trace API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> minimal_fx_tracer. Only designed to work over training loop now. - run_traced_module now runs slightly more ergonomically. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
| flat_outputs = traced_result.gm(*all_args) | ||
| return _wrap_to_subclasses(flat_outputs, traced_result.output_subclass_layouts) | ||
| # TODO: Consider stronger state validation once the long-term state API settles. | ||
| state = extract_module_state(module) |
There was a problem hiding this comment.
this is in the perf critical path, consider caching extracted state, say in TracedResult?
Stack from ghstack (oldest at bottom):
Authored-by: Claude
Redesign the graph trainer's tracing API based on the aot_function design doc.
Key changes:
make_fx_tracer.py:
time and validated at execute time to catch module structure mismatches.
(int/float/bool/str). Non-primitive values (callables, custom objects) must
be captured in fn's closure or registered via pytree.register_constant /
register_pytree_node.
trainer.py:
old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
function factory. The model is now passed as an arg, loss_fn is captured in
the closure.
test_trace_module.py:
direct TracedResult.call.
pass the leaf validation.
All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).