Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@
from megatron.core.distributed.distributed_data_parallel_config import (
DistributedDataParallelConfig,
)
from megatron.core.transformer.cuda_graphs import is_graph_capturing
from megatron.core.utils import is_submodule
except ImportError:
# Megatron-LM is not installed, use Megatron-FSDP as a standalone module.
logger.info("Megatron Core is not installed, Megatron-FSDP will run without Megatron Core.")
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .utils import is_submodule

def is_graph_capturing():
return torch.cuda.is_current_stream_capturing()


class TrainingState(Enum):
"""States of a FSDP parameter group, which are coupled with
Expand Down Expand Up @@ -688,6 +692,10 @@ def _process_post_backward_gradients(param_list):
- In hybrid FSDP configurations, an outer FSDP group gradient reduction
may be triggered.
"""
# Skip entire gradient processing during CUDA graph capture.
if is_graph_capturing():
return

# Filter out shared parameters whose gradients are handled by the root hook.
param_list = [p for p in param_list if not getattr(p, "_is_shared", False)]

Expand Down Expand Up @@ -1000,6 +1008,12 @@ def forward_hook(_module, inputs, output):
)
return output

# Tag for cuda_graphs.py: this forward hook wraps a pre-backward handler.
# cuda_graphs.py will withhold it from TE and extract the inner handler
# into backward_pre_hooks for per-callable manual invocation.
# Attribute name must match _CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR in cuda_graphs.py.
forward_hook._cuda_graph_backward_pre_handler = custom_backward_handler

# Register the post-forward hook that attaches the custom backward hook
# on the output tensor(s).
return module.register_forward_hook(forward_hook)
Expand Down Expand Up @@ -1069,13 +1083,16 @@ def _register_pre_backward_param_unshard_hook(module):
# and reduce-scatter gradients after the backward pass.
if isinstance(module, tuple(fsdp_unit_modules)):
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
_post_bwd_hook = functools.partial(
_register_post_backward_hook, _post_backward_release_module
)
# Tag for cuda_graphs.py: this forward_pre hook wraps a post-backward handler.
# cuda_graphs.py will withhold it from TE and extract the inner handler
# into backward_hooks for per-callable manual invocation.
# Attribute name must match _CUDA_GRAPH_BACKWARD_HANDLER_ATTR in cuda_graphs.py.
_post_bwd_hook._cuda_graph_backward_handler = _post_backward_release_module
self.forward_pre_hooks[f"module {name} register post-backward hook"] = (
module.register_forward_pre_hook(
functools.partial(
_register_post_backward_hook, _post_backward_release_module
),
with_kwargs=True,
)
module.register_forward_pre_hook(_post_bwd_hook, with_kwargs=True)
)
grad_acc_param_list = [p for p in module.parameters() if p.requires_grad]
else:
Expand Down
199 changes: 197 additions & 2 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,37 @@ def to_list(self, x):
return [x] if torch.is_tensor(x) else list(x)


# Keys used to split per-module hooks_dict into TE-facing vs restore-only dicts.
_TE_HOOK_KEYS = frozenset(
{
'forward_pre_hooks',
'forward_pre_hooks_with_kwargs',
'forward_hooks',
'forward_hooks_with_kwargs',
'backward_pre_hooks',
'backward_hooks',
}
)
_RESTORE_KEYS = frozenset(
{
'forward_pre_hooks_restore',
'forward_hooks_restore',
'backward_pre_hooks_restore',
'backward_hooks_restore',
}
)

# Sentinel attribute names set by megatron_fsdp.py on forward hooks that wrap backward handlers.
# cuda_graphs.py reads these to detect which hooks to withhold from TE and how to reroute them.
# The attribute value is the inner backward handler to extract.
# Must match the attribute names used in megatron_fsdp.py.
#
# Set on a forward_pre hook: inner handler goes to backward_hooks (post-backward).
_CUDA_GRAPH_BACKWARD_HANDLER_ATTR = '_cuda_graph_backward_handler'
# Set on a forward hook: inner handler goes to backward_pre_hooks (pre-backward).
_CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR = '_cuda_graph_backward_pre_handler'


