Skip to content

Commit 1f1303a

Browse files
committed
Additional round of review
1 parent 4b8430c commit 1f1303a

File tree

3 files changed

+22
-23
lines changed

3 files changed

+22
-23
lines changed

src/rust/cryptography-key-parsing/src/pkcs8.rs

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,24 @@ pub struct PrivateKeyInfo<'a> {
2424

2525
// RFC 9881 Section 6.5
2626
#[cfg(CRYPTOGRAPHY_IS_AWSLC)]
27-
// NO-COVERAGE-START
2827
#[derive(asn1::Asn1Read, asn1::Asn1Write)]
29-
// NO-COVERAGE-END
30-
pub enum MlDsaPrivateKey<'a> {
28+
pub enum MlDsaPrivateKey {
3129
#[implicit(0)]
32-
Seed(&'a [u8]),
30+
Seed([u8; 32]),
31+
}
32+
33+
/// Extract the 32-byte ML-DSA-65 seed from a private key.
34+
///
35+
/// AWS-LC's `raw_private_key()` returns the expanded key, not the seed.
36+
/// This function round-trips through the native PKCS#8 encoding to extract it.
37+
/// https://github.com/aws/aws-lc/issues/3072
38+
#[cfg(CRYPTOGRAPHY_IS_AWSLC)]
39+
pub fn mldsa_seed_from_pkey(
40+
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
41+
) -> Result<MlDsaPrivateKey, openssl::error::ErrorStack> {
42+
let pkcs8_der = pkey.private_key_to_pkcs8()?;
43+
let pki = asn1::parse_single::<PrivateKeyInfo<'_>>(&pkcs8_der).unwrap();
44+
Ok(asn1::parse_single::<MlDsaPrivateKey>(pki.private_key).unwrap())
3345
}
3446

3547
pub fn parse_private_key(
@@ -120,9 +132,8 @@ pub fn parse_private_key(
120132

121133
#[cfg(CRYPTOGRAPHY_IS_AWSLC)]
122134
AlgorithmParameters::MlDsa65 => {
123-
let MlDsaPrivateKey::Seed(seed) =
124-
asn1::parse_single::<MlDsaPrivateKey<'_>>(k.private_key)?;
125-
Ok(cryptography_openssl::mldsa::new_raw_private_key(seed)?)
135+
let MlDsaPrivateKey::Seed(seed) = asn1::parse_single::<MlDsaPrivateKey>(k.private_key)?;
136+
Ok(cryptography_openssl::mldsa::new_raw_private_key(&seed)?)
126137
}
127138

128139
_ => Err(KeyParsingError::UnsupportedKeyType(
@@ -462,8 +473,7 @@ pub fn serialize_private_key(
462473
}
463474
#[cfg(CRYPTOGRAPHY_IS_AWSLC)]
464475
cryptography_openssl::mldsa::PKEY_ID => {
465-
let seed = pkey.raw_private_key()?;
466-
let private_key_der = asn1::write_single(&MlDsaPrivateKey::Seed(seed.as_slice()))?;
476+
let private_key_der = asn1::write_single(&mldsa_seed_from_pkey(pkey)?)?;
467477
(AlgorithmParameters::MlDsa65, private_key_der)
468478
}
469479
_ => {

src/rust/src/backend/mldsa.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,9 @@ impl MlDsa65PrivateKey {
9191
&self,
9292
py: pyo3::Python<'p>,
9393
) -> CryptographyResult<pyo3::Bound<'p, pyo3::types::PyBytes>> {
94-
// AWS-LC's raw_private_key() returns the expanded key, not the seed.
95-
// Round-trip through PKCS#8 DER to extract the 32-byte seed.
96-
// Note: private_key_to_pkcs8() (i2d_PKCS8PrivateKey_bio) must be used
97-
// instead of private_key_to_der() (i2d_PrivateKey), because AWS-LC's
98-
// i2d_PrivateKey doesn't support PQDSA keys.
99-
let pkcs8_der = self.pkey.private_key_to_pkcs8()?;
100-
let pki =
101-
asn1::parse_single::<cryptography_key_parsing::pkcs8::PrivateKeyInfo<'_>>(&pkcs8_der)
102-
.unwrap();
10394
let cryptography_key_parsing::pkcs8::MlDsaPrivateKey::Seed(seed) =
104-
asn1::parse_single::<cryptography_key_parsing::pkcs8::MlDsaPrivateKey<'_>>(
105-
pki.private_key,
106-
)
107-
.unwrap();
108-
Ok(pyo3::types::PyBytes::new(py, seed))
95+
cryptography_key_parsing::pkcs8::mldsa_seed_from_pkey(&self.pkey)?;
96+
Ok(pyo3::types::PyBytes::new(py, &seed))
10997
}
11098

11199
fn private_bytes<'p>(

tests/hazmat/primitives/test_mldsa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def test_round_trip_private_serialization(
160160
serialized = key.private_bytes(encoding, fmt, encryption)
161161
loaded_key = load_func(serialized, passwd, backend)
162162
assert isinstance(loaded_key, MlDsa65PrivateKey)
163+
assert loaded_key.private_bytes_raw() == key.private_bytes_raw()
163164
sig = loaded_key.sign(b"test data")
164165
key.public_key().verify(sig, b"test data")
165166

0 commit comments

Comments
 (0)