Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions lmdeploy/serve/processors/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,37 @@ def _parse_multimodal_item(i: int, in_messages: list[dict], out_messages: list[d
out_message['content'].append(item)
continue

item_params = item.get(item_type, {}).copy()
data_src = item_params.pop('url', None) or item_params.pop('data', None)
item_val = item.get(item_type, {})
if isinstance(item_val, dict):
# value is a dict containing data and other params
# msg = {'type': 'image_url', 'image_url': {'url': xxx, ...}}
# msg = {'type': 'image', 'image': {'url': xxx, ...}}
# msg = {'type': 'image_data', 'image_data': {'data': PIL.Image.Image, ...}}
item_params = item_val.copy()
data_src = item_params.pop('url', None) or item_params.pop('data', None)
else:
# value is a direct data reference
# msg = {'type': 'image_url', 'image_url': 'xxx', ...}
# msg = {'type': 'image', 'image': 'xxx', ...}
item_params = {k: v for k, v in item.items() if k not in ('type', item_type)}
data_src = item_val

if item_type == 'image_data':
modality = Modality.IMAGE
data = data_src
elif item_type == 'image_url':
elif item_type in ('image_url', 'image'):
modality = Modality.IMAGE
img_io = ImageMediaIO(**media_io_kwargs.get('image', {}))
data = load_from_url(data_src, img_io)
elif item_type == 'video_url':
data = load_from_url(data_src, ImageMediaIO(**media_io_kwargs.get('image', {})))
elif item_type in ('video_url', 'video'):
modality = Modality.VIDEO
vid_io = VideoMediaIO(image_io=ImageMediaIO(), **media_io_kwargs.get('video', {}))
data, metadata = load_from_url(data_src, vid_io)
data, metadata = load_from_url(data_src,
VideoMediaIO(image_io=ImageMediaIO(),
**media_io_kwargs.get('video', {}))
)
item_params['video_metadata'] = metadata
elif item_type == 'time_series_url':
elif item_type in ('time_series_url', 'time_series'):
modality = Modality.TIME_SERIES
ts_io = TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {}))
data = load_from_url(data_src, ts_io)
data = load_from_url(data_src, TimeSeriesMediaIO(**media_io_kwargs.get('time_series', {})))
else:
raise NotImplementedError(f'unknown type: {item_type}')

Expand Down Expand Up @@ -304,7 +316,7 @@ def _re_format_prompt_images_pair(prompt: tuple) -> dict:

def _has_multimodal_input(self, messages: list[dict]) -> bool:
"""Check if messages contain multimodal input (images)."""
multimodal_types = ['image_url', 'image_data', 'video_url', 'time_series_url']
multimodal_types = ['image_url', 'image_data', 'image', 'video_url', 'video', 'time_series_url', 'time_series']
return any(
isinstance(message.get('content'), list) and any(
item.get('type') in multimodal_types for item in message['content']) for message in messages)
Expand Down
Loading