Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/dev-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ jobs:
environment-file: devtools/conda-envs/nightly.yaml
create-args: >-
python=${{ matrix.python-version }}
pydantic=2

- name: Install nightly pytorch
run: |
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/examples-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ env:

jobs:
example_tests:
name: Examples CI (${{ matrix.os }}, py-${{ matrix.python-version }}, rdkit=${{ matrix.include-rdkit }}, openeye=${{ matrix.include-openeye }}, dgl=${{ matrix.include-dgl }}), pydantic=${{ matrix.pydantic-version }}
name: Examples CI (${{ matrix.os }}, py-${{ matrix.python-version }}, rdkit=${{ matrix.include-rdkit }}, openeye=${{ matrix.include-openeye }}, dgl=${{ matrix.include-dgl }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.12"]
pydantic-version: ["2"]
include-rdkit: [true]
include-openeye: [false]
include-dgl: [true]
Expand All @@ -53,7 +52,6 @@ jobs:
environment-file: devtools/conda-envs/examples_env.yaml
create-args: >-
python=${{ matrix.python-version }}
pydantic=${{ matrix.pydantic-version }}

- name: Install package
run: |
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/gh-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ env:

jobs:
main_tests:
name: CI (${{ matrix.os }}, py-${{ matrix.python-version }}, rdkit=${{ matrix.include-rdkit }}, openeye=${{ matrix.include-openeye }}, dgl=${{ matrix.include-dgl }}), pydantic=${{ matrix.pydantic-version }}
name: CI (${{ matrix.os }}, py-${{ matrix.python-version }}, rdkit=${{ matrix.include-rdkit }}, openeye=${{ matrix.include-openeye }}, dgl=${{ matrix.include-dgl }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macOS-latest, ubuntu-latest]
python-version: ["3.12"]
pydantic-version: ["1", "2"]
include-rdkit: [false, true]
include-openeye: [false, true]
include-dgl: [false, true]
Expand All @@ -56,7 +55,6 @@ jobs:
environment-file: devtools/conda-envs/test_env_dgl_${{ matrix.include-dgl }}.yaml
create-args: >-
python=${{ matrix.python-version }}
pydantic=${{ matrix.pydantic-version }}

- name: Install package
run: |
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- openff-toolkit-base >=0.18
- openff-nagl-models
- openff-units
- pydantic <3
- pydantic =2
- rdkit
- scipy

Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/examples_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies:
- openff-qcsubmit
- psi4 =1.9.1 # solver can't figure out 1.10+
- libint =2.9
- pydantic <3
- pydantic =2
- rdkit

# database
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/nightly.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- tqdm

# chemistry
- pydantic <3
- pydantic =2
- rdkit
- scipy
- openeye-toolkits
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env_dgl_false.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies:
- openff-toolkit-base >=0.18
- openff-nagl-models
- openff-units
- pydantic <3
- pydantic =2
- rdkit !=2024.03.6
- scipy
- ambertools
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-envs/test_env_dgl_true.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies:
- openff-toolkit-base >=0.18
- openff-nagl-models
- openff-units
- pydantic <3
- pydantic =2
- rdkit
- scipy
- ambertools
Expand Down
58 changes: 35 additions & 23 deletions openff/nagl/_base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,41 @@
from openff.units import unit


try:
from pydantic.v1 import BaseModel
except ImportError:
from pydantic import BaseModel
from pydantic import BaseModel, model_serializer, ConfigDict

def _encode_values(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, unit.Quantity):
return obj.to_tuple()
if isinstance(obj, enum.Enum):
return obj.name
if isinstance(obj, pathlib.Path):
return str(obj)
if isinstance(obj, (tuple, set)):
return list(obj)
if isinstance(obj, dict):
return {k: _encode_values(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_encode_values(i) for i in obj]
return obj

class MutableModel(BaseModel):
"""
Base class that all classes should subclass.
"""

class Config:
validate_all = True
arbitrary_types_allowed = True
underscore_attrs_are_private = True
validate_assignment = True
extra = "forbid"
json_encoders = {
np.ndarray: lambda x: x.tolist(),
tuple: list,
set: list,
unit.Quantity: lambda x: x.to_tuple(),
enum.Enum: lambda x: x.name,
pathlib.Path: str,
}
model_config = ConfigDict(
validate_default=True,
arbitrary_types_allowed=True,
validate_assignment=True,
extra="forbid",
)

@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
return _encode_values(data)

def __init__(self, *args, **kwargs):
self.__pre_init__(*args, **kwargs)
Expand All @@ -43,8 +54,9 @@ def __pre_init__(self, *args, **kwargs):
def __post_init__(self, *args, **kwargs):
pass

def to_json(self):
return self.json(
def model_dump_json(self, **kwargs):
# sort_keys=True removed in v2, can kinda wrap it to maintain behavior
return json.dumps(self.model_dump(**kwargs),
sort_keys=True,
indent=2,
separators=(",", ": "),
Expand All @@ -64,7 +76,7 @@ def from_json(cls, string_or_file):
return validator(string_or_file)

def to_yaml(self, filename):
data = json.loads(self.json())
data = json.loads(self.model_dump_json())
with open(filename, "w") as f:
yaml.dump(data, f)

Expand All @@ -75,5 +87,5 @@ def from_yaml(cls, filename):
return cls(**data)

class ImmutableModel(MutableModel):
class Config(MutableModel.Config):
allow_mutation = False
# other options are **merged** with parent's config_dict
model_config = ConfigDict(frozen=True)
5 changes: 1 addition & 4 deletions openff/nagl/config/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from openff.nagl._base.base import ImmutableModel
from openff.nagl.utils._types import FromYamlMixin

try:
from pydantic.v1 import Field
except ImportError:
from pydantic import Field
from pydantic import Field

DiscriminatedTargetType = typing.Annotated[TargetType, Field(discriminator="name")]

Expand Down
10 changes: 4 additions & 6 deletions openff/nagl/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
AggregatorType = typing.Literal["mean", "gcn", "pool", "lstm", "sum"]
PostprocessType = typing.Literal["readout", "compute_partial_charges", "regularized_compute_partial_charges"]

try:
from pydantic.v1 import Field, validator
except ImportError:
from pydantic import Field, validator
from pydantic import Field, field_validator

class BaseLayer(ImmutableModel):
"""Base class for single layer in the neural network"""
Expand All @@ -36,15 +33,16 @@ class BaseLayer(ImmutableModel):
description="The dropout to apply after each layer"
)

@validator("activation_function", pre=True)
@field_validator("activation_function", mode="before")
@classmethod
def _validate_activation_function(cls, v):
return ActivationFunction._get_class(v)


class ConvolutionLayer(BaseLayer):
"""Configuration for a single convolution layer"""
aggregator_type: AggregatorType = Field(
default=None,
default=None, # this conflicts with the type annotation
description="The aggregator function to apply after each convolution"
)

Expand Down
5 changes: 1 addition & 4 deletions openff/nagl/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from openff.nagl._base.base import ImmutableModel
from openff.nagl.toolkits.openff import ensure_toolkit_registry

try:
from pydantic.v1 import Field
except ImportError:
from pydantic import Field
from pydantic import Field

if typing.TYPE_CHECKING:
from openff.toolkit.topology import Molecule
Expand Down
8 changes: 3 additions & 5 deletions openff/nagl/features/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
from ._base import CategoricalMixin, Feature
from ._utils import one_hot_encode

try:
from pydantic.v1 import validator, Field
except ImportError:
from pydantic import validator, Field
from pydantic import field_validator, Field

if typing.TYPE_CHECKING:
from openff.toolkit.topology import Molecule
Expand Down Expand Up @@ -110,7 +107,8 @@ class AtomHybridization(CategoricalMixin, AtomFeature):
]
"""The supported hybridization modes."""

@validator("categories", pre=True, each_item=True)
@field_validator("categories", mode="before")
@classmethod
def _validate_categories(cls, v):
if isinstance(v, str):
return HybridizationType[v.upper()]
Expand Down
7 changes: 2 additions & 5 deletions openff/nagl/features/bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
from ._base import CategoricalMixin, Feature #, FeatureMeta
from ._utils import one_hot_encode

try:
from pydantic.v1 import Field
except ImportError:
from pydantic import Field
from pydantic import Field

if typing.TYPE_CHECKING:
from openff.nagl.toolkits.registry import NAGLToolkitRegistry
Expand Down Expand Up @@ -155,7 +152,7 @@ class BondOrder(CategoricalMixin, BondFeature):
"""
name: typing.Literal["bond_order"] = "bond_order"

categories = [1, 2, 3]
categories: typing.List[int] = [1, 2, 3]

def _encode(self, molecule, toolkit_registry: typing.Optional["NAGLToolkitRegistry"] = None) -> torch.Tensor:
return torch.vstack(
Expand Down
8 changes: 3 additions & 5 deletions openff/nagl/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from openff.nagl.toolkits.openff import ensure_toolkit_registry
from openff.nagl.utils._utils import is_iterable, potential_dict_to_list

try:
from pydantic.v1 import Field, validator
except ImportError:
from pydantic import Field, validator
from pydantic import Field, field_validator

if typing.TYPE_CHECKING:
from openff.toolkit.topology import Molecule
Expand Down Expand Up @@ -89,7 +86,8 @@ class AtomPropertiesLookupTable(BaseLookupTable):
description="The property lookup table"
)

@validator("properties", pre=True)
@field_validator("properties", mode="before")
@classmethod
def _convert_property_lookup_table(cls, v):
"""
Do two things:
Expand Down
6 changes: 3 additions & 3 deletions openff/nagl/nn/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ def __init__(

lookup_tables_dict = {}
for k, v in valid_lookup_tables.items():
v_ = v.dict()
v_ = v.model_dump()
v_["properties"] = dict(v_["properties"])
lookup_tables_dict[k] = v_

self.save_hyperparameters({
"config": config.dict(),
"chemical_domain": chemical_domain.dict(),
"config": config.model_dump(),
"chemical_domain": chemical_domain.model_dump(),
"lookup_tables": lookup_tables_dict,
})
self.config = config
Expand Down
15 changes: 7 additions & 8 deletions openff/nagl/tests/_base/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
import json
import textwrap

try:
from pydantic.v1 import Field, validator
except ImportError:
from pydantic import Field, validator
from pydantic import field_validator

class TestMutableModel:

Expand All @@ -19,11 +16,13 @@ class Model(MutableModel):
tuple_type: tuple
unit_type: unit.Quantity

@validator("np_array_type", pre=True)
@field_validator("np_array_type", mode="before")
@classmethod
def _validate_np_array_type(cls, v):
return np.asarray(v)

@validator("unit_type", pre=True)

@field_validator("unit_type", mode="before")
@classmethod
def _validate_unit_type(cls, v):
if not isinstance(v, unit.Quantity):
return unit.Quantity.from_tuple(v)
Expand All @@ -42,7 +41,7 @@ def test_init(self):
def test_to_json(self):
arr = np.arange(10).reshape(2, 5)
model = self.Model(int_type=1, float_type=1.0, list_type=[1, 2, 3], np_array_type=arr, tuple_type=(1, 2, 3), unit_type=unit.Quantity(1.0, "angstrom"))
json_dict = json.loads(model.to_json())
json_dict = json.loads(model.model_dump_json())
expected = {
"int_type": 1,
"float_type": 1.0,
Expand Down
5 changes: 3 additions & 2 deletions openff/nagl/tests/training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def get_required_columns(self) -> List[str]:
def evaluate_target(self, molecules, labels, predictions, readout_modules) -> "torch.Tensor":
return torch.tensor([0.0])

@pytest.mark.skip(reason="TODO")
def test_validate_metric(self):
input_text = '{"metric": "rmse", "name": "readout", "prediction_label": "charges", "target_label": "charges"}'
target = ReadoutTarget.parse_raw(input_text)
target = ReadoutTarget.model_validate_json(input_text)
assert isinstance(target.metric, RMSEMetric)

@pytest.mark.skip(reason="TODO")
def test_non_implemented_methods(self):
target = self.BaseTarget(name="base", metric="rmse", target_label="charges")
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -273,7 +275,6 @@ def test_single_molecule(self, dgl_methane):
readout_modules={},
)
assert torch.allclose(loss, torch.tensor([52.48571234]))


def test_multiple_molecules(self, dgl_batch):
charges = torch.cat([
Expand Down
Loading
Loading