-
Notifications
You must be signed in to change notification settings - Fork 682
feat: implement Turbomind vision encoder support for Qwen3VL/3.5 families #4460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,16 +7,105 @@ | |
| from lmdeploy.utils import get_logger | ||
| from lmdeploy.vl.constants import Modality | ||
| from lmdeploy.vl.model.base import VISION_MODELS, VisionModel | ||
| from lmdeploy.vl.model.utils import disable_logging | ||
|
|
||
| logger = get_logger('lmdeploy') | ||
|
|
||
|
|
||
| def check_transformers(): | ||
| def check_qwen3_vl_deps_install(): | ||
| """Check dependencies for Qwen3-VL / Qwen3.5 (same vision stack as | ||
| Qwen2-VL's ``check_qwen_vl_deps_install``). | ||
|
|
||
| - **Transformers**: recent build with Qwen3-VL and Qwen3.5 classes (see Qwen3-VL model card on HF). | ||
| - **Accelerate**: required for TurboMind split vision loading (`load_checkpoint_and_dispatch`). | ||
| - **qwen-vl-utils** (optional): pip package ``qwen-vl-utils``; many upstream Qwen-VL recipes use it for | ||
| video helpers. LMDeploy's Qwen3 preprocessor uses ``AutoProcessor`` only; warn if missing so users | ||
| can align with `Qwen2VLModel` / official docs when needed. | ||
| """ | ||
| try: | ||
| from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration # noqa: F401 | ||
| from transformers import ( # noqa: F401 | ||
| Qwen3_5ForConditionalGeneration, | ||
| Qwen3_5MoeForConditionalGeneration, | ||
| Qwen3VLForConditionalGeneration, | ||
| Qwen3VLMoeForConditionalGeneration, | ||
| ) | ||
|
Comment on lines
25
to
+31
|
||
| except ImportError: | ||
| raise ImportError('please install latest transformers by ' | ||
| raise ImportError('please install a recent transformers with Qwen3-VL / Qwen3.5 support, e.g. ' | ||
| 'pip install git+https://github.com/huggingface/transformers.git') | ||
| try: | ||
| import accelerate # noqa: F401 | ||
| except ImportError: | ||
| raise ImportError('please install accelerate for TurboMind vision loading: pip install accelerate') | ||
| try: | ||
| import qwen_vl_utils # noqa: F401 | ||
| except ImportError: | ||
| logger.warning_once( | ||
| 'qwen-vl-utils is not installed. Install with `pip install qwen-vl-utils` if you use ' | ||
| 'video pipelines or helpers from the Qwen-VL examples (optional for LMDeploy Qwen3 preprocess).') | ||
|
|
||
|
|
||
| def resolve_qwen_vl_family_automodel(arch: str) -> tuple[type, list[str]]: | ||
| """Map HF architecture name to the model class and accelerate no-split | ||
| vision block names. | ||
|
|
||
| Qwen3-VL introduced this TurboMind split-vision path; Qwen3.5 reuses the same stack. | ||
| """ | ||
| if arch == 'Qwen3VLForConditionalGeneration': | ||
| from transformers import Qwen3VLForConditionalGeneration as AutoModelCls | ||
|
|
||
| no_split = ['Qwen3VLVisionBlock', 'Qwen3VLMoeVisionBlock'] | ||
| elif arch == 'Qwen3VLMoeForConditionalGeneration': | ||
| from transformers import Qwen3VLMoeForConditionalGeneration as AutoModelCls | ||
|
|
||
| no_split = ['Qwen3VLVisionBlock', 'Qwen3VLMoeVisionBlock'] | ||
| elif arch == 'Qwen3_5ForConditionalGeneration': | ||
| from transformers import Qwen3_5ForConditionalGeneration as AutoModelCls | ||
|
|
||
| no_split = ['Qwen3_5VisionBlock', 'Qwen3_5MoeVisionBlock'] | ||
| elif arch == 'Qwen3_5MoeForConditionalGeneration': | ||
| from transformers import Qwen3_5MoeForConditionalGeneration as AutoModelCls | ||
|
|
||
| no_split = ['Qwen3_5VisionBlock', 'Qwen3_5MoeVisionBlock'] | ||
| else: | ||
| raise ValueError(f'Unsupported Qwen VL family architecture: {arch}') | ||
| return AutoModelCls, no_split | ||
|
|
||
|
|
||
| def load_qwen_vl_family_vision_backbone( | ||
| model_path: str, | ||
| hf_config: Any, | ||
| with_llm: bool, | ||
| max_memory: dict[int, int] | None, | ||
| ) -> Any: | ||
| """Load vision tower only (TurboMind path) for Qwen3-VL and Qwen3.5.""" | ||
| arch = hf_config.architectures[0] | ||
| AutoModelCls, no_split = resolve_qwen_vl_family_automodel(arch) | ||
|
|
||
| if with_llm: | ||
| return AutoModelCls.from_pretrained(model_path, device_map='cpu') | ||
|
|
||
| from accelerate import init_empty_weights, load_checkpoint_and_dispatch | ||
|
|
||
| with init_empty_weights(): | ||
| config = hf_config | ||
| config.tie_word_embeddings = False | ||
| if hasattr(config, 'text_config'): | ||
| config.text_config.tie_word_embeddings = False | ||
| model = AutoModelCls._from_config(config) | ||
| del model.model.language_model | ||
| del model.lm_head | ||
| model.half() | ||
|
|
||
| with disable_logging(): | ||
| load_checkpoint_and_dispatch( | ||
| model=model, | ||
| checkpoint=model_path, | ||
| device_map='auto', | ||
| max_memory=max_memory, | ||
| no_split_module_classes=no_split, | ||
| dtype=torch.half, | ||
| ) | ||
| return model.model.eval() | ||
|
|
||
|
|
||
| @VISION_MODELS.register_module() | ||
|
|
@@ -26,7 +115,7 @@ class Qwen3VLModel(VisionModel): | |
| _arch = ['Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration'] | ||
|
|
||
| def build_preprocessor(self): | ||
| check_transformers() | ||
| check_qwen3_vl_deps_install() | ||
| self.processor = AutoProcessor.from_pretrained(self.model_path) | ||
|
|
||
| # image tokens | ||
|
|
@@ -167,7 +256,13 @@ def proc_messages(self, messages, chat_template, sequence_start, chat_template_k | |
| else: | ||
| prompt_messages = messages | ||
| prompt = chat_template.messages2prompt(prompt_messages, sequence_start, **chat_template_kwargs) | ||
| return prompt, None | ||
| return prompt, self.image_token | ||
|
|
||
| def _ensure_turbomind_image_only(self, inputs: list[dict]): | ||
| """TurboMind split vision currently supports image inputs only.""" | ||
| has_video = self.contains_video_input or any('video_grid_thw' in item for item in inputs) | ||
| if has_video: | ||
| raise NotImplementedError('TurboMind split vision for the Qwen3 VL family currently supports images only.') | ||
|
|
||
| def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequence_start): | ||
| """Pack the video input to the compatible format with pytorch | ||
|
|
@@ -229,13 +324,61 @@ def to_pytorch(self, | |
| return self.to_pytorch_aux(messages, prompt, self.image_token, tokenizer, sequence_start) | ||
|
|
||
| def build_model(self): | ||
| # TODO: implement for turbomind | ||
| pass | ||
| """Load vision tower for TurboMind split path (Qwen3-VL and Qwen3.5 | ||
| share the same stack).""" | ||
| loaded = load_qwen_vl_family_vision_backbone(self.model_path, self.hf_config, self.with_llm, | ||
| self.max_memory) | ||
| if self.with_llm: | ||
| self.vl_model = loaded | ||
| else: | ||
| self.model = loaded | ||
|
|
||
| @torch.no_grad() | ||
| def forward(self, messages: list[dict], max_batch_size: int = 1) -> list[dict]: | ||
| # TODO: implement for turbomind | ||
| pass | ||
| """Run vision encoder for TurboMind split path (shared Qwen3 VL | ||
| family).""" | ||
| inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0] | ||
| self._ensure_turbomind_image_only(inputs) | ||
| dtype = torch.half | ||
| device = next(self.model.visual.parameters()).device | ||
| outputs = [] | ||
| for idx in range(0, len(inputs), max_batch_size): | ||
| pixel_values = [x['pixel_values'].type(dtype) for x in inputs[idx:idx + max_batch_size]] | ||
| image_grid_thw = [x['image_grid_thw'] for x in inputs[idx:idx + max_batch_size]] | ||
| pixel_values = torch.cat(pixel_values, dim=0).to(device) | ||
| image_grid_thw = torch.cat(image_grid_thw, dim=0).to(device) | ||
| image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw) | ||
| if hasattr(image_embeds, 'pooler_output'): | ||
| image_embeds = image_embeds.pooler_output | ||
| merge_length = self.processor.image_processor.merge_size**2 | ||
| split_size = image_grid_thw.prod(dim=1) // merge_length | ||
| image_embeds = image_embeds.split(split_size.tolist()) | ||
| outputs.extend(image_embeds) | ||
| messages.append(dict(role='forward', content=outputs)) | ||
| return messages | ||
|
|
||
| @staticmethod | ||
| def get_mrope_info(seq_len: int, grid_thws: list[tuple] | None = None, ranges: list[tuple] | None = None): | ||
| mrope_position_ids = [torch.arange(ranges[0][0]).expand(3, -1)] | ||
| st_idx = ranges[0][0] | ||
| for i, (grid_thw, embedding_range) in enumerate(zip(grid_thws, ranges)): | ||
| llm_grid_t, llm_grid_h, llm_grid_w = grid_thw | ||
| llm_grid_h //= 2 | ||
| llm_grid_w //= 2 | ||
| t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() | ||
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() | ||
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() | ||
| mrope_position_ids.append(torch.stack([t_index, h_index, w_index]) + st_idx) | ||
| st_idx += max(llm_grid_h, llm_grid_w) | ||
| if i < len(ranges) - 1: | ||
| text_len = ranges[i + 1][0] - ranges[i][1] | ||
| else: | ||
| text_len = seq_len - embedding_range[1] | ||
| mrope_position_ids.append(torch.arange(text_len).expand(3, -1) + st_idx) | ||
| st_idx += text_len | ||
| mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) | ||
| mrope_position_delta = torch.tensor([st_idx - seq_len], dtype=torch.long) | ||
| return mrope_position_ids, mrope_position_delta | ||
|
|
||
| def to_turbomind(self, | ||
| messages, | ||
|
|
@@ -244,5 +387,14 @@ def to_turbomind(self, | |
| sequence_start, | ||
| chat_template_kwargs: dict | None = None, | ||
| **kwargs): | ||
| # TODO: implement for turbomind | ||
| pass | ||
| prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs) | ||
| inputs = [x['content'] for x in messages if x['role'] == 'preprocess'][0] | ||
| self._ensure_turbomind_image_only(inputs) | ||
| info = super().to_turbomind_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) | ||
| grid_thws = [x['image_grid_thw'].tolist()[0] for x in inputs] | ||
| seq_len = len(info['input_ids']) | ||
| ranges = info['input_embedding_ranges'] | ||
| mrope_position_ids, mrope_position_delta = self.get_mrope_info(seq_len, grid_thws, ranges) | ||
| meta = dict(mrope_position_ids=mrope_position_ids, mrope_position_delta=mrope_position_delta) | ||
| info.update(dict(input_meta=meta)) | ||
| return info | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
disable_vision_encoderis set,get_task()routes VL architectures to the plainAsyncEngine.AsyncEngineconstructsMultimodalProcessorwithout avl_encoder, andMultimodalProcessor.get_prompt_input()will then treat multimodal messages as text-only and silently drop image/video blocks (it only joinstype=='text'). If the intended contract is to reject multimodal inputs when vision is disabled, this needs an explicit error path (or a different engine selection) to avoid silent data loss.