From 54f2e9e3e7e3025d1a274b6b844909348918826c Mon Sep 17 00:00:00 2001 From: mmagician Date: Mon, 19 Feb 2024 16:57:22 +0100 Subject: [PATCH 01/50] point sumcheck dependency to HCS private fork --- Cargo.lock | 168 +++++++++++++++++++++-------------------------------- Cargo.toml | 2 +- 2 files changed, 67 insertions(+), 103 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 368aba0..b641b84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" dependencies = [ "cfg-if", "once_cell", @@ -44,7 +44,7 @@ checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "ark-bn254" version = "0.4.0" -source = "git+https://github.com/arkworks-rs/algebra/#228787b5ab87139dc2a79359d2f6b25237f46dac" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-ec", "ark-ff", @@ -63,9 +63,9 @@ dependencies = [ "ark-serialize", "ark-snark", "ark-std", - "blake2 0.10.6", + "blake2", "derivative", - "digest 0.10.7", + "digest", "sha2", ] @@ -76,13 +76,13 @@ source = "git+https://github.com/HungryCatsStudio/crypto-primitives?branch=absor dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] name = "ark-ec" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-ff", "ark-poly", @@ -90,7 +90,7 @@ dependencies = [ "ark-std", "derivative", "hashbrown", - "itertools 0.12.0", + "itertools 0.12.1", "num-bigint", "num-traits", "zeroize", @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "ark-ff" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-ff-asm", "ark-ff-macros", @@ -107,8 +107,8 @@ dependencies = [ "ark-std", "arrayvec", "derivative", - "digest 0.10.7", - "itertools 0.12.0", + "digest", + "itertools 0.12.1", "num-bigint", "num-traits", "paste", @@ -118,35 +118,22 @@ dependencies = [ [[package]] name = "ark-ff-asm" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] name = "ark-ff-macros" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "num-bigint", "num-traits", "proc-macro2", "quote", - "syn 2.0.48", -] - -[[package]] -name = "ark-linear-sumcheck" -version = "0.4.0" -source = "git+https://github.com/arkworks-rs/sumcheck/#956fdaa2b80ff72cda2eafefda3f62a57589ddbd" -dependencies = [ - "ark-ff", - "ark-poly", - "ark-serialize", - "ark-std", - "blake2 0.9.2", - "hashbrown", + "syn 2.0.49", ] [[package]] @@ -169,7 +156,7 @@ dependencies = [ [[package]] name = "ark-poly" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-ff", "ark-serialize", @@ -190,7 +177,7 @@ dependencies = [ "ark-serialize", "ark-std", "derivative", - "digest 0.10.7", + "digest", "num-traits", ] @@ -208,22 +195,22 @@ dependencies = [ [[package]] name = "ark-serialize" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "ark-serialize-derive", "ark-std", - "digest 0.10.7", + "digest", "num-bigint", ] [[package]] name = "ark-serialize-derive" version = "0.4.2" -source = "git+https://github.com/arkworks-rs/algebra/#bf96a5b2873e69f3c378c7b25d0901a6701efcc4" +source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] @@ -248,6 +235,20 @@ dependencies = [ "rand", ] +[[package]] +name = "ark-sumcheck" +version = "0.4.0" +source = "git+ssh://git@github.com/HungryCatsStudio/sumcheck-private.git#721fb56acd6aba79333d8862a0067b01023bc845" +dependencies = [ + "ark-crypto-primitives", + "ark-ff", + "ark-poly", + "ark-poly-commit", + "ark-serialize", + "ark-std", + "hashbrown", +] + [[package]] name = "arrayvec" version = "0.7.4" @@ -260,24 +261,13 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "blake2" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e37d16930f5459780f5621038b6382b9bb37c19016f39fb6b5808d831f174" -dependencies = [ - "crypto-mac", - "digest 0.9.0", - "opaque-debug", -] - [[package]] name = "blake2" version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" dependencies = [ - "digest 0.10.7", + "digest", ] [[package]] @@ -330,18 +320,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.0" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80c21025abd42669a92efc996ef13cfb2c5c627858421ea58d5c3b331a6c134f" +checksum = "c918d541ef2913577a0f9566e9ce27cb35b6df072075769e0b26cb5a554520da" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.0" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458bf1f341769dfcf849846f65dffdf9146daa56bcd2a47cb4e1de9915567c99" +checksum = "9f3e7391dad68afb0c2ede1bf619f579a3dc9c2ec67f089baa397123a2f3d1eb" dependencies = [ "anstyle", "clap_lex", @@ -412,16 +402,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "crypto-mac" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b584a330336237c1eecd3e94266efb216c56ed91225d634cb2991c5f3fd1aeab" -dependencies = [ - "generic-array", - "subtle", -] - [[package]] name = "derivative" version = "2.2.0" @@ -433,15 +413,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "digest" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" -dependencies = [ - "generic-array", -] - [[package]] name = "digest" version = "0.10.7" @@ -455,9 +426,9 @@ dependencies = [ [[package]] name = "either" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "generic-array" @@ -517,9 +488,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ "either", ] @@ -532,9 +503,9 @@ checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "libc" -version = "0.2.152" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libm" @@ -561,19 +532,18 @@ dependencies = [ [[package]] name = "num-integer" -version = "0.1.45" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "autocfg", "num-traits", ] [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", "libm", @@ -591,12 +561,6 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "paste" version = "1.0.14" @@ -690,9 +654,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "same-file" @@ -705,29 +669,29 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.195" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.195" +version = "1.0.196" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] name = "serde_json" -version = "1.0.111" +version = "1.0.113" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4" +checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" dependencies = [ "itoa", "ryu", @@ -742,7 +706,7 @@ checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", - "digest 0.10.7", + "digest", ] [[package]] @@ -764,9 +728,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "915aea9e586f80826ee59f8453c1101f9d1c4b3964cd2460185ee8e299ada496" dependencies = [ "proc-macro2", "quote", @@ -819,13 +783,13 @@ dependencies = [ "ark-crypto-primitives", "ark-ec", "ark-ff", - "ark-linear-sumcheck", "ark-pcs-bench-templates", "ark-poly", "ark-poly-commit", "ark-serialize", "ark-std", - "blake2 0.10.6", + "ark-sumcheck", + "blake2", "serde_json", ] @@ -959,7 +923,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] @@ -979,5 +943,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] diff --git a/Cargo.toml b/Cargo.toml index 35fd49c..cfa3d4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ ark-serialize = { version = "^0.4.0", default-features = false, features = [ "de ark-poly = {version = "^0.4.0", default-features = false } ark-poly-commit = {version = "^0.4.0", default-features = false } ark-crypto-primitives = {version = "^0.4.0", default-features = false } -ark-linear-sumcheck = { git = "https://github.com/arkworks-rs/sumcheck/", default-features = false } +ark-sumcheck = { git = "ssh://git@github.com/HungryCatsStudio/sumcheck-private.git", default-features = false } [dev-dependencies] ark-bn254 = { version = "^0.4.0", default-features = false, features = [ "curve" ] } From b24b714a7e9bbbc8010bcf6dfd4d344d86fadfb3 Mon Sep 17 00:00:00 2001 From: mmagician Date: Mon, 19 Feb 2024 17:59:12 +0100 Subject: [PATCH 02/50] changed the interface of prove: take MLEs instead of QArrays --- src/lib.rs | 1 + src/model/mod.rs | 31 +++++++++++++++++++------------ src/model/nodes/bmm.rs | 18 +++++++++++++++--- src/model/nodes/mod.rs | 8 ++++---- src/model/nodes/relu.rs | 4 ++-- src/model/nodes/requantise_bmm.rs | 4 ++-- src/model/nodes/reshape.rs | 4 ++-- src/utils.rs | 26 ++++++++++++++++++++++++++ 8 files changed, 71 insertions(+), 25 deletions(-) create mode 100644 src/utils.rs diff --git a/src/lib.rs b/src/lib.rs index 4d537be..1876bc6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub(crate) mod model; pub(crate) mod quantization; +pub(crate) mod utils; #[cfg(test)] pub(crate) mod pcs_types; diff --git a/src/model/mod.rs b/src/model/mod.rs index 4b3a50f..b7f48fa 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -2,7 +2,7 @@ use ark_std::{log2, rand::RngCore}; use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly::DenseMultilinearExtension; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; use crate::model::nodes::{NodeOps, NodeOpsSNARK}; @@ -141,7 +141,7 @@ where 0, ); - let output_f = output.values().iter().map(|x| F::from(*x)).collect(); + let output_f: Vec = output.values().iter().map(|x| F::from(*x)).collect(); let mut output = QTypeArray::S(output); @@ -149,7 +149,10 @@ where // TODO handling F and QSmallType is inelegant; we might want to switch // to F for IO in NodeOps::prove let mut node_outputs = vec![output.clone()]; - let mut node_outputs_f = vec![output_f]; + let mut node_output_mles = vec![Poly::from_evaluations_vec( + log2(output_f.len()) as usize, + output_f, + )]; for node in &self.nodes { output = node.padded_evaluate(&output); @@ -160,18 +163,22 @@ where }; node_outputs.push(output.clone()); - node_outputs_f.push(output_f); + node_output_mles.push(Poly::from_evaluations_vec( + log2(output_f.len()) as usize, + output_f, + )); } // Committing to node outputs as MLEs (individual per node for now) - let output_mles: Vec>> = node_outputs_f + let output_mles: Vec>> = node_output_mles .iter() - .map(|values| + .map(|mle| // TODO change dummy label once we e.g. have given numbers to the // nodes in the model: fc_1, fc_2, relu_1, etc. + // TODO maybe we don't need to clone, if `prove` can take a reference LabeledPolynomial::new( "dummy".to_string(), - Poly::from_evaluations_vec(log2(values.len()) as usize, values.clone()), + mle.clone(), None, None, )) @@ -192,7 +199,7 @@ where .nodes .iter() .zip(node_commitments.iter()) - .zip(node_outputs.windows(2)) + .zip(node_output_mles.windows(2)) .zip(node_coms.windows(2)) .zip(node_com_states.windows(2)) { @@ -212,21 +219,21 @@ where // output nodes and instead working witht their plain values all along, // but that would require messy node-by-node handling let input_node = node_outputs.first().unwrap(); - let input_node_f = node_outputs_f.first().unwrap(); + let input_node_f = node_output_mles.first().unwrap().to_evaluations(); let input_labeled_value = output_mles.first().unwrap(); let input_node_com = node_coms.first().unwrap(); let input_node_com_state = node_com_states.first().unwrap(); let output_node = node_outputs.last().unwrap(); - let output_node_f = node_outputs_f.last().unwrap(); + let output_node_f = node_output_mles.last().unwrap().to_evaluations(); let output_labeled_value = output_mles.last().unwrap(); let output_node_com = node_coms.last().unwrap(); let output_node_com_state = node_com_states.last().unwrap(); // Absorb the model IO output and squeeze the challenge point // Absorb the plain output and squeeze the challenge point - sponge.absorb(input_node_f); - sponge.absorb(output_node_f); + sponge.absorb(&input_node_f); + sponge.absorb(&output_node_f); let input_challenge_point = sponge.squeeze_field_elements(log2(input_node_f.len()) as usize); let output_challenge_point = diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index b798ecc..e3fab68 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -1,3 +1,4 @@ +use ark_poly::MultilinearExtension; use ark_std::marker::PhantomData; use ark_crypto_primitives::sponge::CryptographicSponge; @@ -239,13 +240,24 @@ where fn prove( &self, - s: &mut S, + sponge: &mut S, node_com: &NodeCommitment, - input: QTypeArray, + input: Poly, input_com: &PCS::Commitment, - output: QTypeArray, + output: Poly, output_com: &PCS::Commitment, ) -> NodeProof { + // we can squeeze directly, since the sponge has already absorbed all the + // commitments in Model::prove_inference + let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); + + let weights_f = self.padded_weights.iter().map(|w| F::from(*w)).collect(); + // TODO this might need LE -> BE conversion + let weights_mle = Poly::from_evaluations_vec(self.com_num_vars(), weights_f); + + // TODO we actually need fix_variables_last + weights_mle.fix_variables(&r); + unimplemented!() } } diff --git a/src/model/nodes/mod.rs b/src/model/nodes/mod.rs index 7b9d83f..5d43bb4 100644 --- a/src/model/nodes/mod.rs +++ b/src/model/nodes/mod.rs @@ -100,9 +100,9 @@ where &self, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, + input: Poly, input_com: &PCS::Commitment, - output: QTypeArray, + output: Poly, output_com: &PCS::Commitment, ) -> NodeProof; } @@ -246,9 +246,9 @@ where &self, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, + input: Poly, input_com: &PCS::Commitment, - output: QTypeArray, + output: Poly, output_com: &PCS::Commitment, ) -> NodeProof { self.as_node_ops_snark() diff --git a/src/model/nodes/relu.rs b/src/model/nodes/relu.rs index 3e67abf..6115090 100644 --- a/src/model/nodes/relu.rs +++ b/src/model/nodes/relu.rs @@ -86,9 +86,9 @@ where &self, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, + input: Poly, input_com: &PCS::Commitment, - output: QTypeArray, + output: Poly, output_com: &PCS::Commitment, ) -> super::NodeProof { todo!() diff --git a/src/model/nodes/requantise_bmm.rs b/src/model/nodes/requantise_bmm.rs index a0ce211..4237f24 100644 --- a/src/model/nodes/requantise_bmm.rs +++ b/src/model/nodes/requantise_bmm.rs @@ -148,9 +148,9 @@ where &self, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, + input: Poly, input_com: &PCS::Commitment, - output: QTypeArray, + output: Poly, output_com: &PCS::Commitment, ) -> NodeProof { unimplemented!() diff --git a/src/model/nodes/reshape.rs b/src/model/nodes/reshape.rs index 6e7d091..690a863 100644 --- a/src/model/nodes/reshape.rs +++ b/src/model/nodes/reshape.rs @@ -123,9 +123,9 @@ where &self, s: &mut S, node_com: &NodeCommitment, - input: QTypeArray, + input: Poly, input_com: &PCS::Commitment, - output: QTypeArray, + output: Poly, output_com: &PCS::Commitment, ) -> NodeProof { unimplemented!() diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..16114fc --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,26 @@ +use ark_ff::Field; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; + +pub(crate) fn fix_variables( + poly: &DenseMultilinearExtension, + partial_point: &[F], +) -> DenseMultilinearExtension { + assert!( + partial_point.len() <= poly.num_vars, + "invalid size of partial point" + ); + let nv = poly.num_vars; + + let mut poly = poly.evaluations.to_vec(); + let dim = partial_point.len(); + // evaluate single variable of partial point from right to left + for i in 1..dim + 1 { + let r = partial_point[i - 1]; + for b in 0..(1 << (nv - i)) { + let left = poly[b << 1]; + let right = poly[(b << 1) + 1]; + poly[b] = left + r * (right - left); + } + } + DenseMultilinearExtension::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))]) +} From 20b4def29f3a7975bb4b4ffc58f57a6d367fbc64 Mon Sep 17 00:00:00 2001 From: mmagician Date: Mon, 19 Feb 2024 18:17:02 +0100 Subject: [PATCH 03/50] construct a list of products of polys from weights and input --- src/model/nodes/bmm.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index e3fab68..9c0d7cf 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -1,3 +1,5 @@ +use ark_std::rc::Rc; + use ark_poly::MultilinearExtension; use ark_std::marker::PhantomData; @@ -6,6 +8,7 @@ use ark_ff::PrimeField; use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; use ark_std::log2; use ark_std::rand::RngCore; +use ark_sumcheck::ml_sumcheck::protocol::ListOfProductsOfPolynomials; use crate::model::qarray::{QArray, QTypeArray}; use crate::model::Poly; @@ -251,6 +254,7 @@ where // commitments in Model::prove_inference let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); + // TODO consider whether this can be done once and stored let weights_f = self.padded_weights.iter().map(|w| F::from(*w)).collect(); // TODO this might need LE -> BE conversion let weights_mle = Poly::from_evaluations_vec(self.com_num_vars(), weights_f); @@ -258,6 +262,18 @@ where // TODO we actually need fix_variables_last weights_mle.fix_variables(&r); + // Constructing the sumcheck polynomial + // big_poly(x) = input(x) * weights(x, r) + let mut big_poly = ListOfProductsOfPolynomials::new(self.padded_dims_log.0); + + big_poly.add_product( + vec![weights_mle, input] + .into_iter() + .map(|mle| Rc::new(mle)) + .collect::>(), + F::one(), + ); + unimplemented!() } } From 1394e8fdd800ef761a15c53c8a2be341b218400b Mon Sep 17 00:00:00 2001 From: mmagician Date: Mon, 19 Feb 2024 21:29:49 +0100 Subject: [PATCH 04/50] Pass committer key to prove method, needed for opening proofs --- src/model/mod.rs | 1 + src/model/nodes/bmm.rs | 1 + src/model/nodes/mod.rs | 4 +++- src/model/nodes/relu.rs | 1 + src/model/nodes/requantise_bmm.rs | 1 + src/model/nodes/reshape.rs | 1 + 6 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/model/mod.rs b/src/model/mod.rs index b7f48fa..29ad4ab 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -205,6 +205,7 @@ where { // TODO prove likely needs to receive the sponge for randomness/FS node_proofs.push(node.prove( + ck, sponge, node_com, values[0].clone(), diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 9c0d7cf..73105b2 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -243,6 +243,7 @@ where fn prove( &self, + ck: &PCS::CommitterKey, sponge: &mut S, node_com: &NodeCommitment, input: Poly, diff --git a/src/model/nodes/mod.rs b/src/model/nodes/mod.rs index 5d43bb4..f2de74a 100644 --- a/src/model/nodes/mod.rs +++ b/src/model/nodes/mod.rs @@ -98,6 +98,7 @@ where /// Produce a node output proof fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, input: Poly, @@ -244,6 +245,7 @@ where fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, input: Poly, @@ -252,6 +254,6 @@ where output_com: &PCS::Commitment, ) -> NodeProof { self.as_node_ops_snark() - .prove(s, node_com, input, input_com, output, output_com) + .prove(ck, s, node_com, input, input_com, output, output_com) } } diff --git a/src/model/nodes/relu.rs b/src/model/nodes/relu.rs index 6115090..2990eec 100644 --- a/src/model/nodes/relu.rs +++ b/src/model/nodes/relu.rs @@ -84,6 +84,7 @@ where fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, input: Poly, diff --git a/src/model/nodes/requantise_bmm.rs b/src/model/nodes/requantise_bmm.rs index 4237f24..8824e06 100644 --- a/src/model/nodes/requantise_bmm.rs +++ b/src/model/nodes/requantise_bmm.rs @@ -146,6 +146,7 @@ where fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, input: Poly, diff --git a/src/model/nodes/reshape.rs b/src/model/nodes/reshape.rs index 690a863..9576f01 100644 --- a/src/model/nodes/reshape.rs +++ b/src/model/nodes/reshape.rs @@ -121,6 +121,7 @@ where fn prove( &self, + ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, input: Poly, From de2b316e8835102ad29f31590b61c59bb02592f5 Mon Sep 17 00:00:00 2001 From: mmagician Date: Mon, 19 Feb 2024 21:34:28 +0100 Subject: [PATCH 05/50] Add `F: Absorb` bounds --- src/model/nodes/bmm.rs | 5 +++-- src/model/nodes/mod.rs | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 73105b2..011a7f4 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -3,12 +3,13 @@ use ark_std::rc::Rc; use ark_poly::MultilinearExtension; use ark_std::marker::PhantomData; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; use ark_std::log2; use ark_std::rand::RngCore; use ark_sumcheck::ml_sumcheck::protocol::ListOfProductsOfPolynomials; +use ark_sumcheck::ml_sumcheck::MLSumcheck; use crate::model::qarray::{QArray, QTypeArray}; use crate::model::Poly; @@ -137,7 +138,7 @@ where impl NodeOpsSNARK for BMMNode where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { diff --git a/src/model/nodes/mod.rs b/src/model/nodes/mod.rs index f2de74a..479c73e 100644 --- a/src/model/nodes/mod.rs +++ b/src/model/nodes/mod.rs @@ -1,3 +1,4 @@ +use ark_crypto_primitives::sponge::Absorb; use ark_ff::PrimeField; use ark_poly_commit::PolynomialCommitment; use ark_std::rand::RngCore; @@ -155,7 +156,7 @@ where // elegantly by simply implementing the trait impl Node where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -192,7 +193,7 @@ where // elegantly by simply implementing the trait impl NodeOps for Node where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -214,7 +215,7 @@ where impl NodeOpsSNARK for Node where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { From f9fe2aec6c4b442e3ad4ba96f4bc11fbf8e143c4 Mon Sep 17 00:00:00 2001 From: mmagician Date: Mon, 19 Feb 2024 21:36:41 +0100 Subject: [PATCH 06/50] use shorthand inside the map --- src/model/nodes/bmm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 011a7f4..95c1e8e 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -271,7 +271,7 @@ where big_poly.add_product( vec![weights_mle, input] .into_iter() - .map(|mle| Rc::new(mle)) + .map(Rc::new) .collect::>(), F::one(), ); From 7d6078b958a61533980a1e441dd23ab3c6c0eb0b Mon Sep 17 00:00:00 2001 From: mmagician Date: Mon, 19 Feb 2024 22:24:37 +0100 Subject: [PATCH 07/50] Incomplete but compiling BMMProof Aside from earlier TODOs, the opening proof only considers input, need to add weights matrix too --- src/model/mod.rs | 35 +++++++++-------- src/model/nodes/bmm.rs | 63 ++++++++++++++++++++++++------- src/model/nodes/mod.rs | 55 ++++++++++++++++++--------- src/model/nodes/relu.rs | 22 ++++++----- src/model/nodes/requantise_bmm.rs | 20 +++++----- src/model/nodes/reshape.rs | 20 +++++----- 6 files changed, 141 insertions(+), 74 deletions(-) diff --git a/src/model/mod.rs b/src/model/mod.rs index 29ad4ab..82144c6 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -20,10 +20,11 @@ mod qarray; mod reshaping; pub(crate) type Poly = DenseMultilinearExtension; +pub(crate) type LabeledPoly = LabeledPolynomial>; pub(crate) struct InferenceProof where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -31,7 +32,7 @@ where outputs: Vec, // Proofs of evaluation of each of the model's nodes - node_proofs: Vec, + node_proofs: Vec>, // Proofs of opening of each of the model's outputs opening_proofs: Vec, @@ -170,7 +171,7 @@ where } // Committing to node outputs as MLEs (individual per node for now) - let output_mles: Vec>> = node_output_mles + let labeled_output_mles: Vec>> = node_output_mles .iter() .map(|mle| // TODO change dummy label once we e.g. have given numbers to the @@ -184,10 +185,10 @@ where )) .collect(); - let (node_coms, node_com_states) = PCS::commit(ck, &output_mles, rng).unwrap(); + let (output_coms, output_com_states) = PCS::commit(ck, &labeled_output_mles, rng).unwrap(); // Absorb all commitments into the sponge - sponge.absorb(&node_coms); + sponge.absorb(&output_coms); // TODO Prove that all commited NIOs live in the right range (to be // discussed) @@ -199,9 +200,9 @@ where .nodes .iter() .zip(node_commitments.iter()) - .zip(node_output_mles.windows(2)) - .zip(node_coms.windows(2)) - .zip(node_com_states.windows(2)) + .zip(labeled_output_mles.windows(2)) + .zip(output_coms.windows(2)) + .zip(output_com_states.windows(2)) { // TODO prove likely needs to receive the sponge for randomness/FS node_proofs.push(node.prove( @@ -209,9 +210,11 @@ where sponge, node_com, values[0].clone(), - l_v_coms[0].commitment(), + &l_v_coms[0], + v_coms_states[0].clone(), values[1].clone(), - l_v_coms[1].commitment(), + &l_v_coms[1], + v_coms_states[1].clone(), )); } @@ -221,15 +224,15 @@ where // but that would require messy node-by-node handling let input_node = node_outputs.first().unwrap(); let input_node_f = node_output_mles.first().unwrap().to_evaluations(); - let input_labeled_value = output_mles.first().unwrap(); - let input_node_com = node_coms.first().unwrap(); - let input_node_com_state = node_com_states.first().unwrap(); + let input_labeled_value = labeled_output_mles.first().unwrap(); + let input_node_com = output_coms.first().unwrap(); + let input_node_com_state = output_com_states.first().unwrap(); let output_node = node_outputs.last().unwrap(); let output_node_f = node_output_mles.last().unwrap().to_evaluations(); - let output_labeled_value = output_mles.last().unwrap(); - let output_node_com = node_coms.last().unwrap(); - let output_node_com_state = node_com_states.last().unwrap(); + let output_labeled_value = labeled_output_mles.last().unwrap(); + let output_node_com = output_coms.last().unwrap(); + let output_node_com_state = output_com_states.last().unwrap(); // Absorb the model IO output and squeeze the challenge point // Absorb the plain output and squeeze the challenge point diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 95c1e8e..a3a5b6b 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -1,18 +1,18 @@ use ark_std::rc::Rc; -use ark_poly::MultilinearExtension; +use ark_poly::{MultilinearExtension, Polynomial}; use ark_std::marker::PhantomData; use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use ark_std::log2; use ark_std::rand::RngCore; use ark_sumcheck::ml_sumcheck::protocol::ListOfProductsOfPolynomials; -use ark_sumcheck::ml_sumcheck::MLSumcheck; +use ark_sumcheck::ml_sumcheck::{MLSumcheck, Proof}; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::Poly; +use crate::model::{LabeledPoly, Poly}; use crate::quantization::{BMMQInfo, QInfo, QLargeType, QScaleType, QSmallType}; use crate::{Commitment, CommitmentState}; @@ -76,8 +76,14 @@ where { } -pub(crate) struct BMMNodeProof { - // this will be the sumcheck proof +pub(crate) struct BMMNodeProof< + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +> { + sumcheck_proof: Proof, + opening_proof: PCS::Proof, + claimed_evaluations: Vec, } impl NodeOps for BMMNode @@ -247,11 +253,13 @@ where ck: &PCS::CommitterKey, sponge: &mut S, node_com: &NodeCommitment, - input: Poly, - input_com: &PCS::Commitment, - output: Poly, - output_com: &PCS::Commitment, - ) -> NodeProof { + input: LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: PCS::CommitmentState, + output: LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: PCS::CommitmentState, + ) -> NodeProof { // we can squeeze directly, since the sponge has already absorbed all the // commitments in Model::prove_inference let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); @@ -268,15 +276,44 @@ where // big_poly(x) = input(x) * weights(x, r) let mut big_poly = ListOfProductsOfPolynomials::new(self.padded_dims_log.0); + // TODO we are cloning the input here, can we do better? big_poly.add_product( - vec![weights_mle, input] + vec![weights_mle, (*input).clone()] .into_iter() .map(Rc::new) .collect::>(), F::one(), ); - unimplemented!() + let (sumcheck_proof, prover_state) = + MLSumcheck::::prove_as_subprotocol(&big_poly, sponge).unwrap(); + + // Prover computes the claimed evaluations of Weights, Input, at the random point + // Note this is a different random point than `r` above: `prover_state.randomness` is + // the list of random values sampled vy V during the sumcheck itself + let claimed_evaluations: Vec = big_poly + .flattened_ml_extensions + .iter() + .map(|x| x.evaluate(&prover_state.randomness)) + .collect(); + + // TODO need to pass the labeled poly, and the commitment to, and the state for, the weights matrix. Currently only passing the data related to the input + let opening_proof = PCS::open( + &ck, + &[input], + &[(*input_com).clone()], + &prover_state.randomness, + sponge, + &[input_com_state], + None, + ) + .unwrap(); + + NodeProof::BMM(BMMNodeProof { + sumcheck_proof, + opening_proof, + claimed_evaluations, + }) } } diff --git a/src/model/nodes/mod.rs b/src/model/nodes/mod.rs index 479c73e..5855fb5 100644 --- a/src/model/nodes/mod.rs +++ b/src/model/nodes/mod.rs @@ -1,6 +1,6 @@ use ark_crypto_primitives::sponge::Absorb; use ark_ff::PrimeField; -use ark_poly_commit::PolynomialCommitment; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use ark_std::rand::RngCore; use crate::{ @@ -20,7 +20,10 @@ use self::{ reshape::ReshapeNode, }; -use super::qarray::{QArray, QTypeArray}; +use super::{ + qarray::{QArray, QTypeArray}, + LabeledPoly, +}; pub(crate) mod bmm; pub(crate) mod relu; @@ -52,7 +55,7 @@ pub(crate) trait NodeOps { pub(crate) trait NodeOpsSNARK where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -102,11 +105,13 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: Poly, - input_com: &PCS::Commitment, - output: Poly, - output_com: &PCS::Commitment, - ) -> NodeProof; + input: LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: PCS::CommitmentState, + output: LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: PCS::CommitmentState, + ) -> NodeProof; } pub(crate) enum Node @@ -121,8 +126,13 @@ where Reshape(ReshapeNode), } -pub(crate) enum NodeProof { - BMM(BMMNodeProof), +pub(crate) enum NodeProof +where + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +{ + BMM(BMMNodeProof), RequantiseBMM(RequantiseBMMNodeProof), ReLU(()), Reshape(()), @@ -249,12 +259,23 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: Poly, - input_com: &PCS::Commitment, - output: Poly, - output_com: &PCS::Commitment, - ) -> NodeProof { - self.as_node_ops_snark() - .prove(ck, s, node_com, input, input_com, output, output_com) + input: LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: PCS::CommitmentState, + output: LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: PCS::CommitmentState, + ) -> NodeProof { + self.as_node_ops_snark().prove( + ck, + s, + node_com, + input, + input_com, + input_com_state, + output, + output_com, + output_com_state, + ) } } diff --git a/src/model/nodes/relu.rs b/src/model/nodes/relu.rs index 2990eec..34dbbaa 100644 --- a/src/model/nodes/relu.rs +++ b/src/model/nodes/relu.rs @@ -1,16 +1,16 @@ use ark_std::log2; use ark_std::marker::PhantomData; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::PolynomialCommitment; +use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::Poly; +use crate::model::{LabeledPoly, Poly}; use crate::quantization::QSmallType; -use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK}; +use super::{NodeCommitment, NodeCommitmentState, NodeOps, NodeOpsSNARK, NodeProof}; // Rectified linear unit node performing x |-> max(0, x). pub(crate) struct ReLUNode @@ -49,7 +49,7 @@ where // impl NodeOpsSnark impl NodeOpsSNARK for ReLUNode where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -87,11 +87,13 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: Poly, - input_com: &PCS::Commitment, - output: Poly, - output_com: &PCS::Commitment, - ) -> super::NodeProof { + input: LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: PCS::CommitmentState, + output: LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: PCS::CommitmentState, + ) -> NodeProof { todo!() } } diff --git a/src/model/nodes/requantise_bmm.rs b/src/model/nodes/requantise_bmm.rs index 8824e06..aef6468 100644 --- a/src/model/nodes/requantise_bmm.rs +++ b/src/model/nodes/requantise_bmm.rs @@ -1,13 +1,13 @@ use ark_std::marker::PhantomData; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use ark_std::log2; use ark_std::rand::RngCore; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::Poly; +use crate::model::{LabeledPoly, Poly}; use crate::quantization::{ requantise_fc, BMMQInfo, QInfo, QLargeType, QScaleType, QSmallType, RoundingScheme, }; @@ -87,7 +87,7 @@ where impl NodeOpsSNARK for RequantiseBMMNode where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -149,11 +149,13 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: Poly, - input_com: &PCS::Commitment, - output: Poly, - output_com: &PCS::Commitment, - ) -> NodeProof { + input: LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: PCS::CommitmentState, + output: LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: PCS::CommitmentState, + ) -> NodeProof { unimplemented!() } } diff --git a/src/model/nodes/reshape.rs b/src/model/nodes/reshape.rs index 9576f01..83edc48 100644 --- a/src/model/nodes/reshape.rs +++ b/src/model/nodes/reshape.rs @@ -1,13 +1,13 @@ use ark_std::log2; use ark_std::marker::PhantomData; -use ark_crypto_primitives::sponge::CryptographicSponge; +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; -use ark_poly_commit::PolynomialCommitment; +use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::Poly; +use crate::model::{LabeledPoly, Poly}; use crate::quantization::QSmallType; use super::{NodeCommitment, NodeOps, NodeOpsSNARK, NodeProof}; @@ -59,7 +59,7 @@ where impl NodeOpsSNARK for ReshapeNode where - F: PrimeField, + F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { @@ -124,11 +124,13 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, - input: Poly, - input_com: &PCS::Commitment, - output: Poly, - output_com: &PCS::Commitment, - ) -> NodeProof { + input: LabeledPoly, + input_com: &LabeledCommitment, + input_com_state: PCS::CommitmentState, + output: LabeledPoly, + output_com: &LabeledCommitment, + output_com_state: PCS::CommitmentState, + ) -> NodeProof { unimplemented!() } } From fa9a3d6e1922fc21df8248af6ba4008303265155 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 09:06:59 +0100 Subject: [PATCH 08/50] added NodeCommitmentState to proving method --- src/model/nodes/bmm.rs | 39 ++++++++++++++++++++++++------- src/model/nodes/mod.rs | 3 +++ src/model/nodes/relu.rs | 1 + src/model/nodes/requantise_bmm.rs | 1 + src/model/nodes/reshape.rs | 3 ++- 5 files changed, 37 insertions(+), 10 deletions(-) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index a3a5b6b..c2829cc 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -82,8 +82,12 @@ pub(crate) struct BMMNodeProof< PCS: PolynomialCommitment, S>, > { sumcheck_proof: Proof, - opening_proof: PCS::Proof, - claimed_evaluations: Vec, + input_opening_proof: PCS::Proof, + input_opening_value: Vec, + weight_opening_proof: PCS::Proof, + weight_opening_value: Vec, + output_opening_proof: PCS::Proof, + output_opening_value: Vec, } impl NodeOps for BMMNode @@ -253,6 +257,7 @@ where ck: &PCS::CommitterKey, sponge: &mut S, node_com: &NodeCommitment, + node_com_state: &NodeCommitmentState, input: LabeledPoly, input_com: &LabeledCommitment, input_com_state: PCS::CommitmentState, @@ -260,10 +265,22 @@ where output_com: &LabeledCommitment, output_com_state: PCS::CommitmentState, ) -> NodeProof { + let (weight_com_state, bias_com_state) = match node_com_state { + NodeCommitmentState::BMM(BMMNodeCommitmentState { + weight_com_state, + bias_com_state, + }) => (weight_com_state, bias_com_state), + _ => panic!( + "BMMNode::prove expected node commitment state of type BMMNodeCommitmentState" + ), + }; + // we can squeeze directly, since the sponge has already absorbed all the // commitments in Model::prove_inference let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); + let input_mle = input.polynomial().clone(); + // TODO consider whether this can be done once and stored let weights_f = self.padded_weights.iter().map(|w| F::from(*w)).collect(); // TODO this might need LE -> BE conversion @@ -273,12 +290,12 @@ where weights_mle.fix_variables(&r); // Constructing the sumcheck polynomial - // big_poly(x) = input(x) * weights(x, r) + // big_poly(x) := input(x) * weights(x, r) let mut big_poly = ListOfProductsOfPolynomials::new(self.padded_dims_log.0); // TODO we are cloning the input here, can we do better? big_poly.add_product( - vec![weights_mle, (*input).clone()] + vec![input_mle, weights_mle] .into_iter() .map(Rc::new) .collect::>(), @@ -288,17 +305,21 @@ where let (sumcheck_proof, prover_state) = MLSumcheck::::prove_as_subprotocol(&big_poly, sponge).unwrap(); - // Prover computes the claimed evaluations of Weights, Input, at the random point - // Note this is a different random point than `r` above: `prover_state.randomness` is - // the list of random values sampled vy V during the sumcheck itself + // The prover computes the claimed evaluations of weight_mle and + // input_mle at the random challenge point + // c:= `prover_state.randomness`, the list of random values sampled by + // the verifier duriing sumcheck. Note that this is different from `r` + // above. If we denote the MLE interpolating the weight matrix by + // original_weight_mle, we need to open + // input_mle(s) * original_weight_mle(s, r) + // as well as output_mle(r) let claimed_evaluations: Vec = big_poly .flattened_ml_extensions .iter() .map(|x| x.evaluate(&prover_state.randomness)) .collect(); - // TODO need to pass the labeled poly, and the commitment to, and the state for, the weights matrix. Currently only passing the data related to the input - let opening_proof = PCS::open( + let input_opening_proof = PCS::open( &ck, &[input], &[(*input_com).clone()], diff --git a/src/model/nodes/mod.rs b/src/model/nodes/mod.rs index 5855fb5..bbf990a 100644 --- a/src/model/nodes/mod.rs +++ b/src/model/nodes/mod.rs @@ -105,6 +105,7 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, + node_com_state: &NodeCommitmentState, input: LabeledPoly, input_com: &LabeledCommitment, input_com_state: PCS::CommitmentState, @@ -259,6 +260,7 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, + node_com_state: &NodeCommitmentState, input: LabeledPoly, input_com: &LabeledCommitment, input_com_state: PCS::CommitmentState, @@ -270,6 +272,7 @@ where ck, s, node_com, + node_com_state, input, input_com, input_com_state, diff --git a/src/model/nodes/relu.rs b/src/model/nodes/relu.rs index 34dbbaa..d00bd36 100644 --- a/src/model/nodes/relu.rs +++ b/src/model/nodes/relu.rs @@ -87,6 +87,7 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, + node_com_state: &NodeCommitmentState, input: LabeledPoly, input_com: &LabeledCommitment, input_com_state: PCS::CommitmentState, diff --git a/src/model/nodes/requantise_bmm.rs b/src/model/nodes/requantise_bmm.rs index aef6468..0cf6a71 100644 --- a/src/model/nodes/requantise_bmm.rs +++ b/src/model/nodes/requantise_bmm.rs @@ -149,6 +149,7 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, + node_com_state: &NodeCommitmentState, input: LabeledPoly, input_com: &LabeledCommitment, input_com_state: PCS::CommitmentState, diff --git a/src/model/nodes/reshape.rs b/src/model/nodes/reshape.rs index 83edc48..556e024 100644 --- a/src/model/nodes/reshape.rs +++ b/src/model/nodes/reshape.rs @@ -7,7 +7,7 @@ use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; use ark_std::rand::RngCore; use crate::model::qarray::{QArray, QTypeArray}; -use crate::model::{LabeledPoly, Poly}; +use crate::model::{LabeledPoly, NodeCommitmentState, Poly}; use crate::quantization::QSmallType; use super::{NodeCommitment, NodeOps, NodeOpsSNARK, NodeProof}; @@ -124,6 +124,7 @@ where ck: &PCS::CommitterKey, s: &mut S, node_com: &NodeCommitment, + node_com_state: &NodeCommitmentState, input: LabeledPoly, input_com: &LabeledCommitment, input_com_state: PCS::CommitmentState, From d792a19d2d7550212d19f89247e616b6f6801b69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 10:24:08 +0100 Subject: [PATCH 09/50] code compiling, several steps missing --- src/model/mod.rs | 19 ++++--- src/model/nodes/bmm.rs | 95 +++++++++++++++++++++++-------- src/model/nodes/mod.rs | 16 +++--- src/model/nodes/relu.rs | 8 +-- src/model/nodes/requantise_bmm.rs | 8 +-- src/model/nodes/reshape.rs | 8 +-- 6 files changed, 102 insertions(+), 52 deletions(-) diff --git a/src/model/mod.rs b/src/model/mod.rs index 82144c6..f0df946 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -129,7 +129,8 @@ where ck: &PCS::CommitterKey, rng: Option<&mut dyn RngCore>, sponge: &mut S, - node_commitments: Vec>, + node_coms: &Vec>, + node_com_states: &Vec>, input: QArray, ) -> InferenceProof { // TODO Absorb public parameters into s (to be determined what exactly) @@ -196,10 +197,11 @@ where let mut node_proofs = Vec::new(); // Second pass: proving - for ((((node, node_com), values), l_v_coms), v_coms_states) in self + for (((((node, node_com), node_com_state), values), l_v_coms), v_coms_states) in self .nodes .iter() - .zip(node_commitments.iter()) + .zip(node_coms.iter()) + .zip(node_com_states.iter()) .zip(labeled_output_mles.windows(2)) .zip(output_coms.windows(2)) .zip(output_com_states.windows(2)) @@ -208,13 +210,14 @@ where node_proofs.push(node.prove( ck, sponge, - node_com, - values[0].clone(), + &node_com, + &node_com_state, + &values[0], &l_v_coms[0], - v_coms_states[0].clone(), - values[1].clone(), + &v_coms_states[0], + &values[1], &l_v_coms[1], - v_coms_states[1].clone(), + &v_coms_states[1], )); } diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index c2829cc..c968f19 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -46,8 +46,8 @@ where S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - weight_com: PCS::Commitment, - bias_com: PCS::Commitment, + weight_com: LabeledCommitment, + bias_com: LabeledCommitment, } impl Commitment for BMMNodeCommitment @@ -83,11 +83,11 @@ pub(crate) struct BMMNodeProof< > { sumcheck_proof: Proof, input_opening_proof: PCS::Proof, - input_opening_value: Vec, + input_opening_value: F, weight_opening_proof: PCS::Proof, - weight_opening_value: Vec, + weight_opening_value: F, output_opening_proof: PCS::Proof, - output_opening_value: Vec, + output_opening_value: F, } impl NodeOps for BMMNode @@ -242,8 +242,8 @@ where ( NodeCommitment::BMM(BMMNodeCommitment { - weight_com: coms.0[0].commitment().clone(), - bias_com: coms.0[1].commitment().clone(), + weight_com: coms.0[0].clone(), + bias_com: coms.0[1].clone(), }), NodeCommitmentState::BMM(BMMNodeCommitmentState { weight_com_state: coms.1[0].clone(), @@ -258,13 +258,21 @@ where sponge: &mut S, node_com: &NodeCommitment, node_com_state: &NodeCommitmentState, - input: LabeledPoly, + input: &LabeledPoly, input_com: &LabeledCommitment, - input_com_state: PCS::CommitmentState, - output: LabeledPoly, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, output_com: &LabeledCommitment, - output_com_state: PCS::CommitmentState, + output_com_state: &PCS::CommitmentState, ) -> NodeProof { + let (weight_com, bias_com) = match node_com { + NodeCommitment::BMM(BMMNodeCommitment { + weight_com, + bias_com, + }) => (weight_com, bias_com), + _ => panic!("BMMNode::prove expected node commitment of type BMMNodeCommitment"), + }; + let (weight_com_state, bias_com_state) = match node_com_state { NodeCommitmentState::BMM(BMMNodeCommitmentState { weight_com_state, @@ -279,15 +287,19 @@ where // commitments in Model::prove_inference let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); + // TODO is this value directly available from the output of sumcheck? + // It doesn't need to be used until the end of the method + let claimed_sum = output.evaluate(&r); + let input_mle = input.polynomial().clone(); // TODO consider whether this can be done once and stored let weights_f = self.padded_weights.iter().map(|w| F::from(*w)).collect(); // TODO this might need LE -> BE conversion - let weights_mle = Poly::from_evaluations_vec(self.com_num_vars(), weights_f); + let weight_mle = Poly::from_evaluations_vec(self.com_num_vars(), weights_f); // TODO we actually need fix_variables_last - weights_mle.fix_variables(&r); + let bound_weight_mle = weight_mle.fix_variables(&r); // Constructing the sumcheck polynomial // big_poly(x) := input(x) * weights(x, r) @@ -295,7 +307,7 @@ where // TODO we are cloning the input here, can we do better? big_poly.add_product( - vec![input_mle, weights_mle] + vec![input_mle, bound_weight_mle] .into_iter() .map(Rc::new) .collect::>(), @@ -307,12 +319,12 @@ where // The prover computes the claimed evaluations of weight_mle and // input_mle at the random challenge point - // c:= `prover_state.randomness`, the list of random values sampled by + // s:= `prover_state.randomness`, the list of random values sampled by // the verifier duriing sumcheck. Note that this is different from `r` - // above. If we denote the MLE interpolating the weight matrix by - // original_weight_mle, we need to open - // input_mle(s) * original_weight_mle(s, r) - // as well as output_mle(r) + // above. + // + // We need to open input_mle(s) * weight_mle(s, r) as well as + // output_mle(r) let claimed_evaluations: Vec = big_poly .flattened_ml_extensions .iter() @@ -321,19 +333,54 @@ where let input_opening_proof = PCS::open( &ck, - &[input], - &[(*input_com).clone()], + [input], + [input_com], &prover_state.randomness, sponge, - &[input_com_state], + [input_com_state], + None, + ) + .unwrap(); + + let weight_opening_proof = PCS::open( + &ck, + [&LabeledPolynomial::new( + "weight_mle".to_string(), + weight_mle, + Some(1), + None, + )], + [weight_com], + &prover_state + .randomness + .into_iter() + .chain(r.clone().into_iter()) + .collect(), + sponge, + [weight_com_state], + None, + ) + .unwrap(); + + let output_opening_proof = PCS::open( + &ck, + [output], + [output_com], + &r, + sponge, + [output_com_state], None, ) .unwrap(); NodeProof::BMM(BMMNodeProof { sumcheck_proof, - opening_proof, - claimed_evaluations, + input_opening_proof, + input_opening_value: claimed_evaluations[0], + weight_opening_proof, + weight_opening_value: claimed_evaluations[1], + output_opening_proof, + output_opening_value: claimed_sum, }) } } diff --git a/src/model/nodes/mod.rs b/src/model/nodes/mod.rs index bbf990a..e5cc3b3 100644 --- a/src/model/nodes/mod.rs +++ b/src/model/nodes/mod.rs @@ -106,12 +106,12 @@ where s: &mut S, node_com: &NodeCommitment, node_com_state: &NodeCommitmentState, - input: LabeledPoly, + input: &LabeledPoly, input_com: &LabeledCommitment, - input_com_state: PCS::CommitmentState, - output: LabeledPoly, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, output_com: &LabeledCommitment, - output_com_state: PCS::CommitmentState, + output_com_state: &PCS::CommitmentState, ) -> NodeProof; } @@ -261,12 +261,12 @@ where s: &mut S, node_com: &NodeCommitment, node_com_state: &NodeCommitmentState, - input: LabeledPoly, + input: &LabeledPoly, input_com: &LabeledCommitment, - input_com_state: PCS::CommitmentState, - output: LabeledPoly, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, output_com: &LabeledCommitment, - output_com_state: PCS::CommitmentState, + output_com_state: &PCS::CommitmentState, ) -> NodeProof { self.as_node_ops_snark().prove( ck, diff --git a/src/model/nodes/relu.rs b/src/model/nodes/relu.rs index d00bd36..31af516 100644 --- a/src/model/nodes/relu.rs +++ b/src/model/nodes/relu.rs @@ -88,12 +88,12 @@ where s: &mut S, node_com: &NodeCommitment, node_com_state: &NodeCommitmentState, - input: LabeledPoly, + input: &LabeledPoly, input_com: &LabeledCommitment, - input_com_state: PCS::CommitmentState, - output: LabeledPoly, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, output_com: &LabeledCommitment, - output_com_state: PCS::CommitmentState, + output_com_state: &PCS::CommitmentState, ) -> NodeProof { todo!() } diff --git a/src/model/nodes/requantise_bmm.rs b/src/model/nodes/requantise_bmm.rs index 0cf6a71..5c69e09 100644 --- a/src/model/nodes/requantise_bmm.rs +++ b/src/model/nodes/requantise_bmm.rs @@ -150,12 +150,12 @@ where s: &mut S, node_com: &NodeCommitment, node_com_state: &NodeCommitmentState, - input: LabeledPoly, + input: &LabeledPoly, input_com: &LabeledCommitment, - input_com_state: PCS::CommitmentState, - output: LabeledPoly, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, output_com: &LabeledCommitment, - output_com_state: PCS::CommitmentState, + output_com_state: &PCS::CommitmentState, ) -> NodeProof { unimplemented!() } diff --git a/src/model/nodes/reshape.rs b/src/model/nodes/reshape.rs index 556e024..053497d 100644 --- a/src/model/nodes/reshape.rs +++ b/src/model/nodes/reshape.rs @@ -125,12 +125,12 @@ where s: &mut S, node_com: &NodeCommitment, node_com_state: &NodeCommitmentState, - input: LabeledPoly, + input: &LabeledPoly, input_com: &LabeledCommitment, - input_com_state: PCS::CommitmentState, - output: LabeledPoly, + input_com_state: &PCS::CommitmentState, + output: &LabeledPoly, output_com: &LabeledCommitment, - output_com_state: PCS::CommitmentState, + output_com_state: &PCS::CommitmentState, ) -> NodeProof { unimplemented!() } From bad5c5757799f953524efef14aad46603a2c393c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 10:25:38 +0100 Subject: [PATCH 10/50] added unpadded inference tests --- .../examples/simple_perceptron_mnist/mod.rs | 28 +++++++++++++++++- .../two_layer_perceptron_mnist/mod.rs | 29 +++++++++++++++++-- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index 15c9f2b..b631117 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -57,7 +57,33 @@ where } #[test] -fn run_simple_perceptron_mnist() { +fn run_native_simple_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; + /**********************/ + + let perceptron = build_simple_perceptron_mnist::, Brakedown>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let output_i8 = perceptron.evaluate(input_i8); + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values(), expected_output); +} + + +#[test] +fn run_padded_simple_perceptron_mnist() { /**** Change here ****/ let input = NORMALISED_INPUT_TEST_150; let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; diff --git a/src/model/examples/two_layer_perceptron_mnist/mod.rs b/src/model/examples/two_layer_perceptron_mnist/mod.rs index ba0ad9d..427dd2b 100644 --- a/src/model/examples/two_layer_perceptron_mnist/mod.rs +++ b/src/model/examples/two_layer_perceptron_mnist/mod.rs @@ -80,7 +80,7 @@ where } #[test] -fn run_two_layer_perceptron_mnist() { +fn run_native_two_layer_perceptron_mnist() { /**** Change here ****/ let input = NORMALISED_INPUT_TEST_150; let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; @@ -96,10 +96,35 @@ fn run_two_layer_perceptron_mnist() { let input_i8 = (quantised_input.cast::() - 128).cast::(); - let output_i8 = perceptron.padded_evaluate(input_i8); + let output_i8 = perceptron.evaluate(input_i8); let output_u8 = (output_i8.cast::() + 128).cast::(); println!("Output: {:?}", output_u8.values()); assert_eq!(output_u8.move_values(), expected_output); } + +#[test] +fn run_padded_two_layer_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; + /**********************/ + + let perceptron = build_two_layer_perceptron_mnist::, Brakedown>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let output_i8 = perceptron.padded_evaluate(input_i8); + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values(), expected_output); +} \ No newline at end of file From 8082f348ba856cc32215aef943514d4fa5710087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 10:46:34 +0100 Subject: [PATCH 11/50] small refactor --- src/lib.rs | 3 - .../examples/simple_perceptron_mnist/mod.rs | 40 ++++- .../two_layer_perceptron_mnist/mod.rs | 2 +- src/{utils.rs => utils/mod.rs} | 4 + src/{ => utils}/pcs_types.rs | 0 src/utils/test_sponge.rs | 142 ++++++++++++++++++ 6 files changed, 186 insertions(+), 5 deletions(-) rename src/{utils.rs => utils/mod.rs} (95%) rename src/{ => utils}/pcs_types.rs (100%) create mode 100644 src/utils/test_sponge.rs diff --git a/src/lib.rs b/src/lib.rs index 1876bc6..1d4523f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,9 +2,6 @@ pub(crate) mod model; pub(crate) mod quantization; pub(crate) mod utils; -#[cfg(test)] -pub(crate) mod pcs_types; - trait Commitment {} trait CommitmentState {} diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index b631117..e87ec50 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -3,7 +3,7 @@ use crate::{ nodes::{bmm::BMMNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, qarray::QArray, Model, Poly, - }, pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} + }, utils::pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} }; use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; @@ -14,6 +14,7 @@ use ark_ff::PrimeField; mod input; mod parameters; +use ark_std::test_rng; use input::*; use parameters::*; @@ -106,3 +107,40 @@ fn run_padded_simple_perceptron_mnist() { println!("Output: {:?}", output_u8.values()); assert_eq!(output_u8.move_values(), expected_output); } + +#[test] +fn prove_inference_simple_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; + /**********************/ + + let perceptron = build_simple_perceptron_mnist::, Brakedown>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let mut rng = test_rng(); + let (ck, vk) = perceptron.setup_keys(&mut rng).unwrap(); + + let mut sponge: PoseidonSponge = PoseidonSponge::new(&poseidon_parameters_for_test()) + + let output_i8 = perceptron.prove_inference( + &ck, + Some(&mut rng), + sponge: &mut S, + node_coms: &Vec>, + node_com_states: &Vec>, + input_i8, + ) + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values(), expected_output); +} \ No newline at end of file diff --git a/src/model/examples/two_layer_perceptron_mnist/mod.rs b/src/model/examples/two_layer_perceptron_mnist/mod.rs index 427dd2b..e77fded 100644 --- a/src/model/examples/two_layer_perceptron_mnist/mod.rs +++ b/src/model/examples/two_layer_perceptron_mnist/mod.rs @@ -3,7 +3,7 @@ use crate::{ nodes::{bmm::BMMNode, relu::ReLUNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, qarray::QArray, Model, Poly, - }, pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} + }, utils::pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} }; use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; diff --git a/src/utils.rs b/src/utils/mod.rs similarity index 95% rename from src/utils.rs rename to src/utils/mod.rs index 16114fc..b8350cb 100644 --- a/src/utils.rs +++ b/src/utils/mod.rs @@ -1,3 +1,7 @@ + +#[cfg(test)] +pub(crate) mod pcs_types; + use ark_ff::Field; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; diff --git a/src/pcs_types.rs b/src/utils/pcs_types.rs similarity index 100% rename from src/pcs_types.rs rename to src/utils/pcs_types.rs diff --git a/src/utils/test_sponge.rs b/src/utils/test_sponge.rs new file mode 100644 index 0000000..1b0c4f1 --- /dev/null +++ b/src/utils/test_sponge.rs @@ -0,0 +1,142 @@ +use core::{borrow::Borrow, marker::PhantomData}; + +use ark_crypto_primitives::{ + crh::{CRHScheme, TwoToOneCRHScheme}, + merkle_tree::{ByteDigestConverter, Config}, + sponge::{ + poseidon::{PoseidonConfig, PoseidonSponge}, + CryptographicSponge, + }, +}; +use ark_ff::PrimeField; +use ark_poly::DenseMultilinearExtension; +use ark_poly_commit::{ + linear_codes::{LinearCodePCS, MultilinearBrakedown}, + to_bytes, +}; +use ark_serialize::CanonicalSerialize; +use ark_std::{rand::RngCore, test_rng}; +use blake2::{Blake2s256, Digest}; +use sha2::Sha256; + +pub(crate) fn test_sponge() -> PoseidonSponge { + PoseidonSponge::new(&poseidon_parameters_for_test()) +} + +/// Generate default parameters for alpha = 17, state-size = 8 +/// +/// WARNING: This poseidon parameter is not secure. Please generate +/// your own parameters according the field you use. +fn poseidon_parameters_for_test() -> PoseidonConfig { + let full_rounds = 8; + let partial_rounds = 31; + let alpha = 17; + + let mds = vec![ + vec![F::one(), F::zero(), F::one()], + vec![F::one(), F::one(), F::zero()], + vec![F::zero(), F::one(), F::one()], + ]; + + let mut ark = Vec::new(); + let mut ark_rng = test_rng(); + + for _ in 0..(full_rounds + partial_rounds) { + let mut res = Vec::new(); + + for _ in 0..3 { + res.push(F::rand(&mut ark_rng)); + } + ark.push(res); + } + PoseidonConfig::new(full_rounds, partial_rounds, alpha, mds, ark, 2, 1) +} + +#[cfg(test)] +pub(crate) struct LeafIdentityHasher; + +#[cfg(test)] +impl CRHScheme for LeafIdentityHasher { + type Input = Vec; + type Output = Vec; + type Parameters = (); + + fn setup(_: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _: &Self::Parameters, + input: T, + ) -> Result { + Ok(input.borrow().to_vec().into()) + } +} + +#[cfg(test)] +pub(crate) struct FieldToBytesColHasher +where + F: PrimeField + CanonicalSerialize, + D: Digest, +{ + _phantom: PhantomData<(F, D)>, +} + +#[cfg(test)] +impl CRHScheme for FieldToBytesColHasher +where + F: PrimeField + CanonicalSerialize, + D: Digest, +{ + type Input = Vec; + type Output = Vec; + type Parameters = (); + + fn setup(_rng: &mut R) -> Result { + Ok(()) + } + + fn evaluate>( + _parameters: &Self::Parameters, + input: T, + ) -> Result { + let mut dig = D::new(); + dig.update(to_bytes!(input.borrow()).unwrap()); + Ok(dig.finalize().to_vec()) + } +} + +pub(crate) type LeafH = LeafIdentityHasher; +pub(crate) type CompressH = Sha256; +pub(crate) type ColHasher = FieldToBytesColHasher; + +pub(crate) struct MerkleTreeParams; + +impl Config for MerkleTreeParams { + type Leaf = Vec; + + type LeafDigest = ::Output; + type LeafInnerDigestConverter = ByteDigestConverter; + type InnerDigest = ::Output; + + type LeafHash = LeafH; + type TwoToOneHash = CompressH; +} + +pub(crate) type MTConfig = MerkleTreeParams; +type Sponge = PoseidonSponge; + +pub(crate) type BrakedownPCS = LinearCodePCS< + MultilinearBrakedown< + F, + MTConfig, + Sponge, + DenseMultilinearExtension, + ColHasher, + >, + F, + DenseMultilinearExtension, + Sponge, + MTConfig, + ColHasher, +>; From da48333b57e809ab4a193e56ffd8ff5b34d30b51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 10:49:13 +0100 Subject: [PATCH 12/50] added utils/mod.rs --- src/utils/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index b8350cb..d16f085 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,4 +1,3 @@ - #[cfg(test)] pub(crate) mod pcs_types; From 43957f54dce58884ab698cabebcf6d9b279ef2e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 10:50:55 +0100 Subject: [PATCH 13/50] added sha2 dependency --- Cargo.lock | 1 + Cargo.toml | 1 + src/model/examples/simple_perceptron_mnist/mod.rs | 2 +- src/utils/mod.rs | 3 +++ 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index b641b84..a075979 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -791,6 +791,7 @@ dependencies = [ "ark-sumcheck", "blake2", "serde_json", + "sha2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index cfa3d4d..a211b0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ ark-bn254 = { version = "^0.4.0", default-features = false, features = [ "curve" blake2 = { version = "0.10", default-features = false } serde_json = "1.0.108" ark-pcs-bench-templates = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "brakedown-com-absorb" } +sha2 = { version = "0.10", default-features = false } [patch.crates-io] ark-ff = { git = "https://github.com/arkworks-rs/algebra/" } diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index e87ec50..b4d9839 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -128,7 +128,7 @@ fn prove_inference_simple_perceptron_mnist() { let mut rng = test_rng(); let (ck, vk) = perceptron.setup_keys(&mut rng).unwrap(); - let mut sponge: PoseidonSponge = PoseidonSponge::new(&poseidon_parameters_for_test()) + let mut sponge: PoseidonSponge = test_sponge(); let output_i8 = perceptron.prove_inference( &ck, diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d16f085..69a4125 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,9 @@ #[cfg(test)] pub(crate) mod pcs_types; +#[cfg(test)] +pub(crate) mod test_sponge; + use ark_ff::Field; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; From 75970241abd5ab08d303ad00f44c2f32043a93a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 11:00:06 +0100 Subject: [PATCH 14/50] progress with the dummy test --- .../examples/simple_perceptron_mnist/mod.rs | 21 ++++++++++++------- src/model/mod.rs | 6 +++--- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index b4d9839..74069d1 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -1,9 +1,9 @@ use crate::{ model::{ nodes::{bmm::BMMNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, - qarray::QArray, + qarray::{QArray, QTypeArray}, Model, Poly, - }, utils::pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} + }, quantization::{quantise_f32_u8_nne, QSmallType}, utils::{pcs_types::Brakedown, test_sponge::test_sponge} }; use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; @@ -130,15 +130,22 @@ fn prove_inference_simple_perceptron_mnist() { let mut sponge: PoseidonSponge = test_sponge(); - let output_i8 = perceptron.prove_inference( + let inference_proof = perceptron.prove_inference( &ck, Some(&mut rng), - sponge: &mut S, - node_coms: &Vec>, - node_com_states: &Vec>, + &mut sponge, + node_coms, + node_com_states, input_i8, - ) + ); + + let output_qtypearray = inference_proof.outputs[0]; + let output_i8 = match output_qtypearray { + QTypeArray::S(o) => o, + _ => panic!("Expected QArray"), + }; + let output_u8 = (output_i8.cast::() + 128).cast::(); println!("Output: {:?}", output_u8.values()); diff --git a/src/model/mod.rs b/src/model/mod.rs index f0df946..db3ab57 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -29,13 +29,13 @@ where PCS: PolynomialCommitment, S>, { // Model output tensors - outputs: Vec, + pub(crate) outputs: Vec, // Proofs of evaluation of each of the model's nodes - node_proofs: Vec>, + pub(crate) node_proofs: Vec>, // Proofs of opening of each of the model's outputs - opening_proofs: Vec, + pub(crate) opening_proofs: Vec, } // TODO change the functions that receive vectors to receive slices instead whenever it makes sense From dc80d72d6e3fe8953da02640db8ba9976840af74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 11:44:44 +0100 Subject: [PATCH 15/50] tracking down bug, which is in Brakedown --- src/model/examples/simple_perceptron_mnist/mod.rs | 13 ++++++++----- src/model/nodes/bmm.rs | 9 +++++++++ src/model/nodes/reshape.rs | 12 ++++++------ 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index 74069d1..11b9575 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -126,24 +126,27 @@ fn prove_inference_simple_perceptron_mnist() { let input_i8 = (quantised_input.cast::() - 128).cast::(); let mut rng = test_rng(); - let (ck, vk) = perceptron.setup_keys(&mut rng).unwrap(); + let (ck, _) = perceptron.setup_keys(&mut rng).unwrap(); let mut sponge: PoseidonSponge = test_sponge(); + //let (hidden_nodes, com_states) = perceptron.commit(&ck, None).iter().unzip(); + let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); + let inference_proof = perceptron.prove_inference( &ck, Some(&mut rng), &mut sponge, - node_coms, - node_com_states, + &node_coms, + &node_com_states, input_i8, ); - let output_qtypearray = inference_proof.outputs[0]; + let output_qtypearray = inference_proof.outputs[0].clone(); let output_i8 = match output_qtypearray { QTypeArray::S(o) => o, - _ => panic!("Expected QArray"), + _ => panic!("Expected QTypeArray::S"), }; let output_u8 = (output_i8.cast::() + 128).cast::(); diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index c968f19..a9e3033 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -331,6 +331,12 @@ where .map(|x| x.evaluate(&prover_state.randomness)) .collect(); + // TODO remove + println!("*** Before opening input ***"); + println!("> node dimensions {:?}", self.padded_dims_log); + println!("> input_n_vars: {}", input.num_vars()); + println!("> randomness length: {}", prover_state.randomness.len()); + let input_opening_proof = PCS::open( &ck, [input], @@ -342,6 +348,9 @@ where ) .unwrap(); + // TODO remove + println!("*** After opening input ***"); + let weight_opening_proof = PCS::open( &ck, [&LabeledPolynomial::new( diff --git a/src/model/nodes/reshape.rs b/src/model/nodes/reshape.rs index 053497d..04bf966 100644 --- a/src/model/nodes/reshape.rs +++ b/src/model/nodes/reshape.rs @@ -112,11 +112,11 @@ where &self, ck: &PCS::CommitterKey, rng: Option<&mut dyn RngCore>, - ) -> ( - super::NodeCommitment, - super::NodeCommitmentState, - ) { - todo!() + ) -> (NodeCommitment, NodeCommitmentState) { + ( + NodeCommitment::Reshape(()), + NodeCommitmentState::Reshape(()), + ) } fn prove( @@ -132,7 +132,7 @@ where output_com: &LabeledCommitment, output_com_state: &PCS::CommitmentState, ) -> NodeProof { - unimplemented!() + NodeProof::Reshape(()) } } From 22b6295abab502ed51a0cd15054cae92c477efe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 11:45:19 +0100 Subject: [PATCH 16/50] text change --- src/model/nodes/bmm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index a9e3033..203dbe6 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -334,7 +334,7 @@ where // TODO remove println!("*** Before opening input ***"); println!("> node dimensions {:?}", self.padded_dims_log); - println!("> input_n_vars: {}", input.num_vars()); + println!("> input num_vars: {}", input.num_vars()); println!("> randomness length: {}", prover_state.randomness.len()); let input_opening_proof = PCS::open( From e5b5cc10643b3d3f88d113c83fe969e818dafa1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 11:55:57 +0100 Subject: [PATCH 17/50] added bias opening proof --- src/model/nodes/bmm.rs | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 203dbe6..96f1345 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -86,6 +86,8 @@ pub(crate) struct BMMNodeProof< input_opening_value: F, weight_opening_proof: PCS::Proof, weight_opening_value: F, + bias_opening_proof: PCS::Proof, + bias_opening_value: F, output_opening_proof: PCS::Proof, output_opening_value: F, } @@ -287,10 +289,6 @@ where // commitments in Model::prove_inference let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); - // TODO is this value directly available from the output of sumcheck? - // It doesn't need to be used until the end of the method - let claimed_sum = output.evaluate(&r); - let input_mle = input.polynomial().clone(); // TODO consider whether this can be done once and stored @@ -298,6 +296,16 @@ where // TODO this might need LE -> BE conversion let weight_mle = Poly::from_evaluations_vec(self.com_num_vars(), weights_f); + // TODO consider whether this can be done once and stored + let bias_f = self.padded_bias.iter().map(|w| F::from(*w)).collect(); + // TODO this might need LE -> BE conversion + let bias_mle = Poly::from_evaluations_vec(self.padded_dims_log.1, bias_f); + + // TODO is output_opening_value directly available from the output of sumcheck? + // It doesn't need to be used until the end of the method + let bias_opening_value = bias_mle.evaluate(&r); + let output_opening_value = output.evaluate(&r); + // TODO we actually need fix_variables_last let bound_weight_mle = weight_mle.fix_variables(&r); @@ -371,6 +379,22 @@ where ) .unwrap(); + let bias_opening_proof = PCS::open( + &ck, + [&LabeledPolynomial::new( + "bias_mle".to_string(), + bias_mle, + Some(1), + None, + )], + [bias_com], + &r, + sponge, + [bias_com_state], + None, + ) + .unwrap(); + let output_opening_proof = PCS::open( &ck, [output], @@ -388,8 +412,10 @@ where input_opening_value: claimed_evaluations[0], weight_opening_proof, weight_opening_value: claimed_evaluations[1], + bias_opening_proof, + bias_opening_value, output_opening_proof, - output_opening_value: claimed_sum, + output_opening_value, }) } } From d634aa61288e7941c90674b1713413a15e270988 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 12:28:23 +0100 Subject: [PATCH 18/50] tiny improvements --- src/model/nodes/bmm.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 96f1345..72d0999 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -220,14 +220,14 @@ where rng: Option<&mut dyn RngCore>, ) -> (NodeCommitment, NodeCommitmentState) { // TODO should we separate the associated commitment type into one with state and one without? - - let num_vars_weights = self.padded_dims_log.0 + self.padded_dims_log.1; let padded_weights_f: Vec = self.padded_weights.iter().map(|w| F::from(*w)).collect(); + // TODO part of this code is duplicated in prove, another hint that this should probs + // be stored let weight_poly = LabeledPolynomial::new( "weight_poly".to_string(), - Poly::from_evaluations_vec(num_vars_weights, padded_weights_f), - None, + Poly::from_evaluations_vec(self.com_num_vars(), padded_weights_f), + Some(1), None, ); @@ -236,7 +236,7 @@ where let bias_poly = LabeledPolynomial::new( "bias_poly".to_string(), Poly::from_evaluations_vec(self.padded_dims_log.1, padded_bias_f), - None, + Some(1), None, ); From c5b28528ff020b35e70bc84b2a3cd68ffd84c726 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 17:06:21 +0100 Subject: [PATCH 19/50] cleaned sponge utils file --- src/utils/test_sponge.rs | 111 ++------------------------------------- 1 file changed, 4 insertions(+), 107 deletions(-) diff --git a/src/utils/test_sponge.rs b/src/utils/test_sponge.rs index 1b0c4f1..eab3854 100644 --- a/src/utils/test_sponge.rs +++ b/src/utils/test_sponge.rs @@ -1,23 +1,9 @@ -use core::{borrow::Borrow, marker::PhantomData}; - -use ark_crypto_primitives::{ - crh::{CRHScheme, TwoToOneCRHScheme}, - merkle_tree::{ByteDigestConverter, Config}, - sponge::{ - poseidon::{PoseidonConfig, PoseidonSponge}, - CryptographicSponge, - }, +use ark_crypto_primitives::sponge::{ + poseidon::{PoseidonConfig, PoseidonSponge}, + CryptographicSponge, }; use ark_ff::PrimeField; -use ark_poly::DenseMultilinearExtension; -use ark_poly_commit::{ - linear_codes::{LinearCodePCS, MultilinearBrakedown}, - to_bytes, -}; -use ark_serialize::CanonicalSerialize; -use ark_std::{rand::RngCore, test_rng}; -use blake2::{Blake2s256, Digest}; -use sha2::Sha256; +use ark_std::test_rng; pub(crate) fn test_sponge() -> PoseidonSponge { PoseidonSponge::new(&poseidon_parameters_for_test()) @@ -51,92 +37,3 @@ fn poseidon_parameters_for_test() -> PoseidonConfig { } PoseidonConfig::new(full_rounds, partial_rounds, alpha, mds, ark, 2, 1) } - -#[cfg(test)] -pub(crate) struct LeafIdentityHasher; - -#[cfg(test)] -impl CRHScheme for LeafIdentityHasher { - type Input = Vec; - type Output = Vec; - type Parameters = (); - - fn setup(_: &mut R) -> Result { - Ok(()) - } - - fn evaluate>( - _: &Self::Parameters, - input: T, - ) -> Result { - Ok(input.borrow().to_vec().into()) - } -} - -#[cfg(test)] -pub(crate) struct FieldToBytesColHasher -where - F: PrimeField + CanonicalSerialize, - D: Digest, -{ - _phantom: PhantomData<(F, D)>, -} - -#[cfg(test)] -impl CRHScheme for FieldToBytesColHasher -where - F: PrimeField + CanonicalSerialize, - D: Digest, -{ - type Input = Vec; - type Output = Vec; - type Parameters = (); - - fn setup(_rng: &mut R) -> Result { - Ok(()) - } - - fn evaluate>( - _parameters: &Self::Parameters, - input: T, - ) -> Result { - let mut dig = D::new(); - dig.update(to_bytes!(input.borrow()).unwrap()); - Ok(dig.finalize().to_vec()) - } -} - -pub(crate) type LeafH = LeafIdentityHasher; -pub(crate) type CompressH = Sha256; -pub(crate) type ColHasher = FieldToBytesColHasher; - -pub(crate) struct MerkleTreeParams; - -impl Config for MerkleTreeParams { - type Leaf = Vec; - - type LeafDigest = ::Output; - type LeafInnerDigestConverter = ByteDigestConverter; - type InnerDigest = ::Output; - - type LeafHash = LeafH; - type TwoToOneHash = CompressH; -} - -pub(crate) type MTConfig = MerkleTreeParams; -type Sponge = PoseidonSponge; - -pub(crate) type BrakedownPCS = LinearCodePCS< - MultilinearBrakedown< - F, - MTConfig, - Sponge, - DenseMultilinearExtension, - ColHasher, - >, - F, - DenseMultilinearExtension, - Sponge, - MTConfig, - ColHasher, ->; From 3b5d847c32f15eddc315d9c83467999a436cf529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 17:12:39 +0100 Subject: [PATCH 20/50] replaced Brakedown by Ligero --- Cargo.lock | 51 +++++++++++++++++-- Cargo.toml | 3 +- .../examples/simple_perceptron_mnist/mod.rs | 8 +-- .../two_layer_perceptron_mnist/mod.rs | 6 +-- src/model/mod.rs | 4 ++ src/utils/pcs_types.rs | 13 ++--- 6 files changed, 63 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a075979..7f572c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,7 +145,7 @@ dependencies = [ "ark-ec", "ark-ff", "ark-poly", - "ark-poly-commit", + "ark-poly-commit 0.4.0 (git+https://github.com/HungryCatsStudio/poly-commit?branch=brakedown-com-absorb)", "ark-serialize", "ark-std", "criterion", @@ -181,6 +181,23 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ark-poly-commit" +version = "0.4.0" +source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb#4bebbc94a94bbc1a4c48b884c6b022bfbe91d934" +dependencies = [ + "ark-crypto-primitives", + "ark-ec", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "derivative", + "digest", + "merlin", + "num-traits", +] + [[package]] name = "ark-relations" version = "0.4.0" @@ -243,7 +260,7 @@ dependencies = [ "ark-crypto-primitives", "ark-ff", "ark-poly", - "ark-poly-commit", + "ark-poly-commit 0.4.0 (git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb)", "ark-serialize", "ark-std", "hashbrown", @@ -279,6 +296,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "cast" version = "0.3.0" @@ -501,6 +524,15 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + [[package]] name = "libc" version = "0.2.153" @@ -519,6 +551,18 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +[[package]] +name = "merlin" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58c38e2799fc0978b65dfff8023ec7843e2330bb462f19198840b34b6582397d" +dependencies = [ + "byteorder", + "keccak", + "rand_core", + "zeroize", +] + [[package]] name = "num-bigint" version = "0.4.4" @@ -785,13 +829,12 @@ dependencies = [ "ark-ff", "ark-pcs-bench-templates", "ark-poly", - "ark-poly-commit", + "ark-poly-commit 0.4.0 (git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb)", "ark-serialize", "ark-std", "ark-sumcheck", "blake2", "serde_json", - "sha2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a211b0e..5df5965 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,13 +18,12 @@ ark-bn254 = { version = "^0.4.0", default-features = false, features = [ "curve" blake2 = { version = "0.10", default-features = false } serde_json = "1.0.108" ark-pcs-bench-templates = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "brakedown-com-absorb" } -sha2 = { version = "0.10", default-features = false } [patch.crates-io] ark-ff = { git = "https://github.com/arkworks-rs/algebra/" } ark-ec = { git = "https://github.com/arkworks-rs/algebra/" } ark-serialize = { git = "https://github.com/arkworks-rs/algebra/" } ark-poly = { git = "https://github.com/arkworks-rs/algebra/" } -ark-poly-commit = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "brakedown-com-absorb" } +ark-poly-commit = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "ligero-uni-and-ml-absorb" } ark-crypto-primitives = { git = "https://github.com/HungryCatsStudio/crypto-primitives", branch = "absorb"} ark-bn254 = { git = "https://github.com/arkworks-rs/algebra/" } diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index 11b9575..b8e702c 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -3,7 +3,7 @@ use crate::{ nodes::{bmm::BMMNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, qarray::{QArray, QTypeArray}, Model, Poly, - }, quantization::{quantise_f32_u8_nne, QSmallType}, utils::{pcs_types::Brakedown, test_sponge::test_sponge} + }, quantization::{quantise_f32_u8_nne, QSmallType}, utils::{pcs_types::Ligero, test_sponge::test_sponge} }; use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; @@ -64,7 +64,7 @@ fn run_native_simple_perceptron_mnist() { let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; /**********************/ - let perceptron = build_simple_perceptron_mnist::, Brakedown>(); + let perceptron = build_simple_perceptron_mnist::, Ligero>(); let quantised_input: QArray = input .iter() @@ -90,7 +90,7 @@ fn run_padded_simple_perceptron_mnist() { let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; /**********************/ - let perceptron = build_simple_perceptron_mnist::, Brakedown>(); + let perceptron = build_simple_perceptron_mnist::, Ligero>(); let quantised_input: QArray = input .iter() @@ -115,7 +115,7 @@ fn prove_inference_simple_perceptron_mnist() { let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; /**********************/ - let perceptron = build_simple_perceptron_mnist::, Brakedown>(); + let perceptron = build_simple_perceptron_mnist::, Ligero>(); let quantised_input: QArray = input .iter() diff --git a/src/model/examples/two_layer_perceptron_mnist/mod.rs b/src/model/examples/two_layer_perceptron_mnist/mod.rs index e77fded..0f16fab 100644 --- a/src/model/examples/two_layer_perceptron_mnist/mod.rs +++ b/src/model/examples/two_layer_perceptron_mnist/mod.rs @@ -3,7 +3,7 @@ use crate::{ nodes::{bmm::BMMNode, relu::ReLUNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, qarray::QArray, Model, Poly, - }, utils::pcs_types::Brakedown, quantization::{quantise_f32_u8_nne, QSmallType} + }, utils::pcs_types::Ligero, quantization::{quantise_f32_u8_nne, QSmallType} }; use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; @@ -86,7 +86,7 @@ fn run_native_two_layer_perceptron_mnist() { let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; /**********************/ - let perceptron = build_two_layer_perceptron_mnist::, Brakedown>(); + let perceptron = build_two_layer_perceptron_mnist::, Ligero>(); let quantised_input: QArray = input .iter() @@ -111,7 +111,7 @@ fn run_padded_two_layer_perceptron_mnist() { let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; /**********************/ - let perceptron = build_two_layer_perceptron_mnist::, Brakedown>(); + let perceptron = build_two_layer_perceptron_mnist::, Ligero>(); let quantised_input: QArray = input .iter() diff --git a/src/model/mod.rs b/src/model/mod.rs index db3ab57..ac9f022 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -207,6 +207,10 @@ where .zip(output_com_states.windows(2)) { // TODO prove likely needs to receive the sponge for randomness/FS + println!("Proving node: {}", node.type_name()); + println!("Node input num_vars: {:?}", values[0].num_vars()); + println!("Node output num_vars: {:?}", values[1].num_vars()); + node_proofs.push(node.prove( ck, sponge, diff --git a/src/utils/pcs_types.rs b/src/utils/pcs_types.rs index 727e397..e6b81cc 100644 --- a/src/utils/pcs_types.rs +++ b/src/utils/pcs_types.rs @@ -6,7 +6,7 @@ use ark_crypto_primitives::{ use ark_pcs_bench_templates::*; use ark_poly::DenseMultilinearExtension; -use ark_poly_commit::linear_codes::{LinearCodePCS, MultilinearBrakedown}; +use ark_poly_commit::linear_codes::{LinearCodePCS, MultilinearLigero}; use blake2::Blake2s256; // Brakedown PCS over BN254 @@ -26,14 +26,9 @@ impl Config for MerkleTreeParams { type MTConfig = MerkleTreeParams; type ColHasher = FieldToBytesColHasher; -pub(crate) type Brakedown = LinearCodePCS< - MultilinearBrakedown< - F, - MTConfig, - PoseidonSponge, - DenseMultilinearExtension, - ColHasher, - >, + +pub(crate) type Ligero = LinearCodePCS< + MultilinearLigero, DenseMultilinearExtension, ColHasher>, F, DenseMultilinearExtension, PoseidonSponge, From d7793bfd1e8660ceca4a0566cebd6d33b487ae16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 17:16:46 +0100 Subject: [PATCH 21/50] cleaned code --- src/model/mod.rs | 5 ----- src/model/nodes/bmm.rs | 6 ------ src/model/nodes/requantise_bmm.rs | 2 +- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/src/model/mod.rs b/src/model/mod.rs index ac9f022..d7b7697 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -206,11 +206,6 @@ where .zip(output_coms.windows(2)) .zip(output_com_states.windows(2)) { - // TODO prove likely needs to receive the sponge for randomness/FS - println!("Proving node: {}", node.type_name()); - println!("Node input num_vars: {:?}", values[0].num_vars()); - println!("Node output num_vars: {:?}", values[1].num_vars()); - node_proofs.push(node.prove( ck, sponge, diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 72d0999..18cf025 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -339,12 +339,6 @@ where .map(|x| x.evaluate(&prover_state.randomness)) .collect(); - // TODO remove - println!("*** Before opening input ***"); - println!("> node dimensions {:?}", self.padded_dims_log); - println!("> input num_vars: {}", input.num_vars()); - println!("> randomness length: {}", prover_state.randomness.len()); - let input_opening_proof = PCS::open( &ck, [input], diff --git a/src/model/nodes/requantise_bmm.rs b/src/model/nodes/requantise_bmm.rs index 5c69e09..319b21a 100644 --- a/src/model/nodes/requantise_bmm.rs +++ b/src/model/nodes/requantise_bmm.rs @@ -157,7 +157,7 @@ where output_com: &LabeledCommitment, output_com_state: &PCS::CommitmentState, ) -> NodeProof { - unimplemented!() + NodeProof::RequantiseBMM(RequantiseBMMNodeProof {}) } } From 1802651c2abe1af02624587eaf16a7531b90f1ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 17:17:57 +0100 Subject: [PATCH 22/50] cleaned code 2 --- src/model/nodes/bmm.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 18cf025..85a8784 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -350,9 +350,6 @@ where ) .unwrap(); - // TODO remove - println!("*** After opening input ***"); - let weight_opening_proof = PCS::open( &ck, [&LabeledPolynomial::new( From 60b13c6caba83864c3aa65473c4970fc09ed94ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Tue, 20 Feb 2024 17:29:10 +0100 Subject: [PATCH 23/50] dummy test working for two-layer perceptron --- .../examples/simple_perceptron_mnist/mod.rs | 6 +-- .../two_layer_perceptron_mnist/mod.rs | 52 ++++++++++++++++++- src/model/mod.rs | 6 +-- src/model/nodes/relu.rs | 4 +- 4 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index b8e702c..fa4247f 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -142,7 +142,7 @@ fn prove_inference_simple_perceptron_mnist() { input_i8, ); - let output_qtypearray = inference_proof.outputs[0].clone(); + let output_qtypearray = inference_proof.inputs_outputs[1].clone(); let output_i8 = match output_qtypearray { QTypeArray::S(o) => o, @@ -151,6 +151,6 @@ fn prove_inference_simple_perceptron_mnist() { let output_u8 = (output_i8.cast::() + 128).cast::(); - println!("Output: {:?}", output_u8.values()); - assert_eq!(output_u8.move_values(), expected_output); + println!("Padded output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS[0]], expected_output); } \ No newline at end of file diff --git a/src/model/examples/two_layer_perceptron_mnist/mod.rs b/src/model/examples/two_layer_perceptron_mnist/mod.rs index 0f16fab..9ed86f2 100644 --- a/src/model/examples/two_layer_perceptron_mnist/mod.rs +++ b/src/model/examples/two_layer_perceptron_mnist/mod.rs @@ -1,9 +1,9 @@ use crate::{ model::{ nodes::{bmm::BMMNode, relu::ReLUNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, - qarray::QArray, + qarray::{QArray, QTypeArray}, Model, Poly, - }, utils::pcs_types::Ligero, quantization::{quantise_f32_u8_nne, QSmallType} + }, quantization::{quantise_f32_u8_nne, QSmallType}, utils::{pcs_types::Ligero, test_sponge::test_sponge} }; use ark_crypto_primitives::sponge::{poseidon::PoseidonSponge, Absorb, CryptographicSponge}; @@ -14,6 +14,7 @@ use ark_ff::PrimeField; mod input; mod parameters; +use ark_std::test_rng; use input::*; use parameters::*; @@ -127,4 +128,51 @@ fn run_padded_two_layer_perceptron_mnist() { println!("Output: {:?}", output_u8.values()); assert_eq!(output_u8.move_values(), expected_output); +} + +#[test] +fn prove_inference_two_layer_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; + /**********************/ + + let perceptron = build_two_layer_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let mut rng = test_rng(); + let (ck, _) = perceptron.setup_keys(&mut rng).unwrap(); + + let mut sponge: PoseidonSponge = test_sponge(); + + //let (hidden_nodes, com_states) = perceptron.commit(&ck, None).iter().unzip(); + let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); + + let inference_proof = perceptron.prove_inference( + &ck, + Some(&mut rng), + &mut sponge, + &node_coms, + &node_com_states, + input_i8, + ); + + let output_qtypearray = inference_proof.inputs_outputs[1].clone(); + + let output_i8 = match output_qtypearray { + QTypeArray::S(o) => o, + _ => panic!("Expected QTypeArray::S"), + }; + + let output_u8 = (output_i8.cast::() + 128).cast::(); + + println!("Padded output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIM], expected_output); } \ No newline at end of file diff --git a/src/model/mod.rs b/src/model/mod.rs index d7b7697..169ca1d 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -28,8 +28,8 @@ where S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - // Model output tensors - pub(crate) outputs: Vec, + // Model input and output tensors in plain + pub(crate) inputs_outputs: Vec, // Proofs of evaluation of each of the model's nodes pub(crate) node_proofs: Vec>, @@ -276,7 +276,7 @@ where /* TODO (important) Change output_node to all boundary nodes: first and last */ // TODO prove that inputs match input commitments? InferenceProof { - outputs: vec![input_node.clone(), output_node.clone()], + inputs_outputs: vec![input_node.clone(), output_node.clone()], node_proofs, opening_proofs: vec![input_opening_proof, output_opening_proof], } diff --git a/src/model/nodes/relu.rs b/src/model/nodes/relu.rs index 31af516..1cfc44c 100644 --- a/src/model/nodes/relu.rs +++ b/src/model/nodes/relu.rs @@ -66,7 +66,7 @@ where ck: &PCS::CommitterKey, rng: Option<&mut dyn RngCore>, ) -> (NodeCommitment, NodeCommitmentState) { - todo!() + (NodeCommitment::ReLU(()), NodeCommitmentState::ReLU(())) } // TODO this is the same as evaluate() for now; the two will likely differ @@ -95,7 +95,7 @@ where output_com: &LabeledCommitment, output_com_state: &PCS::CommitmentState, ) -> NodeProof { - todo!() + NodeProof::ReLU(()) } } From b8cf8eeee1769993c9fbb9c6c9e1511a3a92191a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Wed, 21 Feb 2024 12:53:47 +0100 Subject: [PATCH 24/50] added NIO commitments to the inference proof --- src/model/mod.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/model/mod.rs b/src/model/mod.rs index 169ca1d..3741d62 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -3,7 +3,7 @@ use ark_std::{log2, rand::RngCore}; use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; -use ark_poly_commit::{LabeledPolynomial, PolynomialCommitment}; +use ark_poly_commit::{LabeledCommitment, LabeledPolynomial, PolynomialCommitment}; use crate::model::nodes::{NodeOps, NodeOpsSNARK}; use crate::{model::nodes::Node, quantization::QSmallType}; @@ -31,6 +31,9 @@ where // Model input and output tensors in plain pub(crate) inputs_outputs: Vec, + // Commitments to each of the node values + pub(crate) node_commitments: Vec>, + // Proofs of evaluation of each of the model's nodes pub(crate) node_proofs: Vec>, @@ -273,10 +276,10 @@ where ) .unwrap(); - /* TODO (important) Change output_node to all boundary nodes: first and last */ // TODO prove that inputs match input commitments? InferenceProof { inputs_outputs: vec![input_node.clone(), output_node.clone()], + node_commitments: output_coms, node_proofs, opening_proofs: vec![input_opening_proof, output_opening_proof], } From 13988c2ea27d93b26af3845b103ad664eee176d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Wed, 21 Feb 2024 14:30:44 +0100 Subject: [PATCH 25/50] small rename --- src/model/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model/mod.rs b/src/model/mod.rs index 3741d62..ce8cdba 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -32,7 +32,7 @@ where pub(crate) inputs_outputs: Vec, // Commitments to each of the node values - pub(crate) node_commitments: Vec>, + pub(crate) node_value_commitments: Vec>, // Proofs of evaluation of each of the model's nodes pub(crate) node_proofs: Vec>, @@ -279,7 +279,7 @@ where // TODO prove that inputs match input commitments? InferenceProof { inputs_outputs: vec![input_node.clone(), output_node.clone()], - node_commitments: output_coms, + node_value_commitments: output_coms, node_proofs, opening_proofs: vec![input_opening_proof, output_opening_proof], } From af5ee01cd02cce27a80b1ac38ccb60ea3f906125 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Wed, 21 Feb 2024 12:35:34 +0100 Subject: [PATCH 26/50] verfy_bmm_node compiling --- src/model/isolated_verification.rs | 138 +++++++++++++++++++++++++++++ src/model/mod.rs | 1 + src/model/nodes/bmm.rs | 22 ++--- 3 files changed, 150 insertions(+), 11 deletions(-) create mode 100644 src/model/isolated_verification.rs diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs new file mode 100644 index 0000000..1b5588d --- /dev/null +++ b/src/model/isolated_verification.rs @@ -0,0 +1,138 @@ +use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; +use ark_ff::PrimeField; +use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; +use ark_sumcheck::ml_sumcheck::{ + protocol::{verifier::SubClaim, PolynomialInfo}, + MLSumcheck, +}; + +use crate::model::nodes::bmm::{BMMNodeCommitment, BMMNodeProof}; + +use super::{ + nodes::{NodeCommitment, NodeProof}, + Poly, +}; + +fn verify_bmm_node( + vk: &PCS::VerifierKey, + sponge: &mut S, + node_com: &NodeCommitment, + input_com: &LabeledCommitment, + output_com: &LabeledCommitment, + proof: NodeProof, + padded_dims_log: (usize, usize), +) -> bool +where + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +{ + let NodeCommitment::BMM(BMMNodeCommitment { + weight_com, + bias_com, + }) = node_com + else { + panic!("Expected BMMNodeCommitment") + }; + + let BMMNodeProof { + sumcheck_proof, + input_opening_proof, + input_opening_value, + weight_opening_proof, + weight_opening_value, + bias_opening_proof, + bias_opening_value, + output_opening_proof, + output_opening_value, + } = match proof { + NodeProof::BMM(p) => p, + _ => panic!("Expected BMMNodeProof"), + }; + + let r: Vec = sponge.squeeze_field_elements(padded_dims_log.1); + + // The value proved in sumcheck should be the difference between the output + // and the bias + let sumcheck_evaluation = output_opening_value - bias_opening_value; + + // Information about the polynomial f(s) = input_mle(s) * weight_mle(s, r) + // to which sumcheck is applied + let info = PolynomialInfo { + max_multiplicands: 2, + num_variables: padded_dims_log.0, + products: vec![(F::one(), vec![0, 1])], + }; + + let Ok(subclaim) = MLSumcheck::verify(&info, sumcheck_evaluation, &sumcheck_proof, sponge) + else { + return false; + }; + + let SubClaim { + point: oracle_point, + expected_evaluation: oracle_evaluation, + } = subclaim; + + if oracle_evaluation != input_opening_value * weight_opening_value { + return false; + } + + // TODO possibly rng, not None + if !PCS::check( + vk, + [input_com], + &oracle_point, + [input_opening_value], + &input_opening_proof, + sponge, + None, + ) + .unwrap() + { + return false; + } + + // TODO possibly rng, not None + if !PCS::check( + vk, + [weight_com], + &oracle_point + .into_iter() + .chain(r.clone().into_iter()) + .collect(), + [weight_opening_value], + &weight_opening_proof, + sponge, + None, + ) + .unwrap() + { + return false; + } + + if !PCS::check( + vk, + [bias_com], + &r, + [bias_opening_value], + &bias_opening_proof, + sponge, + None, + ) + .unwrap() + { + return false; + } + + PCS::check( + vk, + [output_com], + &r, + [output_opening_value], + &output_opening_proof, + sponge, + None, + ) + .unwrap() +} diff --git a/src/model/mod.rs b/src/model/mod.rs index ce8cdba..5e9b463 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -15,6 +15,7 @@ use self::{ }; mod examples; +mod isolated_verification; mod nodes; mod qarray; mod reshaping; diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 85a8784..e06b49f 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -46,8 +46,8 @@ where S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - weight_com: LabeledCommitment, - bias_com: LabeledCommitment, + pub(crate) weight_com: LabeledCommitment, + pub(crate) bias_com: LabeledCommitment, } impl Commitment for BMMNodeCommitment @@ -81,15 +81,15 @@ pub(crate) struct BMMNodeProof< S: CryptographicSponge, PCS: PolynomialCommitment, S>, > { - sumcheck_proof: Proof, - input_opening_proof: PCS::Proof, - input_opening_value: F, - weight_opening_proof: PCS::Proof, - weight_opening_value: F, - bias_opening_proof: PCS::Proof, - bias_opening_value: F, - output_opening_proof: PCS::Proof, - output_opening_value: F, + pub(crate) sumcheck_proof: Proof, + pub(crate) input_opening_proof: PCS::Proof, + pub(crate) input_opening_value: F, + pub(crate) weight_opening_proof: PCS::Proof, + pub(crate) weight_opening_value: F, + pub(crate) bias_opening_proof: PCS::Proof, + pub(crate) bias_opening_value: F, + pub(crate) output_opening_proof: PCS::Proof, + pub(crate) output_opening_value: F, } impl NodeOps for BMMNode From 4f110667fd552ca1f1fccd1b41562ebacbceffee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Wed, 21 Feb 2024 14:28:35 +0100 Subject: [PATCH 27/50] wip; need to bring changes from parent branch --- src/model/isolated_verification.rs | 110 ++++++++++++++++++++++++++++- 1 file changed, 109 insertions(+), 1 deletion(-) diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index 1b5588d..d35eeb5 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -1,3 +1,5 @@ +use std::vec; + use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; @@ -10,7 +12,7 @@ use crate::model::nodes::bmm::{BMMNodeCommitment, BMMNodeProof}; use super::{ nodes::{NodeCommitment, NodeProof}, - Poly, + InferenceProof, Model, Poly, }; fn verify_bmm_node( @@ -136,3 +138,109 @@ where ) .unwrap() } + +fn verify_model( + vk: &PCS::VerifierKey, + model: &Model, + node_commitments: &Vec>, + inference_proof: InferenceProof, + sponge: &mut S, +) where + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +{ + let input_node = model.nodes.first().unwrap(); + let output_node = model.nodes.last().unwrap(); + + // Absorb all commitments into the sponge + sponge.absorb(inference_proof); + + // TODO Prove that all commited NIOs live in the right range (to be + // discussed) + + let mut node_proofs = Vec::new(); + + // Second pass: proving + for (((((node, node_com), node_com_state), values), l_v_coms), v_coms_states) in self + .nodes + .iter() + .zip(node_coms.iter()) + .zip(node_com_states.iter()) + .zip(labeled_output_mles.windows(2)) + .zip(output_coms.windows(2)) + .zip(output_com_states.windows(2)) + { + node_proofs.push(node.prove( + ck, + sponge, + &node_com, + &node_com_state, + &values[0], + &l_v_coms[0], + &v_coms_states[0], + &values[1], + &l_v_coms[1], + &v_coms_states[1], + )); + } + + // Opening model IO + // TODO maybe this can be made more efficient by not committing to the + // output nodes and instead working witht their plain values all along, + // but that would require messy node-by-node handling + let input_node = node_outputs.first().unwrap(); + let input_node_f = node_output_mles.first().unwrap().to_evaluations(); + let input_labeled_value = labeled_output_mles.first().unwrap(); + let input_node_com = output_coms.first().unwrap(); + let input_node_com_state = output_com_states.first().unwrap(); + + let output_node = node_outputs.last().unwrap(); + let output_node_f = node_output_mles.last().unwrap().to_evaluations(); + let output_labeled_value = labeled_output_mles.last().unwrap(); + let output_node_com = output_coms.last().unwrap(); + let output_node_com_state = output_com_states.last().unwrap(); + + // Absorb the model IO output and squeeze the challenge point + // Absorb the plain output and squeeze the challenge point + sponge.absorb(&input_node_f); + sponge.absorb(&output_node_f); + let input_challenge_point = sponge.squeeze_field_elements(log2(input_node_f.len()) as usize); + let output_challenge_point = sponge.squeeze_field_elements(log2(output_node_f.len()) as usize); + + // TODO we have to pass rng, not None, but it has been moved before + // fix this once we have decided how to handle the cumbersome + // Option<&mut rng...> + let input_opening_proof = PCS::open( + ck, + [input_labeled_value], + [input_node_com], + &input_challenge_point, + sponge, + [input_node_com_state], + None, + ) + .unwrap(); + + // TODO we have to pass rng, not None, but it has been moved before + // fix this once we have decided how to handle the cumbersome + // Option<&mut rng...> + let output_opening_proof = PCS::open( + ck, + [output_labeled_value], + [output_node_com], + &output_challenge_point, + sponge, + [output_node_com_state], + None, + ) + .unwrap(); + + /* TODO (important) Change output_node to all boundary nodes: first and last */ + // TODO prove that inputs match input commitments? + InferenceProof { + inputs_outputs: vec![input_node.clone(), output_node.clone()], + node_proofs, + opening_proofs: vec![input_opening_proof, output_opening_proof], + } +} From f5364e94d2e8181034289f2b22f319231ec06417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Wed, 21 Feb 2024 15:48:25 +0100 Subject: [PATCH 28/50] inference verification compiling --- src/model/isolated_verification.rs | 181 +++++++++++++++++++---------- src/model/mod.rs | 4 + src/model/nodes/bmm.rs | 4 + 3 files changed, 125 insertions(+), 64 deletions(-) diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index d35eeb5..540365c 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -2,7 +2,9 @@ use std::vec; use ark_crypto_primitives::sponge::{Absorb, CryptographicSponge}; use ark_ff::PrimeField; +use ark_poly::{DenseMultilinearExtension, Polynomial}; use ark_poly_commit::{LabeledCommitment, PolynomialCommitment}; +use ark_std::log2; use ark_sumcheck::ml_sumcheck::{ protocol::{verifier::SubClaim, PolynomialInfo}, MLSumcheck, @@ -11,7 +13,8 @@ use ark_sumcheck::ml_sumcheck::{ use crate::model::nodes::bmm::{BMMNodeCommitment, BMMNodeProof}; use super::{ - nodes::{NodeCommitment, NodeProof}, + nodes::{Node, NodeCommitment, NodeProof}, + qarray::{QArray, QTypeArray}, InferenceProof, Model, Poly, }; @@ -139,67 +142,109 @@ where .unwrap() } +fn verify_node( + vk: &PCS::VerifierKey, + sponge: &mut S, + node_com: &NodeCommitment, + input_com: &LabeledCommitment, + output_com: &LabeledCommitment, + proof: NodeProof, + padded_dims_log: Option<(usize, usize)>, +) -> bool +where + F: PrimeField + Absorb, + S: CryptographicSponge, + PCS: PolynomialCommitment, S>, +{ + match node_com { + NodeCommitment::BMM(_) => verify_bmm_node( + vk, + sponge, + node_com, + input_com, + output_com, + proof, + padded_dims_log.unwrap(), + ), + _ => true, + } +} + fn verify_model( vk: &PCS::VerifierKey, + sponge: &mut S, model: &Model, node_commitments: &Vec>, inference_proof: InferenceProof, - sponge: &mut S, -) where +) -> bool +where F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, { - let input_node = model.nodes.first().unwrap(); - let output_node = model.nodes.last().unwrap(); + let InferenceProof { + inputs_outputs, + node_value_commitments, + node_proofs, + opening_proofs, + } = inference_proof; // Absorb all commitments into the sponge - sponge.absorb(inference_proof); + sponge.absorb(&node_value_commitments); - // TODO Prove that all commited NIOs live in the right range (to be + // TODO Verify that all commited NIOs live in the right range (to be // discussed) - let mut node_proofs = Vec::new(); - - // Second pass: proving - for (((((node, node_com), node_com_state), values), l_v_coms), v_coms_states) in self + // Verify node proofs + for (((node, node_com), io_com), node_proof) in model .nodes .iter() - .zip(node_coms.iter()) - .zip(node_com_states.iter()) - .zip(labeled_output_mles.windows(2)) - .zip(output_coms.windows(2)) - .zip(output_com_states.windows(2)) + .zip(node_commitments.iter()) + .zip(node_value_commitments.windows(2)) + .zip(node_proofs.into_iter()) { - node_proofs.push(node.prove( - ck, + // This will not be necessary in the actual code, as the BMM dimensions + // will be contained in the hidden BMMNode and therefore won't be + // passed to the proving method + let get_padded_dims_log = match node { + Node::BMM(bmm) => Some(bmm.get_padded_dims_log()), + _ => None, + }; + + if !verify_node( + vk, sponge, - &node_com, - &node_com_state, - &values[0], - &l_v_coms[0], - &v_coms_states[0], - &values[1], - &l_v_coms[1], - &v_coms_states[1], - )); + node_com, + &io_com[0], + &io_com[1], + node_proof, + get_padded_dims_log, + ) { + return false; + } } - // Opening model IO + // Verifying model IO // TODO maybe this can be made more efficient by not committing to the // output nodes and instead working witht their plain values all along, // but that would require messy node-by-node handling - let input_node = node_outputs.first().unwrap(); - let input_node_f = node_output_mles.first().unwrap().to_evaluations(); - let input_labeled_value = labeled_output_mles.first().unwrap(); - let input_node_com = output_coms.first().unwrap(); - let input_node_com_state = output_com_states.first().unwrap(); - - let output_node = node_outputs.last().unwrap(); - let output_node_f = node_output_mles.last().unwrap().to_evaluations(); - let output_labeled_value = labeled_output_mles.last().unwrap(); - let output_node_com = output_coms.last().unwrap(); - let output_node_com_state = output_com_states.last().unwrap(); + let input_node_com = node_value_commitments.first().unwrap(); + let input_node_qarray = match &inputs_outputs[0] { + QTypeArray::S(i) => i, + _ => panic!("Model input should be QTypeArray::S"), + }; + let input_node_f: Vec = input_node_qarray + .values() + .iter() + .map(|x| F::from(*x)) + .collect(); + + let output_node_com = node_value_commitments.last().unwrap(); + // TODO maybe it's better to save this as F in the proof? + let output_node_f: Vec = match &inputs_outputs[0] { + QTypeArray::S(o) => o.values().iter().map(|x| F::from(*x)).collect(), + _ => panic!("Model output should be QTypeArray::S"), + }; // Absorb the model IO output and squeeze the challenge point // Absorb the plain output and squeeze the challenge point @@ -208,39 +253,47 @@ fn verify_model( let input_challenge_point = sponge.squeeze_field_elements(log2(input_node_f.len()) as usize); let output_challenge_point = sponge.squeeze_field_elements(log2(output_node_f.len()) as usize); - // TODO we have to pass rng, not None, but it has been moved before - // fix this once we have decided how to handle the cumbersome - // Option<&mut rng...> - let input_opening_proof = PCS::open( - ck, - [input_labeled_value], + // Verifying that the actual input was honestly padded with zeros + let padded_input_shape = input_node_qarray.shape().clone(); + let honestly_padded_input = input_node_qarray + .compact_resize(model.input_shape().clone(), 0) + .compact_resize(padded_input_shape, 0); + + if honestly_padded_input.values() != input_node_qarray.values() { + return false; + } + + // The verifier must evaluate the MLE given by the plain input values + let input_node_eval = + Poly::from_evaluations_vec(log2(input_node_f.len()) as usize, input_node_f) + .evaluate(&input_challenge_point); + let output_node_eval = + Poly::from_evaluations_vec(log2(output_node_f.len()) as usize, output_node_f) + .evaluate(&output_challenge_point); + + // TODO rng, None + if !PCS::check( + vk, [input_node_com], &input_challenge_point, + [input_node_eval], + &opening_proofs[0], sponge, - [input_node_com_state], None, ) - .unwrap(); - - // TODO we have to pass rng, not None, but it has been moved before - // fix this once we have decided how to handle the cumbersome - // Option<&mut rng...> - let output_opening_proof = PCS::open( - ck, - [output_labeled_value], + .unwrap() + { + return false; + } + + PCS::check( + vk, [output_node_com], &output_challenge_point, + [output_node_eval], + &opening_proofs[1], sponge, - [output_node_com_state], None, ) - .unwrap(); - - /* TODO (important) Change output_node to all boundary nodes: first and last */ - // TODO prove that inputs match input commitments? - InferenceProof { - inputs_outputs: vec![input_node.clone(), output_node.clone()], - node_proofs, - opening_proofs: vec![input_opening_proof, output_opening_proof], - } + .unwrap() } diff --git a/src/model/mod.rs b/src/model/mod.rs index 5e9b463..4975e09 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -75,6 +75,10 @@ where } } + pub(crate) fn input_shape(&self) -> &Vec { + &self.input_shape + } + pub(crate) fn setup_keys( &self, rng: &mut R, diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index e06b49f..cdb07af 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -465,6 +465,10 @@ where phantom: PhantomData, } } + + pub(crate) fn get_padded_dims_log(&self) -> (usize, usize) { + self.padded_dims_log + } } // TODO in constructor, add quantisation information checks? (s_bias = s_input * s_weight, z_bias = 0, z_weight = 0, etc.) // TODO in constructor, check bias length matches appropriate matrix dimension From d95aa6d9b3123ad311d2a591afc84bc9c249086f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Wed, 21 Feb 2024 15:51:07 +0100 Subject: [PATCH 29/50] rename --- .../examples/simple_perceptron_mnist/mod.rs | 48 +++++++++++++++++++ src/model/isolated_verification.rs | 2 +- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index fa4247f..7e263bc 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -151,6 +151,54 @@ fn prove_inference_simple_perceptron_mnist() { let output_u8 = (output_i8.cast::() + 128).cast::(); + println!("Padded output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS[0]], expected_output); +} + + +#[test] +fn verify_inference_simple_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![135, 109, 152, 161, 187, 157, 159, 151, 173, 202]; + /**********************/ + + let perceptron = build_simple_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let mut rng = test_rng(); + let (ck, _) = perceptron.setup_keys(&mut rng).unwrap(); + + let mut sponge: PoseidonSponge = test_sponge(); + + //let (hidden_nodes, com_states) = perceptron.commit(&ck, None).iter().unzip(); + let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); + + let inference_proof = perceptron.prove_inference( + &ck, + Some(&mut rng), + &mut sponge, + &node_coms, + &node_com_states, + input_i8, + ); + + let output_qtypearray = inference_proof.inputs_outputs[1].clone(); + + let output_i8 = match output_qtypearray { + QTypeArray::S(o) => o, + _ => panic!("Expected QTypeArray::S"), + }; + + let output_u8 = (output_i8.cast::() + 128).cast::(); + println!("Padded output: {:?}", output_u8.values()); assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS[0]], expected_output); } \ No newline at end of file diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index 540365c..1305815 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -170,7 +170,7 @@ where } } -fn verify_model( +pub(crate) fn verify_inference( vk: &PCS::VerifierKey, sponge: &mut S, model: &Model, From 72f3441a995e9921304fac3de6de1be083111cd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Wed, 21 Feb 2024 15:55:00 +0100 Subject: [PATCH 30/50] testing without node proofs --- .../examples/simple_perceptron_mnist/mod.rs | 16 ++++++++++++---- src/model/isolated_verification.rs | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index 7e263bc..893f816 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -1,8 +1,6 @@ use crate::{ model::{ - nodes::{bmm::BMMNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, - qarray::{QArray, QTypeArray}, - Model, Poly, + isolated_verification::verify_inference, nodes::{bmm::BMMNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, qarray::{QArray, QTypeArray}, Model, Poly }, quantization::{quantise_f32_u8_nne, QSmallType}, utils::{pcs_types::Ligero, test_sponge::test_sponge} }; @@ -174,7 +172,7 @@ fn verify_inference_simple_perceptron_mnist() { let input_i8 = (quantised_input.cast::() - 128).cast::(); let mut rng = test_rng(); - let (ck, _) = perceptron.setup_keys(&mut rng).unwrap(); + let (ck, vk) = perceptron.setup_keys(&mut rng).unwrap(); let mut sponge: PoseidonSponge = test_sponge(); @@ -192,6 +190,16 @@ fn verify_inference_simple_perceptron_mnist() { let output_qtypearray = inference_proof.inputs_outputs[1].clone(); + let mut sponge: PoseidonSponge = test_sponge(); + + verify_inference( + &vk, + &mut sponge, + &perceptron, + &node_coms, + inference_proof + ); + let output_i8 = match output_qtypearray { QTypeArray::S(o) => o, _ => panic!("Expected QTypeArray::S"), diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index 1305815..7f16b5a 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -220,7 +220,7 @@ where node_proof, get_padded_dims_log, ) { - return false; + // pass return false; } } From 36ff0eb816ce2196087e116fd91b2dafe1311066 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Fri, 23 Feb 2024 08:57:30 +0100 Subject: [PATCH 31/50] fixed W opening due to duality and shifted input MLE --- src/model/isolated_verification.rs | 26 +++++++++++++++----------- src/model/nodes/bmm.rs | 18 ++++++++++++------ 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index 7f16b5a..f26e193 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -26,6 +26,7 @@ fn verify_bmm_node( output_com: &LabeledCommitment, proof: NodeProof, padded_dims_log: (usize, usize), + zero_point: F, // This argument will not be here in the final code ) -> bool where F: PrimeField + Absorb, @@ -102,10 +103,7 @@ where if !PCS::check( vk, [weight_com], - &oracle_point - .into_iter() - .chain(r.clone().into_iter()) - .collect(), + &r.clone().into_iter().chain(oracle_point).collect(), [weight_opening_value], &weight_opening_proof, sponge, @@ -150,6 +148,7 @@ fn verify_node( output_com: &LabeledCommitment, proof: NodeProof, padded_dims_log: Option<(usize, usize)>, + zero_point: Option, ) -> bool where F: PrimeField + Absorb, @@ -165,6 +164,7 @@ where output_com, proof, padded_dims_log.unwrap(), + zero_point.unwrap(), ), _ => true, } @@ -204,11 +204,14 @@ where .zip(node_proofs.into_iter()) { // This will not be necessary in the actual code, as the BMM dimensions - // will be contained in the hidden BMMNode and therefore won't be - // passed to the proving method - let get_padded_dims_log = match node { - Node::BMM(bmm) => Some(bmm.get_padded_dims_log()), - _ => None, + // and zero point will be contained in the (possibly hidden) BMMNode + // and therefore won't be passed to the proving method + let (padded_dims_log, input_zero_point) = match node { + Node::BMM(bmm) => ( + Some(bmm.padded_dims_log()), + Some(F::from(bmm.input_zero_point())), + ), + _ => (None, None), }; if !verify_node( @@ -218,9 +221,10 @@ where &io_com[0], &io_com[1], node_proof, - get_padded_dims_log, + padded_dims_log, + input_zero_point, ) { - // pass return false; + return false; } } diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index cdb07af..f8c094d 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -289,7 +289,10 @@ where // commitments in Model::prove_inference let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); - let input_mle = input.polynomial().clone(); + let shifted_input_mle = Poly::from_evaluations_vec( + input.num_vars(), + input.polynomial().iter().map(|x| F::from(*x)).collect(), + ); // TODO consider whether this can be done once and stored let weights_f = self.padded_weights.iter().map(|w| F::from(*w)).collect(); @@ -315,7 +318,7 @@ where // TODO we are cloning the input here, can we do better? big_poly.add_product( - vec![input_mle, bound_weight_mle] + vec![shifted_input_mle, bound_weight_mle] .into_iter() .map(Rc::new) .collect::>(), @@ -359,10 +362,9 @@ where None, )], [weight_com], - &prover_state - .randomness + &r.clone() .into_iter() - .chain(r.clone().into_iter()) + .chain(prover_state.randomness) .collect(), sponge, [weight_com_state], @@ -466,9 +468,13 @@ where } } - pub(crate) fn get_padded_dims_log(&self) -> (usize, usize) { + pub(crate) fn padded_dims_log(&self) -> (usize, usize) { self.padded_dims_log } + + pub(crate) fn input_zero_point(&self) -> QSmallType { + self.input_zero_point + } } // TODO in constructor, add quantisation information checks? (s_bias = s_input * s_weight, z_bias = 0, z_weight = 0, etc.) // TODO in constructor, check bias length matches appropriate matrix dimension From 212f0cc15b47cff902d2aa6c38036665ac062999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Fri, 23 Feb 2024 09:20:50 +0100 Subject: [PATCH 32/50] bmm node proof and verification working --- src/model/examples/simple_perceptron_mnist/mod.rs | 5 ++--- src/model/isolated_verification.rs | 8 ++++---- src/model/nodes/bmm.rs | 12 +++++++++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index 893f816..6d4222e 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -176,7 +176,6 @@ fn verify_inference_simple_perceptron_mnist() { let mut sponge: PoseidonSponge = test_sponge(); - //let (hidden_nodes, com_states) = perceptron.commit(&ck, None).iter().unzip(); let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); let inference_proof = perceptron.prove_inference( @@ -192,13 +191,13 @@ fn verify_inference_simple_perceptron_mnist() { let mut sponge: PoseidonSponge = test_sponge(); - verify_inference( + assert!(verify_inference( &vk, &mut sponge, &perceptron, &node_coms, inference_proof - ); + )); let output_i8 = match output_qtypearray { QTypeArray::S(o) => o, diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index f26e193..b119e82 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -26,7 +26,7 @@ fn verify_bmm_node( output_com: &LabeledCommitment, proof: NodeProof, padded_dims_log: (usize, usize), - zero_point: F, // This argument will not be here in the final code + input_zero_point: F, // This argument will not be here in the final code ) -> bool where F: PrimeField + Absorb, @@ -80,7 +80,7 @@ where expected_evaluation: oracle_evaluation, } = subclaim; - if oracle_evaluation != input_opening_value * weight_opening_value { + if oracle_evaluation != (input_opening_value - input_zero_point) * weight_opening_value { return false; } @@ -148,7 +148,7 @@ fn verify_node( output_com: &LabeledCommitment, proof: NodeProof, padded_dims_log: Option<(usize, usize)>, - zero_point: Option, + input_zero_point: Option, ) -> bool where F: PrimeField + Absorb, @@ -164,7 +164,7 @@ where output_com, proof, padded_dims_log.unwrap(), - zero_point.unwrap(), + input_zero_point.unwrap(), ), _ => true, } diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index f8c094d..1d9f68b 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -289,9 +289,11 @@ where // commitments in Model::prove_inference let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); + let i_z_p_f = F::from(self.input_zero_point); + let shifted_input_mle = Poly::from_evaluations_vec( input.num_vars(), - input.polynomial().iter().map(|x| F::from(*x)).collect(), + input.polynomial().iter().map(|x| *x - i_z_p_f).collect(), ); // TODO consider whether this can be done once and stored @@ -342,6 +344,10 @@ where .map(|x| x.evaluate(&prover_state.randomness)) .collect(); + // Recall that the first MLE in big_poly was the *shifted* input + let input_opening_value = claimed_evaluations[0] + i_z_p_f; + let weight_opening_value = claimed_evaluations[1]; + let input_opening_proof = PCS::open( &ck, [input], @@ -402,9 +408,9 @@ where NodeProof::BMM(BMMNodeProof { sumcheck_proof, input_opening_proof, - input_opening_value: claimed_evaluations[0], + input_opening_value, weight_opening_proof, - weight_opening_value: claimed_evaluations[1], + weight_opening_value, bias_opening_proof, bias_opening_value, output_opening_proof, From 205a3af3266a1b10508ee6e3914ea3c69a95e1fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Fri, 23 Feb 2024 09:33:32 +0100 Subject: [PATCH 33/50] entire model proof verifying --- src/model/isolated_verification.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index b119e82..93842af 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -245,7 +245,7 @@ where let output_node_com = node_value_commitments.last().unwrap(); // TODO maybe it's better to save this as F in the proof? - let output_node_f: Vec = match &inputs_outputs[0] { + let output_node_f: Vec = match &inputs_outputs[1] { QTypeArray::S(o) => o.values().iter().map(|x| F::from(*x)).collect(), _ => panic!("Model output should be QTypeArray::S"), }; From a9e9d10803e4c5d7422386a6e9945650c52ec816 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Fri, 23 Feb 2024 09:37:57 +0100 Subject: [PATCH 34/50] remove unnecessary nesting --- src/model/examples/simple_perceptron_mnist/mod.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/model/examples/simple_perceptron_mnist/mod.rs b/src/model/examples/simple_perceptron_mnist/mod.rs index 6d4222e..5176c18 100644 --- a/src/model/examples/simple_perceptron_mnist/mod.rs +++ b/src/model/examples/simple_perceptron_mnist/mod.rs @@ -17,7 +17,7 @@ use input::*; use parameters::*; const INPUT_DIMS: &[usize] = &[28, 28]; -const OUTPUT_DIMS: &[usize] = &[10]; +const OUTPUT_DIMS: usize = 10; // TODO this is incorrect now that we have switched to logs fn build_simple_perceptron_mnist() -> Model @@ -34,12 +34,12 @@ where let bmm: BMMNode = BMMNode::new( WEIGHTS.to_vec(), BIAS.to_vec(), - (flat_dim, OUTPUT_DIMS[0]), + (flat_dim, OUTPUT_DIMS), Z_I, ); let req_bmm: RequantiseBMMNode = RequantiseBMMNode::new( - OUTPUT_DIMS[0], + OUTPUT_DIMS, S_I, Z_I, S_W, @@ -150,7 +150,7 @@ fn prove_inference_simple_perceptron_mnist() { let output_u8 = (output_i8.cast::() + 128).cast::(); println!("Padded output: {:?}", output_u8.values()); - assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS[0]], expected_output); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS], expected_output); } @@ -207,5 +207,5 @@ fn verify_inference_simple_perceptron_mnist() { let output_u8 = (output_i8.cast::() + 128).cast::(); println!("Padded output: {:?}", output_u8.values()); - assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS[0]], expected_output); -} \ No newline at end of file + assert_eq!(output_u8.move_values()[0..OUTPUT_DIMS], expected_output); +} From 080857d314b4fa1fe55255b5487cf4308e84713d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Fri, 23 Feb 2024 09:39:47 +0100 Subject: [PATCH 35/50] added verification test to two-layer perceptron, passing --- .../two_layer_perceptron_mnist/mod.rs | 61 +++++++++++++++++-- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/src/model/examples/two_layer_perceptron_mnist/mod.rs b/src/model/examples/two_layer_perceptron_mnist/mod.rs index 9ed86f2..c95326d 100644 --- a/src/model/examples/two_layer_perceptron_mnist/mod.rs +++ b/src/model/examples/two_layer_perceptron_mnist/mod.rs @@ -1,8 +1,6 @@ use crate::{ model::{ - nodes::{bmm::BMMNode, relu::ReLUNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, - qarray::{QArray, QTypeArray}, - Model, Poly, + isolated_verification::verify_inference, nodes::{bmm::BMMNode, relu::ReLUNode, requantise_bmm::RequantiseBMMNode, reshape::ReshapeNode, Node}, qarray::{QArray, QTypeArray}, Model, Poly }, quantization::{quantise_f32_u8_nne, QSmallType}, utils::{pcs_types::Ligero, test_sponge::test_sponge} }; @@ -152,7 +150,6 @@ fn prove_inference_two_layer_perceptron_mnist() { let mut sponge: PoseidonSponge = test_sponge(); - //let (hidden_nodes, com_states) = perceptron.commit(&ck, None).iter().unzip(); let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); let inference_proof = perceptron.prove_inference( @@ -173,6 +170,62 @@ fn prove_inference_two_layer_perceptron_mnist() { let output_u8 = (output_i8.cast::() + 128).cast::(); + println!("Padded output: {:?}", output_u8.values()); + assert_eq!(output_u8.move_values()[0..OUTPUT_DIM], expected_output); +} + +#[test] +fn verify_inference_two_layer_perceptron_mnist() { + /**** Change here ****/ + let input = NORMALISED_INPUT_TEST_150; + let expected_output: Vec = vec![138, 106, 149, 160, 174, 152, 141, 146, 169, 207]; + /**********************/ + + let perceptron = build_two_layer_perceptron_mnist::, Ligero>(); + + let quantised_input: QArray = input + .iter() + .map(|r| quantise_f32_u8_nne(r, S_INPUT, Z_INPUT)) + .collect::>>() + .into(); + + let input_i8 = (quantised_input.cast::() - 128).cast::(); + + let mut rng = test_rng(); + let (ck, vk) = perceptron.setup_keys(&mut rng).unwrap(); + + let mut sponge: PoseidonSponge = test_sponge(); + + let (node_coms, node_com_states): (Vec<_>, Vec<_>) = perceptron.commit(&ck, None).into_iter().unzip(); + + let inference_proof = perceptron.prove_inference( + &ck, + Some(&mut rng), + &mut sponge, + &node_coms, + &node_com_states, + input_i8, + ); + + let output_qtypearray = inference_proof.inputs_outputs[1].clone(); + + let mut sponge: PoseidonSponge = test_sponge(); + + assert!(verify_inference( + &vk, + &mut sponge, + &perceptron, + &node_coms, + inference_proof + )); + + let output_i8 = match output_qtypearray { + QTypeArray::S(o) => o, + _ => panic!("Expected QTypeArray::S"), + }; + + let output_u8 = (output_i8.cast::() + 128).cast::(); + println!("Padded output: {:?}", output_u8.values()); assert_eq!(output_u8.move_values()[0..OUTPUT_DIM], expected_output); } \ No newline at end of file From 433f94dc70686357b3f6b561132afbecf2cd7ae3 Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 12:12:03 +0100 Subject: [PATCH 36/50] Use a repo secret in the CI workflows To access private repos --- .github/workflows/ci.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 387e3b0..15102ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,6 +52,11 @@ jobs: target key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - name: Add SSH key for private repos + uses: webfactory/ssh-agent@v0.9.0 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + - name: Test uses: actions-rs/cargo@v1 with: From d3de9f2c620b1512ac59add2298838391a9c2980 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Fri, 23 Feb 2024 12:42:30 +0100 Subject: [PATCH 37/50] added mention of simple possible optimisation --- src/model/isolated_verification.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index 93842af..87898ac 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -114,6 +114,8 @@ where return false; } + // TODO: b and o are opened at the same point, so they could be verified + // with a single call to PCS::check if !PCS::check( vk, [bias_com], From e7b5f83bd88fec1631f4e7131fd322b68823048d Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 15:30:54 +0100 Subject: [PATCH 38/50] pcs-bench-templates and pcs crates should use the same revision --- Cargo.lock | 24 ++++-------------------- Cargo.toml | 2 +- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7f572c8..35650ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -139,13 +139,13 @@ dependencies = [ [[package]] name = "ark-pcs-bench-templates" version = "0.4.0" -source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=brakedown-com-absorb#95fc96c5af94ad3cb6b83a5133810a2954336bc7" +source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb#fe882348c748a64f14b99853a8fc3aeb346843f0" dependencies = [ "ark-crypto-primitives", "ark-ec", "ark-ff", "ark-poly", - "ark-poly-commit 0.4.0 (git+https://github.com/HungryCatsStudio/poly-commit?branch=brakedown-com-absorb)", + "ark-poly-commit", "ark-serialize", "ark-std", "criterion", @@ -165,22 +165,6 @@ dependencies = [ "hashbrown", ] -[[package]] -name = "ark-poly-commit" -version = "0.4.0" -source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=brakedown-com-absorb#95fc96c5af94ad3cb6b83a5133810a2954336bc7" -dependencies = [ - "ark-crypto-primitives", - "ark-ec", - "ark-ff", - "ark-poly", - "ark-serialize", - "ark-std", - "derivative", - "digest", - "num-traits", -] - [[package]] name = "ark-poly-commit" version = "0.4.0" @@ -260,7 +244,7 @@ dependencies = [ "ark-crypto-primitives", "ark-ff", "ark-poly", - "ark-poly-commit 0.4.0 (git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb)", + "ark-poly-commit", "ark-serialize", "ark-std", "hashbrown", @@ -829,7 +813,7 @@ dependencies = [ "ark-ff", "ark-pcs-bench-templates", "ark-poly", - "ark-poly-commit 0.4.0 (git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb)", + "ark-poly-commit", "ark-serialize", "ark-std", "ark-sumcheck", diff --git a/Cargo.toml b/Cargo.toml index 5df5965..7f07589 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ ark-sumcheck = { git = "ssh://git@github.com/HungryCatsStudio/sumcheck-private.g ark-bn254 = { version = "^0.4.0", default-features = false, features = [ "curve" ] } blake2 = { version = "0.10", default-features = false } serde_json = "1.0.108" -ark-pcs-bench-templates = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "brakedown-com-absorb" } +ark-pcs-bench-templates = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "ligero-uni-and-ml-absorb" } [patch.crates-io] ark-ff = { git = "https://github.com/arkworks-rs/algebra/" } From fc71e44f44b91c16c4b5a5811e87468e358fad9e Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 15:47:51 +0100 Subject: [PATCH 39/50] add features to the Cargo manifest --- Cargo.lock | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 8 +++++- 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 35650ea..07ad8e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,6 +93,7 @@ dependencies = [ "itertools 0.12.1", "num-bigint", "num-traits", + "rayon", "zeroize", ] @@ -112,6 +113,7 @@ dependencies = [ "num-bigint", "num-traits", "paste", + "rayon", "zeroize", ] @@ -163,6 +165,7 @@ dependencies = [ "ark-std", "derivative", "hashbrown", + "rayon", ] [[package]] @@ -174,12 +177,14 @@ dependencies = [ "ark-ec", "ark-ff", "ark-poly", + "ark-relations", "ark-serialize", "ark-std", "derivative", "digest", "merlin", "num-traits", + "rayon", ] [[package]] @@ -191,6 +196,7 @@ dependencies = [ "ark-ff", "ark-std", "tracing", + "tracing-subscriber", ] [[package]] @@ -234,6 +240,7 @@ checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" dependencies = [ "num-traits", "rand", + "rayon", ] [[package]] @@ -248,6 +255,7 @@ dependencies = [ "ark-serialize", "ark-std", "hashbrown", + "rayon", ] [[package]] @@ -393,6 +401,31 @@ dependencies = [ "itertools 0.10.5", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + [[package]] name = "crunchy" version = "0.2.2" @@ -651,6 +684,26 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +[[package]] +name = "rayon" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "regex" version = "1.10.3" @@ -790,6 +843,19 @@ name = "tracing-core" version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "tracing-core", +] [[package]] name = "typenum" @@ -803,6 +869,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "verifiaml" version = "0.1.0" @@ -818,6 +890,7 @@ dependencies = [ "ark-std", "ark-sumcheck", "blake2", + "rayon", "serde_json", ] diff --git a/Cargo.toml b/Cargo.toml index 7f07589..cd78ce4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,12 +12,13 @@ ark-poly = {version = "^0.4.0", default-features = false } ark-poly-commit = {version = "^0.4.0", default-features = false } ark-crypto-primitives = {version = "^0.4.0", default-features = false } ark-sumcheck = { git = "ssh://git@github.com/HungryCatsStudio/sumcheck-private.git", default-features = false } +rayon = { version = "1.5", default-features = false, optional = true } [dev-dependencies] ark-bn254 = { version = "^0.4.0", default-features = false, features = [ "curve" ] } blake2 = { version = "0.10", default-features = false } serde_json = "1.0.108" -ark-pcs-bench-templates = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "ligero-uni-and-ml-absorb" } +ark-pcs-bench-templates = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "ligero-uni-and-ml-absorb", default-features = false } [patch.crates-io] ark-ff = { git = "https://github.com/arkworks-rs/algebra/" } @@ -27,3 +28,8 @@ ark-poly = { git = "https://github.com/arkworks-rs/algebra/" } ark-poly-commit = { git = "https://github.com/HungryCatsStudio/poly-commit", branch = "ligero-uni-and-ml-absorb" } ark-crypto-primitives = { git = "https://github.com/HungryCatsStudio/crypto-primitives", branch = "absorb"} ark-bn254 = { git = "https://github.com/arkworks-rs/algebra/" } + +[features] +default = [ "std", "parallel" ] +std = [ "ark-ff/std", "ark-ec/std", "ark-poly/std", "ark-serialize/std", "ark-crypto-primitives/std", "ark-poly-commit/std", "ark-sumcheck/std" ] +parallel = [ "std", "ark-ff/parallel", "ark-ec/parallel", "ark-poly/parallel", "ark-std/parallel", "ark-poly-commit/parallel", "ark-sumcheck/parallel", "rayon" ] \ No newline at end of file From 7274aa2d0b16a38f70c72a7d444f07fc9826eebf Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 15:58:06 +0100 Subject: [PATCH 40/50] use the updated cache action --- .github/workflows/ci.yml | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 15102ba..1f734c5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,13 +44,8 @@ jobs: toolchain: ${{ matrix.rust }} override: true - - uses: actions/cache@v3 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + - uses: Swatinem/rust-cache@v2 + name: Enable Rust Caching - name: Add SSH key for private repos uses: webfactory/ssh-agent@v0.9.0 From 7bcbcf466565a8ab8305bc1ecd7c4e5b8c9c459d Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 15:58:26 +0100 Subject: [PATCH 41/50] use a plain command for running the test; run without default features --- .github/workflows/ci.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f734c5..323f14e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,6 +53,4 @@ jobs: ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} - name: Test - uses: actions-rs/cargo@v1 - with: - command: test + run: cargo test --no-default-features From 6e0fe7500d761891d7bce8dbf52582378c2f00c4 Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 16:12:31 +0100 Subject: [PATCH 42/50] temp run `cargo test` with default features --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 323f14e..7a9b0dd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,4 +53,4 @@ jobs: ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} - name: Test - run: cargo test --no-default-features + run: cargo test From a0a159280d3d1d5f10f5121f895f5a55049aced4 Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 16:13:59 +0100 Subject: [PATCH 43/50] Add two build jobs: default & `--no-default-features` --- .github/workflows/ci.yml | 58 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7a9b0dd..4c1a8b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,64 @@ jobs: - name: Format run: cargo +nightly fmt --all -- --check + build: + name: Build + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - nightly + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust (${{ matrix.rust }}) + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + + - uses: Swatinem/rust-cache@v2 + name: Enable Rust Caching + + - name: Add SSH key for private repos + uses: webfactory/ssh-agent@v0.9.0 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + + - name: Build + run: cargo build + + build-no-std: + name: Build (no-std) + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - nightly + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust (${{ matrix.rust }}) + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + + - uses: Swatinem/rust-cache@v2 + name: Enable Rust Caching + + - name: Add SSH key for private repos + uses: webfactory/ssh-agent@v0.9.0 + with: + ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} + + - name: Build + run: cargo build --no-default-features + test: name: Test runs-on: ubuntu-latest From 1dabb04b23fd3cf8d0acb8c28d6fdb2d326be5f7 Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 16:14:14 +0100 Subject: [PATCH 44/50] no-std note on ark-pcs-bench-templates --- src/utils/pcs_types.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utils/pcs_types.rs b/src/utils/pcs_types.rs index e6b81cc..82414c9 100644 --- a/src/utils/pcs_types.rs +++ b/src/utils/pcs_types.rs @@ -3,6 +3,10 @@ use ark_crypto_primitives::{ merkle_tree::{ByteDigestConverter, Config}, sponge::poseidon::PoseidonSponge, }; +// no-std note: +// Currently, we use the `LeafIdentityHasher` from ark_pcs_bench_templates. +// This is not ideal, since the entire `ark_pcs_bench_templates` crate does not support `no_std` +// (due to `criterion`) dependency. use ark_pcs_bench_templates::*; use ark_poly::DenseMultilinearExtension; From 69e502d7afe55d34b7e314c75952c2c7ffc807ee Mon Sep 17 00:00:00 2001 From: mmagician Date: Fri, 23 Feb 2024 17:46:46 +0100 Subject: [PATCH 45/50] update deps --- Cargo.lock | 70 +++++++++++++++++++++++++++--------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 07ad8e0..5fd48f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "ahash" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" +checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f" dependencies = [ "cfg-if", "once_cell", @@ -76,7 +76,7 @@ source = "git+https://github.com/HungryCatsStudio/crypto-primitives?branch=absor dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -123,7 +123,7 @@ version = "0.4.2" source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402ac037de01f6c069" dependencies = [ "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -135,13 +135,13 @@ dependencies = [ "num-traits", "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] name = "ark-pcs-bench-templates" version = "0.4.0" -source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb#fe882348c748a64f14b99853a8fc3aeb346843f0" +source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb#dfdd8e87d3df9059816dd7cec16ade0f4ac0623a" dependencies = [ "ark-crypto-primitives", "ark-ec", @@ -171,7 +171,7 @@ dependencies = [ [[package]] name = "ark-poly-commit" version = "0.4.0" -source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb#4bebbc94a94bbc1a4c48b884c6b022bfbe91d934" +source = "git+https://github.com/HungryCatsStudio/poly-commit?branch=ligero-uni-and-ml-absorb#dfdd8e87d3df9059816dd7cec16ade0f4ac0623a" dependencies = [ "ark-crypto-primitives", "ark-ec", @@ -217,7 +217,7 @@ source = "git+https://github.com/arkworks-rs/algebra/#3a6156785e12eeb9083a7a402a dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -750,29 +750,29 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] name = "serde_json" -version = "1.0.113" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "itoa", "ryu", @@ -809,9 +809,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.49" +version = "2.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915aea9e586f80826ee59f8453c1101f9d1c4b3964cd2460185ee8e299ada496" +checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" dependencies = [ "proc-macro2", "quote", @@ -952,9 +952,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "d380ba1dc7187569a8a9e91ed34b8ccfc33123bbacb8c0aed2d1ad7f3ef2dc5f" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -967,45 +967,45 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "68e5dcfb9413f53afd9c8f86e56a7b4d86d9a2fa26090ea2dc9e40fba56c6ec6" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "8dab469ebbc45798319e69eebf92308e541ce46760b49b18c6b3fe5e8965b30f" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "2a4e9b6a7cac734a8b4138a4e1044eac3404d8326b6c0f939276560687a033fb" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "28b0ec9c422ca95ff34a78755cfa6ad4a51371da2a5ace67500cf7ca5f232c58" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "704131571ba93e89d7cd43482277d6632589b18ecf4468f591fbae0a8b101614" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "42079295511643151e98d61c38c0acc444e52dd42ab456f7ccfd5152e8ecf21c" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" [[package]] name = "zerocopy" @@ -1024,7 +1024,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -1044,5 +1044,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] From 748f10b4fca913924591d2d649bd711119f910e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Mon, 26 Feb 2024 08:37:13 +0100 Subject: [PATCH 46/50] removed reshaping.rs file, overseeded by QArray::compact_rehsape --- src/model/mod.rs | 1 - src/model/reshaping.rs | 40 ---------------------------------------- 2 files changed, 41 deletions(-) delete mode 100644 src/model/reshaping.rs diff --git a/src/model/mod.rs b/src/model/mod.rs index 4975e09..d2dec26 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -18,7 +18,6 @@ mod examples; mod isolated_verification; mod nodes; mod qarray; -mod reshaping; pub(crate) type Poly = DenseMultilinearExtension; pub(crate) type LabeledPoly = LabeledPolynomial>; diff --git a/src/model/reshaping.rs b/src/model/reshaping.rs deleted file mode 100644 index beea6a6..0000000 --- a/src/model/reshaping.rs +++ /dev/null @@ -1,40 +0,0 @@ -use ark_std::vec; - -// Let `array` be an array of length m. Define M = 2^(ceil(max(log2(m), 0))) -// This function pads `array` to length M with the value `pad`. -pub(crate) fn pad_pow2_1d(mut array: Vec, pad: T) -> Vec { - let m = array.len().next_power_of_two(); - array.resize(m, pad); - array -} - -// Let `array` be a non-empty array of subarrays. Let m = array.len() and -// n = array[0].len(). Define M = 2^(ceil(max(log2(m), 0))) and -// N = 2^(ceil(max(log2(n), 0))). -// This function pads (with the value `pad`) or truncates each subarray of -// `array` to length N; and also pads `array` itself to length M with -// new subarrays of length N filled with the value `pad`. -// -// Panics if `array` is empty -pub(crate) fn pad_pow2_2d(array: Vec>, pad: T) -> Vec> { - assert!(array.is_empty()); - - let m_0 = array.len(); - let m = m_0.next_power_of_two(); - - let n = array[0].len().next_power_of_two(); - - let mut padded_array = Vec::with_capacity(m); - - for subarray in array { - let mut s = subarray.clone(); - s.resize(n, pad); - padded_array.push(s); - } - - for _ in 0..(m - m_0) { - padded_array.push(vec![pad; n]); - } - - padded_array -} From ffcf876097328ef8cb3377adf93fe4bf9156a5c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Mon, 26 Feb 2024 09:01:52 +0100 Subject: [PATCH 47/50] tweaked comments and documentaiton of BMM::prove --- src/model/nodes/bmm.rs | 56 ++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 1d9f68b..0484830 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -40,6 +40,8 @@ pub(crate) struct BMMNode { phantom: PhantomData<(F, S, PCS)>, } +/// Commitment to a BMM node, consisting of a commitment to the *dual* of the +/// weight MLE and one to the *dual* of the bias MLE pub(crate) struct BMMNodeCommitment where F: PrimeField, @@ -58,6 +60,8 @@ where { } +/// Commitment states associated to a BMMNodeCommitment: one for the weight and +/// one for the bias pub(crate) struct BMMNodeCommitmentState where F: PrimeField, @@ -76,18 +80,35 @@ where { } +/// Proof of execution of a BMM node, consisting of a sumcheck proof and four +/// PCS opening proofs pub(crate) struct BMMNodeProof< F: PrimeField + Absorb, S: CryptographicSponge, PCS: PolynomialCommitment, S>, > { + /// Sumcheck protocol proof for the polynomial + /// g(x) = (input - zero_point)^(x) * W^(r, x), + /// where v^ denotes the dual of the MLE of v and r is a challenge point pub(crate) sumcheck_proof: Proof, + + /// Value of the *dual* of the input MLE at the challenge point and proof of + /// opening pub(crate) input_opening_proof: PCS::Proof, pub(crate) input_opening_value: F, + + /// Value of the *dual* of the weight MLE at the challenge point and proof of + /// opening pub(crate) weight_opening_proof: PCS::Proof, pub(crate) weight_opening_value: F, + + /// Value of the *dual* of the bias MLE at the challenge point and proof of + /// opening pub(crate) bias_opening_proof: PCS::Proof, pub(crate) bias_opening_value: F, + + /// Value of the *dual* of the output MLE at the challenge point and proof of + /// opening pub(crate) output_opening_proof: PCS::Proof, pub(crate) output_opening_value: F, } @@ -285,12 +306,13 @@ where ), }; - // we can squeeze directly, since the sponge has already absorbed all the + // We can squeeze directly, since the sponge has already absorbed all the // commitments in Model::prove_inference let r: Vec = sponge.squeeze_field_elements(self.padded_dims_log.1); let i_z_p_f = F::from(self.input_zero_point); + /// (f - zero-point)^ let shifted_input_mle = Poly::from_evaluations_vec( input.num_vars(), input.polynomial().iter().map(|x| *x - i_z_p_f).collect(), @@ -298,12 +320,13 @@ where // TODO consider whether this can be done once and stored let weights_f = self.padded_weights.iter().map(|w| F::from(*w)).collect(); - // TODO this might need LE -> BE conversion + + // Dual of the MLE of the row-major flattening of the weight matrix let weight_mle = Poly::from_evaluations_vec(self.com_num_vars(), weights_f); // TODO consider whether this can be done once and stored let bias_f = self.padded_bias.iter().map(|w| F::from(*w)).collect(); - // TODO this might need LE -> BE conversion + // Dual of the MLE of the bias vector let bias_mle = Poly::from_evaluations_vec(self.padded_dims_log.1, bias_f); // TODO is output_opening_value directly available from the output of sumcheck? @@ -311,15 +334,13 @@ where let bias_opening_value = bias_mle.evaluate(&r); let output_opening_value = output.evaluate(&r); - // TODO we actually need fix_variables_last - let bound_weight_mle = weight_mle.fix_variables(&r); - // Constructing the sumcheck polynomial - // big_poly(x) := input(x) * weights(x, r) - let mut big_poly = ListOfProductsOfPolynomials::new(self.padded_dims_log.0); + // g(x) = (input - zero_point)^(x) * W^(r, x), + let bound_weight_mle = weight_mle.fix_variables(&r); + let mut g = ListOfProductsOfPolynomials::new(self.padded_dims_log.0); // TODO we are cloning the input here, can we do better? - big_poly.add_product( + g.add_product( vec![shifted_input_mle, bound_weight_mle] .into_iter() .map(Rc::new) @@ -328,23 +349,24 @@ where ); let (sumcheck_proof, prover_state) = - MLSumcheck::::prove_as_subprotocol(&big_poly, sponge).unwrap(); + MLSumcheck::::prove_as_subprotocol(&g, sponge).unwrap(); // The prover computes the claimed evaluations of weight_mle and // input_mle at the random challenge point - // s:= `prover_state.randomness`, the list of random values sampled by - // the verifier duriing sumcheck. Note that this is different from `r` + // s := prover_state.randomness, the list of random values sampled by + // the verifier during sumcheck. Note that this is different from r // above. // - // We need to open input_mle(s) * weight_mle(s, r) as well as - // output_mle(r) - let claimed_evaluations: Vec = big_poly + // We need to reveal g(s) by opening input^ at s and weight^ at s || r; + // and also open output^ and bias^ at r + let claimed_evaluations: Vec = g .flattened_ml_extensions .iter() .map(|x| x.evaluate(&prover_state.randomness)) .collect(); - // Recall that the first MLE in big_poly was the *shifted* input + // Recall that the first factor of g was the *shifted* dual input + // (input - zero_point)^ let input_opening_value = claimed_evaluations[0] + i_z_p_f; let weight_opening_value = claimed_evaluations[1]; @@ -378,6 +400,8 @@ where ) .unwrap(); + // TODO: b and o are opened at the same point, so they could be opened + // with a single call to PCS::open let bias_opening_proof = PCS::open( &ck, [&LabeledPolynomial::new( From a9884f87aa50841fbf88bec8933b0faba0ee9e25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Mon, 26 Feb 2024 10:16:01 +0100 Subject: [PATCH 48/50] removed unnecessary fix_variables function --- src/utils/mod.rs | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 69a4125..15a05fb 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -6,27 +6,3 @@ pub(crate) mod test_sponge; use ark_ff::Field; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; - -pub(crate) fn fix_variables( - poly: &DenseMultilinearExtension, - partial_point: &[F], -) -> DenseMultilinearExtension { - assert!( - partial_point.len() <= poly.num_vars, - "invalid size of partial point" - ); - let nv = poly.num_vars; - - let mut poly = poly.evaluations.to_vec(); - let dim = partial_point.len(); - // evaluate single variable of partial point from right to left - for i in 1..dim + 1 { - let r = partial_point[i - 1]; - for b in 0..(1 << (nv - i)) { - let left = poly[b << 1]; - let right = poly[(b << 1) + 1]; - poly[b] = left + r * (right - left); - } - } - DenseMultilinearExtension::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))]) -} From 76e408f64b73bbec941bb7e6fe67d6257aea8080 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Mon, 26 Feb 2024 10:45:00 +0100 Subject: [PATCH 49/50] polished model- and bmm-verification comments --- src/model/isolated_verification.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index 87898ac..96cf0d4 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -56,20 +56,23 @@ where _ => panic!("Expected BMMNodeProof"), }; + // Squeezing random challenge r to bind the first variables of W^ to let r: Vec = sponge.squeeze_field_elements(padded_dims_log.1); - // The value proved in sumcheck should be the difference between the output - // and the bias + // The hypercube sum proved in sumcheck should be the difference between + // the output and the bias let sumcheck_evaluation = output_opening_value - bias_opening_value; - // Information about the polynomial f(s) = input_mle(s) * weight_mle(s, r) - // to which sumcheck is applied + // Public information about the sumchecked polynomial + // g(x) = (input - zero_point)^(x) * W^(r, x), let info = PolynomialInfo { max_multiplicands: 2, num_variables: padded_dims_log.0, products: vec![(F::one(), vec![0, 1])], }; + // Verify the sumcheck proof for g and obtaining the oracle-call point s + // and claimed evaluation g(s) let Ok(subclaim) = MLSumcheck::verify(&info, sumcheck_evaluation, &sumcheck_proof, sponge) else { return false; @@ -80,10 +83,14 @@ where expected_evaluation: oracle_evaluation, } = subclaim; + // Verify g(s) agrees with the claims for (input - zero_point)^(s) and + // W^(r, s) if oracle_evaluation != (input_opening_value - input_zero_point) * weight_opening_value { return false; } + // Verify that the opening of input^ at s agrees with the claimed value for + // (input - zero_point)^(s) // TODO possibly rng, not None if !PCS::check( vk, @@ -99,6 +106,8 @@ where return false; } + // Verify the openings of W^ at r || s and b and o at r match the claimed + // values // TODO possibly rng, not None if !PCS::check( vk, @@ -277,6 +286,8 @@ where Poly::from_evaluations_vec(log2(output_node_f.len()) as usize, output_node_f) .evaluate(&output_challenge_point); + // The computed values should match the openings of the corresponding + // vectors // TODO rng, None if !PCS::check( vk, From efb0960eba5f70e915979042db3a10b2425539b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Mej=C3=ADas=20Gil?= Date: Mon, 26 Feb 2024 12:21:36 +0100 Subject: [PATCH 50/50] combined opening/verification of output and bias, since point is the same --- src/model/isolated_verification.rs | 27 +++------------- src/model/nodes/bmm.rs | 51 ++++++++++-------------------- 2 files changed, 22 insertions(+), 56 deletions(-) diff --git a/src/model/isolated_verification.rs b/src/model/isolated_verification.rs index 96cf0d4..1d2f642 100644 --- a/src/model/isolated_verification.rs +++ b/src/model/isolated_verification.rs @@ -47,10 +47,9 @@ where input_opening_value, weight_opening_proof, weight_opening_value, - bias_opening_proof, - bias_opening_value, - output_opening_proof, + output_bias_opening_proof, output_opening_value, + bias_opening_value, } = match proof { NodeProof::BMM(p) => p, _ => panic!("Expected BMMNodeProof"), @@ -123,28 +122,12 @@ where return false; } - // TODO: b and o are opened at the same point, so they could be verified - // with a single call to PCS::check - if !PCS::check( - vk, - [bias_com], - &r, - [bias_opening_value], - &bias_opening_proof, - sponge, - None, - ) - .unwrap() - { - return false; - } - PCS::check( vk, - [output_com], + [output_com, bias_com], &r, - [output_opening_value], - &output_opening_proof, + [output_opening_value, bias_opening_value], + &output_bias_opening_proof, sponge, None, ) diff --git a/src/model/nodes/bmm.rs b/src/model/nodes/bmm.rs index 0484830..cd7bcc2 100644 --- a/src/model/nodes/bmm.rs +++ b/src/model/nodes/bmm.rs @@ -92,25 +92,24 @@ pub(crate) struct BMMNodeProof< /// where v^ denotes the dual of the MLE of v and r is a challenge point pub(crate) sumcheck_proof: Proof, - /// Value of the *dual* of the input MLE at the challenge point and proof of - /// opening + /// Value of the *dual* of the input MLE at the challenge point s and proof + /// of opening pub(crate) input_opening_proof: PCS::Proof, pub(crate) input_opening_value: F, - /// Value of the *dual* of the weight MLE at the challenge point and proof of + /// Value of the *dual* of the weight MLE at the challenge point r || s and proof of /// opening pub(crate) weight_opening_proof: PCS::Proof, pub(crate) weight_opening_value: F, - /// Value of the *dual* of the bias MLE at the challenge point and proof of - /// opening - pub(crate) bias_opening_proof: PCS::Proof, - pub(crate) bias_opening_value: F, + /// Proof of opening of the *duals* of the output and bias MLEs at the + // challenge point + pub(crate) output_bias_opening_proof: PCS::Proof, - /// Value of the *dual* of the output MLE at the challenge point and proof of + /// Value of the *dual* of the weight MLE at the challenge point and proof of /// opening - pub(crate) output_opening_proof: PCS::Proof, pub(crate) output_opening_value: F, + pub(crate) bias_opening_value: F, } impl NodeOps for BMMNode @@ -329,8 +328,6 @@ where // Dual of the MLE of the bias vector let bias_mle = Poly::from_evaluations_vec(self.padded_dims_log.1, bias_f); - // TODO is output_opening_value directly available from the output of sumcheck? - // It doesn't need to be used until the end of the method let bias_opening_value = bias_mle.evaluate(&r); let output_opening_value = output.evaluate(&r); @@ -402,29 +399,16 @@ where // TODO: b and o are opened at the same point, so they could be opened // with a single call to PCS::open - let bias_opening_proof = PCS::open( + let output_bias_opening_proof = PCS::open( &ck, - [&LabeledPolynomial::new( - "bias_mle".to_string(), - bias_mle, - Some(1), - None, - )], - [bias_com], + [ + output, + &LabeledPolynomial::new("bias_mle".to_string(), bias_mle, Some(1), None), + ], + [output_com, bias_com], &r, sponge, - [bias_com_state], - None, - ) - .unwrap(); - - let output_opening_proof = PCS::open( - &ck, - [output], - [output_com], - &r, - sponge, - [output_com_state], + [output_com_state, bias_com_state], None, ) .unwrap(); @@ -435,10 +419,9 @@ where input_opening_value, weight_opening_proof, weight_opening_value, - bias_opening_proof, - bias_opening_value, - output_opening_proof, + output_bias_opening_proof, output_opening_value, + bias_opening_value, }) } }