diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 7536a1ba9d2..d22cf8f0515 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -50,6 +50,7 @@ from megatron.core.distributed.distributed_data_parallel_config import ( DistributedDataParallelConfig, ) + from megatron.core.transformer.cuda_graphs import is_graph_capturing from megatron.core.utils import is_submodule except ImportError: # Megatron-LM is not installed, use Megatron-FSDP as a standalone module. @@ -57,6 +58,9 @@ from .distributed_data_parallel_config import DistributedDataParallelConfig from .utils import is_submodule + def is_graph_capturing(): + return torch.cuda.is_current_stream_capturing() + class TrainingState(Enum): """States of a FSDP parameter group, which are coupled with @@ -690,6 +694,10 @@ def _process_post_backward_gradients(param_list): - In hybrid FSDP configurations, an outer FSDP group gradient reduction may be triggered. """ + # Skip entire gradient processing during CUDA graph capture. + if is_graph_capturing(): + return + # Filter out shared parameters whose gradients are handled by the root hook. param_list = [p for p in param_list if not getattr(p, "_is_shared", False)] @@ -980,6 +988,12 @@ def forward_hook(_module, inputs, output): ) return output + # Tag for cuda_graphs.py: this forward hook wraps a pre-backward handler. + # cuda_graphs.py will withhold it from TE and extract the inner handler + # into backward_pre_hooks for per-callable manual invocation. + # Attribute name must match _CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR in cuda_graphs.py. + forward_hook._cuda_graph_backward_pre_handler = custom_backward_handler + # Register the post-forward hook that attaches the custom backward hook # on the output tensor(s). return module.register_forward_hook(forward_hook) @@ -1047,13 +1061,16 @@ def _register_pre_backward_param_unshard_hook(module): # and reduce-scatter gradients after the backward pass. if isinstance(module, tuple(fsdp_unit_modules)): if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params": + _post_bwd_hook = functools.partial( + _register_post_backward_hook, _post_backward_release_module + ) + # Tag for cuda_graphs.py: this forward_pre hook wraps a post-backward handler. + # cuda_graphs.py will withhold it from TE and extract the inner handler + # into backward_hooks for per-callable manual invocation. + # Attribute name must match _CUDA_GRAPH_BACKWARD_HANDLER_ATTR in cuda_graphs.py. + _post_bwd_hook._cuda_graph_backward_handler = _post_backward_release_module self.forward_pre_hooks[f"module {name} register post-backward hook"] = ( - module.register_forward_pre_hook( - functools.partial( - _register_post_backward_hook, _post_backward_release_module - ), - with_kwargs=True, - ) + module.register_forward_pre_hook(_post_bwd_hook, with_kwargs=True) ) grad_acc_param_list = [p for p in module.parameters() if p.requires_grad] else: diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index c7631519e43..f7f5d9bb846 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1402,6 +1402,37 @@ def to_list(self, x): return [x] if torch.is_tensor(x) else list(x) +# Keys used to split per-module hooks_dict into TE-facing vs restore-only dicts. +_TE_HOOK_KEYS = frozenset( + { + 'forward_pre_hooks', + 'forward_pre_hooks_with_kwargs', + 'forward_hooks', + 'forward_hooks_with_kwargs', + 'backward_pre_hooks', + 'backward_hooks', + } +) +_RESTORE_KEYS = frozenset( + { + 'forward_pre_hooks_restore', + 'forward_hooks_restore', + 'backward_pre_hooks_restore', + 'backward_hooks_restore', + } +) + +# Sentinel attribute names set by megatron_fsdp.py on forward hooks that wrap backward handlers. +# cuda_graphs.py reads these to detect which hooks to withhold from TE and how to reroute them. +# The attribute value is the inner backward handler to extract. +# Must match the attribute names used in megatron_fsdp.py. +# +# Set on a forward_pre hook: inner handler goes to backward_hooks (post-backward). +_CUDA_GRAPH_BACKWARD_HANDLER_ATTR = '_cuda_graph_backward_handler' +# Set on a forward hook: inner handler goes to backward_pre_hooks (pre-backward). +_CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR = '_cuda_graph_backward_pre_handler' + + class CudaGraphManager(torch.nn.Module): """Creates and runs cudagraphs for a megatron module""" @@ -2162,6 +2193,122 @@ def _get_cuda_graph_input_data(self): # Generate sample arguments and keyword arguments for capturing. sample_args, sample_kwargs = self._get_sample_arguments(order, chunk_id_list) + # Extract hooks from callables for manual invocation during CUDA Graph capture/replay. + # Two-phase approach: + # Phase 1 (_extract_module_hooks): general — copies ALL 4 hook dicts uniformly, clears them. + # Phase 2 (_apply_fsdp_hook_transforms): FSDP-specific — reroutes FSDP wrappers so their + # inner backward handlers land in the right TE-facing key while the wrappers themselves + # are withheld from TE. *_restore keys are never modified in Phase 2. + + def _extract_module_hooks(module): + """Phase 1 (general): copy all 4 PyTorch hook dicts; clear them from module. + + Every hook goes into both its TE-facing key and its *_restore key (independent copy). + *_with_kwargs flag sets are populated where applicable. + No FSDP-specific logic here. + """ + hooks_dict = {} + + if getattr(module, '_forward_pre_hooks', None): + with_kw = getattr(module, '_forward_pre_hooks_with_kwargs', set()) + fph = dict(module._forward_pre_hooks) + hooks_dict['forward_pre_hooks'] = fph + hooks_dict['forward_pre_hooks_restore'] = dict(fph) + fph_kw = {hid: True for hid in fph if hid in with_kw} + if fph_kw: + hooks_dict['forward_pre_hooks_with_kwargs'] = fph_kw + module._forward_pre_hooks.clear() + + if getattr(module, '_forward_hooks', None): + with_kw = getattr(module, '_forward_hooks_with_kwargs', set()) + fh = dict(module._forward_hooks) + hooks_dict['forward_hooks'] = fh + hooks_dict['forward_hooks_restore'] = dict(fh) + fh_kw = {hid: True for hid in fh if hid in with_kw} + if fh_kw: + hooks_dict['forward_hooks_with_kwargs'] = fh_kw + module._forward_hooks.clear() + + if getattr(module, '_backward_pre_hooks', None): + bph = dict(module._backward_pre_hooks) + hooks_dict['backward_pre_hooks'] = bph + hooks_dict['backward_pre_hooks_restore'] = dict(bph) + module._backward_pre_hooks.clear() + + if getattr(module, '_backward_hooks', None): + bh = dict(module._backward_hooks) + hooks_dict['backward_hooks'] = bh + hooks_dict['backward_hooks_restore'] = dict(bh) + module._backward_hooks.clear() + + return hooks_dict + + def _apply_fsdp_hook_transforms(hooks_dict): + """Phase 2 (FSDP-specific): reroute forward hooks that wrap backward handlers. + + megatron_fsdp.py tags such hooks with sentinel attributes before registering them: + - _CUDA_GRAPH_BACKWARD_HANDLER_ATTR: set on forward_pre hooks wrapping a post-backward + handler (e.g. functools.partial over _post_backward_release_module). Cannot be passed + to TE — TE's per-callable backward would fire the handler for ALL callables at once. + - _CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR: set on forward hooks wrapping a pre-backward + handler (e.g. create_custom_backward_hook pattern). + + Each tagged hook is withheld from TE; its inner handler is extracted to the + appropriate backward key for per-callable manual invocation. + *_restore keys are NOT modified. + """ + # 1. forward_pre_hooks: hooks tagged with _CUDA_GRAPH_BACKWARD_HANDLER_ATTR + fph = hooks_dict.get('forward_pre_hooks') + if fph: + to_remove = [] + new_bh = {} + for hook_id, hook_fn in fph.items(): + handler = getattr(hook_fn, _CUDA_GRAPH_BACKWARD_HANDLER_ATTR, None) + if handler is not None: + new_bh[hook_id] = handler + to_remove.append(hook_id) + for hook_id in to_remove: + del fph[hook_id] + hooks_dict.get('forward_pre_hooks_with_kwargs', {}).pop(hook_id, None) + if new_bh: + hooks_dict.setdefault('backward_hooks', {}).update(new_bh) + if not fph: + del hooks_dict['forward_pre_hooks'] + hooks_dict.pop('forward_pre_hooks_with_kwargs', None) + + # 2. forward_hooks: hooks tagged with _CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR + fh = hooks_dict.get('forward_hooks') + if fh: + to_remove = [] + new_bph = {} + for hook_id, hook_fn in fh.items(): + handler = getattr(hook_fn, _CUDA_GRAPH_BACKWARD_PRE_HANDLER_ATTR, None) + if handler is not None: + new_bph[hook_id] = handler + to_remove.append(hook_id) + for hook_id in to_remove: + del fh[hook_id] + hooks_dict.get('forward_hooks_with_kwargs', {}).pop(hook_id, None) + if new_bph: + hooks_dict.setdefault('backward_pre_hooks', {}).update(new_bph) + if not fh: + del hooks_dict['forward_hooks'] + hooks_dict.pop('forward_hooks_with_kwargs', None) + + extracted_hooks = [] # TE-facing: passed to make_graphed_callables as capture_time_hooks + restore_hooks = [] # restore-only: applied to modules after graph capture + for callable_module in self.flattened_callables: + if isinstance(callable_module, torch.nn.Module): + hooks_dict = _extract_module_hooks(callable_module) + _apply_fsdp_hook_transforms(hooks_dict) + te_hooks = {k: v for k, v in hooks_dict.items() if k in _TE_HOOK_KEYS} + restore = {k: v for k, v in hooks_dict.items() if k in _RESTORE_KEYS} + extracted_hooks.append(te_hooks if te_hooks else None) + restore_hooks.append(restore if restore else None) + else: + extracted_hooks.append(None) + restore_hooks.append(None) + def get_make_graphed_callables_kwargs(): kwargs = { 'allow_unused_input': True, @@ -2233,7 +2380,36 @@ def _get_fp8_enabled(): return kwargs kwargs = get_make_graphed_callables_kwargs() - return sample_args, kwargs + + # Add extracted hooks to kwargs for TE to invoke during warmup/capture + # Note: We DON'T pass replay_hooks - hooks will be restored to modules after capture + # and automatically triggered during replay + if extracted_hooks and any(h for h in extracted_hooks): + # TE's make_graphed_callables asserts that all capture_time_hooks return None + # (hooks must only have side effects; returning tensors would corrupt the static + # CUDA graph buffers). Wrap each hook_fn so the return value is always None. + def _wrap_none(fn): + def wrapper(*args, **kwargs): + fn(*args, **kwargs) + + return wrapper + + def _wrap_hooks_dict(hooks_dict): + if hooks_dict is None: + return None + wrapped = {} + for key, id_to_fn in hooks_dict.items(): + # forward_pre_hooks_with_kwargs and forward_hooks_with_kwargs are flag sets + # ({hook_id: True}), not hook_fn dicts — skip wrapping. + if key in ('forward_pre_hooks_with_kwargs', 'forward_hooks_with_kwargs'): + wrapped[key] = id_to_fn + else: + wrapped[key] = {hid: _wrap_none(fn) for hid, fn in id_to_fn.items()} + return wrapped + + kwargs['capture_time_hooks'] = [_wrap_hooks_dict(h) for h in extracted_hooks] + + return sample_args, kwargs, restore_hooks def _start_capturing(self): """ @@ -2300,7 +2476,7 @@ def create_cudagraphs(self): ) else: # Prepare CUDA Graph capturing input data and call `make_graphed_callables`. - sample_args, kwargs = self._get_cuda_graph_input_data() + sample_args, kwargs, restore_hooks = self._get_cuda_graph_input_data() if self.config.sequence_parallel: rng_context = get_cuda_rng_tracker().fork() else: @@ -2310,6 +2486,25 @@ def create_cudagraphs(self): tuple(self.flattened_callables), sample_args, **kwargs ) + # Restore original hooks to callables after CUDA Graph capture. + # restore_hooks contains only the hooks cleared before capture; _with_kwargs flag dicts + # survive .clear() and need no explicit restoration. + if restore_hooks and any(h for h in restore_hooks): + for callable_module, restore in zip(self.flattened_callables, restore_hooks): + if isinstance(callable_module, torch.nn.Module) and restore: + if 'forward_pre_hooks_restore' in restore: + for hook_id, hook_fn in restore['forward_pre_hooks_restore'].items(): + callable_module._forward_pre_hooks[hook_id] = hook_fn + if 'forward_hooks_restore' in restore: + for hook_id, hook_fn in restore['forward_hooks_restore'].items(): + callable_module._forward_hooks[hook_id] = hook_fn + if 'backward_pre_hooks_restore' in restore: + for hook_id, hook_fn in restore['backward_pre_hooks_restore'].items(): + callable_module._backward_pre_hooks[hook_id] = hook_fn + if 'backward_hooks_restore' in restore: + for hook_id, hook_fn in restore['backward_hooks_restore'].items(): + callable_module._backward_hooks[hook_id] = hook_fn + # Push the captured graphs to the corresponding TransformerBlock. num_layers_accumulated = 0 for layers in self.callables_per_chunk: diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 6539ee36105..c4d1f44aeee 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -246,9 +246,12 @@ def get_layer_static_inputs(self, seq_length, micro_batch_size): ) static_inputs = {} + params_dtype = ( + self.config.params_dtype if self.config.params_dtype is not None else torch.bfloat16 + ) static_inputs["hidden_states"] = torch.ones( (slen_per_cptp, micro_batch_size, self.config.hidden_size), - dtype=torch.bfloat16, + dtype=params_dtype, requires_grad=True, device=torch.cuda.current_device(), ) @@ -285,7 +288,8 @@ def _te_cuda_graph_capture(self, *args, **kwargs): CUDA Graph capture for this layer using TE interface. Normally it's just a forward pass if we're capturing the entire layer. """ - return self.forward(*args, **kwargs) + forward_func = getattr(self, '_original_forward', self.forward) + return forward_func(*args, **kwargs) def _te_cuda_graph_replay(self, *args, **kwargs): """ @@ -346,13 +350,22 @@ def __call__(self, *args, **kwargs): if self._should_call_local_cudagraph(*args, **kwargs): return self.cudagraph_manager(self, args, kwargs) elif self._should_call_te_cudagraph(*args, **kwargs): - if not self.cuda_graphs: - # Do CUDA Graphs capture. - cuda_graph_func = self._te_cuda_graph_capture - else: - # Do CUDA Graphs replay. - cuda_graph_func = self._te_cuda_graph_replay - return cuda_graph_func(*args, **kwargs) + # Temporarily replace forward with cuda graph function + self._original_forward = self.forward + try: + if not self.cuda_graphs: + # Do CUDA Graphs capture. + self.forward = self._te_cuda_graph_capture + else: + # Do CUDA Graphs replay. + self.forward = self._te_cuda_graph_replay + + return super().__call__(*args, **kwargs) + finally: + # Restore original forward and clean up temporary attribute + self.forward = self._original_forward + if hasattr(self, '_original_forward'): + delattr(self, '_original_forward') return super().__call__(*args, **kwargs) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 403f0cf4c5f..b7df8f9ef6e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1564,6 +1564,15 @@ def validate_args(args, defaults={}): args.cpu_offloading = True # CUDA Graphs + if args.cuda_graph_scope == "full" or ( + isinstance(args.cuda_graph_scope, list) and "full" in args.cuda_graph_scope + ): + if isinstance(args.cuda_graph_scope, list): + assert args.cuda_graph_scope == ["full"], "full scope cannot be used with other scopes." + args.cuda_graph_scope = [] + warn_rank_0( + 'full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer.' + ) if args.cuda_graph_impl != "none": if ( "transformer_engine" in (args.transformer_impl, args.cuda_graph_impl) @@ -1579,16 +1588,20 @@ def validate_args(args, defaults={}): "Setting NCCL_GRAPH_REGISTER=0 to avoid illegal memory access when using " "CUDA Graph with PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True." ) - if args.cuda_graph_scope == "full" or ( - isinstance(args.cuda_graph_scope, list) and "full" in args.cuda_graph_scope - ): - if isinstance(args.cuda_graph_scope, list): - assert args.cuda_graph_scope == ["full"], "full scope cannot be used with other scopes." - args.cuda_graph_scope = [] - warn_rank_0( - 'full scope is deprecated. Use empty cuda_graph_scope to capture the whole layer.' - ) - + if "full_iteration" not in args.cuda_graph_scope and args.use_megatron_fsdp: + assert args.fsdp_double_buffer, ( + "CUDA Graph requires --fsdp-double-buffer when using Megatron-FSDP. " + "Without double buffer, FSDP parameter buffers addresses are dynamic across " + "iterations, causing numerical errors during graph replay." + ) + assert args.fsdp_db_use_persist_buf_on_alloc_fail, ( + "CUDA Graph with Megatron-FSDP and MoE requires " + "--fsdp-db-use-persist-buf-on-alloc-fail. This is to prevent failed allocation " + "goes to a dynamic buffer, causing illegal memory access during graph replay. " + "You may disable this assertion if you are sure there is no allocation failure " + "in the CUDA graph scope." + ) + if args.multi_latent_attention: assert not args.group_query_attention, "Group query attention is mutually exclusive with multi latent attention." @@ -2730,6 +2743,12 @@ def _add_distributed_args(parser): "Double-buffering the communication memory improves memory management efficiency by " "reusing previously allocated buffers, rather than creating new buffers for each FSDP communication. " "This is required for user buffer registration and is enabled by default when using NCCL user buffers.") + group.add_argument('--fsdp-db-use-persist-buf-on-alloc-fail', action='store_true', + help="Whether to fall back to persistent buffer when a bucket does not fit FSDP double buffer " + "size. If true, FSDP will use the persistently allocated buffer for the bucket that does not " + "fit, it will enable NCCL user buffer with the cost of more memory usage. If false, FSDP will " + "use dynamic memory allocator, NCCL user buffer won't be enabled, which usually leads to low " + "performance.") group.add_argument('--suggested-communication-unit-size', type=int, default=None, help='Specifies the number of elements to communicate at once during FSDP (Fully Sharded Data Parallel) operations. ' 'This flag also affects FSDP all-gather prefetch behavior. Setting a larger value increases the communication buffer size, ' diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py index 3f6670397e2..6c115746211 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py @@ -17,8 +17,12 @@ from megatron.core.optimizer import OptimizerConfig from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import HAVE_TE from megatron.core.transformer import TransformerConfig -from megatron.core.utils import is_torch_min_version +from megatron.core.transformer.cuda_graphs import TECudaGraphHelper +from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.utils import get_attr_wrapped_model, is_te_min_version, is_torch_min_version +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from tests.unit_tests.distributed.megatron_fsdp.utils import ( make_gpt_mock_data_iterator, make_moe_args_model_and_optimizer, @@ -633,6 +637,12 @@ def _training_loop(seed=42, **kwargs): - expert_model_parallel_size (int): Expert model parallel size for MoE. Default: 1. - expert_tensor_parallel_size (int): Expert tensor parallel size for MoE. Default: 1. - num_distributed_optimizer_instances (int): Number of distributed optimizer instances. Default: 1. + - cuda_graph_impl (str): CUDA graph backend. "transformer_engine" enables TE CUDA graphs. + The model and warmup steps run on a non-default side stream so AccumulateGrad nodes + are never on the legacy default stream (avoids cudaErrorStreamCaptureImplicit). + - cuda_graph_warmup_steps (int): Warmup iterations before CUDA graph capture. Default: 0. + Consumed locally; not forwarded to make_moe_args_model_and_optimizer. + - cuda_graph_scope (list[CudaGraphScope]): Scopes to capture; forwarded to TransformerConfig. Returns: list: A list of length train_iters containing the per-step language-model loss values (the value appended from output[-1] each iteration). Loss objects are returned as produced @@ -661,6 +671,8 @@ def _training_loop(seed=42, **kwargs): EP = kwargs.get("expert_model_parallel_size", 1) ETP = kwargs.get("expert_tensor_parallel_size", 1) OUTER_DP = kwargs.get("num_distributed_optimizer_instances", 1) + CUDA_GRAPH_IMPL = kwargs.get("cuda_graph_impl", "none") + CUDA_GRAPH_WARMUP_STEPS = kwargs.pop("cuda_graph_warmup_steps", 0) # Initialize model parallel groups Utils.initialize_model_parallel( @@ -672,22 +684,33 @@ def _training_loop(seed=42, **kwargs): ) DP_GROUP = mpu.get_data_parallel_group() - # Set manual seed for reproducibility + # Set manual seed for reproducibility. When using TE CUDA graphs, + # re-initialize the RNG tracker with te_rng_tracker=True. set_manual_seed(seed) + if CUDA_GRAPH_IMPL == "transformer_engine": + model_parallel_cuda_manual_seed( + seed, te_rng_tracker=True, use_cudagraphable_rng=True, force_reset_rng=True + ) + + # When using TE CUDA graphs, switch to a non-default side stream BEFORE + # model initialization. AccumulateGrad nodes are associated with the CUDA + # stream current at creation time. If they are created on the legacy/default + # stream (stream 0), TE's backward capture (which runs on an internal non- + # default capture stream) triggers cudaErrorStreamCaptureImplicit when the + # autograd engine tries to synchronize the default stream with the capture + # stream. Initializing the model and running warmup steps on a non-default + # side stream ensures AccumulateGrad is never on stream 0. + te_side_stream = None + if CUDA_GRAPH_IMPL == "transformer_engine": + te_side_stream = torch.cuda.Stream() + te_side_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.set_stream(te_side_stream) # Create model and optimizer model_chunks, optim = make_moe_args_model_and_optimizer( ut_filename="test_mcore_fully_sharded_data_parallel.py", - micro_batch_size=MICRO_BATCH_SIZE, - global_batch_size=GLOBAL_BATCH_SIZE, - vocab_size=VOCAB_SIZE, padded_vocab_size=VOCAB_SIZE, - seq_length=MAX_SEQ_LEN, sequence_parallel=TP > 1, - tensor_model_parallel_size=TP, - pipeline_model_parallel_size=PP, - num_layers_per_virtual_pipeline_stage=VPP, - train_iters=NUM_TRAINING_STEPS, **kwargs, ) @@ -700,10 +723,25 @@ def _training_loop(seed=42, **kwargs): num_samples=GLOBAL_BATCH_SIZE * NUM_TRAINING_STEPS, ) + # Create CUDA graph helper (after model is built). + cuda_graph_helper = None + if CUDA_GRAPH_IMPL == "transformer_engine": + config = get_attr_wrapped_model(model_chunks[0], 'config') + cuda_graph_helper = TECudaGraphHelper( + model=model_chunks, + config=config, + seq_length=MAX_SEQ_LEN, + micro_batch_size=MICRO_BATCH_SIZE, + optimizers=[optim], + ) + outputs = [] # Training loop - for _ in range(NUM_TRAINING_STEPS): + for i in range(NUM_TRAINING_STEPS): + if cuda_graph_helper is not None and i == CUDA_GRAPH_WARMUP_STEPS: + cuda_graph_helper.create_cudagraphs() + optim.zero_grad() output = pretrain_forward_backward( model=model_chunks, @@ -717,6 +755,9 @@ def _training_loop(seed=42, **kwargs): # Collect loss outputs.append(output[-1]) + if cuda_graph_helper is not None and cuda_graph_helper.graphs_created(): + cuda_graph_helper.delete_cuda_graphs() + Utils.destroy_model_parallel() return outputs @@ -812,6 +853,69 @@ def test_compatible_with_nd_parallel(self, ref_cache, nd_topology, spec_configs) ), ) + @pytest.mark.flaky_in_dev + @pytest.mark.skipif( + not (HAVE_TE and is_te_min_version("2.10.0")), + reason="Partial CUDA graph support requires TransformerEngine >= 2.10.0", + ) + @pytest.mark.parametrize( + "parallel_config", + [ + pytest.param({}, id="default"), + pytest.param({"tensor_model_parallel_size": 2}, id="TP2"), + pytest.param( + {"expert_model_parallel_size": 2, "expert_tensor_parallel_size": 2}, id="EP2_ETP2" + ), + ], + ) + def test_cudagraph_alignment_with_fsdp(self, parallel_config): + """CUDA graph replay must produce numerically identical loss to eager FSDP execution. + + Parametrized over parallelism configurations. For each config, runs one eager + baseline then verifies all CUDA graph scopes produce bit-identical losses. + """ + SCOPES = [ + [CudaGraphScope.attn], + [CudaGraphScope.attn, CudaGraphScope.moe_router, CudaGraphScope.moe_preprocess], + [CudaGraphScope.moe_router], + ] + FSDP_COMMON = dict( + use_megatron_fsdp=True, + data_parallel_sharding_strategy="optim_grads_params", + init_model_with_meta_device=True, + ckpt_format="fsdp_dtensor", + gradient_accumulation_fusion=False, + fsdp_double_buffer=True, + fsdp_db_use_persist_buf_on_alloc_fail=True, + ) + + reference_outputs = TestMegatronFSDPE2E._training_loop(**FSDP_COMMON, **parallel_config) + + for scope in SCOPES: + outputs = TestMegatronFSDPE2E._training_loop( + **FSDP_COMMON, + **parallel_config, + cuda_graph_impl="transformer_engine", + cuda_graph_scope=scope, + cuda_graph_warmup_steps=3, + ) + if torch.distributed.get_rank() == 0: + for step, (output, ref_output) in enumerate(zip(outputs, reference_outputs)): + loss = output["lm loss"] + ref_loss = ref_output["lm loss"] + assert_close( + loss, + ref_loss, + atol=0, + rtol=0, + msg=( + f"CUDA graph loss mismatch at step {step} " + f"(parallel={parallel_config}, scope={[s.name for s in scope]}): " + f"cuda_graph={loss.item():.6f}, eager={ref_loss.item():.6f}" + f", Compare = {compare_losses(loss.item(), ref_loss.item())}" + ), + ) + def compare_losses(loss_a: float, loss_b: float, reference: str = "b"): """ diff --git a/tests/unit_tests/distributed/megatron_fsdp/utils.py b/tests/unit_tests/distributed/megatron_fsdp/utils.py index 18a2da63786..db7cf1b036b 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/utils.py +++ b/tests/unit_tests/distributed/megatron_fsdp/utils.py @@ -77,6 +77,13 @@ def make_moe_args_model_and_optimizer(ut_filename, **overrides): min_lr=3e-5, use_distributed_optimizer=True, finalize_model_grads_func=finalize_model_grads, + micro_batch_size=2, + global_batch_size=32, + vocab_size=100, + seq_length=128, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + train_iters=20, ) base_args.update(overrides)