diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 05b1ba526f..a4afeff552 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -65,7 +65,8 @@ def get_input_model_registered_name(model_path: str, model_format: str): return register_name -def get_output_model_registered_name_and_config(model_path: str, model_format: str, dtype: str, group_size: int): +def get_output_model_registered_name_and_config(model_path: str, model_format: str, dtype: str, group_size: int, + quantized_format: str): """Get the registered name of the turbomind model and its configuration according to the input model path, format and user-input config. The name will be used to access the OUTPUT_MODELS registry. @@ -76,6 +77,8 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s ['hf', 'awq', 'gptq', 'compressed-tensors', 'fp8', 'mxfp4'] dtype (str): the data type of the model's weights and activations group_size (int): the quantization group size used by grouped formats + quantized_format (str): the quantized format of compressed-tensors model, + which can be one of ['pack-quantized', 'float-quantized'] """ register_name = 'tm' @@ -111,12 +114,17 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s group_size = _validate_quant_group_size(model_format, group_size) if model_format in ['awq', 'gptq', 'compressed-tensors']: - weight_type = 'int4' - dtype = 'float16' # force float16 for int4 quantized weights - if model_format == 'compressed-tensors': - # TurboMind reuses the AWQ int4 export path for pack-quantized - # compressed-tensors weights after the format-specific checks above. - model_format = 'awq' + if model_format in ['awq', 'gptq']: + weight_type = 'int4' + dtype = 'float16' # force float16 for int4 quantized weights + elif model_format == 'compressed-tensors': + if quantized_format == 'pack-quantized': + weight_type = 'int4' + model_format = 'awq' + dtype = 'float16' # force float16 for int4 quantized weights + elif quantized_format == 'float-quantized': + weight_type = 'fp8' + model_format = 'fp8' elif model_format == 'fp8': weight_type = 'fp8' elif model_format == 'mxfp4': @@ -231,18 +239,23 @@ def get_tm_model(model_path, _group_size = 32 elif quant_method == 'compressed-tensors': _format = quant_config['config_groups']['group_0']['format'] - assert _format == 'pack-quantized', ('compressed-tennsors only supports pack-quantized format, ' - f'but got {_format}') + assert _format in ['pack-quantized', 'float-quantized' + ], ('compressed-tennsors only supports pack-quantized/float-quantized format, ' + f'but got {_format}') _weights = quant_config['config_groups']['group_0']['weights'] _group_size = _weights['group_size'] _num_bits = _weights['num_bits'] _type = _weights['type'] - assert _num_bits == 4 and _type == 'int', ('pack-quantized requires 4-bit int, ' - f'but got {_num_bits}-bit {_type}') + assert (_num_bits == 4 and _type == 'int') or (_num_bits == 8 and _type == 'float'), ( + 'pack-quantized requires 4-bit int, ' + f'but got {_num_bits}-bit {_type}. ' + 'or float-quantized requires 8-bit float, ' + f'but got {_num_bits}-bit {_type}') else: assert 0, f'unsupported quant_config: {quant_config}' engine_config.model_format = quant_method + quantized_format = _format if quant_method == 'compressed-tensors' else None group_size = _group_size group_size = _validate_quant_group_size(engine_config.model_format, group_size) @@ -250,16 +263,19 @@ def get_tm_model(model_path, input_model_name = get_input_model_registered_name(model_path, engine_config.model_format) fp8_quant = (engine_config.model_format == 'fp8' and not quant_config) - input_policy = get_input_policy(engine_config.model_format) + input_policy = get_input_policy(engine_config.model_format, + quantized_format=quantized_format if quant_config else None) input_model = INPUT_MODELS.get(input_model_name)(model_path=model_path, tokenizer_path=model_path, input_policy=input_policy, fp8_quant=fp8_quant) - output_model_name, tm_cfg = get_output_model_registered_name_and_config(model_path=model_path, - model_format=engine_config.model_format, - dtype=engine_config.dtype, - group_size=group_size) + output_model_name, tm_cfg = get_output_model_registered_name_and_config( + model_path=model_path, + model_format=engine_config.model_format, + dtype=engine_config.dtype, + group_size=group_size, + quantized_format=quantized_format if quant_config else None) if mixed_awq: # Mixed-precision AWQ: attention weights are fp16 (not quantized), diff --git a/lmdeploy/turbomind/deploy/parameter.py b/lmdeploy/turbomind/deploy/parameter.py index 59c6f0158f..4d6f34a895 100644 --- a/lmdeploy/turbomind/deploy/parameter.py +++ b/lmdeploy/turbomind/deploy/parameter.py @@ -112,6 +112,32 @@ def __call__(self, f, g, i): f(i, g('weight'), 'weight', identity) +class WeightScale(Parameter): + KEYS = '.weight_scale', '.weight' + + # TODO: flag any operations crossing the quant blocks as illegal + def __call__(self, f, g, i): + f(i, g('weight_scale'), 'scales', to_float, apply_gs=['w1', 'w3', 'w2']) + f(i, g('weight'), 'weight', identity) + + +class CompressedWeight(Parameter): + KEYS = '.weight_packed', '.weight_scale', '.weight_zero_point' + + def __init__(self, xs): + self.has_zero_point = False + if any(key.endswith(self.KEYS[2]) for key in xs): + self.has_zero_point = True + + def __call__(self, f, g, i): + f(i, g('weight_packed'), 'qweight', pack_u4_row) + f(i, g('weight_scale'), 'scales', to_half, apply_gs=['w2']) + if self.has_zero_point: + f(i, g('weight_zero_point'), 'zeros', to_half, apply_gs=['w2']) + else: + f(i, generate_zero_point(g), 'zeros', to_half, apply_gs=['w2']) + + class Mxfp4Weight(Parameter): KEYS = '.blocks', '.scales' @@ -151,6 +177,11 @@ def get_params(keys: list[str], bias=0): ps.append(QuantWeightOnly(xs)) if WeightScaleInv.take(keys): ps.append(WeightScaleInv()) + if WeightScale.take(keys): + ps.append(WeightScale()) + xs = CompressedWeight.take(keys) + if xs: + ps.append(CompressedWeight(xs)) if Mxfp4Weight.take(keys): ps.append(Mxfp4Weight()) if Weight.take(keys): diff --git a/lmdeploy/turbomind/deploy/policy.py b/lmdeploy/turbomind/deploy/policy.py index 0e4c061c0d..6c54e4fa49 100644 --- a/lmdeploy/turbomind/deploy/policy.py +++ b/lmdeploy/turbomind/deploy/policy.py @@ -67,7 +67,7 @@ def process_fp8(x: torch.Tensor, kind: str): return x.to(dtype=torch.bfloat16) -def process_compressed_tensor(x: torch.Tensor, kind: str): +def process_compressed_packed_tensor(x: torch.Tensor, kind: str): x = x.cuda() if x.dtype == torch.int32: xs = get_u4_slices(x, torch.uint8) @@ -78,7 +78,7 @@ def process_compressed_tensor(x: torch.Tensor, kind: str): return x -def get_input_policy(model_format): +def get_input_policy(model_format, quantized_format=None): if model_format == 'awq': return process_awq_gemm elif model_format == 'gptq': @@ -88,6 +88,9 @@ def get_input_policy(model_format): elif model_format == 'fp8': return process_fp8 elif model_format == 'compressed-tensors': - return process_compressed_tensor + if quantized_format == 'pack-quantized': + return process_compressed_packed_tensor + elif quantized_format == 'float-quantized': + return process_fp8 else: return to_cuda