Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

logger = get_logger('lmdeploy')


class QuantPolicy(enum.IntEnum):
"""Quantization policy constants for KV cache."""
NONE = 0
INT4 = 4 # 4-bit KV cache
INT8 = 8 # 8-bit KV cache
TURBO_QUANT = 42 # TurboQuant: K=4bit QJL4 + V=2bit MSE

LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a tensor of input_ids, the logits
tensor for the next token, and returns a modified tensor of logits to sample
Expand Down Expand Up @@ -298,7 +306,8 @@ def __post_init__(self):
assert self.dtype in ['auto', 'float16', 'bfloat16']
assert self.tp >= 1, 'tp must be a positive integer'
assert self.cache_max_entry_count > 0, 'invalid cache_max_entry_count'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \
'invalid quant_policy'
assert self.rope_scaling_factor >= 0, 'invalid rope_scaling_factor'
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
Expand Down Expand Up @@ -403,7 +412,7 @@ class PytorchEngineConfig:
custom_module_map: dict[str, str] = None
download_dir: str = None
revision: str = None
quant_policy: Literal[0, 4, 8] = 0
quant_policy: QuantPolicy = QuantPolicy.NONE
distributed_executor_backend: str = None
empty_init: bool = False
enable_microbatch: bool = False
Expand Down Expand Up @@ -440,7 +449,8 @@ def __post_init__(self):
assert self.max_prefill_token_num >= 0, \
'invalid max_prefill_token_num'
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.quant_policy in (QuantPolicy.NONE, QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT), \
'invalid quant_policy'
assert self.device_type in ['cuda', 'ascend', 'maca', 'camb'], (f'invalid device_type: {self.device_type}')
assert self.block_size >= 16 and (self.block_size & (self.block_size - 1)) == 0, \
f'block_size must be >= 16 and a power of 2, but got {self.block_size}'
Expand Down
6 changes: 4 additions & 2 deletions lmdeploy/pytorch/backends/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import Generic, Literal, TypeVar
from typing import Generic, TypeVar

import torch

from lmdeploy.messages import QuantPolicy


@dataclass
class AttentionMetadata:
Expand All @@ -18,7 +20,7 @@ class AttentionMetadata:
fill_seqlens: torch.Tensor = None
cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None
quant_policy: Literal[0, 4, 8] = 0
quant_policy: QuantPolicy = 0


T = TypeVar('T', bound=AttentionMetadata)
Expand Down
18 changes: 16 additions & 2 deletions lmdeploy/pytorch/backends/cuda/attention/default.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Literal

import torch

from lmdeploy.messages import QuantPolicy
from lmdeploy.pytorch.backends.attention import AttentionImpl, AttentionMetadata
from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -40,7 +40,7 @@ class TritonAttentionMetadata(AttentionMetadata):
q_seqlens: torch.Tensor = None
kv_start_loc: torch.Tensor = None
kv_seqlens: torch.Tensor = None
quant_policy: Literal[0, 4, 8] = 0
quant_policy: QuantPolicy = 0
kv_flatten_size: int = None
# flash mla
tile_scheduler_metadata: torch.Tensor = None
Expand Down Expand Up @@ -279,6 +279,15 @@ def _forward_prefill(
flatten_kv_layout=kv_layout,
)

# For quant_policy==QuantPolicy.TURBO_QUANT, flattened K/V are in rotated domain.
# Rotate Q to match, and inverse-rotate output afterwards.
if quant_policy == QuantPolicy.TURBO_QUANT:
from lmdeploy.pytorch.kernels.cuda.turbo_quant import (
hadamard_rotate,
hadamard_rotate_inv,
)
query = hadamard_rotate(query)

