Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,21 +952,32 @@ def get_megatron_optimizer(
buffer_name='buffers',
)

optimizers.append(
_get_megatron_optimizer_based_on_param_groups(
config=config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mp_group,
data_parallel_group=dp_cp_group,
data_parallel_group_gloo=intra_dp_cp_group_gloo,
data_parallel_group_idx=model_parallel_rank,
intra_dist_opt_group=intra_dist_opt_group,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
pg_collection=pg_collection,
)
optimizer_part = _get_megatron_optimizer_based_on_param_groups(
config=config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mp_group,
data_parallel_group=dp_cp_group,
data_parallel_group_gloo=intra_dp_cp_group_gloo,
data_parallel_group_idx=model_parallel_rank,
intra_dist_opt_group=intra_dist_opt_group,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
pg_collection=pg_collection,
)
if (
not USING_PYTORCH_OPTIMIZER
and config.use_precision_aware_optimizer
and getattr(optimizer_part.optimizer, "master_weights", None) is not None
):
# NOTE(@cspades): FusedAdam is provided Megatron-FSDP's main weights as
# non-quantized DTensor(s). Megatron-FSDP should NEVER use FusedAdam's
# main weights, complete waste of memory as the optimizer step is still
# applied to the Megatron-FSDP main weight and extended to FusedAdam
# main weights. Override this here.
setattr(optimizer_part.optimizer, "master_weights", False)

optimizers.append(optimizer_part)
model_chunk_offset += 1

if len(optimizers) == 1:
Expand Down
63 changes: 40 additions & 23 deletions megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

"""Megatron distributed optimizer."""


import gc
import itertools
import logging
Expand Down Expand Up @@ -97,7 +96,7 @@ def __len__(self):
class DistributedOptimizer(MixedPrecisionOptimizer):
"""Optimizer that shards state across data-parallel ranks.

This class reduces memory usage by distributing optimizer states (like
This class reduces memory usage by distributing optimizer states (like
momentum and variance buffers) across GPUs in the data-parallel group.

Attributes:
Expand Down Expand Up @@ -231,9 +230,9 @@ def _build_model_gbuf_range(cls, param_and_grad_buffer: _ParamAndGradBuffer, buc
def _build_gbuf_range_map(cls, param_and_grad_buffer: _ParamAndGradBuffer):
"""Builds a map between parameters and their ranges in the grad buffer.

These mappings are partitioned according to data type. This method
iterates through all buckets of a grad buffer to construct param
ranges that this rank "owns" (the dp_rank'th shard of each bucket,
These mappings are partitioned according to data type. This method
iterates through all buckets of a grad buffer to construct param
ranges that this rank "owns" (the dp_rank'th shard of each bucket,
where each shard is 1/dp_world_size of the bucket).

Args:
Expand Down Expand Up @@ -483,33 +482,33 @@ def __init__(
):
"""Initializes the distributed optimizer for FP16, BF16, and FP32.

The steps in this method create the core mapping between param and grad
buffers, parameters, and parameter shard ranges, that is needed for
converting between model param indexes and main parameter shard indexes.
This method also updates the optimizer parameter groups with the
The steps in this method create the core mapping between param and grad
buffers, parameters, and parameter shard ranges, that is needed for
converting between model param indexes and main parameter shard indexes.
This method also updates the optimizer parameter groups with the
newly created shards.

Args:
optimizer (torch.optim.Optimizer): Base optimizer such as Adam or SGD.
config (OptimizerConfig): Configuration object for the optimizer.
grad_scaler (MegatronGradScaler): Used for scaling gradients. Note that
this can be None for BF16 training if no loss scale is used.
grad_scaler (MegatronGradScaler): Used for scaling gradients. Note that
this can be None for BF16 training if no loss scale is used.
For FP16, a grad scaler is always required.
init_state_fn (Callable, optional): Function to initialize state in
init_state_fn (Callable, optional): Function to initialize state in
the optimizer.
model_chunks (List[MegatronModule]): List of model chunks to optimize.
per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): The
implementation of the distributed optimizer is centered on using
a contiguous buffer for communicating grads & params between
the model state and the optimizer state. For a detailed
per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): The
implementation of the distributed optimizer is centered on using
a contiguous buffer for communicating grads & params between
the model state and the optimizer state. For a detailed
description, see `docs/source/distrib_optimizer.md`.
data_parallel_group (ProcessGroup): Data-parallel group used to
data_parallel_group (ProcessGroup): Data-parallel group used to
all-gather params after optimizer.step().
data_parallel_group_gloo (ProcessGroup, optional): Gloo data-parallel
data_parallel_group_gloo (ProcessGroup, optional): Gloo data-parallel
group used specifically for checkpoint loading and saving.
data_parallel_group_idx (int): Index in the data-parallel group
data_parallel_group_idx (int): Index in the data-parallel group
used by distributed checkpointing logic.
distributed_optimizer_instance_id (int): Unique identifier for the
distributed_optimizer_instance_id (int): Unique identifier for the
distributed optimizer instance.
"""

Expand Down Expand Up @@ -539,6 +538,7 @@ def __init__(

self.is_stub_optimizer = False
if self.ddp_config.use_megatron_fsdp:
# Megatron-FSDP will manage optimizer weights and gradients.
return

# Model grad buffer ranges.
Expand Down Expand Up @@ -596,7 +596,7 @@ def __init__(
param.main_param_sharded = True

# Optimizer ranges.
(self.model_param_group_index_map, self.opt_group_ranges) = (
self.model_param_group_index_map, self.opt_group_ranges = (
self._build_optimizer_group_ranges(self.optimizer.param_groups, self.gbuf_ranges)
)

Expand Down Expand Up @@ -729,6 +729,8 @@ def load_state_dict(self, state_dict):
list.
"""
if self.ddp_config.use_megatron_fsdp:
# When using Megatron-FSDP, directly load the optimizer state
# into the wrapped optimizer.
if "param_to_group_meta" in state_dict:
state_dict["param_groups"] = self._param2group_meta_to_param_groups(
state_dict["param_to_group_meta"], self.optimizer.param_groups
Expand Down Expand Up @@ -1255,6 +1257,7 @@ def sharded_state_dict(
f"sharding_type {sharding_type} is not supported with Megatron FSDP."
)
if sharding_type == "fsdp_dtensor":
# Megatron-FSDP custom sharded state dict construction.
state_dict = self.sharded_param_state_fsdp_dtensor(is_loading)
return state_dict

Expand Down Expand Up @@ -2302,6 +2305,7 @@ def zero_grad(self, set_to_none: bool = True):
"""
if self.ddp_config.use_megatron_fsdp:
for model_chunk in self.model_chunks:
# Zero gradients managed by Megatron-FSDP.
model_chunk.zero_grad_buffer()
return

Expand Down Expand Up @@ -2366,6 +2370,7 @@ def _get_fp8_params_and_shard_fp32_from_fp8(self):
shard_offsets_in_fp8 = []

if self.ddp_config.use_megatron_fsdp:
# Retrieve Megatron-FSDP compute weights.
buffers = []
for m in self.model_chunks:
for group in m.param_and_grad_buffer.parameter_groups:
Expand Down Expand Up @@ -2419,6 +2424,9 @@ def _copy_model_grads_to_main_grads(self):
return

if self.ddp_config.use_megatron_fsdp:
# Megatron-FSDP manages unsharded gradient buffer allocation
# (with zero-copy if using NCCL UB and wgrad accum fusion)
# during the backward pass.
return

# Utility method for copying group grads.
Expand Down Expand Up @@ -2462,6 +2470,8 @@ def _copy_main_params_to_model_params(self):
return

if self.ddp_config.use_megatron_fsdp:
# Update Megatron-FSDP's compute weights with optimized main weights.
# If using quantized parameters, this will also perform quantization.
for model_chunk in self.model_chunks:
model_chunk.param_and_grad_buffer.copy_main_weights_to_model_weights()
return
Expand Down Expand Up @@ -2509,6 +2519,10 @@ def _copy_main_params_to_param_buffer(self):
param buffer is not mapped to model params for MXFP8 case.

"""
if self.ddp_config.use_megatron_fsdp:
raise NotImplementedError(
"_copy_main_params_to_param_buffer not supported for Megatron-FSDP."
)
for shard_main_group, model_group in zip(
self.shard_fp32_from_float16_groups, self.model_float16_groups
):
Expand Down Expand Up @@ -2575,7 +2589,9 @@ def _copy_model_params_to_main_params(self, state_dict=None):
return

if self.ddp_config.use_megatron_fsdp:
return
raise NotImplementedError(
"Megatron-FSDP does not implement a model-to-main parameter update."
)

# When using precision-aware optimizer, main params are held by self.optimizer. It will also
# do the work of copying data from main params to model params.
Expand Down Expand Up @@ -2631,6 +2647,8 @@ def step_with_ready_grads(self) -> bool:
timers('params-all-gather', log_level=1).start(barrier=self.config.barrier_with_L1_time)

if self.ddp_config.use_megatron_fsdp:
# Optionally all-gather Megatron-FSDP sharded main weights
# early in preparation for the subsequent forward pass.
for model_chunk in self.model_chunks:
model_chunk.start_param_sync()
else:
Expand All @@ -2645,4 +2663,3 @@ def step_with_ready_grads(self) -> bool:
timers('params-all-gather').stop()

return update_successful

42 changes: 34 additions & 8 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class MegatronOptimizer(ABC):
"""
Base class for all Megatron optimizers.

Provides a consistent interface for gradient management, parameter
Provides a consistent interface for gradient management, parameter
access, and state-dict handling across different optimization types.

Args:
Expand Down Expand Up @@ -138,11 +138,10 @@ def get_parameters(self) -> List[torch.nn.Parameter]:
return params

def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:

"""Collects gradients for norm calculation, filtering duplicates.

This method filters parameters based on whether the gradient is not None,
the parameter is not shared (to avoid double-counting gradients), and
This method filters parameters based on whether the gradient is not None,
the parameter is not shared (to avoid double-counting gradients), and
the parameter is not a replica due to tensor model parallelism.

Returns:
Expand All @@ -151,15 +150,21 @@ def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
params = self.get_parameters()
grads_for_norm = []
for param in params:
if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8:
if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8 or (
# Megatron-FSDP always uses decoupled_grad with FusedAdam.
self.config.use_precision_aware_optimizer
and getattr(param, "__fsdp_param__", False)
):
grad = param.decoupled_grad if hasattr(param, "decoupled_grad") else None
if (
getattr(param, "__fsdp_param__", False)
and grad is not None
and hasattr(grad, "_local_tensor")
):
# Megatron-FSDP gradients are DTensors.
grad = grad._local_tensor
elif getattr(param, "__fsdp_param__", False):
# Megatron-FSDP gradients are DTensors.
grad = param.grad._local_tensor if param.grad is not None else None
else:
grad = param.grad
Expand Down Expand Up @@ -228,7 +233,13 @@ def clip_grad_norm(self, clip_grad: float) -> float:
params,
clip_grad,
grad_norm,
self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8,
# Decoupled Grad
use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8
or (
# Megatron-FSDP always uses decoupled_grad with FusedAdam.
self.config.use_precision_aware_optimizer
and getattr(params[0], "__fsdp_param__", False)
),
)
return grad_norm

Expand All @@ -238,7 +249,12 @@ def count_zeros(self) -> float:
return count_zeros_fp32(
params,
grad_stats_parallel_group=self.get_grad_stats_parallel_group(),
use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8,
use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8
or (
# Megatron-FSDP always uses decoupled_grad with FusedAdam.
self.config.use_precision_aware_optimizer
and getattr(params[0], "__fsdp_param__", False)
),
tp_group=getattr(self, 'tp_group', None),
)

Expand Down Expand Up @@ -1314,7 +1330,12 @@ def count_zeros(self):
return count_zeros_fp32(
params,
grad_stats_parallel_group=self.get_grad_stats_parallel_group(),
use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8,
use_decoupled_grad=self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8
or (
# Megatron-FSDP always uses decoupled_grad with FusedAdam.
self.config.use_precision_aware_optimizer
and getattr(params[0], "__fsdp_param__", False)
),
)
else:
num_zeros_in_grad = 0
Expand Down Expand Up @@ -1347,6 +1368,11 @@ def step(self):
total_norm=grad_norm,
use_decoupled_grad=(
optimizer.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8
or (
# Megatron-FSDP always uses decoupled_grad with FusedAdam.
self.config.use_precision_aware_optimizer
and getattr(params[0], "__fsdp_param__", False)
)
),
)

Expand Down
31 changes: 22 additions & 9 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,9 +950,25 @@ def validate_args(args, defaults={}):
# Future updates will drop support for `use_custom_fsdp` to avoid confusion.
args.use_custom_fsdp = True

if args.data_parallel_sharding_strategy in ["optim_grads_params", "optim_grads"]:
# Megatron-FSDP requires the DistributedOptimizer.
if not args.use_distributed_optimizer:
warn_rank_0(
'Please make sure your TransformerEngine support FSDP + gradient accumulation fusion',
'Megatron-FSDP is only compatible with --use-distributed-optimizer. Using DistributedOptimizer...',
args.rank,
)
args.use_distributed_optimizer = True
# Optimizer step MXFP8 buffer operation that is not relevant or supported for Megatron-FSDP.
args.reuse_grad_buf_for_mxfp8_param_ag = False
# Optimizer compatibility check.
assert args.optimizer in ('sgd', 'adam'), \
f"Megatron-FSDP does not support the {args.optimizer} optimizer yet."

if (
args.data_parallel_sharding_strategy in ["optim_grads_params", "optim_grads"]
and args.gradient_accumulation_fusion
):
warn_rank_0(
'Verify that fused gradient accumulation is supported by TransformerEngine for Megatron-FSDP.',
args.rank,
)

Expand All @@ -961,17 +977,14 @@ def validate_args(args, defaults={}):
'check_weight_hash_across_dp_replicas_interval is not supported with optim_grads_params'

assert os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1", \
'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one'
'FSDP requires CUDA_DEVICE_MAX_CONNECTIONS > 1 or unset.'

assert args.ckpt_format == "fsdp_dtensor", \
"Megatron FSDP only supports fsdp_dtensor checkpoint format"
"Megatron-FSDP requires the `fsdp_dtensor` checkpointing format."

if args.fsdp_manual_registration:
assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP"
assert args.nccl_ub, "FSDP manual registration is only supported with nccl-ub option"

if args.use_megatron_fsdp:
args.reuse_grad_buf_for_mxfp8_param_ag = False
assert args.use_megatron_fsdp, "FSDP manual registration is only supported with Megatron FSDP."
assert args.nccl_ub, "FSDP manual registration is only supported with --nccl-ub argument."

# Parameters dtype.
args.params_dtype = torch.float
Expand Down
Loading
Loading