diff --git a/lmdeploy/serve/processors/multimodal.py b/lmdeploy/serve/processors/multimodal.py index 8847c6f2c1..f0ffda5a96 100644 --- a/lmdeploy/serve/processors/multimodal.py +++ b/lmdeploy/serve/processors/multimodal.py @@ -109,25 +109,37 @@ def _parse_multimodal_item(i: int, in_messages: list[dict], out_messages: list[d out_message['content'].append(item) continue - item_params = item.get(item_type, {}).copy() - data_src = item_params.pop('url', None) or item_params.pop('data', None) + item_val = item.get(item_type, {}) + if isinstance(item_val, dict): + # value is a dict containing data and other params + # msg = {'type': 'image_url', 'image_url': {'url': xxx, ...}} + # msg = {'type': 'image', 'image': {'url': xxx, ...}} + # msg = {'type': 'image_data', 'image_data': {'data': PIL.Image.Image, ...}} + item_params = item_val.copy() + data_src = item_params.pop('url', None) or item_params.pop('data', None) + else: + # value is a direct data reference + # msg = {'type': 'image_url', 'image_url': 'xxx', ...} + # msg = {'type': 'image', 'image': 'xxx', ...} + item_params = {k: v for k, v in item.items() if k not in ('type', item_type)} + data_src = item_val if item_type == 'image_data': modality = Modality.IMAGE data = data_src - elif item_type == 'image_url': + elif item_type in ('image_url', 'image'): modality = Modality.IMAGE - img_io = ImageMediaIO(**media_io_kwargs.get('image', {})) - data = load_from_url(data_src, img_io) - elif item_type == 'video_url': + data = load_from_url(data_src, ImageMediaIO(**media_io_kwargs.get('image', {}))) + elif item_type in ('video_url', 'video'): modality = Modality.VIDEO - vid_io = VideoMediaIO(image_io=ImageMediaIO(), **media_io_kwargs.get('video', {})) - data, metadata = load_from_url(data_src, vid_io) + data, metadata = load_from_url(data_src, + VideoMediaIO(image_io=ImageMediaIO(), + **media_io_kwargs.get('video', {})) + ) item_params['video_metadata'] = metadata - elif item_type == 'time_series_url': + elif item_type in ('time_series_url', 'time_series'): modality = Modality.TIME_SERIES - ts_io = TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {})) - data = load_from_url(data_src, ts_io) + data = load_from_url(data_src, TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {}))) else: raise NotImplementedError(f'unknown type: {item_type}') @@ -304,7 +316,7 @@ def _re_format_prompt_images_pair(prompt: tuple) -> dict: def _has_multimodal_input(self, messages: list[dict]) -> bool: """Check if messages contain multimodal input (images).""" - multimodal_types = ['image_url', 'image_data', 'video_url', 'time_series_url'] + multimodal_types = ['image_url', 'image_data', 'image', 'video_url', 'video', 'time_series_url', 'time_series'] return any( isinstance(message.get('content'), list) and any( item.get('type') in multimodal_types for item in message['content']) for message in messages)