diff --git a/mace/data/neighborhood.py b/mace/data/neighborhood.py index 03728969d..7e4210c28 100644 --- a/mace/data/neighborhood.py +++ b/mace/data/neighborhood.py @@ -24,16 +24,15 @@ def get_neighborhood( pbc_y = pbc[1] pbc_z = pbc[2] identity = np.identity(3, dtype=float) - max_positions = np.max(np.absolute(positions)) + 1 - # Extend cell in non-periodic directions - # For models with more than 5 layers, the multiplicative constant needs to be increased. - # temp_cell = np.copy(cell) + max_positions = np.max(positions, axis=0) - np.min(positions, axis=0) + padding = 1 # 1 angstrom padding + if not pbc_x: - cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + cell[0, :] = (max_positions[0] + cutoff + padding) * identity[0, :] if not pbc_y: - cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + cell[1, :] = (max_positions[1] + cutoff + padding) * identity[1, :] if not pbc_z: - cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + cell[2, :] = (max_positions[2] + cutoff + padding) * identity[2, :] sender, receiver, unit_shifts = neighbour_list( quantities="ijS",