Skip to content

Commit dbce715

Browse files
Qwix Developerscopybara-github
authored andcommitted
Consolidated qwix improvements for LoRA and sharding support.
This CL is for MaxText to provide support for N-dimensional kernels and batch dimensions in LoRA dot_general, enabling LoRA to be applied to a wider range of model architectures. It also updates update_sharding and update_boxed to support jax.sharding.PartitionSpec, ensuring compatibility with modern JAX sharding APIs and aligning qwix with MaxText sharding requirements. Comprehensive unit tests are included to verify these enhancements. PiperOrigin-RevId: 892161072
1 parent 56471fc commit dbce715

File tree

4 files changed

+297
-15
lines changed

4 files changed

+297
-15
lines changed

qwix/_src/flax_util.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,13 @@ def fn(x):
308308

309309

310310
def update_sharding(
311-
spec: Sequence[Any],
311+
spec: Sequence[Any] | jax.sharding.PartitionSpec,
312312
*,
313313
shape: Sequence[int] | None = None,
314314
split: Collection[int] | None = None,
315315
merge: Collection[int] | None = None,
316316
transpose: Sequence[int | None] | None = None,
317-
) -> tuple[Any, ...]:
317+
) -> tuple[Any, ...] | jax.sharding.PartitionSpec:
318318
"""Derives the partition spec from an existing spec.
319319
320320
Args:
@@ -330,6 +330,8 @@ def update_sharding(
330330
The updated partition spec.
331331
"""
332332
assert bool(split) + bool(merge) + bool(transpose) <= 1
333+
is_pspec = isinstance(spec, jax.sharding.PartitionSpec)
334+
333335
if split:
334336
spec = [(a, None) if i in split else (a,) for i, a in enumerate(spec)]
335337
spec = sum(spec, ()) # flatten the list of tuples.
@@ -344,6 +346,9 @@ def update_sharding(
344346
# For scales: remove sharding for dimensions of size 1.
345347
spec = tuple(None if d == 1 else a for a, d in zip(spec, shape))
346348

349+
if is_pspec:
350+
return jax.sharding.PartitionSpec(*spec)
351+
347352
return spec
348353

349354

@@ -380,7 +385,7 @@ def update_boxed(
380385
shape = boxed.unbox().shape
381386
for possible_field in ('names', 'mesh_axes', 'axes_types'):
382387
axes = getattr(boxed, possible_field, None)
383-
if isinstance(axes, (list, tuple)):
388+
if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
384389
axes = update_sharding(
385390
axes, shape=shape, split=split, merge=merge, transpose=transpose
386391
)
@@ -396,7 +401,8 @@ def update_boxed(
396401
else:
397402
sharding_key = 'sharding_names'
398403
axes = metadata.get(sharding_key, None)
399-
if isinstance(axes, (list, tuple)):
404+
405+
if isinstance(axes, (list, tuple, jax.sharding.PartitionSpec)):
400406
axes = update_sharding(
401407
axes, shape=shape, split=split, merge=merge, transpose=transpose
402408
)

qwix/_src/providers/lora.py

Lines changed: 113 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Low-Rank Adapation (LoRA) support."""
1515
import dataclasses
16+
import math
1617
import string
1718
from typing import Any, Callable, Collection, Sequence
1819
import warnings
@@ -131,6 +132,79 @@ def _parse_einsum_str_for_lora(
131132
)
132133

133134

135+
def _create_lora_layer_shapes(
136+
rhs_ca: Sequence[int],
137+
rhs_ba: Sequence[int],
138+
rhs_ra: Sequence[int],
139+
contract_shape: typing.Shape,
140+
batch_shape: typing.Shape,
141+
remain_shape: typing.Shape,
142+
rank: int,
143+
) -> tuple[
144+
typing.Shape, # a_shape
145+
Sequence[int | None], # a_sharding_transpose
146+
typing.Shape, # b_shape
147+
Sequence[int | None], # b_sharding_transpose
148+
]:
149+
"""Returns lora param shapes and sharding transposes for dot_general."""
150+
151+
# LoRA A: (batch, contracting, rank)
152+
a_shape = (*batch_shape, math.prod(contract_shape), rank)
153+
# LoRA B: (batch, rank, remain)
154+
b_shape = (*batch_shape, rank, math.prod(remain_shape))
155+
156+
# Inherit sharding from first dims; XLA's SPMD partitioner will
157+
# automatically infer and propagate sharding for subsequent ones.
158+
a_sharding_transpose = (*rhs_ba, rhs_ca[0] if rhs_ca else None, None)
159+
b_sharding_transpose = (*rhs_ba, None, rhs_ra[0] if rhs_ra else None)
160+
161+
return a_shape, a_sharding_transpose, b_shape, b_sharding_transpose
162+
163+
164+
def _compute_lora_delta(
165+
lhs: jax.Array,
166+
lora_a: jax.Array,
167+
lora_b: jax.Array,
168+
lhs_ca: Sequence[int],
169+
lhs_ba: Sequence[int],
170+
contract_shape: typing.Shape,
171+
batch_shape: typing.Shape,
172+
remain_shape: typing.Shape,
173+
rank: int,
174+
precision: jax.lax.PrecisionLike = None,
175+
) -> jax.Array:
176+
"""Computes the raw LoRA delta."""
177+
178+
lora_a_reshaped = jax.numpy.reshape(
179+
lora_a, (*batch_shape, *contract_shape, rank)
180+
)
181+
delta_batch_axes = (*range(len(batch_shape)),)
182+
lora_a_contract_axes = (
183+
*range(len(batch_shape), len(batch_shape) + len(contract_shape)),
184+
)
185+
delta_a = jax.lax.dot_general(
186+
lhs,
187+
lora_a_reshaped,
188+
((lhs_ca, lora_a_contract_axes), (lhs_ba, delta_batch_axes)),
189+
precision=precision,
190+
)
191+
192+
# delta = delta_a @ lora_b
193+
lora_b_reshaped = jax.numpy.reshape(
194+
lora_b, (*batch_shape, rank, *remain_shape)
195+
)
196+
delta = jax.lax.dot_general(
197+
delta_a,
198+
lora_b_reshaped,
199+
(
200+
((delta_a.ndim - 1,), (len(batch_shape),)),
201+
(delta_batch_axes, delta_batch_axes),
202+
),
203+
precision=precision,
204+
)
205+
return delta
206+
207+
134208
class LoraProvider(ptq.PtqProvider):
135209
"""Provider for (Q)LoRA.
136210
@@ -190,20 +264,35 @@ def dot_general(
190264
if weight_name is None: # rhs is not a weight.
191265
return res
192266

193-
# We only support ...a,ab->...b for now.
194-
assert (
195-
len(rhs.shape) == 2
196-
and tuple(dimension_numbers[0][1]) == (0,)
197-
and not dimension_numbers[1][1]
198-
), f'Unsupported: {rhs.shape=} {dimension_numbers=}'
267+
(lhs_ca, rhs_ca), (lhs_ba, rhs_ba) = dimension_numbers
268+
269+
rhs_ra = (
270+
*(i for i in range(rhs.ndim) if i not in rhs_ca and i not in rhs_ba),
271+
)
272+
273+
contract_shape = (*(rhs.shape[i] for i in rhs_ca),)
274+
batch_shape = (*(rhs.shape[i] for i in rhs_ba),)
275+
remain_shape = (*(rhs.shape[i] for i in rhs_ra),)
276+
277+
a_shape, a_sharding_transpose, b_shape, b_sharding_transpose = (
278+
_create_lora_layer_shapes(
279+
rhs_ca,
280+
rhs_ba,
281+
rhs_ra,
282+
contract_shape,
283+
batch_shape,
284+
remain_shape,
285+
rule.rank,
286+
)
287+
)
199288

200289
lora_a, lora_b = _get_or_create_lora_params(
201290
name=weight_name,
202291
rule=rule,
203-
a_shape=(rhs.shape[0], rule.rank),
204-
b_shape=(rule.rank, rhs.shape[1]),
205-
a_sharding_transpose=(0, None),
206-
b_sharding_transpose=(None, 1),
292+
a_shape=a_shape,
293+
b_shape=b_shape,
294+
a_sharding_transpose=a_sharding_transpose,
295+
b_sharding_transpose=b_sharding_transpose,
207296
)
208297

209298
if rule.dropout > 0:
@@ -212,7 +301,20 @@ def dot_general(
212301
lhs, rngs=flax_util.make_rng('dropout')
213302
)
214303

215-
return res + lhs @ lora_a @ lora_b * (rule.alpha / rule.rank)
304+
delta = _compute_lora_delta(
305+
lhs,
306+
lora_a,
307+
lora_b,
308+
lhs_ca,
309+
lhs_ba,
310+
contract_shape,
311+
batch_shape,
312+
remain_shape,
313+
rule.rank,
314+
precision=precision,
315+
)
316+
317+
return res + delta * (rule.alpha / rule.rank)
216318

217319
def einsum(
218320
self,

tests/_src/flax_util_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,34 @@ def __call__(self, x):
145145
self.assertIsInstance(variables["lora_a"], nnx.LoRAParam)
146146
self.assertIsInstance(variables["lora_b"], nnx.LoRAParam)
147147

148+
def test_update_sharding(self):
149+
# Test tuple
150+
tuple_spec = ("a", "b")
151+
updated = flax_util.update_sharding(tuple_spec, transpose=[1, 0])
152+
self.assertEqual(updated, ("b", "a"))
153+
154+
updated = flax_util.update_sharding(tuple_spec, split=[0])
155+
self.assertEqual(updated, ("a", None, "b"))
156+
157+
tuple_spec_with_none = ("a", None, "b")
158+
updated = flax_util.update_sharding(tuple_spec_with_none, merge=[0])
159+
self.assertEqual(updated, ("a", "b"))
160+
161+
# Test jax.sharding.PartitionSpec
162+
pspec = jax.sharding.PartitionSpec("a", "b")
163+
updated = flax_util.update_sharding(pspec, transpose=[1, 0])
164+
self.assertIsInstance(updated, jax.sharding.PartitionSpec)
165+
self.assertEqual(updated, jax.sharding.PartitionSpec("b", "a"))
166+
167+
updated = flax_util.update_sharding(pspec, split=[0])
168+
self.assertIsInstance(updated, jax.sharding.PartitionSpec)
169+
self.assertEqual(updated, jax.sharding.PartitionSpec("a", None, "b"))
170+
171+
pspec_with_none = jax.sharding.PartitionSpec("a", None, "b")
172+
updated = flax_util.update_sharding(pspec_with_none, merge=[0])
173+
self.assertIsInstance(updated, jax.sharding.PartitionSpec)
174+
self.assertEqual(updated, jax.sharding.PartitionSpec("a", "b"))
175+
148176
def test_unbox(self):
149177
mesh = jax.make_mesh(
150178
(1, 1),
@@ -194,6 +222,37 @@ def test_update_boxed(self):
194222
self.assertIsInstance(updated, nnx.Param)
195223
self.assertEqual(updated.sharding_names, ("b", "a", None))
196224

225+
# Test nn.Partitioned with jax.sharding.PartitionSpec
226+
boxed_pspec = nn.Partitioned(
227+
jnp.ones((4, 8)), names=jax.sharding.PartitionSpec("a", "b")
228+
)
229+
updated_pspec = flax_util.update_boxed(
230+
boxed_pspec, value=jnp.ones((2, 2, 8)), split=[0]
231+
)
232+
self.assertIsInstance(updated_pspec, nn.Partitioned)
233+
self.assertEqual(updated_pspec.value.shape, (2, 2, 8))
234+
self.assertEqual(
235+
updated_pspec.names, jax.sharding.PartitionSpec("a", None, "b")
236+
)
237+
238+
# Test nnx.Param with jax.sharding.PartitionSpec
239+
boxed_nnx_pspec = nnx.Param(
240+
jnp.ones((2, 2, 8)),
241+
out_sharding=jax.sharding.PartitionSpec("a", None, "b"),
242+
)
243+
updated_nnx_pspec = flax_util.update_boxed(
244+
boxed_nnx_pspec, transpose=[2, 0, None]
245+
)
246+
self.assertIsInstance(updated_nnx_pspec, nnx.Param)
247+
metadata = updated_nnx_pspec.get_metadata()
248+
sharding_key = (
249+
"out_sharding" if "out_sharding" in metadata else "sharding_names"
250+
)
251+
self.assertIsInstance(metadata[sharding_key], jax.sharding.PartitionSpec)
252+
self.assertEqual(
253+
metadata[sharding_key], jax.sharding.PartitionSpec("b", "a", None)
254+
)
255+
197256
def test_make_rng_linen(self):
198257
class MyModule(nn.Module):
199258

0 commit comments

Comments
 (0)