Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,5 @@ known-third-party = ["ezmsg", "ezmsg.baseproc", "ezmsg.sigproc"]

[tool.uv.sources]
# Uncomment to use development version of ezmsg from git
#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
ezmsg-sigproc = { git = "https://github.com/ezmsg-org/ezmsg-sigproc.git", branch = "car-bank-aware" }
38 changes: 37 additions & 1 deletion src/ezmsg/learn/process/ssr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
AffineTransformTransformer,
)
from ezmsg.sigproc.util.array import array_device, xp_create
from ezmsg.sigproc.util.channels import channel_clusters_from_field
from ezmsg.util.messages.axisarray import AxisArray

# ---------------------------------------------------------------------------
Expand All @@ -78,6 +79,13 @@ class SelfSupervisedRegressionSettings(ez.Settings):
channel_clusters: list[list[int]] | None = None
"""Per-cluster regression. ``None`` treats all channels as one cluster."""

cluster_by_field: str | None = None
"""Derive ``channel_clusters`` automatically from a structured field of the
channel coordinate axis (e.g. ``"bank"`` to regress within each electrode
bank). Used only when ``channel_clusters`` is None and the axis actually
carries that field; otherwise falls back to ``block_size`` (then a single
cluster). Explicit ``channel_clusters`` always takes precedence."""

block_size: int | None = None
"""If ``channel_clusters`` is ``None``, use this block size for an automatic clustering."""

Expand All @@ -94,6 +102,9 @@ class SelfSupervisedRegressionState:
cxx: object | None = None # Array API; namespace matches source data.
n_samples: int = 0
weights: object | None = None # Array API; namespace matches cxx.
resolved_clusters: list | None = None
"""Clusters derived from ``cluster_by_field`` at reset (message-dependent),
cached so ``_get_channel_clusters`` can return them without the message."""


