From 15ef0075429b07cb81ee36fb49c388101eb1b3c0 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Wed, 20 May 2026 10:00:22 -0700 Subject: [PATCH] bigint: Add `elem::Ref::{mul,squared}`. --- src/arithmetic/bigint/elem.rs | 55 ++++++++++++++++++++-------- src/arithmetic/bigint/exp.rs | 47 +++++++++++------------- src/arithmetic/bigint/modulus/one.rs | 4 +- 3 files changed, 64 insertions(+), 42 deletions(-) diff --git a/src/arithmetic/bigint/elem.rs b/src/arithmetic/bigint/elem.rs index b078cd944d..e1fabb0263 100644 --- a/src/arithmetic/bigint/elem.rs +++ b/src/arithmetic/bigint/elem.rs @@ -16,8 +16,8 @@ use crate::polyfill::prelude::*; use super::{ - super::montgomery::*, IntoMont, Mont, unwrap_impossible_len_mismatch_error, - unwrap_impossible_limb_slice_error, + super::{LimbSliceError, montgomery::*}, + IntoMont, Mont, unwrap_impossible_len_mismatch_error, unwrap_impossible_limb_slice_error, }; use crate::{ c, cpu, @@ -29,6 +29,7 @@ use crate::{ }, }; use core::{marker::PhantomData, num::NonZero}; + // TODO: Move here? pub(crate) use super::oversized_uninit::OversizedUninit; @@ -128,6 +129,7 @@ impl<'l, M, E> Mut<'l, M, E> { Ref::assume_in_range_and_encoded_less_safe(self.limbs) } + #[cfg(target_arch = "x86_64")] pub(super) fn leak_limbs_less_safe(&self) -> &[Limb] { self.limbs } @@ -188,7 +190,7 @@ impl<'l, M, E> Ref<'l, M, E> { self.limbs.len() } - #[cfg(test)] + #[cfg(any(test, target_arch = "x86_64"))] pub(super) fn leak_limbs_less_safe(&self) -> &[Limb] { self.limbs } @@ -303,18 +305,7 @@ impl<'l, M, E> Mut<'l, M, E> { where (E, OE): ProductEncoding, { - let oneRR = im.one(); - let m = im.modulus(cpu); - - let in_out = self.limbs; - let _: &[Limb] = limbs_mul_mont( - (InOut(&mut *in_out), oneRR.leak_limbs_less_safe()), - m.limbs(), - m.n0(), - m.cpu_features(), - ) - .unwrap_or_else(unwrap_impossible_limb_slice_error); - Mut::assume_in_range_and_encoded_less_safe(in_out) + self.mul(im.one().as_ref(), &im.modulus(cpu)) } } @@ -369,6 +360,40 @@ impl Ref<'_, M, Unencoded> { } } +#[cfg_attr(target_arch = "x86_64", expect(dead_code))] +impl Ref<'_, M, E> { + pub fn mul<'r, BE>( + self, + r: Uninit<'r, Limb>, + b: Ref, + m: &Mont, + ) -> Result::Output>, LimbSliceError> + where + (E, BE): ProductEncoding, + { + let r = limbs_mul_mont( + (r, self.limbs, b.limbs), + m.limbs(), + m.n0(), + m.cpu_features(), + )?; + Ok(Mut::assume_in_range_and_encoded_less_safe(r)) + } + + #[inline] + pub fn squared<'r>( + self, + r: Uninit<'r, Limb>, + m: &Mont, + ) -> Result::Output>, LimbSliceError> + where + (E, E): ProductEncoding, + { + let r = limbs_square_mont((r, self.limbs), m.limbs(), m.n0(), m.cpu_features())?; + Ok(Mut::assume_in_range_and_encoded_less_safe(r)) + } +} + impl<'l, M> Ref<'l, M, Unencoded> { // Assumes self < S*R_S where R_S is the Montgomery 1 for S, and returns // self/R mod S. (When P*Q=M and P and Q have the same bit length, diff --git a/src/arithmetic/bigint/exp.rs b/src/arithmetic/bigint/exp.rs index 63f32a9d01..a8507d924a 100644 --- a/src/arithmetic/bigint/exp.rs +++ b/src/arithmetic/bigint/exp.rs @@ -93,10 +93,7 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>( m: &Mont, tmp: &mut elem::OversizedUninit, ) -> Result, LimbSliceError> { - use super::{ - super::montgomery::{R, limbs_mul_mont, limbs_square_mont}, - elem, - }; + use super::{super::montgomery::R, elem}; use crate::{ bssl, c, error, polyfill::{StartMutPtr, slice::Buf}, @@ -160,16 +157,10 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>( // table[1] = base*R == (base/R * RRR)/R table.write_with(num_limbs.get(), |_init, uninit| { - limbs_mul_mont( - ( - uninit, - base_rinverse.leak_limbs_less_safe(), - oneRRR.leak_limbs_less_safe(), - ), - m.limbs(), - m.n0(), - m.cpu_features(), - ) + base_rinverse + .as_ref() + .mul(uninit, oneRRR.as_ref(), m) + .map(elem::Mut::leak_limbs_into_mut_less_safe) })?; for _i in 1..16 { @@ -178,19 +169,25 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>( // table[2*i] = (n**i)**2/R table.write_with(n, |init, uninit| { let sqrt_start = init.len() / 2; - let sqrt = init - .get(sqrt_start..(sqrt_start + n)) - .unwrap_or_else(|| unreachable!()); - limbs_square_mont((uninit, sqrt), m.limbs(), m.n0(), m.cpu_features()) + let sqrt = elem::Ref::<'_, M, R>::assume_in_range_and_encoded_less_safe( + init.get(sqrt_start..(sqrt_start + n)) + .unwrap_or_else(|| unreachable!()), + ); + sqrt.squared(uninit, m) + .map(elem::Mut::leak_limbs_into_mut_less_safe) })?; // table[2*i + 1] = (n**1)*(n**(2*i))/R table.write_with(n, |init, uninit| { - let one = init.get(n..(n + n)).unwrap_or_else(|| unreachable!()); - let previous = init - .get((init.len() - n)..) - .unwrap_or_else(|| unreachable!()); - limbs_mul_mont((uninit, one, previous), m.limbs(), m.n0(), m.cpu_features()) + let one = elem::Ref::<'_, M, R>::assume_in_range_and_encoded_less_safe( + init.get(n..(n + n)).unwrap_or_else(|| unreachable!()), + ); + let previous = elem::Ref::<'_, M, R>::assume_in_range_and_encoded_less_safe( + init.get((init.len() - n)..) + .unwrap_or_else(|| unreachable!()), + ); + one.mul(uninit, previous, m) + .map(elem::Mut::leak_limbs_into_mut_less_safe) })?; } let table: &[Limb] = table.into_filled(); @@ -261,7 +258,7 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>( .ok_or_else(|| LenMismatchError::new(m_len))? .len(); - let oneRRR = oneRRR.leak_limbs_less_safe(); + let oneRRR = oneRRR.as_ref(); // The x86_64 assembly was written under the assumption that the input data // is aligned to `MOD_EXP_CTIME_ALIGN` bytes, which was/is 64 in OpenSSL. @@ -307,7 +304,7 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>( let base_cached: &[Limb] = mul_mont5( base_cached.into(), base_rinverse, - oneRRR, + oneRRR.leak_limbs_less_safe(), m_cached, n0, cpu2, diff --git a/src/arithmetic/bigint/modulus/one.rs b/src/arithmetic/bigint/modulus/one.rs index c472a044cc..10faa6cfb8 100644 --- a/src/arithmetic/bigint/modulus/one.rs +++ b/src/arithmetic/bigint/modulus/one.rs @@ -43,8 +43,8 @@ impl One<'_, M, E> { } } - pub(in super::super) fn leak_limbs_less_safe(&self) -> &[Limb] { - self.value + pub(in super::super) fn as_ref(&self) -> elem::Ref<'_, M, E> { + elem::Ref::assume_in_range_and_encoded_less_safe(self.value) } }