From 88b77ae6c34e35f5c39c265b3204ae5e6c7a5ec9 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Tue, 12 May 2026 13:47:03 +0200 Subject: [PATCH 01/20] Add an empty metatomic-core python package, re-exporting metatomic-torch --- pyproject.toml | 14 +- python/metatomic_core/AUTHORS | 1 + .../CMakeLists.txt} | 0 python/metatomic_core/LICENSE | 1 + python/metatomic_core/MANIFEST.in | 6 + python/metatomic_core/metatomic/__init__.py | 0 python/metatomic_core/metatomic/torch.py | 14 ++ python/metatomic_core/pyproject.toml | 54 +++++++ python/metatomic_core/setup.py | 146 ++++++++++++++++++ python/metatomic_torch/CMakeLists.txt | 15 +- python/metatomic_torch/MANIFEST.in | 2 +- python/metatomic_torch/README.rst | 6 +- .../torch => metatomic_torch}/__init__.py | 8 + .../torch => metatomic_torch}/_c_lib.py | 0 .../torch => metatomic_torch}/_extensions.py | 0 .../ase_calculator.py | 0 .../documentation.py | 0 .../torch => metatomic_torch}/heat_flux.py | 2 +- .../torch => metatomic_torch}/model.py | 0 .../serialization.py | 0 .../systems_to_torch.py | 0 .../torch => metatomic_torch}/utils.py | 0 .../torch => metatomic_torch}/version.py | 0 python/metatomic_torch/pyproject.toml | 4 - python/metatomic_torch/setup.py | 8 +- scripts/clean-python.sh | 9 ++ setup.py | 20 ++- tox.ini | 22 ++- 28 files changed, 303 insertions(+), 29 deletions(-) create mode 120000 python/metatomic_core/AUTHORS rename python/{metatomic_torch/metatomic/__init__.py => metatomic_core/CMakeLists.txt} (100%) create mode 120000 python/metatomic_core/LICENSE create mode 100644 python/metatomic_core/MANIFEST.in create mode 100644 python/metatomic_core/metatomic/__init__.py create mode 100644 python/metatomic_core/metatomic/torch.py create mode 100644 python/metatomic_core/pyproject.toml create mode 100644 python/metatomic_core/setup.py rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/__init__.py (92%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/_c_lib.py (100%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/_extensions.py (100%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/ase_calculator.py (100%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/documentation.py (100%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/heat_flux.py (99%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/model.py (100%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/serialization.py (100%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/systems_to_torch.py (100%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/utils.py (100%) rename python/metatomic_torch/{metatomic/torch => metatomic_torch}/version.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 88dc392b9..2db00795e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,12 +63,16 @@ filterwarnings = [ "ignore:ast.NameConstant is deprecated and will be removed in Python 3.14:DeprecationWarning", # TorchScript deprecation warnings "ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", + "ignore:`torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning", - "ignore:.*vesin.metatomic was only tested with metatomic.torch >=0.1.3,<0.2.*:UserWarning", "ignore:`torch.jit.load` is deprecated. Please switch to `torch.export`.:DeprecationWarning", "ignore:`torch.jit.script` is not supported in Python 3.14+:DeprecationWarning", + "ignore:`torch.jit.script_method` is not supported in Python 3.14+:DeprecationWarning", "ignore:`torch.jit.save` is not supported in Python 3.14+:DeprecationWarning", - # deprecation warning from warp/nvalchemi + # vesin and metatomic warning + "ignore:.*vesin.metatomic was only tested with metatomic.torch >=0.1.3,<0.2.*:UserWarning", + # Warnings from warp (dependency of nvalchemi) + "ignore:.*Structure will use memory layout compatible with MSVC:DeprecationWarning", "ignore:warp.config.quiet is deprecated:DeprecationWarning", ] @@ -95,6 +99,8 @@ docstring-code-format = true [tool.uv.pip] reinstall-package = [ - "metatomic-torch", - "metatomic-torchsim", + "metatomic_core", + "metatomic_torch", + "metatomic_torchsim", + "metatomic_ase", ] diff --git a/python/metatomic_core/AUTHORS b/python/metatomic_core/AUTHORS new file mode 120000 index 000000000..f04b7e8a2 --- /dev/null +++ b/python/metatomic_core/AUTHORS @@ -0,0 +1 @@ +../../AUTHORS \ No newline at end of file diff --git a/python/metatomic_torch/metatomic/__init__.py b/python/metatomic_core/CMakeLists.txt similarity index 100% rename from python/metatomic_torch/metatomic/__init__.py rename to python/metatomic_core/CMakeLists.txt diff --git a/python/metatomic_core/LICENSE b/python/metatomic_core/LICENSE new file mode 120000 index 000000000..30cff7403 --- /dev/null +++ b/python/metatomic_core/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/python/metatomic_core/MANIFEST.in b/python/metatomic_core/MANIFEST.in new file mode 100644 index 000000000..02404051b --- /dev/null +++ b/python/metatomic_core/MANIFEST.in @@ -0,0 +1,6 @@ +include pyproject.toml +include CMakeLists.txt +include AUTHORS +include LICENSE + +include git_version_info diff --git a/python/metatomic_core/metatomic/__init__.py b/python/metatomic_core/metatomic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/metatomic_core/metatomic/torch.py b/python/metatomic_core/metatomic/torch.py new file mode 100644 index 000000000..060e7bccf --- /dev/null +++ b/python/metatomic_core/metatomic/torch.py @@ -0,0 +1,14 @@ +import sys + + +try: + import metatomic_torch +except ImportError as e: + raise ImportError( + "metatomic-torch is required to use the metatomic.torch module. " + "Please install it with `pip install metatomic-torch` or using " + "your favorite Python package manager." + ) from e + +# metatomic.torch is registered as an alias in metatomic_torch's __init__.py +assert sys.modules["metatomic.torch"] is metatomic_torch diff --git a/python/metatomic_core/pyproject.toml b/python/metatomic_core/pyproject.toml new file mode 100644 index 000000000..9107f6805 --- /dev/null +++ b/python/metatomic_core/pyproject.toml @@ -0,0 +1,54 @@ +[project] +name = "metatomic-core" +dynamic = ["version", "authors", "dependencies"] +requires-python = ">=3.10" + +# readme = "TODO" +license = "BSD-3-Clause" +description = "Interface between atomistic machine learning models and simulation tools" + +keywords = ["machine learning", "molecular modeling"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Operating System :: POSIX", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "Topic :: Scientific/Engineering :: Chemistry", + "Topic :: Scientific/Engineering :: Physics", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +[project.urls] +homepage = "https://docs.metatensor.org/metatomic/" +documentation = "https://docs.metatensor.org/metatomic/" +repository = "https://github.com/metatensor/metatomic" +# changelog = "TODO" + +### ======================================================================== ### +[build-system] +requires = [ + "setuptools >=77", + "packaging >=26", + "cmake", + "metatensor-core >=0.2.0,<0.3", +] + +build-backend = "setuptools.build_meta" + + +[tool.setuptools] +zip-safe = false + +### ======================================================================== ### +[tool.pytest.ini_options] +python_files = ["*.py"] +testpaths = ["tests"] +filterwarnings = [ + "error", +] diff --git a/python/metatomic_core/setup.py b/python/metatomic_core/setup.py new file mode 100644 index 000000000..f2654eef3 --- /dev/null +++ b/python/metatomic_core/setup.py @@ -0,0 +1,146 @@ +import os +import subprocess +import sys + +import packaging.version +from setuptools import setup +from setuptools.command.bdist_egg import bdist_egg +from setuptools.command.sdist import sdist + + +ROOT = os.path.realpath(os.path.dirname(__file__)) + +METATOMIC_CORE_VERSION = "0.1.0" + +METATOMIC_BUILD_TYPE = os.environ.get("METATOMIC_BUILD_TYPE", "release") +if METATOMIC_BUILD_TYPE not in ["debug", "release"]: + raise Exception( + f"invalid build type passed: '{METATOMIC_BUILD_TYPE}', " + "expected 'debug' or 'release'" + ) + + +class bdist_egg_disabled(bdist_egg): + """Disabled version of bdist_egg + + Prevents setup.py install performing setuptools' default easy_install, + which it should never ever do. + """ + + def run(self): + sys.exit( + "Aborting implicit building of eggs.\nUse `pip install .` or " + "`python -m build --wheel . && pip install dist/metatomic_torch-*.whl` " + "to install from source." + ) + + +class sdist_generate_data(sdist): + """ + Create a sdist with an additional generated files: + - `git_version_info` + """ + + def run(self): + n_commits, git_hash = git_version_info() + with open("git_version_info", "w") as fd: + fd.write(f"{n_commits}\n{git_hash}\n") + + # run original sdist + super().run() + + os.unlink("git_version_info") + + +def git_version_info(): + """ + If git is available and we are building from a checkout, get the number of commits + since the last tag & full hash of the code. Otherwise, this always returns (0, ""). + """ + TAG_PREFIX = "metatomic-v" + + if os.path.exists("git_version_info"): + # we are building from a sdist, without git available, but the git + # version was recorded in the `git_version_info` file + with open("git_version_info") as fd: + n_commits = int(fd.readline().strip()) + git_hash = fd.readline().strip() + else: + script = os.path.join(ROOT, "..", "..", "scripts", "git-version-info.py") + assert os.path.exists(script) + + output = subprocess.run( + [sys.executable, script, TAG_PREFIX], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + encoding="utf8", + ) + + if output.returncode != 0: + raise Exception( + "failed to get git version info.\n" + f"stdout: {output.stdout}\n" + f"stderr: {output.stderr}\n" + ) + elif output.stderr: + print(output.stderr, file=sys.stderr) + n_commits = 0 + git_hash = "" + else: + lines = output.stdout.splitlines() + n_commits = int(lines[0].strip()) + git_hash = lines[1].strip() + + return n_commits, git_hash + + +def create_version_number(version): + version = packaging.version.parse(version) + + n_commits, git_hash = git_version_info() + + if n_commits != 0: + # if we have commits since the last tag, this mean we are in a pre-release of + # the next version. So we increase either the minor version number or the + # release candidate number (if we are closing up on a release) + if version.pre is not None: + assert version.pre[0] == "rc" + pre = ("rc", version.pre[1] + 1) + release = version.release + else: + major, minor, _ = version.release + release = (major, minor + 1, 0) + pre = None + + version = version.__replace__( + release=release, + pre=pre, + dev=n_commits, + local=git_hash, + ) + + return str(version) + + +if __name__ == "__main__": + with open(os.path.join(ROOT, "AUTHORS")) as fd: + authors = fd.read().splitlines() + + if authors[0].startswith(".."): + # handle "raw" symlink files (on Windows or from full repo tarball) + with open(os.path.join(ROOT, authors[0])) as fd: + authors = fd.read().splitlines() + + install_requires = [ + "metatensor-core >=0.2.0,<0.3", + ] + + setup( + version=create_version_number(METATOMIC_CORE_VERSION), + author=", ".join(authors), + install_requires=install_requires, + cmdclass={ + "bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled, + "sdist": sdist_generate_data, + }, + ) diff --git a/python/metatomic_torch/CMakeLists.txt b/python/metatomic_torch/CMakeLists.txt index 0fb2d5421..2ee4c2bb2 100644 --- a/python/metatomic_torch/CMakeLists.txt +++ b/python/metatomic_torch/CMakeLists.txt @@ -63,6 +63,9 @@ else() add_subdirectory("${METATOMIC_TORCH_SOURCE_DIR}" metatomic-torch) + if (CMAKE_VERSION VERSION_LESS "3.25") + set(LINUX $) + endif() if (LINUX OR APPLE) if (LINUX) @@ -74,12 +77,12 @@ else() set(metatomic_install_rpath "${CMAKE_INSTALL_RPATH}") # when loading the libraries from a Python installation: - # - $ORIGIN/../../../../torch/lib is where libtorch.so will be - # - $ORIGIN/../../../../metatensor/lib is where libmetatensor.so will be - # - $ORIGIN/../../../../metatensor/torch/torch-x.y/lib is where libmetatensor_torch.so will be - set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../../torch/lib") - set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../../metatensor/lib") - set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../../metatensor/torch/torch-${Torch_VERSION_MAJOR}.${Torch_VERSION_MINOR}/lib") + # - $ORIGIN/../../../torch/lib is where libtorch.so will be + # - $ORIGIN/../../../metatensor/lib is where libmetatensor.so will be + # - $ORIGIN/../../../metatensor_torch/torch-${Torch_VERSION_MAJOR}.${Torch_VERSION_MINOR}/lib is where libmetatensor_torch.so will be + set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../torch/lib") + set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../metatensor/lib") + set(metatomic_install_rpath "${metatomic_install_rpath};${rpath_origin}/../../../metatensor_torch/torch-${Torch_VERSION_MAJOR}.${Torch_VERSION_MINOR}/lib") set_target_properties( metatomic_torch PROPERTIES INSTALL_RPATH "${metatomic_install_rpath}" diff --git a/python/metatomic_torch/MANIFEST.in b/python/metatomic_torch/MANIFEST.in index 6d341b48e..9e6ef4edb 100644 --- a/python/metatomic_torch/MANIFEST.in +++ b/python/metatomic_torch/MANIFEST.in @@ -5,6 +5,6 @@ include LICENSE include git_version_info -include metatomic-torch-*.tar.gz +include metatomic-torch-cxx-*.tar.gz recursive-include build-backend *.py diff --git a/python/metatomic_torch/README.rst b/python/metatomic_torch/README.rst index f06f2b8af..994fda75e 100644 --- a/python/metatomic_torch/README.rst +++ b/python/metatomic_torch/README.rst @@ -1,4 +1,4 @@ -metatensor-torch -================ +metatomic-torch +=============== -This package contains the TorchScript bindings to the core API of metatensor. +This package contains the TorchScript bindings to the core API of metatomic. diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic_torch/__init__.py similarity index 92% rename from python/metatomic_torch/metatomic/torch/__init__.py rename to python/metatomic_torch/metatomic_torch/__init__.py index a8bf363aa..dc0ba38ca 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic_torch/__init__.py @@ -1,8 +1,11 @@ import os +import sys from typing import TYPE_CHECKING import torch +import metatomic + from ._c_lib import _load_library from .version import __version__ # noqa: F401 @@ -65,3 +68,8 @@ save_buffer, ) from .systems_to_torch import systems_to_torch # noqa: F401 + + +sys.modules["metatomic.torch"] = sys.modules[__name__] +if not hasattr(metatomic, "torch"): + metatomic.torch = sys.modules[__name__] diff --git a/python/metatomic_torch/metatomic/torch/_c_lib.py b/python/metatomic_torch/metatomic_torch/_c_lib.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/_c_lib.py rename to python/metatomic_torch/metatomic_torch/_c_lib.py diff --git a/python/metatomic_torch/metatomic/torch/_extensions.py b/python/metatomic_torch/metatomic_torch/_extensions.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/_extensions.py rename to python/metatomic_torch/metatomic_torch/_extensions.py diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic_torch/ase_calculator.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/ase_calculator.py rename to python/metatomic_torch/metatomic_torch/ase_calculator.py diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic_torch/documentation.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/documentation.py rename to python/metatomic_torch/metatomic_torch/documentation.py diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic_torch/heat_flux.py similarity index 99% rename from python/metatomic_torch/metatomic/torch/heat_flux.py rename to python/metatomic_torch/metatomic_torch/heat_flux.py index 4de0828e5..167149b06 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic_torch/heat_flux.py @@ -4,7 +4,7 @@ from metatensor.torch import Labels, TensorBlock, TensorMap from vesin.metatomic import NeighborList -from metatomic.torch import ( +from . import ( AtomisticModel, ModelCapabilities, ModelOutput, diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic_torch/model.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/model.py rename to python/metatomic_torch/metatomic_torch/model.py diff --git a/python/metatomic_torch/metatomic/torch/serialization.py b/python/metatomic_torch/metatomic_torch/serialization.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/serialization.py rename to python/metatomic_torch/metatomic_torch/serialization.py diff --git a/python/metatomic_torch/metatomic/torch/systems_to_torch.py b/python/metatomic_torch/metatomic_torch/systems_to_torch.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/systems_to_torch.py rename to python/metatomic_torch/metatomic_torch/systems_to_torch.py diff --git a/python/metatomic_torch/metatomic/torch/utils.py b/python/metatomic_torch/metatomic_torch/utils.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/utils.py rename to python/metatomic_torch/metatomic_torch/utils.py diff --git a/python/metatomic_torch/metatomic/torch/version.py b/python/metatomic_torch/metatomic_torch/version.py similarity index 100% rename from python/metatomic_torch/metatomic/torch/version.py rename to python/metatomic_torch/metatomic_torch/version.py diff --git a/python/metatomic_torch/pyproject.toml b/python/metatomic_torch/pyproject.toml index 40259291d..fe432a0ae 100644 --- a/python/metatomic_torch/pyproject.toml +++ b/python/metatomic_torch/pyproject.toml @@ -48,10 +48,6 @@ backend-path = ["build-backend"] [tool.setuptools] zip-safe = false -[tool.setuptools.packages.find] -include = ["metatomic*"] -namespaces = true - ### ======================================================================== ### [tool.pytest.ini_options] python_files = ["*.py"] diff --git a/python/metatomic_torch/setup.py b/python/metatomic_torch/setup.py index 7f327b645..98b3d55c9 100644 --- a/python/metatomic_torch/setup.py +++ b/python/metatomic_torch/setup.py @@ -24,6 +24,7 @@ METATOMIC_TORCH_SRC = os.path.realpath( os.path.join(ROOT, "..", "..", "metatomic-torch") ) +METATOMIC_CORE = os.path.realpath(os.path.join(ROOT, "..", "metatomic_core")) METATOMIC_ASE = os.path.realpath(os.path.join(ROOT, "..", "metatomic_ase")) @@ -50,7 +51,7 @@ def run(self): source_dir = ROOT build_dir = os.path.join(ROOT, "build", "cmake-build") - install_dir = os.path.join(os.path.realpath(self.build_lib), "metatomic/torch") + install_dir = os.path.join(os.path.realpath(self.build_lib), "metatomic_torch") os.makedirs(build_dir, exist_ok=True) @@ -325,11 +326,14 @@ def create_version_number(version): # when packaging a sdist for release, we should never use local dependencies METATOMIC_NO_LOCAL_DEPS = os.environ.get("METATOMIC_NO_LOCAL_DEPS", "0") == "1" - if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_ASE): + if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_CORE): + assert os.path.exists(METATOMIC_ASE) # we are building from a git checkout or full repo archive + install_requires.append(f"metatomic-core @ file://{METATOMIC_CORE}") install_requires.append(f"metatomic-ase @ file://{METATOMIC_ASE}") else: # we are building from a sdist/installing from a wheel + install_requires.append("metatomic-core >=0.1.0,<0.2.0") install_requires.append("metatomic-ase >=0.1.1,<0.2.0") setup( diff --git a/scripts/clean-python.sh b/scripts/clean-python.sh index ba6a9e9f5..81e69b26a 100755 --- a/scripts/clean-python.sh +++ b/scripts/clean-python.sh @@ -14,9 +14,18 @@ rm -rf docs/build rm -rf docs/src/examples rm -rf docs/src/sg_execution_times.rst +rm -rf python/metatomic_core/dist +rm -rf python/metatomic_core/build + rm -rf python/metatomic_torch/dist rm -rf python/metatomic_torch/build +rm -rf python/metatomic_ase/dist +rm -rf python/metatomic_ase/build + +rm -rf python/metatomic_torchsim/dist +rm -rf python/metatomic_torchsim/build + find . -name "*.egg-info" -exec rm -rf "{}" + find . -name "__pycache__" -exec rm -rf "{}" + find . -name ".coverage" -exec rm -rf "{}" + diff --git a/setup.py b/setup.py index ced9f7146..2124530b5 100644 --- a/setup.py +++ b/setup.py @@ -4,29 +4,39 @@ ROOT = os.path.realpath(os.path.dirname(__file__)) +METATOMIC_CORE = os.path.join(ROOT, "python", "metatomic_core") METATOMIC_TORCH = os.path.join(ROOT, "python", "metatomic_torch") +METATOMIC_ASE = os.path.join(ROOT, "python", "metatomic_ase") METATOMIC_TORCHSIM = os.path.join(ROOT, "python", "metatomic_torchsim") if __name__ == "__main__": extras_require = {} + install_requires = [] # when packaging a sdist for release, we should never use local dependencies METATOMIC_NO_LOCAL_DEPS = os.environ.get("METATOMIC_NO_LOCAL_DEPS", "0") == "1" - if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_TORCH): + if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_CORE): + assert os.path.exists(METATOMIC_TORCH) + assert os.path.exists(METATOMIC_ASE) + assert os.path.exists(METATOMIC_TORCHSIM) + # we are building from a git checkout + install_requires.append(f"metatomic-core @ file://{METATOMIC_CORE}") extras_require["torch"] = f"metatomic-torch @ file://{METATOMIC_TORCH}" + extras_require["ase"] = f"metatomic-ase @ file://{METATOMIC_ASE}" + extras_require["torchsim"] = f"metatomic-torchsim @ file://{METATOMIC_TORCHSIM}" else: # we are building from a sdist/installing from a wheel - extras_require["torch"] = "metatomic-torch" + install_requires.append("metatomic-core") - if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_TORCHSIM): - extras_require["torchsim"] = f"metatomic-torchsim @ file://{METATOMIC_TORCHSIM}" - else: + extras_require["torch"] = "metatomic-torch" + extras_require["ase"] = "metatomic-ase" extras_require["torchsim"] = "metatomic-torchsim" setup( author=", ".join(open(os.path.join(ROOT, "AUTHORS")).read().splitlines()), + install_requires=install_requires, extras_require=extras_require, ) diff --git a/tox.ini b/tox.ini index 201c28456..6eae4e866 100644 --- a/tox.ini +++ b/tox.ini @@ -38,6 +38,7 @@ packaging_deps = testing_deps = pytest pytest-cov + pytest-custom_exit_code metatensor_deps = metatensor-torch >=0.9.0,<0.10 @@ -133,6 +134,7 @@ deps = changedir = python/metatomic_torch commands = + pip install {[testenv]build_single_wheel} ../metatomic_core pip install {[testenv]build_single_wheel} . pip install {[testenv]build_single_wheel} ../metatomic_ase @@ -157,12 +159,23 @@ deps = vesin >=0.5.6,<0.6 ase + torch-sim-atomistic + +setenv = + # ignore the fact that metatensor.torch.operations was loaded from a file + # not in `metatensor/torch/operations` + PY_IGNORE_IMPORTMISMATCH = 1 commands = + pip install {[testenv]build_single_wheel} python/metatomic_core pip install {[testenv]build_single_wheel} python/metatomic_torch pip install {[testenv]build_single_wheel} python/metatomic_ase + pip install {[testenv]build_single_wheel} python/metatomic_torchsim - pytest --doctest-modules --pyargs metatomic + pytest --suppress-no-test-exit-code --doctest-modules --pyargs metatomic + pytest --suppress-no-test-exit-code --doctest-modules --pyargs metatomic_torch + pytest --suppress-no-test-exit-code --doctest-modules --pyargs metatomic_ase + pytest --suppress-no-test-exit-code --doctest-modules --pyargs metatomic_torchsim ################################################################################ @@ -192,8 +205,9 @@ deps = changedir = python/metatomic_ase commands = - pip install {[testenv]build_single_wheel} . + pip install {[testenv]build_single_wheel} ../metatomic_core pip install {[testenv]build_single_wheel} ../metatomic_torch + pip install {[testenv]build_single_wheel} . # use the reference LJ implementation for tests {[testenv]install_lj_tests} @@ -222,8 +236,9 @@ deps = changedir = python/metatomic_torchsim commands = - pip install {[testenv]build_single_wheel} . + pip install {[testenv]build_single_wheel} ../metatomic_core pip install {[testenv]build_single_wheel} ../metatomic_torch + pip install {[testenv]build_single_wheel} . # use the reference LJ implementation for tests {[testenv]install_lj_tests} @@ -292,6 +307,7 @@ deps = chemiscope commands = + pip install {[testenv]build_single_wheel} python/metatomic_core pip install {[testenv]build_single_wheel} python/metatomic_torch pip install {[testenv]build_single_wheel} python/metatomic_ase pip install {[testenv]build_single_wheel} python/metatomic_torchsim From 0423b7aa17be4969ed102de7712772c230739263 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Thu, 21 May 2026 10:26:30 +0200 Subject: [PATCH 02/20] Use pathlib for all path manipulations --- python/metatomic_ase/setup.py | 21 ++++++------ python/metatomic_core/setup.py | 15 +++++---- python/metatomic_torch/setup.py | 53 ++++++++++++++---------------- python/metatomic_torchsim/setup.py | 19 ++++++----- setup.py | 29 ++++++++-------- 5 files changed, 70 insertions(+), 67 deletions(-) diff --git a/python/metatomic_ase/setup.py b/python/metatomic_ase/setup.py index a1930193c..83de5e00d 100644 --- a/python/metatomic_ase/setup.py +++ b/python/metatomic_ase/setup.py @@ -1,4 +1,5 @@ import os +import pathlib import subprocess import sys @@ -8,8 +9,8 @@ from setuptools.command.sdist import sdist -ROOT = os.path.realpath(os.path.dirname(__file__)) -METATOMIC_TORCH = os.path.realpath(os.path.join(ROOT, "..", "metatomic_torch")) +ROOT = pathlib.Path(__file__).parent.resolve() +METATOMIC_TORCH = (ROOT / ".." / "metatomic_torch").resolve() METATOMIC_ASE_VERSION = "0.1.1" @@ -53,15 +54,15 @@ def git_version_info(): """ TAG_PREFIX = "metatomic-ase-v" - if os.path.exists("git_version_info"): + if (ROOT / "git_version_info").exists(): # we are building from a sdist, without git available, but the git # version was recorded in the `git_version_info` file - with open("git_version_info") as fd: + with open(ROOT / "git_version_info") as fd: n_commits = int(fd.readline().strip()) git_hash = fd.readline().strip() else: - script = os.path.join(ROOT, "..", "..", "scripts", "git-version-info.py") - assert os.path.exists(script) + script = (ROOT / ".." / ".." / "scripts" / "git-version-info.py").resolve() + assert script.exists() output = subprocess.run( [sys.executable, script, TAG_PREFIX], @@ -127,19 +128,19 @@ def create_version_number(version): # when packaging a sdist for release, we should never use local dependencies METATOMIC_NO_LOCAL_DEPS = os.environ.get("METATOMIC_NO_LOCAL_DEPS", "0") == "1" - if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_TORCH): + if not METATOMIC_NO_LOCAL_DEPS and METATOMIC_TORCH.exists(): # we are building from a git checkout or full repo archive - install_requires.append(f"metatomic-torch @ file://{METATOMIC_TORCH}") + install_requires.append(f"metatomic-torch @ {METATOMIC_TORCH.as_uri()}") else: # we are building from a sdist/installing from a wheel install_requires.append("metatomic-torch >=0.1.12,<0.2") - with open(os.path.join(ROOT, "AUTHORS")) as fd: + with open(ROOT / "AUTHORS") as fd: authors = fd.read().splitlines() if authors[0].startswith(".."): # handle "raw" symlink files (on Windows or from full repo tarball) - with open(os.path.join(ROOT, authors[0])) as fd: + with open(ROOT / authors[0]) as fd: authors = fd.read().splitlines() setup( diff --git a/python/metatomic_core/setup.py b/python/metatomic_core/setup.py index f2654eef3..35a2ef16c 100644 --- a/python/metatomic_core/setup.py +++ b/python/metatomic_core/setup.py @@ -1,4 +1,5 @@ import os +import pathlib import subprocess import sys @@ -8,7 +9,7 @@ from setuptools.command.sdist import sdist -ROOT = os.path.realpath(os.path.dirname(__file__)) +ROOT = pathlib.Path(__file__).parent.resolve() METATOMIC_CORE_VERSION = "0.1.0" @@ -59,15 +60,15 @@ def git_version_info(): """ TAG_PREFIX = "metatomic-v" - if os.path.exists("git_version_info"): + if (ROOT / "git_version_info").exists(): # we are building from a sdist, without git available, but the git # version was recorded in the `git_version_info` file - with open("git_version_info") as fd: + with open(ROOT / "git_version_info") as fd: n_commits = int(fd.readline().strip()) git_hash = fd.readline().strip() else: - script = os.path.join(ROOT, "..", "..", "scripts", "git-version-info.py") - assert os.path.exists(script) + script = (ROOT / ".." / ".." / "scripts" / "git-version-info.py").resolve() + assert script.exists() output = subprocess.run( [sys.executable, script, TAG_PREFIX], @@ -123,12 +124,12 @@ def create_version_number(version): if __name__ == "__main__": - with open(os.path.join(ROOT, "AUTHORS")) as fd: + with open(ROOT / "AUTHORS") as fd: authors = fd.read().splitlines() if authors[0].startswith(".."): # handle "raw" symlink files (on Windows or from full repo tarball) - with open(os.path.join(ROOT, authors[0])) as fd: + with open(ROOT / authors[0]) as fd: authors = fd.read().splitlines() install_requires = [ diff --git a/python/metatomic_torch/setup.py b/python/metatomic_torch/setup.py index 98b3d55c9..9524ac444 100644 --- a/python/metatomic_torch/setup.py +++ b/python/metatomic_torch/setup.py @@ -1,5 +1,6 @@ import glob import os +import pathlib import subprocess import sys @@ -12,7 +13,7 @@ from setuptools.command.sdist import sdist -ROOT = os.path.realpath(os.path.dirname(__file__)) +ROOT = pathlib.Path(__file__).parent.resolve() METATOMIC_BUILD_TYPE = os.environ.get("METATOMIC_BUILD_TYPE", "release") if METATOMIC_BUILD_TYPE not in ["debug", "release"]: @@ -21,11 +22,9 @@ "expected 'debug' or 'release'" ) -METATOMIC_TORCH_SRC = os.path.realpath( - os.path.join(ROOT, "..", "..", "metatomic-torch") -) -METATOMIC_CORE = os.path.realpath(os.path.join(ROOT, "..", "metatomic_core")) -METATOMIC_ASE = os.path.realpath(os.path.join(ROOT, "..", "metatomic_ase")) +METATOMIC_TORCH_SRC = (ROOT / ".." / ".." / "metatomic-torch").resolve() +METATOMIC_CORE = (ROOT / ".." / "metatomic_core").resolve() +METATOMIC_ASE = (ROOT / ".." / "metatomic_ase").resolve() class universal_wheel(bdist_wheel): @@ -50,10 +49,10 @@ def run(self): import torch source_dir = ROOT - build_dir = os.path.join(ROOT, "build", "cmake-build") - install_dir = os.path.join(os.path.realpath(self.build_lib), "metatomic_torch") + build_dir = ROOT / "build" / "cmake-build" + install_dir = pathlib.Path(self.build_lib).resolve() / "metatomic_torch" - os.makedirs(build_dir, exist_ok=True) + build_dir.mkdir(parents=True, exist_ok=True) # Tell CMake where to find metatensor, metatensor_torch, and torch cmake_prefix_path = [ @@ -66,9 +65,7 @@ def run(self): # compile the code. This allows having multiple version of this shared library # inside the wheel; and dynamically pick the right one. torch_major, torch_minor, *_ = torch.__version__.split(".") - cmake_install_prefix = os.path.join( - install_dir, f"torch-{torch_major}.{torch_minor}" - ) + cmake_install_prefix = install_dir / f"torch-{torch_major}.{torch_minor}" use_external_lib = os.environ.get( "METATOMIC_TORCH_PYTHON_USE_EXTERNAL_LIB", "OFF" @@ -142,8 +139,8 @@ def run(self): def generate_cxx_tar(): - script = os.path.join(ROOT, "..", "..", "scripts", "package-torch.sh") - assert os.path.exists(script) + script = (ROOT / ".." / ".." / "scripts" / "package-torch.sh").resolve() + assert script.exists() try: output = subprocess.run( @@ -180,15 +177,15 @@ def git_version_info(): """ TAG_PREFIX = "metatomic-torch-v" - if os.path.exists("git_version_info"): + if (ROOT / "git_version_info").exists(): # we are building from a sdist, without git available, but the git # version was recorded in the `git_version_info` file - with open("git_version_info") as fd: + with open(ROOT / "git_version_info") as fd: n_commits = int(fd.readline().strip()) git_hash = fd.readline().strip() else: - script = os.path.join(ROOT, "..", "..", "scripts", "git-version-info.py") - assert os.path.exists(script) + script = (ROOT / ".." / ".." / "scripts" / "git-version-info.py").resolve() + assert script.exists() output = subprocess.run( [sys.executable, script, TAG_PREFIX], @@ -275,10 +272,10 @@ def create_version_number(version): # End of Windows/MKL/PIP hack - if not os.path.exists(METATOMIC_TORCH_SRC): + if not METATOMIC_TORCH_SRC.exists(): # we are building from a sdist, which should include metatomic-torch C++ # sources as a tarball - tarballs = glob.glob(os.path.join(ROOT, "metatomic-torch-cxx-*.tar.gz")) + tarballs = glob.glob(ROOT / "metatomic-torch-cxx-*.tar.gz") if not len(tarballs) == 1: raise RuntimeError( @@ -286,7 +283,7 @@ def create_version_number(version): "metatomic-torch C++ sources" ) - METATOMIC_TORCH_SRC = os.path.realpath(tarballs[0]) + METATOMIC_TORCH_SRC = pathlib.Path(tarballs[0]).resolve() subprocess.run( ["cmake", "-E", "tar", "xf", METATOMIC_TORCH_SRC], cwd=ROOT, @@ -295,15 +292,15 @@ def create_version_number(version): METATOMIC_TORCH_SRC = ".".join(METATOMIC_TORCH_SRC.split(".")[:-2]) - with open(os.path.join(METATOMIC_TORCH_SRC, "VERSION")) as fd: + with open(METATOMIC_TORCH_SRC / "VERSION") as fd: METATOMIC_TORCH_VERSION = fd.read().strip() - with open(os.path.join(ROOT, "AUTHORS")) as fd: + with open(ROOT / "AUTHORS") as fd: authors = fd.read().splitlines() if authors[0].startswith(".."): # handle "raw" symlink files (on Windows or from full repo tarball) - with open(os.path.join(ROOT, authors[0])) as fd: + with open(ROOT / authors[0]) as fd: authors = fd.read().splitlines() try: @@ -326,11 +323,11 @@ def create_version_number(version): # when packaging a sdist for release, we should never use local dependencies METATOMIC_NO_LOCAL_DEPS = os.environ.get("METATOMIC_NO_LOCAL_DEPS", "0") == "1" - if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_CORE): - assert os.path.exists(METATOMIC_ASE) + if not METATOMIC_NO_LOCAL_DEPS and METATOMIC_CORE.exists(): + assert METATOMIC_ASE.exists() # we are building from a git checkout or full repo archive - install_requires.append(f"metatomic-core @ file://{METATOMIC_CORE}") - install_requires.append(f"metatomic-ase @ file://{METATOMIC_ASE}") + install_requires.append(f"metatomic-core @ {METATOMIC_CORE.as_uri()}") + install_requires.append(f"metatomic-ase @ {METATOMIC_ASE.as_uri()}") else: # we are building from a sdist/installing from a wheel install_requires.append("metatomic-core >=0.1.0,<0.2.0") diff --git a/python/metatomic_torchsim/setup.py b/python/metatomic_torchsim/setup.py index f3d5d0252..2e403a075 100644 --- a/python/metatomic_torchsim/setup.py +++ b/python/metatomic_torchsim/setup.py @@ -1,4 +1,5 @@ import os +import pathlib import subprocess import sys @@ -7,8 +8,8 @@ from setuptools.command.sdist import sdist -ROOT = os.path.realpath(os.path.dirname(__file__)) -METATOMIC_TORCH = os.path.realpath(os.path.join(ROOT, "..", "metatomic_torch")) +ROOT = pathlib.Path(__file__).parent.resolve() +METATOMIC_TORCH = (ROOT / ".." / "metatomic_torch").resolve() METATOMIC_TORCHSIM_VERSION = "0.1.3" @@ -38,15 +39,15 @@ def git_version_info(): """ TAG_PREFIX = "metatomic-torchsim-v" - if os.path.exists("git_version_info"): + if (ROOT / "git_version_info").exists(): # we are building from a sdist, without git available, but the git # version was recorded in the `git_version_info` file - with open("git_version_info") as fd: + with open(ROOT / "git_version_info") as fd: n_commits = int(fd.readline().strip()) git_hash = fd.readline().strip() else: - script = os.path.join(ROOT, "..", "..", "scripts", "git-version-info.py") - assert os.path.exists(script) + script = (ROOT / ".." / ".." / "scripts" / "git-version-info.py").resolve() + assert script.exists() output = subprocess.run( [sys.executable, script, TAG_PREFIX], @@ -102,7 +103,7 @@ def create_version_number(version): if __name__ == "__main__": - with open(os.path.join(ROOT, "AUTHORS")) as fd: + with open(ROOT / "AUTHORS") as fd: authors = fd.read().splitlines() install_requires = [ @@ -113,9 +114,9 @@ def create_version_number(version): # when packaging a sdist for release, we should never use local dependencies METATOMIC_NO_LOCAL_DEPS = os.environ.get("METATOMIC_NO_LOCAL_DEPS", "0") == "1" - if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_TORCH): + if not METATOMIC_NO_LOCAL_DEPS and METATOMIC_TORCH.exists(): # we are building from a git checkout or full repo archive - install_requires.append(f"metatomic-torch @ file://{METATOMIC_TORCH}") + install_requires.append(f"metatomic-torch @ {METATOMIC_TORCH.as_uri()}") else: # we are building from a sdist/installing from a wheel install_requires.append("metatomic-torch >=0.1.12,<0.2") diff --git a/setup.py b/setup.py index 2124530b5..69699d06e 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,14 @@ import os +import pathlib from setuptools import setup -ROOT = os.path.realpath(os.path.dirname(__file__)) -METATOMIC_CORE = os.path.join(ROOT, "python", "metatomic_core") -METATOMIC_TORCH = os.path.join(ROOT, "python", "metatomic_torch") -METATOMIC_ASE = os.path.join(ROOT, "python", "metatomic_ase") -METATOMIC_TORCHSIM = os.path.join(ROOT, "python", "metatomic_torchsim") +ROOT = pathlib.Path(__file__).parent.resolve() +METATOMIC_CORE = (ROOT / "python" / "metatomic_core").resolve() +METATOMIC_TORCH = (ROOT / "python" / "metatomic_torch").resolve() +METATOMIC_ASE = (ROOT / "python" / "metatomic_ase").resolve() +METATOMIC_TORCHSIM = (ROOT / "python" / "metatomic_torchsim").resolve() if __name__ == "__main__": @@ -17,16 +18,18 @@ # when packaging a sdist for release, we should never use local dependencies METATOMIC_NO_LOCAL_DEPS = os.environ.get("METATOMIC_NO_LOCAL_DEPS", "0") == "1" - if not METATOMIC_NO_LOCAL_DEPS and os.path.exists(METATOMIC_CORE): - assert os.path.exists(METATOMIC_TORCH) - assert os.path.exists(METATOMIC_ASE) - assert os.path.exists(METATOMIC_TORCHSIM) + if not METATOMIC_NO_LOCAL_DEPS and METATOMIC_CORE.exists(): + assert METATOMIC_TORCH.exists() + assert METATOMIC_ASE.exists() + assert METATOMIC_TORCHSIM.exists() # we are building from a git checkout - install_requires.append(f"metatomic-core @ file://{METATOMIC_CORE}") - extras_require["torch"] = f"metatomic-torch @ file://{METATOMIC_TORCH}" - extras_require["ase"] = f"metatomic-ase @ file://{METATOMIC_ASE}" - extras_require["torchsim"] = f"metatomic-torchsim @ file://{METATOMIC_TORCHSIM}" + install_requires.append(f"metatomic-core @ {METATOMIC_CORE.as_uri()}") + extras_require["torch"] = f"metatomic-torch @ {METATOMIC_TORCH.as_uri()}" + extras_require["ase"] = f"metatomic-ase @ {METATOMIC_ASE.as_uri()}" + extras_require["torchsim"] = ( + f"metatomic-torchsim @ {METATOMIC_TORCHSIM.as_uri()}" + ) else: # we are building from a sdist/installing from a wheel install_requires.append("metatomic-core") From 0f18b704381cfca98f90da76027f2f5ae04937a0 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Tue, 12 May 2026 15:26:39 +0200 Subject: [PATCH 03/20] Switch main test runner from tox to cargo --- .github/workflows/python-tests.yml | 96 +++++ .github/workflows/torch-tests.yml | 96 ++--- .gitignore | 3 + CONTRIBUTING.rst | 80 +++- Cargo.toml | 7 + docs/src/devdoc/get-started.rst | 6 + docs/src/devdoc/index.rst | 26 ++ docs/src/index.rst | 1 + metatomic-torch/Cargo.toml | 13 + metatomic-torch/lib.rs | 1 + metatomic-torch/tests/CMakeLists.txt | 6 +- metatomic-torch/tests/check-torch-install.rs | 207 ++++++++++ metatomic-torch/tests/run-torch-tests.rs | 47 +++ metatomic-torch/tests/utils/mod.rs | 410 +++++++++++++++++++ python/Cargo.toml | 12 + python/lib.rs | 1 + python/tests/run-python-tests.rs | 23 ++ tox.ini | 71 ---- 18 files changed, 971 insertions(+), 135 deletions(-) create mode 100644 .github/workflows/python-tests.yml create mode 100644 Cargo.toml create mode 100644 docs/src/devdoc/get-started.rst create mode 100644 docs/src/devdoc/index.rst create mode 100644 metatomic-torch/Cargo.toml create mode 100644 metatomic-torch/lib.rs create mode 100644 metatomic-torch/tests/check-torch-install.rs create mode 100644 metatomic-torch/tests/run-torch-tests.rs create mode 100644 metatomic-torch/tests/utils/mod.rs create mode 100644 python/Cargo.toml create mode 100644 python/lib.rs create mode 100644 python/tests/run-python-tests.rs diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml new file mode 100644 index 000000000..da3944fbf --- /dev/null +++ b/.github/workflows/python-tests.yml @@ -0,0 +1,96 @@ +name: Python tests + +on: + push: + branches: [main] + pull_request: + # Check all PR + +concurrency: + group: python-tests-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + python-tests: + runs-on: ${{ matrix.os }} + name: ${{ matrix.os }} / Python ${{ matrix.python-version }} / Torch ${{ matrix.torch-version }} + strategy: + matrix: + include: + - os: ubuntu-24.04 + python-version: "3.10" + torch-version: "2.3" + numpy-version-pin: "<2.0" + # Do not run docs-tests with python 3.10 since torch-sim-atomistic + # is not available for this version of python + tox-envs: lint,torch-tests + - os: ubuntu-24.04 + python-version: "3.10" + torch-version: "2.12" + # See above + tox-envs: lint,torch-tests + - os: ubuntu-24.04 + # TorchScript is no longer supported in Python 3.14 + # so we keep a test with 3.13 to make sure this doesn't break + python-version: "3.13" + torch-version: "2.12" + tox-envs: lint,torch-tests,docs-tests + - os: ubuntu-24.04 + python-version: "3.14" + torch-version: "2.12" + tox-envs: lint,torch-tests,docs-tests + - os: macos-15 + python-version: "3.14" + torch-version: "2.12" + tox-envs: lint,torch-tests,docs-tests + - os: windows-2022 + python-version: "3.14" + torch-version: "2.12" + tox-envs: lint,torch-tests,docs-tests + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: setup Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: setup rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + + - name: Cache Rust dependencies + uses: Leafwing-Studios/cargo-cache@v2.6.1 + with: + sweep-cache: true + + - name: Setup sccache + if: ${{ !env.ACT }} + uses: mozilla-actions/sccache-action@v0.0.10 + with: + version: "v0.10.0" + + - name: setup MSVC command prompt + uses: ilammy/msvc-dev-cmd@v1 + + - name: Setup sccache environnement variables + if: ${{ !env.ACT }} + run: | + echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV + echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV + echo "CMAKE_C_COMPILER_LAUNCHER=sccache" >> $GITHUB_ENV + echo "CMAKE_CXX_COMPILER_LAUNCHER=sccache" >> $GITHUB_ENV + + - name: install tests dependencies + run: | + python -m pip install --upgrade pip + python -m pip install tox coverage + + - name: run tests + run: tox -e ${{ matrix.tox-envs }} + env: + PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu + METATOMIC_TESTS_TORCH_VERSION: ${{ matrix.torch-version }} diff --git a/.github/workflows/torch-tests.yml b/.github/workflows/torch-tests.yml index 1c549795a..62ed6025c 100644 --- a/.github/workflows/torch-tests.yml +++ b/.github/workflows/torch-tests.yml @@ -13,81 +13,85 @@ concurrency: jobs: tests: runs-on: ${{ matrix.os }} - name: ${{ matrix.os }} / Python ${{ matrix.python-version }} / Torch ${{ matrix.torch-version }} + name: ${{ matrix.os }} / Torch ${{ matrix.torch-version }}${{ matrix.extra-name }} + container: ${{ matrix.container }} strategy: matrix: include: - os: ubuntu-24.04 - python-version: "3.10" - torch-version: "2.3" - - os: ubuntu-24.04 - python-version: "3.10" torch-version: "2.12" - - os: ubuntu-24.04 - # Keep a building with Python 3.13 since TorchScript is deprecated - # in Python 3.14 - python-version: "3.13" - torch-version: "2.12" - - os: ubuntu-24.04 python-version: "3.14" - torch-version: "2.12" + cargo-test-flags: --release + do-valgrind: true + + # check the build on a stock Ubuntu 22.04, which uses cmake 3.22 + - os: ubuntu-24.04 + container: ubuntu:22.04 + extra-name: ", cmake 3.22" + torch-version: "2.3" + cargo-test-flags: "" + - os: macos-15 - python-version: "3.14" torch-version: "2.12" - - os: windows-2022 python-version: "3.14" + cargo-test-flags: --release + + - os: windows-2022 torch-version: "2.12" + python-version: "3.14" + cargo-test-flags: --release steps: + - name: install dependencies in container + if: matrix.container == 'ubuntu:22.04' + run: | + apt update + apt install -y software-properties-common + add-apt-repository ppa:deadsnakes/ppa + apt install -y cmake make gcc g++ git curl python3.10 python3.10-venv + + update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 1 + - uses: actions/checkout@v6 with: fetch-depth: 0 - - name: setup Python - uses: actions/setup-python@v6 + - name: Configure git safe directory + if: matrix.container == 'ubuntu:22.04' + run: git config --global --add safe.directory /__w/metatomic/metatomic + + - name: setup rust + uses: dtolnay/rust-toolchain@master with: - python-version: ${{ matrix.python-version }} + toolchain: stable + + - name: Cache Rust dependencies + uses: Leafwing-Studios/cargo-cache@v2.6.1 + with: + sweep-cache: true + + - name: install valgrind + if: matrix.do-valgrind + run: | + sudo apt-get install -y valgrind - name: Setup sccache + if: ${{ !env.ACT }} uses: mozilla-actions/sccache-action@v0.0.10 with: version: "v0.10.0" - - name: setup MSVC command prompt - uses: ilammy/msvc-dev-cmd@v1 - - name: Setup sccache environnement variables + if: ${{ !env.ACT }} run: | echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV echo "CMAKE_C_COMPILER_LAUNCHER=sccache" >> $GITHUB_ENV echo "CMAKE_CXX_COMPILER_LAUNCHER=sccache" >> $GITHUB_ENV - - name: install tests dependencies - run: | - python -m pip install --upgrade pip - python -m pip install tox coverage - - - name: run Python tests - run: tox -e lint,torch-tests,docs-tests + - name: run TorchScript C++ tests + run: cargo test --package metatomic-torch ${{ matrix.cargo-test-flags }} env: + # Use the CPU only version of torch when building/running the code PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu METATOMIC_TESTS_TORCH_VERSION: ${{ matrix.torch-version }} - - - name: run C++ tests - run: tox -e torch-tests-cxx,torch-install-tests-cxx - env: - PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu - METATOMIC_TESTS_TORCH_VERSION: ${{ matrix.torch-version }} - - - name: combine Python coverage files - shell: bash - run: | - coverage combine .tox/*/.coverage - coverage xml - - - name: upload to codecov.io - uses: codecov/codecov-action@v6 - with: - fail_ci_if_error: true - files: coverage.xml - token: ${{ secrets.CODECOV_TOKEN }} + CXXFLAGS: ${{ matrix.cxx-flags }} diff --git a/.gitignore b/.gitignore index ab865aa23..265263ff8 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,6 @@ build/ htmlcov/ .coverage* coverage.xml + +Cargo.lock +target/ diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 50c8dc986..e62180b83 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -16,6 +16,10 @@ on metatomic: - **git**: the software we use for version control of the source code. See https://git-scm.com/downloads for installation instructions. +- **the rust compiler**: you will need both ``rustc`` (the compiler) and + ``cargo`` (associated build tool). You can install both using `rustup`_, or + use a version provided by your operating system. We need at least Rust version + 1.74 to build metatomic. - **Python**: you can install ``Python`` and ``pip`` on your operating system. We require a Python version of at least 3.9. - **tox**: a Python test runner, see https://tox.readthedocs.io/en/latest/. You @@ -28,17 +32,21 @@ not have to interact with them directly: - **a C++ compiler** we need a compiler supporting C++11. GCC >= 7, clang >= 5 and MSVC >= 19 should all work, although MSVC is not yet tested continuously. +.. _rustup: https://rustup.rs +.. _`cargo` : https://doc.rust-lang.org/cargo/ +.. _tox: https://tox.readthedocs.io/en/latest + .. admonition:: Optional tools Depending on which part of the code you are working on, you might experience a - lot of time spent re-compiling code, even if you did not directly change them. - For faster builds (and in turn faster tests), you can use compiler cache, like - `sccache`_ or the classic `ccache`_ to reduce the recompilation of unchanged - source code. To do this, you should install and configure one of these tools - (we suggest ``sccache`` since it also supports Rust), and then configure - ``cmake`` and ``cargo`` to use them by setting environnement variables. On - Linux and macOS, you should set the following (look up how to do set - environment variable with your shell): + lot of time spend re-compiling Rust or C++ code, even if you did not change + them. If you'd like faster builds (and in turn faster tests), you can use + `sccache`_ or the classic `ccache`_ to only re-run the compiler if the + corresponding source code changed. To do this, you should install and configure + one of these tools (we suggest sccache since it also supports Rust), and then + configure cmake and cargo to use them by setting environnement variables. On + Linux and macOS, you should set the following (look up how to do set environment + variable with your shell): .. code-block:: bash @@ -88,32 +96,70 @@ changes: Running tests ------------- -The continuous integration pipeline is based on `tox`_. You can run all tests +The continuous integration pipeline is based on `cargo`_. You can run all tests with: .. code-block:: bash cd - tox + cargo test # or cargo test --release to run tests in release mode -These are exactly the same tests that will be performed online in our Github CI +These are exactly the same tests that will be performed online in our GitHub CI workflows. You can also run only a subset of tests with one of these commands: +- ``cargo test`` runs everything + +- ``cargo test --package=metatomic-torch`` to run the C++ TorchScript tests only; + + - ``cargo test --test=run-torch-tests`` will run the unit tests for the + TorchScript C++ extension; + - ``cargo test --test=check-cxx-install`` will build the C++ TorchScript + extension, install it and then try to build a basic project depending on + this extension with CMake; + +- ``cargo test --package=metatomic-python`` (or ``tox`` directly, see below) to + run Python tests only; +- ``cargo test --lib`` to run unit tests; +- ``cargo test --doc`` to run documentation tests; +- ``cargo bench --test`` compiles and run the benchmarks once, to quickly ensure + they still work. + +You can add some flags to any of above commands to further refine which tests +should run: + +- ``--release`` to run tests in release mode (default is to run tests in debug mode) +- ``-- `` to only run tests whose name contains filter, for example ``cargo test -- system`` + +Also, you can run individual Python tests using `tox`_ if you wish to run a +subset of Python tests, for example: + .. code-block:: bash tox -e lint # check files for formatting errors tox -e torch-tests # unit tests for metatomic-torch, in Python - tox -e torch-tests-cxx # unit tests for metatomic-torch, in C++ - tox -e torch-install-tests-cxx # testing that the C++ code is a valid CMake package + tox -e ase-tests # unit tests for metatomic-ase, in Python + tox -e torchsim-tests # unit tests for metatomic-torchsim, in Python tox -e docs-tests # doctests (checking inline examples) for all packages - tox -e lint # code style tox -e format # format all files -The last command ``tox -e format`` will use ``tox`` to do actual formatting -instead of just checking it, you can use this to automatically fix some of the -issues detected by ``tox -e lint``. +The last command ``tox -e format`` will use tox to do actual formatting instead +of just checking it, you can use to automatically fix some of the issues +detected by ``tox -e lint``. + +You can run only a subset of the tests with ``tox -e tests -- ``, +replacing ```` with the path to the files you want to test, e.g. +``tox -e tests -- python/tests/operations/abs.py``. + +To get the release build for ``tox`` runs, set the environment variable. + +.. code-block:: bash + + METATOMIC_BUILD_TYPE="release" tox -e torch-tests + +This corresponds to running ``cargo test --package-metatensor-python --release`` +but on the subset of interest. You can run only a subset of the tests with ``tox -e torch-tests -- ``, replacing ```` with the path to the files you diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..5256b9601 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,7 @@ +[workspace] +resolver = "2" + +members = [ + "metatomic-torch", + "python", +] diff --git a/docs/src/devdoc/get-started.rst b/docs/src/devdoc/get-started.rst new file mode 100644 index 000000000..4c19e4ef6 --- /dev/null +++ b/docs/src/devdoc/get-started.rst @@ -0,0 +1,6 @@ +.. _devdoc-get-started: + +Getting started +=============== + +.. include:: ../../../CONTRIBUTING.rst diff --git a/docs/src/devdoc/index.rst b/docs/src/devdoc/index.rst new file mode 100644 index 000000000..43755fdf6 --- /dev/null +++ b/docs/src/devdoc/index.rst @@ -0,0 +1,26 @@ +.. _devdoc: + +Developer documentation +####################### + +This developer documentation contains the following sections: + +1. :ref:`devdoc-get-started` explains how you can start developing code and + documentation; + +.. toctree:: + :maxdepth: 2 + + get-started + +Development team +---------------- + +Metatensor is developed in the `COSMO laboratory`_ at `EPFL`_, and made +available under the `BSD 3-clauses license `_. We welcome +contributions from anyone, feel free to contact us if you need some help working +with the code! + +.. _COSMO laboratory: https://www.epfl.ch/labs/cosmo/ +.. _EPFL: https://www.epfl.ch/ +.. _LICENSE: https://github.com/metatensor/metatensor/blob/main/LICENSE diff --git a/docs/src/index.rst b/docs/src/index.rst index d94c6ded2..170c25c19 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -96,4 +96,5 @@ existing trained models, look into the metatrain_ project instead. quantities/index engines/index examples/index + devdoc/index cite diff --git a/metatomic-torch/Cargo.toml b/metatomic-torch/Cargo.toml new file mode 100644 index 000000000..3809a5a99 --- /dev/null +++ b/metatomic-torch/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "metatomic-torch" +version = "0.0.0" +edition = "2021" +publish = false +rust-version = "1.74" + +[lib] +path = "lib.rs" + +[dev-dependencies] +lazy_static = "1" +which = "8" diff --git a/metatomic-torch/lib.rs b/metatomic-torch/lib.rs new file mode 100644 index 000000000..59bc69bb6 --- /dev/null +++ b/metatomic-torch/lib.rs @@ -0,0 +1 @@ +// empty lib.rs, this crate only exists to run TorchScript C++ tests with cargo diff --git a/metatomic-torch/tests/CMakeLists.txt b/metatomic-torch/tests/CMakeLists.txt index 89a3db0f2..8a64a4f33 100644 --- a/metatomic-torch/tests/CMakeLists.txt +++ b/metatomic-torch/tests/CMakeLists.txt @@ -14,9 +14,11 @@ if (VALGRIND) "--leak-check=full" "--show-leak-kinds=definite,indirect,possible" "--track-origins=yes" "--gen-suppressions=all" "--suppressions=${CMAKE_CURRENT_SOURCE_DIR}/valgrind.supp" ) + set(USING_VALGRIND ON) endif() else() set(TEST_COMMAND "") + set(USING_VALGRIND OFF) endif() @@ -46,7 +48,9 @@ foreach(_file_ ${ALL_TESTS}) ) # stop tests if they run for more than 30s - set_tests_properties(torch-${_name_} PROPERTIES TIMEOUT 30) + if (NOT USING_VALGRIND) + set_tests_properties(torch-${_name_} PROPERTIES TIMEOUT 30) + endif() if(WIN32) # We need to set the path to allow access to torch.dll diff --git a/metatomic-torch/tests/check-torch-install.rs b/metatomic-torch/tests/check-torch-install.rs new file mode 100644 index 000000000..8883d916e --- /dev/null +++ b/metatomic-torch/tests/check-torch-install.rs @@ -0,0 +1,207 @@ +use std::path::PathBuf; +use std::sync::Mutex; + +mod utils; + +lazy_static::lazy_static! { + // Make sure only one of the tests below run at the time, since they both + // try to modify the same files + static ref LOCK: Mutex<()> = Mutex::new(()); +} + +/// Check that metatomic-torch can be built and installed with cmake, and that +/// the installed version can be used from another cmake project with +/// `find_package` +#[test] +fn check_torch_install() { + let _guard = match LOCK.lock() { + Ok(guard) => guard, + Err(_) => { + panic!("another test failed, stopping") + } + }; + + const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); + let cargo_manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); + + // ====================================================================== // + // build and install metatensor-torch with cmake + let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); + build_dir.push("torch-install"); + build_dir.push("cmake-find-package"); + std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); + + + let deps_dir = build_dir.join("deps"); + + let torch_dep = deps_dir.join("virtualenv"); + std::fs::create_dir_all(&torch_dep).expect("failed to create virtualenv dir"); + let python = utils::create_python_venv(torch_dep); + let pytorch_cmake_prefix = utils::setup_torch_pip(&python); + let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python); + let metatensor_torch_cmake_prefix = utils::setup_metatensor_torch_pip(&python); + + // configure cmake for metatomic-torch + let metatomic_torch_dep = deps_dir.join("metatomic-torch"); + + let cmake_options = vec![ + format!( + "-DCMAKE_PREFIX_PATH={};{};{}", + pytorch_cmake_prefix.display(), + metatensor_cmake_prefix.display(), + metatensor_torch_cmake_prefix.display() + ), + // The two properties below handle the RPATH for metatomic_torch, + // setting it in such a way that we can always load libmetatensor.so and + // libtorch.so from the location they are found at when compiling + // metatomic-torch. See + // https://gitlab.kitware.com/cmake/community/-/wikis/doc/cmake/RPATH-handling + // for more information on CMake RPATH handling + "-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON".into(), + "-DCMAKE_INSTALL_RPATH_USE_LINK_PATH=ON".into(), + ]; + + let install_prefix = utils::setup_metatomic_torch_cmake( + &cargo_manifest_dir, + &metatomic_torch_dep, + cmake_options, + ); + + // ====================================================================== // + // // try to use the installed metatomic-torch from cmake + let mut source_dir = PathBuf::from(&cargo_manifest_dir); + source_dir.extend(["tests", "cmake-project"]); + + // configure cmake for the test cmake project + let mut cmake_config = utils::cmake_config(&source_dir, &build_dir); + cmake_config.arg(format!( + "-DCMAKE_PREFIX_PATH={};{};{};{}", + metatensor_cmake_prefix.display(), + pytorch_cmake_prefix.display(), + metatensor_torch_cmake_prefix.display(), + install_prefix.display(), + )); + + utils::run_command(cmake_config, "cmake configuration"); + + // build the code, linking to metatomic-torch + let cmake_build = utils::cmake_build(&build_dir); + utils::run_command(cmake_build, "cmake build"); + + // run the executables + let ctest = utils::ctest(&build_dir); + utils::run_command(ctest, "ctest"); +} + +/// Same as above, but using pre-built metatensor-torch from the Python wheel, +/// instead of building it from source with cmake. +#[test] +fn check_python_install() { + let _guard = match LOCK.lock() { + Ok(guard) => guard, + Err(_) => { + panic!("another test failed, stopping") + } + }; + + const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); + + // ====================================================================== // + // build and install metatensor and metatensor-torch with pip + let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); + build_dir.push("torch-install"); + build_dir.push("python-wheels"); + std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); + + let mut venv_dir = build_dir.clone(); + venv_dir.push("virtualenv"); + + let python_exe = utils::create_python_venv(venv_dir); + + let cargo_manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); + let pytorch_cmake_prefix = utils::setup_torch_pip(&python_exe); + let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python_exe); + let metatensor_torch_cmake_prefix = utils::setup_metatensor_torch_pip(&python_exe); + + let python_source_dir = cargo_manifest_dir.parent().unwrap().join("python").join("metatomic_torch"); + let metatomic_torch_cmake_prefix = utils::setup_metatomic_torch_pip(&python_exe, &python_source_dir); + + // ====================================================================== // + // try to use the installed metatensor-torch from cmake + let mut source_dir = PathBuf::from(&cargo_manifest_dir); + source_dir.extend(["tests", "cmake-project"]); + + // configure cmake for the test cmake project + let mut cmake_config = utils::cmake_config(&source_dir, &build_dir); + cmake_config.arg(format!( + "-DCMAKE_PREFIX_PATH={};{};{};{}", + pytorch_cmake_prefix.display(), + metatensor_cmake_prefix.display(), + metatensor_torch_cmake_prefix.display(), + metatomic_torch_cmake_prefix.display(), + )); + + utils::run_command(cmake_config, "cmake configuration"); + + // build the code, linking to metatensor-torch + let cmake_build = utils::cmake_build(&build_dir); + utils::run_command(cmake_build, "cmake build"); + + // run the executables + let ctest = utils::ctest(&build_dir); + utils::run_command(ctest, "ctest"); +} + +/// Same test as above, but building metatomic-torch in the same +/// CMake project (i.e. using add_subdirectory instead of find_package) +#[test] +fn check_cmake_subdirectory() { + let _guard = match LOCK.lock() { + Ok(guard) => guard, + Err(_) => { + panic!("another test failed, stopping") + } + }; + + const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); + + // install torch + let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); + build_dir.push("torch-install"); + build_dir.push("cmake-subdirectory"); + std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); + + let deps_dir = build_dir.join("deps"); + + let torch_dep = deps_dir.join("virtualenv"); + std::fs::create_dir_all(&torch_dep).expect("failed to create virtualenv dir"); + let python = utils::create_python_venv(torch_dep); + let pytorch_cmake_prefix = utils::setup_torch_pip(&python); + let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python); + let metatensor_torch_cmake_prefix = utils::setup_metatensor_torch_pip(&python); + + // ====================================================================== // + let cargo_manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); + let mut source_dir = PathBuf::from(&cargo_manifest_dir); + source_dir.extend(["tests", "cmake-project"]); + + // configure cmake for the test cmake project + let mut cmake_config = utils::cmake_config(&source_dir, &build_dir); + cmake_config.arg(format!( + "-DCMAKE_PREFIX_PATH={};{};{}", + pytorch_cmake_prefix.display(), + metatensor_cmake_prefix.display(), + metatensor_torch_cmake_prefix.display() + )); + cmake_config.arg("-DUSE_CMAKE_SUBDIRECTORY=ON"); + + utils::run_command(cmake_config, "cmake configuration"); + + // build the code, linking to metatomic-torch + let cmake_build = utils::cmake_build(&build_dir); + utils::run_command(cmake_build, "cmake build"); + + // run the executables + let ctest = utils::ctest(&build_dir); + utils::run_command(ctest, "ctest"); +} diff --git a/metatomic-torch/tests/run-torch-tests.rs b/metatomic-torch/tests/run-torch-tests.rs new file mode 100644 index 000000000..93772f0a6 --- /dev/null +++ b/metatomic-torch/tests/run-torch-tests.rs @@ -0,0 +1,47 @@ +use std::path::PathBuf; + +mod utils; + +#[test] +fn run_torch_tests() { + const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); + let cargo_manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); + + // ====================================================================== // + // setup dependencies for the torch tests + + let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); + build_dir.push("torch-tests"); + let deps_dir = build_dir.join("deps"); + + let torch_dep = deps_dir.join("virtualenv"); + std::fs::create_dir_all(&torch_dep).expect("failed to create virtualenv dir"); + let python_exe = utils::create_python_venv(torch_dep); + let pytorch_cmake_prefix = utils::setup_torch_pip(&python_exe); + let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python_exe); + let metatensor_torch_cmake_prefix = utils::setup_metatensor_torch_pip(&python_exe); + + // ====================================================================== // + // build the metatomic-torch C++ tests and run them + let source_dir = cargo_manifest_dir; + + // configure cmake for the tests + let mut cmake_config = utils::cmake_config(&source_dir, &build_dir); + cmake_config.arg("-DMETATOMIC_TORCH_TESTS=ON"); + cmake_config.arg(format!( + "-DCMAKE_PREFIX_PATH={};{};{}", + pytorch_cmake_prefix.display(), + metatensor_cmake_prefix.display(), + metatensor_torch_cmake_prefix.display() + )); + + utils::run_command(cmake_config, "cmake configuration"); + + // build the tests + let cmake_build = utils::cmake_build(&build_dir); + utils::run_command(cmake_build, "cmake build"); + + // run the tests + let ctest = utils::ctest(&build_dir); + utils::run_command(ctest, "ctest"); +} diff --git a/metatomic-torch/tests/utils/mod.rs b/metatomic-torch/tests/utils/mod.rs new file mode 100644 index 000000000..e223bcee6 --- /dev/null +++ b/metatomic-torch/tests/utils/mod.rs @@ -0,0 +1,410 @@ +#![allow(dead_code)] +#![allow(clippy::needless_return)] + +use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; + +fn build_type() -> &'static str { + // assume that debug assertion means that we are building the code in + // debug mode, even if that could be not true in some cases + if cfg!(debug_assertions) { + "debug" + } else { + "release" + } +} + +fn append_flags(existing: Option, extra: &str) -> String { + match existing { + Some(flags) if !flags.trim().is_empty() => format!("{flags} {extra}"), + _ => extra.into(), + } +} + +pub fn cmake_config(source_dir: &Path, build_dir: &Path) -> Command { + let cmake = which::which("cmake").expect("could not find cmake"); + + let mut cmake_config = Command::new(cmake); + cmake_config.current_dir(build_dir); + cmake_config.arg(source_dir); + cmake_config.arg("--no-warn-unused-cli"); + cmake_config.arg(format!("-DCMAKE_BUILD_TYPE={}", build_type())); + + // the cargo executable currently running + let cargo_exe = std::env::var("CARGO").expect("CARGO env var is not set"); + cmake_config.arg(format!("-DCARGO_EXE={}", cargo_exe)); + + if std::env::var_os("CARGO_LLVM_COV").is_some() { + let coverage_compile_flags = "-fprofile-instr-generate -fcoverage-mapping"; + let coverage_link_flags = "-fprofile-instr-generate"; + + let c_flags = append_flags(std::env::var("CFLAGS").ok(), coverage_compile_flags); + let cxx_flags = append_flags(std::env::var("CXXFLAGS").ok(), coverage_compile_flags); + let exe_linker_flags = + append_flags(std::env::var("LDFLAGS").ok(), coverage_link_flags); + + cmake_config.arg(format!("-DCMAKE_C_FLAGS={c_flags}")); + cmake_config.arg(format!("-DCMAKE_CXX_FLAGS={cxx_flags}")); + cmake_config.arg(format!("-DCMAKE_EXE_LINKER_FLAGS={exe_linker_flags}")); + cmake_config.arg(format!("-DCMAKE_SHARED_LINKER_FLAGS={exe_linker_flags}")); + } + + return cmake_config; +} + +pub fn cmake_build(build_dir: &Path) -> Command { + let cmake = which::which("cmake").expect("could not find cmake"); + + let mut cmake_build = Command::new(cmake); + cmake_build.current_dir(build_dir); + cmake_build.arg("--build"); + cmake_build.arg("."); + cmake_build.arg("--parallel"); + cmake_build.arg("--config"); + cmake_build.arg(build_type()); + + return cmake_build; +} + + +pub fn ctest(build_dir: &Path) -> Command { + let ctest = which::which("ctest").expect("could not find ctest"); + + let mut ctest = Command::new(ctest); + ctest.current_dir(build_dir); + ctest.arg("--output-on-failure"); + ctest.arg("--build-config"); + ctest.arg(build_type()); + + return ctest +} + +/// Find the path to the uv binary, or None if not present +fn find_uv() -> Option { + which::which("uv").ok() +} + +/// Find the path to the `python`or `python3` binary on the user system +fn find_python() -> PathBuf { + if let Ok(python) = which::which("python") { + let output = Command::new(&python) + .arg("-c") + .arg("import sys; print(sys.version_info.major)") + .output() + .expect("could not run python"); + + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + + if stdout.trim() == "3" { + // we found Python 3 + return python; + } + } + } + + // try python3 + let python = which::which("python3").expect("failed to run `which python3`"); + let output = Command::new(&python) + .arg("-c") + .arg("import sys; print(sys.version_info.major)") + .output() + .expect("could not run python"); + + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + if stdout.trim() == "3" { + // we found Python 3 + return python; + } + } + + panic!("could not find Python 3") +} + +/// Helper: get python executable path inside a venv +fn python_in_venv(venv_dir: &Path) -> PathBuf { + let mut python = venv_dir.to_path_buf(); + if cfg!(target_os = "windows") { + python.extend(["Scripts", "python.exe"]); + } else { + python.extend(["bin", "python"]); + } + python +} + +/// Create a fresh Python virtualenv using uv if available, else fallback to +/// `python -m venv`, and return the path to the python executable in the venv +pub fn create_python_venv(build_dir: PathBuf) -> PathBuf { + if let Some(uv_bin) = find_uv() { + let mut cmd = Command::new(&uv_bin); + cmd.arg("venv"); + cmd.arg("--clear"); + cmd.arg(&build_dir); + + run_command(cmd, "uv venv creation"); + } else { + let mut cmd = Command::new(find_python()); + cmd.arg("-m"); + cmd.arg("venv"); + cmd.arg(&build_dir); + + run_command(cmd, "python to create virtualenv with `venv`"); + + // update pip in case the system uses a very old one + let python = python_in_venv(&build_dir); + let mut cmd = Command::new(&python); + cmd.arg("-m"); + cmd.arg("pip"); + cmd.arg("install"); + cmd.arg("--upgrade"); + cmd.arg("pip"); + + run_command(cmd, "pip upgrade in virtualenv"); + } + + python_in_venv(&build_dir) +} + +#[derive(Default)] +pub struct PipInstallOptions { + pub upgrade: bool, + pub no_deps: bool, + pub no_build_isolation: bool, +} + +/// Install a package with pip (uses uv if present, else falls back to python) +fn pip_install( + python: &Path, + packages: &[&str], + options: PipInstallOptions, +) { + if let Some(uv_bin) = find_uv() { + let mut cmd = Command::new(&uv_bin); + cmd.arg("pip").arg("install").arg("--python").arg(python); + + // follow the same behavior as pip when there are multiple indexes + cmd.arg("--index-strategy"); + cmd.arg("unsafe-best-match"); + + if options.upgrade { + cmd.arg("--upgrade"); + } + if options.no_deps { + cmd.arg("--no-deps"); + } + if options.no_build_isolation { + cmd.arg("--no-build-isolation"); + // uv doesn't support --check-build-dependencies + } + + for package in packages { + cmd.arg(package); + } + + run_command(cmd, "uv pip install"); + } else { + let mut cmd = Command::new(python); + cmd.arg("-m").arg("pip").arg("install"); + if options.upgrade { + cmd.arg("--upgrade"); + } + if options.no_deps { + cmd.arg("--no-deps"); + } + if options.no_build_isolation { + // If pip, add both supported options + cmd.arg("--no-build-isolation"); + cmd.arg("--check-build-dependencies"); + } + + for package in packages { + cmd.arg(package); + } + + run_command(cmd, "pip install"); + } +} + +/// Download PyTorch in a Python virtualenv, and return the +/// CMAKE_PREFIX_PATH for the corresponding libtorch +pub fn setup_torch_pip(python: &Path) -> PathBuf { + let torch_version = std::env::var("METATOMIC_TESTS_TORCH_VERSION").unwrap_or("2.12".into()); + pip_install( + python, + &[&format!("torch=={}.*", torch_version)], + PipInstallOptions { upgrade: true, no_deps: false, no_build_isolation: false } + ); + + let mut cmd = Command::new(python); + cmd.arg("-c"); + cmd.arg("import torch; print(torch.utils.cmake_prefix_path)"); + + let output = run_command(cmd, "python to get torch cmake prefix"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let prefix = PathBuf::from(stdout.trim()); + if !prefix.exists() { + panic!("'torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + } + + return prefix; +} + +/// Install metatensor in a Python virtualenv with pip, and return the +/// CMAKE_PREFIX_PATH for the installed libmetatensor. +pub fn setup_metatensor_pip(python: &Path) -> PathBuf { + pip_install(python, &["metatensor-core >=0.2.0,<0.3"], PipInstallOptions::default()); + + let mut cmd = Command::new(python); + cmd.arg("-c"); + cmd.arg("import metatensor; print(metatensor.utils.cmake_prefix_path)"); + + let output = run_command(cmd, "python to get metatensor cmake prefix"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let prefix = PathBuf::from(stdout.trim()); + if !prefix.exists() { + panic!("'metatensor.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + } + + return prefix; +} + +/// Install metatensor-torch in a Python virtualenv with pip, and return the +/// CMAKE_PREFIX_PATH for the installed libmetatensor_torch. +pub fn setup_metatensor_torch_pip(python: &Path) -> PathBuf { + pip_install(python, &["metatensor-torch >=0.9.0,<0.10"], PipInstallOptions::default()); + + let mut cmd = Command::new(python); + cmd.arg("-c"); + cmd.arg("import metatensor.torch; print(metatensor.torch.utils.cmake_prefix_path)"); + + let output = run_command(cmd, "python to get metatensor_torch cmake prefix"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let prefix = PathBuf::from(stdout.trim()); + if !prefix.exists() { + panic!("'metatensor.torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + } + + return prefix; +} + +/// Build metatomic-torch located in `source_dir` inside `build_dir`, and return +/// the installation prefix. +pub fn setup_metatomic_torch_cmake(source_dir: &Path, build_dir: &Path, cmake_args: Vec) -> PathBuf { + std::fs::create_dir_all(build_dir).expect("failed to create metatomic build dir"); + + // configure cmake for metatomic-torch + let mut cmake_config = cmake_config(source_dir, build_dir); + + let install_prefix = build_dir.join("usr"); + cmake_config.arg(format!("-DCMAKE_INSTALL_PREFIX={}", install_prefix.display())); + + // Add any additional cmake arguments + for arg in cmake_args { + cmake_config.arg(arg); + } + + run_command(cmake_config, "cmake configuration for metatomic_torch"); + + // build and install metatomic-torch + let mut cmake_build = cmake_build(build_dir); + cmake_build.arg("--target"); + cmake_build.arg("install"); + + run_command(cmake_build, "cmake build for metatomic_torch"); + + install_prefix +} + + +/// Install metatomic-torch in a Python virtualenv with pip, and return the +/// CMAKE_PREFIX_PATH for the installed libmetatomic_torch. +pub fn setup_metatomic_torch_pip(python: &Path, source_dir: &Path) -> PathBuf { + pip_install(python, &["setuptools>=77", "packaging>=23", "cmake"], PipInstallOptions::default()); + + pip_install( + python, + &[&source_dir.display().to_string()], + PipInstallOptions { + upgrade: true, + no_deps: false, + no_build_isolation: true + } + ); + + let mut cmd = Command::new(python); + cmd.arg("-c"); + cmd.arg("import metatomic.torch; print(metatomic.torch.utils.cmake_prefix_path)"); + + let output = run_command(cmd, "python to get metatomic_torch cmake prefix"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let prefix = PathBuf::from(stdout.trim()); + if !prefix.exists() { + panic!("'metatomic.torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + } + + return prefix; +} + + +pub fn run_command(mut command: Command, context: &str) -> std::process::Output { + write!(std::io::stdout().lock(), "\n\n[Running] {:?}\n\n", command).unwrap(); + + let mut child = command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn().unwrap_or_else(|_| panic!("failed to spawn {}", context)); + + let mut child_stdout = child.stdout.take().expect("missing stdout"); + let mut child_stderr = child.stderr.take().expect("missing stderr"); + + let out_handle = std::thread::spawn(move || -> std::io::Result> { + let mut buf = [0u8; 8192]; + let mut captured = Vec::new(); + let mut sink = std::io::stdout().lock(); + loop { + let n = child_stdout.read(&mut buf)?; + if n == 0 { + break; + } + sink.write_all(&buf[..n])?; + sink.flush()?; + captured.extend_from_slice(&buf[..n]); + } + Ok(captured) + }); + + let err_handle = std::thread::spawn(move || -> std::io::Result> { + let mut buf = [0u8; 8192]; + let mut captured = Vec::new(); + let mut sink = std::io::stderr().lock(); + loop { + let n = child_stderr.read(&mut buf)?; + if n == 0 { + break; + } + sink.write_all(&buf[..n])?; + sink.flush()?; + captured.extend_from_slice(&buf[..n]); + } + Ok(captured) + }); + + let status = child.wait().unwrap_or_else(|_| panic!("failed to run {}", context)); + let stdout = String::from_utf8_lossy(&out_handle.join().unwrap().unwrap()).into_owned(); + let stderr = String::from_utf8_lossy(&err_handle.join().unwrap().unwrap()).into_owned(); + + if !status.success() { + panic!( + "{} failed, status: {}\nstderr:\n\n{}\nstdout:\n\n{}\n", + context, status, stderr, stdout + ); + } + + return std::process::Output { status, stdout: stdout.into_bytes(), stderr: stderr.into_bytes() }; +} diff --git a/python/Cargo.toml b/python/Cargo.toml new file mode 100644 index 000000000..2ca54178e --- /dev/null +++ b/python/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "metatomic-python" +version = "0.0.0" +edition = "2021" +publish = false +rust-version = "1.74" + +[lib] +path = "lib.rs" + +[dev-dependencies] +which = "8" diff --git a/python/lib.rs b/python/lib.rs new file mode 100644 index 000000000..5ef74bad8 --- /dev/null +++ b/python/lib.rs @@ -0,0 +1 @@ +// empty lib.rs, this crate only exists to run Python tests with cargo diff --git a/python/tests/run-python-tests.rs b/python/tests/run-python-tests.rs new file mode 100644 index 000000000..8d52a6f83 --- /dev/null +++ b/python/tests/run-python-tests.rs @@ -0,0 +1,23 @@ +use std::path::PathBuf; +use std::process::Command; + +#[test] +fn run_python_tests() { + let tox = which::which("tox").expect("could not find tox"); + + let mut root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); + root.pop(); + + let mut tox = Command::new(tox); + tox.arg("--"); + if cfg!(debug_assertions) { + // assume that debug assertions means that we are building the code + // in debug mode, even if optimizations could be enabled + tox.env("METATOMIC_BUILD_TYPE", "debug"); + } else { + tox.env("METATOMIC_BUILD_TYPE", "release"); + } + tox.current_dir(&root); + let status = tox.status().expect("failed to run tox"); + assert!(status.success()); +} diff --git a/tox.ini b/tox.ini index 6eae4e866..8674eea74 100644 --- a/tox.ini +++ b/tox.ini @@ -6,8 +6,6 @@ requires = tox >=4.39 # `tox` in the command-line without anything else envlist = lint - torch-tests-cxx - torch-install-tests-cxx torch-tests docs-tests ase-tests @@ -45,75 +43,6 @@ metatensor_deps = metatensor-operations >=0.5.0,<0.6 -################################################################################ -##### C++ tests setup ##### -################################################################################ - -[testenv:torch-tests-cxx] -description = Run the C++ tests for metatomic-torch -deps = - cmake - {[testenv]metatensor_deps} - torch=={env:METATOMIC_TESTS_TORCH_VERSION:2.12}.* - -commands = - # configure cmake - cmake -B {env_dir}/build metatomic-torch \ - -DCMAKE_BUILD_TYPE=Debug \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ - -DCMAKE_PREFIX_PATH={env_site_packages_dir}/metatensor/;\ - {env_site_packages_dir}/torch/;\ - {env_site_packages_dir}/metatensor_torch/torch-{env:METATOMIC_TESTS_TORCH_VERSION:2.12}/ \ - -DMETATOMIC_TORCH_TESTS=ON - - # build code with cmake - cmake --build {env_dir}/build --config Debug --parallel - - # run all tests - ctest --test-dir {env_dir}/build --build-config Debug --output-on-failure - -[testenv:torch-install-tests-cxx] -description = Run the C++ tests for metatomic-torch -deps = - cmake - {[testenv]metatensor_deps} - torch=={env:METATOMIC_TESTS_TORCH_VERSION:2.12}.* - -commands = - # configure, build and install metatomic-torch - cmake -B {env_dir}/build-metatomic-torch metatomic-torch \ - -DCMAKE_BUILD_TYPE=Debug \ - -DCMAKE_INSTALL_PREFIX={env_dir}/usr/ \ - -DCMAKE_PREFIX_PATH={env_site_packages_dir}/metatensor/;\ - {env_site_packages_dir}/torch/;\ - {env_site_packages_dir}/metatensor_torch/torch-{env:METATOMIC_TESTS_TORCH_VERSION:2.12}/ \ - -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ - -DCMAKE_INSTALL_RPATH_USE_LINK_PATH=ON - cmake --build {env_dir}/build-metatomic-torch --config Debug --parallel --target install - - # try to use the installed metatomic-torch from another CMake project - cmake -B {env_dir}/build-find-package metatomic-torch/tests/cmake-project \ - -DCMAKE_BUILD_TYPE=Debug \ - -DCMAKE_PREFIX_PATH={env_site_packages_dir}/metatensor/;\ - {env_site_packages_dir}/torch/;\ - {env_site_packages_dir}/metatensor_torch/torch-{env:METATOMIC_TESTS_TORCH_VERSION:2.12}/;\ - {env_dir}/usr/ \ - -DUSE_CMAKE_SUBDIRECTORY=OFF - - cmake --build {env_dir}/build-find-package --config Debug --parallel - ctest --test-dir {env_dir}/build-find-package --build-config Debug --output-on-failure - - # Same, but using metatomic-torch as a CMake subdirectory - cmake -B {env_dir}/build-subdirectory metatomic-torch/tests/cmake-project \ - -DCMAKE_BUILD_TYPE=Debug \ - -DCMAKE_PREFIX_PATH={env_site_packages_dir}/metatensor/;\ - {env_site_packages_dir}/torch/;\ - {env_site_packages_dir}/metatensor_torch/torch-{env:METATOMIC_TESTS_TORCH_VERSION:2.12}/ \ - -DUSE_CMAKE_SUBDIRECTORY=ON - - cmake --build {env_dir}/build-subdirectory --config Debug --parallel - ctest --test-dir {env_dir}/build-subdirectory --build-config Debug --output-on-failure - ################################################################################ ##### Python tests setup ##### ################################################################################ From 63b9118ab8e36cf8595e25bebbb8e9d189bf0cd7 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Wed, 13 May 2026 14:45:08 +0200 Subject: [PATCH 04/20] Scaffold a new metatomic-core package --- .github/workflows/build-wheels.yml | 11 +- .github/workflows/torch-tests.yml | 7 + Cargo.toml | 1 + metatomic-core/CHANGELOG.md | 18 + metatomic-core/CMakeLists.txt | 506 ++++++++++++++++++ metatomic-core/Cargo.toml | 26 + metatomic-core/Clippy.toml | 1 + metatomic-core/build.rs | 48 ++ metatomic-core/cmake/dev-versions.cmake | 91 ++++ .../cmake/metatomic-config.in.cmake | 91 ++++ metatomic-core/cmake/tempdir.cmake | 51 ++ metatomic-core/include/metatomic.h | 32 ++ metatomic-core/include/metatomic.hpp | 2 + metatomic-core/include/metatomic/model.hpp | 7 + metatomic-core/include/metatomic/system.hpp | 7 + metatomic-core/src/c_api/mod.rs | 18 + metatomic-core/src/lib.rs | 13 + metatomic-core/tests/CMakeLists.txt | 86 +++ metatomic-core/tests/check-cxx-install.rs | 64 +++ .../tests/cmake-project/CMakeLists.txt | 84 +++ metatomic-core/tests/cmake-project/README.md | 3 + metatomic-core/tests/cmake-project/src/main.c | 8 + .../tests/cmake-project/src/main.cpp | 9 + .../tests/external/.gitattributes | 0 .../tests/external/CMakeLists.txt | 0 .../tests/external/catch/catch.cpp | 0 .../tests/external/catch/catch.hpp | 0 metatomic-core/tests/misc.cpp | 15 + metatomic-core/tests/run-cxx-tests.rs | 40 ++ metatomic-core/tests/utils/mod.rs | 470 ++++++++++++++++ metatomic-torch/tests/CMakeLists.txt | 3 +- metatomic-torch/tests/check-torch-install.rs | 10 +- metatomic-torch/tests/utils/mod.rs | 411 +------------- .../metatomic_torch/build-backend/backend.py | 17 +- 34 files changed, 1733 insertions(+), 417 deletions(-) create mode 100644 metatomic-core/CHANGELOG.md create mode 100644 metatomic-core/CMakeLists.txt create mode 100644 metatomic-core/Cargo.toml create mode 100644 metatomic-core/Clippy.toml create mode 100644 metatomic-core/build.rs create mode 100644 metatomic-core/cmake/dev-versions.cmake create mode 100644 metatomic-core/cmake/metatomic-config.in.cmake create mode 100644 metatomic-core/cmake/tempdir.cmake create mode 100644 metatomic-core/include/metatomic.h create mode 100644 metatomic-core/include/metatomic.hpp create mode 100644 metatomic-core/include/metatomic/model.hpp create mode 100644 metatomic-core/include/metatomic/system.hpp create mode 100644 metatomic-core/src/c_api/mod.rs create mode 100644 metatomic-core/src/lib.rs create mode 100644 metatomic-core/tests/CMakeLists.txt create mode 100644 metatomic-core/tests/check-cxx-install.rs create mode 100644 metatomic-core/tests/cmake-project/CMakeLists.txt create mode 100644 metatomic-core/tests/cmake-project/README.md create mode 100644 metatomic-core/tests/cmake-project/src/main.c create mode 100644 metatomic-core/tests/cmake-project/src/main.cpp rename {metatomic-torch => metatomic-core}/tests/external/.gitattributes (100%) rename {metatomic-torch => metatomic-core}/tests/external/CMakeLists.txt (100%) rename {metatomic-torch => metatomic-core}/tests/external/catch/catch.cpp (100%) rename {metatomic-torch => metatomic-core}/tests/external/catch/catch.hpp (100%) create mode 100644 metatomic-core/tests/misc.cpp create mode 100644 metatomic-core/tests/run-cxx-tests.rs create mode 100644 metatomic-core/tests/utils/mod.rs mode change 100644 => 120000 metatomic-torch/tests/utils/mod.rs diff --git a/.github/workflows/build-wheels.yml b/.github/workflows/build-wheels.yml index b12d709ee..564e04409 100644 --- a/.github/workflows/build-wheels.yml +++ b/.github/workflows/build-wheels.yml @@ -100,8 +100,17 @@ jobs: CIBW_BUILD_VERBOSITY: 1 CIBW_MANYLINUX_X86_64_IMAGE: gcc11-manylinux_2_28_x86_64 CIBW_MANYLINUX_AARCH64_IMAGE: gcc11-manylinux_2_28_aarch64 + # METATOMIC_NO_LOCAL_DEPS is set to 1 when building a tag of + # metatomic-torch, which will force to use the version of + # metatomic-core already released on PyPI. Otherwise, this will use + # the version of metatomic-core from git checkout (in case there are + # unreleased breaking changes). + # + # This means that when releasing a breaking change in metatomic-core, + # the full release should be available on PyPI before pushing the new + # metatomic-torch tag. CIBW_ENVIRONMENT: > - METATOMIC_NO_LOCAL_DEPS=1 + METATOMIC_NO_LOCAL_DEPS=${{ startsWith(github.ref, 'refs/tags/metatomic-torch-v') && '1' || '0' }} METATOMIC_TORCH_BUILD_WITH_TORCH_VERSION=${{ matrix.torch-version }}.* PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu MACOSX_DEPLOYMENT_TARGET=11 diff --git a/.github/workflows/torch-tests.yml b/.github/workflows/torch-tests.yml index 62ed6025c..ce3781709 100644 --- a/.github/workflows/torch-tests.yml +++ b/.github/workflows/torch-tests.yml @@ -55,6 +55,12 @@ jobs: with: fetch-depth: 0 + - name: setup Python + uses: actions/setup-python@v6 + if: matrix.container == null + with: + python-version: ${{ matrix.python-version }} + - name: Configure git safe directory if: matrix.container == 'ubuntu:22.04' run: git config --global --add safe.directory /__w/metatomic/metatomic @@ -95,3 +101,4 @@ jobs: PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu METATOMIC_TESTS_TORCH_VERSION: ${{ matrix.torch-version }} CXXFLAGS: ${{ matrix.cxx-flags }} + RUST_BACKTRACE: full diff --git a/Cargo.toml b/Cargo.toml index 5256b9601..1a233774c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ + "metatomic-core", "metatomic-torch", "python", ] diff --git a/metatomic-core/CHANGELOG.md b/metatomic-core/CHANGELOG.md new file mode 100644 index 000000000..160995db2 --- /dev/null +++ b/metatomic-core/CHANGELOG.md @@ -0,0 +1,18 @@ +# Changelog + +All notable changes to metatomic-core are documented here, following the [keep +a changelog](https://keepachangelog.com/en/1.1.0/) format. This project follows +[Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased](https://github.com/metatensor/metatensor/) + + diff --git a/metatomic-core/CMakeLists.txt b/metatomic-core/CMakeLists.txt new file mode 100644 index 000000000..717e52e81 --- /dev/null +++ b/metatomic-core/CMakeLists.txt @@ -0,0 +1,506 @@ +# This file defines the CMake build system for the C and C++ API of metatomic. +# +# This API is implemented in Rust, in the metatomic-core crate, but Rust users +# of the API should use the metatomic crate instead, wrapping metatomic-core in +# an easier to use, idiomatic Rust API. +cmake_minimum_required(VERSION 3.22) + +# Is metatomic the main project configured by the user? Or is this being used +# as a submodule/subdirectory? +if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR}) + set(METATOMIC_MAIN_PROJECT ON) +else() + set(METATOMIC_MAIN_PROJECT OFF) +endif() + +if(${METATOMIC_MAIN_PROJECT} AND NOT "${CACHED_LAST_CMAKE_VERSION}" VERSION_EQUAL ${CMAKE_VERSION}) + # We use CACHED_LAST_CMAKE_VERSION to only print the cmake version + # once in the configuration log + set(CACHED_LAST_CMAKE_VERSION ${CMAKE_VERSION} CACHE INTERNAL "Last version of cmake used to configure") + message(STATUS "Running CMake version ${CMAKE_VERSION}") +endif() + +if (POLICY CMP0077) + # use variables to set OPTIONS + cmake_policy(SET CMP0077 NEW) +endif() + +file(STRINGS "Cargo.toml" CARGO_TOML_CONTENT) +foreach(line ${CARGO_TOML_CONTENT}) + string(REGEX REPLACE "^version = \"(.*)\"" "\\1" METATOMIC_VERSION ${line}) + if (NOT ${CMAKE_MATCH_COUNT} EQUAL 0) + # stop on the first regex match, this should be the right version + break() + endif() +endforeach() + +include(cmake/dev-versions.cmake) +create_development_version("${METATOMIC_VERSION}" METATOMIC_FULL_VERSION "metatomic-core-v") +message(STATUS "Building metatomic-core v${METATOMIC_FULL_VERSION}") + +# strip any -dev/-rc suffix on the version since project(VERSION) does not support it +string(REGEX REPLACE "([0-9]*)\\.([0-9]*)\\.([0-9]*).*" "\\1.\\2.\\3" METATOMIC_VERSION ${METATOMIC_FULL_VERSION}) +project(metatomic + VERSION ${METATOMIC_VERSION} + LANGUAGES C CXX # we need to declare a language to access CMAKE_SIZEOF_VOID_P later +) +set(PROJECT_VERSION ${METATOMIC_FULL_VERSION}) + + +# We follow the standard CMake convention of using BUILD_SHARED_LIBS to provide +# either a shared or static library as a default target. But since cargo always +# builds both versions by default, we also install both versions by default. +# `METATOMIC_INSTALL_BOTH_STATIC_SHARED=OFF` allow to disable this behavior, and +# only install the file corresponding to `BUILD_SHARED_LIBS=ON/OFF`. +# +# BUILD_SHARED_LIBS controls the `metatomic` cmake target, making it an alias of +# either `metatomic::static` or `metatomic::shared`. This is mainly relevant +# when using metatomic from another cmake project, either as a submodule or from +# an installed library (see cmake/metatomic-config.cmake) +option(BUILD_SHARED_LIBS "Use a shared library by default instead of a static one" ON) +option(METATOMIC_INSTALL_BOTH_STATIC_SHARED "Install both shared and static libraries" ON) + +set(RUST_BUILD_TARGET "${RUST_BUILD_TARGET}" CACHE STRING "Cross-compilation target for rust code. Leave empty to build for the host") +set(EXTRA_RUST_FLAGS "${EXTRA_RUST_FLAGS}" CACHE STRING "Flags used to build rust code") + +include(GNUInstallDirs) + +if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") + message(STATUS "Setting build type to 'release' as none was specified.") + set(CMAKE_BUILD_TYPE "release" + CACHE STRING + "Choose the type of build, options are: debug or release" + FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug) +endif() + +if(${METATOMIC_MAIN_PROJECT} AND NOT "${CACHED_LAST_CMAKE_BUILD_TYPE}" STREQUAL "${CMAKE_BUILD_TYPE}") + set(CACHED_LAST_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE} CACHE INTERNAL "Last build type used in configuration") + message(STATUS "Building metatomic in ${CMAKE_BUILD_TYPE} mode") +endif() + + +function(check_compatible_versions _actual_ _requested_) + if(${_actual_} MATCHES "^([0-9]+)\\.([0-9]+)") + set(_actual_major_ "${CMAKE_MATCH_1}") + set(_actual_minor_ "${CMAKE_MATCH_2}") + else() + message(FATAL_ERROR "Failed to parse actual version: ${_actual_}") + endif() + + if(${_requested_} MATCHES "^([0-9]+)\\.([0-9]+)") + set(_requested_major_ "${CMAKE_MATCH_1}") + set(_requested_minor_ "${CMAKE_MATCH_2}") + else() + message(FATAL_ERROR "Failed to parse requested version: ${_requested_}") + endif() + + if (${_requested_major_} EQUAL 0 AND ${_actual_minor_} EQUAL ${_requested_minor_}) + # major version is 0 and same minor version, everything is fine + elseif (${_actual_major_} EQUAL ${_requested_major_}) + # same major version, everything is fine + else() + # not compatible + message(FATAL_ERROR "Incompatible versions: we need ${_requested_}, but we got ${_actual_}") + endif() +endfunction() + + +set(REQUIRED_METATENSOR_VERSION "0.2.0") +# Either metatensor is built as part of the same CMake project, or we try to +# find the corresponding CMake package +if (TARGET metatensor) + get_target_property(METATENSOR_BUILD_VERSION metatensor BUILD_VERSION) + check_compatible_versions(${METATENSOR_BUILD_VERSION} ${REQUIRED_METATENSOR_VERSION}) +else() + find_package(metatensor ${REQUIRED_METATENSOR_VERSION} CONFIG REQUIRED) +endif() + + +find_program(CARGO_EXE "cargo" DOC "path to cargo (Rust build system)") +if (NOT CARGO_EXE) + message(FATAL_ERROR + "could not find cargo, please make sure the Rust compiler is installed \ + (see https://www.rust-lang.org/tools/install) or set CARGO_EXE" + ) +endif() + +execute_process( + COMMAND ${CARGO_EXE} "--version" "--verbose" + RESULT_VARIABLE CARGO_STATUS + OUTPUT_VARIABLE CARGO_VERSION_RAW +) + +if(CARGO_STATUS AND NOT CARGO_STATUS EQUAL 0) + message(FATAL_ERROR + "could not run cargo, please make sure the Rust compiler is installed \ + (see https://www.rust-lang.org/tools/install)" + ) +endif() + +set(REQUIRED_RUST_VERSION "1.74.0") +if (CARGO_VERSION_RAW MATCHES "cargo ([0-9]+\\.[0-9]+\\.[0-9]+).*") + set(CARGO_VERSION "${CMAKE_MATCH_1}") +else() + message(FATAL_ERROR "failed to determine cargo version, output was: ${CARGO_VERSION_RAW}") +endif() + +if (${CARGO_VERSION} VERSION_LESS ${REQUIRED_RUST_VERSION}) + message(FATAL_ERROR + "your Rust installation is too old (you have version ${CARGO_VERSION}), \ + at least ${REQUIRED_RUST_VERSION} is required" + ) +else() + if(NOT "${CACHED_LAST_CARGO_VERSION}" STREQUAL ${CARGO_VERSION}) + set(CACHED_LAST_CARGO_VERSION ${CARGO_VERSION} CACHE INTERNAL "Last version of cargo used in configuration") + message(STATUS "Using cargo version ${CARGO_VERSION} at ${CARGO_EXE}") + set(CARGO_VERSION_CHANGED TRUE) + endif() +endif() + +# ============================================================================ # +# determine Cargo flags + +set(CARGO_BUILD_ARG "") + +if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/Cargo.lock) + set(CARGO_BUILD_ARG "${CARGO_BUILD_ARG};--locked") +endif() + +# TODO: support multiple configuration generators (MSVC, ...) +string(TOLOWER ${CMAKE_BUILD_TYPE} BUILD_TYPE) +if ("${BUILD_TYPE}" STREQUAL "debug") + set(CARGO_BUILD_TYPE "debug") +elseif("${BUILD_TYPE}" STREQUAL "release") + set(CARGO_BUILD_ARG "${CARGO_BUILD_ARG};--release") + set(CARGO_BUILD_TYPE "release") +elseif("${BUILD_TYPE}" STREQUAL "relwithdebinfo") + set(CARGO_BUILD_ARG "${CARGO_BUILD_ARG};--release") + set(CARGO_BUILD_TYPE "release") +else() + message(FATAL_ERROR "unsuported build type: ${CMAKE_BUILD_TYPE}") +endif() + +set(CARGO_TARGET_DIR ${CMAKE_CURRENT_BINARY_DIR}/target) +set(CARGO_BUILD_ARG "${CARGO_BUILD_ARG};--target-dir=${CARGO_TARGET_DIR}") + +if (CARGO_VERSION_RAW MATCHES "host: ([a-zA-Z0-9_\\-]*)\n") + set(RUST_HOST_TARGET "${CMAKE_MATCH_1}") + if (RUST_HOST_TARGET MATCHES "([a-zA-Z0-9_]*)\\-") + set(RUST_HOST_ARCH "${CMAKE_MATCH_1}") + else() + message(FATAL_ERROR "failed to determine host CPU arch, target was: ${RUST_HOST_TARGET}") + endif() +else() + message(FATAL_ERROR "failed to determine host target, output was: ${CARGO_VERSION_RAW}") +endif() + +if (WIN32) + # on Windows, we need to use the same ABI in both CMake and cargo. If the + # user did not explicitly request a target, we can try to set it ourself, + # otherwise we just check that it matches what we expect. + if (MSVC) + if ("${RUST_BUILD_TARGET}" STREQUAL "") + set(RUST_BUILD_TARGET "${RUST_HOST_ARCH}-pc-windows-msvc") + message(STATUS "Setting rust target to ${RUST_BUILD_TARGET}") + elseif(NOT "${RUST_BUILD_TARGET}" MATCHES "-pc-windows-msvc") + message(FATAL_ERROR "CMake is building with MSVC but the Rust target is ${RUST_BUILD_TARGET}") + endif() + endif() + + if (MINGW) + if ("${RUST_BUILD_TARGET}" STREQUAL "") + set(RUST_BUILD_TARGET "${RUST_HOST_ARCH}-pc-windows-gnu") + message(STATUS "Setting rust target to ${RUST_BUILD_TARGET}") + elseif(NOT "${RUST_BUILD_TARGET}" MATCHES "-pc-windows-gnu") + message(FATAL_ERROR "CMake is building with MinGW but the Rust target is ${RUST_BUILD_TARGET}") + endif() + endif() +endif() + +# Handle cross compilation with RUST_BUILD_TARGET +if ("${RUST_BUILD_TARGET}" STREQUAL "") + if (${METATOMIC_MAIN_PROJECT}) + message(STATUS "Compiling to host (${RUST_HOST_TARGET})") + endif() + + set(CARGO_OUTPUT_DIR "${CARGO_TARGET_DIR}/${CARGO_BUILD_TYPE}") + set(RUST_BUILD_TARGET ${RUST_HOST_TARGET}) +else() + if (${METATOMIC_MAIN_PROJECT}) + message(STATUS "Cross-compiling to ${RUST_BUILD_TARGET}") + endif() + + set(CARGO_BUILD_ARG "${CARGO_BUILD_ARG};--target=${RUST_BUILD_TARGET}") + set(CARGO_OUTPUT_DIR "${CARGO_TARGET_DIR}/${RUST_BUILD_TARGET}/${CARGO_BUILD_TYPE}") +endif() + +# Get the list of libraries linked by default by cargo/rustc to add when linking +# to metatomic::static +if (CARGO_VERSION_CHANGED) + include(cmake/tempdir.cmake) + get_tempdir(TMPDIR) + + # Adapted from https://github.com/corrosion-rs/corrosion/blob/dc1e4e5/cmake/FindRust.cmake + execute_process( + COMMAND "${CARGO_EXE}" new --lib _cargo_required_libs + WORKING_DIRECTORY "${TMPDIR}" + RESULT_VARIABLE cargo_new_result + ERROR_QUIET + ) + + if (cargo_new_result) + message(FATAL_ERROR "could not create empty project to find default static libs: ${cargo_new_result}") + endif() + + file(APPEND "${TMPDIR}/_cargo_required_libs/Cargo.toml" "[lib]\ncrate-type=[\"staticlib\"]") + + execute_process( + COMMAND ${CARGO_EXE} rustc --color never --target=${RUST_BUILD_TARGET} -- --print=native-static-libs + WORKING_DIRECTORY "${TMPDIR}/_cargo_required_libs" + RESULT_VARIABLE cargo_static_libs_result + ERROR_VARIABLE cargo_static_libs_stderr + ) + + # clean up the files + file(REMOVE_RECURSE "${TMPDIR}") + + if (cargo_static_libs_result) + message(FATAL_ERROR + "could not extract default static libs (status ${cargo_static_libs_result}), stderr:\n${cargo_static_libs_stderr}" + ) + endif() + + # The pattern starts with `native-static-libs:` and goes to the end of the line. + if (cargo_static_libs_stderr MATCHES "native-static-libs: ([^\r\n]+)\r?\n") + string(REPLACE " " ";" "libs_list" "${CMAKE_MATCH_1}") + set(stripped_lib_list "") + foreach(lib ${libs_list}) + # Strip leading `-l` (unix) and potential .lib suffix (windows) + string(REGEX REPLACE "^-l" "" "stripped_lib" "${lib}") + string(REGEX REPLACE "\.lib$" "" "stripped_lib" "${stripped_lib}") + list(APPEND stripped_lib_list "${stripped_lib}") + endforeach() + + # Special case `msvcrt` to link with the debug version in Debug mode. + list(TRANSFORM stripped_lib_list REPLACE "^msvcrt$" "\$<\$:msvcrtd>") + # Don't try to pass a linker *flag* where CMake expects libraries + list(REMOVE_ITEM stripped_lib_list "/defaultlib:msvcrt") + + if (APPLE) + # Prevent warnings about duplicated `System` in linked libraries + # from Apple's `ld` + list(REMOVE_ITEM stripped_lib_list "System") + endif() + + list(REMOVE_DUPLICATES stripped_lib_list) + set(CARGO_DEFAULT_LIBRARIES "${stripped_lib_list}" CACHE INTERNAL "list of implicitly linked libraries") + + if (${METATOMIC_MAIN_PROJECT}) + message(STATUS "Cargo default link libraries are: ${CARGO_DEFAULT_LIBRARIES}") + endif() + else() + message(FATAL_ERROR "could not find default static libs: `native-static-libs` not found in: `${cargo_static_libs_stderr}`") + endif() +endif() + +file(GLOB_RECURSE ALL_RUST_SOURCES + ${PROJECT_SOURCE_DIR}/Cargo.toml + ${PROJECT_SOURCE_DIR}/src/**.rs +) + +add_library(metatomic::shared SHARED IMPORTED GLOBAL) +set(METATOMIC_SHARED_LOCATION "${CARGO_OUTPUT_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}metatomic${CMAKE_SHARED_LIBRARY_SUFFIX}") +set(METATOMIC_IMPLIB_LOCATION "${METATOMIC_SHARED_LOCATION}.lib") + +if (MINGW) + # `rustc` does not follow the usual naming scheme for DLL with mingw (it + # would typically be 'libmetatomic.dll') + set(METATOMIC_SHARED_LOCATION "${CARGO_OUTPUT_DIR}/metatomic.dll") + set(METATOMIC_IMPLIB_LOCATION "${CARGO_OUTPUT_DIR}/libmetatomic.dll.a") +endif() + +add_library(metatomic::static STATIC IMPORTED GLOBAL) +set(METATOMIC_STATIC_LOCATION "${CARGO_OUTPUT_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}metatomic${CMAKE_STATIC_LIBRARY_SUFFIX}") + +get_filename_component(METATOMIC_SHARED_LIB_NAME ${METATOMIC_SHARED_LOCATION} NAME) +get_filename_component(METATOMIC_IMPLIB_NAME ${METATOMIC_IMPLIB_LOCATION} NAME) +get_filename_component(METATOMIC_STATIC_LIB_NAME ${METATOMIC_STATIC_LOCATION} NAME) + +# We need to add some metadata to the shared library to enable linking to it +# without using an absolute path. +if (UNIX) + if (APPLE) + # set the install name to `@rpath/libmetatomic.dylib` + set(CARGO_RUSTC_ARGS "-Clink-arg=-Wl,-install_name,@rpath/${METATOMIC_SHARED_LIB_NAME}") + set_target_properties(metatomic::shared PROPERTIES + IMPORTED_SONAME @rpath/${METATOMIC_SHARED_LIB_NAME} + ) + else() # LINUX + # set the SONAME to libmetatomic.so + set(CARGO_RUSTC_ARGS "-Clink-arg=-Wl,-soname,${METATOMIC_SHARED_LIB_NAME}") + set_target_properties(metatomic::shared PROPERTIES + IMPORTED_SONAME ${METATOMIC_SHARED_LIB_NAME} + ) + endif() +else() + set(CARGO_RUSTC_ARGS "") +endif() + +if (NOT "${EXTRA_RUST_FLAGS}" STREQUAL "") + set(CARGO_RUSTC_ARGS "${CARGO_RUSTC_ARGS};${EXTRA_RUST_FLAGS}") +endif() + +# Set environment variables for cargo build +set(CARGO_ENV "METATOMIC_FULL_VERSION=${METATOMIC_FULL_VERSION}") +if (NOT "${CMAKE_OSX_DEPLOYMENT_TARGET}" STREQUAL "") + list(APPEND CARGO_ENV "MACOSX_DEPLOYMENT_TARGET=${CMAKE_OSX_DEPLOYMENT_TARGET}") +endif() +if (NOT "$ENV{RUSTC_WRAPPER}" STREQUAL "") + list(APPEND CARGO_ENV "RUSTC_WRAPPER=$ENV{RUSTC_WRAPPER}") +endif() + +if (METATOMIC_INSTALL_BOTH_STATIC_SHARED) + set(CARGO_BUILD_ARG "${CARGO_BUILD_ARG};--crate-type=cdylib;--crate-type=staticlib") + set(CARGO_OUTPUTS ${METATOMIC_SHARED_LOCATION} ${METATOMIC_STATIC_LOCATION}) + if (WIN32) + list(APPEND CARGO_OUTPUTS ${METATOMIC_IMPLIB_LOCATION}) + set(FILE_CREATED_MESSAGE "${METATOMIC_SHARED_LIB_NAME}, ${METATOMIC_STATIC_LIB_NAME}, and ${METATOMIC_IMPLIB_NAME}") + else() + set(FILE_CREATED_MESSAGE "${METATOMIC_SHARED_LIB_NAME} and ${METATOMIC_STATIC_LIB_NAME}") + endif() +else() + if (BUILD_SHARED_LIBS) + set(CARGO_BUILD_ARG "${CARGO_BUILD_ARG};--crate-type=cdylib") + set(CARGO_OUTPUTS ${METATOMIC_SHARED_LOCATION}) + if (WIN32) + list(APPEND CARGO_OUTPUTS ${METATOMIC_IMPLIB_LOCATION}) + set(FILE_CREATED_MESSAGE "${METATOMIC_SHARED_LIB_NAME} and ${METATOMIC_IMPLIB_NAME}") + else() + set(FILE_CREATED_MESSAGE "${METATOMIC_SHARED_LIB_NAME}") + endif() + else() + set(CARGO_BUILD_ARG "${CARGO_BUILD_ARG};--crate-type=staticlib") + set(CARGO_OUTPUTS ${METATOMIC_STATIC_LOCATION}) + set(FILE_CREATED_MESSAGE "${METATOMIC_STATIC_LIB_NAME}") + endif() +endif() + +add_custom_command( + OUTPUT ${CARGO_OUTPUTS} + COMMAND ${CMAKE_COMMAND} -E env ${CARGO_ENV} + cargo rustc ${CARGO_BUILD_ARG} -- ${CARGO_RUSTC_ARGS} + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + DEPENDS ${ALL_RUST_SOURCES} + COMMENT "Building ${FILE_CREATED_MESSAGE} with cargo" + VERBATIM +) +add_custom_target(cargo-build-metatomic ALL DEPENDS ${CARGO_OUTPUTS}) + +# Auto-generate a header containing the version number as #define +set(_path_ "${CMAKE_CURRENT_BINARY_DIR}/generated-version.h") +file(WRITE ${_path_} "#pragma once\n\n") +file(APPEND ${_path_} "/** Full version of metatomic as a string */\n") +file(APPEND ${_path_} "#define METATOMIC_VERSION \"${METATOMIC_FULL_VERSION}\"\n\n") +file(APPEND ${_path_} "/** Major version number of metatomic as an integer */\n") +file(APPEND ${_path_} "#define METATOMIC_VERSION_MAJOR ${PROJECT_VERSION_MAJOR}\n\n") +file(APPEND ${_path_} "/** Minor version number of metatomic as an integer */\n") +file(APPEND ${_path_} "#define METATOMIC_VERSION_MINOR ${PROJECT_VERSION_MINOR}\n\n") +file(APPEND ${_path_} "/** Patch version number of metatomic as an integer */\n") +file(APPEND ${_path_} "#define METATOMIC_VERSION_PATCH ${PROJECT_VERSION_PATCH}\n") + +file(MAKE_DIRECTORY ${PROJECT_BINARY_DIR}/include/metatomic) +set(_destination_ "${CMAKE_CURRENT_BINARY_DIR}/include/metatomic/version.h") +file(COPY_FILE ${_path_} ${_destination_} ONLY_IF_DIFFERENT) + +add_dependencies(metatomic::shared cargo-build-metatomic) +add_dependencies(metatomic::static cargo-build-metatomic) + +set_target_properties(metatomic::shared PROPERTIES + IMPORTED_LOCATION ${METATOMIC_SHARED_LOCATION} + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_BINARY_DIR}/include" + BUILD_VERSION "${METATOMIC_FULL_VERSION}" +) +target_compile_features(metatomic::shared INTERFACE cxx_std_17) + +if (WIN32) + set_target_properties(metatomic::shared PROPERTIES + IMPORTED_IMPLIB ${METATOMIC_IMPLIB_LOCATION} + ) +endif() + +set_target_properties(metatomic::static PROPERTIES + IMPORTED_LOCATION ${METATOMIC_STATIC_LOCATION} + INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_BINARY_DIR}/include" + INTERFACE_LINK_LIBRARIES "${CARGO_DEFAULT_LIBRARIES}" + BUILD_VERSION "${METATOMIC_FULL_VERSION}" +) +target_compile_features(metatomic::static INTERFACE cxx_std_17) + +if (TARGET metatensor::static) + target_link_libraries(metatomic::static INTERFACE metatensor::static) +else() + target_link_libraries(metatomic::static INTERFACE metatensor) +endif() + +if (TARGET metatensor::shared) + target_link_libraries(metatomic::shared INTERFACE metatensor::shared) +else() + target_link_libraries(metatomic::shared INTERFACE metatensor) +endif() + + +if (BUILD_SHARED_LIBS) + add_library(metatomic ALIAS metatomic::shared) +else() + add_library(metatomic ALIAS metatomic::static) +endif() + +#------------------------------------------------------------------------------# +# Installation configuration +#------------------------------------------------------------------------------# +include(CMakePackageConfigHelpers) +configure_package_config_file( + ${PROJECT_SOURCE_DIR}/cmake/metatomic-config.in.cmake + ${PROJECT_BINARY_DIR}/metatomic-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/metatomic +) +write_basic_package_version_file( + metatomic-config-version.cmake + VERSION ${METATOMIC_FULL_VERSION} + COMPATIBILITY SameMinorVersion +) + +install(FILES "include/metatomic.h" DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) +install(FILES "include/metatomic.hpp" DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) +install(DIRECTORY "include/metatomic" DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/include/metatomic/version.h" DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metatomic) + +if (METATOMIC_INSTALL_BOTH_STATIC_SHARED OR BUILD_SHARED_LIBS) + if (WIN32) + # DLL files should go in /bin + install( + FILES ${METATOMIC_SHARED_LOCATION} + DESTINATION ${CMAKE_INSTALL_BINDIR} + PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ GROUP_EXECUTE GROUP_READ WORLD_READ WORLD_EXECUTE + ) + # .lib files should go in /lib + install(FILES ${METATOMIC_IMPLIB_LOCATION} DESTINATION ${CMAKE_INSTALL_LIBDIR}) + else() + install( + FILES ${METATOMIC_SHARED_LOCATION} + DESTINATION ${CMAKE_INSTALL_LIBDIR} + PERMISSIONS OWNER_EXECUTE OWNER_WRITE OWNER_READ GROUP_EXECUTE GROUP_READ WORLD_READ WORLD_EXECUTE + ) + endif() +endif() + +if (METATOMIC_INSTALL_BOTH_STATIC_SHARED OR NOT BUILD_SHARED_LIBS) + install(FILES ${METATOMIC_STATIC_LOCATION} DESTINATION ${CMAKE_INSTALL_LIBDIR}) +endif() + +install(FILES + ${PROJECT_BINARY_DIR}/metatomic-config-version.cmake + ${PROJECT_BINARY_DIR}/metatomic-config.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/metatomic +) diff --git a/metatomic-core/Cargo.toml b/metatomic-core/Cargo.toml new file mode 100644 index 000000000..2a32c1c09 --- /dev/null +++ b/metatomic-core/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "metatomic-core" +version = "0.1.0" +edition = "2021" +publish = false +rust-version = "1.74" +exclude = [ + "tests" +] + +[lib] +crate-type = ["cdylib", "staticlib"] +name = "metatomic" +bench = false + +[dependencies] +once_cell = "1" + + +[build-dependencies] +cbindgen = { version = "0.29", default-features = false } + + +[dev-dependencies] +lazy_static = "1" +which = "8" diff --git a/metatomic-core/Clippy.toml b/metatomic-core/Clippy.toml new file mode 100644 index 000000000..49c5aa7b9 --- /dev/null +++ b/metatomic-core/Clippy.toml @@ -0,0 +1 @@ +doc-valid-idents = ["DLPack", "ROCm", ".."] diff --git a/metatomic-core/build.rs b/metatomic-core/build.rs new file mode 100644 index 000000000..edec71e60 --- /dev/null +++ b/metatomic-core/build.rs @@ -0,0 +1,48 @@ +#![allow(clippy::field_reassign_with_default)] + +use std::path::PathBuf; + +fn main() { + let crate_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); + + let generated_comment = "\ +/* ============ Automatically generated file, DO NOT EDIT. ============== * + * * + * This file is automatically generated from the metatomic sources, * + * using cbindgen. If you want to change this file (including documentation), * + * make the corresponding changes in the rust sources and regenerate it. * + * ============================================================================= */"; + + let mut config: cbindgen::Config = Default::default(); + config.language = cbindgen::Language::C; + config.cpp_compat = true; + config.include_guard = Some("METATOMIC_H".into()); + config.include_version = false; + config.documentation = true; + config.documentation_style = cbindgen::DocumentationStyle::Doxy; + config.line_endings = cbindgen::LineEndingStyle::LF; + config.autogen_warning = Some(generated_comment.into()); + config.includes.push("metatomic/version.h".into()); + + let result = cbindgen::Builder::new() + .with_crate(crate_dir) + .with_config(config) + .generate() + .map(|data| { + let mut path = PathBuf::from("include"); + path.push("metatomic.h"); + data.write_to_file(&path); + }); + + // if not ok, rerun the build script unconditionally + if result.is_ok() { + println!("cargo:rerun-if-changed=src"); + println!("cargo:rerun-if-changed=build.rs"); + } + + if std::env::var("METATOMIC_FULL_VERSION").is_err() { + let version = std::env::var("CARGO_PKG_VERSION").expect("missing CARGO_PKG_VERSION"); + println!("cargo:rustc-env=METATOMIC_FULL_VERSION={}+rust", version); + } + println!("cargo:rerun-if-env-changed=METATOMIC_FULL_VERSION"); +} diff --git a/metatomic-core/cmake/dev-versions.cmake b/metatomic-core/cmake/dev-versions.cmake new file mode 100644 index 000000000..543296493 --- /dev/null +++ b/metatomic-core/cmake/dev-versions.cmake @@ -0,0 +1,91 @@ +# Parse a `_version_` number, and store its components in `_major_` `_minor_` +# `_patch_` and `_rc_` +function(parse_version _version_ _major_ _minor_ _patch_ _rc_) + string(REGEX MATCH "([0-9]+)\\.([0-9]+)\\.([0-9]+)(-rc)?([0-9]+)?" _ "${_version_}") + + if(${CMAKE_MATCH_COUNT} EQUAL 3) + set(${_rc_} "" PARENT_SCOPE) + elseif(${CMAKE_MATCH_COUNT} EQUAL 5) + set(${_rc_} ${CMAKE_MATCH_5} PARENT_SCOPE) + else() + message(FATAL_ERROR "invalid version string ${_version_}") + endif() + + set(${_major_} ${CMAKE_MATCH_1} PARENT_SCOPE) + set(${_minor_} ${CMAKE_MATCH_2} PARENT_SCOPE) + set(${_patch_} ${CMAKE_MATCH_3} PARENT_SCOPE) +endfunction() + +# Get the time of the last modification since the last tag/release, and a hash +# of the latest commit/full state of a dirty repository +function(git_version_info _tag_prefix_ _output_n_commits_ _output_git_hash_) + set(_script_ "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/../../scripts/git-version-info.py") + + if (EXISTS "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/git_version_info") + # When building from a tarball, the script is executed and the result + # put in this file + file(STRINGS "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/git_version_info" _file_content_) + list(GET _file_content_ 0 _n_commits_) + list(GET _file_content_ 1 _git_hash_) + + elseif (EXISTS "${_script_}") + # When building from a checkout, we'll need to run the script + find_package(Python COMPONENTS Interpreter REQUIRED) + execute_process( + COMMAND "${Python_EXECUTABLE}" "${_script_}" "${_tag_prefix_}" + RESULT_VARIABLE _status_ + OUTPUT_VARIABLE _stdout_ + ERROR_VARIABLE _stderr_ + WORKING_DIRECTORY ${CMAKE_CURRENT_FUNCTION_LIST_DIR} + ) + + if (NOT ${_status_} EQUAL 0) + message(WARNING + "git-version-info.py failed, version number might be wrong:\nstdout: ${_stdout_}\nstderr: ${_stderr_}") + set(${_output_} 0 PARENT_SCOPE) + return() + endif() + + if (NOT "${_stderr_}" STREQUAL "") + message(WARNING "git-version-info.py gave some errors, version number might be wrong:\nstdout: ${_stdout_}\nstderr: ${_stderr_}") + endif() + + string(REPLACE "\n" ";" _lines_ ${_stdout_}) + list(GET _lines_ 0 _n_commits_) + list(GET _lines_ 1 _git_hash_) + else() + message(FATAL_ERROR "could not update git version information") + endif() + + string(STRIP ${_n_commits_} _n_commits_) + set(${_output_n_commits_} ${_n_commits_} PARENT_SCOPE) + + string(STRIP ${_git_hash_} _git_hash_) + set(${_output_git_hash_} ${_git_hash_} PARENT_SCOPE) +endfunction() + + +# Take the version declared in the package, and increase the right number if we +# are actually installing a developement version from after the latest git tag +function(create_development_version _version_ _output_ _tag_prefix_) + git_version_info("${_tag_prefix_}" _n_commits_ _git_hash_) + + parse_version(${_version_} _major_ _minor_ _patch_ _rc_) + if(${_n_commits_} STREQUAL "0") + # we are building a release, leave the version number as-is + if("${_rc_}" STREQUAL "") + set(${_output_} "${_major_}.${_minor_}.${_patch_}" PARENT_SCOPE) + else() + set(${_output_} "${_major_}.${_minor_}.${_patch_}-rc${_rc_}" PARENT_SCOPE) + endif() + else() + # we are building a development version, increase the right part of the version + if("${_rc_}" STREQUAL "") + math(EXPR _minor_ "${_minor_} + 1") + set(${_output_} "${_major_}.${_minor_}.0-dev${_n_commits_}+${_git_hash_}" PARENT_SCOPE) + else() + math(EXPR _rc_ "${_rc_} + 1") + set(${_output_} "${_major_}.${_minor_}.${_patch_}-rc${_rc_}-dev${_n_commits_}+${_git_hash_}" PARENT_SCOPE) + endif() + endif() +endfunction() diff --git a/metatomic-core/cmake/metatomic-config.in.cmake b/metatomic-core/cmake/metatomic-config.in.cmake new file mode 100644 index 000000000..310f54364 --- /dev/null +++ b/metatomic-core/cmake/metatomic-config.in.cmake @@ -0,0 +1,91 @@ +@PACKAGE_INIT@ + +cmake_minimum_required(VERSION 3.22) + +include(CMakeFindDependencyMacro) +include(FindPackageHandleStandardArgs) + +if(metatomic_FOUND) + return() +endif() + +enable_language(CXX) + +# use the same version for metatensor-core as the main CMakeLists.txt +set(REQUIRED_METATENSOR_VERSION @REQUIRED_METATENSOR_VERSION@) +find_package(metatensor ${REQUIRED_METATENSOR_VERSION} CONFIG REQUIRED) + +get_filename_component(METATOMIC_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/@PACKAGE_RELATIVE_PATH@" ABSOLUTE) + +if (WIN32) + set(METATOMIC_SHARED_LOCATION ${METATOMIC_PREFIX_DIR}/@CMAKE_INSTALL_BINDIR@/@METATOMIC_SHARED_LIB_NAME@) + set(METATOMIC_IMPLIB_LOCATION ${METATOMIC_PREFIX_DIR}/@CMAKE_INSTALL_LIBDIR@/@METATOMIC_IMPLIB_NAME@) +else() + set(METATOMIC_SHARED_LOCATION ${METATOMIC_PREFIX_DIR}/@CMAKE_INSTALL_LIBDIR@/@METATOMIC_SHARED_LIB_NAME@) +endif() + +set(METATOMIC_STATIC_LOCATION ${METATOMIC_PREFIX_DIR}/@CMAKE_INSTALL_LIBDIR@/@METATOMIC_STATIC_LIB_NAME@) +set(METATOMIC_INCLUDE ${METATOMIC_PREFIX_DIR}/@CMAKE_INSTALL_INCLUDEDIR@/) + +if (NOT EXISTS ${METATOMIC_INCLUDE}/metatomic.h OR NOT EXISTS ${METATOMIC_INCLUDE}/metatomic.hpp) + message(FATAL_ERROR "could not find metatomic headers in '${METATOMIC_INCLUDE}', please re-install metatomic") +endif() + + +# Shared library target +if (@METATOMIC_INSTALL_BOTH_STATIC_SHARED@ OR @BUILD_SHARED_LIBS@) + if (NOT EXISTS ${METATOMIC_SHARED_LOCATION}) + message(FATAL_ERROR "could not find metatomic library at '${METATOMIC_SHARED_LOCATION}', please re-install metatomic") + endif() + + add_library(metatomic::shared SHARED IMPORTED) + set_target_properties(metatomic::shared PROPERTIES + IMPORTED_LOCATION ${METATOMIC_SHARED_LOCATION} + INTERFACE_INCLUDE_DIRECTORIES ${METATOMIC_INCLUDE} + BUILD_VERSION "@METATOMIC_FULL_VERSION@" + ) + + target_compile_features(metatomic::shared INTERFACE cxx_std_17) + + if (WIN32) + if (NOT EXISTS ${METATOMIC_IMPLIB_LOCATION}) + message(FATAL_ERROR "could not find metatomic library at '${METATOMIC_IMPLIB_LOCATION}', please re-install metatomic") + endif() + + set_target_properties(metatomic::shared PROPERTIES + IMPORTED_IMPLIB ${METATOMIC_IMPLIB_LOCATION} + ) + endif() +endif() + + +# Static library target +if (@METATOMIC_INSTALL_BOTH_STATIC_SHARED@ OR NOT @BUILD_SHARED_LIBS@) + if (NOT EXISTS ${METATOMIC_STATIC_LOCATION}) + message(FATAL_ERROR "could not find metatomic library at '${METATOMIC_STATIC_LOCATION}', please re-install metatomic") + endif() + + add_library(metatomic::static STATIC IMPORTED) + set_target_properties(metatomic::static PROPERTIES + IMPORTED_LOCATION ${METATOMIC_STATIC_LOCATION} + INTERFACE_INCLUDE_DIRECTORIES ${METATOMIC_INCLUDE} + INTERFACE_LINK_LIBRARIES "@CARGO_DEFAULT_LIBRARIES@" + BUILD_VERSION "@METATOMIC_FULL_VERSION@" + ) + + target_compile_features(metatomic::static INTERFACE cxx_std_17) +endif() + +# Export either the shared or static library as the metatomic target +if (@BUILD_SHARED_LIBS@) + add_library(metatomic ALIAS metatomic::shared) +else() + add_library(metatomic ALIAS metatomic::static) +endif() + + +if (@BUILD_SHARED_LIBS@) + find_package_handle_standard_args(metatomic DEFAULT_MSG METATOMIC_SHARED_LOCATION METATOMIC_INCLUDE) +else() + find_package_handle_standard_args(metatomic DEFAULT_MSG METATOMIC_STATIC_LOCATION METATOMIC_INCLUDE) +endif() diff --git a/metatomic-core/cmake/tempdir.cmake b/metatomic-core/cmake/tempdir.cmake new file mode 100644 index 000000000..52e4805fc --- /dev/null +++ b/metatomic-core/cmake/tempdir.cmake @@ -0,0 +1,51 @@ +# Create a temporary directory using mktemp on *nix and powershell on windows +function(get_tempdir _outvar_) + # special case for github actions, where $TEMP might + # exist but point to nowhere/a non writable location + # https://docs.github.com/en/actions/learn-github-actions/variables + if (DEFINED ENV{RUNNER_TEMP}) + string(RANDOM LENGTH 12 _dirname_) + set(_output_ $ENV{RUNNER_TEMP}/${_dirname_}) + file(TO_NATIVE_PATH "${_output_}" _output_) + file(MAKE_DIRECTORY ${_output_}) + set(${_outvar_} ${_output_} PARENT_SCOPE) + return() + endif() + + find_program(MKTEMP_EXE NAMES mktemp) + if(MKTEMP_EXE) + execute_process( + COMMAND ${MKTEMP_EXE} -d + OUTPUT_VARIABLE _output_ + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _status_ + ) + + if(_status_ EQUAL 0) + file(MAKE_DIRECTORY ${_output_}) + set(${_outvar_} ${_output_} PARENT_SCOPE) + return() + endif() + endif() + + + find_program(POWERSHELL_EXE NAMES pwsh powershell) + if(POWERSHELL_EXE) + execute_process( + COMMAND ${POWERSHELL_EXE} -c "[System.IO.Path]::GetTempPath()" + OUTPUT_VARIABLE _output_ + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _status_ + ) + + if(_status_ EQUAL 0) + string(RANDOM LENGTH 12 _dirname_) + set(_output_ ${_output_}${_dirname_}) + file(MAKE_DIRECTORY ${_output_}) + set(${_outvar_} ${_output_} PARENT_SCOPE) + return() + endif() + endif() + + message(FATAL_ERROR "Could not find mktemp or PowerShell to make temporary directory") +endfunction() diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h new file mode 100644 index 000000000..83b6cb1fc --- /dev/null +++ b/metatomic-core/include/metatomic.h @@ -0,0 +1,32 @@ +#ifndef METATOMIC_H +#define METATOMIC_H + +/* ============ Automatically generated file, DO NOT EDIT. ============== * + * * + * This file is automatically generated from the metatomic sources, * + * using cbindgen. If you want to change this file (including documentation), * + * make the corresponding changes in the rust sources and regenerate it. * + * ============================================================================= */ + +#include +#include +#include +#include +#include "metatomic/version.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/** + * Get the runtime version of the metatomic library as a string. + * + * This version follows the `..[-]` format. + */ +const char *mta_version(void); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif /* METATOMIC_H */ diff --git a/metatomic-core/include/metatomic.hpp b/metatomic-core/include/metatomic.hpp new file mode 100644 index 000000000..016f26bc5 --- /dev/null +++ b/metatomic-core/include/metatomic.hpp @@ -0,0 +1,2 @@ +#include "metatomic/system.hpp" // IWYU pragma: export +#include "metatomic/model.hpp" // IWYU pragma: export diff --git a/metatomic-core/include/metatomic/model.hpp b/metatomic-core/include/metatomic/model.hpp new file mode 100644 index 000000000..1cae91bdf --- /dev/null +++ b/metatomic-core/include/metatomic/model.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace metatomic { + +} // namespace metatomic diff --git a/metatomic-core/include/metatomic/system.hpp b/metatomic-core/include/metatomic/system.hpp new file mode 100644 index 000000000..1cae91bdf --- /dev/null +++ b/metatomic-core/include/metatomic/system.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace metatomic { + +} // namespace metatomic diff --git a/metatomic-core/src/c_api/mod.rs b/metatomic-core/src/c_api/mod.rs new file mode 100644 index 000000000..33e0786dc --- /dev/null +++ b/metatomic-core/src/c_api/mod.rs @@ -0,0 +1,18 @@ +use std::ffi::CString; +use std::os::raw::c_char; + +use once_cell::sync::Lazy; + + +static VERSION: Lazy = Lazy::new(|| { + CString::new(env!("METATOMIC_FULL_VERSION")).expect("version contains NULL byte") +}); + + +/// Get the runtime version of the metatomic library as a string. +/// +/// This version follows the `..[-]` format. +#[no_mangle] +pub extern "C" fn mta_version() -> *const c_char { + return VERSION.as_ptr(); +} diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs new file mode 100644 index 000000000..bc47948b4 --- /dev/null +++ b/metatomic-core/src/lib.rs @@ -0,0 +1,13 @@ +#![warn(clippy::all, clippy::pedantic)] + +// disable some style lints +#![allow(clippy::needless_return, clippy::must_use_candidate, clippy::comparison_chain)] +#![allow(clippy::redundant_field_names, clippy::redundant_closure_for_method_calls, clippy::redundant_else)] +#![allow(clippy::unreadable_literal, clippy::option_if_let_else, clippy::module_name_repetitions)] +#![allow(clippy::missing_errors_doc, clippy::missing_panics_doc, clippy::missing_safety_doc)] +#![allow(clippy::similar_names, clippy::borrow_as_ptr, clippy::uninlined_format_args)] +#![allow(clippy::let_underscore_untyped, clippy::manual_let_else, clippy::empty_line_after_doc_comments)] + + +#[doc(hidden)] +mod c_api; diff --git a/metatomic-core/tests/CMakeLists.txt b/metatomic-core/tests/CMakeLists.txt new file mode 100644 index 000000000..1a96108d5 --- /dev/null +++ b/metatomic-core/tests/CMakeLists.txt @@ -0,0 +1,86 @@ +cmake_minimum_required(VERSION 3.22) +project(metatomic-tests) + +if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR}) + if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") + message(STATUS "Setting build type to 'release' as none was specified.") + set(CMAKE_BUILD_TYPE "release" + CACHE STRING + "Choose the type of build, options are: debug or release" + FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug) + endif() +endif() + +if (MINGW) + # CI can't find libsdc++, so we statically link it + set(CMAKE_EXE_LINKER_FLAGS "-static-libstdc++") +endif() + +add_subdirectory(../ metatomic) +get_target_property(METATOMIC_IMPORTED_LOCATION metatomic::shared IMPORTED_LOCATION) +get_filename_component(METATOMIC_DIR ${METATOMIC_IMPORTED_LOCATION} DIRECTORY) + +add_subdirectory(external) + +find_program(VALGRIND valgrind) +if (VALGRIND) + if (NOT "$ENV{METATOMIC_DISABLE_VALGRIND}" EQUAL "1") + message(STATUS "Running tests using valgrind") + set(TEST_COMMAND + "${VALGRIND}" "--tool=memcheck" "--dsymutil=yes" "--error-exitcode=125" + "--leak-check=full" "--show-leak-kinds=definite,indirect,possible" "--track-origins=yes" + "--gen-suppressions=all" + ) + endif() +else() + set(TEST_COMMAND "") +endif() + +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Weverything") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++98-compat") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-c++98-compat-pedantic") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-float-equal") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-missing-prototypes") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-shadow-uncaptured-local") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-padded") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unsafe-buffer-usage") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-poison-system-directories") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-allocator-wrappers") +endif() + + +enable_testing() + +file(GLOB ALL_TESTS *.cpp) +foreach(_file_ ${ALL_TESTS}) + get_filename_component(_name_ ${_file_} NAME_WE) + add_executable(${_name_} ${_file_}) + target_link_libraries(${_name_} metatomic catch) + + set_target_properties(${_name_} PROPERTIES + # Ensure that the binaries find the right shared library. + # + # Without this, when configuring with cmake before the library is built, + # cmake does not find the library on the filesystem and does not add the + # RPATH to executables linking to it + BUILD_RPATH ${METATOMIC_DIR} + NO_SYSTEM_FROM_IMPORTED ON + ) + + add_test( + NAME ${_name_} + COMMAND ${TEST_COMMAND} $ + ) + + if(WIN32) + # We need to set the path to allow access to metatomic.dll + # this does a similar job to the BUILD_RPATH above + STRING(REPLACE ";" "\\;" PATH_STRING "$ENV{PATH}") + set_tests_properties(${_name_} PROPERTIES + ENVIRONMENT "PATH=${PATH_STRING}\;$" + ) + endif() +endforeach() diff --git a/metatomic-core/tests/check-cxx-install.rs b/metatomic-core/tests/check-cxx-install.rs new file mode 100644 index 000000000..d66f4883b --- /dev/null +++ b/metatomic-core/tests/check-cxx-install.rs @@ -0,0 +1,64 @@ +use std::path::PathBuf; +use std::sync::Mutex; + +mod utils; + +lazy_static::lazy_static! { + // Make sure only one of the tests below run at the time, since they both + // try to modify the same files + static ref LOCK: Mutex<()> = Mutex::new(()); +} + + +/// Check that metatomic can be built and installed with cmake, and that the +/// installed version can be used from another cmake project with `find_package` +#[test] +fn check_cxx_install() { + let _guard = match LOCK.lock() { + Ok(guard) => guard, + Err(_) => { + panic!("another test failed, stopping") + } + }; + + const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); + + // ====================================================================== // + // build and install metatensor with cmake + let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); + build_dir.push("cxx-install"); + build_dir.push("cmake-find-package"); + std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); + + let deps_dir = build_dir.join("deps"); + let virtualenv_dir = deps_dir.join("virtualenv"); + std::fs::create_dir_all(&virtualenv_dir).expect("failed to create virtualenv dir"); + let python_exe = utils::create_python_venv(virtualenv_dir); + let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python_exe); + + let metatomic_dep = deps_dir.join("metatomic-core"); + let source_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); + + let cmake_args = vec![ + format!("-DCMAKE_PREFIX_PATH={}", metatensor_cmake_prefix.display()), + ]; + let metatomic_cmake_prefix = utils::setup_metatomic_cmake(&source_dir, &metatomic_dep, cmake_args); + + // ====================================================================== // + // try to use the installed metatomic from cmake + let mut tests_source_dir = source_dir; + tests_source_dir.extend(["tests", "cmake-project"]); + + // configure cmake for the test cmake project + let mut cmake_config = utils::cmake_config(&tests_source_dir, &build_dir); + cmake_config.arg(format!("-DCMAKE_PREFIX_PATH={};{}", metatensor_cmake_prefix.display(), metatomic_cmake_prefix.display())); + utils::run_command(cmake_config, "cmake configuration"); + + // build the code, linking to metatensor + let cmake_build = utils::cmake_build(&build_dir); + utils::run_command(cmake_build, "cmake build"); + + // run the executables + let ctest = utils::ctest(&build_dir); + utils::run_command(ctest, "ctest"); +} diff --git a/metatomic-core/tests/cmake-project/CMakeLists.txt b/metatomic-core/tests/cmake-project/CMakeLists.txt new file mode 100644 index 000000000..2b04acfa4 --- /dev/null +++ b/metatomic-core/tests/cmake-project/CMakeLists.txt @@ -0,0 +1,84 @@ +cmake_minimum_required(VERSION 3.22) + +message(STATUS "Running with CMake version ${CMAKE_VERSION}") + +project(metatomic-test-cmake-project C CXX) + +option(USE_CMAKE_SUBDIRECTORY OFF) + +if (MINGW) + # CI can't find libsdc++, so we statically link it + set(CMAKE_EXE_LINKER_FLAGS "-static-libstdc++") +endif() + + +if (USE_CMAKE_SUBDIRECTORY) + message(STATUS "Using metatomic with add_subdirectory") + # build metatomic as part of this project + add_subdirectory(../../ metatomic) + + # load metatomic from the build path + set(CMAKE_BUILD_RPATH "$") +else() + message(STATUS "Using metatomic with find_package") + # If building a dev version, we also need to update the REQUIRED_METATOMIC_VERSION + # in the same way we update the metatomic-torch version + include(../../cmake/dev-versions.cmake) + set(REQUIRED_METATOMIC_VERSION "0.1.0") + create_development_version("${REQUIRED_METATOMIC_VERSION}" METATOMIC_CORE_FULL_VERSION "metatomic-core-v") + string(REGEX REPLACE "([0-9]*)\\.([0-9]*).*" "\\1.\\2" REQUIRED_METATOMIC_VERSION ${METATOMIC_CORE_FULL_VERSION}) + + find_package(metatomic ${REQUIRED_METATOMIC_VERSION} REQUIRED) + + if(TARGET metatomic::shared) + get_target_property(mta_build_version metatomic::shared BUILD_VERSION) + if (NOT ${mta_build_version} STREQUAL ${METATOMIC_CORE_FULL_VERSION}) + message(FATAL_ERROR "Invalid BUILD_VERSION for metatomic::shared, expected ${METATOMIC_CORE_FULL_VERSION} but got ${mta_build_version}") + endif() + endif() + + if(TARGET metatomic::static) + get_target_property(mta_build_version metatomic::static BUILD_VERSION) + if (NOT ${mta_build_version} STREQUAL ${METATOMIC_CORE_FULL_VERSION}) + message(FATAL_ERROR "Invalid BUILD_VERSION for metatomic::static, expected ${METATOMIC_CORE_FULL_VERSION} but got ${mta_build_version}") + endif() + endif() +endif() + +enable_testing() + + +if(TARGET metatomic::shared) + add_executable(c-main src/main.c) + target_link_libraries(c-main metatomic::shared) + + add_executable(cxx-main src/main.cpp) + target_link_libraries(cxx-main metatomic::shared) + + add_test(NAME c-main COMMAND c-main) + add_test(NAME cxx-main COMMAND cxx-main) + + if(WIN32) + # We need to set the path to allow access to metatomic.dll + STRING(REPLACE ";" "\\;" PATH_STRING "$ENV{PATH}") + set_tests_properties(c-main PROPERTIES + ENVIRONMENT "PATH=${PATH_STRING}\;$" + ) + + set_tests_properties(cxx-main PROPERTIES + ENVIRONMENT "PATH=${PATH_STRING}\;$" + ) + endif() +endif() + + +if(TARGET metatomic::static) + add_executable(c-main-static src/main.c) + target_link_libraries(c-main-static metatomic::static) + + add_executable(cxx-main-static src/main.cpp) + target_link_libraries(cxx-main-static metatomic::static) + + add_test(NAME c-main-static COMMAND c-main-static) + add_test(NAME cxx-main-static COMMAND cxx-main-static) +endif() diff --git a/metatomic-core/tests/cmake-project/README.md b/metatomic-core/tests/cmake-project/README.md new file mode 100644 index 000000000..70a687bf0 --- /dev/null +++ b/metatomic-core/tests/cmake-project/README.md @@ -0,0 +1,3 @@ +# Sample CMake project using metatomic + +This is a basic cmake project linking to metatomic from C and C++ code. diff --git a/metatomic-core/tests/cmake-project/src/main.c b/metatomic-core/tests/cmake-project/src/main.c new file mode 100644 index 000000000..dcad0f764 --- /dev/null +++ b/metatomic-core/tests/cmake-project/src/main.c @@ -0,0 +1,8 @@ +#include + +#include + +int main(void) { + printf("Metatomic version: %s\n", mta_version()); + return 0; +} diff --git a/metatomic-core/tests/cmake-project/src/main.cpp b/metatomic-core/tests/cmake-project/src/main.cpp new file mode 100644 index 000000000..04ec152b6 --- /dev/null +++ b/metatomic-core/tests/cmake-project/src/main.cpp @@ -0,0 +1,9 @@ +#include + +#include + + +int main() { + std::cout << "Metatomic version: " << mta_version() << std::endl; + return 0; +} diff --git a/metatomic-torch/tests/external/.gitattributes b/metatomic-core/tests/external/.gitattributes similarity index 100% rename from metatomic-torch/tests/external/.gitattributes rename to metatomic-core/tests/external/.gitattributes diff --git a/metatomic-torch/tests/external/CMakeLists.txt b/metatomic-core/tests/external/CMakeLists.txt similarity index 100% rename from metatomic-torch/tests/external/CMakeLists.txt rename to metatomic-core/tests/external/CMakeLists.txt diff --git a/metatomic-torch/tests/external/catch/catch.cpp b/metatomic-core/tests/external/catch/catch.cpp similarity index 100% rename from metatomic-torch/tests/external/catch/catch.cpp rename to metatomic-core/tests/external/catch/catch.cpp diff --git a/metatomic-torch/tests/external/catch/catch.hpp b/metatomic-core/tests/external/catch/catch.hpp similarity index 100% rename from metatomic-torch/tests/external/catch/catch.hpp rename to metatomic-core/tests/external/catch/catch.hpp diff --git a/metatomic-core/tests/misc.cpp b/metatomic-core/tests/misc.cpp new file mode 100644 index 000000000..bf0ce275f --- /dev/null +++ b/metatomic-core/tests/misc.cpp @@ -0,0 +1,15 @@ +#include + +#include "metatomic.h" + + +TEST_CASE("Version macros") { + CHECK(std::string(METATOMIC_VERSION) == mta_version()); + + auto version = std::to_string(METATOMIC_VERSION_MAJOR) + "." + + std::to_string(METATOMIC_VERSION_MINOR) + "." + + std::to_string(METATOMIC_VERSION_PATCH); + + // METATOMIC_VERSION should start with `x.y.z` + CHECK(std::string(METATOMIC_VERSION).find(version) == 0); +} diff --git a/metatomic-core/tests/run-cxx-tests.rs b/metatomic-core/tests/run-cxx-tests.rs new file mode 100644 index 000000000..0d3b48d9d --- /dev/null +++ b/metatomic-core/tests/run-cxx-tests.rs @@ -0,0 +1,40 @@ +use std::path::PathBuf; + +mod utils; + +#[test] +fn run_cxx_tests() { + const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); + + let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); + build_dir.push("cxx-tests"); + std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); + + // ====================================================================== // + // setup dependencies for the torch tests + let deps_dir = build_dir.join("deps"); + let virtualenv_dir = deps_dir.join("virtualenv"); + std::fs::create_dir_all(&virtualenv_dir).expect("failed to create virtualenv dir"); + let python_exe = utils::create_python_venv(virtualenv_dir); + let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python_exe); + + // ====================================================================== // + // build the metatomic C++ tests and run them + + let mut source_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); + source_dir.push("tests"); + + // configure cmake for the tests + let mut cmake_config = utils::cmake_config(&source_dir, &build_dir); + cmake_config.arg("-DCMAKE_EXPORT_COMPILE_COMMANDS=ON"); + cmake_config.arg(format!("-DCMAKE_PREFIX_PATH={}", metatensor_cmake_prefix.display())); + utils::run_command(cmake_config, "cmake configuration"); + + // build the tests + let cmake_build = utils::cmake_build(&build_dir); + utils::run_command(cmake_build, "cmake build"); + + // run the tests + let ctest = utils::ctest(&build_dir); + utils::run_command(ctest, "ctest"); +} diff --git a/metatomic-core/tests/utils/mod.rs b/metatomic-core/tests/utils/mod.rs new file mode 100644 index 000000000..e12e6897c --- /dev/null +++ b/metatomic-core/tests/utils/mod.rs @@ -0,0 +1,470 @@ +#![allow(dead_code)] +#![allow(clippy::needless_return)] + +use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; + +fn build_type() -> &'static str { + // assume that debug assertion means that we are building the code in + // debug mode, even if that could be not true in some cases + if cfg!(debug_assertions) { + "debug" + } else { + "release" + } +} + +fn append_flags(existing: Option, extra: &str) -> String { + match existing { + Some(flags) if !flags.trim().is_empty() => format!("{flags} {extra}"), + _ => extra.into(), + } +} + +pub fn cmake_config(source_dir: &Path, build_dir: &Path) -> Command { + let cmake = which::which("cmake").expect("could not find cmake"); + + let mut cmake_config = Command::new(cmake); + cmake_config.current_dir(build_dir); + cmake_config.arg(source_dir); + cmake_config.arg("--no-warn-unused-cli"); + cmake_config.arg(format!("-DCMAKE_BUILD_TYPE={}", build_type())); + + // the cargo executable currently running + let cargo_exe = std::env::var("CARGO").expect("CARGO env var is not set"); + cmake_config.arg(format!("-DCARGO_EXE={}", cargo_exe)); + + if std::env::var_os("CARGO_LLVM_COV").is_some() { + let coverage_compile_flags = "-fprofile-instr-generate -fcoverage-mapping"; + let coverage_link_flags = "-fprofile-instr-generate"; + + let c_flags = append_flags(std::env::var("CFLAGS").ok(), coverage_compile_flags); + let cxx_flags = append_flags(std::env::var("CXXFLAGS").ok(), coverage_compile_flags); + let exe_linker_flags = + append_flags(std::env::var("LDFLAGS").ok(), coverage_link_flags); + + cmake_config.arg(format!("-DCMAKE_C_FLAGS={c_flags}")); + cmake_config.arg(format!("-DCMAKE_CXX_FLAGS={cxx_flags}")); + cmake_config.arg(format!("-DCMAKE_EXE_LINKER_FLAGS={exe_linker_flags}")); + cmake_config.arg(format!("-DCMAKE_SHARED_LINKER_FLAGS={exe_linker_flags}")); + } + + return cmake_config; +} + +pub fn cmake_build(build_dir: &Path) -> Command { + let cmake = which::which("cmake").expect("could not find cmake"); + + let mut cmake_build = Command::new(cmake); + cmake_build.current_dir(build_dir); + cmake_build.arg("--build"); + cmake_build.arg("."); + cmake_build.arg("--parallel"); + cmake_build.arg("--config"); + cmake_build.arg(build_type()); + + return cmake_build; +} + + +pub fn ctest(build_dir: &Path) -> Command { + let ctest = which::which("ctest").expect("could not find ctest"); + + let mut ctest = Command::new(ctest); + ctest.current_dir(build_dir); + ctest.arg("--output-on-failure"); + ctest.arg("--build-config"); + ctest.arg(build_type()); + + return ctest +} + +/// Find the path to the uv binary, or None if not present +fn find_uv() -> Option { + which::which("uv").ok() +} + +/// Find the path to the `python`or `python3` binary on the user system +fn find_python() -> PathBuf { + if let Ok(python) = which::which("python") { + let output = Command::new(&python) + .arg("-c") + .arg("import sys; print(sys.version_info.major)") + .output() + .expect("could not run python"); + + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + + if stdout.trim() == "3" { + // we found Python 3 + return python; + } + } + } + + // try python3 + let python = which::which("python3").expect("failed to run `which python3`"); + let output = Command::new(&python) + .arg("-c") + .arg("import sys; print(sys.version_info.major)") + .output() + .expect("could not run python"); + + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout); + if stdout.trim() == "3" { + // we found Python 3 + return python; + } + } + + panic!("could not find Python 3") +} + +/// Helper: get python executable path inside a venv +fn python_in_venv(venv_dir: &Path) -> PathBuf { + let mut python = venv_dir.to_path_buf(); + if cfg!(target_os = "windows") { + python.extend(["Scripts", "python.exe"]); + } else { + python.extend(["bin", "python"]); + } + python +} + +/// Create a fresh Python virtualenv using uv if available, else fallback to +/// `python -m venv`, and return the path to the python executable in the venv +pub fn create_python_venv(build_dir: PathBuf) -> PathBuf { + if let Some(uv_bin) = find_uv() { + let mut cmd = Command::new(&uv_bin); + cmd.arg("venv"); + cmd.arg("--clear"); + cmd.arg(&build_dir); + + run_command(cmd, "uv venv creation"); + } else { + let mut cmd = Command::new(find_python()); + cmd.arg("-m"); + cmd.arg("venv"); + cmd.arg(&build_dir); + + run_command(cmd, "python to create virtualenv with `venv`"); + + // update pip in case the system uses a very old one + let python = python_in_venv(&build_dir); + let mut cmd = Command::new(&python); + cmd.arg("-m"); + cmd.arg("pip"); + cmd.arg("install"); + cmd.arg("--upgrade"); + cmd.arg("pip"); + + run_command(cmd, "pip upgrade in virtualenv"); + } + + python_in_venv(&build_dir) +} + +#[derive(Default)] +pub struct PipInstallOptions { + pub upgrade: bool, + pub no_deps: bool, + pub no_build_isolation: bool, +} + +/// Install a package with pip (uses uv if present, else falls back to python) +fn pip_install( + python: &Path, + packages: &[&str], + options: PipInstallOptions, +) { + if let Some(uv_bin) = find_uv() { + let mut cmd = Command::new(&uv_bin); + cmd.arg("pip").arg("install").arg("--python").arg(python); + + // follow the same behavior as pip when there are multiple indexes + cmd.arg("--index-strategy"); + cmd.arg("unsafe-best-match"); + + if options.upgrade { + cmd.arg("--upgrade"); + } + if options.no_deps { + cmd.arg("--no-deps"); + } + if options.no_build_isolation { + cmd.arg("--no-build-isolation"); + // uv doesn't support --check-build-dependencies + } + + for package in packages { + cmd.arg(package); + } + + run_command(cmd, "uv pip install"); + } else { + let mut cmd = Command::new(python); + cmd.arg("-m").arg("pip").arg("install"); + if options.upgrade { + cmd.arg("--upgrade"); + } + if options.no_deps { + cmd.arg("--no-deps"); + } + if options.no_build_isolation { + // If pip, add both supported options + cmd.arg("--no-build-isolation"); + cmd.arg("--check-build-dependencies"); + } + + for package in packages { + cmd.arg(package); + } + + run_command(cmd, "pip install"); + } +} + +/// Download PyTorch in a Python virtualenv, and return the +/// CMAKE_PREFIX_PATH for the corresponding libtorch +pub fn setup_torch_pip(python: &Path) -> PathBuf { + let torch_version = std::env::var("METATOMIC_TESTS_TORCH_VERSION").unwrap_or("2.12".into()); + pip_install( + python, + &[&format!("torch=={}.*", torch_version)], + PipInstallOptions { upgrade: true, no_deps: false, no_build_isolation: false } + ); + + let mut cmd = Command::new(python); + cmd.arg("-c"); + cmd.arg("import torch; print(torch.utils.cmake_prefix_path)"); + + let output = run_command(cmd, "python to get torch cmake prefix"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let prefix = PathBuf::from(stdout.trim()); + if !prefix.exists() { + panic!("'torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + } + + return prefix; +} + +/// Install metatensor in a Python virtualenv with pip, and return the +/// CMAKE_PREFIX_PATH for the installed libmetatensor. +pub fn setup_metatensor_pip(python: &Path) -> PathBuf { + pip_install(python, &["metatensor-core >=0.2.0,<0.3"], PipInstallOptions::default()); + + let mut cmd = Command::new(python); + cmd.arg("-c"); + cmd.arg("import metatensor; print(metatensor.utils.cmake_prefix_path)"); + + let output = run_command(cmd, "python to get metatensor cmake prefix"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let prefix = PathBuf::from(stdout.trim()); + if !prefix.exists() { + panic!("'metatensor.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + } + + return prefix; +} + +/// Install metatensor-torch in a Python virtualenv with pip, and return the +/// CMAKE_PREFIX_PATH for the installed libmetatensor_torch. +pub fn setup_metatensor_torch_pip(python: &Path) -> PathBuf { + pip_install(python, &["metatensor-torch >=0.9.0,<0.10"], PipInstallOptions::default()); + + let mut cmd = Command::new(python); + cmd.arg("-c"); + cmd.arg("import metatensor.torch; print(metatensor.torch.utils.cmake_prefix_path)"); + + let output = run_command(cmd, "python to get metatensor_torch cmake prefix"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let prefix = PathBuf::from(stdout.trim()); + if !prefix.exists() { + panic!("'metatensor.torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + } + + return prefix; +} + +/// Build metatomic-torch located in `source_dir` inside `build_dir`, and return +/// the installation prefix. +pub fn setup_metatomic_torch_cmake(source_dir: &Path, build_dir: &Path, cmake_args: Vec) -> PathBuf { + std::fs::create_dir_all(build_dir).expect("failed to create metatomic build dir"); + + // configure cmake for metatomic-torch + let mut cmake_config = cmake_config(source_dir, build_dir); + + let install_prefix = build_dir.join("usr"); + cmake_config.arg(format!("-DCMAKE_INSTALL_PREFIX={}", install_prefix.display())); + + // Add any additional cmake arguments + for arg in cmake_args { + cmake_config.arg(arg); + } + + run_command(cmake_config, "cmake configuration for metatomic_torch"); + + // build and install metatomic-torch + let mut cmake_build = cmake_build(build_dir); + cmake_build.arg("--target"); + cmake_build.arg("install"); + + run_command(cmake_build, "cmake build for metatomic_torch"); + + install_prefix +} + +/// Build metatomic-core located in `source_dir` inside `build_dir`, and return +/// the installation prefix +pub fn setup_metatomic_cmake(source_dir: &Path, build_dir: &Path, cmake_args: Vec) -> PathBuf { + std::fs::create_dir_all(build_dir).expect("failed to create metatomic build dir"); + + // configure cmake for metatomic + let mut cmake_config = cmake_config(source_dir, build_dir); + + let install_prefix = build_dir.join("usr"); + cmake_config.arg(format!("-DCMAKE_INSTALL_PREFIX={}", install_prefix.display())); + + // Add any additional cmake arguments + for arg in cmake_args { + cmake_config.arg(arg); + } + + run_command(cmake_config, "cmake configuration for metatomic"); + + // build and install metatomic + let mut cmake_build = cmake_build(build_dir); + cmake_build.arg("--target"); + cmake_build.arg("install"); + + run_command(cmake_build, "cmake build for metatomic"); + + install_prefix +} + +/// Install metatomic-core in a Python virtualenv with pip, and return the +/// CMAKE_PREFIX_PATH for the installed libmetatomic. +pub fn setup_metatomic_core_pip(python: &Path, source_dir: &Path) -> PathBuf { + pip_install( + python, + &["cmake", "packaging >=26", "setuptools >=77"], + PipInstallOptions::default() + ); + + pip_install( + python, + &[&source_dir.display().to_string()], + PipInstallOptions { + upgrade: true, + no_deps: true, + no_build_isolation: true + } + ); + + // let mut cmd = Command::new(python); + // cmd.arg("-c"); + // cmd.arg("import metatomic; print(metatomic.utils.cmake_prefix_path)"); + + // let output = run_command(cmd, "python to get metatomic cmake prefix"); + + // let stdout = String::from_utf8_lossy(&output.stdout); + // let prefix = PathBuf::from(stdout.trim()); + // if !prefix.exists() { + // panic!("'metatomic.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + // } + + // return prefix; + return PathBuf::new(); +} + + +/// Install metatomic-torch in a Python virtualenv with pip, and return the +/// CMAKE_PREFIX_PATH for the installed libmetatomic_torch. +pub fn setup_metatomic_torch_pip(python: &Path, source_dir: &Path) -> PathBuf { + pip_install( + python, + &[&source_dir.display().to_string()], + PipInstallOptions { + upgrade: true, + no_deps: true, + no_build_isolation: true + } + ); + + let mut cmd = Command::new(python); + cmd.arg("-c"); + cmd.arg("import metatomic.torch; print(metatomic.torch.utils.cmake_prefix_path)"); + + let output = run_command(cmd, "python to get metatomic_torch cmake prefix"); + + let stdout = String::from_utf8_lossy(&output.stdout); + let prefix = PathBuf::from(stdout.trim()); + if !prefix.exists() { + panic!("'metatomic.torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); + } + + return prefix; +} + +pub fn run_command(mut command: Command, context: &str) -> std::process::Output { + write!(std::io::stdout().lock(), "\n\n[Running] {:?}\n\n", command).unwrap(); + + let mut child = command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn().unwrap_or_else(|_| panic!("failed to spawn {}", context)); + + let mut child_stdout = child.stdout.take().expect("missing stdout"); + let mut child_stderr = child.stderr.take().expect("missing stderr"); + + let out_handle = std::thread::spawn(move || -> std::io::Result> { + let mut buf = [0u8; 8192]; + let mut captured = Vec::new(); + let mut sink = std::io::stdout().lock(); + loop { + let n = child_stdout.read(&mut buf)?; + if n == 0 { + break; + } + sink.write_all(&buf[..n])?; + sink.flush()?; + captured.extend_from_slice(&buf[..n]); + } + Ok(captured) + }); + + let err_handle = std::thread::spawn(move || -> std::io::Result> { + let mut buf = [0u8; 8192]; + let mut captured = Vec::new(); + let mut sink = std::io::stderr().lock(); + loop { + let n = child_stderr.read(&mut buf)?; + if n == 0 { + break; + } + sink.write_all(&buf[..n])?; + sink.flush()?; + captured.extend_from_slice(&buf[..n]); + } + Ok(captured) + }); + + let status = child.wait().unwrap_or_else(|_| panic!("failed to run {}", context)); + let stdout = String::from_utf8_lossy(&out_handle.join().unwrap().unwrap()).into_owned(); + let stderr = String::from_utf8_lossy(&err_handle.join().unwrap().unwrap()).into_owned(); + + if !status.success() { + panic!( + "{} failed, status: {}\nstderr:\n\n{}\nstdout:\n\n{}\n", + context, status, stderr, stdout + ); + } + + return std::process::Output { status, stdout: stdout.into_bytes(), stderr: stderr.into_bytes() }; +} diff --git a/metatomic-torch/tests/CMakeLists.txt b/metatomic-torch/tests/CMakeLists.txt index 8a64a4f33..7d6257a0d 100644 --- a/metatomic-torch/tests/CMakeLists.txt +++ b/metatomic-torch/tests/CMakeLists.txt @@ -1,4 +1,5 @@ -add_subdirectory(external) +# re-use catch from metatomic-core C++ tests +add_subdirectory(../../metatomic-core/tests/external external) # make sure we compile catch with the flags that torch requires. In particular, # torch sets -D_GLIBCXX_USE_CXX11_ABI=0 on Linux, which changes some of the diff --git a/metatomic-torch/tests/check-torch-install.rs b/metatomic-torch/tests/check-torch-install.rs index 8883d916e..ad8cfb604 100644 --- a/metatomic-torch/tests/check-torch-install.rs +++ b/metatomic-torch/tests/check-torch-install.rs @@ -123,8 +123,11 @@ fn check_python_install() { let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python_exe); let metatensor_torch_cmake_prefix = utils::setup_metatensor_torch_pip(&python_exe); - let python_source_dir = cargo_manifest_dir.parent().unwrap().join("python").join("metatomic_torch"); - let metatomic_torch_cmake_prefix = utils::setup_metatomic_torch_pip(&python_exe, &python_source_dir); + let mta_core_source_dir = cargo_manifest_dir.parent().unwrap().join("python").join("metatomic_core"); + let metatomic_core_cmake_prefix = utils::setup_metatomic_core_pip(&python_exe, &mta_core_source_dir); + + let mta_torch_source_dir = cargo_manifest_dir.parent().unwrap().join("python").join("metatomic_torch"); + let metatomic_torch_cmake_prefix = utils::setup_metatomic_torch_pip(&python_exe, &mta_torch_source_dir); // ====================================================================== // // try to use the installed metatensor-torch from cmake @@ -134,10 +137,11 @@ fn check_python_install() { // configure cmake for the test cmake project let mut cmake_config = utils::cmake_config(&source_dir, &build_dir); cmake_config.arg(format!( - "-DCMAKE_PREFIX_PATH={};{};{};{}", + "-DCMAKE_PREFIX_PATH={};{};{};{};{}", pytorch_cmake_prefix.display(), metatensor_cmake_prefix.display(), metatensor_torch_cmake_prefix.display(), + metatomic_core_cmake_prefix.display(), metatomic_torch_cmake_prefix.display(), )); diff --git a/metatomic-torch/tests/utils/mod.rs b/metatomic-torch/tests/utils/mod.rs deleted file mode 100644 index e223bcee6..000000000 --- a/metatomic-torch/tests/utils/mod.rs +++ /dev/null @@ -1,410 +0,0 @@ -#![allow(dead_code)] -#![allow(clippy::needless_return)] - -use std::io::{Read, Write}; -use std::path::{Path, PathBuf}; -use std::process::{Command, Stdio}; - -fn build_type() -> &'static str { - // assume that debug assertion means that we are building the code in - // debug mode, even if that could be not true in some cases - if cfg!(debug_assertions) { - "debug" - } else { - "release" - } -} - -fn append_flags(existing: Option, extra: &str) -> String { - match existing { - Some(flags) if !flags.trim().is_empty() => format!("{flags} {extra}"), - _ => extra.into(), - } -} - -pub fn cmake_config(source_dir: &Path, build_dir: &Path) -> Command { - let cmake = which::which("cmake").expect("could not find cmake"); - - let mut cmake_config = Command::new(cmake); - cmake_config.current_dir(build_dir); - cmake_config.arg(source_dir); - cmake_config.arg("--no-warn-unused-cli"); - cmake_config.arg(format!("-DCMAKE_BUILD_TYPE={}", build_type())); - - // the cargo executable currently running - let cargo_exe = std::env::var("CARGO").expect("CARGO env var is not set"); - cmake_config.arg(format!("-DCARGO_EXE={}", cargo_exe)); - - if std::env::var_os("CARGO_LLVM_COV").is_some() { - let coverage_compile_flags = "-fprofile-instr-generate -fcoverage-mapping"; - let coverage_link_flags = "-fprofile-instr-generate"; - - let c_flags = append_flags(std::env::var("CFLAGS").ok(), coverage_compile_flags); - let cxx_flags = append_flags(std::env::var("CXXFLAGS").ok(), coverage_compile_flags); - let exe_linker_flags = - append_flags(std::env::var("LDFLAGS").ok(), coverage_link_flags); - - cmake_config.arg(format!("-DCMAKE_C_FLAGS={c_flags}")); - cmake_config.arg(format!("-DCMAKE_CXX_FLAGS={cxx_flags}")); - cmake_config.arg(format!("-DCMAKE_EXE_LINKER_FLAGS={exe_linker_flags}")); - cmake_config.arg(format!("-DCMAKE_SHARED_LINKER_FLAGS={exe_linker_flags}")); - } - - return cmake_config; -} - -pub fn cmake_build(build_dir: &Path) -> Command { - let cmake = which::which("cmake").expect("could not find cmake"); - - let mut cmake_build = Command::new(cmake); - cmake_build.current_dir(build_dir); - cmake_build.arg("--build"); - cmake_build.arg("."); - cmake_build.arg("--parallel"); - cmake_build.arg("--config"); - cmake_build.arg(build_type()); - - return cmake_build; -} - - -pub fn ctest(build_dir: &Path) -> Command { - let ctest = which::which("ctest").expect("could not find ctest"); - - let mut ctest = Command::new(ctest); - ctest.current_dir(build_dir); - ctest.arg("--output-on-failure"); - ctest.arg("--build-config"); - ctest.arg(build_type()); - - return ctest -} - -/// Find the path to the uv binary, or None if not present -fn find_uv() -> Option { - which::which("uv").ok() -} - -/// Find the path to the `python`or `python3` binary on the user system -fn find_python() -> PathBuf { - if let Ok(python) = which::which("python") { - let output = Command::new(&python) - .arg("-c") - .arg("import sys; print(sys.version_info.major)") - .output() - .expect("could not run python"); - - if output.status.success() { - let stdout = String::from_utf8_lossy(&output.stdout); - - if stdout.trim() == "3" { - // we found Python 3 - return python; - } - } - } - - // try python3 - let python = which::which("python3").expect("failed to run `which python3`"); - let output = Command::new(&python) - .arg("-c") - .arg("import sys; print(sys.version_info.major)") - .output() - .expect("could not run python"); - - if output.status.success() { - let stdout = String::from_utf8_lossy(&output.stdout); - if stdout.trim() == "3" { - // we found Python 3 - return python; - } - } - - panic!("could not find Python 3") -} - -/// Helper: get python executable path inside a venv -fn python_in_venv(venv_dir: &Path) -> PathBuf { - let mut python = venv_dir.to_path_buf(); - if cfg!(target_os = "windows") { - python.extend(["Scripts", "python.exe"]); - } else { - python.extend(["bin", "python"]); - } - python -} - -/// Create a fresh Python virtualenv using uv if available, else fallback to -/// `python -m venv`, and return the path to the python executable in the venv -pub fn create_python_venv(build_dir: PathBuf) -> PathBuf { - if let Some(uv_bin) = find_uv() { - let mut cmd = Command::new(&uv_bin); - cmd.arg("venv"); - cmd.arg("--clear"); - cmd.arg(&build_dir); - - run_command(cmd, "uv venv creation"); - } else { - let mut cmd = Command::new(find_python()); - cmd.arg("-m"); - cmd.arg("venv"); - cmd.arg(&build_dir); - - run_command(cmd, "python to create virtualenv with `venv`"); - - // update pip in case the system uses a very old one - let python = python_in_venv(&build_dir); - let mut cmd = Command::new(&python); - cmd.arg("-m"); - cmd.arg("pip"); - cmd.arg("install"); - cmd.arg("--upgrade"); - cmd.arg("pip"); - - run_command(cmd, "pip upgrade in virtualenv"); - } - - python_in_venv(&build_dir) -} - -#[derive(Default)] -pub struct PipInstallOptions { - pub upgrade: bool, - pub no_deps: bool, - pub no_build_isolation: bool, -} - -/// Install a package with pip (uses uv if present, else falls back to python) -fn pip_install( - python: &Path, - packages: &[&str], - options: PipInstallOptions, -) { - if let Some(uv_bin) = find_uv() { - let mut cmd = Command::new(&uv_bin); - cmd.arg("pip").arg("install").arg("--python").arg(python); - - // follow the same behavior as pip when there are multiple indexes - cmd.arg("--index-strategy"); - cmd.arg("unsafe-best-match"); - - if options.upgrade { - cmd.arg("--upgrade"); - } - if options.no_deps { - cmd.arg("--no-deps"); - } - if options.no_build_isolation { - cmd.arg("--no-build-isolation"); - // uv doesn't support --check-build-dependencies - } - - for package in packages { - cmd.arg(package); - } - - run_command(cmd, "uv pip install"); - } else { - let mut cmd = Command::new(python); - cmd.arg("-m").arg("pip").arg("install"); - if options.upgrade { - cmd.arg("--upgrade"); - } - if options.no_deps { - cmd.arg("--no-deps"); - } - if options.no_build_isolation { - // If pip, add both supported options - cmd.arg("--no-build-isolation"); - cmd.arg("--check-build-dependencies"); - } - - for package in packages { - cmd.arg(package); - } - - run_command(cmd, "pip install"); - } -} - -/// Download PyTorch in a Python virtualenv, and return the -/// CMAKE_PREFIX_PATH for the corresponding libtorch -pub fn setup_torch_pip(python: &Path) -> PathBuf { - let torch_version = std::env::var("METATOMIC_TESTS_TORCH_VERSION").unwrap_or("2.12".into()); - pip_install( - python, - &[&format!("torch=={}.*", torch_version)], - PipInstallOptions { upgrade: true, no_deps: false, no_build_isolation: false } - ); - - let mut cmd = Command::new(python); - cmd.arg("-c"); - cmd.arg("import torch; print(torch.utils.cmake_prefix_path)"); - - let output = run_command(cmd, "python to get torch cmake prefix"); - - let stdout = String::from_utf8_lossy(&output.stdout); - let prefix = PathBuf::from(stdout.trim()); - if !prefix.exists() { - panic!("'torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); - } - - return prefix; -} - -/// Install metatensor in a Python virtualenv with pip, and return the -/// CMAKE_PREFIX_PATH for the installed libmetatensor. -pub fn setup_metatensor_pip(python: &Path) -> PathBuf { - pip_install(python, &["metatensor-core >=0.2.0,<0.3"], PipInstallOptions::default()); - - let mut cmd = Command::new(python); - cmd.arg("-c"); - cmd.arg("import metatensor; print(metatensor.utils.cmake_prefix_path)"); - - let output = run_command(cmd, "python to get metatensor cmake prefix"); - - let stdout = String::from_utf8_lossy(&output.stdout); - let prefix = PathBuf::from(stdout.trim()); - if !prefix.exists() { - panic!("'metatensor.utils.cmake_prefix' at '{}' does not exist", prefix.display()); - } - - return prefix; -} - -/// Install metatensor-torch in a Python virtualenv with pip, and return the -/// CMAKE_PREFIX_PATH for the installed libmetatensor_torch. -pub fn setup_metatensor_torch_pip(python: &Path) -> PathBuf { - pip_install(python, &["metatensor-torch >=0.9.0,<0.10"], PipInstallOptions::default()); - - let mut cmd = Command::new(python); - cmd.arg("-c"); - cmd.arg("import metatensor.torch; print(metatensor.torch.utils.cmake_prefix_path)"); - - let output = run_command(cmd, "python to get metatensor_torch cmake prefix"); - - let stdout = String::from_utf8_lossy(&output.stdout); - let prefix = PathBuf::from(stdout.trim()); - if !prefix.exists() { - panic!("'metatensor.torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); - } - - return prefix; -} - -/// Build metatomic-torch located in `source_dir` inside `build_dir`, and return -/// the installation prefix. -pub fn setup_metatomic_torch_cmake(source_dir: &Path, build_dir: &Path, cmake_args: Vec) -> PathBuf { - std::fs::create_dir_all(build_dir).expect("failed to create metatomic build dir"); - - // configure cmake for metatomic-torch - let mut cmake_config = cmake_config(source_dir, build_dir); - - let install_prefix = build_dir.join("usr"); - cmake_config.arg(format!("-DCMAKE_INSTALL_PREFIX={}", install_prefix.display())); - - // Add any additional cmake arguments - for arg in cmake_args { - cmake_config.arg(arg); - } - - run_command(cmake_config, "cmake configuration for metatomic_torch"); - - // build and install metatomic-torch - let mut cmake_build = cmake_build(build_dir); - cmake_build.arg("--target"); - cmake_build.arg("install"); - - run_command(cmake_build, "cmake build for metatomic_torch"); - - install_prefix -} - - -/// Install metatomic-torch in a Python virtualenv with pip, and return the -/// CMAKE_PREFIX_PATH for the installed libmetatomic_torch. -pub fn setup_metatomic_torch_pip(python: &Path, source_dir: &Path) -> PathBuf { - pip_install(python, &["setuptools>=77", "packaging>=23", "cmake"], PipInstallOptions::default()); - - pip_install( - python, - &[&source_dir.display().to_string()], - PipInstallOptions { - upgrade: true, - no_deps: false, - no_build_isolation: true - } - ); - - let mut cmd = Command::new(python); - cmd.arg("-c"); - cmd.arg("import metatomic.torch; print(metatomic.torch.utils.cmake_prefix_path)"); - - let output = run_command(cmd, "python to get metatomic_torch cmake prefix"); - - let stdout = String::from_utf8_lossy(&output.stdout); - let prefix = PathBuf::from(stdout.trim()); - if !prefix.exists() { - panic!("'metatomic.torch.utils.cmake_prefix' at '{}' does not exist", prefix.display()); - } - - return prefix; -} - - -pub fn run_command(mut command: Command, context: &str) -> std::process::Output { - write!(std::io::stdout().lock(), "\n\n[Running] {:?}\n\n", command).unwrap(); - - let mut child = command - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn().unwrap_or_else(|_| panic!("failed to spawn {}", context)); - - let mut child_stdout = child.stdout.take().expect("missing stdout"); - let mut child_stderr = child.stderr.take().expect("missing stderr"); - - let out_handle = std::thread::spawn(move || -> std::io::Result> { - let mut buf = [0u8; 8192]; - let mut captured = Vec::new(); - let mut sink = std::io::stdout().lock(); - loop { - let n = child_stdout.read(&mut buf)?; - if n == 0 { - break; - } - sink.write_all(&buf[..n])?; - sink.flush()?; - captured.extend_from_slice(&buf[..n]); - } - Ok(captured) - }); - - let err_handle = std::thread::spawn(move || -> std::io::Result> { - let mut buf = [0u8; 8192]; - let mut captured = Vec::new(); - let mut sink = std::io::stderr().lock(); - loop { - let n = child_stderr.read(&mut buf)?; - if n == 0 { - break; - } - sink.write_all(&buf[..n])?; - sink.flush()?; - captured.extend_from_slice(&buf[..n]); - } - Ok(captured) - }); - - let status = child.wait().unwrap_or_else(|_| panic!("failed to run {}", context)); - let stdout = String::from_utf8_lossy(&out_handle.join().unwrap().unwrap()).into_owned(); - let stderr = String::from_utf8_lossy(&err_handle.join().unwrap().unwrap()).into_owned(); - - if !status.success() { - panic!( - "{} failed, status: {}\nstderr:\n\n{}\nstdout:\n\n{}\n", - context, status, stderr, stdout - ); - } - - return std::process::Output { status, stdout: stdout.into_bytes(), stderr: stderr.into_bytes() }; -} diff --git a/metatomic-torch/tests/utils/mod.rs b/metatomic-torch/tests/utils/mod.rs new file mode 120000 index 000000000..20b8b0094 --- /dev/null +++ b/metatomic-torch/tests/utils/mod.rs @@ -0,0 +1 @@ +../../../metatomic-core/tests/utils/mod.rs \ No newline at end of file diff --git a/python/metatomic_torch/build-backend/backend.py b/python/metatomic_torch/build-backend/backend.py index c762d91e6..be0389a2c 100644 --- a/python/metatomic_torch/build-backend/backend.py +++ b/python/metatomic_torch/build-backend/backend.py @@ -1,11 +1,24 @@ # This is a custom Python build backend wrapping setuptool's to only depend on # torch/metatensor-torch when building the wheel and not the sdist import os +import pathlib from setuptools import build_meta -ROOT = os.path.realpath(os.path.dirname(__file__)) +ROOT = pathlib.Path(__file__).parent.resolve() + +METATOMIC_CORE = (ROOT / ".." / ".." / "metatomic_core").resolve() +METATOMIC_NO_LOCAL_DEPS = os.environ.get("METATOMIC_NO_LOCAL_DEPS", "0") == "1" + + +if not METATOMIC_NO_LOCAL_DEPS and METATOMIC_CORE.exists(): + # we are building from a git checkout + METATOMIC_CORE_DEP = f"metatomic-core @ {METATOMIC_CORE.as_uri()}" +else: + # we are building from a sdist + METATOMIC_CORE_DEP = "metatomic-core >=0.1.0,<0.2" + FORCED_TORCH_VERSION = os.environ.get("METATOMIC_TORCH_BUILD_WITH_TORCH_VERSION") if FORCED_TORCH_VERSION is not None: @@ -27,7 +40,7 @@ # Special dependencies to build the wheels def get_requires_for_build_wheel(config_settings=None): defaults = build_meta.get_requires_for_build_wheel(config_settings) - return defaults + [TORCH_DEP] + return defaults + [TORCH_DEP, METATOMIC_CORE_DEP] def build_editable(wheel_directory, config_settings=None, metadata_directory=None): From 6400c695680f21b3b51edbb1ba53ee92b983b569 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Wed, 27 May 2026 16:25:27 +0200 Subject: [PATCH 05/20] Draft the C API for metatomic-core --- .github/workflows/rust-tests.yml | 186 ++++++++++++++ .github/workflows/torch-tests.yml | 9 +- docs/Doxyfile | 4 +- docs/src/core/CHANGELOG.md | 1 + docs/src/core/index.rst | 17 ++ docs/src/core/reference/c/index.rst | 17 ++ docs/src/core/reference/c/misc.rst | 56 +++++ docs/src/core/reference/c/model.rst | 16 ++ docs/src/core/reference/c/plugin.rst | 16 ++ docs/src/core/reference/c/system.rst | 42 ++++ docs/src/index.rst | 1 + metatomic-core/Cargo.toml | 11 + metatomic-core/build.rs | 1 + .../cmake/metatomic-config.in.cmake | 2 + metatomic-core/include/metatomic.h | 237 ++++++++++++++++++ metatomic-core/include/metatomic.hpp | 4 +- metatomic-core/include/metatomic/plugin.hpp | 7 + metatomic-core/include/metatomic/utils.hpp | 7 + metatomic-core/src/c_api/mod.rs | 25 +- metatomic-core/src/c_api/model.rs | 76 ++++++ metatomic-core/src/c_api/plugin.rs | 41 +++ metatomic-core/src/c_api/status.rs | 42 ++++ metatomic-core/src/c_api/system.rs | 131 ++++++++++ metatomic-core/src/c_api/utils.rs | 102 ++++++++ metatomic-core/src/lib.rs | 43 +++- metatomic-core/src/metadata.rs | 132 ++++++++++ metatomic-core/src/model.rs | 20 ++ metatomic-core/src/plugin.rs | 37 +++ metatomic-core/src/system.rs | 53 ++++ metatomic-core/src/units.rs | 7 + metatomic-core/tests/check-cxx-install.rs | 8 +- metatomic-torch/tests/check-torch-install.rs | 31 ++- rustfmt.toml | 1 + scripts/check-c-api-docs.py | 101 ++++++++ scripts/include/README | 4 + scripts/include/metatensor.h | 8 + scripts/include/metatomic/version.h | 0 scripts/include/stdarg.h | 0 scripts/include/stdbool.h | 1 + scripts/include/stddef.h | 6 + scripts/include/stdint.h | 7 + scripts/include/stdlib.h | 1 + 42 files changed, 1475 insertions(+), 36 deletions(-) create mode 100644 .github/workflows/rust-tests.yml create mode 120000 docs/src/core/CHANGELOG.md create mode 100644 docs/src/core/index.rst create mode 100644 docs/src/core/reference/c/index.rst create mode 100644 docs/src/core/reference/c/misc.rst create mode 100644 docs/src/core/reference/c/model.rst create mode 100644 docs/src/core/reference/c/plugin.rst create mode 100644 docs/src/core/reference/c/system.rst create mode 100644 metatomic-core/include/metatomic/plugin.hpp create mode 100644 metatomic-core/include/metatomic/utils.hpp create mode 100644 metatomic-core/src/c_api/model.rs create mode 100644 metatomic-core/src/c_api/plugin.rs create mode 100644 metatomic-core/src/c_api/status.rs create mode 100644 metatomic-core/src/c_api/system.rs create mode 100644 metatomic-core/src/c_api/utils.rs create mode 100644 metatomic-core/src/metadata.rs create mode 100644 metatomic-core/src/model.rs create mode 100644 metatomic-core/src/plugin.rs create mode 100644 metatomic-core/src/system.rs create mode 100644 metatomic-core/src/units.rs create mode 100644 rustfmt.toml create mode 100755 scripts/check-c-api-docs.py create mode 100644 scripts/include/README create mode 100644 scripts/include/metatensor.h create mode 100644 scripts/include/metatomic/version.h create mode 100644 scripts/include/stdarg.h create mode 100644 scripts/include/stdbool.h create mode 100644 scripts/include/stddef.h create mode 100644 scripts/include/stdint.h create mode 100644 scripts/include/stdlib.h diff --git a/.github/workflows/rust-tests.yml b/.github/workflows/rust-tests.yml new file mode 100644 index 000000000..77ac78387 --- /dev/null +++ b/.github/workflows/rust-tests.yml @@ -0,0 +1,186 @@ +name: Rust tests + +on: + push: + branches: [main] + pull_request: + # Check all PR + +concurrency: + group: rust-tests-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +jobs: + rust-tests: + name: ${{ matrix.os }} / Rust ${{ matrix.rust-version }}${{ matrix.extra-name }} + runs-on: ${{ matrix.os }} + container: ${{ matrix.container }} + defaults: + run: + shell: "bash" + env: + CMAKE_CXX_COMPILER: ${{ matrix.cxx }} + CMAKE_C_COMPILER: ${{ matrix.cc }} + CMAKE_GENERATOR: ${{ matrix.cmake-generator }} + strategy: + matrix: + include: + - os: ubuntu-24.04 + rust-version: stable + rust-target: x86_64-unknown-linux-gnu + cxx: g++ + cc: gcc + cmake-generator: Unix Makefiles + + # check the build on a stock Ubuntu 22.04, which uses cmake 3.22, and + # with our minimal supported rust version + - os: ubuntu-24.04 + rust-version: 1.74 + container: ubuntu:22.04 + rust-target: x86_64-unknown-linux-gnu + extra-name: ", cmake 3.22" + cxx: g++ + cc: gcc + cmake-generator: Unix Makefiles + + - os: macos-15 + rust-version: stable + rust-target: aarch64-apple-darwin + extra-name: "" + cxx: clang++ + cc: clang + cmake-generator: Unix Makefiles + + # - os: windows-2022 + # rust-version: stable + # rust-target: x86_64-pc-windows-msvc + # extra-name: " / MSVC" + # cxx: cl.exe + # cc: cl.exe + # cmake-generator: Visual Studio 17 2022 + + # - os: windows-2022 + # rust-version: stable + # rust-target: x86_64-pc-windows-gnu + # extra-name: " / MinGW" + # cxx: g++.exe + # cc: gcc.exe + # cmake-generator: MinGW Makefiles + steps: + - name: install dependencies in container + if: matrix.container == 'ubuntu:22.04' + run: | + apt update + apt install -y software-properties-common + apt install -y cmake make gcc g++ git curl python3-venv + + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Configure git safe directory + if: matrix.container == 'ubuntu:22.04' + run: git config --global --add safe.directory /__w/metatomic/metatomic + + - name: setup rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust-version }} + target: ${{ matrix.rust-target }} + + - name: setup Python + uses: actions/setup-python@v6 + if: matrix.container == null + with: + # Python 3.14.5 fails with "No module named pip.__main__; 'pip' is a + # package and cannot be directly executed" when using a venv, so we + # use 3.14.4 for now + python-version: "3.14.4" + + - name: Cache Rust dependencies + uses: Leafwing-Studios/cargo-cache@v2.6.1 + with: + sweep-cache: true + + - name: install valgrind + if: matrix.do-valgrind + run: | + sudo apt-get install -y valgrind + + - name: Setup sccache + if: ${{ !env.ACT }} + uses: mozilla-actions/sccache-action@v0.0.10 + with: + version: "v0.15.0" + + - name: Setup sccache environnement variables + if: ${{ !env.ACT }} + run: | + echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV + echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV + echo "CMAKE_C_COMPILER_LAUNCHER=sccache" >> $GITHUB_ENV + echo "CMAKE_CXX_COMPILER_LAUNCHER=sccache" >> $GITHUB_ENV + + - name: run tests + run: | + cargo test --package metatomic-core --target ${{ matrix.rust-target }} + env: + RUST_BACKTRACE: full + + - name: check that the header was already up to date + run: | + git diff --exit-code + + # check that the C API declarations are correctly documented and used + prevent-bitrot: + runs-on: ubuntu-24.04 + name: check C API declarations + steps: + - uses: actions/checkout@v6 + + - name: setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.14" + + - name: install python dependencies + run: | + pip install pycparser + + - name: check that C API functions are all documented + run: | + python scripts/check-c-api-docs.py + + # make sure no debug print stays in the code + check-debug-prints: + runs-on: ubuntu-24.04 + name: check leftover debug print + + steps: + - uses: actions/checkout@v6 + + - name: install ripgrep + run: | + wget https://github.com/BurntSushi/ripgrep/releases/download/13.0.0/ripgrep-13.0.0-x86_64-unknown-linux-musl.tar.gz + tar xf ripgrep-13.0.0-x86_64-unknown-linux-musl.tar.gz + echo "$(pwd)/ripgrep-13.0.0-x86_64-unknown-linux-musl" >> $GITHUB_PATH + + - name: check for leftover dbg! + run: | + # use ripgrep (rg) to check for instances of `dbg!` in rust files. + # rg will return 1 if it fails to find a match, so we invert it again + # with the `!` builtin to get the error/success in CI + + ! rg "dbg!" --type=rust --quiet + + - name: check for leftover \#include + run: | + ! rg "" --iglob "\!metatomic-core/tests/cpp/external/catch/catch.hpp" --quiet + + - name: check for leftover std::cout + run: | + ! rg "cout" --iglob "\!metatomic-core/tests/cpp/external/catch/catch.hpp" --quiet + + - name: check for leftover std::cerr + run: | + ! rg "cerr" --iglob "\!metatomic-core/tests/cpp/external/catch/catch.hpp" --quiet diff --git a/.github/workflows/torch-tests.yml b/.github/workflows/torch-tests.yml index ce3781709..b088f7501 100644 --- a/.github/workflows/torch-tests.yml +++ b/.github/workflows/torch-tests.yml @@ -20,7 +20,10 @@ jobs: include: - os: ubuntu-24.04 torch-version: "2.12" - python-version: "3.14" + # Python 3.14.5 fails with "No module named pip.__main__; 'pip' is a + # package and cannot be directly executed" when using a venv, so we + # use 3.14.4 for now + python-version: "3.14.4" cargo-test-flags: --release do-valgrind: true @@ -33,12 +36,12 @@ jobs: - os: macos-15 torch-version: "2.12" - python-version: "3.14" + python-version: "3.14.4" cargo-test-flags: --release - os: windows-2022 torch-version: "2.12" - python-version: "3.14" + python-version: "3.14.4" cargo-test-flags: --release steps: - name: install dependencies in container diff --git a/docs/Doxyfile b/docs/Doxyfile index f48f15ed9..5cf71fe6f 100644 --- a/docs/Doxyfile +++ b/docs/Doxyfile @@ -991,7 +991,9 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../metatomic-torch/include/metatomic \ +INPUT = ../metatomic-core/include/ \ + ../metatomic-core/include/metatomic \ + ../metatomic-torch/include/metatomic \ ../metatomic-torch/include/metatomic/torch # This tag can be used to specify the character encoding of the source files diff --git a/docs/src/core/CHANGELOG.md b/docs/src/core/CHANGELOG.md new file mode 120000 index 000000000..a344bc46b --- /dev/null +++ b/docs/src/core/CHANGELOG.md @@ -0,0 +1 @@ +../../../metatomic-core/CHANGELOG.md \ No newline at end of file diff --git a/docs/src/core/index.rst b/docs/src/core/index.rst new file mode 100644 index 000000000..60512b353 --- /dev/null +++ b/docs/src/core/index.rst @@ -0,0 +1,17 @@ +Core Classes +============ + +WIP + + +.. toctree:: + :maxdepth: 2 + + reference/c/index + + +.. toctree:: + :maxdepth: 1 + :hidden: + + CHANGELOG.md diff --git a/docs/src/core/reference/c/index.rst b/docs/src/core/reference/c/index.rst new file mode 100644 index 000000000..f190a5e74 --- /dev/null +++ b/docs/src/core/reference/c/index.rst @@ -0,0 +1,17 @@ +.. _c-api-core: + +C API reference +=============== + +WIP + +The functions and types provided in ``metatomic.h`` can be grouped in four +main groups: + +.. toctree:: + :maxdepth: 1 + + system + model + plugin + misc diff --git a/docs/src/core/reference/c/misc.rst b/docs/src/core/reference/c/misc.rst new file mode 100644 index 000000000..6aec886bc --- /dev/null +++ b/docs/src/core/reference/c/misc.rst @@ -0,0 +1,56 @@ +Miscellaneous +============= + +Version number +^^^^^^^^^^^^^^ + +.. doxygenfunction:: mta_version + +.. c:macro:: METATOMIC_VERSION + + Macro containing the compile-time version of metatomic, as a string + +.. c:macro:: METATOMIC_VERSION_MAJOR + + Macro containing the compile-time **major** version number of metatomic, as + an integer + +.. c:macro:: METATOMIC_VERSION_MINOR + + Macro containing the compile-time **minor** version number of metatomic, as + an integer + +.. c:macro:: METATOMIC_VERSION_PATCH + + Macro containing the compile-time **patch** version number of metatomic, as + an integer + + +Error handling +^^^^^^^^^^^^^^ + +.. doxygenfunction:: mta_last_error + +.. doxygenfunction:: mta_set_last_error + +.. doxygenenum:: mta_status_t + + +String manipulation +^^^^^^^^^^^^^^^^^^^ + +.. doxygentypedef:: mta_string_t + +.. doxygenfunction:: mta_string_create + +.. doxygenfunction:: mta_string_free + +.. doxygenfunction:: mta_string_view + +.. doxygenfunction:: mta_format_metadata + + +Unit conversion +^^^^^^^^^^^^^^^ + +.. doxygenfunction:: mta_unit_conversion_factor diff --git a/docs/src/core/reference/c/model.rst b/docs/src/core/reference/c/model.rst new file mode 100644 index 000000000..6a3d9ee38 --- /dev/null +++ b/docs/src/core/reference/c/model.rst @@ -0,0 +1,16 @@ +Model +===== + +.. doxygenstruct:: mta_model_t + :members: + +The following functions operate on :c:type:`mta_model_t`: + +- :c:func:`mta_load_model`: TODO summary +- :c:func:`mta_execute_model`: TODO summary + +-------------------------------------------------------------------------------- + +.. doxygenfunction:: mta_load_model + +.. doxygenfunction:: mta_execute_model diff --git a/docs/src/core/reference/c/plugin.rst b/docs/src/core/reference/c/plugin.rst new file mode 100644 index 000000000..952650f4c --- /dev/null +++ b/docs/src/core/reference/c/plugin.rst @@ -0,0 +1,16 @@ +Plugin system +============= + +.. doxygenstruct:: mta_plugin_t + :members: + +The following functions operate on :c:type:`mta_plugin_t`: + +- :c:func:`mta_register_plugin`: TODO summary +- :c:func:`mta_load_plugin`: TODO summary + +-------------------------------------------------------------------------------- + +.. doxygenfunction:: mta_register_plugin + +.. doxygenfunction:: mta_load_plugin diff --git a/docs/src/core/reference/c/system.rst b/docs/src/core/reference/c/system.rst new file mode 100644 index 000000000..155245253 --- /dev/null +++ b/docs/src/core/reference/c/system.rst @@ -0,0 +1,42 @@ +System +====== + +.. doxygentypedef:: mta_system_t + +The following functions operate on :c:type:`mta_system_t`: + +- :c:func:`mta_system_create`: TODO summary +- :c:func:`mta_system_free`: TODO summary +- :c:func:`mta_system_size`: TODO summary +- :c:func:`mta_system_get_data`: TODO summary +- :c:func:`mta_system_get_length_unit`: TODO summary +- :c:func:`mta_system_add_pairs`: TODO summary +- :c:func:`mta_system_get_pairs`: TODO summary +- :c:func:`mta_system_known_pairs`: TODO summary +- :c:func:`mta_system_add_custom_data`: TODO summary +- :c:func:`mta_system_get_custom_data`: TODO summary +- :c:func:`mta_system_known_custom_data`: TODO summary + +-------------------------------------------------------------------------------- + +.. doxygenfunction:: mta_system_create + +.. doxygenfunction:: mta_system_free + +.. doxygenfunction:: mta_system_size + +.. doxygenfunction:: mta_system_get_data + +.. doxygenfunction:: mta_system_get_length_unit + +.. doxygenfunction:: mta_system_add_pairs + +.. doxygenfunction:: mta_system_get_pairs + +.. doxygenfunction:: mta_system_known_pairs + +.. doxygenfunction:: mta_system_add_custom_data + +.. doxygenfunction:: mta_system_get_custom_data + +.. doxygenfunction:: mta_system_known_custom_data diff --git a/docs/src/index.rst b/docs/src/index.rst index 170c25c19..441356c29 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -92,6 +92,7 @@ existing trained models, look into the metatrain_ project instead. overview installation + core/index torch/index quantities/index engines/index diff --git a/metatomic-core/Cargo.toml b/metatomic-core/Cargo.toml index 2a32c1c09..2335505a3 100644 --- a/metatomic-core/Cargo.toml +++ b/metatomic-core/Cargo.toml @@ -14,12 +14,23 @@ name = "metatomic" bench = false [dependencies] +metatensor = { version = "0.3.0" } once_cell = "1" +dlpk = "0.3" +json = "0.12" [build-dependencies] cbindgen = { version = "0.29", default-features = false } +# the last versions that supports Rust 1.74 +serde_spanned = "=1.0.1" +toml = "=0.9.6" +toml_datetime = "=0.7.1" +toml_parser = "=1.0.2" +toml_writer = "=1.0.2" +tempfile = "=3.24.0" +indexmap = "=2.11.4" [dev-dependencies] lazy_static = "1" diff --git a/metatomic-core/build.rs b/metatomic-core/build.rs index edec71e60..b92cc2925 100644 --- a/metatomic-core/build.rs +++ b/metatomic-core/build.rs @@ -22,6 +22,7 @@ fn main() { config.documentation_style = cbindgen::DocumentationStyle::Doxy; config.line_endings = cbindgen::LineEndingStyle::LF; config.autogen_warning = Some(generated_comment.into()); + config.includes.push("metatensor.h".into()); config.includes.push("metatomic/version.h".into()); let result = cbindgen::Builder::new() diff --git a/metatomic-core/cmake/metatomic-config.in.cmake b/metatomic-core/cmake/metatomic-config.in.cmake index 310f54364..90fca167a 100644 --- a/metatomic-core/cmake/metatomic-config.in.cmake +++ b/metatomic-core/cmake/metatomic-config.in.cmake @@ -46,6 +46,7 @@ if (@METATOMIC_INSTALL_BOTH_STATIC_SHARED@ OR @BUILD_SHARED_LIBS@) ) target_compile_features(metatomic::shared INTERFACE cxx_std_17) + target_link_libraries(metatomic::shared INTERFACE metatensor) if (WIN32) if (NOT EXISTS ${METATOMIC_IMPLIB_LOCATION}) @@ -74,6 +75,7 @@ if (@METATOMIC_INSTALL_BOTH_STATIC_SHARED@ OR NOT @BUILD_SHARED_LIBS@) ) target_compile_features(metatomic::static INTERFACE cxx_std_17) + target_link_libraries(metatomic::static INTERFACE metatensor) endif() # Export either the shared or static library as the metatomic target diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index 83b6cb1fc..1e69263e1 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -12,12 +12,118 @@ #include #include #include +#include "metatensor.h" #include "metatomic/version.h" +/** + * TODO + */ +#define MTA_ABI_VERSION 1 + +typedef enum mta_status_t { + MTA_SUCCESS = 0, + MTA_ERROR_OTHER = 255, +} mta_status_t; + +/** + * TODO + */ +typedef enum mta_system_data_kind { + MTA_SYSTEM_DATA_TYPES = 0, + MTA_SYSTEM_DATA_POSITIONS = 1, + MTA_SYSTEM_DATA_CELL = 2, + MTA_SYSTEM_DATA_PBC = 3, +} mta_system_data_kind; + +/** + * TODO + */ +typedef struct mta_opaque_string_t mta_opaque_string_t; + +/** + * TODO + */ +typedef struct mta_system_t mta_system_t; + +/** + * TODO + */ +typedef struct mta_opaque_string_t *mta_string_t; + +/** + * TODO + */ +typedef struct mta_model_t { + /** + * TODO + */ + void *data; + /** + * TODO + */ + enum mta_status_t (*unload)(void *model_data); + /** + * TODO + */ + enum mta_status_t (*metadata)(const void *model_data, mta_string_t *metadata_json); + /** + * TODO + */ + enum mta_status_t (*supported_outputs)(const void *model_data, mta_string_t *outputs_json); + /** + * TODO + */ + enum mta_status_t (*requested_pair_lists)(const void *model_data, mta_string_t *pair_options_json); + /** + * TODO + */ + enum mta_status_t (*requested_inputs)(const void *model_data, mta_string_t *inputs_json); + /** + * TODO + */ + enum mta_status_t (*execute_inner)(void *model_data, + const struct mta_system_t *const *systems, + uintptr_t systems_count, + const mts_labels_t *selected_atoms, + const char *const *requested_outputs_json, + uintptr_t requested_outputs_count, + mts_tensormap_t **outputs, + uintptr_t outputs_count); +} mta_model_t; + +/** + * TODO + */ +typedef struct mta_plugin_t { + /** + * TODO + */ + const char *name; + /** + * TODO + */ + enum mta_status_t (*load_model)(const char *load_from, + const char *options_json, + struct mta_model_t *model); +} mta_plugin_t; + #ifdef __cplusplus extern "C" { #endif // __cplusplus +/** + * TODO + */ +enum mta_status_t mta_last_error(const char **message, const char **origin, void **data); + +/** + * TODO + */ +enum mta_status_t mta_set_last_error(const char *message, + const char *origin, + void *data, + void (*data_deleter)(void*)); + /** * Get the runtime version of the metatomic library as a string. * @@ -25,6 +131,137 @@ extern "C" { */ const char *mta_version(void); +/** + * TODO + */ +mta_string_t mta_string_create(const char *raw); + +/** + * TODO + */ +void mta_string_free(mta_string_t string); + +/** + * TODO + */ +const char *mta_string_view(mta_string_t string); + +/** + * TODO + */ +enum mta_status_t mta_unit_conversion_factor(const char *from_unit, + const char *to_unit, + double *conversion); + +/** + * TODO + */ +enum mta_status_t mta_system_create(const char *length_unit, + DLManagedTensorVersioned *types, + DLManagedTensorVersioned *positions, + DLManagedTensorVersioned *cell, + DLManagedTensorVersioned *pbc, + struct mta_system_t **system); + +/** + * TODO + */ +enum mta_status_t mta_system_free(struct mta_system_t *system); + +/** + * TODO + */ +enum mta_status_t mta_system_size(const struct mta_system_t *system, uintptr_t *size); + +/** + * TODO + */ +enum mta_status_t mta_system_get_data(const struct mta_system_t *system, + enum mta_system_data_kind request, + DLManagedTensorVersioned **data); + +/** + * TODO + */ +enum mta_status_t mta_system_get_length_unit(const struct mta_system_t *system, + mta_string_t *length_unit); + +/** + * TODO + */ +enum mta_status_t mta_system_add_pairs(struct mta_system_t *system, + const char *options, + mts_block_t *pairs); + +/** + * TODO + */ +enum mta_status_t mta_system_get_pairs(const struct mta_system_t *system, + const char *options, + const mts_block_t **pairs); + +/** + * TODO + */ +enum mta_status_t mta_system_known_pairs(const struct mta_system_t *system, + mta_string_t *pairs_options); + +/** + * TODO + */ +enum mta_status_t mta_system_add_custom_data(struct mta_system_t *system, + const char *name, + mts_tensormap_t *data); + +/** + * TODO + */ +enum mta_status_t mta_system_get_custom_data(const struct mta_system_t *system, + const char *name, + const mts_tensormap_t **data); + +/** + * TODO + */ +enum mta_status_t mta_system_known_custom_data(const struct mta_system_t *system, + mta_string_t *names); + +/** + * TODO + */ +enum mta_status_t mta_execute_model(struct mta_model_t model, + const struct mta_system_t *const *systems, + uintptr_t systems_count, + const mts_labels_t *selected_atoms, + const char *const *requested_outputs_json, + uintptr_t requested_outputs_count, + bool check_consistency, + mts_tensormap_t **outputs, + uintptr_t outputs_count); + +/** + * TODO + */ +enum mta_status_t mta_format_metadata(const char *metadata, mta_string_t *printed); + +/** + * TODO + */ +void mta_register_plugin(struct mta_plugin_t plugin); + +/** + * TODO + */ +enum mta_status_t mta_load_plugin(const char *path); + +/** + * TODO + */ +enum mta_status_t mta_load_model(const char *plugin_name, + const char *load_from, + const char *options_json, + struct mta_model_t *model); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/metatomic-core/include/metatomic.hpp b/metatomic-core/include/metatomic.hpp index 016f26bc5..3b5c8ac2a 100644 --- a/metatomic-core/include/metatomic.hpp +++ b/metatomic-core/include/metatomic.hpp @@ -1,2 +1,4 @@ +#include "metatomic/utils.hpp" // IWYU pragma: export #include "metatomic/system.hpp" // IWYU pragma: export -#include "metatomic/model.hpp" // IWYU pragma: export +#include "metatomic/model.hpp" // IWYU pragma: export +#include "metatomic/plugin.hpp" // IWYU pragma: export diff --git a/metatomic-core/include/metatomic/plugin.hpp b/metatomic-core/include/metatomic/plugin.hpp new file mode 100644 index 000000000..1cae91bdf --- /dev/null +++ b/metatomic-core/include/metatomic/plugin.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace metatomic { + +} // namespace metatomic diff --git a/metatomic-core/include/metatomic/utils.hpp b/metatomic-core/include/metatomic/utils.hpp new file mode 100644 index 000000000..1cae91bdf --- /dev/null +++ b/metatomic-core/include/metatomic/utils.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace metatomic { + +} // namespace metatomic diff --git a/metatomic-core/src/c_api/mod.rs b/metatomic-core/src/c_api/mod.rs index 33e0786dc..cf6c6176d 100644 --- a/metatomic-core/src/c_api/mod.rs +++ b/metatomic-core/src/c_api/mod.rs @@ -1,18 +1,15 @@ -use std::ffi::CString; -use std::os::raw::c_char; +mod status; +pub use self::status::mta_status_t; -use once_cell::sync::Lazy; +mod utils; +pub use self::utils::mta_string_t; +pub use self::utils::{mta_string_create, mta_string_free, mta_string_view}; +mod system; +pub use self::system::mta_system_t; -static VERSION: Lazy = Lazy::new(|| { - CString::new(env!("METATOMIC_FULL_VERSION")).expect("version contains NULL byte") -}); +mod model; +pub use self::model::mta_model_t; - -/// Get the runtime version of the metatomic library as a string. -/// -/// This version follows the `..[-]` format. -#[no_mangle] -pub extern "C" fn mta_version() -> *const c_char { - return VERSION.as_ptr(); -} +mod plugin; +pub use self::plugin::{mta_plugin_t, mta_register_plugin, mta_load_model}; diff --git a/metatomic-core/src/c_api/model.rs b/metatomic-core/src/c_api/model.rs new file mode 100644 index 000000000..b586f73c8 --- /dev/null +++ b/metatomic-core/src/c_api/model.rs @@ -0,0 +1,76 @@ +use std::ffi::{c_void, c_char}; +use metatensor::c_api::{mts_labels_t, mts_tensormap_t}; + +use super::{mta_status_t, mta_string_t, mta_system_t}; + +/// TODO +#[repr(C)] +#[allow(non_camel_case_types)] +pub struct mta_model_t { + /// TODO + pub data: *mut c_void, + + /// TODO + pub unload: Option mta_status_t>, + + /// TODO + pub metadata: Option mta_status_t>, + + /// TODO + pub supported_outputs: Option mta_status_t>, + + /// TODO + pub requested_pair_lists: Option mta_status_t>, + + /// TODO + pub requested_inputs: Option mta_status_t>, + + /// TODO + pub execute_inner: Option mta_status_t>, +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_execute_model( + model: mta_model_t, + systems: *const *const mta_system_t, + systems_count: usize, + selected_atoms: *const mts_labels_t, + requested_outputs_json: *const *const c_char, + requested_outputs_count: usize, + check_consistency: bool, + outputs: *mut *mut mts_tensormap_t, + outputs_count: usize, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_format_metadata( + metadata: *const c_char, + printed: *mut mta_string_t, +) -> mta_status_t { + todo!() +} diff --git a/metatomic-core/src/c_api/plugin.rs b/metatomic-core/src/c_api/plugin.rs new file mode 100644 index 000000000..6dbfc4add --- /dev/null +++ b/metatomic-core/src/c_api/plugin.rs @@ -0,0 +1,41 @@ +use std::ffi::c_char; + +use super::{mta_model_t, mta_status_t}; + +/// TODO +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct mta_plugin_t { + /// TODO + pub name: *const c_char, + + /// TODO + pub load_model: Option mta_status_t>, +} + +/// TODO +#[no_mangle] +pub extern "C" fn mta_register_plugin(plugin: mta_plugin_t) { + todo!() +} + +/// TODO +#[no_mangle] +pub extern "C" fn mta_load_plugin(path: *const c_char) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub extern "C" fn mta_load_model( + plugin_name: *const c_char, + load_from: *const c_char, + options_json: *const c_char, + model: *mut mta_model_t, +) -> mta_status_t { + todo!() +} diff --git a/metatomic-core/src/c_api/status.rs b/metatomic-core/src/c_api/status.rs new file mode 100644 index 000000000..0c48707cc --- /dev/null +++ b/metatomic-core/src/c_api/status.rs @@ -0,0 +1,42 @@ +use std::ffi::{c_char, c_void}; + +use crate::Error; + + +// TODO +#[allow(non_camel_case_types)] +#[repr(C)] +#[derive(PartialEq, Eq, Debug)] +pub enum mta_status_t { + MTA_SUCCESS = 0, + // ... + MTA_ERROR_OTHER = 255, +} + + +impl From for mta_status_t { + fn from(err: Error) -> Self { + todo!() + } +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_last_error( + message: *mut *const c_char, + origin: *mut *const c_char, + data: *mut *mut c_void, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_set_last_error( + message: *const c_char, + origin: *const c_char, + data: *mut c_void, + data_deleter: Option, +) -> mta_status_t { + todo!() +} diff --git a/metatomic-core/src/c_api/system.rs b/metatomic-core/src/c_api/system.rs new file mode 100644 index 000000000..1697e5155 --- /dev/null +++ b/metatomic-core/src/c_api/system.rs @@ -0,0 +1,131 @@ +use std::ffi::c_char; + +use dlpk::sys::DLManagedTensorVersioned; +use metatensor::c_api::{mts_block_t, mts_tensormap_t}; + +use crate::System; +use super::{mta_status_t, mta_string_t}; + +/// TODO +#[allow(non_camel_case_types)] +pub struct mta_system_t(pub(crate) System); + + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_create( + length_unit: *const c_char, + types: *mut DLManagedTensorVersioned, + positions: *mut DLManagedTensorVersioned, + cell: *mut DLManagedTensorVersioned, + pbc: *mut DLManagedTensorVersioned, + system: *mut *mut mta_system_t, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_free(system: *mut mta_system_t) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_size( + system: *const mta_system_t, + size: *mut usize, +) -> mta_status_t { + todo!() +} + +/// TODO +#[allow(non_camel_case_types)] +#[repr(C)] +#[non_exhaustive] +pub enum mta_system_data_kind { + MTA_SYSTEM_DATA_TYPES = 0, + MTA_SYSTEM_DATA_POSITIONS = 1, + MTA_SYSTEM_DATA_CELL = 2, + MTA_SYSTEM_DATA_PBC = 3, +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_get_data( + system: *const mta_system_t, + request: mta_system_data_kind, + data: *mut *mut DLManagedTensorVersioned, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_get_length_unit( + system: *const mta_system_t, + length_unit: *mut mta_string_t, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_add_pairs( + system: *mut mta_system_t, + options: *const c_char, + pairs: *mut mts_block_t, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_get_pairs( + system: *const mta_system_t, + options: *const c_char, + pairs: *mut *const mts_block_t, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_known_pairs( + system: *const mta_system_t, + pairs_options: *mut mta_string_t, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_add_custom_data( + system: *mut mta_system_t, + name: *const c_char, + data: *mut mts_tensormap_t, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_get_custom_data( + system: *const mta_system_t, + name: *const c_char, + data: *mut *const mts_tensormap_t, +) -> mta_status_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_system_known_custom_data( + system: *const mta_system_t, + names: *mut mta_string_t, +) -> mta_status_t { + todo!() +} + + +// TODO: mta_system_to(device, dtype) diff --git a/metatomic-core/src/c_api/utils.rs b/metatomic-core/src/c_api/utils.rs new file mode 100644 index 000000000..350a9552d --- /dev/null +++ b/metatomic-core/src/c_api/utils.rs @@ -0,0 +1,102 @@ +use std::ffi::{CString, c_char}; + +use once_cell::sync::Lazy; + +use super::mta_status_t; + + +static VERSION: Lazy = Lazy::new(|| { + CString::new(env!("METATOMIC_FULL_VERSION")).expect("version contains NULL byte") +}); + + +/// Get the runtime version of the metatomic library as a string. +/// +/// This version follows the `..[-]` format. +#[no_mangle] +pub extern "C" fn mta_version() -> *const c_char { + return VERSION.as_ptr(); +} + +/// TODO +#[allow(non_camel_case_types)] +pub struct mta_opaque_string_t(CString); + +/// TODO +#[allow(non_camel_case_types)] +#[repr(transparent)] +pub struct mta_string_t(*mut mta_opaque_string_t); + +impl std::fmt::Debug for mta_string_t { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut builder = f.debug_tuple("mta_string_t"); + + if self.0.is_null() { + builder.field(&"NULL"); + } else { + builder.field(&self.as_str()); + } + builder.finish() + } +} + +impl mta_string_t { + /// TODO + pub fn new(value: impl Into) -> Self { + let cstring = CString::new(value.into()).unwrap(); + let boxed = Box::new(mta_opaque_string_t(cstring)); + mta_string_t(Box::into_raw(boxed)) + } + + /// TODO + pub fn null() -> Self { + mta_string_t(std::ptr::null_mut()) + } + + /// TODO + pub fn as_str(&self) -> &str { + if self.0.is_null() { + return ""; + } + unsafe { + return (*(self.0)).0.to_str().expect("mta_string_t is not valid UTF8") + } + } +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_string_create( + raw: *const c_char, +) -> mta_string_t { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_string_free(string: mta_string_t) { + todo!() +} + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_string_view( + string: mta_string_t, +) -> *const c_char { + todo!() +} + + +/// TODO +#[no_mangle] +pub unsafe extern "C" fn mta_unit_conversion_factor( + from_unit: *const c_char, + to_unit: *const c_char, + conversion: *mut f64, +) -> mta_status_t { + todo!() +} + + + +// TODO: logging & warnings? diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs index bc47948b4..8e4c828b7 100644 --- a/metatomic-core/src/lib.rs +++ b/metatomic-core/src/lib.rs @@ -9,5 +9,46 @@ #![allow(clippy::let_underscore_untyped, clippy::manual_let_else, clippy::empty_line_after_doc_comments)] +// To be removed lated +#![allow(unused_variables, dead_code, clippy::needless_pass_by_value)] + + #[doc(hidden)] -mod c_api; +pub mod c_api; + +mod metadata; +pub use self::metadata::{ModelMetadata, Quantity, PairListOptions}; + +mod system; +pub use self::system::System; + +mod model; +pub use self::model::Model; + +mod plugin; +pub use self::plugin::{Plugin, load_plugin, load_model}; + +mod units; +pub use self::units::unit_conversion_factor; + +/// TODO +#[derive(Debug)] +pub enum Error { + // TODO +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + todo!() + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + todo!() + } + + fn cause(&self) -> Option<&dyn std::error::Error> { + self.source() + } +} diff --git a/metatomic-core/src/metadata.rs b/metatomic-core/src/metadata.rs new file mode 100644 index 000000000..d56a4674a --- /dev/null +++ b/metatomic-core/src/metadata.rs @@ -0,0 +1,132 @@ +use json::JsonValue; + +use crate::Error; + +/// TODO +pub struct PairListOptions { + /// TODO + cutoff: f64, + /// TODO + full_list: bool, + /// TODO + strict: bool, + /// TODO + requestors: Vec, +} + +impl std::cmp::PartialEq for PairListOptions { + fn eq(&self, other: &Self) -> bool { + self.cutoff == other.cutoff + && self.full_list == other.full_list + && self.strict == other.strict + } +} + +impl std::cmp::Eq for PairListOptions {} + +impl std::cmp::PartialOrd for PairListOptions { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl std::cmp::Ord for PairListOptions { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.cutoff.partial_cmp(&other.cutoff).expect("cutoff is NaN") + .then_with(|| self.full_list.cmp(&other.full_list)) + .then_with(|| self.strict.cmp(&other.strict)) + } +} + +// TODO +// { +// "type": "metatomic_pair_options", +// "cutoff": "0xaeabf23", <== hex of the int corresponding to the f64 bits to keep full precision +// "full_list": false, +// "strict": false, +// "requestors": ["..."] +// } +impl From for JsonValue { + fn from(value: PairListOptions) -> Self { + todo!() + } +} + +impl TryFrom for PairListOptions { + type Error = Error; + + fn try_from(value: JsonValue) -> Result { + todo!() + } +} + +// ========================================================================== // +// ========================================================================== // +// ========================================================================== // + +/// TODO +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ModelMetadata { + pub name: String, + // TODO +} + +// { +// "type": "metatomic_model_metadata", +// "name": "...", +// "authors": ["..."], +// "references": { +// "implementation": ["..."], +// "architecture": ["..."], +// "model": ["..."] +// }, +// "extra": { +// "key...": "value..." +// } +// }, +impl From for JsonValue { + fn from(value: ModelMetadata) -> Self { + todo!() + } +} + +impl TryFrom for ModelMetadata { + type Error = Error; + + fn try_from(value: JsonValue) -> Result { + todo!() + } +} + +// ========================================================================== // +// ========================================================================== // +// ========================================================================== // + +/// TODO, previously `ModelOutput` +#[derive(Debug)] +pub struct Quantity { + pub name: String, + // TODO +} + +// TODO: +// { +// "type": "metatomic_quantity", +// "name": "...", +// "unit": "...", +// "gradients": ["...", "..."], +// "sample_kind": "atom" | "system" | "atom-pair", +// }, +impl From for JsonValue { + fn from(value: Quantity) -> Self { + todo!() + } +} + +impl TryFrom for Quantity { + type Error = Error; + + fn try_from(value: JsonValue) -> Result { + todo!() + } +} diff --git a/metatomic-core/src/model.rs b/metatomic-core/src/model.rs new file mode 100644 index 000000000..16b208ac1 --- /dev/null +++ b/metatomic-core/src/model.rs @@ -0,0 +1,20 @@ +use metatensor::{Labels, TensorMap}; + +use crate::{Error, Quantity, System}; + +use crate::c_api::mta_model_t; + +/// TODO +pub struct Model(pub(crate) mta_model_t); + + +/// TODO +pub fn execute_model( + model: &Model, + systems: &[System], + selected_atoms: Option, + requested_outputs: &[Quantity], + check_consistency: bool, +) -> Result, Error> { + todo!() +} diff --git a/metatomic-core/src/plugin.rs b/metatomic-core/src/plugin.rs new file mode 100644 index 000000000..60d145803 --- /dev/null +++ b/metatomic-core/src/plugin.rs @@ -0,0 +1,37 @@ +use std::collections::BTreeMap; + +use crate::c_api::mta_plugin_t; +use crate::{Error, Model}; + +/// TODO +pub const MTA_ABI_VERSION: i32 = 1; + +/// TODO +pub struct Plugin(mta_plugin_t); + +impl Plugin { + /// TODO + pub fn new(c_plugin: mta_plugin_t) -> Self { + Self(c_plugin) + } + + /// TODO + pub fn name(&self) -> &str { + todo!() + } + + /// TODO + pub fn load_model(&self, load_from: &str, options: BTreeMap) -> Result { + todo!() + } +} + +/// TODO +pub fn load_plugin(path: &str) -> Result<(), Error> { + todo!() +} + +/// TODO +pub fn load_model(plugin: Option<&str>, load_from: &str, options: BTreeMap) -> Result { + todo!() +} diff --git a/metatomic-core/src/system.rs b/metatomic-core/src/system.rs new file mode 100644 index 000000000..30677f5f9 --- /dev/null +++ b/metatomic-core/src/system.rs @@ -0,0 +1,53 @@ +use std::collections::{BTreeMap, HashMap}; + +use dlpk::DLPackTensor; +use metatensor::{TensorBlock, TensorMap}; + +use crate::PairListOptions; + + +/// TODO +pub struct System { + length_unit: String, + types: DLPackTensor, + positions: DLPackTensor, + cell: DLPackTensor, + pbc: DLPackTensor, + + pairs: BTreeMap, + custom_data: HashMap, +} + + +impl System { + /// TODO + pub fn new( + length_unit: String, + types: DLPackTensor, + positions: DLPackTensor, + cell: DLPackTensor, + pbc: DLPackTensor + ) -> Self { + todo!() + } + + /// TODO + pub fn add_pairs(&mut self, options: PairListOptions, pairs: TensorBlock, check_consistency: bool) { + todo!() + } + + /// TODO + pub fn get_pairs(&mut self, options: PairListOptions) -> Option<&TensorBlock> { + todo!() + } + + /// TODO + pub fn set_custom_data(&mut self, name: String, data: TensorMap) { + todo!() + } + + /// TODO + pub fn get_custom_data(&self, name: &str) -> Option<&TensorMap> { + todo!() + } +} diff --git a/metatomic-core/src/units.rs b/metatomic-core/src/units.rs new file mode 100644 index 000000000..d06eab413 --- /dev/null +++ b/metatomic-core/src/units.rs @@ -0,0 +1,7 @@ +use crate::Error; + + +/// TODO +pub fn unit_conversion_factor(from_unit: &str, to_unit: &str) -> Result { + todo!() +} diff --git a/metatomic-core/tests/check-cxx-install.rs b/metatomic-core/tests/check-cxx-install.rs index d66f4883b..6baa5b4e1 100644 --- a/metatomic-core/tests/check-cxx-install.rs +++ b/metatomic-core/tests/check-cxx-install.rs @@ -23,19 +23,21 @@ fn check_cxx_install() { const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); - // ====================================================================== // - // build and install metatensor with cmake let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); build_dir.push("cxx-install"); build_dir.push("cmake-find-package"); std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); + // ====================================================================== // + // install dependencies with pip let deps_dir = build_dir.join("deps"); let virtualenv_dir = deps_dir.join("virtualenv"); std::fs::create_dir_all(&virtualenv_dir).expect("failed to create virtualenv dir"); let python_exe = utils::create_python_venv(virtualenv_dir); let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python_exe); + // ====================================================================== // + // build and install metatomic with cmake let metatomic_dep = deps_dir.join("metatomic-core"); let source_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); @@ -54,7 +56,7 @@ fn check_cxx_install() { cmake_config.arg(format!("-DCMAKE_PREFIX_PATH={};{}", metatensor_cmake_prefix.display(), metatomic_cmake_prefix.display())); utils::run_command(cmake_config, "cmake configuration"); - // build the code, linking to metatensor + // build the code, linking to metatomic let cmake_build = utils::cmake_build(&build_dir); utils::run_command(cmake_build, "cmake build"); diff --git a/metatomic-torch/tests/check-torch-install.rs b/metatomic-torch/tests/check-torch-install.rs index ad8cfb604..14e85628a 100644 --- a/metatomic-torch/tests/check-torch-install.rs +++ b/metatomic-torch/tests/check-torch-install.rs @@ -24,14 +24,13 @@ fn check_torch_install() { const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); let cargo_manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); - // ====================================================================== // - // build and install metatensor-torch with cmake let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); build_dir.push("torch-install"); build_dir.push("cmake-find-package"); std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); - + // ====================================================================== // + // install dependencies with pip let deps_dir = build_dir.join("deps"); let torch_dep = deps_dir.join("virtualenv"); @@ -41,7 +40,8 @@ fn check_torch_install() { let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python); let metatensor_torch_cmake_prefix = utils::setup_metatensor_torch_pip(&python); - // configure cmake for metatomic-torch + // ====================================================================== // + // build and install metatomic-torch with cmake let metatomic_torch_dep = deps_dir.join("metatomic-torch"); let cmake_options = vec![ @@ -68,7 +68,7 @@ fn check_torch_install() { ); // ====================================================================== // - // // try to use the installed metatomic-torch from cmake + // try to use the installed metatomic-torch from cmake let mut source_dir = PathBuf::from(&cargo_manifest_dir); source_dir.extend(["tests", "cmake-project"]); @@ -93,7 +93,7 @@ fn check_torch_install() { utils::run_command(ctest, "ctest"); } -/// Same as above, but using pre-built metatensor-torch from the Python wheel, +/// Same as above, but using metatomic-torch from the Python wheel, /// instead of building it from source with cmake. #[test] fn check_python_install() { @@ -106,13 +106,13 @@ fn check_python_install() { const CARGO_TARGET_TMPDIR: &str = env!("CARGO_TARGET_TMPDIR"); - // ====================================================================== // - // build and install metatensor and metatensor-torch with pip let mut build_dir = PathBuf::from(CARGO_TARGET_TMPDIR); build_dir.push("torch-install"); build_dir.push("python-wheels"); std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); + // ====================================================================== // + // install dependencies with pip let mut venv_dir = build_dir.clone(); venv_dir.push("virtualenv"); @@ -123,6 +123,8 @@ fn check_python_install() { let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python_exe); let metatensor_torch_cmake_prefix = utils::setup_metatensor_torch_pip(&python_exe); + // ====================================================================== // + // build and install metatomic and metatomic-torch with pip let mta_core_source_dir = cargo_manifest_dir.parent().unwrap().join("python").join("metatomic_core"); let metatomic_core_cmake_prefix = utils::setup_metatomic_core_pip(&python_exe, &mta_core_source_dir); @@ -130,7 +132,7 @@ fn check_python_install() { let metatomic_torch_cmake_prefix = utils::setup_metatomic_torch_pip(&python_exe, &mta_torch_source_dir); // ====================================================================== // - // try to use the installed metatensor-torch from cmake + // try to use the installed metatomic-torch from cmake let mut source_dir = PathBuf::from(&cargo_manifest_dir); source_dir.extend(["tests", "cmake-project"]); @@ -147,7 +149,7 @@ fn check_python_install() { utils::run_command(cmake_config, "cmake configuration"); - // build the code, linking to metatensor-torch + // build the code, linking to metatomic-torch let cmake_build = utils::cmake_build(&build_dir); utils::run_command(cmake_build, "cmake build"); @@ -175,16 +177,19 @@ fn check_cmake_subdirectory() { build_dir.push("cmake-subdirectory"); std::fs::create_dir_all(&build_dir).expect("failed to create build dir"); + // ====================================================================== // + // install dependencies with pip let deps_dir = build_dir.join("deps"); - let torch_dep = deps_dir.join("virtualenv"); - std::fs::create_dir_all(&torch_dep).expect("failed to create virtualenv dir"); - let python = utils::create_python_venv(torch_dep); + let virtualenv_dir = deps_dir.join("virtualenv"); + std::fs::create_dir_all(&virtualenv_dir).expect("failed to create virtualenv dir"); + let python = utils::create_python_venv(virtualenv_dir); let pytorch_cmake_prefix = utils::setup_torch_pip(&python); let metatensor_cmake_prefix = utils::setup_metatensor_pip(&python); let metatensor_torch_cmake_prefix = utils::setup_metatensor_torch_pip(&python); // ====================================================================== // + // build metatomic-torch with cmake, using add_subdirectory let cargo_manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()); let mut source_dir = PathBuf::from(&cargo_manifest_dir); source_dir.extend(["tests", "cmake-project"]); diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000..c7ad93baf --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +disable_all_formatting = true diff --git a/scripts/check-c-api-docs.py b/scripts/check-c-api-docs.py new file mode 100755 index 000000000..73ee7d921 --- /dev/null +++ b/scripts/check-c-api-docs.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +""" +A small script checking that all the C API functions are documented +""" + +import os +import sys + +from pycparser import c_ast, parse_file + + +ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) +C_API_DOCS = os.path.join(ROOT, "docs", "src", "core", "reference", "c") +FAKE_INCLUDES = [os.path.join(ROOT, "scripts", "include")] +METATOMIC_HEADER = os.path.relpath( + os.path.join(ROOT, "metatomic-core", "include", "metatomic.h") +) + + +ERRORS = 0 + + +def error(message): + global ERRORS + ERRORS += 1 + print(message) + + +def documented_functions(): + functions = [] + + for root, _, paths in os.walk(C_API_DOCS): + for path in paths: + with open(os.path.join(root, path), encoding="utf8") as fd: + for line in fd: + if line.startswith(".. doxygenfunction::"): + name = line.split()[2] + functions.append(name) + + return functions + + +def functions_in_outline(): + # function from the "miscellaneous" section of the docs don't require an outline + # (since they are not related to a specific struct type) + functions = [ + "mta_version", + "mta_last_error", + "mta_set_last_error", + "mta_string_create", + "mta_string_free", + "mta_string_view", + "mta_format_metadata", + "mta_unit_conversion_factor", + ] + + for root, _, paths in os.walk(C_API_DOCS): + for path in paths: + with open(os.path.join(root, path), encoding="utf8") as fd: + for line in fd: + if ":c:func:" in line: + name = line.split("`")[1] + functions.append(name) + return functions + + +def all_functions(): + cpp_args = ["-E"] + for path in FAKE_INCLUDES: + cpp_args += ["-I", path] + ast = parse_file(METATOMIC_HEADER, use_cpp=True, cpp_path="gcc", cpp_args=cpp_args) + + functions = [] + + class AstVisitor(c_ast.NodeVisitor): + def visit_Decl(self, node): + if not isinstance(node.type, c_ast.FuncDecl): + return + + if not node.name.startswith("mta_"): + return + + functions.append(node.name) + + visitor = AstVisitor() + visitor.visit(ast) + + return functions + + +if __name__ == "__main__": + docs = documented_functions() + outline = functions_in_outline() + for function in all_functions(): + if function not in docs: + error("Missing documentation for {}".format(function)) + if function not in outline: + error("Missing outline for {}".format(function)) + + if ERRORS != 0: + sys.exit(1) diff --git a/scripts/include/README b/scripts/include/README new file mode 100644 index 000000000..d56dd0788 --- /dev/null +++ b/scripts/include/README @@ -0,0 +1,4 @@ +This directory contains fake headers used to allow pycparser to parse the code +without having to deal with all the complexity of actual stdlib implementations + +See https://eli.thegreenplace.net/2015/on-parsing-c-type-declarations-and-fake-headers for more information diff --git a/scripts/include/metatensor.h b/scripts/include/metatensor.h new file mode 100644 index 000000000..fb8e88f0d --- /dev/null +++ b/scripts/include/metatensor.h @@ -0,0 +1,8 @@ +// empty header with minimal content, to be used to parse metatomic.h + +typedef struct mts_labels_t mts_labels_t; +typedef struct mts_block_t mts_block_t; +typedef struct mts_tensormap_t mts_tensormap_t; + + +typedef struct DLManagedTensorVersioned DLManagedTensorVersioned; diff --git a/scripts/include/metatomic/version.h b/scripts/include/metatomic/version.h new file mode 100644 index 000000000..e69de29bb diff --git a/scripts/include/stdarg.h b/scripts/include/stdarg.h new file mode 100644 index 000000000..e69de29bb diff --git a/scripts/include/stdbool.h b/scripts/include/stdbool.h new file mode 100644 index 000000000..3bd41ef29 --- /dev/null +++ b/scripts/include/stdbool.h @@ -0,0 +1 @@ +typedef _Bool bool; \ No newline at end of file diff --git a/scripts/include/stddef.h b/scripts/include/stddef.h new file mode 100644 index 000000000..48b3db663 --- /dev/null +++ b/scripts/include/stddef.h @@ -0,0 +1,6 @@ +#ifndef FAKE_STDDEF_H +#define FAKE_STDDEF_H + +typedef void nullptr_t; + +#endif /* FAKE_STDDEF_H */ diff --git a/scripts/include/stdint.h b/scripts/include/stdint.h new file mode 100644 index 000000000..43ccc01dd --- /dev/null +++ b/scripts/include/stdint.h @@ -0,0 +1,7 @@ +typedef int uint64_t; +typedef int int64_t; +typedef int int32_t; +typedef int uint32_t; +typedef int uint16_t; +typedef int uint8_t; +typedef int uintptr_t; diff --git a/scripts/include/stdlib.h b/scripts/include/stdlib.h new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/scripts/include/stdlib.h @@ -0,0 +1 @@ + From 743aae3a3add77971f2b7a890e3cf69d6cc67eb9 Mon Sep 17 00:00:00 2001 From: Sofiia Chorna Date: Thu, 28 May 2026 16:58:17 +0200 Subject: [PATCH 06/20] Implement PairListOptions json serialization --- docs/src/core/index.rst | 1 + docs/src/core/reference/json-formats.rst | 52 ++++++ metatomic-core/src/lib.rs | 11 +- metatomic-core/src/metadata.rs | 204 +++++++++++++++++++++-- 4 files changed, 250 insertions(+), 18 deletions(-) create mode 100644 docs/src/core/reference/json-formats.rst diff --git a/docs/src/core/index.rst b/docs/src/core/index.rst index 60512b353..4d0cf24a4 100644 --- a/docs/src/core/index.rst +++ b/docs/src/core/index.rst @@ -8,6 +8,7 @@ WIP :maxdepth: 2 reference/c/index + reference/json-formats .. toctree:: diff --git a/docs/src/core/reference/json-formats.rst b/docs/src/core/reference/json-formats.rst new file mode 100644 index 000000000..d1589fa2d --- /dev/null +++ b/docs/src/core/reference/json-formats.rst @@ -0,0 +1,52 @@ +.. _core-json-formats: + +JSON data formats +================= + +Some metatomic data structures are exchanged across the C API as JSON-encoded +strings rather than dedicated C types. This page documents the exact JSON +representation of each such structure, so that engines and models written in any +language can produce and consume them. + +Pair list options +----------------- + +Options describing a requested pair list (also known as a neighbor list). This +is the JSON representation of ``PairListOptions``, used for example by +:c:func:`mta_system_set_pairs`, :c:func:`mta_system_get_pairs` and +:c:func:`mta_system_pairs_options`. + +.. code-block:: json + + { + "type": "metatomic_pair_options", + "cutoff": "0x400c000000000000", + "full_list": false, + "strict": false, + "requestors": ["my-model"] + } + +``type`` + Must be the string ``"metatomic_pair_options"``. + +``cutoff`` + Cutoff radius for the pair list in the length unit of the model. Must be a + positive finite number. + + It is stored as a string containing the hexadecimal representation of the + 64-bit integer with the same bit pattern as the ``cutoff`` floating-point + value (i.e. reinterpreting the ``double`` as a ``uint64_t``). + +``full_list`` + Boolean. If ``true``, the list is a full list containing both ``i -> j`` + and ``j -> i`` for each pair, if ``false``, it is a half list containing + only ``i -> j``. + +``strict`` + Boolean. If ``true``, the list is guaranteed to contain only atoms within + the cutoff, if ``false``, it may also include some pairs slightly beyond the + cutoff. + +``requestors`` + Optional array of strings identifying who requested this pair list. May be + omitted, in which case it is treated as an empty list. diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs index 8e4c828b7..894e70e3a 100644 --- a/metatomic-core/src/lib.rs +++ b/metatomic-core/src/lib.rs @@ -34,18 +34,23 @@ pub use self::units::unit_conversion_factor; /// TODO #[derive(Debug)] pub enum Error { - // TODO + /// Error while serializing data to or deserializing data from JSON + Serialization(String), } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - todo!() + match self { + Error::Serialization(message) => write!(f, "{}", message), + } } } impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - todo!() + match self { + Error::Serialization(_) => None, + } } fn cause(&self) -> Option<&dyn std::error::Error> { diff --git a/metatomic-core/src/metadata.rs b/metatomic-core/src/metadata.rs index d56a4674a..913edcdfc 100644 --- a/metatomic-core/src/metadata.rs +++ b/metatomic-core/src/metadata.rs @@ -2,15 +2,19 @@ use json::JsonValue; use crate::Error; -/// TODO +/// Options for the calculation of a pair list (neighbor list) +#[derive(Debug, Clone)] pub struct PairListOptions { - /// TODO + /// Cutoff radius for this pair list in the length unit of the model cutoff: f64, - /// TODO + /// Whether the list is a full list (contains both the pair `i -> j` and `j -> i`) + /// or a half list (contains only `i -> j`) full_list: bool, - /// TODO + /// Whether the list guarantees that only atoms within the cutoff are + /// included (strict) or may also include pairs slightly beyond the cutoff + /// (non-strict) strict: bool, - /// TODO + /// List of strings describing who requested this pair list requestors: Vec, } @@ -38,17 +42,16 @@ impl std::cmp::Ord for PairListOptions { } } -// TODO -// { -// "type": "metatomic_pair_options", -// "cutoff": "0xaeabf23", <== hex of the int corresponding to the f64 bits to keep full precision -// "full_list": false, -// "strict": false, -// "requestors": ["..."] -// } impl From for JsonValue { fn from(value: PairListOptions) -> Self { - todo!() + let mut result = JsonValue::new_object(); + result["type"] = "metatomic_pair_options".into(); + // store the bit pattern so the float round-trips exactly + result["cutoff"] = format!("{:#x}", value.cutoff.to_bits()).into(); + result["full_list"] = value.full_list.into(); + result["strict"] = value.strict.into(); + result["requestors"] = value.requestors.into(); + return result; } } @@ -56,7 +59,61 @@ impl TryFrom for PairListOptions { type Error = Error; fn try_from(value: JsonValue) -> Result { - todo!() + if !value.is_object() { + return Err(Error::Serialization( + "invalid JSON data for PairListOptions, expected an object".into() + )); + } + + if value["type"].as_str() != Some("metatomic_pair_options") { + return Err(Error::Serialization( + "'type' in JSON for PairListOptions must be 'metatomic_pair_options'".into() + )); + } + + let cutoff = value["cutoff"].as_str().ok_or_else(|| Error::Serialization( + "'cutoff' in JSON for PairListOptions must be a hex-encoded string".into() + ))?; + let bits = u64::from_str_radix(cutoff.strip_prefix("0x").unwrap_or(cutoff), 16) + .map_err(|_| Error::Serialization( + "'cutoff' in JSON for PairListOptions must be a hex-encoded string".into() + ))?; + let cutoff = f64::from_bits(bits); + + if !cutoff.is_finite() || cutoff <= 0.0 { + return Err(Error::Serialization( + "'cutoff' in JSON for PairListOptions must be a finite positive number".into() + )); + } + + let full_list = value["full_list"].as_bool().ok_or_else(|| Error::Serialization( + "'full_list' in JSON for PairListOptions must be a boolean".into() + ))?; + + let strict = value["strict"].as_bool().ok_or_else(|| Error::Serialization( + "'strict' in JSON for PairListOptions must be a boolean".into() + ))?; + + let mut requestors = Vec::new(); + if value.has_key("requestors") { + if !value["requestors"].is_array() { + return Err(Error::Serialization( + "'requestors' in JSON for PairListOptions must be an array".into() + )); + } + + for requestor in value["requestors"].members() { + let requestor = requestor.as_str().ok_or_else(|| Error::Serialization( + "'requestors' in JSON for PairListOptions must be an array of strings".into() + ))?; + // ignore empty strings and duplicates, keeping first-seen order + if !requestor.is_empty() && !requestors.iter().any(|r| r == requestor) { + requestors.push(requestor.to_string()); + } + } + } + + return Ok(PairListOptions { cutoff, full_list, strict, requestors }); } } @@ -130,3 +187,120 @@ impl TryFrom for Quantity { todo!() } } + + +#[cfg(test)] +mod tests { + mod pair_list_options { + use super::super::*; + + fn example() -> PairListOptions { + PairListOptions { + cutoff: 3.5, + full_list: true, + strict: false, + requestors: vec!["nl-1".to_string(), "nl-2".to_string()], + } + } + + #[test] + fn roundtrip() { + let options = example(); + let json: JsonValue = options.clone().into(); + + assert_eq!(json["type"].as_str(), Some("metatomic_pair_options")); + assert_eq!(json["cutoff"].as_str(), Some(format!("{:#x}", 3.5_f64.to_bits()).as_str())); + assert_eq!(json["full_list"].as_bool(), Some(true)); + assert_eq!(json["strict"].as_bool(), Some(false)); + + let parsed = PairListOptions::try_from(json).unwrap(); + assert_eq!(parsed.cutoff.to_bits(), options.cutoff.to_bits()); + assert_eq!(parsed.full_list, options.full_list); + assert_eq!(parsed.strict, options.strict); + assert_eq!(parsed.requestors, options.requestors); + } + + #[test] + fn cutoff_keeps_full_precision() { + let mut options = example(); + options.cutoff = 1.0 / 3.0; + let parsed = PairListOptions::try_from(JsonValue::from(options.clone())).unwrap(); + assert_eq!(parsed.cutoff.to_bits(), options.cutoff.to_bits()); + } + + #[test] + fn requestors_are_optional() { + let mut json: JsonValue = example().into(); + json.remove("requestors"); + let parsed = PairListOptions::try_from(json).unwrap(); + assert!(parsed.requestors.is_empty()); + } + + #[test] + fn rejects_invalid_json() { + // each case corrupts exactly one field of an otherwise valid object + let with_cutoff = |value: f64| { + let mut json = JsonValue::from(example()); + json["cutoff"] = format!("{:#x}", value.to_bits()).into(); + json + }; + + let mut wrong_type = JsonValue::from(example()); + wrong_type["type"] = "something-else".into(); + + let mut missing_cutoff = JsonValue::from(example()); + missing_cutoff.remove("cutoff"); + + let mut non_hex_cutoff = JsonValue::from(example()); + non_hex_cutoff["cutoff"] = "not-hex".into(); + + let mut non_boolean_flag = JsonValue::from(example()); + non_boolean_flag["full_list"] = "yes".into(); + + let mut non_array_requestors = JsonValue::from(example()); + non_array_requestors["requestors"] = "nl-1".into(); + + let mut non_string_requestor = JsonValue::from(example()); + non_string_requestor["requestors"] = json::array![ "nl-1", 42 ]; + + let cases = [ + (JsonValue::from("not an object"), + "invalid JSON data for PairListOptions, expected an object"), + (wrong_type, + "'type' in JSON for PairListOptions must be 'metatomic_pair_options'"), + (missing_cutoff, + "'cutoff' in JSON for PairListOptions must be a hex-encoded string"), + (non_hex_cutoff, + "'cutoff' in JSON for PairListOptions must be a hex-encoded string"), + (with_cutoff(f64::NAN), + "'cutoff' in JSON for PairListOptions must be a finite positive number"), + (with_cutoff(f64::INFINITY), + "'cutoff' in JSON for PairListOptions must be a finite positive number"), + (with_cutoff(-1.0), + "'cutoff' in JSON for PairListOptions must be a finite positive number"), + (with_cutoff(0.0), + "'cutoff' in JSON for PairListOptions must be a finite positive number"), + (non_boolean_flag, + "'full_list' in JSON for PairListOptions must be a boolean"), + (non_array_requestors, + "'requestors' in JSON for PairListOptions must be an array"), + (non_string_requestor, + "'requestors' in JSON for PairListOptions must be an array of strings"), + ]; + + for (json, expected) in cases { + let error = PairListOptions::try_from(json).expect_err("expected an error"); + assert_eq!(error.to_string(), expected); + } + } + + #[test] + fn requestors_skip_empty_and_duplicates() { + let mut json: JsonValue = example().into(); + json["requestors"] = json::array![ "a", "", "b", "a" ]; + + let parsed = PairListOptions::try_from(json).unwrap(); + assert_eq!(parsed.requestors, vec!["a".to_string(), "b".to_string()]); + } + } +} From e02fe0b744dfa9d87b9a9b614180d6200d0567ad Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Thu, 28 May 2026 22:50:39 +0200 Subject: [PATCH 07/20] Implement JSON serialization for `Quantity` Co-Authored-By: Guillaume Fraux --- docs/src/core/reference/json-formats.rst | 47 +++- metatomic-core/src/lib.rs | 7 +- metatomic-core/src/metadata.rs | 33 --- metatomic-core/src/quantities.rs | 281 +++++++++++++++++++++++ 4 files changed, 329 insertions(+), 39 deletions(-) create mode 100644 metatomic-core/src/quantities.rs diff --git a/docs/src/core/reference/json-formats.rst b/docs/src/core/reference/json-formats.rst index d1589fa2d..d12ff6da5 100644 --- a/docs/src/core/reference/json-formats.rst +++ b/docs/src/core/reference/json-formats.rst @@ -11,10 +11,9 @@ language can produce and consume them. Pair list options ----------------- -Options describing a requested pair list (also known as a neighbor list). This -is the JSON representation of ``PairListOptions``, used for example by -:c:func:`mta_system_set_pairs`, :c:func:`mta_system_get_pairs` and -:c:func:`mta_system_pairs_options`. +The JSON representation of a requested pair list (also known as a neighbor +list). This is used for example by :c:func:`mta_system_add_pairs`, +:c:func:`mta_system_get_pairs` and :c:func:`mta_system_known_pairs`. .. code-block:: json @@ -50,3 +49,43 @@ is the JSON representation of ``PairListOptions``, used for example by ``requestors`` Optional array of strings identifying who requested this pair list. May be omitted, in which case it is treated as an empty list. + + +Quantities +---------- + +The JSON representation of a physical quantity, used to represent custom models +inputs and outputs. This is used for example in +:c:member:`mta_model_t.requested_inputs` and +:c:member:`mta_model_t.supported_outputs`. + +.. code-block:: json + + { + "type": "metatomic_quantity", + "name": "energy", + "unit": "eV", + "sample_kind": "system" + "gradients": ["positions"] + "description": "Potential energy of the system", + } + +``type`` + Must be the string ``"metatomic_quantity"``. + +``name`` + Name of the quantity, this this can be a standard name from the list of + :ref:`standard-quantities`, or a custom name of the form + ``::[/]`` + +``unit`` + Unit of the quantity. + +``gradients`` + Array of strings identifying the gradients for this quantity. This can be an + empty array if the quantity has no gradients. Valid values for the gradients + are ``"positions"``, and ``"strain"``. + +``sample_kind`` + Kind of sample for which this quantity is defined. This can be one of the + following: ``"atom"``, ``"system"`` or ``"atom_pair"``. diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs index 894e70e3a..c778867bb 100644 --- a/metatomic-core/src/lib.rs +++ b/metatomic-core/src/lib.rs @@ -17,7 +17,10 @@ pub mod c_api; mod metadata; -pub use self::metadata::{ModelMetadata, Quantity, PairListOptions}; +pub use self::metadata::{ModelMetadata, PairListOptions}; + +mod quantities; +pub use self::quantities::Quantity; mod system; pub use self::system::System; @@ -31,7 +34,7 @@ pub use self::plugin::{Plugin, load_plugin, load_model}; mod units; pub use self::units::unit_conversion_factor; -/// TODO +/// Error type used throughout `metatomic-core`. #[derive(Debug)] pub enum Error { /// Error while serializing data to or deserializing data from JSON diff --git a/metatomic-core/src/metadata.rs b/metatomic-core/src/metadata.rs index 913edcdfc..f28732919 100644 --- a/metatomic-core/src/metadata.rs +++ b/metatomic-core/src/metadata.rs @@ -155,39 +155,6 @@ impl TryFrom for ModelMetadata { } } -// ========================================================================== // -// ========================================================================== // -// ========================================================================== // - -/// TODO, previously `ModelOutput` -#[derive(Debug)] -pub struct Quantity { - pub name: String, - // TODO -} - -// TODO: -// { -// "type": "metatomic_quantity", -// "name": "...", -// "unit": "...", -// "gradients": ["...", "..."], -// "sample_kind": "atom" | "system" | "atom-pair", -// }, -impl From for JsonValue { - fn from(value: Quantity) -> Self { - todo!() - } -} - -impl TryFrom for Quantity { - type Error = Error; - - fn try_from(value: JsonValue) -> Result { - todo!() - } -} - #[cfg(test)] mod tests { diff --git a/metatomic-core/src/quantities.rs b/metatomic-core/src/quantities.rs new file mode 100644 index 000000000..072cdc41b --- /dev/null +++ b/metatomic-core/src/quantities.rs @@ -0,0 +1,281 @@ +use json::JsonValue; + +use crate::Error; + + +/// Different kind of samples a quantity can be associated with +#[derive(Debug, Clone, PartialEq)] +pub enum SampleKind { + /// The quantity is defined for each atom (e.g. atomic energy, charge, ...) + Atom, + /// The quantity is defined for the whole system (e.g. total energy, ...) + System, + /// The quantity is defined for each pair of atoms (e.g. hamiltonian elements, ...) + AtomPair, +} + +impl From for JsonValue { + fn from(value: SampleKind) -> Self { + let s = match value { + SampleKind::Atom => "atom", + SampleKind::System => "system", + SampleKind::AtomPair => "atom_pair", + }; + JsonValue::from(s) + } +} + +impl<'a> TryFrom<&'a JsonValue> for SampleKind { + type Error = Error; + + fn try_from(value: &'a JsonValue) -> Result { + let s = value.as_str().ok_or_else(|| Error::Serialization( + "'sample_kind' in JSON for Quantity must be a string".into() + ))?; + match s { + "atom" => Ok(SampleKind::Atom), + "system" => Ok(SampleKind::System), + "atom_pair" => Ok(SampleKind::AtomPair), + _ => Err(Error::Serialization(format!( + "'sample_kind' in JSON for Quantity must be 'atom', 'system' or 'atom_pair', got '{}'", s + ))), + } + } +} + +/// Different gradients that a quantity can have +#[derive(Debug, Clone, PartialEq)] +pub enum Gradients { + /// Gradients with respect to atomic positions + Positions, + /// Gradients with respect to the strain (typically used for stress) + Strain, +} + +impl From for JsonValue { + fn from(value: Gradients) -> Self { + let s = match value { + Gradients::Positions => "positions", + Gradients::Strain => "strain", + }; + JsonValue::from(s) + } +} + +impl<'a> TryFrom<&'a JsonValue> for Gradients { + type Error = Error; + + fn try_from(value: &'a JsonValue) -> Result { + let s = value.as_str().ok_or_else(|| Error::Serialization( + "'gradients' in JSON for Quantity must be a string".into() + ))?; + match s { + "positions" => Ok(Gradients::Positions), + "strain" => Ok(Gradients::Strain), + _ => Err(Error::Serialization(format!( + "'gradients' in JSON for Quantity must be 'positions' or 'strain', got '{}'", s + ))), + } + } +} + +/// A quantity that a model can use as input or output +#[derive(Debug, Clone)] +pub struct Quantity { + /// Name of the quantity, this can be a standard name from + /// , or + /// a custom name of the form `::[/]` + pub name: String, + /// Unit of the quantity + pub unit: String, + /// Description of the quantity, used to provide more details about the + /// quantity, especially when a model defines multiple variants of the same + /// quantity. + pub description: Option, + /// List of explicit gradients for this quantity, stored in the + /// corresponding `TensorMap` + pub gradients: Vec, + /// The kind of samples this quantity is associated with (e.g. per-atom, + /// per-system, ...) + pub sample_kind: SampleKind, +} + +impl From for JsonValue { + fn from(value: Quantity) -> Self { + let mut result = JsonValue::new_object(); + result["type"] = "metatomic_quantity".into(); + result["name"] = value.name.into(); + result["unit"] = value.unit.into(); + if let Some(description) = value.description { + result["description"] = description.into(); + } + result["gradients"] = value.gradients.into(); + result["sample_kind"] = value.sample_kind.into(); + return result; + } +} + + +impl TryFrom for Quantity { + type Error = Error; + + fn try_from(value: JsonValue) -> Result { + if !value.is_object() { + return Err(Error::Serialization( + "invalid JSON data for Quantity, expected an object".into() + )); + } + + if value["type"].as_str() != Some("metatomic_quantity") { + return Err(Error::Serialization( + "'type' in JSON for Quantity must be 'metatomic_quantity'".into() + )); + } + + let name = value["name"].as_str().ok_or_else(|| Error::Serialization( + "'name' in JSON for Quantity must be a string".into() + ))?; + + let unit = value["unit"].as_str().ok_or_else(|| Error::Serialization( + "'unit' in JSON for Quantity must be a string".into() + ))?; + + let mut description = value["description"].as_str().map(|s| s.to_string()); + if description == Some(String::new()) { + // Treat empty description as None + description = None; + } + + let gradients = &value["gradients"]; + if !gradients.is_array() { + return Err(Error::Serialization( + "'gradients' in JSON for Quantity must be an array".into() + )); + } + let gradients = gradients.members() + .map(Gradients::try_from) + .collect::, _>>()?; + + let sample_kind = SampleKind::try_from(&value["sample_kind"])?; + + Ok(Quantity { + name: name.to_string(), + unit: unit.to_string(), + description, + gradients, + sample_kind, + }) + } +} + + +#[cfg(test)] +mod tests { + use super::*; + + fn example() -> Quantity { + Quantity { + name: "energy".into(), + unit: "eV".into(), + description: Some("total energy of the system".into()), + gradients: vec![Gradients::Positions], + sample_kind: SampleKind::Atom, + } + } + + #[test] + fn roundtrip() { + let quantity = example(); + let json: JsonValue = quantity.into(); + + assert_eq!(json["type"].as_str(), Some("metatomic_quantity")); + assert_eq!(json["name"].as_str(), Some("energy")); + assert_eq!(json["unit"].as_str(), Some("eV")); + assert_eq!(json["gradients"][0].as_str(), Some("positions")); + assert_eq!(json["sample_kind"].as_str(), Some("atom")); + + let parsed = Quantity::try_from(json).unwrap(); + assert_eq!(parsed.name, "energy"); + assert_eq!(parsed.unit, "eV"); + assert_eq!(parsed.gradients, vec![Gradients::Positions]); + assert!(matches!(parsed.sample_kind, SampleKind::Atom)); + } + + #[test] + fn roundtrip_all_variants() { + for sample in [SampleKind::Atom, SampleKind::System, SampleKind::AtomPair] { + for grads in [ + vec![], + vec![Gradients::Positions], + vec![Gradients::Strain], + vec![Gradients::Positions, Gradients::Strain], + ] { + let quantity = Quantity { + name: "test".into(), + unit: "unit".into(), + description: Some("Hello".to_string()), + gradients: grads.clone(), + sample_kind: sample.clone(), + }; + let parsed = Quantity::try_from(JsonValue::from(quantity.clone())).unwrap(); + assert_eq!(parsed.name, quantity.name); + assert_eq!(parsed.unit, quantity.unit); + assert_eq!(parsed.gradients, grads); + assert_eq!(parsed.sample_kind, sample); + } + } + } + + #[test] + fn rejects_invalid_json() { + let mut wrong_type = JsonValue::from(example()); + wrong_type["type"] = "something-else".into(); + + let mut missing_name = JsonValue::from(example()); + missing_name.remove("name"); + + let mut missing_unit = JsonValue::from(example()); + missing_unit.remove("unit"); + + let mut missing_gradients = JsonValue::from(example()); + missing_gradients.remove("gradients"); + + let mut non_array_gradients = JsonValue::from(example()); + non_array_gradients["gradients"] = "positions".into(); + + let mut invalid_gradient = JsonValue::from(example()); + invalid_gradient["gradients"] = json::array!["positions", "foo"]; + + let mut missing_sample_kind = JsonValue::from(example()); + missing_sample_kind.remove("sample_kind"); + + let mut invalid_sample_kind = JsonValue::from(example()); + invalid_sample_kind["sample_kind"] = "foo".into(); + + let cases: Vec<(JsonValue, &str)> = vec![ + (JsonValue::from("not an object"), + "invalid JSON data for Quantity, expected an object"), + (wrong_type, + "'type' in JSON for Quantity must be 'metatomic_quantity'"), + (missing_name, + "'name' in JSON for Quantity must be a string"), + (missing_unit, + "'unit' in JSON for Quantity must be a string"), + (missing_gradients, + "'gradients' in JSON for Quantity must be an array"), + (non_array_gradients, + "'gradients' in JSON for Quantity must be an array"), + (invalid_gradient, + "'gradients' in JSON for Quantity must be 'positions' or 'strain', got 'foo'"), + (missing_sample_kind, + "'sample_kind' in JSON for Quantity must be a string"), + (invalid_sample_kind, + "'sample_kind' in JSON for Quantity must be 'atom', 'system' or 'atom_pair', got 'foo'"), + ]; + + for (json, expected) in cases { + let error = Quantity::try_from(json).expect_err("expected an error"); + assert_eq!(error.to_string(), expected); + } + } +} From 36e9f33febdf5bac2d900266071e14f3219b8aba Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Fri, 29 May 2026 11:46:27 +0200 Subject: [PATCH 08/20] Validate quantities names --- metatomic-core/src/lib.rs | 7 +- metatomic-core/src/metadata.rs | 22 ++--- metatomic-core/src/quantities.rs | 164 +++++++++++++++++++++++++++++-- 3 files changed, 171 insertions(+), 22 deletions(-) diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs index c778867bb..bf394f738 100644 --- a/metatomic-core/src/lib.rs +++ b/metatomic-core/src/lib.rs @@ -39,12 +39,15 @@ pub use self::units::unit_conversion_factor; pub enum Error { /// Error while serializing data to or deserializing data from JSON Serialization(String), + /// Invalid parameters passed to a function + InvalidParameters(String), } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Error::Serialization(message) => write!(f, "{}", message), + Error::Serialization(message) => write!(f, "serialization error: {}", message), + Error::InvalidParameters(message) => write!(f, "invalid parameter: {}", message), } } } @@ -52,7 +55,7 @@ impl std::fmt::Display for Error { impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - Error::Serialization(_) => None, + Error::Serialization(_) | Error::InvalidParameters(_) => None, } } diff --git a/metatomic-core/src/metadata.rs b/metatomic-core/src/metadata.rs index f28732919..26101c0fa 100644 --- a/metatomic-core/src/metadata.rs +++ b/metatomic-core/src/metadata.rs @@ -232,27 +232,27 @@ mod tests { let cases = [ (JsonValue::from("not an object"), - "invalid JSON data for PairListOptions, expected an object"), + "serialization error: invalid JSON data for PairListOptions, expected an object"), (wrong_type, - "'type' in JSON for PairListOptions must be 'metatomic_pair_options'"), + "serialization error: 'type' in JSON for PairListOptions must be 'metatomic_pair_options'"), (missing_cutoff, - "'cutoff' in JSON for PairListOptions must be a hex-encoded string"), + "serialization error: 'cutoff' in JSON for PairListOptions must be a hex-encoded string"), (non_hex_cutoff, - "'cutoff' in JSON for PairListOptions must be a hex-encoded string"), + "serialization error: 'cutoff' in JSON for PairListOptions must be a hex-encoded string"), (with_cutoff(f64::NAN), - "'cutoff' in JSON for PairListOptions must be a finite positive number"), + "serialization error: 'cutoff' in JSON for PairListOptions must be a finite positive number"), (with_cutoff(f64::INFINITY), - "'cutoff' in JSON for PairListOptions must be a finite positive number"), + "serialization error: 'cutoff' in JSON for PairListOptions must be a finite positive number"), (with_cutoff(-1.0), - "'cutoff' in JSON for PairListOptions must be a finite positive number"), + "serialization error: 'cutoff' in JSON for PairListOptions must be a finite positive number"), (with_cutoff(0.0), - "'cutoff' in JSON for PairListOptions must be a finite positive number"), + "serialization error: 'cutoff' in JSON for PairListOptions must be a finite positive number"), (non_boolean_flag, - "'full_list' in JSON for PairListOptions must be a boolean"), + "serialization error: 'full_list' in JSON for PairListOptions must be a boolean"), (non_array_requestors, - "'requestors' in JSON for PairListOptions must be an array"), + "serialization error: 'requestors' in JSON for PairListOptions must be an array"), (non_string_requestor, - "'requestors' in JSON for PairListOptions must be an array of strings"), + "serialization error: 'requestors' in JSON for PairListOptions must be an array of strings"), ]; for (json, expected) in cases { diff --git a/metatomic-core/src/quantities.rs b/metatomic-core/src/quantities.rs index 072cdc41b..c8d4b45d8 100644 --- a/metatomic-core/src/quantities.rs +++ b/metatomic-core/src/quantities.rs @@ -2,6 +2,83 @@ use json::JsonValue; use crate::Error; +static STANDARD_QUANTITIES: &[&str] = &[ + "charge", + "energy_ensemble", + "energy_uncertainty", + "energy", + "feature", + "heat_flux", + "mass", + "momentum", + "non_conservative_force", + "non_conservative_stress", + "position", + "spin_multiplicity", + "velocity", +]; + +fn is_valid_identifier(s: &str) -> bool { + if s.is_empty() { + return false; + } + let first = s.chars().next().unwrap(); + if !(first.is_ascii_alphabetic() || first == '_') { + return false; + } + s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') +} + +/// Validate a quantity name. +/// +/// The name can be either a standard name or a custom name with the form +/// `::`, where the namespace can itself contain `::` to define +/// sub-namespaces. +/// +/// Both standard and custom names can also define a variant with the form +/// `/` or `::/`. +/// +/// All components (namespace, name, variant) must be non-empty if they are +/// present, and must be valid identifiers (alphanumeric + underscore, not +/// starting with a digit). +fn validate_quantity_name(name: &str) -> Result<(), Error> { + if STANDARD_QUANTITIES.contains(&name) { + return Ok(()); + } + + let (main_part, variant) = if let Some(pos) = name.find('/') { + (&name[..pos], Some(&name[pos + 1..])) + } else { + (name, None) + }; + + if main_part.is_empty() { + return Err(Error::InvalidParameters(format!( + "quantity name cannot be empty in '{}'", name + ))); + } + + if let Some(variant) = variant { + if !is_valid_identifier(variant) { + return Err(Error::InvalidParameters(format!( + "invalid quantity variant '{}' in '{}': must be a valid identifier (alphanumeric or underscore, not starting with a digit)", + variant, name + ))); + } + } + + for component in main_part.split("::") { + if !is_valid_identifier(component) { + return Err(Error::InvalidParameters(format!( + "invalid quantity name component '{}' in '{}': must be a valid identifier (alphanumeric or underscore, not starting with a digit)", + component, name + ))); + } + } + + Ok(()) +} + /// Different kind of samples a quantity can be associated with #[derive(Debug, Clone, PartialEq)] @@ -135,6 +212,7 @@ impl TryFrom for Quantity { let name = value["name"].as_str().ok_or_else(|| Error::Serialization( "'name' in JSON for Quantity must be a string".into() ))?; + validate_quantity_name(name)?; let unit = value["unit"].as_str().ok_or_else(|| Error::Serialization( "'unit' in JSON for Quantity must be a string".into() @@ -254,23 +332,23 @@ mod tests { let cases: Vec<(JsonValue, &str)> = vec![ (JsonValue::from("not an object"), - "invalid JSON data for Quantity, expected an object"), + "serialization error: invalid JSON data for Quantity, expected an object"), (wrong_type, - "'type' in JSON for Quantity must be 'metatomic_quantity'"), + "serialization error: 'type' in JSON for Quantity must be 'metatomic_quantity'"), (missing_name, - "'name' in JSON for Quantity must be a string"), + "serialization error: 'name' in JSON for Quantity must be a string"), (missing_unit, - "'unit' in JSON for Quantity must be a string"), + "serialization error: 'unit' in JSON for Quantity must be a string"), (missing_gradients, - "'gradients' in JSON for Quantity must be an array"), + "serialization error: 'gradients' in JSON for Quantity must be an array"), (non_array_gradients, - "'gradients' in JSON for Quantity must be an array"), + "serialization error: 'gradients' in JSON for Quantity must be an array"), (invalid_gradient, - "'gradients' in JSON for Quantity must be 'positions' or 'strain', got 'foo'"), + "serialization error: 'gradients' in JSON for Quantity must be 'positions' or 'strain', got 'foo'"), (missing_sample_kind, - "'sample_kind' in JSON for Quantity must be a string"), + "serialization error: 'sample_kind' in JSON for Quantity must be a string"), (invalid_sample_kind, - "'sample_kind' in JSON for Quantity must be 'atom', 'system' or 'atom_pair', got 'foo'"), + "serialization error: 'sample_kind' in JSON for Quantity must be 'atom', 'system' or 'atom_pair', got 'foo'"), ]; for (json, expected) in cases { @@ -278,4 +356,72 @@ mod tests { assert_eq!(error.to_string(), expected); } } + + #[test] + fn validate_names() { + for name in STANDARD_QUANTITIES { + assert!(validate_quantity_name(name).is_ok(), "expected '{}' to be valid", name); + } + + let custom = [ + "my_model::energy", + "org::my_model::custom_qty", + "ns1::ns2::ns3::energy", + "custom_name", + "some_ns::name_with_underscores", + "_underscore_start", + "_ns::_name", + ]; + for name in custom { + assert!(validate_quantity_name(name).is_ok(), "expected '{}' to be valid", name); + } + + let variants = [ + "energy/ensemble", + "my_ns::energy/raw", + "ns1::ns2::energy/some_variant", + ]; + for name in variants { + assert!(validate_quantity_name(name).is_ok(), "expected '{}' to be valid", name); + } + + let error = validate_quantity_name("").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: quantity name cannot be empty in ''"); + + let error = validate_quantity_name("/variant").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: quantity name cannot be empty in '/variant'"); + + let error = validate_quantity_name("name/").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity variant '' in 'name/': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("::energy").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity name component '' in '::energy': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("ns::").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity name component '' in 'ns::': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("ns::/variant").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity name component '' in 'ns::/variant': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("::").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity name component '' in '::': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("123name").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity name component '123name' in '123name': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("my_ns::123name").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity name component '123name' in 'my_ns::123name': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("my_ns::name/123variant").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity variant '123variant' in 'my_ns::name/123variant': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("has spaces").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity name component 'has spaces' in 'has spaces': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("my_ns::name/has spaces").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity variant 'has spaces' in 'my_ns::name/has spaces': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + + let error = validate_quantity_name("has-dash").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: invalid quantity name component 'has-dash' in 'has-dash': must be a valid identifier (alphanumeric or underscore, not starting with a digit)"); + } } From e9d95556c69bb65ef22a8f020fbd51527a6db753 Mon Sep 17 00:00:00 2001 From: Alessandro Forina Date: Thu, 28 May 2026 16:56:37 +0200 Subject: [PATCH 09/20] Implement JSON serialization for ModelMetadata Co-Authored-By: Guillaume Fraux --- docs/src/core/reference/json-formats.rst | 59 +++++ metatomic-core/src/metadata.rs | 265 +++++++++++++++++++++-- 2 files changed, 306 insertions(+), 18 deletions(-) diff --git a/docs/src/core/reference/json-formats.rst b/docs/src/core/reference/json-formats.rst index d12ff6da5..d47cb2715 100644 --- a/docs/src/core/reference/json-formats.rst +++ b/docs/src/core/reference/json-formats.rst @@ -89,3 +89,62 @@ inputs and outputs. This is used for example in ``sample_kind`` Kind of sample for which this quantity is defined. This can be one of the following: ``"atom"``, ``"system"`` or ``"atom_pair"``. + + +Model metadata +-------------- + +The JSON representation of a model's metadata. This is used for example by +:c:member:`mta_model_t.metadata`. + +.. code-block:: json + + { + "type": "metatomic_model_metadata", + "name": "MyCoolModel v1.2", + "authors": ["Alice Smith", "Bob Johnson "], + "description": "A machine learning potential for water", + "references": { + "model": ["doi:10.1234/model-paper"], + "architecture": ["doi:10.1234/arch-paper"], + "implementation": ["https://github.com/example/mycoolmodel"] + }, + "extra": { + "training_set": "QM9", + "cutoff": "4.5" + } + } + +``type`` + Must be the string ``"metatomic_model_metadata"``. + +``name`` + Name of the model, e.g. ``"MyCoolModel v1.2"``. + +``authors`` + Array of strings identifying the authors of the model. Each string can be a + name or a name with an email address, e.g. ``"Alice Smith"`` or + ``"Bob Johnson "``. + +``description`` + A free-text description of the model. + +``references`` + An object with three keys, each containing an array of strings (DOIs, URLs, + or any other format): + + ``model`` + References about the model as a whole, e.g. a paper describing the model + or a website presenting it. + + ``architecture`` + References about the architecture of the model, e.g. papers describing + the mathematical form of the model. + + ``implementation`` + References about the implementation of the model, e.g. a link to the + source code repository or a paper describing the software. + +``extra`` + An object with string values, providing any additional key-value pairs the + model author wishes to include. This can be used for any purpose. diff --git a/metatomic-core/src/metadata.rs b/metatomic-core/src/metadata.rs index 26101c0fa..f8cb769b0 100644 --- a/metatomic-core/src/metadata.rs +++ b/metatomic-core/src/metadata.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use json::JsonValue; use crate::Error; @@ -121,29 +123,97 @@ impl TryFrom for PairListOptions { // ========================================================================== // // ========================================================================== // -/// TODO -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// References for a model, divided into three categories: references about the +/// model as a whole, references about the architecture of the model, and +/// references about the implementation of the model. Each category is a list of +/// strings, which can be DOIs, URLs, or any other format the model author finds +/// useful. +#[derive(Debug, Clone)] +pub struct References { + /// The references about the model as a whole, e.g. a paper describing the + /// model or a website presenting it. + model: Vec, + /// The references about the architecture of the model, e.g. papers + /// describing the mathematical form of the model. + architecture: Vec, + /// The references about the implementation of the model, e.g. a link to + /// the source code repository or a paper describing the software. + implementation: Vec, +} + +impl From for JsonValue { + fn from(value: References) -> Self { + let mut result = JsonValue::new_object(); + result["model"] = value.model.into(); + result["architecture"] = value.architecture.into(); + result["implementation"] = value.implementation.into(); + return result; + } +} + + +fn read_references(object: &JsonValue, key: &str) -> Result, Error> { + let mut references = Vec::new(); + if !object[key].is_array() { + return Err(Error::Serialization( + format!("'{}' in references of ModelMetadata must be an array", key) + )); + } + for reference in object[key].members() { + let reference = reference.as_str().ok_or_else(|| Error::Serialization( + format!("'{}' in references of ModelMetadata must be an array of strings", key) + ))?; + references.push(reference.to_string()); + } + Ok(references) +} + +impl TryFrom for References { + type Error = Error; + + fn try_from(value: JsonValue) -> Result { + if !value.is_object() { + return Err(Error::Serialization( + "invalid JSON data for references in ModelMetadata, expected an object".into() + )); + } + + let model = read_references(&value, "model")?; + let architecture = read_references(&value, "architecture")?; + let implementation = read_references(&value, "implementation")?; + + Ok(References { model, architecture, implementation }) + } +} + + +/// Metadata about a model +#[derive(Debug, Clone)] pub struct ModelMetadata { + /// The name of the model, e.g. `"MyCoolModel v1.2"` pub name: String, - // TODO + /// The authors of the model, e.g. `["Alice Smith", "Bob Johnson + /// "]` + pub authors: Vec, + /// A description of the model + pub description: String, + /// References for the model that should be cited when using it + pub references: References, + /// Any other key-value pairs the model author wants to include in the + /// metadata. This can be used for any purpose. + pub extra: BTreeMap, } -// { -// "type": "metatomic_model_metadata", -// "name": "...", -// "authors": ["..."], -// "references": { -// "implementation": ["..."], -// "architecture": ["..."], -// "model": ["..."] -// }, -// "extra": { -// "key...": "value..." -// } -// }, impl From for JsonValue { fn from(value: ModelMetadata) -> Self { - todo!() + let mut result = JsonValue::new_object(); + result["type"] = "metatomic_model_metadata".into(); + result["name"] = value.name.into(); + result["authors"] = value.authors.into(); + result["description"] = value.description.into(); + result["references"] = value.references.into(); + result["extra"] = value.extra.into(); + return result; } } @@ -151,7 +221,61 @@ impl TryFrom for ModelMetadata { type Error = Error; fn try_from(value: JsonValue) -> Result { - todo!() + if !value.is_object() { + return Err(Error::Serialization( + "invalid JSON data for ModelMetadata, expected an object".into() + )); + } + + if value["type"].as_str() != Some("metatomic_model_metadata") { + return Err(Error::Serialization( + "'type' in JSON for ModelMetadata must be 'metatomic_model_metadata'".into() + )); + } + + let name = value["name"].as_str().ok_or_else(|| Error::Serialization( + "'name' in JSON for ModelMetadata must be a string".into() + ))?; + + if !value["authors"].is_array() { + return Err(Error::Serialization( + "'authors' in JSON for ModelMetadata must be an array".into() + )); + } + + let authors = value["authors"].members().map(|author| { + author.as_str().ok_or_else(|| Error::Serialization( + "'authors' in JSON for ModelMetadata must be an array of strings".into() + )).map(|s| s.to_string()) + }).collect::, Error>>()?; + + let description = value["description"].as_str().ok_or_else(|| Error::Serialization( + "'description' in JSON for ModelMetadata must be a string".into() + ))?.to_string(); + + let references = References::try_from(value["references"].clone())?; + + if !value["extra"].is_object() { + return Err(Error::Serialization( + "'extra' in JSON for ModelMetadata must be an object".into() + )); + } + + let mut extra = BTreeMap::new(); + for (key, value) in value["extra"].entries() { + let value = value.as_str().ok_or_else(|| Error::Serialization( + "'extra' in JSON for ModelMetadata must be an object with string values".into() + ))?; + extra.insert(key.to_string(), value.to_string()); + } + + Ok(ModelMetadata { + name: name.to_string(), + authors: authors, + description: description, + references: references, + extra: extra, + }) } } @@ -270,4 +394,109 @@ mod tests { assert_eq!(parsed.requestors, vec!["a".to_string(), "b".to_string()]); } } + + mod model_metadata { + use super::super::*; + + fn example() -> ModelMetadata { + ModelMetadata { + name: "test-model".into(), + authors: vec!["Alice".into(), "Bob ".into()], + description: "A test model".into(), + references: References { + model: vec!["doi:10.1234/test".into()], + architecture: vec!["doi:10.1234/arch".into()], + implementation: vec!["https://github.com/test".into()], + }, + extra: BTreeMap::from([ + ("key1".into(), "value1".into()), + ("key2".into(), "value2".into()), + ]), + } + } + + #[test] + fn roundtrip() { + let metadata = example(); + let json: JsonValue = metadata.clone().into(); + + assert_eq!(json["type"].as_str(), Some("metatomic_model_metadata")); + assert_eq!(json["name"].as_str(), Some("test-model")); + assert_eq!(json["authors"][0].as_str(), Some("Alice")); + assert_eq!(json["authors"][1].as_str(), Some("Bob ")); + assert_eq!(json["description"].as_str(), Some("A test model")); + assert_eq!(json["references"]["model"][0].as_str(), Some("doi:10.1234/test")); + assert_eq!(json["references"]["architecture"][0].as_str(), Some("doi:10.1234/arch")); + assert_eq!(json["references"]["implementation"][0].as_str(), Some("https://github.com/test")); + assert_eq!(json["extra"]["key1"].as_str(), Some("value1")); + assert_eq!(json["extra"]["key2"].as_str(), Some("value2")); + + let parsed = ModelMetadata::try_from(json).unwrap(); + assert_eq!(parsed.name, metadata.name); + assert_eq!(parsed.authors, metadata.authors); + assert_eq!(parsed.description, metadata.description); + assert_eq!(parsed.references.model, metadata.references.model); + assert_eq!(parsed.references.architecture, metadata.references.architecture); + assert_eq!(parsed.references.implementation, metadata.references.implementation); + assert_eq!(parsed.extra, metadata.extra); + } + + #[test] + fn rejects_invalid_json() { + let mut wrong_type = JsonValue::from(example()); + wrong_type["type"] = "something-else".into(); + + let mut missing_name = JsonValue::from(example()); + missing_name.remove("name"); + + let mut non_string_name = JsonValue::from(example()); + non_string_name["name"] = 42.into(); + + let mut non_array_authors = JsonValue::from(example()); + non_array_authors["authors"] = "Alice".into(); + + let mut non_string_author = JsonValue::from(example()); + non_string_author["authors"] = json::array!["Alice", 42]; + + let mut missing_description = JsonValue::from(example()); + missing_description.remove("description"); + + let mut non_object_extra = JsonValue::from(example()); + non_object_extra["extra"] = "not-an-object".into(); + + let mut non_string_extra_value = JsonValue::from(example()); + non_string_extra_value["extra"] = json::object!{ "key" => 42 }; + + let mut non_object_references = JsonValue::from(example()); + non_object_references["references"] = "not-an-object".into(); + + let cases = [ + (JsonValue::from("not an object"), + "serialization error: invalid JSON data for ModelMetadata, expected an object"), + (wrong_type, + "serialization error: 'type' in JSON for ModelMetadata must be 'metatomic_model_metadata'"), + (missing_name, + "serialization error: 'name' in JSON for ModelMetadata must be a string"), + (non_string_name, + "serialization error: 'name' in JSON for ModelMetadata must be a string"), + (non_array_authors, + "serialization error: 'authors' in JSON for ModelMetadata must be an array"), + (non_string_author, + "serialization error: 'authors' in JSON for ModelMetadata must be an array of strings"), + (missing_description, + "serialization error: 'description' in JSON for ModelMetadata must be a string"), + (non_object_extra, + "serialization error: 'extra' in JSON for ModelMetadata must be an object"), + (non_string_extra_value, + "serialization error: 'extra' in JSON for ModelMetadata must be an object with string values"), + (non_object_references, + "serialization error: invalid JSON data for references in ModelMetadata, expected an object"), + ]; + + for (json, expected) in cases { + let error = ModelMetadata::try_from(json).expect_err("expected an error"); + assert_eq!(error.to_string(), expected); + } + } + } } From b39f6752a66c8319124a329eb54d79fcbf41fcb9 Mon Sep 17 00:00:00 2001 From: Rocco Meli Date: Thu, 28 May 2026 17:31:09 +0200 Subject: [PATCH 10/20] Add error handling based on metatensor --- metatomic-core/include/metatomic.h | 35 +++++- metatomic-core/src/c_api/mod.rs | 1 + metatomic-core/src/c_api/status.rs | 185 +++++++++++++++++++++++++++-- metatomic-core/src/lib.rs | 44 +++++-- metatomic-core/src/quantities.rs | 6 +- 5 files changed, 247 insertions(+), 24 deletions(-) diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index 1e69263e1..2ffa3a2c1 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -20,9 +20,38 @@ */ #define MTA_ABI_VERSION 1 +/** + * Status type returned by all functions in the C API. + * + * The value 0 (`MTA_SUCCESS`) indicates success, while any non-zero value indicates an error. + */ typedef enum mta_status_t { + /** + * Status code indicating success + */ MTA_SUCCESS = 0, - MTA_ERROR_OTHER = 255, + /** + * Status code indicating invalid function parameters + */ + MTA_INVALID_PARAMETER_ERROR = 1, + /** + * Status code indicating I/O errors + */ + MTA_IO_ERROR = 2, + /** + * Status code indicating serialization/deserialization errors + */ + MTA_SERIALIZATION_ERROR = 3, + /** + * Status code indicating errors that come from callbacks provided by the user. + * The error message and arbitrary data can be stored using `mta_set_last_error`, + * and retrieved using `mta_last_error`. + */ + MTA_CALLBACK_ERROR = 254, + /** + * Status code used when there is an internal error + */ + MTA_INTERNAL_ERROR = 255, } mta_status_t; /** @@ -112,12 +141,12 @@ extern "C" { #endif // __cplusplus /** - * TODO + * Get last error message that was created on the current thread. */ enum mta_status_t mta_last_error(const char **message, const char **origin, void **data); /** - * TODO + * Set last error message for the current thread. */ enum mta_status_t mta_set_last_error(const char *message, const char *origin, diff --git a/metatomic-core/src/c_api/mod.rs b/metatomic-core/src/c_api/mod.rs index cf6c6176d..c1b29c419 100644 --- a/metatomic-core/src/c_api/mod.rs +++ b/metatomic-core/src/c_api/mod.rs @@ -1,3 +1,4 @@ +#[macro_use] mod status; pub use self::status::mta_status_t; diff --git a/metatomic-core/src/c_api/status.rs b/metatomic-core/src/c_api/status.rs index 0c48707cc..8aef16c11 100644 --- a/metatomic-core/src/c_api/status.rs +++ b/metatomic-core/src/c_api/status.rs @@ -1,36 +1,170 @@ -use std::ffi::{c_char, c_void}; +use std::cell::RefCell; +use std::ffi::{c_char, c_void, CStr, CString}; +use std::panic::UnwindSafe; use crate::Error; +#[derive(Debug)] +struct LastError { + message: CString, + origin: CString, + custom_data: *mut c_void, + custom_data_deleter: Option, +} + +// Save the last error message in thread local storage. +thread_local! { + pub static LAST_ERROR: RefCell = RefCell::new(LastError { + message: CString::new("").expect("invalid C string"), + origin: CString::new("").expect("invalid C string"), + custom_data: std::ptr::null_mut(), + custom_data_deleter: None, + }); +} -// TODO +/// Status type returned by all functions in the C API. +/// +/// The value 0 (`MTA_SUCCESS`) indicates success, while any non-zero value indicates an error. #[allow(non_camel_case_types)] #[repr(C)] #[derive(PartialEq, Eq, Debug)] pub enum mta_status_t { + /// Status code indicating success MTA_SUCCESS = 0, - // ... - MTA_ERROR_OTHER = 255, + /// Status code indicating invalid function parameters + MTA_INVALID_PARAMETER_ERROR = 1, + /// Status code indicating I/O errors + MTA_IO_ERROR = 2, + /// Status code indicating serialization/deserialization errors + MTA_SERIALIZATION_ERROR = 3, + /// Status code indicating errors that come from callbacks provided by the user. + /// The error message and arbitrary data can be stored using `mta_set_last_error`, + /// and retrieved using `mta_last_error`. + MTA_CALLBACK_ERROR = 254, + /// Status code used when there is an internal error + MTA_INTERNAL_ERROR = 255, } +/// `std::panic::catch_unwind` that automatically transform +/// the error into `mta_status_t`. +pub fn catch_unwind(function: F) -> mta_status_t +where + F: FnOnce() -> Result<(), Error> + UnwindSafe, +{ + match std::panic::catch_unwind(function) { + Ok(Ok(())) => mta_status_t::MTA_SUCCESS, + Ok(Err(error)) => error.into(), + Err(error) => Error::from(error).into(), + } +} + +/// Check that pointers (used as C API function parameters) are not null. +#[macro_export] +#[doc(hidden)] +macro_rules! check_pointers_non_null { + ($pointer: ident) => { + if $pointer.is_null() { + return Err($crate::Error::InvalidParameter( + format!( + "got invalid NULL pointer for {} at {}:{}", + stringify!($pointer), file!(), line!() + ) + )); + } + }; + ($($pointer: ident),* $(,)?) => { + $(check_pointers_non_null!($pointer);)* + } +} impl From for mta_status_t { - fn from(err: Error) -> Self { - todo!() + fn from(error: Error) -> mta_status_t { + if let Error::CallbackError = error { + // If the error is already a CallbackError, we can directly return the corresponding status code. + return mta_status_t::MTA_CALLBACK_ERROR; + } + + LAST_ERROR.with(|last_error| { + let mut last_error = last_error.borrow_mut(); + + // If there is a custom data deleter, + // use it to free the custom data before overwriting it with the new error. + if let Some(deleter) = last_error.custom_data_deleter { + unsafe { + deleter(last_error.custom_data); + } + } + + *last_error = LastError { + message: CString::new(format!("{}", error)) + .expect("error message contains a null byte"), + origin: CString::new("metatensor-core").expect("invalid C string"), + custom_data: std::ptr::null_mut(), + custom_data_deleter: None, + }; + }); + + match error { + Error::InvalidParameter(_) => mta_status_t::MTA_INVALID_PARAMETER_ERROR, + Error::Io(_) => mta_status_t::MTA_IO_ERROR, + Error::Serialization(_) => mta_status_t::MTA_SERIALIZATION_ERROR, + Error::CallbackError => unreachable!(), + Error::Internal(_) => mta_status_t::MTA_INTERNAL_ERROR, + } } } -/// TODO +/// Get last error message that was created on the current thread. #[no_mangle] pub unsafe extern "C" fn mta_last_error( message: *mut *const c_char, origin: *mut *const c_char, data: *mut *mut c_void, ) -> mta_status_t { - todo!() + let status = std::panic::catch_unwind(|| { + LAST_ERROR.with(|last_error| { + let last_error = last_error.borrow(); + if !message.is_null() { + *message = last_error.message.as_ptr(); + } + if !origin.is_null() { + *origin = last_error.origin.as_ptr(); + } + if !data.is_null() { + *data = last_error.custom_data; + } + }); + }); + + match status { + Ok(()) => mta_status_t::MTA_SUCCESS, + Err(error) => { + let last_error_debug = + LAST_ERROR.with(|last_error| format!("{:?}", last_error.borrow())); + if error.is::() { + eprintln!( + "panic in mta_last_error: {:?}, last_error: {:?}", + error.downcast_ref::(), + last_error_debug + ); + } else if error.is::<&str>() { + eprintln!( + "panic in mta_last_error: {:?}, last_error: {:?}", + error.downcast_ref::<&str>(), + last_error_debug + ); + } else { + eprintln!( + "panic in mta_last_error: unknown panic error type. last_error: {:?}", + last_error_debug + ); + } + mta_status_t::MTA_INTERNAL_ERROR + } + } } -/// TODO +/// Set last error message for the current thread. #[no_mangle] pub unsafe extern "C" fn mta_set_last_error( message: *const c_char, @@ -38,5 +172,36 @@ pub unsafe extern "C" fn mta_set_last_error( data: *mut c_void, data_deleter: Option, ) -> mta_status_t { - todo!() + catch_unwind(move || { + let message = if message.is_null() { + CString::new("").expect("invalid C string") + } else { + CString::from(CStr::from_ptr(message)) + }; + + let origin = if origin.is_null() { + CString::new("").expect("invalid C string") + } else { + CString::from(CStr::from_ptr(origin)) + }; + + LAST_ERROR.with(|last_error| { + let mut last_error = last_error.borrow_mut(); + + // Call custom data deleter before overwriting the custom data with the new one, to avoid memory leaks. + if let Some(deleter) = last_error.custom_data_deleter { + unsafe { + deleter(last_error.custom_data); + } + } + + *last_error = LastError { + message: message, + origin: origin, + custom_data: data, + custom_data_deleter: data_deleter, + }; + }); + Ok(()) + }) } diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs index bf394f738..09ec7a962 100644 --- a/metatomic-core/src/lib.rs +++ b/metatomic-core/src/lib.rs @@ -8,11 +8,9 @@ #![allow(clippy::similar_names, clippy::borrow_as_ptr, clippy::uninlined_format_args)] #![allow(clippy::let_underscore_untyped, clippy::manual_let_else, clippy::empty_line_after_doc_comments)] - -// To be removed lated +// To be removed later #![allow(unused_variables, dead_code, clippy::needless_pass_by_value)] - #[doc(hidden)] pub mod c_api; @@ -34,20 +32,31 @@ pub use self::plugin::{Plugin, load_plugin, load_model}; mod units; pub use self::units::unit_conversion_factor; -/// Error type used throughout `metatomic-core`. +/// The possible sources of error in metatomic #[derive(Debug)] pub enum Error { /// Error while serializing data to or deserializing data from JSON Serialization(String), /// Invalid parameters passed to a function - InvalidParameters(String), + InvalidParameter(String), + /// I/O error + Io(std::io::Error), + /// Error coming from an external function used as a callback + CallbackError, + /// Any other internal error, usually these are internal bugs. + Internal(String), } impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Error::Serialization(message) => write!(f, "serialization error: {}", message), - Error::InvalidParameters(message) => write!(f, "invalid parameter: {}", message), + Error::Serialization(e) => write!(f, "serialization error: {}", e), + Error::InvalidParameter(e) => write!(f, "invalid parameter: {}", e), + Error::Io(e) => write!(f, "io error: {}", e), + Error::CallbackError => write!(f, "callback error"), + Error::Internal(e) => write!(f, + "internal metatomic error (this is likely a bug, please report it): {}", e + ), } } } @@ -55,7 +64,11 @@ impl std::fmt::Display for Error { impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - Error::Serialization(_) | Error::InvalidParameters(_) => None, + Error::InvalidParameter(_) + | Error::Serialization(_) + | Error::Internal(_) + | Error::CallbackError => None, + Error::Io(e) => Some(e), } } @@ -63,3 +76,18 @@ impl std::error::Error for Error { self.source() } } + +// Box is the error type in std::panic::catch_unwind +impl From> for Error { + fn from(error: Box) -> Error { + if error.is::() { + Error::Internal(*error.downcast::().expect("should be a String")) + } else if error.is::<&str>() { + Error::Internal((*error.downcast::<&str>().expect("should be an &str")).to_owned()) + } else if error.is::() { + return *error.downcast::().expect("it should be an Error"); + } else { + panic!("panic message is not a string, something is very wrong") + } + } +} diff --git a/metatomic-core/src/quantities.rs b/metatomic-core/src/quantities.rs index c8d4b45d8..9d1dfebbf 100644 --- a/metatomic-core/src/quantities.rs +++ b/metatomic-core/src/quantities.rs @@ -53,14 +53,14 @@ fn validate_quantity_name(name: &str) -> Result<(), Error> { }; if main_part.is_empty() { - return Err(Error::InvalidParameters(format!( + return Err(Error::InvalidParameter(format!( "quantity name cannot be empty in '{}'", name ))); } if let Some(variant) = variant { if !is_valid_identifier(variant) { - return Err(Error::InvalidParameters(format!( + return Err(Error::InvalidParameter(format!( "invalid quantity variant '{}' in '{}': must be a valid identifier (alphanumeric or underscore, not starting with a digit)", variant, name ))); @@ -69,7 +69,7 @@ fn validate_quantity_name(name: &str) -> Result<(), Error> { for component in main_part.split("::") { if !is_valid_identifier(component) { - return Err(Error::InvalidParameters(format!( + return Err(Error::InvalidParameter(format!( "invalid quantity name component '{}' in '{}': must be a valid identifier (alphanumeric or underscore, not starting with a digit)", component, name ))); From bfdda0a3a2fce71cef9d1aae780822c284e902fc Mon Sep 17 00:00:00 2001 From: Johannes Spies <13813209+johannes-spies@users.noreply.github.com> Date: Thu, 28 May 2026 16:18:18 +0200 Subject: [PATCH 11/20] Implement mta_string_t in the C API --- metatomic-core/build.rs | 11 ++++ metatomic-core/include/metatomic.h | 41 +++++++++---- metatomic-core/src/c_api/mod.rs | 2 +- metatomic-core/src/c_api/utils.rs | 99 ++++++++++++++++++++++++------ metatomic-core/tests/misc.cpp | 34 ++++++++++ 5 files changed, 155 insertions(+), 32 deletions(-) diff --git a/metatomic-core/build.rs b/metatomic-core/build.rs index b92cc2925..01f95d71c 100644 --- a/metatomic-core/build.rs +++ b/metatomic-core/build.rs @@ -25,6 +25,17 @@ fn main() { config.includes.push("metatensor.h".into()); config.includes.push("metatomic/version.h".into()); + config.export = cbindgen::ExportConfig { + include: vec!["mta_.*".into()], + // This is done manually below + exclude: vec!["mta_opaque_string_t".into()], + ..Default::default() + }; + config.after_includes = Some(" + +/** Heap allocated storage for mta_string_t */ +typedef struct mta_opaque_string_t mta_opaque_string_t;".into()); + let result = cbindgen::Builder::new() .with_crate(crate_dir) .with_config(config) diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index 2ffa3a2c1..25a5f7444 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -15,6 +15,10 @@ #include "metatensor.h" #include "metatomic/version.h" + +/** Heap allocated storage for mta_string_t */ +typedef struct mta_opaque_string_t mta_opaque_string_t; + /** * TODO */ @@ -64,20 +68,21 @@ typedef enum mta_system_data_kind { MTA_SYSTEM_DATA_PBC = 3, } mta_system_data_kind; -/** - * TODO - */ -typedef struct mta_opaque_string_t mta_opaque_string_t; - /** * TODO */ typedef struct mta_system_t mta_system_t; /** - * TODO + * An heap-allocated UTF-8 string passed across the C API boundary. + * + * This is used whenever a C API function or callback needs to return a string. + * + * A null pointer represents an absent or empty string. Use `mta_string_create` + * to allocate, `mta_string_free` to release, and `mta_string_view` to get a + * pointer to the inner C string. */ -typedef struct mta_opaque_string_t *mta_string_t; +typedef mta_opaque_string_t *mta_string_t; /** * TODO @@ -161,17 +166,31 @@ enum mta_status_t mta_set_last_error(const char *message, const char *mta_version(void); /** - * TODO + * Allocate a new `mta_string_t` by copying the null-terminated C string + * `string`. + * + * The returned string must be freed with `mta_string_free`. + * + * @param string A pointer to a null-terminated C string. Must not be null. + * @return A new `mta_string_t` containing a copy of `string`, or null if an + * error occurred. You can check the error with `mta_last_error`. */ -mta_string_t mta_string_create(const char *raw); +mta_string_t mta_string_create(const char *string); /** - * TODO + * Free a `mta_string_t` previously created by `mta_string_create`. + * + * @param string A `mta_string_t` to free. Can be null, in which case this function is a no-op. */ void mta_string_free(mta_string_t string); /** - * TODO + * Return a pointer to the null-terminated string data inside `string`. + * + * The pointer is valid only for the lifetime of `string`. + * + * @param string A `mta_string_t` containing the string to view. Must not be null. + * @return A pointer to the null-terminated C string inside `string` */ const char *mta_string_view(mta_string_t string); diff --git a/metatomic-core/src/c_api/mod.rs b/metatomic-core/src/c_api/mod.rs index c1b29c419..282e230e5 100644 --- a/metatomic-core/src/c_api/mod.rs +++ b/metatomic-core/src/c_api/mod.rs @@ -1,6 +1,6 @@ #[macro_use] mod status; -pub use self::status::mta_status_t; +pub use self::status::{mta_status_t, catch_unwind}; mod utils; pub use self::utils::mta_string_t; diff --git a/metatomic-core/src/c_api/utils.rs b/metatomic-core/src/c_api/utils.rs index 350a9552d..8e22bcc47 100644 --- a/metatomic-core/src/c_api/utils.rs +++ b/metatomic-core/src/c_api/utils.rs @@ -2,7 +2,7 @@ use std::ffi::{CString, c_char}; use once_cell::sync::Lazy; -use super::mta_status_t; +use super::{mta_status_t, catch_unwind}; static VERSION: Lazy = Lazy::new(|| { @@ -18,11 +18,18 @@ pub extern "C" fn mta_version() -> *const c_char { return VERSION.as_ptr(); } -/// TODO +/// Heap-allocated backing storage for `mta_string_t`, opaque to C users. #[allow(non_camel_case_types)] -pub struct mta_opaque_string_t(CString); +#[repr(transparent)] +pub struct mta_opaque_string_t(c_char); -/// TODO +/// An heap-allocated UTF-8 string passed across the C API boundary. +/// +/// This is used whenever a C API function or callback needs to return a string. +/// +/// A null pointer represents an absent or empty string. Use `mta_string_create` +/// to allocate, `mta_string_free` to release, and `mta_string_view` to get a +/// pointer to the inner C string. #[allow(non_camel_case_types)] #[repr(transparent)] pub struct mta_string_t(*mut mta_opaque_string_t); @@ -41,51 +48,104 @@ impl std::fmt::Debug for mta_string_t { } impl mta_string_t { - /// TODO + /// Create a new `mta_string_t` from a Rust string. pub fn new(value: impl Into) -> Self { - let cstring = CString::new(value.into()).unwrap(); - let boxed = Box::new(mta_opaque_string_t(cstring)); - mta_string_t(Box::into_raw(boxed)) + let cstring = CString::new(value.into()).expect("string contains NULL byte"); + let ptr = CString::into_raw(cstring); + return mta_string_t(ptr.cast()); } - /// TODO + /// Create a null `mta_string_t`, representing an absent string. pub fn null() -> Self { mta_string_t(std::ptr::null_mut()) } - /// TODO + /// View the string as a `&str`. Returns `""` for a null string. pub fn as_str(&self) -> &str { if self.0.is_null() { return ""; } unsafe { - return (*(self.0)).0.to_str().expect("mta_string_t is not valid UTF8") + let cstr = std::ffi::CStr::from_ptr(self.0.cast()); + return cstr.to_str().expect("invalid UTF-8 in mta_string_t"); } } } -/// TODO +/// Allocate a new `mta_string_t` by copying the null-terminated C string +/// `string`. +/// +/// The returned string must be freed with `mta_string_free`. +/// +/// @param string A pointer to a null-terminated C string. Must not be null. +/// @return A new `mta_string_t` containing a copy of `string`, or null if an +/// error occurred. You can check the error with `mta_last_error`. #[no_mangle] pub unsafe extern "C" fn mta_string_create( - raw: *const c_char, + string: *const c_char, ) -> mta_string_t { - todo!() + let mut result = mta_string_t::null(); + let unwind_wrapper = std::panic::AssertUnwindSafe(&mut result); + + catch_unwind(move || { + check_pointers_non_null!(string); + + let cstr = std::ffi::CStr::from_ptr(string); + let string = CString::from(cstr); + + let ptr = CString::into_raw(string); + + let _ = &unwind_wrapper; + *unwind_wrapper.0 = mta_string_t(ptr.cast()); + Ok(()) + }); + + return result; } -/// TODO +/// Free a `mta_string_t` previously created by `mta_string_create`. +/// +/// @param string A `mta_string_t` to free. Can be null, in which case this function is a no-op. #[no_mangle] pub unsafe extern "C" fn mta_string_free(string: mta_string_t) { - todo!() + catch_unwind(|| { + if string.0.is_null() { + return Ok(()); + } + + let ptr = string.0.cast::(); + let cstring = CString::from_raw(ptr); + std::mem::drop(cstring); + + Ok(()) + }); } -/// TODO +/// Return a pointer to the null-terminated string data inside `string`. +/// +/// The pointer is valid only for the lifetime of `string`. +/// +/// @param string A `mta_string_t` containing the string to view. Must not be null. +/// @return A pointer to the null-terminated C string inside `string` #[no_mangle] pub unsafe extern "C" fn mta_string_view( string: mta_string_t, ) -> *const c_char { - todo!() -} + let mut result = std::ptr::null(); + let unwind_wrapper = std::panic::AssertUnwindSafe(&mut result); + + catch_unwind(move || { + let string = string.0; + check_pointers_non_null!(string); + + let _ = &unwind_wrapper; + *unwind_wrapper.0 = string.cast(); + Ok(()) + }); + + return result; +} /// TODO #[no_mangle] @@ -98,5 +158,4 @@ pub unsafe extern "C" fn mta_unit_conversion_factor( } - // TODO: logging & warnings? diff --git a/metatomic-core/tests/misc.cpp b/metatomic-core/tests/misc.cpp index bf0ce275f..188eca67c 100644 --- a/metatomic-core/tests/misc.cpp +++ b/metatomic-core/tests/misc.cpp @@ -1,3 +1,5 @@ +#include + #include #include "metatomic.h" @@ -13,3 +15,35 @@ TEST_CASE("Version macros") { // METATOMIC_VERSION should start with `x.y.z` CHECK(std::string(METATOMIC_VERSION).find(version) == 0); } + +TEST_CASE("mta_string_t") { + auto* str = mta_string_create("hello"); + REQUIRE(str != nullptr); + + const char* view = mta_string_view(str); + CHECK(std::strlen(view) == 5); + CHECK(std::string(view) == "hello"); + mta_string_free(str); + + // empty string + str = mta_string_create(""); + REQUIRE(str != nullptr); + CHECK(std::string(mta_string_view(str)) == ""); + mta_string_free(str); + + // special characters + str = mta_string_create("a\nb\tc\xFFºµ"); + REQUIRE(str != nullptr); + CHECK(std::string(mta_string_view(str)) == std::string("a\nb\tc\xFFºµ")); + mta_string_free(str); + + // long string + std::string long_str(10000, 'x'); + str = mta_string_create(long_str.c_str()); + REQUIRE(str != nullptr); + CHECK(std::string(mta_string_view(str)) == long_str); + mta_string_free(str); + + // free on a null pointer should work + mta_string_free(nullptr); +} From f04c6c04751cb4fdf8dc645a0ee0c05082a1bf88 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Fri, 29 May 2026 15:44:44 +0200 Subject: [PATCH 12/20] Port unit parsing from metatomic-torch --- docs/src/core/index.rst | 1 + docs/src/core/units.rst | 97 ++++ docs/src/torch/reference/misc.rst | 62 --- metatomic-core/include/metatomic.h | 19 +- metatomic-core/src/c_api/mod.rs | 2 + metatomic-core/src/c_api/utils.rs | 39 +- metatomic-core/src/units.rs | 723 ++++++++++++++++++++++++++++- metatomic-core/tests/misc.cpp | 23 + 8 files changed, 898 insertions(+), 68 deletions(-) create mode 100644 docs/src/core/units.rst diff --git a/docs/src/core/index.rst b/docs/src/core/index.rst index 4d0cf24a4..2a3691316 100644 --- a/docs/src/core/index.rst +++ b/docs/src/core/index.rst @@ -9,6 +9,7 @@ WIP reference/c/index reference/json-formats + units .. toctree:: diff --git a/docs/src/core/units.rst b/docs/src/core/units.rst new file mode 100644 index 000000000..d5f3dd776 --- /dev/null +++ b/docs/src/core/units.rst @@ -0,0 +1,97 @@ +.. _core-unit-expressions: + +Units +^^^^^ + +Models in metatensor can use arbitrary units for their inputs and outputs. The +unit conversion system allows models to specify the units they expect and +receive data in any compatible unit, with automatic conversion handled by +:c:func:`mta_execute_model`. + +The :c:func:`mta_unit_conversion_factor` function parses two unit expressions, +checks that they have compatible physical dimensions, and returns the +multiplicative conversion factor: + +.. code-block:: c + + // How many eV are in one kJ/mol? + double factor; + mta_unit_conversion_factor("kJ/mol", "eV", &factor); + // factor ≈ 0.01036 + + // How many GPa are in one eV/A^3? + mta_unit_conversion_factor("eV/A^3", "GPa", &factor); + // factor ≈ 160.22 + +If either (or both) unit strings are empty, the conversion returns ``1.0`` +without checking dimensions. This makes it safe to pass optional/unknown units. + +.. _known-base-units: + +Base units +~~~~~~~~~~ + +Unit expressions are built from the following base units. Matching is +case-insensitive, and whitespace is ignored. + +**Temperature**: + ``Kelvin`` (``K``) + +**Length**: + ``angstrom`` (``A``), ``Bohr``, ``meter`` (``m``), ``centimeter`` (``cm``), + ``millimeter`` (``mm``), ``micrometer`` (``um``, ``µm``), ``nanometer`` (``nm``) + +**Energy**: + ``eV``, ``meV``, ``Hartree``, ``kcal``, ``kJ``, ``Joule`` (``J``), ``Rydberg`` (``Ry``) + +**Time**: + ``second`` (``s``), ``millisecond`` (``ms``), ``microsecond`` (``us``, ``µs``), + ``nanosecond`` (``ns``), ``picosecond`` (``ps``), ``femtosecond`` (``fs``) + +**Mass**: + ``Dalton`` (``u``), ``kilogram`` (``kg``), ``gram`` (``g``), ``electron_mass`` (``m_e``) + +**Charge**: + ``e``, ``Coulomb`` (``C``) + +**Pressure**: + ``Pascal`` (``Pa``), ``kiloPascal`` (``kPa``), ``MegaPascal`` (``MPa``), + ``GigaPascal`` (``GPa``), ``bar``, ``atm`` + +**Electric Dipole Moment**: + ``Debye`` (``D``) + +**Dimensionless**: + ``mol`` + +**Derived constants**: + ``hbar`` + +Expression syntax +~~~~~~~~~~~~~~~~~ + +Base units can be combined using the following operators: + +- Multiplication: ``*`` or whitespace (``kJ mol``, ``kJ*mol``) +- Division: ``/`` (``kJ/mol``) +- Exponentiation: ``^`` (``A^3``, ``m^2``) +- Parentheses: ``()`` for grouping (``(eV*u)^(1/2)``) + +Fractional powers + Exponents can be integers (``A^3``) or fractions enclosed in parentheses + (``^(1/2)``, ``^(2/3)``). Fractional powers are supported only when the + result has integer physical dimensions — for example ``(eV*u)^(1/2)`` + computes momentum with dimensions :math:`[L T^{-1} M]`. + +Numeric literals + Bare numbers can be used as dimensionless quantity expressions, e.g. + ``"2"`` evaluates to the conversion factor ``2.0``. This is useful when a + model needs to define a unit that is simply a scalar multiple of another. + +Examples of valid compound expressions: + +- ``kJ/mol`` --- energy per mole +- ``eV/Angstrom^3`` or ``eV/A^3`` --- pressure +- ``(eV*u)^(1/2)`` --- momentum (fractional powers) +- ``Hartree/Bohr`` --- force in atomic units +- ``nm/fs`` --- velocity diff --git a/docs/src/torch/reference/misc.rst b/docs/src/torch/reference/misc.rst index 1d4b522e0..6bcea3818 100644 --- a/docs/src/torch/reference/misc.rst +++ b/docs/src/torch/reference/misc.rst @@ -14,65 +14,3 @@ The :py:func:`unit_conversion_factor` function accepts any valid unit expression built from base units combined with operators. There is no need to specify a physical quantity --- the parser automatically verifies dimensional compatibility between the source and target units. - -.. _known-base-units: - -Supported base units -~~~~~~~~~~~~~~~~~~~~ - -Unit expressions are built from the following base units. Matching is -case-insensitive, and whitespace is ignored. - - -**Temperature**: - ``Kelvin`` (``K``) - -**Length**: - ``angstrom`` (``A``), ``Bohr``, ``meter`` (``m``), ``centimeter`` (``cm``), - ``millimeter`` (``mm``), ``micrometer`` (``um``, ``µm``), ``nanometer`` (``nm``) - -**Energy**: - ``eV``, ``meV``, ``Hartree``, ``kcal``, ``kJ``, ``Joule`` (``J``), ``Rydberg`` (``Ry``) - -**Time**: - ``second`` (``s``), ``millisecond`` (``ms``), ``microsecond`` (``us``, ``µs``), - ``nanosecond`` (``ns``), ``picosecond`` (``ps``), ``femtosecond`` (``fs``) - -**Mass**: - ``Dalton`` (``u``), ``kilogram`` (``kg``), ``gram`` (``g``), ``electron_mass`` (``m_e``) - -**Charge**: - ``e``, ``Coulomb`` (``C``) - -**Pressure**: - ``Pascal`` (``Pa``), ``kiloPascal`` (``kPa``), ``MegaPascal`` (``MPa``), ``GigaPascal`` (``GPa``), ``bar``, ``atm`` - -**Electric Dipole Moment**: - ``Debye`` (``D``) - -**Dimensionless**: - ``mol`` - -**Derived constants**: - ``hbar`` - -Expression syntax -~~~~~~~~~~~~~~~~~~~ - -Base units can be combined using the following operators: - -- Multiplication: ``*`` or whitespace (``kJ mol``, ``kJ*mol``) -- Division: ``/`` (``kJ/mol``) -- Exponentiation: ``^`` (``A^3``, ``m^2``) -- Parentheses: ``()`` for grouping (``(eV*u)^(1/2)``) - -Examples of valid compound expressions: - -- ``kJ/mol`` --- energy per mole -- ``eV/Angstrom^3`` or ``eV/A^3`` --- pressure -- ``(eV*u)^(1/2)`` --- momentum (fractional powers) -- ``Hartree/Bohr`` --- force in atomic units -- ``nm/fs`` --- velocity - -The parser automatically checks that both unit expressions have matching -physical dimensions before computing the conversion factor. diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index 25a5f7444..1a4fc35f9 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -195,7 +195,24 @@ void mta_string_free(mta_string_t string); const char *mta_string_view(mta_string_t string); /** - * TODO + * Get the multiplicative conversion factor to use to convert from + * `from_unit` to `to_unit`. Both units are parsed as expressions (e.g. + * "kJ/mol/A^2", "(eV*u)^(1/2)") and their dimensions must match. + * + * Unit expressions are built from base units combined with `*`, `/`, `^`, + * and parentheses. Unit lookup is case-insensitive, and whitespace is + * ignored. For example: + * + * - `"kJ/mol"` -- energy per mole + * - `"eV/Angstrom^3"` -- pressure + * - `"(eV*u)^(1/2)"` -- momentum (fractional powers) + * - `"Hartree/Bohr"` -- force in atomic units + * + * @param from_unit A null-terminated C string containing the unit to convert from. + * @param to_unit A null-terminated C string containing the unit to convert to. + * @param conversion A pointer to a `double` where the conversion factor will be stored. + * @return The status code of the operation. If this code is not `MTA_SUCCESS`, + * you can get more details about the error with `mta_last_error`. */ enum mta_status_t mta_unit_conversion_factor(const char *from_unit, const char *to_unit, diff --git a/metatomic-core/src/c_api/mod.rs b/metatomic-core/src/c_api/mod.rs index 282e230e5..bffa5003a 100644 --- a/metatomic-core/src/c_api/mod.rs +++ b/metatomic-core/src/c_api/mod.rs @@ -1,3 +1,5 @@ +#![allow(clippy::doc_markdown)] + #[macro_use] mod status; pub use self::status::{mta_status_t, catch_unwind}; diff --git a/metatomic-core/src/c_api/utils.rs b/metatomic-core/src/c_api/utils.rs index 8e22bcc47..448c2b5d4 100644 --- a/metatomic-core/src/c_api/utils.rs +++ b/metatomic-core/src/c_api/utils.rs @@ -3,7 +3,7 @@ use std::ffi::{CString, c_char}; use once_cell::sync::Lazy; use super::{mta_status_t, catch_unwind}; - +use crate::Error; static VERSION: Lazy = Lazy::new(|| { CString::new(env!("METATOMIC_FULL_VERSION")).expect("version contains NULL byte") @@ -147,14 +147,47 @@ pub unsafe extern "C" fn mta_string_view( return result; } -/// TODO +/// Get the multiplicative conversion factor to use to convert from +/// `from_unit` to `to_unit`. Both units are parsed as expressions (e.g. +/// "kJ/mol/A^2", "(eV*u)^(1/2)") and their dimensions must match. +/// +/// Unit expressions are built from base units combined with `*`, `/`, `^`, +/// and parentheses. Unit lookup is case-insensitive, and whitespace is +/// ignored. For example: +/// +/// - `"kJ/mol"` -- energy per mole +/// - `"eV/Angstrom^3"` -- pressure +/// - `"(eV*u)^(1/2)"` -- momentum (fractional powers) +/// - `"Hartree/Bohr"` -- force in atomic units +/// +/// @param from_unit A null-terminated C string containing the unit to convert from. +/// @param to_unit A null-terminated C string containing the unit to convert to. +/// @param conversion A pointer to a `double` where the conversion factor will be stored. +/// @return The status code of the operation. If this code is not `MTA_SUCCESS`, +/// you can get more details about the error with `mta_last_error`. #[no_mangle] pub unsafe extern "C" fn mta_unit_conversion_factor( from_unit: *const c_char, to_unit: *const c_char, conversion: *mut f64, ) -> mta_status_t { - todo!() + catch_unwind(|| { + check_pointers_non_null!(from_unit, to_unit, conversion); + + let from_cstr = std::ffi::CStr::from_ptr(from_unit); + let to_cstr = std::ffi::CStr::from_ptr(to_unit); + + let from_str = from_cstr.to_str().map_err(|_| { + Error::InvalidParameter("from_unit is not valid UTF-8".into()) + })?; + let to_str = to_cstr.to_str().map_err(|_| { + Error::InvalidParameter("to_unit is not valid UTF-8".into()) + })?; + + *conversion = crate::unit_conversion_factor(from_str, to_str)?; + + Ok(()) + }) } diff --git a/metatomic-core/src/units.rs b/metatomic-core/src/units.rs index d06eab413..4cfff2c1d 100644 --- a/metatomic-core/src/units.rs +++ b/metatomic-core/src/units.rs @@ -1,7 +1,726 @@ use crate::Error; +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::fmt; +use std::ops::{Add, Sub}; -/// TODO +/// Physical dimension vector with named integer exponents: +/// [Length, Time, Mass, Electric Current, Temperature] +/// +/// Note: quantity of substance (mole) is intentionally not included, since we +/// want `kJ/mol` and `eV` to have the same dimension. +#[derive(Debug, Clone, PartialEq, Eq)] +struct Dimension { + length: i32, + time: i32, + mass: i32, + electric_current: i32, + temperature: i32, +} + +impl Dimension { + /// Dimensionless — all exponents are zero. + const NONE: Dimension = Dimension { length: 0, time: 0, mass: 0, electric_current: 0, temperature: 0 }; + + /// Length dimension + const LENGTH: Dimension = Dimension { length: 1, time: 0, mass: 0, electric_current: 0, temperature: 0 }; + /// Time dimension + const TIME: Dimension = Dimension { length: 0, time: 1, mass: 0, electric_current: 0, temperature: 0 }; + /// Mass dimension + const MASS: Dimension = Dimension { length: 0, time: 0, mass: 1, electric_current: 0, temperature: 0 }; + /// Electric charge dimension (current × time) + const CHARGE: Dimension = Dimension { length: 0, time: 1, mass: 0, electric_current: 1, temperature: 0 }; + /// Temperature dimension + const TEMPERATURE: Dimension = Dimension { length: 0, time: 0, mass: 0, electric_current: 0, temperature: 1 }; + + /// Energy dimension: L² T⁻² M¹ + const ENERGY: Dimension = Dimension { length: 2, time: -2, mass: 1, electric_current: 0, temperature: 0 }; + /// Pressure dimension: L⁻¹ T⁻² M¹ + const PRESSURE: Dimension = Dimension { length: -1, time: -2, mass: 1, electric_current: 0, temperature: 0 }; + /// Electric dipole dimension: L¹ T¹ I¹ + const ELECTRIC_DIPOLE: Dimension = Dimension { length: 1, time: 1, mass: 0, electric_current: 1, temperature: 0 }; + + fn pow(&self, p: f64) -> Dimension { + Dimension { + length: round_if_integer(f64::from(self.length) * p), + time: round_if_integer(f64::from(self.time) * p), + mass: round_if_integer(f64::from(self.mass) * p), + electric_current: round_if_integer(f64::from(self.electric_current) * p), + temperature: round_if_integer(f64::from(self.temperature) * p), + } + } +} + +impl Add<&Dimension> for &Dimension { + type Output = Dimension; + + fn add(self, other: &Dimension) -> Dimension { + Dimension { + length: self.length + other.length, + time: self.time + other.time, + mass: self.mass + other.mass, + electric_current: self.electric_current + other.electric_current, + temperature: self.temperature + other.temperature, + } + } +} + +impl Sub<&Dimension> for &Dimension { + type Output = Dimension; + + fn sub(self, other: &Dimension) -> Dimension { + Dimension { + length: self.length - other.length, + time: self.time - other.time, + mass: self.mass - other.mass, + electric_current: self.electric_current - other.electric_current, + temperature: self.temperature - other.temperature, + } + } +} + +#[allow(clippy::cast_possible_truncation)] +fn round_if_integer(v: f64) -> i32 { + let rounded = v.round(); + assert!((v - rounded).abs() <= 1e-10, "non-integer dimension exponent {} is not supported", v); + return rounded as i32; +} + +impl fmt::Display for Dimension { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use fmt::Write; + let mut first = true; + f.write_char('[')?; + + for (name, v) in [ + ("L", self.length), + ("T", self.time), + ("M", self.mass), + ("I", self.electric_current), + ("Θ", self.temperature), + ] { + if v == 0 { + continue; + } + + if !first { + f.write_char(' ')?; + } + first = false; + + f.write_str(name)?; + + if v != 1 && v != -1 { + write!(f, "^{}", v)?; + } + + if v == -1 { + f.write_str("^-1")?; + } + } + + if first { + f.write_str("dimensionless")?; + } + f.write_char(']')?; + + Ok(()) + } +} + +/// A parsed unit value: SI conversion factor and physical dimension. +#[derive(Debug, Clone)] +struct UnitValue { + factor: f64, + dim: Dimension, +} + +/// All base units with SI factors and dimensions. +/// Factors are expressed in SI base units (m, s, kg, C, K). +/// Case-insensitive lookup: names are lowercased before searching. +static BASE_UNITS: Lazy> = Lazy::new(|| { + let mut map = HashMap::new(); + + // --- Temperature --- + map.insert("kelvin", UnitValue { factor: 1.0, dim: Dimension::TEMPERATURE }); + map.insert("k", UnitValue { factor: 1.0, dim: Dimension::TEMPERATURE }); + + // --- Length --- + map.insert("angstrom", UnitValue { factor: 1e-10, dim: Dimension::LENGTH }); + map.insert("a", UnitValue { factor: 1e-10, dim: Dimension::LENGTH }); + map.insert("bohr", UnitValue { factor: 5.2917721054482e-11, dim: Dimension::LENGTH }); + map.insert("nm", UnitValue { factor: 1e-9, dim: Dimension::LENGTH }); + map.insert("nanometer", UnitValue { factor: 1e-9, dim: Dimension::LENGTH }); + map.insert("meter", UnitValue { factor: 1.0, dim: Dimension::LENGTH }); + map.insert("m", UnitValue { factor: 1.0, dim: Dimension::LENGTH }); + map.insert("cm", UnitValue { factor: 1e-2, dim: Dimension::LENGTH }); + map.insert("centimeter", UnitValue { factor: 1e-2, dim: Dimension::LENGTH }); + map.insert("mm", UnitValue { factor: 1e-3, dim: Dimension::LENGTH }); + map.insert("millimeter", UnitValue { factor: 1e-3, dim: Dimension::LENGTH }); + map.insert("um", UnitValue { factor: 1e-6, dim: Dimension::LENGTH }); + map.insert("µm", UnitValue { factor: 1e-6, dim: Dimension::LENGTH }); + map.insert("micrometer", UnitValue { factor: 1e-6, dim: Dimension::LENGTH }); + + // --- Energy --- + map.insert("electronvolt", UnitValue { factor: 1.602176634e-19, dim: Dimension::ENERGY }); + map.insert("ev", UnitValue { factor: 1.602176634e-19, dim: Dimension::ENERGY }); + map.insert("mev", UnitValue { factor: 1.602176634e-19 * 1e-3, dim: Dimension::ENERGY }); + map.insert("hartree", UnitValue { factor: 4.359744722206048e-18, dim: Dimension::ENERGY }); + map.insert("ry", UnitValue { factor: 2.179872361103024e-18, dim: Dimension::ENERGY }); + map.insert("rydberg", UnitValue { factor: 2.179872361103024e-18, dim: Dimension::ENERGY }); + map.insert("joule", UnitValue { factor: 1.0, dim: Dimension::ENERGY }); + map.insert("j", UnitValue { factor: 1.0, dim: Dimension::ENERGY }); + map.insert("kcal", UnitValue { factor: 4184.0, dim: Dimension::ENERGY }); + map.insert("kj", UnitValue { factor: 1000.0, dim: Dimension::ENERGY }); + + // --- Time --- + map.insert("s", UnitValue { factor: 1.0, dim: Dimension::TIME }); + map.insert("second", UnitValue { factor: 1.0, dim: Dimension::TIME }); + map.insert("ms", UnitValue { factor: 1e-3, dim: Dimension::TIME }); + map.insert("millisecond", UnitValue { factor: 1e-3, dim: Dimension::TIME }); + map.insert("us", UnitValue { factor: 1e-6, dim: Dimension::TIME }); + map.insert("µs", UnitValue { factor: 1e-6, dim: Dimension::TIME }); + map.insert("microsecond", UnitValue { factor: 1e-6, dim: Dimension::TIME }); + map.insert("ns", UnitValue { factor: 1e-9, dim: Dimension::TIME }); + map.insert("nanosecond", UnitValue { factor: 1e-9, dim: Dimension::TIME }); + map.insert("ps", UnitValue { factor: 1e-12, dim: Dimension::TIME }); + map.insert("picosecond", UnitValue { factor: 1e-12, dim: Dimension::TIME }); + map.insert("fs", UnitValue { factor: 1e-15, dim: Dimension::TIME }); + map.insert("femtosecond", UnitValue { factor: 1e-15, dim: Dimension::TIME }); + + // --- Mass --- + map.insert("u", UnitValue { factor: 1.6605390689252e-27, dim: Dimension::MASS }); + map.insert("dalton", UnitValue { factor: 1.6605390689252e-27, dim: Dimension::MASS }); + map.insert("kg", UnitValue { factor: 1.0, dim: Dimension::MASS }); + map.insert("kilogram", UnitValue { factor: 1.0, dim: Dimension::MASS }); + map.insert("g", UnitValue { factor: 1e-3, dim: Dimension::MASS }); + map.insert("gram", UnitValue { factor: 1e-3, dim: Dimension::MASS }); + map.insert("electron_mass", UnitValue { factor: 9.109383713928e-31, dim: Dimension::MASS }); + map.insert("m_e", UnitValue { factor: 9.109383713928e-31, dim: Dimension::MASS }); + + // --- Charge --- + map.insert("e", UnitValue { factor: 1.602176634e-19, dim: Dimension::CHARGE }); + map.insert("coulomb", UnitValue { factor: 1.0, dim: Dimension::CHARGE }); + map.insert("c", UnitValue { factor: 1.0, dim: Dimension::CHARGE }); + + // --- Pressure --- + map.insert("pa", UnitValue { factor: 1.0, dim: Dimension::PRESSURE }); + map.insert("pascal", UnitValue { factor: 1.0, dim: Dimension::PRESSURE }); + map.insert("kpa", UnitValue { factor: 1e3, dim: Dimension::PRESSURE }); + map.insert("kilopascal", UnitValue { factor: 1e3, dim: Dimension::PRESSURE }); + map.insert("mpa", UnitValue { factor: 1e6, dim: Dimension::PRESSURE }); + map.insert("megapascal", UnitValue { factor: 1e6, dim: Dimension::PRESSURE }); + map.insert("gpa", UnitValue { factor: 1e9, dim: Dimension::PRESSURE }); + map.insert("gigapascal", UnitValue { factor: 1e9, dim: Dimension::PRESSURE }); + map.insert("bar", UnitValue { factor: 100000.0, dim: Dimension::PRESSURE }); + map.insert("atm", UnitValue { factor: 101325.0, dim: Dimension::PRESSURE }); + + // --- Electric dipole moment --- + map.insert("debye", UnitValue { factor: 1.0 / 299792458.0 * 1e-21, dim: Dimension::ELECTRIC_DIPOLE }); + map.insert("d", UnitValue { factor: 1.0 / 299792458.0 * 1e-21, dim: Dimension::ELECTRIC_DIPOLE }); + + // --- Dimensionless --- + map.insert("mol", UnitValue { factor: 6.02214076e23, dim: Dimension::NONE }); + + // --- Derived --- + map.insert("hbar", UnitValue { + factor: 1.0545718176462e-34, + dim: Dimension { length: 2, time: -1, mass: 1, electric_current: 0, temperature: 0 }, + }); + + map +}); + +// ---- Tokenizer ---- + +#[derive(Debug, Clone)] +enum Token { + LParen, + RParen, + Mul, + Div, + Pow, + Value(String), +} + +impl Token { + fn precedence(&self) -> i32 { + match self { + Token::LParen | Token::RParen => 0, + Token::Mul | Token::Div => 10, + Token::Pow => 20, + Token::Value(_) => -1, + } + } +} + +impl fmt::Display for Token { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Token::LParen => write!(f, "("), + Token::RParen => write!(f, ")"), + Token::Mul => write!(f, "*"), + Token::Div => write!(f, "/"), + Token::Pow => write!(f, "^"), + Token::Value(v) => write!(f, "{}", v), + } + } +} + +fn tokenize(unit: &str) -> Vec { + let mut tokens = Vec::new(); + let mut current = String::new(); + + for c in unit.chars() { + if c == '*' || c == '/' || c == '^' || c == '(' || c == ')' { + if !current.is_empty() { + tokens.push(Token::Value(current.clone())); + current.clear(); + } + let t = match c { + '*' => Token::Mul, + '/' => Token::Div, + '^' => Token::Pow, + '(' => Token::LParen, + ')' => Token::RParen, + _ => unreachable!(), + }; + tokens.push(t); + } else if !c.is_whitespace() { + current.push(c); + } + } + + if !current.is_empty() { + tokens.push(Token::Value(current)); + } + + tokens +} + +// ---- Shunting-Yard ---- + +/// Convert infix tokens to [Reverse Polish Notation] (RPN) using the +/// [Shunting-Yard] algorithm. +/// +/// RPN (also called postfix notation) writes operators after their operands, +/// e.g. `kJ / mol` becomes `kJ mol /`. This removes the need for parentheses +/// and precedence rules, making the expression easy to evaluate with a stack. +/// +/// All operators are treated as left-associative. +/// +/// [Reverse Polish Notation]: https://en.wikipedia.org/wiki/Reverse_Polish_notation +/// [Shunting-Yard]: https://en.wikipedia.org/wiki/Shunting-yard_algorithm +fn shunting_yard(tokens: &[Token]) -> Result, Error> { + let mut output: Vec = Vec::new(); + let mut operators: Vec = Vec::new(); + + for token in tokens { + match token { + Token::Value(_) => { + output.push(token.clone()); + } + Token::Mul | Token::Div | Token::Pow => { + while let Some(top) = operators.last() { + if token.precedence() <= top.precedence() { + output.push(operators.pop().unwrap()); + } else { + break; + } + } + operators.push(token.clone()); + } + Token::LParen => { + operators.push(token.clone()); + } + Token::RParen => { + while let Some(top) = operators.last() { + if matches!(top, Token::LParen) { + break; + } + output.push(operators.pop().unwrap()); + } + if operators.is_empty() || !matches!(operators.last(), Some(Token::LParen)) { + return Err(Error::InvalidParameter( + "unit expression has unbalanced parentheses".into(), + )); + } + operators.pop(); // discard LParen + } + } + } + + while let Some(top) = operators.pop() { + if matches!(top, Token::LParen | Token::RParen) { + return Err(Error::InvalidParameter( + "unit expression has unbalanced parentheses".into(), + )); + } + output.push(top); + } + + Ok(output) +} + +// ---- AST evaluator ---- + +struct UnitExpr { + val: UnitExprData, +} + +enum UnitExprData { + Val(UnitValue, String), + Mul(Box, Box), + Div(Box, Box), + Pow(Box, Box), +} + +impl fmt::Display for UnitExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.val { + UnitExprData::Val(_, name) => f.write_str(name), + UnitExprData::Mul(lhs, rhs) => { + write!(f, "({} * {})", lhs, rhs) + } + UnitExprData::Div(lhs, rhs) => { + write!(f, "({} / {})", lhs, rhs) + } + UnitExprData::Pow(base, exponent) => { + write!(f, "({} ^ {})", base, exponent) + } + } + } +} + +impl UnitExpr { + fn eval(&self) -> Result { + match &self.val { + UnitExprData::Val(v, _) => Ok(v.clone()), + UnitExprData::Mul(lhs, rhs) => { + let l = lhs.eval()?; + let r = rhs.eval()?; + let result_factor = l.factor * r.factor; + if !result_factor.is_finite() { + return Err(Error::InvalidParameter(format!( + "unit conversion factor overflows: multiplication result is infinite \ + or NaN for '{}'", + self + ))); + } + Ok(UnitValue { + factor: result_factor, + dim: &l.dim + &r.dim, + }) + } + UnitExprData::Div(lhs, rhs) => { + let l = lhs.eval()?; + let r = rhs.eval()?; + let result_factor = l.factor / r.factor; + if !result_factor.is_finite() { + return Err(Error::InvalidParameter(format!( + "unit conversion factor overflows: division result is infinite \ + or NaN for '{}'", + self + ))); + } + Ok(UnitValue { + factor: result_factor, + dim: &l.dim - &r.dim, + }) + } + UnitExprData::Pow(base, exponent) => { + let b = base.eval()?; + let e = exponent.eval()?; + + if e.dim != Dimension::NONE { + return Err(Error::InvalidParameter(format!( + "exponent in unit expression must be dimensionless, got dimension {} \ + for exponent '{}'", + e.dim, + exponent + ))); + } + let result_factor = b.factor.powf(e.factor); + if !result_factor.is_finite() { + return Err(Error::InvalidParameter(format!( + "unit conversion factor overflows: exponentiation result is infinite \ + or NaN for '{}'", + self + ))); + } + Ok(UnitValue { + factor: result_factor, + dim: b.dim.pow(e.factor), + }) + } + } + } +} + +/// Read one expression from the [RPN] stream (recursive, pops from the back). +/// +/// RPN arranges expressions as `lhs rhs op`, so `rhs` is on top of the stack +/// and must be popped first. For example `kJ mol /` pops `mol` (rhs) then +/// `kJ` (lhs) to build `Div(lhs=kJ, rhs=mol)`. +/// +/// [RPN]: https://en.wikipedia.org/wiki/Reverse_Polish_notation +fn read_expr(stream: &mut Vec) -> Result { + let token = stream.pop().ok_or_else(|| { + Error::InvalidParameter("malformed unit expression: missing a value".into()) + })?; + + match token { + Token::Value(v) => { + let lower = v.to_lowercase(); + if let Some(uv) = BASE_UNITS.get(lower.as_str()) { + return Ok(UnitExpr { + val: UnitExprData::Val(uv.clone(), v), + }); + } + if let Ok(val) = v.parse::() { + return Ok(UnitExpr { + val: UnitExprData::Val(UnitValue { factor: val, dim: Dimension::NONE }, v), + }); + } + Err(Error::InvalidParameter(format!("unknown unit '{}'", v))) + } + // RPN: lhs rhs Mul — pop rhs first, then lhs + Token::Mul => { + let rhs = read_expr(stream)?; + let lhs = read_expr(stream)?; + Ok(UnitExpr { + val: UnitExprData::Mul(Box::new(lhs), Box::new(rhs)), + }) + } + // RPN: lhs rhs Div — pop rhs first, then lhs + Token::Div => { + let rhs = read_expr(stream)?; + let lhs = read_expr(stream)?; + Ok(UnitExpr { + val: UnitExprData::Div(Box::new(lhs), Box::new(rhs)), + }) + } + // RPN: base exponent Pow — pop exponent first, then base + Token::Pow => { + let exponent = read_expr(stream)?; + let base = read_expr(stream)?; + Ok(UnitExpr { + val: UnitExprData::Pow(Box::new(base), Box::new(exponent)), + }) + } + _ => Err(Error::InvalidParameter(format!( + "unexpected symbol in unit expression: '{}'", + token + ))), + } +} + +/// Parse a unit expression string and return the evaluated `UnitValue`. +fn parse_unit_expression(unit: &str) -> Result { + if unit.is_empty() { + return Ok(UnitValue { factor: 1.0, dim: Dimension::NONE }); + } + + let tokens = tokenize(unit); + if tokens.is_empty() { + return Ok(UnitValue { factor: 1.0, dim: Dimension::NONE }); + } + + let mut rpn = shunting_yard(&tokens)?; + let ast = read_expr(&mut rpn)?; + + if !rpn.is_empty() { + let remaining: Vec = rpn.iter().map(|t| t.to_string()).collect(); + return Err(Error::InvalidParameter(format!( + "malformed unit expression: leftover input '{}'", + remaining.join(" ") + ))); + } + + ast.eval() +} + +/// Get the multiplicative conversion factor to use to convert from +/// `from_unit` to `to_unit`. Both units are parsed as expressions (e.g. +/// "kJ/mol/A^2", "(eV*u)^(1/2)") and their dimensions must match. +/// +/// Unit expressions are built from base units combined with `*`, `/`, `^`, +/// and parentheses. Unit lookup is case-insensitive, and whitespace is +/// ignored. For example: +/// +/// - `"kJ/mol"` -- energy per mole +/// - `"eV/Angstrom^3"` -- pressure +/// - `"(eV*u)^(1/2)"` -- momentum (fractional powers) +/// - `"Hartree/Bohr"` -- force in atomic units pub fn unit_conversion_factor(from_unit: &str, to_unit: &str) -> Result { - todo!() + if from_unit.is_empty() || to_unit.is_empty() { + return Ok(1.0); + } + + let from = parse_unit_expression(from_unit)?; + let to = parse_unit_expression(to_unit)?; + + if from.dim != to.dim { + return Err(Error::InvalidParameter(format!( + "dimension mismatch in unit conversion: '{}' has dimension {} but '{}' has dimension {}", + from_unit, + from.dim, + to_unit, + to.dim + ))); + } + + Ok(from.factor / to.factor) +} + +#[cfg(test)] +#[allow(clippy::float_cmp)] +mod tests { + use super::*; + + #[test] + fn test_tokenize_simple() { + let tokens = tokenize("eV"); + assert_eq!(tokens.len(), 1); + assert!(matches!(&tokens[0], Token::Value(v) if v == "eV")); + } + + #[test] + fn test_tokenize_operators() { + let tokens = tokenize("kJ/mol/A^2"); + let types: Vec = tokens.iter().map(|t| t.to_string()).collect(); + assert_eq!(types, vec!["kJ", "/", "mol", "/", "A", "^", "2"]); + } + + #[test] + fn test_tokenize_parens() { + let tokens = tokenize("(eV*u)^(1/2)"); + let types: Vec = tokens.iter().map(|t| t.to_string()).collect(); + assert_eq!(types, vec!["(", "eV", "*", "u", ")", "^", "(", "1", "/", "2", ")"]); + } + + #[test] + fn test_tokenize_whitespace() { + let tokens = tokenize(" kJ / mol "); + let types: Vec = tokens.iter().map(|t| t.to_string()).collect(); + assert_eq!(types, vec!["kJ", "/", "mol"]); + } + + #[test] + fn test_shunting_yard() { + let tokens = tokenize("kJ/mol"); + let rpn = shunting_yard(&tokens).unwrap(); + let types: Vec = rpn.iter().map(|t| t.to_string()).collect(); + assert_eq!(types, vec!["kJ", "mol", "/"]); + + let tokens = tokenize("kJ/mol/A^2"); + let rpn = shunting_yard(&tokens).unwrap(); + let types: Vec = rpn.iter().map(|t| t.to_string()).collect(); + assert_eq!(types, vec!["kJ", "mol", "/", "A", "2", "^", "/"]); + } + + #[test] + fn test_parens_mismatch() { + let tokens = tokenize("("); + let err = shunting_yard(&tokens).expect_err("expected error"); + assert_eq!( + err.to_string(), + "invalid parameter: unit expression has unbalanced parentheses" + ); + + let tokens = tokenize("(eV*u"); + let err = shunting_yard(&tokens).expect_err("expected error"); + assert_eq!( + err.to_string(), + "invalid parameter: unit expression has unbalanced parentheses" + ); + } + + #[test] + fn test_simple_conversion() { + let factor = unit_conversion_factor("eV", "eV").unwrap(); + assert_eq!(factor, 1.0); + + let factor = unit_conversion_factor("m", "A").unwrap(); + assert!((factor - 1e10).abs() < 1e-5); + + let factor = unit_conversion_factor("eV", "kJ").unwrap(); + assert!((factor - 1.602176634e-22).abs() < 1e-30); + } + + #[test] + fn test_dimension_mismatch() { + let err = unit_conversion_factor("eV", "m").expect_err("expected error"); + assert_eq!( + err.to_string(), + "invalid parameter: dimension mismatch in unit conversion: \ + 'eV' has dimension [L^2 T^-2 M] but 'm' has dimension [L]" + ); + } + + #[test] + fn test_empty_units() { + let factor = unit_conversion_factor("", "").unwrap(); + assert_eq!(factor, 1.0); + + let factor = unit_conversion_factor("eV", "").unwrap(); + assert_eq!(factor, 1.0); + } + + #[test] + fn test_compound_units() { + let from = unit_conversion_factor("kJ/mol", "eV").unwrap(); + assert!((from - 0.010364269656262174).abs() < 1e-15); + + let from = unit_conversion_factor("eV/A^3", "GPa").unwrap(); + assert!((from - 160.21766339999996).abs() < 1e-12); + } + + #[test] + fn test_case_insensitive() { + let f1 = unit_conversion_factor("eV", "eV").unwrap(); + let f2 = unit_conversion_factor("EV", "eV").unwrap(); + assert_eq!(f1, f2); + + let factor = unit_conversion_factor("eV", "MeV").unwrap(); + assert!((factor - 1000.0).abs() < 1e-12); + } + + #[test] + fn test_unknown_unit() { + let err = unit_conversion_factor("foo", "eV").expect_err("expected error"); + assert_eq!(err.to_string(), "invalid parameter: unknown unit 'foo'"); + } + + #[test] + fn test_numeric_literal() { + let factor = unit_conversion_factor("2", "1").unwrap(); + assert_eq!(factor, 2.0); + } + + #[test] + fn test_fractional_power() { + let err = unit_conversion_factor("(eV*u)^(1/2)", "eV*u").expect_err("expected error"); + assert_eq!( + err.to_string(), + "invalid parameter: dimension mismatch in unit conversion: \ + '(eV*u)^(1/2)' has dimension [L T^-1 M] but 'eV*u' has dimension [L^2 T^-2 M^2]" + ); + + let factor = unit_conversion_factor("(eV*u)^(1/2)", "(eV*u)^(1/2)").unwrap(); + assert_eq!(factor, 1.0); + } + + #[test] + fn test_dimension_to_string() { + assert_eq!(Dimension::NONE.to_string(), "[dimensionless]"); + assert_eq!(Dimension::LENGTH.to_string(), "[L]"); + assert_eq!(Dimension::ENERGY.to_string(), "[L^2 T^-2 M]"); + assert_eq!(Dimension::PRESSURE.to_string(), "[L^-1 T^-2 M]"); + assert_eq!(Dimension::TEMPERATURE.to_string(), "[Θ]"); + + let velocity = Dimension { length: 1, time: -1, mass: 0, electric_current: 0, temperature: 0 }; + assert_eq!(velocity.to_string(), "[L T^-1]"); + } } diff --git a/metatomic-core/tests/misc.cpp b/metatomic-core/tests/misc.cpp index 188eca67c..8c66d7561 100644 --- a/metatomic-core/tests/misc.cpp +++ b/metatomic-core/tests/misc.cpp @@ -47,3 +47,26 @@ TEST_CASE("mta_string_t") { // free on a null pointer should work mta_string_free(nullptr); } + +TEST_CASE("mta_unit_conversion_factor") { + double factor = 0.0; + + // same unit -> factor = 1.0 + auto status = mta_unit_conversion_factor("m", "m", &factor); + REQUIRE(status == MTA_SUCCESS); + CHECK(factor == 1.0); + + // kJ/mol -> eV + CHECK(mta_unit_conversion_factor("kJ/mol", "eV", &factor) == MTA_SUCCESS); + CHECK(factor == Approx(0.010364269656262174).epsilon(1e-15)); + + // dimension mismatch -> error + status = mta_unit_conversion_factor("m", "kg", &factor); + REQUIRE(status != MTA_SUCCESS); + + const char* error_msg = nullptr; + mta_last_error(&error_msg, nullptr, nullptr); + CHECK(std::string(error_msg) == + "invalid parameter: dimension mismatch in unit conversion: " + "'m' has dimension [L] but 'kg' has dimension [M]"); +} From 49ce8af181f194e8b041436cbd739643ab2ca658 Mon Sep 17 00:00:00 2001 From: Sofiia Chorna Date: Sat, 30 May 2026 13:58:51 +0200 Subject: [PATCH 13/20] document mta_model_t and related functions in C API --- docs/src/core/reference/json-formats.rst | 6 + metatomic-core/include/metatomic.h | 149 ++++++++++++++++++++--- metatomic-core/src/c_api/model.rs | 149 ++++++++++++++++++++--- metatomic-core/src/lib.rs | 1 + 4 files changed, 277 insertions(+), 28 deletions(-) diff --git a/docs/src/core/reference/json-formats.rst b/docs/src/core/reference/json-formats.rst index d47cb2715..859f7d985 100644 --- a/docs/src/core/reference/json-formats.rst +++ b/docs/src/core/reference/json-formats.rst @@ -8,6 +8,8 @@ strings rather than dedicated C types. This page documents the exact JSON representation of each such structure, so that engines and models written in any language can produce and consume them. +.. _core-json-pair-options: + Pair list options ----------------- @@ -51,6 +53,8 @@ list). This is used for example by :c:func:`mta_system_add_pairs`, omitted, in which case it is treated as an empty list. +.. _core-json-quantity: + Quantities ---------- @@ -91,6 +95,8 @@ inputs and outputs. This is used for example in following: ``"atom"``, ``"system"`` or ``"atom_pair"``. +.. _core-json-model-metadata: + Model metadata -------------- diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index 1a4fc35f9..a4ba5e95e 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -85,42 +85,136 @@ typedef struct mta_system_t mta_system_t; typedef mta_opaque_string_t *mta_string_t; /** - * TODO + * A model that computes physical properties of atomistic systems. + * + * `mta_model_t` is a small virtual table: `data` holds the model's own state, + * and the function pointers describe what the model can do. A model is usually + * produced by a plugin's `load_model` callback (see `mta_load_model`) and then + * executed with `mta_execute_model`. + * + * Every callback receives `data` as its first argument. metatomic treats + * `data` as opaque and only hands it back to the callbacks. Callbacks should + * report any error by saving it with `mta_set_last_error` and returning a + * non-success `mta_status_t`. */ typedef struct mta_model_t { /** - * TODO + * Opaque pointer to the model's internal state + * + * Its layout and meaning are private to the model implementation. It is + * initialized by whoever creates the model (e.g. a plugin's `load_model`) + * and released by `unload`. */ void *data; /** - * TODO + * Release the resources owned by `model_data` + * + * Called exactly once when the model is no longer needed. May be `NULL` if + * the model owns no resources. + * + * @param model_data the model's `data` pointer + * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*unload)(void *model_data); /** - * TODO + * Get metadata describing the model (name, authors, references, ...) as a + * JSON string. + * + * @verbatim embed:rst:leading-asterisk + * The expected JSON structure is documented in :ref:`core-json-model-metadata`. + * @endverbatim + * + * @param model_data the model's `data` pointer + * @param metadata_json output string, set to a JSON-serialized + * `ModelMetadata` object. The + * caller takes ownership and must free it with `mta_string_free`. + * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*metadata)(const void *model_data, mta_string_t *metadata_json); /** - * TODO + * List the outputs this model is able to compute as a JSON string. + * + * @verbatim embed:rst:leading-asterisk + * The expected JSON structure for each output is documented in :ref:`core-json-quantity`. + * @endverbatim + * + * @param model_data the model's `data` pointer + * @param outputs_json output string, set to a JSON array of `Quantity` + * objects, one per supported output. The caller takes ownership and + * must free it with `mta_string_free`. + * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*supported_outputs)(const void *model_data, mta_string_t *outputs_json); /** - * TODO + * List the pair lists (neighbor lists) the model needs as input as a JSON + * string. + * + * @verbatim embed:rst:leading-asterisk + * + * The engine is expected to compute these and attach them to every system + * with :c:func:`mta_system_add_pairs` before calling + * :c:func:`mta_execute_model`. + * + * The expected JSON structure for each pair list is documented in :ref:`core-json-pair-options`. + * + * @endverbatim + * + * @param model_data the model's `data` pointer + * @param pair_options_json output string, set to a JSON array of + * `PairListOptions` objects. The caller takes ownership and must + * free it with `mta_string_free`. + * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*requested_pair_lists)(const void *model_data, mta_string_t *pair_options_json); /** - * TODO + * List the additional per-system inputs the model needs as a JSON string. + * + * @verbatim embed:rst:leading-asterisk + * + * These correspond to custom data the engine should attach to every system + * with :c:func:`mta_system_add_custom_data` before execution. + * + * The expected JSON structure for each input is documented in :ref:`core-json-quantity`. + * + * @endverbatim + * + * @param model_data the model's `data` pointer + * @param inputs_json output string, set to a JSON array of `Quantity` + * objects, one per requested input. The caller takes ownership and + * must free it with `mta_string_free`. + * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*requested_inputs)(const void *model_data, mta_string_t *inputs_json); /** - * TODO + * Run the model and compute the requested outputs + * + * @verbatim embed:rst:leading-asterisk + * + * This performs the model's actual computation. This should not be called + * directly, but rather through :c:func:`mta_execute_model`, which handles + * unit conversion and can check inputs and output data for consistency. + * + * @endverbatim + * + * @param model_data the model's `data` pointer + * @param systems array of `systems_count` systems to run the model on + * @param systems_count number of entries in `systems` + * @param selected_atoms optional labels selecting the subset of atoms to + * compute outputs for, or `NULL` to use all atoms. When set, it has the + * dimensions `"system"` and `"atom"` holding 0-based indices. + * @param requested_outputs_json JSON string containing an array of + * `Quantity`, one for each output the model should produce + * @param outputs array of `outputs_count` tensor maps to fill, one per + * requested output and in the same order + * @param outputs_count number of entries in `outputs`, must equal + * `requested_outputs_count` + * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*execute_inner)(void *model_data, const struct mta_system_t *const *systems, uintptr_t systems_count, const mts_labels_t *selected_atoms, - const char *const *requested_outputs_json, - uintptr_t requested_outputs_count, + const char *requested_outputs_json, mts_tensormap_t **outputs, uintptr_t outputs_count); } mta_model_t; @@ -292,20 +386,47 @@ enum mta_status_t mta_system_known_custom_data(const struct mta_system_t *system mta_string_t *names); /** - * TODO + * Execute a model to compute the requested outputs for a set of systems + * + * This is the main entry point to run a model loaded through the C API. It + * validates the arguments and delegates the computation to the model's + * `execute_inner` callback. + * + * @param model the model to execute + * @param systems array of `systems_count` systems to run the model on + * @param systems_count number of entries in `systems` + * @param selected_atoms optional labels selecting the subset of atoms to + * compute outputs for, or `NULL` to use all atoms + * @param requested_outputs_json JSON string containing an array of + * `Quantity`, one for each output the model should produce + * @param check_consistency if `true`, run additional checks on the + * inputs and on the data produced by the model + * @param outputs array of `outputs_count` tensor maps to fill, one per + * requested output and in the same order. The caller takes ownership of + * the returned tensor maps. + * @param outputs_count number of entries in `outputs`, must equal + * `requested_outputs_count` + * @return `MTA_SUCCESS` on success, another status code on error (the message + * is available through `mta_last_error`) */ enum mta_status_t mta_execute_model(struct mta_model_t model, const struct mta_system_t *const *systems, uintptr_t systems_count, const mts_labels_t *selected_atoms, - const char *const *requested_outputs_json, - uintptr_t requested_outputs_count, + const char *requested_outputs_json, bool check_consistency, mts_tensormap_t **outputs, uintptr_t outputs_count); /** - * TODO + * Render model metadata as a human-readable string + * + * @param metadata a JSON-serialized `ModelMetadata` object as produced by a + * model's `metadata` callback. Must not be null. + * @param printed output string, set to a human-readable rendering of the + * metadata. The caller takes ownership and must free it with + * `mta_string_free`. + * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t mta_format_metadata(const char *metadata, mta_string_t *printed); diff --git a/metatomic-core/src/c_api/model.rs b/metatomic-core/src/c_api/model.rs index b586f73c8..c9fc83610 100644 --- a/metatomic-core/src/c_api/model.rs +++ b/metatomic-core/src/c_api/model.rs @@ -3,62 +3,176 @@ use metatensor::c_api::{mts_labels_t, mts_tensormap_t}; use super::{mta_status_t, mta_string_t, mta_system_t}; -/// TODO +/// A model that computes physical properties of atomistic systems. +/// +/// `mta_model_t` is a small virtual table: `data` holds the model's own state, +/// and the function pointers describe what the model can do. A model is usually +/// produced by a plugin's `load_model` callback (see `mta_load_model`) and then +/// executed with `mta_execute_model`. +/// +/// Every callback receives `data` as its first argument. metatomic treats +/// `data` as opaque and only hands it back to the callbacks. Callbacks should +/// report any error by saving it with `mta_set_last_error` and returning a +/// non-success `mta_status_t`. #[repr(C)] #[allow(non_camel_case_types)] pub struct mta_model_t { - /// TODO + /// Opaque pointer to the model's internal state + /// + /// Its layout and meaning are private to the model implementation. It is + /// initialized by whoever creates the model (e.g. a plugin's `load_model`) + /// and released by `unload`. pub data: *mut c_void, - /// TODO + /// Release the resources owned by `model_data` + /// + /// Called exactly once when the model is no longer needed. May be `NULL` if + /// the model owns no resources. + /// + /// @param model_data the model's `data` pointer + /// @return `MTA_SUCCESS` on success, another status code on error pub unload: Option mta_status_t>, - /// TODO + /// Get metadata describing the model (name, authors, references, ...) as a + /// JSON string. + /// + /// @verbatim embed:rst:leading-asterisk + /// The expected JSON structure is documented in :ref:`core-json-model-metadata`. + /// @endverbatim + /// + /// @param model_data the model's `data` pointer + /// @param metadata_json output string, set to a JSON-serialized + /// `ModelMetadata` object. The + /// caller takes ownership and must free it with `mta_string_free`. + /// @return `MTA_SUCCESS` on success, another status code on error pub metadata: Option mta_status_t>, - /// TODO + /// List the outputs this model is able to compute as a JSON string. + /// + /// @verbatim embed:rst:leading-asterisk + /// The expected JSON structure for each output is documented in :ref:`core-json-quantity`. + /// @endverbatim + /// + /// @param model_data the model's `data` pointer + /// @param outputs_json output string, set to a JSON array of `Quantity` + /// objects, one per supported output. The caller takes ownership and + /// must free it with `mta_string_free`. + /// @return `MTA_SUCCESS` on success, another status code on error pub supported_outputs: Option mta_status_t>, - /// TODO + /// List the pair lists (neighbor lists) the model needs as input as a JSON + /// string. + /// + /// @verbatim embed:rst:leading-asterisk + /// + /// The engine is expected to compute these and attach them to every system + /// with :c:func:`mta_system_add_pairs` before calling + /// :c:func:`mta_execute_model`. + /// + /// The expected JSON structure for each pair list is documented in :ref:`core-json-pair-options`. + /// + /// @endverbatim + /// + /// @param model_data the model's `data` pointer + /// @param pair_options_json output string, set to a JSON array of + /// `PairListOptions` objects. The caller takes ownership and must + /// free it with `mta_string_free`. + /// @return `MTA_SUCCESS` on success, another status code on error pub requested_pair_lists: Option mta_status_t>, - /// TODO + /// List the additional per-system inputs the model needs as a JSON string. + /// + /// @verbatim embed:rst:leading-asterisk + /// + /// These correspond to custom data the engine should attach to every system + /// with :c:func:`mta_system_add_custom_data` before execution. + /// + /// The expected JSON structure for each input is documented in :ref:`core-json-quantity`. + /// + /// @endverbatim + /// + /// @param model_data the model's `data` pointer + /// @param inputs_json output string, set to a JSON array of `Quantity` + /// objects, one per requested input. The caller takes ownership and + /// must free it with `mta_string_free`. + /// @return `MTA_SUCCESS` on success, another status code on error pub requested_inputs: Option mta_status_t>, - /// TODO + /// Run the model and compute the requested outputs + /// + /// @verbatim embed:rst:leading-asterisk + /// + /// This performs the model's actual computation. This should not be called + /// directly, but rather through :c:func:`mta_execute_model`, which handles + /// unit conversion and can check inputs and output data for consistency. + /// + /// @endverbatim + /// + /// @param model_data the model's `data` pointer + /// @param systems array of `systems_count` systems to run the model on + /// @param systems_count number of entries in `systems` + /// @param selected_atoms optional labels selecting the subset of atoms to + /// compute outputs for, or `NULL` to use all atoms. When set, it has the + /// dimensions `"system"` and `"atom"` holding 0-based indices. + /// @param requested_outputs_json JSON string containing an array of + /// `Quantity`, one for each output the model should produce + /// @param outputs array of `outputs_count` tensor maps to fill, one per + /// requested output and in the same order + /// @param outputs_count number of entries in `outputs`, must equal + /// `requested_outputs_count` + /// @return `MTA_SUCCESS` on success, another status code on error pub execute_inner: Option mta_status_t>, } -/// TODO +/// Execute a model to compute the requested outputs for a set of systems +/// +/// This is the main entry point to run a model loaded through the C API. It +/// validates the arguments and delegates the computation to the model's +/// `execute_inner` callback. +/// +/// @param model the model to execute +/// @param systems array of `systems_count` systems to run the model on +/// @param systems_count number of entries in `systems` +/// @param selected_atoms optional labels selecting the subset of atoms to +/// compute outputs for, or `NULL` to use all atoms +/// @param requested_outputs_json JSON string containing an array of +/// `Quantity`, one for each output the model should produce +/// @param check_consistency if `true`, run additional checks on the +/// inputs and on the data produced by the model +/// @param outputs array of `outputs_count` tensor maps to fill, one per +/// requested output and in the same order. The caller takes ownership of +/// the returned tensor maps. +/// @param outputs_count number of entries in `outputs`, must equal +/// `requested_outputs_count` +/// @return `MTA_SUCCESS` on success, another status code on error (the message +/// is available through `mta_last_error`) #[no_mangle] pub unsafe extern "C" fn mta_execute_model( model: mta_model_t, systems: *const *const mta_system_t, systems_count: usize, selected_atoms: *const mts_labels_t, - requested_outputs_json: *const *const c_char, - requested_outputs_count: usize, + requested_outputs_json: *const c_char, check_consistency: bool, outputs: *mut *mut mts_tensormap_t, outputs_count: usize, @@ -66,7 +180,14 @@ pub unsafe extern "C" fn mta_execute_model( todo!() } -/// TODO +/// Render model metadata as a human-readable string +/// +/// @param metadata a JSON-serialized `ModelMetadata` object as produced by a +/// model's `metadata` callback. Must not be null. +/// @param printed output string, set to a human-readable rendering of the +/// metadata. The caller takes ownership and must free it with +/// `mta_string_free`. +/// @return `MTA_SUCCESS` on success, another status code on error #[no_mangle] pub unsafe extern "C" fn mta_format_metadata( metadata: *const c_char, diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs index 09ec7a962..89834e45d 100644 --- a/metatomic-core/src/lib.rs +++ b/metatomic-core/src/lib.rs @@ -6,6 +6,7 @@ #![allow(clippy::unreadable_literal, clippy::option_if_let_else, clippy::module_name_repetitions)] #![allow(clippy::missing_errors_doc, clippy::missing_panics_doc, clippy::missing_safety_doc)] #![allow(clippy::similar_names, clippy::borrow_as_ptr, clippy::uninlined_format_args)] +#![allow(clippy::doc_markdown)] #![allow(clippy::let_underscore_untyped, clippy::manual_let_else, clippy::empty_line_after_doc_comments)] // To be removed later From 4a4484aa31cf61345f1375fe9896c24e194af8b7 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Sun, 31 May 2026 18:41:35 +0200 Subject: [PATCH 14/20] Add ModelCapabilities to the JSON structs --- docs/src/core/reference/json-formats.rst | 70 +++++ docs/src/core/units.rst | 2 +- metatomic-core/include/metatomic.h | 18 +- metatomic-core/src/c_api/model.rs | 20 +- metatomic-core/src/lib.rs | 2 +- metatomic-core/src/metadata.rs | 373 +++++++++++++++++++++-- metatomic-core/src/quantities.rs | 10 +- metatomic-core/src/units.rs | 20 ++ 8 files changed, 486 insertions(+), 29 deletions(-) diff --git a/docs/src/core/reference/json-formats.rst b/docs/src/core/reference/json-formats.rst index 859f7d985..f7e3f468b 100644 --- a/docs/src/core/reference/json-formats.rst +++ b/docs/src/core/reference/json-formats.rst @@ -154,3 +154,73 @@ The JSON representation of a model's metadata. This is used for example by ``extra`` An object with string values, providing any additional key-value pairs the model author wishes to include. This can be used for any purpose. + +.. _core-json-model-capabilities: + +Model capabilities +------------------ + +The JSON representation of a model's capabilities, describing which outputs it +provides, which atomic types it supports, and other constraints. This is used +for example by :c:member:`mta_model_t.capabilities`. + +.. code-block:: json + + { + "type": "metatomic_model_capabilities", + "outputs": [ + { + "type": "metatomic_quantity", + "name": "energy", + "unit": "eV", + "sample_kind": "system", + "gradients": ["positions"], + "description": "Potential energy of the system" + }, + { + "type": "metatomic_quantity", + "name": "energy/pbe0", + "unit": "eV", + "sample_kind": "system", + "gradients": ["positions", "strain"], + "description": "Potential energy of the system" + }, + ], + "atomic_types": [1, 6, 8], + "interaction_range": 5.0, + "length_unit": "angstrom", + "supported_devices": ["cpu", "cuda"], + "dtype": "float32" + } + +``type`` + Must be the string ``"metatomic_model_capabilities"``. + +``outputs`` + Array of :ref:`quantity objects ` describing the + outputs this model can provide. + +``atomic_types`` + Array of integers listing the atomic types this model supports. The meaning + of these integers is up to the model, and is not required to be the atomic + numbers. + +``interaction_range`` + The interaction range of the model in the length unit of the model. This is + the maximum distance between two atoms for which the model's output can + depend on their relative position. Must be a non-negative number. + +``length_unit`` + String identifying the length unit used by the model, e.g. ``"angstrom"`` or + ``"nanometer"``. This must be a valid :ref:`unit expression ` with + dimensions compatible with length. + +``supported_devices`` + Array of strings listing the devices on which the model can run. Valid + values are ``"cpu"``, ``"cuda"``, ``"rocm"``, and ``"metal"``. + +``dtype`` + The data type of the model, used for all inputs and outputs. Must be either + ``"float32"`` or ``"float64"``. The model is free to use different data + types for internal computations, but all inputs and outputs must be in this + data type. diff --git a/docs/src/core/units.rst b/docs/src/core/units.rst index d5f3dd776..c9415ed9f 100644 --- a/docs/src/core/units.rst +++ b/docs/src/core/units.rst @@ -1,4 +1,4 @@ -.. _core-unit-expressions: +.. _units: Units ^^^^^ diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index a4ba5e95e..c8077b5c5 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -116,6 +116,20 @@ typedef struct mta_model_t { * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*unload)(void *model_data); + /** + * Get the capabilities of the model as a JSON string. + * + * @verbatim embed:rst:leading-asterisk + * The expected JSON structure is documented in :ref:`core-json-model-capabilities`. + * @endverbatim + * + * @param model_data the model's `data` pointer + * @param capabilities_json output string, set to a JSON-serialized + * `ModelCapabilities` object. The caller takes ownership and must + * free it with `mta_string_free`. + * @return `MTA_SUCCESS` on success, another status code on error + */ + enum mta_status_t (*capabilities)(const void *model_data, mta_string_t *capabilities_json); /** * Get metadata describing the model (name, authors, references, ...) as a * JSON string. @@ -126,8 +140,8 @@ typedef struct mta_model_t { * * @param model_data the model's `data` pointer * @param metadata_json output string, set to a JSON-serialized - * `ModelMetadata` object. The - * caller takes ownership and must free it with `mta_string_free`. + * `ModelMetadata` object. The caller takes ownership and must + * free it with `mta_string_free`. * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*metadata)(const void *model_data, mta_string_t *metadata_json); diff --git a/metatomic-core/src/c_api/model.rs b/metatomic-core/src/c_api/model.rs index c9fc83610..bf03e2a46 100644 --- a/metatomic-core/src/c_api/model.rs +++ b/metatomic-core/src/c_api/model.rs @@ -33,6 +33,22 @@ pub struct mta_model_t { /// @return `MTA_SUCCESS` on success, another status code on error pub unload: Option mta_status_t>, + /// Get the capabilities of the model as a JSON string. + /// + /// @verbatim embed:rst:leading-asterisk + /// The expected JSON structure is documented in :ref:`core-json-model-capabilities`. + /// @endverbatim + /// + /// @param model_data the model's `data` pointer + /// @param capabilities_json output string, set to a JSON-serialized + /// `ModelCapabilities` object. The caller takes ownership and must + /// free it with `mta_string_free`. + /// @return `MTA_SUCCESS` on success, another status code on error + pub capabilities: Option mta_status_t>, + /// Get metadata describing the model (name, authors, references, ...) as a /// JSON string. /// @@ -42,8 +58,8 @@ pub struct mta_model_t { /// /// @param model_data the model's `data` pointer /// @param metadata_json output string, set to a JSON-serialized - /// `ModelMetadata` object. The - /// caller takes ownership and must free it with `mta_string_free`. + /// `ModelMetadata` object. The caller takes ownership and must + /// free it with `mta_string_free`. /// @return `MTA_SUCCESS` on success, another status code on error pub metadata: Option for JsonValue { } } -impl TryFrom for PairListOptions { +impl<'a> TryFrom<&'a JsonValue> for PairListOptions { type Error = Error; - fn try_from(value: JsonValue) -> Result { + fn try_from(value: &'a JsonValue) -> Result { if !value.is_object() { return Err(Error::Serialization( "invalid JSON data for PairListOptions, expected an object".into() @@ -168,19 +169,19 @@ fn read_references(object: &JsonValue, key: &str) -> Result, Error> Ok(references) } -impl TryFrom for References { +impl<'a> TryFrom<&'a JsonValue> for References { type Error = Error; - fn try_from(value: JsonValue) -> Result { + fn try_from(value: &'a JsonValue) -> Result { if !value.is_object() { return Err(Error::Serialization( "invalid JSON data for references in ModelMetadata, expected an object".into() )); } - let model = read_references(&value, "model")?; - let architecture = read_references(&value, "architecture")?; - let implementation = read_references(&value, "implementation")?; + let model = read_references(value, "model")?; + let architecture = read_references(value, "architecture")?; + let implementation = read_references(value, "implementation")?; Ok(References { model, architecture, implementation }) } @@ -217,10 +218,10 @@ impl From for JsonValue { } } -impl TryFrom for ModelMetadata { +impl<'a> TryFrom<&'a JsonValue> for ModelMetadata { type Error = Error; - fn try_from(value: JsonValue) -> Result { + fn try_from(value: &'a JsonValue) -> Result { if !value.is_object() { return Err(Error::Serialization( "invalid JSON data for ModelMetadata, expected an object".into() @@ -253,7 +254,7 @@ impl TryFrom for ModelMetadata { "'description' in JSON for ModelMetadata must be a string".into() ))?.to_string(); - let references = References::try_from(value["references"].clone())?; + let references = References::try_from(&value["references"])?; if !value["extra"].is_object() { return Err(Error::Serialization( @@ -279,6 +280,211 @@ impl TryFrom for ModelMetadata { } } +/// The data type of a model, used for all inputs and outputs. The model can +/// still internally use a different data type for its calculations, but it will +/// get inputs in this type and must produce outputs in this type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DType { + /// 32-bit floating point, following the IEEE 754 standard + Float32, + /// 64-bit floating point, following the IEEE 754 standard + Float64, +} + +impl From for JsonValue { + fn from(value: DType) -> Self { + match value { + DType::Float32 => "float32".into(), + DType::Float64 => "float64".into(), + } + } +} + +impl<'a> TryFrom<&'a JsonValue> for DType { + type Error = Error; + + fn try_from(value: &'a JsonValue) -> Result { + if let Some(s) = value.as_str() { + match s { + "float32" => Ok(DType::Float32), + "float64" => Ok(DType::Float64), + _ => Err(Error::Serialization( + "invalid string for dtype in JSON for ModelCapabilities, expected 'float32' or 'float64'".into() + )), + } + } else { + Err(Error::Serialization( + "dtype in JSON for ModelCapabilities must be a string".into() + )) + } + } +} + +/// A device on which a model can run. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Device(dlpk::DLDeviceType); + +impl From for JsonValue { + fn from(value: Device) -> Self { + match value.0 { + dlpk::DLDeviceType::kDLCPU => "cpu".into(), + dlpk::DLDeviceType::kDLCUDA => "cuda".into(), + dlpk::DLDeviceType::kDLROCM => "rocm".into(), + dlpk::DLDeviceType::kDLMetal => "metal".into(), + dlpk::DLDeviceType::kDLCUDAHost | dlpk::DLDeviceType::kDLCUDAManaged => { + // These refer to memory devices more than execution devices + panic!("Do not use kDLCUDAHost or kDLCUDAManaged, use kDLCUDA instead."); + } + dlpk::DLDeviceType::kDLROCMHost => { + // This refers to a memory device more than an execution device + panic!("Do not use kDLROCMHost, use kDLROCM instead."); + } + _ => { + // We don't want to expose other device types until we have a + // use case for them, and we don't want to accidentally leak + // them if they're added in the future + panic!("unsupported device type: {:?}", value.0); + } + } + } +} + +impl<'a> TryFrom<&'a JsonValue> for Device { + type Error = Error; + + fn try_from(value: &'a JsonValue) -> Result { + if let Some(s) = value.as_str() { + match s { + "cpu" => Ok(Device(dlpk::DLDeviceType::kDLCPU)), + "cuda" => Ok(Device(dlpk::DLDeviceType::kDLCUDA)), + "rocm" => Ok(Device(dlpk::DLDeviceType::kDLROCM)), + "metal" => Ok(Device(dlpk::DLDeviceType::kDLMetal)), + _ => Err(Error::Serialization( + "invalid string for device in JSON for ModelCapabilities, expected 'cpu', 'cuda', 'rocm', or 'metal'".into() + )), + } + } else { + Err(Error::Serialization( + "device in JSON for ModelCapabilities must be a string".into() + )) + } + } +} + +/// Capabilities about a model: which outputs it provides, which atoms it +/// supports, etc. +#[derive(Debug, Clone)] +pub struct ModelCapabilities { + /// The outputs this model can provide + pub outputs: Vec, + /// The atomic types this model supports. The meaning of the integers in + /// this list is up to the model, and is not required to be the atomic + /// numbers. + pub atomic_types: Vec, + /// The interaction range of the model (in the length unit of the model), + /// i.e. the maximum distance between two atoms for which the model's output + /// can depend on their relative position. + pub interaction_range: f64, + /// The length unit of the model, e.g. "angstrom" or "nanometer". This is + /// used to interpret the `interaction_range` and convert the inputs. + pub length_unit: String, + /// The devices on which the model can run, e.g. `["cpu", "cuda"]`. + pub supported_devices: Vec, + /// The data type of the model, used for all inputs and outputs. + pub dtype: DType, +} + +impl From for JsonValue { + fn from(value: ModelCapabilities) -> Self { + let mut result = JsonValue::new_object(); + result["type"] = "metatomic_model_capabilities".into(); + result["outputs"] = value.outputs.into(); + result["atomic_types"] = value.atomic_types.into(); + result["interaction_range"] = value.interaction_range.into(); + result["length_unit"] = value.length_unit.into(); + result["supported_devices"] = value.supported_devices.into(); + result["dtype"] = value.dtype.into(); + return result; + } +} + +impl<'a> TryFrom<&'a JsonValue> for ModelCapabilities { + type Error = Error; + + fn try_from(value: &'a JsonValue) -> Result { + if !value.is_object() { + return Err(Error::Serialization( + "invalid JSON data for ModelCapabilities, expected an object".into() + )); + } + + if value["type"].as_str() != Some("metatomic_model_capabilities") { + return Err(Error::Serialization( + "'type' in JSON for ModelCapabilities must be 'metatomic_model_capabilities'".into() + )); + } + + let mut outputs = Vec::new(); + if !value["outputs"].is_array() { + return Err(Error::Serialization( + "'outputs' in JSON for ModelCapabilities must be an array".into() + )); + } + for output in value["outputs"].members() { + outputs.push(Quantity::try_from(output)?); + } + + + let mut atomic_types = Vec::new(); + if !value["atomic_types"].is_array() { + return Err(Error::Serialization( + "'atomic_types' in JSON for ModelCapabilities must be an array".into() + )); + } + + for atomic_type in value["atomic_types"].members() { + let atomic_type = atomic_type.as_i64().ok_or_else(|| Error::Serialization( + "'atomic_types' in JSON for ModelCapabilities must be an array of integers".into() + ))?; + atomic_types.push(atomic_type); + } + + let interaction_range = value["interaction_range"].as_f64().ok_or_else(|| Error::Serialization( + "'interaction_range' in JSON for ModelCapabilities must be a number".into() + ))?; + if interaction_range < 0.0 { + return Err(Error::Serialization( + "'interaction_range' in JSON for ModelCapabilities must be non-negative".into() + )); + } + + let length_unit = value["length_unit"].as_str().ok_or_else(|| Error::Serialization( + "'length_unit' in JSON for ModelCapabilities must be a string".into() + ))?.to_string(); + validate_unit(&length_unit, "m", Some("'length_unit' in JSON for ModelCapabilities"))?; + + let mut supported_devices = Vec::new(); + if !value["supported_devices"].is_array() { + return Err(Error::Serialization( + "'supported_devices' in JSON for ModelCapabilities must be an array".into() + )); + } + for device in value["supported_devices"].members() { + supported_devices.push(Device::try_from(device)?); + } + + let dtype = DType::try_from(&value["dtype"])?; + + Ok(ModelCapabilities { + outputs, + atomic_types, + interaction_range, + length_unit, + supported_devices, + dtype, + }) + } +} #[cfg(test)] mod tests { @@ -304,7 +510,7 @@ mod tests { assert_eq!(json["full_list"].as_bool(), Some(true)); assert_eq!(json["strict"].as_bool(), Some(false)); - let parsed = PairListOptions::try_from(json).unwrap(); + let parsed = PairListOptions::try_from(&json).unwrap(); assert_eq!(parsed.cutoff.to_bits(), options.cutoff.to_bits()); assert_eq!(parsed.full_list, options.full_list); assert_eq!(parsed.strict, options.strict); @@ -315,7 +521,7 @@ mod tests { fn cutoff_keeps_full_precision() { let mut options = example(); options.cutoff = 1.0 / 3.0; - let parsed = PairListOptions::try_from(JsonValue::from(options.clone())).unwrap(); + let parsed = PairListOptions::try_from(&JsonValue::from(options.clone())).unwrap(); assert_eq!(parsed.cutoff.to_bits(), options.cutoff.to_bits()); } @@ -323,7 +529,7 @@ mod tests { fn requestors_are_optional() { let mut json: JsonValue = example().into(); json.remove("requestors"); - let parsed = PairListOptions::try_from(json).unwrap(); + let parsed = PairListOptions::try_from(&json).unwrap(); assert!(parsed.requestors.is_empty()); } @@ -380,7 +586,7 @@ mod tests { ]; for (json, expected) in cases { - let error = PairListOptions::try_from(json).expect_err("expected an error"); + let error = PairListOptions::try_from(&json).expect_err("expected an error"); assert_eq!(error.to_string(), expected); } } @@ -390,7 +596,7 @@ mod tests { let mut json: JsonValue = example().into(); json["requestors"] = json::array![ "a", "", "b", "a" ]; - let parsed = PairListOptions::try_from(json).unwrap(); + let parsed = PairListOptions::try_from(&json).unwrap(); assert_eq!(parsed.requestors, vec!["a".to_string(), "b".to_string()]); } } @@ -431,7 +637,7 @@ mod tests { assert_eq!(json["extra"]["key1"].as_str(), Some("value1")); assert_eq!(json["extra"]["key2"].as_str(), Some("value2")); - let parsed = ModelMetadata::try_from(json).unwrap(); + let parsed = ModelMetadata::try_from(&json).unwrap(); assert_eq!(parsed.name, metadata.name); assert_eq!(parsed.authors, metadata.authors); assert_eq!(parsed.description, metadata.description); @@ -494,7 +700,138 @@ mod tests { ]; for (json, expected) in cases { - let error = ModelMetadata::try_from(json).expect_err("expected an error"); + let error = ModelMetadata::try_from(&json).expect_err("expected an error"); + assert_eq!(error.to_string(), expected); + } + } + } + + mod model_capabilities { + use super::super::*; + + fn example() -> ModelCapabilities { + ModelCapabilities { + outputs: vec![ + Quantity { + name: "energy".into(), + unit: "eV".into(), + description: Some("total energy".into()), + gradients: vec![crate::Gradients::Positions], + sample_kind: crate::SampleKind::System, + }, + Quantity { + name: "charge".into(), + unit: "e".into(), + description: None, + gradients: vec![], + sample_kind: crate::SampleKind::Atom, + }, + ], + atomic_types: vec![1, 6, 8], + interaction_range: 5.0, + length_unit: "Angstrom".into(), + supported_devices: vec![Device(dlpk::DLDeviceType::kDLCPU), Device(dlpk::DLDeviceType::kDLCUDA)], + dtype: DType::Float32, + } + } + + #[test] + fn roundtrip() { + let capabilities = example(); + let json: JsonValue = capabilities.clone().into(); + + assert_eq!(json["type"].as_str(), Some("metatomic_model_capabilities")); + assert_eq!(json["outputs"][0]["name"].as_str(), Some("energy")); + assert_eq!(json["outputs"][1]["name"].as_str(), Some("charge")); + assert_eq!(json["atomic_types"][0].as_i64(), Some(1)); + assert_eq!(json["atomic_types"][1].as_i64(), Some(6)); + assert_eq!(json["atomic_types"][2].as_i64(), Some(8)); + assert_eq!(json["interaction_range"].as_f64(), Some(5.0)); + assert_eq!(json["length_unit"].as_str(), Some("Angstrom")); + assert_eq!(json["supported_devices"][0].as_str(), Some("cpu")); + assert_eq!(json["supported_devices"][1].as_str(), Some("cuda")); + assert_eq!(json["dtype"].as_str(), Some("float32")); + + let parsed = ModelCapabilities::try_from(&json).unwrap(); + assert_eq!(parsed.outputs.len(), 2); + assert_eq!(parsed.outputs[0].name, "energy"); + assert_eq!(parsed.outputs[1].name, "charge"); + assert_eq!(parsed.atomic_types, vec![1, 6, 8]); + assert_eq!(parsed.interaction_range.to_bits(), 5.0_f64.to_bits()); + assert_eq!(parsed.length_unit, "Angstrom"); + assert_eq!(parsed.supported_devices.len(), 2); + assert_eq!(parsed.dtype, DType::Float32); + } + + #[test] + fn rejects_invalid_json() { + let mut wrong_type = JsonValue::from(example()); + wrong_type["type"] = "something-else".into(); + + let mut non_array_outputs = JsonValue::from(example()); + non_array_outputs["outputs"] = "energy".into(); + + let mut non_array_atomic_types = JsonValue::from(example()); + non_array_atomic_types["atomic_types"] = "1".into(); + + let mut non_integer_atomic_type = JsonValue::from(example()); + non_integer_atomic_type["atomic_types"] = json::array![1, "x"]; + + let mut missing_interaction_range = JsonValue::from(example()); + missing_interaction_range.remove("interaction_range"); + + let mut negative_interaction_range = JsonValue::from(example()); + negative_interaction_range["interaction_range"] = (-1.0).into(); + + let mut missing_length_unit = JsonValue::from(example()); + missing_length_unit.remove("length_unit"); + + let mut wrong_dimension_length_unit = JsonValue::from(example()); + wrong_dimension_length_unit["length_unit"] = "eV".into(); + + let mut non_array_supported_devices = JsonValue::from(example()); + non_array_supported_devices["supported_devices"] = "cpu".into(); + + let mut invalid_device = JsonValue::from(example()); + invalid_device["supported_devices"] = json::array!["cpu", "wat"]; + + let mut missing_dtype = JsonValue::from(example()); + missing_dtype.remove("dtype"); + + let mut invalid_dtype = JsonValue::from(example()); + invalid_dtype["dtype"] = "float16".into(); + + let cases: Vec<(JsonValue, &str)> = vec![ + (JsonValue::from("not an object"), + "serialization error: invalid JSON data for ModelCapabilities, expected an object"), + (wrong_type, + "serialization error: 'type' in JSON for ModelCapabilities must be 'metatomic_model_capabilities'"), + (non_array_outputs, + "serialization error: 'outputs' in JSON for ModelCapabilities must be an array"), + (non_array_atomic_types, + "serialization error: 'atomic_types' in JSON for ModelCapabilities must be an array"), + (non_integer_atomic_type, + "serialization error: 'atomic_types' in JSON for ModelCapabilities must be an array of integers"), + (missing_interaction_range, + "serialization error: 'interaction_range' in JSON for ModelCapabilities must be a number"), + (negative_interaction_range, + "serialization error: 'interaction_range' in JSON for ModelCapabilities must be non-negative"), + (missing_length_unit, + "serialization error: 'length_unit' in JSON for ModelCapabilities must be a string"), + (wrong_dimension_length_unit, + "invalid parameter: dimension mismatch in 'length_unit' in JSON for ModelCapabilities: 'eV' has dimension [L^2 T^-2 M] but expected dimension [L]"), + (non_array_supported_devices, + "serialization error: 'supported_devices' in JSON for ModelCapabilities must be an array"), + (invalid_device, + "serialization error: invalid string for device in JSON for ModelCapabilities, expected 'cpu', 'cuda', 'rocm', or 'metal'"), + (missing_dtype, + "serialization error: dtype in JSON for ModelCapabilities must be a string"), + (invalid_dtype, + "serialization error: invalid string for dtype in JSON for ModelCapabilities, expected 'float32' or 'float64'"), + ]; + + for (json, expected) in cases { + let error = ModelCapabilities::try_from(&json).expect_err("expected an error"); assert_eq!(error.to_string(), expected); } } diff --git a/metatomic-core/src/quantities.rs b/metatomic-core/src/quantities.rs index 9d1dfebbf..93727c837 100644 --- a/metatomic-core/src/quantities.rs +++ b/metatomic-core/src/quantities.rs @@ -193,10 +193,10 @@ impl From for JsonValue { } -impl TryFrom for Quantity { +impl<'a> TryFrom<&'a JsonValue> for Quantity { type Error = Error; - fn try_from(value: JsonValue) -> Result { + fn try_from(value: &'a JsonValue) -> Result { if !value.is_object() { return Err(Error::Serialization( "invalid JSON data for Quantity, expected an object".into() @@ -272,7 +272,7 @@ mod tests { assert_eq!(json["gradients"][0].as_str(), Some("positions")); assert_eq!(json["sample_kind"].as_str(), Some("atom")); - let parsed = Quantity::try_from(json).unwrap(); + let parsed = Quantity::try_from(&json).unwrap(); assert_eq!(parsed.name, "energy"); assert_eq!(parsed.unit, "eV"); assert_eq!(parsed.gradients, vec![Gradients::Positions]); @@ -295,7 +295,7 @@ mod tests { gradients: grads.clone(), sample_kind: sample.clone(), }; - let parsed = Quantity::try_from(JsonValue::from(quantity.clone())).unwrap(); + let parsed = Quantity::try_from(&JsonValue::from(quantity.clone())).unwrap(); assert_eq!(parsed.name, quantity.name); assert_eq!(parsed.unit, quantity.unit); assert_eq!(parsed.gradients, grads); @@ -352,7 +352,7 @@ mod tests { ]; for (json, expected) in cases { - let error = Quantity::try_from(json).expect_err("expected an error"); + let error = Quantity::try_from(&json).expect_err("expected an error"); assert_eq!(error.to_string(), expected); } } diff --git a/metatomic-core/src/units.rs b/metatomic-core/src/units.rs index 4cfff2c1d..4239b86be 100644 --- a/metatomic-core/src/units.rs +++ b/metatomic-core/src/units.rs @@ -574,6 +574,26 @@ pub fn unit_conversion_factor(from_unit: &str, to_unit: &str) -> Result) -> Result<(), Error> { + let unit_value = parse_unit_expression(unit)?; + let reference_value = parse_unit_expression(reference_unit)?; + + if unit_value.dim != reference_value.dim { + return Err(Error::InvalidParameter(format!( + "dimension mismatch{}: '{}' has dimension {} but expected dimension {}", + context.map_or_else(String::new, |c| format!(" in {}", c)), + unit, + unit_value.dim, + reference_value.dim + ))); + } + + Ok(()) +} + + #[cfg(test)] #[allow(clippy::float_cmp)] mod tests { From d35b62b956f41f549aa1965cac248d2159caca65 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 1 Jun 2026 14:49:24 +0200 Subject: [PATCH 15/20] Implement plugin registration and loading, model loading Co-Authored-By: Guillaume Fraux --- metatomic-core/Cargo.toml | 1 + metatomic-core/build.rs | 42 +++- metatomic-core/include/metatomic.h | 133 +++++++++++-- metatomic-core/src/c_api/mod.rs | 2 +- metatomic-core/src/c_api/model.rs | 15 ++ metatomic-core/src/c_api/plugin.rs | 151 ++++++++++++-- metatomic-core/src/c_api/status.rs | 17 +- metatomic-core/src/lib.rs | 24 ++- metatomic-core/src/model.rs | 11 + metatomic-core/src/plugin.rs | 188 ++++++++++++++++-- metatomic-core/tests/CMakeLists.txt | 3 + metatomic-core/tests/plugins.cpp | 45 +++++ .../tests/test-plugins/CMakeLists.txt | 14 ++ metatomic-core/tests/test-plugins/bad-abi.c | 17 ++ metatomic-core/tests/test-plugins/plugin.c | 17 ++ scripts/include/stdio.h | 0 16 files changed, 619 insertions(+), 61 deletions(-) create mode 100644 metatomic-core/tests/plugins.cpp create mode 100644 metatomic-core/tests/test-plugins/CMakeLists.txt create mode 100644 metatomic-core/tests/test-plugins/bad-abi.c create mode 100644 metatomic-core/tests/test-plugins/plugin.c create mode 100644 scripts/include/stdio.h diff --git a/metatomic-core/Cargo.toml b/metatomic-core/Cargo.toml index 2335505a3..dace0f8d5 100644 --- a/metatomic-core/Cargo.toml +++ b/metatomic-core/Cargo.toml @@ -18,6 +18,7 @@ metatensor = { version = "0.3.0" } once_cell = "1" dlpk = "0.3" json = "0.12" +libloading = "0.8" [build-dependencies] diff --git a/metatomic-core/build.rs b/metatomic-core/build.rs index 01f95d71c..1a58845d0 100644 --- a/metatomic-core/build.rs +++ b/metatomic-core/build.rs @@ -22,7 +22,8 @@ fn main() { config.documentation_style = cbindgen::DocumentationStyle::Doxy; config.line_endings = cbindgen::LineEndingStyle::LF; config.autogen_warning = Some(generated_comment.into()); - config.includes.push("metatensor.h".into()); + config.sys_includes.push("stdio.h".into()); + config.sys_includes.push("metatensor.h".into()); config.includes.push("metatomic/version.h".into()); config.export = cbindgen::ExportConfig { @@ -33,6 +34,45 @@ fn main() { }; config.after_includes = Some(" +#ifndef MTA_EXPORT + #if defined(_WIN32) || defined(__CYGWIN__) + #define MTA_EXPORT __declspec(dllexport) + #else + #define MTA_EXPORT __attribute__((visibility(\"default\"))) + #endif +#endif + +#ifndef MTA_EXTERN_C + #ifdef __cplusplus + #define MTA_EXTERN_C extern \"C\" + #else + #define MTA_EXTERN_C + #endif +#endif + +/** + * Define the exported plugin entry points. + * + * This macro should be used once in each plugin shared library with a + * `mta_plugin_t` expression. It exports the plugin ABI version and a + * registration function used by `mta_load_plugin`. + */ +#define MTA_REGISTER_PLUGIN(register_fn_name, ...) \\ + MTA_EXTERN_C MTA_EXPORT mta_status_t mta_plugin_init(int abi, void *data) { \\ + if (abi != MTA_ABI_VERSION) { \\ + char message[256]; \\ + snprintf(message, sizeof(message), \\ + \"Metatomic plugin ABI version mismatch: expected %d, got %d\", \\ + MTA_ABI_VERSION, abi \\ + ); \\ + mta_set_last_error(message, \"MTA_REGISTER_PLUGIN\", NULL, NULL); \\ + return MTA_INVALID_PARAMETER_ERROR; \\ + } \\ + mta_status_t (*register_fn_name)(mta_plugin_t) = (mta_status_t (*)(mta_plugin_t))data; \\ + __VA_ARGS__; \\ + return MTA_SUCCESS; \\ + } + /** Heap allocated storage for mta_string_t */ typedef struct mta_opaque_string_t mta_opaque_string_t;".into()); diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index c8077b5c5..a2a549e78 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -12,15 +12,59 @@ #include #include #include -#include "metatensor.h" +#include +#include #include "metatomic/version.h" +#ifndef MTA_EXPORT + #if defined(_WIN32) || defined(__CYGWIN__) + #define MTA_EXPORT __declspec(dllexport) + #else + #define MTA_EXPORT __attribute__((visibility("default"))) + #endif +#endif + +#ifndef MTA_EXTERN_C + #ifdef __cplusplus + #define MTA_EXTERN_C extern "C" + #else + #define MTA_EXTERN_C + #endif +#endif + +/** + * Define the exported plugin entry points. + * + * This macro should be used once in each plugin shared library with a + * `mta_plugin_t` expression. It exports the plugin ABI version and a + * registration function used by `mta_load_plugin`. + */ +#define MTA_REGISTER_PLUGIN(register_fn_name, ...) \ + MTA_EXTERN_C MTA_EXPORT mta_status_t mta_plugin_init(int abi, void *data) { \ + if (abi != MTA_ABI_VERSION) { \ + char message[256]; \ + snprintf(message, sizeof(message), \ + "Metatomic plugin ABI version mismatch: expected %d, got %d", \ + MTA_ABI_VERSION, abi \ + ); \ + mta_set_last_error(message, "MTA_REGISTER_PLUGIN", NULL, NULL); \ + return MTA_INVALID_PARAMETER_ERROR; \ + } \ + mta_status_t (*register_fn_name)(mta_plugin_t) = (mta_status_t (*)(mta_plugin_t))data; \ + __VA_ARGS__; \ + return MTA_SUCCESS; \ + } + /** Heap allocated storage for mta_string_t */ typedef struct mta_opaque_string_t mta_opaque_string_t; /** - * TODO + * ABI version of the metatomic plugin interface. + * + * This increases anytime the plugin or model C API changes in a non backward + * compatible way. Plugins compiled with an incompatible ABI version will be + * rejected at registration time. */ #define MTA_ABI_VERSION 1 @@ -47,11 +91,10 @@ typedef enum mta_status_t { */ MTA_SERIALIZATION_ERROR = 3, /** - * Status code indicating errors that come from callbacks provided by the user. - * The error message and arbitrary data can be stored using `mta_set_last_error`, - * and retrieved using `mta_last_error`. + * Status code used by plugins when a model is not supported by the + * current plugin */ - MTA_CALLBACK_ERROR = 254, + MTA_MODEL_NOT_SUPPORTED_ERROR = 4, /** * Status code used when there is an internal error */ @@ -234,15 +277,41 @@ typedef struct mta_model_t { } mta_model_t; /** - * TODO + * A metatomic plugin definition. */ typedef struct mta_plugin_t { /** - * TODO + * ABI version this plugin was compiled against, this should be set to + * `MTA_ABI_VERSION` when creating the plugin struct. + */ + int32_t abi_version; + /** + * Name of the plugin, as a null-terminated UTF-8 string. This is the name + * specified in `mta_load_model` when trying to load a model with a + * specific plugin. The name must be unique among all registered plugins. */ const char *name; /** - * TODO + * Callback function to load a model. This function should try to load a + * model from `load_from` (which can be a file path, a model name, etc.) + * and a set of key/values options passed as a JSON string. + * + * If the plugin can load the model, it should fill `model` with a pointer + * to a valid `mta_model_t` struct and return `MTA_SUCCESS`. If the data in + * `load_from` does not correspond to a model supported by the plugin, it + * should return `MTA_MODEL_NOT_SUPPORTED_ERROR`. If an error occurs while + * loading the model, it should return another status code and save an + * error message with `mta_set_last_error`. + * + * @param load_from a null-terminated UTF-8 string describing where to load + * the model from (e.g. a file path, a model name, etc.). The interpretation + * of this string is up to the plugin. + * @param options_json a null-terminated UTF-8 string containing a set of + * string keys and string value options for loading the model. + * @param model output pointer to the loaded model. The caller takes ownership of + * the model and must unload it when the model is no longer needed. + * @return `MTA_SUCCESS` if the model was loaded successfully, `MTA_MODEL_NOT_SUPPORTED_ERROR` + * if the plugin can not load the model, or another status code if an error occurs. */ enum mta_status_t (*load_model)(const char *load_from, const char *options_json, @@ -445,17 +514,55 @@ enum mta_status_t mta_execute_model(struct mta_model_t model, enum mta_status_t mta_format_metadata(const char *metadata, mta_string_t *printed); /** - * TODO + * Register a plugin. This is passed as a callback to the `MTA_REGISTER_PLUGIN` + * macro, and should not be called directly by C or C++ plugin implementations. + * + * @param plugin the plugin to register + * @return `MTA_SUCCESS` if the plugin was registered successfully, or another + * status code if an error occurs. You can get more details about the error + * with `mta_last_error`. */ -void mta_register_plugin(struct mta_plugin_t plugin); +enum mta_status_t mta_register_plugin(struct mta_plugin_t plugin); /** - * TODO + * Load the shared library at `path` and register the plugin contained within. + * + * The library must export the symbols generated by the `MTA_REGISTER_PLUGIN` + * macro. + * + * @param path a null-terminated UTF-8 string containing the path to the plugin + * shared library + * @return `MTA_SUCCESS` if the plugin was loaded successfully, or another + * status code if an error occurs. You can get more details about the + * error with `mta_last_error`. */ enum mta_status_t mta_load_plugin(const char *path); /** - * TODO + * Load a model from `load_from` with the given options. + * + * If `plugin_name` is a NULL pointer, metatomic will try to determine the + * correct plugin to use by checking the `load_from` parameter. If we can not + * determine the correct plugin, we then try to load the model with each + * registered plugin until one succeeds. + * + * If `plugin_name` is given, then we only try to load the model with the + * specified plugin, and return an error if the plugin can not load the model. + * + * @param plugin_name optional null-terminated UTF-8 string containing the name + * of the plugin to use for loading the model, or `NULL` to let metatomic + * search for a correct plugin + * @param load_from a null-terminated UTF-8 string describing where to load the + * model from (e.g. a file path, a model name, etc.). The interpretation + * of this string is up to the plugin. + * @param options_json a null-terminated UTF-8 string containing a set of string + * keys and string value options for loading the model. The interpretation + * of these options is up to the plugin. + * @param model output pointer to the loaded model. The caller takes ownership of + * the model and must unload it when the model is no longer needed. + * @return `MTA_SUCCESS` if the model was loaded successfully, or another + * status code if an error occurs. You can get more details about the + * error with `mta_last_error`. */ enum mta_status_t mta_load_model(const char *plugin_name, const char *load_from, diff --git a/metatomic-core/src/c_api/mod.rs b/metatomic-core/src/c_api/mod.rs index bffa5003a..235b00296 100644 --- a/metatomic-core/src/c_api/mod.rs +++ b/metatomic-core/src/c_api/mod.rs @@ -15,4 +15,4 @@ mod model; pub use self::model::mta_model_t; mod plugin; -pub use self::plugin::{mta_plugin_t, mta_register_plugin, mta_load_model}; +pub use self::plugin::{mta_plugin_t, mta_register_plugin, mta_load_plugin, mta_load_model}; diff --git a/metatomic-core/src/c_api/model.rs b/metatomic-core/src/c_api/model.rs index bf03e2a46..d2050643b 100644 --- a/metatomic-core/src/c_api/model.rs +++ b/metatomic-core/src/c_api/model.rs @@ -160,6 +160,21 @@ pub struct mta_model_t { ) -> mta_status_t>, } +impl mta_model_t { + pub(crate) fn null() -> Self { + return mta_model_t { + data: std::ptr::null_mut(), + unload: None, + capabilities: None, + metadata: None, + supported_outputs: None, + requested_pair_lists: None, + requested_inputs: None, + execute_inner: None, + }; + } +} + /// Execute a model to compute the requested outputs for a set of systems /// /// This is the main entry point to run a model loaded through the C API. It diff --git a/metatomic-core/src/c_api/plugin.rs b/metatomic-core/src/c_api/plugin.rs index 6dbfc4add..76b9989d1 100644 --- a/metatomic-core/src/c_api/plugin.rs +++ b/metatomic-core/src/c_api/plugin.rs @@ -1,15 +1,43 @@ -use std::ffi::c_char; +use std::ffi::{CStr, c_char}; +use super::catch_unwind; use super::{mta_model_t, mta_status_t}; +use crate::Error; +use crate::Plugin; -/// TODO +/// A metatomic plugin definition. #[allow(non_camel_case_types)] #[repr(C)] pub struct mta_plugin_t { - /// TODO + /// ABI version this plugin was compiled against, this should be set to + /// `MTA_ABI_VERSION` when creating the plugin struct. + pub abi_version: i32, + + /// Name of the plugin, as a null-terminated UTF-8 string. This is the name + /// specified in `mta_load_model` when trying to load a model with a + /// specific plugin. The name must be unique among all registered plugins. pub name: *const c_char, - /// TODO + /// Callback function to load a model. This function should try to load a + /// model from `load_from` (which can be a file path, a model name, etc.) + /// and a set of key/values options passed as a JSON string. + /// + /// If the plugin can load the model, it should fill `model` with a pointer + /// to a valid `mta_model_t` struct and return `MTA_SUCCESS`. If the data in + /// `load_from` does not correspond to a model supported by the plugin, it + /// should return `MTA_MODEL_NOT_SUPPORTED_ERROR`. If an error occurs while + /// loading the model, it should return another status code and save an + /// error message with `mta_set_last_error`. + /// + /// @param load_from a null-terminated UTF-8 string describing where to load + /// the model from (e.g. a file path, a model name, etc.). The interpretation + /// of this string is up to the plugin. + /// @param options_json a null-terminated UTF-8 string containing a set of + /// string keys and string value options for loading the model. + /// @param model output pointer to the loaded model. The caller takes ownership of + /// the model and must unload it when the model is no longer needed. + /// @return `MTA_SUCCESS` if the model was loaded successfully, `MTA_MODEL_NOT_SUPPORTED_ERROR` + /// if the plugin can not load the model, or another status code if an error occurs. pub load_model: Option mta_status_t>, } -/// TODO +unsafe impl Send for mta_plugin_t {} + +/// Register a plugin. This is passed as a callback to the `MTA_REGISTER_PLUGIN` +/// macro, and should not be called directly by C or C++ plugin implementations. +/// +/// @param plugin the plugin to register +/// @return `MTA_SUCCESS` if the plugin was registered successfully, or another +/// status code if an error occurs. You can get more details about the error +/// with `mta_last_error`. #[no_mangle] -pub extern "C" fn mta_register_plugin(plugin: mta_plugin_t) { - todo!() +pub unsafe extern "C" fn mta_register_plugin(plugin: mta_plugin_t) -> mta_status_t { + catch_unwind(move || { + let plugin = Plugin::new(plugin)?; + crate::plugin::register_plugin(plugin)?; + Ok(()) + }) } -/// TODO +/// Load the shared library at `path` and register the plugin contained within. +/// +/// The library must export the symbols generated by the `MTA_REGISTER_PLUGIN` +/// macro. +/// +/// @param path a null-terminated UTF-8 string containing the path to the plugin +/// shared library +/// @return `MTA_SUCCESS` if the plugin was loaded successfully, or another +/// status code if an error occurs. You can get more details about the +/// error with `mta_last_error`. #[no_mangle] -pub extern "C" fn mta_load_plugin(path: *const c_char) -> mta_status_t { - todo!() +pub unsafe extern "C" fn mta_load_plugin(path: *const c_char) -> mta_status_t { + catch_unwind(move || { + check_pointers_non_null!(path); + + let path = CStr::from_ptr(path).to_str().map_err(|_| { + Error::InvalidParameter("invalid UTF-8 in plugin path".into()) + })?; + + crate::plugin::load_plugin(path) + }) } -/// TODO +/// Load a model from `load_from` with the given options. +/// +/// If `plugin_name` is a NULL pointer, metatomic will try to determine the +/// correct plugin to use by checking the `load_from` parameter. If we can not +/// determine the correct plugin, we then try to load the model with each +/// registered plugin until one succeeds. +/// +/// If `plugin_name` is given, then we only try to load the model with the +/// specified plugin, and return an error if the plugin can not load the model. +/// +/// @param plugin_name optional null-terminated UTF-8 string containing the name +/// of the plugin to use for loading the model, or `NULL` to let metatomic +/// search for a correct plugin +/// @param load_from a null-terminated UTF-8 string describing where to load the +/// model from (e.g. a file path, a model name, etc.). The interpretation +/// of this string is up to the plugin. +/// @param options_json a null-terminated UTF-8 string containing a set of string +/// keys and string value options for loading the model. The interpretation +/// of these options is up to the plugin. +/// @param model output pointer to the loaded model. The caller takes ownership of +/// the model and must unload it when the model is no longer needed. +/// @return `MTA_SUCCESS` if the model was loaded successfully, or another +/// status code if an error occurs. You can get more details about the +/// error with `mta_last_error`. #[no_mangle] -pub extern "C" fn mta_load_model( +pub unsafe extern "C" fn mta_load_model( plugin_name: *const c_char, load_from: *const c_char, options_json: *const c_char, model: *mut mta_model_t, ) -> mta_status_t { - todo!() + let unwind_wrapper = std::panic::AssertUnwindSafe(model); + + catch_unwind(move || { + check_pointers_non_null!(load_from, model); + + let plugin_name = if plugin_name.is_null() { + None + } else { + Some(CStr::from_ptr(plugin_name).to_str().map_err(|_| { + Error::InvalidParameter("invalid UTF-8 in plugin name".into()) + })?) + }; + + let options_json = if options_json.is_null() { + CStr::from_bytes_with_nul(b"{}\0").expect("invalid CStr") + } else { + CStr::from_ptr(options_json) + }; + + let options_str = options_json.to_str().map_err(|_| { + Error::InvalidParameter("invalid UTF-8 in options JSON".into()) + })?; + + let options = json::parse(options_str).map_err( + |e| Error::Serialization(format!("JSON parsing error: {}", e)) + )?; + if !options.is_object() { + return Err(Error::Serialization("JSON options must be an object in `mta_load_model`".into())) + } + + // just some validation, we pass the raw JSON down to the plugins + for (key, value) in options.entries() { + if !value.is_string() { + return Err(Error::InvalidParameter(format!( + "JSON option '{}' has a non-string value in `mta_load_model`", + key + ))); + } + } + + let loaded = crate::plugin::load_model(plugin_name, CStr::from_ptr(load_from), options_json)?; + + let _ = &unwind_wrapper; + *unwind_wrapper.0 = loaded.into_raw(); + Ok(()) + }) } diff --git a/metatomic-core/src/c_api/status.rs b/metatomic-core/src/c_api/status.rs index 8aef16c11..b9c9a16aa 100644 --- a/metatomic-core/src/c_api/status.rs +++ b/metatomic-core/src/c_api/status.rs @@ -27,7 +27,7 @@ thread_local! { /// The value 0 (`MTA_SUCCESS`) indicates success, while any non-zero value indicates an error. #[allow(non_camel_case_types)] #[repr(C)] -#[derive(PartialEq, Eq, Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum mta_status_t { /// Status code indicating success MTA_SUCCESS = 0, @@ -37,10 +37,9 @@ pub enum mta_status_t { MTA_IO_ERROR = 2, /// Status code indicating serialization/deserialization errors MTA_SERIALIZATION_ERROR = 3, - /// Status code indicating errors that come from callbacks provided by the user. - /// The error message and arbitrary data can be stored using `mta_set_last_error`, - /// and retrieved using `mta_last_error`. - MTA_CALLBACK_ERROR = 254, + /// Status code used by plugins when a model is not supported by the + /// current plugin + MTA_MODEL_NOT_SUPPORTED_ERROR = 4, /// Status code used when there is an internal error MTA_INTERNAL_ERROR = 255, } @@ -79,9 +78,9 @@ macro_rules! check_pointers_non_null { impl From for mta_status_t { fn from(error: Error) -> mta_status_t { - if let Error::CallbackError = error { + if let Error::CallbackError(status) = error { // If the error is already a CallbackError, we can directly return the corresponding status code. - return mta_status_t::MTA_CALLBACK_ERROR; + return status; } LAST_ERROR.with(|last_error| { @@ -98,7 +97,7 @@ impl From for mta_status_t { *last_error = LastError { message: CString::new(format!("{}", error)) .expect("error message contains a null byte"), - origin: CString::new("metatensor-core").expect("invalid C string"), + origin: CString::new("metatomic-core").expect("invalid C string"), custom_data: std::ptr::null_mut(), custom_data_deleter: None, }; @@ -108,7 +107,7 @@ impl From for mta_status_t { Error::InvalidParameter(_) => mta_status_t::MTA_INVALID_PARAMETER_ERROR, Error::Io(_) => mta_status_t::MTA_IO_ERROR, Error::Serialization(_) => mta_status_t::MTA_SERIALIZATION_ERROR, - Error::CallbackError => unreachable!(), + Error::CallbackError(_) => unreachable!("already handled above"), Error::Internal(_) => mta_status_t::MTA_INTERNAL_ERROR, } } diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs index b57ab76dd..3450f05df 100644 --- a/metatomic-core/src/lib.rs +++ b/metatomic-core/src/lib.rs @@ -6,16 +6,20 @@ #![allow(clippy::unreadable_literal, clippy::option_if_let_else, clippy::module_name_repetitions)] #![allow(clippy::missing_errors_doc, clippy::missing_panics_doc, clippy::missing_safety_doc)] #![allow(clippy::similar_names, clippy::borrow_as_ptr, clippy::uninlined_format_args)] -#![allow(clippy::doc_markdown)] +#![allow(clippy::doc_markdown, clippy::needless_continue)] #![allow(clippy::let_underscore_untyped, clippy::manual_let_else, clippy::empty_line_after_doc_comments)] // To be removed later #![allow(unused_variables, dead_code, clippy::needless_pass_by_value)] +use std::sync::Arc; + #[doc(hidden)] pub mod c_api; mod metadata; +use crate::c_api::mta_status_t; + pub use self::metadata::{ModelMetadata, PairListOptions}; mod quantities; @@ -28,22 +32,22 @@ mod model; pub use self::model::Model; mod plugin; -pub use self::plugin::{Plugin, load_plugin, load_model}; +pub use self::plugin::Plugin; mod units; pub use self::units::unit_conversion_factor; /// The possible sources of error in metatomic -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum Error { /// Error while serializing data to or deserializing data from JSON Serialization(String), /// Invalid parameters passed to a function InvalidParameter(String), /// I/O error - Io(std::io::Error), + Io(Arc), /// Error coming from an external function used as a callback - CallbackError, + CallbackError(mta_status_t), /// Any other internal error, usually these are internal bugs. Internal(String), } @@ -54,7 +58,7 @@ impl std::fmt::Display for Error { Error::Serialization(e) => write!(f, "serialization error: {}", e), Error::InvalidParameter(e) => write!(f, "invalid parameter: {}", e), Error::Io(e) => write!(f, "io error: {}", e), - Error::CallbackError => write!(f, "callback error"), + Error::CallbackError(e) => write!(f, "callback error, status code: {:?}", e), Error::Internal(e) => write!(f, "internal metatomic error (this is likely a bug, please report it): {}", e ), @@ -68,7 +72,7 @@ impl std::error::Error for Error { Error::InvalidParameter(_) | Error::Serialization(_) | Error::Internal(_) - | Error::CallbackError => None, + | Error::CallbackError(_) => None, Error::Io(e) => Some(e), } } @@ -92,3 +96,9 @@ impl From> for Error { } } } + +impl From for Error { + fn from(error: std::io::Error) -> Self { + Error::Io(Arc::new(error)) + } +} diff --git a/metatomic-core/src/model.rs b/metatomic-core/src/model.rs index 16b208ac1..1b93577d2 100644 --- a/metatomic-core/src/model.rs +++ b/metatomic-core/src/model.rs @@ -7,6 +7,17 @@ use crate::c_api::mta_model_t; /// TODO pub struct Model(pub(crate) mta_model_t); +impl Model { + /// Create a new `Model` from the corresponding C API struct. + pub fn new(model: mta_model_t) -> Self { + return Model(model); + } + + /// Extract the underlying C API struct. + pub fn into_raw(self) -> mta_model_t { + return self.0; + } +} /// TODO pub fn execute_model( diff --git a/metatomic-core/src/plugin.rs b/metatomic-core/src/plugin.rs index 60d145803..4a9da5ab9 100644 --- a/metatomic-core/src/plugin.rs +++ b/metatomic-core/src/plugin.rs @@ -1,37 +1,191 @@ -use std::collections::BTreeMap; +use std::ffi::CStr; +use std::sync::Mutex; -use crate::c_api::mta_plugin_t; +use libloading::Library; +use once_cell::sync::Lazy; + +use crate::c_api::{mta_model_t, mta_plugin_t, mta_register_plugin, mta_status_t}; use crate::{Error, Model}; -/// TODO +/// ABI version of the metatomic plugin interface. +/// +/// This increases anytime the plugin or model C API changes in a non backward +/// compatible way. Plugins compiled with an incompatible ABI version will be +/// rejected at registration time. pub const MTA_ABI_VERSION: i32 = 1; -/// TODO +/// The list of registered plugins in the current process. +static PLUGINS: Lazy>> = Lazy::new(|| Mutex::new(Vec::new())); +/// Keep the loaded libraries alive for the entire process lifetime, to ensure +/// that the plugin code is not unloaded while it's still in use. +static LIBRARIES: Lazy>> = Lazy::new(|| Mutex::new(Vec::new())); + pub struct Plugin(mta_plugin_t); impl Plugin { - /// TODO - pub fn new(c_plugin: mta_plugin_t) -> Self { - Self(c_plugin) + /// Create a new plugin from the C struct + pub fn new(plugin: mta_plugin_t) -> Result { + if plugin.name.is_null() { + return Err(Error::InvalidParameter( + "can not register plugin: plugin `name` is NULL".into(), + )); + } + + let c_str_name = unsafe { CStr::from_ptr(plugin.name) }; + if c_str_name.to_str().is_err() { + return Err(Error::InvalidParameter(format!( + "can not register plugin: plugin `name` is not valid UTF-8: {}", + c_str_name.to_string_lossy() + ))); + } + + if plugin.load_model.is_none() { + return Err(Error::InvalidParameter( + "can not register plugin: plugin `load_model` callback is NULL".into(), + )); + } + + if plugin.abi_version != MTA_ABI_VERSION { + let name = unsafe { + CStr::from_ptr(plugin.name).to_string_lossy() + }; + + return Err(Error::InvalidParameter(format!( + "can not register plugin '{}': plugin ABI version is {}, but metatomic expects {}", + name, + plugin.abi_version, + MTA_ABI_VERSION, + ))); + } + + Ok(Plugin(plugin)) } - /// TODO + /// Get the plugin name. pub fn name(&self) -> &str { - todo!() + unsafe { + return CStr::from_ptr(self.0.name) + .to_str() + .expect("invalid UTF-8 in plugin name"); + } } - /// TODO - pub fn load_model(&self, load_from: &str, options: BTreeMap) -> Result { - todo!() + /// Try to load a model with this plugin. + pub fn load_model( + &self, + load_from: &CStr, + options_json: &CStr, + ) -> Result { + let load_model = self.0.load_model.expect("`load_model` is NULL"); + + let mut model = mta_model_t::null(); + let status = unsafe { + load_model(load_from.as_ptr(), options_json.as_ptr(), &mut model) + }; + + if status != mta_status_t::MTA_SUCCESS { + return Err(Error::CallbackError(status)); + } + + return Ok(Model::new(model)); } } -/// TODO +/// Register a new plugin in the current process. +pub fn register_plugin(plugin: Plugin) -> Result<(), Error> { + let mut plugins = PLUGINS.lock().expect("plugin registry mutex was poisoned"); + if plugins.iter().any(|existing| existing.name() == plugin.name()) { + return Err(Error::InvalidParameter(format!( + "a plugin named '{}' is already registered", + plugin.name() + ))); + } + + plugins.push(plugin); + return Ok(()); +} + +/// Load a plugin from a shared library. +/// +/// The shared library must export the symbols generated by the +/// `MTA_REGISTER_PLUGIN` C macro. pub fn load_plugin(path: &str) -> Result<(), Error> { - todo!() + // this needs to be kept in sync with the definition in `MTA_REGISTER_PLUGIN` in build.rs + type PluginInitFn = unsafe extern "C" fn(abi: i32, data: *mut std::ffi::c_void) -> mta_status_t; + + let library = unsafe { Library::new(path) }; + + let library = library.map_err(|error| { + std::io::Error::other( + format!("failed to load plugin '{}': {}", path, error), + ) + })?; + + let status = unsafe { + let init_plugin = library.get::(b"mta_plugin_init\0") + .map_err(|error| Error::InvalidParameter(format!( + "failed to load plugin registration symbol from '{}': {}", + path, error + )))?; + init_plugin(MTA_ABI_VERSION, mta_register_plugin as *mut std::ffi::c_void) + }; + + if status != mta_status_t::MTA_SUCCESS { + return Err(Error::CallbackError(status)); + } + + LIBRARIES.lock().expect("loaded plugin registry mutex was poisoned").push(library); + + return Ok(()); } -/// TODO -pub fn load_model(plugin: Option<&str>, load_from: &str, options: BTreeMap) -> Result { - todo!() +/// Load a model from `load_from`, using the given options. +pub fn load_model( + plugin_name: Option<&str>, + load_from: &CStr, + options_json: &CStr, +) -> Result { + let plugins = PLUGINS.lock().expect("plugin registry mutex was poisoned"); + + if let Some(plugin_name) = plugin_name { + for plugin in plugins.iter() { + if plugin.name() == plugin_name { + return plugin.load_model(load_from, options_json); + } + } + + return Err(Error::InvalidParameter(format!( + "no plugin named '{}' is registered", + plugin_name + ))); + } + + for plugin in plugins.iter() { + match plugin.load_model(load_from, options_json) { + Ok(model) => return Ok(model), + Err(e) => { + if let Error::CallbackError(mta_status_t::MTA_MODEL_NOT_SUPPORTED_ERROR) = e { + // try the next plugin + continue; + } else { + return Err(e); + } + } + } + } + + let message = if plugins.is_empty() { + "no plugin is registered".into() + } else { + format!( + "tried the following plugins, but none could load the model: {}", + plugins.iter().map(|p| p.name()).collect::>().join(", ") + ) + }; + + return Err(Error::InvalidParameter(format!( + "failed to load model from '{}': {}", + load_from.to_string_lossy(), + message + ))); } diff --git a/metatomic-core/tests/CMakeLists.txt b/metatomic-core/tests/CMakeLists.txt index 1a96108d5..77e251ddc 100644 --- a/metatomic-core/tests/CMakeLists.txt +++ b/metatomic-core/tests/CMakeLists.txt @@ -53,6 +53,7 @@ endif() enable_testing() +add_subdirectory(test-plugins) file(GLOB ALL_TESTS *.cpp) foreach(_file_ ${ALL_TESTS}) @@ -70,6 +71,8 @@ foreach(_file_ ${ALL_TESTS}) NO_SYSTEM_FROM_IMPORTED ON ) + target_compile_definitions(${_name_} PRIVATE PLUGIN_DIR="${CMAKE_CURRENT_BINARY_DIR}/test-plugins") + add_test( NAME ${_name_} COMMAND ${TEST_COMMAND} $ diff --git a/metatomic-core/tests/plugins.cpp b/metatomic-core/tests/plugins.cpp new file mode 100644 index 000000000..79b53947e --- /dev/null +++ b/metatomic-core/tests/plugins.cpp @@ -0,0 +1,45 @@ +#include + +#include "metatomic.h" + + +TEST_CASE("Load plugins") { + auto status = mta_load_plugin(PLUGIN_DIR "/test-c-plugin.so"); + CHECK(status == MTA_SUCCESS); + + // try to load the model with an explicit plugin name + struct mta_model_t model; + status = mta_load_model("test-c-plugin", "some_model", "{}", &model); + CHECK(status == MTA_MODEL_NOT_SUPPORTED_ERROR); + + // load the plugin without specifying the plugin name + status = mta_load_model(nullptr, "some_model", "{}", &model); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + const char* error_message; + const char* error_origin; + + status = mta_last_error(&error_message, &error_origin, nullptr); + REQUIRE(status == MTA_SUCCESS); + + CHECK(std::string(error_origin) == "metatomic-core"); + const char* expected_message = ( + "invalid parameter: failed to load model from 'some_model': tried the " + "following plugins, but none could load the model: test-c-plugin" + ); + CHECK(std::string(error_message) == expected_message); + + + status = mta_load_plugin(PLUGIN_DIR "/bad-abi-plugin.so"); + CHECK(status == MTA_INVALID_PARAMETER_ERROR); + + status = mta_last_error(&error_message, &error_origin, nullptr); + REQUIRE(status == MTA_SUCCESS); + + CHECK(std::string(error_origin) == "metatomic-core"); + expected_message = ( + "invalid parameter: can not register plugin 'bad-abi-plugin': " + "plugin ABI version is 2, but metatomic expects 1" + ); + CHECK(std::string(error_message) == expected_message); +} diff --git a/metatomic-core/tests/test-plugins/CMakeLists.txt b/metatomic-core/tests/test-plugins/CMakeLists.txt new file mode 100644 index 000000000..2693ceab7 --- /dev/null +++ b/metatomic-core/tests/test-plugins/CMakeLists.txt @@ -0,0 +1,14 @@ +add_library(test-c-plugin SHARED plugin.c) +target_link_libraries(test-c-plugin metatomic) +# create test plugins with a consistent name across platforms +set_target_properties(test-c-plugin PROPERTIES + PREFIX "" + SUFFIX ".so" +) + +add_library(bad-abi-plugin SHARED bad-abi.c) +target_link_libraries(bad-abi-plugin metatomic) +set_target_properties(bad-abi-plugin PROPERTIES + PREFIX "" + SUFFIX ".so" +) diff --git a/metatomic-core/tests/test-plugins/bad-abi.c b/metatomic-core/tests/test-plugins/bad-abi.c new file mode 100644 index 000000000..35e86bcfe --- /dev/null +++ b/metatomic-core/tests/test-plugins/bad-abi.c @@ -0,0 +1,17 @@ +#include + + +static mta_status_t load_model(const char *load_from, const char *options_json, struct mta_model_t *model) { + // This plugin can not load any model + return MTA_MODEL_NOT_SUPPORTED_ERROR; +} + + +MTA_REGISTER_PLUGIN(register_plugin, { + mta_plugin_t plugin = { + .abi_version = MTA_ABI_VERSION + 1, // incompatible ABI version + .name = "bad-abi-plugin", + .load_model = load_model, + }; + return register_plugin(plugin); +}); diff --git a/metatomic-core/tests/test-plugins/plugin.c b/metatomic-core/tests/test-plugins/plugin.c new file mode 100644 index 000000000..1602dfae7 --- /dev/null +++ b/metatomic-core/tests/test-plugins/plugin.c @@ -0,0 +1,17 @@ +#include + + +static mta_status_t load_model(const char *load_from, const char *options_json, struct mta_model_t *model) { + // This plugin can not load any model + return MTA_MODEL_NOT_SUPPORTED_ERROR; +} + + +MTA_REGISTER_PLUGIN(register_plugin, { + mta_plugin_t plugin = { + .abi_version = MTA_ABI_VERSION, + .name = "test-c-plugin", + .load_model = load_model, + }; + return register_plugin(plugin); +}); diff --git a/scripts/include/stdio.h b/scripts/include/stdio.h new file mode 100644 index 000000000..e69de29bb From 272d06cb6a797be9aa22da534725fc6ff1a1985a Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 29 May 2026 08:32:58 +0200 Subject: [PATCH 16/20] Add a C model registration test --- metatomic-core/tests/c-model.cpp | 197 +++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 metatomic-core/tests/c-model.cpp diff --git a/metatomic-core/tests/c-model.cpp b/metatomic-core/tests/c-model.cpp new file mode 100644 index 000000000..ec1da92f7 --- /dev/null +++ b/metatomic-core/tests/c-model.cpp @@ -0,0 +1,197 @@ +#include + +#include + +#include "metatomic.h" + +#include + + +struct SimpleModelData { + double scale; +}; + +static mta_status_t unload_impl(void* model_data) { + delete static_cast(model_data); + return MTA_SUCCESS; +} + +static mta_status_t metadata_impl(const void* model_data, mta_string_t* metadata_json) { + (void) model_data; + + *metadata_json = mta_string_create(R"({ + "name": "test C model", + "description": "small model used as a C API example", + "authors": [], + "references": { + "model": [], + "implementation": [], + "architecture": [] + } + })"); + return MTA_SUCCESS; +} + +static mta_status_t capabilities_impl(const void* model_data, mta_string_t* capabilities_json) { + (void) model_data; + + *capabilities_json = mta_string_create(R"({ + "outputs": [{ + "quantity": "energy", + "unit": "eV", + "per_atom": false + }], + "atomic_types": [1, 6, 8], + "interaction_range": 4.5, + "length_unit": "nm", + "supported_devices": ["cpu"], + "dtype": "float32" + })"); + return MTA_SUCCESS; +} + +static mta_status_t supported_outputs_impl( + const void* model_data, + mta_string_t* outputs_json +) { + (void) model_data; + *outputs_json = mta_string_create(R"([{ + "quantity": "energy", + "unit": "eV", + "per_atom": false + }])"); + return MTA_SUCCESS; +} + +static mta_status_t requested_pair_lists_impl( + const void* model_data, + mta_string_t* pair_options_json +) { + (void) model_data; + *pair_options_json = mta_string_create("[]"); + return MTA_SUCCESS; +} + +static mta_status_t requested_inputs_impl( + const void* model_data, + mta_string_t* requested_inputs_json +) { + (void) model_data; + *requested_inputs_json = mta_string_create("[]"); + return MTA_SUCCESS; +} + + +mts_tensormap_t* scalar_tensormap(double value) { + auto values = std::make_unique>( + std::vector{1, 1}, + std::vector{value} + ); + + auto array = metatensor::DataArrayBase::to_mts_array(std::move(values)); + + auto samples = metatensor::Labels({"system"}, {{0}}); + auto properties = metatensor::Labels({"energy"}, {{0}}); + + auto* block = mts_block( + std::move(array).release(), + samples.as_mts_labels_t(), + nullptr, + 0, + properties.as_mts_labels_t() + ); + if (block == nullptr) { + return nullptr; + } + + auto keys = metatensor::Labels({"_"}, {{0}}); + auto blocks = std::vector{block}; + return mts_tensormap(keys.as_mts_labels_t(), blocks.data(), blocks.size()); +} + +static mta_status_t execute_inner_impl( + void* model_data, + const mta_system_t* const* systems, + uintptr_t systems_count, + const mts_labels_t* selected_atoms, + const char* requested_outputs_json, + mts_tensormap_t** outputs, + uintptr_t outputs_count +) { + (void)model_data; + (void)systems; + (void)systems_count; + (void)selected_atoms; + (void)requested_outputs_json; + (void)outputs; + (void)outputs_count; + + return MTA_INTERNAL_ERROR; +} + +static mta_status_t load_model_impl( + const char* load_from, + const char* options_json, + mta_model_t* model +) { + (void)options_json; + assert(model != nullptr); + + if (std::strcmp(load_from, "test-c-model") != 0) { + return MTA_MODEL_NOT_SUPPORTED_ERROR; + } + + model->data = new SimpleModelData{2.0}; + model->unload = unload_impl; + model->metadata = metadata_impl; + model->capabilities = capabilities_impl; + model->supported_outputs = supported_outputs_impl; + model->requested_pair_lists = requested_pair_lists_impl; + model->requested_inputs = requested_inputs_impl; + model->execute_inner = execute_inner_impl; + + return MTA_SUCCESS; +} + +TEST_CASE("simple C model can be registered and loaded through the C API") { + static auto PLUGIN = mta_plugin_t { + MTA_ABI_VERSION, + "test-c-plugin", + load_model_impl, + }; + mta_register_plugin(PLUGIN); + + auto model = mta_model_t{}; + auto status = mta_load_model("test-c-plugin", "test-c-model", nullptr, &model); + REQUIRE(status == MTA_SUCCESS); + + CHECK(model.data != nullptr); + CHECK(model.unload != nullptr); + CHECK(model.metadata != nullptr); + CHECK(model.capabilities != nullptr); + CHECK(model.supported_outputs != nullptr); + CHECK(model.requested_pair_lists != nullptr); + CHECK(model.requested_inputs != nullptr); + CHECK(model.execute_inner != nullptr); + + mta_string_t metadata = nullptr; + status = model.metadata(model.data, &metadata); + REQUIRE(status == MTA_SUCCESS); + + CHECK(metadata != nullptr); + auto metadata_str = std::string(mta_string_view(metadata)); + mta_string_free(metadata); + + CHECK(metadata_str.find("\"name\": \"test C model\"") != std::string::npos); + + + mta_string_t pair_lists = nullptr; + status = model.requested_pair_lists(model.data, &pair_lists); + REQUIRE(status == MTA_SUCCESS); + + CHECK(pair_lists != nullptr); + CHECK(std::strcmp(mta_string_view(pair_lists), "[]") == 0); + mta_string_free(pair_lists); + + REQUIRE(model.unload(model.data) == MTA_SUCCESS); +} From 5a0882c78bfd9644798aff4352df84a7d3bdd096 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Mon, 1 Jun 2026 16:23:16 +0200 Subject: [PATCH 17/20] Implement System in metatomic-core Co-Authored-By: frostedoyster --- metatomic-core/Cargo.toml | 3 +- metatomic-core/include/metatomic.h | 10 +- metatomic-core/src/c_api/status.rs | 9 +- metatomic-core/src/lib.rs | 20 + metatomic-core/src/metadata.rs | 8 +- metatomic-core/src/quantities.rs | 23 +- metatomic-core/src/system.rs | 739 ++++++++++++++++++++++++++++- 7 files changed, 779 insertions(+), 33 deletions(-) diff --git a/metatomic-core/Cargo.toml b/metatomic-core/Cargo.toml index dace0f8d5..b88a91cda 100644 --- a/metatomic-core/Cargo.toml +++ b/metatomic-core/Cargo.toml @@ -16,9 +16,10 @@ bench = false [dependencies] metatensor = { version = "0.3.0" } once_cell = "1" -dlpk = "0.3" +dlpk = { version = "0.3", features = ["ndarray"]} json = "0.12" libloading = "0.8" +ndarray = "0.17" [build-dependencies] diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index a2a549e78..9273007d0 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -90,11 +90,19 @@ typedef enum mta_status_t { * Status code indicating serialization/deserialization errors */ MTA_SERIALIZATION_ERROR = 3, + /** + * Status code indicating dlpack errors + */ + MTA_DLPACK_ERROR = 4, + /** + * Status code indicating metatensor errors + */ + MTA_METATENSOR_ERROR = 5, /** * Status code used by plugins when a model is not supported by the * current plugin */ - MTA_MODEL_NOT_SUPPORTED_ERROR = 4, + MTA_MODEL_NOT_SUPPORTED_ERROR = 6, /** * Status code used when there is an internal error */ diff --git a/metatomic-core/src/c_api/status.rs b/metatomic-core/src/c_api/status.rs index b9c9a16aa..48d658777 100644 --- a/metatomic-core/src/c_api/status.rs +++ b/metatomic-core/src/c_api/status.rs @@ -37,9 +37,13 @@ pub enum mta_status_t { MTA_IO_ERROR = 2, /// Status code indicating serialization/deserialization errors MTA_SERIALIZATION_ERROR = 3, + /// Status code indicating dlpack errors + MTA_DLPACK_ERROR = 4, + /// Status code indicating metatensor errors + MTA_METATENSOR_ERROR = 5, /// Status code used by plugins when a model is not supported by the /// current plugin - MTA_MODEL_NOT_SUPPORTED_ERROR = 4, + MTA_MODEL_NOT_SUPPORTED_ERROR = 6, /// Status code used when there is an internal error MTA_INTERNAL_ERROR = 255, } @@ -107,8 +111,11 @@ impl From for mta_status_t { Error::InvalidParameter(_) => mta_status_t::MTA_INVALID_PARAMETER_ERROR, Error::Io(_) => mta_status_t::MTA_IO_ERROR, Error::Serialization(_) => mta_status_t::MTA_SERIALIZATION_ERROR, + Error::Dlpack(_) => mta_status_t::MTA_DLPACK_ERROR, + Error::Metatensor(_) => mta_status_t::MTA_METATENSOR_ERROR, Error::CallbackError(_) => unreachable!("already handled above"), Error::Internal(_) => mta_status_t::MTA_INTERNAL_ERROR, + } } } diff --git a/metatomic-core/src/lib.rs b/metatomic-core/src/lib.rs index 3450f05df..cc20d6961 100644 --- a/metatomic-core/src/lib.rs +++ b/metatomic-core/src/lib.rs @@ -46,6 +46,10 @@ pub enum Error { InvalidParameter(String), /// I/O error Io(Arc), + /// Error related to dlpack tensors, such as invalid tensor shapes or types + Dlpack(Arc), + /// Error coming from metatensor + Metatensor(metatensor::Error), /// Error coming from an external function used as a callback CallbackError(mta_status_t), /// Any other internal error, usually these are internal bugs. @@ -58,6 +62,8 @@ impl std::fmt::Display for Error { Error::Serialization(e) => write!(f, "serialization error: {}", e), Error::InvalidParameter(e) => write!(f, "invalid parameter: {}", e), Error::Io(e) => write!(f, "io error: {}", e), + Error::Dlpack(e) => write!(f, "dlpack error: {}", e), + Error::Metatensor(e) => write!(f, "metatensor error: {}", e), Error::CallbackError(e) => write!(f, "callback error, status code: {:?}", e), Error::Internal(e) => write!(f, "internal metatomic error (this is likely a bug, please report it): {}", e @@ -74,6 +80,8 @@ impl std::error::Error for Error { | Error::Internal(_) | Error::CallbackError(_) => None, Error::Io(e) => Some(e), + Error::Dlpack(e) => Some(e), + Error::Metatensor(e) => Some(e), } } @@ -102,3 +110,15 @@ impl From for Error { Error::Io(Arc::new(error)) } } + +impl From for Error { + fn from(error: dlpk::ndarray::DLPackNDarrayError) -> Self { + Error::Dlpack(Arc::new(error)) + } +} + +impl From for Error { + fn from(error: metatensor::Error) -> Self { + Error::Metatensor(error) + } +} diff --git a/metatomic-core/src/metadata.rs b/metatomic-core/src/metadata.rs index 48bfeaad5..a0f0c3875 100644 --- a/metatomic-core/src/metadata.rs +++ b/metatomic-core/src/metadata.rs @@ -9,16 +9,16 @@ use crate::units::validate_unit; #[derive(Debug, Clone)] pub struct PairListOptions { /// Cutoff radius for this pair list in the length unit of the model - cutoff: f64, + pub cutoff: f64, /// Whether the list is a full list (contains both the pair `i -> j` and `j -> i`) /// or a half list (contains only `i -> j`) - full_list: bool, + pub full_list: bool, /// Whether the list guarantees that only atoms within the cutoff are /// included (strict) or may also include pairs slightly beyond the cutoff /// (non-strict) - strict: bool, + pub strict: bool, /// List of strings describing who requested this pair list - requestors: Vec, + pub requestors: Vec, } impl std::cmp::PartialEq for PairListOptions { diff --git a/metatomic-core/src/quantities.rs b/metatomic-core/src/quantities.rs index 93727c837..cc313d555 100644 --- a/metatomic-core/src/quantities.rs +++ b/metatomic-core/src/quantities.rs @@ -41,7 +41,7 @@ fn is_valid_identifier(s: &str) -> bool { /// All components (namespace, name, variant) must be non-empty if they are /// present, and must be valid identifiers (alphanumeric + underscore, not /// starting with a digit). -fn validate_quantity_name(name: &str) -> Result<(), Error> { +pub(crate) fn validate_quantity_name(name: &str) -> Result<(), Error> { if STANDARD_QUANTITIES.contains(&name) { return Ok(()); } @@ -67,7 +67,12 @@ fn validate_quantity_name(name: &str) -> Result<(), Error> { } } - for component in main_part.split("::") { + if STANDARD_QUANTITIES.contains(&main_part) { + return Ok(()); + } + + let components: Vec<_> = main_part.split("::").collect(); + for component in &components { if !is_valid_identifier(component) { return Err(Error::InvalidParameter(format!( "invalid quantity name component '{}' in '{}': must be a valid identifier (alphanumeric or underscore, not starting with a digit)", @@ -76,6 +81,13 @@ fn validate_quantity_name(name: &str) -> Result<(), Error> { } } + if components.len() == 1 { + return Err(Error::InvalidParameter(format!( + "'{}' is not a standard quantity name; custom quantity names must use '::'", + name + ))); + } + Ok(()) } @@ -289,7 +301,7 @@ mod tests { vec![Gradients::Positions, Gradients::Strain], ] { let quantity = Quantity { - name: "test".into(), + name: "test_ns::test".into(), unit: "unit".into(), description: Some("Hello".to_string()), gradients: grads.clone(), @@ -367,9 +379,7 @@ mod tests { "my_model::energy", "org::my_model::custom_qty", "ns1::ns2::ns3::energy", - "custom_name", "some_ns::name_with_underscores", - "_underscore_start", "_ns::_name", ]; for name in custom { @@ -388,6 +398,9 @@ mod tests { let error = validate_quantity_name("").expect_err("expected an error"); assert_eq!(error.to_string(), "invalid parameter: quantity name cannot be empty in ''"); + let error = validate_quantity_name("not_a_standard_name").expect_err("expected an error"); + assert_eq!(error.to_string(), "invalid parameter: 'not_a_standard_name' is not a standard quantity name; custom quantity names must use '::'"); + let error = validate_quantity_name("/variant").expect_err("expected an error"); assert_eq!(error.to_string(), "invalid parameter: quantity name cannot be empty in '/variant'"); diff --git a/metatomic-core/src/system.rs b/metatomic-core/src/system.rs index 30677f5f9..d30dccce4 100644 --- a/metatomic-core/src/system.rs +++ b/metatomic-core/src/system.rs @@ -1,12 +1,21 @@ -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; +use once_cell::sync::Lazy; -use dlpk::DLPackTensor; +use dlpk::sys::{DLDataType, DLDevice, DLDeviceType}; +use dlpk::{DLPackTensor, DLPackTensorRef}; use metatensor::{TensorBlock, TensorMap}; -use crate::PairListOptions; +use crate::{Error, PairListOptions}; +/// Names that can never be used as custom data in a system +static INVALID_DATA_NAMES: Lazy> = Lazy::new(|| { + HashSet::from(["types", "type", "positions", "position", "cell", "neighbors", "neighbor", "pair", "pairs"]) +}); -/// TODO +/// Storage for an atomistic system. +/// +/// This owns the raw DLPack tensors and metatensor objects used at FFI +/// boundaries. pub struct System { length_unit: String, types: DLPackTensor, @@ -18,36 +27,724 @@ pub struct System { custom_data: HashMap, } - impl System { - /// TODO + /// Create a `System` from raw DLPack tensors pub fn new( length_unit: String, types: DLPackTensor, positions: DLPackTensor, cell: DLPackTensor, - pbc: DLPackTensor - ) -> Self { - todo!() + pbc: DLPackTensor, + ) -> Result { + validate_system_tensors(&types, &positions, &cell, &pbc)?; + + let system = System { + length_unit, + types, + positions, + cell, + pbc, + pairs: BTreeMap::new(), + custom_data: HashMap::new(), + }; + + if system.device().device_type == DLDeviceType::kDLCPU { + validate_cpu_system_data(&system)?; + } + + return Ok(system); + } + + /// Get the length unit used by this system + pub fn length_unit(&self) -> &str { + &self.length_unit + } + + /// Get the number of atoms/particles in this system + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + pub fn size(&self) -> usize { + let size = self.types.shape()[0]; + debug_assert!(usize::try_from(size).is_ok()); + return size as usize; + } + + /// Get the particle types + pub fn types(&self) -> DLPackTensorRef<'_> { + self.types.as_ref() + } + + /// Get the particle positions + pub fn positions(&self) -> DLPackTensorRef<'_> { + self.positions.as_ref() + } + + /// Get the unit cell + pub fn cell(&self) -> DLPackTensorRef<'_> { + self.cell.as_ref() + } + + /// Get the periodic boundary condition flags + pub fn pbc(&self) -> DLPackTensorRef<'_> { + self.pbc.as_ref() + } + + /// Add a pair list to this system + pub fn add_pairs( + &mut self, + options: PairListOptions, + pairs: TensorBlock, + ) -> Result<(), Error> { + if self.pairs.contains_key(&options) { + return Err(Error::InvalidParameter( + "the pair list for these options already exists in this system".into(), + )); + } + + let samples = pairs.samples(); + let samples_names = samples.names(); + if samples_names != ["first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"] { + return Err(Error::InvalidParameter( + "invalid samples for `pairs`: the samples names must be \ + 'first_atom', 'second_atom', 'cell_shift_a', 'cell_shift_b', \ + 'cell_shift_c'".into(), + )); + } + + let components = pairs.components(); + if components.len() != 1 || components[0].names() != ["xyz"] || components[0].count() != 3 { + return Err(Error::InvalidParameter( + "invalid components for `pairs`: there should be a \ + single 'xyz'=[0, 1, 2] component".into() + )); + } + + #[allow(clippy::collapsible_if)] + if components[0].device().device_type == DLDeviceType::kDLCPU { + if components[0][0] != [0] || components[0][1] != [1] || components[0][2] != [2] { + return Err(Error::InvalidParameter( + "invalid components for `pairs`: the 'xyz' \ + component should contain [0, 1, 2]".into() + )); + } + } + + let properties = pairs.properties(); + if properties.names() != ["distance"] || properties.count() != 1 { + return Err(Error::InvalidParameter( + "invalid properties for `pairs`: there should be a single \ + 'distance'=0 property".into() + )); + } + + #[allow(clippy::collapsible_if)] + if properties.device().device_type == DLDeviceType::kDLCPU { + if properties[0] != [0] { + return Err(Error::InvalidParameter( + "invalid properties for `pairs`: the 'distance' property \ + should contain [0]".into() + )); + } + } + + if !pairs.as_ref().gradient_list().is_empty() { + return Err(Error::InvalidParameter( + "`pairs` should not have any gradients".into() + )); + } + + // TODO: add TensorBlock::device/dtype and use them here + let values = pairs.values(); + let values_device = values.device()?; + if values_device != self.device() { + return Err(Error::InvalidParameter(format!( + "`pairs` device ({}) does not match this system's device ({})", + values_device, self.device(), + ))); + } + + let values_dtype = values.dtype()?; + if values_dtype != self.dtype() { + return Err(Error::InvalidParameter(format!( + "`pairs` dtype ({}) does not match this system's dtype ({})", + values_dtype, self.dtype(), + ))); + } + + self.pairs.insert(options, pairs); + return Ok(()); + } + + /// Get a pair list from this system + pub fn get_pairs(&self, options: &PairListOptions) -> Option<&TensorBlock> { + return self.pairs.get(options); + } + + /// Get all pair list options known by this system + pub fn known_pairs(&self) -> Vec<&PairListOptions> { + return self.pairs.keys().collect(); + } + + /// Add custom data to this system + /// + /// If `override_` is `true`, existing data with the same name will be + /// replaced. + pub fn add_custom_data(&mut self, name: String, data: TensorMap, override_: bool) -> Result<(), Error> { + if INVALID_DATA_NAMES.contains(name.to_lowercase().as_str()) { + return Err(Error::InvalidParameter(format!( + "custom data can not be named '{}'", name + ))); + } + + crate::quantities::validate_quantity_name(&name)?; + + if !override_ && self.custom_data.contains_key(&name) { + return Err(Error::InvalidParameter(format!( + "custom data '{}' is already present in this system", + name + ))); + } + + if data.keys().count() == 0 { + return Err(Error::InvalidParameter(format!( + "custom data '{}' has no blocks", name + ))); + } + + // TODO: add TensorMap::device/dtype and use them here + let block = data.block_by_id(0); + let values = block.values(); + let data_device = values.device()?; + if data_device != self.device() { + return Err(Error::InvalidParameter(format!( + "device ({}:{}) of the custom data '{}' does not match this system device ({}:{})", + data_device.device_type, data_device.device_id, name, + self.device().device_type, self.device().device_id, + ))); + } + + let values_dtype = values.dtype()?; + if values_dtype != self.dtype() { + return Err(Error::InvalidParameter(format!( + "dtype of custom data '{}' does not match this system dtype", + name, + ))); + } + + self.custom_data.insert(name, data); + return Ok(()); + } + + /// Get custom data from this system. + pub fn get_custom_data(&self, name: &str) -> Result<&TensorMap, Error> { + let lower = name.to_lowercase(); + if INVALID_DATA_NAMES.contains(lower.as_str()) { + return Err(Error::InvalidParameter(format!( + "custom data can not be named '{}'", name + ))); + } + + return self.custom_data.get(name).ok_or_else(|| Error::InvalidParameter(format!( + "no data for '{}' found in this system", name + ))); + } + + /// Get all custom data names known by this system. + pub fn known_custom_data(&self) -> Vec<&str> { + return self.custom_data.keys().map(String::as_str).collect(); + } + + /// The device used for all tensors in this system + fn device(&self) -> DLDevice { + self.types.device() + } + + /// The data type used for the `positions` and `cell` tensors in this + /// system, as well as any pair lists and custom data added to this system. + fn dtype(&self) -> DLDataType { + self.positions.dtype() + } +} + +fn validate_system_tensors( + types: &DLPackTensor, + positions: &DLPackTensor, + cell: &DLPackTensor, + pbc: &DLPackTensor, +) -> Result<(), Error> { + let device = types.device(); + if positions.device() != device || cell.device() != device || pbc.device() != device { + return Err(Error::InvalidParameter( + "`types`, `positions`, `cell`, and `pbc` must be on the same device".into() + )); + } + + let dtype_i32 = ::get_dlpack_data_type(); + let dtype_f32 = ::get_dlpack_data_type(); + let dtype_f64 = ::get_dlpack_data_type(); + let dtype_bool = ::get_dlpack_data_type(); + + if types.dtype() != dtype_i32 { + return Err(Error::InvalidParameter( + "`types` must be a tensor of 32-bit integers".into() + )); + } + + let types_shape = types.shape(); + if types_shape.len() != 1 || types_shape[0] < 0 { + return Err(Error::InvalidParameter(format!( + "`types` must be a (n_atoms,) tensor, got a tensor with shape [{}]", + types_shape.iter().map(|dim| dim.to_string()).collect::>().join(", ") + ))); + } + + let n_atoms = types_shape[0]; + + let positions_shape = positions.shape(); + if positions_shape.len() != 2 || positions_shape[0] != n_atoms || positions_shape[1] != 3 { + return Err(Error::InvalidParameter(format!( + "`positions` must be a (n_atoms x 3) tensor, got a tensor with shape [{}]", + positions_shape.iter().map(|dim| dim.to_string()).collect::>().join(", ") + ))); + } + + if positions.dtype() != dtype_f32 && positions.dtype() != dtype_f64 { + return Err(Error::InvalidParameter( + "`positions` must be a tensor of 32 or 64-bit floating point data".into() + )); + } + + let cell_shape = cell.shape(); + if cell_shape.len() != 2 || cell_shape[0] != 3 || cell_shape[1] != 3 { + return Err(Error::InvalidParameter(format!( + "`cell` must be a (3 x 3) tensor, got a tensor with shape [{}]", + cell_shape.iter().map(|dim| dim.to_string()).collect::>().join(", ") + ))); } - /// TODO - pub fn add_pairs(&mut self, options: PairListOptions, pairs: TensorBlock, check_consistency: bool) { - todo!() + if cell.dtype() != positions.dtype() { + return Err(Error::InvalidParameter( + "`cell` must have the same dtype as `positions`".into() + )); } - /// TODO - pub fn get_pairs(&mut self, options: PairListOptions) -> Option<&TensorBlock> { - todo!() + let pbc_shape = pbc.shape(); + if pbc_shape.len() != 1 || pbc_shape[0] != 3 { + return Err(Error::InvalidParameter(format!( + "`pbc` must contain 3 entries, got a tensor with shape [{}]", + pbc_shape.iter().map(|dim| dim.to_string()).collect::>().join(", ") + ))); } - /// TODO - pub fn set_custom_data(&mut self, name: String, data: TensorMap) { - todo!() + if pbc.dtype() != dtype_bool { + return Err(Error::InvalidParameter( + "`pbc` must be a tensor of booleans".into() + )); } - /// TODO - pub fn get_custom_data(&self, name: &str) -> Option<&TensorMap> { - todo!() + return Ok(()); +} + +fn validate_cpu_system_data(system: &System) -> Result<(), Error> { + let pbc_array: ndarray::ArrayView1 = system.pbc().try_into()?; + + if system.dtype().bits == 32 { + let cell_array: ndarray::ArrayView2 = system.cell().try_into()?; + for i in 0..3 { + if !pbc_array[i] && !cell_array.row(i).iter().all(|&x| x == 0.0) { + return Err(Error::InvalidParameter(format!( + "invalid cell: for non-periodic dimensions, the corresponding \ + cell vector must be zero, but cell[{}] contains non-zero values", + i + ))); + } + } + } else { + let cell_array: ndarray::ArrayView2 = system.cell().try_into()?; + for i in 0..3 { + if !pbc_array[i] && !cell_array.row(i).iter().all(|&x| x == 0.0) { + return Err(Error::InvalidParameter(format!( + "invalid cell: for non-periodic dimensions, the corresponding \ + cell vector must be zero, but cell[{}] contains non-zero values", + i + ))); + } + } + } + + return Ok(()); +} + +#[cfg(test)] +mod tests { + use super::*; + use metatensor::Labels; +use ndarray::{Array1, Array2}; + + // ----------------------------------------------------------------------- + // helpers to create DLPack tensors + // ----------------------------------------------------------------------- + fn type_tensor(data: &[i32]) -> DLPackTensor { + Array1::from_vec(data.to_vec()).try_into().unwrap() + } + + #[allow(clippy::cast_precision_loss)] + fn positions_tensor(n_atoms: usize, dtype: &str) -> DLPackTensor { + match dtype { + "f32" => { + let mut data = Vec::with_capacity(3 * n_atoms); + for i in 0..n_atoms { + data.extend_from_slice(&[i as f32, 0.0, 0.0]); + } + Array2::from_shape_vec((n_atoms, 3), data).unwrap().try_into().unwrap() + } + "f64" => { + let mut data = Vec::with_capacity(3 * n_atoms); + for i in 0..n_atoms { + data.extend_from_slice(&[i as f64, 0.0, 0.0]); + } + Array2::from_shape_vec((n_atoms, 3), data).unwrap().try_into().unwrap() + } + _ => panic!("unsupported dtype '{}'", dtype), + } + } + + #[allow(clippy::cast_possible_truncation)] + fn cell_tensor(size: f64, dtype: &str) -> DLPackTensor { + match dtype { + "f32" => { + Array2::::from_shape_vec( + (3, 3), + vec![ + size as f32, 0.0, 0.0, + 0.0, size as f32, 0.0, + 0.0, 0.0, size as f32, + ], + ).unwrap().try_into().unwrap() + } + "f64" => Array2::::from_shape_vec( + (3, 3), + vec![ + size, 0.0, 0.0, + 0.0, size, 0.0, + 0.0, 0.0, size, + ], + ).unwrap().try_into().unwrap(), + _ => panic!("unsupported dtype '{}'", dtype), + } + } + + fn pbc_tensor(data: &[bool]) -> DLPackTensor { + Array1::from_vec(data.to_vec()).try_into().unwrap() + } + + fn valid_pair_block(dtype: &str) -> TensorBlock { + let samples = Labels::new( + ["first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"], + &[[0i32, 1, 0, 0, 0]], + ); + let components = vec![Labels::new(["xyz"], &[[0i32], [1], [2]])]; + let properties = Labels::new(["distance"], &[[0i32]]); + + match dtype { + "f32" => { + let values = ndarray::ArrayD::::from_shape_vec(vec![1, 3, 1], vec![1.5, 2.5, 3.5]).unwrap(); + TensorBlock::new(values, &samples, &components, &properties).unwrap() + } + "f64" => { + let values = ndarray::ArrayD::::from_shape_vec(vec![1, 3, 1], vec![1.5, 2.5, 3.5]).unwrap(); + TensorBlock::new(values, &samples, &components, &properties).unwrap() + } + _ => panic!("unsupported dtype '{}'", dtype), + } + } + + fn valid_custom_data(dtype: &str) -> TensorMap { + let keys = Labels::new(["key"], &[[0i32]]); + let samples = Labels::new(["sample"], &[[0i32]]); + let properties = Labels::new(["property"], &[[0i32]]); + + let block = match dtype { + "f32" => { + let values = ndarray::ArrayD::::from_shape_vec(vec![1, 1], vec![42.0]).unwrap(); + TensorBlock::new(values, &samples, &[], &properties).unwrap() + } + "f64" => { + let values = ndarray::ArrayD::::from_shape_vec(vec![1, 1], vec![42.0]).unwrap(); + TensorBlock::new(values, &samples, &[], &properties).unwrap() + } + _ => panic!("unsupported dtype '{}'", dtype), + }; + + TensorMap::new(keys, vec![block]).unwrap() + } + + fn assert_error(result: Result, expected: &str) { + let error = match result { + Ok(_) => panic!("expected error"), + Err(error) => error, + }; + assert_eq!(error.to_string(), expected); + } + + #[test] + fn system() { + let system = System::new( + "Angstrom".into(), + type_tensor(&[1, 6, 8]), + positions_tensor(3, "f32"), + cell_tensor(10.0, "f32"), + pbc_tensor(&[true, true, true]), + ).unwrap(); + + assert_eq!(system.length_unit(), "Angstrom"); + assert_eq!(system.size(), 3); + assert_eq!(system.device(), DLDevice::cpu()); + assert_eq!(system.dtype().bits, 32); + + let system = System::new( + "Angstrom".into(), + type_tensor(&[1, 6, 8]), + positions_tensor(3, "f64"), + cell_tensor(10.0, "f64"), + pbc_tensor(&[true, true, true]), + ).unwrap(); + assert_eq!(system.length_unit(), "Angstrom"); + assert_eq!(system.size(), 3); + assert_eq!(system.device(), DLDevice::cpu()); + assert_eq!(system.dtype().bits, 64); + } + + #[test] + fn system_invalid_tensors() { + let length_unit = "Angstrom".to_string(); + + let bad_types: DLPackTensor = Array1::::from_vec(vec![1.0, 2.0]).try_into().unwrap(); + let positions = positions_tensor(2, "f32"); + let cell = cell_tensor(0.0, "f32"); + let pbc = pbc_tensor(&[true, true, true]); + + assert_error( + System::new(length_unit.clone(), bad_types, positions, cell, pbc), + "invalid parameter: `types` must be a tensor of 32-bit integers", + ); + + let bad_types: DLPackTensor = Array2::::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap().try_into().unwrap(); + let positions = positions_tensor(2, "f32"); + let cell = cell_tensor(0.0, "f32"); + let pbc = pbc_tensor(&[true, true, true]); + assert_error( + System::new(length_unit.clone(), bad_types, positions, cell, pbc), + "invalid parameter: `types` must be a (n_atoms,) tensor, got a tensor with shape [2, 2]", + ); + + let types = type_tensor(&[1]); + let bad_positions: DLPackTensor = Array2::::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap().try_into().unwrap(); + let cell = cell_tensor(0.0, "f32"); + let pbc = pbc_tensor(&[true, true, true]); + assert_error( + System::new(length_unit.clone(), types, bad_positions, cell, pbc), + "invalid parameter: `positions` must be a tensor of 32 or 64-bit floating point data", + ); + + let types = type_tensor(&[1, 6]); + let bad_positions = Array2::::from_shape_vec((2, 2), vec![0.0; 4]).unwrap().try_into().unwrap(); + let cell = cell_tensor(0.0, "f32"); + let pbc = pbc_tensor(&[true, true, true]); + assert_error( + System::new("Angstrom".into(), types, bad_positions, cell, pbc), + "invalid parameter: `positions` must be a (n_atoms x 3) tensor, got a tensor with shape [2, 2]", + ); + + let types = type_tensor(&[1, 6]); + let positions = positions_tensor(2, "f32"); + let bad_cell = Array2::::from_shape_vec((2, 3), vec![0.0; 6]).unwrap().try_into().unwrap(); + let pbc = pbc_tensor(&[true, true, true]); + assert_error( + System::new(length_unit.clone(), types, positions, bad_cell, pbc), + "invalid parameter: `cell` must be a (3 x 3) tensor, got a tensor with shape [2, 3]", + ); + + let types = type_tensor(&[1, 6]); + let positions = positions_tensor(2, "f32"); + let cell = cell_tensor(0.0, "f64"); + let pbc = pbc_tensor(&[true, true, true]); + assert_error( + System::new(length_unit.clone(), types, positions, cell, pbc), + "invalid parameter: `cell` must have the same dtype as `positions`", + ); + + let bad_pbc_dtype: DLPackTensor = Array1::::from_vec(vec![1, 0, 1]).try_into().unwrap(); + let types = type_tensor(&[1, 6]); + let positions = positions_tensor(2, "f32"); + let cell = cell_tensor(0.0, "f32"); + assert_error( + System::new(length_unit.clone(), types, positions, cell, bad_pbc_dtype), + "invalid parameter: `pbc` must be a tensor of booleans", + ); + + let types = type_tensor(&[1, 6]); + let positions = positions_tensor(2, "f32"); + let cell = cell_tensor(0.0, "f32"); + let bad_pbc = pbc_tensor(&[true, true]); + assert_error( + System::new(length_unit, types, positions, cell, bad_pbc), + "invalid parameter: `pbc` must contain 3 entries, got a tensor with shape [2]", + ); + } + + #[test] + fn system_periodic() { + let length_unit = "Angstrom".to_string(); + + // valid periodicity combinations: (1) fully periodic + let types = type_tensor(&[1]); + let positions = positions_tensor(1, "f32"); + let cell = cell_tensor(10.0, "f32"); + let pbc = pbc_tensor(&[true, true, true]); + System::new(length_unit.clone(), types, positions, cell, pbc).unwrap(); + + // (2) fully non-periodic with zero cell + let types = type_tensor(&[1]); + let positions = positions_tensor(1, "f32"); + let cell = cell_tensor(0.0, "f32"); + let pbc = pbc_tensor(&[false, false, false]); + System::new(length_unit.clone(), types, positions, cell, pbc).unwrap(); + + // (3) mixed periodic/non-periodic + let types = type_tensor(&[1]); + let positions = positions_tensor(1, "f32"); + let cell: DLPackTensor = Array2::::from_shape_vec( + (3, 3), + vec![10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0], + ).unwrap().try_into().unwrap(); + let pbc = pbc_tensor(&[true, false, true]); + System::new(length_unit.clone(), types, positions, cell, pbc).unwrap(); + + // invalid periodicity/cell + let types = type_tensor(&[1]); + let positions = positions_tensor(1, "f32"); + let cell = cell_tensor(10.0, "f32"); + let pbc = pbc_tensor(&[true, false, true]); + assert_error( + System::new(length_unit.clone(), types, positions, cell, pbc), + "invalid parameter: invalid cell: for non-periodic dimensions, the corresponding cell vector must be zero, but cell[1] contains non-zero values", + ); + } + + #[test] + fn add_pairs() { + let mut system = System::new( + "Angstrom".into(), + type_tensor(&[1, 6, 8]), + positions_tensor(3, "f32"), + cell_tensor(10.0, "f32"), + pbc_tensor(&[true, true, true]), + ).unwrap(); + + let options = PairListOptions { cutoff: 3.5, full_list: true, strict: false, requestors: vec![] }; + let pairs = valid_pair_block("f32"); + system.add_pairs(options.clone(), pairs).unwrap(); + assert_eq!(system.known_pairs().len(), 1); + assert_eq!(system.get_pairs(&options).unwrap().properties().names(), ["distance"]); + + let options_with_requestor = PairListOptions { + cutoff: 3.5, + full_list: true, + strict: false, + requestors: vec!["test-requestor".into()], + }; + // TODO: check that this is the exact same block once we can get the + // pointer to check for id. + assert!(system.get_pairs(&options_with_requestor).is_some()); + + system.add_pairs( + PairListOptions { cutoff: 5.0, full_list: false, strict: true, requestors: vec![] }, + valid_pair_block("f32"), + ).unwrap(); + assert_eq!(system.known_pairs().len(), 2); + } + + + #[test] + fn custom_data() { + let mut system = System::new( + "Angstrom".into(), + type_tensor(&[1, 6, 8]), + positions_tensor(3, "f32"), + cell_tensor(10.0, "f32"), + pbc_tensor(&[true, true, true]), + ).unwrap(); + + let data = valid_custom_data("f32"); + system.add_custom_data("test::my_data".into(), data, false).unwrap(); + assert_eq!(system.known_custom_data(), vec!["test::my_data"]); + assert_eq!(system.get_custom_data("test::my_data").unwrap().keys().names(), ["key"]); + + assert_error( + system.add_custom_data("test::my_data".into(), valid_custom_data("f32"), false), + "invalid parameter: custom data 'test::my_data' is already present in this system", + ); + + let replacement = valid_custom_data("f32"); + system.add_custom_data("test::my_data".into(), replacement, true).unwrap(); + assert_eq!(system.known_custom_data(), vec!["test::my_data"]); + + let mut system = System::new( + "Angstrom".into(), + type_tensor(&[1, 6, 8]), + positions_tensor(3, "f32"), + cell_tensor(10.0, "f32"), + pbc_tensor(&[true, true, true]), + ).unwrap(); + system.add_custom_data("test::a".into(), valid_custom_data("f32"), false).unwrap(); + system.add_custom_data("test::b".into(), valid_custom_data("f32"), false).unwrap(); + let mut names = system.known_custom_data(); + names.sort_unstable(); + assert_eq!(names, vec!["test::a", "test::b"]); + + // TODO: check we get back the same pointer + assert!(system.get_custom_data("test::a").is_ok()); + assert!(system.get_custom_data("test::b").is_ok()); + + assert_error( + system.get_custom_data("no_such_data"), + "invalid parameter: no data for 'no_such_data' found in this system", + ); + } + + #[test] + fn custom_data_validation() { + let mut system = System::new( + "Angstrom".into(), + type_tensor(&[1, 6, 8]), + positions_tensor(3, "f32"), + cell_tensor(10.0, "f32"), + pbc_tensor(&[true, true, true]), + ).unwrap(); + for name in ["types", "type", "Positions", "position", "CELL", "neighbors", "neighbor", "pair", "pairs", "Types", "POSITIONS", "Cell", "Neighbors"] { + let data = valid_custom_data("f32"); + assert_error( + system.add_custom_data(name.to_string(), data, false), + &format!("invalid parameter: custom data can not be named '{}'", name), + ); + } + + assert_error( + system.add_custom_data("my_data".into(), valid_custom_data("f32"), false), + "invalid parameter: 'my_data' is not a standard quantity name; custom quantity names must use '::'", + ); + + let keys = Labels::empty(vec!["key"]); + let empty = TensorMap::new(keys, vec![]).unwrap(); + assert_error( + system.add_custom_data("test::empty".into(), empty, false), + "invalid parameter: custom data 'test::empty' has no blocks", + ); + + let dtype_mismatch = valid_custom_data("f64"); + assert_error( + system.add_custom_data("test::dtype".into(), dtype_mismatch, false), + "invalid parameter: dtype of custom data 'test::dtype' does not match this system dtype", + ); } } From f56f2abc7ca8d3d0bf9399160df356856bfe77d4 Mon Sep 17 00:00:00 2001 From: Qianjun Xu <92628709+GardevoirX@users.noreply.github.com> Date: Wed, 3 Jun 2026 13:53:45 +0200 Subject: [PATCH 18/20] Implement `mta_format_metadata` --- metatomic-core/src/c_api/model.rs | 16 ++- metatomic-core/src/metadata.rs | 219 +++++++++++++++++++++++++++++- metatomic-core/tests/misc.cpp | 46 +++++++ 3 files changed, 277 insertions(+), 4 deletions(-) diff --git a/metatomic-core/src/c_api/model.rs b/metatomic-core/src/c_api/model.rs index d2050643b..c8d321c6b 100644 --- a/metatomic-core/src/c_api/model.rs +++ b/metatomic-core/src/c_api/model.rs @@ -1,6 +1,9 @@ use std::ffi::{c_void, c_char}; use metatensor::c_api::{mts_labels_t, mts_tensormap_t}; +use super::catch_unwind; +use crate::Error; + use super::{mta_status_t, mta_string_t, mta_system_t}; /// A model that computes physical properties of atomistic systems. @@ -224,5 +227,16 @@ pub unsafe extern "C" fn mta_format_metadata( metadata: *const c_char, printed: *mut mta_string_t, ) -> mta_status_t { - todo!() + catch_unwind(|| { + check_pointers_non_null!(metadata, printed); + let metadata_cstr = std::ffi::CStr::from_ptr(metadata); + let metadata_str = metadata_cstr.to_str().map_err(|_| { + Error::InvalidParameter("metadata is not valid UTF-8".into()) + })?; + let metadata_json = json::parse(metadata_str).map_err(|e| { + Error::Serialization(format!("invalid JSON for ModelMetadata: {e}")) + })?; + *printed = mta_string_t::new(crate::ModelMetadata::try_from(&metadata_json)?.print()); + Ok(()) + }) } diff --git a/metatomic-core/src/metadata.rs b/metatomic-core/src/metadata.rs index a0f0c3875..7531d468a 100644 --- a/metatomic-core/src/metadata.rs +++ b/metatomic-core/src/metadata.rs @@ -188,6 +188,63 @@ impl<'a> TryFrom<&'a JsonValue> for References { } +fn normalize_whitespace(data: &str) -> String { + let mut normalized_string = String::new(); + for c in data.chars() { + if c == '\n' || c == '\r' || c == '\t' { + normalized_string.push(' '); + } else { + normalized_string.push(c); + } + } + normalized_string +} + + +fn wrap_80_chars(output: &mut String, data: &str, indent: &str) -> Result<(), Error> { + let string = normalize_whitespace(data); + let mut chars: Vec = string.chars().collect(); + let line_length = 80 - indent.len(); + assert!(line_length > 50); + let mut first_line = true; + loop { + if chars.len() <= line_length { + // last line + if !first_line { + output.push_str(indent); + } + output.extend(chars); + break; + } else { + // backtrack to find the end of a word + let mut word_found = false; + for i in (0..line_length - 1).rev() { + if chars[i] == ' ' { + word_found = true; + // print the current line + if !first_line { + output.push_str(indent); + } + output.extend(chars.drain(0..i)); + output.push('\n'); + // remove the space + chars.remove(0); + first_line = false; + break; + } + } + + if !word_found { + // this is only hit if a single word takes a full line. + return Err(Error::InvalidParameter("some words are too long to be wrapped, make them shorter".into())); + } + } + } + + Ok(()) +} + + /// Metadata about a model #[derive(Debug, Clone)] pub struct ModelMetadata { @@ -270,13 +327,135 @@ impl<'a> TryFrom<&'a JsonValue> for ModelMetadata { extra.insert(key.to_string(), value.to_string()); } - Ok(ModelMetadata { + // Validate the contents of `authors` and `references` + for author in &authors { + if author.is_empty() { + return Err(Error::InvalidParameter("author can not be empty string in ModelMetadata".into())); + } + } + + let References { + model, + architecture, + implementation, + } = &references; + for m in model.iter() { + if m.is_empty() { + return Err(Error::InvalidParameter("reference can not be empty string (in 'model' section)".into())); + } + } + for a in architecture.iter() { + if a.is_empty() { + return Err(Error::InvalidParameter("reference can not be empty string (in 'architecture' section)".into())); + } + } + for i in implementation.iter() { + if i.is_empty() { + return Err(Error::InvalidParameter("reference can not be empty string (in 'implementation' section)".into())); + } + } + + let metadata = ModelMetadata { name: name.to_string(), authors: authors, description: description, references: references, extra: extra, - }) + }; + Ok(metadata) + } +} + +impl ModelMetadata{ + fn validate(&self) -> Result<(), Error> { + for author in &self.authors { + if author.is_empty() { + return Err(Error::InvalidParameter("author can not be empty string in ModelMetadata".into())); + } + } + + let References { + model, + architecture, + implementation, + } = &self.references; + for m in model.iter() { + if m.is_empty() { + return Err(Error::InvalidParameter("reference can not be empty string (in 'model' section)".into())); + } + } + for a in architecture.iter() { + if a.is_empty() { + return Err(Error::InvalidParameter("reference can not be empty string (in 'architecture' section)".into())); + } + } + for i in implementation.iter() { + if i.is_empty() { + return Err(Error::InvalidParameter("reference can not be empty string (in 'implementation' section)".into())); + } + } + + Ok(()) + } + pub fn print(&self) -> String { + let mut output = String::new(); + if self.name.is_empty() { + output.push_str("This is an unnamed model\n"); + output.push_str("========================\n"); + } else { + output.push_str(&format!("This is the {} model\n", &self.name)); + output.push_str(&format!("============{}======\n", "=".repeat(self.name.len()))); + } + if !self.description.is_empty() { + output.push_str("\n"); + let _ = wrap_80_chars(&mut output, &(self.description), ""); + output.push_str("\n"); + } + + if !self.authors.is_empty() { + output.push_str("\nModel authors\n-------------\n\n"); + for author in self.authors.iter() { + output.push_str("- "); + let _ = wrap_80_chars(&mut output, &author, " "); + output.push_str("\n"); + } + } + + let mut references_output = String::new(); + if !self.references.model.is_empty() { + references_output.push_str("- about this specific model:\n"); + for reference in self.references.model.iter() { + references_output.push_str(" * "); + let _ = wrap_80_chars(&mut references_output, &reference, " "); + references_output.push_str("\n"); + } + } + + if !self.references.architecture.is_empty() { + references_output.push_str("- about the architecture of this model:\n"); + for reference in self.references.architecture.iter() { + references_output.push_str(" * "); + let _ = wrap_80_chars(&mut references_output, reference, " "); + references_output.push_str("\n"); + } + } + + if !self.references.implementation.is_empty() { + references_output.push_str("- about the implementation of this model:\n"); + for reference in self.references.implementation.iter() { + references_output.push_str(" * "); + let _ = wrap_80_chars(&mut references_output, reference, " "); + references_output.push_str("\n"); + } + } + + if !references_output.is_empty() { + output.push_str("\nModel references\n----------------\n\n"); + output.push_str("Please cite the following references when using this model:\n"); + output.push_str(&references_output); + } + + output } } @@ -486,6 +665,7 @@ impl<'a> TryFrom<&'a JsonValue> for ModelCapabilities { } } + #[cfg(test)] mod tests { mod pair_list_options { @@ -602,7 +782,8 @@ mod tests { } mod model_metadata { - use super::super::*; + +use super::super::*; fn example() -> ModelMetadata { ModelMetadata { @@ -704,6 +885,38 @@ mod tests { assert_eq!(error.to_string(), expected); } } + + #[test] + fn printing() { + let metadata = example(); + let output = metadata.print(); + let expected = String::from( + "This is the test-model model +============================ + +A test model + +Model authors +------------- + +- Alice +- Bob + +Model references +---------------- + +Please cite the following references when using this model: +- about this specific model: + * doi:10.1234/test +- about the architecture of this model: + * doi:10.1234/arch +- about the implementation of this model: + * https://github.com/test +" +); + + assert_eq!(output, expected); + } } mod model_capabilities { diff --git a/metatomic-core/tests/misc.cpp b/metatomic-core/tests/misc.cpp index 8c66d7561..93baecb67 100644 --- a/metatomic-core/tests/misc.cpp +++ b/metatomic-core/tests/misc.cpp @@ -70,3 +70,49 @@ TEST_CASE("mta_unit_conversion_factor") { "invalid parameter: dimension mismatch in unit conversion: " "'m' has dimension [L] but 'kg' has dimension [M]"); } + +TEST_CASE("mta_format_metadata") { + std::string json =R"({ + "type": "metatomic_model_metadata", + "name": "name", + "description": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation.", + "authors": ["Short author", "Some extremely long author that will take more than one line in the printed output"], + "references": { + "architecture": ["ref-2", "ref-3"], + "model": ["a very long reference that will take more than one line in the printed output"], + "implementation": [] + }, + "extra": {} +})"; + auto* mta_string = mta_string_create(""); + REQUIRE(mta_string != nullptr); + auto status = mta_format_metadata(json.c_str(), &mta_string); + REQUIRE(status == MTA_SUCCESS); + const auto expected = R"(This is the name model +====================== + +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor +incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis +nostrud exercitation. + +Model authors +------------- + +- Short author +- Some extremely long author that will take more than one line in the printed + output + +Model references +---------------- + +Please cite the following references when using this model: +- about this specific model: + * a very long reference that will take more than one line in the printed + output +- about the architecture of this model: + * ref-2 + * ref-3 +)"; + CHECK(std::string(mta_string_view(mta_string)) == expected); + mta_string_free(mta_string); +} From 890c0c99b28376fa519eae05650f82524d5bbd44 Mon Sep 17 00:00:00 2001 From: Rocco Meli Date: Tue, 2 Jun 2026 10:56:37 +0200 Subject: [PATCH 19/20] Add unit conversion and error handling to C++ --- metatomic-core/include/metatomic.hpp | 1 + metatomic-core/include/metatomic/errors.hpp | 108 ++++++++++++++++++++ metatomic-core/include/metatomic/utils.hpp | 15 +++ metatomic-core/tests/misc.cpp | 69 ++++++++----- 4 files changed, 170 insertions(+), 23 deletions(-) create mode 100644 metatomic-core/include/metatomic/errors.hpp diff --git a/metatomic-core/include/metatomic.hpp b/metatomic-core/include/metatomic.hpp index 3b5c8ac2a..e41f09542 100644 --- a/metatomic-core/include/metatomic.hpp +++ b/metatomic-core/include/metatomic.hpp @@ -2,3 +2,4 @@ #include "metatomic/system.hpp" // IWYU pragma: export #include "metatomic/model.hpp" // IWYU pragma: export #include "metatomic/plugin.hpp" // IWYU pragma: export +#include "metatomic/errors.hpp" // IWYU pragma: export diff --git a/metatomic-core/include/metatomic/errors.hpp b/metatomic-core/include/metatomic/errors.hpp new file mode 100644 index 000000000..a926b388b --- /dev/null +++ b/metatomic-core/include/metatomic/errors.hpp @@ -0,0 +1,108 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace metatomic { + + /// Exception class used for all errors in metatomic + class Error: public std::runtime_error { + public: + /// Create a new MetatomicError with the given `message` + Error(const std::string& message): std::runtime_error(message) {} + }; + + namespace details { + /// Check if a return status from the C API indicates an error, and if it is + /// the case, throw an exception of type `metatomic::Error` with the last + /// error message from the library. + inline void check_status(mta_status_t status) { + if (status == MTA_SUCCESS) { + return; + } else if (status == MTA_MODEL_NOT_SUPPORTED_ERROR) { + const char* message = nullptr; + const char* origin = nullptr; + void* data = nullptr; + mta_last_error(&message, &origin, &data); + if (origin != nullptr &&std::strcmp(origin, "C++ exception") == 0 && data != nullptr) { + std::rethrow_exception(*static_cast(data)); + } else { + throw Error(message == nullptr ? "unknown error" : message); + } + } else { + const char* message = nullptr; + mta_last_error(&message, nullptr, nullptr); + throw Error(message == nullptr ? "unknown error" : message); + } + } + + /// Call the given `function` with the given `args` (the function should + /// return an `mta_status_t`), catching any C++ exception, and translating + /// them to native metatomic error code. + /// + /// This is required to prevent callbacks unwinding through the C API. + template + inline mta_status_t catch_exceptions(Function function, Args ...args) { + try { + function(std::move(args)...); + return MTA_SUCCESS; + } catch (...) { + auto* exception_ptr = new std::exception_ptr(std::current_exception()); + + const char* message = nullptr; + try { + std::rethrow_exception(*exception_ptr); + } catch (const std::exception& e) { + message = e.what(); + } catch (...) { + message = "C++ code threw an exception that was not a std::exception"; + } + + auto status = mta_set_last_error( + message, + "C++ exception", + exception_ptr, + [](void *ptr) { delete static_cast(ptr); } + ); + + if (status != MTA_SUCCESS) { + // If we failed to set the error, we are in a very bad state, + // but we should still try to report the original error + // message if possible. + std::fprintf(stderr, "INTERNAL ERROR: unable to set last error after C++ callback failure (status: %d). ", status); + if (message != nullptr) { + fprintf(stderr, "C++ error was: %s\n", message); + } else { + fprintf(stderr, "Unknown C++ error\n"); + } + delete exception_ptr; + } + + return MTA_MODEL_NOT_SUPPORTED_ERROR; + } + } + + /// Check if a pointer allocated by the C API is null, and if it is the + /// case, throw an exception of type `metatomic::Error` with the last + /// error message from the library. + inline void check_pointer(const void* pointer) { + if (pointer == nullptr) { + const char* message = nullptr; + const char* origin = nullptr; + void* data = nullptr; + mta_last_error(&message, &origin, &data); + if (std::strcmp(origin, "C++ exception") == 0 && data != nullptr) { + std::rethrow_exception(*static_cast(data)); + } else { + throw Error(message); + } + } + } + } // namespace details + +} // namespace metatomic diff --git a/metatomic-core/include/metatomic/utils.hpp b/metatomic-core/include/metatomic/utils.hpp index 1cae91bdf..402fee1a7 100644 --- a/metatomic-core/include/metatomic/utils.hpp +++ b/metatomic-core/include/metatomic/utils.hpp @@ -1,7 +1,22 @@ #pragma once +#include + #include +#include namespace metatomic { + inline double unit_conversion_factor( + const std::string& from_unit, + const std::string& to_unit + ) { + double conversion = 0.0; + + auto status = mta_unit_conversion_factor(from_unit.c_str(), to_unit.c_str(), &conversion); + details::check_status(status); + + return conversion; + } + } // namespace metatomic diff --git a/metatomic-core/tests/misc.cpp b/metatomic-core/tests/misc.cpp index 93baecb67..0028f1b88 100644 --- a/metatomic-core/tests/misc.cpp +++ b/metatomic-core/tests/misc.cpp @@ -3,6 +3,7 @@ #include #include "metatomic.h" +#include "metatomic.hpp" TEST_CASE("Version macros") { @@ -48,30 +49,52 @@ TEST_CASE("mta_string_t") { mta_string_free(nullptr); } -TEST_CASE("mta_unit_conversion_factor") { - double factor = 0.0; - - // same unit -> factor = 1.0 - auto status = mta_unit_conversion_factor("m", "m", &factor); - REQUIRE(status == MTA_SUCCESS); - CHECK(factor == 1.0); - - // kJ/mol -> eV - CHECK(mta_unit_conversion_factor("kJ/mol", "eV", &factor) == MTA_SUCCESS); - CHECK(factor == Approx(0.010364269656262174).epsilon(1e-15)); - - // dimension mismatch -> error - status = mta_unit_conversion_factor("m", "kg", &factor); - REQUIRE(status != MTA_SUCCESS); - - const char* error_msg = nullptr; - mta_last_error(&error_msg, nullptr, nullptr); - CHECK(std::string(error_msg) == - "invalid parameter: dimension mismatch in unit conversion: " - "'m' has dimension [L] but 'kg' has dimension [M]"); +TEST_CASE("unit conversion factor") { + SECTION("C API") { + double factor = 0.0; + + // same unit -> factor = 1.0 + auto status = mta_unit_conversion_factor("m", "m", &factor); + REQUIRE(status == MTA_SUCCESS); + CHECK(factor == 1.0); + + // kJ/mol -> eV + CHECK(mta_unit_conversion_factor("kJ/mol", "eV", &factor) == MTA_SUCCESS); + CHECK(factor == Approx(0.010364269656262174).epsilon(1e-15)); + + // dimension mismatch -> error + status = mta_unit_conversion_factor("m", "kg", &factor); + REQUIRE(status != MTA_SUCCESS); + + const char* error_msg = nullptr; + mta_last_error(&error_msg, nullptr, nullptr); + CHECK(std::string(error_msg) == + "invalid parameter: dimension mismatch in unit conversion: " + "'m' has dimension [L] but 'kg' has dimension [M]" + ); + } + + SECTION("C++ API") { + // same unit -> factor = 1.0 + auto factor = metatomic::unit_conversion_factor("m", "m"); + CHECK(factor == 1.0); + + // kJ/mol -> eV + factor = metatomic::unit_conversion_factor("kJ/mol", "eV"); + CHECK(factor == Approx(0.010364269656262174).epsilon(1e-15)); + + // dimension mismatch -> error + try{ + factor = metatomic::unit_conversion_factor("m", "kg"); + } + catch(metatomic::Error& e){ + CHECK(std::string(e.what()) == "invalid parameter: dimension mismatch in unit conversion: 'm' has dimension [L] but 'kg' has dimension [M]"); + } + } } -TEST_CASE("mta_format_metadata") { + +TEST_CASE("metatdata formatting") { std::string json =R"({ "type": "metatomic_model_metadata", "name": "name", @@ -88,7 +111,7 @@ TEST_CASE("mta_format_metadata") { REQUIRE(mta_string != nullptr); auto status = mta_format_metadata(json.c_str(), &mta_string); REQUIRE(status == MTA_SUCCESS); - const auto expected = R"(This is the name model + const auto* expected = R"(This is the name model ====================== Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor From 0dbcd7ce8d7e06021b6d05ed856a71ff3b508a50 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Wed, 3 Jun 2026 13:51:04 +0200 Subject: [PATCH 20/20] Add C++ documentation --- docs/src/core/index.rst | 1 + docs/src/core/reference/cxx/index.rst | 17 +++++++++++++++++ docs/src/core/reference/cxx/misc.rst | 14 ++++++++++++++ docs/src/core/reference/cxx/model.rst | 2 ++ docs/src/core/reference/cxx/plugin.rst | 2 ++ docs/src/core/reference/cxx/system.rst | 2 ++ docs/src/core/units.rst | 20 ++++++++++++-------- metatomic-core/include/metatomic.h | 20 ++++++++++---------- metatomic-core/include/metatomic/utils.hpp | 18 ++++++++++++++++-- metatomic-core/src/c_api/utils.rs | 20 ++++++++++---------- 10 files changed, 86 insertions(+), 30 deletions(-) create mode 100644 docs/src/core/reference/cxx/index.rst create mode 100644 docs/src/core/reference/cxx/misc.rst create mode 100644 docs/src/core/reference/cxx/model.rst create mode 100644 docs/src/core/reference/cxx/plugin.rst create mode 100644 docs/src/core/reference/cxx/system.rst diff --git a/docs/src/core/index.rst b/docs/src/core/index.rst index 2a3691316..66bbcffb7 100644 --- a/docs/src/core/index.rst +++ b/docs/src/core/index.rst @@ -8,6 +8,7 @@ WIP :maxdepth: 2 reference/c/index + reference/cxx/index reference/json-formats units diff --git a/docs/src/core/reference/cxx/index.rst b/docs/src/core/reference/cxx/index.rst new file mode 100644 index 000000000..9a4a7add3 --- /dev/null +++ b/docs/src/core/reference/cxx/index.rst @@ -0,0 +1,17 @@ +.. _cxx-api-core: + +C++ API reference +================= + +WIP + +The functions and types provided in ``metatomic.hpp`` can be grouped in four +main groups: + +.. toctree:: + :maxdepth: 1 + + system + model + plugin + misc diff --git a/docs/src/core/reference/cxx/misc.rst b/docs/src/core/reference/cxx/misc.rst new file mode 100644 index 000000000..26ba29607 --- /dev/null +++ b/docs/src/core/reference/cxx/misc.rst @@ -0,0 +1,14 @@ +Miscellaneous +============= + + +Error handling +^^^^^^^^^^^^^^ + +.. doxygenclass:: metatomic::Error + + +Unit conversion +^^^^^^^^^^^^^^^ + +.. doxygenfunction:: metatomic::unit_conversion_factor diff --git a/docs/src/core/reference/cxx/model.rst b/docs/src/core/reference/cxx/model.rst new file mode 100644 index 000000000..75338c89d --- /dev/null +++ b/docs/src/core/reference/cxx/model.rst @@ -0,0 +1,2 @@ +Model +===== diff --git a/docs/src/core/reference/cxx/plugin.rst b/docs/src/core/reference/cxx/plugin.rst new file mode 100644 index 000000000..67cd50b04 --- /dev/null +++ b/docs/src/core/reference/cxx/plugin.rst @@ -0,0 +1,2 @@ +Plugin system +============= diff --git a/docs/src/core/reference/cxx/system.rst b/docs/src/core/reference/cxx/system.rst new file mode 100644 index 000000000..3dcbaeea1 --- /dev/null +++ b/docs/src/core/reference/cxx/system.rst @@ -0,0 +1,2 @@ +System +====== diff --git a/docs/src/core/units.rst b/docs/src/core/units.rst index c9415ed9f..6c50603ca 100644 --- a/docs/src/core/units.rst +++ b/docs/src/core/units.rst @@ -6,21 +6,25 @@ Units Models in metatensor can use arbitrary units for their inputs and outputs. The unit conversion system allows models to specify the units they expect and receive data in any compatible unit, with automatic conversion handled by -:c:func:`mta_execute_model`. +during model execution. -The :c:func:`mta_unit_conversion_factor` function parses two unit expressions, -checks that they have compatible physical dimensions, and returns the -multiplicative conversion factor: +Unit parsing is handled by one of the following functions: -.. code-block:: c +- :c:func:`mta_unit_conversion_factor` in C +- :cpp:func:`metatomic::unit_conversion_factor` in C++ + +These functions parses two unit expressions, checks that they have compatible +physical dimensions, and returns the multiplicative conversion factor. For +example, in C++: + +.. code-block:: C++ // How many eV are in one kJ/mol? - double factor; - mta_unit_conversion_factor("kJ/mol", "eV", &factor); + double factor = metatomic::unit_conversion_factor("kJ/mol", "eV"); // factor ≈ 0.01036 // How many GPa are in one eV/A^3? - mta_unit_conversion_factor("eV/A^3", "GPa", &factor); + factor = metatomic::unit_conversion_factor("eV/A^3", "GPa"); // factor ≈ 160.22 If either (or both) unit strings are empty, the conversion returns ``1.0`` diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index 9273007d0..e7de4118b 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -380,18 +380,18 @@ void mta_string_free(mta_string_t string); const char *mta_string_view(mta_string_t string); /** - * Get the multiplicative conversion factor to use to convert from - * `from_unit` to `to_unit`. Both units are parsed as expressions (e.g. - * "kJ/mol/A^2", "(eV*u)^(1/2)") and their dimensions must match. + * Get the multiplicative conversion factor to use to convert from `from_unit` + * to `to_unit`. Both units are parsed as expressions (e.g. `kJ / mol / A^2`, + * `(eV * u)^(1/2)`) and their dimensions must match. * - * Unit expressions are built from base units combined with `*`, `/`, `^`, - * and parentheses. Unit lookup is case-insensitive, and whitespace is - * ignored. For example: + * @verbatim embed:rst:leading-asterisk * - * - `"kJ/mol"` -- energy per mole - * - `"eV/Angstrom^3"` -- pressure - * - `"(eV*u)^(1/2)"` -- momentum (fractional powers) - * - `"Hartree/Bohr"` -- force in atomic units + * .. seealso:: + * + * The general documentation for :ref:`units`, with the expression + * syntax and list of supported base units. + * + * @endverbatim * * @param from_unit A null-terminated C string containing the unit to convert from. * @param to_unit A null-terminated C string containing the unit to convert to. diff --git a/metatomic-core/include/metatomic/utils.hpp b/metatomic-core/include/metatomic/utils.hpp index 402fee1a7..38f6aaf79 100644 --- a/metatomic-core/include/metatomic/utils.hpp +++ b/metatomic-core/include/metatomic/utils.hpp @@ -6,7 +6,22 @@ #include namespace metatomic { - + /// Get the multiplicative conversion factor to use to convert from + /// `from_unit` to `to_unit`. Both units are parsed as expressions + /// (e.g. `kJ / mol / A^2`, `(eV * u)^(1/2)`) and their dimensions must + /// match. + /// + /// @verbatim embed:rst:leading-slashes + /// + /// .. seealso:: + /// + /// The general documentation for :ref:`units`, with the expression + /// syntax and list of supported base units. + /// + /// @endverbatim + /// + /// @param from_unit the unit to convert from + /// @param to_unit the unit to convert to inline double unit_conversion_factor( const std::string& from_unit, const std::string& to_unit @@ -18,5 +33,4 @@ namespace metatomic { return conversion; } - } // namespace metatomic diff --git a/metatomic-core/src/c_api/utils.rs b/metatomic-core/src/c_api/utils.rs index 448c2b5d4..38e5222c4 100644 --- a/metatomic-core/src/c_api/utils.rs +++ b/metatomic-core/src/c_api/utils.rs @@ -147,18 +147,18 @@ pub unsafe extern "C" fn mta_string_view( return result; } -/// Get the multiplicative conversion factor to use to convert from -/// `from_unit` to `to_unit`. Both units are parsed as expressions (e.g. -/// "kJ/mol/A^2", "(eV*u)^(1/2)") and their dimensions must match. +/// Get the multiplicative conversion factor to use to convert from `from_unit` +/// to `to_unit`. Both units are parsed as expressions (e.g. `kJ / mol / A^2`, +/// `(eV * u)^(1/2)`) and their dimensions must match. /// -/// Unit expressions are built from base units combined with `*`, `/`, `^`, -/// and parentheses. Unit lookup is case-insensitive, and whitespace is -/// ignored. For example: +/// @verbatim embed:rst:leading-asterisk /// -/// - `"kJ/mol"` -- energy per mole -/// - `"eV/Angstrom^3"` -- pressure -/// - `"(eV*u)^(1/2)"` -- momentum (fractional powers) -/// - `"Hartree/Bohr"` -- force in atomic units +/// .. seealso:: +/// +/// The general documentation for :ref:`units`, with the expression +/// syntax and list of supported base units. +/// +/// @endverbatim /// /// @param from_unit A null-terminated C string containing the unit to convert from. /// @param to_unit A null-terminated C string containing the unit to convert to.