Skip to content

[graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution#2766

Merged
tugsbayasgalan merged 90 commits intomainfrom
gh/tugsbayasgalan/11/head
Apr 16, 2026
Merged

[graph_trainer] Add torch.no_grad() and graph-based SAC to traced execution#2766
tugsbayasgalan merged 90 commits intomainfrom
gh/tugsbayasgalan/11/head

Conversation

@tugsbayasgalan
Copy link
Copy Markdown
Contributor

@tugsbayasgalan tugsbayasgalan commented Mar 31, 2026

Stack from ghstack (oldest at bottom):

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer.

Adds option to dump the peak memory from CUDA caching allocator.

Adds similar test as test_bitwise_equivalency.py for peak memory

Here is the run comparison between eager and aot_fx_trace with SAC
Screenshot 2026-04-14 at 1 18 41 PM

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC with cuda graph seq_len=1024
Screenshot 2026-04-14 at 10 39 41 PM

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC with cuda graph seq_len=1024
Screenshot 2026-04-14 at 10 39 26 PM

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC without cuda graph with seq_len=8192
Screenshot 2026-04-15 at 4 07 50 PM

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC without cuda graph with seq_len=8192
Screenshot 2026-04-15 at 4 07 43 PM

Deepseek3 16B before SAC
Screenshot 2026-04-14 at 10 54 31 PM

Deepseek3 16B after SAC

Screenshot 2026-04-14 at 10 58 30 PM

…ion API

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…h aot_function API"

Authored-by: Claude

Redesign the graph trainer's tracing API based on the aot_function design doc.

Key changes:

make_fx_tracer.py:
- Rename trace_module -> aot_function. Takes any callable (not just nn.Module)
  with nn.Module instances auto-detected in args and their params/buffers
  lifted as graph inputs. When fn is an nn.Module, it is prepended to args
  and type(fn).__call__ is used as the callable.
- Delete run_traced_module. TracedResult is now directly callable — pass the
  same positional args (with live modules) to execute the graph. Fresh params
  are read from the modules automatically on each call.
- Store and restore output pytree spec so TracedResult.__call__ returns the
  same pytree structure as the original function (e.g. single tensor, list,
  tuple, dict), not a flat list.
- Add _ModuleParamsMeta with FQN storage. Parameter FQNs are recorded at trace
  time and validated at execute time to catch module structure mismatches.
- Add _collect_module_params helper for multi-module param extraction.
- Install TracingContext before make_fx so invoke_subgraph deduplication works.
- Validate that all pytree leaves in args are tensors or primitives
  (int/float/bool/str). Non-primitive values (callables, custom objects) must
  be captured in fn's closure or registered via pytree.register_constant /
  register_pytree_node.

trainer.py:
- Replace FwdBwdStepModule (nn.Module wrapper that only existed because the
  old trace_module required nn.Module as fn) with _make_fwd_bwd_step, a plain
  function factory. The model is now passed as an arg, loss_fn is captured in
  the closure.
- Remove manual params_and_buffers dict construction — TracedResult.__call__
  reads fresh params from the live module automatically.
- Add TODO for investigating loss_fn interaction with non-strict trace.

test_trace_module.py:
- Replace TrainStepModule with _make_train_step plain function factory.
- Remove _get_params_and_buffers helper (no longer needed).
- Update all callsites: trace_module -> aot_function, run_traced_module ->
  direct TracedResult.__call__.
- Register BlockMask as pytree node at module level so flex_attention tests
  pass the leaf validation.
- Add test_module_not_first_arg: module at position 1 in args.
- Add test_multiple_modules: two nn.Modules interleaved with a tensor.
- Add test_mismatched_module_raises: FQN validation catches wrong module.
- Add test_non_tensor_leaf_raises: callable leaf in args raises ValueError.

All 7 model tests pass (llama3, llama4, qwen3, qwen3_moe, deepseek_v3,
gpt_oss, flex_attention_annotations).

[ghstack-poisoned]
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 31, 2026
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: 053b6db
Pull Request resolved: #2766
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: 9dc00d2
Pull Request resolved: #2766
…o_grad() execution to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: b038c22
Pull Request resolved: #2766
…o_grad() execution to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: be8f4ee
Pull Request resolved: #2766
…o_grad() execution to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
…on to minimal_fx_tracer"

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Mar 31, 2026
…al_fx_tracer

- Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during
  _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can
  identify the forward/backward boundary.
- Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass.
  Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before
  backward and the forward copies are DCE'd, reducing peak memory.
- Execute traced graph under torch.no_grad() since the graph already contains
  explicit backward ops. Without this, PyTorch builds a redundant autograd
  graph keeping all forward intermediates alive via grad_fn references.
- Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20%
  of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

ghstack-source-id: 21edf47
Pull Request resolved: #2766
# Rematerialize activations tagged PREFER_RECOMPUTE by selective AC.
# Duplicates recomputable forward ops before the backward region and
# DCEs the original copies, reducing peak memory.
traced = remat_using_tags_for_fwd_loss_bwd_graph(traced)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, user should apply this if they want.

keep the tracing simple.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it kinda weird that user annotates nodes with AC and don't see any improvement in the end? It feels like silent incorrectness to me.

Comment on lines +550 to +553
for node in traced.graph.nodes:
ac_id = node.meta.get("custom", {}).get("ac_graph_id")
if ac_id is not None:
node.meta["ac_graph_id"] = ac_id
Copy link
Copy Markdown
Contributor

@SherlockNoMad SherlockNoMad Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, remove this. should be a post processing pass in user's control

@claude
Copy link
Copy Markdown

claude Bot commented Apr 15, 2026

Claude finished @tugsbayasgalan's task in 2m 1s —— View job


CI Failure Analysis

  • Check CI status
  • Analyze failing jobs
  • Report findings

There are 4 failed jobs out of 14. Lint and CPU Unit Test both passed. 8 jobs are still queued/in-progress.


Failure 1: GraphTrainer 8 GPU Integration Tests (both runs)

Error: AttributeError: module 'torch.compiler' has no attribute '_patch_autograd_grad'

Location: make_fx_tracer.py:374 — the code uses:

with torch.compiler._patch_autograd_grad():

Failed test flavors:

  • aot_fx_trace llama3 FSDP+TP+cudagraph
  • aot_fx_trace llama3 FSDP+TP+FlexAttn

All 8 ranks crash with the same AttributeError. This API doesn't exist in the PyTorch nightly installed in CI.

This is caused by this PR's changes.


Failure 2: GraphTrainer 8 GPU H100 Integration Tests

Same root causetorch.compiler._patch_autograd_grad missing.

Failed test flavors:

  • aot_fx_trace deepseek_v3 FSDP+TP+EP
  • aot_fx_trace deepseek_v3 FSDP+TP+EP+FlexAttn
  • aot_fx_trace qwen3 FSDP+TP
  • aot_fx_trace qwen3 MoE FSDP+TP+EP

All aot_fx_trace configurations fail with the identical AttributeError.

This is caused by this PR's changes.


Failure 3: 8 GPU Feature Tests (NOT related to this PR)

3 sub-tests failed, all due to pre-existing infrastructure issues:

Test Error
FSDP+FLEX_ATTN InductorError: A compilation subprocess exited unexpectedly — Inductor subprocess crash
FSDP + FLEX + per op SAC LLVM SLPVectorizer.cpp:2690 assertion failure → SIGABRT — Triton/LLVM compiler bug
SFT ChatDataset Same LLVM SLPVectorizer assertion → SIGABRT

These are NOT caused by this PR — they're Inductor/LLVM/Triton bugs.


Summary

Job Status This PR?
Lint Passed
CPU Unit Test Passed
GraphTrainer 8 GPU (run 1) Failed Yes_patch_autograd_grad missing
GraphTrainer 8 GPU (run 2) Failed Yes — same
GraphTrainer 8 GPU H100 Failed Yes — same (deepseek, qwen3)
8 GPU Feature Tests Failed No — LLVM/Inductor bugs
8 remaining jobs Queued/In-progress

Single root cause for all PR-related failures: torch.compiler._patch_autograd_grad() at make_fx_tracer.py:374 does not exist in the CI PyTorch nightly. You need to either:

  1. Guard with hasattr(torch.compiler, '_patch_autograd_grad') and make it optional
  2. Wait for the PyTorch nightly that includes this API to land in CI
  3. Use an alternative if the API was renamed or moved

…aph-based SAC to traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between eager and aot_fx_trace with SAC
<img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC 
<img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC 
<img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" />


 For Llama3 8B, from the step-30 rank0 memory snapshots:

  - No SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 22.75 GiB
  - SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 16.6 GiB
     
Deepseek3 16B before SAC 
<img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" />

Deepseek3 16B after SAC 

<img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" />



[ghstack-poisoned]
… traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between eager and aot_fx_trace with SAC
<img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC 
<img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC 
<img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" />


 For Llama3 8B, from the step-30 rank0 memory snapshots:

  - No SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 22.75 GiB
  - SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 16.6 GiB
     
Deepseek3 16B before SAC 
<img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" />

Deepseek3 16B after SAC 

<img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" />



[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Apr 15, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: 157213f
Pull Request resolved: #2766
@tugsbayasgalan
Copy link
Copy Markdown
Contributor Author

@claude what are the failures?

@claude
Copy link
Copy Markdown

claude Bot commented Apr 15, 2026

Claude finished @tugsbayasgalan's task in 3m 39s —— View job


CI Failure Analysis

  • Check CI status
  • Analyze failing jobs
  • Report findings

There are 2 completed failures out of 14 jobs (9 still queued/in-progress). Lint, CPU Unit Test, and 8 GPU Feature Tests all passed.


Failure 1 & 2: GraphTrainer 8 GPU Integration Tests (both runs — same errors)

Both runs fail on the same 2 test flavors with different errors:

Test Flavor 1: aot_fx_trace llama3 FSDP+TP+cudagraph

Error: RNG ops in recompute regions

RuntimeError: Activation checkpoint rematerialization in `forward-loss-backward` graph
does not support RNG ops in recompute regions. Please move RNG operations outside of
recompute regions, or use joint graph mode (where partitioner handles RNG).

Stack trace:

trainer.py:132  →  _make_fx_forward_backward_step
passes.py:120   →  apply_graph_passes → pass_fn(gm, example_inputs)
passes.py:491   →  apply_ac_on_fwd_bwd_graph
torch/...remat_using_tags_for_fwd_loss_bwd_graph_pass.py:124  →  raises RuntimeError

Root cause: The remat_using_tags_for_fwd_loss_bwd_graph pass detects RNG operations (e.g., dropout) in regions tagged for recomputation and rejects them. The fwd-loss-bwd graph remat pass doesn't handle RNG state save/restore like the joint graph partitioner does. The SAC tagging in apply_sac_pass is tagging nodes that include RNG ops as PREFER_RECOMPUTE, but the remat pass can't safely recompute those because the random state would differ on replay.

Test Flavor 2: aot_fx_trace llama3 FSDP+TP+FlexAttn

Error: Disjoint backward regions

RuntimeError: Detected 7 disjoint backward regions in the graph but remat only supports
a single backward region. This can happen when non-backward computation appears between
backward sections. Please annotate the real backward with
torch.fx.traceback.annotate({"phase": "backward"}).

Stack trace: Same path through passes.py:491remat_using_tags_for_fwd_loss_bwd_graph

Root cause: When FlexAttention is used, the traced graph contains nested backward computations from flex_attention's internal autograd. The _is_backward_node annotation logic (which uses custom["phase"] == "backward") isn't correctly identifying all backward nodes in the FlexAttn case, causing the remat pass to see what looks like multiple disjoint backward regions instead of one contiguous backward phase. This is the exact edge case noted in the _is_backward_node comment at passes.py:48-59 about flex_attention nested backward tracing.


Summary

Job Status This PR? Error
Lint Passed
CPU Unit Test Passed
8 GPU Feature Tests Passed
GraphTrainer 8 GPU (run 1) Failed Yes RNG in recompute + disjoint backward
GraphTrainer 8 GPU (run 2) Failed Yes Same
GraphTrainer 8 GPU H100 Queued
8 other jobs Queued

Note: The previous torch.compiler._patch_autograd_grad error is gone — the CI PyTorch nightly now includes that API. The new failures are both in the SAC/remat pass:

  1. FSDP+TP+cudagraph: apply_sac_pass tags RNG ops (dropout) for recomputation, but remat_using_tags_for_fwd_loss_bwd_graph doesn't support RNG ops in recompute regions. Fix: either tag RNG ops as MUST_SAVE instead of PREFER_RECOMPUTE, or exclude RNG ops from recomputation tagging in apply_sac_pass.

  2. FSDP+TP+FlexAttn: The backward phase annotation doesn't capture all backward nodes when FlexAttention's nested backward is involved, causing the remat pass to see 7 disjoint backward regions. Fix: the backward annotation needs to handle flex_attention's nested autograd.grad calls correctly.


@tugsbayasgalan
Copy link
Copy Markdown
Contributor Author

@claude I am pretty sure there was no RNG ops to be saved few days ago, is this a recent change?

@claude
Copy link
Copy Markdown

claude Bot commented Apr 15, 2026

Claude finished @tugsbayasgalan's task in 13m 2s —— View job


Analysis: RNG Ops in Recompute Regions

  • Read current passes.py, make_fx_tracer.py, and trainer.py
  • Analyze CI failure logs for both test flavors
  • Investigate whether RNG ops are from model code or PyTorch internals
  • Determine if this is a recent PyTorch nightly change

Short answer: This is almost certainly a recent PyTorch nightly change, not something you introduced. The Llama3 debug model has no dropout or explicit RNG ops in its forward pass.

Here's the detailed analysis:

What the CI shows

The SAC pass itself runs successfully — the logs show it tagging nodes with MUST_SAVE/PREFER_RECOMPUTE across all AC regions. The error comes after tagging, when remat_using_tags_for_fwd_loss_bwd_graph (line 124 of the PyTorch pass) validates the tagged graph and finds RNG ops in PREFER_RECOMPUTE regions.

The Llama3 debug model has no RNG ops

I checked the entire model path: TransformerBlock, attention (ScaledDotProductAttention), FeedForward (SwiGLU) — none have dropout, torch.rand, torch.randn, or any other RNG operations. The train_context is also clean (just optional loss_parallel).

What likely changed: _patch_autograd_grad or the remat pass itself

This PR switched from the old _patch_engine_run_backward() to torch.compiler._patch_autograd_grad() (make_fx_tracer.py:355). Two things could be happening:

  1. _patch_autograd_grad introduces RNG state-tracking ops into the traced graph. When PyTorch's patched autograd.grad traces backward ops, the nightly may now emit RNG state save/restore operations (like philox_seed_like or aten.bernoulli_.float) as part of its backward decomposition. These ops would appear in the forward portion of the fwd+loss+bwd graph (since the entire torch.autograd.grad call is traced as one flat graph) and get tagged as PREFER_RECOMPUTE by apply_sac_pass.

  2. The remat_using_tags_for_fwd_loss_bwd_graph pass recently added or tightened the RNG check. Since you said the RNG error didn't appear a few days ago, and the model code hasn't changed, the most likely explanation is that the PyTorch nightly updated remat_using_tags_for_fwd_loss_bwd_graph to check for RNG ops that it previously didn't check for, or _patch_autograd_grad started emitting new ops that trigger the check.

How to confirm

You can dump the graph right before the remat pass to see exactly which ops are flagged:

# In apply_ac_on_fwd_bwd_graph, before calling remat_using_tags_for_fwd_loss_bwd_graph:
for node in gm.graph.nodes:
    if node.op == "call_function" and "recompute" in node.meta:
        if node.meta["recompute"] == CheckpointPolicy.PREFER_RECOMPUTE:
            # Check what PyTorch considers RNG
            if hasattr(torch._functorch._activation_checkpointing, '_is_rng_node'):
                from torch._functorch._activation_checkpointing import _is_rng_node
                if _is_rng_node(node):
                    logger.info(f"RNG node in recompute region: {node.target}")

Or simpler: print all node targets in the graph and grep for anything RNG-related (bernoulli, uniform, normal, philox, rng).

Two separate failures

  1. FSDP+TP+cudagraph → RNG ops error: The remat pass finds RNG ops in recompute regions. Likely caused by new ops from PyTorch nightly in the traced backward.

  2. FSDP+TP+FlexAttn → 7 disjoint backward regions: FlexAttention's nested autograd.grad calls create multiple backward sections that _is_backward_node (which checks node.meta.get("autograd_backward", False)) doesn't fully identify. This is a separate issue from the RNG one.

Suggested fix for the RNG issue

If the RNG ops are coming from PyTorch internals (not your model), you likely need to either:

  • Tag those specific RNG ops as MUST_SAVE in apply_sac_pass (skip them from PREFER_RECOMPUTE)
  • Or coordinate with the PyTorch team on _patch_autograd_grad to not emit RNG ops, or on remat_using_tags_for_fwd_loss_bwd_graph to handle them

@tugsbayasgalan
Copy link
Copy Markdown
Contributor Author

@claude What is the test command to reproduce these errors?

@claude
Copy link
Copy Markdown

claude Bot commented Apr 15, 2026

Claude finished @tugsbayasgalan's task in 3m 0s —— View job


Test Commands to Reproduce CI Failures

  • Check CI workflow files for test commands
  • Read run_tests.py for test configurations
  • Find the override definitions for GraphTrainer tests
  • Provide reproduction commands

The two failing test flavors from CI are run via the integration test runner at torchtitan/experiments/graph_trainer/tests/integration_tests.py. Here's how to reproduce each:

Full suite command (what CI runs)

# This runs ALL graph_trainer_default tests (including the two that fail)
python -m torchtitan.experiments.graph_trainer.tests.integration_tests \
    --test_suite graph_trainer_default \
    --gpu_arch_type cuda \
    /tmp/test_output \
    --ngpu 8

Individual failing tests

Failure 1: aot_fx_trace llama3 FSDP+TP+cudagraph (RNG ops in recompute regions)

python -m torchtitan.experiments.graph_trainer.tests.integration_tests \
    --test_suite graph_trainer_default \
    --test_name aot_fx_trace_llama3_fsdp_tp \
    --gpu_arch_type cuda \
    /tmp/test_output \
    --ngpu 8

Which under the hood runs ./run_train.sh with these overrides:

--module graph_trainer.llama3
--config graph_trainer_llama3_debugmodel
--compile.mode aot_fx_trace
--parallelism.data_parallel_shard_degree 4
--parallelism.tensor_parallel_degree 2

Failure 2: aot_fx_trace llama3 FSDP+TP+FlexAttn (Disjoint backward regions)

python -m torchtitan.experiments.graph_trainer.tests.integration_tests \
    --test_suite graph_trainer_default \
    --test_name aot_fx_trace_llama3_fsdp_tp_flexattn \
    --gpu_arch_type cuda \
    /tmp/test_output \
    --ngpu 8

Which runs ./run_train.sh with:

--module graph_trainer.llama3
--config graph_trainer_llama3_debugmodel_flex_attn
--compile.mode aot_fx_trace
--parallelism.data_parallel_shard_degree 4
--parallelism.tensor_parallel_degree 2

Equivalent direct run_train.sh commands

If you want to run them directly without the test harness:

# Failure 1: FSDP+TP+cudagraph (RNG error)
NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 MODULE=graph_trainer.llama3 CONFIG=graph_trainer_llama3_debugmodel \
  ./run_train.sh --dump_folder /tmp/test_output/aot_fx_trace_llama3_fsdp_tp \
    --compile.mode aot_fx_trace \
    --parallelism.data_parallel_shard_degree 4 \
    --parallelism.tensor_parallel_degree 2

# Failure 2: FSDP+TP+FlexAttn (disjoint backward regions)
NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 MODULE=graph_trainer.llama3 CONFIG=graph_trainer_llama3_debugmodel_flex_attn \
  ./run_train.sh --dump_folder /tmp/test_output/aot_fx_trace_llama3_fsdp_tp_flexattn \
    --compile.mode aot_fx_trace \
    --parallelism.data_parallel_shard_degree 4 \
    --parallelism.tensor_parallel_degree 2

CI also runs these unit tests (which passed)

# Bitwise deterministic test
pytest torchtitan/experiments/graph_trainer/tests/test_bitwise_deterministic.py -v

# SAC peak memory test
pytest torchtitan/experiments/graph_trainer/tests/test_sac_peak_memory.py -v

# SAC pass unit tests
pytest torchtitan/experiments/graph_trainer/tests/test_passes.py::TestApplySACPass -v

The test definitions are in torchtitan/experiments/graph_trainer/tests/integration_tests.py — look for the build_graph_trainer_default_test_list() function to see all OverrideDefinitions.


ctx = TracingContext(fake_mode)
# preserve_node_meta propagates fx.traceback.annotate metadata to traced nodes
with fake_mode, tracing(ctx), preserve_node_meta(), _skip_nested_compile():
# _skip_nested_compile lets this tracer run when an outer dynamo trace
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not "outer dynamo trace", it's the "current make_fx trace"?


passes: list[Callable] = [
functools.partial(tlparse_log_graph_pass, graph_name="make_fx_graph_traced"),
apply_ac_on_fwd_bwd_graph,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply after graph clean up pass

return gm


def apply_ac_on_fwd_bwd_graph(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please ask claude to come up with a more tasteful name, also follow naming convention, ends with *_pass

The model must have been annotated with annotate_ac_regions before
tracing so that nodes have custom["ac_region_id"] metadata.
"""
from torch._functorch._activation_checkpointing.remat_using_tags_for_fwd_loss_bwd_graph_pass import (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow up in future diff, I also think it' better to have remat_using_tags_for_fwd_loss_bwd_graph_pass in titan, since no one it's using it in core, and it's applicable senario is niche.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do use it in torch.compile(fullgraph=True) today.

with fake_mode, tracing(ctx), preserve_node_meta(), _skip_nested_compile():
# _skip_nested_compile lets this tracer run when an outer dynamo trace
# reaches torch.compile'd FlexAttention kernels.
# _non_strict_tracing_context is required by _patch_autograd_grad() and
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a fan of needing _non_strict_tracing_context coupling with _patch_autograd_grad.
What breaks if we don't have this context?
Is there other ways to fix this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only want to enable patching autograd in the cases we know it is carefully coupled with rest of the stack. So we added check in patch_autograd that it should only succeed if this specific context is on. The context is _non_strict_tracing_context. For example, we don't want to do this non-strict export.

Comment thread tests/integration_tests/run_tests.py Outdated


def run_tests(args, test_list: list[OverrideDefinitions], module=None, config=None):
def run_tests(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see meaningful change in this file

…aph-based SAC to traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between eager and aot_fx_trace with SAC
<img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC 
<img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC 
<img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" />


 For Llama3 8B, from the step-30 rank0 memory snapshots:

  - No SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 22.75 GiB
  - SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 16.6 GiB
     
Deepseek3 16B before SAC 
<img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" />

Deepseek3 16B after SAC 

<img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" />



[ghstack-poisoned]
… traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between eager and aot_fx_trace with SAC
<img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC 
<img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC 
<img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" />


 For Llama3 8B, from the step-30 rank0 memory snapshots:

  - No SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 22.75 GiB
  - SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 16.6 GiB
     
Deepseek3 16B before SAC 
<img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" />

Deepseek3 16B after SAC 

<img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" />



[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Apr 15, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: 4b99a70
Pull Request resolved: #2766
…aph-based SAC to traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between eager and aot_fx_trace with SAC
<img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC 
<img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC 
<img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" />


 For Llama3 8B, from the step-30 rank0 memory snapshots:

  - No SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 22.75 GiB
  - SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 16.6 GiB
     
Deepseek3 16B before SAC 
<img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" />

Deepseek3 16B after SAC 

<img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" />



[ghstack-poisoned]
… traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between eager and aot_fx_trace with SAC
<img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC 
<img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC 
<img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" />


 For Llama3 8B, from the step-30 rank0 memory snapshots:

  - No SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 22.75 GiB
  - SAC:
      - end-of-step memory_allocated() = 14.99 GiB
        GiB
      - end-of-step memory_reserved() = 16.6 GiB
     
Deepseek3 16B before SAC 
<img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" />

Deepseek3 16B after SAC 

<img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" />



[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Apr 15, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: 1188563
Pull Request resolved: #2766
@tugsbayasgalan tugsbayasgalan changed the base branch from gh/tugsbayasgalan/11/base to main April 16, 2026 17:48
…aph-based SAC to traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between eager and aot_fx_trace with SAC
<img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC with cuda graph seq_len=1024
<img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC with cuda graph seq_len=1024
<img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC without cuda graph with seq_len=8192
<img width="1728" height="577" alt="Screenshot 2026-04-15 at 4 07 50 PM" src="https://github.com/user-attachments/assets/ca7108eb-6dc6-476c-b9d9-fa320f96945a" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC without cuda graph with seq_len=8192
<img width="1721" height="592" alt="Screenshot 2026-04-15 at 4 07 43 PM" src="https://github.com/user-attachments/assets/80fcdcb0-ca38-4012-b71d-2b091adbfddd" />

Deepseek3 16B before SAC 
<img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" />

Deepseek3 16B after SAC 

<img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" />



[ghstack-poisoned]
… traced execution"


Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Adds SAC option on aot_fx_trace on GraphTrainer. 

Adds option to dump the peak memory from CUDA caching allocator. 

Adds similar test as test_bitwise_equivalency.py for peak memory 

Here is the run comparison between eager and aot_fx_trace with SAC
<img width="1114" height="455" alt="Screenshot 2026-04-14 at 1 18 41 PM" src="https://github.com/user-attachments/assets/bf4ffdd3-8783-47be-ac64-32587787591f" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC with cuda graph seq_len=1024
<img width="1728" height="585" alt="Screenshot 2026-04-14 at 10 39 41 PM" src="https://github.com/user-attachments/assets/d8a1fe31-bd91-41a4-8e6f-a0e738cc2b76" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC with cuda graph seq_len=1024
<img width="1728" height="552" alt="Screenshot 2026-04-14 at 10 39 26 PM" src="https://github.com/user-attachments/assets/273b6e65-c2a3-473b-a2db-53b67a95be3a" />

Llama3B 8b with dp=4 tp=2 on 10 step run before SAC without cuda graph with seq_len=8192
<img width="1728" height="577" alt="Screenshot 2026-04-15 at 4 07 50 PM" src="https://github.com/user-attachments/assets/ca7108eb-6dc6-476c-b9d9-fa320f96945a" />

Llama3B 8b with dp=4 tp=2 on 10 step run after SAC without cuda graph with seq_len=8192
<img width="1721" height="592" alt="Screenshot 2026-04-15 at 4 07 43 PM" src="https://github.com/user-attachments/assets/80fcdcb0-ca38-4012-b71d-2b091adbfddd" />

Deepseek3 16B before SAC 
<img width="1728" height="600" alt="Screenshot 2026-04-14 at 10 54 31 PM" src="https://github.com/user-attachments/assets/61180fa4-debd-4c10-834e-e4936cfbd67e" />

Deepseek3 16B after SAC 

<img width="1728" height="578" alt="Screenshot 2026-04-14 at 10 58 30 PM" src="https://github.com/user-attachments/assets/3e3b4e71-e13f-43b7-9404-ad9a0a04f75c" />



[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Apr 16, 2026
…cution

Execute traced graph under torch.no_grad() since it already contains
explicit backward ops (from torch.autograd.grad traced by make_fx).
Without this, PyTorch builds a redundant autograd graph keeping all
forward intermediates alive via grad_fn references.

Replace monkey-patched _CachingTorchDispatchMode AC approach with clean
graph-based SAC: annotate_ac_regions before tracing, apply_sac_pass
(which now skips backward-tagged nodes) for post-hoc tagging, then
remat_using_tags_for_fwd_loss_bwd_graph for the remat transform.

apply_ac_remat_pass now takes GraphModule and returns GraphModule,
following the standard pass signature convention.

Results on Llama 1B (H100): traced graph SAC uses 18.23 GB vs eager SAC
19.60 GB (0.93x ratio), with bitwise identical losses and gradients.

ghstack-source-id: d22c1e5
Pull Request resolved: #2766

def run_tests(args, test_list: list[OverrideDefinitions], module=None, config=None):
"""Run all integration tests to test the core features of TorchTitan"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why delete this line?

@tugsbayasgalan tugsbayasgalan merged commit 132683a into main Apr 16, 2026
22 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants