-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathaiter_linear.py
More file actions
110 lines (84 loc) · 3.15 KB
/
aiter_linear.py
File metadata and controls
110 lines (84 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import lru_cache
from aiter import hipb_mm, hipb_create_extension, per_tensor_quant_hip
from aiter.tuned_gemm import tgemm
from aiter.ops.shuffle import shuffle_weight
from diffsynth_engine.utils.platform import DTYPE_FP8
from contextlib import contextmanager
@lru_cache(maxsize=1)
def init_hipblas():
hipb_create_extension()
@contextmanager
def use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=True, use_scale_for_fp8=False):
if not swizzle:
yield
return
# Preserve original F.linear
_original_linear = F.linear
def optimized_linear(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):
input_flat = input.reshape(-1, input.shape[-1])
init_hipblas()
weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=scaleA,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)
# Reshape output to match input dimensions
new_shape = input.shape[:-1] + (weight.shape[0],)
output = output_flat.view(new_shape)
return output
def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):
input_flat = input.reshape(-1, input.shape[-1])
if use_scale_for_fp8:
input_flat, a_scale = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8)
weight = weight.to(DTYPE_FP8)
init_hipblas()
weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=a_scale,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)
else:
input_flat = input_flat.to(DTYPE_FP8)
weight = weight.to(DTYPE_FP8)
init_hipblas()
weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=scaleA,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)
# Reshape output to match input dimensions
new_shape = input.shape[:-1] + (weight.shape[0],)
output = output_flat.view(new_shape)
return output
if use_fp8_linear:
F.linear = optimized_linear_fp8
else:
F.linear = optimized_linear
yield
F.linear = _original_linear