diff --git a/tests/tool_parsers/test_gemma4_tool_parser.py b/tests/tool_parsers/test_gemma4_tool_parser.py index 26722e68d762..0631c9732204 100644 --- a/tests/tool_parsers/test_gemma4_tool_parser.py +++ b/tests/tool_parsers/test_gemma4_tool_parser.py @@ -104,6 +104,36 @@ def test_array_of_strings(self): result = _parse_gemma4_args('items:[<|"|>a<|"|>,<|"|>b<|"|>]') assert result == {"items": ["a", "b"]} + def test_string_with_internal_quotes(self): + """String values containing " must be preserved.""" + result = _parse_gemma4_args('content:<|"|>She said "hello" loudly<|"|>') + assert result == {"content": 'She said "hello" loudly'} + + def test_string_with_braces(self): + """String values containing { and } must be preserved.""" + result = _parse_gemma4_args( + 'content:<|"|>
{test}
<|"|>' + ) + assert result == {"content": "
{test}
"} + + def test_string_with_html_attributes(self): + """HTML attributes like class="main" inside string values.""" + result = _parse_gemma4_args( + 'path:<|"|>out.html<|"|>,' + 'content:<|"|>
hello
<|"|>' + ) + assert result == { + "path": "out.html", + "content": '
hello
', + } + + def test_string_with_code_content(self): + """Code with braces and quotes in string values.""" + result = _parse_gemma4_args( + 'code:<|"|>function() { return "ok"; }<|"|>' + ) + assert result == {"code": 'function() { return "ok"; }'} + def test_unterminated_string(self): """Unterminated strings should take everything after the delimiter.""" result = _parse_gemma4_args('key:<|"|>unterminated') diff --git a/tests/tool_parsers/test_gemma4_utils.py b/tests/tool_parsers/test_gemma4_utils.py new file mode 100644 index 000000000000..1034fac2066f --- /dev/null +++ b/tests/tool_parsers/test_gemma4_utils.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for vllm.tool_parsers.gemma4_utils — offline inference tool parser.""" + +import pytest + +from vllm.tool_parsers.gemma4_utils import ( + _parse_tool_arguments, + has_tool_response_tag, + parse_tool_calls, +) + + +# --------------------------------------------------------------------------- +# _parse_tool_arguments +# --------------------------------------------------------------------------- + + +class TestParseToolArguments: + """Unit tests for the offline tool argument parser.""" + + def test_simple_string(self): + result = _parse_tool_arguments('location:<|"|>Tokyo<|"|>') + assert result["location"] == "Tokyo" + + def test_multiple_string_values(self): + result = _parse_tool_arguments( + 'location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>' + ) + assert result == {"location": "San Francisco", "unit": "celsius"} + + def test_string_with_internal_quotes(self): + """Regression: internal " must not truncate the value.""" + result = _parse_tool_arguments( + 'content:<|"|>She said "hello" loudly<|"|>' + ) + assert result["content"] == 'She said "hello" loudly' + + def test_html_with_quoted_attributes(self): + """Regression: HTML attributes like class="main" must be preserved.""" + result = _parse_tool_arguments( + 'path:<|"|>out.html<|"|>,' + 'content:<|"|>
hello
<|"|>' + ) + assert result["path"] == "out.html" + assert result["content"] == '
hello
' + + def test_string_with_braces(self): + """String values containing { and } must be preserved.""" + result = _parse_tool_arguments( + 'content:<|"|>
{test}
<|"|>' + ) + assert result["content"] == "
{test}
" + + def test_string_with_mixed_special_chars(self): + """Values with quotes, braces, and angle brackets.""" + result = _parse_tool_arguments( + 'code:<|"|>function() { return "ok"; }<|"|>' + ) + assert result["code"] == 'function() { return "ok"; }' + + def test_empty_string(self): + assert _parse_tool_arguments("") == {} + + def test_whitespace_only(self): + assert _parse_tool_arguments(" ") == {} + + def test_bare_numeric_value(self): + """Bare (non-delimited) numeric values.""" + result = _parse_tool_arguments("count:42") + # Should return the value (as string is acceptable for offline parser) + assert "count" in result + + def test_multiline_content(self): + """Multi-line string values (e.g., code blocks).""" + result = _parse_tool_arguments( + 'content:<|"|>line1\nline2\nline3<|"|>' + ) + assert "line1" in result["content"] + assert "line3" in result["content"] + + +# --------------------------------------------------------------------------- +# parse_tool_calls (end-to-end) +# --------------------------------------------------------------------------- + + +class TestParseToolCalls: + """End-to-end tests for parse_tool_calls.""" + + def test_standard_format(self): + text = '<|tool_call>call:get_weather{location:<|"|>London<|"|>}' + calls = parse_tool_calls(text) + assert len(calls) == 1 + assert calls[0]["name"] == "get_weather" + assert calls[0]["arguments"]["location"] == "London" + + def test_standard_format_with_internal_quotes(self): + """Regression: tool call with internal quotes in string value.""" + text = ( + "<|tool_call>call:write_file{" + 'path:<|"|>index.html<|"|>,' + 'content:<|"|>
Hello
<|"|>' + "}" + ) + calls = parse_tool_calls(text) + assert len(calls) == 1 + assert calls[0]["name"] == "write_file" + assert calls[0]["arguments"]["path"] == "index.html" + assert calls[0]["arguments"]["content"] == '
Hello
' + + def test_multiple_tool_calls(self): + text = ( + '<|tool_call>call:read_file{path:<|"|>a.txt<|"|>}' + '<|tool_call>call:read_file{path:<|"|>b.txt<|"|>}' + ) + calls = parse_tool_calls(text) + assert len(calls) == 2 + + def test_no_tool_calls(self): + assert parse_tool_calls("Hello, how can I help?") == [] + + def test_strict_mode_ignores_fallback(self): + text = "call:get_weather{location:Tokyo}" + assert parse_tool_calls(text, strict=True) == [] + assert len(parse_tool_calls(text, strict=False)) >= 1 + + +# --------------------------------------------------------------------------- +# has_tool_response_tag +# --------------------------------------------------------------------------- + + +class TestHasToolResponseTag: + def test_with_tag(self): + assert has_tool_response_tag("some text <|tool_response>") is True + + def test_without_tag(self): + assert has_tool_response_tag("some text ") is False + + def test_with_trailing_whitespace(self): + assert has_tool_response_tag("some text <|tool_response> ") is True diff --git a/vllm/tool_parsers/gemma4_tool_parser.py b/vllm/tool_parsers/gemma4_tool_parser.py index 406ba9e70205..f81d86fb3de9 100644 --- a/vllm/tool_parsers/gemma4_tool_parser.py +++ b/vllm/tool_parsers/gemma4_tool_parser.py @@ -39,6 +39,12 @@ from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser +from vllm.tool_parsers.gemma4_utils import ( + _ESCAPE_TOKEN as STRING_DELIM, + parse_gemma4_args as _parse_gemma4_args, + parse_gemma4_array as _parse_gemma4_array, + parse_gemma4_value as _parse_gemma4_value, +) from vllm.tool_parsers.utils import find_common_prefix logger = init_logger(__name__) @@ -46,207 +52,6 @@ # Gemma4 special tokens for tool calls TOOL_CALL_START = "<|tool_call>" TOOL_CALL_END = "" -STRING_DELIM = '<|"|>' - - -# --------------------------------------------------------------------------- -# Gemma4 argument parser (used by both streaming and non-streaming paths) -# --------------------------------------------------------------------------- - - -def _parse_gemma4_value(value_str: str) -> object: - """Parse a single Gemma4 value (after key:) into a Python object.""" - value_str = value_str.strip() - if not value_str: - return value_str - - # Boolean - if value_str == "true": - return True - if value_str == "false": - return False - - # Number (int or float) - try: - if "." in value_str: - return float(value_str) - return int(value_str) - except ValueError: - pass - - # Bare string (no <|"|> delimiters — shouldn't happen but be safe) - return value_str - - -def _parse_gemma4_args(args_str: str) -> dict: - """Parse Gemma4's custom key:value format into a Python dict. - - Format examples:: - - location:<|"|>Tokyo<|"|> - location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|> - count:42,flag:true - nested:{inner_key:<|"|>val<|"|>} - items:[<|"|>a<|"|>,<|"|>b<|"|>] - - Returns a dict ready for ``json.dumps()``. - """ - if not args_str or not args_str.strip(): - return {} - - result: dict = {} - i = 0 - n = len(args_str) - - while i < n: - # Skip whitespace and commas - while i < n and args_str[i] in (" ", ",", "\n", "\t"): - i += 1 - if i >= n: - break - - # Parse key (unquoted, ends at ':') - key_start = i - while i < n and args_str[i] != ":": - i += 1 - if i >= n: - break - key = args_str[key_start:i].strip() - i += 1 # skip ':' - - # Parse value - if i >= n: - result[key] = "" - break - - # Skip whitespace after ':' - while i < n and args_str[i] in (" ", "\n", "\t"): - i += 1 - if i >= n: - result[key] = "" - break - - # String value: <|"|>...<|"|> - if args_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) - val_start = i - end_pos = args_str.find(STRING_DELIM, i) - if end_pos == -1: - # Unterminated string — take rest - result[key] = args_str[val_start:] - break - result[key] = args_str[val_start:end_pos] - i = end_pos + len(STRING_DELIM) - - # Nested object: {...} - elif args_str[i] == "{": - depth = 1 - obj_start = i + 1 - i += 1 - while i < n and depth > 0: - if args_str[i:].startswith(STRING_DELIM): - # Skip over string contents to avoid counting { inside strings - i += len(STRING_DELIM) - next_delim = args_str.find(STRING_DELIM, i) - i = n if next_delim == -1 else next_delim + len(STRING_DELIM) - continue - if args_str[i] == "{": - depth += 1 - elif args_str[i] == "}": - depth -= 1 - i += 1 - result[key] = _parse_gemma4_args(args_str[obj_start : i - 1]) - - # Array: [...] - elif args_str[i] == "[": - depth = 1 - arr_start = i + 1 - i += 1 - while i < n and depth > 0: - if args_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) - next_delim = args_str.find(STRING_DELIM, i) - i = n if next_delim == -1 else next_delim + len(STRING_DELIM) - continue - if args_str[i] == "[": - depth += 1 - elif args_str[i] == "]": - depth -= 1 - i += 1 - arr_content = args_str[arr_start : i - 1] - result[key] = _parse_gemma4_array(arr_content) - - # Bare value (number, boolean, etc.) - else: - val_start = i - while i < n and args_str[i] not in (",", "}", "]"): - i += 1 - result[key] = _parse_gemma4_value(args_str[val_start:i]) - - return result - - -def _parse_gemma4_array(arr_str: str) -> list: - """Parse a Gemma4 array content string into a Python list.""" - items: list = [] - i = 0 - n = len(arr_str) - - while i < n: - while i < n and arr_str[i] in (" ", ",", "\n", "\t"): - i += 1 - if i >= n: - break - - # String element - if arr_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) - end_pos = arr_str.find(STRING_DELIM, i) - if end_pos == -1: - items.append(arr_str[i:]) - break - items.append(arr_str[i:end_pos]) - i = end_pos + len(STRING_DELIM) - - # Nested object - elif arr_str[i] == "{": - depth = 1 - obj_start = i + 1 - i += 1 - while i < n and depth > 0: - if arr_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) - nd = arr_str.find(STRING_DELIM, i) - i = nd + len(STRING_DELIM) if nd != -1 else n - continue - if arr_str[i] == "{": - depth += 1 - elif arr_str[i] == "}": - depth -= 1 - i += 1 - items.append(_parse_gemma4_args(arr_str[obj_start : i - 1])) - - # Nested array - elif arr_str[i] == "[": - depth = 1 - sub_start = i + 1 - i += 1 - while i < n and depth > 0: - if arr_str[i] == "[": - depth += 1 - elif arr_str[i] == "]": - depth -= 1 - i += 1 - items.append(_parse_gemma4_array(arr_str[sub_start : i - 1])) - - # Bare value - else: - val_start = i - while i < n and arr_str[i] not in (",", "]"): - i += 1 - items.append(_parse_gemma4_value(arr_str[val_start:i])) - - return items # --------------------------------------------------------------------------- diff --git a/vllm/tool_parsers/gemma4_utils.py b/vllm/tool_parsers/gemma4_utils.py index 439ad1125ce2..63b13ef7ca44 100644 --- a/vllm/tool_parsers/gemma4_utils.py +++ b/vllm/tool_parsers/gemma4_utils.py @@ -49,45 +49,237 @@ _ESCAPE_TOKEN = '<|"|>' -def _parse_tool_arguments(args_str: str) -> dict[str, str]: - """Parse tool call arguments from the Gemma4 compact format. +# --------------------------------------------------------------------------- +# Shared Gemma4 argument parser +# --------------------------------------------------------------------------- +# This parser handles the Gemma4 custom key:value format natively, +# without replacing <|"|> delimiters with quotes. It is used by both +# the offline parser (this module) and the API server parser +# (gemma4_tool_parser.py). + + +def parse_gemma4_value(value_str: str) -> object: + """Parse a single Gemma4 value (after key:) into a Python object.""" + value_str = value_str.strip() + if not value_str: + return value_str + + # Boolean + if value_str == "true": + return True + if value_str == "false": + return False + + # Number (int or float) + try: + if "." in value_str: + return float(value_str) + return int(value_str) + except ValueError: + pass - Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback - to heuristic key-value extraction. Also tolerates the slightly different - ``key: "value"`` format (space + plain quotes) that some chat templates - produce. + # Bare string (no <|"|> delimiters — shouldn't happen but be safe) + return value_str - Args: - args_str: Raw argument string from inside ``call:name{...}``. - Returns: - Dictionary of argument name → value. +def parse_gemma4_args(args_str: str) -> dict: + """Parse Gemma4's custom key:value format into a Python dict. + + Format examples:: + + location:<|"|>Tokyo<|"|> + location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|> + count:42,flag:true + nested:{inner_key:<|"|>val<|"|>} + items:[<|"|>a<|"|>,<|"|>b<|"|>] + + Returns a dict ready for ``json.dumps()``. """ if not args_str or not args_str.strip(): return {} - # Replace Gemma4 escape tokens with standard quotes. - cleaned = args_str.replace(_ESCAPE_TOKEN, '"') + result: dict = {} + i = 0 + n = len(args_str) + + while i < n: + # Skip whitespace and commas + while i < n and args_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # Parse key (unquoted, ends at ':') + key_start = i + while i < n and args_str[i] != ":": + i += 1 + if i >= n: + break + key = args_str[key_start:i].strip() + i += 1 # skip ':' + + # Parse value + if i >= n: + result[key] = "" + break + + # Skip whitespace after ':' + while i < n and args_str[i] in (" ", "\n", "\t"): + i += 1 + if i >= n: + result[key] = "" + break + + # String value: <|"|>...<|"|> + if args_str[i:].startswith(_ESCAPE_TOKEN): + i += len(_ESCAPE_TOKEN) + val_start = i + end_pos = args_str.find(_ESCAPE_TOKEN, i) + if end_pos == -1: + # Unterminated string — take rest + result[key] = args_str[val_start:] + break + result[key] = args_str[val_start:end_pos] + i = end_pos + len(_ESCAPE_TOKEN) + + # Nested object: {...} + elif args_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i:].startswith(_ESCAPE_TOKEN): + # Skip over string contents to avoid counting { inside strings + i += len(_ESCAPE_TOKEN) + next_delim = args_str.find(_ESCAPE_TOKEN, i) + i = n if next_delim == -1 else next_delim + len(_ESCAPE_TOKEN) + continue + if args_str[i] == "{": + depth += 1 + elif args_str[i] == "}": + depth -= 1 + i += 1 + result[key] = parse_gemma4_args(args_str[obj_start : i - 1]) + + # Array: [...] + elif args_str[i] == "[": + depth = 1 + arr_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i:].startswith(_ESCAPE_TOKEN): + i += len(_ESCAPE_TOKEN) + next_delim = args_str.find(_ESCAPE_TOKEN, i) + i = n if next_delim == -1 else next_delim + len(_ESCAPE_TOKEN) + continue + if args_str[i] == "[": + depth += 1 + elif args_str[i] == "]": + depth -= 1 + i += 1 + arr_content = args_str[arr_start : i - 1] + result[key] = parse_gemma4_array(arr_content) + + # Bare value (number, boolean, etc.) + else: + val_start = i + while i < n and args_str[i] not in (",", "}", "]"): + i += 1 + result[key] = parse_gemma4_value(args_str[val_start:i]) + + return result + + +def parse_gemma4_array(arr_str: str) -> list: + """Parse a Gemma4 array content string into a Python list.""" + items: list = [] + i = 0 + n = len(arr_str) + + while i < n: + while i < n and arr_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # String element + if arr_str[i:].startswith(_ESCAPE_TOKEN): + i += len(_ESCAPE_TOKEN) + end_pos = arr_str.find(_ESCAPE_TOKEN, i) + if end_pos == -1: + items.append(arr_str[i:]) + break + items.append(arr_str[i:end_pos]) + i = end_pos + len(_ESCAPE_TOKEN) + + # Nested object + elif arr_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i:].startswith(_ESCAPE_TOKEN): + i += len(_ESCAPE_TOKEN) + nd = arr_str.find(_ESCAPE_TOKEN, i) + i = nd + len(_ESCAPE_TOKEN) if nd != -1 else n + continue + if arr_str[i] == "{": + depth += 1 + elif arr_str[i] == "}": + depth -= 1 + i += 1 + items.append(parse_gemma4_args(arr_str[obj_start : i - 1])) + + # Nested array + elif arr_str[i] == "[": + depth = 1 + sub_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i:].startswith(_ESCAPE_TOKEN): + i += len(_ESCAPE_TOKEN) + nd = arr_str.find(_ESCAPE_TOKEN, i) + i = nd + len(_ESCAPE_TOKEN) if nd != -1 else n + continue + if arr_str[i] == "[": + depth += 1 + elif arr_str[i] == "]": + depth -= 1 + i += 1 + items.append(parse_gemma4_array(arr_str[sub_start : i - 1])) + + # Bare value + else: + val_start = i + while i < n and arr_str[i] not in (",", "]"): + i += 1 + items.append(parse_gemma4_value(arr_str[val_start:i])) + + return items + + +# --------------------------------------------------------------------------- +# Legacy wrapper (kept for backward compatibility) +# --------------------------------------------------------------------------- - # Try JSON parsing first (handles nested values, arrays, etc.). - try: - parsed = json.loads("{" + cleaned + "}") - # Ensure all values are strings for consistency. - return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()} - except (json.JSONDecodeError, ValueError): - pass - # Fallback: extract key:"value" pairs (allow optional space after colon). - arguments = {} - for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned): - arguments[key] = value +def _parse_tool_arguments(args_str: str) -> dict[str, str]: + """Parse tool call arguments from the Gemma4 compact format. - if not arguments: - # Last resort: extract key:value pairs (unquoted). - for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str): - arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "") + Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, including + string values containing internal quotes, braces, and other special + characters. - return arguments + Args: + args_str: Raw argument string from inside ``call:name{...}``. + + Returns: + Dictionary of argument name → value (string). + """ + parsed = parse_gemma4_args(args_str) + # Convert all values to strings for backward compatibility + # with callers that expect dict[str, str]. + return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()} def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]: