diff --git a/sp1-gpu/crates/air/src/air_block.rs b/sp1-gpu/crates/air/src/air_block.rs index 4933caef14..5afd616fc1 100644 --- a/sp1-gpu/crates/air/src/air_block.rs +++ b/sp1-gpu/crates/air/src/air_block.rs @@ -3,11 +3,12 @@ use crate::{ }; use itertools::Itertools; use slop_air::{Air, AirBuilder, PairBuilder}; -use slop_algebra::AbstractField; +use slop_algebra::{AbstractExtensionField, AbstractField}; use slop_matrix::Matrix; use sp1_core_executor::events::FieldOperation; -use sp1_core_executor::SyscallCode; -use sp1_core_machine::air::{MemoryAirBuilder, SP1CoreAirBuilder}; +use sp1_core_executor::{ByteOpcode, SyscallCode}; +use sp1_core_machine::air::{MemoryAirBuilder, SP1CoreAirBuilder, WordAirBuilder}; +use sp1_core_machine::global::{GlobalChip, GlobalCols}; use sp1_core_machine::operations::{ AddrAddOperation, AddressSlicePageProtOperation, SyscallAddrOperation, }; @@ -26,14 +27,18 @@ use sp1_curves::params::FieldParameters; use sp1_curves::params::{Limbs, NumLimbs}; use sp1_curves::weierstrass::WeierstrassParameters; use sp1_curves::{BigUint, CurveType, EllipticCurve}; -use sp1_hypercube::air::InstructionAirBuilder; -#[cfg(feature = "mprotect")] -use sp1_hypercube::air::MachineAirBuilder; +use sp1_hypercube::air::{InstructionAirBuilder, MachineAirBuilder}; use sp1_hypercube::operations::poseidon2::air::{eval_external_round, eval_internal_rounds}; -use sp1_hypercube::operations::poseidon2::WIDTH; +use sp1_hypercube::operations::poseidon2::permutation::Poseidon2Cols; +use sp1_hypercube::operations::poseidon2::{NUM_EXTERNAL_ROUNDS, WIDTH}; +use sp1_hypercube::septic_curve::SepticCurve; +use sp1_hypercube::septic_extension::SepticExtension; use sp1_hypercube::Word; use sp1_hypercube::{ - air::{AirInteraction, InteractionScope, MachineAir, MessageBuilder}, + air::{ + AirInteraction, ByteAirBuilder, InteractionScope, MachineAir, MessageBuilder, + SepticExtensionAirBuilder, + }, InteractionKind, }; use sp1_primitives::consts::{PROT_READ, PROT_WRITE}; @@ -56,6 +61,30 @@ pub trait BlockAir: Air + MachineAir + 'static + Send + S } } +/// Number of [`BlockAir`] blocks consumed by a single Poseidon2 permutation: one block per external +/// round, plus one block holding all internal rounds. +pub const POSEIDON2_PERM_NUM_BLOCKS: usize = NUM_EXTERNAL_ROUNDS + 1; + +/// Evaluates the `index`-th block of the Poseidon2 permutation constraints over `perm_cols`. +/// +/// `index` must be in `0..POSEIDON2_PERM_NUM_BLOCKS`. The first `NUM_EXTERNAL_ROUNDS` indices each +/// evaluate one external round; the final index evaluates all internal rounds. +fn eval_poseidon2_perm_block( + builder: &mut AB, + perm_cols: &dyn Poseidon2Cols, + index: usize, +) where + AB: MachineAirBuilder + PairBuilder, +{ + if index < NUM_EXTERNAL_ROUNDS { + eval_external_round(builder, perm_cols, index); + } else if index == NUM_EXTERNAL_ROUNDS { + eval_internal_rounds(builder, perm_cols); + } else { + panic!("Poseidon2 permutation block index out of range: {index}"); + } +} + impl<'a> BlockAir> for RiscvAir { fn num_blocks(&self) -> usize { match self { @@ -64,6 +93,7 @@ impl<'a> BlockAir> for RiscvAir { RiscvAir::Secp256k1AddUser(secp256k1_add) => secp256k1_add.num_blocks(), RiscvAir::Secp256k1Double(secp256k1_double) => secp256k1_double.num_blocks(), RiscvAir::Secp256k1DoubleUser(secp256k1_double) => secp256k1_double.num_blocks(), + RiscvAir::Global(global) => global.num_blocks(), _ => 1, } } @@ -79,6 +109,7 @@ impl<'a> BlockAir> for RiscvAir { RiscvAir::Secp256k1DoubleUser(secp256k1_double) => { secp256k1_double.eval_block(builder, index) } + RiscvAir::Global(global) => global.eval_block(builder, index), _ => { assert!(index == 0); self.eval(builder); @@ -855,7 +886,7 @@ where impl<'a, const DEGREE: usize> BlockAir> for Poseidon2WideChip { fn num_blocks(&self) -> usize { - 9 + POSEIDON2_PERM_NUM_BLOCKS } fn eval_block(&self, builder: &mut SymbolicProverFolder<'a>, index: usize) { @@ -865,38 +896,275 @@ impl<'a, const DEGREE: usize> BlockAir> for Poseidon2Wi let prep_local = prepr.row_slice(0); let prep_local: &Poseidon2PreprocessedColsWide<_> = (*prep_local).borrow(); + if index == 0 { + // Dummy constraints to normalize to DEGREE. + let lhs = (0..DEGREE) + .map(|_| local_row.external_rounds_state()[0][0].into()) + .product::(); + let rhs = (0..DEGREE) + .map(|_| local_row.external_rounds_state()[0][0].into()) + .product::(); + builder.assert_eq(lhs, rhs); + + (0..WIDTH).for_each(|i| { + builder.receive_single( + prep_local.input[i], + local_row.external_rounds_state()[0][i], + prep_local.is_real, + ) + }); + + (0..WIDTH).for_each(|i| { + builder.send_single( + prep_local.output[i].addr, + local_row.perm_output()[i], + prep_local.output[i].mult, + ) + }); + } + eval_poseidon2_perm_block(builder, local_row.as_ref(), index); + } +} + +/// Number of `GlobalChip` [`BlockAir`] blocks consumed *after* the Poseidon2 permutation: one +/// block each for the curve formula, y-coordinate sign + range checks, and digest accumulation. +const GLOBAL_NUM_EC_BLOCKS: usize = 3; + +impl<'a> BlockAir> for GlobalChip { + fn num_blocks(&self) -> usize { + POSEIDON2_PERM_NUM_BLOCKS + GLOBAL_NUM_EC_BLOCKS + } + + fn eval_block(&self, builder: &mut SymbolicProverFolder<'a>, index: usize) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &GlobalCols = (*local).borrow(); + + let cols = local.interaction; + let acc = local.accumulation; + let is_real = local.is_real; + let is_receive: SymbolicExprF = local.is_receive.into(); + let is_send: SymbolicExprF = local.is_send.into(); + match index { 0 => { - // Dummy constraints to normalize to DEGREE. - let lhs = (0..DEGREE) - .map(|_| local_row.external_rounds_state()[0][0].into()) - .product::(); - let rhs = (0..DEGREE) - .map(|_| local_row.external_rounds_state()[0][0].into()) - .product::(); - builder.assert_eq(lhs, rhs); - - (0..WIDTH).for_each(|i| { - builder.receive_single( - prep_local.input[i], - local_row.external_rounds_state()[0][i], - prep_local.is_real, - ) - }); + // Top-level constraints from `GlobalChip::eval`. + builder.assert_bool(is_real); + builder.receive( + AirInteraction::new( + vec![ + SymbolicExprF::from(local.message[0]), + local.message[1].into(), + local.message[2].into(), + local.message[3].into(), + local.message[4].into(), + local.message[5].into(), + local.message[6].into(), + local.message[7].into(), + local.is_send.into(), + local.is_receive.into(), + local.kind.into(), + ], + is_real.into(), + InteractionKind::Global, + ), + InteractionScope::Local, + ); + + // Setup constraints from `GlobalInteractionOperation::eval_single_digest`. + builder.assert_bool(is_real); + builder.when(is_real).assert_eq(is_receive + is_send, SymbolicExprF::one()); + builder.assert_bool(is_receive); + builder.assert_bool(is_send); + + builder.send_byte( + SymbolicExprF::from_canonical_u32(ByteOpcode::U8Range as u32), + SymbolicExprF::zero(), + SymbolicExprF::zero(), + SymbolicExprF::from(cols.offset), + SymbolicExprF::from(is_real), + ); - (0..WIDTH).for_each(|i| { - builder.send_single( - prep_local.output[i].addr, - local_row.perm_output()[i], - prep_local.output[i].mult, - ) + builder.when(is_real).assert_eq( + SymbolicExprF::from(local.message[0]), + local.message_0_16bit_limb + + local.message_0_8bit_limb * F::from_canonical_u32(1 << 16), + ); + + builder.slice_range_check_u16( + &[SymbolicExprF::from(local.message_0_16bit_limb), local.message[7].into()], + is_real, + ); + builder.slice_range_check_u8(&[local.message_0_8bit_limb], is_real); + + builder.send_byte( + SymbolicExprF::from_canonical_u32(ByteOpcode::Range as u32), + SymbolicExprF::from(local.kind), + SymbolicExprF::from_canonical_u32(6), + SymbolicExprF::zero(), + SymbolicExprF::from(is_real), + ); + + // Constrain the permutation input to equal the hash trial. + let m_trial: [SymbolicExprF; WIDTH] = [ + SymbolicExprF::from(local.message[0]) + + SymbolicExprF::from_canonical_u32(1 << 24) * local.kind, + local.message[1].into(), + local.message[2].into(), + local.message[3].into(), + local.message[4].into(), + local.message[5].into(), + local.message[6].into(), + SymbolicExprF::from(local.message[7]) + + SymbolicExprF::from_canonical_u32(1 << 16) * cols.offset, + SymbolicExprF::zero(), + SymbolicExprF::zero(), + SymbolicExprF::zero(), + SymbolicExprF::zero(), + SymbolicExprF::zero(), + SymbolicExprF::zero(), + SymbolicExprF::zero(), + SymbolicExprF::zero(), + ]; + for (perm_input, trial) in cols.permutation.permutation.external_rounds_state()[0] + .iter() + .zip(m_trial.iter()) + { + builder.when(is_real).assert_eq(*perm_input, *trial); + } + + eval_poseidon2_perm_block(builder, &cols.permutation.permutation, 0); + } + i if i < POSEIDON2_PERM_NUM_BLOCKS - 1 => { + eval_poseidon2_perm_block(builder, &cols.permutation.permutation, i); + } + i if i == POSEIDON2_PERM_NUM_BLOCKS - 1 => { + eval_poseidon2_perm_block(builder, &cols.permutation.permutation, i); + + // The Poseidon2 output is the x-coordinate of the curve point. + let m_hash = cols.permutation.permutation.perm_output(); + for (x_coord, hash) in cols.x_coordinate.0.iter().zip(m_hash.iter()) { + builder.when(is_real).assert_eq(*x_coord, *hash); + } + } + 9 => { + // (x, y) lies on the septic curve y^2 = x^3 + 45x + 41z^3. + let x = SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(cols.x_coordinate[i]) + }); + let y = SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(cols.y_coordinate[i]) }); - eval_external_round(builder, local_row.as_ref(), 0); + let y2 = y.square(); + let curve = SepticCurve::::curve_formula(x); + builder.assert_septic_ext_eq(y2, curve); } - 1..8 => { - eval_external_round(builder, local_row.as_ref(), index); + 10 => { + // y6 byte decomposition + sign-of-y constraints. + let y = SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(cols.y_coordinate[i]) + }); + + let mut y6_value = SymbolicExprF::zero(); + for i in 0..3 { + y6_value = y6_value + + cols.y6_byte_decomp[i] * SymbolicExprF::from_canonical_u32(1 << (8 * i)); + builder.send_byte( + SymbolicExprF::from_canonical_u32(ByteOpcode::U8Range as u32), + SymbolicExprF::zero(), + SymbolicExprF::zero(), + SymbolicExprF::from(cols.y6_byte_decomp[i]), + SymbolicExprF::from(is_real), + ); + } + y6_value = + y6_value + cols.y6_byte_decomp[3] * SymbolicExprF::from_canonical_u32(1 << 24); + builder.send_byte( + SymbolicExprF::from_canonical_u32(ByteOpcode::LTU as u32), + SymbolicExprF::one(), + SymbolicExprF::from(cols.y6_byte_decomp[3]), + SymbolicExprF::from_canonical_u8(63), + SymbolicExprF::from(is_real), + ); + + builder.when(is_receive).assert_eq(y.0[6], SymbolicExprF::one() + y6_value); + builder.when(is_send).assert_zero(y.0[6] + SymbolicExprF::one() + y6_value); + } + 11 => { + // Accumulation: receive previous digest, check sum, send next digest. + builder.assert_bool(is_real); + builder.receive( + AirInteraction::new( + vec![local.index] + .into_iter() + .chain(acc.initial_digest.into_iter().flat_map(|septic| septic.0)) + .map(SymbolicExprF::from) + .collect(), + SymbolicExprF::from(is_real), + InteractionKind::GlobalAccumulation, + ), + InteractionScope::Local, + ); + + let initial_digest = SepticCurve:: { + x: SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(acc.initial_digest[0][i]) + }), + y: SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(acc.initial_digest[1][i]) + }), + }; + let cumulative_sum = SepticCurve:: { + x: SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(acc.cumulative_sum[0].0[i]) + }), + y: SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(acc.cumulative_sum[1].0[i]) + }), + }; + let point_to_add = SepticCurve:: { + x: SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(cols.x_coordinate.0[i]) + }), + y: SepticExtension::::from_base_fn(|i| { + SymbolicExprF::from(cols.y_coordinate.0[i]) + }), + }; + + let sum_checker_x = SepticCurve::::sum_checker_x( + initial_digest, + point_to_add, + cumulative_sum, + ); + let sum_checker_y = SepticCurve::::sum_checker_y( + initial_digest, + point_to_add, + cumulative_sum, + ); + + builder + .assert_septic_ext_eq(sum_checker_x, SepticExtension::::zero()); + builder + .when(is_real) + .assert_septic_ext_eq(sum_checker_y, SepticExtension::::zero()); + + builder.send( + AirInteraction::new( + vec![local.index + SymbolicExprF::one()] + .into_iter() + .chain( + acc.cumulative_sum + .into_iter() + .flat_map(|septic| septic.0) + .map(SymbolicExprF::from), + ) + .collect(), + SymbolicExprF::from(is_real), + InteractionKind::GlobalAccumulation, + ), + InteractionScope::Local, + ); } - 8 => eval_internal_rounds(builder, local_row.as_ref()), _ => unreachable!(), } }