Skip to content

Commit 8b75699

Browse files
windreameropenhands-agent
authored andcommitted
fix: optimize TurboQuant kernel by replacing codebook lookup with inline centroid calculation
- Replace memory-loaded Lloyd-Max codebook with pure register-based centroid calculation in PagedAttention kernel (quant_policy==42) - Add _k4v2_k_centroid and _k4v2_v_centroid triton JIT functions for inline centroid lookup, reducing memory access and improving performance - Remove turbo_quant flag and codebook pointer parameters from kernel - Fix conditional branch bug in turbo_quant.py (else -> elif bits == 4)
1 parent 1255c52 commit 8b75699

File tree

2 files changed

+55
-46
lines changed

2 files changed

+55
-46
lines changed

lmdeploy/pytorch/kernels/cuda/pagedattention.py

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from lmdeploy.messages import QuantPolicy
1313
from lmdeploy.utils import get_logger
1414

15-
from .turbo_quant import get_lloyd_max_codebook, hadamard_rotate, hadamard_rotate_inv
15+
from .turbo_quant import hadamard_rotate, hadamard_rotate_inv
1616
from .utils import get_device_props
1717

1818
logger = get_logger('lmdeploy')
@@ -220,15 +220,53 @@ def _fwd_grouped_split_kernel(
220220
tl.store(acc_out_ptr + off_meta + 1, l_i, mask=mask_h)
221221

222222

223+
224+
225+
@triton.jit
226+
def _k4v2_k_centroid(idx3, head_size: tl.constexpr):
227+
"""QJL4 K centroid lookup: 8 entries, pure register."""
228+
# Lloyd-Max 3-bit centroids at sigma=1
229+
S0: tl.constexpr = -2.1519456
230+
S1: tl.constexpr = -1.3439093
231+
S2: tl.constexpr = -0.7560052
232+
S3: tl.constexpr = -0.2450942
233+
S4: tl.constexpr = 0.2450942
234+
S5: tl.constexpr = 0.7560052
235+
S6: tl.constexpr = 1.3439093
236+
S7: tl.constexpr = 2.1519456
237+
sigma: tl.constexpr = 1.0 / tl.math.sqrt(head_size * 2.0)
238+
c = tl.where(idx3 < 4,
239+
tl.where(idx3 < 2,
240+
tl.where(idx3 == 0, S0, S1),
241+
tl.where(idx3 == 2, S2, S3)),
242+
tl.where(idx3 < 6,
243+
tl.where(idx3 == 4, S4, S5),
244+
tl.where(idx3 == 6, S6, S7)))
245+
return c * sigma
246+
247+
248+
@triton.jit
249+
def _k4v2_v_centroid(idx2, head_size_v: tl.constexpr):
250+
"""MSE int2 V centroid lookup: 4 entries, pure register."""
251+
# Lloyd-Max 2-bit centroids at sigma=1
252+
S0: tl.constexpr = -1.5104176
253+
S1: tl.constexpr = -0.4527808
254+
S2: tl.constexpr = 0.4527808
255+
S3: tl.constexpr = 1.5104176
256+
sigma: tl.constexpr = 1.0 / tl.math.sqrt(head_size_v * 4.0)
257+
c = tl.where(idx2 < 2,
258+
tl.where(idx2 == 0, S0, S1),
259+
tl.where(idx2 == 2, S2, S3))
260+
return c * sigma
261+
262+
223263
@triton.jit
224264
def _fwd_grouped_split_quant_kernel(
225265
q_ptr,
226266
k_ptr,
227267
v_ptr,
228268
KScalesZeros,
229269
VScalesZeros,
230-
k_codebook_ptr,
231-
v_codebook_ptr,
232270
sm_scale,
233271
cache_seqlens_ptr,
234272
page_table_ptr,
@@ -254,7 +292,6 @@ def _fwd_grouped_split_quant_kernel(
254292
stride_vszh: tl.constexpr,
255293
stride_vszd: tl.constexpr,
256294
quant_policy: tl.constexpr,
257-
turbo_quant: tl.constexpr,
258295
stride_ok: tl.constexpr,
259296
stride_obs: tl.constexpr,
260297
stride_oh: tl.constexpr,
@@ -406,23 +443,14 @@ def _fwd_grouped_split_quant_kernel(
406443
if quant_policy == 4 or quant_policy == 42:
407444
k = (k >> shift_kd) & 0x0F
408445

409-
if turbo_quant and quant_policy == 42:
410-
# K = QJL4:
411-
# low 3bit = mse idx
412-
# high 1bit = qjl sign
446+
if quant_policy == 42:
413447
kmse_norm = tl.load(ksz_ptrs + b_offset * stride_kszp)
414448
kqjl_norm = tl.load(ksz_ptrs + b_offset * stride_kszp + stride_kszd)
415449

416-
k_idx3 = k & 0x7
417-
k_bit1 = (k >> 3) & 0x1
418-
k_sign = k_bit1.to(tl.float32) * 2.0 - 1.0
419-
420-
k = tl.load(k_codebook_ptr + k_idx3.to(tl.int32))
421-
k = (kmse_norm * (k + kqjl_norm * k_sign)).to(q.dtype)
422-
elif turbo_quant:
423-
ks = tl.load(ksz_ptrs + b_offset * stride_kszp)
424-
k = tl.load(k_codebook_ptr + k.to(tl.int32))
425-
k = (k * ks).to(q.dtype)
450+
# k is 4-bit nibble: low 3 = mse_idx, high 1 = sign
451+
k_cent = _k4v2_k_centroid((k & 0x7), head_size)
452+
k_sign = ((k >> 3) & 0x1).to(tl.float32) * 2.0 - 1.0
453+
k = (kmse_norm * (k_cent + kqjl_norm * k_sign)).to(q.dtype)
426454
else:
427455
ks = tl.load(ksz_ptrs + b_offset * stride_kszp)
428456
kz = tl.load(ksz_ptrs + b_offset * stride_kszp + 1)
@@ -433,19 +461,13 @@ def _fwd_grouped_split_quant_kernel(
433461
if quant_policy == 4 or quant_policy == 42:
434462
k1 = (k1 >> shift_k1d) & 0x0F
435463

436-
if turbo_quant and quant_policy == 42:
464+
if quant_policy == 42:
437465
kmse_norm = tl.load(ksz_ptrs + b_offset * stride_kszp)
438466
kqjl_norm = tl.load(ksz_ptrs + b_offset * stride_kszp + stride_kszd)
439467

440-
k1_idx3 = k1 & 0x7
441-
k1_bit1 = (k1 >> 3) & 0x1
442-
k1_sign = k1_bit1.to(tl.float32) * 2.0 - 1.0
443-
444-
k1 = tl.load(k_codebook_ptr + k1_idx3.to(tl.int32))
445-
k1 = (kmse_norm * (k1 + kqjl_norm * k1_sign)).to(q.dtype)
446-
elif turbo_quant:
447-
k1 = tl.load(k_codebook_ptr + k1.to(tl.int32))
448-
k1 = (k1 * ks).to(q.dtype)
468+
k1_cent = _k4v2_k_centroid((k1 & 0x7), head_size)
469+
k1_sign = ((k1 >> 3) & 0x1).to(tl.float32) * 2.0 - 1.0
470+
k1 = (kmse_norm * (k1_cent + kqjl_norm * k1_sign)).to(q.dtype)
449471
else:
450472
k1 = ((k1 - kz) * ks).to(q.dtype)
451473

@@ -459,9 +481,9 @@ def _fwd_grouped_split_quant_kernel(
459481
else:
460482
v = tl.load(v_ptr + off_v + b_offset * stride_vp)
461483

462-
if turbo_quant:
484+
if quant_policy == 42:
463485
vs = tl.load(vsz_ptrs + b_offset * stride_vszp)
464-
v = tl.load(v_codebook_ptr + v.to(tl.int32))
486+
v = _k4v2_v_centroid(v, head_size_v)
465487
v = (v * vs).to(q.dtype)
466488
else:
467489
vs = tl.load(vsz_ptrs + b_offset * stride_vszp)
@@ -690,12 +712,6 @@ def _get_block_d(Lk):
690712
BLOCK_DV = triton.next_power_of_2(Lv)
691713
return BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV
692714

693-
turbo_quant = False
694-
# Triton still receives these arguments for quantized paths, so keep
695-
# valid tensor-backed pointers even when turbo quant is not enabled.
696-
# They will be overwritten with real codebooks when quant_policy == 42.
697-
turbo_k_codebook = q.new_empty((1, ))
698-
turbo_v_codebook = q.new_empty((1, ))
699715

700716
# shape constraints
701717
Lq, Lk, Lv = q.shape[-1], k_cache.shape[d_dim], v_cache.shape[d_dim]
@@ -720,19 +736,13 @@ def _get_block_d(Lk):
720736
# - V dequant as norm * centroid[idx2]
721737
# - output inverse-rotated because V is still rotated before caching
722738
if quant_policy == 42:
723-
turbo_quant = True
724739
real_k_dim = Lq
725740
real_v_dim = Lv * 4
726741
if real_k_dim & (real_k_dim - 1) != 0:
727742
raise ValueError(f'TurboQuant requires power-of-2 K/Q head dim, got {real_k_dim}')
728743
if real_v_dim & (real_v_dim - 1) != 0:
729744
raise ValueError(f'TurboQuant requires power-of-2 V head dim, got {real_v_dim}')
730745

731-
# K = QJL4 => 3bit centroid codebook
732-
turbo_k_codebook, _ = get_lloyd_max_codebook(real_k_dim, bits=3, device=q.device)
733-
# V = TurboQuant MSE int2 => 2bit centroid codebook
734-
turbo_v_codebook, _ = get_lloyd_max_codebook(real_v_dim, bits=2, device=q.device)
735-
736746
# Rotate query into the same domain as quantized K/V
737747
q = hadamard_rotate(q)
738748

@@ -795,8 +805,6 @@ def _get_block_d(Lk):
795805
v_cache,
796806
k_scales_zeros,
797807
v_scales_zeros,
798-
turbo_k_codebook,
799-
turbo_v_codebook,
800808
softmax_scale,
801809
cache_seqlens,
802810
page_table,
@@ -822,7 +830,6 @@ def _get_block_d(Lk):
822830
stride_vszh=v_scales_zeros.stride(h_dim),
823831
stride_vszd=v_scales_zeros.stride(d_dim),
824832
quant_policy=quant_policy,
825-
turbo_quant=turbo_quant,
826833
stride_ok=acc.stride(-2),
827834
stride_obs=acc.stride(-4),
828835
stride_oh=acc.stride(-3),

lmdeploy/pytorch/kernels/cuda/turbo_quant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def get_lloyd_max_codebook(d: int, bits: int, device: str = 'cuda') -> tuple[Ten
182182
device=device,
183183
dtype=torch.float32,
184184
)
185-
else: # bits == 4
185+
elif bits == 4:
186186
centroids_std = torch.tensor(
187187
[-2.4175594, -1.7094618, -1.2629677, -0.9265621,
188188
-0.6470380, -0.4015197, -0.1756835, 0.0391761,
@@ -197,6 +197,8 @@ def get_lloyd_max_codebook(d: int, bits: int, device: str = 'cuda') -> tuple[Ten
197197
1.4898994, 2.0245643, 2.6740751],
198198
device=device, dtype=torch.float32
199199
)
200+
else:
201+
raise NotImplementedError
200202

201203
centroids = centroids_std * sigma
202204
boundaries = boundaries_std * sigma

0 commit comments

Comments
 (0)