class CudaGraphManager(torch.nn.Module):
"""Creates and runs cudagraphs for a megatron module"""

Expand Down Expand Up @@ -2154,6 +2185,122 @@ def _get_cuda_graph_input_data(self):
# Generate sample arguments and keyword arguments for capturing.
sample_args, sample_kwargs = self._get_sample_arguments(order, chunk_id_list)

# Extract hooks from callables for manual invocation during CUDA Graph capture/replay.
# Two-phase approach:
# Phase 1 (_extract_module_hooks): general — copies ALL 4 hook dicts uniformly, clears them.
# Phase 2 (_apply_fsdp_hook_transforms): FSDP-specific — reroutes FSDP wrappers so their
# inner backward handlers land in the right TE-facing key while the wrappers themselves
# are withheld from TE. *_restore keys are never modified in Phase 2.

def _extract_module_hooks(module):
"""Phase 1 (general): copy all 4 PyTorch hook dicts; clear them from module.

Every hook goes into both its TE-facing key and its *_restore key (independent copy).
*_with_kwargs flag sets are populated where applicable.
No FSDP-specific logic here.
"""
hooks_dict = {}

if getattr(module, '_forward_pre_hooks', None):
with_kw = getattr(module, '_forward_pre_hooks_with_kwargs', set())
fph = dict(module._forward_pre_hooks)
hooks_dict['forward_pre_hooks'] = fph
hooks_dict['forward_pre_hooks_restore'] = dict(fph)
fph_kw = {hid: True for hid in fph if hid in with_kw}
if fph_kw:
hooks_dict['forward_pre_hooks_with_kwargs'] = fph_kw
module._forward_pre_hooks.clear()

if getattr(module, '_forward_hooks', None):
with_kw = getattr(module, '_forward_hooks_with_kwargs', set())
fh = dict(module._forward_hooks)
hooks_dict['forward_hooks'] = fh
hooks_dict['forward_hooks_restore'] = dict(fh)
fh_kw = {hid: True for hid in fh if hid in with_kw}
if fh_kw:
hooks_dict['forward_hooks_with_kwargs'] = fh_kw
module._forward_hooks.clear()

if getattr(module, '_backward_pre_hooks', None):
bph = dict(module._backward_pre_hooks)
hooks_dict['backward_pre_hooks'] = bph
hooks_dict['backward_pre_hooks_restore'] = dict(bph)
module._backward_pre_hooks.clear()

if getattr(module, '_backward_hooks', None):
bh = dict(module._backward_hooks)
hooks_dict['backward_hooks'] = bh
hooks_dict['backward_hooks_restore'] = dict(bh)
module._backward_hooks.clear()

return hooks_dict

def _apply_fsdp_hook_transforms(hooks_dict):
"""Phase 2 (FSDP-specific): reroute forward hooks that wrap backward handlers.

megatron_fsdp.py tags such hooks with sentinel attributes before registering them:
- _CUDA_GRAPH_BACKWARD_HANDLER_ATTR: set on forward_pre hooks wrapping a post-backward
handler (e.g. functools.partial over _post_backward_release_module). Cannot be passed
to TE — TE's per-callable backward would fire the handler for ALL callables at once.
- _CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR: set on forward hooks wrapping a pre-backward
handler (e.g. create_custom_backward_hook pattern).

Each tagged hook is withheld from TE; its inner handler is extracted to the
appropriate backward key for per-callable manual invocation.
*_restore keys are NOT modified.
"""
# 1. forward_pre_hooks: hooks tagged with _CUDA_GRAPH_BACKWARD_HANDLER_ATTR
fph = hooks_dict.get('forward_pre_hooks')
if fph:
to_remove = []
new_bh = {}
for hook_id, hook_fn in fph.items():
handler = getattr(hook_fn, _CUDA_GRAPH_BACKWARD_HANDLER_ATTR, None)
if handler is not None:
new_bh[hook_id] = handler
to_remove.append(hook_id)
for hook_id in to_remove:
del fph[hook_id]
hooks_dict.get('forward_pre_hooks_with_kwargs', {}).pop(hook_id, None)
if new_bh:
hooks_dict.setdefault('backward_hooks', {}).update(new_bh)
if not fph:
del hooks_dict['forward_pre_hooks']
hooks_dict.pop('forward_pre_hooks_with_kwargs', None)

