Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions lmdeploy/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def autoget_backend_config(
return backend, config


def check_vl_llm(backend: str, config: dict) -> bool:
def check_vl_llm(config: dict) -> bool:
"""Check if the model is a vl model from model config."""
if 'auto_map' in config:
for _, v in config['auto_map'].items():
Expand Down Expand Up @@ -121,22 +121,22 @@ def check_vl_llm(backend: str, config: dict) -> bool:
return True
elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] and 'vision_config' in config:
return True
elif arch in ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration'] and backend == 'turbomind':
return False
elif arch in supported_archs:
return True
return False


def get_task(backend: str, model_path: str):
def get_task(model_path: str, backend_config: PytorchEngineConfig | TurbomindEngineConfig | None = None):
"""Get pipeline type and pipeline class from model config."""
from lmdeploy.serve.core import AsyncEngine

if os.path.exists(os.path.join(model_path, 'triton_models', 'weights')):
# workspace model
return 'llm', AsyncEngine
_, config = get_model_arch(model_path)
if check_vl_llm(backend, config.to_dict()):
if check_vl_llm(config.to_dict()):
if backend_config and backend_config.disable_vision_encoder:
return 'llm', AsyncEngine
Comment on lines +138 to +139
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When disable_vision_encoder is set, get_task() routes VL architectures to the plain AsyncEngine. AsyncEngine constructs MultimodalProcessor without a vl_encoder, and MultimodalProcessor.get_prompt_input() will then treat multimodal messages as text-only and silently drop image/video blocks (it only joins type=='text'). If the intended contract is to reject multimodal inputs when vision is disabled, this needs an explicit error path (or a different engine selection) to avoid silent data loss.

Suggested change
if backend_config and backend_config.disable_vision_encoder:
return 'llm', AsyncEngine
if backend_config and getattr(backend_config, 'disable_vision_encoder', False):
raise ValueError(
'Invalid configuration: disable_vision_encoder is True for a vision-language '
'model. This would route the model to a text-only engine and silently drop '
'image/video inputs. Please use a pure language model or enable the vision '
'encoder.'
)

