Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 303 additions & 35 deletions sp1-gpu/crates/air/src/air_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use crate::{
};
use itertools::Itertools;
use slop_air::{Air, AirBuilder, PairBuilder};
use slop_algebra::AbstractField;
use slop_algebra::{AbstractExtensionField, AbstractField};
use slop_matrix::Matrix;
use sp1_core_executor::events::FieldOperation;
use sp1_core_executor::SyscallCode;
use sp1_core_machine::air::{MemoryAirBuilder, SP1CoreAirBuilder};
use sp1_core_executor::{ByteOpcode, SyscallCode};
use sp1_core_machine::air::{MemoryAirBuilder, SP1CoreAirBuilder, WordAirBuilder};
use sp1_core_machine::global::{GlobalChip, GlobalCols};
use sp1_core_machine::operations::{
AddrAddOperation, AddressSlicePageProtOperation, SyscallAddrOperation,
};
Expand All @@ -26,14 +27,18 @@ use sp1_curves::params::FieldParameters;
use sp1_curves::params::{Limbs, NumLimbs};
use sp1_curves::weierstrass::WeierstrassParameters;
use sp1_curves::{BigUint, CurveType, EllipticCurve};
use sp1_hypercube::air::InstructionAirBuilder;
#[cfg(feature = "mprotect")]
use sp1_hypercube::air::MachineAirBuilder;
use sp1_hypercube::air::{InstructionAirBuilder, MachineAirBuilder};
use sp1_hypercube::operations::poseidon2::air::{eval_external_round, eval_internal_rounds};
use sp1_hypercube::operations::poseidon2::WIDTH;
use sp1_hypercube::operations::poseidon2::permutation::Poseidon2Cols;
use sp1_hypercube::operations::poseidon2::{NUM_EXTERNAL_ROUNDS, WIDTH};
use sp1_hypercube::septic_curve::SepticCurve;
use sp1_hypercube::septic_extension::SepticExtension;
use sp1_hypercube::Word;
use sp1_hypercube::{
air::{AirInteraction, InteractionScope, MachineAir, MessageBuilder},
air::{
AirInteraction, ByteAirBuilder, InteractionScope, MachineAir, MessageBuilder,
SepticExtensionAirBuilder,
},
InteractionKind,
};
use sp1_primitives::consts::{PROT_READ, PROT_WRITE};
Expand All @@ -56,6 +61,30 @@ pub trait BlockAir<AB: AirBuilder>: Air<AB> + MachineAir<F> + 'static + Send + S
}
}

/// Number of [`BlockAir`] blocks consumed by a single Poseidon2 permutation: one block per external
/// round, plus one block holding all internal rounds.
pub const POSEIDON2_PERM_NUM_BLOCKS: usize = NUM_EXTERNAL_ROUNDS + 1;

/// Evaluates the `index`-th block of the Poseidon2 permutation constraints over `perm_cols`.
///
/// `index` must be in `0..POSEIDON2_PERM_NUM_BLOCKS`. The first `NUM_EXTERNAL_ROUNDS` indices each
/// evaluate one external round; the final index evaluates all internal rounds.
fn eval_poseidon2_perm_block<AB>(
builder: &mut AB,
perm_cols: &dyn Poseidon2Cols<AB::Var>,
index: usize,
) where
AB: MachineAirBuilder + PairBuilder,
{
if index < NUM_EXTERNAL_ROUNDS {
eval_external_round(builder, perm_cols, index);
} else if index == NUM_EXTERNAL_ROUNDS {
eval_internal_rounds(builder, perm_cols);
} else {
panic!("Poseidon2 permutation block index out of range: {index}");
}
}

