diff --git a/qwix/_src/providers/sparsity_qt.py b/qwix/_src/providers/sparsity_qt.py index 7b72ad0..965d4fe 100644 --- a/qwix/_src/providers/sparsity_qt.py +++ b/qwix/_src/providers/sparsity_qt.py @@ -17,26 +17,23 @@ import jax import jax.numpy as jnp from qwix._src import flax_util -from qwix._src import qconfig from qwix._src.core import sparsity class SparsityModule(nn.Module): """Sparsity module for Flax.""" - sparsity_rule: qconfig.SparsityRule | None = None + sparsity_rule: sparsity.SparsityRule | None = None def _maybe_update_mask( self, weight: jax.Array, step: jax.Array, + mask_val: jax.Array, ) -> jax.Array: """Updates the sparsity mask based on the current step and config.""" - mask_val = flax_util.get_or_create_variable( - 'compression', 'mask', lambda: jnp.ones(weight.shape, jnp.bool_) - ) - # NOTE: Reshape if mask and wesight have shape mismatch. + # NOTE: Reshape if mask and weight have shape mismatch. if mask_val.shape != weight.shape: mask_val = jnp.reshape(mask_val, weight.shape) @@ -112,7 +109,9 @@ def __call__( ): # Do not update mask for eval. if not self.sparsity_rule.eval_mode: - new_mask = self._maybe_update_mask(weight=weight, step=step.value) + new_mask = self._maybe_update_mask( + weight=weight, step=step.value, mask_val=mask.value + ) mask.value = new_mask step.value = step.value + 1 diff --git a/tests/_src/providers/sparsity_qt_test.py b/tests/_src/providers/sparsity_qt_test.py new file mode 100644 index 0000000..a24179d --- /dev/null +++ b/tests/_src/providers/sparsity_qt_test.py @@ -0,0 +1,92 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for sparsity_qt module, for update mask and apply sparsity.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +from qwix._src.core import sparsity +from qwix._src.providers import sparsity_qt + + +class SparsityQtTest(parameterized.TestCase): + + def test_no_sparsity(self): + module = sparsity_qt.SparsityModule() + inputs = jnp.arange(10, dtype=jnp.float32) + weight = jnp.arange(10, dtype=jnp.float32) + out_inputs, out_weight = module.apply({}, inputs, weight) + self.assertTrue(jnp.array_equal(out_inputs, inputs)) + self.assertTrue(jnp.array_equal(out_weight, weight)) + + def test_activation_sparsity(self): + rule = sparsity.SparsityRule( + activation_sparsity_n=1, activation_sparsity_m=2 + ) + module = sparsity_qt.SparsityModule(sparsity_rule=rule) + inputs = jnp.array([1.0, 2.0, 3.0, 4.0]) + weight = jnp.array([1.0, 1.0, 1.0, 1.0]) + out_inputs, out_weight = module.apply({}, inputs, weight) + self.assertTrue( + jnp.array_equal(out_inputs, jnp.array([0.0, 2.0, 0.0, 4.0])) + ) + self.assertTrue(jnp.array_equal(out_weight, weight)) + + def test_weight_sparsity(self): + rule = sparsity.SparsityRule( + weight_sparsity_n=1, + weight_sparsity_m=2, + weight_sparsity_start_step=0, + weight_sparsity_update_step=1, + ) + module = sparsity_qt.SparsityModule(sparsity_rule=rule) + inputs = jnp.array([1.0, 1.0, 1.0, 1.0]) + weight = jnp.array([1.0, 2.0, 3.0, 4.0]) + + variables = module.init(jax.random.key(0), inputs, weight) + self.assertEqual(variables["compression"]["step"], 0) + + (out_inputs, out_weight), new_vars = module.apply( + variables, inputs, weight, mutable=["compression"] + ) + self.assertEqual(new_vars["compression"]["step"], 1) + expected_weight = jnp.array([0.0, 2.0, 0.0, 4.0]) + self.assertTrue(jnp.array_equal(out_inputs, inputs)) + self.assertTrue(jnp.array_equal(out_weight, expected_weight)) + + def test_eval_mode(self): + rule = sparsity.SparsityRule( + weight_sparsity_n=1, + weight_sparsity_m=2, + eval_mode=True, + ) + module = sparsity_qt.SparsityModule(sparsity_rule=rule) + inputs = jnp.array([1.0, 1.0, 1.0, 1.0]) + weight = jnp.array([1.0, 2.0, 3.0, 4.0]) + + variables = module.init(jax.random.key(0), inputs, weight) + # Mask initialized to all ones in evaluation/init + (out_inputs, out_weight), new_vars = module.apply( + variables, inputs, weight, mutable=["compression"] + ) + # In eval_mode, mask isn't updated and step isn't incremented. + # It just applies the existing mask, which is currently all ones. + self.assertTrue(jnp.array_equal(out_weight, weight)) + self.assertTrue(jnp.array_equal(out_inputs, inputs)) + self.assertEqual(new_vars["compression"]["step"], 0) + + +if __name__ == "__main__": + absltest.main()