Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/models/gemma4/modeling_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,11 @@ def forward(
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}

# Ensure a cache exists for KV sharing between layers, even when use_cache=False.
# This must happen after mask creation to avoid affecting causal mask computation.
if past_key_values is None:
past_key_values = DynamicCache(config=self.config)

# embed positions
hidden_states = inputs_embeds
position_embeddings = {}
Expand All @@ -1616,7 +1621,7 @@ def forward(

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
past_key_values=past_key_values if use_cache else None,
)

def get_per_layer_inputs(self, input_ids: torch.Tensor | None, inputs_embeds: torch.Tensor | None) -> torch.Tensor:
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/gemma4/modular_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,11 @@ def forward(
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}

# Ensure a cache exists for KV sharing between layers, even when use_cache=False.
# This must happen after mask creation to avoid affecting causal mask computation.
if past_key_values is None:
past_key_values = DynamicCache(config=self.config)

# embed positions
hidden_states = inputs_embeds
position_embeddings = {}
Expand All @@ -1385,7 +1390,7 @@ def forward(

return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
past_key_values=past_key_values if use_cache else None,
)


Expand Down
21 changes: 21 additions & 0 deletions tests/models/gemma4/test_modeling_gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ def test_model_rope_scaling_from_config(self):
def test_generate_from_random_inputs_embeds(self):
pass

def test_use_cache_false_with_kv_sharing(self):
"""Regression test: use_cache=False must produce the same logits as use_cache=True.

Gemma4 uses KV sharing (num_kv_shared_layers) where later layers reuse K/V from earlier
layers via the cache object. When use_cache=False the cache was not created, breaking the
sharing mechanism and causing receiver layers to use keys as values (garbage logits).
See https://github.com/huggingface/transformers/issues/45242
"""
config = self.model_tester.get_config()
config.attention_k_eq_v = True
config.num_global_key_value_heads = config.num_key_value_heads
model = Gemma4ForCausalLM(config).to(torch_device).eval()
input_ids = ids_tensor([1, 16], config.vocab_size).to(torch_device)

with torch.no_grad():
out_cached = model(input_ids, use_cache=True)
out_uncached = model(input_ids, use_cache=False)

torch.testing.assert_close(out_cached.logits, out_uncached.logits, atol=1e-4, rtol=1e-4)
self.assertIsNone(out_uncached.past_key_values, "past_key_values should be None when use_cache=False")

@unittest.skip(
"Flaky on CI, but not locally on Mac. If model is set to fp32 instead of bf16, not flaky anymore."
"TODO Cyril: investigate where the loss of precision between bf16 and fp32 comes from."
Expand Down
Loading