Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 269 additions & 1 deletion src/transformers/models/gemma4/convert_gemma4_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,278 @@

# ==== Internal Constants and Classes ====


def _patch_template_for_openai_tool_role(template: str) -> str:
"""Patch a Gemma4 chat template to support OpenAI-standard ``role: "tool"`` messages.

Applies three string replacements to the upstream template:

1. Injects a ``format_tool_response_block`` macro after the ``strip_thinking`` macro
to DRY up tool-response rendering.
2. Injects a ``last_user_idx`` pre-scan and replaces the entire message loop to:
- Skip ``role: "tool"`` messages in the outer loop (they are rendered proactively).
- Forward-scan consecutive ``role: "tool"`` messages from assistant turns that
have ``tool_calls``, rendering them as ``<|tool_response>`` blocks.
- Resolve ``tool_call_id`` back to function names from the originating ``tool_calls``.
- Handle ``content`` as both plain strings and OpenAI content-parts arrays.
- Suppress duplicate ``<|turn>model`` when consecutive assistant messages are
separated only by tool messages (multi-round tool-call loops).
- Render ``reasoning`` / ``reasoning_content`` fields as ``<|channel>thought`` blocks.
3. Preserves legacy ``tool_responses`` on assistant messages (Google/Gemma native format).
"""
# --- Change 1: Inject format_tool_response_block macro after strip_thinking ---
old_after_strip = """{%- endmacro -%}\n\n{%- set ns = namespace(prev_message_type=None) -%}"""

new_after_strip = (
"""{%- endmacro -%}\n"""
"""\n"""
"""{%- macro format_tool_response_block(tool_name, response) -%}\n"""
""" {{- '<|tool_response>' -}}\n"""
""" {%- if response is mapping -%}\n"""
""" {{- 'response:' + tool_name + '{' -}}\n"""
""" {%- for key, value in response | dictsort -%}\n"""
""" {{- key -}}:{{- format_argument(value, escape_keys=False) -}}\n"""
""" {%- if not loop.last %},{% endif -%}\n"""
""" {%- endfor -%}\n"""
""" {{- '}' -}}\n"""
""" {%- else -%}\n"""
""" {{- 'response:' + tool_name + '{value:' + format_argument(response, escape_keys=False) + '}' -}}\n"""
""" {%- endif -%}\n"""
""" {{- '<tool_response|>' -}}\n"""
"""{%- endmacro -%}\n"""
"""\n"""
"""{%- set ns = namespace(prev_message_type=None) -%}"""
)
template = template.replace(old_after_strip, new_after_strip)