impl<'a> BlockAir<SymbolicProverFolder<'a>> for RiscvAir<F> {
fn num_blocks(&self) -> usize {
match self {
Expand All @@ -64,6 +93,7 @@ impl<'a> BlockAir<SymbolicProverFolder<'a>> for RiscvAir<F> {
RiscvAir::Secp256k1AddUser(secp256k1_add) => secp256k1_add.num_blocks(),
RiscvAir::Secp256k1Double(secp256k1_double) => secp256k1_double.num_blocks(),
RiscvAir::Secp256k1DoubleUser(secp256k1_double) => secp256k1_double.num_blocks(),
RiscvAir::Global(global) => global.num_blocks(),
_ => 1,
}
}
Expand All @@ -79,6 +109,7 @@ impl<'a> BlockAir<SymbolicProverFolder<'a>> for RiscvAir<F> {
RiscvAir::Secp256k1DoubleUser(secp256k1_double) => {
secp256k1_double.eval_block(builder, index)
}
RiscvAir::Global(global) => global.eval_block(builder, index),
_ => {
assert!(index == 0);
self.eval(builder);
Expand Down Expand Up @@ -855,7 +886,7 @@ where

impl<'a, const DEGREE: usize> BlockAir<SymbolicProverFolder<'a>> for Poseidon2WideChip<DEGREE> {
fn num_blocks(&self) -> usize {
9
POSEIDON2_PERM_NUM_BLOCKS
}

fn eval_block(&self, builder: &mut SymbolicProverFolder<'a>, index: usize) {
Expand All @@ -865,38 +896,275 @@ impl<'a, const DEGREE: usize> BlockAir<SymbolicProverFolder<'a>> for Poseidon2Wi
let prep_local = prepr.row_slice(0);
let prep_local: &Poseidon2PreprocessedColsWide<_> = (*prep_local).borrow();

if index == 0 {
// Dummy constraints to normalize to DEGREE.
let lhs = (0..DEGREE)
.map(|_| local_row.external_rounds_state()[0][0].into())
.product::<SymbolicExprF>();
let rhs = (0..DEGREE)
.map(|_| local_row.external_rounds_state()[0][0].into())
.product::<SymbolicExprF>();
builder.assert_eq(lhs, rhs);

(0..WIDTH).for_each(|i| {
builder.receive_single(
prep_local.input[i],
local_row.external_rounds_state()[0][i],
prep_local.is_real,
)
});

(0..WIDTH).for_each(|i| {
builder.send_single(
prep_local.output[i].addr,
local_row.perm_output()[i],
prep_local.output[i].mult,
)
});
}
eval_poseidon2_perm_block(builder, local_row.as_ref(), index);
}
}

/// Number of `GlobalChip` [`BlockAir`] blocks consumed *after* the Poseidon2 permutation: one
/// block each for the curve formula, y-coordinate sign + range checks, and digest accumulation.
const GLOBAL_NUM_EC_BLOCKS: usize = 3;

impl<'a> BlockAir<SymbolicProverFolder<'a>> for GlobalChip {
fn num_blocks(&self) -> usize {
POSEIDON2_PERM_NUM_BLOCKS + GLOBAL_NUM_EC_BLOCKS
}

fn eval_block(&self, builder: &mut SymbolicProverFolder<'a>, index: usize) {
let main = builder.main();
let local = main.row_slice(0);
let local: &GlobalCols<SymbolicVarF> = (*local).borrow();

let cols = local.interaction;
let acc = local.accumulation;
let is_real = local.is_real;
let is_receive: SymbolicExprF = local.is_receive.into();
let is_send: SymbolicExprF = local.is_send.into();

match index {
0 => {
// Dummy constraints to normalize to DEGREE.
let lhs = (0..DEGREE)
.map(|_| local_row.external_rounds_state()[0][0].into())
.product::<SymbolicExprF>();
let rhs = (0..DEGREE)
.map(|_| local_row.external_rounds_state()[0][0].into())
.product::<SymbolicExprF>();
builder.assert_eq(lhs, rhs);

(0..WIDTH).for_each(|i| {
builder.receive_single(
prep_local.input[i],
local_row.external_rounds_state()[0][i],
prep_local.is_real,
)
});
// Top-level constraints from `GlobalChip::eval`.
builder.assert_bool(is_real);
builder.receive(
AirInteraction::new(
vec![
SymbolicExprF::from(local.message[0]),
local.message[1].into(),
local.message[2].into(),
local.message[3].into(),
local.message[4].into(),
local.message[5].into(),
local.message[6].into(),
local.message[7].into(),
local.is_send.into(),
local.is_receive.into(),
local.kind.into(),
],
is_real.into(),
InteractionKind::Global,
),
InteractionScope::Local,
);

// Setup constraints from `GlobalInteractionOperation::eval_single_digest`.
builder.assert_bool(is_real);
builder.when(is_real).assert_eq(is_receive + is_send, SymbolicExprF::one());
builder.assert_bool(is_receive);
builder.assert_bool(is_send);

builder.send_byte(
SymbolicExprF::from_canonical_u32(ByteOpcode::U8Range as u32),
SymbolicExprF::zero(),
SymbolicExprF::zero(),
SymbolicExprF::from(cols.offset),
SymbolicExprF::from(is_real),
);

(0..WIDTH).for_each(|i| {
builder.send_single(
prep_local.output[i].addr,
local_row.perm_output()[i],
prep_local.output[i].mult,
)
builder.when(is_real).assert_eq(
SymbolicExprF::from(local.message[0]),
local.message_0_16bit_limb
+ local.message_0_8bit_limb * F::from_canonical_u32(1 << 16),
);

builder.slice_range_check_u16(
&[SymbolicExprF::from(local.message_0_16bit_limb), local.message[7].into()],
is_real,
);
builder.slice_range_check_u8(&[local.message_0_8bit_limb], is_real);

builder.send_byte(
SymbolicExprF::from_canonical_u32(ByteOpcode::Range as u32),
SymbolicExprF::from(local.kind),
SymbolicExprF::from_canonical_u32(6),
SymbolicExprF::zero(),
SymbolicExprF::from(is_real),
);

// Constrain the permutation input to equal the hash trial.
let m_trial: [SymbolicExprF; WIDTH] = [
SymbolicExprF::from(local.message[0])
+ SymbolicExprF::from_canonical_u32(1 << 24) * local.kind,
local.message[1].into(),
local.message[2].into(),
local.message[3].into(),
local.message[4].into(),
local.message[5].into(),
local.message[6].into(),
SymbolicExprF::from(local.message[7])
+ SymbolicExprF::from_canonical_u32(1 << 16) * cols.offset,
SymbolicExprF::zero(),
SymbolicExprF::zero(),
SymbolicExprF::zero(),
SymbolicExprF::zero(),
SymbolicExprF::zero(),
SymbolicExprF::zero(),
SymbolicExprF::zero(),
SymbolicExprF::zero(),
];
for (perm_input, trial) in cols.permutation.permutation.external_rounds_state()[0]
.iter()
.zip(m_trial.iter())
{
builder.when(is_real).assert_eq(*perm_input, *trial);
}

eval_poseidon2_perm_block(builder, &cols.permutation.permutation, 0);
}
i if i < POSEIDON2_PERM_NUM_BLOCKS - 1 => {
eval_poseidon2_perm_block(builder, &cols.permutation.permutation, i);
}
i if i == POSEIDON2_PERM_NUM_BLOCKS - 1 => {
eval_poseidon2_perm_block(builder, &cols.permutation.permutation, i);

// The Poseidon2 output is the x-coordinate of the curve point.
let m_hash = cols.permutation.permutation.perm_output();
for (x_coord, hash) in cols.x_coordinate.0.iter().zip(m_hash.iter()) {
builder.when(is_real).assert_eq(*x_coord, *hash);
}
}
9 => {
// (x, y) lies on the septic curve y^2 = x^3 + 45x + 41z^3.
let x = SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(cols.x_coordinate[i])
});
let y = SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(cols.y_coordinate[i])
});
eval_external_round(builder, local_row.as_ref(), 0);
let y2 = y.square();
let curve = SepticCurve::<SymbolicExprF>::curve_formula(x);
builder.assert_septic_ext_eq(y2, curve);
}
1..8 => {
eval_external_round(builder, local_row.as_ref(), index);
10 => {
// y6 byte decomposition + sign-of-y constraints.
let y = SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(cols.y_coordinate[i])
});

let mut y6_value = SymbolicExprF::zero();
for i in 0..3 {
y6_value = y6_value
+ cols.y6_byte_decomp[i] * SymbolicExprF::from_canonical_u32(1 << (8 * i));
builder.send_byte(
SymbolicExprF::from_canonical_u32(ByteOpcode::U8Range as u32),
SymbolicExprF::zero(),
SymbolicExprF::zero(),
SymbolicExprF::from(cols.y6_byte_decomp[i]),
SymbolicExprF::from(is_real),
);
}
y6_value =
y6_value + cols.y6_byte_decomp[3] * SymbolicExprF::from_canonical_u32(1 << 24);
builder.send_byte(
SymbolicExprF::from_canonical_u32(ByteOpcode::LTU as u32),
SymbolicExprF::one(),
SymbolicExprF::from(cols.y6_byte_decomp[3]),
SymbolicExprF::from_canonical_u8(63),
SymbolicExprF::from(is_real),
);

builder.when(is_receive).assert_eq(y.0[6], SymbolicExprF::one() + y6_value);
builder.when(is_send).assert_zero(y.0[6] + SymbolicExprF::one() + y6_value);
}
11 => {
// Accumulation: receive previous digest, check sum, send next digest.
builder.assert_bool(is_real);
builder.receive(
AirInteraction::new(
vec![local.index]
.into_iter()
.chain(acc.initial_digest.into_iter().flat_map(|septic| septic.0))
.map(SymbolicExprF::from)
.collect(),
SymbolicExprF::from(is_real),
InteractionKind::GlobalAccumulation,
),
InteractionScope::Local,
);

let initial_digest = SepticCurve::<SymbolicExprF> {
x: SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(acc.initial_digest[0][i])
}),
y: SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(acc.initial_digest[1][i])
}),
};
let cumulative_sum = SepticCurve::<SymbolicExprF> {
x: SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(acc.cumulative_sum[0].0[i])
}),
y: SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(acc.cumulative_sum[1].0[i])
}),
};
let point_to_add = SepticCurve::<SymbolicExprF> {
x: SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(cols.x_coordinate.0[i])
}),
y: SepticExtension::<SymbolicExprF>::from_base_fn(|i| {
SymbolicExprF::from(cols.y_coordinate.0[i])
}),
};

let sum_checker_x = SepticCurve::<SymbolicExprF>::sum_checker_x(
initial_digest,
point_to_add,
cumulative_sum,
);
let sum_checker_y = SepticCurve::<SymbolicExprF>::sum_checker_y(
initial_digest,
point_to_add,
cumulative_sum,
);

builder
.assert_septic_ext_eq(sum_checker_x, SepticExtension::<SymbolicExprF>::zero());
builder
.when(is_real)
.assert_septic_ext_eq(sum_checker_y, SepticExtension::<SymbolicExprF>::zero());

builder.send(
AirInteraction::new(
vec![local.index + SymbolicExprF::one()]
.into_iter()
.chain(
acc.cumulative_sum
.into_iter()
.flat_map(|septic| septic.0)
.map(SymbolicExprF::from),
)
.collect(),
SymbolicExprF::from(is_real),
InteractionKind::GlobalAccumulation,
),
InteractionScope::Local,
);
}
8 => eval_internal_rounds(builder, local_row.as_ref()),
_ => unreachable!(),
}
}
Expand Down
Loading