Skip to content

Commit 2d03817

Browse files
committed
refactor: introduce QuantPolicy enum and consolidate TurboQuant utilities
- Add QuantPolicy IntEnum to replace magic numbers (0, 4, 8, 42) for KV cache quantization policies: NONE, INT4, INT8, TURBO_QUANT - Update TurbomindEngineConfig and PytorchEngineConfig to use QuantPolicy type - Extract TurboQuant utilities (Hadamard rotation, Lloyd-Max codebook) from fill_kv_cache.py into new dedicated module turbo_quant.py - Rename butterfly_rotate/butterfly_rotate_inv to hadamard_rotate/ hadamard_rotate_inv for naming accuracy (the transform uses Hadamard matrix) - Update all call sites across attention kernels, cache engine, and tests - Update test fixtures and assertions to use QuantPolicy constants This improves type safety, code readability, and maintains backward compatibility through enum integer values matching previous magic numbers.
1 parent bbde920 commit 2d03817

File tree

17 files changed

+362
-301
lines changed

17 files changed

+362
-301
lines changed

lmdeploy/messages.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616

1717
logger = get_logger('lmdeploy')
1818

19+
20+
class QuantPolicy(enum.IntEnum):
21+
"""Quantization policy constants for KV cache."""
22+
NONE = 0
23+
INT4 = 4 # 4-bit KV cache
24+
INT8 = 8 # 8-bit KV cache
25+
TURBO_QUANT = 42 # TurboQuant: K=4bit QJL4 + V=2bit MSE
26+
1927
LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
2028
"""LogitsProcessor is a function that takes a tensor of input_ids, the logits
2129
tensor for the next token, and returns a modified tensor of logits to sample
@@ -298,7 +306,8 @@ def __post_init__(self):
298306
assert self.dtype in ['auto', 'float16', 'bfloat16']
299307
assert self.tp >= 1, 'tp must be a positive integer'
300308
assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count'
301-
assert self.quant_policy in (0, 4, 8, 42), 'invalid quant_policy'
309+
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \
310+
'invalid quant_policy'
302311
assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor'
303312
assert self.max_prefill_token_num >= 0, \
304313
'invalid max_prefill_token_num'
@@ -403,7 +412,7 @@ class PytorchEngineConfig:
403412
custom_module_map: dict[str, str] = None
404413
download_dir: str = None
405414
revision: str = None
406-
quant_policy: Literal[0, 4, 8, 42] = 0
415+
quant_policy: QuantPolicy = QuantPolicy.NONE
407416
distributed_executor_backend: str = None
408417
empty_init: bool = False
409418
enable_microbatch: bool = False
@@ -440,7 +449,8 @@ def __post_init__(self):
440449
assert self.max_prefill_token_num >= 0, \
441450
'invalid max_prefill_token_num'
442451
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
443-
assert self.quant_policy in (0, 4, 8, 42), 'invalid quant_policy'
452+
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \
453+
'invalid quant_policy'
444454
assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')
445455
assert self.block_size >= 16 and (self.block_size & (self.block_size - 1)) == 0, \
446456
f'block_size must be >= 16 and a power of 2, but got {self.block_size}'

lmdeploy/pytorch/backends/attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
44
from functools import lru_cache
5-
from typing import Generic, Literal, TypeVar
5+
from typing import Generic, TypeVar
66

77
import torch
88

9+
from lmdeploy.messages import QuantPolicy
10+
911

1012
@dataclass
1113
class AttentionMetadata:
@@ -18,7 +20,7 @@ class AttentionMetadata:
1820
fill_seqlens: torch.Tensor = None
1921
cu_seqlens_q: torch.Tensor = None
2022
cu_seqlens_k: torch.Tensor = None
21-
quant_policy: Literal[0, 4, 8, 42] = 0
23+
quant_policy: QuantPolicy = 0
2224

2325

2426
T = TypeVar('T', bound=AttentionMetadata)

lmdeploy/pytorch/backends/cuda/attention/default.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from dataclasses import dataclass
3-
from typing import Literal
43

54
import torch
65

6+
from lmdeploy.messages import QuantPolicy
77
from lmdeploy.pytorch.backends.attention import AttentionImpl, AttentionMetadata
88
from lmdeploy.utils import get_logger
99

@@ -40,7 +40,7 @@ class TritonAttentionMetadata(AttentionMetadata):
4040
q_seqlens: torch.Tensor = None
4141
kv_start_loc: torch.Tensor = None
4242
kv_seqlens: torch.Tensor = None
43-
quant_policy: Literal[0, 4, 8, 42] = 0
43+
quant_policy: QuantPolicy = 0
4444
kv_flatten_size: int = None
4545
# flash mla
4646
tile_scheduler_metadata: torch.Tensor = None
@@ -279,15 +279,15 @@ def _forward_prefill(
279279
flatten_kv_layout=kv_layout,
280280
)
281281

282-
# For quant_policy==42, flattened K/V are in rotated domain.
282+
# For quant_policy==QuantPolicy.TURBO_QUANT, flattened K/V are in rotated domain.
283283
# Rotate Q to match, and inverse-rotate output afterwards.
284-
if quant_policy == 42:
285-
from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import (
286-
butterfly_rotate,
287-
butterfly_rotate_inv,
284+
if quant_policy == QuantPolicy.TURBO_QUANT:
285+
from lmdeploy.pytorch.kernels.cuda.turbo_quant import (
286+
hadamard_rotate,
287+
hadamard_rotate_inv,
288288
)
289289
orig_dtype = query.dtype
290-
query = butterfly_rotate(query.float()).to(orig_dtype)
290+
query = hadamard_rotate(query.float()).to(orig_dtype)
291291

292292
attn_output = self.flash_attention_fwd(
293293
query,
@@ -309,8 +309,8 @@ def _forward_prefill(
309309
)
310310

311311
# Inverse-rotate output back to original domain
312-
if quant_policy == 42:
313-
attn_output = butterfly_rotate_inv(
312+
if quant_policy == QuantPolicy.TURBO_QUANT:
313+
attn_output = hadamard_rotate_inv(
314314
attn_output.float()
315315
).to(orig_dtype)
316316

lmdeploy/pytorch/backends/cuda/attention/mla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from lmdeploy.messages import QuantPolicy
78
from lmdeploy.utils import get_logger
89

910
from .default import TritonAttentionImpl, TritonAttentionMetadata
@@ -405,7 +406,7 @@ def _fill_kv_cache_impl(self,
405406
block_offsets = attn_metadata.block_offsets
406407
kv_seqlens = attn_metadata.kv_seqlens
407408
quant_policy = attn_metadata.quant_policy
408-
assert quant_policy == 0
409+
assert quant_policy == QuantPolicy.NONE
409410

410411
# fill seqlen args
411412
fill_seqlens, fill_max_q_seqlen, fill_q_start_loc = self._get_fill_meta(

lmdeploy/pytorch/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import enum
33
from collections.abc import Callable
44
from dataclasses import dataclass, field
5-
from typing import Any, Literal
5+
from typing import Any
66

77
import torch
88

9-
from lmdeploy.messages import PytorchEngineConfig
9+
from lmdeploy.messages import PytorchEngineConfig, QuantPolicy
1010
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
1111
from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value
1212
from lmdeploy.utils import get_logger, is_bf16_supported
@@ -98,7 +98,7 @@ class CacheConfig:
9898
cache_max_entry_count: float = 0.8
9999
max_prefill_token_num: int = 4096
100100
enable_prefix_caching: bool = False
101-
quant_policy: Literal[0, 4, 8, 42] = 0
101+
quant_policy: QuantPolicy = QuantPolicy.NONE
102102
device_type: str = 'cuda'
103103
num_state_caches: int = None
104104
states_shapes: list[tuple] = field(default_factory=list)

lmdeploy/pytorch/engine/cache_engine.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import math
55
from collections.abc import Sequence
66
from dataclasses import dataclass
7-
from typing import Literal
87

98
import torch
109

@@ -20,6 +19,7 @@
2019
)
2120
from lmdeploy.utils import get_logger
2221

22+
from ...messages import QuantPolicy
2323
from ..config import CacheConfig, ModelConfig
2424

2525
KVCache = tuple[torch.Tensor, torch.Tensor]
@@ -140,7 +140,7 @@ def _get_key_block_shape_impl(cls,
140140
block_size: int,
141141
head_size: int,
142142
world_size: int = 1,
143-
quant_policy: Literal[0, 4, 8, 42] = 0):
143+
quant_policy: QuantPolicy = 0):
144144
"""Get single block shape."""
145145
attn_backend = get_backend()
146146
dtype = model_config.dtype
@@ -155,7 +155,8 @@ def _get_key_block_shape_impl(cls,
155155
if model_config.use_mla_fp8_cache:
156156
return (block_size, num_heads, MLA_FP8_HEAD_DIM)
157157

158-
if quant_policy == 4 or quant_policy == 42: # pack head_dim to uint8 (4-bit)
158+
# pack head_dim to uint8 (4-bit)
159+
if quant_policy == QuantPolicy.INT4 or quant_policy == QuantPolicy.TURBO_QUANT:
159160
assert head_size % 2 == 0, \
160161
f'head_size: {head_size}, quant_policy: {quant_policy}'
161162
head_size = head_size // 2
@@ -167,7 +168,7 @@ def _get_value_block_shape_impl(cls,
167168
block_size: int,
168169
head_size: int,
169170
world_size: int = 1,
170-
quant_policy: Literal[0, 4, 8, 42] = 0):
171+
quant_policy: QuantPolicy = 0):
171172
"""Get single block shape."""
172173
attn_backend = get_backend()
173174
dtype = model_config.dtype
@@ -183,11 +184,11 @@ def _get_value_block_shape_impl(cls,
183184
# flash mla shared key and value
184185
return (block_size, num_heads, 0)
185186

186-
if quant_policy == 42: # pack head_dim to uint8 (2-bit for V cache)
187+
if quant_policy == QuantPolicy.TURBO_QUANT: # pack head_dim to uint8 (2-bit for V cache)
187188
assert head_size % 4 == 0, \
188189
f'head_size: {head_size}, quant_policy: {quant_policy}'
189190
head_size = head_size // 4
190-
elif quant_policy == 4: # pack head_dim to uint8 (4-bit)
191+
elif quant_policy == QuantPolicy.INT4: # pack head_dim to uint8 (4-bit)
191192
assert head_size % 2 == 0, \
192193
f'head_size: {head_size}, quant_policy: {quant_policy}'
193194
head_size = head_size // 2
@@ -209,7 +210,7 @@ def get_k_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig,
209210
)
210211
shape = list(shape)
211212
dtype = _get_kv_cache_dtype(model_config)
212-
if cache_config.quant_policy in (4, 8, 42):
213+
if cache_config.quant_policy in (QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT):
213214
dtype = torch.uint8
214215
return CacheDesc(shape=shape, dtype=dtype)
215216

@@ -228,21 +229,21 @@ def get_v_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig,
228229
)
229230
shape = list(shape)
230231
dtype = _get_kv_cache_dtype(model_config)
231-
if cache_config.quant_policy in (4, 8, 42):
232+
if cache_config.quant_policy in (QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT):
232233
dtype = torch.uint8
233234
return CacheDesc(shape=shape, dtype=dtype)
234235

235236
@classmethod
236237
def get_quant_cache_descs(cls, k_cache_desc: CacheDesc, v_cache_desc: CacheDesc, model_config: ModelConfig,
237238
cache_config: CacheConfig):
238239
"""Get quant cache descs."""
239-
if cache_config.quant_policy == 0:
240+
if cache_config.quant_policy == QuantPolicy.NONE:
240241
return []
241242

242243
dtype = model_config.dtype
243-
# For quant_policy==42, K uses 4-bit quantization (has MSE norm and QJL norm),
244+
# For quant_policy==QuantPolicy.TURBO_QUANT, K uses 4-bit quantization (has MSE norm and QJL norm),
244245
# V uses 2-bit quantization (only has MSE norm)
245-
if cache_config.quant_policy == 42:
246+
if cache_config.quant_policy == QuantPolicy.TURBO_QUANT:
246247
key_scale_zero_shape = k_cache_desc.shape[:-1] + [2]
247248
val_scale_zero_shape = v_cache_desc.shape[:-1] + [1]
248249
else:

0 commit comments

Comments
 (0)