From 58a1792fc1d6d911e2547940e29d5051089791b1 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 26 May 2026 14:38:57 -0400 Subject: [PATCH 1/2] use first-party testing --- mace/calculators/mace_torchsim.py | 12 +- tests/test_torchsim.py | 510 ++++++++++++++++-------------- 2 files changed, 283 insertions(+), 239 deletions(-) diff --git a/mace/calculators/mace_torchsim.py b/mace/calculators/mace_torchsim.py index aab609ed2..2cd7429f9 100644 --- a/mace/calculators/mace_torchsim.py +++ b/mace/calculators/mace_torchsim.py @@ -488,7 +488,7 @@ def compute_forces(self) -> bool: def compute_stress(self) -> bool: return self._compute_stress - def forward(self, state: Any) -> Dict[str, torch.Tensor]: + def forward(self, state: Any, **_kwargs: Any) -> Dict[str, torch.Tensor]: if ts is None: raise RuntimeError( "torch-sim is required to call MaceTorchSimModel.forward" @@ -601,7 +601,15 @@ def forward(self, state: Any) -> Dict[str, torch.Tensor]: data_dict["total_spin"] = total_spin if self._is_polar: - if total_spin is not None and (total_spin == 0).all(): + if "total_charge" not in data_dict: + data_dict["total_charge"] = torch.zeros( + self.n_systems, device=self._device, dtype=self._dtype + ) + if "total_spin" not in data_dict: + data_dict["total_spin"] = torch.ones( + self.n_systems, device=self._device, dtype=self._dtype + ) + elif (data_dict["total_spin"] == 0).all(): log.warning( "PolarMACE detected with total_spin=0 for all systems. " "MACE's default total_spin is 1.0; if this is unintentional, " diff --git a/tests/test_torchsim.py b/tests/test_torchsim.py index e3c7a27ce..4b7199667 100644 --- a/tests/test_torchsim.py +++ b/tests/test_torchsim.py @@ -1,4 +1,13 @@ -"""Tests for the MACE TorchSim model interface.""" +"""Tests for the MACE TorchSim model interface. + +Uses factory functions (mirrored from the torch-sim test conftest) to build +model/calculator consistency tests and ``validate_model_outputs`` tests from +fixture names. Unlike the upstream factories we feed them sim_state *fixtures* +(built from ASE atoms) instead of the ``SIMSTATE_GENERATORS`` registry, since +the small water-only MACE model trained here doesn't know about Si/Mg/Fe. +""" + +from __future__ import annotations import os import subprocess @@ -8,7 +17,6 @@ import numpy as np import pytest import torch -from ase import build try: import torch_sim as ts @@ -24,42 +32,156 @@ except ImportError: CUET_AVAILABLE = False +try: + import graph_longrange # noqa: F401 + + GRAPH_LONGRANGE_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + GRAPH_LONGRANGE_AVAILABLE = False + pytestmark = pytest.mark.skipif( not TORCHSIM_AVAILABLE, reason="torch-sim not installed" ) +DEVICE = torch.device("cpu") +DTYPE = torch.float64 +POLAR_MODEL_NAME = "polar-1-s" + pytest_mace_dir = Path(__file__).parent.parent run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" +def make_model_calculator_consistency_test( + test_name: str, + model_fixture_name: str, + calculator_fixture_name: str, + sim_state_fixture_names: tuple[str, ...], + *, + energy_rtol: float = 1e-5, + energy_atol: float = 1e-5, + force_rtol: float = 1e-5, + force_atol: float = 1e-5, + stress_rtol: float = 1e-5, + stress_atol: float = 1e-5, +): + """Build a parametrized consistency test between a model and a calculator. + + Adapted from the torch-sim conftest factory; takes sim_state *fixtures* + rather than entries of ``SIMSTATE_GENERATORS``. + """ + + @pytest.mark.parametrize("sim_state_fixture_name", sim_state_fixture_names) + def _test(sim_state_fixture_name: str, request: pytest.FixtureRequest) -> None: + from torch_sim.testing import assert_model_calculator_consistency + + model = request.getfixturevalue(model_fixture_name) + calculator = request.getfixturevalue(calculator_fixture_name) + sim_state = request.getfixturevalue(sim_state_fixture_name) + assert_model_calculator_consistency( + model=model, + calculator=calculator, + sim_state=sim_state, + energy_rtol=energy_rtol, + energy_atol=energy_atol, + force_rtol=force_rtol, + force_atol=force_atol, + stress_rtol=stress_rtol, + stress_atol=stress_atol, + ) + + _test.__name__ = f"test_{test_name}_consistency" + return _test + + +def make_validate_model_outputs_test( + model_fixture_name: str, + *, + device: torch.device = DEVICE, + dtype: torch.dtype = DTYPE, + check_detached: bool = True, + state_modifier=None, +): + """Build a ``validate_model_outputs`` test for a model fixture.""" + + def _test(request: pytest.FixtureRequest) -> None: + from torch_sim.models.interface import validate_model_outputs + + model = request.getfixturevalue(model_fixture_name) + validate_model_outputs( + model, + device, + dtype, + check_detached=check_detached, + state_modifier=state_modifier, + ) + + _test.__name__ = f"test_{model_fixture_name}_output_validation" + return _test + + +def _ensure_periodic_cell(state, box_size: float = 20.0): + """Force PBC and replace any zero-volume cell with a cubic *box_size* box. + + Used as a ``state_modifier`` when running ``validate_model_outputs`` on + PolarMACE: the validator probes the model on a non-periodic benzene + molecule (zero cell) which makes graph_longrange's k-space code blow up. + Wrapping it in a big PBC box keeps the chemistry effectively molecular + while giving the long-range code a well-defined reciprocal lattice. + """ + state = state.clone() + state.pbc = torch.tensor([True, True, True]) + cell = state.cell.clone() + volumes = torch.linalg.det(cell).abs() + eye = torch.eye(3, device=cell.device, dtype=cell.dtype) * box_size + for i, vol in enumerate(volumes): + if vol < 1e-6: + cell[i] = eye + state.cell = cell + return state + + @pytest.fixture(scope="module") def trained_model_path(tmp_path_factory): """Train a minimal MACE model and return the path to the model file.""" from ase.atoms import Atoms + from ase.build import bulk, molecule + water = Atoms( numbers=[8, 1, 1], positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], cell=[4] * 3, pbc=[True] * 3, ) - fit_configs = [ - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), - ] - fit_configs[0].info["REF_energy"] = 1.0 - fit_configs[0].info["config_type"] = "IsolatedAtom" - fit_configs[1].info["REF_energy"] = -0.5 - fit_configs[1].info["config_type"] = "IsolatedAtom" + isolated_energies = {1: -0.5, 8: 1.0, 6: -1.3, 14: -0.7, 12: -0.2, 26: -2.1} + fit_configs = [] + for z, e0 in isolated_energies.items(): + atom = Atoms(numbers=[z], positions=[[0, 0, 0]], cell=[6] * 3) + atom.info["REF_energy"] = e0 + atom.info["config_type"] = "IsolatedAtom" + fit_configs.append(atom) np.random.seed(42) - for _ in range(10): - c = water.copy() + + def _rattled(atoms_template): + c = atoms_template.copy() c.positions += np.random.normal(0.1, size=c.positions.shape) c.info["REF_energy"] = np.random.normal(0.1) c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) c.info["REF_stress"] = np.random.normal(0.1, size=6) - fit_configs.append(c) + return c + + extra_templates = [ + bulk("Si", "diamond", a=5.43, cubic=True), + bulk("Mg", "hcp", a=3.21, c=5.21), + bulk("Fe", "bcc", a=2.87, cubic=True), + molecule("C6H6", vacuum=4.0), + ] + extra_templates[-1].pbc = [True] * 3 # MACE stress training requires PBC + templates = [water, *extra_templates] + for template in templates: + for _ in range(3): + fit_configs.append(_rattled(template)) tmp_path = tmp_path_factory.mktemp("torchsim_model_") import ase.io @@ -123,147 +245,113 @@ def water_atoms(): return atoms -def test_torchsim_basic(trained_model_path, water_atoms): - from mace.calculators.mace_torchsim import MaceTorchSimModel +@pytest.fixture(scope="module") +def water_batched_atoms(water_atoms): + rng = np.random.default_rng(seed=0) + w1, w2 = water_atoms.copy(), water_atoms.copy() + w2.positions += rng.normal(0.01, size=w2.positions.shape) + return [w1, w2] - model = MaceTorchSimModel( - model=trained_model_path, - device=torch.device("cpu"), - dtype=torch.float64, - compute_forces=True, - compute_stress=True, - ) - state = ts.io.atoms_to_state( - water_atoms, device=torch.device("cpu"), dtype=torch.float64 - ) +@pytest.fixture(scope="module") +def water_sim_state(water_atoms): + return ts.io.atoms_to_state(water_atoms, device=DEVICE, dtype=DTYPE) - results = model(state) - assert "energy" in results - assert "forces" in results - assert "stress" in results - assert results["energy"].shape == (1,) - assert results["forces"].shape[0] == len(water_atoms) - assert results["forces"].shape[1] == 3 + +@pytest.fixture(scope="module") +def water_batched_sim_state(water_batched_atoms): + return ts.io.atoms_to_state(water_batched_atoms, device=DEVICE, dtype=DTYPE) -def test_torchsim_no_stress(trained_model_path, water_atoms): +@pytest.fixture(scope="module") +def mace_model(trained_model_path): from mace.calculators.mace_torchsim import MaceTorchSimModel - model = MaceTorchSimModel( + return MaceTorchSimModel( model=trained_model_path, - device=torch.device("cpu"), - dtype=torch.float64, + device=DEVICE, + dtype=DTYPE, compute_forces=True, - compute_stress=False, - ) - - state = ts.io.atoms_to_state( - water_atoms, device=torch.device("cpu"), dtype=torch.float64 + compute_stress=True, ) - results = model(state) - assert "energy" in results - assert "forces" in results - - -def test_torchsim_matches_ase_calculator(trained_model_path, water_atoms): - from ase.stress import full_3x3_to_voigt_6_stress +@pytest.fixture(scope="module") +def mace_calculator(trained_model_path): from mace.calculators.mace import MACECalculator - from mace.calculators.mace_torchsim import MaceTorchSimModel - ase_calc = MACECalculator( - model_paths=trained_model_path, device="cpu", default_dtype="float64" + return MACECalculator( + model_paths=trained_model_path, device=DEVICE.type, default_dtype="float64" ) - atoms_ase = water_atoms.copy() - atoms_ase.calc = ase_calc - ase_energy = atoms_ase.get_potential_energy() - ase_forces = atoms_ase.get_forces() - ase_stress = atoms_ase.get_stress() - ts_model = MaceTorchSimModel( - model=trained_model_path, - device=torch.device("cpu"), - dtype=torch.float64, - ) - state = ts.io.atoms_to_state( - water_atoms, device=torch.device("cpu"), dtype=torch.float64 - ) - ts_results = ts_model(state) - np.testing.assert_allclose( - ts_results["energy"].item(), ase_energy, atol=1e-5, rtol=1e-5 - ) - np.testing.assert_allclose( - ts_results["forces"].detach().cpu().numpy(), ase_forces, atol=1e-5, rtol=1e-5 - ) - ts_stress_voigt = full_3x3_to_voigt_6_stress( - ts_results["stress"].detach().cpu().numpy().reshape(3, 3) - ) - np.testing.assert_allclose(ts_stress_voigt, ase_stress, atol=1e-5, rtol=1e-5) - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_torchsim_cueq(trained_model_path, water_atoms): - from mace.calculators.mace_torchsim import MaceTorchSimModel +test_mace_water = make_model_calculator_consistency_test( + test_name="mace_water", + model_fixture_name="mace_model", + calculator_fixture_name="mace_calculator", + sim_state_fixture_names=("water_sim_state",), +) - model = MaceTorchSimModel( - model=trained_model_path, - device=torch.device("cpu"), - dtype=torch.float64, - enable_cueq=True, - ) +test_mace_output_validation = make_validate_model_outputs_test( + model_fixture_name="mace_model", + check_detached=True, +) - state = ts.io.atoms_to_state( - water_atoms, device=torch.device("cpu"), dtype=torch.float64 - ) - results = model(state) +def test_mace_torchsim_basic(mace_model, water_sim_state, water_atoms): + """Smoke test: forward pass returns the expected keys/shapes.""" + results = mace_model(water_sim_state) assert "energy" in results assert "forces" in results + assert "stress" in results + assert results["energy"].shape == (1,) + assert results["forces"].shape == (len(water_atoms), 3) -def test_torchsim_batched(trained_model_path, water_atoms): +def test_mace_torchsim_no_stress(trained_model_path, water_sim_state): from mace.calculators.mace_torchsim import MaceTorchSimModel - w1 = water_atoms.copy() - w2 = water_atoms.copy() - w2.positions += np.random.RandomState(0).normal(0.01, size=w2.positions.shape) - model = MaceTorchSimModel( model=trained_model_path, - device=torch.device("cpu"), - dtype=torch.float64, - ) - - state = ts.io.atoms_to_state( - [w1, w2], device=torch.device("cpu"), dtype=torch.float64 + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=False, ) + results = model(water_sim_state) + assert "energy" in results + assert "forces" in results - results = model(state) - assert results["energy"].shape == (2,) - assert results["forces"].shape == (len(w1) + len(w2), 3) -# --------------------------------------------------------------------------- -# PolarMACE tests -# --------------------------------------------------------------------------- -try: - import graph_longrange # noqa: F401 +@pytest.fixture(scope="module") +def mace_cueq_model(trained_model_path): + if not CUET_AVAILABLE: + pytest.skip("cuequivariance not installed") + from mace.calculators.mace_torchsim import MaceTorchSimModel - GRAPH_LONGRANGE_AVAILABLE = True -except (ImportError, ModuleNotFoundError): - GRAPH_LONGRANGE_AVAILABLE = False + return MaceTorchSimModel( + model=trained_model_path, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + enable_cueq=True, + ) -POLAR_MODEL_NAME = "polar-1-s" -DEVICE = torch.device("cpu") -DTYPE = torch.float64 +test_mace_cueq_water = make_model_calculator_consistency_test( + test_name="mace_cueq_water", + model_fixture_name="mace_cueq_model", + calculator_fixture_name="mace_calculator", + sim_state_fixture_names=("water_sim_state",), +) -def _atoms_state(atoms): - return ts.io.atoms_to_state(atoms, device=DEVICE, dtype=DTYPE) +test_mace_cueq_output_validation = make_validate_model_outputs_test( + model_fixture_name="mace_cueq_model", + check_detached=True, +) def _skip_if_polar_unavailable(exc, model_name): @@ -289,34 +377,59 @@ def polar_raw_model(): @pytest.fixture(scope="module") -def water_state(water_atoms): - """SimState for a single water molecule (no extras).""" - return _atoms_state(water_atoms) +def polar_model(polar_raw_model): + from mace.calculators.mace_torchsim import MaceTorchSimModel + + return MaceTorchSimModel( + model=polar_raw_model, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) @pytest.fixture(scope="module") -def water_state_with_extras(water_state): - """SimState with polar-relevant extras set.""" - state = water_state.clone() - state.external_E_field = torch.tensor([[0.1, 0.0, 0.0]], dtype=DTYPE) - state.charge = torch.tensor([0.0], dtype=DTYPE) +def polar_calculator(): + if not GRAPH_LONGRANGE_AVAILABLE: + pytest.skip("graph_longrange is not installed") + from mace.calculators.foundations_models import mace_polar + + try: + return mace_polar( + model=POLAR_MODEL_NAME, device=DEVICE.type, default_dtype="float64" + ) + except (FileNotFoundError, ValueError, RuntimeError) as exc: + _skip_if_polar_unavailable(exc, POLAR_MODEL_NAME) + + +@pytest.fixture(scope="module") +def water_sim_state_polar(water_sim_state): + """Single water with spin=1.0 to match ASE's AtomicData default.""" + state = water_sim_state.clone() state.spin = torch.tensor([1.0], dtype=DTYPE) return state @pytest.fixture(scope="module") -def water_batched_state(water_atoms): - """Batched SimState with 2 water molecules.""" - rng = np.random.default_rng(seed=0) - w1, w2 = water_atoms.copy(), water_atoms.copy() - w2.positions += rng.normal(0.01, size=w2.positions.shape) - return _atoms_state([w1, w2]) +def water_batched_sim_state_polar(water_batched_sim_state): + state = water_batched_sim_state.clone() + state.spin = torch.tensor([1.0, 1.0], dtype=DTYPE) + return state + + +@pytest.fixture(scope="module") +def water_sim_state_with_extras(water_sim_state): + state = water_sim_state.clone() + state.external_E_field = torch.tensor([[0.1, 0.0, 0.0]], dtype=DTYPE) + state.charge = torch.tensor([0.0], dtype=DTYPE) + state.spin = torch.tensor([1.0], dtype=DTYPE) + return state @pytest.fixture(scope="module") -def water_batched_state_with_extras(water_batched_state): - """Batched SimState with polar-relevant extras.""" - state = water_batched_state.clone() +def water_batched_sim_state_with_extras(water_batched_sim_state): + state = water_batched_sim_state.clone() state.external_E_field = torch.tensor( [[0.1, 0.0, 0.0], [0.0, 0.1, 0.0]], dtype=DTYPE ) @@ -325,19 +438,23 @@ def water_batched_state_with_extras(water_batched_state): return state -def test_torchsim_polar_basic(polar_raw_model, water_state): - """Forward pass with PolarMACE using defaults (no extras).""" - from mace.calculators.mace_torchsim import MaceTorchSimModel +test_polar_water = make_model_calculator_consistency_test( + test_name="polar_water", + model_fixture_name="polar_model", + calculator_fixture_name="polar_calculator", + sim_state_fixture_names=("water_sim_state_polar",), +) - model = MaceTorchSimModel( - model=polar_raw_model, - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) +test_polar_output_validation = make_validate_model_outputs_test( + model_fixture_name="polar_model", + check_detached=True, + state_modifier=_ensure_periodic_cell, # needed because default benzene is non-periodic +) - results = model(water_state) + +def test_polar_torchsim_basic(polar_model, water_sim_state): + """Forward pass with PolarMACE using defaults (no extras).""" + results = polar_model(water_sim_state) assert results["energy"].shape == (1,) assert results["forces"].shape == (3, 3) assert results["stress"].shape == (1, 3, 3) @@ -346,39 +463,20 @@ def test_torchsim_polar_basic(polar_raw_model, water_state): assert "density_coefficients" in results -def test_torchsim_polar_with_extras(polar_raw_model, water_state_with_extras): - """Forward pass with PolarMACE using explicit extras.""" - from mace.calculators.mace_torchsim import MaceTorchSimModel - - model = MaceTorchSimModel( - model=polar_raw_model, - device=DEVICE, - dtype=DTYPE, - ) - - results = model(water_state_with_extras) +def test_polar_torchsim_with_extras(polar_model, water_sim_state_with_extras): + results = polar_model(water_sim_state_with_extras) assert results["energy"].shape == (1,) assert "charges" in results assert "dipole" in results assert "density_coefficients" in results -def test_torchsim_polar_no_extras_vs_zero_extras(polar_raw_model, water_state): +def test_polar_torchsim_no_extras_vs_zero_extras(polar_model, water_sim_state): """Defaults (no extras) should match explicitly passing zeros.""" - from mace.calculators.mace_torchsim import MaceTorchSimModel - - model = MaceTorchSimModel( - model=polar_raw_model, - device=DEVICE, - dtype=DTYPE, - ) - - results_no_extras = model(water_state) - - state_zero_extras = water_state.clone() + results_no_extras = polar_model(water_sim_state) + state_zero_extras = water_sim_state.clone() state_zero_extras.external_E_field = torch.zeros(1, 3, dtype=DTYPE) - results_zero_extras = model(state_zero_extras) - + results_zero_extras = polar_model(state_zero_extras) np.testing.assert_allclose( results_no_extras["energy"].detach().cpu().numpy(), results_zero_extras["energy"].detach().cpu().numpy(), @@ -386,81 +484,19 @@ def test_torchsim_polar_no_extras_vs_zero_extras(polar_raw_model, water_state): ) -def test_torchsim_polar_batched(polar_raw_model, water_batched_state): - """Batched PolarMACE forward pass.""" - from mace.calculators.mace_torchsim import MaceTorchSimModel - - model = MaceTorchSimModel( - model=polar_raw_model, - device=DEVICE, - dtype=DTYPE, - ) - - results = model(water_batched_state) +def test_polar_torchsim_batched(polar_model, water_batched_sim_state): + results = polar_model(water_batched_sim_state) assert results["energy"].shape == (2,) assert results["forces"].shape == (6, 3) assert "dipole" in results assert results["dipole"].shape[0] == 2 -def test_torchsim_polar_batched_with_extras( - polar_raw_model, water_batched_state_with_extras +def test_polar_torchsim_batched_with_extras( + polar_model, water_batched_sim_state_with_extras ): - """Batched PolarMACE with per-system extras.""" - from mace.calculators.mace_torchsim import MaceTorchSimModel - - model = MaceTorchSimModel( - model=polar_raw_model, - device=DEVICE, - dtype=DTYPE, - ) - - results = model(water_batched_state_with_extras) + results = polar_model(water_batched_sim_state_with_extras) assert results["energy"].shape == (2,) assert results["forces"].shape == (6, 3) assert "dipole" in results assert results["dipole"].shape[0] == 2 - - -def test_torchsim_polar_matches_ase(polar_raw_model, water_atoms): - """PolarMACE TorchSim results should match ASE calculator.""" - from ase.stress import full_3x3_to_voigt_6_stress - - from mace.calculators.foundations_models import mace_polar - from mace.calculators.mace_torchsim import MaceTorchSimModel - - try: - ase_calc = mace_polar( - model=POLAR_MODEL_NAME, device=DEVICE.type, default_dtype="float64" - ) - except (FileNotFoundError, ValueError, RuntimeError) as exc: - _skip_if_polar_unavailable(exc, POLAR_MODEL_NAME) - - # ASE's AtomicData.from_config() defaults total_spin to 1.0; SimState - # defaults spin to 0.0. Pin both sides to 1.0 so they match. - atoms_ase = water_atoms.copy() - atoms_ase.info["spin"] = 1.0 - atoms_ase.calc = ase_calc - ase_energy = atoms_ase.get_potential_energy() - ase_forces = atoms_ase.get_forces() - ase_stress = atoms_ase.get_stress() - - ts_model = MaceTorchSimModel( - model=polar_raw_model, - device=DEVICE, - dtype=DTYPE, - ) - state = _atoms_state(water_atoms) - state.spin = torch.tensor([1.0], dtype=DTYPE) - ts_results = ts_model(state) - - np.testing.assert_allclose( - ts_results["energy"].item(), ase_energy, atol=1e-5, rtol=1e-5 - ) - np.testing.assert_allclose( - ts_results["forces"].detach().cpu().numpy(), ase_forces, atol=1e-5, rtol=1e-5 - ) - ts_stress_voigt = full_3x3_to_voigt_6_stress( - ts_results["stress"].detach().cpu().numpy().reshape(3, 3) - ) - np.testing.assert_allclose(ts_stress_voigt, ase_stress, atol=1e-5, rtol=1e-5) From 1ebcd1c40a8075aa0d976c0c5aaaf4f668e1475e Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Tue, 26 May 2026 14:42:56 -0400 Subject: [PATCH 2/2] doc --- tests/test_torchsim.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_torchsim.py b/tests/test_torchsim.py index 4b7199667..a636d2123 100644 --- a/tests/test_torchsim.py +++ b/tests/test_torchsim.py @@ -3,8 +3,9 @@ Uses factory functions (mirrored from the torch-sim test conftest) to build model/calculator consistency tests and ``validate_model_outputs`` tests from fixture names. Unlike the upstream factories we feed them sim_state *fixtures* -(built from ASE atoms) instead of the ``SIMSTATE_GENERATORS`` registry, since -the small water-only MACE model trained here doesn't know about Si/Mg/Fe. +(built from ASE atoms) instead of the ``SIMSTATE_GENERATORS`` registry, which +keeps the inputs hermetic to this repo and avoids needing torch-sim's bulk +crystal generators at consistency-test time. """ from __future__ import annotations