forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
1073 lines (941 loc) · 46.8 KB
/
__init__.py
File metadata and controls
1073 lines (941 loc) · 46.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import copy
import logging
import warnings
from collections import defaultdict
from dataclasses import astuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.optim import SGD as CPUSGD
from torch.optim import AdamW as CPUAdam
try:
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
from transformer_engine.pytorch.optimizers import FusedSGD as SGD
USING_PYTORCH_OPTIMIZER = False
except ImportError:
try:
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
USING_PYTORCH_OPTIMIZER = False
except ImportError:
warnings.warn(
f'Transformer Engine and Apex are not installed. Falling back to Torch optimizers.'
)
# Apex's FusedAdam is a drop-in replacement for torch's AdamW.
# pylint: disable-next=line-too-long.
# See https://github.com/NVIDIA/apex/blob/7b73b12361068a10b0f44844534613f252a5ea75/apex/optimizers/fused_adam.py#L16.
from torch.optim import SGD
from torch.optim import AdamW as Adam
USING_PYTORCH_OPTIMIZER = True
try:
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as _pkg_version
_eo_ver = tuple(int(x) for x in _pkg_version('emerging-optimizers').split('.')[:2])
except (ImportError, PackageNotFoundError):
_eo_ver = (0, 0)
HAVE_EMERGING_OPTIMIZERS = _eo_ver >= (0, 2)
if HAVE_EMERGING_OPTIMIZERS:
from emerging_optimizers.scalar_optimizers import Lion
from megatron.core import parallel_state
from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer
from megatron.core.optimizer_param_scheduler import (
ParamGroupOverride,
combine_param_group_overrides,
param_group_override_to_tuple,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.fsdp_dtensor_checkpoint import get_global_unique_param_name
from ..distributed.param_and_grad_buffer import _ParamAndGradBuffer
from ..transformer.module import MegatronModule
from ..utils import get_model_config, get_pg_rank, get_pg_size, is_te_min_version, log_single_rank
from .distrib_optimizer import DistributedOptimizer
from .emerging_optimizers import (
_EMERGING_OPTIMIZERS,
HAVE_EMERGING_OPTIMIZERS,
_create_emerging_optimizer,
)
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .layer_wise_optimizer import LayerWiseDistributedOptimizer
from .optimizer import (
ChainedOptimizer,
Float16OptimizerWithFloat16Params,
FP32Optimizer,
MegatronOptimizer,
param_group_identifier_keys,
)
# Subclass aliases kept for backward compatibility; all are OptimizerConfig.
from .optimizer_config import (
AdamOptimizerConfig,
OptimizerConfig,
ParamKey,
ParamPredicate,
ParamWithNamePredicate,
SGDOptimizerConfig,
)
logger = logging.getLogger(__name__)
def get_standard_config_overrides(config: OptimizerConfig) -> Dict[ParamKey, ParamGroupOverride]:
"""Get standard config overrides for the optimizer, handling decoupled LR and common wd skips.
Args:
config (OptimizerConfig): optimizer configuration object.
Returns:
Dict[ParamKey, ParamGroupOverride]: standard config overrides.
"""
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]] = {}
# First, figure out how we are going to do wd skipping. The two main approaches are:
# 1. The classic megatron approach of skipping all len 1 and bias parameters.
# 2. The Qwen3-Next approach of doing 1, other than qk layernorm parameters.
if config.apply_wd_to_qk_layernorm:
shape_1_not_qkln_param = ParamWithNamePredicate(
name="s1_not_qkln",
fn=lambda param, name: (len(param.shape) == 1 or name.endswith(".bias"))
and not ("q_layernorm." in name or "k_layernorm." in name),
)
param_wd_mult_key = ParamKey(with_name_predicate=shape_1_not_qkln_param)
else:
param_length_1_match = ParamPredicate(
name="param_len_1", fn=lambda param: len(param.shape) == 1
)
param_wd_mult_key = ParamKey(name="*.bias", predicate=param_length_1_match)
config_overrides[param_wd_mult_key] = ParamGroupOverride(wd_mult=0.0)
if config.decoupled_lr is not None:
decoupled_lr_config: ParamGroupOverride = {"max_lr": config.decoupled_lr}
decoupled_param_key = ParamKey(attr="is_embedding_or_output_parameter")
if config.decoupled_min_lr is not None:
decoupled_lr_config["min_lr"] = config.decoupled_min_lr
config_overrides[decoupled_param_key] = decoupled_lr_config
return config_overrides
def get_mup_config_overrides(
config: OptimizerConfig, mup_width_mult: float, optimizer_type: str = 'adam'
) -> Dict[ParamKey, ParamGroupOverride]:
"""Get MuP config overrides for per-layer LR and Adam epsilon scaling.
In MuP, optimizer learning rates are adjusted by parameter class to ensure
stable update scales across model widths and enable hyperparameter transfer.
MuP optimizer scaling rules (as implemented here):
- Adam/AdamW:
- hidden (matrix-like) lr = base_lr / width_mult
- hidden (matrix-like) eps = base_eps / width_mult
- vector-like params keep base lr and eps
- SGD:
- vector-like lr = base_lr * width_mult
- hidden (matrix-like) lr keeps base_lr in the current uniform-width setup
- no eps override is applied
- Non-Adam optimizers:
- hidden (matrix-like) lr = base_lr / width_mult
- no eps override is applied.
- for Muon optimizers, matrix-like params managed by Muon itself are
excluded from these Adam-style MuP overrides.
With decoupled_lr enabled, embedding/output params continue using decoupled LR
and MuP will not override those explicit decoupled values.
Args:
config (OptimizerConfig): optimizer configuration object.
mup_width_mult (float): Width multiplier (hidden_size / base_hidden_size).
optimizer_type (str): Optimizer type string from config.optimizer.
Returns:
Dict[ParamKey, ParamGroupOverride]: MuP optimizer overrides.
"""
optimizer_type_lower = optimizer_type.lower()
is_sgd_optimizer = optimizer_type_lower == 'sgd'
is_adam_optimizer = 'adam' in optimizer_type_lower
is_muon_optimizer = 'muon' in optimizer_type_lower
decoupled_lr_enabled = config.decoupled_lr is not None
if decoupled_lr_enabled:
message = (
"Both decoupled_lr and MuP LR scaling are enabled. decoupled_lr sets an "
"absolute LR for embedding+output params, and MuP LR scaling will not "
"override those parameters."
)
if is_adam_optimizer:
message += " MuP Adam epsilon scaling remains applied to hidden matrix-like parameters."
log_single_rank(logger, logging.WARNING, message)
if is_muon_optimizer:
muon_scale_mode = getattr(config, 'muon_scale_mode', 'spectral')
if muon_scale_mode == 'spectral':
log_single_rank(
logger,
logging.WARNING,
"Both MuP and muon_scale_mode=spectral are enabled. "
"Muon-managed matrix parameters will continue using spectral Muon scaling. "
"Set --muon-scale-mode unit_rms_norm to use unit_rms_norm scaling for "
"Muon-managed matrices with MuP.",
)
if mup_width_mult == 1.0:
# No scaling needed when width_mult is 1
return {}
hidden_lr_mult = 1.0 / mup_width_mult
base_lr = config.lr
base_min_lr = config.min_lr
# Hidden matrix-like layers get scaled LR/eps; vector-like params keep base values.
# Prefer the explicit parameter attribute set by LanguageModule. Fall back to
# a conservative name check for older or non-language modules.
def is_embedding_parameter(param: torch.nn.Parameter, param_name: str) -> bool:
if getattr(param, 'shared_embedding', False):
return True
if hasattr(param, 'is_embedding_parameter'):
return bool(param.is_embedding_parameter)
return 'embedding' in param_name.lower()
def is_vector_like_parameter(param: torch.nn.Parameter, param_name: str) -> bool:
if is_embedding_parameter(param, param_name):
return True
if param.dim() <= 1:
return True
return False
def is_muon_managed_matrix_parameter(param: torch.nn.Parameter, _: str) -> bool:
if not is_muon_optimizer:
return False
return param.dim() == 2 and not getattr(param, 'is_embedding_or_output_parameter', False)
def should_scale_lr_with_mup(param: torch.nn.Parameter, param_name: str) -> bool:
if decoupled_lr_enabled and getattr(param, 'is_embedding_or_output_parameter', False):
return False
if is_muon_managed_matrix_parameter(param, param_name):
return False
return not is_vector_like_parameter(param, param_name)
def should_scale_vector_like_lr_with_mup(param: torch.nn.Parameter, param_name: str) -> bool:
if decoupled_lr_enabled and getattr(param, 'is_embedding_or_output_parameter', False):
return False
return is_vector_like_parameter(param, param_name)
def should_scale_eps_with_mup(param: torch.nn.Parameter, param_name: str) -> bool:
if is_vector_like_parameter(param, param_name):
return False
if is_muon_managed_matrix_parameter(param, param_name):
return False
# MuP Appendix B.3: eps scales with fan_in when non-negligible.
# This implementation follows the common denominator form: sqrt(v) + eps.
return True
mup_overrides: Dict[ParamKey, ParamGroupOverride] = {}
if is_sgd_optimizer:
vector_like_lr_mult = mup_width_mult
vector_like_lr_override: ParamGroupOverride = {}
if base_lr is not None:
vector_like_lr_override["max_lr"] = base_lr * vector_like_lr_mult
if base_min_lr is not None:
vector_like_lr_override["min_lr"] = base_min_lr * vector_like_lr_mult
if vector_like_lr_override:
vector_like_predicate = ParamWithNamePredicate(
name="mup_sgd_vector_like_excluding_embedding_output",
fn=should_scale_vector_like_lr_with_mup,
)
mup_overrides[ParamKey(with_name_predicate=vector_like_predicate)] = (
vector_like_lr_override
)
return mup_overrides
lr_override: ParamGroupOverride = {}
if base_lr is not None:
lr_override["max_lr"] = base_lr * hidden_lr_mult
if base_min_lr is not None:
lr_override["min_lr"] = base_min_lr * hidden_lr_mult
eps_override: ParamGroupOverride = {}
if is_adam_optimizer and config.adam_eps is not None:
eps_override["eps"] = config.adam_eps * hidden_lr_mult
if decoupled_lr_enabled:
if lr_override:
hidden_predicate = ParamWithNamePredicate(
name="mup_hidden_only_excluding_embedding_output", fn=should_scale_lr_with_mup
)
mup_overrides[ParamKey(with_name_predicate=hidden_predicate)] = lr_override
if eps_override:
hidden_output_predicate = ParamWithNamePredicate(
name="mup_hidden_only_for_adam_eps", fn=should_scale_eps_with_mup
)
mup_overrides[ParamKey(with_name_predicate=hidden_output_predicate)] = eps_override
else:
combined_override: ParamGroupOverride = {}
combined_override.update(lr_override)
combined_override.update(eps_override)
if combined_override:
hidden_output_predicate = ParamWithNamePredicate(
name="mup_hidden_and_output", fn=should_scale_eps_with_mup
)
mup_overrides[ParamKey(with_name_predicate=hidden_output_predicate)] = combined_override
return mup_overrides
def _get_param_groups(
model_chunks: List[MegatronModule],
config: OptimizerConfig,
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]],
) -> List[Dict]:
"""Create parameter groups for optimizer.
Creates parameter groups from provided optimizer config object.
NOTE There can be more than one match between a ParamKey and a parameter.
What we do is merge all of the matching ParamKey overrides into a single ParamGroupOverride
for that parameter and use that as the key for that parameter. Any parameters that get
the same set of merged overrides will be mapped into the same parameter group.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
config (OptimizerConfig): optimizer configuration object.
config_overrides (Optional[Dict[ParamKey, ParamGroupOverride]): optimizer overrides,
specified on a per-layer basis. NOTE: if you want to skip applying weight decay on bias
and length 1 parameters, and also do not want to do any other overrides, set this to an
empty dictionary rather than the default value of None.
Returns:
List of parameter groups.
"""
# Map (pg_overrides, is_expert_parallel) to params.
params_map = {}
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
uses_default_config = False
# Get optimizer config overrides for this parameter.
param_overrides_list: list[ParamGroupOverride] = []
if config_overrides is not None:
for param_key, param_override in config_overrides.items():
if param_key.matches(param, name):
param_overrides_list.append(param_override)
if param_overrides_list:
param_override: ParamGroupOverride | None = combine_param_group_overrides(
param_overrides_list
)
else:
param_override = None
is_expert_parallel = not getattr(param, 'allreduce', True)
# Create config_tuple that is hash-able, and has a consistent ordering of the keys.
param_override_tuple: tuple[tuple[str, Any], ...] | None = (
param_group_override_to_tuple(param_override)
)
key = (param_override_tuple, is_expert_parallel)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)
# Distributed checkpoint requires all ranks to have the same param groups,
# so we need to align the param groups across ranks, otherwise we may have
# runtime error when loading the checkpoint or numerical error when resuming training.
params_key = list(params_map.keys())
gathered_params_key = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_params_key, params_key)
for keys in gathered_params_key:
for key in keys:
if key not in params_key:
params_key.append(key)
# Need to pick one of the param_override_tuples to use for the param group.
param_groups = []
# Sort keys, None first.
for key in sorted(params_key, key=lambda x: (x[0] is not None, x[0])):
param_override_tuple, is_expert_parallel = key
params = params_map[key] if key in params_map else []
if param_override_tuple is None:
param_override: ParamGroupOverride = {}
else:
param_override: ParamGroupOverride = {k: v for (k, v) in param_override_tuple}
# False if param_group_override is None or empty tuple or if we do not modify the
# LR schedule.
# NOTE: "default_config" is used for logging the learning rate in training.py.
# so set to True if we do not modify the learning rate.
# if param_group['default_config']:
# learning_rate = param_group['lr']
uses_default_lr_schedule: bool = (not bool(param_override_tuple)) or not any(
["lr" in k for k in param_override]
)
# TODO: Remove "backwards compatible" fields below eventually.
default_config: ParamGroupOverride = {
'wd_mult': 1.0,
'lr_mult': 1.0,
'is_decoupled_lr': False,
# The following two fields may be important to keep even when we remove the
# above "backwards compatible" fields.
"max_lr": config.lr, # user may override this in param_override
"min_lr": config.min_lr, # user may override this in param_override
}
assert (
"params" not in param_override
), "'params' should not be in param_override, this is a protected key"
param_group = {
'params': params,
'is_expert_parallel': is_expert_parallel,
'default_config': uses_default_lr_schedule,
**default_config,
**param_override, # keep **param_override last so that users can override other fields.
}
param_groups.append(param_group)
return param_groups
def _get_param_groups_and_buffers(
model_chunks: List[MegatronModule],
model_chunk_offset: int,
config: OptimizerConfig,
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]],
filter_fn: Callable,
buffer_name: str,
) -> Tuple[List[Dict], Dict[int, List[_ParamAndGradBuffer]]]:
"""Returns parameter groups and buffer for optimizer.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
model_chunk_offset (int): offset of model_chunks in global model_chunks list.
config (OptimizerConfig): optimizer configuration object.
config_overrides (Optional[Dict[ParamKey, ParamGroupOverride]): optimizer/scheduler
overrides, specified on the basis of ParamKey matches with each parameter.
lr (float): learning rate.
min_lr (float): minimum learning rate.
filter_fn (callable): filtering function for param_groups.
buffer_name (str): name of buffer.
Returns:
List of parameter groups and dictionary of model chunk IDs to buffers.
"""
param_groups = _get_param_groups(model_chunks, config, config_overrides)
param_groups = list(filter(filter_fn, param_groups))
buffers = {}
for model_chunk_idx, model_chunk in enumerate(model_chunks):
if hasattr(model_chunk, buffer_name):
buffers[model_chunk_idx + model_chunk_offset] = getattr(model_chunk, buffer_name)
return param_groups, buffers
def _get_megatron_optimizer_based_on_param_groups(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
param_groups: List,
per_model_buffers: Optional[Dict[int, List[_ParamAndGradBuffer]]] = None,
model_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_idx: Optional[int] = None,
intra_dist_opt_group: Optional[torch.distributed.ProcessGroup] = None,
distributed_optimizer_instance_id: Optional[int] = 0,
pg_collection: Optional[ProcessGroupCollection] = None,
skip_megatron_wrapping: bool = False,
) -> Union[MegatronOptimizer, Tuple[Optional[torch.optim.Optimizer], Optional[Callable]]]:
"""Get Megatron optimizer based on parameter groups.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (list): list of model chunks.
param_groups (list): list of parameter groups.
per_model_buffers (dict, optional): buffers for distributed optimizer. Defaults to None.
data_parallel_group (torch.distributed.ProcessGroup, optional): data-parallel group for
distributed optimizer. Defaults to None.
data_parallel_group_gloo (torch.distributed.ProcessGroup, optional): gloo data-parallel
group for distributed optimizer. Defaults to None.
data_parallel_group_idx (int, optional): data-parallel group index for distributed
optimizer. Defaults to None.
distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults
0.
skip_megatron_wrapping (bool): if True, return a
``(optimizer, init_state_fn)`` tuple of the raw PyTorch optimizer
without any Megatron wrapping. Useful when the caller
(e.g. LayerWiseDistributedOptimizer) performs its own wrapping.
Returns:
Instance of MegatronOptimizer, or ``(optimizer, init_state_fn)`` when
*skip_megatron_wrapping=True*.
"""
# All param_groups passed here must belong to the same optimizer type (adam / sgd).
# Callers are responsible for splitting by optimizer type before calling this function.
if skip_megatron_wrapping and config.use_precision_aware_optimizer:
raise ValueError(
"skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer."
)
if skip_megatron_wrapping and config.optimizer_cpu_offload:
raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.")
# When freezing sub-models we may have no trainable parameters on a rank and
# hence an empty param_groups. However, we still need to create an optimizer
# for the purposes of grad stats reductions.
if param_groups:
if config.optimizer_cpu_offload:
if torch.__version__ < '2.3.0':
warnings.warn(
"CPU offload is recommended for PyTorch >= 2.3.0, "
"untested versions below this may have convergence issues."
)
assert (
config.decoupled_weight_decay
), "CPU offloading only supported with decoupled_weight_decay enabled (AdamW mode)."
gpu_optimizer_cls = Adam if config.optimizer == 'adam' else SGD
cpu_optimizer_cls = CPUAdam if config.optimizer == 'adam' else CPUSGD
if config.use_torch_optimizer_for_cpu_offload:
gpu_optimizer_cls = cpu_optimizer_cls
if config.optimizer == 'adam':
gpu_optimizer_cls = Adam
cpu_optimizer_cls = CPUAdam
optimizer_defaults = dict(
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
bias_correction=True,
fused=True, # this flag is used to improve the performance of the cpu optimizer
)
else:
gpu_optimizer_cls = SGD
cpu_optimizer_cls = CPUSGD
optimizer_defaults = dict(
lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum
)
optimizer = HybridDeviceOptimizer(
param_groups,
offload_fraction=config.optimizer_offload_fraction,
cpu_optimizer_cls=cpu_optimizer_cls,
gpu_optimizer_cls=gpu_optimizer_cls,
overlap_cpu_optimizer_d2h_h2d=config.overlap_cpu_optimizer_d2h_h2d,
pin_cpu_grads=config.pin_cpu_grads,
pin_cpu_params=config.pin_cpu_params,
param_update_in_fp32=True,
**optimizer_defaults,
)
init_state_fn = None
elif config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
"capturable": config.optimizer_cuda_graph,
}
# set Adam class and weight decay mode depending
# on source of optimizer (Torch or TE/Apex)
if USING_PYTORCH_OPTIMIZER:
adam_cls = torch.optim.AdamW if config.decoupled_weight_decay else torch.optim.Adam
else:
kwargs["adam_w_mode"] = config.decoupled_weight_decay
adam_cls = Adam
if config.use_precision_aware_optimizer:
kwargs.update(
{
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)
# Master weight is managed by MCore when main_params_dtype is fp32. This is
# because we want to use fp8 primary weight with precision aware optimizer.
# Otherwise, master weight will be managed by TransformerEngine.
# Delayed scaling is an exception because casting as well as the computation
# of the scaling factor can be conducted in the adam kernel.
if config.use_precision_aware_optimizer_no_fp8_or_ds_fp8:
kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
}
)
if is_te_min_version("2.1.0.dev0"):
kwargs.update({"store_param_remainders": config.store_param_remainders})
optimizer = adam_cls(**kwargs)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
elif config.optimizer == 'lion':
if not HAVE_EMERGING_OPTIMIZERS:
raise ImportError(
"Lion optimizer requires emerging_optimizers >= 0.2. "
"Please install or upgrade it to use --optimizer lion."
)
optimizer = Lion( # pylint: disable=possibly-used-before-assignment
param_groups,
lr=config.lr,
betas=(config.lion_beta1, config.lion_beta2),
weight_decay=config.weight_decay,
)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
optimizer = None
init_state_fn = None
if skip_megatron_wrapping:
return optimizer, init_state_fn
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if config.fp16 or config.bf16 or config.use_distributed_optimizer:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler = None
# Constant loss scale.
if config.loss_scale:
grad_scaler = ConstantGradScaler(config.loss_scale)
# Dynamic loss scale.
else:
if config.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=config.initial_loss_scale,
min_scale=config.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=config.loss_scale_window,
hysteresis=config.hysteresis,
)
optimizer_args = [optimizer, config, grad_scaler, init_state_fn]
if config.use_distributed_optimizer:
optimizer = DistributedOptimizer(
*optimizer_args,
model_chunks=model_chunks,
per_model_buffers=per_model_buffers,
data_parallel_group=data_parallel_group,
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=data_parallel_group_idx,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
# This is needed for case where num_distributed_optimizer_instances > 1. In this case,
# weight gradients are all-reduced across optimizer instances, so each instance has
# the duplicated weight gradients, need to reduce gradient stats inside each instance.
setattr(optimizer, 'grad_stats_parallel_group', intra_dist_opt_group)
else:
optimizer = Float16OptimizerWithFloat16Params(*optimizer_args)
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
else:
# FP32 optimizer.
optimizer = FP32Optimizer(optimizer, config, init_state_fn)
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
if pg_collection is None or not hasattr(pg_collection, 'tp'):
tp_group = parallel_state.get_tensor_model_parallel_group()
else:
tp_group = pg_collection.tp
# TODO(M4): plumb tp_group through optimizer constructors so this setattr disappears.
setattr(optimizer, 'tp_group', tp_group)
return optimizer
def check_config_overrides_consistency(
config: OptimizerConfig, config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]]
):
"""Check if the config overrides are consistent with the config."""
# TODO: Remove `optimizer` from this eventually (e.g., if we use Muon for some layers and
# Adam for other layers). This would need some more refactoring to work though (param_groups
# filtered by optimizer passed into _get_megatron_optimizer_based_on_param_groups).
if config_overrides is not None:
fields_to_check_for_consistency = [
'overlap_param_gather_with_optimizer_step',
'optimizer',
'optimizer_cpu_offload',
]
for field_name in fields_to_check_for_consistency:
base_field = getattr(config, field_name, None)
all_config_overrides = list(config_overrides.values())
for config_override in all_config_overrides:
if field_name in config_override:
field = config_override[field_name]
if field != base_field:
raise ValueError(
f"Field {field_name} should not be overriden in a config override."
)
return True
def _get_megatron_emerging_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
config_overrides: Optional[Dict[ParamKey, Any]] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
) -> MegatronOptimizer:
"""Build an emerging optimizer (e.g. Muon) for the given model chunks.
Parameter separation (e.g., linear weights -> Muon, rest -> Adam) is expressed as a
config_override, the same mechanism used for weight-decay and learning-rate overrides.
Adam/SGD groups are delegated to _get_megatron_optimizer_based_on_param_groups so they
go through the exact same code path as the standard optimizer factory.
When ``config.use_layer_wise_distributed_optimizer`` is True, the underlying optimizers
are wrapped with :class:`LayerWiseDistributedOptimizer`.
"""
eopt_name = config.optimizer
use_layer_wise = config.use_layer_wise_distributed_optimizer
# Handle legacy "dist_*" optimizer names (e.g. "dist_muon" → "muon" + layer-wise).
if eopt_name.startswith('dist_'):
bare_name = eopt_name[len('dist_') :]
warnings.warn(
f"optimizer='{eopt_name}' is deprecated. "
f"Use optimizer='{bare_name}' with use_layer_wise_distributed_optimizer=True.",
DeprecationWarning,
stacklevel=3,
)
eopt_name = bare_name
use_layer_wise = True
if not HAVE_EMERGING_OPTIMIZERS:
raise ImportError(
f"emerging-optimizers package is required for optimizer='{eopt_name}'. "
"Install it with: pip install emerging-optimizers"
)
if eopt_name not in _EMERGING_OPTIMIZERS:
raise ValueError(f"Unsupported emerging optimizer: {eopt_name}")
if config.fp16:
raise ValueError('emerging optimizer with fp16 is not supported.')
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
log_single_rank(logger, logging.INFO, f'Setting up emerging optimizer with config {config}')
# Tag parameters with optimizer-specific attributes (expert_tp, is_qkv).
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
if 'experts' in name and 'shared' not in name:
param.expert_tp = True
# TODO(deyuf): support MLA
if 'linear_qkv.weight' in name and len(param.shape) == 2:
param.is_qkv = True
# Apply optimizer-specific default param overrides (e.g. muon: non-linear -> adam).
config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides)
# Build param groups and bucket by (optimizer_name, is_expert_parallel).
# Layer-wise distributed optimizer handles expert params internally so we skip that split.
all_param_groups = _get_param_groups(model_chunks, config, config_overrides)
grouped_param_groups = defaultdict(list)
for group in all_param_groups:
opt_name = group.get('optimizer', eopt_name)
is_expert = group['is_expert_parallel'] and not use_layer_wise
grouped_param_groups[(opt_name, is_expert)].append(group)
# Build an optimizer for each (optimizer_name, is_expert) bucket and combine.
results = []
for (opt_name, is_expert), groups in grouped_param_groups.items():
if not groups:
continue
model_parallel_group = pg_collection.tp_ep_pp if is_expert else pg_collection.mp
if opt_name in _EMERGING_OPTIMIZERS:
optimizer, init_state_fn = _create_emerging_optimizer(
config, groups, eopt_name, model_chunks, pg_collection
)
if use_layer_wise:
result = (optimizer, init_state_fn)
else:
if config.bf16:
optimizer = Float16OptimizerWithFloat16Params(
optimizer, config, None, init_state_fn
)
else:
optimizer = FP32Optimizer(optimizer, config, init_state_fn)
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
if pg_collection is None or not hasattr(pg_collection, 'tp'):
tp_group = parallel_state.get_tensor_model_parallel_group()
else:
tp_group = pg_collection.tp
setattr(optimizer, 'tp_group', tp_group)
result = optimizer
else:
fallback_config = copy.copy(config)
fallback_config.optimizer = opt_name
fallback_config.use_distributed_optimizer = False
result = _get_megatron_optimizer_based_on_param_groups(
config=fallback_config,
model_chunks=model_chunks,
param_groups=groups,
model_parallel_group=model_parallel_group,
pg_collection=pg_collection,
skip_megatron_wrapping=use_layer_wise,
)
# TODO(deyuf): ChainedOptimizer currently asserts all sub-optimizers
# share the same config. Revisit this design now that emerging
# optimizers mix different optimizer types (e.g. Muon + Adam).
# For now, reset to the top-level config so the assertion holds.
if not use_layer_wise and hasattr(result, 'config'):
result.config = config
results.append(result)
if use_layer_wise:
base_optimizers, init_fns = (), ()
if results:
base_optimizers, init_fns = zip(*results)
log_single_rank(
logger, logging.INFO, f'Using LayerWiseDistributedOptimizer for {eopt_name}'
)
return LayerWiseDistributedOptimizer(
list(base_optimizers),
config,
pg_collection,
init_state_fn_list=list(init_fns),
model_chunks=model_chunks if config.overlap_param_gather else None,
overlap_param_gather=config.overlap_param_gather,
)
return ChainedOptimizer(results)
def get_megatron_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]] = None,
use_gloo_process_groups: bool = True,
pg_collection: Optional[ProcessGroupCollection] = None,
dump_param_to_param_group_map: Optional[str] = None,
) -> MegatronOptimizer:
"""Retrieve the Megatron optimizer for model chunks.
Handles both standard optimizers (Adam, SGD) and emerging optimizers (e.g. Muon).
We use separate optimizers for expert parameters and non-expert parameters.
For emerging optimizers with ``config.use_layer_wise_distributed_optimizer=True``,
the optimizer is automatically wrapped with :class:`LayerWiseDistributedOptimizer`.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (List[MegatronModule]): model chunks to get optimizer for.
config_overrides (Optional[Dict[ParamKey, OptimizerConfig]]): optional dictionary of
optimizer configuration objects to override default optimizer behavior for different
subsets of parameters (identified by ParamKey).
use_gloo_process_groups (bool): if false, disable use of Gloo process groups
in underlying Megatron optimizers.
pg_collection: Optional unified process group for distributed training.
dump_param_to_param_group_map (Optional[str]): path to dump parameter to param group map.
Returns:
Instance of MegatronOptimizer.
"""
# None → apply standard defaults. To extend defaults with custom overrides,
# start from get_standard_config_overrides(config) and merge yours in.
if config_overrides is None:
config_overrides = get_standard_config_overrides(config)
check_config_overrides_consistency(config, config_overrides)
# TODO: the standard and emerging optimizer paths handle pg_collection differently;
# unify them so both use a single pg_collection-based flow.
if config.optimizer not in ('adam', 'sgd'):
return _get_megatron_emerging_optimizer(
config=config,
model_chunks=model_chunks,
config_overrides=config_overrides,
pg_collection=pg_collection,
)
log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')
# Separate out first model chunk if overlapping param AG with optimizer step.
if config.overlap_param_gather_with_optimizer_step:
all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]]
overlap_param_gather_with_optimizer_step_flags = [True, False]
else:
all_dense_model_chunks = [model_chunks]
overlap_param_gather_with_optimizer_step_flags = [False]
# Setup process groups using helper method
process_groups_dict = ProcessGroupCollection.setup_process_groups_for_optimizer(
pg_collection, model_chunks, use_gloo_process_groups
)
dp_cp_group = process_groups_dict['dp_cp_group']
intra_dp_cp_group = process_groups_dict['intra_dp_cp_group']
intra_expt_dp_group = process_groups_dict['intra_expt_dp_group']
mp_group = process_groups_dict['mp_group']
expt_tp_pp_group = process_groups_dict['expt_tp_pp_group']
intra_dp_cp_group_gloo = process_groups_dict['intra_dp_cp_group_gloo']
intra_expt_dp_group_gloo = process_groups_dict['intra_expt_dp_group_gloo']
intra_dist_opt_group = process_groups_dict['intra_dist_opt_group']
model_parallel_rank = get_pg_rank(mp_group)
if get_pg_size(dp_cp_group) > get_pg_size(intra_dp_cp_group):
inter_dist_opt_group = process_groups_dict['inter_dist_opt_group']
distributed_optimizer_instance_id = get_pg_rank(inter_dist_opt_group)
else:
distributed_optimizer_instance_id = 0
optimizers = []
model_chunk_offset = 0
ddp_config = model_chunks[0].ddp_config # Use the first model chunk's DDP config
if ddp_config.use_megatron_fsdp:
for model_chunk, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
model_chunk,
model_chunk_offset=model_chunk_offset,
config=config,
config_overrides=config_overrides,
filter_fn=lambda g: True,
buffer_name='buffers',
)
optimizer_part = _get_megatron_optimizer_based_on_param_groups(
config=config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mp_group,
data_parallel_group=dp_cp_group,
data_parallel_group_gloo=intra_dp_cp_group_gloo,
data_parallel_group_idx=model_parallel_rank,
intra_dist_opt_group=intra_dist_opt_group,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
pg_collection=pg_collection,
)
if (
not USING_PYTORCH_OPTIMIZER
and config.use_precision_aware_optimizer
and getattr(optimizer_part.optimizer, "master_weights", None) is not None
):
# NOTE(@cspades): FusedAdam is provided Megatron-FSDP's main weights as
# non-quantized DTensor(s). Megatron-FSDP should NEVER use FusedAdam's
# main weights, complete waste of memory as the optimizer step is still
# applied to the Megatron-FSDP main weight and extended to FusedAdam
# main weights. Override this here.
setattr(optimizer_part.optimizer, "master_weights", False)
optimizers.append(optimizer_part)
model_chunk_offset += 1
if len(optimizers) == 1:
return optimizers[0]
return ChainedOptimizer(optimizers)
if dump_param_to_param_group_map is not None:
param_to_param_group = {}
param_group_id = 0
for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
dense_model_chunks,
model_chunk_offset=model_chunk_offset,
config=config,
config_overrides=config_overrides,
filter_fn=lambda g: not g['is_expert_parallel'],
buffer_name='buffers',