From 929082240e98a23c99e73902ef9273f94e3b41ff Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 7 Apr 2026 12:45:04 +0000 Subject: [PATCH] Add ML-DSA-44 support alongside existing ML-DSA-65 https://claude.ai/code/session_01GaM5UZrNWJvnptDHvFNBWg --- CHANGELOG.rst | 2 +- docs/hazmat/primitives/asymmetric/mldsa.rst | 186 ++++++++++++++++ .../hazmat/bindings/_rust/openssl/mldsa.pyi | 5 + .../hazmat/primitives/asymmetric/mldsa.py | 143 +++++++++++++ .../cryptography-key-parsing/src/pkcs8.rs | 17 +- src/rust/cryptography-key-parsing/src/spki.rs | 18 +- src/rust/cryptography-openssl/src/mldsa.rs | 59 +++++- src/rust/cryptography-x509/src/common.rs | 2 + src/rust/cryptography-x509/src/oid.rs | 1 + src/rust/src/backend/keys.rs | 40 ++-- src/rust/src/backend/mldsa.rs | 198 +++++++++++++++++- tests/hazmat/primitives/test_mldsa.py | 92 ++++---- tests/wycheproof/test_mldsa.py | 77 ++++++- 13 files changed, 763 insertions(+), 77 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5c6decfc3488..2025e72f6b74 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -90,7 +90,7 @@ Changelog * Added :doc:`/hazmat/primitives/hpke` support implementing :rfc:`9180` for hybrid authenticated encryption. * Added new :doc:`/hazmat/primitives/asymmetric/mldsa` module with - support for ML-DSA-65 signing and verification with the AWS-LC backend. + support for ML-DSA signing and verification with the AWS-LC backend. .. _v46-0-6: diff --git a/docs/hazmat/primitives/asymmetric/mldsa.rst b/docs/hazmat/primitives/asymmetric/mldsa.rst index 9d78e6e2787f..6523c6006474 100644 --- a/docs/hazmat/primitives/asymmetric/mldsa.rst +++ b/docs/hazmat/primitives/asymmetric/mldsa.rst @@ -41,6 +41,192 @@ different contexts or protocols. Key interfaces ~~~~~~~~~~~~~~ +.. class:: MlDsa44PrivateKey + + .. versionadded:: 47.0 + + .. classmethod:: generate() + + Generate an ML-DSA-44 private key. + + :returns: :class:`MlDsa44PrivateKey` + + :raises cryptography.exceptions.UnsupportedAlgorithm: If ML-DSA-44 is + not supported by the backend ``cryptography`` is using. + + .. classmethod:: from_seed_bytes(data) + + Load an ML-DSA-44 private key from seed bytes. + + :param data: 32 byte seed. + :type data: :term:`bytes-like` + + :returns: :class:`MlDsa44PrivateKey` + + :raises ValueError: If the seed is not 32 bytes. + + :raises cryptography.exceptions.UnsupportedAlgorithm: If ML-DSA-44 is + not supported by the backend ``cryptography`` is using. + + .. doctest:: + :skipif: not _backend.mldsa_supported() + + >>> from cryptography.hazmat.primitives.asymmetric import mldsa + >>> private_key = mldsa.MlDsa44PrivateKey.generate() + >>> seed = private_key.private_bytes_raw() + >>> same_key = mldsa.MlDsa44PrivateKey.from_seed_bytes(seed) + + .. method:: public_key() + + :returns: :class:`MlDsa44PublicKey` + + .. method:: sign(data, context=None) + + Sign the data using ML-DSA-44. An optional context string can be + provided. + + :param data: The data to sign. + :type data: :term:`bytes-like` + + :param context: An optional context string (up to 255 bytes). + :type context: :term:`bytes-like` or ``None`` + + :returns bytes: The signature (2420 bytes). + + :raises ValueError: If the context is longer than 255 bytes. + + .. method:: private_bytes(encoding, format, encryption_algorithm) + + Allows serialization of the key to bytes. Encoding ( + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.PEM`, + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.DER`, or + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.Raw`) and + format ( + :attr:`~cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8` + or + :attr:`~cryptography.hazmat.primitives.serialization.PrivateFormat.Raw` + ) are chosen to define the exact serialization. + + This method only returns the serialization of the seed form of the + private key, never the expanded one. + + :param encoding: A value from the + :class:`~cryptography.hazmat.primitives.serialization.Encoding` enum. + + :param format: A value from the + :class:`~cryptography.hazmat.primitives.serialization.PrivateFormat` + enum. If the ``encoding`` is + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.Raw` + then ``format`` must be + :attr:`~cryptography.hazmat.primitives.serialization.PrivateFormat.Raw` + , otherwise it must be + :attr:`~cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8`. + + :param encryption_algorithm: An instance of an object conforming to the + :class:`~cryptography.hazmat.primitives.serialization.KeySerializationEncryption` + interface. + + :return bytes: Serialized key. + + .. method:: private_bytes_raw() + + Allows serialization of the key to raw bytes. This method is a + convenience shortcut for calling :meth:`private_bytes` with + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.Raw` + encoding, + :attr:`~cryptography.hazmat.primitives.serialization.PrivateFormat.Raw` + format, and + :class:`~cryptography.hazmat.primitives.serialization.NoEncryption`. + + This method only returns the seed form of the private key (32 bytes). + + :return bytes: Raw key (32-byte seed). + +.. class:: MlDsa44PublicKey + + .. versionadded:: 47.0 + + .. classmethod:: from_public_bytes(data) + + :param bytes data: 1312 byte public key. + + :returns: :class:`MlDsa44PublicKey` + + :raises ValueError: If the public key is not 1312 bytes. + + :raises cryptography.exceptions.UnsupportedAlgorithm: If ML-DSA-44 is + not supported by the backend ``cryptography`` is using. + + .. doctest:: + :skipif: not _backend.mldsa_supported() + + >>> from cryptography.hazmat.primitives import serialization + >>> from cryptography.hazmat.primitives.asymmetric import mldsa + >>> private_key = mldsa.MlDsa44PrivateKey.generate() + >>> public_key = private_key.public_key() + >>> public_bytes = public_key.public_bytes( + ... encoding=serialization.Encoding.Raw, + ... format=serialization.PublicFormat.Raw + ... ) + >>> loaded_public_key = mldsa.MlDsa44PublicKey.from_public_bytes(public_bytes) + + .. method:: public_bytes(encoding, format) + + Allows serialization of the key to bytes. Encoding ( + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.PEM`, + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.DER`, or + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.Raw`) and + format ( + :attr:`~cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo` + or + :attr:`~cryptography.hazmat.primitives.serialization.PublicFormat.Raw` + ) are chosen to define the exact serialization. + + :param encoding: A value from the + :class:`~cryptography.hazmat.primitives.serialization.Encoding` enum. + + :param format: A value from the + :class:`~cryptography.hazmat.primitives.serialization.PublicFormat` + enum. If the ``encoding`` is + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.Raw` + then ``format`` must be + :attr:`~cryptography.hazmat.primitives.serialization.PublicFormat.Raw` + , otherwise it must be + :attr:`~cryptography.hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo`. + + :returns bytes: The public key bytes. + + .. method:: public_bytes_raw() + + Allows serialization of the key to raw bytes. This method is a + convenience shortcut for calling :meth:`public_bytes` with + :attr:`~cryptography.hazmat.primitives.serialization.Encoding.Raw` + encoding and + :attr:`~cryptography.hazmat.primitives.serialization.PublicFormat.Raw` + format. + + :return bytes: 1312-byte raw public key. + + .. method:: verify(signature, data, context=None) + + Verify a signature using ML-DSA-44. If a context string was used during + signing, the same context must be provided for verification to succeed. + + :param signature: The signature to verify. + :type signature: :term:`bytes-like` + + :param data: The data to verify. + :type data: :term:`bytes-like` + + :param context: An optional context string (up to 255 bytes) that was + used during signing. + :type context: :term:`bytes-like` or ``None`` + + :returns: None + :raises cryptography.exceptions.InvalidSignature: Raised when the + signature cannot be verified. + :raises ValueError: If the context is longer than 255 bytes. + .. class:: MlDsa65PrivateKey .. versionadded:: 47.0 diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/mldsa.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/mldsa.pyi index 9d834b41b0fd..9028e560c09f 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/mldsa.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/mldsa.pyi @@ -5,9 +5,14 @@ from cryptography.hazmat.primitives.asymmetric import mldsa from cryptography.utils import Buffer +class MlDsa44PrivateKey: ... +class MlDsa44PublicKey: ... class MlDsa65PrivateKey: ... class MlDsa65PublicKey: ... +def generate_mldsa44_key() -> mldsa.MlDsa44PrivateKey: ... +def from_mldsa44_public_bytes(data: bytes) -> mldsa.MlDsa44PublicKey: ... +def from_mldsa44_seed_bytes(data: Buffer) -> mldsa.MlDsa44PrivateKey: ... def generate_mldsa65_key() -> mldsa.MlDsa65PrivateKey: ... def from_mldsa65_public_bytes(data: bytes) -> mldsa.MlDsa65PublicKey: ... def from_mldsa65_seed_bytes(data: Buffer) -> mldsa.MlDsa65PrivateKey: ... diff --git a/src/cryptography/hazmat/primitives/asymmetric/mldsa.py b/src/cryptography/hazmat/primitives/asymmetric/mldsa.py index 349e8a6411e6..87d5873b20a3 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/mldsa.py +++ b/src/cryptography/hazmat/primitives/asymmetric/mldsa.py @@ -12,6 +12,149 @@ from cryptography.utils import Buffer +class MlDsa44PublicKey(metaclass=abc.ABCMeta): + @classmethod + def from_public_bytes(cls, data: bytes) -> MlDsa44PublicKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.mldsa_supported(): + raise UnsupportedAlgorithm( + "ML-DSA-44 is not supported by this backend.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.mldsa.from_mldsa44_public_bytes(data) + + @abc.abstractmethod + def public_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PublicFormat, + ) -> bytes: + """ + The serialized bytes of the public key. + """ + + @abc.abstractmethod + def public_bytes_raw(self) -> bytes: + """ + The raw bytes of the public key. + Equivalent to public_bytes(Raw, Raw). + + The public key is 1,312 bytes for MLDSA-44. + """ + + @abc.abstractmethod + def verify( + self, + signature: Buffer, + data: Buffer, + context: Buffer | None = None, + ) -> None: + """ + Verify the signature. + """ + + @abc.abstractmethod + def __eq__(self, other: object) -> bool: + """ + Checks equality. + """ + + @abc.abstractmethod + def __copy__(self) -> MlDsa44PublicKey: + """ + Returns a copy. + """ + + @abc.abstractmethod + def __deepcopy__(self, memo: dict) -> MlDsa44PublicKey: + """ + Returns a deep copy. + """ + + +if hasattr(rust_openssl, "mldsa"): + MlDsa44PublicKey.register(rust_openssl.mldsa.MlDsa44PublicKey) + + +class MlDsa44PrivateKey(metaclass=abc.ABCMeta): + @classmethod + def generate(cls) -> MlDsa44PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.mldsa_supported(): + raise UnsupportedAlgorithm( + "ML-DSA-44 is not supported by this backend.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.mldsa.generate_mldsa44_key() + + @classmethod + def from_seed_bytes(cls, data: Buffer) -> MlDsa44PrivateKey: + from cryptography.hazmat.backends.openssl.backend import backend + + if not backend.mldsa_supported(): + raise UnsupportedAlgorithm( + "ML-DSA-44 is not supported by this backend.", + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM, + ) + + return rust_openssl.mldsa.from_mldsa44_seed_bytes(data) + + @abc.abstractmethod + def public_key(self) -> MlDsa44PublicKey: + """ + The MlDsa44PublicKey derived from the private key. + """ + + @abc.abstractmethod + def private_bytes( + self, + encoding: _serialization.Encoding, + format: _serialization.PrivateFormat, + encryption_algorithm: _serialization.KeySerializationEncryption, + ) -> bytes: + """ + The serialized bytes of the private key. + + This method only returns the serialization of the seed form of the + private key, never the expanded one. + """ + + @abc.abstractmethod + def private_bytes_raw(self) -> bytes: + """ + The raw bytes of the private key. + Equivalent to private_bytes(Raw, Raw, NoEncryption()). + + This method only returns the seed form of the private key (32 bytes). + """ + + @abc.abstractmethod + def sign(self, data: Buffer, context: Buffer | None = None) -> bytes: + """ + Signs the data. + """ + + @abc.abstractmethod + def __copy__(self) -> MlDsa44PrivateKey: + """ + Returns a copy. + """ + + @abc.abstractmethod + def __deepcopy__(self, memo: dict) -> MlDsa44PrivateKey: + """ + Returns a deep copy. + """ + + +if hasattr(rust_openssl, "mldsa"): + MlDsa44PrivateKey.register(rust_openssl.mldsa.MlDsa44PrivateKey) + + class MlDsa65PublicKey(metaclass=abc.ABCMeta): @classmethod def from_public_bytes(cls, data: bytes) -> MlDsa65PublicKey: diff --git a/src/rust/cryptography-key-parsing/src/pkcs8.rs b/src/rust/cryptography-key-parsing/src/pkcs8.rs index 94ae1d1caff3..56ae756b47b8 100644 --- a/src/rust/cryptography-key-parsing/src/pkcs8.rs +++ b/src/rust/cryptography-key-parsing/src/pkcs8.rs @@ -132,6 +132,17 @@ pub fn parse_private_key(data: &[u8]) -> KeyParsingResult { Ok(ParsedPrivateKey::Pkey(pkey)) } + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + AlgorithmParameters::MlDsa44 => { + let MlDsaPrivateKey::Seed(seed) = asn1::parse_single::(k.private_key)?; + Ok(ParsedPrivateKey::Pkey( + cryptography_openssl::mldsa::new_raw_private_key( + cryptography_openssl::mldsa::MlDsaVariant::MlDsa44, + &seed, + )?, + )) + } + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] AlgorithmParameters::MlDsa65 => { let MlDsaPrivateKey::Seed(seed) = asn1::parse_single::(k.private_key)?; @@ -481,7 +492,11 @@ pub fn serialize_private_key( #[cfg(CRYPTOGRAPHY_IS_AWSLC)] cryptography_openssl::mldsa::PKEY_ID => { let private_key_der = asn1::write_single(&mldsa_seed_from_pkey(pkey)?)?; - (AlgorithmParameters::MlDsa65, private_key_der) + let params = match cryptography_openssl::mldsa::MlDsaVariant::from_pkey(pkey) { + cryptography_openssl::mldsa::MlDsaVariant::MlDsa44 => AlgorithmParameters::MlDsa44, + cryptography_openssl::mldsa::MlDsaVariant::MlDsa65 => AlgorithmParameters::MlDsa65, + }; + (params, private_key_der) } _ => { unimplemented!("Unknown key type"); diff --git a/src/rust/cryptography-key-parsing/src/spki.rs b/src/rust/cryptography-key-parsing/src/spki.rs index cb03024b53e1..ad6199280dbb 100644 --- a/src/rust/cryptography-key-parsing/src/spki.rs +++ b/src/rust/cryptography-key-parsing/src/spki.rs @@ -109,6 +109,14 @@ pub fn parse_public_key(data: &[u8]) -> KeyParsingResult { Ok(ParsedPublicKey::Pkey(openssl::pkey::PKey::from_dh(dh)?)) } #[cfg(CRYPTOGRAPHY_IS_AWSLC)] + AlgorithmParameters::MlDsa44 => Ok(ParsedPublicKey::Pkey( + cryptography_openssl::mldsa::new_raw_public_key( + cryptography_openssl::mldsa::MlDsaVariant::MlDsa44, + k.subject_public_key.as_bytes(), + ) + .map_err(|_| KeyParsingError::InvalidKey)?, + )), + #[cfg(CRYPTOGRAPHY_IS_AWSLC)] AlgorithmParameters::MlDsa65 => Ok(ParsedPublicKey::Pkey( cryptography_openssl::mldsa::new_raw_public_key( cryptography_openssl::mldsa::MlDsaVariant::MlDsa65, @@ -234,11 +242,11 @@ pub fn serialize_public_key( #[cfg(CRYPTOGRAPHY_IS_AWSLC)] cryptography_openssl::mldsa::PKEY_ID => { let raw_bytes = pkey.raw_public_key()?; - assert_eq!( - raw_bytes.len(), - cryptography_openssl::mldsa::MlDsaVariant::MlDsa65.public_key_bytes() - ); - (AlgorithmParameters::MlDsa65, raw_bytes) + let params = match cryptography_openssl::mldsa::MlDsaVariant::from_pkey(pkey) { + cryptography_openssl::mldsa::MlDsaVariant::MlDsa44 => AlgorithmParameters::MlDsa44, + cryptography_openssl::mldsa::MlDsaVariant::MlDsa65 => AlgorithmParameters::MlDsa65, + }; + (params, raw_bytes) } _ => { unimplemented!("Unknown key type"); diff --git a/src/rust/cryptography-openssl/src/mldsa.rs b/src/rust/cryptography-openssl/src/mldsa.rs index 8b3ddc585736..d9f071b6818b 100644 --- a/src/rust/cryptography-openssl/src/mldsa.rs +++ b/src/rust/cryptography-openssl/src/mldsa.rs @@ -12,27 +12,26 @@ pub const PKEY_ID: openssl::pkey::Id = openssl::pkey::Id::from_raw(ffi::NID_PQDS #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MlDsaVariant { + MlDsa44, MlDsa65, } impl MlDsaVariant { pub fn nid(self) -> c_int { match self { + MlDsaVariant::MlDsa44 => ffi::NID_MLDSA44, MlDsaVariant::MlDsa65 => ffi::NID_MLDSA65, } } - pub fn public_key_bytes(self) -> usize { - match self { - MlDsaVariant::MlDsa65 => 1952, - } - } - - fn from_pkey(pkey: &openssl::pkey::PKeyRef) -> MlDsaVariant { + pub fn from_pkey( + pkey: &openssl::pkey::PKeyRef, + ) -> MlDsaVariant { // SAFETY: EVP_PKEY_pqdsa_get_type returns the NID of the PQDSA // algorithm for a valid PQDSA pkey. let nid = unsafe { ffi::EVP_PKEY_pqdsa_get_type(pkey.as_ptr()) }; match nid { + ffi::NID_MLDSA44 => MlDsaVariant::MlDsa44, ffi::NID_MLDSA65 => MlDsaVariant::MlDsa65, _ => panic!("Unsupported ML-DSA variant"), } @@ -40,9 +39,29 @@ impl MlDsaVariant { } extern "C" { - // We call ml_dsa_65_sign/verify directly instead of going through + // We call ml_dsa_{44,65}_sign/verify directly instead of going through // EVP_DigestSign/EVP_DigestVerify because the EVP PQDSA path hardcodes // context to (NULL, 0), so we'd lose context string support. + fn ml_dsa_44_sign( + private_key: *const u8, + sig: *mut u8, + sig_len: *mut usize, + message: *const u8, + message_len: usize, + ctx_string: *const u8, + ctx_string_len: usize, + ) -> c_int; + + fn ml_dsa_44_verify( + public_key: *const u8, + sig: *const u8, + sig_len: usize, + message: *const u8, + message_len: usize, + ctx_string: *const u8, + ctx_string_len: usize, + ) -> c_int; + fn ml_dsa_65_sign( private_key: *const u8, sig: *mut u8, @@ -104,7 +123,17 @@ pub fn sign( let raw_key = pkey.raw_private_key()?; let variant = MlDsaVariant::from_pkey(pkey); - let (signature_bytes, sign_func) = match variant { + type SignFn = unsafe extern "C" fn( + *const u8, + *mut u8, + *mut usize, + *const u8, + usize, + *const u8, + usize, + ) -> c_int; + let (signature_bytes, sign_func): (usize, SignFn) = match variant { + MlDsaVariant::MlDsa44 => (2420, ml_dsa_44_sign), MlDsaVariant::MlDsa65 => (3309, ml_dsa_65_sign), }; @@ -149,7 +178,17 @@ pub fn verify( let raw_key = pkey.raw_public_key()?; let variant = MlDsaVariant::from_pkey(pkey); - let verify_func = match variant { + type VerifyFn = unsafe extern "C" fn( + *const u8, + *const u8, + usize, + *const u8, + usize, + *const u8, + usize, + ) -> c_int; + let verify_func: VerifyFn = match variant { + MlDsaVariant::MlDsa44 => ml_dsa_44_verify, MlDsaVariant::MlDsa65 => ml_dsa_65_verify, }; diff --git a/src/rust/cryptography-x509/src/common.rs b/src/rust/cryptography-x509/src/common.rs index 7a22c100cabe..56e01d6d0cef 100644 --- a/src/rust/cryptography-x509/src/common.rs +++ b/src/rust/cryptography-x509/src/common.rs @@ -53,6 +53,8 @@ pub enum AlgorithmParameters<'a> { #[defined_by(oid::ED448_OID)] Ed448, + #[defined_by(oid::ML_DSA_44_OID)] + MlDsa44, #[defined_by(oid::ML_DSA_65_OID)] MlDsa65, diff --git a/src/rust/cryptography-x509/src/oid.rs b/src/rust/cryptography-x509/src/oid.rs index 4a69312d7ceb..d7ee487d380a 100644 --- a/src/rust/cryptography-x509/src/oid.rs +++ b/src/rust/cryptography-x509/src/oid.rs @@ -109,6 +109,7 @@ pub const X448_OID: asn1::ObjectIdentifier = asn1::oid!(1, 3, 101, 111); pub const ED25519_OID: asn1::ObjectIdentifier = asn1::oid!(1, 3, 101, 112); pub const ED448_OID: asn1::ObjectIdentifier = asn1::oid!(1, 3, 101, 113); +pub const ML_DSA_44_OID: asn1::ObjectIdentifier = asn1::oid!(2, 16, 840, 1, 101, 3, 4, 3, 17); pub const ML_DSA_65_OID: asn1::ObjectIdentifier = asn1::oid!(2, 16, 840, 1, 101, 3, 4, 3, 18); // Hashes diff --git a/src/rust/src/backend/keys.rs b/src/rust/src/backend/keys.rs index 3b37dcaaf45f..1461c2e4d41e 100644 --- a/src/rust/src/backend/keys.rs +++ b/src/rust/src/backend/keys.rs @@ -182,14 +182,18 @@ fn private_key_from_pkey<'p>( .into_any()), #[cfg(CRYPTOGRAPHY_IS_AWSLC)] cryptography_openssl::mldsa::PKEY_ID => { - let pub_len = pkey.raw_public_key()?.len(); - assert_eq!( - pub_len, - cryptography_openssl::mldsa::MlDsaVariant::MlDsa65.public_key_bytes() - ); - Ok(crate::backend::mldsa::mldsa65_private_key_from_pkey(pkey) - .into_pyobject(py)? - .into_any()) + match cryptography_openssl::mldsa::MlDsaVariant::from_pkey(pkey) { + cryptography_openssl::mldsa::MlDsaVariant::MlDsa44 => { + Ok(crate::backend::mldsa::mldsa44_private_key_from_pkey(pkey) + .into_pyobject(py)? + .into_any()) + } + cryptography_openssl::mldsa::MlDsaVariant::MlDsa65 => { + Ok(crate::backend::mldsa::mldsa65_private_key_from_pkey(pkey) + .into_pyobject(py)? + .into_any()) + } + } } _ => Err(CryptographyError::from( exceptions::UnsupportedAlgorithm::new_err("Unsupported key type."), @@ -349,14 +353,18 @@ fn public_key_from_pkey<'p>( #[cfg(CRYPTOGRAPHY_IS_AWSLC)] cryptography_openssl::mldsa::PKEY_ID => { - let pub_len = pkey.raw_public_key()?.len(); - assert_eq!( - pub_len, - cryptography_openssl::mldsa::MlDsaVariant::MlDsa65.public_key_bytes() - ); - Ok(crate::backend::mldsa::mldsa65_public_key_from_pkey(pkey) - .into_pyobject(py)? - .into_any()) + match cryptography_openssl::mldsa::MlDsaVariant::from_pkey(pkey) { + cryptography_openssl::mldsa::MlDsaVariant::MlDsa44 => { + Ok(crate::backend::mldsa::mldsa44_public_key_from_pkey(pkey) + .into_pyobject(py)? + .into_any()) + } + cryptography_openssl::mldsa::MlDsaVariant::MlDsa65 => { + Ok(crate::backend::mldsa::mldsa65_public_key_from_pkey(pkey) + .into_pyobject(py)? + .into_any()) + } + } } _ => Err(CryptographyError::from( exceptions::UnsupportedAlgorithm::new_err("Unsupported key type."), diff --git a/src/rust/src/backend/mldsa.rs b/src/rust/src/backend/mldsa.rs index 3205fdbf3be5..842ed0f1e781 100644 --- a/src/rust/src/backend/mldsa.rs +++ b/src/rust/src/backend/mldsa.rs @@ -12,6 +12,199 @@ use crate::exceptions; const MAX_CONTEXT_BYTES: usize = 255; +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.mldsa")] +pub(crate) struct MlDsa44PrivateKey { + pkey: openssl::pkey::PKey, +} + +#[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.mldsa")] +pub(crate) struct MlDsa44PublicKey { + pkey: openssl::pkey::PKey, +} + +pub(crate) fn mldsa44_private_key_from_pkey( + pkey: &openssl::pkey::PKeyRef, +) -> MlDsa44PrivateKey { + MlDsa44PrivateKey { + pkey: pkey.to_owned(), + } +} + +pub(crate) fn mldsa44_public_key_from_pkey( + pkey: &openssl::pkey::PKeyRef, +) -> MlDsa44PublicKey { + MlDsa44PublicKey { + pkey: pkey.to_owned(), + } +} + +#[pyo3::pyfunction] +fn generate_mldsa44_key() -> CryptographyResult { + let mut seed = [0u8; 32]; + cryptography_openssl::rand::rand_bytes(&mut seed)?; + let pkey = cryptography_openssl::mldsa::new_raw_private_key(MlDsaVariant::MlDsa44, &seed)?; + Ok(MlDsa44PrivateKey { pkey }) +} + +#[pyo3::pyfunction] +fn from_mldsa44_seed_bytes(data: CffiBuf<'_>) -> pyo3::PyResult { + let pkey = + cryptography_openssl::mldsa::new_raw_private_key(MlDsaVariant::MlDsa44, data.as_bytes()) + .map_err(|_| { + pyo3::exceptions::PyValueError::new_err("An ML-DSA-44 seed is 32 bytes long") + })?; + Ok(MlDsa44PrivateKey { pkey }) +} + +#[pyo3::pyfunction] +fn from_mldsa44_public_bytes(data: &[u8]) -> pyo3::PyResult { + let pkey = cryptography_openssl::mldsa::new_raw_public_key(MlDsaVariant::MlDsa44, data) + .map_err(|_| { + pyo3::exceptions::PyValueError::new_err("An ML-DSA-44 public key is 1312 bytes long") + })?; + Ok(MlDsa44PublicKey { pkey }) +} + +#[pyo3::pymethods] +impl MlDsa44PrivateKey { + #[pyo3(signature = (data, context=None))] + fn sign<'p>( + &self, + py: pyo3::Python<'p>, + data: CffiBuf<'_>, + context: Option>, + ) -> CryptographyResult> { + let ctx_bytes = context.as_ref().map_or(&[][..], |c| c.as_bytes()); + if ctx_bytes.len() > MAX_CONTEXT_BYTES { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Context must be at most 255 bytes"), + )); + } + let sig = cryptography_openssl::mldsa::sign(&self.pkey, data.as_bytes(), ctx_bytes)?; + Ok(pyo3::types::PyBytes::new(py, &sig)) + } + + fn public_key(&self) -> CryptographyResult { + let raw_bytes = self.pkey.raw_public_key()?; + Ok(MlDsa44PublicKey { + pkey: cryptography_openssl::mldsa::new_raw_public_key( + MlDsaVariant::MlDsa44, + &raw_bytes, + )?, + }) + } + + fn private_bytes_raw<'p>( + &self, + py: pyo3::Python<'p>, + ) -> CryptographyResult> { + let cryptography_key_parsing::pkcs8::MlDsaPrivateKey::Seed(seed) = + cryptography_key_parsing::pkcs8::mldsa_seed_from_pkey(&self.pkey)?; + Ok(pyo3::types::PyBytes::new(py, &seed)) + } + + fn private_bytes<'p>( + slf: &pyo3::Bound<'p, Self>, + py: pyo3::Python<'p>, + encoding: crate::serialization::Encoding, + format: crate::serialization::PrivateFormat, + encryption_algorithm: &pyo3::Bound<'p, pyo3::PyAny>, + ) -> CryptographyResult> { + if encoding == crate::serialization::Encoding::Raw + && format == crate::serialization::PrivateFormat::Raw + && encryption_algorithm.is_instance(&crate::types::NO_ENCRYPTION.get(py)?)? + { + return slf.borrow().private_bytes_raw(py); + } + utils::pkey_private_bytes( + py, + slf, + &slf.borrow().pkey, + encoding, + format, + encryption_algorithm, + true, + false, + ) + } + + fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> { + slf + } + + fn __deepcopy__<'p>( + slf: pyo3::PyRef<'p, Self>, + _memo: &pyo3::Bound<'p, pyo3::PyAny>, + ) -> pyo3::PyRef<'p, Self> { + slf + } +} + +#[pyo3::pymethods] +impl MlDsa44PublicKey { + #[pyo3(signature = (signature, data, context=None))] + fn verify( + &self, + signature: CffiBuf<'_>, + data: CffiBuf<'_>, + context: Option>, + ) -> CryptographyResult<()> { + let ctx_bytes = context.as_ref().map_or(&[][..], |c| c.as_bytes()); + if ctx_bytes.len() > MAX_CONTEXT_BYTES { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Context must be at most 255 bytes"), + )); + } + let valid = cryptography_openssl::mldsa::verify( + &self.pkey, + signature.as_bytes(), + data.as_bytes(), + ctx_bytes, + ) + .unwrap_or(false); + + if !valid { + return Err(CryptographyError::from( + exceptions::InvalidSignature::new_err(()), + )); + } + + Ok(()) + } + + fn public_bytes_raw<'p>( + &self, + py: pyo3::Python<'p>, + ) -> CryptographyResult> { + let raw_bytes = self.pkey.raw_public_key()?; + Ok(pyo3::types::PyBytes::new(py, &raw_bytes)) + } + + fn public_bytes<'p>( + slf: &pyo3::Bound<'p, Self>, + py: pyo3::Python<'p>, + encoding: crate::serialization::Encoding, + format: crate::serialization::PublicFormat, + ) -> CryptographyResult> { + utils::pkey_public_bytes(py, slf, &slf.borrow().pkey, encoding, format, true, true) + } + + fn __eq__(&self, other: pyo3::PyRef<'_, Self>) -> bool { + self.pkey.public_eq(&other.pkey) + } + + fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> { + slf + } + + fn __deepcopy__<'p>( + slf: pyo3::PyRef<'p, Self>, + _memo: &pyo3::Bound<'p, pyo3::PyAny>, + ) -> pyo3::PyRef<'p, Self> { + slf + } +} + #[pyo3::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.mldsa")] pub(crate) struct MlDsa65PrivateKey { pkey: openssl::pkey::PKey, @@ -212,7 +405,8 @@ impl MlDsa65PublicKey { pub(crate) mod mldsa { #[pymodule_export] use super::{ - from_mldsa65_public_bytes, from_mldsa65_seed_bytes, generate_mldsa65_key, - MlDsa65PrivateKey, MlDsa65PublicKey, + from_mldsa44_public_bytes, from_mldsa44_seed_bytes, from_mldsa65_public_bytes, + from_mldsa65_seed_bytes, generate_mldsa44_key, generate_mldsa65_key, MlDsa44PrivateKey, + MlDsa44PublicKey, MlDsa65PrivateKey, MlDsa65PublicKey, }; } diff --git a/tests/hazmat/primitives/test_mldsa.py b/tests/hazmat/primitives/test_mldsa.py index 8ddfd514df19..3b231efd72f2 100644 --- a/tests/hazmat/primitives/test_mldsa.py +++ b/tests/hazmat/primitives/test_mldsa.py @@ -13,6 +13,8 @@ from cryptography.exceptions import InvalidSignature, _Reasons from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.mldsa import ( + MlDsa44PrivateKey, + MlDsa44PublicKey, MlDsa65PrivateKey, MlDsa65PublicKey, ) @@ -35,6 +37,16 @@ class MlDsaVariant: ML_DSA_VARIANTS = [ + pytest.param( + MlDsaVariant( + private_key_class=MlDsa44PrivateKey, + public_key_class=MlDsa44PublicKey, + pub_key_size=1312, + sig_size=2420, + seed_size=32, + ), + id="ML-DSA-44", + ), pytest.param( MlDsaVariant( private_key_class=MlDsa65PrivateKey, @@ -50,9 +62,24 @@ class MlDsaVariant: @pytest.mark.supported( only_if=lambda backend: not backend.mldsa_supported(), - skip_message="Requires a backend without ML-DSA-65 support", + skip_message="Requires a backend without ML-DSA support", ) def test_mldsa_unsupported(backend): + with raises_unsupported_algorithm( + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM + ): + MlDsa44PublicKey.from_public_bytes(b"0" * 1312) + + with raises_unsupported_algorithm( + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM + ): + MlDsa44PrivateKey.from_seed_bytes(b"0" * 32) + + with raises_unsupported_algorithm( + _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM + ): + MlDsa44PrivateKey.generate() + with raises_unsupported_algorithm( _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM ): @@ -71,7 +98,7 @@ def test_mldsa_unsupported(backend): @pytest.mark.supported( only_if=lambda backend: backend.mldsa_supported(), - skip_message="Requires a backend with ML-DSA-65 support", + skip_message="Requires a backend with ML-DSA support", ) class TestMlDsa: @pytest.mark.parametrize("variant", ML_DSA_VARIANTS) @@ -109,7 +136,28 @@ def test_empty_context_equivalence(self, variant, backend): sig2 = key.sign(data, b"") pub.verify(sig2, data) - def test_kat_vectors(self, backend, subtests): + def test_kat_vectors_44(self, backend, subtests): + vectors = load_vectors_from_file( + os.path.join("asymmetric", "MLDSA", "kat_MLDSA_44_det_pure.rsp"), + load_nist_vectors, + ) + for vector in vectors: + with subtests.test(): + xi = binascii.unhexlify(vector["xi"]) + pk = binascii.unhexlify(vector["pk"]) + msg = binascii.unhexlify(vector["msg"]) + ctx = binascii.unhexlify(vector["ctx"]) + sm = binascii.unhexlify(vector["sm"]) + expected_sig = sm[:2420] + + key = MlDsa44PrivateKey.from_seed_bytes(xi) + assert key.private_bytes_raw() == xi + assert key.public_key().public_bytes_raw() == pk + + pub = MlDsa44PublicKey.from_public_bytes(pk) + pub.verify(expected_sig, msg, ctx) + + def test_kat_vectors_65(self, backend, subtests): vectors = load_vectors_from_file( os.path.join("asymmetric", "MLDSA", "kat_MLDSA_65_det_pure.rsp"), load_nist_vectors, @@ -358,26 +406,7 @@ def test_private_key_deepcopy(self, variant, backend): @pytest.mark.supported( only_if=lambda backend: backend.mldsa_supported(), - skip_message="Requires a backend with ML-DSA-65 support", -) -def test_unsupported_mldsa_variant_private_key(backend): - # ML-DSA-44 is not supported; loading it must raise UnsupportedAlgorithm. - pkcs8_der = load_vectors_from_file( - os.path.join("asymmetric", "MLDSA", "mldsa44_priv.der"), - lambda derfile: derfile.read(), - mode="rb", - ) - with raises_unsupported_algorithm( - _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM - ): - serialization.load_der_private_key( - pkcs8_der, password=None, backend=backend - ) - - -@pytest.mark.supported( - only_if=lambda backend: backend.mldsa_supported(), - skip_message="Requires a backend with ML-DSA-65 support", + skip_message="Requires a backend with ML-DSA support", ) def test_mldsa65_private_key_no_seed(backend): pkcs8_der = load_vectors_from_file( @@ -389,20 +418,3 @@ def test_mldsa65_private_key_no_seed(backend): serialization.load_der_private_key( pkcs8_der, password=None, backend=backend ) - - -@pytest.mark.supported( - only_if=lambda backend: backend.mldsa_supported(), - skip_message="Requires a backend with ML-DSA-65 support", -) -def test_unsupported_mldsa_variant_public_key(backend): - # ML-DSA-44 is not supported; loading it must raise UnsupportedAlgorithm. - spki_der = load_vectors_from_file( - os.path.join("asymmetric", "MLDSA", "mldsa44_pub.der"), - lambda derfile: derfile.read(), - mode="rb", - ) - with raises_unsupported_algorithm( - _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM - ): - serialization.load_der_public_key(spki_der, backend=backend) diff --git a/tests/wycheproof/test_mldsa.py b/tests/wycheproof/test_mldsa.py index 48615356efcc..877607fb1f7b 100644 --- a/tests/wycheproof/test_mldsa.py +++ b/tests/wycheproof/test_mldsa.py @@ -8,6 +8,8 @@ from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric.mldsa import ( + MlDsa44PrivateKey, + MlDsa44PublicKey, MlDsa65PrivateKey, MlDsa65PublicKey, ) @@ -17,7 +19,78 @@ @pytest.mark.supported( only_if=lambda backend: backend.mldsa_supported(), - skip_message="Requires a backend with ML-DSA-65 support", + skip_message="Requires a backend with ML-DSA support", +) +@wycheproof_tests("mldsa_44_verify_test.json") +def test_mldsa44_verify(backend, wycheproof): + try: + pub = MlDsa44PublicKey.from_public_bytes( + binascii.unhexlify(wycheproof.testgroup["publicKey"]) + ) + except ValueError: + assert wycheproof.invalid + assert wycheproof.has_flag("IncorrectPublicKeyLength") + return + + msg = binascii.unhexlify(wycheproof.testcase["msg"]) + sig = binascii.unhexlify(wycheproof.testcase["sig"]) + has_ctx = "ctx" in wycheproof.testcase + ctx = binascii.unhexlify(wycheproof.testcase["ctx"]) if has_ctx else None + + if wycheproof.valid: + pub.verify(sig, msg, ctx) + else: + with pytest.raises( + ( + ValueError, + InvalidSignature, + ) + ): + pub.verify(sig, msg, ctx) + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA support", +) +@wycheproof_tests("mldsa_44_sign_seed_test.json") +def test_mldsa44_sign_seed(backend, wycheproof): + # Skip "Internal" tests, they use the inner method `Sign_internal` + # instead of `Sign` which we do not expose. + if wycheproof.has_flag("Internal"): + return + + seed = binascii.unhexlify(wycheproof.testgroup["privateSeed"]) + try: + key = MlDsa44PrivateKey.from_seed_bytes(seed) + except ValueError: + assert wycheproof.invalid + assert wycheproof.has_flag("IncorrectPrivateKeyLength") + return + pub = MlDsa44PublicKey.from_public_bytes( + binascii.unhexlify(wycheproof.testgroup["publicKey"]) + ) + + assert key.public_key() == pub + + msg = binascii.unhexlify(wycheproof.testcase["msg"]) + has_ctx = "ctx" in wycheproof.testcase + ctx = binascii.unhexlify(wycheproof.testcase["ctx"]) if has_ctx else None + + if wycheproof.valid or wycheproof.acceptable: + # Sign and verify round-trip. We don't compare exact signature + # bytes because some backends use hedged (randomized) signing. + sig = key.sign(msg, ctx) + pub.verify(sig, msg, ctx) + else: + with pytest.raises(ValueError): + assert has_ctx + key.sign(msg, ctx) + + +@pytest.mark.supported( + only_if=lambda backend: backend.mldsa_supported(), + skip_message="Requires a backend with ML-DSA support", ) @wycheproof_tests("mldsa_65_verify_test.json") def test_mldsa65_verify(backend, wycheproof): @@ -49,7 +122,7 @@ def test_mldsa65_verify(backend, wycheproof): @pytest.mark.supported( only_if=lambda backend: backend.mldsa_supported(), - skip_message="Requires a backend with ML-DSA-65 support", + skip_message="Requires a backend with ML-DSA support", ) @wycheproof_tests("mldsa_65_sign_seed_test.json") def test_mldsa65_sign_seed(backend, wycheproof):