diff --git a/src/arithmetic/bigint/exp.rs b/src/arithmetic/bigint/exp.rs index 96004199e7..939ab21cb8 100644 --- a/src/arithmetic/bigint/exp.rs +++ b/src/arithmetic/bigint/exp.rs @@ -153,15 +153,13 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>( let mut table = Buf::from(Uninit::from(table.as_mut())); // table[0] = base**0 (i.e. 1). - table.try_write_with::(num_limbs.get(), |_init, uninit| { - Ok( - One::write_mont_identity_assuming_full_upper_limb(uninit, m)? - .leak_limbs_into_mut_less_safe(), - ) + table.with_filled_and_unfilled_buf_checked(num_limbs.get(), |_init, uninit| { + One::write_mont_identity_assuming_full_upper_limb(uninit, m) + .map(elem::Mut::leak_limbs_into_mut_less_safe) })?; // table[1] = base*R == (base/R * RRR)/R - table.try_write_with(num_limbs.get(), |_init, uninit| { + table.with_filled_and_unfilled_buf_checked(num_limbs.get(), |_init, uninit| { limbs_mul_mont( ( uninit, @@ -178,7 +176,7 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>( let n = num_limbs.get(); // table[2*i] = (n**i)**2/R - table.try_write_with(n, |init, uninit| { + table.with_filled_and_unfilled_buf_checked(n, |init, uninit| { let sqrt_start = init.len() / 2; let sqrt = init .get(sqrt_start..(sqrt_start + n)) @@ -187,7 +185,7 @@ fn elem_exp_consttime_inner<'out, N, M, const STORAGE_LIMBS: usize>( })?; // table[2*i + 1] = (n**1)*(n**(2*i))/R - table.try_write_with(n, |init, uninit| { + table.with_filled_and_unfilled_buf_checked(n, |init, uninit| { let one = init.get(n..(n + n)).unwrap_or_else(|| unreachable!()); let previous = init .get((init.len() - n)..) diff --git a/src/arithmetic/bigint/modulus/mont.rs b/src/arithmetic/bigint/modulus/mont.rs index f5de8c7c45..2867b160be 100644 --- a/src/arithmetic/bigint/modulus/mont.rs +++ b/src/arithmetic/bigint/modulus/mont.rs @@ -137,7 +137,7 @@ impl ValidatedInput<'_> { if out.capacity() < storage_num_limbs { return Err(LenMismatchError::new(out.capacity())); } - out.with_unfilled_buf(|out| { + out.with_unfilled_buf_checked(|out| { // We can't compute `n0` until after we've written `value`. out.unfilled().write_repeat(limb::ZERO, N0::LIMBS_USED)?; out.unfilled().write(limb::limb_from_usize(num_limbs))?; @@ -149,7 +149,7 @@ impl ValidatedInput<'_> { let (_num_limbs, value) = rest.split_first().unwrap_or_else(|| unreachable!()); // Since we just wrote it. N0::write_into(n0, value); - out.try_write_with(num_limbs, |init, uninit| { + out.with_filled_and_unfilled_buf_checked(num_limbs, |init, uninit| { let m = &Mont::<'_, M>::from_storage_unchecked_less_safe(init, cpu); let r: elem::Mut<'_, _, RR> = One::write_mont_identity(uninit, m, self.len_bits())?.mul_r(m)?; // in place. diff --git a/src/polyfill/uninit_slice.rs b/src/polyfill/uninit_slice.rs index bb39b8baa3..6ee1b17bbf 100644 --- a/src/polyfill/uninit_slice.rs +++ b/src/polyfill/uninit_slice.rs @@ -313,14 +313,11 @@ impl<'target, E: Copy> Buf<'target, E> { /// Reserves the first `len` bytes of the unfilled space as `to_fill`, then /// calls `f(filled, to_fill)` where `filled` is the filled space. If `f` /// returns `Ok(filled)`, then `filled` must be `to_fill`, filled in. - pub fn try_write_with( + pub fn with_filled_and_unfilled_buf_checked>( &mut self, len: usize, f: impl for<'a> FnOnce(&mut [E], Uninit<'a, E>) -> Result<&'a mut [E], Err>, - ) -> Result<(), Err> - where - Err: From, - { + ) -> Result<(), Err> { let (filled, mut unfilled) = self.split_at_spare_mut(); let Some(to_fill) = unfilled.get_mut(..len) else { return Err(LenMismatchError::new(unfilled.len()))?; @@ -387,17 +384,21 @@ impl Cursor<'_, '_, E> { /// See `core::io::BorrowedCursor::with_unfilled_buf`. /// - /// # Panics - /// - /// Panics if `f` replaces `Buf` with a different one. - pub fn with_unfilled_buf(&mut self, f: impl FnOnce(&mut Buf<'_, E>) -> R) -> R { + /// `f` must not replace the `Buf` it is given with a different one; + /// i.e. it must not assign into the mutable reference it is given. + pub fn with_unfilled_buf_checked>( + &mut self, + f: impl FnOnce(&mut Buf<'_, E>) -> Result, + ) -> Result { let mut buf = Buf::from(self.buf.unfilled_uninit()); let ptr_and_len: *const [MaybeUninit] = ptr::from_ref(buf.storage.target); let res = f(&mut buf); // It would be OK if the length became shorter, but there's no reason // for us to support that case. It is important that the address is the // same and the length isn't longer. - assert!(ptr::eq(buf.storage.target, ptr_and_len)); + if !ptr::eq(buf.storage.target, ptr_and_len) { + Err(LenMismatchError::new(buf.storage.target.len()))?; + } debug_assert!(buf.filled <= buf.storage.len()); // invariant self.buf.filled += buf.filled; // The above assertions ensure our invariant is maintained.