# 2. forward_hooks: hooks tagged with _CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR
fh = hooks_dict.get('forward_hooks')
if fh:
to_remove = []
new_bph = {}
for hook_id, hook_fn in fh.items():
handler = getattr(hook_fn, _CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR, None)
if handler is not None:
new_bph[hook_id] = handler
to_remove.append(hook_id)
for hook_id in to_remove:
del fh[hook_id]
hooks_dict.get('forward_hooks_with_kwargs', {}).pop(hook_id, None)
if new_bph:
hooks_dict.setdefault('backward_pre_hooks', {}).update(new_bph)
if not fh:
del hooks_dict['forward_hooks']
hooks_dict.pop('forward_hooks_with_kwargs', None)

extracted_hooks = [] # TE-facing: passed to make_graphed_callables as capture_time_hooks
restore_hooks = [] # restore-only: applied to modules after graph capture
for callable_module in self.flattened_callables:
if isinstance(callable_module, torch.nn.Module):
hooks_dict = _extract_module_hooks(callable_module)
_apply_fsdp_hook_transforms(hooks_dict)
te_hooks = {k: v for k, v in hooks_dict.items() if k in _TE_HOOK_KEYS}
restore = {k: v for k, v in hooks_dict.items() if k in _RESTORE_KEYS}
extracted_hooks.append(te_hooks if te_hooks else None)
restore_hooks.append(restore if restore else None)
else:
extracted_hooks.append(None)
restore_hooks.append(None)

