diff --git a/src/transformers/models/gemma4/modeling_gemma4.py b/src/transformers/models/gemma4/modeling_gemma4.py index f690c0425c8c..9d5b3ce1ea77 100644 --- a/src/transformers/models/gemma4/modeling_gemma4.py +++ b/src/transformers/models/gemma4/modeling_gemma4.py @@ -1592,6 +1592,11 @@ def forward( "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } + # Ensure a cache exists for KV sharing between layers, even when use_cache=False. + # This must happen after mask creation to avoid affecting causal mask computation. + if past_key_values is None: + past_key_values = DynamicCache(config=self.config) + # embed positions hidden_states = inputs_embeds position_embeddings = {} @@ -1616,7 +1621,7 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, ) def get_per_layer_inputs(self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None) -> torch.Tensor: diff --git a/src/transformers/models/gemma4/modular_gemma4.py b/src/transformers/models/gemma4/modular_gemma4.py index a97273802213..085ad8dad2b1 100644 --- a/src/transformers/models/gemma4/modular_gemma4.py +++ b/src/transformers/models/gemma4/modular_gemma4.py @@ -1361,6 +1361,11 @@ def forward( "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } + # Ensure a cache exists for KV sharing between layers, even when use_cache=False. + # This must happen after mask creation to avoid affecting causal mask computation. + if past_key_values is None: + past_key_values = DynamicCache(config=self.config) + # embed positions hidden_states = inputs_embeds position_embeddings = {} @@ -1385,7 +1390,7 @@ def forward( return BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values, + past_key_values=past_key_values if use_cache else None, ) diff --git a/tests/models/gemma4/test_modeling_gemma4.py b/tests/models/gemma4/test_modeling_gemma4.py index c63e9ba20165..42f941b14418 100644 --- a/tests/models/gemma4/test_modeling_gemma4.py +++ b/tests/models/gemma4/test_modeling_gemma4.py @@ -114,6 +114,27 @@ def test_model_rope_scaling_from_config(self): def test_generate_from_random_inputs_embeds(self): pass + def test_use_cache_false_with_kv_sharing(self): + """Regression test: use_cache=False must produce the same logits as use_cache=True. + + Gemma4 uses KV sharing (num_kv_shared_layers) where later layers reuse K/V from earlier + layers via the cache object. When use_cache=False the cache was not created, breaking the + sharing mechanism and causing receiver layers to use keys as values (garbage logits). + See https://github.com/huggingface/transformers/issues/45242 + """ + config = self.model_tester.get_config() + config.attention_k_eq_v = True + config.num_global_key_value_heads = config.num_key_value_heads + model = Gemma4ForCausalLM(config).to(torch_device).eval() + input_ids = ids_tensor([1, 16], config.vocab_size).to(torch_device) + + with torch.no_grad(): + out_cached = model(input_ids, use_cache=True) + out_uncached = model(input_ids, use_cache=False) + + torch.testing.assert_close(out_cached.logits, out_uncached.logits, atol=1e-4, rtol=1e-4) + self.assertIsNone(out_uncached.past_key_values, "past_key_values should be None when use_cache=False") + @unittest.skip( "Flaky on CI, but not locally on Mac. If model is set to fp32 instead of bf16, not flaky anymore." "TODO Cyril: investigate where the loss of precision between bf16 and fp32 comes from."