Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,9 @@ def _register_reduced_input_impl(self, x: Tensor) -> None:
trace = fns.sum(fns.multiply(x, x))
# NOTE: average trace?? divide by number of diagonal elements
# TODO(dlyakhov): revise this formula as possibly it is with an error; adopted from previous HAWQ implementation
Comment thread
AlexanderDokuchaev marked this conversation as resolved.
Outdated
self._container = (self._container + trace) / x.size
# We normalize the trace by the number of elements in the tensor so that larger matrices do not dominate
# the sensitivity scores
self._container += trace / x.size
Comment thread
anzr299 marked this conversation as resolved.

def _aggregate_impl(self) -> Tensor:
return self._container * 2 / self._collected_samples
Expand Down
4 changes: 2 additions & 2 deletions tests/common/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,8 @@ def test_aggregators_hash(self, aggregator_cls):

HAWQ_AGGREGATOR_REFERENCE_VALUES = [
([np.arange(10)], 57.0),
([np.arange(12).reshape((2, 6)), np.arange(24).reshape((4, 6))], 181.92361111111111),
([np.arange(8 * i).reshape((1, 8, i)) for i in range(1, 5)], 165.61627197265625),
([np.arange(12).reshape((2, 6)), np.arange(24).reshape((4, 6))], 222.33333333333331),
([np.arange(8 * i).reshape((1, 8, i)) for i in range(1, 5)], 300.3333333333333),
]
Comment thread
anzr299 marked this conversation as resolved.

@pytest.mark.parametrize("inputs,reference_output", HAWQ_AGGREGATOR_REFERENCE_VALUES)
Expand Down
Loading