1- # Copyright 2024 Google LLC
1+ # Copyright 2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1313# limitations under the License.
1414"""Core sparsity quantized training support."""
1515
16- import flax . linen as nn
16+ from flax import nnx
1717import jax
1818import jax .numpy as jnp
19- from qwix ._src import flax_util
2019from qwix ._src .core import sparsity
2120
2221
23- class SparsityModule (nn .Module ):
24- """Sparsity module for Flax."""
22+ class SparsityModule (nnx .Module ):
23+ """A stateful module for managing and applying structured sparsity in Flax NNX.
2524
26- sparsity_rule : sparsity .SparsityRule | None = None
25+ This module tracks the training step and maintains a persistent sparsity mask
26+ as `nnx.BatchStat` variables (effectively part of the model's batch stats,
27+ not trainable parameters). It can be used to apply structured N:M sparsity
28+ to activations and/or weights.
29+
30+ For weight sparsity, it periodically updates a cached boolean mask based on
31+ the `SparsityRule` and applied it to the weights. For activation sparsity,
32+ it computes and applies the mask dynamically on each call if enabled.
33+
34+ Attributes:
35+ step: An `nnx.BatchStat` tracking the number of update steps.
36+ mask: An `nnx.BatchStat` holding the persistent boolean mask for weights.
37+ sparsity_rule: The `SparsityRule` configuration.
38+ """
39+
40+ step : nnx .BatchStat
41+ mask : nnx .BatchStat
42+
43+ def __init__ (
44+ self ,
45+ shape : tuple [int , ...],
46+ sharding_axes : tuple [jax .sharding .PartitionSpec | None , ...],
47+ sparsity_rule : sparsity .SparsityRule | None = None ,
48+ ):
49+ self .sparsity_rule = sparsity_rule
50+ self .step = nnx .BatchStat (jnp .zeros ([], jnp .int32 ))
51+ self .mask = nnx .BatchStat (
52+ jnp .ones (shape , jnp .bool_ ), sharding = sharding_axes
53+ )
2754
2855 def _maybe_update_mask (
2956 self ,
3057 weight : jax .Array ,
3158 step : jax .Array ,
3259 ) -> jax .Array :
3360 """Updates the sparsity mask based on the current step and config."""
34-
35- mask_val = flax_util .get_or_create_variable (
36- 'compression' , 'mask' , lambda : jnp .ones (weight .shape , jnp .bool_ )
37- )
38- # NOTE: Reshape if mask and wesight have shape mismatch.
61+ mask_val = self .mask .value
3962 if mask_val .shape != weight .shape :
40- mask_val = jnp . reshape ( mask_val , weight .shape )
63+ mask_val = mask_val [ tuple ( slice ( 0 , s ) for s in weight .shape )]
4164
4265 def mask_update (w : jax .Array , mask_val : jax .Array ) -> jax .Array : # pylint: disable=unused-argument
4366 if self .sparsity_rule is None :
@@ -65,7 +88,8 @@ def should_update_mask(step: jax.Array):
6588 % self .sparsity_rule .weight_sparsity_update_step ,
6689 0 ,
6790 )
68- return jnp .logical_and (in_update_window , is_update_step )
91+ should_update = jnp .logical_and (in_update_window , is_update_step )
92+ return should_update
6993
7094 new_mask_val = jax .lax .cond (
7195 should_update_mask (step ),
@@ -76,7 +100,6 @@ def should_update_mask(step: jax.Array):
76100 )
77101 return new_mask_val
78102
79- @nn .compact
80103 def __call__ (
81104 self , inputs : jax .Array , weight : jax .Array
82105 ) -> tuple [jax .Array , jax .Array ]:
@@ -97,29 +120,18 @@ def __call__(
97120 input_mask , inputs , jnp .zeros (inputs .shape , inputs .dtype )
98121 )
99122 if self .sparsity_rule .weight_sparsity_m != 0 :
100-
101- step = flax_util .get_or_create_variable (
102- 'compression' , 'step' , lambda : jnp .zeros ([], jnp .int32 )
123+ if self .mask is None :
124+ self .mask = nnx .BatchStat (jnp .ones (weight .shape , jnp .bool_ ))
125+
126+ # Only update if not in eval mode
127+ if not self .sparsity_rule .eval_mode :
128+ new_mask = self ._maybe_update_mask (weight = weight , step = self .step .value )
129+ jax .debug .print ('amanda Current Sparsity Step: {s}' , s = self .step .value )
130+ self .mask .value = new_mask
131+ self .step .value = self .step .value + 1
132+
133+ weight = jnp .where (
134+ self .mask .value , weight , jnp .zeros (weight .shape , weight .dtype )
103135 )
104136
105- mask = flax_util .get_or_create_variable (
106- 'compression' , 'mask' , lambda : jnp .ones (weight .shape , jnp .bool_ )
107- )
108-
109- if not self .is_initializing () and self .has_variable (
110- 'compression' , 'mask'
111- ):
112- # Do not update mask for eval.
113- if not self .sparsity_rule .eval_mode :
114- new_mask = self ._maybe_update_mask (weight = weight , step = step .value )
115- mask .value = new_mask
116- step .value = step .value + 1
117-
118- # Unless updated mask is all ones, so we apply mask irrespective of
119- # start_step
120-
121- weight = jnp .where (
122- mask .value , weight , jnp .zeros (weight .shape , weight .dtype )
123- )
124-
125137 return inputs , weight
0 commit comments