Skip to content

feat: Add TurboQuant (quant_policy=42) support for KV Cache Quantization#4510

Open
windreamer wants to merge 14 commits intoInternLM:mainfrom
windreamer:turboquant-integration
Open

feat: Add TurboQuant (quant_policy=42) support for KV Cache Quantization#4510
windreamer wants to merge 14 commits intoInternLM:mainfrom
windreamer:turboquant-integration

Conversation

@windreamer
Copy link
Copy Markdown
Collaborator

@windreamer windreamer commented Apr 8, 2026

Motivation

TurboQuant is a near-optimal KV cache quantization technique from Google (ICLR 2026) that achieves 5x compression at 3-bit with 99.5% attention fidelity. This PR implements TurboQuant K4V2 (K=4bit, V=2bit mixed precision) KV cache quantization, addressing issue #4499.

Reference Implementations:

Modification

  1. Introduce QuantPolicy enum: Consolidate quantization policy definitions in a centralized enum, supporting KV_CACHE_QUANT, TURBO_QUANT, and NONE.

  2. Add TurboQuant kernel: Implement Hadamard rotation kernel for TurboQuant, supporting both forward and inverse rotation operations.

  3. Update KV cache kernels: Extend fill_kv_cache, flatten_kv_cache, and pagedattention kernels to support quant_policy==42.

  4. Update attention backends: Add TurboQuant support in CUDA attention implementations for default and MLA models.

  5. Add comprehensive tests: Include unit tests for TurboQuant kernels and QuantPolicy enum.

Usage

from lmdeploy import pipeline, TurbomindEngineConfig

pipe = pipeline(
    "internlm/internlm2_7b",
    backend_config=TurbomindEngineConfig(quant_policy=42)
)

Related Issues

windreamer and others added 6 commits April 8, 2026 15:57
…t_policy=42)

Add support for mixed precision KV cache quantization where K uses 4-bit
and V uses 2-bit quantization. This includes:

- Update flatten_kv_cache and pagedattention kernels for quant_policy=42
- Implement 2-fold/4-fold bit packing with simplified unpack operations
- Add fill_kv_cache support for TurboQuant K4V2
- Use FWHT to generate permutation matrices for TurboQuant
- Fix scales_zeros handling for last dimension == 1

Add test file to verify quant_policy=42 works correctly with PytorchEngine
for Qwen3-8B using normalized cosine similarity for reconstruction quality.
- Create tests/pytorch/kernel/test_turboquant.py for TurboQuant MSE
  algorithm tests (quant/dequant roundtrip, MSE bound, reconstruction
  quality, determinism)
- Remove duplicate TurboQuant tests from test_fill_kv_cache.py
- Add TestQuantPolicy42Accuracy in test_quant_policy.py to compare
  quant_policy=0 vs quant_policy=42 (logits, token accuracy, text
  quality, logprobs)
Implement quant_policy=42 (TurboQuant) support in fill_kv_cache, flatten_kv_cache,
and pagedattention kernels. This includes QJL4 quantization for K and 2-bit MSE
quantization for V, with butterfly rotation for improved accuracy.

Co-authored-by: openhands <openhands@all-hands.dev>
Extract TurboQuant (quant_policy=42) test utilities into turboquant_utils.py
and reorganize tests into dedicated test files for better maintainability.

Co-authored-by: openhands <openhands@all-hands.dev>
- Update butterfly_rotate_inv to use rotation matrix approach instead of
  removed fwht and _get_deterministic_signs functions
- Change default dtype in _get_rotation_matrix from float16 to float32
  to match expected dtype in tests

Co-authored-by: openhands <openhands@all-hands.dev>
When quant_policy==42, the flattened K/V are in rotated domain.
Rotate Q to match, and inverse-rotate output afterwards.

Also fix flatten_kv_cache to output float16 when quant_policy==42.
@windreamer windreamer force-pushed the turboquant-integration branch 4 times, most recently from a8b2fd7 to 2d03817 Compare April 8, 2026 10:04
…ties

- Add QuantPolicy IntEnum to replace magic numbers (0, 4, 8, 42) for KV cache
  quantization policies: NONE, INT4, INT8, TURBO_QUANT
- Update TurbomindEngineConfig and PytorchEngineConfig to use QuantPolicy type
- Extract TurboQuant utilities (Hadamard rotation, Lloyd-Max codebook) from
  fill_kv_cache.py into new dedicated module turbo_quant.py
- Rename butterfly_rotate/butterfly_rotate_inv to hadamard_rotate/
  hadamard_rotate_inv for naming accuracy (the transform uses Hadamard matrix)
