diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 5387ed31bf3..3339a912cd6 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -952,21 +952,32 @@ def get_megatron_optimizer( buffer_name='buffers', ) - optimizers.append( - _get_megatron_optimizer_based_on_param_groups( - config=config, - model_chunks=model_chunk, - param_groups=param_groups, - per_model_buffers=buffers, - model_parallel_group=mp_group, - data_parallel_group=dp_cp_group, - data_parallel_group_gloo=intra_dp_cp_group_gloo, - data_parallel_group_idx=model_parallel_rank, - intra_dist_opt_group=intra_dist_opt_group, - distributed_optimizer_instance_id=distributed_optimizer_instance_id, - pg_collection=pg_collection, - ) + optimizer_part = _get_megatron_optimizer_based_on_param_groups( + config=config, + model_chunks=model_chunk, + param_groups=param_groups, + per_model_buffers=buffers, + model_parallel_group=mp_group, + data_parallel_group=dp_cp_group, + data_parallel_group_gloo=intra_dp_cp_group_gloo, + data_parallel_group_idx=model_parallel_rank, + intra_dist_opt_group=intra_dist_opt_group, + distributed_optimizer_instance_id=distributed_optimizer_instance_id, + pg_collection=pg_collection, ) + if ( + not USING_PYTORCH_OPTIMIZER + and config.use_precision_aware_optimizer + and getattr(optimizer_part.optimizer, "master_weights", None) is not None + ): + # NOTE(@cspades): FusedAdam is provided Megatron-FSDP's main weights as + # non-quantized DTensor(s). Megatron-FSDP should NEVER use FusedAdam's + # main weights, complete waste of memory as the optimizer step is still + # applied to the Megatron-FSDP main weight and extended to FusedAdam + # main weights. Override this here. + setattr(optimizer_part.optimizer, "master_weights", False) + + optimizers.append(optimizer_part) model_chunk_offset += 1 if len(optimizers) == 1: diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index a18206ed591..320bd6ae7fc 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -2,7 +2,6 @@ """Megatron distributed optimizer.""" - import gc import itertools import logging @@ -97,7 +96,7 @@ def __len__(self): class DistributedOptimizer(MixedPrecisionOptimizer): """Optimizer that shards state across data-parallel ranks. - This class reduces memory usage by distributing optimizer states (like + This class reduces memory usage by distributing optimizer states (like momentum and variance buffers) across GPUs in the data-parallel group. Attributes: @@ -231,9 +230,9 @@ def _build_model_gbuf_range(cls, param_and_grad_buffer: _ParamAndGradBuffer, buc def _build_gbuf_range_map(cls, param_and_grad_buffer: _ParamAndGradBuffer): """Builds a map between parameters and their ranges in the grad buffer. - These mappings are partitioned according to data type. This method - iterates through all buckets of a grad buffer to construct param - ranges that this rank "owns" (the dp_rank'th shard of each bucket, + These mappings are partitioned according to data type. This method + iterates through all buckets of a grad buffer to construct param + ranges that this rank "owns" (the dp_rank'th shard of each bucket, where each shard is 1/dp_world_size of the bucket). Args: @@ -483,33 +482,33 @@ def __init__( ): """Initializes the distributed optimizer for FP16, BF16, and FP32. - The steps in this method create the core mapping between param and grad - buffers, parameters, and parameter shard ranges, that is needed for - converting between model param indexes and main parameter shard indexes. - This method also updates the optimizer parameter groups with the + The steps in this method create the core mapping between param and grad + buffers, parameters, and parameter shard ranges, that is needed for + converting between model param indexes and main parameter shard indexes. + This method also updates the optimizer parameter groups with the newly created shards. Args: optimizer (torch.optim.Optimizer): Base optimizer such as Adam or SGD. config (OptimizerConfig): Configuration object for the optimizer. - grad_scaler (MegatronGradScaler): Used for scaling gradients. Note that - this can be None for BF16 training if no loss scale is used. + grad_scaler (MegatronGradScaler): Used for scaling gradients. Note that + this can be None for BF16 training if no loss scale is used. For FP16, a grad scaler is always required. - init_state_fn (Callable, optional): Function to initialize state in + init_state_fn (Callable, optional): Function to initialize state in the optimizer. model_chunks (List[MegatronModule]): List of model chunks to optimize. - per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): The - implementation of the distributed optimizer is centered on using - a contiguous buffer for communicating grads & params between - the model state and the optimizer state. For a detailed + per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): The + implementation of the distributed optimizer is centered on using + a contiguous buffer for communicating grads & params between + the model state and the optimizer state. For a detailed description, see `docs/source/distrib_optimizer.md`. - data_parallel_group (ProcessGroup): Data-parallel group used to + data_parallel_group (ProcessGroup): Data-parallel group used to all-gather params after optimizer.step(). - data_parallel_group_gloo (ProcessGroup, optional): Gloo data-parallel + data_parallel_group_gloo (ProcessGroup, optional): Gloo data-parallel group used specifically for checkpoint loading and saving. - data_parallel_group_idx (int): Index in the data-parallel group + data_parallel_group_idx (int): Index in the data-parallel group used by distributed checkpointing logic. - distributed_optimizer_instance_id (int): Unique identifier for the + distributed_optimizer_instance_id (int): Unique identifier for the distributed optimizer instance. """ @@ -539,6 +538,7 @@ def __init__( self.is_stub_optimizer = False if self.ddp_config.use_megatron_fsdp: + # Megatron-FSDP will manage optimizer weights and gradients. return # Model grad buffer ranges. @@ -596,7 +596,7 @@ def __init__( param.main_param_sharded = True # Optimizer ranges. - (self.model_param_group_index_map, self.opt_group_ranges) = ( + self.model_param_group_index_map, self.opt_group_ranges = ( self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges) ) @@ -729,6 +729,8 @@ def load_state_dict(self, state_dict): list. """ if self.ddp_config.use_megatron_fsdp: + # When using Megatron-FSDP, directly load the optimizer state + # into the wrapped optimizer. if "param_to_group_meta" in state_dict: state_dict["param_groups"] = self._param2group_meta_to_param_groups( state_dict["param_to_group_meta"], self.optimizer.param_groups @@ -1255,6 +1257,7 @@ def sharded_state_dict( f"sharding_type {sharding_type} is not supported with Megatron FSDP." ) if sharding_type == "fsdp_dtensor": + # Megatron-FSDP custom sharded state dict construction. state_dict = self.sharded_param_state_fsdp_dtensor(is_loading) return state_dict @@ -2302,6 +2305,7 @@ def zero_grad(self, set_to_none: bool = True): """ if self.ddp_config.use_megatron_fsdp: for model_chunk in self.model_chunks: + # Zero gradients managed by Megatron-FSDP. model_chunk.zero_grad_buffer() return @@ -2366,6 +2370,7 @@ def _get_fp8_params_and_shard_fp32_from_fp8(self): shard_offsets_in_fp8 = [] if self.ddp_config.use_megatron_fsdp: + # Retrieve Megatron-FSDP compute weights. buffers = [] for m in self.model_chunks: for group in m.param_and_grad_buffer.parameter_groups: @@ -2419,6 +2424,9 @@ def _copy_model_grads_to_main_grads(self): return if self.ddp_config.use_megatron_fsdp: + # Megatron-FSDP manages unsharded gradient buffer allocation + # (with zero-copy if using NCCL UB and wgrad accum fusion) + # during the backward pass. return # Utility method for copying group grads. @@ -2462,6 +2470,8 @@ def _copy_main_params_to_model_params(self): return if self.ddp_config.use_megatron_fsdp: + # Update Megatron-FSDP's compute weights with optimized main weights. + # If using quantized parameters, this will also perform quantization. for model_chunk in self.model_chunks: model_chunk.param_and_grad_buffer.copy_main_weights_to_model_weights() return @@ -2509,6 +2519,10 @@ def _copy_main_params_to_param_buffer(self): param buffer is not mapped to model params for MXFP8 case. """ + if self.ddp_config.use_megatron_fsdp: + raise NotImplementedError( + "_copy_main_params_to_param_buffer not supported for Megatron-FSDP." + ) for shard_main_group, model_group in zip( self.shard_fp32_from_float16_groups, self.model_float16_groups ): @@ -2575,7 +2589,9 @@ def _copy_model_params_to_main_params(self, state_dict=None): return if self.ddp_config.use_megatron_fsdp: - return + raise NotImplementedError( + "Megatron-FSDP does not implement a model-to-main parameter update." + ) # When using precision-aware optimizer, main params are held by self.optimizer. It will also # do the work of copying data from main params to model params. @@ -2631,6 +2647,8 @@ def step_with_ready_grads(self) -> bool: timers('params-all-gather', log_level=1).start(barrier=self.config.barrier_with_L1_time) if self.ddp_config.use_megatron_fsdp: + # Optionally all-gather Megatron-FSDP sharded main weights + # early in preparation for the subsequent forward pass. for model_chunk in self.model_chunks: model_chunk.start_param_sync() else: @@ -2645,4 +2663,3 @@ def step_with_ready_grads(self) -> bool: timers('params-all-gather').stop() return update_successful - diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index 61979a50b6c..463280bfecc 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -101,7 +101,7 @@ class MegatronOptimizer(ABC): """ Base class for all Megatron optimizers. - Provides a consistent interface for gradient management, parameter + Provides a consistent interface for gradient management, parameter access, and state-dict handling across different optimization types. Args: @@ -138,11 +138,10 @@ def get_parameters(self) -> List[torch.nn.Parameter]: return params def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]: - """Collects gradients for norm calculation, filtering duplicates. - This method filters parameters based on whether the gradient is not None, - the parameter is not shared (to avoid double-counting gradients), and + This method filters parameters based on whether the gradient is not None, + the parameter is not shared (to avoid double-counting gradients), and the parameter is not a replica due to tensor model parallelism. Returns: @@ -151,15 +150,21 @@ def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]: params = self.get_parameters() grads_for_norm = [] for param in params: - if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8: + if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 or ( + # Megatron-FSDP always uses decoupled_grad with FusedAdam. + self.config.use_precision_aware_optimizer + and getattr(param, "__fsdp_param__", False) + ): grad = param.decoupled_grad if hasattr(param, "decoupled_grad") else None if ( getattr(param, "__fsdp_param__", False) and grad is not None and hasattr(grad, "_local_tensor") ): + # Megatron-FSDP gradients are DTensors. grad = grad._local_tensor elif getattr(param, "__fsdp_param__", False): + # Megatron-FSDP gradients are DTensors. grad = param.grad._local_tensor if param.grad is not None else None else: grad = param.grad @@ -228,7 +233,13 @@ def clip_grad_norm(self, clip_grad: float) -> float: params, clip_grad, grad_norm, - self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8, + # Decoupled Grad + use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 + or ( + # Megatron-FSDP always uses decoupled_grad with FusedAdam. + self.config.use_precision_aware_optimizer + and getattr(params[0], "__fsdp_param__", False) + ), ) return grad_norm @@ -238,7 +249,12 @@ def count_zeros(self) -> float: return count_zeros_fp32( params, grad_stats_parallel_group=self.get_grad_stats_parallel_group(), - use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8, + use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 + or ( + # Megatron-FSDP always uses decoupled_grad with FusedAdam. + self.config.use_precision_aware_optimizer + and getattr(params[0], "__fsdp_param__", False) + ), tp_group=getattr(self, 'tp_group', None), ) @@ -1314,7 +1330,12 @@ def count_zeros(self): return count_zeros_fp32( params, grad_stats_parallel_group=self.get_grad_stats_parallel_group(), - use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8, + use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 + or ( + # Megatron-FSDP always uses decoupled_grad with FusedAdam. + self.config.use_precision_aware_optimizer + and getattr(params[0], "__fsdp_param__", False) + ), ) else: num_zeros_in_grad = 0 @@ -1347,6 +1368,11 @@ def step(self): total_norm=grad_norm, use_decoupled_grad=( optimizer.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 + or ( + # Megatron-FSDP always uses decoupled_grad with FusedAdam. + self.config.use_precision_aware_optimizer + and getattr(params[0], "__fsdp_param__", False) + ) ), ) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index e4755971edf..17346be3977 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -950,9 +950,25 @@ def validate_args(args, defaults={}): # Future updates will drop support for `use_custom_fsdp` to avoid confusion. args.use_custom_fsdp = True - if args.data_parallel_sharding_strategy in ["optim_grads_params", "optim_grads"]: + # Megatron-FSDP requires the DistributedOptimizer. + if not args.use_distributed_optimizer: warn_rank_0( - 'Please make sure your TransformerEngine support FSDP + gradient accumulation fusion', + 'Megatron-FSDP is only compatible with --use-distributed-optimizer. Using DistributedOptimizer...', + args.rank, + ) + args.use_distributed_optimizer = True + # Optimizer step MXFP8 buffer operation that is not relevant or supported for Megatron-FSDP. + args.reuse_grad_buf_for_mxfp8_param_ag = False + # Optimizer compatibility check. + assert args.optimizer in ('sgd', 'adam'), \ + f"Megatron-FSDP does not support the {args.optimizer} optimizer yet." + + if ( + args.data_parallel_sharding_strategy in ["optim_grads_params", "optim_grads"] + and args.gradient_accumulation_fusion + ): + warn_rank_0( + 'Verify that fused gradient accumulation is supported by TransformerEngine for Megatron-FSDP.', args.rank, ) @@ -961,17 +977,14 @@ def validate_args(args, defaults={}): 'check_weight_hash_across_dp_replicas_interval is not supported with optim_grads_params' assert os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1", \ - 'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one' + 'FSDP requires CUDA_DEVICE_MAX_CONNECTIONS > 1 or unset.' assert args.ckpt_format == "fsdp_dtensor", \ - "Megatron FSDP only supports fsdp_dtensor checkpoint format" + "Megatron-FSDP requires the `fsdp_dtensor` checkpointing format." if args.fsdp_manual_registration: - assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP" - assert args.nccl_ub, "FSDP manual registration is only supported with nccl-ub option" - - if args.use_megatron_fsdp: - args.reuse_grad_buf_for_mxfp8_param_ag = False + assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP." + assert args.nccl_ub, "FSDP manual registration is only supported with --nccl-ub argument." # Parameters dtype. args.params_dtype = torch.float diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py index 3f6670397e2..92f73d0c73f 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py @@ -690,6 +690,12 @@ def _training_loop(seed=42, **kwargs): train_iters=NUM_TRAINING_STEPS, **kwargs, ) + if kwargs.get("use_megatron_fsdp", False) and kwargs.get( + "use_precision_aware_optimizer", False + ): + assert ( + not optim.optimizer.master_weights + ), "Megatron-FSDP should not use FusedAdam master weights." # Prepare data iterator data_iterator = make_gpt_mock_data_iterator( @@ -759,6 +765,17 @@ def _training_loop(seed=42, **kwargs): ), id="optim_grads_params_double_buffer", ), + pytest.param( + dict( + data_parallel_sharding_strategy="optim_grads_params", + megatron_fsdp_main_params_dtype=torch.float32, + use_precision_aware_optimizer=True, + fp8_recipe="delayed", + fp8_param_gather=True, + bf16=True, + ), + id="optim_grads_params_fused_adam_e2e", + ), pytest.param( dict( data_parallel_sharding_strategy="optim_grads_params", fsdp_double_buffer=False @@ -776,8 +793,10 @@ def _training_loop(seed=42, **kwargs): ], ) def test_compatible_with_nd_parallel(self, ref_cache, nd_topology, spec_configs): - if spec_configs.get("fp8_recipe") == "mxfp8" and not HAVE_TE_MXFP8TENSOR: - pytest.skip("Requires PyTorch with TE MXFP8Tensor support") + if spec_configs.get("fp8_recipe") == "mxfp8" and ( + torch.cuda.get_device_capability()[0] < 10 or not HAVE_TE_MXFP8TENSOR + ): + pytest.skip("Requires PyTorch & CUDA device with TE MXFP8Tensor support") nd_topology_str = "_".join([f"{k}{v}" for k, v in nd_topology.items()]) if nd_topology_str not in ref_cache: