Skip to content

Commit f90a1cf

Browse files
committed
Add NNX SparsityModue
1 parent c63b310 commit f90a1cf

File tree

2 files changed

+141
-37
lines changed

2 files changed

+141
-37
lines changed
Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Google LLC
1+
# Copyright 2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,31 +13,54 @@
1313
# limitations under the License.
1414
"""Core sparsity quantized training support."""
1515

16-
import flax.linen as nn
16+
from flax import nnx
1717
import jax
1818
import jax.numpy as jnp
19-
from qwix._src import flax_util
2019
from qwix._src.core import sparsity
2120

2221

23-
class SparsityModule(nn.Module):
24-
"""Sparsity module for Flax."""
22+
class SparsityModule(nnx.Module):
23+
"""A stateful module for managing and applying structured sparsity in Flax NNX.
2524
26-
sparsity_rule: sparsity.SparsityRule | None = None
25+
This module tracks the training step and maintains a persistent sparsity mask
26+
as `nnx.BatchStat` variables (effectively part of the model's batch stats,
27+
not trainable parameters). It can be used to apply structured N:M sparsity
28+
to activations and/or weights.
29+
30+
For weight sparsity, it periodically updates a cached boolean mask based on
31+
the `SparsityRule` and applied it to the weights. For activation sparsity,
32+
it computes and applies the mask dynamically on each call if enabled.
33+
34+
Attributes:
35+
step: An `nnx.BatchStat` tracking the number of update steps.
36+
mask: An `nnx.BatchStat` holding the persistent boolean mask for weights.
37+
sparsity_rule: The `SparsityRule` configuration.
38+
"""
39+
40+
step: nnx.BatchStat
41+
mask: nnx.BatchStat
42+
43+
def __init__(
44+
self,
45+
shape: tuple[int, ...],
46+
sharding_axes: tuple[jax.sharding.PartitionSpec | None, ...],
47+
sparsity_rule: sparsity.SparsityRule | None = None,
48+
):
49+
self.sparsity_rule = sparsity_rule
50+
self.step = nnx.BatchStat(jnp.zeros([], jnp.int32))
51+
self.mask = nnx.BatchStat(
52+
jnp.ones(shape, jnp.bool_), sharding=sharding_axes
53+
)
2754

2855
def _maybe_update_mask(
2956
self,
3057
weight: jax.Array,
3158
step: jax.Array,
3259
) -> jax.Array:
3360
"""Updates the sparsity mask based on the current step and config."""
34-
35-
mask_val = flax_util.get_or_create_variable(
36-
'compression', 'mask', lambda: jnp.ones(weight.shape, jnp.bool_)
37-
)
38-
# NOTE: Reshape if mask and wesight have shape mismatch.
61+
mask_val = self.mask.value
3962
if mask_val.shape != weight.shape:
40-
mask_val = jnp.reshape(mask_val, weight.shape)
63+
mask_val = mask_val[tuple(slice(0, s) for s in weight.shape)]
4164

4265
def mask_update(w: jax.Array, mask_val: jax.Array) -> jax.Array: # pylint: disable=unused-argument
4366
if self.sparsity_rule is None:
@@ -65,7 +88,8 @@ def should_update_mask(step: jax.Array):
6588
% self.sparsity_rule.weight_sparsity_update_step,
6689
0,
6790
)
68-
return jnp.logical_and(in_update_window, is_update_step)
91+
should_update = jnp.logical_and(in_update_window, is_update_step)
92+
return should_update
6993

7094
new_mask_val = jax.lax.cond(
7195
should_update_mask(step),
@@ -76,7 +100,6 @@ def should_update_mask(step: jax.Array):
76100
)
77101
return new_mask_val
78102

79-
@nn.compact
80103
def __call__(
81104
self, inputs: jax.Array, weight: jax.Array
82105
) -> tuple[jax.Array, jax.Array]:
@@ -97,29 +120,18 @@ def __call__(
97120
input_mask, inputs, jnp.zeros(inputs.shape, inputs.dtype)
98121
)
99122
if self.sparsity_rule.weight_sparsity_m != 0:
100-
101-
step = flax_util.get_or_create_variable(
102-
'compression', 'step', lambda: jnp.zeros([], jnp.int32)
123+
if self.mask is None:
124+
self.mask = nnx.BatchStat(jnp.ones(weight.shape, jnp.bool_))
125+
126+
# Only update if not in eval mode
127+
if not self.sparsity_rule.eval_mode:
128+
new_mask = self._maybe_update_mask(weight=weight, step=self.step.value)
129+
jax.debug.print('amanda Current Sparsity Step: {s}', s=self.step.value)
130+
self.mask.value = new_mask
131+
self.step.value = self.step.value + 1
132+
133+
weight = jnp.where(
134+
self.mask.value, weight, jnp.zeros(weight.shape, weight.dtype)
103135
)
104136

105-
mask = flax_util.get_or_create_variable(
106-
'compression', 'mask', lambda: jnp.ones(weight.shape, jnp.bool_)
107-
)
108-
109-
if not self.is_initializing() and self.has_variable(
110-
'compression', 'mask'
111-
):
112-
# Do not update mask for eval.
113-
if not self.sparsity_rule.eval_mode:
114-
new_mask = self._maybe_update_mask(weight=weight, step=step.value)
115-
mask.value = new_mask
116-
step.value = step.value + 1
117-
118-
# Unless updated mask is all ones, so we apply mask irrespective of
119-
# start_step
120-
121-
weight = jnp.where(
122-
mask.value, weight, jnp.zeros(weight.shape, weight.dtype)
123-
)
124-
125137
return inputs, weight
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for sparsity_module module, for update mask and apply sparsity."""
15+
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
import jax.numpy as jnp
19+
from qwix._src.core import sparsity
20+
from qwix.contrib.sparsity import sparsity_module
21+
22+
23+
class SparsityQtTest(parameterized.TestCase):
24+
25+
def test_no_sparsity(self):
26+
module = sparsity_module.SparsityModule(shape=(), sharding_axes=())
27+
inputs = jnp.arange(10, dtype=jnp.float32)
28+
weight = jnp.arange(10, dtype=jnp.float32)
29+
out_inputs, out_weight = module(inputs, weight)
30+
self.assertTrue(jnp.array_equal(out_inputs, inputs))
31+
self.assertTrue(jnp.array_equal(out_weight, weight))
32+
33+
def test_activation_sparsity(self):
34+
rule = sparsity.SparsityRule(
35+
activation_sparsity_n=1, activation_sparsity_m=2
36+
)
37+
module = sparsity_module.SparsityModule(
38+
shape=(), sharding_axes=(), sparsity_rule=rule
39+
)
40+
inputs = jnp.array([1.0, 2.0, 3.0, 4.0])
41+
weight = jnp.array([1.0, 1.0, 1.0, 1.0])
42+
out_inputs, out_weight = module(inputs, weight)
43+
self.assertTrue(
44+
jnp.array_equal(out_inputs, jnp.array([0.0, 2.0, 0.0, 4.0]))
45+
)
46+
self.assertTrue(jnp.array_equal(out_weight, weight))
47+
48+
def test_weight_sparsity(self):
49+
rule = sparsity.SparsityRule(
50+
weight_sparsity_n=1,
51+
weight_sparsity_m=2,
52+
weight_sparsity_start_step=0,
53+
weight_sparsity_update_step=1,
54+
)
55+
module = sparsity_module.SparsityModule(
56+
shape=(4,), sharding_axes=(), sparsity_rule=rule
57+
)
58+
inputs = jnp.array([1.0, 1.0, 1.0, 1.0])
59+
weight = jnp.array([1.0, 2.0, 3.0, 4.0])
60+
61+
self.assertEqual(module.step.value, 0)
62+
63+
out_inputs, out_weight = module(inputs, weight)
64+
65+
self.assertEqual(module.step.value, 1)
66+
expected_weight = jnp.array([0.0, 2.0, 0.0, 4.0])
67+
self.assertTrue(jnp.array_equal(out_inputs, inputs))
68+
self.assertTrue(jnp.array_equal(out_weight, expected_weight))
69+
70+
def test_eval_mode(self):
71+
rule = sparsity.SparsityRule(
72+
weight_sparsity_n=1,
73+
weight_sparsity_m=2,
74+
eval_mode=True,
75+
)
76+
module = sparsity_module.SparsityModule(
77+
shape=(4,), sharding_axes=(), sparsity_rule=rule
78+
)
79+
inputs = jnp.array([1.0, 1.0, 1.0, 1.0])
80+
weight = jnp.array([1.0, 2.0, 3.0, 4.0])
81+
82+
self.assertEqual(module.step.value, 0)
83+
84+
out_inputs, out_weight = module(inputs, weight)
85+
86+
self.assertTrue(jnp.array_equal(out_weight, weight))
87+
self.assertTrue(jnp.array_equal(out_inputs, inputs))
88+
self.assertEqual(module.step.value, 0)
89+
90+
91+
if __name__ == "__main__":
92+
absltest.main()

0 commit comments

Comments
 (0)