- Update all call sites across attention kernels, cache engine, and tests
- Update test fixtures and assertions to use QuantPolicy constants

This improves type safety, code readability, and maintains backward
compatibility through enum integer values matching previous magic numbers.
@windreamer windreamer force-pushed the turboquant-integration branch from 2d03817 to 62c1db2 Compare April 8, 2026 12:27
@grimoire
Copy link
Copy Markdown
Collaborator

grimoire commented Apr 8, 2026

Attention might be dispatched to fa3implementation

…nsform

- Use fast_hadamard_transform library when available (2.4x faster)
- Fallback to matmul-based Walsh-Hadamard if not installed
- Fuse dtype conversion inside hadamard_rotate/hadamard_rotate_inv
- Use standard Walsh-Hadamard matrix (same as fast_hadamard_transform)
- Add logging to notify when falling back to matmul implementation

Co-authored-by: openhands <openhands@all-hands.dev>
@windreamer windreamer marked this pull request as ready for review April 9, 2026 04:30
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds TurboQuant KV-cache quantization support (QuantPolicy.TURBO_QUANT / quant_policy=42) to the PyTorch backend, including Hadamard-rotation + Lloyd–Max codebooks, and extends the CUDA/Triton KV-cache and attention kernels plus tests to validate correctness.

Changes:

  • Introduce QuantPolicy enum (including TURBO_QUANT=42) and plumb it through engine/config/metadata.
  • Add TurboQuant utilities (hadamard_rotate, hadamard_rotate_inv, precomputed Lloyd–Max codebooks) and implement TurboQuant support in fill_kv_cache, flatten_kv_cache, and paged attention kernels/backends.
  • Add kernel-level and pipeline-level tests for TurboQuant behavior and accuracy.

Reviewed changes

Copilot reviewed 18 out of 19 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/test_lmdeploy/test_quant_policy.py Adds end-to-end pipeline tests comparing TurboQuant vs non-quantized baseline.
tests/pytorch/kernel/turboquant_utils.py Adds shared Python reference quant/dequant helpers + metrics for TurboQuant kernel tests.
tests/pytorch/kernel/test_turboquant.py Adds kernel-level reference tests for TurboQuant MSE and QJL4 behavior.
tests/pytorch/kernel/test_paged_attention.py Extends paged-attention tests for TurboQuant cache formats and output comparisons.
tests/pytorch/kernel/test_flatten_kv_cache.py Adds TurboQuant flatten-cache coverage and uses shared test utilities.
tests/pytorch/kernel/test_fill_kv_cache.py Adds TurboQuant fill-cache tests and reference comparisons.
lmdeploy/pytorch/model_inputs.py Updates step context typing to use QuantPolicy.
lmdeploy/pytorch/kernels/cuda/turbo_quant.py New TurboQuant helper module (Hadamard rotation + Lloyd–Max codebooks).
lmdeploy/pytorch/kernels/cuda/pagedattention.py Adds TurboQuant support to decoding paged-attention kernel path.
lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py Adds TurboQuant dequantization path in KV-cache flattening kernel.
lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py Adds TurboQuant quantization path in KV-cache fill kernel (including rotation).
lmdeploy/pytorch/engine/cache_engine.py Adjusts KV-cache shapes/meta layouts for TurboQuant (K meta=2, V meta=1).
lmdeploy/pytorch/config.py Updates cache config typing/default to use QuantPolicy.
lmdeploy/pytorch/backends/cuda/attention/mla.py Adjusts quant-policy assertion to use QuantPolicy.NONE.
lmdeploy/pytorch/backends/cuda/attention/fa3.py Adds TurboQuant handling for FA3 prefill and explicit unsupported speculative decoding.
lmdeploy/pytorch/backends/cuda/attention/default.py Adds TurboQuant rotation/inverse-rotation around prefill attention.
lmdeploy/pytorch/backends/attention.py Updates attention metadata typing to use QuantPolicy.
lmdeploy/messages.py Defines QuantPolicy and expands config validation to accept TurboQuant.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@windreamer windreamer force-pushed the turboquant-integration branch from 40b9fb6 to 21a7d08 Compare April 9, 2026 04:59
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@windreamer windreamer force-pushed the turboquant-integration branch from 21a7d08 to 1255c52 Compare April 9, 2026 05:30
…ine centroid calculation

- Replace memory-loaded Lloyd-Max codebook with pure register-based centroid
  calculation in PagedAttention kernel (quant_policy==42)
- Add _k4v2_k_centroid and _k4v2_v_centroid triton JIT functions for inline
  centroid lookup, reducing memory access and improving performance
