1212from lmdeploy .messages import QuantPolicy
1313from lmdeploy .utils import get_logger
1414
15- from .turbo_quant import get_lloyd_max_codebook , hadamard_rotate , hadamard_rotate_inv
15+ from .turbo_quant import hadamard_rotate , hadamard_rotate_inv
1616from .utils import get_device_props
1717
1818logger = get_logger ('lmdeploy' )
@@ -220,15 +220,53 @@ def _fwd_grouped_split_kernel(
220220 tl .store (acc_out_ptr + off_meta + 1 , l_i , mask = mask_h )
221221
222222
223+
224+
225+ @triton .jit
226+ def _k4v2_k_centroid (idx3 , head_size : tl .constexpr ):
227+ """QJL4 K centroid lookup: 8 entries, pure register."""
228+ # Lloyd-Max 3-bit centroids at sigma=1
229+ S0 : tl .constexpr = - 2.1519456
230+ S1 : tl .constexpr = - 1.3439093
231+ S2 : tl .constexpr = - 0.7560052
232+ S3 : tl .constexpr = - 0.2450942
233+ S4 : tl .constexpr = 0.2450942
234+ S5 : tl .constexpr = 0.7560052
235+ S6 : tl .constexpr = 1.3439093
236+ S7 : tl .constexpr = 2.1519456
237+ sigma : tl .constexpr = 1.0 / tl .math .sqrt (head_size * 2.0 )
238+ c = tl .where (idx3 < 4 ,
239+ tl .where (idx3 < 2 ,
240+ tl .where (idx3 == 0 , S0 , S1 ),
241+ tl .where (idx3 == 2 , S2 , S3 )),
242+ tl .where (idx3 < 6 ,
243+ tl .where (idx3 == 4 , S4 , S5 ),
244+ tl .where (idx3 == 6 , S6 , S7 )))
245+ return c * sigma
246+
247+
248+ @triton .jit
249+ def _k4v2_v_centroid (idx2 , head_size_v : tl .constexpr ):
250+ """MSE int2 V centroid lookup: 4 entries, pure register."""
251+ # Lloyd-Max 2-bit centroids at sigma=1
252+ S0 : tl .constexpr = - 1.5104176
253+ S1 : tl .constexpr = - 0.4527808
254+ S2 : tl .constexpr = 0.4527808
255+ S3 : tl .constexpr = 1.5104176
256+ sigma : tl .constexpr = 1.0 / tl .math .sqrt (head_size_v * 4.0 )
257+ c = tl .where (idx2 < 2 ,
258+ tl .where (idx2 == 0 , S0 , S1 ),
259+ tl .where (idx2 == 2 , S2 , S3 ))
260+ return c * sigma
261+
262+
223263@triton .jit
224264def _fwd_grouped_split_quant_kernel (
225265 q_ptr ,
226266 k_ptr ,
227267 v_ptr ,
228268 KScalesZeros ,
229269 VScalesZeros ,
230- k_codebook_ptr ,
231- v_codebook_ptr ,
232270 sm_scale ,
233271 cache_seqlens_ptr ,
234272 page_table_ptr ,
@@ -254,7 +292,6 @@ def _fwd_grouped_split_quant_kernel(
254292 stride_vszh : tl .constexpr ,
255293 stride_vszd : tl .constexpr ,
256294 quant_policy : tl .constexpr ,
257- turbo_quant : tl .constexpr ,
258295 stride_ok : tl .constexpr ,
259296 stride_obs : tl .constexpr ,
260297 stride_oh : tl .constexpr ,
@@ -406,23 +443,14 @@ def _fwd_grouped_split_quant_kernel(
406443 if quant_policy == 4 or quant_policy == 42 :
407444 k = (k >> shift_kd ) & 0x0F
408445
409- if turbo_quant and quant_policy == 42 :
410- # K = QJL4:
411- # low 3bit = mse idx
412- # high 1bit = qjl sign
446+ if quant_policy == 42 :
413447 kmse_norm = tl .load (ksz_ptrs + b_offset * stride_kszp )
414448 kqjl_norm = tl .load (ksz_ptrs + b_offset * stride_kszp + stride_kszd )
415449
416- k_idx3 = k & 0x7
417- k_bit1 = (k >> 3 ) & 0x1
418- k_sign = k_bit1 .to (tl .float32 ) * 2.0 - 1.0
419-
420- k = tl .load (k_codebook_ptr + k_idx3 .to (tl .int32 ))
421- k = (kmse_norm * (k + kqjl_norm * k_sign )).to (q .dtype )
422- elif turbo_quant :
423- ks = tl .load (ksz_ptrs + b_offset * stride_kszp )
424- k = tl .load (k_codebook_ptr + k .to (tl .int32 ))
425- k = (k * ks ).to (q .dtype )
450+ # k is 4-bit nibble: low 3 = mse_idx, high 1 = sign
451+ k_cent = _k4v2_k_centroid ((k & 0x7 ), head_size )
452+ k_sign = ((k >> 3 ) & 0x1 ).to (tl .float32 ) * 2.0 - 1.0
453+ k = (kmse_norm * (k_cent + kqjl_norm * k_sign )).to (q .dtype )
426454 else :
427455 ks = tl .load (ksz_ptrs + b_offset * stride_kszp )
428456 kz = tl .load (ksz_ptrs + b_offset * stride_kszp + 1 )
@@ -433,19 +461,13 @@ def _fwd_grouped_split_quant_kernel(
433461 if quant_policy == 4 or quant_policy == 42 :
434462 k1 = (k1 >> shift_k1d ) & 0x0F
435463
436- if turbo_quant and quant_policy == 42 :
464+ if quant_policy == 42 :
437465 kmse_norm = tl .load (ksz_ptrs + b_offset * stride_kszp )
438466 kqjl_norm = tl .load (ksz_ptrs + b_offset * stride_kszp + stride_kszd )
439467
440- k1_idx3 = k1 & 0x7
441- k1_bit1 = (k1 >> 3 ) & 0x1
442- k1_sign = k1_bit1 .to (tl .float32 ) * 2.0 - 1.0
443-
444- k1 = tl .load (k_codebook_ptr + k1_idx3 .to (tl .int32 ))
445- k1 = (kmse_norm * (k1 + kqjl_norm * k1_sign )).to (q .dtype )
446- elif turbo_quant :
447- k1 = tl .load (k_codebook_ptr + k1 .to (tl .int32 ))
448- k1 = (k1 * ks ).to (q .dtype )
468+ k1_cent = _k4v2_k_centroid ((k1 & 0x7 ), head_size )
469+ k1_sign = ((k1 >> 3 ) & 0x1 ).to (tl .float32 ) * 2.0 - 1.0
470+ k1 = (kmse_norm * (k1_cent + kqjl_norm * k1_sign )).to (q .dtype )
449471 else :
450472 k1 = ((k1 - kz ) * ks ).to (q .dtype )
451473
@@ -459,9 +481,9 @@ def _fwd_grouped_split_quant_kernel(
459481 else :
460482 v = tl .load (v_ptr + off_v + b_offset * stride_vp )
461483
462- if turbo_quant :
484+ if quant_policy == 42 :
463485 vs = tl .load (vsz_ptrs + b_offset * stride_vszp )
464- v = tl . load ( v_codebook_ptr + v . to ( tl . int32 ) )
486+ v = _k4v2_v_centroid ( v , head_size_v )
465487 v = (v * vs ).to (q .dtype )
466488 else :
467489 vs = tl .load (vsz_ptrs + b_offset * stride_vszp )
@@ -690,12 +712,6 @@ def _get_block_d(Lk):
690712 BLOCK_DV = triton .next_power_of_2 (Lv )
691713 return BLOCK_DMODEL , BLOCK_DMODEL1 , BLOCK_DV
692714
693- turbo_quant = False
694- # Triton still receives these arguments for quantized paths, so keep
695- # valid tensor-backed pointers even when turbo quant is not enabled.
696- # They will be overwritten with real codebooks when quant_policy == 42.
697- turbo_k_codebook = q .new_empty ((1 , ))
698- turbo_v_codebook = q .new_empty ((1 , ))
699715
700716 # shape constraints
701717 Lq , Lk , Lv = q .shape [- 1 ], k_cache .shape [d_dim ], v_cache .shape [d_dim ]
@@ -720,19 +736,13 @@ def _get_block_d(Lk):
720736 # - V dequant as norm * centroid[idx2]
721737 # - output inverse-rotated because V is still rotated before caching
722738 if quant_policy == 42 :
723- turbo_quant = True
724739 real_k_dim = Lq
725740 real_v_dim = Lv * 4
726741 if real_k_dim & (real_k_dim - 1 ) != 0 :
727742 raise ValueError (f'TurboQuant requires power-of-2 K/Q head dim, got { real_k_dim } ' )
728743 if real_v_dim & (real_v_dim - 1 ) != 0 :
729744 raise ValueError (f'TurboQuant requires power-of-2 V head dim, got { real_v_dim } ' )
730745
731- # K = QJL4 => 3bit centroid codebook
732- turbo_k_codebook , _ = get_lloyd_max_codebook (real_k_dim , bits = 3 , device = q .device )
733- # V = TurboQuant MSE int2 => 2bit centroid codebook
734- turbo_v_codebook , _ = get_lloyd_max_codebook (real_v_dim , bits = 2 , device = q .device )
735-
736746 # Rotate query into the same domain as quantized K/V
737747 q = hadamard_rotate (q )
738748
@@ -795,8 +805,6 @@ def _get_block_d(Lk):
795805 v_cache ,
796806 k_scales_zeros ,
797807 v_scales_zeros ,
798- turbo_k_codebook ,
799- turbo_v_codebook ,
800808 softmax_scale ,
801809 cache_seqlens ,
802810 page_table ,
@@ -822,7 +830,6 @@ def _get_block_d(Lk):
822830 stride_vszh = v_scales_zeros .stride (h_dim ),
823831 stride_vszd = v_scales_zeros .stride (d_dim ),
824832 quant_policy = quant_policy ,
825- turbo_quant = turbo_quant ,
826833 stride_ok = acc .stride (- 2 ),
827834 stride_obs = acc .stride (- 4 ),
828835 stride_oh = acc .stride (- 3 ),
0 commit comments