From 845af40815bfccd05aac4f12a7dbdcf0144c57a6 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 20 May 2026 10:53:14 -0700 Subject: [PATCH] bigint: Clarify what `elem::Ref::{mul,squared}` can return. --- src/arithmetic/bigint.rs | 13 +++++++++++-- src/arithmetic/bigint/elem.rs | 20 +++++++++++--------- src/arithmetic/bigint/exp.rs | 4 ++-- src/arithmetic/ffi.rs | 4 ++-- src/arithmetic/limbs/aarch64/mont.rs | 5 +++-- src/arithmetic/limbs/x86_64/mont.rs | 5 +++-- src/arithmetic/mod.rs | 10 ++++++++-- src/polyfill/cold_error.rs | 3 +++ 8 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 388bb2aa41..253e9ecec0 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -87,8 +87,17 @@ fn unwrap_impossible_len_mismatch_error(LenMismatchError { .. }: LenMismatchE fn unwrap_impossible_limb_slice_error(err: LimbSliceError) -> T { match err { LimbSliceError::LenMismatch(_) => unreachable!(), - LimbSliceError::TooShort(_) => unreachable!(), - LimbSliceError::TooLong(_) => unreachable!(), + LimbSliceError::ModulusTooShort(_) => unreachable!(), + LimbSliceError::ModulusTooLong(_) => unreachable!(), + } +} + +#[cold] +fn limb_slice_error_must_be_len_mismatch_error(err: LimbSliceError) -> LenMismatchError { + match err { + LimbSliceError::LenMismatch(err) => err, + LimbSliceError::ModulusTooLong(_) => unreachable!(), // since `m: Mont`. + LimbSliceError::ModulusTooShort(_) => unreachable!(), // since `m: Mont`. } } diff --git a/src/arithmetic/bigint/elem.rs b/src/arithmetic/bigint/elem.rs index e1fabb0263..d3f16ce1e8 100644 --- a/src/arithmetic/bigint/elem.rs +++ b/src/arithmetic/bigint/elem.rs @@ -16,8 +16,8 @@ use crate::polyfill::prelude::*; use super::{ - super::{LimbSliceError, montgomery::*}, - IntoMont, Mont, unwrap_impossible_len_mismatch_error, unwrap_impossible_limb_slice_error, + super::montgomery::*, IntoMont, Mont, limb_slice_error_must_be_len_mismatch_error, + unwrap_impossible_len_mismatch_error, unwrap_impossible_limb_slice_error, }; use crate::{ c, cpu, @@ -367,17 +367,18 @@ impl Ref<'_, M, E> { r: Uninit<'r, Limb>, b: Ref, m: &Mont, - ) -> Result::Output>, LimbSliceError> + ) -> Result::Output>, LenMismatchError> where (E, BE): ProductEncoding, { - let r = limbs_mul_mont( + limbs_mul_mont( (r, self.limbs, b.limbs), m.limbs(), m.n0(), m.cpu_features(), - )?; - Ok(Mut::assume_in_range_and_encoded_less_safe(r)) + ) + .map(Mut::assume_in_range_and_encoded_less_safe) + .map_err(limb_slice_error_must_be_len_mismatch_error) // because `m: Mont` } #[inline] @@ -385,12 +386,13 @@ impl Ref<'_, M, E> { self, r: Uninit<'r, Limb>, m: &Mont, - ) -> Result::Output>, LimbSliceError> + ) -> Result::Output>, LenMismatchError> where (E, E): ProductEncoding, { - let r = limbs_square_mont((r, self.limbs), m.limbs(), m.n0(), m.cpu_features())?; - Ok(Mut::assume_in_range_and_encoded_less_safe(r)) + limbs_square_mont((r, self.limbs), m.limbs(), m.n0(), m.cpu_features()) + .map(Mut::assume_in_range_and_encoded_less_safe) + .map_err(limb_slice_error_must_be_len_mismatch_error) // because `m: Mont` } } diff --git a/src/arithmetic/bigint/exp.rs b/src/arithmetic/bigint/exp.rs index a8507d924a..ffc2f269b4 100644 --- a/src/arithmetic/bigint/exp.rs +++ b/src/arithmetic/bigint/exp.rs @@ -461,8 +461,8 @@ mod tests { match actual_result { Ok(r) => assert_elem_eq(r.as_ref(), expected_result.as_ref()), Err(LimbSliceError::LenMismatch { .. }) => panic!(), - Err(LimbSliceError::TooLong { .. }) => panic!(), - Err(LimbSliceError::TooShort { .. }) => panic!(), + Err(LimbSliceError::ModulusTooLong { .. }) => panic!(), + Err(LimbSliceError::ModulusTooShort { .. }) => panic!(), }; Ok(()) diff --git a/src/arithmetic/ffi.rs b/src/arithmetic/ffi.rs index efbaa5e3ca..51e543483b 100644 --- a/src/arithmetic/ffi.rs +++ b/src/arithmetic/ffi.rs @@ -86,7 +86,7 @@ pub(super) unsafe fn bn_mul_mont_ffi<'o, Cpu, const LEN_MIN: usize, const LEN_MO assert_eq!(n.len() % LEN_MOD, 0); // The caller should guard against this. assert!(LEN_MIN >= MIN_LIMBS); if n.len() < LEN_MIN { - return Err(LimbSliceError::too_short(n.len())); + return Err(LimbSliceError::modulus_too_short(n.len())); } let len = NonZero::new(n.len()).unwrap_or_else(|| { // Unreachable because we checked against `LEN_MIN`, and we checked @@ -98,7 +98,7 @@ pub(super) unsafe fn bn_mul_mont_ffi<'o, Cpu, const LEN_MIN: usize, const LEN_MO // `2*len` + a non-trivial fixed amount. if len.get() > MAX_LIMBS { - return Err(LimbSliceError::too_long(n.len())); + return Err(LimbSliceError::modulus_too_long(n.len())); } let r = in_out.with_non_dangling_non_null_pointers(len, |mut r, [a, b]| { let n = n.as_ptr(); diff --git a/src/arithmetic/limbs/aarch64/mont.rs b/src/arithmetic/limbs/aarch64/mont.rs index e086042f0d..1c02d7e373 100644 --- a/src/arithmetic/limbs/aarch64/mont.rs +++ b/src/arithmetic/limbs/aarch64/mont.rs @@ -73,7 +73,8 @@ pub(in super::super::super) fn sqr_mont5<'o>( } let n = n.as_flattened(); - let num_limbs = NonZero::new(n.len()).ok_or_else(|| LimbSliceError::too_short(n.len()))?; + let num_limbs = + NonZero::new(n.len()).ok_or_else(|| LimbSliceError::modulus_too_short(n.len()))?; // Avoid stack overflow from the alloca inside. // @@ -82,7 +83,7 @@ pub(in super::super::super) fn sqr_mont5<'o>( // that we don't have to precisely audit the code. const _CHKSTK_NOT_NEEDED: () = _TWICE_MAX_LIMBS_LE_3KB; if num_limbs.get() > MAX_LIMBS { - return Err(LimbSliceError::too_long(num_limbs.get())); + return Err(LimbSliceError::modulus_too_long(num_limbs.get())); } let r = in_out.with_non_dangling_non_null_pointers(num_limbs, |mut r, [a]| { diff --git a/src/arithmetic/limbs/x86_64/mont.rs b/src/arithmetic/limbs/x86_64/mont.rs index 0a5981e41b..cd999ab8af 100644 --- a/src/arithmetic/limbs/x86_64/mont.rs +++ b/src/arithmetic/limbs/x86_64/mont.rs @@ -96,11 +96,12 @@ pub(in super::super::super) fn sqr_mont5<'o>( } let n = n.as_flattened(); - let num_limbs = NonZero::new(n.len()).ok_or_else(|| LimbSliceError::too_short(n.len()))?; + let num_limbs = + NonZero::new(n.len()).ok_or_else(|| LimbSliceError::modulus_too_short(n.len()))?; // Avoid stack overflow from the alloca inside. if num_limbs.get() > MAX_LIMBS { - return Err(LimbSliceError::too_long(num_limbs.get())); + return Err(LimbSliceError::modulus_too_long(num_limbs.get())); } // `Limb::from(mulx_adx.is_some())`, but intentionally branchy. diff --git a/src/arithmetic/mod.rs b/src/arithmetic/mod.rs index d38ca68c89..dc16a37a10 100644 --- a/src/arithmetic/mod.rs +++ b/src/arithmetic/mod.rs @@ -40,8 +40,14 @@ pub const MAX_LIMBS: usize = 8192 / LIMB_BITS; cold_exhaustive_error! { enum limb_slice_error::LimbSliceError { len_mismatch => LenMismatch(LenMismatchError), - too_short => TooShort(usize), - too_long => TooLong(usize), + /// "Too short" checks should only be done against the modulus, + /// not against other inputs. Callers rely on this + /// to reject these cases as impossible, if they've already + /// checked the modulus length. + modulus_too_short => ModulusTooShort(usize), + /// "Too long" checks should only be done against the modulus, + /// for the same reason as "too short" checks. + modulus_too_long => ModulusTooLong(usize), } } diff --git a/src/polyfill/cold_error.rs b/src/polyfill/cold_error.rs index 13fc9f6baa..3e4b48f33a 100644 --- a/src/polyfill/cold_error.rs +++ b/src/polyfill/cold_error.rs @@ -73,6 +73,7 @@ macro_rules! cold_exhaustive_error { { enum $mod_name:ident::$Error:ident { $( + $( #[$meta:meta] )* $constructor:ident => $Variant:ident($ValueType:ty), )+ } @@ -83,12 +84,14 @@ macro_rules! cold_exhaustive_error { pub enum $Error { $( + $( #[$meta] )* $Variant(#[allow(dead_code)] $ValueType) ),+ } impl $Error { $( + $( #[$meta] )* #[cold] #[inline(never)] pub(super) fn $constructor(value: $ValueType) -> Self {