Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "ComfyUI-QuantOps"
description = "Extended quantization layouts for ComfyUI (INT8, row/block-wise FP8)"
version = "1.7.2"
version = "1.8.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
Expand Down
27 changes: 25 additions & 2 deletions unified_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,14 @@ def _load_from_state_dict(

if is_tensorwise and _HAS_TENSORWISE_INT8_LAYOUT:
self.layout_type = "TensorWiseINT8Layout"
_orig_dtype_str = layer_conf.get("orig_dtype", "torch.bfloat16") if layer_conf else "torch.bfloat16"
_DTYPE_MAP = {"torch.bfloat16": torch.bfloat16, "torch.float16": torch.float16, "torch.float32": torch.float32}
_orig_dtype = _DTYPE_MAP.get(_orig_dtype_str, torch.bfloat16)
layout_params = TensorWiseINT8Layout.Params(
scale=scale.to(torch.float32)
if scale is not None
else None,
orig_dtype=torch.bfloat16,
orig_dtype=_orig_dtype,
orig_shape=tuple(weight_tensor.shape),
is_weight=True,
)
Expand Down Expand Up @@ -205,6 +208,14 @@ def _load_from_state_dict(
requires_grad=False,
)
else:
# TODO (#2 — medium severity, low risk): this branch fires when
# is_tensorwise=True but _HAS_TENSORWISE_INT8_LAYOUT=False (ck absent).
# Result: raw int8 tensor stored with is_quantized=True, layout_type=None.
# That is a broken state — forward() will hit F.linear with raw int8 weight.
# Fix: degrade to BlockWiseINT8Layout if _HAS_INT8_LAYOUT, else set
# is_quantized=False and log a warning. Not patching now because ck is
# effectively required for tensorwise; if ck import failed the checkpoint
# is already unrunnable regardless.
self.weight = torch.nn.Parameter(
weight_tensor, requires_grad=False
)
Expand Down Expand Up @@ -463,7 +474,19 @@ def forward_comfy_cast_weights(self, input):
)

else:
# Default trigger for QuantizedTensor dispatch -> layout-specific handler
# Default trigger for QuantizedTensor dispatch -> layout-specific handler.
# TensorWiseINT8Layout and BlockWiseINT8Layout land here — aten.linear
# dispatch in comfy_kitchen handles the actual matmul.
#
# TODO (#3 — low-medium severity, medium risk): this else branch has no 3D
# input reshape guard, unlike all the explicit elif branches above. ComfyUI
# transformer attention layers pass [batch, seq, hidden] (3D). F.linear
# handles 3D natively so it works, but ck dispatch handlers may not. If
# tensorwise inference produces wrong shapes on 3D inputs, add the standard
# tensor_3d guard here (reshape -1,hidden before linear, reshape back after).
# Not patching now — risk of breaking currently-working layouts that fall
# through to this branch (e.g. RowWiseFP8, BlockWiseFP8 if aten dispatch
# handles them here too).
out = torch.nn.functional.linear(input, weight, bias)

else:
Expand Down
39 changes: 33 additions & 6 deletions utils/eager_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,40 @@ def int8_linear(
bias: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""INT8 linear layer using torch.int8_mm for direct quantized matmul.

Uses native torch.int8_mm which avoids materializing large float32 intermediates
and handles scaling more efficiently than manual int32 -> float32 conversion.

Ported from comfy-kitchen eager backend with OOM fixes.
"""INT8 linear layer. Delegates to comfy_kitchen.int8_linear (triton->eager)
when available, falls back to local torch.int8_mm chunked path.

ck.int8_linear signature matches exactly:
(x, weight, weight_scale, bias=None, out_dtype=None)
weight: [N, K] int8, weight_scale: scalar float32, out_dtype defaults bfloat16.
"""
# Prefer comfy_kitchen dispatch (triton -> eager via registry).
# ck.int8_linear routes through torch.ops.comfy_kitchen.int8_linear which
# goes through the registry with priority ["cuda", "triton", "eager"].
# cuda backend has no int8_linear, so triton wins if available, else eager.
try:
import comfy_kitchen as ck
return ck.int8_linear(x, weight, weight_scale, bias, out_dtype)
except ImportError:
pass
except Exception as e:
import logging
logging.warning(f"ComfyUI-QuantOps: ck.int8_linear failed, falling back to local path: {e}")

# --- Local fallback: chunked torch.int8_mm path (OOM-safe) ---
# Unwrap QuantizedTensor if weight arrived still wrapped (defensive).
try:
from comfy.quant_ops import QuantizedTensor
if isinstance(weight, QuantizedTensor):
weight_scale = weight._params.scale
weight = weight._qdata
except ImportError:
pass

# Ensure weight is raw int8 and contiguous before torch.int8_mm.
if not weight.is_contiguous():
weight = weight.contiguous()

orig_shape = x.shape
x_2d = x.reshape(-1, x.shape[-1])

Expand Down