diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 37d55c3dfc..7a85ce772f 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -741,6 +741,7 @@ The returned images are automatically injected into the conversation and passed Tested with: - [**Gemma4**](https://huggingface.co/collections/google/gemma-4) — e.g., `google/gemma-4-E2B-it` +- **GLM-4-MoE** ([4.5](https://huggingface.co/collections/zai-org/glm-45), [4.6](https://huggingface.co/collections/zai-org/glm-46) or [4.7](https://huggingface.co/collections/zai-org/glm-47)) — e.g., `zai-org/GLM-4.7` - [**GPT-OSS**](https://huggingface.co/collections/openai/gpt-oss) — e.g., `openai/gpt-oss-20b` - [**Qwen3**](https://huggingface.co/collections/Qwen/qwen3) — e.g., `Qwen/Qwen3-0.6B` - [**Qwen3.5**](https://huggingface.co/collections/Qwen/qwen35) — e.g., `Qwen/Qwen3.5-2B` diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 4c4193d8b9..efaae83289 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -114,6 +114,7 @@ def test_clone_with_sequence_classification_model(self): @pytest.mark.parametrize( "tokenizer_name", [ + pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), @@ -465,6 +466,7 @@ def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, t @pytest.mark.parametrize( "tokenizer_name", [ + pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), pytest.param("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification", id="qwen3"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 441de1f176..9c6d4eba72 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -117,6 +117,41 @@ def clone_chat_template( return model, tokenizer, added_tokens +glm4moe_schema = { + "x-regex": r"^(?:\n?\n?(?:(?P.*?\S.*?)\n?|[\s]*)\s*)?(?P.*?)(?:\n(?=))?(?=(?:|$))(?P(?:.+?\s*)+)?$", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "reasoning_content": {"type": "string"}, + "tool_calls": { + "type": "array", + "x-regex-iterator": r"\s*(.+?)\s*", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r"^(\S+)"}, + "arguments": { + "type": "object", + "x-regex-key-value": r"(?P[^<]+)\s*\n(?P.*?)", + "default": {}, + "additionalProperties": { + "x-parser": "json", + "x-parser-args": {"allow_non_json": True}, + }, + }, + }, + }, + }, + }, + }, + }, +} + gptoss_schema = { # Normalize final content to analysis format so both map to the same "content" group. "x-regex-substitutions": [ @@ -230,6 +265,8 @@ def clone_chat_template( } +glm4moe_chat_template = (_CHAT_TEMPLATES_DIR / "glm4moe.jinja").read_text() + gptoss_chat_template = (_CHAT_TEMPLATES_DIR / "gptoss.jinja").read_text() qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text() @@ -268,6 +305,8 @@ def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ + if tokenizer.chat_template == glm4moe_chat_template: + tokenizer.response_schema = glm4moe_schema if tokenizer.chat_template == gptoss_chat_template: tokenizer.response_schema = gptoss_schema return tokenizer diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md index 4a7ac33e8c..7b05d49783 100644 --- a/trl/chat_templates/README.md +++ b/trl/chat_templates/README.md @@ -11,6 +11,10 @@ Jinja2 chat templates stored here serve two purposes: Used for identity comparison only. +### `glm4moe.jinja` + +Original GLM-4-MoE chat template. + ### `gptoss.jinja` Original GPT-OSS chat template. diff --git a/trl/chat_templates/glm4moe.jinja b/trl/chat_templates/glm4moe.jinja new file mode 100644 index 0000000000..41478957ac --- /dev/null +++ b/trl/chat_templates/glm4moe.jinja @@ -0,0 +1,103 @@ +[gMASK] +{%- if tools -%} +<|system|> +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{% for tool in tools %} +{{ tool | tojson(ensure_ascii=False) }} +{% endfor %} + + +For each function call, output the function name and arguments within the following XML format: +{function-name} +{arg-key-1} +{arg-value-1} +{arg-key-2} +{arg-value-2} +... +{%- endif -%} +{%- macro visible_text(content) -%} + {%- if content is string -%} + {{- content }} + {%- elif content is iterable and content is not mapping -%} + {%- for item in content -%} + {%- if item is mapping and item.type == 'text' -%} + {{- item.text }} + {%- elif item is string -%} + {{- item }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- content }} + {%- endif -%} +{%- endmacro -%} +{%- set ns = namespace(last_user_index=-1) %} +{%- for m in messages %} + {%- if m.role == 'user' %} + {% set ns.last_user_index = loop.index0 -%} + {%- endif %} +{%- endfor %} +{% for m in messages %} +{%- if m.role == 'user' -%}<|user|> +{{ visible_text(m.content) }} +{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}} +{%- elif m.role == 'assistant' -%} +<|assistant|> +{%- set reasoning_content = '' %} +{%- set content = visible_text(m.content) %} +{%- if m.reasoning_content is string %} + {%- set reasoning_content = m.reasoning_content %} +{%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} +{%- endif %} +{%- if loop.index0 > ns.last_user_index and reasoning_content -%} +{{ '\n' + reasoning_content.strip() + ''}} +{%- else -%} +{{ '\n' }} +{%- endif -%} +{%- if content.strip() -%} +{{ '\n' + content.strip() }} +{%- endif -%} +{% if m.tool_calls %} +{% for tc in m.tool_calls %} +{%- if tc.function %} + {%- set tc = tc.function %} +{%- endif %} +{{ '\n' + tc.name }} +{% set _args = tc.arguments %} +{% for k, v in _args.items() %} +{{ k }} +{{ v | tojson(ensure_ascii=False) if v is not string else v }} +{% endfor %} +{% endfor %} +{% endif %} +{%- elif m.role == 'tool' -%} +{%- if m.content is string -%} +{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|observation|>' }} +{%- endif %} +{{- '\n\n' }} +{{- m.content }} +{{- '\n' }} +{%- else -%} +<|observation|>{% for tr in m.content %} + + +{{ tr.output if tr.output is defined else tr }} +{% endfor -%} +{% endif -%} +{%- elif m.role == 'system' -%} +<|system|> +{{ visible_text(m.content) }} +{%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + <|assistant|>{{- '\n' if (enable_thinking is defined and not enable_thinking) else '' -}} +{%- endif -%} \ No newline at end of file