# --- Change 2: Replace entire message loop with OpenAI-compatible version ---
# The old message loop is identical between E4B and 31B templates.
old_message_loop = (
"""{#- Loop through messages -#}\n"""
"""{%- for message in loop_messages -%}\n"""
""" {%- set ns.prev_message_type = None -%}\n"""
""" {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}\n"""
""" {{- '<|turn>' + role + '\\n' }}\n"""
"""\n"""
""" {%- if message['tool_calls'] -%}\n"""
""" {%- for tool_call in message['tool_calls'] -%}\n"""
""" {%- set function = tool_call['function'] -%}\n"""
""" {{- '<|tool_call>call:' + function['name'] + '{' -}}\n"""
""" {%- if function['arguments'] is mapping -%}\n"""
""" {%- set ns_args = namespace(found_first=false) -%}\n"""
""" {%- for key, value in function['arguments'] | dictsort -%}\n"""
""" {%- if ns_args.found_first %},{% endif -%}\n"""
""" {%- set ns_args.found_first = true -%}\n"""
""" {{- key -}}:{{- format_argument(value, escape_keys=False) -}}\n"""
""" {%- endfor -%}\n"""
""" {%- elif function['arguments'] is string -%}\n"""
""" {{- function['arguments'] -}}\n"""
""" {%- endif -%}\n"""
""" {{- '}<tool_call|>' -}}\n"""
""" {%- endfor -%}\n"""
""" {%- set ns.prev_message_type = 'tool_call' -%}\n"""
""" {%- endif -%}\n"""
"""\n"""
""" {%- if message['tool_responses'] -%}\n"""
""" {#- Tool Response handling -#}\n"""
""" {%- for tool_response in message['tool_responses'] -%}\n"""
""" {{- '<|tool_response>' -}}\n"""
""" {%- if tool_response['response'] is mapping -%}\n"""
""" {{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}\n"""
""" {%- for key, value in tool_response['response'] | dictsort -%}\n"""
""" {{- key -}}:{{- format_argument(value, escape_keys=False) -}}\n"""
""" {%- if not loop.last %},{% endif -%}\n"""
""" {%- endfor -%}\n"""
""" {{- '}' -}}\n"""
""" {%- else -%}\n"""
""" {{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}\n"""
""" {%- endif -%}\n"""
""" {{- '<tool_response|>' -}}\n"""
""" {%- endfor -%}\n"""
""" {%- set ns.prev_message_type = 'tool_response' -%}\n"""
""" {%- endif -%}\n"""
"""\n"""
""" {%- if message['content'] is string -%}\n"""
""" {%- if role == 'model' -%}\n"""
""" {{- strip_thinking(message['content']) -}}\n"""
""" {%- else -%}\n"""
""" {{- message['content'] | trim -}}\n"""
""" {%- endif -%}\n"""
""" {%- elif message['content'] is sequence -%}\n"""
""" {%- for item in message['content'] -%}\n"""
""" {%- if item['type'] == 'text' -%}\n"""
""" {%- if role == 'model' -%}\n"""
""" {{- strip_thinking(item['text']) -}}\n"""
""" {%- else -%}\n"""
""" {{- item['text'] | trim -}}\n"""
""" {%- endif -%}\n"""
""" {%- elif item['type'] == 'image' -%}\n"""
""" {{- '\\n\\n<|image|>\\n\\n' -}}\n"""
""" {%- set ns.prev_message_type = 'image' -%}\n"""
""" {%- elif item['type'] == 'audio' -%}\n"""
""" {{- '<|audio|>' -}}\n"""
""" {%- set ns.prev_message_type = 'audio' -%}\n"""
""" {%- elif item['type'] == 'video' -%}\n"""
""" {{- '\\n\\n<|video|>\\n\\n' -}}\n"""
""" {%- set ns.prev_message_type = 'video' -%}\n"""
""" {%- endif -%}\n"""
""" {%- endfor -%}\n"""
""" {%- endif -%}\n"""
"""\n"""
""" {%- if not (message['tool_responses'] and not message['content']) -%}\n"""
""" {{- '<turn|>\\n' -}}\n"""
""" {%- endif -%}\n"""
"""{%- endfor -%}"""
)

