Skip to content

Commit 29c9a8c

Browse files
liudangyicopybara-github
authored andcommitted
Use native int2 types.
PiperOrigin-RevId: 753273544
1 parent cf7ca70 commit 29c9a8c

3 files changed

Lines changed: 17 additions & 5 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ targets (LiteRT).
1919
converter could produce full integer models.
2020
* LoRA/QLoRA: this mode enables LoRA and QLoRA on a model.
2121
* Supported numerics:
22-
* Native: `int4`, `int8`, `fp8`.
23-
* Emulated: `int1` to `int7`, `nf4`.
22+
* Native: `int2`, `int4`, `int8`, `fp8`.
23+
* Emulated: `int3` to `int7`, `nf4`.
2424
* Supported array calibration methods:
2525
* `absmax`: symmetric quantization using maximum absolute value.
2626
* `minmax`: asymmetric quantization using minimum and maximum values.

qwix/core/numerics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_symmetric_bound(qtype: jax.typing.DTypeLike) -> float:
4343
match qtype:
4444
case 'nf4':
4545
return 1.0
46-
case 'int2' | 'int3' | 'int5' | 'int6' | 'int7':
46+
case 'int3' | 'int5' | 'int6' | 'int7':
4747
# The bound is extended to qmax + 0.5 so that we have a better utilization
4848
# of the whole range. This is more important for fewer bits of int.
4949
return 2 ** (int(qtype[3:]) - 1) - 0.5
@@ -63,7 +63,7 @@ def convert_to(x: jax.Array, qtype: jax.typing.DTypeLike) -> jax.Array:
6363
match qtype:
6464
case 'nf4':
6565
return fp_to_nf4(x)
66-
case 'int2' | 'int3' | 'int5' | 'int6' | 'int7':
66+
case 'int3' | 'int5' | 'int6' | 'int7':
6767
bits = int(qtype[3:])
6868
qmin = -(2 ** (bits - 1))
6969
qmax = 2 ** (bits - 1) - 1

tests/core/numerics_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def test_convert_to(self):
3939
numerics.convert_to(jnp.array([1.2, 3.5, 8, -1300]), jnp.int4),
4040
jnp.array([1, 4, 7, -8], jnp.int4),
4141
)
42+
self._assert_equal(
43+
numerics.convert_to(jnp.array([1.2, 3.5, 8, -1300]), jnp.int2),
44+
jnp.array([1, 1, 1, -2], jnp.int2),
45+
)
4246

4347
def test_inf(self):
4448
self._assert_equal(
@@ -52,7 +56,15 @@ def test_arbitrary_integer_dtype(self):
5256
numerics.convert_to(jnp.array([1.2, 3.5, 129, -1300]), "int6"),
5357
jnp.array([1, 4, 31, -32], jnp.int8),
5458
)
55-
# jnp.int4 and "int4" should be the same.
59+
# jnp.int* and "int*" should be the same.
60+
self._assert_equal(
61+
numerics.get_symmetric_bound("int2"),
62+
numerics.get_symmetric_bound(jnp.int2),
63+
)
64+
self._assert_equal(
65+
numerics.convert_to(jnp.array([1.2, 3.5, 129, -1300]), "int2"),
66+
numerics.convert_to(jnp.array([1.2, 3.5, 129, -1300]), jnp.int2),
67+
)
5668
self._assert_equal(
5769
numerics.get_symmetric_bound("int4"),
5870
numerics.get_symmetric_bound(jnp.int4),

0 commit comments

Comments
 (0)