Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lmdeploy/lite/apis/auto_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
import torch
from torch import nn

from lmdeploy.lite.apis.calibrate import LAYER_TYPE_MAP, calibrate
from lmdeploy.lite.quantization.awq import FC_FCS_MAP, NORM_FCS_MAP, awq_layers, quant_weights, smooth_layers
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.utils import try_import_deeplink

from .calibrate import LAYER_TYPE_MAP, calibrate


def save_vl_model(vl_model, model_path, dst_path):
vl_model.save_pretrained(dst_path, safe_serialization=True)
Expand Down
14 changes: 8 additions & 6 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,23 +251,25 @@ def calibrate(model: str,
model = load_hf_from_pretrained(model, dtype=dtype, trust_remote_code=True)
vl_model = None
elif model_type == 'vlm':
from transformers import AutoConfig
original_torch_dtype = AutoConfig.from_pretrained(model, trust_remote_code=True).torch_dtype
vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds an extra AutoConfig.from_pretrained() call for VLM calibration, but load_vl_model() already loads the HF config via get_model_arch() (which calls AutoConfig.from_pretrained). Consider reusing that existing config (or reading torch_dtype from the loaded model/config) to avoid duplicate network/cache IO.

Copilot uses AI. Check for mistakes.
model = vl_model
if hasattr(vl_model, 'language_model'): # deepseek-vl, ...
model = vl_model.language_model
if hasattr(vl_model, 'llm'): # MiniCPMV, ...
model = vl_model.llm
model.config.use_cache = False
if dtype == 'float16':
if hasattr(model.config, 'text_config'):
model.config.text_config.use_cache = False
elif hasattr(model.config, 'llm_config'):
model.config.llm_config.use_cache = False
if dtype == 'float16' or (dtype == 'auto' and original_torch_dtype == torch.float16):
model.half()
elif dtype == 'bfloat16':
elif dtype == 'bfloat16' or (dtype == 'auto' and original_torch_dtype == torch.bfloat16):
assert torch.cuda.is_bf16_supported(
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_torch_dtype from AutoConfig is not guaranteed to be a torch.dtype (it can be None or a string like 'float16'/'bfloat16'), so the equality checks against torch.float16/torch.bfloat16 can silently fail and skip the intended casting. Consider normalizing original_torch_dtype (e.g., mapping strings to torch.dtype) before these comparisons.

Copilot uses AI. Check for mistakes.
), 'your device does not support bfloat16 please set --dtype float16' # noqa
model.to(torch.bfloat16)
elif dtype == 'auto' and model.config.torch_dtype == torch.bfloat16:
print('Warning: we cast model to float16 to prevent OOM. You'
' may enforce it bfloat16 by `--dtype bfloat16`')
model.half()
model.eval()

model_type = type(model).__name__
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/lite/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module,
'clamping w_scales.pow(1 - alpha) to 1e-4')
w_scales_pow = w_scales_pow.clamp(min=1e-4)
scales = (act_scales.pow(alpha) / w_scales_pow).clamp(min=1e-4).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt()
# prevent scales.max() * scales.min() == inf
denom = (scales.max().float() * scales.min().float()).sqrt()
denom = denom.to(dtype=dtype)
scales = scales / denom
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting denom back to dtype can overflow again (e.g., float16) and reintroduce inf/0 scaling, undermining the float32 max/min fix. Keep denom in float32 (and optionally compute scales = scales.float() / denom), then cast the final scales back to the target dtype.

Suggested change
denom = denom.to(dtype=dtype)
scales = scales / denom
scales = (scales.float() / denom).to(device=device, dtype=dtype)

Copilot uses AI. Check for mistakes.

# (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale
# phi3 fused qkv and gate_up
Expand Down
21 changes: 14 additions & 7 deletions lmdeploy/lite/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(self,
self.norm_type = norm_type
self.batch_size = batch_size

num_kv_heads, num_attn_heads = self._guess_num_heads(model)
num_kv_heads, num_attn_heads, text_config = self._guess_num_heads(model)
self.num_kv_heads = num_kv_heads
self.head_dim = model.config.hidden_size // num_attn_heads
self.head_dim = text_config.hidden_size // num_attn_heads
self.model = model

self.tokenizer = tokenizer
Expand All @@ -80,14 +80,21 @@ def __init__(self,

def _guess_num_heads(self, model):

if hasattr(model.config, 'num_key_value_heads'):
num_kv_heads = model.config.num_key_value_heads
if hasattr(model.config, 'text_config'):
text_config = model.config.text_config
elif hasattr(model.config, 'llm_config'):
text_config = model.config.llm_config
else:
num_kv_heads = model.config.num_attention_heads
text_config = model.config

num_attn_heads = model.config.num_attention_heads
if hasattr(text_config, 'num_key_value_heads'):
num_kv_heads = text_config.num_key_value_heads
else:
num_kv_heads = text_config.num_attention_heads

num_attn_heads = text_config.num_attention_heads

return num_kv_heads, num_attn_heads
return num_kv_heads, num_attn_heads, text_config

def _init_input_observers(self, name2mod):
"""Initialize input observers for given modules."""
Expand Down
4 changes: 0 additions & 4 deletions lmdeploy/lite/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ def load_hf_from_pretrained(pretrained_model_name_or_path, dtype: Literal['float
torch_dtype = torch.bfloat16
elif dtype == 'float16':
torch_dtype = torch.float16
elif dtype == 'auto' and torch_dtype == torch.bfloat16:
print('Warning: we cast model to float16 to prevent OOM. '
'You may enforce it bfloat16 by `--dtype bfloat16`')
torch_dtype = torch.float16

Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With dtype='auto', torch_dtype can still resolve to torch.bfloat16 (from config) even when torch.cuda.is_bf16_supported() is false; the current guard only checks dtype == 'bfloat16'. This can lead to failures later when moving the model to CUDA. Consider adding an auto branch to fall back to float16 (or raise) when the resolved dtype is bf16 but the device can’t run bf16.

Suggested change
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
if dtype == 'auto':
torch_dtype = torch.float16
if hasattr(hf_config, 'bf16'):
hf_config.bf16 = False
if hasattr(hf_config, 'fp16'):
hf_config.fp16 = True
else:
raise RuntimeError('Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)')

Copilot uses AI. Check for mistakes.
with LoadNoInit():
# Load model
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/models/q_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def from_float(cls, mod: nn.Module, initialization: bool = True, quant_dtype=tor
`initialization = True` for real init. `initialization = False` for dummy init.
"""
hidden_size = mod.weight.shape[0]
eps = mod.variance_epsilon
eps = getattr(mod, 'variance_epsilon', None) or getattr(mod, 'eps', 1e-6)
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or treats 0/0.0 as falsy, so an explicit variance_epsilon=0 would be replaced by mod.eps/default. Prefer an explicit is None check (or nested getattr) so only missing attributes fall back, not valid zero values.

Suggested change
eps = getattr(mod, 'variance_epsilon', None) or getattr(mod, 'eps', 1e-6)
eps = getattr(mod, 'variance_epsilon', None)
if eps is None:
eps = getattr(mod, 'eps', 1e-6)

Copilot uses AI. Check for mistakes.
q_mod = cls(hidden_size, eps, quant_dtype=quant_dtype)
if initialization:
q_mod.weight = nn.Parameter(mod.weight.detach())
Expand Down
Loading