From 049c995ff84f77806b942383c2bb0f24563e1dd3 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Sun, 31 May 2026 18:41:35 +0200 Subject: [PATCH] Add ModelCapabilities to the JSON structs --- docs/src/core/reference/json-formats.rst | 70 +++++ docs/src/core/units.rst | 2 +- metatomic-core/include/metatomic.h | 18 +- metatomic-core/src/c_api/model.rs | 20 +- metatomic-core/src/lib.rs | 2 +- metatomic-core/src/metadata.rs | 373 +++++++++++++++++++++-- metatomic-core/src/quantities.rs | 10 +- metatomic-core/src/units.rs | 20 ++ 8 files changed, 486 insertions(+), 29 deletions(-) diff --git a/docs/src/core/reference/json-formats.rst b/docs/src/core/reference/json-formats.rst index 859f7d985..f7e3f468b 100644 --- a/docs/src/core/reference/json-formats.rst +++ b/docs/src/core/reference/json-formats.rst @@ -154,3 +154,73 @@ The JSON representation of a model's metadata. This is used for example by ``extra`` An object with string values, providing any additional key-value pairs the model author wishes to include. This can be used for any purpose. + +.. _core-json-model-capabilities: + +Model capabilities +------------------ + +The JSON representation of a model's capabilities, describing which outputs it +provides, which atomic types it supports, and other constraints. This is used +for example by :c:member:`mta_model_t.capabilities`. + +.. code-block:: json + + { + "type": "metatomic_model_capabilities", + "outputs": [ + { + "type": "metatomic_quantity", + "name": "energy", + "unit": "eV", + "sample_kind": "system", + "gradients": ["positions"], + "description": "Potential energy of the system" + }, + { + "type": "metatomic_quantity", + "name": "energy/pbe0", + "unit": "eV", + "sample_kind": "system", + "gradients": ["positions", "strain"], + "description": "Potential energy of the system" + }, + ], + "atomic_types": [1, 6, 8], + "interaction_range": 5.0, + "length_unit": "angstrom", + "supported_devices": ["cpu", "cuda"], + "dtype": "float32" + } + +``type`` + Must be the string ``"metatomic_model_capabilities"``. + +``outputs`` + Array of :ref:`quantity objects ` describing the + outputs this model can provide. + +``atomic_types`` + Array of integers listing the atomic types this model supports. The meaning + of these integers is up to the model, and is not required to be the atomic + numbers. + +``interaction_range`` + The interaction range of the model in the length unit of the model. This is + the maximum distance between two atoms for which the model's output can + depend on their relative position. Must be a non-negative number. + +``length_unit`` + String identifying the length unit used by the model, e.g. ``"angstrom"`` or + ``"nanometer"``. This must be a valid :ref:`unit expression ` with + dimensions compatible with length. + +``supported_devices`` + Array of strings listing the devices on which the model can run. Valid + values are ``"cpu"``, ``"cuda"``, ``"rocm"``, and ``"metal"``. + +``dtype`` + The data type of the model, used for all inputs and outputs. Must be either + ``"float32"`` or ``"float64"``. The model is free to use different data + types for internal computations, but all inputs and outputs must be in this + data type. diff --git a/docs/src/core/units.rst b/docs/src/core/units.rst index d5f3dd776..c9415ed9f 100644 --- a/docs/src/core/units.rst +++ b/docs/src/core/units.rst @@ -1,4 +1,4 @@ -.. _core-unit-expressions: +.. _units: Units ^^^^^ diff --git a/metatomic-core/include/metatomic.h b/metatomic-core/include/metatomic.h index a4ba5e95e..c8077b5c5 100644 --- a/metatomic-core/include/metatomic.h +++ b/metatomic-core/include/metatomic.h @@ -116,6 +116,20 @@ typedef struct mta_model_t { * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*unload)(void *model_data); + /** + * Get the capabilities of the model as a JSON string. + * + * @verbatim embed:rst:leading-asterisk + * The expected JSON structure is documented in :ref:`core-json-model-capabilities`. + * @endverbatim + * + * @param model_data the model's `data` pointer + * @param capabilities_json output string, set to a JSON-serialized + * `ModelCapabilities` object. The caller takes ownership and must + * free it with `mta_string_free`. + * @return `MTA_SUCCESS` on success, another status code on error + */ + enum mta_status_t (*capabilities)(const void *model_data, mta_string_t *capabilities_json); /** * Get metadata describing the model (name, authors, references, ...) as a * JSON string. @@ -126,8 +140,8 @@ typedef struct mta_model_t { * * @param model_data the model's `data` pointer * @param metadata_json output string, set to a JSON-serialized - * `ModelMetadata` object. The - * caller takes ownership and must free it with `mta_string_free`. + * `ModelMetadata` object. The caller takes ownership and must + * free it with `mta_string_free`. * @return `MTA_SUCCESS` on success, another status code on error */ enum mta_status_t (*metadata)(const void *model_data, mta_string_t *metadata_json); diff --git a/metatomic-core/src/c_api/model.rs b/metatomic-core/src/c_api/model.rs index c9fc83610..bf03e2a46 100644 --- a/metatomic-core/src/c_api/model.rs +++ b/metatomic-core/src/c_api/model.rs @@ -33,6 +33,22 @@ pub struct mta_model_t { /// @return `MTA_SUCCESS` on success, another status code on error pub unload: Option mta_status_t>, + /// Get the capabilities of the model as a JSON string. + /// + /// @verbatim embed:rst:leading-asterisk + /// The expected JSON structure is documented in :ref:`core-json-model-capabilities`. + /// @endverbatim + /// + /// @param model_data the model's `data` pointer + /// @param capabilities_json output string, set to a JSON-serialized + /// `ModelCapabilities` object. The caller takes ownership and must + /// free it with `mta_string_free`. + /// @return `MTA_SUCCESS` on success, another status code on error + pub capabilities: Option mta_status_t>, + /// Get metadata describing the model (name, authors, references, ...) as a /// JSON string. /// @@ -42,8 +58,8 @@ pub struct mta_model_t { /// /// @param model_data the model's `data` pointer /// @param metadata_json output string, set to a JSON-serialized - /// `ModelMetadata` object. The - /// caller takes ownership and must free it with `mta_string_free`. + /// `ModelMetadata` object. The caller takes ownership and must + /// free it with `mta_string_free`. /// @return `MTA_SUCCESS` on success, another status code on error pub metadata: Option for JsonValue { } } -impl TryFrom for PairListOptions { +impl<'a> TryFrom<&'a JsonValue> for PairListOptions { type Error = Error; - fn try_from(value: JsonValue) -> Result { + fn try_from(value: &'a JsonValue) -> Result { if !value.is_object() { return Err(Error::Serialization( "invalid JSON data for PairListOptions, expected an object".into() @@ -168,19 +169,19 @@ fn read_references(object: &JsonValue, key: &str) -> Result, Error> Ok(references) } -impl TryFrom for References { +impl<'a> TryFrom<&'a JsonValue> for References { type Error = Error; - fn try_from(value: JsonValue) -> Result { + fn try_from(value: &'a JsonValue) -> Result { if !value.is_object() { return Err(Error::Serialization( "invalid JSON data for references in ModelMetadata, expected an object".into() )); } - let model = read_references(&value, "model")?; - let architecture = read_references(&value, "architecture")?; - let implementation = read_references(&value, "implementation")?; + let model = read_references(value, "model")?; + let architecture = read_references(value, "architecture")?; + let implementation = read_references(value, "implementation")?; Ok(References { model, architecture, implementation }) } @@ -217,10 +218,10 @@ impl From for JsonValue { } } -impl TryFrom for ModelMetadata { +impl<'a> TryFrom<&'a JsonValue> for ModelMetadata { type Error = Error; - fn try_from(value: JsonValue) -> Result { + fn try_from(value: &'a JsonValue) -> Result { if !value.is_object() { return Err(Error::Serialization( "invalid JSON data for ModelMetadata, expected an object".into() @@ -253,7 +254,7 @@ impl TryFrom for ModelMetadata { "'description' in JSON for ModelMetadata must be a string".into() ))?.to_string(); - let references = References::try_from(value["references"].clone())?; + let references = References::try_from(&value["references"])?; if !value["extra"].is_object() { return Err(Error::Serialization( @@ -279,6 +280,211 @@ impl TryFrom for ModelMetadata { } } +/// The data type of a model, used for all inputs and outputs. The model can +/// still internally use a different data type for its calculations, but it will +/// get inputs in this type and must produce outputs in this type. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DType { + /// 32-bit floating point, following the IEEE 754 standard + Float32, + /// 64-bit floating point, following the IEEE 754 standard + Float64, +} + +impl From for JsonValue { + fn from(value: DType) -> Self { + match value { + DType::Float32 => "float32".into(), + DType::Float64 => "float64".into(), + } + } +} + +impl<'a> TryFrom<&'a JsonValue> for DType { + type Error = Error; + + fn try_from(value: &'a JsonValue) -> Result { + if let Some(s) = value.as_str() { + match s { + "float32" => Ok(DType::Float32), + "float64" => Ok(DType::Float64), + _ => Err(Error::Serialization( + "invalid string for dtype in JSON for ModelCapabilities, expected 'float32' or 'float64'".into() + )), + } + } else { + Err(Error::Serialization( + "dtype in JSON for ModelCapabilities must be a string".into() + )) + } + } +} + +/// A device on which a model can run. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Device(dlpk::DLDeviceType); + +impl From for JsonValue { + fn from(value: Device) -> Self { + match value.0 { + dlpk::DLDeviceType::kDLCPU => "cpu".into(), + dlpk::DLDeviceType::kDLCUDA => "cuda".into(), + dlpk::DLDeviceType::kDLROCM => "rocm".into(), + dlpk::DLDeviceType::kDLMetal => "metal".into(), + dlpk::DLDeviceType::kDLCUDAHost | dlpk::DLDeviceType::kDLCUDAManaged => { + // These refer to memory devices more than execution devices + panic!("Do not use kDLCUDAHost or kDLCUDAManaged, use kDLCUDA instead."); + } + dlpk::DLDeviceType::kDLROCMHost => { + // This refers to a memory device more than an execution device + panic!("Do not use kDLROCMHost, use kDLROCM instead."); + } + _ => { + // We don't want to expose other device types until we have a + // use case for them, and we don't want to accidentally leak + // them if they're added in the future + panic!("unsupported device type: {:?}", value.0); + } + } + } +} + +impl<'a> TryFrom<&'a JsonValue> for Device { + type Error = Error; + + fn try_from(value: &'a JsonValue) -> Result { + if let Some(s) = value.as_str() { + match s { + "cpu" => Ok(Device(dlpk::DLDeviceType::kDLCPU)), + "cuda" => Ok(Device(dlpk::DLDeviceType::kDLCUDA)), + "rocm" => Ok(Device(dlpk::DLDeviceType::kDLROCM)), + "metal" => Ok(Device(dlpk::DLDeviceType::kDLMetal)), + _ => Err(Error::Serialization( + "invalid string for device in JSON for ModelCapabilities, expected 'cpu', 'cuda', 'rocm', or 'metal'".into() + )), + } + } else { + Err(Error::Serialization( + "device in JSON for ModelCapabilities must be a string".into() + )) + } + } +} + +/// Capabilities about a model: which outputs it provides, which atoms it +/// supports, etc. +#[derive(Debug, Clone)] +pub struct ModelCapabilities { + /// The outputs this model can provide + pub outputs: Vec, + /// The atomic types this model supports. The meaning of the integers in + /// this list is up to the model, and is not required to be the atomic + /// numbers. + pub atomic_types: Vec, + /// The interaction range of the model (in the length unit of the model), + /// i.e. the maximum distance between two atoms for which the model's output + /// can depend on their relative position. + pub interaction_range: f64, + /// The length unit of the model, e.g. "angstrom" or "nanometer". This is + /// used to interpret the `interaction_range` and convert the inputs. + pub length_unit: String, + /// The devices on which the model can run, e.g. `["cpu", "cuda"]`. + pub supported_devices: Vec, + /// The data type of the model, used for all inputs and outputs. + pub dtype: DType, +} + +impl From for JsonValue { + fn from(value: ModelCapabilities) -> Self { + let mut result = JsonValue::new_object(); + result["type"] = "metatomic_model_capabilities".into(); + result["outputs"] = value.outputs.into(); + result["atomic_types"] = value.atomic_types.into(); + result["interaction_range"] = value.interaction_range.into(); + result["length_unit"] = value.length_unit.into(); + result["supported_devices"] = value.supported_devices.into(); + result["dtype"] = value.dtype.into(); + return result; + } +} + +impl<'a> TryFrom<&'a JsonValue> for ModelCapabilities { + type Error = Error; + + fn try_from(value: &'a JsonValue) -> Result { + if !value.is_object() { + return Err(Error::Serialization( + "invalid JSON data for ModelCapabilities, expected an object".into() + )); + } + + if value["type"].as_str() != Some("metatomic_model_capabilities") { + return Err(Error::Serialization( + "'type' in JSON for ModelCapabilities must be 'metatomic_model_capabilities'".into() + )); + } + + let mut outputs = Vec::new(); + if !value["outputs"].is_array() { + return Err(Error::Serialization( + "'outputs' in JSON for ModelCapabilities must be an array".into() + )); + } + for output in value["outputs"].members() { + outputs.push(Quantity::try_from(output)?); + } + + + let mut atomic_types = Vec::new(); + if !value["atomic_types"].is_array() { + return Err(Error::Serialization( + "'atomic_types' in JSON for ModelCapabilities must be an array".into() + )); + } + + for atomic_type in value["atomic_types"].members() { + let atomic_type = atomic_type.as_i64().ok_or_else(|| Error::Serialization( + "'atomic_types' in JSON for ModelCapabilities must be an array of integers".into() + ))?; + atomic_types.push(atomic_type); + } + + let interaction_range = value["interaction_range"].as_f64().ok_or_else(|| Error::Serialization( + "'interaction_range' in JSON for ModelCapabilities must be a number".into() + ))?; + if interaction_range < 0.0 { + return Err(Error::Serialization( + "'interaction_range' in JSON for ModelCapabilities must be non-negative".into() + )); + } + + let length_unit = value["length_unit"].as_str().ok_or_else(|| Error::Serialization( + "'length_unit' in JSON for ModelCapabilities must be a string".into() + ))?.to_string(); + validate_unit(&length_unit, "m", Some("'length_unit' in JSON for ModelCapabilities"))?; + + let mut supported_devices = Vec::new(); + if !value["supported_devices"].is_array() { + return Err(Error::Serialization( + "'supported_devices' in JSON for ModelCapabilities must be an array".into() + )); + } + for device in value["supported_devices"].members() { + supported_devices.push(Device::try_from(device)?); + } + + let dtype = DType::try_from(&value["dtype"])?; + + Ok(ModelCapabilities { + outputs, + atomic_types, + interaction_range, + length_unit, + supported_devices, + dtype, + }) + } +} #[cfg(test)] mod tests { @@ -304,7 +510,7 @@ mod tests { assert_eq!(json["full_list"].as_bool(), Some(true)); assert_eq!(json["strict"].as_bool(), Some(false)); - let parsed = PairListOptions::try_from(json).unwrap(); + let parsed = PairListOptions::try_from(&json).unwrap(); assert_eq!(parsed.cutoff.to_bits(), options.cutoff.to_bits()); assert_eq!(parsed.full_list, options.full_list); assert_eq!(parsed.strict, options.strict); @@ -315,7 +521,7 @@ mod tests { fn cutoff_keeps_full_precision() { let mut options = example(); options.cutoff = 1.0 / 3.0; - let parsed = PairListOptions::try_from(JsonValue::from(options.clone())).unwrap(); + let parsed = PairListOptions::try_from(&JsonValue::from(options.clone())).unwrap(); assert_eq!(parsed.cutoff.to_bits(), options.cutoff.to_bits()); } @@ -323,7 +529,7 @@ mod tests { fn requestors_are_optional() { let mut json: JsonValue = example().into(); json.remove("requestors"); - let parsed = PairListOptions::try_from(json).unwrap(); + let parsed = PairListOptions::try_from(&json).unwrap(); assert!(parsed.requestors.is_empty()); } @@ -380,7 +586,7 @@ mod tests { ]; for (json, expected) in cases { - let error = PairListOptions::try_from(json).expect_err("expected an error"); + let error = PairListOptions::try_from(&json).expect_err("expected an error"); assert_eq!(error.to_string(), expected); } } @@ -390,7 +596,7 @@ mod tests { let mut json: JsonValue = example().into(); json["requestors"] = json::array![ "a", "", "b", "a" ]; - let parsed = PairListOptions::try_from(json).unwrap(); + let parsed = PairListOptions::try_from(&json).unwrap(); assert_eq!(parsed.requestors, vec!["a".to_string(), "b".to_string()]); } } @@ -431,7 +637,7 @@ mod tests { assert_eq!(json["extra"]["key1"].as_str(), Some("value1")); assert_eq!(json["extra"]["key2"].as_str(), Some("value2")); - let parsed = ModelMetadata::try_from(json).unwrap(); + let parsed = ModelMetadata::try_from(&json).unwrap(); assert_eq!(parsed.name, metadata.name); assert_eq!(parsed.authors, metadata.authors); assert_eq!(parsed.description, metadata.description); @@ -494,7 +700,138 @@ mod tests { ]; for (json, expected) in cases { - let error = ModelMetadata::try_from(json).expect_err("expected an error"); + let error = ModelMetadata::try_from(&json).expect_err("expected an error"); + assert_eq!(error.to_string(), expected); + } + } + } + + mod model_capabilities { + use super::super::*; + + fn example() -> ModelCapabilities { + ModelCapabilities { + outputs: vec![ + Quantity { + name: "energy".into(), + unit: "eV".into(), + description: Some("total energy".into()), + gradients: vec![crate::Gradients::Positions], + sample_kind: crate::SampleKind::System, + }, + Quantity { + name: "charge".into(), + unit: "e".into(), + description: None, + gradients: vec![], + sample_kind: crate::SampleKind::Atom, + }, + ], + atomic_types: vec![1, 6, 8], + interaction_range: 5.0, + length_unit: "Angstrom".into(), + supported_devices: vec![Device(dlpk::DLDeviceType::kDLCPU), Device(dlpk::DLDeviceType::kDLCUDA)], + dtype: DType::Float32, + } + } + + #[test] + fn roundtrip() { + let capabilities = example(); + let json: JsonValue = capabilities.clone().into(); + + assert_eq!(json["type"].as_str(), Some("metatomic_model_capabilities")); + assert_eq!(json["outputs"][0]["name"].as_str(), Some("energy")); + assert_eq!(json["outputs"][1]["name"].as_str(), Some("charge")); + assert_eq!(json["atomic_types"][0].as_i64(), Some(1)); + assert_eq!(json["atomic_types"][1].as_i64(), Some(6)); + assert_eq!(json["atomic_types"][2].as_i64(), Some(8)); + assert_eq!(json["interaction_range"].as_f64(), Some(5.0)); + assert_eq!(json["length_unit"].as_str(), Some("Angstrom")); + assert_eq!(json["supported_devices"][0].as_str(), Some("cpu")); + assert_eq!(json["supported_devices"][1].as_str(), Some("cuda")); + assert_eq!(json["dtype"].as_str(), Some("float32")); + + let parsed = ModelCapabilities::try_from(&json).unwrap(); + assert_eq!(parsed.outputs.len(), 2); + assert_eq!(parsed.outputs[0].name, "energy"); + assert_eq!(parsed.outputs[1].name, "charge"); + assert_eq!(parsed.atomic_types, vec![1, 6, 8]); + assert_eq!(parsed.interaction_range.to_bits(), 5.0_f64.to_bits()); + assert_eq!(parsed.length_unit, "Angstrom"); + assert_eq!(parsed.supported_devices.len(), 2); + assert_eq!(parsed.dtype, DType::Float32); + } + + #[test] + fn rejects_invalid_json() { + let mut wrong_type = JsonValue::from(example()); + wrong_type["type"] = "something-else".into(); + + let mut non_array_outputs = JsonValue::from(example()); + non_array_outputs["outputs"] = "energy".into(); + + let mut non_array_atomic_types = JsonValue::from(example()); + non_array_atomic_types["atomic_types"] = "1".into(); + + let mut non_integer_atomic_type = JsonValue::from(example()); + non_integer_atomic_type["atomic_types"] = json::array![1, "x"]; + + let mut missing_interaction_range = JsonValue::from(example()); + missing_interaction_range.remove("interaction_range"); + + let mut negative_interaction_range = JsonValue::from(example()); + negative_interaction_range["interaction_range"] = (-1.0).into(); + + let mut missing_length_unit = JsonValue::from(example()); + missing_length_unit.remove("length_unit"); + + let mut wrong_dimension_length_unit = JsonValue::from(example()); + wrong_dimension_length_unit["length_unit"] = "eV".into(); + + let mut non_array_supported_devices = JsonValue::from(example()); + non_array_supported_devices["supported_devices"] = "cpu".into(); + + let mut invalid_device = JsonValue::from(example()); + invalid_device["supported_devices"] = json::array!["cpu", "wat"]; + + let mut missing_dtype = JsonValue::from(example()); + missing_dtype.remove("dtype"); + + let mut invalid_dtype = JsonValue::from(example()); + invalid_dtype["dtype"] = "float16".into(); + + let cases: Vec<(JsonValue, &str)> = vec![ + (JsonValue::from("not an object"), + "serialization error: invalid JSON data for ModelCapabilities, expected an object"), + (wrong_type, + "serialization error: 'type' in JSON for ModelCapabilities must be 'metatomic_model_capabilities'"), + (non_array_outputs, + "serialization error: 'outputs' in JSON for ModelCapabilities must be an array"), + (non_array_atomic_types, + "serialization error: 'atomic_types' in JSON for ModelCapabilities must be an array"), + (non_integer_atomic_type, + "serialization error: 'atomic_types' in JSON for ModelCapabilities must be an array of integers"), + (missing_interaction_range, + "serialization error: 'interaction_range' in JSON for ModelCapabilities must be a number"), + (negative_interaction_range, + "serialization error: 'interaction_range' in JSON for ModelCapabilities must be non-negative"), + (missing_length_unit, + "serialization error: 'length_unit' in JSON for ModelCapabilities must be a string"), + (wrong_dimension_length_unit, + "invalid parameter: dimension mismatch in 'length_unit' in JSON for ModelCapabilities: 'eV' has dimension [L^2 T^-2 M] but expected dimension [L]"), + (non_array_supported_devices, + "serialization error: 'supported_devices' in JSON for ModelCapabilities must be an array"), + (invalid_device, + "serialization error: invalid string for device in JSON for ModelCapabilities, expected 'cpu', 'cuda', 'rocm', or 'metal'"), + (missing_dtype, + "serialization error: dtype in JSON for ModelCapabilities must be a string"), + (invalid_dtype, + "serialization error: invalid string for dtype in JSON for ModelCapabilities, expected 'float32' or 'float64'"), + ]; + + for (json, expected) in cases { + let error = ModelCapabilities::try_from(&json).expect_err("expected an error"); assert_eq!(error.to_string(), expected); } } diff --git a/metatomic-core/src/quantities.rs b/metatomic-core/src/quantities.rs index 9d1dfebbf..93727c837 100644 --- a/metatomic-core/src/quantities.rs +++ b/metatomic-core/src/quantities.rs @@ -193,10 +193,10 @@ impl From for JsonValue { } -impl TryFrom for Quantity { +impl<'a> TryFrom<&'a JsonValue> for Quantity { type Error = Error; - fn try_from(value: JsonValue) -> Result { + fn try_from(value: &'a JsonValue) -> Result { if !value.is_object() { return Err(Error::Serialization( "invalid JSON data for Quantity, expected an object".into() @@ -272,7 +272,7 @@ mod tests { assert_eq!(json["gradients"][0].as_str(), Some("positions")); assert_eq!(json["sample_kind"].as_str(), Some("atom")); - let parsed = Quantity::try_from(json).unwrap(); + let parsed = Quantity::try_from(&json).unwrap(); assert_eq!(parsed.name, "energy"); assert_eq!(parsed.unit, "eV"); assert_eq!(parsed.gradients, vec![Gradients::Positions]); @@ -295,7 +295,7 @@ mod tests { gradients: grads.clone(), sample_kind: sample.clone(), }; - let parsed = Quantity::try_from(JsonValue::from(quantity.clone())).unwrap(); + let parsed = Quantity::try_from(&JsonValue::from(quantity.clone())).unwrap(); assert_eq!(parsed.name, quantity.name); assert_eq!(parsed.unit, quantity.unit); assert_eq!(parsed.gradients, grads); @@ -352,7 +352,7 @@ mod tests { ]; for (json, expected) in cases { - let error = Quantity::try_from(json).expect_err("expected an error"); + let error = Quantity::try_from(&json).expect_err("expected an error"); assert_eq!(error.to_string(), expected); } } diff --git a/metatomic-core/src/units.rs b/metatomic-core/src/units.rs index 4cfff2c1d..4239b86be 100644 --- a/metatomic-core/src/units.rs +++ b/metatomic-core/src/units.rs @@ -574,6 +574,26 @@ pub fn unit_conversion_factor(from_unit: &str, to_unit: &str) -> Result) -> Result<(), Error> { + let unit_value = parse_unit_expression(unit)?; + let reference_value = parse_unit_expression(reference_unit)?; + + if unit_value.dim != reference_value.dim { + return Err(Error::InvalidParameter(format!( + "dimension mismatch{}: '{}' has dimension {} but expected dimension {}", + context.map_or_else(String::new, |c| format!(" in {}", c)), + unit, + unit_value.dim, + reference_value.dim + ))); + } + + Ok(()) +} + + #[cfg(test)] #[allow(clippy::float_cmp)] mod tests {