From 03a8cab626dccdd75ce93abd88fdf8f900c1799d Mon Sep 17 00:00:00 2001 From: Onyeka Obi Date: Thu, 2 Apr 2026 06:27:39 -0700 Subject: [PATCH] Fix uninit_vector UB by returning Vec> uninit_vector previously returned Vec with uninitialized memory, which is undefined behavior (especially for types implementing Drop). Change return type to Vec> and add assume_init_vec helper for zero-cost conversion after initialization. Update all ~30 call sites across math, crypto, fri, prover, and examples crates. Closes #396 --- crypto/benches/merkle.rs | 9 ++-- crypto/src/merkle/concurrent.rs | 20 +++++--- crypto/src/merkle/mod.rs | 15 +++--- .../src/rescue_raps/custom_trace_table.rs | 5 +- examples/src/rescue_raps/prover.rs | 7 +-- fri/src/folding/mod.rs | 9 ++-- fri/src/prover/mod.rs | 10 ++-- math/src/fft/concurrent.rs | 10 +++- math/src/fft/serial.rs | 12 +++-- math/src/polynom/mod.rs | 8 +-- math/src/utils/mod.rs | 31 ++++++++---- prover/src/constraints/evaluation_table.rs | 14 ++++-- .../constraints/evaluator/periodic_table.rs | 8 +-- prover/src/matrix/col_matrix.rs | 10 ++-- prover/src/matrix/row_matrix.rs | 49 +++++++++++-------- prover/src/matrix/segments.rs | 7 +-- prover/src/trace/trace_table.rs | 5 +- utils/core/src/lib.rs | 36 ++++++++++---- 18 files changed, 168 insertions(+), 97 deletions(-) diff --git a/crypto/benches/merkle.rs b/crypto/benches/merkle.rs index c44db8af3..f14961e46 100644 --- a/crypto/benches/merkle.rs +++ b/crypto/benches/merkle.rs @@ -6,7 +6,8 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use math::fields::f128::BaseElement; use rand_utils::rand_value; -use utils::uninit_vector; +use core::mem::MaybeUninit; +use utils::{assume_init_vec, uninit_vector}; use winter_crypto::{build_merkle_nodes, concurrent, hashers::Blake3_256, Hasher}; type Blake3 = Blake3_256; @@ -20,11 +21,11 @@ pub fn merkle_tree_construction(c: &mut Criterion) { for size in &BATCH_SIZES { let data: Vec = { - let mut res = unsafe { uninit_vector(*size) }; + let mut res = uninit_vector(*size); for i in 0..*size { - res[i] = Blake3::hash(&rand_value::().to_le_bytes()); + res[i] = MaybeUninit::new(Blake3::hash(&rand_value::().to_le_bytes())); } - res + unsafe { assume_init_vec(res) } }; merkle_group.bench_with_input(BenchmarkId::new("sequential", size), &data, |b, i| { b.iter(|| build_merkle_nodes::(i)) diff --git a/crypto/src/merkle/concurrent.rs b/crypto/src/merkle/concurrent.rs index 66696c174..296783eae 100644 --- a/crypto/src/merkle/concurrent.rs +++ b/crypto/src/merkle/concurrent.rs @@ -24,11 +24,13 @@ pub const MIN_CONCURRENT_LEAVES: usize = 1024; /// results in a single vector such that root of the tree is at position 1, nodes immediately /// under the root is at positions 2 and 3 etc. pub fn build_merkle_nodes(leaves: &[H::Digest]) -> Vec { + use core::mem::MaybeUninit; + let n = leaves.len() / 2; // create un-initialized array to hold all intermediate nodes - let mut nodes = unsafe { utils::uninit_vector::(2 * n) }; - nodes[0] = H::Digest::default(); + let mut nodes = utils::uninit_vector::(2 * n); + nodes[0] = MaybeUninit::new(H::Digest::default()); // re-interpret leaves as an array of two leaves fused together and use it to // build first row of internal nodes (parents of leaves) @@ -36,7 +38,7 @@ pub fn build_merkle_nodes(leaves: &[H::Digest]) -> Vec { nodes[n..] .par_iter_mut() .zip(two_leaves.par_iter()) - .for_each(|(target, source)| *target = H::merge(source)); + .for_each(|(target, source)| *target = MaybeUninit::new(H::merge(source))); // calculate all other tree nodes, we can't use regular iterators here because // access patterns are rather complicated - so, we use regular threads instead @@ -45,19 +47,21 @@ pub fn build_merkle_nodes(leaves: &[H::Digest]) -> Vec { let num_subtrees = rayon::current_num_threads().next_power_of_two(); let batch_size = n / num_subtrees; - // re-interpret nodes as an array of two nodes fused together + // re-interpret nodes as an array of two nodes fused together; MaybeUninit has the + // same layout as T, so the pointer cast is valid let two_nodes = unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [H::Digest; 2], n) }; // process each subtree in a separate thread rayon::scope(|s| { for i in 0..num_subtrees { - let nodes = unsafe { &mut *(&mut nodes[..] as *mut [H::Digest]) }; + let nodes = + unsafe { &mut *(&mut nodes[..] as *mut [MaybeUninit] as *mut [MaybeUninit]) }; s.spawn(move |_| { let mut batch_size = batch_size / 2; let mut start_idx = n / 2 + batch_size * i; while start_idx >= num_subtrees { for k in (start_idx..(start_idx + batch_size)).rev() { - nodes[k] = H::merge(&two_nodes[k]); + nodes[k] = MaybeUninit::new(H::merge(&two_nodes[k])); } start_idx /= 2; batch_size /= 2; @@ -68,10 +72,10 @@ pub fn build_merkle_nodes(leaves: &[H::Digest]) -> Vec { // finish the tip of the tree for i in (1..num_subtrees).rev() { - nodes[i] = H::merge(&two_nodes[i]); + nodes[i] = MaybeUninit::new(H::merge(&two_nodes[i])); } - nodes + unsafe { utils::assume_init_vec(nodes) } } // TESTS diff --git a/crypto/src/merkle/mod.rs b/crypto/src/merkle/mod.rs index 708effdf1..96e66f867 100644 --- a/crypto/src/merkle/mod.rs +++ b/crypto/src/merkle/mod.rs @@ -342,29 +342,32 @@ impl MerkleTree { /// This function is exposed primarily for benchmarking purposes. It is not intended to be used /// directly by the end users of the crate. pub fn build_merkle_nodes(leaves: &[H::Digest]) -> Vec { + use core::mem::MaybeUninit; + let n = leaves.len() / 2; // create un-initialized array to hold all intermediate nodes - let mut nodes = unsafe { utils::uninit_vector::(2 * n) }; - nodes[0] = H::Digest::default(); + let mut nodes = utils::uninit_vector::(2 * n); + nodes[0] = MaybeUninit::new(H::Digest::default()); // re-interpret leaves as an array of two leaves fused together let two_leaves = unsafe { slice::from_raw_parts(leaves.as_ptr() as *const [H::Digest; 2], n) }; // build first row of internal nodes (parents of leaves) for (i, j) in (0..n).zip(n..nodes.len()) { - nodes[j] = H::merge(&two_leaves[i]); + nodes[j] = MaybeUninit::new(H::merge(&two_leaves[i])); } - // re-interpret nodes as an array of two nodes fused together + // re-interpret nodes as an array of two nodes fused together; safe because all elements + // from index n onwards are initialized, and lower indices will be initialized below let two_nodes = unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [H::Digest; 2], n) }; // calculate all other tree nodes for i in (1..n).rev() { - nodes[i] = H::merge(&two_nodes[i]); + nodes[i] = MaybeUninit::new(H::merge(&two_nodes[i])); } - nodes + unsafe { utils::assume_init_vec(nodes) } } fn map_indexes( diff --git a/examples/src/rescue_raps/custom_trace_table.rs b/examples/src/rescue_raps/custom_trace_table.rs index e28722917..661ff885e 100644 --- a/examples/src/rescue_raps/custom_trace_table.rs +++ b/examples/src/rescue_raps/custom_trace_table.rs @@ -3,7 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use core_utils::uninit_vector; +use core_utils::{assume_init_vec, uninit_vector}; use winterfell::{math::StarkField, matrix::ColMatrix, EvaluationFrame, Trace, TraceInfo}; // RAP TRACE TABLE @@ -87,7 +87,8 @@ impl RapTraceTable { meta.len() ); - let columns = unsafe { (0..width).map(|_| uninit_vector(length)).collect() }; + // SAFETY: each column is fully initialized via fill() or update_row() before being read. + let columns = (0..width).map(|_| unsafe { assume_init_vec(uninit_vector(length)) }).collect(); Self { info: TraceInfo::new_multi_segment(width, 3, 3, length, meta), trace: ColMatrix::new(columns), diff --git a/examples/src/rescue_raps/prover.rs b/examples/src/rescue_raps/prover.rs index 7b04f98b9..2f5ce0b72 100644 --- a/examples/src/rescue_raps/prover.rs +++ b/examples/src/rescue_raps/prover.rs @@ -3,7 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use core_utils::uninit_vector; +use core_utils::{assume_init_vec, uninit_vector}; use winterfell::{ crypto::MerkleTree, matrix::ColMatrix, AuxRandElements, CompositionPoly, CompositionPolyTrace, ConstraintCompositionCoefficients, DefaultConstraintCommitment, DefaultConstraintEvaluator, @@ -168,8 +168,9 @@ where let main_trace = trace.main_segment(); let rand_elements = aux_rand_elements.rand_elements(); - let mut current_row = unsafe { uninit_vector(main_trace.num_cols()) }; - let mut next_row = unsafe { uninit_vector(main_trace.num_cols()) }; + // SAFETY: read_row_into fully initializes each row buffer before it is read. + let mut current_row = unsafe { assume_init_vec(uninit_vector(main_trace.num_cols())) }; + let mut next_row = unsafe { assume_init_vec(uninit_vector(main_trace.num_cols())) }; main_trace.read_row_into(0, &mut current_row); let mut aux_columns = vec![vec![E::ZERO; main_trace.num_rows()]; trace.aux_trace_width()]; diff --git a/fri/src/folding/mod.rs b/fri/src/folding/mod.rs index 58385953e..7a2f541ee 100644 --- a/fri/src/folding/mod.rs +++ b/fri/src/folding/mod.rs @@ -16,7 +16,8 @@ use math::{ }; #[cfg(feature = "concurrent")] use utils::iterators::*; -use utils::{iter_mut, uninit_vector}; +use core::mem::MaybeUninit; +use utils::{assume_init_vec, iter_mut, uninit_vector}; // DEGREE-RESPECTING PROJECTION // ================================================================================================ @@ -93,7 +94,7 @@ where let inv_twiddles = get_inv_twiddles::(N); let len_offset = E::inv((N as u32).into()); - let mut result = unsafe { uninit_vector(values.len()) }; + let mut result = uninit_vector(values.len()); iter_mut!(result) .zip(values) .zip(inv_offsets) @@ -111,10 +112,10 @@ where } // evaluate the polynomial at alpha, and save the result - *result = polynom::eval(&poly, alpha) + *result = MaybeUninit::new(polynom::eval(&poly, alpha)) }); - result + unsafe { assume_init_vec(result) } } // POSITION FOLDING diff --git a/fri/src/prover/mod.rs b/fri/src/prover/mod.rs index b661083c7..ef1cf8273 100644 --- a/fri/src/prover/mod.rs +++ b/fri/src/prover/mod.rs @@ -10,8 +10,10 @@ use crypto::{ElementHasher, Hasher, VectorCommitment}; use math::{fft, FieldElement}; #[cfg(feature = "concurrent")] use utils::iterators::*; +use core::mem::MaybeUninit; use utils::{ - flatten_vector_elements, group_slice_elements, iter_mut, transpose_slice, uninit_vector, + assume_init_vec, flatten_vector_elements, group_slice_elements, iter_mut, transpose_slice, + uninit_vector, }; use crate::{ @@ -326,11 +328,11 @@ where H: ElementHasher, V: VectorCommitment, { - let mut hashed_evaluations: Vec = unsafe { uninit_vector(values.len()) }; + let mut hashed_evaluations: Vec> = uninit_vector(values.len()); iter_mut!(hashed_evaluations, 1024).zip(values).for_each(|(e, v)| { let digest: H::Digest = H::hash_elements(v); - *e = digest + *e = MaybeUninit::new(digest) }); - V::new(hashed_evaluations) + V::new(unsafe { assume_init_vec(hashed_evaluations) }) } diff --git a/math/src/fft/concurrent.rs b/math/src/fft/concurrent.rs index 1b3771cb8..7febb7816 100644 --- a/math/src/fft/concurrent.rs +++ b/math/src/fft/concurrent.rs @@ -5,7 +5,8 @@ use alloc::vec::Vec; -use utils::{iterators::*, rayon, uninit_vector}; +use core::mem::MaybeUninit; +use utils::{assume_init_vec, iterators::*, rayon, uninit_vector}; use super::fft_inputs::FftInputs; use crate::field::{FieldElement, StarkField}; @@ -31,7 +32,7 @@ pub fn evaluate_poly_with_offset>( ) -> Vec { let domain_size = p.len() * blowup_factor; let g = B::get_root_of_unity(domain_size.ilog2()); - let mut result = unsafe { uninit_vector(domain_size) }; + let mut result = uninit_vector(domain_size); result .as_mut_slice() @@ -40,10 +41,15 @@ pub fn evaluate_poly_with_offset>( .for_each(|(i, chunk)| { let idx = super::permute_index(blowup_factor, i) as u64; let offset = g.exp(idx.into()) * domain_offset; + // SAFETY: MaybeUninit has the same layout as E; we fully initialize + // the chunk via clone_and_shift before reading via split_radix_fft. + let chunk = unsafe { &mut *(chunk as *mut [MaybeUninit] as *mut [E]) }; clone_and_shift(p, chunk, offset); split_radix_fft(chunk, twiddles); }); + let mut result = unsafe { assume_init_vec(result) }; + permute(&mut result); result } diff --git a/math/src/fft/serial.rs b/math/src/fft/serial.rs index 30647332f..760335412 100644 --- a/math/src/fft/serial.rs +++ b/math/src/fft/serial.rs @@ -5,7 +5,8 @@ use alloc::vec::Vec; -use utils::uninit_vector; +use core::mem::MaybeUninit; +use utils::{assume_init_vec, uninit_vector}; use super::fft_inputs::FftInputs; use crate::{field::StarkField, FieldElement}; @@ -38,16 +39,21 @@ where { let domain_size = p.len() * blowup_factor; let g = B::get_root_of_unity(domain_size.ilog2()); - let mut result = unsafe { uninit_vector(domain_size) }; + let mut result = uninit_vector(domain_size); result.as_mut_slice().chunks_mut(p.len()).enumerate().for_each(|(i, chunk)| { let idx = super::permute_index(blowup_factor, i) as u64; let offset = g.exp(idx.into()) * domain_offset; let mut factor = E::BaseField::ONE; for (d, c) in chunk.iter_mut().zip(p.iter()) { - *d = (*c).mul_base(factor); + *d = MaybeUninit::new((*c).mul_base(factor)); factor *= offset; } + }); + + let mut result = unsafe { assume_init_vec(result) }; + + result.as_mut_slice().chunks_mut(p.len()).for_each(|chunk| { chunk.fft_in_place(twiddles); }); diff --git a/math/src/polynom/mod.rs b/math/src/polynom/mod.rs index 8967a12ec..8b12eebb2 100644 --- a/math/src/polynom/mod.rs +++ b/math/src/polynom/mod.rs @@ -662,9 +662,11 @@ where /// assert_eq!(expected_poly, poly); /// ``` pub fn poly_from_roots(xs: &[E]) -> Vec { - let mut result = unsafe { utils::uninit_vector(xs.len() + 1) }; - fill_zero_roots(xs, &mut result); - result + let mut result = utils::uninit_vector(xs.len() + 1); + // fill_zero_roots writes all elements of result + let result_slice = unsafe { &mut *(result.as_mut_slice() as *mut [core::mem::MaybeUninit] as *mut [E]) }; + fill_zero_roots(xs, result_slice); + unsafe { utils::assume_init_vec(result) } } // HELPER FUNCTIONS diff --git a/math/src/utils/mod.rs b/math/src/utils/mod.rs index 6f5cb6673..b4bb7bf28 100644 --- a/math/src/utils/mod.rs +++ b/math/src/utils/mod.rs @@ -5,9 +5,11 @@ use alloc::vec::Vec; +use core::mem::MaybeUninit; + #[cfg(feature = "concurrent")] use utils::iterators::*; -use utils::{batch_iter_mut, iter_mut, uninit_vector}; +use utils::{assume_init_vec, batch_iter_mut, iter_mut, uninit_vector}; use crate::{field::FieldElement, ExtensionOf}; @@ -37,12 +39,15 @@ pub fn get_power_series(b: E, n: usize) -> Vec where E: FieldElement, { - let mut result = unsafe { uninit_vector(n) }; - batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| { + let mut result = uninit_vector(n); + batch_iter_mut!(&mut result, 1024, |batch: &mut [MaybeUninit], batch_offset: usize| { + // SAFETY: MaybeUninit has the same layout as E; fill_power_series initializes + // every element of the batch. + let batch = unsafe { &mut *(batch as *mut [MaybeUninit] as *mut [E]) }; let start = b.exp((batch_offset as u64).into()); fill_power_series(batch, b, start); }); - result + unsafe { assume_init_vec(result) } } /// Returns a vector containing successive powers of a given base offset by the specified value. @@ -70,12 +75,15 @@ pub fn get_power_series_with_offset(b: E, s: E, n: usize) -> Vec where E: FieldElement, { - let mut result = unsafe { uninit_vector(n) }; - batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| { + let mut result = uninit_vector(n); + batch_iter_mut!(&mut result, 1024, |batch: &mut [MaybeUninit], batch_offset: usize| { + // SAFETY: MaybeUninit has the same layout as E; fill_power_series initializes + // every element of the batch. + let batch = unsafe { &mut *(batch as *mut [MaybeUninit] as *mut [E]) }; let start = s * b.exp((batch_offset as u64).into()); fill_power_series(batch, b, start); }); - result + unsafe { assume_init_vec(result) } } /// Computes element-wise sum of the provided vectors, and stores the result in the first vector. @@ -170,13 +178,16 @@ pub fn batch_inversion(values: &[E]) -> Vec where E: FieldElement, { - let mut result: Vec = unsafe { uninit_vector(values.len()) }; - batch_iter_mut!(&mut result, 1024, |batch: &mut [E], batch_offset: usize| { + let mut result: Vec> = uninit_vector(values.len()); + batch_iter_mut!(&mut result, 1024, |batch: &mut [MaybeUninit], batch_offset: usize| { + // SAFETY: MaybeUninit has the same layout as E; serial_batch_inversion + // initializes every element of the batch. + let batch = unsafe { &mut *(batch as *mut [MaybeUninit] as *mut [E]) }; let start = batch_offset; let end = start + batch.len(); serial_batch_inversion(&values[start..end], batch); }); - result + unsafe { assume_init_vec(result) } } // HELPER FUNCTIONS diff --git a/prover/src/constraints/evaluation_table.rs b/prover/src/constraints/evaluation_table.rs index 524a9b3c4..c9217ab61 100644 --- a/prover/src/constraints/evaluation_table.rs +++ b/prover/src/constraints/evaluation_table.rs @@ -12,7 +12,8 @@ use math::fft; use math::{batch_inversion, FieldElement, StarkField}; #[cfg(feature = "concurrent")] use utils::iterators::*; -use utils::{batch_iter_mut, iter_mut, uninit_vector}; +use core::mem::MaybeUninit; +use utils::{assume_init_vec, batch_iter_mut, iter_mut, uninit_vector}; use super::{ConstraintDivisor, StarkDomain}; @@ -288,7 +289,9 @@ impl EvaluationTableFragment<'_, E> { /// Allocates memory for a two-dimensional data structure without initializing it. fn uninit_matrix(num_cols: usize, num_rows: usize) -> Vec> { - unsafe { (0..num_cols).map(|_| uninit_vector(num_rows)).collect() } + // SAFETY: each column is fully initialized before being read; the callers write to + // every element via constraint evaluation before accessing the data. + (0..num_cols).map(|_| unsafe { assume_init_vec(uninit_vector(num_rows)) }).collect() } /// Breaks the source data into a mutable set of fragments such that each fragment has the same @@ -390,17 +393,18 @@ fn get_inv_evaluation( let domain_offset_exp = domain.offset().exp(a.into()); // compute x^a - b for all x - let mut evaluations = unsafe { uninit_vector(n) }; + let mut evaluations = uninit_vector(n); batch_iter_mut!( &mut evaluations, 128, // min batch size - |batch: &mut [B], batch_offset: usize| { + |batch: &mut [MaybeUninit], batch_offset: usize| { for (i, evaluation) in batch.iter_mut().enumerate() { let x = domain.get_ce_x_power_at(batch_offset + i, a, domain_offset_exp); - *evaluation = x - b; + *evaluation = MaybeUninit::new(x - b); } } ); + let evaluations = unsafe { assume_init_vec(evaluations) }; // compute 1 / (x^a - b) batch_inversion(&evaluations) diff --git a/prover/src/constraints/evaluator/periodic_table.rs b/prover/src/constraints/evaluator/periodic_table.rs index ec72aa766..9d91159f5 100644 --- a/prover/src/constraints/evaluator/periodic_table.rs +++ b/prover/src/constraints/evaluator/periodic_table.rs @@ -7,7 +7,8 @@ use alloc::{collections::BTreeMap, vec::Vec}; use air::Air; use math::{fft, StarkField}; -use utils::uninit_vector; +use core::mem::MaybeUninit; +use utils::{assume_init_vec, uninit_vector}; pub struct PeriodicValueTable { values: Vec, @@ -54,12 +55,13 @@ impl PeriodicValueTable { // table in such a way that values for the same row are adjacent to each other. let row_width = polys.len(); let column_length = max_poly_size * air.ce_blowup_factor(); - let mut values = unsafe { uninit_vector(row_width * column_length) }; + let mut values = uninit_vector(row_width * column_length); for i in 0..column_length { for (j, column) in evaluations.iter().enumerate() { - values[i * row_width + j] = column[i % column.len()]; + values[i * row_width + j] = MaybeUninit::new(column[i % column.len()]); } } + let values = unsafe { assume_init_vec(values) }; PeriodicValueTable { values, diff --git a/prover/src/matrix/col_matrix.rs b/prover/src/matrix/col_matrix.rs index 6b15f8088..7e8744c65 100644 --- a/prover/src/matrix/col_matrix.rs +++ b/prover/src/matrix/col_matrix.rs @@ -10,7 +10,8 @@ use crypto::{ElementHasher, VectorCommitment}; use math::{fft, polynom, FieldElement}; #[cfg(feature = "concurrent")] use utils::iterators::*; -use utils::{batch_iter_mut, iter, iter_mut, uninit_vector}; +use core::mem::MaybeUninit; +use utils::{assume_init_vec, batch_iter_mut, iter, iter_mut, uninit_vector}; use crate::StarkDomain; @@ -265,7 +266,7 @@ impl ColMatrix { V: VectorCommitment, { // allocate vector to store row hashes - let mut row_hashes = unsafe { uninit_vector::(self.num_rows()) }; + let mut row_hashes = uninit_vector::(self.num_rows()); // iterate though matrix rows, hashing each row; the hashing is done by first copying a // row into row_buf to avoid heap allocations, and then by applying the hash function to @@ -273,15 +274,16 @@ impl ColMatrix { batch_iter_mut!( &mut row_hashes, 128, // min batch size - |batch: &mut [H::Digest], batch_offset: usize| { + |batch: &mut [MaybeUninit], batch_offset: usize| { let mut row_buf = vec![E::ZERO; self.num_cols()]; for (i, row_hash) in batch.iter_mut().enumerate() { self.read_row_into(i + batch_offset, &mut row_buf); - *row_hash = H::hash_elements(&row_buf); + *row_hash = MaybeUninit::new(H::hash_elements(&row_buf)); } } ); + let row_hashes = unsafe { assume_init_vec(row_hashes) }; V::new(row_hashes).expect("failed to construct trace vector commitment") } diff --git a/prover/src/matrix/row_matrix.rs b/prover/src/matrix/row_matrix.rs index 6a7a18233..7573e2112 100644 --- a/prover/src/matrix/row_matrix.rs +++ b/prover/src/matrix/row_matrix.rs @@ -10,7 +10,8 @@ use crypto::{ElementHasher, VectorCommitment}; use math::{fft, FieldElement, StarkField}; #[cfg(feature = "concurrent")] use utils::iterators::*; -use utils::{batch_iter_mut, flatten_vector_elements, uninit_vector}; +use core::mem::MaybeUninit; +use utils::{assume_init_vec, batch_iter_mut, flatten_vector_elements, uninit_vector}; use super::{ColMatrix, Segment}; use crate::StarkDomain; @@ -187,7 +188,7 @@ impl RowMatrix { V: VectorCommitment, { // allocate vector to store row hashes - let mut row_hashes = unsafe { uninit_vector::(self.num_rows()) }; + let mut row_hashes = uninit_vector::(self.num_rows()); let partition_size = partition_options.partition_size::(self.num_cols()); if partition_size == self.num_cols() { @@ -195,9 +196,9 @@ impl RowMatrix { batch_iter_mut!( &mut row_hashes, 128, // min batch size - |batch: &mut [H::Digest], batch_offset: usize| { + |batch: &mut [MaybeUninit], batch_offset: usize| { for (i, row_hash) in batch.iter_mut().enumerate() { - *row_hash = H::hash_elements(self.row(batch_offset + i)); + *row_hash = MaybeUninit::new(H::hash_elements(self.row(batch_offset + i))); } } ); @@ -208,7 +209,7 @@ impl RowMatrix { batch_iter_mut!( &mut row_hashes, 128, // min batch size - |batch: &mut [H::Digest], batch_offset: usize| { + |batch: &mut [MaybeUninit], batch_offset: usize| { let mut buffer = vec![H::Digest::default(); num_partitions]; for (i, row_hash) in batch.iter_mut().enumerate() { self.row(batch_offset + i) @@ -217,13 +218,14 @@ impl RowMatrix { .for_each(|(chunk, buf)| { *buf = H::hash_elements(chunk); }); - *row_hash = H::merge_many(&buffer); + *row_hash = MaybeUninit::new(H::merge_many(&buffer)); } } ); } // build the vector commitment to the hashed rows + let row_hashes = unsafe { assume_init_vec(row_hashes) }; V::new(row_hashes).expect("failed to construct trace vector commitment") } } @@ -244,17 +246,17 @@ pub fn get_evaluation_offsets( let g = E::BaseField::get_root_of_unity(domain_size.ilog2()); // allocate memory to hold the offsets - let mut offsets = unsafe { uninit_vector(domain_size) }; + let mut offsets = uninit_vector(domain_size); // define a closure to compute offsets for a given chunk of the result; the number of chunks // is defined by the blowup factor. for example, for blowup factor = 2, the number of chunks // will be 2, for blowup factor = 8, the number of chunks will be 8 etc. - let compute_offsets = |(chunk_idx, chunk): (usize, &mut [E::BaseField])| { + let compute_offsets = |(chunk_idx, chunk): (usize, &mut [MaybeUninit])| { let idx = fft::permute_index(blowup_factor, chunk_idx) as u64; let offset = g.exp_vartime(idx.into()) * domain_offset; let mut factor = E::BaseField::ONE; for res in chunk.iter_mut() { - *res = factor; + *res = MaybeUninit::new(factor); factor *= offset; } }; @@ -267,7 +269,7 @@ pub fn get_evaluation_offsets( #[cfg(feature = "concurrent")] offsets.par_chunks_mut(poly_size).enumerate().for_each(compute_offsets); - offsets + unsafe { assume_init_vec(offsets) } } /// Returns matrix segments constructed by evaluating polynomials in the specified matrix over the @@ -308,7 +310,7 @@ fn transpose(mut segments: Vec>) -> // allocate memory to hold the transposed result; // TODO: investigate transposing in-place - let mut result = unsafe { uninit_vector::<[B; N]>(result_len) }; + let mut result = uninit_vector::<[B; N]>(result_len); // determine number of batches in which transposition will be preformed; if `concurrent` // feature is not enabled, the number of batches will always be 1 @@ -316,16 +318,21 @@ fn transpose(mut segments: Vec>) -> let rows_per_batch = num_rows / num_batches; // define a closure for transposing a given batch - let transpose_batch = |(batch_idx, batch): (usize, &mut [[B; N]])| { - let row_offset = batch_idx * rows_per_batch; - for i in 0..rows_per_batch { - let row_idx = i + row_offset; - for j in 0..num_segs { - let v = &segments[j][row_idx]; - batch[i * num_segs + j].copy_from_slice(v); + let transpose_batch = + |(batch_idx, batch): (usize, &mut [MaybeUninit<[B; N]>])| { + // SAFETY: MaybeUninit<[B; N]> has the same layout as [B; N]; every element + // of the batch is fully written via copy_from_slice. + let batch = + unsafe { &mut *(batch as *mut [MaybeUninit<[B; N]>] as *mut [[B; N]]) }; + let row_offset = batch_idx * rows_per_batch; + for i in 0..rows_per_batch { + let row_idx = i + row_offset; + for j in 0..num_segs { + let v = &segments[j][row_idx]; + batch[i * num_segs + j].copy_from_slice(v); + } } - } - }; + }; // call the closure either once (for single-threaded transposition) or in a parallel // iterator (for multi-threaded transposition) @@ -339,7 +346,7 @@ fn transpose(mut segments: Vec>) -> .enumerate() .for_each(transpose_batch); - result + unsafe { assume_init_vec(result) } } #[cfg(not(feature = "concurrent"))] diff --git a/prover/src/matrix/segments.rs b/prover/src/matrix/segments.rs index df7b04e3e..70c65053e 100644 --- a/prover/src/matrix/segments.rs +++ b/prover/src/matrix/segments.rs @@ -9,7 +9,7 @@ use core::ops::Deref; use math::{fft::fft_inputs::FftInputs, FieldElement, StarkField}; #[cfg(feature = "concurrent")] use utils::iterators::*; -use utils::uninit_vector; +use utils::{assume_init_vec, uninit_vector}; use super::ColMatrix; @@ -65,8 +65,9 @@ impl Segment { // allocate memory for the segment let data = if polys.num_base_cols() - poly_offset >= N { - // if we will fill the entire segment, we allocate uninitialized memory - unsafe { uninit_vector::<[B; N]>(domain_size) } + // if we will fill the entire segment, we allocate uninitialized memory; + // SAFETY: new_with_buffer fully initializes all elements before reading. + unsafe { assume_init_vec(uninit_vector::<[B; N]>(domain_size)) } } else { // but if some columns in the segment will remain unfilled, we allocate memory // initialized to zeros to make sure we don't end up with memory with diff --git a/prover/src/trace/trace_table.rs b/prover/src/trace/trace_table.rs index caa9da8b7..93665a571 100644 --- a/prover/src/trace/trace_table.rs +++ b/prover/src/trace/trace_table.rs @@ -7,7 +7,7 @@ use alloc::vec::Vec; use air::{EvaluationFrame, TraceInfo}; use math::StarkField; -use utils::uninit_vector; +use utils::{assume_init_vec, uninit_vector}; #[cfg(feature = "concurrent")] use utils::{iterators::*, rayon}; @@ -103,7 +103,8 @@ impl TraceTable { length.ilog2() ); - let columns = unsafe { (0..width).map(|_| uninit_vector(length)).collect() }; + // SAFETY: each column is fully initialized via fill() or set() before being read. + let columns = (0..width).map(|_| unsafe { assume_init_vec(uninit_vector(length)) }).collect(); Self { info, trace: ColMatrix::new(columns) } } diff --git a/utils/core/src/lib.rs b/utils/core/src/lib.rs index 947abdf18..a1379d6ab 100644 --- a/utils/core/src/lib.rs +++ b/utils/core/src/lib.rs @@ -16,7 +16,7 @@ extern crate std; pub mod iterators; use alloc::vec::Vec; -use core::{mem, slice}; +use core::{mem, mem::MaybeUninit, slice}; mod serde; #[cfg(feature = "std")] @@ -76,12 +76,25 @@ impl AsBytes for [[u8; N]] { /// overwrite all contents of the vector immediately after memory allocation. /// /// # Safety -/// Using values from the returned vector before initializing them will lead to undefined behavior. -#[allow(clippy::uninit_vec)] -pub unsafe fn uninit_vector(length: usize) -> Vec { - let mut vector = Vec::with_capacity(length); - vector.set_len(length); - vector +/// All elements must be initialized before reading from them or converting via +/// [`assume_init_vec`]. +pub fn uninit_vector(length: usize) -> Vec> { + let mut result = Vec::with_capacity(length); + // SAFETY: MaybeUninit does not require initialization; we are only extending the + // length to match the capacity, which is valid for MaybeUninit elements. + unsafe { result.set_len(length) }; + result +} + +/// Converts a vector of fully-initialized `MaybeUninit` values into a `Vec`. +/// +/// This is a zero-cost conversion (same memory layout). +/// +/// # Safety +/// Every element in `v` must have been initialized before calling this function. +pub unsafe fn assume_init_vec(v: Vec>) -> Vec { + let mut v = mem::ManuallyDrop::new(v); + Vec::from_raw_parts(v.as_mut_ptr().cast::(), v.len(), v.capacity()) } // GROUPING / UN-GROUPING FUNCTIONS @@ -173,13 +186,16 @@ pub fn transpose_slice(source: &[T]) -> V source.len() ); - let mut result: Vec<[T; N]> = unsafe { uninit_vector(row_count) }; + let mut result = uninit_vector::<[T; N]>(row_count); iter_mut!(result, 1024).enumerate().for_each(|(i, element)| { + let mut arr = MaybeUninit::<[T; N]>::uninit(); + let ptr = arr.as_mut_ptr() as *mut T; for j in 0..N { - element[j] = source[i + j * row_count] + unsafe { ptr.add(j).write(source[i + j * row_count]) }; } + *element = arr; }); - result + unsafe { assume_init_vec(result) } } // RANDOMNESS