From c6aa875c576161f1dd22554ab8ad5f276e72fa08 Mon Sep 17 00:00:00 2001 From: Markos Georghiades <53157953+Markos-The-G@users.noreply.github.com> Date: Tue, 12 May 2026 22:40:15 -0400 Subject: [PATCH] feat(jolt-witness): add modular witness crate --- Cargo.lock | 8 + Cargo.toml | 3 + crates/jolt-poly/src/dense.rs | 56 ++ crates/jolt-poly/src/expanding_table.rs | 212 +++++ crates/jolt-poly/src/lib.rs | 6 +- crates/jolt-poly/src/split_eq.rs | 446 +++++++++ crates/jolt-r1cs/src/lib.rs | 2 + crates/jolt-r1cs/src/row_dots.rs | 118 +++ crates/jolt-transcript/src/lib.rs | 2 + crates/jolt-transcript/src/mock.rs | 109 +++ crates/jolt-witness/Cargo.toml | 13 + crates/jolt-witness/src/lib.rs | 1119 +++++++++++++++++++++++ 12 files changed, 2093 insertions(+), 1 deletion(-) create mode 100644 crates/jolt-poly/src/expanding_table.rs create mode 100644 crates/jolt-poly/src/split_eq.rs create mode 100644 crates/jolt-r1cs/src/row_dots.rs create mode 100644 crates/jolt-transcript/src/mock.rs create mode 100644 crates/jolt-witness/Cargo.toml create mode 100644 crates/jolt-witness/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 191249e63a..91d1f3d200 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3153,6 +3153,14 @@ dependencies = [ "sha3 0.11.0", ] +[[package]] +name = "jolt-witness" +version = "0.1.0" +dependencies = [ + "jolt-field", + "jolt-poly", +] + [[package]] name = "js-sys" version = "0.3.91" diff --git a/Cargo.toml b/Cargo.toml index 56f9bd41ea..dfb4789692 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ members = [ "crates/jolt-hyperkzg", "crates/jolt-riscv", "crates/jolt-transcript", + "crates/jolt-witness", "crates/jolt-profiling", "crates/jolt-field", "jolt-core", @@ -382,6 +383,8 @@ jolt-openings = { path = "./crates/jolt-openings" } jolt-poly = { path = "./crates/jolt-poly" } jolt-transcript = { path = "./crates/jolt-transcript" } jolt-sumcheck = { path = "./crates/jolt-sumcheck" } +jolt-r1cs = { path = "./crates/jolt-r1cs" } +jolt-witness = { path = "./crates/jolt-witness" } jolt-riscv = { path = "./crates/jolt-riscv", default-features = false } jolt-program = { path = "./crates/jolt-program", default-features = false } jolt-lookup-tables = { path = "./crates/jolt-lookup-tables" } diff --git a/crates/jolt-poly/src/dense.rs b/crates/jolt-poly/src/dense.rs index 46293f2bbb..749f65dee8 100644 --- a/crates/jolt-poly/src/dense.rs +++ b/crates/jolt-poly/src/dense.rs @@ -360,6 +360,62 @@ impl crate::MultilinearBinding for Polynomial { } } +/// Fixes the first (MSB) variable in a dense evaluation vector. +#[inline] +pub fn bind_high_to_low(evals: &mut Vec, scalar: F) { + let half = evals.len() / 2; + + #[cfg(feature = "parallel")] + { + if half >= PAR_THRESHOLD { + use rayon::prelude::*; + let (lo, hi) = evals.split_at_mut(half); + lo.par_iter_mut().zip(hi.par_iter()).for_each(|(a, b)| { + *a = *a + scalar * (*b - *a); + }); + evals.truncate(half); + return; + } + } + + for i in 0..half { + let lo = evals[i]; + let hi = evals[i + half]; + evals[i] = lo + scalar * (hi - lo); + } + evals.truncate(half); +} + +/// Fixes the last (LSB) variable in a dense evaluation vector. +#[inline] +pub fn bind_low_to_high(evals: &mut Vec, scalar: F) { + let half = evals.len() / 2; + + #[cfg(feature = "parallel")] + { + if half >= PAR_THRESHOLD { + use rayon::prelude::*; + let coeffs = &*evals; + *evals = (0..half) + .into_par_iter() + .map(|i| { + let lo = coeffs[2 * i]; + let hi = coeffs[2 * i + 1]; + lo + scalar * (hi - lo) + }) + .collect(); + return; + } + } + + for i in 0..half { + let lo = evals[2 * i]; + let hi = evals[2 * i + 1]; + evals[i] = lo + scalar * (hi - lo); + } + evals.truncate(half); +} + #[inline] fn assert_matching_dims(a: &Polynomial, b: &Polynomial) -> (usize, usize) { assert_eq!( diff --git a/crates/jolt-poly/src/expanding_table.rs b/crates/jolt-poly/src/expanding_table.rs new file mode 100644 index 0000000000..1f48734d4f --- /dev/null +++ b/crates/jolt-poly/src/expanding_table.rs @@ -0,0 +1,212 @@ +//! Incrementally materialized equality tables. + +use std::ops::Index; + +use jolt_field::Field; + +use crate::{thread::unsafe_allocate_zero_vec, BindingOrder}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Table containing the evaluations of `eq(x, r)` as challenges are streamed in. +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct ExpandingTable { + binding_order: BindingOrder, + len: usize, + values: Vec, + scratch_space: Vec, +} + +impl ExpandingTable { + #[tracing::instrument(skip_all, name = "ExpandingTable::new")] + pub fn new(capacity: usize, binding_order: BindingOrder) -> Self { + assert!(capacity > 0, "expanding table capacity must be positive"); + let (values, scratch_space) = join_or_serial( + || unsafe_allocate_zero_vec(capacity), + || match binding_order { + BindingOrder::LowToHigh => Vec::new(), + BindingOrder::HighToLow => unsafe_allocate_zero_vec(capacity), + }, + ); + Self { + binding_order, + len: 0, + values, + scratch_space, + } + } + + #[inline] + pub fn len(&self) -> usize { + self.len + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + #[inline] + pub fn order(&self) -> BindingOrder { + self.binding_order + } + + #[inline] + pub fn values(&self) -> &[F] { + &self.values[..self.len] + } + + pub fn reset(&mut self, value: F) { + assert!(!self.values.is_empty(), "expanding table has zero capacity"); + self.values[0] = value; + self.len = 1; + } + + pub fn clone_values(&self) -> Vec { + self.values().to_vec() + } + + #[tracing::instrument(skip_all, name = "ExpandingTable::update")] + pub fn update(&mut self, challenge: F) { + assert!(self.len > 0, "expanding table must be reset before update"); + assert!( + self.len * 2 <= self.values.len(), + "expanding table capacity exceeded" + ); + match self.binding_order { + BindingOrder::LowToHigh => self.update_low_to_high(challenge), + BindingOrder::HighToLow => self.update_high_to_low(challenge), + } + self.len *= 2; + } + + fn update_low_to_high(&mut self, challenge: F) { + #[cfg(feature = "parallel")] + { + let (left, right) = self.values.split_at_mut(self.len); + left.par_iter_mut() + .zip(right.par_iter_mut()) + .for_each(|(left, right)| { + *right = *left * challenge; + *left -= *right; + }); + } + + #[cfg(not(feature = "parallel"))] + { + let (left, right) = self.values.split_at_mut(self.len); + for (left, right) in left.iter_mut().zip(right.iter_mut()) { + *right = *left * challenge; + *left -= *right; + } + } + } + + fn update_high_to_low(&mut self, challenge: F) { + #[cfg(feature = "parallel")] + { + self.values[..self.len] + .par_iter() + .zip(self.scratch_space.par_chunks_mut(2)) + .for_each(|(&value, dest)| { + let eval_1 = value * challenge; + dest[0] = value - eval_1; + dest[1] = eval_1; + }); + std::mem::swap(&mut self.values, &mut self.scratch_space); + } + + #[cfg(not(feature = "parallel"))] + { + for (index, &value) in self.values[..self.len].iter().enumerate() { + let eval_1 = value * challenge; + self.scratch_space[2 * index] = value - eval_1; + self.scratch_space[2 * index + 1] = eval_1; + } + std::mem::swap(&mut self.values, &mut self.scratch_space); + } + } +} + +impl Index for ExpandingTable { + type Output = F; + + fn index(&self, index: usize) -> &Self::Output { + assert!( + index < self.len, + "expanding table index {index} out of bounds for len {}", + self.len + ); + &self.values[index] + } +} + +#[cfg(feature = "parallel")] +fn join_or_serial( + left: impl FnOnce() -> A + Send, + right: impl FnOnce() -> B + Send, +) -> (A, B) { + rayon::join(left, right) +} + +#[cfg(not(feature = "parallel"))] +fn join_or_serial(left: impl FnOnce() -> A, right: impl FnOnce() -> B) -> (A, B) { + (left(), right()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::EqPolynomial; + use jolt_field::{Fr, FromPrimitiveInt, RandomSampling}; + use num_traits::One; + use rand_chacha::ChaCha20Rng; + use rand_core::SeedableRng; + + #[test] + fn high_to_low_matches_eq_table_prefixes() { + let mut rng = ChaCha20Rng::seed_from_u64(710); + let point: Vec = (0..8).map(|_| Fr::random(&mut rng)).collect(); + let mut table = ExpandingTable::new(1 << point.len(), BindingOrder::HighToLow); + table.reset(Fr::one()); + + for prefix_len in 0..=point.len() { + let expected = EqPolynomial::::evals(&point[..prefix_len], None); + assert_eq!(table.values(), expected); + if prefix_len < point.len() { + table.update(point[prefix_len]); + } + } + } + + #[test] + fn low_to_high_matches_reversed_eq_prefixes() { + let mut rng = ChaCha20Rng::seed_from_u64(711); + let point: Vec = (0..8).map(|_| Fr::random(&mut rng)).collect(); + let mut reversed_prefix = Vec::new(); + let mut table = ExpandingTable::new(1 << point.len(), BindingOrder::LowToHigh); + table.reset(Fr::one()); + + for (prefix_len, &challenge) in point.iter().enumerate() { + let expected = EqPolynomial::::evals(&reversed_prefix, None); + assert_eq!(table.values(), expected, "prefix_len={prefix_len}"); + reversed_prefix.insert(0, challenge); + table.update(challenge); + } + let expected = EqPolynomial::::evals(&reversed_prefix, None); + assert_eq!(table.values(), expected); + } + + #[test] + fn clone_and_index_expose_active_prefix_only() { + let mut table = ExpandingTable::new(8, BindingOrder::HighToLow); + table.reset(Fr::from_u64(3)); + table.update(Fr::from_u64(5)); + + assert_eq!(table.len(), 2); + assert_eq!(table[0], Fr::from_u64(3) - Fr::from_u64(15)); + assert_eq!(table[1], Fr::from_u64(15)); + assert_eq!(table.clone_values(), table.values()); + } +} diff --git a/crates/jolt-poly/src/lib.rs b/crates/jolt-poly/src/lib.rs index 88c2d27055..a67085c12e 100644 --- a/crates/jolt-poly/src/lib.rs +++ b/crates/jolt-poly/src/lib.rs @@ -48,22 +48,26 @@ mod compressed_univariate; mod dense; mod eq; mod eq_plus_one; +mod expanding_table; mod identity; pub mod lagrange; mod lt; pub mod math; mod multilinear; mod one_hot; +mod split_eq; pub mod thread; mod univariate; pub use binding::BindingOrder; pub use compressed_univariate::CompressedPoly; -pub use dense::Polynomial; +pub use dense::{bind_high_to_low, bind_low_to_high, Polynomial}; pub use eq::EqPolynomial; pub use eq_plus_one::{EqPlusOnePolynomial, EqPlusOnePrefixSuffix}; +pub use expanding_table::ExpandingTable; pub use identity::IdentityPolynomial; pub use lt::LtPolynomial; pub use multilinear::{MultilinearBinding, MultilinearEvaluation, MultilinearPoly, RlcSource}; pub use one_hot::OneHotPolynomial; +pub use split_eq::GruenSplitEqPolynomial; pub use univariate::{UnivariatePoly, UnivariatePolynomial}; diff --git a/crates/jolt-poly/src/split_eq.rs b/crates/jolt-poly/src/split_eq.rs new file mode 100644 index 0000000000..7b256a7ee1 --- /dev/null +++ b/crates/jolt-poly/src/split_eq.rs @@ -0,0 +1,446 @@ +//! Split equality polynomial used by sumcheck provers. +//! +//! This implements the Dao-Thaler/Gruen factorization used by Jolt's larger +//! sumchecks. It stores prefix tables for two halves of the remaining equality +//! polynomial and tracks already-bound variables in a scalar. + +use jolt_field::Field; + +use crate::{BindingOrder, EqPolynomial, Polynomial, UnivariatePoly}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GruenSplitEqPolynomial { + current_index: usize, + current_scalar: F, + w: Vec, + e_in_vec: Vec>, + e_out_vec: Vec>, + binding_order: BindingOrder, +} + +impl GruenSplitEqPolynomial { + #[tracing::instrument(skip_all, name = "GruenSplitEqPolynomial::new_with_scaling")] + pub fn new_with_scaling( + w: &[F], + binding_order: BindingOrder, + scaling_factor: Option, + ) -> Self { + assert!(!w.is_empty(), "split eq requires at least one variable"); + match binding_order { + BindingOrder::LowToHigh => { + let split = w.len() / 2; + let w_prime = &w[..w.len() - 1]; + let (w_out, w_in) = w_prime.split_at(split); + let (e_out_vec, e_in_vec) = join_or_serial( + || EqPolynomial::evals_cached(w_out, None), + || EqPolynomial::evals_cached(w_in, None), + ); + Self { + current_index: w.len(), + current_scalar: scaling_factor.unwrap_or_else(F::one), + w: w.to_vec(), + e_in_vec, + e_out_vec, + binding_order, + } + } + BindingOrder::HighToLow => { + let w_prime = &w[1..]; + let split = w.len() / 2; + let (w_in, w_out) = w_prime.split_at(split); + let (e_in_vec, e_out_vec) = join_or_serial( + || EqPolynomial::evals_cached_rev(w_in, None), + || EqPolynomial::evals_cached_rev(w_out, None), + ); + Self { + current_index: 0, + current_scalar: scaling_factor.unwrap_or_else(F::one), + w: w.to_vec(), + e_in_vec, + e_out_vec, + binding_order, + } + } + } + } + + #[tracing::instrument(skip_all, name = "GruenSplitEqPolynomial::new")] + pub fn new(w: &[F], binding_order: BindingOrder) -> Self { + Self::new_with_scaling(w, binding_order, None) + } + + #[inline] + pub fn num_vars(&self) -> usize { + self.w.len() + } + + #[inline] + pub fn len(&self) -> usize { + match self.binding_order { + BindingOrder::LowToHigh => 1 << self.current_index, + BindingOrder::HighToLow => 1 << (self.w.len() - self.current_index), + } + } + + #[inline] + pub fn is_empty(&self) -> bool { + false + } + + #[inline] + pub fn num_bound_vars(&self) -> usize { + match self.binding_order { + BindingOrder::LowToHigh => self.w.len() - self.current_index, + BindingOrder::HighToLow => self.current_index, + } + } + + #[inline] + pub fn e_in_current_len(&self) -> usize { + self.e_in_current().len() + } + + #[inline] + pub fn e_out_current_len(&self) -> usize { + self.e_out_current().len() + } + + #[inline] + pub fn e_in_current(&self) -> &[F] { + &self.e_in_vec[self.e_in_vec.len() - 1] + } + + #[inline] + pub fn e_out_current(&self) -> &[F] { + &self.e_out_vec[self.e_out_vec.len() - 1] + } + + pub fn e_out_in_for_window(&self, window_size: usize) -> (&[F], &[F]) { + assert_eq!( + self.binding_order, + BindingOrder::LowToHigh, + "streaming windows are not defined for high-to-low split eq" + ); + let num_unbound = self.current_index; + let window_size = window_size.min(num_unbound); + let head_len = num_unbound.saturating_sub(window_size); + let split = self.w.len() / 2; + let head_out_bits = head_len.min(split); + let head_in_bits = head_len.saturating_sub(head_out_bits); + (&self.e_out_vec[head_out_bits], &self.e_in_vec[head_in_bits]) + } + + pub fn e_active_for_window(&self, window_size: usize) -> Vec { + if window_size <= 1 { + return vec![F::one()]; + } + assert_eq!( + self.binding_order, + BindingOrder::LowToHigh, + "streaming windows are not defined for high-to-low split eq" + ); + let num_unbound = self.current_index; + if window_size > num_unbound { + return vec![F::one()]; + } + let remaining_w = &self.w[..num_unbound]; + let window_start = remaining_w.len() - window_size; + let (_head, w_window) = remaining_w.split_at(window_start); + let (w_active, _w_current) = w_window.split_at(window_size - 1); + EqPolynomial::::evals(w_active, None) + } + + #[tracing::instrument(skip_all, name = "GruenSplitEqPolynomial::bind")] + pub fn bind(&mut self, r: F) { + match self.binding_order { + BindingOrder::LowToHigh => { + let w = self.w[self.current_index - 1]; + let prod = w * r; + self.current_scalar *= F::one() - w - r + prod + prod; + self.current_index -= 1; + if self.w.len() / 2 < self.current_index && self.e_in_vec.len() > 1 { + let _ = self.e_in_vec.pop(); + } else if 0 < self.current_index && self.e_out_vec.len() > 1 { + let _ = self.e_out_vec.pop(); + } + } + BindingOrder::HighToLow => { + let w = self.w[self.current_index]; + let prod = w * r; + self.current_scalar *= F::one() - w - r + prod + prod; + self.current_index += 1; + if self.current_index <= self.w.len() / 2 && self.e_in_vec.len() > 1 { + let _ = self.e_in_vec.pop(); + } else if self.current_index <= self.w.len() && self.e_out_vec.len() > 1 { + let _ = self.e_out_vec.pop(); + } + } + } + } + + pub fn gruen_poly_deg_3( + &self, + q_constant: F, + q_quadratic_coeff: F, + s_0_plus_s_1: F, + ) -> UnivariatePoly { + let eq_eval_1 = self.current_scalar * self.current_w(); + let eq_eval_0 = self.current_scalar - eq_eval_1; + let eq_slope = eq_eval_1 - eq_eval_0; + let eq_eval_2 = eq_eval_1 + eq_slope; + let eq_eval_3 = eq_eval_2 + eq_slope; + + let quadratic_eval_0 = q_constant; + let cubic_eval_0 = eq_eval_0 * quadratic_eval_0; + let cubic_eval_1 = s_0_plus_s_1 - cubic_eval_0; + let quadratic_eval_1 = field_div(cubic_eval_1, eq_eval_1); + let e_times_2 = q_quadratic_coeff + q_quadratic_coeff; + let quadratic_eval_2 = quadratic_eval_1 + quadratic_eval_1 - quadratic_eval_0 + e_times_2; + let quadratic_eval_3 = + quadratic_eval_2 + quadratic_eval_1 - quadratic_eval_0 + e_times_2 + e_times_2; + + UnivariatePoly::from_evals(&[ + cubic_eval_0, + cubic_eval_1, + eq_eval_2 * quadratic_eval_2, + eq_eval_3 * quadratic_eval_3, + ]) + } + + pub fn gruen_poly_deg_2(&self, q_0: F, previous_claim: F) -> UnivariatePoly { + let eq_eval_1 = self.current_scalar * self.current_w(); + let eq_eval_0 = self.current_scalar - eq_eval_1; + let eq_slope = eq_eval_1 - eq_eval_0; + let eq_eval_2 = eq_eval_1 + eq_slope; + + let quadratic_eval_0 = eq_eval_0 * q_0; + let quadratic_eval_1 = previous_claim - quadratic_eval_0; + let linear_eval_1 = field_div(quadratic_eval_1, eq_eval_1); + let linear_eval_2 = linear_eval_1 + linear_eval_1 - q_0; + + UnivariatePoly::from_evals(&[ + quadratic_eval_0, + quadratic_eval_1, + eq_eval_2 * linear_eval_2, + ]) + } + + pub fn gruen_poly_from_evals(&self, q_evals: &[F], s_0_plus_s_1: F) -> UnivariatePoly { + assert!(!q_evals.is_empty(), "q_evals must be non-empty"); + let r_round = self.current_w(); + let l_at_0 = self.current_scalar * EqPolynomial::::mle(&[F::zero()], &[r_round]); + let l_at_1 = self.current_scalar * EqPolynomial::::mle(&[F::one()], &[r_round]); + let q_at_0 = field_div(s_0_plus_s_1 - l_at_1 * q_evals[0], l_at_0); + + let mut full_q_evals = Vec::with_capacity(q_evals.len() + 1); + full_q_evals.push(q_at_0); + full_q_evals.extend_from_slice(q_evals); + let q = UnivariatePoly::from_evals_toom(&full_q_evals); + + let l_c0 = l_at_0; + let l_c1 = l_at_1 - l_at_0; + let q_coeffs = q.into_coefficients(); + let mut s_coeffs = vec![F::zero(); q_coeffs.len() + 1]; + for (index, q_coeff) in q_coeffs.into_iter().enumerate() { + s_coeffs[index] += q_coeff * l_c0; + s_coeffs[index + 1] += q_coeff * l_c1; + } + UnivariatePoly::new(s_coeffs) + } + + pub fn merge(&self) -> Polynomial { + let evals = match self.binding_order { + BindingOrder::LowToHigh => { + EqPolynomial::evals(&self.w[..self.current_index], Some(self.current_scalar)) + } + BindingOrder::HighToLow => { + EqPolynomial::evals(&self.w[self.current_index..], Some(self.current_scalar)) + } + }; + Polynomial::new(evals) + } + + #[inline] + pub fn current_scalar(&self) -> F { + self.current_scalar + } + + #[inline] + pub fn current_w(&self) -> F { + match self.binding_order { + BindingOrder::LowToHigh => self.w[self.current_index - 1], + BindingOrder::HighToLow => self.w[self.current_index], + } + } + + #[inline] + pub fn group_index(&self, x_out: usize, x_in: usize) -> usize { + let num_x_in_bits = self.e_in_current_len().trailing_zeros() as usize; + (x_out << num_x_in_bits) | x_in + } + + pub fn fold_out_in< + OuterAcc: Send, + InnerAcc: Send, + MakeInner: Fn() -> InnerAcc + Sync + Send, + InnerStep: Fn(&mut InnerAcc, usize, usize, F) + Sync + Send, + OuterStep: Fn(usize, F, InnerAcc) -> OuterAcc + Sync + Send, + Merge: Fn(OuterAcc, OuterAcc) -> OuterAcc + Sync + Send, + >( + &self, + make_inner: MakeInner, + inner_step: InnerStep, + outer_step: OuterStep, + merge: Merge, + ) -> OuterAcc { + let e_out = self.e_out_current(); + let e_in = self.e_in_current(); + + #[cfg(feature = "parallel")] + { + let result = (0..e_out.len()) + .into_par_iter() + .map(|x_out| { + let mut inner_acc = make_inner(); + for (x_in, &e_in) in e_in.iter().enumerate() { + let group = self.group_index(x_out, x_in); + inner_step(&mut inner_acc, group, x_in, e_in); + } + outer_step(x_out, e_out[x_out], inner_acc) + }) + .reduce_with(merge); + if let Some(result) = result { + result + } else { + assert!(!e_out.is_empty(), "split eq e_out invariant"); + std::process::abort(); + } + } + + #[cfg(not(feature = "parallel"))] + { + let mut iter = (0..e_out.len()).map(|x_out| { + let mut inner_acc = make_inner(); + for (x_in, &e_in) in e_in.iter().enumerate() { + let group = self.group_index(x_out, x_in); + inner_step(&mut inner_acc, group, x_in, e_in); + } + outer_step(x_out, e_out[x_out], inner_acc) + }); + let first = iter.next().expect("split eq e_out invariant"); + iter.fold(first, merge) + } + } +} + +#[inline] +fn field_div(numerator: F, denominator: F) -> F { + let Some(inverse) = denominator.inverse() else { + unreachable!("split equality denominator must be nonzero"); + }; + numerator * inverse +} + +#[cfg(feature = "parallel")] +fn join_or_serial( + left: impl FnOnce() -> A + Send, + right: impl FnOnce() -> B + Send, +) -> (A, B) { + rayon::join(left, right) +} + +#[cfg(not(feature = "parallel"))] +fn join_or_serial(left: impl FnOnce() -> A, right: impl FnOnce() -> B) -> (A, B) { + (left(), right()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::math::Math; + use jolt_field::{Fr, FromPrimitiveInt, RandomSampling}; + use num_traits::{One, Zero}; + use rand_chacha::ChaCha20Rng; + use rand_core::SeedableRng; + + #[test] + fn bind_low_to_high_matches_dense_eq() { + let mut rng = ChaCha20Rng::seed_from_u64(700); + let point: Vec = (0..10).map(|_| Fr::random(&mut rng)).collect(); + let mut dense = Polynomial::new(EqPolynomial::::evals(&point, None)); + let mut split = GruenSplitEqPolynomial::new(&point, BindingOrder::LowToHigh); + assert_eq!(dense, split.merge()); + + for _ in 0..point.len() { + let r = Fr::random(&mut rng); + dense.bind_with_order(r, BindingOrder::LowToHigh); + split.bind(r); + assert_eq!(dense, split.merge()); + } + } + + #[test] + fn bind_high_to_low_matches_dense_eq() { + let mut rng = ChaCha20Rng::seed_from_u64(701); + let point: Vec = (0..10).map(|_| Fr::random(&mut rng)).collect(); + let mut dense = Polynomial::new(EqPolynomial::::evals(&point, None)); + let mut split = GruenSplitEqPolynomial::new(&point, BindingOrder::HighToLow); + assert_eq!(dense, split.merge()); + + for _ in 0..point.len() { + let r = Fr::random(&mut rng); + dense.bind_with_order(r, BindingOrder::HighToLow); + split.bind(r); + assert_eq!(dense, split.merge()); + } + } + + #[test] + fn window_size_one_factors_current_head() { + let mut rng = ChaCha20Rng::seed_from_u64(702); + let point: Vec = (0..10).map(|_| Fr::random(&mut rng)).collect(); + let mut split = GruenSplitEqPolynomial::new(&point, BindingOrder::LowToHigh); + + for _round in 0..point.len() { + let num_unbound = split.current_index; + if num_unbound <= 1 { + break; + } + let (e_out, e_in) = split.e_out_in_for_window(1); + let head = EqPolynomial::::evals(&split.w[..num_unbound - 1], None); + assert_eq!(e_out.len() * e_in.len(), head.len()); + + let x_in_bits = e_in.len().log_2(); + for (x_out, &e_out) in e_out.iter().enumerate() { + for (x_in, &e_in) in e_in.iter().enumerate() { + let index = (x_out << x_in_bits) | x_in; + assert_eq!(e_out * e_in, head[index]); + } + } + + split.bind(Fr::random(&mut rng)); + } + } + + #[test] + fn gruen_degree_two_matches_direct_interpolation() { + let mut rng = ChaCha20Rng::seed_from_u64(703); + let point: Vec = (0..5).map(|_| Fr::random(&mut rng)).collect(); + let split = GruenSplitEqPolynomial::new(&point, BindingOrder::LowToHigh); + let q0 = Fr::from_u64(11); + let q1 = Fr::from_u64(29); + let l1 = split.current_scalar() * split.current_w(); + let l0 = split.current_scalar() - l1; + let previous_claim = l0 * q0 + l1 * q1; + + let poly = split.gruen_poly_deg_2(q0, previous_claim); + let q2 = q1 + q1 - q0; + let l2 = l1 + (l1 - l0); + assert_eq!(poly.evaluate(Fr::zero()), l0 * q0); + assert_eq!(poly.evaluate(Fr::one()), l1 * q1); + assert_eq!(poly.evaluate(Fr::from_u64(2)), l2 * q2); + } +} diff --git a/crates/jolt-r1cs/src/lib.rs b/crates/jolt-r1cs/src/lib.rs index 62f0f4176d..4ad18b2dfc 100644 --- a/crates/jolt-r1cs/src/lib.rs +++ b/crates/jolt-r1cs/src/lib.rs @@ -13,8 +13,10 @@ pub mod constraint; pub mod constraints; pub mod key; pub mod provider; +mod row_dots; pub use column::R1csColumn; pub use constraint::ConstraintMatrices; pub use key::R1csKey; pub use provider::{R1csSource, SpartanChallenges}; +pub use row_dots::{R1csRowDotSlice, R1csRowDotTable}; diff --git a/crates/jolt-r1cs/src/row_dots.rs b/crates/jolt-r1cs/src/row_dots.rs new file mode 100644 index 0000000000..11a040a10b --- /dev/null +++ b/crates/jolt-r1cs/src/row_dots.rs @@ -0,0 +1,118 @@ +use jolt_field::Field; + +use crate::R1csKey; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +#[derive(Clone, Debug)] +pub struct R1csRowDotTable { + row_count: usize, + cycle_count: usize, + a: Vec, + b: Vec, +} + +#[derive(Clone, Copy, Debug)] +pub struct R1csRowDotSlice<'a, F: Field> { + pub a: &'a [F], + pub b: &'a [F], +} + +impl R1csRowDotTable { + #[tracing::instrument(skip_all, name = "R1csRowDotTable::compute_ab_prefix")] + pub fn compute_ab_prefix(key: &R1csKey, witness: &[F], row_count: usize) -> Self { + assert!( + row_count <= key.matrices.num_constraints, + "row_count exceeds R1CS constraint count" + ); + let expected = key.num_cycles * key.num_vars_padded; + assert_eq!( + witness.len(), + expected, + "R1CS witness length does not match key shape" + ); + + let total = key.num_cycles * row_count; + let mut a = vec![F::zero(); total]; + let mut b = vec![F::zero(); total]; + compute_row_dots(key, witness, row_count, &mut a, &mut b); + + Self { + row_count, + cycle_count: key.num_cycles, + a, + b, + } + } + + #[inline] + pub fn row_count(&self) -> usize { + self.row_count + } + + #[inline] + pub fn cycle_count(&self) -> usize { + self.cycle_count + } + + #[inline] + pub fn cycle(&self, cycle: usize) -> R1csRowDotSlice<'_, F> { + assert!(cycle < self.cycle_count, "cycle index out of bounds"); + let start = cycle * self.row_count; + let end = start + self.row_count; + R1csRowDotSlice { + a: &self.a[start..end], + b: &self.b[start..end], + } + } +} + +#[cfg(feature = "parallel")] +fn compute_row_dots( + key: &R1csKey, + witness: &[F], + row_count: usize, + a: &mut [F], + b: &mut [F], +) { + a.par_chunks_mut(row_count) + .zip(b.par_chunks_mut(row_count)) + .enumerate() + .for_each(|(cycle, (a_chunk, b_chunk))| { + let start = cycle * key.num_vars_padded; + let witness_row = &witness[start..start + key.matrices.num_vars]; + for row in 0..row_count { + a_chunk[row] = row_dot(&key.matrices.a[row], witness_row); + b_chunk[row] = row_dot(&key.matrices.b[row], witness_row); + } + }); +} + +#[cfg(not(feature = "parallel"))] +fn compute_row_dots( + key: &R1csKey, + witness: &[F], + row_count: usize, + a: &mut [F], + b: &mut [F], +) { + for cycle in 0..key.num_cycles { + let witness_start = cycle * key.num_vars_padded; + let row_start = cycle * row_count; + let witness_row = &witness[witness_start..witness_start + key.matrices.num_vars]; + for row in 0..row_count { + a[row_start + row] = row_dot(&key.matrices.a[row], witness_row); + b[row_start + row] = row_dot(&key.matrices.b[row], witness_row); + } + } +} + +#[inline] +fn row_dot(row: &[(usize, F)], witness: &[F]) -> F { + let mut acc = F::zero(); + for &(variable, coefficient) in row { + acc += coefficient * witness[variable]; + } + acc +} diff --git a/crates/jolt-transcript/src/lib.rs b/crates/jolt-transcript/src/lib.rs index 9fcc5f4fee..e11c858c47 100644 --- a/crates/jolt-transcript/src/lib.rs +++ b/crates/jolt-transcript/src/lib.rs @@ -51,6 +51,7 @@ mod blanket; mod digest; pub mod domain; mod keccak; +mod mock; #[cfg(feature = "poseidon")] mod poseidon; mod transcript; @@ -59,6 +60,7 @@ pub use blake2b::Blake2bTranscript; pub use digest::DigestTranscript; pub use domain::{Label, LabelWithCount, U64Word}; pub use keccak::KeccakTranscript; +pub use mock::MockTranscript; #[cfg(feature = "poseidon")] pub use poseidon::PoseidonTranscript; pub use transcript::{AppendToTranscript, Transcript}; diff --git a/crates/jolt-transcript/src/mock.rs b/crates/jolt-transcript/src/mock.rs new file mode 100644 index 0000000000..edd41a1ad4 --- /dev/null +++ b/crates/jolt-transcript/src/mock.rs @@ -0,0 +1,109 @@ +//! Deterministic mock transcript for testing. +//! +//! All absorb operations are no-ops. Challenges come from a seeded Blake2b +//! counter, producing the same sequence regardless of what is absorbed. +//! Use the same seed in both jolt-core's and jolt-transcript's mock +//! transcripts to get identical challenges for cross-system comparison. + +use crate::transcript::{AppendToTranscript, Transcript}; +use blake2::digest::consts::U32; +use blake2::{Blake2b, Digest}; +use jolt_field::Field; +use std::marker::PhantomData; + +type Blake2b256 = Blake2b; + +/// Mock transcript that ignores absorbs and produces deterministic challenges. +/// +/// Challenges are derived from `H(seed || counter)` where counter increments +/// on each squeeze. Two mock transcripts with the same seed always produce +/// identical challenge sequences. +#[derive(Clone)] +pub struct MockTranscript { + seed: [u8; 32], + counter: u64, + _field: PhantomData, +} + +impl Default for MockTranscript { + fn default() -> Self { + Self { + seed: [0u8; 32], + counter: 0, + _field: PhantomData, + } + } +} + +impl MockTranscript { + /// Creates a mock transcript with the given seed bytes. + #[must_use] + pub fn with_seed(seed: &[u8]) -> Self { + let hash: [u8; 32] = Blake2b256::new().chain_update(seed).finalize().into(); + Self { + seed: hash, + counter: 0, + _field: PhantomData, + } + } + + fn next_bytes32(&mut self) -> [u8; 32] { + let hash: [u8; 32] = Blake2b256::new() + .chain_update(self.seed) + .chain_update(self.counter.to_le_bytes()) + .finalize() + .into(); + self.counter += 1; + hash + } +} + +impl Transcript for MockTranscript { + type Challenge = F; + + fn new(_label: &'static [u8]) -> Self { + Self::with_seed(b"mock_transcript_default_seed") + } + + fn append_bytes(&mut self, _bytes: &[u8]) {} + + fn append(&mut self, _value: &A) {} + + fn challenge(&mut self) -> F { + ::from_challenge_bytes(&self.next_bytes32()) + } + + fn state(&self) -> &[u8; 32] { + &self.seed + } + + #[cfg(test)] + fn compare_to(&mut self, _other: &Self) {} +} + +#[cfg(test)] +mod tests { + use super::*; + use jolt_field::Fr; + + #[test] + fn same_seed_same_challenges() { + let mut t1 = MockTranscript::::with_seed(b"test"); + let mut t2 = MockTranscript::::with_seed(b"test"); + + // Absorb different things — shouldn't matter + t1.append_bytes(b"hello"); + t2.append_bytes(b"world"); + + for _ in 0..100 { + assert_eq!(t1.challenge(), t2.challenge()); + } + } + + #[test] + fn different_seed_different_challenges() { + let mut t1 = MockTranscript::::with_seed(b"seed_a"); + let mut t2 = MockTranscript::::with_seed(b"seed_b"); + assert_ne!(t1.challenge(), t2.challenge()); + } +} diff --git a/crates/jolt-witness/Cargo.toml b/crates/jolt-witness/Cargo.toml new file mode 100644 index 0000000000..2ec9481490 --- /dev/null +++ b/crates/jolt-witness/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "jolt-witness" +version = "0.1.0" +edition = "2021" +license = "MIT OR Apache-2.0" +description = "Primitive oracle construction kernels for Bolt-generated Jolt code" + +[lints] +workspace = true + +[dependencies] +jolt-field = { path = "../jolt-field" } +jolt-poly = { path = "../jolt-poly" } diff --git a/crates/jolt-witness/src/lib.rs b/crates/jolt-witness/src/lib.rs new file mode 100644 index 0000000000..2e4dda0488 --- /dev/null +++ b/crates/jolt-witness/src/lib.rs @@ -0,0 +1,1119 @@ +//! Primitive oracle construction kernels for Bolt-generated Jolt code. +//! +//! This crate is intentionally not a runtime/provider abstraction. Generated +//! code calls these kernels after the Bolt lowering pipeline has made oracle +//! generation explicit in IR. + +use jolt_field::Field; +use jolt_poly::EqPolynomial; + +pub const NUM_DENSE_TRACE_COLUMNS: usize = 2; +pub const NUM_ONE_HOT_TRACE_SOURCES: usize = 3; + +/// Per-cycle primitive inputs consumed by Bolt oracle generation. +#[derive(Clone, Copy, Debug)] +pub struct CycleInput { + pub dense: [i128; NUM_DENSE_TRACE_COLUMNS], + pub one_hot: [Option; NUM_ONE_HOT_TRACE_SOURCES], +} + +impl CycleInput { + pub const PADDING: Self = Self { + dense: [0; NUM_DENSE_TRACE_COLUMNS], + one_hot: [Some(0), Some(0), None], + }; +} + +impl Default for CycleInput { + fn default() -> Self { + Self::PADDING + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct CommitmentTraceSources { + pub rd_inc: Vec, + pub ram_inc: Vec, + pub instruction_keys: Vec>, + pub ram_addresses: Vec>, + pub bytecode_indices: Vec>, +} + +impl CommitmentTraceSources { + pub fn from_cycle_inputs(cycle_inputs: &[CycleInput]) -> Self { + Self { + rd_inc: cycle_inputs.iter().map(|cycle| cycle.dense[0]).collect(), + ram_inc: cycle_inputs.iter().map(|cycle| cycle.dense[1]).collect(), + instruction_keys: one_hot_cycle_column(cycle_inputs, 0), + ram_addresses: one_hot_cycle_column(cycle_inputs, 2), + bytecode_indices: one_hot_cycle_column(cycle_inputs, 1), + } + } +} + +pub fn commitment_trace_sources(cycle_inputs: &[CycleInput]) -> CommitmentTraceSources { + CommitmentTraceSources::from_cycle_inputs(cycle_inputs) +} + +/// Returns a dense trace source by its generated oracle source name. +pub fn dense_cycle_source(cycle_inputs: &[CycleInput], source: &str) -> Vec { + let slot = match source { + "trace.rd_inc" => 0, + "trace.ram_inc" => 1, + _ => unreachable!("unsupported dense source `{source}`"), + }; + cycle_inputs.iter().map(|cycle| cycle.dense[slot]).collect() +} + +/// Returns a one-hot trace source by its generated oracle source name. +pub fn one_hot_cycle_source(cycle_inputs: &[CycleInput], source: &str) -> Vec> { + let slot = match source { + "trace.instruction_keys" => 0, + "trace.bytecode_indices" => 1, + "trace.ram_addresses" => 2, + _ => unreachable!("unsupported one-hot source `{source}`"), + }; + one_hot_cycle_column(cycle_inputs, slot) +} + +/// Maps generated one-hot padding policy names to the corresponding value. +pub fn one_hot_padding_value(padding: &str) -> Option { + match padding { + "zero" => Some(0), + "none" => None, + _ => unreachable!("unsupported padding `{padding}`"), + } +} + +/// Converts an i128 trace column to field elements and pads it to `target_len`. +/// +/// The input is normally trace-length data; commitment domains can be larger +/// than the trace domain, so generated code asks for the final committed length. +pub fn dense_i128_column_to_field(values: &[i128], target_len: usize) -> Vec { + assert!( + values.len() <= target_len, + "dense trace column has {} values, target length is {target_len}", + values.len() + ); + let mut output: Vec = values.iter().map(|&value| F::from_i128(value)).collect(); + output.resize(target_len, F::zero()); + output +} + +/// Pads an optional field-valued oracle to `target_len`. +/// +/// `None` stays `None`; zero-skipping policy is deliberately left to the +/// generated commitment code because skip semantics are protocol metadata. +pub fn optional_field_oracle(values: Option<&[F]>, target_len: usize) -> Option> { + values.map(|values| pad_field_oracle(values, target_len)) +} + +/// Pads a field-valued oracle to `target_len`. +pub fn pad_field_oracle(values: &[F], target_len: usize) -> Vec { + assert!( + values.len() <= target_len, + "field oracle has {} values, target length is {target_len}", + values.len() + ); + let mut output = values.to_vec(); + output.resize(target_len, F::zero()); + output +} + +/// Deterministic placeholder data for optional advice oracles in synthetic tests. +pub fn deterministic_oracle_data(oracle: &str, num_vars: usize) -> Vec { + let seed = oracle.bytes().fold(17u64, |state, byte| { + state.wrapping_mul(131).wrapping_add(byte as u64) + }); + (0..(1usize << num_vars)) + .map(|index| F::from_u64(seed.wrapping_add(index as u64 + 1))) + .collect() +} + +/// Returns synthetic data for non-advice oracles and `None` for optional advice. +pub fn optional_oracle_data(oracle: &str, num_vars: usize) -> Option> { + match oracle { + "UntrustedAdvice" | "TrustedAdvice" => None, + _ => Some(deterministic_oracle_data(oracle, num_vars)), + } +} + +/// Builds sparse per-cycle one-hot chunk indices. +/// +/// The returned vector has one entry per trace cycle. `None` means no one-hot +/// entry is active for that cycle. Chunk `0` is the most significant chunk, +/// matching jolt-core's committed RA decomposition. +pub fn one_hot_chunk_indices( + values: &[Option], + chunk: usize, + num_chunks: usize, + chunk_bits: usize, + trace_len: usize, + padding_value: Option, +) -> Vec> { + assert!( + values.len() <= trace_len, + "one-hot source has {} values, trace length is {trace_len}", + values.len() + ); + assert!( + chunk < num_chunks, + "chunk index {chunk} out of bounds for {num_chunks} chunks" + ); + assert!( + chunk_bits <= u8::BITS as usize, + "chunk_bits must fit in one byte" + ); + assert!( + chunk_bits * num_chunks <= u128::BITS as usize, + "one-hot chunks must fit in u128 source values" + ); + + let chunk_domain = 1usize << chunk_bits; + let shift = chunk_bits * (num_chunks - 1 - chunk); + let mask = (chunk_domain - 1) as u128; + let mut output = Vec::with_capacity(trace_len); + + for cycle in 0..trace_len { + let value = values.get(cycle).copied().flatten().or(padding_value); + output.push(value.map(|value| ((value >> shift) & mask) as u8)); + } + + output +} + +/// Builds one address-major one-hot chunk polynomial. +/// +/// Layout is `output[chunk_value * trace_len + cycle]`. Chunk `0` is the most +/// significant chunk, matching jolt-core's committed RA decomposition. +pub fn one_hot_chunk_address_major( + values: &[Option], + chunk: usize, + num_chunks: usize, + chunk_bits: usize, + trace_len: usize, + padding_value: Option, +) -> Vec { + let indices = one_hot_chunk_indices( + values, + chunk, + num_chunks, + chunk_bits, + trace_len, + padding_value, + ); + one_hot_address_major_from_indices(&indices, chunk_bits) +} + +/// Builds one address-major one-hot chunk polynomial from sparse per-cycle indices. +/// +/// Layout is `output[chunk_value * trace_len + cycle]`. +pub fn one_hot_address_major_from_indices( + indices: &[Option], + chunk_bits: usize, +) -> Vec { + assert!( + chunk_bits < usize::BITS as usize, + "chunk_bits must fit in usize" + ); + + let chunk_domain = 1usize << chunk_bits; + let mut output = vec![F::zero(); chunk_domain * indices.len()]; + + for (cycle, index) in indices.iter().enumerate() { + if let Some(index) = index { + let index = usize::from(*index); + assert!( + index < chunk_domain, + "one-hot index {index} exceeds chunk domain {chunk_domain}" + ); + output[index * indices.len() + cycle] = F::one(); + } + } + + output +} + +/// Builds one cycle-major one-hot chunk polynomial from sparse per-cycle indices. +/// +/// Layout is `output[cycle * chunk_domain + chunk_value]`. +pub fn one_hot_cycle_major_from_indices( + indices: &[Option], + chunk_bits: usize, +) -> Vec { + assert!( + chunk_bits < usize::BITS as usize, + "chunk_bits must fit in usize" + ); + + let chunk_domain = 1usize << chunk_bits; + let mut output = vec![F::zero(); chunk_domain * indices.len()]; + + for (cycle, index) in indices.iter().enumerate() { + if let Some(index) = index { + let index = usize::from(*index); + assert!( + index < chunk_domain, + "one-hot index {index} exceeds chunk domain {chunk_domain}" + ); + output[cycle * chunk_domain + index] = F::one(); + } + } + + output +} + +/// Evaluates one-hot per-cycle indices at an address-chunk point. +/// +/// The returned vector has one field element per cycle. Skipped entries +/// evaluate to zero. +pub fn one_hot_evals_at_chunk_point(indices: &[Option], point: &[F]) -> Vec { + let eq_table = EqPolynomial::::evals(point, None); + indices + .iter() + .map(|index| { + index.map_or(F::zero(), |index| { + let index = usize::from(index); + assert!( + index < eq_table.len(), + "one-hot index {index} exceeds chunk point domain {}", + eq_table.len() + ); + eq_table[index] + }) + }) + .collect() +} + +/// Returns most-significant-first chunk widths for a bitstring split by `chunk_bits`. +/// +/// If the high chunk is partial, it appears first. The result is padded with +/// full-width chunks until it reaches `chunk_count`. +pub fn msb_chunk_bit_widths( + total_bits: usize, + chunk_bits: usize, + chunk_count: usize, +) -> Vec { + assert!(chunk_bits > 0, "chunk_bits must be nonzero"); + let first_chunk_bits = total_bits % chunk_bits; + let mut widths = Vec::with_capacity(chunk_count); + if first_chunk_bits != 0 { + widths.push(first_chunk_bits); + } + while widths.len() < chunk_count { + widths.push(chunk_bits); + } + widths +} + +/// Splits a most-significant-first point into fixed-width chunks. +/// +/// The high chunk is left-padded with zero challenges if the point length is +/// not a multiple of `chunk_bits`. +pub fn msb_point_chunks(point: &[F], chunk_bits: usize) -> Vec> { + assert!(chunk_bits > 0, "chunk_bits must be nonzero"); + let mut padded = Vec::new(); + let remainder = point.len() % chunk_bits; + if remainder != 0 { + padded.resize(chunk_bits - remainder, F::zero()); + } + padded.extend_from_slice(point); + padded + .chunks(chunk_bits) + .map(|chunk| chunk.to_vec()) + .collect() +} + +/// Computes `post - pre` in the field for a `u64` value transition. +pub fn u64_increment(pre: u64, post: u64) -> F { + if post >= pre { + F::from_u64(post - pre) + } else { + -F::from_u64(pre - post) + } +} + +/// Computes a field increment column from `(pre, post)` `u64` transitions. +pub fn u64_increment_column(transitions: impl IntoIterator) -> Vec { + transitions + .into_iter() + .map(|(pre, post)| u64_increment(pre, post)) + .collect() +} + +/// Computes a field increment column where a missing write contributes zero. +pub fn optional_u64_increment_column( + transitions: impl IntoIterator>, +) -> Vec { + transitions + .into_iter() + .map(|transition| transition.map_or_else(F::zero, |(pre, post)| u64_increment(pre, post))) + .collect() +} + +/// Materializes an optional `usize` source column. +pub fn optional_usize_column( + values: impl IntoIterator>, +) -> Vec> { + values.into_iter().collect() +} + +#[derive(Clone, Debug)] +pub struct Stage45SparseTraceWitness { + pub rd_inc: Vec, + pub ram_addresses: Vec>, + pub ram_inc: Vec, + pub rd_write_addresses: Vec>, +} + +pub fn stage4_5_sparse_trace_witness( + register_writes: impl IntoIterator>, + ram_accesses: impl IntoIterator, u64, u64)>, +) -> Stage45SparseTraceWitness { + let mut rd_inc = Vec::new(); + let mut rd_write_addresses = Vec::new(); + for write in register_writes { + if let Some((address, pre_value, post_value)) = write { + rd_inc.push(u64_increment(pre_value, post_value)); + rd_write_addresses.push(Some(address)); + } else { + rd_inc.push(F::zero()); + rd_write_addresses.push(None); + } + } + + let mut ram_addresses = Vec::new(); + let mut ram_inc = Vec::new(); + for (address, read_value, write_value) in ram_accesses { + ram_addresses.push(address); + ram_inc.push(u64_increment(read_value, write_value)); + } + + Stage45SparseTraceWitness { + rd_inc, + ram_addresses, + ram_inc, + rd_write_addresses, + } +} + +/// Evaluates a `u64`-valued multilinear extension at `point`. +pub fn mle_eval_u64(values: &[u64], point: &[F]) -> F { + EqPolynomial::::evals(point, None) + .iter() + .zip(values) + .map(|(&weight, &value)| weight * F::from_u64(value)) + .sum() +} + +/// Builds the Stage 4 `RamValInit` opening from the initial RAM image. +/// +/// Stage 4 consumes this at the same address point as `RamValFinal`. +pub fn stage4_ram_val_init_opening( + initial_ram_state: &[u64], + ram_val_final_point: &[F], +) -> (Vec, F) { + ( + ram_val_final_point.to_vec(), + mle_eval_u64(initial_ram_state, ram_val_final_point), + ) +} + +/// Reverses a challenge point. +pub fn reverse_point(point: &[F]) -> Vec { + point.iter().rev().copied().collect() +} + +/// Returns the last `len` point coordinates in reverse order. +pub fn reversed_suffix(point: &[F], len: usize) -> Vec { + let Some(start) = point.len().checked_sub(len) else { + unreachable!("point is shorter than suffix length {len}"); + }; + point[start..].iter().rev().copied().collect() +} + +/// Normalizes Stage 4 register read/write points to address-major order. +pub fn normalized_stage4_registers_rw_point( + log_t: usize, + register_log_k: usize, + point: &[F], +) -> Vec { + let expected = log_t + register_log_k; + assert_eq!( + point.len(), + expected, + "Stage 4 registers point length mismatch" + ); + let (cycle, address) = point.split_at(log_t); + address + .iter() + .rev() + .copied() + .chain(cycle.iter().rev().copied()) + .collect() +} + +/// Extracts the Stage 5 instruction read-RAF cycle point. +pub fn stage5_instruction_cycle_point( + stage5_point: &[F], + instruction_ra_virtual_d: usize, + ra_virtual_log_k_chunk: usize, + log_t: usize, +) -> Vec { + let address_len = instruction_ra_virtual_d * ra_virtual_log_k_chunk; + let end = address_len + log_t; + assert!( + end <= stage5_point.len(), + "Stage 5 point is shorter than instruction address plus cycle arity" + ); + reverse_point(&stage5_point[address_len..end]) +} + +/// Builds a Stage 5 instruction RA opening point for one virtual address chunk. +pub fn stage5_instruction_ra_point( + stage5_point: &[F], + instruction_ra_virtual_d: usize, + ra_virtual_log_k_chunk: usize, + log_t: usize, + index: usize, +) -> Vec { + let start = index * ra_virtual_log_k_chunk; + let end = start + ra_virtual_log_k_chunk; + assert!( + end <= stage5_point.len(), + "Stage 5 point is shorter than instruction RA chunk {index}" + ); + let mut point = stage5_point[start..end].to_vec(); + point.extend(stage5_instruction_cycle_point( + stage5_point, + instruction_ra_virtual_d, + ra_virtual_log_k_chunk, + log_t, + )); + point +} + +/// Builds the Stage 5 RAM RA opening point from its input address and cycle point. +pub fn stage5_ram_ra_point( + stage5_input_point: &[F], + stage5_point: &[F], + log_k_ram: usize, + log_t: usize, +) -> Vec { + assert!( + stage5_input_point.len() >= log_k_ram, + "Stage 5 RAM RA input point is shorter than RAM address arity" + ); + let mut point = stage5_input_point[..log_k_ram].to_vec(); + point.extend(reversed_suffix(stage5_point, log_t)); + point +} + +/// Builds the Stage 5 RegistersVal opening point from address and cycle points. +pub fn stage5_registers_val_point( + stage5_input_point: &[F], + stage5_point: &[F], + register_log_k: usize, + log_t: usize, +) -> Vec { + assert!( + stage5_input_point.len() >= register_log_k, + "Stage 5 RegistersVal input point is shorter than register address arity" + ); + let mut point = stage5_input_point[..register_log_k].to_vec(); + point.extend(reversed_suffix(stage5_point, log_t)); + point +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6WitnessParams { + pub trace_len: usize, + pub log_k_chunk: usize, + pub log_k_bytecode: usize, + pub log_k_ram: usize, + pub lookups_ra_virtual_log_k_chunk: usize, + pub instruction_d: usize, + pub instruction_ra_virtual_d: usize, + pub bytecode_d: usize, + pub ram_d: usize, +} + +#[derive(Clone, Copy, Debug)] +pub struct Stage6BytecodeEntry { + pub address: F, + pub imm: F, + pub circuit_flags: [bool; 14], + pub rd: Option, + pub rs1: Option, + pub rs2: Option, + pub lookup_table: Option, + pub is_interleaved: bool, + pub is_branch: bool, + pub left_is_rs1: bool, + pub left_is_pc: bool, + pub right_is_rs2: bool, + pub right_is_imm: bool, + pub is_noop: bool, +} + +#[derive(Clone, Copy, Debug)] +pub struct Stage6OpeningInputRef<'a, F: Field> { + pub symbol: &'a str, + pub point: &'a [F], +} + +#[derive(Clone, Copy, Debug)] +pub struct Stage6WitnessInputs<'a, F: Field> { + pub params: Stage6WitnessParams, + pub cycle_inputs: &'a [CycleInput], + pub opening_inputs: &'a [Stage6OpeningInputRef<'a, F>], +} + +#[derive(Clone, Debug)] +pub struct Stage6WitnessPolynomials { + pub instruction_ra_indices: Vec>>, + pub bytecode_ra_indices: Vec>>, + pub ram_ra_indices: Vec>>, + pub instruction_ra_booleanity: Vec>, + pub bytecode_ra_booleanity: Vec>, + pub ram_ra_booleanity: Vec>, + pub bytecode_ra_read_raf: Vec>, + pub bytecode_ra_read_raf_chunk_lens: Vec, + pub instruction_ra_virtual: Vec>, + pub ram_ra_virtual: Vec>, + pub hamming_weight: Vec, + pub ram_inc: Vec, + pub rd_inc: Vec, +} + +#[derive(Clone, Debug)] +pub struct Stage6WitnessSlices<'a, F: Field> { + pub booleanity_chunks: Vec<&'a [F]>, + pub booleanity_index_chunks: Vec<&'a [Option]>, + pub bytecode_ra_read_raf_chunks: Vec<&'a [F]>, + pub bytecode_ra_read_raf_chunk_lens: Vec, + pub ram_ra_virtual_chunks: Vec<&'a [F]>, + pub instruction_ra_virtual_chunks: Vec<&'a [F]>, + pub instruction_ra_index_chunks: Vec<&'a [Option]>, + pub bytecode_ra_index_chunks: Vec<&'a [Option]>, + pub ram_ra_index_chunks: Vec<&'a [Option]>, +} + +impl Stage6WitnessPolynomials { + /// Returns borrowed slices in the order expected by the generated Stage 6/7 kernels. + pub fn slices(&self) -> Stage6WitnessSlices<'_, F> { + let mut booleanity_chunks = field_slices(&self.instruction_ra_booleanity); + booleanity_chunks.extend(field_slices(&self.bytecode_ra_booleanity)); + booleanity_chunks.extend(field_slices(&self.ram_ra_booleanity)); + + let mut booleanity_index_chunks = index_slices(&self.instruction_ra_indices); + booleanity_index_chunks.extend(index_slices(&self.bytecode_ra_indices)); + booleanity_index_chunks.extend(index_slices(&self.ram_ra_indices)); + + Stage6WitnessSlices { + booleanity_chunks, + booleanity_index_chunks, + bytecode_ra_read_raf_chunks: field_slices(&self.bytecode_ra_read_raf), + bytecode_ra_read_raf_chunk_lens: self.bytecode_ra_read_raf_chunk_lens.clone(), + ram_ra_virtual_chunks: field_slices(&self.ram_ra_virtual), + instruction_ra_virtual_chunks: field_slices(&self.instruction_ra_virtual), + instruction_ra_index_chunks: index_slices(&self.instruction_ra_indices), + bytecode_ra_index_chunks: index_slices(&self.bytecode_ra_indices), + ram_ra_index_chunks: index_slices(&self.ram_ra_indices), + } + } +} + +pub fn stage6_witness_polynomials( + inputs: Stage6WitnessInputs<'_, F>, +) -> Stage6WitnessPolynomials { + let params = inputs.params; + let trace_len = params.trace_len; + assert!( + inputs.cycle_inputs.len() <= trace_len, + "cycle input length {} exceeds trace length {trace_len}", + inputs.cycle_inputs.len() + ); + + let instruction_keys = one_hot_cycle_column(inputs.cycle_inputs, 0); + let bytecode_indices_source = one_hot_cycle_column(inputs.cycle_inputs, 1); + let ram_addresses = one_hot_cycle_column(inputs.cycle_inputs, 2); + + let instruction_indices = (0..params.instruction_d) + .map(|index| { + one_hot_chunk_indices( + &instruction_keys, + index, + params.instruction_d, + params.log_k_chunk, + trace_len, + Some(0), + ) + }) + .collect::>(); + let bytecode_indices = (0..params.bytecode_d) + .map(|index| { + one_hot_chunk_indices( + &bytecode_indices_source, + index, + params.bytecode_d, + params.log_k_chunk, + trace_len, + Some(0), + ) + }) + .collect::>(); + let ram_indices = (0..params.ram_d) + .map(|index| { + one_hot_chunk_indices( + &ram_addresses, + index, + params.ram_d, + params.log_k_chunk, + trace_len, + None, + ) + }) + .collect::>(); + + let bytecode_ra_read_raf_chunk_lens = + msb_chunk_bit_widths(params.log_k_bytecode, params.log_k_chunk, params.bytecode_d); + + let ram_address_chunks = stage6_ram_virtual_address_chunks(params, inputs.opening_inputs); + assert_eq!( + ram_address_chunks.len(), + params.ram_d, + "RAM Stage 6 address chunk count mismatch" + ); + let ram_ra_virtual = ram_indices + .iter() + .zip(&ram_address_chunks) + .map(|(indices, point)| one_hot_evals_at_chunk_point(indices, point)) + .collect::>(); + + let instruction_address_chunks = + stage6_instruction_virtual_address_chunks(params, inputs.opening_inputs); + assert_eq!( + instruction_address_chunks.len(), + params.instruction_d, + "instruction Stage 6 address chunk count mismatch" + ); + let instruction_ra_virtual = instruction_indices + .iter() + .zip(&instruction_address_chunks) + .map(|(indices, point)| one_hot_evals_at_chunk_point(indices, point)) + .collect::>(); + + Stage6WitnessPolynomials { + instruction_ra_indices: instruction_indices, + bytecode_ra_indices: bytecode_indices, + ram_ra_indices: ram_indices, + instruction_ra_booleanity: Vec::new(), + bytecode_ra_booleanity: Vec::new(), + ram_ra_booleanity: Vec::new(), + bytecode_ra_read_raf: Vec::new(), + bytecode_ra_read_raf_chunk_lens, + instruction_ra_virtual, + ram_ra_virtual, + hamming_weight: hamming_weight_from_cycle_inputs(inputs.cycle_inputs, trace_len), + ram_inc: dense_cycle_column_to_field(inputs.cycle_inputs, 1, trace_len), + rd_inc: dense_cycle_column_to_field(inputs.cycle_inputs, 0, trace_len), + } +} + +fn field_slices(values: &[Vec]) -> Vec<&[F]> { + values.iter().map(Vec::as_slice).collect() +} + +fn index_slices(values: &[Vec>]) -> Vec<&[Option]> { + values.iter().map(Vec::as_slice).collect() +} + +fn one_hot_cycle_column(cycle_inputs: &[CycleInput], slot: usize) -> Vec> { + cycle_inputs + .iter() + .map(|cycle| cycle.one_hot[slot]) + .collect() +} + +fn dense_cycle_column_to_field( + cycle_inputs: &[CycleInput], + slot: usize, + trace_len: usize, +) -> Vec { + assert!( + cycle_inputs.len() <= trace_len, + "cycle input length {} exceeds trace length {trace_len}", + cycle_inputs.len() + ); + let mut output = cycle_inputs + .iter() + .map(|cycle| F::from_i128(cycle.dense[slot])) + .collect::>(); + output.resize(trace_len, F::zero()); + output +} + +fn hamming_weight_from_cycle_inputs( + cycle_inputs: &[CycleInput], + trace_len: usize, +) -> Vec { + assert!( + cycle_inputs.len() <= trace_len, + "cycle input length {} exceeds trace length {trace_len}", + cycle_inputs.len() + ); + let mut output = cycle_inputs + .iter() + .map(|cycle| { + if cycle.one_hot[2].is_some() { + F::one() + } else { + F::zero() + } + }) + .collect::>(); + output.resize(trace_len, F::zero()); + output +} + +fn stage6_ram_virtual_address_chunks( + params: Stage6WitnessParams, + opening_inputs: &[Stage6OpeningInputRef<'_, F>], +) -> Vec> { + let point = stage6_opening_point( + opening_inputs, + "stage6.input.stage5.ram_ra_claim_reduction.RamRa", + ); + assert!( + point.len() >= params.log_k_ram, + "RAM RA opening point is shorter than the RAM address arity" + ); + msb_point_chunks(&point[..params.log_k_ram], params.log_k_chunk) +} + +fn stage6_instruction_virtual_address_chunks( + params: Stage6WitnessParams, + opening_inputs: &[Stage6OpeningInputRef<'_, F>], +) -> Vec> { + let mut address = Vec::with_capacity(params.instruction_d * params.log_k_chunk); + for index in 0..params.instruction_ra_virtual_d { + let symbol = format!("stage6.input.stage5.instruction_read_raf.InstructionRa_{index}"); + let point = stage6_opening_point(opening_inputs, &symbol); + assert!( + point.len() >= params.lookups_ra_virtual_log_k_chunk, + "instruction RA opening point is shorter than the virtual address chunk arity" + ); + address.extend_from_slice(&point[..params.lookups_ra_virtual_log_k_chunk]); + } + msb_point_chunks(&address, params.log_k_chunk) +} + +fn stage6_opening_point<'a, F: Field>( + opening_inputs: &'a [Stage6OpeningInputRef<'_, F>], + symbol: &str, +) -> &'a [F] { + let Some(input) = opening_inputs.iter().find(|input| input.symbol == symbol) else { + unreachable!("missing Stage 6 opening input `{symbol}`"); + }; + input.point +} + +#[cfg(test)] +mod tests { + use super::*; + use jolt_field::{Fr, FromPrimitiveInt}; + + fn fr(value: u64) -> Fr { + Fr::from_u64(value) + } + + #[test] + fn dense_column_converts_and_pads() { + let output = dense_i128_column_to_field::(&[5, -3], 4); + assert_eq!(output.len(), 4); + assert_eq!(output[0], Fr::from_i128(5)); + assert_eq!(output[1], Fr::from_i128(-3)); + assert_eq!(output[2], Fr::from_u64(0)); + assert_eq!(output[3], Fr::from_u64(0)); + } + + #[test] + fn cycle_sources_select_generated_trace_columns() { + let cycle_inputs = [ + CycleInput { + dense: [3, -2], + one_hot: [Some(7), Some(5), None], + }, + CycleInput { + dense: [8, 11], + one_hot: [Some(1), Some(4), Some(9)], + }, + ]; + let sources = commitment_trace_sources(&cycle_inputs); + assert_eq!(sources.rd_inc, vec![3, 8]); + assert_eq!(sources.ram_inc, vec![-2, 11]); + assert_eq!(sources.instruction_keys, vec![Some(7), Some(1)]); + assert_eq!(sources.ram_addresses, vec![None, Some(9)]); + assert_eq!(sources.bytecode_indices, vec![Some(5), Some(4)]); + assert_eq!( + dense_cycle_source(&cycle_inputs, "trace.rd_inc"), + vec![3, 8] + ); + assert_eq!( + dense_cycle_source(&cycle_inputs, "trace.ram_inc"), + vec![-2, 11] + ); + assert_eq!( + one_hot_cycle_source(&cycle_inputs, "trace.instruction_keys"), + vec![Some(7), Some(1)] + ); + assert_eq!( + one_hot_cycle_source(&cycle_inputs, "trace.bytecode_indices"), + vec![Some(5), Some(4)] + ); + assert_eq!( + one_hot_cycle_source(&cycle_inputs, "trace.ram_addresses"), + vec![None, Some(9)] + ); + } + + #[test] + fn increment_columns_compute_signed_field_deltas() { + assert_eq!(u64_increment::(2, 9), Fr::from_u64(7)); + assert_eq!(u64_increment::(9, 2), -Fr::from_u64(7)); + assert_eq!( + u64_increment_column::([(5, 8), (8, 3)]), + vec![Fr::from_u64(3), -Fr::from_u64(5)] + ); + assert_eq!( + optional_u64_increment_column::([Some((5, 8)), None, Some((8, 3))]), + vec![Fr::from_u64(3), Fr::from_u64(0), -Fr::from_u64(5)] + ); + } + + #[test] + fn stage4_5_sparse_trace_witness_groups_sparse_columns() { + let witness = stage4_5_sparse_trace_witness::( + [Some((2, 5, 8)), None, Some((3, 9, 4))], + [(Some(7), 10, 12), (None, 3, 3), (Some(8), 1, 0)], + ); + + assert_eq!(witness.rd_inc, vec![fr(3), fr(0), -fr(5)]); + assert_eq!(witness.rd_write_addresses, vec![Some(2), None, Some(3)]); + assert_eq!(witness.ram_addresses, vec![Some(7), None, Some(8)]); + assert_eq!(witness.ram_inc, vec![fr(2), fr(0), -fr(1)]); + } + + #[test] + fn mle_eval_u64_matches_boolean_hypercube_points() { + let values = [10, 20, 30, 40]; + let point = [Fr::from_u64(1), Fr::from_u64(0)]; + assert_eq!(mle_eval_u64(&values, &point), Fr::from_u64(30)); + } + + #[test] + fn stage4_ram_val_init_opening_uses_final_ram_point() { + let values = [10, 20, 30, 40]; + let point = [Fr::from_u64(1), Fr::from_u64(0)]; + let (opening_point, eval) = stage4_ram_val_init_opening(&values, &point); + + assert_eq!(opening_point, point); + assert_eq!(eval, Fr::from_u64(30)); + } + + #[test] + fn point_helpers_normalize_stage_points() { + let point = [fr(1), fr(2), fr(3), fr(4), fr(5)]; + assert_eq!( + reverse_point(&point), + vec![fr(5), fr(4), fr(3), fr(2), fr(1)] + ); + assert_eq!(reversed_suffix(&point, 3), vec![fr(5), fr(4), fr(3)]); + assert_eq!( + normalized_stage4_registers_rw_point(2, 3, &point), + vec![fr(5), fr(4), fr(3), fr(2), fr(1)] + ); + } + + #[test] + fn stage5_point_helpers_compose_address_and_cycle_points() { + let stage5_point = [fr(10), fr(11), fr(12), fr(13), fr(14), fr(15)]; + let input_point = [fr(20), fr(21), fr(22), fr(23)]; + assert_eq!( + stage5_instruction_cycle_point(&stage5_point, 2, 2, 2), + vec![fr(15), fr(14)] + ); + assert_eq!( + stage5_instruction_ra_point(&stage5_point, 2, 2, 2, 1), + vec![fr(12), fr(13), fr(15), fr(14)] + ); + assert_eq!( + stage5_ram_ra_point(&input_point, &stage5_point, 3, 2), + vec![fr(20), fr(21), fr(22), fr(15), fr(14)] + ); + assert_eq!( + stage5_registers_val_point(&input_point, &stage5_point, 2, 2), + vec![fr(20), fr(21), fr(15), fr(14)] + ); + } + + #[test] + fn one_hot_chunks_are_address_major_and_msb_first() { + let values = [Some(0xABu128), Some(0x12), None]; + let output = one_hot_chunk_address_major::(&values, 0, 2, 4, 4, Some(0)); + + assert_eq!(output.len(), 16 * 4); + assert_eq!(output[0xA * 4], Fr::from_u64(1)); + assert_eq!(output[5], Fr::from_u64(1)); + assert_eq!(output[2], Fr::from_u64(1)); + assert_eq!(output[3], Fr::from_u64(1)); + } + + #[test] + fn one_hot_address_major_from_indices_skips_none_entries() { + let output = one_hot_address_major_from_indices::(&[Some(2), None, Some(1)], 2); + + assert_eq!(output.len(), 12); + assert_eq!(output[2 * 3], Fr::from_u64(1)); + assert_eq!(output[5], Fr::from_u64(1)); + assert_eq!( + output + .iter() + .enumerate() + .filter(|(_, value)| **value == Fr::from_u64(1)) + .map(|(index, _)| index) + .collect::>(), + vec![5, 6] + ); + } + + #[test] + fn one_hot_cycle_major_from_indices_skips_none_entries() { + let output = one_hot_cycle_major_from_indices::(&[Some(2), None, Some(1)], 2); + + assert_eq!(output.len(), 12); + assert_eq!(output[2], Fr::from_u64(1)); + assert_eq!(output[2 * 4 + 1], Fr::from_u64(1)); + assert_eq!( + output + .iter() + .enumerate() + .filter(|(_, value)| **value == Fr::from_u64(1)) + .map(|(index, _)| index) + .collect::>(), + vec![2, 9] + ); + } + + #[test] + fn one_hot_evals_at_chunk_point_evaluates_sparse_indices() { + let point = [Fr::from_u64(5), Fr::from_u64(7)]; + let eq = EqPolynomial::::evals(&point, None); + let output = one_hot_evals_at_chunk_point(&[Some(0), Some(3), None], &point); + + assert_eq!(output, vec![eq[0], eq[3], Fr::from_u64(0)]); + } + + #[test] + fn stage6_witness_slices_preserve_kernel_order() { + let witness = Stage6WitnessPolynomials { + instruction_ra_indices: vec![vec![Some(1)]], + bytecode_ra_indices: vec![vec![Some(2)]], + ram_ra_indices: vec![vec![None]], + instruction_ra_booleanity: vec![vec![fr(10)]], + bytecode_ra_booleanity: vec![vec![fr(20)]], + ram_ra_booleanity: vec![vec![fr(30)]], + bytecode_ra_read_raf: vec![vec![fr(40)]], + bytecode_ra_read_raf_chunk_lens: vec![1], + instruction_ra_virtual: vec![vec![fr(50)]], + ram_ra_virtual: vec![vec![fr(60)]], + hamming_weight: vec![fr(70)], + ram_inc: vec![fr(80)], + rd_inc: vec![fr(90)], + }; + + let slices = witness.slices(); + assert_eq!( + slices.booleanity_chunks, + vec![ + witness.instruction_ra_booleanity[0].as_slice(), + witness.bytecode_ra_booleanity[0].as_slice(), + witness.ram_ra_booleanity[0].as_slice(), + ] + ); + assert_eq!( + slices.booleanity_index_chunks, + vec![ + witness.instruction_ra_indices[0].as_slice(), + witness.bytecode_ra_indices[0].as_slice(), + witness.ram_ra_indices[0].as_slice(), + ] + ); + assert_eq!( + slices.bytecode_ra_read_raf_chunks, + vec![witness.bytecode_ra_read_raf[0].as_slice()] + ); + assert_eq!(slices.bytecode_ra_read_raf_chunk_lens, vec![1]); + assert_eq!( + slices.instruction_ra_index_chunks, + vec![witness.instruction_ra_indices[0].as_slice()] + ); + assert_eq!( + slices.bytecode_ra_index_chunks, + vec![witness.bytecode_ra_indices[0].as_slice()] + ); + assert_eq!( + slices.ram_ra_index_chunks, + vec![witness.ram_ra_indices[0].as_slice()] + ); + } + + #[test] + fn msb_chunk_bit_widths_puts_partial_high_chunk_first() { + assert_eq!(msb_chunk_bit_widths(10, 4, 3), vec![2, 4, 4]); + assert_eq!(msb_chunk_bit_widths(12, 4, 3), vec![4, 4, 4]); + } + + #[test] + fn msb_point_chunks_left_pads_partial_high_chunk() { + let point = [Fr::from_u64(1), Fr::from_u64(2), Fr::from_u64(3)]; + let chunks = msb_point_chunks(&point, 2); + + assert_eq!( + chunks, + vec![ + vec![Fr::from_u64(0), Fr::from_u64(1)], + vec![Fr::from_u64(2), Fr::from_u64(3)] + ] + ); + } + + #[test] + fn one_hot_chunk_indices_are_msb_first_and_padded() { + let values = [Some(0xABu128), Some(0x12), None]; + let output = one_hot_chunk_indices(&values, 0, 2, 4, 4, Some(0)); + + assert_eq!(output, vec![Some(0xA), Some(0x1), Some(0), Some(0)]); + } + + #[test] + fn one_hot_chunk_indices_preserve_skipped_entries() { + let values = [Some(3u128), None]; + let output = one_hot_chunk_indices(&values, 0, 1, 2, 3, None); + + assert_eq!(output, vec![Some(3), None, None]); + } + + #[test] + fn one_hot_none_padding_skips_entries() { + let values = [Some(3u128), None]; + let output = one_hot_chunk_address_major::(&values, 0, 1, 2, 3, None); + + assert_eq!(output[3 * 3], Fr::from_u64(1)); + assert!(output + .iter() + .enumerate() + .all(|(index, value)| index == 9 || *value == Fr::from_u64(0))); + } +}