diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 0dd41d6fd450..08b18abc0287 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -751,7 +751,9 @@ def create_causal_mask_mapping( Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and token_type_ids is None: - raise ValueError("`token_type_ids` is required as a model input when training") + token_type_ids = torch.zeros( + inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device + ) mask_kwargs = { "config": config.get_text_config(), diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index fe8678265ead..e894485a83b8 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -624,7 +624,9 @@ def create_causal_mask_mapping( Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and token_type_ids is None: - raise ValueError("`token_type_ids` is required as a model input when training") + token_type_ids = torch.zeros( + inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device + ) mask_kwargs = { "config": config.get_text_config(), diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index f690c0425c8c..94c5ca94a094 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -2002,7 +2002,9 @@ def create_causal_mask_mapping( Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and mm_token_type_ids is None: - raise ValueError("`mm_token_type_ids` is required as a model input when training") + mm_token_type_ids = torch.zeros( + inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device + ) mask_kwargs = { "config": config.get_text_config(), diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index a97273802213..b4eb30529821 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -1624,7 +1624,9 @@ def create_causal_mask_mapping( Uses `pixel_values` as an optional input to disambiguate edge cases. """ if is_training and mm_token_type_ids is None: - raise ValueError("`mm_token_type_ids` is required as a model input when training") + mm_token_type_ids = torch.zeros( + inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device + ) mask_kwargs = { "config": config.get_text_config(),