Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions qwix/_src/flax_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,13 @@ def fn(x):


def update_sharding(
spec: Sequence[Any],
spec: Sequence[Any] | jax.sharding.PartitionSpec,
*,
shape: Sequence[int] | None = None,
split: Collection[int] | None = None,
merge: Collection[int] | None = None,
transpose: Sequence[int | None] | None = None,
) -> tuple[Any, ...]:
) -> tuple[Any, ...] | jax.sharding.PartitionSpec:
"""Derives the partition spec from an existing spec.

Args:
Expand All @@ -330,6 +330,8 @@ def update_sharding(
The updated partition spec.
"""
assert bool(split) + bool(merge) + bool(transpose) <= 1
is_pspec = isinstance(spec, jax.sharding.PartitionSpec)

if split:
spec = [(a, None) if i in split else (a,) for i, a in enumerate(spec)]
spec = sum(spec, ()) # flatten the list of tuples.
Expand All @@ -344,6 +346,9 @@ def update_sharding(
# For scales: remove sharding for dimensions of size 1.
spec = tuple(None if d == 1 else a for a, d in zip(spec, shape))

if is_pspec:
return jax.sharding.PartitionSpec(*spec)

return spec


Expand Down Expand Up @@ -380,7 +385,7 @@ def update_boxed(
shape = boxed.unbox().shape
for possible_field in ('names', 'mesh_axes', 'axes_types'):
axes = getattr(boxed, possible_field, None)
if isinstance(axes, (list, tuple)):
if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
axes = update_sharding(
axes, shape=shape, split=split, merge=merge, transpose=transpose
)
Expand All @@ -396,7 +401,8 @@ def update_boxed(
else:
sharding_key = 'sharding_names'
axes = metadata.get(sharding_key, None)
if isinstance(axes, (list, tuple)):

if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
axes = update_sharding(
axes, shape=shape, split=split, merge=merge, transpose=transpose
)
Expand Down
124 changes: 113 additions & 11 deletions qwix/_src/providers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Low-Rank Adapation (LoRA) support."""
import dataclasses
import math
import string
from typing import Any, Callable, Collection, Sequence
import warnings
Expand Down Expand Up @@ -131,6 +132,79 @@ def _parse_einsum_str_for_lora(
)


def _create_lora_layer_shapes(
rhs_ca: Sequence[int],
rhs_ba: Sequence[int],
rhs_ra: Sequence[int],
contract_shape: typing.Shape,
batch_shape: typing.Shape,
remain_shape: typing.Shape,
rank: int,
) -> tuple[
typing.Shape, # a_shape
Sequence[int | None], # a_sharding_transpose
typing.Shape, # b_shape
Sequence[int | None], # b_sharding_transpose
]:
"""Returns lora param shapes and sharding transposes for dot_general."""

# LoRA A: (batch, contracting, rank)
a_shape = (*batch_shape, math.prod(contract_shape), rank)
# LoRA B: (batch, rank, remain)
b_shape = (*batch_shape, rank, math.prod(remain_shape))

# Inherit sharding from first dims; XLA's SPMD partitioner will
# automatically infer and propagate sharding for subsequent ones.
a_sharding_transpose = (*rhs_ba, rhs_ca[0] if rhs_ca else None, None)
b_sharding_transpose = (*rhs_ba, None, rhs_ra[0] if rhs_ra else None)

return a_shape, a_sharding_transpose, b_shape, b_sharding_transpose


def _compute_lora_delta(
lhs: jax.Array,
lora_a: jax.Array,
lora_b: jax.Array,
lhs_ca: Sequence[int],
lhs_ba: Sequence[int],
contract_shape: typing.Shape,
batch_shape: typing.Shape,
remain_shape: typing.Shape,
rank: int,
precision: jax.lax.PrecisionLike = None,
) -> jax.Array:
"""Computes the raw LoRA delta."""

lora_a_reshaped = jax.numpy.reshape(
lora_a, (*batch_shape, *contract_shape, rank)
)
delta_batch_axes = (*range(len(batch_shape)),)
lora_a_contract_axes = (
*range(len(batch_shape), len(batch_shape) + len(contract_shape)),
)
delta_a = jax.lax.dot_general(
lhs,
lora_a_reshaped,
((lhs_ca, lora_a_contract_axes), (lhs_ba, delta_batch_axes)),
precision=precision,
)

# delta = delta_a @ lora_b
lora_b_reshaped = jax.numpy.reshape(
lora_b, (*batch_shape, rank, *remain_shape)
)
delta = jax.lax.dot_general(
delta_a,
lora_b_reshaped,
(
((delta_a.ndim - 1,), (len(batch_shape),)),
(delta_batch_axes, delta_batch_axes),
),
precision=precision,
)
return delta


class LoraProvider(ptq.PtqProvider):
"""Provider for (Q)LoRA.

Expand Down Expand Up @@ -190,20 +264,35 @@ def dot_general(
if weight_name is None: # rhs is not a weight.
return res

# We only support ...a,ab->...b for now.
assert (
len(rhs.shape) == 2
and tuple(dimension_numbers[0][1]) == (0,)
and not dimension_numbers[1][1]
), f'Unsupported: {rhs.shape=} {dimension_numbers=}'
(lhs_ca, rhs_ca), (lhs_ba, rhs_ba) = dimension_numbers

rhs_ra = (
*(i for i in range(rhs.ndim) if i not in rhs_ca and i not in rhs_ba),
)

contract_shape = (*(rhs.shape[i] for i in rhs_ca),)
batch_shape = (*(rhs.shape[i] for i in rhs_ba),)
remain_shape = (*(rhs.shape[i] for i in rhs_ra),)

a_shape, a_sharding_transpose, b_shape, b_sharding_transpose = (
_create_lora_layer_shapes(
rhs_ca,
rhs_ba,
rhs_ra,
contract_shape,
batch_shape,
remain_shape,
rule.rank,
)
)

lora_a, lora_b = _get_or_create_lora_params(
name=weight_name,
rule=rule,
a_shape=(rhs.shape[0], rule.rank),
b_shape=(rule.rank, rhs.shape[1]),
a_sharding_transpose=(0, None),
b_sharding_transpose=(None, 1),
a_shape=a_shape,
b_shape=b_shape,
a_sharding_transpose=a_sharding_transpose,
b_sharding_transpose=b_sharding_transpose,
)

if rule.dropout > 0:
Expand All @@ -212,7 +301,20 @@ def dot_general(
lhs, rngs=flax_util.make_rng('dropout')
)

return res + lhs @ lora_a @ lora_b * (rule.alpha / rule.rank)
delta = _compute_lora_delta(
lhs,
lora_a,
lora_b,
lhs_ca,
lhs_ba,
contract_shape,
batch_shape,
remain_shape,
rule.rank,
precision=precision,
)

return res + delta * (rule.alpha / rule.rank)

def einsum(
self,
Expand Down
59 changes: 59 additions & 0 deletions tests/_src/flax_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,34 @@ def __call__(self, x):
self.assertIsInstance(variables["lora_a"], nnx.LoRAParam)
self.assertIsInstance(variables["lora_b"], nnx.LoRAParam)

def test_update_sharding(self):
# Test tuple
tuple_spec = ("a", "b")
updated = flax_util.update_sharding(tuple_spec, transpose=[1, 0])
self.assertEqual(updated, ("b", "a"))

updated = flax_util.update_sharding(tuple_spec, split=[0])
self.assertEqual(updated, ("a", None, "b"))

tuple_spec_with_none = ("a", None, "b")
updated = flax_util.update_sharding(tuple_spec_with_none, merge=[0])
self.assertEqual(updated, ("a", "b"))

# Test jax.sharding.PartitionSpec
pspec = jax.sharding.PartitionSpec("a", "b")
updated = flax_util.update_sharding(pspec, transpose=[1, 0])
self.assertIsInstance(updated, jax.sharding.PartitionSpec)
self.assertEqual(updated, jax.sharding.PartitionSpec("b", "a"))

updated = flax_util.update_sharding(pspec, split=[0])
self.assertIsInstance(updated, jax.sharding.PartitionSpec)
self.assertEqual(updated, jax.sharding.PartitionSpec("a", None, "b"))

pspec_with_none = jax.sharding.PartitionSpec("a", None, "b")
updated = flax_util.update_sharding(pspec_with_none, merge=[0])
self.assertIsInstance(updated, jax.sharding.PartitionSpec)
self.assertEqual(updated, jax.sharding.PartitionSpec("a", "b"))

def test_unbox(self):
mesh = jax.make_mesh(
(1, 1),
Expand Down Expand Up @@ -194,6 +222,37 @@ def test_update_boxed(self):
self.assertIsInstance(updated, nnx.Param)
self.assertEqual(updated.sharding_names, ("b", "a", None))

# Test nn.Partitioned with jax.sharding.PartitionSpec
boxed_pspec = nn.Partitioned(
jnp.ones((4, 8)), names=jax.sharding.PartitionSpec("a", "b")
)
updated_pspec = flax_util.update_boxed(
boxed_pspec, value=jnp.ones((2, 2, 8)), split=[0]
)
self.assertIsInstance(updated_pspec, nn.Partitioned)
self.assertEqual(updated_pspec.value.shape, (2, 2, 8))
self.assertEqual(
updated_pspec.names, jax.sharding.PartitionSpec("a", None, "b")
)

# Test nnx.Param with jax.sharding.PartitionSpec
boxed_nnx_pspec = nnx.Param(
jnp.ones((2, 2, 8)),
out_sharding=jax.sharding.PartitionSpec("a", None, "b"),
)
updated_nnx_pspec = flax_util.update_boxed(
boxed_nnx_pspec, transpose=[2, 0, None]
)
self.assertIsInstance(updated_nnx_pspec, nnx.Param)
metadata = updated_nnx_pspec.get_metadata()
sharding_key = (
"out_sharding" if "out_sharding" in metadata else "sharding_names"
)
self.assertIsInstance(metadata[sharding_key], jax.sharding.PartitionSpec)
self.assertEqual(
metadata[sharding_key], jax.sharding.PartitionSpec("b", "a", None)
)

def test_make_rng_linen(self):
class MyModule(nn.Module):

Expand Down
Loading