Skip to content

fix lite module for transformers>=5.0#4488

Open
43758726 wants to merge 7 commits intoInternLM:mainfrom
43758726:transformers_compatible
Open

fix lite module for transformers>=5.0#4488
43758726 wants to merge 7 commits intoInternLM:mainfrom
43758726:transformers_compatible

Conversation

@43758726
Copy link
Copy Markdown
Collaborator

@43758726 43758726 commented Apr 2, 2026

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

The [lmdeploy.lite] fails to quant/calibrate when running with [transformers >= 5.0] in some models.

Modification

lmdeploy/lite/quantization/calibration.py: Added fallback logic in _guess_num_heads() to unwrap nested config objects by checking for text_config and llm_config attributes before accessing head count parameters.
lmdeploy/lite/quantization/awq.py: Cast scales.max() and scales.min() to float32 before multiplication to prevent float16/bfloat16 overflow that produces inf.
lmdeploy/lite/apis/auto_awq.py: Changed the import of LAYER_TYPE_MAP and calibrate from a relative import to an absolute import to avoid potential circular import issues.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

Copilot AI review requested due to automatic review settings April 2, 2026 13:02
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes lmdeploy.lite quantization/calibration regressions when used with transformers>=5.0, focusing on newer nested config wrappers and numerical stability in AWQ smoothing.

Changes:

  • Add fallback logic in calibration to unwrap nested HF config objects before reading head-count fields.
  • Prevent potential overflow in AWQ scale normalization by computing extrema in float32.
  • Switch auto_awq to absolute imports for calibrate/LAYER_TYPE_MAP.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
lmdeploy/lite/quantization/calibration.py Unwrap nested config objects in _guess_num_heads; also includes new commented debug prints in the wrapped forward.
lmdeploy/lite/quantization/awq.py Adjusts AWQ smooth_fc_fcs normalization to avoid float16/bfloat16 overflow.
lmdeploy/lite/apis/auto_awq.py Changes relative import of calibrate/LAYER_TYPE_MAP to an absolute import.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +83 to +86
if hasattr(model.config, 'text_config'):
model.config = model.config.text_config
if hasattr(model.config, 'llm_config'):
model.config = model.config.llm_config
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_guess_num_heads() mutates model.config by reassigning it to text_config / llm_config. This has side effects for the rest of calibration (e.g., later code uses model.config.hidden_size, use_cache, and config updates/saving) and can break models whose wrapper config contains fields not present on the nested config. Use a local variable (e.g., cfg = model.config and unwrap cfg), and leave model.config unchanged.

Copilot uses AI. Check for mistakes.
@43758726 43758726 requested a review from lvhan028 April 2, 2026 13:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"""
hidden_size = mod.weight.shape[0]
eps = mod.variance_epsilon
eps = getattr(mod, 'variance_epsilon', None) or getattr(mod, 'eps', 1e-6)
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or treats 0/0.0 as falsy, so an explicit variance_epsilon=0 would be replaced by mod.eps/default. Prefer an explicit is None check (or nested getattr) so only missing attributes fall back, not valid zero values.

Suggested change
eps = getattr(mod, 'variance_epsilon', None) or getattr(mod, 'eps', 1e-6)
eps = getattr(mod, 'variance_epsilon', None)
if eps is None:
eps = getattr(mod, 'eps', 1e-6)

Copilot uses AI. Check for mistakes.
print('Warning: we cast model to float16 to prevent OOM. '
'You may enforce it bfloat16 by `--dtype bfloat16`')
torch_dtype = torch.float16

Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With dtype='auto', torch_dtype can still resolve to torch.bfloat16 (from config) even when torch.cuda.is_bf16_supported() is false; the current guard only checks dtype == 'bfloat16'. This can lead to failures later when moving the model to CUDA. Consider adding an auto branch to fall back to float16 (or raise) when the resolved dtype is bf16 but the device can’t run bf16.

Suggested change
if torch_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
if dtype == 'auto':
torch_dtype = torch.float16
if hasattr(hf_config, 'bf16'):
hf_config.bf16 = False
if hasattr(hf_config, 'fp16'):
hf_config.fp16 = True
else:
raise RuntimeError('Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)')

Copilot uses AI. Check for mistakes.
Comment on lines +241 to +242
denom = denom.to(dtype=dtype)
scales = scales / denom
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting denom back to dtype can overflow again (e.g., float16) and reintroduce inf/0 scaling, undermining the float32 max/min fix. Keep denom in float32 (and optionally compute scales = scales.float() / denom), then cast the final scales back to the target dtype.

Suggested change
denom = denom.to(dtype=dtype)
scales = scales / denom
scales = (scales.float() / denom).to(device=device, dtype=dtype)

Copilot uses AI. Check for mistakes.
Comment on lines 255 to 270
original_torch_dtype = AutoConfig.from_pretrained(model, trust_remote_code=True).torch_dtype
vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
model = vl_model
if hasattr(vl_model, 'language_model'): # deepseek-vl, ...
model = vl_model.language_model
if hasattr(vl_model, 'llm'): # MiniCPMV, ...
model = vl_model.llm
model.config.use_cache = False
if dtype == 'float16':
if hasattr(model.config, 'text_config'):
model.config.text_config.use_cache = False
elif hasattr(model.config, 'llm_config'):
model.config.llm_config.use_cache = False
if dtype == 'float16' or (dtype == 'auto' and original_torch_dtype == torch.float16):
model.half()
elif dtype == 'bfloat16':
elif dtype == 'bfloat16' or (dtype == 'auto' and original_torch_dtype == torch.bfloat16):
assert torch.cuda.is_bf16_supported(
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_torch_dtype from AutoConfig is not guaranteed to be a torch.dtype (it can be None or a string like 'float16'/'bfloat16'), so the equality checks against torch.float16/torch.bfloat16 can silently fail and skip the intended casting. Consider normalizing original_torch_dtype (e.g., mapping strings to torch.dtype) before these comparisons.

Copilot uses AI. Check for mistakes.
Comment on lines 254 to 256
from transformers import AutoConfig
original_torch_dtype = AutoConfig.from_pretrained(model, trust_remote_code=True).torch_dtype
vl_model = load_vl_model(model, backend=None, with_llm=True).vl_model
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds an extra AutoConfig.from_pretrained() call for VLM calibration, but load_vl_model() already loads the HF config via get_model_arch() (which calls AutoConfig.from_pretrained). Consider reusing that existing config (or reading torch_dtype from the loaded model/config) to avoid duplicate network/cache IO.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants