-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Expand file tree
/
Copy pathmodule.py
More file actions
531 lines (441 loc) · 21.7 KB
/
module.py
File metadata and controls
531 lines (441 loc) · 21.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module."""
from functools import partial
from typing import Optional, Tuple
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
ensure_metadata_has_dp_cp_group,
make_sharded_tensors_for_checkpoint,
sharded_state_dict_default,
)
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param): # pylint: disable=missing-function-docstring
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module):
"""Base Megatron module inhertied by all Models.
Megatron specific extensions of torch Module with support
for pipelining
Args:
config (TransformerConfig): Transformer config
"""
# def __init__(self, config: TransformerConfig, share_word_embeddings=True):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
def state_dict_for_save_checkpoint(self, prefix: str = '', keep_vars: bool = False):
"""Override state dict for saving checkpoints Use this function to override the
state dict for saving checkpoints.
Args:
prefix (str, optional): _description_. Defaults to ''.
keep_vars (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int], ...] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Default implementation for sharded state dict for distributed checkpointing.
General definition of sharded_state_dict simply calls `sharded_state_dict_default`
(which call sharded_state_dict method if possible or a default implementation otherwise)
recursively on all submodules.
Args:
prefix (str): prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional): metadata passed recursively to sharded_state_dict methods
Returns:
dict: dictionary of state dict keys mapped to ShardedTensors
"""
sharded_state_dict = {}
# Save parameters
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
if not hasattr(self, 'tp_group'):
# some model interface hasn't updated for m4, fallback needed
tp_group = parallel_state.get_tensor_model_parallel_group()
else:
tp_group = self.tp_group
# Guard for cases metadata is not provided
metadata = ensure_metadata_has_dp_cp_group(metadata)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
sharded_state_dict,
prefix,
sharded_offsets=sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata['dp_cp_group'],
)
# Recurse into submodules
for name, module in self.named_children():
sharded_state_dict.update(
sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=tp_group
)
)
return sharded_state_dict
def set_is_first_microbatch(self):
"""Sets the is_first_microbatch flag if it exists and config.fp8==True.
When this flag is set, TE modules will update their fp8 parameter cache.
If kitchen is being used, kitchen controls quantization level.
"""
if (
self.config.fp8 is not None
or self.config.fp4 is not None
or getattr(self.config, 'use_kitchen', False)
):
if not hasattr(self, "modules_with_is_first_microbatch"):
self.modules_with_is_first_microbatch = []
for m in self.modules():
if hasattr(m, "is_first_microbatch"):
self.modules_with_is_first_microbatch.append(m)
for m in self.modules_with_is_first_microbatch:
m.is_first_microbatch = True
def set_symmetric_ar(self, set_to: Optional[str] = None) -> None:
"""
Set symmetric all-reduce functionality across all eligible modules.
This method traverses the model's module hierarchy to find all modules
with the 'symmetric_ar_type' attribute, caches them, and then sets their
'_symmetric_ar_cache' attribute to the specified value to enable or
disable symmetric all-reduce operations.
Args:
set_to (Any, optional): Value to set for the 'symmetric_ar_type' to.
Allowed choices ['two_shot', "one_shot", "multimem_all_reduce", None]
"""
assert set_to in ['two_shot', "one_shot", "multimem_all_reduce", None]
# Recursive function to find all modules with our target attributes
def create_ar_cache(module):
# Check if this module has any of our target attributes
if hasattr(module, "symmetric_ar_type"):
self._symmetric_ar_cache.append(module)
# Check all children modules recursively
for child in module._modules.values():
if child is not None:
create_ar_cache(child)
if not hasattr(self, "_symmetric_ar_cache"):
self._symmetric_ar_cache = []
create_ar_cache(self)
for module in self._symmetric_ar_cache:
module._symmetric_ar_cache = set_to
class GraphableMegatronModule(MegatronModule):
"""Megatron module that can be used to capture and replay CUDA graphs.
Now only TransformerLayer and MambaLayer are graphable.
Args:
config (TransformerConfig): Transformer config
"""
def __init__(self, config: TransformerConfig, vp_stage: Optional[int] = None):
super().__init__(config)
assert isinstance(config, TransformerConfig), "config must be a TransformerConfig"
# Enable cuda graphs.
if (
config.cuda_graph_impl == "local"
and CudaGraphScope.full_iteration not in config.cuda_graph_scope
):
if hasattr(self, "create_mcore_cudagraph_manager"):
self.create_mcore_cudagraph_manager(config)
else:
from megatron.core.transformer.cuda_graphs import CudaGraphManager
self.cudagraph_manager = CudaGraphManager(config)
elif config.cuda_graph_impl == "transformer_engine":
# List to store CUDA graphs. A list of `N` CUDA graphs for this layer where N is
# the number of microbatches. Multiple CUDA graphs per layer is required to support
# pipelining which requires running FWD graph of multiple microbatches before BWD
# graph. To enable CUDA graph, this list should be populated in the model training
# script with the graphs returned by make_graphed_callables API before the first
# training step.
self.cuda_graphs = []
# List to store forward pre-hooks. Forward pre-hooks are not captured into CUDA
# graphs. Those hooks and args are collected in this list and should be manually
# triggered before CUDA Graph running. This is required to ensure the correct param
# all-gather overlap with forward compute.
self.cuda_graph_manual_hooks = []
# _CudaGraphBackwardDWWrapper object used to manage the wgrad backward computation.
# The `backward_dw` func api is the same as `TransformerLayerNode.backward_dw` and
# calls wgrad computation in attention module (contains attn and shared expert)
# according to CUDA graph scope.
self.cuda_graph_backward_dw_wrapper = None
def init_backward_dw_wrapper(self):
"""Initialize the backward_dw_wrapper."""
from megatron.core.models.gpt.fine_grained_callables import _BackwardDWWrapper
config = getattr(self, 'config', None)
assert config is not None, (
"TransformerLayer must be initialized before calling " "`init_backward_dw_wrapper`."
)
self.backward_dw_wrapper = _BackwardDWWrapper(self)
def set_te_cuda_graph_backward_dw_wrapper(self):
"""Replace the backward_dw callable with dw cuda graph."""
assert (
self.backward_dw_wrapper is not None
), "`backward_dw_wrapper` must be set when cuda graphs are enabled for ep overlap."
self.backward_dw_wrapper.set_graphed_backward_dw_callable(
partial(self._te_cuda_graph_backward_dw_graph, self.current_microbatch)
)
def _te_cuda_graph_backward_dw_graph(self, microbatch_idx):
"""
CUDA Graph backward weight gradient computation for current layer.
"""
cg_index = microbatch_idx % len(self.cuda_graphs)
if not hasattr(self.cuda_graphs[cg_index], 'backward_dw'):
return
self.cuda_graphs[cg_index].backward_dw()
def get_layer_static_inputs(self, seq_length, micro_batch_size):
"""
Get the static inputs for the layer.
We assume that the module has one hidden_states input, whose shape is inferred
from the seq_length, micro_batch_size, and parallel config.
Override this method if the module has other inputs.
Returns:
Dict[str, torch.Tensor]: A dictionary containing the static inputs for the layer.
"""
# Calculate data shape related values.
context_parallel_size = self.config.context_parallel_size
slen_per_cp = seq_length // context_parallel_size
sequence_parallel = self.config.sequence_parallel
tensor_model_parallel_size = self.config.tensor_model_parallel_size
slen_per_cptp = (
slen_per_cp // tensor_model_parallel_size if sequence_parallel else slen_per_cp
)
static_inputs = {}
params_dtype = (
self.config.params_dtype if self.config.params_dtype is not None else torch.bfloat16
)
static_inputs["hidden_states"] = torch.ones(
(slen_per_cptp, micro_batch_size, self.config.hidden_size),
dtype=params_dtype,
requires_grad=True,
device=torch.cuda.current_device(),
)
return static_inputs
def setup_manual_hooks(self, make_hook_func):
"""
Set CUDA Graph manual hooks for the submodules that contain direct parameters and are
covered by cudagraphs.
"""
self.cuda_graph_manual_hooks = []
# Select the modules who contain direct parameters and are covered by cudagraphs.
# Add these modules to the `cuda_graph_manual_hooks` because their hooks will not
# be automatically triggered when they go through the CUDA Graph path.
param_modules = {}
for submodule in self._get_submodules_under_cudagraphs():
for module in submodule.modules():
if next(module.parameters(recurse=False), None) is not None:
# Module contains direct parameters.
param_modules[id(module)] = module
for module in param_modules.values():
self.cuda_graph_manual_hooks.append((make_hook_func(), (module,)))
def _get_submodules_under_cudagraphs(self):
"""
Get the submodules that are covered by cudagraphs. Return a list that only contains the
module itself if the whole layer is covered by cudagraphs.
"""
return [self]
def _te_cuda_graph_capture(self, *args, **kwargs):
"""
CUDA Graph capture for this layer using TE interface.
Normally it's just a forward pass if we're capturing the entire layer.
"""
forward_func = getattr(self, '_original_forward', self.forward)
return forward_func(*args, **kwargs)
def _te_cuda_graph_replay(self, *args, **kwargs):
"""
CUDA graph replay for this layer and microbatch `self.current_microbatch` using TE
interface. TransformerEngine versions>=1.10 allow keyword arguments with CUDA graph.
However, CUDA graph accepts only Tensor inputs.
Hence, check if the arguments are all tensors.
"""
for arg in args:
assert isinstance(arg, torch.Tensor), "CUDA graph accepts only Tensor inputs."
for _, v in kwargs.items():
assert v is None or isinstance(
v, torch.Tensor
), "CUDA graph accepts only Tensor inputs."
cg_index = getattr(self, 'current_microbatch', 0) % len(self.cuda_graphs)
cudagraph_args, cudagraph_kwargs = self._get_te_cuda_graph_replay_args(*args, **kwargs)
for hook, hook_args in self.cuda_graph_manual_hooks:
hook(*hook_args)
return self.cuda_graphs[cg_index](*cudagraph_args, **cudagraph_kwargs)
def _get_te_cuda_graph_replay_args(self, *args, **kwargs):
"""Helper function to get tensor arguments for TE CUDA graph."""
if len(args) == 0:
assert 'hidden_states' in kwargs, "hidden_states is required."
hidden_states = kwargs.pop('hidden_states')
cudagraph_args = (hidden_states,)
else:
assert (
'hidden_states' not in kwargs
), "hidden_states should only be passed as either a positional or keyword argument."
cudagraph_args = tuple(args)
cudagraph_kwargs = kwargs.copy()
cudagraph_kwargs['is_first_microbatch'] = getattr(self, 'current_microbatch', 0) == 0
return cudagraph_args, cudagraph_kwargs
def _should_call_local_cudagraph(self, *args, **kwargs):
"""
Check if we should call the local cudagraph path.
"""
return hasattr(self, 'cudagraph_manager')
def _should_call_te_cudagraph(self, *args, **kwargs):
"""
Check if we should call the TE cudagraph path.
"""
from megatron.core.transformer.cuda_graphs import is_graph_capturing
return (
self.config.cuda_graph_impl == "transformer_engine"
and self.training
and (is_graph_capturing() or self.cuda_graphs)
)
def __call__(self, *args, **kwargs):
if self._should_call_local_cudagraph(*args, **kwargs):
return self.cudagraph_manager(self, args, kwargs)
elif self._should_call_te_cudagraph(*args, **kwargs):
# Temporarily replace forward with cuda graph function
self._original_forward = self.forward
try:
if not self.cuda_graphs:
# Do CUDA Graphs capture.
self.forward = self._te_cuda_graph_capture
else:
# Do CUDA Graphs replay.
self.forward = self._te_cuda_graph_replay
return super().__call__(*args, **kwargs)
finally:
# Restore original forward and clean up temporary attribute
self.forward = self._original_forward
if hasattr(self, '_original_forward'):
delattr(self, '_original_forward')
return super().__call__(*args, **kwargs)
def conversion_helper(val, conversion):
"""Recursively applies a conversion function to values in nested data structures.
Args:
val: A single value or a nested structure (tuple/list) of values to convert
conversion (callable): A function that performs the desired conversion on a single value
Returns:
The converted value, maintaining the same nested structure as the input.
If input is a single value, returns the converted value.
If input is a tuple/list, returns a tuple/list with all elements converted.
"""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val, float16_convertor):
"""Converts floating-point values from fp32 to fp16.
Args:
val: The value to convert. Can be a single number, a tuple, or a list.
float16_convertor: A function that converts a single fp32 value to fp16
"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
"""Converts floating-point values from fp16 to fp32.
Args:
val: The value to convert. Can be a single number, a tuple, or a list.
"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class Float16Module(MegatronModule):
"""Float 16 Module.
Attributes:
config (TransformerConfig): Transformer config
fp16 (bool) : Specifies if the model runs in fp16 mode
bf16 (bool) : Specifies if the model runs in bf16 mode
Args:
config (TransformerConfig): The transformer config used to initalize the model
"""
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super(Float16Module, self).__init__(config)
self.config = config
self.fp16 = config.fp16
self.bf16 = config.bf16
self.vp_size = config.virtual_pipeline_model_parallel_size
self.vp_stage = getattr(module, 'vp_stage', None)
self.pg_collection = getattr(module, 'pg_collection', None)
if self.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif self.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('Either config.fp16 or config.bf16 should be True.')
self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor): # pylint: disable=missing-function-docstring
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, fp32_output=True, **kwargs):
"""
Execute the wrapped module in model precision and optionally upcast outputs to fp32.
On the first pipeline stage, positional/keyword tensor inputs are converted to the
module precision (fp16 or bf16) before invoking the wrapped module. The wrapped module
is called with the provided inputs and keyword arguments. On the last pipeline stage
only, outputs are upcast to fp32 if ``fp32_output`` is True; otherwise, outputs are
returned in the model precision (fp16/bf16).
Args:
*inputs: Positional inputs forwarded to the wrapped module (converted to fp16/bf16 on
the pipeline first stage).
fp32_output (bool, keyword-only): If True (default), upcast outputs to fp32 on the
pipeline last stage. Has no effect on non-last stages. Set to False to keep outputs
in model precision when downstream consumers expect half precision or to avoid
extra casts.
**kwargs: Keyword arguments forwarded to the wrapped module.
Returns:
The wrapped module's outputs, potentially upcast to fp32 depending on pipeline stage
and ``fp32_output``.
"""
from megatron.core.pipeline_parallel.utils import (
is_pp_first_stage,
is_pp_last_stage,
is_vp_first_stage,
is_vp_last_stage,
)
if self.pg_collection is None:
pp_group = parallel_state.get_pipeline_model_parallel_group()
else:
pp_group = self.pg_collection.pp
if is_vp_first_stage(self.vp_stage, self.vp_size) and is_pp_first_stage(pp_group):
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if (
is_vp_last_stage(self.vp_stage, self.vp_size)
and is_pp_last_stage(pp_group)
and fp32_output is True
):
outputs = float16_to_fp32(outputs)
return outputs
def state_dict(
self, destination=None, prefix='', keep_vars=False
): # pylint: disable=missing-function-docstring
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Retrieve state_dict from the module being wrapped."""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(self, prefix='', *args, **kwargs):
"""Retrieve sharded_state_dict from the module being wrapped."""
return self.module.sharded_state_dict(prefix, *args, **kwargs)
def load_state_dict(
self, state_dict, strict=True
): # pylint: disable=missing-function-docstring
self.module.load_state_dict(state_dict, strict=strict)