- Remove turbo_quant flag and codebook pointer parameters from kernel
- Fix conditional branch bug in turbo_quant.py (else -> elif bits == 4)
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 19 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

- Fix _unpack_indices in turboquant_utils.py to properly unpack nibble-packed
  data for nbits=4 (was returning packed data without unpacking)
- Fix quant_turboquant_mse to only support nbits=2 (V cache quantization)
- Refactor test_turboquant.py to import shared utilities from turboquant_utils.py
  instead of maintaining duplicate implementations (~180 lines removed)
- Fix raw integer 42 in test_fill_kv_cache.py to use QuantPolicy.TURBO_QUANT
- Remove unused 4-bit Lloyd-Max codebook from get_lloyd_max_codebook()
  (production only uses 2-bit for V and 3-bit for K with QJL4)
- Rename _get_hadamard_matrix to get_hadamard_matrix (expose as public API)
- Update comments to reflect K=3bit QJL4, V=2bit mixed precision

Co-authored-by: openhands <openhands@all-hands.dev>
@windreamer windreamer force-pushed the turboquant-integration branch from 768a62c to dd62e4b Compare April 9, 2026 08:52
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 19 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 19 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 19 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

…nels

Replace Python type annotations (x: tl.constexpr = value) with proper
Triton constexpr instantiation (x = tl.constexpr(value)) to allow
Triton kernels to access these global variables.

Also simplify fill_kv_cache by removing redundant k_quant_policy,
v_quant_policy parameters, using quant_policy + is_value instead
and vectorize quant42 quantization in _make_blocked_cache_quant42

Co-authored-by: openhands <openhands@all-hands.dev>
@windreamer windreamer force-pushed the turboquant-integration branch from 5553de8 to 1824242 Compare April 10, 2026 02:20
This optimization fuses the split-K reduction and inverse Hadamard transform
into a single Triton kernel for TURBO_QUANT, reducing kernel launch overhead
and eliminating extra memory allocation by reusing the acc buffer as scratch
space for the butterfly operations.

- Add _bar_sync() for CTA-level synchronization
- Add _fused_reduce_hadamard_kernel() that performs both operations
- Remove separate hadamard_rotate_inv call in flash_attn_with_kvcache
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 18 out of 19 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@windreamer
Copy link
Copy Markdown
Collaborator Author

TurboQuant(K4V2) on H200

Setup

  • GPU: H200
  • Model: Qwen3-30B-A3B-Base
  • Dataset: ShareGPT
  • Backend: pytorch
  • Concurrency: 64
  • Requests: 5000

Key point

TurboQuant K4V2 (quant_policy=42) provides effective 3-bit KV cache, i.e. ~5x KV cache memory reduction vs. the non-quantized baseline (quant_policy=0).

Performance cost

Under the same benchmark setting, the end-to-end overhead is moderate:

  • Input throughput: 2368.8 -> 2195.8 tok/s (-7.3%)
  • Output throughput: 2186.7 -> 2027.0 tok/s (-7.3%)
  • Request throughput: 10.74 -> 9.96 req/s (-7.3%)
  • Mean E2E latency: 5.888s -> 6.348s (+7.8%)
  • Mean TTFT: 1.139s -> 1.235s (+8.4%)
  • Mean TPOT: 0.024s -> 0.026s (+8.3%)
  • Mean ITL: 0.059s -> 0.059s (roughly unchanged)

Takeaway

On H200, TurboQuant K4V2 achieves ~5x KV cache reduction with about 7%–8% end-to-end performance overhead, which looks like a reasonable trade-off for memory-bound serving scenarios.

@hicofeng
Copy link
Copy Markdown

Is there a plan to add support for the TurboMind engine, perhaps with a more noticeable performance overhead compared to the PyTorch engine?

@windreamer
Copy link
Copy Markdown
Collaborator Author

Is there a plan to add support for the TurboMind engine, perhaps with a more noticeable performance overhead compared to the PyTorch engine?

We are still discussing internally if we should introduce this feature. Currently we are not quite sure if the quantization quality is good enough for practice usage.

You can see currently we use TurboQuant 4bit (3bit MSE + 1bit QJL) for the key cache and 2bit MSE for value cache. But a lot of discussion from various communities believe 4bit MSE is better for key cache as variations are more harmful than biases.

I believe we still have a lot of work need to be done or evaluated before we can plan for TurboMind integration.

Can you kindly elaborate further why do you think TurboMind integration is needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add TurboQuant Support for KV Cache Quantization

4 participants