Copilot uses AI. Check for mistakes.
from lmdeploy.serve.core import VLAsyncEngine
return 'vlm', VLAsyncEngine

Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def api_server(args):
async_=args.async_,
communicator=args.communicator,
enable_metrics=not args.disable_metrics,
hf_overrides=args.hf_overrides)
hf_overrides=args.hf_overrides,
disable_vision_encoder=args.disable_vision_encoder)
chat_template_config = get_chat_template(args.chat_template, args.model_path)
speculative_config = get_speculative_config(args)

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def calibrate(model: str,
'Support only `wikitext2`, `c4`, `pileval`, `gsm8k`, ' \
'`neuralmagic_calibration`, `open-platypus`, `openwebtext`.'

model_type, _ = get_task(backend='turbomind', model_path=model)
model_type, _ = get_task(model_path=model)
make_compatible_internvl_config(model)

# Load tokenizer and configuration
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ class TurbomindEngineConfig:
it to True if you want to update weights after create the pipeline
hf_overrides: Huggingface overrides for the model.
It can be used to override the default config of the model
disable_vision_encoder: Whether to disable loading vision
encoder. Default to False.
enable_metrics: enable metrics system
"""

Expand Down Expand Up @@ -291,6 +293,7 @@ class TurbomindEngineConfig:
empty_init: bool = False
communicator: str = 'nccl'
hf_overrides: dict[str, Any] | None = None
disable_vision_encoder: bool = False
enable_metrics: bool = True

def __post_init__(self):
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self,

# Create inference engine
backend, backend_config = autoget_backend_config(model_path, backend_config)
_, pipeline_class = get_task(backend, model_path)
_, pipeline_class = get_task(model_path, backend_config)
self.async_engine = pipeline_class(model_path,
backend=backend,
backend_config=backend_config,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,7 @@ def serve(model_path: str,
http_or_https = 'https'

handle_torchrun()
_, pipeline_class = get_task(backend, model_path)
_, pipeline_class = get_task(model_path, backend_config)
if isinstance(backend_config, PytorchEngineConfig):
backend_config.enable_mp_engine = True
# router replay
Expand Down
29 changes: 22 additions & 7 deletions lmdeploy/turbomind/deploy/source_model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def moe_ffn_shared_gate(self, i):
return self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.shared_expert_gate.weight')


def _configure_nested_language_model_prefix(reader):
"""Handle VL checkpoints whose text weights live under
``model.language_model``."""
if any(k.startswith('model.language_model.') for k in reader.params.keys()):
reader.attn_layer_prefix = 'model.language_model.layers'
reader.tok_embeddings_key = 'model.language_model.embed_tokens.weight'
reader.norm_weight_key = 'model.language_model.norm.weight'
if reader.model_cfg.get('tie_word_embeddings', False):
reader.output_weight_key = reader.tok_embeddings_key


@INPUT_MODELS.register_module(name='qwen2-moe')
class Qwen2MoeModel(LlamaModel):

Expand All @@ -172,6 +183,11 @@ def model_info(self):


class Qwen3Reader(LlamaReader):
attn_layer_patten = r'(?:model\.language_model\.|model\.)layers\.([0-9]+)\.'

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_configure_nested_language_model_prefix(self)

def qk_norm(self, i: int):
result = []
Expand All @@ -193,6 +209,11 @@ def model_info(self):


class Qwen3MoeReader(Qwen2MoeReader):
attn_layer_patten = r'(?:model\.language_model\.|model\.)layers\.([0-9]+)\.'

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_configure_nested_language_model_prefix(self)

def qk_norm(self, i: int):
result = []
Expand Down Expand Up @@ -236,13 +257,7 @@ class Qwen3_5ReaderMixin:

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if any(k.startswith('model.language_model.') for k in self.params.keys()):
self.attn_layer_prefix = 'model.language_model.layers'
self.tok_embeddings_key = 'model.language_model.embed_tokens.weight'
self.norm_weight_key = 'model.language_model.norm.weight'
tie_word_embeddings = self.model_cfg.get('tie_word_embeddings', False)
if tie_word_embeddings:
self.output_weight_key = self.tok_embeddings_key
_configure_nested_language_model_prefix(self)

# ---- zero-centered RMSNorm: add 1 to weights during export ----
def attn_norm(self, i: int):
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/turbomind/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
# Qwen3
Qwen3ForCausalLM='qwen3',
Qwen3MoeForCausalLM='qwen3-moe',
Qwen3VLForConditionalGeneration='qwen3',
Qwen3VLMoeForConditionalGeneration='qwen3-moe',
# Qwen 3.5
Qwen3_5ForConditionalGeneration='qwen3_5',
Qwen3_5MoeForConditionalGeneration='qwen3_5-moe',
Expand Down
174 changes: 163 additions & 11 deletions lmdeploy/vl/model/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,105 @@
from lmdeploy.utils import get_logger
from lmdeploy.vl.constants import Modality
from lmdeploy.vl.model.base import VISION_MODELS, VisionModel
from lmdeploy.vl.model.utils import disable_logging

logger = get_logger('lmdeploy')


def check_transformers():
def check_qwen3_vl_deps_install():
"""Check dependencies for Qwen3-VL / Qwen3.5 (same vision stack as
Qwen2-VL's ``check_qwen_vl_deps_install``).

- **Transformers**: recent build with Qwen3-VL and Qwen3.5 classes (see Qwen3-VL model card on HF).
- **Accelerate**: required for TurboMind split vision loading (`load_checkpoint_and_dispatch`).
- **qwen-vl-utils** (optional): pip package ``qwen-vl-utils``; many upstream Qwen-VL recipes use it for
video helpers. LMDeploy's Qwen3 preprocessor uses ``AutoProcessor`` only; warn if missing so users
can align with `Qwen2VLModel` / official docs when needed.
"""
try:
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration # noqa: F401
from transformers import ( # noqa: F401
Qwen3_5ForConditionalGeneration,
Qwen3_5MoeForConditionalGeneration,
Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration,
)
Comment on lines 25 to +31
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_qwen3_vl_deps_install() currently requires both Qwen3-VL and Qwen3.5 classes to be importable from transformers. This can regress Qwen3-VL usage in environments where transformers has Qwen3-VL but not yet Qwen3.5 (or vice-versa). Consider checking only the architecture actually being loaded (e.g., based on self.hf_config.architectures[0]), or making the Qwen3.5 import optional unless a Qwen3.5 arch is detected.

Copilot uses AI. Check for mistakes.
except ImportError:
raise ImportError('please install latest transformers by '
raise ImportError('please install a recent transformers with Qwen3-VL / Qwen3.5 support, e.g. '
'pip install git+https://github.com/huggingface/transformers.git')
try:
import accelerate # noqa: F401
except ImportError:
raise ImportError('please install accelerate for TurboMind vision loading: pip install accelerate')
try:
import qwen_vl_utils # noqa: F401
except ImportError:
logger.warning_once(
'qwen-vl-utils is not installed. Install with `pip install qwen-vl-utils` if you use '
'video pipelines or helpers from the Qwen-VL examples (optional for LMDeploy Qwen3 preprocess).')


def resolve_qwen_vl_family_automodel(arch: str) -> tuple[type, list[str]]:
"""Map HF architecture name to the model class and accelerate no-split
vision block names.

Qwen3-VL introduced this TurboMind split-vision path; Qwen3.5 reuses the same stack.
"""
if arch == 'Qwen3VLForConditionalGeneration':
from transformers import Qwen3VLForConditionalGeneration as AutoModelCls

no_split = ['Qwen3VLVisionBlock', 'Qwen3VLMoeVisionBlock']
elif arch == 'Qwen3VLMoeForConditionalGeneration':
from transformers import Qwen3VLMoeForConditionalGeneration as AutoModelCls

no_split = ['Qwen3VLVisionBlock', 'Qwen3VLMoeVisionBlock']
elif arch == 'Qwen3_5ForConditionalGeneration':
from transformers import Qwen3_5ForConditionalGeneration as AutoModelCls

no_split = ['Qwen3_5VisionBlock', 'Qwen3_5MoeVisionBlock']
elif arch == 'Qwen3_5MoeForConditionalGeneration':
from transformers import Qwen3_5MoeForConditionalGeneration as AutoModelCls

no_split = ['Qwen3_5VisionBlock', 'Qwen3_5MoeVisionBlock']
else:
raise ValueError(f'Unsupported Qwen VL family architecture: {arch}')
return AutoModelCls, no_split


def load_qwen_vl_family_vision_backbone(
model_path: str,
hf_config: Any,
with_llm: bool,
max_memory: dict[int, int] | None,
) -> Any:
"""Load vision tower only (TurboMind path) for Qwen3-VL and Qwen3.5."""
arch = hf_config.architectures[0]
AutoModelCls, no_split = resolve_qwen_vl_family_automodel(arch)

if with_llm:
return AutoModelCls.from_pretrained(model_path, device_map='cpu')

from accelerate import init_empty_weights, load_checkpoint_and_dispatch

with init_empty_weights():
config = hf_config
config.tie_word_embeddings = False
if hasattr(config, 'text_config'):
config.text_config.tie_word_embeddings = False
model = AutoModelCls._from_config(config)
del model.model.language_model
del model.lm_head
model.half()

with disable_logging():
load_checkpoint_and_dispatch(
model=model,
checkpoint=model_path,
device_map='auto',
max_memory=max_memory,
no_split_module_classes=no_split,
dtype=torch.half,
)
return model.model.eval()


@VISION_MODELS.register_module()
Expand All @@ -26,7 +115,7 @@ class Qwen3VLModel(VisionModel):
_arch = ['Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration']

def build_preprocessor(self):
check_transformers()
check_qwen3_vl_deps_install()
self.processor = AutoProcessor.from_pretrained(self.model_path)

# image tokens
Expand Down Expand Up @@ -167,7 +256,13 @@ def proc_messages(self, messages, chat_template, sequence_start, chat_template_k
else:
prompt_messages = messages
prompt = chat_template.messages2prompt(prompt_messages, sequence_start, **chat_template_kwargs)
return prompt, None
return prompt, self.image_token

def _ensure_turbomind_image_only(self, inputs: list[dict]):
"""TurboMind split vision currently supports image inputs only."""
has_video = self.contains_video_input or any('video_grid_thw' in item for item in inputs)
if has_video:
raise NotImplementedError('TurboMind split vision for the Qwen3 VL family currently supports images only.')

def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequence_start):
"""Pack the video input to the compatible format with pytorch
Expand Down Expand Up @@ -229,13 +324,61 @@ def to_pytorch(self,
return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start)

def build_model(self):
# TODO: implement for turbomind
pass
"""Load vision tower for TurboMind split path (Qwen3-VL and Qwen3.5
share the same stack)."""
loaded = load_qwen_vl_family_vision_backbone(self.model_path, self.hf_config, self.with_llm,
self.max_memory)
if self.with_llm:
self.vl_model = loaded
else:
self.model = loaded

@torch.no_grad()
def forward(self, messages: list[dict], max_batch_size: int = 1) -> list[dict]:
# TODO: implement for turbomind
pass
"""Run vision encoder for TurboMind split path (shared Qwen3 VL
family)."""
inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0]
self._ensure_turbomind_image_only(inputs)
dtype = torch.half
device = next(self.model.visual.parameters()).device
outputs = []
for idx in range(0, len(inputs), max_batch_size):
pixel_values = [x['pixel_values'].type(dtype) for x in inputs[idx:idx + max_batch_size]]
image_grid_thw = [x['image_grid_thw'] for x in inputs[idx:idx + max_batch_size]]
pixel_values = torch.cat(pixel_values, dim=0).to(device)
image_grid_thw = torch.cat(image_grid_thw, dim=0).to(device)
image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)
if hasattr(image_embeds, 'pooler_output'):
image_embeds = image_embeds.pooler_output
merge_length = self.processor.image_processor.merge_size**2
split_size = image_grid_thw.prod(dim=1) // merge_length
image_embeds = image_embeds.split(split_size.tolist())
outputs.extend(image_embeds)
messages.append(dict(role='forward', content=outputs))
return messages

