-
Notifications
You must be signed in to change notification settings - Fork 33k
Expand file tree
/
Copy pathmodular_gemma4.py
More file actions
2162 lines (1835 loc) · 95.1 KB
/
modular_gemma4.py
File metadata and controls
2162 lines (1835 loc) · 95.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2026 the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
import torch
from torch import nn
from torch.nn import functional as F
from ... import initialization as init
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...configuration_utils import PreTrainedConfig
from ...integrations import use_kernelized_func
from ...masking_utils import (
create_bidirectional_mask,
create_causal_mask,
create_masks_for_generate,
create_sliding_window_causal_mask,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
TransformersKwargs,
auto_docstring,
can_return_tuple,
logging,
torch_compilable_check,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ..auto.modeling_auto import AutoModel
from ..gemma3.modeling_gemma3 import (
Gemma3Attention,
Gemma3DecoderLayer,
Gemma3ForCausalLM,
Gemma3MLP,
Gemma3RotaryEmbedding,
Gemma3TextModel,
Gemma3TextScaledWordEmbedding,
)
from ..gemma3n.modeling_gemma3n import (
Gemma3nCausalLMOutputWithPast,
Gemma3nForConditionalGeneration,
Gemma3nModel,
Gemma3nModelOutputWithPast,
Gemma3nMultimodalEmbedder,
Gemma3nRMSNorm,
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..llama.modeling_llama import LlamaRotaryEmbedding
from ..mixtral.modeling_mixtral import MixtralExperts
from ..moonshine_streaming.modeling_moonshine_streaming import sliding_window_mask_function
from .configuration_gemma4 import Gemma4AudioConfig, Gemma4Config, Gemma4TextConfig, Gemma4VisionConfig
logger = logging.get_logger(__name__)
class Gemma4ModelOutputWithPast(Gemma3nModelOutputWithPast):
pass
class Gemma4CausalLMOutputWithPast(Gemma3nCausalLMOutputWithPast):
pass
@dataclass
@auto_docstring
class Gemma4AudioModelOutput(BaseModelOutputWithPooling):
r"""
attention_mask (`torch.BoolTensor`, *optional*):
A torch.BoolTensor of shape `(batch_size, num_frames)`. True for valid positions, False for padding.
"""
attention_mask: torch.BoolTensor | None = None
class Gemma4ClippableLinear(nn.Module):
def __init__(
self,
config: Gemma4VisionConfig | Gemma4AudioConfig,
in_features: int,
out_features: int,
) -> None:
super().__init__()
self.use_clipped_linears = config.use_clipped_linears
self.linear = nn.Linear(in_features, out_features, bias=False)
if self.use_clipped_linears:
self.register_buffer("input_min", torch.tensor(-float("inf")))
self.register_buffer("input_max", torch.tensor(float("inf")))
self.register_buffer("output_min", torch.tensor(-float("inf")))
self.register_buffer("output_max", torch.tensor(float("inf")))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_clipped_linears:
hidden_states = torch.clamp(hidden_states, self.input_min, self.input_max)
hidden_states = self.linear(hidden_states)
if self.use_clipped_linears:
hidden_states = torch.clamp(hidden_states, self.output_min, self.output_max)
return hidden_states
class Gemma4RMSNorm(Gemma3nRMSNorm):
pass
class Gemma4AudioRelPositionalEncoding(nn.Module):
"""Sinusoidal relative positional encoding for the audio encoder.
Produces position embeddings of shape [1, 2*context_size - 1, hidden_size] with
concatenated [sin..., cos...] layout matching the original Gemma4 convention.
"""
inv_timescales: torch.Tensor
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.context_size = (
config.attention_chunk_size + config.attention_context_left - 1 + config.attention_context_right
)
min_timescale = 1.0
max_timescale = 10000.0
num_timescales = self.hidden_size // 2
log_timescale_increment = math.log(max_timescale / min_timescale) / max(num_timescales - 1, 1)
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales) * -log_timescale_increment)
self.register_buffer("inv_timescales", inv_timescales.unsqueeze(0).unsqueeze(0), persistent=False)
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
position_ids = torch.arange(12, -1, -1, device=hidden_states.device)
position_ids = position_ids[..., None]
scaled_time = position_ids * self.inv_timescales.to(device=hidden_states.device)
pos_embed = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=-1)
return pos_embed.to(dtype=hidden_states.dtype)
class Gemma4AudioAttention(nn.Module):
"""Chunked local attention with relative position bias"""
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_logits_soft_cap = config.attention_logit_cap
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.q_scale = (self.head_dim**-0.5) / math.log(2)
self.k_scale = math.log(1 + math.e) / math.log(2)
self.chunk_size = config.attention_chunk_size
self.max_past_horizon = config.attention_context_left - 1
self.max_future_horizon = config.attention_context_right
self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, self.num_heads * self.head_dim)
self.post = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
self.relative_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.per_dim_scale = nn.Parameter(torch.zeros(self.head_dim))
self.register_buffer("softcap", torch.tensor(self.attention_logits_soft_cap), persistent=False)
def _convert_to_block(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Splits a `(batch_size, seq_len, num_heads, head_dim)` tensor into non-overlapping blocks of `chunk_size` along the sequence dim."""
batch_size, seq_len, num_heads, head_dim = hidden_states.shape
num_blocks = (seq_len + self.chunk_size - 1) // self.chunk_size
pad = num_blocks * self.chunk_size - seq_len
hidden_states = F.pad(hidden_states, (0, 0, 0, 0, 0, pad))
return hidden_states.reshape(batch_size, num_blocks, self.chunk_size, num_heads, head_dim).contiguous()
def _extract_block_context(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Extracts overlapping context windows of `context_size` for every block, strided by `chunk_size`."""
batch_size, seq_len, num_heads, head_dim = hidden_states.shape
hidden_states = F.pad(
hidden_states, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)
)
hidden_states = hidden_states.unfold(1, self.context_size, self.chunk_size)
hidden_states = torch.movedim(hidden_states, -1, 2)
return hidden_states.contiguous()
def _rel_shift(self, x: torch.Tensor) -> torch.Tensor:
"""Relative position shift for blocked attention. See appendix B of https://huggingface.co/papers/1901.02860."""
batch_size, num_heads, num_blocks, block_size, position_length = x.shape
context_size = self.context_size
x = F.pad(x, (0, context_size + 1 - position_length))
x = x.view(batch_size, num_heads, num_blocks, block_size * (context_size + 1))
x = x[..., : block_size * context_size]
return x.view(batch_size, num_heads, num_blocks, block_size, context_size)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.BoolTensor | None = None,
) -> tuple[torch.Tensor, None]:
batch_size, seq_length, _ = hidden_states.shape
hidden_shape = (batch_size, seq_length, self.num_heads, self.head_dim)
query_states = self.q_proj(hidden_states).float().view(hidden_shape)
key_states = self.k_proj(hidden_states).float().view(hidden_shape)
value_states = self.v_proj(hidden_states).float().view(hidden_shape)
query_states = query_states * self.q_scale * F.softplus(self.per_dim_scale)
key_states = key_states * self.k_scale
query_states = self._convert_to_block(query_states)
key_states = self._extract_block_context(key_states)
value_states = self._extract_block_context(value_states)
num_blocks = query_states.shape[1]
relative_key_states = self.relative_k_proj(position_embeddings)
relative_key_states = relative_key_states.view(-1, self.num_heads, self.head_dim)
relative_key_states = relative_key_states.to(dtype=query_states.dtype)
queries = query_states.permute(0, 3, 1, 2, 4)
matrix_ac = queries @ key_states.permute(0, 3, 1, 4, 2)
queries_flat = queries.reshape(batch_size, self.num_heads, -1, self.head_dim)
matrix_bd = queries_flat @ relative_key_states.permute(1, 2, 0)
matrix_bd = matrix_bd.reshape(batch_size, self.num_heads, num_blocks, self.chunk_size, -1)
matrix_bd = self._rel_shift(matrix_bd)
attn_weights = matrix_ac + matrix_bd
attn_weights = attn_weights / self.softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.softcap
if attention_mask is not None:
attn_weights = attn_weights.masked_fill(
attention_mask.logical_not(), self.config.attention_invalid_logits_value
)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
attn_output = attn_weights @ value_states.permute(0, 3, 1, 2, 4)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, num_blocks * self.chunk_size, -1)
attn_output = attn_output[:, :seq_length].contiguous()
attn_output = self.post(attn_output.to(dtype=self.post.linear.weight.dtype))
return attn_output, attn_weights
class Gemma4AudioSubSampleConvProjectionLayer(nn.Module):
def __init__(self, in_channels, out_channels, norm_eps):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=1,
bias=False,
)
self.norm = nn.LayerNorm(out_channels, eps=norm_eps, elementwise_affine=True, bias=False)
self.act = nn.ReLU()
def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None):
if mask is not None:
mask = mask.to(device=hidden_states.device)
hidden_states = hidden_states * mask[:, None, :, None]
hidden_states = self.conv(hidden_states.to(self.conv.weight.dtype))
hidden_states = self.act(self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous())
if mask is not None:
mask = mask[:, ::2]
return hidden_states, mask
class Gemma4AudioSubSampleConvProjection(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.layer0 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=1,
out_channels=config.subsampling_conv_channels[0],
norm_eps=config.rms_norm_eps,
)
self.layer1 = Gemma4AudioSubSampleConvProjectionLayer(
in_channels=config.subsampling_conv_channels[0],
out_channels=config.subsampling_conv_channels[1],
norm_eps=config.rms_norm_eps,
)
proj_input_dim = (config.subsampling_conv_channels[0] // 4) * config.subsampling_conv_channels[1]
self.input_proj_linear = nn.Linear(proj_input_dim, config.hidden_size, bias=False)
def forward(
self,
input_features: torch.Tensor,
input_features_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = input_features.unsqueeze(1)
hidden_states, mask = self.layer0(hidden_states, input_features_mask)
hidden_states, mask = self.layer1(hidden_states, mask)
batch_size, _, seq_len, _ = hidden_states.shape
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1)
return self.input_proj_linear(hidden_states), mask
class Gemma4AudioFeedForward(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.ffw_layer_1 = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 4)
self.ffw_layer_2 = Gemma4ClippableLinear(config, config.hidden_size * 4, config.hidden_size)
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.post_layer_norm = Gemma4RMSNorm(config.hidden_size)
self.act_fn = ACT2FN[config.hidden_act]
self.gradient_clipping = config.gradient_clipping
self.post_layer_scale = config.residual_weight
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# This is needed to avoid any underflow/overflow issues when clipping
gradient_clipping = min(self.gradient_clipping, torch.finfo(self.ffw_layer_1.linear.weight.dtype).max)
residual = hidden_states
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.pre_layer_norm(hidden_states)
hidden_states = self.ffw_layer_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.ffw_layer_2(hidden_states)
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.post_layer_norm(hidden_states)
hidden_states *= self.post_layer_scale
hidden_states += residual
return hidden_states
# TODO: this could be imported from Voxtral realtime
class Gemma4AudioCausalConv1d(nn.Conv1d):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# kernel_size: int,
# # cache_key: str,
# stride: int = 1,
# dilation: int = 1,
# bias: bool = True,
# ):
# super().__init__(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, bias=bias)
# self.cache_key = cache_key
@cached_property
def left_pad(self):
effective_kernel_size = (self.kernel_size[0] - 1) * self.dilation[0] + 1
return effective_kernel_size - self.stride[0]
def forward(
self,
x: torch.Tensor,
# padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, # TODO: we might want to add a cache?
) -> torch.Tensor:
# if padding_cache is not None:
# x = padding_cache.update(x, self.cache_key, self)
# else:
# x = nn.functional.pad(x, (self.left_pad, 0))
x = nn.functional.pad(x, (self.left_pad, 0))
return super().forward(x)
class Gemma4AudioLightConv1d(nn.Module):
def __init__(self, config: Gemma4AudioConfig):
super().__init__()
self.config = config
self.linear_start = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size * 2)
self.linear_end = Gemma4ClippableLinear(config, config.hidden_size, config.hidden_size)
self.depthwise_conv1d = Gemma4AudioCausalConv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=config.conv_kernel_size,
groups=config.hidden_size,
bias=False,
)
self.pre_layer_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
self.conv_norm = Gemma4RMSNorm(config.hidden_size, eps=config.rms_norm_eps, with_scale=True)
self.act_fn = ACT2FN[config.hidden_act]
self.gradient_clipping = config.gradient_clipping
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.pre_layer_norm(hidden_states)
hidden_states = self.linear_start(hidden_states)
hidden_states = nn.functional.glu(hidden_states, dim=-1)
hidden_states = self.depthwise_conv1d(hidden_states.transpose(1, 2)).transpose(1, 2)
# This is needed to avoid any underflow/overflow issues when clipping
gradient_clipping = min(self.gradient_clipping, torch.finfo(self.linear_start.linear.weight.dtype).max)
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.conv_norm(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_end(hidden_states)
hidden_states += residual
return hidden_states
class Gemma4AudioLayer(nn.Module):
def __init__(self, config: Gemma4AudioConfig, layer_idx: int):
super().__init__()
self.config = config
self.feed_forward1 = Gemma4AudioFeedForward(config)
self.feed_forward2 = Gemma4AudioFeedForward(config)
self.self_attn = Gemma4AudioAttention(config, layer_idx)
self.lconv1d = Gemma4AudioLightConv1d(config)
self.norm_pre_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_post_attn = Gemma4RMSNorm(config.hidden_size)
self.norm_out = Gemma4RMSNorm(config.hidden_size)
self.gradient_clipping = config.gradient_clipping
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.BoolTensor | None,
position_embeddings: torch.Tensor,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
# This is needed to avoid any underflow/overflow issues when clipping
gradient_clipping = min(self.gradient_clipping, torch.finfo(self.norm_pre_attn.weight.dtype).max)
hidden_states = self.feed_forward1(hidden_states)
residual = hidden_states
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.norm_pre_attn(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
)
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.norm_post_attn(hidden_states)
hidden_states += residual
hidden_states = self.lconv1d(hidden_states)
hidden_states = self.feed_forward2(hidden_states)
hidden_states = torch.clamp(hidden_states, -gradient_clipping, gradient_clipping)
hidden_states = self.norm_out(hidden_states)
return hidden_states
# ---- Vision Encoder Layers ----
class Gemma4VisionPatchEmbedder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.position_embedding_size = config.position_embedding_size
self.input_proj = nn.Linear(3 * self.patch_size**2, self.hidden_size, bias=False)
self.position_embedding_table = nn.Parameter(torch.ones(2, self.position_embedding_size, self.hidden_size))
def _position_embeddings(self, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor) -> torch.Tensor:
"""Prepare patch positions map for matmul with positon embedding table."""
# Expanding and permute patch positions to (batch_size, num_patches, 2, position_embedding_size) for matmul.
clamped_positions = pixel_position_ids.clamp(min=0)
one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
# Compute positional embeddings and sum across x and y.
position_embeddings = one_hot @ self.position_embedding_table
position_embeddings = position_embeddings.sum(dim=1)
# Zero out embeddings for any padding patches.
position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
return position_embeddings
def forward(
self, pixel_values: torch.Tensor, pixel_position_ids: torch.Tensor, padding_positions: torch.Tensor
) -> torch.Tensor:
# Gemma4 applies no normalization and instead scales in model code
pixel_values = 2 * (pixel_values - 0.5)
hidden_states = self.input_proj(pixel_values.to(self.input_proj.weight.dtype))
position_embeddings = self._position_embeddings(pixel_position_ids, padding_positions)
return hidden_states + position_embeddings
class Gemma4VisionPooler(nn.Module):
"""Scaling and optional spatial pooling for vision encodings"""
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.root_hidden_size = self.hidden_size**0.5
def _avg_pool_by_positions(
self, hidden_states: torch.Tensor, pixel_position_ids: torch.Tensor, length: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
2D spatial pooling according to patch positions.
Pools the input tokens by averaging patches within a `k^2` grid, where `k` is determined by the ratio between
input and output lengths
"""
input_seq_len = hidden_states.shape[1]
k = int((input_seq_len // length) ** 0.5)
k_squared = k**2
if k_squared * length != input_seq_len:
raise ValueError(
f"Cannot pool {hidden_states.shape} to {length}: {k=}^2 times {length=} must be {input_seq_len}."
)
# Clamp padding positions (which are -1) to 0 so they don't break one_hot.
# Padding patches have zero hidden states so they contribute nothing to the average.
clamped_positions = pixel_position_ids.clamp(min=0)
max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1
kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor")
kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1]
weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared
output = weights.transpose(1, 2) @ hidden_states.float()
mask = torch.logical_not((weights == 0).all(dim=1))
return output.to(hidden_states.dtype), mask
def forward(
self,
hidden_states: torch.Tensor,
pixel_position_ids: torch.Tensor,
padding_positions: torch.Tensor,
output_length: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if output_length > hidden_states.shape[1]:
raise ValueError(
f"Cannot output more soft tokens (requested {output_length}) than there are patches"
f" ({hidden_states.shape[1]}). Change the value of `num_soft_tokens` when processing."
)
hidden_states = hidden_states.masked_fill(padding_positions.unsqueeze(-1), 0.0)
if hidden_states.shape[1] != output_length:
hidden_states, padding_positions = self._avg_pool_by_positions(
hidden_states, pixel_position_ids, output_length
)
hidden_states *= self.root_hidden_size
return hidden_states, padding_positions
class Gemma4VisionMLP(Gemma3MLP):
def __init__(self, config: Gemma4VisionConfig):
super().__init__(self, config)
self.gate_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.up_proj = Gemma4ClippableLinear(config, self.hidden_size, self.intermediate_size)
self.down_proj = Gemma4ClippableLinear(config, self.intermediate_size, self.hidden_size)
def apply_multidimensional_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
unsqueeze_dim: int = 2,
) -> torch.Tensor:
"""Applies multidimensional RoPE to inputs.
Args:
x (`torch.Tensor`): The tensor to embed.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
If position_ids.ndim + 2 == x.ndim, then this function passes through to `apply_rotary_pos_emb()`.
Otherwise, position_ids is used to split the inputs, x, into multiple pieces, where each piece is fed to
`apply_rotary_pos_emb()`, and then concatenated back together.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
Tensor of shape [B, L, N, H] with RoPE applied.
"""
ndim = position_ids.shape[-1]
num_input_channels = x.shape[-1]
num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim))
if num_rotated_channels_per_dim <= 0:
raise ValueError(
"Invalid configuration: num_rotated_channels_per_dim must be > 0, got"
f" {num_rotated_channels_per_dim} (num_input_channels={num_input_channels},"
f" ndim={ndim})"
)
# Correctly split the input tensor into ndim parts
split_sizes = [num_rotated_channels_per_dim] * ndim
x_parts = torch.split(x, split_sizes, dim=-1)
cos_parts = torch.split(cos, split_sizes, dim=-1)
sin_parts = torch.split(sin, split_sizes, dim=-1)
y_parts = [
apply_rotary_pos_emb(
x=x_parts[k],
cos=cos_parts[k],
sin=sin_parts[k],
unsqueeze_dim=unsqueeze_dim,
)
for k in range(ndim)
]
return torch.cat(y_parts, dim=-1)
class Gemma4VisionRotaryEmbedding(LlamaRotaryEmbedding):
@staticmethod
def compute_default_rope_parameters(
config: Gemma4VisionConfig | None = None,
device: torch.device | None = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
# The reference implementation computes RoPE frequencies INDEPENDENTLY
# for each spatial dimension using the partitioned head_dim (head_dim // ndim),
# so both x and y dimensions get identical frequency ranges.
# This is different from splitting the global inv_freq between dimensions.
spatial_dim = dim // 2
attention_factor = 1.0 # Unused in this type of RoPE
inv_freq = 1.0 / (
base
** (torch.arange(0, spatial_dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / spatial_dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
# Multidimensional positions: [batch, num_patches, ndim]. Apply rotations to each spatial dim separately
all_cos, all_sin = [], []
for i in range(2):
dim_position_ids = position_ids[:, :, i]
dim_position_ids_expanded = dim_position_ids[:, None, :].float()
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ dim_position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
all_cos.append(cos)
all_sin.append(sin)
cos = torch.cat(all_cos, dim=-1).to(dtype=x.dtype)
sin = torch.cat(all_sin, dim=-1).to(dtype=x.dtype)
return cos, sin
class Gemma4VisionAttention(Gemma3Attention):
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__(self, config, layer_idx)
del self.attn_logit_softcapping
del self.sliding_window
del self.is_sliding
self.scaling = 1.0
self.is_causal = False
self.k_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
self.q_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_attention_heads * self.head_dim)
self.v_proj = Gemma4ClippableLinear(config, config.hidden_size, config.num_key_value_heads * self.head_dim)
self.o_proj = Gemma4ClippableLinear(config, config.num_attention_heads * self.head_dim, config.hidden_size)
self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
cos, sin = position_embeddings
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_multidimensional_rope(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)
key_states = apply_multidimensional_rope(key_states, cos, sin, position_ids)
key_states = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape)
value_states = self.v_norm(value_states)
value_states = value_states.transpose(1, 2)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# Same forward as Gemma3 but no cache
class Gemma4VisionEncoderLayer(Gemma3DecoderLayer):
def __init__(self, config: Gemma4VisionConfig, layer_idx: int):
super().__init__(self, config, layer_idx)
self.self_attn = Gemma4VisionAttention(config=config, layer_idx=layer_idx)
self.mlp = Gemma4VisionMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma4VisionEncoder(nn.Module):
def __init__(self, config: Gemma4VisionConfig):
super().__init__()
self.config = config
self.num_layers = config.num_hidden_layers
self.rotary_emb = Gemma4VisionRotaryEmbedding(config)
self.layers = nn.ModuleList(
[Gemma4VisionEncoderLayer(config=config, layer_idx=i) for i in range(self.num_layers)]
)
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor,
pixel_position_ids: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
r"""
pixel_position_ids (torch.Tensor):
Patch positions as (x, y) coordinates in the image as [batch, num_patches, 2].
"""
attention_mask = create_bidirectional_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
# embed positions
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, pixel_position_ids)
# decoder layers
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
position_ids=pixel_position_ids,
**kwargs,
)
return BaseModelOutputWithPast(last_hidden_state=hidden_states)
# ---- Text model ----
class Gemma4TextMLP(Gemma3MLP):
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
first_kv_shared_layer_idx = config.num_hidden_layers - config.num_kv_shared_layers
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer
super().__init__()
self.intermediate_size = config.intermediate_size * (2 if use_double_wide_mlp else 1)
class Gemma4TextRotaryEmbedding(Gemma3RotaryEmbedding):
def __init__(self, config: Gemma4TextConfig, device=None, layer_type=None):
nn.Module.__init__(self)
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.layer_types = set(config.layer_types)
self.rope_init_fns: dict[str, Callable[..., tuple[torch.Tensor, float]]] = {}
self.rope_type: dict[str, str] = {}
for layer_type in self.layer_types:
rope_params = self.config.rope_parameters[layer_type]
if rope_params is None:
continue
if (rope_type := rope_params["rope_type"]) != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
else:
rope_init_fn = self.compute_default_rope_parameters
self.rope_init_fns[layer_type] = rope_init_fn
self.rope_type[layer_type] = rope_type
rope_init_fn_kwargs = {"device": device, "layer_type": layer_type}
if layer_type == "full_attention" and rope_type == "proportional":
rope_init_fn_kwargs["head_dim_key"] = "global_head_dim"
curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, **rope_init_fn_kwargs)
self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
@use_kernelized_func(apply_rotary_pos_emb)
class Gemma4TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma4TextConfig, layer_idx: int):
super().__init__()
self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
self.config = config
self.layer_idx = layer_idx
self.is_sliding = self.layer_type == "sliding_attention"
self.sliding_window = config.sliding_window if self.is_sliding else None
self.head_dim = config.global_head_dim if not self.is_sliding and config.global_head_dim else config.head_dim
self.use_alternative_attention = config.attention_k_eq_v and not self.is_sliding
num_key_value_heads = (
config.num_global_key_value_heads if self.use_alternative_attention else config.num_key_value_heads
)
self.num_key_value_groups = config.num_attention_heads // num_key_value_heads
self.scaling = 1.0
self.attention_dropout = self.config.attention_dropout
self.is_causal = config.use_bidirectional_attention != "all"
# Shared kv cache
first_kv_shared_layer_idx = self.config.num_hidden_layers - getattr(self.config, "num_kv_shared_layers", 0)
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
prev_layers = config.layer_types[:first_kv_shared_layer_idx]
if self.is_kv_shared_layer:
# For shared layers, find the last non-shared layer of the same type before sharing starts
self.kv_shared_layer_index = len(prev_layers) - 1 - prev_layers[::-1].index(config.layer_types[layer_idx])
self.store_full_length_kv = False
else:
self.kv_shared_layer_index = None
# For non-shared layers, store full-length kv if this is the last non-shared layer of its type
self.store_full_length_kv = layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(
config.layer_types[layer_idx]
)
self.q_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Gemma4RMSNorm(dim=self.head_dim, eps=config.rms_norm_eps)
self.v_norm = Gemma4RMSNorm(self.head_dim, eps=config.rms_norm_eps, with_scale=False)
self.k_proj = nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = (
nn.Linear(config.hidden_size, num_key_value_heads * self.head_dim, bias=config.attention_bias)
if not self.use_alternative_attention
else None
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
past_key_values: Cache | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
cos, sin = position_embeddings
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states = self.q_norm(query_states)
query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2)
query_states = query_states.transpose(1, 2)
# For layers with shared KV (from kv sharing point onwards), we reuse the same keys/values states as the last non-sharing layer
if self.is_kv_shared_layer and past_key_values is not None:
key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index]
# Device of past layer may be different from current one
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states
key_states = self.k_norm(key_states)
key_states = apply_rotary_pos_emb(key_states, cos, sin, unsqueeze_dim=2)
key_states = key_states.transpose(1, 2)
value_states = self.v_norm(value_states)
value_states = value_states.transpose(1, 2)
if past_key_values is not None:
if not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
if self.store_full_length_kv:
if not hasattr(past_key_values, "shared_layers"):
past_key_values.shared_layers = {}
past_key_values.shared_layers[self.layer_idx] = key_states, value_states
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,