-
Notifications
You must be signed in to change notification settings - Fork 699
Expand file tree
/
Copy pathdebug_quantization.py
More file actions
710 lines (615 loc) · 27.9 KB
/
debug_quantization.py
File metadata and controls
710 lines (615 loc) · 27.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains DebugQuantizer and DebugQuantizedTensor objects,
which are wrappers over Quantizer and QuantizedTensor.
These wrappers add logic related to debugging, using the nvdlfw_inspect package.
"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union, List
import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.quantized_tensor import (
QuantizedTensor,
Quantizer,
QuantizedTensorStorage,
prepare_for_saving,
restore_from_saved,
)
from transformer_engine.debug.pytorch.debug_state import TEDebugState
aten = torch.ops.aten
_tensor_to_gemm_names_map = {
"weight": ["fprop", "dgrad"],
"activation": ["fprop", "wgrad"],
"output": ["fprop", None],
"gradient": ["dgrad", "wgrad"],
"wgrad": ["wgrad", None],
"dgrad": ["dgrad", None],
}
API_CALL_MODIFY = "modify_tensor()"
STANDARD_QUANTIZE = "Quantize"
HIGH_PRECISION = "High Precision"
class DebugQuantizer(Quantizer):
"""
DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect.
It allows adding custom calls inside the quantization process - which enables modifying tensors
or gathering tensor stats.
"""
def __init__(
self,
layer_name: str,
tensor_name: str,
parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup,
tp_size: int,
):
super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name
self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls
self.tp_size = tp_size
self.iteration = TEDebugState.get_iteration()
# Configure parent quantizer
if parent_quantizer is not None:
# .internal = True is slightly faster, but results
# in errors when caching the weights.
# Setting .internal = False is safer.
parent_quantizer.internal = False
# .optimize_for_gemm = True is not supported because debug
# quantizers perform non-GEMM operations.
parent_quantizer.optimize_for_gemm = False
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# next iteration when this quantizer will call any API
# it is None at the init and it is computed after_enabled api calls.
# None at the beginning means that if nothing will be done,
# this quantizer will never call any API.
self.next_debug_iter = None
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted.
#
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor:
self.inspect_tensor_enabled, self.rowwise_tensor_plan = (
self.get_plans_for_output_tensors()
)
else:
(
self.inspect_tensor_enabled,
self.inspect_tensor_postquantize_enabled_rowwise,
self.inspect_tensor_postquantize_enabled_columnwise,
) = self.get_enabled_look_at_tensors()
self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan()
self.log_messages_about_plans()
def get_plans_for_output_tensors(self) -> Tuple[bool, str]:
"""
Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the
API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support
gemm output in FP8.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
)
modify_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION
return inspect_tensor_enabled, plan
def get_enabled_look_at_tensors(self):
"""
Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
)
inspect_tensor_postquantize_enabled_rowwise = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.rowwise_gemm_name,
)
)
inspect_tensor_postquantize_enabled_columnwise = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.columnwise_gemm_name,
)
)
return (
inspect_tensor_enabled,
inspect_tensor_postquantize_enabled_rowwise,
inspect_tensor_postquantize_enabled_columnwise,
)
def get_tensors_plan(self):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import nvdlfw_inspect.api as debug_api
rowwise_plan = None
columnwise_plan = None
modify_rowwise = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
if modify_rowwise:
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
)
if quantize_enabled:
rowwise_plan = STANDARD_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
if self.columnwise_gemm_name is not None:
modify_columnwise = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
if modify_columnwise:
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
)
if quantize_enabled:
columnwise_plan = STANDARD_QUANTIZE
if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION
return rowwise_plan, columnwise_plan
def log_messages_about_plans(self):
"""
Logs the messages about the plans for each of the tensors.
"""
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -"
f" {self.rowwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name),
)
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -"
f" {self.columnwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name),
)
def _call_inspect_tensor_api(
self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None
):
import nvdlfw_inspect.api as debug_api
args = {
"layer_name": self.layer_name,
"tensor": tensor,
"tensor_name": self.tensor_name,
"iteration": TEDebugState.get_iteration(),
"tp_group": self.tp_group,
"tp_size": self.tp_size,
"columnwise_quantized_tensor": columnwise_gemm_tensor,
"rowwise_quantized_tensor": rowwise_gemm_tensor,
"quantizer": self.parent_quantizer,
}
if self.inspect_tensor_enabled:
debug_api.transformer_engine.inspect_tensor(**args)
if self.output_tensor:
return
del args["columnwise_quantized_tensor"]
del args["rowwise_quantized_tensor"]
del args["quantizer"]
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
):
args["tensor"] = rowwise_gemm_tensor
args["rowwise"] = True
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
):
args["tensor"] = columnwise_gemm_tensor
args["rowwise"] = False
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None,
dtype: torch.dtype = None,
):
"""Returns DebugQuantizedTensor object."""
import nvdlfw_inspect.api as debug_api
assert not self.output_tensor
if out is not None:
return self.update_quantized(tensor, self)
# 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed.
self._update_parent_quantizer_usage()
# Only columnwise quantization is not supported.
if self.parent_quantizer is not None:
if not self.parent_quantizer.rowwise_usage and self.parent_quantizer.columnwise_usage:
self.parent_quantizer.set_usage(rowwise=True)
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used.
if self.columnwise_tensor_plan == API_CALL_MODIFY:
columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if dtype is not None:
if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if self.rowwise_tensor_plan == API_CALL_MODIFY:
rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if dtype is not None:
if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
# 3. If some tensors still are not defined we use high precision tensor.
if self.rowwise_tensor_plan == HIGH_PRECISION:
rowwise_gemm_tensor = tensor.to(dtype)
if self.columnwise_tensor_plan == HIGH_PRECISION:
columnwise_gemm_tensor = tensor.to(dtype)
self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor)
# sometimes we may want to return simple tensor with only rowwise_gemm
if self.tensor_name in ["wgrad", "dgrad", "output"]:
return rowwise_gemm_tensor
return DebugQuantizedTensor(
rowwise_gemm_tensor=rowwise_gemm_tensor,
columnwise_gemm_tensor=columnwise_gemm_tensor,
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
)
def process_gemm_output(self, tensor: torch.Tensor):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "Quantized output is not supported for debug=True."
assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY:
tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
gemm=tensor_to_gemm[self.tensor_name],
tensor_name=self.tensor_name,
tensor=tensor,
iteration=self.iteration,
default_quantizer=self.parent_quantizer,
)
self._call_inspect_tensor_api(tensor)
return tensor
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> QuantizedTensor:
"""Override make_empty() from Quantizer class."""
if self.parent_quantizer is not None:
return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device)
return torch.empty(shape, dtype=dtype, device=device)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
# pylint: disable=too-many-boolean-expressions
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or API_CALL_MODIFY in (self.rowwise_tensor_plan, self.columnwise_tensor_plan)
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_QUANTIZE:
return True
return False
def calibrate(self, tensor: torch.Tensor):
"""Calibration override, should not be invoked."""
raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported")
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Update quantized tensor - used in weight caching."""
import nvdlfw_inspect.api as debug_api
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
updated_rowwise_gemm = False
if self.parent_quantizer is not None:
if (
dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_QUANTIZE
):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None)
updated_rowwise_gemm = True
if (
dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_QUANTIZE
and not updated_rowwise_gemm
):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None)
if self.columnwise_tensor_plan == API_CALL_MODIFY:
out = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.columnwise_gemm_tensor,
iteration=self.iteration,
)
assert out is None, (
"API call debug_api.transformer_engine.modify_tensor with out != None should"
" return None"
)
if self.rowwise_tensor_plan == API_CALL_MODIFY:
debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.rowwise_gemm_tensor,
iteration=self.iteration,
)
if self.rowwise_tensor_plan == HIGH_PRECISION:
dst.rowwise_gemm_tensor.copy_(src)
if self.columnwise_tensor_plan == HIGH_PRECISION:
# if they are the same tensor object, it is sufficient to update one
if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor:
dst.columnwise_gemm_tensor.copy_(src)
self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor)
def get_next_debug_iter(self) -> Optional[int]:
"""
Returns the next iteration for which the debug is enabled for this tensor.
If the next iteration is None, then the debug is not enabled for this tensor.
"""
return self.next_debug_iter
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer"""
return None
def process_enabled_api_call(
self, enabled_call_output: bool | Tuple[bool, Optional[int]]
) -> bool:
"""
Process enabled API call output.
Updates self.next_debug_iter field accordingly.
Return the bool representing if the API call is enabled.
"""
if isinstance(enabled_call_output, tuple):
assert len(enabled_call_output) == 2, "Expected a tuple of length 2"
enabled_bool, next_iter = enabled_call_output
else:
enabled_bool = enabled_call_output
next_iter = self.iteration + 1
if self.next_debug_iter is None:
self.next_debug_iter = next_iter
elif next_iter is not None:
# If next iter is None, that means that call will never be enabled.
self.next_debug_iter = min(self.next_debug_iter, next_iter)
return enabled_bool
def supports_only_rowwise_all_gather(self) -> bool:
if self.parent_quantizer is not None:
return self.parent_quantizer.supports_only_rowwise_all_gather()
return False
def _update_parent_quantizer_usage(self):
"""
Updates the usage of the parent quantizer.
"""
rowwise_gemm_quantize = self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE
)
if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=rowwise_gemm_quantize,
columnwise=columnwise_gemm_quantize,
)
def set_usage(self, rowwise: bool = None, columnwise: bool = None):
"""
Sets the usage of the quantizer.
"""
super().set_usage(rowwise=rowwise, columnwise=columnwise)
if not self.output_tensor:
self._update_parent_quantizer_usage()
def wrap_quantized_tensor(self, tensor: QuantizedTensor):
"""
Wraps the quantized tensor with the debug quantizer.
It is used for weight tensors when fp8 model parameters are enabled.
"""
assert (
self.rowwise_tensor_plan == STANDARD_QUANTIZE
and self.columnwise_tensor_plan == STANDARD_QUANTIZE
), (
"[NVTORCH INSPECT ERROR] Weight tensor with fp8 model parameters enabled cannot be"
" modified by any feature."
)
self._call_inspect_tensor_api(None, tensor, tensor)
return DebugQuantizedTensor(
rowwise_gemm_tensor=tensor,
columnwise_gemm_tensor=tensor,
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
)
@classmethod
def multi_tensor_quantize(
cls,
tensor: torch.Tensor,
quantizers: List[Quantizer],
m_splits: List[int],
activation_dtype: torch.dtype,
) -> List[DebugQuantizedTensor]:
"""
Splits a tensor into a list of tensors and quantizes each tensor using a list of quantizers.
"""
tensors = torch.split(tensor, m_splits)
output = []
for tensor, quantizer in zip(tensors, quantizers):
output.append(quantizer.quantize(tensor, dtype=activation_dtype))
return output
class DebugQuantizedTensor(QuantizedTensorStorage):
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
get_tensor().
"""
def __init__(
self,
rowwise_gemm_tensor,
columnwise_gemm_tensor,
quantizer,
layer_name=None,
tensor_name=None,
):
self.rowwise_gemm_tensor = rowwise_gemm_tensor
self.columnwise_gemm_tensor = columnwise_gemm_tensor
self.quantizer = quantizer
self._layer_name = layer_name
self._tensor_name = tensor_name
def prepare_for_saving(self):
""" " Prepare for saving method override"""
self.tensors_to_save = (
[self.rowwise_gemm_tensor, self.columnwise_gemm_tensor]
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor
else [self.rowwise_gemm_tensor]
)
tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save)
self.tensors_to_save = tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking
return tensor_list, self
def restore_from_saved(self, tensors):
"""Restore from saved method override"""
tensor_objects_list, saved_tensors = restore_from_saved(
self.tensors_to_save,
tensors,
return_saved_tensors=True,
)
if len(tensor_objects_list) == 2:
# pylint: disable=unbalanced-tuple-unpacking
self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list
else:
self.rowwise_gemm_tensor = tensor_objects_list[0]
self.columnwise_gemm_tensor = self.rowwise_gemm_tensor
return saved_tensors
def quantize_(self, tensor, *, noop_flag=None):
""" " quantize_ method override"""
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
self.quantizer.update_quantized(tensor, self)
def dequantize(self, *, dtype=None):
""" " dequantize method override"""
if dtype is None:
dtype = self.rowwise_gemm_tensor.dtype
return self.rowwise_gemm_tensor.dequantize().to(dtype)
def get_tensor(self, transpose: bool):
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor
def size(self, *args):
"""Size of the tensor."""
return self.rowwise_gemm_tensor.size(*args)
def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor."""
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor:
# If the same object is used both for rowwise and columnwise gemms,
# there is no benefit in erasing the usage of one of them.
# And there are scenarios when not deleting the usage of one of them is needed.
# For example when we want to recreate columnwise from rowwise.
if rowwise_usage is False:
self.rowwise_gemm_tensor = None
if columnwise_usage is False:
self.columnwise_gemm_tensor = None
if isinstance(self.rowwise_gemm_tensor, QuantizedTensor):
self.rowwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage)
if isinstance(self.columnwise_gemm_tensor, QuantizedTensor):
self.columnwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage)
if rowwise_usage and self.rowwise_gemm_tensor is None:
raise RuntimeError(
"Cannot recreate rowwise tensor from columnwise tensor in debug mode."
)
if columnwise_usage and self.columnwise_gemm_tensor is None:
raise RuntimeError(
"Cannot recreate columnwise tensor from rowwise tensor is debug mode."
)
@property
def device(self):
"""Return the device of the tensor. Define this to avoid expensive PyObject lookups."""
if self.rowwise_gemm_tensor is not None:
return self.rowwise_gemm_tensor.device
if self.columnwise_gemm_tensor is not None:
return self.columnwise_gemm_tensor.device
raise RuntimeError("DebugQuantizedTensor has no data!")