From 5fd476463c12544fd8ea715dd5f02c583fdcacca Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Tue, 30 Jun 2026 13:05:44 -0600 Subject: [PATCH] lrr cluster by field --- pyproject.toml | 3 +- src/ezmsg/learn/process/ssr.py | 38 ++++++++++++++- tests/unit/test_ssr.py | 84 ++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index db30ddd..366da3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,4 +75,5 @@ known-third-party = ["ezmsg", "ezmsg.baseproc", "ezmsg.sigproc"] [tool.uv.sources] # Uncomment to use development version of ezmsg from git -#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" } \ No newline at end of file +#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" } +ezmsg-sigproc = { git = "https://github.com/ezmsg-org/ezmsg-sigproc.git", branch = "car-bank-aware" } \ No newline at end of file diff --git a/src/ezmsg/learn/process/ssr.py b/src/ezmsg/learn/process/ssr.py index c719d31..10100f3 100644 --- a/src/ezmsg/learn/process/ssr.py +++ b/src/ezmsg/learn/process/ssr.py @@ -58,6 +58,7 @@ AffineTransformTransformer, ) from ezmsg.sigproc.util.array import array_device, xp_create +from ezmsg.sigproc.util.channels import channel_clusters_from_field from ezmsg.util.messages.axisarray import AxisArray # --------------------------------------------------------------------------- @@ -78,6 +79,13 @@ class SelfSupervisedRegressionSettings(ez.Settings): channel_clusters: list[list[int]] | None = None """Per-cluster regression. ``None`` treats all channels as one cluster.""" + cluster_by_field: str | None = None + """Derive ``channel_clusters`` automatically from a structured field of the + channel coordinate axis (e.g. ``"bank"`` to regress within each electrode + bank). Used only when ``channel_clusters`` is None and the axis actually + carries that field; otherwise falls back to ``block_size`` (then a single + cluster). Explicit ``channel_clusters`` always takes precedence.""" + block_size: int | None = None """If ``channel_clusters`` is ``None``, use this block size for an automatic clustering.""" @@ -94,6 +102,9 @@ class SelfSupervisedRegressionState: cxx: object | None = None # Array API; namespace matches source data. n_samples: int = 0 weights: object | None = None # Array API; namespace matches cxx. + resolved_clusters: list | None = None + """Clusters derived from ``cluster_by_field`` at reset (message-dependent), + cached so ``_get_channel_clusters`` can return them without the message.""" class SelfSupervisedRegressionTransformer( @@ -115,13 +126,34 @@ class SelfSupervisedRegressionTransformer( def _hash_message(self, message: AxisArray) -> int: axis = self.settings.axis or message.dims[-1] axis_idx = message.get_axis_idx(axis) - return hash((message.key, message.data.shape[axis_idx])) + components: tuple = (message.key, message.data.shape[axis_idx]) + # When clusters are derived from a structured channel-axis field, the + # cached clusters go stale if that field's values change even though the + # key and channel count are unchanged. Fold the field's bytes into the + # hash so the state re-derives. Scoped to the cluster_by_field path so + # the common (no-field) case pays nothing; the cost is O(channels) once + # per message, far below the regression compute. + if self.settings.channel_clusters is None and self.settings.cluster_by_field is not None: + ax = message.axes.get(axis) + data = getattr(ax, "data", None) + names = getattr(getattr(data, "dtype", None), "names", None) + if data is not None and names and self.settings.cluster_by_field in names: + components += (data[self.settings.cluster_by_field].tobytes(),) + return hash(components) def _reset_state(self, message: AxisArray) -> None: axis = self.settings.axis or message.dims[-1] axis_idx = message.get_axis_idx(axis) n_channels = message.data.shape[axis_idx] + # Derive clusters from a structured channel-axis field (e.g. "bank") when + # requested and no explicit clusters were given. Cached so the + # message-less _get_channel_clusters can return them later. + if self.settings.channel_clusters is None and self.settings.cluster_by_field is not None: + self._state.resolved_clusters = channel_clusters_from_field(message, axis, self.settings.cluster_by_field) + else: + self._state.resolved_clusters = None + self._validate_clusters(n_channels) self._state.cxx = None self._state.n_samples = 0 @@ -141,7 +173,11 @@ def _reset_state(self, message: AxisArray) -> None: # -- cluster validation -------------------------------------------------- def _get_channel_clusters(self, n_channels: int) -> list[list[int]] | None: + # Precedence: explicit channel_clusters > cluster_by_field-derived + # (cached at reset) > block_size > None (single cluster). clusters = self.settings.channel_clusters + if clusters is None: + clusters = getattr(self._state, "resolved_clusters", None) if clusters is None and self.settings.block_size is not None: clusters = [ list(range(i, min(i + self.settings.block_size, n_channels))) diff --git a/tests/unit/test_ssr.py b/tests/unit/test_ssr.py index a8a06a9..be8626e 100644 --- a/tests/unit/test_ssr.py +++ b/tests/unit/test_ssr.py @@ -126,6 +126,90 @@ def test_channel_clusters(self): assert np.any(off_diag != 0), "Expected non-zero within-cluster weights" +def _banked_axisarray(data: np.ndarray, banks: list[str], key: str = "test") -> AxisArray: + """AxisArray whose ch axis is a structured CoordinateAxis with a bank field, + like ezmsg-blackrock ChannelMap emits.""" + dt = np.dtype([("label", "U16"), ("bank", "U1"), ("elec", "i4")]) + ch = np.zeros(len(banks), dtype=dt) + ch["bank"] = banks + ch["elec"] = list(range(1, len(banks) + 1)) + ch["label"] = [f"ch{i}" for i in range(len(banks))] + return AxisArray( + data=data, + dims=["time", "ch"], + axes={"time": AxisArray.TimeAxis(fs=100.0, offset=0.0), "ch": AxisArray.CoordinateAxis(data=ch, dims=["ch"])}, + key=key, + ) + + +class TestClusterByField: + def test_bank_field_matches_explicit_clusters(self): + """cluster_by_field='bank' derives the same clusters (and weights) as + passing the equivalent explicit channel_clusters.""" + rng = np.random.default_rng(7) + banks = ["A", "A", "A", "A", "B", "B", "B", "B"] + X = _random_data(n_ch=len(banks), rng=rng) + + proc_field = LRRTransformer(LRRSettings(axis="ch", cluster_by_field="bank")) + proc_field.partial_fit(_banked_axisarray(X, banks)) + + proc_explicit = LRRTransformer(LRRSettings(axis="ch", channel_clusters=[[0, 1, 2, 3], [4, 5, 6, 7]])) + proc_explicit.partial_fit(_make_axisarray(X)) + + np.testing.assert_array_equal(proc_field.state.weights, proc_explicit.state.weights) + # And cross-bank weights are zero + W = proc_field.state.weights + np.testing.assert_array_equal(W[np.ix_([0, 1, 2, 3], [4, 5, 6, 7])], 0.0) + + def test_explicit_clusters_take_precedence(self): + """Explicit channel_clusters win over cluster_by_field.""" + rng = np.random.default_rng(8) + banks = ["A", "A", "A", "A", "B", "B", "B", "B"] + X = _random_data(n_ch=len(banks), rng=rng) + + # One all-channel cluster should override the bank grouping. + proc = LRRTransformer( + LRRSettings(axis="ch", channel_clusters=[list(range(8))], cluster_by_field="bank") + ) + proc.partial_fit(_banked_axisarray(X, banks)) + # With a single cluster, cross-"bank" weights are NOT forced to zero. + W = proc.state.weights + assert np.any(W[np.ix_([0, 1, 2, 3], [4, 5, 6, 7])] != 0) + + def test_missing_field_falls_back_to_block_size(self): + """cluster_by_field with no structured bank field falls back to block_size.""" + rng = np.random.default_rng(9) + n_ch = 8 + X = _random_data(n_ch=n_ch, rng=rng) + # Plain axis (no structured bank field) + block_size=4 -> two contiguous blocks. + proc_field = LRRTransformer(LRRSettings(axis="ch", cluster_by_field="bank", block_size=4)) + proc_field.partial_fit(_make_axisarray(X)) + + proc_block = LRRTransformer(LRRSettings(axis="ch", block_size=4)) + proc_block.partial_fit(_make_axisarray(X)) + + np.testing.assert_array_equal(proc_field.state.weights, proc_block.state.weights) + + def test_clusters_re_derive_when_bank_field_changes(self): + """A new message with the same key and channel count but different bank + assignments must re-derive clusters. This only happens because + ``_hash_message`` folds the bank field into the hash; without it the + cached clusters (and their zeroed cross-bank weights) would go stale.""" + rng = np.random.default_rng(11) + X = _random_data(n_ch=4, rng=rng) + proc = LRRTransformer(LRRSettings(axis="ch", cluster_by_field="bank")) + + # First arrangement: banks A,A,B,B -> clusters {0,1},{2,3}. + proc.partial_fit(_banked_axisarray(X, ["A", "A", "B", "B"], key="x")) + assert proc.state.resolved_clusters == [[0, 1], [2, 3]] + np.testing.assert_array_equal(proc.state.weights[np.ix_([0, 1], [2, 3])], 0.0) + + # Same key, same channel count, different banks A,B,A,B -> {0,2},{1,3}. + proc.partial_fit(_banked_axisarray(X, ["A", "B", "A", "B"], key="x")) + assert proc.state.resolved_clusters == [[0, 2], [1, 3]] + np.testing.assert_array_equal(proc.state.weights[np.ix_([0, 2], [1, 3])], 0.0) + + class TestIncrementalAccumulates: def test_incremental_accumulates(self): """Two partial_fits with incremental=True should match one fit on concatenated data."""