new_message_loop = (
"""{#- Pre-scan: find last user message index for reasoning guard -#}\n"""
"""{%- set ns_turn = namespace(last_user_idx=-1) -%}\n"""
"""{%- for i in range(loop_messages | length) -%}\n"""
""" {%- if loop_messages[i]['role'] == 'user' -%}\n"""
""" {%- set ns_turn.last_user_idx = i -%}\n"""
""" {%- endif -%}\n"""
"""{%- endfor -%}\n"""
"""\n"""
"""{#- Loop through messages -#}\n"""
"""{%- for message in loop_messages -%}\n"""
""" {%- if message['role'] != 'tool' -%}\n"""
""" {%- set ns.prev_message_type = None -%}\n"""
""" {%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}\n"""
""" {#- Detect continuation: suppress duplicate <|turn>model when previous non-tool message was also assistant -#}\n"""
""" {%- set prev_nt = namespace(role=None, found=false) -%}\n"""
""" {%- if loop.index0 > 0 -%}\n"""
""" {%- for j in range(loop.index0 - 1, -1, -1) -%}\n"""
""" {%- if not prev_nt.found -%}\n"""
""" {%- if loop_messages[j]['role'] != 'tool' -%}\n"""
""" {%- set prev_nt.role = loop_messages[j]['role'] -%}\n"""
""" {%- set prev_nt.found = true -%}\n"""
""" {%- endif -%}\n"""
""" {%- endif -%}\n"""
""" {%- endfor -%}\n"""
""" {%- endif -%}\n"""
""" {%- set continue_same_model_turn = (role == 'model' and prev_nt.role == 'assistant') -%}\n"""
""" {%- if not continue_same_model_turn -%}\n"""
""" {{- '<|turn>' + role + '\\n' }}\n"""
""" {%- endif -%}\n"""
"""\n"""
""" {#- Render reasoning/reasoning_content as thinking channel -#}\n"""
""" {%- set thinking_text = message.get('reasoning') or message.get('reasoning_content') -%}\n"""
""" {%- if thinking_text and loop.index0 > ns_turn.last_user_idx and message.get('tool_calls') -%}\n"""
""" {{- '<|channel>thought\\n' + thinking_text + '\\n<channel|>' -}}\n"""
""" {%- endif -%}\n"""
"""\n"""
""" {%- if message['tool_calls'] -%}\n"""
""" {%- for tool_call in message['tool_calls'] -%}\n"""
""" {%- set function = tool_call['function'] -%}\n"""
""" {{- '<|tool_call>call:' + function['name'] + '{' -}}\n"""
""" {%- if function['arguments'] is mapping -%}\n"""
""" {%- set ns_args = namespace(found_first=false) -%}\n"""
""" {%- for key, value in function['arguments'] | dictsort -%}\n"""
""" {%- if ns_args.found_first %},{% endif -%}\n"""
""" {%- set ns_args.found_first = true -%}\n"""
""" {{- key -}}:{{- format_argument(value, escape_keys=False) -}}\n"""
""" {%- endfor -%}\n"""
""" {%- elif function['arguments'] is string -%}\n"""
""" {{- function['arguments'] -}}\n"""
""" {%- endif -%}\n"""
""" {{- '}<tool_call|>' -}}\n"""
""" {%- endfor -%}\n"""
""" {%- set ns.prev_message_type = 'tool_call' -%}\n"""
""" {%- endif -%}\n"""
"""\n"""
""" {%- set ns_tr_out = namespace(flag=false) -%}\n"""
""" {%- if message.get('tool_responses') -%}\n"""
""" {#- Legacy: tool_responses embedded on the assistant message (Google/Gemma native) -#}\n"""
""" {%- for tool_response in message['tool_responses'] -%}\n"""
""" {{- format_tool_response_block(tool_response['name'] | default('unknown'), tool_response['response']) -}}\n"""
""" {%- set ns_tr_out.flag = true -%}\n"""
""" {%- set ns.prev_message_type = 'tool_response' -%}\n"""
""" {%- endfor -%}\n"""
""" {%- elif message.get('tool_calls') -%}\n"""
""" {#- OpenAI Chat Completions: forward-scan consecutive role:tool messages -#}\n"""
""" {%- set ns_tool_scan = namespace(stopped=false) -%}\n"""
""" {%- for k in range(loop.index0 + 1, loop_messages | length) -%}\n"""
""" {%- if ns_tool_scan.stopped -%}\n"""
""" {%- elif loop_messages[k]['role'] != 'tool' -%}\n"""
""" {%- set ns_tool_scan.stopped = true -%}\n"""
""" {%- else -%}\n"""
""" {%- set follow = loop_messages[k] -%}\n"""
""" {#- Resolve tool_call_id to function name -#}\n"""
""" {%- set ns_tname = namespace(name=follow.get('name') | default('unknown')) -%}\n"""
""" {%- for tc in message['tool_calls'] -%}\n"""
""" {%- if tc.get('id') == follow.get('tool_call_id') -%}\n"""
""" {%- set ns_tname.name = tc['function']['name'] -%}\n"""
""" {%- endif -%}\n"""
""" {%- endfor -%}\n"""
""" {#- Handle content as string or content-parts array -#}\n"""
""" {%- set tool_body = follow.get('content') -%}\n"""
""" {%- if tool_body is string -%}\n"""
""" {{- format_tool_response_block(ns_tname.name, tool_body) -}}\n"""
""" {%- elif tool_body is sequence and tool_body is not string -%}\n"""
""" {%- set ns_txt = namespace(s='') -%}\n"""
""" {%- for part in tool_body -%}\n"""
""" {%- if part.get('type') == 'text' -%}\n"""
""" {%- set ns_txt.s = ns_txt.s + (part.get('text') | default('')) -%}\n"""
""" {%- endif -%}\n"""
""" {%- endfor -%}\n"""
""" {{- format_tool_response_block(ns_tname.name, ns_txt.s) -}}\n"""
""" {%- else -%}\n"""
""" {{- format_tool_response_block(ns_tname.name, tool_body) -}}\n"""
""" {%- endif -%}\n"""
""" {%- set ns_tr_out.flag = true -%}\n"""
""" {%- set ns.prev_message_type = 'tool_response' -%}\n"""
""" {%- endif -%}\n"""
""" {%- endfor -%}\n"""
""" {%- endif -%}\n"""
"""\n"""
""" {%- if message['content'] is string -%}\n"""
""" {%- if role == 'model' -%}\n"""
""" {{- strip_thinking(message['content']) -}}\n"""
""" {%- else -%}\n"""
""" {{- message['content'] | trim -}}\n"""
""" {%- endif -%}\n"""
""" {%- elif message['content'] is sequence -%}\n"""
""" {%- for item in message['content'] -%}\n"""
""" {%- if item['type'] == 'text' -%}\n"""
""" {%- if role == 'model' -%}\n"""
""" {{- strip_thinking(item['text']) -}}\n"""
""" {%- else -%}\n"""
""" {{- item['text'] | trim -}}\n"""
""" {%- endif -%}\n"""
""" {%- elif item['type'] == 'image' -%}\n"""
""" {{- '\\n\\n<|image|>\\n\\n' -}}\n"""
""" {%- set ns.prev_message_type = 'image' -%}\n"""
""" {%- elif item['type'] == 'audio' -%}\n"""
""" {{- '<|audio|>' -}}\n"""
""" {%- set ns.prev_message_type = 'audio' -%}\n"""
""" {%- elif item['type'] == 'video' -%}\n"""
""" {{- '\\n\\n<|video|>\\n\\n' -}}\n"""
""" {%- set ns.prev_message_type = 'video' -%}\n"""
""" {%- endif -%}\n"""
""" {%- endfor -%}\n"""
""" {%- endif -%}\n"""
"""\n"""
""" {%- if not (ns_tr_out.flag and not message.get('content')) -%}\n"""
""" {{- '<turn|>\\n' -}}\n"""
""" {%- endif -%}\n"""
""" {%- endif -%}\n"""
"""{%- endfor -%}"""
)
template = template.replace(old_message_loop, new_message_loop)

