Skip to content

make fp8 model quantized by llm-compressor can be inferenced in turbomind#4509

Open
43758726 wants to merge 2 commits intoInternLM:mainfrom
43758726:add/llm-compressor-fp8-inference
Open

make fp8 model quantized by llm-compressor can be inferenced in turbomind#4509
43758726 wants to merge 2 commits intoInternLM:mainfrom
43758726:add/llm-compressor-fp8-inference

Conversation

@43758726
Copy link
Copy Markdown
Collaborator

@43758726 43758726 commented Apr 8, 2026

…mind

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Make fp8 model quanztied by llm-compressor can be inferenced in turbomind engine successfully.

Modification

lmdeploy/lmdeploy/turbomind/deploy/converter.py: Add config about llm-compressor fp8 model.
lmdeploy/lmdeploy/turbomind/deploy/policy.py: Add judge about llm-compressor fp8 model in get_input_policy function.
lmdeploy/lmdeploy/turbomind/deploy/parameter.py: Add WeightScale class for llm-compressor fp8 model.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends Turbomind’s deploy/conversion pipeline to support FP8 models produced by llm-compressor under the compressed-tensors quantization config, enabling successful inference in the Turbomind engine.

Changes:

  • Add compressed-tensors format branching to map pack-quantized → AWQ(int4) and float-quantized → FP8 pathways.
  • Extend input tensor processing policy selection to handle compressed-tensors sub-formats.
  • Add a new WeightScale parameter handler intended for llm-compressor FP8 scale tensors.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
lmdeploy/turbomind/deploy/policy.py Selects different input tensor processing functions depending on compressed-tensors quantized sub-format.
lmdeploy/turbomind/deploy/parameter.py Adds WeightScale parameter export logic and wires it into get_params().
lmdeploy/turbomind/deploy/converter.py Adds compressed-tensors config handling, validates formats, and maps to existing AWQ/FP8 output config paths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 34 to 48
@@ -42,6 +43,8 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s
['hf', 'awq', 'gptq']
dtype (str): the data type of the model's weights and activations
group_size (int): the size of group used by awq model
quantized_format (str): the quantized format of compressed-tensors model,
which can be one of ['pack-quantized', 'float-quantized']
"""
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_output_model_registered_name_and_config now requires quantized_format, but the repo’s tests and existing callers invoke it without that argument (e.g. tests/test_lmdeploy/test_turbomind/test_converter.py::test_torch_dtype_fallback). This will raise a TypeError at runtime. Consider making quantized_format optional with a default (e.g. None) and updating the docstring/type hint accordingly, only validating it when model_format == 'compressed-tensors'.

Copilot uses AI. Check for mistakes.
Comment on lines +81 to +87
class WeightScale(Parameter):
KEYS = '.weight_scale', '.weight'

# TODO: flag any operations crossing the quant blocks as illegal
def __call__(self, f, g, i):
f(i, g('weight_scale'), 'scales', to_float, apply_gs=['w1', 'w3', 'w2'])
f(i, g('weight'), 'weight', identity)
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WeightScale.take() triggers on any key ending with .weight_scale, which also matches the existing CompressedWeight path. For compressed-tensors pack-quantized models this can cause WeightScale.__call__ to request g('weight') (a .weight tensor) that doesn’t exist, leading to a KeyError during export. Please tighten the selection logic so WeightScale only applies when both .weight_scale and .weight are present (and/or when .weight_packed is absent), or override take() to enforce that invariant.

Copilot uses AI. Check for mistakes.
if quantized_format == 'pack-quantized':
return process_compressed_packed_tensor
elif quantized_format == 'float-quantized':
return process_fp8
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_input_policy() has a model_format == 'compressed-tensors' branch that doesn’t return anything if quantized_format is missing/unknown, so the function returns None and later code will fail when calling the policy. It would be safer to either return a default policy or raise a clear ValueError for unsupported quantized_format values.

Suggested change
return process_fp8
return process_fp8
raise ValueError(f'Unsupported quantized_format for compressed-tensors: {quantized_format!r}')

Copilot uses AI. Check for mistakes.
assert _format == 'pack-quantized', ('compressed-tennsors only supports pack-quantized format, '
f'but got {_format}')
assert _format in ['pack-quantized', 'float-quantized'
], ('compressed-tennsors only supports pack-quantized/float-quantized format, '
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in assertion message: 'compressed-tennsors' → 'compressed-tensors'. Keeping error messages correct helps users debug config issues quickly.

Suggested change
], ('compressed-tennsors only supports pack-quantized/float-quantized format, '
], ('compressed-tensors only supports pack-quantized/float-quantized format, '

Copilot uses AI. Check for mistakes.
Comment on lines 80 to 92
if model_format in ['awq', 'gptq', 'compressed-tensors']:
weight_type = 'int4'
dtype = 'float16' # force float16 for int4 quantized weights
if model_format in ['awq', 'gptq']:
weight_type = 'int4'
dtype = 'float16' # force float16 for int4 quantized weights
elif model_format == 'compressed-tensors':
if quantized_format == 'pack-quantized':
weight_type = 'int4'
model_format = 'awq'
dtype = 'float16' # force float16 for int4 quantized weights
elif quantized_format == 'float-quantized':
weight_type = 'fp8'
model_format = 'fp8'
group_size = 128 if group_size == 0 else group_size
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New logic for compressed-tensors (pack-quantized vs float-quantized) changes both config mapping (compressed-tensorsawq/fp8) and validation of _num_bits/_type, but there are no unit tests covering these new branches. Since tests/test_lmdeploy/test_turbomind/test_converter.py already exercises this module, please add targeted tests that build a minimal quantization_config dict for each format and assert the selected weight_type/model_format and policy behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +43 to +47
['hf', 'awq', 'gptq']
dtype (str): the data type of the model's weights and activations
group_size (int): the size of group used by awq model
quantized_format (str): the quantized format of compressed-tensors model,
which can be one of ['pack-quantized', 'float-quantized']
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring for get_output_model_registered_name_and_config() still says model_format is one of ['hf', 'awq', 'gptq'], but the function now handles fp8, mxfp4, and compressed-tensors as well. Please update the docstring to reflect the supported values (and clarify how compressed-tensors maps to awq/fp8).

Suggested change
['hf', 'awq', 'gptq']
dtype (str): the data type of the model's weights and activations
group_size (int): the size of group used by awq model
quantized_format (str): the quantized format of compressed-tensors model,
which can be one of ['pack-quantized', 'float-quantized']
['hf', 'awq', 'gptq', 'fp8', 'mxfp4', 'compressed-tensors'].
For 'compressed-tensors', `quantized_format` determines the target
registration path: 'pack-quantized' maps to the AWQ path and
'float-quantized' maps to the FP8 path.
dtype (str): the data type of the model's weights and activations
group_size (int): the size of group used by awq model
quantized_format (str): the quantized format of a
compressed-tensors model, which can be one of
['pack-quantized', 'float-quantized']

Copilot uses AI. Check for mistakes.
@lvhan028 lvhan028 added the enhancement New feature or request label Apr 8, 2026
Comment on lines +216 to +220
assert (_num_bits == 4 and _type == 'int') or (_num_bits == 8 and _type == 'float'), (
'pack-quantized requires 4-bit int, '
f'but got {_num_bits}-bit {_type}. '
'or float-quantized requires 8-bit float, '
f'but got {_num_bits}-bit {_type}')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"pack-quantized requires int4 or fp8, but got type {_type} and {_num_bits} bits"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants