diff --git a/docs/api-reference/atomic.md b/docs/api-reference/atomic.md index f5c01ad..5caa3a3 100644 --- a/docs/api-reference/atomic.md +++ b/docs/api-reference/atomic.md @@ -53,6 +53,8 @@ Atom definitions, element properties, electronic structure (SCF), orbital evalua ## Pretraining ```{eval-rst} +.. autoclass:: jaqmc.utils.atomic.pretrain.PretrainReferenceConfig + .. autofunction:: jaqmc.utils.atomic.pretrain.make_pretrain_log_amplitude .. autofunction:: jaqmc.utils.atomic.pretrain.make_pretrain_loss ``` diff --git a/docs/guide/estimators/ecp.md b/docs/guide/estimators/ecp.md index dd1e4fe..69c5161 100644 --- a/docs/guide/estimators/ecp.md +++ b/docs/guide/estimators/ecp.md @@ -44,6 +44,7 @@ where $r = |\mathbf{r} - \mathbf{R}|$ is the electron-atom distance, $\mathbf{r} ## See also -- The ECP estimator is automatically added when `ecp` is set in the system configuration. See [Basis Sets and ECPs](#molecule-basis-sets-and-ecps). +- The ECP estimator is automatically added when `ecp` is set in the system + configuration. See the molecule guide's [ECP setup](#molecule-ecps). - Configuration: [Molecule](#molecule-estimators), [Solid](#solid-estimators) - API: {class}`~jaqmc.estimator.ecp.estimator.ECPEnergy` diff --git a/docs/systems/molecule/index.md b/docs/systems/molecule/index.md index 4b8fdb4..2cd1c90 100644 --- a/docs/systems/molecule/index.md +++ b/docs/systems/molecule/index.md @@ -5,7 +5,8 @@ boundary conditions. Most runs start from a YAML definition and a single `jaqmc molecule train` command. JaQMC then follows the standard molecular workflow: -1. **Hartree-Fock (HF)** computes reference orbitals with PySCF. +1. **Hartree-Fock (HF)** computes a reference electronic-structure solution with + PySCF. 2. **Pretraining** matches the neural wavefunction to those orbitals. 3. **VMC training** performs the main energy optimization. @@ -39,6 +40,12 @@ system: electron_spins: [5, 5] # [n_up, n_down] ``` +`electron_spins` gives `[n_up, n_down]` for the electrons included in the QMC +simulation. This water example is all-electron, so `[5, 5]` includes all ten +electrons. If you later add an ECP, leave out the core electrons it replaces; +`electron_spins` should count only the valence electrons JaQMC samples +explicitly. + Then run training: ```bash @@ -76,14 +83,14 @@ that generate the underlying configuration for you. ### Single Atoms For a single atom, `system.module=atom` is a shortcut. You provide the element -symbol and optional HF settings, and JaQMC fills in the matching electron spin -configuration automatically. +symbol, and JaQMC fills in the matching electron spin configuration +automatically. By default it uses the all-electron count; when `system.ecp` is +set, it uses the valence count instead. ```yaml system: module: atom symbol: Li # Element symbol (H, He, Li, Be, ...) - basis: sto-3g # Basis set for SCF initialization # ecp: ccecp # Optional: effective core potential ``` @@ -96,8 +103,9 @@ jaqmc molecule train --yml atom_li.yml workflow.save_path=./runs/atom_li ### Diatomic Molecules For common two-atom systems, `system.module=diatomic` is a shortcut. You provide -the chemical formula, bond length, and optional total spin; JaQMC places the -atoms along the z-axis and computes `electron_spins` for you. +the chemical formula, bond length, and optional spin for the simulated +electrons. JaQMC places the atoms along the z-axis and computes +`electron_spins` for you. ```yaml system: @@ -105,8 +113,7 @@ system: formula: LiH # Chemical formula (H2, LiH, N2, ClF, ...) bond_length: 3.015 # Distance between atoms unit: bohr # Length unit for bond_length - spin: 0 # n_up - n_down for the full molecule - basis: cc-pvdz + spin: 0 # n_up - n_down for electrons being simulated ``` Save as `li_h_diatomic.yml`, then run: @@ -115,36 +122,105 @@ Save as `li_h_diatomic.yml`, then run: jaqmc molecule train --yml li_h_diatomic.yml workflow.save_path=./runs/li_h_diatomic ``` -(molecule-basis-sets-and-ecps)= -## Basis Sets and ECPs +(molecule-ecps)= +## Effective core potentials + +Most examples above are all-electron calculations: JaQMC represents every +electron in the molecule explicitly. For heavier elements, you may instead +replace core electrons with an effective core potential (ECP). The core +electrons no longer appear as QMC electrons; their effect enters through the +pseudopotential, while JaQMC samples the remaining valence electrons. + +Enable an ECP by setting `system.ecp`: + +```yaml +system: + ecp: ccecp +``` + +Use an ECP designed for correlated many-body calculations rather than a +DFT-only pseudopotential. The correlation-consistent ECP family, `ccecp`, is the +usual choice for QMC runs. + +Once an ECP is enabled, `electron_spins` describes the electrons being sampled, +not the full electron count of the physical atoms. The `atom` and `diatomic` +shortcuts use `system.ecp` to choose the valence count automatically. If you +define `atoms` and `electron_spins` directly, set `electron_spins` to the +valence-electron system you want to simulate. + +For mixed systems, apply ECPs only to the elements that need them: + +```yaml +system: + ecp: + Fe: ccecp +``` + +(molecule-pretrain-reference)= +## Pretrain reference settings -The `basis` parameter controls the basis set used for the HF calculation. Any basis set supported by PySCF works: +`pretrain.reference.*` configures the PySCF Hartree-Fock calculation used to +generate the target orbitals for pretraining. In most runs, the basis is the +only reference setting you need to choose. The default is cc-pVDZ, and you can +change it with: -- Minimal: `sto-3g` (default, fast) -- Split-valence: `6-31g`, `6-311g` -- Correlation-consistent: `cc-pvdz`, `cc-pvtz`, `cc-pvqz` +```yaml +pretrain: + reference: + basis: sto-3g +``` -For heavy elements (transition metals, lanthanides), use an effective core potential (ECP) to replace core electrons with a pseudopotential, reducing the number of electrons treated explicitly: +If the system uses an ECP, choose a pretrain basis that matches that +pseudopotential. For example, with ccECP use the corresponding ccECP basis +family: ```yaml system: module: atom symbol: Fe - basis: ccecpccpvdz ecp: ccecp +pretrain: + reference: + basis: ccecpccpvdz ``` -Both `basis` and `ecp` can be specified per element: +For mixed systems, keep the same per-element split between the physical system +and the HF reference: put ECPs in `system.ecp`, and put matching PySCF basis +choices in `pretrain.reference.basis`. ```yaml system: - basis: - Fe: ccecpccpvdz - O: cc-pvdz ecp: Fe: ccecp +pretrain: + reference: + basis: + Fe: ccecpccpvdz + O: cc-pvdz ``` +When the HF calculation itself needs tuning, use the `pretrain.reference.*` +block for PySCF solver settings. JaQMC supports +`pretrain.reference.method` (`UHF` or `RHF`) and forwards additional keys to the +selected PySCF mean-field object. + +```yaml +pretrain: + reference: + method: RHF + basis: cc-pvdz + conv_tol: 1.0e-10 + max_cycle: 200 + diis_space: 12 +``` + +Use these extra keys for SCF convergence and solver behavior tuning, such as +`conv_tol`, `max_cycle`, and related PySCF options. If a key is not supported by +the selected PySCF object, JaQMC ignores it and logs a warning. + +For authoritative key definitions and defaults under `pretrain.reference.*`, see +. + ## Estimators The training stage computes energy from several components: kinetic energy, diff --git a/docs/systems/molecule/train.md b/docs/systems/molecule/train.md index b8e70b6..e9e49f0 100644 --- a/docs/systems/molecule/train.md +++ b/docs/systems/molecule/train.md @@ -185,6 +185,16 @@ Initializes the neural network to approximate Hartree-Fock orbitals before VMC training. It uses the same run, sampler, and writer schemas as the train stage, but with a different optimizer default and a workflow-wired supervised loss. +### Reference (`pretrain.reference.*`) + +The Hartree-Fock reference is the PySCF calculation JaQMC uses to generate the +target orbitals for pretraining. Most runs can keep the default settings. + +```{eval-rst} +.. config-defaults:: jaqmc.app.molecule.config.base.MoleculePretrainReferenceConfig + :prefix: pretrain.reference +``` + ### Run options (`pretrain.run.*`) ```{eval-rst} diff --git a/docs/systems/solid/eval.md b/docs/systems/solid/eval.md index 01d3dff..baa800d 100644 --- a/docs/systems/solid/eval.md +++ b/docs/systems/solid/eval.md @@ -31,6 +31,17 @@ are identical to the [training system config](#solid-train-system). Must match the training run. The effective defaults and built-in module choices are identical to the [training wavefunction config](#solid-train-wf). +## Reference (`reference.*`) + +The Hartree-Fock reference is the PySCF calculation JaQMC uses when it needs +reference orbitals or related setup from that calculation. Itis recommended to +set the values to match the reference configuration used during training. + +```{eval-rst} +.. config-defaults:: jaqmc.app.solid.config.base.SolidPretrainReferenceConfig + :prefix: reference +``` + ## Run Options (`run.*`) Evaluation reuses the same checkpointing and sampling controls as training, but diff --git a/docs/systems/solid/index.md b/docs/systems/solid/index.md index 1bec9e7..868eae6 100644 --- a/docs/systems/solid/index.md +++ b/docs/systems/solid/index.md @@ -5,7 +5,8 @@ boundary conditions. Most runs start from a YAML definition and a single `jaqmc solid train` command. JaQMC then follows the same three-stage workflow used for [molecules](../molecule/index.md): -1. **Hartree-Fock (HF)** computes reference orbitals with PySCF. +1. **Hartree-Fock (HF)** computes a reference electronic-structure solution with + PySCF. 2. **Pretraining** matches the neural wavefunction to those orbitals. 3. **VMC training** performs the main energy optimization. @@ -47,9 +48,15 @@ system: - symbol: H coords: [3.78, 3.78, 3.78] electron_spins: [2, 2] # [n_up, n_down] per primitive cell - basis: sto-3g ``` +`electron_spins` is counted per primitive cell and describes the electrons JaQMC +samples explicitly. In an all-electron solid, that is the full electron count +per primitive cell. With an ECP, core electrons are replaced by the +pseudopotential, so `electron_spins` should count only the valence electrons. +If you later expand to a supercell, JaQMC multiplies these primitive-cell counts +by the number of primitive cells in the supercell. + Then run training: ```bash @@ -78,7 +85,9 @@ that generate the underlying configuration for you. For FCC rock-salt structures such as LiH or NaCl, `system.module=rock_salt` is a shortcut. You provide the species and lattice constant, and JaQMC builds the -primitive cell and fills in the corresponding electron counts automatically. +primitive cell and fills in the corresponding electron counts automatically. It +uses all-electron counts by default; when `system.ecp` is set, it uses valence +counts instead. ```yaml system: @@ -88,7 +97,6 @@ system: lattice_constant: 4.0 # in angstrom by default unit: angstrom # or "bohr" # supercell: [2, 2, 2] # Optional diagonal supercell shorthand - basis: sto-3g ``` Save as `rock_salt.yml`, then run: @@ -100,8 +108,9 @@ jaqmc solid train --yml rock_salt.yml workflow.save_path=./runs/rock_salt ### Two-Atom Chain For simple one-dimensional test systems, `system.module=two_atom_chain` is a -shortcut. You provide the element, bond length, and optional spin; JaQMC builds -a primitive cell with two atoms along the chain direction. +shortcut. You provide the element, bond length, and optional spin for the +simulated electrons. JaQMC builds a primitive cell with two atoms along the +chain direction. ```yaml system: @@ -111,7 +120,6 @@ system: unit: bohr # or "angstrom" spin: 0 # n_up - n_down per primitive cell # supercell: 4 # Optional repetition along the chain direction - basis: sto-3g ``` Save as `two_atom_chain.yml`, then run: @@ -120,8 +128,56 @@ Save as `two_atom_chain.yml`, then run: jaqmc solid train --yml two_atom_chain.yml workflow.save_path=./runs/two_atom_chain ``` -Basis sets and ECPs work the same as for -[molecules](#molecule-basis-sets-and-ecps). +(solid-ecps)= +## Effective core potentials + +Solids use the same ECP mechanism as molecules: core electrons are replaced by a +pseudopotential, and JaQMC samples only the remaining valence electrons. Set +`system.ecp` to apply an ECP. A mapping lets you apply it only to selected +elements: + +```yaml +system: + ecp: + Li: ccecp +``` + +The `rock_salt` and `two_atom_chain` shortcuts use `system.ecp` to choose +valence electron counts automatically. If you define `atoms` and +`electron_spins` directly, set `electron_spins` to the valence-electron count per +primitive cell. + +See for the broader ECP setup guidance. + +(solid-pretrain-reference)= +## Pretrain reference settings + +`pretrain.reference.*` configures the PySCF Hartree-Fock calculation used for +pretraining. For most solid runs, the basis is the only reference setting you +need to choose. The default is cc-pVDZ, and you can change it with: + +```yaml +pretrain: + reference: + basis: sto-3g +``` + +If the system uses an ECP, choose a pretrain basis that matches that +pseudopotential: + +```yaml +system: + ecp: + Li: ccecp +pretrain: + reference: + basis: + Li: ccecpccpvdz + H: cc-pvdz +``` + +The available reference settings are shared with molecule runs; see + for the detailed discussion. ## Supercell Expansion diff --git a/docs/systems/solid/train.md b/docs/systems/solid/train.md index aa3d02c..a01c394 100644 --- a/docs/systems/solid/train.md +++ b/docs/systems/solid/train.md @@ -172,6 +172,16 @@ Initializes the neural network to approximate Hartree-Fock orbitals before VMC training. It uses the same run, sampler, and writer schemas as the train stage, but with a different optimizer default and a workflow-wired supervised loss. +### Reference (`pretrain.reference.*`) + +The Hartree-Fock reference is the PySCF calculation JaQMC uses to generate the +target orbitals for pretraining. Most runs can keep the default settings. + +```{eval-rst} +.. config-defaults:: jaqmc.app.solid.config.base.SolidPretrainReferenceConfig + :prefix: pretrain.reference +``` + ### Run options (`pretrain.run.*`) ```{eval-rst} diff --git a/src/jaqmc/app/molecule/config/__init__.py b/src/jaqmc/app/molecule/config/__init__.py index 2680d3b..29f9a28 100644 --- a/src/jaqmc/app/molecule/config/__init__.py +++ b/src/jaqmc/app/molecule/config/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) 2025-2026 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: Apache-2.0 -from .base import MoleculeConfig +from .base import MoleculeConfig, MoleculePretrainReferenceConfig -__all__ = ["MoleculeConfig"] +__all__ = ["MoleculeConfig", "MoleculePretrainReferenceConfig"] diff --git a/src/jaqmc/app/molecule/config/atom.py b/src/jaqmc/app/molecule/config/atom.py index e908b01..bf2d688 100644 --- a/src/jaqmc/app/molecule/config/atom.py +++ b/src/jaqmc/app/molecule/config/atom.py @@ -11,7 +11,6 @@ def atom_config( symbol: str = "H", electron_init_width: float = 1.0, - basis: str = "sto-3g", ecp: str | None = None, ): """Create a MoleculeConfig for a single atom. @@ -19,7 +18,6 @@ def atom_config( Args: symbol: Element symbol (e.g., "H", "Li", "Fe"). electron_init_width: Width of Gaussian for electron initialization. - basis: Basis set name. ecp: Effective core potential name. Can be None (no ECP) or a string (e.g., "ccecp"). @@ -41,6 +39,5 @@ def atom_config( atoms=[atom], electron_spins=electron_spins, electron_init_width=electron_init_width, - basis=basis, ecp=ecp, ) diff --git a/src/jaqmc/app/molecule/config/base.py b/src/jaqmc/app/molecule/config/base.py index e178328..96ceee3 100644 --- a/src/jaqmc/app/molecule/config/base.py +++ b/src/jaqmc/app/molecule/config/base.py @@ -2,12 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import field +from typing import Literal from jaqmc.utils.atomic import Atom from jaqmc.utils.atomic.atomic_system import AtomicSystemConfig +from jaqmc.utils.atomic.pretrain import PretrainReferenceConfig from jaqmc.utils.config import configurable_dataclass -__all__ = ["MoleculeConfig"] +__all__ = ["MoleculeConfig", "MoleculePretrainReferenceConfig"] @configurable_dataclass @@ -16,3 +18,9 @@ class MoleculeConfig(AtomicSystemConfig): electron_spins: tuple[int, int] = (1, 0) fixed_spins_per_atom: list[tuple[int, int]] | None = None electron_init_width: float = 1.0 + + +@configurable_dataclass +class MoleculePretrainReferenceConfig(PretrainReferenceConfig): + method: Literal["RHF", "UHF"] = "UHF" + "Variants of Hartree-Fock method." diff --git a/src/jaqmc/app/molecule/config/diatomic.py b/src/jaqmc/app/molecule/config/diatomic.py index 33557ec..e9594b0 100644 --- a/src/jaqmc/app/molecule/config/diatomic.py +++ b/src/jaqmc/app/molecule/config/diatomic.py @@ -51,7 +51,6 @@ def diatomic_config( formula: str = "H2", bond_length: float = 1.4, unit: LengthUnit = LengthUnit.bohr, - basis: str | dict[str, str] = "sto-3g", ecp: str | dict[str, str] | None = None, spin: int = 0, electron_init_width: float = 1.0, @@ -66,8 +65,6 @@ def diatomic_config( bond_length: Distance between the two atoms. unit: Length unit for ``bond_length`` and atom coordinates. Either ``"bohr"`` or ``"angstrom"``. - basis: Basis set name, or per-element mapping - (e.g., ``{"Li": "ccecpccpvdz", "H": "cc-pvdz"}``). ecp: Effective core potential specification. Can be ``None`` (all-electron), a string (e.g., ``"ccecp"``), or a per-element mapping (e.g., ``{"Li": "ccecp"}``). @@ -117,6 +114,5 @@ def diatomic_config( atoms=atoms, electron_spins=(n_alpha, n_beta), electron_init_width=electron_init_width, - basis=basis, ecp=ecp, ) diff --git a/src/jaqmc/app/molecule/workflow.py b/src/jaqmc/app/molecule/workflow.py index d52a259..faa08c7 100644 --- a/src/jaqmc/app/molecule/workflow.py +++ b/src/jaqmc/app/molecule/workflow.py @@ -8,8 +8,6 @@ import numpy as np -from jaqmc.app.molecule.config import MoleculeConfig -from jaqmc.app.molecule.data import data_init from jaqmc.estimator import EstimatorLike from jaqmc.estimator.density import CartesianAxis, CartesianDensity from jaqmc.estimator.ecp import ECPEnergy @@ -28,6 +26,8 @@ from jaqmc.workflow.stage.vmc import VMCWorkStage from jaqmc.workflow.vmc import VMCWorkflow +from .config import MoleculeConfig, MoleculePretrainReferenceConfig +from .data import data_init from .hamiltonian import potential_energy from .wavefunction import MoleculeWavefunction @@ -45,7 +45,12 @@ def default_preset(cls) -> dict[str, Any]: return { "pretrain": { "run": {"iterations": 2_000}, - "optim": {"learning_rate": {"rate": 3e-4}}, + "optim": { + "learning_rate": { + "module": "jaqmc.optimizer.schedule:Constant", + "rate": 3e-4, + } + }, }, "train": { "run": {"iterations": 200_000}, @@ -58,7 +63,8 @@ def __init__(self, cfg: ConfigManager) -> None: system_config, wf = configure_system(cfg) nspins = system_config.electron_spins - self.scf = make_scf(system_config) + pretrain_config = cfg.get("pretrain.reference", MoleculePretrainReferenceConfig) + self.scf = make_scf(pretrain_config, system_config) self.data_init = partial(data_init, system_config) sampler = cfg.get("sampler", MCMCSampler) @@ -66,7 +72,9 @@ def __init__(self, cfg: ConfigManager) -> None: orbitals_fn=wf.orbitals, scf=self.scf, nspins=nspins, full_det=wf.full_det ) pretrain_f_log_amplitude = make_pretrain_log_amplitude( - wf.logpsi, lambda data: self.scf.eval_slater(data.electrons, nspins)[1] + wf.logpsi, + lambda data: self.scf.eval_slater(data.electrons, nspins)[1], + scf_fraction=pretrain_config.sample_fraction, ) pretrain = VMCWorkStage.builder(cfg.scoped("pretrain"), wf) @@ -79,7 +87,7 @@ def __init__(self, cfg: ConfigManager) -> None: train.configure_sample_plan(wf.logpsi, {"electrons": sampler}) train.configure_optimizer(default=KFACOptimizer, f_log_psi=wf.logpsi) estimators = make_estimators( - cfg, wf, self.scf, system_config, always_enable_energy=True + cfg, wf, system_config, self.scf._mol._ecp, always_enable_energy=True ) train.configure_estimators(**estimators) train.configure_loss_grads(f_log_psi=wf.logpsi) @@ -99,21 +107,14 @@ def __init__(self, cfg: ConfigManager) -> None: self.data_init = partial(data_init, system_config) - nspins = system_config.electron_spins - core_electrons = get_core_electrons(system_config.atoms, system_config.ecp) - scf = MolecularSCF( - system_config.atoms, - nspins, - basis=system_config.basis, - ecp=system_config.ecp, - core_electrons=core_electrons, - ) - evaluation = EvaluationWorkStage.builder(cfg, wf, name="evaluation") sampler = cfg.get("sampler", MCMCSampler) evaluation.configure_sample_plan(wf.logpsi, {"electrons": sampler}) + + # Just a simple SCF object to get ecp coefficients + scf = make_scf(MoleculePretrainReferenceConfig(), system_config) eval_estimators: dict[str, EstimatorLike] = make_estimators( - cfg, wf, scf, system_config + cfg, wf, system_config, scf._mol._ecp ) evaluation.configure_estimators(**eval_estimators) @@ -140,21 +141,27 @@ def configure_system( return system_config, wf -def make_scf(system_config: MoleculeConfig) -> MolecularSCF: +def make_scf( + pretrain_config: MoleculePretrainReferenceConfig, system_config: MoleculeConfig +) -> MolecularSCF: + restricted = pretrain_config.method == "RHF" return MolecularSCF( system_config.atoms, system_config.electron_spins, - basis=system_config.basis, + basis=pretrain_config.basis, + restricted=restricted, ecp=system_config.ecp, core_electrons=get_core_electrons(system_config.atoms, system_config.ecp), + verbose=pretrain_config.verbose, + pyscf_options=pretrain_config.extra, ) def make_estimators( cfg: ConfigManagerLike, wf: MoleculeWavefunction, - scf: MolecularSCF, system_config: MoleculeConfig, + ecp_coefficients: dict[str, Any] | None = None, always_enable_energy: bool = False, ) -> dict[str, EstimatorLike]: estimators: dict[str, EstimatorLike] = {} @@ -163,12 +170,12 @@ def make_estimators( estimators["kinetic"] = cfg.get( "estimators.energy.kinetic", EuclideanKinetic(f_log_psi=wf.logpsi) ) - if scf._mol._ecp: - logger.info("ECP enabled for elements: %s", list(scf._mol._ecp.keys())) + if ecp_coefficients: + logger.info("ECP enabled for elements: %s", list(ecp_coefficients.keys())) estimators["ecp"] = cfg.get( "estimators.energy.ecp", ECPEnergy( - ecp_coefficients=scf._mol._ecp, + ecp_coefficients=ecp_coefficients, atom_symbols=[atom.symbol for atom in system_config.atoms], phase_logpsi=wf.phase_logpsi, ), diff --git a/src/jaqmc/app/solid/config/__init__.py b/src/jaqmc/app/solid/config/__init__.py index 2190f80..eaffefa 100644 --- a/src/jaqmc/app/solid/config/__init__.py +++ b/src/jaqmc/app/solid/config/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) 2025-2026 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: Apache-2.0 -from .base import SolidConfig +from .base import SolidConfig, SolidPretrainReferenceConfig -__all__ = ["SolidConfig"] +__all__ = ["SolidConfig", "SolidPretrainReferenceConfig"] diff --git a/src/jaqmc/app/solid/config/base.py b/src/jaqmc/app/solid/config/base.py index b7f90a7..0a9ccf2 100644 --- a/src/jaqmc/app/solid/config/base.py +++ b/src/jaqmc/app/solid/config/base.py @@ -2,14 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import field -from typing import Any +from typing import Any, Literal import numpy as np from jaqmc.utils.atomic.atomic_system import AtomicSystemConfig +from jaqmc.utils.atomic.pretrain import PretrainReferenceConfig from jaqmc.utils.config import configurable_dataclass -__all__ = ["SolidConfig"] +__all__ = ["SolidConfig", "SolidPretrainReferenceConfig"] @configurable_dataclass @@ -60,3 +61,9 @@ def __post_init__(self): self.scale = round(np.linalg.det(supercell_arr)) else: self.supercell_lattice = lattice_arr + + +@configurable_dataclass +class SolidPretrainReferenceConfig(PretrainReferenceConfig): + method: Literal["KRHF", "KUHF"] = "KUHF" + "Variants of Hartree-Fock method." diff --git a/src/jaqmc/app/solid/config/rock_salt.py b/src/jaqmc/app/solid/config/rock_salt.py index 687d7c6..5effefa 100644 --- a/src/jaqmc/app/solid/config/rock_salt.py +++ b/src/jaqmc/app/solid/config/rock_salt.py @@ -32,7 +32,6 @@ def rock_salt_config( lattice_constant: float = 4.0, unit: LengthUnit = LengthUnit.angstrom, supercell: list[int] | None = None, - basis: str = "sto-3g", ecp: Any = None, electron_init_width: float = 1.0, ): @@ -44,7 +43,6 @@ def rock_salt_config( lattice_constant: Lattice constant. unit: Unit of the lattice constant ('angstrom' or 'bohr'). supercell: Supercell dimensions [nx, ny, nz]. Defaults to [1, 1, 1]. - basis: Basis set name for HF pretrain. ecp: Effective core potential specification. Can be ``None`` (all-electron), a string (e.g., ``"ccecp"``), or a per-element mapping (e.g., ``{"Li": "ccecp"}``). @@ -92,7 +90,6 @@ def rock_salt_config( [0, 0, supercell[2]], ], electron_spins=(n_up, n_down), - basis=basis, ecp=ecp, electron_init_width=electron_init_width, ) diff --git a/src/jaqmc/app/solid/config/two_atom_chain.py b/src/jaqmc/app/solid/config/two_atom_chain.py index f57d443..4ef04e2 100644 --- a/src/jaqmc/app/solid/config/two_atom_chain.py +++ b/src/jaqmc/app/solid/config/two_atom_chain.py @@ -26,7 +26,6 @@ def two_atom_chain( supercell: int = 1, vacuum_separation: float = 100.0, spin: int = 0, - basis: str = "sto-3g", ecp: Any = None, electron_init_width: float = 1.0, ): @@ -40,7 +39,6 @@ def two_atom_chain( vacuum_separation: Lattice constant in y and z directions (in Bohr) to isolate the 1D chain. spin: Total spin polarization (n_up - n_down) for the primitive cell. - basis: Basis set name for HF pretrain. ecp: Effective core potential specification. Can be ``None`` (all-electron), a string (e.g., ``"ccecp"``), or a per-element mapping (e.g., ``{"Li": "ccecp"}``). @@ -91,7 +89,6 @@ def two_atom_chain( lattice_vectors=lattice_vectors, supercell_matrix=[[supercell, 0, 0], [0, 1, 0], [0, 0, 1]], electron_spins=(n_up, n_down), - basis=basis, ecp=ecp, electron_init_width=electron_init_width, ) diff --git a/src/jaqmc/app/solid/workflow.py b/src/jaqmc/app/solid/workflow.py index bacf134..362dee5 100644 --- a/src/jaqmc/app/solid/workflow.py +++ b/src/jaqmc/app/solid/workflow.py @@ -30,7 +30,7 @@ from jaqmc.workflow.stage.vmc import VMCWorkStage from jaqmc.workflow.vmc import VMCWorkflow -from .config import SolidConfig +from .config import SolidConfig, SolidPretrainReferenceConfig from .data import data_init from .hamiltonian import PotentialEnergy from .wavefunction import SolidWavefunction @@ -49,7 +49,12 @@ def default_preset(cls) -> dict[str, Any]: return { "pretrain": { "run": {"iterations": 2_000}, - "optim": {"learning_rate": {"rate": 3e-4}}, + "optim": { + "learning_rate": { + "module": "jaqmc.optimizer.schedule:Constant", + "rate": 3e-4, + } + }, }, "train": { "run": {"iterations": 200_000}, @@ -59,21 +64,25 @@ def default_preset(cls) -> dict[str, Any]: def __init__(self, cfg: ConfigManager) -> None: super().__init__(cfg) - system_config, wf, scf, sampling_proposal = configure_system(cfg) + system_config, wf, sampling_proposal = configure_system(cfg) + + pretrain_config = cfg.get("pretrain.reference", SolidPretrainReferenceConfig) + self.scf = make_scf(pretrain_config, system_config) nspins = ( system_config.electron_spins[0] * system_config.scale, system_config.electron_spins[1] * system_config.scale, ) - self.scf = scf self.wf = wf self.data_init = partial(data_init, system_config) loss_estimator = make_pretrain_loss( - orbitals_fn=wf.orbitals, scf=scf, nspins=nspins, full_det=wf.full_det + orbitals_fn=wf.orbitals, scf=self.scf, nspins=nspins, full_det=wf.full_det ) f_log_amplitude = make_pretrain_log_amplitude( - wf.logpsi, lambda data: scf.eval_slater(data.electrons, nspins).real + wf.logpsi, + lambda data: self.scf.eval_slater(data.electrons, nspins).real, + scf_fraction=pretrain_config.sample_fraction, ) sampler = cfg.get("sampler", MCMCSampler(sampling_proposal=sampling_proposal)) @@ -87,7 +96,7 @@ def __init__(self, cfg: ConfigManager) -> None: train.configure_sample_plan(wf.logpsi, {"electrons": sampler}) train.configure_optimizer(default=KFACOptimizer, f_log_psi=wf.logpsi) estimators = make_estimators( - cfg, wf, scf, system_config, always_enable_energy=True + cfg, wf, system_config, self.scf._cell._ecp, always_enable_energy=True ) train.configure_estimators(**estimators) train.configure_loss_grads(f_log_psi=wf.logpsi) @@ -104,17 +113,19 @@ class SolidEvalWorkflow(EvaluationWorkflow): def __init__(self, cfg: ConfigManager) -> None: super().__init__(cfg) - system_config, wf, scf, sampling_proposal = configure_system(cfg) + system_config, wf, sampling_proposal = configure_system(cfg) - self.scf = scf self.wf = wf self.data_init = partial(data_init, system_config) evaluation = EvaluationWorkStage.builder(cfg, wf, name="evaluation") sampler = cfg.get("sampler", MCMCSampler(sampling_proposal=sampling_proposal)) evaluation.configure_sample_plan(wf.logpsi, {"electrons": sampler}) + + reference_config = cfg.get("reference", SolidPretrainReferenceConfig) + self.scf = make_scf(reference_config, system_config) eval_estimators: dict[str, EstimatorLike] = make_estimators( - cfg, wf, scf, system_config + cfg, wf, system_config, self.scf._cell._ecp ) evaluation.configure_estimators(**eval_estimators) self.evaluation_stage = evaluation.build() @@ -127,7 +138,7 @@ def run(self) -> None: def configure_system( cfg: ConfigManagerLike, -) -> tuple[SolidConfig, SolidWavefunction, PeriodicSCF, Callable]: +) -> tuple[SolidConfig, SolidWavefunction, Callable]: """Build the shared system objects for solid workflows. Returns: @@ -155,40 +166,46 @@ def configure_system( wf.primitive_lattice = lattice_vectors wf.simulation_lattice = supercell_lattice - # Compute k-points for PeriodicSCF + sampling_proposal = make_pbc_gaussian_proposal(supercell_lattice) + + if not isinstance(wf, Wavefunction) or not isinstance(wf, SolidWavefunction): + raise TypeError( + f"Wavefunction must implement SolidWavefunction, got {type(wf).__name__}" + ) + return system_config, wf, sampling_proposal + + +def make_scf( + pretrain_config: SolidPretrainReferenceConfig, system_config: SolidConfig +) -> PeriodicSCF: + S = jnp.array(system_config.supercell_matrix) - prim_rec_vecs = get_reciprocal_vectors(lattice_vectors) + prim_rec_vecs = get_reciprocal_vectors(jnp.array(system_config.lattice_vectors)) kpts_folding = get_supercell_kpts(S, prim_rec_vecs) twist = jnp.array(system_config.twist) - sim_rec_vecs = get_reciprocal_vectors(supercell_lattice) + sim_rec_vecs = get_reciprocal_vectors(jnp.array(system_config.supercell_lattice)) k_twist = jnp.dot(twist, sim_rec_vecs) kpts = kpts_folding + k_twist[None, :] - core_electrons = get_core_electrons(system_config.atoms, system_config.ecp) - scf = PeriodicSCF( - atoms=system_config.atoms, - nelectrons=system_config.electron_spins, - lattice_vectors=np.asarray(lattice_vectors), + return PeriodicSCF( + system_config.atoms, + system_config.electron_spins, + lattice_vectors=np.asarray(system_config.lattice_vectors), kpts=np.asarray(kpts), - basis=system_config.basis, + basis=pretrain_config.basis, + restricted=pretrain_config.method == "KRHF", ecp=system_config.ecp, - core_electrons=core_electrons, + core_electrons=get_core_electrons(system_config.atoms, system_config.ecp), + verbose=pretrain_config.verbose, + pyscf_options=pretrain_config.extra, ) - sampling_proposal = make_pbc_gaussian_proposal(supercell_lattice) - - if not isinstance(wf, Wavefunction) or not isinstance(wf, SolidWavefunction): - raise TypeError( - f"Wavefunction must implement SolidWavefunction, got {type(wf).__name__}" - ) - return system_config, wf, scf, sampling_proposal - def make_estimators( cfg: ConfigManagerLike, wf: SolidWavefunction, - scf: PeriodicSCF, system_config: SolidConfig, + ecp_coefficients: dict[str, Any] | None = None, always_enable_energy: bool = False, ) -> dict[str, EstimatorLike]: estimators: dict[str, EstimatorLike] = {} @@ -200,12 +217,12 @@ def make_estimators( estimators["kinetic"] = cfg.get( "estimators.energy.kinetic", EuclideanKinetic(f_log_psi=wf.logpsi) ) - if system_config.ecp is not None: - logger.info("ECP enabled for elements: %s", list(scf._cell._ecp.keys())) + if ecp_coefficients: + logger.info("ECP enabled for elements: %s", list(ecp_coefficients.keys())) estimators["ecp"] = cfg.get( "estimators.energy.ecp", ECPEnergy( - ecp_coefficients=scf._cell._ecp, + ecp_coefficients=ecp_coefficients, atom_symbols=[atom.symbol for atom in system_config.atoms] * system_config.scale, phase_logpsi=wf.phase_logpsi, diff --git a/src/jaqmc/utils/atomic/atomic_system.py b/src/jaqmc/utils/atomic/atomic_system.py index d7e3cb4..23db8a5 100644 --- a/src/jaqmc/utils/atomic/atomic_system.py +++ b/src/jaqmc/utils/atomic/atomic_system.py @@ -11,11 +11,10 @@ class AtomicSystemConfig: Args: atoms: List of atoms in the system. - electron_spins: Tuple of two integers representing the - number of up and down electrons. - basis: The basis set for Hartree-Fock pretrain. Can be a string - (e.g., "sto-3g", "ccecpccpvdz") or a dict mapping element - symbols to basis names (e.g., {"Fe": "ccecpccpvdz", "O": "cc-pvdz"}). + electron_spins: Tuple of two integers representing the number of + explicitly simulated up and down electrons. In all-electron systems + this is the full electron count; with ECPs it is the valence count + after core electrons are replaced by the pseudopotential. ecp: Effective core potential specification. Can be None (no ECP), a string (same ECP for all atoms, e.g., "ccecp"), or a dict mapping element symbols to ECP names (e.g., {"Fe": "ccecp"}). @@ -27,7 +26,6 @@ class AtomicSystemConfig: """ atoms: list[Atom] - basis: str | dict[str, str] = "sto-3g" ecp: str | dict[str, str] | None = None electron_spins: tuple[int, int] fixed_spins_per_atom: list[tuple[int, int]] | None = None diff --git a/src/jaqmc/utils/atomic/pretrain.py b/src/jaqmc/utils/atomic/pretrain.py index 4233055..76e6e8a 100644 --- a/src/jaqmc/utils/atomic/pretrain.py +++ b/src/jaqmc/utils/atomic/pretrain.py @@ -7,6 +7,7 @@ from typing import Any import jax +import serde from jax import numpy as jnp from jaqmc.array_types import Params, PRNGKey @@ -14,12 +15,33 @@ from jaqmc.estimator import FunctionEstimator from jaqmc.estimator.base import Estimator from jaqmc.utils import parallel_jax +from jaqmc.utils.config import configurable_dataclass from jaqmc.wavefunction import NumericWavefunctionEvaluate from jaqmc.wavefunction.base import WavefunctionEvaluate from .scf import MolecularSCF, PeriodicSCF +@configurable_dataclass +class PretrainReferenceConfig: + """Configuration for the Hartree-Fock reference used during pretraining. + + Args: + basis: The basis set for Hartree-Fock pretrain. Can be a string + (e.g., "sto-3g", "ccecpccpvdz") or a dict mapping element + symbols to basis names (e.g., {"Fe": "ccecpccpvdz", "O": "cc-pvdz"}). + sample_fraction: Mixing fraction for SCF during pretrain sampling. + (0.0 = pure NN, 1.0 = pure SCF.) + extra: Extra options for the PySCF mean-field object. + When specifying in CLI, all unknown/extra fields are captured. + """ + + basis: str | Mapping[str, str] | None = "cc-pVDZ" + sample_fraction: float = 1.0 + verbose: int = 4 + extra: dict[str, Any] = serde.field(flatten=True, default_factory=dict) + + def make_pretrain_log_amplitude[DataT: Data]( log_psi_fn: WavefunctionEvaluate[DataT, jnp.ndarray], scf_log_amplitude_fn: Callable[[DataT], jnp.ndarray], diff --git a/src/jaqmc/utils/atomic/scf.py b/src/jaqmc/utils/atomic/scf.py index 3e23f0f..dac8600 100644 --- a/src/jaqmc/utils/atomic/scf.py +++ b/src/jaqmc/utils/atomic/scf.py @@ -27,6 +27,7 @@ import logging from collections.abc import Mapping, Sequence +from typing import Any import numpy as np import pyscf.gto @@ -75,6 +76,26 @@ def _extract_spin_blocks( return alpha_matrix, beta_matrix +def _slogdet_spin_block(matrix: NDArray) -> tuple[jnp.ndarray, jnp.ndarray]: + """``jnp.linalg.slogdet`` for one spin block. + + An empty block (shape ``(..., 0, 0)``) contributes determinant one + (``log|det| = 0``, ``sign = 1``). Avoiding ``slogdet`` on empty matrices also + sidesteps JAX ``diagonal``/``platform_dependent`` inconsistencies under manual + partitioning (see JaQMC ``qmc_batch_axis`` + JAX 0.6.x). + + Returns: + Same as ``jnp.linalg.slogdet``: ``(sign, log_abs_det)`` with batch shape + ``matrix.shape[:-2]``. + """ + if matrix.shape[-1] == 0: + leading = matrix.shape[:-2] + ones = jnp.ones(leading, dtype=matrix.dtype) + zeros = jnp.zeros(leading, dtype=matrix.dtype) + return ones, zeros + return jnp.linalg.slogdet(matrix) + + def _eval_slater_from_orbitals( alpha_matrix: NDArray, beta_matrix: NDArray ) -> tuple[jnp.ndarray, jnp.ndarray]: @@ -87,8 +108,8 @@ def _eval_slater_from_orbitals( Returns: Tuple of (sign, log_abs_det) for the Slater determinant. """ - sign_alpha, logdet_alpha = jnp.linalg.slogdet(alpha_matrix) - sign_beta, logdet_beta = jnp.linalg.slogdet(beta_matrix) + sign_alpha, logdet_alpha = _slogdet_spin_block(alpha_matrix) + sign_beta, logdet_beta = _slogdet_spin_block(beta_matrix) return sign_alpha * sign_beta, logdet_alpha + logdet_beta @@ -121,7 +142,9 @@ def __init__( ecp: str | Mapping[str, str] | None = None, core_electrons: Mapping[str, int] | None = None, pyscf_mol: pyscf.gto.Mole | None = None, - restricted: bool = True, + restricted: bool = False, + verbose: int = 4, + pyscf_options: Mapping[str, Any] | None = None, ): pyscf.lib.param.TMPDIR = None @@ -152,6 +175,7 @@ def __init__( self._mol.spin = nelectrons[0] - nelectrons[1] self._mol.charge = charge self._mol.ecp = ecp + self._mol.verbose = verbose self._mol.build() if self._mol.nelectron != sum(nelectrons): raise RuntimeError("PySCF molecule not consistent with QMC molecule.") @@ -166,6 +190,12 @@ def __init__( self.eval_aos = AtomicOrbitalEvaluator.from_pyscf(self._mol) self.restricted = restricted + for k, v in (pyscf_options or {}).items(): + if k not in self.mean_field._keys: + logger.warning("Ignoring option %s as it's not used by PySCF", k) + else: + setattr(self.mean_field, k, v) + def run(self, dm0: np.ndarray | None = None): """Runs the Hartree-Fock calculation. @@ -176,6 +206,7 @@ def run(self, dm0: np.ndarray | None = None): A pyscf scf object (i.e. pyscf.scf.rhf.RHF, pyscf.scf.uhf.UHF or pyscf.scf.rohf.ROHF depending on the spin and restricted settings). """ + logger.info("Start %s", type(self.mean_field).__name__) try: self.mean_field.kernel(dm0=dm0) except TypeError: @@ -185,6 +216,7 @@ def run(self, dm0: np.ndarray | None = None): ) # 1e solvers (e.g. uhf.HF1e) do not take any keyword arguments. self.mean_field.kernel() + logger.info("Complete %s", type(self.mean_field).__name__) return self.mean_field def eval_mos(self, positions: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: @@ -321,8 +353,10 @@ def __init__( ecp: str | Mapping[str, str] | None = None, core_electrons: Mapping[str, int] | None = None, pyscf_cell: pyscf.pbc.gto.Cell | None = None, - restricted: bool = True, + restricted: bool = False, rcut: float | None = None, + verbose: int = 4, + pyscf_options: Mapping[str, Any] | None = None, ): pyscf.lib.param.TMPDIR = None @@ -353,6 +387,7 @@ def __init__( self._cell.spin = nelectrons[0] - nelectrons[1] self._cell.charge = charge self._cell.ecp = ecp + self._cell.verbose = verbose self._cell.build() # Set up k-points @@ -372,6 +407,12 @@ def __init__( self.eval_aos = PBCAtomicOrbitalEvaluator.from_pyscf(self._cell, rcut=rcut) self._mo_coeff: tuple[list, list] | None = None + for k, v in (pyscf_options or {}).items(): + if k not in self.mean_field._keys: + logger.warning("Ignoring option %s as it's not used by PySCF", k) + else: + setattr(self.mean_field, k, v) + def run(self, dm0: np.ndarray | None = None): """Run the k-point HF calculation. diff --git a/tests/cli_dry_run_test.py b/tests/cli_dry_run_test.py index ec06fae..6c04b59 100644 --- a/tests/cli_dry_run_test.py +++ b/tests/cli_dry_run_test.py @@ -85,13 +85,14 @@ class CliDryRunCase: formula: LiH bond_length: 3.015 unit: bohr - basis: sto-3g workflow: batch_size: 4 wf: hidden_dims_single: [4, 4] hidden_dims_double: [2, 2] pretrain: + reference: + basis: sto-3g run: iterations: 1 train: @@ -108,7 +109,6 @@ class CliDryRunCase: system: module: atom symbol: Fe - basis: ccecpccpvdz ecp: ccecp workflow: batch_size: 4 @@ -116,6 +116,9 @@ class CliDryRunCase: hidden_dims_single: [4, 4] hidden_dims_double: [2, 2] pretrain: + reference: + basis: ccecpccpvdz + method: UHF run: iterations: 1 train: @@ -169,13 +172,14 @@ class CliDryRunCase: - symbol: H coords: [3.78, 3.78, 3.78] electron_spins: [2, 2] - basis: sto-3g workflow: batch_size: 4 wf: hidden_dims_single: [4, 4] hidden_dims_double: [2, 2] pretrain: + reference: + basis: sto-3g run: iterations: 1 train: @@ -195,13 +199,15 @@ class CliDryRunCase: symbol_b: H lattice_constant: 4.0 unit: angstrom - basis: sto-3g workflow: batch_size: 4 wf: hidden_dims_single: [4, 4] hidden_dims_double: [2, 2] pretrain: + reference: + basis: sto-3g + method: KRHF run: iterations: 1 train: @@ -268,6 +274,9 @@ class CliDryRunCase: wf: hidden_dims_single: [4, 4] hidden_dims_double: [2, 2] +reference: + basis: sto-3g + method: KRHF run: iterations: 1 estimators: diff --git a/tests/utils/atomic/scf_test.py b/tests/utils/atomic/scf_test.py index 3d37c9b..5b000cb 100644 --- a/tests/utils/atomic/scf_test.py +++ b/tests/utils/atomic/scf_test.py @@ -10,7 +10,7 @@ from jax import numpy as jnp from jaqmc.utils.atomic.atom import Atom -from jaqmc.utils.atomic.scf import PeriodicSCF +from jaqmc.utils.atomic.scf import MolecularSCF, PeriodicSCF @pytest.fixture(autouse=True) @@ -18,6 +18,39 @@ def no_temp_file(): pyscf.lib.param.TMPDIR = None +def test_molecular_scf_pyscf_options_applied(): + """User-supplied pyscf_options are applied to the PySCF mean-field object.""" + atoms = [Atom("H", (0.0, 0.0, 0.0))] + scf = MolecularSCF( + atoms, + (1, 0), + basis="sto-3g", + verbose=0, + # Listed in mean_field._keys (PySCF applies options only for those names). + pyscf_options={"init_guess_breaksym": False}, + ) + assert not scf.mean_field.init_guess_breaksym + + +def test_periodic_scf_pyscf_options_applied(): + """Same as test_molecular_scf_pyscf_options_applied for periodic HF.""" + atoms = [Atom("H", (0.0, 0.0, 0.0))] + latvec = np.eye(3) * 5.0 + recip_vecs = 2 * np.pi * np.linalg.inv(latvec).T + kpts = np.array([[0.0, 0.0, 0.0]]) @ recip_vecs + scf = PeriodicSCF( + atoms=atoms, + nelectrons=(1, 0), + basis="sto-3g", + lattice_vectors=latvec, + kpts=kpts, + restricted=False, + verbose=0, + pyscf_options={"init_guess_breaksym": False}, + ) + assert not scf.mean_field.init_guess_breaksym + + # k-point configurations: (name, fractional coordinates) KPOINT_CONFIGS = { "gamma": np.array([[0.0, 0.0, 0.0]]),