diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index 68fa03a407..74b0935dcf 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -110,9 +110,9 @@ def check_vl_llm(backend: str, config: dict) -> bool: 'InternVLChatModel', 'MiniCPMV', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration', 'Qwen3_5ForConditionalGeneration', - 'Qwen3_5MoeForConditionalGeneration', 'MllamaForConditionalGeneration', 'MolmoForCausalLM', - 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', 'InternVLForConditionalGeneration', - 'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration', + 'Qwen3_5MoeForConditionalGeneration', 'Qwen3OmniMoeForConditionalGeneration', 'MllamaForConditionalGeneration', + 'MolmoForCausalLM', 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', + 'InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration', 'InternS1ProForConditionalGeneration', 'InternS1_1_ForConditionalGeneration', 'Glm4vForConditionalGeneration' ]) if arch == 'QWenLMHeadModel' and 'visual' in config: diff --git a/lmdeploy/model.py b/lmdeploy/model.py index dbd8939ecf..d0ab8f2be0 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -688,8 +688,19 @@ class HFChatTemplate(BaseChatTemplate): def __init__(self, model_path: str = '', **kwargs): self.model_path = model_path try: - from transformers import AutoTokenizer + from transformers import AutoProcessor, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Some tokenizers do not have chat_template, in this case try to get chat_template from processor + # If this still does not work, fallback to BaseChatTemplate. + if getattr(self.tokenizer, 'chat_template', None) is None: + try: + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + self.tokenizer.chat_template = getattr(processor, 'chat_template', None) + except Exception as e: + logger.warning(f'Failed to load processor from {model_path} for chat template. ' + f'Fallback to tokenizer only. Error: {e}') + # Verify if the model can perform apply_chat_template with different roles. self.user_start, self.user_end, _, _ = self._user_instruction() self.assistant_start, self.assistant_end, _, _ = self._assistant_instruction() diff --git a/lmdeploy/pytorch/configurations/qwen3_omni.py b/lmdeploy/pytorch/configurations/qwen3_omni.py new file mode 100644 index 0000000000..9d1e3ec61c --- /dev/null +++ b/lmdeploy/pytorch/configurations/qwen3_omni.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .builder import AutoModelConfigBuilder +from .default import DefaultModelConfigBuilder + + +class Qwen3OmniModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.model_type == 'qwen3_omni_moe' + + @classmethod + def build(cls, hf_config, model_path: str = None, **kwargs): + """build.""" + cfg = DefaultModelConfigBuilder.build(hf_config.thinker_config.text_config, model_path, **kwargs) + cfg.hf_config = hf_config + cfg.use_mrope = True + return cfg diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index f2f21eb93b..847f74d6cf 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -190,6 +190,13 @@ 'Qwen3_5MTPModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_5_mtp.Qwen3_5MTPModel', }) +# qwen3 omni moe thinker +# only support thinker module, so map to Qwen3OmniMoeThinkerForConditionalGeneration +MODULE_MAP.update({ + 'Qwen3OmniMoeForConditionalGeneration': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_omni_moe_thinker.Qwen3OmniMoeThinkerForConditionalGeneration', +}) + # starcoder2 MODULE_MAP.update({ 'Starcoder2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.Starcoder2ForCausalLM', diff --git a/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py new file mode 100644 index 0000000000..4d5d34091a --- /dev/null +++ b/lmdeploy/pytorch/models/qwen3_omni_moe_thinker.py @@ -0,0 +1,909 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import math +from collections.abc import Iterable, Sequence +from functools import lru_cache +from typing import Any + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalData +from lmdeploy.pytorch.nn import ApplyRotaryEmb, FlashAttention, LayerNorm +from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_qkv_proj, build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight +from lmdeploy.vl.constants import Modality + +from .qwen3_vl import Qwen3VLVisionBlock, Qwen3VLVisionPatchEmbed, Qwen3VLVisionRotaryEmbedding +from .qwen3_vl_moe import Qwen3VLMoeTextModel +from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixin, vlm_model + + +def _get_feat_extract_output_lengths(input_lengths): + """Computes the output length of the convolutional layers and the output + length of the audio encoder.""" + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + +class Qwen3OmniMoeAudioAttention(nn.Module): + """Vision attention.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + dim = config.d_model + num_heads = config.encoder_attention_heads + head_dim = dim // num_heads + self.head_dim = head_dim + + # packed qkv + self.qkv_proj = build_qkv_proj( + dim, + num_q_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + ) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + # attention + self.attention = FlashAttention( + num_heads, + head_dim, + causal=False, + ) + + # o_proj + self.out_proj = build_rowwise_linear(dim, + dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + q, k, v = self.qkv_proj.split_qkv(qkv_states) + + attn_output = self.attention( + q, + k, + v, + q_start_loc=cu_seqlens[:-1], + q_seqlens=cu_seqlens[1:] - cu_seqlens[:-1], + ) + + attn_output = attn_output.reshape(seq_length, -1) + + # o proj + attn_output = self.out_proj(attn_output) + return attn_output + + +class Qwen3OmniMoeAudioEncoderLayer(nn.Module): + """Qwen3OmniMoeAudioEncoderLayer.""" + + def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None: + super().__init__() + self.embed_dim = config.d_model + self.self_attn = Qwen3OmniMoeAudioAttention(config, dtype=dtype, device=device) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=1e-5, dtype=dtype, device=device) + + self.activation_fn = ACT2FN[config.activation_function] + self.fc1 = build_colwise_linear( + self.embed_dim, + config.encoder_ffn_dim, + bias=True, + dtype=dtype, + device=device, + ) + self.fc2 = build_rowwise_linear( + config.encoder_ffn_dim, + self.embed_dim, + bias=True, + dtype=dtype, + device=device, + ) + self.final_layer_norm = LayerNorm(self.embed_dim, eps=1e-5, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states, ) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError('SinusoidsPositionEmbedding needs even channels input') + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + 'positional_embedding', + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class Qwen3OmniMoeAudioEncoder(nn.Module): + """Qwen3OmniMoeAudioEncoder.""" + + def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None: + super().__init__() + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.layers = nn.ModuleList( + [Qwen3OmniMoeAudioEncoderLayer(config, dtype=dtype, device=device) for _ in range(config.encoder_layers)]) + self.ln_post = LayerNorm(config.d_model, eps=1e-5, dtype=dtype, device=device) + self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1, dtype=dtype, device=device) + self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, + config.downsample_hidden_size, + 3, + 2, + padding=1, + dtype=dtype, + device=device) + self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, + config.downsample_hidden_size, + 3, + 2, + padding=1, + dtype=dtype, + device=device) + conv_out_dim = config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2) + self.conv_out = nn.Linear( + conv_out_dim, + config.d_model, + bias=False, + dtype=dtype, + device=device, + ) + self.proj1 = nn.Linear(config.d_model, config.d_model, dtype=dtype, device=device) + self.act = ACT2FN[config.activation_function] + self.proj2 = nn.Linear(config.d_model, config.output_dim, dtype=dtype, device=device) + self.n_window_infer = config.n_window_infer + self.conv_chunksize = config.conv_chunksize + + def forward( + self, + input_features: torch.Tensor, + feature_lens: torch.Tensor, + aftercnn_lens=None, + ): + r"""feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + + mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length after cnn + """ + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = nn.utils.rnn.pad_sequence( + [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], + batch_first=True, + ) + padded_feature = padded_feature.unsqueeze(1) + # Split to chunk to avoid OOM during convolution + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) + + positional_embedding = ( + self.positional_embedding.positional_embedding[:padded_embed.shape[1], :].unsqueeze(0).to( + padded_embed.dtype)) + padded_embed = padded_embed + positional_embedding + hidden_states = padded_embed[padded_mask_after_cnn] + cu_chunk_lens = [0] + window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return hidden_states + + +class Qwen3OmniMoeVisionPatchMerger(nn.Module): + """Vision patch merger. + + Different namings with qwen3vl, but actual calculations are the same. + """ + + def __init__(self, + config: PretrainedConfig, + use_postshuffle_norm=False, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, + eps=1e-6, + dtype=dtype, + device=device) + self.mlp = nn.ModuleList([ + build_colwise_linear( + self.hidden_size, + self.hidden_size, + bias=True, + dtype=dtype, + device=device, + is_tp=True, + ), + nn.GELU(), + build_rowwise_linear( + self.hidden_size, + config.out_hidden_size, + bias=True, + dtype=dtype, + device=device, + is_tp=True, + ), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + for layer in self.mlp: + x = layer(x) + return x + + +@vlm_model +class Qwen3OmniMoeVisionEncoder(nn.Module): + """Vision transformer.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed(config=config, dtype=dtype, device=device) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2, device=device) + + self.blocks = nn.ModuleList( + [Qwen3VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)]) + self.merger = Qwen3OmniMoeVisionPatchMerger(config=config, + use_postshuffle_norm=False, + dtype=dtype, + device=device) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.merger_list = nn.ModuleList([ + Qwen3OmniMoeVisionPatchMerger(config=config, use_postshuffle_norm=True, dtype=dtype, device=device) + for _ in range(len(config.deepstack_visual_indexes)) + ]) + + @staticmethod + @lru_cache(maxsize=1024) + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + h_div = h // spatial_merge_size + w_div = w // spatial_merge_size + + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) + hpos_ids = hpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) + wpos_ids = wpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + """Rotary position embedding.""" + pos_ids = [] + + for t, h, w in grid_thw: + base = self.rot_pos_ids(int(h), int(w), self.spatial_merge_size) + pos_ids.append(base if t == 1 else base.repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + return rotary_pos_emb + + # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474 + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + device = self.pos_embed.weight.device + + outputs = [] + for t, h, w in grid_thw: + h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device) + w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij') + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij') + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij') + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - w01 + + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device) + + embeds = self.pos_embed(indices) + embeds *= weights + combined = embeds.sum(dim=0) + + combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, + pos_embeds: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states + pos_embeds + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + if layer_num in self.deepstack_visual_indexes: + deepstack_merge_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.merger_list[deepstack_merge_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +class Qwen3OmniMoeThinkerForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + thinker_config = config.thinker_config + + # build preprocessor + self.input_processor = Qwen3OmniInputProcessor(self.config) + + # build audio encoder + self.audio_tower = Qwen3OmniMoeAudioEncoder( + thinker_config.audio_config, + dtype=dtype, + device=device, + ) + + # build vision encoder + self.visual = Qwen3OmniMoeVisionEncoder( + thinker_config.vision_config, + dtype=dtype, + device=device, + ) + + # build text model + self.language_model = Qwen3VLMoeTextModel(thinker_config.text_config, dtype=dtype, device=device) + + # build lm_head + self.lm_head = build_rowwise_linear(thinker_config.text_config.hidden_size, + thinker_config.text_config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: list[list[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + mrope_position_ids: torch.Tensor = None, + pixel_values: torch.Tensor = None, + vis_cu_seqlens: torch.Tensor = None, + vis_pos_emb: torch.Tensor = None, + image_mask: torch.Tensor = None, + pos_embeds: torch.Tensor = None, + grid_thw: torch.Tensor = None, + audio_values: torch.Tensor = None, + audio_mask: torch.Tensor = None, + audio_feature_lengths: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + + visual_pos_masks = None + deepstack_visual_embeds = None + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + dtype = inputs_embeds.dtype + pixel_values = pixel_values.to(dtype) + vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype)) + + # get image embeds and deepstack visual embeds + image_embeds, deepstack_visual_embeds = self.visual(pixel_values, + cu_seqlens=vis_cu_seqlens, + rotary_pos_emb=vis_pos_emb, + pos_embeds=pos_embeds) + + # split image embeds per sample + split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) + + # mask and scatter to create final input embeddings + expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) + + visual_pos_masks = expanded_image_mask + + if audio_values is not None: + dtype = inputs_embeds.dtype + audio_values = audio_values.to(dtype) + audio_embeds = self.audio_tower( + input_features=audio_values, + feature_lens=audio_feature_lengths, + ) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask.unsqueeze(-1), audio_embeds) + + hidden_states = self.language_model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + mrope_position_ids=mrope_position_ids, + # args for deepstack + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.language_model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: list[list[torch.Tensor]], + inputs_embeds: torch.Tensor | None = None, + context: StepContext = None, + ): + """Prepare input.""" + + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + pixel_values = None + vis_cu_seqlens = None + vis_pos_emb = None + image_mask = None + grid_thw = None + pos_embeds = None + audio_values = None + audio_mask = None + audio_feature_lengths = None + if context.input_multimodals is not None: + mm_inputs = [input_mm.get('mm_data', []) for input_mm in context.input_multimodals] + # flatten batch + mm_inputs = [item for sublist in mm_inputs for item in sublist] + + if len(mm_inputs) > 0: + modality = mm_inputs[0].modality + + image_token_id = mm_inputs[0].meta.get('image_token_id') + video_token_id = mm_inputs[0].meta.get('video_token_id') + audio_token_id = mm_inputs[0].meta.get('audio_token_id') + + if modality == Modality.AUDIO: + audio_values = torch.cat([inp.data for inp in mm_inputs]) + # FIXME: zhouxinyu, batch ? + audio_values = audio_values.squeeze(0) + audio_mask = (input_ids == audio_token_id) + # FIXME: zhouxinyu, list ? + audio_feature_lengths = mm_inputs[0].meta['audio_feature_lengths'] + elif modality in [Modality.IMAGE, Modality.VIDEO]: + pixel_values = torch.cat([inp.data for inp in mm_inputs]) + + mm_token_id = image_token_id if modality == Modality.IMAGE else video_token_id + image_mask = (input_ids == mm_token_id) + + grid_thw = torch.cat([data.meta['grid_thw'] for data in mm_inputs]).cpu() + vis_pos_emb = self.visual.rot_pos_emb(grid_thw) + pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw) + vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).to(pixel_values.device) + vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32) + vis_pos_emb = vis_pos_emb.repeat(1, 2) + vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin()) + + mrope_position_ids = getattr(context, 'mrope_position_ids', None) + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + mrope_position_ids=mrope_position_ids, + pixel_values=pixel_values, + vis_cu_seqlens=vis_cu_seqlens, + vis_pos_emb=vis_pos_emb, + image_mask=image_mask, + grid_thw=grid_thw, + pos_embeds=pos_embeds, + audio_values=audio_values, + audio_mask=audio_mask, + audio_feature_lengths=audio_feature_lengths, + ) + + def rename_weight(self, name: str) -> str: + """Rename weight.""" + if name.startswith('thinker.model.'): + return 'language_model.' + name[len('thinker.model.'):] + elif name.startswith('thinker.visual.'): + return 'visual.' + name[len('thinker.visual.'):] + elif name.startswith('thinker.audio_tower.'): + return 'audio_tower.' + name[len('thinker.audio_tower.'):] + # thinker_config.text_config tie_word_embeddings = False + elif name.startswith('thinker.lm_head.'): + return 'lm_head.' + name[len('thinker.lm_head.'):] + return name + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter], + expert_params_mapping: list): + """Load weight experts.""" + + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + # expert mapping + num_experts = self.config.thinker_config.text_config.num_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + # (param_name, weight_name, expert_id, shard_id) + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + # skip talker and code2wav weights + if ('talker.' in name or 'code2wav.' in name): + continue + + name = name.replace('.block_sparse_moe.', '.mlp.') + if '.experts' in name: + self._load_weight_experts(name, + loaded_weight, + params_dict, + expert_params_mapping=expert_params_mapping) + else: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.qkv.' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def get_input_processor(self) -> BaseModelInputProcessor: + """Get input processor.""" + return self.input_processor + + +class Qwen3OmniInputProcessor(BaseModelInputProcessor): + """Qwen3 Omni input processor.""" + + def __init__(self, config: PretrainedConfig) -> None: + self.config = config + + @classmethod + def _get_multimodal_pos_ids(cls, grid_thw: Sequence[int]) -> np.ndarray: + """Get mrope ids.""" + t, h, w = grid_thw + h = h // 2 + w = w // 2 + stride = np.array([h * w, w, 1])[None] + size = np.array([t, h, w])[None] + pos_ids = np.arange(t * h * w)[:, None].repeat(3, axis=1) + pos_ids = pos_ids // stride % size + return pos_ids + + @classmethod + def make_mrope(cls, grid_thw: torch.Tensor): + img_pos_ids = cls._get_multimodal_pos_ids(grid_thw[0].tolist()) + return img_pos_ids + + def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: + """Make image MultiModalData.""" + pixel_values = input_mm['pixel_values'] + image_grid_thw = input_mm['image_grid_thw'] + offset = input_mm['offset'] + start = offset + image_token_id = input_mm['image_token_id'] + num_pad = input_mm['mm_token_num'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mrope_pos_ids = self.make_mrope(image_grid_thw) + + mm_data = MultiModalData(modality=Modality.IMAGE, + data=pixel_values, + start=start, + end=start + num_pad, + mrope_pos_ids=mrope_pos_ids, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + return mm_data + + def _make_video_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: + """Make video MultiModalData.""" + pixel_values_videos = input_mm['pixel_values_videos'] + video_grid_thw = input_mm['video_grid_thw'] + offset = input_mm['offset'] + start = offset + video_token_id = input_mm['video_token_id'] + num_pad = input_mm['mm_token_num'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mrope_pos_ids = self.make_mrope(video_grid_thw) + + mm_data = MultiModalData(modality=Modality.VIDEO, + data=pixel_values_videos, + start=start, + end=start + num_pad, + mrope_pos_ids=mrope_pos_ids, + meta=dict( + grid_thw=video_grid_thw, + video_token_id=video_token_id, + second_per_grid=input_mm.get('second_per_grid'), + )) + return mm_data + + def _make_audio_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: + """Make audio MultiModalData.""" + input_features = input_mm['input_features'] + offset = input_mm['offset'] + start = offset + audio_token_id = input_mm['audio_token_id'] + num_pad = input_mm['mm_token_num'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalData(modality=Modality.AUDIO, + data=input_features, + start=start, + end=start + num_pad, + meta=dict( + audio_token_id=audio_token_id, + audio_feature_lengths=input_mm.get('audio_feature_lengths'), + )) + return mm_data + + def preprocess_input(self, + input_ids: list[int], + input_multimodals: list[dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """Prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_mm_data = [] + for input_mm in input_multimodals: + modality = input_mm.get('modality') + + if modality == Modality.IMAGE: + mm_data = self._make_image_mm_data(input_mm) + elif modality == Modality.VIDEO: + mm_data = self._make_video_mm_data(input_mm) + elif modality == Modality.AUDIO: + mm_data = self._make_audio_mm_data(input_mm) + + input_mm_data.append(mm_data) + + result = PreprocessInputResult(input_ids=input_ids, input_multimodals=dict(mm_data=input_mm_data)) + + return result diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 8a12c878e5..60682c4e45 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -754,7 +754,7 @@ def _make_image_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: offset = input_mm['offset'] start = offset image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] + num_pad = input_mm['mm_token_num'] if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() @@ -775,7 +775,7 @@ def _make_video_mm_data(self, input_mm: dict[str, Any]) -> MultiModalData: offset = input_mm['offset'] start = offset video_token_id = input_mm['video_token_id'] - num_pad = input_mm['video_tokens'] + num_pad = input_mm['mm_token_num'] if isinstance(num_pad, torch.Tensor): num_pad = num_pad.item() diff --git a/lmdeploy/pytorch/multimodal/data_type.py b/lmdeploy/pytorch/multimodal/data_type.py index 6e4d847bf9..e5292cc5c5 100644 --- a/lmdeploy/pytorch/multimodal/data_type.py +++ b/lmdeploy/pytorch/multimodal/data_type.py @@ -16,12 +16,11 @@ class MultiModalData: start: int end: int | None = None meta: dict[str, Any] | None = None + modality: Modality = Modality.IMAGE # for qwen-vl mrope_pos_ids: np.ndarray | None = None - modality: Modality = Modality.IMAGE - def __post_init__(self): if self.end is None: self.end = self.start diff --git a/lmdeploy/serve/processors/multimodal.py b/lmdeploy/serve/processors/multimodal.py index 8847c6f2c1..a9c7317820 100644 --- a/lmdeploy/serve/processors/multimodal.py +++ b/lmdeploy/serve/processors/multimodal.py @@ -8,6 +8,7 @@ from lmdeploy.tokenizer import Tokenizer from lmdeploy.utils import get_logger from lmdeploy.vl.constants import Modality +from lmdeploy.vl.media.audio import AudioMediaIO from lmdeploy.vl.media.connection import load_from_url from lmdeploy.vl.media.image import ImageMediaIO from lmdeploy.vl.media.time_series import TimeSeriesMediaIO @@ -124,6 +125,10 @@ def _parse_multimodal_item(i: int, in_messages: list[dict], out_messages: list[d vid_io = VideoMediaIO(image_io=ImageMediaIO(), **media_io_kwargs.get('video', {})) data, metadata = load_from_url(data_src, vid_io) item_params['video_metadata'] = metadata + elif item_type == 'audio_url': + modality = Modality.AUDIO + audio_io = AudioMediaIO(**media_io_kwargs.get('audio', {})) + data = load_from_url(data_src, audio_io) elif item_type == 'time_series_url': modality = Modality.TIME_SERIES ts_io = TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {})) @@ -304,7 +309,7 @@ def _re_format_prompt_images_pair(prompt: tuple) -> dict: def _has_multimodal_input(self, messages: list[dict]) -> bool: """Check if messages contain multimodal input (images).""" - multimodal_types = ['image_url', 'image_data', 'video_url', 'time_series_url'] + multimodal_types = ['image_url', 'image_data', 'video_url', 'audio_url', 'time_series_url'] return any( isinstance(message.get('content'), list) and any( item.get('type') in multimodal_types for item in message['content']) for message in messages) diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 9d83dc06b4..8ecc8b965b 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -305,6 +305,10 @@ def _get_and_verify_max_len( for key in llm_keys: hf_config = getattr(hf_config, key, hf_config) + # for qwen3-omni thinker + if hasattr(hf_config, 'thinker_config'): + hf_config = hf_config.thinker_config.text_config + logger = get_logger('lmdeploy') derived_max_model_len = float('inf') possible_keys = [ diff --git a/lmdeploy/vl/media/audio.py b/lmdeploy/vl/media/audio.py new file mode 100644 index 0000000000..19d39bfff9 --- /dev/null +++ b/lmdeploy/vl/media/audio.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/media/audio.py + +import base64 +from io import BytesIO +from pathlib import Path + +import numpy.typing as npt + +from .base import MediaIO + + +class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): + + def __init__(self, **kwargs) -> None: + super().__init__() + + # lazy import to avoid dependency issues for users who don't use audio features + try: + import librosa + self._librosa = librosa + except ImportError: + raise ImportError('Please install librosa via `pip install librosa`.') + + try: + import soundfile + self._soundfile = soundfile + except ImportError: + raise ImportError('Please install soundfile via `pip install soundfile`.') + + # for potential custom arguments from --media-io-kwargs + self.kwargs = kwargs + + def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: + # sr = None, preserves the original sampling rate of the audio file + return self._librosa.load(BytesIO(data), sr=None) + + def load_base64( + self, + media_type: str, + data: str, + ) -> tuple[npt.NDArray, float]: + return self.load_bytes(base64.b64decode(data)) + + def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]: + return self._librosa.load(filepath, sr=None) + + def encode_base64( + self, + media: tuple[npt.NDArray, int], + *, + audio_format: str = 'WAV', + ) -> str: + audio, sr = media + + with BytesIO() as buffer: + self._soundfile.write(buffer, audio, sr, format=audio_format) + data = buffer.getvalue() + + return base64.b64encode(data).decode('utf-8') diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index 51ebb44419..b9fba14fa1 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -251,40 +251,32 @@ def to_pytorch_with_input_ids(self, messages): return dict(prompt=None, input_ids=input_ids, multimodal=preps) - def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): + def to_pytorch_aux(self, messages, prompt, mm_placeholder, tokenizer, sequence_start): """Auxiliary function to pack the preprocessing results in a format - compatible with what is required by pytorch engine. - - Args: - messages(list[dict]): the output of `preprocess` - prompt(str): the prompt after applying chat template - IMAGE_TOKEN(str): a placeholder where image tokens will be - inserted - tokenzer: the tokenizer model - sequence_start: starting flag of a sequence - """ - # collect all preprocessing result from messages - preps = [x['content'] for x in messages if x['role'] == 'preprocess'] - assert len(preps) == 1 - preps = preps[0] + compatible with what is required by pytorch engine.""" + # collect all multi-modal preprocessing result from messages, keyed by 'preprocess' + mm_items = [x['content'] for x in messages if x['role'] == 'preprocess'] + assert len(mm_items) == 1 + mm_items = mm_items[0] # split prompt into segments and validate data - segs = prompt.split(IMAGE_TOKEN) - assert len(segs) == len(preps) + 1, (f'the number of {IMAGE_TOKEN} is not equal ' - f'to input images, {len(segs) - 1} vs {len(preps)}') + prompt_segments = prompt.split(mm_placeholder) + assert len(prompt_segments) == len(mm_items) + 1, ( + f'the number of {mm_placeholder} is not equal ' + f'to input multi modal items, {len(mm_items) - 1} vs {len(prompt_segments)}') - # calculate the image token offset for each image + # calculate the token offset for each multi modal item input_ids = [] - for i, seg in enumerate(segs): - if i > 0 and i <= len(preps): - preps[i - 1].update(offset=len(input_ids)) - image_tokens = preps[i - 1]['image_tokens'] - assert self.image_token_id == preps[i - 1]['image_token_id'] - input_ids.extend([self.image_token_id] * image_tokens) + mm_placeholder_id = tokenizer.encode(mm_placeholder, add_special_tokens=False)[-1] + for i, seg in enumerate(prompt_segments): + if i > 0 and i <= len(mm_items): + mm_items[i - 1].update(offset=len(input_ids)) + mm_token_num = mm_items[i - 1]['mm_token_num'] + input_ids.extend([mm_placeholder_id] * mm_token_num) token_ids = tokenizer.encode(seg, add_bos=((i == 0) and sequence_start)) input_ids.extend(token_ids) - return dict(prompt=prompt, input_ids=input_ids, multimodal=preps) + return dict(prompt=prompt, input_ids=input_ids, multimodal=mm_items) def to_turbomind_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): """Auxiliary function to pack the forwarding results in a format diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py index 04ac5ab759..749a75632f 100644 --- a/lmdeploy/vl/model/builder.py +++ b/lmdeploy/vl/model/builder.py @@ -30,6 +30,7 @@ from .qwen2 import Qwen2VLModel # noqa F401 from .qwen3 import Qwen3VLModel # noqa F401 from .qwen3_5 import Qwen3_5Model # noqa F401 +from .qwen3_omni import Qwen3OmniModel # noqa F401 from .xcomposer2 import Xcomposer2VisionModel # noqa F401 from .yi import YiVisionModel # noqa F401 diff --git a/lmdeploy/vl/model/qwen3.py b/lmdeploy/vl/model/qwen3.py index e43dad838c..e33f94a287 100644 --- a/lmdeploy/vl/model/qwen3.py +++ b/lmdeploy/vl/model/qwen3.py @@ -66,7 +66,7 @@ def _preprocess_image(self, result = self.processor.image_processor(images=data, size=size, return_tensors='pt') merge_length = self.processor.image_processor.merge_size**2 image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length - result.update(dict(image_size=data.size, image_tokens=image_tokens, image_token_id=self.image_token_id)) + result.update(dict(image_size=data.size, mm_token_num=image_tokens, image_token_id=self.image_token_id)) return result def _preprocess_video(self, @@ -199,7 +199,7 @@ def to_pytorch_aux_video(self, messages, prompt, VIDEO_TOKEN, tokenizer, sequenc frame_preps.append( dict( offset=frame_offset, - video_tokens=frame_seqlen, + mm_token_num=frame_seqlen, pixel_values_videos=pixel_values_videos[frame_idx * h * w:(frame_idx + 1) * h * w], video_grid_thw=torch.tensor([[1, h, w]]), video_token_id=self.video_token_id, diff --git a/lmdeploy/vl/model/qwen3_omni.py b/lmdeploy/vl/model/qwen3_omni.py new file mode 100644 index 0000000000..1b5a4d1774 --- /dev/null +++ b/lmdeploy/vl/model/qwen3_omni.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any + +import torch +from transformers import AutoProcessor +from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import Qwen3OmniMoeProcessorKwargs +from transformers.models.whisper import WhisperFeatureExtractor + +from lmdeploy.utils import get_logger +from lmdeploy.vl.constants import Modality +from lmdeploy.vl.model.base import VISION_MODELS, VisionModel + +logger = get_logger('lmdeploy') + + +def check_transformers(): + try: + from transformers import Qwen3OmniMoeForConditionalGeneration # noqa: F401 + except ImportError: + raise ImportError('please install latest transformers by ' + 'pip install git+https://github.com/huggingface/transformers.git') + + +@VISION_MODELS.register_module() +class Qwen3OmniModel(VisionModel): + """Qwen3Omni model.""" + + _arch = ['Qwen3OmniMoeForConditionalGeneration'] + + def build_preprocessor(self): + check_transformers() + self.processor = AutoProcessor.from_pretrained(self.model_path) + tokenizer = self.processor.tokenizer + + # image tokens + self.image_token = self.processor.image_token + self.image_token_id = tokenizer.encode(self.image_token)[-1] + + # video tokens + self.video_token = self.processor.video_token + self.video_token_id = tokenizer.encode(self.video_token)[-1] + + # audio tokens + self.audio_token = self.processor.audio_token + self.audio_token_id = tokenizer.encode(self.audio_token)[-1] + + # default kwargs for hf processor + self.default_mm_processor_kwargs = Qwen3OmniMoeProcessorKwargs._defaults + + def resolve_size_params(self, default_size, mm_processor_kwargs: dict[str, Any] | None = None): + default_min = default_size['shortest_edge'] + default_max = default_size['longest_edge'] + + if not mm_processor_kwargs: + return {'shortest_edge': default_min, 'longest_edge': default_max} + + min_pixels = mm_processor_kwargs.get('min_pixels', default_min) + max_pixels = mm_processor_kwargs.get('max_pixels', default_max) + + if min_pixels > max_pixels: + logger.warning(f'min_pixels {min_pixels} > max_pixels {max_pixels}, falling back to defaults.') + return {'shortest_edge': default_min, 'longest_edge': default_max} + + return {'shortest_edge': min_pixels, 'longest_edge': max_pixels} + + def _preprocess_image(self, + data: list[Any], + params: dict[str, Any], + mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: + + default_image_size = self.processor.image_processor.size + size = self.resolve_size_params(default_image_size, mm_processor_kwargs) + result = self.processor.image_processor(images=data, size=size, return_tensors='pt') + merge_length = self.processor.image_processor.merge_size**2 + image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length + result.update(dict(image_size=data.size, mm_token_num=image_tokens, image_token_id=self.image_token_id)) + return result + + def _preprocess_video(self, + data: list[Any], + params: dict[str, Any], + mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: + + metadata = params['video_metadata'] + if metadata.get('fps') is None or metadata['fps'] <= 0: + logger.warning('Qwen3Omni: fps not found or invalid, fallback to 24.') + metadata['fps'] = 24 + + defualt_video_kwargs = self.default_mm_processor_kwargs['videos_kwargs'] + size = self.resolve_size_params(defualt_video_kwargs['size'], mm_processor_kwargs) + + # do_resize = True, we leave resize to hf processor + # do_sample_frames = False, we already sample frames in video loader, avoid duplicates in hf processor + result = self.processor.video_processor(videos=data, + size=size, + return_metadata=True, + do_resize=True, + do_sample_frames=False, + video_metadata=metadata, + return_tensors='pt') + + merge_length = self.processor.video_processor.merge_size**2 + video_grid_thw = result['video_grid_thw'] + # TODO: custom fps + second_per_grid = self.processor.video_processor.temporal_patch_size / defualt_video_kwargs.get('fps', 1.0) + video_tokens = video_grid_thw[0].prod() // merge_length # T * H * W / merge_size^2 + + result.update(mm_token_num=video_tokens, + second_per_grid=second_per_grid, + video_token_id=self.video_token_id) + return result + + def _get_feat_extract_output_lengths(self, input_lengths): + """Computes the output length of the convolutional layers and the + output length of the audio encoder.""" + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + def _preprocess_audio(self, + data: list[Any], + params: dict[str, Any], + mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: + audio, original_sr = data + defualt_audio_kwargs = self.default_mm_processor_kwargs['audios_kwargs'] + feature_extractor = self.processor.feature_extractor + assert isinstance(feature_extractor, WhisperFeatureExtractor), \ + 'Qwen3Omni audio processor only support WhisperFeatureExtractor' + + # truncation is explicitly set to False to avoid different hf processor behavior + # https://github.com/huggingface/transformers/pull/41473 + result = feature_extractor(audio, + truncation=False, + return_tensors='pt', + **defualt_audio_kwargs) + + feature_attention_mask = result.get('attention_mask') + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + audio_output_length = self._get_feat_extract_output_lengths(audio_feature_lengths) + audio_tokens = audio_output_length + + result.update( + dict(mm_token_num=audio_tokens, + audio_feature_lengths=audio_feature_lengths, + audio_token_id=self.audio_token_id)) + return result + + def preprocess(self, messages: list[dict], mm_processor_kwargs: dict[str, Any] | None = None) -> list[dict]: + """Refer to `super().preprocess()` for spec.""" + outputs = [] + self.contains_video_input = False + self.contains_audio_input = False + + mm_items = self.collect_multimodal_items(messages) + for modality, data, params in mm_items: + result = {} + if modality == Modality.IMAGE: + result = self._preprocess_image(data, params, mm_processor_kwargs) + elif modality == Modality.VIDEO: + self.contains_video_input = True + result = self._preprocess_video(data, params, mm_processor_kwargs) + elif modality == Modality.AUDIO: + self.contains_audio_input = True + result = self._preprocess_audio(data, params, mm_processor_kwargs) + + result.update(modality=modality) + outputs.append(result) + + messages.append(dict(role='preprocess', content=outputs)) + return messages + + def proc_messages(self, messages, chat_template, sequence_start, chat_template_kwargs=None): + """Apply chat template to get the prompt.""" + chat_template_kwargs = chat_template_kwargs or {} + messages = [x for x in messages if x['role'] not in ['preprocess', 'forward']] + prompt = chat_template.messages2prompt(messages, sequence_start, **chat_template_kwargs) + + mm_placeholder = self.image_token + if self.contains_video_input: + mm_placeholder = self.video_token + elif self.contains_audio_input: + mm_placeholder = self.audio_token + + return prompt, mm_placeholder + + def to_pytorch(self, + messages, + chat_template, + tokenizer, + sequence_start, + chat_template_kwargs: dict | None = None, + **kwargs): + """Return to the information needed by pytorch engine.""" + prompt, mm_placeholder = self.proc_messages(messages, chat_template, sequence_start, chat_template_kwargs) + return self.to_pytorch_aux(messages, prompt, mm_placeholder, tokenizer, sequence_start)