Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/metatrain/experimental/flashmd/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def forward(
)
)
else:
output_blocks.append(b)
output_blocks.append(b.copy())
return_dict[name] = TensorMap(
return_dict[name].keys, output_blocks
)
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/flashmd_symplectic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def forward(
)
)
else:
output_blocks.append(b)
output_blocks.append(b.copy())
return_dict[name] = TensorMap(
return_dict[name].keys, output_blocks
)
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/mace/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def add_additive_contributions(
)
)
else:
output_blocks.append(b)
output_blocks.append(b.copy())
values[name] = TensorMap(values[name].keys, output_blocks)

def supported_outputs(self) -> Dict[str, ModelOutput]:
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/experimental/phace/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def forward(
)
)
else:
output_blocks.append(b)
output_blocks.append(b.copy())
return_dict[name] = TensorMap(return_dict[name].keys, output_blocks)

# For atomic basis targets, sparsify to create blocks with "atom_type"
Expand Down
25 changes: 22 additions & 3 deletions src/metatrain/gap/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,17 @@ def forward(
for idx_key in range(len(self._species_labels)):
key = self._species_labels.entry(idx_key)
if soap_features.keys.position(key) is not None:
new_blocks.append(soap_features.block(key))
block = soap_features.block(key)
new_blocks.append(
TensorBlock(
values=block.values,
samples=block.samples,
components=block.components,
properties=block.properties,
)
)
else:
new_blocks.append(dummyblock)
new_blocks.append(dummyblock.copy())
soap_features = TensorMap(keys=self._species_labels, blocks=new_blocks)
soap_features = soap_features.keys_to_samples("center_type")
# here, we move to properties to use metatensor operations to aggregate
Expand All @@ -254,7 +262,18 @@ def forward(
soap_features = soap_features.keys_to_properties(
["neighbor_1_type", "neighbor_2_type"]
)
soap_features = TensorMap(self._keys, soap_features.blocks())
soap_features = TensorMap(
self._keys,
[
TensorBlock(
values=block.values,
samples=block.samples,
components=block.components,
properties=block.properties,
)
for block in soap_features.blocks()
],
)
output_key = list(outputs.keys())[0]
energies = self._subset_of_regressors_torch(soap_features)
return_dict = {output_key: energies}
Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def forward(
)
)
else:
output_blocks.append(b)
output_blocks.append(b.copy())
return_dict[name] = TensorMap(return_dict[name].keys, output_blocks)

# For atomic basis targets, sparsify to create blocks with "atom_type"
Expand Down
5 changes: 4 additions & 1 deletion src/metatrain/utils/data/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,10 @@ def is_auxiliary_output(name: str) -> bool:
:return: `True` if the target is an auxiliary output, `False` otherwise.
"""
is_auxiliary = (
name == "features" or name == "energy_ensemble" or name.startswith("mtt::aux::")
name == "features"
or name == "feature"
or name == "energy_ensemble"
or name.startswith("mtt::aux::")
)
return is_auxiliary

Expand Down
2 changes: 1 addition & 1 deletion src/metatrain/utils/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def evaluate_model(
energy_targets_that_require_strain_gradients = []
for target_name in targets.keys():
# Check if the target is an energy:
if model_outputs[target_name].quantity == "energy":
if targets[target_name].quantity == "energy":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended to match on energy, energy_ensemble, and energy_uncertaintyoutputs, or onlyenergy`?

in the second case, this could be done better using

base_target_name = target_name.split("/")[0]
if base_target_name == "energy"

i.e. checking the output name, handling both energy and energy/pbe0. This way we could also remove the quantity field from metatrain

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but i am wondering that the custom energy names will be missed in this case, like mtt::U0

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There really should not be a point to use something like this nowadays =) This kind of naming was used before we got variants on the models outputs.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmmm if we remove quantity, old options.yaml will sometimes not work as Sofiia says for example when you pass a target called mtt::U0 (which is what the tutorials do).

Although I agree that the default behavior is a bit strange, everything by default is quantity: "energy" (even spherical targets), it is set here:

CONF_TARGET_FIELDS = OmegaConf.create(
{
"quantity": "energy",
"read_from": "${...systems.read_from}",
"reader": None,
"key": None,
"unit": "",
"per_atom": False,
"type": "scalar",
"num_subtargets": 1,
"description": "",
}
)

which makes quantity: "energy" by itself quite useless, and the code needs to perform further checks to know that it is dealing with an energy, like here:
is_energy = (
(target["quantity"] == "energy")
and (not target["per_atom"])
and target["num_subtargets"] == 1
and target["type"] == "scalar"
)

So yeah I agree this should go away but I'm a bit scared of touching it tbh

energy_targets.append(target_name)
if isinstance(targets[target_name], TargetInfo):
# Check if the energy requires gradients:
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/utils/scaler/_base_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,11 +831,11 @@ def _set_fixed_weights(

self.scales[target_name] = TensorMap(
self.Y2[target_name].keys.to(device=block.values.device),
[block],
[block.copy()],
)
self.per_target_scales[target_name] = TensorMap(
self.Y2[target_name].keys.to(device=block.values.device),
[block],
[block.copy()],
)

def _sync_device_dtype(self, device: torch.device, dtype: torch.dtype) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/utils/testing/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def test_output_features(

features_output_options = ModelOutput(
quantity="",
unit="unitless",
unit="",
per_atom=per_atom,
)
model = model.to(system.positions.dtype)
Expand Down Expand Up @@ -705,7 +705,7 @@ def test_output_last_layer_features(
# last-layer features per atom:
ll_output_options = ModelOutput(
quantity="",
unit="unitless",
unit="",
per_atom=per_atom,
)
model = model.to(system.positions.dtype)
Expand Down
Loading