From 5bf6979bf73249485d1853e6d32a3af414e3b24c Mon Sep 17 00:00:00 2001 From: Nicholas Gao Date: Wed, 15 Apr 2026 10:02:52 +0000 Subject: [PATCH] improve traceabaility --- orb_models/common/atoms/batch/graph_batch.py | 10 ++++++---- orb_models/common/models/segment_ops.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/orb_models/common/atoms/batch/graph_batch.py b/orb_models/common/atoms/batch/graph_batch.py index 9790838..735d6f8 100644 --- a/orb_models/common/atoms/batch/graph_batch.py +++ b/orb_models/common/atoms/batch/graph_batch.py @@ -79,7 +79,7 @@ def __init__( # repeat_interleave, e.g. tensor[node_batch_index] instead of tensor.repeat_interleave(n_node) self.node_batch_index = torch.arange( self.n_node.shape[0], dtype=torch.int64, device=self.n_node.device - ).repeat_interleave(self.n_node) + ).repeat_interleave(self.n_node, output_size=self.positions.shape[0]) def split(self: _T, clone=True) -> list["AtomGraphs"]: """Splits batched AtomGraphs into constituent system AtomGraphs. @@ -218,11 +218,12 @@ def _get_per_node_graph_indices(self): TODO: we could cache these indices. """ positions = self.node_features["positions"] - graph_indices = torch.zeros(len(positions), device=positions.device, dtype=torch.int) + n_atoms = positions.shape[0] + graph_indices = torch.zeros(n_atoms, device=positions.device, dtype=torch.int) cumsums = torch.cumsum(self.n_node, dim=0) graph_indices[:] = torch.searchsorted( cumsums, - torch.arange(len(positions), device=positions.device), + torch.arange(n_atoms, device=positions.device), right=True, ) return graph_indices # (natoms,) @@ -234,9 +235,10 @@ def _get_per_edge_graph_indices(self): """ graph_indices = torch.zeros_like(self.senders) # (nedges,) cumsums = torch.cumsum(self.n_edge, dim=0) # (ngraphs,) + n_edges = self.senders.shape[0] graph_indices[:] = torch.searchsorted( cumsums, - torch.arange(len(self.senders), device=self.senders.device), + torch.arange(n_edges, device=self.senders.device), right=True, ) return graph_indices # (nedges,) diff --git a/orb_models/common/models/segment_ops.py b/orb_models/common/models/segment_ops.py index 1aade66..28a0e41 100644 --- a/orb_models/common/models/segment_ops.py +++ b/orb_models/common/models/segment_ops.py @@ -31,7 +31,7 @@ def aggregate_nodes( # assert n_node.sum() == tensor.shape[0] device = tensor.device - count = len(n_node) + count = n_node.shape[0] if deterministic: import os @@ -40,7 +40,7 @@ def aggregate_nodes( https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility""" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True) - segments = torch.arange(count, device=device).repeat_interleave(n_node) + segments = torch.arange(count, device=device).repeat_interleave(n_node, output_size=tensor.shape[0]) if reduction == "sum": return scatter_sum(tensor, segments, dim=0, dim_size=count) elif reduction == "mean":