From cc6ea6ac2a7fa009b2b05a7340be2ac36ee86b28 Mon Sep 17 00:00:00 2001 From: Jash Shah Date: Fri, 3 Apr 2026 09:27:10 -0700 Subject: [PATCH] fix(gemma3, gemma4): default token_type_ids to zeros for text-only training Fixes #45200 --- src/transformers/models/gemma3/modeling_gemma3.py | 4 +++- src/transformers/models/gemma3/modular_gemma3.py | 4 +++- src/transformers/models/gemma4/modeling_gemma4.py | 4 +++- src/transformers/models/gemma4/modular_gemma4.py | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 0dd41d6fd450..08b18abc0287 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -751,7 +751,9 @@ def create_causal_mask_mapping( Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and token_type_ids is None: - raise ValueError("`token_type_ids` is required as a model input when training") + token_type_ids = torch.zeros( + inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device + ) mask_kwargs = { "config": config.get_text_config(), diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index fe8678265ead..e894485a83b8 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -624,7 +624,9 @@ def create_causal_mask_mapping( Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and token_type_ids is None: - raise ValueError("`token_type_ids` is required as a model input when training") + token_type_ids = torch.zeros( + inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device + ) mask_kwargs = { "config": config.get_text_config(), diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index f690c0425c8c..94c5ca94a094 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -2002,7 +2002,9 @@ def create_causal_mask_mapping( Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and mm_token_type_ids is None: - raise ValueError("`mm_token_type_ids` is required as a model input when training") + mm_token_type_ids = torch.zeros( + inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device + ) mask_kwargs = { "config": config.get_text_config(), diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index a97273802213..b4eb30529821 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -1624,7 +1624,9 @@ def create_causal_mask_mapping( Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and mm_token_type_ids is None: - raise ValueError("`mm_token_type_ids` is required as a model input when training") + mm_token_type_ids = torch.zeros( + inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device + ) mask_kwargs = { "config": config.get_text_config(),