class SelfSupervisedRegressionTransformer(
Expand All @@ -115,13 +126,34 @@ class SelfSupervisedRegressionTransformer(
def _hash_message(self, message: AxisArray) -> int:
axis = self.settings.axis or message.dims[-1]
axis_idx = message.get_axis_idx(axis)
return hash((message.key, message.data.shape[axis_idx]))
components: tuple = (message.key, message.data.shape[axis_idx])
# When clusters are derived from a structured channel-axis field, the
# cached clusters go stale if that field's values change even though the
# key and channel count are unchanged. Fold the field's bytes into the
# hash so the state re-derives. Scoped to the cluster_by_field path so
# the common (no-field) case pays nothing; the cost is O(channels) once
# per message, far below the regression compute.
if self.settings.channel_clusters is None and self.settings.cluster_by_field is not None:
ax = message.axes.get(axis)
data = getattr(ax, "data", None)
names = getattr(getattr(data, "dtype", None), "names", None)
if data is not None and names and self.settings.cluster_by_field in names:
components += (data[self.settings.cluster_by_field].tobytes(),)
return hash(components)

def _reset_state(self, message: AxisArray) -> None:
axis = self.settings.axis or message.dims[-1]
axis_idx = message.get_axis_idx(axis)
n_channels = message.data.shape[axis_idx]

# Derive clusters from a structured channel-axis field (e.g. "bank") when
# requested and no explicit clusters were given. Cached so the
# message-less _get_channel_clusters can return them later.
if self.settings.channel_clusters is None and self.settings.cluster_by_field is not None:
self._state.resolved_clusters = channel_clusters_from_field(message, axis, self.settings.cluster_by_field)
else:
self._state.resolved_clusters = None

self._validate_clusters(n_channels)
self._state.cxx = None
self._state.n_samples = 0
Expand All @@ -141,7 +173,11 @@ def _reset_state(self, message: AxisArray) -> None:
# -- cluster validation --------------------------------------------------

def _get_channel_clusters(self, n_channels: int) -> list[list[int]] | None:
# Precedence: explicit channel_clusters > cluster_by_field-derived
# (cached at reset) > block_size > None (single cluster).
clusters = self.settings.channel_clusters
if clusters is None:
clusters = getattr(self._state, "resolved_clusters", None)
if clusters is None and self.settings.block_size is not None:
clusters = [
list(range(i, min(i + self.settings.block_size, n_channels)))
Expand Down
84 changes: 84 additions & 0 deletions tests/unit/test_ssr.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,90 @@ def test_channel_clusters(self):
assert np.any(off_diag != 0), "Expected non-zero within-cluster weights"


def _banked_axisarray(data: np.ndarray, banks: list[str], key: str = "test") -> AxisArray:
"""AxisArray whose ch axis is a structured CoordinateAxis with a bank field,
like ezmsg-blackrock ChannelMap emits."""
dt = np.dtype([("label", "U16"), ("bank", "U1"), ("elec", "i4")])
ch = np.zeros(len(banks), dtype=dt)
ch["bank"] = banks
ch["elec"] = list(range(1, len(banks) + 1))
ch["label"] = [f"ch{i}" for i in range(len(banks))]
return AxisArray(
data=data,
dims=["time", "ch"],
axes={"time": AxisArray.TimeAxis(fs=100.0, offset=0.0), "ch": AxisArray.CoordinateAxis(data=ch, dims=["ch"])},
key=key,
)


class TestClusterByField:
def test_bank_field_matches_explicit_clusters(self):
"""cluster_by_field='bank' derives the same clusters (and weights) as
passing the equivalent explicit channel_clusters."""
rng = np.random.default_rng(7)
banks = ["A", "A", "A", "A", "B", "B", "B", "B"]
X = _random_data(n_ch=len(banks), rng=rng)

proc_field = LRRTransformer(LRRSettings(axis="ch", cluster_by_field="bank"))
proc_field.partial_fit(_banked_axisarray(X, banks))

proc_explicit = LRRTransformer(LRRSettings(axis="ch", channel_clusters=[[0, 1, 2, 3], [4, 5, 6, 7]]))
proc_explicit.partial_fit(_make_axisarray(X))

np.testing.assert_array_equal(proc_field.state.weights, proc_explicit.state.weights)
# And cross-bank weights are zero
W = proc_field.state.weights
np.testing.assert_array_equal(W[np.ix_([0, 1, 2, 3], [4, 5, 6, 7])], 0.0)

def test_explicit_clusters_take_precedence(self):
"""Explicit channel_clusters win over cluster_by_field."""
rng = np.random.default_rng(8)
banks = ["A", "A", "A", "A", "B", "B", "B", "B"]
X = _random_data(n_ch=len(banks), rng=rng)

# One all-channel cluster should override the bank grouping.
proc = LRRTransformer(
LRRSettings(axis="ch", channel_clusters=[list(range(8))], cluster_by_field="bank")
)
proc.partial_fit(_banked_axisarray(X, banks))
# With a single cluster, cross-"bank" weights are NOT forced to zero.
W = proc.state.weights
assert np.any(W[np.ix_([0, 1, 2, 3], [4, 5, 6, 7])] != 0)

def test_missing_field_falls_back_to_block_size(self):
"""cluster_by_field with no structured bank field falls back to block_size."""
rng = np.random.default_rng(9)
n_ch = 8
X = _random_data(n_ch=n_ch, rng=rng)
# Plain axis (no structured bank field) + block_size=4 -> two contiguous blocks.
proc_field = LRRTransformer(LRRSettings(axis="ch", cluster_by_field="bank", block_size=4))
proc_field.partial_fit(_make_axisarray(X))

proc_block = LRRTransformer(LRRSettings(axis="ch", block_size=4))
proc_block.partial_fit(_make_axisarray(X))

np.testing.assert_array_equal(proc_field.state.weights, proc_block.state.weights)

def test_clusters_re_derive_when_bank_field_changes(self):
"""A new message with the same key and channel count but different bank
assignments must re-derive clusters. This only happens because
``_hash_message`` folds the bank field into the hash; without it the
cached clusters (and their zeroed cross-bank weights) would go stale."""
rng = np.random.default_rng(11)
X = _random_data(n_ch=4, rng=rng)
proc = LRRTransformer(LRRSettings(axis="ch", cluster_by_field="bank"))

# First arrangement: banks A,A,B,B -> clusters {0,1},{2,3}.
proc.partial_fit(_banked_axisarray(X, ["A", "A", "B", "B"], key="x"))
assert proc.state.resolved_clusters == [[0, 1], [2, 3]]
np.testing.assert_array_equal(proc.state.weights[np.ix_([0, 1], [2, 3])], 0.0)

# Same key, same channel count, different banks A,B,A,B -> {0,2},{1,3}.
proc.partial_fit(_banked_axisarray(X, ["A", "B", "A", "B"], key="x"))
assert proc.state.resolved_clusters == [[0, 2], [1, 3]]
np.testing.assert_array_equal(proc.state.weights[np.ix_([0, 2], [1, 3])], 0.0)


class TestIncrementalAccumulates:
def test_incremental_accumulates(self):
"""Two partial_fits with incremental=True should match one fit on concatenated data."""
Expand Down
Loading