@staticmethod
def get_mrope_info(seq_len: int, grid_thws: list[tuple] | None = None, ranges: list[tuple] | None = None):
mrope_position_ids = [torch.arange(ranges[0][0]).expand(3, -1)]
st_idx = ranges[0][0]
for i, (grid_thw, embedding_range) in enumerate(zip(grid_thws, ranges)):
llm_grid_t, llm_grid_h, llm_grid_w = grid_thw
llm_grid_h //= 2
llm_grid_w //= 2
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
mrope_position_ids.append(torch.stack([t_index, h_index, w_index]) + st_idx)
st_idx += max(llm_grid_h, llm_grid_w)
if i < len(ranges) - 1:
text_len = ranges[i + 1][0] - ranges[i][1]
else:
text_len = seq_len - embedding_range[1]
mrope_position_ids.append(torch.arange(text_len).expand(3, -1) + st_idx)
st_idx += text_len
mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
mrope_position_delta = torch.tensor([st_idx - seq_len], dtype=torch.long)
return mrope_position_ids, mrope_position_delta

def to_turbomind(self,
messages,
Expand All @@ -244,5 +387,14 @@ def to_turbomind(self,
sequence_start,
chat_template_kwargs: dict | None = None,
**kwargs):
# TODO: implement for turbomind
pass
prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs)
inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0]
self._ensure_turbomind_image_only(inputs)
info = super().to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)
grid_thws = [x['image_grid_thw'].tolist()[0] for x in inputs]
seq_len = len(info['input_ids'])
ranges = info['input_embedding_ranges']
mrope_position_ids, mrope_position_delta = self.get_mrope_info(seq_len, grid_thws, ranges)
meta = dict(mrope_position_ids=mrope_position_ids, mrope_position_delta=mrope_position_delta)
info.update(dict(input_meta=meta))
return info
Loading
Loading