77from lmdeploy .utils import get_logger
88from lmdeploy .vl .constants import Modality
99from lmdeploy .vl .model .base import VISION_MODELS , VisionModel
10+ from lmdeploy .vl .model .utils import disable_logging
1011
1112logger = get_logger ('lmdeploy' )
1213
1314
14- def check_transformers ():
15+ def check_qwen3_vl_deps_install ():
16+ """Check dependencies for Qwen3-VL / Qwen3.5 (same vision stack as
17+ Qwen2-VL's ``check_qwen_vl_deps_install``).
18+
19+ - **Transformers**: recent build with Qwen3-VL and Qwen3.5 classes (see Qwen3-VL model card on HF).
20+ - **Accelerate**: required for TurboMind split vision loading (`load_checkpoint_and_dispatch`).
21+ - **qwen-vl-utils** (optional): pip package ``qwen-vl-utils``; many upstream Qwen-VL recipes use it for
22+ video helpers. LMDeploy's Qwen3 preprocessor uses ``AutoProcessor`` only; warn if missing so users
23+ can align with `Qwen2VLModel` / official docs when needed.
24+ """
1525 try :
16- from transformers import Qwen3VLForConditionalGeneration , Qwen3VLMoeForConditionalGeneration # noqa: F401
26+ from transformers import ( # noqa: F401
27+ Qwen3_5ForConditionalGeneration ,
28+ Qwen3_5MoeForConditionalGeneration ,
29+ Qwen3VLForConditionalGeneration ,
30+ Qwen3VLMoeForConditionalGeneration ,
31+ )
1732 except ImportError :
18- raise ImportError ('please install latest transformers by '
33+ raise ImportError ('please install a recent transformers with Qwen3-VL / Qwen3.5 support, e.g. '
1934 'pip install git+https://github.com/huggingface/transformers.git' )
35+ try :
36+ import accelerate # noqa: F401
37+ except ImportError :
38+ raise ImportError ('please install accelerate for TurboMind vision loading: pip install accelerate' )
39+ try :
40+ import qwen_vl_utils # noqa: F401
41+ except ImportError :
42+ logger .warning_once (
43+ 'qwen-vl-utils is not installed. Install with `pip install qwen-vl-utils` if you use '
44+ 'video pipelines or helpers from the Qwen-VL examples (optional for LMDeploy Qwen3 preprocess).' )
45+
46+
47+ def resolve_qwen_vl_family_automodel (arch : str ) -> tuple [type , list [str ]]:
48+ """Map HF architecture name to the model class and accelerate no-split
49+ vision block names.
50+
51+ Qwen3-VL introduced this TurboMind split-vision path; Qwen3.5 reuses the same stack.
52+ """
53+ if arch == 'Qwen3VLForConditionalGeneration' :
54+ from transformers import Qwen3VLForConditionalGeneration as AutoModelCls
55+
56+ no_split = ['Qwen3VLVisionBlock' , 'Qwen3VLMoeVisionBlock' ]
57+ elif arch == 'Qwen3VLMoeForConditionalGeneration' :
58+ from transformers import Qwen3VLMoeForConditionalGeneration as AutoModelCls
59+
60+ no_split = ['Qwen3VLVisionBlock' , 'Qwen3VLMoeVisionBlock' ]
61+ elif arch == 'Qwen3_5ForConditionalGeneration' :
62+ from transformers import Qwen3_5ForConditionalGeneration as AutoModelCls
63+
64+ no_split = ['Qwen3_5VisionBlock' , 'Qwen3_5MoeVisionBlock' ]
65+ elif arch == 'Qwen3_5MoeForConditionalGeneration' :
66+ from transformers import Qwen3_5MoeForConditionalGeneration as AutoModelCls
67+
68+ no_split = ['Qwen3_5VisionBlock' , 'Qwen3_5MoeVisionBlock' ]
69+ else :
70+ raise ValueError (f'Unsupported Qwen VL family architecture: { arch } ' )
71+ return AutoModelCls , no_split
72+
73+
74+ def load_qwen_vl_family_vision_backbone (
75+ model_path : str ,
76+ hf_config : Any ,
77+ with_llm : bool ,
78+ max_memory : dict [int , int ] | None ,
79+ ) -> Any :
80+ """Load vision tower only (TurboMind path) for Qwen3-VL and Qwen3.5."""
81+ arch = hf_config .architectures [0 ]
82+ AutoModelCls , no_split = resolve_qwen_vl_family_automodel (arch )
83+
84+ if with_llm :
85+ return AutoModelCls .from_pretrained (model_path , device_map = 'cpu' )
86+
87+ from accelerate import init_empty_weights , load_checkpoint_and_dispatch
88+
89+ with init_empty_weights ():
90+ config = hf_config
91+ config .tie_word_embeddings = False
92+ if hasattr (config , 'text_config' ):
93+ config .text_config .tie_word_embeddings = False
94+ model = AutoModelCls ._from_config (config )
95+ del model .model .language_model
96+ del model .lm_head
97+ model .half ()
98+
99+ with disable_logging ():
100+ load_checkpoint_and_dispatch (
101+ model = model ,
102+ checkpoint = model_path ,
103+ device_map = 'auto' ,
104+ max_memory = max_memory ,
105+ no_split_module_classes = no_split ,
106+ dtype = torch .half ,
107+ )
108+ return model .model .eval ()
20109
21110
22111@VISION_MODELS .register_module ()
@@ -26,7 +115,7 @@ class Qwen3VLModel(VisionModel):
26115 _arch = ['Qwen3VLForConditionalGeneration' , 'Qwen3VLMoeForConditionalGeneration' ]
27116
28117 def build_preprocessor (self ):
29- check_transformers ()
118+ check_qwen3_vl_deps_install ()
30119 self .processor = AutoProcessor .from_pretrained (self .model_path )
31120
32121 # image tokens
@@ -229,13 +318,60 @@ def to_pytorch(self,
229318 return self .to_pytorch_aux (messages , prompt , self .image_token , tokenizer , sequence_start )
230319
231320 def build_model (self ):
232- # TODO: implement for turbomind
233- pass
321+ """Load vision tower for TurboMind split path (Qwen3-VL and Qwen3.5
322+ share the same stack)."""
323+ loaded = load_qwen_vl_family_vision_backbone (self .model_path , self .hf_config , self .with_llm ,
324+ self .max_memory )
325+ if self .with_llm :
326+ self .vl_model = loaded
327+ else :
328+ self .model = loaded
234329
235330 @torch .no_grad ()
236331 def forward (self , messages : list [dict ], max_batch_size : int = 1 ) -> list [dict ]:
237- # TODO: implement for turbomind
238- pass
332+ """Run vision encoder for TurboMind split path (shared Qwen3 VL
333+ family)."""
334+ inputs = [x ['content' ] for x in messages if x ['role' ] == 'preprocess' ][0 ]
335+ dtype = torch .half
336+ device = next (self .model .visual .parameters ()).device
337+ outputs = []
338+ for idx in range (0 , len (inputs ), max_batch_size ):
339+ pixel_values = [x ['pixel_values' ].type (dtype ) for x in inputs [idx :idx + max_batch_size ]]
340+ image_grid_thw = [x ['image_grid_thw' ] for x in inputs [idx :idx + max_batch_size ]]
341+ pixel_values = torch .cat (pixel_values , dim = 0 ).to (device )
342+ image_grid_thw = torch .cat (image_grid_thw , dim = 0 ).to (device )
343+ image_embeds = self .model .visual (pixel_values , grid_thw = image_grid_thw )
344+ if hasattr (image_embeds , 'pooler_output' ):
345+ image_embeds = image_embeds .pooler_output
346+ merge_length = self .processor .image_processor .merge_size ** 2
347+ split_size = image_grid_thw .prod (dim = 1 ) // merge_length
348+ image_embeds = image_embeds .split (split_size .tolist ())
349+ outputs .extend (image_embeds )
350+ messages .append (dict (role = 'forward' , content = outputs ))
351+ return messages
352+
353+ @staticmethod
354+ def get_mrope_info (seq_len : int , grid_thws : list [tuple ] | None = None , ranges : list [tuple ] | None = None ):
355+ mrope_position_ids = [torch .arange (ranges [0 ][0 ]).expand (3 , - 1 )]
356+ st_idx = ranges [0 ][0 ]
357+ for i , (grid_thw , embedding_range ) in enumerate (zip (grid_thws , ranges )):
358+ llm_grid_t , llm_grid_h , llm_grid_w = grid_thw
359+ llm_grid_h //= 2
360+ llm_grid_w //= 2
361+ t_index = torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (- 1 , llm_grid_h * llm_grid_w ).flatten ()
362+ h_index = torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (llm_grid_t , - 1 , llm_grid_w ).flatten ()
363+ w_index = torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (llm_grid_t , llm_grid_h , - 1 ).flatten ()
364+ mrope_position_ids .append (torch .stack ([t_index , h_index , w_index ]) + st_idx )
365+ st_idx += max (llm_grid_h , llm_grid_w )
366+ if i < len (ranges ) - 1 :
367+ text_len = ranges [i + 1 ][0 ] - ranges [i ][1 ]
368+ else :
369+ text_len = seq_len - embedding_range [1 ]
370+ mrope_position_ids .append (torch .arange (text_len ).expand (3 , - 1 ) + st_idx )
371+ st_idx += text_len
372+ mrope_position_ids = torch .cat (mrope_position_ids , dim = - 1 )
373+ mrope_position_delta = torch .tensor ([st_idx - seq_len ], dtype = torch .long )
374+ return mrope_position_ids , mrope_position_delta
239375
240376 def to_turbomind (self ,
241377 messages ,
@@ -244,5 +380,13 @@ def to_turbomind(self,
244380 sequence_start ,
245381 chat_template_kwargs : dict | None = None ,
246382 ** kwargs ):
247- # TODO: implement for turbomind
248- pass
383+ prompt , IMAGE_TOKEN = self .proc_messages (messages , chat_template , sequence_start , chat_template_kwargs )
384+ info = super ().to_turbomind_aux (messages , prompt , IMAGE_TOKEN , tokenizer , sequence_start )
385+ inputs = [x ['content' ] for x in messages if x ['role' ] == 'preprocess' ][0 ]
386+ grid_thws = [x ['image_grid_thw' ].tolist ()[0 ] for x in inputs ]
387+ seq_len = len (info ['input_ids' ])
388+ ranges = info ['input_embedding_ranges' ]
389+ mrope_position_ids , mrope_position_delta = self .get_mrope_info (seq_len , grid_thws , ranges )
390+ meta = dict (mrope_position_ids = mrope_position_ids , mrope_position_delta = mrope_position_delta )
391+ info .update (dict (input_meta = meta ))
392+ return info
0 commit comments