forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcache_utils.py
More file actions
1629 lines (1365 loc) · 76.1 KB
/
cache_utils.py
File metadata and controls
1629 lines (1365 loc) · 76.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from abc import ABC, abstractmethod
from collections.abc import Iterable
import torch
from .configuration_utils import PreTrainedConfig
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_quanto_greater,
is_torch_greater_or_equal,
is_torchdynamo_compiling,
logging,
)
if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer
_is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
logger = logging.get_logger(__name__)
class CacheLayerMixin(ABC):
"""Base, abstract class for a single layer's cache."""
is_compileable = False
def __init__(self):
self.keys: torch.Tensor | None = None
self.values: torch.Tensor | None = None
self.is_initialized = False
def __repr__(self):
return f"{self.__class__.__name__}"
@abstractmethod
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: ...
@abstractmethod
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]: ...
@abstractmethod
def get_mask_sizes(self, query_length: int) -> tuple[int, int]: ...
@abstractmethod
def get_seq_length(self) -> int: ...
@abstractmethod
def get_max_cache_shape(self) -> int: ...
def offload(self):
"""Offload this layer's data to CPU device."""
if self.is_initialized:
self.keys = self.keys.to("cpu", non_blocking=True)
self.values = self.values.to("cpu", non_blocking=True)
def prefetch(self):
"""In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
if self.is_initialized and self.keys.device != self.device:
self.keys = self.keys.to(self.device, non_blocking=True)
self.values = self.values.to(self.device, non_blocking=True)
def reset(self) -> None:
"""Resets the cache values while preserving the objects"""
if self.is_initialized:
self.keys.zero_()
self.values.zero_()
# This attribute is set on several Layers
if hasattr(self, "cumulative_length"):
# It can either be an int for dynamic layers, or a tensor for static layers
if isinstance(self.cumulative_length, int):
self.cumulative_length = 0
else:
self.cumulative_length.zero_()
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
"""Reorders this layer's cache for beam search."""
if self.get_seq_length() > 0:
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
class DynamicLayer(CacheLayerMixin):
"""
A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
"""
is_sliding = False
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
self.is_initialized = True
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
self.keys = torch.cat([self.keys, key_states], dim=-2)
self.values = torch.cat([self.values, value_states], dim=-2)
return self.keys, self.values
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the mask"""
kv_offset = 0
kv_length = self.get_seq_length() + query_length
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
if not self.is_initialized or self.keys.numel() == 0:
return 0
return self.keys.shape[-2]
def get_max_cache_shape(self) -> int:
"""Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
return -1
def crop(self, max_length: int) -> None:
"""
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
to remove `max_length` tokens.
"""
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
if self.get_seq_length() <= max_length:
return
self.keys = self.keys[..., :max_length, :]
self.values = self.values[..., :max_length, :]
def batch_repeat_interleave(self, repeats: int) -> None:
"""Repeat the cache `repeats` times in the batch dimension."""
if self.get_seq_length() > 0:
self.keys = self.keys.repeat_interleave(repeats, dim=0)
self.values = self.values.repeat_interleave(repeats, dim=0)
def batch_select_indices(self, indices: torch.Tensor) -> None:
"""Only keep the `indices` in the batch dimension of the cache."""
if self.get_seq_length() > 0:
self.keys = self.keys[indices, ...]
self.values = self.values[indices, ...]
class DynamicSlidingWindowLayer(DynamicLayer):
"""
A cache layer that grows dynamically as more tokens are generated, up until the sliding window size.
It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
"""
is_sliding = True
def __init__(self, sliding_window: int):
super().__init__()
self.sliding_window = sliding_window
self.cumulative_length = 0
self._sliding_window_tensor = torch.tensor(self.sliding_window, dtype=torch.long)
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
super().lazy_initialization(key_states, value_states)
self._sliding_window_tensor = self._sliding_window_tensor.to(self.device)
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
self.cumulative_length += key_states.shape[-2]
# Compute the full states
full_key_states = torch.cat([self.keys, key_states], dim=-2)
full_value_states = torch.cat([self.values, value_states], dim=-2)
# Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
# Return the full states
return full_key_states, full_value_states
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
is_full = self.cumulative_length >= self.sliding_window
kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
if is_full:
kv_length = self.sliding_window - 1 + query_length
else:
kv_length = self.cumulative_length + query_length
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
return self.cumulative_length
def get_max_cache_shape(self) -> int:
"""Return the maximum cache shape of the cache"""
return self.sliding_window
def crop(self, max_length: int) -> None:
"""
Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens.
"""
if self.get_seq_length() >= self.sliding_window:
raise ValueError(
"Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its"
"sliding window (otherwise some states are lost)"
)
super().crop(max_length)
self.cumulative_length = self.keys.shape[-2]
class StaticLayer(CacheLayerMixin):
"""
A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
Args:
max_cache_len (`int`):
Maximum number of tokens that can be stored, used for tensor preallocation.
"""
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int):
super().__init__()
self.max_cache_len = max_cache_len
# Very important that it's a tensor here, to avoid recompiling when we update it and use it to create positions
self.cumulative_length = torch.tensor([0], dtype=int)
def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
"""
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
not be compiled anyway for performances!
"""
self.dtype, self.device = key_states.dtype, key_states.device
self.max_batch_size, self.num_heads = key_states.shape[:2]
self.v_head_dim = value_states.shape[-1]
self.k_head_dim = key_states.shape[-1]
self.keys = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.k_head_dim),
dtype=self.dtype,
device=self.device,
)
self.values = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.v_head_dim),
dtype=self.dtype,
device=self.device,
)
self.cumulative_length = self.cumulative_length.to(self.device)
# Note: `mark_static_address` is used to tag the tensors as a fixed data pointer, preventing compiled graph
# breaks or cudagraph skips due to inplace mutations when updating the cache. However, it is not supported when
# tracing the graph, so we skip it in this case. As prefill should never be compiled, this is not an issue and it
# will still be run (except when users compile prefill explicitly, but this should be avoided!)
# Without this, we cannot use cudagraphs!!
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys)
torch._dynamo.mark_static_address(self.values)
torch._dynamo.mark_static_address(self.cumulative_length)
self.is_initialized = True
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
# Create a tensor to slice the static kv at the correct indices
kv_length = key_states.shape[-2]
cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
# Note that has to be performed in-place, as we have a static address that we need to keep
self.cumulative_length.add_(kv_length)
# Update the cache
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
return self.cumulative_length if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
"""Return the maximum cache shape of the cache"""
return self.max_cache_len
class StaticSlidingWindowLayer(StaticLayer):
"""
A static cache layer that stores the key and value states as static tensors of shape
`[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
tensors, and then mutates them in-place. Built for `torch.compile` support.
Args:
max_cache_len (`int`):
Maximum number of tokens that can be stored, used for tensor preallocation.
sliding_window (`int`):
The size of the sliding window.
"""
is_sliding = True
def __init__(self, max_cache_len: int, sliding_window: int):
effective_max_cache_len = min(sliding_window, max_cache_len)
super().__init__(max_cache_len=effective_max_cache_len)
# Here, to avoid data-dependent control flows, we also need to use a python int to keep track of the cumulative length
self.cumulative_length_int = 0
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
kv_length = key_states.shape[-2]
current_length = self.cumulative_length_int
is_full = current_length >= self.max_cache_len
# Update it now that we saved the value above
self.cumulative_length_int += kv_length
if is_full:
# In general, we should use a much simpler `cat` here as well, independently of the states size. However,
# dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
if key_states.shape[-2] == 1:
# Roll all values to the left by 1 position
new_keys = self.keys.roll(-1, dims=-2)
new_values = self.values.roll(-1, dims=-2)
# Overwrite the last position with new states
# (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
index = torch.tensor([-1], dtype=int, device=self.device)
new_keys[:, :, index] = key_states
new_values[:, :, index] = value_states
# Copy back into `self` (do not just assign again) in order to keep the static dynamo address
self.keys.copy_(new_keys)
self.values.copy_(new_values)
# Very important to return the `self` tensors here, as they have the static dynamo address
return self.keys, self.values
# Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
else:
full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
# Not yet full, but becoming full on this update
elif current_length + kv_length > self.max_cache_len:
# Fast prefill path, no need to cat() in this case, as the cache is currently empty
if current_length == 0:
full_key_states = key_states
full_value_states = value_states
else:
full_key_states = torch.cat((self.keys[:, :, :current_length, :], key_states), dim=-2)
full_value_states = torch.cat((self.values[:, :, :current_length, :], value_states), dim=-2)
else:
# Note: very important to use the tensor version of the cumulative length here, as otherwise cudagraphs
# (triggered by mode="reduced_overhead") will lead to random crashes, as the int would be overwritten
cache_position = torch.arange(kv_length, device=self.device) + self.cumulative_length
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
# Update the tensor version of the length in-place (we don't need to update it if we are already outside
# of this branch, as we don't need the tensor anymore)
self.cumulative_length.add_(kv_length)
# Very important to return the `self` tensors here, as they have the static dynamo address
return self.keys, self.values
# We only cache the last `sliding_window` tokens
self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
# we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
return full_key_states, full_value_states
def get_mask_sizes(self, query_length: int) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
sliding_window = self.max_cache_len
is_full = self.cumulative_length_int >= self.max_cache_len
kv_offset = max(self.cumulative_length_int - sliding_window + 1, 0)
# The cache is already full
if is_full:
kv_length = sliding_window + query_length - 1
# Not yet full, but becoming full on this update
elif self.cumulative_length_int + query_length > sliding_window:
kv_length = self.cumulative_length_int + query_length
# Here the Cache is still smaller than the local size, but we return the local size as it's static
else:
kv_length = sliding_window
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
return self.cumulative_length_int
def reset(self):
super().reset()
self.cumulative_length_int = 0
class QuantizedLayer(DynamicLayer):
"""
A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by
applying quantization.
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length`
is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original
precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size`
for both Keys and Values, in contrast to what was described in the paper.
"""
def __init__(
self,
nbits: int = 4,
axis_key: int = 0,
axis_value: int = 0,
q_group_size: int = 64,
residual_length: int = 128,
):
super().__init__()
self.nbits = nbits
self.axis_key = axis_key
self.axis_value = axis_value
self.q_group_size = q_group_size
self.residual_length = residual_length
self.cumulative_length = 0
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
self.cumulative_length += key_states.shape[-2]
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states, value_states)
self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
return key_states, value_states
dequant_keys = self._dequantize(self._quantized_keys)
dequant_values = self._dequantize(self._quantized_values)
keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
else:
self.keys = torch.cat([self.keys, key_states], dim=-2)
self.values = torch.cat([self.values, value_states], dim=-2)
return keys_to_return, values_to_return
@abstractmethod
def _quantize(self, tensor, axis): ...
@abstractmethod
def _dequantize(self, q_tensor): ...
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
return self.cumulative_length
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
"""Reorders both the residual and quantized buffers for beam search."""
super().reorder_cache(beam_idx)
if hasattr(self, "_quantized_keys"):
dequant_keys = self._dequantize(self._quantized_keys)
dequant_values = self._dequantize(self._quantized_values)
dequant_keys = dequant_keys.index_select(0, beam_idx.to(dequant_keys.device))
dequant_values = dequant_values.index_select(0, beam_idx.to(dequant_values.device))
self._quantized_keys = self._quantize(dequant_keys.contiguous(), axis=self.axis_key)
self._quantized_values = self._quantize(dequant_values.contiguous(), axis=self.axis_value)
def crop(self, max_length: int) -> None:
"""Crop the residual buffer; re-quantize the whole state if the crop falls inside the quantized region."""
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
if self.get_seq_length() <= max_length:
return
if not hasattr(self, "_quantized_keys"):
super().crop(max_length)
self.cumulative_length = max_length
return
# Reconstruct the full-precision tensor, crop, and re-quantize
dequant_keys = self._dequantize(self._quantized_keys)
dequant_values = self._dequantize(self._quantized_values)
full_keys = torch.cat([dequant_keys, self.keys], dim=-2) if self.keys.numel() > 0 else dequant_keys
full_values = torch.cat([dequant_values, self.values], dim=-2) if self.values.numel() > 0 else dequant_values
full_keys = full_keys[..., :max_length, :]
full_values = full_values[..., :max_length, :]
self._quantized_keys = self._quantize(full_keys.contiguous(), axis=self.axis_key)
self._quantized_values = self._quantize(full_values.contiguous(), axis=self.axis_value)
self.keys = torch.tensor([], dtype=self.keys.dtype, device=self.keys.device)
self.values = torch.tensor([], dtype=self.values.dtype, device=self.values.device)
self.cumulative_length = max_length
def batch_repeat_interleave(self, repeats: int) -> None:
"""Repeat both the residual and quantized buffers in the batch dimension."""
super().batch_repeat_interleave(repeats)
if hasattr(self, "_quantized_keys"):
dequant_keys = self._dequantize(self._quantized_keys).repeat_interleave(repeats, dim=0)
dequant_values = self._dequantize(self._quantized_values).repeat_interleave(repeats, dim=0)
self._quantized_keys = self._quantize(dequant_keys.contiguous(), axis=self.axis_key)
self._quantized_values = self._quantize(dequant_values.contiguous(), axis=self.axis_value)
def batch_select_indices(self, indices: torch.Tensor) -> None:
"""Select batch indices from both the residual and quantized buffers."""
super().batch_select_indices(indices)
if hasattr(self, "_quantized_keys"):
dequant_keys = self._dequantize(self._quantized_keys)[indices, ...]
dequant_values = self._dequantize(self._quantized_values)[indices, ...]
self._quantized_keys = self._quantize(dequant_keys.contiguous(), axis=self.axis_key)
self._quantized_values = self._quantize(dequant_values.contiguous(), axis=self.axis_value)
class QuantoQuantizedLayer(QuantizedLayer):
def __init__(
self,
nbits: int = 4,
axis_key: int = 0,
axis_value: int = 0,
q_group_size: int = 64,
residual_length: int = 128,
):
super().__init__(
nbits=nbits,
axis_key=axis_key,
axis_value=axis_value,
q_group_size=q_group_size,
residual_length=residual_length,
)
# We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
if not is_optimum_quanto_available():
raise ImportError(
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto "
"backend. Please install it via with `pip install optimum-quanto`"
)
elif is_quanto_greater("0.2.5", accept_dev=True):
from optimum.quanto import MaxOptimizer, qint2, qint4
else:
raise ImportError(
"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedLayer`. "
)
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
if self.axis_key not in [0, -1]:
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
if self.axis_value not in [0, -1]:
raise ValueError(
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
)
self.qtype = qint4 if self.nbits == 4 else qint2
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
def _quantize(self, tensor, axis):
from optimum.quanto import quantize_weight
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
return qtensor
def _dequantize(self, qtensor):
return qtensor.dequantize()
class HQQQuantizedLayer(QuantizedLayer):
def __init__(
self,
nbits: int = 4,
axis_key: int = 0,
axis_value: int = 0,
q_group_size: int = 64,
residual_length: int = 128,
):
super().__init__(
nbits=nbits,
axis_key=axis_key,
axis_value=axis_value,
q_group_size=q_group_size,
residual_length=residual_length,
)
if not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
if self.nbits not in [1, 2, 3, 4, 8]:
raise ValueError(
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
)
if self.axis_key not in [0, 1]:
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
if self.axis_value not in [0, 1]:
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
self.quantizer = HQQQuantizer
def _quantize(self, tensor, axis):
qtensor, meta = self.quantizer.quantize(
tensor,
axis=axis,
device=self.keys.device,
compute_dtype=self.keys.dtype,
nbits=self.nbits,
group_size=self.q_group_size,
)
meta["compute_dtype"] = self.keys.dtype
self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype
meta["scale"] = meta["scale"].to(qtensor.device)
meta["zero"] = meta["zero"].to(qtensor.device)
return qtensor, meta
def _dequantize(self, qtensor):
quant_tensor, meta = qtensor
tensor = self.quantizer.dequantize(quant_tensor, meta)
return tensor
class LinearAttentionCacheLayerMixin(ABC):
"""Base, abstract class for a linear attention single layer's cache."""
# All shapes are static by essence in a LinearAttention layer, so it is compileable
is_compileable = True
def __init__(self):
self.conv_states: torch.Tensor | None = None
self.recurrent_states: torch.Tensor | None = None
self.is_conv_states_initialized = False
self.is_recurrent_states_initialized = False
self.has_previous_state = False
def __repr__(self):
return f"{self.__class__.__name__}"
@abstractmethod
def lazy_initialization(
self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None
) -> None: ...
@abstractmethod
def update_conv_state(self, conv_states: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def update_recurrent_state(self, recurrent_states: torch.Tensor) -> torch.Tensor: ...
def offload(self):
"""Offload this layer's data to CPU device."""
if self.is_conv_states_initialized:
self.conv_states = self.conv_states.to("cpu", non_blocking=True)
if self.is_recurrent_states_initialized:
self.recurrent_states = self.recurrent_states.to("cpu", non_blocking=True)
def prefetch(self):
"""In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
if self.is_conv_states_initialized and self.conv_states.device != self.device:
self.conv_states = self.conv_states.to(self.device, non_blocking=True)
if self.is_recurrent_states_initialized and self.recurrent_states.device != self.device:
self.recurrent_states = self.recurrent_states.to(self.device, non_blocking=True)
def reset(self) -> None:
"""Resets the cache values while preserving the objects"""
if self.is_conv_states_initialized:
self.conv_states.zero_()
if self.is_recurrent_states_initialized:
self.recurrent_states.zero_()
self.has_previous_state = False
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
if self.is_conv_states_initialized:
self.conv_states = self.conv_states.index_select(0, beam_idx.to(self.device))
# recurrent_states can stay empty sometimes, see e.g. lfm2 which only uses the conv_states
if self.is_recurrent_states_initialized:
self.recurrent_states = self.recurrent_states.index_select(0, beam_idx.to(self.device))
def crop(self, max_length: int):
# We don't crop the linear attention cache, so simply do nothing here
pass
class LinearAttentionLayer(LinearAttentionCacheLayerMixin):
def lazy_initialization(
self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None
) -> None:
# Here, we will lazy init both states separately, each in their own update function
if conv_states is not None:
self.dtype, self.device = conv_states.dtype, conv_states.device
# Even if prefill is larfer/shorter than the conv_size, the tensor is always either padded or truncated
self.max_batch_size, self.conv_kernel_size = conv_states.shape[0], conv_states.shape[-1]
# The shape is always static, so we init as such
self.conv_states = torch.zeros_like(conv_states, dtype=self.dtype, device=self.device)
# Mark as static address to be able to use cudagraphs
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.conv_states)
self.is_conv_states_initialized = True
if recurrent_states is not None:
# The shape is always static, so we init as such
self.recurrent_states = torch.zeros_like(recurrent_states, dtype=self.dtype, device=self.device)
# Mark as static address to be able to use cudagraphs
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.recurrent_states)
self.is_recurrent_states_initialized = True
def update_conv_state(self, conv_states: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Update the linear attention cache in-place, and return the necessary conv states.
Args:
conv_states (`torch.Tensor`): The new conv states to cache.
Returns:
`torch.Tensor`: The updated conv states.
"""
# Lazy initialization
if not self.is_conv_states_initialized:
self.lazy_initialization(conv_states=conv_states)
if not self.has_previous_state:
# Note that we copy instead of assigning, to preserve the static address for cudagraphs
self.conv_states.copy_(conv_states)
self.has_previous_state = True
# Technically, this update is not logically correct if the prefill is smaller than `conv_kernel_size`,
# as it will `roll` anyway in the first decoding step, even though it should `roll` ONLY if the cache is already full.
# But since `conv_kernel_size=4` in practice, it's almost impossible to have a smaller prefill so it's mostly fine for now
else:
# Note that we copy instead of assigning, to preserve the static address for cudagraphs
num_new_tokens = conv_states.shape[-1]
if num_new_tokens >= self.conv_kernel_size:
self.conv_states.copy_(conv_states[..., -self.conv_kernel_size :])
else:
new_conv_states = self.conv_states.roll(shifts=-num_new_tokens, dims=-1)
new_conv_states[:, :, -num_new_tokens:] = conv_states
self.conv_states.copy_(new_conv_states)
return self.conv_states
def update_recurrent_state(self, recurrent_states: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Update the linear attention cache in-place, and return the necessary ssm states.
Args:
smm_states (`torch.Tensor`): The new ssm states to cache.
Returns:
`torch.Tensor`: The updated ssm states.
"""
if not self.is_recurrent_states_initialized:
self.lazy_initialization(recurrent_states=recurrent_states)
# Note that we copy instead of assigning, to preserve the static address for cudagraphs
self.recurrent_states.copy_(recurrent_states)
return self.recurrent_states
class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer):
# The dynamic Attention part makes it non-compileable
is_compileable = False
def __init__(self):
DynamicLayer.__init__(self)
LinearAttentionLayer.__init__(self)
def lazy_initialization(self, *args, **kwargs) -> None:
# When the Attention cache is used with `update`, `lazy_initialization` is called with 2 positional args
if len(args) == 2 and len(kwargs) == 0:
DynamicLayer.lazy_initialization(self, *args)
# Otherwise, for the LinearAttention cache, when it's called in `update_conv_state` or `update_recurrent_state`, it's
# always called with 1 single kwarg (cause it needs to know if it's for the conv or ssm states)
if len(args) == 0 and len(kwargs) == 1:
LinearAttentionLayer.lazy_initialization(self, **kwargs)
def reset(self) -> None:
LinearAttentionLayer.reset(self)
DynamicLayer.reset(self)
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
LinearAttentionLayer.reorder_cache(self, beam_idx)
DynamicLayer.reorder_cache(self, beam_idx)
class Cache:
"""
A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
the Cache of each layer.
Args:
layers (`Optional`, *optional*):
A list of pre-created `CacheLayerMixin` or `LinearAttentionCacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate`
will be used.
layer_class_to_replicate (`type[CacheLayerMixin | LinearAttentionCacheLayerMixin]`, *optional*):
Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
list of layers.
offloading (`bool`, *optional*, defaults to `False`):
Whether to perform offloading of the layers to `cpu`, to save GPU memory.
offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
"""
def __init__(
self,
layers: list[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None,
layer_class_to_replicate: type[CacheLayerMixin | LinearAttentionCacheLayerMixin] | None = None,
offloading: bool = False,
offload_only_non_sliding: bool = True,
):
if layers is not None and layer_class_to_replicate is not None:
raise ValueError(
"You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
"`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
"`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
)
if layers is None and layer_class_to_replicate is None:
raise ValueError(
"You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
)
self.layers = layers if layers is not None else []
self.layer_class_to_replicate = layer_class_to_replicate
self.offloading = offloading
if self.offloading:
self.only_non_sliding = offload_only_non_sliding
self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
def __repr__(self):
return f"{self.__class__.__name__}(layers={self.layers})"
def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
"""
Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
Note that we use a non-default stream for this, to avoid blocking.
"""
if only_non_sliding:
# Try to find next non-sliding, starting at `layer_idx`
try:
layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
# In this case, we need to circle back to the beginning
except ValueError:
layer_idx = self.is_sliding.index(False)
else:
layer_idx = layer_idx if layer_idx < len(self.layers) else 0
# Prefetch
with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
self.layers[layer_idx].prefetch()
def offload(self, layer_idx: int, only_non_sliding: bool = True):
"""
Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
computation in the layer's `update` methods are finished.
"""
if not (only_non_sliding and self.is_sliding[layer_idx]):
self.layers[layer_idx].offload()
def update(
self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
Return:
A tuple containing the updated key and value states.
"""
# In this case, the `layers` were not provided, and we must append as much as `layer_idx`
if self.layer_class_to_replicate is not None:
while len(self.layers) <= layer_idx:
self.layers.append(self.layer_class_to_replicate())
if self.offloading:
# Wait for the stream to finish if needed, and start prefetching the next layer
torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
self.prefetch(layer_idx + 1, self.only_non_sliding)
keys, values = self.layers[layer_idx].update(key_states, value_states, *args, **kwargs)
if self.offloading:
self.offload(layer_idx, self.only_non_sliding)
return keys, values
def update_conv_state(self, conv_states: torch.Tensor, layer_idx: int, **kwargs) -> torch.Tensor:
"""