[graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution#2766
[graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution#2766tugsbayasgalan merged 90 commits intomainfrom
Conversation
…ion API Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…h aot_function API" Authored-by: Claude Redesign the graph trainer's tracing API based on the aot_function design doc. Key changes: make_fx_tracer.py: - Rename trace_module -> aot_function. Takes any callable (not just nn.Module) with nn.Module instances auto-detected in args and their params/buffers lifted as graph inputs. When fn is an nn.Module, it is prepended to args and type(fn).__call__ is used as the callable. - Delete run_traced_module. TracedResult is now directly callable — pass the same positional args (with live modules) to execute the graph. Fresh params are read from the modules automatically on each call. - Store and restore output pytree spec so TracedResult.__call__ returns the same pytree structure as the original function (e.g. single tensor, list, tuple, dict), not a flat list. - Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace time and validated at execute time to catch module structure mismatches. - Add _collect_module_params helper for multi-module param extraction. - Install TracingContext before make_fx so invoke_subgraph deduplication works. - Validate that all pytree leaves in args are tensors or primitives (int/float/bool/str). Non-primitive values (callables, custom objects) must be captured in fn's closure or registered via pytree.register_constant / register_pytree_node. trainer.py: - Replace FwdBwdStepModule (nn.Module wrapper that only existed because the old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain function factory. The model is now passed as an arg, loss_fn is captured in the closure. - Remove manual params_and_buffers dict construction — TracedResult.__call__ reads fresh params from the live module automatically. - Add TODO for investigating loss_fn interaction with non-strict trace. test_trace_module.py: - Replace TrainStepModule with _make_train_step plain function factory. - Remove _get_params_and_buffers helper (no longer needed). - Update all callsites: trace_module -> aot_function, run_traced_module -> direct TracedResult.__call__. - Register BlockMask as pytree node at module level so flex_attention tests pass the leaf validation. - Add test_module_not_first_arg: module at position 1 in args. - Add test_multiple_modules: two nn.Modules interleaved with a tensor. - Add test_mismatched_module_raises: FQN validation catches wrong module. - Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError. All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3, gpt_oss, flex_attention_annotations). [ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: 053b6db
Pull Request resolved: #2766
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: 9dc00d2
Pull Request resolved: #2766
…o_grad() execution to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: b038c22
Pull Request resolved: #2766
…o_grad() execution to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: be8f4ee
Pull Request resolved: #2766
…o_grad() execution to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…on to minimal_fx_tracer"
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
[ghstack-poisoned]
…al_fx_tracer
- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
_patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
explicit backward ops. Without this, PyTorch builds a redundant autograd
graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
of eager+AC on Llama 1B (BS=2, seq=2048, bf16).
ghstack-source-id: 21edf47
Pull Request resolved: #2766
| # Rematerialize activations tagged PREFER_RECOMPUTE by selective AC. | ||
| # Duplicates recomputable forward ops before the backward region and | ||
| # DCEs the original copies, reducing peak memory. | ||
| traced = remat_using_tags_for_fwd_loss_bwd_graph(traced) |
There was a problem hiding this comment.
No, user should apply this if they want.
keep the tracing simple.
There was a problem hiding this comment.
Isn't it kinda weird that user annotates nodes with AC and don't see any improvement in the end? It feels like silent incorrectness to me.
| for node in traced.graph.nodes: | ||
| ac_id = node.meta.get("custom", {}).get("ac_graph_id") | ||
| if ac_id is not None: | ||
| node.meta["ac_graph_id"] = ac_id |
There was a problem hiding this comment.
again, remove this. should be a post processing pass in user's control
|
Claude finished @tugsbayasgalan's task in 2m 1s —— View job CI Failure Analysis
There are 4 failed jobs out of 14. Lint and CPU Unit Test both passed. 8 jobs are still queued/in-progress. Failure 1: GraphTrainer 8 GPU Integration Tests (both runs)Error: Location: with torch.compiler._patch_autograd_grad():Failed test flavors:
All 8 ranks crash with the same This is caused by this PR's changes. Failure 2: GraphTrainer 8 GPU H100 Integration TestsSame root cause — Failed test flavors:
All This is caused by this PR's changes. Failure 3: 8 GPU Feature Tests (NOT related to this PR)3 sub-tests failed, all due to pre-existing infrastructure issues:
These are NOT caused by this PR — they're Inductor/LLVM/Triton bugs. Summary
Single root cause for all PR-related failures:
|
…aph-based SAC to traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between eager and aot_fx_trace with SAC <img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC <img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC <img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" /> For Llama3 8B, from the step-30 rank0 memory snapshots: - No SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 22.75 GiB - SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 16.6 GiB Deepseek3 16B before SAC <img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" /> Deepseek3 16B after SAC <img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" /> [ghstack-poisoned]
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between eager and aot_fx_trace with SAC <img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC <img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC <img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" /> For Llama3 8B, from the step-30 rank0 memory snapshots: - No SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 22.75 GiB - SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 16.6 GiB Deepseek3 16B before SAC <img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" /> Deepseek3 16B after SAC <img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" /> [ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: 157213f Pull Request resolved: #2766
|
@claude what are the failures? |
|
Claude finished @tugsbayasgalan's task in 3m 39s —— View job CI Failure Analysis
There are 2 completed failures out of 14 jobs (9 still queued/in-progress). Lint, CPU Unit Test, and 8 GPU Feature Tests all passed. Failure 1 & 2: GraphTrainer 8 GPU Integration Tests (both runs — same errors)Both runs fail on the same 2 test flavors with different errors: Test Flavor 1:
|
| Job | Status | This PR? | Error |
|---|---|---|---|
| Lint | Passed | — | — |
| CPU Unit Test | Passed | — | — |
| 8 GPU Feature Tests | Passed | — | — |
| GraphTrainer 8 GPU (run 1) | Failed | Yes | RNG in recompute + disjoint backward |
| GraphTrainer 8 GPU (run 2) | Failed | Yes | Same |
| GraphTrainer 8 GPU H100 | Queued | — | — |
| 8 other jobs | Queued | — | — |
Note: The previous torch.compiler._patch_autograd_grad error is gone — the CI PyTorch nightly now includes that API. The new failures are both in the SAC/remat pass:
-
FSDP+TP+cudagraph:
apply_sac_passtags RNG ops (dropout) for recomputation, butremat_using_tags_for_fwd_loss_bwd_graphdoesn't support RNG ops in recompute regions. Fix: either tag RNG ops asMUST_SAVEinstead ofPREFER_RECOMPUTE, or exclude RNG ops from recomputation tagging inapply_sac_pass. -
FSDP+TP+FlexAttn: The backward phase annotation doesn't capture all backward nodes when FlexAttention's nested backward is involved, causing the remat pass to see 7 disjoint backward regions. Fix: the backward annotation needs to handle flex_attention's nested
autograd.gradcalls correctly.
|
@claude I am pretty sure there was no RNG ops to be saved few days ago, is this a recent change? |
|
Claude finished @tugsbayasgalan's task in 13m 2s —— View job Analysis: RNG Ops in Recompute Regions
Short answer: This is almost certainly a recent PyTorch nightly change, not something you introduced. The Llama3 debug model has no dropout or explicit RNG ops in its forward pass. Here's the detailed analysis: What the CI showsThe SAC pass itself runs successfully — the logs show it tagging nodes with MUST_SAVE/PREFER_RECOMPUTE across all AC regions. The error comes after tagging, when The Llama3 debug model has no RNG opsI checked the entire model path: What likely changed:
|
|
@claude What is the test command to reproduce these errors? |
|
Claude finished @tugsbayasgalan's task in 3m 0s —— View job Test Commands to Reproduce CI Failures
The two failing test flavors from CI are run via the integration test runner at Full suite command (what CI runs)# This runs ALL graph_trainer_default tests (including the two that fail)
python -m torchtitan.experiments.graph_trainer.tests.integration_tests \
--test_suite graph_trainer_default \
--gpu_arch_type cuda \
/tmp/test_output \
--ngpu 8Individual failing testsFailure 1: python -m torchtitan.experiments.graph_trainer.tests.integration_tests \
--test_suite graph_trainer_default \
--test_name aot_fx_trace_llama3_fsdp_tp \
--gpu_arch_type cuda \
/tmp/test_output \
--ngpu 8Which under the hood runs Failure 2: python -m torchtitan.experiments.graph_trainer.tests.integration_tests \
--test_suite graph_trainer_default \
--test_name aot_fx_trace_llama3_fsdp_tp_flexattn \
--gpu_arch_type cuda \
/tmp/test_output \
--ngpu 8Which runs Equivalent direct
|
| ctx = TracingContext(fake_mode) | ||
| # preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes | ||
| with fake_mode, tracing(ctx), preserve_node_meta(), _skip_nested_compile(): | ||
| # _skip_nested_compile lets this tracer run when an outer dynamo trace |
There was a problem hiding this comment.
not "outer dynamo trace", it's the "current make_fx trace"?
|
|
||
| passes: list[Callable] = [ | ||
| functools.partial(tlparse_log_graph_pass, graph_name="make_fx_graph_traced"), | ||
| apply_ac_on_fwd_bwd_graph, |
There was a problem hiding this comment.
apply after graph clean up pass
| return gm | ||
|
|
||
|
|
||
| def apply_ac_on_fwd_bwd_graph( |
There was a problem hiding this comment.
please ask claude to come up with a more tasteful name, also follow naming convention, ends with *_pass
| The model must have been annotated with annotate_ac_regions before | ||
| tracing so that nodes have custom["ac_region_id"] metadata. | ||
| """ | ||
| from torch._functorch._activation_checkpointing.remat_using_tags_for_fwd_loss_bwd_graph_pass import ( |
There was a problem hiding this comment.
follow up in future diff, I also think it' better to have remat_using_tags_for_fwd_loss_bwd_graph_pass in titan, since no one it's using it in core, and it's applicable senario is niche.
There was a problem hiding this comment.
we do use it in torch.compile(fullgraph=True) today.
| with fake_mode, tracing(ctx), preserve_node_meta(), _skip_nested_compile(): | ||
| # _skip_nested_compile lets this tracer run when an outer dynamo trace | ||
| # reaches torch.compile'd FlexAttention kernels. | ||
| # _non_strict_tracing_context is required by _patch_autograd_grad() and |
There was a problem hiding this comment.
I am not a fan of needing _non_strict_tracing_context coupling with _patch_autograd_grad.
What breaks if we don't have this context?
Is there other ways to fix this?
There was a problem hiding this comment.
We only want to enable patching autograd in the cases we know it is carefully coupled with rest of the stack. So we added check in patch_autograd that it should only succeed if this specific context is on. The context is _non_strict_tracing_context. For example, we don't want to do this non-strict export.
|
|
||
|
|
||
| def run_tests(args, test_list: list[OverrideDefinitions], module=None, config=None): | ||
| def run_tests( |
There was a problem hiding this comment.
I don't see meaningful change in this file
…aph-based SAC to traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between eager and aot_fx_trace with SAC <img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC <img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC <img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" /> For Llama3 8B, from the step-30 rank0 memory snapshots: - No SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 22.75 GiB - SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 16.6 GiB Deepseek3 16B before SAC <img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" /> Deepseek3 16B after SAC <img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" /> [ghstack-poisoned]
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between eager and aot_fx_trace with SAC <img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC <img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC <img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" /> For Llama3 8B, from the step-30 rank0 memory snapshots: - No SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 22.75 GiB - SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 16.6 GiB Deepseek3 16B before SAC <img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" /> Deepseek3 16B after SAC <img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" /> [ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: 4b99a70 Pull Request resolved: #2766
…aph-based SAC to traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between eager and aot_fx_trace with SAC <img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC <img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC <img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" /> For Llama3 8B, from the step-30 rank0 memory snapshots: - No SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 22.75 GiB - SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 16.6 GiB Deepseek3 16B before SAC <img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" /> Deepseek3 16B after SAC <img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" /> [ghstack-poisoned]
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between eager and aot_fx_trace with SAC <img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC <img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC <img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" /> For Llama3 8B, from the step-30 rank0 memory snapshots: - No SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 22.75 GiB - SAC: - end-of-step memory_allocated() = 14.99 GiB GiB - end-of-step memory_reserved() = 16.6 GiB Deepseek3 16B before SAC <img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" /> Deepseek3 16B after SAC <img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" /> [ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: 1188563 Pull Request resolved: #2766
…aph-based SAC to traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between eager and aot_fx_trace with SAC <img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC with cuda graph seq_len=1024 <img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC with cuda graph seq_len=1024 <img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC without cuda graph with seq_len=8192 <img width="1728" height="577" alt="Screenshot 2026-04-15 at 4 07 50 PM" src="https://github.com/user-attachments/assets/ca7108eb-6dc6-476c-b9d9-fa320f96945a" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC without cuda graph with seq_len=8192 <img width="1721" height="592" alt="Screenshot 2026-04-15 at 4 07 43 PM" src="https://github.com/user-attachments/assets/80fcdcb0-ca38-4012-b71d-2b091adbfddd" /> Deepseek3 16B before SAC <img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" /> Deepseek3 16B after SAC <img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" /> [ghstack-poisoned]
… traced execution" Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Adds SAC option on aot_fx_trace on GraphTrainer. Adds option to dump the peak memory from CUDA caching allocator. Adds similar test as test_bitwise_equivalency.py for peak memory Here is the run comparison between eager and aot_fx_trace with SAC <img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC with cuda graph seq_len=1024 <img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC with cuda graph seq_len=1024 <img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" /> Llama3B 8b with dp=4 tp=2 on 10 step run before SAC without cuda graph with seq_len=8192 <img width="1728" height="577" alt="Screenshot 2026-04-15 at 4 07 50 PM" src="https://github.com/user-attachments/assets/ca7108eb-6dc6-476c-b9d9-fa320f96945a" /> Llama3B 8b with dp=4 tp=2 on 10 step run after SAC without cuda graph with seq_len=8192 <img width="1721" height="592" alt="Screenshot 2026-04-15 at 4 07 43 PM" src="https://github.com/user-attachments/assets/80fcdcb0-ca38-4012-b71d-2b091adbfddd" /> Deepseek3 16B before SAC <img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" /> Deepseek3 16B after SAC <img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" /> [ghstack-poisoned]
…cution Execute traced graph under torch.no_grad() since it already contains explicit backward ops (from torch.autograd.grad traced by make_fx). Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references. Replace monkey-patched _CachingTorchDispatchMode AC approach with clean graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass (which now skips backward-tagged nodes) for post-hoc tagging, then remat_using_tags_for_fwd_loss_bwd_graph for the remat transform. apply_ac_remat_pass now takes GraphModule and returns GraphModule, following the standard pass signature convention. Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC 19.60 GB (0.93x ratio), with bitwise identical losses and gradients. ghstack-source-id: d22c1e5 Pull Request resolved: #2766
|
|
||
| def run_tests(args, test_list: list[OverrideDefinitions], module=None, config=None): | ||
| """Run all integration tests to test the core features of TorchTitan""" | ||
|
|
Stack from ghstack (oldest at bottom):
Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.
Adds SAC option on aot_fx_trace on GraphTrainer.
Adds option to dump the peak memory from CUDA caching allocator.
Adds similar test as test_bitwise_equivalency.py for peak memory
Here is the run comparison between eager and aot_fx_trace with SAC

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC with cuda graph seq_len=1024

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC with cuda graph seq_len=1024

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC without cuda graph with seq_len=8192

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC without cuda graph with seq_len=8192

Deepseek3 16B before SAC

Deepseek3 16B after SAC