Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
3d04ab9
[graph_trainer] Replace trace_module/run_traced_module with aot_funct…
tugsbayasgalan Mar 30, 2026
da526a8
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Mar 30, 2026
3394aaf
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Mar 31, 2026
8f276ba
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Mar 31, 2026
4f56847
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Mar 31, 2026
155d120
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Mar 31, 2026
9837317
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Mar 31, 2026
7d8b735
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Mar 31, 2026
8a620f8
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 2, 2026
45c8e76
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 4, 2026
8098738
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 4, 2026
11510b1
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 6, 2026
87ef3a3
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 6, 2026
f2d1fe6
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 6, 2026
baee1b1
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 6, 2026
581b8c5
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 6, 2026
c5f26fe
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 6, 2026
c9ba1fc
Update on "[graph_trainer] Replace trace_module/run_traced_module wit…
tugsbayasgalan Apr 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions torchtitan/experiments/graph_trainer/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn as nn
from torch.distributed.tensor import DTensor, Replicate
from torch.fx.traceback import annotate_fn
from torch.utils._pytree import register_pytree_node, tree_map
from torch.utils._pytree import register_constant, register_pytree_node, tree_map

from torchtitan.config import CompileConfig
from torchtitan.distributed import ParallelDims
Expand Down Expand Up @@ -70,6 +70,16 @@ def register_blockmask_pytree_node():
)


def maybe_register_blockmask_pytree_node() -> None:
"""Register flex-attention pytree helpers if they are missing."""
from torch.nn.attention.flex_attention import BlockMask, _MaskModWrapper

if BlockMask not in torch.utils._pytree.SUPPORTED_NODES:
register_blockmask_pytree_node()
if _MaskModWrapper not in torch.utils._pytree.SUPPORTED_NODES:
register_constant(_MaskModWrapper)


def end_with_pass(passes: list[Callable], names: list[str]) -> bool:
return (
len(passes) > 0
Expand Down Expand Up @@ -177,7 +187,7 @@ def apply_graph_ac(
if ac_config.mode != "selective":
raise ValueError(
f"graph_trainer only supports activation_checkpoint.mode 'selective' or "
f"'none', got '{ac_config.mode}'. Use 'selective' for graph-based SAC."
f"'none', got {ac_config.mode!r}. Use 'selective' for graph-based SAC."
)

joint_pass_names = getattr(compile_config, "joint_passes", [])
Expand Down
Loading
Loading