-
Notifications
You must be signed in to change notification settings - Fork 701
Expand file tree
/
Copy path_common.py
More file actions
187 lines (156 loc) · 6.13 KB
/
_common.py
File metadata and controls
187 lines (156 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Helper functions used in fusible operations."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine_torch import FP8TensorMeta
from ..torch_version import torch_version
from ..quantization import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..quantized_tensor import QuantizedTensorStorage
from ..utils import canonicalize_dtype
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool:
"""Check if tensor is a quantized tensor"""
return isinstance(tensor, QuantizedTensorStorage)
def maybe_dequantize(
tensor: torch.Tensor | QuantizedTensorStorage, dtype: torch.dtype | None = None
) -> torch.Tensor:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if is_quantized_tensor(tensor):
return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype:
tensor = tensor.to(dtype)
if not tensor.is_contiguous():
tensor = tensor.contiguous()
return tensor
def maybe_autocast_dtype(
*,
device_type: str = "cuda",
default_dtype: Optional[torch.dtype] = None,
) -> torch.dtype:
"""Get autocast dtype if enabled"""
if torch_version() >= (2, 4, 3):
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
else:
if torch.is_autocast_enabled():
return torch.get_autocast_gpu_dtype()
return canonicalize_dtype(default_dtype)
def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, int]:
"""Get FP8TensorMeta object and index corresponding to Float8Tensor
Constructs FP8TensorMeta if needed.
"""
# Check if tensor already has FP8 metadata
if tensor._fp8_meta is not None:
key = FP8GlobalStateManager.get_meta_tensor_key(
forward=tensor._fp8_meta_forward,
)
return tensor._fp8_meta[key], tensor._fp8_meta_index
# Create FP8TensorMeta class
fp8_meta = FP8TensorMeta()
fp8_meta.scale = tensor._scale_inv.reciprocal()
fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=tensor.device)
fp8_meta.scale_inv = tensor._scale_inv
return fp8_meta, 0
def validate_grouped_mlp_dims(fc1, swiglu, fc2) -> None:
"""Validate FC1/SwiGLU/FC2 dimensions and interleave size for fused grouped MLP."""
if fc1.in_features % 64 != 0 or fc1.out_features % 64 != 0:
raise ValueError(
f"Unsupported dims for FC1 (num_groups={fc1.num_groups}, "
f"in_features={fc1.in_features}, out_features={fc1.out_features})."
)
if fc2.in_features % 64 != 0 or fc2.out_features % 64 != 0:
raise ValueError(
f"Unsupported dims for FC2 (num_groups={fc2.num_groups}, "
f"in_features={fc2.in_features}, out_features={fc2.out_features})."
)
if fc1.out_features != 2 * fc2.in_features or fc1.num_groups != fc2.num_groups:
raise ValueError(
f"FC1 (num_groups={fc1.num_groups}, in_features={fc1.in_features}, "
f"out_features={fc1.out_features}) "
f"and FC2 (num_groups={fc2.num_groups}, in_features={fc2.in_features}, "
f"out_features={fc2.out_features}) do not match."
)
if swiglu.glu_interleave_size != 32:
raise ValueError(
"Fused kernel requires 32-wide GLU interleaving, "
f"but got glu_interleave_size={swiglu.glu_interleave_size}."
)
def fuse_grouped_mlp_ops(
ops,
*,
recipe,
fused_op_cls,
):
"""Sliding-window fusion for GroupedLinear + ScaledSwiGLU + GroupedLinear.
Parameters
----------
ops : list of FusibleOperation
Operations to scan.
recipe : Recipe or None
Quantization recipe.
fused_op_cls : type
Fused operation class with ``is_supported()`` classmethod and
constructor accepting ``fc1``, ``swiglu``, ``fc2`` keyword args.
May also expose ``is_fc1_bias_supported()`` and/or
``is_fc2_bias_supported()`` classmethods for bias eligibility.
Returns
-------
list of FusibleOperation
Updated operations with matched triples replaced by fused ops.
"""
from .basic import GroupedLinear, ScaledSwiGLU # pylint: disable=import-outside-toplevel
if not fused_op_cls.is_supported():
return ops
if recipe is None or not recipe.mxfp8():
return ops
fc1_bias_ok = (
not hasattr(fused_op_cls, "is_fc1_bias_supported") or fused_op_cls.is_fc1_bias_supported()
)
fc2_bias_ok = (
not hasattr(fused_op_cls, "is_fc2_bias_supported") or fused_op_cls.is_fc2_bias_supported()
)
out = []
window, ops = ops[:3], ops[3:]
while len(window) == 3:
matches_pattern = True
if not (
isinstance(window[0], GroupedLinear)
and isinstance(window[1], ScaledSwiGLU)
and isinstance(window[2], GroupedLinear)
):
matches_pattern = False
elif window[0].num_groups != window[2].num_groups:
matches_pattern = False
elif (
window[0].in_features % 64 != 0
or window[0].out_features % 64 != 0
or window[2].in_features % 64 != 0
or window[2].out_features % 64 != 0
):
matches_pattern = False
elif window[1].glu_interleave_size != 32:
matches_pattern = False
elif window[0].has_bias and not fc1_bias_ok:
matches_pattern = False
elif window[2].has_bias and not fc2_bias_ok:
matches_pattern = False
if matches_pattern:
op = fused_op_cls(
fc1=window[0],
swiglu=window[1],
fc2=window[2],
)
window = [op]
else:
out.extend(window[:-2])
window = window[-2:]
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]
out.extend(window)
return out