-
Notifications
You must be signed in to change notification settings - Fork 682
make fp8 model quantized by llm-compressor can be inferenced in turbomind #4509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -31,7 +31,8 @@ def get_input_model_registered_name(model_path: str, model_format: str): | |||||
| return register_name | ||||||
|
|
||||||
|
|
||||||
| def get_output_model_registered_name_and_config(model_path: str, model_format: str, dtype: str, group_size: int): | ||||||
| def get_output_model_registered_name_and_config(model_path: str, model_format: str, dtype: str, group_size: int, | ||||||
| quantized_format: str): | ||||||
| """Get the registered name of the turbomind model and its configuration | ||||||
| according to the input model path, format and user-input config. The name | ||||||
| will be used to access the OUTPUT_MODELS registry. | ||||||
|
|
@@ -42,6 +43,8 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s | |||||
| ['hf', 'awq', 'gptq'] | ||||||
| dtype (str): the data type of the model's weights and activations | ||||||
| group_size (int): the size of group used by awq model | ||||||
| quantized_format (str): the quantized format of compressed-tensors model, | ||||||
| which can be one of ['pack-quantized', 'float-quantized'] | ||||||
| """ | ||||||
|
Comment on lines
68
to
82
|
||||||
| register_name = 'tm' | ||||||
|
|
||||||
|
|
@@ -75,11 +78,18 @@ def get_output_model_registered_name_and_config(model_path: str, model_format: s | |||||
| session_len = _get_and_verify_max_len(model_config, None) | ||||||
|
|
||||||
| if model_format in ['awq', 'gptq', 'compressed-tensors']: | ||||||
| weight_type = 'int4' | ||||||
| dtype = 'float16' # force float16 for int4 quantized weights | ||||||
| if model_format in ['awq', 'gptq']: | ||||||
| weight_type = 'int4' | ||||||
| dtype = 'float16' # force float16 for int4 quantized weights | ||||||
| elif model_format == 'compressed-tensors': | ||||||
| if quantized_format == 'pack-quantized': | ||||||
| weight_type = 'int4' | ||||||
| model_format = 'awq' | ||||||
| dtype = 'float16' # force float16 for int4 quantized weights | ||||||
| elif quantized_format == 'float-quantized': | ||||||
| weight_type = 'fp8' | ||||||
| model_format = 'fp8' | ||||||
| group_size = 128 if group_size == 0 else group_size | ||||||
|
||||||
| if model_format == 'compressed-tensors': | ||||||
| model_format = 'awq' | ||||||
| elif model_format == 'fp8': | ||||||
| weight_type = 'fp8' | ||||||
| group_size = 128 | ||||||
|
|
@@ -196,18 +206,23 @@ def get_tm_model(model_path, | |||||
| _group_size = 32 | ||||||
| elif quant_method == 'compressed-tensors': | ||||||
| _format = quant_config['config_groups']['group_0']['format'] | ||||||
| assert _format == 'pack-quantized', ('compressed-tennsors only supports pack-quantized format, ' | ||||||
| f'but got {_format}') | ||||||
| assert _format in ['pack-quantized', 'float-quantized' | ||||||
| ], ('compressed-tennsors only supports pack-quantized/float-quantized format, ' | ||||||
|
||||||
| ], ('compressed-tennsors only supports pack-quantized/float-quantized format, ' | |
| ], ('compressed-tensors only supports pack-quantized/float-quantized format, ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"pack-quantized requires int4 or fp8, but got type {_type} and {_num_bits} bits"
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,6 +78,15 @@ def __call__(self, f, g, i): | |
| f(i, g('weight'), 'weight', identity) | ||
|
|
||
|
|
||
| class WeightScale(Parameter): | ||
| KEYS = '.weight_scale', '.weight' | ||
|
|
||
| # TODO: flag any operations crossing the quant blocks as illegal | ||
| def __call__(self, f, g, i): | ||
| f(i, g('weight_scale'), 'scales', to_float, apply_gs=['w1', 'w3', 'w2']) | ||
| f(i, g('weight'), 'weight', identity) | ||
|
Comment on lines
+115
to
+121
|
||
|
|
||
|
|
||
| class CompressedWeight(Parameter): | ||
| KEYS = '.weight_packed', '.weight_scale', '.weight_zero_point' | ||
|
|
||
|
|
@@ -133,6 +142,8 @@ def get_params(keys: list[str], bias=0): | |
| ps.append(QuantWeightOnly()) | ||
| if WeightScaleInv.take(keys): | ||
| ps.append(WeightScaleInv()) | ||
| if WeightScale.take(keys): | ||
| ps.append(WeightScale()) | ||
| xs = CompressedWeight.take(keys) | ||
| if xs: | ||
| ps.append(CompressedWeight(xs)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -68,7 +68,7 @@ def process_fp8(x: torch.Tensor, kind: str): | |||||||
| return x.to(dtype=torch.bfloat16) | ||||||||
|
|
||||||||
|
|
||||||||
| def process_compressed_tensor(x: torch.Tensor, kind: str): | ||||||||
| def process_compressed_packed_tensor(x: torch.Tensor, kind: str): | ||||||||
| x = x.cuda() | ||||||||
| if x.dtype == torch.int32: | ||||||||
| xs = get_u4_slices(x, torch.uint8) | ||||||||
|
|
@@ -79,7 +79,7 @@ def process_compressed_tensor(x: torch.Tensor, kind: str): | |||||||
| return x | ||||||||
|
|
||||||||
|
|
||||||||
| def get_input_policy(model_format): | ||||||||
| def get_input_policy(model_format, quantized_format=None): | ||||||||
| if model_format == 'awq': | ||||||||
| return process_awq_gemm | ||||||||
| elif model_format == 'gptq': | ||||||||
|
|
@@ -89,6 +89,9 @@ def get_input_policy(model_format): | |||||||
| elif model_format == 'fp8': | ||||||||
| return process_fp8 | ||||||||
| elif model_format == 'compressed-tensors': | ||||||||
| return process_compressed_tensor | ||||||||
| if quantized_format == 'pack-quantized': | ||||||||
| return process_compressed_packed_tensor | ||||||||
| elif quantized_format == 'float-quantized': | ||||||||
| return process_fp8 | ||||||||
|
||||||||
| return process_fp8 | |
| return process_fp8 | |
| raise ValueError(f'Unsupported quantized_format for compressed-tensors: {quantized_format!r}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring for
get_output_model_registered_name_and_config()still saysmodel_formatis one of['hf', 'awq', 'gptq'], but the function now handlesfp8,mxfp4, andcompressed-tensorsas well. Please update the docstring to reflect the supported values (and clarify howcompressed-tensorsmaps toawq/fp8).