Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/arithmetic/bigint/exp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,13 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>(
let mut table = Buf::from(Uninit::from(table.as_mut()));

// table[0] = base**0 (i.e. 1).
table.try_write_with::<LenMismatchError>(num_limbs.get(), |_init, uninit| {
Ok(
One::write_mont_identity_assuming_full_upper_limb(uninit, m)?
.leak_limbs_into_mut_less_safe(),
)
table.with_filled_and_unfilled_buf_checked(num_limbs.get(), |_init, uninit| {
One::write_mont_identity_assuming_full_upper_limb(uninit, m)
.map(elem::Mut::leak_limbs_into_mut_less_safe)
})?;

// table[1] = base*R == (base/R * RRR)/R
table.try_write_with(num_limbs.get(), |_init, uninit| {
table.with_filled_and_unfilled_buf_checked(num_limbs.get(), |_init, uninit| {
limbs_mul_mont(
(
uninit,
Expand All @@ -178,7 +176,7 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>(
let n = num_limbs.get();

// table[2*i] = (n**i)**2/R
table.try_write_with(n, |init, uninit| {
table.with_filled_and_unfilled_buf_checked(n, |init, uninit| {
let sqrt_start = init.len() / 2;
let sqrt = init
.get(sqrt_start..(sqrt_start + n))
Expand All @@ -187,7 +185,7 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>(
})?;

// table[2*i + 1] = (n**1)*(n**(2*i))/R
table.try_write_with(n, |init, uninit| {
table.with_filled_and_unfilled_buf_checked(n, |init, uninit| {
let one = init.get(n..(n + n)).unwrap_or_else(|| unreachable!());
let previous = init
.get((init.len() - n)..)
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/bigint/modulus/mont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl ValidatedInput<'_> {
if out.capacity() < storage_num_limbs {
return Err(LenMismatchError::new(out.capacity()));
}
out.with_unfilled_buf(|out| {
out.with_unfilled_buf_checked(|out| {
// We can't compute `n0` until after we've written `value`.
out.unfilled().write_repeat(limb::ZERO, N0::LIMBS_USED)?;
out.unfilled().write(limb::limb_from_usize(num_limbs))?;
Expand All @@ -149,7 +149,7 @@ impl ValidatedInput<'_> {
let (_num_limbs, value) = rest.split_first().unwrap_or_else(|| unreachable!()); // Since we just wrote it.
N0::write_into(n0, value);

out.try_write_with(num_limbs, |init, uninit| {
out.with_filled_and_unfilled_buf_checked(num_limbs, |init, uninit| {
let m = &Mont::<'_, M>::from_storage_unchecked_less_safe(init, cpu);
let r: elem::Mut<'_, _, RR> =
One::write_mont_identity(uninit, m, self.len_bits())?.mul_r(m)?; // in place.
Expand Down
21 changes: 11 additions & 10 deletions src/polyfill/uninit_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,11 @@ impl<'target, E: Copy> Buf<'target, E> {
/// Reserves the first `len` bytes of the unfilled space as `to_fill`, then
/// calls `f(filled, to_fill)` where `filled` is the filled space. If `f`
/// returns `Ok(filled)`, then `filled` must be `to_fill`, filled in.
pub fn try_write_with<Err>(
pub fn with_filled_and_unfilled_buf_checked<Err: From<LenMismatchError>>(
&mut self,
len: usize,
f: impl for<'a> FnOnce(&mut [E], Uninit<'a, E>) -> Result<&'a mut [E], Err>,
) -> Result<(), Err>
where
Err: From<LenMismatchError>,
{
) -> Result<(), Err> {
let (filled, mut unfilled) = self.split_at_spare_mut();
let Some(to_fill) = unfilled.get_mut(..len) else {
return Err(LenMismatchError::new(unfilled.len()))?;
Expand Down Expand Up @@ -387,17 +384,21 @@ impl<E: Copy> Cursor<'_, '_, E> {

/// See `core::io::BorrowedCursor::with_unfilled_buf`.
///
/// # Panics
///
/// Panics if `f` replaces `Buf` with a different one.
pub fn with_unfilled_buf<R>(&mut self, f: impl FnOnce(&mut Buf<'_, E>) -> R) -> R {
/// `f` must not replace the `Buf` it is given with a different one;
/// i.e. it must not assign into the mutable reference it is given.
pub fn with_unfilled_buf_checked<R, Err: From<LenMismatchError>>(
&mut self,
f: impl FnOnce(&mut Buf<'_, E>) -> Result<R, Err>,
) -> Result<R, Err> {
let mut buf = Buf::from(self.buf.unfilled_uninit());
let ptr_and_len: *const [MaybeUninit<E>] = ptr::from_ref(buf.storage.target);
let res = f(&mut buf);
// It would be OK if the length became shorter, but there's no reason
// for us to support that case. It is important that the address is the
// same and the length isn't longer.
assert!(ptr::eq(buf.storage.target, ptr_and_len));
if !ptr::eq(buf.storage.target, ptr_and_len) {
Err(LenMismatchError::new(buf.storage.target.len()))?;
}
debug_assert!(buf.filled <= buf.storage.len()); // invariant
self.buf.filled += buf.filled;
// The above assertions ensure our invariant is maintained.
Expand Down
Loading