return template


# The correct chat templates were already uploaded to those 2 repos, so download from there
_CHAT_TEMPLATE = pathlib.Path(cached_file("gg-hf-gg/gemma-4-E4B-it", "chat_template.jinja")).read_text()
_CHAT_TEMPLATE_LARGE = pathlib.Path(cached_file("gg-hf-gg/gemma-4-31B-it", "chat_template.jinja")).read_text()

# Patch templates to support OpenAI-standard role: "tool" messages
_CHAT_TEMPLATE = _patch_template_for_openai_tool_role(_CHAT_TEMPLATE)
_CHAT_TEMPLATE_LARGE = _patch_template_for_openai_tool_role(_CHAT_TEMPLATE_LARGE)


_RESPONSE_SCHEMA = {
"type": "object",
"properties": {
Expand Down Expand Up @@ -1215,7 +1483,7 @@ def main(*args):
pad_token_id=config.get_text_config().pad_token_id,
bos_token_id=config.get_text_config().bos_token_id,
eos_token_id=(
tokenizer.convert_tokens_to_ids([tokenizer.eos_token, tokenizer.eot_token, tokenizer.str_token])
tokenizer.convert_tokens_to_ids([tokenizer.eos_token, tokenizer.eot_token])
if _INCLUDE_CHAT_TEMPLATE.value
else config.get_text_config().eos_token_id
),
Expand Down
Loading