Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/nncf/torch/quantization/quantize_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nncf.torch.quantization.extensions import QuantizedFunctionsCPU
from nncf.torch.quantization.extensions import QuantizedFunctionsCUDA
from nncf.torch.quantization.reference import ReferenceQuantizedFunctions as RQ
from nncf.torch.utils import CompilationWrapper
from nncf.torch.utils import add_ov_domain


Expand Down Expand Up @@ -302,6 +303,32 @@ def asymmetric_quantize_lora(
)
if skip:
return input_
return _asymmetric_quantize_lora(
input_,
Comment on lines +306 to +307
input_shape,
A,
B,
input_low_,
input_range_,
level_low,
level_high,
levels,
eps,
)


def _asymmetric_quantize_lora(
input_,
input_shape,
A,
B,
input_low_,
input_range_,
level_low,
level_high,
levels,
eps,
):
input_range_safe = abs(input_range_) + eps
input_low, input_range = TuneRange.apply(input_low_, input_range_safe, levels)
input_ = (input_ + B @ A).type(input_.dtype) # input(float16) + lora(bfloat16) = float32, need a cast to float16
Expand Down Expand Up @@ -334,6 +361,10 @@ def symmetric_quantize_lora(input_, input_shape, A, B, scale, level_low, level_h
)
if skip:
return input_
return _symmetric_quantize_lora(input_, input_shape, A, B, scale, level_low, level_high, levels, eps)

Comment on lines +364 to +365

def _symmetric_quantize_lora(input_, input_shape, A, B, scale, level_low, level_high, levels, eps):
scale_safe = torch.where(torch.abs(scale) < eps, eps, scale)
input_ = (input_ + B @ A).type(input_.dtype) # input(float16) + lora(bfloat16) = float32, need a cast to float16
return QuantizeSymmetricTorch.apply(
Expand Down Expand Up @@ -471,3 +502,7 @@ def unpack_int4(packed_tensor: torch.Tensor) -> torch.Tensor:
"""
t = unpack_uint4(packed_tensor)
return t.type(torch.int8) - 8


_asymmetric_quantize_lora = CompilationWrapper(_asymmetric_quantize_lora)
_symmetric_quantize_lora = CompilationWrapper(_symmetric_quantize_lora)
Comment on lines +507 to +508
72 changes: 42 additions & 30 deletions src/nncf/torch/quantization/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,6 @@
GeneralizedTensor = TypeVar("GeneralizedTensor", torch.Tensor, np.ndarray)


def fp32_accum_wrapper(func):
def wrapper(tensor_to_sum, ret_tensor):
half = tensor_to_sum.dtype == np.float16
if half:
tensor_to_sum = tensor_to_sum.astype(np.float32)
retval = func(tensor_to_sum, ret_tensor)
if half:
retval = retval.astype(np.float16)
return retval

return wrapper


@fp32_accum_wrapper
def sum_like(tensor_to_sum, ref_tensor):
"""Warning: may modify tensor_to_sum"""
if ref_tensor.size == 1:
return tensor_to_sum.sum()

for dim, size in enumerate(ref_tensor.shape):
if size == 1:
if isinstance(tensor_to_sum, np.ndarray):
tensor_to_sum = tensor_to_sum.sum(dim, keepdims=True)
else:
tensor_to_sum = tensor_to_sum.sum(dim, keepdim=True)
return tensor_to_sum


class ReferenceBackendType(Enum):
NUMPY = "numpy"
TORCH = "torch"
Expand Down Expand Up @@ -79,6 +51,46 @@ def _reciprocal(self, tensor: GeneralizedTensor) -> GeneralizedTensor:
return np.reciprocal(tensor)
return torch.reciprocal(tensor)

def _sum_like(self, tensor_to_sum: GeneralizedTensor, ref_tensor: GeneralizedTensor):
"""Warning: may modify tensor_to_sum"""
Comment on lines +54 to +55
if self.backend is np:
half = tensor_to_sum.dtype == np.float16
if half:
tensor_to_sum = tensor_to_sum.astype(np.float32)
retval = self._sum_like_fp32(tensor_to_sum, ref_tensor)
if half:
retval = retval.astype(np.float16)
return retval

half = tensor_to_sum.dtype == torch.float16
if half:
tensor_to_sum = tensor_to_sum.type(torch.float32)
retval = self._sum_like_fp32(tensor_to_sum, ref_tensor)
if half:
retval = retval.type(torch.float16)
return retval

def _sum_like_fp32(self, tensor_to_sum: GeneralizedTensor, ref_tensor: GeneralizedTensor):
"""Warning: may modify tensor_to_sum"""
Comment on lines +73 to +74
if self.backend is np:
n_elements = ref_tensor.size
if n_elements == 1:
return tensor_to_sum.sum().reshape(ref_tensor.shape)

for dim, size in enumerate(ref_tensor.shape):
if size == 1:
tensor_to_sum = tensor_to_sum.sum(dim, keepdims=True)
return tensor_to_sum

n_elements = ref_tensor.numel()
if n_elements == 1:
return tensor_to_sum.sum().reshape(ref_tensor.shape)

for dim, size in enumerate(ref_tensor.shape):
if size == 1:
tensor_to_sum = tensor_to_sum.sum(dim, keepdim=True)
return tensor_to_sum

def forward(
self, input_: GeneralizedTensor, input_low: GeneralizedTensor, input_range: GeneralizedTensor, levels: int
) -> GeneralizedTensor:
Expand Down Expand Up @@ -114,12 +126,12 @@ def backward(
output = self.forward(input_, input_low, input_range, levels)
err = (output - input_) * self._reciprocal(input_range * range_sign)
grad_range = grad_output * (err * mask_in + range_sign * (level_low / level_high) * mask_lo + mask_hi)
grad_range = sum_like(grad_range, input_range)
grad_range = self._sum_like(grad_range, input_range)

grad_input = grad_output * mask_in

grad_low = grad_output * (mask_hi + mask_lo)
grad_low = sum_like(grad_low, input_low)
grad_low = self._sum_like(grad_low, input_low)
return [grad_input, grad_low, grad_range]

def tune_range(
Expand Down
4 changes: 4 additions & 0 deletions src/nncf/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:

:return: Result of the function call.
"""
# Prevent nested compilation
if torch.compiler.is_compiling():
return self._func(*args, **kwargs)

if self._compiled_func is None:
try:
self._compiled_func = torch.compile(self._func)
Expand Down