def get_make_graphed_callables_kwargs():
kwargs = {
'allow_unused_input': True,
Expand Down Expand Up @@ -2234,7 +2381,36 @@ def _get_fp8_enabled():
return kwargs

kwargs = get_make_graphed_callables_kwargs()
return sample_args, kwargs

# Add extracted hooks to kwargs for TE to invoke during warmup/capture
# Note: We DON'T pass replay_hooks - hooks will be restored to modules after capture
# and automatically triggered during replay
if extracted_hooks and any(h for h in extracted_hooks):
# TE's make_graphed_callables asserts that all capture_time_hooks return None
# (hooks must only have side effects; returning tensors would corrupt the static
# CUDA graph buffers). Wrap each hook_fn so the return value is always None.
def _wrap_none(fn):
def wrapper(*args, **kwargs):
fn(*args, **kwargs)

return wrapper

def _wrap_hooks_dict(hooks_dict):
if hooks_dict is None:
return None
wrapped = {}
for key, id_to_fn in hooks_dict.items():
# forward_pre_hooks_with_kwargs and forward_hooks_with_kwargs are flag sets
# ({hook_id: True}), not hook_fn dicts — skip wrapping.
if key in ('forward_pre_hooks_with_kwargs', 'forward_hooks_with_kwargs'):
wrapped[key] = id_to_fn
else:
wrapped[key] = {hid: _wrap_none(fn) for hid, fn in id_to_fn.items()}
return wrapped

kwargs['capture_time_hooks'] = [_wrap_hooks_dict(h) for h in extracted_hooks]

return sample_args, kwargs, restore_hooks

def _start_capturing(self):
"""
Expand Down Expand Up @@ -2320,7 +2496,7 @@ def create_cudagraphs(self):
)
else:
# Prepare CUDA Graph capturing input data and call `make_graphed_callables`.
sample_args, kwargs = self._get_cuda_graph_input_data()
sample_args, kwargs, restore_hooks = self._get_cuda_graph_input_data()
if self.config.sequence_parallel:
rng_context = get_cuda_rng_tracker().fork()
else:
Expand All @@ -2330,6 +2506,25 @@ def create_cudagraphs(self):
tuple(self.flattened_callables), sample_args, **kwargs
)

# Restore original hooks to callables after CUDA Graph capture.
# restore_hooks contains only the hooks cleared before capture; _with_kwargs flag dicts
# survive .clear() and need no explicit restoration.
if restore_hooks and any(h for h in restore_hooks):
for callable_module, restore in zip(self.flattened_callables, restore_hooks):
if isinstance(callable_module, torch.nn.Module) and restore:
if 'forward_pre_hooks_restore' in restore:
for hook_id, hook_fn in restore['forward_pre_hooks_restore'].items():
callable_module._forward_pre_hooks[hook_id] = hook_fn
if 'forward_hooks_restore' in restore:
for hook_id, hook_fn in restore['forward_hooks_restore'].items():
callable_module._forward_hooks[hook_id] = hook_fn
if 'backward_pre_hooks_restore' in restore:
for hook_id, hook_fn in restore['backward_pre_hooks_restore'].items():
callable_module._backward_pre_hooks[hook_id] = hook_fn
if 'backward_hooks_restore' in restore:
for hook_id, hook_fn in restore['backward_hooks_restore'].items():
callable_module._backward_hooks[hook_id] = hook_fn

# Push the captured graphs to the corresponding TransformerBlock.
num_layers_accumulated = 0
for layers in self.callables_per_chunk:
Expand Down
31 changes: 22 additions & 9 deletions megatron/core/transformer/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,12 @@ def get_layer_static_inputs(self, seq_length, micro_batch_size):
)

static_inputs = {}
params_dtype = (
self.config.params_dtype if self.config.params_dtype is not None else torch.bfloat16
)
static_inputs["hidden_states"] = torch.ones(
(slen_per_cptp, micro_batch_size, self.config.hidden_size),
dtype=torch.bfloat16,
dtype=params_dtype,
requires_grad=True,
device=torch.cuda.current_device(),
)
Expand Down Expand Up @@ -285,7 +288,8 @@ def _te_cuda_graph_capture(self, *args, **kwargs):
CUDA Graph capture for this layer using TE interface.
Normally it's just a forward pass if we're capturing the entire layer.
"""
return self.forward(*args, **kwargs)
forward_func = getattr(self, '_original_forward', self.forward)
return forward_func(*args, **kwargs)

def _te_cuda_graph_replay(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -355,13 +359,22 @@ def __call__(self, *args, **kwargs):
if self._should_call_local_cudagraph(*args, **kwargs):
return self.cudagraph_manager(self, args, kwargs)
elif self._should_call_te_cudagraph(*args, **kwargs):
if not self.cuda_graphs:
# Do CUDA Graphs capture.
cuda_graph_func = self._te_cuda_graph_capture
else:
# Do CUDA Graphs replay.
cuda_graph_func = self._te_cuda_graph_replay
return cuda_graph_func(*args, **kwargs)
# Temporarily replace forward with cuda graph function
self._original_forward = self.forward
try:
if not self.cuda_graphs:
# Do CUDA Graphs capture.
self.forward = self._te_cuda_graph_capture
else:
# Do CUDA Graphs replay.
self.forward = self._te_cuda_graph_replay

return super().__call__(*args, **kwargs)
finally:
# Restore original forward and clean up temporary attribute
self.forward = self._original_forward
if hasattr(self, '_original_forward'):
delattr(self, '_original_forward')
return super().__call__(*args, **kwargs)


Expand Down
Loading