diff --git a/qwix/_src/flax_util.py b/qwix/_src/flax_util.py index 4ea5d80..daf45af 100644 --- a/qwix/_src/flax_util.py +++ b/qwix/_src/flax_util.py @@ -308,13 +308,13 @@ def fn(x): def update_sharding( - spec: Sequence[Any], + spec: Sequence[Any] | jax.sharding.PartitionSpec, *, shape: Sequence[int] | None = None, split: Collection[int] | None = None, merge: Collection[int] | None = None, transpose: Sequence[int | None] | None = None, -) -> tuple[Any, ...]: +) -> tuple[Any, ...] | jax.sharding.PartitionSpec: """Derives the partition spec from an existing spec. Args: @@ -330,6 +330,8 @@ def update_sharding( The updated partition spec. """ assert bool(split) + bool(merge) + bool(transpose) <= 1 + is_pspec = isinstance(spec, jax.sharding.PartitionSpec) + if split: spec = [(a, None) if i in split else (a,) for i, a in enumerate(spec)] spec = sum(spec, ()) # flatten the list of tuples. @@ -344,6 +346,9 @@ def update_sharding( # For scales: remove sharding for dimensions of size 1. spec = tuple(None if d == 1 else a for a, d in zip(spec, shape)) + if is_pspec: + return jax.sharding.PartitionSpec(*spec) + return spec @@ -380,7 +385,7 @@ def update_boxed( shape = boxed.unbox().shape for possible_field in ('names', 'mesh_axes', 'axes_types'): axes = getattr(boxed, possible_field, None) - if isinstance(axes, (list, tuple)): + if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)): axes = update_sharding( axes, shape=shape, split=split, merge=merge, transpose=transpose ) @@ -396,7 +401,8 @@ def update_boxed( else: sharding_key = 'sharding_names' axes = metadata.get(sharding_key, None) - if isinstance(axes, (list, tuple)): + + if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)): axes = update_sharding( axes, shape=shape, split=split, merge=merge, transpose=transpose ) diff --git a/qwix/_src/providers/lora.py b/qwix/_src/providers/lora.py index 3e03023..b8e76ef 100644 --- a/qwix/_src/providers/lora.py +++ b/qwix/_src/providers/lora.py @@ -13,6 +13,7 @@ # limitations under the License. """Low-Rank Adapation (LoRA) support.""" import dataclasses +import math import string from typing import Any, Callable, Collection, Sequence import warnings @@ -131,6 +132,79 @@ def _parse_einsum_str_for_lora( ) +def _create_lora_layer_shapes( + rhs_ca: Sequence[int], + rhs_ba: Sequence[int], + rhs_ra: Sequence[int], + contract_shape: typing.Shape, + batch_shape: typing.Shape, + remain_shape: typing.Shape, + rank: int, +) -> tuple[ + typing.Shape, # a_shape + Sequence[int | None], # a_sharding_transpose + typing.Shape, # b_shape + Sequence[int | None], # b_sharding_transpose +]: + """Returns lora param shapes and sharding transposes for dot_general.""" + + # LoRA A: (batch, contracting, rank) + a_shape = (*batch_shape, math.prod(contract_shape), rank) + # LoRA B: (batch, rank, remain) + b_shape = (*batch_shape, rank, math.prod(remain_shape)) + + # Inherit sharding from first dims; XLA's SPMD partitioner will + # automatically infer and propagate sharding for subsequent ones. + a_sharding_transpose = (*rhs_ba, rhs_ca[0] if rhs_ca else None, None) + b_sharding_transpose = (*rhs_ba, None, rhs_ra[0] if rhs_ra else None) + + return a_shape, a_sharding_transpose, b_shape, b_sharding_transpose + + +def _compute_lora_delta( + lhs: jax.Array, + lora_a: jax.Array, + lora_b: jax.Array, + lhs_ca: Sequence[int], + lhs_ba: Sequence[int], + contract_shape: typing.Shape, + batch_shape: typing.Shape, + remain_shape: typing.Shape, + rank: int, + precision: jax.lax.PrecisionLike = None, +) -> jax.Array: + """Computes the raw LoRA delta.""" + + lora_a_reshaped = jax.numpy.reshape( + lora_a, (*batch_shape, *contract_shape, rank) + ) + delta_batch_axes = (*range(len(batch_shape)),) + lora_a_contract_axes = ( + *range(len(batch_shape), len(batch_shape) + len(contract_shape)), + ) + delta_a = jax.lax.dot_general( + lhs, + lora_a_reshaped, + ((lhs_ca, lora_a_contract_axes), (lhs_ba, delta_batch_axes)), + precision=precision, + ) + + # delta = delta_a @ lora_b + lora_b_reshaped = jax.numpy.reshape( + lora_b, (*batch_shape, rank, *remain_shape) + ) + delta = jax.lax.dot_general( + delta_a, + lora_b_reshaped, + ( + ((delta_a.ndim - 1,), (len(batch_shape),)), + (delta_batch_axes, delta_batch_axes), + ), + precision=precision, + ) + return delta + + class LoraProvider(ptq.PtqProvider): """Provider for (Q)LoRA. @@ -190,20 +264,35 @@ def dot_general( if weight_name is None: # rhs is not a weight. return res - # We only support ...a,ab->...b for now. - assert ( - len(rhs.shape) == 2 - and tuple(dimension_numbers[0][1]) == (0,) - and not dimension_numbers[1][1] - ), f'Unsupported: {rhs.shape=} {dimension_numbers=}' + (lhs_ca, rhs_ca), (lhs_ba, rhs_ba) = dimension_numbers + + rhs_ra = ( + *(i for i in range(rhs.ndim) if i not in rhs_ca and i not in rhs_ba), + ) + + contract_shape = (*(rhs.shape[i] for i in rhs_ca),) + batch_shape = (*(rhs.shape[i] for i in rhs_ba),) + remain_shape = (*(rhs.shape[i] for i in rhs_ra),) + + a_shape, a_sharding_transpose, b_shape, b_sharding_transpose = ( + _create_lora_layer_shapes( + rhs_ca, + rhs_ba, + rhs_ra, + contract_shape, + batch_shape, + remain_shape, + rule.rank, + ) + ) lora_a, lora_b = _get_or_create_lora_params( name=weight_name, rule=rule, - a_shape=(rhs.shape[0], rule.rank), - b_shape=(rule.rank, rhs.shape[1]), - a_sharding_transpose=(0, None), - b_sharding_transpose=(None, 1), + a_shape=a_shape, + b_shape=b_shape, + a_sharding_transpose=a_sharding_transpose, + b_sharding_transpose=b_sharding_transpose, ) if rule.dropout > 0: @@ -212,7 +301,20 @@ def dot_general( lhs, rngs=flax_util.make_rng('dropout') ) - return res + lhs @ lora_a @ lora_b * (rule.alpha / rule.rank) + delta = _compute_lora_delta( + lhs, + lora_a, + lora_b, + lhs_ca, + lhs_ba, + contract_shape, + batch_shape, + remain_shape, + rule.rank, + precision=precision, + ) + + return res + delta * (rule.alpha / rule.rank) def einsum( self, diff --git a/tests/_src/flax_util_test.py b/tests/_src/flax_util_test.py index 17c534a..3e6a479 100644 --- a/tests/_src/flax_util_test.py +++ b/tests/_src/flax_util_test.py @@ -145,6 +145,34 @@ def __call__(self, x): self.assertIsInstance(variables["lora_a"], nnx.LoRAParam) self.assertIsInstance(variables["lora_b"], nnx.LoRAParam) + def test_update_sharding(self): + # Test tuple + tuple_spec = ("a", "b") + updated = flax_util.update_sharding(tuple_spec, transpose=[1, 0]) + self.assertEqual(updated, ("b", "a")) + + updated = flax_util.update_sharding(tuple_spec, split=[0]) + self.assertEqual(updated, ("a", None, "b")) + + tuple_spec_with_none = ("a", None, "b") + updated = flax_util.update_sharding(tuple_spec_with_none, merge=[0]) + self.assertEqual(updated, ("a", "b")) + + # Test jax.sharding.PartitionSpec + pspec = jax.sharding.PartitionSpec("a", "b") + updated = flax_util.update_sharding(pspec, transpose=[1, 0]) + self.assertIsInstance(updated, jax.sharding.PartitionSpec) + self.assertEqual(updated, jax.sharding.PartitionSpec("b", "a")) + + updated = flax_util.update_sharding(pspec, split=[0]) + self.assertIsInstance(updated, jax.sharding.PartitionSpec) + self.assertEqual(updated, jax.sharding.PartitionSpec("a", None, "b")) + + pspec_with_none = jax.sharding.PartitionSpec("a", None, "b") + updated = flax_util.update_sharding(pspec_with_none, merge=[0]) + self.assertIsInstance(updated, jax.sharding.PartitionSpec) + self.assertEqual(updated, jax.sharding.PartitionSpec("a", "b")) + def test_unbox(self): mesh = jax.make_mesh( (1, 1), @@ -194,6 +222,37 @@ def test_update_boxed(self): self.assertIsInstance(updated, nnx.Param) self.assertEqual(updated.sharding_names, ("b", "a", None)) + # Test nn.Partitioned with jax.sharding.PartitionSpec + boxed_pspec = nn.Partitioned( + jnp.ones((4, 8)), names=jax.sharding.PartitionSpec("a", "b") + ) + updated_pspec = flax_util.update_boxed( + boxed_pspec, value=jnp.ones((2, 2, 8)), split=[0] + ) + self.assertIsInstance(updated_pspec, nn.Partitioned) + self.assertEqual(updated_pspec.value.shape, (2, 2, 8)) + self.assertEqual( + updated_pspec.names, jax.sharding.PartitionSpec("a", None, "b") + ) + + # Test nnx.Param with jax.sharding.PartitionSpec + boxed_nnx_pspec = nnx.Param( + jnp.ones((2, 2, 8)), + out_sharding=jax.sharding.PartitionSpec("a", None, "b"), + ) + updated_nnx_pspec = flax_util.update_boxed( + boxed_nnx_pspec, transpose=[2, 0, None] + ) + self.assertIsInstance(updated_nnx_pspec, nnx.Param) + metadata = updated_nnx_pspec.get_metadata() + sharding_key = ( + "out_sharding" if "out_sharding" in metadata else "sharding_names" + ) + self.assertIsInstance(metadata[sharding_key], jax.sharding.PartitionSpec) + self.assertEqual( + metadata[sharding_key], jax.sharding.PartitionSpec("b", "a", None) + ) + def test_make_rng_linen(self): class MyModule(nn.Module): diff --git a/tests/_src/providers/lora_test.py b/tests/_src/providers/lora_test.py index ee03684..5646630 100644 --- a/tests/_src/providers/lora_test.py +++ b/tests/_src/providers/lora_test.py @@ -341,6 +341,161 @@ def test_lora_dot_general_nn(self, weight_qtype): self.assertEqual(lora_b.unbox().shape, (3, 32)) self.assertEqual(lora_b.names, (None, "b")) + def test_lora_dot_general_batch(self): + """Test LoRA on a dot_general operation with batch dimensions.""" + batch, in_dim, contract, out_dim = 2, 4, 8, 16 + rank = 3 + alpha = 1.0 + + class BatchLinear(nnx.Module): + + def __init__(self, rngs): + # kernel shape: (batch, contract, out) + self.kernel = nnx.Param( + jax.random.normal(rngs.params(), (batch, contract, out_dim)) + ) + + def __call__(self, x): + # x: (batch, in, contract) + # dim_nums: ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) + return jax.lax.dot_general( + x, + self.kernel.value, + (((2,), (1,)), ((0,), (0,))), + ) + + model = BatchLinear(nnx.Rngs(0)) + lora_provider = lora.LoraProvider([ + lora.LoraRule( + module_path=".*", + rank=rank, + alpha=alpha, + lora_b_initializer=nnx.initializers.ones, + ), + ]) + + lhs = jnp.ones((batch, in_dim, contract)) + lora_model = lora.apply_lora_to_model(model, lora_provider, lhs) + + # Check if LoRA parameters were created and have expected shapes + variables = nnx.variables(lora_model, nnx.LoRAParam) + self.assertIn("kernel_lora_a", variables) + self.assertIn("kernel_lora_b", variables) + + lora_a = variables["kernel_lora_a"].value + lora_b = variables["kernel_lora_b"].value + + self.assertEqual(lora_a.shape, (batch, contract, rank)) + self.assertEqual(lora_b.shape, (batch, rank, out_dim)) + + # Verify forward pass + output = lora_model(lhs) + self.assertEqual(output.shape, (batch, in_dim, out_dim)) + + # Expected output + res = jax.lax.dot_general( + lhs, model.kernel.value, (((2,), (1,)), ((0,), (0,))) + ) + delta_a = jax.lax.dot_general(lhs, lora_a, (((2,), (1,)), ((0,), (0,)))) + delta = jax.lax.dot_general(delta_a, lora_b, (((2,), (1,)), ((0,), (0,)))) + expected = res + delta * (alpha / rank) + + self.assertTrue(jnp.allclose(output, expected)) + + def test_lora_dot_general_multi_out_dim(self): + """Test LoRA on a dot_general operation with multiple output dimensions.""" + batch, in_dim, contract, out_dim1, out_dim2 = 2, 4, 8, 4, 4 + rank = 3 + alpha = 1.0 + + class MultiOutLinear(nnx.Module): + + def __init__(self, rngs): + # kernel shape: (batch, contract, out_dim1, out_dim2) + self.kernel = nnx.Param( + jax.random.normal( + rngs.params(), (batch, contract, out_dim1, out_dim2) + ) + ) + + def __call__(self, x): + # x: (batch, in, contract) + # dim_nums: ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) + return jax.lax.dot_general( + x, self.kernel.value, (((2,), (1,)), ((0,), (0,))) + ) + + model = MultiOutLinear(nnx.Rngs(0)) + lora_provider = lora.LoraProvider(rank=rank, alpha=alpha) + lhs = jnp.ones((batch, in_dim, contract)) + lora_model = lora.apply_lora_to_model(model, lora_provider, lhs) + + # Check if LoRA parameters were created and have expected shapes + variables = nnx.variables(lora_model, nnx.LoRAParam) + self.assertIn("kernel_lora_a", variables) + self.assertIn("kernel_lora_b", variables) + + lora_a = variables["kernel_lora_a"].value + lora_b = variables["kernel_lora_b"].value + + self.assertEqual(lora_a.shape, (batch, contract, rank)) + self.assertEqual(lora_b.shape, (batch, rank, out_dim1 * out_dim2)) + + # Verify forward pass + output = lora_model(lhs) + self.assertEqual(output.shape, (batch, in_dim, out_dim1, out_dim2)) + + # Expected output + res = model(lhs) + delta_a = jax.lax.dot_general(lhs, lora_a, (((2,), (1,)), ((0,), (0,)))) + lora_b_reshaped = jnp.reshape(lora_b, (batch, rank, out_dim1, out_dim2)) + delta = jax.lax.dot_general( + delta_a, lora_b_reshaped, (((2,), (1,)), ((0,), (0,))) + ) + expected = res + delta * (alpha / rank) + + self.assertTrue(jnp.allclose(output, expected, atol=1e-5)) + + def test_lora_dot_general_no_weight(self): + """Tests that LoRA skips operations where rhs is not a weight.""" + + class NoWeightModel(nnx.Module): + + def __call__(self, x): + # Multiply by a constant array, not a parameter + return jax.lax.dot_general( + x, jnp.ones((8, 8)), (((1,), (0,)), ((), ())) + ) + + model = NoWeightModel() + lora_provider = lora.LoraProvider(rank=2, alpha=1.0) + # This should work without error and without creating LoRA params + lora_model = lora.apply_lora_to_model( + model, lora_provider, jnp.ones((1, 8)) + ) + self.assertEmpty(nnx.variables(lora_model, nnx.LoRAParam)) + output = lora_model(jnp.ones((1, 8))) + self.assertEqual(output.shape, (1, 8)) + + def test_lora_dot_general_no_remain(self): + """Tests LoRA where there are no remaining dimensions (full contraction).""" + + class FullContractModel(nnx.Module): + + def __init__(self, rngs): + self.kernel = nnx.Param(jax.random.normal(rngs.params(), (8,))) + + def __call__(self, x): + return jax.lax.dot_general( + x, self.kernel.value, (((0,), (0,)), ((), ())) + ) + + model = FullContractModel(nnx.Rngs(0)) + lora_provider = lora.LoraProvider(rank=2, alpha=1.0) + lora_model = lora.apply_lora_to_model(model, lora_provider, jnp.ones((8,))) + output = lora_model(jnp.ones((8,))) + self.assertEqual(output.shape, ()) + def test_lora_conv_nn(self): """Test LoRA on nn.Conv module.""" conv = nn.Conv(