From deda031f381b03da3b6cb4245e6b46cde2b7fd5c Mon Sep 17 00:00:00 2001 From: Geoffrey Schau Date: Tue, 12 May 2026 14:05:44 -0700 Subject: [PATCH 1/8] feat: biased sparse inference sampling Replaces the placeholder SparseGraphDataset with the prototype from feat/sparse-inference-sampling (5b53700), ported to the training-improvements SkeletonGraph API (get_branchings vs. branching_nodes). SparseGraphDataset now inherits from DenseGraphDataset and overrides only _generate_batch_nodes / estimate_iterations. Node selection is gated by sparse_sampling.compute_interesting_nodes, which picks the union of: - nodes within branch_radius graph-distance of any branching node - nodes within proximity_radius euclidean distance of a node in another connected component Everything else in merge_inference.py is unchanged from training-improvements. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../merge_proofreading/merge_inference.py | 100 ++++++++++++------ .../merge_proofreading/sparse_sampling.py | 62 +++++++++++ 2 files changed, 127 insertions(+), 35 deletions(-) create mode 100644 src/neuron_proofreader/merge_proofreading/sparse_sampling.py diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 26addf5f..88c0fb8c 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -24,6 +24,9 @@ from neuron_proofreader.machine_learning.point_cloud_models import ( subgraph_to_point_cloud, ) +from neuron_proofreader.merge_proofreading.sparse_sampling import ( + compute_interesting_nodes, +) from neuron_proofreader.utils import ( geometry_util, img_util, @@ -556,7 +559,13 @@ def estimate_iterations(self): return int(length / self.step_size) -class SparseGraphDataset(GraphDataset): +class SparseGraphDataset(DenseGraphDataset): + """ + Inference dataset that samples only nodes near branch points or near + other axons, skipping long isolated axon segments. Inherits the + image-prefetch and per-node feature extraction from DenseGraphDataset + and overrides only the node-selection logic. + """ def __init__( self, @@ -564,72 +573,93 @@ def __init__( img_path, patch_shape, batch_size=16, + branch_radius=25.0, + brightness_clip=300, is_multimodal=False, min_search_size=0, prefetch=128, + proximity_radius=15.0, segmentation_path=None, + step_size=10, subgraph_radius=100, - use_new_mask=False + use_new_mask=False, ): - # Call parent class super().__init__( graph, img_path, patch_shape, batch_size=batch_size, + brightness_clip=brightness_clip, is_multimodal=is_multimodal, min_search_size=min_search_size, prefetch=prefetch, segmentation_path=segmentation_path, + step_size=step_size, subgraph_radius=subgraph_radius, - use_new_mask=use_new_mask + use_new_mask=use_new_mask, ) # Instance attributes - self.search_mode = "branching_points" - - def _generate_batches_from_component(self): - pass + self.search_mode = "biased_sparse" + self.branch_radius = branch_radius + self.proximity_radius = proximity_radius + self._interesting_nodes = compute_interesting_nodes( + graph, + branch_radius=branch_radius, + proximity_radius=proximity_radius, + ) def _generate_batch_nodes(self, root): + """ + Iterates the connected component containing "root" via DFS at the + same "step_size" cadence as DenseGraphDataset, but emits only nodes + pre-selected by "compute_interesting_nodes". Long, isolated axon + segments contribute zero samples. + """ nodes = list() - patch_centers = list() for i, j in nx.dfs_edges(self.graph, source=root): - # Check if starting new batch self.distance_traversed += self.graph.dist(i, j) - if len(patch_centers) == 0 and self.graph.degree[i] > 2: - root = i - nodes.append(i) - patch_centers.append(self.graph.get_voxel(i)) - # Check whether to yield batch - is_node_far = self.graph.dist(root, j) > 256 - is_batch_full = len(patch_centers) == self.batch_size - if is_node_far or is_batch_full: - # Yield batch metadata - patch_centers = np.array(patch_centers, dtype=int) - nodes = np.array(nodes, dtype=int) - yield nodes, patch_centers + # Open a batch on the first interesting node we reach + if len(nodes) == 0: + if i in self._interesting_nodes and self.is_node_valid(i): + root = i + last_node = i + nodes.append(i) + else: + continue - # Reset batch metadata + # Yield when batch is full or has spread too far for prefetch + is_node_far = self.graph.dist(root, j) > 512 + is_batch_full = len(nodes) == self.batch_size + if is_node_far or is_batch_full: + yield np.array(nodes, dtype=int) nodes = list() - patch_centers = list() - # Visit j - if self.graph.degree[j] > 2: + # Visit j: same step_size cadence as Dense, gated on the + # interesting set so only branch- and proximity-region nodes + # are emitted. + is_next = self.graph.dist(last_node, j) >= self.step_size - 2 + is_branching = self.graph.degree[j] >= 3 + if ( + (is_next or is_branching) + and j in self._interesting_nodes + and self.is_node_valid(j) + ): + last_node = j nodes.append(j) - patch_centers.append(self.graph.get_voxel(j)) - if len(patch_centers) == 1: + if len(nodes) == 1: root = j + if nodes: + yield np.array(nodes, dtype=int) + # --- Helpers --- def estimate_iterations(self): """ - Estimates the number of iterations required to search graph. - - Returns - ------- - int - Estimated number of iterations required to search graph. + Estimates the number of iterations required to search graph: the + size of the interesting set, divided by the step_size cadence used + within those regions. """ - return len(self.graph.get_branchings()) + step = max(1, self.step_size) + return max(1, len(self._interesting_nodes) // step) diff --git a/src/neuron_proofreader/merge_proofreading/sparse_sampling.py b/src/neuron_proofreader/merge_proofreading/sparse_sampling.py new file mode 100644 index 00000000..7a35198c --- /dev/null +++ b/src/neuron_proofreader/merge_proofreading/sparse_sampling.py @@ -0,0 +1,62 @@ +""" +Helpers for biased sparse inference-time sampling: pre-selects skeleton +nodes that sit near a branch point or near another axon, so that merge +detection can skip long, isolated axon segments. + +Kept in its own module (only numpy + a networkx graph object) so it can be +exercised by unit tests without pulling in the full merge_inference image +and ML stack. +""" + +import numpy as np + + +def compute_interesting_nodes(graph, branch_radius=25.0, proximity_radius=15.0): + """ + Selects nodes worth running merge detection on: the union of (a) nodes + within "branch_radius" graph-distance of a branching node, and (b) nodes + within "proximity_radius" Euclidean distance of a node belonging to a + different connected component. + + Parameters + ---------- + graph : SkeletonGraph + Skeleton graph with "node_xyz", "node_component_id", and "kdtree" + populated. Must expose "get_branchings()", "neighbors(i)", + "dist(i, j)", and "set_kdtree()". + branch_radius : float, optional + Graph-distance window (microns) around branching nodes. Default 25. + proximity_radius : float, optional + Euclidean threshold (microns) for nodes treated as "near another + axon". Default 15 (matches "geometry_util.is_double_merge"). + + Returns + ------- + Set[int] + Node IDs to sample at inference time. + """ + # Branch-region nodes: bounded DFS from every branching node + branch_set = set(graph.get_branchings()) + queue = [(i, 0.0) for i in branch_set] + while queue: + i, dist_i = queue.pop() + for j in graph.neighbors(i): + dist_j = dist_i + graph.dist(i, j) + if j not in branch_set and dist_j < branch_radius: + branch_set.add(j) + queue.append((j, dist_j)) + + # Inter-component proximity nodes + if graph.kdtree is None: + graph.set_kdtree() + proximity_set = set() + for i in graph.nodes: + idxs = np.array( + graph.kdtree.query_ball_point(graph.node_xyz[i], proximity_radius) + ) + if idxs.size and np.any( + graph.node_component_id[idxs] != graph.node_component_id[i] + ): + proximity_set.add(i) + + return branch_set | proximity_set From 1742e475bd7b33cf7051d3383e3a72441112221a Mon Sep 17 00:00:00 2001 From: Geoffrey Schau Date: Tue, 12 May 2026 21:20:47 -0700 Subject: [PATCH 2/8] init sparse sampling figure --- assets/overview/sparse_inference_overview.tex | 197 ++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 assets/overview/sparse_inference_overview.tex diff --git a/assets/overview/sparse_inference_overview.tex b/assets/overview/sparse_inference_overview.tex new file mode 100644 index 00000000..b08b0048 --- /dev/null +++ b/assets/overview/sparse_inference_overview.tex @@ -0,0 +1,197 @@ +% Figure: biased sparse inference-time sampling. +% +% Geometry (in figure units; treat as proxies for microns): +% axon A trunk: (0.5, 5) -- (10, 5) +% axon A branch: (3.5, 5) -- (5.7, 7) (branch point at (3.5, 5)) +% axon B trunk: flat at y=2 from x=0.5..5, bezier bump up to (7, 4.4), +% bezier back down to (9, 2), flat to (10, 2) +% r_branch = 1.4 (drawn around the branch point) +% r_prox = 1.3 (drawn around the A trunk node nearest to B's peak) +% +% A node is colored ("interesting") iff it falls within r_branch graph- +% distance of a branch node OR within r_prox Euclidean distance of a node +% in a different connected component. Otherwise it is grey ("skipped"). +% +% Membership has been precomputed by hand; see comments next to each +% \foreach loop for the math. + +\documentclass[border=10pt,tikz]{standalone} +\usepackage{tikz} +\usepackage{amsmath} +\usepackage{xcolor} +\usetikzlibrary{positioning, calc, arrows.meta, backgrounds, shapes.geometric} + +% --- Palette --- +\definecolor{compA}{HTML}{C0392B} % red -- axon A +\definecolor{compB}{HTML}{2471A3} % blue -- axon B +\definecolor{branchR}{HTML}{D68910} % amber -- r_branch +\definecolor{proxR}{HTML}{6C3483} % purple -- r_prox +\definecolor{idle}{HTML}{B0B7BD} % grey -- skipped + +\begin{document} +\begin{tikzpicture}[ + >=Stealth, + font=\small, + skelA/.style ={line width=1.3pt, color=compA, line cap=round}, + skelB/.style ={line width=1.3pt, color=compB, line cap=round}, + rBranch/.style ={dashed, line width=0.9pt, color=branchR}, + rProx/.style ={dashed, line width=0.9pt, color=proxR}, + intA/.style ={circle, fill=compA, draw=compA!50!black, + line width=0.2pt, inner sep=1.6pt}, + intB/.style ={circle, fill=compB, draw=compB!50!black, + line width=0.2pt, inner sep=1.6pt}, + skipped/.style ={circle, fill=idle, inner sep=0.9pt}, + branchMark/.style={star, star points=5, star point ratio=0.55, + fill=compA, draw=compA!30!black, line width=0.4pt, + inner sep=2.4pt}, +] + +% ==================================================================== +% Background shading of the two radius regions (drawn before everything) +% ==================================================================== +\begin{scope}[on background layer] + \fill[branchR!10] (3.5, 5) circle [radius=1.4]; + \fill[proxR!10] (7.0, 5) circle [radius=1.3]; +\end{scope} + +% ==================================================================== +% Axon A skeleton: trunk + branch +% ==================================================================== +\draw[skelA] (0.5, 5) -- (10.0, 5); +\draw[skelA] (3.5, 5) -- (5.7, 7.0); + +% ==================================================================== +% Axon B skeleton: flat, bezier bump up to (7, 4.4), bezier back down +% ==================================================================== +\draw[skelB] + (0.5, 2.0) -- (5.0, 2.0) + .. controls (5.7, 2.0) and (6.3, 4.4) .. (7.0, 4.4) + .. controls (7.7, 4.4) and (8.3, 2.0) .. (9.0, 2.0) + -- (10.0, 2.0); + +% ==================================================================== +% Radii outlines +% ==================================================================== +\draw[rBranch] (3.5, 5) circle [radius=1.4]; +\draw[rProx] (7.0, 5) circle [radius=1.3]; + +% ==================================================================== +% Radius annotations +% ==================================================================== +% r_branch: arrow pointing upper-left from branch point to circle edge +\draw[branchR, line width=0.7pt, ->] + (3.5, 5) -- ({3.5 + 1.4*cos(120)}, {5 + 1.4*sin(120)}); +\node[color=branchR, anchor=south, font=\footnotesize] at (2.65, 6.30) + {$r_{\mathrm{branch}}$}; + +% r_prox: arrow pointing lower-left from A node to circle edge +\draw[proxR, line width=0.7pt, ->] + (7.0, 5) -- ({7.0 + 1.3*cos(-110)}, {5 + 1.3*sin(-110)}); +\node[color=proxR, anchor=north, font=\footnotesize] at (6.55, 3.62) + {$r_{\mathrm{prox}}$}; + +% Inter-axon distance illustration: thin segment between the closest +% pair of nodes from different components +\draw[proxR!70, line width=0.4pt, dotted] (7.0, 5) -- (7.0, 4.4); + +% ==================================================================== +% Branching node marker (drawn on top of the trunk) +% ==================================================================== +\node[branchMark] at (3.5, 5) {}; + +% ==================================================================== +% Axon labels on the right +% ==================================================================== +\node[color=compA, anchor=west, font=\footnotesize\itshape] at (10.1, 5) + {axon $A$}; +\node[color=compB, anchor=west, font=\footnotesize\itshape] at (10.1, 2) + {axon $B$}; +\node[color=compA!70!black, anchor=south west, font=\footnotesize\itshape] + at (5.75, 7.0) {branch}; + +% ==================================================================== +% Skeleton nodes +% ==================================================================== +% --- A trunk (x = 0.5..10.0 step 0.5) ------------------------------- +% Interesting iff (a) within r_branch=1.4 of (3.5, 5) along trunk: +% |x - 3.5| < 1.4 => x in {2.5, 3.0, 3.5, 4.0, 4.5} +% or (b) within r_prox=1.3 of a B node (closest B nodes computed below): +% x in {6.0, 6.5, 7.0, 7.5, 8.0} +% (x = 5.0, 5.5 sit between the two regions and are skipped) +\foreach \x in {0.5, 1.0, 1.5, 2.0, 5.0, 5.5, 8.5, 9.0, 9.5, 10.0} { + \node[skipped] at (\x, 5) {}; +} +\foreach \x in {2.5, 3.0, 3.5, 4.0, 4.5, 6.0, 6.5, 7.0, 7.5, 8.0} { + \node[intA] at (\x, 5) {}; +} + +% --- A branch (parametrized: (3.5 + 2.2t, 5 + 2.0t), length 2.974) --- +% Interesting iff t * 2.974 < r_branch = 1.4 => t < 0.471 +\foreach \t in {0.1, 0.2, 0.3, 0.4} { + \node[intA] at ({3.5 + 2.2*\t}, {5 + 2.0*\t}) {}; +} +\foreach \t in {0.5, 0.6, 0.7, 0.8, 0.9, 1.0} { + \node[skipped] at ({3.5 + 2.2*\t}, {5 + 2.0*\t}) {}; +} + +% --- B nodes along the bezier (approximate y from bezier evaluation) --- +% Interesting iff within r_prox=1.3 of any A node. +% Closest A node to B point (x, y) is (x, 5). +% B (6.5, 4.0): dist 1.0 -> interesting +% B (7.0, 4.4): dist 0.6 -> interesting +% B (7.5, 4.0): dist 1.0 -> interesting +% B (6.0, 3.2): dist 1.8 -> skipped +% all flat-region B nodes: -> skipped +\foreach \x in {0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0} { + \node[skipped] at (\x, 2.0) {}; +} +\node[skipped] at (5.5, 2.30) {}; +\node[skipped] at (6.0, 3.20) {}; +\node[intB] at (6.5, 4.00) {}; +\node[intB] at (7.0, 4.40) {}; +\node[intB] at (7.5, 4.00) {}; +\node[skipped] at (8.0, 3.20) {}; +\node[skipped] at (8.5, 2.30) {}; +\foreach \x in {9.0, 9.5, 10.0} { + \node[skipped] at (\x, 2.0) {}; +} + +% ==================================================================== +% Legend +% ==================================================================== +\begin{scope}[shift={(0, -1.7)}, font=\footnotesize] + \draw[black!25, line width=0.4pt, rounded corners=2pt, fill=white] + (0, -1.45) rectangle (10.7, 0.65); + + % --- Row 1: node-type swatches --- + \node[intA] at (0.30, 0.40) {}; + \node[anchor=west] at (0.50, 0.40) {sampled (axon $A$)}; + + \node[intB] at (3.20, 0.40) {}; + \node[anchor=west] at (3.40, 0.40) {sampled (axon $B$)}; + + \node[skipped] at (5.95, 0.40) {}; + \node[anchor=west] at (6.10, 0.40) {skipped}; + + \node[branchMark, inner sep=1.6pt] at (7.85, 0.40) {}; + \node[anchor=west] at (8.05, 0.40) {branching node}; + + % --- Row 2: r_branch --- + \draw[rBranch] (0.10, 0.00) -- (0.50, 0.00); + \node[anchor=west] at (0.55, 0.00) + {$r_{\mathrm{branch}}$: graph-distance buffer around branch points}; + + % --- Row 3: r_prox --- + \draw[rProx] (0.10, -0.40) -- (0.50, -0.40); + \node[anchor=west] at (0.55, -0.40) + {$r_{\mathrm{prox}}$: Euclidean buffer toward nodes in other components}; + + % --- Row 4: rule summary --- + \node[anchor=west, font=\footnotesize\itshape, color=black!70] + at (0.10, -1.00) + {sparse sampler emits a node iff (within $r_{\mathrm{branch}}$ of a branch) $\vee$ + (within $r_{\mathrm{prox}}$ of a node in another component)}; +\end{scope} + +\end{tikzpicture} +\end{document} From 56c1a800090596153b1c219be99e6675e31036fc Mon Sep 17 00:00:00 2001 From: Geoffrey Schau Date: Thu, 14 May 2026 13:28:42 -0700 Subject: [PATCH 3/8] rm tikz gen --- assets/overview/sparse_inference_overview.tex | 197 ------------------ 1 file changed, 197 deletions(-) delete mode 100644 assets/overview/sparse_inference_overview.tex diff --git a/assets/overview/sparse_inference_overview.tex b/assets/overview/sparse_inference_overview.tex deleted file mode 100644 index b08b0048..00000000 --- a/assets/overview/sparse_inference_overview.tex +++ /dev/null @@ -1,197 +0,0 @@ -% Figure: biased sparse inference-time sampling. -% -% Geometry (in figure units; treat as proxies for microns): -% axon A trunk: (0.5, 5) -- (10, 5) -% axon A branch: (3.5, 5) -- (5.7, 7) (branch point at (3.5, 5)) -% axon B trunk: flat at y=2 from x=0.5..5, bezier bump up to (7, 4.4), -% bezier back down to (9, 2), flat to (10, 2) -% r_branch = 1.4 (drawn around the branch point) -% r_prox = 1.3 (drawn around the A trunk node nearest to B's peak) -% -% A node is colored ("interesting") iff it falls within r_branch graph- -% distance of a branch node OR within r_prox Euclidean distance of a node -% in a different connected component. Otherwise it is grey ("skipped"). -% -% Membership has been precomputed by hand; see comments next to each -% \foreach loop for the math. - -\documentclass[border=10pt,tikz]{standalone} -\usepackage{tikz} -\usepackage{amsmath} -\usepackage{xcolor} -\usetikzlibrary{positioning, calc, arrows.meta, backgrounds, shapes.geometric} - -% --- Palette --- -\definecolor{compA}{HTML}{C0392B} % red -- axon A -\definecolor{compB}{HTML}{2471A3} % blue -- axon B -\definecolor{branchR}{HTML}{D68910} % amber -- r_branch -\definecolor{proxR}{HTML}{6C3483} % purple -- r_prox -\definecolor{idle}{HTML}{B0B7BD} % grey -- skipped - -\begin{document} -\begin{tikzpicture}[ - >=Stealth, - font=\small, - skelA/.style ={line width=1.3pt, color=compA, line cap=round}, - skelB/.style ={line width=1.3pt, color=compB, line cap=round}, - rBranch/.style ={dashed, line width=0.9pt, color=branchR}, - rProx/.style ={dashed, line width=0.9pt, color=proxR}, - intA/.style ={circle, fill=compA, draw=compA!50!black, - line width=0.2pt, inner sep=1.6pt}, - intB/.style ={circle, fill=compB, draw=compB!50!black, - line width=0.2pt, inner sep=1.6pt}, - skipped/.style ={circle, fill=idle, inner sep=0.9pt}, - branchMark/.style={star, star points=5, star point ratio=0.55, - fill=compA, draw=compA!30!black, line width=0.4pt, - inner sep=2.4pt}, -] - -% ==================================================================== -% Background shading of the two radius regions (drawn before everything) -% ==================================================================== -\begin{scope}[on background layer] - \fill[branchR!10] (3.5, 5) circle [radius=1.4]; - \fill[proxR!10] (7.0, 5) circle [radius=1.3]; -\end{scope} - -% ==================================================================== -% Axon A skeleton: trunk + branch -% ==================================================================== -\draw[skelA] (0.5, 5) -- (10.0, 5); -\draw[skelA] (3.5, 5) -- (5.7, 7.0); - -% ==================================================================== -% Axon B skeleton: flat, bezier bump up to (7, 4.4), bezier back down -% ==================================================================== -\draw[skelB] - (0.5, 2.0) -- (5.0, 2.0) - .. controls (5.7, 2.0) and (6.3, 4.4) .. (7.0, 4.4) - .. controls (7.7, 4.4) and (8.3, 2.0) .. (9.0, 2.0) - -- (10.0, 2.0); - -% ==================================================================== -% Radii outlines -% ==================================================================== -\draw[rBranch] (3.5, 5) circle [radius=1.4]; -\draw[rProx] (7.0, 5) circle [radius=1.3]; - -% ==================================================================== -% Radius annotations -% ==================================================================== -% r_branch: arrow pointing upper-left from branch point to circle edge -\draw[branchR, line width=0.7pt, ->] - (3.5, 5) -- ({3.5 + 1.4*cos(120)}, {5 + 1.4*sin(120)}); -\node[color=branchR, anchor=south, font=\footnotesize] at (2.65, 6.30) - {$r_{\mathrm{branch}}$}; - -% r_prox: arrow pointing lower-left from A node to circle edge -\draw[proxR, line width=0.7pt, ->] - (7.0, 5) -- ({7.0 + 1.3*cos(-110)}, {5 + 1.3*sin(-110)}); -\node[color=proxR, anchor=north, font=\footnotesize] at (6.55, 3.62) - {$r_{\mathrm{prox}}$}; - -% Inter-axon distance illustration: thin segment between the closest -% pair of nodes from different components -\draw[proxR!70, line width=0.4pt, dotted] (7.0, 5) -- (7.0, 4.4); - -% ==================================================================== -% Branching node marker (drawn on top of the trunk) -% ==================================================================== -\node[branchMark] at (3.5, 5) {}; - -% ==================================================================== -% Axon labels on the right -% ==================================================================== -\node[color=compA, anchor=west, font=\footnotesize\itshape] at (10.1, 5) - {axon $A$}; -\node[color=compB, anchor=west, font=\footnotesize\itshape] at (10.1, 2) - {axon $B$}; -\node[color=compA!70!black, anchor=south west, font=\footnotesize\itshape] - at (5.75, 7.0) {branch}; - -% ==================================================================== -% Skeleton nodes -% ==================================================================== -% --- A trunk (x = 0.5..10.0 step 0.5) ------------------------------- -% Interesting iff (a) within r_branch=1.4 of (3.5, 5) along trunk: -% |x - 3.5| < 1.4 => x in {2.5, 3.0, 3.5, 4.0, 4.5} -% or (b) within r_prox=1.3 of a B node (closest B nodes computed below): -% x in {6.0, 6.5, 7.0, 7.5, 8.0} -% (x = 5.0, 5.5 sit between the two regions and are skipped) -\foreach \x in {0.5, 1.0, 1.5, 2.0, 5.0, 5.5, 8.5, 9.0, 9.5, 10.0} { - \node[skipped] at (\x, 5) {}; -} -\foreach \x in {2.5, 3.0, 3.5, 4.0, 4.5, 6.0, 6.5, 7.0, 7.5, 8.0} { - \node[intA] at (\x, 5) {}; -} - -% --- A branch (parametrized: (3.5 + 2.2t, 5 + 2.0t), length 2.974) --- -% Interesting iff t * 2.974 < r_branch = 1.4 => t < 0.471 -\foreach \t in {0.1, 0.2, 0.3, 0.4} { - \node[intA] at ({3.5 + 2.2*\t}, {5 + 2.0*\t}) {}; -} -\foreach \t in {0.5, 0.6, 0.7, 0.8, 0.9, 1.0} { - \node[skipped] at ({3.5 + 2.2*\t}, {5 + 2.0*\t}) {}; -} - -% --- B nodes along the bezier (approximate y from bezier evaluation) --- -% Interesting iff within r_prox=1.3 of any A node. -% Closest A node to B point (x, y) is (x, 5). -% B (6.5, 4.0): dist 1.0 -> interesting -% B (7.0, 4.4): dist 0.6 -> interesting -% B (7.5, 4.0): dist 1.0 -> interesting -% B (6.0, 3.2): dist 1.8 -> skipped -% all flat-region B nodes: -> skipped -\foreach \x in {0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0} { - \node[skipped] at (\x, 2.0) {}; -} -\node[skipped] at (5.5, 2.30) {}; -\node[skipped] at (6.0, 3.20) {}; -\node[intB] at (6.5, 4.00) {}; -\node[intB] at (7.0, 4.40) {}; -\node[intB] at (7.5, 4.00) {}; -\node[skipped] at (8.0, 3.20) {}; -\node[skipped] at (8.5, 2.30) {}; -\foreach \x in {9.0, 9.5, 10.0} { - \node[skipped] at (\x, 2.0) {}; -} - -% ==================================================================== -% Legend -% ==================================================================== -\begin{scope}[shift={(0, -1.7)}, font=\footnotesize] - \draw[black!25, line width=0.4pt, rounded corners=2pt, fill=white] - (0, -1.45) rectangle (10.7, 0.65); - - % --- Row 1: node-type swatches --- - \node[intA] at (0.30, 0.40) {}; - \node[anchor=west] at (0.50, 0.40) {sampled (axon $A$)}; - - \node[intB] at (3.20, 0.40) {}; - \node[anchor=west] at (3.40, 0.40) {sampled (axon $B$)}; - - \node[skipped] at (5.95, 0.40) {}; - \node[anchor=west] at (6.10, 0.40) {skipped}; - - \node[branchMark, inner sep=1.6pt] at (7.85, 0.40) {}; - \node[anchor=west] at (8.05, 0.40) {branching node}; - - % --- Row 2: r_branch --- - \draw[rBranch] (0.10, 0.00) -- (0.50, 0.00); - \node[anchor=west] at (0.55, 0.00) - {$r_{\mathrm{branch}}$: graph-distance buffer around branch points}; - - % --- Row 3: r_prox --- - \draw[rProx] (0.10, -0.40) -- (0.50, -0.40); - \node[anchor=west] at (0.55, -0.40) - {$r_{\mathrm{prox}}$: Euclidean buffer toward nodes in other components}; - - % --- Row 4: rule summary --- - \node[anchor=west, font=\footnotesize\itshape, color=black!70] - at (0.10, -1.00) - {sparse sampler emits a node iff (within $r_{\mathrm{branch}}$ of a branch) $\vee$ - (within $r_{\mathrm{prox}}$ of a node in another component)}; -\end{scope} - -\end{tikzpicture} -\end{document} From 75c2d737610c460a1eed5d2307bb41ff53c23e13 Mon Sep 17 00:00:00 2001 From: Geoffrey Schau Date: Mon, 18 May 2026 19:35:10 -0700 Subject: [PATCH 4/8] updated focal objective and early stopping criteria --- .../machine_learning/train.py | 128 +++++++++++++++--- .../machine_learning/vision_models.py | 11 +- .../merge_proofreading/merge_datasets.py | 73 ++++++++-- 3 files changed, 185 insertions(+), 27 deletions(-) diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index 911698b8..159ff4b1 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -29,10 +29,32 @@ import torch import torch.distributed as dist import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from neuron_proofreader.utils import img_util, ml_util, util + +class FocalLoss(nn.Module): + """Binary focal loss for imbalanced classification. + + Downweights easy examples (high confidence, correct) so training + concentrates on the hard cases that drive false positives and false + negatives. Alpha upweights the positive class; gamma sharpens the + focus (gamma=0 reduces to standard BCE). + """ + + def __init__(self, alpha=0.25, gamma=2.0): + super().__init__() + self.alpha = alpha + self.gamma = gamma + + def forward(self, logits, targets): + bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") + pt = torch.exp(-bce) + focal_weight = self.alpha * (1 - pt) ** self.gamma + return (focal_weight * bce).mean() + logger = logging.getLogger(__name__) _LOG_EVERY = 100 # batches between progress log lines @@ -83,6 +105,8 @@ def __init__( warmup_epochs=5, scheduler_type="cosine", pos_weight=None, + focal_gamma=None, + focal_alpha=0.25, save_val_logits=False, save_mistake_mips=False, on_best_model_saved=None, @@ -124,6 +148,7 @@ def __init__( # Instance attributes self.best_f1 = 0 self.best_val_loss = float("inf") + self.best_f1_at_95recall = 0.0 self.device = device self.log_dir = log_dir self.max_epochs = max_epochs @@ -138,11 +163,15 @@ def __init__( self.save_mistake_mips = save_mistake_mips self.on_best_model_saved = on_best_model_saved - if pos_weight is None: + if focal_gamma is not None: + self.criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) + print(f"Loss: FocalLoss(alpha={focal_alpha}, gamma={focal_gamma})") + elif pos_weight is None: self.criterion = nn.BCEWithLogitsLoss() else: pos_weight_tensor = torch.tensor([pos_weight], device=device) self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor) + print(f"Loss: BCEWithLogitsLoss(pos_weight={pos_weight})") self.model = model.to(device) self.optimizer = optim.AdamW( self._build_param_groups(self.model, lr, head_lr), @@ -256,16 +285,29 @@ def run(self, train_dataloader, val_dataloader): # Train-Validate train_stats = self.train_step(train_dataloader, epoch) val_stats = self.validate_step(val_dataloader, epoch) + new_best_loss = val_stats["loss"] < self.best_val_loss if new_best_loss: self.best_val_loss = val_stats["loss"] + + f1_95 = val_stats.get("f1_at_95recall", 0.0) + new_best_f1_95 = f1_95 > self.best_f1_at_95recall + if new_best_f1_95: + self.best_f1_at_95recall = f1_95 + + # Checkpoint: use F1@95recall once the model achieves it; fall back + # to val loss before that threshold is first reached. + if new_best_f1_95: + self.save_model(epoch, tag="best_f1_at_95recall") + if self.save_val_logits: + self._save_val_logits( + val_dataloader, self._last_val_y, self._last_val_hat_y, epoch + ) + elif new_best_loss and self.best_f1_at_95recall == 0.0: self.save_model(epoch, tag="best_loss") if self.save_val_logits: self._save_val_logits( - val_dataloader, - self._last_val_y, - self._last_val_hat_y, - epoch, + val_dataloader, self._last_val_y, self._last_val_hat_y, epoch ) # Log learning rate @@ -274,7 +316,12 @@ def run(self, train_dataloader, val_dataloader): self.writer.add_scalar("lr", current_lr, epoch) # Report results - print(f"\nEpoch {epoch}: " + ("New Best!" if new_best_loss else " ")) + is_new_best = new_best_f1_95 or (new_best_loss and self.best_f1_at_95recall == 0.0) + criterion_label = ( + f"F1@95R={f1_95:.4f}" if self.best_f1_at_95recall > 0.0 + else f"loss={val_stats['loss']:.4f}" + ) + print(f"\nEpoch {epoch}: " + (f"New Best! ({criterion_label})" if is_new_best else "")) self.report_stats(train_stats, is_train=True) self.report_stats(val_stats, is_train=False) @@ -289,8 +336,8 @@ def run(self, train_dataloader, val_dataloader): if new != old: print(f" LR reduced: group {i} {old:.2e} -> {new:.2e}") - # Early stopping check - if new_best_loss: + # Early stopping: track whichever criterion is active + if is_new_best: self.epochs_without_improvement = 0 else: self.epochs_without_improvement += 1 @@ -485,6 +532,27 @@ def forward_pass(self, x, y): return hat_y, loss # --- Helpers --- + @staticmethod + def _f1_at_recall_target(y, hat_y_logits, recall_target=0.95): + """Return the best F1 achievable at >= recall_target recall. + + Sweeps 200 probability thresholds and returns the maximum F1 among + those where recall >= recall_target. Returns 0.0 if the model never + achieves the target recall at any threshold. + """ + y_arr = np.array(y, dtype=int) + probs = 1.0 / (1.0 + np.exp(-np.array(hat_y_logits))) + thresholds = np.unique(np.percentile(probs, np.linspace(0, 100, 200))) + best_f1 = 0.0 + for t in thresholds: + preds = (probs >= t).astype(int) + r = recall_score(y_arr, preds, zero_division=0) + if r >= recall_target: + p = precision_score(y_arr, preds, zero_division=0) + f1 = 2 * p * r / max(p + r, 1e-8) + best_f1 = max(best_f1, f1) + return best_f1 + @staticmethod def compute_stats(y, hat_y): """ @@ -515,8 +583,10 @@ def compute_stats(y, hat_y): avg_recall = recall_score(y, hat_y, zero_division=np.nan) avg_f1 = 2 * avg_prec * avg_recall / max((avg_prec + avg_recall), 1e-8) avg_acc = accuracy_score(y, hat_y) + f1_at_95recall = Trainer._f1_at_recall_target(y, hat_y_arr) stats = { "f1": avg_f1, + "f1_at_95recall": f1_at_95recall, "precision": avg_prec, "recall": avg_recall, "accuracy": avg_acc, @@ -800,6 +870,8 @@ def __init__( warmup_epochs=5, scheduler_type="cosine", pos_weight=None, + focal_gamma=None, + focal_alpha=0.25, save_val_logits=False, save_mistake_mips=False ): @@ -876,6 +948,7 @@ def __init__( # Now initialize parent class attributes without creating directories self.best_f1 = 0 self.best_val_loss = float("inf") + self.best_f1_at_95recall = 0.0 self.device = device self.log_dir = log_dir self.max_epochs = max_epochs @@ -889,11 +962,17 @@ def __init__( self.save_val_logits = save_val_logits self.save_mistake_mips = save_mistake_mips - if pos_weight is None: + if focal_gamma is not None: + self.criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) + if self.rank == 0: + print(f"Loss: FocalLoss(alpha={focal_alpha}, gamma={focal_gamma})") + elif pos_weight is None: self.criterion = nn.BCEWithLogitsLoss() else: pos_weight_tensor = torch.tensor([pos_weight], device=device) self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor) + if self.rank == 0: + print(f"Loss: BCEWithLogitsLoss(pos_weight={pos_weight})") self.model = model.to(device) self.scaler = torch.cuda.amp.GradScaler(enabled=True) @@ -1092,15 +1171,32 @@ def run(self, train_dataloader, val_dataloader): new_best_loss = val_stats["loss"] < self.best_val_loss if new_best_loss: self.best_val_loss = val_stats["loss"] + + f1_95 = val_stats.get("f1_at_95recall", 0.0) + new_best_f1_95 = f1_95 > self.best_f1_at_95recall + if new_best_f1_95: + self.best_f1_at_95recall = f1_95 + + # Checkpoint: F1@95recall once achieved, val loss as fallback + if new_best_f1_95: + self.save_model(epoch, tag="best_f1_at_95recall") + if self.save_val_logits: + self._save_val_logits( + val_dataloader, self._last_val_y, self._last_val_hat_y, epoch + ) + elif new_best_loss and self.best_f1_at_95recall == 0.0: self.save_model(epoch, tag="best_loss") if self.save_val_logits: self._save_val_logits( - val_dataloader, - self._last_val_y, - self._last_val_hat_y, - epoch, + val_dataloader, self._last_val_y, self._last_val_hat_y, epoch ) - print(f"\nEpoch {epoch}: ", "New Best!" if new_best_loss else "") + + is_new_best = new_best_f1_95 or (new_best_loss and self.best_f1_at_95recall == 0.0) + criterion_label = ( + f"F1@95R={f1_95:.4f}" if self.best_f1_at_95recall > 0.0 + else f"loss={val_stats['loss']:.4f}" + ) + print(f"\nEpoch {epoch}: " + (f"New Best! ({criterion_label})" if is_new_best else "")) self.report_stats(train_stats, is_train=True) self.report_stats(val_stats, is_train=False) @@ -1117,8 +1213,8 @@ def run(self, train_dataloader, val_dataloader): print(f" LR reduced: group {i} {old:.2e} -> {new:.2e}") if rank == 0: - # Early stopping check - if new_best_loss: + # Early stopping: track whichever criterion is active + if is_new_best: self.epochs_without_improvement = 0 else: self.epochs_without_improvement += 1 diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index a20a91b5..cd5393db 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -177,7 +177,12 @@ def __init__( self.encoder_dim = self.model.encoder_dim self.n_prefix_tokens = 1 + encoder.n_register_tokens self.grid_size = tuple(int(g) for g in encoder.grid_size) - self.pool_power = pool_power + # Learnable pooling power γ (log-parameterized so it stays positive). + # exp(log(pool_power)) = pool_power at init, so skeleton tokens start + # pool_power× heavier than segment tokens and background stays zero. + self.pool_log_power = nn.Parameter( + torch.tensor(float(pool_power)).log() + ) # Dual-stream classifier: [CLS, mask-pooled] → 1 self.classifier = nn.Sequential( @@ -203,8 +208,8 @@ def forward(self, x): weights = F.adaptive_max_pool3d(mask, self.grid_size) weights = weights.reshape(weights.shape[0], -1) # (B, n_patches) - # Power-scale: skeleton=1.0, segment=0.25, bg=0.0 (with power=2) - weights = weights ** self.pool_power + # Power-scale with learned γ; exp keeps γ strictly positive. + weights = weights ** self.pool_log_power.exp() # Normalize weights to sum to 1 weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index 19c957a5..eed011c4 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -452,12 +452,21 @@ def get_random_negative_site(self): outcome = random.random() while True: # Sample node - if outcome < 0.4: + if outcome < 0.1: + # Near-merge hard negative: node 25–80 µm from a known merge + # site but confirmed not a merge itself. These are the hardest + # false positives the detector encounters at deployment. + node = self._sample_near_merge_negative(brain_id) + if node is None: + outcome = random.random() + continue + subgraph = self.graphs[brain_id].get_rooted_subgraph( + node, self.subgraph_radius + ) + return brain_id, subgraph, 0 + elif outcome < 0.44: # Any node node = util.sample_once(list(self.graphs[brain_id].nodes)) - #elif outcome < 0.5: - # # Node close to soma - # node = self.sample_node_nearby_soma(brain_id) elif outcome < 0.8: # Branching node branching_nodes = self.graphs[brain_id].get_branchings() @@ -491,6 +500,31 @@ def get_random_negative_site(self): if not self.is_nearby_merge_site(brain_id, node): return brain_id, subgraph, 0 + def _sample_near_merge_negative(self, brain_id, min_dist=25.0, max_dist=80.0, max_tries=20): + """ + Sample a fragment node that is near (but not at) a merge site. + + Nodes between min_dist and max_dist µm from the nearest GT merge site + are the hardest false positives at deployment. Adding ~10% of training + negatives from this zone makes the model discriminate them better. + + Returns None after max_tries failed attempts (e.g. sparse brain). + """ + if brain_id not in self.merge_site_kdtrees: + return None + kdtree = self.merge_site_kdtrees[brain_id] + graph = self.graphs[brain_id] + nodes = list(graph.nodes) + if not nodes: + return None + for _ in range(max_tries): + node = util.sample_once(nodes) + xyz = graph.node_xyz[node] + dist, _ = kdtree.query(xyz) + if min_dist <= dist <= max_dist: + return node + return None + def get_img_patch(self, brain_id, center): """ Extracts and normalizes a 3D image patch from the specified whole- @@ -537,8 +571,8 @@ def get_segment_mask(self, brain_id, center, subgraph): else: segment_mask = np.zeros(self.patch_shape) - # Annotate fragment - center = subgraph.get_voxel(0) + # Annotate fragment — use the passed center so translation augmentation + # shifts the skeleton overlay to match the shifted image read window. offset = img_util.get_offset(center, self.patch_shape) for node1, node2 in subgraph.edges: # Get local voxel coordinates @@ -663,7 +697,7 @@ class MergeSiteTrainDataset(MergeSiteDataset): A class for storing and retrieving training examples. """ - def __init__(self, base_dataset=None, idxs=None, negative_bias=0): + def __init__(self, base_dataset=None, idxs=None, negative_bias=0, max_translation=20): """ Instantiates a MergeSiteTrainDataset object. @@ -675,6 +709,10 @@ def __init__(self, base_dataset=None, idxs=None, negative_bias=0): Indices of examples to be kept in train dataset. negative_bias : float, optional Specifies percentage of additional negative examples to add. + max_translation : int, optional + Maximum voxel shift applied to the patch read center along each + axis during training. Shifts the merge site off-center to improve + robustness to misaligned inputs. Default is 20 voxels. """ # Create sub-dataset subset_dataset = base_dataset.subset(self.__class__, idxs) @@ -682,6 +720,7 @@ def __init__(self, base_dataset=None, idxs=None, negative_bias=0): # Instance attributes self.negative_bias = negative_bias + self.max_translation = max_translation self.transform = ImageTransforms() # --- Getters --- @@ -704,7 +743,25 @@ def __getitem__(self, idx): label : int 1 if the example is positive and 0 otherwise. """ - patches, subgraph, label = super().__getitem__(idx) + brain_id, subgraph, label = self.get_site(idx) + voxel = subgraph.get_voxel(0) + + # Random translation: shift the read window so the site appears + # off-center, training the model to be robust to misaligned inputs. + if self.max_translation > 0: + delta = np.random.randint(-self.max_translation, self.max_translation + 1, 3) + voxel = tuple(int(v + d) for v, d in zip(voxel, delta)) + + img_patch = self.get_img_patch(brain_id, voxel) + segment_mask = self.get_segment_mask(brain_id, voxel, subgraph) + + try: + patches = np.stack([img_patch, segment_mask], axis=0) + except ValueError: + img_patch = img_util.pad_to_shape(img_patch, self.patch_shape) + patches = np.stack([img_patch, segment_mask], axis=0) + + patches[0] = (patches[0] - patches[0].mean()) / (patches[0].std() + 1e-8) patches = self.transform(patches) return patches, subgraph, label From 2144de446b76a7346497cb7efb064a8409e7538e Mon Sep 17 00:00:00 2001 From: Geoffrey Schau Date: Tue, 19 May 2026 17:33:30 -0700 Subject: [PATCH 5/8] updated for sparse inference augmentation --- .../merge_proofreading/merge_dataloading.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/neuron_proofreader/merge_proofreading/merge_dataloading.py b/src/neuron_proofreader/merge_proofreading/merge_dataloading.py index 9ef69906..8873f9fa 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_dataloading.py +++ b/src/neuron_proofreader/merge_proofreading/merge_dataloading.py @@ -174,6 +174,19 @@ def load_images( segmentation_prefixes = util.read_json(segmentation_prefixes_path) brain_ids = get_brain_ids(dataset.merge_sites_df, is_test) + # Filter to brains present in both prefix maps. A missing key would + # KeyError inside the dict comprehension below before any futures are + # submitted, bypassing the try/except that guards future.result(). + loadable = [] + for bid in brain_ids: + if bid not in img_prefixes: + logger.warning("No image prefix for brain %s — skipping image load", bid) + elif bid not in segmentation_prefixes: + logger.warning("No seg prefix for brain %s — skipping image load", bid) + else: + loadable.append(bid) + brain_ids = loadable + _log_ram("before images") logger.info("Loading images (%d brains, %d workers)", len(brain_ids), max_workers) completed = 0 From 2ba1b8671edcd6e39dbe282e5a4622865b9cd9db Mon Sep 17 00:00:00 2001 From: Geoffrey Schau Date: Fri, 22 May 2026 13:48:39 -0700 Subject: [PATCH 6/8] updated sparsity --- src/neuron_proofreader/merge_proofreading/merge_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 88c0fb8c..cbe47e25 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -62,7 +62,7 @@ def __init__( # --- Core routines def search_graph(self): # Initialize progress bar - pbar = tqdm(total=self.dataset.estimate_iterations()) + pbar = tqdm(total=self.dataset.estimate_iterations(), miniters=1000, mininterval=0) t0 = time() # Iterate over dataset From 7b3324100f98e30a99457d574346c020c45d77fa Mon Sep 17 00:00:00 2001 From: Geoffrey Schau Date: Tue, 26 May 2026 10:05:58 -0700 Subject: [PATCH 7/8] updated logging timer --- src/neuron_proofreader/machine_learning/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index 159ff4b1..826186a5 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -52,7 +52,8 @@ def __init__(self, alpha=0.25, gamma=2.0): def forward(self, logits, targets): bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") pt = torch.exp(-bce) - focal_weight = self.alpha * (1 - pt) ** self.gamma + alpha_t = targets * self.alpha + (1 - targets) * (1 - self.alpha) + focal_weight = alpha_t * (1 - pt) ** self.gamma return (focal_weight * bce).mean() logger = logging.getLogger(__name__) From 92352d58ecf68b6a494cdfcc2a67b8f248647c00 Mon Sep 17 00:00:00 2001 From: Geoffrey Schau Date: Wed, 27 May 2026 10:15:11 -0700 Subject: [PATCH 8/8] train metrics update --- src/neuron_proofreader/machine_learning/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index 826186a5..173c5cd6 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -57,7 +57,7 @@ def forward(self, logits, targets): return (focal_weight * bce).mean() logger = logging.getLogger(__name__) -_LOG_EVERY = 100 # batches between progress log lines +_LOG_EVERY = 1000 # batches between progress log lines class Trainer: