diff --git a/.cargo/katex-header.html b/.cargo/katex-header.html index 5db5bc0b1..ca338654e 100644 --- a/.cargo/katex-header.html +++ b/.cargo/katex-header.html @@ -11,6 +11,7 @@ renderMathInElement(document.body, { fleqn: false, macros: { + "\\B": "\\mathbb{B}", "\\F": "\\mathbb{F}", "\\G": "\\mathbb{G}", "\\O": "\\mathcal{O}", diff --git a/Cargo.toml b/Cargo.toml index b0ed3f07c..7a5f28ff7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,8 @@ members = [ "prover", "verifier", "winterfell", - "examples" -] + "examples", + "sumcheck",] resolver = "2" [profile.release] diff --git a/air/src/air/aux.rs b/air/src/air/aux.rs index 01f59035a..d9fa3c2d5 100644 --- a/air/src/air/aux.rs +++ b/air/src/air/aux.rs @@ -3,39 +3,32 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use alloc::{string::ToString, vec::Vec}; +use alloc::vec::Vec; -use crypto::{ElementHasher, RandomCoin, RandomCoinError}; -use math::FieldElement; -use utils::Deserializable; +use math::{ExtensionOf, FieldElement}; -use super::lagrange::LagrangeKernelRandElements; +use super::{LagrangeKernelRandElements, LogUpGkrOracle}; -/// Holds the randomly generated elements necessary to build the auxiliary trace. +/// Holds the randomly generated elements used in defining the auxiliary segment of the trace. /// -/// Specifically, [`AuxRandElements`] currently supports 3 types of random elements: -/// - the ones needed to build the Lagrange kernel column (when using GKR to accelerate LogUp), -/// - the ones needed to build the "s" auxiliary column (when using GKR to accelerate LogUp), -/// - the ones needed to build all the other auxiliary columns +/// Specifically, [`AuxRandElements`] currently supports 2 types of random elements: +/// - the ones needed to build all the auxiliary columns except for the ones associated +/// to LogUp-GKR. +/// - the ones needed to build the "s" and Lagrange kernel auxiliary columns (when using GKR to +/// accelerate LogUp). These also include additional information needed to evaluate constraints +/// one these two columns. #[derive(Debug, Clone)] -pub struct AuxRandElements { +pub struct AuxRandElements { rand_elements: Vec, - gkr: Option>, + gkr: Option>, } -impl AuxRandElements { - /// Creates a new [`AuxRandElements`], where the auxiliary trace doesn't contain a Lagrange - /// kernel column. - pub fn new(rand_elements: Vec) -> Self { - Self { rand_elements, gkr: None } - } - - /// Creates a new [`AuxRandElements`], where the auxiliary trace contains columns needed when +impl AuxRandElements { + /// Creates a new [`AuxRandElements`], where the auxiliary segment may contain columns needed when /// using GKR to accelerate LogUp (i.e. a Lagrange kernel column and the "s" column). - pub fn new_with_gkr(rand_elements: Vec, gkr: GkrRandElements) -> Self { - Self { rand_elements, gkr: Some(gkr) } + pub fn new(rand_elements: Vec, gkr: Option>) -> Self { + Self { rand_elements, gkr } } - /// Returns the random elements needed to build all columns other than the two GKR-related ones. pub fn rand_elements(&self) -> &[E] { &self.rand_elements @@ -43,7 +36,7 @@ impl AuxRandElements { /// Returns the random elements needed to build the Lagrange kernel column. pub fn lagrange(&self) -> Option<&LagrangeKernelRandElements> { - self.gkr.as_ref().map(|gkr| &gkr.lagrange) + self.gkr.as_ref().map(|gkr| &gkr.lagrange_kernel_eval_point) } /// Returns the random values used to linearly combine the openings returned from the GKR proof. @@ -52,83 +45,101 @@ impl AuxRandElements { pub fn gkr_openings_combining_randomness(&self) -> Option<&[E]> { self.gkr.as_ref().map(|gkr| gkr.openings_combining_randomness.as_ref()) } + + /// Returns a collection of data necessary for implementing the univariate IOP for multi-linear + /// evaluations of [1] when LogUp-GKR is enabled, else returns a `None`. + /// + /// [1]: https://eprint.iacr.org/2023/1284 + pub fn gkr_data(&self) -> Option> { + self.gkr.clone() + } } -/// Holds all the random elements needed when using GKR to accelerate LogUp. +/// Holds all the data needed when using LogUp-GKR in order to build and verify the correctness of +/// two extra auxiliary columns required for running the univariate IOP for multi-linear +/// evaluations of [1]. /// -/// This consists of two sets of random values: -/// 1. The Lagrange kernel random elements (expanded on in [`LagrangeKernelRandElements`]), and +/// This consists of: +/// 1. The Lagrange kernel random elements (expanded on in [`LagrangeKernelRandElements`]). These +/// make up the evaluation point of the multi-linear extension polynomials underlying the oracles +/// in point 4 below. /// 2. The "openings combining randomness". +/// 3. The openings of the multi-linear extension polynomials of the main trace columns involved +/// in LogUp. +/// 4. A description of the each of the oracles involved in LogUp. /// -/// After the verifying the LogUp-GKR circuit, the verifier is left with unproven claims provided -/// nondeterministically by the prover about the evaluations of the MLE of the main trace columns at -/// the Lagrange kernel random elements. Those claims are (linearly) combined into one using the -/// openings combining randomness. +/// After verifying the LogUp-GKR circuit, the verifier is left with unproven claims provided +/// by the prover about the evaluations of the MLEs of the main trace columns at the evaluation +/// point defining the Lagrange kernel. Those claims are (linearly) batched into one using the +/// openings combining randomness and checked against the batched oracles using univariate IOP +/// for multi-linear evaluations of [1]. +/// +/// [1]: https://eprint.iacr.org/2023/1284 #[derive(Clone, Debug)] -pub struct GkrRandElements { - lagrange: LagrangeKernelRandElements, - openings_combining_randomness: Vec, +pub struct GkrData { + pub lagrange_kernel_eval_point: LagrangeKernelRandElements, + pub openings_combining_randomness: Vec, + pub openings: Vec, + pub oracles: Vec, } -impl GkrRandElements { - /// Constructs a new [`GkrRandElements`] from [`LagrangeKernelRandElements`], and the openings - /// combining randomness. +impl GkrData { + /// Constructs a new [`GkrData`] from [`LagrangeKernelRandElements`], the openings combining + /// randomness and the LogUp-GKR oracles. /// - /// See [`GkrRandElements`] for a more detailed description. + /// See [`GkrData`] for a more detailed description. pub fn new( - lagrange: LagrangeKernelRandElements, + lagrange_kernel_eval_point: LagrangeKernelRandElements, openings_combining_randomness: Vec, + openings: Vec, + oracles: Vec, ) -> Self { - Self { lagrange, openings_combining_randomness } + Self { + lagrange_kernel_eval_point, + openings_combining_randomness, + openings, + oracles, + } } /// Returns the random elements needed to build the Lagrange kernel column. pub fn lagrange_kernel_rand_elements(&self) -> &LagrangeKernelRandElements { - &self.lagrange + &self.lagrange_kernel_eval_point } /// Returns the random values used to linearly combine the openings returned from the GKR proof. pub fn openings_combining_randomness(&self) -> &[E] { &self.openings_combining_randomness } -} -/// A trait for verifying a GKR proof. -/// -/// Specifically, the use case in mind is proving the constraints of a LogUp bus using GKR, as -/// described in [Improving logarithmic derivative lookups using -/// GKR](https://eprint.iacr.org/2023/1284.pdf). -pub trait GkrVerifier { - /// The GKR proof. - type GkrProof: Deserializable; - /// The error that can occur during GKR proof verification. - type Error: ToString; - - /// Verifies the GKR proof, and returns the random elements that were used in building - /// the Lagrange kernel auxiliary column. - fn verify( - &self, - gkr_proof: Self::GkrProof, - public_coin: &mut impl RandomCoin, - ) -> Result, Self::Error> - where - E: FieldElement, - Hasher: ElementHasher; -} + pub fn openings(&self) -> &[E] { + &self.openings + } -impl GkrVerifier for () { - type GkrProof = (); - type Error = RandomCoinError; + pub fn oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + pub fn compute_batched_claim(&self) -> E { + self.openings[0] + + self + .openings + .iter() + .skip(1) + .zip(self.openings_combining_randomness.iter()) + .fold(E::ZERO, |acc, (a, b)| acc + *a * *b) + } - fn verify( - &self, - _gkr_proof: Self::GkrProof, - _public_coin: &mut impl RandomCoin, - ) -> Result, Self::Error> + pub fn compute_batched_query(&self, query: &[F]) -> E where - E: FieldElement, - Hasher: ElementHasher, + F: FieldElement, + E: ExtensionOf, { - Ok(GkrRandElements::new(LagrangeKernelRandElements::default(), Vec::new())) + E::from(query[0]) + + query + .iter() + .skip(1) + .zip(self.openings_combining_randomness.iter()) + .fold(E::ZERO, |acc, (a, b)| acc + b.mul_base(*a)) } } diff --git a/air/src/air/boundary/mod.rs b/air/src/air/boundary/mod.rs index 7f92c80ab..2c15ac5a3 100644 --- a/air/src/air/boundary/mod.rs +++ b/air/src/air/boundary/mod.rs @@ -58,8 +58,8 @@ impl BoundaryConstraints { /// coefficients. /// * The specified assertions are not valid in the context of the computation (e.g., assertion /// column index is out of bounds). - pub fn new( - context: &AirContext, + pub fn new

( + context: &AirContext, main_assertions: Vec>, aux_assertions: Vec>, composition_coefficients: &[E], @@ -88,7 +88,7 @@ impl BoundaryConstraints { ); let trace_length = context.trace_info.length(); - let main_trace_width = context.trace_info.main_trace_width(); + let main_trace_width = context.trace_info.main_segment_width(); let aux_trace_width = context.trace_info.aux_segment_width(); // make sure the assertions are valid in the context of their respective trace segments; @@ -151,9 +151,9 @@ impl BoundaryConstraints { /// Translates the provided assertions into boundary constraints, groups the constraints by their /// divisor, and sorts the resulting groups by the degree adjustment factor. -fn group_constraints( +fn group_constraints( assertions: Vec>, - context: &AirContext, + context: &AirContext, composition_coefficients: &[E], inv_g: F::BaseField, twiddle_map: &mut BTreeMap>, diff --git a/air/src/air/coefficients.rs b/air/src/air/coefficients.rs index b82b2ac6b..ed6c3fa99 100644 --- a/air/src/air/coefficients.rs +++ b/air/src/air/coefficients.rs @@ -27,11 +27,19 @@ use math::FieldElement; /// /// The coefficients are separated into two lists: one for transition constraints and another one /// for boundary constraints. This separation is done for convenience only. +/// +/// In addition to the above, and when LogUp-GKR is enabled, there are two extra sets of +/// constraint composition coefficients that are used, namely for: +/// +/// 1. Lagrange kernel constraints, which include both transition and boundary constraints. +/// 2. S-column constraint, which is used in implementing the cohomological sum-check argument +/// of https://eprint.iacr.org/2021/930 #[derive(Debug, Clone)] pub struct ConstraintCompositionCoefficients { pub transition: Vec, pub boundary: Vec, pub lagrange: Option>, + pub s_col: Option, } /// Stores the constraint composition coefficients for the Lagrange kernel transition and boundary @@ -83,8 +91,9 @@ pub struct LagrangeConstraintsCompositionCoefficients { /// negligible increase in soundness error. The formula for the updated error can be found in /// Theorem 8 of https://eprint.iacr.org/2022/1216. /// -/// In the case when the trace polynomials contain a trace polynomial corresponding to a Lagrange -/// kernel column, the above expression of $Y(x)$ includes the additional term given by +/// In the case when LogUp-GKR is enabled, the trace polynomials contain an additional trace +/// polynomial corresponding to a Lagrange kernel column and the above expression of $Y(x)$ +/// includes the additional term given by /// /// $$ /// \gamma \cdot \frac{T_l(x) - p_S(x)}{Z_S(x)} @@ -99,8 +108,13 @@ pub struct LagrangeConstraintsCompositionCoefficients { /// 4. $p_S(X)$ is the polynomial of minimal degree interpolating the set ${(a, T_l(a)): a \in S}$. /// 5. $Z_S(X)$ is the polynomial of minimal degree vanishing over the set $S$. /// -/// Note that, if a Lagrange kernel trace polynomial is present, then $\rho^{+}$ from above should -/// be updated to be $\rho^{+} := \frac{\kappa + log_2(\nu) + 1}{\nu}$. +/// Note that when LogUp-GKR is enabled, we also have to take into account an additional column, +/// called s-column throughout, used in implementing the univariate IOP for multi-linear evaluation. +/// This means that we need and additional random value, in addition to $\gamma$ above, when +/// LogUp-GKR is enabled. +/// +/// Note that, when LogUp-GKR is enabled, $\rho^{+}$ from above should be updated to be +/// $\rho^{+} := \frac{\kappa + log_2(\nu) + 1}{\nu}$. #[derive(Debug, Clone)] pub struct DeepCompositionCoefficients { /// Trace polynomial composition coefficients $\alpha_i$. @@ -109,4 +123,6 @@ pub struct DeepCompositionCoefficients { pub constraints: Vec, /// Lagrange kernel trace polynomial composition coefficient $\gamma$. pub lagrange: Option, + /// S-column trace polynomial composition coefficient. + pub s_col: Option, } diff --git a/air/src/air/context.rs b/air/src/air/context.rs index 183f575fc..a4074036a 100644 --- a/air/src/air/context.rs +++ b/air/src/air/context.rs @@ -14,21 +14,21 @@ use crate::{air::TransitionConstraintDegree, ProofOptions, TraceInfo}; // ================================================================================================ /// STARK parameters and trace properties for a specific execution of a computation. #[derive(Clone, PartialEq, Eq)] -pub struct AirContext { +pub struct AirContext { pub(super) options: ProofOptions, pub(super) trace_info: TraceInfo, + pub(super) pub_inputs: P, pub(super) main_transition_constraint_degrees: Vec, pub(super) aux_transition_constraint_degrees: Vec, pub(super) num_main_assertions: usize, pub(super) num_aux_assertions: usize, - pub(super) lagrange_kernel_aux_column_idx: Option, pub(super) ce_blowup_factor: usize, pub(super) trace_domain_generator: B, pub(super) lde_domain_generator: B, pub(super) num_transition_exemptions: usize, } -impl AirContext { +impl AirContext { // CONSTRUCTORS // -------------------------------------------------------------------------------------------- /// Returns a new instance of [AirContext] instantiated for computations which require a single @@ -48,6 +48,7 @@ impl AirContext { /// * `trace_info` describes a multi-segment execution trace. pub fn new( trace_info: TraceInfo, + pub_inputs: P, transition_constraint_degrees: Vec, num_assertions: usize, options: ProofOptions, @@ -58,11 +59,11 @@ impl AirContext { ); Self::new_multi_segment( trace_info, + pub_inputs, transition_constraint_degrees, Vec::new(), num_assertions, 0, - None, options, ) } @@ -91,11 +92,11 @@ impl AirContext { /// of the specified transition constraints. pub fn new_multi_segment( trace_info: TraceInfo, + pub_inputs: P, main_transition_constraint_degrees: Vec, aux_transition_constraint_degrees: Vec, num_main_assertions: usize, num_aux_assertions: usize, - lagrange_kernel_aux_column_idx: Option, options: ProofOptions, ) -> Self { assert!( @@ -105,14 +106,16 @@ impl AirContext { assert!(num_main_assertions > 0, "at least one assertion must be specified"); if trace_info.is_multi_segment() { - assert!( - !aux_transition_constraint_degrees.is_empty(), - "at least one transition constraint degree must be specified for the auxiliary trace segment" + if !trace_info.logup_gkr_enabled() { + assert!( + !aux_transition_constraint_degrees.is_empty(), + "at least one transition constraint degree must be specified for the auxiliary trace segment" ); - assert!( - num_aux_assertions > 0, - "at least one assertion must be specified against the auxiliary trace segment" - ); + assert!( + num_aux_assertions > 0, + "at least one assertion must be specified against the auxiliary trace segment" + ); + } } else { assert!( aux_transition_constraint_degrees.is_empty(), @@ -124,15 +127,6 @@ impl AirContext { ); } - // validate Lagrange kernel aux column, if any - if let Some(lagrange_kernel_aux_column_idx) = lagrange_kernel_aux_column_idx { - assert!( - lagrange_kernel_aux_column_idx == trace_info.get_aux_segment_width() - 1, - "Lagrange kernel column should be the last column of the auxiliary trace: index={}, but aux trace width is {}", - lagrange_kernel_aux_column_idx, trace_info.get_aux_segment_width() - ); - } - // determine minimum blowup factor needed to evaluate transition constraints by taking // the blowup factor of the highest degree constraint let mut ce_blowup_factor = 0; @@ -161,11 +155,11 @@ impl AirContext { AirContext { options, trace_info, + pub_inputs, main_transition_constraint_degrees, aux_transition_constraint_degrees, num_main_assertions, num_aux_assertions, - lagrange_kernel_aux_column_idx, ce_blowup_factor, trace_domain_generator: B::get_root_of_unity(trace_length.ilog2()), lde_domain_generator: B::get_root_of_unity(lde_domain_size.ilog2()), @@ -209,6 +203,10 @@ impl AirContext { self.trace_info.length() * self.options.blowup_factor() } + pub fn public_inputs(&self) -> &P { + &self.pub_inputs + } + /// Returns the number of transition constraints for a computation, excluding the Lagrange /// kernel transition constraints, which are managed separately. /// @@ -230,14 +228,14 @@ impl AirContext { self.aux_transition_constraint_degrees.len() } - /// Returns the index of the auxiliary column which implements the Lagrange kernel, if any - pub fn lagrange_kernel_aux_column_idx(&self) -> Option { - self.lagrange_kernel_aux_column_idx + /// Returns the index of the auxiliary column which implements the Lagrange kernel, if any. + pub fn lagrange_kernel_column_idx(&self) -> Option { + self.trace_info.lagrange_kernel_column_idx() } - /// Returns true if the auxiliary trace segment contains a Lagrange kernel column - pub fn has_lagrange_kernel_aux_column(&self) -> bool { - self.lagrange_kernel_aux_column_idx().is_some() + /// Returns true if LogUp-GKR is enabled. + pub fn logup_gkr_enabled(&self) -> bool { + self.trace_info.logup_gkr_enabled() } /// Returns the total number of assertions defined for a computation, excluding the Lagrange @@ -307,10 +305,8 @@ impl AirContext { let trace_length = self.trace_len(); let transition_divisior_degree = trace_length - self.num_transition_exemptions(); - // we use the identity: ceil(a/b) = (a + b - 1)/b let num_constraint_col = - (highest_constraint_degree - transition_divisior_degree + trace_length - 1) - / trace_length; + (highest_constraint_degree - transition_divisior_degree).div_ceil(trace_length); cmp::max(num_constraint_col, 1) } diff --git a/air/src/air/lagrange/boundary.rs b/air/src/air/logup_gkr/lagrange/boundary.rs similarity index 84% rename from air/src/air/lagrange/boundary.rs rename to air/src/air/logup_gkr/lagrange/boundary.rs index 5d1954615..3eaad9f5d 100644 --- a/air/src/air/lagrange/boundary.rs +++ b/air/src/air/logup_gkr/lagrange/boundary.rs @@ -5,7 +5,7 @@ use math::FieldElement; -use crate::{LagrangeKernelEvaluationFrame, LagrangeKernelRandElements}; +use super::{LagrangeKernelEvaluationFrame, LagrangeKernelRandElements}; #[derive(Debug, Clone, Eq, PartialEq)] pub struct LagrangeKernelBoundaryConstraint @@ -31,27 +31,28 @@ where } } + /// Returns the constraint composition coefficient for this boundary constraint. + pub fn constraint_composition_coefficient(&self) -> E { + self.composition_coefficient + } + /// Returns the evaluation of the boundary constraint at `x`, multiplied by the composition /// coefficient. /// /// `frame` is the evaluation frame of the Lagrange kernel column `c`, starting at `c(x)` pub fn evaluate_at(&self, x: E, frame: &LagrangeKernelEvaluationFrame) -> E { - let numerator = self.evaluate_numerator_at(frame); + let numerator = self.evaluate_numerator_at(frame) * self.composition_coefficient; let denominator = self.evaluate_denominator_at(x); numerator / denominator } - /// Returns the evaluation of the boundary constraint numerator, multiplied by the composition - /// coefficient. + /// Returns the evaluation of the boundary constraint numerator. /// /// `frame` is the evaluation frame of the Lagrange kernel column `c`, starting at `c(x)` for /// some `x` pub fn evaluate_numerator_at(&self, frame: &LagrangeKernelEvaluationFrame) -> E { - let trace_value = frame.inner()[0]; - let constraint_evaluation = trace_value - self.assertion_value; - - constraint_evaluation * self.composition_coefficient + frame[0] - self.assertion_value } /// Returns the evaluation of the boundary constraint denominator at point `x`. diff --git a/air/src/air/lagrange/frame.rs b/air/src/air/logup_gkr/lagrange/frame.rs similarity index 81% rename from air/src/air/lagrange/frame.rs rename to air/src/air/logup_gkr/lagrange/frame.rs index d0ffc4fa4..6dc0a64cc 100644 --- a/air/src/air/lagrange/frame.rs +++ b/air/src/air/logup_gkr/lagrange/frame.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; +use core::ops::{Index, IndexMut}; use math::{polynom, FieldElement, StarkField}; @@ -25,14 +26,15 @@ impl LagrangeKernelEvaluationFrame { // -------------------------------------------------------------------------------------------- /// Constructs a Lagrange kernel evaluation frame from the raw column polynomial evaluations. - pub fn new(frame: Vec) -> Self { + pub fn with_values(frame: Vec) -> Self { Self { frame } } /// Constructs an empty Lagrange kernel evaluation frame from the raw column polynomial /// evaluations. The frame can subsequently be filled using [`Self::frame_mut`]. - pub fn new_empty() -> Self { - Self { frame: Vec::new() } + pub fn new(trace_len: usize) -> Self { + let frame_length = trace_len.ilog2() as usize + 1; + Self { frame: vec![E::ZERO; frame_length] } } /// Constructs the frame from the Lagrange kernel column trace polynomial coefficients for an @@ -61,14 +63,6 @@ impl LagrangeKernelEvaluationFrame { Self { frame } } - // MUTATORS - // -------------------------------------------------------------------------------------------- - - /// Returns a mutable reference to the inner frame. - pub fn frame_mut(&mut self) -> &mut Vec { - &mut self.frame - } - // ACCESSORS // -------------------------------------------------------------------------------------------- @@ -84,3 +78,17 @@ impl LagrangeKernelEvaluationFrame { self.frame.len() } } + +impl Index for LagrangeKernelEvaluationFrame { + type Output = E; + + fn index(&self, index: usize) -> &Self::Output { + &self.frame[index] + } +} + +impl IndexMut for LagrangeKernelEvaluationFrame { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.frame[index] + } +} diff --git a/air/src/air/lagrange/mod.rs b/air/src/air/logup_gkr/lagrange/mod.rs similarity index 95% rename from air/src/air/lagrange/mod.rs rename to air/src/air/logup_gkr/lagrange/mod.rs index fed5897f3..9d80b4437 100644 --- a/air/src/air/lagrange/mod.rs +++ b/air/src/air/logup_gkr/lagrange/mod.rs @@ -3,17 +3,18 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -mod boundary; use alloc::vec::Vec; use core::ops::Deref; +use math::FieldElement; + +mod boundary; pub use boundary::LagrangeKernelBoundaryConstraint; mod frame; pub use frame::LagrangeKernelEvaluationFrame; mod transition; -use math::FieldElement; pub use transition::LagrangeKernelTransitionConstraints; use crate::LagrangeConstraintsCompositionCoefficients; @@ -22,7 +23,6 @@ use crate::LagrangeConstraintsCompositionCoefficients; pub struct LagrangeKernelConstraints { pub transition: LagrangeKernelTransitionConstraints, pub boundary: LagrangeKernelBoundaryConstraint, - pub lagrange_kernel_col_idx: usize, } impl LagrangeKernelConstraints { @@ -30,7 +30,6 @@ impl LagrangeKernelConstraints { pub fn new( lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, lagrange_kernel_rand_elements: &LagrangeKernelRandElements, - lagrange_kernel_col_idx: usize, ) -> Self { Self { transition: LagrangeKernelTransitionConstraints::new( @@ -40,7 +39,6 @@ impl LagrangeKernelConstraints { lagrange_composition_coefficients.boundary, lagrange_kernel_rand_elements, ), - lagrange_kernel_col_idx, } } } diff --git a/air/src/air/lagrange/transition.rs b/air/src/air/logup_gkr/lagrange/transition.rs similarity index 91% rename from air/src/air/lagrange/transition.rs rename to air/src/air/logup_gkr/lagrange/transition.rs index 18bdfa9be..5f5b110e6 100644 --- a/air/src/air/lagrange/transition.rs +++ b/air/src/air/logup_gkr/lagrange/transition.rs @@ -43,6 +43,11 @@ impl LagrangeKernelTransitionConstraints { } } + /// Returns the constraint composition coefficients for the Lagrange kernel transition constraints. + pub fn lagrange_constraint_coefficients(&self) -> &[E] { + &self.lagrange_constraint_coefficients + } + /// Evaluates the numerator of the `constraint_idx`th transition constraint. pub fn evaluate_ith_numerator( &self, @@ -54,14 +59,12 @@ impl LagrangeKernelTransitionConstraints { F: FieldElement, E: ExtensionOf, { - let c = lagrange_kernel_column_frame.inner(); - let v = c.len() - 1; + let c = lagrange_kernel_column_frame; + let v = c.num_rows() - 1; let r = lagrange_kernel_rand_elements; let k = constraint_idx + 1; - let eval = (r[v - k] * c[0]) - ((E::ONE - r[v - k]) * c[v - k + 1]); - - self.lagrange_constraint_coefficients[constraint_idx].mul_base(eval) + (r[v - k] * c[0]) - ((E::ONE - r[v - k]) * c[v - k + 1]) } /// Evaluates the divisor of the `constraint_idx`th transition constraint. @@ -124,8 +127,8 @@ impl LagrangeKernelTransitionConstraints { let log2_trace_len = lagrange_kernel_column_frame.num_rows() - 1; let mut transition_evals = vec![E::ZERO; log2_trace_len]; - let c = lagrange_kernel_column_frame.inner(); - let v = c.len() - 1; + let c = lagrange_kernel_column_frame; + let v = c.num_rows() - 1; let r = lagrange_kernel_rand_elements; for k in 1..v + 1 { diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs new file mode 100644 index 000000000..d3e198912 --- /dev/null +++ b/air/src/air/logup_gkr/mod.rs @@ -0,0 +1,305 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; +use core::marker::PhantomData; + +use crypto::{ElementHasher, RandomCoin}; +use math::{ExtensionOf, FieldElement, StarkField, ToElements}; + +use super::{EvaluationFrame, GkrData, LagrangeConstraintsCompositionCoefficients}; +mod s_column; +use s_column::SColumnConstraint; + +mod lagrange; +pub use lagrange::{ + LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, + LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, +}; + +/// A trait containing the necessary information in order to run the LogUp-GKR protocol of [1]. +/// +/// The trait contains useful information for running the GKR protocol as well as for implementing +/// the univariate IOP for multi-linear evaluation of Section 5 in [1] for the final evaluation +/// check resulting from GKR. +/// +/// [1]: https://eprint.iacr.org/2023/1284 +pub trait LogUpGkrEvaluator: Clone + Sync { + /// Defines the base field of the evaluator. + type BaseField: StarkField; + + /// Public inputs need to compute the final claim. + type PublicInputs: ToElements + Send; + + /// Gets a list of all oracles involved in LogUp-GKR; this is intended to be used in construction of + /// MLEs. + fn get_oracles(&self) -> &[LogUpGkrOracle]; + + /// A vector of virtual periodic columns defined by their values in some given cycle. + /// Note that the cycle lengths must be powers of 2. + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + /// Returns the number of random values needed to evaluate a query. + fn get_num_rand_values(&self) -> usize; + + /// Returns the number of fractions in the LogUp-GKR statement. + fn get_num_fractions(&self) -> usize; + + /// Returns the maximal degree of the multi-variate associated to the input layer. + /// + /// This is equal to the max of $1 + deg_k(\text{numerator}_i) * deg_k(\text{denominator}_j)$ where + /// $i$ and $j$ range over the number of numerators and denominators, respectively, and $deg_k$ + /// is the degree of a multi-variate polynomial in its $k$-th variable. + fn max_degree(&self) -> usize; + + /// Builds a query from the provided main trace frame and periodic values. + /// + /// Note: it should be possible to provide an implementation of this method based on the + /// information returned from `get_oracles()`. However, this implementation is likely to be + /// expensive compared to the hand-written implementation. However, we could provide a test + /// which verifies that `get_oracles()` and `build_query()` methods are consistent. + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement; + + /// Evaluates the provided query and writes the results into the numerators and denominators. + /// + /// Note: it is also possible to combine `build_query()` and `evaluate_query()` into a single + /// method to avoid the need to first build the query struct and then evaluate it. However: + /// - We assume that the compiler will be able to optimize this away. + /// - Merging the methods will make it more difficult avoid inconsistencies between + /// `evaluate_query()` and `get_oracles()` methods. + fn evaluate_query( + &self, + query: &[F], + periodic_values: &[F], + logup_randomness: &[E], + numerators: &mut [E], + denominators: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf; + + /// Computes the final claim for the LogUp-GKR circuit. + /// + /// The default implementation of this method returns E::ZERO as it is expected that the + /// fractional sums will cancel out. However, in cases when some boundary conditions need to + /// be imposed on the LogUp-GKR relations, this method can be overridden to compute the final + /// expected claim. + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } + + /// Generates the data needed for running the univariate IOP for multi-linear evaluation of [1]. + /// + /// This mainly generates the batching randomness used to batch a number of multi-linear + /// evaluation claims and includes some additional data that is needed for building/verifying + /// the univariate IOP for multi-linear evaluation of [1]. + /// + /// This is the $\lambda$ randomness in section 5.2 in [1] but using different random values for + /// each term instead of powers of a single random element. + /// + /// [1]: https://eprint.iacr.org/2023/1284 + fn generate_univariate_iop_for_multi_linear_opening_data( + &self, + openings: Vec, + eval_point: Vec, + public_coin: &mut impl RandomCoin, + ) -> GkrData + where + E: FieldElement, + H: ElementHasher, + { + public_coin.reseed(H::hash_elements(&openings)); + + let mut batching_randomness = Vec::with_capacity(openings.len() - 1); + for _ in 0..openings.len() - 1 { + batching_randomness.push(public_coin.draw().expect("failed to generate randomness")) + } + + GkrData::new( + LagrangeKernelRandElements::new(eval_point), + batching_randomness, + openings, + self.get_oracles().to_vec(), + ) + } + + /// Returns a new [`LagrangeKernelConstraints`]. + fn get_lagrange_kernel_constraints>( + &self, + lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, + lagrange_kernel_rand_elements: &LagrangeKernelRandElements, + ) -> LagrangeKernelConstraints { + LagrangeKernelConstraints::new( + lagrange_composition_coefficients, + lagrange_kernel_rand_elements, + ) + } + + /// Returns a new [`SColumnConstraints`]. + fn get_s_column_constraints>( + &self, + gkr_data: GkrData, + composition_coefficient: E, + ) -> SColumnConstraint { + SColumnConstraint::new(gkr_data, composition_coefficient) + } + + /// Returns the periodic values used in the LogUp-GKR statement, either as base field element + /// during circuit evaluation or as extension field element during the run of sum-check for + /// the input layer. + fn build_periodic_values(&self) -> PeriodicTable + where + E: FieldElement, + { + let table = self + .get_periodic_column_values() + .iter() + .map(|values| values.iter().map(|x| E::from(*x)).collect()) + .collect(); + + PeriodicTable { table } + } +} + +#[derive(Clone, Default)] +pub(crate) struct PhantomLogUpGkrEval> { + _field: PhantomData, + _public_inputs: PhantomData

, +} + +impl PhantomLogUpGkrEval +where + B: StarkField, + P: Clone + Send + Sync + ToElements, +{ + pub fn new() -> Self { + Self { + _field: PhantomData, + _public_inputs: PhantomData, + } + } +} + +impl LogUpGkrEvaluator for PhantomLogUpGkrEval +where + B: StarkField, + P: Clone + Send + Sync + ToElements, +{ + type BaseField = B; + + type PublicInputs = P; + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn get_num_rand_values(&self) -> usize { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn get_num_fractions(&self) -> usize { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn max_degree(&self) -> usize { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn build_query(&self, _frame: &EvaluationFrame, _query: &mut [E]) + where + E: FieldElement, + { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn evaluate_query( + &self, + _query: &[F], + _periodic_values: &[F], + _rand_values: &[E], + _numerator: &mut [E], + _denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } +} + +#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] +pub enum LogUpGkrOracle { + /// A column with a given index in the main trace segment. + CurrentRow(usize), + /// A column with a given index in the main trace segment but shifted upwards. + NextRow(usize), +} + +// PERIODIC COLUMNS FOR LOGUP +// ================================================================================================= + +/// Stores the periodic columns used in a LogUp-GKR statement. +/// +/// Each stored periodic column is interpreted as a multi-linear extension polynomial of the column +/// with the given periodic values. Due to the periodic nature of the values, storing, binding of +/// an argument and evaluating the said multi-linear extension can be all done linearly in the size +/// of the smallest cycle defining the periodic values. Hence we only store the values of this +/// smallest cycle. The cycle is assumed throughout to be a power of 2. +#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)] +pub struct PeriodicTable { + pub table: Vec>, +} + +impl PeriodicTable +where + E: FieldElement, +{ + pub fn new(table: Vec>) -> Self { + let table = table.iter().map(|col| col.iter().map(|x| E::from(*x)).collect()).collect(); + + Self { table } + } + + pub fn num_columns(&self) -> usize { + self.table.len() + } + + pub fn table(&self) -> &[Vec] { + &self.table + } + + pub fn fill_periodic_values_at(&self, row: usize, values: &mut [E]) { + self.table + .iter() + .zip(values.iter_mut()) + .for_each(|(col, value)| *value = col[row % col.len()]) + } + + pub fn bind_least_significant_variable(&mut self, round_challenge: E) { + for col in self.table.iter_mut() { + if col.len() > 1 { + let num_evals = col.len() >> 1; + for i in 0..num_evals { + col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]); + } + col.truncate(num_evals) + } + } + } +} diff --git a/air/src/air/logup_gkr/s_column.rs b/air/src/air/logup_gkr/s_column.rs new file mode 100644 index 000000000..685c6e026 --- /dev/null +++ b/air/src/air/logup_gkr/s_column.rs @@ -0,0 +1,56 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use math::FieldElement; + +use super::{super::Air, EvaluationFrame, GkrData}; +use crate::LogUpGkrEvaluator; + +/// Represents the transition constraint for the s-column, as well as the random coefficient used +/// to linearly combine the constraint into the constraint composition polynomial. +/// +/// The s-column implements the cohomological sum-check argument of [1] and the constraint in +/// [`SColumnConstraint`] is exactly Eq (4) in Lemma 1 in [1]. +/// +/// +/// [1]: https://eprint.iacr.org/2021/930 +pub struct SColumnConstraint { + gkr_data: GkrData, + composition_coefficient: E, +} + +impl SColumnConstraint { + pub fn new(gkr_data: GkrData, composition_coefficient: E) -> Self { + Self { gkr_data, composition_coefficient } + } + + /// Evaluates the transition constraint over the specificed main trace segment, s-column, + /// and Lagrange kernel evaluation frames. + pub fn evaluate( + &self, + air: &A, + main_trace_frame: &EvaluationFrame, + s_cur: E, + s_nxt: E, + l_cur: E, + x: E, + ) -> E + where + A: Air, + { + let batched_claim = self.gkr_data.compute_batched_claim(); + let mean = batched_claim + .mul_base(E::BaseField::ONE / E::BaseField::from(air.trace_length() as u32)); + + let mut query = vec![E::ZERO; air.get_logup_gkr_evaluator().get_oracles().len()]; + air.get_logup_gkr_evaluator().build_query(main_trace_frame, &mut query); + let batched_claim_at_query = self.gkr_data.compute_batched_query::(&query); + let rhs = s_cur - mean + batched_claim_at_query * l_cur; + let lhs = s_nxt; + + let divisor = x.exp((air.trace_length() as u32).into()) - E::ONE; + self.composition_coefficient * (rhs - lhs) / divisor + } +} diff --git a/air/src/air/mod.rs b/air/src/air/mod.rs index 53a59fa5a..cc2e82d2b 100644 --- a/air/src/air/mod.rs +++ b/air/src/air/mod.rs @@ -11,7 +11,7 @@ use math::{fft, ExtensibleField, ExtensionOf, FieldElement, StarkField, ToElemen use crate::ProofOptions; mod aux; -pub use aux::{AuxRandElements, GkrRandElements, GkrVerifier}; +pub use aux::{AuxRandElements, GkrData}; mod trace_info; pub use trace_info::TraceInfo; @@ -28,10 +28,12 @@ pub use boundary::{BoundaryConstraint, BoundaryConstraintGroup, BoundaryConstrai mod transition; pub use transition::{EvaluationFrame, TransitionConstraintDegree, TransitionConstraints}; -mod lagrange; -pub use lagrange::{ +mod logup_gkr; +use logup_gkr::PhantomLogUpGkrEval; +pub use logup_gkr::{ LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, - LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, + LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, + LogUpGkrOracle, PeriodicTable, }; mod coefficients; @@ -42,7 +44,6 @@ pub use coefficients::{ mod divisor; pub use divisor::ConstraintDivisor; -use utils::{Deserializable, Serializable}; #[cfg(test)] mod tests; @@ -192,13 +193,7 @@ pub trait Air: Send + Sync { /// A type defining shape of public inputs for the computation described by this protocol. /// This could be any type as long as it can be serialized into a sequence of field elements. - type PublicInputs: ToElements + Send; - - /// An GKR proof object. If not needed, set to `()`. - type GkrProof: Serializable + Deserializable + Send; - - /// A verifier for verifying GKR proofs. If not needed, set to `()`. - type GkrVerifier: GkrVerifier; + type PublicInputs: ToElements + Clone + Send + Sync; // REQUIRED METHODS // -------------------------------------------------------------------------------------------- @@ -214,7 +209,7 @@ pub trait Air: Send + Sync { fn new(trace_info: TraceInfo, pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self; /// Returns context for this instance of the computation. - fn context(&self) -> &AirContext; + fn context(&self) -> &AirContext; /// Evaluates transition constraints over the specified evaluation frame. /// @@ -303,16 +298,15 @@ pub trait Air: Send + Sync { Vec::new() } - // AUXILIARY PROOF VERIFIER + // LOGUP-GKR EVALUATOR // -------------------------------------------------------------------------------------------- - /// Returns the [`GkrVerifier`] to be used to verify the GKR proof. - /// - /// Leave unimplemented if the `Air` doesn't use a GKR proof. - fn get_gkr_proof_verifier>( + /// Returns the object needed for the LogUp-GKR argument. + fn get_logup_gkr_evaluator( &self, - ) -> Self::GkrVerifier { - unimplemented!("`get_auxiliary_proof_verifier()` must be implemented when the proof contains a GKR proof"); + ) -> impl LogUpGkrEvaluator + { + PhantomLogUpGkrEval::new() } // PROVIDED METHODS @@ -335,22 +329,6 @@ pub trait Air: Send + Sync { Ok(rand_elements) } - /// Returns a new [`LagrangeKernelConstraints`] if a Lagrange kernel auxiliary column is present - /// in the trace, or `None` otherwise. - fn get_lagrange_kernel_constraints>( - &self, - lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, - lagrange_kernel_rand_elements: &LagrangeKernelRandElements, - ) -> Option> { - self.context().lagrange_kernel_aux_column_idx().map(|col_idx| { - LagrangeKernelConstraints::new( - lagrange_composition_coefficients, - lagrange_kernel_rand_elements, - col_idx, - ) - }) - } - /// Returns values for all periodic columns used in the computation. /// /// These values will be used to compute column values at specific states of the computation @@ -545,7 +523,7 @@ pub trait Air: Send + Sync { b_coefficients.push(public_coin.draw()?); } - let lagrange = if self.context().has_lagrange_kernel_aux_column() { + let lagrange = if self.context().logup_gkr_enabled() { let mut lagrange_kernel_t_coefficients = Vec::new(); for _ in 0..self.context().trace_len().ilog2() { lagrange_kernel_t_coefficients.push(public_coin.draw()?); @@ -561,10 +539,17 @@ pub trait Air: Send + Sync { None }; + let s_col = if self.context().logup_gkr_enabled() { + Some(public_coin.draw()?) + } else { + None + }; + Ok(ConstraintCompositionCoefficients { transition: t_coefficients, boundary: b_coefficients, lagrange, + s_col, }) } @@ -588,7 +573,13 @@ pub trait Air: Send + Sync { c_coefficients.push(public_coin.draw()?); } - let lagrange_cc = if self.context().has_lagrange_kernel_aux_column() { + let lagrange_cc = if self.context().logup_gkr_enabled() { + Some(public_coin.draw()?) + } else { + None + }; + + let s_col_cc = if self.context().logup_gkr_enabled() { Some(public_coin.draw()?) } else { None @@ -598,6 +589,7 @@ pub trait Air: Send + Sync { trace: t_coefficients, constraints: c_coefficients, lagrange: lagrange_cc, + s_col: s_col_cc, }) } } diff --git a/air/src/air/tests.rs b/air/src/air/tests.rs index e0063ed3b..2400cb883 100644 --- a/air/src/air/tests.rs +++ b/air/src/air/tests.rs @@ -9,8 +9,8 @@ use crypto::{hashers::Blake3_256, DefaultRandomCoin, RandomCoin}; use math::{fields::f64::BaseElement, get_power_series, polynom, FieldElement, StarkField}; use super::{ - Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo, - TransitionConstraintDegree, + logup_gkr::PhantomLogUpGkrEval, Air, AirContext, Assertion, EvaluationFrame, ProofOptions, + TraceInfo, TransitionConstraintDegree, }; use crate::FieldExtension; @@ -192,7 +192,7 @@ fn get_boundary_constraints() { // ================================================================================================ struct MockAir { - context: AirContext, + context: AirContext, assertions: Vec>, periodic_columns: Vec>, } @@ -225,8 +225,6 @@ impl MockAir { impl Air for MockAir { type BaseField = BaseElement; type PublicInputs = (); - type GkrProof = (); - type GkrVerifier = (); fn new(trace_info: TraceInfo, _pub_inputs: (), _options: ProofOptions) -> Self { let num_assertions = trace_info.meta()[0] as usize; @@ -238,7 +236,7 @@ impl Air for MockAir { } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } @@ -257,6 +255,13 @@ impl Air for MockAir { _result: &mut [E], ) { } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl super::LogUpGkrEvaluator + { + PhantomLogUpGkrEval::default() + } } // UTILITY FUNCTIONS @@ -266,11 +271,11 @@ pub fn build_context( trace_length: usize, trace_width: usize, num_assertions: usize, -) -> AirContext { +) -> AirContext { let options = ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 31); let t_degrees = vec![TransitionConstraintDegree::new(2)]; let trace_info = TraceInfo::new(trace_width, trace_length); - AirContext::new(trace_info, t_degrees, num_assertions, options) + AirContext::new(trace_info, (), t_degrees, num_assertions, options) } pub fn build_prng() -> DefaultRandomCoin> { diff --git a/air/src/air/trace_info.rs b/air/src/air/trace_info.rs index 99ff4aa6d..29cf4726b 100644 --- a/air/src/air/trace_info.rs +++ b/air/src/air/trace_info.rs @@ -27,6 +27,7 @@ pub struct TraceInfo { num_aux_segment_rands: usize, trace_length: usize, trace_meta: Vec, + logup_gkr: bool, } impl TraceInfo { @@ -38,6 +39,10 @@ impl TraceInfo { pub const MAX_META_LENGTH: usize = 65535; /// Maximum number of random elements in the auxiliary trace segment; currently set to 255. pub const MAX_RAND_SEGMENT_ELEMENTS: usize = 255; + /// The Lagrange kernel, if present, is the last column of the auxiliary trace. + pub const LAGRANGE_KERNEL_OFFSET: usize = 1; + /// The s-column, if present, is the second to last column of the auxiliary trace. + pub const S_COLUMN_OFFSET: usize = 2; // CONSTRUCTORS // -------------------------------------------------------------------------------------------- @@ -65,7 +70,7 @@ impl TraceInfo { /// * Length of `meta` is greater than 65535; pub fn with_meta(width: usize, length: usize, meta: Vec) -> Self { assert!(width > 0, "trace width must be greater than 0"); - Self::new_multi_segment(width, 0, 0, length, meta) + Self::new_multi_segment(width, 0, 0, length, meta, false) } /// Creates a new [TraceInfo] with main and auxiliary segments. @@ -90,6 +95,7 @@ impl TraceInfo { num_aux_segment_rands: usize, trace_length: usize, trace_meta: Vec, + logup_gkr: bool, ) -> Self { assert!( trace_length >= Self::MIN_TRACE_LENGTH, @@ -110,7 +116,7 @@ impl TraceInfo { // validate trace segment widths assert!(main_segment_width > 0, "main trace segment must consist of at least one column"); - let full_width = main_segment_width + aux_segment_width; + let full_width = main_segment_width + aux_segment_width + 2 * logup_gkr as usize; assert!( full_width <= TraceInfo::MAX_TRACE_WIDTH, "total number of columns in the trace cannot be greater than {}, but was {}", @@ -138,6 +144,7 @@ impl TraceInfo { num_aux_segment_rands, trace_length, trace_meta, + logup_gkr, } } @@ -146,9 +153,13 @@ impl TraceInfo { /// Returns the total number of columns in an execution trace. /// + /// When LogUp-GKR is enabled, we also account for two extra columns, in the auxiliary segment, + /// which are needed for implementing the univariate IOP for multi-linear evaluation in + /// https://eprint.iacr.org/2023/1284. + /// /// This is guaranteed to be between 1 and 255. pub fn width(&self) -> usize { - self.main_segment_width + self.aux_segment_width + self.main_segment_width + self.aux_segment_width + 2 * self.logup_gkr as usize } /// Returns execution trace length. @@ -163,21 +174,27 @@ impl TraceInfo { &self.trace_meta } - /// Returns true if an execution trace contains the auxiliary trace segment. + /// Returns true if an execution trace contains an auxiliary trace segment. + /// + /// This includes either the case when the auxiliary trace segment is user defined or the case + /// when the segment is created as part of LogUp-GKR. pub fn is_multi_segment(&self) -> bool { - self.aux_segment_width > 0 + self.aux_segment_width > 0 || self.logup_gkr } /// Returns the number of columns in the main segment of an execution trace. /// /// This is guaranteed to be between 1 and 255. - pub fn main_trace_width(&self) -> usize { + pub fn main_segment_width(&self) -> usize { self.main_segment_width } /// Returns the number of columns in the auxiliary segment of an execution trace. + /// + /// This includes both the columns that are user defined as well as the two columns defined + /// as part of LogUp-GKR when the latter is enabled. pub fn aux_segment_width(&self) -> usize { - self.aux_segment_width + self.aux_segment_width + 2 * self.logup_gkr as usize } /// Returns the total number of segments in an execution trace. @@ -198,9 +215,27 @@ impl TraceInfo { } } - /// Returns the number of columns in the auxiliary trace segment. - pub fn get_aux_segment_width(&self) -> usize { - self.aux_segment_width + /// Returns a boolean indicating whether LogUp-GKR is enabled. + pub fn logup_gkr_enabled(&self) -> bool { + self.logup_gkr + } + + /// Returns the index of the auxiliary column which implements the Lagrange kernel, if any. + pub fn lagrange_kernel_column_idx(&self) -> Option { + if self.logup_gkr_enabled() { + Some(self.aux_segment_width() - TraceInfo::LAGRANGE_KERNEL_OFFSET) + } else { + None + } + } + + /// Returns the index of the auxiliary column which implements the s-column, if any. + pub fn s_column_idx(&self) -> Option { + if self.logup_gkr_enabled() { + Some(self.aux_segment_width() - TraceInfo::S_COLUMN_OFFSET) + } else { + None + } } /// Returns the number of random elements needed to build all auxiliary columns, except for the @@ -264,6 +299,9 @@ impl Serializable for TraceInfo { // store trace meta target.write_u16(self.trace_meta.len() as u16); target.write_bytes(&self.trace_meta); + + // write bool indicating if LogUp-GKR is used + target.write_bool(self.logup_gkr); } } @@ -326,12 +364,16 @@ impl Deserializable for TraceInfo { vec![] }; + // read `logup_gkr` + let logup_gkr = source.read_bool()?; + Ok(Self::new_multi_segment( main_segment_width, aux_segment_width, num_aux_segment_rands, trace_length, trace_meta, + logup_gkr, )) } } @@ -387,6 +429,7 @@ mod tests { aux_rands, trace_length as usize, trace_meta, + false, ); assert_eq!(expected, info.to_elements()); diff --git a/air/src/air/transition/mod.rs b/air/src/air/transition/mod.rs index 60e641817..d29cbbb8b 100644 --- a/air/src/air/transition/mod.rs +++ b/air/src/air/transition/mod.rs @@ -46,7 +46,7 @@ impl TransitionConstraints { /// # Panics /// Panics if the number of transition constraints in the context does not match the number of /// provided composition coefficients. - pub fn new(context: &AirContext, composition_coefficients: &[E]) -> Self { + pub fn new

(context: &AirContext, composition_coefficients: &[E]) -> Self { assert_eq!( context.num_transition_constraints(), composition_coefficients.len(), diff --git a/air/src/lib.rs b/air/src/lib.rs index 539a812d9..39ef44d18 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -44,9 +44,9 @@ mod air; pub use air::{ Air, AirContext, Assertion, AuxRandElements, BoundaryConstraint, BoundaryConstraintGroup, BoundaryConstraints, ConstraintCompositionCoefficients, ConstraintDivisor, - DeepCompositionCoefficients, EvaluationFrame, GkrRandElements, GkrVerifier, + DeepCompositionCoefficients, EvaluationFrame, GkrData, LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, - LagrangeKernelTransitionConstraints, TraceInfo, TransitionConstraintDegree, - TransitionConstraints, + LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable, + TraceInfo, TransitionConstraintDegree, TransitionConstraints, }; diff --git a/air/src/proof/context.rs b/air/src/proof/context.rs index 83c2beece..73152709a 100644 --- a/air/src/proof/context.rs +++ b/air/src/proof/context.rs @@ -190,6 +190,7 @@ mod tests { aux_rands, trace_length, vec![], + false, ); let mut expected = trace_info.to_elements(); @@ -213,8 +214,14 @@ mod tests { fri_folding_factor as usize, fri_remainder_max_degree as usize, ); - let trace_info = - TraceInfo::new_multi_segment(main_width, aux_width, aux_rands, trace_length, vec![]); + let trace_info = TraceInfo::new_multi_segment( + main_width, + aux_width, + aux_rands, + trace_length, + vec![], + false, + ); let context = Context::new::(trace_info, options); assert_eq!(expected, context.to_elements()); } diff --git a/air/src/proof/ood_frame.rs b/air/src/proof/ood_frame.rs index d4b3f14ec..9ae017094 100644 --- a/air/src/proof/ood_frame.rs +++ b/air/src/proof/ood_frame.rs @@ -131,7 +131,7 @@ impl OodFrame { let lagrange_kernel_frame = if lagrange_kernel_frame_size > 0 { let lagrange_kernel_trace = reader.read_many(lagrange_kernel_frame_size)?; - Some(LagrangeKernelEvaluationFrame::new(lagrange_kernel_trace)) + Some(LagrangeKernelEvaluationFrame::with_values(lagrange_kernel_trace)) } else { None }; @@ -229,6 +229,8 @@ impl Deserializable for OodFrame { // OOD FRAME TRACE STATES // ================================================================================================ +/// Stores trace evaluations at an OOD point. +/// /// Stores the trace evaluations at `z` and `gz`, where `z` is a random Field element in /// `current_row` and `next_row`, respectively. If the Air contains a Lagrange kernel auxiliary /// column, then that column interpolated polynomial will be evaluated at `z`, `gz`, `g^2 z`, ... diff --git a/crypto/src/hash/mds/mds_f64_12x12.rs b/crypto/src/hash/mds/mds_f64_12x12.rs index 44f5660b9..ddf79f4a2 100644 --- a/crypto/src/hash/mds/mds_f64_12x12.rs +++ b/crypto/src/hash/mds/mds_f64_12x12.rs @@ -12,19 +12,19 @@ use math::{ FieldElement, }; -/// This module contains helper functions as well as constants used to perform a 12x12 vector-matrix -/// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce -/// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". -/// This follows from the simple fact that every circulant matrix has the columns of the discrete -/// Fourier transform matrix as orthogonal eigenvectors. -/// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that -/// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, -/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of -/// an MDS matrix that has small powers of 2 entries in frequency domain. -/// The following implementation has benefited greatly from the discussions and insights of -/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is based on Nabaglo's implementation -/// in [Plonky2](https://github.com/mir-protocol/plonky2). -/// The circulant matrix is identified by its first row: [7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8]. +// This module contains helper functions as well as constants used to perform a 12x12 vector-matrix +// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce +// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". +// This follows from the simple fact that every circulant matrix has the columns of the discrete +// Fourier transform matrix as orthogonal eigenvectors. +// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that +// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, +// divisions by 2 and repeated modular reductions. This is because of our explicit choice of +// an MDS matrix that has small powers of 2 entries in frequency domain. +// The following implementation has benefited greatly from the discussions and insights of +// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is based on Nabaglo's implementation +// in [Plonky2](https://github.com/mir-protocol/plonky2). +// The circulant matrix is identified by its first row: [7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8]. // MDS matrix in frequency domain. // More precisely, this is the output of the three 4-point (real) FFTs of the first column of @@ -33,6 +33,7 @@ use math::{ // The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4. // The code to generate the matrix in frequency domain is based on an adaptation of a code, to generate // MDS matrices efficiently in original domain, that was developed by the Polygon Zero team. + const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16]; const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)]; const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1]; diff --git a/crypto/src/hash/mds/mds_f64_8x8.rs b/crypto/src/hash/mds/mds_f64_8x8.rs index 037dee721..4e7818357 100644 --- a/crypto/src/hash/mds/mds_f64_8x8.rs +++ b/crypto/src/hash/mds/mds_f64_8x8.rs @@ -12,25 +12,26 @@ use math::{ FieldElement, }; -/// This module contains helper functions as well as constants used to perform a 8x8 vector-matrix -/// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce -/// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". -/// This follows from the simple fact that every circulant matrix has the columns of the discrete -/// Fourier transform matrix as orthogonal eigenvectors. -/// The implementation also avoids the use of internal 2-point FFTs, and 2-point iFFTs, and substitutes -/// them with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, -/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of -/// an MDS matrix that has small powers of 2 entries in frequency domain. -/// The following implementation has benefited greatly from the discussions and insights of -/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero is based on Nabaglo's implementation -/// in [Plonky2](https://github.com/mir-protocol/plonky2). -/// The circulant matrix is identified by its first row: [23, 8, 13, 10, 7, 6, 21, 8]. +// This module contains helper functions as well as constants used to perform a 8x8 vector-matrix +// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce +// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". +// This follows from the simple fact that every circulant matrix has the columns of the discrete +// Fourier transform matrix as orthogonal eigenvectors. +// The implementation also avoids the use of internal 2-point FFTs, and 2-point iFFTs, and substitutes +// them with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, +// divisions by 2 and repeated modular reductions. This is because of our explicit choice of +// an MDS matrix that has small powers of 2 entries in frequency domain. +// The following implementation has benefited greatly from the discussions and insights of +// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero is based on Nabaglo's implementation +// in [Plonky2](https://github.com/mir-protocol/plonky2). +// The circulant matrix is identified by its first row: [23, 8, 13, 10, 7, 6, 21, 8]. // MDS matrix in frequency domain. // More precisely, this is the output of the two 4-point (real) FFTs of the first column of // the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors // and application of the final four 2-point FFT in order to get the full 8-point FFT. // The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4. + const MDS_FREQ_BLOCK_ONE: [i64; 2] = [16, 8]; const MDS_FREQ_BLOCK_TWO: [(i64, i64); 2] = [(8, -4), (-1, 1)]; const MDS_FREQ_BLOCK_THREE: [i64; 2] = [-1, 1]; diff --git a/crypto/src/merkle/concurrent.rs b/crypto/src/merkle/concurrent.rs index 637bd51b5..7a3ba077f 100644 --- a/crypto/src/merkle/concurrent.rs +++ b/crypto/src/merkle/concurrent.rs @@ -18,9 +18,10 @@ pub const MIN_CONCURRENT_LEAVES: usize = 1024; // PUBLIC FUNCTIONS // ================================================================================================ -/// Builds all internal nodes of the Merkle using all available threads and stores the -/// results in a single vector such that root of the tree is at position 1, nodes immediately -/// under the root is at positions 2 and 3 etc. +/// Builds all internal nodes of the Merkle tree. +/// +/// This uses all available threads and stores the results in a single vector such that root of +/// the tree is at position 1, nodes immediately under the root is at positions 2 and 3 etc. pub fn build_merkle_nodes(leaves: &[H::Digest]) -> Vec { let n = leaves.len() / 2; diff --git a/examples/src/fibonacci/fib2/air.rs b/examples/src/fibonacci/fib2/air.rs index 9e5d75a48..4019ddcae 100644 --- a/examples/src/fibonacci/fib2/air.rs +++ b/examples/src/fibonacci/fib2/air.rs @@ -14,15 +14,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct FibAir { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for FibAir { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -30,12 +28,12 @@ impl Air for FibAir { let degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; assert_eq!(TRACE_WIDTH, trace_info.width()); FibAir { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/fibonacci/fib8/air.rs b/examples/src/fibonacci/fib8/air.rs index 4d7aef9ba..17edc7970 100644 --- a/examples/src/fibonacci/fib8/air.rs +++ b/examples/src/fibonacci/fib8/air.rs @@ -15,15 +15,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct Fib8Air { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for Fib8Air { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -31,12 +29,12 @@ impl Air for Fib8Air { let degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; assert_eq!(TRACE_WIDTH, trace_info.width()); Fib8Air { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/fibonacci/fib_small/air.rs b/examples/src/fibonacci/fib_small/air.rs index 66580c872..b48eb734b 100644 --- a/examples/src/fibonacci/fib_small/air.rs +++ b/examples/src/fibonacci/fib_small/air.rs @@ -14,15 +14,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct FibSmall { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for FibSmall { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -30,12 +28,12 @@ impl Air for FibSmall { let degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; assert_eq!(TRACE_WIDTH, trace_info.width()); FibSmall { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/fibonacci/mulfib2/air.rs b/examples/src/fibonacci/mulfib2/air.rs index 3190d2e41..501adf6af 100644 --- a/examples/src/fibonacci/mulfib2/air.rs +++ b/examples/src/fibonacci/mulfib2/air.rs @@ -16,15 +16,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct MulFib2Air { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for MulFib2Air { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -32,12 +30,12 @@ impl Air for MulFib2Air { let degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2)]; assert_eq!(TRACE_WIDTH, trace_info.width()); MulFib2Air { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/fibonacci/mulfib8/air.rs b/examples/src/fibonacci/mulfib8/air.rs index bbbe1dea0..c76f4f091 100644 --- a/examples/src/fibonacci/mulfib8/air.rs +++ b/examples/src/fibonacci/mulfib8/air.rs @@ -16,15 +16,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct MulFib8Air { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for MulFib8Air { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -41,12 +39,12 @@ impl Air for MulFib8Air { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); MulFib8Air { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/lamport/aggregate/air.rs b/examples/src/lamport/aggregate/air.rs index 29b6e2372..57708fd74 100644 --- a/examples/src/lamport/aggregate/air.rs +++ b/examples/src/lamport/aggregate/air.rs @@ -38,7 +38,7 @@ impl ToElements for PublicInputs { } pub struct LamportAggregateAir { - context: AirContext, + context: AirContext, pub_keys: Vec<[BaseElement; 2]>, messages: Vec<[BaseElement; 2]>, } @@ -46,8 +46,6 @@ pub struct LamportAggregateAir { impl Air for LamportAggregateAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -88,13 +86,13 @@ impl Air for LamportAggregateAir { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); LamportAggregateAir { - context: AirContext::new(trace_info, degrees, 22, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 22, options), pub_keys: pub_inputs.pub_keys, messages: pub_inputs.messages, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/lamport/threshold/air.rs b/examples/src/lamport/threshold/air.rs index 41983c743..b68a2a24d 100644 --- a/examples/src/lamport/threshold/air.rs +++ b/examples/src/lamport/threshold/air.rs @@ -22,7 +22,7 @@ const TWO: BaseElement = BaseElement::new(2); // THRESHOLD LAMPORT PLUS SIGNATURE AIR // ================================================================================================ -#[derive(Clone)] +#[derive(Clone, Default)] pub struct PublicInputs { pub pub_key_root: [BaseElement; 2], pub num_pub_keys: usize, @@ -41,7 +41,7 @@ impl ToElements for PublicInputs { } pub struct LamportThresholdAir { - context: AirContext, + context: AirContext, pub_key_root: [BaseElement; 2], num_pub_keys: usize, num_signatures: usize, @@ -51,8 +51,6 @@ pub struct LamportThresholdAir { impl Air for LamportThresholdAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -99,7 +97,7 @@ impl Air for LamportThresholdAir { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); LamportThresholdAir { - context: AirContext::new(trace_info, degrees, 26, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 26, options), pub_key_root: pub_inputs.pub_key_root, num_pub_keys: pub_inputs.num_pub_keys, num_signatures: pub_inputs.num_signatures, @@ -244,7 +242,7 @@ impl Air for LamportThresholdAir { result } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } } diff --git a/examples/src/merkle/air.rs b/examples/src/merkle/air.rs index e0c8b177c..5d38397ff 100644 --- a/examples/src/merkle/air.rs +++ b/examples/src/merkle/air.rs @@ -14,6 +14,7 @@ use crate::utils::{are_equal, is_binary, is_zero, not, EvaluationResult}; // MERKLE PATH VERIFICATION AIR // ================================================================================================ +#[derive(Clone)] pub struct PublicInputs { pub tree_root: [BaseElement; 2], } @@ -25,15 +26,13 @@ impl ToElements for PublicInputs { } pub struct MerkleAir { - context: AirContext, + context: AirContext, tree_root: [BaseElement; 2], } impl Air for MerkleAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -49,12 +48,12 @@ impl Air for MerkleAir { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); MerkleAir { - context: AirContext::new(trace_info, degrees, 4, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 4, options), tree_root: pub_inputs.tree_root, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/rescue/air.rs b/examples/src/rescue/air.rs index a9d3d5ebb..09bf9c450 100644 --- a/examples/src/rescue/air.rs +++ b/examples/src/rescue/air.rs @@ -37,6 +37,7 @@ const CYCLE_MASK: [BaseElement; CYCLE_LENGTH] = [ // RESCUE AIR // ================================================================================================ +#[derive(Clone)] pub struct PublicInputs { pub seed: [BaseElement; 2], pub result: [BaseElement; 2], @@ -51,7 +52,7 @@ impl ToElements for PublicInputs { } pub struct RescueAir { - context: AirContext, + context: AirContext, seed: [BaseElement; 2], result: [BaseElement; 2], } @@ -59,8 +60,6 @@ pub struct RescueAir { impl Air for RescueAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -73,13 +72,13 @@ impl Air for RescueAir { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); RescueAir { - context: AirContext::new(trace_info, degrees, 4, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 4, options), seed: pub_inputs.seed, result: pub_inputs.result, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/rescue_raps/air.rs b/examples/src/rescue_raps/air.rs index 6fb5321b1..694e189bc 100644 --- a/examples/src/rescue_raps/air.rs +++ b/examples/src/rescue_raps/air.rs @@ -41,6 +41,7 @@ const CYCLE_MASK: [BaseElement; CYCLE_LENGTH] = [ // RESCUE AIR // ================================================================================================ +#[derive(Clone)] pub struct PublicInputs { pub result: [[BaseElement; 2]; 2], } @@ -52,15 +53,13 @@ impl ToElements for PublicInputs { } pub struct RescueRapsAir { - context: AirContext, + context: AirContext, result: [[BaseElement; 2]; 2], } impl Air for RescueRapsAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -76,18 +75,18 @@ impl Air for RescueRapsAir { RescueRapsAir { context: AirContext::new_multi_segment( trace_info, + pub_inputs.clone(), main_degrees, aux_degrees, 8, 2, - None, options, ), result: pub_inputs.result, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/rescue_raps/custom_trace_table.rs b/examples/src/rescue_raps/custom_trace_table.rs index 063d509a4..f6f9d075b 100644 --- a/examples/src/rescue_raps/custom_trace_table.rs +++ b/examples/src/rescue_raps/custom_trace_table.rs @@ -89,7 +89,7 @@ impl RapTraceTable { let columns = unsafe { (0..width).map(|_| uninit_vector(length)).collect() }; Self { - info: TraceInfo::new_multi_segment(width, 3, 3, length, meta), + info: TraceInfo::new_multi_segment(width, 3, 3, length, meta, false), trace: ColMatrix::new(columns), } } @@ -113,7 +113,7 @@ impl RapTraceTable { I: Fn(&mut [B]), U: Fn(usize, &mut [B]), { - let mut state = vec![B::ZERO; self.info.main_trace_width()]; + let mut state = vec![B::ZERO; self.info.main_segment_width()]; init(&mut state); self.update_row(0, &state); @@ -133,7 +133,7 @@ impl RapTraceTable { /// Returns the number of columns in this execution trace. pub fn width(&self) -> usize { - self.info.main_trace_width() + self.info.main_segment_width() } /// Returns value of the cell in the specified column at the specified row of this trace. diff --git a/examples/src/rescue_raps/prover.rs b/examples/src/rescue_raps/prover.rs index 7adee9bbb..6e50f1572 100644 --- a/examples/src/rescue_raps/prover.rs +++ b/examples/src/rescue_raps/prover.rs @@ -139,16 +139,11 @@ where DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) } - fn build_aux_trace( - &self, - trace: &Self::Trace, - aux_rand_elements: &AuxRandElements, - ) -> ColMatrix + fn build_aux_trace(&self, trace: &Self::Trace, aux_rand_elements: &[E]) -> ColMatrix where E: FieldElement, { let main_trace = trace.main_segment(); - let rand_elements = aux_rand_elements.rand_elements(); let mut current_row = unsafe { uninit_vector(main_trace.num_cols()) }; let mut next_row = unsafe { uninit_vector(main_trace.num_cols()) }; @@ -157,10 +152,10 @@ where // Columns storing the copied values for the permutation argument are not necessary, but // help understanding the construction of RAPs and are kept for illustrative purposes. - aux_columns[0][0] = - rand_elements[0] * current_row[0].into() + rand_elements[1] * current_row[1].into(); - aux_columns[1][0] = - rand_elements[0] * current_row[4].into() + rand_elements[1] * current_row[5].into(); + aux_columns[0][0] = aux_rand_elements[0] * current_row[0].into() + + aux_rand_elements[1] * current_row[1].into(); + aux_columns[1][0] = aux_rand_elements[0] * current_row[4].into() + + aux_rand_elements[1] * current_row[5].into(); // Permutation argument column aux_columns[2][0] = E::ONE; @@ -172,14 +167,16 @@ where main_trace.read_row_into(index, &mut current_row); main_trace.read_row_into(index + 1, &mut next_row); - aux_columns[0][index] = rand_elements[0] * (next_row[0] - current_row[0]).into() - + rand_elements[1] * (next_row[1] - current_row[1]).into(); - aux_columns[1][index] = rand_elements[0] * (next_row[4] - current_row[4]).into() - + rand_elements[1] * (next_row[5] - current_row[5]).into(); + aux_columns[0][index] = aux_rand_elements[0] + * (next_row[0] - current_row[0]).into() + + aux_rand_elements[1] * (next_row[1] - current_row[1]).into(); + aux_columns[1][index] = aux_rand_elements[0] + * (next_row[4] - current_row[4]).into() + + aux_rand_elements[1] * (next_row[5] - current_row[5]).into(); } - let num = aux_columns[0][index - 1] + rand_elements[2]; - let denom = aux_columns[1][index - 1] + rand_elements[2]; + let num = aux_columns[0][index - 1] + aux_rand_elements[2]; + let denom = aux_columns[1][index - 1] + aux_rand_elements[2]; aux_columns[2][index] = aux_columns[2][index - 1] * num * denom.inv(); } diff --git a/examples/src/utils/rescue.rs b/examples/src/utils/rescue.rs index e09cb094e..be297fcf3 100644 --- a/examples/src/utils/rescue.rs +++ b/examples/src/utils/rescue.rs @@ -21,6 +21,8 @@ pub const RATE_WIDTH: usize = 4; /// Two elements (32-bytes) are returned as digest. const DIGEST_SIZE: usize = 2; +/// Number of rounds used in Rescue. +/// /// The number of rounds is set to 7 to provide 128-bit security level with 40% security margin; /// computed using algorithm 7 from /// security margin here differs from Rescue Prime specification which suggests 50% security diff --git a/examples/src/vdf/exempt/air.rs b/examples/src/vdf/exempt/air.rs index 9254e4e0a..015778459 100644 --- a/examples/src/vdf/exempt/air.rs +++ b/examples/src/vdf/exempt/air.rs @@ -29,7 +29,7 @@ impl ToElements for VdfInputs { // ================================================================================================ pub struct VdfAir { - context: AirContext, + context: AirContext, seed: BaseElement, result: BaseElement, } @@ -37,16 +37,14 @@ pub struct VdfAir { impl Air for VdfAir { type BaseField = BaseElement; type PublicInputs = VdfInputs; - type GkrProof = (); - type GkrVerifier = (); fn new(trace_info: TraceInfo, pub_inputs: VdfInputs, options: ProofOptions) -> Self { let degrees = vec![TransitionConstraintDegree::new(3)]; assert_eq!(TRACE_WIDTH, trace_info.width()); // make sure the last two rows are excluded from transition constraints as we populate // values in the last row with garbage - let context = - AirContext::new(trace_info, degrees, 2, options).set_num_transition_exemptions(2); + let context = AirContext::new(trace_info, pub_inputs.clone(), degrees, 2, options) + .set_num_transition_exemptions(2); Self { context, seed: pub_inputs.seed, @@ -76,7 +74,7 @@ impl Air for VdfAir { ] } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } } diff --git a/examples/src/vdf/regular/air.rs b/examples/src/vdf/regular/air.rs index b434c1478..bec2ccb3c 100644 --- a/examples/src/vdf/regular/air.rs +++ b/examples/src/vdf/regular/air.rs @@ -29,7 +29,7 @@ impl ToElements for VdfInputs { // ================================================================================================ pub struct VdfAir { - context: AirContext, + context: AirContext, seed: BaseElement, result: BaseElement, } @@ -37,14 +37,12 @@ pub struct VdfAir { impl Air for VdfAir { type BaseField = BaseElement; type PublicInputs = VdfInputs; - type GkrProof = (); - type GkrVerifier = (); fn new(trace_info: TraceInfo, pub_inputs: VdfInputs, options: ProofOptions) -> Self { let degrees = vec![TransitionConstraintDegree::new(3)]; assert_eq!(TRACE_WIDTH, trace_info.width()); Self { - context: AirContext::new(trace_info, degrees, 2, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 2, options), seed: pub_inputs.seed, result: pub_inputs.result, } @@ -67,7 +65,7 @@ impl Air for VdfAir { vec![Assertion::single(0, 0, self.seed), Assertion::single(0, last_step, self.result)] } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } } diff --git a/math/src/field/f64/mod.rs b/math/src/field/f64/mod.rs index 119676076..64c637c0a 100644 --- a/math/src/field/f64/mod.rs +++ b/math/src/field/f64/mod.rs @@ -3,9 +3,10 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$ -//! using Montgomery representation. -//! Our implementation follows and is constant-time. +//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$. +//! +//! Our implementation uses Montgomery representation and follows +//! and is constant-time. //! //! This field supports very fast modular arithmetic and has a number of other attractive //! properties, including: diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 36272766f..125ab2a4f 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -16,16 +16,20 @@ rust-version = "1.78" bench = false [[bench]] -name = "row_matrix" +name = "logup_gkr_e2e" harness = false [[bench]] -name = "lagrange_kernel" +name = "logup_gkr" +harness = false + +[[bench]] +name = "row_matrix" harness = false [features] async = ["maybe_async/async"] -concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "std"] +concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "sumcheck/concurrent", "std"] default = ["std"] std = ["air/std", "crypto/std", "fri/std", "math/std", "utils/std"] @@ -35,6 +39,8 @@ crypto = { version = "0.9", path = "../crypto", package = "winter-crypto", defau fri = { version = "0.9", path = '../fri', package = "winter-fri", default-features = false } math = { version = "0.9", path = "../math", package = "winter-math", default-features = false } maybe_async = { path = "../utils/maybe_async" , package = "winter-maybe-async" } +sumcheck = { version = "0.1", path = "../sumcheck", package = "winter-sumcheck", default-features = false } +thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes"]} utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } diff --git a/prover/benches/lagrange_kernel.rs b/prover/benches/lagrange_kernel.rs index 7ee8ab3c3..348554806 100644 --- a/prover/benches/lagrange_kernel.rs +++ b/prover/benches/lagrange_kernel.rs @@ -7,15 +7,14 @@ use std::time::Duration; use air::{ Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, - EvaluationFrame, FieldExtension, GkrRandElements, LagrangeKernelRandElements, ProofOptions, - TraceInfo, TransitionConstraintDegree, + EvaluationFrame, FieldExtension, ProofOptions, TraceInfo, TransitionConstraintDegree, }; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; -use crypto::{hashers::Blake3_256, DefaultRandomCoin, MerkleTree, RandomCoin}; +use crypto::{hashers::Blake3_256, DefaultRandomCoin, MerkleTree}; use math::{fields::f64::BaseElement, ExtensionOf, FieldElement}; use winter_prover::{ - matrix::ColMatrix, DefaultConstraintEvaluator, DefaultTraceLde, Prover, ProverGkrProof, - StarkDomain, Trace, TracePolyTable, + matrix::ColMatrix, DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, Trace, + TracePolyTable, }; const TRACE_LENS: [usize; 2] = [2_usize.pow(16), 2_usize.pow(20)]; @@ -61,7 +60,7 @@ impl LagrangeTrace { Self { main_trace: ColMatrix::new(vec![main_trace_col]), - info: TraceInfo::new_multi_segment(1, aux_segment_width, 0, trace_len, vec![]), + info: TraceInfo::new_multi_segment(1, aux_segment_width, 0, trace_len, vec![], false), } } @@ -94,31 +93,28 @@ impl Trace for LagrangeTrace { // ================================================================================================= struct LagrangeKernelAir { - context: AirContext, + context: AirContext, } impl Air for LagrangeKernelAir { type BaseField = BaseElement; - type GkrProof = (); - type GkrVerifier = (); - type PublicInputs = (); fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { Self { context: AirContext::new_multi_segment( trace_info, + _pub_inputs, vec![TransitionConstraintDegree::new(1)], vec![TransitionConstraintDegree::new(1)], 1, 1, - Some(0), options, ), } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } @@ -221,42 +217,14 @@ impl Prover for LagrangeProver { DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) } - fn generate_gkr_proof( - &self, - main_trace: &Self::Trace, - public_coin: &mut Self::RandomCoin, - ) -> (ProverGkrProof, GkrRandElements) - where - E: FieldElement, - { - let main_trace = main_trace.main_segment(); - let lagrange_kernel_rand_elements = { - let log_trace_len = main_trace.num_rows().ilog2() as usize; - let mut rand_elements = Vec::with_capacity(log_trace_len); - for _ in 0..log_trace_len { - rand_elements.push(public_coin.draw().unwrap()); - } - - LagrangeKernelRandElements::new(rand_elements) - }; - - ((), GkrRandElements::new(lagrange_kernel_rand_elements, Vec::new())) - } - - fn build_aux_trace( - &self, - main_trace: &Self::Trace, - aux_rand_elements: &AuxRandElements, - ) -> ColMatrix + fn build_aux_trace(&self, main_trace: &Self::Trace, aux_rand_elements: &[E]) -> ColMatrix where E: FieldElement, { let main_trace = main_trace.main_segment(); let mut columns = Vec::new(); - let lagrange_kernel_rand_elements = aux_rand_elements - .lagrange() - .expect("expected lagrange kernel random elements to be present."); + let lagrange_kernel_rand_elements = aux_rand_elements; // first build the Lagrange kernel column { diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs new file mode 100644 index 000000000..e86c84aef --- /dev/null +++ b/prover/benches/logup_gkr.rs @@ -0,0 +1,287 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, time::Duration, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, LogUpGkrEvaluator, + LogUpGkrOracle, ProofOptions, TraceInfo, TransitionConstraintDegree, +}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::RandomCoin; +use math::StarkField; +use winter_prover::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + prove_gkr, Trace, +}; + +const TRACE_LENS: [usize; 4] = [2_usize.pow(18), 2_usize.pow(19), 2_usize.pow(20), 2_usize.pow(21)]; + +/// Simple benchmark for the GKR part of STARK with LogUp-GKR. +/// +/// The main trace contains `5` columns and the LogUp relation is a simple one where we have: +/// +/// 1. a table of values from `0` to `trace_len - 1`. +/// 2. a multiplicity column containing the number of look ups for each value in the table. +/// 3. three columns with values contained in the table above. +/// +/// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling +/// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. +fn prove_with_logup_gkr(c: &mut Criterion) { + let mut group = c.benchmark_group("prove LogUp-GKR"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(20)); + + for &trace_len in TRACE_LENS.iter() { + group.bench_function(BenchmarkId::new("", trace_len), |b| { + let main_trace = LogUpGkrSimpleTrace::new(trace_len); + let evaluator = PlainLogUpGkrEval::new(); + + b.iter_batched( + || (main_trace.clone(), evaluator.clone()), + |(main_trace, evaluator)| { + let mut public_coin = + DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); + prove_gkr::(&main_trace, &evaluator, &mut public_coin) + }, + BatchSize::SmallInput, + ) + }); + } +} + +criterion_group!(logup_gkr_group, prove_with_logup_gkr); +criterion_main!(logup_gkr_group); + +// LogUpGkrSimple +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrSimpleTrace { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimpleTrace { + fn new(trace_len: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + // we create a column for the table we are looking values into. These are just the integers + // from 0 to `trace_len`. + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + + // we create three columns that contains values contained in `table`. For simplicity, we + // look up only the values `0` or `1`, we look up the value `1` four times and the value `0` + // `trace_len - 4` times. + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + // we create the multiplicity column + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + // we look up the value `1` four times in three columns + multiplicity[1] = BaseElement::new(3 * 4); + // we look up the value `0` `trace_len - 4` in three columns + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, 0, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimpleTrace { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} diff --git a/prover/benches/logup_gkr_e2e.rs b/prover/benches/logup_gkr_e2e.rs new file mode 100644 index 000000000..2f81dd850 --- /dev/null +++ b/prover/benches/logup_gkr_e2e.rs @@ -0,0 +1,368 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, time::Duration, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, + EvaluationFrame, FieldExtension, LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, + TransitionConstraintDegree, +}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::MerkleTree; +use math::StarkField; +use winter_prover::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultTraceLde, LogUpGkrConstraintEvaluator, Prover, StarkDomain, Trace, TracePolyTable, +}; + +const TRACE_LENS: [usize; 2] = [2_usize.pow(18), 2_usize.pow(20)]; +const AUX_TRACE_WIDTH: usize = 2; + +/// Simple end-to-end benchmark for LogUp-GKR. +/// +/// The main trace contains `5` columns and the LogUp relation is a simple one where we have: +/// +/// 1. a table of values from `0` to `trace_len - 1`. +/// 2. a multiplicity column containing the number of look ups for each value in the table. +/// 3. three columns with values contained in the table above. +/// +/// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling +/// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. +fn prove_with_logup_gkr(c: &mut Criterion) { + let mut group = c.benchmark_group("prove with LogUp-GKR"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(20)); + + for &trace_len in TRACE_LENS.iter() { + group.bench_function(BenchmarkId::new("", trace_len), |b| { + let trace = LogUpGkrSimpleTrace::new(trace_len, AUX_TRACE_WIDTH); + let prover = LogUpGkrSimpleProver::new(AUX_TRACE_WIDTH); + + b.iter_batched( + || trace.clone(), + |trace| prover.prove(trace).unwrap(), + BatchSize::SmallInput, + ) + }); + } +} + +criterion_group!(logup_gkr_group, prove_with_logup_gkr); +criterion_main!(logup_gkr_group); + +// LogUpGkrSimple +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrSimpleTrace { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimpleTrace { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + // we create a column for the table we are looking values into. These are just the integers + // from 0 to `trace_len`. + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + + // we create three columns that contains values contained in `table`. For simplicity, we + // look up only the values `0` or `1`, we look up the value `1` four times and the value `0` + // `trace_len - 4` times. + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + // we create the multiplicity column + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + // we look up the value `1` four times in three columns + multiplicity[1] = BaseElement::new(3 * 4); + // we look up the value `0` `trace_len - 4` in three columns + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimpleTrace { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} +// Prover +// ================================================================================================ + +struct LogUpGkrSimpleProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrSimpleProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrSimpleProver { + type BaseField = BaseElement; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimpleTrace; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} diff --git a/prover/src/constraints/evaluation_table.rs b/prover/src/constraints/evaluation_table.rs index 826c61253..5ec4f92ee 100644 --- a/prover/src/constraints/evaluation_table.rs +++ b/prover/src/constraints/evaluation_table.rs @@ -46,8 +46,9 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { pub fn new( domain: &'a StarkDomain, divisors: Vec>, + logup_gkr_enabled: bool, ) -> Self { - let num_columns = divisors.len(); + let num_columns = divisors.len() + logup_gkr_enabled as usize; let num_rows = domain.ce_domain_size(); ConstraintEvaluationTable { evaluations: uninit_matrix(num_columns, num_rows), @@ -64,8 +65,9 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { domain: &'a StarkDomain, divisors: Vec>, transition_constraints: &TransitionConstraints, + logup_gkr_enabled: bool, ) -> Self { - let num_columns = divisors.len(); + let num_columns = divisors.len() + logup_gkr_enabled as usize; let num_rows = domain.ce_domain_size(); let num_tm_columns = transition_constraints.num_main_constraints(); let num_ta_columns = transition_constraints.num_aux_constraints(); @@ -161,13 +163,28 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { /// Divides constraint evaluation columns by their respective divisor (in evaluation form) and /// combines the results into a single column. pub fn combine(self) -> Vec { - // allocate memory for the combined polynomial - let mut combined_poly = vec![E::ZERO; self.num_rows()]; + // when LogUp-GKR is enabled, the last column contains the constraint evaluations of + // the Lagrange kernel column and the s-column. These evaluations were already divided by + // their respective divisors, and hence we just have to add them to `combined_poly`. + let mut combined_poly = if self.evaluations.len() != self.divisors.len() { + // allocate memory for the combined polynomial + let mut combined_poly = unsafe { uninit_vector(self.num_rows()) }; + + iter_mut!(combined_poly) + .enumerate() + .for_each(|(i, row)| *row = self.evaluations[self.divisors.len()][i]); + combined_poly + } else { + vec![E::ZERO; self.num_rows()] + }; // iterate over all columns of the constraint evaluation table, divide each column // by the evaluations of its corresponding divisor, and add all resulting evaluations - // together into a single vector - for (column, divisor) in self.evaluations.into_iter().zip(self.divisors.iter()) { + // together into a single vector. When LogUp-GKR is enabled, we skip the last two columns + // of the evaluation table as these were already handled above. + for (column, divisor) in + self.evaluations.into_iter().take(self.divisors.len()).zip(self.divisors.iter()) + { // divide the column by the divisor and accumulate the result into combined_poly acc_column(column, divisor, self.domain, &mut combined_poly); } @@ -210,11 +227,13 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { max_degree = core::cmp::max(max_degree, degree); } - // make sure expected and actual degrees are equal - assert_eq!( - self.expected_transition_degrees, actual_degrees, - "transition constraint degrees didn't match\nexpected: {:>3?}\nactual: {:>3?}", - self.expected_transition_degrees, actual_degrees + // make sure the actual degrees are less than or equal to the expected degree bounds + assert!( + self.expected_transition_degrees >= actual_degrees, + "transition constraint degrees do not satisfy the expected degree bounds + \nexpected degree bounds: {:>3?}\nactual degrees: {:>3?}", + self.expected_transition_degrees, + actual_degrees ); // make sure evaluation domain size does not exceed the size required by max degree diff --git a/prover/src/constraints/evaluator/default.rs b/prover/src/constraints/evaluator/default.rs index 8f96c7dcd..1a5da1b4b 100644 --- a/prover/src/constraints/evaluator/default.rs +++ b/prover/src/constraints/evaluator/default.rs @@ -13,9 +13,8 @@ use utils::iter_mut; use utils::{iterators::*, rayon}; use super::{ - super::EvaluationTableFragment, lagrange::LagrangeKernelConstraintsBatchEvaluator, - BoundaryConstraints, CompositionPolyTrace, ConstraintEvaluationTable, ConstraintEvaluator, - PeriodicValueTable, StarkDomain, TraceLde, + super::EvaluationTableFragment, BoundaryConstraints, CompositionPolyTrace, + ConstraintEvaluationTable, ConstraintEvaluator, PeriodicValueTable, StarkDomain, TraceLde, }; // CONSTANTS @@ -40,7 +39,6 @@ pub struct DefaultConstraintEvaluator<'a, A: Air, E: FieldElement, transition_constraints: TransitionConstraints, - lagrange_constraints_evaluator: Option>, aux_rand_elements: Option>, periodic_values: PeriodicValueTable, } @@ -80,10 +78,14 @@ where // memory to hold all transition constraint evaluations (before they are merged into a // single value) so that we can check their degrees later #[cfg(not(debug_assertions))] - let mut evaluation_table = ConstraintEvaluationTable::::new(domain, divisors); + let mut evaluation_table = ConstraintEvaluationTable::::new(domain, divisors, false); #[cfg(debug_assertions)] - let mut evaluation_table = - ConstraintEvaluationTable::::new(domain, divisors, &self.transition_constraints); + let mut evaluation_table = ConstraintEvaluationTable::::new( + domain, + divisors, + &self.transition_constraints, + false, + ); // when `concurrent` feature is enabled, break the evaluation table into multiple fragments // to evaluate them into multiple threads; unless the constraint evaluation domain is small, @@ -116,16 +118,7 @@ where #[cfg(debug_assertions)] evaluation_table.validate_transition_degrees(); - // combine all constraint evaluations into a single column, including the evaluations of the - // Lagrange kernel constraints (if present) - let combined_evaluations = { - let mut constraints_evaluations = evaluation_table.combine(); - self.evaluate_lagrange_kernel_constraints(trace, domain, &mut constraints_evaluations); - - constraints_evaluations - }; - - CompositionPolyTrace::new(combined_evaluations) + CompositionPolyTrace::new(evaluation_table.combine()) } } @@ -143,6 +136,11 @@ where aux_rand_elements: Option>, composition_coefficients: ConstraintCompositionCoefficients, ) -> Self { + assert!( + !air.context().logup_gkr_enabled(), + "evaluating LogUp-GKR constraints is not supported in `DefaultConstraintEvaluator`" + ); + // build transition constraint groups; these will be used to compose transition constraint // evaluations let transition_constraints = @@ -158,28 +156,10 @@ where &composition_coefficients.boundary, ); - let lagrange_constraints_evaluator = if air.context().has_lagrange_kernel_aux_column() { - let aux_rand_elements = - aux_rand_elements.as_ref().expect("expected aux rand elements to be present"); - let lagrange_rand_elements = aux_rand_elements - .lagrange() - .expect("expected lagrange rand elements to be present"); - Some(LagrangeKernelConstraintsBatchEvaluator::new( - air, - lagrange_rand_elements.clone(), - composition_coefficients - .lagrange - .expect("expected Lagrange kernel composition coefficients to be present"), - )) - } else { - None - }; - DefaultConstraintEvaluator { air, boundary_constraints, transition_constraints, - lagrange_constraints_evaluator, aux_rand_elements, periodic_values, } @@ -198,7 +178,7 @@ where fragment: &mut EvaluationTableFragment, ) { // initialize buffers to hold trace values and evaluation results at each step; - let mut main_frame = EvaluationFrame::new(trace.trace_info().main_trace_width()); + let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); let mut evaluations = vec![E::ZERO; fragment.num_columns()]; let mut t_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()]; @@ -249,7 +229,7 @@ where fragment: &mut EvaluationTableFragment, ) { // initialize buffers to hold trace values and evaluation results at each step - let mut main_frame = EvaluationFrame::new(trace.trace_info().main_trace_width()); + let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); let mut aux_frame = EvaluationFrame::new(trace.trace_info().aux_segment_width()); let mut tm_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()]; let mut ta_evaluations = vec![E::ZERO; self.num_aux_transition_constraints()]; @@ -295,29 +275,6 @@ where } } - /// If present, evaluates the Lagrange kernel constraints over the constraint evaluation domain. - /// The evaluation of each constraint (both boundary and transition) is divided by its divisor, - /// multiplied by its composition coefficient, the result of which is added to - /// `combined_evaluations_accumulator`. - /// - /// Specifically, `combined_evaluations_accumulator` is a buffer whose length is the size of the - /// constraint evaluation domain, where each index contains combined evaluations of other - /// constraints in the system. - fn evaluate_lagrange_kernel_constraints>( - &self, - trace: &T, - domain: &StarkDomain, - combined_evaluations_accumulator: &mut [E], - ) { - if let Some(ref lagrange_constraints_evaluator) = self.lagrange_constraints_evaluator { - lagrange_constraints_evaluator.evaluate_constraints( - trace, - domain, - combined_evaluations_accumulator, - ) - } - } - // TRANSITION CONSTRAINT EVALUATORS // -------------------------------------------------------------------------------------------- diff --git a/prover/src/constraints/evaluator/lagrange.rs b/prover/src/constraints/evaluator/lagrange.rs deleted file mode 100644 index 89d07f62c..000000000 --- a/prover/src/constraints/evaluator/lagrange.rs +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -use alloc::vec::Vec; - -use air::{ - Air, LagrangeConstraintsCompositionCoefficients, LagrangeKernelConstraints, - LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, -}; -use math::{batch_inversion, FieldElement}; - -use crate::{StarkDomain, TraceLde}; - -/// Contains a specific strategy for evaluating the Lagrange kernel boundary and transition -/// constraints where the divisors' evaluation is batched. -/// -/// Specifically, [`batch_inversion`] is used to reduce the number of divisions performed. -pub struct LagrangeKernelConstraintsBatchEvaluator { - lagrange_kernel_constraints: LagrangeKernelConstraints, - rand_elements: LagrangeKernelRandElements, -} - -impl LagrangeKernelConstraintsBatchEvaluator { - /// Constructs a new [`LagrangeConstraintsBatchEvaluator`]. - pub fn new( - air: &A, - lagrange_kernel_rand_elements: LagrangeKernelRandElements, - lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, - ) -> Self - where - E: FieldElement, - { - Self { - lagrange_kernel_constraints: air - .get_lagrange_kernel_constraints( - lagrange_composition_coefficients, - &lagrange_kernel_rand_elements, - ) - .expect("expected Lagrange kernel constraints to be present"), - rand_elements: lagrange_kernel_rand_elements, - } - } - - /// Evaluates the transition and boundary constraints. Specifically, the constraint evaluations - /// are divided by their corresponding divisors, and the resulting terms are linearly combined - /// using the composition coefficients. - /// - /// Writes the evaluations in `combined_evaluations_acc` at the corresponding (constraint - /// evaluation) domain index. - pub fn evaluate_constraints( - &self, - trace: &T, - domain: &StarkDomain, - combined_evaluations_acc: &mut [E], - ) where - T: TraceLde, - { - let lde_shift = domain.ce_to_lde_blowup().trailing_zeros(); - let trans_constraints_divisors = LagrangeKernelTransitionConstraintsDivisor::new( - self.lagrange_kernel_constraints.transition.num_constraints(), - domain, - ); - let boundary_divisors_inv = self.compute_boundary_divisors_inv(domain); - - let mut frame = LagrangeKernelEvaluationFrame::new_empty(); - - for step in 0..domain.ce_domain_size() { - // compute Lagrange kernel frame - trace.read_lagrange_kernel_frame_into( - step << lde_shift, - self.lagrange_kernel_constraints.lagrange_kernel_col_idx, - &mut frame, - ); - - // Compute the combined transition and boundary constraints evaluations for this row - let combined_evaluations = { - let mut combined_evaluations = E::ZERO; - - // combine transition constraints - for trans_constraint_idx in - 0..self.lagrange_kernel_constraints.transition.num_constraints() - { - let numerator = self - .lagrange_kernel_constraints - .transition - .evaluate_ith_numerator(&frame, &self.rand_elements, trans_constraint_idx); - let inv_divisor = trans_constraints_divisors - .get_inverse_divisor_eval(trans_constraint_idx, step); - - combined_evaluations += numerator * inv_divisor; - } - - // combine boundary constraints - { - let boundary_numerator = - self.lagrange_kernel_constraints.boundary.evaluate_numerator_at(&frame); - - combined_evaluations += boundary_numerator * boundary_divisors_inv[step]; - } - - combined_evaluations - }; - - combined_evaluations_acc[step] += combined_evaluations; - } - } - - // HELPERS - // --------------------------------------------------------------------------------------------- - - /// Computes the inverse boundary divisor at every point of the constraint evaluation domain. - /// That is, returns a vector of the form `[1 / div_0, ..., 1 / div_n]`, where `div_i` is the - /// divisor for the Lagrange kernel boundary constraint at the i'th row of the constraint - /// evaluation domain. - fn compute_boundary_divisors_inv(&self, domain: &StarkDomain) -> Vec { - let mut boundary_denominator_evals = Vec::with_capacity(domain.ce_domain_size()); - for step in 0..domain.ce_domain_size() { - let domain_point = domain.get_ce_x_at(step); - let boundary_denominator = self - .lagrange_kernel_constraints - .boundary - .evaluate_denominator_at(domain_point.into()); - boundary_denominator_evals.push(boundary_denominator); - } - - batch_inversion(&boundary_denominator_evals) - } -} - -/// Holds all the transition constraint inverse divisor evaluations over the constraint evaluation -/// domain. -/// -/// [`LagrangeKernelTransitionConstraintsDivisor`] takes advantage of some structure in the -/// divisors' evaluations. Recall that the divisor for the i'th transition constraint is `x^(2^i) - -/// 1`. When substituting `x` for each value of the constraint evaluation domain, for constraints -/// `i>0`, the divisor evaluations "wrap-around" such that some values repeat. For example, -/// -/// i=0: no repetitions -/// i=1: the first half of the buffer is equal to the second half -/// i=2: each 1/4th of the buffer are equal -/// i=3: each 1/8th of the buffer are equal -/// ... -/// Therefore, we only compute the non-repeating section of the buffer in each iteration, and index -/// into it accordingly. -struct LagrangeKernelTransitionConstraintsDivisor { - divisor_evals_inv: Vec, - - // Precompute the indices into `divisors_evals_inv` of the slices that correspond to each - // transition constraint. - // - // For example, for a CE domain size `n=8`, `slice_indices_precomputes = [0, 8, 12, 14]`, such - // that transition constraint `idx` owns the range: - // idx=0: [0, 8) - // idx=1: [8, 12) - // idx=2: [12, 14) - slice_indices_precomputes: Vec, -} - -impl LagrangeKernelTransitionConstraintsDivisor { - pub fn new( - num_lagrange_transition_constraints: usize, - domain: &StarkDomain, - ) -> Self { - let divisor_evals_inv = { - let divisor_evaluator = TransitionDivisorEvaluator::::new( - num_lagrange_transition_constraints, - domain.offset(), - ); - - // The number of divisor evaluations is - // `ce_domain_size + ce_domain_size/2 + ce_domain_size/4 + ... + ce_domain_size/(log(ce_domain_size)-1)`, - // which is slightly smaller than `ce_domain_size * 2` - let mut divisor_evals: Vec = Vec::with_capacity(domain.ce_domain_size() * 2); - - for trans_constraint_idx in 0..num_lagrange_transition_constraints { - let num_non_repeating_denoms = - domain.ce_domain_size() / 2_usize.pow(trans_constraint_idx as u32); - - for step in 0..num_non_repeating_denoms { - let divisor_eval = - divisor_evaluator.evaluate_ith_divisor(trans_constraint_idx, domain, step); - - divisor_evals.push(divisor_eval.into()); - } - } - - batch_inversion(&divisor_evals) - }; - - let slice_indices_precomputes = { - let num_indices = num_lagrange_transition_constraints + 1; - let mut slice_indices_precomputes = Vec::with_capacity(num_indices); - - slice_indices_precomputes.push(0); - - let mut current_slice_len = domain.ce_domain_size(); - for i in 1..num_indices { - let next_precompute_index = slice_indices_precomputes[i - 1] + current_slice_len; - slice_indices_precomputes.push(next_precompute_index); - - current_slice_len /= 2; - } - - slice_indices_precomputes - }; - - Self { - divisor_evals_inv, - slice_indices_precomputes, - } - } - - /// Returns the evaluation `1 / divisor`, where `divisor` is the divisor for the given - /// transition constraint, at the given row of the constraint evaluation domain - pub fn get_inverse_divisor_eval(&self, trans_constraint_idx: usize, row_idx: usize) -> E { - let inv_divisors_slice_for_constraint = - self.get_transition_constraint_slice(trans_constraint_idx); - - inv_divisors_slice_for_constraint[row_idx % inv_divisors_slice_for_constraint.len()] - } - - // HELPERS - // --------------------------------------------------------------------------------------------- - - /// Returns a slice containing all the inverse divisor evaluations for the given transition - /// constraint. - fn get_transition_constraint_slice(&self, trans_constraint_idx: usize) -> &[E] { - let start = self.slice_indices_precomputes[trans_constraint_idx]; - let end = self.slice_indices_precomputes[trans_constraint_idx + 1]; - - &self.divisor_evals_inv[start..end] - } -} - -/// Encapsulates the efficient evaluation of the Lagrange kernel transition constraints divisors. -/// -/// `s` stands for the domain offset (i.e. coset shift element). The key concept in this -/// optimization is to realize that the computation of the first transition constraint divisor can -/// be reused for all the other divisors (call the evaluations `d`). -/// -/// Specifically, each subsequent transition constraint divisor evaluation is equivalent to -/// multiplying an element `d` by a fixed number. For example, the multiplier for the transition -/// constraints are: -/// -/// - transition constraint 1's multiplier: s -/// - transition constraint 2's multiplier: s^3 -/// - transition constraint 3's multiplier: s^7 -/// - transition constraint 4's multiplier: s^15 -/// - ... -/// -/// This is what `s_precomputes` stores. -/// -/// Finally, recall that the ith Lagrange kernel divisor is `x^(2^i) - 1`. -/// [`TransitionDivisorEvaluator`] is only concerned with values of `x` in the constraint evaluation -/// domain, where the j'th element is `s * g^j`, where `g` is the group generator. To understand the -/// implementation of [`Self::evaluate_ith_divisor`], plug in `x = s * g^j` into `x^(2^i) - 1`. -pub struct TransitionDivisorEvaluator { - s_precomputes: Vec, -} - -impl TransitionDivisorEvaluator { - /// Constructs a new [`TransitionDivisorEvaluator`] - pub fn new(num_lagrange_transition_constraints: usize, domain_offset: E::BaseField) -> Self { - let s_precomputes = { - // s_precomputes = [1, s, s^3, s^7, s^15, ...] (where s = domain_offset) - let mut s_precomputes = Vec::with_capacity(num_lagrange_transition_constraints); - - let mut s_exp = E::BaseField::ONE; - for _ in 0..num_lagrange_transition_constraints { - s_precomputes.push(s_exp); - s_exp = s_exp * s_exp * domain_offset; - } - - s_precomputes - }; - - Self { s_precomputes } - } - - /// Evaluates the divisor of the `trans_constraint_idx`'th transition constraint. See - /// [`TransitionDivisorEvaluator`] for a more in-depth description of the algorithm. - pub fn evaluate_ith_divisor( - &self, - trans_constraint_idx: usize, - domain: &StarkDomain, - ce_domain_step: usize, - ) -> E::BaseField { - let domain_idx = ((1 << trans_constraint_idx) * ce_domain_step) % domain.ce_domain_size(); - - self.s_precomputes[trans_constraint_idx] * domain.get_ce_x_at(domain_idx) - - E::BaseField::ONE - } -} diff --git a/prover/src/constraints/evaluator/logup_gkr.rs b/prover/src/constraints/evaluator/logup_gkr.rs new file mode 100644 index 000000000..a729c7545 --- /dev/null +++ b/prover/src/constraints/evaluator/logup_gkr.rs @@ -0,0 +1,326 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; + +use air::{ + Air, GkrData, LagrangeConstraintsCompositionCoefficients, LagrangeKernelConstraints, + LogUpGkrEvaluator, +}; +use math::{batch_inversion, FieldElement}; + +use crate::StarkDomain; + +/// Contains a specific strategy for evaluating the Lagrange kernel and s-column boundary and +/// transition constraints. +pub struct LogUpGkrConstraintsEvaluator { + pub(crate) lagrange_kernel_constraints: LagrangeKernelConstraints, + pub(crate) gkr_data: GkrData, + pub(crate) s_col_composition_coefficient: E, + pub(crate) s_col_idx: usize, + pub(crate) l_col_idx: usize, + pub(crate) mean: E, +} + +impl LogUpGkrConstraintsEvaluator +where + E: FieldElement, +{ + /// Constructs a new [`LogUpGkrConstraintsEvaluator`]. + pub fn new>( + air: &A, + gkr_data: GkrData, + lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, + s_col_composition_coefficient: E, + ) -> Self { + let trace_info = air.trace_info(); + let s_col_idx = trace_info.s_column_idx().expect("S-column should be present"); + let l_col_idx = trace_info + .lagrange_kernel_column_idx() + .expect("Lagrange kernel should be present"); + + let c = gkr_data.compute_batched_claim(); + let mean = c / E::from(E::BaseField::from(trace_info.length() as u32)); + Self { + lagrange_kernel_constraints: air + .get_logup_gkr_evaluator() + .get_lagrange_kernel_constraints( + lagrange_composition_coefficients, + gkr_data.lagrange_kernel_rand_elements(), + ), + gkr_data, + s_col_composition_coefficient, + s_col_idx, + l_col_idx, + mean, + } + } +} + +/// Holds all the transition and boundary constraint inverse divisor evaluations over +/// the constraint evaluation domain for both the Lagrange kernel as well the s-column. +/// +/// [`LogUpGkrConstraintsDivisors`] takes advantage of some structure in the divisors' +/// evaluations for transition constraints. +/// Recall that the divisor for the i'th transition constraint is `x^(2^i) - 1`. +/// When substituting `x` for each value of the constraint evaluation domain, for constraints +/// `i>0`, the divisor evaluations "wrap-around" such that some values repeat. For example, +/// +/// i=0: no repetitions +/// i=1: the first half of the buffer is equal to the second half +/// i=2: each 1/4th of the buffer are equal +/// i=3: each 1/8th of the buffer are equal +/// ... +/// Therefore, we only compute the non-repeating section of the buffer in each iteration, and index +/// into it accordingly. +/// +/// Note that instead of storing `1 / div` for Lagrange and s-column transition and boundary +/// constraints, we store instead `c / div` where `c` is the constraint composition coefficient +/// associated to divisor `div`. We call `c / div` constraint evaluation multipliers or just +/// constraint multipliers. +pub(crate) struct LogUpGkrConstraintsDivisors { + lagrange_transition_multipliers: Vec, + + lagrange_boundary_multipliers: Vec, + + s_col_transition_multipliers: Vec, + + // Precompute the indices into `divisors_evals_inv` of the slices that correspond to each + // transition constraint. + // + // For example, for a CE domain size `n=8`, `slice_indices_precomputes = [0, 8, 12, 14]`, such + // that transition constraint `idx` owns the range: + // idx=0: [0, 8) + // idx=1: [8, 12) + // idx=2: [12, 14) + slice_indices_precomputes: Vec, +} + +impl LogUpGkrConstraintsDivisors { + pub fn new( + logup_gkr_constraints: &LogUpGkrConstraintsEvaluator, + domain: &StarkDomain, + ) -> Self { + let num_lagrange_transition_constraints = + logup_gkr_constraints.lagrange_kernel_constraints.transition.num_constraints(); + + // collect all constraint composition coefficient in order to optimize inversion + let mut lagrange_transition_cc = logup_gkr_constraints + .lagrange_kernel_constraints + .transition + .lagrange_constraint_coefficients() + .to_vec(); + let lagrange_boundary_cc = logup_gkr_constraints + .lagrange_kernel_constraints + .boundary + .constraint_composition_coefficient(); + let s_col_cc = logup_gkr_constraints.s_col_composition_coefficient; + + lagrange_transition_cc.push(lagrange_boundary_cc); + lagrange_transition_cc.push(s_col_cc); + + // batch invert + let constraint_composition_coefficients = lagrange_transition_cc; + let constraint_composition_coefficients_inv = + batch_inversion(&constraint_composition_coefficients); + + let lagrange_cc_inv = + &constraint_composition_coefficients_inv[..num_lagrange_transition_constraints]; + let lagrange_transition_multipliers = { + let divisor_evaluator = TransitionDivisorEvaluator::::new( + num_lagrange_transition_constraints, + domain.offset(), + ); + + // The number of divisor evaluations is + // `ce_domain_size + ce_domain_size/2 + ce_domain_size/4 + ... + + // ce_domain_size/(log(ce_domain_size)-1)`, + // which is slightly smaller than `ce_domain_size * 2`. + // This is also the number of multipliers `c / div` for Lagrange transition constraints + let mut multipliers: Vec = Vec::with_capacity(domain.ce_domain_size() * 2); + + for (trans_constraint_idx, cc_inv) in lagrange_cc_inv.iter().enumerate() { + let num_non_repeating_denoms = + domain.ce_domain_size() / 2_usize.pow(trans_constraint_idx as u32); + + for step in 0..num_non_repeating_denoms { + let divisor_eval = + divisor_evaluator.evaluate_ith_divisor(trans_constraint_idx, domain, step); + + multipliers.push(cc_inv.mul_base(divisor_eval)); + } + } + + batch_inversion(&multipliers) + }; + + // computes the inverse boundary divisor multiplier by the corresponding constraint + // composition at every point of the constraint evaluation domain. + // That is, returns a vector of the form `[c / div_0, ..., c / div_n]`, where `div_i` is the + // divisor for the Lagrange kernel boundary constraint against the first row at the i'th row + // of the constraint evaluation domain, and `c` is the constraint evaluation coefficient. + let lagrange_boundary_multipliers = { + let mut multipliers = Vec::with_capacity(domain.ce_domain_size()); + for step in 0..domain.ce_domain_size() { + let domain_point = domain.get_ce_x_at(step); + let boundary_denominator = domain_point - E::BaseField::ONE; + let multiplier = constraint_composition_coefficients_inv + [num_lagrange_transition_constraints] + .mul_base(boundary_denominator); + multipliers.push(multiplier); + } + + batch_inversion(&multipliers) + }; + + // compute the divisors for the s-column transition constraint + let s_col_transition_multipliers = compute_s_col_multipliers( + domain, + constraint_composition_coefficients_inv[num_lagrange_transition_constraints + 1], + ); + + let slice_indices_precomputes = { + let num_indices = num_lagrange_transition_constraints + 1; + let mut slice_indices_precomputes = Vec::with_capacity(num_indices); + + slice_indices_precomputes.push(0); + + let mut current_slice_len = domain.ce_domain_size(); + for i in 1..num_indices { + let next_precompute_index = slice_indices_precomputes[i - 1] + current_slice_len; + slice_indices_precomputes.push(next_precompute_index); + + current_slice_len /= 2; + } + + slice_indices_precomputes + }; + + Self { + lagrange_transition_multipliers, + lagrange_boundary_multipliers, + slice_indices_precomputes, + s_col_transition_multipliers, + } + } + + /// Returns the evaluation `c / divisor`, where `divisor` is the divisor for the given + /// Lagrange kernel transition constraint, at the given row of the constraint evaluation domain + /// and `c` is the corresponding constraint composition coefficient. + pub fn get_lagrange_transition_multiplier( + &self, + trans_constraint_idx: usize, + row_idx: usize, + ) -> E { + let multipliers_slice = self.get_lagrange_transition_constraint_slice(trans_constraint_idx); + + multipliers_slice[row_idx % multipliers_slice.len()] + } + + /// Returns the evaluation `c / divisor`, where `divisor` runs over all Lagrange kernel + /// boundary constraint divisors at the given row of the constraint evaluation domain and `c` + /// is the corresponding constraint composition coefficient. + pub fn get_lagrange_boundary_multiplier(&self, row_idx: usize) -> E { + self.lagrange_boundary_multipliers[row_idx % self.lagrange_boundary_multipliers.len()] + } + + /// Returns the evaluation `c / divisor`, where `divisor` is the divisor for the s-column + /// transition constraint, at the given row of the constraint evaluation domain and `c` is + /// the corresponding constraint composition coefficient. + pub fn get_s_col_transition_multiplier(&self, row_idx: usize) -> E { + self.s_col_transition_multipliers[row_idx % (self.s_col_transition_multipliers.len())] + } + + // HELPERS + // --------------------------------------------------------------------------------------------- + + /// Returns a slice containing all the multipliers evaluations' for the given Lagrange + /// transition constraint. + fn get_lagrange_transition_constraint_slice(&self, trans_constraint_idx: usize) -> &[E] { + let start = self.slice_indices_precomputes[trans_constraint_idx]; + let end = self.slice_indices_precomputes[trans_constraint_idx + 1]; + + &self.lagrange_transition_multipliers[start..end] + } +} + +/// Encapsulates the efficient evaluation of the Lagrange kernel transition constraints divisors. +/// +/// `s` stands for the domain offset (i.e. coset shift element). The key concept in this +/// optimization is to realize that the computation of the first transition constraint divisor can +/// be reused for all the other divisors (call the evaluations `d`). +/// +/// Specifically, each subsequent transition constraint divisor evaluation is equivalent to +/// multiplying an element `d` by a fixed number. For example, the multiplier for the transition +/// constraints are: +/// +/// - transition constraint 1's multiplier: s +/// - transition constraint 2's multiplier: s^3 +/// - transition constraint 3's multiplier: s^7 +/// - transition constraint 4's multiplier: s^15 +/// - ... +/// +/// This is what `s_precomputes` stores. +/// +/// Finally, recall that the ith Lagrange kernel divisor is `x^(2^i) - 1`. +/// [`TransitionDivisorEvaluator`] is only concerned with values of `x` in the constraint evaluation +/// domain, where the j'th element is `s * g^j`, where `g` is the group generator. To understand the +/// implementation of [`Self::evaluate_ith_divisor`], plug in `x = s * g^j` into `x^(2^i) - 1`. +pub struct TransitionDivisorEvaluator { + s_precomputes: Vec, +} + +impl TransitionDivisorEvaluator { + /// Constructs a new [`TransitionDivisorEvaluator`] + pub fn new(num_lagrange_transition_constraints: usize, domain_offset: E::BaseField) -> Self { + let s_precomputes = { + // s_precomputes = [1, s, s^3, s^7, s^15, ...] (where s = domain_offset) + let mut s_precomputes = Vec::with_capacity(num_lagrange_transition_constraints); + + let mut s_exp = E::BaseField::ONE; + for _ in 0..num_lagrange_transition_constraints { + s_precomputes.push(s_exp); + s_exp = s_exp * s_exp * domain_offset; + } + + s_precomputes + }; + + Self { s_precomputes } + } + + /// Evaluates the divisor of the `trans_constraint_idx`'th transition constraint. See + /// [`TransitionDivisorEvaluator`] for a more in-depth description of the algorithm. + pub fn evaluate_ith_divisor( + &self, + trans_constraint_idx: usize, + domain: &StarkDomain, + ce_domain_step: usize, + ) -> E::BaseField { + let domain_idx = ((1 << trans_constraint_idx) * ce_domain_step) % domain.ce_domain_size(); + + self.s_precomputes[trans_constraint_idx] * domain.get_ce_x_at(domain_idx) + - E::BaseField::ONE + } +} + +/// Computes the evaluations of the s-column multipliers. +/// +/// The divisor for the s-column is $X^n - 1$ where $n$ is the trace length. This means that +/// we need only compute `ce_blowup` many values and thus only that many exponentiations. +fn compute_s_col_multipliers( + domain: &StarkDomain, + composition_coef_inv: E, +) -> Vec { + let degree = domain.trace_length() as u32; + let mut result = Vec::with_capacity(domain.trace_to_ce_blowup()); + + for row in 0..domain.trace_to_ce_blowup() { + let divisor = domain.get_ce_x_at(row).exp(degree.into()) - E::BaseField::ONE; + + result.push(composition_coef_inv.mul_base(divisor)); + } + batch_inversion(&result) +} diff --git a/prover/src/constraints/evaluator/logup_gkr_evaluator.rs b/prover/src/constraints/evaluator/logup_gkr_evaluator.rs new file mode 100644 index 000000000..2c26e735b --- /dev/null +++ b/prover/src/constraints/evaluator/logup_gkr_evaluator.rs @@ -0,0 +1,433 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use air::{ + Air, AuxRandElements, ConstraintCompositionCoefficients, EvaluationFrame, + LagrangeKernelEvaluationFrame, LogUpGkrEvaluator, TransitionConstraints, +}; +use math::FieldElement; +use tracing::instrument; +use utils::iter_mut; +#[cfg(feature = "concurrent")] +use utils::{iterators::*, rayon}; + +use super::{ + super::EvaluationTableFragment, + logup_gkr::{LogUpGkrConstraintsDivisors, LogUpGkrConstraintsEvaluator}, + BoundaryConstraints, CompositionPolyTrace, ConstraintEvaluationTable, ConstraintEvaluator, + PeriodicValueTable, StarkDomain, TraceLde, +}; + +// CONSTANTS +// ================================================================================================ + +#[cfg(feature = "concurrent")] +const MIN_CONCURRENT_DOMAIN_SIZE: usize = 8192; + +// DEFAULT CONSTRAINT EVALUATOR +// ================================================================================================ + +/// Default implementation of the [ConstraintEvaluator] trait. +/// +/// This implementation iterates over all evaluation frames of an extended execution trace and +/// evaluates constraints over these frames one-by-one. Constraint evaluations are merged together +/// using random linear combinations and in the end, only a single column is returned. +/// +/// When `concurrent` feature is enabled, the extended execution trace is split into sets of +/// sequential evaluation frames (called fragments), and frames in each fragment are evaluated +/// in separate threads. +pub struct LogUpGkrConstraintEvaluator<'a, A: Air, E: FieldElement> { + air: &'a A, + boundary_constraints: BoundaryConstraints, + transition_constraints: TransitionConstraints, + periodic_values: PeriodicValueTable, + logup_gkr_constraints_evaluator: LogUpGkrConstraintsEvaluator, + aux_rand_elements: AuxRandElements, +} + +impl<'a, A, E> ConstraintEvaluator for LogUpGkrConstraintEvaluator<'a, A, E> +where + A: Air, + E: FieldElement, +{ + type Air = A; + + #[instrument( + skip_all, + name = "evaluate_constraints", + fields( + ce_domain_size = %domain.ce_domain_size() + ) + )] + fn evaluate>( + self, + trace: &T, + domain: &StarkDomain<::BaseField>, + ) -> CompositionPolyTrace { + assert_eq!( + trace.trace_len(), + domain.lde_domain_size(), + "extended trace length is not consistent with evaluation domain" + ); + + // build a list of constraint divisors; currently, all transition constraints have the same + // divisor which we put at the front of the list; boundary constraint divisors are appended + // after that + let mut divisors = vec![self.transition_constraints.divisor().clone()]; + divisors.append(&mut self.boundary_constraints.get_divisors()); + + // build the divisors related to LogUp-GKR + let logup_gkr_constraints_divisors = + LogUpGkrConstraintsDivisors::::new(&self.logup_gkr_constraints_evaluator, domain); + + // allocate space for constraint evaluations; when we are in debug mode, we also allocate + // memory to hold all transition constraint evaluations (before they are merged into a + // single value) so that we can check their degrees later + #[cfg(not(debug_assertions))] + let mut evaluation_table = ConstraintEvaluationTable::::new(domain, divisors, true); + #[cfg(debug_assertions)] + let mut evaluation_table = ConstraintEvaluationTable::::new( + domain, + divisors, + &self.transition_constraints, + true, + ); + + // when `concurrent` feature is enabled, break the evaluation table into multiple fragments + // to evaluate them into multiple threads; unless the constraint evaluation domain is small, + // then don't bother with concurrent evaluation + + #[cfg(not(feature = "concurrent"))] + let num_fragments = 1; + + #[cfg(feature = "concurrent")] + let num_fragments = if domain.ce_domain_size() >= MIN_CONCURRENT_DOMAIN_SIZE { + rayon::current_num_threads().next_power_of_two() + } else { + 1 + }; + + // evaluate constraints for each fragment; if the trace consist of multiple segments + // we evaluate constraints for all segments. otherwise, we evaluate constraints only + // for the main segment. + let mut fragments = evaluation_table.fragments(num_fragments); + iter_mut!(fragments).for_each(|fragment| { + self.evaluate_fragment_full(trace, domain, fragment, &logup_gkr_constraints_divisors); + }); + + // when in debug mode, make sure expected transition constraint degrees align with + // actual degrees we got during constraint evaluation + #[cfg(debug_assertions)] + evaluation_table.validate_transition_degrees(); + + CompositionPolyTrace::new(evaluation_table.combine()) + } +} + +impl<'a, A, E> LogUpGkrConstraintEvaluator<'a, A, E> +where + A: Air, + E: FieldElement, +{ + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + /// Returns a new evaluator which can be used to evaluate transition and boundary constraints + /// over extended execution trace. + pub fn new( + air: &'a A, + aux_rand_elements: AuxRandElements, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self { + assert!( + air.context().logup_gkr_enabled(), + "`LogUpGkrConstraintEvaluator` can only be used when LogUp-GKR is enabled" + ); + + // build transition constraint groups; these will be used to compose transition constraint + // evaluations + let transition_constraints = + air.get_transition_constraints(&composition_coefficients.transition); + // build periodic value table + let periodic_values = PeriodicValueTable::new(air); + + // build boundary constraint groups; these will be used to evaluate and compose boundary + // constraint evaluations. + let boundary_constraints = BoundaryConstraints::new( + air, + Some(&aux_rand_elements), + &composition_coefficients.boundary, + ); + + let logup_gkr_constraints_evaluator = LogUpGkrConstraintsEvaluator::new( + air, + aux_rand_elements + .gkr_data() + .expect("expected LogUp-GKR randomness to be present"), + composition_coefficients + .lagrange + .expect("expected Lagrange kernel composition coefficients to be present"), + composition_coefficients + .s_col + .expect("expected s-column composition coefficient to be present"), + ); + air.trace_info(); + + Self { + air, + boundary_constraints, + transition_constraints, + logup_gkr_constraints_evaluator, + aux_rand_elements, + periodic_values, + } + } + + // EVALUATION HELPER + // -------------------------------------------------------------------------------------------- + + /// Evaluates constraints for a single fragment of the evaluation table. + /// + /// This evaluates constraints only over all segments of the execution trace (i.e. main segment + /// and all auxiliary segments). + fn evaluate_fragment_full>( + &self, + trace: &T, + domain: &StarkDomain, + fragment: &mut EvaluationTableFragment, + logup_gkr_divisors: &LogUpGkrConstraintsDivisors, + ) { + // initialize buffers to hold trace values and evaluation results at each step + let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); + let mut aux_frame = EvaluationFrame::new(trace.trace_info().aux_segment_width()); + let mut tm_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()]; + let mut ta_evaluations = vec![E::ZERO; self.num_aux_transition_constraints()]; + let mut evaluations = vec![E::ZERO; fragment.num_columns()]; + let mut lagrange_frame = LagrangeKernelEvaluationFrame::new(trace.trace_info().length()); + + let evaluator = self.air.get_logup_gkr_evaluator(); + let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + + // this will be used to convert steps in constraint evaluation domain to steps in + // LDE domain + let lde_shift = domain.ce_to_lde_blowup().trailing_zeros(); + + for i in 0..fragment.num_rows() { + let step = i + fragment.offset(); + + // read both the main and the auxiliary evaluation frames from the trace + trace.read_main_trace_frame_into(step << lde_shift, &mut main_frame); + trace.read_lagrange_kernel_frame_into( + step << lde_shift, + self.logup_gkr_constraints_evaluator.l_col_idx, + &mut lagrange_frame, + ); + trace.read_aux_trace_frame_into(step << lde_shift, &mut aux_frame); + + // evaluate transition constraints and save the merged result the first slot of the + // evaluations buffer; we evaluate and compose constraints in the same function, we + // can just add up the results of evaluating main and auxiliary constraints. + evaluations[0] = self.evaluate_main_transition(&main_frame, step, &mut tm_evaluations); + evaluations[0] += + self.evaluate_aux_transition(&main_frame, &aux_frame, step, &mut ta_evaluations); + + // when in debug mode, save transition constraint evaluations + #[cfg(debug_assertions)] + fragment.update_transition_evaluations(i, &tm_evaluations, &ta_evaluations); + + // evaluate Lagrange kernel constraints and assign them to the last column + *evaluations.last_mut().expect("should contain at least one entry") = self + .evaluate_s_column_transition( + &evaluator, + &main_frame, + &aux_frame, + &mut query, + logup_gkr_divisors.get_s_col_transition_multiplier(step), + ); + // evaluate s-column constraints and add them to the last column + *evaluations.last_mut().expect("should contain at least one entry") += + self.evaluate_lagrange_transition(&lagrange_frame, step, logup_gkr_divisors); + + // evaluate boundary constraints; the results go into remaining slots of the + // evaluations buffer + let main_state = main_frame.current(); + let aux_state = aux_frame.current(); + let limit = evaluations.len() - 1; + self.boundary_constraints.evaluate_all( + main_state, + aux_state, + domain, + step, + &mut evaluations[1..limit], + ); + + // record the result in the evaluation table + fragment.update_row(i, &evaluations); + } + } + + // TRANSITION CONSTRAINT EVALUATOR + // -------------------------------------------------------------------------------------------- + + /// Evaluates transition constraints of the main execution trace at the specified step of the + /// constraint evaluation domain. + /// + /// `x` is the corresponding domain value at the specified step. That is, x = s * g^step, + /// where g is the generator of the constraint evaluation domain, and s is the domain offset. + fn evaluate_main_transition( + &self, + main_frame: &EvaluationFrame, + step: usize, + evaluations: &mut [E::BaseField], + ) -> E { + // TODO: use a more efficient way to zero out memory + evaluations.fill(E::BaseField::ZERO); + + // get periodic values at the evaluation step + let periodic_values = self.periodic_values.get_row(step); + + // evaluate transition constraints over the main segment of the execution trace and save + // the results into evaluations buffer + self.air.evaluate_transition(main_frame, periodic_values, evaluations); + + // merge transition constraint evaluations into a single value and return it; + // we can do this here because all transition constraints have the same divisor. + evaluations + .iter() + .zip(self.transition_constraints.main_constraint_coef().iter()) + .fold(E::ZERO, |acc, (&const_eval, &coef)| acc + coef.mul_base(const_eval)) + } + + /// Evaluates all transition constraints (i.e., for main and the auxiliary trace segment) at the + /// specified step of the constraint evaluation domain. + /// + /// `x` is the corresponding domain value at the specified step. That is, x = s * g^step, + /// where g is the generator of the constraint evaluation domain, and s is the domain offset. + fn evaluate_aux_transition( + &self, + main_frame: &EvaluationFrame, + aux_frame: &EvaluationFrame, + step: usize, + evaluations: &mut [E], + ) -> E { + // TODO: use a more efficient way to zero out memory + evaluations.fill(E::ZERO); + + // get periodic values at the evaluation step + let periodic_values = self.periodic_values.get_row(step); + + // evaluate transition constraints over the auxiliary trace segment and save the results into + // evaluations buffer + self.air.evaluate_aux_transition( + main_frame, + aux_frame, + periodic_values, + &self.aux_rand_elements, + evaluations, + ); + + // merge transition constraint evaluations into a single value and return it; + // we can do this here because all transition constraints have the same divisor. + let evaluation = evaluations + .iter() + .zip(self.transition_constraints.aux_constraint_coef().iter()) + .fold(E::ZERO, |acc, (&const_eval, &coef)| acc + coef * const_eval); + + evaluation + } + + /// Computes the transition and boundary constraints for the Lagrange kernel. + fn evaluate_lagrange_transition( + &self, + lagrange_frame: &LagrangeKernelEvaluationFrame, + step: usize, + constraints_divisors: &LogUpGkrConstraintsDivisors, + ) -> E { + // Compute the combined transition and boundary constraints evaluations for this row + + let mut combined_evaluations = E::ZERO; + + // combine transition constraints + for trans_constraint_idx in 0..self + .logup_gkr_constraints_evaluator + .lagrange_kernel_constraints + .transition + .num_constraints() + { + let numerator = self + .logup_gkr_constraints_evaluator + .lagrange_kernel_constraints + .transition + .evaluate_ith_numerator( + lagrange_frame, + &self.logup_gkr_constraints_evaluator.gkr_data.lagrange_kernel_eval_point, + trans_constraint_idx, + ); + let multiplier = + constraints_divisors.get_lagrange_transition_multiplier(trans_constraint_idx, step); + + combined_evaluations += numerator * multiplier; + } + + // combine boundary constraints + { + let boundary_numerator = self + .logup_gkr_constraints_evaluator + .lagrange_kernel_constraints + .boundary + .evaluate_numerator_at(lagrange_frame); + + combined_evaluations += + boundary_numerator * constraints_divisors.get_lagrange_boundary_multiplier(step); + } + + combined_evaluations + } + + /// Computes the transition constraints for the s-column. + /// + /// The s-column implements the cohomological sum-check argument of [1] and + /// the constraint we enfore is exactly Eq (4) in Lemma 1 in [1]. + /// + /// [1]: https://eprint.iacr.org/2021/930 + fn evaluate_s_column_transition( + &self, + evaluator: &impl LogUpGkrEvaluator, + main_frame: &EvaluationFrame, + aux_frame: &EvaluationFrame, + query: &mut [E::BaseField], + multiplier: E, + ) -> E { + let l_col_idx = self.logup_gkr_constraints_evaluator.l_col_idx; + let s_col_idx = self.logup_gkr_constraints_evaluator.s_col_idx; + let mean = self.logup_gkr_constraints_evaluator.mean; + + let l_cur = aux_frame.current()[l_col_idx]; + let s_cur = aux_frame.current()[s_col_idx]; + let s_nxt = aux_frame.next()[s_col_idx]; + + evaluator.build_query(main_frame, query); + let batched_query = + self.logup_gkr_constraints_evaluator.gkr_data.compute_batched_query(query); + + let rhs = s_cur - mean + batched_query * l_cur; + let lhs = s_nxt; + + (rhs - lhs) * multiplier + } + + // ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the number of transition constraints applied against the main segment of the + /// execution trace. + fn num_main_transition_constraints(&self) -> usize { + self.transition_constraints.num_main_constraints() + } + + /// Returns the number of transition constraints applied against the auxiliary trace segment. + fn num_aux_transition_constraints(&self) -> usize { + self.transition_constraints.num_aux_constraints() + } +} diff --git a/prover/src/constraints/evaluator/mod.rs b/prover/src/constraints/evaluator/mod.rs index da8a166c2..ce0488a57 100644 --- a/prover/src/constraints/evaluator/mod.rs +++ b/prover/src/constraints/evaluator/mod.rs @@ -14,7 +14,10 @@ pub use default::DefaultConstraintEvaluator; mod boundary; use boundary::BoundaryConstraints; -mod lagrange; +mod logup_gkr; + +mod logup_gkr_evaluator; +pub use logup_gkr_evaluator::LogUpGkrConstraintEvaluator; mod periodic_table; use periodic_table::PeriodicValueTable; diff --git a/prover/src/constraints/evaluator/periodic_table.rs b/prover/src/constraints/evaluator/periodic_table.rs index ec72aa766..f1fc751e0 100644 --- a/prover/src/constraints/evaluator/periodic_table.rs +++ b/prover/src/constraints/evaluator/periodic_table.rs @@ -94,7 +94,7 @@ mod tests { use air::Air; use math::{ - fields::f128::BaseElement, get_power_series_with_offset, polynom, FieldElement, StarkField, + fields::f64::BaseElement, get_power_series_with_offset, polynom, FieldElement, StarkField, }; use crate::tests::MockAir; @@ -104,8 +104,8 @@ mod tests { let trace_length = 32; // instantiate AIR with 2 periodic columns - let col1 = vec![1u128, 2].into_iter().map(BaseElement::new).collect::>(); - let col2 = vec![3u128, 4, 5, 6].into_iter().map(BaseElement::new).collect::>(); + let col1 = vec![1u64, 2].into_iter().map(BaseElement::new).collect::>(); + let col2 = vec![3u64, 4, 5, 6].into_iter().map(BaseElement::new).collect::>(); let air = MockAir::with_periodic_columns(vec![col1, col2], trace_length); // build a table of periodic values diff --git a/prover/src/constraints/mod.rs b/prover/src/constraints/mod.rs index 566065f0f..9cf84c3dd 100644 --- a/prover/src/constraints/mod.rs +++ b/prover/src/constraints/mod.rs @@ -6,7 +6,7 @@ use super::{ColMatrix, ConstraintDivisor, RowMatrix, StarkDomain}; mod evaluator; -pub use evaluator::{ConstraintEvaluator, DefaultConstraintEvaluator}; +pub use evaluator::{ConstraintEvaluator, DefaultConstraintEvaluator, LogUpGkrConstraintEvaluator}; mod composition_poly; pub use composition_poly::{CompositionPoly, CompositionPolyTrace}; diff --git a/prover/src/errors.rs b/prover/src/errors.rs index a0d01a233..3a14de46e 100644 --- a/prover/src/errors.rs +++ b/prover/src/errors.rs @@ -21,6 +21,8 @@ pub enum ProverError { /// This error occurs when the base field specified by the AIR does not support field extension /// of degree specified by proof options. UnsupportedFieldExtension(usize), + /// This error occurs when generation of the GKR proof for the LogUp relation fails. + FailedToGenerateGkrProof, } impl fmt::Display for ProverError { @@ -36,6 +38,9 @@ impl fmt::Display for ProverError { Self::UnsupportedFieldExtension(degree) => { write!(f, "field extension of degree {degree} is not supported for the specified base field") } + ProverError::FailedToGenerateGkrProof => { + write!(f, "Failed to generate the GKR proof for the LogUp relation") + } } } } diff --git a/prover/src/lib.rs b/prover/src/lib.rs index ac0e82be2..d9da99dd5 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -48,7 +48,7 @@ pub use air::{ EvaluationFrame, FieldExtension, LagrangeKernelRandElements, ProofOptions, TraceInfo, TransitionConstraintDegree, }; -use air::{AuxRandElements, GkrRandElements}; +use air::{AuxRandElements, GkrData, LogUpGkrEvaluator}; pub use crypto; use crypto::{ElementHasher, RandomCoin, VectorCommitment}; use fri::FriProver; @@ -58,6 +58,7 @@ use math::{ fields::{CubeExtension, QuadExtension}, ExtensibleField, FieldElement, StarkField, ToElements, }; +use sumcheck::FinalOpeningClaim; use tracing::{event, info_span, instrument, Level}; pub use utils::{ iterators, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, @@ -73,7 +74,7 @@ use matrix::{ColMatrix, RowMatrix}; mod constraints; pub use constraints::{ CompositionPoly, CompositionPolyTrace, ConstraintCommitment, ConstraintEvaluator, - DefaultConstraintEvaluator, + DefaultConstraintEvaluator, LogUpGkrConstraintEvaluator, }; mod composer; @@ -86,6 +87,9 @@ pub use trace::{ TraceTableFragment, }; +mod logup_gkr; +pub use logup_gkr::{build_lagrange_column, build_s_column, prove_gkr}; + mod channel; use channel::ProverChannel; @@ -101,9 +105,6 @@ pub mod tests; // this segment width seems to give the best performance for small fields (i.e., 64 bits) const DEFAULT_SEGMENT_WIDTH: usize = 8; -/// Accesses the `GkrProof` type in a [`Prover`]. -pub type ProverGkrProof

= <

::Air as Air>::GkrProof; - /// Defines a STARK prover for a computation. /// /// A STARK prover can be used to generate STARK proofs. The prover contains definitions of a @@ -201,28 +202,11 @@ pub trait Prover { // PROVIDED METHODS // -------------------------------------------------------------------------------------------- - /// Builds the GKR proof. If the [`Air`] doesn't use a GKR proof, leave unimplemented. - #[allow(unused_variables)] - #[maybe_async] - fn generate_gkr_proof( - &self, - main_trace: &Self::Trace, - public_coin: &mut Self::RandomCoin, - ) -> (ProverGkrProof, GkrRandElements) - where - E: FieldElement, - { - unimplemented!("`Prover::generate_gkr_proof` needs to be implemented when the auxiliary trace has a Lagrange kernel column.") - } - /// Builds and returns the auxiliary trace. #[allow(unused_variables)] #[maybe_async] - fn build_aux_trace( - &self, - main_trace: &Self::Trace, - aux_rand_elements: &AuxRandElements, - ) -> ColMatrix + #[instrument(skip_all)] + fn build_aux_trace(&self, main_trace: &Self::Trace, aux_rand_elements: &[E]) -> ColMatrix where E: FieldElement, { @@ -241,7 +225,6 @@ pub trait Prover { fn prove(&self, trace: Self::Trace) -> Result where ::PublicInputs: Send, - ::GkrProof: Send, { // figure out which version of the generic proof generation procedure to run. this is a sort // of static dispatch for selecting two generic parameter: extension field and hash @@ -275,7 +258,6 @@ pub trait Prover { where E: FieldElement, ::PublicInputs: Send, - ::GkrProof: Send, { // 0 ----- instantiate AIR and prover channel --------------------------------------------- @@ -314,27 +296,40 @@ pub trait Prover { // build the auxiliary trace segment, and append the resulting segments to trace commitment // and trace polynomial table structs let aux_trace_with_metadata = if air.trace_info().is_multi_segment() { - let (gkr_proof, aux_rand_elements) = if air.context().has_lagrange_kernel_aux_column() { - let (gkr_proof, gkr_rand_elements) = - maybe_await!(self.generate_gkr_proof(&trace, channel.public_coin())); - - let rand_elements = air - .get_aux_rand_elements(channel.public_coin()) - .expect("failed to draw random elements for the auxiliary trace segment"); - - let aux_rand_elements = - AuxRandElements::new_with_gkr(rand_elements, gkr_rand_elements); - - (Some(gkr_proof), aux_rand_elements) + // build the auxiliary segment without the LogUp-GKR related part + let aux_rand_elements = air + .get_aux_rand_elements(channel.public_coin()) + .expect("failed to draw random elements for the auxiliary trace segment"); + let mut aux_trace = maybe_await!(self.build_aux_trace(&trace, &aux_rand_elements)); + + // build the LogUp-GKR related section of the auxiliary segment, if any. This will also + // build an object containing randomness and data related to the LogUp-GKR section of + // the auxiliary trace segment. + let (gkr_proof, gkr_rand_elements) = if air.context().logup_gkr_enabled() { + let gkr_proof = + prove_gkr(&trace, &air.get_logup_gkr_evaluator(), channel.public_coin()) + .map_err(|_| ProverError::FailedToGenerateGkrProof)?; + + let FinalOpeningClaim { eval_point, openings } = + gkr_proof.get_final_opening_claim(); + + let gkr_data = air + .get_logup_gkr_evaluator() + .generate_univariate_iop_for_multi_linear_opening_data( + openings, + eval_point, + channel.public_coin(), + ); + + // add the extra columns required for LogUp-GKR + maybe_await!(build_logup_gkr_columns(&air, &trace, &mut aux_trace, &gkr_data)); + + (Some(gkr_proof), Some(gkr_data)) } else { - let rand_elements = air - .get_aux_rand_elements(channel.public_coin()) - .expect("failed to draw random elements for the auxiliary trace segment"); - - (None, AuxRandElements::new(rand_elements)) + (None, None) }; - - let aux_trace = maybe_await!(self.build_aux_trace(&trace, &aux_rand_elements)); + // build the set of all random values associated to the auxiliary segment + let aux_rand_elements = AuxRandElements::new(aux_rand_elements, gkr_rand_elements); // commit to the auxiliary trace segment let aux_segment_polys = { @@ -352,7 +347,7 @@ pub trait Prover { }; trace_polys - .add_aux_segment(aux_segment_polys, air.context().lagrange_kernel_aux_column_idx()); + .add_aux_segment(aux_segment_polys, air.context().lagrange_kernel_column_idx()); Some(AuxTraceWithMetadata { aux_trace, aux_rand_elements, gkr_proof }) } else { @@ -616,3 +611,27 @@ pub trait Prover { (constraint_commitment, composition_poly) } } + +/// Builds and appends to the auxiliary segment two additional columns needed for implementing +/// the univariate IOP for multi-linear evaluation of Section 5 in [1]. +/// +/// [1]: https://eprint.iacr.org/2023/1284 +#[maybe_async] +#[instrument(skip_all)] +fn build_logup_gkr_columns( + air: &A, + main_trace: &T, + aux_trace: &mut ColMatrix, + gkr_data: &GkrData, +) where + E: FieldElement, + A: Air, + T: Trace, +{ + let evaluator = air.get_logup_gkr_evaluator(); + let lagrange_col = build_lagrange_column(&gkr_data.lagrange_kernel_eval_point); + let s_col = build_s_column(main_trace, gkr_data, &evaluator, &lagrange_col); + + aux_trace.merge_column(s_col); + aux_trace.merge_column(lagrange_col); +} diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs new file mode 100644 index 000000000..4d2ee975a --- /dev/null +++ b/prover/src/logup_gkr/mod.rs @@ -0,0 +1,514 @@ +use alloc::vec::Vec; +use core::ops::Add; + +use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; +use math::FieldElement; +use sumcheck::{EqFunction, MultiLinearPoly, SumCheckProverError}; +use tracing::instrument; +use utils::{ + batch_iter_mut, chunks, uninit_vector, ByteReader, ByteWriter, Deserializable, + DeserializationError, Serializable, +}; + +use crate::Trace; + +mod prover; +pub use prover::prove_gkr; +#[cfg(feature = "concurrent")] +pub use utils::{ + rayon::{current_num_threads as rayon_num_threads, prelude::*}, + {chunks_mut, iter, iter_mut}, +}; + +// EVALUATED CIRCUIT +// ================================================================================================ + +/// Evaluation of a layered circuit for computing a sum of fractions. +/// +/// The circuit computes a sum of fractions based on the formula a / c + b / d = (a * d + b * c) / +/// (c * d) which defines a "gate" ((a, b), (c, d)) --> (a * d + b * c, c * d) upon which the +/// [`EvaluatedCircuit`] is built. Due to the uniformity of the circuit, each of the circuit +/// layers collect all the: +/// +/// 1. `a`'s into a [`MultiLinearPoly`] called `left_numerators`. +/// 2. `b`'s into a [`MultiLinearPoly`] called `right_numerators`. +/// 3. `c`'s into a [`MultiLinearPoly`] called `left_denominators`. +/// 4. `d`'s into a [`MultiLinearPoly`] called `right_denominators`. +/// +/// The relation between two subsequent layers is given by the formula +/// +/// p_0[layer + 1](x_0, x_1, ..., x_{ν - 2}) = p_0[layer](x_0, x_1, ..., x_{ν - 2}, 0) * +/// q_1[layer](x_0, x_1, ..., x_{ν - 2}, 0) +/// + p_1[layer](x_0, x_1, ..., x_{ν - 2}, 0) * q_0[layer](x_0, +/// x_1, ..., x_{ν - 2}, 0) +/// +/// p_1[layer + 1](x_0, x_1, ..., x_{ν - 2}) = p_0[layer](x_0, x_1, ..., x_{ν - 2}, 1) * +/// q_1[layer](x_0, x_1, ..., x_{ν - 2}, 1) +/// + p_1[layer](x_0, x_1, ..., x_{ν - 2}, 1) * q_0[layer](x_0, +/// x_1, ..., x_{ν - 2}, 1) +/// +/// and +/// +/// q_0[layer + 1](x_0, x_1, ..., x_{ν - 2}) = q_0[layer](x_0, x_1, ..., x_{ν - 2}, 0) * +/// q_1[layer](x_0, x_1, ..., x_{ν - 1}, 0) +/// q_1[layer + 1](x_0, x_1, ..., x_{ν - 2}) = q_0[layer](x_0, x_1, ..., x_{ν - 2}, 1) * +/// q_1[layer](x_0, x_1, ..., x_{ν - 1}, 1) +/// +/// This logic is encoded in [`CircuitWire`]. +/// +/// This means that layer ν will be the output layer and will consist of four values +/// (p_0[ν - 1], p_1[ν - 1], p_0[ν - 1], p_1[ν - 1]) ∈ 𝔽^ν. +pub struct EvaluatedCircuit { + layer_polys: Vec>, +} + +impl EvaluatedCircuit { + /// Creates a new [`EvaluatedCircuit`] by evaluating the circuit where the input layer is + /// defined from the main trace columns. + #[instrument(skip_all, name = "evaluate_logup_gkr_circuit")] + pub fn new( + main_trace_columns: &impl Trace, + evaluator: &impl LogUpGkrEvaluator, + log_up_randomness: &[E], + ) -> Result { + let mut layer_polys = Vec::new(); + + let mut current_layer = + Self::generate_input_layer(main_trace_columns, evaluator, log_up_randomness); + while current_layer.num_wires() > 1 { + let next_layer = Self::compute_next_layer(¤t_layer); + + layer_polys.push(CircuitLayerPolys::from_circuit_layer(current_layer)); + + current_layer = next_layer; + } + + Ok(Self { layer_polys }) + } + + /// Returns all layers of the evaluated circuit, starting from the input layer. + /// + /// Note that the return type is a slice of [`CircuitLayerPolys`] as opposed to + /// [`CircuitLayer`], since the evaluated layers are stored in a representation which can be + /// proved using GKR. + pub fn layers(self) -> Vec> { + self.layer_polys + } + + /// Returns the numerator/denominator polynomials representing the output layer of the circuit. + pub fn output_layer(&self) -> &CircuitLayerPolys { + self.layer_polys.last().expect("circuit has at least one layer") + } + + /// Evaluates the output layer at `query`, where the numerators of the output layer are treated + /// as evaluations of a multilinear polynomial, and similarly for the denominators. + pub fn evaluate_output_layer(&self, query: E) -> (E, E) { + let CircuitLayerPolys { numerators, denominators } = self.output_layer(); + + (numerators.evaluate(&[query]), denominators.evaluate(&[query])) + } + + // HELPERS + // ------------------------------------------------------------------------------------------- + + /// Generates the input layer of the circuit from the main trace columns and some randomness + /// provided by the verifier. + fn generate_input_layer( + trace: &impl Trace, + evaluator: &impl LogUpGkrEvaluator, + log_up_randomness: &[E], + ) -> CircuitLayer { + let num_fractions = evaluator.get_num_fractions(); + let periodic_values = evaluator.build_periodic_values(); + + let mut input_layer_wires = + unsafe { uninit_vector(trace.main_segment().num_rows() * num_fractions) }; + let num_cols = trace.main_segment().num_cols(); + let num_oracles = evaluator.get_oracles().len(); + let num_periodic_cols = periodic_values.num_columns(); + + batch_iter_mut!( + &mut input_layer_wires, + 1024, + |batch: &mut [CircuitWire], batch_offset: usize| { + let mut main_frame = EvaluationFrame::new(num_cols); + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut periodic_values_row = vec![E::BaseField::ZERO; num_periodic_cols]; + let mut numerators = vec![E::ZERO; num_fractions]; + let mut denominators = vec![E::ZERO; num_fractions]; + + let row_offset = batch_offset / num_fractions; + let batch_size = batch.len(); + let num_rows_per_batch = batch_size / num_fractions; + + for i in + (0..trace.main_segment().num_rows()).skip(row_offset).take(num_rows_per_batch) + { + trace.read_main_frame(i, &mut main_frame); + periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); + evaluator.build_query(&main_frame, &mut query); + + evaluator.evaluate_query( + &query, + &periodic_values_row, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + + let n = (i - row_offset) * num_fractions; + for ((wire, numerator), denominator) in batch[n..n + num_fractions] + .iter_mut() + .zip(numerators.iter()) + .zip(denominators.iter()) + { + *wire = CircuitWire::new(*numerator, *denominator); + } + } + } + ); + + CircuitLayer::new(input_layer_wires) + } + + /// Computes the subsequent layer of the circuit from a given layer. + fn compute_next_layer(prev_layer: &CircuitLayer) -> CircuitLayer { + let next_layer_wires = chunks!(prev_layer.wires(), 2) + .map(|input_wires| { + let left_input_wire = input_wires[0]; + let right_input_wire = input_wires[1]; + + // output wire + left_input_wire + right_input_wire + }) + .collect(); + + CircuitLayer::new(next_layer_wires) + } +} + +// CIRCUIT LAYER POLYS +// =============================================================================================== + +/// Holds a layer of an [`EvaluatedCircuit`] in a representation amenable to proving circuit +/// evaluation using GKR. +#[derive(Clone, Debug)] +pub struct CircuitLayerPolys { + pub numerators: MultiLinearPoly, + pub denominators: MultiLinearPoly, +} + +impl CircuitLayerPolys +where + E: FieldElement, +{ + pub fn from_circuit_layer(layer: CircuitLayer) -> Self { + Self::from_wires(layer.wires) + } + + pub fn from_wires(wires: Vec>) -> Self { + let mut numerators = Vec::new(); + let mut denominators = Vec::new(); + + for wire in wires { + numerators.push(wire.numerator); + denominators.push(wire.denominator); + } + + Self { + numerators: MultiLinearPoly::from_evaluations(numerators), + denominators: MultiLinearPoly::from_evaluations(denominators), + } + } + + fn into_numerators_denominators(self) -> (MultiLinearPoly, MultiLinearPoly) { + (self.numerators, self.denominators) + } +} + +impl Serializable for CircuitLayerPolys +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { numerators, denominators } = self; + numerators.write_into(target); + denominators.write_into(target); + } +} + +impl Deserializable for CircuitLayerPolys +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + numerators: MultiLinearPoly::read_from(source)?, + denominators: MultiLinearPoly::read_from(source)?, + }) + } +} + +// CIRCUIT LAYER +// =============================================================================================== + +/// Represents a layer in a [`EvaluatedCircuit`]. +/// +/// A layer is made up of a set of `n` wires, where `n` is a power of two. This is the natural +/// circuit representation of a layer, where each consecutive pair of wires are summed to yield a +/// wire in the subsequent layer of an [`EvaluatedCircuit`]. +/// +/// Note that a [`Layer`] needs to be first converted to a [`LayerPolys`] before the evaluation of +/// the layer can be proved using GKR. +pub struct CircuitLayer { + wires: Vec>, +} + +impl CircuitLayer { + /// Creates a new [`Layer`] from a set of projective coordinates. + /// + /// Panics if the number of projective coordinates is not a power of two. + pub fn new(wires: Vec>) -> Self { + assert!(wires.len().is_power_of_two()); + + Self { wires } + } + + /// Returns the wires that make up this circuit layer. + pub fn wires(&self) -> &[CircuitWire] { + &self.wires + } + + /// Returns the number of wires in the layer. + pub fn num_wires(&self) -> usize { + self.wires.len() + } +} + +// CIRCUIT WIRE +// =============================================================================================== + +/// Represents a fraction `numerator / denominator` as a pair `(numerator, denominator)`. This is +/// the type for the gates' inputs in [`prover::EvaluatedCircuit`]. +/// +/// Hence, addition is defined in the natural way fractions are added together: `a/b + c/d = (ad + +/// bc) / bd`. +#[derive(Debug, Clone, Copy)] +pub struct CircuitWire { + numerator: E, + denominator: E, +} + +impl CircuitWire +where + E: FieldElement, +{ + /// Creates new projective coordinates from a numerator and a denominator. + pub fn new(numerator: E, denominator: E) -> Self { + assert_ne!(denominator, E::ZERO); + + Self { numerator, denominator } + } +} + +impl Add for CircuitWire +where + E: FieldElement, +{ + type Output = Self; + + fn add(self, other: Self) -> Self { + let numerator = self.numerator * other.denominator + other.numerator * self.denominator; + let denominator = self.denominator * other.denominator; + + Self::new(numerator, denominator) + } +} + +/// Represents a claim to be proven by a subsequent call to the sum-check protocol. +#[derive(Debug)] +pub struct GkrClaim { + pub evaluation_point: Vec, + pub claimed_evaluation: (E, E), +} + +/// We receive our 4 multilinear polynomials which were evaluated at a random point: +/// `left_numerators` (or `p0`), `right_numerators` (or `p1`), `left_denominators` (or `q0`), and +/// `right_denominators` (or `q1`). We'll call the 4 evaluations at a random point `p0(r)`, `p1(r)`, +/// `q0(r)`, and `q1(r)`, respectively, where `r` is the random point. Note that `r` is a shorthand +/// for a tuple of random values `(r_0, ... r_{l-1})`, where `2^{l + 1}` is the number of wires in +/// the layer. +/// +/// It is important to recall how `p0` and `p1` were constructed (and analogously for `q0` and +/// `q1`). They are the `numerators` layer polynomial (or `p`) evaluations `p(0, r)` and `p(1, r)`, +/// obtained from [`MultiLinearPoly::project_least_significant_variable`]. Hence, `[p0, p1]` form +/// the evaluations of polynomial `p'(x_0) = p(x_0, r)`. Then, the round claim for `numerators`, +/// defined as `p(r_layer, r)`, is simply `p'(r_layer)`. +fn reduce_layer_claim( + left_numerators_opening: E, + right_numerators_opening: E, + left_denominators_opening: E, + right_denominators_opening: E, + r_layer: E, +) -> (E, E) +where + E: FieldElement, +{ + // This is the `numerators` layer polynomial `f(x_0) = numerators(x_0, rx_0, ..., rx_{l-1})`, + // where `rx_0, ..., rx_{l-1}` are the random variables that were sampled during the sumcheck + // round for this layer. + let numerators_univariate = + MultiLinearPoly::from_evaluations(vec![left_numerators_opening, right_numerators_opening]); + + // This is analogous to `numerators_univariate`, but for the `denominators` layer polynomial + let denominators_univariate = MultiLinearPoly::from_evaluations(vec![ + left_denominators_opening, + right_denominators_opening, + ]); + + ( + numerators_univariate.evaluate(&[r_layer]), + denominators_univariate.evaluate(&[r_layer]), + ) +} + +/// Builds the auxiliary trace column for the univariate sum-check argument. +/// +/// Following Section 5.2 in [1] and using the inner product representation of multi-linear queries, +/// we need two univariate oracles, or equivalently two columns in the auxiliary trace, namely: +/// +/// 1. The Lagrange oracle, denoted by $c(X)$ in [1], and refered to throughout the codebase by +/// the Lagrange kernel column. +/// 2. The oracle witnessing the univariate sum-check relation defined by the aforementioned inner +/// product i.e., equation (12) in [1]. This oracle is refered to throughout the codebase as +/// the s-column. +/// +/// The following function's purpose is two build the column in point 2 given the one in point 1. +/// +/// [1]: https://eprint.iacr.org/2023/1284 +pub fn build_s_column( + trace: &impl Trace, + gkr_data: &GkrData, + evaluator: &impl LogUpGkrEvaluator, + lagrange_kernel_col: &[E], +) -> Vec { + let c = gkr_data.compute_batched_claim(); + let num_oracles = evaluator.get_oracles().len(); + + let main_segment = trace.main_segment(); + let num_cols = main_segment.num_cols(); + let num_rows = main_segment.num_rows(); + let mean = c / E::from(E::BaseField::from(num_rows as u32)); + + #[cfg(not(feature = "concurrent"))] + let result = { + let mut result = Vec::with_capacity(num_rows); + let mut last_value = E::ZERO; + result.push(last_value); + + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut main_frame = EvaluationFrame::new(num_cols); + + for (i, item) in lagrange_kernel_col.iter().enumerate().take(num_rows - 1) { + trace.read_main_frame(i, &mut main_frame); + + evaluator.build_query(&main_frame, &mut query); + let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; + + result.push(cur_value); + last_value = cur_value; + } + + result + }; + + #[cfg(feature = "concurrent")] + let result = { + let mut deltas = unsafe { uninit_vector(num_rows) }; + deltas[0] = E::ZERO; + let batch_size = num_rows / rayon_num_threads().next_power_of_two(); + batch_iter_mut!(&mut deltas[1..], batch_size, |batch: &mut [E], batch_offset: usize| { + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut main_frame = EvaluationFrame::::new(num_cols); + + for (i, v) in batch.iter_mut().enumerate() { + trace.read_main_frame(i + batch_offset, &mut main_frame); + + evaluator.build_query(&main_frame, &mut query); + *v = gkr_data.compute_batched_query(&query) * lagrange_kernel_col[i + batch_offset] + - mean; + } + }); + + // note that `deltas[0]` is set `0` and thus `deltas` satisfies the conditions for invoking + // the function + let mut cumulative_sum = deltas; + prefix_sum_parallel(&mut cumulative_sum, batch_size); + cumulative_sum + }; + + result +} + +/// Builds the Lagrange kernel column at a given point. +pub fn build_lagrange_column(lagrange_randomness: &[E]) -> Vec { + EqFunction::new(lagrange_randomness.into()).evaluations() +} + +#[derive(Debug, thiserror::Error)] +pub enum GkrProverError { + #[error("failed to generate the sum-check proof")] + FailedToProveSumCheck(#[from] SumCheckProverError), + #[error("failed to generate the random challenge")] + FailedToGenerateChallenge, +} + +// HELPER +// ================================================================================================= + +/// Computes the cumulative sum, also called prefix sum, of a vector of field elements using +/// parallelism, in place. +/// +/// The function divides the vector into non-overlapping segments and then computes an array of sums +/// for each segment. The function then applies the naive serial implementation to each segment and +/// uses the pre-computed sums in each segment in order to coordinate the results in the different +/// segments. +/// +/// The input vector is of the form `0 || values` where `values` are the values the cumulative sum +/// vector will be computed for, in place. +#[cfg(feature = "concurrent")] +fn prefix_sum_parallel(vector: &mut [E], batch_size: usize) { + let num_partitions = vector.len().div_ceil(batch_size); + let mut sum_per_partition = vec![E::ZERO; num_partitions]; + + chunks!(vector, batch_size) + .zip(iter_mut!(sum_per_partition)) + .for_each(|(chunk, entry)| *entry = chunk.iter().fold(E::ZERO, |acc, term| acc + *term)); + + prefix_sum_truncate_right(&mut sum_per_partition); + + chunks_mut!(vector, batch_size) + .zip(iter!(sum_per_partition)) + .for_each(|(chunk, sum_so_far)| prefix_sum_truncate_left(chunk, *sum_so_far)); +} + +/// Computes the cumulative sum of a vector but omits the final cumulative sum. +#[cfg(feature = "concurrent")] +fn prefix_sum_truncate_right(values: &mut [E]) { + let mut sum = E::ZERO; + values.iter_mut().for_each(|v| { + let tmp = *v; + *v = sum; + sum += tmp; + }); +} + +/// Computes the cumulative sum of a vector but omits the initial cumulative sum, namely zero. +#[cfg(feature = "concurrent")] +fn prefix_sum_truncate_left(values: &mut [E], sum: E) { + let mut sum = sum; + values.iter_mut().for_each(|v| { + sum += *v; + *v = sum; + }); +} diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs new file mode 100644 index 000000000..6f413f201 --- /dev/null +++ b/prover/src/logup_gkr/prover.rs @@ -0,0 +1,278 @@ +use alloc::vec::Vec; + +use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +use sumcheck::{ + sum_check_prove_higher_degree, sumcheck_prove_plain, BeforeFinalLayerProof, CircuitOutput, + EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, +}; +use tracing::instrument; +#[cfg(feature = "concurrent")] +use utils::rayon::prelude::*; +use utils::{iter, iter_mut, uninit_vector}; + +use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; +use crate::{matrix::ColMatrix, Trace}; + +// PROVER +// ================================================================================================ + +/// Evaluates and proves a fractional sum circuit given a set of composition polynomials. +/// +/// For the input layer of the circuit, each individual component of the quadruple +/// [p_0, p_1, q_0, q_1] is of the form: +/// +/// m(z_0, ... , z_{μ - 1}, x_0, ... , x_{ν - 1}) = \sum_{y ∈ {0,1}^μ} EQ(z, y) * g_{[y]}(f_0(x_0, +/// ... , x_{ν - 1}), ... , f_{κ - 1}(x_0, ... , x_{ν +/// - 1})) +/// +/// where: +/// +/// 1. μ is the log_2 of the number of different numerator/denominator expressions divided by two. +/// 2. [y] := \sum_{j = 0}^{μ - 1} y_j * 2^j +/// 3. κ is the number of multi-linears (i.e., main trace columns) involved in the computation of +/// the circuit (i.e., virtual bus). +/// 4. ν is the log_2 of the trace length. +/// +/// The above `m` is usually referred to as the merge of the individual composed multi-linear +/// polynomials g_{[y]}(f_0(x_0, ... , x_{ν - 1}), ... , f_{κ - 1}(x_0, ... , x_{ν - 1})). +/// +/// The composition polynomials `g` are provided as inputs and then used in order to compute the +/// evaluations of each of the four merge polynomials over {0, 1}^{μ + ν}. The resulting evaluations +/// are then used in order to evaluate the circuit. At this point, the GKR protocol is used to prove +/// the correctness of circuit evaluation. It should be noted that the input layer, which +/// corresponds to the last layer treated by the GKR protocol, is handled differently from the other +/// layers. More specifically, the sum-check protocol used for the input layer is composed of two +/// sum-check protocols, the first one works directly with the evaluations of the `m`'s over {0, +/// 1}^{μ + ν} and runs for μ - 1 rounds. After these μ - 1 rounds, and using the resulting [`RoundClaim`], +/// we run the second and final sum-check protocol for ν rounds on the composed multi-linear +/// polynomial given by +/// +/// \sum_{y ∈ {0,1}^μ} EQ(ρ', y) * g_{[y]}(f_0(x_0, ... , x_{ν - 1}), ... , f_{κ - 1}(x_0, ... , +/// x_{ν - 1})) +/// +/// where ρ' is the randomness sampled during the first sum-check protocol. +/// +/// As part of the final sum-check protocol, the openings {f_j(ρ)} are provided as part of a +/// [`FinalOpeningClaim`]. This latter claim will be proven by the STARK prover later on using the +/// auxiliary trace. +#[instrument(skip_all)] +pub fn prove_gkr( + main_trace: &impl Trace, + evaluator: &impl LogUpGkrEvaluator, + public_coin: &mut impl RandomCoin, +) -> Result, GkrProverError> { + let num_logup_random_values = evaluator.get_num_rand_values(); + let mut logup_randomness: Vec = Vec::with_capacity(num_logup_random_values); + + for _ in 0..num_logup_random_values { + logup_randomness.push(public_coin.draw().expect("failed to generate randomness")); + } + + // evaluate the GKR fractional sum circuit + let circuit = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?; + + // include the circuit output as part of the final proof + let CircuitLayerPolys { numerators, denominators } = circuit.output_layer().clone(); + + // run the GKR prover for all layers except the input layer + let (before_final_layer_proofs, gkr_claim) = prove_intermediate_layers(circuit, public_coin)?; + + // build the MLEs of the relevant main trace columns + let main_trace_mls = + build_mle_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + // build the periodic table representing periodic columns as multi-linear extensions + let periodic_table = evaluator.build_periodic_values(); + + // run the GKR prover for the input layer + let final_layer_proof = prove_input_layer( + evaluator, + logup_randomness, + main_trace_mls, + periodic_table, + gkr_claim, + public_coin, + )?; + + Ok(GkrCircuitProof { + circuit_outputs: CircuitOutput { numerators, denominators }, + before_final_layer_proofs, + final_layer_proof, + }) +} + +/// Proves the final GKR layer which corresponds to the input circuit layer. +#[instrument(skip_all)] +fn prove_input_layer< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + evaluator: &impl LogUpGkrEvaluator, + log_up_randomness: Vec, + multi_linear_ext_polys: Vec>, + periodic_table: PeriodicTable, + claim: GkrClaim, + transcript: &mut C, +) -> Result, GkrProverError> { + // parse the [GkrClaim] resulting from the previous GKR layer + let GkrClaim { evaluation_point, claimed_evaluation } = claim; + + transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); + let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + let claim = claimed_evaluation.0 + claimed_evaluation.1 * r_batch; + + let proof = sum_check_prove_higher_degree( + evaluator, + evaluation_point, + claim, + r_batch, + log_up_randomness, + multi_linear_ext_polys, + periodic_table, + transcript, + )?; + + Ok(FinalLayerProof::new(proof)) +} + +/// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for +/// LogUp-GKR. +#[instrument(skip_all)] +fn build_mle_from_main_trace_segment( + oracles: &[LogUpGkrOracle], + main_trace: &ColMatrix<::BaseField>, +) -> Result>, GkrProverError> { + let mut mls = Vec::with_capacity(oracles.len()); + + for oracle in oracles { + match oracle { + LogUpGkrOracle::CurrentRow(index) => { + let col = main_trace.get_column(*index); + let values: Vec = iter!(col).map(|value| E::from(*value)).collect(); + let ml = MultiLinearPoly::from_evaluations(values); + mls.push(ml) + }, + LogUpGkrOracle::NextRow(index) => { + let col = main_trace.get_column(*index); + + let mut values: Vec = unsafe { uninit_vector(col.len()) }; + values[col.len() - 1] = E::from(col[0]); + iter_mut!(&mut values[..col.len() - 1]) + .enumerate() + .for_each(|(i, value)| *value = E::from(col[i + 1])); + let ml = MultiLinearPoly::from_evaluations(values); + mls.push(ml) + }, + }; + } + + Ok(mls) +} + +/// Proves all GKR layers except for input layer. +#[instrument(skip_all)] +fn prove_intermediate_layers< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + circuit: EvaluatedCircuit, + transcript: &mut C, +) -> Result<(BeforeFinalLayerProof, GkrClaim), GkrProverError> { + // absorb the circuit output layer. This corresponds to sending the four values of the output + // layer to the verifier. The verifier then replies with a challenge `r` in order to evaluate + // `p` and `q` at `r` as multi-linears. + let CircuitLayerPolys { numerators, denominators } = circuit.output_layer(); + let mut evaluations = numerators.evaluations().to_vec(); + evaluations.extend_from_slice(denominators.evaluations()); + transcript.reseed(H::hash_elements(&evaluations)); + + // generate the challenge and reduce [p0, p1, q0, q1] to [pr, qr] + let r = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + let mut claimed_evaluation = circuit.evaluate_output_layer(r); + + let mut layer_proofs: Vec> = Vec::new(); + let mut evaluation_point = vec![r]; + + // Loop over all inner layers, from output to input. + // + // In a layered circuit, each layer is defined in terms of its predecessor. The first inner + // layer (starting from the output layer) is the first layer that has a predecessor. Here, we + // loop over all inner layers in order to iteratively reduce a layer in terms of its successor + // layer. Note that we don't include the input layer, since its predecessor layer will be + // reduced in terms of the input layer separately in `prove_final_circuit_layer`. + for inner_layer in circuit.layers().into_iter().skip(1).rev().skip(1) { + // construct the Lagrange kernel evaluated at the previous GKR round randomness + let mut eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); + + let (numerators, denominators) = inner_layer.into_numerators_denominators(); + + // run the sumcheck protocol + let proof = sum_check_prove_num_rounds_degree_3( + claimed_evaluation, + &evaluation_point, + numerators, + denominators, + &mut eq_mle, + transcript, + )?; + + // sample a random challenge to reduce claims + transcript.reseed(H::hash_elements(&proof.openings_claim.openings)); + let r_layer = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + + // reduce the claim + claimed_evaluation = { + let left_numerators_opening = proof.openings_claim.openings[0]; + let right_numerators_opening = proof.openings_claim.openings[1]; + let left_denominators_opening = proof.openings_claim.openings[2]; + let right_denominators_opening = proof.openings_claim.openings[3]; + + reduce_layer_claim( + left_numerators_opening, + right_numerators_opening, + left_denominators_opening, + right_denominators_opening, + r_layer, + ) + }; + + // collect the randomness used for the current layer + let mut ext = vec![r_layer]; + ext.extend_from_slice(&proof.openings_claim.eval_point); + evaluation_point = ext; + + layer_proofs.push(proof); + } + + Ok(( + BeforeFinalLayerProof { proof: layer_proofs }, + GkrClaim { evaluation_point, claimed_evaluation }, + )) +} + +/// Runs the sum-check prover used in all but the input layer. +#[allow(clippy::too_many_arguments)] +fn sum_check_prove_num_rounds_degree_3< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + claim: (E, E), + evaluation_point: &[E], + p: MultiLinearPoly, + q: MultiLinearPoly, + eq: &mut MultiLinearPoly, + transcript: &mut C, +) -> Result, GkrProverError> { + // generate challenge to batch two sumchecks + transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + let claim = claim.0 + claim.1 * r_batch; + + let proof = sumcheck_prove_plain(claim, evaluation_point, r_batch, p, q, eq, transcript)?; + + Ok(proof) +} diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 6b44fa0e9..5132e2025 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -9,7 +9,7 @@ use air::{ Air, AirContext, Assertion, EvaluationFrame, FieldExtension, ProofOptions, TraceInfo, TransitionConstraintDegree, }; -use math::{fields::f128::BaseElement, FieldElement, StarkField}; +use math::{fields::f64::BaseElement, FieldElement, StarkField}; use crate::TraceTable; @@ -34,7 +34,7 @@ pub fn build_fib_trace(length: usize) -> TraceTable { // ================================================================================================ pub struct MockAir { - context: AirContext, + context: AirContext, assertions: Vec>, periodic_columns: Vec>, } @@ -75,8 +75,6 @@ impl MockAir { impl Air for MockAir { type BaseField = BaseElement; type PublicInputs = (); - type GkrProof = (); - type GkrVerifier = (); fn new(trace_info: TraceInfo, _pub_inputs: (), _options: ProofOptions) -> Self { let context = build_context(trace_info, 8, 1); @@ -87,7 +85,7 @@ impl Air for MockAir { } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } @@ -115,8 +113,8 @@ fn build_context( trace_info: TraceInfo, blowup_factor: usize, num_assertions: usize, -) -> AirContext { +) -> AirContext { let options = ProofOptions::new(32, blowup_factor, 0, FieldExtension::None, 4, 31); let t_degrees = vec![TransitionConstraintDegree::new(2)]; - AirContext::new(trace_info, t_degrees, num_assertions, options) + AirContext::new(trace_info, (), t_degrees, num_assertions, options) } diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 26b383a3b..08fb49a2a 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -3,8 +3,12 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use air::{Air, AuxRandElements, EvaluationFrame, LagrangeKernelBoundaryConstraint, TraceInfo}; +use air::{ + Air, AuxRandElements, EvaluationFrame, LagrangeKernelBoundaryConstraint, LogUpGkrEvaluator, + TraceInfo, +}; use math::{polynom, FieldElement, StarkField}; +use sumcheck::GkrCircuitProof; use super::ColMatrix; @@ -22,7 +26,7 @@ mod tests; /// Defines an [`AuxTraceWithMetadata`] type where the type arguments use their equivalents in an /// [`Air`]. -type AirAuxTraceWithMetadata = AuxTraceWithMetadata::GkrProof>; +type AirAuxTraceWithMetadata = AuxTraceWithMetadata; // AUX TRACE WITH METADATA // ================================================================================================ @@ -30,10 +34,10 @@ type AirAuxTraceWithMetadata = AuxTraceWithMetadata::GkrProo /// Holds the auxiliary trace, the random elements used when generating the auxiliary trace, and /// optionally, a GKR proof. See [`crate::Proof`] for more information about the auxiliary /// proof. -pub struct AuxTraceWithMetadata { +pub struct AuxTraceWithMetadata { pub aux_trace: ColMatrix, pub aux_rand_elements: AuxRandElements, - pub gkr_proof: Option, + pub gkr_proof: Option>, } // TRACE TRAIT @@ -52,7 +56,7 @@ pub struct AuxTraceWithMetadata { /// implementation supports concurrent trace generation and should be sufficient in most /// situations. However, if functionality provided by [TraceTable] is not sufficient, uses can /// provide custom implementations of the [Trace] trait which better suit their needs. -pub trait Trace: Sized { +pub trait Trace: Sized + Sync { /// Base field for this execution trace. /// /// All cells of this execution trace contain values which are elements in this field. @@ -79,7 +83,7 @@ pub trait Trace: Sized { /// Returns the number of columns in the main segment of this trace. fn main_trace_width(&self) -> usize { - self.info().main_trace_width() + self.info().main_segment_width() } /// Returns the number of columns in the auxiliary trace segment. @@ -90,21 +94,18 @@ pub trait Trace: Sized { /// Checks if this trace is valid against the specified AIR, and panics if not. /// /// NOTE: this is a very expensive operation and is intended for use only in debug mode. - fn validate( - &self, - air: &A, - aux_trace_with_metadata: Option<&AirAuxTraceWithMetadata>, - ) where + fn validate(&self, air: &A, aux_trace_with_metadata: Option<&AirAuxTraceWithMetadata>) + where A: Air, E: FieldElement, { // make sure the width align; if they don't something went terribly wrong assert_eq!( self.main_trace_width(), - air.trace_info().main_trace_width(), + air.trace_info().main_segment_width(), "inconsistent trace width: expected {}, but was {}", self.main_trace_width(), - air.trace_info().main_trace_width(), + air.trace_info().main_segment_width(), ); // --- 1. make sure the assertions are valid ---------------------------------------------- @@ -141,7 +142,7 @@ pub trait Trace: Sized { } // then, check the Lagrange kernel assertion, if any - if let Some(lagrange_kernel_col_idx) = air.context().lagrange_kernel_aux_column_idx() { + if let Some(lagrange_kernel_col_idx) = air.context().lagrange_kernel_column_idx() { let boundary_constraint_assertion_value = LagrangeKernelBoundaryConstraint::assertion_value( aux_rand_elements @@ -224,19 +225,27 @@ pub trait Trace: Sized { x *= g; } - // evaluate transition constraints for Lagrange kernel column (if any) and make sure - // they all evaluate to zeros - if let Some(col_idx) = air.context().lagrange_kernel_aux_column_idx() { + // evaluate transition constraints for Lagrange kernel column and s-column, when LogUp-GKR + // is enabled, and make sure they all evaluate to zeros + if air.context().logup_gkr_enabled() { let aux_trace_with_metadata = aux_trace_with_metadata.expect("expected aux trace to be present"); let aux_trace = &aux_trace_with_metadata.aux_trace; let aux_rand_elements = &aux_trace_with_metadata.aux_rand_elements; - - let c = aux_trace.get_column(col_idx); - let v = self.length().ilog2() as usize; - let r = aux_rand_elements.lagrange().expect("expected Lagrange column to be present"); - - // Loop over every constraint + let l_col_idx = air + .context() + .trace_info() + .lagrange_kernel_column_idx() + .expect("should not be None"); + let s_col_idx = air.context().trace_info().s_column_idx().expect("should not be None"); + + let c = aux_trace.get_column(l_col_idx); + let trace_length = self.length(); + let v = trace_length.ilog2() as usize; + let gkr_data = aux_rand_elements.gkr_data().expect("should not be None"); + let r = gkr_data.lagrange_kernel_rand_elements(); + + // Loop over every Lagrange kernel constraint for constraint_idx in 1..v + 1 { let domain_step = 2_usize.pow((v - constraint_idx + 1) as u32); let domain_half_step = 2_usize.pow((v - constraint_idx) as u32); @@ -256,6 +265,36 @@ pub trait Trace: Sized { ); } } + + // Validate the s-column constraint + let evaluator = air.get_logup_gkr_evaluator(); + let mut aux_frame = EvaluationFrame::new(self.aux_trace_width()); + + let c = gkr_data.compute_batched_claim(); + let mean = c / E::from(E::BaseField::from(trace_length as u32)); + + let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + for step in 0..self.length() { + self.read_main_frame(step, &mut main_frame); + read_aux_frame(aux_trace, step, &mut aux_frame); + + let l_cur = aux_frame.current()[l_col_idx]; + let s_cur = aux_frame.current()[s_col_idx]; + let s_nxt = aux_frame.next()[s_col_idx]; + + evaluator.build_query(&main_frame, &mut query); + let batched_query = gkr_data.compute_batched_query(&query); + + let rhs = s_cur - mean + batched_query * l_cur; + let lhs = s_nxt; + + let evaluation = rhs - lhs; + + assert!( + evaluation == E::ZERO, + "s-column transition constraint did not evaluate to ZERO at step {step}" + ); + } } } } diff --git a/prover/src/trace/tests.rs b/prover/src/trace/tests.rs index fc653bbde..b08771a3e 100644 --- a/prover/src/trace/tests.rs +++ b/prover/src/trace/tests.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; -use math::fields::f128::BaseElement; +use math::fields::f64::BaseElement; use crate::{tests::build_fib_trace, Trace}; diff --git a/prover/src/trace/trace_lde/default/mod.rs b/prover/src/trace/trace_lde/default/mod.rs index e06839d53..2cb177bc5 100644 --- a/prover/src/trace/trace_lde/default/mod.rs +++ b/prover/src/trace/trace_lde/default/mod.rs @@ -195,20 +195,17 @@ where lagrange_kernel_aux_column_idx: usize, frame: &mut LagrangeKernelEvaluationFrame, ) { - let frame = frame.frame_mut(); - frame.truncate(0); - let aux_segment = self.aux_segment_lde.as_ref().expect("expected aux segment to be present"); - frame.push(aux_segment.get(lagrange_kernel_aux_column_idx, lde_step)); + frame[0] = aux_segment.get(lagrange_kernel_aux_column_idx, lde_step); let frame_length = self.trace_info.length().ilog2() as usize + 1; for i in 0..frame_length - 1 { let shift = self.blowup() * (1 << i); let next_lde_step = (lde_step + shift) % self.trace_len(); - frame.push(aux_segment.get(lagrange_kernel_aux_column_idx, next_lde_step)); + frame[i + 1] = aux_segment.get(lagrange_kernel_aux_column_idx, next_lde_step); } } diff --git a/prover/src/trace/trace_lde/default/tests.rs b/prover/src/trace/trace_lde/default/tests.rs index c06cc2e60..e1b9b6299 100644 --- a/prover/src/trace/trace_lde/default/tests.rs +++ b/prover/src/trace/trace_lde/default/tests.rs @@ -7,7 +7,7 @@ use alloc::vec::Vec; use crypto::{hashers::Blake3_256, ElementHasher, MerkleTree}; use math::{ - fields::f128::BaseElement, get_power_series, get_power_series_with_offset, polynom, + fields::f64::BaseElement, get_power_series, get_power_series_with_offset, polynom, FieldElement, StarkField, }; diff --git a/prover/src/trace/trace_table.rs b/prover/src/trace/trace_table.rs index dfbd6fe72..0d26c73a0 100644 --- a/prover/src/trace/trace_table.rs +++ b/prover/src/trace/trace_table.rs @@ -166,7 +166,7 @@ impl TraceTable { I: FnOnce(&mut [B]), U: FnMut(usize, &mut [B]), { - let mut state = vec![B::ZERO; self.info.main_trace_width()]; + let mut state = vec![B::ZERO; self.info.main_segment_width()]; init(&mut state); self.update_row(0, &state); @@ -255,7 +255,7 @@ impl TraceTable { /// Returns the number of columns in this execution trace. pub fn width(&self) -> usize { - self.info.main_trace_width() + self.info.main_segment_width() } /// Returns the entire trace column at the specified index. diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml new file mode 100644 index 000000000..7db2e8058 --- /dev/null +++ b/sumcheck/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "winter-sumcheck" +version = "0.1.0" +description = "Implementation of the sum-check protocol for the LogUp-GKR protocol" +authors = ["winterfell contributors"] +readme = "README.md" +license = "MIT" +repository = "https://github.com/novifinancial/winterfell" +documentation = "https://docs.rs/winter-sumcheck/0.1.0" +categories = ["cryptography", "no-std"] +keywords = ["crypto", "sumcheck", "iop"] +edition = "2021" +rust-version = "1.78" + +[[bench]] +name = "sum_check_plain" +harness = false + +[[bench]] +name = "sum_check_high_degree" +harness = false + +[[bench]] +name = "eq_function" +harness = false + +[[bench]] +name = "bind_variable" +harness = false + +[features] +concurrent = ["utils/concurrent", "dep:rayon", "std"] +default = ["std"] +std = ["utils/std"] + +[dependencies] +air = { version = "0.9", path = "../air", package = "winter-air", default-features = false } +crypto = { version = "0.9", path = "../crypto", package = "winter-crypto", default-features = false } +math = { version = "0.9", path = "../math", package = "winter-math", default-features = false } +utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } +rayon = { version = "1.8", optional = true } +smallvec = { version = "1.13", default-features = false } +thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } + +[dev-dependencies] +criterion = "0.5" +rand-utils = { version = "0.9", path = "../utils/rand", package = "winter-rand-utils" } \ No newline at end of file diff --git a/sumcheck/README.md b/sumcheck/README.md new file mode 100644 index 000000000..be6734aae --- /dev/null +++ b/sumcheck/README.md @@ -0,0 +1,24 @@ +# Winter sum-check +This crate contains an implementation of the sum-check protocol intended to be used for [LogUp-GKR](https://eprint.iacr.org/2023/1284) by the Winterfell STARK prover and verifier. + +The crate provides two implementations of the sum-check protocol: + +* An implementation for the sum-check protocol as used in [LogUp-GKR](https://eprint.iacr.org/2023/1284). +* An implementation which generalizes the previous one to the case where the numerators and denominators appearing in the fractional sum-checks in Section 3 of [LogUp-GKR](https://eprint.iacr.org/2023/1284) can be non-linear compositions of multi-linear polynomials. + +The first implementation is intended to be used by the GKR protocol for proving the correct evaluation of all of the layers of the fractionl sum circuit except for the input layer. The second implementation is intended to be used for proving the correct evaluation of the input layer. + + +## Crate features +This crate can be compiled with the following features: + +* `std` - enabled by default and relies on the Rust standard library. +* `concurrent` - implies `std` and also re-exports `rayon` crate and enables multi-threaded execution for some of the crate functions. +* `no_std` - does not rely on Rust's standard library and enables compilation to WebAssembly. + +To compile with `no_std`, disable default features via `--no-default-features` flag. + +License +------- + +This project is [MIT licensed](../LICENSE). \ No newline at end of file diff --git a/sumcheck/benches/bind_variable.rs b/sumcheck/benches/bind_variable.rs new file mode 100644 index 000000000..4e65f684b --- /dev/null +++ b/sumcheck/benches/bind_variable.rs @@ -0,0 +1,39 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use math::fields::f64::BaseElement; +use rand_utils::{rand_value, rand_vector}; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use winter_sumcheck::MultiLinearPoly; + +const POLY_SIZE: [usize; 2] = [1 << 18, 1 << 20]; + +fn bind_variable(c: &mut Criterion) { + let mut group = c.benchmark_group("bind variable "); + group.measurement_time(Duration::from_secs(15)); + + for &poly_size in POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("", poly_size), |b| { + b.iter_batched( + || { + let random_challenge: BaseElement = rand_value(); + let poly = MultiLinearPoly::from_evaluations(rand_vector(poly_size)); + (random_challenge, poly) + }, + |(random_challenge, mut poly)| { + poly.bind_least_significant_variable(random_challenge) + }, + BatchSize::SmallInput, + ) + }); + } +} + +criterion_group!(group, bind_variable); +criterion_main!(group); diff --git a/sumcheck/benches/eq_function.rs b/sumcheck/benches/eq_function.rs new file mode 100644 index 000000000..df2326f95 --- /dev/null +++ b/sumcheck/benches/eq_function.rs @@ -0,0 +1,37 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use math::fields::f64::BaseElement; +use rand_utils::rand_vector; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use winter_sumcheck::EqFunction; + +const LOG_POLY_SIZE: [usize; 2] = [18, 20]; + +fn evaluate_eq(c: &mut Criterion) { + let mut group = c.benchmark_group("EQ function evaluations"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &log_poly_size in LOG_POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("", log_poly_size), |b| { + b.iter_batched( + || { + let randomness: Vec = rand_vector(log_poly_size); + EqFunction::new(randomness.into()) + }, + |eq_function| eq_function.evaluations(), + BatchSize::SmallInput, + ) + }); + } +} + +criterion_group!(group, evaluate_eq); +criterion_main!(group); diff --git a/sumcheck/benches/sum_check_high_degree.rs b/sumcheck/benches/sum_check_high_degree.rs new file mode 100644 index 000000000..483890579 --- /dev/null +++ b/sumcheck/benches/sum_check_high_degree.rs @@ -0,0 +1,178 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, time::Duration}; + +use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin}; +use math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField}; +use rand_utils::{rand_value, rand_vector}; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use winter_sumcheck::{sum_check_prove_higher_degree, MultiLinearPoly}; + +const LOG_POLY_SIZE: [usize; 2] = [18, 20]; + +fn sum_check_high_degree(c: &mut Criterion) { + let mut group = c.benchmark_group("Sum-check prover high degree"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &log_poly_size in LOG_POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("", log_poly_size), |b| { + b.iter_batched( + || { + let logup_randomness = rand_vector(1); + let evaluator = PlainLogUpGkrEval::::default(); + let transcript = + DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); + ( + setup_sum_check::(log_poly_size), + evaluator, + logup_randomness, + transcript, + ) + }, + |( + (claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4), periodic_table), + evaluator, + logup_randomness, + transcript, + )| { + let mls = vec![ml0, ml1, ml2, ml3, ml4]; + let mut transcript = transcript; + + sum_check_prove_higher_degree( + &evaluator, + rand_pt, + claim, + r_batch, + logup_randomness, + mls, + periodic_table, + &mut transcript, + ) + }, + BatchSize::SmallInput, + ) + }); + } +} + +#[allow(clippy::too_many_arguments)] +#[allow(clippy::type_complexity)] +fn setup_sum_check( + log_size: usize, +) -> ( + E, + E, + Vec, + ( + MultiLinearPoly, + MultiLinearPoly, + MultiLinearPoly, + MultiLinearPoly, + MultiLinearPoly, + ), + PeriodicTable, +) { + let n = 1 << log_size; + let table = MultiLinearPoly::from_evaluations(rand_vector(n)); + let multiplicity = MultiLinearPoly::from_evaluations(rand_vector(n)); + let values_0 = MultiLinearPoly::from_evaluations(rand_vector(n)); + let values_1 = MultiLinearPoly::from_evaluations(rand_vector(n)); + let values_2 = MultiLinearPoly::from_evaluations(rand_vector(n)); + let periodic_table = PeriodicTable::default(); + + // this will not generate the correct claim with overwhelming probability but should be fine + // for benchmarking + let rand_pt: Vec = rand_vector(log_size + 2); + let r_batch: E = rand_value(); + let claim: E = rand_value(); + + ( + claim, + r_batch, + rand_pt, + (table, multiplicity, values_0, values_1, values_2), + periodic_table, + ) +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f); + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } +} + +criterion_group!(group, sum_check_high_degree); +criterion_main!(group); diff --git a/sumcheck/benches/sum_check_plain.rs b/sumcheck/benches/sum_check_plain.rs new file mode 100644 index 000000000..203961fa4 --- /dev/null +++ b/sumcheck/benches/sum_check_plain.rs @@ -0,0 +1,73 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin}; +use math::{fields::f64::BaseElement, FieldElement}; +use rand_utils::{rand_value, rand_vector}; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use winter_sumcheck::{sumcheck_prove_plain, EqFunction, MultiLinearPoly}; +const LOG_POLY_SIZE: [usize; 2] = [18, 20]; + +fn sum_check_plain(c: &mut Criterion) { + let mut group = c.benchmark_group("Sum-check prover plain"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &log_poly_size in LOG_POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("", log_poly_size), |b| { + b.iter_batched( + || { + let transcript = + DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); + (setup_sum_check::(log_poly_size), transcript) + }, + |((claim, evaluation_point, r_batch, p, q, eq), transcript)| { + let mut eq = eq; + let mut transcript = transcript; + sumcheck_prove_plain( + claim, + &evaluation_point, + r_batch, + p, + q, + &mut eq, + &mut transcript, + ) + }, + BatchSize::SmallInput, + ) + }); + } +} + +#[allow(clippy::too_many_arguments)] +#[allow(clippy::type_complexity)] +fn setup_sum_check( + log_size: usize, +) -> (E, Vec, E, MultiLinearPoly, MultiLinearPoly, MultiLinearPoly) { + let n = 1 << (log_size + 1); + let p: Vec = rand_vector(n); + let q: Vec = rand_vector(n); + + // this will not generate the correct claim with overwhelming probability but should be fine + // for benchmarking + let rand_pt = rand_vector(log_size); + let r_batch: E = rand_value(); + let claim: E = rand_value(); + let evaluation_point = rand_vector(log_size); + + let p = MultiLinearPoly::from_evaluations(p); + let q = MultiLinearPoly::from_evaluations(q); + let eq = MultiLinearPoly::from_evaluations(EqFunction::new(rand_pt.into()).evaluations()); + + (claim, evaluation_point, r_batch, p, q, eq) +} + +criterion_group!(group, sum_check_plain); +criterion_main!(group); diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs new file mode 100644 index 000000000..b11f19d74 --- /dev/null +++ b/sumcheck/src/lib.rs @@ -0,0 +1,281 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#![no_std] + +use alloc::vec::Vec; + +use ::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use math::FieldElement; + +#[macro_use] +extern crate alloc; + +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; + +mod prover; +pub use prover::*; + +mod verifier; +pub use verifier::*; + +mod univariate; +pub use univariate::CompressedUnivariatePoly; + +mod multilinear; +pub use multilinear::{inner_product, EqFunction, MultiLinearPoly}; + +/// Represents an opening claim at an evaluation point against a batch of oracles. +/// +/// After verifying [`Proof`], the verifier is left with a question on the validity of a final +/// claim on a number of oracles open to a given set of values at some given point. +/// This question is answered either using further interaction with the Prover or using +/// a polynomial commitment opening proof in the compiled protocol. +#[derive(Clone, Debug)] +pub struct FinalOpeningClaim { + pub eval_point: Vec, + pub openings: Vec, +} + +impl Serializable for FinalOpeningClaim { + fn write_into(&self, target: &mut W) { + let Self { eval_point, openings } = self; + eval_point.write_into(target); + openings.write_into(target); + } +} + +impl Deserializable for FinalOpeningClaim +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + eval_point: Deserializable::read_from(source)?, + openings: Deserializable::read_from(source)?, + }) + } +} + +/// A sum-check proof. +/// +/// Composed of the round proofs i.e., the polynomials sent by the Prover at each round as well as +/// the (claimed) openings of the multi-linear oracles at the evaluation point given by the round +/// challenges. +#[derive(Debug, Clone)] +pub struct SumCheckProof { + pub openings_claim: FinalOpeningClaim, + pub round_proofs: Vec>, +} + +impl Serializable for SumCheckProof +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + self.openings_claim.write_into(target); + self.round_proofs.write_into(target); + } +} + +impl Deserializable for SumCheckProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + openings_claim: Deserializable::read_from(source)?, + round_proofs: Deserializable::read_from(source)?, + }) + } +} + +/// A sum-check round proof. +/// +/// This represents the partial polynomial sent by the Prover during one of the rounds of the +/// sum-check protocol. The polynomial is in coefficient form and excludes the coefficient for +/// the linear term as the Verifier can recover it from the other coefficients and the current +/// (reduced) claim. +#[derive(Debug, Clone)] +pub struct RoundProof { + pub round_poly_coefs: CompressedUnivariatePoly, +} + +impl Serializable for RoundProof { + fn write_into(&self, target: &mut W) { + let Self { round_poly_coefs } = self; + round_poly_coefs.write_into(target); + } +} + +impl Deserializable for RoundProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + round_poly_coefs: Deserializable::read_from(source)?, + }) + } +} + +/// A proof for the input circuit layer i.e., the final layer in the GKR protocol. +#[derive(Debug, Clone)] +pub struct FinalLayerProof(SumCheckProof); + +impl FinalLayerProof { + pub fn new(proof: SumCheckProof) -> Self { + Self(proof) + } +} + +impl Serializable for FinalLayerProof +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + self.0.write_into(target); + } +} + +impl Deserializable for FinalLayerProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self(Deserializable::read_from(source)?)) + } +} + +/// Contains the round challenges sent by the Verifier up to some round as well as the current +/// reduced claim. +#[derive(Debug)] +pub struct SumCheckRoundClaim { + pub eval_point: Vec, + pub claim: E, +} + +// GKR CIRCUIT PROOF +// =============================================================================================== + +/// A GKR proof for the correct evaluation of the sum of fractions circuit. +#[derive(Debug, Clone)] +pub struct GkrCircuitProof { + pub circuit_outputs: CircuitOutput, + pub before_final_layer_proofs: BeforeFinalLayerProof, + pub final_layer_proof: FinalLayerProof, +} + +impl GkrCircuitProof { + pub fn get_final_opening_claim(&self) -> FinalOpeningClaim { + self.final_layer_proof.0.openings_claim.clone() + } +} + +impl Serializable for GkrCircuitProof +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + self.circuit_outputs.write_into(target); + self.before_final_layer_proofs.write_into(target); + self.final_layer_proof.0.write_into(target); + } +} + +impl Deserializable for GkrCircuitProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + circuit_outputs: CircuitOutput::read_from(source)?, + before_final_layer_proofs: BeforeFinalLayerProof::read_from(source)?, + final_layer_proof: FinalLayerProof::read_from(source)?, + }) + } +} + +/// A set of sum-check proofs for all GKR layers but for the input circuit layer. +#[derive(Debug, Clone)] +pub struct BeforeFinalLayerProof { + pub proof: Vec>, +} + +impl Serializable for BeforeFinalLayerProof +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { proof } = self; + proof.write_into(target); + } +} + +impl Deserializable for BeforeFinalLayerProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + proof: Deserializable::read_from(source)?, + }) + } +} + +/// Holds the output layer of an [`EvaluatedCircuit`]. +#[derive(Clone, Debug)] +pub struct CircuitOutput { + pub numerators: MultiLinearPoly, + pub denominators: MultiLinearPoly, +} + +impl Serializable for CircuitOutput +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { numerators, denominators } = self; + numerators.write_into(target); + denominators.write_into(target); + } +} + +impl Deserializable for CircuitOutput +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + numerators: MultiLinearPoly::read_from(source)?, + denominators: MultiLinearPoly::read_from(source)?, + }) + } +} + +/// The non-linear composition polynomial of the LogUp-GKR protocol. +/// +/// This is the result of batching the `p_k` and `q_k` of section 3.2 in +/// https://eprint.iacr.org/2023/1284.pdf. +#[inline(always)] +fn comb_func(p0: E, p1: E, q0: E, q1: E, eq: E, r_batch: E) -> E { + (p0 * q1 + p1 * q0 + r_batch * q0 * q1) * eq +} + +/// The non-linear composition polynomial of the LogUp-GKR protocol specific to the input layer. +pub fn evaluate_composition_poly( + eq_at_mu: &[E], + numerators: &[E], + denominators: &[E], + eq_eval: E, + r_sum_check: E, +) -> E { + numerators + .chunks(2) + .zip(denominators.chunks(2).zip(eq_at_mu.iter())) + .map(|(p, (q, eq_w))| *eq_w * comb_func(p[0], p[1], q[0], q[1], eq_eval, r_sum_check)) + .fold(E::ZERO, |acc, x| acc + x) +} diff --git a/sumcheck/src/multilinear.rs b/sumcheck/src/multilinear.rs new file mode 100644 index 000000000..df6177914 --- /dev/null +++ b/sumcheck/src/multilinear.rs @@ -0,0 +1,395 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; +use core::ops::Index; + +use math::FieldElement; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use smallvec::SmallVec; +use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +// MULTI-LINEAR POLYNOMIAL +// ================================================================================================ + +/// Represents a multi-linear polynomial. +/// +/// The representation stores the evaluations of the polynomial over the boolean hyper-cube +/// ${0 , 1}^{\nu}$. +#[derive(Clone, Debug, PartialEq)] +pub struct MultiLinearPoly { + evaluations: Vec, +} + +impl MultiLinearPoly { + /// Constructs a [`MultiLinearPoly`] from its evaluations over the boolean hyper-cube ${0 , 1}^{\nu}$. + pub fn from_evaluations(evaluations: Vec) -> Self { + assert!(evaluations.len().is_power_of_two(), "A multi-linear polynomial should have a power of 2 number of evaluations over the Boolean hyper-cube"); + Self { evaluations } + } + + /// Returns the number of variables of the multi-linear polynomial. + pub fn num_variables(&self) -> usize { + self.evaluations.len().trailing_zeros() as usize + } + + /// Returns the evaluations over the boolean hyper-cube. + pub fn evaluations(&self) -> &[E] { + &self.evaluations + } + + /// Returns the number of evaluations. This is equal to the size of the boolean hyper-cube. + pub fn num_evaluations(&self) -> usize { + self.evaluations.len() + } + + /// Evaluate the multi-linear at some query $(r_0, ..., r_{{\nu} - 1}) \in \mathbb{F}^{\nu}$. + /// + /// It first computes the evaluations of the Lagrange basis polynomials over the interpolating + /// set ${0 , 1}^{\nu}$ at $(r_0, ..., r_{{\nu} - 1})$ i.e., the Lagrange kernel at $(r_0, ..., r_{{\nu} - 1})$. + /// The evaluation then is the inner product, indexed by ${0 , 1}^{\nu}$, of the vector of + /// evaluations times the Lagrange kernel. + pub fn evaluate(&self, query: &[E]) -> E { + let tensored_query = compute_lagrange_basis_evals_at(query); + inner_product(&self.evaluations, &tensored_query) + } + + /// Similar to [`Self::evaluate`], except that the query was already turned into the Lagrange + /// kernel (i.e. the [`lagrange_ker::EqFunction`] evaluated at every point in the set + /// `${0 , 1}^{\nu}$`). + /// + /// This is more efficient than [`Self::evaluate`] when multiple different [`MultiLinearPoly`] + /// need to be evaluated at the same query point. + pub fn evaluate_with_lagrange_kernel(&self, lagrange_kernel: &[E]) -> E { + inner_product(&self.evaluations, lagrange_kernel) + } + + /// Computes $f(r_0, y_1, ..., y_{{\nu} - 1})$ using the linear interpolation formula + /// $(1 - r_0) * f(0, y_1, ..., y_{{\nu} - 1}) + r_0 * f(1, y_1, ..., y_{{\nu} - 1})$ and assigns + /// the resulting multi-linear, defined over a domain of half the size, to `self`. + #[inline(always)] + pub fn bind_least_significant_variable(&mut self, round_challenge: E) { + let num_evals = self.evaluations.len() >> 1; + #[cfg(not(feature = "concurrent"))] + { + for i in 0..num_evals { + // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is + // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is + // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. + let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i << 1) }; + let evaluations_2i_plus_1 = + unsafe { *self.evaluations.get_unchecked((i << 1) + 1) }; + + self.evaluations[i] = + evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); + } + self.evaluations.truncate(num_evals); + } + + #[cfg(feature = "concurrent")] + { + let mut result = unsafe { utils::uninit_vector(num_evals) }; + result.par_iter_mut().enumerate().for_each(|(i, ev)| { + // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is + // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is + // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. + let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i << 1) }; + let evaluations_2i_plus_1 = + unsafe { *self.evaluations.get_unchecked((i << 1) + 1) }; + + *ev = evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); + }); + self.evaluations = result + } + } + + /// Given the multilinear polynomial $f(y_0, y_1, ..., y_{{\nu} - 1})$, returns two polynomials: + /// $f(0, y_1, ..., y_{{\nu} - 1})$ and $f(1, y_1, ..., y_{{\nu} - 1})$. + pub fn project_least_significant_variable(mut self) -> (Self, Self) { + let odds: Vec = self + .evaluations + .iter() + .enumerate() + .filter_map(|(idx, x)| if idx % 2 == 1 { Some(*x) } else { None }) + .collect(); + + // Builds the evens multilinear from the current `self.evaluations` buffer, which saves an + // allocation. + let evens = { + let evens_size = self.num_evaluations() / 2; + for write_idx in 0..evens_size { + let read_idx = write_idx * 2; + self.evaluations[write_idx] = self.evaluations[read_idx]; + } + self.evaluations.truncate(evens_size); + + self.evaluations + }; + + (Self::from_evaluations(evens), Self::from_evaluations(odds)) + } +} + +impl Index for MultiLinearPoly { + type Output = E; + + fn index(&self, index: usize) -> &E { + &(self.evaluations[index]) + } +} + +impl Serializable for MultiLinearPoly +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { evaluations } = self; + evaluations.write_into(target); + } +} + +impl Deserializable for MultiLinearPoly +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + evaluations: Deserializable::read_from(source)?, + }) + } +} + +// EQ FUNCTION +// ================================================================================================ + +/// Maximal expected size of the point of a given Lagrange kernel. +const MAX_EQ_SIZE: usize = 25; + +/// The EQ (equality) function is the binary function defined by +/// +/// $$ +/// EQ: {0 , 1}^{\nu} ⛌ {0 , 1}^{\nu} \longrightarrow {0 , 1} +/// ((x_0, ..., x_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1})) \mapsto \prod_{i = 0}^{{\nu} - 1} (x_i \cdot y_i + (1 - x_i) +/// \cdot (1 - y_i)) +/// $$ +/// +/// Taking its multi-linear extension $\tilde{EQ}$, we can define a basis for the set of multi-linear +/// polynomials in {\nu} variables by +/// $${\tilde{EQ}(., (y_0, ..., y_{{\nu} - 1})): (y_0, ..., y_{{\nu} - 1}) \in {0 , 1}^{\nu}}$$ +/// where each basis function is a function of its first argument. This is called the Lagrange or +/// evaluation basis for evaluation set ${0 , 1}^{\nu}$. +/// +/// Given a function $(f: {0 , 1}^{\nu} \longrightarrow \mathbb{F})$, its multi-linear extension (i.e., the unique +/// mult-linear polynomial extending `f` to $(\tilde{f}: \mathbb{F}^{\nu} \longrightarrow \mathbb{F})$ and agreeing with it on ${0 , 1}^{\nu}$) is +/// defined as the summation of the evaluations of f against the Lagrange basis. +/// More specifically, given $(r_0, ..., r_{{\nu} - 1}) \in \mathbb{F}^{\nu}$, then: +/// +/// $$ +/// \tilde{f}(r_0, ..., r_{{\nu} - 1}) = \sum_{(y_0, ..., y_{{\nu} - 1}) \in {0 , 1}^{\nu}} +/// f(y_0, ..., y_{{\nu} - 1}) \tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1})) +/// $$ +/// +/// We call the Lagrange kernel the evaluation of the $\tilde{EQ}$ function at +/// $((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1}))$ for all $(y_0, ..., y_{{\nu} - 1}) \in {0 , 1}^{\nu}$ for +/// a fixed $(r_0, ..., r_{{\nu} - 1}) \in \mathbb{F}^{\nu}$. +/// +/// [`EqFunction`] represents $\tilde{EQ}$ the multi-linear extension of +/// +/// $((y_0, ..., y_{{\nu} - 1}) \mapsto EQ((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1})))$ +/// +/// and contains a method to generate the Lagrange kernel for defining evaluations of multi-linear +/// extensions of arbitrary functions $(f: {0 , 1}^{\nu} \longrightarrow \mathbb{F})$ at a given point $(r_0, ..., r_{{\nu} - 1})$ +/// as well as a method to evaluate $\tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (t_0, ..., t_{{\nu} - 1})))$ for +/// $(t_0, ..., t_{{\nu} - 1}) \in \mathbb{F}^{\nu}$. +pub struct EqFunction { + r: SmallVec<[E; MAX_EQ_SIZE]>, +} + +impl EqFunction { + /// Creates a new [EqFunction]. + pub fn new(r: SmallVec<[E; MAX_EQ_SIZE]>) -> Self { + EqFunction { r } + } + + /// Computes $\tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (t_0, ..., t_{{\nu} - 1})))$. + pub fn evaluate(&self, t: &[E]) -> E { + assert_eq!(self.r.len(), t.len()); + + (0..self.r.len()) + .map(|i| self.r[i] * t[i] + (E::ONE - self.r[i]) * (E::ONE - t[i])) + .fold(E::ONE, |acc, term| acc * term) + } + + /// Computes $\tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1}))$ for all + /// $(y_0, ..., y_{{\nu} - 1}) \in {0 , 1}^{\nu}$ i.e., the Lagrange kernel at $r = (r_0, ..., r_{{\nu} - 1})$. + pub fn evaluations(&self) -> Vec { + compute_lagrange_basis_evals_at(&self.r) + } + + /// Returns the evaluations of + /// $((y_0, ..., y_{{\nu} - 1}) \mapsto \tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1})))$ + /// over ${0 , 1}^{\nu}$. + pub fn ml_at(evaluation_point: SmallVec<[E; MAX_EQ_SIZE]>) -> MultiLinearPoly { + let eq_evals = EqFunction::new(evaluation_point).evaluations(); + MultiLinearPoly::from_evaluations(eq_evals) + } +} + +// HELPER +// ================================================================================================ + +/// Computes the evaluations of the Lagrange basis polynomials over the interpolating +/// set ${0 , 1}^{\nu}$ at $(r_0, ..., r_{{\nu} - 1})$ i.e., the Lagrange kernel at $(r_0, ..., r_{{\nu} - 1})$. +/// +/// If `concurrent` feature is enabled, this function can make use of multi-threading. +/// +/// The implementation uses the memoization technique in Lemma 3.8 in [1]. More precisely, we can +/// build a table $A^{(\nu)}$ in $\nu$ steps using the following master equation: +/// +/// $$ +/// A^{(j)}\left[\left(w_{1}, \dots, w_{j} \right)\right] = +/// A^{(j - 1)}\left[\left(w_{1}, \dots, w_{j - 1} \right)\right] \times +/// \left(w_{j}\cdot r_{j} + (1 - w_{j})\cdot( 1 - r_{j}) \right) +/// $$ +/// +/// if we interpret $\left(w_{1}, \dots, w_{j} \right)$ in little endian i.e., +/// $\left(w_{1}, \dots, w_{j} \right) = \sum_{i=1}^{\nu} 2^{i - 1}\cdot w_{i}$. +/// +/// We thus have the following algorithm: +/// +/// 1. Split current table, stored as a vector, $A^{(j)}\left[\left(w_{1}, \dots, w_{j} \right)\right]$ +/// into two tables $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 0 \right)\right]$ and +/// $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 1 \right)\right]$, +/// with the first part initialized to $A^{(j - 1)}\left[\left(w_{1}, \dots, w_{j-1} \right)\right]$. +/// 2. Iterating over $\left(w_{1}, \dots, w_{j-1} \right)$, do: +/// 1. Let $factor = A^{(j - 1)}\left[\left(w_{1}, \dots, w_{j-1} \right)\right]$, which is equal +/// by the above to $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 0 \right)\right]$. +/// 2. $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 1 \right)\right] = factor \cdot r_j$ +/// 3. $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 0 \right)\right] = +/// A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 0 \right)\right] - +/// A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 1 \right)\right]$ +/// +/// Note that we can allocate from the start a vector of size $2^{\nu}$ in order to hold the final +/// as well as the intermediate tables. +/// +/// [1]: https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf +fn compute_lagrange_basis_evals_at(query: &[E]) -> Vec { + let n = 1 << query.len(); + let mut evals = unsafe { utils::uninit_vector(n) }; + + let mut size = 1; + evals[0] = E::ONE; + #[cfg(not(feature = "concurrent"))] + let evals = { + for r_i in query.iter() { + let (left_evals, right_evals) = evals.split_at_mut(size); + left_evals.iter_mut().zip(right_evals.iter_mut()).for_each(|(left, right)| { + let factor = *left; + *right = factor * *r_i; + *left -= *right; + }); + + size <<= 1; + } + evals + }; + + #[cfg(feature = "concurrent")] + let evals = { + for r_i in query.iter() { + let (left_evals, right_evals) = evals.split_at_mut(size); + left_evals + .par_iter_mut() + .zip(right_evals.par_iter_mut()) + .for_each(|(left, right)| { + let factor = *left; + *right = factor * *r_i; + *left -= *right; + }); + + size <<= 1; + } + evals + }; + + evals +} + +/// Computes the inner product in the extension field of two slices with the same number of items. +/// +/// If `concurrent` feature is enabled, this function can make use of multi-threading. +pub fn inner_product(x: &[E], y: &[E]) -> E { + #[cfg(not(feature = "concurrent"))] + return x.iter().zip(y.iter()).fold(E::ZERO, |acc, (x_i, y_i)| acc + *x_i * *y_i); + + #[cfg(feature = "concurrent")] + return x + .par_iter() + .zip(y.par_iter()) + .map(|(x_i, y_i)| *x_i * *y_i) + .reduce(|| E::ZERO, |a, b| a + b); +} + +// TESTS +// ================================================================================================ + +#[test] +fn multi_linear_sanity_checks() { + use math::fields::f64::BaseElement; + let nu = 3; + let n = 1 << nu; + + // the zero multi-linear should evaluate to zero + let p = MultiLinearPoly::from_evaluations(vec![BaseElement::ZERO; n]); + let challenge: Vec = rand_utils::rand_vector(nu); + + assert_eq!(BaseElement::ZERO, p.evaluate(&challenge)); + + // the constant multi-linear should be constant everywhere + let constant = rand_utils::rand_value(); + let p = MultiLinearPoly::from_evaluations(vec![constant; n]); + let challenge: Vec = rand_utils::rand_vector(nu); + + assert_eq!(constant, p.evaluate(&challenge)) +} + +#[test] +fn test_bind() { + use math::fields::f64::BaseElement; + let mut p = MultiLinearPoly::from_evaluations(vec![BaseElement::ONE; 8]); + let expected = MultiLinearPoly::from_evaluations(vec![BaseElement::ONE; 4]); + + let challenge = rand_utils::rand_value(); + p.bind_least_significant_variable(challenge); + assert_eq!(p, expected) +} + +#[test] +fn test_eq_function() { + use math::fields::f64::BaseElement; + use rand_utils::rand_value; + use smallvec::smallvec; + + let one = BaseElement::ONE; + + // Lagrange kernel is computed correctly + let r0 = rand_value(); + let r1 = rand_value(); + let eq_function = EqFunction::new(smallvec![r0, r1]); + + let expected = vec![(one - r0) * (one - r1), r0 * (one - r1), (one - r0) * r1, r0 * r1]; + + assert_eq!(expected, eq_function.evaluations()); + + // Lagrange kernel evaluation is correct + let q0 = rand_value(); + let q1 = rand_value(); + let tensored_query = vec![(one - q0) * (one - q1), q0 * (one - q1), (one - q0) * q1, q0 * q1]; + + let expected = inner_product(&tensored_query, &eq_function.evaluations()); + + assert_eq!(expected, eq_function.evaluate(&[q0, q1])) +} diff --git a/sumcheck/src/prover/error.rs b/sumcheck/src/prover/error.rs new file mode 100644 index 000000000..c86198d73 --- /dev/null +++ b/sumcheck/src/prover/error.rs @@ -0,0 +1,15 @@ +#[derive(Debug, thiserror::Error)] +pub enum SumCheckProverError { + #[error("number of rounds for sum-check must be greater than zero")] + NumRoundsZero, + #[error("the number of rounds is greater than the number of variables")] + TooManyRounds, + #[error("should provide at least one multi-linear polynomial as input")] + NoMlsProvided, + #[error("failed to generate round challenge")] + FailedToGenerateChallenge, + #[error("the provided multi-linears have different arities")] + MlesDifferentArities, + #[error("multi-linears should have at least one variable")] + AtLeastOneVariable, +} diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs new file mode 100644 index 000000000..da2021f63 --- /dev/null +++ b/sumcheck/src/prover/high_degree.rs @@ -0,0 +1,675 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; + +use air::{LogUpGkrEvaluator, PeriodicTable}; +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; + +use super::{compute_scaling_down_factors, to_coefficients, SumCheckProverError}; +use crate::{ + evaluate_composition_poly, EqFunction, FinalOpeningClaim, MultiLinearPoly, RoundProof, + SumCheckProof, SumCheckRoundClaim, +}; + +/// A sum-check prover for the input layer which can accommodate non-linear expressions in +/// the numerators of the LogUp relation. +/// +/// The LogUp-GKR protocol in [1] is an IOP for the following statement +/// +/// $$ +/// \sum_{v_i, x_i} \frac{p_n\left(v_1, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right)} +/// {q_n\left(v_1, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right)} = C +/// $$ +/// +/// where: +/// +/// $$ +/// p_n\left(v_1, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = +/// \sum_{w\in\{0, 1\}^\mu} EQ\left(\left(v_1, \cdots, v_{\mu}\right), +/// \left(w_1, \cdots, w_{\mu}\right)\right) +/// g_{[w]}\left(f_1\left(x_1, \cdots, x_{\nu}\right), +/// \cdots, f_l\left(x_1, \cdots, x_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// $$ +/// q_n\left(v_1, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = +/// \sum_{w\in\{0, 1\}^\mu} EQ\left(\left(v_1, \cdots, v_{\mu}\right), +/// \left(w_1, \cdots, w_{\mu}\right)\right) +/// h_{[w]}\left(f_1\left(x_1, \cdots, x_{\nu}\right), +/// \cdots, f_l\left(x_1, \cdots, x_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// 1. $f_i$ are multi-linears. +/// 2. ${[w]} := \sum_i w_i \cdot 2^i$ and $w := (w_1, \cdots, w_{\mu})$. +/// 3. $h_{j}$ and $g_{j}$ are multi-variate polynomials for $j = 0, \cdots, 2^{\mu} - 1$. +/// 4. $n := \nu + \mu$ +/// 5. $\\B_{\gamma} := \{0, 1\}^{\gamma}$ for positive integer $\gamma$. +/// +/// The sum above is evaluated using a layered circuit with the equation linking the input layer +/// values $p_n$ to the next layer values $p_{n-1}$ given by the following relations +/// +/// $$ +/// p_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = \sum_{w_i, y_i} +/// EQ\left(\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right), +/// \left(w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) +/// \cdot \left( p_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) + +/// p_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \cdot +/// q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// $$ +/// q_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = \sum_{w_i, y_i} +/// EQ\left(\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right), +/// \left(w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) +/// \cdot \left( q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) +/// $$ +/// +/// and similarly for all subsequent layers. +/// +/// By the properties of the $EQ$ function, we can write the above as follows: +/// +/// $$ +/// p_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = \sum_{y_i} +/// EQ\left(\left(x_1, \cdots, x_{\nu}\right), +/// \left(y_1, \cdots, y_{\nu}\right)\right) +/// \left( \sum_{w_i} EQ\left(\left(v_2, \cdots, v_{\mu}\right), +/// \left(w_2, \cdots, w_{\mu}\right)\right) +/// \cdot \left( p_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) + +/// p_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \cdot +/// q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) \right) +/// $$ +/// +/// and +/// +/// $$ +/// q_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = \sum_{y_i} +/// EQ\left(\left(x_1, \cdots, x_{\nu}\right), +/// \left(y_1, \cdots, y_{\nu}\right)\right) +/// \left( \sum_{w_i} EQ\left(\left(v_2, \cdots, v_{\mu}\right)\right) +/// \cdot q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \right) +/// $$ +/// +/// These expressions are nothing but the equations in Section 3.2 in [1] but with the projection +/// happening in the first argument instead of the last one. +/// The current function is then tasked with running a batched sum-check protocol for +/// +/// $$ +/// p_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = +/// \sum_{y\in\\B_{\nu}} G(y_{1}, ..., y_{\nu}) +/// $$ +/// +/// and +/// +/// $$ +/// q_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = +/// \sum_{y\in\\B_{\nu}} H\left(y_1, \cdots, y_{\nu} \right) +/// $$ +/// +/// where +/// +/// $$ +/// G := \left( \left(y_1, \cdots, y_{\nu}\right) \longrightarrow +/// EQ\left(\left(x_1, \cdots, x_{\nu}\right), +/// \left(y_1, \cdots, y_{\nu}\right)\right) +/// \left( \sum_{w_i} EQ\left(\left(v_2, \cdots, v_{\mu}\right), +/// \left(w_2, \cdots, w_{\mu}\right)\right) +/// \cdot \left( p_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) + +/// p_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \cdot +/// q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) \right) +/// \right) +/// $$ +/// +/// and +/// +/// $$ +/// H := \left( \left(y_1, \cdots, y_{\nu}\right) \longrightarrow +/// EQ\left(\left(x_1, \cdots, x_{\nu}\right), +/// \left(y_1, \cdots, y_{\nu}\right)\right) +/// \left( \sum_{w_i} EQ\left(\left(v_2, \cdots, v_{\mu}\right)\right) +/// \cdot q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \right) +/// \right) +/// $$ +/// +/// +/// We now discuss a further optimization due to [2]. Suppose that we have a sum-check statment of +/// the following form: +/// +/// $$v_0=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{\nu - 1}\right);\left( x_0, \cdots, x_{\nu - 1}\right)\right) +/// C\left( x_0, \cdots, x_{\nu - 1} \right)$$ +/// +/// Then during round $i + 1$ of sum-check, the prover needs to send the following polynomial +/// +/// $$v_{i+1}(X)=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1},\alpha_i, \alpha_{i+1},\cdots\alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// We can write $v_{i+1}(X)$ as: +/// +/// $$v_{i+1}(X)=Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1} \right);\left(r_0,\cdots,r_{i-1}\right)\right) +/// \cdot Eq\left(\alpha_i ;X\right)\sum_{x}Eq\left(\left(\alpha_{i+1},\cdots\alpha_{\nu - 1}\right);\left( x_{i+1}, \cdots x_{\nu - 1}\right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// This means that $v_{i+1}(X)$ is the product of: +/// +/// 1. A constant polynomial: $Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right);\left( r_0, \cdots, r_{i-1} \right) \right)$ +/// 2. A linear polynomial: $Eq\left( \alpha_i ; X \right)$ +/// 3. A high degree polynomial: $\sum_{x} +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$ +/// +/// The advantage of the above decomposition is that the prover when computing $v_{i+1}(X)$ needs to sum over +/// +/// $$ +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// instead of +/// +/// $$ +/// Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1}, \alpha_i, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// which has the advantage of being of degree $1$ less and hence requires less work on the part of the prover. +/// +/// Thus, the prover computes the following polynomial +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and then scales it in order to get +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right) \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// As the prover computes $v_{i+1}^{'}(X)$ in evaluation form and hence also $v_{i+1}(X)$, this +/// means that due to the degrees being off by $1$, the prover uses the linear factor in order to +/// obtain an additional evaluation point in order to be able to interpolate $v_{i+1}(X)$. +/// More precisely, we can get a root of $$v_{i+1}(X) = 0$$ by solving $$Eq\left( \alpha_i ; X \right) = 0$$ +/// The latter equation has as solution $$\mathsf{r} = \frac{1 - \alpha}{1 - 2\cdot\alpha}$$ +/// which is, except with negligible probability, an evaluation point not in the original +/// evaluation set and hence the prover is able to interpolate $v_{i+1}(X)$ and send it to +/// the verifier. +/// +/// Note that in order to avoid having to compute $\{Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ from $\{Eq\left( \left( \alpha_{i}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i}, \cdots x_{\nu - 1} \right) \right)\}$, or vice versa, we can write +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// as +/// +/// $$v_{i+1}^{'}(X) = \frac{1}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \sum_{x} +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// Thus, $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ can be read from +/// $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{\nu - 1} \right);\left(x_{0}, \cdots x_{\nu - 1} \right) \right)\}$ +/// directly, at the cost of the relation between $v_{i+1}^{'}(X)$ and $v_{i+1}(X)$ becoming +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// +/// [1]: https://eprint.iacr.org/2023/1284 +/// [2]: https://eprint.iacr.org/2024/108 +#[allow(clippy::too_many_arguments)] +pub fn sum_check_prove_higher_degree< + E: FieldElement, + H: ElementHasher, +>( + evaluator: &impl LogUpGkrEvaluator::BaseField>, + evaluation_point: Vec, + claim: E, + r_sum_check: E, + log_up_randomness: Vec, + mut mls: Vec>, + mut periodic_table: PeriodicTable, + coin: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let num_rounds = mls[0].num_variables(); + + let mut round_proofs = vec![]; + + // split the evaluation point into two points of dimension mu and nu, respectively + let mu = evaluator.get_num_fractions().trailing_zeros() - 1; + let (evaluation_point_mu, evaluation_point_nu) = evaluation_point.split_at(mu as usize); + let eq_mu = EqFunction::ml_at(evaluation_point_mu.into()).evaluations().to_vec(); + let eq_nu = EqFunction::ml_at(evaluation_point_nu.into()); + + // setup first round claim + let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; + + // run the first round of the protocol + let mut round_poly_evals = sumcheck_round( + 0, + &eq_mu, + evaluator, + &eq_nu, + &mls, + &periodic_table, + &log_up_randomness, + r_sum_check, + ); + + // this will hold `Eq((\alpha_0, \cdots, \alpha_{i - 1});(r_0, \cdots, r_{i-1}))` + let mut scaling_up_factor = E::ONE; + // this will hold `Eq((\alpha_{0}, \cdots, \alpha_{i}); (0, \cdots, 0))` for all `i` + let scaling_down_factors = compute_scaling_down_factors(evaluation_point_nu); + // this is `\alpha_i` above + let mut alpha_i = evaluation_point_nu[0]; + let scaling_down_factor = scaling_down_factors[0]; + let round_poly_coefs = to_coefficients( + &mut round_poly_evals, + current_round_claim.claim, + alpha_i, + scaling_down_factor, + scaling_up_factor, + ); + + // reseed with the s_0 polynomial + coin.reseed(H::hash_elements(&round_poly_coefs.0)); + round_proofs.push(RoundProof { round_poly_coefs }); + + for i in 1..num_rounds { + // generate random challenge r_i for the i-th round + let round_challenge = + coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // update `scaling_up_factor` + alpha_i = evaluation_point_nu[evaluation_point_nu.len() - mls[0].num_variables()]; + scaling_up_factor *= + round_challenge * alpha_i + (E::ONE - round_challenge) * (E::ONE - alpha_i); + + // compute the new reduced round claim + let new_round_claim = + reduce_claim(&round_proofs[i - 1], current_round_claim, round_challenge); + + // fold each multi-linear using the round challenge + mls.iter_mut() + .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); + + // fold each periodic multi-linear using the round challenge + periodic_table.bind_least_significant_variable(round_challenge); + + // run the i-th round of the protocol using the folded multi-linears for the new reduced + // claim. This basically computes the s_i polynomial. + let mut round_poly_evals = sumcheck_round( + i, + &eq_mu, + evaluator, + &eq_nu, + &mls, + &periodic_table, + &log_up_randomness, + r_sum_check, + ); + + // update the claim + current_round_claim = new_round_claim; + + let alpha_i = evaluation_point_nu[i]; + let round_poly_coefs = to_coefficients( + &mut round_poly_evals, + current_round_claim.claim, + alpha_i, + scaling_down_factors[i], + scaling_up_factor, + ); + + // reseed with the s_i polynomial + coin.reseed(H::hash_elements(&round_poly_coefs.0)); + let round_proof = RoundProof { round_poly_coefs }; + round_proofs.push(round_proof); + } + + // generate the last random challenge + let round_challenge = + coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // fold each multi-linear using the last random round challenge + mls.iter_mut() + .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); + + let SumCheckRoundClaim { eval_point, claim: _claim } = + reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); + + let openings = mls.iter_mut().map(|ml| ml.evaluations()[0]).collect(); + + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { eval_point, openings }, + round_proofs, + }) +} + +/// Computes the polynomial +/// +/// $$ +/// s_i(X_i) := \sum_{(x_{i + 1},\cdots, x_{\nu - 1}) +/// w(r_0,\cdots, r_{i - 1}, X_i, x_{i + 1}, \cdots, x_{\nu - 1}). +/// $$ +/// +/// where +/// +/// $$ +/// w(x_0,\cdots, x_{\nu - 1}) := g(f_0((x_0,\cdots, x_{\nu - 1})), +/// \cdots , f_c((x_0,\cdots, x_{\nu - 1}))). +/// $$ +/// +/// where `g` is the expression defined in the documentation of [`sum_check_prove_higher_degree`] +/// +/// Given a degree bound `d_max` for all variables, it suffices to compute the evaluations of `s_i` +/// at `d_max + 1` points. Given that $s_{i}(0) = s_{i}(1) - s_{i - 1}(r_{i - 1})$ it is sufficient +/// to compute the evaluations on only `d_max` points. +/// +/// The algorithm works by iterating over the variables $(x_{i + 1}, \cdots, x_{\nu - 1})$ in +/// ${0, 1}^{\nu - 1 - i}$. For each such tuple, we store the evaluations of the (folded) +/// multi-linears at $(0, x_{i + 1}, \cdots, x_{\nu - 1})$ and +/// $(1, x_{i + 1}, \cdots, x_{\nu - 1})$ in two arrays, `evals_zero` and `evals_one`. +/// Using `evals_one`, remember that we optimize evaluating at 0 away, we get the first evaluation +/// i.e., $s_i(1)$. +/// +/// For the remaining evaluations, we use the fact that the folded `f_i` is multi-linear and hence +/// we can write +/// +/// $$ +/// f_i(X_i, x_{i + 1}, \cdots, x_{\nu - 1}) = +/// (1 - X_i) . f_i(0, x_{i + 1}, \cdots, x_{\nu - 1}) + +/// X_i . f_i(1, x_{i + 1}, \cdots, x_{\nu - 1}) +/// $$ +/// +/// Note that we omitted writing the folding randomness for readability. +/// Since the evaluation domain is $\{0, 1, ... , d_max\}$, we can compute the evaluations based on +/// the previous one using only additions. This is the purpose of `deltas`, to hold the increments +/// added to each multi-linear to compute the evaluation at the next point, and `evals_x` to hold +/// the current evaluation at $x$ in $\{2, ... , d_max\}$. +#[allow(clippy::too_many_arguments)] +fn sumcheck_round( + sum_check_round: usize, + eq_mu: &[E], + evaluator: &impl LogUpGkrEvaluator::BaseField>, + eq_ml: &MultiLinearPoly, + mls: &[MultiLinearPoly], + periodic_table: &PeriodicTable, + log_up_randomness: &[E], + r_sum_check: E, +) -> Vec { + let num_mls = mls.len(); + let num_periodic = periodic_table.num_columns(); + let num_vars = mls[0].num_variables(); + let num_rounds = num_vars - 1; + + #[cfg(not(feature = "concurrent"))] + let evaluations = { + let mut evals_one = vec![E::ZERO; num_mls]; + let mut evals_zero = vec![E::ZERO; num_mls]; + let mut evals_x = vec![E::ZERO; num_mls]; + + let mut evals_periodic_one = vec![E::ZERO; num_periodic]; + let mut evals_periodic_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_x = vec![E::ZERO; num_periodic]; + + let mut deltas = vec![E::ZERO; num_mls]; + let mut deltas_periodic = vec![E::ZERO; num_periodic]; + + let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; + (0..1 << num_rounds) + .map(|i| { + let mut poly_evals = vec![E::ZERO; evaluator.max_degree() - 1]; + + for (j, ml) in mls.iter().enumerate() { + evals_zero[j] = ml.evaluations()[2 * i]; + evals_one[j] = ml.evaluations()[2 * i + 1]; + } + + // add evaluation of periodic columns + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + + // `(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1})` + let j = i << (sum_check_round + 1); + // `Eq((\alpha_{0}, \cdots, \alpha_{\nu - 1}); (0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1})) ` + let eq_at_zero = eq_ml.evaluations()[j]; + + // compute the evaluation at 0 + evaluator.evaluate_query( + &evals_zero, + &evals_periodic_zero, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + poly_evals[0] = evaluate_composition_poly( + eq_mu, + &numerators, + &denominators, + eq_at_zero, + r_sum_check, + ); + + // compute the evaluations at `2, ..., d_max - 1` points + for i in 0..num_mls { + deltas[i] = evals_one[i] - evals_zero[i]; + evals_x[i] = evals_one[i]; + } + for i in 0..num_periodic { + deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; + evals_periodic_x[i] = evals_periodic_one[i]; + } + + for e in poly_evals.iter_mut().skip(1) { + evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); + + evaluator.evaluate_query( + &evals_x, + &evals_periodic_x, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + *e = evaluate_composition_poly( + eq_mu, + &numerators, + &denominators, + eq_at_zero, + r_sum_check, + ); + } + + poly_evals + }) + .fold(vec![E::ZERO; evaluator.max_degree() - 1], |mut acc, poly_eval| { + acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { + *a += *b; + }); + acc + }) + }; + + #[cfg(feature = "concurrent")] + let evaluations = (0..1 << num_rounds) + .into_par_iter() + .fold( + || { + ( + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_periodic], + vec![E::ZERO; evaluator.max_degree() - 1], + ) + }, + |( + mut evals_zero, + mut evals_one, + mut evals_x, + mut evals_periodic_zero, + mut evals_periodic_one, + mut evals_periodic_x, + mut numerators, + mut denominators, + mut deltas, + mut deltas_periodic, + mut poly_evals, + ), + i| { + for (j, ml) in mls.iter().enumerate() { + evals_zero[j] = ml.evaluations()[2 * i]; + evals_one[j] = ml.evaluations()[2 * i + 1]; + } + + // add evaluation of periodic columns + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + + // `(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1})` + let j = i << (sum_check_round + 1); + // `Eq((\alpha_{0}, \cdots, \alpha_{\nu - 1}); (0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1})) ` + let eq_at_zero = eq_ml.evaluations()[j]; + + // compute the evaluation at 0 + evaluator.evaluate_query( + &evals_zero, + &evals_periodic_zero, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + poly_evals[0] += evaluate_composition_poly( + eq_mu, + &numerators, + &denominators, + eq_at_zero, + r_sum_check, + ); + + // compute the evaluations at `2, ..., d_max - 1` points + for i in 0..num_mls { + deltas[i] = evals_one[i] - evals_zero[i]; + evals_x[i] = evals_one[i]; + } + for i in 0..num_periodic { + deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; + evals_periodic_x[i] = evals_periodic_one[i]; + } + + for e in poly_evals.iter_mut().skip(1) { + evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); + + evaluator.evaluate_query( + &evals_x, + &evals_periodic_x, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + *e += evaluate_composition_poly( + eq_mu, + &numerators, + &denominators, + eq_at_zero, + r_sum_check, + ); + } + + ( + evals_zero, + evals_one, + evals_x, + evals_periodic_zero, + evals_periodic_one, + evals_periodic_x, + numerators, + denominators, + deltas, + deltas_periodic, + poly_evals, + ) + }, + ) + .map(|(.., poly_evals)| poly_evals) + .reduce( + || vec![E::ZERO; evaluator.max_degree() - 1], + |mut acc, poly_eval| { + acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { + *a += *b; + }); + acc + }, + ); + + evaluations +} + +/// Reduces an old claim to a new claim using the round challenge. +fn reduce_claim( + current_poly: &RoundProof, + current_round_claim: SumCheckRoundClaim, + round_challenge: E, +) -> SumCheckRoundClaim { + // evaluate the round polynomial at the round challenge to obtain the new claim + let new_claim = current_poly + .round_poly_coefs + .evaluate_using_claim(¤t_round_claim.claim, &round_challenge); + + // update the evaluation point using the round challenge + let mut new_partial_eval_point = current_round_claim.eval_point; + new_partial_eval_point.push(round_challenge); + + SumCheckRoundClaim { + eval_point: new_partial_eval_point, + claim: new_claim, + } +} diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs new file mode 100644 index 000000000..66a110ee6 --- /dev/null +++ b/sumcheck/src/prover/mod.rs @@ -0,0 +1,80 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +mod high_degree; +use alloc::vec::Vec; + +pub use high_degree::sum_check_prove_higher_degree; + +mod plain; +use math::{batch_inversion, FieldElement}; +pub use plain::sumcheck_prove_plain; + +mod error; +pub use error::SumCheckProverError; + +use crate::CompressedUnivariatePoly; + +/// Takes the evaluation of the polynomial $v_{i+1}^{'}(X)$ defined by +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and computes the interpolation of the $v_{i+1}(X)$ polynomial defined by +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// The function returns a `CompressedUnivariatePoly` instead of the full list of coefficients. +fn to_coefficients( + round_poly_evals: &mut [E], + claim: E, + alpha: E, + scaling_down_factor: E, + scaling_up_factor: E, +) -> CompressedUnivariatePoly { + let a = scaling_down_factor; + round_poly_evals.iter_mut().for_each(|e| *e *= scaling_up_factor); + + let mut round_poly_evaluations = Vec::with_capacity(round_poly_evals.len() + 1); + round_poly_evaluations.push(round_poly_evals[0] * compute_weight(alpha, E::ZERO) * a); + round_poly_evaluations.push(claim - round_poly_evaluations[0]); + + for (x, eval) in round_poly_evals.iter().skip(1).enumerate() { + round_poly_evaluations.push(*eval * compute_weight(alpha, E::from(x as u32 + 2)) * a) + } + + let root = (E::ONE - alpha) / (E::ONE - alpha.double()); + + CompressedUnivariatePoly::interpolate_equidistant_points(&round_poly_evaluations, root) +} + +/// Computes +/// +/// $$ +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right) +/// $$ +/// +/// given $(\alpha_0, \cdots, \alpha_{\nu - 1})$ for all $i$ in $0, \cdots, \nu - 1$. +fn compute_scaling_down_factors(gkr_point: &[E]) -> Vec { + let cumulative_product: Vec = gkr_point + .iter() + .scan(E::ONE, |acc, &x| { + *acc *= E::ONE - x; + Some(*acc) + }) + .collect(); + batch_inversion(&cumulative_product) +} + +/// Computes $EQ(x; \alpha)$. +fn compute_weight(alpha: E, x: E) -> E { + x * alpha + (E::ONE - x) * (E::ONE - alpha) +} diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs new file mode 100644 index 000000000..8e4766b6a --- /dev/null +++ b/sumcheck/src/prover/plain.rs @@ -0,0 +1,270 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; + +use super::{compute_scaling_down_factors, to_coefficients, SumCheckProverError}; +use crate::{comb_func, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof}; + +/// Sum-check prover for non-linear multivariate polynomial of the simple LogUp-GKR. +/// +/// More specifically, the following function implements the logic of the sum-check prover as +/// described in Section 3.2 in [1], that is, given verifier challenges , the following implements +/// the sum-check prover for the following two statements +/// $$ +/// p_{\nu - \kappa}\left(v_{\kappa+1}, \cdots, v_{\nu}\right) = \sum_{w_i} +/// EQ\left(\left(v_{\kappa+1}, \cdots, v_{\nu}\right), \left(w_{\kappa+1}, \cdots, +/// w_{\nu}\right)\right) \cdot +/// \left( p_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right) + +/// p_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// $$ +/// q_{\nu -k}\left(v_{\kappa+1}, \cdots, v_{\nu}\right) = \sum_{w_i}EQ\left(\left(v_{\kappa+1}, +/// \cdots, v_{\nu}\right), \left(w_{\kappa+1}, \cdots, w_{\nu }\right)\right) \cdot +/// \left( q_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right)\right) +/// $$ +/// +/// for $k = 1, \cdots, \nu - 1$ +/// +/// Instead of executing two runs of the sum-check protocol, a batching randomness `r_batch` is +/// sent by the verifier at the outset in order to batch the two statments. +/// +/// Note that the degree of the non-linear composition polynomial is 3. +/// +/// +/// We now discuss a further optimization due to [2]. Suppose that we have a sum-check statment of +/// the following form: +/// +/// $$v_0=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{\nu - 1}\right);\left( x_0, \cdots, x_{\nu - 1}\right)\right) +/// C\left( x_0, \cdots, x_{\nu - 1} \right)$$ +/// +/// Then during round $i + 1$ of sum-check, the prover needs to send the following polynomial +/// +/// $$v_{i+1}(X)=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1},\alpha_i, \alpha_{i+1},\cdots\alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// We can write $v_{i+1}(X)$ as: +/// +/// $$v_{i+1}(X)=Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1} \right);\left(r_0,\cdots,r_{i-1}\right)\right) +/// \cdot Eq\left(\alpha_i ;X\right)\sum_{x}Eq\left(\left(\alpha_{i+1},\cdots\alpha_{\nu - 1}\right);\left( x_{i+1}, \cdots x_{\nu - 1}\right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// This means that $v_{i+1}(X)$ is the product of: +/// +/// 1. A constant polynomial: $Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right);\left( r_0, \cdots, r_{i-1} \right) \right)$ +/// 2. A linear polynomial: $Eq\left( \alpha_i ; X \right)$ +/// 3. A high degree polynomial: $\sum_{x} +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$ +/// +/// The advantage of the above decomposition is that the prover when computing $v_{i+1}(X)$ needs to sum over +/// +/// $$ +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// instead of +/// +/// $$ +/// Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1}, \alpha_i, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// which has the advantage of being of degree $1$ less and hence requires less work on the part of the prover. +/// +/// Thus, the prover computes the following polynomial +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and then scales it in order to get +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right) \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// As the prover computes $v_{i+1}^{'}(X)$ in evaluation form and hence also $v_{i+1}(X)$, this +/// means that due to the degrees being off by $1$, the prover uses the linear factor in order to +/// obtain an additional evaluation point in order to be able to interpolate $v_{i+1}(X)$. +/// More precisely, we can get a root of $$v_{i+1}(X) = 0$$ by solving $$Eq\left( \alpha_i ; X \right) = 0$$ +/// The latter equation has as solution $$\mathsf{r} = \frac{1 - \alpha}{1 - 2\cdot\alpha}$$ +/// which is, except with negligible probability, an evaluation point not in the original +/// evaluation set and hence the prover is able to interpolate $v_{i+1}(X)$ and send it to +/// the verifier. +/// +/// Note that in order to avoid having to compute $\{Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ from $\{Eq\left( \left( \alpha_{i}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i}, \cdots x_{\nu - 1} \right) \right)\}$, or vice versa, we can write +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// as +/// +/// $$v_{i+1}^{'}(X) = \frac{1}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \sum_{x} +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// Thus, $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ can be read from +/// $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{\nu - 1} \right);\left(x_{0}, \cdots x_{\nu - 1} \right) \right)\}$ +/// directly, at the cost of the relation between $v_{i+1}^{'}(X)$ and $v_{i+1}(X)$ becoming +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// +/// [1]: https://eprint.iacr.org/2023/1284 +/// [2]: https://eprint.iacr.org/2024/108 +#[allow(clippy::too_many_arguments)] +pub fn sumcheck_prove_plain>( + mut claim: E, + gkr_point: &[E], + r_batch: E, + p: MultiLinearPoly, + q: MultiLinearPoly, + eq: &mut MultiLinearPoly, + transcript: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let mut round_proofs = vec![]; + + let mut challenges = vec![]; + + // construct the vector of multi-linear polynomials + let (mut p0, mut p1) = p.project_least_significant_variable(); + let (mut q0, mut q1) = q.project_least_significant_variable(); + + let num_rounds = p0.num_variables(); + + let scaling_down_factors = compute_scaling_down_factors(gkr_point); + let mut scaling_up_factor = E::ONE; + + for i in 0..num_rounds { + let len = p0.num_evaluations() / 2; + + #[cfg(not(feature = "concurrent"))] + let (round_poly_eval_at_0, round_poly_eval_at_2) = + (0..len).fold((E::ZERO, E::ZERO), |(acc_point_0, acc_point_2), k| { + let j = k << (i + 1); + let round_poly_eval_at_0 = + comb_func(p0[2 * k], p1[2 * k], q0[2 * k], q1[2 * k], eq[j], r_batch); + + let p0_delta = p0[2 * k + 1] - p0[2 * k]; + let p1_delta = p1[2 * k + 1] - p1[2 * k]; + let q0_delta = q0[2 * k + 1] - q0[2 * k]; + let q1_delta = q1[2 * k + 1] - q1[2 * k]; + + let p0_eval_at_x = p0[2 * k + 1] + p0_delta; + let p1_eval_at_x = p1[2 * k + 1] + p1_delta; + let q0_eval_at_x = q0[2 * k + 1] + q0_delta; + let q1_eval_at_x = q1[2 * k + 1] + q1_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq[j], + r_batch, + ); + + (round_poly_eval_at_0 + acc_point_0, round_poly_eval_at_2 + acc_point_2) + }); + + #[cfg(feature = "concurrent")] + let (round_poly_eval_at_0, round_poly_eval_at_2) = (0..len) + .into_par_iter() + .fold( + || (E::ZERO, E::ZERO), + |(a, b), k| { + let j = k << (i + 1); + let round_poly_eval_at_0 = + comb_func(p0[2 * k], p1[2 * k], q0[2 * k], q1[2 * k], eq[j], r_batch); + + let p0_delta = p0[2 * k + 1] - p0[2 * k]; + let p1_delta = p1[2 * k + 1] - p1[2 * k]; + let q0_delta = q0[2 * k + 1] - q0[2 * k]; + let q1_delta = q1[2 * k + 1] - q1[2 * k]; + + let p0_eval_at_x = p0[2 * k + 1] + p0_delta; + let p1_eval_at_x = p1[2 * k + 1] + p1_delta; + let q0_eval_at_x = q0[2 * k + 1] + q0_delta; + let q1_eval_at_x = q1[2 * k + 1] + q1_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq[j], + r_batch, + ); + + (round_poly_eval_at_0 + a, round_poly_eval_at_2 + b) + }, + ) + .reduce(|| (E::ZERO, E::ZERO), |(a0, b0), (a1, b1)| (a0 + a1, b0 + b1)); + + let alpha_i = gkr_point[i]; + let compressed_round_poly = to_coefficients( + &mut [round_poly_eval_at_0, round_poly_eval_at_2], + claim, + alpha_i, + scaling_down_factors[i], + scaling_up_factor, + ); + + // reseed with the s_i polynomial + transcript.reseed(H::hash_elements(&compressed_round_poly.0)); + let round_proof = RoundProof { + round_poly_coefs: compressed_round_poly.clone(), + }; + + let round_challenge = + transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // fold each multi-linear using the round challenge + p0.bind_least_significant_variable(round_challenge); + p1.bind_least_significant_variable(round_challenge); + q0.bind_least_significant_variable(round_challenge); + q1.bind_least_significant_variable(round_challenge); + + // update the scaling up factor + scaling_up_factor *= + round_challenge * alpha_i + (E::ONE - round_challenge) * (E::ONE - alpha_i); + + // compute the new reduced round claim + claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); + + round_proofs.push(round_proof); + challenges.push(round_challenge); + } + + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { + eval_point: challenges, + openings: vec![p0[0], p1[0], q0[0], q1[0]], + }, + round_proofs, + }) +} diff --git a/sumcheck/src/univariate.rs b/sumcheck/src/univariate.rs new file mode 100644 index 000000000..ebc6dfa47 --- /dev/null +++ b/sumcheck/src/univariate.rs @@ -0,0 +1,278 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; + +use math::{batch_inversion, polynom, FieldElement}; +use smallvec::SmallVec; +use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +// CONSTANTS +// ================================================================================================ + +/// Maximum expected size of the round polynomials. This is needed for `SmallVec`. The size of +/// the round polynomials is dictated by the degree of the non-linearity in the sum-check statement +/// which is direcly influenced by the maximal degrees of the numerators and denominators appearing +/// in the LogUp-GKR relation and equal to one plus the maximal degree of the numerators and +/// maximal degree of denominators. +/// The following value assumes that this degree is at most 10. +const MAX_POLY_SIZE: usize = 10; + +// COMPRESSED UNIVARIATE POLYNOMIAL +// ================================================================================================ + +/// The coefficients of a univariate polynomial of degree n with the linear term coefficient +/// omitted. +/// +/// This compressed representation is useful during the sum-check protocol as the full uncompressed +/// representation can be recovered from the compressed one and the current sum-check round claim. +#[derive(Clone, Debug, PartialEq)] +pub struct CompressedUnivariatePoly(pub(crate) SmallVec<[E; MAX_POLY_SIZE]>); + +impl CompressedUnivariatePoly { + /// Evaluates a polynomial at a challenge point using a round claim. + /// + /// The round claim is used to recover the coefficient of the linear term using the relation + /// 2 * c0 + c1 + ... c_{n - 1} = claim. Using the complete list of coefficients, the polynomial + /// is then evaluated using Horner's method. + pub fn evaluate_using_claim(&self, claim: &E, challenge: &E) -> E { + // recover the coefficient of the linear term + let c1 = *claim - self.0.iter().fold(E::ZERO, |acc, term| acc + *term) - self.0[0]; + + // construct the full coefficient list + let mut complete_coefficients = vec![self.0[0], c1]; + complete_coefficients.extend_from_slice(&self.0[1..]); + + // evaluate + polynom::eval(&complete_coefficients, *challenge) + } + + /// Given the evaluations of a polynomial over the set $0, 1, \cdots, d - 1$ and a `root` not in + /// the interpolation set, computes its coefficients. + pub fn interpolate_equidistant_points(ys: &[E], root: E) -> CompressedUnivariatePoly { + // we factor out the term `(x - r)` where `r` is the root + let quotient: Vec = (0..ys.len()).map(|i| E::from(i as u32) - root).collect(); + let quotient_inv = batch_inversion("ient); + let mut ys: Vec = ys.iter().zip(quotient_inv.iter()).map(|(&y, &q)| y * q).collect(); + + // the zeroth coefficient can be recovered immediately + let c0 = ys.remove(0); + + // build the interpolation set + let n_minus_1 = ys.len(); + let points = (1..=n_minus_1 as u32).map(E::BaseField::from).collect::>(); + + // construct their inverses. These will be needed for computing the evaluations + // of the q polynomial as well as for doing the interpolation on q where q is + // defined as $p(x) = c0 + x * q(x) where q(x) = c1 + ... + c_{n-1} * x^{n - 2}$ + let points_inv = batch_inversion(&points); + + // compute the evaluations of q + let q_evals: Vec = ys + .iter() + .enumerate() + .map(|(i, evals)| (*evals - c0).mul_base(points_inv[i])) + .collect(); + + // interpolate q + let q_coefs = multiply_by_inverse_vandermonde(&q_evals, &points_inv); + + // append c0 to the coefficients of q to get the coefficients of p. The linear term + // coefficient is removed as this can be recovered from the other coefficients using + // the reduced claim. + let mut coefficients = SmallVec::<[E; MAX_POLY_SIZE]>::with_capacity(ys.len() + 1); + coefficients.push(c0); + coefficients.extend_from_slice(&q_coefs[..]); + + // multiply back the factor `(x - r)` + let mut p_coefficients = polynom::mul(&coefficients, &[-root, E::ONE]); + + // remove the linear factor as it can be recovered from the `claim` and the other factors + p_coefficients.remove(1); + + CompressedUnivariatePoly(p_coefficients.into()) + } +} + +impl Serializable for CompressedUnivariatePoly { + fn write_into(&self, target: &mut W) { + let vector: Vec = self.0.clone().into_vec(); + vector.write_into(target); + } +} + +impl Deserializable for CompressedUnivariatePoly +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + let vector: Vec = Vec::::read_from(source)?; + Ok(Self(vector.into())) + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Given a (row) vector `v`, computes the vector-matrix product `v * V^{-1}` where `V` is +/// the Vandermonde matrix over the points `1, ..., n` where `n` is the length of `v`. +/// The resulting vector will then be the coefficients of the minimal interpolating polynomial +/// through the points `(i+1, v[i])` for `i` in `0, ..., n - 1` +/// +/// The naive way would be to invert the matrix `V` and then compute the vector-matrix product +/// this will cost `O(n^3)` operations and `O(n^2)` memory. We can also try Gaussian elimination +/// but this is also worst case `O(n^3)` operations and `O(n^2)` memory. +/// In the following implementation, we use the fact that the points over which we are interpolating +/// is a set of equidistant points and thus both the Vandermonde matrix and its inverse can be +/// described by sparse linear recurrence equations. +/// More specifically, we use the representation given in [1], where `V^{-1}` is represented as +/// `U * M` where: +/// +/// 1. `M` is a lower triangular matrix where its entries are given by M(i, j) = M(i - 1, j) - M(i - +/// 1, j - 1) / (i - 1) with boundary conditions M(i, 1) = 1 and M(i, j) = 0 when j > i. +/// +/// 2. `U` is an upper triangular (involutory) matrix where its entries are given by U(i, j) = U(i, +/// j - 1) - U(i - 1, j - 1) with boundary condition U(1, j) = 1 and U(i, j) = 0 when i > j. +/// +/// Note that the matrix indexing in the formulas above matches the one in the reference and starts +/// from 1. +/// +/// The above implies that we can do the vector-matrix multiplication in `O(n^2)` and using only +/// `O(n)` space. +/// +/// [1]: https://link.springer.com/article/10.1007/s002110050360 +fn multiply_by_inverse_vandermonde( + vector: &[E], + nodes_inv: &[E::BaseField], +) -> Vec { + let res = multiply_by_u(vector); + multiply_by_m(&res, nodes_inv) +} + +/// Multiplies a (row) vector `v` by an upper triangular matrix `U` to compute `v * U`. +/// +/// `U` is an upper triangular (involutory) matrix with its entries given by +/// U(i, j) = U(i, j - 1) - U(i - 1, j - 1) +/// with boundary condition U(1, j) = 1 and U(i, j) = 0 when i > j. +fn multiply_by_u(vector: &[E]) -> Vec { + let n = vector.len(); + let mut previous_u_col = vec![E::BaseField::ZERO; n]; + previous_u_col[0] = E::BaseField::ONE; + let mut current_u_col = vec![E::BaseField::ZERO; n]; + current_u_col[0] = E::BaseField::ONE; + + let mut result: Vec = vec![E::ZERO; n]; + for (i, res) in result.iter_mut().enumerate() { + *res = vector[0]; + + for (j, v) in vector.iter().enumerate().take(i + 1).skip(1) { + let u_entry: E::BaseField = + compute_u_entry::(j, &mut previous_u_col, &mut current_u_col); + *res += v.mul_base(u_entry); + } + previous_u_col.clone_from(¤t_u_col); + } + + result +} + +/// Multiplies a (row) vector `v` by a lower triangular matrix `M` to compute `v * M`. +/// +/// `M` is a lower triangular matrix with its entries given by +/// M(i, j) = M(i - 1, j) - M(i - 1, j - 1) / (i - 1) +/// with boundary conditions M(i, 1) = 1 and M(i, j) = 0 when j > i. +fn multiply_by_m(vector: &[E], nodes_inv: &[E::BaseField]) -> Vec { + let n = vector.len(); + let mut previous_m_col = vec![E::BaseField::ONE; n]; + let mut current_m_col = vec![E::BaseField::ZERO; n]; + current_m_col[0] = E::BaseField::ONE; + + let mut result: Vec = vec![E::ZERO; n]; + result[0] = vector.iter().fold(E::ZERO, |acc, term| acc + *term); + for (i, res) in result.iter_mut().enumerate().skip(1) { + current_m_col = vec![E::BaseField::ZERO; n]; + + for (j, v) in vector.iter().enumerate().skip(i) { + let m_entry: E::BaseField = + compute_m_entry::(j, &mut previous_m_col, &mut current_m_col, nodes_inv[j - 1]); + *res += v.mul_base(m_entry); + } + previous_m_col.clone_from(¤t_m_col); + } + + result +} + +/// Returns the j-th entry of the i-th column of matrix `U` given the values of the (i - 1)-th +/// column. The i-th column is also updated with the just computed `U(i, j)` entry. +/// +/// `U` is an upper triangular (involutory) matrix with its entries given by +/// U(i, j) = U(i, j - 1) - U(i - 1, j - 1) +/// with boundary condition U(1, j) = 1 and U(i, j) = 0 when i > j. +fn compute_u_entry( + j: usize, + col_prev: &mut [E::BaseField], + col_cur: &mut [E::BaseField], +) -> E::BaseField { + let value = col_prev[j] - col_prev[j - 1]; + col_cur[j] = value; + value +} + +/// Returns the j-th entry of the i-th column of matrix `M` given the values of the (i - 1)-th +/// and the i-th columns. The i-th column is also updated with the just computed `M(i, j)` entry. +/// +/// `M` is a lower triangular matrix with its entries given by +/// M(i, j) = M(i - 1, j) - M(i - 1, j - 1) / (i - 1) +/// with boundary conditions M(i, 1) = 1 and M(i, j) = 0 when j > i. +fn compute_m_entry( + j: usize, + col_previous: &mut [E::BaseField], + col_current: &mut [E::BaseField], + node_inv: E::BaseField, +) -> E::BaseField { + let value = col_current[j - 1] - node_inv * col_previous[j - 1]; + col_current[j] = value; + value +} + +// TESTS +// ================================================================================================ + +#[test] +fn test_poly_partial() { + use math::fields::f64::BaseElement; + + let degree = 1000; + + // compute the claim + let p: Vec = rand_utils::rand_vector(degree); + let evals = polynom::eval_many(&p, &[BaseElement::ZERO, BaseElement::ONE]); + let claim = evals[0] + evals[1]; + + // build compressed polynomial + let mut poly_coeff = p.clone(); + poly_coeff.remove(1); + let poly_coeff = CompressedUnivariatePoly(poly_coeff.into()); + + // generate random challenge + let r = rand_utils::rand_vector(1); + + assert_eq!(polynom::eval(&p, r[0]), poly_coeff.evaluate_using_claim(&claim, &r[0])) +} + +#[test] +fn test_serialization() { + use math::fields::f64::BaseElement; + + let original_poly = + CompressedUnivariatePoly(rand_utils::rand_array::().into()); + let poly_bytes = original_poly.to_bytes(); + + let deserialized_poly = + CompressedUnivariatePoly::::read_from_bytes(&poly_bytes).unwrap(); + + assert_eq!(original_poly, deserialized_poly) +} diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs new file mode 100644 index 000000000..900be4c86 --- /dev/null +++ b/sumcheck/src/verifier/mod.rs @@ -0,0 +1,175 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; + +use air::{LogUpGkrEvaluator, PeriodicTable}; +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; + +use crate::{ + comb_func, evaluate_composition_poly, EqFunction, FinalLayerProof, FinalOpeningClaim, + MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, +}; + +/// Verifies sum-check proofs, as part of the GKR proof, for all GKR layers except for the last one +/// i.e., the circuit input layer. +pub fn verify_sum_check_intermediate_layers< + E: FieldElement, + H: ElementHasher, +>( + proof: &SumCheckProof, + gkr_eval_point: &[E], + claim: (E, E), + transcript: &mut impl RandomCoin, +) -> Result, SumCheckVerifierError> { + // generate challenge to batch sum-checks + transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + let r_batch: E = transcript + .draw() + .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; + + // compute the claim for the batched sum-check + let reduced_claim = claim.0 + claim.1 * r_batch; + + let SumCheckProof { openings_claim, round_proofs } = proof; + + let final_round_claim = verify_rounds(reduced_claim, round_proofs, transcript)?; + assert_eq!(openings_claim.eval_point, final_round_claim.eval_point); + + let p0 = openings_claim.openings[0]; + let p1 = openings_claim.openings[1]; + let q0 = openings_claim.openings[2]; + let q1 = openings_claim.openings[3]; + + let eq = EqFunction::new(gkr_eval_point.into()).evaluate(&openings_claim.eval_point); + + if comb_func(p0, p1, q0, q1, eq, r_batch) != final_round_claim.claim { + return Err(SumCheckVerifierError::FinalEvaluationCheckFailed); + } + + Ok(openings_claim.clone()) +} + +/// Sum-check verifier for the input layer. +/// +/// Verifies the final sum-check proof i.e., the one for the input layer, including the final check, +/// and returns a [`FinalOpeningClaim`] to the STARK verifier in order to verify the correctness of +/// the openings. +pub fn verify_sum_check_input_layer>( + evaluator: &impl LogUpGkrEvaluator, + proof: &FinalLayerProof, + log_up_randomness: Vec, + gkr_eval_point: &[E], + claim: (E, E), + transcript: &mut impl RandomCoin, +) -> Result, SumCheckVerifierError> { + // generate challenge to batch sum-checks + transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + let r_batch: E = transcript + .draw() + .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; + + // compute the claim for the batched sum-check + let reduced_claim = claim.0 + claim.1 * r_batch; + + // verify the sum-check proof + let SumCheckRoundClaim { eval_point, claim } = + verify_rounds(reduced_claim, &proof.0.round_proofs, transcript)?; + + // execute the final evaluation check + if proof.0.openings_claim.eval_point != eval_point { + return Err(SumCheckVerifierError::WrongOpeningPoint); + } + + let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; + + let periodic_columns = evaluator.build_periodic_values(); + let periodic_columns_evaluations = + evaluate_periodic_columns_at(periodic_columns, &proof.0.openings_claim.eval_point); + + evaluator.evaluate_query( + &proof.0.openings_claim.openings, + &periodic_columns_evaluations, + &log_up_randomness, + &mut numerators, + &mut denominators, + ); + + let mu = evaluator.get_num_fractions().trailing_zeros() - 1; + let (evaluation_point_mu, evaluation_point_nu) = gkr_eval_point.split_at(mu as usize); + + let eq_mu = EqFunction::new(evaluation_point_mu.into()).evaluations(); + let eq_nu = EqFunction::new(evaluation_point_nu.into()); + + let eq_nu_eval = eq_nu.evaluate(&proof.0.openings_claim.eval_point); + let expected_evaluation = + evaluate_composition_poly(&eq_mu, &numerators, &denominators, eq_nu_eval, r_batch); + + if expected_evaluation != claim { + Err(SumCheckVerifierError::FinalEvaluationCheckFailed) + } else { + Ok(proof.0.openings_claim.clone()) + } +} + +/// Verifies a round of the sum-check protocol without executing the final check. +fn verify_rounds( + claim: E, + round_proofs: &[RoundProof], + coin: &mut impl RandomCoin, +) -> Result, SumCheckVerifierError> +where + E: FieldElement, + H: ElementHasher, +{ + let mut round_claim = claim; + let mut evaluation_point = vec![]; + for round_proof in round_proofs { + let round_poly_coefs = round_proof.round_poly_coefs.clone(); + coin.reseed(H::hash_elements(&round_poly_coefs.0)); + + let r = coin.draw().map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; + + round_claim = round_proof.round_poly_coefs.evaluate_using_claim(&round_claim, &r); + evaluation_point.push(r); + } + + Ok(SumCheckRoundClaim { + eval_point: evaluation_point, + claim: round_claim, + }) +} + +#[derive(Debug, thiserror::Error)] +pub enum SumCheckVerifierError { + #[error("the final evaluation check of sum-check failed")] + FinalEvaluationCheckFailed, + #[error("failed to generate round challenge")] + FailedToGenerateChallenge, + #[error("wrong opening point for the oracles")] + WrongOpeningPoint, +} + +// HELPER +// ================================================================================================= + +/// Evaluates periodic columns as multi-linear extensions. +fn evaluate_periodic_columns_at( + periodic_columns: PeriodicTable, + eval_point: &[E], +) -> Vec { + let mut evaluations = vec![]; + for col in periodic_columns.table() { + let ml = MultiLinearPoly::from_evaluations(col.to_vec()); + let num_variables = ml.num_variables(); + let point = &eval_point[..num_variables]; + + let evaluation = ml.evaluate(point); + evaluations.push(evaluation) + } + evaluations +} diff --git a/utils/core/src/iterators.rs b/utils/core/src/iterators.rs index 2d9782730..cf8328ae9 100644 --- a/utils/core/src/iterators.rs +++ b/utils/core/src/iterators.rs @@ -115,3 +115,39 @@ macro_rules! batch_iter_mut { $c($e, 0); }; } + +/// Returns either a regular or a parallel iterator over at most `chunk_size` elements depending +/// on whether `concurrent` feature is enabled. +/// +/// When `concurrent` feature is enabled, creates a parallel iterator; otherwise, creates a +/// regular iterator. +#[macro_export] +macro_rules! chunks { + ($e: expr, $chunk_size: expr) => {{ + #[cfg(feature = "concurrent")] + let result = $e.par_chunks($chunk_size); + + #[cfg(not(feature = "concurrent"))] + let result = $e.chunks($chunk_size); + + result + }}; +} + +/// Returns either a regular or a parallel mutable iterator over at most `chunk_size` elements +/// depending on whether `concurrent` feature is enabled. +/// +/// When `concurrent` feature is enabled, creates a parallel iterator; otherwise, creates a +/// regular iterator. +#[macro_export] +macro_rules! chunks_mut { + ($e: expr, $chunk_size: expr) => {{ + #[cfg(feature = "concurrent")] + let result = $e.par_chunks_mut($chunk_size); + + #[cfg(not(feature = "concurrent"))] + let result = $e.chunks_mut($chunk_size); + + result + }}; +} diff --git a/verifier/Cargo.toml b/verifier/Cargo.toml index de8c3f24c..0c07d493c 100644 --- a/verifier/Cargo.toml +++ b/verifier/Cargo.toml @@ -24,6 +24,8 @@ air = { version = "0.9", path = "../air", package = "winter-air", default-featur crypto = { version = "0.9", path = "../crypto", package = "winter-crypto", default-features = false } fri = { version = "0.9", path = "../fri", package = "winter-fri", default-features = false } math = { version = "0.9", path = "../math", package = "winter-math", default-features = false } +sumcheck = { version = "0.1", path = "../sumcheck", package = "winter-sumcheck", default-features = false } +thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } # Allow math in docs diff --git a/verifier/src/channel.rs b/verifier/src/channel.rs index c84f4ec2a..5fefe991d 100644 --- a/verifier/src/channel.rs +++ b/verifier/src/channel.rs @@ -13,6 +13,8 @@ use air::{ use crypto::{ElementHasher, VectorCommitment}; use fri::VerifierChannel as FriVerifierChannel; use math::{FieldElement, StarkField}; +use sumcheck::GkrCircuitProof; +use utils::Deserializable; use crate::VerifierError; @@ -81,7 +83,7 @@ where let constraint_frame_width = air.context().num_constraint_composition_columns(); let num_trace_segments = air.trace_info().num_segments(); - let main_trace_width = air.trace_info().main_trace_width(); + let main_trace_width = air.trace_info().main_segment_width(); let aux_trace_width = air.trace_info().aux_segment_width(); let lde_domain_size = air.lde_domain_size(); let fri_options = air.options().to_fri_options(); @@ -172,9 +174,12 @@ where self.pow_nonce } - /// Returns the serialized GKR proof, if any. - pub fn read_gkr_proof(&self) -> Option<&Vec> { - self.gkr_proof.as_ref() + /// Returns the GKR proof, if any. + pub fn read_gkr_proof(&self) -> Result, VerifierError> { + GkrCircuitProof::read_from_bytes( + self.gkr_proof.as_ref().expect("Expected a GKR proof but there was none"), + ) + .map_err(|err| VerifierError::ProofDeserializationError(err.to_string())) } /// Returns trace states at the specified positions of the LDE domain. This also checks if @@ -313,7 +318,7 @@ where ); // parse main trace segment queries - let main_segment_width = air.trace_info().main_trace_width(); + let main_segment_width = air.trace_info().main_segment_width(); let main_segment_queries = queries.remove(0); let (main_segment_query_proofs, main_segment_states) = main_segment_queries .parse::(air.lde_domain_size(), num_queries, main_segment_width) @@ -331,7 +336,7 @@ where let aux_trace_states = if air.trace_info().is_multi_segment() { let mut aux_trace_states = Vec::new(); let segment_queries = queries.remove(0); - let segment_width = air.trace_info().get_aux_segment_width(); + let segment_width = air.trace_info().aux_segment_width(); let (segment_query_proof, segment_trace_states) = segment_queries .parse::(air.lde_domain_size(), num_queries, segment_width) .map_err(|err| { diff --git a/verifier/src/composer.rs b/verifier/src/composer.rs index 5f10ef79f..ae20f4586 100644 --- a/verifier/src/composer.rs +++ b/verifier/src/composer.rs @@ -43,7 +43,7 @@ impl DeepComposer { x_coordinates, z: [z, z * E::from(g_trace)], g_trace, - lagrange_kernel_column_idx: air.context().lagrange_kernel_aux_column_idx(), + lagrange_kernel_column_idx: air.context().lagrange_kernel_column_idx(), } } diff --git a/verifier/src/evaluator.rs b/verifier/src/evaluator.rs index 10910a555..a226ec9c8 100644 --- a/verifier/src/evaluator.rs +++ b/verifier/src/evaluator.rs @@ -7,7 +7,7 @@ use alloc::vec::Vec; use air::{ Air, AuxRandElements, ConstraintCompositionCoefficients, EvaluationFrame, - LagrangeKernelEvaluationFrame, + LagrangeKernelEvaluationFrame, LogUpGkrEvaluator, }; use math::{polynom, FieldElement}; @@ -89,32 +89,50 @@ pub fn evaluate_constraints>( // 3 ----- evaluate Lagrange kernel constraints ------------------------------------ if let Some(lagrange_kernel_column_frame) = lagrange_kernel_frame { + let logup_gkr_evaluator = air.get_logup_gkr_evaluator(); + let lagrange_coefficients = composition_coefficients .lagrange .expect("expected Lagrange kernel composition coefficients to be present"); - let lagrange_kernel_aux_rand_elements = { - let aux_rand_elements = - aux_rand_elements.expect("expected aux rand elements to be present"); - - aux_rand_elements - .lagrange() - .expect("expected lagrange rand elements to be present") - }; - - let lagrange_constraints = air - .get_lagrange_kernel_constraints( - lagrange_coefficients, - lagrange_kernel_aux_rand_elements, - ) - .expect("expected Lagrange kernel constraints to be present"); + + let gkr_data = aux_rand_elements + .expect("expected aux rand elements to be present") + .gkr_data() + .expect("expected LogUp-GKR rand elements to be present"); + + // Lagrange kernel constraints + + let lagrange_constraints = logup_gkr_evaluator.get_lagrange_kernel_constraints( + lagrange_coefficients, + &gkr_data.lagrange_kernel_eval_point, + ); result += lagrange_constraints.transition.evaluate_and_combine::( lagrange_kernel_column_frame, - lagrange_kernel_aux_rand_elements, + &gkr_data.lagrange_kernel_eval_point, x, ); - result += lagrange_constraints.boundary.evaluate_at(x, lagrange_kernel_column_frame); + + // s-column constraints + + let s_col_idx = air.trace_info().s_column_idx().expect("s-column should be present"); + + let aux_trace_frame = + aux_trace_frame.as_ref().expect("expected aux rand elements to be present"); + + let s_cur = aux_trace_frame.current()[s_col_idx]; + let s_nxt = aux_trace_frame.next()[s_col_idx]; + let l_cur = lagrange_kernel_column_frame.inner()[0]; + + let s_column_cc = composition_coefficients + .s_col + .expect("expected constraint composition coefficient for s-column to be present"); + + let s_column_constraint = + logup_gkr_evaluator.get_s_column_constraints(gkr_data, s_column_cc); + + result += s_column_constraint.evaluate(air, main_trace_frame, s_cur, s_nxt, l_cur, x); } result diff --git a/verifier/src/lib.rs b/verifier/src/lib.rs index 2c75ecd1d..cabbda6c7 100644 --- a/verifier/src/lib.rs +++ b/verifier/src/lib.rs @@ -38,7 +38,7 @@ pub use air::{ ConstraintCompositionCoefficients, ConstraintDivisor, DeepCompositionCoefficients, EvaluationFrame, FieldExtension, ProofOptions, TraceInfo, TransitionConstraintDegree, }; -use air::{AuxRandElements, GkrVerifier}; +use air::{AuxRandElements, LogUpGkrEvaluator}; pub use crypto; use crypto::{ElementHasher, Hasher, RandomCoin, VectorCommitment}; use fri::FriVerifier; @@ -47,6 +47,7 @@ use math::{ fields::{CubeExtension, QuadExtension}, FieldElement, ToElements, }; +use sumcheck::FinalOpeningClaim; pub use utils::{ ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader, }; @@ -60,6 +61,9 @@ use evaluator::evaluate_constraints; mod composer; use composer::DeepComposer; +mod logup_gkr; +use logup_gkr::verify_gkr; + mod errors; pub use errors::VerifierError; @@ -175,35 +179,40 @@ where // process auxiliary trace segments (if any), to build a set of random elements for each segment let aux_trace_rand_elements = if air.trace_info().is_multi_segment() { - if air.context().has_lagrange_kernel_aux_column() { - let gkr_proof = { - let gkr_proof_serialized = channel - .read_gkr_proof() - .expect("Expected an a GKR proof because trace has lagrange kernel column"); - - Deserializable::read_from_bytes(gkr_proof_serialized) - .map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))? - }; - let gkr_rand_elements = air - .get_gkr_proof_verifier::() - .verify::(gkr_proof, &mut public_coin) - .map_err(|err| VerifierError::GkrProofVerificationFailed(err.to_string()))?; - - let rand_elements = air.get_aux_rand_elements(&mut public_coin).expect( - "failed to generate the random elements needed to build the auxiliary trace", - ); + // build the set of random elements related to the auxiliary segment without the LogUp-GKR + // related ones. + let trace_rand_elements = air + .get_aux_rand_elements(&mut public_coin) + .expect("failed to generate the random elements needed to build the auxiliary trace"); + + // if LogUp-GKR is enabled, verify the attached proof and build an object which includes + // randomness and data related to LogUp-GKR + if air.context().logup_gkr_enabled() { + let gkr_proof = channel.read_gkr_proof()?; + let logup_gkr_evaluator = air.get_logup_gkr_evaluator(); + + let FinalOpeningClaim { eval_point, openings } = verify_gkr::( + air.context().public_inputs(), + &gkr_proof, + &logup_gkr_evaluator, + &mut public_coin, + ) + .map_err(|err| VerifierError::GkrProofVerificationFailed(err.to_string()))?; + + let gkr_data = logup_gkr_evaluator + .generate_univariate_iop_for_multi_linear_opening_data( + openings, + eval_point, + &mut public_coin, + ); public_coin.reseed(trace_commitments[AUX_TRACE_IDX]); - Some(AuxRandElements::new_with_gkr(rand_elements, gkr_rand_elements)) + Some(AuxRandElements::new(trace_rand_elements, Some(gkr_data))) } else { - let rand_elements = air.get_aux_rand_elements(&mut public_coin).expect( - "failed to generate the random elements needed to build the auxiliary trace", - ); - public_coin.reseed(trace_commitments[AUX_TRACE_IDX]); - Some(AuxRandElements::new(rand_elements)) + Some(AuxRandElements::new(trace_rand_elements, None)) } } else { None diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs new file mode 100644 index 000000000..e317e0ab1 --- /dev/null +++ b/verifier/src/logup_gkr/mod.rs @@ -0,0 +1,115 @@ +use alloc::vec::Vec; + +use air::{Air, LogUpGkrEvaluator}; +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +use sumcheck::{ + verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, + FinalOpeningClaim, GkrCircuitProof, SumCheckVerifierError, +}; + +/// Verifies the validity of a GKR proof for a LogUp-GKR relation. +pub fn verify_gkr< + A: Air, + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + pub_inputs: &A::PublicInputs, + proof: &GkrCircuitProof, + evaluator: &impl LogUpGkrEvaluator, + transcript: &mut C, +) -> Result, VerifierError> { + let num_logup_random_values = evaluator.get_num_rand_values(); + let mut logup_randomness: Vec = Vec::with_capacity(num_logup_random_values); + + for _ in 0..num_logup_random_values { + logup_randomness.push(transcript.draw().expect("failed to generate randomness")); + } + + let GkrCircuitProof { + circuit_outputs, + before_final_layer_proofs, + final_layer_proof, + } = proof; + + let CircuitOutput { numerators, denominators } = circuit_outputs; + let p0 = numerators.evaluations()[0]; + let p1 = numerators.evaluations()[1]; + let q0 = denominators.evaluations()[0]; + let q1 = denominators.evaluations()[1]; + + // make sure that both denominators are not equal to E::ZERO + if q0 == E::ZERO || q1 == E::ZERO { + return Err(VerifierError::ZeroOutputDenominator); + } + + // check that the output matches the expected `claim` + let claim = evaluator.compute_claim(pub_inputs, &logup_randomness); + if (p0 * q1 + p1 * q0) / (q0 * q1) != claim { + return Err(VerifierError::MismatchingCircuitOutput); + } + + // generate the random challenge to reduce two claims into a single claim + let mut evaluations = numerators.evaluations().to_vec(); + evaluations.extend_from_slice(denominators.evaluations()); + transcript.reseed(H::hash_elements(&evaluations)); + let r = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; + + // reduce the claim + let p_r = p0 + r * (p1 - p0); + let q_r = q0 + r * (q1 - q0); + let mut reduced_claim = (p_r, q_r); + + // verify all GKR layers but for the last one + let num_layers = before_final_layer_proofs.proof.len(); + let mut evaluation_point = vec![r]; + for i in 0..num_layers { + let FinalOpeningClaim { eval_point, openings } = verify_sum_check_intermediate_layers( + &before_final_layer_proofs.proof[i], + &evaluation_point, + reduced_claim, + transcript, + )?; + + // generate the random challenge to reduce two claims into a single claim + transcript.reseed(H::hash_elements(&openings)); + let r_layer = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; + + let p0 = openings[0]; + let p1 = openings[1]; + let q0 = openings[2]; + let q1 = openings[3]; + reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); + + // collect the randomness used for the current layer + let rand_sumcheck = eval_point; + let mut ext = vec![r_layer]; + ext.extend_from_slice(&rand_sumcheck); + evaluation_point = ext; + } + + // verify the proof of the final GKR layer and pass final opening claim for verification + // to the STARK + verify_sum_check_input_layer( + evaluator, + final_layer_proof, + logup_randomness, + &evaluation_point, + reduced_claim, + transcript, + ) + .map_err(VerifierError::FailedToVerifySumCheck) +} + +#[derive(Debug, thiserror::Error)] +pub enum VerifierError { + #[error("one of the claimed circuit denominators is zero")] + ZeroOutputDenominator, + #[error("the output of the fraction circuit is not equal to the expected value")] + MismatchingCircuitOutput, + #[error("failed to generate the random challenge")] + FailedToGenerateChallenge, + #[error("failed to verify the sum-check proof")] + FailedToVerifySumCheck(#[from] SumCheckVerifierError), +} diff --git a/winterfell/src/lib.rs b/winterfell/src/lib.rs index 86c5e0345..621796864 100644 --- a/winterfell/src/lib.rs +++ b/winterfell/src/lib.rs @@ -150,12 +150,13 @@ //! ```no_run //! use winterfell::{ //! math::{fields::f128::BaseElement, FieldElement, ToElements}, -//! Air, AirContext, Assertion, GkrVerifier, EvaluationFrame, +//! Air, AirContext, Assertion, EvaluationFrame, //! ProofOptions, TraceInfo, TransitionConstraintDegree, //! crypto::{hashers::Blake3_256, DefaultRandomCoin, MerkleTree}, //! }; //! //! // Public inputs for our computation will consist of the starting value and the end result. +//! #[derive(Clone)] //! pub struct PublicInputs { //! start: BaseElement, //! result: BaseElement, @@ -172,7 +173,7 @@ //! // the computation's context which we'll build in the constructor. The context is used //! // internally by the Winterfell prover/verifier when interpreting this AIR. //! pub struct WorkAir { -//! context: AirContext, +//! context: AirContext, //! start: BaseElement, //! result: BaseElement, //! } @@ -182,8 +183,6 @@ //! // the public inputs must look like. //! type BaseField = BaseElement; //! type PublicInputs = PublicInputs; -//! type GkrProof = (); -//! type GkrVerifier = (); //! //! // Here, we'll construct a new instance of our computation which is defined by 3 //! // parameters: starting value, number of steps, and the end result. Another way to @@ -206,7 +205,7 @@ //! let num_assertions = 2; //! //! WorkAir { -//! context: AirContext::new(trace_info, degrees, num_assertions, options), +//! context: AirContext::new(trace_info, pub_inputs.clone(), degrees, num_assertions, options), //! start: pub_inputs.start, //! result: pub_inputs.result, //! } @@ -246,7 +245,7 @@ //! //! // This is just boilerplate which is used by the Winterfell prover/verifier to retrieve //! // the context of the computation. -//! fn context(&self) -> &AirContext { +//! fn context(&self) -> &AirContext { //! &self.context //! } //! } @@ -269,6 +268,7 @@ //! # EvaluationFrame, TraceInfo, TransitionConstraintDegree, //! # }; //! # +//! # #[derive(Clone)] //! # pub struct PublicInputs { //! # start: BaseElement, //! # result: BaseElement, @@ -281,7 +281,7 @@ //! # } //! # //! # pub struct WorkAir { -//! # context: AirContext, +//! # context: AirContext, //! # start: BaseElement, //! # result: BaseElement, //! # } @@ -289,14 +289,12 @@ //! # impl Air for WorkAir { //! # type BaseField = BaseElement; //! # type PublicInputs = PublicInputs; -//! # type GkrProof = (); -//! # type GkrVerifier = (); //! # //! # fn new(trace_info: TraceInfo, pub_inputs: PublicInputs, options: ProofOptions) -> Self { //! # assert_eq!(1, trace_info.width()); //! # let degrees = vec![TransitionConstraintDegree::new(3)]; //! # WorkAir { -//! # context: AirContext::new(trace_info, degrees, 2, options), +//! # context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 2, options), //! # start: pub_inputs.start, //! # result: pub_inputs.result, //! # } @@ -321,7 +319,7 @@ //! # ] //! # } //! # -//! # fn context(&self) -> &AirContext { +//! # fn context(&self) -> &AirContext { //! # &self.context //! # } //! # } @@ -418,7 +416,7 @@ //! # trace //! # } //! # -//! # +//! # #[derive(Clone)] //! # pub struct PublicInputs { //! # start: BaseElement, //! # result: BaseElement, @@ -431,7 +429,7 @@ //! # } //! # //! # pub struct WorkAir { -//! # context: AirContext, +//! # context: AirContext, //! # start: BaseElement, //! # result: BaseElement, //! # } @@ -439,14 +437,12 @@ //! # impl Air for WorkAir { //! # type BaseField = BaseElement; //! # type PublicInputs = PublicInputs; -//! # type GkrProof = (); -//! # type GkrVerifier = (); //! # //! # fn new(trace_info: TraceInfo, pub_inputs: PublicInputs, options: ProofOptions) -> Self { //! # assert_eq!(1, trace_info.width()); //! # let degrees = vec![TransitionConstraintDegree::new(3)]; //! # WorkAir { -//! # context: AirContext::new(trace_info, degrees, 2, options), +//! # context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 2, options), //! # start: pub_inputs.start, //! # result: pub_inputs.result, //! # } @@ -471,7 +467,7 @@ //! # ] //! # } //! # -//! # fn context(&self) -> &AirContext { +//! # fn context(&self) -> &AirContext { //! # &self.context //! # } //! # } @@ -594,14 +590,14 @@ #[cfg(test)] extern crate std; -pub use air::{AuxRandElements, GkrVerifier}; +pub use air::{AuxRandElements, LogUpGkrEvaluator}; pub use prover::{ crypto, iterators, math, matrix, Air, AirContext, Assertion, AuxTraceWithMetadata, BoundaryConstraint, BoundaryConstraintGroup, CompositionPolyTrace, ConstraintCompositionCoefficients, ConstraintDivisor, ConstraintEvaluator, DeepCompositionCoefficients, DefaultConstraintEvaluator, DefaultTraceLde, EvaluationFrame, - FieldExtension, Proof, ProofOptions, Prover, ProverError, ProverGkrProof, StarkDomain, Trace, - TraceInfo, TraceLde, TracePolyTable, TraceTable, TraceTableFragment, + FieldExtension, LogUpGkrConstraintEvaluator, Proof, ProofOptions, Prover, ProverError, + StarkDomain, Trace, TraceInfo, TraceLde, TracePolyTable, TraceTable, TraceTableFragment, TransitionConstraintDegree, }; pub use verifier::{verify, AcceptableOptions, ByteWriter, VerifierError}; diff --git a/winterfell/src/tests.rs b/winterfell/src/tests.rs deleted file mode 100644 index 3757e2010..000000000 --- a/winterfell/src/tests.rs +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -use std::{vec, vec::Vec}; - -use air::{GkrRandElements, LagrangeKernelRandElements}; -use crypto::MerkleTree; -use prover::{ - crypto::{hashers::Blake3_256, DefaultRandomCoin, RandomCoin}, - math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, - matrix::ColMatrix, -}; - -use super::*; - -const AUX_TRACE_WIDTH: usize = 2; - -#[test] -fn test_complex_lagrange_kernel_air() { - let trace = LagrangeComplexTrace::new(2_usize.pow(10), AUX_TRACE_WIDTH); - - let prover = LagrangeComplexProver::new(AUX_TRACE_WIDTH); - - let proof = prover.prove(trace).unwrap(); - - verify::< - LagrangeKernelComplexAir, - Blake3_256, - DefaultRandomCoin>, - MerkleTree>, - >(proof, (), &AcceptableOptions::MinConjecturedSecurity(0)) - .unwrap() -} - -// LagrangeComplexTrace -// ================================================================================================= - -#[derive(Clone, Debug)] -struct LagrangeComplexTrace { - // dummy main trace - main_trace: ColMatrix, - info: TraceInfo, -} - -impl LagrangeComplexTrace { - fn new(trace_len: usize, aux_segment_width: usize) -> Self { - assert!(trace_len < u32::MAX.try_into().unwrap()); - - let main_trace_col: Vec = - (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); - - Self { - main_trace: ColMatrix::new(vec![main_trace_col]), - info: TraceInfo::new_multi_segment(1, aux_segment_width, 0, trace_len, vec![]), - } - } - - fn len(&self) -> usize { - self.main_trace.num_rows() - } -} - -impl Trace for LagrangeComplexTrace { - type BaseField = BaseElement; - - fn info(&self) -> &TraceInfo { - &self.info - } - - fn main_segment(&self) -> &ColMatrix { - &self.main_trace - } - - fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { - let next_row_idx = row_idx + 1; - assert_ne!(next_row_idx, self.len()); - - self.main_trace.read_row_into(row_idx, frame.current_mut()); - self.main_trace.read_row_into(next_row_idx, frame.next_mut()); - } -} - -// AIR -// ================================================================================================= - -#[derive(Debug, Clone, Default)] -struct DummyGkrVerifier; - -impl GkrVerifier for DummyGkrVerifier { - // `GkrProof` is log(trace_len) for this dummy example, so that the verifier knows how many aux - // random variables to generate - type GkrProof = usize; - type Error = VerifierError; - - fn verify( - &self, - gkr_proof: usize, - public_coin: &mut impl RandomCoin, - ) -> Result, Self::Error> - where - E: FieldElement, - Hasher: crypto::ElementHasher, - { - let log_trace_len = gkr_proof; - let lagrange_kernel_rand_elements: LagrangeKernelRandElements = { - let mut rand_elements = Vec::with_capacity(log_trace_len); - for _ in 0..log_trace_len { - rand_elements.push(public_coin.draw().unwrap()); - } - - LagrangeKernelRandElements::new(rand_elements) - }; - - Ok(GkrRandElements::new(lagrange_kernel_rand_elements, Vec::new())) - } -} - -struct LagrangeKernelComplexAir { - context: AirContext, -} - -impl Air for LagrangeKernelComplexAir { - type BaseField = BaseElement; - // `GkrProof` is log(trace_len) for this dummy example, so that the verifier knows how many aux - // random variables to generate - type GkrProof = usize; - type GkrVerifier = DummyGkrVerifier; - - type PublicInputs = (); - - fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { - Self { - context: AirContext::new_multi_segment( - trace_info, - vec![TransitionConstraintDegree::new(1)], - vec![TransitionConstraintDegree::new(1)], - 1, - 1, - Some(1), - options, - ), - } - } - - fn context(&self) -> &AirContext { - &self.context - } - - fn evaluate_transition>( - &self, - frame: &EvaluationFrame, - _periodic_values: &[E], - result: &mut [E], - ) { - let current = frame.current()[0]; - let next = frame.next()[0]; - - // increments by 1 - result[0] = next - current - E::ONE; - } - - fn get_assertions(&self) -> Vec> { - vec![Assertion::single(0, 0, BaseElement::ZERO)] - } - - fn evaluate_aux_transition( - &self, - _main_frame: &EvaluationFrame, - _aux_frame: &EvaluationFrame, - _periodic_values: &[F], - _aux_rand_elements: &AuxRandElements, - _result: &mut [E], - ) where - F: FieldElement, - E: FieldElement + ExtensionOf, - { - // do nothing - } - - fn get_aux_assertions>( - &self, - _aux_rand_elements: &AuxRandElements, - ) -> Vec> { - vec![Assertion::single(0, 0, E::ZERO)] - } - - fn get_gkr_proof_verifier>( - &self, - ) -> Self::GkrVerifier { - DummyGkrVerifier - } -} - -// LagrangeComplexProver -// ================================================================================================ - -struct LagrangeComplexProver { - aux_trace_width: usize, - options: ProofOptions, -} - -impl LagrangeComplexProver { - fn new(aux_trace_width: usize) -> Self { - Self { - aux_trace_width, - options: ProofOptions::new(1, 2, 0, FieldExtension::None, 2, 1), - } - } -} - -impl Prover for LagrangeComplexProver { - type BaseField = BaseElement; - type Air = LagrangeKernelComplexAir; - type Trace = LagrangeComplexTrace; - type HashFn = Blake3_256; - type VC = MerkleTree>; - type RandomCoin = DefaultRandomCoin; - type TraceLde> = - DefaultTraceLde; - type ConstraintEvaluator<'a, E: FieldElement> = - DefaultConstraintEvaluator<'a, LagrangeKernelComplexAir, E>; - - fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { - } - - fn options(&self) -> &ProofOptions { - &self.options - } - - fn new_trace_lde( - &self, - trace_info: &TraceInfo, - main_trace: &ColMatrix, - domain: &StarkDomain, - ) -> (Self::TraceLde, TracePolyTable) - where - E: math::FieldElement, - { - DefaultTraceLde::new(trace_info, main_trace, domain) - } - - fn new_evaluator<'a, E>( - &self, - air: &'a Self::Air, - aux_rand_elements: Option>, - composition_coefficients: ConstraintCompositionCoefficients, - ) -> Self::ConstraintEvaluator<'a, E> - where - E: math::FieldElement, - { - DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) - } - - fn generate_gkr_proof( - &self, - main_trace: &Self::Trace, - public_coin: &mut Self::RandomCoin, - ) -> (ProverGkrProof, GkrRandElements) - where - E: FieldElement, - { - let main_trace = main_trace.main_segment(); - let log_trace_len = main_trace.num_rows().ilog2() as usize; - let lagrange_kernel_rand_elements = { - let mut rand_elements = Vec::with_capacity(log_trace_len); - for _ in 0..log_trace_len { - rand_elements.push(public_coin.draw().unwrap()); - } - - LagrangeKernelRandElements::new(rand_elements) - }; - - (log_trace_len, GkrRandElements::new(lagrange_kernel_rand_elements, Vec::new())) - } - - fn build_aux_trace( - &self, - main_trace: &Self::Trace, - aux_rand_elements: &AuxRandElements, - ) -> ColMatrix - where - E: FieldElement, - { - let main_trace = main_trace.main_segment(); - let lagrange_kernel_rand_elements = aux_rand_elements - .lagrange() - .expect("expected lagrange random elements to be present."); - - let mut columns = Vec::new(); - - // First all other auxiliary columns - let rand_summed = lagrange_kernel_rand_elements.iter().fold(E::ZERO, |acc, &r| acc + r); - for _ in 1..self.aux_trace_width { - // building a dummy auxiliary column - let column = main_trace - .get_column(0) - .iter() - .map(|row_val| rand_summed.mul_base(*row_val)) - .collect(); - - columns.push(column); - } - - // then build the Lagrange kernel column - { - let r = &lagrange_kernel_rand_elements; - - let mut lagrange_col = Vec::with_capacity(main_trace.num_rows()); - - for row_idx in 0..main_trace.num_rows() { - let mut row_value = E::ONE; - for (bit_idx, &r_i) in r.iter().enumerate() { - if row_idx & (1 << bit_idx) == 0 { - row_value *= E::ONE - r_i; - } else { - row_value *= r_i; - } - } - lagrange_col.push(row_value); - } - - columns.push(lagrange_col); - } - - ColMatrix::new(columns) - } -} diff --git a/winterfell/src/tests/logup_gkr_periodic.rs b/winterfell/src/tests/logup_gkr_periodic.rs new file mode 100644 index 000000000..849cbbd5d --- /dev/null +++ b/winterfell/src/tests/logup_gkr_periodic.rs @@ -0,0 +1,357 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, vec, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, FieldExtension, + LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, +}; +use crypto::MerkleTree; +use math::StarkField; + +use super::super::*; +use crate::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultTraceLde, Prover, StarkDomain, TracePolyTable, +}; + +#[test] +fn test_logup_gkr_periodic() { + let aux_trace_width = 1; + let trace = LogUpGkrPeriodic::new(2_usize.pow(12), aux_trace_width); + let prover = LogUpGkrPeriodicProver::new(aux_trace_width); + + let proof = prover.prove(trace).unwrap(); + + verify::< + LogUpGkrPeriodicAir, + Blake3_256, + DefaultRandomCoin>, + MerkleTree>, + >(proof, (), &AcceptableOptions::MinConjecturedSecurity(0)) + .unwrap() +} + +// LogUpGkrPeriodic +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrPeriodic { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrPeriodic { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + let mut multiplicity = vec![BaseElement::ZERO; trace_len]; + multiplicity.iter_mut().step_by(8).for_each(|m| *m = BaseElement::from(3_u32)); + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_0[8 * i] = BaseElement::from(8 * i as u32); + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_1[8 * i] = BaseElement::from(8 * i as u32); + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_2[8 * i] = BaseElement::from(8 * i as u32); + } + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrPeriodic { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrPeriodicAir { + context: AirContext, +} + +impl Air for LogUpGkrPeriodicAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + (), + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PeriodicLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PeriodicLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PeriodicLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PeriodicLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![vec![ + Self::BaseField::ONE, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + ]] + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::from(periodic_values[0]); + numerator[2] = E::from(periodic_values[0]); + numerator[3] = E::from(periodic_values[0]); + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} + +// Prover +// ================================================================================================ + +struct LogUpGkrPeriodicProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrPeriodicProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrPeriodicProver { + type BaseField = BaseElement; + type Air = LogUpGkrPeriodicAir; + type Trace = LogUpGkrPeriodic; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + LogUpGkrConstraintEvaluator<'a, LogUpGkrPeriodicAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs new file mode 100644 index 000000000..6c814c948 --- /dev/null +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -0,0 +1,369 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, vec, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, FieldExtension, + LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, +}; +use crypto::MerkleTree; +use math::StarkField; + +use super::super::*; +use crate::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultTraceLde, Prover, StarkDomain, TracePolyTable, +}; + +#[test] +fn test_logup_gkr() { + let aux_trace_width = 1; + let trace = LogUpGkrSimple::new(2_usize.pow(7), aux_trace_width); + let prover = LogUpGkrSimpleProver::new(aux_trace_width); + + let proof = prover.prove(trace).unwrap(); + + verify::< + LogUpGkrSimpleAir, + Blake3_256, + DefaultRandomCoin>, + MerkleTree>, + >(proof, (), &AcceptableOptions::MinConjecturedSecurity(0)) + .unwrap() +} + +// LogUpGkrSimple +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrSimple { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimple { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + multiplicity[1] = BaseElement::new(3 * 4); + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimple { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let committed_0_next_row = LogUpGkrOracle::NextRow(0); + let committed_1_next_row = LogUpGkrOracle::NextRow(1); + let committed_2_next_row = LogUpGkrOracle::NextRow(2); + let committed_3_next_row = LogUpGkrOracle::NextRow(3); + let committed_4_next_row = LogUpGkrOracle::NextRow(4); + let oracles = vec![ + committed_0, + committed_1, + committed_2, + committed_3, + committed_4, + committed_0_next_row, + committed_1_next_row, + committed_2_next_row, + committed_3_next_row, + committed_4_next_row, + ]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query[0] = frame.current()[0]; + query[1] = frame.current()[1]; + query[2] = frame.current()[2]; + query[3] = frame.current()[3]; + query[4] = frame.current()[4]; + + query[5] = frame.next()[0]; + query[6] = frame.next()[1]; + query[7] = frame.next()[2]; + query[8] = frame.next()[3]; + query[9] = frame.next()[4]; + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 10); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} +// Prover +// ================================================================================================ + +struct LogUpGkrSimpleProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrSimpleProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrSimpleProver { + type BaseField = BaseElement; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimple; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} diff --git a/winterfell/src/tests/mod.rs b/winterfell/src/tests/mod.rs new file mode 100644 index 000000000..51881e55e --- /dev/null +++ b/winterfell/src/tests/mod.rs @@ -0,0 +1,8 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +mod logup_gkr_simple; + +mod logup_gkr_periodic;