diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 665f7556a..eb13c53e5 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -307,9 +307,21 @@ pub fn verify_zkvm_proof>( let num_lks: Var = builder.eval(C::N::from_canonical_usize(chip_vk.get_cs().num_lks())); + // Chips with EC-sum ops carry an extra hypercube variable; the + // prover fills it with EC-tree internal nodes that are inactive + // via `selector_zero = 0` and thus collapse to dummy lookup + // queries. Mirror the native verifier's adjustment here so the + // dummy multiplicity matches the prover. + let ecc_row_factor: usize = if circuit_vk.get_cs().has_ecc_ops() { + 2 + } else { + 1 + }; // each padding instance contribute to (2^rotation_vars) dummy lookup padding - let next_pow2_instance: Var = + let next_pow2_chip_rows: Var = pow_2(builder, chip_proof.log2_num_instances.get_var()); + let next_pow2_instance: Var = + builder.eval(next_pow2_chip_rows * C::N::from_canonical_usize(ecc_row_factor)); let num_padded_instance: Var = builder.eval(next_pow2_instance - chip_proof.sum_num_instances.clone()); let rotation_var: Var = builder.constant(C::N::from_canonical_usize( diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index e6dd8b334..42edd188c 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1497,6 +1497,30 @@ pub fn generate_witness<'a, E: ExtensionField>( &mut zkvm_witness, ) }).unwrap(); + + // Assign continuation circuits (LocalFinal + ShardRam) before + // `finalize_lk_multiplicities`: ShardRam's per-row y6_lo byte / + // LTU lookups must land in `combined_lk_mlt` so the U8 / LTU + // table `mlt` columns balance the logup grand product. LocalFinal + // does not consume `combined_lk_mlt`, so running it pre-finalize + // is safe — `assign_table_circuit` tolerates a not-yet-finalized + // multiplicity by passing an empty slice. + info_span!("assign_continuation").in_scope(|| { + system_config + .mmu_config + .assign_continuation_circuit( + &system_config.zkvm_cs, + &shard_ctx, + &mut zkvm_witness, + &pi, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.heap, + ) + }).unwrap(); + info_span!("finalize_lk_multiplicities").in_scope(|| { zkvm_witness.finalize_lk_multiplicities(); }); @@ -1535,6 +1559,23 @@ pub fn generate_witness<'a, E: ExtensionField>( &mut cpu_witness, ) .unwrap(); + // Mirror the main path so `combined_lk_mlt` comparison stays + // meaningful: continuation pushes ShardRamCircuit's per-row + // y6_lo lookups into `lk_mlts` before finalize. + system_config + .mmu_config + .assign_continuation_circuit( + &system_config.zkvm_cs, + &cpu_shard_ctx, + &mut cpu_witness, + &pi, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.heap, + ) + .unwrap(); cpu_witness.finalize_lk_multiplicities(); #[cfg(feature = "gpu")] @@ -1626,22 +1667,6 @@ pub fn generate_witness<'a, E: ExtensionField>( ) }).unwrap(); - info_span!("assign_continuation").in_scope(|| { - system_config - .mmu_config - .assign_continuation_circuit( - &system_config.zkvm_cs, - &shard_ctx, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.stack, - &emul_result.final_mem_state.heap, - ) - }).unwrap(); - info_span!("assign_program_table").in_scope(|| { zkvm_witness .assign_table_circuit::>( diff --git a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs index 1d987a2e6..650c68177 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs @@ -7,7 +7,7 @@ use rustc_hash::FxHashSet; use crate::{ e2e::ShardContext, error::ZKVMError, - tables::{MemFinalRecord, ShardRamConfig, ShardRamRecord}, + tables::{MemFinalRecord, ShardRamConfig, ShardRamRecord, Y6_LO_TOP_BYTE_LT_BOUND}, }; /// Filter and construct a cross-shard ShardRamRecord without EC computation. @@ -198,13 +198,18 @@ pub fn gpu_batch_continuation_ec_on_device( } /// Try to run ShardRamCircuit assign_instances on GPU. -/// Returns `Ok(None)` if GPU is unavailable or disabled. +/// Returns `Ok(None)` if GPU is unavailable or disabled. On success the +/// y6_lo byte / LTU lookup multiplicity is derived from `steps` and pushed +/// into `lk_multiplicity` so the caller sees the same per-row contribution +/// the CPU `assign_instance` path would have made. pub(crate) fn try_gpu_assign_shard_ram( config: &ShardRamConfig, num_witin: usize, num_structural_witin: usize, + lk_multiplicity: &mut crate::witness::LkMultiplicity, steps: &[crate::tables::ShardRamInput], ) -> Result>, ZKVMError> { + use crate::scheme::constants::SEPTIC_EXTENSION_DEGREE; use ceno_gpu::{ Buffer, CudaHal, bb31::CudaHalBB31, @@ -496,6 +501,23 @@ pub(crate) fn try_gpu_assign_shard_ram( ); } + // The GPU witness kernel above writes the row data but does not run + // the per-row `assign_instance` CPU path that pushes the y6_lo byte / + // LTU lookup multiplicity. Derive the same contribution from `steps` + // here so the caller's `lk_multiplicity` mirrors the CPU branch and + // `combined_lk_mlt` balances the U8 / LTU table `mlt` columns. Source + // of truth for the queries is `ShardRamConfig::configure`. + for step in steps { + let y6_lo = crate::tables::y6_lo_value::( + step.ec_point.point.y.0[SEPTIC_EXTENSION_DEGREE - 1], + step.record.is_to_write_set, + ); + for i in 0..3 { + lk_multiplicity.assert_const_range((y6_lo >> (8 * i)) & 0xff, 8); + } + lk_multiplicity.lookup_ltu_byte((y6_lo >> 24) & 0xff, Y6_LO_TOP_BYTE_LT_BOUND); + } + Ok(Some([raw_witin, raw_structural_witin])) } @@ -749,7 +771,10 @@ pub(crate) fn try_gpu_assign_shard_ram_from_device( } /// Full GPU pipeline for assign_shared_circuit: device-resident EC merge + partition + assign. -/// Returns `Ok(None)` if GPU is unavailable, `Ok(Some(inputs))` on success. +/// Returns `Ok(None)` if GPU is unavailable, `Ok(Some((inputs, lk_mlt)))` on +/// success — `lk_mlt` carries the y6_lo byte / LTU lookup multiplicity that +/// `ShardRamConfig::configure` consumes (mirrors the per-row CPU push in +/// `ShardRamCircuit::assign_instance`). #[allow(clippy::type_complexity)] pub(crate) fn try_gpu_assign_shared_circuit( shard_ctx: &crate::e2e::ShardContext, @@ -762,7 +787,13 @@ pub(crate) fn try_gpu_assign_shared_circuit( num_witin: usize, num_structural_witin: usize, max_chunk: usize, -) -> Result>>, ZKVMError> { +) -> Result< + Option<( + Vec>, + gkr_iop::utils::lk_multiplicity::Multiplicity, + )>, + ZKVMError, +> { use crate::{ instructions::gpu::{ chips::shard_ram::gpu_batch_continuation_ec_on_device, @@ -770,8 +801,10 @@ pub(crate) fn try_gpu_assign_shared_circuit( }, structs::{ChipInput, ZKVMWitnesses}, tables::{ShardRamCircuit, ShardRamRecord, TableCircuit}, + witness::LkMultiplicity, }; use ceno_gpu::Buffer; + use ff_ext::SmallField; use gkr_iop::gpu::get_cuda_hal; use rayon::prelude::*; use tracing::info_span; @@ -943,8 +976,51 @@ pub(crate) fn try_gpu_assign_shared_circuit( total_records - num_writes, ); - // 7. GPU assign_instances from device buffer (chunked by max_cross_shard) let record_u32s = std::mem::size_of::() / 4; + // GpuShardRamRecord (#[repr(C)]) layout — derived from shard_ram_record_to_gpu + // above: 4xu32 leader (addr, ram_type, value, _pad0), 3xu64 + // (shard, local_clk, global_clk), 2xu32 (is_to_write_set, nonce), + // [u32; 7] point_x, [u32; 7] point_y. Total = 26 u32s. + debug_assert_eq!(record_u32s, 26, "GpuShardRamRecord layout changed"); + const IS_TO_WRITE_SET_U32_OFFSET: usize = 10; + const POINT_Y6_U32_OFFSET: usize = 25; + + // 6.5. Derive ShardRam's per-row y6_lo byte / LTU lookup multiplicity + // from the partitioned device buffer. Mirrors the per-row CPU push in + // `ShardRamCircuit::assign_instance`; the constraint these queries serve + // lives in `ShardRamConfig::configure` (y6_lo bytes + lookup_ltu_byte). + let lk_mlt = info_span!("gpu_shard_ram_derive_lk_mlt", n = total_records).in_scope( + || -> Result, ZKVMError> { + if total_records == 0 { + return Ok(gkr_iop::utils::lk_multiplicity::Multiplicity::default()); + } + let host_data: Vec = partitioned_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("[GPU full pipeline] partitioned_buf D2H: {e}").into(), + ) + })?; + debug_assert_eq!(host_data.len(), total_records * record_u32s); + let prime = ::MODULUS_U64; + let lk_multiplicity = LkMultiplicity::default(); + host_data.par_chunks_exact(record_u32s).for_each(|rec| { + let mut local = lk_multiplicity.clone(); + let is_to_write_set = rec[IS_TO_WRITE_SET_U32_OFFSET] != 0; + let y6 = rec[POINT_Y6_U32_OFFSET] as u64; + let y6_lo = if is_to_write_set { + prime - 1 - y6 + } else { + y6 - 1 + }; + for i in 0..3 { + local.assert_const_range((y6_lo >> (8 * i)) & 0xff, 8); + } + local.lookup_ltu_byte((y6_lo >> 24) & 0xff, Y6_LO_TOP_BYTE_LT_BOUND); + }); + Ok(lk_multiplicity.into_finalize_result()) + }, + )?; + + // 7. GPU assign_instances from device buffer (chunked by max_cross_shard) let circuit_inputs = info_span!("shard_ram_assign_from_device", n = total_records).in_scope(|| { @@ -993,7 +1069,7 @@ pub(crate) fn try_gpu_assign_shared_circuit( total_records, ); - Ok(Some(circuit_inputs)) + Ok(Some((circuit_inputs, lk_mlt))) } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 6c12fbe8c..4d1d038fd 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -139,6 +139,16 @@ impl MmuConfig { Ok(()) } + /// Assign LocalFinalCircuit and ShardRamCircuit witnesses. Must run + /// *before* `ZKVMWitnesses::finalize_lk_multiplicities`: + /// - `ShardRamCircuit` accumulates its per-row y6_lo byte / LTU lookups + /// into `lk_mlts` via `assign_shared_circuit` (which threads a shared + /// `LkMultiplicity` through `assign_instances_with_lk_multiplicities`), + /// so they land in `combined_lk_mlt` and balance the U8 / LTU table + /// `mlt` columns. + /// - `LocalFinalCircuit` does not consume `combined_lk_mlt`; the regular + /// `assign_table_circuit` entry tolerates a not-yet-finalized + /// multiplicity by passing an empty slice. #[allow(clippy::too_many_arguments)] pub fn assign_continuation_circuit( &self, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 6a51ed550..968c9bddc 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -698,8 +698,20 @@ impl> // compute logup_sum padding // getting the number of dummy padding item that we used in this opcode circuit let num_lks = circuit_vk.get_cs().num_lks(); + // Chips with EC-sum ops carry an extra hypercube variable (one extra + // log2 row dimension) that the prover fills with EC-tree internal + // nodes; those rows are not "active" instances and their lookup + // queries collapse to the dummy_table_item via `selector_zero = 0`. + // Mirror that here so the verifier subtracts the right number of + // dummy queries. + let ecc_row_factor = if circuit_vk.get_cs().has_ecc_ops() { + 2 + } else { + 1 + }; + let padded_rows = next_pow2_instance_padding(num_instance) * ecc_row_factor; // each padding instance contribute to (2^rotation_vars) dummy lookup padding - let num_padded_instance = (next_pow2_instance_padding(num_instance) - num_instance) + let num_padded_instance = (padded_rows - num_instance) * (1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)); // each instance contribute to (2^rotation_vars - rotated) dummy lookup padding let num_instance_non_selected = num_instance diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 069802673..f205333ca 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -8,6 +8,7 @@ use crate::{ ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit, }, + witness::LkMultiplicity, }; use ceno_emul::{Addr, CENO_PLATFORM, Platform, RegIdx, StepIndex, StepRecord, WordAddr}; use ff_ext::{ExtensionField, PoseidonField}; @@ -471,13 +472,15 @@ impl ZKVMWitnesses { config: &TC::TableConfig, input: &TC::WitnessInput<'_>, ) -> Result<(), ZKVMError> { - assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&TC::name()).unwrap(); + let empty_mlt: Vec> = Vec::new(); + // Scope the immutable borrow of `self.combined_lk_mlt` so the + // `self.witnesses.insert` mutable borrow below is legal. let witness = TC::assign_instances( config, cs.zkvm_v1_css.num_witin as usize, cs.zkvm_v1_css.num_structural_witin as usize, - self.combined_lk_mlt.as_ref().unwrap(), + self.combined_lk_mlt.as_ref().unwrap_or(&empty_mlt), input, )?; let witness_instances = witness[0].num_instances(); @@ -645,19 +648,25 @@ impl ZKVMWitnesses { } } - assert!(self.combined_lk_mlt.is_some()); + assert!(self.combined_lk_mlt.is_none()); let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); let n_global = global_input.len(); + // `ShardRamCircuit::assign_instances` ignores the `multiplicity` + // argument (its lookup contribution is derived externally above), so + // an empty slice is sufficient here and matches the pre-finalize + // ordering: `combined_lk_mlt` is intentionally `None` at this point. + let lk_multiplicity = LkMultiplicity::default(); let circuit_inputs = info_span!("shard_ram_assign_instances", n = n_global).in_scope(|| { global_input .par_chunks(shard_ctx.max_num_cross_shard_accesses) .map(|shard_accesses| { - let witness = ShardRamCircuit::assign_instances( + let mut lk_multiplicity = lk_multiplicity.clone(); + let witness = ShardRamCircuit::assign_instances_with_lk_multiplicities( config, cs.zkvm_v1_css.num_witin as usize, cs.zkvm_v1_css.num_structural_witin as usize, - self.combined_lk_mlt.as_ref().unwrap(), + &mut lk_multiplicity, shard_accesses, )?; let num_reads = shard_accesses @@ -674,6 +683,15 @@ impl ZKVMWitnesses { }) .collect::, ZKVMError>>() })?; + + assert!( + self.lk_mlts + .insert( + ShardRamCircuit::::name(), + lk_multiplicity.into_finalize_result() + ) + .is_none() + ); // set num_read, num_write as separate instance assert!( self.witnesses @@ -687,6 +705,11 @@ impl ZKVMWitnesses { /// Full GPU pipeline for assign_shared_circuit: keep data on device, minimal CPU roundtrips. /// /// Returns Ok(true) if successful, Ok(false) if unavailable (no shared device buffers). + /// On success, inserts both `ChipInput` and `ShardRamCircuit`'s derived + /// lookup multiplicity (for the y6_lo byte / LTU queries) into + /// `self.witnesses` / `self.lk_mlts` so the subsequent + /// `finalize_lk_multiplicities` folds the contribution into + /// `combined_lk_mlt` — matching the CPU shortcut's invariant. #[cfg(feature = "gpu")] fn try_assign_shared_circuit_gpu( &mut self, @@ -695,7 +718,7 @@ impl ZKVMWitnesses { final_mem: &[(&'static str, Option>, &[MemFinalRecord])], config: & as TableCircuit>::TableConfig, ) -> Result { - assert!(self.combined_lk_mlt.is_some()); + assert!(self.combined_lk_mlt.is_none()); let cs_inner = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); let num_witin = cs_inner.zkvm_v1_css.num_witin as usize; let num_structural_witin = cs_inner.zkvm_v1_css.num_structural_witin as usize; @@ -708,12 +731,17 @@ impl ZKVMWitnesses { num_structural_witin, shard_ctx.max_num_cross_shard_accesses, )? { - Some(circuit_inputs) => { + Some((circuit_inputs, lk_mlt)) => { assert!( self.witnesses .insert(ShardRamCircuit::::name(), circuit_inputs) .is_none() ); + assert!( + self.lk_mlts + .insert(ShardRamCircuit::::name(), lk_mlt) + .is_none() + ); Ok(true) } None => Ok(false), diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 3dd744fc8..2e1ac58d6 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -4,6 +4,7 @@ use gkr_iop::{ chip::Chip, gkr::{GKRCircuit, layer::Layer}, selector::SelectorType, + utils::lk_multiplicity::LkMultiplicity, }; use itertools::Itertools; use multilinear_extensions::ToExpr; @@ -95,11 +96,25 @@ pub trait TableCircuit { input: &Self::FixedInput, ) -> RowMajorMatrix; + fn assign_instances_with_lk_multiplicities( + _config: &Self::TableConfig, + _num_witin: usize, + _num_structural_witin: usize, + _lk_multiplicity: &mut LkMultiplicity, + _input: &Self::WitnessInput<'_>, + ) -> Result, ZKVMError> { + unimplemented!( + "assign_instances_with_lk_multiplicities is not implemented for this table circuit" + ) + } + fn assign_instances( - config: &Self::TableConfig, - num_witin: usize, - num_structural_witin: usize, - multiplicity: &[FxHashMap], - input: &Self::WitnessInput<'_>, - ) -> Result, ZKVMError>; + _config: &Self::TableConfig, + _num_witin: usize, + _num_structural_witin: usize, + _multiplicity: &[FxHashMap], + _input: &Self::WitnessInput<'_>, + ) -> Result, ZKVMError> { + unimplemented!("assign_instances is not implemented for this table circuit") + } } diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 16e954508..de033aefd 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -1,4 +1,3 @@ -use rustc_hash::FxHashMap; use std::{iter::repeat_n, marker::PhantomData}; use crate::{ @@ -42,6 +41,8 @@ use witness::{InstancePaddingStrategy, next_pow2_instance_padding, set_val}; use crate::{instructions::riscv::constants::UInt, scheme::constants::SEPTIC_EXTENSION_DEGREE}; +pub(crate) const Y6_LO_TOP_BYTE_LT_BOUND: u64 = 60; + /// A record for a read/write into the shard RAM #[derive(Debug, Clone)] pub struct ShardRamRecord { @@ -133,24 +134,31 @@ impl ShardRamRecord { hasher.permute(input.clone())[0..SEPTIC_EXTENSION_DEGREE].into(); if let Some(p) = SepticPoint::from_x(x) { let y6 = (p.y.0)[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64(); - let is_y_in_2nd_half = y6 >= (prime / 2); - - // we negate y if needed - // to ensure read => y in [0, p/2) and write => y in [p/2, p) - let negate = match (self.is_to_write_set, is_y_in_2nd_half) { - (true, false) => true, // write, y in [0, p/2) - (false, true) => true, // read, y in [p/2, p) - _ => false, - }; - - let point = if negate { -p } else { p }; - - return ECPoint { nonce, point }; - } else { - // try again with different nonce - nonce += 1; - input[6] = E::BaseField::from_canonical_u32(nonce); + // Reject cases where y6 = 0 because then the y-sign + // binding in the circuit cannot distinguish read from write. + if y6 != 0 { + // Strict `>`: `prime / 2 = (p-1)/2` belongs to the lower + // half (read region `[1, (p-1)/2]`). Using `>=` would + // misclassify it and produce `y6_lo = (p-1)/2` whose top + // byte b3 = 60 fails `lookup_ltu_byte(b3, 60, 1)`. + let is_y_in_2nd_half = y6 > (prime / 2); + + // Enforce convention: + // is_to_write_set = 0 (read) => y6 in [1, (p-1)/2] + // is_to_write_set = 1 (write) => y6 in [(p+1)/2, p-1] + let negate = matches!( + (self.is_to_write_set, is_y_in_2nd_half), + (true, false) | (false, true) + ); + + let point = if negate { -p } else { p }; + + return ECPoint { nonce, point }; + } } + // try again with different nonce + nonce += 1; + input[6] = E::BaseField::from_canonical_u32(nonce); } } } @@ -180,6 +188,9 @@ pub struct ShardRamConfig { pub(crate) x: Vec, pub(crate) y: Vec, pub(crate) slope: Vec, + // Byte limbs of `y6_lo`, the helper that binds `y[SEPTIC_EXTENSION_DEGREE - 1]` + // to `is_global_write` in `configure`. + pub(crate) y6_lo_bytes: [WitIn; 4], pub(crate) perm_config: Poseidon2Config, } @@ -273,12 +284,45 @@ impl ShardRamConfig { cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; } - // both (x, y) and (x, -y) are valid ec points - // if is_global_write = 1, then y should be in [0, p/2) - // if is_global_write = 0, then y should be in [p/2, p) - - // TODO: enforce 0 <= y < p/2 if is_global_write = 1 - // enforce p/2 <= y < p if is_global_write = 0 + // Bind the sign of y[SEPTIC_EXTENSION_DEGREE - 1] (call it y6) to + // is_global_write: + // is_global_write = 0 (read) => y6 in [1, (p-1)/2] + // is_global_write = 1 (write) => y6 in [(p+1)/2, p-1] + // y6_lo is witnessed as four byte limbs with the top byte < 60. + // For BabyBear, (p-1)/2 = 60 * 2^24 exactly, so the byte bound + // gives y6_lo in [0, (p-1)/2). Branch equality: + // read : y6 = y6_lo + 1 + // write : y6 + y6_lo + 1 = 0 (mod p) + // y6 = 0 is the unique fixed point; `to_ec_point` rejects it. + assert_eq!( + ::MODULUS_U64, + 0x7800_0001, + "y6_lo byte bound assumes BabyBear's (p-1)/2 = 60 * 2^24" + ); + let y6_lo_bytes: [WitIn; 4] = + std::array::from_fn(|i| cb.create_witin(|| format!("y6_lo_b{i}"))); + for (i, w) in y6_lo_bytes.iter().enumerate().take(3) { + cb.assert_byte(|| format!("y6_lo_b{i} byte"), w.expr())?; + } + // `lookup_ltu_byte(a, b, 1)` asserts `a, b` are bytes and `a < b`. + cb.lookup_ltu_byte( + y6_lo_bytes[3].expr(), + E::BaseField::from_canonical_u64(Y6_LO_TOP_BYTE_LT_BOUND).expr(), + Expression::ONE, + )?; + let y6_lo = y6_lo_bytes[0].expr() + + y6_lo_bytes[1].expr() * E::BaseField::from_canonical_u64(1 << 8).expr() + + y6_lo_bytes[2].expr() * E::BaseField::from_canonical_u64(1 << 16).expr() + + y6_lo_bytes[3].expr() * E::BaseField::from_canonical_u64(1 << 24).expr(); + let y6 = y[SEPTIC_EXTENSION_DEGREE - 1].expr(); + cb.condition_require_equal( + || "y6 binds to is_global_write", + is_global_write.expr(), + y6, + E::BaseField::from_canonical_u64(::MODULUS_U64 - 1).expr() + - y6_lo.clone(), + y6_lo + Expression::ONE, + )?; Ok(ShardRamConfig { x, @@ -292,6 +336,7 @@ impl ShardRamConfig { local_clk, nonce, is_global_write, + y6_lo_bytes, perm_config, }) } @@ -311,11 +356,26 @@ pub struct ShardRamInput { pub ec_point: ECPoint, } +/// Decode `y6_lo` (the byte-decomposed helper bound to `is_global_write` in +/// `ShardRamConfig::configure`) from a witnessed `y6` field element. Mirrors +/// the prover-side derivation done inside the per-row witness assignment; +/// `to_ec_point` guarantees `y6 != 0` and the half-of-field convention, so +/// neither branch underflows. +pub(crate) fn y6_lo_value(y6: E::BaseField, is_to_write_set: bool) -> u64 { + let prime = ::MODULUS_U64; + let y6_u64 = y6.to_canonical_u64(); + if is_to_write_set { + prime - 1 - y6_u64 + } else { + y6_u64 - 1 + } +} + impl ShardRamCircuit { fn assign_instance( config: &ShardRamConfig, instance: &mut [E::BaseField], - _lk_multiplicity: &mut LkMultiplicity, + lk_multiplicity: &mut LkMultiplicity, input: &ShardRamInput, ) -> Result<(), crate::error::ZKVMError> { // assign basic fields @@ -350,6 +410,23 @@ impl ShardRamCircuit { instance[witin.id as usize] = *fe; }); + // y6_lo byte limbs for the y-sign binding constraint in `configure`. + // `to_ec_point` guarantees y6 != 0 and the half-of-field convention, + // so the subtraction below never underflows. + let y6_lo_u64 = y6_lo_value::( + point.y.0[SEPTIC_EXTENSION_DEGREE - 1], + record.is_to_write_set, + ); + for i in 0..4 { + let b = (y6_lo_u64 >> (8 * i)) & 0xff; + set_val!(instance, config.y6_lo_bytes[i], b); + } + for i in 0..3 { + let b = (y6_lo_u64 >> (8 * i)) & 0xff; + lk_multiplicity.assert_const_range(b, 8); + } + lk_multiplicity.lookup_ltu_byte((y6_lo_u64 >> 24) & 0xff, Y6_LO_TOP_BYTE_LT_BOUND); + let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32); let mut input = [E::BaseField::ZERO; 16]; @@ -483,11 +560,11 @@ impl TableCircuit for ShardRamCircuit { } /// steps format: local reads ++ local writes - fn assign_instances( + fn assign_instances_with_lk_multiplicities( config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[FxHashMap], + lk_multiplicity: &mut LkMultiplicity, steps: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { if steps.is_empty() { @@ -499,9 +576,13 @@ impl TableCircuit for ShardRamCircuit { #[cfg(feature = "gpu")] { - if let Some(result) = - Self::try_gpu_assign_instances(config, num_witin, num_structural_witin, steps)? - { + if let Some(result) = Self::try_gpu_assign_instances( + config, + num_witin, + num_structural_witin, + lk_multiplicity, + steps, + )? { return Ok(result); } } @@ -547,7 +628,6 @@ impl TableCircuit for ShardRamCircuit { let n = next_pow2_instance_padding(steps.len()); // compute the input for the binary tree for ec point summation - let lk_multiplicity = LkMultiplicity::default(); // *2 because we need to store the internal nodes of binary tree for ec point summation let num_rows_padded = 2 * n; @@ -692,12 +772,14 @@ impl ShardRamCircuit { config: &ShardRamConfig, num_witin: usize, num_structural_witin: usize, + lk_multiplicity: &mut LkMultiplicity, steps: &[ShardRamInput], ) -> Result>, ZKVMError> { crate::instructions::gpu::chips::shard_ram::try_gpu_assign_shard_ram( config, num_witin, num_structural_witin, + lk_multiplicity, steps, ) } @@ -735,14 +817,17 @@ mod tests { use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use transcript::BasicTranscript; + use super::ECPoint; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, scheme::{ PublicValues, constants::SEPTIC_EXTENSION_DEGREE, create_backend, create_prover, - hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, + hal::ProofInput, mock_prover::MockProver, prover::ZKVMProver, + septic_curve::SepticPoint, }, structs::{ComposedConstrainSystem, ProgramParams, RAMType, ZKVMProvingKey}, tables::{ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit}, + witness::LkMultiplicity, }; #[cfg(feature = "gpu")] use gkr_iop::{ @@ -845,11 +930,12 @@ mod tests { let public_value = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, [0; 8], shard_rw_sum); // assign witness - let witness = ShardRamCircuit::assign_instances( + let mut lk_multiplicity = LkMultiplicity::default(); + let witness = ShardRamCircuit::assign_instances_with_lk_multiplicities( &config, cs.num_witin as usize, cs.num_structural_witin as usize, - &[], + &mut lk_multiplicity, &input, ) .unwrap(); @@ -938,4 +1024,98 @@ mod tests { .create_chip_proof(&mut task, &mut transcript) .unwrap(); } + + /// Drive a single ShardRam row through the full `configure` + + /// `assign_instances` pipeline and validate with `MockProver`. + /// + /// * The honest witness satisfies every assert-zero and lookup + /// constraint MockProver checks. + /// * Negating the EC point but reusing the original record causes + /// `assign_instance` to derive `y6_lo` such that the algebraic + /// branch equality still holds, but the top byte + /// `b3 = (y6_lo >> 24) & 0xff` lands in `[60, 256)`. The query + /// `(Ltu, b3, 60, 1)` is missing from the LTU table (which only + /// carries `(Ltu, b3, 60, 0)` for `b3 >= 60`), so the + /// `lookup_Ltu` constraint rejects the tampered row. + #[test] + fn test_shard_ram_y_sign_circuit_rejects_negation() { + let perm = ::get_default_perm(); + + let mut cs = ConstraintSystem::new(|| "y_sign"); + let mut cb = CircuitBuilder::::new(&mut cs); + let (config, _gkr) = + ShardRamCircuit::::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural = cb.cs.num_structural_witin as usize; + // Pass a concrete challenge so `assert_with_expected_errors` routes + // through `run_with_challenge`; the no-challenge `run` path drops + // `structural_witin` and ShardRam relies on `selector_zero` to gate + // its lookup queries. + let mut rng = thread_rng(); + let challenge = [E::random(&mut rng), E::random(&mut rng)]; + + for is_to_write_set in [true, false] { + let record = ShardRamRecord { + addr: 0x1000, + ram_type: RAMType::Memory, + value: 0x1234_5678, + shard: if is_to_write_set { 1 } else { 2 }, + local_clk: if is_to_write_set { 7 } else { 0 }, + global_clk: 13, + is_to_write_set, + }; + let ec = record.to_ec_point::(&perm); + + // Honest row: every constraint MockProver checks must be + // satisfied. + let honest = [ShardRamInput { + name: "honest", + record: record.clone(), + ec_point: ec.clone(), + }]; + let mut honest_lkm = LkMultiplicity::default(); + let honest_witness = ShardRamCircuit::::assign_instances_with_lk_multiplicities( + &config, + num_witin, + num_structural, + &mut honest_lkm, + &honest, + ) + .unwrap(); + MockProver::::assert_satisfied_raw(&cb, honest_witness, &[], Some(challenge), None); + + // Tampered row: negate the EC point. `assign_instance` re-derives + // `y6_lo` from the witnessed `y6`, keeping the branch equality + // intact, so only the `lookup_Ltu` byte bound catches the wrong + // sign. + let tampered = [ShardRamInput { + name: "tampered", + record, + ec_point: ECPoint { + nonce: ec.nonce, + point: -ec.point, + }, + }]; + let mut tampered_lkm = LkMultiplicity::default(); + let [w, sw] = ShardRamCircuit::::assign_instances_with_lk_multiplicities( + &config, + num_witin, + num_structural, + &mut tampered_lkm, + &tampered, + ) + .unwrap(); + MockProver::::assert_with_expected_errors( + &cb, + &[], + &w.to_mles().into_iter().map(|v| v.into()).collect_vec(), + &sw.to_mles().into_iter().map(|v| v.into()).collect_vec(), + &[], + &["lookup_Ltu"], + Some(challenge), + None, + ); + } + } }