-
Notifications
You must be signed in to change notification settings - Fork 699
Expand file tree
/
Copy pathapi.py
More file actions
541 lines (466 loc) · 22.1 KB
/
api.py
File metadata and controls
541 lines (466 loc) · 22.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""API definition for nvidia-dlframework-inspect."""
import copy
import warnings
from typing import Dict, Union, Tuple, Optional
from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper
from nvdlfw_inspect.registry import Registry
import torch
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import get_all_tensor_types
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
class TEConfigAPIMapper(BaseConfigAPIMapper):
"""Class responsible for determining which NV DLFW Inspect API should be run for each tensor and gemm."""
def parse_config_and_api(self, config, **kwargs):
"""Process the config and returns True if the config and api args match, along with processed config."""
processed_config = None
config_copy = copy.deepcopy(config)
gemm_parsing = kwargs.get("gemm_parsing", False)
tensor_parsing = kwargs.get("tensor_parsing", False)
if gemm_parsing:
# parse with GEMM and/or tensor
processed_config = self._process_transformer_engine_config(config_copy, **kwargs)
elif tensor_parsing:
# parse with only tensor
processed_config = self._process_tensor_config(config_copy, kwargs["tensor_name"])
if not processed_config:
return False, None
if "enabled" in processed_config:
processed_config.pop("enabled")
return True, processed_config
def _validate_gemm(self, gemm):
assert gemm in ["fprop", "wgrad", "dgrad"], (
f"[NVTORCH INSPECT ERROR] Invalid gemm: {gemm}. It must be one of the ['fprop',"
" 'wgrad', 'dgrad']."
)
def _process_transformer_engine_config(self, config, **kwargs):
"""
Return config specific to a particular tensor name and gemm that matches the api args.
"""
if "gemms_struct" in config:
for cfg in config["gemms_struct"]:
self._validate_gemm(cfg["gemm"])
if cfg["gemm"] == kwargs["gemm"]:
if kwargs["tensor_parsing"]:
cfg = self._process_tensor_config(cfg, kwargs["tensor_name"])
if not cfg:
return None
cfg_copy = copy.deepcopy(cfg)
config.pop("gemms_struct")
assert (
"enabled" not in cfg_copy
), "[NVTORCH INSPECT ERROR] Enabled field should not be part of gemms_struct"
config.update(cfg_copy)
return config
return None
if "gemms" in config:
for gemm in config["gemms"]:
self._validate_gemm(gemm)
if kwargs["gemm"] in config["gemms"]:
if kwargs["tensor_parsing"]:
cfg = self._process_tensor_config(config, kwargs["tensor_name"])
if not cfg:
return None
config["gemm"] = kwargs["gemm"]
config.pop("gemms")
return config
return None
raise ValueError(
"[NVTORCH INSPECT ERROR] Provide 'gemms_struct: List[Dict]' or 'gemms: List[str]'"
" in the config yaml"
)
required_kwargs = {
"fp8_gemm_enabled": ["gemm"],
"modify_tensor_enabled": ["tensor_name", "gemm"],
"modify_tensor": ["tensor_name", "gemm"],
"inspect_tensor": ["tensor_name"],
"inspect_tensor_postquantize": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name"],
"inspect_tensor_postquantize_enabled": ["tensor_name"],
"default": ["tensor_name", "gemm"],
}
# pylint: disable=unused-argument
class TEDefaultFeatures:
"""Transformer Engine API calls default behavior."""
def fp8_gemm_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
iteration: int,
) -> bool | Tuple[bool, Optional[int]]:
"""
If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled,
then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*.
If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled,
the result of this call does not matter.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be disabled.
It can return (bool, None) if the feature will never be enabled for that layer and gemm.
Returning the next enabled iteration can help optimize CPU usage.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
Union[bool, Tuple[bool, Optional[int]]] - default is (True, None)
"""
return True, None # if it is false, fp8_gemm will be turned off. Otherwise nothing happens.
def modify_tensor_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
iteration: int,
) -> bool | Tuple[bool, Optional[int]]:
"""
It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name.
It has **higher priority** than fp8_gemm; if *modify_tensor_enabled* returns True or (True, next_enabled_iter),
then modify_tensor call is invoked for the respective tensor no matter what.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor.
Returning the next enabled iteration can help optimize CPU usage, especially when the interval between modify_tensor is large.
Returning only a bool is deprecated.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
"""
return False, None
def modify_tensor(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
tensor: torch.Tensor,
default_quantizer: Quantizer,
iteration: int,
out: Union[torch.Tensor, QuantizedTensor],
) -> torch.Tensor | QuantizedTensor | None:
"""
It allows tensor modification.
For example, feature `FakeQuant` uses it to emulate casting to FP8.
It can be invoked at most once for each tensor within a given GEMM operation.
This call is invoked if `modify_tensor_enabled` returns `True` and the feature is enabled for the *tensor_name* and *gemm*.
Then it is called **instead of** the default quantization.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor: torch.Tensor
tensor in high precision,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
default_quantizer : Quantizer
quantizer which is used to cast the tensor to lower precision
if *modify_tensor* is not invoked. For example,
feature per tensor scale uses it to obtain FP8 dtype of the tensor.
If the recipe indicates that the tensor is not cast - for example,
if running without FP8 autocast, then `default_quantizer=None`,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
out: Union[torch.Tensor, QuantizedTensor]
output tensor, used in the weight caching mechanism.
Returns
-------
Union[torch.Tensor, transformer_engine.pytorch.QuantizerTensor, None]
can be `torch.Tensor` or one of the Transformer Engine's `QuantizedTensor` -
the rule is that both tensors returned for each GEMM should have the same type.
If both are `Float8Tensor`, then GEMM is run in FP8.
If both are `torch.Tensor`, GEMM is run in high precision.
Please take that into account especially if only one tensor of the GEMM
is processed by the `modify_tensor()`. For example, `FakeQuant`
disabled FP8 GEMM to ensure that the second tensor is also in high precision.
If the tensor is not the input for any GEMM - namely `output`,
`wgrad` and `dgrad` - the return type would match the input type.
Should return `None` if `out` is not `None`.
"""
raise NotImplementedError(
"modify_tensor_enabled() returned True, modify_tensor() was invoked, but it is not"
" handled by any API."
)
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: Optional[torch.Tensor],
rowwise_quantized_tensor: Optional[torch.Tensor],
columnwise_quantized_tensor: Optional[torch.Tensor],
quantizer: Optional[Quantizer],
iteration: int,
tp_group: torch.distributed.ProcessGroup,
) -> None:
"""
The feature is invoked if *inspect_tensor_enabled* returns `True`. It can be used to obtain information on the high precision tensor. For example, it is run by the `LogTensorStats` feature.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: Optional[torch.Tensor]
tensor in high precision. It can be None only if fp8 model parameters are used and tensor name is `weight`.
rowwise_quantized_tensor: Optional[torch.Tensor]
rowwise quantized tensor,
columnwise_quantized_tensor: Optional[torch.Tensor]
columnwise quantized tensor,
quantizer: Optional[Quantizer]
quantizer,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
process group for the tensor parallel group. This is used for weight statistics reduction.
This is not reduction group from debug_api.
Returns
-------
Should return nothing.
"""
def inspect_tensor_postquantize(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: torch.Tensor,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
rowwise: bool,
) -> None:
"""
This is deprecated call, we advise to use *inspect_tensor* instead.
Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in fp8 or processed tensor after the modify_tensor call,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
process group for the tensor parallel group. This is used for weight statistics reduction.
This is not reduction group from debug_api.
Returns
-------
Should return nothing.
"""
def inspect_tensor_enabled(
self,
config: Dict,
layer_name: str,
tensor_name: str,
iteration: int,
) -> bool | Tuple[bool, Optional[int]]:
"""
It is a routing call, which is run at the initialization of the layer.
Determines if *inspect_tensor* for a given GEMM and tensor will be invoked.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer and tensor.
Returning the next enabled iteration can help optimize CPU usage, especially when the interval between inspect_tensor is large.
Returning only a bool is deprecated.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
"""
return False, None
def inspect_tensor_postquantize_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
iteration: int,
) -> bool | Tuple[bool, Optional[int]]:
"""
This is deprecated call, we advise to use *inspect_tensor* and *inspect_tensor_enabled* instead.
It is a routing call, which is run at the initialization of the layer.
Determines if *inspect_tensor_postquantize* for a given GEMM and tensor will be invoked.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor name.
Returning the next enabled iteration can help optimize CPU usage,
especially when the interval between inspect_tensor_postquantize is large.
Returning only a bool is deprecated.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
"""
return False, None
@Registry.register_namespace_api(namespace="transformer_engine")
class TransformerEngineAPI(BaseNamespaceAPI):
"""
Transformer Engine API class that contains default APIs that are invoked when a config is not provided
or a layer is not selected in the config.
TransformerEngine specific features must override these APIs wherever required.
The overridden APIs will be invoked whenever the corresponding feature is enabled in the config.
"""
def __init__(self):
BaseNamespaceAPI.__init__(self)
self._default_api_impl = TEDefaultFeatures()
self._cacheable_api_kwargs_map = {
"fp8_gemm": ["gemm"],
"modify_tensor": ["tensor_name", "gemm"],
"inspect_tensor": ["tensor_name"],
"inspect_tensor_postquantize": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name", "iteration"],
"inspect_tensor_postquantize_enabled": ["tensor_name", "iteration"],
"modify_tensor_enabled": ["tensor_name"],
}
def is_multiple_feature_invocation_allowed(self, api_name):
"""
Check if API allows executing multiple features for a single call
"""
return api_name in {
"fp8_gemm_enabled",
"inspect_tensor",
"inspect_tensor_postquantize",
"inspect_tensor_enabled",
"inspect_tensor_postquantize_enabled",
}
def input_assertions_hook(self, api_name, **kwargs):
"""
These args must be passed as kwargs in the API call for all TransformerEngine specific APIs.
"""
if api_name in required_kwargs:
for kwarg in required_kwargs[api_name]:
assert kwarg in kwargs, (
f"[NVTORCH INSPECT ERROR] Cannot route API, too ambiguous. Provide {kwarg} in"
f" {api_name}."
)
else:
for kwarg in required_kwargs["default"]:
assert kwarg in kwargs, (
f"[NVTORCH INSPECT ERROR] Cannot route API, too ambiguous. Provide {kwarg} in"
f" {api_name}."
)
def routing_condition(self, api_name, config, _, feature_obj, **kwargs):
"""
Overridden APIs are selected based on the GEMM name in the config and kwargs.
"""
tensor_parsing = "tensor_name" in required_kwargs[api_name]
gemm_parsing = "gemm" in required_kwargs[api_name]
status, modified_config = feature_obj.parse_config_and_api(
config, gemm_parsing=gemm_parsing, tensor_parsing=tensor_parsing, **kwargs
)
return status, modified_config
def output_assertions_hook(self, api_name, ret, **kwargs):
"""Output hooks used to check correctness of the outputs of the API calls."""
if "enabled" in api_name or api_name == "fp8_gemm":
assert isinstance(ret, (bool, tuple))
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None
if api_name == "modify_tensor":
assert type(ret) in get_all_tensor_types()
if type(ret) is torch.Tensor and "dtype" in kwargs:
if kwargs["dtype"] is not None:
assert ret.dtype == kwargs["dtype"]
def call_feature(self, call, feat_config, layer_name, **kwargs):
"""
For backward compatibility, remove kwargs that are not needed for the call
"""
if call.__name__ == "inspect_tensor":
kwargs_copy = kwargs.copy()
for k in [
"quantizer",
"columnwise_quantized_tensor",
"rowwise_quantized_tensor",
"tp_size",
]:
if k not in call.__code__.co_varnames:
kwargs_copy.pop(k, None)
else:
kwargs_copy = kwargs
if call.__name__ == "inspect_tensor_postquantize":
warnings.warn(
"inspect_tensor_postquantize is deprecated, use inspect_tensor instead.",
DeprecationWarning,
)
kwargs_copy = kwargs.copy()
for k in ["tp_size"]:
if k not in call.__code__.co_varnames:
kwargs_copy.pop(
k, None
) # use None default to avoid KeyError if kwarg wasn't passed
return call(feat_config, layer_name, **kwargs_copy)
def handle_multi_feature_output(
self, api_name, multi_feature_outputs, features_to_invoke, **kwargs
):
"""
Handle multi-tensor output of the API calls.
"""
if "enabled" in api_name:
# *_enabled feature calls can return bool, or tuple (bool, Optional[int]).
# If any of them returns bool, then we return bool - this means that we cannot state anything
# about enablement in the next steps.
# If all of them return a tuple (bool, Optional[int]), we return the minimum value,
# representing the number of steps after the feature will be enabled next time.
# If the second value is None, that means that the feature will never be enabled.
all_ret_tuple = all(
isinstance(feature_output, tuple) for feature_output in multi_feature_outputs
)
if all_ret_tuple:
run_current = any(feature_output[0] for feature_output in multi_feature_outputs)
next_iter = None
for feature_output in multi_feature_outputs:
if next_iter is None:
next_iter = feature_output[1]
elif feature_output[1] is not None:
next_iter = min(next_iter, feature_output[1])
return run_current, next_iter
run_current = any(feature_output for feature_output in multi_feature_outputs)
return run_current, None
return super().handle_multi_feature_output(
api_name, multi_feature_outputs, features_to_invoke, **kwargs
)
def step(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.step()"""
STATS_BUFFERS.log_stats()
def end_debug(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()"""
TEDebugState._reset()