diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index 68fa03a407..9cc42638c7 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -93,7 +93,7 @@ def autoget_backend_config( return backend, config -def check_vl_llm(backend: str, config: dict) -> bool: +def check_vl_llm(config: dict) -> bool: """Check if the model is a vl model from model config.""" if 'auto_map' in config: for _, v in config['auto_map'].items(): @@ -121,14 +121,12 @@ def check_vl_llm(backend: str, config: dict) -> bool: return True elif arch in ['ChatGLMModel', 'ChatGLMForConditionalGeneration'] and 'vision_config' in config: return True - elif arch in ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration'] and backend == 'turbomind': - return False elif arch in supported_archs: return True return False -def get_task(backend: str, model_path: str): +def get_task(model_path: str, backend_config: PytorchEngineConfig | TurbomindEngineConfig | None = None): """Get pipeline type and pipeline class from model config.""" from lmdeploy.serve.core import AsyncEngine @@ -136,7 +134,9 @@ def get_task(backend: str, model_path: str): # workspace model return 'llm', AsyncEngine _, config = get_model_arch(model_path) - if check_vl_llm(backend, config.to_dict()): + if check_vl_llm(config.to_dict()): + if backend_config and backend_config.disable_vision_encoder: + return 'llm', AsyncEngine from lmdeploy.serve.core import VLAsyncEngine return 'vlm', VLAsyncEngine diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 155392f4a7..8274d7073d 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -272,7 +272,8 @@ def api_server(args): async_=args.async_, communicator=args.communicator, enable_metrics=not args.disable_metrics, - hf_overrides=args.hf_overrides) + hf_overrides=args.hf_overrides, + disable_vision_encoder=args.disable_vision_encoder) chat_template_config = get_chat_template(args.chat_template, args.model_path) speculative_config = get_speculative_config(args) diff --git a/lmdeploy/lite/apis/calibrate.py b/lmdeploy/lite/apis/calibrate.py index c46c1ab820..1f8ef62c9a 100644 --- a/lmdeploy/lite/apis/calibrate.py +++ b/lmdeploy/lite/apis/calibrate.py @@ -241,7 +241,7 @@ def calibrate(model: str, 'Support only `wikitext2`, `c4`, `pileval`, `gsm8k`, ' \ '`neuralmagic_calibration`, `open-platypus`, `openwebtext`.' - model_type, _ = get_task(backend='turbomind', model_path=model) + model_type, _ = get_task(model_path=model) make_compatible_internvl_config(model) # Load tokenizer and configuration diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index d6cd1a3329..61531f72f7 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -253,6 +253,8 @@ class TurbomindEngineConfig: it to True if you want to update weights after create the pipeline hf_overrides: Huggingface overrides for the model. It can be used to override the default config of the model + disable_vision_encoder: Whether to disable loading vision + encoder. Default to False. enable_metrics: enable metrics system """ @@ -291,6 +293,7 @@ class TurbomindEngineConfig: empty_init: bool = False communicator: str = 'nccl' hf_overrides: dict[str, Any] | None = None + disable_vision_encoder: bool = False enable_metrics: bool = True def __post_init__(self): diff --git a/lmdeploy/pipeline.py b/lmdeploy/pipeline.py index ca4c42bba0..d707e6f487 100644 --- a/lmdeploy/pipeline.py +++ b/lmdeploy/pipeline.py @@ -69,7 +69,7 @@ def __init__(self, # Create inference engine backend, backend_config = autoget_backend_config(model_path, backend_config) - _, pipeline_class = get_task(backend, model_path) + _, pipeline_class = get_task(model_path, backend_config) self.async_engine = pipeline_class(model_path, backend=backend, backend_config=backend_config, diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index f64586fff7..ce76bf3c07 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1485,7 +1485,7 @@ def serve(model_path: str, http_or_https = 'https' handle_torchrun() - _, pipeline_class = get_task(backend, model_path) + _, pipeline_class = get_task(model_path, backend_config) if isinstance(backend_config, PytorchEngineConfig): backend_config.enable_mp_engine = True # router replay diff --git a/lmdeploy/turbomind/deploy/source_model/qwen.py b/lmdeploy/turbomind/deploy/source_model/qwen.py index 2223151e54..4eaf29053d 100644 --- a/lmdeploy/turbomind/deploy/source_model/qwen.py +++ b/lmdeploy/turbomind/deploy/source_model/qwen.py @@ -153,6 +153,17 @@ def moe_ffn_shared_gate(self, i): return self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.shared_expert_gate.weight') +def _configure_nested_language_model_prefix(reader): + """Handle VL checkpoints whose text weights live under + ``model.language_model``.""" + if any(k.startswith('model.language_model.') for k in reader.params.keys()): + reader.attn_layer_prefix = 'model.language_model.layers' + reader.tok_embeddings_key = 'model.language_model.embed_tokens.weight' + reader.norm_weight_key = 'model.language_model.norm.weight' + if reader.model_cfg.get('tie_word_embeddings', False): + reader.output_weight_key = reader.tok_embeddings_key + + @INPUT_MODELS.register_module(name='qwen2-moe') class Qwen2MoeModel(LlamaModel): @@ -172,6 +183,11 @@ def model_info(self): class Qwen3Reader(LlamaReader): + attn_layer_patten = r'(?:model\.language_model\.|model\.)layers\.([0-9]+)\.' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + _configure_nested_language_model_prefix(self) def qk_norm(self, i: int): result = [] @@ -193,6 +209,11 @@ def model_info(self): class Qwen3MoeReader(Qwen2MoeReader): + attn_layer_patten = r'(?:model\.language_model\.|model\.)layers\.([0-9]+)\.' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + _configure_nested_language_model_prefix(self) def qk_norm(self, i: int): result = [] @@ -236,13 +257,7 @@ class Qwen3_5ReaderMixin: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if any(k.startswith('model.language_model.') for k in self.params.keys()): - self.attn_layer_prefix = 'model.language_model.layers' - self.tok_embeddings_key = 'model.language_model.embed_tokens.weight' - self.norm_weight_key = 'model.language_model.norm.weight' - tie_word_embeddings = self.model_cfg.get('tie_word_embeddings', False) - if tie_word_embeddings: - self.output_weight_key = self.tok_embeddings_key + _configure_nested_language_model_prefix(self) # ---- zero-centered RMSNorm: add 1 to weights during export ---- def attn_norm(self, i: int): diff --git a/lmdeploy/turbomind/supported_models.py b/lmdeploy/turbomind/supported_models.py index 732b38c84d..6acd19ec27 100644 --- a/lmdeploy/turbomind/supported_models.py +++ b/lmdeploy/turbomind/supported_models.py @@ -32,6 +32,8 @@ # Qwen3 Qwen3ForCausalLM='qwen3', Qwen3MoeForCausalLM='qwen3-moe', + Qwen3VLForConditionalGeneration='qwen3', + Qwen3VLMoeForConditionalGeneration='qwen3-moe', # Qwen 3.5 Qwen3_5ForConditionalGeneration='qwen3_5', Qwen3_5MoeForConditionalGeneration='qwen3_5-moe', diff --git a/lmdeploy/vl/model/qwen3.py b/lmdeploy/vl/model/qwen3.py index 729c2a333c..fc60678743 100644 --- a/lmdeploy/vl/model/qwen3.py +++ b/lmdeploy/vl/model/qwen3.py @@ -7,16 +7,105 @@ from lmdeploy.utils import get_logger from lmdeploy.vl.constants import Modality from lmdeploy.vl.model.base import VISION_MODELS, VisionModel +from lmdeploy.vl.model.utils import disable_logging logger = get_logger('lmdeploy') -def check_transformers(): +def check_qwen3_vl_deps_install(): + """Check dependencies for Qwen3-VL / Qwen3.5 (same vision stack as + Qwen2-VL's ``check_qwen_vl_deps_install``). + + - **Transformers**: recent build with Qwen3-VL and Qwen3.5 classes (see Qwen3-VL model card on HF). + - **Accelerate**: required for TurboMind split vision loading (`load_checkpoint_and_dispatch`). + - **qwen-vl-utils** (optional): pip package ``qwen-vl-utils``; many upstream Qwen-VL recipes use it for + video helpers. LMDeploy's Qwen3 preprocessor uses ``AutoProcessor`` only; warn if missing so users + can align with `Qwen2VLModel` / official docs when needed. + """ try: - from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration # noqa: F401 + from transformers import ( # noqa: F401 + Qwen3_5ForConditionalGeneration, + Qwen3_5MoeForConditionalGeneration, + Qwen3VLForConditionalGeneration, + Qwen3VLMoeForConditionalGeneration, + ) except ImportError: - raise ImportError('please install latest transformers by ' + raise ImportError('please install a recent transformers with Qwen3-VL / Qwen3.5 support, e.g. ' 'pip install git+https://github.com/huggingface/transformers.git') + try: + import accelerate # noqa: F401 + except ImportError: + raise ImportError('please install accelerate for TurboMind vision loading: pip install accelerate') + try: + import qwen_vl_utils # noqa: F401 + except ImportError: + logger.warning_once( + 'qwen-vl-utils is not installed. Install with `pip install qwen-vl-utils` if you use ' + 'video pipelines or helpers from the Qwen-VL examples (optional for LMDeploy Qwen3 preprocess).') + + +def resolve_qwen_vl_family_automodel(arch: str) -> tuple[type, list[str]]: + """Map HF architecture name to the model class and accelerate no-split + vision block names. + + Qwen3-VL introduced this TurboMind split-vision path; Qwen3.5 reuses the same stack. + """ + if arch == 'Qwen3VLForConditionalGeneration': + from transformers import Qwen3VLForConditionalGeneration as AutoModelCls + + no_split = ['Qwen3VLVisionBlock', 'Qwen3VLMoeVisionBlock'] + elif arch == 'Qwen3VLMoeForConditionalGeneration': + from transformers import Qwen3VLMoeForConditionalGeneration as AutoModelCls + + no_split = ['Qwen3VLVisionBlock', 'Qwen3VLMoeVisionBlock'] + elif arch == 'Qwen3_5ForConditionalGeneration': + from transformers import Qwen3_5ForConditionalGeneration as AutoModelCls + + no_split = ['Qwen3_5VisionBlock', 'Qwen3_5MoeVisionBlock'] + elif arch == 'Qwen3_5MoeForConditionalGeneration': + from transformers import Qwen3_5MoeForConditionalGeneration as AutoModelCls + + no_split = ['Qwen3_5VisionBlock', 'Qwen3_5MoeVisionBlock'] + else: + raise ValueError(f'Unsupported Qwen VL family architecture: {arch}') + return AutoModelCls, no_split + + +def load_qwen_vl_family_vision_backbone( + model_path: str, + hf_config: Any, + with_llm: bool, + max_memory: dict[int, int] | None, +) -> Any: + """Load vision tower only (TurboMind path) for Qwen3-VL and Qwen3.5.""" + arch = hf_config.architectures[0] + AutoModelCls, no_split = resolve_qwen_vl_family_automodel(arch) + + if with_llm: + return AutoModelCls.from_pretrained(model_path, device_map='cpu') + + from accelerate import init_empty_weights, load_checkpoint_and_dispatch + + with init_empty_weights(): + config = hf_config + config.tie_word_embeddings = False + if hasattr(config, 'text_config'): + config.text_config.tie_word_embeddings = False + model = AutoModelCls._from_config(config) + del model.model.language_model + del model.lm_head + model.half() + + with disable_logging(): + load_checkpoint_and_dispatch( + model=model, + checkpoint=model_path, + device_map='auto', + max_memory=max_memory, + no_split_module_classes=no_split, + dtype=torch.half, + ) + return model.model.eval() @VISION_MODELS.register_module() @@ -26,7 +115,7 @@ class Qwen3VLModel(VisionModel): _arch = ['Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration'] def build_preprocessor(self): - check_transformers() + check_qwen3_vl_deps_install() self.processor = AutoProcessor.from_pretrained(self.model_path) # image tokens @@ -167,7 +256,13 @@ def proc_messages(self, messages, chat_template, sequence_start, chat_template_k else: prompt_messages = messages prompt = chat_template.messages2prompt(prompt_messages, sequence_start, **chat_template_kwargs) - return prompt, None + return prompt, self.image_token + + def _ensure_turbomind_image_only(self, inputs: list[dict]): + """TurboMind split vision currently supports image inputs only.""" + has_video = self.contains_video_input or any('video_grid_thw' in item for item in inputs) + if has_video: + raise NotImplementedError('TurboMind split vision for the Qwen3 VL family currently supports images only.') def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequence_start): """Pack the video input to the compatible format with pytorch @@ -229,13 +324,61 @@ def to_pytorch(self, return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start) def build_model(self): - # TODO: implement for turbomind - pass + """Load vision tower for TurboMind split path (Qwen3-VL and Qwen3.5 + share the same stack).""" + loaded = load_qwen_vl_family_vision_backbone(self.model_path, self.hf_config, self.with_llm, + self.max_memory) + if self.with_llm: + self.vl_model = loaded + else: + self.model = loaded @torch.no_grad() def forward(self, messages: list[dict], max_batch_size: int = 1) -> list[dict]: - # TODO: implement for turbomind - pass + """Run vision encoder for TurboMind split path (shared Qwen3 VL + family).""" + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0] + self._ensure_turbomind_image_only(inputs) + dtype = torch.half + device = next(self.model.visual.parameters()).device + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + pixel_values = [x['pixel_values'].type(dtype) for x in inputs[idx:idx + max_batch_size]] + image_grid_thw = [x['image_grid_thw'] for x in inputs[idx:idx + max_batch_size]] + pixel_values = torch.cat(pixel_values, dim=0).to(device) + image_grid_thw = torch.cat(image_grid_thw, dim=0).to(device) + image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw) + if hasattr(image_embeds, 'pooler_output'): + image_embeds = image_embeds.pooler_output + merge_length = self.processor.image_processor.merge_size**2 + split_size = image_grid_thw.prod(dim=1) // merge_length + image_embeds = image_embeds.split(split_size.tolist()) + outputs.extend(image_embeds) + messages.append(dict(role='forward', content=outputs)) + return messages + + @staticmethod + def get_mrope_info(seq_len: int, grid_thws: list[tuple] | None = None, ranges: list[tuple] | None = None): + mrope_position_ids = [torch.arange(ranges[0][0]).expand(3, -1)] + st_idx = ranges[0][0] + for i, (grid_thw, embedding_range) in enumerate(zip(grid_thws, ranges)): + llm_grid_t, llm_grid_h, llm_grid_w = grid_thw + llm_grid_h //= 2 + llm_grid_w //= 2 + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + mrope_position_ids.append(torch.stack([t_index, h_index, w_index]) + st_idx) + st_idx += max(llm_grid_h, llm_grid_w) + if i < len(ranges) - 1: + text_len = ranges[i + 1][0] - ranges[i][1] + else: + text_len = seq_len - embedding_range[1] + mrope_position_ids.append(torch.arange(text_len).expand(3, -1) + st_idx) + st_idx += text_len + mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) + mrope_position_delta = torch.tensor([st_idx - seq_len], dtype=torch.long) + return mrope_position_ids, mrope_position_delta def to_turbomind(self, messages, @@ -244,5 +387,14 @@ def to_turbomind(self, sequence_start, chat_template_kwargs: dict | None = None, **kwargs): - # TODO: implement for turbomind - pass + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs) + inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0] + self._ensure_turbomind_image_only(inputs) + info = super().to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) + grid_thws = [x['image_grid_thw'].tolist()[0] for x in inputs] + seq_len = len(info['input_ids']) + ranges = info['input_embedding_ranges'] + mrope_position_ids, mrope_position_delta = self.get_mrope_info(seq_len, grid_thws, ranges) + meta = dict(mrope_position_ids=mrope_position_ids, mrope_position_delta=mrope_position_delta) + info.update(dict(input_meta=meta)) + return info diff --git a/lmdeploy/vl/model/qwen3_5.py b/lmdeploy/vl/model/qwen3_5.py index f030de5e5e..8ae3ecb43a 100644 --- a/lmdeploy/vl/model/qwen3_5.py +++ b/lmdeploy/vl/model/qwen3_5.py @@ -1,30 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. from transformers import AutoProcessor -from lmdeploy.utils import get_logger from lmdeploy.vl.model.base import VISION_MODELS -from .qwen3 import Qwen3VLModel - -logger = get_logger('lmdeploy') - - -def check_transformers(): - try: - from transformers import Qwen3_5ForConditionalGeneration, Qwen3_5MoeForConditionalGeneration # noqa: F401 - except ImportError: - raise ImportError('please install latest transformers by ' - 'pip install git+https://github.com/huggingface/transformers.git') +from .qwen3 import Qwen3VLModel, check_qwen3_vl_deps_install @VISION_MODELS.register_module() class Qwen3_5Model(Qwen3VLModel): - """Qwen3_5 model.""" + """Qwen3_5 model (TurboMind vision path is inherited from + `Qwen3VLModel`).""" _arch = ['Qwen3_5ForConditionalGeneration', 'Qwen3_5MoeForConditionalGeneration'] def build_preprocessor(self): - check_transformers() + check_qwen3_vl_deps_install() self.processor = AutoProcessor.from_pretrained(self.model_path) @@ -39,3 +29,4 @@ def build_preprocessor(self): # vision start and end tokens self.vision_start_token = self.processor.vision_start_token self.vision_end_token = self.processor.vision_end_token + self.mm_processor_kwargs = None diff --git a/tests/test_lmdeploy/test_pytorch/test_engine_disable_vision.py b/tests/test_lmdeploy/test_pytorch/test_engine_disable_vision.py new file mode 100644 index 0000000000..41ae9d34ba --- /dev/null +++ b/tests/test_lmdeploy/test_pytorch/test_engine_disable_vision.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +from unittest.mock import MagicMock, patch + +from lmdeploy.messages import PytorchEngineConfig, ResponseType + + +def test_on_add_message_disable_vision_rejects_multimodal(): + """Multimodal input with disable_vision_encoder must error, not strip + inputs.""" + from lmdeploy.pytorch.engine.engine import Engine + from lmdeploy.pytorch.engine.request import Request, RequestType, Response + + engine = Engine.__new__(Engine) + engine.engine_config = PytorchEngineConfig(disable_vision_encoder=True) + engine.input_processor = object() + engine.scheduler = MagicMock() + engine.scheduler.sessions = {1: MagicMock()} + engine.req_manager = MagicMock() + engine._add_message = MagicMock() + + resp = Response(type=ResponseType.SUCCESS, sender_id=0, event=asyncio.Event()) + req = Request( + type=RequestType.ADD_MESSAGE, + sender_id=0, + data={ + 'session_id': 1, + 'token_ids': [1, 2, 3], + 'input_multimodals': [{'image': []}], + 'response': True, + }, + resp=resp, + ) + + captured = [] + + def capture_response(req_manager, resp, resp_type, data, err_msg): + captured.append((resp_type, err_msg)) + + with patch('lmdeploy.pytorch.engine.engine.response_reqs', side_effect=capture_response): + engine._on_add_message([req]) + + assert len(captured) == 1 + assert captured[0][0] == ResponseType.INTERNAL_ENGINE_ERROR + assert 'disable_vision_encoder=True' in captured[0][1] + engine._add_message.assert_not_called() diff --git a/tests/test_lmdeploy/test_vl/test_qwen_vl_family.py b/tests/test_lmdeploy/test_vl/test_qwen_vl_family.py new file mode 100644 index 0000000000..df70eb72b6 --- /dev/null +++ b/tests/test_lmdeploy/test_vl/test_qwen_vl_family.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from types import SimpleNamespace + +import pytest +import torch + +from lmdeploy.archs import get_task +from lmdeploy.messages import TurbomindEngineConfig +from lmdeploy.vl.model.qwen3 import Qwen3VLModel, resolve_qwen_vl_family_automodel + + +@pytest.mark.parametrize('arch,expected_block', [ + ('Qwen3VLForConditionalGeneration', 'Qwen3VLVisionBlock'), + ('Qwen3VLMoeForConditionalGeneration', 'Qwen3VLMoeVisionBlock'), + ('Qwen3_5ForConditionalGeneration', 'Qwen3_5VisionBlock'), + ('Qwen3_5MoeForConditionalGeneration', 'Qwen3_5MoeVisionBlock'), +]) +def test_resolve_qwen_vl_family_automodel(arch, expected_block): + cls, no_split = resolve_qwen_vl_family_automodel(arch) + assert cls is not None + assert expected_block in no_split + + +def test_resolve_unknown_arch_raises(): + with pytest.raises(ValueError, match='Unsupported'): + resolve_qwen_vl_family_automodel('NotAModel') + + +def test_get_task_routes_qwen3_vl_to_vl_engine(monkeypatch): + cfg = SimpleNamespace(to_dict=lambda: {'architectures': ['Qwen3VLForConditionalGeneration']}) + monkeypatch.setattr('lmdeploy.archs.get_model_arch', lambda _path: ('Qwen3VLForConditionalGeneration', cfg)) + + task, pipeline_class = get_task('/fake-model', TurbomindEngineConfig()) + assert task == 'vlm' + assert pipeline_class.__name__ == 'VLAsyncEngine' + + +class _DummyChatTemplate: + + def __init__(self, prompt): + self.prompt = prompt + + def messages2prompt(self, messages, sequence_start, **kwargs): + return self.prompt + + +class _DummyTokenizer: + + def encode(self, text, add_bos=False): + tokens = [] if not text else [len(text)] + if add_bos: + return [0] + tokens + return tokens + + +def _build_qwen3_vl_stub(): + model = Qwen3VLModel.__new__(Qwen3VLModel) + model.image_token = '<|image_pad|>' + model.image_token_id = 151655 + model.contains_video_input = False + return model + + +def test_qwen3_vl_to_turbomind_uses_image_token_placeholder(): + model = _build_qwen3_vl_stub() + tokenizer = _DummyTokenizer() + prompt = 'prefix<|vision_start|><|image_pad|><|vision_end|>suffix' + chat_template = _DummyChatTemplate(prompt) + image_grid_thw = torch.tensor([[1, 2, 2]]) + image_embed = torch.randn(1, 4) + messages = [{ + 'role': 'user', + 'content': [{ + 'type': 'image', + 'data': object(), + }], + }, { + 'role': 'preprocess', + 'content': [{ + 'image_grid_thw': image_grid_thw, + }], + }, { + 'role': 'forward', + 'content': [image_embed], + }] + + info = model.to_turbomind(messages, chat_template, tokenizer, sequence_start=True) + + begin = len(tokenizer.encode('prefix<|vision_start|>', add_bos=True)) + assert info['input_embedding_ranges'] == [[begin, begin + image_embed.shape[0]]] + assert len(info['input_embeddings']) == 1 + assert info['input_meta']['mrope_position_ids'].shape[1] == len(info['input_ids']) + + +def test_qwen3_vl_to_turbomind_rejects_video(): + model = _build_qwen3_vl_stub() + model.contains_video_input = True + messages = [{ + 'role': 'preprocess', + 'content': [{ + 'video_grid_thw': torch.tensor([[1, 2, 2]]), + }], + }] + + with pytest.raises(NotImplementedError, match='supports images only'): + model.to_turbomind(messages, _DummyChatTemplate(''), _DummyTokenizer(), sequence_start=True)