attn_output = self.flash_attention_fwd(
query,
flatten_k,
Expand All @@ -297,6 +306,11 @@ def _forward_prefill(
block_sparse_size=self.block_sparse_size,
kv_layout=kv_layout,
)

# Inverse-rotate output back to original domain
if quant_policy == QuantPolicy.TURBO_QUANT:
attn_output = hadamard_rotate_inv(attn_output)

return attn_output

def forward(
Expand Down
28 changes: 28 additions & 0 deletions lmdeploy/pytorch/backends/cuda/attention/fa3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from lmdeploy.messages import QuantPolicy
from lmdeploy.utils import get_logger

from .default import TritonAttentionImpl, TritonAttentionMetadata
Expand Down Expand Up @@ -102,6 +103,19 @@ def _decoding_speculative(
Returns:
Attention output tensor.
"""
quant_policy = attn_metadata.quant_policy

# TurboQuant stores packed uint8 data in cache, which FA3's native
# flash_attn_with_kvcache cannot dequantize directly.
if quant_policy == QuantPolicy.TURBO_QUANT:
raise NotImplementedError(
'quant_policy=QuantPolicy.TURBO_QUANT is not supported with '
'FA3 speculative decoding (max_q_seqlen > 1). '
'FA3 speculative decoding accesses raw KV cache directly '
'and cannot dequantize TurboQuant packed data. '
'Use standard decoding (max_q_seqlen=1).'
)

block_offsets = attn_metadata.block_offsets
sliding_window = self._normalize_sliding_window(self.sliding_window)

Expand Down Expand Up @@ -257,6 +271,15 @@ def _forward_prefill(

sliding_window = self._normalize_sliding_window(self.sliding_window)

# For TurboQuant, flattened K/V are in rotated domain.
# Rotate Q to match, and inverse-rotate output afterwards.
if quant_policy == QuantPolicy.TURBO_QUANT:
from lmdeploy.pytorch.kernels.cuda.turbo_quant import (
hadamard_rotate,
hadamard_rotate_inv,
)
query = hadamard_rotate(query)

attn_output = self.flash_attn_varlen_func_v3(
q=query,
k=flatten_k,
Expand All @@ -270,6 +293,11 @@ def _forward_prefill(
window_size=sliding_window,
softcap=self.logit_softcapping,
)

# Inverse-rotate output back to original domain
if quant_policy == QuantPolicy.TURBO_QUANT:
attn_output = hadamard_rotate_inv(attn_output)

return attn_output

def forward(
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/backends/cuda/attention/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from lmdeploy.messages import QuantPolicy
from lmdeploy.utils import get_logger

from .default import TritonAttentionImpl, TritonAttentionMetadata
Expand Down Expand Up @@ -405,7 +406,7 @@ def _fill_kv_cache_impl(self,
block_offsets = attn_metadata.block_offsets
kv_seqlens = attn_metadata.kv_seqlens
quant_policy = attn_metadata.quant_policy
assert quant_policy == 0
assert quant_policy == QuantPolicy.NONE

# fill seqlen args
fill_seqlens, fill_max_q_seqlen, fill_q_start_loc = self._get_fill_meta(
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import enum
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Literal
from typing import Any

import torch

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.messages import PytorchEngineConfig, QuantPolicy
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
from lmdeploy.pytorch.utils import maybe_register_config_serialize_by_value
from lmdeploy.utils import get_logger, is_bf16_supported
Expand Down Expand Up @@ -98,7 +98,7 @@ class CacheConfig:
cache_max_entry_count: float = 0.8
max_prefill_token_num: int = 4096
enable_prefix_caching: bool = False
quant_policy: Literal[0, 4, 8] = 0
quant_policy: QuantPolicy = QuantPolicy.NONE
device_type: str = 'cuda'
num_state_caches: int = None
states_shapes: list[tuple] = field(default_factory=list)
Expand Down
31 changes: 21 additions & 10 deletions lmdeploy/pytorch/engine/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import math
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal

import torch

Expand All @@ -20,6 +19,7 @@
)
from lmdeploy.utils import get_logger

from ...messages import QuantPolicy
from ..config import CacheConfig, ModelConfig

KVCache = tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -140,7 +140,7 @@ def _get_key_block_shape_impl(cls,
block_size: int,
head_size: int,
world_size: int = 1,
quant_policy: Literal[0, 4, 8] = 0):
quant_policy: QuantPolicy = 0):
"""Get single block shape."""
attn_backend = get_backend()
dtype = model_config.dtype
Expand All @@ -155,7 +155,8 @@ def _get_key_block_shape_impl(cls,
if model_config.use_mla_fp8_cache:
return (block_size, num_heads, MLA_FP8_HEAD_DIM)

if quant_policy == 4: # pack head_dim to uint8
# pack head_dim to uint8 (4-bit)
if quant_policy == QuantPolicy.INT4 or quant_policy == QuantPolicy.TURBO_QUANT:
assert head_size % 2 == 0, \
f'head_size: {head_size}, quant_policy: {quant_policy}'
head_size = head_size // 2
Expand All @@ -167,7 +168,7 @@ def _get_value_block_shape_impl(cls,
block_size: int,
head_size: int,
world_size: int = 1,
quant_policy: Literal[0, 4, 8] = 0):
quant_policy: QuantPolicy = 0):
"""Get single block shape."""
attn_backend = get_backend()
dtype = model_config.dtype
Expand All @@ -183,7 +184,11 @@ def _get_value_block_shape_impl(cls,
# flash mla shared key and value
return (block_size, num_heads, 0)

if quant_policy == 4: # pack head_dim to uint8
if quant_policy == QuantPolicy.TURBO_QUANT: # pack head_dim to uint8 (2-bit for V cache)
assert head_size % 4 == 0, \
f'head_size: {head_size}, quant_policy: {quant_policy}'
head_size = head_size // 4
elif quant_policy == QuantPolicy.INT4: # pack head_dim to uint8 (4-bit)
assert head_size % 2 == 0, \
f'head_size: {head_size}, quant_policy: {quant_policy}'
head_size = head_size // 2
Expand All @@ -205,7 +210,7 @@ def get_k_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig,
)
shape = list(shape)
dtype = _get_kv_cache_dtype(model_config)
if cache_config.quant_policy in (4, 8):
if cache_config.quant_policy in (QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT):
dtype = torch.uint8
return CacheDesc(shape=shape, dtype=dtype)

Expand All @@ -224,20 +229,26 @@ def get_v_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig,
)
shape = list(shape)
dtype = _get_kv_cache_dtype(model_config)
if cache_config.quant_policy in (4, 8):
if cache_config.quant_policy in (QuantPolicy.INT4, QuantPolicy.INT8, QuantPolicy.TURBO_QUANT):
dtype = torch.uint8
return CacheDesc(shape=shape, dtype=dtype)

@classmethod
def get_quant_cache_descs(cls, k_cache_desc: CacheDesc, v_cache_desc: CacheDesc, model_config: ModelConfig,
cache_config: CacheConfig):
"""Get quant cache descs."""
if cache_config.quant_policy == 0:
if cache_config.quant_policy == QuantPolicy.NONE:
return []

dtype = model_config.dtype
key_scale_zero_shape = k_cache_desc.shape[:-1] + [2]
val_scale_zero_shape = v_cache_desc.shape[:-1] + [2]
# For quant_policy==QuantPolicy.TURBO_QUANT, K uses 4-bit quantization (has MSE norm and QJL norm),
# V uses 2-bit quantization (only has MSE norm)
if cache_config.quant_policy == QuantPolicy.TURBO_QUANT:
key_scale_zero_shape = k_cache_desc.shape[:-1] + [2]
val_scale_zero_shape = v_cache_desc.shape[:-1] + [1]
else:
key_scale_zero_shape = k_cache_desc.shape[:-1] + [2]
val_scale_zero_shape = v_cache_desc.shape[:-1] + [2]
key_scale_zero_desc = CacheDesc(shape=key_scale_zero_shape, dtype=dtype)
val_scale_zero_desc = CacheDesc(shape=val_scale_zero_shape, dtype=dtype)
return [key_scale_zero_desc, val_scale_zero_desc]
Expand Down
Loading
Loading