diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c5f68e0f54..60b480146c 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -47,6 +47,16 @@ jobs: - uses: actions-rust-lang/setup-rust-toolchain@v1 with: components: clippy + - name: Install LLVM + run: | + wget -qO /tmp/llvm.sh https://apt.llvm.org/llvm.sh + chmod +x /tmp/llvm.sh + sudo /tmp/llvm.sh 22 all + sudo apt-get install -y libmlir-22-dev mlir-22-tools + - name: Configure LLVM + run: | + echo "MLIR_SYS_220_PREFIX=/usr/lib/llvm-22" >> "$GITHUB_ENV" + echo "/usr/lib/llvm-22/bin" >> "$GITHUB_PATH" - name: cargo clippy --features allocative,host run: cargo clippy --all --all-targets --features allocative,host - name: cargo clippy --features allocative,host,zk @@ -222,6 +232,18 @@ jobs: steps: - uses: actions/checkout@v6 - uses: actions-rust-lang/setup-rust-toolchain@v1 + - name: Install LLVM + run: | + wget -qO /tmp/llvm.sh https://apt.llvm.org/llvm.sh + chmod +x /tmp/llvm.sh + sudo /tmp/llvm.sh 22 all + sudo apt-get install -y libmlir-22-dev mlir-22-tools + - name: Configure LLVM + run: | + echo "MLIR_SYS_220_PREFIX=/usr/lib/llvm-22" >> "$GITHUB_ENV" + echo "/usr/lib/llvm-22/bin" >> "$GITHUB_PATH" + - name: Build and install jolt CLI + run: cargo install --path . --locked --force - name: Install nextest uses: taiki-e/install-action@nextest - name: Discover modular crates diff --git a/.gitignore b/.gitignore index 10c39b7ff4..dda984bddf 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ *.txt .DS_Store +.bolt-dev-env pprof/ # pprof files diff --git a/CLAUDE.md b/CLAUDE.md index 7cf20e6ce7..33a72f8ef7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -42,6 +42,20 @@ cargo build -p jolt-core -q cargo install --path . --locked ``` +### Local Bolt/MLIR Environment + +Before running Bolt MLIR/codegen checks locally, generate and source the Bolt +dev environment: + +```bash +scripts/setup-bolt-dev.sh +source .bolt-dev-env +``` + +Agents should source `.bolt-dev-env` before any Bolt, generated-role, or +Jolt-on-Bolt equivalence command so `llvm-config`, `MLIR_SYS_220_PREFIX`, and +the local `jolt` CLI all resolve consistently. + ### Profiling ```bash diff --git a/Cargo.lock b/Cargo.lock index 91d1f3d200..d15f61c316 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -954,6 +954,26 @@ dependencies = [ "virtue", ] +[[package]] +name = "bindgen" +version = "0.72.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" +dependencies = [ + "bitflags 2.11.0", + "cexpr", + "clang-sys", + "itertools 0.13.0", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.117", +] + [[package]] name = "bit-set" version = "0.8.0" @@ -1075,6 +1095,13 @@ dependencies = [ "zeroize", ] +[[package]] +name = "bolt" +version = "0.1.0" +dependencies = [ + "melior", +] + [[package]] name = "borsh" version = "1.6.0" @@ -1188,6 +1215,15 @@ dependencies = [ "serde", ] +[[package]] +name = "caseless" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6fd507454086c8edfd769ca6ada439193cdb209c7681712ef6275cccbfe5d8" +dependencies = [ + "unicode-normalization", +] + [[package]] name = "cast" version = "0.3.0" @@ -1206,6 +1242,15 @@ dependencies = [ "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -1259,6 +1304,17 @@ dependencies = [ "half", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.6.1" @@ -1347,6 +1403,23 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "comrak" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f07383e7799d964bf7ffa6fc4457d177c54a44614661c7458bb0bd91b108e32" +dependencies = [ + "caseless", + "entities", + "finl_unicode", + "jetscii", + "phf", + "phf_codegen", + "rustc-hash", + "smallvec", + "typed-arena", +] + [[package]] name = "const-hex" version = "1.18.1" @@ -1406,6 +1479,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "convert_case" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "affbf0190ed2caf063e3def54ff444b449371d55c58e513a95ab98eca50adb49" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1746,7 +1828,7 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" dependencies = [ - "convert_case", + "convert_case 0.10.0", "proc-macro2", "quote", "rustc_version 0.4.1", @@ -1928,6 +2010,12 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" +[[package]] +name = "entities" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" + [[package]] name = "enum-ordinalize" version = "4.3.2" @@ -2117,6 +2205,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "finl_unicode" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9844ddc3a6e533d62bba727eb6c28b5d360921d5175e9ff0f1e621a5c590a4d5" + [[package]] name = "fixed-hash" version = "0.8.0" @@ -2644,6 +2738,12 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +[[package]] +name = "jetscii" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47f142fe24a9c9944451e8349de0a56af5f3e7226dc46f3ed4d4ecc0b85af75e" + [[package]] name = "jiff" version = "0.2.23" @@ -3226,6 +3326,16 @@ version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.16" @@ -3325,6 +3435,32 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "melior" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c388c08773539126c32f2a9c65260695ed9e2b9376164ce8b839826e3085d8d" +dependencies = [ + "melior-macro", + "mlir-sys", +] + +[[package]] +name = "melior-macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aac876dfce9f514df46c6d4b7267f1fc513126ddc34dfde11909584fdfc87bd" +dependencies = [ + "comrak", + "convert_case 0.11.0", + "proc-macro2", + "quote", + "regex", + "syn 2.0.117", + "tblgen", + "unindent", +] + [[package]] name = "memchr" version = "2.8.0" @@ -3413,6 +3549,12 @@ name = "mini-template" version = "0.1.0" source = "git+https://github.com/LayerZero-Labs/ZeroOS.git?rev=3b132ce862ba6769a4261d151f69ff32c5f0dc30#3b132ce862ba6769a4261d151f69ff32c5f0dc30" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -3423,6 +3565,15 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "mlir-sys" +version = "220.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55f8c5781dc23ebe7456f9e9aaecab9eee9e542f7e9051b95b0a31b1d5d4b61f" +dependencies = [ + "bindgen", +] + [[package]] name = "modinv" version = "0.1.0" @@ -3512,6 +3663,16 @@ dependencies = [ "libc", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "ntapi" version = "0.4.3" @@ -3920,6 +4081,16 @@ dependencies = [ "serde", ] +[[package]] +name = "phf_codegen" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49aa7f9d80421bca176ca8dbfebe668cc7a2684708594ec9f3c0db0805d5d6e1" +dependencies = [ + "phf_generator", + "phf_shared", +] + [[package]] name = "phf_generator" version = "0.13.1" @@ -5592,6 +5763,18 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tblgen" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15a6768dc51535aeaecbe812e2a792a77cbef10b1c4529e29bea5864a738dc6a" +dependencies = [ + "bindgen", + "cc", + "paste", + "thiserror 2.0.18", +] + [[package]] name = "tempfile" version = "3.27.0" @@ -5719,6 +5902,21 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tinyvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "toml_datetime" version = "1.1.1+spec-1.1.0" @@ -6034,6 +6232,15 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-normalization" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -6046,6 +6253,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + [[package]] name = "unsafe-libyaml" version = "0.2.11" diff --git a/Cargo.toml b/Cargo.toml index dfb4789692..81b7728ae6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ members = [ "crates/jolt-riscv", "crates/jolt-transcript", "crates/jolt-witness", + "crates/bolt", "crates/jolt-profiling", "crates/jolt-field", "jolt-core", @@ -385,6 +386,7 @@ jolt-transcript = { path = "./crates/jolt-transcript" } jolt-sumcheck = { path = "./crates/jolt-sumcheck" } jolt-r1cs = { path = "./crates/jolt-r1cs" } jolt-witness = { path = "./crates/jolt-witness" } +bolt = { path = "./crates/bolt" } jolt-riscv = { path = "./crates/jolt-riscv", default-features = false } jolt-program = { path = "./crates/jolt-program", default-features = false } jolt-lookup-tables = { path = "./crates/jolt-lookup-tables" } diff --git a/crates/bolt/Cargo.toml b/crates/bolt/Cargo.toml new file mode 100644 index 0000000000..31ca735f59 --- /dev/null +++ b/crates/bolt/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "bolt" +version = "0.1.0" +authors = ["Jolt Contributors"] +edition = "2021" +license = "MIT OR Apache-2.0" +description = "Bolt-shaped compiler prototype for the Jolt zkVM" +repository = "https://github.com/a16z/jolt" +keywords = ["SNARK", "compiler", "mlir", "protocol"] +categories = ["cryptography"] + +[lints] +workspace = true + +[dependencies] +melior = "0.27" diff --git a/crates/bolt/GENERIC_PROTOCOL_GOAL.md b/crates/bolt/GENERIC_PROTOCOL_GOAL.md new file mode 100644 index 0000000000..607035f3b3 --- /dev/null +++ b/crates/bolt/GENERIC_PROTOCOL_GOAL.md @@ -0,0 +1,228 @@ +# Bolt Generic Protocol Goal + +Bolt should be a compiler framework for SNARK/PIOP-style protocols, not a +Jolt-shaped compiler with configurable names. Jolt is the first complete +protocol package and the correctness/performance oracle, but generic Bolt +layers must remain reusable for other protocols. + +## Objective + +Refactor the compiler boundaries so generic IRs, passes, validation, and Rust +artifact assembly operate over protocol concepts: + +```text +roles +stages +transcript events +oracles and commitments +claims and relations +sumcheck obligations +opening obligations +proof slots +role-specific execution plans +``` + +Jolt-specific facts should enter only through the Jolt protocol package: + +```text +protocol params +stage ordering +oracle names +relation definitions +proof-slot names +transcript labels +artifact crate names +prover kernel ABI mappings +Jolt-specific evaluation-proof composition +``` + +The result should make adding another protocol a matter of adding a new +`src/protocols//` package plus artifact config and any required prover +kernels, not editing Bolt's generic compiler core. + +## Non-Negotiables + +- Generic Bolt modules must not branch on Jolt stage names, Jolt relation + symbols, or Jolt artifact names. +- Jolt symbols may appear as ordinary MLIR symbol data carried by Jolt-built + modules, but generic passes may only preserve, validate structurally, or emit + that data. +- Generic lowering remains: + +```text +protocol -> concrete -> party -> compute -> cpu -> Rust +``` + +- Rust emission is the final target. Protocol behavior should be represented + in dialect ops or MLIR-derived typed plans before Rust is emitted. +- Verifier emission must remain kernel-free and protocol-auditable. +- Prover kernels are protocol-package implementation details below the + dialect boundary. +- Checked-in generated role crates remain generated artifacts, not + hand-maintained source. + +## Target Source Layout + +Generic compiler modules: + +```text +crates/bolt/src/dialects.rs +crates/bolt/src/ir.rs +crates/bolt/src/mlir.rs +crates/bolt/src/schema/ +crates/bolt/src/pass/ +crates/bolt/src/emit/rust/ +``` + +Jolt protocol package: + +```text +crates/bolt/src/protocols/jolt/ + params.rs + validate.rs + oracles.rs + phases/ + relations/ + emit/ + rust/ + artifacts.rs +``` + +The exact file split can evolve, but ownership should not: generic modules own +compiler mechanics; `protocols/jolt` owns Jolt instantiation facts. + +## Generic IR Criteria + +Generic IR should expose enough structure for passes and emitters to reason +without protocol-specific string matching: + +- `protocol` declares roles, stage boundaries, protocol params, and proof + boundaries. +- `transcript` declares absorb/squeeze events and explicit state threading. +- `piop` declares oracles, claims, sumchecks, relation obligations, opening + claims, opening equalities, and proof slots. +- `pcs` declares commitment, opening, verification, and evaluation-aggregation + obligations over abstract PCS schemes. +- `party` represents role projection without deleting semantic obligations + needed by later validation. +- `compute` represents executable obligations and optional prover kernel hooks. +- `cpu` represents backend-ready execution plans while staying + protocol-agnostic. + +If a generic emitter needs to branch on a string like +`jolt.stage6.booleanity`, the IR has not been lowered into a sufficiently +typed plan. + +## Generic Pass Criteria + +Generic passes may branch on: + +```text +dialect op name +role +phase +declared proof-slot kind +declared relation-plan kind +declared PCS/transcript operation kind +backend target +``` + +Generic passes must not branch on: + +```text +Jolt stage names +Jolt relation symbols +Jolt oracle names +Jolt artifact crate names +Jolt kernel ABI strings +``` + +Jolt-specific lowering is allowed inside `protocols/jolt`, but it should +produce generic dialect ops and typed plans consumed by the shared compiler. + +## Generic Artifact Criteria + +The generic Rust artifact assembler should be driven by `ProtocolArtifactConfig` +and ordered `ProtocolRustArtifact` values: + +- Protocol name, type prefix, transcript label, role crate names, dependencies, + forbidden imports, and type paths are config data. +- Stage modules are emitted from `ProtocolStage` data, not hardcoded enums in + generic artifact code. +- Top-level `prover.rs` and `verifier.rs` are generated from role/stage/proof + plans. +- Protocol-specific proof extensions are represented by explicit extension + config or generic PCS/evaluation IR, not by checks like + `type_prefix == "Jolt"`. +- `jolt-prover` may import verifier-owned proof types; `jolt-verifier` must + never import prover or kernel code. + +## Jolt Quarantine Criteria + +These are the only acceptable homes for Jolt-specific compiler knowledge: + +```text +crates/bolt/src/protocols/jolt/** +crates/bolt/tests/** when the test explicitly targets Jolt +crates/jolt-prover/** +crates/jolt-verifier/** +crates/jolt-kernels/** +crates/jolt-equivalence/** +``` + +Generic Bolt modules should have a hygiene gate rejecting `jolt`, `Jolt`, +`stage6`, `stage7`, `stage8`, and Jolt relation/policy names, with a small +temporary allowlist during migration. + +## Correctness Criteria + +Every genericity cleanup slice must preserve the existing semantic oracles: + +```text +generated role crates still compile +checked-in generated role crates match canonical generation +Bolt prover artifacts are accepted by the generated Bolt verifier +Bolt prover artifacts are accepted by the core oracle for implemented stages +Bolt/core transcript histories match for implemented stages +internal prover/verifier transcript histories match for implemented stages +tampering gates still reject malformed artifacts +generated verifier import boundaries remain intact +verifier CPU IR remains kernel-free +``` + +For pure file moves and namespace refactors, generated output should either be +byte-for-byte unchanged or intentionally regenerated with a clear explanation +of why the generated surface changed. + +## Migration Algorithm + +For each cleanup slice: + +1. Identify the Jolt-specific fact currently living in a generic module. +2. Decide whether it is protocol data, relation semantics, artifact config, or + a prover-kernel implementation detail. +3. Move that fact to `protocols/jolt` or encode it as generic IR/typed plan + data. +4. Keep generic APIs protocol-named (`Protocol*`) and provide Jolt convenience + wrappers only under `protocols::jolt`. +5. Add or tighten a hygiene gate so the leak does not reappear. +6. Regenerate artifacts only through the canonical generator. +7. Run the relevant schema, generation, import, equivalence, and tamper gates. + +Do not hide protocol semantics in opaque Rust helpers to pass the hygiene gate. +If the generic emitter needs new information, add a typed plan field or dialect +operation and validate it. + +## Definition Of Done + +- `crates/bolt/src/lib.rs` exports generic compiler APIs at the root and keeps + Jolt APIs namespaced under `bolt::protocols::jolt`. +- `crates/bolt/src/emit/rust` contains generic Rust backend mechanics only. +- `JoltProtocolStage`, Jolt artifact config, Jolt stage emitters, Jolt relation + mappings, and Jolt eval-proof composition live under `protocols/jolt`. +- Generic artifact assembly can produce role crates for a non-Jolt protocol + fixture using only `ProtocolArtifactConfig` and `ProtocolStage` data. +- Generic passes and validators have automated hygiene tests preventing Jolt + leakage. +- Existing Jolt correctness, transcript, tamper, import, generated-artifact, + and performance gates remain available and green for implemented stages. diff --git a/crates/bolt/GOAL.md b/crates/bolt/GOAL.md new file mode 100644 index 0000000000..4f0aa601ff --- /dev/null +++ b/crates/bolt/GOAL.md @@ -0,0 +1,395 @@ +# Bolt Jolt Verifier Goal + +Bolt's first full-field, non-zk Jolt implementation is semantically complete +enough to move the active goal from stage bring-up to verifier-pipeline +hardening. The next long-haul objective is to make the Bolt-generated Jolt +verifier compact, human-readable, auditable, and security-hardened while +preserving the existing full-`Fr` Jolt semantics. + +## Objective + +Refactor the Bolt-generated Jolt verifier pipeline so the generated verifier is +a small orchestration layer plus declarative verifier plans, backed by reusable +verifier runtime modules. The compiler should continue to own protocol facts +through MLIR and typed plan data; generated Rust should not rediscover Jolt +semantics late through ad hoc string matching or repeated stage-local helper +code. + +Starting baseline: + +```text +generated jolt-verifier: ~21.5k LOC +stage6 + stage7: ~13.2k LOC +verifier.rs: 649 LOC +``` + +Current locked cleanup baseline: + +```text +generated jolt-verifier total: 7,755 LOC +generated verifier surface: 5,966 LOC +shared verifier runtime: 1,789 LOC +stage6 + stage7: 1,669 LOC +verifier.rs: 487 LOC +``` + +Target: + +```text +generated verifier surface: <= 4k-6k LOC +stretch generated surface: <= 2k-3k LOC +verifier.rs orchestration: <= 350-500 LOC +stage6 + stage7 generated surface: <= 2k-3k LOC +shared runtime/helpers: allowed when generic, named, and reviewed +``` + +The goal is to reduce the human-facing generated verifier surface by roughly an +order of magnitude. Shared runtime code may exist, but it must be modular, +boring to audit, and driven by explicit MLIR-derived plan data. + +This verifier cleanup is coupled to the generic protocol cleanup in +`GENERIC_PROTOCOL_GOAL.md`: shrinking the generated verifier should move generic +mechanics into Bolt IR/typed plans and shared runtime, not into Jolt-specific +emitter special cases. + +## Locked Genericity Decisions + +The next cleanup track should make Jolt a quarantined protocol package over +generic Bolt compiler infrastructure: + +- Root `bolt::*` exports should be generic-only. Jolt APIs should be imported + from `bolt::protocols::jolt::*`. +- Jolt-specific emitters are not the long-term target. Quarantine them first so + leakage is explicit, then progressively lift stage emission into a generic + `cpu -> Rust` backend driven by typed MLIR-derived plans. +- Replace the current Jolt evaluation-proof special case with either a generic + protocol extension hook or generic PCS/evaluation IR. Start with the minimal + extension hook if that keeps the cleanup mechanical. +- Add hygiene gates for generic compiler modules, initially targeting + `crates/bolt/src/{schema.rs,pass.rs,emit/rust}`. Any temporary Jolt allowlist + must be explicit and shrink over time. +- Namespace and file-layout refactors should preserve generated + `jolt-prover`/`jolt-verifier` output byte-for-byte unless the change + intentionally updates artifact structure. +- At the end of each goal-mode slice, report which quarantined Jolt emitters are + still genuinely protocol-specific and which are ready to lift into generic + typed-plan emission. + +## Immediate Goal-Mode Slice + +First objective for another agent: + +```text +Quarantine Jolt-specific artifact APIs out of generic Rust artifact assembly +while preserving generated output and all current gates. +``` + +Required steps: + +1. Move `JoltProtocolStage`, `jolt_artifact_config`, `jolt_rust_artifact`, + `assemble_jolt_*`, `write_jolt_generated_crates`, and + `validate_jolt_rust_artifact_imports` out of generic + `crates/bolt/src/emit/rust/artifacts.rs` into a Jolt-owned module such as + `crates/bolt/src/protocols/jolt/artifacts.rs`. +2. Keep `ProtocolArtifactConfig`, `ProtocolStage`, `ProtocolRustArtifact`, + `GeneratedCrate`, `assemble_generated_crates`, `write_generated_crates`, and + `validate_rust_artifact_imports` in the generic artifact layer. +3. Stop re-exporting Jolt APIs from `crates/bolt/src/lib.rs`; update callers in + Bolt tests, `jolt-equivalence`, and perf harnesses to import from + `bolt::protocols::jolt`. +4. Add the first genericity hygiene test that rejects new Jolt protocol strings + in generic compiler modules, using a small documented allowlist only for + migration leftovers. +5. Run focused generation/import gates and confirm checked-in generated role + crates are unchanged unless an intentional artifact-structure change is + documented. + +Acceptance criteria: + +```text +generic artifact assembly has no Jolt stage enum or Jolt artifact config +root bolt exports are generic-only +Jolt artifact helpers are namespaced under protocols::jolt +generated jolt-prover/jolt-verifier are byte-for-byte unchanged, or changes are intentional +genericity hygiene gate exists +existing generated-artifact and verifier-boundary gates pass +``` + +## Non-Negotiables + +- Preserve the current full-field non-zk Jolt protocol path: + `Transcript`. +- `jolt-verifier` must not depend on `jolt-prover`, `jolt-kernels`, + `jolt-core`, `jolt-equivalence`, `jolt-profiling`, or tracer internals. +- Bolt compiler boundaries remain: + `protocol -> concrete -> party -> compute -> cpu -> Rust`. +- Verifier CPU IR must remain kernel-free. Prover kernels are temporary + implementation details below the dialect boundary. +- Jolt semantics should be represented in protocol builders, dialect ops, + validators, lowering passes, or typed verifier plans. The Rust emitter should + not infer protocol meaning from loose strings when a typed enum, attr, op, or + plan field can carry it. +- Generated verifier files should be mostly declarative: + +```rust +pub const STAGE_PLAN: StagePlan = ...; + +pub fn verify_stage(...) -> Result { + runtime::verify_stage(&STAGE_PLAN, ...) +} +``` + +## Target Architecture + +The final verifier shape should read like this: + +```text +crates/jolt-verifier + src/lib.rs + src/verifier.rs + public API + proof shape + stage ordering + error mapping + + src/stages/ + commitment.rs + stage1_outer.rs + stage2.rs + ... + mostly declarative generated plans + + src/runtime/ or shared verifier crate + generic stage verifier + generic field expression evaluator + generic opening-claim machinery + generic sumcheck/eval proof conversion + transcript helpers + typed relation evaluators +``` + +Generated stage files should answer: + +```text +What claims exist? +What expressions are evaluated? +What transcript events happen? +What openings are checked? +What relations are verified? +``` + +Runtime modules should answer: + +```text +How is a field expression plan evaluated? +How is a stage plan verified? +How are opening/eval consistency checks performed? +How are proof records converted into runtime verifier inputs? +``` + +## Main Refactor Tracks + +1. **Verifier runtime extraction** + + Move duplicated stage-local machinery into one runtime: + + ```text + field expression evaluation + opening claim lookup and equality checks + sumcheck driver verification + transcript squeeze/absorb helpers + stage proof conversion + stage plan execution + ``` + +2. **Shared verifier plan types** + + Replace stage-specific copies such as `Stage6FieldExprPlan` and + `Stage7OpeningClaimPlan` with shared plan structs: + + ```text + FieldExprPlan + OpeningClaimPlan + OpeningEqualityPlan + SumcheckClaimPlan + SumcheckDriverPlan + SumcheckEvalPlan + StagePlan + RelationPlan + ``` + +3. **Compact field expression encoding** + + Stage 6 and Stage 7 are bloated by per-expression constants and operand + arrays. Replace those with compact tables or pooled operand slices. + +4. **Typed relation dispatch** + + Replace stringly relation handling with typed plan data where practical: + + ```text + RelationKind::RamReadWrite + RelationKind::InstructionReadRaf + RelationKind::BytecodeReadRaf + RelationKind::Booleanity + RelationKind::HammingBooleanity + RelationKind::RegistersReadWrite + ... + ``` + + Any remaining string dispatch must be explicitly allowlisted and covered by + schema tests. + +5. **Clean top-level verifier API** + + `verifier.rs` should be readable orchestration: proof shape, verifier + inputs, verifier programs, stage ordering, evaluation proof handling, and + clear error mapping. Repeated per-stage proof conversion should disappear. + +## One-Time Hardening Work + +Before large readability refactors, add a durable verifier hardening suite. +The suite should include positive equivalence and negative tamper oracles. + +Verifier tamper cases: + +```text +valid generated proof verifies +core and Bolt verifier accept/reject agree +tampered sumcheck coefficient rejects +tampered sumcheck point rejects +tampered named eval rejects +tampered commitment rejects +missing commitment rejects +missing stage proof rejects +reordered stage proof rejects +stage proof in the wrong slot rejects +wrong transcript state rejects +wrong evaluation proof rejects +missing evaluation setup rejects +missing evaluation proof rejects +extra/missing opening claims reject +opening claims in the wrong order reject +opening equality mismatch rejects +PCS proof mismatch rejects +``` + +MLIR/compiler hardening cases: + +```text +unknown dialects rejected +prover-only ops rejected in verifier pipeline +verifier-only ops rejected in prover pipeline +unthreaded transcript ops rejected +hidden or reordered opening batch claims rejected +unsupported equality modes rejected +duplicate proof slots rejected +invalid point arity rejected +invalid round schedule rejected +invalid relation kind rejected +verifier CPU IR contains no kernel dispatch +generated verifier imports no forbidden crates +``` + +## Concrete Gates + +Readability and LOC gates: + +```text +total generated jolt-verifier LOC trends down +verifier.rs <= 500 LOC, stretch <= 350 +stage6 + stage7 generated LOC <= 3k-5k, stretch <= 2k-3k +no duplicate stage-local generic plan structs +no duplicate stage-local field-expression interpreter +no duplicate stage-local opening equality interpreter +no giant per-expression operand constants after compaction +stage files are mostly declarative plan data +``` + +Security and boundary gates: + +```text +jolt-verifier imports are allowlisted +no jolt-prover dependency from jolt-verifier +no jolt-kernels dependency from jolt-verifier +no jolt-core dependency from jolt-verifier +no prover role ops in verifier MLIR +no kernel attrs in verifier CPU IR +all transcript-producing ops thread transcript state +all opening batches preserve explicit ordered claims +all relation dispatch is typed or allowlisted +``` + +Semantic gates: + +```bash +cargo fmt --check +cargo check -p bolt -p jolt-verifier -p jolt-prover -p jolt-equivalence --quiet +cargo nextest run -p bolt --test verifier_cleanup --no-capture +cargo nextest run -p bolt --test commitment_ir --cargo-quiet +cargo nextest run -p jolt-equivalence --test generated_role_crates --cargo-quiet +cargo nextest run -p jolt-equivalence --test bolt_commitment --no-capture +``` + +Required semantic outcomes: + +```text +core accepts Bolt proof +Bolt verifier accepts Bolt proof +core and Bolt transcript state histories match +core and Bolt observable proof artifacts match +core and Bolt reject equivalent tampered proofs +generated prover/verifier crates stay in sync with artifact rail +``` + +Perf remains a regression guard, not the center of this task. The existing +`sha2-chain` e2e/proving Perfetto traces are useful for confirming cleanup does +not accidentally move prover cost, but the main objective is verifier +readability, simplicity, and security. + +## Iteration Algorithm + +Each cleanup loop should follow the same rule: + +```text +1. Measure current LOC, duplication, imports, and typed-vs-string dispatch. +2. Pick one duplication class or hygiene issue to eliminate. +3. Move generic logic into runtime only if semantics remain explicit in MLIR or + typed plan data. +4. Regenerate checked-in verifier artifacts through the compiler rail. +5. Run hardening, equivalence, import, and schema gates. +6. Keep the change only if readability improves and no oracle weakens. +``` + +Use this scoring function when choosing work: + +```text +score = + LOC reduction ++ fewer duplicate structs/functions ++ fewer string dispatch sites ++ fewer generated helper bodies ++ stronger negative oracles ++ clearer verifier.rs +- semantic opacity introduced into runtime +``` + +## Definition Of Done + +This long-haul cleanup is complete when: + +```text +generated verifier surface is <= 4k-6k LOC +verifier.rs is <= 500 LOC +stage files are mostly declarative plans +generic verifier mechanics live once +Jolt relation semantics are typed and auditable +MLIR verifier pathway has malformed-input rejection tests +tamper suite covers commitments, transcript, stages, openings, evals, and PCS proof +core/Bolt accept/reject equivalence is preserved +generated verifier import boundaries are enforced +``` + +The desired end state is not merely fewer lines. The verifier should be easy to +navigate, easy to audit, and hard for the compiler pipeline to accidentally +weaken. diff --git a/crates/bolt/JOLT_PROTOCOL_IMPLEMENTATION.md b/crates/bolt/JOLT_PROTOCOL_IMPLEMENTATION.md new file mode 100644 index 0000000000..a6034e37f8 --- /dev/null +++ b/crates/bolt/JOLT_PROTOCOL_IMPLEMENTATION.md @@ -0,0 +1,74 @@ +# Jolt Protocol Implementation Notes + +The original stage-by-stage bring-up plan has been completed for the first +full-field, non-zk Jolt-on-Bolt implementation. The active long-haul goal now +lives in `GOAL.md`: make the generated Jolt verifier much smaller, cleaner, +and better hardened. + +This file keeps the durable implementation rules that should continue to guide +that cleanup. + +The companion genericity goal lives in `GENERIC_PROTOCOL_GOAL.md`. It defines +the rule that Jolt is a protocol package over Bolt, not a special case inside +generic IRs, passes, validation, or Rust artifact assembly. + +## Permanent Compiler Rules + +- Protocol facts live in `crates/bolt/src/protocols/jolt` and typed MLIR/plan + structures, not in generated Rust control flow. +- Generic dialects should remain generic. Jolt-only names and parameters may be + ordinary attrs or SSA values carried by the Jolt protocol definition, but + they should not become hidden assumptions in generic lowering code. +- Generic artifact assembly should consume protocol config and ordered stage + artifacts; Jolt artifact names, stage enums, relation mappings, and eval-proof + composition belong under `crates/bolt/src/protocols/jolt`. +- Lowering order remains: + +```text +protocol -> concrete -> party -> compute -> cpu -> Rust +``` + +- Rust emission is the final target. Before emission, behavior should be + represented as dialect ops, validation passes, analyses, rewrites, lowerings, + or typed plan extraction. +- Prover code may use coarse CPU kernels while performance work continues. + Those kernels are below the dialect boundary. +- Verifier code must stay kernel-free and audit-stable. It should use modular + verifier crates and generated plan data, not `jolt-kernels` or `jolt-core`. + +## Verifier Cleanup Algorithm + +For every verifier cleanup iteration: + +1. Measure generated verifier LOC, stage LOC, duplicate plan structs, duplicate + helper functions, forbidden imports, and string-dispatch sites. +2. Pick one duplication class or compiler hygiene issue. +3. Move generic mechanics into shared verifier runtime only when protocol + semantics remain explicit in MLIR-derived typed plan data. +4. Regenerate checked-in `jolt-prover` and `jolt-verifier` artifacts through the + canonical artifact rail. +5. Run schema, import, equivalence, and tamper gates. +6. Keep the change only when generated code is easier to read and no semantic + oracle weakens. + +## Do Not Regress + +- Verifier CPU IR must not contain kernel attrs or prover-only ops. +- Generated verifier Rust must not import `jolt-kernels`, `jolt-core`, + `jolt-prover`, `jolt-equivalence`, `jolt-profiling`, or tracer internals. +- Transcript state must be explicitly threaded through MLIR. +- Opening batches must preserve ordered claim lists. +- Opening equality checks must reject incompatible claim metadata. +- Sumcheck relation dispatch should be typed or explicitly allowlisted. +- Full-field transcript challenges are the intended path: + `Transcript`. + +## Regeneration Rail + +Checked-in generated role crates are not hand-maintained. Regenerate them with: + +```bash +JOLT_UPDATE_GOLDENS=1 cargo nextest run -p bolt generated_jolt_artifacts_have_uniform_crate_layout_and_import_rules --cargo-quiet +``` + +Then run the gates in `TESTING.md`. diff --git a/crates/bolt/README.md b/crates/bolt/README.md new file mode 100644 index 0000000000..19301e32f6 --- /dev/null +++ b/crates/bolt/README.md @@ -0,0 +1,114 @@ +# bolt + +This crate is the Bolt-shaped compiler prototype for the full-field, non-zk +Jolt implementation. `melior::ir::Module` is the IR source of truth; Rust types +provide phase/role guardrails, schema validation, builders, analysis results, +and final Rust emission. + +## Active Goal + +The first Jolt-on-Bolt implementation is semantically complete enough that the +active work is verifier cleanup and hardening, not stage bring-up. See +`GOAL.md` for the long-haul target: + +```text +make the generated jolt-verifier compact, human-readable, auditable, +security-hardened, and driven by explicit MLIR-derived plan data +``` + +`GENERIC_PROTOCOL_GOAL.md` describes the parallel cleanup track that makes Bolt +generic over protocol packages instead of Jolt-shaped. `JOLT_PROTOCOL_IMPLEMENTATION.md` +keeps the durable compiler-boundary rules. `TESTING.md` lists the LOC, +readability, equivalence, import, MLIR, and tamper gates for this cleanup track. + +## Compiler Shape + +Protocol-specific facts live under `src/protocols/`. Generic compiler layers +understand Bolt dialect operations but should not learn Jolt-only protocol +semantics except as ordinary typed attrs, SSA values, or typed plan data carried +by a protocol definition. + +The intended lowering path is: + +```text +protocol -> concrete -> party -> compute -> cpu -> Rust +``` + +The dialect split matters: + +- `protocol`, `piop`, `poly`, `field`, `transcript`, `commit`, and `pcs` model + protocol obligations. +- `party` projects prover/verifier visibility. +- `compute` represents role-specific executable structure while preserving + semantic dataflow. +- `cpu` is the final MLIR target before Rust emission. +- Rust is generated output, not the place where protocol meaning should be + inferred. + +## Verifier Boundary + +The generated verifier must remain audit-stable: + +```text +no jolt-prover dependency +no jolt-kernels dependency +no jolt-core dependency +no jolt-equivalence dependency +no jolt-profiling dependency +no tracer internals +``` + +Verifier CPU IR must stay kernel-free. Prover code may still call coarse +`jolt-kernels` CPU kernels while performance work continues, but those kernels +are below the dialect boundary and must not become verifier infrastructure. + +The cleanup target is for generated verifier modules to become mostly +declarative plan data, with generic mechanics factored into named verifier +runtime modules. + +## Generated Artifacts + +Generated Jolt Rust artifacts are organized as two role crates: + +```text +crates/jolt-prover +crates/jolt-verifier +``` + +The checked-in role crates are generated artifacts, not hand-maintained code. +Regenerate them through the Rust artifact rail with: + +```bash +JOLT_UPDATE_GOLDENS=1 cargo nextest run -p bolt generated_jolt_artifacts_have_uniform_crate_layout_and_import_rules --cargo-quiet +``` + +The generator emits manifests, stage registries, `src/stages/*.rs`, and the +top-level `prover.rs`/`verifier.rs` APIs. `jolt-verifier` owns proof types and +verification. `jolt-prover` may construct verifier-owned proof types, but must +not import verifier stage internals. + +## Local MLIR Toolchain + +The easiest setup path on macOS is: + +```bash +make bolt-dev-setup +source .bolt-dev-env +``` + +The helper installs Homebrew LLVM, Rust components used by CI, and writes the +local environment required by `mlir-sys`. + +On macOS with Homebrew LLVM: + +```bash +brew install llvm +export MLIR_SYS_220_PREFIX=/opt/homebrew/opt/llvm +export PATH="/opt/homebrew/opt/llvm/bin:$PATH" +export SDKROOT="$(xcrun --show-sdk-path)" +export BINDGEN_EXTRA_CLANG_ARGS="-isysroot$(xcrun --show-sdk-path)" +``` + +Do not set `MLIR_SYS_LINK_SHARED=1` with the Homebrew LLVM 22 bottle; it does +not ship `libMLIR-C.dylib`, so `mlir-sys` needs its default static MLIR link +path. diff --git a/crates/bolt/TESTING.md b/crates/bolt/TESTING.md new file mode 100644 index 0000000000..d1573410fc --- /dev/null +++ b/crates/bolt/TESTING.md @@ -0,0 +1,210 @@ +# Jolt-on-Bolt Equivalence Gates + +The first full-field, non-zk Jolt-on-Bolt implementation is in equivalence, +hardening, and perf-gating mode. The active objective is in +`crates/jolt-equivalence/GOAL.md`: keep `jolt-equivalence` as a thin oracle and +gate suite while semantic construction lives in Bolt, generated crates, +`jolt-kernels`, or `jolt-witness`. + +## Fast Local Gates + +Set up and source the Bolt dev environment first: + +```bash +scripts/setup-bolt-dev.sh +source .bolt-dev-env +``` + +Run: + +```bash +cargo fmt --check +cargo check -p bolt -p jolt-verifier -p jolt-prover -p jolt-equivalence --quiet +cargo nextest run -p bolt --test verifier_cleanup --no-capture +cargo nextest run -p bolt --test commitment_ir --cargo-quiet +cargo nextest run -p jolt-equivalence --test generated_role_crates --cargo-quiet +``` + +`commitment_ir` can also materialize ignored MLIR/Rust scratch fixtures for +local inspection: + +```bash +JOLT_UPDATE_GOLDENS=1 cargo nextest run -p bolt --test commitment_ir --cargo-quiet +``` + +These gates cover: + +- MLIR dialect registration and schema validation. +- Concrete transcript threading. +- Prover/verifier party projection. +- `compute` and `cpu` schema validation. +- Prover-only kernel resolution. +- Kernel-free verifier CPU IR. +- Generated Rust compilation. +- Generated role-crate layout and import boundaries. +- Matching generated stage registries. +- Generated verifier LOC, duplication, relation-string, and boundary metrics. + +## Real-Data Equivalence Gate + +Run: + +```bash +cargo nextest run -p jolt-equivalence --test bolt_commitment --no-capture +``` + +This is the main semantic oracle. It should continue to prove: + +- Bolt verifier accepts Bolt proof artifacts on real trace data. +- Core accepts the corresponding Bolt proof path. +- Bolt/core transcript histories match. +- Bolt/core observable proof artifacts match. +- Generated standalone and top-level verifier paths agree. +- Representative tampering is rejected by the generated verifier. + +The `Bolt equivalence` workflow runs the generated role parity and real-data +tamper gates on pull requests. It also has an optional full +`jolt-equivalence` sweep that runs on the nightly schedule, or manually through +`workflow_dispatch` with `include_full_sweep=true`: + +```bash +cargo nextest run -p jolt-equivalence --cargo-quiet +``` + +## Required Hardening Coverage + +The verifier hardening suite should cover these negative cases: + +```text +tampered commitment +missing commitment +tampered sumcheck coefficient +tampered sumcheck point +tampered named eval +missing stage proof +reordered stage proof +stage proof in wrong slot +wrong transcript state +missing opening claim +extra opening claim +opening claims in wrong order +opening equality mismatch +wrong evaluation proof +missing evaluation setup +missing evaluation proof +PCS proof mismatch +``` + +The MLIR/compiler hardening suite should cover: + +```text +unknown dialect rejection +prover-only op rejection in verifier pipeline +verifier-only op rejection in prover pipeline +unthreaded transcript op rejection +hidden/reordered opening batch claim rejection +unsupported equality mode rejection +duplicate proof slot rejection +invalid point arity rejection +invalid round schedule rejection +invalid relation kind rejection +kernel attr rejection in verifier CPU IR +forbidden generated verifier imports +``` + +## LOC And Readability Gates + +Track these metrics before and after each cleanup iteration: + +```text +total generated jolt-verifier LOC +verifier.rs LOC +stage6 + stage7 generated LOC +number of stage-local generic plan structs +number of stage-local helper/interpreter functions +number of field-expression operand constants +number of relation string-dispatch sites +forbidden imports +``` + +Targets: + +```text +generated verifier surface: <= 4k-6k LOC +stretch generated surface: <= 2k-3k LOC +verifier.rs orchestration: <= 350-500 LOC +stage6 + stage7 generated surface: <= 2k-3k LOC +``` + +Do not accept a LOC reduction that hides semantics in opaque runtime code. The +generated surface should shrink because generic mechanics moved into named, +reviewable runtime modules and the remaining generated code became declarative +plan data. + +## Regeneration Gate + +Checked-in generated role crates must stay synchronized with the artifact rail: + +```bash +JOLT_UPDATE_GOLDENS=1 cargo nextest run -p bolt generated_jolt_artifacts_have_uniform_crate_layout_and_import_rules --cargo-quiet +``` + +After regenerating, rerun the fast local gates and the real-data equivalence +gate. + +## Perf Oracle Guard + +New Jolt-on-Bolt changes should preserve a core-vs-Bolt perf oracle that uses +`jolt-profiling` as the shared instrumentation layer. The gate should run the +same program, inputs, trace length, PCS setup size, and transcript mode through: + +```text +core reference path: + setup, prove, verify, proof size, peak RSS + +Bolt generated path: + setup, prove, verify, proof size, peak RSS +``` + +Both paths must emit the same named tracing spans through `jolt-profiling`, at +minimum: + +```text +core.setup +core.prove +core.verify +bolt.setup +bolt.prove +bolt.commitment +bolt.commitment.batch +bolt.commitment.dory_commit +bolt.stage1 ... bolt.stage8 +bolt.evaluate +bolt.evaluate.claims +bolt.evaluate.materialize_joint_polynomial +bolt.evaluate.joint_opening_hint +bolt.evaluate.dory_open +bolt.verify +bolt.verify.evaluation_state +bolt.verify.dory_verify +``` + +The checked-in CI smoke programs are: + +```text +PR gate: bolt_sha2_chain_2_16_core_vs_bolt_perf_oracle +PR gate: bolt_sha2_chain_2_20_core_vs_bolt_perf_oracle +``` + +Both tests live in `jolt-equivalence/tests/bolt_perf.rs` because they reuse the +real semantic oracle fixture and pass paired `PerfMetrics` into +`jolt-profiling`'s `check_core_vs_bolt_gate`. The workflow sets +`JOLT_BOLT_PERF_TRACE=1` so the same run writes Perfetto JSON traces under +`benchmark-runs/perfetto_traces/`. + +To run them locally after `source .bolt-dev-env`: + +```bash +JOLT_BOLT_PERF_TRACE=1 cargo nextest run -p jolt-equivalence --test bolt_perf --release --cargo-quiet --run-ignored only --no-capture bolt_sha2_chain_2_16_core_vs_bolt_perf_oracle +JOLT_BOLT_PERF_TRACE=1 cargo nextest run -p jolt-equivalence --test bolt_perf --release --cargo-quiet --run-ignored only --no-capture bolt_sha2_chain_2_20_core_vs_bolt_perf_oracle +``` diff --git a/crates/bolt/irdl/commit.mlir b/crates/bolt/irdl/commit.mlir new file mode 100644 index 0000000000..b759683bd2 --- /dev/null +++ b/crates/bolt/irdl/commit.mlir @@ -0,0 +1,25 @@ +irdl.dialect @commit { + irdl.type @artifact + irdl.operation @publish_batch { + %artifact_type = irdl.parametric @commit::@artifact<> + %sym = irdl.any + %oracle_family = irdl.any + %label = irdl.any + irdl.attributes {"sym_name" = %sym, "oracle_family" = %oracle_family, "label" = %label} + irdl.results(artifact: %artifact_type) + } + irdl.operation @publish_optional { + %artifact_type = irdl.parametric @commit::@artifact<> + %sym = irdl.any + %oracle = irdl.any + %label = irdl.any + %skip_policy = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "label" = %label, + "skip_policy" = %skip_policy + } + irdl.results(artifact: %artifact_type) + } +} diff --git a/crates/bolt/irdl/compute.mlir b/crates/bolt/irdl/compute.mlir new file mode 100644 index 0000000000..60c01c9eaa --- /dev/null +++ b/crates/bolt/irdl/compute.mlir @@ -0,0 +1,807 @@ +irdl.dialect @compute { + irdl.type @commitment_artifact + irdl.type @transcript_state + irdl.type @oracle_buffer + irdl.type @oracle_family + irdl.type @field_value + irdl.type @point + irdl.type @sumcheck_claim_type + irdl.type @sumcheck_batch_type + irdl.type @sumcheck_result_type + irdl.type @sumcheck_proof_type + irdl.type @opening_claim_type + irdl.type @opening_batch_type + irdl.type @opening_proof_type + + irdl.operation @params { + %sym = irdl.any + %field = irdl.any + %pcs = irdl.any + %transcript = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field, "pcs" = %pcs, "transcript" = %transcript} + } + irdl.operation @function { + %sym = irdl.any + %source = irdl.any + irdl.attributes {"sym_name" = %sym, "source" = %source} + } + irdl.operation @relation { + %sym = irdl.any + %kind = irdl.any + %domain = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + %output_count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "kind" = %kind, + "domain" = %domain, + "num_rounds" = %num_rounds, + "degree" = %degree, + "output_count" = %output_count + } + } + irdl.operation @kernel { + %sym = irdl.any + %relation = irdl.any + %kind = irdl.any + %backend = irdl.any + %abi = irdl.any + irdl.attributes { + "sym_name" = %sym, + "relation" = %relation, + "kind" = %kind, + "backend" = %backend, + "abi" = %abi + } + } + irdl.operation @oracle_dense_trace { + %buffer = irdl.parametric @compute::@oracle_buffer<> + %sym = irdl.any + %oracle = irdl.any + %source = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %padding = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "source" = %source, + "domain" = %domain, + "num_vars" = %num_vars, + "padding" = %padding + } + irdl.results(buffer: %buffer) + } + irdl.operation @oracle_one_hot_chunk { + %buffer = irdl.parametric @compute::@oracle_buffer<> + %sym = irdl.any + %oracle = irdl.any + %source = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %trace_num_vars = irdl.any + %chunk = irdl.any + %num_chunks = irdl.any + %chunk_bits = irdl.any + %padding = irdl.any + %layout = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "source" = %source, + "domain" = %domain, + "num_vars" = %num_vars, + "trace_num_vars" = %trace_num_vars, + "chunk" = %chunk, + "num_chunks" = %num_chunks, + "chunk_bits" = %chunk_bits, + "padding" = %padding, + "layout" = %layout + } + irdl.results(buffer: %buffer) + } + irdl.operation @oracle_optional_advice { + %buffer = irdl.parametric @compute::@oracle_buffer<> + %sym = irdl.any + %oracle = irdl.any + %source = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %skip_policy = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "source" = %source, + "domain" = %domain, + "num_vars" = %num_vars, + "skip_policy" = %skip_policy + } + irdl.results(buffer: %buffer) + } + irdl.operation @oracle_ref { + %buffer = irdl.parametric @compute::@oracle_buffer<> + %sym = irdl.any + %oracle = irdl.any + %domain = irdl.any + %num_vars = irdl.any + irdl.attributes {"sym_name" = %sym, "oracle" = %oracle, "domain" = %domain, "num_vars" = %num_vars} + irdl.results(buffer: %buffer) + } + irdl.operation @oracle_family_init { + %family_type = irdl.parametric @compute::@oracle_family<> + %sym = irdl.any + %family = irdl.any + %count = irdl.any + irdl.attributes {"sym_name" = %sym, "family" = %family, "count" = %count} + irdl.results(family: %family_type) + } + irdl.operation @oracle_family_append { + %family_type = irdl.parametric @compute::@oracle_family<> + %buffer = irdl.parametric @compute::@oracle_buffer<> + %sym = irdl.any + %family = irdl.any + %oracle = irdl.any + %index = irdl.any + irdl.attributes {"sym_name" = %sym, "family" = %family, "oracle" = %oracle, "index" = %index} + irdl.operands(input: %family_type, oracle_buffer: %buffer) + irdl.results(output: %family_type) + } + irdl.operation @pcs_commit_batch { + %artifact_type = irdl.parametric @compute::@commitment_artifact<> + %family_type = irdl.parametric @compute::@oracle_family<> + %sym = irdl.any + %artifact = irdl.any + %pcs = irdl.any + %oracle_family = irdl.any + %ordered_oracles = irdl.any + %label = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "artifact" = %artifact, + "pcs" = %pcs, + "oracle_family" = %oracle_family, + "ordered_oracles" = %ordered_oracles, + "label" = %label, + "domain" = %domain, + "num_vars" = %num_vars, + "count" = %count + } + irdl.operands(oracles: %family_type) + irdl.results(artifact: %artifact_type) + } + irdl.operation @pcs_commit_optional { + %artifact_type = irdl.parametric @compute::@commitment_artifact<> + %buffer = irdl.parametric @compute::@oracle_buffer<> + %sym = irdl.any + %artifact = irdl.any + %pcs = irdl.any + %oracle = irdl.any + %label = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %skip_policy = irdl.any + irdl.attributes { + "sym_name" = %sym, + "artifact" = %artifact, + "pcs" = %pcs, + "oracle" = %oracle, + "label" = %label, + "domain" = %domain, + "num_vars" = %num_vars, + "skip_policy" = %skip_policy + } + irdl.operands(oracle_buffer: %buffer) + irdl.results(artifact: %artifact_type) + } + irdl.operation @pcs_receive_batch { + %artifact_type = irdl.parametric @compute::@commitment_artifact<> + %family_type = irdl.parametric @compute::@oracle_family<> + %sym = irdl.any + %artifact = irdl.any + %pcs = irdl.any + %oracle_family = irdl.any + %ordered_oracles = irdl.any + %label = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "artifact" = %artifact, + "pcs" = %pcs, + "oracle_family" = %oracle_family, + "ordered_oracles" = %ordered_oracles, + "label" = %label, + "domain" = %domain, + "num_vars" = %num_vars, + "count" = %count + } + irdl.operands(oracles: %family_type) + irdl.results(artifact: %artifact_type) + } + irdl.operation @pcs_receive_optional { + %artifact_type = irdl.parametric @compute::@commitment_artifact<> + %buffer = irdl.parametric @compute::@oracle_buffer<> + %sym = irdl.any + %artifact = irdl.any + %pcs = irdl.any + %oracle = irdl.any + %label = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %skip_policy = irdl.any + irdl.attributes { + "sym_name" = %sym, + "artifact" = %artifact, + "pcs" = %pcs, + "oracle" = %oracle, + "label" = %label, + "domain" = %domain, + "num_vars" = %num_vars, + "skip_policy" = %skip_policy + } + irdl.operands(oracle_buffer: %buffer) + irdl.results(artifact: %artifact_type) + } + irdl.operation @transcript_init { + %state = irdl.parametric @compute::@transcript_state<> + %sym = irdl.any + %scheme = irdl.any + irdl.attributes {"sym_name" = %sym, "scheme" = %scheme} + irdl.results(state: %state) + } + irdl.operation @transcript_absorb { + %state = irdl.parametric @compute::@transcript_state<> + %artifact = irdl.parametric @compute::@commitment_artifact<> + %sym = irdl.any + %label = irdl.any + %optional = irdl.any + irdl.attributes { + "sym_name" = %sym, + "label" = %label, + "optional" = %optional + } + irdl.operands(input: %state, artifact: %artifact) + irdl.results(output: %state) + } + irdl.operation @transcript_absorb_bytes { + %state = irdl.parametric @compute::@transcript_state<> + %sym = irdl.any + %label = irdl.any + %payload = irdl.any + irdl.attributes { + "sym_name" = %sym, + "label" = %label, + "payload" = %payload + } + irdl.operands(input: %state) + irdl.results(output: %state) + } + irdl.operation @transcript_squeeze { + %state = irdl.parametric @compute::@transcript_state<> + %challenge = irdl.any + %sym = irdl.any + %label = irdl.any + %kind = irdl.any + %count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "label" = %label, + "kind" = %kind, + "count" = %count + } + irdl.operands(input: %state) + irdl.results(output: %state, challenge: %challenge) + } + irdl.operation @opening_input { + %point = irdl.parametric @compute::@point<> + %eval = irdl.parametric @compute::@field_value<> + %claim = irdl.parametric @compute::@opening_claim_type<> + %sym = irdl.any + %source_stage = irdl.any + %source_claim = irdl.any + %oracle = irdl.any + %domain = irdl.any + %point_arity = irdl.any + %claim_kind = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source_stage" = %source_stage, + "source_claim" = %source_claim, + "oracle" = %oracle, + "domain" = %domain, + "point_arity" = %point_arity, + "claim_kind" = %claim_kind + } + irdl.results(point: %point, eval: %eval, claim: %claim) + } + irdl.operation @point_slice { + %input = irdl.parametric @compute::@point<> + %output = irdl.parametric @compute::@point<> + %sym = irdl.any + %source = irdl.any + %offset = irdl.any + %length = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "offset" = %offset, + "length" = %length + } + irdl.operands(input: %input) + irdl.results(output: %output) + } + irdl.operation @point_zero { + %output = irdl.parametric @compute::@point<> + %sym = irdl.any + %field = irdl.any + %arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "field" = %field, + "arity" = %arity + } + irdl.results(output: %output) + } + irdl.operation @point_concat { + %input = irdl.parametric @compute::@point<> + %output = irdl.parametric @compute::@point<> + %sym = irdl.any + %layout = irdl.any + %arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "layout" = %layout, + "arity" = %arity + } + irdl.operands(inputs: variadic %input) + irdl.results(output: %output) + } + irdl.operation @field_const { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + %field = irdl.any + %value = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field, "value" = %value} + irdl.results(value: %value_type) + } + irdl.operation @field_zero { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + %field = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field} + irdl.results(value: %value_type) + } + irdl.operation @field_one { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + %field = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field} + irdl.results(value: %value_type) + } + irdl.operation @field_add { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %value_type, rhs: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @field_sub { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %value_type, rhs: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @field_neg { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(input: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @field_mul { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %value_type, rhs: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @field_pow { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + %exponent = irdl.any + irdl.attributes {"sym_name" = %sym, "exponent" = %exponent} + irdl.operands(input: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @poly_lagrange_basis_eval { + %value_type = irdl.parametric @compute::@field_value<> + %sym = irdl.any + %domain_start = irdl.any + %domain_size = irdl.any + %index = irdl.any + irdl.attributes { + "sym_name" = %sym, + "domain_start" = %domain_start, + "domain_size" = %domain_size, + "index" = %index + } + irdl.operands(point: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @sumcheck_claim { + %input_claim = irdl.parametric @compute::@field_value<> + %opening_claim = irdl.parametric @compute::@opening_claim_type<> + %claim_type = irdl.parametric @compute::@sumcheck_claim_type<> + %sym = irdl.any + %stage = irdl.any + %domain = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + %claim = irdl.any + %relation = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "domain" = %domain, + "num_rounds" = %num_rounds, + "degree" = %degree, + "claim" = %claim, + "relation" = %relation + } + irdl.operands(input_claim: %input_claim, inputs: variadic %opening_claim) + irdl.results(claim: %claim_type) + } + irdl.operation @sumcheck_kernel_claim { + %input_claim = irdl.parametric @compute::@field_value<> + %opening_claim = irdl.parametric @compute::@opening_claim_type<> + %claim_type = irdl.parametric @compute::@sumcheck_claim_type<> + %sym = irdl.any + %stage = irdl.any + %domain = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + %claim = irdl.any + %kernel = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "domain" = %domain, + "num_rounds" = %num_rounds, + "degree" = %degree, + "claim" = %claim, + "kernel" = %kernel + } + irdl.operands(input_claim: %input_claim, inputs: variadic %opening_claim) + irdl.results(claim: %claim_type) + } + irdl.operation @sumcheck_verify_claim { + %input_claim = irdl.parametric @compute::@field_value<> + %opening_claim = irdl.parametric @compute::@opening_claim_type<> + %claim_type = irdl.parametric @compute::@sumcheck_claim_type<> + %sym = irdl.any + %stage = irdl.any + %domain = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + %claim = irdl.any + %relation = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "domain" = %domain, + "num_rounds" = %num_rounds, + "degree" = %degree, + "claim" = %claim, + "relation" = %relation + } + irdl.operands(input_claim: %input_claim, inputs: variadic %opening_claim) + irdl.results(claim: %claim_type) + } + irdl.operation @sumcheck_batch { + %claim_type = irdl.parametric @compute::@sumcheck_claim_type<> + %batch_type = irdl.parametric @compute::@sumcheck_batch_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %round_schedule = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims, + "claim_label" = %claim_label, + "round_label" = %round_label, + "round_schedule" = %round_schedule + } + irdl.operands(claims: variadic %claim_type) + irdl.results(batch: %batch_type) + } + irdl.operation @sumcheck_driver { + %state = irdl.parametric @compute::@transcript_state<> + %batch_type = irdl.parametric @compute::@sumcheck_batch_type<> + %point = irdl.parametric @compute::@point<> + %result = irdl.parametric @compute::@sumcheck_result_type<> + %proof = irdl.parametric @compute::@sumcheck_proof_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %relation = irdl.any + %policy = irdl.any + %round_schedule = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "relation" = %relation, + "policy" = %policy, + "round_schedule" = %round_schedule, + "claim_label" = %claim_label, + "round_label" = %round_label, + "num_rounds" = %num_rounds, + "degree" = %degree + } + irdl.operands(input: %state, batch: %batch_type) + irdl.results(output: %state, point: %point, result: %result, proof: %proof) + } + irdl.operation @sumcheck_kernel_driver { + %state = irdl.parametric @compute::@transcript_state<> + %batch_type = irdl.parametric @compute::@sumcheck_batch_type<> + %point = irdl.parametric @compute::@point<> + %result = irdl.parametric @compute::@sumcheck_result_type<> + %proof = irdl.parametric @compute::@sumcheck_proof_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %kernel = irdl.any + %policy = irdl.any + %round_schedule = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "kernel" = %kernel, + "policy" = %policy, + "round_schedule" = %round_schedule, + "claim_label" = %claim_label, + "round_label" = %round_label, + "num_rounds" = %num_rounds, + "degree" = %degree + } + irdl.operands(input: %state, batch: %batch_type) + irdl.results(output: %state, point: %point, result: %result, proof: %proof) + } + irdl.operation @sumcheck_verify { + %state = irdl.parametric @compute::@transcript_state<> + %batch_type = irdl.parametric @compute::@sumcheck_batch_type<> + %point = irdl.parametric @compute::@point<> + %result = irdl.parametric @compute::@sumcheck_result_type<> + %proof = irdl.parametric @compute::@sumcheck_proof_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %relation = irdl.any + %policy = irdl.any + %round_schedule = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "relation" = %relation, + "policy" = %policy, + "round_schedule" = %round_schedule, + "claim_label" = %claim_label, + "round_label" = %round_label, + "num_rounds" = %num_rounds, + "degree" = %degree + } + irdl.operands(input: %state, batch: %batch_type) + irdl.results(output: %state, point: %point, result: %result, proof: %proof) + } + irdl.operation @sumcheck_eval { + %result = irdl.parametric @compute::@sumcheck_result_type<> + %eval = irdl.parametric @compute::@field_value<> + %sym = irdl.any + %source = irdl.any + %name = irdl.any + %index = irdl.any + %oracle = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "name" = %name, + "index" = %index, + "oracle" = %oracle + } + irdl.operands(result: %result) + irdl.results(eval: %eval) + } + irdl.operation @sumcheck_instance_result { + %input_point = irdl.parametric @compute::@point<> + %output_point = irdl.parametric @compute::@point<> + %input_result = irdl.parametric @compute::@sumcheck_result_type<> + %output_result = irdl.parametric @compute::@sumcheck_result_type<> + %sym = irdl.any + %source = irdl.any + %claim = irdl.any + %relation = irdl.any + %index = irdl.any + %point_arity = irdl.any + %num_rounds = irdl.any + %round_offset = irdl.any + %point_order = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "claim" = %claim, + "relation" = %relation, + "index" = %index, + "point_arity" = %point_arity, + "num_rounds" = %num_rounds, + "round_offset" = %round_offset, + "point_order" = %point_order, + "degree" = %degree + } + irdl.operands(input_point: %input_point, input_result: %input_result) + irdl.results(instance_point: %output_point, instance_result: %output_result) + } + irdl.operation @opening_claim { + %point = irdl.parametric @compute::@point<> + %eval = irdl.parametric @compute::@field_value<> + %claim = irdl.parametric @compute::@opening_claim_type<> + %sym = irdl.any + %oracle = irdl.any + %domain = irdl.any + %point_arity = irdl.any + %claim_kind = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "domain" = %domain, + "point_arity" = %point_arity, + "claim_kind" = %claim_kind + } + irdl.operands(point: %point, eval: %eval) + irdl.results(claim: %claim) + } + irdl.operation @opening_claim_equal { + %claim = irdl.parametric @compute::@opening_claim_type<> + %sym = irdl.any + %mode = irdl.any + irdl.attributes { + "sym_name" = %sym, + "mode" = %mode + } + irdl.operands(left: %claim, right: %claim) + } + irdl.operation @opening_batch { + %claim = irdl.parametric @compute::@opening_claim_type<> + %batch = irdl.parametric @compute::@opening_batch_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims + } + irdl.operands(claims: variadic %claim) + irdl.results(batch: %batch) + } + irdl.operation @pcs_opening_claim { + %point = irdl.parametric @compute::@point<> + %eval = irdl.parametric @compute::@field_value<> + %claim = irdl.parametric @compute::@opening_claim_type<> + %sym = irdl.any + %oracle = irdl.any + %family = irdl.any + %domain = irdl.any + %point_arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "family" = %family, + "domain" = %domain, + "point_arity" = %point_arity + } + irdl.operands(point: %point, eval: %eval) + irdl.results(claim: %claim) + } + irdl.operation @pcs_opening_batch { + %claim = irdl.parametric @compute::@opening_claim_type<> + %batch = irdl.parametric @compute::@opening_batch_type<> + %sym = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + irdl.attributes { + "sym_name" = %sym, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims + } + irdl.operands(claims: variadic %claim) + irdl.results(batch: %batch) + } + irdl.operation @pcs_batch_open { + %state = irdl.parametric @compute::@transcript_state<> + %batch = irdl.parametric @compute::@opening_batch_type<> + %proof = irdl.parametric @compute::@opening_proof_type<> + %sym = irdl.any + %pcs = irdl.any + %proof_slot = irdl.any + %transcript_label = irdl.any + irdl.attributes { + "sym_name" = %sym, + "pcs" = %pcs, + "proof_slot" = %proof_slot, + "transcript_label" = %transcript_label + } + irdl.operands(input: %state, batch: %batch) + irdl.results(output: %state, proof: %proof) + } + irdl.operation @pcs_batch_verify { + %state = irdl.parametric @compute::@transcript_state<> + %batch = irdl.parametric @compute::@opening_batch_type<> + %proof = irdl.parametric @compute::@opening_proof_type<> + %sym = irdl.any + %pcs = irdl.any + %proof_slot = irdl.any + %transcript_label = irdl.any + irdl.attributes { + "sym_name" = %sym, + "pcs" = %pcs, + "proof_slot" = %proof_slot, + "transcript_label" = %transcript_label + } + irdl.operands(input: %state, batch: %batch) + irdl.results(output: %state, proof: %proof) + } + irdl.operation @generate_oracle { + %sym = irdl.any + %oracle = irdl.any + %source = irdl.any + %generation = irdl.any + irdl.attributes {"sym_name" = %sym, "oracle" = %oracle, "source" = %source, "generation" = %generation} + } + irdl.operation @generate_oracle_family { + %sym = irdl.any + %family = irdl.any + %source = irdl.any + %generation = irdl.any + irdl.attributes {"sym_name" = %sym, "family" = %family, "source" = %source, "generation" = %generation} + } +} diff --git a/crates/bolt/irdl/cpu.mlir b/crates/bolt/irdl/cpu.mlir new file mode 100644 index 0000000000..22fe7ad519 --- /dev/null +++ b/crates/bolt/irdl/cpu.mlir @@ -0,0 +1,723 @@ +irdl.dialect @cpu { + irdl.type @commitment_artifact + irdl.type @transcript_state + irdl.type @oracle_buffer + irdl.type @oracle_family + irdl.type @field_value + irdl.type @point + irdl.type @sumcheck_claim_type + irdl.type @sumcheck_batch_type + irdl.type @sumcheck_result_type + irdl.type @sumcheck_proof_type + irdl.type @opening_claim_type + irdl.type @opening_batch_type + irdl.type @opening_proof_type + + irdl.operation @params { + %sym = irdl.any + %field = irdl.any + %pcs = irdl.any + %transcript = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field, "pcs" = %pcs, "transcript" = %transcript} + } + irdl.operation @function { + %sym = irdl.any + %source = irdl.any + irdl.attributes {"sym_name" = %sym, "source" = %source} + } + irdl.operation @kernel { + %sym = irdl.any + %relation = irdl.any + %kind = irdl.any + %backend = irdl.any + %abi = irdl.any + irdl.attributes { + "sym_name" = %sym, + "relation" = %relation, + "kind" = %kind, + "backend" = %backend, + "abi" = %abi + } + } + irdl.operation @oracle_dense_trace { + %buffer = irdl.parametric @cpu::@oracle_buffer<> + %sym = irdl.any + %oracle = irdl.any + %source = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %padding = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "source" = %source, + "domain" = %domain, + "num_vars" = %num_vars, + "padding" = %padding + } + irdl.results(buffer: %buffer) + } + irdl.operation @oracle_one_hot_chunk { + %buffer = irdl.parametric @cpu::@oracle_buffer<> + %sym = irdl.any + %oracle = irdl.any + %source = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %trace_num_vars = irdl.any + %chunk = irdl.any + %num_chunks = irdl.any + %chunk_bits = irdl.any + %padding = irdl.any + %layout = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "source" = %source, + "domain" = %domain, + "num_vars" = %num_vars, + "trace_num_vars" = %trace_num_vars, + "chunk" = %chunk, + "num_chunks" = %num_chunks, + "chunk_bits" = %chunk_bits, + "padding" = %padding, + "layout" = %layout + } + irdl.results(buffer: %buffer) + } + irdl.operation @oracle_optional_advice { + %buffer = irdl.parametric @cpu::@oracle_buffer<> + %sym = irdl.any + %oracle = irdl.any + %source = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %skip_policy = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "source" = %source, + "domain" = %domain, + "num_vars" = %num_vars, + "skip_policy" = %skip_policy + } + irdl.results(buffer: %buffer) + } + irdl.operation @oracle_ref { + %buffer = irdl.parametric @cpu::@oracle_buffer<> + %sym = irdl.any + %oracle = irdl.any + %domain = irdl.any + %num_vars = irdl.any + irdl.attributes {"sym_name" = %sym, "oracle" = %oracle, "domain" = %domain, "num_vars" = %num_vars} + irdl.results(buffer: %buffer) + } + irdl.operation @oracle_family_init { + %family_type = irdl.parametric @cpu::@oracle_family<> + %sym = irdl.any + %family = irdl.any + %count = irdl.any + irdl.attributes {"sym_name" = %sym, "family" = %family, "count" = %count} + irdl.results(family: %family_type) + } + irdl.operation @oracle_family_append { + %family_type = irdl.parametric @cpu::@oracle_family<> + %buffer = irdl.parametric @cpu::@oracle_buffer<> + %sym = irdl.any + %family = irdl.any + %oracle = irdl.any + %index = irdl.any + irdl.attributes {"sym_name" = %sym, "family" = %family, "oracle" = %oracle, "index" = %index} + irdl.operands(input: %family_type, oracle_buffer: %buffer) + irdl.results(output: %family_type) + } + irdl.operation @pcs_commit_batch { + %artifact_type = irdl.parametric @cpu::@commitment_artifact<> + %family_type = irdl.parametric @cpu::@oracle_family<> + %sym = irdl.any + %artifact = irdl.any + %pcs = irdl.any + %oracle_family = irdl.any + %ordered_oracles = irdl.any + %label = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "artifact" = %artifact, + "pcs" = %pcs, + "oracle_family" = %oracle_family, + "ordered_oracles" = %ordered_oracles, + "label" = %label, + "domain" = %domain, + "num_vars" = %num_vars, + "count" = %count + } + irdl.operands(oracles: %family_type) + irdl.results(artifact: %artifact_type) + } + irdl.operation @pcs_commit_optional { + %artifact_type = irdl.parametric @cpu::@commitment_artifact<> + %buffer = irdl.parametric @cpu::@oracle_buffer<> + %sym = irdl.any + %artifact = irdl.any + %pcs = irdl.any + %oracle = irdl.any + %label = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %skip_policy = irdl.any + irdl.attributes { + "sym_name" = %sym, + "artifact" = %artifact, + "pcs" = %pcs, + "oracle" = %oracle, + "label" = %label, + "domain" = %domain, + "num_vars" = %num_vars, + "skip_policy" = %skip_policy + } + irdl.operands(oracle_buffer: %buffer) + irdl.results(artifact: %artifact_type) + } + irdl.operation @pcs_receive_batch { + %artifact_type = irdl.parametric @cpu::@commitment_artifact<> + %family_type = irdl.parametric @cpu::@oracle_family<> + %sym = irdl.any + %artifact = irdl.any + %pcs = irdl.any + %oracle_family = irdl.any + %ordered_oracles = irdl.any + %label = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "artifact" = %artifact, + "pcs" = %pcs, + "oracle_family" = %oracle_family, + "ordered_oracles" = %ordered_oracles, + "label" = %label, + "domain" = %domain, + "num_vars" = %num_vars, + "count" = %count + } + irdl.operands(oracles: %family_type) + irdl.results(artifact: %artifact_type) + } + irdl.operation @pcs_receive_optional { + %artifact_type = irdl.parametric @cpu::@commitment_artifact<> + %buffer = irdl.parametric @cpu::@oracle_buffer<> + %sym = irdl.any + %artifact = irdl.any + %pcs = irdl.any + %oracle = irdl.any + %label = irdl.any + %domain = irdl.any + %num_vars = irdl.any + %skip_policy = irdl.any + irdl.attributes { + "sym_name" = %sym, + "artifact" = %artifact, + "pcs" = %pcs, + "oracle" = %oracle, + "label" = %label, + "domain" = %domain, + "num_vars" = %num_vars, + "skip_policy" = %skip_policy + } + irdl.operands(oracle_buffer: %buffer) + irdl.results(artifact: %artifact_type) + } + irdl.operation @transcript_init { + %state = irdl.parametric @cpu::@transcript_state<> + %sym = irdl.any + %scheme = irdl.any + irdl.attributes {"sym_name" = %sym, "scheme" = %scheme} + irdl.results(state: %state) + } + irdl.operation @transcript_absorb { + %state = irdl.parametric @cpu::@transcript_state<> + %artifact = irdl.parametric @cpu::@commitment_artifact<> + %sym = irdl.any + %label = irdl.any + %optional = irdl.any + irdl.attributes { + "sym_name" = %sym, + "label" = %label, + "optional" = %optional + } + irdl.operands(input: %state, artifact: %artifact) + irdl.results(output: %state) + } + irdl.operation @transcript_absorb_bytes { + %state = irdl.parametric @cpu::@transcript_state<> + %sym = irdl.any + %label = irdl.any + %payload = irdl.any + irdl.attributes { + "sym_name" = %sym, + "label" = %label, + "payload" = %payload + } + irdl.operands(input: %state) + irdl.results(output: %state) + } + irdl.operation @transcript_squeeze { + %state = irdl.parametric @cpu::@transcript_state<> + %challenge = irdl.any + %sym = irdl.any + %label = irdl.any + %kind = irdl.any + %count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "label" = %label, + "kind" = %kind, + "count" = %count + } + irdl.operands(input: %state) + irdl.results(output: %state, challenge: %challenge) + } + irdl.operation @opening_input { + %point = irdl.parametric @cpu::@point<> + %eval = irdl.parametric @cpu::@field_value<> + %claim = irdl.parametric @cpu::@opening_claim_type<> + %sym = irdl.any + %source_stage = irdl.any + %source_claim = irdl.any + %oracle = irdl.any + %domain = irdl.any + %point_arity = irdl.any + %claim_kind = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source_stage" = %source_stage, + "source_claim" = %source_claim, + "oracle" = %oracle, + "domain" = %domain, + "point_arity" = %point_arity, + "claim_kind" = %claim_kind + } + irdl.results(point: %point, eval: %eval, claim: %claim) + } + irdl.operation @point_slice { + %input = irdl.parametric @cpu::@point<> + %output = irdl.parametric @cpu::@point<> + %sym = irdl.any + %source = irdl.any + %offset = irdl.any + %length = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "offset" = %offset, + "length" = %length + } + irdl.operands(input: %input) + irdl.results(output: %output) + } + irdl.operation @point_zero { + %output = irdl.parametric @cpu::@point<> + %sym = irdl.any + %field = irdl.any + %arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "field" = %field, + "arity" = %arity + } + irdl.results(output: %output) + } + irdl.operation @point_concat { + %input = irdl.parametric @cpu::@point<> + %output = irdl.parametric @cpu::@point<> + %sym = irdl.any + %layout = irdl.any + %arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "layout" = %layout, + "arity" = %arity + } + irdl.operands(inputs: variadic %input) + irdl.results(output: %output) + } + irdl.operation @field_const { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + %field = irdl.any + %value = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field, "value" = %value} + irdl.results(value: %value_type) + } + irdl.operation @field_zero { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + %field = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field} + irdl.results(value: %value_type) + } + irdl.operation @field_one { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + %field = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field} + irdl.results(value: %value_type) + } + irdl.operation @field_add { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %value_type, rhs: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @field_sub { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %value_type, rhs: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @field_neg { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(input: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @field_mul { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %value_type, rhs: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @field_pow { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + %exponent = irdl.any + irdl.attributes {"sym_name" = %sym, "exponent" = %exponent} + irdl.operands(input: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @poly_lagrange_basis_eval { + %value_type = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + %domain_start = irdl.any + %domain_size = irdl.any + %index = irdl.any + irdl.attributes { + "sym_name" = %sym, + "domain_start" = %domain_start, + "domain_size" = %domain_size, + "index" = %index + } + irdl.operands(point: %value_type) + irdl.results(value: %value_type) + } + irdl.operation @sumcheck_claim { + %input_claim = irdl.parametric @cpu::@field_value<> + %opening_claim = irdl.parametric @cpu::@opening_claim_type<> + %claim_type = irdl.parametric @cpu::@sumcheck_claim_type<> + %sym = irdl.any + %stage = irdl.any + %domain = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + %claim = irdl.any + %kernel = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "domain" = %domain, + "num_rounds" = %num_rounds, + "degree" = %degree, + "claim" = %claim, + "kernel" = %kernel + } + irdl.operands(input_claim: %input_claim, inputs: variadic %opening_claim) + irdl.results(claim: %claim_type) + } + irdl.operation @sumcheck_verify_claim { + %input_claim = irdl.parametric @cpu::@field_value<> + %opening_claim = irdl.parametric @cpu::@opening_claim_type<> + %claim_type = irdl.parametric @cpu::@sumcheck_claim_type<> + %sym = irdl.any + %stage = irdl.any + %domain = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + %claim = irdl.any + %relation = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "domain" = %domain, + "num_rounds" = %num_rounds, + "degree" = %degree, + "claim" = %claim, + "relation" = %relation + } + irdl.operands(input_claim: %input_claim, inputs: variadic %opening_claim) + irdl.results(claim: %claim_type) + } + irdl.operation @sumcheck_batch { + %claim_type = irdl.parametric @cpu::@sumcheck_claim_type<> + %batch_type = irdl.parametric @cpu::@sumcheck_batch_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %round_schedule = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims, + "claim_label" = %claim_label, + "round_label" = %round_label, + "round_schedule" = %round_schedule + } + irdl.operands(claims: variadic %claim_type) + irdl.results(batch: %batch_type) + } + irdl.operation @sumcheck_driver { + %state = irdl.parametric @cpu::@transcript_state<> + %batch_type = irdl.parametric @cpu::@sumcheck_batch_type<> + %point = irdl.parametric @cpu::@point<> + %result = irdl.parametric @cpu::@sumcheck_result_type<> + %proof = irdl.parametric @cpu::@sumcheck_proof_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %kernel = irdl.any + %policy = irdl.any + %round_schedule = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "kernel" = %kernel, + "policy" = %policy, + "round_schedule" = %round_schedule, + "claim_label" = %claim_label, + "round_label" = %round_label, + "num_rounds" = %num_rounds, + "degree" = %degree + } + irdl.operands(input: %state, batch: %batch_type) + irdl.results(output: %state, point: %point, result: %result, proof: %proof) + } + irdl.operation @sumcheck_verify { + %state = irdl.parametric @cpu::@transcript_state<> + %batch_type = irdl.parametric @cpu::@sumcheck_batch_type<> + %point = irdl.parametric @cpu::@point<> + %result = irdl.parametric @cpu::@sumcheck_result_type<> + %proof = irdl.parametric @cpu::@sumcheck_proof_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %relation = irdl.any + %policy = irdl.any + %round_schedule = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "relation" = %relation, + "policy" = %policy, + "round_schedule" = %round_schedule, + "claim_label" = %claim_label, + "round_label" = %round_label, + "num_rounds" = %num_rounds, + "degree" = %degree + } + irdl.operands(input: %state, batch: %batch_type) + irdl.results(output: %state, point: %point, result: %result, proof: %proof) + } + irdl.operation @sumcheck_eval { + %result = irdl.parametric @cpu::@sumcheck_result_type<> + %eval = irdl.parametric @cpu::@field_value<> + %sym = irdl.any + %source = irdl.any + %name = irdl.any + %index = irdl.any + %oracle = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "name" = %name, + "index" = %index, + "oracle" = %oracle + } + irdl.operands(result: %result) + irdl.results(eval: %eval) + } + irdl.operation @sumcheck_instance_result { + %input_point = irdl.parametric @cpu::@point<> + %output_point = irdl.parametric @cpu::@point<> + %input_result = irdl.parametric @cpu::@sumcheck_result_type<> + %output_result = irdl.parametric @cpu::@sumcheck_result_type<> + %sym = irdl.any + %source = irdl.any + %claim = irdl.any + %relation = irdl.any + %index = irdl.any + %point_arity = irdl.any + %num_rounds = irdl.any + %round_offset = irdl.any + %point_order = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "claim" = %claim, + "relation" = %relation, + "index" = %index, + "point_arity" = %point_arity, + "num_rounds" = %num_rounds, + "round_offset" = %round_offset, + "point_order" = %point_order, + "degree" = %degree + } + irdl.operands(input_point: %input_point, input_result: %input_result) + irdl.results(instance_point: %output_point, instance_result: %output_result) + } + irdl.operation @opening_claim { + %point = irdl.parametric @cpu::@point<> + %eval = irdl.parametric @cpu::@field_value<> + %claim = irdl.parametric @cpu::@opening_claim_type<> + %sym = irdl.any + %oracle = irdl.any + %domain = irdl.any + %point_arity = irdl.any + %claim_kind = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "domain" = %domain, + "point_arity" = %point_arity, + "claim_kind" = %claim_kind + } + irdl.operands(point: %point, eval: %eval) + irdl.results(claim: %claim) + } + irdl.operation @opening_claim_equal { + %claim = irdl.parametric @cpu::@opening_claim_type<> + %sym = irdl.any + %mode = irdl.any + irdl.attributes { + "sym_name" = %sym, + "mode" = %mode + } + irdl.operands(left: %claim, right: %claim) + } + irdl.operation @opening_batch { + %claim = irdl.parametric @cpu::@opening_claim_type<> + %batch = irdl.parametric @cpu::@opening_batch_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims + } + irdl.operands(claims: variadic %claim) + irdl.results(batch: %batch) + } + irdl.operation @pcs_opening_claim { + %point = irdl.parametric @cpu::@point<> + %eval = irdl.parametric @cpu::@field_value<> + %claim = irdl.parametric @cpu::@opening_claim_type<> + %sym = irdl.any + %oracle = irdl.any + %family = irdl.any + %domain = irdl.any + %point_arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "family" = %family, + "domain" = %domain, + "point_arity" = %point_arity + } + irdl.operands(point: %point, eval: %eval) + irdl.results(claim: %claim) + } + irdl.operation @pcs_opening_batch { + %claim = irdl.parametric @cpu::@opening_claim_type<> + %batch = irdl.parametric @cpu::@opening_batch_type<> + %sym = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + irdl.attributes { + "sym_name" = %sym, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims + } + irdl.operands(claims: variadic %claim) + irdl.results(batch: %batch) + } + irdl.operation @pcs_batch_open { + %state = irdl.parametric @cpu::@transcript_state<> + %batch = irdl.parametric @cpu::@opening_batch_type<> + %proof = irdl.parametric @cpu::@opening_proof_type<> + %sym = irdl.any + %pcs = irdl.any + %proof_slot = irdl.any + %transcript_label = irdl.any + irdl.attributes { + "sym_name" = %sym, + "pcs" = %pcs, + "proof_slot" = %proof_slot, + "transcript_label" = %transcript_label + } + irdl.operands(input: %state, batch: %batch) + irdl.results(output: %state, proof: %proof) + } + irdl.operation @pcs_batch_verify { + %state = irdl.parametric @cpu::@transcript_state<> + %batch = irdl.parametric @cpu::@opening_batch_type<> + %proof = irdl.parametric @cpu::@opening_proof_type<> + %sym = irdl.any + %pcs = irdl.any + %proof_slot = irdl.any + %transcript_label = irdl.any + irdl.attributes { + "sym_name" = %sym, + "pcs" = %pcs, + "proof_slot" = %proof_slot, + "transcript_label" = %transcript_label + } + irdl.operands(input: %state, batch: %batch) + irdl.results(output: %state, proof: %proof) + } +} diff --git a/crates/bolt/irdl/field.mlir b/crates/bolt/irdl/field.mlir new file mode 100644 index 0000000000..2f14d2b061 --- /dev/null +++ b/crates/bolt/irdl/field.mlir @@ -0,0 +1,67 @@ +irdl.dialect @field { + irdl.type @scalar + irdl.operation @define { + %sym = irdl.any + %modulus_bits = irdl.any + %role = irdl.any + irdl.attributes {"sym_name" = %sym, "modulus_bits" = %modulus_bits, "role" = %role} + } + irdl.operation @const { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + %field = irdl.any + %value = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field, "value" = %value} + irdl.results(value: %scalar) + } + irdl.operation @zero { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + %field = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field} + irdl.results(value: %scalar) + } + irdl.operation @one { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + %field = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field} + irdl.results(value: %scalar) + } + irdl.operation @add { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %scalar, rhs: %scalar) + irdl.results(value: %scalar) + } + irdl.operation @sub { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %scalar, rhs: %scalar) + irdl.results(value: %scalar) + } + irdl.operation @neg { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(input: %scalar) + irdl.results(value: %scalar) + } + irdl.operation @mul { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + irdl.attributes {"sym_name" = %sym} + irdl.operands(lhs: %scalar, rhs: %scalar) + irdl.results(value: %scalar) + } + irdl.operation @pow { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + %exponent = irdl.any + irdl.attributes {"sym_name" = %sym, "exponent" = %exponent} + irdl.operands(input: %scalar) + irdl.results(value: %scalar) + } +} diff --git a/crates/bolt/irdl/hash.mlir b/crates/bolt/irdl/hash.mlir new file mode 100644 index 0000000000..66b2f984fa --- /dev/null +++ b/crates/bolt/irdl/hash.mlir @@ -0,0 +1,7 @@ +irdl.dialect @hash { + irdl.operation @function { + %sym = irdl.any + %algorithm = irdl.any + irdl.attributes {"sym_name" = %sym, "algorithm" = %algorithm} + } +} diff --git a/crates/bolt/irdl/party.mlir b/crates/bolt/irdl/party.mlir new file mode 100644 index 0000000000..46dfb26d30 --- /dev/null +++ b/crates/bolt/irdl/party.mlir @@ -0,0 +1,8 @@ +irdl.dialect @party { + irdl.operation @function { + %sym = irdl.any + %source = irdl.any + %role = irdl.any + irdl.attributes {"sym_name" = %sym, "source" = %source, "role" = %role} + } +} diff --git a/crates/bolt/irdl/pcs.mlir b/crates/bolt/irdl/pcs.mlir new file mode 100644 index 0000000000..a5fff00470 --- /dev/null +++ b/crates/bolt/irdl/pcs.mlir @@ -0,0 +1,89 @@ +irdl.dialect @pcs { + irdl.type @scheme_type + irdl.type @opening_claim_type + irdl.type @opening_batch_type + irdl.type @opening_proof_type + irdl.operation @scheme { + %sym = irdl.any + %field = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field} + } + irdl.operation @commit_batch { + %artifact = irdl.parametric @commit::@artifact<> + %sym = irdl.any + %scheme = irdl.any + irdl.attributes {"sym_name" = %sym, "scheme" = %scheme} + irdl.operands(commitment: %artifact) + } + irdl.operation @opening_claim { + %point = irdl.parametric @poly::@point<> + %eval = irdl.parametric @field::@scalar<> + %claim = irdl.parametric @pcs::@opening_claim_type<> + %sym = irdl.any + %oracle = irdl.any + %family = irdl.any + %domain = irdl.any + %point_arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "family" = %family, + "domain" = %domain, + "point_arity" = %point_arity + } + irdl.operands(point: %point, eval: %eval) + irdl.results(claim: %claim) + } + irdl.operation @opening_batch { + %claim = irdl.parametric @pcs::@opening_claim_type<> + %batch = irdl.parametric @pcs::@opening_batch_type<> + %sym = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + irdl.attributes { + "sym_name" = %sym, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims + } + irdl.operands(claims: variadic %claim) + irdl.results(batch: %batch) + } + irdl.operation @batch_open { + %state = irdl.parametric @transcript::@state_type<> + %batch = irdl.parametric @pcs::@opening_batch_type<> + %proof = irdl.parametric @pcs::@opening_proof_type<> + %sym = irdl.any + %pcs = irdl.any + %proof_slot = irdl.any + %transcript_label = irdl.any + irdl.attributes { + "sym_name" = %sym, + "pcs" = %pcs, + "proof_slot" = %proof_slot, + "transcript_label" = %transcript_label + } + irdl.operands(input: %state, batch: %batch) + irdl.results(output: %state, proof: %proof) + } + irdl.operation @batch_verify { + %state = irdl.parametric @transcript::@state_type<> + %batch = irdl.parametric @pcs::@opening_batch_type<> + %proof = irdl.parametric @pcs::@opening_proof_type<> + %sym = irdl.any + %pcs = irdl.any + %proof_slot = irdl.any + %transcript_label = irdl.any + irdl.attributes { + "sym_name" = %sym, + "pcs" = %pcs, + "proof_slot" = %proof_slot, + "transcript_label" = %transcript_label + } + irdl.operands(input: %state, batch: %batch) + irdl.results(output: %state, proof: %proof) + } +} diff --git a/crates/bolt/irdl/piop.mlir b/crates/bolt/irdl/piop.mlir new file mode 100644 index 0000000000..169bc5c9ab --- /dev/null +++ b/crates/bolt/irdl/piop.mlir @@ -0,0 +1,270 @@ +irdl.dialect @piop { + irdl.type @stage_type + irdl.type @sumcheck_claim_type + irdl.type @sumcheck_batch_type + irdl.type @sumcheck_result_type + irdl.type @sumcheck_proof_type + irdl.type @opening_claim_type + irdl.type @opening_batch_type + + irdl.operation @oracle { + %sym = irdl.any + %field = irdl.any + %domain = irdl.any + %commit_domain = irdl.any + %visibility = irdl.any + %layout = irdl.any + irdl.attributes { + "sym_name" = %sym, + "field" = %field, + "domain" = %domain, + "commit_domain" = %commit_domain, + "visibility" = %visibility, + "layout" = %layout + } + } + irdl.operation @oracle_family { + %sym = irdl.any + %ordered_oracles = irdl.any + %visibility = irdl.any + %count = irdl.any + %domain = irdl.any + irdl.attributes { + "sym_name" = %sym, + "ordered_oracles" = %ordered_oracles, + "visibility" = %visibility, + "count" = %count, + "domain" = %domain + } + } + irdl.operation @stage { + %stage_type = irdl.parametric @piop::@stage_type<> + %sym = irdl.any + %name = irdl.any + %order = irdl.any + %roles = irdl.any + irdl.attributes { + "sym_name" = %sym, + "name" = %name, + "order" = %order, + "roles" = %roles + } + irdl.results(stage: %stage_type) + } + irdl.operation @relation { + %sym = irdl.any + %kind = irdl.any + %domain = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + %output_count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "kind" = %kind, + "domain" = %domain, + "num_rounds" = %num_rounds, + "degree" = %degree, + "output_count" = %output_count + } + } + irdl.operation @sumcheck_claim { + %input_claim = irdl.parametric @field::@scalar<> + %opening_claim = irdl.parametric @piop::@opening_claim_type<> + %claim_type = irdl.parametric @piop::@sumcheck_claim_type<> + %sym = irdl.any + %stage = irdl.any + %domain = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + %claim = irdl.any + %relation = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "domain" = %domain, + "num_rounds" = %num_rounds, + "degree" = %degree, + "claim" = %claim, + "relation" = %relation + } + irdl.operands(input_claim: %input_claim, inputs: variadic %opening_claim) + irdl.results(claim: %claim_type) + } + irdl.operation @opening_input { + %point = irdl.parametric @poly::@point<> + %eval = irdl.parametric @field::@scalar<> + %claim = irdl.parametric @piop::@opening_claim_type<> + %sym = irdl.any + %source_stage = irdl.any + %source_claim = irdl.any + %oracle = irdl.any + %domain = irdl.any + %point_arity = irdl.any + %claim_kind = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source_stage" = %source_stage, + "source_claim" = %source_claim, + "oracle" = %oracle, + "domain" = %domain, + "point_arity" = %point_arity, + "claim_kind" = %claim_kind + } + irdl.results(point: %point, eval: %eval, claim: %claim) + } + irdl.operation @sumcheck_batch { + %stage_type = irdl.parametric @piop::@stage_type<> + %claim_type = irdl.parametric @piop::@sumcheck_claim_type<> + %batch_type = irdl.parametric @piop::@sumcheck_batch_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %round_schedule = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims, + "claim_label" = %claim_label, + "round_label" = %round_label, + "round_schedule" = %round_schedule + } + irdl.operands(stage: %stage_type, claims: variadic %claim_type) + irdl.results(batch: %batch_type) + } + irdl.operation @sumcheck { + %state = irdl.parametric @transcript::@state_type<> + %batch_type = irdl.parametric @piop::@sumcheck_batch_type<> + %point = irdl.parametric @poly::@point<> + %result = irdl.parametric @piop::@sumcheck_result_type<> + %proof = irdl.parametric @piop::@sumcheck_proof_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %relation = irdl.any + %policy = irdl.any + %round_schedule = irdl.any + %claim_label = irdl.any + %round_label = irdl.any + %num_rounds = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "relation" = %relation, + "policy" = %policy, + "round_schedule" = %round_schedule, + "claim_label" = %claim_label, + "round_label" = %round_label, + "num_rounds" = %num_rounds, + "degree" = %degree + } + irdl.operands(input: %state, batch: %batch_type) + irdl.results(output: %state, point: %point, result: %result, proof: %proof) + } + irdl.operation @sumcheck_eval { + %result = irdl.parametric @piop::@sumcheck_result_type<> + %eval = irdl.parametric @field::@scalar<> + %sym = irdl.any + %source = irdl.any + %name = irdl.any + %index = irdl.any + %oracle = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "name" = %name, + "index" = %index, + "oracle" = %oracle + } + irdl.operands(result: %result) + irdl.results(eval: %eval) + } + irdl.operation @sumcheck_instance_result { + %input_point = irdl.parametric @poly::@point<> + %output_point = irdl.parametric @poly::@point<> + %input_result = irdl.parametric @piop::@sumcheck_result_type<> + %output_result = irdl.parametric @piop::@sumcheck_result_type<> + %sym = irdl.any + %source = irdl.any + %claim = irdl.any + %relation = irdl.any + %index = irdl.any + %point_arity = irdl.any + %num_rounds = irdl.any + %round_offset = irdl.any + %point_order = irdl.any + %degree = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "claim" = %claim, + "relation" = %relation, + "index" = %index, + "point_arity" = %point_arity, + "num_rounds" = %num_rounds, + "round_offset" = %round_offset, + "point_order" = %point_order, + "degree" = %degree + } + irdl.operands(input_point: %input_point, input_result: %input_result) + irdl.results(instance_point: %output_point, instance_result: %output_result) + } + irdl.operation @opening_claim { + %point = irdl.parametric @poly::@point<> + %eval = irdl.parametric @field::@scalar<> + %claim = irdl.parametric @piop::@opening_claim_type<> + %sym = irdl.any + %oracle = irdl.any + %domain = irdl.any + %point_arity = irdl.any + %claim_kind = irdl.any + irdl.attributes { + "sym_name" = %sym, + "oracle" = %oracle, + "domain" = %domain, + "point_arity" = %point_arity, + "claim_kind" = %claim_kind + } + irdl.operands(point: %point, eval: %eval) + irdl.results(claim: %claim) + } + irdl.operation @opening_claim_equal { + %claim = irdl.parametric @piop::@opening_claim_type<> + %sym = irdl.any + %mode = irdl.any + irdl.attributes { + "sym_name" = %sym, + "mode" = %mode + } + irdl.operands(left: %claim, right: %claim) + } + irdl.operation @opening_batch { + %claim = irdl.parametric @piop::@opening_claim_type<> + %batch = irdl.parametric @piop::@opening_batch_type<> + %sym = irdl.any + %stage = irdl.any + %proof_slot = irdl.any + %policy = irdl.any + %count = irdl.any + %ordered_claims = irdl.any + irdl.attributes { + "sym_name" = %sym, + "stage" = %stage, + "proof_slot" = %proof_slot, + "policy" = %policy, + "count" = %count, + "ordered_claims" = %ordered_claims + } + irdl.operands(claims: variadic %claim) + irdl.results(batch: %batch) + } +} diff --git a/crates/bolt/irdl/poly.mlir b/crates/bolt/irdl/poly.mlir new file mode 100644 index 0000000000..6ad0ea5951 --- /dev/null +++ b/crates/bolt/irdl/poly.mlir @@ -0,0 +1,68 @@ +irdl.dialect @poly { + irdl.type @domain_type + irdl.type @oracle + irdl.type @point + irdl.operation @domain { + %sym = irdl.any + %field = irdl.any + %log_size = irdl.any + irdl.attributes {"sym_name" = %sym, "field" = %field, "log_size" = %log_size} + } + irdl.operation @point_slice { + %input = irdl.parametric @poly::@point<> + %output = irdl.parametric @poly::@point<> + %sym = irdl.any + %source = irdl.any + %offset = irdl.any + %length = irdl.any + irdl.attributes { + "sym_name" = %sym, + "source" = %source, + "offset" = %offset, + "length" = %length + } + irdl.operands(input: %input) + irdl.results(output: %output) + } + irdl.operation @point_zero { + %output = irdl.parametric @poly::@point<> + %sym = irdl.any + %field = irdl.any + %arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "field" = %field, + "arity" = %arity + } + irdl.results(output: %output) + } + irdl.operation @point_concat { + %input = irdl.parametric @poly::@point<> + %output = irdl.parametric @poly::@point<> + %sym = irdl.any + %layout = irdl.any + %arity = irdl.any + irdl.attributes { + "sym_name" = %sym, + "layout" = %layout, + "arity" = %arity + } + irdl.operands(inputs: variadic %input) + irdl.results(output: %output) + } + irdl.operation @lagrange_basis_eval { + %scalar = irdl.parametric @field::@scalar<> + %sym = irdl.any + %domain_start = irdl.any + %domain_size = irdl.any + %index = irdl.any + irdl.attributes { + "sym_name" = %sym, + "domain_start" = %domain_start, + "domain_size" = %domain_size, + "index" = %index + } + irdl.operands(point: %scalar) + irdl.results(value: %scalar) + } +} diff --git a/crates/bolt/irdl/protocol.mlir b/crates/bolt/irdl/protocol.mlir new file mode 100644 index 0000000000..f17629f4b9 --- /dev/null +++ b/crates/bolt/irdl/protocol.mlir @@ -0,0 +1,19 @@ +irdl.dialect @protocol { + irdl.operation @params { + %sym = irdl.any + %field = irdl.any + %pcs = irdl.any + %transcript = irdl.any + irdl.attributes { + "sym_name" = %sym, + "field" = %field, + "pcs" = %pcs, + "transcript" = %transcript + } + } + irdl.operation @boundary { + %sym = irdl.any + %roles = irdl.any + irdl.attributes {"sym_name" = %sym, "roles" = %roles} + } +} diff --git a/crates/bolt/irdl/transcript.mlir b/crates/bolt/irdl/transcript.mlir new file mode 100644 index 0000000000..7e2dd926bf --- /dev/null +++ b/crates/bolt/irdl/transcript.mlir @@ -0,0 +1,62 @@ +irdl.dialect @transcript { + irdl.type @state_type + irdl.operation @scheme { + %sym = irdl.any + %hash = irdl.any + irdl.attributes {"sym_name" = %sym, "hash" = %hash} + } + irdl.operation @state { + %state = irdl.parametric @transcript::@state_type<> + %sym = irdl.any + %scheme = irdl.any + irdl.attributes {"sym_name" = %sym, "scheme" = %scheme} + irdl.results(state: %state) + } + irdl.operation @absorb { + %state = irdl.parametric @transcript::@state_type<> + %artifact = irdl.parametric @commit::@artifact<> + %sym = irdl.any + %label = irdl.any + irdl.attributes {"sym_name" = %sym, "label" = %label} + irdl.operands(input: %state, artifact: %artifact) + irdl.results(output: %state) + } + irdl.operation @absorb_optional { + %state = irdl.parametric @transcript::@state_type<> + %artifact = irdl.parametric @commit::@artifact<> + %sym = irdl.any + %label = irdl.any + irdl.attributes {"sym_name" = %sym, "label" = %label} + irdl.operands(input: %state, artifact: %artifact) + irdl.results(output: %state) + } + irdl.operation @absorb_bytes { + %state = irdl.parametric @transcript::@state_type<> + %sym = irdl.any + %label = irdl.any + %payload = irdl.any + irdl.attributes { + "sym_name" = %sym, + "label" = %label, + "payload" = %payload + } + irdl.operands(input: %state) + irdl.results(output: %state) + } + irdl.operation @squeeze { + %state = irdl.parametric @transcript::@state_type<> + %challenge = irdl.any + %sym = irdl.any + %label = irdl.any + %kind = irdl.any + %count = irdl.any + irdl.attributes { + "sym_name" = %sym, + "label" = %label, + "kind" = %kind, + "count" = %count + } + irdl.operands(input: %state) + irdl.results(output: %state, challenge: %challenge) + } +} diff --git a/crates/bolt/src/dialects.rs b/crates/bolt/src/dialects.rs new file mode 100644 index 0000000000..655ae7c43f --- /dev/null +++ b/crates/bolt/src/dialects.rs @@ -0,0 +1,39 @@ +use melior::ir::Module; +use melior::utility::load_irdl_dialects; +use melior::Context; + +pub const BOLT_IRDL: &str = concat!( + "module {\n", + include_str!("../irdl/field.mlir"), + "\n", + include_str!("../irdl/poly.mlir"), + "\n", + include_str!("../irdl/hash.mlir"), + "\n", + include_str!("../irdl/transcript.mlir"), + "\n", + include_str!("../irdl/commit.mlir"), + "\n", + include_str!("../irdl/pcs.mlir"), + "\n", + include_str!("../irdl/protocol.mlir"), + "\n", + include_str!("../irdl/piop.mlir"), + "\n", + include_str!("../irdl/party.mlir"), + "\n", + include_str!("../irdl/compute.mlir"), + "\n", + include_str!("../irdl/cpu.mlir"), + "\n}\n" +); + +pub fn load_bolt_dialects(context: &Context) -> Result<(), String> { + let module = Module::parse(context, BOLT_IRDL) + .ok_or_else(|| "failed to parse Bolt IRDL dialect definitions".to_owned())?; + if load_irdl_dialects(&module) { + Ok(()) + } else { + Err("failed to load Bolt IRDL dialect definitions".to_owned()) + } +} diff --git a/crates/bolt/src/emit/mod.rs b/crates/bolt/src/emit/mod.rs new file mode 100644 index 0000000000..0ad9e7d3d7 --- /dev/null +++ b/crates/bolt/src/emit/mod.rs @@ -0,0 +1 @@ +pub mod rust; diff --git a/crates/bolt/src/emit/rust/artifacts.rs b/crates/bolt/src/emit/rust/artifacts.rs new file mode 100644 index 0000000000..ec2d2a6c7f --- /dev/null +++ b/crates/bolt/src/emit/rust/artifacts.rs @@ -0,0 +1,1653 @@ +#![expect( + clippy::format_push_string, + reason = "Rust artifact emission assembles generated source text from format templates" +)] + +use std::path::{Component, Path}; + +use crate::ir::Role; + +use super::{EmitError, RustSourceFile}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProtocolArtifactConfig { + pub protocol_name: String, + pub type_prefix: String, + pub transcript_label: String, + pub repository: Option, + pub prover_crate_name: String, + pub verifier_crate_name: String, + pub crates_io_patches: Vec, + pub standalone_dependency_overrides: Vec, + pub common_dependencies: Vec, + pub prover_dependencies: Vec, + pub verifier_dependencies: Vec, + pub instrumentation_prefix: Option, + pub prover_forbidden_imports: Vec, + pub verifier_forbidden_imports: Vec, + pub kernel_crate: Option, + pub field_type: RustTypeRef, + pub default_transcript_type: RustTypeRef, + pub transcript_trait: RustTypeRef, + pub commitment_type: RustTypeRef, + pub prover_setup_type: RustTypeRef, + pub role_api_extension: Option, + pub verifier_runtime_modules: Vec, + pub verifier_named_eval_type: RustTypeRef, + pub verifier_sumcheck_output_type: RustTypeRef, + pub verifier_stage_proof_type: RustTypeRef, +} + +impl ProtocolArtifactConfig { + fn protocol_snake(&self) -> String { + snake_case(&self.protocol_name) + } + + fn crate_name(&self, role: &Role) -> &str { + match role { + Role::Prover => &self.prover_crate_name, + Role::Verifier => &self.verifier_crate_name, + } + } + + fn dependencies(&self, role: &Role) -> Vec { + let mut dependencies = self.common_dependencies.clone(); + match role { + Role::Prover => { + dependencies.extend(self.prover_dependencies.clone()); + if !dependencies.contains(&self.verifier_crate_name) { + dependencies.push(self.verifier_crate_name.clone()); + } + } + Role::Verifier => dependencies.extend(self.verifier_dependencies.clone()), + } + dependencies.sort(); + dependencies.dedup(); + dependencies + } + + fn forbidden_imports(&self, role: &Role) -> &[String] { + match role { + Role::Prover => &self.prover_forbidden_imports, + Role::Verifier => &self.verifier_forbidden_imports, + } + } + + fn verifier_crate_import(&self) -> String { + rust_crate_ident(&self.verifier_crate_name) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProtocolStandaloneDependency { + pub package: String, + pub manifest_entry: String, +} + +impl ProtocolStandaloneDependency { + pub fn new(package: impl Into, manifest_entry: impl Into) -> Self { + Self { + package: package.into(), + manifest_entry: manifest_entry.into(), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProtocolArtifactExtension { + pub required_commitment: bool, + pub required_proof_stages: Vec, + pub required_artifact_stages: Vec, + pub prover: ProtocolProverApiExtension, + pub verifier: ProtocolVerifierApiExtension, +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct ProtocolProverApiExtension { + pub lib_module: String, + pub imports: String, + pub input_fields: String, + pub program_fields: String, + pub default_program_fields: String, + pub error_variants: String, + pub error_items: String, + pub error_conversions: String, + pub after_stage_execution: String, + pub proof_fields: String, + pub helper_items: String, +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct ProtocolVerifierApiExtension { + pub lib_module: String, + pub imports: String, + pub proof_fields: String, + pub proof_items: String, + pub inputs_derive: Option, + pub input_fields: String, + pub program_fields: String, + pub default_program_fields: String, + pub error_variants: String, + pub error_items: String, + pub error_conversions: String, + pub after_default_verify: String, + pub with_programs_body_intro: String, + pub stage_verification_override: String, + pub after_stage_verification: String, + pub helper_items: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProtocolCrateRef { + pub package: String, + pub import: String, +} + +impl ProtocolCrateRef { + pub fn new(package: impl Into, import: impl Into) -> Self { + Self { + package: package.into(), + import: import.into(), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RustTypeRef { + pub path: String, +} + +impl RustTypeRef { + pub fn new(path: impl Into) -> Self { + Self { path: path.into() } + } + + fn ident(&self) -> &str { + self.path.rsplit("::").next().unwrap_or(&self.path) + } + + fn use_line(&self) -> String { + format!("use {};\n", self.path) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ProtocolStageKind { + Commitment, + Proof, + Evaluation, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProtocolStage { + name: String, + module_name: String, + ordinal: usize, + kind: ProtocolStageKind, +} + +impl ProtocolStage { + pub fn new( + name: impl Into, + module_name: impl Into, + ordinal: usize, + kind: ProtocolStageKind, + ) -> Self { + Self { + name: name.into(), + module_name: module_name.into(), + ordinal, + kind, + } + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn module_name(&self) -> &str { + &self.module_name + } + + pub fn order(&self) -> usize { + self.ordinal + } + + pub fn is_commitment(&self) -> bool { + self.kind == ProtocolStageKind::Commitment + } + + pub fn is_proof(&self) -> bool { + self.kind == ProtocolStageKind::Proof + } + + pub fn is_evaluation(&self) -> bool { + self.kind == ProtocolStageKind::Evaluation + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ArtifactCrateRole { + Prover, + Verifier, +} + +impl ArtifactCrateRole { + pub fn for_role(role: &Role) -> Self { + match role { + Role::Prover => Self::Prover, + Role::Verifier => Self::Verifier, + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProtocolRustArtifact { + pub role: Role, + pub stage: ProtocolStage, + pub crate_name: String, + pub path: String, + pub source: RustSourceFile, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GeneratedCrate { + pub crate_name: String, + pub files: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GeneratedFile { + pub path: String, + pub source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProtocolRuntimeModule { + pub module_name: String, + pub file: GeneratedFile, +} + +impl GeneratedCrate { + pub fn write_to(&self, output_root: impl AsRef) -> Result<(), EmitError> { + let crate_root = output_root.as_ref().join(&self.crate_name); + for file in &self.files { + let path = generated_file_path(&crate_root, &file.path)?; + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|error| { + EmitError::new(format!( + "failed to create generated crate directory `{}`: {error}", + parent.display() + )) + })?; + } + std::fs::write(&path, &file.source).map_err(|error| { + EmitError::new(format!( + "failed to write generated crate file `{}`: {error}", + path.display() + )) + })?; + } + Ok(()) + } +} + +pub fn write_generated_crates( + generated_crates: &[GeneratedCrate], + output_root: impl AsRef, +) -> Result<(), EmitError> { + for generated_crate in generated_crates { + generated_crate.write_to(output_root.as_ref())?; + } + Ok(()) +} + +pub fn protocol_rust_artifact( + config: &ProtocolArtifactConfig, + stage: ProtocolStage, + role: Role, + source: RustSourceFile, +) -> ProtocolRustArtifact { + let crate_name = config.crate_name(&role).to_owned(); + let path = format!("{crate_name}/src/stages/{}.rs", stage.module_name()); + ProtocolRustArtifact { + role, + stage, + crate_name, + path, + source, + } +} + +pub fn validate_rust_artifact_imports( + config: &ProtocolArtifactConfig, + artifact: &ProtocolRustArtifact, +) -> Result<(), EmitError> { + for import in config.forbidden_imports(&artifact.role) { + if artifact.source.source.contains(import) { + return Err(EmitError::new(format!( + "{} artifact `{}` for {} imports forbidden `{import}`", + artifact.crate_name, + artifact.path, + artifact.stage.name() + ))); + } + } + Ok(()) +} + +pub fn assemble_generated_crates( + config: &ProtocolArtifactConfig, + artifacts: Vec, + dependency_root: &str, +) -> Result, EmitError> { + assemble_generated_crates_with_manifest( + config, + artifacts, + ManifestMode::Standalone { dependency_root }, + ) +} + +pub fn assemble_workspace_generated_crates( + config: &ProtocolArtifactConfig, + artifacts: Vec, +) -> Result, EmitError> { + assemble_generated_crates_with_manifest(config, artifacts, ManifestMode::Workspace) +} + +fn assemble_generated_crates_with_manifest( + config: &ProtocolArtifactConfig, + artifacts: Vec, + manifest_mode: ManifestMode<'_>, +) -> Result, EmitError> { + let mut prover = Vec::new(); + let mut verifier = Vec::new(); + for artifact in artifacts { + validate_rust_artifact_imports(config, &artifact)?; + match artifact.role { + Role::Prover => prover.push(artifact), + Role::Verifier => verifier.push(artifact), + } + } + Ok(vec![ + generated_crate(config, Role::Prover, prover, manifest_mode), + generated_crate(config, Role::Verifier, verifier, manifest_mode), + ]) +} + +fn generated_crate( + config: &ProtocolArtifactConfig, + role: Role, + mut artifacts: Vec, + manifest_mode: ManifestMode<'_>, +) -> GeneratedCrate { + artifacts.sort_by_key(|artifact| artifact.stage.order()); + let crate_name = config.crate_name(&role).to_owned(); + let mut stage_module_lines = Vec::new(); + if role == Role::Verifier { + stage_module_lines.extend( + config + .verifier_runtime_modules + .iter() + .map(|module| format!("pub mod {};", module.module_name)), + ); + } + stage_module_lines.extend(artifacts.iter().map(|artifact| { + format!( + "#[rustfmt::skip]\npub mod {};", + artifact.stage.module_name() + ) + })); + let stage_modules = stage_module_lines.join("\n"); + let mut files = vec![ + GeneratedFile { + path: "Cargo.toml".to_owned(), + source: generated_manifest(config, &role, manifest_mode), + }, + GeneratedFile { + path: "src/lib.rs".to_owned(), + source: generated_lib(config, &role, &artifacts), + }, + generated_role_api_file(config, &role, &artifacts), + GeneratedFile { + path: "src/stages/mod.rs".to_owned(), + source: format!("{stage_modules}\n"), + }, + ]; + if role == Role::Verifier { + files.extend( + config + .verifier_runtime_modules + .iter() + .map(|module| module.file.clone()), + ); + } + files.extend(artifacts.into_iter().map(|artifact| GeneratedFile { + path: format!("src/stages/{}.rs", artifact.stage.module_name()), + source: artifact.source.source, + })); + GeneratedCrate { crate_name, files } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ManifestMode<'a> { + Standalone { dependency_root: &'a str }, + Workspace, +} + +fn generated_manifest( + config: &ProtocolArtifactConfig, + role: &Role, + manifest_mode: ManifestMode<'_>, +) -> String { + let crate_name = config.crate_name(role); + let dependencies = config.dependencies(role); + match manifest_mode { + ManifestMode::Standalone { dependency_root } => { + let patch_section = if config.crates_io_patches.is_empty() { + String::new() + } else { + format!( + "\n[patch.crates-io]\n{}\n", + config.crates_io_patches.join("\n") + ) + }; + let dependencies = dependencies + .into_iter() + .map(|name| standalone_dependency_entry(config, dependency_root, &name)) + .collect::>() + .join("\n"); + format!( + "[package]\nname = \"{crate_name}\"\nversion = \"0.0.0\"\nedition = \"2021\"\n{patch_section}\n[dependencies]\n{dependencies}\n" + ) + } + ManifestMode::Workspace => { + let dependencies = dependencies + .into_iter() + .map(|name| format!("{name}.workspace = true")) + .collect::>() + .join("\n"); + let role_name = match role { + Role::Prover => "prover", + Role::Verifier => "verifier", + }; + let repository = config + .repository + .as_ref() + .map(|repository| format!("repository = \"{repository}\"\n")) + .unwrap_or_default(); + format!( + "[package]\nname = \"{crate_name}\"\nversion = \"0.0.0\"\nedition = \"2021\"\nlicense = \"MIT OR Apache-2.0\"\ndescription = \"Bolt-generated {} {role_name} role crate\"\n{repository}\n[lints]\nworkspace = true\n\n[dependencies]\n{dependencies}\n", + config.protocol_name + ) + } + } +} + +fn standalone_dependency_entry( + config: &ProtocolArtifactConfig, + dependency_root: &str, + package: &str, +) -> String { + config + .standalone_dependency_overrides + .iter() + .find(|dependency| dependency.package == package) + .map_or_else( + || format!("{package} = {{ path = \"{dependency_root}/{package}\" }}"), + |dependency| dependency.manifest_entry.clone(), + ) +} + +fn generated_lib( + config: &ProtocolArtifactConfig, + role: &Role, + artifacts: &[ProtocolRustArtifact], +) -> String { + let protocol_snake = config.protocol_snake(); + let prefix = &config.type_prefix; + let stage_apis = stage_apis(config, artifacts); + let commitment_api = commitment_api(artifacts); + let extension = + active_role_api_extension(config, &stage_apis, commitment_api.as_ref(), artifacts); + let role_module = match (role, extension) { + (Role::Prover, Some(extension)) => extension.prover.lib_module.clone(), + (Role::Prover, None) => format!( + "#[rustfmt::skip]\npub mod prover;\npub mod stages;\n\npub use prover::{{\n default_prover_programs, prove_{protocol_snake}, prove_{protocol_snake}_with_programs,\n Default{prefix}Transcript, {prefix}ProveError, {prefix}ProverArtifacts, {prefix}ProverInputs,\n {prefix}ProverPrograms,\n}};" + ), + (Role::Verifier, Some(extension)) => extension.verifier.lib_module.clone(), + (Role::Verifier, None) => format!( + "pub mod stages;\n#[rustfmt::skip]\npub mod verifier;\n\npub use verifier::{{\n default_verifier_programs, verify_{protocol_snake}, verify_{protocol_snake}_with_programs, {prefix}NamedEval, {prefix}Proof,\n {prefix}StageProof, {prefix}SumcheckOutput, {prefix}VerificationArtifacts, {prefix}VerifierInputs,\n {prefix}VerifierPrograms, {prefix}VerifyError,\n}};" + ), + }; + let stages = artifacts + .iter() + .map(|artifact| { + format!( + " GeneratedStage {{\n name: \"{}\",\n module: \"{}\",\n ordinal: {},\n }},", + artifact.stage.name(), + artifact.stage.module_name(), + artifact.stage.order() + ) + }) + .collect::>() + .join("\n"); + format!( + "{role_module}\n\npub const TRANSCRIPT_LABEL: &[u8] = {};\n\n#[derive(Clone, Copy, Debug, PartialEq, Eq)]\npub struct GeneratedStage {{\n pub name: &'static str,\n pub module: &'static str,\n pub ordinal: usize,\n}}\n\npub const GENERATED_STAGES: &[GeneratedStage] = &[\n{stages}\n];\n\npub fn generated_stage_names() -> impl Iterator {{\n GENERATED_STAGES.iter().map(|stage| stage.name)\n}}\n", + byte_string_literal(&config.transcript_label) + ) +} + +fn generated_role_api_file( + config: &ProtocolArtifactConfig, + role: &Role, + artifacts: &[ProtocolRustArtifact], +) -> GeneratedFile { + match role { + Role::Prover => GeneratedFile { + path: "src/prover.rs".to_owned(), + source: generated_prover_api(config, artifacts), + }, + Role::Verifier => GeneratedFile { + path: "src/verifier.rs".to_owned(), + source: generated_verifier_api(config, artifacts), + }, + } +} + +#[derive(Clone, Debug)] +struct StageRustApi { + field_name: String, + module_alias: String, + variant_name: String, + output_type: String, + eval_type: String, + artifacts_type: String, + error_type: String, + verifier_fn: Option, + with_program_verifier_fn: Option, + program_type: Option, + program_const: Option, + prover_fn: Option, + with_program_prover_fn: Option, + kernel_module: Option, + opening_input_type: Option, + ram_data_type: Option, + verifier_data_type: Option, +} + +#[derive(Clone, Debug)] +struct CommitmentRustApi { + field_name: String, + module_alias: String, + variant_name: String, + artifacts_type: String, + error_type: String, + verifier_fn: Option, + with_program_verifier_fn: Option, + program_type: Option, + program_const: Option, + prover_fn: Option, + with_program_prover_fn: Option, + input_provider_trait: Option, +} + +fn generated_verifier_api( + config: &ProtocolArtifactConfig, + artifacts: &[ProtocolRustArtifact], +) -> String { + let stages = stage_apis(config, artifacts); + let modules = role_modules(artifacts); + let commitment = commitment_api(artifacts); + let extension = active_role_api_extension(config, &stages, commitment.as_ref(), artifacts); + let prefix = &config.type_prefix; + let protocol_snake = config.protocol_snake(); + let field_type = config.field_type.ident(); + let transcript_trait = config.transcript_trait.ident(); + let commitment_type = config.commitment_type.ident(); + let runtime_named_eval_type = &config.verifier_named_eval_type.path; + let runtime_sumcheck_output_type = &config.verifier_sumcheck_output_type.path; + let runtime_stage_proof_type = &config.verifier_stage_proof_type.path; + let named_eval_type = format!("{prefix}NamedEval"); + let sumcheck_output_type = format!("{prefix}SumcheckOutput"); + let stage_proof_type = format!("{prefix}StageProof"); + let proof_type = format!("{prefix}Proof"); + let verifier_inputs_type = format!("{prefix}VerifierInputs"); + let verifier_programs_type = format!("{prefix}VerifierPrograms"); + let verification_artifacts_type = format!("{prefix}VerificationArtifacts"); + let verify_error_type = format!("{prefix}VerifyError"); + + let mut source = String::new(); + if let Some(extension) = extension { + source.push_str(&extension.verifier.imports); + } else { + if commitment.is_some() { + source.push_str(&config.commitment_type.use_line()); + } + source.push_str(&config.field_type.use_line()); + source.push_str(&config.transcript_trait.use_line()); + } + source.push('\n'); + if !modules.is_empty() { + source.push_str(&format!( + "use crate::stages::{{{}}};\n\n", + aliased_modules(&modules).join(", ") + )); + } + + source.push_str(&format!( + "pub type {named_eval_type} = {runtime_named_eval_type}<{field_type}>;\n\ + pub type {sumcheck_output_type} = {runtime_sumcheck_output_type}<{field_type}>;\n\ + pub type {stage_proof_type} = {runtime_stage_proof_type}<{field_type}>;\n\n", + )); + source.push_str(&format!( + "#[derive(Clone, Debug)]\npub struct {proof_type} {{\n" + )); + if commitment.is_some() { + source.push_str(&format!( + " pub commitments: Vec>,\n" + )); + } + for stage in &stages { + source.push_str(&format!( + " pub {}: {stage_proof_type},\n", + stage.field_name + )); + } + if let Some(extension) = extension { + source.push_str(&extension.verifier.proof_fields); + } + source.push_str("}\n\n"); + + if let Some(extension) = extension { + source.push_str(&extension.verifier.proof_items); + } + + let verifier_inputs_derive = extension + .and_then(|extension| extension.verifier.inputs_derive.as_deref()) + .unwrap_or("#[derive(Clone, Copy, Debug)]"); + source.push_str(&format!( + "{verifier_inputs_derive}\npub struct {verifier_inputs_type}<'a> {{\n" + )); + for stage in &stages { + if let Some(opening_type) = &stage.opening_input_type { + source.push_str(&format!( + " pub {}_openings: &'a [{}::{}<{field_type}>],\n", + stage.field_name, stage.module_alias, opening_type + )); + } + if let Some(ram_type) = &stage.ram_data_type { + source.push_str(&format!( + " pub {}_ram: Option<&'a {}::{}<'a>>,\n", + stage.field_name, stage.module_alias, ram_type + )); + } + if let Some(data_type) = &stage.verifier_data_type { + source.push_str(&format!( + " pub {}_data: Option<&'a {}::{}>,\n", + stage.field_name, stage.module_alias, data_type + )); + } + } + if let Some(extension) = extension { + source.push_str(&extension.verifier.input_fields); + } + source.push_str("}\n\n"); + + source.push_str(&format!( + "#[derive(Clone, Copy, Debug)]\npub struct {verifier_programs_type} {{\n" + )); + if let Some(commitment) = &commitment { + if let (Some(program_type), Some(_), Some(_)) = ( + &commitment.program_type, + &commitment.program_const, + &commitment.with_program_verifier_fn, + ) { + source.push_str(&format!( + " pub {}: &'static {}::{},\n", + commitment.field_name, commitment.module_alias, program_type + )); + } + } + for stage in &stages { + if let (Some(program_type), Some(_), Some(_)) = ( + &stage.program_type, + &stage.program_const, + &stage.with_program_verifier_fn, + ) { + source.push_str(&format!( + " pub {}: &'static {}::{},\n", + stage.field_name, stage.module_alias, program_type + )); + } + } + if let Some(extension) = extension { + source.push_str(&extension.verifier.program_fields); + } + source.push_str("}\n\n"); + + source.push_str(&format!( + "pub fn default_verifier_programs() -> {verifier_programs_type} {{\n {verifier_programs_type} {{\n" + )); + if let Some(commitment) = &commitment { + if let (Some(_), Some(program_const), Some(_)) = ( + &commitment.program_type, + &commitment.program_const, + &commitment.with_program_verifier_fn, + ) { + source.push_str(&format!( + " {}: &{}::{},\n", + commitment.field_name, commitment.module_alias, program_const + )); + } + } + for stage in &stages { + if let (Some(_), Some(program_const), Some(_)) = ( + &stage.program_type, + &stage.program_const, + &stage.with_program_verifier_fn, + ) { + source.push_str(&format!( + " {}: &{}::{},\n", + stage.field_name, stage.module_alias, program_const + )); + } + } + if let Some(extension) = extension { + source.push_str(&extension.verifier.default_program_fields); + } + source.push_str(" }\n}\n\n"); + + source.push_str(&format!( + "#[derive(Clone, Debug)]\npub struct {verification_artifacts_type} {{\n" + )); + if let Some(commitment) = &commitment { + source.push_str(&format!( + " pub {}: {}::{},\n", + commitment.field_name, commitment.module_alias, commitment.artifacts_type + )); + } + for stage in &stages { + source.push_str(&format!( + " pub {}: {}::{}<{field_type}>,\n", + stage.field_name, stage.module_alias, stage.artifacts_type + )); + } + source.push_str("}\n\n"); + + source.push_str(&format!( + "#[derive(Debug)]\npub enum {verify_error_type} {{\n" + )); + if let Some(commitment) = &commitment { + source.push_str(&format!( + " {}({}::{}),\n", + commitment.variant_name, commitment.module_alias, commitment.error_type + )); + } + for stage in &stages { + source.push_str(&format!( + " {}({}::{}),\n", + stage.variant_name, stage.module_alias, stage.error_type + )); + } + if let Some(extension) = extension { + source.push_str(&extension.verifier.error_variants); + } + source.push_str("}\n\n"); + + if let Some(extension) = extension { + source.push_str(&extension.verifier.error_items); + } + + source.push_str(&format!( + "macro_rules! define_{protocol_snake}_verify_error_from {{\n ($module:ident, $error_ty:ident, $variant:ident) => {{\n impl From<$module::$error_ty> for {verify_error_type} {{\n fn from(error: $module::$error_ty) -> Self {{\n Self::$variant(error)\n }}\n }}\n }};\n}}\n\n" + )); + if let Some(commitment) = &commitment { + source.push_str(&format!( + "define_{protocol_snake}_verify_error_from!({module}, {error}, {variant});\n", + module = commitment.module_alias, + error = commitment.error_type, + variant = commitment.variant_name, + )); + } + for stage in &stages { + source.push_str(&format!( + "define_{protocol_snake}_verify_error_from!({}, {}, {});\n", + stage.module_alias, stage.error_type, stage.variant_name + )); + } + if commitment.is_some() || !stages.is_empty() { + source.push('\n'); + } + if let Some(extension) = extension { + source.push_str(&extension.verifier.error_conversions); + } + + source.push_str(&format!( + "pub fn verify_{protocol_snake}>(proof: &{proof_type}, inputs: {verifier_inputs_type}<'_>, transcript: &mut T) -> Result<{verification_artifacts_type}, {verify_error_type}> {{\n", + )); + source.push_str(&format!( + " verify_{protocol_snake}_with_programs(proof, inputs, default_verifier_programs(), transcript)\n}}\n\n" + )); + if let Some(extension) = extension { + source.push_str(&extension.verifier.after_default_verify); + } + source.push_str(&format!( + "pub fn verify_{protocol_snake}_with_programs>(proof: &{proof_type}, inputs: {verifier_inputs_type}<'_>, programs: {verifier_programs_type}, transcript: &mut T) -> Result<{verification_artifacts_type}, {verify_error_type}> {{\n", + )); + if let Some(extension) = extension { + source.push_str(&extension.verifier.with_programs_body_intro); + } + if let Some(prefix) = &config.instrumentation_prefix { + source.push_str(&format!( + " let _verify_span = tracing::info_span!(\"{prefix}.verify\").entered();\n" + )); + } + if let Some(commitment) = &commitment { + let verifier_fn = commitment + .with_program_verifier_fn + .as_deref() + .or(commitment.verifier_fn.as_deref()) + .unwrap_or("missing_commitment_verifier_function"); + let program_arg = if commitment.with_program_verifier_fn.is_some() + && commitment.program_type.is_some() + && commitment.program_const.is_some() + { + format!("programs.{}, ", commitment.field_name) + } else { + String::new() + }; + source.push_str(&format!( + " let {field} = {module}::{verifier_fn}({program_arg}&proof.commitments, transcript)?;\n", + field = commitment.field_name, + module = commitment.module_alias, + )); + } + if let Some(extension) = extension { + if !extension.verifier.stage_verification_override.is_empty() { + source.push_str(&extension.verifier.stage_verification_override); + } else { + emit_verifier_stage_calls(&mut source, &stages); + } + } else { + emit_verifier_stage_calls(&mut source, &stages); + } + if let Some(extension) = extension { + source.push_str(&extension.verifier.after_stage_verification); + } + source.push_str(&format!("\n Ok({verification_artifacts_type} {{\n")); + if let Some(commitment) = &commitment { + source.push_str(&format!(" {},\n", commitment.field_name)); + } + for stage in &stages { + source.push_str(&format!(" {},\n", stage.field_name)); + } + source.push_str(" })\n}\n\n"); + + if let Some(extension) = extension { + source.push_str(&extension.verifier.helper_items); + } + + source +} + +fn emit_verifier_stage_calls(source: &mut String, stages: &[StageRustApi]) { + for stage in stages { + let verifier_fn = stage + .with_program_verifier_fn + .as_deref() + .or(stage.verifier_fn.as_deref()) + .unwrap_or("missing_verifier_function"); + let mut args = vec![format!("&proof.{}", stage.field_name)]; + if stage.with_program_verifier_fn.is_some() + && stage.program_type.is_some() + && stage.program_const.is_some() + { + args.insert(0, format!("programs.{}", stage.field_name)); + } + if stage.opening_input_type.is_some() { + args.push(format!("inputs.{}_openings", stage.field_name)); + } + if stage.ram_data_type.is_some() { + args.push(format!("inputs.{}_ram", stage.field_name)); + } + if stage.verifier_data_type.is_some() { + args.push(format!("inputs.{}_data", stage.field_name)); + } + args.push("transcript".to_owned()); + source.push_str(&format!( + " let {} = {}::{}({})?;\n", + stage.field_name, + stage.module_alias, + verifier_fn, + args.join(", ") + )); + } +} + +fn generated_prover_api( + config: &ProtocolArtifactConfig, + artifacts: &[ProtocolRustArtifact], +) -> String { + let stages = stage_apis(config, artifacts); + let modules = role_modules(artifacts); + let kernel_modules = unique_kernel_modules(&stages); + let commitment = commitment_api(artifacts); + let has_commitment = commitment.is_some(); + let extension = active_role_api_extension(config, &stages, commitment.as_ref(), artifacts); + let generic_params = prover_generic_params(&stages, has_commitment); + let prefix = &config.type_prefix; + let protocol_snake = config.protocol_snake(); + let field_type = config.field_type.ident(); + let default_transcript_type = config.default_transcript_type.ident(); + let transcript_trait = config.transcript_trait.ident(); + let prover_setup_type = config.prover_setup_type.ident(); + let verifier_import = config.verifier_crate_import(); + let named_eval_type = format!("{prefix}NamedEval"); + let sumcheck_output_type = format!("{prefix}SumcheckOutput"); + let stage_proof_type = format!("{prefix}StageProof"); + let proof_type = format!("{prefix}Proof"); + let prover_inputs_type = format!("{prefix}ProverInputs"); + let prover_programs_type = format!("{prefix}ProverPrograms"); + let prover_artifacts_type = format!("{prefix}ProverArtifacts"); + let prove_error_type = format!("{prefix}ProveError"); + let default_transcript_alias = format!("Default{prefix}Transcript"); + + let mut source = String::new(); + if let Some(extension) = extension { + source.push_str(&extension.prover.imports); + } else { + if has_commitment { + source.push_str(&config.prover_setup_type.use_line()); + } + source.push_str(&config.field_type.use_line()); + if !kernel_modules.is_empty() { + let kernel_crate = config + .kernel_crate + .as_ref() + .map_or("missing_kernel_crate", |kernel_crate| { + kernel_crate.import.as_str() + }); + source.push_str(&format!( + "use {kernel_crate}::{{{}}};\n", + kernel_modules.join(", ") + )); + } + source.push_str(&config.default_transcript_type.use_line()); + source.push_str(&config.transcript_trait.use_line()); + source.push_str(&format!( + "use {verifier_import}::{{{named_eval_type}, {proof_type}, {stage_proof_type}, {sumcheck_output_type}}};\n\n", + )); + } + if !modules.is_empty() { + source.push_str(&format!( + "use crate::stages::{{{}}};\n\n", + aliased_modules(&modules).join(", ") + )); + } + source.push_str(&format!( + "pub type {default_transcript_alias} = {default_transcript_type}<{field_type}>;\n\n" + )); + + source.push_str(&format!( + "pub struct {prover_inputs_type}<'a, {}> {{\n", + generic_params.join(", ") + )); + if has_commitment { + source.push_str(" pub commitment_inputs: &'a mut CommitmentInputs,\n"); + source.push_str(&format!(" pub prover_setup: &'a {prover_setup_type},\n")); + } + for stage in &stages { + source.push_str(&format!( + " pub {}_executor: &'a mut {}Executor,\n", + stage.field_name, stage.variant_name + )); + } + if let Some(extension) = extension { + source.push_str(&extension.prover.input_fields); + } + source.push_str("}\n\n"); + + source.push_str(&format!( + "#[derive(Clone, Copy, Debug)]\npub struct {prover_programs_type} {{\n" + )); + if let Some(commitment) = &commitment { + if let (Some(program_type), Some(_), Some(_)) = ( + &commitment.program_type, + &commitment.program_const, + &commitment.with_program_prover_fn, + ) { + source.push_str(&format!( + " pub {}: &'static {}::{},\n", + commitment.field_name, commitment.module_alias, program_type + )); + } + } + for stage in &stages { + if let (Some(program_type), Some(_), Some(_)) = ( + &stage.program_type, + &stage.program_const, + &stage.with_program_prover_fn, + ) { + let program_module = stage + .kernel_module + .as_deref() + .unwrap_or(stage.module_alias.as_str()); + source.push_str(&format!( + " pub {}: &'static {}::{},\n", + stage.field_name, program_module, program_type + )); + } + } + if let Some(extension) = extension { + source.push_str(&extension.prover.program_fields); + } + source.push_str("}\n\n"); + + source.push_str(&format!( + "pub fn default_prover_programs() -> {prover_programs_type} {{\n {prover_programs_type} {{\n" + )); + if let Some(commitment) = &commitment { + if let (Some(_), Some(program_const), Some(_)) = ( + &commitment.program_type, + &commitment.program_const, + &commitment.with_program_prover_fn, + ) { + source.push_str(&format!( + " {}: &{}::{},\n", + commitment.field_name, commitment.module_alias, program_const + )); + } + } + for stage in &stages { + if let (Some(_), Some(program_const), Some(_)) = ( + &stage.program_type, + &stage.program_const, + &stage.with_program_prover_fn, + ) { + source.push_str(&format!( + " {}: &{}::{},\n", + stage.field_name, stage.module_alias, program_const + )); + } + } + if let Some(extension) = extension { + source.push_str(&extension.prover.default_program_fields); + } + source.push_str(" }\n}\n\n"); + + source.push_str(&format!( + "#[derive(Clone, Debug)]\npub struct {prover_artifacts_type} {{\n" + )); + if let Some(commitment) = &commitment { + source.push_str(&format!( + " pub {}: {}::{},\n", + commitment.field_name, commitment.module_alias, commitment.artifacts_type + )); + } + for stage in &stages { + let kernel_module = stage + .kernel_module + .as_deref() + .unwrap_or(stage.module_alias.as_str()); + source.push_str(&format!( + " pub {}: {}::{}<{field_type}>,\n", + stage.field_name, kernel_module, stage.artifacts_type + )); + } + source.push_str("}\n\n"); + + source.push_str(&format!( + "#[derive(Debug)]\npub enum {prove_error_type} {{\n" + )); + if let Some(commitment) = &commitment { + source.push_str(&format!( + " {}({}::{}),\n", + commitment.variant_name, commitment.module_alias, commitment.error_type + )); + } + for stage in &stages { + let kernel_module = stage + .kernel_module + .as_deref() + .unwrap_or(stage.module_alias.as_str()); + source.push_str(&format!( + " {}({}::{}),\n", + stage.variant_name, kernel_module, stage.error_type + )); + } + if let Some(extension) = extension { + source.push_str(&extension.prover.error_variants); + } + source.push_str("}\n\n"); + + if let Some(extension) = extension { + source.push_str(&extension.prover.error_items); + } + + if let Some(commitment) = &commitment { + source.push_str(&format!( + "impl From<{module}::{error}> for {prove_error_type} {{\n fn from(error: {module}::{error}) -> Self {{\n Self::{variant}(error)\n }}\n}}\n\n", + module = commitment.module_alias, + error = commitment.error_type, + variant = commitment.variant_name, + )); + } + for stage in &stages { + let kernel_module = stage + .kernel_module + .as_deref() + .unwrap_or(stage.module_alias.as_str()); + source.push_str(&format!( + "impl From<{}::{}> for {prove_error_type} {{\n fn from(error: {}::{}) -> Self {{\n Self::{}(error)\n }}\n}}\n\n", + kernel_module, stage.error_type, kernel_module, stage.error_type, stage.variant_name + )); + } + if let Some(extension) = extension { + source.push_str(&extension.prover.error_conversions); + } + + source.push_str(&format!( + "pub fn prove_{protocol_snake}<{}, T>(\n inputs: {prover_inputs_type}<'_, {}>,\n transcript: &mut T,\n) -> Result<({proof_type}, {prover_artifacts_type}), {prove_error_type}>\nwhere\n", + generic_params.join(", "), + generic_params.join(", ") + )); + if let Some(commitment) = &commitment { + let input_provider = commitment + .input_provider_trait + .as_deref() + .unwrap_or("MissingCommitmentInputProvider"); + source.push_str(&format!( + " CommitmentInputs: {}::{input_provider},\n", + commitment.module_alias + )); + } + for stage in &stages { + let kernel_module = stage + .kernel_module + .as_deref() + .unwrap_or(stage.module_alias.as_str()); + let kernel_trait = kernel_executor_type(&stage.error_type); + source.push_str(&format!( + " {}Executor: {}::{}<{field_type}>,\n", + stage.variant_name, kernel_module, kernel_trait + )); + } + source.push_str(&format!( + " T: {transcript_trait},\n" + )); + source.push_str("{\n"); + source.push_str(&format!( + " prove_{protocol_snake}_with_programs(inputs, default_prover_programs(), transcript)\n}}\n\n" + )); + + source.push_str(&format!( + "pub fn prove_{protocol_snake}_with_programs<{}, T>(\n inputs: {prover_inputs_type}<'_, {}>,\n programs: {prover_programs_type},\n transcript: &mut T,\n) -> Result<({proof_type}, {prover_artifacts_type}), {prove_error_type}>\nwhere\n", + generic_params.join(", "), + generic_params.join(", ") + )); + if let Some(commitment) = &commitment { + let input_provider = commitment + .input_provider_trait + .as_deref() + .unwrap_or("MissingCommitmentInputProvider"); + source.push_str(&format!( + " CommitmentInputs: {}::{input_provider},\n", + commitment.module_alias + )); + } + for stage in &stages { + let kernel_module = stage + .kernel_module + .as_deref() + .unwrap_or(stage.module_alias.as_str()); + let kernel_trait = kernel_executor_type(&stage.error_type); + source.push_str(&format!( + " {}Executor: {}::{}<{field_type}>,\n", + stage.variant_name, kernel_module, kernel_trait + )); + } + source.push_str(&format!( + " T: {transcript_trait},\n" + )); + source.push_str("{\n"); + if let Some(prefix) = &config.instrumentation_prefix { + source.push_str(&format!( + " let _prove_span = tracing::info_span!(\"{prefix}.prove\").entered();\n" + )); + } + if let Some(commitment) = &commitment { + let prover_fn = commitment + .with_program_prover_fn + .as_deref() + .or(commitment.prover_fn.as_deref()) + .unwrap_or("missing_commitment_prover_function"); + let program_arg = if commitment.with_program_prover_fn.is_some() + && commitment.program_type.is_some() + && commitment.program_const.is_some() + { + format!("programs.{}, ", commitment.field_name) + } else { + String::new() + }; + if let Some(prefix) = &config.instrumentation_prefix { + source.push_str(&format!( + " let _{field}_span = tracing::info_span!(\"{prefix}.{field}\").entered();\n", + field = commitment.field_name + )); + } + source.push_str(&format!( + " let {field} = {module}::{prover_fn}(\n {program_arg}inputs.commitment_inputs,\n inputs.prover_setup,\n transcript,\n )?;\n", + field = commitment.field_name, + module = commitment.module_alias + )); + if config.instrumentation_prefix.is_some() { + source.push_str(&format!(" drop(_{}_span);\n", commitment.field_name)); + } + } + for stage in &stages { + let prover_fn = stage + .with_program_prover_fn + .as_deref() + .or(stage.prover_fn.as_deref()) + .unwrap_or("missing_prover_function"); + let program_arg = if stage.with_program_prover_fn.is_some() + && stage.program_type.is_some() + && stage.program_const.is_some() + { + format!("programs.{}, ", stage.field_name) + } else { + String::new() + }; + if let Some(prefix) = &config.instrumentation_prefix { + source.push_str(&format!( + " let _{field}_span = tracing::info_span!(\"{prefix}.{span}\").entered();\n", + field = stage.field_name, + span = generated_stage_span_name(&stage.field_name) + )); + } + source.push_str(&format!( + " let {} = {}::{}({program_arg}inputs.{}_executor, transcript)?;\n", + stage.field_name, stage.module_alias, prover_fn, stage.field_name + )); + if config.instrumentation_prefix.is_some() { + source.push_str(&format!(" drop(_{}_span);\n", stage.field_name)); + } + } + if let Some(extension) = extension { + source.push_str(&extension.prover.after_stage_execution); + } + source.push_str(&format!("\n let proof = {proof_type} {{\n")); + if let Some(commitment) = &commitment { + source.push_str(&format!( + " commitments: {}.commitments.clone(),\n", + commitment.field_name + )); + } + for stage in &stages { + source.push_str(&format!( + " {}: {}_proof(&{}),\n", + stage.field_name, stage.field_name, stage.field_name + )); + } + if let Some(extension) = extension { + source.push_str(&extension.prover.proof_fields); + } + source.push_str(&format!( + " }};\n let artifacts = {prover_artifacts_type} {{\n" + )); + if let Some(commitment) = &commitment { + source.push_str(&format!(" {},\n", commitment.field_name)); + } + for stage in &stages { + source.push_str(&format!(" {},\n", stage.field_name)); + } + source.push_str(" };\n Ok((proof, artifacts))\n}\n\n"); + + if let Some(extension) = extension { + source.push_str(&extension.prover.helper_items); + } + + for stage in &stages { + let kernel_module = stage + .kernel_module + .as_deref() + .unwrap_or(stage.module_alias.as_str()); + source.push_str(&format!( + "pub fn {field}_proof(artifacts: &{kernel}::{artifacts_ty}<{field_type}>) -> {stage_proof_type} {{\n {stage_proof_type} {{\n sumchecks: artifacts.sumchecks.iter().map({field}_sumcheck).collect(),\n }}\n}}\n\n", + field = stage.field_name, + kernel = kernel_module, + artifacts_ty = stage.artifacts_type + )); + source.push_str(&format!( + "fn {field}_sumcheck(output: &{kernel}::{output_ty}<{field_type}>) -> {sumcheck_output_type} {{\n {sumcheck_output_type} {{\n driver: output.driver,\n point: output.point.clone(),\n evals: output.evals.iter().map({field}_eval).collect(),\n proof: output.proof.clone(),\n }}\n}}\n\n", + field = stage.field_name, + kernel = kernel_module, + output_ty = stage.output_type + )); + source.push_str(&format!( + "fn {field}_eval(eval: &{kernel}::{eval_ty}<{field_type}>) -> {named_eval_type} {{\n {named_eval_type} {{\n name: eval.name,\n oracle: eval.oracle,\n value: eval.value,\n }}\n}}\n\n", + field = stage.field_name, + kernel = kernel_module, + eval_ty = stage.eval_type + )); + } + source +} + +fn stage_apis( + config: &ProtocolArtifactConfig, + artifacts: &[ProtocolRustArtifact], +) -> Vec { + artifacts + .iter() + .filter(|artifact| artifact.stage.is_proof()) + .map(|artifact| stage_api(config, artifact)) + .collect() +} + +fn stage_api(config: &ProtocolArtifactConfig, artifact: &ProtocolRustArtifact) -> StageRustApi { + let source = artifact.source.source.as_str(); + let artifacts_type = find_type_with_suffix(source, "ExecutionArtifacts").unwrap_or_else(|| { + format!( + "{}ExecutionArtifacts", + upper_camel(artifact.stage.module_name()) + ) + }); + let prefix = artifacts_type + .strip_suffix("ExecutionArtifacts") + .unwrap_or(&artifacts_type); + let opening_input_name = format!("{prefix}OpeningInputValue"); + let ram_data_name = format!("{prefix}RamData"); + let verifier_data_name = format!("{prefix}VerifierData"); + let program_type_suffix = match artifact.role { + Role::Prover => "CpuProgramPlan", + Role::Verifier => "VerifierProgramPlan", + }; + let program_type = find_type_with_suffix(source, program_type_suffix); + let program_const = program_type + .as_deref() + .and_then(|program_type| find_public_const_of_type(source, program_type)); + let error_type = match artifact.role { + Role::Prover => find_type_with_suffix(source, "KernelError") + .unwrap_or_else(|| format!("{prefix}KernelError")), + Role::Verifier => find_public_item(source, "pub enum ", "Error") + .unwrap_or_else(|| format!("Verify{prefix}Error")), + }; + StageRustApi { + field_name: artifact.stage.module_name().to_owned(), + module_alias: module_alias(artifact.stage.module_name()), + variant_name: upper_camel(artifact.stage.module_name()), + output_type: format!("{prefix}SumcheckOutput"), + eval_type: format!("{prefix}NamedEval"), + artifacts_type, + error_type, + verifier_fn: find_public_fn(source, &["verify_"]), + with_program_verifier_fn: find_public_fn_containing(source, &["verify_"], "_with_program"), + program_type, + program_const, + prover_fn: find_public_fn(source, &["prove_", "execute_"]), + with_program_prover_fn: find_public_fn_containing( + source, + &["prove_", "execute_"], + "_with_program", + ), + kernel_module: find_kernel_module(config, source), + opening_input_type: has_public_type_name(source, &opening_input_name) + .then_some(opening_input_name), + ram_data_type: has_public_type_name(source, &ram_data_name).then_some(ram_data_name), + verifier_data_type: has_public_type_name(source, &verifier_data_name) + .then_some(verifier_data_name), + } +} + +fn generated_stage_span_name(field_name: &str) -> &str { + field_name.strip_suffix("_outer").unwrap_or(field_name) +} + +fn commitment_api(artifacts: &[ProtocolRustArtifact]) -> Option { + let artifact = artifacts + .iter() + .find(|artifact| artifact.stage.is_commitment())?; + let source = artifact.source.source.as_str(); + let artifacts_type = find_type_with_suffix(source, "Artifacts") + .unwrap_or_else(|| format!("{}Artifacts", upper_camel(artifact.stage.module_name()))); + let error_type = find_public_item(source, "pub enum ", "Error") + .unwrap_or_else(|| format!("{}Error", upper_camel(artifact.stage.module_name()))); + let input_provider_trait = find_public_item(source, "pub trait ", "InputProvider"); + let program_type_suffix = match artifact.role { + Role::Prover => "ProverProgramPlan", + Role::Verifier => "VerifierProgramPlan", + }; + let program_type = find_type_with_suffix(source, program_type_suffix); + let program_const = program_type + .as_deref() + .and_then(|program_type| find_public_const_of_type(source, program_type)); + Some(CommitmentRustApi { + field_name: artifact.stage.module_name().to_owned(), + module_alias: module_alias(artifact.stage.module_name()), + variant_name: upper_camel(artifact.stage.module_name()), + artifacts_type, + error_type, + verifier_fn: find_public_fn(source, &["verify_"]), + with_program_verifier_fn: find_public_fn_containing(source, &["verify_"], "_with_program"), + program_type, + program_const, + prover_fn: find_public_fn(source, &["prove_"]), + with_program_prover_fn: find_public_fn_containing(source, &["prove_"], "_with_program"), + input_provider_trait, + }) +} + +fn active_role_api_extension<'a>( + config: &'a ProtocolArtifactConfig, + stages: &[StageRustApi], + commitment: Option<&CommitmentRustApi>, + artifacts: &[ProtocolRustArtifact], +) -> Option<&'a ProtocolArtifactExtension> { + let extension = config.role_api_extension.as_ref()?; + if extension.required_commitment && commitment.is_none() { + return None; + } + if !extension + .required_proof_stages + .iter() + .all(|required| stages.iter().any(|stage| &stage.field_name == required)) + { + return None; + } + if !extension.required_artifact_stages.iter().all(|required| { + artifacts + .iter() + .any(|artifact| artifact.stage.module_name() == required) + }) { + return None; + } + Some(extension) +} + +fn role_modules(artifacts: &[ProtocolRustArtifact]) -> Vec { + artifacts + .iter() + .map(|artifact| artifact.stage.module_name().to_owned()) + .collect() +} + +fn aliased_modules(modules: &[String]) -> Vec { + modules + .iter() + .map(|module| format!("{module} as {}", module_alias(module))) + .collect() +} + +fn module_alias(module: &str) -> String { + format!("{module}_stage") +} + +fn unique_kernel_modules(stages: &[StageRustApi]) -> Vec { + let mut modules = Vec::new(); + for stage in stages { + if let Some(kernel_module) = &stage.kernel_module { + if !modules.contains(kernel_module) { + modules.push(kernel_module.clone()); + } + } + } + modules +} + +fn prover_generic_params(stages: &[StageRustApi], has_commitment: bool) -> Vec { + let mut params = if has_commitment { + vec!["CommitmentInputs".to_owned()] + } else { + Vec::new() + }; + params.extend( + stages + .iter() + .map(|stage| format!("{}Executor", stage.variant_name)), + ); + params +} + +fn find_public_item(source: &str, prefix: &str, suffix: &str) -> Option { + source.lines().find_map(|line| { + let trimmed = line.trim_start(); + let rest = trimmed.strip_prefix(prefix)?; + let name = rest + .split(|character: char| { + matches!(character, '<' | '(' | '{') || character.is_whitespace() + }) + .next()?; + name.ends_with(suffix).then(|| name.to_owned()) + }) +} + +fn find_type_with_suffix(source: &str, suffix: &str) -> Option { + source + .split(|character: char| !character.is_ascii_alphanumeric() && character != '_') + .find(|token| token.ends_with(suffix) && token.len() > suffix.len()) + .map(ToOwned::to_owned) +} + +fn find_public_fn(source: &str, prefixes: &[&str]) -> Option { + source.lines().find_map(|line| { + let trimmed = line.trim_start(); + let rest = trimmed.strip_prefix("pub fn ")?; + let name = rest + .split(|character: char| matches!(character, '<' | '(') || character.is_whitespace()) + .next()?; + prefixes + .iter() + .any(|prefix| name.starts_with(prefix)) + .then(|| name.to_owned()) + }) +} + +fn find_public_fn_containing(source: &str, prefixes: &[&str], needle: &str) -> Option { + source.lines().find_map(|line| { + let trimmed = line.trim_start(); + let rest = trimmed.strip_prefix("pub fn ")?; + let name = rest + .split(|character: char| matches!(character, '<' | '(') || character.is_whitespace()) + .next()?; + (name.contains(needle) && prefixes.iter().any(|prefix| name.starts_with(prefix))) + .then(|| name.to_owned()) + }) +} + +fn find_public_const_of_type(source: &str, type_name: &str) -> Option { + source.lines().find_map(|line| { + let trimmed = line.trim_start(); + let rest = trimmed.strip_prefix("pub const ")?; + let name = rest + .split(|character: char| character == ':' || character.is_whitespace()) + .next()?; + rest.contains(&format!(": {type_name}")) + .then(|| name.to_owned()) + }) +} + +fn has_public_type_name(source: &str, type_name: &str) -> bool { + source.contains(&format!("pub struct {type_name}")) + || source.contains(&format!("pub type {type_name}")) + || source.contains(&format!(" as {type_name}")) +} + +fn kernel_executor_type(error_type: &str) -> String { + error_type.strip_suffix("KernelError").map_or_else( + || { + error_type + .replace("Verify", "") + .replace("Error", "KernelExecutor") + }, + |prefix| format!("{prefix}KernelExecutor"), + ) +} + +fn find_kernel_module(config: &ProtocolArtifactConfig, source: &str) -> Option { + let kernel_import = config.kernel_crate.as_ref()?.import.as_str(); + let prefix = format!("use {kernel_import}::"); + source.lines().find_map(|line| { + let rest = line.trim_start().strip_prefix(&prefix)?; + rest.split(|character: char| matches!(character, ':' | '{') || character.is_whitespace()) + .next() + .filter(|name| !name.is_empty()) + .map(ToOwned::to_owned) + }) +} + +fn upper_camel(name: &str) -> String { + let mut output = String::new(); + for segment in name.split('_') { + let mut chars = segment.chars(); + if let Some(first) = chars.next() { + output.extend(first.to_uppercase()); + output.push_str(chars.as_str()); + } + } + output +} + +fn snake_case(name: &str) -> String { + let mut output = String::new(); + for (index, character) in name.chars().enumerate() { + if character.is_ascii_uppercase() { + if index != 0 { + output.push('_'); + } + output.extend(character.to_lowercase()); + } else if character == '-' || character == ' ' { + output.push('_'); + } else { + output.push(character); + } + } + output +} + +fn rust_crate_ident(package_name: &str) -> String { + package_name.replace('-', "_") +} + +fn byte_string_literal(label: &str) -> String { + let escaped = label.escape_default().to_string(); + format!("b\"{escaped}\"") +} + +fn generated_file_path(root: &Path, relative_path: &str) -> Result { + let path = Path::new(relative_path); + if path.is_absolute() + || path.components().any(|component| { + matches!( + component, + Component::ParentDir | Component::RootDir | Component::Prefix(_) + ) + }) + { + return Err(EmitError::new(format!( + "generated crate file path `{relative_path}` must be relative and stay inside the crate" + ))); + } + Ok(root.join(path)) +} diff --git a/crates/bolt/src/emit/rust/mod.rs b/crates/bolt/src/emit/rust/mod.rs new file mode 100644 index 0000000000..aeb2a0bcf2 --- /dev/null +++ b/crates/bolt/src/emit/rust/mod.rs @@ -0,0 +1,19 @@ +mod artifacts; +mod source; + +pub(crate) fn push_format(source: &mut String, args: std::fmt::Arguments<'_>) { + use std::fmt::Write as _; + + if source.write_fmt(args).is_err() { + std::process::abort(); + } +} + +pub use artifacts::{ + assemble_generated_crates, assemble_workspace_generated_crates, protocol_rust_artifact, + validate_rust_artifact_imports, write_generated_crates, ArtifactCrateRole, GeneratedCrate, + GeneratedFile, ProtocolArtifactConfig, ProtocolArtifactExtension, ProtocolCrateRef, + ProtocolProverApiExtension, ProtocolRuntimeModule, ProtocolRustArtifact, ProtocolStage, + ProtocolStageKind, ProtocolStandaloneDependency, ProtocolVerifierApiExtension, RustTypeRef, +}; +pub use source::{EmitError, RustSourceFile}; diff --git a/crates/bolt/src/emit/rust/source.rs b/crates/bolt/src/emit/rust/source.rs new file mode 100644 index 0000000000..8adfb86d44 --- /dev/null +++ b/crates/bolt/src/emit/rust/source.rs @@ -0,0 +1,37 @@ +use std::error::Error; +use std::fmt::{self, Display, Formatter}; + +use crate::schema::SchemaError; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RustSourceFile { + pub filename: String, + pub source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct EmitError { + message: String, +} + +impl EmitError { + pub(crate) fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} + +impl Display for EmitError { + fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result { + formatter.write_str(&self.message) + } +} + +impl Error for EmitError {} + +impl From for EmitError { + fn from(error: SchemaError) -> Self { + Self::new(error.to_string()) + } +} diff --git a/crates/bolt/src/ir.rs b/crates/bolt/src/ir.rs new file mode 100644 index 0000000000..f4721730ea --- /dev/null +++ b/crates/bolt/src/ir.rs @@ -0,0 +1,160 @@ +use std::fmt::{self, Display, Formatter}; +use std::marker::PhantomData; + +use melior::ir::operation::OperationLike; +use melior::ir::{Attribute, Module}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Protocol; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Concrete; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Party; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Compute; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Cpu; + +pub trait Phase { + const NAME: &'static str; +} + +impl Phase for Protocol { + const NAME: &'static str = "protocol"; +} + +impl Phase for Concrete { + const NAME: &'static str = "concrete"; +} + +impl Phase for Party { + const NAME: &'static str = "party"; +} + +impl Phase for Compute { + const NAME: &'static str = "compute"; +} + +impl Phase for Cpu { + const NAME: &'static str = "cpu"; +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Role { + Prover, + Verifier, +} + +impl Role { + pub fn as_str(&self) -> &'static str { + match self { + Self::Prover => "prover", + Self::Verifier => "verifier", + } + } + + fn parse(value: &str) -> Option { + match value { + "prover" => Some(Self::Prover), + "verifier" => Some(Self::Verifier), + _ => None, + } + } +} + +impl Display for Role { + fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result { + formatter.write_str(self.as_str()) + } +} + +#[derive(Debug)] +pub struct BoltModule<'c, P: Phase> { + module: Module<'c>, + phase: PhantomData

, +} + +impl<'c, P: Phase> BoltModule<'c, P> { + pub(crate) fn from_mlir(module: Module<'c>) -> Self { + Self { + module, + phase: PhantomData, + } + } + + pub fn as_mlir_module(&self) -> &Module<'c> { + &self.module + } + + pub fn as_mlir_module_mut(&mut self) -> &mut Module<'c> { + &mut self.module + } + + pub fn into_mlir_module(self) -> Module<'c> { + self.module + } + + pub fn name(&self) -> String { + self.string_attr("sym_name") + .unwrap_or_else(|| "anonymous".to_owned()) + } + + pub fn role(&self) -> Option { + self.string_attr("bolt.role") + .and_then(|value| Role::parse(&value)) + } + + pub fn verify(&self) -> bool { + self.module.as_operation().verify() + } + + fn string_attr(&self, name: &str) -> Option { + self.module + .as_operation() + .attribute(name) + .ok() + .and_then(string_attribute_value) + } +} + +pub trait TextMlir { + fn to_text_mlir(&self) -> String; +} + +impl TextMlir for BoltModule<'_, P> { + fn to_text_mlir(&self) -> String { + self.module.as_operation().to_string() + } +} + +pub(crate) fn string_attribute_value(attribute: Attribute<'_>) -> Option { + let value = attribute.to_string(); + value + .strip_prefix('"') + .and_then(|value| value.strip_suffix('"')) + .map(ToOwned::to_owned) +} + +pub(crate) fn symbol_attribute_value(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .strip_prefix('@') + .map(ToOwned::to_owned) +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Diagnostic { + pub message: String, +} + +impl Diagnostic { + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} diff --git a/crates/bolt/src/lib.rs b/crates/bolt/src/lib.rs new file mode 100644 index 0000000000..c802d446ca --- /dev/null +++ b/crates/bolt/src/lib.rs @@ -0,0 +1,28 @@ +pub mod dialects; +pub mod emit; +pub mod ir; +pub mod mlir; +pub mod pass; +pub mod protocols; +pub mod schema; + +pub use emit::rust::{ + assemble_generated_crates, assemble_workspace_generated_crates, protocol_rust_artifact, + validate_rust_artifact_imports, write_generated_crates, ArtifactCrateRole, EmitError, + GeneratedCrate, GeneratedFile, ProtocolArtifactConfig, ProtocolArtifactExtension, + ProtocolCrateRef, ProtocolProverApiExtension, ProtocolRuntimeModule, ProtocolRustArtifact, + ProtocolStage, ProtocolStageKind, ProtocolStandaloneDependency, ProtocolVerifierApiExtension, + RustSourceFile, RustTypeRef, +}; +pub use ir::{ + BoltModule, Compute, Concrete, Cpu, Diagnostic, Party, Phase, Protocol, Role, TextMlir, +}; +pub use mlir::{MeliorContext, MlirError}; +pub use pass::{ + derive_prover_role, derive_verifier_role, lower_piop_and_fiat_shamir, project_party, + project_prover_party, project_verifier_party, verify_concrete_transcript, VerifyError, +}; +pub use schema::{ + verify_compute_schema, verify_concrete_schema, verify_cpu_schema, verify_party_schema, + verify_protocol_schema, SchemaError, +}; diff --git a/crates/bolt/src/mlir.rs b/crates/bolt/src/mlir.rs new file mode 100644 index 0000000000..263bbcbad6 --- /dev/null +++ b/crates/bolt/src/mlir.rs @@ -0,0 +1,292 @@ +use std::error::Error; +use std::fmt::{self, Display, Formatter}; + +use melior::dialect::DialectRegistry; +use melior::ir::attribute::StringAttribute; +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationBuilder, OperationMutLike}; +use melior::ir::{Attribute, Identifier, Location, Module, OperationRef, Type, Value}; +use melior::utility::register_all_dialects; +use melior::Context; + +use crate::dialects::load_bolt_dialects; +use crate::ir::{BoltModule, Phase, Role, TextMlir}; + +#[derive(Debug)] +pub struct MeliorContext { + context: Context, +} + +impl MeliorContext { + pub fn new() -> Self { + Self::try_new().unwrap_or_else(Self::abort_init_error) + } + + pub fn try_new() -> Result { + let registry = DialectRegistry::new(); + register_all_dialects(®istry); + let context = Context::new_with_registry(®istry, false); + context.load_all_available_dialects(); + load_bolt_dialects(&context) + .map_err(|message| MlirError::DialectRegistration { message })?; + context.set_allow_unregistered_dialects(false); + Ok(Self { context }) + } + + fn abort_init_error(error: MlirError) -> Self { + drop(error); + std::process::abort(); + } + + pub fn context(&self) -> &Context { + &self.context + } + + pub fn new_module<'c, P: Phase>(&'c self, name: &str, role: Option) -> BoltModule<'c, P> { + let mut module = Module::new(Location::unknown(&self.context)); + module + .as_operation_mut() + .set_attribute("sym_name", StringAttribute::new(&self.context, name).into()); + module.as_operation_mut().set_attribute( + "bolt.phase", + StringAttribute::new(&self.context, P::NAME).into(), + ); + if let Some(role) = role { + module.as_operation_mut().set_attribute( + "bolt.role", + StringAttribute::new(&self.context, role.as_str()).into(), + ); + } + BoltModule::from_mlir(module) + } + + pub fn parse_module<'c, P: Phase>( + &'c self, + source: &str, + ) -> Result, MlirError> { + Module::parse(&self.context, source) + .map(BoltModule::from_mlir) + .ok_or_else(|| MlirError::ParseFailed { + source: source.to_owned(), + }) + } + + pub fn append_op<'c, P: Phase>( + &'c self, + module: &BoltModule<'c, P>, + name: &str, + symbol: Option<&str>, + attrs: &[(&str, &str)], + ) -> Result<(), MlirError> { + self.append_op_from_iter(module, name, symbol, attrs.iter().copied()) + } + + pub fn append_op_with_owned_attrs<'c, P: Phase>( + &'c self, + module: &BoltModule<'c, P>, + name: &str, + symbol: Option<&str>, + attrs: &[(String, String)], + ) -> Result<(), MlirError> { + self.append_op_from_iter( + module, + name, + symbol, + attrs + .iter() + .map(|(name, value)| (name.as_str(), value.as_str())), + ) + } + + fn append_op_from_iter<'c, P: Phase, I, K, V>( + &'c self, + module: &BoltModule<'c, P>, + name: &str, + symbol: Option<&str>, + attrs: I, + ) -> Result<(), MlirError> + where + I: IntoIterator, + K: AsRef, + V: AsRef, + { + let mut attributes = Vec::new(); + if let Some(symbol) = symbol { + attributes.push(( + Identifier::new(&self.context, "sym_name"), + StringAttribute::new(&self.context, symbol).into(), + )); + } + for (name, source) in attrs { + let name = name.as_ref(); + let source = source.as_ref(); + attributes.push(( + Identifier::new(&self.context, name), + self.parse_attr(name, source)?, + )); + } + + let operation = OperationBuilder::new(name, Location::unknown(&self.context)) + .add_attributes(&attributes) + .build() + .map_err(|source| MlirError::OperationBuild { + op: name.to_owned(), + source, + })?; + let _operation = module.as_mlir_module().body().append_operation(operation); + Ok(()) + } + + pub(crate) fn append_typed_op<'c, 'a, P: Phase>( + &'c self, + module: &'a BoltModule<'c, P>, + name: &str, + symbol: Option<&str>, + attrs: &[(&str, &str)], + operands: &[Value<'c, 'a>], + result_types: &[&str], + ) -> Result, MlirError> { + self.append_typed_op_from_iter( + module, + name, + symbol, + attrs.iter().copied(), + operands, + result_types, + ) + } + + pub(crate) fn append_typed_op_with_owned_attrs<'c, 'a, P: Phase>( + &'c self, + module: &'a BoltModule<'c, P>, + name: &str, + symbol: Option<&str>, + attrs: &[(String, String)], + operands: &[Value<'c, 'a>], + result_types: &[&str], + ) -> Result, MlirError> { + self.append_typed_op_from_iter( + module, + name, + symbol, + attrs + .iter() + .map(|(name, value)| (name.as_str(), value.as_str())), + operands, + result_types, + ) + } + + fn append_typed_op_from_iter<'c, 'a, P: Phase, I, K, V>( + &'c self, + module: &'a BoltModule<'c, P>, + name: &str, + symbol: Option<&str>, + attrs: I, + operands: &[Value<'c, 'a>], + result_types: &[&str], + ) -> Result, MlirError> + where + I: IntoIterator, + K: AsRef, + V: AsRef, + { + let mut attributes = Vec::new(); + if let Some(symbol) = symbol { + attributes.push(( + Identifier::new(&self.context, "sym_name"), + StringAttribute::new(&self.context, symbol).into(), + )); + } + for (name, source) in attrs { + let name = name.as_ref(); + let source = source.as_ref(); + attributes.push(( + Identifier::new(&self.context, name), + self.parse_attr(name, source)?, + )); + } + let result_types = result_types + .iter() + .map(|source| { + Type::parse(&self.context, source).ok_or_else(|| MlirError::TypeParse { + source: (*source).to_owned(), + }) + }) + .collect::, _>>()?; + + let operation = OperationBuilder::new(name, Location::unknown(&self.context)) + .add_operands(operands) + .add_results(&result_types) + .add_attributes(&attributes) + .build() + .map_err(|source| MlirError::OperationBuild { + op: name.to_owned(), + source, + })?; + Ok(module.as_mlir_module().body().append_operation(operation)) + } + + fn parse_attr<'c>(&'c self, name: &str, source: &str) -> Result, MlirError> { + Attribute::parse(&self.context, source).ok_or_else(|| MlirError::AttributeParse { + name: name.to_owned(), + source: source.to_owned(), + }) + } +} + +impl Default for MeliorContext { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +pub enum MlirError { + AttributeParse { name: String, source: String }, + TypeParse { source: String }, + OperationBuild { op: String, source: melior::Error }, + ParseFailed { source: String }, + Schema { message: String }, + DialectRegistration { message: String }, + VerificationFailed { source: String }, +} + +impl Display for MlirError { + fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result { + match self { + Self::AttributeParse { name, source } => { + write!( + formatter, + "failed to parse MLIR attribute `{name}` from `{source}`" + ) + } + Self::TypeParse { source } => { + write!(formatter, "failed to parse MLIR type `{source}`") + } + Self::OperationBuild { op, source } => { + write!(formatter, "failed to build MLIR operation `{op}`: {source}") + } + Self::ParseFailed { source } => { + write!(formatter, "failed to parse MLIR module:\n{source}") + } + Self::Schema { message } => formatter.write_str(message), + Self::DialectRegistration { message } => formatter.write_str(message), + Self::VerificationFailed { source } => { + write!(formatter, "MLIR module verification failed:\n{source}") + } + } + } +} + +impl Error for MlirError {} + +pub(crate) fn verify_module(module: &BoltModule<'_, P>) -> Result<(), MlirError> { + if module.verify() { + Ok(()) + } else { + Err(MlirError::VerificationFailed { + source: module.to_text_mlir(), + }) + } +} diff --git a/crates/bolt/src/pass.rs b/crates/bolt/src/pass.rs new file mode 100644 index 0000000000..8b883a1f6b --- /dev/null +++ b/crates/bolt/src/pass.rs @@ -0,0 +1,316 @@ +use std::error::Error; +use std::fmt::{self, Display, Formatter}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::OperationLike; + +use crate::ir::{BoltModule, Concrete, Party, Phase, Protocol, Role}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{verify_concrete_schema, verify_party_schema, verify_protocol_schema}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct VerifyError { + message: String, +} + +impl VerifyError { + fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} + +impl Display for VerifyError { + fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result { + formatter.write_str(&self.message) + } +} + +impl Error for VerifyError {} + +pub fn lower_piop_and_fiat_shamir<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, +) -> Result, MlirError> { + verify_protocol_schema(module)?; + let source = phase_copy_source(module, Concrete::NAME, None, &[]); + let concrete = context.parse_module::(&source)?; + verify_module(&concrete)?; + verify_concrete_schema(&concrete)?; + Ok(concrete) +} + +pub fn derive_prover_role<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Concrete>, +) -> Result, MlirError> { + derive_role(context, module, Role::Prover) +} + +pub fn derive_verifier_role<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Concrete>, +) -> Result, MlirError> { + derive_role(context, module, Role::Verifier) +} + +pub fn project_prover_party<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Concrete>, +) -> Result, MlirError> { + project_party(context, module, Role::Prover) +} + +pub fn project_verifier_party<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Concrete>, +) -> Result, MlirError> { + project_party(context, module, Role::Verifier) +} + +pub fn project_party<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Concrete>, + role: Role, +) -> Result, MlirError> { + verify_concrete_schema(module)?; + require_declared_role(module, &role)?; + let party_function = format!( + " \"party.function\"() {{role = \"{}\", source = @{}, sym_name = \"{}.{}\"}} : () -> ()", + role.as_str(), + module.name(), + module.name(), + role.as_str() + ); + let source = phase_copy_source(module, Party::NAME, Some(&role), &[party_function]); + let party = context.parse_module::(&source)?; + verify_module(&party)?; + verify_party_schema(&party)?; + Ok(party) +} + +pub fn verify_concrete_transcript

(module: &BoltModule<'_, P>) -> Result<(), VerifyError> +where + P: Phase, +{ + let mut current_state = None; + let mut error = None; + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "transcript.state" => { + if current_state.is_some() { + error = Some(VerifyError::new("multiple transcript.state ops")); + break; + } + let Ok(result) = op.result(0) else { + error = Some(VerifyError::new("transcript.state requires one result")); + break; + }; + current_state = Some(result.to_string()); + } + "transcript.absorb" | "transcript.absorb_optional" => { + let Some(expected_input) = current_state.as_deref() else { + error = Some(VerifyError::new( + "transcript absorb requires a prior transcript.state result", + )); + break; + }; + let Ok(input) = op.operand(0) else { + error = Some(VerifyError::new(format!( + "{} requires transcript-state operand 0", + operation_name(op) + ))); + break; + }; + let input = input.to_string(); + if input != expected_input { + error = Some(VerifyError::new(format!( + "{} consumed transcript state {input}, expected {expected_input}", + operation_name(op) + ))); + break; + } + if op.operand(1).is_err() { + error = Some(VerifyError::new(format!( + "{} requires commitment artifact operand 1", + operation_name(op) + ))); + break; + } + let Ok(result) = op.result(0) else { + error = Some(VerifyError::new(format!( + "{} requires one transcript-state result", + operation_name(op) + ))); + break; + }; + current_state = Some(result.to_string()); + } + "transcript.absorb_bytes" => { + let Some(expected_input) = current_state.as_deref() else { + error = Some(VerifyError::new( + "transcript absorb_bytes requires a prior transcript.state result", + )); + break; + }; + let Ok(input) = op.operand(0) else { + error = Some(VerifyError::new( + "transcript.absorb_bytes requires transcript-state operand 0", + )); + break; + }; + let input = input.to_string(); + if input != expected_input { + error = Some(VerifyError::new(format!( + "transcript.absorb_bytes consumed transcript state {input}, expected {expected_input}", + ))); + break; + } + let Ok(result) = op.result(0) else { + error = Some(VerifyError::new( + "transcript.absorb_bytes requires one transcript-state result", + )); + break; + }; + current_state = Some(result.to_string()); + } + "transcript.squeeze" | "piop.sumcheck" | "pcs.batch_open" | "pcs.batch_verify" => { + let Some(expected_input) = current_state.as_deref() else { + error = Some(VerifyError::new(format!( + "{} requires a prior transcript.state result", + operation_name(op) + ))); + break; + }; + let Ok(input) = op.operand(0) else { + error = Some(VerifyError::new(format!( + "{} requires transcript-state operand 0", + operation_name(op) + ))); + break; + }; + let input = input.to_string(); + if input != expected_input { + error = Some(VerifyError::new(format!( + "{} consumed transcript state {input}, expected {expected_input}", + operation_name(op) + ))); + break; + } + let Ok(result) = op.result(0) else { + error = Some(VerifyError::new(format!( + "{} requires transcript-state result 0", + operation_name(op) + ))); + break; + }; + current_state = Some(result.to_string()); + } + _ => {} + } + } + + match error { + Some(error) => Err(error), + None => Ok(()), + } +} + +fn derive_role<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Concrete>, + role: Role, +) -> Result, MlirError> { + let source = phase_copy_source(module, Concrete::NAME, Some(&role), &[]); + context.parse_module::(&source) +} + +fn phase_copy_source( + module: &BoltModule<'_, P>, + target_phase: &str, + role: Option<&Role>, + prefix_ops: &[String], +) -> String { + let mut source = format!( + "module @{} attributes {{bolt.phase = \"{target_phase}\"", + module.name() + ); + if let Some(role) = role { + source.push_str(", bolt.role = \""); + source.push_str(role.as_str()); + source.push('"'); + } + source.push_str("} {\n"); + for op in prefix_ops { + source.push_str(op); + source.push('\n'); + } + append_body_text(&mut source, module); + source.push_str("}\n"); + source +} + +fn require_declared_role(module: &BoltModule<'_, Concrete>, role: &Role) -> Result<(), MlirError> { + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + if operation_name(op) != "protocol.boundary" { + continue; + } + let roles = op + .attribute("roles") + .ok() + .and_then(|attribute| parse_string_array(&attribute.to_string())) + .ok_or_else(|| MlirError::Schema { + message: "protocol.boundary requires string array attr `roles`".to_owned(), + })?; + if roles.iter().any(|declared| declared == role.as_str()) { + return Ok(()); + } + return Err(MlirError::Schema { + message: format!("protocol.boundary does not declare role `{role}`"), + }); + } + + Err(MlirError::Schema { + message: "module missing required op `protocol.boundary`".to_owned(), + }) +} + +fn parse_string_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| { + item.trim() + .strip_prefix('"') + .and_then(|item| item.strip_suffix('"')) + .map(ToOwned::to_owned) + }) + .collect() +} + +fn operation_name<'c: 'a, 'a>(operation: impl OperationLike<'c, 'a>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} + +fn append_body_text(module_source: &mut String, module: &BoltModule<'_, P>) { + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + module_source.push_str(" "); + module_source.push_str(&op.to_string()); + module_source.push('\n'); + } +} diff --git a/crates/bolt/src/protocols/jolt/artifacts.rs b/crates/bolt/src/protocols/jolt/artifacts.rs new file mode 100644 index 0000000000..0a13379d7d --- /dev/null +++ b/crates/bolt/src/protocols/jolt/artifacts.rs @@ -0,0 +1,2349 @@ +use std::path::Path; + +use crate::emit::rust::{ + assemble_generated_crates, assemble_workspace_generated_crates, protocol_rust_artifact, + validate_rust_artifact_imports, write_generated_crates, EmitError, GeneratedCrate, + GeneratedFile, ProtocolArtifactConfig, ProtocolArtifactExtension, ProtocolCrateRef, + ProtocolProverApiExtension, ProtocolRuntimeModule, ProtocolRustArtifact, ProtocolStage, + ProtocolStageKind, ProtocolStandaloneDependency, ProtocolVerifierApiExtension, RustSourceFile, + RustTypeRef, +}; +use crate::ir::Role; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum JoltProtocolStage { + Commitment, + Stage1Outer, + Stage2, + Stage3, + Stage4, + Stage5, + Stage6, + Stage7, + Stage8, +} + +impl JoltProtocolStage { + pub fn name(self) -> &'static str { + match self { + Self::Commitment => "commitment", + Self::Stage1Outer => "stage1_outer", + Self::Stage2 => "stage2", + Self::Stage3 => "stage3", + Self::Stage4 => "stage4", + Self::Stage5 => "stage5", + Self::Stage6 => "stage6", + Self::Stage7 => "stage7", + Self::Stage8 => "stage8", + } + } + + fn expected_filename(self, role: &Role) -> &'static str { + match (self, role) { + (Self::Commitment, Role::Prover) => "prove_commitment_phase.rs", + (Self::Commitment, Role::Verifier) => "verify_commitment_phase.rs", + (Self::Stage1Outer, Role::Prover) => "prove_stage1_outer.rs", + (Self::Stage1Outer, Role::Verifier) => "verify_stage1_outer.rs", + (Self::Stage2, Role::Prover) => "prove_stage2.rs", + (Self::Stage2, Role::Verifier) => "verify_stage2.rs", + (Self::Stage3, Role::Prover) => "prove_stage3.rs", + (Self::Stage3, Role::Verifier) => "verify_stage3.rs", + (Self::Stage4, Role::Prover) => "prove_stage4.rs", + (Self::Stage4, Role::Verifier) => "verify_stage4.rs", + (Self::Stage5, Role::Prover) => "prove_stage5.rs", + (Self::Stage5, Role::Verifier) => "verify_stage5.rs", + (Self::Stage6, Role::Prover) => "prove_stage6.rs", + (Self::Stage6, Role::Verifier) => "verify_stage6.rs", + (Self::Stage7, Role::Prover) => "prove_stage7.rs", + (Self::Stage7, Role::Verifier) => "verify_stage7.rs", + (Self::Stage8, Role::Prover) => "prove_stage8.rs", + (Self::Stage8, Role::Verifier) => "verify_stage8.rs", + } + } +} + +impl From for ProtocolStage { + fn from(stage: JoltProtocolStage) -> Self { + match stage { + JoltProtocolStage::Commitment => { + ProtocolStage::new("commitment", "commitment", 0, ProtocolStageKind::Commitment) + } + JoltProtocolStage::Stage1Outer => { + ProtocolStage::new("stage1_outer", "stage1_outer", 1, ProtocolStageKind::Proof) + } + JoltProtocolStage::Stage2 => { + ProtocolStage::new("stage2", "stage2", 2, ProtocolStageKind::Proof) + } + JoltProtocolStage::Stage3 => { + ProtocolStage::new("stage3", "stage3", 3, ProtocolStageKind::Proof) + } + JoltProtocolStage::Stage4 => { + ProtocolStage::new("stage4", "stage4", 4, ProtocolStageKind::Proof) + } + JoltProtocolStage::Stage5 => { + ProtocolStage::new("stage5", "stage5", 5, ProtocolStageKind::Proof) + } + JoltProtocolStage::Stage6 => { + ProtocolStage::new("stage6", "stage6", 6, ProtocolStageKind::Proof) + } + JoltProtocolStage::Stage7 => { + ProtocolStage::new("stage7", "stage7", 7, ProtocolStageKind::Proof) + } + JoltProtocolStage::Stage8 => { + ProtocolStage::new("stage8", "stage8", 8, ProtocolStageKind::Evaluation) + } + } + } +} + +impl PartialEq for ProtocolStage { + fn eq(&self, other: &JoltProtocolStage) -> bool { + self == &ProtocolStage::from(*other) + } +} + +impl PartialEq for JoltProtocolStage { + fn eq(&self, other: &ProtocolStage) -> bool { + other == self + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum JoltArtifactCrate { + Prover, + Verifier, +} + +impl JoltArtifactCrate { + pub fn for_role(role: &Role) -> Self { + match role { + Role::Prover => Self::Prover, + Role::Verifier => Self::Verifier, + } + } + + pub fn name(self) -> &'static str { + match self { + Self::Prover => "jolt-prover", + Self::Verifier => "jolt-verifier", + } + } +} + +pub type JoltRustArtifact = ProtocolRustArtifact; +pub type JoltGeneratedCrate = GeneratedCrate; +pub type JoltGeneratedFile = GeneratedFile; + +pub fn write_jolt_generated_crates( + generated_crates: &[GeneratedCrate], + output_root: impl AsRef, +) -> Result<(), EmitError> { + write_generated_crates(generated_crates, output_root) +} + +pub fn jolt_rust_artifact( + stage: JoltProtocolStage, + role: Role, + source: RustSourceFile, +) -> Result { + let expected = stage.expected_filename(&role); + if source.filename != expected { + return Err(EmitError::new(format!( + "generated {} artifact for {} expected filename `{expected}`, got `{}`", + role, + stage.name(), + source.filename + ))); + } + + Ok(protocol_rust_artifact( + &jolt_artifact_config(), + ProtocolStage::from(stage), + role, + source, + )) +} + +pub fn validate_jolt_rust_artifact_imports( + artifact: &ProtocolRustArtifact, +) -> Result<(), EmitError> { + validate_rust_artifact_imports(&jolt_artifact_config(), artifact) +} + +pub fn assemble_jolt_generated_crates( + artifacts: Vec, + dependency_root: &str, +) -> Result, EmitError> { + assemble_generated_crates(&jolt_artifact_config(), artifacts, dependency_root) +} + +pub fn assemble_jolt_workspace_generated_crates( + artifacts: Vec, +) -> Result, EmitError> { + assemble_workspace_generated_crates(&jolt_artifact_config(), artifacts) +} + +pub fn jolt_artifact_config() -> ProtocolArtifactConfig { + ProtocolArtifactConfig { + protocol_name: "Jolt".to_owned(), + type_prefix: "Jolt".to_owned(), + transcript_label: "Jolt".to_owned(), + repository: Some("https://github.com/a16z/jolt".to_owned()), + prover_crate_name: "jolt-prover".to_owned(), + verifier_crate_name: "jolt-verifier".to_owned(), + crates_io_patches: vec![ + "ark-bn254 = { git = \"https://github.com/a16z/arkworks-algebra\", branch = \"dev/twist-shout\" }".to_owned(), + "ark-ec = { git = \"https://github.com/a16z/arkworks-algebra\", branch = \"dev/twist-shout\" }".to_owned(), + "ark-ff = { git = \"https://github.com/a16z/arkworks-algebra\", branch = \"dev/twist-shout\" }".to_owned(), + "ark-serialize = { git = \"https://github.com/a16z/arkworks-algebra\", branch = \"dev/twist-shout\" }".to_owned(), + ], + standalone_dependency_overrides: vec![ProtocolStandaloneDependency::new( + "rayon", + "rayon = \"1.12.0\"", + ), ProtocolStandaloneDependency::new( + "serde", + "serde = { version = \"1.0\", default-features = false, features = [\"derive\"] }", + ), ProtocolStandaloneDependency::new( + "tracing", + "tracing = { version = \"0.1.37\", default-features = false, features = [\"attributes\"] }", + )], + common_dependencies: vec![ + "jolt-field".to_owned(), + "jolt-openings".to_owned(), + "jolt-poly".to_owned(), + "jolt-transcript".to_owned(), + "tracing".to_owned(), + ], + prover_dependencies: vec![ + "jolt-dory".to_owned(), + "jolt-kernels".to_owned(), + "jolt-witness".to_owned(), + "rayon".to_owned(), + ], + verifier_dependencies: vec![ + "jolt-dory".to_owned(), + "jolt-lookup-tables".to_owned(), + "jolt-sumcheck".to_owned(), + "serde".to_owned(), + ], + instrumentation_prefix: Some("bolt".to_owned()), + prover_forbidden_imports: PROVER_FORBIDDEN_IMPORTS + .iter() + .map(ToString::to_string) + .collect(), + verifier_forbidden_imports: VERIFIER_FORBIDDEN_IMPORTS + .iter() + .map(ToString::to_string) + .collect(), + kernel_crate: Some(ProtocolCrateRef::new("jolt-kernels", "jolt_kernels")), + field_type: RustTypeRef::new("jolt_field::Fr"), + default_transcript_type: RustTypeRef::new("jolt_transcript::Blake2bTranscript"), + transcript_trait: RustTypeRef::new("jolt_transcript::Transcript"), + commitment_type: RustTypeRef::new("jolt_dory::DoryCommitment"), + prover_setup_type: RustTypeRef::new("jolt_dory::DoryProverSetup"), + role_api_extension: Some(jolt_evaluation_role_api_extension()), + verifier_runtime_modules: vec![ProtocolRuntimeModule { + module_name: "common".to_owned(), + file: GeneratedFile { + path: "src/stages/common.rs".to_owned(), + source: include_str!("verifier_common.rs.template").to_owned(), + }, + }], + verifier_named_eval_type: RustTypeRef::new("crate::stages::common::StageNamedEval"), + verifier_sumcheck_output_type: RustTypeRef::new( + "crate::stages::common::StageSumcheckOutput", + ), + verifier_stage_proof_type: RustTypeRef::new("crate::stages::common::StageProof"), + } +} + +fn jolt_prover_lib_module() -> String { + r"#[rustfmt::skip] +pub mod prover; +pub mod stages; + +pub use prover::{ + default_prover_programs, jolt_proof_through_stage5, jolt_proof_through_stage6, + jolt_proof_through_stage7, prove_jolt, prove_jolt_evaluation_proof, prove_jolt_with_programs, + prove_jolt_with_stage_inputs, prove_jolt_with_witness_inputs, + prove_stage1_outer_inputs_with_program, prove_stage2_inputs_with_program, + prove_stage3_inputs_with_program, prove_stage4_inputs_with_program, + prove_stage5_inputs_with_program, prove_stage6_inputs_with_program, + prove_stage7_inputs_with_program, replay_stage1_outer_proof_with_program, + replay_stage2_proof_with_program, replay_stage3_proof_with_program, + replay_stage4_proof_with_program, replay_stage5_proof_with_program, + replay_stage6_proof_with_program, replay_stage7_proof_with_program, stage1_outer_proof, + stage1_outer_proof_from_kernel_proof, stage1_outer_prover_inputs, + stage2_opening_inputs_from_artifacts, stage2_proof, stage2_prover_inputs, + stage2_verifier_ram_data, stage3_opening_inputs_from_artifacts, stage3_proof, + stage3_prover_inputs, stage4_opening_inputs_from_artifacts, stage4_proof, stage4_prover_inputs, + stage5_kernel_proof, stage5_opening_inputs_from_artifacts, stage5_proof, stage5_prover_inputs, + stage6_bytecode_read_raf_data_from_witness_entries, stage6_execution_artifacts, + stage6_kernel_proof, stage6_opening_inputs_from_artifacts, stage6_proof, stage6_prover_inputs, + stage6_witness_from_opening_inputs, stage7_execution_artifacts, stage7_kernel_proof, + stage7_opening_inputs_from_stage6_artifacts, + stage7_opening_inputs_from_stage6_artifacts_with_program, stage7_proof, stage7_prover_inputs, + verifier_opening_inputs_from_kernel, DefaultJoltTranscript, JoltEvaluationProveError, + JoltKernelOpeningInput, JoltOpeningInputError, JoltProveError, JoltProverArtifacts, + JoltProverInputs, JoltProverPrograms, JoltProverStageInputs, JoltProverWitnessInputs, + JoltStage2RamDataStorage, +}; + +pub use prover::{ + prove_stage1_outer_with_witness_inputs, prove_stage2_with_witness_inputs, + prove_stage3_with_witness_inputs, prove_stage4_with_trace_witness_inputs, + prove_stage4_with_witness_inputs, prove_stage5_with_trace_witness_inputs, + prove_stage5_with_witness_inputs, prove_stage6_with_trace_witness_inputs, + prove_stage6_with_witness_inputs, prove_stage7_with_trace_witness_inputs, + prove_stage7_with_witness_inputs, stage6_verifier_data_from_witness_entries, +};" + .to_owned() +} + +fn jolt_verifier_lib_module() -> String { + r"pub mod stages; +#[rustfmt::skip] +pub mod verifier; + +pub use stages::{ + stage1_outer::{verify_stage1_outer_with_program, Stage1VerifierProgramPlan}, + stage2::{verify_stage2_with_program, Stage2VerifierProgramPlan}, + stage3::{verify_stage3_with_program, Stage3VerifierProgramPlan}, + stage4::{verify_stage4_with_program, Stage4VerifierProgramPlan}, + stage5::{verify_stage5_with_program, Stage5VerifierProgramPlan}, + stage6::{verify_stage6_with_program, Stage6VerifierProgramPlan}, + stage7::{verify_stage7_with_program, Stage7VerifierProgramPlan}, +}; + +pub use verifier::{ + default_verifier_programs, verify_jolt, verify_jolt_evaluation_proof, verify_jolt_prefix, + verify_jolt_prefix_with_programs, verify_jolt_through_stage5, + verify_jolt_through_stage5_with_programs, verify_jolt_through_stage6, + verify_jolt_through_stage6_with_programs, verify_jolt_through_stage7, + verify_jolt_through_stage7_with_programs, verify_jolt_with_programs, JoltEvaluationProof, + JoltEvaluationProofError, JoltNamedEval, JoltProof, JoltStage2RamAccess, JoltStage2RamData, + JoltStage2RamOutputLayout, JoltStageChallengeVector, JoltStageExecutionArtifacts, + JoltStage6BytecodeEntry, JoltStage6BytecodeReadRafData, JoltStage6VerifierData, + JoltStageOpeningInputValue, JoltStageProof, JoltSumcheckOutput, JoltVerificationArtifacts, + JoltVerifierInputs, JoltVerifierPrograms, JoltVerifierTarget, JoltVerifyError, +};" + .to_owned() +} + +fn jolt_evaluation_role_api_extension() -> ProtocolArtifactExtension { + ProtocolArtifactExtension { + required_commitment: true, + required_proof_stages: vec!["stage6".to_owned(), "stage7".to_owned()], + required_artifact_stages: vec!["stage8".to_owned()], + prover: ProtocolProverApiExtension { + lib_module: jolt_prover_lib_module(), + imports: "#![expect(\n clippy::too_many_arguments,\n reason = \"generated prover helpers mirror staged protocol ABIs\"\n)]\n\nuse jolt_dory::{DoryCommitment, DoryHint, DoryProverSetup, DoryScheme};\nuse jolt_field::{Field, Fr};\nuse jolt_kernels::{stage1, stage2, stage3, stage4, stage5, stage6, stage7};\nuse jolt_openings::{AdditivelyHomomorphic, CommitmentScheme};\nuse jolt_poly::{EqPolynomial, Polynomial};\nuse jolt_transcript::{AppendToTranscript, Blake2bTranscript, LabelWithCount, Transcript};\nuse jolt_verifier::{JoltEvaluationProof, JoltNamedEval, JoltProof, JoltStage2RamAccess, JoltStage2RamData, JoltStage2RamOutputLayout, JoltStage6BytecodeEntry, JoltStage6BytecodeReadRafData, JoltStage6VerifierData, JoltStageChallengeVector, JoltStageExecutionArtifacts, JoltStageOpeningInputValue, JoltStageProof, JoltSumcheckOutput};\nuse jolt_witness::{stage4_ram_val_init_opening, CycleInput, Stage45SparseTraceWitness, Stage6BytecodeEntry as WitnessStage6BytecodeEntry, Stage6WitnessParams, Stage6WitnessPolynomials, Stage6WitnessSlices};\nuse rayon::prelude::*;\n\n".to_owned(), + input_fields: + " pub stage7_openings: Option<&'a [stage7::Stage7OpeningInputValue]>,\n" + .to_owned(), + program_fields: + " pub stage8: &'static stage8_stage::Stage8EvaluationProgramPlan,\n".to_owned(), + default_program_fields: " stage8: &stage8_stage::STAGE8_PROGRAM,\n".to_owned(), + error_variants: " Evaluation(JoltEvaluationProveError),\n".to_owned(), + error_items: "#[derive(Debug)]\npub enum JoltEvaluationProveError {\n MissingOracle { oracle: &'static str },\n MissingOpeningHint { oracle: &'static str },\n MissingStageEval { stage: &'static str, eval: &'static str },\n MissingStage7RaEval,\n MissingStage7EvaluationPoint,\n InvalidPointLength {\n artifact: &'static str,\n expected: usize,\n actual: usize,\n },\n TargetSizeOverflow { num_vars: usize },\n}\n\n#[derive(Debug)]\npub enum JoltOpeningInputError {\n MissingOpeningClaim { stage: &'static str, source_claim: &'static str },\n MissingStage6OpeningClaim { source_claim: &'static str },\n UnsupportedOpeningInputSource { stage: &'static str, symbol: &'static str, source_stage: &'static str },\n UnsupportedStage7InputSource { symbol: &'static str, source_stage: &'static str },\n InvalidPointLength {\n symbol: &'static str,\n expected: usize,\n actual: usize,\n },\n}\n\n".to_owned(), + error_conversions: "impl From for JoltProveError {\n fn from(error: JoltEvaluationProveError) -> Self {\n Self::Evaluation(error)\n }\n}\n\n".to_owned(), + after_stage_execution: " let evaluation = if let Some(stage7_openings) = inputs.stage7_openings {\n let _stage8_span = tracing::info_span!(\"bolt.stage8\").entered();\n let _evaluate_span = tracing::info_span!(\"bolt.evaluate\").entered();\n Some(prove_jolt_evaluation_proof(\n programs.stage8,\n inputs.commitment_inputs,\n inputs.prover_setup,\n &commitment,\n &stage6,\n &stage7,\n stage7_openings,\n transcript,\n )?)\n } else {\n None\n };\n".to_owned(), + proof_fields: " evaluation,\n".to_owned(), + helper_items: format!( + "{}{}", + jolt_prover_evaluation_helpers("Fr"), + jolt_prover_stage7_opening_input_helpers("Fr") + ), + }, + verifier: ProtocolVerifierApiExtension { + lib_module: jolt_verifier_lib_module(), + imports: "use std::collections::BTreeMap;\n\nuse jolt_dory::{DoryCommitment, DoryProof, DoryScheme, DoryVerifierSetup};\nuse jolt_field::{Field, Fr};\nuse jolt_openings::{AdditivelyHomomorphic, CommitmentScheme, OpeningsError};\nuse jolt_poly::EqPolynomial;\nuse jolt_transcript::{AppendToTranscript, LabelWithCount, Transcript};\n".to_owned(), + proof_fields: " pub evaluation: Option,\n".to_owned(), + proof_items: "pub type JoltStage2RamAccess = crate::stages::stage2::Stage2RamAccess;\npub type JoltStage2RamOutputLayout = crate::stages::stage2::Stage2RamOutputLayout;\npub type JoltStage2RamData<'a> = crate::stages::stage2::Stage2RamData<'a>;\npub type JoltStageChallengeVector = crate::stages::common::StageChallengeVector;\npub type JoltStageExecutionArtifacts = crate::stages::common::StageExecutionArtifacts;\npub type JoltStageOpeningInputValue = crate::stages::common::StageOpeningInputValue;\n\n#[derive(Clone, Debug)]\npub struct JoltEvaluationProof {\n pub joint_opening_proof: DoryProof,\n}\n\n".to_owned(), + inputs_derive: Some("#[derive(Clone, Copy)]".to_owned()), + input_fields: " pub evaluation_setup: Option<&'a DoryVerifierSetup>,\n".to_owned(), + program_fields: + " pub stage8: &'static stage8_stage::Stage8EvaluationProgramPlan,\n".to_owned(), + default_program_fields: " stage8: &stage8_stage::STAGE8_PROGRAM,\n".to_owned(), + error_variants: " Evaluation(JoltEvaluationProofError),\n".to_owned(), + error_items: format!("{}{}", jolt_verifier_target_items(), "#[derive(Debug)]\npub enum JoltEvaluationProofError {\n MissingProof,\n MissingVerifierSetup,\n MissingStageEval { stage: &'static str, eval: &'static str },\n MissingStage7RaEval,\n MissingStage7EvaluationPoint,\n MissingCommitment { oracle: &'static str },\n InvalidPointLength {\n artifact: &'static str,\n expected: usize,\n actual: usize,\n },\n Opening(OpeningsError),\n}\n\n"), + error_conversions: "impl From for JoltVerifyError {\n fn from(error: JoltEvaluationProofError) -> Self {\n Self::Evaluation(error)\n }\n}\n\nimpl From for JoltEvaluationProofError {\n fn from(error: OpeningsError) -> Self {\n Self::Opening(error)\n }\n}\n\n".to_owned(), + after_default_verify: "pub fn verify_jolt_prefix>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, transcript: &mut T) -> Result { verify_jolt_prefix_with_programs(proof, inputs, default_verifier_programs(), transcript) }\n\npub fn verify_jolt_through_stage5>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, transcript: &mut T) -> Result { verify_jolt_through_stage5_with_programs(proof, inputs, default_verifier_programs(), transcript) }\n\npub fn verify_jolt_through_stage6>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, transcript: &mut T) -> Result { verify_jolt_through_stage6_with_programs(proof, inputs, default_verifier_programs(), transcript) }\n\npub fn verify_jolt_through_stage7>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, transcript: &mut T) -> Result { verify_jolt_through_stage7_with_programs(proof, inputs, default_verifier_programs(), transcript) }\n\n".to_owned(), + with_programs_body_intro: " verify_jolt_with_programs_inner(proof, inputs, programs, transcript, JoltVerifierTarget::Full)\n}\n\npub fn verify_jolt_through_stage5_with_programs>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, programs: JoltVerifierPrograms, transcript: &mut T) -> Result { verify_jolt_with_programs_inner(proof, inputs, programs, transcript, JoltVerifierTarget::ThroughStage5) }\n\npub fn verify_jolt_through_stage6_with_programs>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, programs: JoltVerifierPrograms, transcript: &mut T) -> Result { verify_jolt_with_programs_inner(proof, inputs, programs, transcript, JoltVerifierTarget::ThroughStage6) }\n\npub fn verify_jolt_through_stage7_with_programs>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, programs: JoltVerifierPrograms, transcript: &mut T) -> Result { verify_jolt_with_programs_inner(proof, inputs, programs, transcript, JoltVerifierTarget::ThroughStage7) }\n\npub fn verify_jolt_prefix_with_programs>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, programs: JoltVerifierPrograms, transcript: &mut T) -> Result { verify_jolt_through_stage7_with_programs(proof, inputs, programs, transcript) }\n\nfn verify_jolt_with_programs_inner>(proof: &JoltProof, inputs: JoltVerifierInputs<'_>, programs: JoltVerifierPrograms, transcript: &mut T, target: JoltVerifierTarget) -> Result {\n".to_owned(), + stage_verification_override: jolt_verifier_stage_verification(), + after_stage_verification: jolt_verifier_evaluation_check(), + helper_items: format!( + "{}{}", + jolt_verifier_input_helpers("Jolt"), + jolt_verifier_evaluation_helpers("Jolt", "Fr") + ), + }, + } +} + +fn jolt_verifier_target_items() -> String { + "#[derive(Clone, Copy, Debug, PartialEq, Eq)]\npub enum JoltVerifierTarget {\n ThroughStage5,\n ThroughStage6,\n ThroughStage7,\n Full,\n}\n\nimpl JoltVerifierTarget {\n fn verifies_stage6(self) -> bool { matches!(self, Self::ThroughStage6 | Self::ThroughStage7 | Self::Full) }\n fn verifies_stage7(self) -> bool { matches!(self, Self::ThroughStage7 | Self::Full) }\n fn verifies_evaluation(self) -> bool { matches!(self, Self::Full) }\n fn allows_optional_evaluation(self) -> bool { matches!(self, Self::ThroughStage7 | Self::Full) }\n}\n\n".to_owned() +} + +fn jolt_verifier_stage_verification() -> String { + " let stage1_outer = stage1_outer_stage::verify_stage1_outer_with_program(programs.stage1_outer, &proof.stage1_outer, transcript)?;\n let stage2 = stage2_stage::verify_stage2_with_program(programs.stage2, &proof.stage2, inputs.stage2_openings, inputs.stage2_ram, transcript)?;\n let stage3 = stage3_stage::verify_stage3_with_program(programs.stage3, &proof.stage3, inputs.stage3_openings, transcript)?;\n let stage4 = stage4_stage::verify_stage4_with_program(programs.stage4, &proof.stage4, inputs.stage4_openings, transcript)?;\n let stage5 = stage5_stage::verify_stage5_with_program(programs.stage5, &proof.stage5, inputs.stage5_openings, transcript)?;\n let stage6 = if target.verifies_stage6() {\n stage6_stage::verify_stage6_with_program(programs.stage6, &proof.stage6, inputs.stage6_openings, inputs.stage6_data, transcript)?\n } else {\n stage6_stage::Stage6ExecutionArtifacts::default()\n };\n let stage7 = if target.verifies_stage7() {\n stage7_stage::verify_stage7_with_program(programs.stage7, &proof.stage7, inputs.stage7_openings, transcript)?\n } else {\n stage7_stage::Stage7ExecutionArtifacts::default()\n };\n".to_owned() +} + +fn jolt_verifier_evaluation_check() -> String { + " if target.allows_optional_evaluation() {\n match (&proof.evaluation, inputs.evaluation_setup) {\n (Some(evaluation), Some(setup)) => {\n verify_jolt_evaluation_proof(\n programs.stage8,\n evaluation,\n &commitment,\n &proof.stage6,\n &proof.stage7,\n inputs.stage7_openings,\n setup,\n transcript,\n )?;\n }\n (Some(_), None) => return Err(JoltEvaluationProofError::MissingVerifierSetup.into()),\n (None, Some(_)) => return Err(JoltEvaluationProofError::MissingProof.into()),\n (None, None) if target.verifies_evaluation() => return Err(JoltEvaluationProofError::MissingProof.into()),\n (None, None) => {}\n }\n }\n".to_owned() +} + +fn jolt_verifier_input_helpers(prefix: &str) -> String { + format!( + "impl<'a> {prefix}VerifierInputs<'a> {{\n pub fn through_stage5(mut self) -> Self {{ self.stage6_openings = &[]; self.stage7_openings = &[]; self.evaluation_setup = None; self }}\n pub fn through_stage6(mut self) -> Self {{ self.stage7_openings = &[]; self.evaluation_setup = None; self }}\n pub fn through_stage7(mut self) -> Self {{ self.evaluation_setup = None; self }}\n pub fn full(mut self, evaluation_setup: &'a DoryVerifierSetup) -> Self {{ self.evaluation_setup = Some(evaluation_setup); self }}\n}}\n\n" + ) +} + +fn jolt_verifier_evaluation_helpers(prefix: &str, field_type: &str) -> String { + format!( + r#"pub type {prefix}Stage6BytecodeEntry = crate::stages::stage6::Stage6BytecodeEntry; +pub type {prefix}Stage6BytecodeReadRafData = crate::stages::stage6::Stage6BytecodeReadRafData; +pub type {prefix}Stage6VerifierData = crate::stages::stage6::Stage6VerifierData; + +#[expect( + clippy::too_many_arguments, + reason = "generated verifier entry point follows the Jolt proof artifact boundary" +)] +pub fn verify_jolt_evaluation_proof( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + proof: &{prefix}EvaluationProof, + commitments: &commitment_stage::CommitmentArtifacts, + stage6: &{prefix}StageProof, + stage7: &{prefix}StageProof, + stage7_openings: &[stage7_stage::Stage7OpeningInputValue<{field_type}>], + verifier_setup: &DoryVerifierSetup, + transcript: &mut T, +) -> Result<(), {prefix}EvaluationProofError> +where + T: Transcript, +{{ + let _state_span = tracing::info_span!("bolt.verify.evaluation_state").entered(); + let state = + evaluation_proof_state(program, commitments, stage6, stage7, stage7_openings, transcript)?; + drop(_state_span); + let _dory_verify_span = tracing::info_span!("bolt.verify.dory_verify").entered(); + ::verify( + &state.joint_commitment, + &state.opening_point, + state.joint_claim, + &proof.joint_opening_proof, + verifier_setup, + transcript, + )?; + drop(_dory_verify_span); + let _bind_span = tracing::info_span!("bolt.verify.bind_opening_inputs").entered(); + ::bind_opening_inputs( + transcript, + &state.opening_point, + &state.joint_claim, + ); + drop(_bind_span); + Ok(()) +}} + +struct EvaluationProofState {{ + opening_point: Vec<{field_type}>, + joint_claim: {field_type}, + joint_commitment: DoryCommitment, +}} + +struct EvaluationClaim {{ + oracle: &'static str, + value: {field_type}, +}} + +fn evaluation_proof_state( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + commitments: &commitment_stage::CommitmentArtifacts, + stage6: &{prefix}StageProof, + stage7: &{prefix}StageProof, + stage7_openings: &[stage7_stage::Stage7OpeningInputValue<{field_type}>], + transcript: &mut T, +) -> Result +where + T: Transcript, +{{ + let (sumcheck_address_point, stage7_values) = stage7_claim_values(program, stage7)?; + let address_point = reverse_point(&sumcheck_address_point); + let opening_point = stage7_evaluation_opening_point(program, &address_point, stage7_openings)?; + let lagrange_factor = EqPolynomial::<{field_type}>::zero_selector(&address_point); + let claims = evaluation_claims(program, stage6, &stage7_values, lagrange_factor)?; + + append_rlc_claims(transcript, &claims); + let gamma_powers = gamma_powers(transcript, claims.len()); + let joint_claim = claims + .iter() + .zip(&gamma_powers) + .map(|(claim, gamma)| claim.value * *gamma) + .sum(); + let joint_commitment = joint_commitment(commitments, &claims, &gamma_powers)?; + + Ok(EvaluationProofState {{ + opening_point, + joint_claim, + joint_commitment, + }}) +}} + +fn stage_eval( + proof: &{prefix}StageProof, + stage: &'static str, + eval_name: &'static str, +) -> Result<{field_type}, {prefix}EvaluationProofError> {{ + for output in &proof.sumchecks {{ + if let Some(eval) = output.evals.iter().find(|eval| eval.name == eval_name) {{ + return Ok(eval.value); + }} + }} + Err({prefix}EvaluationProofError::MissingStageEval {{ + stage, + eval: eval_name, + }}) +}} + +fn evaluation_claims( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + stage6: &{prefix}StageProof, + stage7_values: &BTreeMap<&'static str, {field_type}>, + lagrange_factor: {field_type}, +) -> Result, {prefix}EvaluationProofError> {{ + let mut claims = Vec::with_capacity(program.opening_claims.len()); + for plan in program.opening_claims {{ + let value = match plan.source_stage {{ + "stage6" => stage_eval(stage6, plan.source_stage, plan.source_claim)? * lagrange_factor, + "stage7" => *stage7_values.get(plan.source_claim).ok_or( + {prefix}EvaluationProofError::MissingStageEval {{ + stage: plan.source_stage, + eval: plan.source_claim, + }}, + )?, + _ => {{ + return Err({prefix}EvaluationProofError::MissingStageEval {{ + stage: plan.source_stage, + eval: plan.source_claim, + }}); + }} + }}; + claims.push(EvaluationClaim {{ + oracle: plan.oracle, + value, + }}); + }} + Ok(claims) +}} + +fn stage7_claim_values( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + proof: &{prefix}StageProof, +) -> Result<(Vec<{field_type}>, BTreeMap<&'static str, {field_type}>), {prefix}EvaluationProofError> {{ + let stage7_plans = program + .opening_claims + .iter() + .filter(|plan| plan.source_stage == "stage7") + .collect::>(); + for output in &proof.sumchecks {{ + let mut values = BTreeMap::new(); + for plan in &stage7_plans {{ + if let Some(eval) = output.evals.iter().find(|eval| eval.name == plan.source_claim) {{ + let _ = values.insert(plan.source_claim, eval.value); + }} + }} + if values.len() == stage7_plans.len() {{ + return Ok((output.point.clone(), values)); + }} + }} + Err({prefix}EvaluationProofError::MissingStage7RaEval) +}} + +fn reverse_point(point: &[{field_type}]) -> Vec<{field_type}> {{ + point.iter().rev().copied().collect() +}} + +fn stage7_evaluation_opening_point( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + address_point: &[{field_type}], + stage7_openings: &[stage7_stage::Stage7OpeningInputValue<{field_type}>], +) -> Result, {prefix}EvaluationProofError> {{ + let cycle_source_symbol = program.evaluation_point_source.source_claim; + let cycle_source = stage7_openings + .iter() + .find(|input| input.symbol == cycle_source_symbol) + .ok_or({prefix}EvaluationProofError::MissingStage7EvaluationPoint)?; + if cycle_source.point.len() < address_point.len() {{ + return Err({prefix}EvaluationProofError::InvalidPointLength {{ + artifact: cycle_source_symbol, + expected: address_point.len(), + actual: cycle_source.point.len(), + }}); + }} + let mut point = Vec::with_capacity(cycle_source.point.len()); + point.extend_from_slice(address_point); + point.extend_from_slice(&cycle_source.point[address_point.len()..]); + Ok(point) +}} + +fn append_rlc_claims(transcript: &mut T, claims: &[EvaluationClaim]) +where + T: Transcript, +{{ + transcript.append(&LabelWithCount(b"rlc_claims", claims.len() as u64)); + for claim in claims {{ + claim.value.append_to_transcript(transcript); + }} +}} + +fn gamma_powers(transcript: &mut T, count: usize) -> Vec<{field_type}> +where + T: Transcript, +{{ + let gamma = transcript.challenge(); + let mut powers = Vec::with_capacity(count); + let mut power = {field_type}::from_u64(1); + for _ in 0..count {{ + powers.push(power); + power *= gamma; + }} + powers +}} + +fn joint_commitment( + commitments: &commitment_stage::CommitmentArtifacts, + claims: &[EvaluationClaim], + gamma_powers: &[{field_type}], +) -> Result {{ + let mut coefficients = BTreeMap::<&'static str, {field_type}>::new(); + for (claim, gamma) in claims.iter().zip(gamma_powers) {{ + let coefficient = coefficients.entry(claim.oracle).or_insert({field_type}::from_u64(0)); + *coefficient += *gamma; + }} + let mut commitment_values = Vec::with_capacity(coefficients.len()); + let mut scalars = Vec::with_capacity(coefficients.len()); + for (oracle, coefficient) in coefficients {{ + commitment_values.push(commitment_for_oracle(commitments, oracle)?); + scalars.push(coefficient); + }} + Ok(::combine( + &commitment_values, + &scalars, + )) +}} + +fn commitment_for_oracle( + commitments: &commitment_stage::CommitmentArtifacts, + oracle: &'static str, +) -> Result {{ + for (record, commitment) in commitments.records.iter().zip(&commitments.commitments) {{ + if record.oracle == oracle {{ + return commitment + .clone() + .ok_or({prefix}EvaluationProofError::MissingCommitment {{ oracle }}); + }} + }} + Err({prefix}EvaluationProofError::MissingCommitment {{ oracle }}) +}} + +"# + ) +} + +fn jolt_prover_stage7_opening_input_helpers(field_type: &str) -> String { + format!( + r#"pub struct JoltProverStageInputs<'a, CommitmentInputs> {{ + pub commitment_inputs: &'a mut CommitmentInputs, + pub prover_setup: &'a DoryProverSetup, + pub stage1_outer: stage1::Stage1ProverInputs<'a, {field_type}>, + pub stage2: stage2::Stage2ProverInputs<'a, {field_type}>, + pub stage3: stage3::Stage3ProverInputs<'a, {field_type}>, + pub stage4: stage4::Stage4ProverInputs<'a, {field_type}>, + pub stage5: stage5::Stage5ProverInputs<'a, {field_type}>, + pub stage6: stage6::Stage6ProverInputs<'a, {field_type}>, + pub stage7: stage7::Stage7ProverInputs<'a, {field_type}>, + pub stage7_openings: Option<&'a [stage7::Stage7OpeningInputValue<{field_type}>]>, +}} + +pub fn prove_jolt_with_stage_inputs( + inputs: JoltProverStageInputs<'_, CommitmentInputs>, + programs: JoltProverPrograms, + transcript: &mut T, +) -> Result<(JoltProof, JoltProverArtifacts), JoltProveError> +where + CommitmentInputs: commitment_stage::CommitmentInputProvider, + T: Transcript, +{{ + let JoltProverStageInputs {{ + commitment_inputs, + prover_setup, + stage1_outer, + stage2, + stage3, + stage4, + stage5, + stage6, + stage7, + stage7_openings, + }} = inputs; + let mut stage1_outer_executor = stage1::Stage1ProverKernelExecutor::new(stage1_outer); + let mut stage2_executor = stage2::Stage2ProverKernelExecutor::new(stage2); + let mut stage3_executor = stage3::Stage3ProverKernelExecutor::new(stage3); + let mut stage4_executor = stage4::Stage4ProverKernelExecutor::new(stage4); + let mut stage5_executor = stage5::Stage5ProverKernelExecutor::new(stage5); + let mut stage6_executor = stage6::Stage6ProverKernelExecutor::new(stage6); + let mut stage7_executor = stage7::Stage7ProverKernelExecutor::new(stage7); + prove_jolt_with_programs( + JoltProverInputs {{ + commitment_inputs, + prover_setup, + stage1_outer_executor: &mut stage1_outer_executor, + stage2_executor: &mut stage2_executor, + stage3_executor: &mut stage3_executor, + stage4_executor: &mut stage4_executor, + stage5_executor: &mut stage5_executor, + stage6_executor: &mut stage6_executor, + stage7_executor: &mut stage7_executor, + stage7_openings, + }}, + programs, + transcript, + ) +}} + +pub struct JoltProverWitnessInputs<'a, CommitmentInputs> {{ + pub commitment_inputs: &'a mut CommitmentInputs, + pub prover_setup: &'a DoryProverSetup, + pub stage1_trace_num_vars: usize, + pub stage1_outer_evaluator: &'a dyn stage1::Stage1OuterRemainingEvaluator<{field_type}>, + pub stage2_openings: &'a [stage2::Stage2OpeningInputValue<{field_type}>], + pub product_virtual_cycles: &'a [stage2::Stage2ProductVirtualCycle], + pub instruction_lookup_cycles: &'a [stage2::Stage2InstructionLookupCycle], + pub ram: &'a stage2::Stage2RamData<'a>, + pub stage3_openings: &'a [stage3::Stage3OpeningInputValue<{field_type}>], + pub stage3_cycles: &'a [stage3::Stage3Cycle], + pub stage4_openings: &'a [stage4::Stage4OpeningInputValue<{field_type}>], + pub register_count: usize, + pub trace_len: usize, + pub ram_k: usize, + pub register_accesses: &'a [stage4::Stage4RegisterAccess], + pub stage5_openings: &'a [stage5::Stage5OpeningInputValue<{field_type}>], + pub lookup_indices: &'a [u128], + pub lookup_table_indices: &'a [Option], + pub is_interleaved_operands: &'a [bool], + pub ra_virtual_log_k_chunk: usize, + pub stage6_openings: &'a [stage6::Stage6OpeningInputValue<{field_type}>], + pub stage6_bytecode_data: stage6::Stage6BytecodeReadRafData<'a, {field_type}>, + pub stage6_witness_params: Stage6WitnessParams, + pub cycle_inputs: &'a [CycleInput], + pub instruction_ra_virtual_d: usize, + pub stage7_openings: &'a [stage7::Stage7OpeningInputValue<{field_type}>], + pub evaluation_openings: Option<&'a [stage7::Stage7OpeningInputValue<{field_type}>]>, +}} + +pub fn prove_jolt_with_witness_inputs( + inputs: JoltProverWitnessInputs<'_, CommitmentInputs>, + programs: JoltProverPrograms, + transcript: &mut T, +) -> Result<(JoltProof, JoltProverArtifacts), JoltProveError> +where + CommitmentInputs: commitment_stage::CommitmentInputProvider, + T: Transcript, +{{ + let _input_span = tracing::info_span!("bolt.prove.inputs").entered(); + let _stage1_input_span = tracing::info_span!("bolt.prove.inputs.stage1").entered(); + let stage1_outer = + stage1_outer_prover_inputs(inputs.stage1_trace_num_vars, inputs.stage1_outer_evaluator); + drop(_stage1_input_span); + let _stage2_input_span = tracing::info_span!("bolt.prove.inputs.stage2").entered(); + let stage2 = stage2_prover_inputs( + inputs.stage2_openings, + inputs.product_virtual_cycles, + inputs.instruction_lookup_cycles, + inputs.ram, + )?; + drop(_stage2_input_span); + let _stage3_input_span = tracing::info_span!("bolt.prove.inputs.stage3").entered(); + let stage3 = stage3_prover_inputs(inputs.stage3_openings, inputs.stage3_cycles); + drop(_stage3_input_span); + let _stage45_witness_span = tracing::info_span!("bolt.prove.inputs.stage45_witness").entered(); + let stage45_witness = stage4::stage4_5_sparse_trace_witness_from_accesses( + inputs.register_accesses, + inputs.ram.accesses, + ); + drop(_stage45_witness_span); + let _stage4_input_span = tracing::info_span!("bolt.prove.inputs.stage4").entered(); + let stage4 = stage4_prover_inputs( + inputs.stage4_openings, + inputs.register_count, + inputs.trace_len, + inputs.ram_k, + inputs.register_accesses, + &stage45_witness, + ); + drop(_stage4_input_span); + let _stage5_input_span = tracing::info_span!("bolt.prove.inputs.stage5").entered(); + let stage5 = stage5_prover_inputs( + inputs.stage5_openings, + inputs.trace_len, + inputs.ram_k, + inputs.register_count, + inputs.lookup_indices, + inputs.lookup_table_indices, + inputs.is_interleaved_operands, + inputs.ra_virtual_log_k_chunk, + &stage45_witness, + ); + drop(_stage5_input_span); + let _stage6_witness_span = tracing::info_span!("bolt.prove.inputs.stage6_witness").entered(); + let stage6_witness = stage6_witness_from_opening_inputs( + inputs.stage6_witness_params, + inputs.cycle_inputs, + inputs.stage6_openings, + ); + let stage6_witness_slices = stage6_witness.slices(); + drop(_stage6_witness_span); + let _stage6_input_span = tracing::info_span!("bolt.prove.inputs.stage6").entered(); + let stage6 = stage6_prover_inputs( + inputs.stage6_openings, + inputs.stage6_bytecode_data, + &stage6_witness, + &stage6_witness_slices, + inputs.instruction_ra_virtual_d, + ); + drop(_stage6_input_span); + let _stage7_input_span = tracing::info_span!("bolt.prove.inputs.stage7").entered(); + let stage7 = stage7_prover_inputs(inputs.stage7_openings, &stage6_witness_slices); + drop(_stage7_input_span); + drop(_input_span); + prove_jolt_with_stage_inputs( + JoltProverStageInputs {{ + commitment_inputs: inputs.commitment_inputs, + prover_setup: inputs.prover_setup, + stage1_outer, + stage2, + stage3, + stage4, + stage5, + stage6, + stage7, + stage7_openings: inputs.evaluation_openings, + }}, + programs, + transcript, + ) +}} + +pub fn stage1_outer_prover_inputs( + trace_num_vars: usize, + evaluator: &dyn stage1::Stage1OuterRemainingEvaluator<{field_type}>, +) -> stage1::Stage1ProverInputs<'_, {field_type}> {{ + stage1::Stage1ProverInputs::empty(trace_num_vars).with_outer_remaining_evaluator(evaluator) +}} + +pub fn prove_stage1_outer_inputs_with_program( + program: &'static stage1::Stage1CpuProgramPlan, + inputs: stage1::Stage1ProverInputs<'_, {field_type}>, + transcript: &mut T, +) -> Result, stage1::Stage1KernelError> +where + T: Transcript, +{{ + let mut executor = stage1::Stage1ProverKernelExecutor::new(inputs); + stage1_outer_stage::prove_stage1_outer_with_program(program, &mut executor, transcript) +}} + +pub fn prove_stage1_outer_with_witness_inputs( + program: &'static stage1::Stage1CpuProgramPlan, + trace_num_vars: usize, + evaluator: &dyn stage1::Stage1OuterRemainingEvaluator<{field_type}>, + transcript: &mut T, +) -> Result, stage1::Stage1KernelError> +where + T: Transcript, +{{ + let inputs = stage1_outer_prover_inputs(trace_num_vars, evaluator); + prove_stage1_outer_inputs_with_program(program, inputs, transcript) +}} + +pub fn replay_stage1_outer_proof_with_program( + program: &'static stage1::Stage1CpuProgramPlan, + proof: &stage1::Stage1Proof<{field_type}>, + transcript: &mut T, +) -> Result, stage1::Stage1KernelError> +where + T: Transcript, +{{ + let mut executor = stage1::Stage1VerifierKernelExecutor::new(proof); + stage1::execute_stage1_program( + program, + stage1::Stage1ExecutionMode::Verifier, + &mut executor, + transcript, + ) +}} + +pub fn stage1_outer_proof_from_kernel_proof( + proof: &stage1::Stage1Proof<{field_type}>, +) -> JoltStageProof {{ + JoltStageProof {{ + sumchecks: proof + .sumchecks + .iter() + .map(stage1_outer_sumcheck) + .collect(), + }} +}} + +pub fn stage2_prover_inputs<'a>( + opening_inputs: &'a [stage2::Stage2OpeningInputValue<{field_type}>], + product_virtual_cycles: &'a [stage2::Stage2ProductVirtualCycle], + instruction_lookup_cycles: &'a [stage2::Stage2InstructionLookupCycle], + ram: &'a stage2::Stage2RamData<'a>, +) -> Result, stage2::Stage2KernelError> {{ + Ok(stage2::Stage2ProverInputs::new(opening_inputs) + .with_product_virtual_witness(product_virtual_cycles)? + .with_instruction_lookup_cycles(instruction_lookup_cycles) + .with_ram_data(ram)) +}} + +pub struct JoltStage2RamDataStorage<'a> {{ + log_k: usize, + start_address: u64, + initial_ram: &'a [u64], + final_ram: &'a [u64], + accesses: Vec, + output_layout: Option, +}} + +impl<'a> JoltStage2RamDataStorage<'a> {{ + pub fn from_kernel(ram: &stage2::Stage2RamData<'a>) -> Self {{ + Self {{ + log_k: ram.log_k, + start_address: ram.start_address, + initial_ram: ram.initial_ram, + final_ram: ram.final_ram, + accesses: ram + .accesses + .iter() + .map(|access| JoltStage2RamAccess {{ + remapped_address: access.remapped_address, + read_value: access.read_value, + write_value: access.write_value, + }}) + .collect(), + output_layout: ram.output_layout.map(|layout| JoltStage2RamOutputLayout {{ + io_start: layout.io_start, + io_end: layout.io_end, + }}), + }} + }} + + pub fn as_input(&self) -> JoltStage2RamData<'_> {{ + JoltStage2RamData {{ + log_k: self.log_k, + start_address: self.start_address, + initial_ram: self.initial_ram, + final_ram: self.final_ram, + accesses: &self.accesses, + output_layout: self.output_layout, + }} + }} +}} + +pub fn stage2_verifier_ram_data<'a>( + ram: &stage2::Stage2RamData<'a>, +) -> JoltStage2RamDataStorage<'a> {{ + JoltStage2RamDataStorage::from_kernel(ram) +}} + +pub trait JoltKernelOpeningInput {{ + fn symbol(&self) -> &'static str; + fn point(&self) -> &[{field_type}]; + fn eval(&self) -> {field_type}; +}} + +macro_rules! impl_jolt_kernel_opening_input {{ + ($opening:ty) => {{ + impl JoltKernelOpeningInput for $opening {{ + fn symbol(&self) -> &'static str {{ + self.symbol + }} + + fn point(&self) -> &[{field_type}] {{ + &self.point + }} + + fn eval(&self) -> {field_type} {{ + self.eval + }} + }} + }}; +}} + +impl_jolt_kernel_opening_input!(stage2::Stage2OpeningInputValue<{field_type}>); +impl_jolt_kernel_opening_input!(stage3::Stage3OpeningInputValue<{field_type}>); +impl_jolt_kernel_opening_input!(stage4::Stage4OpeningInputValue<{field_type}>); + +pub fn verifier_opening_inputs_from_kernel(inputs: &[I]) -> Vec +where + I: JoltKernelOpeningInput, +{{ + inputs + .iter() + .map(|input| JoltStageOpeningInputValue {{ + symbol: input.symbol(), + point: input.point().to_vec(), + eval: input.eval(), + }}) + .collect() +}} + +pub fn prove_stage2_inputs_with_program( + program: &'static stage2::Stage2CpuProgramPlan, + inputs: stage2::Stage2ProverInputs<'_, {field_type}>, + transcript: &mut T, +) -> Result, stage2::Stage2KernelError> +where + T: Transcript, +{{ + let mut executor = stage2::Stage2ProverKernelExecutor::new(inputs); + stage2_stage::execute_stage2_prover_with_program(program, &mut executor, transcript) +}} + +pub fn prove_stage2_with_witness_inputs<'a, T>( + program: &'static stage2::Stage2CpuProgramPlan, + opening_inputs: &'a [stage2::Stage2OpeningInputValue<{field_type}>], + product_virtual_cycles: &'a [stage2::Stage2ProductVirtualCycle], + instruction_lookup_cycles: &'a [stage2::Stage2InstructionLookupCycle], + ram: &'a stage2::Stage2RamData<'a>, + transcript: &mut T, +) -> Result, stage2::Stage2KernelError> +where + T: Transcript, +{{ + let inputs = stage2_prover_inputs( + opening_inputs, + product_virtual_cycles, + instruction_lookup_cycles, + ram, + )?; + prove_stage2_inputs_with_program(program, inputs, transcript) +}} + +pub fn stage2_opening_inputs_from_artifacts( + program: &'static stage2::Stage2CpuProgramPlan, + stage1_artifacts: &stage1::Stage1ExecutionArtifacts<{field_type}>, +) -> Result>, JoltOpeningInputError> {{ + program + .opening_inputs + .iter() + .map(|input| {{ + let (point, eval) = match input.source_stage {{ + "stage1" => stage1_opening_claim(stage1_artifacts, input.source_claim)?, + source_stage => {{ + return Err(JoltOpeningInputError::UnsupportedOpeningInputSource {{ + stage: "stage2", + symbol: input.symbol, + source_stage, + }}); + }} + }}; + validate_point_len(input.symbol, input.point_arity, point.len())?; + Ok(stage2::Stage2OpeningInputValue {{ + symbol: input.symbol, + point, + eval, + }}) + }}) + .collect() +}} + +pub fn replay_stage2_proof_with_program<'a, T>( + program: &'static stage2::Stage2CpuProgramPlan, + proof: &'a stage2::Stage2Proof<{field_type}>, + opening_inputs: &'a [stage2::Stage2OpeningInputValue<{field_type}>], + ram: Option<&'a stage2::Stage2RamData<'a>>, + transcript: &mut T, +) -> Result, stage2::Stage2KernelError> +where + T: Transcript, +{{ + let mut executor = stage2::Stage2VerifierKernelExecutor::new(proof, opening_inputs); + if let Some(ram) = ram {{ + executor = executor.with_ram_data(ram); + }} + stage2::execute_stage2_program( + program, + stage2::Stage2ExecutionMode::Verifier, + &mut executor, + transcript, + ) +}} + +pub fn stage3_prover_inputs<'a>( + opening_inputs: &'a [stage3::Stage3OpeningInputValue<{field_type}>], + cycles: &'a [stage3::Stage3Cycle], +) -> stage3::Stage3ProverInputs<'a, {field_type}> {{ + stage3::Stage3ProverInputs::new(opening_inputs).with_cycles(cycles) +}} + +pub fn prove_stage3_inputs_with_program( + program: &'static stage3::Stage3CpuProgramPlan, + inputs: stage3::Stage3ProverInputs<'_, {field_type}>, + transcript: &mut T, +) -> Result, stage3::Stage3KernelError> +where + T: Transcript, +{{ + let mut executor = stage3::Stage3ProverKernelExecutor::new(inputs); + stage3_stage::execute_stage3_prover_with_program(program, &mut executor, transcript) +}} + +pub fn prove_stage3_with_witness_inputs( + program: &'static stage3::Stage3CpuProgramPlan, + opening_inputs: &[stage3::Stage3OpeningInputValue<{field_type}>], + cycles: &[stage3::Stage3Cycle], + transcript: &mut T, +) -> Result, stage3::Stage3KernelError> +where + T: Transcript, +{{ + let inputs = stage3_prover_inputs(opening_inputs, cycles); + prove_stage3_inputs_with_program(program, inputs, transcript) +}} + +pub fn stage3_opening_inputs_from_artifacts( + program: &'static stage3::Stage3CpuProgramPlan, + stage1_artifacts: &stage1::Stage1ExecutionArtifacts<{field_type}>, + stage2_artifacts: &stage2::Stage2ExecutionArtifacts<{field_type}>, +) -> Result>, JoltOpeningInputError> {{ + program + .opening_inputs + .iter() + .map(|input| {{ + let (point, eval) = match input.source_stage {{ + "stage1" => stage1_opening_claim(stage1_artifacts, input.source_claim)?, + "stage2" => stage2_opening_claim(stage2_artifacts, input.source_claim)?, + source_stage => {{ + return Err(JoltOpeningInputError::UnsupportedOpeningInputSource {{ + stage: "stage3", + symbol: input.symbol, + source_stage, + }}); + }} + }}; + validate_point_len(input.symbol, input.point_arity, point.len())?; + Ok(stage3::Stage3OpeningInputValue {{ + symbol: input.symbol, + point, + eval, + }}) + }}) + .collect() +}} + +pub fn replay_stage3_proof_with_program( + program: &'static stage3::Stage3CpuProgramPlan, + proof: &stage3::Stage3Proof<{field_type}>, + opening_inputs: &[stage3::Stage3OpeningInputValue<{field_type}>], + transcript: &mut T, +) -> Result, stage3::Stage3KernelError> +where + T: Transcript, +{{ + let mut executor = stage3::Stage3VerifierKernelExecutor::new(proof, opening_inputs); + stage3::execute_stage3_program( + program, + stage3::Stage3ExecutionMode::Verifier, + &mut executor, + transcript, + ) +}} + +pub fn stage4_prover_inputs<'a>( + opening_inputs: &'a [stage4::Stage4OpeningInputValue<{field_type}>], + register_count: usize, + trace_len: usize, + ram_k: usize, + register_accesses: &'a [stage4::Stage4RegisterAccess], + witness: &'a Stage45SparseTraceWitness<{field_type}>, +) -> stage4::Stage4ProverInputs<'a, {field_type}> {{ + stage4::Stage4ProverInputs::new(opening_inputs).with_stage45_sparse_trace_witness( + register_count, + trace_len, + ram_k, + register_accesses, + witness, + ) +}} + +pub fn prove_stage4_inputs_with_program( + program: &'static stage4::Stage4CpuProgramPlan, + inputs: stage4::Stage4ProverInputs<'_, {field_type}>, + transcript: &mut T, +) -> Result, stage4::Stage4KernelError> +where + T: Transcript, +{{ + let mut executor = stage4::Stage4ProverKernelExecutor::new(inputs); + stage4_stage::execute_stage4_prover_with_program(program, &mut executor, transcript) +}} + +pub fn prove_stage4_with_witness_inputs( + program: &'static stage4::Stage4CpuProgramPlan, + opening_inputs: &[stage4::Stage4OpeningInputValue<{field_type}>], + register_count: usize, + trace_len: usize, + ram_k: usize, + register_accesses: &[stage4::Stage4RegisterAccess], + witness: &Stage45SparseTraceWitness<{field_type}>, + transcript: &mut T, +) -> Result, stage4::Stage4KernelError> +where + T: Transcript, +{{ + let inputs = stage4_prover_inputs( + opening_inputs, + register_count, + trace_len, + ram_k, + register_accesses, + witness, + ); + prove_stage4_inputs_with_program(program, inputs, transcript) +}} + +pub fn prove_stage4_with_trace_witness_inputs( + program: &'static stage4::Stage4CpuProgramPlan, + opening_inputs: &[stage4::Stage4OpeningInputValue<{field_type}>], + register_count: usize, + trace_len: usize, + ram_k: usize, + register_accesses: &[stage4::Stage4RegisterAccess], + ram_accesses: &[stage2::Stage2RamAccess], + transcript: &mut T, +) -> Result, stage4::Stage4KernelError> +where + T: Transcript, +{{ + let witness = stage4::stage4_5_sparse_trace_witness_from_accesses( + register_accesses, + ram_accesses, + ); + prove_stage4_with_witness_inputs( + program, + opening_inputs, + register_count, + trace_len, + ram_k, + register_accesses, + &witness, + transcript, + ) +}} + +pub fn stage4_opening_inputs_from_artifacts( + program: &'static stage4::Stage4CpuProgramPlan, + initial_ram_state: &[u64], + stage2_artifacts: &stage2::Stage2ExecutionArtifacts<{field_type}>, + stage3_artifacts: &stage3::Stage3ExecutionArtifacts<{field_type}>, +) -> Result>, JoltOpeningInputError> {{ + program + .opening_inputs + .iter() + .map(|input| {{ + let (point, eval) = match input.source_stage {{ + "stage2" => stage2_opening_claim(stage2_artifacts, input.source_claim)?, + "stage3" => stage3_opening_claim(stage3_artifacts, input.source_claim)?, + "stage4_precomputed" => {{ + let (point, _) = stage2_opening_claim( + stage2_artifacts, + "stage2.ram_output.opening.RamValFinal", + )?; + stage4_ram_val_init_opening(initial_ram_state, &point) + }} + source_stage => {{ + return Err(JoltOpeningInputError::UnsupportedOpeningInputSource {{ + stage: "stage4", + symbol: input.symbol, + source_stage, + }}); + }} + }}; + opening_input_value(input.symbol, input.point_arity, point, eval) + }}) + .collect() +}} + +pub fn replay_stage4_proof_with_program( + program: &'static stage4::Stage4CpuProgramPlan, + proof: &stage4::Stage4Proof<{field_type}>, + opening_inputs: &[stage4::Stage4OpeningInputValue<{field_type}>], + transcript: &mut T, +) -> Result, stage4::Stage4KernelError> +where + T: Transcript, +{{ + let mut executor = stage4::Stage4VerifierKernelExecutor::new(proof, opening_inputs); + stage4::execute_stage4_program( + program, + stage4::Stage4ExecutionMode::Verifier, + &mut executor, + transcript, + ) +}} + +pub fn stage5_prover_inputs<'a>( + opening_inputs: &'a [stage5::Stage5OpeningInputValue<{field_type}>], + trace_len: usize, + ram_k: usize, + register_count: usize, + lookup_indices: &'a [u128], + lookup_table_indices: &'a [Option], + is_interleaved_operands: &'a [bool], + ra_virtual_log_k_chunk: usize, + witness: &'a Stage45SparseTraceWitness<{field_type}>, +) -> stage5::Stage5ProverInputs<'a, {field_type}> {{ + stage5::Stage5ProverInputs::new(opening_inputs).with_stage45_sparse_trace_witness( + trace_len, + ram_k, + register_count, + lookup_indices, + lookup_table_indices, + is_interleaved_operands, + ra_virtual_log_k_chunk, + witness, + ) +}} + +pub fn prove_stage5_inputs_with_program( + program: &'static stage5::Stage5CpuProgramPlan, + inputs: stage5::Stage5ProverInputs<'_, {field_type}>, + transcript: &mut T, +) -> Result, stage5::Stage5KernelError> +where + T: Transcript, +{{ + let mut executor = stage5::Stage5ProverKernelExecutor::new(inputs); + stage5_stage::execute_stage5_prover_with_program(program, &mut executor, transcript) +}} + +pub fn prove_stage5_with_witness_inputs( + program: &'static stage5::Stage5CpuProgramPlan, + opening_inputs: &[stage5::Stage5OpeningInputValue<{field_type}>], + trace_len: usize, + ram_k: usize, + register_count: usize, + lookup_indices: &[u128], + lookup_table_indices: &[Option], + is_interleaved_operands: &[bool], + ra_virtual_log_k_chunk: usize, + witness: &Stage45SparseTraceWitness<{field_type}>, + transcript: &mut T, +) -> Result, stage5::Stage5KernelError> +where + T: Transcript, +{{ + let inputs = stage5_prover_inputs( + opening_inputs, + trace_len, + ram_k, + register_count, + lookup_indices, + lookup_table_indices, + is_interleaved_operands, + ra_virtual_log_k_chunk, + witness, + ); + prove_stage5_inputs_with_program(program, inputs, transcript) +}} + +pub fn prove_stage5_with_trace_witness_inputs( + program: &'static stage5::Stage5CpuProgramPlan, + opening_inputs: &[stage5::Stage5OpeningInputValue<{field_type}>], + trace_len: usize, + ram_k: usize, + register_count: usize, + lookup_indices: &[u128], + lookup_table_indices: &[Option], + is_interleaved_operands: &[bool], + ra_virtual_log_k_chunk: usize, + register_accesses: &[stage4::Stage4RegisterAccess], + ram_accesses: &[stage2::Stage2RamAccess], + transcript: &mut T, +) -> Result, stage5::Stage5KernelError> +where + T: Transcript, +{{ + let witness = stage4::stage4_5_sparse_trace_witness_from_accesses( + register_accesses, + ram_accesses, + ); + prove_stage5_with_witness_inputs( + program, + opening_inputs, + trace_len, + ram_k, + register_count, + lookup_indices, + lookup_table_indices, + is_interleaved_operands, + ra_virtual_log_k_chunk, + &witness, + transcript, + ) +}} + +pub fn stage5_opening_inputs_from_artifacts( + program: &'static stage5::Stage5CpuProgramPlan, + stage2_artifacts: &stage2::Stage2ExecutionArtifacts<{field_type}>, + stage4_artifacts: &stage4::Stage4ExecutionArtifacts<{field_type}>, +) -> Result>, JoltOpeningInputError> {{ + program + .opening_inputs + .iter() + .map(|input| {{ + let (point, eval) = match input.source_stage {{ + "stage2" => stage2_opening_claim(stage2_artifacts, input.source_claim)?, + "stage4" => stage4_opening_claim(stage4_artifacts, input.source_claim)?, + source_stage => {{ + return Err(JoltOpeningInputError::UnsupportedOpeningInputSource {{ + stage: "stage5", + symbol: input.symbol, + source_stage, + }}); + }} + }}; + opening_input_value(input.symbol, input.point_arity, point, eval) + }}) + .collect() +}} + +pub fn stage5_kernel_proof( + artifacts: &stage5::Stage5ExecutionArtifacts<{field_type}>, +) -> stage5::Stage5Proof<{field_type}> {{ + stage5::Stage5Proof {{ + sumchecks: artifacts.sumchecks.clone(), + }} +}} + +pub fn jolt_proof_through_stage5( + commitments: &[Option], + stage1_artifacts: &stage1::Stage1ExecutionArtifacts<{field_type}>, + stage2_artifacts: &stage2::Stage2ExecutionArtifacts<{field_type}>, + stage3_artifacts: &stage3::Stage3ExecutionArtifacts<{field_type}>, + stage4_artifacts: &stage4::Stage4ExecutionArtifacts<{field_type}>, + stage5_proof: &JoltStageProof, +) -> JoltProof {{ + JoltProof {{ + commitments: commitments.to_vec(), + stage1_outer: stage1_outer_proof(stage1_artifacts), + stage2: stage2_proof(stage2_artifacts), + stage3: stage3_proof(stage3_artifacts), + stage4: stage4_proof(stage4_artifacts), + stage5: stage5_proof.clone(), + stage6: JoltStageProof::default(), + stage7: JoltStageProof::default(), + evaluation: None, + }} +}} + +pub fn jolt_proof_through_stage6( + commitments: &[Option], + stage1_artifacts: &stage1::Stage1ExecutionArtifacts<{field_type}>, + stage2_artifacts: &stage2::Stage2ExecutionArtifacts<{field_type}>, + stage3_artifacts: &stage3::Stage3ExecutionArtifacts<{field_type}>, + stage4_artifacts: &stage4::Stage4ExecutionArtifacts<{field_type}>, + stage5_proof: &JoltStageProof, + stage6_proof: &JoltStageProof, +) -> JoltProof {{ + let mut proof = jolt_proof_through_stage5( + commitments, + stage1_artifacts, + stage2_artifacts, + stage3_artifacts, + stage4_artifacts, + stage5_proof, + ); + proof.stage6 = stage6_proof.clone(); + proof +}} + +pub fn jolt_proof_through_stage7( + commitments: &[Option], + stage1_artifacts: &stage1::Stage1ExecutionArtifacts<{field_type}>, + stage2_artifacts: &stage2::Stage2ExecutionArtifacts<{field_type}>, + stage3_artifacts: &stage3::Stage3ExecutionArtifacts<{field_type}>, + stage4_artifacts: &stage4::Stage4ExecutionArtifacts<{field_type}>, + stage5_proof: &JoltStageProof, + stage6_proof: &JoltStageProof, + stage7_proof: &JoltStageProof, +) -> JoltProof {{ + let mut proof = jolt_proof_through_stage6( + commitments, + stage1_artifacts, + stage2_artifacts, + stage3_artifacts, + stage4_artifacts, + stage5_proof, + stage6_proof, + ); + proof.stage7 = stage7_proof.clone(); + proof +}} + +pub fn replay_stage5_proof_with_program( + program: &'static stage5::Stage5CpuProgramPlan, + proof: &stage5::Stage5Proof<{field_type}>, + opening_inputs: &[stage5::Stage5OpeningInputValue<{field_type}>], + transcript: &mut T, +) -> Result, stage5::Stage5KernelError> +where + T: Transcript, +{{ + let mut executor = stage5::Stage5ProofCarryingKernelExecutor::new(proof, opening_inputs); + stage5_stage::execute_stage5_prover_with_program(program, &mut executor, transcript) +}} + +pub fn stage6_witness_from_opening_inputs( + params: Stage6WitnessParams, + cycle_inputs: &[CycleInput], + opening_inputs: &[stage6::Stage6OpeningInputValue<{field_type}>], +) -> Stage6WitnessPolynomials<{field_type}> {{ + stage6::stage6_witness_from_opening_inputs(params, cycle_inputs, opening_inputs) +}} + +pub fn stage6_bytecode_read_raf_data_from_witness_entries( + entries: &[WitnessStage6BytecodeEntry<{field_type}>], + entry_bytecode_index: usize, + num_lookup_tables: usize, +) -> stage6::Stage6BytecodeReadRafDataStorage<{field_type}> {{ + stage6::Stage6BytecodeReadRafDataStorage::from_witness_entries( + entries, + entry_bytecode_index, + num_lookup_tables, + ) +}} + +pub fn stage6_verifier_data_from_witness_entries( + entries: &[WitnessStage6BytecodeEntry<{field_type}>], + entry_bytecode_index: usize, + num_lookup_tables: usize, +) -> JoltStage6VerifierData {{ + JoltStage6VerifierData {{ + bytecode_read_raf: Some(JoltStage6BytecodeReadRafData {{ + entries: entries + .iter() + .map(|entry| JoltStage6BytecodeEntry {{ + address: entry.address, + imm: entry.imm, + circuit_flags: entry.circuit_flags, + rd: entry.rd, + rs1: entry.rs1, + rs2: entry.rs2, + lookup_table: entry.lookup_table, + is_interleaved: entry.is_interleaved, + is_branch: entry.is_branch, + left_is_rs1: entry.left_is_rs1, + left_is_pc: entry.left_is_pc, + right_is_rs2: entry.right_is_rs2, + right_is_imm: entry.right_is_imm, + is_noop: entry.is_noop, + }}) + .collect(), + entry_bytecode_index, + num_lookup_tables, + }}), + }} +}} + +pub fn stage6_prover_inputs<'a>( + opening_inputs: &'a [stage6::Stage6OpeningInputValue<{field_type}>], + bytecode_data: stage6::Stage6BytecodeReadRafData<'a, {field_type}>, + witness: &'a Stage6WitnessPolynomials<{field_type}>, + slices: &'a Stage6WitnessSlices<'a, {field_type}>, + instruction_ra_virtual_d: usize, +) -> stage6::Stage6ProverInputs<'a, {field_type}> {{ + stage6::Stage6ProverInputs::new(opening_inputs).with_stage6_witness( + bytecode_data, + witness, + slices, + instruction_ra_virtual_d, + ) +}} + +pub fn prove_stage6_inputs_with_program( + program: &'static stage6::Stage6CpuProgramPlan, + inputs: stage6::Stage6ProverInputs<'_, {field_type}>, + transcript: &mut T, +) -> Result, stage6::Stage6KernelError> +where + T: Transcript, +{{ + let mut executor = stage6::Stage6ProverKernelExecutor::new(inputs); + stage6_stage::execute_stage6_prover_with_program(program, &mut executor, transcript) +}} + +pub fn prove_stage6_with_witness_inputs( + program: &'static stage6::Stage6CpuProgramPlan, + opening_inputs: &[stage6::Stage6OpeningInputValue<{field_type}>], + bytecode_data: stage6::Stage6BytecodeReadRafData<'_, {field_type}>, + witness: &Stage6WitnessPolynomials<{field_type}>, + slices: &Stage6WitnessSlices<'_, {field_type}>, + instruction_ra_virtual_d: usize, + transcript: &mut T, +) -> Result, stage6::Stage6KernelError> +where + T: Transcript, +{{ + let inputs = stage6_prover_inputs( + opening_inputs, + bytecode_data, + witness, + slices, + instruction_ra_virtual_d, + ); + prove_stage6_inputs_with_program(program, inputs, transcript) +}} + +pub fn prove_stage6_with_trace_witness_inputs( + program: &'static stage6::Stage6CpuProgramPlan, + opening_inputs: &[stage6::Stage6OpeningInputValue<{field_type}>], + bytecode_data: stage6::Stage6BytecodeReadRafData<'_, {field_type}>, + witness_params: Stage6WitnessParams, + cycle_inputs: &[CycleInput], + instruction_ra_virtual_d: usize, + transcript: &mut T, +) -> Result, stage6::Stage6KernelError> +where + T: Transcript, +{{ + let witness = stage6_witness_from_opening_inputs(witness_params, cycle_inputs, opening_inputs); + let slices = witness.slices(); + prove_stage6_with_witness_inputs( + program, + opening_inputs, + bytecode_data, + &witness, + &slices, + instruction_ra_virtual_d, + transcript, + ) +}} + +pub fn stage6_opening_inputs_from_artifacts( + program: &'static stage6::Stage6CpuProgramPlan, + stage1_artifacts: &stage1::Stage1ExecutionArtifacts<{field_type}>, + stage2_artifacts: &stage2::Stage2ExecutionArtifacts<{field_type}>, + stage3_artifacts: &stage3::Stage3ExecutionArtifacts<{field_type}>, + stage4_artifacts: &stage4::Stage4ExecutionArtifacts<{field_type}>, + stage5_artifacts: &stage5::Stage5ExecutionArtifacts<{field_type}>, +) -> Result>, JoltOpeningInputError> {{ + program + .opening_inputs + .iter() + .map(|input| {{ + let (point, eval) = match input.source_stage {{ + "stage1" => stage1_opening_claim(stage1_artifacts, input.source_claim)?, + "stage2" => stage2_opening_claim(stage2_artifacts, input.source_claim)?, + "stage3" => stage3_opening_claim(stage3_artifacts, input.source_claim)?, + "stage4" => stage4_opening_claim(stage4_artifacts, input.source_claim)?, + "stage5" => stage5_opening_claim(stage5_artifacts, input.source_claim)?, + source_stage => {{ + return Err(JoltOpeningInputError::UnsupportedOpeningInputSource {{ + stage: "stage6", + symbol: input.symbol, + source_stage, + }}); + }} + }}; + opening_input_value(input.symbol, input.point_arity, point, eval) + }}) + .collect() +}} + +pub fn stage6_kernel_proof(proof: &JoltStageProof) -> stage6::Stage6Proof<{field_type}> {{ + stage6::Stage6Proof {{ + sumchecks: proof + .sumchecks + .iter() + .map(stage6_kernel_sumcheck_output) + .collect(), + }} +}} + +fn stage6_kernel_sumcheck_output( + output: &JoltSumcheckOutput, +) -> stage6::Stage6SumcheckOutput<{field_type}> {{ + stage6::Stage6SumcheckOutput {{ + driver: output.driver, + point: output.point.clone(), + evals: output.evals.iter().map(stage6_kernel_eval).collect(), + opening_claims: Vec::new(), + proof: output.proof.clone(), + }} +}} + +fn stage6_kernel_eval(eval: &JoltNamedEval) -> stage6::Stage6NamedEval<{field_type}> {{ + stage6::Stage6NamedEval {{ + name: eval.name, + oracle: eval.oracle, + value: eval.value, + }} +}} + +pub fn stage6_execution_artifacts( + artifacts: &stage6::Stage6ExecutionArtifacts<{field_type}>, +) -> JoltStageExecutionArtifacts {{ + JoltStageExecutionArtifacts {{ + challenge_vectors: artifacts + .challenge_vectors + .iter() + .map(|challenge| JoltStageChallengeVector {{ + symbol: challenge.symbol, + values: challenge.values.clone(), + }}) + .collect(), + sumchecks: stage6_proof(artifacts).sumchecks, + opening_batches: Vec::new(), + }} +}} + +pub fn replay_stage6_proof_with_program<'a, T>( + program: &'static stage6::Stage6CpuProgramPlan, + proof: &'a stage6::Stage6Proof<{field_type}>, + opening_inputs: &'a [stage6::Stage6OpeningInputValue<{field_type}>], + bytecode_data: Option>, + transcript: &mut T, +) -> Result, stage6::Stage6KernelError> +where + T: Transcript, +{{ + let mut executor = stage6::Stage6ProofCarryingKernelExecutor::new(proof, opening_inputs); + if let Some(bytecode_data) = bytecode_data {{ + executor = executor.with_bytecode_read_raf_data(bytecode_data); + }} + stage6_stage::execute_stage6_prover_with_program(program, &mut executor, transcript) +}} + +pub fn stage7_prover_inputs<'a>( + opening_inputs: &'a [stage7::Stage7OpeningInputValue<{field_type}>], + slices: &'a Stage6WitnessSlices<'a, {field_type}>, +) -> stage7::Stage7ProverInputs<'a, {field_type}> {{ + stage7::Stage7ProverInputs::new(opening_inputs).with_stage6_witness_indices(slices) +}} + +pub fn prove_stage7_inputs_with_program( + program: &'static stage7::Stage7CpuProgramPlan, + inputs: stage7::Stage7ProverInputs<'_, {field_type}>, + transcript: &mut T, +) -> Result, stage7::Stage7KernelError> +where + T: Transcript, +{{ + let mut executor = stage7::Stage7ProverKernelExecutor::new(inputs); + stage7_stage::execute_stage7_prover_with_program(program, &mut executor, transcript) +}} + +pub fn prove_stage7_with_witness_inputs( + program: &'static stage7::Stage7CpuProgramPlan, + opening_inputs: &[stage7::Stage7OpeningInputValue<{field_type}>], + slices: &Stage6WitnessSlices<'_, {field_type}>, + transcript: &mut T, +) -> Result, stage7::Stage7KernelError> +where + T: Transcript, +{{ + let inputs = stage7_prover_inputs(opening_inputs, slices); + prove_stage7_inputs_with_program(program, inputs, transcript) +}} + +pub fn prove_stage7_with_trace_witness_inputs( + program: &'static stage7::Stage7CpuProgramPlan, + opening_inputs: &[stage7::Stage7OpeningInputValue<{field_type}>], + witness_params: Stage6WitnessParams, + cycle_inputs: &[CycleInput], + stage6_openings: &[stage6::Stage6OpeningInputValue<{field_type}>], + transcript: &mut T, +) -> Result, stage7::Stage7KernelError> +where + T: Transcript, +{{ + let witness = stage6_witness_from_opening_inputs(witness_params, cycle_inputs, stage6_openings); + let slices = witness.slices(); + prove_stage7_with_witness_inputs(program, opening_inputs, &slices, transcript) +}} + +pub fn stage7_kernel_proof(proof: &JoltStageProof) -> stage7::Stage7Proof<{field_type}> {{ + stage7::Stage7Proof {{ + sumchecks: proof + .sumchecks + .iter() + .map(stage7_kernel_sumcheck_output) + .collect(), + }} +}} + +fn stage7_kernel_sumcheck_output( + output: &JoltSumcheckOutput, +) -> stage7::Stage7SumcheckOutput<{field_type}> {{ + stage7::Stage7SumcheckOutput {{ + driver: output.driver, + point: output.point.clone(), + evals: output.evals.iter().map(stage7_kernel_eval).collect(), + opening_claims: Vec::new(), + proof: output.proof.clone(), + }} +}} + +fn stage7_kernel_eval(eval: &JoltNamedEval) -> stage7::Stage7NamedEval<{field_type}> {{ + stage7::Stage7NamedEval {{ + name: eval.name, + oracle: eval.oracle, + value: eval.value, + }} +}} + +pub fn stage7_execution_artifacts( + artifacts: &stage7::Stage7ExecutionArtifacts<{field_type}>, +) -> JoltStageExecutionArtifacts {{ + JoltStageExecutionArtifacts {{ + challenge_vectors: artifacts + .challenge_vectors + .iter() + .map(|challenge| JoltStageChallengeVector {{ + symbol: challenge.symbol, + values: challenge.values.clone(), + }}) + .collect(), + sumchecks: stage7_proof(artifacts).sumchecks, + opening_batches: Vec::new(), + }} +}} + +pub fn replay_stage7_proof_with_program( + program: &'static stage7::Stage7CpuProgramPlan, + proof: &stage7::Stage7Proof<{field_type}>, + opening_inputs: &[stage7::Stage7OpeningInputValue<{field_type}>], + transcript: &mut T, +) -> Result, stage7::Stage7KernelError> +where + T: Transcript, +{{ + let mut executor = stage7::Stage7ProofCarryingKernelExecutor::new(proof, opening_inputs); + stage7_stage::execute_stage7_prover_with_program(program, &mut executor, transcript) +}} + +pub fn stage7_opening_inputs_from_stage6_artifacts( + artifacts: &stage6::Stage6ExecutionArtifacts<{field_type}>, +) -> Result>, JoltOpeningInputError> {{ + stage7_opening_inputs_from_stage6_artifacts_with_program(&stage7_stage::STAGE7_PROGRAM, artifacts) +}} + +pub fn stage7_opening_inputs_from_stage6_artifacts_with_program( + program: &'static stage7::Stage7CpuProgramPlan, + artifacts: &stage6::Stage6ExecutionArtifacts<{field_type}>, +) -> Result>, JoltOpeningInputError> {{ + program + .opening_inputs + .iter() + .map(|input| {{ + let (point, eval) = stage6_opening_claim(artifacts, input.symbol, input.source_stage, input.source_claim, input.point_arity)?; + Ok(stage7::Stage7OpeningInputValue {{ + symbol: input.symbol, + point, + eval, + }}) + }}) + .collect() +}} + +fn stage6_opening_claim( + artifacts: &stage6::Stage6ExecutionArtifacts<{field_type}>, + symbol: &'static str, + source_stage: &'static str, + source_claim: &'static str, + point_arity: usize, +) -> Result<(Vec<{field_type}>, {field_type}), JoltOpeningInputError> {{ + if source_stage != "stage6" {{ + return Err(JoltOpeningInputError::UnsupportedStage7InputSource {{ + symbol, + source_stage, + }}); + }} + let opening = artifacts + .opening_claims + .iter() + .find(|opening| opening.symbol == source_claim) + .ok_or(JoltOpeningInputError::MissingStage6OpeningClaim {{ source_claim }})?; + if opening.point.len() != point_arity {{ + return Err(JoltOpeningInputError::InvalidPointLength {{ + symbol, + expected: point_arity, + actual: opening.point.len(), + }}); + }} + Ok((opening.point.clone(), opening.eval)) +}} + +fn opening_input_value( + symbol: &'static str, + point_arity: usize, + point: Vec<{field_type}>, + eval: {field_type}, +) -> Result, JoltOpeningInputError> {{ + validate_point_len(symbol, point_arity, point.len())?; + Ok(stage4::Stage4OpeningInputValue {{ + symbol, + point, + eval, + }}) +}} + +fn validate_point_len( + symbol: &'static str, + expected: usize, + actual: usize, +) -> Result<(), JoltOpeningInputError> {{ + if actual != expected {{ + return Err(JoltOpeningInputError::InvalidPointLength {{ + symbol, + expected, + actual, + }}); + }} + Ok(()) +}} + +fn stage1_opening_claim( + artifacts: &stage1::Stage1ExecutionArtifacts<{field_type}>, + source_claim: &'static str, +) -> Result<(Vec<{field_type}>, {field_type}), JoltOpeningInputError> {{ + let opening = artifacts.opening_value(source_claim).ok_or( + JoltOpeningInputError::MissingOpeningClaim {{ + stage: "stage1", + source_claim, + }}, + )?; + Ok((opening.point.clone(), opening.eval)) +}} + +fn stage2_opening_claim( + artifacts: &stage2::Stage2ExecutionArtifacts<{field_type}>, + source_claim: &'static str, +) -> Result<(Vec<{field_type}>, {field_type}), JoltOpeningInputError> {{ + artifacts + .opening_claims + .iter() + .find(|opening| opening.symbol == source_claim) + .map(|opening| (opening.point.clone(), opening.eval)) + .ok_or(JoltOpeningInputError::MissingOpeningClaim {{ + stage: "stage2", + source_claim, + }}) +}} + +fn stage3_opening_claim( + artifacts: &stage3::Stage3ExecutionArtifacts<{field_type}>, + source_claim: &'static str, +) -> Result<(Vec<{field_type}>, {field_type}), JoltOpeningInputError> {{ + artifacts + .opening_claims + .iter() + .find(|opening| opening.symbol == source_claim) + .map(|opening| (opening.point.clone(), opening.eval)) + .ok_or(JoltOpeningInputError::MissingOpeningClaim {{ + stage: "stage3", + source_claim, + }}) +}} + +fn stage4_opening_claim( + artifacts: &stage4::Stage4ExecutionArtifacts<{field_type}>, + source_claim: &'static str, +) -> Result<(Vec<{field_type}>, {field_type}), JoltOpeningInputError> {{ + artifacts + .opening_claims + .iter() + .find(|opening| opening.symbol == source_claim) + .map(|opening| (opening.point.clone(), opening.eval)) + .ok_or(JoltOpeningInputError::MissingOpeningClaim {{ + stage: "stage4", + source_claim, + }}) +}} + +fn stage5_opening_claim( + artifacts: &stage5::Stage5ExecutionArtifacts<{field_type}>, + source_claim: &'static str, +) -> Result<(Vec<{field_type}>, {field_type}), JoltOpeningInputError> {{ + artifacts + .opening_claims + .iter() + .find(|opening| opening.symbol == source_claim) + .map(|opening| (opening.point.clone(), opening.eval)) + .ok_or(JoltOpeningInputError::MissingOpeningClaim {{ + stage: "stage5", + source_claim, + }}) +}} + +"# + ) +} + +fn jolt_prover_evaluation_helpers(field_type: &str) -> String { + format!( + r#"pub fn prove_jolt_evaluation_proof( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + commitment_inputs: &mut I, + prover_setup: &DoryProverSetup, + commitments: &commitment_stage::CommitmentArtifacts, + stage6: &stage6::Stage6ExecutionArtifacts<{field_type}>, + stage7: &stage7::Stage7ExecutionArtifacts<{field_type}>, + stage7_openings: &[stage7::Stage7OpeningInputValue<{field_type}>], + transcript: &mut T, +) -> Result +where + I: commitment_stage::CommitmentInputProvider, + T: Transcript, +{{ + let _claims_span = tracing::info_span!("bolt.evaluate.claims").entered(); + let (sumcheck_address_point, stage7_values) = stage7_claim_values(program, stage7)?; + let address_point = reverse_point(&sumcheck_address_point); + let (opening_point, log_t) = + stage7_evaluation_opening_point(program, &address_point, stage7_openings)?; + let lagrange_factor = EqPolynomial::<{field_type}>::zero_selector(&address_point); + let claims = evaluation_claims(program, stage6, &stage7_values, lagrange_factor)?; + drop(_claims_span); + + let _rlc_span = tracing::info_span!("bolt.evaluate.rlc_claims").entered(); + append_rlc_claims(transcript, &claims); + let gamma_powers = gamma_powers(transcript, claims.len()); + let joint_claim = claims + .iter() + .zip(&gamma_powers) + .map(|(claim, gamma)| claim.value * *gamma) + .sum(); + drop(_rlc_span); + let _materialize_span = + tracing::info_span!("bolt.evaluate.materialize_joint_polynomial").entered(); + let joint_evals = materialize_joint_polynomial( + commitment_inputs, + &claims, + &gamma_powers, + log_t, + opening_point.len(), + )?; + drop(_materialize_span); + let joint_poly = Polynomial::new(joint_evals); + let _hint_span = tracing::info_span!("bolt.evaluate.joint_opening_hint").entered(); + let joint_hint = joint_opening_hint(commitments, &claims, &gamma_powers)?; + drop(_hint_span); + let _dory_open_span = tracing::info_span!("bolt.evaluate.dory_open").entered(); + let joint_opening_proof = ::open( + &joint_poly, + &opening_point, + joint_claim, + prover_setup, + Some(joint_hint), + transcript, + ); + drop(_dory_open_span); + let _bind_span = tracing::info_span!("bolt.evaluate.bind_opening_inputs").entered(); + ::bind_opening_inputs( + transcript, + &opening_point, + &joint_claim, + ); + drop(_bind_span); + Ok(JoltEvaluationProof {{ joint_opening_proof }}) +}} + +struct EvaluationClaim {{ + oracle: &'static str, + source_stage: &'static str, + value: {field_type}, +}} + +fn stage6_eval_claim( + artifacts: &stage6::Stage6ExecutionArtifacts<{field_type}>, + eval_name: &'static str, +) -> Result<{field_type}, JoltEvaluationProveError> {{ + for output in &artifacts.sumchecks {{ + if let Some(eval) = output.evals.iter().find(|eval| eval.name == eval_name) {{ + return Ok(eval.value); + }} + }} + Err(JoltEvaluationProveError::MissingStageEval {{ + stage: "stage6", + eval: eval_name, + }}) +}} + +fn evaluation_claims( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + stage6: &stage6::Stage6ExecutionArtifacts<{field_type}>, + stage7_values: &std::collections::BTreeMap<&'static str, {field_type}>, + lagrange_factor: {field_type}, +) -> Result, JoltEvaluationProveError> {{ + let mut claims = Vec::with_capacity(program.opening_claims.len()); + for plan in program.opening_claims {{ + let value = match plan.source_stage {{ + "stage6" => stage6_eval_claim(stage6, plan.source_claim)? * lagrange_factor, + "stage7" => *stage7_values.get(plan.source_claim).ok_or( + JoltEvaluationProveError::MissingStageEval {{ + stage: plan.source_stage, + eval: plan.source_claim, + }}, + )?, + _ => {{ + return Err(JoltEvaluationProveError::MissingStageEval {{ + stage: plan.source_stage, + eval: plan.source_claim, + }}); + }} + }}; + claims.push(EvaluationClaim {{ + oracle: plan.oracle, + source_stage: plan.source_stage, + value, + }}); + }} + Ok(claims) +}} + +fn stage7_claim_values( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + artifacts: &stage7::Stage7ExecutionArtifacts<{field_type}>, +) -> Result<(Vec<{field_type}>, std::collections::BTreeMap<&'static str, {field_type}>), JoltEvaluationProveError> {{ + let stage7_plans = program + .opening_claims + .iter() + .filter(|plan| plan.source_stage == "stage7") + .collect::>(); + for output in &artifacts.sumchecks {{ + let mut values = std::collections::BTreeMap::new(); + for plan in &stage7_plans {{ + if let Some(eval) = output.evals.iter().find(|eval| eval.name == plan.source_claim) {{ + let _ = values.insert(plan.source_claim, eval.value); + }} + }} + if values.len() == stage7_plans.len() {{ + return Ok((output.point.clone(), values)); + }} + }} + Err(JoltEvaluationProveError::MissingStage7RaEval) +}} + +fn reverse_point(point: &[{field_type}]) -> Vec<{field_type}> {{ + point.iter().rev().copied().collect() +}} + +fn stage7_evaluation_opening_point( + program: &'static stage8_stage::Stage8EvaluationProgramPlan, + address_point: &[{field_type}], + stage7_openings: &[stage7::Stage7OpeningInputValue<{field_type}>], +) -> Result<(Vec<{field_type}>, usize), JoltEvaluationProveError> {{ + let cycle_source_symbol = program.evaluation_point_source.source_claim; + let cycle_source = stage7_openings + .iter() + .find(|input| input.symbol == cycle_source_symbol) + .ok_or(JoltEvaluationProveError::MissingStage7EvaluationPoint)?; + if cycle_source.point.len() < address_point.len() {{ + return Err(JoltEvaluationProveError::InvalidPointLength {{ + artifact: cycle_source_symbol, + expected: address_point.len(), + actual: cycle_source.point.len(), + }}); + }} + let cycle_len = cycle_source.point.len() - address_point.len(); + let mut point = Vec::with_capacity(cycle_source.point.len()); + point.extend_from_slice(address_point); + point.extend_from_slice(&cycle_source.point[address_point.len()..]); + Ok((point, cycle_len)) +}} + +fn append_rlc_claims(transcript: &mut T, claims: &[EvaluationClaim]) +where + T: Transcript, +{{ + transcript.append(&LabelWithCount(b"rlc_claims", claims.len() as u64)); + for claim in claims {{ + claim.value.append_to_transcript(transcript); + }} +}} + +fn gamma_powers(transcript: &mut T, count: usize) -> Vec<{field_type}> +where + T: Transcript, +{{ + let gamma = transcript.challenge(); + let mut powers = Vec::with_capacity(count); + let mut power = {field_type}::from_u64(1); + for _ in 0..count {{ + powers.push(power); + power *= gamma; + }} + powers +}} + +fn materialize_joint_polynomial( + commitment_inputs: &mut I, + claims: &[EvaluationClaim], + gamma_powers: &[{field_type}], + log_t: usize, + main_num_vars: usize, +) -> Result, JoltEvaluationProveError> +where + I: commitment_stage::CommitmentInputProvider, +{{ + let trace_len = target_len(log_t)?; + let main_len = target_len(main_num_vars)?; + let mut joint = vec![{field_type}::from_u64(0); main_len]; + for (claim, gamma) in claims.iter().zip(gamma_powers) {{ + if claim.source_stage == "stage6" {{ + add_oracle_scaled(commitment_inputs, &mut joint, claim.oracle, log_t, trace_len, *gamma)?; + }} else {{ + add_oracle_scaled( + commitment_inputs, + &mut joint, + claim.oracle, + main_num_vars, + main_len, + *gamma, + )?; + }} + }} + Ok(joint) +}} + +fn add_oracle_scaled( + commitment_inputs: &mut I, + joint: &mut [{field_type}], + oracle: &'static str, + num_vars: usize, + limit: usize, + scalar: {field_type}, +) -> Result<(), JoltEvaluationProveError> +where + I: commitment_stage::CommitmentInputProvider, +{{ + if commitment_inputs.add_scaled_to_joint(oracle, joint, num_vars, limit, scalar) {{ + return Ok(()); + }} + let target_len = target_len(num_vars)?; + let data = commitment_inputs + .materialize_with_num_vars(oracle, num_vars) + .ok_or(JoltEvaluationProveError::MissingOracle {{ oracle }})?; + if data.len() > target_len {{ + return Err(JoltEvaluationProveError::InvalidPointLength {{ + artifact: oracle, + expected: target_len, + actual: data.len(), + }}); + }} + let zero = {field_type}::from_u64(0); + let one = {field_type}::from_u64(1); + let len = limit.min(joint.len()).min(data.len()); + if len >= 1 << 15 {{ + joint[..len] + .par_iter_mut() + .zip(data[..len].par_iter()) + .for_each(|(dst, value)| {{ + if *value == zero {{ + return; + }} + if *value == one {{ + *dst += scalar; + }} else {{ + *dst += *value * scalar; + }} + }}); + }} else {{ + for (dst, value) in joint.iter_mut().take(len).zip(data.iter()) {{ + if *value == zero {{ + continue; + }} + if *value == one {{ + *dst += scalar; + }} else {{ + *dst += *value * scalar; + }} + }} + }} + Ok(()) +}} + +fn joint_opening_hint( + commitments: &commitment_stage::CommitmentArtifacts, + claims: &[EvaluationClaim], + gamma_powers: &[{field_type}], +) -> Result {{ + let mut coefficients = std::collections::BTreeMap::<&'static str, {field_type}>::new(); + for (claim, gamma) in claims.iter().zip(gamma_powers) {{ + let coefficient = coefficients.entry(claim.oracle).or_insert({field_type}::from_u64(0)); + *coefficient += *gamma; + }} + + let mut hints = Vec::with_capacity(coefficients.len()); + let mut scalars = Vec::with_capacity(coefficients.len()); + for (oracle, coefficient) in coefficients {{ + hints.push(opening_hint_for_oracle(commitments, oracle)?); + scalars.push(coefficient); + }} + + Ok(::combine_hints( + hints, &scalars, + )) +}} + +fn opening_hint_for_oracle( + commitments: &commitment_stage::CommitmentArtifacts, + oracle: &'static str, +) -> Result {{ + commitments + .hints + .iter() + .find(|hint| hint.oracle == oracle) + .map(|hint| hint.hint.clone()) + .ok_or(JoltEvaluationProveError::MissingOpeningHint {{ oracle }}) +}} + +fn target_len(num_vars: usize) -> Result {{ + if num_vars >= usize::BITS as usize {{ + return Err(JoltEvaluationProveError::TargetSizeOverflow {{ num_vars }}); + }} + Ok(1usize << num_vars) +}} + +"# + ) +} + +const PROVER_FORBIDDEN_IMPORTS: &[&str] = &[ + "use jolt_core", + "jolt_core::", + "use jolt_verifier::stages", + "jolt_verifier::stages::", + "use jolt_equivalence", + "jolt_equivalence::", + "use jolt_profiling", + "jolt_profiling::", +]; + +const VERIFIER_FORBIDDEN_IMPORTS: &[&str] = &[ + "use jolt_kernels", + "jolt_kernels::", + "use jolt_prover", + "jolt_prover::", + "use jolt_core", + "jolt_core::", + "use jolt_equivalence", + "jolt_equivalence::", + "use jolt_profiling", + "jolt_profiling::", + "tracer::", +]; diff --git a/crates/bolt/src/protocols/jolt/emit/mod.rs b/crates/bolt/src/protocols/jolt/emit/mod.rs new file mode 100644 index 0000000000..0ad9e7d3d7 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/mod.rs @@ -0,0 +1 @@ +pub mod rust; diff --git a/crates/bolt/src/protocols/jolt/emit/rust/commitment.rs b/crates/bolt/src/protocols/jolt/emit/rust/commitment.rs new file mode 100644 index 0000000000..66186d2419 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/commitment.rs @@ -0,0 +1,1744 @@ +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{Attribute, OperationRef}; + +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CommitmentCpuProgram { + pub role: Role, + pub params: CommitmentParams, + pub oracle_plans: Vec, + pub batch_plans: Vec, + pub optional_plans: Vec, + pub transcript_steps: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CommitmentParams { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct OraclePlan { + pub oracle: String, + pub source: String, + pub domain: String, + pub num_vars: usize, + pub generation: OracleGeneration, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum OracleGeneration { + Reference, + DenseTrace { + padding: String, + }, + OneHotChunk { + trace_num_vars: usize, + chunk: usize, + num_chunks: usize, + chunk_bits: usize, + padding: String, + layout: String, + }, + OptionalAdvice { + skip_policy: OptionalSkipPolicy, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CommitmentBatchPlan { + pub artifact: String, + pub pcs: String, + pub oracle_family: String, + pub label: String, + pub oracles: Vec, + pub count: usize, + pub domain: String, + pub num_vars: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct OptionalCommitmentPlan { + pub artifact: String, + pub pcs: String, + pub oracle: String, + pub label: String, + pub domain: String, + pub num_vars: usize, + pub skip_policy: OptionalSkipPolicy, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum OptionalSkipPolicy { + MissingOrZero, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TranscriptStep { + pub label: String, + pub source: String, + pub optional: bool, +} + +pub fn emit_commitment_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = commitment_cpu_program(module)?; + + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source()?, + }) +} + +pub fn commitment_cpu_program( + module: &BoltModule<'_, Cpu>, +) -> Result { + verify_cpu_schema(module)?; + let program = CommitmentCpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +impl CommitmentCpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let mut params = None; + let mut oracle_plans = Vec::new(); + let mut batch_plans = Vec::new(); + let mut optional_plans = Vec::new(); + let mut transcript_steps = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(CommitmentParams { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.oracle_dense_trace" => { + oracle_plans.push(OraclePlan { + oracle: symbol_attr(op, "oracle")?, + source: symbol_attr(op, "source")?, + domain: symbol_attr(op, "domain")?, + num_vars: int_attr(op, "num_vars")?, + generation: OracleGeneration::DenseTrace { + padding: string_attr(op, "padding")?, + }, + }); + } + "cpu.oracle_one_hot_chunk" => { + oracle_plans.push(OraclePlan { + oracle: symbol_attr(op, "oracle")?, + source: symbol_attr(op, "source")?, + domain: symbol_attr(op, "domain")?, + num_vars: int_attr(op, "num_vars")?, + generation: OracleGeneration::OneHotChunk { + trace_num_vars: int_attr(op, "trace_num_vars")?, + chunk: int_attr(op, "chunk")?, + num_chunks: int_attr(op, "num_chunks")?, + chunk_bits: int_attr(op, "chunk_bits")?, + padding: string_attr(op, "padding")?, + layout: string_attr(op, "layout")?, + }, + }); + } + "cpu.oracle_optional_advice" => { + oracle_plans.push(OraclePlan { + oracle: symbol_attr(op, "oracle")?, + source: symbol_attr(op, "source")?, + domain: symbol_attr(op, "domain")?, + num_vars: int_attr(op, "num_vars")?, + generation: OracleGeneration::OptionalAdvice { + skip_policy: skip_policy_attr(op, "skip_policy")?, + }, + }); + } + "cpu.oracle_ref" => { + oracle_plans.push(OraclePlan { + oracle: symbol_attr(op, "oracle")?, + source: String::new(), + domain: symbol_attr(op, "domain")?, + num_vars: int_attr(op, "num_vars")?, + generation: OracleGeneration::Reference, + }); + } + "cpu.pcs_commit_batch" | "cpu.pcs_receive_batch" => { + batch_plans.push(CommitmentBatchPlan { + artifact: symbol_attr(op, "artifact")?, + pcs: symbol_attr(op, "pcs")?, + oracle_family: symbol_attr(op, "oracle_family")?, + label: string_attr(op, "label")?, + oracles: symbol_array_attr(op, "ordered_oracles")?, + count: int_attr(op, "count")?, + domain: symbol_attr(op, "domain")?, + num_vars: int_attr(op, "num_vars")?, + }); + } + "cpu.pcs_commit_optional" | "cpu.pcs_receive_optional" => { + optional_plans.push(OptionalCommitmentPlan { + artifact: symbol_attr(op, "artifact")?, + pcs: symbol_attr(op, "pcs")?, + oracle: symbol_attr(op, "oracle")?, + label: string_attr(op, "label")?, + domain: symbol_attr(op, "domain")?, + num_vars: int_attr(op, "num_vars")?, + skip_policy: skip_policy_attr(op, "skip_policy")?, + }); + } + "cpu.transcript_absorb" => { + transcript_steps.push(TranscriptStep { + label: string_attr(op, "label")?, + source: transcript_artifact_source(op)?, + optional: bool_attr(op, "optional")?, + }); + } + _ => {} + } + } + + Ok(Self { + params: params.ok_or_else(|| EmitError::new("missing cpu.params"))?, + role: module + .role() + .ok_or_else(|| EmitError::new("missing cpu party role"))?, + oracle_plans, + batch_plans, + optional_plans, + transcript_steps, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + require_supported_symbol("field", &self.params.field, "bn254_fr")?; + require_supported_symbol("pcs", &self.params.pcs, "dory")?; + require_supported_symbol("transcript", &self.params.transcript, "blake2b_transcript")?; + for plan in &self.batch_plans { + require_supported_symbol("batch pcs", &plan.pcs, "dory")?; + } + for plan in &self.optional_plans { + require_supported_symbol("optional pcs", &plan.pcs, "dory")?; + } + Ok(()) + } + + fn emit_source(&self) -> Result { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + source.push_str(self.emit_imports()); + source.push_str("\n\n"); + source.push_str(&self.emit_types()?); + source.push('\n'); + source.push_str(&self.emit_constants()); + source.push('\n'); + source.push_str(self.emit_entrypoint()); + Ok(source) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_commitment_phase.rs", + Role::Verifier => "verify_commitment_phase.rs", + } + } + + fn emit_imports(&self) -> &'static str { + match self.role { + Role::Prover => { + "use std::borrow::Cow;\n\ + \n\ + use jolt_dory::{DoryCommitment, DoryHint, DoryProverSetup, DoryScheme};\n\ + use jolt_field::{Field, Fr};\n\ + use jolt_openings::CommitmentScheme as _;\n\ + use jolt_poly::{EqPolynomial, MultilinearPoly};\n\ + use jolt_transcript::{AppendToTranscript, Blake2bTranscript, LabelWithCount, Transcript};\n\ + use jolt_witness::{dense_i128_column_to_field, one_hot_chunk_address_major, one_hot_chunk_indices, optional_field_oracle, CommitmentTraceSources};\n\ + use rayon::prelude::*;" + } + Role::Verifier => { + "use jolt_dory::DoryCommitment;\n\ + use jolt_field::Fr;\n\ + use jolt_transcript::{AppendToTranscript, Blake2bTranscript, LabelWithCount, Transcript};" + } + } + } + + fn emit_types(&self) -> Result { + match self.role { + Role::Prover => { + let mut types = Self::emit_prover_types().to_owned(); + types.push('\n'); + types.push_str(&self.emit_oracle_store_types()?); + Ok(types) + } + Role::Verifier => Ok(Self::emit_verifier_types().to_owned()), + } + } + + fn emit_prover_types() -> &'static str { + r"pub type DefaultCommitmentTranscript = Blake2bTranscript; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CommitmentParams { + pub field: &'static str, + pub pcs: &'static str, + pub transcript: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OraclePlan { + pub oracle: &'static str, + pub domain: &'static str, + pub num_vars: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CommitmentBatchPlan { + pub artifact: &'static str, + pub pcs: &'static str, + pub oracle_family: &'static str, + pub label: &'static str, + pub oracles: &'static [&'static str], + pub count: usize, + pub domain: &'static str, + pub num_vars: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum OptionalSkipPolicy { + MissingOrZero, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OptionalCommitmentPlan { + pub artifact: &'static str, + pub pcs: &'static str, + pub oracle: &'static str, + pub label: &'static str, + pub domain: &'static str, + pub num_vars: usize, + pub skip_policy: OptionalSkipPolicy, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct TranscriptStep { + pub label: &'static str, + pub source: &'static str, + pub optional: bool, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CommitmentProverProgramPlan { + pub params: CommitmentParams, + pub oracle_plans: &'static [OraclePlan], + pub batch_plans: &'static [CommitmentBatchPlan], + pub optional_plans: &'static [OptionalCommitmentPlan], + pub transcript_steps: &'static [TranscriptStep], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CommitmentRecord { + pub artifact: &'static str, + pub oracle: &'static str, + pub label: &'static str, + pub num_vars: usize, +} + +#[derive(Clone, Debug)] +pub struct OracleOpeningHint { + pub oracle: &'static str, + pub hint: DoryHint, +} + +#[derive(Clone, Debug)] +pub struct CommittedOracle { + pub commitment: Option, + pub record: CommitmentRecord, + pub hint: Option, +} + +#[derive(Clone, Debug, Default)] +pub struct CommitmentArtifacts { + pub commitments: Vec>, + pub records: Vec, + pub hints: Vec, +} + +pub trait CommitmentInputProvider { + fn materialize(&mut self, oracle: &'static str) -> Option>; + + fn materialize_with_num_vars( + &mut self, + oracle: &'static str, + _num_vars: usize, + ) -> Option> { + self.materialize(oracle) + } + + fn commit_batch( + &mut self, + _program: &CommitmentProverProgramPlan, + _plan: &CommitmentBatchPlan, + _prover_setup: &DoryProverSetup, + ) -> Option, CommitmentPhaseError>> { + None + } + + fn add_scaled_to_joint( + &mut self, + _oracle: &'static str, + _joint: &mut [Fr], + _num_vars: usize, + _limit: usize, + _scalar: Fr, + ) -> bool { + false + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum CommitmentPhaseError { + MissingOracle { oracle: &'static str }, + MissingTranscriptSource { source: &'static str }, + PlanCountMismatch { artifact: &'static str, expected: usize, actual: usize }, + OracleTooLarge { oracle: &'static str, len: usize, target_len: usize }, + TargetSizeOverflow { num_vars: usize }, +}" + } + + fn emit_oracle_store_types(&self) -> Result { + let input_type = r" +pub struct CommitmentOracleInputs<'a> { + pub rd_inc: &'a [i128], + pub ram_inc: &'a [i128], + pub instruction_keys: &'a [Option], + pub ram_addresses: &'a [Option], + pub bytecode_indices: &'a [Option], + pub untrusted_advice: Option<&'a [Fr]>, + pub trusted_advice: Option<&'a [Fr]>, +} + +impl<'a> CommitmentOracleInputs<'a> { + pub fn from_trace_sources( + sources: &'a CommitmentTraceSources, + untrusted_advice: Option<&'a [Fr]>, + trusted_advice: Option<&'a [Fr]>, + ) -> Self { + Self { + rd_inc: &sources.rd_inc, + ram_inc: &sources.ram_inc, + instruction_keys: &sources.instruction_keys, + ram_addresses: &sources.ram_addresses, + bytecode_indices: &sources.bytecode_indices, + untrusted_advice, + trusted_advice, + } + } +} +"; + let sparse_provider = r#" +struct AddressMajorOneHotPolynomial { + trace_len: usize, + chunk_domain: usize, + indices: Vec>, + num_vars: usize, +} + +impl AddressMajorOneHotPolynomial { + fn new( + trace_len: usize, + chunk_domain: usize, + indices: Vec>, + num_vars: usize, + ) -> Result { + let active_len = trace_len + .checked_mul(chunk_domain) + .ok_or(CommitmentPhaseError::TargetSizeOverflow { num_vars })?; + let target_len = target_len(num_vars)?; + if active_len > target_len { + return Err(CommitmentPhaseError::OracleTooLarge { + oracle: "one_hot", + len: active_len, + target_len, + }); + } + Ok(Self { + trace_len, + chunk_domain, + indices, + num_vars, + }) + } + + fn nonzero_flat_indices(&self) -> impl Iterator + '_ { + self.indices + .iter() + .enumerate() + .filter_map(|(cycle, &index)| { + index.map(|index| { + let index = index as usize; + assert!( + index < self.chunk_domain, + "one-hot index {index} exceeds domain {}", + self.chunk_domain + ); + index * self.trace_len + cycle + }) + }) + } +} + +impl MultilinearPoly for AddressMajorOneHotPolynomial { + fn num_vars(&self) -> usize { + self.num_vars + } + + fn evaluate(&self, point: &[Fr]) -> Fr { + assert_eq!(point.len(), self.num_vars); + let eq_evals = EqPolynomial::new(point.to_vec()).evaluations(); + self.nonzero_flat_indices() + .fold(Fr::from_u64(0), |acc, flat| acc + eq_evals[flat]) + } + + fn for_each_row(&self, sigma: usize, f: &mut dyn FnMut(usize, &[Fr])) { + let num_cols = 1usize << sigma; + let num_rows = 1usize << (self.num_vars - sigma); + let mut entries = Vec::with_capacity(self.indices.len()); + for flat in self.nonzero_flat_indices() { + entries.push((flat / num_cols, flat % num_cols)); + } + entries.sort_unstable_by_key(|(row, _)| *row); + + let mut cursor = 0; + let mut row = vec![Fr::from_u64(0); num_cols]; + for row_index in 0..num_rows { + row.fill(Fr::from_u64(0)); + while cursor < entries.len() && entries[cursor].0 == row_index { + row[entries[cursor].1] = Fr::from_u64(1); + cursor += 1; + } + f(row_index, &row); + } + } + + fn fold_rows(&self, left: &[Fr], sigma: usize) -> Vec { + let num_cols = 1usize << sigma; + let num_rows = 1usize << (self.num_vars - sigma); + assert_eq!(left.len(), num_rows); + let mut result = vec![Fr::from_u64(0); num_cols]; + for flat in self.nonzero_flat_indices() { + result[flat % num_cols] += left[flat / num_cols]; + } + result + } + + fn is_sparse(&self) -> bool { + true + } + + fn for_each_nonzero(&self, f: &mut dyn FnMut(usize, Fr)) { + for flat in self.nonzero_flat_indices() { + f(flat, Fr::from_u64(1)); + } + } +} + +pub struct SparseCommitmentInputs<'a> { + pub inputs: CommitmentOracleInputs<'a>, + cache: std::collections::BTreeMap<(&'static str, usize), Option>>, + chunk_counts: OneHotChunkCounts, +} + +impl<'a> SparseCommitmentInputs<'a> { + pub fn new(inputs: CommitmentOracleInputs<'a>) -> Self { + Self { + inputs, + cache: std::collections::BTreeMap::new(), + chunk_counts: OneHotChunkCounts::default(), + } + } + + fn update_chunk_counts(&mut self, program: &CommitmentProverProgramPlan) { + let mut counts = OneHotChunkCounts::default(); + let mut instruction = 0; + let mut ram = 0; + let mut bytecode = 0; + for plan in program.oracle_plans { + if plan.oracle.strip_prefix("InstructionRa_").is_some() { + instruction += 1; + } else if plan.oracle.strip_prefix("RamRa_").is_some() { + ram += 1; + } else if plan.oracle.strip_prefix("BytecodeRa_").is_some() { + bytecode += 1; + } + } + if instruction > 0 { + counts.instruction = instruction; + } + if ram > 0 { + counts.ram = ram; + } + if bytecode > 0 { + counts.bytecode = bytecode; + } + self.chunk_counts = counts; + } + + fn one_hot_spec(&self, oracle: &'static str) -> Option { + let (prefix, num_chunks, values, padding) = + if let Some(suffix) = oracle.strip_prefix("InstructionRa_") { + ( + suffix, + self.chunk_counts.instruction, + OneHotSource::InstructionKeys, + Some(0), + ) + } else if let Some(suffix) = oracle.strip_prefix("RamRa_") { + ( + suffix, + self.chunk_counts.ram, + OneHotSource::RamAddresses, + None, + ) + } else if let Some(suffix) = oracle.strip_prefix("BytecodeRa_") { + ( + suffix, + self.chunk_counts.bytecode, + OneHotSource::BytecodeIndices, + Some(0), + ) + } else { + return None; + }; + let chunk = prefix.parse::().ok()?; + if chunk >= num_chunks { + return None; + } + Some(OneHotSpec { + source: values, + chunk, + num_chunks, + chunk_bits: 4, + padding, + }) + } + + fn source_values(&self, source: OneHotSource) -> &'a [Option] { + match source { + OneHotSource::InstructionKeys => self.inputs.instruction_keys, + OneHotSource::RamAddresses => self.inputs.ram_addresses, + OneHotSource::BytecodeIndices => self.inputs.bytecode_indices, + } + } + + fn one_hot_indices( + &self, + oracle: &'static str, + trace_len: usize, + ) -> Option>> { + let spec = self.one_hot_spec(oracle)?; + let values = self.source_values(spec.source); + Some(one_hot_chunk_indices( + values, + spec.chunk, + spec.num_chunks, + spec.chunk_bits, + trace_len, + spec.padding, + )) + } + + #[expect( + clippy::option_option, + reason = "distinguishes missing oracle from present optional oracle" + )] + fn materialize_oracle( + &self, + oracle: &'static str, + num_vars: usize, + ) -> Option>> { + let materialized = match oracle { + "RdInc" => Some(dense_i128_column_to_field( + self.inputs.rd_inc, + target_len(num_vars).ok()?, + )), + "RamInc" => Some(dense_i128_column_to_field( + self.inputs.ram_inc, + target_len(num_vars).ok()?, + )), + "UntrustedAdvice" => optional_field_oracle( + self.inputs.untrusted_advice, + target_len(num_vars).ok()?, + ), + "TrustedAdvice" => { + optional_field_oracle(self.inputs.trusted_advice, target_len(num_vars).ok()?) + } + _ => { + let spec = self.one_hot_spec(oracle)?; + let trace_len = target_len(num_vars.checked_sub(spec.chunk_bits)?).ok()?; + let values = self.source_values(spec.source); + Some(one_hot_chunk_address_major( + values, + spec.chunk, + spec.num_chunks, + spec.chunk_bits, + trace_len, + spec.padding, + )) + } + }; + Some(materialized) + } + + fn commit_oracle( + &self, + program: &CommitmentProverProgramPlan, + oracle: &'static str, + layout_num_vars: usize, + prover_setup: &DoryProverSetup, + ) -> Result<(DoryCommitment, DoryHint), CommitmentPhaseError> { + let oracle_num_vars = oracle_num_vars(program, oracle, layout_num_vars); + if let Some(spec) = self.one_hot_spec(oracle) { + let trace_len = target_len(oracle_num_vars - spec.chunk_bits)?; + let chunk_domain = target_len(spec.chunk_bits)?; + let indices = self + .one_hot_indices(oracle, trace_len) + .ok_or(CommitmentPhaseError::MissingOracle { oracle })?; + let poly = AddressMajorOneHotPolynomial::new( + trace_len, + chunk_domain, + indices, + layout_num_vars, + )?; + let _dory_commit_span = tracing::info_span!("bolt.commitment.dory_commit").entered(); + Ok(DoryScheme::commit(&poly, prover_setup)) + } else { + let data = self + .materialize_oracle(oracle, oracle_num_vars) + .flatten() + .ok_or(CommitmentPhaseError::MissingOracle { oracle })?; + let data = into_padded_oracle(oracle, oracle_num_vars, Cow::Owned(data))?; + commit_with_layout(&data, layout_num_vars, prover_setup) + } + } +} + +impl CommitmentInputProvider for SparseCommitmentInputs<'_> { + fn materialize(&mut self, oracle: &'static str) -> Option> { + let num_vars = match oracle { + "RdInc" | "RamInc" | "UntrustedAdvice" | "TrustedAdvice" => 16, + _ if self.one_hot_spec(oracle).is_some() => 20, + _ => return None, + }; + self.materialize_with_num_vars(oracle, num_vars) + } + + fn materialize_with_num_vars( + &mut self, + oracle: &'static str, + num_vars: usize, + ) -> Option> { + if !self.cache.contains_key(&(oracle, num_vars)) { + let materialized = self.materialize_oracle(oracle, num_vars).flatten(); + let _ = self.cache.insert((oracle, num_vars), materialized); + } + self.cache + .get(&(oracle, num_vars)) + .and_then(|values| values.as_ref()) + .map(|values| Cow::Borrowed(values.as_slice())) + } + + fn commit_batch( + &mut self, + program: &CommitmentProverProgramPlan, + plan: &CommitmentBatchPlan, + prover_setup: &DoryProverSetup, + ) -> Option, CommitmentPhaseError>> { + self.update_chunk_counts(program); + Some( + plan.oracles + .par_iter() + .map(|&oracle| { + let oracle_num_vars = oracle_num_vars(program, oracle, plan.num_vars); + let (commitment, hint) = + self.commit_oracle(program, oracle, plan.num_vars, prover_setup)?; + Ok(CommittedOracle { + commitment: Some(commitment), + record: CommitmentRecord { + artifact: plan.artifact, + oracle, + label: plan.label, + num_vars: oracle_num_vars, + }, + hint: Some(OracleOpeningHint { oracle, hint }), + }) + }) + .collect(), + ) + } + + fn add_scaled_to_joint( + &mut self, + oracle: &'static str, + joint: &mut [Fr], + num_vars: usize, + limit: usize, + scalar: Fr, + ) -> bool { + let dense = match oracle { + "RdInc" => Some(self.inputs.rd_inc), + "RamInc" => Some(self.inputs.ram_inc), + _ => None, + }; + if let Some(values) = dense { + let Ok(target_len) = target_len(num_vars) else { + return false; + }; + let len = limit.min(joint.len()).min(values.len()).min(target_len); + for (dst, &value) in joint.iter_mut().take(len).zip(values.iter()) { + if value != 0 { + *dst += Fr::from_i128(value) * scalar; + } + } + return true; + } + + let Some(spec) = self.one_hot_spec(oracle) else { + return false; + }; + let Some(trace_num_vars) = num_vars.checked_sub(spec.chunk_bits) else { + return false; + }; + let Ok(trace_len) = target_len(trace_num_vars) else { + return false; + }; + let Ok(chunk_domain) = target_len(spec.chunk_bits) else { + return false; + }; + let Some(active_len) = trace_len.checked_mul(chunk_domain) else { + return false; + }; + let max_flat = limit.min(joint.len()).min(active_len); + let Some(indices) = self.one_hot_indices(oracle, trace_len) else { + return false; + }; + for (cycle, index) in indices.into_iter().enumerate() { + let Some(index) = index else { + continue; + }; + let flat = index as usize * trace_len + cycle; + if flat < max_flat { + joint[flat] += scalar; + } + } + true + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum OneHotSource { + InstructionKeys, + RamAddresses, + BytecodeIndices, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct OneHotSpec { + source: OneHotSource, + chunk: usize, + num_chunks: usize, + chunk_bits: usize, + padding: Option, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +struct OneHotChunkCounts { + instruction: usize, + ram: usize, + bytecode: usize, +} + +impl Default for OneHotChunkCounts { + fn default() -> Self { + Self { + instruction: 32, + ram: 4, + bytecode: 3, + } + } +} +"#; + let mut fields = Vec::new(); + let mut provider_arms = Vec::new(); + let mut initializers = Vec::new(); + for plan in &self.oracle_plans { + match &plan.generation { + OracleGeneration::Reference => {} + OracleGeneration::OptionalAdvice { .. } => { + let field = rust_field_name(&plan.oracle); + fields.push(format!(" pub {field}: Option>,")); + provider_arms.push(format!( + " {} => self.{field}.as_deref().map(Cow::Borrowed),", + rust_str(&plan.oracle) + )); + initializers.push(format!( + " {field}: {},", + Self::oracle_initializer(plan)? + )); + } + _ => { + let field = rust_field_name(&plan.oracle); + fields.push(format!(" pub {field}: Vec,")); + provider_arms.push(format!( + " {} => Some(Cow::Borrowed(&self.{field})),", + rust_str(&plan.oracle) + )); + initializers.push(format!( + " {field}: {},", + Self::oracle_initializer(plan)? + )); + } + } + } + let fields = fields.join("\n"); + let provider_arms = provider_arms.join("\n"); + let initializers = initializers.join("\n"); + + Ok(format!( + "{input_type} +{sparse_provider} +#[derive(Clone, Debug, Default)] +pub struct CommitmentOracles {{ +{fields} +}} + +impl CommitmentInputProvider for CommitmentOracles {{ + fn materialize(&mut self, oracle: &'static str) -> Option> {{ + match oracle {{ +{provider_arms} + _ => None, + }} + }} +}} + +pub fn build_commitment_oracles( + inputs: &CommitmentOracleInputs<'_>, +) -> Result {{ + Ok(CommitmentOracles {{ +{initializers} + }}) +}} +" + )) + } + + fn oracle_initializer(plan: &OraclePlan) -> Result { + match &plan.generation { + OracleGeneration::Reference => Err(EmitError::new(format!( + "reference oracle @{} has no prover initializer", + plan.oracle + ))), + OracleGeneration::DenseTrace { .. } => Ok(format!( + "dense_i128_column_to_field(inputs.{}, target_len({})?)", + rust_input_field(&plan.source)?, + plan.num_vars + )), + OracleGeneration::OneHotChunk { + trace_num_vars, + chunk, + num_chunks, + chunk_bits, + padding, + .. + } => Ok(format!( + "one_hot_chunk_address_major(inputs.{}, {chunk}, {num_chunks}, {chunk_bits}, target_len({trace_num_vars})?, {})", + rust_input_field(&plan.source)?, + rust_padding_value(padding)? + )), + OracleGeneration::OptionalAdvice { .. } => Ok(format!( + "optional_field_oracle(inputs.{}, target_len({})?)", + rust_input_field(&plan.source)?, + plan.num_vars + )), + } + } + + fn emit_verifier_types() -> &'static str { + r"pub type DefaultCommitmentTranscript = Blake2bTranscript; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CommitmentParams { + pub field: &'static str, + pub pcs: &'static str, + pub transcript: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OraclePlan { + pub oracle: &'static str, + pub domain: &'static str, + pub num_vars: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CommitmentBatchPlan { + pub artifact: &'static str, + pub pcs: &'static str, + pub oracle_family: &'static str, + pub label: &'static str, + pub oracles: &'static [&'static str], + pub count: usize, + pub domain: &'static str, + pub num_vars: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum OptionalSkipPolicy { + MissingOrZero, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OptionalCommitmentPlan { + pub artifact: &'static str, + pub pcs: &'static str, + pub oracle: &'static str, + pub label: &'static str, + pub domain: &'static str, + pub num_vars: usize, + pub skip_policy: OptionalSkipPolicy, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct TranscriptStep { + pub label: &'static str, + pub source: &'static str, + pub optional: bool, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CommitmentVerifierProgramPlan { + pub params: CommitmentParams, + pub oracle_plans: &'static [OraclePlan], + pub batch_plans: &'static [CommitmentBatchPlan], + pub optional_plans: &'static [OptionalCommitmentPlan], + pub transcript_steps: &'static [TranscriptStep], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct CommitmentRecord { + pub artifact: &'static str, + pub oracle: &'static str, + pub label: &'static str, + pub num_vars: usize, +} + +#[derive(Clone, Debug, Default)] +pub struct CommitmentArtifacts { + pub commitments: Vec>, + pub records: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum CommitmentPhaseError { + MissingProofCommitment { oracle: &'static str }, + MissingProofCommitmentSlot { artifact: &'static str, oracle: &'static str }, + MissingTranscriptSource { source: &'static str }, + PlanCountMismatch { artifact: &'static str, expected: usize, actual: usize }, + ProofCommitmentCountMismatch { expected: usize, actual: usize }, +}" + } + + fn emit_constants(&self) -> String { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const COMMITMENT_PARAMS: CommitmentParams = CommitmentParams {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + + let oracle_plans = self + .oracle_plans + .iter() + .map(|plan| { + format!( + " OraclePlan {{ oracle: {}, domain: {}, num_vars: {} }},", + rust_str(&plan.oracle), + rust_str(&plan.domain), + plan.num_vars + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!("pub const ORACLE_PLANS: &[OraclePlan] = &[\n{oracle_plans}\n];\n"), + ); + + for (index, plan) in self.batch_plans.iter().enumerate() { + let oracles = plan + .oracles + .iter() + .map(|oracle| format!(" {},", rust_str(oracle))) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const COMMITMENT_BATCH_{index}_ORACLES: &[&str] = &[\n{oracles}\n];\n" + ), + ); + } + + let batch_plans = self + .batch_plans + .iter() + .enumerate() + .map(|(index, plan)| { + format!( + " CommitmentBatchPlan {{ artifact: {}, pcs: {}, oracle_family: {}, label: {}, oracles: COMMITMENT_BATCH_{index}_ORACLES, count: {}, domain: {}, num_vars: {} }},", + rust_str(&plan.artifact), + rust_str(&plan.pcs), + rust_str(&plan.oracle_family), + rust_str(&plan.label), + plan.count, + rust_str(&plan.domain), + plan.num_vars + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const COMMITMENT_BATCH_PLANS: &[CommitmentBatchPlan] = &[\n{batch_plans}\n];\n" + ), + ); + + let optional_plans = self + .optional_plans + .iter() + .map(|plan| { + format!( + " OptionalCommitmentPlan {{ artifact: {}, pcs: {}, oracle: {}, label: {}, domain: {}, num_vars: {}, skip_policy: {} }},", + rust_str(&plan.artifact), + rust_str(&plan.pcs), + rust_str(&plan.oracle), + rust_str(&plan.label), + rust_str(&plan.domain), + plan.num_vars, + plan.skip_policy.rust_variant() + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const OPTIONAL_COMMITMENT_PLANS: &[OptionalCommitmentPlan] = &[\n{optional_plans}\n];\n" + ), + ); + + let steps = self + .transcript_steps + .iter() + .map(|step| { + format!( + " TranscriptStep {{ label: {}, source: {}, optional: {} }},", + rust_str(&step.label), + rust_str(&step.source), + step.optional + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!("pub const TRANSCRIPT_PLAN: &[TranscriptStep] = &[\n{steps}\n];"), + ); + source.push('\n'); + let program_type = match self.role { + Role::Prover => "CommitmentProverProgramPlan", + Role::Verifier => "CommitmentVerifierProgramPlan", + }; + push_format( + &mut source, + format_args!( + "pub const COMMITMENT_PROGRAM: {program_type} = {program_type} {{\n\ + \x20 params: COMMITMENT_PARAMS,\n\ + \x20 oracle_plans: ORACLE_PLANS,\n\ + \x20 batch_plans: COMMITMENT_BATCH_PLANS,\n\ + \x20 optional_plans: OPTIONAL_COMMITMENT_PLANS,\n\ + \x20 transcript_steps: TRANSCRIPT_PLAN,\n\ + }};\n" + ), + ); + + source + } + + fn emit_entrypoint(&self) -> &'static str { + match self.role { + Role::Prover => Self::emit_prover_entrypoint(), + Role::Verifier => Self::emit_verifier_entrypoint(), + } + } + + fn emit_prover_entrypoint() -> &'static str { + r#"pub fn prove_commitment_phase( + inputs: &mut I, + prover_setup: &DoryProverSetup, + transcript: &mut T, +) -> Result +where + I: CommitmentInputProvider, + T: Transcript, +{ + prove_commitment_phase_with_program(&COMMITMENT_PROGRAM, inputs, prover_setup, transcript) +} + +pub fn prove_commitment_phase_with_program( + program: &'static CommitmentProverProgramPlan, + inputs: &mut I, + prover_setup: &DoryProverSetup, + transcript: &mut T, +) -> Result +where + I: CommitmentInputProvider, + T: Transcript, +{ + let mut artifacts = CommitmentArtifacts::default(); + for plan in program.batch_plans { + let _batch_span = tracing::info_span!("bolt.commitment.batch").entered(); + commit_batch(program, inputs, prover_setup, &mut artifacts, plan)?; + } + for plan in program.optional_plans { + let _optional_span = tracing::info_span!("bolt.commitment.optional").entered(); + commit_optional(program, inputs, prover_setup, &mut artifacts, plan)?; + } + absorb_transcript(program, &artifacts, transcript)?; + Ok(artifacts) +} + +fn commit_batch( + program: &CommitmentProverProgramPlan, + inputs: &mut I, + prover_setup: &DoryProverSetup, + artifacts: &mut CommitmentArtifacts, + plan: &CommitmentBatchPlan, +) -> Result<(), CommitmentPhaseError> +where + I: CommitmentInputProvider, +{ + if plan.count != plan.oracles.len() { + return Err(CommitmentPhaseError::PlanCountMismatch { + artifact: plan.artifact, + expected: plan.count, + actual: plan.oracles.len(), + }); + } + if let Some(committed) = inputs.commit_batch(program, plan, prover_setup) { + for committed in committed? { + artifacts.records.push(committed.record); + artifacts.commitments.push(committed.commitment); + if let Some(hint) = committed.hint { + artifacts.hints.push(hint); + } + } + return Ok(()); + } + for &oracle in plan.oracles { + let data = inputs + .materialize_with_num_vars(oracle, oracle_num_vars(program, oracle, plan.num_vars)) + .ok_or(CommitmentPhaseError::MissingOracle { oracle })?; + let oracle_num_vars = oracle_num_vars(program, oracle, plan.num_vars); + let data = into_padded_oracle(oracle, oracle_num_vars, data)?; + let (commitment, hint) = commit_with_layout(&data, plan.num_vars, prover_setup)?; + artifacts.records.push(CommitmentRecord { + artifact: plan.artifact, + oracle, + label: plan.label, + num_vars: oracle_num_vars, + }); + artifacts.commitments.push(Some(commitment)); + artifacts.hints.push(OracleOpeningHint { oracle, hint }); + } + Ok(()) +} + +fn commit_optional( + program: &CommitmentProverProgramPlan, + inputs: &mut I, + prover_setup: &DoryProverSetup, + artifacts: &mut CommitmentArtifacts, + plan: &OptionalCommitmentPlan, +) -> Result<(), CommitmentPhaseError> +where + I: CommitmentInputProvider, +{ + let Some(data) = inputs.materialize_with_num_vars(plan.oracle, plan.num_vars) else { + return push_skipped_optional(program, artifacts, plan); + }; + if should_skip_optional(plan.skip_policy, data.as_ref()) { + return push_skipped_optional(program, artifacts, plan); + } + let data = into_padded_oracle(plan.oracle, plan.num_vars, data)?; + let (commitment, hint) = commit_with_layout(&data, plan.num_vars, prover_setup)?; + artifacts.records.push(CommitmentRecord { + artifact: plan.artifact, + oracle: plan.oracle, + label: plan.label, + num_vars: oracle_num_vars(program, plan.oracle, plan.num_vars), + }); + artifacts.commitments.push(Some(commitment)); + artifacts.hints.push(OracleOpeningHint { + oracle: plan.oracle, + hint, + }); + Ok(()) +} + +fn push_skipped_optional( + program: &CommitmentProverProgramPlan, + artifacts: &mut CommitmentArtifacts, + plan: &OptionalCommitmentPlan, +) -> Result<(), CommitmentPhaseError> { + artifacts.records.push(CommitmentRecord { + artifact: plan.artifact, + oracle: plan.oracle, + label: plan.label, + num_vars: oracle_num_vars(program, plan.oracle, plan.num_vars), + }); + artifacts.commitments.push(None); + Ok(()) +} + +fn should_skip_optional(policy: OptionalSkipPolicy, data: &[Fr]) -> bool { + match policy { + OptionalSkipPolicy::MissingOrZero => data.iter().all(|value| *value == Fr::from_u64(0)), + } +} + +fn into_padded_oracle( + oracle: &'static str, + num_vars: usize, + data: Cow<'_, [Fr]>, +) -> Result, CommitmentPhaseError> { + let target_len = target_len(num_vars)?; + if data.len() > target_len { + return Err(CommitmentPhaseError::OracleTooLarge { + oracle, + len: data.len(), + target_len, + }); + } + let mut data = data.into_owned(); + data.resize(target_len, Fr::from_u64(0)); + Ok(data) +} + +fn oracle_num_vars( + program: &CommitmentProverProgramPlan, + oracle: &'static str, + fallback: usize, +) -> usize { + program + .oracle_plans + .iter() + .find(|plan| plan.oracle == oracle) + .map_or(fallback, |plan| plan.num_vars) +} + +fn commit_with_layout( + data: &[Fr], + layout_num_vars: usize, + prover_setup: &DoryProverSetup, +) -> Result<(DoryCommitment, DoryHint), CommitmentPhaseError> { + let row_len = target_len(layout_num_vars.div_ceil(2))?; + let _dory_commit_span = tracing::info_span!("bolt.commitment.dory_commit").entered(); + Ok(DoryScheme::commit_evaluations_with_row_len( + data, + row_len, + prover_setup, + )) +} + +fn target_len(num_vars: usize) -> Result { + if num_vars >= usize::BITS as usize { + return Err(CommitmentPhaseError::TargetSizeOverflow { num_vars }); + } + Ok(1usize << num_vars) +} + +fn absorb_transcript( + program: &CommitmentProverProgramPlan, + artifacts: &CommitmentArtifacts, + transcript: &mut T, +) -> Result<(), CommitmentPhaseError> +where + T: Transcript, +{ + for step in program.transcript_steps { + let mut appended = false; + for (record, commitment) in artifacts.records.iter().zip(&artifacts.commitments) { + if record.artifact != step.source { + continue; + } + if let Some(commitment) = commitment { + transcript.append(&LabelWithCount(step.label.as_bytes(), commitment.serialized_len())); + commitment.append_to_transcript(transcript); + appended = true; + } + } + if !step.optional && !appended { + return Err(CommitmentPhaseError::MissingTranscriptSource { + source: step.source, + }); + } + } + Ok(()) +} +"# + } + + fn emit_verifier_entrypoint() -> &'static str { + r"pub fn verify_commitment_phase( + proof_commitments: &[Option], + transcript: &mut T, +) -> Result +where + T: Transcript, +{ + verify_commitment_phase_with_program(&COMMITMENT_PROGRAM, proof_commitments, transcript) +} + +pub fn verify_commitment_phase_with_program( + program: &'static CommitmentVerifierProgramPlan, + proof_commitments: &[Option], + transcript: &mut T, +) -> Result +where + T: Transcript, +{ + let mut artifacts = CommitmentArtifacts::default(); + let mut cursor = 0usize; + for plan in program.batch_plans { + receive_batch(program, proof_commitments, &mut cursor, &mut artifacts, plan)?; + } + for plan in program.optional_plans { + receive_optional(program, proof_commitments, &mut cursor, &mut artifacts, plan)?; + } + if cursor != proof_commitments.len() { + return Err(CommitmentPhaseError::ProofCommitmentCountMismatch { + expected: cursor, + actual: proof_commitments.len(), + }); + } + absorb_transcript(program, &artifacts, transcript)?; + Ok(artifacts) +} + +fn receive_batch( + program: &'static CommitmentVerifierProgramPlan, + proof_commitments: &[Option], + cursor: &mut usize, + artifacts: &mut CommitmentArtifacts, + plan: &CommitmentBatchPlan, +) -> Result<(), CommitmentPhaseError> { + if plan.count != plan.oracles.len() { + return Err(CommitmentPhaseError::PlanCountMismatch { + artifact: plan.artifact, + expected: plan.count, + actual: plan.oracles.len(), + }); + } + for &oracle in plan.oracles { + let commitment = proof_commitments + .get(*cursor) + .ok_or(CommitmentPhaseError::MissingProofCommitmentSlot { + artifact: plan.artifact, + oracle, + })? + .as_ref() + .ok_or(CommitmentPhaseError::MissingProofCommitment { oracle })? + .clone(); + *cursor += 1; + let oracle_num_vars = oracle_num_vars(program, oracle, plan.num_vars); + artifacts.records.push(CommitmentRecord { + artifact: plan.artifact, + oracle, + label: plan.label, + num_vars: oracle_num_vars, + }); + artifacts.commitments.push(Some(commitment)); + } + Ok(()) +} + +fn receive_optional( + program: &'static CommitmentVerifierProgramPlan, + proof_commitments: &[Option], + cursor: &mut usize, + artifacts: &mut CommitmentArtifacts, + plan: &OptionalCommitmentPlan, +) -> Result<(), CommitmentPhaseError> { + let commitment = proof_commitments + .get(*cursor) + .ok_or(CommitmentPhaseError::MissingProofCommitmentSlot { + artifact: plan.artifact, + oracle: plan.oracle, + })? + .clone(); + *cursor += 1; + artifacts.records.push(CommitmentRecord { + artifact: plan.artifact, + oracle: plan.oracle, + label: plan.label, + num_vars: oracle_num_vars(program, plan.oracle, plan.num_vars), + }); + artifacts.commitments.push(commitment); + Ok(()) +} + +pub fn commitment_verifier_program() -> &'static CommitmentVerifierProgramPlan { + &COMMITMENT_PROGRAM +} + +fn oracle_num_vars( + program: &'static CommitmentVerifierProgramPlan, + oracle: &'static str, + fallback: usize, +) -> usize { + program + .oracle_plans + .iter() + .find(|plan| plan.oracle == oracle) + .map_or(fallback, |plan| plan.num_vars) +} + +fn absorb_transcript( + program: &'static CommitmentVerifierProgramPlan, + artifacts: &CommitmentArtifacts, + transcript: &mut T, +) -> Result<(), CommitmentPhaseError> +where + T: Transcript, +{ + for step in program.transcript_steps { + let mut appended = false; + for (record, commitment) in artifacts.records.iter().zip(&artifacts.commitments) { + if record.artifact != step.source { + continue; + } + if let Some(commitment) = commitment { + transcript.append(&LabelWithCount(step.label.as_bytes(), commitment.serialized_len())); + commitment.append_to_transcript(transcript); + appended = true; + } + } + if !step.optional && !appended { + return Err(CommitmentPhaseError::MissingTranscriptSource { + source: step.source, + }); + } + } + Ok(()) +} +" + } +} + +impl OptionalSkipPolicy { + fn parse(value: &str) -> Result { + match value { + "missing_or_zero" => Ok(Self::MissingOrZero), + _ => Err(EmitError::new(format!( + "unsupported optional commitment skip policy `{value}`" + ))), + } + } + + fn rust_variant(&self) -> &'static str { + match self { + Self::MissingOrZero => "OptionalSkipPolicy::MissingOrZero", + } + } +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn skip_policy_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result { + OptionalSkipPolicy::parse(&string_attr(operation, attr)?) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn bool_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(|attribute| match attribute.to_string().as_str() { + "true" => Some(true), + "false" => Some(false), + _ => None, + }) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "bool")) +} + +fn transcript_artifact_source(operation: OperationRef<'_, '_>) -> Result { + let artifact = operation + .operand(1) + .map_err(|_| attr_error(operation, "artifact operand", "value"))?; + let owner = OperationResult::try_from(artifact) + .map_err(|_| EmitError::new("cpu.transcript_absorb artifact operand must be op result"))? + .owner(); + symbol_attr(owner, "artifact") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name(operation: OperationRef<'_, '_>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn rust_field_name(value: &str) -> String { + let mut output = String::new(); + let mut previous_was_separator = false; + for (index, character) in value.chars().enumerate() { + if character == '_' { + output.push('_'); + previous_was_separator = true; + continue; + } + if character.is_ascii_uppercase() { + if index != 0 && !previous_was_separator { + output.push('_'); + } + output.push(character.to_ascii_lowercase()); + } else { + output.push(character); + } + previous_was_separator = false; + } + output +} + +fn rust_input_field(source: &str) -> Result<&'static str, EmitError> { + match source { + "trace.rd_inc" => Ok("rd_inc"), + "trace.ram_inc" => Ok("ram_inc"), + "trace.instruction_keys" => Ok("instruction_keys"), + "trace.ram_addresses" => Ok("ram_addresses"), + "trace.bytecode_indices" => Ok("bytecode_indices"), + "advice.untrusted" => Ok("untrusted_advice"), + "advice.trusted" => Ok("trusted_advice"), + _ => Err(EmitError::new(format!( + "unsupported oracle source `{source}`" + ))), + } +} + +fn rust_padding_value(padding: &str) -> Result<&'static str, EmitError> { + match padding { + "zero" => Ok("Some(0)"), + "none" => Ok("None"), + _ => Err(EmitError::new(format!( + "unsupported oracle padding `{padding}`" + ))), + } +} + +fn require_supported_symbol(kind: &str, actual: &str, expected: &str) -> Result<(), EmitError> { + if actual == expected { + Ok(()) + } else { + Err(EmitError::new(format!( + "unsupported {kind} @{actual}; Rust commitment emitter currently supports @{expected}" + ))) + } +} diff --git a/crates/bolt/src/protocols/jolt/emit/rust/mod.rs b/crates/bolt/src/protocols/jolt/emit/rust/mod.rs new file mode 100644 index 0000000000..50f36ba9f7 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/mod.rs @@ -0,0 +1,27 @@ +mod commitment; +mod stage1; +mod stage2; +mod stage3; +mod stage4; +mod stage5; +mod stage6; +mod stage7; +mod stage8; + +pub use commitment::{ + commitment_cpu_program, emit_commitment_rust, CommitmentBatchPlan, CommitmentCpuProgram, + CommitmentParams, OptionalCommitmentPlan, OptionalSkipPolicy, OracleGeneration, OraclePlan, + TranscriptStep, +}; +pub use stage1::{ + emit_stage1_rust, stage1_cpu_program, Stage1CpuProgram, Stage1KernelPlan, + Stage1OpeningBatchPlan, Stage1OpeningClaimPlan, Stage1Params, Stage1SumcheckBatchPlan, + Stage1SumcheckClaimPlan, Stage1SumcheckDriverPlan, Stage1SumcheckEvalPlan, +}; +pub use stage2::{emit_stage2_rust, stage2_cpu_program, Stage2CpuProgram}; +pub use stage3::{emit_stage3_rust, stage3_cpu_program, Stage3CpuProgram}; +pub use stage4::{emit_stage4_rust, stage4_cpu_program, Stage4CpuProgram}; +pub use stage5::{emit_stage5_rust, stage5_cpu_program, Stage5CpuProgram}; +pub use stage6::{emit_stage6_rust, stage6_cpu_program, Stage6CpuProgram}; +pub use stage7::{emit_stage7_rust, stage7_cpu_program, Stage7CpuProgram}; +pub use stage8::{emit_stage8_rust, stage8_cpu_program, Stage8CpuProgram}; diff --git a/crates/bolt/src/protocols/jolt/emit/rust/stage1.rs b/crates/bolt/src/protocols/jolt/emit/rust/stage1.rs new file mode 100644 index 0000000000..0755beb79d --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/stage1.rs @@ -0,0 +1,1653 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{Attribute, OperationRef}; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1CpuProgram { + pub role: Role, + pub params: Stage1Params, + pub transcript_squeezes: Vec, + pub kernels: Vec, + pub claims: Vec, + pub batches: Vec, + pub drivers: Vec, + pub instance_results: Vec, + pub evals: Vec, + pub opening_claims: Vec, + pub opening_batches: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1Params { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1KernelPlan { + pub symbol: String, + pub relation: String, + pub kind: String, + pub backend: String, + pub abi: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1TranscriptSqueezePlan { + pub symbol: String, + pub label: String, + pub kind: String, + pub count: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1SumcheckClaimPlan { + pub symbol: String, + pub stage: String, + pub domain: String, + pub num_rounds: usize, + pub degree: usize, + pub claim: String, + pub kernel: Option, + pub relation: Option, + pub claim_value: String, + pub input_openings: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1SumcheckBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, + pub claim_label: String, + pub round_label: String, + pub round_schedule: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1SumcheckDriverPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub kernel: Option, + pub relation: Option, + pub batch: String, + pub policy: String, + pub round_schedule: Vec, + pub claim_label: String, + pub round_label: String, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1SumcheckInstanceResultPlan { + pub symbol: String, + pub source: String, + pub claim: String, + pub relation: String, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: String, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1SumcheckEvalPlan { + pub symbol: String, + pub source: String, + pub name: String, + pub index: usize, + pub oracle: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1OpeningClaimPlan { + pub symbol: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, + pub point_source: String, + pub eval_source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage1OpeningBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, +} + +pub fn stage1_cpu_program(module: &BoltModule<'_, Cpu>) -> Result { + verify_cpu_schema(module)?; + let program = Stage1CpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +pub fn emit_stage1_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = stage1_cpu_program(module)?; + + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source()?, + }) +} + +impl Stage1CpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let mut params = None; + let mut transcript_squeezes = Vec::new(); + let mut kernels = Vec::new(); + let mut claims = Vec::new(); + let mut batches = Vec::new(); + let mut drivers = Vec::new(); + let mut instance_results = Vec::new(); + let mut evals = Vec::new(); + let mut opening_claims = Vec::new(); + let mut opening_batches = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(Stage1Params { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.kernel" => { + kernels.push(Stage1KernelPlan { + symbol: string_attr(op, "sym_name")?, + relation: symbol_attr(op, "relation")?, + kind: string_attr(op, "kind")?, + backend: string_attr(op, "backend")?, + abi: string_attr(op, "abi")?, + }); + } + "cpu.transcript_squeeze" => { + transcript_squeezes.push(Stage1TranscriptSqueezePlan { + symbol: string_attr(op, "sym_name")?, + label: string_attr(op, "label")?, + kind: string_attr(op, "kind")?, + count: int_attr(op, "count")?, + }); + } + "cpu.sumcheck_claim" => { + claims.push(Stage1SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_verify_claim" => { + claims.push(Stage1SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_batch" => { + batches.push(Stage1SumcheckBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + round_schedule: int_array_attr(op, "round_schedule")?, + }); + } + "cpu.sumcheck_driver" => { + drivers.push(Stage1SumcheckDriverPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_verify" => { + drivers.push(Stage1SumcheckDriverPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_eval" => { + evals.push(Stage1SumcheckEvalPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + name: symbol_attr(op, "name")?, + index: int_attr(op, "index")?, + oracle: symbol_attr(op, "oracle")?, + }); + } + "cpu.sumcheck_instance_result" => { + instance_results.push(Stage1SumcheckInstanceResultPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + claim: symbol_attr(op, "claim")?, + relation: symbol_attr(op, "relation")?, + index: int_attr(op, "index")?, + point_arity: int_attr(op, "point_arity")?, + num_rounds: int_attr(op, "num_rounds")?, + round_offset: int_attr(op, "round_offset")?, + point_order: string_attr(op, "point_order")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.opening_claim" => { + opening_claims.push(Stage1OpeningClaimPlan { + symbol: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + point_source: operand_symbol(op, 0)?, + eval_source: operand_symbol(op, 1)?, + }); + } + "cpu.opening_batch" => { + opening_batches.push(Stage1OpeningBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + }); + } + _ => {} + } + } + + Ok(Self { + params: params.ok_or_else(|| EmitError::new("missing cpu.params"))?, + role: module + .role() + .ok_or_else(|| EmitError::new("missing cpu party role"))?, + transcript_squeezes, + kernels, + claims, + batches, + drivers, + instance_results, + evals, + opening_claims, + opening_batches, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + require_supported_symbol("field", &self.params.field, "bn254_fr")?; + require_supported_symbol("pcs", &self.params.pcs, "dory")?; + require_supported_symbol("transcript", &self.params.transcript, "blake2b_transcript")?; + self.verify_transcript_squeezes()?; + self.verify_claim_batches()?; + match self.role { + Role::Prover => { + self.verify_kernel_definitions()?; + self.verify_prover_driver_bindings()?; + } + Role::Verifier => self.verify_verifier_driver_bindings()?, + } + self.verify_opening_flow() + } + + fn verify_transcript_squeezes(&self) -> Result<(), EmitError> { + for squeeze in &self.transcript_squeezes { + if squeeze.kind != "challenge_vector" { + return Err(EmitError::new(format!( + "stage1 transcript squeeze @{} has unsupported kind `{}`", + squeeze.symbol, squeeze.kind + ))); + } + if squeeze.count == 0 { + return Err(EmitError::new(format!( + "stage1 transcript squeeze @{} has zero count", + squeeze.symbol + ))); + } + } + Ok(()) + } + + fn verify_kernel_definitions(&self) -> Result<(), EmitError> { + for kernel in &self.kernels { + if kernel.backend != "cpu" { + return Err(EmitError::new(format!( + "stage1 kernel @{} targets unsupported backend `{}`", + kernel.symbol, kernel.backend + ))); + } + if kernel.kind != "sumcheck" { + return Err(EmitError::new(format!( + "stage1 kernel @{} has unsupported kind `{}`", + kernel.symbol, kernel.kind + ))); + } + let expected_abi = match kernel.relation.as_str() { + "jolt.stage1.outer.uniskip" => "jolt_stage1_outer_uniskip", + "jolt.stage1.outer.remaining" => "jolt_stage1_outer_remaining", + _ => { + return Err(EmitError::new(format!( + "unsupported stage1 kernel relation @{}", + kernel.relation + ))); + } + }; + if kernel.abi != expected_abi { + return Err(EmitError::new(format!( + "stage1 kernel @{} ABI `{}` does not match relation @{}", + kernel.symbol, kernel.abi, kernel.relation + ))); + } + } + Ok(()) + } + + fn verify_claim_batches(&self) -> Result<(), EmitError> { + let claims = symbols(self.claims.iter().map(|claim| &claim.symbol)); + for batch in &self.batches { + verify_count( + "sumcheck batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "sumcheck batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "sumcheck batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !claims.contains(claim) { + return Err(EmitError::new(format!( + "sumcheck batch @{} references missing claim @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn verify_prover_driver_bindings(&self) -> Result<(), EmitError> { + let kernels = symbols(self.kernels.iter().map(|kernel| &kernel.symbol)); + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + let claims: BTreeMap<_, _> = self + .claims + .iter() + .map(|claim| (claim.symbol.as_str(), claim)) + .collect(); + for claim in &self.claims { + let Some(kernel) = claim.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck claim @{} is missing kernel", + claim.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck claim @{} references missing kernel @{kernel}", + claim.symbol + ))); + } + } + for driver in &self.drivers { + let Some(kernel) = driver.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck driver @{} is missing kernel", + driver.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck driver @{} references missing kernel @{kernel}", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + for claim in &batch.ordered_claims { + let claim = claims.get(claim.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing claim @{claim}", + driver.symbol + )) + })?; + if claim.kernel.as_deref() != Some(kernel) { + return Err(EmitError::new(format!( + "sumcheck driver @{} kernel @{kernel} differs from claim @{} kernel {:?}", + driver.symbol, claim.symbol, claim.kernel + ))); + } + } + } + Ok(()) + } + + fn verify_verifier_driver_bindings(&self) -> Result<(), EmitError> { + if !self.kernels.is_empty() { + return Err(EmitError::new( + "verifier stage1 program must not contain kernels", + )); + } + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + let claims: BTreeMap<_, _> = self + .claims + .iter() + .map(|claim| (claim.symbol.as_str(), claim)) + .collect(); + for claim in &self.claims { + if claim.kernel.is_some() || claim.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck claim @{} must carry relation and no kernel", + claim.symbol + ))); + } + } + for driver in &self.drivers { + let Some(relation) = driver.relation.as_deref() else { + return Err(EmitError::new(format!( + "verifier sumcheck driver @{} is missing relation", + driver.symbol + ))); + }; + if driver.kernel.is_some() { + return Err(EmitError::new(format!( + "verifier sumcheck driver @{} must not carry kernel", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + for claim in &batch.ordered_claims { + let claim = claims.get(claim.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing claim @{claim}", + driver.symbol + )) + })?; + if claim.relation.as_deref() != Some(relation) { + return Err(EmitError::new(format!( + "sumcheck driver @{} relation @{relation} differs from claim @{} relation {:?}", + driver.symbol, claim.symbol, claim.relation + ))); + } + } + } + Ok(()) + } + + fn verify_opening_flow(&self) -> Result<(), EmitError> { + let drivers = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + let instance_results = symbols( + self.instance_results + .iter() + .map(|instance| &instance.symbol), + ); + for instance in &self.instance_results { + if !drivers.contains(&instance.source) { + return Err(EmitError::new(format!( + "sumcheck instance result @{} references missing driver @{}", + instance.symbol, instance.source + ))); + } + } + let mut point_sources = drivers.clone(); + point_sources.extend(instance_results); + let evals = symbols(self.evals.iter().map(|eval| &eval.symbol)); + let openings = symbols(self.opening_claims.iter().map(|claim| &claim.symbol)); + for eval in &self.evals { + if !drivers.contains(&eval.source) { + return Err(EmitError::new(format!( + "sumcheck eval @{} references missing driver @{}", + eval.symbol, eval.source + ))); + } + } + for claim in &self.opening_claims { + if !point_sources.contains(&claim.point_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing point source @{}", + claim.symbol, claim.point_source + ))); + } + if !evals.contains(&claim.eval_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing eval source @{}", + claim.symbol, claim.eval_source + ))); + } + } + for batch in &self.opening_batches { + verify_count( + "opening batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "opening batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "opening batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !openings.contains(claim) { + return Err(EmitError::new(format!( + "opening batch @{} references missing opening @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn emit_source(&self) -> Result { + match self.role { + Role::Prover => self.emit_prover_source(), + Role::Verifier => self.emit_verifier_source(), + } + } + + fn emit_prover_source(&self) -> Result { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + source.push_str(Self::emit_prover_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_prover_types()); + source.push('\n'); + source.push_str(&self.emit_prover_constants()?); + source.push('\n'); + source.push_str(Self::emit_prover_entrypoint()); + Ok(source) + } + + fn emit_verifier_source(&self) -> Result { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + source.push_str(Self::emit_verifier_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_verifier_types()); + source.push('\n'); + source.push_str(&self.emit_verifier_constants()?); + source.push('\n'); + source.push_str(Self::emit_verifier_entrypoint()); + Ok(source) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_stage1_outer.rs", + Role::Verifier => "verify_stage1_outer.rs", + } + } + + fn emit_prover_imports() -> &'static str { + "use jolt_field::Fr;\n\ + use jolt_kernels::stage1::{execute_stage1_program, Stage1CpuProgramPlan, Stage1ExecutionArtifacts, Stage1ExecutionMode, Stage1KernelError, Stage1KernelExecutor, Stage1KernelPlan, Stage1OpeningBatchPlan, Stage1OpeningClaimPlan, Stage1Params, Stage1SumcheckBatchPlan, Stage1SumcheckClaimPlan, Stage1SumcheckDriverPlan, Stage1SumcheckEvalPlan, Stage1SumcheckInstanceResultPlan, Stage1TranscriptSqueezePlan};\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_prover_types() -> &'static str { + "pub type DefaultStage1Transcript = Blake2bTranscript;\n" + } + + fn emit_prover_constants(&self) -> Result { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const STAGE1_PARAMS: Stage1Params = Stage1Params {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + + source.push_str(&self.emit_transcript_squeeze_constants()); + source.push_str(&self.emit_kernel_constants()); + source.push_str(&self.emit_sumcheck_claim_constants()?); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_sumcheck_driver_constants()?); + source.push_str(&self.emit_sumcheck_instance_result_constants()); + source.push_str(&self.emit_sumcheck_eval_constants()); + source.push_str(&self.emit_opening_claim_constants()); + source.push_str(&self.emit_opening_batch_constants()); + source.push_str( + "pub const STAGE1_PROGRAM: Stage1CpuProgramPlan = Stage1CpuProgramPlan {\n\ + \x20 params: STAGE1_PARAMS,\n\ + \x20 transcript_squeezes: STAGE1_TRANSCRIPT_SQUEEZES,\n\ + \x20 kernels: STAGE1_KERNELS,\n\ + \x20 claims: STAGE1_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE1_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE1_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE1_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE1_SUMCHECK_EVALS,\n\ + \x20 opening_claims: STAGE1_OPENING_CLAIMS,\n\ + \x20 opening_batches: STAGE1_OPENING_BATCHES,\n\ + };\n", + ); + Ok(source) + } + + fn emit_sumcheck_instance_result_constants(&self) -> String { + let instances = self + .instance_results + .iter() + .map(|instance| { + format!( + " Stage1SumcheckInstanceResultPlan {{ symbol: {}, source: {}, claim: {}, relation: {}, index: {}, point_arity: {}, num_rounds: {}, round_offset: {}, point_order: {}, degree: {} }},", + rust_str(&instance.symbol), + rust_str(&instance.source), + rust_str(&instance.claim), + rust_str(&instance.relation), + instance.index, + instance.point_arity, + instance.num_rounds, + instance.round_offset, + rust_str(&instance.point_order), + instance.degree + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE1_SUMCHECK_INSTANCE_RESULTS: &[Stage1SumcheckInstanceResultPlan] = &[\n{instances}\n];\n\n" + ) + } + + fn emit_transcript_squeeze_constants(&self) -> String { + let squeezes = self + .transcript_squeezes + .iter() + .map(|squeeze| { + format!( + " Stage1TranscriptSqueezePlan {{ symbol: {}, label: {}, kind: {}, count: {} }},", + rust_str(&squeeze.symbol), + rust_str(&squeeze.label), + rust_str(&squeeze.kind), + squeeze.count, + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE1_TRANSCRIPT_SQUEEZES: &[Stage1TranscriptSqueezePlan] = &[\n{squeezes}\n];\n\n" + ) + } + + fn emit_kernel_constants(&self) -> String { + let kernels = self + .kernels + .iter() + .map(|kernel| { + format!( + " Stage1KernelPlan {{ symbol: {}, relation: {}, kind: {}, backend: {}, abi: {} }},", + rust_str(&kernel.symbol), + rust_str(&kernel.relation), + rust_str(&kernel.kind), + rust_str(&kernel.backend), + rust_str(&kernel.abi) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE1_KERNELS: &[Stage1KernelPlan] = &[\n{kernels}\n];\n\n") + } + + fn emit_sumcheck_claim_constants(&self) -> Result { + let mut source = String::new(); + for (index, claim) in self.claims.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE1_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS"), + &claim.input_openings, + )); + } + let mut claims = Vec::new(); + for (index, claim) in self.claims.iter().enumerate() { + let kernel = claim + .kernel + .as_deref() + .ok_or_else(|| missing_role_binding("prover claim kernel", &claim.symbol))?; + claims.push(format!( + " Stage1SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: Some({}), relation: None, claim_value: {}, input_openings: STAGE1_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_str(kernel), + rust_str(&claim.claim_value) + )); + } + let claims = claims.join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE1_SUMCHECK_CLAIMS: &[Stage1SumcheckClaimPlan] = &[\n{claims}\n];\n" + ), + ); + Ok(source) + } + + fn emit_sumcheck_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE1_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage1SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {}, claim_label: {}, round_label: {}, round_schedule: STAGE1_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")), + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE1_SUMCHECK_BATCHES: &[Stage1SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + return source; + } + + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE1_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE1_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + source.push_str(&emit_usize_array( + &format!("STAGE1_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage1SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE1_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE1_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS, claim_label: {}, round_label: {}, round_schedule: STAGE1_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE1_SUMCHECK_BATCHES: &[Stage1SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_driver_constants(&self) -> Result { + let mut source = String::new(); + for (index, driver) in self.drivers.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE1_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE"), + &driver.round_schedule, + )); + } + let mut drivers = Vec::new(); + for (index, driver) in self.drivers.iter().enumerate() { + let kernel = driver + .kernel + .as_deref() + .ok_or_else(|| missing_role_binding("prover driver kernel", &driver.symbol))?; + drivers.push(format!( + " Stage1SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: Some({}), relation: None, batch: {}, policy: {}, round_schedule: STAGE1_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_str(kernel), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + )); + } + let drivers = drivers.join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE1_SUMCHECK_DRIVERS: &[Stage1SumcheckDriverPlan] = &[\n{drivers}\n];\n" + ), + ); + Ok(source) + } + + fn emit_sumcheck_eval_constants(&self) -> String { + let evals = self + .evals + .iter() + .map(|eval| { + format!( + " Stage1SumcheckEvalPlan {{ symbol: {}, source: {}, name: {}, index: {}, oracle: {} }},", + rust_str(&eval.symbol), + rust_str(&eval.source), + rust_str(&eval.name), + eval.index, + rust_str(&eval.oracle) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE1_SUMCHECK_EVALS: &[Stage1SumcheckEvalPlan] = &[\n{evals}\n];\n\n") + } + + fn emit_opening_claim_constants(&self) -> String { + let claims = self + .opening_claims + .iter() + .map(|claim| { + format!( + " Stage1OpeningClaimPlan {{ symbol: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {}, point_source: {}, eval_source: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.oracle), + rust_str(&claim.domain), + claim.point_arity, + rust_str(&claim.claim_kind), + rust_str(&claim.point_source), + rust_str(&claim.eval_source) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE1_OPENING_CLAIMS: &[Stage1OpeningClaimPlan] = &[\n{claims}\n];\n\n") + } + + fn emit_opening_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let batches = self + .opening_batches + .iter() + .map(|batch| { + format!( + " Stage1OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {} }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE1_OPENING_BATCHES: &[Stage1OpeningBatchPlan] = &[\n{batches}\n];\n" + ); + } + + let mut source = String::new(); + for (index, batch) in self.opening_batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE1_OPENING_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE1_OPENING_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + } + let batches = self + .opening_batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage1OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE1_OPENING_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE1_OPENING_BATCH_{index}_CLAIM_OPERANDS }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE1_OPENING_BATCHES: &[Stage1OpeningBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_verifier_imports() -> &'static str { + "use super::common::append_labeled_scalar;\n\ + use jolt_field::{Field, Fr};\n\ + use jolt_sumcheck::{CompressedLabeledRoundPoly, LabeledRoundPoly, SumcheckClaim, SumcheckError, SumcheckVerifier};\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_verifier_types() -> &'static str { + r"pub type DefaultStage1Transcript = Blake2bTranscript; + +pub type Stage1Params = super::common::StageParams; +pub type Stage1NamedEval = super::common::StageNamedEval; +pub type Stage1SumcheckOutput = super::common::StageSumcheckOutput; +pub type Stage1ChallengeVector = super::common::StageChallengeVector; +pub type Stage1ExecutionArtifacts = super::common::StageExecutionArtifacts; +pub type Stage1Proof = super::common::StageProof; +pub type Stage1VerifierProgramPlan = super::common::VerifierProgramPlanMinimal; + +pub use super::common::{ + OpeningBatchPlan as Stage1OpeningBatchPlan, OpeningClaimPlan as Stage1OpeningClaimPlan, + SumcheckBatchPlan as Stage1SumcheckBatchPlan, SumcheckEvalPlan as Stage1SumcheckEvalPlan, + SumcheckInstanceResultPlan as Stage1SumcheckInstanceResultPlan, + TranscriptSqueezePlan as Stage1TranscriptSqueezePlan, + SumcheckClaimPlan as Stage1SumcheckClaimPlan, + SumcheckDriverPlan as Stage1SumcheckDriverPlan, +}; + +#[derive(Debug)] +pub enum VerifyStage1Error { + UnexpectedProofCount { expected: usize, got: usize }, + MissingProof { driver: &'static str }, + MissingBatch { driver: &'static str, batch: &'static str }, + MissingClaim { driver: &'static str, claim: &'static str }, + MissingDependency { driver: &'static str, dependency: &'static str }, + InvalidProof { driver: &'static str, reason: &'static str }, + UnsupportedRelation { relation: &'static str }, + Sumcheck { driver: &'static str, error: SumcheckError }, +} +" + } + + fn emit_verifier_constants(&self) -> Result { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const STAGE1_PARAMS: Stage1Params = Stage1Params {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + + source.push_str(&self.emit_transcript_squeeze_constants()); + source.push_str(&self.emit_verifier_sumcheck_claim_constants()?); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_verifier_sumcheck_driver_constants()?); + source.push_str(&self.emit_sumcheck_instance_result_constants()); + source.push_str(&self.emit_sumcheck_eval_constants()); + source.push_str(&self.emit_opening_claim_constants()); + source.push_str(&self.emit_opening_batch_constants()); + source.push_str( + "pub const STAGE1_PROGRAM: Stage1VerifierProgramPlan = Stage1VerifierProgramPlan {\n\ + \x20 params: STAGE1_PARAMS,\n\ + \x20 transcript_squeezes: STAGE1_TRANSCRIPT_SQUEEZES,\n\ + \x20 claims: STAGE1_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE1_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE1_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE1_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE1_SUMCHECK_EVALS,\n\ + \x20 opening_claims: STAGE1_OPENING_CLAIMS,\n\ + \x20 opening_batches: STAGE1_OPENING_BATCHES,\n\ + };\n", + ); + Ok(source) + } + + fn emit_verifier_sumcheck_claim_constants(&self) -> Result { + let mut claims = Vec::new(); + for claim in &self.claims { + let relation = claim + .relation + .as_deref() + .ok_or_else(|| missing_role_binding("verifier claim relation", &claim.symbol))?; + claims.push(format!( + " Stage1SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: None, relation: Some({}), claim_value: {}, input_openings: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_str(relation), + rust_str(&claim.claim_value), + rust_str(&claim.input_openings.join("|")) + )); + } + let claims = claims.join("\n"); + Ok(format!( + "pub const STAGE1_SUMCHECK_CLAIMS: &[Stage1SumcheckClaimPlan] = &[\n{claims}\n];\n" + )) + } + + fn emit_verifier_sumcheck_driver_constants(&self) -> Result { + let mut source = String::new(); + for (index, driver) in self.drivers.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE1_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE"), + &driver.round_schedule, + )); + } + let mut drivers = Vec::new(); + for (index, driver) in self.drivers.iter().enumerate() { + let relation = driver + .relation + .as_deref() + .ok_or_else(|| missing_role_binding("verifier driver relation", &driver.symbol))?; + drivers.push(format!( + " Stage1SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: None, relation: Some({}), batch: {}, policy: {}, round_schedule: STAGE1_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_str(relation), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + )); + } + let drivers = drivers.join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE1_SUMCHECK_DRIVERS: &[Stage1SumcheckDriverPlan] = &[\n{drivers}\n];\n" + ), + ); + Ok(source) + } + + fn emit_prover_entrypoint() -> &'static str { + r"pub fn prove_stage1_outer( + executor: &mut E, + transcript: &mut T, +) -> Result, Stage1KernelError> +where + E: Stage1KernelExecutor, + T: Transcript, +{ + prove_stage1_outer_with_program(&STAGE1_PROGRAM, executor, transcript) +} + +pub fn prove_stage1_outer_with_program( + program: &'static Stage1CpuProgramPlan, + executor: &mut E, + transcript: &mut T, +) -> Result, Stage1KernelError> +where + E: Stage1KernelExecutor, + T: Transcript, +{ + execute_stage1_program( + program, + Stage1ExecutionMode::Prover, + executor, + transcript, + ) +} +" + } + + fn emit_verifier_entrypoint() -> &'static str { + r#"pub fn verify_stage1_outer( + proof: &Stage1Proof, + transcript: &mut T, +) -> Result, VerifyStage1Error> +where + T: Transcript, +{ + verify_stage1_outer_with_program(&STAGE1_PROGRAM, proof, transcript) +} + +pub fn verify_stage1_outer_with_program( + program: &'static Stage1VerifierProgramPlan, + proof: &Stage1Proof, + transcript: &mut T, +) -> Result, VerifyStage1Error> +where + T: Transcript, +{ + if proof.sumchecks.len() != program.drivers.len() { + return Err(VerifyStage1Error::UnexpectedProofCount { + expected: program.drivers.len(), + got: proof.sumchecks.len(), + }); + } + let mut artifacts = Stage1ExecutionArtifacts::default(); + for squeeze in program.transcript_squeezes { + let values = transcript.challenge_vector(squeeze.count); + artifacts.challenge_vectors.push(Stage1ChallengeVector { + symbol: squeeze.symbol, + values, + }); + } + for (index, driver) in program.drivers.iter().enumerate() { + let proof = proof.sumchecks.get(index).ok_or(VerifyStage1Error::MissingProof { + driver: driver.symbol, + })?; + let output = verify_stage1_driver(program, driver, proof, &artifacts.sumchecks, transcript)?; + artifacts.sumchecks.push(output); + } + artifacts + .opening_batches + .extend(program.opening_batches.iter()); + Ok(artifacts) +} + +pub fn stage1_outer_verifier_program() -> &'static Stage1VerifierProgramPlan { + &STAGE1_PROGRAM +} + +fn verify_stage1_driver( + program: &'static Stage1VerifierProgramPlan, + driver: &'static Stage1SumcheckDriverPlan, + proof: &Stage1SumcheckOutput, + completed: &[Stage1SumcheckOutput], + transcript: &mut T, +) -> Result, VerifyStage1Error> +where + T: Transcript, +{ + if proof.driver != driver.symbol { + return Err(VerifyStage1Error::InvalidProof { + driver: driver.symbol, + reason: "driver symbol mismatch", + }); + } + let relation = driver.relation.unwrap_or(""); + match relation { + "jolt.stage1.outer.uniskip" => verify_outer_uniskip(program, driver, proof, transcript), + "jolt.stage1.outer.remaining" => { + verify_outer_remaining(program, driver, proof, completed, transcript) + } + relation => Err(VerifyStage1Error::UnsupportedRelation { relation }), + } +} + +fn verify_outer_uniskip( + program: &'static Stage1VerifierProgramPlan, + driver: &'static Stage1SumcheckDriverPlan, + proof: &Stage1SumcheckOutput, + transcript: &mut T, +) -> Result, VerifyStage1Error> +where + T: Transcript, +{ + let claim = SumcheckClaim::new(driver.num_rounds, driver.degree, Fr::from_u64(0)); + let round_proofs = proof + .proof + .round_polynomials + .iter() + .map(|poly| LabeledRoundPoly::new(poly, driver.round_label.as_bytes())) + .collect::>(); + let output = SumcheckVerifier::verify(&claim, &round_proofs, transcript) + .map_err(|error| VerifyStage1Error::Sumcheck { + driver: driver.symbol, + error, + })?; + let eval = output.value; + let point = output.point; + if !proof.point.is_empty() && proof.point != point { + return Err(VerifyStage1Error::InvalidProof { + driver: driver.symbol, + reason: "uniskip point mismatch", + }); + } + validate_eval_shape(program, driver, &proof.evals, Some(eval))?; + append_labeled_scalar(transcript, "opening_claim", &eval); + Ok(Stage1SumcheckOutput { + driver: driver.symbol, + point, + evals: driver_evals(program, driver.symbol, eval), + proof: proof.proof.clone(), + }) +} + +fn verify_outer_remaining( + program: &'static Stage1VerifierProgramPlan, + driver: &'static Stage1SumcheckDriverPlan, + proof: &Stage1SumcheckOutput, + completed: &[Stage1SumcheckOutput], + transcript: &mut T, +) -> Result, VerifyStage1Error> +where + T: Transcript, +{ + let input_claim = completed + .iter() + .find(|output| output.driver == "stage1.uniskip.sumcheck") + .and_then(|output| output.evals.first()) + .map(|eval| eval.value) + .ok_or(VerifyStage1Error::MissingDependency { + driver: driver.symbol, + dependency: "stage1.uniskip.eval", + })?; + append_labeled_scalar(transcript, driver.claim_label, &input_claim); + let batching_coeff = transcript.challenge(); + let claim = SumcheckClaim::new( + driver.num_rounds, + driver.degree, + input_claim * batching_coeff, + ); + let round_proofs = proof + .proof + .round_polynomials + .iter() + .map(|poly| CompressedLabeledRoundPoly::new(poly, driver.round_label.as_bytes())) + .collect::>(); + let output = SumcheckVerifier::verify(&claim, &round_proofs, transcript) + .map_err(|error| VerifyStage1Error::Sumcheck { + driver: driver.symbol, + error, + })?; + let point = output.point; + if !proof.point.is_empty() && proof.point != point { + return Err(VerifyStage1Error::InvalidProof { + driver: driver.symbol, + reason: "outer remaining point mismatch", + }); + } + validate_eval_shape(program, driver, &proof.evals, None)?; + append_opening_claims(transcript, &proof.evals); + Ok(Stage1SumcheckOutput { + driver: driver.symbol, + point, + evals: proof.evals.clone(), + proof: proof.proof.clone(), + }) +} + +fn driver_evals( + program: &'static Stage1VerifierProgramPlan, + driver: &'static str, + value: Fr, +) -> Vec> { + program + .evals + .iter() + .filter(|eval| eval.source == driver) + .map(|eval| Stage1NamedEval { + name: eval.name, + oracle: eval.oracle, + value, + }) + .collect() +} + +fn validate_eval_shape( + program: &'static Stage1VerifierProgramPlan, + driver: &'static Stage1SumcheckDriverPlan, + actual: &[Stage1NamedEval], + expected_value: Option, +) -> Result<(), VerifyStage1Error> { + let expected = program + .evals + .iter() + .filter(|eval| eval.source == driver.symbol) + .collect::>(); + if actual.len() != expected.len() { + return Err(VerifyStage1Error::InvalidProof { + driver: driver.symbol, + reason: "eval count mismatch", + }); + } + for (actual, expected) in actual.iter().zip(expected) { + if actual.name != expected.name { + return Err(VerifyStage1Error::InvalidProof { + driver: driver.symbol, + reason: "eval name mismatch", + }); + } + if actual.oracle != expected.oracle { + return Err(VerifyStage1Error::InvalidProof { + driver: driver.symbol, + reason: "eval oracle mismatch", + }); + } + if expected_value.is_some_and(|value| actual.value != value) { + return Err(VerifyStage1Error::InvalidProof { + driver: driver.symbol, + reason: "eval value mismatch", + }); + } + } + Ok(()) +} + +fn append_opening_claims(transcript: &mut T, evals: &[Stage1NamedEval]) +where + T: Transcript, +{ + for eval in evals { + append_labeled_scalar(transcript, "opening_claim", &eval.value); + } +} +"# + } +} + +fn emit_str_array(name: &str, values: &[String]) -> String { + if values.is_empty() { + return format!("pub const {name}: &[&str] = &[];\n\n"); + } + if let [value] = values { + return format!("pub const {name}: &[&str] = &[{}];\n\n", rust_str(value)); + } + let entries = values + .iter() + .map(|value| format!(" {},", rust_str(value))) + .collect::>() + .join("\n"); + format!("pub const {name}: &[&str] = &[\n{entries}\n];\n\n") +} + +fn emit_usize_array(name: &str, values: &[usize]) -> String { + let entries = values + .iter() + .map(|value| format!(" {value},")) + .collect::>() + .join("\n"); + format!("pub const {name}: &[usize] = &[\n{entries}\n];\n\n") +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn verify_count(kind: &str, symbol: &str, expected: usize, actual: usize) -> Result<(), EmitError> { + if expected == actual { + Ok(()) + } else { + Err(EmitError::new(format!( + "{kind} @{symbol} count mismatch: expected {expected}, got {actual}" + ))) + } +} + +fn missing_role_binding(kind: &str, symbol: &str) -> EmitError { + EmitError::new(format!("missing {kind} for `{symbol}`")) +} + +fn symbols<'a>(values: impl Iterator) -> BTreeSet { + values.cloned().collect() +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn int_array_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "integer array"))?; + parse_int_array(&attribute).ok_or_else(|| attr_error(operation, attr, "integer array")) +} + +fn parse_int_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().parse().ok()) + .collect() +} + +fn operand_symbols( + operation: OperationRef<'_, '_>, + start_index: usize, +) -> Result, EmitError> { + (start_index..operation.operand_count()) + .map(|index| operand_symbol(operation, index)) + .collect() +} + +fn operand_symbol(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + EmitError::new(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + EmitError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name(operation: OperationRef<'_, '_>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} + +fn require_supported_symbol(kind: &str, actual: &str, expected: &str) -> Result<(), EmitError> { + if actual == expected { + Ok(()) + } else { + Err(EmitError::new(format!( + "unsupported {kind} @{actual}; Rust stage1 emitter currently supports @{expected}" + ))) + } +} diff --git a/crates/bolt/src/protocols/jolt/emit/rust/stage2.rs b/crates/bolt/src/protocols/jolt/emit/rust/stage2.rs new file mode 100644 index 0000000000..eb8eceadb7 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/stage2.rs @@ -0,0 +1,2695 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{Attribute, OperationRef}; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2CpuProgram { + pub role: Role, + pub params: Stage2Params, + pub steps: Vec, + pub transcript_squeezes: Vec, + pub opening_inputs: Vec, + pub field_constants: Vec, + pub field_exprs: Vec, + pub kernels: Vec, + pub claims: Vec, + pub batches: Vec, + pub drivers: Vec, + pub instance_results: Vec, + pub evals: Vec, + pub point_slices: Vec, + pub point_concats: Vec, + pub opening_claims: Vec, + pub opening_batches: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2Params { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2KernelPlan { + pub symbol: String, + pub relation: String, + pub kind: String, + pub backend: String, + pub abi: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2TranscriptSqueezePlan { + pub symbol: String, + pub label: String, + pub kind: String, + pub count: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2ProgramStepPlan { + pub kind: String, + pub symbol: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2OpeningInputPlan { + pub symbol: String, + pub source_stage: String, + pub source_claim: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2FieldConstantPlan { + pub symbol: String, + pub field: String, + pub value: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2FieldExprPlan { + pub symbol: String, + pub kind: String, + pub formula: String, + pub operand_names: Vec, + pub operands: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2SumcheckClaimPlan { + pub symbol: String, + pub stage: String, + pub domain: String, + pub num_rounds: usize, + pub degree: usize, + pub claim: String, + pub kernel: Option, + pub relation: Option, + pub claim_value: String, + pub input_openings: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2SumcheckBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, + pub claim_label: String, + pub round_label: String, + pub round_schedule: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2SumcheckDriverPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub kernel: Option, + pub relation: Option, + pub batch: String, + pub policy: String, + pub round_schedule: Vec, + pub claim_label: String, + pub round_label: String, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2SumcheckInstanceResultPlan { + pub symbol: String, + pub source: String, + pub claim: String, + pub relation: String, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: String, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2SumcheckEvalPlan { + pub symbol: String, + pub source: String, + pub name: String, + pub index: usize, + pub oracle: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2PointSlicePlan { + pub symbol: String, + pub source: String, + pub offset: usize, + pub length: usize, + pub input: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2PointConcatPlan { + pub symbol: String, + pub layout: String, + pub arity: usize, + pub inputs: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2OpeningClaimPlan { + pub symbol: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, + pub point_source: String, + pub eval_source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage2OpeningBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, +} + +pub fn stage2_cpu_program(module: &BoltModule<'_, Cpu>) -> Result { + verify_cpu_schema(module)?; + let program = Stage2CpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +pub fn emit_stage2_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = stage2_cpu_program(module)?; + + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source()?, + }) +} + +impl Stage2CpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let mut params = None; + let mut steps = Vec::new(); + let mut transcript_squeezes = Vec::new(); + let mut opening_inputs = Vec::new(); + let mut field_constants = Vec::new(); + let mut field_exprs = Vec::new(); + let mut kernels = Vec::new(); + let mut claims = Vec::new(); + let mut batches = Vec::new(); + let mut drivers = Vec::new(); + let mut instance_results = Vec::new(); + let mut evals = Vec::new(); + let mut point_slices = Vec::new(); + let mut point_concats = Vec::new(); + let mut opening_claims = Vec::new(); + let mut opening_batches = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(Stage2Params { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.kernel" => { + kernels.push(Stage2KernelPlan { + symbol: string_attr(op, "sym_name")?, + relation: symbol_attr(op, "relation")?, + kind: string_attr(op, "kind")?, + backend: string_attr(op, "backend")?, + abi: string_attr(op, "abi")?, + }); + } + "cpu.transcript_squeeze" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage2ProgramStepPlan { + kind: "transcript_squeeze".to_owned(), + symbol: symbol.clone(), + }); + transcript_squeezes.push(Stage2TranscriptSqueezePlan { + symbol, + label: string_attr(op, "label")?, + kind: string_attr(op, "kind")?, + count: int_attr(op, "count")?, + }); + } + "cpu.opening_input" => { + opening_inputs.push(Stage2OpeningInputPlan { + symbol: string_attr(op, "sym_name")?, + source_stage: symbol_attr(op, "source_stage")?, + source_claim: symbol_attr(op, "source_claim")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + }); + } + "cpu.field_const" => { + field_constants.push(Stage2FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: int_attr(op, "value")?, + }); + } + "cpu.field_zero" => { + field_constants.push(Stage2FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 0, + }); + } + "cpu.field_one" => { + field_constants.push(Stage2FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 1, + }); + } + "cpu.field_add" | "cpu.field_sub" | "cpu.field_mul" | "cpu.field_neg" => { + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage2FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: operation_name(op).replace("cpu.field_", "field."), + operand_names: operands.clone(), + operands, + }); + } + "cpu.field_pow" => { + let exponent = int_attr(op, "exponent")?; + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage2FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: format!("field.pow:{exponent}"), + operand_names: operands.clone(), + operands, + }); + } + "cpu.poly_lagrange_basis_eval" => { + let domain_start = signed_int_attr(op, "domain_start")?; + let domain_size = int_attr(op, "domain_size")?; + let index = int_attr(op, "index")?; + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage2FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: format!( + "poly.lagrange_basis_eval:{domain_start}:{domain_size}:{index}" + ), + operand_names: operands.clone(), + operands, + }); + } + "cpu.sumcheck_claim" => { + claims.push(Stage2SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_verify_claim" => { + claims.push(Stage2SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_batch" => { + batches.push(Stage2SumcheckBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + round_schedule: int_array_attr(op, "round_schedule")?, + }); + } + "cpu.sumcheck_driver" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage2ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage2SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_verify" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage2ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage2SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_instance_result" => { + instance_results.push(Stage2SumcheckInstanceResultPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + claim: symbol_attr(op, "claim")?, + relation: symbol_attr(op, "relation")?, + index: int_attr(op, "index")?, + point_arity: int_attr(op, "point_arity")?, + num_rounds: int_attr(op, "num_rounds")?, + round_offset: int_attr(op, "round_offset")?, + point_order: string_attr(op, "point_order")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_eval" => { + evals.push(Stage2SumcheckEvalPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + name: symbol_attr(op, "name")?, + index: int_attr(op, "index")?, + oracle: symbol_attr(op, "oracle")?, + }); + } + "cpu.point_slice" => { + point_slices.push(Stage2PointSlicePlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + offset: int_attr(op, "offset")?, + length: int_attr(op, "length")?, + input: operand_symbol(op, 0)?, + }); + } + "cpu.point_concat" => { + point_concats.push(Stage2PointConcatPlan { + symbol: string_attr(op, "sym_name")?, + layout: string_attr(op, "layout")?, + arity: int_attr(op, "arity")?, + inputs: operand_symbols(op, 0)?, + }); + } + "cpu.opening_claim" => { + opening_claims.push(Stage2OpeningClaimPlan { + symbol: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + point_source: operand_symbol(op, 0)?, + eval_source: operand_symbol(op, 1)?, + }); + } + "cpu.opening_batch" => { + opening_batches.push(Stage2OpeningBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + }); + } + _ => {} + } + } + + Ok(Self { + params: params.ok_or_else(|| EmitError::new("missing cpu.params"))?, + role: module + .role() + .ok_or_else(|| EmitError::new("missing cpu party role"))?, + steps, + transcript_squeezes, + opening_inputs, + field_constants, + field_exprs, + kernels, + claims, + batches, + drivers, + instance_results, + evals, + point_slices, + point_concats, + opening_claims, + opening_batches, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + require_supported_symbol("field", &self.params.field, "bn254_fr")?; + require_supported_symbol("pcs", &self.params.pcs, "dory")?; + require_supported_symbol("transcript", &self.params.transcript, "blake2b_transcript")?; + self.verify_transcript_squeezes()?; + self.verify_field_flow()?; + self.verify_claim_batches()?; + match self.role { + Role::Prover => { + self.verify_kernel_definitions()?; + self.verify_prover_driver_bindings()?; + } + Role::Verifier => self.verify_verifier_driver_bindings()?, + } + self.verify_opening_flow() + } + + fn verify_transcript_squeezes(&self) -> Result<(), EmitError> { + for squeeze in &self.transcript_squeezes { + if !matches!( + squeeze.kind.as_str(), + "challenge_scalar" | "challenge_vector" + ) { + return Err(EmitError::new(format!( + "stage2 transcript squeeze @{} has unsupported kind `{}`", + squeeze.symbol, squeeze.kind + ))); + } + if squeeze.count == 0 { + return Err(EmitError::new(format!( + "stage2 transcript squeeze @{} has zero count", + squeeze.symbol + ))); + } + } + Ok(()) + } + + fn verify_field_flow(&self) -> Result<(), EmitError> { + for constant in &self.field_constants { + require_supported_symbol("field constant field", &constant.field, "bn254_fr")?; + } + let field_values = self.field_value_symbols(); + for expr in &self.field_exprs { + verify_count( + "field expr operands", + &expr.symbol, + expr.operand_names.len(), + expr.operands.len(), + )?; + for operand in &expr.operands { + if !field_values.contains(operand) { + return Err(EmitError::new(format!( + "field expr @{} references missing field value @{operand}", + expr.symbol + ))); + } + } + } + for claim in &self.claims { + if !field_values.contains(&claim.claim_value) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing claim value @{}", + claim.symbol, claim.claim_value + ))); + } + } + Ok(()) + } + + fn field_value_symbols(&self) -> BTreeSet { + let mut values = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + values.extend(symbols( + self.field_constants.iter().map(|constant| &constant.symbol), + )); + values.extend(symbols( + self.transcript_squeezes + .iter() + .filter(|squeeze| matches!(squeeze.kind.as_str(), "challenge_scalar" | "scalar")) + .map(|squeeze| &squeeze.symbol), + )); + values.extend(symbols(self.field_exprs.iter().map(|expr| &expr.symbol))); + values.extend(symbols(self.evals.iter().map(|eval| &eval.symbol))); + values + } + + fn verify_kernel_definitions(&self) -> Result<(), EmitError> { + for kernel in &self.kernels { + if kernel.backend != "cpu" { + return Err(EmitError::new(format!( + "stage2 kernel @{} targets unsupported backend `{}`", + kernel.symbol, kernel.backend + ))); + } + if kernel.kind != "sumcheck" { + return Err(EmitError::new(format!( + "stage2 kernel @{} has unsupported kind `{}`", + kernel.symbol, kernel.kind + ))); + } + let expected_abi = match kernel.relation.as_str() { + "jolt.stage2.product_virtual.uniskip" => "jolt_stage2_product_virtual_uniskip", + "jolt.stage2.ram.read_write" => "jolt_stage2_ram_read_write", + "jolt.stage2.product_virtual.remainder" => "jolt_stage2_product_virtual_remainder", + "jolt.stage2.instruction_lookup.claim_reduction" => { + "jolt_stage2_instruction_lookup_claim_reduction" + } + "jolt.stage2.ram.raf_evaluation" => "jolt_stage2_ram_raf_evaluation", + "jolt.stage2.ram.output_check" => "jolt_stage2_ram_output_check", + "jolt.stage2.batched" => "jolt_stage2_batched", + _ => { + return Err(EmitError::new(format!( + "unsupported stage2 kernel relation @{}", + kernel.relation + ))); + } + }; + if kernel.abi != expected_abi { + return Err(EmitError::new(format!( + "stage2 kernel @{} ABI `{}` does not match relation @{}", + kernel.symbol, kernel.abi, kernel.relation + ))); + } + } + Ok(()) + } + + fn verify_claim_batches(&self) -> Result<(), EmitError> { + let claims = symbols(self.claims.iter().map(|claim| &claim.symbol)); + for batch in &self.batches { + verify_count( + "sumcheck batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "sumcheck batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "sumcheck batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !claims.contains(claim) { + return Err(EmitError::new(format!( + "sumcheck batch @{} references missing claim @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn verify_prover_driver_bindings(&self) -> Result<(), EmitError> { + let kernels = symbols(self.kernels.iter().map(|kernel| &kernel.symbol)); + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + let Some(kernel) = claim.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck claim @{} is missing kernel", + claim.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck claim @{} references missing kernel @{kernel}", + claim.symbol + ))); + } + } + for driver in &self.drivers { + let Some(kernel) = driver.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck driver @{} is missing kernel", + driver.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck driver @{} references missing kernel @{kernel}", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_verifier_driver_bindings(&self) -> Result<(), EmitError> { + if !self.kernels.is_empty() { + return Err(EmitError::new( + "verifier stage2 program must not contain kernels", + )); + } + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + if claim.kernel.is_some() || claim.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck claim @{} must carry relation and no kernel", + claim.symbol + ))); + } + } + for driver in &self.drivers { + if driver.kernel.is_some() || driver.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck driver @{} must carry relation and no kernel", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_opening_flow(&self) -> Result<(), EmitError> { + let mut point_sources = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + point_sources.extend(symbols( + self.instance_results + .iter() + .map(|instance| &instance.symbol), + )); + point_sources.extend(symbols( + self.opening_inputs.iter().map(|input| &input.symbol), + )); + point_sources.extend(symbols(self.point_slices.iter().map(|slice| &slice.symbol))); + point_sources.extend(symbols( + self.point_concats.iter().map(|concat| &concat.symbol), + )); + for slice in &self.point_slices { + if !point_sources.contains(&slice.input) { + return Err(EmitError::new(format!( + "point slice @{} uses missing point source @{}", + slice.symbol, slice.input + ))); + } + } + for concat in &self.point_concats { + for input in &concat.inputs { + if !point_sources.contains(input) { + return Err(EmitError::new(format!( + "point concat @{} uses missing point source @{input}", + concat.symbol + ))); + } + } + } + let eval_sources = self.field_value_symbols(); + let mut opening_sources = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + opening_sources.extend(symbols( + self.opening_claims.iter().map(|claim| &claim.symbol), + )); + for claim in &self.claims { + for input in &claim.input_openings { + if !opening_sources.contains(input) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing opening @{input}", + claim.symbol + ))); + } + } + } + let drivers = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + for instance in &self.instance_results { + if !drivers.contains(&instance.source) { + return Err(EmitError::new(format!( + "sumcheck instance result @{} references missing driver @{}", + instance.symbol, instance.source + ))); + } + } + for eval in &self.evals { + if !drivers.contains(&eval.source) { + return Err(EmitError::new(format!( + "sumcheck eval @{} references missing driver @{}", + eval.symbol, eval.source + ))); + } + } + for claim in &self.opening_claims { + if !point_sources.contains(&claim.point_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing point source @{}", + claim.symbol, claim.point_source + ))); + } + if !eval_sources.contains(&claim.eval_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing eval source @{}", + claim.symbol, claim.eval_source + ))); + } + } + let openings = symbols(self.opening_claims.iter().map(|claim| &claim.symbol)); + for batch in &self.opening_batches { + verify_count( + "opening batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "opening batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "opening batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !openings.contains(claim) { + return Err(EmitError::new(format!( + "opening batch @{} references missing opening @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn emit_source(&self) -> Result { + match self.role { + Role::Prover => self.emit_prover_source(), + Role::Verifier => self.emit_verifier_source(), + } + } + + fn emit_prover_source(&self) -> Result { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + source.push_str(Self::emit_prover_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_prover_types()); + source.push('\n'); + source.push_str(&self.emit_prover_constants()?); + source.push('\n'); + source.push_str(Self::emit_prover_entrypoint()); + Ok(source) + } + + fn emit_verifier_source(&self) -> Result { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + source.push_str(Self::emit_verifier_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_verifier_types()); + source.push('\n'); + source.push_str(&self.emit_verifier_constants()?); + source.push('\n'); + source.push_str(Self::emit_verifier_entrypoint()); + Ok(source) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_stage2.rs", + Role::Verifier => "verify_stage2.rs", + } + } + + fn emit_prover_imports() -> &'static str { + "use jolt_field::Fr;\n\ + use jolt_kernels::stage2::{execute_stage2_program, Stage2CpuProgramPlan, Stage2ExecutionArtifacts, Stage2ExecutionMode, Stage2FieldConstantPlan, Stage2FieldExprPlan, Stage2KernelError, Stage2KernelExecutor, Stage2KernelPlan, Stage2OpeningBatchPlan, Stage2OpeningClaimPlan, Stage2OpeningInputPlan, Stage2Params, Stage2PointConcatPlan, Stage2PointSlicePlan, Stage2ProgramStepPlan, Stage2SumcheckBatchPlan, Stage2SumcheckClaimPlan, Stage2SumcheckDriverPlan, Stage2SumcheckEvalPlan, Stage2SumcheckInstanceResultPlan, Stage2TranscriptSqueezePlan};\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_prover_types() -> &'static str { + "pub type DefaultStage2Transcript = Blake2bTranscript;\n" + } + + fn emit_verifier_imports() -> &'static str { + "use super::common::{append_labeled_scalar, batch_claims, eval_by_name, find_batch, find_plan, pow_field, require_operand_count, reverse_slice, single_operand};\n\ + use jolt_field::{Field, Fr};\n\ + use jolt_poly::lagrange::{lagrange_evals, lagrange_kernel_eval};\n\ + use jolt_poly::{EqPolynomial, UnivariatePoly};\n\ + use jolt_sumcheck::{CompressedLabeledRoundPoly, SumcheckClaim, SumcheckError, SumcheckVerifier};\n\ + use jolt_transcript::{Blake2bTranscript, LabelWithCount, Transcript};" + } + + fn emit_verifier_types() -> &'static str { + r"pub type DefaultStage2Transcript = Blake2bTranscript; + +pub type Stage2NamedEval = super::common::StageNamedEval; +pub type Stage2SumcheckOutput = super::common::StageSumcheckOutput; +pub type Stage2ChallengeVector = super::common::StageChallengeVector; +pub type Stage2ExecutionArtifacts = super::common::StageExecutionArtifacts; +pub type Stage2Proof = super::common::StageProof; +pub type Stage2OpeningInputValue = super::common::StageOpeningInputValue; +pub type Stage2VerifierProgramPlan = super::common::StageVerifierProgramPlanNoEqualities; + +pub use super::common::{ + FieldConstantPlan as Stage2FieldConstantPlan, FieldExprPlan as Stage2FieldExprPlan, + OpeningBatchPlan as Stage2OpeningBatchPlan, OpeningClaimPlan as Stage2OpeningClaimPlan, + OpeningInputPlan as Stage2OpeningInputPlan, PointConcatPlan as Stage2PointConcatPlan, + PointSlicePlan as Stage2PointSlicePlan, ProgramStepPlan as Stage2ProgramStepPlan, + StageParams as Stage2Params, SumcheckBatchPlan as Stage2SumcheckBatchPlan, + SumcheckEvalPlan as Stage2SumcheckEvalPlan, + SumcheckInstanceResultPlan as Stage2SumcheckInstanceResultPlan, + TranscriptSqueezePlan as Stage2TranscriptSqueezePlan, + SumcheckClaimPlan as Stage2SumcheckClaimPlan, + SumcheckDriverPlan as Stage2SumcheckDriverPlan, +}; + +#[derive(Clone, Copy, Debug)] +pub struct Stage2RamAccess { + pub remapped_address: Option, + pub read_value: u64, + pub write_value: u64, +} + +#[derive(Clone, Copy, Debug)] +pub struct Stage2RamOutputLayout { + pub io_start: usize, + pub io_end: usize, +} + +#[derive(Clone, Copy, Debug)] +pub struct Stage2RamData<'a> { + pub log_k: usize, + pub start_address: u64, + pub initial_ram: &'a [u64], + pub final_ram: &'a [u64], + pub accesses: &'a [Stage2RamAccess], + pub output_layout: Option, +} + +#[derive(Clone, Debug, Default)] +struct Stage2ValueStore(super::common::ValueStore); + +#[derive(Debug)] +pub enum VerifyStage2Error { + UnexpectedProofCount { expected: usize, got: usize }, + MissingProof { driver: &'static str }, + MissingBatch { driver: &'static str, batch: &'static str }, + MissingClaim { batch: &'static str, claim: &'static str }, + MissingValue { symbol: &'static str }, + InvalidInputLength { input: &'static str, expected: usize, actual: usize }, + InvalidProof { driver: &'static str, reason: &'static str }, + UnsupportedFieldExpr { symbol: &'static str, formula: &'static str }, + UnsupportedRelation { relation: &'static str }, + MissingRam { relation: &'static str }, + Sumcheck { driver: &'static str, error: SumcheckError }, +} + +super::common::impl_runtime_plan_error_conversion!(VerifyStage2Error); +" + } + + fn emit_prover_constants(&self) -> Result { + let mut source = self.emit_shared_constants(); + source.push_str(&self.emit_kernel_constants()); + source.push_str(&self.emit_prover_sumcheck_claim_constants()?); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_prover_sumcheck_driver_constants()?); + source.push_str(&self.emit_tail_constants()); + source.push_str( + "pub const STAGE2_PROGRAM: Stage2CpuProgramPlan = Stage2CpuProgramPlan {\n\ + \x20 params: STAGE2_PARAMS,\n\ + \x20 steps: STAGE2_PROGRAM_STEPS,\n\ + \x20 transcript_squeezes: STAGE2_TRANSCRIPT_SQUEEZES,\n\ + \x20 opening_inputs: STAGE2_OPENING_INPUTS,\n\ + \x20 field_constants: STAGE2_FIELD_CONSTANTS,\n\ + \x20 field_exprs: STAGE2_FIELD_EXPRS,\n\ + \x20 kernels: STAGE2_KERNELS,\n\ + \x20 claims: STAGE2_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE2_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE2_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE2_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE2_SUMCHECK_EVALS,\n\ + \x20 point_slices: STAGE2_POINT_SLICES,\n\ + \x20 point_concats: STAGE2_POINT_CONCATS,\n\ + \x20 opening_claims: STAGE2_OPENING_CLAIMS,\n\ + \x20 opening_batches: STAGE2_OPENING_BATCHES,\n\ + };\n", + ); + Ok(source) + } + + fn emit_verifier_constants(&self) -> Result { + let mut source = self.emit_shared_constants(); + source.push_str(&self.emit_verifier_sumcheck_claim_constants()?); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_verifier_sumcheck_driver_constants()?); + source.push_str(&self.emit_tail_constants()); + source.push_str( + "pub const STAGE2_PROGRAM: Stage2VerifierProgramPlan = Stage2VerifierProgramPlan {\n\ + \x20 params: STAGE2_PARAMS,\n\ + \x20 steps: STAGE2_PROGRAM_STEPS,\n\ + \x20 transcript_squeezes: STAGE2_TRANSCRIPT_SQUEEZES,\n\ + \x20 opening_inputs: STAGE2_OPENING_INPUTS,\n\ + \x20 field_constants: STAGE2_FIELD_CONSTANTS,\n\ + \x20 field_exprs: STAGE2_FIELD_EXPRS,\n\ + \x20 claims: STAGE2_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE2_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE2_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE2_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE2_SUMCHECK_EVALS,\n\ + \x20 point_slices: STAGE2_POINT_SLICES,\n\ + \x20 point_concats: STAGE2_POINT_CONCATS,\n\ + \x20 opening_claims: STAGE2_OPENING_CLAIMS,\n\ + \x20 opening_batches: STAGE2_OPENING_BATCHES,\n\ + };\n", + ); + Ok(source) + } + + fn emit_shared_constants(&self) -> String { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const STAGE2_PARAMS: Stage2Params = Stage2Params {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + source.push_str(&self.emit_program_step_constants()); + source.push_str(&self.emit_transcript_squeeze_constants()); + source.push_str(&self.emit_opening_input_constants()); + source.push_str(&self.emit_field_constant_constants()); + source.push_str(&self.emit_field_expr_constants()); + source + } + + fn emit_program_step_constants(&self) -> String { + let steps = self + .steps + .iter() + .map(|step| { + format!( + " Stage2ProgramStepPlan {{ kind: {}, symbol: {} }},", + rust_str(&step.kind), + rust_str(&step.symbol), + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE2_PROGRAM_STEPS: &[Stage2ProgramStepPlan] = &[\n{steps}\n];\n\n") + } + + fn emit_transcript_squeeze_constants(&self) -> String { + let squeezes = self + .transcript_squeezes + .iter() + .map(|squeeze| { + format!( + " Stage2TranscriptSqueezePlan {{ symbol: {}, label: {}, kind: {}, count: {} }},", + rust_str(&squeeze.symbol), + rust_str(&squeeze.label), + rust_str(&squeeze.kind), + squeeze.count, + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE2_TRANSCRIPT_SQUEEZES: &[Stage2TranscriptSqueezePlan] = &[\n{squeezes}\n];\n\n" + ) + } + + fn emit_opening_input_constants(&self) -> String { + let inputs = self + .opening_inputs + .iter() + .map(|input| { + format!( + " Stage2OpeningInputPlan {{ symbol: {}, source_stage: {}, source_claim: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {} }},", + rust_str(&input.symbol), + rust_str(&input.source_stage), + rust_str(&input.source_claim), + rust_str(&input.oracle), + rust_str(&input.domain), + input.point_arity, + rust_str(&input.claim_kind) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE2_OPENING_INPUTS: &[Stage2OpeningInputPlan] = &[\n{inputs}\n];\n\n") + } + + fn emit_field_constant_constants(&self) -> String { + let constants = self + .field_constants + .iter() + .map(|constant| { + format!( + " Stage2FieldConstantPlan {{ symbol: {}, field: {}, value: {} }},", + rust_str(&constant.symbol), + rust_str(&constant.field), + constant.value + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE2_FIELD_CONSTANTS: &[Stage2FieldConstantPlan] = &[\n{constants}\n];\n\n" + ) + } + + fn emit_field_expr_constants(&self) -> String { + if self.role == Role::Verifier { + let exprs = self + .field_exprs + .iter() + .map(|expr| { + format!( + " Stage2FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operands: {} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula), + rust_str(&expr.operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE2_FIELD_EXPRS: &[Stage2FieldExprPlan] = &[\n{exprs}\n];\n" + ); + } + + let mut source = String::new(); + let mut arrays = Vec::new(); + let mut array_refs = Vec::new(); + for (index, expr) in self.field_exprs.iter().enumerate() { + let operands = intern_str_array( + &mut source, + &mut arrays, + "STAGE2_FIELD_EXPR_OPERANDS", + &expr.operands, + ); + let operand_names = intern_str_array( + &mut source, + &mut arrays, + "STAGE2_FIELD_EXPR_OPERANDS", + &expr.operand_names, + ); + array_refs.push((index, operand_names, operands)); + } + let exprs = self + .field_exprs + .iter() + .enumerate() + .map(|(index, expr)| { + let (_, operand_names, operands) = &array_refs[index]; + format!( + " Stage2FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operand_names: {operand_names}, operands: {operands} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE2_FIELD_EXPRS: &[Stage2FieldExprPlan] = &[\n{exprs}\n];\n" + ), + ); + source + } + + fn emit_kernel_constants(&self) -> String { + let kernels = self + .kernels + .iter() + .map(|kernel| { + format!( + " Stage2KernelPlan {{ symbol: {}, relation: {}, kind: {}, backend: {}, abi: {} }},", + rust_str(&kernel.symbol), + rust_str(&kernel.relation), + rust_str(&kernel.kind), + rust_str(&kernel.backend), + rust_str(&kernel.abi) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE2_KERNELS: &[Stage2KernelPlan] = &[\n{kernels}\n];\n\n") + } + + fn emit_prover_sumcheck_claim_constants(&self) -> Result { + self.emit_sumcheck_claim_constants(true) + } + + fn emit_verifier_sumcheck_claim_constants(&self) -> Result { + self.emit_sumcheck_claim_constants(false) + } + + fn emit_sumcheck_claim_constants(&self, prover: bool) -> Result { + let mut source = String::new(); + if prover { + for (index, claim) in self.claims.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE2_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS"), + &claim.input_openings, + )); + } + } + let mut claims = Vec::new(); + for (index, claim) in self.claims.iter().enumerate() { + if prover { + let kernel = claim + .kernel + .as_deref() + .ok_or_else(|| missing_role_binding("prover claim kernel", &claim.symbol))?; + claims.push(format!( + " Stage2SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: Some({}), relation: None, claim_value: {}, input_openings: STAGE2_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_str(kernel), + rust_str(&claim.claim_value) + )); + } else { + let relation = claim.relation.as_deref().ok_or_else(|| { + missing_role_binding("verifier claim relation", &claim.symbol) + })?; + claims.push(format!( + " Stage2SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: None, relation: Some({}), claim_value: {}, input_openings: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_str(relation), + rust_str(&claim.claim_value), + rust_str(&claim.input_openings.join("|")) + )); + } + } + let claims = claims.join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE2_SUMCHECK_CLAIMS: &[Stage2SumcheckClaimPlan] = &[\n{claims}\n];\n" + ), + ); + Ok(source) + } + + fn emit_sumcheck_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE2_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage2SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {}, claim_label: {}, round_label: {}, round_schedule: STAGE2_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")), + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE2_SUMCHECK_BATCHES: &[Stage2SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + return source; + } + + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE2_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE2_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + source.push_str(&emit_usize_array( + &format!("STAGE2_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage2SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE2_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE2_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS, claim_label: {}, round_label: {}, round_schedule: STAGE2_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE2_SUMCHECK_BATCHES: &[Stage2SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_prover_sumcheck_driver_constants(&self) -> Result { + self.emit_sumcheck_driver_constants(true) + } + + fn emit_verifier_sumcheck_driver_constants(&self) -> Result { + self.emit_sumcheck_driver_constants(false) + } + + fn emit_sumcheck_driver_constants(&self, prover: bool) -> Result { + let mut source = String::new(); + for (index, driver) in self.drivers.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE2_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE"), + &driver.round_schedule, + )); + } + let mut drivers = Vec::new(); + for (index, driver) in self.drivers.iter().enumerate() { + if prover { + let kernel = driver + .kernel + .as_deref() + .ok_or_else(|| missing_role_binding("prover driver kernel", &driver.symbol))?; + drivers.push(format!( + " Stage2SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: Some({}), relation: None, batch: {}, policy: {}, round_schedule: STAGE2_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_str(kernel), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + )); + } else { + let relation = driver.relation.as_deref().ok_or_else(|| { + missing_role_binding("verifier driver relation", &driver.symbol) + })?; + drivers.push(format!( + " Stage2SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: None, relation: Some({}), batch: {}, policy: {}, round_schedule: STAGE2_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_str(relation), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + )); + } + } + let drivers = drivers.join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE2_SUMCHECK_DRIVERS: &[Stage2SumcheckDriverPlan] = &[\n{drivers}\n];\n" + ), + ); + Ok(source) + } + + fn emit_tail_constants(&self) -> String { + let mut source = String::new(); + source.push_str(&self.emit_sumcheck_instance_result_constants()); + source.push_str(&self.emit_sumcheck_eval_constants()); + source.push_str(&self.emit_point_slice_constants()); + source.push_str(&self.emit_point_concat_constants()); + source.push_str(&self.emit_opening_claim_constants()); + source.push_str(&self.emit_opening_batch_constants()); + source + } + + fn emit_sumcheck_instance_result_constants(&self) -> String { + let instances = self + .instance_results + .iter() + .map(|instance| { + format!( + " Stage2SumcheckInstanceResultPlan {{ symbol: {}, source: {}, claim: {}, relation: {}, index: {}, point_arity: {}, num_rounds: {}, round_offset: {}, point_order: {}, degree: {} }},", + rust_str(&instance.symbol), + rust_str(&instance.source), + rust_str(&instance.claim), + rust_str(&instance.relation), + instance.index, + instance.point_arity, + instance.num_rounds, + instance.round_offset, + rust_str(&instance.point_order), + instance.degree + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE2_SUMCHECK_INSTANCE_RESULTS: &[Stage2SumcheckInstanceResultPlan] = &[\n{instances}\n];\n\n" + ) + } + + fn emit_sumcheck_eval_constants(&self) -> String { + let evals = self + .evals + .iter() + .map(|eval| { + format!( + " Stage2SumcheckEvalPlan {{ symbol: {}, source: {}, name: {}, index: {}, oracle: {} }},", + rust_str(&eval.symbol), + rust_str(&eval.source), + rust_str(&eval.name), + eval.index, + rust_str(&eval.oracle) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE2_SUMCHECK_EVALS: &[Stage2SumcheckEvalPlan] = &[\n{evals}\n];\n\n") + } + + fn emit_point_slice_constants(&self) -> String { + let slices = self + .point_slices + .iter() + .map(|slice| { + format!( + " Stage2PointSlicePlan {{ symbol: {}, source: {}, offset: {}, length: {}, input: {} }},", + rust_str(&slice.symbol), + rust_str(&slice.source), + slice.offset, + slice.length, + rust_str(&slice.input) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE2_POINT_SLICES: &[Stage2PointSlicePlan] = &[\n{slices}\n];\n\n") + } + + fn emit_point_concat_constants(&self) -> String { + if self.role == Role::Verifier { + let concats = self + .point_concats + .iter() + .map(|concat| { + format!( + " Stage2PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: {} }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity, + rust_str(&concat.inputs.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE2_POINT_CONCATS: &[Stage2PointConcatPlan] = &[\n{concats}\n];\n" + ); + } + + let mut source = String::new(); + for (index, concat) in self.point_concats.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE2_POINT_CONCAT_{index}_INPUTS"), + &concat.inputs, + )); + } + let concats = self + .point_concats + .iter() + .enumerate() + .map(|(index, concat)| { + format!( + " Stage2PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: STAGE2_POINT_CONCAT_{index}_INPUTS }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE2_POINT_CONCATS: &[Stage2PointConcatPlan] = &[\n{concats}\n];\n" + ), + ); + source + } + + fn emit_opening_claim_constants(&self) -> String { + let claims = self + .opening_claims + .iter() + .map(|claim| { + format!( + " Stage2OpeningClaimPlan {{ symbol: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {}, point_source: {}, eval_source: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.oracle), + rust_str(&claim.domain), + claim.point_arity, + rust_str(&claim.claim_kind), + rust_str(&claim.point_source), + rust_str(&claim.eval_source) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE2_OPENING_CLAIMS: &[Stage2OpeningClaimPlan] = &[\n{claims}\n];\n\n") + } + + fn emit_opening_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let batches = self + .opening_batches + .iter() + .map(|batch| { + format!( + " Stage2OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {} }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE2_OPENING_BATCHES: &[Stage2OpeningBatchPlan] = &[\n{batches}\n];\n" + ); + } + + let mut source = String::new(); + for (index, batch) in self.opening_batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE2_OPENING_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE2_OPENING_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + } + let batches = self + .opening_batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage2OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE2_OPENING_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE2_OPENING_BATCH_{index}_CLAIM_OPERANDS }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE2_OPENING_BATCHES: &[Stage2OpeningBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_prover_entrypoint() -> &'static str { + "pub fn execute_stage2_prover(\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage2KernelError>\n\ + where\n\ + \x20 E: Stage2KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage2_prover_with_program(&STAGE2_PROGRAM, executor, transcript)\n\ + }\n\ + \n\ + pub fn execute_stage2_prover_with_program(\n\ + \x20 program: &'static Stage2CpuProgramPlan,\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage2KernelError>\n\ + where\n\ + \x20 E: Stage2KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage2_program(program, Stage2ExecutionMode::Prover, executor, transcript)\n\ + }\n" + } + + fn emit_verifier_entrypoint() -> &'static str { + r#"const PRODUCT_VIRTUAL_UNISKIP_DOMAIN_START: i64 = -1; +const PRODUCT_VIRTUAL_UNISKIP_DOMAIN_SIZE: usize = 3; + +pub fn verify_stage2( + proof: &Stage2Proof, + opening_inputs: &[Stage2OpeningInputValue], + ram: Option<&Stage2RamData<'_>>, + transcript: &mut T, +) -> Result, VerifyStage2Error> +where + T: Transcript, +{ + verify_stage2_with_program(&STAGE2_PROGRAM, proof, opening_inputs, ram, transcript) +} + +pub fn verify_stage2_with_program( + program: &'static Stage2VerifierProgramPlan, + proof: &Stage2Proof, + opening_inputs: &[Stage2OpeningInputValue], + ram: Option<&Stage2RamData<'_>>, + transcript: &mut T, +) -> Result, VerifyStage2Error> +where + T: Transcript, +{ + if proof.sumchecks.len() != program.drivers.len() { + return Err(VerifyStage2Error::UnexpectedProofCount { + expected: program.drivers.len(), + got: proof.sumchecks.len(), + }); + } + let mut store = Stage2ValueStore::with_opening_inputs(program, opening_inputs)?; + store.seed_constants(program); + let mut artifacts = Stage2ExecutionArtifacts::default(); + if program.steps.is_empty() { + for squeeze in program.transcript_squeezes { + verify_stage2_squeeze(program, squeeze, &mut store, transcript, &mut artifacts)?; + } + for driver in program.drivers { + verify_stage2_driver(program, driver, proof, ram, &mut store, transcript, &mut artifacts)?; + } + } else { + for step in program.steps { + match step.kind { + "transcript_squeeze" => { + let squeeze = find_plan(program.transcript_squeezes, step.symbol).ok_or(VerifyStage2Error::MissingValue { + symbol: step.symbol, + })?; + verify_stage2_squeeze(program, squeeze, &mut store, transcript, &mut artifacts)?; + } + "sumcheck_driver" => { + let driver = find_plan(program.drivers, step.symbol).ok_or(VerifyStage2Error::MissingProof { + driver: step.symbol, + })?; + verify_stage2_driver(program, driver, proof, ram, &mut store, transcript, &mut artifacts)?; + } + _ => { + return Err(VerifyStage2Error::InvalidProof { + driver: step.symbol, + reason: "unsupported stage2 program step", + }); + } + } + } + } + artifacts + .opening_batches + .extend(program.opening_batches.iter()); + Ok(artifacts) +} + +pub fn stage2_verifier_program() -> &'static Stage2VerifierProgramPlan { + &STAGE2_PROGRAM +} + +fn verify_stage2_squeeze( + program: &'static Stage2VerifierProgramPlan, + squeeze: &'static Stage2TranscriptSqueezePlan, + store: &mut Stage2ValueStore, + transcript: &mut T, + artifacts: &mut Stage2ExecutionArtifacts, +) -> Result<(), VerifyStage2Error> +where + T: Transcript, +{ + let values = transcript.challenge_vector(squeeze.count); + store.observe_challenge_vector(program, squeeze, &values)?; + artifacts.challenge_vectors.push(Stage2ChallengeVector { + symbol: squeeze.symbol, + values, + }); + Ok(()) +} + +fn verify_stage2_driver( + program: &'static Stage2VerifierProgramPlan, + driver: &'static Stage2SumcheckDriverPlan, + proof: &Stage2Proof, + ram: Option<&Stage2RamData<'_>>, + store: &mut Stage2ValueStore, + transcript: &mut T, + artifacts: &mut Stage2ExecutionArtifacts, +) -> Result<(), VerifyStage2Error> +where + T: Transcript, +{ + let proof = proof + .sumchecks + .get(artifacts.sumchecks.len()) + .ok_or(VerifyStage2Error::MissingProof { + driver: driver.symbol, + })?; + let relation = driver.relation.unwrap_or(""); + let output = match relation { + "jolt.stage2.product_virtual.uniskip" => { + verify_product_virtual_uniskip(program, driver, proof, store, transcript)? + } + "jolt.stage2.batched" => verify_batched_stage2(program, driver, proof, ram, store, transcript)?, + relation => return Err(VerifyStage2Error::UnsupportedRelation { relation }), + }; + artifacts.sumchecks.push(output); + Ok(()) +} + +fn verify_product_virtual_uniskip( + program: &'static Stage2VerifierProgramPlan, + driver: &'static Stage2SumcheckDriverPlan, + proof: &Stage2SumcheckOutput, + store: &mut Stage2ValueStore, + transcript: &mut T, +) -> Result, VerifyStage2Error> +where + T: Transcript, +{ + validate_driver_symbol(driver, proof)?; + let [poly] = proof.proof.round_polynomials.as_slice() else { + return Err(VerifyStage2Error::InvalidProof { + driver: driver.symbol, + reason: "unexpected product uniskip round count", + }); + }; + if polynomial_degree(poly) > driver.degree { + return Err(VerifyStage2Error::InvalidProof { + driver: driver.symbol, + reason: "product uniskip polynomial exceeds degree bound", + }); + } + let batch = find_batch(program.batches, driver.symbol, driver.batch)?; + let claim = batch_claims(program.claims, batch)? + .into_iter() + .next() + .ok_or(VerifyStage2Error::MissingClaim { + batch: batch.symbol, + claim: "stage2.product_virtual.uniskip.input", + })?; + let input_claim = store.claim_value(program, claim)?; + if !product_uniskip_sum_matches(poly, input_claim) { + return Err(VerifyStage2Error::InvalidProof { + driver: driver.symbol, + reason: "product uniskip input claim mismatch", + }); + } + append_univariate_poly(transcript, driver.round_label, poly); + let r0 = transcript.challenge(); + if !proof.point.is_empty() && proof.point != [r0] { + return Err(VerifyStage2Error::InvalidProof { + driver: driver.symbol, + reason: "product uniskip point mismatch", + }); + } + let eval = poly.evaluate(r0); + append_labeled_scalar(transcript, "opening_claim", &eval); + let output = Stage2SumcheckOutput { + driver: driver.symbol, + point: vec![r0], + evals: driver_evals(program, driver.symbol, eval), + proof: proof.proof.clone(), + }; + verify_named_evals(driver.symbol, &output.evals, &proof.evals)?; + store.observe_sumcheck_output(program, &output)?; + Ok(output) +} + +fn verify_batched_stage2( + program: &'static Stage2VerifierProgramPlan, + driver: &'static Stage2SumcheckDriverPlan, + proof: &Stage2SumcheckOutput, + ram: Option<&Stage2RamData<'_>>, + store: &mut Stage2ValueStore, + transcript: &mut T, +) -> Result, VerifyStage2Error> +where + T: Transcript, +{ + validate_driver_symbol(driver, proof)?; + let batch = find_batch(program.batches, driver.symbol, driver.batch)?; + let claims = batch_claims(program.claims, batch)?; + let input_claims = store.batch_claim_values(program, batch)?; + for claim in &input_claims { + append_labeled_scalar(transcript, batch.claim_label, claim); + } + let batching_coeffs = transcript.challenge_vector(claims.len()); + let claimed_sum = input_claims + .iter() + .zip(claims.iter()) + .zip(&batching_coeffs) + .map(|((claim, plan), coefficient)| { + claim.mul_pow_2(driver.num_rounds - plan.num_rounds) * *coefficient + }) + .sum::(); + let claim = SumcheckClaim::new(driver.num_rounds, driver.degree, claimed_sum); + let round_proofs = proof + .proof + .round_polynomials + .iter() + .map(|poly| CompressedLabeledRoundPoly::new(poly, driver.round_label.as_bytes())) + .collect::>(); + let output = SumcheckVerifier::verify(&claim, &round_proofs, transcript) + .map_err(|error| VerifyStage2Error::Sumcheck { + driver: driver.symbol, + error, + })?; + if !proof.point.is_empty() && proof.point != output.point { + return Err(VerifyStage2Error::InvalidProof { + driver: driver.symbol, + reason: "batched point mismatch", + }); + } + let expected = + expected_batched_output_claim(program, driver, &*store, &proof.evals, &output.point, &batching_coeffs, ram)?; + if output.value != expected { + return Err(VerifyStage2Error::InvalidProof { + driver: driver.symbol, + reason: "batched output claim mismatch", + }); + } + let verified = Stage2SumcheckOutput { + driver: driver.symbol, + point: output.point, + evals: proof.evals.clone(), + proof: proof.proof.clone(), + }; + store.observe_sumcheck_output(program, &verified)?; + super::common::append_opening_claims( + program.opening_inputs, + program.opening_claims, + program.opening_batches, + &mut store.0, + transcript, + &verified.evals, + |batch, claim| VerifyStage2Error::MissingClaim { batch, claim }, + |symbol| VerifyStage2Error::MissingValue { symbol }, + )?; + Ok(verified) +} + +impl Stage2ValueStore { + fn with_opening_inputs( + program: &'static Stage2VerifierProgramPlan, + inputs: &[Stage2OpeningInputValue], + ) -> Result { + Ok(Self(super::common::ValueStore::with_opening_inputs( + inputs, + program.opening_inputs, + )?)) + } + + fn seed_constants(&mut self, program: &'static Stage2VerifierProgramPlan) { + self.0.seed_constants(program.field_constants); + } + + fn observe_challenge_vector( + &mut self, + program: &'static Stage2VerifierProgramPlan, + plan: &'static Stage2TranscriptSqueezePlan, + values: &[F], + ) -> Result<(), VerifyStage2Error> { + self.0.observe_challenge_vector(plan, values, |input, expected, actual| { + VerifyStage2Error::InvalidInputLength { input, expected, actual } + })?; + self.evaluate_available_points(program)?; + self.evaluate_available_field_exprs(program)?; + Ok(()) + } + + fn observe_sumcheck_output( + &mut self, + program: &'static Stage2VerifierProgramPlan, + output: &Stage2SumcheckOutput, + ) -> Result<(), VerifyStage2Error> { + self.0.observe_sumcheck_output( + program.instance_results, + program.evals, + output, + |instance, mut point| { + match instance.point_order { + "as_is" => {} + "reverse" => point.reverse(), + _ => { + return Err(VerifyStage2Error::InvalidProof { + driver: output.driver, + reason: "unsupported point order", + }); + } + } + Ok(point) + }, + |input, expected, actual| VerifyStage2Error::InvalidInputLength { + input, + expected, + actual, + }, + |symbol| VerifyStage2Error::MissingValue { symbol }, + )?; + self.evaluate_available_points(program)?; + self.evaluate_available_field_exprs(program)?; + Ok(()) + } + + fn claim_value( + &mut self, + program: &'static Stage2VerifierProgramPlan, + claim: &Stage2SumcheckClaimPlan, + ) -> Result { + self.evaluate_available_field_exprs(program)?; + self.scalar(claim.claim_value) + } + + fn batch_claim_values( + &mut self, + program: &'static Stage2VerifierProgramPlan, + batch: &Stage2SumcheckBatchPlan, + ) -> Result, VerifyStage2Error> { + super::common::symbol_list(batch.claim_operands) + .map(|symbol| { + let claim = find_plan(program.claims, symbol).ok_or(VerifyStage2Error::MissingClaim { + batch: batch.symbol, + claim: symbol, + })?; + self.claim_value(program, claim) + }) + .collect() + } + + fn evaluate_available_points( + &mut self, + program: &'static Stage2VerifierProgramPlan, + ) -> Result<(), VerifyStage2Error> { + self.0.evaluate_available_points( + program.point_slices, + program.point_concats, + |input, expected, actual| VerifyStage2Error::InvalidInputLength { + input, + expected, + actual, + }, + ) + } + + fn evaluate_available_field_exprs( + &mut self, + program: &'static Stage2VerifierProgramPlan, + ) -> Result<(), VerifyStage2Error> { + self.0 + .evaluate_available_field_exprs(program.field_exprs, evaluate_stage2_field_expr) + } + + fn scalar(&self, symbol: &'static str) -> Result { + self.0 + .scalar_or(symbol, |symbol| VerifyStage2Error::MissingValue { symbol }) + } + + fn point(&self, symbol: &'static str) -> Result<&[F], VerifyStage2Error> { + self.0 + .point_or(symbol, |symbol| VerifyStage2Error::MissingValue { symbol }) + } + + fn try_point(&self, symbol: &str) -> Option<&[F]> { + self.0.try_point(symbol) + } +} + +fn evaluate_stage2_field_expr( + expr: &Stage2FieldExprPlan, + operands: &[F], +) -> Result { + match expr.formula { + "opening_eval" => Ok(single_operand(expr.symbol, operands)?), + "jolt_stage2_product_virtual_uniskip_input" => { + require_operand_count(expr.symbol, 4, operands.len())?; + let weights = lagrange_evals( + PRODUCT_VIRTUAL_UNISKIP_DOMAIN_START, + PRODUCT_VIRTUAL_UNISKIP_DOMAIN_SIZE, + operands[0], + ); + Ok(weights[0] * operands[1] + weights[1] * operands[2] + weights[2] * operands[3]) + } + "jolt_stage2_ram_read_write_input" => { + require_operand_count(expr.symbol, 3, operands.len())?; + Ok(operands[1] + operands[0] * operands[2]) + } + "jolt_stage2_instruction_lookup_input" => { + require_operand_count(expr.symbol, 6, operands.len())?; + let gamma = operands[0]; + let gamma2 = gamma.square(); + let gamma3 = gamma2 * gamma; + let gamma4 = gamma2.square(); + Ok(operands[1] + + gamma * operands[2] + + gamma2 * operands[3] + + gamma3 * operands[4] + + gamma4 * operands[5]) + } + "field.add" => { + require_operand_count(expr.symbol, 2, operands.len())?; + Ok(operands[0] + operands[1]) + } + "field.sub" => { + require_operand_count(expr.symbol, 2, operands.len())?; + Ok(operands[0] - operands[1]) + } + "field.mul" => { + require_operand_count(expr.symbol, 2, operands.len())?; + Ok(operands[0] * operands[1]) + } + "field.neg" => { + require_operand_count(expr.symbol, 1, operands.len())?; + Ok(-operands[0]) + } + formula => { + if let Some(exponent) = formula.strip_prefix("field.pow:") { + require_operand_count(expr.symbol, 1, operands.len())?; + let exponent = exponent.parse::().map_err(|_| { + VerifyStage2Error::UnsupportedFieldExpr { + symbol: expr.symbol, + formula, + } + })?; + return Ok(pow_field(operands[0], exponent)); + } + if let Some(spec) = formula.strip_prefix("poly.lagrange_basis_eval:") { + require_operand_count(expr.symbol, 1, operands.len())?; + let parts = spec.split(':').collect::>(); + if parts.len() != 3 { + return Err(VerifyStage2Error::UnsupportedFieldExpr { + symbol: expr.symbol, + formula, + }); + } + let domain_start = parts[0].parse::().map_err(|_| { + VerifyStage2Error::UnsupportedFieldExpr { + symbol: expr.symbol, + formula, + } + })?; + let domain_size = parts[1].parse::().map_err(|_| { + VerifyStage2Error::UnsupportedFieldExpr { + symbol: expr.symbol, + formula, + } + })?; + let index = parts[2].parse::().map_err(|_| { + VerifyStage2Error::UnsupportedFieldExpr { + symbol: expr.symbol, + formula, + } + })?; + let weights = lagrange_evals(domain_start, domain_size, operands[0]); + return weights + .get(index) + .copied() + .ok_or(VerifyStage2Error::InvalidInputLength { + input: expr.symbol, + expected: index + 1, + actual: weights.len(), + }); + } + Err(VerifyStage2Error::UnsupportedFieldExpr { + symbol: expr.symbol, + formula, + }) + } + } +} + +fn expected_batched_output_claim( + program: &'static Stage2VerifierProgramPlan, + driver: &'static Stage2SumcheckDriverPlan, + store: &Stage2ValueStore, + evals: &[Stage2NamedEval], + point: &[Fr], + batching_coeffs: &[Fr], + ram: Option<&Stage2RamData<'_>>, +) -> Result { + let batch = find_batch(program.batches, driver.symbol, driver.batch)?; + let claims = batch_claims(program.claims, batch)?; + let mut expected = Fr::from_u64(0); + for (claim, coefficient) in claims.iter().zip(batching_coeffs) { + let instance = program + .instance_results + .iter() + .find(|instance| instance.claim == claim.symbol && instance.source == driver.symbol) + .ok_or(VerifyStage2Error::MissingClaim { + batch: batch.symbol, + claim: claim.symbol, + })?; + let local_point = point + .get(instance.round_offset..instance.round_offset + instance.num_rounds) + .ok_or(VerifyStage2Error::InvalidInputLength { + input: instance.symbol, + expected: instance.round_offset + instance.num_rounds, + actual: point.len(), + })?; + let value = match instance.relation { + "jolt.stage2.ram.read_write" => expected_ram_read_write(store, evals, local_point)?, + "jolt.stage2.product_virtual.remainder" => { + expected_product_remainder(store, evals, local_point)? + } + "jolt.stage2.instruction_lookup.claim_reduction" => { + expected_instruction_lookup(store, evals, local_point)? + } + "jolt.stage2.ram.raf_evaluation" => expected_ram_raf(evals, local_point, ram)?, + "jolt.stage2.ram.output_check" => expected_ram_output(store, evals, local_point, ram)?, + relation => return Err(VerifyStage2Error::UnsupportedRelation { relation }), + }; + expected += *coefficient * value; + } + Ok(expected) +} + +fn expected_ram_read_write( + store: &Stage2ValueStore, + evals: &[Stage2NamedEval], + local_point: &[Fr], +) -> Result { + let r_cycle_stage1 = store.point("stage2.input.stage1.RamReadValue")?; + let log_t = r_cycle_stage1.len(); + let r_cycle = reverse_slice(&local_point[..log_t]); + let eq_eval = EqPolynomial::::mle(r_cycle_stage1, &r_cycle); + let gamma = store.scalar("stage2.ram_read_write.gamma")?; + let val = eval_by_name(evals, "stage2.ram_read_write.eval.RamVal")?; + let ra = eval_by_name(evals, "stage2.ram_read_write.eval.RamRa")?; + let inc = eval_by_name(evals, "stage2.ram_read_write.eval.RamInc")?; + Ok(eq_eval * ra * (val + gamma * (val + inc))) +} + +fn expected_product_remainder( + store: &Stage2ValueStore, + evals: &[Stage2NamedEval], + local_point: &[Fr], +) -> Result { + let tau_low = store.point("stage2.input.stage1.Product")?; + let tau_high = store.scalar("stage2.product_virtual.tau_high")?; + let r0 = *store + .point("stage2.product_virtual.uniskip.sumcheck")? + .first() + .ok_or(VerifyStage2Error::MissingValue { + symbol: "stage2.product_virtual.uniskip.sumcheck", + })?; + let r_tail = reverse_slice(local_point); + let low = EqPolynomial::::mle(tau_low, &r_tail); + let high = lagrange_kernel_eval( + PRODUCT_VIRTUAL_UNISKIP_DOMAIN_START, + PRODUCT_VIRTUAL_UNISKIP_DOMAIN_SIZE, + tau_high, + r0, + ); + let weights = lagrange_evals( + PRODUCT_VIRTUAL_UNISKIP_DOMAIN_START, + PRODUCT_VIRTUAL_UNISKIP_DOMAIN_SIZE, + r0, + ); + let left = weights[0] + * eval_by_name(evals, "stage2.product_virtual.remainder.eval.LeftInstructionInput")? + + weights[1] * eval_by_name(evals, "stage2.product_virtual.remainder.eval.LookupOutput")? + + weights[2] * eval_by_name(evals, "stage2.product_virtual.remainder.eval.OpFlagJump")?; + let right = weights[0] + * eval_by_name(evals, "stage2.product_virtual.remainder.eval.RightInstructionInput")? + + weights[1] + * eval_by_name(evals, "stage2.product_virtual.remainder.eval.InstructionFlagBranch")? + + weights[2] + * (Fr::from_u64(1) + - eval_by_name(evals, "stage2.product_virtual.remainder.eval.NextIsNoop")?); + Ok(high * low * left * right) +} + +fn expected_instruction_lookup( + store: &Stage2ValueStore, + evals: &[Stage2NamedEval], + local_point: &[Fr], +) -> Result { + let opening_point = reverse_slice(local_point); + let r_spartan = store.point("stage2.input.stage1.LookupOutput")?; + let eq_eval = EqPolynomial::::mle(&opening_point, r_spartan); + let gamma = store.scalar("stage2.instruction_lookup.gamma")?; + let gamma2 = gamma.square(); + let gamma3 = gamma2 * gamma; + let gamma4 = gamma2.square(); + let weighted = eval_by_name( + evals, + "stage2.instruction_lookup.claim_reduction.eval.LookupOutput", + )? + gamma + * eval_by_name( + evals, + "stage2.instruction_lookup.claim_reduction.eval.LeftLookupOperand", + )? + + gamma2 + * eval_by_name( + evals, + "stage2.instruction_lookup.claim_reduction.eval.RightLookupOperand", + )? + + gamma3 + * eval_by_name( + evals, + "stage2.instruction_lookup.claim_reduction.eval.LeftInstructionInput", + )? + + gamma4 + * eval_by_name( + evals, + "stage2.instruction_lookup.claim_reduction.eval.RightInstructionInput", + )?; + Ok(eq_eval * weighted) +} + +fn expected_ram_raf( + evals: &[Stage2NamedEval], + local_point: &[Fr], + ram: Option<&Stage2RamData<'_>>, +) -> Result { + let ram = ram.ok_or(VerifyStage2Error::MissingRam { + relation: "jolt.stage2.ram.raf_evaluation", + })?; + let address = reverse_slice(local_point); + let unmap = unmap_eval(ram.log_k, ram.start_address, &address); + Ok(unmap * eval_by_name(evals, "stage2.ram_raf.eval.RamRa")?) +} + +fn expected_ram_output( + store: &Stage2ValueStore, + evals: &[Stage2NamedEval], + local_point: &[Fr], + ram: Option<&Stage2RamData<'_>>, +) -> Result { + let ram = ram.ok_or(VerifyStage2Error::MissingRam { + relation: "jolt.stage2.ram.output_check", + })?; + let layout = ram.output_layout.ok_or(VerifyStage2Error::MissingRam { + relation: "jolt.stage2.ram.output_check.layout", + })?; + let r_address = store.point("stage2.ram_output.r_address")?; + let opening_point = reverse_slice(local_point); + let eq_eval = EqPolynomial::::mle(r_address, &opening_point); + let io_mask = range_mask_eval(layout.io_start, layout.io_end, &opening_point); + let val_io = sparse_final_ram_eval( + ram.final_ram, + layout.io_start, + layout.io_end, + &opening_point, + ); + let val_final = eval_by_name(evals, "stage2.ram_output.eval.RamValFinal")?; + Ok(eq_eval * io_mask * (val_final - val_io)) +} + +fn driver_evals( + program: &'static Stage2VerifierProgramPlan, + driver: &'static str, + value: Fr, +) -> Vec> { + program + .evals + .iter() + .filter(|eval| eval.source == driver) + .map(|eval| Stage2NamedEval { + name: eval.name, + oracle: eval.oracle, + value, + }) + .collect() +} + +fn verify_named_evals( + driver: &'static str, + expected: &[Stage2NamedEval], + actual: &[Stage2NamedEval], +) -> Result<(), VerifyStage2Error> { + if expected.len() != actual.len() { + return Err(VerifyStage2Error::InvalidProof { + driver, + reason: "eval count mismatch", + }); + } + for (expected, actual) in expected.iter().zip(actual) { + if expected.name != actual.name || expected.oracle != actual.oracle || expected.value != actual.value { + return Err(VerifyStage2Error::InvalidProof { + driver, + reason: "eval mismatch", + }); + } + } + Ok(()) +} + +fn validate_driver_symbol( + driver: &'static Stage2SumcheckDriverPlan, + proof: &Stage2SumcheckOutput, +) -> Result<(), VerifyStage2Error> { + if proof.driver == driver.symbol { + Ok(()) + } else { + Err(VerifyStage2Error::InvalidProof { + driver: driver.symbol, + reason: "driver symbol mismatch", + }) + } +} + +fn append_univariate_poly(transcript: &mut T, label: &'static str, poly: &UnivariatePoly) +where + T: Transcript, +{ + transcript.append(&LabelWithCount( + label.as_bytes(), + poly.coefficients().len() as u64, + )); + for coefficient in poly.coefficients() { + transcript.append(coefficient); + } +} + +fn product_uniskip_sum_matches(poly: &UnivariatePoly, claim: Fr) -> bool { + (0..PRODUCT_VIRTUAL_UNISKIP_DOMAIN_SIZE) + .map(|index| { + poly.evaluate(Fr::from_i64( + PRODUCT_VIRTUAL_UNISKIP_DOMAIN_START + index as i64, + )) + }) + .sum::() + == claim +} + +fn polynomial_degree(poly: &UnivariatePoly) -> usize { + poly.coefficients() + .iter() + .rposition(|coefficient| *coefficient != Fr::from_u64(0)) + .unwrap_or(0) +} + +fn unmap_eval(log_k: usize, start_address: u64, point: &[Fr]) -> Fr { + point + .iter() + .enumerate() + .fold(Fr::from_u64(start_address), |acc, (index, value)| { + acc + value.mul_pow_2(log_k - 1 - index).mul_u64(8) + }) +} + +fn range_mask_eval(start: usize, end: usize, point: &[Fr]) -> Fr { + eq_prefix_sum(end, point) - eq_prefix_sum(start, point) +} + +fn sparse_final_ram_eval(values: &[u64], start: usize, end: usize, point: &[Fr]) -> Fr { + values[start..end] + .iter() + .enumerate() + .filter(|(_, value)| **value != 0) + .map(|(offset, value)| Fr::from_u64(*value) * eq_eval_at_index(start + offset, point)) + .sum() +} + +fn eq_prefix_sum(end: usize, point: &[Fr]) -> Fr { + let domain_len = 1usize << point.len(); + if end >= domain_len { + return Fr::from_u64(1); + } + let mut sum = Fr::from_u64(0); + let mut prefix = Fr::from_u64(1); + for (bit, r) in point.iter().enumerate() { + let mask = 1usize << (point.len() - 1 - bit); + if end & mask == 0 { + prefix *= Fr::from_u64(1) - *r; + } else { + sum += prefix * (Fr::from_u64(1) - *r); + prefix *= *r; + } + } + sum +} + +fn eq_eval_at_index(index: usize, point: &[Fr]) -> Fr { + point.iter().enumerate().fold(Fr::from_u64(1), |acc, (bit, r)| { + let mask = 1usize << (point.len() - 1 - bit); + if index & mask == 0 { + acc * (Fr::from_u64(1) - *r) + } else { + acc * *r + } + }) +} +"# + } +} + +fn require_supported_symbol(kind: &str, actual: &str, expected: &str) -> Result<(), EmitError> { + if actual == expected { + Ok(()) + } else { + Err(EmitError::new(format!( + "unsupported {kind} @{actual}; expected @{expected}" + ))) + } +} + +fn emit_str_array(name: &str, values: &[String]) -> String { + if values.is_empty() { + return format!("pub const {name}: &[&str] = &[];\n\n"); + } + if let [value] = values { + return format!("pub const {name}: &[&str] = &[{}];\n\n", rust_str(value)); + } + let entries = values + .iter() + .map(|value| format!(" {},", rust_str(value))) + .collect::>() + .join("\n"); + format!("pub const {name}: &[&str] = &[\n{entries}\n];\n\n") +} + +fn emit_usize_array(name: &str, values: &[usize]) -> String { + let entries = values + .iter() + .map(|value| format!(" {value},")) + .collect::>() + .join("\n"); + format!("pub const {name}: &[usize] = &[\n{entries}\n];\n\n") +} + +fn intern_str_array( + source: &mut String, + arrays: &mut Vec<(Vec, String)>, + name_prefix: &str, + values: &[String], +) -> String { + if let Some((_, name)) = arrays + .iter() + .find(|(existing, _)| existing.as_slice() == values) + { + return name.clone(); + } + let name = format!("{name_prefix}_{}", arrays.len()); + source.push_str(&emit_str_array(&name, values)); + arrays.push((values.to_vec(), name.clone())); + name +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn verify_count(kind: &str, symbol: &str, expected: usize, actual: usize) -> Result<(), EmitError> { + if expected == actual { + Ok(()) + } else { + Err(EmitError::new(format!( + "{kind} @{symbol} count mismatch: expected {expected}, got {actual}" + ))) + } +} + +fn missing_role_binding(kind: &str, symbol: &str) -> EmitError { + EmitError::new(format!("missing {kind} for `{symbol}`")) +} + +fn symbols<'a>(values: impl Iterator) -> BTreeSet { + values.cloned().collect() +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn signed_int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_signed_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "signed integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn parse_signed_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn int_array_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "integer array"))?; + parse_int_array(&attribute).ok_or_else(|| attr_error(operation, attr, "integer array")) +} + +fn parse_int_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().parse().ok()) + .collect() +} + +fn operand_symbols( + operation: OperationRef<'_, '_>, + start_index: usize, +) -> Result, EmitError> { + (start_index..operation.operand_count()) + .map(|index| operand_symbol(operation, index)) + .collect() +} + +fn operand_symbol(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + EmitError::new(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + EmitError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name<'c: 'a, 'a>(operation: impl OperationLike<'c, 'a>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} diff --git a/crates/bolt/src/protocols/jolt/emit/rust/stage3.rs b/crates/bolt/src/protocols/jolt/emit/rust/stage3.rs new file mode 100644 index 0000000000..936557f08a --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/stage3.rs @@ -0,0 +1,2245 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{Attribute, OperationRef}; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3CpuProgram { + pub role: Role, + pub params: Stage3Params, + pub steps: Vec, + pub transcript_squeezes: Vec, + pub opening_inputs: Vec, + pub field_constants: Vec, + pub field_exprs: Vec, + pub kernels: Vec, + pub claims: Vec, + pub batches: Vec, + pub drivers: Vec, + pub instance_results: Vec, + pub evals: Vec, + pub point_slices: Vec, + pub point_concats: Vec, + pub opening_claims: Vec, + pub opening_equalities: Vec, + pub opening_batches: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3Params { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3KernelPlan { + pub symbol: String, + pub relation: String, + pub kind: String, + pub backend: String, + pub abi: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3TranscriptSqueezePlan { + pub symbol: String, + pub label: String, + pub kind: String, + pub count: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3ProgramStepPlan { + pub kind: String, + pub symbol: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3OpeningInputPlan { + pub symbol: String, + pub source_stage: String, + pub source_claim: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3FieldConstantPlan { + pub symbol: String, + pub field: String, + pub value: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3FieldExprPlan { + pub symbol: String, + pub kind: String, + pub formula: String, + pub operand_names: Vec, + pub operands: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3SumcheckClaimPlan { + pub symbol: String, + pub stage: String, + pub domain: String, + pub num_rounds: usize, + pub degree: usize, + pub claim: String, + pub kernel: Option, + pub relation: Option, + pub claim_value: String, + pub input_openings: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3SumcheckBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, + pub claim_label: String, + pub round_label: String, + pub round_schedule: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3SumcheckDriverPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub kernel: Option, + pub relation: Option, + pub batch: String, + pub policy: String, + pub round_schedule: Vec, + pub claim_label: String, + pub round_label: String, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3SumcheckInstanceResultPlan { + pub symbol: String, + pub source: String, + pub claim: String, + pub relation: String, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: String, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3SumcheckEvalPlan { + pub symbol: String, + pub source: String, + pub name: String, + pub index: usize, + pub oracle: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3PointSlicePlan { + pub symbol: String, + pub source: String, + pub offset: usize, + pub length: usize, + pub input: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3PointConcatPlan { + pub symbol: String, + pub layout: String, + pub arity: usize, + pub inputs: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3OpeningClaimPlan { + pub symbol: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, + pub point_source: String, + pub eval_source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3OpeningClaimEqualityPlan { + pub symbol: String, + pub mode: String, + pub lhs: String, + pub rhs: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage3OpeningBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, +} + +pub fn stage3_cpu_program(module: &BoltModule<'_, Cpu>) -> Result { + verify_cpu_schema(module)?; + let program = Stage3CpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +pub fn emit_stage3_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = stage3_cpu_program(module)?; + + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source()?, + }) +} + +impl Stage3CpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let mut params = None; + let mut steps = Vec::new(); + let mut transcript_squeezes = Vec::new(); + let mut opening_inputs = Vec::new(); + let mut field_constants = Vec::new(); + let mut field_exprs = Vec::new(); + let mut kernels = Vec::new(); + let mut claims = Vec::new(); + let mut batches = Vec::new(); + let mut drivers = Vec::new(); + let mut instance_results = Vec::new(); + let mut evals = Vec::new(); + let mut point_slices = Vec::new(); + let mut point_concats = Vec::new(); + let mut opening_claims = Vec::new(); + let mut opening_equalities = Vec::new(); + let mut opening_batches = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(Stage3Params { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.kernel" => { + kernels.push(Stage3KernelPlan { + symbol: string_attr(op, "sym_name")?, + relation: symbol_attr(op, "relation")?, + kind: string_attr(op, "kind")?, + backend: string_attr(op, "backend")?, + abi: string_attr(op, "abi")?, + }); + } + "cpu.transcript_squeeze" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage3ProgramStepPlan { + kind: "transcript_squeeze".to_owned(), + symbol: symbol.clone(), + }); + transcript_squeezes.push(Stage3TranscriptSqueezePlan { + symbol, + label: string_attr(op, "label")?, + kind: string_attr(op, "kind")?, + count: int_attr(op, "count")?, + }); + } + "cpu.opening_input" => { + opening_inputs.push(Stage3OpeningInputPlan { + symbol: string_attr(op, "sym_name")?, + source_stage: symbol_attr(op, "source_stage")?, + source_claim: symbol_attr(op, "source_claim")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + }); + } + "cpu.field_const" => { + field_constants.push(Stage3FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: int_attr(op, "value")?, + }); + } + "cpu.field_zero" => { + field_constants.push(Stage3FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 0, + }); + } + "cpu.field_one" => { + field_constants.push(Stage3FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 1, + }); + } + "cpu.field_add" | "cpu.field_sub" | "cpu.field_mul" | "cpu.field_neg" => { + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage3FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: operation_name(op).replace("cpu.field_", "field."), + operand_names: operands.clone(), + operands, + }); + } + "cpu.field_pow" => { + let exponent = int_attr(op, "exponent")?; + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage3FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: format!("field.pow:{exponent}"), + operand_names: operands.clone(), + operands, + }); + } + "cpu.poly_lagrange_basis_eval" => { + let domain_start = signed_int_attr(op, "domain_start")?; + let domain_size = int_attr(op, "domain_size")?; + let index = int_attr(op, "index")?; + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage3FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: format!( + "poly.lagrange_basis_eval:{domain_start}:{domain_size}:{index}" + ), + operand_names: operands.clone(), + operands, + }); + } + "cpu.sumcheck_claim" => { + claims.push(Stage3SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_verify_claim" => { + claims.push(Stage3SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_batch" => { + batches.push(Stage3SumcheckBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + round_schedule: int_array_attr(op, "round_schedule")?, + }); + } + "cpu.sumcheck_driver" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage3ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage3SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_verify" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage3ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage3SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_instance_result" => { + instance_results.push(Stage3SumcheckInstanceResultPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + claim: symbol_attr(op, "claim")?, + relation: symbol_attr(op, "relation")?, + index: int_attr(op, "index")?, + point_arity: int_attr(op, "point_arity")?, + num_rounds: int_attr(op, "num_rounds")?, + round_offset: int_attr(op, "round_offset")?, + point_order: string_attr(op, "point_order")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_eval" => { + evals.push(Stage3SumcheckEvalPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + name: symbol_attr(op, "name")?, + index: int_attr(op, "index")?, + oracle: symbol_attr(op, "oracle")?, + }); + } + "cpu.point_slice" => { + point_slices.push(Stage3PointSlicePlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + offset: int_attr(op, "offset")?, + length: int_attr(op, "length")?, + input: operand_symbol(op, 0)?, + }); + } + "cpu.point_concat" => { + point_concats.push(Stage3PointConcatPlan { + symbol: string_attr(op, "sym_name")?, + layout: string_attr(op, "layout")?, + arity: int_attr(op, "arity")?, + inputs: operand_symbols(op, 0)?, + }); + } + "cpu.opening_claim" => { + opening_claims.push(Stage3OpeningClaimPlan { + symbol: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + point_source: operand_symbol(op, 0)?, + eval_source: operand_symbol(op, 1)?, + }); + } + "cpu.opening_claim_equal" => { + opening_equalities.push(Stage3OpeningClaimEqualityPlan { + symbol: string_attr(op, "sym_name")?, + mode: string_attr(op, "mode")?, + lhs: operand_symbol(op, 0)?, + rhs: operand_symbol(op, 1)?, + }); + } + "cpu.opening_batch" => { + opening_batches.push(Stage3OpeningBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + }); + } + _ => {} + } + } + + Ok(Self { + params: params.ok_or_else(|| EmitError::new("missing cpu.params"))?, + role: module + .role() + .ok_or_else(|| EmitError::new("missing cpu party role"))?, + steps, + transcript_squeezes, + opening_inputs, + field_constants, + field_exprs, + kernels, + claims, + batches, + drivers, + instance_results, + evals, + point_slices, + point_concats, + opening_claims, + opening_equalities, + opening_batches, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + require_supported_symbol("field", &self.params.field, "bn254_fr")?; + require_supported_symbol("pcs", &self.params.pcs, "dory")?; + require_supported_symbol("transcript", &self.params.transcript, "blake2b_transcript")?; + self.verify_transcript_squeezes()?; + self.verify_field_flow()?; + self.verify_claim_batches()?; + match self.role { + Role::Prover => { + self.verify_kernel_definitions()?; + self.verify_prover_driver_bindings()?; + } + Role::Verifier => self.verify_verifier_driver_bindings()?, + } + self.verify_opening_flow() + } + + fn verify_transcript_squeezes(&self) -> Result<(), EmitError> { + for squeeze in &self.transcript_squeezes { + if !matches!( + squeeze.kind.as_str(), + "challenge_scalar" | "challenge_vector" + ) { + return Err(EmitError::new(format!( + "stage3 transcript squeeze @{} has unsupported kind `{}`", + squeeze.symbol, squeeze.kind + ))); + } + if squeeze.count == 0 { + return Err(EmitError::new(format!( + "stage3 transcript squeeze @{} has zero count", + squeeze.symbol + ))); + } + } + Ok(()) + } + + fn verify_field_flow(&self) -> Result<(), EmitError> { + for constant in &self.field_constants { + require_supported_symbol("field constant field", &constant.field, "bn254_fr")?; + } + let field_values = self.field_value_symbols(); + for expr in &self.field_exprs { + verify_count( + "field expr operands", + &expr.symbol, + expr.operand_names.len(), + expr.operands.len(), + )?; + for operand in &expr.operands { + if !field_values.contains(operand) { + return Err(EmitError::new(format!( + "field expr @{} references missing field value @{operand}", + expr.symbol + ))); + } + } + } + for claim in &self.claims { + if !field_values.contains(&claim.claim_value) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing claim value @{}", + claim.symbol, claim.claim_value + ))); + } + } + Ok(()) + } + + fn field_value_symbols(&self) -> BTreeSet { + let mut values = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + values.extend(symbols( + self.field_constants.iter().map(|constant| &constant.symbol), + )); + values.extend(symbols( + self.transcript_squeezes + .iter() + .filter(|squeeze| matches!(squeeze.kind.as_str(), "challenge_scalar" | "scalar")) + .map(|squeeze| &squeeze.symbol), + )); + values.extend(symbols(self.field_exprs.iter().map(|expr| &expr.symbol))); + values.extend(symbols(self.evals.iter().map(|eval| &eval.symbol))); + values + } + + fn verify_kernel_definitions(&self) -> Result<(), EmitError> { + for kernel in &self.kernels { + if kernel.backend != "cpu" { + return Err(EmitError::new(format!( + "stage3 kernel @{} targets unsupported backend `{}`", + kernel.symbol, kernel.backend + ))); + } + if kernel.kind != "sumcheck" { + return Err(EmitError::new(format!( + "stage3 kernel @{} has unsupported kind `{}`", + kernel.symbol, kernel.kind + ))); + } + let expected_abi = match kernel.relation.as_str() { + "jolt.stage3.spartan_shift" => "jolt_stage3_spartan_shift", + "jolt.stage3.instruction_input" => "jolt_stage3_instruction_input", + "jolt.stage3.registers_claim_reduction" => "jolt_stage3_registers_claim_reduction", + "jolt.stage3.batched" => "jolt_stage3_batched", + _ => { + return Err(EmitError::new(format!( + "unsupported stage3 kernel relation @{}", + kernel.relation + ))); + } + }; + if kernel.abi != expected_abi { + return Err(EmitError::new(format!( + "stage3 kernel @{} ABI `{}` does not match relation @{}", + kernel.symbol, kernel.abi, kernel.relation + ))); + } + } + Ok(()) + } + + fn verify_claim_batches(&self) -> Result<(), EmitError> { + let claims = symbols(self.claims.iter().map(|claim| &claim.symbol)); + for batch in &self.batches { + verify_count( + "sumcheck batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "sumcheck batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "sumcheck batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !claims.contains(claim) { + return Err(EmitError::new(format!( + "sumcheck batch @{} references missing claim @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn verify_prover_driver_bindings(&self) -> Result<(), EmitError> { + let kernels = symbols(self.kernels.iter().map(|kernel| &kernel.symbol)); + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + let Some(kernel) = claim.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck claim @{} is missing kernel", + claim.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck claim @{} references missing kernel @{kernel}", + claim.symbol + ))); + } + } + for driver in &self.drivers { + let Some(kernel) = driver.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck driver @{} is missing kernel", + driver.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck driver @{} references missing kernel @{kernel}", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_verifier_driver_bindings(&self) -> Result<(), EmitError> { + if !self.kernels.is_empty() { + return Err(EmitError::new( + "verifier stage3 program must not contain kernels", + )); + } + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + if claim.kernel.is_some() || claim.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck claim @{} must carry relation and no kernel", + claim.symbol + ))); + } + } + for driver in &self.drivers { + if driver.kernel.is_some() || driver.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck driver @{} must carry relation and no kernel", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_opening_flow(&self) -> Result<(), EmitError> { + let mut point_sources = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + point_sources.extend(symbols( + self.instance_results + .iter() + .map(|instance| &instance.symbol), + )); + point_sources.extend(symbols( + self.opening_inputs.iter().map(|input| &input.symbol), + )); + point_sources.extend(symbols(self.point_slices.iter().map(|slice| &slice.symbol))); + point_sources.extend(symbols( + self.point_concats.iter().map(|concat| &concat.symbol), + )); + for slice in &self.point_slices { + if !point_sources.contains(&slice.input) { + return Err(EmitError::new(format!( + "point slice @{} uses missing point source @{}", + slice.symbol, slice.input + ))); + } + } + for concat in &self.point_concats { + for input in &concat.inputs { + if !point_sources.contains(input) { + return Err(EmitError::new(format!( + "point concat @{} uses missing point source @{input}", + concat.symbol + ))); + } + } + } + let eval_sources = self.field_value_symbols(); + let mut opening_sources = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + opening_sources.extend(symbols( + self.opening_claims.iter().map(|claim| &claim.symbol), + )); + for equality in &self.opening_equalities { + if !opening_sources.contains(&equality.lhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing lhs opening @{}", + equality.symbol, equality.lhs + ))); + } + if !opening_sources.contains(&equality.rhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing rhs opening @{}", + equality.symbol, equality.rhs + ))); + } + } + for claim in &self.claims { + for input in &claim.input_openings { + if !opening_sources.contains(input) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing opening @{input}", + claim.symbol + ))); + } + } + } + let drivers = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + for instance in &self.instance_results { + if !drivers.contains(&instance.source) { + return Err(EmitError::new(format!( + "sumcheck instance result @{} references missing driver @{}", + instance.symbol, instance.source + ))); + } + } + for eval in &self.evals { + if !drivers.contains(&eval.source) { + return Err(EmitError::new(format!( + "sumcheck eval @{} references missing driver @{}", + eval.symbol, eval.source + ))); + } + } + for claim in &self.opening_claims { + if !point_sources.contains(&claim.point_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing point source @{}", + claim.symbol, claim.point_source + ))); + } + if !eval_sources.contains(&claim.eval_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing eval source @{}", + claim.symbol, claim.eval_source + ))); + } + } + let openings = symbols(self.opening_claims.iter().map(|claim| &claim.symbol)); + for batch in &self.opening_batches { + verify_count( + "opening batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "opening batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "opening batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !openings.contains(claim) { + return Err(EmitError::new(format!( + "opening batch @{} references missing opening @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn emit_source(&self) -> Result { + match self.role { + Role::Prover => self.emit_prover_source(), + Role::Verifier => self.emit_verifier_source(), + } + } + + fn emit_prover_source(&self) -> Result { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + source.push_str(Self::emit_prover_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_prover_types()); + source.push('\n'); + source.push_str(&self.emit_prover_constants()?); + source.push('\n'); + source.push_str(Self::emit_prover_entrypoint()); + Ok(source) + } + + fn emit_verifier_source(&self) -> Result { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + source.push_str(Self::emit_verifier_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_verifier_types()); + source.push('\n'); + source.push_str(&self.emit_verifier_constants()?); + source.push('\n'); + source.push_str(Self::emit_verifier_entrypoint()); + Ok(source) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_stage3.rs", + Role::Verifier => "verify_stage3.rs", + } + } + + fn emit_prover_imports() -> &'static str { + "use jolt_field::Fr;\n\ + use jolt_kernels::stage3::{execute_stage3_program, Stage3CpuProgramPlan, Stage3ExecutionArtifacts, Stage3ExecutionMode, Stage3FieldConstantPlan, Stage3FieldExprPlan, Stage3KernelError, Stage3KernelExecutor, Stage3KernelPlan, Stage3OpeningBatchPlan, Stage3OpeningClaimEqualityPlan, Stage3OpeningClaimPlan, Stage3OpeningInputPlan, Stage3Params, Stage3PointConcatPlan, Stage3PointSlicePlan, Stage3ProgramStepPlan, Stage3SumcheckBatchPlan, Stage3SumcheckClaimPlan, Stage3SumcheckDriverPlan, Stage3SumcheckEvalPlan, Stage3SumcheckInstanceResultPlan, Stage3TranscriptSqueezePlan};\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_prover_types() -> &'static str { + "pub type DefaultStage3Transcript = Blake2bTranscript;\n" + } + + fn emit_verifier_imports() -> &'static str { + "use super::common::{batch_claims, eval_by_name, find_batch, find_plan, reverse_slice};\n\ + use jolt_field::{Field, Fr};\n\ + use jolt_poly::{EqPlusOnePolynomial, EqPolynomial};\n\ + use jolt_sumcheck::SumcheckError;\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_verifier_types() -> &'static str { + r"pub type DefaultStage3Transcript = Blake2bTranscript; + +pub type Stage3NamedEval = super::common::StageNamedEval; +pub type Stage3SumcheckOutput = super::common::StageSumcheckOutput; +pub type Stage3ChallengeVector = super::common::StageChallengeVector; +pub type Stage3ExecutionArtifacts = super::common::StageExecutionArtifacts; +pub type Stage3Proof = super::common::StageProof; +pub type Stage3OpeningInputValue = super::common::StageOpeningInputValue; +pub type Stage3VerifierProgramPlan = super::common::StageVerifierProgramPlan; + +pub use super::common::{ + FieldConstantPlan as Stage3FieldConstantPlan, FieldExprPlan as Stage3FieldExprPlan, + OpeningBatchPlan as Stage3OpeningBatchPlan, + OpeningClaimEqualityPlan as Stage3OpeningClaimEqualityPlan, + OpeningClaimPlan as Stage3OpeningClaimPlan, OpeningInputPlan as Stage3OpeningInputPlan, + PointConcatPlan as Stage3PointConcatPlan, PointSlicePlan as Stage3PointSlicePlan, + ProgramStepPlan as Stage3ProgramStepPlan, StageParams as Stage3Params, + SumcheckBatchPlan as Stage3SumcheckBatchPlan, SumcheckEvalPlan as Stage3SumcheckEvalPlan, + SumcheckInstanceResultPlan as Stage3SumcheckInstanceResultPlan, + TranscriptSqueezePlan as Stage3TranscriptSqueezePlan, + SumcheckClaimPlan as Stage3SumcheckClaimPlan, + SumcheckDriverPlan as Stage3SumcheckDriverPlan, +}; + +#[derive(Debug)] +pub enum VerifyStage3Error { + UnexpectedProofCount { expected: usize, got: usize }, + MissingProof { driver: &'static str }, + MissingBatch { driver: &'static str, batch: &'static str }, + MissingClaim { batch: &'static str, claim: &'static str }, + MissingValue { symbol: &'static str }, + InvalidInputLength { input: &'static str, expected: usize, actual: usize }, + InvalidProof { driver: &'static str, reason: &'static str }, + UnsupportedFieldExpr { symbol: &'static str, formula: &'static str }, + UnsupportedRelation { relation: &'static str }, + Sumcheck { driver: &'static str, error: SumcheckError }, +} + +super::common::impl_runtime_plan_error_conversion!(VerifyStage3Error); +" + } + + fn emit_prover_constants(&self) -> Result { + let mut source = self.emit_shared_constants(); + source.push_str(&self.emit_kernel_constants()); + source.push_str(&self.emit_prover_sumcheck_claim_constants()?); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_prover_sumcheck_driver_constants()?); + source.push_str(&self.emit_tail_constants()); + source.push_str( + "pub const STAGE3_PROGRAM: Stage3CpuProgramPlan = Stage3CpuProgramPlan {\n\ + \x20 params: STAGE3_PARAMS,\n\ + \x20 steps: STAGE3_PROGRAM_STEPS,\n\ + \x20 transcript_squeezes: STAGE3_TRANSCRIPT_SQUEEZES,\n\ + \x20 opening_inputs: STAGE3_OPENING_INPUTS,\n\ + \x20 field_constants: STAGE3_FIELD_CONSTANTS,\n\ + \x20 field_exprs: STAGE3_FIELD_EXPRS,\n\ + \x20 kernels: STAGE3_KERNELS,\n\ + \x20 claims: STAGE3_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE3_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE3_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE3_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE3_SUMCHECK_EVALS,\n\ + \x20 point_slices: STAGE3_POINT_SLICES,\n\ + \x20 point_concats: STAGE3_POINT_CONCATS,\n\ + \x20 opening_claims: STAGE3_OPENING_CLAIMS,\n\ + \x20 opening_equalities: STAGE3_OPENING_EQUALITIES,\n\ + \x20 opening_batches: STAGE3_OPENING_BATCHES,\n\ + };\n", + ); + Ok(source) + } + + fn emit_verifier_constants(&self) -> Result { + let mut source = self.emit_shared_constants(); + source.push_str(&self.emit_verifier_sumcheck_claim_constants()?); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_verifier_sumcheck_driver_constants()?); + source.push_str(&self.emit_tail_constants()); + source.push_str( + "pub const STAGE3_PROGRAM: Stage3VerifierProgramPlan = Stage3VerifierProgramPlan {\n\ + \x20 params: STAGE3_PARAMS,\n\ + \x20 steps: STAGE3_PROGRAM_STEPS,\n\ + \x20 transcript_squeezes: STAGE3_TRANSCRIPT_SQUEEZES,\n\ + \x20 opening_inputs: STAGE3_OPENING_INPUTS,\n\ + \x20 field_constants: STAGE3_FIELD_CONSTANTS,\n\ + \x20 field_exprs: STAGE3_FIELD_EXPRS,\n\ + \x20 claims: STAGE3_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE3_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE3_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE3_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE3_SUMCHECK_EVALS,\n\ + \x20 point_slices: STAGE3_POINT_SLICES,\n\ + \x20 point_concats: STAGE3_POINT_CONCATS,\n\ + \x20 opening_claims: STAGE3_OPENING_CLAIMS,\n\ + \x20 opening_equalities: STAGE3_OPENING_EQUALITIES,\n\ + \x20 opening_batches: STAGE3_OPENING_BATCHES,\n\ + };\n", + ); + Ok(source) + } + + fn emit_shared_constants(&self) -> String { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const STAGE3_PARAMS: Stage3Params = Stage3Params {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + source.push_str(&self.emit_program_step_constants()); + source.push_str(&self.emit_transcript_squeeze_constants()); + source.push_str(&self.emit_opening_input_constants()); + source.push_str(&self.emit_field_constant_constants()); + source.push_str(&self.emit_field_expr_constants()); + source + } + + fn emit_program_step_constants(&self) -> String { + let steps = self + .steps + .iter() + .map(|step| { + format!( + " Stage3ProgramStepPlan {{ kind: {}, symbol: {} }},", + rust_str(&step.kind), + rust_str(&step.symbol), + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE3_PROGRAM_STEPS: &[Stage3ProgramStepPlan] = &[\n{steps}\n];\n\n") + } + + fn emit_transcript_squeeze_constants(&self) -> String { + let squeezes = self + .transcript_squeezes + .iter() + .map(|squeeze| { + format!( + " Stage3TranscriptSqueezePlan {{ symbol: {}, label: {}, kind: {}, count: {} }},", + rust_str(&squeeze.symbol), + rust_str(&squeeze.label), + rust_str(&squeeze.kind), + squeeze.count, + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE3_TRANSCRIPT_SQUEEZES: &[Stage3TranscriptSqueezePlan] = &[\n{squeezes}\n];\n\n" + ) + } + + fn emit_opening_input_constants(&self) -> String { + let inputs = self + .opening_inputs + .iter() + .map(|input| { + format!( + " Stage3OpeningInputPlan {{ symbol: {}, source_stage: {}, source_claim: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {} }},", + rust_str(&input.symbol), + rust_str(&input.source_stage), + rust_str(&input.source_claim), + rust_str(&input.oracle), + rust_str(&input.domain), + input.point_arity, + rust_str(&input.claim_kind) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE3_OPENING_INPUTS: &[Stage3OpeningInputPlan] = &[\n{inputs}\n];\n\n") + } + + fn emit_field_constant_constants(&self) -> String { + let constants = self + .field_constants + .iter() + .map(|constant| { + format!( + " Stage3FieldConstantPlan {{ symbol: {}, field: {}, value: {} }},", + rust_str(&constant.symbol), + rust_str(&constant.field), + constant.value + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE3_FIELD_CONSTANTS: &[Stage3FieldConstantPlan] = &[\n{constants}\n];\n\n" + ) + } + + fn emit_field_expr_constants(&self) -> String { + if self.role == Role::Verifier { + let exprs = self + .field_exprs + .iter() + .map(|expr| { + format!( + " Stage3FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operands: {} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula), + rust_str(&expr.operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE3_FIELD_EXPRS: &[Stage3FieldExprPlan] = &[\n{exprs}\n];\n" + ); + } + + let mut source = String::new(); + let mut arrays = Vec::new(); + let mut array_refs = Vec::new(); + for (index, expr) in self.field_exprs.iter().enumerate() { + let operands = intern_str_array( + &mut source, + &mut arrays, + "STAGE3_FIELD_EXPR_OPERANDS", + &expr.operands, + ); + let operand_names = intern_str_array( + &mut source, + &mut arrays, + "STAGE3_FIELD_EXPR_OPERANDS", + &expr.operand_names, + ); + array_refs.push((index, operand_names, operands)); + } + let exprs = self + .field_exprs + .iter() + .enumerate() + .map(|(index, expr)| { + let (_, operand_names, operands) = &array_refs[index]; + format!( + " Stage3FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operand_names: {operand_names}, operands: {operands} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE3_FIELD_EXPRS: &[Stage3FieldExprPlan] = &[\n{exprs}\n];\n" + ), + ); + source + } + + fn emit_kernel_constants(&self) -> String { + let kernels = self + .kernels + .iter() + .map(|kernel| { + format!( + " Stage3KernelPlan {{ symbol: {}, relation: {}, kind: {}, backend: {}, abi: {} }},", + rust_str(&kernel.symbol), + rust_str(&kernel.relation), + rust_str(&kernel.kind), + rust_str(&kernel.backend), + rust_str(&kernel.abi) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE3_KERNELS: &[Stage3KernelPlan] = &[\n{kernels}\n];\n\n") + } + + fn emit_prover_sumcheck_claim_constants(&self) -> Result { + self.emit_sumcheck_claim_constants(true) + } + + fn emit_verifier_sumcheck_claim_constants(&self) -> Result { + self.emit_sumcheck_claim_constants(false) + } + + fn emit_sumcheck_claim_constants(&self, prover: bool) -> Result { + let mut source = String::new(); + if prover { + for (index, claim) in self.claims.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE3_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS"), + &claim.input_openings, + )); + } + } + let mut claims = Vec::new(); + for (index, claim) in self.claims.iter().enumerate() { + if prover { + let kernel = claim + .kernel + .as_deref() + .ok_or_else(|| missing_role_binding("prover claim kernel", &claim.symbol))?; + claims.push(format!( + " Stage3SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: Some({}), relation: None, claim_value: {}, input_openings: STAGE3_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_str(kernel), + rust_str(&claim.claim_value) + )); + } else { + let relation = claim.relation.as_deref().ok_or_else(|| { + missing_role_binding("verifier claim relation", &claim.symbol) + })?; + claims.push(format!( + " Stage3SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: None, relation: Some({}), claim_value: {}, input_openings: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_str(relation), + rust_str(&claim.claim_value), + rust_str(&claim.input_openings.join("|")) + )); + } + } + let claims = claims.join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE3_SUMCHECK_CLAIMS: &[Stage3SumcheckClaimPlan] = &[\n{claims}\n];\n" + ), + ); + Ok(source) + } + + fn emit_sumcheck_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE3_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage3SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {}, claim_label: {}, round_label: {}, round_schedule: STAGE3_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")), + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE3_SUMCHECK_BATCHES: &[Stage3SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + return source; + } + + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE3_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE3_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + source.push_str(&emit_usize_array( + &format!("STAGE3_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage3SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE3_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE3_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS, claim_label: {}, round_label: {}, round_schedule: STAGE3_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE3_SUMCHECK_BATCHES: &[Stage3SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_prover_sumcheck_driver_constants(&self) -> Result { + self.emit_sumcheck_driver_constants(true) + } + + fn emit_verifier_sumcheck_driver_constants(&self) -> Result { + self.emit_sumcheck_driver_constants(false) + } + + fn emit_sumcheck_driver_constants(&self, prover: bool) -> Result { + let mut source = String::new(); + for (index, driver) in self.drivers.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE3_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE"), + &driver.round_schedule, + )); + } + let mut drivers = Vec::new(); + for (index, driver) in self.drivers.iter().enumerate() { + if prover { + let kernel = driver + .kernel + .as_deref() + .ok_or_else(|| missing_role_binding("prover driver kernel", &driver.symbol))?; + drivers.push(format!( + " Stage3SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: Some({}), relation: None, batch: {}, policy: {}, round_schedule: STAGE3_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_str(kernel), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + )); + } else { + let relation = driver.relation.as_deref().ok_or_else(|| { + missing_role_binding("verifier driver relation", &driver.symbol) + })?; + drivers.push(format!( + " Stage3SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: None, relation: Some({}), batch: {}, policy: {}, round_schedule: STAGE3_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_str(relation), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + )); + } + } + let drivers = drivers.join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE3_SUMCHECK_DRIVERS: &[Stage3SumcheckDriverPlan] = &[\n{drivers}\n];\n" + ), + ); + Ok(source) + } + + fn emit_tail_constants(&self) -> String { + let mut source = String::new(); + source.push_str(&self.emit_sumcheck_instance_result_constants()); + source.push_str(&self.emit_sumcheck_eval_constants()); + source.push_str(&self.emit_point_slice_constants()); + source.push_str(&self.emit_point_concat_constants()); + source.push_str(&self.emit_opening_claim_constants()); + source.push_str(&self.emit_opening_claim_equality_constants()); + source.push_str(&self.emit_opening_batch_constants()); + source + } + + fn emit_sumcheck_instance_result_constants(&self) -> String { + let instances = self + .instance_results + .iter() + .map(|instance| { + format!( + " Stage3SumcheckInstanceResultPlan {{ symbol: {}, source: {}, claim: {}, relation: {}, index: {}, point_arity: {}, num_rounds: {}, round_offset: {}, point_order: {}, degree: {} }},", + rust_str(&instance.symbol), + rust_str(&instance.source), + rust_str(&instance.claim), + rust_str(&instance.relation), + instance.index, + instance.point_arity, + instance.num_rounds, + instance.round_offset, + rust_str(&instance.point_order), + instance.degree + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE3_SUMCHECK_INSTANCE_RESULTS: &[Stage3SumcheckInstanceResultPlan] = &[\n{instances}\n];\n\n" + ) + } + + fn emit_sumcheck_eval_constants(&self) -> String { + let evals = self + .evals + .iter() + .map(|eval| { + format!( + " Stage3SumcheckEvalPlan {{ symbol: {}, source: {}, name: {}, index: {}, oracle: {} }},", + rust_str(&eval.symbol), + rust_str(&eval.source), + rust_str(&eval.name), + eval.index, + rust_str(&eval.oracle) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE3_SUMCHECK_EVALS: &[Stage3SumcheckEvalPlan] = &[\n{evals}\n];\n\n") + } + + fn emit_point_slice_constants(&self) -> String { + let slices = self + .point_slices + .iter() + .map(|slice| { + format!( + " Stage3PointSlicePlan {{ symbol: {}, source: {}, offset: {}, length: {}, input: {} }},", + rust_str(&slice.symbol), + rust_str(&slice.source), + slice.offset, + slice.length, + rust_str(&slice.input) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE3_POINT_SLICES: &[Stage3PointSlicePlan] = &[\n{slices}\n];\n\n") + } + + fn emit_point_concat_constants(&self) -> String { + if self.role == Role::Verifier { + let concats = self + .point_concats + .iter() + .map(|concat| { + format!( + " Stage3PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: {} }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity, + rust_str(&concat.inputs.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE3_POINT_CONCATS: &[Stage3PointConcatPlan] = &[\n{concats}\n];\n" + ); + } + + let mut source = String::new(); + for (index, concat) in self.point_concats.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE3_POINT_CONCAT_{index}_INPUTS"), + &concat.inputs, + )); + } + let concats = self + .point_concats + .iter() + .enumerate() + .map(|(index, concat)| { + format!( + " Stage3PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: STAGE3_POINT_CONCAT_{index}_INPUTS }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE3_POINT_CONCATS: &[Stage3PointConcatPlan] = &[\n{concats}\n];\n" + ), + ); + source + } + + fn emit_opening_claim_constants(&self) -> String { + let claims = self + .opening_claims + .iter() + .map(|claim| { + format!( + " Stage3OpeningClaimPlan {{ symbol: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {}, point_source: {}, eval_source: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.oracle), + rust_str(&claim.domain), + claim.point_arity, + rust_str(&claim.claim_kind), + rust_str(&claim.point_source), + rust_str(&claim.eval_source) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE3_OPENING_CLAIMS: &[Stage3OpeningClaimPlan] = &[\n{claims}\n];\n\n") + } + + fn emit_opening_claim_equality_constants(&self) -> String { + let equalities = self + .opening_equalities + .iter() + .map(|equality| { + format!( + " Stage3OpeningClaimEqualityPlan {{ symbol: {}, mode: {}, lhs: {}, rhs: {} }},", + rust_str(&equality.symbol), + rust_str(&equality.mode), + rust_str(&equality.lhs), + rust_str(&equality.rhs) + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE3_OPENING_EQUALITIES: &[Stage3OpeningClaimEqualityPlan] = &[\n{equalities}\n];\n\n" + ) + } + + fn emit_opening_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let batches = self + .opening_batches + .iter() + .map(|batch| { + format!( + " Stage3OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {} }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE3_OPENING_BATCHES: &[Stage3OpeningBatchPlan] = &[\n{batches}\n];\n" + ); + } + + let mut source = String::new(); + for (index, batch) in self.opening_batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE3_OPENING_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE3_OPENING_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + } + let batches = self + .opening_batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage3OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE3_OPENING_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE3_OPENING_BATCH_{index}_CLAIM_OPERANDS }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE3_OPENING_BATCHES: &[Stage3OpeningBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_prover_entrypoint() -> &'static str { + "pub fn execute_stage3_prover(\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage3KernelError>\n\ + where\n\ + \x20 E: Stage3KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage3_prover_with_program(&STAGE3_PROGRAM, executor, transcript)\n\ + }\n\ + \n\ + pub fn execute_stage3_prover_with_program(\n\ + \x20 program: &'static Stage3CpuProgramPlan,\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage3KernelError>\n\ + where\n\ + \x20 E: Stage3KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage3_program(program, Stage3ExecutionMode::Prover, executor, transcript)\n\ + }\n" + } + + fn emit_verifier_entrypoint() -> &'static str { + r#"pub fn verify_stage3( + proof: &Stage3Proof, + opening_inputs: &[Stage3OpeningInputValue], + transcript: &mut T, +) -> Result, VerifyStage3Error> +where + T: Transcript, +{ + verify_stage3_with_program(&STAGE3_PROGRAM, proof, opening_inputs, transcript) +} + +pub fn verify_stage3_with_program( + program: &'static Stage3VerifierProgramPlan, + proof: &Stage3Proof, + opening_inputs: &[Stage3OpeningInputValue], + transcript: &mut T, +) -> Result, VerifyStage3Error> +where + T: Transcript, +{ + if proof.sumchecks.len() != program.drivers.len() { + return Err(VerifyStage3Error::UnexpectedProofCount { + expected: program.drivers.len(), + got: proof.sumchecks.len(), + }); + } + let mut store = + super::common::ValueStore::with_opening_inputs(opening_inputs, program.opening_inputs)?; + store.seed_constants(program.field_constants); + let mut artifacts = Stage3ExecutionArtifacts::default(); + for step in program.steps { + match step.kind { + "transcript_squeeze" => { + let squeeze = + find_plan(program.transcript_squeezes, step.symbol).ok_or(VerifyStage3Error::MissingValue { + symbol: step.symbol, + })?; + verify_stage3_squeeze(program, squeeze, &mut store, transcript, &mut artifacts)?; + } + "sumcheck_driver" => { + let driver = + find_plan(program.drivers, step.symbol).ok_or(VerifyStage3Error::MissingProof { + driver: step.symbol, + })?; + verify_stage3_driver(program, driver, proof, &mut store, transcript, &mut artifacts)?; + } + _ => { + return Err(VerifyStage3Error::InvalidProof { + driver: step.symbol, + reason: "unsupported stage3 program step", + }); + } + } + } + artifacts + .opening_batches + .extend(program.opening_batches.iter()); + Ok(artifacts) +} + +pub fn stage3_verifier_program() -> &'static Stage3VerifierProgramPlan { + &STAGE3_PROGRAM +} + +fn verify_stage3_squeeze( + program: &'static Stage3VerifierProgramPlan, + squeeze: &'static Stage3TranscriptSqueezePlan, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage3ExecutionArtifacts, +) -> Result<(), VerifyStage3Error> +where + T: Transcript, +{ + let values = transcript.challenge_vector(squeeze.count); + store.observe_challenge_vector(squeeze, &values, |input, expected, actual| { + VerifyStage3Error::InvalidInputLength { + input, + expected, + actual, + } + })?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage3Error::from)?; + artifacts.challenge_vectors.push(Stage3ChallengeVector { + symbol: squeeze.symbol, + values, + }); + Ok(()) +} + +fn verify_stage3_driver( + program: &'static Stage3VerifierProgramPlan, + driver: &'static Stage3SumcheckDriverPlan, + proof: &Stage3Proof, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage3ExecutionArtifacts, +) -> Result<(), VerifyStage3Error> +where + T: Transcript, +{ + let proof = proof + .sumchecks + .get(artifacts.sumchecks.len()) + .ok_or(VerifyStage3Error::MissingProof { + driver: driver.symbol, + })?; + let relation = driver.relation.unwrap_or(""); + let output = match relation { + "jolt.stage3.batched" => { + verify_batched_stage3(program, driver, proof, store, transcript)? + } + _ => { + return Err(VerifyStage3Error::UnsupportedRelation { + relation, + }); + } + }; + artifacts.sumchecks.push(output); + Ok(()) +} + +fn verify_batched_stage3( + program: &'static Stage3VerifierProgramPlan, + driver: &'static Stage3SumcheckDriverPlan, + proof: &Stage3SumcheckOutput, + store: &mut super::common::ValueStore, + transcript: &mut T, +) -> Result, VerifyStage3Error> +where + T: Transcript, +{ + super::common::verify_batched_sumcheck( + driver, + proof, + program.claims, + program.batches, + program.field_exprs, + program.opening_inputs, + program.opening_claims, + program.opening_batches, + store, + transcript, + |store, evals, point, batching_coeffs| { + expected_batched_output_claim(program, driver, store, evals, point, batching_coeffs) + }, + |store, verified| observe_stage3_sumcheck_output(program, store, verified), + |driver, error| VerifyStage3Error::Sumcheck { driver, error }, + ) +} + +fn observe_stage3_sumcheck_output( + program: &'static Stage3VerifierProgramPlan, + store: &mut super::common::ValueStore, + output: &Stage3SumcheckOutput, +) -> Result<(), VerifyStage3Error> { + store.observe_sumcheck_output( + program.instance_results, + program.evals, + output, + |instance, mut point| { + match instance.point_order { + "as_is" => {} + "reverse" => point.reverse(), + _ => { + return Err(VerifyStage3Error::InvalidProof { + driver: output.driver, + reason: "unsupported point order", + }); + } + } + Ok(point) + }, + |input, expected, actual| VerifyStage3Error::InvalidInputLength { + input, + expected, + actual, + }, + |symbol| VerifyStage3Error::MissingValue { symbol }, + )?; + store.evaluate_available_points( + program.point_slices, + program.point_concats, + |input, expected, actual| VerifyStage3Error::InvalidInputLength { + input, + expected, + actual, + }, + )?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage3Error::from)?; + store.verify_opening_equalities( + program.opening_equalities, + |driver, reason| VerifyStage3Error::InvalidProof { driver, reason }, + |symbol| VerifyStage3Error::MissingValue { symbol }, + ) +} + +fn expected_batched_output_claim( + program: &'static Stage3VerifierProgramPlan, + driver: &'static Stage3SumcheckDriverPlan, + store: &super::common::ValueStore, + evals: &[Stage3NamedEval], + point: &[Fr], + batching_coeffs: &[Fr], +) -> Result { + let batch = find_batch(program.batches, driver.symbol, driver.batch)?; + let claims = batch_claims(program.claims, batch)?; + let mut expected = Fr::from_u64(0); + for (claim, coefficient) in claims.iter().zip(batching_coeffs) { + let instance = program + .instance_results + .iter() + .find(|instance| instance.claim == claim.symbol && instance.source == driver.symbol) + .ok_or(VerifyStage3Error::MissingClaim { + batch: batch.symbol, + claim: claim.symbol, + })?; + let local_point = point + .get(instance.round_offset..instance.round_offset + instance.num_rounds) + .ok_or(VerifyStage3Error::InvalidInputLength { + input: instance.symbol, + expected: instance.round_offset + instance.num_rounds, + actual: point.len(), + })?; + let value = match instance.relation { + "jolt.stage3.spartan_shift" => { + expected_spartan_shift(store, evals, local_point)? + } + "jolt.stage3.instruction_input" => { + expected_instruction_input(store, evals, local_point)? + } + "jolt.stage3.registers_claim_reduction" => { + expected_registers(store, evals, local_point)? + } + _ => { + return Err(VerifyStage3Error::UnsupportedRelation { + relation: instance.relation, + }); + } + }; + expected += *coefficient * value; + } + Ok(expected) +} + +fn expected_spartan_shift( + store: &super::common::ValueStore, + evals: &[Stage3NamedEval], + local_point: &[Fr], +) -> Result { + let opening_point = reverse_slice(local_point); + let eq_outer = + EqPlusOnePolynomial::::new(super::common::store_point(store, "stage3.input.stage1.NextPC")?.to_vec()) + .evaluate(&opening_point); + let eq_product = EqPlusOnePolynomial::::new( + super::common::store_point(store, "stage3.input.stage2.product_virtual.NextIsNoop")? + .to_vec(), + ) + .evaluate(&opening_point); + let weighted_outer = eval_by_name(evals, "stage3.spartan_shift.eval.UnexpandedPC")? + + super::common::store_scalar(store, "stage3.spartan_shift.gamma")? + * eval_by_name(evals, "stage3.spartan_shift.eval.PC")? + + super::common::store_scalar(store, "stage3.spartan_shift.gamma2")? + * eval_by_name(evals, "stage3.spartan_shift.eval.OpFlagVirtualInstruction")? + + super::common::store_scalar(store, "stage3.spartan_shift.gamma3")? + * eval_by_name(evals, "stage3.spartan_shift.eval.OpFlagIsFirstInSequence")?; + Ok(eq_outer * weighted_outer + + super::common::store_scalar(store, "stage3.spartan_shift.gamma4")? + * eq_product + * (Fr::from_u64(1) + - eval_by_name(evals, "stage3.spartan_shift.eval.InstructionFlagIsNoop")?)) +} + +fn expected_instruction_input( + store: &super::common::ValueStore, + evals: &[Stage3NamedEval], + local_point: &[Fr], +) -> Result { + let opening_point = reverse_slice(local_point); + let eq_eval = EqPolynomial::::mle( + &opening_point, + super::common::store_point(store, "stage3.input.stage2.product_virtual.LeftInstructionInput")?, + ); + let left = eval_by_name( + evals, + "stage3.instruction_input.eval.InstructionFlagLeftOperandIsRs1Value", + )? * eval_by_name(evals, "stage3.instruction_input.eval.Rs1Value")? + + eval_by_name( + evals, + "stage3.instruction_input.eval.InstructionFlagLeftOperandIsPC", + )? * eval_by_name(evals, "stage3.instruction_input.eval.UnexpandedPC")?; + let right = eval_by_name( + evals, + "stage3.instruction_input.eval.InstructionFlagRightOperandIsRs2Value", + )? * eval_by_name(evals, "stage3.instruction_input.eval.Rs2Value")? + + eval_by_name( + evals, + "stage3.instruction_input.eval.InstructionFlagRightOperandIsImm", + )? * eval_by_name(evals, "stage3.instruction_input.eval.Imm")?; + Ok(eq_eval * (right + super::common::store_scalar(store, "stage3.instruction_input.gamma")? * left)) +} + +fn expected_registers( + store: &super::common::ValueStore, + evals: &[Stage3NamedEval], + local_point: &[Fr], +) -> Result { + let opening_point = reverse_slice(local_point); + let eq_eval = EqPolynomial::::mle( + &opening_point, + super::common::store_point(store, "stage3.input.stage1.RdWriteValue")?, + ); + Ok(eq_eval + * (eval_by_name(evals, "stage3.registers_claim_reduction.eval.RdWriteValue")? + + super::common::store_scalar(store, "stage3.registers.gamma")? + * eval_by_name(evals, "stage3.registers_claim_reduction.eval.Rs1Value")? + + super::common::store_scalar(store, "stage3.registers.gamma2")? + * eval_by_name(evals, "stage3.registers_claim_reduction.eval.Rs2Value")?)) +} + +"# + } +} + +fn require_supported_symbol(kind: &str, actual: &str, expected: &str) -> Result<(), EmitError> { + if actual == expected { + Ok(()) + } else { + Err(EmitError::new(format!( + "unsupported {kind} @{actual}; expected @{expected}" + ))) + } +} + +fn emit_str_array(name: &str, values: &[String]) -> String { + if values.is_empty() { + return format!("pub const {name}: &[&str] = &[];\n\n"); + } + if let [value] = values { + return format!("pub const {name}: &[&str] = &[{}];\n\n", rust_str(value)); + } + let entries = values + .iter() + .map(|value| format!(" {},", rust_str(value))) + .collect::>() + .join("\n"); + format!("pub const {name}: &[&str] = &[\n{entries}\n];\n\n") +} + +fn emit_usize_array(name: &str, values: &[usize]) -> String { + let entries = values + .iter() + .map(|value| format!(" {value},")) + .collect::>() + .join("\n"); + format!("pub const {name}: &[usize] = &[\n{entries}\n];\n\n") +} + +fn intern_str_array( + source: &mut String, + arrays: &mut Vec<(Vec, String)>, + name_prefix: &str, + values: &[String], +) -> String { + if let Some((_, name)) = arrays + .iter() + .find(|(existing, _)| existing.as_slice() == values) + { + return name.clone(); + } + let name = format!("{name_prefix}_{}", arrays.len()); + source.push_str(&emit_str_array(&name, values)); + arrays.push((values.to_vec(), name.clone())); + name +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn verify_count(kind: &str, symbol: &str, expected: usize, actual: usize) -> Result<(), EmitError> { + if expected == actual { + Ok(()) + } else { + Err(EmitError::new(format!( + "{kind} @{symbol} count mismatch: expected {expected}, got {actual}" + ))) + } +} + +fn missing_role_binding(kind: &str, symbol: &str) -> EmitError { + EmitError::new(format!("missing {kind} for `{symbol}`")) +} + +fn symbols<'a>(values: impl Iterator) -> BTreeSet { + values.cloned().collect() +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn signed_int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_signed_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "signed integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn parse_signed_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn int_array_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "integer array"))?; + parse_int_array(&attribute).ok_or_else(|| attr_error(operation, attr, "integer array")) +} + +fn parse_int_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().parse().ok()) + .collect() +} + +fn operand_symbols( + operation: OperationRef<'_, '_>, + start_index: usize, +) -> Result, EmitError> { + (start_index..operation.operand_count()) + .map(|index| operand_symbol(operation, index)) + .collect() +} + +fn operand_symbol(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + EmitError::new(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + EmitError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name<'c: 'a, 'a>(operation: impl OperationLike<'c, 'a>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} diff --git a/crates/bolt/src/protocols/jolt/emit/rust/stage4.rs b/crates/bolt/src/protocols/jolt/emit/rust/stage4.rs new file mode 100644 index 0000000000..b0b5c0ca63 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/stage4.rs @@ -0,0 +1,2487 @@ +#![expect( + clippy::needless_raw_string_hashes, + reason = "generated Rust templates are kept as raw string blocks for copyable output" +)] + +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{Attribute, OperationRef}; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4CpuProgram { + pub role: Role, + pub params: Stage4Params, + pub steps: Vec, + pub transcript_squeezes: Vec, + pub transcript_absorb_bytes: Vec, + pub opening_inputs: Vec, + pub field_constants: Vec, + pub field_exprs: Vec, + pub kernels: Vec, + pub claims: Vec, + pub batches: Vec, + pub drivers: Vec, + pub instance_results: Vec, + pub evals: Vec, + pub point_slices: Vec, + pub point_concats: Vec, + pub opening_claims: Vec, + pub opening_equalities: Vec, + pub opening_batches: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4Params { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4KernelPlan { + pub symbol: String, + pub relation: String, + pub kind: String, + pub backend: String, + pub abi: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4TranscriptSqueezePlan { + pub symbol: String, + pub label: String, + pub kind: String, + pub count: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4TranscriptAbsorbBytesPlan { + pub symbol: String, + pub label: String, + pub payload: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4ProgramStepPlan { + pub kind: String, + pub symbol: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4OpeningInputPlan { + pub symbol: String, + pub source_stage: String, + pub source_claim: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4FieldConstantPlan { + pub symbol: String, + pub field: String, + pub value: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4FieldExprPlan { + pub symbol: String, + pub kind: String, + pub formula: String, + pub operand_names: Vec, + pub operands: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckClaimPlan { + pub symbol: String, + pub stage: String, + pub domain: String, + pub num_rounds: usize, + pub degree: usize, + pub claim: String, + pub kernel: Option, + pub relation: Option, + pub claim_value: String, + pub input_openings: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, + pub claim_label: String, + pub round_label: String, + pub round_schedule: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckDriverPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub kernel: Option, + pub relation: Option, + pub batch: String, + pub policy: String, + pub round_schedule: Vec, + pub claim_label: String, + pub round_label: String, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckInstanceResultPlan { + pub symbol: String, + pub source: String, + pub claim: String, + pub relation: String, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: String, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckEvalPlan { + pub symbol: String, + pub source: String, + pub name: String, + pub index: usize, + pub oracle: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4PointSlicePlan { + pub symbol: String, + pub source: String, + pub offset: usize, + pub length: usize, + pub input: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4PointConcatPlan { + pub symbol: String, + pub layout: String, + pub arity: usize, + pub inputs: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4OpeningClaimPlan { + pub symbol: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, + pub point_source: String, + pub eval_source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4OpeningClaimEqualityPlan { + pub symbol: String, + pub mode: String, + pub lhs: String, + pub rhs: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage4OpeningBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, +} + +pub fn stage4_cpu_program(module: &BoltModule<'_, Cpu>) -> Result { + verify_cpu_schema(module)?; + let program = Stage4CpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +pub fn emit_stage4_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = stage4_cpu_program(module)?; + + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source(), + }) +} + +impl Stage4CpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let mut params = None; + let mut steps = Vec::new(); + let mut transcript_squeezes = Vec::new(); + let mut transcript_absorb_bytes = Vec::new(); + let mut opening_inputs = Vec::new(); + let mut field_constants = Vec::new(); + let mut field_exprs = Vec::new(); + let mut kernels = Vec::new(); + let mut claims = Vec::new(); + let mut batches = Vec::new(); + let mut drivers = Vec::new(); + let mut instance_results = Vec::new(); + let mut evals = Vec::new(); + let mut point_slices = Vec::new(); + let mut point_concats = Vec::new(); + let mut opening_claims = Vec::new(); + let mut opening_equalities = Vec::new(); + let mut opening_batches = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(Stage4Params { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.kernel" => { + kernels.push(Stage4KernelPlan { + symbol: string_attr(op, "sym_name")?, + relation: symbol_attr(op, "relation")?, + kind: string_attr(op, "kind")?, + backend: string_attr(op, "backend")?, + abi: string_attr(op, "abi")?, + }); + } + "cpu.transcript_squeeze" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage4ProgramStepPlan { + kind: "transcript_squeeze".to_owned(), + symbol: symbol.clone(), + }); + transcript_squeezes.push(Stage4TranscriptSqueezePlan { + symbol, + label: string_attr(op, "label")?, + kind: string_attr(op, "kind")?, + count: int_attr(op, "count")?, + }); + } + "cpu.transcript_absorb_bytes" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage4ProgramStepPlan { + kind: "transcript_absorb_bytes".to_owned(), + symbol: symbol.clone(), + }); + transcript_absorb_bytes.push(Stage4TranscriptAbsorbBytesPlan { + symbol, + label: string_attr(op, "label")?, + payload: string_attr(op, "payload")?, + }); + } + "cpu.opening_input" => { + opening_inputs.push(Stage4OpeningInputPlan { + symbol: string_attr(op, "sym_name")?, + source_stage: symbol_attr(op, "source_stage")?, + source_claim: symbol_attr(op, "source_claim")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + }); + } + "cpu.field_const" => { + field_constants.push(Stage4FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: int_attr(op, "value")?, + }); + } + "cpu.field_zero" => { + field_constants.push(Stage4FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 0, + }); + } + "cpu.field_one" => { + field_constants.push(Stage4FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 1, + }); + } + "cpu.field_add" | "cpu.field_sub" | "cpu.field_mul" | "cpu.field_neg" => { + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage4FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: operation_name(op).replace("cpu.field_", "field."), + operand_names: operands.clone(), + operands, + }); + } + "cpu.field_pow" => { + let exponent = int_attr(op, "exponent")?; + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage4FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: format!("field.pow:{exponent}"), + operand_names: operands.clone(), + operands, + }); + } + "cpu.sumcheck_claim" => { + claims.push(Stage4SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_verify_claim" => { + claims.push(Stage4SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_batch" => { + batches.push(Stage4SumcheckBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + round_schedule: int_array_attr(op, "round_schedule")?, + }); + } + "cpu.sumcheck_driver" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage4ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage4SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_verify" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage4ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage4SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_instance_result" => { + instance_results.push(Stage4SumcheckInstanceResultPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + claim: symbol_attr(op, "claim")?, + relation: symbol_attr(op, "relation")?, + index: int_attr(op, "index")?, + point_arity: int_attr(op, "point_arity")?, + num_rounds: int_attr(op, "num_rounds")?, + round_offset: int_attr(op, "round_offset")?, + point_order: string_attr(op, "point_order")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_eval" => { + evals.push(Stage4SumcheckEvalPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + name: symbol_attr(op, "name")?, + index: int_attr(op, "index")?, + oracle: symbol_attr(op, "oracle")?, + }); + } + "cpu.point_slice" => { + point_slices.push(Stage4PointSlicePlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + offset: int_attr(op, "offset")?, + length: int_attr(op, "length")?, + input: operand_symbol(op, 0)?, + }); + } + "cpu.point_concat" => { + point_concats.push(Stage4PointConcatPlan { + symbol: string_attr(op, "sym_name")?, + layout: string_attr(op, "layout")?, + arity: int_attr(op, "arity")?, + inputs: operand_symbols(op, 0)?, + }); + } + "cpu.opening_claim" => { + opening_claims.push(Stage4OpeningClaimPlan { + symbol: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + point_source: operand_symbol(op, 0)?, + eval_source: operand_symbol(op, 1)?, + }); + } + "cpu.opening_claim_equal" => { + opening_equalities.push(Stage4OpeningClaimEqualityPlan { + symbol: string_attr(op, "sym_name")?, + mode: string_attr(op, "mode")?, + lhs: operand_symbol(op, 0)?, + rhs: operand_symbol(op, 1)?, + }); + } + "cpu.opening_batch" => { + opening_batches.push(Stage4OpeningBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + }); + } + _ => {} + } + } + + Ok(Self { + params: params.ok_or_else(|| EmitError::new("missing cpu.params"))?, + role: module + .role() + .ok_or_else(|| EmitError::new("missing cpu party role"))?, + steps, + transcript_squeezes, + transcript_absorb_bytes, + opening_inputs, + field_constants, + field_exprs, + kernels, + claims, + batches, + drivers, + instance_results, + evals, + point_slices, + point_concats, + opening_claims, + opening_equalities, + opening_batches, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + require_supported_symbol("field", &self.params.field, "bn254_fr")?; + require_supported_symbol("pcs", &self.params.pcs, "dory")?; + require_supported_symbol("transcript", &self.params.transcript, "blake2b_transcript")?; + self.verify_transcript_steps()?; + self.verify_field_flow()?; + self.verify_claim_batches()?; + match self.role { + Role::Prover => { + self.verify_kernel_definitions()?; + self.verify_prover_driver_bindings()?; + } + Role::Verifier => self.verify_verifier_driver_bindings()?, + } + self.verify_opening_flow() + } + + fn verify_transcript_steps(&self) -> Result<(), EmitError> { + for squeeze in &self.transcript_squeezes { + if !matches!( + squeeze.kind.as_str(), + "challenge_scalar" | "challenge_vector" + ) { + return Err(EmitError::new(format!( + "stage4 transcript squeeze @{} has unsupported kind `{}`", + squeeze.symbol, squeeze.kind + ))); + } + if squeeze.count == 0 { + return Err(EmitError::new(format!( + "stage4 transcript squeeze @{} has zero count", + squeeze.symbol + ))); + } + } + for absorb in &self.transcript_absorb_bytes { + if absorb.label.is_empty() { + return Err(EmitError::new(format!( + "stage4 transcript byte absorb @{} has empty label", + absorb.symbol + ))); + } + } + Ok(()) + } + + fn verify_field_flow(&self) -> Result<(), EmitError> { + for constant in &self.field_constants { + require_supported_symbol("field constant field", &constant.field, "bn254_fr")?; + } + let field_values = self.field_value_symbols(); + for expr in &self.field_exprs { + verify_count( + "field expr operands", + &expr.symbol, + expr.operand_names.len(), + expr.operands.len(), + )?; + for operand in &expr.operands { + if !field_values.contains(operand) { + return Err(EmitError::new(format!( + "field expr @{} references missing field value @{operand}", + expr.symbol + ))); + } + } + } + for claim in &self.claims { + if !field_values.contains(&claim.claim_value) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing claim value @{}", + claim.symbol, claim.claim_value + ))); + } + } + Ok(()) + } + + fn field_value_symbols(&self) -> BTreeSet { + let mut values = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + values.extend(symbols( + self.field_constants.iter().map(|constant| &constant.symbol), + )); + values.extend(symbols( + self.transcript_squeezes + .iter() + .filter(|squeeze| matches!(squeeze.kind.as_str(), "challenge_scalar" | "scalar")) + .map(|squeeze| &squeeze.symbol), + )); + values.extend(symbols(self.field_exprs.iter().map(|expr| &expr.symbol))); + values.extend(symbols(self.evals.iter().map(|eval| &eval.symbol))); + values + } + + fn verify_kernel_definitions(&self) -> Result<(), EmitError> { + for kernel in &self.kernels { + if kernel.backend != "cpu" { + return Err(EmitError::new(format!( + "stage4 kernel @{} targets unsupported backend `{}`", + kernel.symbol, kernel.backend + ))); + } + if kernel.kind != "sumcheck" { + return Err(EmitError::new(format!( + "stage4 kernel @{} has unsupported kind `{}`", + kernel.symbol, kernel.kind + ))); + } + let expected_abi = match kernel.relation.as_str() { + "jolt.stage4.registers_read_write" => "jolt_stage4_registers_read_write", + "jolt.stage4.ram_val_check" => "jolt_stage4_ram_val_check", + "jolt.stage4.batched" => "jolt_stage4_batched", + _ => { + return Err(EmitError::new(format!( + "unsupported stage4 kernel relation @{}", + kernel.relation + ))); + } + }; + if kernel.abi != expected_abi { + return Err(EmitError::new(format!( + "stage4 kernel @{} ABI `{}` does not match relation @{}", + kernel.symbol, kernel.abi, kernel.relation + ))); + } + } + Ok(()) + } + + fn verify_claim_batches(&self) -> Result<(), EmitError> { + let claims = symbols(self.claims.iter().map(|claim| &claim.symbol)); + for batch in &self.batches { + verify_count( + "sumcheck batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "sumcheck batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "sumcheck batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !claims.contains(claim) { + return Err(EmitError::new(format!( + "sumcheck batch @{} references missing claim @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn verify_prover_driver_bindings(&self) -> Result<(), EmitError> { + let kernels = symbols(self.kernels.iter().map(|kernel| &kernel.symbol)); + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + let Some(kernel) = claim.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck claim @{} is missing kernel", + claim.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck claim @{} references missing kernel @{kernel}", + claim.symbol + ))); + } + } + for driver in &self.drivers { + let Some(kernel) = driver.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck driver @{} is missing kernel", + driver.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck driver @{} references missing kernel @{kernel}", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_verifier_driver_bindings(&self) -> Result<(), EmitError> { + if !self.kernels.is_empty() { + return Err(EmitError::new( + "verifier stage4 program must not contain kernels", + )); + } + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + if claim.kernel.is_some() || claim.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck claim @{} must carry relation and no kernel", + claim.symbol + ))); + } + } + for driver in &self.drivers { + if driver.kernel.is_some() || driver.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck driver @{} must carry relation and no kernel", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_opening_flow(&self) -> Result<(), EmitError> { + let mut point_sources = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + point_sources.extend(symbols( + self.instance_results + .iter() + .map(|instance| &instance.symbol), + )); + point_sources.extend(symbols( + self.opening_inputs.iter().map(|input| &input.symbol), + )); + point_sources.extend(symbols(self.point_slices.iter().map(|slice| &slice.symbol))); + point_sources.extend(symbols( + self.point_concats.iter().map(|concat| &concat.symbol), + )); + for slice in &self.point_slices { + if !point_sources.contains(&slice.input) { + return Err(EmitError::new(format!( + "point slice @{} uses missing point source @{}", + slice.symbol, slice.input + ))); + } + } + for concat in &self.point_concats { + for input in &concat.inputs { + if !point_sources.contains(input) { + return Err(EmitError::new(format!( + "point concat @{} uses missing point source @{input}", + concat.symbol + ))); + } + } + } + let eval_sources = self.field_value_symbols(); + let mut opening_sources = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + opening_sources.extend(symbols( + self.opening_claims.iter().map(|claim| &claim.symbol), + )); + for equality in &self.opening_equalities { + if !opening_sources.contains(&equality.lhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing lhs opening @{}", + equality.symbol, equality.lhs + ))); + } + if !opening_sources.contains(&equality.rhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing rhs opening @{}", + equality.symbol, equality.rhs + ))); + } + } + for claim in &self.claims { + for input in &claim.input_openings { + if !opening_sources.contains(input) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing opening @{input}", + claim.symbol + ))); + } + } + } + let drivers = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + for instance in &self.instance_results { + if !drivers.contains(&instance.source) { + return Err(EmitError::new(format!( + "sumcheck instance result @{} references missing driver @{}", + instance.symbol, instance.source + ))); + } + } + for eval in &self.evals { + if !drivers.contains(&eval.source) { + return Err(EmitError::new(format!( + "sumcheck eval @{} references missing driver @{}", + eval.symbol, eval.source + ))); + } + } + for claim in &self.opening_claims { + if !point_sources.contains(&claim.point_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing point source @{}", + claim.symbol, claim.point_source + ))); + } + if !eval_sources.contains(&claim.eval_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing eval source @{}", + claim.symbol, claim.eval_source + ))); + } + } + let openings = symbols(self.opening_claims.iter().map(|claim| &claim.symbol)); + for batch in &self.opening_batches { + verify_count( + "opening batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "opening batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "opening batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !openings.contains(claim) { + return Err(EmitError::new(format!( + "opening batch @{} references missing opening @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_stage4.rs", + Role::Verifier => "verify_stage4.rs", + } + } + + fn emit_source(&self) -> String { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + match self.role { + Role::Prover => { + source.push_str(Self::emit_prover_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_prover_types()); + } + Role::Verifier => { + source.push_str(Self::emit_verifier_imports()); + source.push_str("\n\n"); + source.push_str(&Self::emit_verifier_types()); + } + } + source.push('\n'); + source.push_str(&self.emit_constants()); + source.push('\n'); + source.push_str(self.emit_entrypoint()); + source + } + + fn emit_prover_imports() -> &'static str { + "use jolt_field::Fr;\n\ + use jolt_kernels::stage4::{execute_stage4_program, Stage4CpuProgramPlan, Stage4ExecutionArtifacts, Stage4ExecutionMode, Stage4FieldConstantPlan, Stage4FieldExprPlan, Stage4KernelError, Stage4KernelExecutor, Stage4KernelPlan, Stage4OpeningBatchPlan, Stage4OpeningClaimEqualityPlan, Stage4OpeningClaimPlan, Stage4OpeningInputPlan, Stage4Params, Stage4PointConcatPlan, Stage4PointSlicePlan, Stage4ProgramStepPlan, Stage4SumcheckBatchPlan, Stage4SumcheckClaimPlan, Stage4SumcheckDriverPlan, Stage4SumcheckEvalPlan, Stage4SumcheckInstanceResultPlan, Stage4TranscriptAbsorbBytesPlan, Stage4TranscriptSqueezePlan};\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_prover_types() -> &'static str { + "pub type DefaultStage4Transcript = Blake2bTranscript;\n" + } + + fn emit_verifier_imports() -> &'static str { + "use super::common::{batch_claims, eval_by_name, find_batch, find_plan, lt_polynomial_eval, reverse_slice};\n\ + use jolt_field::{Field, Fr};\n\ + use jolt_poly::EqPolynomial;\n\ + use jolt_sumcheck::SumcheckError;\n\ + use jolt_transcript::{Blake2bTranscript, LabelWithCount, Transcript};" + } + + #[expect(dead_code)] + fn emit_types() -> &'static str { + r#"#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4Params { + pub field: &'static str, + pub pcs: &'static str, + pub transcript: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4KernelPlan { + pub symbol: &'static str, + pub relation: &'static str, + pub kind: &'static str, + pub backend: &'static str, + pub abi: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4TranscriptSqueezePlan { + pub symbol: &'static str, + pub label: &'static str, + pub kind: &'static str, + pub count: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4TranscriptAbsorbBytesPlan { + pub symbol: &'static str, + pub label: &'static str, + pub payload: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4ProgramStepPlan { + pub kind: &'static str, + pub symbol: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4OpeningInputPlan { + pub symbol: &'static str, + pub source_stage: &'static str, + pub source_claim: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4FieldConstantPlan { + pub symbol: &'static str, + pub field: &'static str, + pub value: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4FieldExprPlan { + pub symbol: &'static str, + pub kind: &'static str, + pub formula: &'static str, + pub operand_names: &'static [&'static str], + pub operands: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckClaimPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub domain: &'static str, + pub num_rounds: usize, + pub degree: usize, + pub claim: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub claim_value: &'static str, + pub input_openings: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static [&'static str], + pub claim_operands: &'static [&'static str], + pub claim_label: &'static str, + pub round_label: &'static str, + pub round_schedule: &'static [usize], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckDriverPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub batch: &'static str, + pub policy: &'static str, + pub round_schedule: &'static [usize], + pub claim_label: &'static str, + pub round_label: &'static str, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckInstanceResultPlan { + pub symbol: &'static str, + pub source: &'static str, + pub claim: &'static str, + pub relation: &'static str, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: &'static str, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4SumcheckEvalPlan { + pub symbol: &'static str, + pub source: &'static str, + pub name: &'static str, + pub index: usize, + pub oracle: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4PointSlicePlan { + pub symbol: &'static str, + pub source: &'static str, + pub offset: usize, + pub length: usize, + pub input: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4PointConcatPlan { + pub symbol: &'static str, + pub layout: &'static str, + pub arity: usize, + pub inputs: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4OpeningClaimPlan { + pub symbol: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, + pub point_source: &'static str, + pub eval_source: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4OpeningClaimEqualityPlan { + pub symbol: &'static str, + pub mode: &'static str, + pub lhs: &'static str, + pub rhs: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4OpeningBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static [&'static str], + pub claim_operands: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage4CpuProgramPlan { + pub role: &'static str, + pub params: Stage4Params, + pub steps: &'static [Stage4ProgramStepPlan], + pub transcript_squeezes: &'static [Stage4TranscriptSqueezePlan], + pub transcript_absorb_bytes: &'static [Stage4TranscriptAbsorbBytesPlan], + pub opening_inputs: &'static [Stage4OpeningInputPlan], + pub field_constants: &'static [Stage4FieldConstantPlan], + pub field_exprs: &'static [Stage4FieldExprPlan], + pub kernels: &'static [Stage4KernelPlan], + pub claims: &'static [Stage4SumcheckClaimPlan], + pub batches: &'static [Stage4SumcheckBatchPlan], + pub drivers: &'static [Stage4SumcheckDriverPlan], + pub instance_results: &'static [Stage4SumcheckInstanceResultPlan], + pub evals: &'static [Stage4SumcheckEvalPlan], + pub point_slices: &'static [Stage4PointSlicePlan], + pub point_concats: &'static [Stage4PointConcatPlan], + pub opening_claims: &'static [Stage4OpeningClaimPlan], + pub opening_equalities: &'static [Stage4OpeningClaimEqualityPlan], + pub opening_batches: &'static [Stage4OpeningBatchPlan], +} +"# + } + + fn emit_verifier_type_aliases() -> &'static str { + r#"pub type Stage4NamedEval = super::common::StageNamedEval; +pub type Stage4SumcheckOutput = super::common::StageSumcheckOutput; +pub type Stage4ChallengeVector = super::common::StageChallengeVector; +pub type Stage4ExecutionArtifacts = super::common::StageExecutionArtifacts; +pub type Stage4Proof = super::common::StageProof; +pub type Stage4OpeningInputValue = super::common::StageOpeningInputValue; + +pub use super::common::{ + FieldConstantPlan as Stage4FieldConstantPlan, FieldExprPlan as Stage4FieldExprPlan, + KernelPlan as Stage4KernelPlan, OpeningBatchPlan as Stage4OpeningBatchPlan, + OpeningClaimEqualityPlan as Stage4OpeningClaimEqualityPlan, + OpeningClaimPlan as Stage4OpeningClaimPlan, OpeningInputPlan as Stage4OpeningInputPlan, + PointConcatPlan as Stage4PointConcatPlan, PointSlicePlan as Stage4PointSlicePlan, + ProgramStepPlan as Stage4ProgramStepPlan, StageParams as Stage4Params, + StageProgramPlanNoPointZeros as Stage4CpuProgramPlan, + SumcheckBatchPlan as Stage4SumcheckBatchPlan, + SumcheckClaimPlan as Stage4SumcheckClaimPlan, SumcheckDriverPlan as Stage4SumcheckDriverPlan, + SumcheckEvalPlan as Stage4SumcheckEvalPlan, + SumcheckInstanceResultPlan as Stage4SumcheckInstanceResultPlan, + TranscriptAbsorbBytesPlan as Stage4TranscriptAbsorbBytesPlan, + TranscriptSqueezePlan as Stage4TranscriptSqueezePlan, +}; +"# + } + + fn emit_verifier_types() -> String { + let mut source = Self::emit_verifier_type_aliases().to_owned(); + source.push_str( + r#" +pub type DefaultStage4Transcript = Blake2bTranscript; +pub type Stage4VerifierProgramPlan = Stage4CpuProgramPlan; + +#[derive(Debug)] +pub enum VerifyStage4Error { + UnexpectedProofCount { expected: usize, got: usize }, + MissingProof { driver: &'static str }, + MissingBatch { driver: &'static str, batch: &'static str }, + MissingClaim { batch: &'static str, claim: &'static str }, + MissingValue { symbol: &'static str }, + InvalidInputLength { input: &'static str, expected: usize, actual: usize }, + InvalidProof { driver: &'static str, reason: &'static str }, + UnsupportedFieldExpr { symbol: &'static str, formula: &'static str }, + UnsupportedRelation { relation: &'static str }, + Sumcheck { driver: &'static str, error: SumcheckError }, +} + +super::common::impl_runtime_plan_error_conversion!(VerifyStage4Error); +"#, + ); + source + } + + fn emit_constants(&self) -> String { + let mut source = self.emit_shared_constants(); + source.push_str(&self.emit_kernel_constants()); + source.push_str(&self.emit_sumcheck_claim_constants()); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_sumcheck_driver_constants()); + source.push_str(&self.emit_tail_constants()); + push_format( + &mut source, + format_args!( + "pub const STAGE4_PROGRAM: {} = Stage4CpuProgramPlan {{\n\ + \x20 role: {},\n\ + \x20 params: STAGE4_PARAMS,\n\ + \x20 steps: STAGE4_PROGRAM_STEPS,\n\ + \x20 transcript_squeezes: STAGE4_TRANSCRIPT_SQUEEZES,\n\ + \x20 transcript_absorb_bytes: STAGE4_TRANSCRIPT_ABSORB_BYTES,\n\ + \x20 opening_inputs: STAGE4_OPENING_INPUTS,\n\ + \x20 field_constants: STAGE4_FIELD_CONSTANTS,\n\ + \x20 field_exprs: STAGE4_FIELD_EXPRS,\n\ + \x20 kernels: STAGE4_KERNELS,\n\ + \x20 claims: STAGE4_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE4_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE4_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE4_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE4_SUMCHECK_EVALS,\n\ + \x20 point_slices: STAGE4_POINT_SLICES,\n\ + \x20 point_concats: STAGE4_POINT_CONCATS,\n\ + \x20 opening_claims: STAGE4_OPENING_CLAIMS,\n\ + \x20 opening_equalities: STAGE4_OPENING_EQUALITIES,\n\ + \x20 opening_batches: STAGE4_OPENING_BATCHES,\n\ + }};\n", + self.program_plan_type(), + rust_str(self.role_label()) + ), + ); + source + } + + fn emit_shared_constants(&self) -> String { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const STAGE4_PARAMS: Stage4Params = Stage4Params {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + source.push_str(&self.emit_program_step_constants()); + source.push_str(&self.emit_transcript_squeeze_constants()); + source.push_str(&self.emit_transcript_absorb_bytes_constants()); + source.push_str(&self.emit_opening_input_constants()); + source.push_str(&self.emit_field_constant_constants()); + source.push_str(&self.emit_field_expr_constants()); + source + } + + fn emit_program_step_constants(&self) -> String { + let steps = self + .steps + .iter() + .map(|step| { + format!( + " Stage4ProgramStepPlan {{ kind: {}, symbol: {} }},", + rust_str(&step.kind), + rust_str(&step.symbol), + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE4_PROGRAM_STEPS: &[Stage4ProgramStepPlan] = &[\n{steps}\n];\n\n") + } + + fn emit_transcript_squeeze_constants(&self) -> String { + let squeezes = self + .transcript_squeezes + .iter() + .map(|squeeze| { + format!( + " Stage4TranscriptSqueezePlan {{ symbol: {}, label: {}, kind: {}, count: {} }},", + rust_str(&squeeze.symbol), + rust_str(&squeeze.label), + rust_str(&squeeze.kind), + squeeze.count, + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE4_TRANSCRIPT_SQUEEZES: &[Stage4TranscriptSqueezePlan] = &[\n{squeezes}\n];\n\n" + ) + } + + fn emit_transcript_absorb_bytes_constants(&self) -> String { + let absorbs = self + .transcript_absorb_bytes + .iter() + .map(|absorb| { + format!( + " Stage4TranscriptAbsorbBytesPlan {{ symbol: {}, label: {}, payload: {} }},", + rust_str(&absorb.symbol), + rust_str(&absorb.label), + rust_str(&absorb.payload), + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE4_TRANSCRIPT_ABSORB_BYTES: &[Stage4TranscriptAbsorbBytesPlan] = &[\n{absorbs}\n];\n\n" + ) + } + + fn emit_opening_input_constants(&self) -> String { + let inputs = self + .opening_inputs + .iter() + .map(|input| { + format!( + " Stage4OpeningInputPlan {{ symbol: {}, source_stage: {}, source_claim: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {} }},", + rust_str(&input.symbol), + rust_str(&input.source_stage), + rust_str(&input.source_claim), + rust_str(&input.oracle), + rust_str(&input.domain), + input.point_arity, + rust_str(&input.claim_kind) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE4_OPENING_INPUTS: &[Stage4OpeningInputPlan] = &[\n{inputs}\n];\n\n") + } + + fn emit_field_constant_constants(&self) -> String { + let constants = self + .field_constants + .iter() + .map(|constant| { + format!( + " Stage4FieldConstantPlan {{ symbol: {}, field: {}, value: {} }},", + rust_str(&constant.symbol), + rust_str(&constant.field), + constant.value + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE4_FIELD_CONSTANTS: &[Stage4FieldConstantPlan] = &[\n{constants}\n];\n\n" + ) + } + + fn emit_field_expr_constants(&self) -> String { + if self.role == Role::Verifier { + let exprs = self + .field_exprs + .iter() + .map(|expr| { + format!( + " Stage4FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operands: {} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula), + rust_str(&expr.operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE4_FIELD_EXPRS: &[Stage4FieldExprPlan] = &[\n{exprs}\n];\n" + ); + } + + let mut source = String::new(); + let mut arrays = Vec::new(); + let mut array_refs = Vec::new(); + for (index, expr) in self.field_exprs.iter().enumerate() { + let operands = intern_str_array( + &mut source, + &mut arrays, + "STAGE4_FIELD_EXPR_OPERANDS", + &expr.operands, + ); + let operand_names = intern_str_array( + &mut source, + &mut arrays, + "STAGE4_FIELD_EXPR_OPERANDS", + &expr.operand_names, + ); + array_refs.push((index, operand_names, operands)); + } + let exprs = self + .field_exprs + .iter() + .enumerate() + .map(|(index, expr)| { + let (_, operand_names, operands) = &array_refs[index]; + format!( + " Stage4FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operand_names: {operand_names}, operands: {operands} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE4_FIELD_EXPRS: &[Stage4FieldExprPlan] = &[\n{exprs}\n];\n" + ), + ); + source + } + + fn emit_kernel_constants(&self) -> String { + let kernels = self + .kernels + .iter() + .map(|kernel| { + format!( + " Stage4KernelPlan {{ symbol: {}, relation: {}, kind: {}, backend: {}, abi: {} }},", + rust_str(&kernel.symbol), + rust_str(&kernel.relation), + rust_str(&kernel.kind), + rust_str(&kernel.backend), + rust_str(&kernel.abi) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE4_KERNELS: &[Stage4KernelPlan] = &[\n{kernels}\n];\n\n") + } + + fn emit_sumcheck_claim_constants(&self) -> String { + if self.role == Role::Verifier { + let claims = self + .claims + .iter() + .map(|claim| { + format!( + " Stage4SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: {}, relation: {}, claim_value: {}, input_openings: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_option_str(claim.kernel.as_deref()), + rust_option_str(claim.relation.as_deref()), + rust_str(&claim.claim_value), + rust_str(&claim.input_openings.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE4_SUMCHECK_CLAIMS: &[Stage4SumcheckClaimPlan] = &[\n{claims}\n];\n" + ); + } + + let mut source = String::new(); + for (index, claim) in self.claims.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE4_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS"), + &claim.input_openings, + )); + } + let claims = self + .claims + .iter() + .enumerate() + .map(|(index, claim)| { + format!( + " Stage4SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: {}, relation: {}, claim_value: {}, input_openings: STAGE4_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_option_str(claim.kernel.as_deref()), + rust_option_str(claim.relation.as_deref()), + rust_str(&claim.claim_value) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE4_SUMCHECK_CLAIMS: &[Stage4SumcheckClaimPlan] = &[\n{claims}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE4_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage4SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {}, claim_label: {}, round_label: {}, round_schedule: STAGE4_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")), + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE4_SUMCHECK_BATCHES: &[Stage4SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + return source; + } + + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE4_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE4_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + source.push_str(&emit_usize_array( + &format!("STAGE4_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage4SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE4_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE4_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS, claim_label: {}, round_label: {}, round_schedule: STAGE4_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE4_SUMCHECK_BATCHES: &[Stage4SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_driver_constants(&self) -> String { + let mut source = String::new(); + for (index, driver) in self.drivers.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE4_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE"), + &driver.round_schedule, + )); + } + let drivers = self + .drivers + .iter() + .enumerate() + .map(|(index, driver)| { + format!( + " Stage4SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: {}, relation: {}, batch: {}, policy: {}, round_schedule: STAGE4_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_option_str(driver.kernel.as_deref()), + rust_option_str(driver.relation.as_deref()), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE4_SUMCHECK_DRIVERS: &[Stage4SumcheckDriverPlan] = &[\n{drivers}\n];\n" + ), + ); + source + } + + fn emit_tail_constants(&self) -> String { + let mut source = String::new(); + source.push_str(&self.emit_sumcheck_instance_result_constants()); + source.push_str(&self.emit_sumcheck_eval_constants()); + source.push_str(&self.emit_point_slice_constants()); + source.push_str(&self.emit_point_concat_constants()); + source.push_str(&self.emit_opening_claim_constants()); + source.push_str(&self.emit_opening_claim_equality_constants()); + source.push_str(&self.emit_opening_batch_constants()); + source + } + + fn emit_sumcheck_instance_result_constants(&self) -> String { + let instances = self + .instance_results + .iter() + .map(|instance| { + format!( + " Stage4SumcheckInstanceResultPlan {{ symbol: {}, source: {}, claim: {}, relation: {}, index: {}, point_arity: {}, num_rounds: {}, round_offset: {}, point_order: {}, degree: {} }},", + rust_str(&instance.symbol), + rust_str(&instance.source), + rust_str(&instance.claim), + rust_str(&instance.relation), + instance.index, + instance.point_arity, + instance.num_rounds, + instance.round_offset, + rust_str(&instance.point_order), + instance.degree + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE4_SUMCHECK_INSTANCE_RESULTS: &[Stage4SumcheckInstanceResultPlan] = &[\n{instances}\n];\n\n" + ) + } + + fn emit_sumcheck_eval_constants(&self) -> String { + let evals = self + .evals + .iter() + .map(|eval| { + format!( + " Stage4SumcheckEvalPlan {{ symbol: {}, source: {}, name: {}, index: {}, oracle: {} }},", + rust_str(&eval.symbol), + rust_str(&eval.source), + rust_str(&eval.name), + eval.index, + rust_str(&eval.oracle) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE4_SUMCHECK_EVALS: &[Stage4SumcheckEvalPlan] = &[\n{evals}\n];\n\n") + } + + fn emit_point_slice_constants(&self) -> String { + let slices = self + .point_slices + .iter() + .map(|slice| { + format!( + " Stage4PointSlicePlan {{ symbol: {}, source: {}, offset: {}, length: {}, input: {} }},", + rust_str(&slice.symbol), + rust_str(&slice.source), + slice.offset, + slice.length, + rust_str(&slice.input) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE4_POINT_SLICES: &[Stage4PointSlicePlan] = &[\n{slices}\n];\n\n") + } + + fn emit_point_concat_constants(&self) -> String { + if self.role == Role::Verifier { + let concats = self + .point_concats + .iter() + .map(|concat| { + format!( + " Stage4PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: {} }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity, + rust_str(&concat.inputs.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE4_POINT_CONCATS: &[Stage4PointConcatPlan] = &[\n{concats}\n];\n" + ); + } + + let mut source = String::new(); + for (index, concat) in self.point_concats.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE4_POINT_CONCAT_{index}_INPUTS"), + &concat.inputs, + )); + } + let concats = self + .point_concats + .iter() + .enumerate() + .map(|(index, concat)| { + format!( + " Stage4PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: STAGE4_POINT_CONCAT_{index}_INPUTS }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE4_POINT_CONCATS: &[Stage4PointConcatPlan] = &[\n{concats}\n];\n" + ), + ); + source + } + + fn emit_opening_claim_constants(&self) -> String { + let claims = self + .opening_claims + .iter() + .map(|claim| { + format!( + " Stage4OpeningClaimPlan {{ symbol: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {}, point_source: {}, eval_source: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.oracle), + rust_str(&claim.domain), + claim.point_arity, + rust_str(&claim.claim_kind), + rust_str(&claim.point_source), + rust_str(&claim.eval_source) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE4_OPENING_CLAIMS: &[Stage4OpeningClaimPlan] = &[\n{claims}\n];\n\n") + } + + fn emit_opening_claim_equality_constants(&self) -> String { + let equalities = self + .opening_equalities + .iter() + .map(|equality| { + format!( + " Stage4OpeningClaimEqualityPlan {{ symbol: {}, mode: {}, lhs: {}, rhs: {} }},", + rust_str(&equality.symbol), + rust_str(&equality.mode), + rust_str(&equality.lhs), + rust_str(&equality.rhs) + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE4_OPENING_EQUALITIES: &[Stage4OpeningClaimEqualityPlan] = &[\n{equalities}\n];\n\n" + ) + } + + fn emit_opening_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let batches = self + .opening_batches + .iter() + .map(|batch| { + format!( + " Stage4OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {} }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE4_OPENING_BATCHES: &[Stage4OpeningBatchPlan] = &[\n{batches}\n];\n" + ); + } + + let mut source = String::new(); + for (index, batch) in self.opening_batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE4_OPENING_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE4_OPENING_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + } + let batches = self + .opening_batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage4OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE4_OPENING_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE4_OPENING_BATCH_{index}_CLAIM_OPERANDS }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE4_OPENING_BATCHES: &[Stage4OpeningBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_entrypoint(&self) -> &'static str { + match self.role { + Role::Prover => { + "pub fn execute_stage4_prover(\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage4KernelError>\n\ + where\n\ + \x20 E: Stage4KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage4_prover_with_program(&STAGE4_PROGRAM, executor, transcript)\n\ + }\n\ + \n\ + pub fn execute_stage4_prover_with_program(\n\ + \x20 program: &'static Stage4CpuProgramPlan,\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage4KernelError>\n\ + where\n\ + \x20 E: Stage4KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage4_program(program, Stage4ExecutionMode::Prover, executor, transcript)\n\ + }\n" + } + Role::Verifier => { + r#"pub fn verify_stage4( + proof: &Stage4Proof, + opening_inputs: &[Stage4OpeningInputValue], + transcript: &mut T, +) -> Result, VerifyStage4Error> +where + T: Transcript, +{ + verify_stage4_with_program(&STAGE4_PROGRAM, proof, opening_inputs, transcript) +} + +pub fn verify_stage4_with_program( + program: &'static Stage4VerifierProgramPlan, + proof: &Stage4Proof, + opening_inputs: &[Stage4OpeningInputValue], + transcript: &mut T, +) -> Result, VerifyStage4Error> +where + T: Transcript, +{ + if proof.sumchecks.len() != program.drivers.len() { + return Err(VerifyStage4Error::UnexpectedProofCount { + expected: program.drivers.len(), + got: proof.sumchecks.len(), + }); + } + let mut store = + super::common::ValueStore::with_opening_inputs(opening_inputs, program.opening_inputs)?; + store.seed_constants(program.field_constants); + let mut artifacts = Stage4ExecutionArtifacts::default(); + for step in program.steps { + match step.kind { + "transcript_squeeze" => { + let squeeze = + find_plan(program.transcript_squeezes, step.symbol).ok_or(VerifyStage4Error::MissingValue { + symbol: step.symbol, + })?; + verify_stage4_squeeze(program, squeeze, &mut store, transcript, &mut artifacts)?; + } + "transcript_absorb_bytes" => { + let absorb = find_plan(program.transcript_absorb_bytes, step.symbol).ok_or( + VerifyStage4Error::MissingValue { + symbol: step.symbol, + }, + )?; + absorb_stage4_bytes(absorb, transcript); + } + "sumcheck_driver" => { + let driver = + find_plan(program.drivers, step.symbol).ok_or(VerifyStage4Error::MissingProof { + driver: step.symbol, + })?; + verify_stage4_driver(program, driver, proof, &mut store, transcript, &mut artifacts)?; + } + _ => { + return Err(VerifyStage4Error::InvalidProof { + driver: step.symbol, + reason: "unsupported stage4 program step", + }); + } + } + } + artifacts + .opening_batches + .extend(program.opening_batches.iter()); + Ok(artifacts) +} + +pub fn stage4_verifier_program() -> &'static Stage4VerifierProgramPlan { + &STAGE4_PROGRAM +} + +fn verify_stage4_squeeze( + program: &'static Stage4VerifierProgramPlan, + squeeze: &'static Stage4TranscriptSqueezePlan, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage4ExecutionArtifacts, +) -> Result<(), VerifyStage4Error> +where + T: Transcript, +{ + let values = transcript.challenge_vector(squeeze.count); + store.observe_challenge_vector(squeeze, &values, |input, expected, actual| { + VerifyStage4Error::InvalidInputLength { + input, + expected, + actual, + } + })?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage4Error::from)?; + artifacts.challenge_vectors.push(Stage4ChallengeVector { + symbol: squeeze.symbol, + values, + }); + Ok(()) +} + +fn absorb_stage4_bytes(absorb: &'static Stage4TranscriptAbsorbBytesPlan, transcript: &mut T) +where + T: Transcript, +{ + transcript.append(&LabelWithCount( + absorb.label.as_bytes(), + absorb.payload.len() as u64, + )); + transcript.append_bytes(absorb.payload.as_bytes()); +} + +fn verify_stage4_driver( + program: &'static Stage4VerifierProgramPlan, + driver: &'static Stage4SumcheckDriverPlan, + proof: &Stage4Proof, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage4ExecutionArtifacts, +) -> Result<(), VerifyStage4Error> +where + T: Transcript, +{ + let proof = proof + .sumchecks + .get(artifacts.sumchecks.len()) + .ok_or(VerifyStage4Error::MissingProof { + driver: driver.symbol, + })?; + let relation = driver.relation.unwrap_or(""); + let output = match relation { + "jolt.stage4.batched" => { + verify_batched_stage4(program, driver, proof, store, transcript)? + } + _ => return Err(VerifyStage4Error::UnsupportedRelation { relation }), + }; + artifacts.sumchecks.push(output); + Ok(()) +} + +fn verify_batched_stage4( + program: &'static Stage4VerifierProgramPlan, + driver: &'static Stage4SumcheckDriverPlan, + proof: &Stage4SumcheckOutput, + store: &mut super::common::ValueStore, + transcript: &mut T, +) -> Result, VerifyStage4Error> +where + T: Transcript, +{ + super::common::verify_batched_sumcheck( + driver, + proof, + program.claims, + program.batches, + program.field_exprs, + program.opening_inputs, + program.opening_claims, + program.opening_batches, + store, + transcript, + |store, evals, point, batching_coeffs| { + expected_batched_output_claim(program, driver, store, evals, point, batching_coeffs) + }, + |store, verified| observe_stage4_sumcheck_output(program, store, verified), + |driver, error| VerifyStage4Error::Sumcheck { driver, error }, + ) +} + +fn observe_stage4_sumcheck_output( + program: &'static Stage4VerifierProgramPlan, + store: &mut super::common::ValueStore, + output: &Stage4SumcheckOutput, +) -> Result<(), VerifyStage4Error> { + store.observe_sumcheck_output( + program.instance_results, + program.evals, + output, + |instance, mut point| { + match instance.point_order { + "as_is" => {} + "reverse" => point.reverse(), + "stage4_registers_rw" => { + point = normalize_stage4_registers_rw_point(program, output.driver, &point)?; + } + _ => { + return Err(VerifyStage4Error::InvalidProof { + driver: output.driver, + reason: "unsupported point order", + }); + } + } + Ok(point) + }, + |input, expected, actual| VerifyStage4Error::InvalidInputLength { + input, + expected, + actual, + }, + |symbol| VerifyStage4Error::MissingValue { symbol }, + )?; + store.evaluate_available_points( + program.point_slices, + program.point_concats, + |input, expected, actual| VerifyStage4Error::InvalidInputLength { + input, + expected, + actual, + }, + )?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage4Error::from)?; + store.verify_opening_equalities( + program.opening_equalities, + |driver, reason| VerifyStage4Error::InvalidProof { driver, reason }, + |symbol| VerifyStage4Error::MissingValue { symbol }, + ) +} + +fn expected_batched_output_claim( + program: &'static Stage4VerifierProgramPlan, + driver: &'static Stage4SumcheckDriverPlan, + store: &super::common::ValueStore, + evals: &[Stage4NamedEval], + point: &[Fr], + batching_coeffs: &[Fr], +) -> Result { + let batch = find_batch(program.batches, driver.symbol, driver.batch)?; + let claims = batch_claims(program.claims, batch)?; + let mut expected = Fr::from_u64(0); + for (claim, coefficient) in claims.iter().zip(batching_coeffs) { + let instance = program + .instance_results + .iter() + .find(|instance| instance.claim == claim.symbol && instance.source == driver.symbol) + .ok_or(VerifyStage4Error::MissingClaim { + batch: batch.symbol, + claim: claim.symbol, + })?; + let local_point = point + .get(instance.round_offset..instance.round_offset + instance.num_rounds) + .ok_or(VerifyStage4Error::InvalidInputLength { + input: instance.symbol, + expected: instance.round_offset + instance.num_rounds, + actual: point.len(), + })?; + let relation = claim.relation.unwrap_or(""); + let value = match relation { + "jolt.stage4.registers_read_write" => { + expected_registers_read_write(store, evals, local_point)? + } + "jolt.stage4.ram_val_check" => { + expected_ram_val_check(store, evals, local_point)? + } + _ => return Err(VerifyStage4Error::UnsupportedRelation { relation }), + }; + expected += *coefficient * value; + } + Ok(expected) +} + +fn expected_registers_read_write( + store: &super::common::ValueStore, + evals: &[Stage4NamedEval], + local_point: &[Fr], +) -> Result { + let trace_point = super::common::store_point(store, "stage4.input.stage3.registers.RdWriteValue")?; + let r_cycle = normalize_stage4_registers_rw_cycle_point( + local_point, + trace_point.len(), + "stage4.registers_read_write.instance", + )?; + let eq_eval = EqPolynomial::::mle(&r_cycle, trace_point); + let registers_val = eval_by_name( + evals, + "stage4.registers_read_write.eval.RegistersVal", + )?; + let rs1_ra = eval_by_name(evals, "stage4.registers_read_write.eval.Rs1Ra")?; + let rs2_ra = eval_by_name(evals, "stage4.registers_read_write.eval.Rs2Ra")?; + let rd_wa = eval_by_name(evals, "stage4.registers_read_write.eval.RdWa")?; + let rd_inc = eval_by_name(evals, "stage4.registers_read_write.eval.RdInc")?; + let gamma = super::common::store_scalar(store, "stage4.registers_read_write.gamma")?; + Ok(eq_eval + * (rd_wa * (registers_val + rd_inc) + + gamma * (rs1_ra * registers_val + gamma * rs2_ra * registers_val))) +} + +fn expected_ram_val_check( + store: &super::common::ValueStore, + evals: &[Stage4NamedEval], + local_point: &[Fr], +) -> Result { + let ram_val_point = super::common::store_point(store, "stage4.input.stage2.RamVal")?; + let r_cycle_prime = reverse_slice(local_point); + let r_cycle = suffix_point( + ram_val_point, + r_cycle_prime.len(), + "stage4.input.stage2.RamVal", + )?; + let lt_eval = lt_polynomial_eval(&r_cycle_prime, r_cycle); + let gamma = super::common::store_scalar(store, "stage4.ram_val_check.gamma")?; + let ram_ra = eval_by_name(evals, "stage4.ram_val_check.eval.RamRa")?; + let ram_inc = eval_by_name(evals, "stage4.ram_val_check.eval.RamInc")?; + Ok(ram_inc * ram_ra * (lt_eval + gamma)) +} + +fn suffix_point<'a>( + point: &'a [Fr], + length: usize, + input: &'static str, +) -> Result<&'a [Fr], VerifyStage4Error> { + point + .get(point.len().saturating_sub(length)..) + .filter(|suffix| suffix.len() == length) + .ok_or(VerifyStage4Error::InvalidInputLength { + input, + expected: length, + actual: point.len(), + }) +} + +fn normalize_stage4_registers_rw_point( + program: &'static Stage4VerifierProgramPlan, + driver: &'static str, + point: &[F], +) -> Result, VerifyStage4Error> { + let driver_plan = find_plan(program.drivers, driver).ok_or(VerifyStage4Error::MissingProof { + driver, + })?; + if driver_plan.round_schedule.len() != 2 { + return Err(VerifyStage4Error::InvalidProof { + driver, + reason: "stage4 registers point normalization requires [cycle, address] schedule", + }); + } + let cycle_rounds = driver_plan.round_schedule[0]; + let address_rounds = driver_plan.round_schedule[1]; + if point.len() != cycle_rounds + address_rounds { + return Err(VerifyStage4Error::InvalidInputLength { + input: "stage4.registers_read_write.instance", + expected: cycle_rounds + address_rounds, + actual: point.len(), + }); + } + let (cycle, address) = point.split_at(cycle_rounds); + Ok(address + .iter() + .rev() + .copied() + .chain(cycle.iter().rev().copied()) + .collect()) +} + +fn normalize_stage4_registers_rw_cycle_point( + point: &[F], + cycle_rounds: usize, + input: &'static str, +) -> Result, VerifyStage4Error> { + let cycle = point + .get(..cycle_rounds) + .filter(|cycle| cycle.len() == cycle_rounds) + .ok_or(VerifyStage4Error::InvalidInputLength { + input, + expected: cycle_rounds, + actual: point.len(), + })?; + Ok(cycle.iter().rev().copied().collect()) +} + +"# + } + } + } + + fn role_label(&self) -> &'static str { + match self.role { + Role::Prover => "prover", + Role::Verifier => "verifier", + } + } + + fn program_plan_type(&self) -> &'static str { + match self.role { + Role::Prover => "Stage4CpuProgramPlan", + Role::Verifier => "Stage4VerifierProgramPlan", + } + } +} + +fn require_supported_symbol(kind: &str, actual: &str, expected: &str) -> Result<(), EmitError> { + if actual == expected { + Ok(()) + } else { + Err(EmitError::new(format!( + "unsupported {kind} @{actual}; expected @{expected}" + ))) + } +} + +fn emit_str_array(name: &str, values: &[String]) -> String { + if values.is_empty() { + return format!("pub const {name}: &[&str] = &[];\n\n"); + } + if let [value] = values { + return format!("pub const {name}: &[&str] = &[{}];\n\n", rust_str(value)); + } + let entries = values + .iter() + .map(|value| format!(" {},", rust_str(value))) + .collect::>() + .join("\n"); + format!("pub const {name}: &[&str] = &[\n{entries}\n];\n\n") +} + +fn emit_usize_array(name: &str, values: &[usize]) -> String { + let entries = values + .iter() + .map(|value| format!(" {value},")) + .collect::>() + .join("\n"); + format!("pub const {name}: &[usize] = &[\n{entries}\n];\n\n") +} + +fn intern_str_array( + source: &mut String, + arrays: &mut Vec<(Vec, String)>, + name_prefix: &str, + values: &[String], +) -> String { + if let Some((_, name)) = arrays + .iter() + .find(|(existing, _)| existing.as_slice() == values) + { + return name.clone(); + } + let name = format!("{name_prefix}_{}", arrays.len()); + source.push_str(&emit_str_array(&name, values)); + arrays.push((values.to_vec(), name.clone())); + name +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn rust_option_str(value: Option<&str>) -> String { + value.map_or_else( + || "None".to_owned(), + |value| format!("Some({})", rust_str(value)), + ) +} + +fn verify_count(kind: &str, symbol: &str, expected: usize, actual: usize) -> Result<(), EmitError> { + if expected == actual { + Ok(()) + } else { + Err(EmitError::new(format!( + "{kind} @{symbol} count mismatch: expected {expected}, got {actual}" + ))) + } +} + +fn symbols<'a>(values: impl Iterator) -> BTreeSet { + values.cloned().collect() +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn int_array_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "integer array"))?; + parse_int_array(&attribute).ok_or_else(|| attr_error(operation, attr, "integer array")) +} + +fn parse_int_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().parse().ok()) + .collect() +} + +fn operand_symbols( + operation: OperationRef<'_, '_>, + start_index: usize, +) -> Result, EmitError> { + (start_index..operation.operand_count()) + .map(|index| operand_symbol(operation, index)) + .collect() +} + +fn operand_symbol(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + EmitError::new(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + EmitError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name<'c: 'a, 'a>(operation: impl OperationLike<'c, 'a>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} diff --git a/crates/bolt/src/protocols/jolt/emit/rust/stage5.rs b/crates/bolt/src/protocols/jolt/emit/rust/stage5.rs new file mode 100644 index 0000000000..30ab3d7edc --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/stage5.rs @@ -0,0 +1,2489 @@ +#![expect( + clippy::needless_raw_string_hashes, + reason = "generated Rust templates are kept as raw string blocks for copyable output" +)] + +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{Attribute, OperationRef}; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5CpuProgram { + pub role: Role, + pub params: Stage5Params, + pub steps: Vec, + pub transcript_squeezes: Vec, + pub transcript_absorb_bytes: Vec, + pub opening_inputs: Vec, + pub field_constants: Vec, + pub field_exprs: Vec, + pub kernels: Vec, + pub claims: Vec, + pub batches: Vec, + pub drivers: Vec, + pub instance_results: Vec, + pub evals: Vec, + pub point_slices: Vec, + pub point_concats: Vec, + pub opening_claims: Vec, + pub opening_equalities: Vec, + pub opening_batches: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5Params { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5KernelPlan { + pub symbol: String, + pub relation: String, + pub kind: String, + pub backend: String, + pub abi: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5TranscriptSqueezePlan { + pub symbol: String, + pub label: String, + pub kind: String, + pub count: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5TranscriptAbsorbBytesPlan { + pub symbol: String, + pub label: String, + pub payload: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5ProgramStepPlan { + pub kind: String, + pub symbol: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5OpeningInputPlan { + pub symbol: String, + pub source_stage: String, + pub source_claim: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5FieldConstantPlan { + pub symbol: String, + pub field: String, + pub value: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5FieldExprPlan { + pub symbol: String, + pub kind: String, + pub formula: String, + pub operand_names: Vec, + pub operands: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckClaimPlan { + pub symbol: String, + pub stage: String, + pub domain: String, + pub num_rounds: usize, + pub degree: usize, + pub claim: String, + pub kernel: Option, + pub relation: Option, + pub claim_value: String, + pub input_openings: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, + pub claim_label: String, + pub round_label: String, + pub round_schedule: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckDriverPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub kernel: Option, + pub relation: Option, + pub batch: String, + pub policy: String, + pub round_schedule: Vec, + pub claim_label: String, + pub round_label: String, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckInstanceResultPlan { + pub symbol: String, + pub source: String, + pub claim: String, + pub relation: String, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: String, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckEvalPlan { + pub symbol: String, + pub source: String, + pub name: String, + pub index: usize, + pub oracle: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5PointSlicePlan { + pub symbol: String, + pub source: String, + pub offset: usize, + pub length: usize, + pub input: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5PointConcatPlan { + pub symbol: String, + pub layout: String, + pub arity: usize, + pub inputs: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5OpeningClaimPlan { + pub symbol: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, + pub point_source: String, + pub eval_source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5OpeningClaimEqualityPlan { + pub symbol: String, + pub mode: String, + pub lhs: String, + pub rhs: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage5OpeningBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, +} + +pub fn stage5_cpu_program(module: &BoltModule<'_, Cpu>) -> Result { + verify_cpu_schema(module)?; + let program = Stage5CpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +pub fn emit_stage5_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = stage5_cpu_program(module)?; + + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source(), + }) +} + +impl Stage5CpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let mut params = None; + let mut steps = Vec::new(); + let mut transcript_squeezes = Vec::new(); + let mut transcript_absorb_bytes = Vec::new(); + let mut opening_inputs = Vec::new(); + let mut field_constants = Vec::new(); + let mut field_exprs = Vec::new(); + let mut kernels = Vec::new(); + let mut claims = Vec::new(); + let mut batches = Vec::new(); + let mut drivers = Vec::new(); + let mut instance_results = Vec::new(); + let mut evals = Vec::new(); + let mut point_slices = Vec::new(); + let mut point_concats = Vec::new(); + let mut opening_claims = Vec::new(); + let mut opening_equalities = Vec::new(); + let mut opening_batches = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(Stage5Params { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.kernel" => { + kernels.push(Stage5KernelPlan { + symbol: string_attr(op, "sym_name")?, + relation: symbol_attr(op, "relation")?, + kind: string_attr(op, "kind")?, + backend: string_attr(op, "backend")?, + abi: string_attr(op, "abi")?, + }); + } + "cpu.transcript_squeeze" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage5ProgramStepPlan { + kind: "transcript_squeeze".to_owned(), + symbol: symbol.clone(), + }); + transcript_squeezes.push(Stage5TranscriptSqueezePlan { + symbol, + label: string_attr(op, "label")?, + kind: string_attr(op, "kind")?, + count: int_attr(op, "count")?, + }); + } + "cpu.transcript_absorb_bytes" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage5ProgramStepPlan { + kind: "transcript_absorb_bytes".to_owned(), + symbol: symbol.clone(), + }); + transcript_absorb_bytes.push(Stage5TranscriptAbsorbBytesPlan { + symbol, + label: string_attr(op, "label")?, + payload: string_attr(op, "payload")?, + }); + } + "cpu.opening_input" => { + opening_inputs.push(Stage5OpeningInputPlan { + symbol: string_attr(op, "sym_name")?, + source_stage: symbol_attr(op, "source_stage")?, + source_claim: symbol_attr(op, "source_claim")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + }); + } + "cpu.field_const" => { + field_constants.push(Stage5FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: int_attr(op, "value")?, + }); + } + "cpu.field_zero" => { + field_constants.push(Stage5FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 0, + }); + } + "cpu.field_one" => { + field_constants.push(Stage5FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 1, + }); + } + "cpu.field_add" | "cpu.field_sub" | "cpu.field_mul" | "cpu.field_neg" => { + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage5FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: operation_name(op).replace("cpu.field_", "field."), + operand_names: operands.clone(), + operands, + }); + } + "cpu.field_pow" => { + let exponent = int_attr(op, "exponent")?; + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage5FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: format!("field.pow:{exponent}"), + operand_names: operands.clone(), + operands, + }); + } + "cpu.sumcheck_claim" => { + claims.push(Stage5SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_verify_claim" => { + claims.push(Stage5SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_batch" => { + batches.push(Stage5SumcheckBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + round_schedule: int_array_attr(op, "round_schedule")?, + }); + } + "cpu.sumcheck_driver" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage5ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage5SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_verify" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage5ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage5SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_instance_result" => { + instance_results.push(Stage5SumcheckInstanceResultPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + claim: symbol_attr(op, "claim")?, + relation: symbol_attr(op, "relation")?, + index: int_attr(op, "index")?, + point_arity: int_attr(op, "point_arity")?, + num_rounds: int_attr(op, "num_rounds")?, + round_offset: int_attr(op, "round_offset")?, + point_order: string_attr(op, "point_order")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_eval" => { + evals.push(Stage5SumcheckEvalPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + name: symbol_attr(op, "name")?, + index: int_attr(op, "index")?, + oracle: symbol_attr(op, "oracle")?, + }); + } + "cpu.point_slice" => { + point_slices.push(Stage5PointSlicePlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + offset: int_attr(op, "offset")?, + length: int_attr(op, "length")?, + input: operand_symbol(op, 0)?, + }); + } + "cpu.point_concat" => { + point_concats.push(Stage5PointConcatPlan { + symbol: string_attr(op, "sym_name")?, + layout: string_attr(op, "layout")?, + arity: int_attr(op, "arity")?, + inputs: operand_symbols(op, 0)?, + }); + } + "cpu.opening_claim" => { + opening_claims.push(Stage5OpeningClaimPlan { + symbol: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + point_source: operand_symbol(op, 0)?, + eval_source: operand_symbol(op, 1)?, + }); + } + "cpu.opening_claim_equal" => { + opening_equalities.push(Stage5OpeningClaimEqualityPlan { + symbol: string_attr(op, "sym_name")?, + mode: string_attr(op, "mode")?, + lhs: operand_symbol(op, 0)?, + rhs: operand_symbol(op, 1)?, + }); + } + "cpu.opening_batch" => { + opening_batches.push(Stage5OpeningBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + }); + } + _ => {} + } + } + + Ok(Self { + params: params.ok_or_else(|| EmitError::new("missing cpu.params"))?, + role: module + .role() + .ok_or_else(|| EmitError::new("missing cpu party role"))?, + steps, + transcript_squeezes, + transcript_absorb_bytes, + opening_inputs, + field_constants, + field_exprs, + kernels, + claims, + batches, + drivers, + instance_results, + evals, + point_slices, + point_concats, + opening_claims, + opening_equalities, + opening_batches, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + require_supported_symbol("field", &self.params.field, "bn254_fr")?; + require_supported_symbol("pcs", &self.params.pcs, "dory")?; + require_supported_symbol("transcript", &self.params.transcript, "blake2b_transcript")?; + self.verify_transcript_steps()?; + self.verify_field_flow()?; + self.verify_claim_batches()?; + match self.role { + Role::Prover => { + self.verify_kernel_definitions()?; + self.verify_prover_driver_bindings()?; + } + Role::Verifier => self.verify_verifier_driver_bindings()?, + } + self.verify_opening_flow() + } + + fn verify_transcript_steps(&self) -> Result<(), EmitError> { + for squeeze in &self.transcript_squeezes { + if !matches!( + squeeze.kind.as_str(), + "challenge_scalar" | "challenge_vector" + ) { + return Err(EmitError::new(format!( + "stage5 transcript squeeze @{} has unsupported kind `{}`", + squeeze.symbol, squeeze.kind + ))); + } + if squeeze.count == 0 { + return Err(EmitError::new(format!( + "stage5 transcript squeeze @{} has zero count", + squeeze.symbol + ))); + } + } + for absorb in &self.transcript_absorb_bytes { + if absorb.label.is_empty() { + return Err(EmitError::new(format!( + "stage5 transcript byte absorb @{} has empty label", + absorb.symbol + ))); + } + } + Ok(()) + } + + fn verify_field_flow(&self) -> Result<(), EmitError> { + for constant in &self.field_constants { + require_supported_symbol("field constant field", &constant.field, "bn254_fr")?; + } + let field_values = self.field_value_symbols(); + for expr in &self.field_exprs { + verify_count( + "field expr operands", + &expr.symbol, + expr.operand_names.len(), + expr.operands.len(), + )?; + for operand in &expr.operands { + if !field_values.contains(operand) { + return Err(EmitError::new(format!( + "field expr @{} references missing field value @{operand}", + expr.symbol + ))); + } + } + } + for claim in &self.claims { + if !field_values.contains(&claim.claim_value) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing claim value @{}", + claim.symbol, claim.claim_value + ))); + } + } + Ok(()) + } + + fn field_value_symbols(&self) -> BTreeSet { + let mut values = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + values.extend(symbols( + self.field_constants.iter().map(|constant| &constant.symbol), + )); + values.extend(symbols( + self.transcript_squeezes + .iter() + .filter(|squeeze| matches!(squeeze.kind.as_str(), "challenge_scalar" | "scalar")) + .map(|squeeze| &squeeze.symbol), + )); + values.extend(symbols(self.field_exprs.iter().map(|expr| &expr.symbol))); + values.extend(symbols(self.evals.iter().map(|eval| &eval.symbol))); + values + } + + fn verify_kernel_definitions(&self) -> Result<(), EmitError> { + for kernel in &self.kernels { + if kernel.backend != "cpu" { + return Err(EmitError::new(format!( + "stage5 kernel @{} targets unsupported backend `{}`", + kernel.symbol, kernel.backend + ))); + } + if kernel.kind != "sumcheck" { + return Err(EmitError::new(format!( + "stage5 kernel @{} has unsupported kind `{}`", + kernel.symbol, kernel.kind + ))); + } + let expected_abi = match kernel.relation.as_str() { + "jolt.stage5.instruction_read_raf" => "jolt_stage5_instruction_read_raf", + "jolt.stage5.ram_ra_claim_reduction" => "jolt_stage5_ram_ra_claim_reduction", + "jolt.stage5.registers_val_evaluation" => "jolt_stage5_registers_val_evaluation", + "jolt.stage5.batched" => "jolt_stage5_batched", + _ => { + return Err(EmitError::new(format!( + "unsupported stage5 kernel relation @{}", + kernel.relation + ))); + } + }; + if kernel.abi != expected_abi { + return Err(EmitError::new(format!( + "stage5 kernel @{} ABI `{}` does not match relation @{}", + kernel.symbol, kernel.abi, kernel.relation + ))); + } + } + Ok(()) + } + + fn verify_claim_batches(&self) -> Result<(), EmitError> { + let claims = symbols(self.claims.iter().map(|claim| &claim.symbol)); + for batch in &self.batches { + verify_count( + "sumcheck batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "sumcheck batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "sumcheck batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !claims.contains(claim) { + return Err(EmitError::new(format!( + "sumcheck batch @{} references missing claim @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn verify_prover_driver_bindings(&self) -> Result<(), EmitError> { + let kernels = symbols(self.kernels.iter().map(|kernel| &kernel.symbol)); + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + let Some(kernel) = claim.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck claim @{} is missing kernel", + claim.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck claim @{} references missing kernel @{kernel}", + claim.symbol + ))); + } + } + for driver in &self.drivers { + let Some(kernel) = driver.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck driver @{} is missing kernel", + driver.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck driver @{} references missing kernel @{kernel}", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_verifier_driver_bindings(&self) -> Result<(), EmitError> { + if !self.kernels.is_empty() { + return Err(EmitError::new( + "verifier stage5 program must not contain kernels", + )); + } + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + if claim.kernel.is_some() || claim.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck claim @{} must carry relation and no kernel", + claim.symbol + ))); + } + } + for driver in &self.drivers { + if driver.kernel.is_some() || driver.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck driver @{} must carry relation and no kernel", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_opening_flow(&self) -> Result<(), EmitError> { + let mut point_sources = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + point_sources.extend(symbols( + self.instance_results + .iter() + .map(|instance| &instance.symbol), + )); + point_sources.extend(symbols( + self.opening_inputs.iter().map(|input| &input.symbol), + )); + point_sources.extend(symbols(self.point_slices.iter().map(|slice| &slice.symbol))); + point_sources.extend(symbols( + self.point_concats.iter().map(|concat| &concat.symbol), + )); + for slice in &self.point_slices { + if !point_sources.contains(&slice.input) { + return Err(EmitError::new(format!( + "point slice @{} uses missing point source @{}", + slice.symbol, slice.input + ))); + } + } + for concat in &self.point_concats { + for input in &concat.inputs { + if !point_sources.contains(input) { + return Err(EmitError::new(format!( + "point concat @{} uses missing point source @{input}", + concat.symbol + ))); + } + } + } + let eval_sources = self.field_value_symbols(); + let mut opening_sources = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + opening_sources.extend(symbols( + self.opening_claims.iter().map(|claim| &claim.symbol), + )); + for equality in &self.opening_equalities { + if !opening_sources.contains(&equality.lhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing lhs opening @{}", + equality.symbol, equality.lhs + ))); + } + if !opening_sources.contains(&equality.rhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing rhs opening @{}", + equality.symbol, equality.rhs + ))); + } + } + for claim in &self.claims { + for input in &claim.input_openings { + if !opening_sources.contains(input) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing opening @{input}", + claim.symbol + ))); + } + } + } + let drivers = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + for instance in &self.instance_results { + if !drivers.contains(&instance.source) { + return Err(EmitError::new(format!( + "sumcheck instance result @{} references missing driver @{}", + instance.symbol, instance.source + ))); + } + } + for eval in &self.evals { + if !drivers.contains(&eval.source) { + return Err(EmitError::new(format!( + "sumcheck eval @{} references missing driver @{}", + eval.symbol, eval.source + ))); + } + } + for claim in &self.opening_claims { + if !point_sources.contains(&claim.point_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing point source @{}", + claim.symbol, claim.point_source + ))); + } + if !eval_sources.contains(&claim.eval_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing eval source @{}", + claim.symbol, claim.eval_source + ))); + } + } + let openings = symbols(self.opening_claims.iter().map(|claim| &claim.symbol)); + for batch in &self.opening_batches { + verify_count( + "opening batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "opening batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "opening batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !openings.contains(claim) { + return Err(EmitError::new(format!( + "opening batch @{} references missing opening @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_stage5.rs", + Role::Verifier => "verify_stage5.rs", + } + } + + fn emit_source(&self) -> String { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + match self.role { + Role::Prover => { + source.push_str(Self::emit_prover_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_prover_types()); + } + Role::Verifier => { + source.push_str(Self::emit_verifier_imports()); + source.push_str("\n\n"); + source.push_str(&Self::emit_verifier_types()); + } + } + source.push('\n'); + source.push_str(&self.emit_constants()); + source.push('\n'); + source.push_str(self.emit_entrypoint()); + source + } + + fn emit_prover_imports() -> &'static str { + "use jolt_field::Fr;\n\ + use jolt_kernels::stage5::{execute_stage5_program, Stage5CpuProgramPlan, Stage5ExecutionArtifacts, Stage5ExecutionMode, Stage5FieldConstantPlan, Stage5FieldExprPlan, Stage5KernelError, Stage5KernelExecutor, Stage5KernelPlan, Stage5OpeningBatchPlan, Stage5OpeningClaimEqualityPlan, Stage5OpeningClaimPlan, Stage5OpeningInputPlan, Stage5Params, Stage5PointConcatPlan, Stage5PointSlicePlan, Stage5ProgramStepPlan, Stage5SumcheckBatchPlan, Stage5SumcheckClaimPlan, Stage5SumcheckDriverPlan, Stage5SumcheckEvalPlan, Stage5SumcheckInstanceResultPlan, Stage5TranscriptAbsorbBytesPlan, Stage5TranscriptSqueezePlan};\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_prover_types() -> &'static str { + "pub type DefaultStage5Transcript = Blake2bTranscript;\n" + } + + fn emit_verifier_imports() -> &'static str { + "use super::common::{batch_claims, eval_by_name, find_batch, find_plan, identity_polynomial_eval, indexed_evals_by_prefix, indexed_evals_by_prefix_any, lt_polynomial_eval, normalize_instruction_read_raf_point, operand_polynomial_eval, reverse_slice, suffix_point};\n\ + use jolt_field::{Field, Fr};\n\ + use jolt_lookup_tables::LookupTableKind;\n\ + use jolt_poly::EqPolynomial;\n\ + use jolt_sumcheck::SumcheckError;\n\ + use jolt_transcript::{Blake2bTranscript, LabelWithCount, Transcript};" + } + + #[expect(dead_code)] + fn emit_types() -> &'static str { + r#"#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5Params { + pub field: &'static str, + pub pcs: &'static str, + pub transcript: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5KernelPlan { + pub symbol: &'static str, + pub relation: &'static str, + pub kind: &'static str, + pub backend: &'static str, + pub abi: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5TranscriptSqueezePlan { + pub symbol: &'static str, + pub label: &'static str, + pub kind: &'static str, + pub count: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5TranscriptAbsorbBytesPlan { + pub symbol: &'static str, + pub label: &'static str, + pub payload: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5ProgramStepPlan { + pub kind: &'static str, + pub symbol: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5OpeningInputPlan { + pub symbol: &'static str, + pub source_stage: &'static str, + pub source_claim: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5FieldConstantPlan { + pub symbol: &'static str, + pub field: &'static str, + pub value: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5FieldExprPlan { + pub symbol: &'static str, + pub kind: &'static str, + pub formula: &'static str, + pub operand_names: &'static [&'static str], + pub operands: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckClaimPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub domain: &'static str, + pub num_rounds: usize, + pub degree: usize, + pub claim: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub claim_value: &'static str, + pub input_openings: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static [&'static str], + pub claim_operands: &'static [&'static str], + pub claim_label: &'static str, + pub round_label: &'static str, + pub round_schedule: &'static [usize], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckDriverPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub batch: &'static str, + pub policy: &'static str, + pub round_schedule: &'static [usize], + pub claim_label: &'static str, + pub round_label: &'static str, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckInstanceResultPlan { + pub symbol: &'static str, + pub source: &'static str, + pub claim: &'static str, + pub relation: &'static str, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: &'static str, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5SumcheckEvalPlan { + pub symbol: &'static str, + pub source: &'static str, + pub name: &'static str, + pub index: usize, + pub oracle: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5PointSlicePlan { + pub symbol: &'static str, + pub source: &'static str, + pub offset: usize, + pub length: usize, + pub input: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5PointConcatPlan { + pub symbol: &'static str, + pub layout: &'static str, + pub arity: usize, + pub inputs: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5OpeningClaimPlan { + pub symbol: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, + pub point_source: &'static str, + pub eval_source: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5OpeningClaimEqualityPlan { + pub symbol: &'static str, + pub mode: &'static str, + pub lhs: &'static str, + pub rhs: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5OpeningBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static [&'static str], + pub claim_operands: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage5CpuProgramPlan { + pub role: &'static str, + pub params: Stage5Params, + pub steps: &'static [Stage5ProgramStepPlan], + pub transcript_squeezes: &'static [Stage5TranscriptSqueezePlan], + pub transcript_absorb_bytes: &'static [Stage5TranscriptAbsorbBytesPlan], + pub opening_inputs: &'static [Stage5OpeningInputPlan], + pub field_constants: &'static [Stage5FieldConstantPlan], + pub field_exprs: &'static [Stage5FieldExprPlan], + pub kernels: &'static [Stage5KernelPlan], + pub claims: &'static [Stage5SumcheckClaimPlan], + pub batches: &'static [Stage5SumcheckBatchPlan], + pub drivers: &'static [Stage5SumcheckDriverPlan], + pub instance_results: &'static [Stage5SumcheckInstanceResultPlan], + pub evals: &'static [Stage5SumcheckEvalPlan], + pub point_slices: &'static [Stage5PointSlicePlan], + pub point_concats: &'static [Stage5PointConcatPlan], + pub opening_claims: &'static [Stage5OpeningClaimPlan], + pub opening_equalities: &'static [Stage5OpeningClaimEqualityPlan], + pub opening_batches: &'static [Stage5OpeningBatchPlan], +} +"# + } + + fn emit_verifier_type_aliases() -> &'static str { + r#"pub type Stage5NamedEval = super::common::StageNamedEval; +pub type Stage5SumcheckOutput = super::common::StageSumcheckOutput; +pub type Stage5ChallengeVector = super::common::StageChallengeVector; +pub type Stage5ExecutionArtifacts = super::common::StageExecutionArtifacts; +pub type Stage5Proof = super::common::StageProof; +pub type Stage5OpeningInputValue = super::common::StageOpeningInputValue; + +pub use super::common::{ + FieldConstantPlan as Stage5FieldConstantPlan, FieldExprPlan as Stage5FieldExprPlan, + KernelPlan as Stage5KernelPlan, OpeningBatchPlan as Stage5OpeningBatchPlan, + OpeningClaimEqualityPlan as Stage5OpeningClaimEqualityPlan, + OpeningClaimPlan as Stage5OpeningClaimPlan, OpeningInputPlan as Stage5OpeningInputPlan, + PointConcatPlan as Stage5PointConcatPlan, PointSlicePlan as Stage5PointSlicePlan, + ProgramStepPlan as Stage5ProgramStepPlan, StageParams as Stage5Params, + StageProgramPlanNoPointZeros as Stage5CpuProgramPlan, + SumcheckBatchPlan as Stage5SumcheckBatchPlan, + SumcheckClaimPlan as Stage5SumcheckClaimPlan, SumcheckDriverPlan as Stage5SumcheckDriverPlan, + SumcheckEvalPlan as Stage5SumcheckEvalPlan, + SumcheckInstanceResultPlan as Stage5SumcheckInstanceResultPlan, + TranscriptAbsorbBytesPlan as Stage5TranscriptAbsorbBytesPlan, + TranscriptSqueezePlan as Stage5TranscriptSqueezePlan, +}; +"# + } + + fn emit_verifier_types() -> String { + let mut source = Self::emit_verifier_type_aliases().to_owned(); + source.push_str( + r#" +pub type DefaultStage5Transcript = Blake2bTranscript; +pub type Stage5VerifierProgramPlan = Stage5CpuProgramPlan; + +#[derive(Debug)] +pub enum VerifyStage5Error { + UnexpectedProofCount { expected: usize, got: usize }, + MissingProof { driver: &'static str }, + MissingBatch { driver: &'static str, batch: &'static str }, + MissingClaim { batch: &'static str, claim: &'static str }, + MissingValue { symbol: &'static str }, + InvalidInputLength { input: &'static str, expected: usize, actual: usize }, + InvalidProof { driver: &'static str, reason: &'static str }, + UnsupportedFieldExpr { symbol: &'static str, formula: &'static str }, + UnsupportedRelation { relation: &'static str }, + Sumcheck { driver: &'static str, error: SumcheckError }, +} + +super::common::impl_runtime_plan_error_conversion!(VerifyStage5Error); +"#, + ); + source + } + + fn emit_constants(&self) -> String { + let mut source = self.emit_shared_constants(); + source.push_str(&self.emit_kernel_constants()); + source.push_str(&self.emit_sumcheck_claim_constants()); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_sumcheck_driver_constants()); + source.push_str(&self.emit_tail_constants()); + push_format( + &mut source, + format_args!( + "pub const STAGE5_PROGRAM: {} = Stage5CpuProgramPlan {{\n\ + \x20 role: {},\n\ + \x20 params: STAGE5_PARAMS,\n\ + \x20 steps: STAGE5_PROGRAM_STEPS,\n\ + \x20 transcript_squeezes: STAGE5_TRANSCRIPT_SQUEEZES,\n\ + \x20 transcript_absorb_bytes: STAGE5_TRANSCRIPT_ABSORB_BYTES,\n\ + \x20 opening_inputs: STAGE5_OPENING_INPUTS,\n\ + \x20 field_constants: STAGE5_FIELD_CONSTANTS,\n\ + \x20 field_exprs: STAGE5_FIELD_EXPRS,\n\ + \x20 kernels: STAGE5_KERNELS,\n\ + \x20 claims: STAGE5_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE5_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE5_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE5_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE5_SUMCHECK_EVALS,\n\ + \x20 point_slices: STAGE5_POINT_SLICES,\n\ + \x20 point_concats: STAGE5_POINT_CONCATS,\n\ + \x20 opening_claims: STAGE5_OPENING_CLAIMS,\n\ + \x20 opening_equalities: STAGE5_OPENING_EQUALITIES,\n\ + \x20 opening_batches: STAGE5_OPENING_BATCHES,\n\ + }};\n", + self.program_plan_type(), + rust_str(self.role_label()) + ), + ); + source + } + + fn emit_shared_constants(&self) -> String { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const STAGE5_PARAMS: Stage5Params = Stage5Params {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + source.push_str(&self.emit_program_step_constants()); + source.push_str(&self.emit_transcript_squeeze_constants()); + source.push_str(&self.emit_transcript_absorb_bytes_constants()); + source.push_str(&self.emit_opening_input_constants()); + source.push_str(&self.emit_field_constant_constants()); + source.push_str(&self.emit_field_expr_constants()); + source + } + + fn emit_program_step_constants(&self) -> String { + let steps = self + .steps + .iter() + .map(|step| { + format!( + " Stage5ProgramStepPlan {{ kind: {}, symbol: {} }},", + rust_str(&step.kind), + rust_str(&step.symbol), + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE5_PROGRAM_STEPS: &[Stage5ProgramStepPlan] = &[\n{steps}\n];\n\n") + } + + fn emit_transcript_squeeze_constants(&self) -> String { + let squeezes = self + .transcript_squeezes + .iter() + .map(|squeeze| { + format!( + " Stage5TranscriptSqueezePlan {{ symbol: {}, label: {}, kind: {}, count: {} }},", + rust_str(&squeeze.symbol), + rust_str(&squeeze.label), + rust_str(&squeeze.kind), + squeeze.count, + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE5_TRANSCRIPT_SQUEEZES: &[Stage5TranscriptSqueezePlan] = &[\n{squeezes}\n];\n\n" + ) + } + + fn emit_transcript_absorb_bytes_constants(&self) -> String { + let absorbs = self + .transcript_absorb_bytes + .iter() + .map(|absorb| { + format!( + " Stage5TranscriptAbsorbBytesPlan {{ symbol: {}, label: {}, payload: {} }},", + rust_str(&absorb.symbol), + rust_str(&absorb.label), + rust_str(&absorb.payload), + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE5_TRANSCRIPT_ABSORB_BYTES: &[Stage5TranscriptAbsorbBytesPlan] = &[\n{absorbs}\n];\n\n" + ) + } + + fn emit_opening_input_constants(&self) -> String { + let inputs = self + .opening_inputs + .iter() + .map(|input| { + format!( + " Stage5OpeningInputPlan {{ symbol: {}, source_stage: {}, source_claim: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {} }},", + rust_str(&input.symbol), + rust_str(&input.source_stage), + rust_str(&input.source_claim), + rust_str(&input.oracle), + rust_str(&input.domain), + input.point_arity, + rust_str(&input.claim_kind) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE5_OPENING_INPUTS: &[Stage5OpeningInputPlan] = &[\n{inputs}\n];\n\n") + } + + fn emit_field_constant_constants(&self) -> String { + let constants = self + .field_constants + .iter() + .map(|constant| { + format!( + " Stage5FieldConstantPlan {{ symbol: {}, field: {}, value: {} }},", + rust_str(&constant.symbol), + rust_str(&constant.field), + constant.value + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE5_FIELD_CONSTANTS: &[Stage5FieldConstantPlan] = &[\n{constants}\n];\n\n" + ) + } + + fn emit_field_expr_constants(&self) -> String { + if self.role == Role::Verifier { + let exprs = self + .field_exprs + .iter() + .map(|expr| { + format!( + " Stage5FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operands: {} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula), + rust_str(&expr.operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE5_FIELD_EXPRS: &[Stage5FieldExprPlan] = &[\n{exprs}\n];\n" + ); + } + + let mut source = String::new(); + let mut arrays = Vec::new(); + let mut array_refs = Vec::new(); + for (index, expr) in self.field_exprs.iter().enumerate() { + let operands = intern_str_array( + &mut source, + &mut arrays, + "STAGE5_FIELD_EXPR_OPERANDS", + &expr.operands, + ); + let operand_names = intern_str_array( + &mut source, + &mut arrays, + "STAGE5_FIELD_EXPR_OPERANDS", + &expr.operand_names, + ); + array_refs.push((index, operand_names, operands)); + } + let exprs = self + .field_exprs + .iter() + .enumerate() + .map(|(index, expr)| { + let (_, operand_names, operands) = &array_refs[index]; + format!( + " Stage5FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operand_names: {operand_names}, operands: {operands} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE5_FIELD_EXPRS: &[Stage5FieldExprPlan] = &[\n{exprs}\n];\n" + ), + ); + source + } + + fn emit_kernel_constants(&self) -> String { + let kernels = self + .kernels + .iter() + .map(|kernel| { + format!( + " Stage5KernelPlan {{ symbol: {}, relation: {}, kind: {}, backend: {}, abi: {} }},", + rust_str(&kernel.symbol), + rust_str(&kernel.relation), + rust_str(&kernel.kind), + rust_str(&kernel.backend), + rust_str(&kernel.abi) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE5_KERNELS: &[Stage5KernelPlan] = &[\n{kernels}\n];\n\n") + } + + fn emit_sumcheck_claim_constants(&self) -> String { + if self.role == Role::Verifier { + let claims = self + .claims + .iter() + .map(|claim| { + format!( + " Stage5SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: {}, relation: {}, claim_value: {}, input_openings: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_option_str(claim.kernel.as_deref()), + rust_option_str(claim.relation.as_deref()), + rust_str(&claim.claim_value), + rust_str(&claim.input_openings.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE5_SUMCHECK_CLAIMS: &[Stage5SumcheckClaimPlan] = &[\n{claims}\n];\n" + ); + } + + let mut source = String::new(); + for (index, claim) in self.claims.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE5_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS"), + &claim.input_openings, + )); + } + let claims = self + .claims + .iter() + .enumerate() + .map(|(index, claim)| { + format!( + " Stage5SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: {}, relation: {}, claim_value: {}, input_openings: STAGE5_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_option_str(claim.kernel.as_deref()), + rust_option_str(claim.relation.as_deref()), + rust_str(&claim.claim_value) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE5_SUMCHECK_CLAIMS: &[Stage5SumcheckClaimPlan] = &[\n{claims}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE5_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage5SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {}, claim_label: {}, round_label: {}, round_schedule: STAGE5_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")), + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE5_SUMCHECK_BATCHES: &[Stage5SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + return source; + } + + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE5_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE5_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + source.push_str(&emit_usize_array( + &format!("STAGE5_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage5SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE5_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE5_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS, claim_label: {}, round_label: {}, round_schedule: STAGE5_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE5_SUMCHECK_BATCHES: &[Stage5SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_driver_constants(&self) -> String { + let mut source = String::new(); + for (index, driver) in self.drivers.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE5_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE"), + &driver.round_schedule, + )); + } + let drivers = self + .drivers + .iter() + .enumerate() + .map(|(index, driver)| { + format!( + " Stage5SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: {}, relation: {}, batch: {}, policy: {}, round_schedule: STAGE5_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_option_str(driver.kernel.as_deref()), + rust_option_str(driver.relation.as_deref()), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE5_SUMCHECK_DRIVERS: &[Stage5SumcheckDriverPlan] = &[\n{drivers}\n];\n" + ), + ); + source + } + + fn emit_tail_constants(&self) -> String { + let mut source = String::new(); + source.push_str(&self.emit_sumcheck_instance_result_constants()); + source.push_str(&self.emit_sumcheck_eval_constants()); + source.push_str(&self.emit_point_slice_constants()); + source.push_str(&self.emit_point_concat_constants()); + source.push_str(&self.emit_opening_claim_constants()); + source.push_str(&self.emit_opening_claim_equality_constants()); + source.push_str(&self.emit_opening_batch_constants()); + source + } + + fn emit_sumcheck_instance_result_constants(&self) -> String { + let instances = self + .instance_results + .iter() + .map(|instance| { + format!( + " Stage5SumcheckInstanceResultPlan {{ symbol: {}, source: {}, claim: {}, relation: {}, index: {}, point_arity: {}, num_rounds: {}, round_offset: {}, point_order: {}, degree: {} }},", + rust_str(&instance.symbol), + rust_str(&instance.source), + rust_str(&instance.claim), + rust_str(&instance.relation), + instance.index, + instance.point_arity, + instance.num_rounds, + instance.round_offset, + rust_str(&instance.point_order), + instance.degree + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE5_SUMCHECK_INSTANCE_RESULTS: &[Stage5SumcheckInstanceResultPlan] = &[\n{instances}\n];\n\n" + ) + } + + fn emit_sumcheck_eval_constants(&self) -> String { + let evals = self + .evals + .iter() + .map(|eval| { + format!( + " Stage5SumcheckEvalPlan {{ symbol: {}, source: {}, name: {}, index: {}, oracle: {} }},", + rust_str(&eval.symbol), + rust_str(&eval.source), + rust_str(&eval.name), + eval.index, + rust_str(&eval.oracle) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE5_SUMCHECK_EVALS: &[Stage5SumcheckEvalPlan] = &[\n{evals}\n];\n\n") + } + + fn emit_point_slice_constants(&self) -> String { + let slices = self + .point_slices + .iter() + .map(|slice| { + format!( + " Stage5PointSlicePlan {{ symbol: {}, source: {}, offset: {}, length: {}, input: {} }},", + rust_str(&slice.symbol), + rust_str(&slice.source), + slice.offset, + slice.length, + rust_str(&slice.input) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE5_POINT_SLICES: &[Stage5PointSlicePlan] = &[\n{slices}\n];\n\n") + } + + fn emit_point_concat_constants(&self) -> String { + if self.role == Role::Verifier { + let concats = self + .point_concats + .iter() + .map(|concat| { + format!( + " Stage5PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: {} }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity, + rust_str(&concat.inputs.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE5_POINT_CONCATS: &[Stage5PointConcatPlan] = &[\n{concats}\n];\n" + ); + } + + let mut source = String::new(); + for (index, concat) in self.point_concats.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE5_POINT_CONCAT_{index}_INPUTS"), + &concat.inputs, + )); + } + let concats = self + .point_concats + .iter() + .enumerate() + .map(|(index, concat)| { + format!( + " Stage5PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: STAGE5_POINT_CONCAT_{index}_INPUTS }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE5_POINT_CONCATS: &[Stage5PointConcatPlan] = &[\n{concats}\n];\n" + ), + ); + source + } + + fn emit_opening_claim_constants(&self) -> String { + let claims = self + .opening_claims + .iter() + .map(|claim| { + format!( + " Stage5OpeningClaimPlan {{ symbol: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {}, point_source: {}, eval_source: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.oracle), + rust_str(&claim.domain), + claim.point_arity, + rust_str(&claim.claim_kind), + rust_str(&claim.point_source), + rust_str(&claim.eval_source) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE5_OPENING_CLAIMS: &[Stage5OpeningClaimPlan] = &[\n{claims}\n];\n\n") + } + + fn emit_opening_claim_equality_constants(&self) -> String { + let equalities = self + .opening_equalities + .iter() + .map(|equality| { + format!( + " Stage5OpeningClaimEqualityPlan {{ symbol: {}, mode: {}, lhs: {}, rhs: {} }},", + rust_str(&equality.symbol), + rust_str(&equality.mode), + rust_str(&equality.lhs), + rust_str(&equality.rhs) + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE5_OPENING_EQUALITIES: &[Stage5OpeningClaimEqualityPlan] = &[\n{equalities}\n];\n\n" + ) + } + + fn emit_opening_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let batches = self + .opening_batches + .iter() + .map(|batch| { + format!( + " Stage5OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {} }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE5_OPENING_BATCHES: &[Stage5OpeningBatchPlan] = &[\n{batches}\n];\n" + ); + } + + let mut source = String::new(); + for (index, batch) in self.opening_batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE5_OPENING_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE5_OPENING_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + } + let batches = self + .opening_batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage5OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE5_OPENING_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE5_OPENING_BATCH_{index}_CLAIM_OPERANDS }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE5_OPENING_BATCHES: &[Stage5OpeningBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_entrypoint(&self) -> &'static str { + match self.role { + Role::Prover => { + "pub fn execute_stage5_prover(\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage5KernelError>\n\ + where\n\ + \x20 E: Stage5KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage5_prover_with_program(&STAGE5_PROGRAM, executor, transcript)\n\ + }\n\ + \n\ + pub fn execute_stage5_prover_with_program(\n\ + \x20 program: &'static Stage5CpuProgramPlan,\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage5KernelError>\n\ + where\n\ + \x20 E: Stage5KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage5_program(program, Stage5ExecutionMode::Prover, executor, transcript)\n\ + }\n" + } + Role::Verifier => { + r#"pub fn verify_stage5( + proof: &Stage5Proof, + opening_inputs: &[Stage5OpeningInputValue], + transcript: &mut T, +) -> Result, VerifyStage5Error> +where + T: Transcript, +{ + verify_stage5_with_program(&STAGE5_PROGRAM, proof, opening_inputs, transcript) +} + +pub fn verify_stage5_with_program( + program: &'static Stage5VerifierProgramPlan, + proof: &Stage5Proof, + opening_inputs: &[Stage5OpeningInputValue], + transcript: &mut T, +) -> Result, VerifyStage5Error> +where + T: Transcript, +{ + if proof.sumchecks.len() != program.drivers.len() { + return Err(VerifyStage5Error::UnexpectedProofCount { + expected: program.drivers.len(), + got: proof.sumchecks.len(), + }); + } + let mut store = + super::common::ValueStore::with_opening_inputs(opening_inputs, program.opening_inputs)?; + store.seed_constants(program.field_constants); + let mut artifacts = Stage5ExecutionArtifacts::default(); + for step in program.steps { + match step.kind { + "transcript_squeeze" => { + let squeeze = + find_plan(program.transcript_squeezes, step.symbol).ok_or(VerifyStage5Error::MissingValue { + symbol: step.symbol, + })?; + verify_stage5_squeeze(program, squeeze, &mut store, transcript, &mut artifacts)?; + } + "transcript_absorb_bytes" => { + let absorb = find_plan(program.transcript_absorb_bytes, step.symbol).ok_or( + VerifyStage5Error::MissingValue { + symbol: step.symbol, + }, + )?; + absorb_stage5_bytes(absorb, transcript); + } + "sumcheck_driver" => { + let driver = + find_plan(program.drivers, step.symbol).ok_or(VerifyStage5Error::MissingProof { + driver: step.symbol, + })?; + verify_stage5_driver(program, driver, proof, &mut store, transcript, &mut artifacts)?; + } + _ => { + return Err(VerifyStage5Error::InvalidProof { + driver: step.symbol, + reason: "unsupported stage5 program step", + }); + } + } + } + artifacts + .opening_batches + .extend(program.opening_batches.iter()); + Ok(artifacts) +} + +pub fn stage5_verifier_program() -> &'static Stage5VerifierProgramPlan { + &STAGE5_PROGRAM +} + +fn verify_stage5_squeeze( + program: &'static Stage5VerifierProgramPlan, + squeeze: &'static Stage5TranscriptSqueezePlan, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage5ExecutionArtifacts, +) -> Result<(), VerifyStage5Error> +where + T: Transcript, +{ + let values = transcript.challenge_vector(squeeze.count); + store.observe_challenge_vector(squeeze, &values, |input, expected, actual| { + VerifyStage5Error::InvalidInputLength { + input, + expected, + actual, + } + })?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage5Error::from)?; + artifacts.challenge_vectors.push(Stage5ChallengeVector { + symbol: squeeze.symbol, + values, + }); + Ok(()) +} + +fn absorb_stage5_bytes(absorb: &'static Stage5TranscriptAbsorbBytesPlan, transcript: &mut T) +where + T: Transcript, +{ + transcript.append(&LabelWithCount( + absorb.label.as_bytes(), + absorb.payload.len() as u64, + )); + transcript.append_bytes(absorb.payload.as_bytes()); +} + +fn verify_stage5_driver( + program: &'static Stage5VerifierProgramPlan, + driver: &'static Stage5SumcheckDriverPlan, + proof: &Stage5Proof, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage5ExecutionArtifacts, +) -> Result<(), VerifyStage5Error> +where + T: Transcript, +{ + let proof = proof + .sumchecks + .get(artifacts.sumchecks.len()) + .ok_or(VerifyStage5Error::MissingProof { + driver: driver.symbol, + })?; + let relation = driver.relation.unwrap_or(""); + let output = match relation { + "jolt.stage5.batched" => { + verify_batched_stage5(program, driver, proof, store, transcript)? + } + _ => return Err(VerifyStage5Error::UnsupportedRelation { relation }), + }; + artifacts.sumchecks.push(output); + Ok(()) +} + +fn verify_batched_stage5( + program: &'static Stage5VerifierProgramPlan, + driver: &'static Stage5SumcheckDriverPlan, + proof: &Stage5SumcheckOutput, + store: &mut super::common::ValueStore, + transcript: &mut T, +) -> Result, VerifyStage5Error> +where + T: Transcript, +{ + super::common::verify_batched_sumcheck( + driver, + proof, + program.claims, + program.batches, + program.field_exprs, + program.opening_inputs, + program.opening_claims, + program.opening_batches, + store, + transcript, + |store, evals, point, batching_coeffs| { + expected_batched_output_claim(program, driver, store, evals, point, batching_coeffs) + }, + |store, verified| observe_stage5_sumcheck_output(program, store, verified), + |driver, error| VerifyStage5Error::Sumcheck { driver, error }, + ) +} + +fn observe_stage5_sumcheck_output( + program: &'static Stage5VerifierProgramPlan, + store: &mut super::common::ValueStore, + output: &Stage5SumcheckOutput, +) -> Result<(), VerifyStage5Error> { + store.observe_sumcheck_output( + program.instance_results, + program.evals, + output, + |instance, mut point| { + match instance.point_order { + "as_is" => {} + "reverse" => point.reverse(), + "instruction_read_raf" => { + point = normalize_instruction_read_raf_point(&point, "stage5.instruction_read_raf.point")?; + } + _ => { + return Err(VerifyStage5Error::InvalidProof { + driver: output.driver, + reason: "unsupported point order", + }); + } + } + Ok(point) + }, + |input, expected, actual| VerifyStage5Error::InvalidInputLength { + input, + expected, + actual, + }, + |symbol| VerifyStage5Error::MissingValue { symbol }, + )?; + store.evaluate_available_points( + program.point_slices, + program.point_concats, + |input, expected, actual| VerifyStage5Error::InvalidInputLength { + input, + expected, + actual, + }, + )?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage5Error::from)?; + store.verify_opening_equalities( + program.opening_equalities, + |driver, reason| VerifyStage5Error::InvalidProof { driver, reason }, + |symbol| VerifyStage5Error::MissingValue { symbol }, + ) +} + +fn expected_batched_output_claim( + program: &'static Stage5VerifierProgramPlan, + driver: &'static Stage5SumcheckDriverPlan, + store: &super::common::ValueStore, + evals: &[Stage5NamedEval], + point: &[Fr], + batching_coeffs: &[Fr], +) -> Result { + let batch = find_batch(program.batches, driver.symbol, driver.batch)?; + let claims = batch_claims(program.claims, batch)?; + let mut expected = Fr::from_u64(0); + for (claim, coefficient) in claims.iter().zip(batching_coeffs) { + let instance = program + .instance_results + .iter() + .find(|instance| instance.claim == claim.symbol && instance.source == driver.symbol) + .ok_or(VerifyStage5Error::MissingClaim { + batch: batch.symbol, + claim: claim.symbol, + })?; + let local_point = point + .get(instance.round_offset..instance.round_offset + instance.num_rounds) + .ok_or(VerifyStage5Error::InvalidInputLength { + input: instance.symbol, + expected: instance.round_offset + instance.num_rounds, + actual: point.len(), + })?; + let relation = claim.relation.unwrap_or(""); + let value = match relation { + "jolt.stage5.instruction_read_raf" => { + expected_instruction_read_raf(store, evals, local_point)? + } + "jolt.stage5.ram_ra_claim_reduction" => { + expected_ram_ra_claim_reduction(store, evals, local_point)? + } + "jolt.stage5.registers_val_evaluation" => { + expected_registers_val_evaluation(store, evals, local_point)? + } + _ => return Err(VerifyStage5Error::UnsupportedRelation { relation }), + }; + expected += *coefficient * value; + } + Ok(expected) +} + +fn expected_instruction_read_raf( + store: &super::common::ValueStore, + evals: &[Stage5NamedEval], + local_point: &[Fr], +) -> Result { + const LOG_K: usize = 128; + const XLEN: usize = 64; + + if local_point.len() < LOG_K { + return Err(VerifyStage5Error::InvalidInputLength { + input: "stage5.instruction_read_raf.point", + expected: LOG_K, + actual: local_point.len(), + }); + } + + let (r_address_prime, r_cycle) = local_point.split_at(LOG_K); + let r_cycle_prime = reverse_slice(r_cycle); + let r_reduction = super::common::store_point(store, "stage5.input.stage2.instruction.LookupOutput")?; + let eq_eval_r_reduction = EqPolynomial::::mle(r_reduction, &r_cycle_prime); + + let left_operand_eval = operand_polynomial_eval(r_address_prime, true); + let right_operand_eval = operand_polynomial_eval(r_address_prime, false); + let identity_poly_eval = identity_polynomial_eval(r_address_prime); + + let table_values = LookupTableKind::::all() + .iter() + .map(|table| table.evaluate_mle::(r_address_prime)) + .collect::>(); + let table_flag_claims = indexed_evals_by_prefix( + evals, + "stage5.instruction_read_raf.eval.LookupTableFlag_", + table_values.len(), + )?; + let val_claim = table_values + .into_iter() + .zip(table_flag_claims) + .map(|(table_value, flag_claim)| table_value * flag_claim) + .sum::(); + + let ra_claim = indexed_evals_by_prefix_any( + evals, + "stage5.instruction_read_raf.eval.InstructionRa_", + )? + .into_iter() + .product::(); + let raf_flag_claim = eval_by_name( + evals, + "stage5.instruction_read_raf.eval.InstructionRafFlag", + )?; + let gamma = super::common::store_scalar(store, "stage5.instruction_read_raf.gamma")?; + + let raf_claim = (Fr::from_u64(1) - raf_flag_claim) + * (left_operand_eval + gamma * right_operand_eval) + + raf_flag_claim * gamma * identity_poly_eval; + Ok(eq_eval_r_reduction * ra_claim * (val_claim + gamma * raf_claim)) +} + +fn expected_ram_ra_claim_reduction( + store: &super::common::ValueStore, + evals: &[Stage5NamedEval], + local_point: &[Fr], +) -> Result { + let r_cycle_reduced = reverse_slice(local_point); + let r_cycle_raf = suffix_point( + super::common::store_point(store, "stage5.input.stage2.ram_raf.RamRa")?, + r_cycle_reduced.len(), + "stage5.input.stage2.ram_raf.RamRa", + )?; + let r_cycle_rw = suffix_point( + super::common::store_point(store, "stage5.input.stage2.ram_read_write.RamRa")?, + r_cycle_reduced.len(), + "stage5.input.stage2.ram_read_write.RamRa", + )?; + let r_cycle_val = suffix_point( + super::common::store_point(store, "stage5.input.stage4.ram_val_check.RamRa")?, + r_cycle_reduced.len(), + "stage5.input.stage4.ram_val_check.RamRa", + )?; + let gamma = super::common::store_scalar(store, "stage5.ram_ra_claim_reduction.gamma")?; + let eq_combined = EqPolynomial::::mle(r_cycle_raf, &r_cycle_reduced) + + gamma * EqPolynomial::::mle(r_cycle_rw, &r_cycle_reduced) + + gamma.square() * EqPolynomial::::mle(r_cycle_val, &r_cycle_reduced); + let ram_ra = eval_by_name(evals, "stage5.ram_ra_claim_reduction.eval.RamRa")?; + Ok(eq_combined * ram_ra) +} + +fn expected_registers_val_evaluation( + store: &super::common::ValueStore, + evals: &[Stage5NamedEval], + local_point: &[Fr], +) -> Result { + let registers_val_point = super::common::store_point(store, "stage5.input.stage4.registers.RegistersVal")?; + let r_cycle = suffix_point( + registers_val_point, + local_point.len(), + "stage5.input.stage4.registers.RegistersVal", + )?; + let r_reduced = reverse_slice(local_point); + let lt_eval = lt_polynomial_eval(&r_reduced, r_cycle); + let rd_inc = eval_by_name(evals, "stage5.registers_val_evaluation.eval.RdInc")?; + let rd_wa = eval_by_name(evals, "stage5.registers_val_evaluation.eval.RdWa")?; + Ok(rd_inc * rd_wa * lt_eval) +} + +"# + } + } + } + + fn role_label(&self) -> &'static str { + match self.role { + Role::Prover => "prover", + Role::Verifier => "verifier", + } + } + + fn program_plan_type(&self) -> &'static str { + match self.role { + Role::Prover => "Stage5CpuProgramPlan", + Role::Verifier => "Stage5VerifierProgramPlan", + } + } +} + +fn require_supported_symbol(kind: &str, actual: &str, expected: &str) -> Result<(), EmitError> { + if actual == expected { + Ok(()) + } else { + Err(EmitError::new(format!( + "unsupported {kind} @{actual}; expected @{expected}" + ))) + } +} + +fn emit_str_array(name: &str, values: &[String]) -> String { + if values.is_empty() { + return format!("pub const {name}: &[&str] = &[];\n\n"); + } + if let [value] = values { + return format!("pub const {name}: &[&str] = &[{}];\n\n", rust_str(value)); + } + let entries = values + .iter() + .map(|value| format!(" {},", rust_str(value))) + .collect::>() + .join("\n"); + format!("pub const {name}: &[&str] = &[\n{entries}\n];\n\n") +} + +fn emit_usize_array(name: &str, values: &[usize]) -> String { + let entries = values + .iter() + .map(|value| format!(" {value},")) + .collect::>() + .join("\n"); + format!("pub const {name}: &[usize] = &[\n{entries}\n];\n\n") +} + +fn intern_str_array( + source: &mut String, + arrays: &mut Vec<(Vec, String)>, + name_prefix: &str, + values: &[String], +) -> String { + if let Some((_, name)) = arrays + .iter() + .find(|(existing, _)| existing.as_slice() == values) + { + return name.clone(); + } + let name = format!("{name_prefix}_{}", arrays.len()); + source.push_str(&emit_str_array(&name, values)); + arrays.push((values.to_vec(), name.clone())); + name +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn rust_option_str(value: Option<&str>) -> String { + value.map_or_else( + || "None".to_owned(), + |value| format!("Some({})", rust_str(value)), + ) +} + +fn verify_count(kind: &str, symbol: &str, expected: usize, actual: usize) -> Result<(), EmitError> { + if expected == actual { + Ok(()) + } else { + Err(EmitError::new(format!( + "{kind} @{symbol} count mismatch: expected {expected}, got {actual}" + ))) + } +} + +fn symbols<'a>(values: impl Iterator) -> BTreeSet { + values.cloned().collect() +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn int_array_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "integer array"))?; + parse_int_array(&attribute).ok_or_else(|| attr_error(operation, attr, "integer array")) +} + +fn parse_int_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().parse().ok()) + .collect() +} + +fn operand_symbols( + operation: OperationRef<'_, '_>, + start_index: usize, +) -> Result, EmitError> { + (start_index..operation.operand_count()) + .map(|index| operand_symbol(operation, index)) + .collect() +} + +fn operand_symbol(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + EmitError::new(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + EmitError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name<'c: 'a, 'a>(operation: impl OperationLike<'c, 'a>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} diff --git a/crates/bolt/src/protocols/jolt/emit/rust/stage6.rs b/crates/bolt/src/protocols/jolt/emit/rust/stage6.rs new file mode 100644 index 0000000000..800eed9731 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/stage6.rs @@ -0,0 +1,2655 @@ +#![expect( + clippy::needless_raw_string_hashes, + reason = "generated Rust templates are kept as raw string blocks for copyable output" +)] + +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{Attribute, OperationRef}; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6CpuProgram { + pub role: Role, + pub params: Stage6Params, + pub steps: Vec, + pub transcript_squeezes: Vec, + pub transcript_absorb_bytes: Vec, + pub opening_inputs: Vec, + pub field_constants: Vec, + pub field_exprs: Vec, + pub kernels: Vec, + pub claims: Vec, + pub batches: Vec, + pub drivers: Vec, + pub instance_results: Vec, + pub evals: Vec, + pub point_zeros: Vec, + pub point_slices: Vec, + pub point_concats: Vec, + pub opening_claims: Vec, + pub opening_equalities: Vec, + pub opening_batches: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6Params { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6KernelPlan { + pub symbol: String, + pub relation: String, + pub kind: String, + pub backend: String, + pub abi: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6TranscriptSqueezePlan { + pub symbol: String, + pub label: String, + pub kind: String, + pub count: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6TranscriptAbsorbBytesPlan { + pub symbol: String, + pub label: String, + pub payload: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6ProgramStepPlan { + pub kind: String, + pub symbol: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6OpeningInputPlan { + pub symbol: String, + pub source_stage: String, + pub source_claim: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6FieldConstantPlan { + pub symbol: String, + pub field: String, + pub value: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6FieldExprPlan { + pub symbol: String, + pub kind: String, + pub formula: String, + pub operand_names: Vec, + pub operands: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckClaimPlan { + pub symbol: String, + pub stage: String, + pub domain: String, + pub num_rounds: usize, + pub degree: usize, + pub claim: String, + pub kernel: Option, + pub relation: Option, + pub claim_value: String, + pub input_openings: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, + pub claim_label: String, + pub round_label: String, + pub round_schedule: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckDriverPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub kernel: Option, + pub relation: Option, + pub batch: String, + pub policy: String, + pub round_schedule: Vec, + pub claim_label: String, + pub round_label: String, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckInstanceResultPlan { + pub symbol: String, + pub source: String, + pub claim: String, + pub relation: String, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: String, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckEvalPlan { + pub symbol: String, + pub source: String, + pub name: String, + pub index: usize, + pub oracle: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6PointZeroPlan { + pub symbol: String, + pub field: String, + pub arity: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6PointSlicePlan { + pub symbol: String, + pub source: String, + pub offset: usize, + pub length: usize, + pub input: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6PointConcatPlan { + pub symbol: String, + pub layout: String, + pub arity: usize, + pub inputs: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6OpeningClaimPlan { + pub symbol: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, + pub point_source: String, + pub eval_source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6OpeningClaimEqualityPlan { + pub symbol: String, + pub mode: String, + pub lhs: String, + pub rhs: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage6OpeningBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, +} + +pub fn stage6_cpu_program(module: &BoltModule<'_, Cpu>) -> Result { + verify_cpu_schema(module)?; + let program = Stage6CpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +pub fn emit_stage6_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = stage6_cpu_program(module)?; + + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source(), + }) +} + +impl Stage6CpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let mut params = None; + let mut steps = Vec::new(); + let mut transcript_squeezes = Vec::new(); + let mut transcript_absorb_bytes = Vec::new(); + let mut opening_inputs = Vec::new(); + let mut field_constants = Vec::new(); + let mut field_exprs = Vec::new(); + let mut kernels = Vec::new(); + let mut claims = Vec::new(); + let mut batches = Vec::new(); + let mut drivers = Vec::new(); + let mut instance_results = Vec::new(); + let mut evals = Vec::new(); + let mut point_zeros = Vec::new(); + let mut point_slices = Vec::new(); + let mut point_concats = Vec::new(); + let mut opening_claims = Vec::new(); + let mut opening_equalities = Vec::new(); + let mut opening_batches = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(Stage6Params { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.kernel" => { + kernels.push(Stage6KernelPlan { + symbol: string_attr(op, "sym_name")?, + relation: symbol_attr(op, "relation")?, + kind: string_attr(op, "kind")?, + backend: string_attr(op, "backend")?, + abi: string_attr(op, "abi")?, + }); + } + "cpu.transcript_squeeze" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage6ProgramStepPlan { + kind: "transcript_squeeze".to_owned(), + symbol: symbol.clone(), + }); + transcript_squeezes.push(Stage6TranscriptSqueezePlan { + symbol, + label: string_attr(op, "label")?, + kind: string_attr(op, "kind")?, + count: int_attr(op, "count")?, + }); + } + "cpu.transcript_absorb_bytes" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage6ProgramStepPlan { + kind: "transcript_absorb_bytes".to_owned(), + symbol: symbol.clone(), + }); + transcript_absorb_bytes.push(Stage6TranscriptAbsorbBytesPlan { + symbol, + label: string_attr(op, "label")?, + payload: string_attr(op, "payload")?, + }); + } + "cpu.opening_input" => { + opening_inputs.push(Stage6OpeningInputPlan { + symbol: string_attr(op, "sym_name")?, + source_stage: symbol_attr(op, "source_stage")?, + source_claim: symbol_attr(op, "source_claim")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + }); + } + "cpu.field_const" => { + field_constants.push(Stage6FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: int_attr(op, "value")?, + }); + } + "cpu.field_zero" => { + field_constants.push(Stage6FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 0, + }); + } + "cpu.field_one" => { + field_constants.push(Stage6FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 1, + }); + } + "cpu.field_add" | "cpu.field_sub" | "cpu.field_mul" | "cpu.field_neg" => { + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage6FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: operation_name(op).replace("cpu.field_", "field."), + operand_names: operands.clone(), + operands, + }); + } + "cpu.field_pow" => { + let exponent = int_attr(op, "exponent")?; + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage6FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: format!("field.pow:{exponent}"), + operand_names: operands.clone(), + operands, + }); + } + "cpu.sumcheck_claim" => { + claims.push(Stage6SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_verify_claim" => { + claims.push(Stage6SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_batch" => { + batches.push(Stage6SumcheckBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + round_schedule: int_array_attr(op, "round_schedule")?, + }); + } + "cpu.sumcheck_driver" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage6ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage6SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_verify" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage6ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage6SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_instance_result" => { + instance_results.push(Stage6SumcheckInstanceResultPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + claim: symbol_attr(op, "claim")?, + relation: symbol_attr(op, "relation")?, + index: int_attr(op, "index")?, + point_arity: int_attr(op, "point_arity")?, + num_rounds: int_attr(op, "num_rounds")?, + round_offset: int_attr(op, "round_offset")?, + point_order: string_attr(op, "point_order")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_eval" => { + evals.push(Stage6SumcheckEvalPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + name: symbol_attr(op, "name")?, + index: int_attr(op, "index")?, + oracle: symbol_attr(op, "oracle")?, + }); + } + "cpu.point_zero" => { + point_zeros.push(Stage6PointZeroPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + arity: int_attr(op, "arity")?, + }); + } + "cpu.point_slice" => { + point_slices.push(Stage6PointSlicePlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + offset: int_attr(op, "offset")?, + length: int_attr(op, "length")?, + input: operand_symbol(op, 0)?, + }); + } + "cpu.point_concat" => { + point_concats.push(Stage6PointConcatPlan { + symbol: string_attr(op, "sym_name")?, + layout: string_attr(op, "layout")?, + arity: int_attr(op, "arity")?, + inputs: operand_symbols(op, 0)?, + }); + } + "cpu.opening_claim" => { + opening_claims.push(Stage6OpeningClaimPlan { + symbol: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + point_source: operand_symbol(op, 0)?, + eval_source: operand_symbol(op, 1)?, + }); + } + "cpu.opening_claim_equal" => { + opening_equalities.push(Stage6OpeningClaimEqualityPlan { + symbol: string_attr(op, "sym_name")?, + mode: string_attr(op, "mode")?, + lhs: operand_symbol(op, 0)?, + rhs: operand_symbol(op, 1)?, + }); + } + "cpu.opening_batch" => { + opening_batches.push(Stage6OpeningBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + }); + } + _ => {} + } + } + + Ok(Self { + params: params.ok_or_else(|| EmitError::new("missing cpu.params"))?, + role: module + .role() + .ok_or_else(|| EmitError::new("missing cpu party role"))?, + steps, + transcript_squeezes, + transcript_absorb_bytes, + opening_inputs, + field_constants, + field_exprs, + kernels, + claims, + batches, + drivers, + instance_results, + evals, + point_zeros, + point_slices, + point_concats, + opening_claims, + opening_equalities, + opening_batches, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + require_supported_symbol("field", &self.params.field, "bn254_fr")?; + require_supported_symbol("pcs", &self.params.pcs, "dory")?; + require_supported_symbol("transcript", &self.params.transcript, "blake2b_transcript")?; + self.verify_transcript_steps()?; + self.verify_field_flow()?; + self.verify_claim_batches()?; + match self.role { + Role::Prover => { + self.verify_kernel_definitions()?; + self.verify_prover_driver_bindings()?; + } + Role::Verifier => self.verify_verifier_driver_bindings()?, + } + self.verify_opening_flow() + } + + fn verify_transcript_steps(&self) -> Result<(), EmitError> { + for squeeze in &self.transcript_squeezes { + if !matches!( + squeeze.kind.as_str(), + "challenge_scalar" | "challenge_vector" + ) { + return Err(EmitError::new(format!( + "stage6 transcript squeeze @{} has unsupported kind `{}`", + squeeze.symbol, squeeze.kind + ))); + } + if squeeze.count == 0 { + return Err(EmitError::new(format!( + "stage6 transcript squeeze @{} has zero count", + squeeze.symbol + ))); + } + } + for absorb in &self.transcript_absorb_bytes { + if absorb.label.is_empty() { + return Err(EmitError::new(format!( + "stage6 transcript byte absorb @{} has empty label", + absorb.symbol + ))); + } + } + Ok(()) + } + + fn verify_field_flow(&self) -> Result<(), EmitError> { + for constant in &self.field_constants { + require_supported_symbol("field constant field", &constant.field, "bn254_fr")?; + } + let field_values = self.field_value_symbols(); + for expr in &self.field_exprs { + verify_count( + "field expr operands", + &expr.symbol, + expr.operand_names.len(), + expr.operands.len(), + )?; + for operand in &expr.operands { + if !field_values.contains(operand) { + return Err(EmitError::new(format!( + "field expr @{} references missing field value @{operand}", + expr.symbol + ))); + } + } + } + for claim in &self.claims { + if !field_values.contains(&claim.claim_value) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing claim value @{}", + claim.symbol, claim.claim_value + ))); + } + } + Ok(()) + } + + fn field_value_symbols(&self) -> BTreeSet { + let mut values = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + values.extend(symbols( + self.field_constants.iter().map(|constant| &constant.symbol), + )); + values.extend(symbols( + self.transcript_squeezes + .iter() + .filter(|squeeze| matches!(squeeze.kind.as_str(), "challenge_scalar" | "scalar")) + .map(|squeeze| &squeeze.symbol), + )); + values.extend(symbols(self.field_exprs.iter().map(|expr| &expr.symbol))); + values.extend(symbols(self.evals.iter().map(|eval| &eval.symbol))); + values + } + + fn verify_kernel_definitions(&self) -> Result<(), EmitError> { + for kernel in &self.kernels { + if kernel.backend != "cpu" { + return Err(EmitError::new(format!( + "stage6 kernel @{} targets unsupported backend `{}`", + kernel.symbol, kernel.backend + ))); + } + if kernel.kind != "sumcheck" { + return Err(EmitError::new(format!( + "stage6 kernel @{} has unsupported kind `{}`", + kernel.symbol, kernel.kind + ))); + } + let expected_abi = match kernel.relation.as_str() { + "jolt.stage6.bytecode_read_raf" => "jolt_stage6_bytecode_read_raf", + "jolt.stage6.booleanity" => "jolt_stage6_booleanity", + "jolt.stage6.hamming_booleanity" => "jolt_stage6_hamming_booleanity", + "jolt.stage6.ram_ra_virtual" => "jolt_stage6_ram_ra_virtual", + "jolt.stage6.instruction_ra_virtual" => "jolt_stage6_instruction_ra_virtual", + "jolt.stage6.inc_claim_reduction" => "jolt_stage6_inc_claim_reduction", + "jolt.stage6.batched" => "jolt_stage6_batched", + _ => { + return Err(EmitError::new(format!( + "unsupported stage6 kernel relation @{}", + kernel.relation + ))); + } + }; + if kernel.abi != expected_abi { + return Err(EmitError::new(format!( + "stage6 kernel @{} ABI `{}` does not match relation @{}", + kernel.symbol, kernel.abi, kernel.relation + ))); + } + } + Ok(()) + } + + fn verify_claim_batches(&self) -> Result<(), EmitError> { + let claims = symbols(self.claims.iter().map(|claim| &claim.symbol)); + for batch in &self.batches { + verify_count( + "sumcheck batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "sumcheck batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "sumcheck batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !claims.contains(claim) { + return Err(EmitError::new(format!( + "sumcheck batch @{} references missing claim @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn verify_prover_driver_bindings(&self) -> Result<(), EmitError> { + let kernels = symbols(self.kernels.iter().map(|kernel| &kernel.symbol)); + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + let Some(kernel) = claim.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck claim @{} is missing kernel", + claim.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck claim @{} references missing kernel @{kernel}", + claim.symbol + ))); + } + } + for driver in &self.drivers { + let Some(kernel) = driver.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck driver @{} is missing kernel", + driver.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck driver @{} references missing kernel @{kernel}", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_verifier_driver_bindings(&self) -> Result<(), EmitError> { + if !self.kernels.is_empty() { + return Err(EmitError::new( + "verifier stage6 program must not contain kernels", + )); + } + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + if claim.kernel.is_some() || claim.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck claim @{} must carry relation and no kernel", + claim.symbol + ))); + } + } + for driver in &self.drivers { + if driver.kernel.is_some() || driver.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck driver @{} must carry relation and no kernel", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_opening_flow(&self) -> Result<(), EmitError> { + let mut point_sources = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + point_sources.extend(symbols( + self.instance_results + .iter() + .map(|instance| &instance.symbol), + )); + point_sources.extend(symbols( + self.opening_inputs.iter().map(|input| &input.symbol), + )); + point_sources.extend(symbols(self.point_zeros.iter().map(|zero| &zero.symbol))); + point_sources.extend(symbols(self.point_slices.iter().map(|slice| &slice.symbol))); + point_sources.extend(symbols( + self.point_concats.iter().map(|concat| &concat.symbol), + )); + for zero in &self.point_zeros { + require_supported_symbol("point zero field", &zero.field, "bn254_fr")?; + } + for slice in &self.point_slices { + if !point_sources.contains(&slice.input) { + return Err(EmitError::new(format!( + "point slice @{} uses missing point source @{}", + slice.symbol, slice.input + ))); + } + } + for concat in &self.point_concats { + for input in &concat.inputs { + if !point_sources.contains(input) { + return Err(EmitError::new(format!( + "point concat @{} uses missing point source @{input}", + concat.symbol + ))); + } + } + } + let eval_sources = self.field_value_symbols(); + let mut opening_sources = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + opening_sources.extend(symbols( + self.opening_claims.iter().map(|claim| &claim.symbol), + )); + for equality in &self.opening_equalities { + if !opening_sources.contains(&equality.lhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing lhs opening @{}", + equality.symbol, equality.lhs + ))); + } + if !opening_sources.contains(&equality.rhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing rhs opening @{}", + equality.symbol, equality.rhs + ))); + } + } + for claim in &self.claims { + for input in &claim.input_openings { + if !opening_sources.contains(input) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing opening @{input}", + claim.symbol + ))); + } + } + } + let drivers = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + for instance in &self.instance_results { + if !drivers.contains(&instance.source) { + return Err(EmitError::new(format!( + "sumcheck instance result @{} references missing driver @{}", + instance.symbol, instance.source + ))); + } + } + for eval in &self.evals { + if !drivers.contains(&eval.source) { + return Err(EmitError::new(format!( + "sumcheck eval @{} references missing driver @{}", + eval.symbol, eval.source + ))); + } + } + for claim in &self.opening_claims { + if !point_sources.contains(&claim.point_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing point source @{}", + claim.symbol, claim.point_source + ))); + } + if !eval_sources.contains(&claim.eval_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing eval source @{}", + claim.symbol, claim.eval_source + ))); + } + } + let openings = symbols(self.opening_claims.iter().map(|claim| &claim.symbol)); + for batch in &self.opening_batches { + verify_count( + "opening batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "opening batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "opening batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !openings.contains(claim) { + return Err(EmitError::new(format!( + "opening batch @{} references missing opening @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_stage6.rs", + Role::Verifier => "verify_stage6.rs", + } + } + + fn emit_source(&self) -> String { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + match self.role { + Role::Prover => { + source.push_str(Self::emit_prover_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_prover_types()); + } + Role::Verifier => { + source.push_str(Self::emit_verifier_imports()); + source.push_str("\n\n"); + source.push_str(&Self::emit_verifier_types()); + } + } + source.push('\n'); + source.push_str(&self.emit_constants()); + source.push('\n'); + source.push_str(self.emit_entrypoint()); + source + } + + fn emit_prover_imports() -> &'static str { + "use jolt_field::Fr;\n\ + use jolt_kernels::stage6::{execute_stage6_program, Stage6CpuProgramPlan, Stage6ExecutionArtifacts, Stage6ExecutionMode, Stage6FieldConstantPlan, Stage6FieldExprPlan, Stage6KernelError, Stage6KernelExecutor, Stage6KernelPlan, Stage6OpeningBatchPlan, Stage6OpeningClaimEqualityPlan, Stage6OpeningClaimPlan, Stage6OpeningInputPlan, Stage6Params, Stage6PointConcatPlan, Stage6PointSlicePlan, Stage6PointZeroPlan, Stage6ProgramStepPlan, Stage6SumcheckBatchPlan, Stage6SumcheckClaimPlan, Stage6SumcheckDriverPlan, Stage6SumcheckEvalPlan, Stage6SumcheckInstanceResultPlan, Stage6TranscriptAbsorbBytesPlan, Stage6TranscriptSqueezePlan};\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_prover_types() -> &'static str { + "pub type DefaultStage6Transcript = Blake2bTranscript;\n" + } + + fn emit_verifier_imports() -> &'static str { + "use super::common::{batch_claims, expected_stage67_booleanity, expected_stage67_bytecode_read_raf, expected_stage67_hamming_booleanity, expected_stage67_inc_claim_reduction, expected_stage67_instruction_ra_virtual, expected_stage67_ram_ra_virtual, find_batch, find_plan, normalize_bytecode_read_raf_point, normalize_instruction_read_raf_point, stage67_trace_rounds, Stage67BytecodeEntry, Stage67BytecodeSymbols, Stage67RelationSymbols};\n\ + use jolt_field::{Field, Fr};\n\ + use jolt_sumcheck::SumcheckError;\n\ + use jolt_transcript::{Blake2bTranscript, LabelWithCount, Transcript};" + } + + #[expect(dead_code)] + fn emit_types() -> &'static str { + r#"#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6Params { + pub field: &'static str, + pub pcs: &'static str, + pub transcript: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6KernelPlan { + pub symbol: &'static str, + pub relation: &'static str, + pub kind: &'static str, + pub backend: &'static str, + pub abi: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6TranscriptSqueezePlan { + pub symbol: &'static str, + pub label: &'static str, + pub kind: &'static str, + pub count: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6TranscriptAbsorbBytesPlan { + pub symbol: &'static str, + pub label: &'static str, + pub payload: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6ProgramStepPlan { + pub kind: &'static str, + pub symbol: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6OpeningInputPlan { + pub symbol: &'static str, + pub source_stage: &'static str, + pub source_claim: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6FieldConstantPlan { + pub symbol: &'static str, + pub field: &'static str, + pub value: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6FieldExprPlan { + pub symbol: &'static str, + pub kind: &'static str, + pub formula: &'static str, + pub operand_names: &'static [&'static str], + pub operands: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckClaimPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub domain: &'static str, + pub num_rounds: usize, + pub degree: usize, + pub claim: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub claim_value: &'static str, + pub input_openings: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static [&'static str], + pub claim_operands: &'static [&'static str], + pub claim_label: &'static str, + pub round_label: &'static str, + pub round_schedule: &'static [usize], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckDriverPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub batch: &'static str, + pub policy: &'static str, + pub round_schedule: &'static [usize], + pub claim_label: &'static str, + pub round_label: &'static str, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckInstanceResultPlan { + pub symbol: &'static str, + pub source: &'static str, + pub claim: &'static str, + pub relation: &'static str, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: &'static str, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6SumcheckEvalPlan { + pub symbol: &'static str, + pub source: &'static str, + pub name: &'static str, + pub index: usize, + pub oracle: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6PointZeroPlan { + pub symbol: &'static str, + pub field: &'static str, + pub arity: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6PointSlicePlan { + pub symbol: &'static str, + pub source: &'static str, + pub offset: usize, + pub length: usize, + pub input: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6PointConcatPlan { + pub symbol: &'static str, + pub layout: &'static str, + pub arity: usize, + pub inputs: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6OpeningClaimPlan { + pub symbol: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, + pub point_source: &'static str, + pub eval_source: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6OpeningClaimEqualityPlan { + pub symbol: &'static str, + pub mode: &'static str, + pub lhs: &'static str, + pub rhs: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6OpeningBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static [&'static str], + pub claim_operands: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage6CpuProgramPlan { + pub role: &'static str, + pub params: Stage6Params, + pub steps: &'static [Stage6ProgramStepPlan], + pub transcript_squeezes: &'static [Stage6TranscriptSqueezePlan], + pub transcript_absorb_bytes: &'static [Stage6TranscriptAbsorbBytesPlan], + pub opening_inputs: &'static [Stage6OpeningInputPlan], + pub field_constants: &'static [Stage6FieldConstantPlan], + pub field_exprs: &'static [Stage6FieldExprPlan], + pub kernels: &'static [Stage6KernelPlan], + pub claims: &'static [Stage6SumcheckClaimPlan], + pub batches: &'static [Stage6SumcheckBatchPlan], + pub drivers: &'static [Stage6SumcheckDriverPlan], + pub instance_results: &'static [Stage6SumcheckInstanceResultPlan], + pub evals: &'static [Stage6SumcheckEvalPlan], + pub point_zeros: &'static [Stage6PointZeroPlan], + pub point_slices: &'static [Stage6PointSlicePlan], + pub point_concats: &'static [Stage6PointConcatPlan], + pub opening_claims: &'static [Stage6OpeningClaimPlan], + pub opening_equalities: &'static [Stage6OpeningClaimEqualityPlan], + pub opening_batches: &'static [Stage6OpeningBatchPlan], +} +"# + } + + fn emit_verifier_type_aliases() -> &'static str { + r#"pub type Stage6NamedEval = super::common::StageNamedEval; +pub type Stage6SumcheckOutput = super::common::StageSumcheckOutput; +pub type Stage6ChallengeVector = super::common::StageChallengeVector; +pub type Stage6ExecutionArtifacts = super::common::StageExecutionArtifacts; +pub type Stage6Proof = super::common::StageProof; +pub type Stage6OpeningInputValue = super::common::StageOpeningInputValue; + +pub use super::common::{ + FieldConstantPlan as Stage6FieldConstantPlan, FieldExprPlan as Stage6FieldExprPlan, + KernelPlan as Stage6KernelPlan, OpeningBatchPlan as Stage6OpeningBatchPlan, + OpeningClaimEqualityPlan as Stage6OpeningClaimEqualityPlan, + OpeningClaimPlan as Stage6OpeningClaimPlan, OpeningInputPlan as Stage6OpeningInputPlan, + PointConcatPlan as Stage6PointConcatPlan, PointSlicePlan as Stage6PointSlicePlan, + PointZeroPlan as Stage6PointZeroPlan, ProgramStepPlan as Stage6ProgramStepPlan, + StageParams as Stage6Params, StageProgramPlan as Stage6CpuProgramPlan, + SumcheckBatchPlan as Stage6SumcheckBatchPlan, + SumcheckClaimPlan as Stage6SumcheckClaimPlan, SumcheckDriverPlan as Stage6SumcheckDriverPlan, + SumcheckEvalPlan as Stage6SumcheckEvalPlan, + SumcheckInstanceResultPlan as Stage6SumcheckInstanceResultPlan, + TranscriptAbsorbBytesPlan as Stage6TranscriptAbsorbBytesPlan, + TranscriptSqueezePlan as Stage6TranscriptSqueezePlan, +}; +"# + } + + fn emit_verifier_types() -> String { + let mut source = Self::emit_verifier_type_aliases().to_owned(); + source.push_str( + r#" +pub type DefaultStage6Transcript = Blake2bTranscript; +pub type Stage6VerifierProgramPlan = Stage6CpuProgramPlan; + +#[derive(Clone, Debug)] +pub struct Stage6BytecodeEntry { + pub address: Fr, + pub imm: Fr, + pub circuit_flags: [bool; 14], + pub rd: Option, + pub rs1: Option, + pub rs2: Option, + pub lookup_table: Option, + pub is_interleaved: bool, + pub is_branch: bool, + pub left_is_rs1: bool, + pub left_is_pc: bool, + pub right_is_rs2: bool, + pub right_is_imm: bool, + pub is_noop: bool, +} + +impl Stage67BytecodeEntry for Stage6BytecodeEntry { + fn address(&self) -> Fr { self.address } + fn imm(&self) -> Fr { self.imm } + fn circuit_flags(&self) -> &[bool; 14] { &self.circuit_flags } + fn rd(&self) -> Option { self.rd } + fn rs1(&self) -> Option { self.rs1 } + fn rs2(&self) -> Option { self.rs2 } + fn lookup_table(&self) -> Option { self.lookup_table } + fn is_interleaved(&self) -> bool { self.is_interleaved } + fn is_branch(&self) -> bool { self.is_branch } + fn left_is_rs1(&self) -> bool { self.left_is_rs1 } + fn left_is_pc(&self) -> bool { self.left_is_pc } + fn right_is_rs2(&self) -> bool { self.right_is_rs2 } + fn right_is_imm(&self) -> bool { self.right_is_imm } + fn is_noop(&self) -> bool { self.is_noop } +} + + +#[derive(Clone, Debug)] +pub struct Stage6BytecodeReadRafData { + pub entries: Vec, + pub entry_bytecode_index: usize, + pub num_lookup_tables: usize, +} + +#[derive(Clone, Debug)] +pub struct Stage6VerifierData { + pub bytecode_read_raf: Option, +} + +const STAGE6_RELATION_SYMBOLS: Stage67RelationSymbols = Stage67RelationSymbols { + hamming_booleanity_relation: "jolt.stage6.hamming_booleanity", + hamming_booleanity_instance: "stage6.hamming_booleanity.instance", + booleanity_point: "stage6.booleanity.point", + stage5_instruction_ra0: "stage6.input.stage5.instruction_read_raf.InstructionRa_0", + booleanity_combined_point: "stage6.booleanity.combined_point", + booleanity_gamma: "stage6.booleanity.gamma", + booleanity_instruction_ra_prefix: "stage6.booleanity.eval.InstructionRa_", + booleanity_bytecode_ra_prefix: "stage6.booleanity.eval.BytecodeRa_", + booleanity_ram_ra_prefix: "stage6.booleanity.eval.RamRa_", + hamming_weight_eval: "stage6.hamming_booleanity.eval.HammingWeight", + hamming_lookup_output: "stage6.input.stage1.LookupOutput", + ram_ra_virtual_cycle: "stage6.input.stage5.ram_ra_claim_reduction.RamRa", + ram_ra_virtual_eval_prefix: "stage6.ram_ra_virtual.eval.RamRa_", + instruction_ra_virtual_cycle: "stage6.input.stage5.instruction_read_raf.InstructionRa_0", + instruction_ra_virtual_eval_prefix: "stage6.instruction_ra_virtual.eval.InstructionRa_", + instruction_ra_virtual_input_prefix: "stage6.input.stage5.instruction_read_raf.InstructionRa_", + instruction_ra_virtual_gamma: "stage6.instruction_ra_virtual.gamma", + inc_ram_stage2: "stage6.input.stage2.ram_read_write.RamInc", + inc_ram_stage4: "stage6.input.stage4.ram_val_check.RamInc", + inc_rd_stage4: "stage6.input.stage4.registers_read_write.RdInc", + inc_rd_stage5: "stage6.input.stage5.registers_val_evaluation.RdInc", + inc_gamma: "stage6.inc_claim_reduction.gamma", + inc_ram_eval: "stage6.inc_claim_reduction.eval.RamInc", + inc_rd_eval: "stage6.inc_claim_reduction.eval.RdInc", +}; + +const STAGE6_BYTECODE_SYMBOLS: Stage67BytecodeSymbols = Stage67BytecodeSymbols { + point: "stage6.bytecode_read_raf.point", + gamma: "stage6.bytecode_read_raf.gamma", + bytecode_ra_eval_prefix: "stage6.bytecode_read_raf.eval.BytecodeRa_", + entries: "stage6.bytecode_read_raf.entries", + entry_bytecode_index: "stage6.bytecode_read_raf.entry_bytecode_index", + stage_gammas: [ + "stage6.bytecode_read_raf.stage1_gamma", + "stage6.bytecode_read_raf.stage2_gamma", + "stage6.bytecode_read_raf.stage3_gamma", + "stage6.bytecode_read_raf.stage4_gamma", + "stage6.bytecode_read_raf.stage5_gamma", + ], + stage_cycle_points: [ + "stage6.input.stage1.Imm", + "stage6.input.stage2.OpFlagJump", + "stage6.input.stage3.spartan_shift.UnexpandedPC", + "stage6.input.stage4.Rs1Ra", + "stage6.input.stage5.registers_val_evaluation.RdWa", + ], + stage4_register_point: "stage6.input.stage4.Rs1Ra", + stage5_register_point: "stage6.input.stage5.registers_val_evaluation.RdWa", + entry_rd: "stage6.bytecode.entry.rd", + entry_rs1: "stage6.bytecode.entry.rs1", + entry_rs2: "stage6.bytecode.entry.rs2", + entry_lookup_table: "stage6.bytecode.entry.lookup_table", +}; + +#[derive(Debug)] +pub enum VerifyStage6Error { + UnexpectedProofCount { expected: usize, got: usize }, + MissingProof { driver: &'static str }, + MissingBatch { driver: &'static str, batch: &'static str }, + MissingClaim { batch: &'static str, claim: &'static str }, + MissingValue { symbol: &'static str }, + InvalidInputLength { input: &'static str, expected: usize, actual: usize }, + InvalidProof { driver: &'static str, reason: &'static str }, + UnsupportedFieldExpr { symbol: &'static str, formula: &'static str }, + UnsupportedRelation { relation: &'static str }, + Sumcheck { driver: &'static str, error: SumcheckError }, +} + +super::common::impl_runtime_plan_error_conversion!(VerifyStage6Error); +"#, + ); + source + } + + fn emit_constants(&self) -> String { + let mut source = self.emit_shared_constants(); + source.push_str(&self.emit_kernel_constants()); + source.push_str(&self.emit_sumcheck_claim_constants()); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_sumcheck_driver_constants()); + source.push_str(&self.emit_tail_constants()); + push_format( + &mut source, + format_args!( + "pub const STAGE6_PROGRAM: {} = Stage6CpuProgramPlan {{\n\ + \x20 role: {},\n\ + \x20 params: STAGE6_PARAMS,\n\ + \x20 steps: STAGE6_PROGRAM_STEPS,\n\ + \x20 transcript_squeezes: STAGE6_TRANSCRIPT_SQUEEZES,\n\ + \x20 transcript_absorb_bytes: STAGE6_TRANSCRIPT_ABSORB_BYTES,\n\ + \x20 opening_inputs: STAGE6_OPENING_INPUTS,\n\ + \x20 field_constants: STAGE6_FIELD_CONSTANTS,\n\ + \x20 field_exprs: STAGE6_FIELD_EXPRS,\n\ + \x20 kernels: STAGE6_KERNELS,\n\ + \x20 claims: STAGE6_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE6_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE6_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE6_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE6_SUMCHECK_EVALS,\n\ + \x20 point_zeros: STAGE6_POINT_ZEROS,\n\ + \x20 point_slices: STAGE6_POINT_SLICES,\n\ + \x20 point_concats: STAGE6_POINT_CONCATS,\n\ + \x20 opening_claims: STAGE6_OPENING_CLAIMS,\n\ + \x20 opening_equalities: STAGE6_OPENING_EQUALITIES,\n\ + \x20 opening_batches: STAGE6_OPENING_BATCHES,\n\ + }};\n", + self.program_plan_type(), + rust_str(self.role_label()) + ), + ); + source + } + + fn emit_shared_constants(&self) -> String { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const STAGE6_PARAMS: Stage6Params = Stage6Params {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + source.push_str(&self.emit_program_step_constants()); + source.push_str(&self.emit_transcript_squeeze_constants()); + source.push_str(&self.emit_transcript_absorb_bytes_constants()); + source.push_str(&self.emit_opening_input_constants()); + source.push_str(&self.emit_field_constant_constants()); + source.push_str(&self.emit_field_expr_constants()); + source + } + + fn emit_program_step_constants(&self) -> String { + let steps = self + .steps + .iter() + .map(|step| { + format!( + " Stage6ProgramStepPlan {{ kind: {}, symbol: {} }},", + rust_str(&step.kind), + rust_str(&step.symbol), + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE6_PROGRAM_STEPS: &[Stage6ProgramStepPlan] = &[\n{steps}\n];\n\n") + } + + fn emit_transcript_squeeze_constants(&self) -> String { + let squeezes = self + .transcript_squeezes + .iter() + .map(|squeeze| { + format!( + " Stage6TranscriptSqueezePlan {{ symbol: {}, label: {}, kind: {}, count: {} }},", + rust_str(&squeeze.symbol), + rust_str(&squeeze.label), + rust_str(&squeeze.kind), + squeeze.count, + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE6_TRANSCRIPT_SQUEEZES: &[Stage6TranscriptSqueezePlan] = &[\n{squeezes}\n];\n\n" + ) + } + + fn emit_transcript_absorb_bytes_constants(&self) -> String { + let absorbs = self + .transcript_absorb_bytes + .iter() + .map(|absorb| { + format!( + " Stage6TranscriptAbsorbBytesPlan {{ symbol: {}, label: {}, payload: {} }},", + rust_str(&absorb.symbol), + rust_str(&absorb.label), + rust_str(&absorb.payload), + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE6_TRANSCRIPT_ABSORB_BYTES: &[Stage6TranscriptAbsorbBytesPlan] = &[\n{absorbs}\n];\n\n" + ) + } + + fn emit_opening_input_constants(&self) -> String { + let inputs = self + .opening_inputs + .iter() + .map(|input| { + format!( + " Stage6OpeningInputPlan {{ symbol: {}, source_stage: {}, source_claim: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {} }},", + rust_str(&input.symbol), + rust_str(&input.source_stage), + rust_str(&input.source_claim), + rust_str(&input.oracle), + rust_str(&input.domain), + input.point_arity, + rust_str(&input.claim_kind) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE6_OPENING_INPUTS: &[Stage6OpeningInputPlan] = &[\n{inputs}\n];\n\n") + } + + fn emit_field_constant_constants(&self) -> String { + let constants = self + .field_constants + .iter() + .map(|constant| { + format!( + " Stage6FieldConstantPlan {{ symbol: {}, field: {}, value: {} }},", + rust_str(&constant.symbol), + rust_str(&constant.field), + constant.value + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE6_FIELD_CONSTANTS: &[Stage6FieldConstantPlan] = &[\n{constants}\n];\n\n" + ) + } + + fn emit_field_expr_constants(&self) -> String { + if self.role == Role::Verifier { + let rows = self + .field_exprs + .chunks(8) + .map(|chunk| { + let exprs = chunk + .iter() + .map(|expr| { + format!( + "stage6_field_expr!({}, {}, {})", + rust_str(&expr.symbol), + rust_str(&expr.formula), + rust_str(&expr.operands.join("|")) + ) + }) + .collect::>() + .join(", "); + format!(" {exprs},") + }) + .collect::>() + .join("\n"); + return format!( + "macro_rules! stage6_field_expr {{\n ($symbol:literal, $formula:literal, $operands:literal) => {{\n Stage6FieldExprPlan {{ symbol: $symbol, kind: \"op\", formula: $formula, operands: $operands }}\n }};\n}}\n\n#[rustfmt::skip]\npub const STAGE6_FIELD_EXPRS: &[Stage6FieldExprPlan] = &[\n{rows}\n];\n" + ); + } + + let mut source = String::new(); + let mut arrays = Vec::new(); + let mut array_refs = Vec::new(); + for (index, expr) in self.field_exprs.iter().enumerate() { + let operands = intern_str_array( + &mut source, + &mut arrays, + "STAGE6_FIELD_EXPR_OPERANDS", + &expr.operands, + ); + let operand_names = intern_str_array( + &mut source, + &mut arrays, + "STAGE6_FIELD_EXPR_OPERANDS", + &expr.operand_names, + ); + array_refs.push((index, operand_names, operands)); + } + let exprs = self + .field_exprs + .iter() + .enumerate() + .map(|(index, expr)| { + let (_, operand_names, operands) = &array_refs[index]; + format!( + " Stage6FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operand_names: {operand_names}, operands: {operands} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE6_FIELD_EXPRS: &[Stage6FieldExprPlan] = &[\n{exprs}\n];\n" + ), + ); + source + } + + fn emit_kernel_constants(&self) -> String { + let kernels = self + .kernels + .iter() + .map(|kernel| { + format!( + " Stage6KernelPlan {{ symbol: {}, relation: {}, kind: {}, backend: {}, abi: {} }},", + rust_str(&kernel.symbol), + rust_str(&kernel.relation), + rust_str(&kernel.kind), + rust_str(&kernel.backend), + rust_str(&kernel.abi) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE6_KERNELS: &[Stage6KernelPlan] = &[\n{kernels}\n];\n\n") + } + + fn emit_sumcheck_claim_constants(&self) -> String { + if self.role == Role::Verifier { + let claims = self + .claims + .iter() + .map(|claim| { + format!( + " Stage6SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: {}, relation: {}, claim_value: {}, input_openings: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_option_str(claim.kernel.as_deref()), + rust_option_str(claim.relation.as_deref()), + rust_str(&claim.claim_value), + rust_str(&claim.input_openings.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE6_SUMCHECK_CLAIMS: &[Stage6SumcheckClaimPlan] = &[\n{claims}\n];\n" + ); + } + + let mut source = String::new(); + for (index, claim) in self.claims.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE6_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS"), + &claim.input_openings, + )); + } + let claims = self + .claims + .iter() + .enumerate() + .map(|(index, claim)| { + format!( + " Stage6SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: {}, relation: {}, claim_value: {}, input_openings: STAGE6_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_option_str(claim.kernel.as_deref()), + rust_option_str(claim.relation.as_deref()), + rust_str(&claim.claim_value) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE6_SUMCHECK_CLAIMS: &[Stage6SumcheckClaimPlan] = &[\n{claims}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE6_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage6SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {}, claim_label: {}, round_label: {}, round_schedule: STAGE6_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")), + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE6_SUMCHECK_BATCHES: &[Stage6SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + return source; + } + + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE6_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE6_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + source.push_str(&emit_usize_array( + &format!("STAGE6_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage6SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE6_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE6_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS, claim_label: {}, round_label: {}, round_schedule: STAGE6_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE6_SUMCHECK_BATCHES: &[Stage6SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_driver_constants(&self) -> String { + let mut source = String::new(); + for (index, driver) in self.drivers.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE6_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE"), + &driver.round_schedule, + )); + } + let drivers = self + .drivers + .iter() + .enumerate() + .map(|(index, driver)| { + format!( + " Stage6SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: {}, relation: {}, batch: {}, policy: {}, round_schedule: STAGE6_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_option_str(driver.kernel.as_deref()), + rust_option_str(driver.relation.as_deref()), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE6_SUMCHECK_DRIVERS: &[Stage6SumcheckDriverPlan] = &[\n{drivers}\n];\n" + ), + ); + source + } + + fn emit_tail_constants(&self) -> String { + let mut source = String::new(); + source.push_str(&self.emit_sumcheck_instance_result_constants()); + source.push_str(&self.emit_sumcheck_eval_constants()); + source.push_str(&self.emit_point_zero_constants()); + source.push_str(&self.emit_point_slice_constants()); + source.push_str(&self.emit_point_concat_constants()); + source.push_str(&self.emit_opening_claim_constants()); + source.push_str(&self.emit_opening_claim_equality_constants()); + source.push_str(&self.emit_opening_batch_constants()); + source + } + + fn emit_sumcheck_instance_result_constants(&self) -> String { + let instances = self + .instance_results + .iter() + .map(|instance| { + format!( + " Stage6SumcheckInstanceResultPlan {{ symbol: {}, source: {}, claim: {}, relation: {}, index: {}, point_arity: {}, num_rounds: {}, round_offset: {}, point_order: {}, degree: {} }},", + rust_str(&instance.symbol), + rust_str(&instance.source), + rust_str(&instance.claim), + rust_str(&instance.relation), + instance.index, + instance.point_arity, + instance.num_rounds, + instance.round_offset, + rust_str(&instance.point_order), + instance.degree + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE6_SUMCHECK_INSTANCE_RESULTS: &[Stage6SumcheckInstanceResultPlan] = &[\n{instances}\n];\n\n" + ) + } + + fn emit_sumcheck_eval_constants(&self) -> String { + let rows = self + .evals + .chunks(4) + .map(|chunk| { + let evals = chunk + .iter() + .map(|eval| { + format!( + "stage6_sumcheck_eval!({}, {}, {}, {}, {})", + rust_str(&eval.symbol), + rust_str(&eval.source), + rust_str(&eval.name), + eval.index, + rust_str(&eval.oracle) + ) + }) + .collect::>() + .join(", "); + format!(" {evals},") + }) + .collect::>() + .join("\n"); + format!( + "macro_rules! stage6_sumcheck_eval {{\n ($symbol:literal, $source:literal, $name:literal, $index:literal, $oracle:literal) => {{\n Stage6SumcheckEvalPlan {{ symbol: $symbol, source: $source, name: $name, index: $index, oracle: $oracle }}\n }};\n}}\n\n#[rustfmt::skip]\npub const STAGE6_SUMCHECK_EVALS: &[Stage6SumcheckEvalPlan] = &[\n{rows}\n];\n\n" + ) + } + + fn emit_point_zero_constants(&self) -> String { + let zeros = self + .point_zeros + .iter() + .map(|zero| { + format!( + " Stage6PointZeroPlan {{ symbol: {}, field: {}, arity: {} }},", + rust_str(&zero.symbol), + rust_str(&zero.field), + zero.arity + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE6_POINT_ZEROS: &[Stage6PointZeroPlan] = &[\n{zeros}\n];\n\n") + } + + fn emit_point_slice_constants(&self) -> String { + let slices = self + .point_slices + .iter() + .map(|slice| { + format!( + " Stage6PointSlicePlan {{ symbol: {}, source: {}, offset: {}, length: {}, input: {} }},", + rust_str(&slice.symbol), + rust_str(&slice.source), + slice.offset, + slice.length, + rust_str(&slice.input) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE6_POINT_SLICES: &[Stage6PointSlicePlan] = &[\n{slices}\n];\n\n") + } + + fn emit_point_concat_constants(&self) -> String { + if self.role == Role::Verifier { + let concats = self + .point_concats + .iter() + .map(|concat| { + format!( + " Stage6PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: {} }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity, + rust_str(&concat.inputs.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE6_POINT_CONCATS: &[Stage6PointConcatPlan] = &[\n{concats}\n];\n" + ); + } + + let mut source = String::new(); + for (index, concat) in self.point_concats.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE6_POINT_CONCAT_{index}_INPUTS"), + &concat.inputs, + )); + } + let concats = self + .point_concats + .iter() + .enumerate() + .map(|(index, concat)| { + format!( + " Stage6PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: STAGE6_POINT_CONCAT_{index}_INPUTS }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE6_POINT_CONCATS: &[Stage6PointConcatPlan] = &[\n{concats}\n];\n" + ), + ); + source + } + + fn emit_opening_claim_constants(&self) -> String { + let claims = self + .opening_claims + .iter() + .map(|claim| { + format!( + " Stage6OpeningClaimPlan {{ symbol: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {}, point_source: {}, eval_source: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.oracle), + rust_str(&claim.domain), + claim.point_arity, + rust_str(&claim.claim_kind), + rust_str(&claim.point_source), + rust_str(&claim.eval_source) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE6_OPENING_CLAIMS: &[Stage6OpeningClaimPlan] = &[\n{claims}\n];\n\n") + } + + fn emit_opening_claim_equality_constants(&self) -> String { + let equalities = self + .opening_equalities + .iter() + .map(|equality| { + format!( + " Stage6OpeningClaimEqualityPlan {{ symbol: {}, mode: {}, lhs: {}, rhs: {} }},", + rust_str(&equality.symbol), + rust_str(&equality.mode), + rust_str(&equality.lhs), + rust_str(&equality.rhs) + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE6_OPENING_EQUALITIES: &[Stage6OpeningClaimEqualityPlan] = &[\n{equalities}\n];\n\n" + ) + } + + fn emit_opening_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let batches = self + .opening_batches + .iter() + .map(|batch| { + format!( + " Stage6OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {} }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE6_OPENING_BATCHES: &[Stage6OpeningBatchPlan] = &[\n{batches}\n];\n" + ); + } + + let mut source = String::new(); + for (index, batch) in self.opening_batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE6_OPENING_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE6_OPENING_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + } + let batches = self + .opening_batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage6OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE6_OPENING_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE6_OPENING_BATCH_{index}_CLAIM_OPERANDS }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE6_OPENING_BATCHES: &[Stage6OpeningBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_entrypoint(&self) -> &'static str { + match self.role { + Role::Prover => { + "pub fn execute_stage6_prover(\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage6KernelError>\n\ + where\n\ + \x20 E: Stage6KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage6_prover_with_program(&STAGE6_PROGRAM, executor, transcript)\n\ + }\n\ + \n\ + pub fn execute_stage6_prover_with_program(\n\ + \x20 program: &'static Stage6CpuProgramPlan,\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage6KernelError>\n\ + where\n\ + \x20 E: Stage6KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage6_program(program, Stage6ExecutionMode::Prover, executor, transcript)\n\ + }\n" + } + Role::Verifier => { + r#"pub fn verify_stage6( + proof: &Stage6Proof, + opening_inputs: &[Stage6OpeningInputValue], + verifier_data: Option<&Stage6VerifierData>, + transcript: &mut T, +) -> Result, VerifyStage6Error> +where + T: Transcript, +{ + verify_stage6_with_program(&STAGE6_PROGRAM, proof, opening_inputs, verifier_data, transcript) +} + +pub fn verify_stage6_with_program( + program: &'static Stage6VerifierProgramPlan, + proof: &Stage6Proof, + opening_inputs: &[Stage6OpeningInputValue], + verifier_data: Option<&Stage6VerifierData>, + transcript: &mut T, +) -> Result, VerifyStage6Error> +where + T: Transcript, +{ + if proof.sumchecks.len() != program.drivers.len() { + return Err(VerifyStage6Error::UnexpectedProofCount { + expected: program.drivers.len(), + got: proof.sumchecks.len(), + }); + } + let mut store = + super::common::ValueStore::with_opening_inputs(opening_inputs, program.opening_inputs)?; + store.seed_constants(program.field_constants); + store.seed_point_zeros(program.point_zeros); + let mut artifacts = Stage6ExecutionArtifacts::default(); + for step in program.steps { + match step.kind { + "transcript_squeeze" => { + let squeeze = + find_plan(program.transcript_squeezes, step.symbol).ok_or(VerifyStage6Error::MissingValue { + symbol: step.symbol, + })?; + verify_stage6_squeeze(program, squeeze, &mut store, transcript, &mut artifacts)?; + } + "transcript_absorb_bytes" => { + let absorb = find_plan(program.transcript_absorb_bytes, step.symbol).ok_or( + VerifyStage6Error::MissingValue { + symbol: step.symbol, + }, + )?; + absorb_stage6_bytes(absorb, transcript); + } + "sumcheck_driver" => { + let driver = + find_plan(program.drivers, step.symbol).ok_or(VerifyStage6Error::MissingProof { + driver: step.symbol, + })?; + verify_stage6_driver( + program, + driver, + proof, + verifier_data, + &mut store, + transcript, + &mut artifacts, + )?; + } + _ => { + return Err(VerifyStage6Error::InvalidProof { + driver: step.symbol, + reason: "unsupported stage6 program step", + }); + } + } + } + artifacts + .opening_batches + .extend(program.opening_batches.iter()); + Ok(artifacts) +} + +pub fn stage6_verifier_program() -> &'static Stage6VerifierProgramPlan { + &STAGE6_PROGRAM +} + +fn verify_stage6_squeeze( + program: &'static Stage6VerifierProgramPlan, + squeeze: &'static Stage6TranscriptSqueezePlan, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage6ExecutionArtifacts, +) -> Result<(), VerifyStage6Error> +where + T: Transcript, +{ + let values = transcript.challenge_vector(squeeze.count); + store.observe_challenge_vector(squeeze, &values, |input, expected, actual| { + VerifyStage6Error::InvalidInputLength { + input, + expected, + actual, + } + })?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage6Error::from)?; + artifacts.challenge_vectors.push(Stage6ChallengeVector { + symbol: squeeze.symbol, + values, + }); + Ok(()) +} + +fn absorb_stage6_bytes(absorb: &'static Stage6TranscriptAbsorbBytesPlan, transcript: &mut T) +where + T: Transcript, +{ + transcript.append(&LabelWithCount( + absorb.label.as_bytes(), + absorb.payload.len() as u64, + )); + transcript.append_bytes(absorb.payload.as_bytes()); +} + +fn verify_stage6_driver( + program: &'static Stage6VerifierProgramPlan, + driver: &'static Stage6SumcheckDriverPlan, + proof: &Stage6Proof, + verifier_data: Option<&Stage6VerifierData>, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage6ExecutionArtifacts, +) -> Result<(), VerifyStage6Error> +where + T: Transcript, +{ + let proof = proof + .sumchecks + .get(artifacts.sumchecks.len()) + .ok_or(VerifyStage6Error::MissingProof { + driver: driver.symbol, + })?; + let relation = driver.relation.unwrap_or(""); + let output = match relation { + "jolt.stage6.batched" => { + verify_batched_stage6(program, driver, proof, verifier_data, store, transcript)? + } + _ => return Err(VerifyStage6Error::UnsupportedRelation { relation }), + }; + artifacts.sumchecks.push(output); + Ok(()) +} + +fn verify_batched_stage6( + program: &'static Stage6VerifierProgramPlan, + driver: &'static Stage6SumcheckDriverPlan, + proof: &Stage6SumcheckOutput, + verifier_data: Option<&Stage6VerifierData>, + store: &mut super::common::ValueStore, + transcript: &mut T, +) -> Result, VerifyStage6Error> +where + T: Transcript, +{ + super::common::verify_batched_sumcheck( + driver, + proof, + program.claims, + program.batches, + program.field_exprs, + program.opening_inputs, + program.opening_claims, + program.opening_batches, + store, + transcript, + |store, evals, point, batching_coeffs| { + expected_batched_output_claim( + program, + driver, + verifier_data, + store, + evals, + point, + batching_coeffs, + ) + }, + |store, verified| observe_stage6_sumcheck_output(program, store, verified), + |driver, error| VerifyStage6Error::Sumcheck { driver, error }, + ) +} + +fn observe_stage6_sumcheck_output( + program: &'static Stage6VerifierProgramPlan, + store: &mut super::common::ValueStore, + output: &Stage6SumcheckOutput, +) -> Result<(), VerifyStage6Error> { + store.observe_sumcheck_output( + program.instance_results, + program.evals, + output, + |instance, mut point| { + match instance.point_order { + "as_is" => {} + "reverse" => point.reverse(), + "bytecode_read_raf" => point = normalize_bytecode_read_raf_point(&point, stage6_trace_rounds(program)?, "stage6.bytecode_read_raf.point")?, + "stage6_booleanity" => {} + "instruction_read_raf" => point = normalize_instruction_read_raf_point(&point, "stage6.instruction_read_raf.point")?, + _ => { + return Err(VerifyStage6Error::InvalidProof { + driver: output.driver, + reason: "unsupported point order", + }); + } + } + Ok(point) + }, + |input, expected, actual| VerifyStage6Error::InvalidInputLength { + input, + expected, + actual, + }, + |symbol| VerifyStage6Error::MissingValue { symbol }, + )?; + store.evaluate_available_points( + program.point_slices, + program.point_concats, + |input, expected, actual| VerifyStage6Error::InvalidInputLength { + input, + expected, + actual, + }, + )?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage6Error::from)?; + store.verify_opening_equalities( + program.opening_equalities, + |driver, reason| VerifyStage6Error::InvalidProof { driver, reason }, + |symbol| VerifyStage6Error::MissingValue { symbol }, + ) +} + +fn expected_batched_output_claim( + program: &'static Stage6VerifierProgramPlan, + driver: &'static Stage6SumcheckDriverPlan, + verifier_data: Option<&Stage6VerifierData>, + store: &super::common::ValueStore, + evals: &[Stage6NamedEval], + point: &[Fr], + batching_coeffs: &[Fr], +) -> Result { + let batch = find_batch(program.batches, driver.symbol, driver.batch)?; + let claims = batch_claims(program.claims, batch)?; + let mut expected = Fr::from_u64(0); + for (claim, coefficient) in claims.iter().zip(batching_coeffs) { + let instance = program + .instance_results + .iter() + .find(|instance| instance.claim == claim.symbol && instance.source == driver.symbol) + .ok_or(VerifyStage6Error::MissingClaim { + batch: batch.symbol, + claim: claim.symbol, + })?; + let local_point = point + .get(instance.round_offset..instance.round_offset + instance.num_rounds) + .ok_or(VerifyStage6Error::InvalidInputLength { + input: instance.symbol, + expected: instance.round_offset + instance.num_rounds, + actual: point.len(), + })?; + let relation = claim.relation.unwrap_or(""); + let value = match relation { + "jolt.stage6.bytecode_read_raf" => { + let data = verifier_data + .and_then(|data| data.bytecode_read_raf.as_ref()) + .ok_or(VerifyStage6Error::MissingValue { + symbol: "stage6.bytecode_read_raf.data", + })?; + expected_bytecode_read_raf(program, data, store, evals, local_point)? + } + "jolt.stage6.booleanity" => { + expected_booleanity(program, store, evals, local_point)? + } + "jolt.stage6.hamming_booleanity" => { + expected_hamming_booleanity(store, evals, local_point)? + } + "jolt.stage6.ram_ra_virtual" => { + expected_ram_ra_virtual(store, evals, local_point)? + } + "jolt.stage6.instruction_ra_virtual" => { + expected_instruction_ra_virtual(program, store, evals, local_point)? + } + "jolt.stage6.inc_claim_reduction" => { + expected_inc_claim_reduction(store, evals, local_point)? + } + _ => return Err(VerifyStage6Error::UnsupportedRelation { relation }), + }; + expected += *coefficient * value; + } + Ok(expected) +} + +fn expected_bytecode_read_raf( + program: &'static Stage6VerifierProgramPlan, + data: &Stage6BytecodeReadRafData, + store: &super::common::ValueStore, + evals: &[Stage6NamedEval], + local_point: &[Fr], +) -> Result { + let log_t = stage6_trace_rounds(program)?; + Ok(expected_stage67_bytecode_read_raf( + &data.entries, + data.entry_bytecode_index, + data.num_lookup_tables, + store, + evals, + local_point, + log_t, + &STAGE6_BYTECODE_SYMBOLS, + )?) +} + +fn expected_booleanity( + program: &'static Stage6VerifierProgramPlan, + store: &super::common::ValueStore, + evals: &[Stage6NamedEval], + local_point: &[Fr], +) -> Result { + let log_t = stage6_trace_rounds(program)?; + Ok(expected_stage67_booleanity(store, evals, local_point, log_t, &STAGE6_RELATION_SYMBOLS)?) +} + +fn expected_hamming_booleanity( + store: &super::common::ValueStore, + evals: &[Stage6NamedEval], + local_point: &[Fr], +) -> Result { + Ok(expected_stage67_hamming_booleanity(store, evals, local_point, &STAGE6_RELATION_SYMBOLS)?) +} + +fn expected_ram_ra_virtual( + store: &super::common::ValueStore, + evals: &[Stage6NamedEval], + local_point: &[Fr], +) -> Result { + Ok(expected_stage67_ram_ra_virtual(store, evals, local_point, &STAGE6_RELATION_SYMBOLS)?) +} + +fn expected_instruction_ra_virtual( + program: &'static Stage6VerifierProgramPlan, + store: &super::common::ValueStore, + evals: &[Stage6NamedEval], + local_point: &[Fr], +) -> Result { + Ok(expected_stage67_instruction_ra_virtual(program.opening_inputs, store, evals, local_point, &STAGE6_RELATION_SYMBOLS)?) +} + +fn expected_inc_claim_reduction( + store: &super::common::ValueStore, + evals: &[Stage6NamedEval], + local_point: &[Fr], +) -> Result { + Ok(expected_stage67_inc_claim_reduction(store, evals, local_point, &STAGE6_RELATION_SYMBOLS)?) +} + +fn stage6_trace_rounds( + program: &'static Stage6VerifierProgramPlan, +) -> Result { + Ok(stage67_trace_rounds(program.instance_results, &STAGE6_RELATION_SYMBOLS)?) +} +"# + } + } + } + + fn role_label(&self) -> &'static str { + match self.role { + Role::Prover => "prover", + Role::Verifier => "verifier", + } + } + + fn program_plan_type(&self) -> &'static str { + match self.role { + Role::Prover => "Stage6CpuProgramPlan", + Role::Verifier => "Stage6VerifierProgramPlan", + } + } +} + +fn require_supported_symbol(kind: &str, actual: &str, expected: &str) -> Result<(), EmitError> { + if actual == expected { + Ok(()) + } else { + Err(EmitError::new(format!( + "unsupported {kind} @{actual}; expected @{expected}" + ))) + } +} + +fn emit_str_array(name: &str, values: &[String]) -> String { + if values.is_empty() { + return format!("pub const {name}: &[&str] = &[];\n\n"); + } + if let [value] = values { + return format!("pub const {name}: &[&str] = &[{}];\n\n", rust_str(value)); + } + let entries = values + .iter() + .map(|value| format!(" {},", rust_str(value))) + .collect::>() + .join("\n"); + format!("pub const {name}: &[&str] = &[\n{entries}\n];\n\n") +} + +fn emit_usize_array(name: &str, values: &[usize]) -> String { + let entries = values + .iter() + .map(|value| format!(" {value},")) + .collect::>() + .join("\n"); + format!("pub const {name}: &[usize] = &[\n{entries}\n];\n\n") +} + +fn intern_str_array( + source: &mut String, + arrays: &mut Vec<(Vec, String)>, + name_prefix: &str, + values: &[String], +) -> String { + if let Some((_, name)) = arrays + .iter() + .find(|(existing, _)| existing.as_slice() == values) + { + return name.clone(); + } + let name = format!("{name_prefix}_{}", arrays.len()); + source.push_str(&emit_str_array(&name, values)); + arrays.push((values.to_vec(), name.clone())); + name +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn rust_option_str(value: Option<&str>) -> String { + value.map_or_else( + || "None".to_owned(), + |value| format!("Some({})", rust_str(value)), + ) +} + +fn verify_count(kind: &str, symbol: &str, expected: usize, actual: usize) -> Result<(), EmitError> { + if expected == actual { + Ok(()) + } else { + Err(EmitError::new(format!( + "{kind} @{symbol} count mismatch: expected {expected}, got {actual}" + ))) + } +} + +fn symbols<'a>(values: impl Iterator) -> BTreeSet { + values.cloned().collect() +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn int_array_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "integer array"))?; + parse_int_array(&attribute).ok_or_else(|| attr_error(operation, attr, "integer array")) +} + +fn parse_int_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().parse().ok()) + .collect() +} + +fn operand_symbols( + operation: OperationRef<'_, '_>, + start_index: usize, +) -> Result, EmitError> { + (start_index..operation.operand_count()) + .map(|index| operand_symbol(operation, index)) + .collect() +} + +fn operand_symbol(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + EmitError::new(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + EmitError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name<'c: 'a, 'a>(operation: impl OperationLike<'c, 'a>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} diff --git a/crates/bolt/src/protocols/jolt/emit/rust/stage7.rs b/crates/bolt/src/protocols/jolt/emit/rust/stage7.rs new file mode 100644 index 0000000000..afc274ff7f --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/stage7.rs @@ -0,0 +1,2529 @@ +#![expect( + clippy::needless_raw_string_hashes, + reason = "generated Rust templates are kept as raw string blocks for copyable output" +)] + +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{Attribute, OperationRef}; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7CpuProgram { + pub role: Role, + pub params: Stage7Params, + pub steps: Vec, + pub transcript_squeezes: Vec, + pub transcript_absorb_bytes: Vec, + pub opening_inputs: Vec, + pub field_constants: Vec, + pub field_exprs: Vec, + pub kernels: Vec, + pub claims: Vec, + pub batches: Vec, + pub drivers: Vec, + pub instance_results: Vec, + pub evals: Vec, + pub point_zeros: Vec, + pub point_slices: Vec, + pub point_concats: Vec, + pub opening_claims: Vec, + pub opening_equalities: Vec, + pub opening_batches: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7Params { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7KernelPlan { + pub symbol: String, + pub relation: String, + pub kind: String, + pub backend: String, + pub abi: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7TranscriptSqueezePlan { + pub symbol: String, + pub label: String, + pub kind: String, + pub count: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7TranscriptAbsorbBytesPlan { + pub symbol: String, + pub label: String, + pub payload: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7ProgramStepPlan { + pub kind: String, + pub symbol: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7OpeningInputPlan { + pub symbol: String, + pub source_stage: String, + pub source_claim: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7FieldConstantPlan { + pub symbol: String, + pub field: String, + pub value: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7FieldExprPlan { + pub symbol: String, + pub kind: String, + pub formula: String, + pub operand_names: Vec, + pub operands: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckClaimPlan { + pub symbol: String, + pub stage: String, + pub domain: String, + pub num_rounds: usize, + pub degree: usize, + pub claim: String, + pub kernel: Option, + pub relation: Option, + pub claim_value: String, + pub input_openings: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, + pub claim_label: String, + pub round_label: String, + pub round_schedule: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckDriverPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub kernel: Option, + pub relation: Option, + pub batch: String, + pub policy: String, + pub round_schedule: Vec, + pub claim_label: String, + pub round_label: String, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckInstanceResultPlan { + pub symbol: String, + pub source: String, + pub claim: String, + pub relation: String, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: String, + pub degree: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckEvalPlan { + pub symbol: String, + pub source: String, + pub name: String, + pub index: usize, + pub oracle: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7PointZeroPlan { + pub symbol: String, + pub field: String, + pub arity: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7PointSlicePlan { + pub symbol: String, + pub source: String, + pub offset: usize, + pub length: usize, + pub input: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7PointConcatPlan { + pub symbol: String, + pub layout: String, + pub arity: usize, + pub inputs: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7OpeningClaimPlan { + pub symbol: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, + pub point_source: String, + pub eval_source: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7OpeningClaimEqualityPlan { + pub symbol: String, + pub mode: String, + pub lhs: String, + pub rhs: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage7OpeningBatchPlan { + pub symbol: String, + pub stage: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, +} + +pub fn stage7_cpu_program(module: &BoltModule<'_, Cpu>) -> Result { + verify_cpu_schema(module)?; + let program = Stage7CpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +pub fn emit_stage7_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = stage7_cpu_program(module)?; + + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source(), + }) +} + +impl Stage7CpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let mut params = None; + let mut steps = Vec::new(); + let mut transcript_squeezes = Vec::new(); + let mut transcript_absorb_bytes = Vec::new(); + let mut opening_inputs = Vec::new(); + let mut field_constants = Vec::new(); + let mut field_exprs = Vec::new(); + let mut kernels = Vec::new(); + let mut claims = Vec::new(); + let mut batches = Vec::new(); + let mut drivers = Vec::new(); + let mut instance_results = Vec::new(); + let mut evals = Vec::new(); + let mut point_zeros = Vec::new(); + let mut point_slices = Vec::new(); + let mut point_concats = Vec::new(); + let mut opening_claims = Vec::new(); + let mut opening_equalities = Vec::new(); + let mut opening_batches = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(Stage7Params { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.kernel" => { + kernels.push(Stage7KernelPlan { + symbol: string_attr(op, "sym_name")?, + relation: symbol_attr(op, "relation")?, + kind: string_attr(op, "kind")?, + backend: string_attr(op, "backend")?, + abi: string_attr(op, "abi")?, + }); + } + "cpu.transcript_squeeze" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage7ProgramStepPlan { + kind: "transcript_squeeze".to_owned(), + symbol: symbol.clone(), + }); + transcript_squeezes.push(Stage7TranscriptSqueezePlan { + symbol, + label: string_attr(op, "label")?, + kind: string_attr(op, "kind")?, + count: int_attr(op, "count")?, + }); + } + "cpu.transcript_absorb_bytes" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage7ProgramStepPlan { + kind: "transcript_absorb_bytes".to_owned(), + symbol: symbol.clone(), + }); + transcript_absorb_bytes.push(Stage7TranscriptAbsorbBytesPlan { + symbol, + label: string_attr(op, "label")?, + payload: string_attr(op, "payload")?, + }); + } + "cpu.opening_input" => { + opening_inputs.push(Stage7OpeningInputPlan { + symbol: string_attr(op, "sym_name")?, + source_stage: symbol_attr(op, "source_stage")?, + source_claim: symbol_attr(op, "source_claim")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + }); + } + "cpu.field_const" => { + field_constants.push(Stage7FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: int_attr(op, "value")?, + }); + } + "cpu.field_zero" => { + field_constants.push(Stage7FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 0, + }); + } + "cpu.field_one" => { + field_constants.push(Stage7FieldConstantPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + value: 1, + }); + } + "cpu.field_add" | "cpu.field_sub" | "cpu.field_mul" | "cpu.field_neg" => { + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage7FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: operation_name(op).replace("cpu.field_", "field."), + operand_names: operands.clone(), + operands, + }); + } + "cpu.field_pow" => { + let exponent = int_attr(op, "exponent")?; + let operands = operand_symbols(op, 0)?; + field_exprs.push(Stage7FieldExprPlan { + symbol: string_attr(op, "sym_name")?, + kind: "op".to_owned(), + formula: format!("field.pow:{exponent}"), + operand_names: operands.clone(), + operands, + }); + } + "cpu.sumcheck_claim" => { + claims.push(Stage7SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_verify_claim" => { + claims.push(Stage7SumcheckClaimPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + domain: symbol_attr(op, "domain")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + claim: symbol_attr(op, "claim")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + claim_value: operand_symbol(op, 0)?, + input_openings: operand_symbols(op, 1)?, + }); + } + "cpu.sumcheck_batch" => { + batches.push(Stage7SumcheckBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + round_schedule: int_array_attr(op, "round_schedule")?, + }); + } + "cpu.sumcheck_driver" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage7ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage7SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: Some(symbol_attr(op, "kernel")?), + relation: None, + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_verify" => { + let symbol = string_attr(op, "sym_name")?; + steps.push(Stage7ProgramStepPlan { + kind: "sumcheck_driver".to_owned(), + symbol: symbol.clone(), + }); + drivers.push(Stage7SumcheckDriverPlan { + symbol, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + kernel: None, + relation: Some(symbol_attr(op, "relation")?), + batch: operand_symbol(op, 1)?, + policy: string_attr(op, "policy")?, + round_schedule: int_array_attr(op, "round_schedule")?, + claim_label: string_attr(op, "claim_label")?, + round_label: string_attr(op, "round_label")?, + num_rounds: int_attr(op, "num_rounds")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_instance_result" => { + instance_results.push(Stage7SumcheckInstanceResultPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + claim: symbol_attr(op, "claim")?, + relation: symbol_attr(op, "relation")?, + index: int_attr(op, "index")?, + point_arity: int_attr(op, "point_arity")?, + num_rounds: int_attr(op, "num_rounds")?, + round_offset: int_attr(op, "round_offset")?, + point_order: string_attr(op, "point_order")?, + degree: int_attr(op, "degree")?, + }); + } + "cpu.sumcheck_eval" => { + evals.push(Stage7SumcheckEvalPlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + name: symbol_attr(op, "name")?, + index: int_attr(op, "index")?, + oracle: symbol_attr(op, "oracle")?, + }); + } + "cpu.point_zero" => { + point_zeros.push(Stage7PointZeroPlan { + symbol: string_attr(op, "sym_name")?, + field: symbol_attr(op, "field")?, + arity: int_attr(op, "arity")?, + }); + } + "cpu.point_slice" => { + point_slices.push(Stage7PointSlicePlan { + symbol: string_attr(op, "sym_name")?, + source: symbol_attr(op, "source")?, + offset: int_attr(op, "offset")?, + length: int_attr(op, "length")?, + input: operand_symbol(op, 0)?, + }); + } + "cpu.point_concat" => { + point_concats.push(Stage7PointConcatPlan { + symbol: string_attr(op, "sym_name")?, + layout: string_attr(op, "layout")?, + arity: int_attr(op, "arity")?, + inputs: operand_symbols(op, 0)?, + }); + } + "cpu.opening_claim" => { + opening_claims.push(Stage7OpeningClaimPlan { + symbol: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + point_source: operand_symbol(op, 0)?, + eval_source: operand_symbol(op, 1)?, + }); + } + "cpu.opening_claim_equal" => { + opening_equalities.push(Stage7OpeningClaimEqualityPlan { + symbol: string_attr(op, "sym_name")?, + mode: string_attr(op, "mode")?, + lhs: operand_symbol(op, 0)?, + rhs: operand_symbol(op, 1)?, + }); + } + "cpu.opening_batch" => { + opening_batches.push(Stage7OpeningBatchPlan { + symbol: string_attr(op, "sym_name")?, + stage: symbol_attr(op, "stage")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + }); + } + _ => {} + } + } + + Ok(Self { + params: params.ok_or_else(|| EmitError::new("missing cpu.params"))?, + role: module + .role() + .ok_or_else(|| EmitError::new("missing cpu party role"))?, + steps, + transcript_squeezes, + transcript_absorb_bytes, + opening_inputs, + field_constants, + field_exprs, + kernels, + claims, + batches, + drivers, + instance_results, + evals, + point_zeros, + point_slices, + point_concats, + opening_claims, + opening_equalities, + opening_batches, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + require_supported_symbol("field", &self.params.field, "bn254_fr")?; + require_supported_symbol("pcs", &self.params.pcs, "dory")?; + require_supported_symbol("transcript", &self.params.transcript, "blake2b_transcript")?; + self.verify_transcript_steps()?; + self.verify_field_flow()?; + self.verify_claim_batches()?; + match self.role { + Role::Prover => { + self.verify_kernel_definitions()?; + self.verify_prover_driver_bindings()?; + } + Role::Verifier => self.verify_verifier_driver_bindings()?, + } + self.verify_opening_flow() + } + + fn verify_transcript_steps(&self) -> Result<(), EmitError> { + for squeeze in &self.transcript_squeezes { + if !matches!( + squeeze.kind.as_str(), + "challenge_scalar" | "challenge_vector" + ) { + return Err(EmitError::new(format!( + "stage7 transcript squeeze @{} has unsupported kind `{}`", + squeeze.symbol, squeeze.kind + ))); + } + if squeeze.count == 0 { + return Err(EmitError::new(format!( + "stage7 transcript squeeze @{} has zero count", + squeeze.symbol + ))); + } + } + for absorb in &self.transcript_absorb_bytes { + if absorb.label.is_empty() { + return Err(EmitError::new(format!( + "stage7 transcript byte absorb @{} has empty label", + absorb.symbol + ))); + } + } + Ok(()) + } + + fn verify_field_flow(&self) -> Result<(), EmitError> { + for constant in &self.field_constants { + require_supported_symbol("field constant field", &constant.field, "bn254_fr")?; + } + let field_values = self.field_value_symbols(); + for expr in &self.field_exprs { + verify_count( + "field expr operands", + &expr.symbol, + expr.operand_names.len(), + expr.operands.len(), + )?; + for operand in &expr.operands { + if !field_values.contains(operand) { + return Err(EmitError::new(format!( + "field expr @{} references missing field value @{operand}", + expr.symbol + ))); + } + } + } + for claim in &self.claims { + if !field_values.contains(&claim.claim_value) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing claim value @{}", + claim.symbol, claim.claim_value + ))); + } + } + Ok(()) + } + + fn field_value_symbols(&self) -> BTreeSet { + let mut values = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + values.extend(symbols( + self.field_constants.iter().map(|constant| &constant.symbol), + )); + values.extend(symbols( + self.transcript_squeezes + .iter() + .filter(|squeeze| matches!(squeeze.kind.as_str(), "challenge_scalar" | "scalar")) + .map(|squeeze| &squeeze.symbol), + )); + values.extend(symbols(self.field_exprs.iter().map(|expr| &expr.symbol))); + values.extend(symbols(self.evals.iter().map(|eval| &eval.symbol))); + values + } + + fn verify_kernel_definitions(&self) -> Result<(), EmitError> { + for kernel in &self.kernels { + if kernel.backend != "cpu" { + return Err(EmitError::new(format!( + "stage7 kernel @{} targets unsupported backend `{}`", + kernel.symbol, kernel.backend + ))); + } + if kernel.kind != "sumcheck" { + return Err(EmitError::new(format!( + "stage7 kernel @{} has unsupported kind `{}`", + kernel.symbol, kernel.kind + ))); + } + let expected_abi = match kernel.relation.as_str() { + "jolt.stage7.hamming_weight_claim_reduction" => { + "jolt_stage7_hamming_weight_claim_reduction" + } + "jolt.stage7.batched" => "jolt_stage7_batched", + _ => { + return Err(EmitError::new(format!( + "unsupported stage7 kernel relation @{}", + kernel.relation + ))); + } + }; + if kernel.abi != expected_abi { + return Err(EmitError::new(format!( + "stage7 kernel @{} ABI `{}` does not match relation @{}", + kernel.symbol, kernel.abi, kernel.relation + ))); + } + } + Ok(()) + } + + fn verify_claim_batches(&self) -> Result<(), EmitError> { + let claims = symbols(self.claims.iter().map(|claim| &claim.symbol)); + for batch in &self.batches { + verify_count( + "sumcheck batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "sumcheck batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "sumcheck batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !claims.contains(claim) { + return Err(EmitError::new(format!( + "sumcheck batch @{} references missing claim @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn verify_prover_driver_bindings(&self) -> Result<(), EmitError> { + let kernels = symbols(self.kernels.iter().map(|kernel| &kernel.symbol)); + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + let Some(kernel) = claim.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck claim @{} is missing kernel", + claim.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck claim @{} references missing kernel @{kernel}", + claim.symbol + ))); + } + } + for driver in &self.drivers { + let Some(kernel) = driver.kernel.as_deref() else { + return Err(EmitError::new(format!( + "prover sumcheck driver @{} is missing kernel", + driver.symbol + ))); + }; + if !kernels.contains(kernel) { + return Err(EmitError::new(format!( + "sumcheck driver @{} references missing kernel @{kernel}", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_verifier_driver_bindings(&self) -> Result<(), EmitError> { + if !self.kernels.is_empty() { + return Err(EmitError::new( + "verifier stage7 program must not contain kernels", + )); + } + let batches: BTreeMap<_, _> = self + .batches + .iter() + .map(|batch| (batch.symbol.as_str(), batch)) + .collect(); + for claim in &self.claims { + if claim.kernel.is_some() || claim.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck claim @{} must carry relation and no kernel", + claim.symbol + ))); + } + } + for driver in &self.drivers { + if driver.kernel.is_some() || driver.relation.is_none() { + return Err(EmitError::new(format!( + "verifier sumcheck driver @{} must carry relation and no kernel", + driver.symbol + ))); + } + let batch = batches.get(driver.batch.as_str()).ok_or_else(|| { + EmitError::new(format!( + "sumcheck driver @{} references missing batch @{}", + driver.symbol, driver.batch + )) + })?; + verify_count( + "sumcheck driver round_schedule", + &driver.symbol, + driver.num_rounds, + driver.round_schedule.iter().sum(), + )?; + if driver.round_schedule != batch.round_schedule { + return Err(EmitError::new(format!( + "sumcheck driver @{} round_schedule differs from batch @{}", + driver.symbol, batch.symbol + ))); + } + } + Ok(()) + } + + fn verify_opening_flow(&self) -> Result<(), EmitError> { + let mut point_sources = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + point_sources.extend(symbols( + self.instance_results + .iter() + .map(|instance| &instance.symbol), + )); + point_sources.extend(symbols( + self.opening_inputs.iter().map(|input| &input.symbol), + )); + point_sources.extend(symbols(self.point_zeros.iter().map(|zero| &zero.symbol))); + point_sources.extend(symbols(self.point_slices.iter().map(|slice| &slice.symbol))); + point_sources.extend(symbols( + self.point_concats.iter().map(|concat| &concat.symbol), + )); + for zero in &self.point_zeros { + require_supported_symbol("point zero field", &zero.field, "bn254_fr")?; + } + for slice in &self.point_slices { + if !point_sources.contains(&slice.input) { + return Err(EmitError::new(format!( + "point slice @{} uses missing point source @{}", + slice.symbol, slice.input + ))); + } + } + for concat in &self.point_concats { + for input in &concat.inputs { + if !point_sources.contains(input) { + return Err(EmitError::new(format!( + "point concat @{} uses missing point source @{input}", + concat.symbol + ))); + } + } + } + let eval_sources = self.field_value_symbols(); + let mut opening_sources = symbols(self.opening_inputs.iter().map(|input| &input.symbol)); + opening_sources.extend(symbols( + self.opening_claims.iter().map(|claim| &claim.symbol), + )); + for equality in &self.opening_equalities { + if !opening_sources.contains(&equality.lhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing lhs opening @{}", + equality.symbol, equality.lhs + ))); + } + if !opening_sources.contains(&equality.rhs) { + return Err(EmitError::new(format!( + "opening equality @{} uses missing rhs opening @{}", + equality.symbol, equality.rhs + ))); + } + } + for claim in &self.claims { + for input in &claim.input_openings { + if !opening_sources.contains(input) { + return Err(EmitError::new(format!( + "sumcheck claim @{} uses missing opening @{input}", + claim.symbol + ))); + } + } + } + let drivers = symbols(self.drivers.iter().map(|driver| &driver.symbol)); + for instance in &self.instance_results { + if !drivers.contains(&instance.source) { + return Err(EmitError::new(format!( + "sumcheck instance result @{} references missing driver @{}", + instance.symbol, instance.source + ))); + } + } + for eval in &self.evals { + if !drivers.contains(&eval.source) { + return Err(EmitError::new(format!( + "sumcheck eval @{} references missing driver @{}", + eval.symbol, eval.source + ))); + } + } + for claim in &self.opening_claims { + if !point_sources.contains(&claim.point_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing point source @{}", + claim.symbol, claim.point_source + ))); + } + if !eval_sources.contains(&claim.eval_source) { + return Err(EmitError::new(format!( + "opening claim @{} uses missing eval source @{}", + claim.symbol, claim.eval_source + ))); + } + } + let openings = symbols(self.opening_claims.iter().map(|claim| &claim.symbol)); + for batch in &self.opening_batches { + verify_count( + "opening batch", + &batch.symbol, + batch.count, + batch.ordered_claims.len(), + )?; + verify_count( + "opening batch operands", + &batch.symbol, + batch.count, + batch.claim_operands.len(), + )?; + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new(format!( + "opening batch @{} operand order does not match ordered_claims", + batch.symbol + ))); + } + for claim in &batch.ordered_claims { + if !openings.contains(claim) { + return Err(EmitError::new(format!( + "opening batch @{} references missing opening @{claim}", + batch.symbol + ))); + } + } + } + Ok(()) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_stage7.rs", + Role::Verifier => "verify_stage7.rs", + } + } + + fn emit_source(&self) -> String { + let mut source = String::new(); + source.push_str("#![allow(dead_code)]\n\n"); + match self.role { + Role::Prover => { + source.push_str(Self::emit_prover_imports()); + source.push_str("\n\n"); + source.push_str(Self::emit_prover_types()); + } + Role::Verifier => { + source.push_str(Self::emit_verifier_imports()); + source.push_str("\n\n"); + source.push_str(&Self::emit_verifier_types()); + } + } + source.push('\n'); + source.push_str(&self.emit_constants()); + source.push('\n'); + source.push_str(self.emit_entrypoint()); + source + } + + fn emit_prover_imports() -> &'static str { + "use jolt_field::Fr;\n\ + use jolt_kernels::stage7::{execute_stage7_program, Stage7CpuProgramPlan, Stage7ExecutionArtifacts, Stage7ExecutionMode, Stage7FieldConstantPlan, Stage7FieldExprPlan, Stage7KernelError, Stage7KernelExecutor, Stage7KernelPlan, Stage7OpeningBatchPlan, Stage7OpeningClaimEqualityPlan, Stage7OpeningClaimPlan, Stage7OpeningInputPlan, Stage7Params, Stage7PointConcatPlan, Stage7PointSlicePlan, Stage7PointZeroPlan, Stage7ProgramStepPlan, Stage7SumcheckBatchPlan, Stage7SumcheckClaimPlan, Stage7SumcheckDriverPlan, Stage7SumcheckEvalPlan, Stage7SumcheckInstanceResultPlan, Stage7TranscriptAbsorbBytesPlan, Stage7TranscriptSqueezePlan};\n\ + use jolt_transcript::{Blake2bTranscript, Transcript};" + } + + fn emit_prover_types() -> &'static str { + "pub type DefaultStage7Transcript = Blake2bTranscript;\n" + } + + fn emit_verifier_imports() -> &'static str { + "use super::common::{batch_claims, eval_by_name, find_batch, find_plan, normalize_bytecode_read_raf_point, normalize_instruction_read_raf_point, reverse_slice};\n\ + use jolt_field::{Field, Fr};\n\ + use jolt_poly::EqPolynomial;\n\ + use jolt_sumcheck::SumcheckError;\n\ + use jolt_transcript::{Blake2bTranscript, LabelWithCount, Transcript};" + } + + #[expect(dead_code)] + fn emit_types() -> &'static str { + r#"#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7Params { + pub field: &'static str, + pub pcs: &'static str, + pub transcript: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7KernelPlan { + pub symbol: &'static str, + pub relation: &'static str, + pub kind: &'static str, + pub backend: &'static str, + pub abi: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7TranscriptSqueezePlan { + pub symbol: &'static str, + pub label: &'static str, + pub kind: &'static str, + pub count: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7TranscriptAbsorbBytesPlan { + pub symbol: &'static str, + pub label: &'static str, + pub payload: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7ProgramStepPlan { + pub kind: &'static str, + pub symbol: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7OpeningInputPlan { + pub symbol: &'static str, + pub source_stage: &'static str, + pub source_claim: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7FieldConstantPlan { + pub symbol: &'static str, + pub field: &'static str, + pub value: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7FieldExprPlan { + pub symbol: &'static str, + pub kind: &'static str, + pub formula: &'static str, + pub operand_names: &'static [&'static str], + pub operands: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckClaimPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub domain: &'static str, + pub num_rounds: usize, + pub degree: usize, + pub claim: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub claim_value: &'static str, + pub input_openings: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static [&'static str], + pub claim_operands: &'static [&'static str], + pub claim_label: &'static str, + pub round_label: &'static str, + pub round_schedule: &'static [usize], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckDriverPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub batch: &'static str, + pub policy: &'static str, + pub round_schedule: &'static [usize], + pub claim_label: &'static str, + pub round_label: &'static str, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckInstanceResultPlan { + pub symbol: &'static str, + pub source: &'static str, + pub claim: &'static str, + pub relation: &'static str, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: &'static str, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7SumcheckEvalPlan { + pub symbol: &'static str, + pub source: &'static str, + pub name: &'static str, + pub index: usize, + pub oracle: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7PointZeroPlan { + pub symbol: &'static str, + pub field: &'static str, + pub arity: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7PointSlicePlan { + pub symbol: &'static str, + pub source: &'static str, + pub offset: usize, + pub length: usize, + pub input: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7PointConcatPlan { + pub symbol: &'static str, + pub layout: &'static str, + pub arity: usize, + pub inputs: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7OpeningClaimPlan { + pub symbol: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, + pub point_source: &'static str, + pub eval_source: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7OpeningClaimEqualityPlan { + pub symbol: &'static str, + pub mode: &'static str, + pub lhs: &'static str, + pub rhs: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7OpeningBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static [&'static str], + pub claim_operands: &'static [&'static str], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage7CpuProgramPlan { + pub role: &'static str, + pub params: Stage7Params, + pub steps: &'static [Stage7ProgramStepPlan], + pub transcript_squeezes: &'static [Stage7TranscriptSqueezePlan], + pub transcript_absorb_bytes: &'static [Stage7TranscriptAbsorbBytesPlan], + pub opening_inputs: &'static [Stage7OpeningInputPlan], + pub field_constants: &'static [Stage7FieldConstantPlan], + pub field_exprs: &'static [Stage7FieldExprPlan], + pub kernels: &'static [Stage7KernelPlan], + pub claims: &'static [Stage7SumcheckClaimPlan], + pub batches: &'static [Stage7SumcheckBatchPlan], + pub drivers: &'static [Stage7SumcheckDriverPlan], + pub instance_results: &'static [Stage7SumcheckInstanceResultPlan], + pub evals: &'static [Stage7SumcheckEvalPlan], + pub point_zeros: &'static [Stage7PointZeroPlan], + pub point_slices: &'static [Stage7PointSlicePlan], + pub point_concats: &'static [Stage7PointConcatPlan], + pub opening_claims: &'static [Stage7OpeningClaimPlan], + pub opening_equalities: &'static [Stage7OpeningClaimEqualityPlan], + pub opening_batches: &'static [Stage7OpeningBatchPlan], +} +"# + } + + fn emit_verifier_type_aliases() -> &'static str { + r#"pub type Stage7NamedEval = super::common::StageNamedEval; +pub type Stage7SumcheckOutput = super::common::StageSumcheckOutput; +pub type Stage7ChallengeVector = super::common::StageChallengeVector; +pub type Stage7ExecutionArtifacts = super::common::StageExecutionArtifacts; +pub type Stage7Proof = super::common::StageProof; +pub type Stage7OpeningInputValue = super::common::StageOpeningInputValue; + +pub use super::common::{ + FieldConstantPlan as Stage7FieldConstantPlan, FieldExprPlan as Stage7FieldExprPlan, + KernelPlan as Stage7KernelPlan, OpeningBatchPlan as Stage7OpeningBatchPlan, + OpeningClaimEqualityPlan as Stage7OpeningClaimEqualityPlan, + OpeningClaimPlan as Stage7OpeningClaimPlan, OpeningInputPlan as Stage7OpeningInputPlan, + PointConcatPlan as Stage7PointConcatPlan, PointSlicePlan as Stage7PointSlicePlan, + PointZeroPlan as Stage7PointZeroPlan, ProgramStepPlan as Stage7ProgramStepPlan, + StageParams as Stage7Params, StageProgramPlan as Stage7CpuProgramPlan, + SumcheckBatchPlan as Stage7SumcheckBatchPlan, + SumcheckClaimPlan as Stage7SumcheckClaimPlan, SumcheckDriverPlan as Stage7SumcheckDriverPlan, + SumcheckEvalPlan as Stage7SumcheckEvalPlan, + SumcheckInstanceResultPlan as Stage7SumcheckInstanceResultPlan, + TranscriptAbsorbBytesPlan as Stage7TranscriptAbsorbBytesPlan, + TranscriptSqueezePlan as Stage7TranscriptSqueezePlan, +}; +"# + } + + fn emit_verifier_types() -> String { + let mut source = Self::emit_verifier_type_aliases().to_owned(); + source.push_str( + r#" +pub type DefaultStage7Transcript = Blake2bTranscript; +pub type Stage7VerifierProgramPlan = Stage7CpuProgramPlan; + +#[derive(Debug)] +pub enum VerifyStage7Error { + UnexpectedProofCount { expected: usize, got: usize }, + MissingProof { driver: &'static str }, + MissingBatch { driver: &'static str, batch: &'static str }, + MissingClaim { batch: &'static str, claim: &'static str }, + MissingValue { symbol: &'static str }, + InvalidInputLength { input: &'static str, expected: usize, actual: usize }, + InvalidProof { driver: &'static str, reason: &'static str }, + UnsupportedFieldExpr { symbol: &'static str, formula: &'static str }, + UnsupportedRelation { relation: &'static str }, + Sumcheck { driver: &'static str, error: SumcheckError }, +} + +super::common::impl_runtime_plan_error_conversion!(VerifyStage7Error); +"#, + ); + source + } + + fn emit_constants(&self) -> String { + let mut source = self.emit_shared_constants(); + source.push_str(&self.emit_kernel_constants()); + source.push_str(&self.emit_sumcheck_claim_constants()); + source.push_str(&self.emit_sumcheck_batch_constants()); + source.push_str(&self.emit_sumcheck_driver_constants()); + source.push_str(&self.emit_tail_constants()); + push_format( + &mut source, + format_args!( + "pub const STAGE7_PROGRAM: {} = Stage7CpuProgramPlan {{\n\ + \x20 role: {},\n\ + \x20 params: STAGE7_PARAMS,\n\ + \x20 steps: STAGE7_PROGRAM_STEPS,\n\ + \x20 transcript_squeezes: STAGE7_TRANSCRIPT_SQUEEZES,\n\ + \x20 transcript_absorb_bytes: STAGE7_TRANSCRIPT_ABSORB_BYTES,\n\ + \x20 opening_inputs: STAGE7_OPENING_INPUTS,\n\ + \x20 field_constants: STAGE7_FIELD_CONSTANTS,\n\ + \x20 field_exprs: STAGE7_FIELD_EXPRS,\n\ + \x20 kernels: STAGE7_KERNELS,\n\ + \x20 claims: STAGE7_SUMCHECK_CLAIMS,\n\ + \x20 batches: STAGE7_SUMCHECK_BATCHES,\n\ + \x20 drivers: STAGE7_SUMCHECK_DRIVERS,\n\ + \x20 instance_results: STAGE7_SUMCHECK_INSTANCE_RESULTS,\n\ + \x20 evals: STAGE7_SUMCHECK_EVALS,\n\ + \x20 point_zeros: STAGE7_POINT_ZEROS,\n\ + \x20 point_slices: STAGE7_POINT_SLICES,\n\ + \x20 point_concats: STAGE7_POINT_CONCATS,\n\ + \x20 opening_claims: STAGE7_OPENING_CLAIMS,\n\ + \x20 opening_equalities: STAGE7_OPENING_EQUALITIES,\n\ + \x20 opening_batches: STAGE7_OPENING_BATCHES,\n\ + }};\n", + self.program_plan_type(), + rust_str(self.role_label()) + ), + ); + source + } + + fn emit_shared_constants(&self) -> String { + let mut source = String::new(); + push_format( + &mut source, + format_args!( + "pub const STAGE7_PARAMS: Stage7Params = Stage7Params {{\n\ + \x20 field: {},\n\ + \x20 pcs: {},\n\ + \x20 transcript: {},\n\ + }};\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript) + ), + ); + source.push_str(&self.emit_program_step_constants()); + source.push_str(&self.emit_transcript_squeeze_constants()); + source.push_str(&self.emit_transcript_absorb_bytes_constants()); + source.push_str(&self.emit_opening_input_constants()); + source.push_str(&self.emit_field_constant_constants()); + source.push_str(&self.emit_field_expr_constants()); + source + } + + fn emit_program_step_constants(&self) -> String { + let steps = self + .steps + .iter() + .map(|step| { + format!( + " Stage7ProgramStepPlan {{ kind: {}, symbol: {} }},", + rust_str(&step.kind), + rust_str(&step.symbol), + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE7_PROGRAM_STEPS: &[Stage7ProgramStepPlan] = &[\n{steps}\n];\n\n") + } + + fn emit_transcript_squeeze_constants(&self) -> String { + let squeezes = self + .transcript_squeezes + .iter() + .map(|squeeze| { + format!( + " Stage7TranscriptSqueezePlan {{ symbol: {}, label: {}, kind: {}, count: {} }},", + rust_str(&squeeze.symbol), + rust_str(&squeeze.label), + rust_str(&squeeze.kind), + squeeze.count, + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE7_TRANSCRIPT_SQUEEZES: &[Stage7TranscriptSqueezePlan] = &[\n{squeezes}\n];\n\n" + ) + } + + fn emit_transcript_absorb_bytes_constants(&self) -> String { + let absorbs = self + .transcript_absorb_bytes + .iter() + .map(|absorb| { + format!( + " Stage7TranscriptAbsorbBytesPlan {{ symbol: {}, label: {}, payload: {} }},", + rust_str(&absorb.symbol), + rust_str(&absorb.label), + rust_str(&absorb.payload), + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE7_TRANSCRIPT_ABSORB_BYTES: &[Stage7TranscriptAbsorbBytesPlan] = &[\n{absorbs}\n];\n\n" + ) + } + + fn emit_opening_input_constants(&self) -> String { + let inputs = self + .opening_inputs + .iter() + .map(|input| { + format!( + " Stage7OpeningInputPlan {{ symbol: {}, source_stage: {}, source_claim: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {} }},", + rust_str(&input.symbol), + rust_str(&input.source_stage), + rust_str(&input.source_claim), + rust_str(&input.oracle), + rust_str(&input.domain), + input.point_arity, + rust_str(&input.claim_kind) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE7_OPENING_INPUTS: &[Stage7OpeningInputPlan] = &[\n{inputs}\n];\n\n") + } + + fn emit_field_constant_constants(&self) -> String { + let constants = self + .field_constants + .iter() + .map(|constant| { + format!( + " Stage7FieldConstantPlan {{ symbol: {}, field: {}, value: {} }},", + rust_str(&constant.symbol), + rust_str(&constant.field), + constant.value + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE7_FIELD_CONSTANTS: &[Stage7FieldConstantPlan] = &[\n{constants}\n];\n\n" + ) + } + + fn emit_field_expr_constants(&self) -> String { + if self.role == Role::Verifier { + let rows = self + .field_exprs + .chunks(8) + .map(|chunk| { + let exprs = chunk + .iter() + .map(|expr| { + format!( + "stage7_field_expr!({}, {}, {})", + rust_str(&expr.symbol), + rust_str(&expr.formula), + rust_str(&expr.operands.join("|")) + ) + }) + .collect::>() + .join(", "); + format!(" {exprs},") + }) + .collect::>() + .join("\n"); + return format!( + "macro_rules! stage7_field_expr {{\n ($symbol:literal, $formula:literal, $operands:literal) => {{\n Stage7FieldExprPlan {{ symbol: $symbol, kind: \"op\", formula: $formula, operands: $operands }}\n }};\n}}\n\n#[rustfmt::skip]\npub const STAGE7_FIELD_EXPRS: &[Stage7FieldExprPlan] = &[\n{rows}\n];\n" + ); + } + + let mut source = String::new(); + let mut arrays = Vec::new(); + let mut array_refs = Vec::new(); + for (index, expr) in self.field_exprs.iter().enumerate() { + let operands = intern_str_array( + &mut source, + &mut arrays, + "STAGE7_FIELD_EXPR_OPERANDS", + &expr.operands, + ); + let operand_names = intern_str_array( + &mut source, + &mut arrays, + "STAGE7_FIELD_EXPR_OPERANDS", + &expr.operand_names, + ); + array_refs.push((index, operand_names, operands)); + } + let exprs = self + .field_exprs + .iter() + .enumerate() + .map(|(index, expr)| { + let (_, operand_names, operands) = &array_refs[index]; + format!( + " Stage7FieldExprPlan {{ symbol: {}, kind: {}, formula: {}, operand_names: {operand_names}, operands: {operands} }},", + rust_str(&expr.symbol), + rust_str(&expr.kind), + rust_str(&expr.formula) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE7_FIELD_EXPRS: &[Stage7FieldExprPlan] = &[\n{exprs}\n];\n" + ), + ); + source + } + + fn emit_kernel_constants(&self) -> String { + let kernels = self + .kernels + .iter() + .map(|kernel| { + format!( + " Stage7KernelPlan {{ symbol: {}, relation: {}, kind: {}, backend: {}, abi: {} }},", + rust_str(&kernel.symbol), + rust_str(&kernel.relation), + rust_str(&kernel.kind), + rust_str(&kernel.backend), + rust_str(&kernel.abi) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE7_KERNELS: &[Stage7KernelPlan] = &[\n{kernels}\n];\n\n") + } + + fn emit_sumcheck_claim_constants(&self) -> String { + if self.role == Role::Verifier { + let claims = self + .claims + .iter() + .map(|claim| { + format!( + " Stage7SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: {}, relation: {}, claim_value: {}, input_openings: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_option_str(claim.kernel.as_deref()), + rust_option_str(claim.relation.as_deref()), + rust_str(&claim.claim_value), + rust_str(&claim.input_openings.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE7_SUMCHECK_CLAIMS: &[Stage7SumcheckClaimPlan] = &[\n{claims}\n];\n" + ); + } + + let mut source = String::new(); + for (index, claim) in self.claims.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE7_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS"), + &claim.input_openings, + )); + } + let claims = self + .claims + .iter() + .enumerate() + .map(|(index, claim)| { + format!( + " Stage7SumcheckClaimPlan {{ symbol: {}, stage: {}, domain: {}, num_rounds: {}, degree: {}, claim: {}, kernel: {}, relation: {}, claim_value: {}, input_openings: STAGE7_SUMCHECK_CLAIM_{index}_INPUT_OPENINGS }},", + rust_str(&claim.symbol), + rust_str(&claim.stage), + rust_str(&claim.domain), + claim.num_rounds, + claim.degree, + rust_str(&claim.claim), + rust_option_str(claim.kernel.as_deref()), + rust_option_str(claim.relation.as_deref()), + rust_str(&claim.claim_value) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE7_SUMCHECK_CLAIMS: &[Stage7SumcheckClaimPlan] = &[\n{claims}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE7_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage7SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {}, claim_label: {}, round_label: {}, round_schedule: STAGE7_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")), + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE7_SUMCHECK_BATCHES: &[Stage7SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + return source; + } + + let mut source = String::new(); + for (index, batch) in self.batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE7_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE7_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + source.push_str(&emit_usize_array( + &format!("STAGE7_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE"), + &batch.round_schedule, + )); + } + let batches = self + .batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage7SumcheckBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE7_SUMCHECK_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE7_SUMCHECK_BATCH_{index}_CLAIM_OPERANDS, claim_label: {}, round_label: {}, round_schedule: STAGE7_SUMCHECK_BATCH_{index}_ROUND_SCHEDULE }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.claim_label), + rust_str(&batch.round_label) + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE7_SUMCHECK_BATCHES: &[Stage7SumcheckBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_sumcheck_driver_constants(&self) -> String { + let mut source = String::new(); + for (index, driver) in self.drivers.iter().enumerate() { + source.push_str(&emit_usize_array( + &format!("STAGE7_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE"), + &driver.round_schedule, + )); + } + let drivers = self + .drivers + .iter() + .enumerate() + .map(|(index, driver)| { + format!( + " Stage7SumcheckDriverPlan {{ symbol: {}, stage: {}, proof_slot: {}, kernel: {}, relation: {}, batch: {}, policy: {}, round_schedule: STAGE7_SUMCHECK_DRIVER_{index}_ROUND_SCHEDULE, claim_label: {}, round_label: {}, num_rounds: {}, degree: {} }},", + rust_str(&driver.symbol), + rust_str(&driver.stage), + rust_str(&driver.proof_slot), + rust_option_str(driver.kernel.as_deref()), + rust_option_str(driver.relation.as_deref()), + rust_str(&driver.batch), + rust_str(&driver.policy), + rust_str(&driver.claim_label), + rust_str(&driver.round_label), + driver.num_rounds, + driver.degree + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE7_SUMCHECK_DRIVERS: &[Stage7SumcheckDriverPlan] = &[\n{drivers}\n];\n" + ), + ); + source + } + + fn emit_tail_constants(&self) -> String { + let mut source = String::new(); + source.push_str(&self.emit_sumcheck_instance_result_constants()); + source.push_str(&self.emit_sumcheck_eval_constants()); + source.push_str(&self.emit_point_zero_constants()); + source.push_str(&self.emit_point_slice_constants()); + source.push_str(&self.emit_point_concat_constants()); + source.push_str(&self.emit_opening_claim_constants()); + source.push_str(&self.emit_opening_claim_equality_constants()); + source.push_str(&self.emit_opening_batch_constants()); + source + } + + fn emit_sumcheck_instance_result_constants(&self) -> String { + let instances = self + .instance_results + .iter() + .map(|instance| { + format!( + " Stage7SumcheckInstanceResultPlan {{ symbol: {}, source: {}, claim: {}, relation: {}, index: {}, point_arity: {}, num_rounds: {}, round_offset: {}, point_order: {}, degree: {} }},", + rust_str(&instance.symbol), + rust_str(&instance.source), + rust_str(&instance.claim), + rust_str(&instance.relation), + instance.index, + instance.point_arity, + instance.num_rounds, + instance.round_offset, + rust_str(&instance.point_order), + instance.degree + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE7_SUMCHECK_INSTANCE_RESULTS: &[Stage7SumcheckInstanceResultPlan] = &[\n{instances}\n];\n\n" + ) + } + + fn emit_sumcheck_eval_constants(&self) -> String { + let rows = self + .evals + .chunks(4) + .map(|chunk| { + let evals = chunk + .iter() + .map(|eval| { + format!( + "stage7_sumcheck_eval!({}, {}, {}, {}, {})", + rust_str(&eval.symbol), + rust_str(&eval.source), + rust_str(&eval.name), + eval.index, + rust_str(&eval.oracle) + ) + }) + .collect::>() + .join(", "); + format!(" {evals},") + }) + .collect::>() + .join("\n"); + format!( + "macro_rules! stage7_sumcheck_eval {{\n ($symbol:literal, $source:literal, $name:literal, $index:literal, $oracle:literal) => {{\n Stage7SumcheckEvalPlan {{ symbol: $symbol, source: $source, name: $name, index: $index, oracle: $oracle }}\n }};\n}}\n\n#[rustfmt::skip]\npub const STAGE7_SUMCHECK_EVALS: &[Stage7SumcheckEvalPlan] = &[\n{rows}\n];\n\n" + ) + } + + fn emit_point_zero_constants(&self) -> String { + let zeros = self + .point_zeros + .iter() + .map(|zero| { + format!( + " Stage7PointZeroPlan {{ symbol: {}, field: {}, arity: {} }},", + rust_str(&zero.symbol), + rust_str(&zero.field), + zero.arity + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE7_POINT_ZEROS: &[Stage7PointZeroPlan] = &[\n{zeros}\n];\n\n") + } + + fn emit_point_slice_constants(&self) -> String { + let slices = self + .point_slices + .iter() + .map(|slice| { + format!( + " Stage7PointSlicePlan {{ symbol: {}, source: {}, offset: {}, length: {}, input: {} }},", + rust_str(&slice.symbol), + rust_str(&slice.source), + slice.offset, + slice.length, + rust_str(&slice.input) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE7_POINT_SLICES: &[Stage7PointSlicePlan] = &[\n{slices}\n];\n\n") + } + + fn emit_point_concat_constants(&self) -> String { + if self.role == Role::Verifier { + let concats = self + .point_concats + .iter() + .map(|concat| { + format!( + " Stage7PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: {} }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity, + rust_str(&concat.inputs.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE7_POINT_CONCATS: &[Stage7PointConcatPlan] = &[\n{concats}\n];\n" + ); + } + + let mut source = String::new(); + for (index, concat) in self.point_concats.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE7_POINT_CONCAT_{index}_INPUTS"), + &concat.inputs, + )); + } + let concats = self + .point_concats + .iter() + .enumerate() + .map(|(index, concat)| { + format!( + " Stage7PointConcatPlan {{ symbol: {}, layout: {}, arity: {}, inputs: STAGE7_POINT_CONCAT_{index}_INPUTS }},", + rust_str(&concat.symbol), + rust_str(&concat.layout), + concat.arity + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE7_POINT_CONCATS: &[Stage7PointConcatPlan] = &[\n{concats}\n];\n" + ), + ); + source + } + + fn emit_opening_claim_constants(&self) -> String { + let claims = self + .opening_claims + .iter() + .map(|claim| { + format!( + " Stage7OpeningClaimPlan {{ symbol: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {}, point_source: {}, eval_source: {} }},", + rust_str(&claim.symbol), + rust_str(&claim.oracle), + rust_str(&claim.domain), + claim.point_arity, + rust_str(&claim.claim_kind), + rust_str(&claim.point_source), + rust_str(&claim.eval_source) + ) + }) + .collect::>() + .join("\n"); + format!("pub const STAGE7_OPENING_CLAIMS: &[Stage7OpeningClaimPlan] = &[\n{claims}\n];\n\n") + } + + fn emit_opening_claim_equality_constants(&self) -> String { + let equalities = self + .opening_equalities + .iter() + .map(|equality| { + format!( + " Stage7OpeningClaimEqualityPlan {{ symbol: {}, mode: {}, lhs: {}, rhs: {} }},", + rust_str(&equality.symbol), + rust_str(&equality.mode), + rust_str(&equality.lhs), + rust_str(&equality.rhs) + ) + }) + .collect::>() + .join("\n"); + format!( + "pub const STAGE7_OPENING_EQUALITIES: &[Stage7OpeningClaimEqualityPlan] = &[\n{equalities}\n];\n\n" + ) + } + + fn emit_opening_batch_constants(&self) -> String { + if self.role == Role::Verifier { + let batches = self + .opening_batches + .iter() + .map(|batch| { + format!( + " Stage7OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: {}, claim_operands: {} }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + rust_str(&batch.ordered_claims.join("|")), + rust_str(&batch.claim_operands.join("|")) + ) + }) + .collect::>() + .join("\n"); + return format!( + "pub const STAGE7_OPENING_BATCHES: &[Stage7OpeningBatchPlan] = &[\n{batches}\n];\n" + ); + } + + let mut source = String::new(); + for (index, batch) in self.opening_batches.iter().enumerate() { + source.push_str(&emit_str_array( + &format!("STAGE7_OPENING_BATCH_{index}_ORDERED_CLAIMS"), + &batch.ordered_claims, + )); + source.push_str(&emit_str_array( + &format!("STAGE7_OPENING_BATCH_{index}_CLAIM_OPERANDS"), + &batch.claim_operands, + )); + } + let batches = self + .opening_batches + .iter() + .enumerate() + .map(|(index, batch)| { + format!( + " Stage7OpeningBatchPlan {{ symbol: {}, stage: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE7_OPENING_BATCH_{index}_ORDERED_CLAIMS, claim_operands: STAGE7_OPENING_BATCH_{index}_CLAIM_OPERANDS }},", + rust_str(&batch.symbol), + rust_str(&batch.stage), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count + ) + }) + .collect::>() + .join("\n"); + push_format( + &mut source, + format_args!( + "pub const STAGE7_OPENING_BATCHES: &[Stage7OpeningBatchPlan] = &[\n{batches}\n];\n" + ), + ); + source + } + + fn emit_entrypoint(&self) -> &'static str { + match self.role { + Role::Prover => { + "pub fn execute_stage7_prover(\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage7KernelError>\n\ + where\n\ + \x20 E: Stage7KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage7_prover_with_program(&STAGE7_PROGRAM, executor, transcript)\n\ + }\n\ + \n\ + pub fn execute_stage7_prover_with_program(\n\ + \x20 program: &'static Stage7CpuProgramPlan,\n\ + \x20 executor: &mut E,\n\ + \x20 transcript: &mut T,\n\ + ) -> Result, Stage7KernelError>\n\ + where\n\ + \x20 E: Stage7KernelExecutor,\n\ + \x20 T: Transcript,\n\ + {\n\ + \x20 execute_stage7_program(program, Stage7ExecutionMode::Prover, executor, transcript)\n\ + }\n" + } + Role::Verifier => { + r#"pub fn verify_stage7( + proof: &Stage7Proof, + opening_inputs: &[Stage7OpeningInputValue], + transcript: &mut T, +) -> Result, VerifyStage7Error> +where + T: Transcript, +{ + verify_stage7_with_program(&STAGE7_PROGRAM, proof, opening_inputs, transcript) +} + +pub fn verify_stage7_with_program( + program: &'static Stage7VerifierProgramPlan, + proof: &Stage7Proof, + opening_inputs: &[Stage7OpeningInputValue], + transcript: &mut T, +) -> Result, VerifyStage7Error> +where + T: Transcript, +{ + if proof.sumchecks.len() != program.drivers.len() { + return Err(VerifyStage7Error::UnexpectedProofCount { + expected: program.drivers.len(), + got: proof.sumchecks.len(), + }); + } + let mut store = + super::common::ValueStore::with_opening_inputs(opening_inputs, program.opening_inputs)?; + store.seed_constants(program.field_constants); + store.seed_point_zeros(program.point_zeros); + let mut artifacts = Stage7ExecutionArtifacts::default(); + for step in program.steps { + match step.kind { + "transcript_squeeze" => { + let squeeze = + find_plan(program.transcript_squeezes, step.symbol).ok_or(VerifyStage7Error::MissingValue { + symbol: step.symbol, + })?; + verify_stage7_squeeze(program, squeeze, &mut store, transcript, &mut artifacts)?; + } + "transcript_absorb_bytes" => { + let absorb = find_plan(program.transcript_absorb_bytes, step.symbol).ok_or( + VerifyStage7Error::MissingValue { + symbol: step.symbol, + }, + )?; + absorb_stage7_bytes(absorb, transcript); + } + "sumcheck_driver" => { + let driver = + find_plan(program.drivers, step.symbol).ok_or(VerifyStage7Error::MissingProof { + driver: step.symbol, + })?; + verify_stage7_driver( + program, + driver, + proof, + &mut store, + transcript, + &mut artifacts, + )?; + } + _ => { + return Err(VerifyStage7Error::InvalidProof { + driver: step.symbol, + reason: "unsupported stage7 program step", + }); + } + } + } + artifacts + .opening_batches + .extend(program.opening_batches.iter()); + Ok(artifacts) +} + +pub fn stage7_verifier_program() -> &'static Stage7VerifierProgramPlan { + &STAGE7_PROGRAM +} + +fn verify_stage7_squeeze( + program: &'static Stage7VerifierProgramPlan, + squeeze: &'static Stage7TranscriptSqueezePlan, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage7ExecutionArtifacts, +) -> Result<(), VerifyStage7Error> +where + T: Transcript, +{ + let values = transcript.challenge_vector(squeeze.count); + store.observe_challenge_vector(squeeze, &values, |input, expected, actual| { + VerifyStage7Error::InvalidInputLength { + input, + expected, + actual, + } + })?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage7Error::from)?; + artifacts.challenge_vectors.push(Stage7ChallengeVector { + symbol: squeeze.symbol, + values, + }); + Ok(()) +} + +fn absorb_stage7_bytes(absorb: &'static Stage7TranscriptAbsorbBytesPlan, transcript: &mut T) +where + T: Transcript, +{ + transcript.append(&LabelWithCount( + absorb.label.as_bytes(), + absorb.payload.len() as u64, + )); + transcript.append_bytes(absorb.payload.as_bytes()); +} + +fn verify_stage7_driver( + program: &'static Stage7VerifierProgramPlan, + driver: &'static Stage7SumcheckDriverPlan, + proof: &Stage7Proof, + store: &mut super::common::ValueStore, + transcript: &mut T, + artifacts: &mut Stage7ExecutionArtifacts, +) -> Result<(), VerifyStage7Error> +where + T: Transcript, +{ + let proof = proof + .sumchecks + .get(artifacts.sumchecks.len()) + .ok_or(VerifyStage7Error::MissingProof { + driver: driver.symbol, + })?; + let relation = driver.relation.unwrap_or(""); + let output = match relation { + "jolt.stage7.batched" => { + verify_batched_stage7(program, driver, proof, store, transcript)? + } + _ => return Err(VerifyStage7Error::UnsupportedRelation { relation }), + }; + artifacts.sumchecks.push(output); + Ok(()) +} + +fn verify_batched_stage7( + program: &'static Stage7VerifierProgramPlan, + driver: &'static Stage7SumcheckDriverPlan, + proof: &Stage7SumcheckOutput, + store: &mut super::common::ValueStore, + transcript: &mut T, +) -> Result, VerifyStage7Error> +where + T: Transcript, +{ + super::common::verify_batched_sumcheck( + driver, + proof, + program.claims, + program.batches, + program.field_exprs, + program.opening_inputs, + program.opening_claims, + program.opening_batches, + store, + transcript, + |store, evals, point, batching_coeffs| { + expected_batched_output_claim(program, driver, store, evals, point, batching_coeffs) + }, + |store, verified| observe_stage7_sumcheck_output(program, store, verified), + |driver, error| VerifyStage7Error::Sumcheck { driver, error }, + ) +} + +fn observe_stage7_sumcheck_output( + program: &'static Stage7VerifierProgramPlan, + store: &mut super::common::ValueStore, + output: &Stage7SumcheckOutput, +) -> Result<(), VerifyStage7Error> { + store.observe_sumcheck_output( + program.instance_results, + program.evals, + output, + |instance, mut point| { + match instance.point_order { + "as_is" => {} + "reverse" => point.reverse(), + "bytecode_read_raf" => point = normalize_bytecode_read_raf_point(&point, stage7_trace_rounds(program)?, "stage7.bytecode_read_raf.point")?, + "stage7_booleanity" => {} + "instruction_read_raf" => point = normalize_instruction_read_raf_point(&point, "stage7.instruction_read_raf.point")?, + _ => { + return Err(VerifyStage7Error::InvalidProof { + driver: output.driver, + reason: "unsupported point order", + }); + } + } + Ok(point) + }, + |input, expected, actual| VerifyStage7Error::InvalidInputLength { + input, + expected, + actual, + }, + |symbol| VerifyStage7Error::MissingValue { symbol }, + )?; + store.evaluate_available_points( + program.point_slices, + program.point_concats, + |input, expected, actual| VerifyStage7Error::InvalidInputLength { + input, + expected, + actual, + }, + )?; + store + .evaluate_available_field_exprs(program.field_exprs, super::common::evaluate_field_expr) + .map_err(VerifyStage7Error::from)?; + store.verify_opening_equalities( + program.opening_equalities, + |driver, reason| VerifyStage7Error::InvalidProof { driver, reason }, + |symbol| VerifyStage7Error::MissingValue { symbol }, + ) +} + +fn expected_batched_output_claim( + program: &'static Stage7VerifierProgramPlan, + driver: &'static Stage7SumcheckDriverPlan, + store: &super::common::ValueStore, + evals: &[Stage7NamedEval], + point: &[Fr], + batching_coeffs: &[Fr], +) -> Result { + let batch = find_batch(program.batches, driver.symbol, driver.batch)?; + let claims = batch_claims(program.claims, batch)?; + let mut expected = Fr::from_u64(0); + for (claim, coefficient) in claims.iter().zip(batching_coeffs) { + let instance = program + .instance_results + .iter() + .find(|instance| instance.claim == claim.symbol && instance.source == driver.symbol) + .ok_or(VerifyStage7Error::MissingClaim { + batch: batch.symbol, + claim: claim.symbol, + })?; + let local_point = point + .get(instance.round_offset..instance.round_offset + instance.num_rounds) + .ok_or(VerifyStage7Error::InvalidInputLength { + input: instance.symbol, + expected: instance.round_offset + instance.num_rounds, + actual: point.len(), + })?; + let relation = claim.relation.unwrap_or(""); + let value = match relation { + "jolt.stage7.hamming_weight_claim_reduction" => { + expected_hamming_weight_claim_reduction(program, driver, store, evals, local_point)? + } + _ => return Err(VerifyStage7Error::UnsupportedRelation { relation }), + }; + expected += *coefficient * value; + } + Ok(expected) +} + +fn expected_hamming_weight_claim_reduction( + program: &'static Stage7VerifierProgramPlan, + driver: &'static Stage7SumcheckDriverPlan, + store: &super::common::ValueStore, + evals: &[Stage7NamedEval], + local_point: &[Fr], +) -> Result { + let rho_rev = reverse_slice(local_point); + let booleanity_point = super::common::store_point(store, "stage7.input.stage6.booleanity.InstructionRa_0")?; + let r_addr_bool = + booleanity_point + .get(..local_point.len()) + .ok_or(VerifyStage7Error::InvalidInputLength { + input: "stage7.input.stage6.booleanity.InstructionRa_0", + expected: local_point.len(), + actual: booleanity_point.len(), + })?; + let eq_bool = EqPolynomial::::mle(&rho_rev, r_addr_bool); + let gamma = super::common::store_scalar(store, "stage7.hamming_weight_claim_reduction.gamma")?; + let mut gamma_power = Fr::from_u64(1); + let mut expected = Fr::from_u64(0); + let mut eval_plans = program + .evals + .iter() + .filter(|eval| eval.source == driver.symbol) + .collect::>(); + eval_plans.sort_by_key(|eval| eval.index); + for eval_plan in eval_plans { + let g_i = eval_by_name(evals, eval_plan.name)?; + let virt_point = + stage7_virtualization_point(store, eval_plan.oracle, local_point.len())?; + let eq_virt = EqPolynomial::::mle(&rho_rev, virt_point); + expected += g_i * (gamma_power + gamma_power * gamma * eq_bool + + gamma_power * gamma.square() * eq_virt); + gamma_power *= gamma; + gamma_power *= gamma; + gamma_power *= gamma; + } + Ok(expected) +} + +fn stage7_virtualization_point<'a>( + store: &'a super::common::ValueStore, + oracle: &str, + log_k_chunk: usize, +) -> Result<&'a [Fr], VerifyStage7Error> { + let symbol = if oracle.starts_with("InstructionRa_") { + format!("stage7.input.stage6.instruction_ra_virtual.{oracle}") + } else if oracle.starts_with("BytecodeRa_") { + format!("stage7.input.stage6.bytecode_read_raf.{oracle}") + } else if oracle.starts_with("RamRa_") { + format!("stage7.input.stage6.ram_ra_virtual.{oracle}") + } else { + return Err(VerifyStage7Error::MissingValue { + symbol: "stage7.hamming_weight_claim_reduction.oracle", + }); + }; + let point = store.try_point(&symbol).ok_or(VerifyStage7Error::MissingValue { + symbol: "stage7.hamming_weight_claim_reduction.virtualization_point", + })?; + point + .get(..log_k_chunk) + .ok_or(VerifyStage7Error::InvalidInputLength { + input: "stage7.hamming_weight_claim_reduction.virtualization_point", + expected: log_k_chunk, + actual: point.len(), + }) +} + +fn stage7_trace_rounds( + program: &'static Stage7VerifierProgramPlan, +) -> Result { + program + .instance_results + .iter() + .find(|instance| instance.relation == "jolt.stage7.hamming_booleanity") + .map(|instance| instance.num_rounds) + .ok_or(VerifyStage7Error::MissingValue { + symbol: "stage7.hamming_booleanity.instance", + }) +} +"# + } + } + } + + fn role_label(&self) -> &'static str { + match self.role { + Role::Prover => "prover", + Role::Verifier => "verifier", + } + } + + fn program_plan_type(&self) -> &'static str { + match self.role { + Role::Prover => "Stage7CpuProgramPlan", + Role::Verifier => "Stage7VerifierProgramPlan", + } + } +} + +fn require_supported_symbol(kind: &str, actual: &str, expected: &str) -> Result<(), EmitError> { + if actual == expected { + Ok(()) + } else { + Err(EmitError::new(format!( + "unsupported {kind} @{actual}; expected @{expected}" + ))) + } +} + +fn emit_str_array(name: &str, values: &[String]) -> String { + if values.is_empty() { + return format!("pub const {name}: &[&str] = &[];\n\n"); + } + if let [value] = values { + return format!("pub const {name}: &[&str] = &[{}];\n\n", rust_str(value)); + } + let entries = values + .iter() + .map(|value| format!(" {},", rust_str(value))) + .collect::>() + .join("\n"); + format!("pub const {name}: &[&str] = &[\n{entries}\n];\n\n") +} + +fn emit_usize_array(name: &str, values: &[usize]) -> String { + let entries = values + .iter() + .map(|value| format!(" {value},")) + .collect::>() + .join("\n"); + format!("pub const {name}: &[usize] = &[\n{entries}\n];\n\n") +} + +fn intern_str_array( + source: &mut String, + arrays: &mut Vec<(Vec, String)>, + name_prefix: &str, + values: &[String], +) -> String { + if let Some((_, name)) = arrays + .iter() + .find(|(existing, _)| existing.as_slice() == values) + { + return name.clone(); + } + let name = format!("{name_prefix}_{}", arrays.len()); + source.push_str(&emit_str_array(&name, values)); + arrays.push((values.to_vec(), name.clone())); + name +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn rust_option_str(value: Option<&str>) -> String { + value.map_or_else( + || "None".to_owned(), + |value| format!("Some({})", rust_str(value)), + ) +} + +fn verify_count(kind: &str, symbol: &str, expected: usize, actual: usize) -> Result<(), EmitError> { + if expected == actual { + Ok(()) + } else { + Err(EmitError::new(format!( + "{kind} @{symbol} count mismatch: expected {expected}, got {actual}" + ))) + } +} + +fn symbols<'a>(values: impl Iterator) -> BTreeSet { + values.cloned().collect() +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn int_array_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result, EmitError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "integer array"))?; + parse_int_array(&attribute).ok_or_else(|| attr_error(operation, attr, "integer array")) +} + +fn parse_int_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().parse().ok()) + .collect() +} + +fn operand_symbols( + operation: OperationRef<'_, '_>, + start_index: usize, +) -> Result, EmitError> { + (start_index..operation.operand_count()) + .map(|index| operand_symbol(operation, index)) + .collect() +} + +fn operand_symbol(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + EmitError::new(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + EmitError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name<'c: 'a, 'a>(operation: impl OperationLike<'c, 'a>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} diff --git a/crates/bolt/src/protocols/jolt/emit/rust/stage8.rs b/crates/bolt/src/protocols/jolt/emit/rust/stage8.rs new file mode 100644 index 0000000000..9914b4b4cf --- /dev/null +++ b/crates/bolt/src/protocols/jolt/emit/rust/stage8.rs @@ -0,0 +1,516 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::OperationRef; + +use crate::emit::rust::{push_format, EmitError, RustSourceFile}; +use crate::ir::{string_attribute_value, symbol_attribute_value, BoltModule, Cpu, Role}; +use crate::schema::verify_cpu_schema; + +const EVALUATION_POINT_SOURCE_SYMBOL: &str = "stage8.evaluation.point_source"; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage8CpuProgram { + pub role: Role, + pub params: Stage8Params, + pub function: String, + pub opening_inputs: Vec, + pub opening_claims: Vec, + pub opening_batches: Vec, + pub pcs_proofs: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage8Params { + pub field: String, + pub pcs: String, + pub transcript: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage8OpeningInputPlan { + pub symbol: String, + pub source_stage: String, + pub source_claim: String, + pub oracle: String, + pub domain: String, + pub point_arity: usize, + pub claim_kind: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage8OpeningClaimPlan { + pub symbol: String, + pub oracle: String, + pub family: String, + pub domain: String, + pub point_arity: usize, + pub point_source: String, + pub eval_source: String, + pub source_stage: String, + pub source_claim: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage8OpeningBatchPlan { + pub symbol: String, + pub proof_slot: String, + pub policy: String, + pub count: usize, + pub ordered_claims: Vec, + pub claim_operands: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Stage8PcsProofPlan { + pub symbol: String, + pub mode: String, + pub pcs: String, + pub proof_slot: String, + pub transcript_label: String, + pub batch: String, +} + +pub fn stage8_cpu_program(module: &BoltModule<'_, Cpu>) -> Result { + verify_cpu_schema(module)?; + let program = Stage8CpuProgram::from_module(module)?; + program.verify_supported_target()?; + Ok(program) +} + +pub fn emit_stage8_rust(module: &BoltModule<'_, Cpu>) -> Result { + let program = stage8_cpu_program(module)?; + Ok(RustSourceFile { + filename: program.filename().to_owned(), + source: program.emit_source()?, + }) +} + +impl Stage8CpuProgram { + fn from_module(module: &BoltModule<'_, Cpu>) -> Result { + let role = module + .role() + .ok_or_else(|| EmitError::new("stage8 CPU module missing role"))?; + let mut params = None; + let mut function = None; + let mut opening_inputs = Vec::new(); + let mut opening_claims = Vec::new(); + let mut opening_batches = Vec::new(); + let mut pcs_proofs = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "cpu.params" => { + params = Some(Stage8Params { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + "cpu.function" => { + function = Some(string_attr(op, "sym_name")?); + } + "cpu.opening_input" => { + opening_inputs.push(Stage8OpeningInputPlan { + symbol: string_attr(op, "sym_name")?, + source_stage: symbol_attr(op, "source_stage")?, + source_claim: symbol_attr(op, "source_claim")?, + oracle: symbol_attr(op, "oracle")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + claim_kind: string_attr(op, "claim_kind")?, + }); + } + "cpu.pcs_opening_claim" => { + opening_claims.push(Stage8OpeningClaimPlan { + symbol: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + family: symbol_attr(op, "family")?, + domain: symbol_attr(op, "domain")?, + point_arity: int_attr(op, "point_arity")?, + point_source: operand_symbol(op, 0)?, + eval_source: operand_symbol(op, 1)?, + source_stage: String::new(), + source_claim: String::new(), + }); + } + "cpu.pcs_opening_batch" => { + opening_batches.push(Stage8OpeningBatchPlan { + symbol: string_attr(op, "sym_name")?, + proof_slot: symbol_attr(op, "proof_slot")?, + policy: string_attr(op, "policy")?, + count: int_attr(op, "count")?, + ordered_claims: symbol_array_attr(op, "ordered_claims")?, + claim_operands: operand_symbols(op, 0)?, + }); + } + "cpu.pcs_batch_open" | "cpu.pcs_batch_verify" => { + let mode = match operation_name(op).as_str() { + "cpu.pcs_batch_open" => "open", + "cpu.pcs_batch_verify" => "verify", + _ => unreachable!(), + }; + pcs_proofs.push(Stage8PcsProofPlan { + symbol: string_attr(op, "sym_name")?, + mode: mode.to_owned(), + pcs: symbol_attr(op, "pcs")?, + proof_slot: symbol_attr(op, "proof_slot")?, + transcript_label: string_attr(op, "transcript_label")?, + batch: operand_symbol(op, 1)?, + }); + } + _ => {} + } + } + + let input_by_symbol = opening_inputs + .iter() + .map(|input| (input.symbol.as_str(), input)) + .collect::>(); + for claim in &mut opening_claims { + let input = input_by_symbol + .get(claim.point_source.as_str()) + .ok_or_else(|| { + EmitError::new(format!( + "stage8 opening claim `{}` references missing point source `{}`", + claim.symbol, claim.point_source + )) + })?; + claim.source_stage = input.source_stage.clone(); + claim.source_claim = input.source_claim.clone(); + } + + Ok(Self { + role, + params: params.ok_or_else(|| EmitError::new("stage8 program missing cpu.params"))?, + function: function + .ok_or_else(|| EmitError::new("stage8 program missing cpu.function"))?, + opening_inputs, + opening_claims, + opening_batches, + pcs_proofs, + }) + } + + fn verify_supported_target(&self) -> Result<(), EmitError> { + if self.function != "jolt.stage8" { + return Err(EmitError::new(format!( + "stage8 emitter expected function `jolt.stage8`, got `{}`", + self.function + ))); + } + if self.opening_batches.len() != 1 { + return Err(EmitError::new(format!( + "stage8 emitter expects one PCS opening batch, got {}", + self.opening_batches.len() + ))); + } + if self.pcs_proofs.len() != 1 { + return Err(EmitError::new(format!( + "stage8 emitter expects one PCS proof op, got {}", + self.pcs_proofs.len() + ))); + } + let expected_mode = match self.role { + Role::Prover => "open", + Role::Verifier => "verify", + }; + if self.pcs_proofs[0].mode != expected_mode { + return Err(EmitError::new(format!( + "stage8 {} artifact expected PCS mode `{expected_mode}`, got `{}`", + self.role, self.pcs_proofs[0].mode + ))); + } + let batch = &self.opening_batches[0]; + if batch.count != self.opening_claims.len() { + return Err(EmitError::new(format!( + "stage8 opening batch count {} does not match {} opening claims", + batch.count, + self.opening_claims.len() + ))); + } + if batch.ordered_claims != batch.claim_operands { + return Err(EmitError::new( + "stage8 opening batch ordered claims do not match SSA operands", + )); + } + if !self + .opening_inputs + .iter() + .any(|input| input.symbol == EVALUATION_POINT_SOURCE_SYMBOL) + { + return Err(EmitError::new(format!( + "stage8 program missing `{EVALUATION_POINT_SOURCE_SYMBOL}` opening-point source" + ))); + } + let input_symbols = self + .opening_inputs + .iter() + .map(|input| input.symbol.as_str()) + .collect::>(); + for claim in &self.opening_claims { + if !input_symbols.contains(claim.point_source.as_str()) { + return Err(EmitError::new(format!( + "stage8 claim `{}` point source `{}` is not an opening input", + claim.symbol, claim.point_source + ))); + } + if claim.point_source != claim.eval_source { + return Err(EmitError::new(format!( + "stage8 claim `{}` must take point and eval from the same opening input", + claim.symbol + ))); + } + } + Ok(()) + } + + fn filename(&self) -> &'static str { + match self.role { + Role::Prover => "prove_stage8.rs", + Role::Verifier => "verify_stage8.rs", + } + } + + fn emit_source(&self) -> Result { + let mut source = String::new(); + source.push_str("#![allow(clippy::too_many_lines)]\n\n"); + source.push_str("#[derive(Clone, Copy, Debug, PartialEq, Eq)]\n"); + source.push_str( + "pub struct Stage8Params {\n pub field: &'static str,\n pub pcs: &'static str,\n pub transcript: &'static str,\n}\n\n", + ); + source.push_str("#[derive(Clone, Copy, Debug, PartialEq, Eq)]\n"); + source.push_str( + "pub struct Stage8OpeningInputPlan {\n pub symbol: &'static str,\n pub source_stage: &'static str,\n pub source_claim: &'static str,\n pub oracle: &'static str,\n pub domain: &'static str,\n pub point_arity: usize,\n pub claim_kind: &'static str,\n}\n\n", + ); + source.push_str("#[derive(Clone, Copy, Debug, PartialEq, Eq)]\n"); + source.push_str( + "pub struct Stage8OpeningClaimPlan {\n pub symbol: &'static str,\n pub oracle: &'static str,\n pub family: &'static str,\n pub domain: &'static str,\n pub point_arity: usize,\n pub point_source: &'static str,\n pub eval_source: &'static str,\n pub source_stage: &'static str,\n pub source_claim: &'static str,\n}\n\n", + ); + source.push_str("#[derive(Clone, Copy, Debug, PartialEq, Eq)]\n"); + source.push_str( + "pub struct Stage8OpeningBatchPlan {\n pub symbol: &'static str,\n pub proof_slot: &'static str,\n pub policy: &'static str,\n pub count: usize,\n pub ordered_claims: &'static [&'static str],\n}\n\n", + ); + source.push_str("#[derive(Clone, Copy, Debug, PartialEq, Eq)]\n"); + source.push_str( + "pub struct Stage8PcsProofPlan {\n pub symbol: &'static str,\n pub mode: &'static str,\n pub pcs: &'static str,\n pub proof_slot: &'static str,\n pub transcript_label: &'static str,\n pub batch: &'static str,\n}\n\n", + ); + source.push_str("#[derive(Clone, Copy, Debug, PartialEq, Eq)]\n"); + source.push_str( + "pub struct Stage8EvaluationProgramPlan {\n pub role: &'static str,\n pub function: &'static str,\n pub params: Stage8Params,\n pub evaluation_point_source: Stage8OpeningInputPlan,\n pub opening_inputs: &'static [Stage8OpeningInputPlan],\n pub opening_claims: &'static [Stage8OpeningClaimPlan],\n pub opening_batch: Stage8OpeningBatchPlan,\n pub pcs_proof: Stage8PcsProofPlan,\n}\n\n", + ); + + push_format( + &mut source, + format_args!( + "pub const STAGE8_PARAMS: Stage8Params = Stage8Params {{ field: {}, pcs: {}, transcript: {} }};\n\n", + rust_str(&self.params.field), + rust_str(&self.params.pcs), + rust_str(&self.params.transcript), + ), + ); + let point_source = self + .opening_inputs + .iter() + .find(|input| input.symbol == EVALUATION_POINT_SOURCE_SYMBOL) + .ok_or_else(|| { + EmitError::new(format!( + "evaluation program missing `{EVALUATION_POINT_SOURCE_SYMBOL}` opening-point source" + )) + })?; + push_format( + &mut source, + format_args!( + "pub const STAGE8_EVALUATION_POINT_SOURCE: Stage8OpeningInputPlan = {};\n\n", + opening_input_literal(point_source), + ), + ); + source.push_str("pub const STAGE8_OPENING_INPUTS: &[Stage8OpeningInputPlan] = &[\n"); + for input in &self.opening_inputs { + push_format( + &mut source, + format_args!(" {},\n", opening_input_literal(input)), + ); + } + source.push_str("];\n\n"); + source.push_str("pub const STAGE8_OPENING_CLAIMS: &[Stage8OpeningClaimPlan] = &[\n"); + for claim in &self.opening_claims { + push_format( + &mut source, + format_args!( + " Stage8OpeningClaimPlan {{ symbol: {}, oracle: {}, family: {}, domain: {}, point_arity: {}, point_source: {}, eval_source: {}, source_stage: {}, source_claim: {} }},\n", + rust_str(&claim.symbol), + rust_str(&claim.oracle), + rust_str(&claim.family), + rust_str(&claim.domain), + claim.point_arity, + rust_str(&claim.point_source), + rust_str(&claim.eval_source), + rust_str(&claim.source_stage), + rust_str(&claim.source_claim), + ), + ); + } + source.push_str("];\n\n"); + let batch = &self.opening_batches[0]; + push_format( + &mut source, + format_args!( + "pub const STAGE8_OPENING_BATCH_ORDERED_CLAIMS: &[&str] = &{};\n\n", + rust_str_array(&batch.ordered_claims), + ), + ); + push_format( + &mut source, + format_args!( + "pub const STAGE8_OPENING_BATCH: Stage8OpeningBatchPlan = Stage8OpeningBatchPlan {{ symbol: {}, proof_slot: {}, policy: {}, count: {}, ordered_claims: STAGE8_OPENING_BATCH_ORDERED_CLAIMS }};\n\n", + rust_str(&batch.symbol), + rust_str(&batch.proof_slot), + rust_str(&batch.policy), + batch.count, + ), + ); + let proof = &self.pcs_proofs[0]; + push_format( + &mut source, + format_args!( + "pub const STAGE8_PCS_PROOF: Stage8PcsProofPlan = Stage8PcsProofPlan {{ symbol: {}, mode: {}, pcs: {}, proof_slot: {}, transcript_label: {}, batch: {} }};\n\n", + rust_str(&proof.symbol), + rust_str(&proof.mode), + rust_str(&proof.pcs), + rust_str(&proof.proof_slot), + rust_str(&proof.transcript_label), + rust_str(&proof.batch), + ), + ); + push_format( + &mut source, + format_args!( + "pub const STAGE8_PROGRAM: Stage8EvaluationProgramPlan = Stage8EvaluationProgramPlan {{\n role: {},\n function: {},\n params: STAGE8_PARAMS,\n evaluation_point_source: STAGE8_EVALUATION_POINT_SOURCE,\n opening_inputs: STAGE8_OPENING_INPUTS,\n opening_claims: STAGE8_OPENING_CLAIMS,\n opening_batch: STAGE8_OPENING_BATCH,\n pcs_proof: STAGE8_PCS_PROOF,\n}};\n", + rust_str(self.role.as_str()), + rust_str(&self.function), + ), + ); + Ok(source) + } +} + +fn opening_input_literal(input: &Stage8OpeningInputPlan) -> String { + format!( + "Stage8OpeningInputPlan {{ symbol: {}, source_stage: {}, source_claim: {}, oracle: {}, domain: {}, point_arity: {}, claim_kind: {} }}", + rust_str(&input.symbol), + rust_str(&input.source_stage), + rust_str(&input.source_claim), + rust_str(&input.oracle), + rust_str(&input.domain), + input.point_arity, + rust_str(&input.claim_kind), + ) +} + +fn rust_str(value: &str) -> String { + format!("{value:?}") +} + +fn rust_str_array(values: &[String]) -> String { + let values = values + .iter() + .map(|value| rust_str(value)) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +fn symbol_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol reference")) +} + +fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + let value = operation + .attribute(attr) + .ok() + .and_then(|attr| attr.to_string().strip_suffix(" : i64").map(str::to_owned)) + .ok_or_else(|| attr_error(operation, attr, "integer"))?; + value + .parse() + .map_err(|_| attr_error(operation, attr, "integer")) +} + +fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, EmitError> { + let value = operation + .attribute(attr) + .ok() + .map(|attr| attr.to_string()) + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&value).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(str::to_owned)) + .collect() +} + +fn operand_symbols( + operation: OperationRef<'_, '_>, + start_index: usize, +) -> Result, EmitError> { + (start_index..operation.operand_count()) + .map(|index| operand_symbol(operation, index)) + .collect() +} + +fn operand_symbol(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + EmitError::new(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + EmitError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> EmitError { + EmitError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +fn operation_name<'c: 'a, 'a>(operation: impl OperationLike<'c, 'a>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} diff --git a/crates/bolt/src/protocols/jolt/mod.rs b/crates/bolt/src/protocols/jolt/mod.rs new file mode 100644 index 0000000000..91da996a0e --- /dev/null +++ b/crates/bolt/src/protocols/jolt/mod.rs @@ -0,0 +1,41 @@ +pub mod artifacts; +pub mod emit; +pub mod oracles; +pub mod params; +pub mod phases; +pub mod validate; + +pub use artifacts::{ + assemble_jolt_generated_crates, assemble_jolt_workspace_generated_crates, jolt_artifact_config, + jolt_rust_artifact, validate_jolt_rust_artifact_imports, write_jolt_generated_crates, + JoltArtifactCrate, JoltGeneratedCrate, JoltGeneratedFile, JoltProtocolStage, JoltRustArtifact, +}; +pub use emit::rust::{ + commitment_cpu_program, emit_commitment_rust, emit_stage1_rust, emit_stage2_rust, + emit_stage3_rust, emit_stage4_rust, emit_stage5_rust, emit_stage6_rust, emit_stage7_rust, + emit_stage8_rust, stage1_cpu_program, stage2_cpu_program, stage3_cpu_program, + stage4_cpu_program, stage5_cpu_program, stage6_cpu_program, stage7_cpu_program, + stage8_cpu_program, CommitmentBatchPlan, CommitmentCpuProgram, CommitmentParams, + OptionalCommitmentPlan, OptionalSkipPolicy, OracleGeneration, OraclePlan, Stage1CpuProgram, + Stage1KernelPlan, Stage1OpeningBatchPlan, Stage1OpeningClaimPlan, Stage1Params, + Stage1SumcheckBatchPlan, Stage1SumcheckClaimPlan, Stage1SumcheckDriverPlan, + Stage1SumcheckEvalPlan, Stage2CpuProgram, Stage3CpuProgram, Stage4CpuProgram, Stage5CpuProgram, + Stage6CpuProgram, Stage7CpuProgram, Stage8CpuProgram, TranscriptStep, +}; +pub use params::JoltProtocolParams; +pub use phases::commitment::{ + build_commitment_protocol, lower_commitment_to_compute, lower_compute_to_cpu, +}; +pub use phases::stage1::{ + build_stage1_outer_protocol, lower_stage1_to_compute, resolve_compute_kernels, +}; +pub use phases::stage2::{build_stage2_protocol, lower_stage2_to_compute}; +pub use phases::stage3::{build_stage3_protocol, lower_stage3_to_compute}; +pub use phases::stage4::{build_stage4_protocol, lower_stage4_to_compute}; +pub use phases::stage5::{build_stage5_protocol, lower_stage5_to_compute}; +pub use phases::stage6::{build_stage6_protocol, lower_stage6_to_compute}; +pub use phases::stage7::{build_stage7_protocol, lower_stage7_to_compute}; +pub use phases::stage8::{build_stage8_protocol, lower_stage8_to_compute}; +pub use validate::{ + verify_jolt_concrete_schema, verify_jolt_party_schema, verify_jolt_protocol_schema, +}; diff --git a/crates/bolt/src/protocols/jolt/oracles.rs b/crates/bolt/src/protocols/jolt/oracles.rs new file mode 100644 index 0000000000..46e66304b2 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/oracles.rs @@ -0,0 +1,227 @@ +use crate::ir::{BoltModule, Protocol}; +use crate::mlir::{MeliorContext, MlirError}; + +use super::params::JoltProtocolParams; + +pub const FIELD_SYMBOL: &str = "bn254_fr"; +pub const HASH_SYMBOL: &str = "blake2b"; +pub const TRANSCRIPT_SYMBOL: &str = "blake2b_transcript"; +pub const PCS_SYMBOL: &str = "dory"; +pub const TRACE_DOMAIN_SYMBOL: &str = "jolt.trace_domain"; +pub const MAIN_WITNESS_COMMIT_DOMAIN_SYMBOL: &str = "jolt.main_witness_commit_domain"; +pub const MAIN_WITNESS_FAMILY_SYMBOL: &str = "jolt.main_witness_polys"; +pub const ADVICE_FAMILY_SYMBOL: &str = "jolt.advice_polys"; + +pub fn main_witness_oracles(params: &JoltProtocolParams) -> Vec { + let mut oracles = vec!["RdInc".to_owned(), "RamInc".to_owned()]; + oracles.extend((0..params.instruction_d).map(|index| format!("InstructionRa_{index}"))); + oracles.extend((0..params.ram_d).map(|index| format!("RamRa_{index}"))); + oracles.extend((0..params.bytecode_d).map(|index| format!("BytecodeRa_{index}"))); + oracles +} + +pub fn main_witness_oracle_attr(params: &JoltProtocolParams) -> String { + symbol_array_attr(&main_witness_oracles(params)) +} + +pub fn append_foundation_ops<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + context.append_op( + module, + "field.define", + Some(FIELD_SYMBOL), + &[("modulus_bits", "254 : i64"), ("role", r#""scalar""#)], + )?; + context.append_op( + module, + "hash.function", + Some(HASH_SYMBOL), + &[("algorithm", r#""blake2b""#)], + )?; + context.append_op( + module, + "transcript.scheme", + Some(TRANSCRIPT_SYMBOL), + &[("hash", "@blake2b")], + )?; + context.append_op( + module, + "pcs.scheme", + Some(PCS_SYMBOL), + &[("field", "@bn254_fr")], + )?; + context.append_op( + module, + "poly.domain", + Some(TRACE_DOMAIN_SYMBOL), + &[ + ("field", "@bn254_fr"), + ("log_size", &format!("{} : i64", params.log_t)), + ], + )?; + context.append_op( + module, + "poly.domain", + Some(MAIN_WITNESS_COMMIT_DOMAIN_SYMBOL), + &[ + ("field", "@bn254_fr"), + ( + "log_size", + &format!("{} : i64", params.log_t + params.log_k_chunk), + ), + ], + )?; + Ok(()) +} + +pub fn append_committed_oracles<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_oracle( + context, + module, + OracleSpec { + symbol: "RdInc".to_owned(), + domain: "@jolt.trace_domain", + commit_domain: "@jolt.main_witness_commit_domain", + layout: "dense_trace", + visibility: "committed", + extra_attrs: Vec::new(), + }, + )?; + append_oracle( + context, + module, + OracleSpec { + symbol: "RamInc".to_owned(), + domain: "@jolt.trace_domain", + commit_domain: "@jolt.main_witness_commit_domain", + layout: "dense_trace", + visibility: "committed", + extra_attrs: Vec::new(), + }, + )?; + for index in 0..params.instruction_d { + append_indexed_oracle( + context, + module, + "InstructionRa", + index, + "@jolt.main_witness_commit_domain", + )?; + } + for index in 0..params.ram_d { + append_indexed_oracle( + context, + module, + "RamRa", + index, + "@jolt.main_witness_commit_domain", + )?; + } + for index in 0..params.bytecode_d { + append_indexed_oracle( + context, + module, + "BytecodeRa", + index, + "@jolt.main_witness_commit_domain", + )?; + } + append_oracle( + context, + module, + OracleSpec { + symbol: "UntrustedAdvice".to_owned(), + domain: "@jolt.trace_domain", + commit_domain: "@jolt.trace_domain", + layout: "dense_trace", + visibility: "optional_committed", + extra_attrs: Vec::new(), + }, + )?; + append_oracle( + context, + module, + OracleSpec { + symbol: "TrustedAdvice".to_owned(), + domain: "@jolt.trace_domain", + commit_domain: "@jolt.trace_domain", + layout: "dense_trace", + visibility: "optional_committed", + extra_attrs: Vec::new(), + }, + )?; + Ok(()) +} + +fn append_indexed_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + family: &str, + index: usize, + domain: &str, +) -> Result<(), MlirError> { + append_oracle( + context, + module, + OracleSpec { + symbol: format!("{family}_{index}"), + domain, + commit_domain: "@jolt.main_witness_commit_domain", + layout: "onehot_expanded", + visibility: "committed", + extra_attrs: vec![ + ("family", format!("@{family}")), + ("index", format!("{index} : i64")), + ], + }, + ) +} + +struct OracleSpec<'a> { + symbol: String, + domain: &'a str, + commit_domain: &'a str, + layout: &'a str, + visibility: &'a str, + extra_attrs: Vec<(&'a str, String)>, +} + +fn append_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + spec: OracleSpec<'_>, +) -> Result<(), MlirError> { + let mut attrs = vec![ + ("field", "@bn254_fr".to_owned()), + ("domain", spec.domain.to_owned()), + ("commit_domain", spec.commit_domain.to_owned()), + ("visibility", format!("\"{}\"", spec.visibility)), + ("layout", format!("\"{}\"", spec.layout)), + ]; + attrs.extend(spec.extra_attrs); + context.append_op_with_owned_attrs( + module, + "piop.oracle", + Some(&spec.symbol), + &attrs + .into_iter() + .map(|(name, value)| (name.to_owned(), value)) + .collect::>(), + ) +} + +fn symbol_array_attr(values: &[String]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} diff --git a/crates/bolt/src/protocols/jolt/params.rs b/crates/bolt/src/protocols/jolt/params.rs new file mode 100644 index 0000000000..5dd94687dd --- /dev/null +++ b/crates/bolt/src/protocols/jolt/params.rs @@ -0,0 +1,305 @@ +use melior::ir::OperationRef; + +use crate::schema::{ + int_attr as mlir_int_attr, require_attrs, symbol_attr as mlir_symbol_attr, SchemaError, +}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct JoltProtocolParams { + pub field: &'static str, + pub pcs: &'static str, + pub transcript: &'static str, + pub xlen: usize, + pub log_t: usize, + pub trace_length: usize, + pub log_k_bytecode: usize, + pub bytecode_k: usize, + pub log_k_ram: usize, + pub ram_k: usize, + pub log_k_chunk: usize, + pub k_chunk: usize, + pub lookups_ra_virtual_log_k_chunk: usize, + pub instruction_log_k: usize, + pub register_log_k: usize, + pub lookup_table_count: usize, + pub instruction_d: usize, + pub instruction_ra_virtual_d: usize, + pub bytecode_d: usize, + pub ram_d: usize, + pub num_committed: usize, + pub num_r1cs_constraints: usize, + pub num_r1cs_inputs: usize, + pub num_vars_padded: usize, +} + +impl JoltProtocolParams { + pub fn new(log_t: usize, log_k_bytecode: usize, log_k_ram: usize) -> Self { + let log_k_chunk = if log_t < 25 { 4 } else { 8 }; + let instruction_log_k = 128; + let lookups_ra_virtual_log_k_chunk = if log_t < 25 { + instruction_log_k / 8 + } else { + instruction_log_k / 4 + }; + let instruction_d = instruction_log_k / log_k_chunk; + let instruction_ra_virtual_d = instruction_log_k / lookups_ra_virtual_log_k_chunk; + let bytecode_d = log_k_bytecode.div_ceil(log_k_chunk); + let ram_d = log_k_ram.div_ceil(log_k_chunk); + Self { + field: "bn254_fr", + pcs: "dory", + transcript: "blake2b_transcript", + xlen: 64, + log_t, + trace_length: 1usize << log_t, + log_k_bytecode, + bytecode_k: 1usize << log_k_bytecode, + log_k_ram, + ram_k: 1usize << log_k_ram, + log_k_chunk, + k_chunk: 1usize << log_k_chunk, + lookups_ra_virtual_log_k_chunk, + instruction_log_k, + register_log_k: 7, + lookup_table_count: 40, + instruction_d, + instruction_ra_virtual_d, + bytecode_d, + ram_d, + num_committed: 2 + instruction_d + bytecode_d + ram_d, + num_r1cs_constraints: 19, + num_r1cs_inputs: 35, + num_vars_padded: 64, + } + } + + pub fn fixture() -> Self { + Self::new(16, 10, 16) + } + + pub fn attrs(&self) -> Vec<(String, String)> { + vec![ + symbol_attr("field", self.field), + symbol_attr("pcs", self.pcs), + symbol_attr("transcript", self.transcript), + int_attr("xlen", self.xlen), + int_attr("log_t", self.log_t), + int_attr("trace_length", self.trace_length), + int_attr("log_k_bytecode", self.log_k_bytecode), + int_attr("bytecode_k", self.bytecode_k), + int_attr("log_k_ram", self.log_k_ram), + int_attr("ram_k", self.ram_k), + int_attr("log_k_chunk", self.log_k_chunk), + int_attr("k_chunk", self.k_chunk), + int_attr( + "lookups_ra_virtual_log_k_chunk", + self.lookups_ra_virtual_log_k_chunk, + ), + int_attr("instruction_log_k", self.instruction_log_k), + int_attr("register_log_k", self.register_log_k), + int_attr("lookup_table_count", self.lookup_table_count), + int_attr("instruction_d", self.instruction_d), + int_attr("instruction_ra_virtual_d", self.instruction_ra_virtual_d), + int_attr("bytecode_d", self.bytecode_d), + int_attr("ram_d", self.ram_d), + int_attr("num_committed", self.num_committed), + int_attr("num_r1cs_constraints", self.num_r1cs_constraints), + int_attr("num_r1cs_inputs", self.num_r1cs_inputs), + int_attr("num_vars_padded", self.num_vars_padded), + ] + } +} + +fn symbol_attr(name: &str, value: &str) -> (String, String) { + (name.to_owned(), format!("@{value}")) +} + +fn int_attr(name: &str, value: usize) -> (String, String) { + (name.to_owned(), format!("{value} : i64")) +} + +#[derive(Clone, Debug)] +pub(crate) struct ParsedJoltProtocolParams { + pub(crate) field: String, + pub(crate) pcs: String, + pub(crate) transcript: String, + pub(crate) log_t: usize, + pub(crate) trace_length: usize, + pub(crate) log_k_bytecode: usize, + pub(crate) bytecode_k: usize, + pub(crate) log_k_ram: usize, + pub(crate) ram_k: usize, + pub(crate) log_k_chunk: usize, + pub(crate) k_chunk: usize, + pub(crate) lookups_ra_virtual_log_k_chunk: usize, + pub(crate) instruction_log_k: usize, + pub(crate) instruction_d: usize, + pub(crate) instruction_ra_virtual_d: usize, + pub(crate) bytecode_d: usize, + pub(crate) ram_d: usize, + pub(crate) num_committed: usize, +} + +impl ParsedJoltProtocolParams { + pub(crate) fn from_op(operation: OperationRef<'_, '_>) -> Result { + require_jolt_params_attrs(operation)?; + Ok(Self { + field: mlir_symbol_attr(operation, "field")?, + pcs: mlir_symbol_attr(operation, "pcs")?, + transcript: mlir_symbol_attr(operation, "transcript")?, + log_t: mlir_int_attr(operation, "log_t")?, + trace_length: mlir_int_attr(operation, "trace_length")?, + log_k_bytecode: mlir_int_attr(operation, "log_k_bytecode")?, + bytecode_k: mlir_int_attr(operation, "bytecode_k")?, + log_k_ram: mlir_int_attr(operation, "log_k_ram")?, + ram_k: mlir_int_attr(operation, "ram_k")?, + log_k_chunk: mlir_int_attr(operation, "log_k_chunk")?, + k_chunk: mlir_int_attr(operation, "k_chunk")?, + lookups_ra_virtual_log_k_chunk: mlir_int_attr( + operation, + "lookups_ra_virtual_log_k_chunk", + )?, + instruction_log_k: mlir_int_attr(operation, "instruction_log_k")?, + instruction_d: mlir_int_attr(operation, "instruction_d")?, + instruction_ra_virtual_d: mlir_int_attr(operation, "instruction_ra_virtual_d")?, + bytecode_d: mlir_int_attr(operation, "bytecode_d")?, + ram_d: mlir_int_attr(operation, "ram_d")?, + num_committed: mlir_int_attr(operation, "num_committed")?, + }) + } + + pub(crate) fn validate(&self) -> Result<(), SchemaError> { + require_power_relation("trace_length", self.trace_length, "log_t", self.log_t)?; + require_power_relation( + "bytecode_k", + self.bytecode_k, + "log_k_bytecode", + self.log_k_bytecode, + )?; + require_power_relation("ram_k", self.ram_k, "log_k_ram", self.log_k_ram)?; + require_power_relation("k_chunk", self.k_chunk, "log_k_chunk", self.log_k_chunk)?; + + if self.log_k_chunk != 4 && self.log_k_chunk != 8 { + return Err(SchemaError::new(format!( + "log_k_chunk must be 4 or 8, got {}", + self.log_k_chunk + ))); + } + if self.instruction_log_k != 128 { + return Err(SchemaError::new(format!( + "instruction_log_k must be 128, got {}", + self.instruction_log_k + ))); + } + if self.lookups_ra_virtual_log_k_chunk < self.log_k_chunk { + return Err(SchemaError::new(format!( + "lookups_ra_virtual_log_k_chunk must be >= log_k_chunk; got {} < {}", + self.lookups_ra_virtual_log_k_chunk, self.log_k_chunk + ))); + } + if !self + .lookups_ra_virtual_log_k_chunk + .is_multiple_of(self.log_k_chunk) + { + return Err(SchemaError::new(format!( + "lookups_ra_virtual_log_k_chunk must be a multiple of log_k_chunk; got {} and {}", + self.lookups_ra_virtual_log_k_chunk, self.log_k_chunk + ))); + } + if !self + .instruction_log_k + .is_multiple_of(self.lookups_ra_virtual_log_k_chunk) + { + return Err(SchemaError::new(format!( + "instruction_log_k must be divisible by lookups_ra_virtual_log_k_chunk; got {} and {}", + self.instruction_log_k, self.lookups_ra_virtual_log_k_chunk + ))); + } + + let instruction_d = self.instruction_log_k / self.log_k_chunk; + let instruction_ra_virtual_d = self.instruction_log_k / self.lookups_ra_virtual_log_k_chunk; + let bytecode_d = self.log_k_bytecode.div_ceil(self.log_k_chunk); + let ram_d = self.log_k_ram.div_ceil(self.log_k_chunk); + require_eq("instruction_d", self.instruction_d, instruction_d)?; + require_eq( + "instruction_ra_virtual_d", + self.instruction_ra_virtual_d, + instruction_ra_virtual_d, + )?; + require_eq("bytecode_d", self.bytecode_d, bytecode_d)?; + require_eq("ram_d", self.ram_d, ram_d)?; + require_eq( + "num_committed", + self.num_committed, + 2 + instruction_d + bytecode_d + ram_d, + )?; + Ok(()) + } + + pub(crate) fn main_witness_oracles(&self) -> Vec { + let mut oracles = vec!["RdInc".to_owned(), "RamInc".to_owned()]; + oracles.extend((0..self.instruction_d).map(|index| format!("InstructionRa_{index}"))); + oracles.extend((0..self.ram_d).map(|index| format!("RamRa_{index}"))); + oracles.extend((0..self.bytecode_d).map(|index| format!("BytecodeRa_{index}"))); + oracles + } +} + +fn require_jolt_params_attrs(operation: OperationRef<'_, '_>) -> Result<(), SchemaError> { + require_attrs( + operation, + &[ + "sym_name", + "field", + "pcs", + "transcript", + "xlen", + "log_t", + "trace_length", + "log_k_bytecode", + "bytecode_k", + "log_k_ram", + "ram_k", + "log_k_chunk", + "k_chunk", + "lookups_ra_virtual_log_k_chunk", + "instruction_log_k", + "register_log_k", + "lookup_table_count", + "instruction_d", + "instruction_ra_virtual_d", + "bytecode_d", + "ram_d", + "num_committed", + "num_r1cs_constraints", + "num_r1cs_inputs", + "num_vars_padded", + ], + ) +} + +fn require_power_relation( + value_name: &str, + value: usize, + log_name: &str, + log_value: usize, +) -> Result<(), SchemaError> { + let expected = 1usize << log_value; + if value == expected { + Ok(()) + } else { + Err(SchemaError::new(format!( + "{value_name} must equal 2^{log_name}; got {value}, expected {expected}" + ))) + } +} + +fn require_eq(name: &str, actual: usize, expected: usize) -> Result<(), SchemaError> { + if actual == expected { + Ok(()) + } else { + Err(SchemaError::new(format!( + "{name} must be {expected}, got {actual}" + ))) + } +} diff --git a/crates/bolt/src/protocols/jolt/phases/commitment.rs b/crates/bolt/src/protocols/jolt/phases/commitment.rs new file mode 100644 index 0000000000..10327884cd --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/commitment.rs @@ -0,0 +1,1663 @@ +use std::collections::BTreeMap; + +use melior::ir::block::BlockLike; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::{OperationRef, Value}; + +use crate::ir::{BoltModule, Compute, Cpu, Party, Protocol, Role}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{ + int_attr, operation_name, symbol_array_attr, symbol_attr, verify_compute_schema, + verify_cpu_schema, SchemaError, +}; + +use super::super::oracles::{self, ADVICE_FAMILY_SYMBOL, MAIN_WITNESS_FAMILY_SYMBOL, PCS_SYMBOL}; +use super::super::params::JoltProtocolParams; +use super::super::validate::{verify_jolt_party_schema, verify_jolt_protocol_schema}; +use super::lowering::{ + copy_attrs, field_lowering_attrs as compute_field_attrs, string_attr, + transcript_squeeze_cpu_result_types, +}; + +pub fn build_commitment_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.commitment_phase", None); + oracles::append_foundation_ops(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.commitment_phase"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + oracles::append_committed_oracles(context, &module, params)?; + context.append_op( + &module, + "piop.oracle_family", + Some(MAIN_WITNESS_FAMILY_SYMBOL), + &[ + ( + "ordered_oracles", + &oracles::main_witness_oracle_attr(params), + ), + ("count", &format!("{} : i64", params.num_committed)), + ("domain", "@jolt.main_witness_commit_domain"), + ("visibility", r#""committed""#), + ], + )?; + context.append_op( + &module, + "piop.oracle_family", + Some(ADVICE_FAMILY_SYMBOL), + &[ + ("ordered_oracles", "[@UntrustedAdvice, @TrustedAdvice]"), + ("count", "2 : i64"), + ("domain", "@jolt.trace_domain"), + ("visibility", r#""optional_committed""#), + ], + )?; + let state = context.append_typed_op( + &module, + "transcript.state", + Some("fs0"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let mut state = state + .result(0) + .map_err(|_| schema_error("transcript.state requires one result"))? + .into(); + let main_commitments = context.append_typed_op( + &module, + "commit.publish_batch", + Some("jolt.main_witness_commitments"), + &[ + ("oracle_family", "@jolt.main_witness_polys"), + ("label", r#""commitment""#), + ], + &[], + &["!commit.artifact"], + )?; + let main_commitments = main_commitments + .result(0) + .map_err(|_| schema_error("commit.publish_batch requires one result"))? + .into(); + let _pcs_commit = context.append_typed_op( + &module, + "pcs.commit_batch", + Some("jolt.dory_main_witness_commit"), + &[("scheme", &format!("@{PCS_SYMBOL}"))], + &[main_commitments], + &[], + )?; + let untrusted_advice = context.append_typed_op( + &module, + "commit.publish_optional", + Some("jolt.untrusted_advice_commitment"), + &[ + ("oracle", "@UntrustedAdvice"), + ("label", r#""untrusted_advice""#), + ("skip_policy", r#""missing_or_zero""#), + ], + &[], + &["!commit.artifact"], + )?; + let untrusted_advice = untrusted_advice + .result(0) + .map_err(|_| schema_error("commit.publish_optional requires one result"))? + .into(); + let trusted_advice = context.append_typed_op( + &module, + "commit.publish_optional", + Some("jolt.trusted_advice_commitment"), + &[ + ("oracle", "@TrustedAdvice"), + ("label", r#""trusted_advice""#), + ("skip_policy", r#""missing_or_zero""#), + ], + &[], + &["!commit.artifact"], + )?; + let trusted_advice = trusted_advice + .result(0) + .map_err(|_| schema_error("commit.publish_optional requires one result"))? + .into(); + let absorb = context.append_typed_op( + &module, + "transcript.absorb", + Some("jolt.absorb_main_witness_commitments"), + &[("label", r#""commitment""#)], + &[state, main_commitments], + &["!transcript.state_type"], + )?; + state = absorb + .result(0) + .map_err(|_| schema_error("transcript.absorb requires one result"))? + .into(); + let absorb = context.append_typed_op( + &module, + "transcript.absorb_optional", + Some("jolt.absorb_untrusted_advice"), + &[("label", r#""untrusted_advice""#)], + &[state, untrusted_advice], + &["!transcript.state_type"], + )?; + state = absorb + .result(0) + .map_err(|_| schema_error("transcript.absorb_optional requires one result"))? + .into(); + let _absorb = context.append_typed_op( + &module, + "transcript.absorb_optional", + Some("jolt.absorb_trusted_advice"), + &[("label", r#""trusted_advice""#)], + &[state, trusted_advice], + &["!transcript.state_type"], + )?; + verify_module(&module)?; + verify_jolt_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_commitment_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + verify_jolt_party_schema(module)?; + let role = module + .role() + .ok_or_else(|| schema_error("commitment lowering requires party role"))?; + let concrete = analyze_concrete(module)?; + let (batch_op, optional_op) = match role { + Role::Prover => ("compute.pcs_commit_batch", "compute.pcs_commit_optional"), + Role::Verifier => ("compute.pcs_receive_batch", "compute.pcs_receive_optional"), + }; + let module_name = module.name(); + let compute = context.new_module::(&module_name, Some(role.clone())); + context.append_op_with_owned_attrs( + &compute, + "compute.params", + Some("jolt.compute_params"), + &[ + ("field".to_owned(), symbol_ref(&concrete.params.field)), + ("pcs".to_owned(), symbol_ref(&concrete.params.pcs)), + ( + "transcript".to_owned(), + symbol_ref(&concrete.params.transcript), + ), + ], + )?; + context.append_op( + &compute, + "compute.function", + Some("jolt.commitment_phase"), + &[("source", "@jolt.commitment_phase")], + )?; + + let mut artifact_values = BTreeMap::new(); + let transcript_scheme = symbol_ref(&concrete.params.transcript); + let transcript_init = context.append_typed_op( + &compute, + "compute.transcript_init", + Some("fs0"), + &[("scheme", transcript_scheme.as_str())], + &[], + &["!compute.transcript_state"], + )?; + let mut transcript_value = first_result(transcript_init, "compute.transcript_init")?; + + for plan in &concrete.batch_plans { + let family_symbol = format!("{}.oracle_family.compute", plan.oracle_family); + let family_init = context.append_typed_op_with_owned_attrs( + &compute, + "compute.oracle_family_init", + Some(&family_symbol), + &[ + ("family".to_owned(), symbol_ref(&plan.oracle_family)), + ("count".to_owned(), int_attr_source(plan.count)), + ], + &[], + &["!compute.oracle_family"], + )?; + let mut family_value = first_result(family_init, "compute.oracle_family_init")?; + for (index, oracle) in plan.oracles.iter().enumerate() { + let oracle_buffer = concrete.oracle_buffers.get(oracle).ok_or_else(|| { + schema_error(format!( + "batch commitment references missing oracle buffer @{oracle}" + )) + })?; + let oracle_value = append_oracle_buffer( + context, + &compute, + &role, + &concrete.params, + oracle, + &oracle_buffer.domain, + oracle_buffer.num_vars, + )?; + let append_symbol = format!("{}.append_{index}.compute", plan.oracle_family); + let append = context.append_typed_op_with_owned_attrs( + &compute, + "compute.oracle_family_append", + Some(&append_symbol), + &[ + ("family".to_owned(), symbol_ref(&plan.oracle_family)), + ("oracle".to_owned(), symbol_ref(oracle)), + ("index".to_owned(), int_attr_source(index)), + ], + &[family_value, oracle_value], + &["!compute.oracle_family"], + )?; + family_value = first_result(append, "compute.oracle_family_append")?; + } + let symbol = format!("{}.compute", plan.artifact); + let attrs = vec![ + ("artifact".to_owned(), symbol_ref(&plan.artifact)), + ("count".to_owned(), int_attr_source(plan.count)), + ("domain".to_owned(), symbol_ref(&plan.domain)), + ("label".to_owned(), string_attr_source(&plan.label)), + ("num_vars".to_owned(), int_attr_source(plan.num_vars)), + ("oracle_family".to_owned(), symbol_ref(&plan.oracle_family)), + ( + "ordered_oracles".to_owned(), + symbol_array_attr_source(&plan.oracles), + ), + ("pcs".to_owned(), symbol_ref(&plan.pcs)), + ]; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + batch_op, + Some(&symbol), + &attrs, + &[family_value], + &["!compute.commitment_artifact"], + )?; + let value = first_result(operation, batch_op)?; + let inserted = artifact_values.insert(plan.artifact.clone(), value); + debug_assert!(inserted.is_none()); + } + for plan in &concrete.optional_plans { + let oracle_value = append_optional_oracle_buffer( + context, + &compute, + &role, + &plan.oracle, + &plan.domain, + plan.num_vars, + &plan.skip_policy, + )?; + let symbol = format!("{}.compute", plan.artifact); + let attrs = vec![ + ("artifact".to_owned(), symbol_ref(&plan.artifact)), + ("domain".to_owned(), symbol_ref(&plan.domain)), + ("label".to_owned(), string_attr_source(&plan.label)), + ("num_vars".to_owned(), int_attr_source(plan.num_vars)), + ("oracle".to_owned(), symbol_ref(&plan.oracle)), + ("pcs".to_owned(), symbol_ref(&plan.pcs)), + ( + "skip_policy".to_owned(), + string_attr_source(&plan.skip_policy), + ), + ]; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + optional_op, + Some(&symbol), + &attrs, + &[oracle_value], + &["!compute.commitment_artifact"], + )?; + let value = first_result(operation, optional_op)?; + let inserted = artifact_values.insert(plan.artifact.clone(), value); + debug_assert!(inserted.is_none()); + } + for step in &concrete.transcript_steps { + let artifact = artifact_values.get(&step.source).copied().ok_or_else(|| { + schema_error(format!( + "transcript absorb @{} references missing commitment artifact @{}", + step.symbol, step.source + )) + })?; + let symbol = format!("{}.compute", step.symbol); + let attrs = vec![ + ("label".to_owned(), string_attr_source(&step.label)), + ( + "optional".to_owned(), + bool_attr_source(step.optional).to_owned(), + ), + ]; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_absorb", + Some(&symbol), + &attrs, + &[transcript_value, artifact], + &["!compute.transcript_state"], + )?; + transcript_value = first_result(operation, "compute.transcript_absorb")?; + } + verify_module(&compute)?; + verify_compute_schema(&compute)?; + Ok(compute) +} + +pub fn lower_compute_to_cpu<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Compute>, +) -> Result, MlirError> { + verify_compute_schema(module)?; + let role = module + .role() + .ok_or_else(|| schema_error("CPU lowering requires compute party role"))?; + let module_name = module.name(); + let cpu = context.new_module::(&module_name, Some(role)); + let mut value_map = BTreeMap::new(); + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "compute.function" => { + let source = symbol_ref(&symbol_attr(op, "source")?); + let symbol = string_attr(op, "sym_name")?; + context.append_op(&cpu, "cpu.function", Some(&symbol), &[("source", &source)])?; + } + "compute.params" => { + let symbol = string_attr(op, "sym_name")?; + context.append_op_with_owned_attrs( + &cpu, + "cpu.params", + Some(&symbol), + &[ + ("field".to_owned(), symbol_ref(&symbol_attr(op, "field")?)), + ("pcs".to_owned(), symbol_ref(&symbol_attr(op, "pcs")?)), + ( + "transcript".to_owned(), + symbol_ref(&symbol_attr(op, "transcript")?), + ), + ], + )?; + } + "compute.kernel" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["relation", "kind", "backend", "abi"])?; + context.append_op_with_owned_attrs(&cpu, "cpu.kernel", Some(&symbol), &attrs)?; + } + "compute.transcript_init" => { + let attrs = vec![("scheme".to_owned(), symbol_ref(&symbol_attr(op, "scheme")?))]; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.transcript_init", + Some(&symbol), + &attrs, + &[], + &["!cpu.transcript_state"], + )?; + let value = first_result(operation, "cpu.transcript_init")?; + let inserted = value_map.insert(operation_result_key(op)?, value); + debug_assert!(inserted.is_none()); + } + "compute.oracle_dense_trace" + | "compute.oracle_one_hot_chunk" + | "compute.oracle_optional_advice" + | "compute.oracle_ref" => { + let target_op = operation_name(op).replacen("compute.", "cpu.", 1); + let attrs = copy_attrs( + op, + &[ + "oracle", + "source", + "domain", + "num_vars", + "trace_num_vars", + "chunk", + "num_chunks", + "chunk_bits", + "padding", + "layout", + "skip_policy", + ], + )?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + &target_op, + Some(&symbol), + &attrs, + &[], + &["!cpu.oracle_buffer"], + )?; + let value = first_result(operation, &target_op)?; + let inserted = value_map.insert(operation_result_key(op)?, value); + debug_assert!(inserted.is_none()); + } + "compute.oracle_family_init" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["family", "count"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.oracle_family_init", + Some(&symbol), + &attrs, + &[], + &["!cpu.oracle_family"], + )?; + let value = first_result(operation, "cpu.oracle_family_init")?; + let inserted = value_map.insert(operation_result_key(op)?, value); + debug_assert!(inserted.is_none()); + } + "compute.oracle_family_append" => { + let input = operand_key(op, 0)?; + let oracle = operand_key(op, 1)?; + let input = value_map.get(&input).copied().ok_or_else(|| { + schema_error("compute.oracle_family_append input operand was not lowered") + })?; + let oracle = value_map.get(&oracle).copied().ok_or_else(|| { + schema_error("compute.oracle_family_append oracle operand was not lowered") + })?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["family", "oracle", "index"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.oracle_family_append", + Some(&symbol), + &attrs, + &[input, oracle], + &["!cpu.oracle_family"], + )?; + let value = first_result(operation, "cpu.oracle_family_append")?; + let inserted = value_map.insert(operation_result_key(op)?, value); + debug_assert!(inserted.is_none()); + } + "compute.pcs_commit_batch" | "compute.pcs_receive_batch" => { + let target_op = match operation_name(op).as_str() { + "compute.pcs_commit_batch" => "cpu.pcs_commit_batch", + "compute.pcs_receive_batch" => "cpu.pcs_receive_batch", + _ => unreachable!(), + }; + let attrs = vec![ + ( + "artifact".to_owned(), + symbol_ref(&symbol_attr(op, "artifact")?), + ), + ("count".to_owned(), int_attr_source(int_attr(op, "count")?)), + ("domain".to_owned(), symbol_ref(&symbol_attr(op, "domain")?)), + ( + "label".to_owned(), + string_attr_source(&string_attr(op, "label")?), + ), + ( + "num_vars".to_owned(), + int_attr_source(int_attr(op, "num_vars")?), + ), + ( + "oracle_family".to_owned(), + symbol_ref(&symbol_attr(op, "oracle_family")?), + ), + ( + "ordered_oracles".to_owned(), + symbol_array_attr_source(&symbol_array_attr(op, "ordered_oracles")?), + ), + ("pcs".to_owned(), symbol_ref(&symbol_attr(op, "pcs")?)), + ]; + let oracles = operand_key(op, 0)?; + let oracles = value_map.get(&oracles).copied().ok_or_else(|| { + schema_error("compute.pcs batch oracle family was not lowered") + })?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + target_op, + Some(&symbol), + &attrs, + &[oracles], + &["!cpu.commitment_artifact"], + )?; + let value = first_result(operation, target_op)?; + let inserted = value_map.insert(operation_result_key(op)?, value); + debug_assert!(inserted.is_none()); + } + "compute.pcs_commit_optional" | "compute.pcs_receive_optional" => { + let target_op = match operation_name(op).as_str() { + "compute.pcs_commit_optional" => "cpu.pcs_commit_optional", + "compute.pcs_receive_optional" => "cpu.pcs_receive_optional", + _ => unreachable!(), + }; + let attrs = vec![ + ( + "artifact".to_owned(), + symbol_ref(&symbol_attr(op, "artifact")?), + ), + ("domain".to_owned(), symbol_ref(&symbol_attr(op, "domain")?)), + ( + "label".to_owned(), + string_attr_source(&string_attr(op, "label")?), + ), + ( + "num_vars".to_owned(), + int_attr_source(int_attr(op, "num_vars")?), + ), + ("oracle".to_owned(), symbol_ref(&symbol_attr(op, "oracle")?)), + ("pcs".to_owned(), symbol_ref(&symbol_attr(op, "pcs")?)), + ( + "skip_policy".to_owned(), + string_attr_source(&string_attr(op, "skip_policy")?), + ), + ]; + let oracle = operand_key(op, 0)?; + let oracle = value_map + .get(&oracle) + .copied() + .ok_or_else(|| schema_error("compute.pcs optional oracle was not lowered"))?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + target_op, + Some(&symbol), + &attrs, + &[oracle], + &["!cpu.commitment_artifact"], + )?; + let value = first_result(operation, target_op)?; + let inserted = value_map.insert(operation_result_key(op)?, value); + debug_assert!(inserted.is_none()); + } + "compute.transcript_absorb" => { + let input = operand_key(op, 0)?; + let artifact = operand_key(op, 1)?; + let input = value_map.get(&input).copied().ok_or_else(|| { + schema_error("compute.transcript_absorb input operand was not lowered") + })?; + let artifact = value_map.get(&artifact).copied().ok_or_else(|| { + schema_error("compute.transcript_absorb artifact operand was not lowered") + })?; + let attrs = vec![ + ( + "label".to_owned(), + string_attr_source(&string_attr(op, "label")?), + ), + ( + "optional".to_owned(), + bool_attr_source(bool_attr(op, "optional")?).to_owned(), + ), + ]; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.transcript_absorb", + Some(&symbol), + &attrs, + &[input, artifact], + &["!cpu.transcript_state"], + )?; + let output = first_result(operation, "cpu.transcript_absorb")?; + let inserted = value_map.insert(operation_result_key(op)?, output); + debug_assert!(inserted.is_none()); + } + "compute.transcript_absorb_bytes" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "payload"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.transcript_absorb_bytes", + Some(&symbol), + &attrs, + &operands, + &["!cpu.transcript_state"], + )?; + let output = first_result(operation, "cpu.transcript_absorb_bytes")?; + let inserted = value_map.insert(operation_result_key(op)?, output); + debug_assert!(inserted.is_none()); + } + "compute.transcript_squeeze" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "kind", "count"])?; + let result_types = transcript_squeeze_cpu_result_types(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.transcript_squeeze", + Some(&symbol), + &attrs, + &operands, + &result_types, + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "compute.opening_input" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source_stage", + "source_claim", + "oracle", + "domain", + "point_arity", + "claim_kind", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.opening_input", + Some(&symbol), + &attrs, + &[], + &["!cpu.point", "!cpu.field_value", "!cpu.opening_claim_type"], + )?; + for index in 0..3 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "compute.point_slice" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "offset", "length"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.point_slice", + Some(&symbol), + &attrs, + &operands, + &["!cpu.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.point_zero" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field", "arity"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.point_zero", + Some(&symbol), + &attrs, + &[], + &["!cpu.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.point_concat" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["layout", "arity"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.point_concat", + Some(&symbol), + &attrs, + &operands, + &["!cpu.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.field_const" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field", "value"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.field_const", + Some(&symbol), + &attrs, + &[], + &["!cpu.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.field_zero" | "compute.field_one" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + &operation_name(op).replace("compute.", "cpu."), + Some(&symbol), + &attrs, + &[], + &["!cpu.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.field_add" + | "compute.field_sub" + | "compute.field_mul" + | "compute.field_neg" + | "compute.field_pow" + | "compute.poly_lagrange_basis_eval" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = compute_field_attrs(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + &operation_name(op).replace("compute.", "cpu."), + Some(&symbol), + &attrs, + &operands, + &["!cpu.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_kernel_claim" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &["stage", "domain", "num_rounds", "degree", "claim", "kernel"], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.sumcheck_claim", + Some(&symbol), + &attrs, + &operands, + &["!cpu.sumcheck_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_verify_claim" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.sumcheck_verify_claim", + Some(&symbol), + &attrs, + &operands, + &["!cpu.sumcheck_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_batch" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.sumcheck_batch", + Some(&symbol), + &attrs, + &operands, + &["!cpu.sumcheck_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_kernel_driver" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "kernel", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.sumcheck_driver", + Some(&symbol), + &attrs, + &operands, + &[ + "!cpu.transcript_state", + "!cpu.point", + "!cpu.sumcheck_result_type", + "!cpu.sumcheck_proof_type", + ], + )?; + for index in 0..4 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "compute.sumcheck_verify" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.sumcheck_verify", + Some(&symbol), + &attrs, + &operands, + &[ + "!cpu.transcript_state", + "!cpu.point", + "!cpu.sumcheck_result_type", + "!cpu.sumcheck_proof_type", + ], + )?; + for index in 0..4 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "compute.sumcheck_eval" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "name", "index", "oracle"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.sumcheck_eval", + Some(&symbol), + &attrs, + &operands, + &["!cpu.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_instance_result" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.sumcheck_instance_result", + Some(&symbol), + &attrs, + &operands, + &["!cpu.point", "!cpu.sumcheck_result_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "compute.opening_claim" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["oracle", "domain", "point_arity", "claim_kind"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!cpu.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.opening_claim_equal" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["mode"])?; + let _operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.opening_claim_equal", + Some(&symbol), + &attrs, + &operands, + &[], + )?; + } + "compute.opening_batch" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &["stage", "proof_slot", "policy", "count", "ordered_claims"], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!cpu.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.pcs_opening_claim" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["oracle", "family", "domain", "point_arity"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.pcs_opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!cpu.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.pcs_opening_batch" => { + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["proof_slot", "policy", "count", "ordered_claims"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + "cpu.pcs_opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!cpu.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.pcs_batch_open" | "compute.pcs_batch_verify" => { + let target_op = match operation_name(op).as_str() { + "compute.pcs_batch_open" => "cpu.pcs_batch_open", + "compute.pcs_batch_verify" => "cpu.pcs_batch_verify", + _ => unreachable!(), + }; + let operands = lowered_operands(op, &value_map)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["pcs", "proof_slot", "transcript_label"])?; + let operation = context.append_typed_op_with_owned_attrs( + &cpu, + target_op, + Some(&symbol), + &attrs, + &operands, + &["!cpu.transcript_state", "!cpu.opening_proof_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + _ => {} + } + } + verify_module(&cpu)?; + verify_cpu_schema(&cpu)?; + Ok(cpu) +} + +#[derive(Clone, Debug)] +struct ConcreteCommitmentAst { + params: ParamsAst, + oracle_buffers: BTreeMap, + batch_plans: Vec, + optional_plans: Vec, + transcript_steps: Vec, +} + +#[derive(Clone, Debug)] +struct ParamsAst { + field: String, + pcs: String, + transcript: String, + log_t: usize, + log_k_chunk: usize, + instruction_d: usize, + bytecode_d: usize, + ram_d: usize, +} + +#[derive(Clone, Debug)] +struct DomainAst { + num_vars: usize, +} + +#[derive(Clone, Debug)] +struct OracleAst { + domain: String, + commit_domain: String, + layout: String, +} + +#[derive(Clone, Debug)] +struct OracleBufferAst { + domain: String, + num_vars: usize, +} + +#[derive(Clone, Debug)] +struct OracleFamilyAst { + oracles: Vec, + count: usize, + domain: String, +} + +#[derive(Clone, Debug)] +struct PublishedBatchAst { + artifact: String, + oracle_family: String, + label: String, +} + +#[derive(Clone, Debug)] +struct PublishedOptionalAst { + artifact: String, + oracle: String, + label: String, + skip_policy: String, +} + +#[derive(Clone, Debug)] +struct BatchPlanAst { + artifact: String, + pcs: String, + oracle_family: String, + oracles: Vec, + label: String, + domain: String, + num_vars: usize, + count: usize, +} + +#[derive(Clone, Debug)] +struct OptionalPlanAst { + artifact: String, + pcs: String, + oracle: String, + label: String, + domain: String, + num_vars: usize, + skip_policy: String, +} + +#[derive(Clone, Debug)] +struct TranscriptStepAst { + symbol: String, + label: String, + source: String, + optional: bool, +} + +fn analyze_concrete

(module: &BoltModule<'_, P>) -> Result +where + P: crate::ir::Phase, +{ + let mut params = None; + let mut domains = BTreeMap::new(); + let mut oracles = BTreeMap::new(); + let mut families = BTreeMap::new(); + let mut published_batches = Vec::new(); + let mut published_optional = Vec::new(); + let mut pcs_by_artifact = BTreeMap::new(); + let mut transcript_steps = Vec::new(); + + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "protocol.params" => { + params = Some(ParamsAst { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + log_t: int_attr(op, "log_t")?, + log_k_chunk: int_attr(op, "log_k_chunk")?, + instruction_d: int_attr(op, "instruction_d")?, + bytecode_d: int_attr(op, "bytecode_d")?, + ram_d: int_attr(op, "ram_d")?, + }); + } + "poly.domain" => { + let _ = domains.insert( + string_attr(op, "sym_name")?, + DomainAst { + num_vars: int_attr(op, "log_size")?, + }, + ); + } + "piop.oracle" => { + let _ = oracles.insert( + string_attr(op, "sym_name")?, + OracleAst { + domain: symbol_attr(op, "domain")?, + commit_domain: symbol_attr(op, "commit_domain")?, + layout: string_attr(op, "layout")?, + }, + ); + } + "piop.oracle_family" => { + let _ = families.insert( + string_attr(op, "sym_name")?, + OracleFamilyAst { + oracles: symbol_array_attr(op, "ordered_oracles")?, + count: int_attr(op, "count")?, + domain: symbol_attr(op, "domain")?, + }, + ); + } + "commit.publish_batch" => published_batches.push(PublishedBatchAst { + artifact: string_attr(op, "sym_name")?, + oracle_family: symbol_attr(op, "oracle_family")?, + label: string_attr(op, "label")?, + }), + "commit.publish_optional" => published_optional.push(PublishedOptionalAst { + artifact: string_attr(op, "sym_name")?, + oracle: symbol_attr(op, "oracle")?, + label: string_attr(op, "label")?, + skip_policy: string_attr(op, "skip_policy")?, + }), + "pcs.commit_batch" => { + let _ = pcs_by_artifact + .insert(pcs_commitment_artifact(op)?, symbol_attr(op, "scheme")?); + } + "transcript.absorb" | "transcript.absorb_optional" => { + transcript_steps.push(TranscriptStepAst { + symbol: string_attr(op, "sym_name")?, + label: string_attr(op, "label")?, + source: transcript_artifact_source(op)?, + optional: operation_name(op) == "transcript.absorb_optional", + }); + } + _ => {} + } + } + + let params = params.ok_or_else(|| schema_error("missing protocol.params"))?; + let mut oracle_buffers = BTreeMap::new(); + for (symbol, oracle) in &oracles { + let buffer_domain = oracle_buffer_domain(oracle); + let domain = domains.get(buffer_domain).ok_or_else(|| { + schema_error(format!( + "oracle @{symbol} references missing buffer domain @{buffer_domain}" + )) + })?; + let _ = oracle_buffers.insert( + symbol.clone(), + OracleBufferAst { + domain: buffer_domain.to_owned(), + num_vars: domain.num_vars, + }, + ); + } + + let mut batch_plans = Vec::new(); + for batch in published_batches { + let family = families.get(&batch.oracle_family).ok_or_else(|| { + schema_error(format!( + "commitment artifact @{} references missing oracle family @{}", + batch.artifact, batch.oracle_family + )) + })?; + let domain = domains.get(&family.domain).ok_or_else(|| { + schema_error(format!( + "oracle family @{} references missing domain @{}", + batch.oracle_family, family.domain + )) + })?; + batch_plans.push(BatchPlanAst { + pcs: pcs_by_artifact + .get(&batch.artifact) + .cloned() + .unwrap_or_else(|| params.pcs.clone()), + artifact: batch.artifact, + oracle_family: batch.oracle_family, + oracles: family.oracles.clone(), + label: batch.label, + domain: family.domain.clone(), + num_vars: domain.num_vars, + count: family.count, + }); + } + + let mut optional_plans = Vec::new(); + for optional in published_optional { + let oracle = oracles.get(&optional.oracle).ok_or_else(|| { + schema_error(format!( + "commitment artifact @{} references missing oracle @{}", + optional.artifact, optional.oracle + )) + })?; + let domain = domains.get(&oracle.commit_domain).ok_or_else(|| { + schema_error(format!( + "oracle @{} references missing commit domain @{}", + optional.oracle, oracle.commit_domain + )) + })?; + optional_plans.push(OptionalPlanAst { + pcs: params.pcs.clone(), + artifact: optional.artifact, + oracle: optional.oracle, + label: optional.label, + domain: oracle.commit_domain.clone(), + num_vars: domain.num_vars, + skip_policy: optional.skip_policy, + }); + } + + Ok(ConcreteCommitmentAst { + params, + oracle_buffers, + batch_plans, + optional_plans, + transcript_steps, + }) +} + +fn oracle_buffer_domain(oracle: &OracleAst) -> &str { + if oracle.layout == "onehot_expanded" { + &oracle.commit_domain + } else { + &oracle.domain + } +} + +fn append_oracle_buffer<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Compute>, + role: &Role, + params: &ParamsAst, + oracle: &str, + domain: &str, + num_vars: usize, +) -> Result, MlirError> { + let symbol = format!("jolt.oracle.{oracle}.compute"); + match role { + Role::Verifier => append_oracle_ref(context, module, &symbol, oracle, domain, num_vars), + Role::Prover => { + let recipe = oracle_recipe(oracle, params)?; + let attrs = recipe.attrs(oracle, domain, num_vars, params); + let operation = context.append_typed_op_with_owned_attrs( + module, + recipe.op_name(), + Some(&symbol), + &attrs, + &[], + &["!compute.oracle_buffer"], + )?; + first_result(operation, recipe.op_name()) + } + } +} + +fn append_optional_oracle_buffer<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Compute>, + role: &Role, + oracle: &str, + domain: &str, + num_vars: usize, + skip_policy: &str, +) -> Result, MlirError> { + let symbol = format!("jolt.oracle.{oracle}.compute"); + match role { + Role::Verifier => append_oracle_ref(context, module, &symbol, oracle, domain, num_vars), + Role::Prover => { + let operation = context.append_typed_op_with_owned_attrs( + module, + "compute.oracle_optional_advice", + Some(&symbol), + &[ + ("oracle".to_owned(), symbol_ref(oracle)), + ( + "source".to_owned(), + symbol_ref(&optional_advice_source(oracle)?), + ), + ("domain".to_owned(), symbol_ref(domain)), + ("num_vars".to_owned(), int_attr_source(num_vars)), + ("skip_policy".to_owned(), string_attr_source(skip_policy)), + ], + &[], + &["!compute.oracle_buffer"], + )?; + first_result(operation, "compute.oracle_optional_advice") + } + } +} + +fn append_oracle_ref<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Compute>, + symbol: &str, + oracle: &str, + domain: &str, + num_vars: usize, +) -> Result, MlirError> { + let operation = context.append_typed_op_with_owned_attrs( + module, + "compute.oracle_ref", + Some(symbol), + &[ + ("oracle".to_owned(), symbol_ref(oracle)), + ("domain".to_owned(), symbol_ref(domain)), + ("num_vars".to_owned(), int_attr_source(num_vars)), + ], + &[], + &["!compute.oracle_buffer"], + )?; + first_result(operation, "compute.oracle_ref") +} + +#[derive(Clone, Debug)] +enum OracleRecipe { + DenseTrace { + source: &'static str, + }, + OneHotChunk { + source: &'static str, + chunk: usize, + num_chunks: usize, + padding: &'static str, + }, +} + +impl OracleRecipe { + fn op_name(&self) -> &'static str { + match self { + Self::DenseTrace { .. } => "compute.oracle_dense_trace", + Self::OneHotChunk { .. } => "compute.oracle_one_hot_chunk", + } + } + + fn attrs( + &self, + oracle: &str, + domain: &str, + num_vars: usize, + params: &ParamsAst, + ) -> Vec<(String, String)> { + let mut attrs = vec![ + ("oracle".to_owned(), symbol_ref(oracle)), + ("domain".to_owned(), symbol_ref(domain)), + ("num_vars".to_owned(), int_attr_source(num_vars)), + ]; + match self { + Self::DenseTrace { source } => { + attrs.push(("source".to_owned(), symbol_ref(source))); + attrs.push(("padding".to_owned(), string_attr_source("zero"))); + } + Self::OneHotChunk { + source, + chunk, + num_chunks, + padding, + } => { + attrs.push(("source".to_owned(), symbol_ref(source))); + attrs.push(("trace_num_vars".to_owned(), int_attr_source(params.log_t))); + attrs.push(("chunk".to_owned(), int_attr_source(*chunk))); + attrs.push(("num_chunks".to_owned(), int_attr_source(*num_chunks))); + attrs.push(("chunk_bits".to_owned(), int_attr_source(params.log_k_chunk))); + attrs.push(("padding".to_owned(), string_attr_source(padding))); + attrs.push(("layout".to_owned(), string_attr_source("address_major"))); + } + } + attrs + } +} + +fn oracle_recipe(oracle: &str, params: &ParamsAst) -> Result { + if oracle == "RdInc" { + return Ok(OracleRecipe::DenseTrace { + source: "trace.rd_inc", + }); + } + if oracle == "RamInc" { + return Ok(OracleRecipe::DenseTrace { + source: "trace.ram_inc", + }); + } + if let Some(index) = parse_indexed_oracle(oracle, "InstructionRa") { + return Ok(OracleRecipe::OneHotChunk { + source: "trace.instruction_keys", + chunk: index, + num_chunks: params.instruction_d, + padding: "zero", + }); + } + if let Some(index) = parse_indexed_oracle(oracle, "RamRa") { + return Ok(OracleRecipe::OneHotChunk { + source: "trace.ram_addresses", + chunk: index, + num_chunks: params.ram_d, + padding: "none", + }); + } + if let Some(index) = parse_indexed_oracle(oracle, "BytecodeRa") { + return Ok(OracleRecipe::OneHotChunk { + source: "trace.bytecode_indices", + chunk: index, + num_chunks: params.bytecode_d, + padding: "zero", + }); + } + Err(schema_error(format!( + "unsupported commitment oracle @{oracle}" + ))) +} + +fn parse_indexed_oracle(oracle: &str, prefix: &str) -> Option { + oracle.strip_prefix(prefix)?.strip_prefix('_')?.parse().ok() +} + +fn optional_advice_source(oracle: &str) -> Result { + match oracle { + "UntrustedAdvice" => Ok("advice.untrusted".to_owned()), + "TrustedAdvice" => Ok("advice.trusted".to_owned()), + _ => Err(schema_error(format!( + "unsupported optional advice oracle @{oracle}" + ))), + } +} + +fn bool_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(|attribute| match attribute.to_string().as_str() { + "true" => Some(true), + "false" => Some(false), + _ => None, + }) + .ok() + .flatten() + .ok_or_else(|| { + schema_error(format!( + "{} attr `{attr}` is not a bool", + operation_name(operation) + )) + }) +} + +fn operation_result_key(operation: OperationRef<'_, '_>) -> Result { + operation_result_key_at(operation, 0) +} + +fn operation_result_key_at( + operation: OperationRef<'_, '_>, + index: usize, +) -> Result { + let result = operation.result(index).map_err(|_| { + schema_error(format!( + "{} requires result {index}", + operation_name(operation) + )) + })?; + result_key(result.owner(), result.result_number()) +} + +fn result_key(operation: OperationRef<'_, '_>, result_number: usize) -> Result { + Ok(format!( + "{}#{result_number}", + string_attr(operation, "sym_name")? + )) +} + +fn operand_key(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + schema_error(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + schema_error(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + result_key(owner.owner(), owner.result_number()).map_err(|_| { + schema_error(format!( + "{} operand {index} owner missing sym_name", + operation_name(operation) + )) + }) +} + +fn lowered_operands<'c, 'a>( + operation: OperationRef<'_, '_>, + value_map: &BTreeMap>, +) -> Result>, MlirError> { + (0..operation.operand_count()) + .map(|index| { + let key = operand_key(operation, index)?; + value_map.get(&key).copied().ok_or_else(|| { + schema_error(format!( + "{} operand {index} was not lowered", + operation_name(operation) + )) + }) + }) + .collect() +} + +fn insert_result_mapping<'c, 'a>( + value_map: &mut BTreeMap>, + source: OperationRef<'_, '_>, + target: OperationRef<'c, 'a>, + source_index: usize, + target_index: usize, +) -> Result<(), MlirError> { + let key = operation_result_key_at(source, source_index)?; + let value = target.result(target_index).map(Into::into).map_err(|_| { + schema_error(format!( + "{} requires result {target_index}", + operation_name(target) + )) + })?; + let inserted = value_map.insert(key, value); + debug_assert!(inserted.is_none()); + Ok(()) +} + +fn first_result<'c, 'a>( + operation: OperationRef<'c, 'a>, + operation_name: &str, +) -> Result, MlirError> { + operation + .result(0) + .map(Into::into) + .map_err(|_| schema_error(format!("{operation_name} requires one result"))) +} + +fn pcs_commitment_artifact(operation: OperationRef<'_, '_>) -> Result { + let artifact = operation.operand(0).map_err(|_| { + schema_error(format!( + "{} requires commitment artifact operand 0", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(artifact).map_err(|_| { + schema_error(format!( + "{} commitment operand must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn transcript_artifact_source(operation: OperationRef<'_, '_>) -> Result { + if let Ok(source) = symbol_attr(operation, "source") { + return Ok(source); + } + let artifact = operation.operand(1).map_err(|_| { + schema_error(format!( + "{} requires commitment artifact operand 1", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(artifact).map_err(|_| { + schema_error(format!( + "{} artifact operand must be an op result", + operation_name(operation) + )) + })?; + string_attr(owner.owner(), "sym_name") +} + +fn schema_error(message: impl Into) -> MlirError { + let error = SchemaError::new(message); + error.into() +} + +fn symbol_ref(value: &str) -> String { + format!("@{value}") +} + +fn string_attr_source(value: &str) -> String { + format!("{value:?}") +} + +fn symbol_array_attr_source(values: &[String]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn int_attr_source(value: usize) -> String { + format!("{value} : i64") +} + +fn bool_attr_source(value: bool) -> &'static str { + if value { + "true" + } else { + "false" + } +} diff --git a/crates/bolt/src/protocols/jolt/phases/lowering.rs b/crates/bolt/src/protocols/jolt/phases/lowering.rs new file mode 100644 index 0000000000..ed50ed20db --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/lowering.rs @@ -0,0 +1,641 @@ +use std::collections::BTreeMap; + +use melior::ir::block::BlockLike; +use melior::ir::operation::OperationResult; +use melior::ir::operation::{OperationLike, OperationRef}; +use melior::ir::Value; + +use crate::ir::{string_attribute_value, BoltModule, Compute, Party, Role}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{operation_name, verify_compute_schema, verify_party_schema, SchemaError}; + +pub(super) fn copy_attrs( + operation: OperationRef<'_, '_>, + attrs: &[&str], +) -> Result, MlirError> { + attrs + .iter() + .filter_map(|attr| { + operation + .attribute(attr) + .ok() + .map(|value| Ok(((*attr).to_owned(), value.to_string()))) + }) + .collect() +} + +pub(super) fn string_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| { + schema_error(format!( + "{} attr `{attr}` is not a string", + operation_name(operation) + )) + }) +} + +pub(super) fn transcript_squeeze_protocol_result_type( + kind: &str, +) -> Result<&'static str, MlirError> { + transcript_squeeze_value_type(kind, "!poly.point", "!field.scalar") +} + +pub(super) fn transcript_squeeze_compute_result_types( + operation: OperationRef<'_, '_>, +) -> Result<[&'static str; 2], MlirError> { + Ok([ + "!compute.transcript_state", + transcript_squeeze_value_type( + string_attr(operation, "kind")?.as_str(), + "!compute.point", + "!compute.field_value", + )?, + ]) +} + +pub(super) fn transcript_squeeze_cpu_result_types( + operation: OperationRef<'_, '_>, +) -> Result<[&'static str; 2], MlirError> { + Ok([ + "!cpu.transcript_state", + transcript_squeeze_value_type( + string_attr(operation, "kind")?.as_str(), + "!cpu.point", + "!cpu.field_value", + )?, + ]) +} + +pub(super) fn field_lowering_attrs( + operation: OperationRef<'_, '_>, +) -> Result, MlirError> { + match operation_name(operation).as_str() { + "field.pow" | "compute.field_pow" => copy_attrs(operation, &["exponent"]), + "poly.lagrange_basis_eval" | "compute.poly_lagrange_basis_eval" => { + copy_attrs(operation, &["domain_start", "domain_size", "index"]) + } + _ => Ok(Vec::new()), + } +} + +pub(super) fn lower_party_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, + function_symbol: &str, + source_symbol: &str, + stage_label: &str, +) -> Result, MlirError> { + verify_party_schema(module)?; + let role = module + .role() + .ok_or_else(|| schema_error(format!("{stage_label} lowering requires party role")))?; + let compute = context.new_module::(&module.name(), Some(role.clone())); + let params_attrs = compute_params_attrs(module)?; + context.append_op_with_owned_attrs( + &compute, + "compute.params", + Some("jolt.compute_params"), + ¶ms_attrs, + )?; + context.append_op( + &compute, + "compute.function", + Some(function_symbol), + &[("source", &format!("@{source_symbol}"))], + )?; + + let mut value_map = BTreeMap::new(); + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "piop.relation" => { + let attrs = copy_attrs( + op, + &["kind", "domain", "num_rounds", "degree", "output_count"], + )?; + let symbol = string_attr(op, "sym_name")?; + context.append_op_with_owned_attrs( + &compute, + "compute.relation", + Some(&symbol), + &attrs, + )?; + } + "transcript.state" => { + let attrs = copy_attrs(op, &["scheme"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_init", + Some(&symbol), + &attrs, + &[], + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "transcript.absorb_bytes" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "payload"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_absorb_bytes", + Some(&symbol), + &attrs, + &operands, + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "transcript.squeeze" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "kind", "count"])?; + let result_types = transcript_squeeze_compute_result_types(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_squeeze", + Some(&symbol), + &attrs, + &operands, + &result_types, + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "field.const" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field", "value"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.field_const", + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "field.zero" | "field.one" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + &format!("compute.{}", operation_name(op).replace('.', "_")), + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "field.add" | "field.sub" | "field.mul" | "field.neg" | "field.pow" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = field_lowering_attrs(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + &format!("compute.{}", operation_name(op).replace('.', "_")), + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "poly.lagrange_basis_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["domain_start", "domain_size", "index"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.poly_lagrange_basis_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.opening_input" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source_stage", + "source_claim", + "oracle", + "domain", + "point_arity", + "claim_kind", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_input", + Some(&symbol), + &attrs, + &[], + &[ + "!compute.point", + "!compute.field_value", + "!compute.opening_claim_type", + ], + )?; + for index in 0..3 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "poly.point_slice" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "offset", "length"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.point_slice", + Some(&symbol), + &attrs, + &operands, + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "poly.point_zero" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field", "arity"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.point_zero", + Some(&symbol), + &attrs, + &[], + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "poly.point_concat" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["layout", "arity"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.point_concat", + Some(&symbol), + &attrs, + &operands, + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + let target_op = match &role { + Role::Prover => "compute.sumcheck_claim", + Role::Verifier => "compute.sumcheck_verify_claim", + }; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_batch" => { + let operands = lowered_operands(op, &value_map, 1)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + let target_op = match &role { + Role::Prover => "compute.sumcheck_driver", + Role::Verifier => "compute.sumcheck_verify", + }; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &[ + "!compute.transcript_state", + "!compute.point", + "!compute.sumcheck_result_type", + "!compute.sumcheck_proof_type", + ], + )?; + for index in 0..4 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "piop.sumcheck_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "name", "index", "oracle"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_instance_result" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_instance_result", + Some(&symbol), + &attrs, + &operands, + &["!compute.point", "!compute.sumcheck_result_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "piop.opening_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["oracle", "domain", "point_arity", "claim_kind"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.opening_claim_equal" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["mode"])?; + let _operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_claim_equal", + Some(&symbol), + &attrs, + &operands, + &[], + )?; + } + "piop.opening_batch" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &["stage", "proof_slot", "policy", "count", "ordered_claims"], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "pcs.opening_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["oracle", "family", "domain", "point_arity"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.pcs_opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "pcs.opening_batch" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["proof_slot", "policy", "count", "ordered_claims"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.pcs_opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "pcs.batch_open" | "pcs.batch_verify" => { + let target_op = match &role { + Role::Prover => "compute.pcs_batch_open", + Role::Verifier => "compute.pcs_batch_verify", + }; + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["pcs", "proof_slot", "transcript_label"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &["!compute.transcript_state", "!compute.opening_proof_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + _ => {} + } + } + + verify_module(&compute)?; + verify_compute_schema(&compute)?; + Ok(compute) +} + +fn compute_params_attrs( + module: &BoltModule<'_, P>, +) -> Result, MlirError> { + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + if operation_name(op) == "protocol.params" { + return copy_attrs(op, &["field", "pcs", "transcript"]); + } + } + Err(schema_error("module missing protocol.params")) +} + +fn transcript_squeeze_value_type( + kind: &str, + point_type: &'static str, + scalar_type: &'static str, +) -> Result<&'static str, MlirError> { + match kind { + "challenge_vector" => Ok(point_type), + "challenge_scalar" | "scalar" => Ok(scalar_type), + kind => Err(schema_error(format!( + "unsupported transcript squeeze kind `{kind}`" + ))), + } +} + +fn schema_error(message: impl Into) -> MlirError { + SchemaError::new(message).into() +} + +fn operation_result_key_at( + operation: OperationRef<'_, '_>, + index: usize, +) -> Result { + let result = operation.result(index).map_err(|_| { + schema_error(format!( + "{} requires result {index}", + operation_name(operation) + )) + })?; + result_key(operation, result.result_number()).map_err(|_| { + schema_error(format!( + "{} result {index} owner missing sym_name", + operation_name(operation) + )) + }) +} + +fn result_key(operation: OperationRef<'_, '_>, result_number: usize) -> Result { + let symbol = string_attr(operation, "sym_name")?; + Ok(format!("{symbol}#{result_number}")) +} + +fn operand_key(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + schema_error(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + schema_error(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + result_key(owner.owner(), owner.result_number()).map_err(|_| { + schema_error(format!( + "{} operand {index} owner missing sym_name", + operation_name(operation) + )) + }) +} + +fn lowered_operands<'c, 'a>( + operation: OperationRef<'_, '_>, + value_map: &BTreeMap>, + start_index: usize, +) -> Result>, MlirError> { + (start_index..operation.operand_count()) + .map(|index| { + let key = operand_key(operation, index)?; + value_map.get(&key).copied().ok_or_else(|| { + schema_error(format!( + "{} operand {index} was not lowered", + operation_name(operation) + )) + }) + }) + .collect() +} + +fn insert_result_mapping<'c, 'a>( + value_map: &mut BTreeMap>, + source: OperationRef<'_, '_>, + target: OperationRef<'c, 'a>, + source_index: usize, + target_index: usize, +) -> Result<(), MlirError> { + let key = operation_result_key_at(source, source_index)?; + let value = target.result(target_index).map(Into::into).map_err(|_| { + schema_error(format!( + "{} requires result {target_index}", + operation_name(target) + )) + })?; + let _ = value_map.insert(key, value); + Ok(()) +} diff --git a/crates/bolt/src/protocols/jolt/phases/mod.rs b/crates/bolt/src/protocols/jolt/phases/mod.rs new file mode 100644 index 0000000000..19f9646db3 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/mod.rs @@ -0,0 +1,11 @@ +mod lowering; + +pub mod commitment; +pub mod stage1; +pub mod stage2; +pub mod stage3; +pub mod stage4; +pub mod stage5; +pub mod stage6; +pub mod stage7; +pub mod stage8; diff --git a/crates/bolt/src/protocols/jolt/phases/stage1.rs b/crates/bolt/src/protocols/jolt/phases/stage1.rs new file mode 100644 index 0000000000..7501e62768 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/stage1.rs @@ -0,0 +1,1723 @@ +use std::collections::BTreeMap; + +use melior::ir::block::BlockLike; +use melior::ir::operation::OperationRef; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::Value; + +use crate::ir::{BoltModule, Compute, Party, Protocol, Role}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{ + operation_name, symbol_attr, verify_compute_schema, verify_party_schema, + verify_protocol_schema, SchemaError, +}; + +use super::super::oracles; +use super::super::params::JoltProtocolParams; +use super::lowering::{ + copy_attrs, field_lowering_attrs as field_compute_attrs, string_attr, + transcript_squeeze_compute_result_types, +}; + +const R1CS_INPUT_ORACLES: [&str; 35] = [ + "LeftInstructionInput", + "RightInstructionInput", + "Product", + "ShouldBranch", + "PC", + "UnexpandedPC", + "Imm", + "RamAddress", + "Rs1Value", + "Rs2Value", + "RdWriteValue", + "RamReadValue", + "RamWriteValue", + "LeftLookupOperand", + "RightLookupOperand", + "NextUnexpandedPC", + "NextPC", + "NextIsVirtual", + "NextIsFirstInSequence", + "LookupOutput", + "ShouldJump", + "OpFlagAddOperands", + "OpFlagSubtractOperands", + "OpFlagMultiplyOperands", + "OpFlagLoad", + "OpFlagStore", + "OpFlagJump", + "OpFlagWriteLookupOutputToRD", + "OpFlagVirtualInstruction", + "OpFlagAssert", + "OpFlagDoNotUpdateUnexpandedPC", + "OpFlagAdvice", + "OpFlagIsCompressed", + "OpFlagIsFirstInSequence", + "OpFlagIsLastInSequence", +]; +const OUTER_UNISKIP_FIRST_ROUND_DEGREE_BOUND: usize = 27; + +pub fn build_stage1_outer_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.stage1_outer", None); + oracles::append_foundation_ops(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.stage1_outer"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + append_stage1_virtual_oracles(context, &module, params)?; + append_stage1_relations(context, &module, params)?; + + let fs0 = context.append_typed_op( + &module, + "transcript.state", + Some("fs0"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let state = first_result(fs0, "transcript.state")?; + let tau = context.append_typed_op( + &module, + "transcript.squeeze", + Some("stage1.tau"), + &[ + ("label", r#""outer_tau""#), + ("kind", r#""challenge_vector""#), + ("count", &int_attr(params.log_t + 2)), + ], + &[state], + &["!transcript.state_type", "!poly.point"], + )?; + let state = first_result(tau, "transcript.squeeze")?; + + let stage = context.append_typed_op( + &module, + "piop.stage", + Some("stage1"), + &[ + ("name", r#""spartan_outer""#), + ("order", "1 : i64"), + ("roles", r#"["prover", "verifier"]"#), + ], + &[], + &["!piop.stage_type"], + )?; + let stage = first_result(stage, "piop.stage")?; + let zero_claim = append_field_zero(context, &module, "stage1.zero")?; + + let (state, uniskip_opening, uniskip_eval) = + append_uniskip_sumcheck(context, &module, params, state, stage, zero_claim)?; + let _state = append_remaining_sumcheck( + context, + &module, + params, + state, + stage, + uniskip_eval, + uniskip_opening, + )?; + + verify_module(&module)?; + verify_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_stage1_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + verify_party_schema(module)?; + let role = module + .role() + .ok_or_else(|| schema_error("stage1 lowering requires party role"))?; + let params = stage_params(module)?; + let compute = context.new_module::(&module.name(), Some(role.clone())); + context.append_op_with_owned_attrs( + &compute, + "compute.params", + Some("jolt.compute_params"), + &[ + ("field".to_owned(), symbol_ref(¶ms.field)), + ("pcs".to_owned(), symbol_ref(¶ms.pcs)), + ("transcript".to_owned(), symbol_ref(¶ms.transcript)), + ], + )?; + context.append_op( + &compute, + "compute.function", + Some("jolt.stage1_outer"), + &[("source", "@jolt.stage1_outer")], + )?; + + let mut value_map = BTreeMap::new(); + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "piop.relation" => { + let attrs = copy_attrs( + op, + &["kind", "domain", "num_rounds", "degree", "output_count"], + )?; + let symbol = string_attr(op, "sym_name")?; + context.append_op_with_owned_attrs( + &compute, + "compute.relation", + Some(&symbol), + &attrs, + )?; + } + "transcript.state" => { + let attrs = copy_attrs(op, &["scheme"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_init", + Some(&symbol), + &attrs, + &[], + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "transcript.absorb_bytes" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "payload"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_absorb_bytes", + Some(&symbol), + &attrs, + &operands, + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "transcript.squeeze" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "kind", "count"])?; + let result_types = transcript_squeeze_compute_result_types(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_squeeze", + Some(&symbol), + &attrs, + &operands, + &result_types, + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "field.const" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field", "value"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.field_const", + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "field.zero" | "field.one" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + &format!("compute.{}", operation_name(op).replace('.', "_")), + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "field.add" | "field.sub" | "field.mul" | "field.neg" | "field.pow" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = field_compute_attrs(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + &format!("compute.{}", operation_name(op).replace('.', "_")), + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "poly.lagrange_basis_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["domain_start", "domain_size", "index"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.poly_lagrange_basis_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + let target_op = match &role { + Role::Prover => "compute.sumcheck_claim", + Role::Verifier => "compute.sumcheck_verify_claim", + }; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_batch" => { + let operands = lowered_operands(op, &value_map, 1)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + let target_op = match &role { + Role::Prover => "compute.sumcheck_driver", + Role::Verifier => "compute.sumcheck_verify", + }; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &[ + "!compute.transcript_state", + "!compute.point", + "!compute.sumcheck_result_type", + "!compute.sumcheck_proof_type", + ], + )?; + for index in 0..4 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "piop.sumcheck_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "name", "index", "oracle"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_instance_result" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_instance_result", + Some(&symbol), + &attrs, + &operands, + &["!compute.point", "!compute.sumcheck_result_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "piop.opening_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["oracle", "domain", "point_arity", "claim_kind"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.opening_claim_equal" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["mode"])?; + let _operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_claim_equal", + Some(&symbol), + &attrs, + &operands, + &[], + )?; + } + "piop.opening_batch" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &["stage", "proof_slot", "policy", "count", "ordered_claims"], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + _ => {} + } + } + + verify_module(&compute)?; + verify_compute_schema(&compute)?; + Ok(compute) +} + +pub fn resolve_compute_kernels<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Compute>, +) -> Result, MlirError> { + verify_compute_schema(module)?; + let role = module + .role() + .ok_or_else(|| schema_error("kernel resolution requires compute party role"))?; + let kernelized = context.new_module::(&module.name(), Some(role)); + let mut value_map = BTreeMap::new(); + let mut kernels = BTreeMap::new(); + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "compute.params" => { + let attrs = copy_attrs(op, &["field", "pcs", "transcript"])?; + let symbol = string_attr(op, "sym_name")?; + context.append_op_with_owned_attrs( + &kernelized, + "compute.params", + Some(&symbol), + &attrs, + )?; + } + "compute.function" => { + let attrs = copy_attrs(op, &["source"])?; + let symbol = string_attr(op, "sym_name")?; + context.append_op_with_owned_attrs( + &kernelized, + "compute.function", + Some(&symbol), + &attrs, + )?; + } + "compute.relation" => { + let attrs = copy_attrs( + op, + &["kind", "domain", "num_rounds", "degree", "output_count"], + )?; + let symbol = string_attr(op, "sym_name")?; + context.append_op_with_owned_attrs( + &kernelized, + "compute.relation", + Some(&symbol), + &attrs, + )?; + } + "compute.transcript_init" => { + let attrs = copy_attrs(op, &["scheme"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.transcript_init", + Some(&symbol), + &attrs, + &[], + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.transcript_absorb_bytes" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["label", "payload"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.transcript_absorb_bytes", + Some(&symbol), + &attrs, + &operands, + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.transcript_squeeze" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["label", "kind", "count"])?; + let symbol = string_attr(op, "sym_name")?; + let result_types = transcript_squeeze_compute_result_types(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.transcript_squeeze", + Some(&symbol), + &attrs, + &operands, + &result_types, + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "compute.opening_input" => { + let attrs = copy_attrs( + op, + &[ + "source_stage", + "source_claim", + "oracle", + "domain", + "point_arity", + "claim_kind", + ], + )?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.opening_input", + Some(&symbol), + &attrs, + &[], + &[ + "!compute.point", + "!compute.field_value", + "!compute.opening_claim_type", + ], + )?; + for index in 0..3 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "compute.point_slice" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["source", "offset", "length"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.point_slice", + Some(&symbol), + &attrs, + &operands, + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.point_zero" => { + let attrs = copy_attrs(op, &["field", "arity"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.point_zero", + Some(&symbol), + &attrs, + &[], + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.point_concat" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["layout", "arity"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.point_concat", + Some(&symbol), + &attrs, + &operands, + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.field_const" => { + let attrs = copy_attrs(op, &["field", "value"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.field_const", + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.field_zero" | "compute.field_one" => { + let attrs = copy_attrs(op, &["field"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + &operation_name(op), + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.field_add" + | "compute.field_sub" + | "compute.field_mul" + | "compute.field_neg" + | "compute.field_pow" + | "compute.poly_lagrange_basis_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = field_compute_attrs(op)?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + &operation_name(op), + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_claim" => { + let relation = symbol_attr(op, "relation")?; + let kernel = ensure_kernel(context, &kernelized, &mut kernels, &relation)?; + let operands = lowered_operands(op, &value_map, 0)?; + let mut attrs = + copy_attrs(op, &["stage", "domain", "num_rounds", "degree", "claim"])?; + attrs.push(("kernel".to_owned(), symbol_ref(&kernel))); + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.sumcheck_kernel_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_verify_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs( + op, + &[ + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.sumcheck_verify_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_batch" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.sumcheck_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_driver" => { + let relation = symbol_attr(op, "relation")?; + let kernel = ensure_kernel(context, &kernelized, &mut kernels, &relation)?; + let operands = lowered_operands(op, &value_map, 0)?; + let mut attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + attrs.push(("kernel".to_owned(), symbol_ref(&kernel))); + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.sumcheck_kernel_driver", + Some(&symbol), + &attrs, + &operands, + &[ + "!compute.transcript_state", + "!compute.point", + "!compute.sumcheck_result_type", + "!compute.sumcheck_proof_type", + ], + )?; + for index in 0..4 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "compute.sumcheck_verify" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.sumcheck_verify", + Some(&symbol), + &attrs, + &operands, + &[ + "!compute.transcript_state", + "!compute.point", + "!compute.sumcheck_result_type", + "!compute.sumcheck_proof_type", + ], + )?; + for index in 0..4 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "compute.sumcheck_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["source", "name", "index", "oracle"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.sumcheck_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.sumcheck_instance_result" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs( + op, + &[ + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.sumcheck_instance_result", + Some(&symbol), + &attrs, + &operands, + &["!compute.point", "!compute.sumcheck_result_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "compute.opening_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["oracle", "domain", "point_arity", "claim_kind"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.opening_claim_equal" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["mode"])?; + let symbol = string_attr(op, "sym_name")?; + let _operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.opening_claim_equal", + Some(&symbol), + &attrs, + &operands, + &[], + )?; + } + "compute.opening_batch" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs( + op, + &["stage", "proof_slot", "policy", "count", "ordered_claims"], + )?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.pcs_opening_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["oracle", "family", "domain", "point_arity"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.pcs_opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.pcs_opening_batch" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["proof_slot", "policy", "count", "ordered_claims"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + "compute.pcs_opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "compute.pcs_batch_open" | "compute.pcs_batch_verify" => { + let operands = lowered_operands(op, &value_map, 0)?; + let attrs = copy_attrs(op, &["pcs", "proof_slot", "transcript_label"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &kernelized, + &operation_name(op), + Some(&symbol), + &attrs, + &operands, + &["!compute.transcript_state", "!compute.opening_proof_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + _ => {} + } + } + + verify_module(&kernelized)?; + verify_compute_schema(&kernelized)?; + Ok(kernelized) +} + +fn append_stage1_virtual_oracles<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + context.append_op( + module, + "poly.domain", + Some("jolt.stage1_uniskip_domain"), + &[("field", "@bn254_fr"), ("log_size", "1 : i64")], + )?; + append_virtual_oracle( + context, + module, + "UnivariateSkip", + "jolt.stage1_uniskip_domain", + )?; + for oracle in R1CS_INPUT_ORACLES { + append_virtual_oracle(context, module, oracle, "jolt.trace_domain")?; + } + context.append_op( + module, + "piop.oracle_family", + Some("jolt.stage1_r1cs_virtuals"), + &[ + ("ordered_oracles", &symbol_array_attr(&R1CS_INPUT_ORACLES)), + ("count", &int_attr(params.num_r1cs_inputs)), + ("domain", "@jolt.trace_domain"), + ("visibility", r#""virtual""#), + ], + ) +} + +fn append_virtual_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + domain: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", &format!("@{domain}")), + ("commit_domain", &format!("@{domain}")), + ("visibility", r#""virtual""#), + ("layout", r#""virtual""#), + ], + ) +} + +fn append_stage1_relations<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.relation", + Some("jolt.stage1.outer.uniskip"), + &[ + ("kind", r#""sumcheck""#), + ("domain", "@jolt.stage1_uniskip_domain"), + ("num_rounds", "1 : i64"), + ("degree", &int_attr(OUTER_UNISKIP_FIRST_ROUND_DEGREE_BOUND)), + ("output_count", "1 : i64"), + ], + )?; + context.append_op( + module, + "piop.relation", + Some("jolt.stage1.outer.remaining"), + &[ + ("kind", r#""sumcheck""#), + ("domain", "@jolt.trace_domain"), + ("num_rounds", &int_attr(params.log_t + 1)), + ("degree", "3 : i64"), + ("output_count", &int_attr(R1CS_INPUT_ORACLES.len())), + ], + ) +} + +fn append_field_zero<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.zero", + Some(symbol), + &[("field", "@bn254_fr")], + &[], + &["!field.scalar"], + )?; + first_result(op, "field.zero") +} + +fn append_uniskip_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + zero_claim: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let claim = context.append_typed_op( + module, + "piop.sumcheck_claim", + Some("stage1.uniskip.input"), + &[ + ("stage", "@stage1"), + ("domain", "@jolt.stage1_uniskip_domain"), + ("num_rounds", "1 : i64"), + ("degree", &int_attr(OUTER_UNISKIP_FIRST_ROUND_DEGREE_BOUND)), + ("claim", "@stage1.zero"), + ("relation", "@jolt.stage1.outer.uniskip"), + ], + &[zero_claim], + &["!piop.sumcheck_claim_type"], + )?; + let claim = first_result(claim, "piop.sumcheck_claim")?; + let batch = context.append_typed_op( + module, + "piop.sumcheck_batch", + Some("stage1.uniskip.batch"), + &[ + ("stage", "@stage1"), + ("proof_slot", "@stage1.uni_skip_first_round"), + ("policy", r#""single_instance""#), + ("count", "1 : i64"), + ("ordered_claims", "[@stage1.uniskip.input]"), + ("claim_label", r#""uniskip_claim""#), + ("round_label", r#""uniskip_poly""#), + ("round_schedule", "[1]"), + ], + &[stage, claim], + &["!piop.sumcheck_batch_type"], + )?; + let batch = first_result(batch, "piop.sumcheck_batch")?; + let sumcheck = context.append_typed_op( + module, + "piop.sumcheck", + Some("stage1.uniskip.sumcheck"), + &[ + ("stage", "@stage1"), + ("proof_slot", "@stage1.uni_skip_first_round"), + ("relation", "@jolt.stage1.outer.uniskip"), + ("policy", r#""univariate_skip""#), + ("round_schedule", "[1]"), + ("claim_label", r#""uniskip_claim""#), + ("round_label", r#""uniskip_poly""#), + ("num_rounds", "1 : i64"), + ("degree", &int_attr(OUTER_UNISKIP_FIRST_ROUND_DEGREE_BOUND)), + ], + &[state, batch], + &[ + "!transcript.state_type", + "!poly.point", + "!piop.sumcheck_result_type", + "!piop.sumcheck_proof_type", + ], + )?; + let state = result(sumcheck, 0, "piop.sumcheck")?; + let point = result(sumcheck, 1, "piop.sumcheck")?; + let result_value = result(sumcheck, 2, "piop.sumcheck")?; + let (point, result_value) = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage1.uniskip.instance", + source: "stage1.uniskip.sumcheck", + claim: "stage1.uniskip.input", + relation: "jolt.stage1.outer.uniskip", + index: 0, + point_arity: 1, + num_rounds: 1, + round_offset: 0, + point_order: "as_is", + degree: OUTER_UNISKIP_FIRST_ROUND_DEGREE_BOUND, + }, + point, + result_value, + )?; + let eval = append_sumcheck_eval( + context, + module, + "stage1.uniskip.eval", + "stage1.uniskip.sumcheck", + "UnivariateSkip", + 0, + result_value, + )?; + let opening = append_piop_opening_claim( + context, + module, + point, + eval, + OpeningClaimSpec { + symbol: "stage1.uniskip.opening", + oracle: "UnivariateSkip", + domain: "jolt.stage1_uniskip_domain", + point_arity: 1, + }, + )?; + let _ = params; + Ok((state, opening, eval)) +} + +fn append_remaining_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + input_claim: Value<'c, 'a>, + uniskip_opening: Value<'c, 'a>, +) -> Result, MlirError> { + let num_rounds = params.log_t + 1; + let claim = context.append_typed_op( + module, + "piop.sumcheck_claim", + Some("stage1.outer_remaining.input"), + &[ + ("stage", "@stage1"), + ("domain", "@jolt.trace_domain"), + ("num_rounds", &int_attr(num_rounds)), + ("degree", "3 : i64"), + ("claim", "@stage1.uniskip.eval"), + ("relation", "@jolt.stage1.outer.remaining"), + ], + &[input_claim, uniskip_opening], + &["!piop.sumcheck_claim_type"], + )?; + let claim = first_result(claim, "piop.sumcheck_claim")?; + let batch = context.append_typed_op( + module, + "piop.sumcheck_batch", + Some("stage1.outer_remaining.batch"), + &[ + ("stage", "@stage1"), + ("proof_slot", "@stage1.sumcheck"), + ("policy", r#""jolt_core_front_loaded""#), + ("count", "1 : i64"), + ("ordered_claims", "[@stage1.outer_remaining.input]"), + ("claim_label", r#""sumcheck_claim""#), + ("round_label", r#""sumcheck_poly""#), + ("round_schedule", &format!("[{}]", num_rounds)), + ], + &[stage, claim], + &["!piop.sumcheck_batch_type"], + )?; + let batch = first_result(batch, "piop.sumcheck_batch")?; + let sumcheck = context.append_typed_op( + module, + "piop.sumcheck", + Some("stage1.outer_remaining.sumcheck"), + &[ + ("stage", "@stage1"), + ("proof_slot", "@stage1.sumcheck"), + ("relation", "@jolt.stage1.outer.remaining"), + ("policy", r#""jolt_core_front_loaded""#), + ("round_schedule", &format!("[{}]", num_rounds)), + ("claim_label", r#""sumcheck_claim""#), + ("round_label", r#""sumcheck_poly""#), + ("num_rounds", &int_attr(num_rounds)), + ("degree", "3 : i64"), + ], + &[state, batch], + &[ + "!transcript.state_type", + "!poly.point", + "!piop.sumcheck_result_type", + "!piop.sumcheck_proof_type", + ], + )?; + let state = result(sumcheck, 0, "piop.sumcheck")?; + let point = result(sumcheck, 1, "piop.sumcheck")?; + let result_value = result(sumcheck, 2, "piop.sumcheck")?; + let (point, result_value) = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage1.outer_remaining.instance", + source: "stage1.outer_remaining.sumcheck", + claim: "stage1.outer_remaining.input", + relation: "jolt.stage1.outer.remaining", + index: 0, + point_arity: params.log_t, + num_rounds, + round_offset: 1, + point_order: "reverse", + degree: 3, + }, + point, + result_value, + )?; + let mut claims = Vec::with_capacity(R1CS_INPUT_ORACLES.len()); + for (index, oracle) in R1CS_INPUT_ORACLES.iter().enumerate() { + let eval = append_sumcheck_eval( + context, + module, + &format!("stage1.outer_remaining.eval.{oracle}"), + "stage1.outer_remaining.sumcheck", + oracle, + index, + result_value, + )?; + claims.push(append_piop_opening_claim( + context, + module, + point, + eval, + OpeningClaimSpec { + symbol: &format!("stage1.outer_remaining.opening.{oracle}"), + oracle, + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?); + } + let _batch = context.append_typed_op( + module, + "piop.opening_batch", + Some("stage1.outer_remaining.openings"), + &[ + ("stage", "@stage1"), + ("proof_slot", "@stage1.virtual_openings"), + ("policy", r#""jolt_r1cs_input_order""#), + ("count", &int_attr(R1CS_INPUT_ORACLES.len())), + ("ordered_claims", &opening_claim_attr()), + ], + &claims, + &["!piop.opening_batch_type"], + )?; + Ok(state) +} + +fn append_sumcheck_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + oracle: &str, + index: usize, + result_value: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_eval", + Some(symbol), + &[ + ("source", &format!("@{source}")), + ("name", &format!("@{symbol}")), + ("index", &int_attr(index)), + ("oracle", &format!("@{oracle}")), + ], + &[result_value], + &["!field.scalar"], + )?; + first_result(op, "piop.sumcheck_eval") +} + +fn append_sumcheck_instance_result<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckInstanceResultSpec<'_>, + point: Value<'c, 'a>, + result_value: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_instance_result", + Some(spec.symbol), + &[ + ("source", &format!("@{}", spec.source)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ("index", &int_attr(spec.index)), + ("point_arity", &int_attr(spec.point_arity)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("round_offset", &int_attr(spec.round_offset)), + ("point_order", &format!("\"{}\"", spec.point_order)), + ("degree", &int_attr(spec.degree)), + ], + &[point, result_value], + &["!poly.point", "!piop.sumcheck_result_type"], + )?; + Ok(( + result(op, 0, "piop.sumcheck_instance_result")?, + result(op, 1, "piop.sumcheck_instance_result")?, + )) +} + +fn append_piop_opening_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + spec: OpeningClaimSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_claim", + Some(spec.symbol), + &[ + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", r#""virtual""#), + ], + &[point, eval], + &["!piop.opening_claim_type"], + )?; + first_result(op, "piop.opening_claim") +} + +struct OpeningClaimSpec<'a> { + symbol: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, +} + +struct SumcheckInstanceResultSpec<'a> { + symbol: &'a str, + source: &'a str, + claim: &'a str, + relation: &'a str, + index: usize, + point_arity: usize, + num_rounds: usize, + round_offset: usize, + point_order: &'a str, + degree: usize, +} + +fn first_result<'c, 'a>( + operation: OperationRef<'c, 'a>, + operation_name: &str, +) -> Result, MlirError> { + result(operation, 0, operation_name) +} + +fn result<'c, 'a>( + operation: OperationRef<'c, 'a>, + index: usize, + operation_name: &str, +) -> Result, MlirError> { + operation + .result(index) + .map(Into::into) + .map_err(|_| schema_error(format!("{operation_name} requires result {index}"))) +} + +#[derive(Clone, Debug)] +struct StageParamsAst { + field: String, + pcs: String, + transcript: String, +} + +fn stage_params(module: &BoltModule<'_, Party>) -> Result { + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + if operation_name(op) == "protocol.params" { + return Ok(StageParamsAst { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + } + Err(schema_error("stage1 lowering requires protocol.params")) +} + +fn operation_result_key_at( + operation: OperationRef<'_, '_>, + index: usize, +) -> Result { + let result = operation.result(index).map_err(|_| { + schema_error(format!( + "{} requires result {index}", + operation_name(operation) + )) + })?; + result_key(result.owner(), result.result_number()) +} + +fn result_key(operation: OperationRef<'_, '_>, result_number: usize) -> Result { + Ok(format!( + "{}#{result_number}", + string_attr(operation, "sym_name")? + )) +} + +fn operand_key(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + schema_error(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + schema_error(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + result_key(owner.owner(), owner.result_number()).map_err(|_| { + schema_error(format!( + "{} operand {index} owner missing sym_name", + operation_name(operation) + )) + }) +} + +fn lowered_operands<'c, 'a>( + operation: OperationRef<'_, '_>, + value_map: &BTreeMap>, + start_index: usize, +) -> Result>, MlirError> { + (start_index..operation.operand_count()) + .map(|index| { + let key = operand_key(operation, index)?; + value_map.get(&key).copied().ok_or_else(|| { + schema_error(format!( + "{} operand {index} was not lowered", + operation_name(operation) + )) + }) + }) + .collect() +} + +fn insert_result_mapping<'c, 'a>( + value_map: &mut BTreeMap>, + source: OperationRef<'_, '_>, + target: OperationRef<'c, 'a>, + source_index: usize, + target_index: usize, +) -> Result<(), MlirError> { + let key = operation_result_key_at(source, source_index)?; + let value = target.result(target_index).map(Into::into).map_err(|_| { + schema_error(format!( + "{} requires result {target_index}", + operation_name(target) + )) + })?; + let inserted = value_map.insert(key, value); + debug_assert!(inserted.is_none()); + Ok(()) +} + +fn ensure_kernel<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Compute>, + kernels: &mut BTreeMap, + relation: &str, +) -> Result { + if let Some(kernel) = kernels.get(relation) { + return Ok(kernel.clone()); + } + let spec = kernel_spec(relation)?; + context.append_op_with_owned_attrs( + module, + "compute.kernel", + Some(spec.symbol), + &[ + ("relation".to_owned(), symbol_ref(relation)), + ("kind".to_owned(), string_literal(spec.kind)), + ("backend".to_owned(), string_literal("cpu")), + ("abi".to_owned(), string_literal(spec.abi)), + ], + )?; + let inserted = kernels.insert(relation.to_owned(), spec.symbol.to_owned()); + debug_assert!(inserted.is_none()); + Ok(spec.symbol.to_owned()) +} + +fn kernel_spec(relation: &str) -> Result { + match relation { + "jolt.stage1.outer.uniskip" => Ok(KernelSpec { + symbol: "jolt.cpu.stage1.outer.uniskip", + kind: "sumcheck", + abi: "jolt_stage1_outer_uniskip", + }), + "jolt.stage1.outer.remaining" => Ok(KernelSpec { + symbol: "jolt.cpu.stage1.outer.remaining", + kind: "sumcheck", + abi: "jolt_stage1_outer_remaining", + }), + "jolt.stage2.product_virtual.uniskip" => Ok(KernelSpec { + symbol: "jolt.cpu.stage2.product_virtual.uniskip", + kind: "sumcheck", + abi: "jolt_stage2_product_virtual_uniskip", + }), + "jolt.stage2.ram.read_write" => Ok(KernelSpec { + symbol: "jolt.cpu.stage2.ram.read_write", + kind: "sumcheck", + abi: "jolt_stage2_ram_read_write", + }), + "jolt.stage2.product_virtual.remainder" => Ok(KernelSpec { + symbol: "jolt.cpu.stage2.product_virtual.remainder", + kind: "sumcheck", + abi: "jolt_stage2_product_virtual_remainder", + }), + "jolt.stage2.instruction_lookup.claim_reduction" => Ok(KernelSpec { + symbol: "jolt.cpu.stage2.instruction_lookup.claim_reduction", + kind: "sumcheck", + abi: "jolt_stage2_instruction_lookup_claim_reduction", + }), + "jolt.stage2.ram.raf_evaluation" => Ok(KernelSpec { + symbol: "jolt.cpu.stage2.ram.raf_evaluation", + kind: "sumcheck", + abi: "jolt_stage2_ram_raf_evaluation", + }), + "jolt.stage2.ram.output_check" => Ok(KernelSpec { + symbol: "jolt.cpu.stage2.ram.output_check", + kind: "sumcheck", + abi: "jolt_stage2_ram_output_check", + }), + "jolt.stage2.batched" => Ok(KernelSpec { + symbol: "jolt.cpu.stage2.batched", + kind: "sumcheck", + abi: "jolt_stage2_batched", + }), + "jolt.stage3.spartan_shift" => Ok(KernelSpec { + symbol: "jolt.cpu.stage3.spartan_shift", + kind: "sumcheck", + abi: "jolt_stage3_spartan_shift", + }), + "jolt.stage3.instruction_input" => Ok(KernelSpec { + symbol: "jolt.cpu.stage3.instruction_input", + kind: "sumcheck", + abi: "jolt_stage3_instruction_input", + }), + "jolt.stage3.registers_claim_reduction" => Ok(KernelSpec { + symbol: "jolt.cpu.stage3.registers_claim_reduction", + kind: "sumcheck", + abi: "jolt_stage3_registers_claim_reduction", + }), + "jolt.stage3.batched" => Ok(KernelSpec { + symbol: "jolt.cpu.stage3.batched", + kind: "sumcheck", + abi: "jolt_stage3_batched", + }), + "jolt.stage4.registers_read_write" => Ok(KernelSpec { + symbol: "jolt.cpu.stage4.registers_read_write", + kind: "sumcheck", + abi: "jolt_stage4_registers_read_write", + }), + "jolt.stage4.ram_val_check" => Ok(KernelSpec { + symbol: "jolt.cpu.stage4.ram_val_check", + kind: "sumcheck", + abi: "jolt_stage4_ram_val_check", + }), + "jolt.stage4.batched" => Ok(KernelSpec { + symbol: "jolt.cpu.stage4.batched", + kind: "sumcheck", + abi: "jolt_stage4_batched", + }), + "jolt.stage5.instruction_read_raf" => Ok(KernelSpec { + symbol: "jolt.cpu.stage5.instruction_read_raf", + kind: "sumcheck", + abi: "jolt_stage5_instruction_read_raf", + }), + "jolt.stage5.ram_ra_claim_reduction" => Ok(KernelSpec { + symbol: "jolt.cpu.stage5.ram_ra_claim_reduction", + kind: "sumcheck", + abi: "jolt_stage5_ram_ra_claim_reduction", + }), + "jolt.stage5.registers_val_evaluation" => Ok(KernelSpec { + symbol: "jolt.cpu.stage5.registers_val_evaluation", + kind: "sumcheck", + abi: "jolt_stage5_registers_val_evaluation", + }), + "jolt.stage5.batched" => Ok(KernelSpec { + symbol: "jolt.cpu.stage5.batched", + kind: "sumcheck", + abi: "jolt_stage5_batched", + }), + "jolt.stage6.bytecode_read_raf" => Ok(KernelSpec { + symbol: "jolt.cpu.stage6.bytecode_read_raf", + kind: "sumcheck", + abi: "jolt_stage6_bytecode_read_raf", + }), + "jolt.stage6.booleanity" => Ok(KernelSpec { + symbol: "jolt.cpu.stage6.booleanity", + kind: "sumcheck", + abi: "jolt_stage6_booleanity", + }), + "jolt.stage6.hamming_booleanity" => Ok(KernelSpec { + symbol: "jolt.cpu.stage6.hamming_booleanity", + kind: "sumcheck", + abi: "jolt_stage6_hamming_booleanity", + }), + "jolt.stage6.ram_ra_virtual" => Ok(KernelSpec { + symbol: "jolt.cpu.stage6.ram_ra_virtual", + kind: "sumcheck", + abi: "jolt_stage6_ram_ra_virtual", + }), + "jolt.stage6.instruction_ra_virtual" => Ok(KernelSpec { + symbol: "jolt.cpu.stage6.instruction_ra_virtual", + kind: "sumcheck", + abi: "jolt_stage6_instruction_ra_virtual", + }), + "jolt.stage6.inc_claim_reduction" => Ok(KernelSpec { + symbol: "jolt.cpu.stage6.inc_claim_reduction", + kind: "sumcheck", + abi: "jolt_stage6_inc_claim_reduction", + }), + "jolt.stage6.batched" => Ok(KernelSpec { + symbol: "jolt.cpu.stage6.batched", + kind: "sumcheck", + abi: "jolt_stage6_batched", + }), + "jolt.stage7.hamming_weight_claim_reduction" => Ok(KernelSpec { + symbol: "jolt.cpu.stage7.hamming_weight_claim_reduction", + kind: "sumcheck", + abi: "jolt_stage7_hamming_weight_claim_reduction", + }), + "jolt.stage7.batched" => Ok(KernelSpec { + symbol: "jolt.cpu.stage7.batched", + kind: "sumcheck", + abi: "jolt_stage7_batched", + }), + _ => Err(schema_error(format!( + "unsupported compute relation @{relation}" + ))), + } +} + +struct KernelSpec { + symbol: &'static str, + kind: &'static str, + abi: &'static str, +} + +fn int_attr(value: usize) -> String { + format!("{value} : i64") +} + +fn symbol_array_attr(values: &[&str]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn opening_claim_attr() -> String { + let values = R1CS_INPUT_ORACLES + .iter() + .map(|oracle| format!("@stage1.outer_remaining.opening.{oracle}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn schema_error(message: impl Into) -> MlirError { + let error = SchemaError::new(message); + error.into() +} + +fn symbol_ref(value: &str) -> String { + format!("@{value}") +} + +fn string_literal(value: &str) -> String { + format!("{value:?}") +} diff --git a/crates/bolt/src/protocols/jolt/phases/stage2.rs b/crates/bolt/src/protocols/jolt/phases/stage2.rs new file mode 100644 index 0000000000..2c67e00a81 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/stage2.rs @@ -0,0 +1,2082 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::OperationRef; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::Value; + +use crate::ir::{BoltModule, Compute, Party, Protocol, Role}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{ + operation_name, symbol_attr, verify_compute_schema, verify_party_schema, + verify_protocol_schema, SchemaError, +}; + +use super::super::oracles; +use super::super::params::JoltProtocolParams; +use super::lowering::{ + copy_attrs, field_lowering_attrs as field_compute_attrs, string_attr, + transcript_squeeze_compute_result_types, transcript_squeeze_protocol_result_type, +}; + +const PRODUCT_UNISKIP_DEGREE_BOUND: usize = 6; +const PRODUCT_UNISKIP_DOMAIN_START: isize = -1; +const PRODUCT_UNISKIP_DOMAIN_SIZE: usize = 3; +const RAM_RW_DEGREE: usize = 3; +const PRODUCT_REMAINDER_DEGREE: usize = 3; +const INSTRUCTION_CLAIM_REDUCTION_DEGREE: usize = 2; +const RAM_RAF_DEGREE: usize = 2; +const RAM_OUTPUT_DEGREE: usize = 3; + +const STAGE1_PRODUCT_OPENINGS: [&str; 3] = ["Product", "ShouldBranch", "ShouldJump"]; +const STAGE2_RAM_RW_INPUTS: [&str; 2] = ["RamReadValue", "RamWriteValue"]; +const STAGE2_INSTRUCTION_INPUTS: [&str; 5] = [ + "LookupOutput", + "LeftLookupOperand", + "RightLookupOperand", + "LeftInstructionInput", + "RightInstructionInput", +]; +const PRODUCT_REMAINDER_OUTPUTS: [&str; 8] = [ + "LeftInstructionInput", + "RightInstructionInput", + "OpFlagJump", + "OpFlagWriteLookupOutputToRD", + "LookupOutput", + "InstructionFlagBranch", + "NextIsNoop", + "OpFlagVirtualInstruction", +]; + +pub fn build_stage2_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.stage2", None); + oracles::append_foundation_ops(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.stage2"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + append_stage2_domains(context, &module, params)?; + append_stage2_oracles(context, &module)?; + append_stage2_relations(context, &module, params)?; + let inputs = append_stage2_opening_inputs(context, &module, params)?; + + let fs = context.append_typed_op( + &module, + "transcript.state", + Some("fs_after_stage1"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let state = first_result(fs, "transcript.state")?; + let stage = context.append_typed_op( + &module, + "piop.stage", + Some("stage2"), + &[ + ("name", r#""product_virtual_and_ram""#), + ("order", "2 : i64"), + ("roles", r#"["prover", "verifier"]"#), + ], + &[], + &["!piop.stage_type"], + )?; + let stage = first_result(stage, "piop.stage")?; + let (state, tau_high) = append_transcript_squeeze( + context, + &module, + state, + "stage2.product_virtual.tau_high", + "product_virtual_tau_high", + "challenge_scalar", + 1, + )?; + let (state, uniskip) = + append_product_uniskip(context, &module, params, state, stage, &inputs, tau_high)?; + let (state, ram_read_write_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage2.ram_read_write.gamma", + "ram_read_write_gamma", + "challenge_scalar", + 1, + )?; + let (state, instruction_lookup_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage2.instruction_lookup.gamma", + "instruction_lookup_gamma", + "challenge_scalar", + 1, + )?; + let (state, _ram_output_address) = append_transcript_squeeze( + context, + &module, + state, + "stage2.ram_output.r_address", + "ram_output_r_address", + "challenge_vector", + params.log_k_ram, + )?; + let _state = append_stage2_batched_sumcheck( + context, + &module, + params, + Stage2BatchedSumcheckInputs { + state, + stage, + openings: &inputs, + uniskip, + ram_read_write_gamma, + instruction_lookup_gamma, + }, + )?; + + verify_module(&module)?; + verify_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_stage2_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + verify_party_schema(module)?; + let role = module + .role() + .ok_or_else(|| schema_error("stage2 lowering requires party role"))?; + let params = stage_params(module)?; + let compute = context.new_module::(&module.name(), Some(role.clone())); + context.append_op_with_owned_attrs( + &compute, + "compute.params", + Some("jolt.compute_params"), + &[ + ("field".to_owned(), symbol_ref(¶ms.field)), + ("pcs".to_owned(), symbol_ref(¶ms.pcs)), + ("transcript".to_owned(), symbol_ref(¶ms.transcript)), + ], + )?; + context.append_op( + &compute, + "compute.function", + Some("jolt.stage2"), + &[("source", "@jolt.stage2")], + )?; + + let mut value_map = BTreeMap::new(); + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "piop.relation" => { + let attrs = copy_attrs( + op, + &["kind", "domain", "num_rounds", "degree", "output_count"], + )?; + let symbol = string_attr(op, "sym_name")?; + context.append_op_with_owned_attrs( + &compute, + "compute.relation", + Some(&symbol), + &attrs, + )?; + } + "transcript.state" => { + let attrs = copy_attrs(op, &["scheme"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_init", + Some(&symbol), + &attrs, + &[], + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "transcript.absorb_bytes" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "payload"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_absorb_bytes", + Some(&symbol), + &attrs, + &operands, + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "transcript.squeeze" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "kind", "count"])?; + let result_types = transcript_squeeze_compute_result_types(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_squeeze", + Some(&symbol), + &attrs, + &operands, + &result_types, + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "field.const" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field", "value"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.field_const", + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "field.zero" | "field.one" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + &format!("compute.{}", operation_name(op).replace('.', "_")), + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "field.add" | "field.sub" | "field.mul" | "field.neg" | "field.pow" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = field_compute_attrs(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + &format!("compute.{}", operation_name(op).replace('.', "_")), + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "poly.lagrange_basis_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["domain_start", "domain_size", "index"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.poly_lagrange_basis_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.opening_input" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source_stage", + "source_claim", + "oracle", + "domain", + "point_arity", + "claim_kind", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_input", + Some(&symbol), + &attrs, + &[], + &[ + "!compute.point", + "!compute.field_value", + "!compute.opening_claim_type", + ], + )?; + for index in 0..3 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "poly.point_slice" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "offset", "length"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.point_slice", + Some(&symbol), + &attrs, + &operands, + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "poly.point_concat" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["layout", "arity"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.point_concat", + Some(&symbol), + &attrs, + &operands, + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + let target_op = match &role { + Role::Prover => "compute.sumcheck_claim", + Role::Verifier => "compute.sumcheck_verify_claim", + }; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_batch" => { + let operands = lowered_operands(op, &value_map, 1)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + let target_op = match &role { + Role::Prover => "compute.sumcheck_driver", + Role::Verifier => "compute.sumcheck_verify", + }; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &[ + "!compute.transcript_state", + "!compute.point", + "!compute.sumcheck_result_type", + "!compute.sumcheck_proof_type", + ], + )?; + for index in 0..4 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "piop.sumcheck_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "name", "index", "oracle"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_instance_result" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_instance_result", + Some(&symbol), + &attrs, + &operands, + &["!compute.point", "!compute.sumcheck_result_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "piop.opening_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["oracle", "domain", "point_arity", "claim_kind"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.opening_claim_equal" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["mode"])?; + let _operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_claim_equal", + Some(&symbol), + &attrs, + &operands, + &[], + )?; + } + "piop.opening_batch" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &["stage", "proof_slot", "policy", "count", "ordered_claims"], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + _ => {} + } + } + + verify_module(&compute)?; + verify_compute_schema(&compute)?; + Ok(compute) +} + +fn append_stage2_domains<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + context.append_op( + module, + "poly.domain", + Some("jolt.stage2_uniskip_domain"), + &[("field", "@bn254_fr"), ("log_size", "1 : i64")], + )?; + context.append_op( + module, + "poly.domain", + Some("jolt.stage2_ram_rw_domain"), + &[ + ("field", "@bn254_fr"), + ("log_size", &int_attr(stage2_max_rounds(params))), + ], + )?; + context.append_op( + module, + "poly.domain", + Some("jolt.ram_address_domain"), + &[ + ("field", "@bn254_fr"), + ("log_size", &int_attr(params.log_k_ram)), + ], + ) +} + +fn append_stage2_oracles<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, +) -> Result<(), MlirError> { + let mut trace_oracles = BTreeSet::new(); + trace_oracles.extend(STAGE1_PRODUCT_OPENINGS); + trace_oracles.extend(STAGE2_RAM_RW_INPUTS); + trace_oracles.extend(STAGE2_INSTRUCTION_INPUTS); + trace_oracles.extend(PRODUCT_REMAINDER_OUTPUTS); + let _ = trace_oracles.insert("RamAddress"); + for oracle in trace_oracles { + append_virtual_oracle(context, module, oracle, "jolt.trace_domain")?; + } + append_virtual_oracle( + context, + module, + "UnivariateSkip", + "jolt.stage2_uniskip_domain", + )?; + append_virtual_oracle(context, module, "RamVal", "jolt.stage2_ram_rw_domain")?; + append_virtual_oracle(context, module, "RamRa", "jolt.stage2_ram_rw_domain")?; + append_virtual_oracle(context, module, "RamValFinal", "jolt.ram_address_domain")?; + context.append_op( + module, + "piop.oracle", + Some("RamInc"), + &[ + ("field", "@bn254_fr"), + ("domain", "@jolt.trace_domain"), + ("commit_domain", "@jolt.main_witness_commit_domain"), + ("visibility", r#""committed""#), + ("layout", r#""dense_trace""#), + ], + ) +} + +fn append_virtual_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + domain: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", &format!("@{domain}")), + ("commit_domain", &format!("@{domain}")), + ("visibility", r#""virtual""#), + ("layout", r#""virtual""#), + ], + ) +} + +fn append_stage2_relations<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + let max_rounds = stage2_max_rounds(params); + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage2.product_virtual.uniskip", + kind: "sumcheck", + domain: "jolt.stage2_uniskip_domain", + num_rounds: 1, + degree: PRODUCT_UNISKIP_DEGREE_BOUND, + output_count: 1, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage2.ram.read_write", + kind: "sumcheck", + domain: "jolt.stage2_ram_rw_domain", + num_rounds: max_rounds, + degree: RAM_RW_DEGREE, + output_count: 3, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage2.product_virtual.remainder", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: PRODUCT_REMAINDER_DEGREE, + output_count: PRODUCT_REMAINDER_OUTPUTS.len(), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage2.instruction_lookup.claim_reduction", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: INSTRUCTION_CLAIM_REDUCTION_DEGREE, + output_count: STAGE2_INSTRUCTION_INPUTS.len(), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage2.ram.raf_evaluation", + kind: "sumcheck", + domain: "jolt.ram_address_domain", + num_rounds: params.log_k_ram, + degree: RAM_RAF_DEGREE, + output_count: 1, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage2.ram.output_check", + kind: "sumcheck", + domain: "jolt.ram_address_domain", + num_rounds: params.log_k_ram, + degree: RAM_OUTPUT_DEGREE, + output_count: 1, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage2.batched", + kind: "batched_sumcheck", + domain: "jolt.stage2_ram_rw_domain", + num_rounds: max_rounds, + degree: RAM_RW_DEGREE, + output_count: 18, + }, + ) +} + +fn append_relation<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + spec: RelationSpec<'_>, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.relation", + Some(spec.symbol), + &[ + ("kind", &format!("\"{}\"", spec.kind)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("output_count", &int_attr(spec.output_count)), + ], + ) +} + +fn append_stage2_opening_inputs<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let product = append_stage1_opening_input(context, module, params, "Product")?; + let should_branch = append_stage1_opening_input(context, module, params, "ShouldBranch")?; + let should_jump = append_stage1_opening_input(context, module, params, "ShouldJump")?; + let ram_read_value = append_stage1_opening_input(context, module, params, "RamReadValue")?; + let ram_write_value = append_stage1_opening_input(context, module, params, "RamWriteValue")?; + let lookup_output = append_stage1_opening_input(context, module, params, "LookupOutput")?; + let left_lookup_operand = + append_stage1_opening_input(context, module, params, "LeftLookupOperand")?; + let right_lookup_operand = + append_stage1_opening_input(context, module, params, "RightLookupOperand")?; + let left_instruction_input = + append_stage1_opening_input(context, module, params, "LeftInstructionInput")?; + let right_instruction_input = + append_stage1_opening_input(context, module, params, "RightInstructionInput")?; + let ram_address = append_stage1_opening_input(context, module, params, "RamAddress")?; + + Ok(Stage2OpeningInputs { + product, + should_branch, + should_jump, + ram_read_value, + ram_write_value, + lookup_output, + left_lookup_operand, + right_lookup_operand, + left_instruction_input, + right_instruction_input, + ram_address, + }) +} + +fn append_stage1_opening_input<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + oracle: &str, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_input", + Some(&format!("stage2.input.stage1.{oracle}")), + &[ + ("source_stage", "@stage1"), + ( + "source_claim", + &format!("@stage1.outer_remaining.opening.{oracle}"), + ), + ("oracle", &format!("@{oracle}")), + ("domain", "@jolt.trace_domain"), + ("point_arity", &int_attr(params.log_t)), + ("claim_kind", r#""virtual""#), + ], + &[], + &["!poly.point", "!field.scalar", "!piop.opening_claim_type"], + )?; + Ok(Stage2OpeningInput { + point: result(op, 0, "piop.opening_input")?, + eval: result(op, 1, "piop.opening_input")?, + claim: result(op, 2, "piop.opening_input")?, + }) +} + +fn append_transcript_squeeze<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + symbol: &str, + label: &str, + kind: &str, + count: usize, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "transcript.squeeze", + Some(symbol), + &[ + ("label", &format!("\"{label}\"")), + ("kind", &format!("\"{kind}\"")), + ("count", &int_attr(count)), + ], + &[state], + &[ + "!transcript.state_type", + transcript_squeeze_protocol_result_type(kind)?, + ], + )?; + Ok(( + result(op, 0, "transcript.squeeze")?, + result(op, 1, "transcript.squeeze")?, + )) +} + +fn append_field_const<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + value: usize, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.const", + Some(symbol), + &[("field", "@bn254_fr"), ("value", &int_attr(value))], + &[], + &["!field.scalar"], + )?; + first_result(op, "field.const") +} + +fn append_field_binary<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + op_name: &str, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + op_name, + Some(symbol), + &[], + &[lhs, rhs], + &["!field.scalar"], + )?; + first_result(op, op_name) +} + +fn append_field_add<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.add", symbol, lhs, rhs) +} + +fn append_field_mul<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.mul", symbol, lhs, rhs) +} + +fn append_lagrange_basis_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + point: Value<'c, 'a>, + index: usize, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.lagrange_basis_eval", + Some(symbol), + &[ + ( + "domain_start", + &int_attr_signed(PRODUCT_UNISKIP_DOMAIN_START), + ), + ("domain_size", &int_attr(PRODUCT_UNISKIP_DOMAIN_SIZE)), + ("index", &int_attr(index)), + ], + &[point], + &["!field.scalar"], + )?; + first_result(op, "poly.lagrange_basis_eval") +} + +fn append_product_uniskip<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + _params: &JoltProtocolParams, + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + inputs: &Stage2OpeningInputs<'c, 'a>, + tau_high: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Stage2UniskipOutput<'c, 'a>), MlirError> { + let product_weight = append_lagrange_basis_eval( + context, + module, + "stage2.product_virtual.uniskip.weight.Product", + tau_high, + 0, + )?; + let branch_weight = append_lagrange_basis_eval( + context, + module, + "stage2.product_virtual.uniskip.weight.ShouldBranch", + tau_high, + 1, + )?; + let jump_weight = append_lagrange_basis_eval( + context, + module, + "stage2.product_virtual.uniskip.weight.ShouldJump", + tau_high, + 2, + )?; + let product_term = append_field_mul( + context, + module, + "stage2.product_virtual.uniskip.term.Product", + product_weight, + inputs.product.eval, + )?; + let branch_term = append_field_mul( + context, + module, + "stage2.product_virtual.uniskip.term.ShouldBranch", + branch_weight, + inputs.should_branch.eval, + )?; + let jump_term = append_field_mul( + context, + module, + "stage2.product_virtual.uniskip.term.ShouldJump", + jump_weight, + inputs.should_jump.eval, + )?; + let product_branch_sum = append_field_add( + context, + module, + "stage2.product_virtual.uniskip.partial.ProductShouldBranch", + product_term, + branch_term, + )?; + let input_claim = append_field_add( + context, + module, + "stage2.product_virtual.uniskip.claim_expr", + product_branch_sum, + jump_term, + )?; + let claim = append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage2.product_virtual.uniskip.input", + stage: "stage2", + domain: "jolt.stage2_uniskip_domain", + num_rounds: 1, + degree: PRODUCT_UNISKIP_DEGREE_BOUND, + claim: "stage2.product_virtual.weighted_stage1_outputs", + relation: "jolt.stage2.product_virtual.uniskip", + }, + input_claim, + &[ + inputs.product.claim, + inputs.should_branch.claim, + inputs.should_jump.claim, + ], + )?; + let batch = append_sumcheck_batch( + context, + module, + stage, + &[claim], + SumcheckBatchSpec { + symbol: "stage2.product_virtual.uniskip.batch", + stage: "stage2", + proof_slot: "stage2.product_virtual.uni_skip_first_round", + policy: "single_instance", + ordered_claims: &["stage2.product_virtual.uniskip.input"], + claim_label: "uniskip_claim", + round_label: "uniskip_poly", + round_schedule: "[1]".to_owned(), + }, + )?; + let (state, point, result_value) = append_sumcheck( + context, + module, + state, + batch, + SumcheckDriverSpec { + symbol: "stage2.product_virtual.uniskip.sumcheck", + stage: "stage2", + proof_slot: "stage2.product_virtual.uni_skip_first_round", + relation: "jolt.stage2.product_virtual.uniskip", + policy: "univariate_skip", + round_schedule: "[1]".to_owned(), + claim_label: "uniskip_claim", + round_label: "uniskip_poly", + num_rounds: 1, + degree: PRODUCT_UNISKIP_DEGREE_BOUND, + }, + )?; + let (point, result_value) = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage2.product_virtual.uniskip.instance", + source: "stage2.product_virtual.uniskip.sumcheck", + claim: "stage2.product_virtual.uniskip.input", + relation: "jolt.stage2.product_virtual.uniskip", + index: 0, + point_arity: 1, + num_rounds: 1, + round_offset: 0, + point_order: "as_is", + degree: PRODUCT_UNISKIP_DEGREE_BOUND, + }, + point, + result_value, + )?; + let eval = append_sumcheck_eval( + context, + module, + "stage2.product_virtual.uniskip.eval.UnivariateSkip", + "stage2.product_virtual.uniskip.sumcheck", + "UnivariateSkip", + 0, + result_value, + )?; + let opening = append_opening_claim( + context, + module, + point, + eval, + OpeningClaimSpec { + symbol: "stage2.product_virtual.uniskip.opening.UnivariateSkip", + oracle: "UnivariateSkip", + domain: "jolt.stage2_uniskip_domain", + point_arity: 1, + claim_kind: "virtual", + }, + )?; + Ok((state, Stage2UniskipOutput { opening, eval })) +} + +fn append_stage2_batched_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + spec: Stage2BatchedSumcheckInputs<'c, 'a, '_>, +) -> Result, MlirError> { + let inputs = spec.openings; + let uniskip = spec.uniskip; + let max_rounds = stage2_max_rounds(params); + let product_offset = max_rounds - params.log_t; + let ram_offset = params.log_t; + let ram_write_term = append_field_mul( + context, + module, + "stage2.ram_read_write.term.RamWriteValue", + spec.ram_read_write_gamma, + inputs.ram_write_value.eval, + )?; + let ram_read_write_claim = append_field_add( + context, + module, + "stage2.ram_read_write.claim_expr", + inputs.ram_read_value.eval, + ram_write_term, + )?; + let product_remainder_claim = uniskip.eval; + let gamma2 = append_field_mul( + context, + module, + "stage2.instruction_lookup.gamma2", + spec.instruction_lookup_gamma, + spec.instruction_lookup_gamma, + )?; + let gamma3 = append_field_mul( + context, + module, + "stage2.instruction_lookup.gamma3", + gamma2, + spec.instruction_lookup_gamma, + )?; + let gamma4 = append_field_mul( + context, + module, + "stage2.instruction_lookup.gamma4", + gamma2, + gamma2, + )?; + let left_lookup_term = append_field_mul( + context, + module, + "stage2.instruction_lookup.term.LeftLookupOperand", + spec.instruction_lookup_gamma, + inputs.left_lookup_operand.eval, + )?; + let right_lookup_term = append_field_mul( + context, + module, + "stage2.instruction_lookup.term.RightLookupOperand", + gamma2, + inputs.right_lookup_operand.eval, + )?; + let left_input_term = append_field_mul( + context, + module, + "stage2.instruction_lookup.term.LeftInstructionInput", + gamma3, + inputs.left_instruction_input.eval, + )?; + let right_input_term = append_field_mul( + context, + module, + "stage2.instruction_lookup.term.RightInstructionInput", + gamma4, + inputs.right_instruction_input.eval, + )?; + let instruction_sum_0 = append_field_add( + context, + module, + "stage2.instruction_lookup.partial.LookupOutputLeftOperand", + inputs.lookup_output.eval, + left_lookup_term, + )?; + let instruction_sum_1 = append_field_add( + context, + module, + "stage2.instruction_lookup.partial.RightOperand", + instruction_sum_0, + right_lookup_term, + )?; + let instruction_sum_2 = append_field_add( + context, + module, + "stage2.instruction_lookup.partial.LeftInstructionInput", + instruction_sum_1, + left_input_term, + )?; + let instruction_claim = append_field_add( + context, + module, + "stage2.instruction_lookup.claim_reduction.claim_expr", + instruction_sum_2, + right_input_term, + )?; + let ram_raf_claim = inputs.ram_address.eval; + let ram_output_claim = append_field_const(context, module, "stage2.ram_output.zero", 0)?; + let claims = [ + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage2.ram_read_write.input", + stage: "stage2", + domain: "jolt.stage2_ram_rw_domain", + num_rounds: max_rounds, + degree: RAM_RW_DEGREE, + claim: "stage2.ram_read_write.weighted_values", + relation: "jolt.stage2.ram.read_write", + }, + ram_read_write_claim, + &[inputs.ram_read_value.claim, inputs.ram_write_value.claim], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage2.product_virtual.remainder.input", + stage: "stage2", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: PRODUCT_REMAINDER_DEGREE, + claim: "stage2.product_virtual.uniskip.opening", + relation: "jolt.stage2.product_virtual.remainder", + }, + product_remainder_claim, + &[uniskip.opening], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage2.instruction_lookup.claim_reduction.input", + stage: "stage2", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: INSTRUCTION_CLAIM_REDUCTION_DEGREE, + claim: "stage2.instruction_lookup.weighted_operands", + relation: "jolt.stage2.instruction_lookup.claim_reduction", + }, + instruction_claim, + &[ + inputs.lookup_output.claim, + inputs.left_lookup_operand.claim, + inputs.right_lookup_operand.claim, + inputs.left_instruction_input.claim, + inputs.right_instruction_input.claim, + ], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage2.ram_raf.input", + stage: "stage2", + domain: "jolt.ram_address_domain", + num_rounds: params.log_k_ram, + degree: RAM_RAF_DEGREE, + claim: "stage2.ram_raf.ram_address", + relation: "jolt.stage2.ram.raf_evaluation", + }, + ram_raf_claim, + &[inputs.ram_address.claim], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage2.ram_output.input", + stage: "stage2", + domain: "jolt.ram_address_domain", + num_rounds: params.log_k_ram, + degree: RAM_OUTPUT_DEGREE, + claim: "zero", + relation: "jolt.stage2.ram.output_check", + }, + ram_output_claim, + &[], + )?, + ]; + let batch = append_sumcheck_batch( + context, + module, + spec.stage, + &claims, + SumcheckBatchSpec { + symbol: "stage2.batch", + stage: "stage2", + proof_slot: "stage2.sumcheck", + policy: "jolt_core_stage2_aligned", + ordered_claims: &[ + "stage2.ram_read_write.input", + "stage2.product_virtual.remainder.input", + "stage2.instruction_lookup.claim_reduction.input", + "stage2.ram_raf.input", + "stage2.ram_output.input", + ], + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + round_schedule: format!("[{}, {}]", params.log_t, params.log_k_ram), + }, + )?; + let (state, point, result_value) = append_sumcheck( + context, + module, + spec.state, + batch, + SumcheckDriverSpec { + symbol: "stage2.sumcheck", + stage: "stage2", + proof_slot: "stage2.sumcheck", + relation: "jolt.stage2.batched", + policy: "jolt_core_stage2_aligned", + round_schedule: format!("[{}, {}]", params.log_t, params.log_k_ram), + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + num_rounds: max_rounds, + degree: RAM_RW_DEGREE, + }, + )?; + let ram_rw = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage2.ram_read_write.instance", + source: "stage2.sumcheck", + claim: "stage2.ram_read_write.input", + relation: "jolt.stage2.ram.read_write", + index: 0, + point_arity: max_rounds, + num_rounds: max_rounds, + round_offset: 0, + point_order: "reverse", + degree: RAM_RW_DEGREE, + }, + point, + result_value, + )?; + let product = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage2.product_virtual.remainder.instance", + source: "stage2.sumcheck", + claim: "stage2.product_virtual.remainder.input", + relation: "jolt.stage2.product_virtual.remainder", + index: 1, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: product_offset, + point_order: "reverse", + degree: PRODUCT_REMAINDER_DEGREE, + }, + point, + result_value, + )?; + let instruction = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage2.instruction_lookup.claim_reduction.instance", + source: "stage2.sumcheck", + claim: "stage2.instruction_lookup.claim_reduction.input", + relation: "jolt.stage2.instruction_lookup.claim_reduction", + index: 2, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: product_offset, + point_order: "reverse", + degree: INSTRUCTION_CLAIM_REDUCTION_DEGREE, + }, + point, + result_value, + )?; + let ram_raf = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage2.ram_raf.instance", + source: "stage2.sumcheck", + claim: "stage2.ram_raf.input", + relation: "jolt.stage2.ram.raf_evaluation", + index: 3, + point_arity: params.log_k_ram, + num_rounds: params.log_k_ram, + round_offset: ram_offset, + point_order: "reverse", + degree: RAM_RAF_DEGREE, + }, + point, + result_value, + )?; + let ram_output = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage2.ram_output.instance", + source: "stage2.sumcheck", + claim: "stage2.ram_output.input", + relation: "jolt.stage2.ram.output_check", + index: 4, + point_arity: params.log_k_ram, + num_rounds: params.log_k_ram, + round_offset: ram_offset, + point_order: "reverse", + degree: RAM_OUTPUT_DEGREE, + }, + point, + result_value, + )?; + append_stage2_output_openings( + context, + module, + params, + Stage2OutputOpeningSpec { + outputs: &[ + InstanceOutput { + prefix: "stage2.product_virtual.remainder", + instance: product, + eval_source: "stage2.sumcheck", + outputs: &PRODUCT_REMAINDER_OUTPUTS, + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "virtual", + }, + InstanceOutput { + prefix: "stage2.instruction_lookup.claim_reduction", + instance: instruction, + eval_source: "stage2.sumcheck", + outputs: &STAGE2_INSTRUCTION_INPUTS, + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "virtual", + }, + ], + ram_rw, + ram_raf, + ram_output, + stage1_ram_address_point: inputs.ram_address.point, + }, + )?; + Ok(state) +} + +fn append_stage2_output_openings<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + spec: Stage2OutputOpeningSpec<'c, 'a, '_>, +) -> Result<(), MlirError> { + let mut claims = Vec::new(); + let mut claim_symbols = Vec::new(); + + for (index, &oracle) in ["RamVal", "RamRa"].iter().enumerate() { + let symbol = format!("stage2.ram_read_write.opening.{oracle}"); + let eval = append_sumcheck_eval( + context, + module, + &format!("stage2.ram_read_write.eval.{oracle}"), + "stage2.sumcheck", + oracle, + index, + spec.ram_rw.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + spec.ram_rw.0, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle, + domain: "jolt.stage2_ram_rw_domain", + point_arity: stage2_max_rounds(params), + claim_kind: "virtual", + }, + )?); + } + let ram_inc_point = append_point_slice( + context, + module, + "stage2.ram_read_write.point.RamInc", + "stage2.ram_read_write.instance", + params.log_k_ram, + params.log_t, + spec.ram_rw.0, + )?; + let ram_inc_eval = append_sumcheck_eval( + context, + module, + "stage2.ram_read_write.eval.RamInc", + "stage2.sumcheck", + "RamInc", + 2, + spec.ram_rw.1, + )?; + claim_symbols.push("stage2.ram_read_write.opening.RamInc".to_owned()); + claims.push(append_opening_claim( + context, + module, + ram_inc_point, + ram_inc_eval, + OpeningClaimSpec { + symbol: "stage2.ram_read_write.opening.RamInc", + oracle: "RamInc", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?); + + for output in spec.outputs { + for (index, &oracle) in output.outputs.iter().enumerate() { + let symbol = format!("{}.opening.{oracle}", output.prefix); + let eval = append_sumcheck_eval( + context, + module, + &format!("{}.eval.{oracle}", output.prefix), + output.eval_source, + oracle, + index, + output.instance.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + output.instance.0, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle, + domain: output.domain, + point_arity: output.point_arity, + claim_kind: output.claim_kind, + }, + )?); + } + } + + let ram_raf_point = append_point_concat( + context, + module, + "stage2.ram_raf.point.RamRa", + "address_then_cycle", + params.log_k_ram + params.log_t, + &[spec.ram_raf.0, spec.stage1_ram_address_point], + )?; + let ram_raf_eval = append_sumcheck_eval( + context, + module, + "stage2.ram_raf.eval.RamRa", + "stage2.sumcheck", + "RamRa", + 0, + spec.ram_raf.1, + )?; + claim_symbols.push("stage2.ram_raf.opening.RamRa".to_owned()); + claims.push(append_opening_claim( + context, + module, + ram_raf_point, + ram_raf_eval, + OpeningClaimSpec { + symbol: "stage2.ram_raf.opening.RamRa", + oracle: "RamRa", + domain: "jolt.stage2_ram_rw_domain", + point_arity: params.log_k_ram + params.log_t, + claim_kind: "virtual", + }, + )?); + + let ram_output_eval = append_sumcheck_eval( + context, + module, + "stage2.ram_output.eval.RamValFinal", + "stage2.sumcheck", + "RamValFinal", + 0, + spec.ram_output.1, + )?; + claim_symbols.push("stage2.ram_output.opening.RamValFinal".to_owned()); + claims.push(append_opening_claim( + context, + module, + spec.ram_output.0, + ram_output_eval, + OpeningClaimSpec { + symbol: "stage2.ram_output.opening.RamValFinal", + oracle: "RamValFinal", + domain: "jolt.ram_address_domain", + point_arity: params.log_k_ram, + claim_kind: "virtual", + }, + )?); + + let claim_names = claim_symbols.iter().map(String::as_str).collect::>(); + let _batch = context.append_typed_op( + module, + "piop.opening_batch", + Some("stage2.openings"), + &[ + ("stage", "@stage2"), + ("proof_slot", "@stage2.openings"), + ("policy", r#""jolt_stage2_output_order""#), + ("count", &int_attr(claims.len())), + ("ordered_claims", &symbol_array_attr(&claim_names)), + ], + &claims, + &["!piop.opening_batch_type"], + )?; + Ok(()) +} + +fn append_sumcheck_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckClaimSpec<'_>, + input_claim: Value<'c, 'a>, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(inputs.len() + 1); + operands.push(input_claim); + operands.extend_from_slice(inputs); + let op = context.append_typed_op( + module, + "piop.sumcheck_claim", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ], + &operands, + &["!piop.sumcheck_claim_type"], + )?; + first_result(op, "piop.sumcheck_claim") +} + +fn append_sumcheck_batch<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + stage: Value<'c, 'a>, + claims: &[Value<'c, 'a>], + spec: SumcheckBatchSpec<'_>, +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(claims.len() + 1); + operands.push(stage); + operands.extend_from_slice(claims); + let op = context.append_typed_op( + module, + "piop.sumcheck_batch", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("policy", &format!("\"{}\"", spec.policy)), + ("count", &int_attr(spec.ordered_claims.len())), + ("ordered_claims", &symbol_array_attr(spec.ordered_claims)), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("round_schedule", &spec.round_schedule), + ], + &operands, + &["!piop.sumcheck_batch_type"], + )?; + first_result(op, "piop.sumcheck_batch") +} + +fn append_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + batch: Value<'c, 'a>, + spec: SumcheckDriverSpec<'_>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("relation", &format!("@{}", spec.relation)), + ("policy", &format!("\"{}\"", spec.policy)), + ("round_schedule", &spec.round_schedule), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ], + &[state, batch], + &[ + "!transcript.state_type", + "!poly.point", + "!piop.sumcheck_result_type", + "!piop.sumcheck_proof_type", + ], + )?; + Ok(( + result(op, 0, "piop.sumcheck")?, + result(op, 1, "piop.sumcheck")?, + result(op, 2, "piop.sumcheck")?, + )) +} + +fn append_sumcheck_instance_result<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckInstanceResultSpec<'_>, + point: Value<'c, 'a>, + result_value: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_instance_result", + Some(spec.symbol), + &[ + ("source", &format!("@{}", spec.source)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ("index", &int_attr(spec.index)), + ("point_arity", &int_attr(spec.point_arity)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("round_offset", &int_attr(spec.round_offset)), + ("point_order", &format!("\"{}\"", spec.point_order)), + ("degree", &int_attr(spec.degree)), + ], + &[point, result_value], + &["!poly.point", "!piop.sumcheck_result_type"], + )?; + Ok(( + result(op, 0, "piop.sumcheck_instance_result")?, + result(op, 1, "piop.sumcheck_instance_result")?, + )) +} + +fn append_sumcheck_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + oracle: &str, + index: usize, + result_value: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_eval", + Some(symbol), + &[ + ("source", &format!("@{source}")), + ("name", &format!("@{symbol}")), + ("index", &int_attr(index)), + ("oracle", &format!("@{oracle}")), + ], + &[result_value], + &["!field.scalar"], + )?; + first_result(op, "piop.sumcheck_eval") +} + +fn append_opening_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + spec: OpeningClaimSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_claim", + Some(spec.symbol), + &[ + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", &format!("\"{}\"", spec.claim_kind)), + ], + &[point, eval], + &["!piop.opening_claim_type"], + )?; + first_result(op, "piop.opening_claim") +} + +fn append_point_slice<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + offset: usize, + length: usize, + point: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_slice", + Some(symbol), + &[ + ("source", &format!("@{source}")), + ("offset", &int_attr(offset)), + ("length", &int_attr(length)), + ], + &[point], + &["!poly.point"], + )?; + first_result(op, "poly.point_slice") +} + +fn append_point_concat<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + layout: &str, + arity: usize, + points: &[Value<'c, 'a>], +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_concat", + Some(symbol), + &[ + ("layout", &format!("\"{layout}\"")), + ("arity", &int_attr(arity)), + ], + points, + &["!poly.point"], + )?; + first_result(op, "poly.point_concat") +} + +fn first_result<'c, 'a>( + operation: OperationRef<'c, 'a>, + operation_name: &str, +) -> Result, MlirError> { + result(operation, 0, operation_name) +} + +fn result<'c, 'a>( + operation: OperationRef<'c, 'a>, + index: usize, + operation_name: &str, +) -> Result, MlirError> { + operation + .result(index) + .map(Into::into) + .map_err(|_| schema_error(format!("{operation_name} requires result {index}"))) +} + +#[derive(Clone, Debug)] +struct StageParamsAst { + field: String, + pcs: String, + transcript: String, +} + +fn stage_params(module: &BoltModule<'_, Party>) -> Result { + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + if operation_name(op) == "protocol.params" { + return Ok(StageParamsAst { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + } + Err(schema_error("stage2 lowering requires protocol.params")) +} + +fn operation_result_key_at( + operation: OperationRef<'_, '_>, + index: usize, +) -> Result { + let result = operation.result(index).map_err(|_| { + schema_error(format!( + "{} requires result {index}", + operation_name(operation) + )) + })?; + result_key(result.owner(), result.result_number()) +} + +fn result_key(operation: OperationRef<'_, '_>, result_number: usize) -> Result { + Ok(format!( + "{}#{result_number}", + string_attr(operation, "sym_name")? + )) +} + +fn operand_key(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + schema_error(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + schema_error(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + result_key(owner.owner(), owner.result_number()).map_err(|_| { + schema_error(format!( + "{} operand {index} owner missing sym_name", + operation_name(operation) + )) + }) +} + +fn lowered_operands<'c, 'a>( + operation: OperationRef<'_, '_>, + value_map: &BTreeMap>, + start_index: usize, +) -> Result>, MlirError> { + (start_index..operation.operand_count()) + .map(|index| { + let key = operand_key(operation, index)?; + value_map.get(&key).copied().ok_or_else(|| { + schema_error(format!( + "{} operand {index} was not lowered", + operation_name(operation) + )) + }) + }) + .collect() +} + +fn insert_result_mapping<'c, 'a>( + value_map: &mut BTreeMap>, + source: OperationRef<'_, '_>, + target: OperationRef<'c, 'a>, + source_index: usize, + target_index: usize, +) -> Result<(), MlirError> { + let key = operation_result_key_at(source, source_index)?; + let value = target.result(target_index).map(Into::into).map_err(|_| { + schema_error(format!( + "{} requires result {target_index}", + operation_name(target) + )) + })?; + let inserted = value_map.insert(key, value); + debug_assert!(inserted.is_none()); + Ok(()) +} + +fn symbol_ref(symbol: &str) -> String { + format!("@{symbol}") +} + +fn stage2_max_rounds(params: &JoltProtocolParams) -> usize { + params.log_t + params.log_k_ram +} + +fn int_attr(value: usize) -> String { + format!("{value} : i64") +} + +fn int_attr_signed(value: isize) -> String { + format!("{value} : i64") +} + +fn symbol_array_attr(values: &[&str]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn schema_error(message: impl Into) -> MlirError { + SchemaError::new(message).into() +} + +#[derive(Clone, Copy)] +struct Stage2OpeningInput<'c, 'a> { + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + claim: Value<'c, 'a>, +} + +struct Stage2OpeningInputs<'c, 'a> { + product: Stage2OpeningInput<'c, 'a>, + should_branch: Stage2OpeningInput<'c, 'a>, + should_jump: Stage2OpeningInput<'c, 'a>, + ram_read_value: Stage2OpeningInput<'c, 'a>, + ram_write_value: Stage2OpeningInput<'c, 'a>, + lookup_output: Stage2OpeningInput<'c, 'a>, + left_lookup_operand: Stage2OpeningInput<'c, 'a>, + right_lookup_operand: Stage2OpeningInput<'c, 'a>, + left_instruction_input: Stage2OpeningInput<'c, 'a>, + right_instruction_input: Stage2OpeningInput<'c, 'a>, + ram_address: Stage2OpeningInput<'c, 'a>, +} + +#[derive(Clone, Copy)] +struct Stage2UniskipOutput<'c, 'a> { + opening: Value<'c, 'a>, + eval: Value<'c, 'a>, +} + +struct Stage2BatchedSumcheckInputs<'c, 'a, 'b> { + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + openings: &'b Stage2OpeningInputs<'c, 'a>, + uniskip: Stage2UniskipOutput<'c, 'a>, + ram_read_write_gamma: Value<'c, 'a>, + instruction_lookup_gamma: Value<'c, 'a>, +} + +struct RelationSpec<'a> { + symbol: &'a str, + kind: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + output_count: usize, +} + +struct SumcheckClaimSpec<'a> { + symbol: &'a str, + stage: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + claim: &'a str, + relation: &'a str, +} + +struct SumcheckBatchSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + policy: &'a str, + ordered_claims: &'a [&'a str], + claim_label: &'a str, + round_label: &'a str, + round_schedule: String, +} + +struct SumcheckDriverSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + relation: &'a str, + policy: &'a str, + round_schedule: String, + claim_label: &'a str, + round_label: &'a str, + num_rounds: usize, + degree: usize, +} + +struct SumcheckInstanceResultSpec<'a> { + symbol: &'a str, + source: &'a str, + claim: &'a str, + relation: &'a str, + index: usize, + point_arity: usize, + num_rounds: usize, + round_offset: usize, + point_order: &'a str, + degree: usize, +} + +struct OpeningClaimSpec<'a> { + symbol: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, + claim_kind: &'a str, +} + +struct Stage2OutputOpeningSpec<'c, 'a, 'b> { + outputs: &'b [InstanceOutput<'c, 'a, 'b>], + ram_rw: (Value<'c, 'a>, Value<'c, 'a>), + ram_raf: (Value<'c, 'a>, Value<'c, 'a>), + ram_output: (Value<'c, 'a>, Value<'c, 'a>), + stage1_ram_address_point: Value<'c, 'a>, +} + +struct InstanceOutput<'c, 'a, 'b> { + prefix: &'b str, + instance: (Value<'c, 'a>, Value<'c, 'a>), + eval_source: &'b str, + outputs: &'b [&'b str], + domain: &'b str, + point_arity: usize, + claim_kind: &'b str, +} diff --git a/crates/bolt/src/protocols/jolt/phases/stage3.rs b/crates/bolt/src/protocols/jolt/phases/stage3.rs new file mode 100644 index 0000000000..b0e3c1513d --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/stage3.rs @@ -0,0 +1,1762 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::OperationRef; +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::Value; + +use crate::ir::{BoltModule, Compute, Party, Protocol, Role}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{ + operation_name, symbol_attr, verify_compute_schema, verify_party_schema, + verify_protocol_schema, SchemaError, +}; + +use super::super::oracles; +use super::super::params::JoltProtocolParams; +use super::lowering::{ + copy_attrs, field_lowering_attrs as field_compute_attrs, string_attr, + transcript_squeeze_compute_result_types, transcript_squeeze_protocol_result_type, +}; + +const SPARTAN_SHIFT_DEGREE: usize = 2; +const INSTRUCTION_INPUT_DEGREE: usize = 3; +const REGISTERS_CLAIM_REDUCTION_DEGREE: usize = 2; +const STAGE3_BATCHED_DEGREE: usize = 3; + +const STAGE3_SHIFT_INPUTS: [&str; 4] = [ + "NextUnexpandedPC", + "NextPC", + "NextIsVirtual", + "NextIsFirstInSequence", +]; +const STAGE3_SHIFT_OUTPUTS: [&str; 5] = [ + "UnexpandedPC", + "PC", + "OpFlagVirtualInstruction", + "OpFlagIsFirstInSequence", + "InstructionFlagIsNoop", +]; +const STAGE3_INSTRUCTION_INPUT_OUTPUTS: [&str; 8] = [ + "InstructionFlagLeftOperandIsRs1Value", + "Rs1Value", + "InstructionFlagLeftOperandIsPC", + "UnexpandedPC", + "InstructionFlagRightOperandIsRs2Value", + "Rs2Value", + "InstructionFlagRightOperandIsImm", + "Imm", +]; +const STAGE3_REGISTER_INPUTS: [&str; 3] = ["RdWriteValue", "Rs1Value", "Rs2Value"]; + +pub fn build_stage3_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.stage3", None); + oracles::append_foundation_ops(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.stage3"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + append_stage3_oracles(context, &module)?; + append_stage3_relations(context, &module, params)?; + let inputs = append_stage3_opening_inputs(context, &module, params)?; + + let fs = context.append_typed_op( + &module, + "transcript.state", + Some("fs_after_stage2"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let state = first_result(fs, "transcript.state")?; + let stage = context.append_typed_op( + &module, + "piop.stage", + Some("stage3"), + &[ + ("name", r#""shift_instruction_input_and_registers""#), + ("order", "3 : i64"), + ("roles", r#"["prover", "verifier"]"#), + ], + &[], + &["!piop.stage_type"], + )?; + let stage = first_result(stage, "piop.stage")?; + + let (state, shift_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage3.spartan_shift.gamma", + "spartan_shift_gamma", + "challenge_scalar", + 1, + )?; + let (state, instruction_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage3.instruction_input.gamma", + "instruction_input_gamma", + "challenge_scalar", + 1, + )?; + let (state, registers_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage3.registers.gamma", + "registers_gamma", + "challenge_scalar", + 1, + )?; + let _state = append_stage3_batched_sumcheck( + context, + &module, + params, + Stage3BatchedSumcheckInputs { + state, + stage, + openings: &inputs, + shift_gamma, + instruction_gamma, + registers_gamma, + }, + )?; + + verify_module(&module)?; + verify_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_stage3_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + verify_party_schema(module)?; + let role = module + .role() + .ok_or_else(|| schema_error("stage3 lowering requires party role"))?; + let params = stage_params(module)?; + let compute = context.new_module::(&module.name(), Some(role.clone())); + context.append_op_with_owned_attrs( + &compute, + "compute.params", + Some("jolt.compute_params"), + &[ + ("field".to_owned(), symbol_ref(¶ms.field)), + ("pcs".to_owned(), symbol_ref(¶ms.pcs)), + ("transcript".to_owned(), symbol_ref(¶ms.transcript)), + ], + )?; + context.append_op( + &compute, + "compute.function", + Some("jolt.stage3"), + &[("source", "@jolt.stage3")], + )?; + + let mut value_map = BTreeMap::new(); + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + match operation_name(op).as_str() { + "piop.relation" => { + let attrs = copy_attrs( + op, + &["kind", "domain", "num_rounds", "degree", "output_count"], + )?; + let symbol = string_attr(op, "sym_name")?; + context.append_op_with_owned_attrs( + &compute, + "compute.relation", + Some(&symbol), + &attrs, + )?; + } + "transcript.state" => { + let attrs = copy_attrs(op, &["scheme"])?; + let symbol = string_attr(op, "sym_name")?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_init", + Some(&symbol), + &attrs, + &[], + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "transcript.absorb_bytes" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "payload"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_absorb_bytes", + Some(&symbol), + &attrs, + &operands, + &["!compute.transcript_state"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "transcript.squeeze" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["label", "kind", "count"])?; + let result_types = transcript_squeeze_compute_result_types(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.transcript_squeeze", + Some(&symbol), + &attrs, + &operands, + &result_types, + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "field.const" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field", "value"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.field_const", + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "field.zero" | "field.one" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["field"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + &format!("compute.{}", operation_name(op).replace('.', "_")), + Some(&symbol), + &attrs, + &[], + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "field.add" | "field.sub" | "field.mul" | "field.neg" | "field.pow" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = field_compute_attrs(op)?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + &format!("compute.{}", operation_name(op).replace('.', "_")), + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "poly.lagrange_basis_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["domain_start", "domain_size", "index"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.poly_lagrange_basis_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.opening_input" => { + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source_stage", + "source_claim", + "oracle", + "domain", + "point_arity", + "claim_kind", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_input", + Some(&symbol), + &attrs, + &[], + &[ + "!compute.point", + "!compute.field_value", + "!compute.opening_claim_type", + ], + )?; + for index in 0..3 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "poly.point_slice" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "offset", "length"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.point_slice", + Some(&symbol), + &attrs, + &operands, + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "poly.point_concat" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["layout", "arity"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.point_concat", + Some(&symbol), + &attrs, + &operands, + &["!compute.point"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + let target_op = match &role { + Role::Prover => "compute.sumcheck_claim", + Role::Verifier => "compute.sumcheck_verify_claim", + }; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_batch" => { + let operands = lowered_operands(op, &value_map, 1)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.sumcheck_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + let target_op = match &role { + Role::Prover => "compute.sumcheck_driver", + Role::Verifier => "compute.sumcheck_verify", + }; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + target_op, + Some(&symbol), + &attrs, + &operands, + &[ + "!compute.transcript_state", + "!compute.point", + "!compute.sumcheck_result_type", + "!compute.sumcheck_proof_type", + ], + )?; + for index in 0..4 { + insert_result_mapping(&mut value_map, op, operation, index, index)?; + } + } + "piop.sumcheck_eval" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["source", "name", "index", "oracle"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_eval", + Some(&symbol), + &attrs, + &operands, + &["!compute.field_value"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.sumcheck_instance_result" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &[ + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.sumcheck_instance_result", + Some(&symbol), + &attrs, + &operands, + &["!compute.point", "!compute.sumcheck_result_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + insert_result_mapping(&mut value_map, op, operation, 1, 1)?; + } + "piop.opening_claim" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["oracle", "domain", "point_arity", "claim_kind"])?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_claim", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_claim_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + "piop.opening_claim_equal" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs(op, &["mode"])?; + let _operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_claim_equal", + Some(&symbol), + &attrs, + &operands, + &[], + )?; + } + "piop.opening_batch" => { + let operands = lowered_operands(op, &value_map, 0)?; + let symbol = string_attr(op, "sym_name")?; + let attrs = copy_attrs( + op, + &["stage", "proof_slot", "policy", "count", "ordered_claims"], + )?; + let operation = context.append_typed_op_with_owned_attrs( + &compute, + "compute.opening_batch", + Some(&symbol), + &attrs, + &operands, + &["!compute.opening_batch_type"], + )?; + insert_result_mapping(&mut value_map, op, operation, 0, 0)?; + } + _ => {} + } + } + + verify_module(&compute)?; + verify_compute_schema(&compute)?; + Ok(compute) +} + +fn append_stage3_oracles<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, +) -> Result<(), MlirError> { + let mut trace_oracles = BTreeSet::new(); + trace_oracles.extend(STAGE3_SHIFT_INPUTS); + trace_oracles.extend(STAGE3_SHIFT_OUTPUTS); + trace_oracles.extend(STAGE3_INSTRUCTION_INPUT_OUTPUTS); + trace_oracles.extend(STAGE3_REGISTER_INPUTS); + trace_oracles.extend([ + "LeftInstructionInput", + "RightInstructionInput", + "NextIsNoop", + ]); + for oracle in trace_oracles { + append_virtual_oracle(context, module, oracle, "jolt.trace_domain")?; + } + Ok(()) +} + +fn append_virtual_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + domain: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", &format!("@{domain}")), + ("commit_domain", &format!("@{domain}")), + ("visibility", r#""virtual""#), + ("layout", r#""virtual""#), + ], + ) +} + +fn append_stage3_relations<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage3.spartan_shift", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: SPARTAN_SHIFT_DEGREE, + output_count: STAGE3_SHIFT_OUTPUTS.len(), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage3.instruction_input", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: INSTRUCTION_INPUT_DEGREE, + output_count: STAGE3_INSTRUCTION_INPUT_OUTPUTS.len(), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage3.registers_claim_reduction", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: REGISTERS_CLAIM_REDUCTION_DEGREE, + output_count: STAGE3_REGISTER_INPUTS.len(), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage3.batched", + kind: "batched_sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: STAGE3_BATCHED_DEGREE, + output_count: stage3_output_count(), + }, + ) +} + +fn append_relation<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + spec: RelationSpec<'_>, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.relation", + Some(spec.symbol), + &[ + ("kind", &format!("\"{}\"", spec.kind)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("output_count", &int_attr(spec.output_count)), + ], + ) +} + +fn append_stage3_opening_inputs<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result, MlirError> { + Ok(Stage3OpeningInputs { + next_unexpanded_pc: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage1.NextUnexpandedPC", + source_stage: "stage1", + source_claim: "stage1.outer_remaining.opening.NextUnexpandedPC", + oracle: "NextUnexpandedPC", + }, + )?, + next_pc: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage1.NextPC", + source_stage: "stage1", + source_claim: "stage1.outer_remaining.opening.NextPC", + oracle: "NextPC", + }, + )?, + next_is_virtual: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage1.NextIsVirtual", + source_stage: "stage1", + source_claim: "stage1.outer_remaining.opening.NextIsVirtual", + oracle: "NextIsVirtual", + }, + )?, + next_is_first_in_sequence: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage1.NextIsFirstInSequence", + source_stage: "stage1", + source_claim: "stage1.outer_remaining.opening.NextIsFirstInSequence", + oracle: "NextIsFirstInSequence", + }, + )?, + product_next_is_noop: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage2.product_virtual.NextIsNoop", + source_stage: "stage2", + source_claim: "stage2.product_virtual.remainder.opening.NextIsNoop", + oracle: "NextIsNoop", + }, + )?, + product_left_instruction_input: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage2.product_virtual.LeftInstructionInput", + source_stage: "stage2", + source_claim: "stage2.product_virtual.remainder.opening.LeftInstructionInput", + oracle: "LeftInstructionInput", + }, + )?, + product_right_instruction_input: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage2.product_virtual.RightInstructionInput", + source_stage: "stage2", + source_claim: "stage2.product_virtual.remainder.opening.RightInstructionInput", + oracle: "RightInstructionInput", + }, + )?, + instruction_left_instruction_input: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage2.instruction_lookup.LeftInstructionInput", + source_stage: "stage2", + source_claim: + "stage2.instruction_lookup.claim_reduction.opening.LeftInstructionInput", + oracle: "LeftInstructionInput", + }, + )?, + instruction_right_instruction_input: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage2.instruction_lookup.RightInstructionInput", + source_stage: "stage2", + source_claim: + "stage2.instruction_lookup.claim_reduction.opening.RightInstructionInput", + oracle: "RightInstructionInput", + }, + )?, + rd_write_value: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage1.RdWriteValue", + source_stage: "stage1", + source_claim: "stage1.outer_remaining.opening.RdWriteValue", + oracle: "RdWriteValue", + }, + )?, + rs1_value: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage1.Rs1Value", + source_stage: "stage1", + source_claim: "stage1.outer_remaining.opening.Rs1Value", + oracle: "Rs1Value", + }, + )?, + rs2_value: append_stage_input( + context, + module, + params, + StageOpeningInputSpec { + symbol: "stage3.input.stage1.Rs2Value", + source_stage: "stage1", + source_claim: "stage1.outer_remaining.opening.Rs2Value", + oracle: "Rs2Value", + }, + )?, + }) +} + +fn append_stage_input<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + spec: StageOpeningInputSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_input", + Some(spec.symbol), + &[ + ("source_stage", &format!("@{}", spec.source_stage)), + ("source_claim", &format!("@{}", spec.source_claim)), + ("oracle", &format!("@{}", spec.oracle)), + ("domain", "@jolt.trace_domain"), + ("point_arity", &int_attr(params.log_t)), + ("claim_kind", r#""virtual""#), + ], + &[], + &["!poly.point", "!field.scalar", "!piop.opening_claim_type"], + )?; + Ok(Stage3OpeningInput { + eval: result(op, 1, "piop.opening_input")?, + claim: result(op, 2, "piop.opening_input")?, + }) +} + +fn append_transcript_squeeze<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + symbol: &str, + label: &str, + kind: &str, + count: usize, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "transcript.squeeze", + Some(symbol), + &[ + ("label", &format!("\"{label}\"")), + ("kind", &format!("\"{kind}\"")), + ("count", &int_attr(count)), + ], + &[state], + &[ + "!transcript.state_type", + transcript_squeeze_protocol_result_type(kind)?, + ], + )?; + Ok(( + result(op, 0, "transcript.squeeze")?, + result(op, 1, "transcript.squeeze")?, + )) +} + +fn append_field_one<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.one", + Some(symbol), + &[("field", "@bn254_fr")], + &[], + &["!field.scalar"], + )?; + first_result(op, "field.one") +} + +fn append_field_binary<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + op_name: &str, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + op_name, + Some(symbol), + &[], + &[lhs, rhs], + &["!field.scalar"], + )?; + first_result(op, op_name) +} + +fn append_field_add<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.add", symbol, lhs, rhs) +} + +fn append_field_sub<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.sub", symbol, lhs, rhs) +} + +fn append_field_mul<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.mul", symbol, lhs, rhs) +} + +fn append_field_pow<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + base: Value<'c, 'a>, + exponent: usize, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.pow", + Some(symbol), + &[("exponent", &int_attr(exponent))], + &[base], + &["!field.scalar"], + )?; + first_result(op, "field.pow") +} + +fn append_opening_claim_equal<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + left: Value<'c, '_>, + right: Value<'c, '_>, +) -> Result<(), MlirError> { + let _operation = context.append_typed_op( + module, + "piop.opening_claim_equal", + Some(symbol), + &[("mode", r#""point_and_eval""#)], + &[left, right], + &[], + )?; + Ok(()) +} + +fn append_stage3_batched_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + spec: Stage3BatchedSumcheckInputs<'c, 'a, '_>, +) -> Result, MlirError> { + let inputs = spec.openings; + let shift_gamma2 = append_field_pow( + context, + module, + "stage3.spartan_shift.gamma2", + spec.shift_gamma, + 2, + )?; + let shift_gamma3 = append_field_mul( + context, + module, + "stage3.spartan_shift.gamma3", + shift_gamma2, + spec.shift_gamma, + )?; + let shift_gamma4 = append_field_mul( + context, + module, + "stage3.spartan_shift.gamma4", + shift_gamma2, + shift_gamma2, + )?; + let one = append_field_one(context, module, "stage3.field.one")?; + let next_pc_term = append_field_mul( + context, + module, + "stage3.spartan_shift.term.NextPC", + spec.shift_gamma, + inputs.next_pc.eval, + )?; + let next_virtual_term = append_field_mul( + context, + module, + "stage3.spartan_shift.term.NextIsVirtual", + shift_gamma2, + inputs.next_is_virtual.eval, + )?; + let next_first_term = append_field_mul( + context, + module, + "stage3.spartan_shift.term.NextIsFirstInSequence", + shift_gamma3, + inputs.next_is_first_in_sequence.eval, + )?; + let one_minus_noop = append_field_sub( + context, + module, + "stage3.spartan_shift.one_minus.NextIsNoop", + one, + inputs.product_next_is_noop.eval, + )?; + let next_noop_term = append_field_mul( + context, + module, + "stage3.spartan_shift.term.NextIsNoop", + shift_gamma4, + one_minus_noop, + )?; + let shift_sum0 = append_field_add( + context, + module, + "stage3.spartan_shift.partial.NextUnexpandedPCNextPC", + inputs.next_unexpanded_pc.eval, + next_pc_term, + )?; + let shift_sum1 = append_field_add( + context, + module, + "stage3.spartan_shift.partial.NextIsVirtual", + shift_sum0, + next_virtual_term, + )?; + let shift_sum2 = append_field_add( + context, + module, + "stage3.spartan_shift.partial.NextIsFirstInSequence", + shift_sum1, + next_first_term, + )?; + let shift_claim = append_field_add( + context, + module, + "stage3.spartan_shift.claim_expr", + shift_sum2, + next_noop_term, + )?; + append_opening_claim_equal( + context, + module, + "stage3.instruction_input.left_claim_consistency", + inputs.product_left_instruction_input.claim, + inputs.instruction_left_instruction_input.claim, + )?; + append_opening_claim_equal( + context, + module, + "stage3.instruction_input.right_claim_consistency", + inputs.product_right_instruction_input.claim, + inputs.instruction_right_instruction_input.claim, + )?; + let instruction_left_term = append_field_mul( + context, + module, + "stage3.instruction_input.term.LeftInstructionInput", + spec.instruction_gamma, + inputs.product_left_instruction_input.eval, + )?; + let instruction_claim = append_field_add( + context, + module, + "stage3.instruction_input.claim_expr", + inputs.product_right_instruction_input.eval, + instruction_left_term, + )?; + let registers_gamma2 = append_field_pow( + context, + module, + "stage3.registers.gamma2", + spec.registers_gamma, + 2, + )?; + let rs1_term = append_field_mul( + context, + module, + "stage3.registers.term.Rs1Value", + spec.registers_gamma, + inputs.rs1_value.eval, + )?; + let rs2_term = append_field_mul( + context, + module, + "stage3.registers.term.Rs2Value", + registers_gamma2, + inputs.rs2_value.eval, + )?; + let registers_sum = append_field_add( + context, + module, + "stage3.registers.partial.RdWriteValueRs1Value", + inputs.rd_write_value.eval, + rs1_term, + )?; + let registers_claim = append_field_add( + context, + module, + "stage3.registers.claim_expr", + registers_sum, + rs2_term, + )?; + + let claims = [ + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage3.spartan_shift.input", + stage: "stage3", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: SPARTAN_SHIFT_DEGREE, + claim: "stage3.spartan_shift.weighted_next_values", + relation: "jolt.stage3.spartan_shift", + }, + shift_claim, + &[ + inputs.next_unexpanded_pc.claim, + inputs.next_pc.claim, + inputs.next_is_virtual.claim, + inputs.next_is_first_in_sequence.claim, + inputs.product_next_is_noop.claim, + ], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage3.instruction_input.input", + stage: "stage3", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: INSTRUCTION_INPUT_DEGREE, + claim: "stage3.instruction_input.weighted_inputs", + relation: "jolt.stage3.instruction_input", + }, + instruction_claim, + &[ + inputs.product_right_instruction_input.claim, + inputs.product_left_instruction_input.claim, + ], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage3.registers_claim_reduction.input", + stage: "stage3", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: REGISTERS_CLAIM_REDUCTION_DEGREE, + claim: "stage3.registers.weighted_register_values", + relation: "jolt.stage3.registers_claim_reduction", + }, + registers_claim, + &[ + inputs.rd_write_value.claim, + inputs.rs1_value.claim, + inputs.rs2_value.claim, + ], + )?, + ]; + let batch = append_sumcheck_batch( + context, + module, + spec.stage, + &claims, + SumcheckBatchSpec { + symbol: "stage3.batch", + stage: "stage3", + proof_slot: "stage3.sumcheck", + policy: "jolt_core_stage3_aligned", + ordered_claims: &[ + "stage3.spartan_shift.input", + "stage3.instruction_input.input", + "stage3.registers_claim_reduction.input", + ], + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + round_schedule: format!("[{}]", params.log_t), + }, + )?; + let (state, point, result_value) = append_sumcheck( + context, + module, + spec.state, + batch, + SumcheckDriverSpec { + symbol: "stage3.sumcheck", + stage: "stage3", + proof_slot: "stage3.sumcheck", + relation: "jolt.stage3.batched", + policy: "jolt_core_stage3_aligned", + round_schedule: format!("[{}]", params.log_t), + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + num_rounds: params.log_t, + degree: STAGE3_BATCHED_DEGREE, + }, + )?; + + let shift = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage3.spartan_shift.instance", + source: "stage3.sumcheck", + claim: "stage3.spartan_shift.input", + relation: "jolt.stage3.spartan_shift", + index: 0, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: 0, + point_order: "reverse", + degree: SPARTAN_SHIFT_DEGREE, + }, + point, + result_value, + )?; + let instruction = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage3.instruction_input.instance", + source: "stage3.sumcheck", + claim: "stage3.instruction_input.input", + relation: "jolt.stage3.instruction_input", + index: 1, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: 0, + point_order: "reverse", + degree: INSTRUCTION_INPUT_DEGREE, + }, + point, + result_value, + )?; + let registers = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage3.registers_claim_reduction.instance", + source: "stage3.sumcheck", + claim: "stage3.registers_claim_reduction.input", + relation: "jolt.stage3.registers_claim_reduction", + index: 2, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: 0, + point_order: "reverse", + degree: REGISTERS_CLAIM_REDUCTION_DEGREE, + }, + point, + result_value, + )?; + append_stage3_output_openings( + context, + module, + &[ + InstanceOutput { + prefix: "stage3.spartan_shift", + instance: shift, + outputs: &STAGE3_SHIFT_OUTPUTS, + degree_offset: 0, + }, + InstanceOutput { + prefix: "stage3.instruction_input", + instance: instruction, + outputs: &STAGE3_INSTRUCTION_INPUT_OUTPUTS, + degree_offset: STAGE3_SHIFT_OUTPUTS.len(), + }, + InstanceOutput { + prefix: "stage3.registers_claim_reduction", + instance: registers, + outputs: &STAGE3_REGISTER_INPUTS, + degree_offset: STAGE3_SHIFT_OUTPUTS.len() + STAGE3_INSTRUCTION_INPUT_OUTPUTS.len(), + }, + ], + params.log_t, + )?; + Ok(state) +} + +fn append_stage3_output_openings<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + outputs: &[InstanceOutput<'c, 'a, '_>], + point_arity: usize, +) -> Result<(), MlirError> { + let mut claims = Vec::new(); + let mut claim_symbols = Vec::new(); + + for output in outputs { + for (index, &oracle) in output.outputs.iter().enumerate() { + let symbol = format!("{}.opening.{oracle}", output.prefix); + let eval = append_sumcheck_eval( + context, + module, + &format!("{}.eval.{oracle}", output.prefix), + "stage3.sumcheck", + oracle, + output.degree_offset + index, + output.instance.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + output.instance.0, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle, + domain: "jolt.trace_domain", + point_arity, + claim_kind: "virtual", + }, + )?); + } + } + + let claim_names = claim_symbols.iter().map(String::as_str).collect::>(); + let _batch = context.append_typed_op( + module, + "piop.opening_batch", + Some("stage3.openings"), + &[ + ("stage", "@stage3"), + ("proof_slot", "@stage3.openings"), + ("policy", r#""jolt_stage3_output_order""#), + ("count", &int_attr(claims.len())), + ("ordered_claims", &symbol_array_attr(&claim_names)), + ], + &claims, + &["!piop.opening_batch_type"], + )?; + Ok(()) +} + +fn append_sumcheck_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckClaimSpec<'_>, + input_claim: Value<'c, 'a>, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(inputs.len() + 1); + operands.push(input_claim); + operands.extend_from_slice(inputs); + let op = context.append_typed_op( + module, + "piop.sumcheck_claim", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ], + &operands, + &["!piop.sumcheck_claim_type"], + )?; + first_result(op, "piop.sumcheck_claim") +} + +fn append_sumcheck_batch<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + stage: Value<'c, 'a>, + claims: &[Value<'c, 'a>], + spec: SumcheckBatchSpec<'_>, +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(claims.len() + 1); + operands.push(stage); + operands.extend_from_slice(claims); + let op = context.append_typed_op( + module, + "piop.sumcheck_batch", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("policy", &format!("\"{}\"", spec.policy)), + ("count", &int_attr(spec.ordered_claims.len())), + ("ordered_claims", &symbol_array_attr(spec.ordered_claims)), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("round_schedule", &spec.round_schedule), + ], + &operands, + &["!piop.sumcheck_batch_type"], + )?; + first_result(op, "piop.sumcheck_batch") +} + +fn append_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + batch: Value<'c, 'a>, + spec: SumcheckDriverSpec<'_>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("relation", &format!("@{}", spec.relation)), + ("policy", &format!("\"{}\"", spec.policy)), + ("round_schedule", &spec.round_schedule), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ], + &[state, batch], + &[ + "!transcript.state_type", + "!poly.point", + "!piop.sumcheck_result_type", + "!piop.sumcheck_proof_type", + ], + )?; + Ok(( + result(op, 0, "piop.sumcheck")?, + result(op, 1, "piop.sumcheck")?, + result(op, 2, "piop.sumcheck")?, + )) +} + +fn append_sumcheck_instance_result<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckInstanceResultSpec<'_>, + point: Value<'c, 'a>, + result_value: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_instance_result", + Some(spec.symbol), + &[ + ("source", &format!("@{}", spec.source)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ("index", &int_attr(spec.index)), + ("point_arity", &int_attr(spec.point_arity)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("round_offset", &int_attr(spec.round_offset)), + ("point_order", &format!("\"{}\"", spec.point_order)), + ("degree", &int_attr(spec.degree)), + ], + &[point, result_value], + &["!poly.point", "!piop.sumcheck_result_type"], + )?; + Ok(( + result(op, 0, "piop.sumcheck_instance_result")?, + result(op, 1, "piop.sumcheck_instance_result")?, + )) +} + +fn append_sumcheck_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + oracle: &str, + index: usize, + result_value: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_eval", + Some(symbol), + &[ + ("source", &format!("@{source}")), + ("name", &format!("@{symbol}")), + ("index", &int_attr(index)), + ("oracle", &format!("@{oracle}")), + ], + &[result_value], + &["!field.scalar"], + )?; + first_result(op, "piop.sumcheck_eval") +} + +fn append_opening_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + spec: OpeningClaimSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_claim", + Some(spec.symbol), + &[ + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", &format!("\"{}\"", spec.claim_kind)), + ], + &[point, eval], + &["!piop.opening_claim_type"], + )?; + first_result(op, "piop.opening_claim") +} + +fn first_result<'c, 'a>( + operation: OperationRef<'c, 'a>, + operation_name: &str, +) -> Result, MlirError> { + result(operation, 0, operation_name) +} + +fn result<'c, 'a>( + operation: OperationRef<'c, 'a>, + index: usize, + operation_name: &str, +) -> Result, MlirError> { + operation + .result(index) + .map(Into::into) + .map_err(|_| schema_error(format!("{operation_name} requires result {index}"))) +} + +#[derive(Clone, Debug)] +struct StageParamsAst { + field: String, + pcs: String, + transcript: String, +} + +fn stage_params(module: &BoltModule<'_, Party>) -> Result { + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + if operation_name(op) == "protocol.params" { + return Ok(StageParamsAst { + field: symbol_attr(op, "field")?, + pcs: symbol_attr(op, "pcs")?, + transcript: symbol_attr(op, "transcript")?, + }); + } + } + Err(schema_error("stage3 lowering requires protocol.params")) +} + +fn operation_result_key_at( + operation: OperationRef<'_, '_>, + index: usize, +) -> Result { + let result = operation.result(index).map_err(|_| { + schema_error(format!( + "{} requires result {index}", + operation_name(operation) + )) + })?; + result_key(result.owner(), result.result_number()) +} + +fn result_key(operation: OperationRef<'_, '_>, result_number: usize) -> Result { + Ok(format!( + "{}#{result_number}", + string_attr(operation, "sym_name")? + )) +} + +fn operand_key(operation: OperationRef<'_, '_>, index: usize) -> Result { + let operand = operation.operand(index).map_err(|_| { + schema_error(format!( + "{} requires operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + schema_error(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + result_key(owner.owner(), owner.result_number()).map_err(|_| { + schema_error(format!( + "{} operand {index} owner missing sym_name", + operation_name(operation) + )) + }) +} + +fn lowered_operands<'c, 'a>( + operation: OperationRef<'_, '_>, + value_map: &BTreeMap>, + start_index: usize, +) -> Result>, MlirError> { + (start_index..operation.operand_count()) + .map(|index| { + let key = operand_key(operation, index)?; + value_map.get(&key).copied().ok_or_else(|| { + schema_error(format!( + "{} operand {index} was not lowered", + operation_name(operation) + )) + }) + }) + .collect() +} + +fn insert_result_mapping<'c, 'a>( + value_map: &mut BTreeMap>, + source: OperationRef<'_, '_>, + target: OperationRef<'c, 'a>, + source_index: usize, + target_index: usize, +) -> Result<(), MlirError> { + let key = operation_result_key_at(source, source_index)?; + let value = target.result(target_index).map(Into::into).map_err(|_| { + schema_error(format!( + "{} requires result {target_index}", + operation_name(target) + )) + })?; + let inserted = value_map.insert(key, value); + debug_assert!(inserted.is_none()); + Ok(()) +} + +fn symbol_ref(symbol: &str) -> String { + format!("@{symbol}") +} + +fn stage3_output_count() -> usize { + STAGE3_SHIFT_OUTPUTS.len() + + STAGE3_INSTRUCTION_INPUT_OUTPUTS.len() + + STAGE3_REGISTER_INPUTS.len() +} + +fn int_attr(value: usize) -> String { + format!("{value} : i64") +} + +fn symbol_array_attr(values: &[&str]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn schema_error(message: impl Into) -> MlirError { + SchemaError::new(message).into() +} + +#[derive(Clone, Copy)] +struct Stage3OpeningInput<'c, 'a> { + eval: Value<'c, 'a>, + claim: Value<'c, 'a>, +} + +struct Stage3OpeningInputs<'c, 'a> { + next_unexpanded_pc: Stage3OpeningInput<'c, 'a>, + next_pc: Stage3OpeningInput<'c, 'a>, + next_is_virtual: Stage3OpeningInput<'c, 'a>, + next_is_first_in_sequence: Stage3OpeningInput<'c, 'a>, + product_next_is_noop: Stage3OpeningInput<'c, 'a>, + product_left_instruction_input: Stage3OpeningInput<'c, 'a>, + product_right_instruction_input: Stage3OpeningInput<'c, 'a>, + instruction_left_instruction_input: Stage3OpeningInput<'c, 'a>, + instruction_right_instruction_input: Stage3OpeningInput<'c, 'a>, + rd_write_value: Stage3OpeningInput<'c, 'a>, + rs1_value: Stage3OpeningInput<'c, 'a>, + rs2_value: Stage3OpeningInput<'c, 'a>, +} + +struct StageOpeningInputSpec<'a> { + symbol: &'a str, + source_stage: &'a str, + source_claim: &'a str, + oracle: &'a str, +} + +struct Stage3BatchedSumcheckInputs<'c, 'a, 'b> { + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + openings: &'b Stage3OpeningInputs<'c, 'a>, + shift_gamma: Value<'c, 'a>, + instruction_gamma: Value<'c, 'a>, + registers_gamma: Value<'c, 'a>, +} + +struct RelationSpec<'a> { + symbol: &'a str, + kind: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + output_count: usize, +} + +struct SumcheckClaimSpec<'a> { + symbol: &'a str, + stage: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + claim: &'a str, + relation: &'a str, +} + +struct SumcheckBatchSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + policy: &'a str, + ordered_claims: &'a [&'a str], + claim_label: &'a str, + round_label: &'a str, + round_schedule: String, +} + +struct SumcheckDriverSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + relation: &'a str, + policy: &'a str, + round_schedule: String, + claim_label: &'a str, + round_label: &'a str, + num_rounds: usize, + degree: usize, +} + +struct SumcheckInstanceResultSpec<'a> { + symbol: &'a str, + source: &'a str, + claim: &'a str, + relation: &'a str, + index: usize, + point_arity: usize, + num_rounds: usize, + round_offset: usize, + point_order: &'a str, + degree: usize, +} + +struct OpeningClaimSpec<'a> { + symbol: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, + claim_kind: &'a str, +} + +struct InstanceOutput<'c, 'a, 'b> { + prefix: &'b str, + instance: (Value<'c, 'a>, Value<'c, 'a>), + outputs: &'b [&'b str], + degree_offset: usize, +} diff --git a/crates/bolt/src/protocols/jolt/phases/stage4.rs b/crates/bolt/src/protocols/jolt/phases/stage4.rs new file mode 100644 index 0000000000..adf6bc080a --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/stage4.rs @@ -0,0 +1,1264 @@ +use melior::ir::operation::OperationRef; +use melior::ir::Value; + +use crate::ir::{BoltModule, Compute, Party, Protocol}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{verify_protocol_schema, SchemaError}; + +use super::super::oracles; +use super::super::params::JoltProtocolParams; +use super::lowering::{lower_party_to_compute, transcript_squeeze_protocol_result_type}; + +const REGISTERS_RW_DEGREE: usize = 3; +const RAM_VAL_CHECK_DEGREE: usize = 3; +const STAGE4_BATCHED_DEGREE: usize = 3; + +const STAGE4_REGISTER_INPUTS: [&str; 3] = ["RdWriteValue", "Rs1Value", "Rs2Value"]; +const STAGE4_REGISTER_OUTPUTS: [&str; 5] = ["RegistersVal", "Rs1Ra", "Rs2Ra", "RdWa", "RdInc"]; +const STAGE4_RAM_VAL_OUTPUTS: [&str; 2] = ["RamRa", "RamInc"]; + +pub fn build_stage4_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.stage4", None); + oracles::append_foundation_ops(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.stage4"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + append_stage4_domains(context, &module, params)?; + append_stage4_oracles(context, &module)?; + append_stage4_relations(context, &module, params)?; + let inputs = append_stage4_opening_inputs(context, &module, params)?; + + let fs = context.append_typed_op( + &module, + "transcript.state", + Some("fs_after_stage3"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let state = first_result(fs, "transcript.state")?; + let stage = context.append_typed_op( + &module, + "piop.stage", + Some("stage4"), + &[ + ("name", r#""registers_rw_and_ram_val_check""#), + ("order", "4 : i64"), + ("roles", r#"["prover", "verifier"]"#), + ], + &[], + &["!piop.stage_type"], + )?; + let stage = first_result(stage, "piop.stage")?; + + let (state, registers_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage4.registers_read_write.gamma", + "registers_read_write_gamma", + "challenge_scalar", + 1, + )?; + let state = append_transcript_absorb_bytes( + context, + &module, + state, + "stage4.ram_val_check.domain_separator", + "ram_val_check_gamma", + "", + )?; + let (state, ram_val_check_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage4.ram_val_check.gamma", + "ram_val_check_gamma", + "challenge_scalar", + 1, + )?; + let _state = append_stage4_batched_sumcheck( + context, + &module, + params, + Stage4BatchedSumcheckInputs { + state, + stage, + openings: &inputs, + registers_gamma, + ram_val_check_gamma, + }, + )?; + + verify_module(&module)?; + verify_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_stage4_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + lower_party_to_compute(context, module, "jolt.stage4", "jolt.stage4", "stage4") +} + +fn append_stage4_domains<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + context.append_op( + module, + "poly.domain", + Some("jolt.stage4_registers_rw_domain"), + &[ + ("field", "@bn254_fr"), + ("log_size", &int_attr(stage4_registers_rw_rounds(params))), + ], + )?; + context.append_op( + module, + "poly.domain", + Some("jolt.stage2_ram_rw_domain"), + &[ + ("field", "@bn254_fr"), + ("log_size", &int_attr(params.log_k_ram + params.log_t)), + ], + )?; + context.append_op( + module, + "poly.domain", + Some("jolt.ram_address_domain"), + &[ + ("field", "@bn254_fr"), + ("log_size", &int_attr(params.log_k_ram)), + ], + ) +} + +fn append_stage4_oracles<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, +) -> Result<(), MlirError> { + for oracle in STAGE4_REGISTER_INPUTS { + append_virtual_oracle(context, module, oracle, "jolt.trace_domain")?; + } + append_virtual_oracle( + context, + module, + "RegistersVal", + "jolt.stage4_registers_rw_domain", + )?; + append_virtual_oracle(context, module, "Rs1Ra", "jolt.stage4_registers_rw_domain")?; + append_virtual_oracle(context, module, "Rs2Ra", "jolt.stage4_registers_rw_domain")?; + append_virtual_oracle(context, module, "RdWa", "jolt.stage4_registers_rw_domain")?; + append_virtual_oracle(context, module, "RamVal", "jolt.stage2_ram_rw_domain")?; + append_virtual_oracle(context, module, "RamRa", "jolt.stage2_ram_rw_domain")?; + append_virtual_oracle(context, module, "RamValFinal", "jolt.ram_address_domain")?; + append_virtual_oracle(context, module, "RamValInit", "jolt.ram_address_domain")?; + append_committed_trace_oracle(context, module, "RdInc")?; + append_committed_trace_oracle(context, module, "RamInc") +} + +fn append_virtual_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + domain: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", &format!("@{domain}")), + ("commit_domain", &format!("@{domain}")), + ("visibility", r#""virtual""#), + ("layout", r#""virtual""#), + ], + ) +} + +fn append_committed_trace_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", "@jolt.trace_domain"), + ("commit_domain", "@jolt.main_witness_commit_domain"), + ("visibility", r#""committed""#), + ("layout", r#""dense_trace""#), + ], + ) +} + +fn append_stage4_relations<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage4.registers_read_write", + kind: "sumcheck", + domain: "jolt.stage4_registers_rw_domain", + num_rounds: stage4_registers_rw_rounds(params), + degree: REGISTERS_RW_DEGREE, + output_count: STAGE4_REGISTER_OUTPUTS.len(), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage4.ram_val_check", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: RAM_VAL_CHECK_DEGREE, + output_count: STAGE4_RAM_VAL_OUTPUTS.len(), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage4.batched", + kind: "batched_sumcheck", + domain: "jolt.stage4_registers_rw_domain", + num_rounds: stage4_registers_rw_rounds(params), + degree: STAGE4_BATCHED_DEGREE, + output_count: stage4_output_count(), + }, + ) +} + +fn append_relation<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + spec: RelationSpec<'_>, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.relation", + Some(spec.symbol), + &[ + ("kind", &format!("\"{}\"", spec.kind)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("output_count", &int_attr(spec.output_count)), + ], + ) +} + +fn append_stage4_opening_inputs<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result, MlirError> { + Ok(Stage4OpeningInputs { + rd_write_value: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage4.input.stage3.registers.RdWriteValue", + source_stage: "stage3", + source_claim: "stage3.registers_claim_reduction.opening.RdWriteValue", + oracle: "RdWriteValue", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + rs1_registers: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage4.input.stage3.registers.Rs1Value", + source_stage: "stage3", + source_claim: "stage3.registers_claim_reduction.opening.Rs1Value", + oracle: "Rs1Value", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + rs2_registers: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage4.input.stage3.registers.Rs2Value", + source_stage: "stage3", + source_claim: "stage3.registers_claim_reduction.opening.Rs2Value", + oracle: "Rs2Value", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + rs1_instruction: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage4.input.stage3.instruction.Rs1Value", + source_stage: "stage3", + source_claim: "stage3.instruction_input.opening.Rs1Value", + oracle: "Rs1Value", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + rs2_instruction: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage4.input.stage3.instruction.Rs2Value", + source_stage: "stage3", + source_claim: "stage3.instruction_input.opening.Rs2Value", + oracle: "Rs2Value", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + ram_val: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage4.input.stage2.RamVal", + source_stage: "stage2", + source_claim: "stage2.ram_read_write.opening.RamVal", + oracle: "RamVal", + domain: "jolt.stage2_ram_rw_domain", + point_arity: params.log_k_ram + params.log_t, + }, + )?, + ram_val_final: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage4.input.stage2.RamValFinal", + source_stage: "stage2", + source_claim: "stage2.ram_output.opening.RamValFinal", + oracle: "RamValFinal", + domain: "jolt.ram_address_domain", + point_arity: params.log_k_ram, + }, + )?, + ram_val_init: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage4.input.initial_ram.RamValInit", + source_stage: "stage4_precomputed", + source_claim: "stage4.ram_val_check.initial_ram_eval", + oracle: "RamValInit", + domain: "jolt.ram_address_domain", + point_arity: params.log_k_ram, + }, + )?, + }) +} + +fn append_stage_input<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: StageOpeningInputSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_input", + Some(spec.symbol), + &[ + ("source_stage", &format!("@{}", spec.source_stage)), + ("source_claim", &format!("@{}", spec.source_claim)), + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", r#""virtual""#), + ], + &[], + &["!poly.point", "!field.scalar", "!piop.opening_claim_type"], + )?; + Ok(Stage4OpeningInput { + point: result(op, 0, "piop.opening_input")?, + eval: result(op, 1, "piop.opening_input")?, + claim: result(op, 2, "piop.opening_input")?, + }) +} + +fn append_transcript_squeeze<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + symbol: &str, + label: &str, + kind: &str, + count: usize, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "transcript.squeeze", + Some(symbol), + &[ + ("label", &format!("\"{label}\"")), + ("kind", &format!("\"{kind}\"")), + ("count", &int_attr(count)), + ], + &[state], + &[ + "!transcript.state_type", + transcript_squeeze_protocol_result_type(kind)?, + ], + )?; + Ok(( + result(op, 0, "transcript.squeeze")?, + result(op, 1, "transcript.squeeze")?, + )) +} + +fn append_transcript_absorb_bytes<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + symbol: &str, + label: &str, + payload: &str, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "transcript.absorb_bytes", + Some(symbol), + &[ + ("label", &format!("\"{label}\"")), + ("payload", &format!("\"{payload}\"")), + ], + &[state], + &["!transcript.state_type"], + )?; + first_result(op, "transcript.absorb_bytes") +} + +fn append_stage4_batched_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + spec: Stage4BatchedSumcheckInputs<'c, 'a, '_>, +) -> Result, MlirError> { + let inputs = spec.openings; + append_opening_claim_equal( + context, + module, + "stage4.registers.rs1_claim_consistency", + inputs.rs1_registers.claim, + inputs.rs1_instruction.claim, + )?; + append_opening_claim_equal( + context, + module, + "stage4.registers.rs2_claim_consistency", + inputs.rs2_registers.claim, + inputs.rs2_instruction.claim, + )?; + let registers_gamma2 = append_field_pow( + context, + module, + "stage4.registers_read_write.gamma2", + spec.registers_gamma, + 2, + )?; + let rs1_term = append_field_mul( + context, + module, + "stage4.registers_read_write.term.Rs1Value", + spec.registers_gamma, + inputs.rs1_registers.eval, + )?; + let rs2_term = append_field_mul( + context, + module, + "stage4.registers_read_write.term.Rs2Value", + registers_gamma2, + inputs.rs2_registers.eval, + )?; + let registers_sum = append_field_add( + context, + module, + "stage4.registers_read_write.partial.RdWriteValueRs1Value", + inputs.rd_write_value.eval, + rs1_term, + )?; + let registers_claim = append_field_add( + context, + module, + "stage4.registers_read_write.claim_expr", + registers_sum, + rs2_term, + )?; + + let ram_val_delta = append_field_sub( + context, + module, + "stage4.ram_val_check.delta.RamVal", + inputs.ram_val.eval, + inputs.ram_val_init.eval, + )?; + let ram_final_delta = append_field_sub( + context, + module, + "stage4.ram_val_check.delta.RamValFinal", + inputs.ram_val_final.eval, + inputs.ram_val_init.eval, + )?; + let ram_final_term = append_field_mul( + context, + module, + "stage4.ram_val_check.term.RamValFinal", + spec.ram_val_check_gamma, + ram_final_delta, + )?; + let ram_val_claim = append_field_add( + context, + module, + "stage4.ram_val_check.claim_expr", + ram_val_delta, + ram_final_term, + )?; + + let claims = [ + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage4.registers_read_write.input", + stage: "stage4", + domain: "jolt.stage4_registers_rw_domain", + num_rounds: stage4_registers_rw_rounds(params), + degree: REGISTERS_RW_DEGREE, + claim: "stage4.registers_read_write.weighted_values", + relation: "jolt.stage4.registers_read_write", + }, + registers_claim, + &[ + inputs.rd_write_value.claim, + inputs.rs1_registers.claim, + inputs.rs2_registers.claim, + ], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage4.ram_val_check.input", + stage: "stage4", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: RAM_VAL_CHECK_DEGREE, + claim: "stage4.ram_val_check.weighted_values", + relation: "jolt.stage4.ram_val_check", + }, + ram_val_claim, + &[ + inputs.ram_val.claim, + inputs.ram_val_final.claim, + inputs.ram_val_init.claim, + ], + )?, + ]; + let batch = append_sumcheck_batch( + context, + module, + spec.stage, + &claims, + SumcheckBatchSpec { + symbol: "stage4.batch", + stage: "stage4", + proof_slot: "stage4.sumcheck", + policy: "jolt_core_stage4_aligned", + ordered_claims: &[ + "stage4.registers_read_write.input", + "stage4.ram_val_check.input", + ], + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + round_schedule: format!("[{}, {}]", params.log_t, params.register_log_k), + }, + )?; + let (state, point, result_value) = append_sumcheck( + context, + module, + spec.state, + batch, + SumcheckDriverSpec { + symbol: "stage4.sumcheck", + stage: "stage4", + proof_slot: "stage4.sumcheck", + relation: "jolt.stage4.batched", + policy: "jolt_core_stage4_aligned", + round_schedule: format!("[{}, {}]", params.log_t, params.register_log_k), + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + num_rounds: stage4_registers_rw_rounds(params), + degree: STAGE4_BATCHED_DEGREE, + }, + )?; + let registers = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage4.registers_read_write.instance", + source: "stage4.sumcheck", + claim: "stage4.registers_read_write.input", + relation: "jolt.stage4.registers_read_write", + index: 0, + point_arity: stage4_registers_rw_rounds(params), + num_rounds: stage4_registers_rw_rounds(params), + round_offset: 0, + point_order: "stage4_registers_rw", + degree: REGISTERS_RW_DEGREE, + }, + point, + result_value, + )?; + let ram_val_check = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage4.ram_val_check.instance", + source: "stage4.sumcheck", + claim: "stage4.ram_val_check.input", + relation: "jolt.stage4.ram_val_check", + index: 1, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: params.register_log_k, + point_order: "reverse", + degree: RAM_VAL_CHECK_DEGREE, + }, + point, + result_value, + )?; + append_stage4_output_openings(context, module, params, inputs, registers, ram_val_check)?; + Ok(state) +} + +fn append_stage4_output_openings<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + inputs: &Stage4OpeningInputs<'c, 'a>, + registers: (Value<'c, 'a>, Value<'c, 'a>), + ram_val_check: (Value<'c, 'a>, Value<'c, 'a>), +) -> Result<(), MlirError> { + let mut claims = Vec::new(); + let mut claim_symbols = Vec::new(); + + for (index, &oracle) in ["RegistersVal", "Rs1Ra", "Rs2Ra", "RdWa"] + .iter() + .enumerate() + { + let symbol = format!("stage4.registers_read_write.opening.{oracle}"); + let eval = append_sumcheck_eval( + context, + module, + &format!("stage4.registers_read_write.eval.{oracle}"), + "stage4.sumcheck", + oracle, + index, + registers.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + registers.0, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle, + domain: "jolt.stage4_registers_rw_domain", + point_arity: stage4_registers_rw_rounds(params), + claim_kind: "virtual", + }, + )?); + } + + let rd_inc_point = append_point_slice( + context, + module, + "stage4.registers_read_write.point.RdInc", + "stage4.registers_read_write.instance", + params.register_log_k, + params.log_t, + registers.0, + )?; + let rd_inc_eval = append_sumcheck_eval( + context, + module, + "stage4.registers_read_write.eval.RdInc", + "stage4.sumcheck", + "RdInc", + 4, + registers.1, + )?; + claim_symbols.push("stage4.registers_read_write.opening.RdInc".to_owned()); + claims.push(append_opening_claim( + context, + module, + rd_inc_point, + rd_inc_eval, + OpeningClaimSpec { + symbol: "stage4.registers_read_write.opening.RdInc", + oracle: "RdInc", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?); + + let ram_address_point = append_point_slice( + context, + module, + "stage4.ram_val_check.point.RamAddress", + "stage4.input.stage2.RamVal", + 0, + params.log_k_ram, + inputs.ram_val.point, + )?; + let ram_ra_point = append_point_concat( + context, + module, + "stage4.ram_val_check.point.RamRa", + "address_then_cycle", + params.log_k_ram + params.log_t, + &[ram_address_point, ram_val_check.0], + )?; + let ram_ra_eval = append_sumcheck_eval( + context, + module, + "stage4.ram_val_check.eval.RamRa", + "stage4.sumcheck", + "RamRa", + 0, + ram_val_check.1, + )?; + claim_symbols.push("stage4.ram_val_check.opening.RamRa".to_owned()); + claims.push(append_opening_claim( + context, + module, + ram_ra_point, + ram_ra_eval, + OpeningClaimSpec { + symbol: "stage4.ram_val_check.opening.RamRa", + oracle: "RamRa", + domain: "jolt.stage2_ram_rw_domain", + point_arity: params.log_k_ram + params.log_t, + claim_kind: "virtual", + }, + )?); + + let ram_inc_eval = append_sumcheck_eval( + context, + module, + "stage4.ram_val_check.eval.RamInc", + "stage4.sumcheck", + "RamInc", + 1, + ram_val_check.1, + )?; + claim_symbols.push("stage4.ram_val_check.opening.RamInc".to_owned()); + claims.push(append_opening_claim( + context, + module, + ram_val_check.0, + ram_inc_eval, + OpeningClaimSpec { + symbol: "stage4.ram_val_check.opening.RamInc", + oracle: "RamInc", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?); + + let claim_names = claim_symbols.iter().map(String::as_str).collect::>(); + let _batch = context.append_typed_op( + module, + "piop.opening_batch", + Some("stage4.openings"), + &[ + ("stage", "@stage4"), + ("proof_slot", "@stage4.openings"), + ("policy", r#""jolt_stage4_output_order""#), + ("count", &int_attr(claims.len())), + ("ordered_claims", &symbol_array_attr(&claim_names)), + ], + &claims, + &["!piop.opening_batch_type"], + )?; + Ok(()) +} + +fn append_opening_claim_equal<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + left: Value<'c, '_>, + right: Value<'c, '_>, +) -> Result<(), MlirError> { + let _operation = context.append_typed_op( + module, + "piop.opening_claim_equal", + Some(symbol), + &[("mode", r#""point_and_eval""#)], + &[left, right], + &[], + )?; + Ok(()) +} + +fn append_field_binary<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + op_name: &str, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + op_name, + Some(symbol), + &[], + &[lhs, rhs], + &["!field.scalar"], + )?; + first_result(op, op_name) +} + +fn append_field_add<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.add", symbol, lhs, rhs) +} + +fn append_field_sub<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.sub", symbol, lhs, rhs) +} + +fn append_field_mul<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.mul", symbol, lhs, rhs) +} + +fn append_field_pow<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + base: Value<'c, 'a>, + exponent: usize, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.pow", + Some(symbol), + &[("exponent", &int_attr(exponent))], + &[base], + &["!field.scalar"], + )?; + first_result(op, "field.pow") +} + +fn append_sumcheck_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckClaimSpec<'_>, + input_claim: Value<'c, 'a>, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(inputs.len() + 1); + operands.push(input_claim); + operands.extend_from_slice(inputs); + let op = context.append_typed_op( + module, + "piop.sumcheck_claim", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ], + &operands, + &["!piop.sumcheck_claim_type"], + )?; + first_result(op, "piop.sumcheck_claim") +} + +fn append_sumcheck_batch<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + stage: Value<'c, 'a>, + claims: &[Value<'c, 'a>], + spec: SumcheckBatchSpec<'_>, +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(claims.len() + 1); + operands.push(stage); + operands.extend_from_slice(claims); + let op = context.append_typed_op( + module, + "piop.sumcheck_batch", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("policy", &format!("\"{}\"", spec.policy)), + ("count", &int_attr(spec.ordered_claims.len())), + ("ordered_claims", &symbol_array_attr(spec.ordered_claims)), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("round_schedule", &spec.round_schedule), + ], + &operands, + &["!piop.sumcheck_batch_type"], + )?; + first_result(op, "piop.sumcheck_batch") +} + +fn append_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + batch: Value<'c, 'a>, + spec: SumcheckDriverSpec<'_>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("relation", &format!("@{}", spec.relation)), + ("policy", &format!("\"{}\"", spec.policy)), + ("round_schedule", &spec.round_schedule), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ], + &[state, batch], + &[ + "!transcript.state_type", + "!poly.point", + "!piop.sumcheck_result_type", + "!piop.sumcheck_proof_type", + ], + )?; + Ok(( + result(op, 0, "piop.sumcheck")?, + result(op, 1, "piop.sumcheck")?, + result(op, 2, "piop.sumcheck")?, + )) +} + +fn append_sumcheck_instance_result<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckInstanceResultSpec<'_>, + point: Value<'c, 'a>, + result_value: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_instance_result", + Some(spec.symbol), + &[ + ("source", &format!("@{}", spec.source)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ("index", &int_attr(spec.index)), + ("point_arity", &int_attr(spec.point_arity)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("round_offset", &int_attr(spec.round_offset)), + ("point_order", &format!("\"{}\"", spec.point_order)), + ("degree", &int_attr(spec.degree)), + ], + &[point, result_value], + &["!poly.point", "!piop.sumcheck_result_type"], + )?; + Ok(( + result(op, 0, "piop.sumcheck_instance_result")?, + result(op, 1, "piop.sumcheck_instance_result")?, + )) +} + +fn append_sumcheck_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + oracle: &str, + index: usize, + result_value: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_eval", + Some(symbol), + &[ + ("source", &format!("@{source}")), + ("name", &format!("@{symbol}")), + ("index", &int_attr(index)), + ("oracle", &format!("@{oracle}")), + ], + &[result_value], + &["!field.scalar"], + )?; + first_result(op, "piop.sumcheck_eval") +} + +fn append_opening_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + spec: OpeningClaimSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_claim", + Some(spec.symbol), + &[ + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", &format!("\"{}\"", spec.claim_kind)), + ], + &[point, eval], + &["!piop.opening_claim_type"], + )?; + first_result(op, "piop.opening_claim") +} + +fn append_point_slice<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + offset: usize, + length: usize, + point: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_slice", + Some(symbol), + &[ + ("source", &format!("@{source}")), + ("offset", &int_attr(offset)), + ("length", &int_attr(length)), + ], + &[point], + &["!poly.point"], + )?; + first_result(op, "poly.point_slice") +} + +fn append_point_concat<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + layout: &str, + arity: usize, + points: &[Value<'c, 'a>], +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_concat", + Some(symbol), + &[ + ("layout", &format!("\"{layout}\"")), + ("arity", &int_attr(arity)), + ], + points, + &["!poly.point"], + )?; + first_result(op, "poly.point_concat") +} + +fn first_result<'c, 'a>( + operation: OperationRef<'c, 'a>, + operation_name: &str, +) -> Result, MlirError> { + result(operation, 0, operation_name) +} + +fn result<'c, 'a>( + operation: OperationRef<'c, 'a>, + index: usize, + operation_name: &str, +) -> Result, MlirError> { + operation + .result(index) + .map(Into::into) + .map_err(|_| schema_error(format!("{operation_name} requires result {index}"))) +} + +struct RelationSpec<'a> { + symbol: &'a str, + kind: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + output_count: usize, +} + +struct Stage4OpeningInputs<'c, 'a> { + rd_write_value: Stage4OpeningInput<'c, 'a>, + rs1_registers: Stage4OpeningInput<'c, 'a>, + rs2_registers: Stage4OpeningInput<'c, 'a>, + rs1_instruction: Stage4OpeningInput<'c, 'a>, + rs2_instruction: Stage4OpeningInput<'c, 'a>, + ram_val: Stage4OpeningInput<'c, 'a>, + ram_val_final: Stage4OpeningInput<'c, 'a>, + ram_val_init: Stage4OpeningInput<'c, 'a>, +} + +struct Stage4OpeningInput<'c, 'a> { + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + claim: Value<'c, 'a>, +} + +struct StageOpeningInputSpec<'a> { + symbol: &'a str, + source_stage: &'a str, + source_claim: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, +} + +struct Stage4BatchedSumcheckInputs<'c, 'a, 'b> { + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + openings: &'b Stage4OpeningInputs<'c, 'a>, + registers_gamma: Value<'c, 'a>, + ram_val_check_gamma: Value<'c, 'a>, +} + +struct SumcheckClaimSpec<'a> { + symbol: &'a str, + stage: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + claim: &'a str, + relation: &'a str, +} + +struct SumcheckBatchSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + policy: &'a str, + ordered_claims: &'a [&'a str], + claim_label: &'a str, + round_label: &'a str, + round_schedule: String, +} + +struct SumcheckDriverSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + relation: &'a str, + policy: &'a str, + round_schedule: String, + claim_label: &'a str, + round_label: &'a str, + num_rounds: usize, + degree: usize, +} + +struct SumcheckInstanceResultSpec<'a> { + symbol: &'a str, + source: &'a str, + claim: &'a str, + relation: &'a str, + index: usize, + point_arity: usize, + num_rounds: usize, + round_offset: usize, + point_order: &'a str, + degree: usize, +} + +struct OpeningClaimSpec<'a> { + symbol: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, + claim_kind: &'a str, +} + +fn stage4_registers_rw_rounds(params: &JoltProtocolParams) -> usize { + params.log_t + params.register_log_k +} + +fn stage4_output_count() -> usize { + STAGE4_REGISTER_OUTPUTS.len() + STAGE4_RAM_VAL_OUTPUTS.len() +} + +fn int_attr(value: usize) -> String { + format!("{value} : i64") +} + +fn symbol_array_attr(values: &[&str]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn schema_error(message: impl Into) -> MlirError { + SchemaError::new(message).into() +} diff --git a/crates/bolt/src/protocols/jolt/phases/stage5.rs b/crates/bolt/src/protocols/jolt/phases/stage5.rs new file mode 100644 index 0000000000..3a7b1171e8 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/stage5.rs @@ -0,0 +1,1387 @@ +use melior::ir::operation::OperationRef; +use melior::ir::Value; + +use crate::ir::{BoltModule, Compute, Party, Protocol}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{verify_protocol_schema, SchemaError}; + +use super::super::oracles; +use super::super::params::JoltProtocolParams; +use super::lowering::{lower_party_to_compute, transcript_squeeze_protocol_result_type}; + +const RAM_RA_CLAIM_REDUCTION_DEGREE: usize = 2; +const REGISTERS_VAL_EVALUATION_DEGREE: usize = 3; + +pub fn build_stage5_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.stage5", None); + oracles::append_foundation_ops(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.stage5"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + append_stage5_domains(context, &module, params)?; + append_stage5_oracles(context, &module, params)?; + append_stage5_relations(context, &module, params)?; + let inputs = append_stage5_opening_inputs(context, &module, params)?; + + let fs = context.append_typed_op( + &module, + "transcript.state", + Some("fs_after_stage4"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let state = first_result(fs, "transcript.state")?; + let stage = context.append_typed_op( + &module, + "piop.stage", + Some("stage5"), + &[ + ("name", r#""instruction_ram_and_register_value_reductions""#), + ("order", "5 : i64"), + ("roles", r#"["prover", "verifier"]"#), + ], + &[], + &["!piop.stage_type"], + )?; + let stage = first_result(stage, "piop.stage")?; + + let (state, instruction_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage5.instruction_read_raf.gamma", + "instruction_read_raf_gamma", + "challenge_scalar", + 1, + )?; + let (state, ram_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage5.ram_ra_claim_reduction.gamma", + "ram_ra_claim_reduction_gamma", + "challenge_scalar", + 1, + )?; + let _state = append_stage5_batched_sumcheck( + context, + &module, + params, + Stage5BatchedSumcheckInputs { + state, + stage, + openings: &inputs, + instruction_gamma, + ram_gamma, + }, + )?; + + verify_module(&module)?; + verify_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_stage5_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + lower_party_to_compute(context, module, "jolt.stage5", "jolt.stage5", "stage5") +} + +fn append_stage5_domains<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_domain( + context, + module, + "jolt.stage2_ram_rw_domain", + params.log_k_ram + params.log_t, + )?; + append_domain( + context, + module, + "jolt.stage4_registers_rw_domain", + params.register_log_k + params.log_t, + )?; + append_domain( + context, + module, + "jolt.stage5_instruction_read_raf_domain", + params.instruction_log_k + params.log_t, + )?; + append_domain( + context, + module, + "jolt.stage5_instruction_ra_chunk_domain", + params.lookups_ra_virtual_log_k_chunk + params.log_t, + ) +} + +fn append_domain<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + log_size: usize, +) -> Result<(), MlirError> { + context.append_op( + module, + "poly.domain", + Some(symbol), + &[("field", "@bn254_fr"), ("log_size", &int_attr(log_size))], + ) +} + +fn append_stage5_oracles<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_virtual_oracle(context, module, "LookupOutput", "jolt.trace_domain")?; + append_virtual_oracle(context, module, "LeftLookupOperand", "jolt.trace_domain")?; + append_virtual_oracle(context, module, "RightLookupOperand", "jolt.trace_domain")?; + append_virtual_oracle(context, module, "RamRa", "jolt.stage2_ram_rw_domain")?; + append_virtual_oracle( + context, + module, + "RegistersVal", + "jolt.stage4_registers_rw_domain", + )?; + append_virtual_oracle(context, module, "RdWa", "jolt.stage4_registers_rw_domain")?; + append_committed_trace_oracle(context, module, "RdInc")?; + append_virtual_oracle(context, module, "InstructionRafFlag", "jolt.trace_domain")?; + for index in 0..params.lookup_table_count { + append_virtual_oracle( + context, + module, + &format!("LookupTableFlag_{index}"), + "jolt.trace_domain", + )?; + } + for index in 0..params.instruction_ra_virtual_d { + append_virtual_oracle( + context, + module, + &format!("InstructionRa_{index}"), + "jolt.stage5_instruction_ra_chunk_domain", + )?; + } + Ok(()) +} + +fn append_virtual_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + domain: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", &format!("@{domain}")), + ("commit_domain", &format!("@{domain}")), + ("visibility", r#""virtual""#), + ("layout", r#""virtual""#), + ], + ) +} + +fn append_committed_trace_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", "@jolt.trace_domain"), + ("commit_domain", "@jolt.main_witness_commit_domain"), + ("visibility", r#""committed""#), + ("layout", r#""dense_trace""#), + ], + ) +} + +fn append_stage5_relations<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage5.instruction_read_raf", + kind: "sumcheck", + domain: "jolt.stage5_instruction_read_raf_domain", + num_rounds: stage5_instruction_rounds(params), + degree: instruction_read_raf_degree(params), + output_count: instruction_read_raf_output_count(params), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage5.ram_ra_claim_reduction", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: RAM_RA_CLAIM_REDUCTION_DEGREE, + output_count: 1, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage5.registers_val_evaluation", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: REGISTERS_VAL_EVALUATION_DEGREE, + output_count: 2, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage5.batched", + kind: "batched_sumcheck", + domain: "jolt.stage5_instruction_read_raf_domain", + num_rounds: stage5_instruction_rounds(params), + degree: instruction_read_raf_degree(params), + output_count: stage5_output_count(params), + }, + ) +} + +fn append_relation<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + spec: RelationSpec<'_>, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.relation", + Some(spec.symbol), + &[ + ("kind", &format!("\"{}\"", spec.kind)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("output_count", &int_attr(spec.output_count)), + ], + ) +} + +fn append_stage5_opening_inputs<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result, MlirError> { + Ok(Stage5OpeningInputs { + lookup_output_instruction: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage5.input.stage2.instruction.LookupOutput", + source_stage: "stage2", + source_claim: "stage2.instruction_lookup.claim_reduction.opening.LookupOutput", + oracle: "LookupOutput", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + lookup_output_product: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage5.input.stage2.product_virtual.LookupOutput", + source_stage: "stage2", + source_claim: "stage2.product_virtual.remainder.opening.LookupOutput", + oracle: "LookupOutput", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + left_lookup_operand: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage5.input.stage2.instruction.LeftLookupOperand", + source_stage: "stage2", + source_claim: "stage2.instruction_lookup.claim_reduction.opening.LeftLookupOperand", + oracle: "LeftLookupOperand", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + right_lookup_operand: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage5.input.stage2.instruction.RightLookupOperand", + source_stage: "stage2", + source_claim: + "stage2.instruction_lookup.claim_reduction.opening.RightLookupOperand", + oracle: "RightLookupOperand", + domain: "jolt.trace_domain", + point_arity: params.log_t, + }, + )?, + ram_ra_raf: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage5.input.stage2.ram_raf.RamRa", + source_stage: "stage2", + source_claim: "stage2.ram_raf.opening.RamRa", + oracle: "RamRa", + domain: "jolt.stage2_ram_rw_domain", + point_arity: params.log_k_ram + params.log_t, + }, + )?, + ram_ra_rw: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage5.input.stage2.ram_read_write.RamRa", + source_stage: "stage2", + source_claim: "stage2.ram_read_write.opening.RamRa", + oracle: "RamRa", + domain: "jolt.stage2_ram_rw_domain", + point_arity: params.log_k_ram + params.log_t, + }, + )?, + ram_ra_val: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage5.input.stage4.ram_val_check.RamRa", + source_stage: "stage4", + source_claim: "stage4.ram_val_check.opening.RamRa", + oracle: "RamRa", + domain: "jolt.stage2_ram_rw_domain", + point_arity: params.log_k_ram + params.log_t, + }, + )?, + registers_val: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage5.input.stage4.registers.RegistersVal", + source_stage: "stage4", + source_claim: "stage4.registers_read_write.opening.RegistersVal", + oracle: "RegistersVal", + domain: "jolt.stage4_registers_rw_domain", + point_arity: params.register_log_k + params.log_t, + }, + )?, + }) +} + +fn append_stage_input<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: StageOpeningInputSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_input", + Some(spec.symbol), + &[ + ("source_stage", &format!("@{}", spec.source_stage)), + ("source_claim", &format!("@{}", spec.source_claim)), + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", r#""virtual""#), + ], + &[], + &["!poly.point", "!field.scalar", "!piop.opening_claim_type"], + )?; + Ok(Stage5OpeningInput { + point: result(op, 0, "piop.opening_input")?, + eval: result(op, 1, "piop.opening_input")?, + claim: result(op, 2, "piop.opening_input")?, + }) +} + +fn append_transcript_squeeze<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + symbol: &str, + label: &str, + kind: &str, + count: usize, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "transcript.squeeze", + Some(symbol), + &[ + ("label", &format!("\"{label}\"")), + ("kind", &format!("\"{kind}\"")), + ("count", &int_attr(count)), + ], + &[state], + &[ + "!transcript.state_type", + transcript_squeeze_protocol_result_type(kind)?, + ], + )?; + Ok(( + result(op, 0, "transcript.squeeze")?, + result(op, 1, "transcript.squeeze")?, + )) +} + +fn append_stage5_batched_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + spec: Stage5BatchedSumcheckInputs<'c, 'a, '_>, +) -> Result, MlirError> { + let inputs = spec.openings; + append_opening_claim_equal( + context, + module, + "stage5.instruction.lookup_output_claim_consistency", + inputs.lookup_output_instruction.claim, + inputs.lookup_output_product.claim, + )?; + + let instruction_gamma2 = append_field_pow( + context, + module, + "stage5.instruction_read_raf.gamma2", + spec.instruction_gamma, + 2, + )?; + let left_term = append_field_mul( + context, + module, + "stage5.instruction_read_raf.term.LeftLookupOperand", + spec.instruction_gamma, + inputs.left_lookup_operand.eval, + )?; + let right_term = append_field_mul( + context, + module, + "stage5.instruction_read_raf.term.RightLookupOperand", + instruction_gamma2, + inputs.right_lookup_operand.eval, + )?; + let lookup_left_sum = append_field_add( + context, + module, + "stage5.instruction_read_raf.partial.LookupOutputLeftOperand", + inputs.lookup_output_instruction.eval, + left_term, + )?; + let instruction_claim = append_field_add( + context, + module, + "stage5.instruction_read_raf.claim_expr", + lookup_left_sum, + right_term, + )?; + + let ram_gamma2 = append_field_pow( + context, + module, + "stage5.ram_ra_claim_reduction.gamma2", + spec.ram_gamma, + 2, + )?; + let ram_rw_term = append_field_mul( + context, + module, + "stage5.ram_ra_claim_reduction.term.RamRaReadWrite", + spec.ram_gamma, + inputs.ram_ra_rw.eval, + )?; + let ram_val_term = append_field_mul( + context, + module, + "stage5.ram_ra_claim_reduction.term.RamRaValCheck", + ram_gamma2, + inputs.ram_ra_val.eval, + )?; + let ram_raf_rw_sum = append_field_add( + context, + module, + "stage5.ram_ra_claim_reduction.partial.RafReadWrite", + inputs.ram_ra_raf.eval, + ram_rw_term, + )?; + let ram_claim = append_field_add( + context, + module, + "stage5.ram_ra_claim_reduction.claim_expr", + ram_raf_rw_sum, + ram_val_term, + )?; + + let claims = [ + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage5.instruction_read_raf.input", + stage: "stage5", + domain: "jolt.stage5_instruction_read_raf_domain", + num_rounds: stage5_instruction_rounds(params), + degree: instruction_read_raf_degree(params), + claim: "stage5.instruction_read_raf.weighted_lookup_values", + relation: "jolt.stage5.instruction_read_raf", + }, + instruction_claim, + &[ + inputs.lookup_output_instruction.claim, + inputs.left_lookup_operand.claim, + inputs.right_lookup_operand.claim, + ], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage5.ram_ra_claim_reduction.input", + stage: "stage5", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: RAM_RA_CLAIM_REDUCTION_DEGREE, + claim: "stage5.ram_ra_claim_reduction.weighted_ram_ra", + relation: "jolt.stage5.ram_ra_claim_reduction", + }, + ram_claim, + &[ + inputs.ram_ra_raf.claim, + inputs.ram_ra_rw.claim, + inputs.ram_ra_val.claim, + ], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage5.registers_val_evaluation.input", + stage: "stage5", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: REGISTERS_VAL_EVALUATION_DEGREE, + claim: "stage5.registers_val_evaluation.registers_val", + relation: "jolt.stage5.registers_val_evaluation", + }, + inputs.registers_val.eval, + &[inputs.registers_val.claim], + )?, + ]; + let round_schedule = format!("[{}, {}]", params.instruction_log_k, params.log_t); + let batch = append_sumcheck_batch( + context, + module, + spec.stage, + &claims, + SumcheckBatchSpec { + symbol: "stage5.batch", + stage: "stage5", + proof_slot: "stage5.sumcheck", + policy: "jolt_core_stage5_aligned", + ordered_claims: &[ + "stage5.instruction_read_raf.input", + "stage5.ram_ra_claim_reduction.input", + "stage5.registers_val_evaluation.input", + ], + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + round_schedule: &round_schedule, + }, + )?; + let (state, point, result_value) = append_sumcheck( + context, + module, + spec.state, + batch, + SumcheckDriverSpec { + symbol: "stage5.sumcheck", + stage: "stage5", + proof_slot: "stage5.sumcheck", + relation: "jolt.stage5.batched", + policy: "jolt_core_stage5_aligned", + round_schedule: &round_schedule, + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + num_rounds: stage5_instruction_rounds(params), + degree: instruction_read_raf_degree(params), + }, + )?; + let instruction = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage5.instruction_read_raf.instance", + source: "stage5.sumcheck", + claim: "stage5.instruction_read_raf.input", + relation: "jolt.stage5.instruction_read_raf", + index: 0, + point_arity: stage5_instruction_rounds(params), + num_rounds: stage5_instruction_rounds(params), + round_offset: 0, + point_order: "instruction_read_raf", + degree: instruction_read_raf_degree(params), + }, + point, + result_value, + )?; + let ram = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage5.ram_ra_claim_reduction.instance", + source: "stage5.sumcheck", + claim: "stage5.ram_ra_claim_reduction.input", + relation: "jolt.stage5.ram_ra_claim_reduction", + index: 1, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: params.instruction_log_k, + point_order: "reverse", + degree: RAM_RA_CLAIM_REDUCTION_DEGREE, + }, + point, + result_value, + )?; + let registers = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage5.registers_val_evaluation.instance", + source: "stage5.sumcheck", + claim: "stage5.registers_val_evaluation.input", + relation: "jolt.stage5.registers_val_evaluation", + index: 2, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: params.instruction_log_k, + point_order: "reverse", + degree: REGISTERS_VAL_EVALUATION_DEGREE, + }, + point, + result_value, + )?; + append_stage5_output_openings(context, module, params, inputs, instruction, ram, registers)?; + Ok(state) +} + +fn append_stage5_output_openings<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + inputs: &Stage5OpeningInputs<'c, 'a>, + instruction: (Value<'c, 'a>, Value<'c, 'a>), + ram: (Value<'c, 'a>, Value<'c, 'a>), + registers: (Value<'c, 'a>, Value<'c, 'a>), +) -> Result<(), MlirError> { + let mut claims = Vec::new(); + let mut claim_symbols = Vec::new(); + + let instruction_cycle = append_point_slice( + context, + module, + "stage5.instruction_read_raf.point.Cycle", + "stage5.instruction_read_raf.instance", + params.instruction_log_k, + params.log_t, + instruction.0, + )?; + for index in 0..params.lookup_table_count { + let oracle = format!("LookupTableFlag_{index}"); + let symbol = format!("stage5.instruction_read_raf.opening.{oracle}"); + let eval_symbol = format!("stage5.instruction_read_raf.eval.{oracle}"); + let eval = append_sumcheck_eval( + context, + module, + &eval_symbol, + "stage5.sumcheck", + &oracle, + index, + instruction.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + instruction_cycle, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle: &oracle, + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "virtual", + }, + )?); + } + + for index in 0..params.instruction_ra_virtual_d { + let oracle = format!("InstructionRa_{index}"); + let symbol = format!("stage5.instruction_read_raf.opening.{oracle}"); + let address_chunk = append_point_slice( + context, + module, + &format!("stage5.instruction_read_raf.point.{oracle}.address"), + "stage5.instruction_read_raf.instance", + index * params.lookups_ra_virtual_log_k_chunk, + params.lookups_ra_virtual_log_k_chunk, + instruction.0, + )?; + let ra_point = append_point_concat( + context, + module, + &format!("stage5.instruction_read_raf.point.{oracle}"), + "address_chunk_then_cycle", + params.lookups_ra_virtual_log_k_chunk + params.log_t, + &[address_chunk, instruction_cycle], + )?; + let eval_symbol = format!("stage5.instruction_read_raf.eval.{oracle}"); + let eval = append_sumcheck_eval( + context, + module, + &eval_symbol, + "stage5.sumcheck", + &oracle, + params.lookup_table_count + index, + instruction.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + ra_point, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle: &oracle, + domain: "jolt.stage5_instruction_ra_chunk_domain", + point_arity: params.lookups_ra_virtual_log_k_chunk + params.log_t, + claim_kind: "virtual", + }, + )?); + } + + let raf_flag_eval_index = params.lookup_table_count + params.instruction_ra_virtual_d; + let raf_flag_eval = append_sumcheck_eval( + context, + module, + "stage5.instruction_read_raf.eval.InstructionRafFlag", + "stage5.sumcheck", + "InstructionRafFlag", + raf_flag_eval_index, + instruction.1, + )?; + claim_symbols.push("stage5.instruction_read_raf.opening.InstructionRafFlag".to_owned()); + claims.push(append_opening_claim( + context, + module, + instruction_cycle, + raf_flag_eval, + OpeningClaimSpec { + symbol: "stage5.instruction_read_raf.opening.InstructionRafFlag", + oracle: "InstructionRafFlag", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "virtual", + }, + )?); + + let ram_address = append_point_slice( + context, + module, + "stage5.ram_ra_claim_reduction.point.RamAddress", + "stage5.input.stage2.ram_raf.RamRa", + 0, + params.log_k_ram, + inputs.ram_ra_raf.point, + )?; + let ram_ra_point = append_point_concat( + context, + module, + "stage5.ram_ra_claim_reduction.point.RamRa", + "address_then_cycle", + params.log_k_ram + params.log_t, + &[ram_address, ram.0], + )?; + let ram_ra_eval = append_sumcheck_eval( + context, + module, + "stage5.ram_ra_claim_reduction.eval.RamRa", + "stage5.sumcheck", + "RamRa", + 0, + ram.1, + )?; + claim_symbols.push("stage5.ram_ra_claim_reduction.opening.RamRa".to_owned()); + claims.push(append_opening_claim( + context, + module, + ram_ra_point, + ram_ra_eval, + OpeningClaimSpec { + symbol: "stage5.ram_ra_claim_reduction.opening.RamRa", + oracle: "RamRa", + domain: "jolt.stage2_ram_rw_domain", + point_arity: params.log_k_ram + params.log_t, + claim_kind: "virtual", + }, + )?); + + let rd_inc_eval = append_sumcheck_eval( + context, + module, + "stage5.registers_val_evaluation.eval.RdInc", + "stage5.sumcheck", + "RdInc", + 0, + registers.1, + )?; + claim_symbols.push("stage5.registers_val_evaluation.opening.RdInc".to_owned()); + claims.push(append_opening_claim( + context, + module, + registers.0, + rd_inc_eval, + OpeningClaimSpec { + symbol: "stage5.registers_val_evaluation.opening.RdInc", + oracle: "RdInc", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?); + + let register_address = append_point_slice( + context, + module, + "stage5.registers_val_evaluation.point.RegisterAddress", + "stage5.input.stage4.registers.RegistersVal", + 0, + params.register_log_k, + inputs.registers_val.point, + )?; + let rd_wa_point = append_point_concat( + context, + module, + "stage5.registers_val_evaluation.point.RdWa", + "register_address_then_cycle", + params.register_log_k + params.log_t, + &[register_address, registers.0], + )?; + let rd_wa_eval = append_sumcheck_eval( + context, + module, + "stage5.registers_val_evaluation.eval.RdWa", + "stage5.sumcheck", + "RdWa", + 1, + registers.1, + )?; + claim_symbols.push("stage5.registers_val_evaluation.opening.RdWa".to_owned()); + claims.push(append_opening_claim( + context, + module, + rd_wa_point, + rd_wa_eval, + OpeningClaimSpec { + symbol: "stage5.registers_val_evaluation.opening.RdWa", + oracle: "RdWa", + domain: "jolt.stage4_registers_rw_domain", + point_arity: params.register_log_k + params.log_t, + claim_kind: "virtual", + }, + )?); + + let claim_names = claim_symbols.iter().map(String::as_str).collect::>(); + let _batch = context.append_typed_op( + module, + "piop.opening_batch", + Some("stage5.openings"), + &[ + ("stage", "@stage5"), + ("proof_slot", "@stage5.openings"), + ("policy", r#""jolt_stage5_output_order""#), + ("count", &int_attr(claims.len())), + ("ordered_claims", &symbol_array_attr(&claim_names)), + ], + &claims, + &["!piop.opening_batch_type"], + )?; + Ok(()) +} + +fn append_opening_claim_equal<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + left: Value<'c, '_>, + right: Value<'c, '_>, +) -> Result<(), MlirError> { + let _operation = context.append_typed_op( + module, + "piop.opening_claim_equal", + Some(symbol), + &[("mode", r#""point_and_eval""#)], + &[left, right], + &[], + )?; + Ok(()) +} + +fn append_field_binary<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + op_name: &str, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + op_name, + Some(symbol), + &[], + &[lhs, rhs], + &["!field.scalar"], + )?; + first_result(op, op_name) +} + +fn append_field_add<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.add", symbol, lhs, rhs) +} + +fn append_field_mul<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.mul", symbol, lhs, rhs) +} + +fn append_field_pow<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + base: Value<'c, 'a>, + exponent: usize, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.pow", + Some(symbol), + &[("exponent", &int_attr(exponent))], + &[base], + &["!field.scalar"], + )?; + first_result(op, "field.pow") +} + +fn append_sumcheck_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckClaimSpec<'_>, + input_claim: Value<'c, 'a>, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(inputs.len() + 1); + operands.push(input_claim); + operands.extend_from_slice(inputs); + let op = context.append_typed_op( + module, + "piop.sumcheck_claim", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ], + &operands, + &["!piop.sumcheck_claim_type"], + )?; + first_result(op, "piop.sumcheck_claim") +} + +fn append_sumcheck_batch<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + stage: Value<'c, 'a>, + claims: &[Value<'c, 'a>], + spec: SumcheckBatchSpec<'_>, +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(claims.len() + 1); + operands.push(stage); + operands.extend_from_slice(claims); + let op = context.append_typed_op( + module, + "piop.sumcheck_batch", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("policy", &format!("\"{}\"", spec.policy)), + ("count", &int_attr(spec.ordered_claims.len())), + ("ordered_claims", &symbol_array_attr(spec.ordered_claims)), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("round_schedule", spec.round_schedule), + ], + &operands, + &["!piop.sumcheck_batch_type"], + )?; + first_result(op, "piop.sumcheck_batch") +} + +fn append_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + batch: Value<'c, 'a>, + spec: SumcheckDriverSpec<'_>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("relation", &format!("@{}", spec.relation)), + ("policy", &format!("\"{}\"", spec.policy)), + ("round_schedule", spec.round_schedule), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ], + &[state, batch], + &[ + "!transcript.state_type", + "!poly.point", + "!piop.sumcheck_result_type", + "!piop.sumcheck_proof_type", + ], + )?; + Ok(( + result(op, 0, "piop.sumcheck")?, + result(op, 1, "piop.sumcheck")?, + result(op, 2, "piop.sumcheck")?, + )) +} + +fn append_sumcheck_instance_result<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckInstanceResultSpec<'_>, + point: Value<'c, 'a>, + result_value: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_instance_result", + Some(spec.symbol), + &[ + ("source", &format!("@{}", spec.source)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ("index", &int_attr(spec.index)), + ("point_arity", &int_attr(spec.point_arity)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("round_offset", &int_attr(spec.round_offset)), + ("point_order", &format!("\"{}\"", spec.point_order)), + ("degree", &int_attr(spec.degree)), + ], + &[point, result_value], + &["!poly.point", "!piop.sumcheck_result_type"], + )?; + Ok(( + result(op, 0, "piop.sumcheck_instance_result")?, + result(op, 1, "piop.sumcheck_instance_result")?, + )) +} + +fn append_sumcheck_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + oracle: &str, + index: usize, + result_value: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_eval", + Some(symbol), + &[ + ("source", &format!("@{}", source)), + ("name", &format!("@{}", symbol)), + ("index", &int_attr(index)), + ("oracle", &format!("@{}", oracle)), + ], + &[result_value], + &["!field.scalar"], + )?; + first_result(op, "piop.sumcheck_eval") +} + +fn append_point_slice<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + offset: usize, + length: usize, + input: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_slice", + Some(symbol), + &[ + ("source", &format!("@{}", source)), + ("offset", &int_attr(offset)), + ("length", &int_attr(length)), + ], + &[input], + &["!poly.point"], + )?; + first_result(op, "poly.point_slice") +} + +fn append_point_concat<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + layout: &str, + arity: usize, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_concat", + Some(symbol), + &[ + ("layout", &format!("\"{}\"", layout)), + ("arity", &int_attr(arity)), + ], + inputs, + &["!poly.point"], + )?; + first_result(op, "poly.point_concat") +} + +fn append_opening_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + spec: OpeningClaimSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_claim", + Some(spec.symbol), + &[ + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", &format!("\"{}\"", spec.claim_kind)), + ], + &[point, eval], + &["!piop.opening_claim_type"], + )?; + first_result(op, "piop.opening_claim") +} + +fn stage5_instruction_rounds(params: &JoltProtocolParams) -> usize { + params.instruction_log_k + params.log_t +} + +fn instruction_read_raf_degree(params: &JoltProtocolParams) -> usize { + params.instruction_ra_virtual_d + 2 +} + +fn instruction_read_raf_output_count(params: &JoltProtocolParams) -> usize { + params.lookup_table_count + params.instruction_ra_virtual_d + 1 +} + +fn stage5_output_count(params: &JoltProtocolParams) -> usize { + instruction_read_raf_output_count(params) + 3 +} + +fn int_attr(value: usize) -> String { + format!("{value} : i64") +} + +fn symbol_array_attr(values: &[&str]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn first_result<'c, 'a>( + op: OperationRef<'c, 'a>, + context: &str, +) -> Result, MlirError> { + result(op, 0, context) +} + +fn result<'c, 'a>( + op: OperationRef<'c, 'a>, + index: usize, + context: &str, +) -> Result, MlirError> { + op.result(index) + .map(Into::into) + .map_err(|_| schema_error(format!("{context} expected result {index}"))) +} + +fn schema_error(message: impl Into) -> MlirError { + SchemaError::new(message).into() +} + +struct Stage5OpeningInput<'c, 'a> { + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + claim: Value<'c, 'a>, +} + +struct Stage5OpeningInputs<'c, 'a> { + lookup_output_instruction: Stage5OpeningInput<'c, 'a>, + lookup_output_product: Stage5OpeningInput<'c, 'a>, + left_lookup_operand: Stage5OpeningInput<'c, 'a>, + right_lookup_operand: Stage5OpeningInput<'c, 'a>, + ram_ra_raf: Stage5OpeningInput<'c, 'a>, + ram_ra_rw: Stage5OpeningInput<'c, 'a>, + ram_ra_val: Stage5OpeningInput<'c, 'a>, + registers_val: Stage5OpeningInput<'c, 'a>, +} + +struct StageOpeningInputSpec<'a> { + symbol: &'a str, + source_stage: &'a str, + source_claim: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, +} + +struct Stage5BatchedSumcheckInputs<'c, 'a, 'b> { + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + openings: &'b Stage5OpeningInputs<'c, 'a>, + instruction_gamma: Value<'c, 'a>, + ram_gamma: Value<'c, 'a>, +} + +struct RelationSpec<'a> { + symbol: &'a str, + kind: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + output_count: usize, +} + +struct SumcheckClaimSpec<'a> { + symbol: &'a str, + stage: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + claim: &'a str, + relation: &'a str, +} + +struct SumcheckBatchSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + policy: &'a str, + ordered_claims: &'a [&'a str], + claim_label: &'a str, + round_label: &'a str, + round_schedule: &'a str, +} + +struct SumcheckDriverSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + relation: &'a str, + policy: &'a str, + round_schedule: &'a str, + claim_label: &'a str, + round_label: &'a str, + num_rounds: usize, + degree: usize, +} + +struct SumcheckInstanceResultSpec<'a> { + symbol: &'a str, + source: &'a str, + claim: &'a str, + relation: &'a str, + index: usize, + point_arity: usize, + num_rounds: usize, + round_offset: usize, + point_order: &'a str, + degree: usize, +} + +struct OpeningClaimSpec<'a> { + symbol: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, + claim_kind: &'a str, +} diff --git a/crates/bolt/src/protocols/jolt/phases/stage6.rs b/crates/bolt/src/protocols/jolt/phases/stage6.rs new file mode 100644 index 0000000000..2e474f42c5 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/stage6.rs @@ -0,0 +1,2365 @@ +use std::collections::BTreeSet; + +use melior::ir::operation::OperationRef; +use melior::ir::Value; + +use crate::ir::{BoltModule, Compute, Party, Protocol}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{verify_protocol_schema, SchemaError}; + +use super::super::oracles; +use super::super::params::JoltProtocolParams; +use super::lowering::{lower_party_to_compute, transcript_squeeze_protocol_result_type}; + +const BOOLEANITY_DEGREE: usize = 3; +const HAMMING_BOOLEANITY_DEGREE: usize = 3; +const INC_CLAIM_REDUCTION_DEGREE: usize = 2; + +#[derive(Clone, Copy)] +enum BytecodeStageGamma { + Stage1, + Stage2, + Stage3, + Stage4, + Stage5, +} + +pub fn build_stage6_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.stage6", None); + oracles::append_foundation_ops(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.stage6"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + append_stage6_domains(context, &module, params)?; + append_stage6_oracles(context, &module, params)?; + append_stage6_relations(context, &module, params)?; + let inputs = append_stage6_opening_inputs(context, &module, params)?; + + let fs = context.append_typed_op( + &module, + "transcript.state", + Some("fs_after_stage5"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let state = first_result(fs, "transcript.state")?; + let stage = context.append_typed_op( + &module, + "piop.stage", + Some("stage6"), + &[ + ( + "name", + r#""bytecode_booleanity_and_virtual_address_reductions""#, + ), + ("order", "6 : i64"), + ("roles", r#"["prover", "verifier"]"#), + ], + &[], + &["!piop.stage_type"], + )?; + let stage = first_result(stage, "piop.stage")?; + + let (state, bc_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.bytecode_read_raf.gamma", + "bc_raf_gamma", + "challenge_scalar", + 1, + )?; + let (state, bc_stage1_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.bytecode_read_raf.stage1_gamma", + "bc_raf_stage1_gamma", + "challenge_scalar", + 1, + )?; + let (state, bc_stage2_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.bytecode_read_raf.stage2_gamma", + "bc_raf_stage2_gamma", + "challenge_scalar", + 1, + )?; + let (state, bc_stage3_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.bytecode_read_raf.stage3_gamma", + "bc_raf_stage3_gamma", + "challenge_scalar", + 1, + )?; + let (state, bc_stage4_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.bytecode_read_raf.stage4_gamma", + "bc_raf_stage4_gamma", + "challenge_scalar", + 1, + )?; + let (state, bc_stage5_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.bytecode_read_raf.stage5_gamma", + "bc_raf_stage5_gamma", + "challenge_scalar", + 1, + )?; + let (state, booleanity_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.booleanity.gamma", + "booleanity_gamma", + "challenge_scalar", + 1, + )?; + append_booleanity_power_placeholders(context, &module, params, booleanity_gamma)?; + let (state, inst_ra_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.instruction_ra_virtual.gamma", + "inst_ra_virtual_gamma", + "challenge_scalar", + 1, + )?; + let (state, inc_gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage6.inc_claim_reduction.gamma", + "inc_reduction_gamma", + "challenge_scalar", + 1, + )?; + + let _state = append_stage6_batched_sumcheck( + context, + &module, + params, + Stage6BatchedSumcheckInputs { + state, + stage, + openings: &inputs, + bc_gamma, + bc_stage1_gamma, + bc_stage2_gamma, + bc_stage3_gamma, + bc_stage4_gamma, + bc_stage5_gamma, + inst_ra_gamma, + inc_gamma, + }, + )?; + + verify_module(&module)?; + verify_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_stage6_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + lower_party_to_compute(context, module, "jolt.stage6", "jolt.stage6", "stage6") +} + +fn append_stage6_domains<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_domain( + context, + module, + "jolt.stage2_ram_rw_domain", + params.log_k_ram + params.log_t, + )?; + append_domain( + context, + module, + "jolt.stage4_registers_rw_domain", + params.register_log_k + params.log_t, + )?; + append_domain( + context, + module, + "jolt.stage5_instruction_ra_chunk_domain", + params.lookups_ra_virtual_log_k_chunk + params.log_t, + )?; + append_domain( + context, + module, + "jolt.stage6_bytecode_read_raf_domain", + stage6_max_rounds(params), + )?; + append_domain( + context, + module, + "jolt.stage6_booleanity_domain", + booleanity_rounds(params), + ) +} + +fn append_domain<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + log_size: usize, +) -> Result<(), MlirError> { + context.append_op( + module, + "poly.domain", + Some(symbol), + &[("field", "@bn254_fr"), ("log_size", &int_attr(log_size))], + ) +} + +fn append_stage6_oracles<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + let mut trace_oracles = BTreeSet::new(); + trace_oracles.extend( + [ + "HammingWeight", + "Imm", + "InstructionFlagBranch", + "InstructionFlagIsNoop", + "InstructionFlagLeftOperandIsPC", + "InstructionFlagLeftOperandIsRs1Value", + "InstructionFlagRightOperandIsImm", + "InstructionFlagRightOperandIsRs2Value", + "InstructionRafFlag", + "LookupOutput", + "OpFlagAddOperands", + "OpFlagAdvice", + "OpFlagAssert", + "OpFlagDoNotUpdateUnexpandedPC", + "OpFlagIsCompressed", + "OpFlagIsFirstInSequence", + "OpFlagIsLastInSequence", + "OpFlagJump", + "OpFlagLoad", + "OpFlagMultiplyOperands", + "OpFlagStore", + "OpFlagSubtractOperands", + "OpFlagVirtualInstruction", + "OpFlagWriteLookupOutputToRD", + "PC", + "UnexpandedPC", + ] + .into_iter() + .map(str::to_owned), + ); + for index in 0..params.lookup_table_count { + let _inserted = trace_oracles.insert(format!("LookupTableFlag_{index}")); + } + for oracle in trace_oracles { + append_virtual_oracle(context, module, &oracle, "jolt.trace_domain")?; + } + + append_virtual_oracle(context, module, "RamRa", "jolt.stage2_ram_rw_domain")?; + for oracle in ["RdWa", "Rs1Ra", "Rs2Ra"] { + append_virtual_oracle(context, module, oracle, "jolt.stage4_registers_rw_domain")?; + } + + append_committed_trace_oracle(context, module, "RamInc")?; + append_committed_trace_oracle(context, module, "RdInc")?; + for index in 0..params.instruction_d { + append_committed_main_witness_oracle(context, module, &format!("InstructionRa_{index}"))?; + } + for index in 0..params.bytecode_d { + append_committed_main_witness_oracle(context, module, &format!("BytecodeRa_{index}"))?; + } + for index in 0..params.ram_d { + append_committed_main_witness_oracle(context, module, &format!("RamRa_{index}"))?; + } + Ok(()) +} + +fn append_virtual_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + domain: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", &format!("@{domain}")), + ("commit_domain", &format!("@{domain}")), + ("visibility", r#""virtual""#), + ("layout", r#""virtual""#), + ], + ) +} + +fn append_committed_trace_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", "@jolt.trace_domain"), + ("commit_domain", "@jolt.main_witness_commit_domain"), + ("visibility", r#""committed""#), + ("layout", r#""dense_trace""#), + ], + ) +} + +fn append_committed_main_witness_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", "@jolt.main_witness_commit_domain"), + ("commit_domain", "@jolt.main_witness_commit_domain"), + ("visibility", r#""committed""#), + ("layout", r#""onehot_expanded""#), + ], + ) +} + +fn append_stage6_relations<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage6.bytecode_read_raf", + kind: "sumcheck", + domain: "jolt.stage6_bytecode_read_raf_domain", + num_rounds: stage6_max_rounds(params), + degree: params.bytecode_d + 1, + output_count: params.bytecode_d, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage6.booleanity", + kind: "sumcheck", + domain: "jolt.stage6_booleanity_domain", + num_rounds: booleanity_rounds(params), + degree: BOOLEANITY_DEGREE, + output_count: total_ra_oracles(params), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage6.hamming_booleanity", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: HAMMING_BOOLEANITY_DEGREE, + output_count: 1, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage6.ram_ra_virtual", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: params.ram_d + 1, + output_count: params.ram_d, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage6.instruction_ra_virtual", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: n_committed_per_virtual(params) + 1, + output_count: params.instruction_d, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage6.inc_claim_reduction", + kind: "sumcheck", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: INC_CLAIM_REDUCTION_DEGREE, + output_count: 2, + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage6.batched", + kind: "batched_sumcheck", + domain: "jolt.stage6_bytecode_read_raf_domain", + num_rounds: stage6_max_rounds(params), + degree: stage6_batched_degree(params), + output_count: stage6_output_count(params), + }, + ) +} + +fn append_relation<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + spec: RelationSpec<'_>, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.relation", + Some(spec.symbol), + &[ + ("kind", &format!("\"{}\"", spec.kind)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("output_count", &int_attr(spec.output_count)), + ], + ) +} + +fn append_stage6_opening_inputs<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let mut bytecode_terms = Vec::new(); + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + "stage6.input.stage1.UnexpandedPC", + "stage1", + "stage1.outer_remaining.opening.UnexpandedPC", + "UnexpandedPC", + 0, + Some(BytecodeStageGamma::Stage1), + 0, + ), + )?; + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + "stage6.input.stage1.Imm", + "stage1", + "stage1.outer_remaining.opening.Imm", + "Imm", + 0, + Some(BytecodeStageGamma::Stage1), + 1, + ), + )?; + for (index, oracle) in STAGE1_OP_FLAGS.iter().enumerate() { + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + &format!("stage6.input.stage1.{oracle}"), + "stage1", + &format!("stage1.outer_remaining.opening.{oracle}"), + oracle, + 0, + Some(BytecodeStageGamma::Stage1), + 2 + index, + ), + )?; + } + for (oracle, stage_gamma_power) in [ + ("OpFlagJump", 0), + ("InstructionFlagBranch", 1), + ("OpFlagWriteLookupOutputToRD", 2), + ("OpFlagVirtualInstruction", 3), + ] { + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + &format!("stage6.input.stage2.{oracle}"), + "stage2", + &format!("stage2.product_virtual.remainder.opening.{oracle}"), + oracle, + 1, + Some(BytecodeStageGamma::Stage2), + stage_gamma_power, + ), + )?; + } + for (symbol, source_claim, oracle, stage_gamma_power) in [ + ( + "stage6.input.stage3.instruction_input.Imm", + "stage3.instruction_input.opening.Imm", + "Imm", + 0, + ), + ( + "stage6.input.stage3.spartan_shift.UnexpandedPC", + "stage3.spartan_shift.opening.UnexpandedPC", + "UnexpandedPC", + 1, + ), + ( + "stage6.input.stage3.instruction_input.InstructionFlagLeftOperandIsRs1Value", + "stage3.instruction_input.opening.InstructionFlagLeftOperandIsRs1Value", + "InstructionFlagLeftOperandIsRs1Value", + 2, + ), + ( + "stage6.input.stage3.instruction_input.InstructionFlagLeftOperandIsPC", + "stage3.instruction_input.opening.InstructionFlagLeftOperandIsPC", + "InstructionFlagLeftOperandIsPC", + 3, + ), + ( + "stage6.input.stage3.instruction_input.InstructionFlagRightOperandIsRs2Value", + "stage3.instruction_input.opening.InstructionFlagRightOperandIsRs2Value", + "InstructionFlagRightOperandIsRs2Value", + 4, + ), + ( + "stage6.input.stage3.instruction_input.InstructionFlagRightOperandIsImm", + "stage3.instruction_input.opening.InstructionFlagRightOperandIsImm", + "InstructionFlagRightOperandIsImm", + 5, + ), + ( + "stage6.input.stage3.spartan_shift.InstructionFlagIsNoop", + "stage3.spartan_shift.opening.InstructionFlagIsNoop", + "InstructionFlagIsNoop", + 6, + ), + ( + "stage6.input.stage3.spartan_shift.OpFlagVirtualInstruction", + "stage3.spartan_shift.opening.OpFlagVirtualInstruction", + "OpFlagVirtualInstruction", + 7, + ), + ( + "stage6.input.stage3.spartan_shift.OpFlagIsFirstInSequence", + "stage3.spartan_shift.opening.OpFlagIsFirstInSequence", + "OpFlagIsFirstInSequence", + 8, + ), + ] { + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + symbol, + "stage3", + source_claim, + oracle, + 2, + Some(BytecodeStageGamma::Stage3), + stage_gamma_power, + ), + )?; + } + for (oracle, stage_gamma_power) in [("RdWa", 0), ("Rs1Ra", 1), ("Rs2Ra", 2)] { + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec { + input: StageOpeningInputSpec { + symbol: &format!("stage6.input.stage4.{oracle}"), + source_stage: "stage4", + source_claim: &format!("stage4.registers_read_write.opening.{oracle}"), + oracle, + domain: "jolt.stage4_registers_rw_domain", + point_arity: params.register_log_k + params.log_t, + claim_kind: "virtual", + }, + gamma_power: 3, + stage_gamma: Some(BytecodeStageGamma::Stage4), + stage_gamma_power, + }, + )?; + } + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec { + input: StageOpeningInputSpec { + symbol: "stage6.input.stage5.registers_val_evaluation.RdWa", + source_stage: "stage5", + source_claim: "stage5.registers_val_evaluation.opening.RdWa", + oracle: "RdWa", + domain: "jolt.stage4_registers_rw_domain", + point_arity: params.register_log_k + params.log_t, + claim_kind: "virtual", + }, + gamma_power: 4, + stage_gamma: Some(BytecodeStageGamma::Stage5), + stage_gamma_power: 0, + }, + )?; + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + "stage6.input.stage5.InstructionRafFlag", + "stage5", + "stage5.instruction_read_raf.opening.InstructionRafFlag", + "InstructionRafFlag", + 4, + Some(BytecodeStageGamma::Stage5), + 1, + ), + )?; + for index in 0..params.lookup_table_count { + let oracle = format!("LookupTableFlag_{index}"); + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + &format!("stage6.input.stage5.{oracle}"), + "stage5", + &format!("stage5.instruction_read_raf.opening.{oracle}"), + &oracle, + 4, + Some(BytecodeStageGamma::Stage5), + 2 + index, + ), + )?; + } + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + "stage6.input.stage1.PC", + "stage1", + "stage1.outer_remaining.opening.PC", + "PC", + 5, + None, + 0, + ), + )?; + append_bytecode_term( + context, + module, + params, + &mut bytecode_terms, + BytecodeTermSpec::trace( + "stage6.input.stage3.spartan_shift.PC", + "stage3", + "stage3.spartan_shift.opening.PC", + "PC", + 6, + None, + 0, + ), + )?; + + let ram_ra_virtual = append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage6.input.stage5.ram_ra_claim_reduction.RamRa", + source_stage: "stage5", + source_claim: "stage5.ram_ra_claim_reduction.opening.RamRa", + oracle: "RamRa", + domain: "jolt.stage2_ram_rw_domain", + point_arity: params.log_k_ram + params.log_t, + claim_kind: "virtual", + }, + )?; + + let mut instruction_ra_virtual = Vec::with_capacity(params.instruction_ra_virtual_d); + for index in 0..params.instruction_ra_virtual_d { + instruction_ra_virtual.push(append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: &format!("stage6.input.stage5.instruction_read_raf.InstructionRa_{index}"), + source_stage: "stage5", + source_claim: &format!("stage5.instruction_read_raf.opening.InstructionRa_{index}"), + oracle: &format!("InstructionRa_{index}"), + domain: "jolt.stage5_instruction_ra_chunk_domain", + point_arity: params.lookups_ra_virtual_log_k_chunk + params.log_t, + claim_kind: "virtual", + }, + )?); + } + + Ok(Stage6OpeningInputs { + bytecode_terms, + hamming_lookup_output: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage6.input.stage1.LookupOutput", + source_stage: "stage1", + source_claim: "stage1.outer_remaining.opening.LookupOutput", + oracle: "LookupOutput", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "virtual", + }, + )?, + ram_ra_virtual, + instruction_ra_virtual, + ram_inc_stage2: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage6.input.stage2.ram_read_write.RamInc", + source_stage: "stage2", + source_claim: "stage2.ram_read_write.opening.RamInc", + oracle: "RamInc", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?, + ram_inc_stage4: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage6.input.stage4.ram_val_check.RamInc", + source_stage: "stage4", + source_claim: "stage4.ram_val_check.opening.RamInc", + oracle: "RamInc", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?, + rd_inc_stage4: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage6.input.stage4.registers_read_write.RdInc", + source_stage: "stage4", + source_claim: "stage4.registers_read_write.opening.RdInc", + oracle: "RdInc", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?, + rd_inc_stage5: append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage6.input.stage5.registers_val_evaluation.RdInc", + source_stage: "stage5", + source_claim: "stage5.registers_val_evaluation.opening.RdInc", + oracle: "RdInc", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?, + }) +} + +fn append_bytecode_term<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + terms: &mut Vec>, + mut spec: BytecodeTermSpec<'_>, +) -> Result<(), MlirError> { + if spec.input.point_arity == 0 { + spec.input.point_arity = params.log_t; + } + let input = append_stage_input(context, module, spec.input)?; + terms.push(Stage6BytecodeTerm { + eval: input.eval, + claim: input.claim, + gamma_power: spec.gamma_power, + stage_gamma: spec.stage_gamma, + stage_gamma_power: spec.stage_gamma_power, + }); + Ok(()) +} + +fn append_stage_input<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: StageOpeningInputSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_input", + Some(spec.symbol), + &[ + ("source_stage", &format!("@{}", spec.source_stage)), + ("source_claim", &format!("@{}", spec.source_claim)), + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", &format!("\"{}\"", spec.claim_kind)), + ], + &[], + &["!poly.point", "!field.scalar", "!piop.opening_claim_type"], + )?; + Ok(Stage6OpeningInput { + point: result(op, 0, "piop.opening_input")?, + eval: result(op, 1, "piop.opening_input")?, + claim: result(op, 2, "piop.opening_input")?, + }) +} + +fn append_transcript_squeeze<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + symbol: &str, + label: &str, + kind: &str, + count: usize, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "transcript.squeeze", + Some(symbol), + &[ + ("label", &format!("\"{label}\"")), + ("kind", &format!("\"{kind}\"")), + ("count", &int_attr(count)), + ], + &[state], + &[ + "!transcript.state_type", + transcript_squeeze_protocol_result_type(kind)?, + ], + )?; + Ok(( + result(op, 0, "transcript.squeeze")?, + result(op, 1, "transcript.squeeze")?, + )) +} + +fn append_booleanity_power_placeholders<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + booleanity_gamma: Value<'c, 'a>, +) -> Result<(), MlirError> { + let total = total_ra_oracles(params); + for index in 0..total { + let _ = append_field_pow( + context, + module, + &format!("stage6.booleanity.gamma_sq_{index}"), + booleanity_gamma, + 2 * index, + )?; + } + for index in 0..total { + let _ = append_field_pow( + context, + module, + &format!("stage6.booleanity.gamma_pow_{index}"), + booleanity_gamma, + index, + )?; + } + Ok(()) +} + +fn append_stage6_batched_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + spec: Stage6BatchedSumcheckInputs<'c, 'a, '_>, +) -> Result, MlirError> { + let inputs = spec.openings; + let bytecode_claim = append_bytecode_read_raf_claim(context, module, inputs, &spec)?; + let zero = append_field_zero(context, module, "stage6.zero")?; + let ram_ra_virtual_claim = inputs.ram_ra_virtual.eval; + let inst_ra_virtual_claim = + append_instruction_ra_virtual_claim(context, module, inputs, spec.inst_ra_gamma)?; + let inc_claim = append_inc_claim_reduction_claim(context, module, inputs, spec.inc_gamma)?; + + let bytecode_inputs = inputs + .bytecode_terms + .iter() + .map(|term| term.claim) + .collect::>(); + let instruction_inputs = inputs + .instruction_ra_virtual + .iter() + .map(|input| input.claim) + .collect::>(); + let inc_inputs = [ + inputs.ram_inc_stage2.claim, + inputs.ram_inc_stage4.claim, + inputs.rd_inc_stage4.claim, + inputs.rd_inc_stage5.claim, + ]; + let claims = [ + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage6.bytecode_read_raf.input", + stage: "stage6", + domain: "jolt.stage6_bytecode_read_raf_domain", + num_rounds: stage6_max_rounds(params), + degree: params.bytecode_d + 1, + claim: "stage6.bytecode_read_raf.weighted_prior_stage_values", + relation: "jolt.stage6.bytecode_read_raf", + }, + bytecode_claim, + &bytecode_inputs, + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage6.booleanity.input", + stage: "stage6", + domain: "jolt.stage6_booleanity_domain", + num_rounds: booleanity_rounds(params), + degree: BOOLEANITY_DEGREE, + claim: "stage6.booleanity.zero", + relation: "jolt.stage6.booleanity", + }, + zero, + &[], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage6.hamming_booleanity.input", + stage: "stage6", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: HAMMING_BOOLEANITY_DEGREE, + claim: "stage6.hamming_booleanity.zero", + relation: "jolt.stage6.hamming_booleanity", + }, + zero, + &[inputs.hamming_lookup_output.claim], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage6.ram_ra_virtual.input", + stage: "stage6", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: params.ram_d + 1, + claim: "stage6.ram_ra_virtual.weighted_ram_ra", + relation: "jolt.stage6.ram_ra_virtual", + }, + ram_ra_virtual_claim, + &[inputs.ram_ra_virtual.claim], + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage6.instruction_ra_virtual.input", + stage: "stage6", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: n_committed_per_virtual(params) + 1, + claim: "stage6.instruction_ra_virtual.weighted_instruction_ra", + relation: "jolt.stage6.instruction_ra_virtual", + }, + inst_ra_virtual_claim, + &instruction_inputs, + )?, + append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage6.inc_claim_reduction.input", + stage: "stage6", + domain: "jolt.trace_domain", + num_rounds: params.log_t, + degree: INC_CLAIM_REDUCTION_DEGREE, + claim: "stage6.inc_claim_reduction.weighted_increments", + relation: "jolt.stage6.inc_claim_reduction", + }, + inc_claim, + &inc_inputs, + )?, + ]; + let round_schedule = format!("[{}, {}]", params.log_k_bytecode, params.log_t); + let batch = append_sumcheck_batch( + context, + module, + spec.stage, + &claims, + SumcheckBatchSpec { + symbol: "stage6.batch", + stage: "stage6", + proof_slot: "stage6.sumcheck", + policy: "jolt_core_stage6_aligned", + ordered_claims: &[ + "stage6.bytecode_read_raf.input", + "stage6.booleanity.input", + "stage6.hamming_booleanity.input", + "stage6.ram_ra_virtual.input", + "stage6.instruction_ra_virtual.input", + "stage6.inc_claim_reduction.input", + ], + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + round_schedule: &round_schedule, + }, + )?; + let (state, point, result_value) = append_sumcheck( + context, + module, + spec.state, + batch, + SumcheckDriverSpec { + symbol: "stage6.sumcheck", + stage: "stage6", + proof_slot: "stage6.sumcheck", + relation: "jolt.stage6.batched", + policy: "jolt_core_stage6_aligned", + round_schedule: &round_schedule, + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + num_rounds: stage6_max_rounds(params), + degree: stage6_batched_degree(params), + }, + )?; + let bytecode = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage6.bytecode_read_raf.instance", + source: "stage6.sumcheck", + claim: "stage6.bytecode_read_raf.input", + relation: "jolt.stage6.bytecode_read_raf", + index: 0, + point_arity: stage6_max_rounds(params), + num_rounds: stage6_max_rounds(params), + round_offset: 0, + point_order: "bytecode_read_raf", + degree: params.bytecode_d + 1, + }, + point, + result_value, + )?; + let booleanity = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage6.booleanity.instance", + source: "stage6.sumcheck", + claim: "stage6.booleanity.input", + relation: "jolt.stage6.booleanity", + index: 1, + point_arity: booleanity_rounds(params), + num_rounds: booleanity_rounds(params), + round_offset: params.log_k_bytecode.saturating_sub(params.log_k_chunk), + point_order: "stage6_booleanity", + degree: BOOLEANITY_DEGREE, + }, + point, + result_value, + )?; + let hamming = append_stage6_trace_instance_result( + context, + module, + params, + point, + result_value, + Stage6TraceInstanceSpec { + symbol: "stage6.hamming_booleanity.instance", + claim: "stage6.hamming_booleanity.input", + relation: "jolt.stage6.hamming_booleanity", + index: 2, + degree: HAMMING_BOOLEANITY_DEGREE, + }, + )?; + let ram = append_stage6_trace_instance_result( + context, + module, + params, + point, + result_value, + Stage6TraceInstanceSpec { + symbol: "stage6.ram_ra_virtual.instance", + claim: "stage6.ram_ra_virtual.input", + relation: "jolt.stage6.ram_ra_virtual", + index: 3, + degree: params.ram_d + 1, + }, + )?; + let instruction = append_stage6_trace_instance_result( + context, + module, + params, + point, + result_value, + Stage6TraceInstanceSpec { + symbol: "stage6.instruction_ra_virtual.instance", + claim: "stage6.instruction_ra_virtual.input", + relation: "jolt.stage6.instruction_ra_virtual", + index: 4, + degree: n_committed_per_virtual(params) + 1, + }, + )?; + let inc = append_stage6_trace_instance_result( + context, + module, + params, + point, + result_value, + Stage6TraceInstanceSpec { + symbol: "stage6.inc_claim_reduction.instance", + claim: "stage6.inc_claim_reduction.input", + relation: "jolt.stage6.inc_claim_reduction", + index: 5, + degree: INC_CLAIM_REDUCTION_DEGREE, + }, + )?; + append_stage6_output_openings( + context, + module, + params, + inputs, + bytecode, + booleanity, + hamming, + ram, + instruction, + inc, + )?; + Ok(state) +} + +fn append_bytecode_read_raf_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + inputs: &Stage6OpeningInputs<'c, 'a>, + spec: &Stage6BatchedSumcheckInputs<'c, 'a, '_>, +) -> Result, MlirError> { + let mut terms = Vec::with_capacity(inputs.bytecode_terms.len() + 1); + for (index, term) in inputs.bytecode_terms.iter().enumerate() { + terms.push(append_weighted_eval( + context, + module, + &format!("stage6.bytecode_read_raf.claim.term{index}"), + term.eval, + WeightedEvalSpec { + gamma: spec.bc_gamma, + gamma_power: term.gamma_power, + stage_gamma: term.stage_gamma.map(|gamma| stage_gamma_value(gamma, spec)), + stage_gamma_power: term.stage_gamma_power, + }, + )?); + } + terms.push(append_field_pow( + context, + module, + "stage6.bytecode_read_raf.claim.entry_constant", + spec.bc_gamma, + 7, + )?); + append_field_sum( + context, + module, + "stage6.bytecode_read_raf.claim_expr", + &terms, + ) +} + +fn append_instruction_ra_virtual_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + inputs: &Stage6OpeningInputs<'c, 'a>, + gamma: Value<'c, 'a>, +) -> Result, MlirError> { + let mut terms = Vec::with_capacity(inputs.instruction_ra_virtual.len()); + for (index, input) in inputs.instruction_ra_virtual.iter().enumerate() { + terms.push(append_weighted_eval( + context, + module, + &format!("stage6.instruction_ra_virtual.claim.term{index}"), + input.eval, + WeightedEvalSpec { + gamma, + gamma_power: index, + stage_gamma: None, + stage_gamma_power: 0, + }, + )?); + } + append_field_sum( + context, + module, + "stage6.instruction_ra_virtual.claim_expr", + &terms, + ) +} + +fn append_inc_claim_reduction_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + inputs: &Stage6OpeningInputs<'c, 'a>, + gamma: Value<'c, 'a>, +) -> Result, MlirError> { + let terms = [ + inputs.ram_inc_stage2.eval, + append_weighted_eval( + context, + module, + "stage6.inc_claim_reduction.claim.ram_inc_stage4", + inputs.ram_inc_stage4.eval, + WeightedEvalSpec { + gamma, + gamma_power: 1, + stage_gamma: None, + stage_gamma_power: 0, + }, + )?, + append_weighted_eval( + context, + module, + "stage6.inc_claim_reduction.claim.rd_inc_stage4", + inputs.rd_inc_stage4.eval, + WeightedEvalSpec { + gamma, + gamma_power: 2, + stage_gamma: None, + stage_gamma_power: 0, + }, + )?, + append_weighted_eval( + context, + module, + "stage6.inc_claim_reduction.claim.rd_inc_stage5", + inputs.rd_inc_stage5.eval, + WeightedEvalSpec { + gamma, + gamma_power: 3, + stage_gamma: None, + stage_gamma_power: 0, + }, + )?, + ]; + append_field_sum( + context, + module, + "stage6.inc_claim_reduction.claim_expr", + &terms, + ) +} + +fn append_weighted_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol_prefix: &str, + eval: Value<'c, 'a>, + spec: WeightedEvalSpec<'c, 'a>, +) -> Result, MlirError> { + let mut value = eval; + if spec.stage_gamma_power > 0 { + let power = append_field_pow( + context, + module, + &format!("{symbol_prefix}.stage_gamma_pow"), + spec.stage_gamma.ok_or_else(|| MlirError::Schema { + message: format!( + "{symbol_prefix} requires stage gamma when stage_gamma_power is non-zero" + ), + })?, + spec.stage_gamma_power, + )?; + value = append_field_mul( + context, + module, + &format!("{symbol_prefix}.stage_gamma_term"), + power, + value, + )?; + } + if spec.gamma_power > 0 { + let power = append_field_pow( + context, + module, + &format!("{symbol_prefix}.gamma_pow"), + spec.gamma, + spec.gamma_power, + )?; + value = append_field_mul( + context, + module, + &format!("{symbol_prefix}.gamma_term"), + power, + value, + )?; + } + Ok(value) +} + +fn stage_gamma_value<'c, 'a>( + gamma: BytecodeStageGamma, + spec: &Stage6BatchedSumcheckInputs<'c, 'a, '_>, +) -> Value<'c, 'a> { + match gamma { + BytecodeStageGamma::Stage1 => spec.bc_stage1_gamma, + BytecodeStageGamma::Stage2 => spec.bc_stage2_gamma, + BytecodeStageGamma::Stage3 => spec.bc_stage3_gamma, + BytecodeStageGamma::Stage4 => spec.bc_stage4_gamma, + BytecodeStageGamma::Stage5 => spec.bc_stage5_gamma, + } +} + +fn append_stage6_trace_instance_result<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + point: Value<'c, 'a>, + result_value: Value<'c, 'a>, + spec: Stage6TraceInstanceSpec<'_>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: spec.symbol, + source: "stage6.sumcheck", + claim: spec.claim, + relation: spec.relation, + index: spec.index, + point_arity: params.log_t, + num_rounds: params.log_t, + round_offset: params.log_k_bytecode, + point_order: "reverse", + degree: spec.degree, + }, + point, + result_value, + ) +} + +#[expect(clippy::too_many_arguments)] +fn append_stage6_output_openings<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + inputs: &Stage6OpeningInputs<'c, 'a>, + bytecode: (Value<'c, 'a>, Value<'c, 'a>), + booleanity: (Value<'c, 'a>, Value<'c, 'a>), + hamming: (Value<'c, 'a>, Value<'c, 'a>), + ram: (Value<'c, 'a>, Value<'c, 'a>), + instruction: (Value<'c, 'a>, Value<'c, 'a>), + inc: (Value<'c, 'a>, Value<'c, 'a>), +) -> Result<(), MlirError> { + let mut claims = Vec::new(); + let mut claim_symbols = Vec::new(); + + let bytecode_cycle = append_point_slice( + context, + module, + "stage6.bytecode_read_raf.point.Cycle", + "stage6.bytecode_read_raf.instance", + params.log_k_bytecode, + params.log_t, + bytecode.0, + )?; + for index in 0..params.bytecode_d { + let oracle = format!("BytecodeRa_{index}"); + let eval_symbol = format!("stage6.bytecode_read_raf.eval.{oracle}"); + let eval = append_sumcheck_eval( + context, + module, + &eval_symbol, + "stage6.sumcheck", + &oracle, + index, + bytecode.1, + )?; + let address = append_padded_address_chunk( + context, + module, + &format!("stage6.bytecode_read_raf.point.{oracle}.address"), + "stage6.bytecode_read_raf.instance", + params.log_k_bytecode, + index, + params.log_k_chunk, + bytecode.0, + )?; + let point = append_point_concat( + context, + module, + &format!("stage6.bytecode_read_raf.point.{oracle}"), + "address_chunk_then_cycle", + params.log_k_chunk + params.log_t, + &[address, bytecode_cycle], + )?; + let symbol = format!("stage6.bytecode_read_raf.opening.{oracle}"); + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + point, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle: &oracle, + domain: "jolt.main_witness_commit_domain", + point_arity: params.log_k_chunk + params.log_t, + claim_kind: "committed", + }, + )?); + } + + let mut eval_index = 0; + for index in 0..params.instruction_d { + append_booleanity_output_opening( + context, + module, + params, + &mut claims, + &mut claim_symbols, + booleanity, + &format!("InstructionRa_{index}"), + eval_index, + )?; + eval_index += 1; + } + for index in 0..params.bytecode_d { + append_booleanity_output_opening( + context, + module, + params, + &mut claims, + &mut claim_symbols, + booleanity, + &format!("BytecodeRa_{index}"), + eval_index, + )?; + eval_index += 1; + } + for index in 0..params.ram_d { + append_booleanity_output_opening( + context, + module, + params, + &mut claims, + &mut claim_symbols, + booleanity, + &format!("RamRa_{index}"), + eval_index, + )?; + eval_index += 1; + } + + let hamming_eval = append_sumcheck_eval( + context, + module, + "stage6.hamming_booleanity.eval.HammingWeight", + "stage6.sumcheck", + "HammingWeight", + 0, + hamming.1, + )?; + claim_symbols.push("stage6.hamming_booleanity.opening.HammingWeight".to_owned()); + claims.push(append_opening_claim( + context, + module, + hamming.0, + hamming_eval, + OpeningClaimSpec { + symbol: "stage6.hamming_booleanity.opening.HammingWeight", + oracle: "HammingWeight", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "virtual", + }, + )?); + + for index in 0..params.ram_d { + let oracle = format!("RamRa_{index}"); + let symbol = format!("stage6.ram_ra_virtual.opening.{oracle}"); + let address = append_padded_address_chunk( + context, + module, + &format!("stage6.ram_ra_virtual.point.{oracle}.address"), + "stage6.input.stage5.ram_ra_claim_reduction.RamRa", + params.log_k_ram, + index, + params.log_k_chunk, + inputs.ram_ra_virtual.point, + )?; + let point = append_point_concat( + context, + module, + &format!("stage6.ram_ra_virtual.point.{oracle}"), + "address_chunk_then_cycle", + params.log_k_chunk + params.log_t, + &[address, ram.0], + )?; + let eval = append_sumcheck_eval( + context, + module, + &format!("stage6.ram_ra_virtual.eval.{oracle}"), + "stage6.sumcheck", + &oracle, + index, + ram.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + point, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle: &oracle, + domain: "jolt.main_witness_commit_domain", + point_arity: params.log_k_chunk + params.log_t, + claim_kind: "committed", + }, + )?); + } + for index in 0..params.instruction_d { + let oracle = format!("InstructionRa_{index}"); + let symbol = format!("stage6.instruction_ra_virtual.opening.{oracle}"); + let virtual_index = index / n_committed_per_virtual(params); + let chunk_index = index % n_committed_per_virtual(params); + let virtual_input = inputs.instruction_ra_virtual[virtual_index].point; + let address = append_padded_address_chunk( + context, + module, + &format!("stage6.instruction_ra_virtual.point.{oracle}.address"), + &format!("stage6.input.stage5.instruction_read_raf.InstructionRa_{virtual_index}"), + params.lookups_ra_virtual_log_k_chunk, + chunk_index, + params.log_k_chunk, + virtual_input, + )?; + let point = append_point_concat( + context, + module, + &format!("stage6.instruction_ra_virtual.point.{oracle}"), + "address_chunk_then_cycle", + params.log_k_chunk + params.log_t, + &[address, instruction.0], + )?; + let eval = append_sumcheck_eval( + context, + module, + &format!("stage6.instruction_ra_virtual.eval.{oracle}"), + "stage6.sumcheck", + &oracle, + index, + instruction.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + point, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle: &oracle, + domain: "jolt.main_witness_commit_domain", + point_arity: params.log_k_chunk + params.log_t, + claim_kind: "committed", + }, + )?); + } + + for (index, oracle) in ["RamInc", "RdInc"].iter().enumerate() { + let symbol = format!("stage6.inc_claim_reduction.opening.{oracle}"); + let eval = append_sumcheck_eval( + context, + module, + &format!("stage6.inc_claim_reduction.eval.{oracle}"), + "stage6.sumcheck", + oracle, + index, + inc.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + inc.0, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle, + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "committed", + }, + )?); + } + + let claim_names = claim_symbols.iter().map(String::as_str).collect::>(); + let _batch = context.append_typed_op( + module, + "piop.opening_batch", + Some("stage6.openings"), + &[ + ("stage", "@stage6"), + ("proof_slot", "@stage6.openings"), + ("policy", r#""jolt_stage6_output_order""#), + ("count", &int_attr(claims.len())), + ("ordered_claims", &symbol_array_attr(&claim_names)), + ], + &claims, + &["!piop.opening_batch_type"], + )?; + Ok(()) +} + +#[expect(clippy::too_many_arguments)] +fn append_booleanity_output_opening<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + claims: &mut Vec>, + claim_symbols: &mut Vec, + booleanity: (Value<'c, 'a>, Value<'c, 'a>), + oracle: &str, + eval_index: usize, +) -> Result<(), MlirError> { + let symbol = format!("stage6.booleanity.opening.{oracle}"); + let eval = append_sumcheck_eval( + context, + module, + &format!("stage6.booleanity.eval.{oracle}"), + "stage6.sumcheck", + oracle, + eval_index, + booleanity.1, + )?; + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + booleanity.0, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle, + domain: "jolt.main_witness_commit_domain", + point_arity: booleanity_rounds(params), + claim_kind: "committed", + }, + )?); + Ok(()) +} + +fn append_field_zero<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.zero", + Some(symbol), + &[("field", "@bn254_fr")], + &[], + &["!field.scalar"], + )?; + first_result(op, "field.zero") +} + +fn append_field_binary<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + op_name: &str, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + op_name, + Some(symbol), + &[], + &[lhs, rhs], + &["!field.scalar"], + )?; + first_result(op, op_name) +} + +fn append_field_add<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.add", symbol, lhs, rhs) +} + +fn append_field_mul<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.mul", symbol, lhs, rhs) +} + +fn append_field_pow<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + base: Value<'c, 'a>, + exponent: usize, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.pow", + Some(symbol), + &[("exponent", &int_attr(exponent))], + &[base], + &["!field.scalar"], + )?; + first_result(op, "field.pow") +} + +fn append_field_sum<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol_prefix: &str, + terms: &[Value<'c, 'a>], +) -> Result, MlirError> { + let Some((&first, rest)) = terms.split_first() else { + return append_field_zero(context, module, symbol_prefix); + }; + let mut value = first; + for (index, &term) in rest.iter().enumerate() { + value = append_field_add( + context, + module, + &format!("{symbol_prefix}.partial{index}"), + value, + term, + )?; + } + Ok(value) +} + +fn append_sumcheck_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckClaimSpec<'_>, + input_claim: Value<'c, 'a>, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(inputs.len() + 1); + operands.push(input_claim); + operands.extend_from_slice(inputs); + let op = context.append_typed_op( + module, + "piop.sumcheck_claim", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ], + &operands, + &["!piop.sumcheck_claim_type"], + )?; + first_result(op, "piop.sumcheck_claim") +} + +fn append_sumcheck_batch<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + stage: Value<'c, 'a>, + claims: &[Value<'c, 'a>], + spec: SumcheckBatchSpec<'_>, +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(claims.len() + 1); + operands.push(stage); + operands.extend_from_slice(claims); + let op = context.append_typed_op( + module, + "piop.sumcheck_batch", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("policy", &format!("\"{}\"", spec.policy)), + ("count", &int_attr(spec.ordered_claims.len())), + ("ordered_claims", &symbol_array_attr(spec.ordered_claims)), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("round_schedule", spec.round_schedule), + ], + &operands, + &["!piop.sumcheck_batch_type"], + )?; + first_result(op, "piop.sumcheck_batch") +} + +fn append_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + batch: Value<'c, 'a>, + spec: SumcheckDriverSpec<'_>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("relation", &format!("@{}", spec.relation)), + ("policy", &format!("\"{}\"", spec.policy)), + ("round_schedule", spec.round_schedule), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ], + &[state, batch], + &[ + "!transcript.state_type", + "!poly.point", + "!piop.sumcheck_result_type", + "!piop.sumcheck_proof_type", + ], + )?; + Ok(( + result(op, 0, "piop.sumcheck")?, + result(op, 1, "piop.sumcheck")?, + result(op, 2, "piop.sumcheck")?, + )) +} + +fn append_sumcheck_instance_result<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckInstanceResultSpec<'_>, + point: Value<'c, 'a>, + result_value: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_instance_result", + Some(spec.symbol), + &[ + ("source", &format!("@{}", spec.source)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ("index", &int_attr(spec.index)), + ("point_arity", &int_attr(spec.point_arity)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("round_offset", &int_attr(spec.round_offset)), + ("point_order", &format!("\"{}\"", spec.point_order)), + ("degree", &int_attr(spec.degree)), + ], + &[point, result_value], + &["!poly.point", "!piop.sumcheck_result_type"], + )?; + Ok(( + result(op, 0, "piop.sumcheck_instance_result")?, + result(op, 1, "piop.sumcheck_instance_result")?, + )) +} + +fn append_sumcheck_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + oracle: &str, + index: usize, + result_value: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_eval", + Some(symbol), + &[ + ("source", &format!("@{}", source)), + ("name", &format!("@{}", symbol)), + ("index", &int_attr(index)), + ("oracle", &format!("@{}", oracle)), + ], + &[result_value], + &["!field.scalar"], + )?; + first_result(op, "piop.sumcheck_eval") +} + +fn append_point_slice<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + offset: usize, + length: usize, + input: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_slice", + Some(symbol), + &[ + ("source", &format!("@{}", source)), + ("offset", &int_attr(offset)), + ("length", &int_attr(length)), + ], + &[input], + &["!poly.point"], + )?; + first_result(op, "poly.point_slice") +} + +#[expect(clippy::too_many_arguments)] +fn append_padded_address_chunk<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + address_len: usize, + chunk_index: usize, + chunk_len: usize, + input: Value<'c, 'a>, +) -> Result, MlirError> { + let pad_len = (chunk_len - (address_len % chunk_len)) % chunk_len; + let padded_offset = chunk_index * chunk_len; + let zero_len = pad_len.saturating_sub(padded_offset).min(chunk_len); + let source_offset = padded_offset.saturating_sub(pad_len); + let source_len = chunk_len - zero_len; + if source_offset + source_len > address_len { + return Err(schema_error(format!( + "address chunk {chunk_index} exceeds source point @{source}" + ))); + } + + let source_chunk = if source_len == 0 { + None + } else { + let source_symbol = if zero_len == 0 { + symbol.to_owned() + } else { + format!("{symbol}.source") + }; + Some(append_point_slice( + context, + module, + &source_symbol, + source, + source_offset, + source_len, + input, + )?) + }; + + if zero_len == 0 { + return source_chunk.ok_or_else(|| { + schema_error(format!("address chunk {chunk_index} has no source point")) + }); + } + + let zero = append_point_zero(context, module, &format!("{symbol}.zero_pad"), zero_len)?; + let inputs = match source_chunk { + Some(source_chunk) => vec![zero, source_chunk], + None => vec![zero], + }; + append_point_concat( + context, + module, + symbol, + "left_zero_padded_address_chunk", + chunk_len, + &inputs, + ) +} + +fn append_point_zero<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + arity: usize, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_zero", + Some(symbol), + &[("field", "@bn254_fr"), ("arity", &int_attr(arity))], + &[], + &["!poly.point"], + )?; + first_result(op, "poly.point_zero") +} + +fn append_point_concat<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + layout: &str, + arity: usize, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_concat", + Some(symbol), + &[ + ("layout", &format!("\"{}\"", layout)), + ("arity", &int_attr(arity)), + ], + inputs, + &["!poly.point"], + )?; + first_result(op, "poly.point_concat") +} + +fn append_opening_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + spec: OpeningClaimSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_claim", + Some(spec.symbol), + &[ + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", &format!("\"{}\"", spec.claim_kind)), + ], + &[point, eval], + &["!piop.opening_claim_type"], + )?; + first_result(op, "piop.opening_claim") +} + +fn stage6_max_rounds(params: &JoltProtocolParams) -> usize { + params.log_k_bytecode + params.log_t +} + +fn booleanity_rounds(params: &JoltProtocolParams) -> usize { + params.log_k_chunk + params.log_t +} + +fn n_committed_per_virtual(params: &JoltProtocolParams) -> usize { + params.lookups_ra_virtual_log_k_chunk / params.log_k_chunk +} + +fn total_ra_oracles(params: &JoltProtocolParams) -> usize { + params.instruction_d + params.bytecode_d + params.ram_d +} + +fn stage6_batched_degree(params: &JoltProtocolParams) -> usize { + [ + params.bytecode_d + 1, + BOOLEANITY_DEGREE, + HAMMING_BOOLEANITY_DEGREE, + params.ram_d + 1, + n_committed_per_virtual(params) + 1, + INC_CLAIM_REDUCTION_DEGREE, + ] + .into_iter() + .max() + .unwrap_or(INC_CLAIM_REDUCTION_DEGREE) +} + +fn stage6_output_count(params: &JoltProtocolParams) -> usize { + params.bytecode_d + total_ra_oracles(params) + 1 + params.ram_d + params.instruction_d + 2 +} + +fn int_attr(value: usize) -> String { + format!("{value} : i64") +} + +fn symbol_array_attr(values: &[&str]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn first_result<'c, 'a>( + op: OperationRef<'c, 'a>, + context: &str, +) -> Result, MlirError> { + result(op, 0, context) +} + +fn result<'c, 'a>( + op: OperationRef<'c, 'a>, + index: usize, + context: &str, +) -> Result, MlirError> { + op.result(index) + .map(Into::into) + .map_err(|_| schema_error(format!("{context} missing result {index}"))) +} + +fn schema_error(message: impl Into) -> MlirError { + SchemaError::new(message).into() +} + +const STAGE1_OP_FLAGS: [&str; 14] = [ + "OpFlagAddOperands", + "OpFlagSubtractOperands", + "OpFlagMultiplyOperands", + "OpFlagLoad", + "OpFlagStore", + "OpFlagJump", + "OpFlagWriteLookupOutputToRD", + "OpFlagVirtualInstruction", + "OpFlagAssert", + "OpFlagDoNotUpdateUnexpandedPC", + "OpFlagAdvice", + "OpFlagIsCompressed", + "OpFlagIsFirstInSequence", + "OpFlagIsLastInSequence", +]; + +struct Stage6BatchedSumcheckInputs<'c, 'a, 'b> { + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + openings: &'b Stage6OpeningInputs<'c, 'a>, + bc_gamma: Value<'c, 'a>, + bc_stage1_gamma: Value<'c, 'a>, + bc_stage2_gamma: Value<'c, 'a>, + bc_stage3_gamma: Value<'c, 'a>, + bc_stage4_gamma: Value<'c, 'a>, + bc_stage5_gamma: Value<'c, 'a>, + inst_ra_gamma: Value<'c, 'a>, + inc_gamma: Value<'c, 'a>, +} + +struct Stage6OpeningInputs<'c, 'a> { + bytecode_terms: Vec>, + hamming_lookup_output: Stage6OpeningInput<'c, 'a>, + ram_ra_virtual: Stage6OpeningInput<'c, 'a>, + instruction_ra_virtual: Vec>, + ram_inc_stage2: Stage6OpeningInput<'c, 'a>, + ram_inc_stage4: Stage6OpeningInput<'c, 'a>, + rd_inc_stage4: Stage6OpeningInput<'c, 'a>, + rd_inc_stage5: Stage6OpeningInput<'c, 'a>, +} + +struct Stage6OpeningInput<'c, 'a> { + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + claim: Value<'c, 'a>, +} + +struct Stage6BytecodeTerm<'c, 'a> { + eval: Value<'c, 'a>, + claim: Value<'c, 'a>, + gamma_power: usize, + stage_gamma: Option, + stage_gamma_power: usize, +} + +struct BytecodeTermSpec<'a> { + input: StageOpeningInputSpec<'a>, + gamma_power: usize, + stage_gamma: Option, + stage_gamma_power: usize, +} + +impl<'a> BytecodeTermSpec<'a> { + fn trace( + symbol: &'a str, + source_stage: &'a str, + source_claim: &'a str, + oracle: &'a str, + gamma_power: usize, + stage_gamma: Option, + stage_gamma_power: usize, + ) -> Self { + Self { + input: StageOpeningInputSpec { + symbol, + source_stage, + source_claim, + oracle, + domain: "jolt.trace_domain", + point_arity: 0, + claim_kind: "virtual", + }, + gamma_power, + stage_gamma, + stage_gamma_power, + } + } +} + +struct WeightedEvalSpec<'c, 'a> { + gamma: Value<'c, 'a>, + gamma_power: usize, + stage_gamma: Option>, + stage_gamma_power: usize, +} + +struct StageOpeningInputSpec<'a> { + symbol: &'a str, + source_stage: &'a str, + source_claim: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, + claim_kind: &'a str, +} + +struct RelationSpec<'a> { + symbol: &'a str, + kind: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + output_count: usize, +} + +struct SumcheckClaimSpec<'a> { + symbol: &'a str, + stage: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + claim: &'a str, + relation: &'a str, +} + +struct SumcheckBatchSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + policy: &'a str, + ordered_claims: &'a [&'a str], + claim_label: &'a str, + round_label: &'a str, + round_schedule: &'a str, +} + +struct SumcheckDriverSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + relation: &'a str, + policy: &'a str, + round_schedule: &'a str, + claim_label: &'a str, + round_label: &'a str, + num_rounds: usize, + degree: usize, +} + +struct SumcheckInstanceResultSpec<'a> { + symbol: &'a str, + source: &'a str, + claim: &'a str, + relation: &'a str, + index: usize, + point_arity: usize, + num_rounds: usize, + round_offset: usize, + point_order: &'a str, + degree: usize, +} + +struct Stage6TraceInstanceSpec<'a> { + symbol: &'a str, + claim: &'a str, + relation: &'a str, + index: usize, + degree: usize, +} + +struct OpeningClaimSpec<'a> { + symbol: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, + claim_kind: &'a str, +} diff --git a/crates/bolt/src/protocols/jolt/phases/stage7.rs b/crates/bolt/src/protocols/jolt/phases/stage7.rs new file mode 100644 index 0000000000..c37568c3d3 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/stage7.rs @@ -0,0 +1,1095 @@ +use melior::ir::operation::OperationRef; +use melior::ir::Value; + +use crate::ir::{BoltModule, Compute, Party, Protocol}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{verify_protocol_schema, SchemaError}; + +use super::super::oracles; +use super::super::params::JoltProtocolParams; +use super::lowering::{lower_party_to_compute, transcript_squeeze_protocol_result_type}; + +const HAMMING_WEIGHT_CLAIM_REDUCTION_DEGREE: usize = 2; + +pub fn build_stage7_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.stage7", None); + oracles::append_foundation_ops(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.stage7"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + append_stage7_domains(context, &module, params)?; + append_stage7_oracles(context, &module, params)?; + append_stage7_relations(context, &module, params)?; + let inputs = append_stage7_opening_inputs(context, &module, params)?; + + let fs = context.append_typed_op( + &module, + "transcript.state", + Some("fs_after_stage6"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let state = first_result(fs, "transcript.state")?; + let stage = context.append_typed_op( + &module, + "piop.stage", + Some("stage7"), + &[ + ("name", r#""hamming_weight_claim_reduction""#), + ("order", "7 : i64"), + ("roles", r#"["prover", "verifier"]"#), + ], + &[], + &["!piop.stage_type"], + )?; + let stage = first_result(stage, "piop.stage")?; + + let (state, gamma) = append_transcript_squeeze( + context, + &module, + state, + "stage7.hamming_weight_claim_reduction.gamma", + "hamming_weight_claim_reduction_gamma", + "challenge_scalar", + 1, + )?; + let _state = append_stage7_sumcheck( + context, + &module, + params, + Stage7SumcheckInputs { + state, + stage, + openings: &inputs, + gamma, + }, + )?; + + verify_module(&module)?; + verify_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_stage7_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + lower_party_to_compute(context, module, "jolt.stage7", "jolt.stage7", "stage7") +} + +fn append_stage7_domains<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_domain( + context, + module, + "jolt.stage7_hamming_weight_claim_reduction_domain", + params.log_k_chunk, + ) +} + +fn append_domain<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + log_size: usize, +) -> Result<(), MlirError> { + context.append_op( + module, + "poly.domain", + Some(symbol), + &[("field", "@bn254_fr"), ("log_size", &int_attr(log_size))], + ) +} + +fn append_stage7_oracles<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_virtual_oracle(context, module, "HammingWeight", "jolt.trace_domain")?; + for index in 0..params.instruction_d { + append_committed_main_witness_oracle(context, module, &format!("InstructionRa_{index}"))?; + } + for index in 0..params.bytecode_d { + append_committed_main_witness_oracle(context, module, &format!("BytecodeRa_{index}"))?; + } + for index in 0..params.ram_d { + append_committed_main_witness_oracle(context, module, &format!("RamRa_{index}"))?; + } + Ok(()) +} + +fn append_virtual_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, + domain: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", &format!("@{domain}")), + ("commit_domain", &format!("@{domain}")), + ("visibility", r#""virtual""#), + ("layout", r#""virtual""#), + ], + ) +} + +fn append_committed_main_witness_oracle<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + symbol: &str, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.oracle", + Some(symbol), + &[ + ("field", "@bn254_fr"), + ("domain", "@jolt.main_witness_commit_domain"), + ("commit_domain", "@jolt.main_witness_commit_domain"), + ("visibility", r#""committed""#), + ("layout", r#""onehot_expanded""#), + ], + ) +} + +fn append_stage7_relations<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result<(), MlirError> { + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage7.hamming_weight_claim_reduction", + kind: "sumcheck", + domain: "jolt.stage7_hamming_weight_claim_reduction_domain", + num_rounds: params.log_k_chunk, + degree: HAMMING_WEIGHT_CLAIM_REDUCTION_DEGREE, + output_count: total_ra_oracles(params), + }, + )?; + append_relation( + context, + module, + RelationSpec { + symbol: "jolt.stage7.batched", + kind: "batched_sumcheck", + domain: "jolt.stage7_hamming_weight_claim_reduction_domain", + num_rounds: params.log_k_chunk, + degree: HAMMING_WEIGHT_CLAIM_REDUCTION_DEGREE, + output_count: total_ra_oracles(params), + }, + ) +} + +fn append_relation<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Protocol>, + spec: RelationSpec<'_>, +) -> Result<(), MlirError> { + context.append_op( + module, + "piop.relation", + Some(spec.symbol), + &[ + ("kind", &format!("\"{}\"", spec.kind)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("output_count", &int_attr(spec.output_count)), + ], + ) +} + +fn append_stage7_opening_inputs<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let ram_hamming = append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: "stage7.input.stage6.hamming_booleanity.HammingWeight", + source_stage: "stage6", + source_claim: "stage6.hamming_booleanity.opening.HammingWeight", + oracle: "HammingWeight", + domain: "jolt.trace_domain", + point_arity: params.log_t, + claim_kind: "virtual", + }, + )?; + + let mut ra_inputs = Vec::with_capacity(total_ra_oracles(params)); + for index in 0..params.instruction_d { + let oracle = format!("InstructionRa_{index}"); + ra_inputs.push(append_ra_inputs( + context, + module, + params, + &oracle, + Stage7RaKind::Instruction, + &format!("stage6.instruction_ra_virtual.opening.{oracle}"), + &format!("stage7.input.stage6.instruction_ra_virtual.{oracle}"), + )?); + } + for index in 0..params.bytecode_d { + let oracle = format!("BytecodeRa_{index}"); + ra_inputs.push(append_ra_inputs( + context, + module, + params, + &oracle, + Stage7RaKind::Bytecode, + &format!("stage6.bytecode_read_raf.opening.{oracle}"), + &format!("stage7.input.stage6.bytecode_read_raf.{oracle}"), + )?); + } + for index in 0..params.ram_d { + let oracle = format!("RamRa_{index}"); + ra_inputs.push(append_ra_inputs( + context, + module, + params, + &oracle, + Stage7RaKind::Ram, + &format!("stage6.ram_ra_virtual.opening.{oracle}"), + &format!("stage7.input.stage6.ram_ra_virtual.{oracle}"), + )?); + } + + let booleanity_point = ra_inputs + .first() + .ok_or_else(|| schema_error("Stage 7 requires at least one RA oracle"))? + .booleanity + .point; + + Ok(Stage7OpeningInputs { + ra_inputs, + ram_hamming, + booleanity_point, + }) +} + +fn append_ra_inputs<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + oracle: &str, + kind: Stage7RaKind, + source_virtual_claim: &str, + virtual_input_symbol: &str, +) -> Result, MlirError> { + let booleanity = append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: &format!("stage7.input.stage6.booleanity.{oracle}"), + source_stage: "stage6", + source_claim: &format!("stage6.booleanity.opening.{oracle}"), + oracle, + domain: "jolt.main_witness_commit_domain", + point_arity: params.log_k_chunk + params.log_t, + claim_kind: "committed", + }, + )?; + let virtualization = append_stage_input( + context, + module, + StageOpeningInputSpec { + symbol: virtual_input_symbol, + source_stage: "stage6", + source_claim: source_virtual_claim, + oracle, + domain: "jolt.main_witness_commit_domain", + point_arity: params.log_k_chunk + params.log_t, + claim_kind: "committed", + }, + )?; + Ok(Stage7RaInput { + oracle: oracle.to_owned(), + kind, + booleanity, + virtualization, + }) +} + +fn append_stage_input<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: StageOpeningInputSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_input", + Some(spec.symbol), + &[ + ("source_stage", &format!("@{}", spec.source_stage)), + ("source_claim", &format!("@{}", spec.source_claim)), + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", &format!("\"{}\"", spec.claim_kind)), + ], + &[], + &["!poly.point", "!field.scalar", "!piop.opening_claim_type"], + )?; + Ok(Stage7OpeningInput { + point: result(op, 0, "piop.opening_input")?, + eval: result(op, 1, "piop.opening_input")?, + claim: result(op, 2, "piop.opening_input")?, + }) +} + +fn append_stage7_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + spec: Stage7SumcheckInputs<'c, 'a, '_>, +) -> Result, MlirError> { + let input_claim = append_hamming_weight_claim_reduction_input_claim( + context, + module, + spec.openings, + spec.gamma, + )?; + let mut input_openings = Vec::with_capacity(2 * spec.openings.ra_inputs.len() + 1); + input_openings.push(spec.openings.ram_hamming.claim); + for input in &spec.openings.ra_inputs { + input_openings.push(input.booleanity.claim); + input_openings.push(input.virtualization.claim); + } + let claim = append_sumcheck_claim( + context, + module, + SumcheckClaimSpec { + symbol: "stage7.hamming_weight_claim_reduction.input", + stage: "stage7", + domain: "jolt.stage7_hamming_weight_claim_reduction_domain", + num_rounds: params.log_k_chunk, + degree: HAMMING_WEIGHT_CLAIM_REDUCTION_DEGREE, + claim: "stage7.hamming_weight_claim_reduction.weighted_stage6_claims", + relation: "jolt.stage7.hamming_weight_claim_reduction", + }, + input_claim, + &input_openings, + )?; + let round_schedule = format!("[{}]", params.log_k_chunk); + let batch = append_sumcheck_batch( + context, + module, + spec.stage, + &[claim], + SumcheckBatchSpec { + symbol: "stage7.batch", + stage: "stage7", + proof_slot: "stage7.sumcheck", + policy: "jolt_core_stage7_aligned", + ordered_claims: &["stage7.hamming_weight_claim_reduction.input"], + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + round_schedule: &round_schedule, + }, + )?; + let (state, point, result_value) = append_sumcheck( + context, + module, + spec.state, + batch, + SumcheckDriverSpec { + symbol: "stage7.sumcheck", + stage: "stage7", + proof_slot: "stage7.sumcheck", + relation: "jolt.stage7.batched", + policy: "jolt_core_stage7_aligned", + round_schedule: &round_schedule, + claim_label: "sumcheck_claim", + round_label: "sumcheck_poly", + num_rounds: params.log_k_chunk, + degree: HAMMING_WEIGHT_CLAIM_REDUCTION_DEGREE, + }, + )?; + let instance = append_sumcheck_instance_result( + context, + module, + SumcheckInstanceResultSpec { + symbol: "stage7.hamming_weight_claim_reduction.instance", + source: "stage7.sumcheck", + claim: "stage7.hamming_weight_claim_reduction.input", + relation: "jolt.stage7.hamming_weight_claim_reduction", + index: 0, + point_arity: params.log_k_chunk, + num_rounds: params.log_k_chunk, + round_offset: 0, + point_order: "reverse", + degree: HAMMING_WEIGHT_CLAIM_REDUCTION_DEGREE, + }, + point, + result_value, + )?; + append_stage7_output_openings(context, module, params, spec.openings, instance)?; + Ok(state) +} + +fn append_hamming_weight_claim_reduction_input_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + inputs: &Stage7OpeningInputs<'c, 'a>, + gamma: Value<'c, 'a>, +) -> Result, MlirError> { + let one = append_field_one(context, module, "stage7.field.one")?; + let mut terms = Vec::with_capacity(3 * inputs.ra_inputs.len()); + for (index, input) in inputs.ra_inputs.iter().enumerate() { + let hamming_eval = match input.kind { + Stage7RaKind::Instruction | Stage7RaKind::Bytecode => one, + Stage7RaKind::Ram => inputs.ram_hamming.eval, + }; + terms.push(append_weighted_eval( + context, + module, + &format!("stage7.hamming_weight_claim_reduction.claim.{index}.hw"), + hamming_eval, + gamma, + 3 * index, + )?); + terms.push(append_weighted_eval( + context, + module, + &format!("stage7.hamming_weight_claim_reduction.claim.{index}.booleanity"), + input.booleanity.eval, + gamma, + 3 * index + 1, + )?); + terms.push(append_weighted_eval( + context, + module, + &format!("stage7.hamming_weight_claim_reduction.claim.{index}.virtualization"), + input.virtualization.eval, + gamma, + 3 * index + 2, + )?); + } + append_field_sum( + context, + module, + "stage7.hamming_weight_claim_reduction.claim_expr", + &terms, + ) +} + +fn append_weighted_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol_prefix: &str, + eval: Value<'c, 'a>, + gamma: Value<'c, 'a>, + gamma_power: usize, +) -> Result, MlirError> { + if gamma_power == 0 { + return Ok(eval); + } + let power = append_field_pow( + context, + module, + &format!("{symbol_prefix}.gamma_pow"), + gamma, + gamma_power, + )?; + append_field_mul( + context, + module, + &format!("{symbol_prefix}.gamma_term"), + power, + eval, + ) +} + +fn append_stage7_output_openings<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + inputs: &Stage7OpeningInputs<'c, 'a>, + instance: (Value<'c, 'a>, Value<'c, 'a>), +) -> Result<(), MlirError> { + let cycle = append_point_slice( + context, + module, + "stage7.hamming_weight_claim_reduction.point.cycle", + "stage7.input.stage6.booleanity.InstructionRa_0", + params.log_k_chunk, + params.log_t, + inputs.booleanity_point, + )?; + let full_point = append_point_concat( + context, + module, + "stage7.hamming_weight_claim_reduction.point", + "address_chunk_then_cycle", + params.log_k_chunk + params.log_t, + &[instance.0, cycle], + )?; + + let mut claims = Vec::with_capacity(inputs.ra_inputs.len()); + let mut claim_symbols = Vec::with_capacity(inputs.ra_inputs.len()); + for (index, input) in inputs.ra_inputs.iter().enumerate() { + let eval = append_sumcheck_eval( + context, + module, + &format!( + "stage7.hamming_weight_claim_reduction.eval.{}", + input.oracle + ), + "stage7.sumcheck", + &input.oracle, + index, + instance.1, + )?; + let symbol = format!( + "stage7.hamming_weight_claim_reduction.opening.{}", + input.oracle + ); + claim_symbols.push(symbol.clone()); + claims.push(append_opening_claim( + context, + module, + full_point, + eval, + OpeningClaimSpec { + symbol: &symbol, + oracle: &input.oracle, + domain: "jolt.main_witness_commit_domain", + point_arity: params.log_k_chunk + params.log_t, + claim_kind: "committed", + }, + )?); + } + let claim_names = claim_symbols.iter().map(String::as_str).collect::>(); + let _batch = context.append_typed_op( + module, + "piop.opening_batch", + Some("stage7.openings"), + &[ + ("stage", "@stage7"), + ("proof_slot", "@stage7.openings"), + ("policy", r#""jolt_stage7_output_order""#), + ("count", &int_attr(claims.len())), + ("ordered_claims", &symbol_array_attr(&claim_names)), + ], + &claims, + &["!piop.opening_batch_type"], + )?; + Ok(()) +} + +fn append_transcript_squeeze<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + symbol: &str, + label: &str, + kind: &str, + count: usize, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "transcript.squeeze", + Some(symbol), + &[ + ("label", &format!("\"{label}\"")), + ("kind", &format!("\"{kind}\"")), + ("count", &int_attr(count)), + ], + &[state], + &[ + "!transcript.state_type", + transcript_squeeze_protocol_result_type(kind)?, + ], + )?; + Ok(( + result(op, 0, "transcript.squeeze")?, + result(op, 1, "transcript.squeeze")?, + )) +} + +fn append_field_one<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.one", + Some(symbol), + &[("field", "@bn254_fr")], + &[], + &["!field.scalar"], + )?; + first_result(op, "field.one") +} + +fn append_field_binary<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + op_name: &str, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + op_name, + Some(symbol), + &[], + &[lhs, rhs], + &["!field.scalar"], + )?; + first_result(op, op_name) +} + +fn append_field_add<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.add", symbol, lhs, rhs) +} + +fn append_field_mul<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + lhs: Value<'c, 'a>, + rhs: Value<'c, 'a>, +) -> Result, MlirError> { + append_field_binary(context, module, "field.mul", symbol, lhs, rhs) +} + +fn append_field_pow<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + base: Value<'c, 'a>, + exponent: usize, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "field.pow", + Some(symbol), + &[("exponent", &int_attr(exponent))], + &[base], + &["!field.scalar"], + )?; + first_result(op, "field.pow") +} + +fn append_field_sum<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol_prefix: &str, + terms: &[Value<'c, 'a>], +) -> Result, MlirError> { + let Some((&first, rest)) = terms.split_first() else { + return append_field_one(context, module, symbol_prefix); + }; + let mut value = first; + for (index, &term) in rest.iter().enumerate() { + value = append_field_add( + context, + module, + &format!("{symbol_prefix}.partial{index}"), + value, + term, + )?; + } + Ok(value) +} + +fn append_sumcheck_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckClaimSpec<'_>, + input_claim: Value<'c, 'a>, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(inputs.len() + 1); + operands.push(input_claim); + operands.extend_from_slice(inputs); + let op = context.append_typed_op( + module, + "piop.sumcheck_claim", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("domain", &format!("@{}", spec.domain)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ], + &operands, + &["!piop.sumcheck_claim_type"], + )?; + first_result(op, "piop.sumcheck_claim") +} + +fn append_sumcheck_batch<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + stage: Value<'c, 'a>, + claims: &[Value<'c, 'a>], + spec: SumcheckBatchSpec<'_>, +) -> Result, MlirError> { + let mut operands = Vec::with_capacity(claims.len() + 1); + operands.push(stage); + operands.extend_from_slice(claims); + let op = context.append_typed_op( + module, + "piop.sumcheck_batch", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("policy", &format!("\"{}\"", spec.policy)), + ("count", &int_attr(spec.ordered_claims.len())), + ("ordered_claims", &symbol_array_attr(spec.ordered_claims)), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("round_schedule", spec.round_schedule), + ], + &operands, + &["!piop.sumcheck_batch_type"], + )?; + first_result(op, "piop.sumcheck_batch") +} + +fn append_sumcheck<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + state: Value<'c, 'a>, + batch: Value<'c, 'a>, + spec: SumcheckDriverSpec<'_>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck", + Some(spec.symbol), + &[ + ("stage", &format!("@{}", spec.stage)), + ("proof_slot", &format!("@{}", spec.proof_slot)), + ("relation", &format!("@{}", spec.relation)), + ("policy", &format!("\"{}\"", spec.policy)), + ("round_schedule", spec.round_schedule), + ("claim_label", &format!("\"{}\"", spec.claim_label)), + ("round_label", &format!("\"{}\"", spec.round_label)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("degree", &int_attr(spec.degree)), + ], + &[state, batch], + &[ + "!transcript.state_type", + "!poly.point", + "!piop.sumcheck_result_type", + "!piop.sumcheck_proof_type", + ], + )?; + Ok(( + result(op, 0, "piop.sumcheck")?, + result(op, 1, "piop.sumcheck")?, + result(op, 2, "piop.sumcheck")?, + )) +} + +fn append_sumcheck_instance_result<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: SumcheckInstanceResultSpec<'_>, + point: Value<'c, 'a>, + result_value: Value<'c, 'a>, +) -> Result<(Value<'c, 'a>, Value<'c, 'a>), MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_instance_result", + Some(spec.symbol), + &[ + ("source", &format!("@{}", spec.source)), + ("claim", &format!("@{}", spec.claim)), + ("relation", &format!("@{}", spec.relation)), + ("index", &int_attr(spec.index)), + ("point_arity", &int_attr(spec.point_arity)), + ("num_rounds", &int_attr(spec.num_rounds)), + ("round_offset", &int_attr(spec.round_offset)), + ("point_order", &format!("\"{}\"", spec.point_order)), + ("degree", &int_attr(spec.degree)), + ], + &[point, result_value], + &["!poly.point", "!piop.sumcheck_result_type"], + )?; + Ok(( + result(op, 0, "piop.sumcheck_instance_result")?, + result(op, 1, "piop.sumcheck_instance_result")?, + )) +} + +fn append_sumcheck_eval<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + oracle: &str, + index: usize, + result_value: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.sumcheck_eval", + Some(symbol), + &[ + ("source", &format!("@{}", source)), + ("name", &format!("@{}", symbol)), + ("index", &int_attr(index)), + ("oracle", &format!("@{}", oracle)), + ], + &[result_value], + &["!field.scalar"], + )?; + first_result(op, "piop.sumcheck_eval") +} + +fn append_point_slice<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + source: &str, + offset: usize, + length: usize, + input: Value<'c, 'a>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_slice", + Some(symbol), + &[ + ("source", &format!("@{}", source)), + ("offset", &int_attr(offset)), + ("length", &int_attr(length)), + ], + &[input], + &["!poly.point"], + )?; + first_result(op, "poly.point_slice") +} + +fn append_point_concat<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + symbol: &str, + layout: &str, + arity: usize, + inputs: &[Value<'c, 'a>], +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "poly.point_concat", + Some(symbol), + &[ + ("layout", &format!("\"{}\"", layout)), + ("arity", &int_attr(arity)), + ], + inputs, + &["!poly.point"], + )?; + first_result(op, "poly.point_concat") +} + +fn append_opening_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + spec: OpeningClaimSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_claim", + Some(spec.symbol), + &[ + ("oracle", &format!("@{}", spec.oracle)), + ("domain", &format!("@{}", spec.domain)), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", &format!("\"{}\"", spec.claim_kind)), + ], + &[point, eval], + &["!piop.opening_claim_type"], + )?; + first_result(op, "piop.opening_claim") +} + +fn total_ra_oracles(params: &JoltProtocolParams) -> usize { + params.instruction_d + params.bytecode_d + params.ram_d +} + +fn int_attr(value: usize) -> String { + format!("{value} : i64") +} + +fn symbol_array_attr(values: &[&str]) -> String { + let values = values + .iter() + .map(|value| format!("@{value}")) + .collect::>() + .join(", "); + format!("[{values}]") +} + +fn first_result<'c, 'a>( + op: OperationRef<'c, 'a>, + context: &str, +) -> Result, MlirError> { + result(op, 0, context) +} + +fn result<'c, 'a>( + op: OperationRef<'c, 'a>, + index: usize, + context: &str, +) -> Result, MlirError> { + op.result(index) + .map(Into::into) + .map_err(|_| schema_error(format!("{context} missing result {index}"))) +} + +fn schema_error(message: impl Into) -> MlirError { + SchemaError::new(message).into() +} + +#[derive(Clone, Copy)] +enum Stage7RaKind { + Instruction, + Bytecode, + Ram, +} + +struct Stage7SumcheckInputs<'c, 'a, 'b> { + state: Value<'c, 'a>, + stage: Value<'c, 'a>, + openings: &'b Stage7OpeningInputs<'c, 'a>, + gamma: Value<'c, 'a>, +} + +struct Stage7OpeningInputs<'c, 'a> { + ra_inputs: Vec>, + ram_hamming: Stage7OpeningInput<'c, 'a>, + booleanity_point: Value<'c, 'a>, +} + +struct Stage7RaInput<'c, 'a> { + oracle: String, + kind: Stage7RaKind, + booleanity: Stage7OpeningInput<'c, 'a>, + virtualization: Stage7OpeningInput<'c, 'a>, +} + +struct Stage7OpeningInput<'c, 'a> { + point: Value<'c, 'a>, + eval: Value<'c, 'a>, + claim: Value<'c, 'a>, +} + +struct StageOpeningInputSpec<'a> { + symbol: &'a str, + source_stage: &'a str, + source_claim: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, + claim_kind: &'a str, +} + +struct RelationSpec<'a> { + symbol: &'a str, + kind: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + output_count: usize, +} + +struct SumcheckClaimSpec<'a> { + symbol: &'a str, + stage: &'a str, + domain: &'a str, + num_rounds: usize, + degree: usize, + claim: &'a str, + relation: &'a str, +} + +struct SumcheckBatchSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + policy: &'a str, + ordered_claims: &'a [&'a str], + claim_label: &'a str, + round_label: &'a str, + round_schedule: &'a str, +} + +struct SumcheckDriverSpec<'a> { + symbol: &'a str, + stage: &'a str, + proof_slot: &'a str, + relation: &'a str, + policy: &'a str, + round_schedule: &'a str, + claim_label: &'a str, + round_label: &'a str, + num_rounds: usize, + degree: usize, +} + +struct SumcheckInstanceResultSpec<'a> { + symbol: &'a str, + source: &'a str, + claim: &'a str, + relation: &'a str, + index: usize, + point_arity: usize, + num_rounds: usize, + round_offset: usize, + point_order: &'a str, + degree: usize, +} + +struct OpeningClaimSpec<'a> { + symbol: &'a str, + oracle: &'a str, + domain: &'a str, + point_arity: usize, + claim_kind: &'a str, +} diff --git a/crates/bolt/src/protocols/jolt/phases/stage8.rs b/crates/bolt/src/protocols/jolt/phases/stage8.rs new file mode 100644 index 0000000000..97ed373918 --- /dev/null +++ b/crates/bolt/src/protocols/jolt/phases/stage8.rs @@ -0,0 +1,293 @@ +use melior::ir::operation::{OperationLike, OperationRef}; +use melior::ir::Value; + +use crate::ir::{BoltModule, Compute, Party, Protocol}; +use crate::mlir::{verify_module, MeliorContext, MlirError}; +use crate::schema::{verify_protocol_schema, SchemaError}; + +use super::super::oracles; +use super::super::params::JoltProtocolParams; +use super::lowering::lower_party_to_compute; + +const EVALUATION_POINT_SOURCE_SYMBOL: &str = "stage8.evaluation.point_source"; + +pub fn build_stage8_protocol<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> Result, MlirError> { + let module = context.new_module::("jolt.stage8", None); + oracles::append_foundation_ops(context, &module, params)?; + oracles::append_committed_oracles(context, &module, params)?; + context.append_op_with_owned_attrs( + &module, + "protocol.params", + Some("jolt.params"), + ¶ms.attrs(), + )?; + context.append_op( + &module, + "protocol.boundary", + Some("jolt.stage8"), + &[("roles", r#"["prover", "verifier"]"#)], + )?; + + let fs = context.append_typed_op( + &module, + "transcript.state", + Some("fs_after_stage7"), + &[("scheme", "@blake2b_transcript")], + &[], + &["!transcript.state_type"], + )?; + let state = result(fs, 0, "transcript.state")?; + let _stage = context.append_typed_op( + &module, + "piop.stage", + Some("stage8"), + &[ + ("name", r#""evaluation_proof""#), + ("order", "8 : i64"), + ("roles", r#"["prover", "verifier"]"#), + ], + &[], + &["!piop.stage_type"], + )?; + + let _point_source = append_opening_input( + context, + &module, + Stage8OpeningInputSpec { + symbol: EVALUATION_POINT_SOURCE_SYMBOL, + source_stage: "stage7", + source_claim: "stage7.input.stage6.booleanity.InstructionRa_0", + oracle: "InstructionRa_0", + point_arity: params.log_t + params.log_k_chunk, + }, + )?; + let mut claims = Vec::new(); + let mut claim_symbols = Vec::new(); + append_evaluation_claim( + context, + &module, + params, + &mut claims, + &mut claim_symbols, + Stage8EvaluationClaimSpec { + oracle: "RamInc", + source_stage: "stage6", + source_claim: "stage6.inc_claim_reduction.eval.RamInc", + }, + )?; + append_evaluation_claim( + context, + &module, + params, + &mut claims, + &mut claim_symbols, + Stage8EvaluationClaimSpec { + oracle: "RdInc", + source_stage: "stage6", + source_claim: "stage6.inc_claim_reduction.eval.RdInc", + }, + )?; + for index in 0..params.instruction_d { + append_evaluation_claim( + context, + &module, + params, + &mut claims, + &mut claim_symbols, + Stage8EvaluationClaimSpec { + oracle: &format!("InstructionRa_{index}"), + source_stage: "stage7", + source_claim: &format!( + "stage7.hamming_weight_claim_reduction.eval.InstructionRa_{index}" + ), + }, + )?; + } + for index in 0..params.bytecode_d { + append_evaluation_claim( + context, + &module, + params, + &mut claims, + &mut claim_symbols, + Stage8EvaluationClaimSpec { + oracle: &format!("BytecodeRa_{index}"), + source_stage: "stage7", + source_claim: &format!( + "stage7.hamming_weight_claim_reduction.eval.BytecodeRa_{index}" + ), + }, + )?; + } + for index in 0..params.ram_d { + append_evaluation_claim( + context, + &module, + params, + &mut claims, + &mut claim_symbols, + Stage8EvaluationClaimSpec { + oracle: &format!("RamRa_{index}"), + source_stage: "stage7", + source_claim: &format!("stage7.hamming_weight_claim_reduction.eval.RamRa_{index}"), + }, + )?; + } + + let opening_batch = context.append_typed_op( + &module, + "pcs.opening_batch", + Some("stage8.evaluation.openings"), + &[ + ("proof_slot", "@stage8.evaluation"), + ("policy", r#""jolt_stage8_joint_rlc""#), + ("count", &int_attr(claims.len())), + ("ordered_claims", &symbol_array_attr(&claim_symbols)), + ], + &claims, + &["!pcs.opening_batch_type"], + )?; + let opening_batch = result(opening_batch, 0, "pcs.opening_batch")?; + let _state = context.append_typed_op( + &module, + "pcs.batch_open", + Some("stage8.evaluation.proof"), + &[ + ("pcs", "@dory"), + ("proof_slot", "@stage8.evaluation"), + ("transcript_label", r#""rlc_claims""#), + ], + &[state, opening_batch], + &["!transcript.state_type", "!pcs.opening_proof_type"], + )?; + + verify_module(&module)?; + verify_protocol_schema(&module)?; + Ok(module) +} + +pub fn lower_stage8_to_compute<'c>( + context: &'c MeliorContext, + module: &BoltModule<'c, Party>, +) -> Result, MlirError> { + lower_party_to_compute(context, module, "jolt.stage8", "jolt.stage8", "stage8") +} + +fn append_evaluation_claim<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + params: &JoltProtocolParams, + claims: &mut Vec>, + claim_symbols: &mut Vec, + spec: Stage8EvaluationClaimSpec<'_>, +) -> Result<(), MlirError> { + let input_symbol = format!("stage8.input.{}.{}", spec.source_stage, spec.oracle); + let opening_input = append_opening_input( + context, + module, + Stage8OpeningInputSpec { + symbol: &input_symbol, + source_stage: spec.source_stage, + source_claim: spec.source_claim, + oracle: spec.oracle, + point_arity: params.log_t + params.log_k_chunk, + }, + )?; + let opening_symbol = format!("stage8.evaluation.opening.{}", spec.oracle); + let opening = context.append_typed_op( + module, + "pcs.opening_claim", + Some(&opening_symbol), + &[ + ("oracle", &format!("@{}", spec.oracle)), + ("family", "@jolt.main_witness_polys"), + ("domain", "@jolt.main_witness_commit_domain"), + ("point_arity", &int_attr(params.log_t + params.log_k_chunk)), + ], + &[opening_input.point, opening_input.eval], + &["!pcs.opening_claim_type"], + )?; + claims.push(result(opening, 0, "pcs.opening_claim")?); + claim_symbols.push(opening_symbol); + Ok(()) +} + +fn append_opening_input<'c, 'a>( + context: &'c MeliorContext, + module: &'a BoltModule<'c, Protocol>, + spec: Stage8OpeningInputSpec<'_>, +) -> Result, MlirError> { + let op = context.append_typed_op( + module, + "piop.opening_input", + Some(spec.symbol), + &[ + ("source_stage", &format!("@{}", spec.source_stage)), + ("source_claim", &format!("@{}", spec.source_claim)), + ("oracle", &format!("@{}", spec.oracle)), + ("domain", "@jolt.main_witness_commit_domain"), + ("point_arity", &int_attr(spec.point_arity)), + ("claim_kind", r#""committed""#), + ], + &[], + &["!poly.point", "!field.scalar", "!piop.opening_claim_type"], + )?; + Ok(Stage8OpeningInput { + point: result(op, 0, "piop.opening_input")?, + eval: result(op, 1, "piop.opening_input")?, + }) +} + +#[derive(Clone, Copy)] +struct Stage8OpeningInputSpec<'a> { + symbol: &'a str, + source_stage: &'a str, + source_claim: &'a str, + oracle: &'a str, + point_arity: usize, +} + +#[derive(Clone, Copy)] +struct Stage8EvaluationClaimSpec<'a> { + oracle: &'a str, + source_stage: &'a str, + source_claim: &'a str, +} + +struct Stage8OpeningInput<'c, 'a> { + point: Value<'c, 'a>, + eval: Value<'c, 'a>, +} + +fn result<'c, 'a>( + operation: OperationRef<'c, 'a>, + index: usize, + op_name: &str, +) -> Result, MlirError> { + operation.result(index).map(Into::into).map_err(|_| { + schema_error(format!( + "{op_name} requires result {index}, got {} results", + operation.result_count() + )) + }) +} + +fn symbol_array_attr(symbols: &[String]) -> String { + let symbols = symbols + .iter() + .map(|symbol| format!("@{symbol}")) + .collect::>() + .join(", "); + format!("[{symbols}]") +} + +fn int_attr(value: usize) -> String { + format!("{value} : i64") +} + +fn schema_error(message: impl Into) -> MlirError { + SchemaError::new(message).into() +} diff --git a/crates/bolt/src/protocols/jolt/validate.rs b/crates/bolt/src/protocols/jolt/validate.rs new file mode 100644 index 0000000000..8070f5ba3a --- /dev/null +++ b/crates/bolt/src/protocols/jolt/validate.rs @@ -0,0 +1,134 @@ +use melior::ir::operation::{OperationLike, OperationResult}; +use melior::ir::OperationRef; + +use crate::ir::{string_attribute_value, BoltModule, Concrete, Party, Protocol}; +use crate::schema::{ + find_symbol, int_attr, missing_module_op, missing_symbol, require_symbol_attr_eq, + symbol_array_attr, symbol_attr, verify_concrete_schema, verify_party_schema, + verify_protocol_schema, SchemaError, +}; + +use super::oracles::{MAIN_WITNESS_FAMILY_SYMBOL, PCS_SYMBOL}; +use super::params::ParsedJoltProtocolParams; + +pub fn verify_jolt_protocol_schema(module: &BoltModule<'_, Protocol>) -> Result<(), SchemaError> { + verify_protocol_schema(module)?; + validate_jolt_shape(module) +} + +pub fn verify_jolt_concrete_schema(module: &BoltModule<'_, Concrete>) -> Result<(), SchemaError> { + verify_concrete_schema(module)?; + validate_jolt_shape(module) +} + +pub fn verify_jolt_party_schema(module: &BoltModule<'_, Party>) -> Result<(), SchemaError> { + verify_party_schema(module)?; + validate_jolt_shape(module) +} + +fn validate_jolt_shape

(module: &BoltModule<'_, P>) -> Result<(), SchemaError> +where + P: crate::ir::Phase, +{ + let params_op = + find_symbol(module, "jolt.params").ok_or_else(|| missing_module_op("protocol.params"))?; + let params = ParsedJoltProtocolParams::from_op(params_op)?; + params.validate()?; + + require_symbol(module, ¶ms.field)?; + require_symbol(module, ¶ms.pcs)?; + require_symbol(module, ¶ms.transcript)?; + + let witness_family = find_symbol(module, MAIN_WITNESS_FAMILY_SYMBOL) + .ok_or_else(|| missing_symbol(MAIN_WITNESS_FAMILY_SYMBOL))?; + let witness_count = int_attr(witness_family, "count")?; + if witness_count != params.num_committed { + return Err(SchemaError::new(format!( + "main witness count {witness_count} does not match num_committed {}", + params.num_committed + ))); + } + let ordered_oracles = symbol_array_attr(witness_family, "ordered_oracles")?; + let expected_oracles = params.main_witness_oracles(); + if ordered_oracles != expected_oracles { + return Err(SchemaError::new(format!( + "main witness ordered_oracles mismatch: expected [{}], got [{}]", + expected_oracles.join(", "), + ordered_oracles.join(", ") + ))); + } + if ordered_oracles.len() != witness_count { + return Err(SchemaError::new(format!( + "main witness ordered_oracles length {} does not match count {witness_count}", + ordered_oracles.len() + ))); + } + for oracle in &ordered_oracles { + let oracle_op = find_symbol(module, oracle).ok_or_else(|| missing_symbol(oracle))?; + require_symbol_attr_eq(oracle_op, "field", ¶ms.field)?; + } + + let commit = find_symbol(module, "jolt.main_witness_commitments") + .ok_or_else(|| missing_symbol("jolt.main_witness_commitments"))?; + require_symbol_attr_eq(commit, "oracle_family", MAIN_WITNESS_FAMILY_SYMBOL)?; + + let pcs = find_symbol(module, "jolt.dory_main_witness_commit") + .ok_or_else(|| missing_symbol("jolt.dory_main_witness_commit"))?; + require_operand_owner_symbol_eq(pcs, 0, "jolt.main_witness_commitments")?; + let scheme = symbol_attr(pcs, "scheme")?; + if scheme != PCS_SYMBOL || scheme != params.pcs { + return Err(SchemaError::new(format!( + "PCS scheme `{scheme}` does not match params pcs `{}`", + params.pcs + ))); + } + + Ok(()) +} + +fn require_operand_owner_symbol_eq( + operation: OperationRef<'_, '_>, + index: usize, + expected: &str, +) -> Result<(), SchemaError> { + let operand = operation.operand(index).map_err(|_| { + SchemaError::new(format!( + "{} missing required operand {index}", + crate::schema::operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + SchemaError::new(format!( + "{} operand {index} must be an op result", + crate::schema::operation_name(operation) + )) + })?; + let actual = owner + .owner() + .attribute("sym_name") + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| { + SchemaError::new(format!( + "{} operand {index} owner missing sym_name", + crate::schema::operation_name(operation) + )) + })?; + if actual == expected { + Ok(()) + } else { + Err(SchemaError::new(format!( + "{} operand {index} expected @{expected}, got @{actual}", + crate::schema::operation_name(operation) + ))) + } +} + +fn require_symbol

(module: &BoltModule<'_, P>, symbol: &str) -> Result<(), SchemaError> +where + P: crate::ir::Phase, +{ + find_symbol(module, symbol) + .map(|_| ()) + .ok_or_else(|| missing_symbol(symbol)) +} diff --git a/crates/bolt/src/protocols/jolt/verifier_common.rs.template b/crates/bolt/src/protocols/jolt/verifier_common.rs.template new file mode 100644 index 0000000000..5c31bfa6ef --- /dev/null +++ b/crates/bolt/src/protocols/jolt/verifier_common.rs.template @@ -0,0 +1,1789 @@ +#![expect( + clippy::too_many_arguments, + reason = "generated verifier helpers mirror staged protocol ABIs" +)] + +use jolt_field::{Field, Fr}; +use jolt_poly::EqPolynomial; +use jolt_sumcheck::{ + CompressedLabeledRoundPoly, SumcheckClaim, SumcheckError, SumcheckProof, SumcheckVerifier, +}; +use jolt_transcript::{Label, Transcript}; +use serde::Serialize; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct StageParams { + pub field: &'static str, + pub pcs: &'static str, + pub transcript: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct KernelPlan { + pub symbol: &'static str, + pub relation: &'static str, + pub kind: &'static str, + pub backend: &'static str, + pub abi: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct TranscriptSqueezePlan { + pub symbol: &'static str, + pub label: &'static str, + pub kind: &'static str, + pub count: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct TranscriptAbsorbBytesPlan { + pub symbol: &'static str, + pub label: &'static str, + pub payload: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct ProgramStepPlan { + pub kind: &'static str, + pub symbol: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OpeningInputPlan { + pub symbol: &'static str, + pub source_stage: &'static str, + pub source_claim: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FieldConstantPlan { + pub symbol: &'static str, + pub field: &'static str, + pub value: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FieldExprPlan { + pub symbol: &'static str, + pub kind: &'static str, + pub formula: &'static str, + pub operands: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SumcheckClaimPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub domain: &'static str, + pub num_rounds: usize, + pub degree: usize, + pub claim: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub claim_value: &'static str, + pub input_openings: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SumcheckBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static str, + pub claim_operands: &'static str, + pub claim_label: &'static str, + pub round_label: &'static str, + pub round_schedule: &'static [usize], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SumcheckDriverPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub kernel: Option<&'static str>, + pub relation: Option<&'static str>, + pub batch: &'static str, + pub policy: &'static str, + pub round_schedule: &'static [usize], + pub claim_label: &'static str, + pub round_label: &'static str, + pub num_rounds: usize, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SumcheckInstanceResultPlan { + pub symbol: &'static str, + pub source: &'static str, + pub claim: &'static str, + pub relation: &'static str, + pub index: usize, + pub point_arity: usize, + pub num_rounds: usize, + pub round_offset: usize, + pub point_order: &'static str, + pub degree: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SumcheckEvalPlan { + pub symbol: &'static str, + pub source: &'static str, + pub name: &'static str, + pub index: usize, + pub oracle: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct PointZeroPlan { + pub symbol: &'static str, + pub field: &'static str, + pub arity: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct PointSlicePlan { + pub symbol: &'static str, + pub source: &'static str, + pub offset: usize, + pub length: usize, + pub input: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct PointConcatPlan { + pub symbol: &'static str, + pub layout: &'static str, + pub arity: usize, + pub inputs: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OpeningClaimPlan { + pub symbol: &'static str, + pub oracle: &'static str, + pub domain: &'static str, + pub point_arity: usize, + pub claim_kind: &'static str, + pub point_source: &'static str, + pub eval_source: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OpeningClaimEqualityPlan { + pub symbol: &'static str, + pub mode: &'static str, + pub lhs: &'static str, + pub rhs: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OpeningBatchPlan { + pub symbol: &'static str, + pub stage: &'static str, + pub proof_slot: &'static str, + pub policy: &'static str, + pub count: usize, + pub ordered_claims: &'static str, + pub claim_operands: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct StageProgramPlan { + pub role: &'static str, + pub params: StageParams, + pub steps: &'static [ProgramStepPlan], + pub transcript_squeezes: &'static [TranscriptSqueezePlan], + pub transcript_absorb_bytes: &'static [TranscriptAbsorbBytesPlan], + pub opening_inputs: &'static [OpeningInputPlan], + pub field_constants: &'static [FieldConstantPlan], + pub field_exprs: &'static [FieldExprPlan], + pub kernels: &'static [KernelPlan], + pub claims: &'static [SumcheckClaimPlan], + pub batches: &'static [SumcheckBatchPlan], + pub drivers: &'static [SumcheckDriverPlan], + pub instance_results: &'static [SumcheckInstanceResultPlan], + pub evals: &'static [SumcheckEvalPlan], + pub point_zeros: &'static [PointZeroPlan], + pub point_slices: &'static [PointSlicePlan], + pub point_concats: &'static [PointConcatPlan], + pub opening_claims: &'static [OpeningClaimPlan], + pub opening_equalities: &'static [OpeningClaimEqualityPlan], + pub opening_batches: &'static [OpeningBatchPlan], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct StageProgramPlanNoPointZeros { + pub role: &'static str, + pub params: StageParams, + pub steps: &'static [ProgramStepPlan], + pub transcript_squeezes: &'static [TranscriptSqueezePlan], + pub transcript_absorb_bytes: &'static [TranscriptAbsorbBytesPlan], + pub opening_inputs: &'static [OpeningInputPlan], + pub field_constants: &'static [FieldConstantPlan], + pub field_exprs: &'static [FieldExprPlan], + pub kernels: &'static [KernelPlan], + pub claims: &'static [SumcheckClaimPlan], + pub batches: &'static [SumcheckBatchPlan], + pub drivers: &'static [SumcheckDriverPlan], + pub instance_results: &'static [SumcheckInstanceResultPlan], + pub evals: &'static [SumcheckEvalPlan], + pub point_slices: &'static [PointSlicePlan], + pub point_concats: &'static [PointConcatPlan], + pub opening_claims: &'static [OpeningClaimPlan], + pub opening_equalities: &'static [OpeningClaimEqualityPlan], + pub opening_batches: &'static [OpeningBatchPlan], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct StageVerifierProgramPlan { + pub params: StageParams, + pub steps: &'static [ProgramStepPlan], + pub transcript_squeezes: &'static [TranscriptSqueezePlan], + pub opening_inputs: &'static [OpeningInputPlan], + pub field_constants: &'static [FieldConstantPlan], + pub field_exprs: &'static [FieldExprPlan], + pub claims: &'static [SumcheckClaimPlan], + pub batches: &'static [SumcheckBatchPlan], + pub drivers: &'static [SumcheckDriverPlan], + pub instance_results: &'static [SumcheckInstanceResultPlan], + pub evals: &'static [SumcheckEvalPlan], + pub point_slices: &'static [PointSlicePlan], + pub point_concats: &'static [PointConcatPlan], + pub opening_claims: &'static [OpeningClaimPlan], + pub opening_equalities: &'static [OpeningClaimEqualityPlan], + pub opening_batches: &'static [OpeningBatchPlan], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct StageVerifierProgramPlanNoEqualities { + pub params: StageParams, + pub steps: &'static [ProgramStepPlan], + pub transcript_squeezes: &'static [TranscriptSqueezePlan], + pub opening_inputs: &'static [OpeningInputPlan], + pub field_constants: &'static [FieldConstantPlan], + pub field_exprs: &'static [FieldExprPlan], + pub claims: &'static [SumcheckClaimPlan], + pub batches: &'static [SumcheckBatchPlan], + pub drivers: &'static [SumcheckDriverPlan], + pub instance_results: &'static [SumcheckInstanceResultPlan], + pub evals: &'static [SumcheckEvalPlan], + pub point_slices: &'static [PointSlicePlan], + pub point_concats: &'static [PointConcatPlan], + pub opening_claims: &'static [OpeningClaimPlan], + pub opening_batches: &'static [OpeningBatchPlan], +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct VerifierProgramPlanMinimal { + pub params: StageParams, + pub transcript_squeezes: &'static [TranscriptSqueezePlan], + pub claims: &'static [SumcheckClaimPlan], + pub batches: &'static [SumcheckBatchPlan], + pub drivers: &'static [SumcheckDriverPlan], + pub instance_results: &'static [SumcheckInstanceResultPlan], + pub evals: &'static [SumcheckEvalPlan], + pub opening_claims: &'static [OpeningClaimPlan], + pub opening_batches: &'static [OpeningBatchPlan], +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] +pub struct StageNamedEval { + pub name: &'static str, + pub oracle: &'static str, + pub value: F, +} + +#[derive(Clone, Debug, Serialize)] +pub struct StageSumcheckOutput { + pub driver: &'static str, + pub point: Vec, + pub evals: Vec>, + pub proof: SumcheckProof, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct StageChallengeVector { + pub symbol: &'static str, + pub values: Vec, +} + +#[derive(Clone, Debug)] +pub struct StageExecutionArtifacts { + pub challenge_vectors: Vec>, + pub sumchecks: Vec>, + pub opening_batches: Vec<&'static OpeningBatchPlan>, +} + +impl Default for StageExecutionArtifacts { + fn default() -> Self { + Self { + challenge_vectors: Vec::new(), + sumchecks: Vec::new(), + opening_batches: Vec::new(), + } + } +} + +#[derive(Clone, Debug, Default, Serialize)] +pub struct StageProof { + pub sumchecks: Vec>, +} + +#[derive(Clone, Debug)] +pub struct StageOpeningInputValue { + pub symbol: &'static str, + pub point: Vec, + pub eval: F, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum RuntimePlanError { + MissingBatch { + driver: &'static str, + batch: &'static str, + }, + MissingClaim { + batch: &'static str, + claim: &'static str, + }, + MissingValue { + symbol: &'static str, + }, + InvalidInputLength { + input: &'static str, + expected: usize, + actual: usize, + }, + InvalidProof { + driver: &'static str, + reason: &'static str, + }, + UnsupportedFieldExpr { + symbol: &'static str, + formula: &'static str, + }, +} + +macro_rules! impl_runtime_plan_error_conversion { + ($error:ident) => { + impl From for $error { + fn from(error: super::common::RuntimePlanError) -> Self { + match error { + super::common::RuntimePlanError::MissingBatch { driver, batch } => { + Self::MissingBatch { driver, batch } + } + super::common::RuntimePlanError::MissingClaim { batch, claim } => { + Self::MissingClaim { batch, claim } + } + super::common::RuntimePlanError::MissingValue { symbol } => { + Self::MissingValue { symbol } + } + super::common::RuntimePlanError::InvalidInputLength { + input, + expected, + actual, + } => Self::InvalidInputLength { + input, + expected, + actual, + }, + super::common::RuntimePlanError::InvalidProof { driver, reason } => { + Self::InvalidProof { driver, reason } + } + super::common::RuntimePlanError::UnsupportedFieldExpr { symbol, formula } => { + Self::UnsupportedFieldExpr { symbol, formula } + } + } + } + } + }; +} + +pub(crate) use impl_runtime_plan_error_conversion; + +pub trait SymbolPlan { + fn symbol(&self) -> &'static str; +} + +impl SymbolPlan for TranscriptSqueezePlan { + fn symbol(&self) -> &'static str { + self.symbol + } +} + +impl SymbolPlan for TranscriptAbsorbBytesPlan { + fn symbol(&self) -> &'static str { + self.symbol + } +} + +impl SymbolPlan for SumcheckBatchPlan { + fn symbol(&self) -> &'static str { + self.symbol + } +} + +impl SymbolPlan for SumcheckClaimPlan { + fn symbol(&self) -> &'static str { + self.symbol + } +} + +impl SymbolPlan for SumcheckDriverPlan { + fn symbol(&self) -> &'static str { + self.symbol + } +} + +impl SymbolPlan for OpeningClaimPlan { + fn symbol(&self) -> &'static str { + self.symbol + } +} + +pub trait SumcheckClaimInfo: SymbolPlan { + fn num_rounds(&self) -> usize; + fn claim_value(&self) -> &'static str; +} + +impl SumcheckClaimInfo for SumcheckClaimPlan { + fn num_rounds(&self) -> usize { + self.num_rounds + } + + fn claim_value(&self) -> &'static str { + self.claim_value + } +} + +pub trait SumcheckDriverInfo: SymbolPlan { + fn batch(&self) -> &'static str; + fn num_rounds(&self) -> usize; + fn degree(&self) -> usize; + fn round_label(&self) -> &'static str; +} + +impl SumcheckDriverInfo for SumcheckDriverPlan { + fn batch(&self) -> &'static str { + self.batch + } + + fn num_rounds(&self) -> usize { + self.num_rounds + } + + fn degree(&self) -> usize { + self.degree + } + + fn round_label(&self) -> &'static str { + self.round_label + } +} + +#[derive(Clone, Debug, Default)] +pub struct ValueStore { + scalars: Vec<(&'static str, F)>, + points: Vec<(&'static str, Vec)>, +} + +impl ValueStore { + pub fn with_opening_inputs( + inputs: &[StageOpeningInputValue], + expected_inputs: &[OpeningInputPlan], + ) -> Result { + if inputs.len() != expected_inputs.len() { + return Err(RuntimePlanError::InvalidInputLength { + input: "opening_inputs", + expected: expected_inputs.len(), + actual: inputs.len(), + }); + } + for expected in expected_inputs { + let matching_count = inputs + .iter() + .filter(|input| input.symbol == expected.symbol) + .count(); + if matching_count != 1 { + return Err(RuntimePlanError::InvalidInputLength { + input: expected.symbol, + expected: 1, + actual: matching_count, + }); + } + if let Some(input) = inputs.iter().find(|input| input.symbol == expected.symbol) { + if input.point.len() != expected.point_arity { + return Err(RuntimePlanError::InvalidInputLength { + input: expected.symbol, + expected: expected.point_arity, + actual: input.point.len(), + }); + } + } + } + let mut store = Self::default(); + for input in inputs { + store.insert_scalar(input.symbol, input.eval); + store.insert_point(input.symbol, input.point.clone()); + } + Ok(store) + } + + pub fn seed_constants(&mut self, constants: &[FieldConstantPlan]) { + for constant in constants { + self.insert_scalar(constant.symbol, F::from_u64(constant.value as u64)); + } + } + + pub fn seed_point_zeros(&mut self, point_zeros: &[PointZeroPlan]) { + for zero in point_zeros { + self.insert_point(zero.symbol, vec![F::from_u64(0); zero.arity]); + } + } + + pub fn observe_challenge_vector( + &mut self, + plan: &TranscriptSqueezePlan, + values: &[F], + invalid_input_length: impl Fn(&'static str, usize, usize) -> E, + ) -> Result<(), E> { + self.insert_point(plan.symbol, values.to_vec()); + if matches!(plan.kind, "challenge_scalar" | "scalar") { + if values.len() != 1 { + return Err(invalid_input_length(plan.symbol, 1, values.len())); + } + self.insert_scalar(plan.symbol, values[0]); + } + Ok(()) + } + + pub fn observe_sumcheck_output( + &mut self, + instance_results: &[SumcheckInstanceResultPlan], + evals: &[SumcheckEvalPlan], + output: &StageSumcheckOutput, + normalize_point: impl Fn(&SumcheckInstanceResultPlan, Vec) -> Result, E>, + invalid_input_length: impl Fn(&'static str, usize, usize) -> E, + missing_value: impl Fn(&'static str) -> E, + ) -> Result<(), E> { + self.insert_point(output.driver, output.point.clone()); + for instance in instance_results + .iter() + .filter(|instance| instance.source == output.driver) + { + let end = instance.round_offset + instance.point_arity; + let point = output + .point + .get(instance.round_offset..end) + .ok_or_else(|| invalid_input_length(instance.symbol, end, output.point.len()))? + .to_vec(); + self.insert_point(instance.symbol, normalize_point(instance, point)?); + } + for eval in evals.iter().filter(|eval| eval.source == output.driver) { + let value = output + .evals + .iter() + .find(|value| value.name == eval.name) + .or_else(|| output.evals.get(eval.index)) + .ok_or_else(|| missing_value(eval.symbol))? + .value; + self.insert_scalar(eval.symbol, value); + self.insert_scalar(eval.name, value); + } + Ok(()) + } + + pub fn evaluate_available_points( + &mut self, + point_slices: &[PointSlicePlan], + point_concats: &[PointConcatPlan], + invalid_input_length: impl Fn(&'static str, usize, usize) -> E, + ) -> Result<(), E> { + loop { + let mut progress = 0usize; + for slice in point_slices { + if self.try_point(slice.symbol).is_some() { + continue; + } + let Some(input) = self.try_point(slice.input) else { + continue; + }; + let end = slice.offset + slice.length; + let point = input + .get(slice.offset..end) + .ok_or_else(|| invalid_input_length(slice.symbol, end, input.len()))? + .to_vec(); + self.insert_point(slice.symbol, point); + progress += 1; + } + for concat in point_concats { + if self.try_point(concat.symbol).is_some() { + continue; + } + let Some(point) = self.try_concat_point(concat) else { + continue; + }; + if point.len() != concat.arity { + return Err(invalid_input_length( + concat.symbol, + concat.arity, + point.len(), + )); + } + self.insert_point(concat.symbol, point); + progress += 1; + } + if progress == 0 { + return Ok(()); + } + } + } + + pub fn evaluate_available_field_exprs( + &mut self, + field_exprs: &[FieldExprPlan], + evaluate: impl Fn(&FieldExprPlan, &[F]) -> Result, + ) -> Result<(), E> { + loop { + let mut progress = 0usize; + for expr in field_exprs { + if self.try_scalar(expr.symbol).is_some() { + continue; + } + let Some(operands) = self.try_expr_operands(expr) else { + continue; + }; + self.insert_scalar(expr.symbol, evaluate(expr, &operands)?); + progress += 1; + } + if progress == 0 { + return Ok(()); + } + } + } + + pub fn verify_opening_equalities( + &self, + opening_equalities: &[OpeningClaimEqualityPlan], + invalid_proof: impl Fn(&'static str, &'static str) -> E, + missing_value: impl Fn(&'static str) -> E, + ) -> Result<(), E> { + for equality in opening_equalities { + match equality.mode { + "point_and_eval" => { + if self.point_or(equality.lhs, &missing_value)? + != self.point_or(equality.rhs, &missing_value)? + || self.scalar_or(equality.lhs, &missing_value)? + != self.scalar_or(equality.rhs, &missing_value)? + { + return Err(invalid_proof( + equality.symbol, + "opening claim equality failed", + )); + } + } + _ => { + return Err(invalid_proof( + equality.symbol, + "unsupported opening equality mode", + )); + } + } + } + Ok(()) + } + + pub fn insert_scalar(&mut self, symbol: &'static str, value: F) { + if let Some((_, existing)) = self.scalars.iter_mut().find(|(name, _)| *name == symbol) { + *existing = value; + } else { + self.scalars.push((symbol, value)); + } + } + + pub fn insert_point(&mut self, symbol: &'static str, point: Vec) { + if let Some((_, existing)) = self.points.iter_mut().find(|(name, _)| *name == symbol) { + *existing = point; + } else { + self.points.push((symbol, point)); + } + } + + pub fn scalar_or( + &self, + symbol: &'static str, + missing_value: impl FnOnce(&'static str) -> E, + ) -> Result { + self.try_scalar(symbol).ok_or_else(|| missing_value(symbol)) + } + + pub fn try_scalar(&self, symbol: &str) -> Option { + self.scalars + .iter() + .find(|(name, _)| *name == symbol) + .map(|(_, value)| *value) + } + + pub fn point_or( + &self, + symbol: &'static str, + missing_value: impl FnOnce(&'static str) -> E, + ) -> Result<&[F], E> { + self.try_point(symbol).ok_or_else(|| missing_value(symbol)) + } + + pub fn try_point(&self, symbol: &str) -> Option<&[F]> { + self.points + .iter() + .find(|(name, _)| *name == symbol) + .map(|(_, point)| point.as_slice()) + } + + fn try_expr_operands(&self, expr: &FieldExprPlan) -> Option> { + if expr.operands.is_empty() { + return Some(Vec::new()); + } + expr.operands + .split('|') + .map(|operand| self.try_scalar(operand)) + .collect() + } + + fn try_concat_point(&self, concat: &PointConcatPlan) -> Option> { + let mut point = Vec::with_capacity(concat.arity); + for input in symbol_list(concat.inputs) { + point.extend_from_slice(self.try_point(input)?); + } + Some(point) + } +} + +pub fn symbol_list(symbols: &'static str) -> impl Iterator { + symbols.split('|').filter(|symbol| !symbol.is_empty()) +} + +pub fn find_plan<'a, T: SymbolPlan>(plans: &'a [T], symbol: &str) -> Option<&'a T> { + plans.iter().find(|plan| plan.symbol() == symbol) +} + +pub fn find_batch<'a>( + batches: &'a [SumcheckBatchPlan], + driver: &'static str, + batch: &'static str, +) -> Result<&'a SumcheckBatchPlan, RuntimePlanError> { + find_plan(batches, batch).ok_or(RuntimePlanError::MissingBatch { driver, batch }) +} + +pub fn batch_claims<'a, C: SymbolPlan>( + claims: &'a [C], + batch: &SumcheckBatchPlan, +) -> Result, RuntimePlanError> { + symbol_list(batch.claim_operands) + .map(|symbol| { + find_plan(claims, symbol).ok_or(RuntimePlanError::MissingClaim { + batch: batch.symbol, + claim: symbol, + }) + }) + .collect() +} + +pub fn batch_claim_values( + claims: &[&C], + field_exprs: &[FieldExprPlan], + store: &mut ValueStore, +) -> Result, RuntimePlanError> { + claims + .iter() + .map(|claim| { + store.evaluate_available_field_exprs(field_exprs, evaluate_field_expr)?; + store.scalar_or(claim.claim_value(), |symbol| { + RuntimePlanError::MissingValue { symbol } + }) + }) + .collect() +} + +pub fn verify_batched_sumcheck( + driver: &'static D, + proof: &StageSumcheckOutput, + claims: &'static [C], + batches: &'static [SumcheckBatchPlan], + field_exprs: &'static [FieldExprPlan], + opening_inputs: &'static [OpeningInputPlan], + opening_claims: &'static [OpeningClaimPlan], + opening_batches: &'static [OpeningBatchPlan], + store: &mut ValueStore, + transcript: &mut T, + expected_output: Expected, + observe_output: Observe, + map_sumcheck: MapSumcheck, +) -> Result, E> +where + T: Transcript, + E: From, + C: SumcheckClaimInfo, + D: SumcheckDriverInfo, + Expected: FnOnce(&ValueStore, &[StageNamedEval], &[Fr], &[Fr]) -> Result, + Observe: FnOnce(&mut ValueStore, &StageSumcheckOutput) -> Result<(), E>, + MapSumcheck: FnOnce(&'static str, SumcheckError) -> E, +{ + if proof.driver != driver.symbol() { + return Err(RuntimePlanError::InvalidProof { + driver: driver.symbol(), + reason: "driver symbol mismatch", + } + .into()); + } + let batch = find_batch(batches, driver.symbol(), driver.batch())?; + let claims = batch_claims(claims, batch)?; + let input_claims = batch_claim_values(&claims, field_exprs, store)?; + for claim in &input_claims { + append_labeled_scalar(transcript, batch.claim_label, claim); + } + let batching_coeffs = transcript.challenge_vector(claims.len()); + let claimed_sum = input_claims + .iter() + .zip(claims.iter()) + .zip(&batching_coeffs) + .map(|((claim, plan), coefficient)| { + claim.mul_pow_2(driver.num_rounds() - plan.num_rounds()) * *coefficient + }) + .sum::(); + let claim = SumcheckClaim::new(driver.num_rounds(), driver.degree(), claimed_sum); + let round_proofs = proof + .proof + .round_polynomials + .iter() + .map(|poly| CompressedLabeledRoundPoly::new(poly, driver.round_label().as_bytes())) + .collect::>(); + let output = SumcheckVerifier::verify(&claim, &round_proofs, transcript) + .map_err(|error| map_sumcheck(driver.symbol(), error))?; + if !proof.point.is_empty() && proof.point != output.point { + return Err(RuntimePlanError::InvalidProof { + driver: driver.symbol(), + reason: "batched point mismatch", + } + .into()); + } + let expected = expected_output(store, &proof.evals, &output.point, &batching_coeffs)?; + if output.value != expected { + return Err(RuntimePlanError::InvalidProof { + driver: driver.symbol(), + reason: "batched output claim mismatch", + } + .into()); + } + let verified = StageSumcheckOutput { + driver: driver.symbol(), + point: output.point, + evals: proof.evals.clone(), + proof: proof.proof.clone(), + }; + observe_output(store, &verified)?; + append_opening_claims( + opening_inputs, + opening_claims, + opening_batches, + store, + transcript, + &verified.evals, + |batch, claim| RuntimePlanError::MissingClaim { batch, claim }, + |symbol| RuntimePlanError::MissingValue { symbol }, + )?; + Ok(verified) +} + +pub fn eval_by_name( + evals: &[StageNamedEval], + name: &'static str, +) -> Result { + evals + .iter() + .find(|eval| eval.name == name) + .map(|eval| eval.value) + .ok_or(RuntimePlanError::MissingValue { symbol: name }) +} + +pub fn indexed_evals_by_prefix( + evals: &[StageNamedEval], + prefix: &'static str, + count: usize, +) -> Result, RuntimePlanError> { + let mut values = vec![None; count]; + for eval in evals { + let Some(suffix) = eval.name.strip_prefix(prefix) else { + continue; + }; + let index = suffix + .parse::() + .map_err(|_| RuntimePlanError::InvalidProof { + driver: prefix, + reason: "invalid indexed eval suffix", + })?; + if index >= count || values[index].is_some() { + return Err(RuntimePlanError::InvalidProof { + driver: prefix, + reason: "invalid indexed eval", + }); + } + values[index] = Some(eval.value); + } + values + .into_iter() + .map(|value| value.ok_or(RuntimePlanError::MissingValue { symbol: prefix })) + .collect() +} + +pub fn indexed_evals_by_prefix_any( + evals: &[StageNamedEval], + prefix: &'static str, +) -> Result, RuntimePlanError> { + let mut indexed_values = Vec::new(); + for eval in evals { + let Some(suffix) = eval.name.strip_prefix(prefix) else { + continue; + }; + let index = suffix + .parse::() + .map_err(|_| RuntimePlanError::InvalidProof { + driver: prefix, + reason: "invalid indexed eval suffix", + })?; + if indexed_values + .iter() + .any(|(existing_index, _)| *existing_index == index) + { + return Err(RuntimePlanError::InvalidProof { + driver: prefix, + reason: "duplicate indexed eval", + }); + } + indexed_values.push((index, eval.value)); + } + if indexed_values.is_empty() { + return Err(RuntimePlanError::MissingValue { symbol: prefix }); + } + indexed_values.sort_by_key(|(index, _)| *index); + for (expected, (actual, _)) in indexed_values.iter().enumerate() { + if *actual != expected { + return Err(RuntimePlanError::InvalidProof { + driver: prefix, + reason: "non-contiguous indexed eval", + }); + } + } + Ok(indexed_values.into_iter().map(|(_, value)| value).collect()) +} + +pub fn single_operand( + symbol: &'static str, + operands: &[F], +) -> Result { + require_operand_count(symbol, 1, operands.len())?; + Ok(operands[0]) +} + +pub fn require_operand_count( + input: &'static str, + expected: usize, + actual: usize, +) -> Result<(), RuntimePlanError> { + if expected == actual { + Ok(()) + } else { + Err(RuntimePlanError::InvalidInputLength { + input, + expected, + actual, + }) + } +} + +pub fn evaluate_field_expr( + expr: &FieldExprPlan, + operands: &[F], +) -> Result { + match expr.formula { + "opening_eval" => Ok(single_operand(expr.symbol, operands)?), + "field.add" => { + require_operand_count(expr.symbol, 2, operands.len())?; + Ok(operands[0] + operands[1]) + } + "field.sub" => { + require_operand_count(expr.symbol, 2, operands.len())?; + Ok(operands[0] - operands[1]) + } + "field.mul" => { + require_operand_count(expr.symbol, 2, operands.len())?; + Ok(operands[0] * operands[1]) + } + "field.neg" => { + require_operand_count(expr.symbol, 1, operands.len())?; + Ok(-operands[0]) + } + formula => { + if let Some(exponent) = formula.strip_prefix("field.pow:") { + require_operand_count(expr.symbol, 1, operands.len())?; + let exponent = exponent.parse::().map_err(|_| { + RuntimePlanError::UnsupportedFieldExpr { + symbol: expr.symbol, + formula, + } + })?; + return Ok(pow_field(operands[0], exponent)); + } + Err(RuntimePlanError::UnsupportedFieldExpr { + symbol: expr.symbol, + formula, + }) + } + } +} + +pub fn bytecode_gamma_powers(gamma: Fr) -> [Fr; 8] { + let mut powers = [Fr::from_u64(1); 8]; + for index in 1..powers.len() { + powers[index] = powers[index - 1] * gamma; + } + powers +} + +pub fn indexed_boolean_eq(index: usize, point: &[Fr]) -> Fr { + point + .iter() + .enumerate() + .map(|(bit, value)| { + if (index >> (point.len() - 1 - bit)) & 1 == 1 { + *value + } else { + Fr::from_u64(1) - *value + } + }) + .product() +} + +pub fn field_powers(base: Fr, count: usize) -> Vec { + let mut powers = Vec::with_capacity(count); + let mut power = Fr::from_u64(1); + for _ in 0..count { + powers.push(power); + power *= base; + } + powers +} + +pub fn prefix_point<'a, F: Field>( + point: &'a [F], + length: usize, + input: &'static str, +) -> Result<&'a [F], RuntimePlanError> { + point + .get(..length) + .filter(|prefix| prefix.len() == length) + .ok_or(RuntimePlanError::InvalidInputLength { + input, + expected: length, + actual: point.len(), + }) +} + +pub fn suffix_point<'a, F: Field>( + point: &'a [F], + length: usize, + input: &'static str, +) -> Result<&'a [F], RuntimePlanError> { + point + .get(point.len().saturating_sub(length)..) + .filter(|suffix| suffix.len() == length) + .ok_or(RuntimePlanError::InvalidInputLength { + input, + expected: length, + actual: point.len(), + }) +} + +pub fn normalize_bytecode_read_raf_point( + point: &[F], + log_t: usize, + input: &'static str, +) -> Result, RuntimePlanError> { + let log_k = point + .len() + .checked_sub(log_t) + .ok_or(RuntimePlanError::InvalidInputLength { + input, + expected: log_t, + actual: point.len(), + })?; + let mut normalized = point.to_vec(); + normalized[..log_k].reverse(); + normalized[log_k..].reverse(); + Ok(normalized) +} + +pub fn normalize_instruction_read_raf_point( + point: &[F], + input: &'static str, +) -> Result, RuntimePlanError> { + const LOG_K: usize = 128; + if point.len() < LOG_K { + return Err(RuntimePlanError::InvalidInputLength { + input, + expected: LOG_K, + actual: point.len(), + }); + } + let mut normalized = point.to_vec(); + normalized[LOG_K..].reverse(); + Ok(normalized) +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage67RelationSymbols { + pub hamming_booleanity_relation: &'static str, + pub hamming_booleanity_instance: &'static str, + pub booleanity_point: &'static str, + pub stage5_instruction_ra0: &'static str, + pub booleanity_combined_point: &'static str, + pub booleanity_gamma: &'static str, + pub booleanity_instruction_ra_prefix: &'static str, + pub booleanity_bytecode_ra_prefix: &'static str, + pub booleanity_ram_ra_prefix: &'static str, + pub hamming_weight_eval: &'static str, + pub hamming_lookup_output: &'static str, + pub ram_ra_virtual_cycle: &'static str, + pub ram_ra_virtual_eval_prefix: &'static str, + pub instruction_ra_virtual_cycle: &'static str, + pub instruction_ra_virtual_eval_prefix: &'static str, + pub instruction_ra_virtual_input_prefix: &'static str, + pub instruction_ra_virtual_gamma: &'static str, + pub inc_ram_stage2: &'static str, + pub inc_ram_stage4: &'static str, + pub inc_rd_stage4: &'static str, + pub inc_rd_stage5: &'static str, + pub inc_gamma: &'static str, + pub inc_ram_eval: &'static str, + pub inc_rd_eval: &'static str, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Stage67BytecodeSymbols { + pub point: &'static str, + pub gamma: &'static str, + pub bytecode_ra_eval_prefix: &'static str, + pub entries: &'static str, + pub entry_bytecode_index: &'static str, + pub stage_gammas: [&'static str; 5], + pub stage_cycle_points: [&'static str; 5], + pub stage4_register_point: &'static str, + pub stage5_register_point: &'static str, + pub entry_rd: &'static str, + pub entry_rs1: &'static str, + pub entry_rs2: &'static str, + pub entry_lookup_table: &'static str, +} + +pub trait Stage67BytecodeEntry { + fn address(&self) -> Fr; + fn imm(&self) -> Fr; + fn circuit_flags(&self) -> &[bool; 14]; + fn rd(&self) -> Option; + fn rs1(&self) -> Option; + fn rs2(&self) -> Option; + fn lookup_table(&self) -> Option; + fn is_interleaved(&self) -> bool; + fn is_branch(&self) -> bool; + fn left_is_rs1(&self) -> bool; + fn left_is_pc(&self) -> bool; + fn right_is_rs2(&self) -> bool; + fn right_is_imm(&self) -> bool; + fn is_noop(&self) -> bool; +} + +pub fn store_scalar(store: &ValueStore, symbol: &'static str) -> Result { + store.scalar_or(symbol, |symbol| RuntimePlanError::MissingValue { symbol }) +} + +pub fn store_point<'a>( + store: &'a ValueStore, + symbol: &'static str, +) -> Result<&'a [Fr], RuntimePlanError> { + store.point_or(symbol, |symbol| RuntimePlanError::MissingValue { symbol }) +} + +pub fn stage67_trace_rounds( + instance_results: &[SumcheckInstanceResultPlan], + symbols: &Stage67RelationSymbols, +) -> Result { + instance_results + .iter() + .find(|instance| instance.relation == symbols.hamming_booleanity_relation) + .map(|instance| instance.num_rounds) + .ok_or(RuntimePlanError::MissingValue { + symbol: symbols.hamming_booleanity_instance, + }) +} + +pub fn expected_stage67_bytecode_read_raf( + entries: &[E], + entry_bytecode_index: usize, + num_lookup_tables: usize, + store: &ValueStore, + evals: &[StageNamedEval], + local_point: &[Fr], + log_t: usize, + symbols: &Stage67BytecodeSymbols, +) -> Result { + let opening_point = normalize_bytecode_read_raf_point(local_point, log_t, symbols.point)?; + let log_k = opening_point.len() - log_t; + let (r_address_prime, r_cycle_prime) = opening_point.split_at(log_k); + + let gamma = store_scalar(store, symbols.gamma)?; + let gamma_powers = bytecode_gamma_powers(gamma); + let int_eval = identity_polynomial_eval(r_address_prime); + let stage_value_evals = stage67_bytecode_stage_value_evals( + entries, + entry_bytecode_index, + num_lookup_tables, + store, + r_address_prime, + r_cycle_prime.len(), + symbols, + )?; + let stage_cycle_points = + stage67_bytecode_stage_cycle_points(store, r_cycle_prime.len(), symbols)?; + let int_contrib = [ + gamma_powers[5] * int_eval, + Fr::from_u64(0), + gamma_powers[4] * int_eval, + Fr::from_u64(0), + Fr::from_u64(0), + ]; + + let mut val = Fr::from_u64(0); + for index in 0..stage_value_evals.len() { + val += (stage_value_evals[index] + int_contrib[index]) + * EqPolynomial::::mle(&stage_cycle_points[index], r_cycle_prime) + * gamma_powers[index]; + } + + let entry_bits = (0..log_k) + .map(|index| Fr::from_u64(((entry_bytecode_index >> (log_k - 1 - index)) & 1) as u64)) + .collect::>(); + let zero_cycle = vec![Fr::from_u64(0); r_cycle_prime.len()]; + let entry_contrib = gamma_powers[7] + * EqPolynomial::::mle(&entry_bits, r_address_prime) + * EqPolynomial::::mle(&zero_cycle, r_cycle_prime); + let bytecode_ra = indexed_evals_by_prefix_any(evals, symbols.bytecode_ra_eval_prefix)? + .into_iter() + .product::(); + Ok((val + entry_contrib) * bytecode_ra) +} + +pub fn expected_stage67_booleanity( + store: &ValueStore, + evals: &[StageNamedEval], + local_point: &[Fr], + log_t: usize, + symbols: &Stage67RelationSymbols, +) -> Result { + let log_k_chunk = + local_point + .len() + .checked_sub(log_t) + .ok_or(RuntimePlanError::InvalidInputLength { + input: symbols.booleanity_point, + expected: log_t, + actual: local_point.len(), + })?; + let stage5_point = store_point(store, symbols.stage5_instruction_ra0)?; + let stage5_address_len = + stage5_point + .len() + .checked_sub(log_t) + .ok_or(RuntimePlanError::InvalidInputLength { + input: symbols.stage5_instruction_ra0, + expected: log_t, + actual: stage5_point.len(), + })?; + if stage5_address_len < log_k_chunk { + return Err(RuntimePlanError::InvalidInputLength { + input: symbols.stage5_instruction_ra0, + expected: log_k_chunk + log_t, + actual: stage5_point.len(), + }); + } + + let mut stage5_addr = stage5_point[..stage5_address_len].to_vec(); + stage5_addr.reverse(); + let mut combined_r = stage5_addr[stage5_address_len - log_k_chunk..].to_vec(); + combined_r.extend(stage5_point[stage5_address_len..].iter().rev().copied()); + if combined_r.len() != local_point.len() { + return Err(RuntimePlanError::InvalidInputLength { + input: symbols.booleanity_combined_point, + expected: local_point.len(), + actual: combined_r.len(), + }); + } + let mut verifier_point = combined_r[..log_k_chunk].to_vec(); + verifier_point.reverse(); + verifier_point.extend(combined_r[log_k_chunk..].iter().rev().copied()); + let eq_eval = EqPolynomial::::mle(local_point, &verifier_point); + + let gamma = store_scalar(store, symbols.booleanity_gamma)?; + let gamma_sq = gamma.square(); + let mut gamma_power = Fr::from_u64(1); + let mut booleanity = Fr::from_u64(0); + for ra in stage67_booleanity_evals(evals, symbols)? { + booleanity += gamma_power * (ra.square() - ra); + gamma_power *= gamma_sq; + } + Ok(eq_eval * booleanity) +} + +pub fn expected_stage67_hamming_booleanity( + store: &ValueStore, + evals: &[StageNamedEval], + local_point: &[Fr], + symbols: &Stage67RelationSymbols, +) -> Result { + let hamming = eval_by_name(evals, symbols.hamming_weight_eval)?; + let lookup_output_point = reverse_slice(store_point(store, symbols.hamming_lookup_output)?); + if lookup_output_point.len() != local_point.len() { + return Err(RuntimePlanError::InvalidInputLength { + input: symbols.hamming_lookup_output, + expected: local_point.len(), + actual: lookup_output_point.len(), + }); + } + let eq_eval = EqPolynomial::::mle(local_point, &lookup_output_point); + Ok((hamming.square() - hamming) * eq_eval) +} + +pub fn expected_stage67_ram_ra_virtual( + store: &ValueStore, + evals: &[StageNamedEval], + local_point: &[Fr], + symbols: &Stage67RelationSymbols, +) -> Result { + let r_cycle_reduced = reverse_slice(local_point); + let r_cycle = suffix_point( + store_point(store, symbols.ram_ra_virtual_cycle)?, + r_cycle_reduced.len(), + symbols.ram_ra_virtual_cycle, + )?; + let eq_eval = EqPolynomial::::mle(r_cycle, &r_cycle_reduced); + let ram_ra = indexed_evals_by_prefix_any(evals, symbols.ram_ra_virtual_eval_prefix)? + .into_iter() + .product::(); + Ok(eq_eval * ram_ra) +} + +pub fn expected_stage67_instruction_ra_virtual( + opening_inputs: &[OpeningInputPlan], + store: &ValueStore, + evals: &[StageNamedEval], + local_point: &[Fr], + symbols: &Stage67RelationSymbols, +) -> Result { + let r_cycle_reduced = reverse_slice(local_point); + let r_cycle = suffix_point( + store_point(store, symbols.instruction_ra_virtual_cycle)?, + r_cycle_reduced.len(), + symbols.instruction_ra_virtual_cycle, + )?; + let eq_eval = EqPolynomial::::mle(r_cycle, &r_cycle_reduced); + let committed_ra = + indexed_evals_by_prefix_any(evals, symbols.instruction_ra_virtual_eval_prefix)?; + let virtual_count = opening_inputs + .iter() + .filter(|input| { + input + .symbol + .starts_with(symbols.instruction_ra_virtual_input_prefix) + }) + .count(); + if virtual_count == 0 || committed_ra.len() % virtual_count != 0 { + return Err(RuntimePlanError::InvalidInputLength { + input: symbols.instruction_ra_virtual_eval_prefix, + expected: virtual_count, + actual: committed_ra.len(), + }); + } + let committed_per_virtual = committed_ra.len() / virtual_count; + let gamma = store_scalar(store, symbols.instruction_ra_virtual_gamma)?; + let mut gamma_power = Fr::from_u64(1); + let mut value = Fr::from_u64(0); + for chunk in committed_ra.chunks(committed_per_virtual) { + value += gamma_power * chunk.iter().copied().product::(); + gamma_power *= gamma; + } + Ok(eq_eval * value) +} + +pub fn expected_stage67_inc_claim_reduction( + store: &ValueStore, + evals: &[StageNamedEval], + local_point: &[Fr], + symbols: &Stage67RelationSymbols, +) -> Result { + let r_cycle_reduced = reverse_slice(local_point); + let ram_inc_stage2 = suffix_point( + store_point(store, symbols.inc_ram_stage2)?, + r_cycle_reduced.len(), + symbols.inc_ram_stage2, + )?; + let ram_inc_stage4 = suffix_point( + store_point(store, symbols.inc_ram_stage4)?, + r_cycle_reduced.len(), + symbols.inc_ram_stage4, + )?; + let rd_inc_stage4 = suffix_point( + store_point(store, symbols.inc_rd_stage4)?, + r_cycle_reduced.len(), + symbols.inc_rd_stage4, + )?; + let rd_inc_stage5 = suffix_point( + store_point(store, symbols.inc_rd_stage5)?, + r_cycle_reduced.len(), + symbols.inc_rd_stage5, + )?; + let gamma = store_scalar(store, symbols.inc_gamma)?; + let eq_ram_combined = EqPolynomial::::mle(ram_inc_stage2, &r_cycle_reduced) + + gamma * EqPolynomial::::mle(ram_inc_stage4, &r_cycle_reduced); + let eq_rd_combined = EqPolynomial::::mle(rd_inc_stage4, &r_cycle_reduced) + + gamma * EqPolynomial::::mle(rd_inc_stage5, &r_cycle_reduced); + let ram_inc = eval_by_name(evals, symbols.inc_ram_eval)?; + let rd_inc = eval_by_name(evals, symbols.inc_rd_eval)?; + Ok(ram_inc * eq_ram_combined + gamma.square() * rd_inc * eq_rd_combined) +} + +fn stage67_booleanity_evals( + evals: &[StageNamedEval], + symbols: &Stage67RelationSymbols, +) -> Result, RuntimePlanError> { + let mut values = indexed_evals_by_prefix_any(evals, symbols.booleanity_instruction_ra_prefix)?; + values.extend(indexed_evals_by_prefix_any( + evals, + symbols.booleanity_bytecode_ra_prefix, + )?); + values.extend(indexed_evals_by_prefix_any( + evals, + symbols.booleanity_ram_ra_prefix, + )?); + Ok(values) +} + +fn stage67_bytecode_stage_cycle_points( + store: &ValueStore, + log_t: usize, + symbols: &Stage67BytecodeSymbols, +) -> Result<[Vec; 5], RuntimePlanError> { + let point = |index| { + let symbol = symbols.stage_cycle_points[index]; + suffix_point(store_point(store, symbol)?, log_t, symbol).map(|point| point.to_vec()) + }; + Ok([point(0)?, point(1)?, point(2)?, point(3)?, point(4)?]) +} + +fn stage67_bytecode_stage_value_evals( + entries: &[E], + entry_bytecode_index: usize, + num_lookup_tables: usize, + store: &ValueStore, + r_address: &[Fr], + log_t: usize, + symbols: &Stage67BytecodeSymbols, +) -> Result<[Fr; 5], RuntimePlanError> { + let expected_len = + 1usize + .checked_shl(r_address.len() as u32) + .ok_or(RuntimePlanError::InvalidInputLength { + input: symbols.entries, + expected: usize::BITS as usize, + actual: r_address.len(), + })?; + if entries.len() != expected_len { + return Err(RuntimePlanError::InvalidInputLength { + input: symbols.entries, + expected: expected_len, + actual: entries.len(), + }); + } + if entry_bytecode_index >= expected_len { + return Err(RuntimePlanError::InvalidInputLength { + input: symbols.entry_bytecode_index, + expected: expected_len, + actual: entry_bytecode_index + 1, + }); + } + + let stage1_gamma_powers = field_powers(store_scalar(store, symbols.stage_gammas[0])?, 16); + let stage2_gamma_powers = field_powers(store_scalar(store, symbols.stage_gammas[1])?, 4); + let stage3_gamma_powers = field_powers(store_scalar(store, symbols.stage_gammas[2])?, 9); + let stage4_gamma_powers = field_powers(store_scalar(store, symbols.stage_gammas[3])?, 3); + let stage5_gamma_powers = field_powers( + store_scalar(store, symbols.stage_gammas[4])?, + num_lookup_tables + 2, + ); + + let stage4_register_point = + stage67_register_prefix_point(store, symbols.stage4_register_point, log_t)?; + let stage5_register_point = + stage67_register_prefix_point(store, symbols.stage5_register_point, log_t)?; + + let mut evals = [Fr::from_u64(0); 5]; + for (index, entry) in entries.iter().enumerate() { + let eq = indexed_boolean_eq(index, r_address); + let values = stage67_bytecode_entry_stage_values( + entry, + num_lookup_tables, + stage4_register_point, + stage5_register_point, + &stage1_gamma_powers, + &stage2_gamma_powers, + &stage3_gamma_powers, + &stage4_gamma_powers, + &stage5_gamma_powers, + symbols, + )?; + for stage in 0..evals.len() { + evals[stage] += eq * values[stage]; + } + } + Ok(evals) +} + +fn stage67_bytecode_entry_stage_values( + entry: &E, + num_lookup_tables: usize, + stage4_register_point: &[Fr], + stage5_register_point: &[Fr], + stage1_gamma_powers: &[Fr], + stage2_gamma_powers: &[Fr], + stage3_gamma_powers: &[Fr], + stage4_gamma_powers: &[Fr], + stage5_gamma_powers: &[Fr], + symbols: &Stage67BytecodeSymbols, +) -> Result<[Fr; 5], RuntimePlanError> { + let flags = entry.circuit_flags(); + let mut stage1 = entry.address() + entry.imm() * stage1_gamma_powers[1]; + for (flag, gamma) in flags.iter().zip(stage1_gamma_powers.iter().skip(2)) { + if *flag { + stage1 += *gamma; + } + } + + let mut stage2 = Fr::from_u64(0); + if flags[5] { + stage2 += stage2_gamma_powers[0]; + } + if entry.is_branch() { + stage2 += stage2_gamma_powers[1]; + } + if flags[6] { + stage2 += stage2_gamma_powers[2]; + } + if flags[7] { + stage2 += stage2_gamma_powers[3]; + } + + let mut stage3 = entry.imm() + entry.address() * stage3_gamma_powers[1]; + if entry.left_is_rs1() { + stage3 += stage3_gamma_powers[2]; + } + if entry.left_is_pc() { + stage3 += stage3_gamma_powers[3]; + } + if entry.right_is_rs2() { + stage3 += stage3_gamma_powers[4]; + } + if entry.right_is_imm() { + stage3 += stage3_gamma_powers[5]; + } + if entry.is_noop() { + stage3 += stage3_gamma_powers[6]; + } + if flags[7] { + stage3 += stage3_gamma_powers[7]; + } + if flags[12] { + stage3 += stage3_gamma_powers[8]; + } + + let stage4 = stage67_register_eq(entry.rd(), stage4_register_point, symbols.entry_rd)? + * stage4_gamma_powers[0] + + stage67_register_eq(entry.rs1(), stage4_register_point, symbols.entry_rs1)? + * stage4_gamma_powers[1] + + stage67_register_eq(entry.rs2(), stage4_register_point, symbols.entry_rs2)? + * stage4_gamma_powers[2]; + + let mut stage5 = stage67_register_eq(entry.rd(), stage5_register_point, symbols.entry_rd)? + * stage5_gamma_powers[0]; + if !entry.is_interleaved() { + stage5 += stage5_gamma_powers[1]; + } + if let Some(table) = entry.lookup_table() { + if table >= num_lookup_tables { + return Err(RuntimePlanError::InvalidInputLength { + input: symbols.entry_lookup_table, + expected: num_lookup_tables, + actual: table + 1, + }); + } + stage5 += stage5_gamma_powers[2 + table]; + } + + Ok([stage1, stage2, stage3, stage4, stage5]) +} + +fn stage67_register_eq( + index: Option, + point: &[Fr], + input: &'static str, +) -> Result { + let Some(index) = index else { + return Ok(Fr::from_u64(0)); + }; + let register_count = + 1usize + .checked_shl(point.len() as u32) + .ok_or(RuntimePlanError::InvalidInputLength { + input, + expected: usize::BITS as usize, + actual: point.len(), + })?; + if index >= register_count { + return Err(RuntimePlanError::InvalidInputLength { + input, + expected: register_count, + actual: index + 1, + }); + } + Ok(indexed_boolean_eq(index, point)) +} + +fn stage67_register_prefix_point<'a>( + store: &'a ValueStore, + symbol: &'static str, + log_t: usize, +) -> Result<&'a [Fr], RuntimePlanError> { + let point = store_point(store, symbol)?; + let register_len = + point + .len() + .checked_sub(log_t) + .ok_or(RuntimePlanError::InvalidInputLength { + input: symbol, + expected: log_t, + actual: point.len(), + })?; + prefix_point(point, register_len, symbol) +} + +pub fn operand_polynomial_eval(point: &[Fr], left: bool) -> Fr { + let stride_offset = usize::from(!left); + let operand_bits = point.len() / 2; + (0..operand_bits) + .map(|index| point[2 * index + stride_offset].mul_pow_2(operand_bits - 1 - index)) + .sum() +} + +pub fn identity_polynomial_eval(point: &[Fr]) -> Fr { + point + .iter() + .enumerate() + .map(|(index, value)| value.mul_pow_2(point.len() - 1 - index)) + .sum() +} + +pub fn append_labeled_scalar(transcript: &mut T, label: &'static str, scalar: &Fr) +where + T: Transcript, +{ + transcript.append(&Label(label.as_bytes())); + transcript.append(scalar); +} + +pub fn append_opening_claims( + opening_inputs: &[OpeningInputPlan], + opening_claims: &[OpeningClaimPlan], + opening_batches: &[OpeningBatchPlan], + store: &mut ValueStore, + transcript: &mut T, + evals: &[StageNamedEval], + missing_claim: impl Fn(&'static str, &'static str) -> E, + missing_value: impl Fn(&'static str) -> E, +) -> Result<(), E> +where + T: Transcript, +{ + if opening_batches.is_empty() { + for eval in evals { + append_labeled_scalar(transcript, "opening_claim", &eval.value); + } + return Ok(()); + } + let mut seen = opening_inputs + .iter() + .filter_map(|input| { + store + .try_point(input.symbol) + .map(|point| (input.claim_kind, input.oracle, point.to_vec())) + }) + .collect::>(); + for batch in opening_batches { + for symbol in symbol_list(batch.claim_operands) { + let claim = opening_claims + .iter() + .find(|claim| claim.symbol == symbol) + .ok_or_else(|| missing_claim(batch.symbol, symbol))?; + let point = store.point_or(claim.point_source, &missing_value)?.to_vec(); + if seen.iter().any(|(kind, oracle, seen_point)| { + *kind == claim.claim_kind && *oracle == claim.oracle && seen_point == &point + }) { + continue; + } + let value = store.scalar_or(claim.eval_source, &missing_value)?; + append_labeled_scalar(transcript, "opening_claim", &value); + seen.push((claim.claim_kind, claim.oracle, point)); + } + } + Ok(()) +} + +pub fn lt_polynomial_eval(x: &[Fr], y: &[Fr]) -> Fr { + let mut lt_eval = Fr::from_u64(0); + let mut eq_term = Fr::from_u64(1); + for (x_i, y_i) in x.iter().zip(y.iter()) { + lt_eval += (Fr::from_u64(1) - *x_i) * *y_i * eq_term; + eq_term *= Fr::from_u64(1) - *x_i - *y_i + *x_i * *y_i + *x_i * *y_i; + } + lt_eval +} + +pub fn pow_field(base: F, mut exponent: usize) -> F { + let mut result = F::one(); + let mut power = base; + while exponent != 0 { + if exponent & 1 == 1 { + result *= power; + } + power = power.square(); + exponent >>= 1; + } + result +} + +pub fn reverse_slice(values: &[Fr]) -> Vec { + values.iter().rev().copied().collect() +} diff --git a/crates/bolt/src/protocols/mod.rs b/crates/bolt/src/protocols/mod.rs new file mode 100644 index 0000000000..5e01eb0c28 --- /dev/null +++ b/crates/bolt/src/protocols/mod.rs @@ -0,0 +1 @@ +pub mod jolt; diff --git a/crates/bolt/src/schema.rs b/crates/bolt/src/schema.rs new file mode 100644 index 0000000000..46e963c4e7 --- /dev/null +++ b/crates/bolt/src/schema.rs @@ -0,0 +1,1466 @@ +use std::collections::BTreeSet; +use std::error::Error; +use std::fmt::{self, Display, Formatter}; + +use melior::ir::block::BlockLike; +use melior::ir::operation::OperationLike; +use melior::ir::operation::OperationResult; +use melior::ir::{Attribute, OperationRef}; + +use crate::ir::{ + string_attribute_value, symbol_attribute_value, BoltModule, Compute, Concrete, Cpu, Party, + Protocol, Role, +}; +use crate::mlir::MlirError; +use crate::pass::{verify_concrete_transcript, VerifyError}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SchemaError { + message: String, +} + +impl SchemaError { + pub(crate) fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} + +impl Display for SchemaError { + fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result { + formatter.write_str(&self.message) + } +} + +impl Error for SchemaError {} + +impl From for MlirError { + fn from(error: SchemaError) -> Self { + Self::Schema { + message: error.to_string(), + } + } +} + +impl From for SchemaError { + fn from(error: VerifyError) -> Self { + Self::new(error.to_string()) + } +} + +pub fn verify_protocol_schema(module: &BoltModule<'_, Protocol>) -> Result<(), SchemaError> { + verify_schema(module, ModulePhase::Protocol) +} + +pub fn verify_concrete_schema(module: &BoltModule<'_, Concrete>) -> Result<(), SchemaError> { + verify_schema(module, ModulePhase::Concrete)?; + verify_concrete_transcript(module)?; + Ok(()) +} + +pub fn verify_party_schema(module: &BoltModule<'_, Party>) -> Result<(), SchemaError> { + verify_schema(module, ModulePhase::Party)?; + verify_concrete_transcript(module)?; + Ok(()) +} + +pub fn verify_compute_schema(module: &BoltModule<'_, Compute>) -> Result<(), SchemaError> { + verify_schema(module, ModulePhase::Compute) +} + +pub fn verify_cpu_schema(module: &BoltModule<'_, Cpu>) -> Result<(), SchemaError> { + verify_schema(module, ModulePhase::Cpu) +} + +#[derive(Clone, Copy)] +enum ModulePhase { + Protocol, + Concrete, + Party, + Compute, + Cpu, +} + +fn verify_schema

(module: &BoltModule<'_, P>, phase: ModulePhase) -> Result<(), SchemaError> +where + P: crate::ir::Phase, +{ + let phase_attr = module + .as_mlir_module() + .as_operation() + .attribute("bolt.phase") + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| SchemaError::new("module missing required attr `bolt.phase`"))?; + if phase_attr != P::NAME { + return Err(SchemaError::new(format!( + "module phase `{phase_attr}` does not match expected `{}`", + P::NAME + ))); + } + + let mut kernel_symbols = BTreeSet::new(); + let mut kernel_refs = Vec::new(); + let role = module.role(); + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + validate_op(op, phase)?; + if matches!(role, Some(Role::Verifier)) + && matches!(phase, ModulePhase::Compute | ModulePhase::Cpu) + { + validate_verifier_lowering_op(op)?; + } + match operation_name(op).as_str() { + "compute.kernel" | "cpu.kernel" => { + let _ = kernel_symbols.insert(string_attr(op, "sym_name")?); + } + "compute.sumcheck_kernel_claim" + | "compute.sumcheck_kernel_driver" + | "cpu.sumcheck_claim" + | "cpu.sumcheck_driver" => { + kernel_refs.push(symbol_attr(op, "kernel")?); + } + _ => {} + } + } + + if matches!(phase, ModulePhase::Compute | ModulePhase::Cpu) { + for kernel in kernel_refs { + if !kernel_symbols.contains(&kernel) { + return Err(SchemaError::new(format!( + "kernel reference @{kernel} has no matching kernel definition" + ))); + } + } + } + + Ok(()) +} + +fn validate_verifier_lowering_op(operation: OperationRef<'_, '_>) -> Result<(), SchemaError> { + let name = operation_name(operation); + match name.as_str() { + "compute.kernel" + | "compute.sumcheck_claim" + | "compute.sumcheck_driver" + | "compute.sumcheck_kernel_claim" + | "compute.sumcheck_kernel_driver" + | "compute.generate_oracle" + | "compute.generate_oracle_family" + | "cpu.kernel" + | "cpu.sumcheck_claim" + | "cpu.sumcheck_driver" => Err(SchemaError::new(format!( + "verifier lowering must use verifier-specific ops, got `{name}`" + ))), + _ => Ok(()), + } +} + +fn validate_op(operation: OperationRef<'_, '_>, _phase: ModulePhase) -> Result<(), SchemaError> { + let name = operation_name(operation); + match name.as_str() { + "field.define" => require_attrs(operation, &["sym_name", "modulus_bits", "role"]), + "field.const" => { + require_attrs(operation, &["sym_name", "field", "value"])?; + require_shape(operation, 0, 1) + } + "field.zero" | "field.one" => { + require_attrs(operation, &["sym_name", "field"])?; + require_shape(operation, 0, 1) + } + "field.add" | "field.sub" | "field.mul" => { + require_attrs(operation, &["sym_name"])?; + require_shape(operation, 2, 1) + } + "field.neg" => { + require_attrs(operation, &["sym_name"])?; + require_shape(operation, 1, 1) + } + "field.pow" => { + require_attrs(operation, &["sym_name", "exponent"])?; + require_shape(operation, 1, 1) + } + "hash.function" => require_attrs(operation, &["sym_name", "algorithm"]), + "transcript.scheme" => require_attrs(operation, &["sym_name", "hash"]), + "pcs.scheme" => require_attrs(operation, &["sym_name", "field"]), + "poly.domain" => require_attrs(operation, &["sym_name", "field", "log_size"]), + "poly.point_slice" => { + require_attrs(operation, &["sym_name", "source", "offset", "length"])?; + require_shape(operation, 1, 1) + } + "poly.point_zero" => { + require_attrs(operation, &["sym_name", "field", "arity"])?; + require_shape(operation, 0, 1) + } + "poly.point_concat" => { + require_attrs(operation, &["sym_name", "layout", "arity"])?; + require_min_shape(operation, 1, 1) + } + "poly.lagrange_basis_eval" => { + require_attrs( + operation, + &["sym_name", "domain_start", "domain_size", "index"], + )?; + require_shape(operation, 1, 1) + } + "protocol.params" => require_attrs(operation, &["sym_name", "field", "pcs", "transcript"]), + "protocol.boundary" => require_attrs(operation, &["sym_name", "roles"]), + "piop.oracle" => require_attrs( + operation, + &[ + "sym_name", + "field", + "domain", + "commit_domain", + "visibility", + "layout", + ], + ), + "piop.oracle_family" => require_attrs( + operation, + &[ + "sym_name", + "ordered_oracles", + "visibility", + "count", + "domain", + ], + ), + "commit.publish_batch" => { + require_attrs(operation, &["sym_name", "oracle_family", "label"])?; + require_shape(operation, 0, 1) + } + "commit.publish_optional" => { + require_attrs(operation, &["sym_name", "oracle", "label", "skip_policy"])?; + require_shape(operation, 0, 1) + } + "pcs.commit_batch" => { + require_attrs(operation, &["sym_name", "scheme"])?; + require_shape(operation, 1, 0) + } + "transcript.absorb" | "transcript.absorb_optional" => { + require_attrs(operation, &["sym_name", "label"])?; + require_shape(operation, 2, 1) + } + "transcript.absorb_bytes" => { + require_attrs(operation, &["sym_name", "label", "payload"])?; + require_shape(operation, 1, 1) + } + "transcript.squeeze" => { + require_attrs(operation, &["sym_name", "label", "kind", "count"])?; + require_shape(operation, 1, 2) + } + "transcript.state" => { + require_attrs(operation, &["sym_name", "scheme"])?; + require_shape(operation, 0, 1) + } + "piop.stage" => { + require_attrs(operation, &["sym_name", "name", "order", "roles"])?; + require_shape(operation, 0, 1) + } + "piop.relation" => require_attrs( + operation, + &[ + "sym_name", + "kind", + "domain", + "num_rounds", + "degree", + "output_count", + ], + ), + "piop.sumcheck_claim" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + require_min_shape(operation, 1, 1) + } + "piop.opening_input" => { + require_attrs( + operation, + &[ + "sym_name", + "source_stage", + "source_claim", + "oracle", + "domain", + "point_arity", + "claim_kind", + ], + )?; + require_shape(operation, 0, 3) + } + "piop.sumcheck_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + require_min_shape(operation, 1, 1)?; + require_counted_operands(operation, 1, "ordered_claims") + } + "piop.sumcheck" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + require_shape(operation, 2, 4) + } + "piop.sumcheck_eval" => { + require_attrs( + operation, + &["sym_name", "source", "name", "index", "oracle"], + )?; + require_shape(operation, 1, 1) + } + "piop.sumcheck_instance_result" => { + require_attrs( + operation, + &[ + "sym_name", + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + require_shape(operation, 2, 2) + } + "piop.opening_claim" => { + require_attrs( + operation, + &["sym_name", "oracle", "domain", "point_arity", "claim_kind"], + )?; + require_shape(operation, 2, 1) + } + "piop.opening_claim_equal" => { + require_attrs(operation, &["sym_name", "mode"])?; + require_shape(operation, 2, 0)?; + require_opening_claim_equality(operation) + } + "piop.opening_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + ], + )?; + require_min_shape(operation, 0, 1)?; + require_counted_operands(operation, 0, "ordered_claims") + } + "party.function" => require_attrs(operation, &["sym_name", "source", "role"]), + "compute.params" => require_attrs(operation, &["sym_name", "field", "pcs", "transcript"]), + "compute.function" => require_attrs(operation, &["sym_name", "source"]), + "compute.relation" => require_attrs( + operation, + &[ + "sym_name", + "kind", + "domain", + "num_rounds", + "degree", + "output_count", + ], + ), + "compute.kernel" => require_attrs( + operation, + &["sym_name", "relation", "kind", "backend", "abi"], + ), + "compute.oracle_dense_trace" => { + require_attrs( + operation, + &[ + "sym_name", "oracle", "source", "domain", "num_vars", "padding", + ], + )?; + require_shape(operation, 0, 1) + } + "compute.oracle_one_hot_chunk" => { + require_attrs( + operation, + &[ + "sym_name", + "oracle", + "source", + "domain", + "num_vars", + "trace_num_vars", + "chunk", + "num_chunks", + "chunk_bits", + "padding", + "layout", + ], + )?; + require_shape(operation, 0, 1) + } + "compute.oracle_optional_advice" => { + require_attrs( + operation, + &[ + "sym_name", + "oracle", + "source", + "domain", + "num_vars", + "skip_policy", + ], + )?; + require_shape(operation, 0, 1) + } + "compute.oracle_ref" => { + require_attrs(operation, &["sym_name", "oracle", "domain", "num_vars"])?; + require_shape(operation, 0, 1) + } + "compute.oracle_family_init" => { + require_attrs(operation, &["sym_name", "family", "count"])?; + require_shape(operation, 0, 1) + } + "compute.oracle_family_append" => { + require_attrs(operation, &["sym_name", "family", "oracle", "index"])?; + require_shape(operation, 2, 1) + } + "compute.pcs_commit_batch" | "compute.pcs_receive_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "artifact", + "pcs", + "oracle_family", + "ordered_oracles", + "label", + "domain", + "num_vars", + "count", + ], + )?; + require_shape(operation, 1, 1) + } + "compute.pcs_commit_optional" | "compute.pcs_receive_optional" => { + require_attrs( + operation, + &[ + "sym_name", + "artifact", + "pcs", + "oracle", + "label", + "domain", + "num_vars", + "skip_policy", + ], + )?; + require_shape(operation, 1, 1) + } + "compute.transcript_init" => { + require_attrs(operation, &["sym_name", "scheme"])?; + require_shape(operation, 0, 1) + } + "compute.transcript_absorb" => { + require_attrs(operation, &["sym_name", "label", "optional"])?; + require_shape(operation, 2, 1) + } + "compute.transcript_absorb_bytes" => { + require_attrs(operation, &["sym_name", "label", "payload"])?; + require_shape(operation, 1, 1) + } + "compute.transcript_squeeze" => { + require_attrs(operation, &["sym_name", "label", "kind", "count"])?; + require_shape(operation, 1, 2) + } + "compute.opening_input" => { + require_attrs( + operation, + &[ + "sym_name", + "source_stage", + "source_claim", + "oracle", + "domain", + "point_arity", + "claim_kind", + ], + )?; + require_shape(operation, 0, 3) + } + "compute.point_slice" => { + require_attrs(operation, &["sym_name", "source", "offset", "length"])?; + require_shape(operation, 1, 1) + } + "compute.point_zero" => { + require_attrs(operation, &["sym_name", "field", "arity"])?; + require_shape(operation, 0, 1) + } + "compute.point_concat" => { + require_attrs(operation, &["sym_name", "layout", "arity"])?; + require_min_shape(operation, 1, 1) + } + "compute.field_const" => { + require_attrs(operation, &["sym_name", "field", "value"])?; + require_shape(operation, 0, 1) + } + "compute.field_zero" | "compute.field_one" => { + require_attrs(operation, &["sym_name", "field"])?; + require_shape(operation, 0, 1) + } + "compute.field_add" | "compute.field_sub" | "compute.field_mul" => { + require_attrs(operation, &["sym_name"])?; + require_shape(operation, 2, 1) + } + "compute.field_neg" => { + require_attrs(operation, &["sym_name"])?; + require_shape(operation, 1, 1) + } + "compute.field_pow" => { + require_attrs(operation, &["sym_name", "exponent"])?; + require_shape(operation, 1, 1) + } + "compute.poly_lagrange_basis_eval" => { + require_attrs( + operation, + &["sym_name", "domain_start", "domain_size", "index"], + )?; + require_shape(operation, 1, 1) + } + "compute.sumcheck_claim" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + require_min_shape(operation, 1, 1) + } + "compute.sumcheck_kernel_claim" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "kernel", + ], + )?; + require_min_shape(operation, 1, 1) + } + "compute.sumcheck_verify_claim" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + require_min_shape(operation, 1, 1) + } + "compute.sumcheck_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + require_min_shape(operation, 0, 1)?; + require_counted_operands(operation, 0, "ordered_claims") + } + "compute.sumcheck_driver" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + require_shape(operation, 2, 4) + } + "compute.sumcheck_kernel_driver" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "kernel", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + require_shape(operation, 2, 4) + } + "compute.sumcheck_verify" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + require_shape(operation, 2, 4) + } + "compute.sumcheck_eval" => { + require_attrs( + operation, + &["sym_name", "source", "name", "index", "oracle"], + )?; + require_shape(operation, 1, 1) + } + "compute.sumcheck_instance_result" => { + require_attrs( + operation, + &[ + "sym_name", + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + require_shape(operation, 2, 2) + } + "compute.opening_claim" => { + require_attrs( + operation, + &["sym_name", "oracle", "domain", "point_arity", "claim_kind"], + )?; + require_shape(operation, 2, 1) + } + "compute.opening_claim_equal" => { + require_attrs(operation, &["sym_name", "mode"])?; + require_shape(operation, 2, 0)?; + require_opening_claim_equality(operation) + } + "compute.opening_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + ], + )?; + require_min_shape(operation, 0, 1)?; + require_counted_operands(operation, 0, "ordered_claims") + } + "compute.pcs_opening_claim" => { + require_attrs( + operation, + &["sym_name", "oracle", "family", "domain", "point_arity"], + )?; + require_shape(operation, 2, 1) + } + "compute.pcs_opening_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "proof_slot", + "policy", + "count", + "ordered_claims", + ], + )?; + require_min_shape(operation, 0, 1)?; + require_counted_operands(operation, 0, "ordered_claims") + } + "compute.pcs_batch_open" | "compute.pcs_batch_verify" => { + require_attrs( + operation, + &["sym_name", "pcs", "proof_slot", "transcript_label"], + )?; + require_shape(operation, 2, 2) + } + "cpu.params" => require_attrs(operation, &["sym_name", "field", "pcs", "transcript"]), + "cpu.function" => require_attrs(operation, &["sym_name", "source"]), + "cpu.oracle_dense_trace" => { + require_attrs( + operation, + &[ + "sym_name", "oracle", "source", "domain", "num_vars", "padding", + ], + )?; + require_shape(operation, 0, 1) + } + "cpu.oracle_one_hot_chunk" => { + require_attrs( + operation, + &[ + "sym_name", + "oracle", + "source", + "domain", + "num_vars", + "trace_num_vars", + "chunk", + "num_chunks", + "chunk_bits", + "padding", + "layout", + ], + )?; + require_shape(operation, 0, 1) + } + "cpu.oracle_optional_advice" => { + require_attrs( + operation, + &[ + "sym_name", + "oracle", + "source", + "domain", + "num_vars", + "skip_policy", + ], + )?; + require_shape(operation, 0, 1) + } + "cpu.oracle_ref" => { + require_attrs(operation, &["sym_name", "oracle", "domain", "num_vars"])?; + require_shape(operation, 0, 1) + } + "cpu.oracle_family_init" => { + require_attrs(operation, &["sym_name", "family", "count"])?; + require_shape(operation, 0, 1) + } + "cpu.oracle_family_append" => { + require_attrs(operation, &["sym_name", "family", "oracle", "index"])?; + require_shape(operation, 2, 1) + } + "cpu.pcs_commit_batch" | "cpu.pcs_receive_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "artifact", + "pcs", + "oracle_family", + "ordered_oracles", + "label", + "domain", + "num_vars", + "count", + ], + )?; + require_shape(operation, 1, 1) + } + "cpu.pcs_commit_optional" | "cpu.pcs_receive_optional" => { + require_attrs( + operation, + &[ + "sym_name", + "artifact", + "pcs", + "oracle", + "label", + "domain", + "num_vars", + "skip_policy", + ], + )?; + require_shape(operation, 1, 1) + } + "cpu.transcript_init" => { + require_attrs(operation, &["sym_name", "scheme"])?; + require_shape(operation, 0, 1) + } + "cpu.transcript_absorb" => { + require_attrs(operation, &["sym_name", "label", "optional"])?; + require_shape(operation, 2, 1) + } + "cpu.transcript_absorb_bytes" => { + require_attrs(operation, &["sym_name", "label", "payload"])?; + require_shape(operation, 1, 1) + } + "cpu.transcript_squeeze" => { + require_attrs(operation, &["sym_name", "label", "kind", "count"])?; + require_shape(operation, 1, 2) + } + "cpu.opening_input" => { + require_attrs( + operation, + &[ + "sym_name", + "source_stage", + "source_claim", + "oracle", + "domain", + "point_arity", + "claim_kind", + ], + )?; + require_shape(operation, 0, 3) + } + "cpu.point_slice" => { + require_attrs(operation, &["sym_name", "source", "offset", "length"])?; + require_shape(operation, 1, 1) + } + "cpu.point_zero" => { + require_attrs(operation, &["sym_name", "field", "arity"])?; + require_shape(operation, 0, 1) + } + "cpu.point_concat" => { + require_attrs(operation, &["sym_name", "layout", "arity"])?; + require_min_shape(operation, 1, 1) + } + "cpu.field_const" => { + require_attrs(operation, &["sym_name", "field", "value"])?; + require_shape(operation, 0, 1) + } + "cpu.field_zero" | "cpu.field_one" => { + require_attrs(operation, &["sym_name", "field"])?; + require_shape(operation, 0, 1) + } + "cpu.field_add" | "cpu.field_sub" | "cpu.field_mul" => { + require_attrs(operation, &["sym_name"])?; + require_shape(operation, 2, 1) + } + "cpu.field_neg" => { + require_attrs(operation, &["sym_name"])?; + require_shape(operation, 1, 1) + } + "cpu.field_pow" => { + require_attrs(operation, &["sym_name", "exponent"])?; + require_shape(operation, 1, 1) + } + "cpu.poly_lagrange_basis_eval" => { + require_attrs( + operation, + &["sym_name", "domain_start", "domain_size", "index"], + )?; + require_shape(operation, 1, 1) + } + "cpu.kernel" => require_attrs( + operation, + &["sym_name", "relation", "kind", "backend", "abi"], + ), + "cpu.sumcheck_claim" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "kernel", + ], + )?; + require_min_shape(operation, 1, 1) + } + "cpu.sumcheck_verify_claim" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "domain", + "num_rounds", + "degree", + "claim", + "relation", + ], + )?; + require_min_shape(operation, 1, 1) + } + "cpu.sumcheck_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + "claim_label", + "round_label", + "round_schedule", + ], + )?; + require_min_shape(operation, 0, 1)?; + require_counted_operands(operation, 0, "ordered_claims") + } + "cpu.sumcheck_driver" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "kernel", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + require_shape(operation, 2, 4) + } + "cpu.sumcheck_verify" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "relation", + "policy", + "round_schedule", + "claim_label", + "round_label", + "num_rounds", + "degree", + ], + )?; + require_shape(operation, 2, 4) + } + "cpu.sumcheck_eval" => { + require_attrs( + operation, + &["sym_name", "source", "name", "index", "oracle"], + )?; + require_shape(operation, 1, 1) + } + "cpu.sumcheck_instance_result" => { + require_attrs( + operation, + &[ + "sym_name", + "source", + "claim", + "relation", + "index", + "point_arity", + "num_rounds", + "round_offset", + "point_order", + "degree", + ], + )?; + require_shape(operation, 2, 2) + } + "cpu.opening_claim" => { + require_attrs( + operation, + &["sym_name", "oracle", "domain", "point_arity", "claim_kind"], + )?; + require_shape(operation, 2, 1) + } + "cpu.opening_claim_equal" => { + require_attrs(operation, &["sym_name", "mode"])?; + require_shape(operation, 2, 0)?; + require_opening_claim_equality(operation) + } + "cpu.opening_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "stage", + "proof_slot", + "policy", + "count", + "ordered_claims", + ], + )?; + require_min_shape(operation, 0, 1)?; + require_counted_operands(operation, 0, "ordered_claims") + } + "cpu.pcs_opening_claim" => { + require_attrs( + operation, + &["sym_name", "oracle", "family", "domain", "point_arity"], + )?; + require_shape(operation, 2, 1) + } + "cpu.pcs_opening_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "proof_slot", + "policy", + "count", + "ordered_claims", + ], + )?; + require_min_shape(operation, 0, 1)?; + require_counted_operands(operation, 0, "ordered_claims") + } + "cpu.pcs_batch_open" | "cpu.pcs_batch_verify" => { + require_attrs( + operation, + &["sym_name", "pcs", "proof_slot", "transcript_label"], + )?; + require_shape(operation, 2, 2) + } + "pcs.opening_claim" => { + require_attrs( + operation, + &["sym_name", "oracle", "family", "domain", "point_arity"], + )?; + require_shape(operation, 2, 1) + } + "pcs.opening_batch" => { + require_attrs( + operation, + &[ + "sym_name", + "proof_slot", + "policy", + "count", + "ordered_claims", + ], + )?; + require_min_shape(operation, 0, 1)?; + require_counted_operands(operation, 0, "ordered_claims") + } + "pcs.batch_open" | "pcs.batch_verify" => { + require_attrs( + operation, + &["sym_name", "pcs", "proof_slot", "transcript_label"], + )?; + require_shape(operation, 2, 2) + } + _ if is_bolt_dialect_op(&name) => Err(SchemaError::new(format!( + "unknown Bolt op `{name}` in schema verifier" + ))), + _ => Ok(()), + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct OpeningClaimMetadata { + owner: String, + oracle: String, + domain: String, + point_arity: usize, + claim_kind: String, +} + +fn require_opening_claim_equality(operation: OperationRef<'_, '_>) -> Result<(), SchemaError> { + let mode = string_attr(operation, "mode")?; + if mode != "point_and_eval" { + return Err(SchemaError::new(format!( + "{} attr `mode` expected \"point_and_eval\", got \"{mode}\"", + operation_name(operation) + ))); + } + + let left = opening_claim_metadata(operation, 0)?; + let right = opening_claim_metadata(operation, 1)?; + if left.oracle != right.oracle + || left.domain != right.domain + || left.point_arity != right.point_arity + || left.claim_kind != right.claim_kind + { + return Err(SchemaError::new(format!( + "{} compares incompatible claims @{} and @{}", + operation_name(operation), + left.owner, + right.owner + ))); + } + Ok(()) +} + +fn opening_claim_metadata( + equality_op: OperationRef<'_, '_>, + operand_index: usize, +) -> Result { + let operand = equality_op.operand(operand_index).map_err(|_| { + SchemaError::new(format!( + "{} missing required operand {operand_index}", + operation_name(equality_op) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + SchemaError::new(format!( + "{} operand {operand_index} must be an op result", + operation_name(equality_op) + )) + })?; + let operation = owner.owner(); + let result_number = owner.result_number(); + let expected_result = match operation_name(operation).as_str() { + "piop.opening_input" | "compute.opening_input" | "cpu.opening_input" => 2, + "piop.opening_claim" | "compute.opening_claim" | "cpu.opening_claim" => 0, + name => { + return Err(SchemaError::new(format!( + "{} operand {operand_index} must be an opening claim, got result from `{name}`", + operation_name(equality_op) + ))); + } + }; + if result_number != expected_result { + return Err(SchemaError::new(format!( + "{} operand {operand_index} must use opening claim result {expected_result}, got result {result_number}", + operation_name(equality_op) + ))); + } + + Ok(OpeningClaimMetadata { + owner: string_attr(operation, "sym_name")?, + oracle: symbol_attr(operation, "oracle")?, + domain: symbol_attr(operation, "domain")?, + point_arity: int_attr(operation, "point_arity")?, + claim_kind: string_attr(operation, "claim_kind")?, + }) +} + +fn require_shape( + operation: OperationRef<'_, '_>, + operands: usize, + results: usize, +) -> Result<(), SchemaError> { + if operation.operand_count() != operands { + return Err(SchemaError::new(format!( + "{} expected {operands} operands, got {}", + operation_name(operation), + operation.operand_count() + ))); + } + if operation.result_count() != results { + return Err(SchemaError::new(format!( + "{} expected {results} results, got {}", + operation_name(operation), + operation.result_count() + ))); + } + Ok(()) +} + +fn require_min_shape( + operation: OperationRef<'_, '_>, + min_operands: usize, + results: usize, +) -> Result<(), SchemaError> { + if operation.operand_count() < min_operands { + return Err(SchemaError::new(format!( + "{} expected at least {min_operands} operands, got {}", + operation_name(operation), + operation.operand_count() + ))); + } + if operation.result_count() != results { + return Err(SchemaError::new(format!( + "{} expected {results} results, got {}", + operation_name(operation), + operation.result_count() + ))); + } + Ok(()) +} + +fn require_counted_operands( + operation: OperationRef<'_, '_>, + fixed_operands: usize, + ordered_attr: &str, +) -> Result<(), SchemaError> { + let count = int_attr(operation, "count")?; + let dynamic_count = operation.operand_count().saturating_sub(fixed_operands); + if count != dynamic_count { + return Err(SchemaError::new(format!( + "{} attr `count` expected {dynamic_count}, got {count}", + operation_name(operation) + ))); + } + let ordered = symbol_array_attr(operation, ordered_attr)?; + if ordered.len() != count { + return Err(SchemaError::new(format!( + "{} attr `{ordered_attr}` length {} does not match count {count}", + operation_name(operation), + ordered.len() + ))); + } + for (index, expected) in ordered.iter().enumerate() { + let operand_index = fixed_operands + index; + let actual = operand_owner_symbol(operation, operand_index)?; + if &actual != expected { + return Err(SchemaError::new(format!( + "{} operand {operand_index} expected @{expected}, got @{actual}", + operation_name(operation) + ))); + } + } + Ok(()) +} + +pub(crate) fn require_attrs( + operation: OperationRef<'_, '_>, + attrs: &[&str], +) -> Result<(), SchemaError> { + for attr in attrs { + if !operation.has_attribute(attr) { + return Err(SchemaError::new(format!( + "{} missing required attr `{attr}`", + operation_name(operation) + ))); + } + } + Ok(()) +} + +pub(crate) fn operand_owner_symbol( + operation: OperationRef<'_, '_>, + index: usize, +) -> Result { + let operand = operation.operand(index).map_err(|_| { + SchemaError::new(format!( + "{} missing required operand {index}", + operation_name(operation) + )) + })?; + let owner = OperationResult::try_from(operand).map_err(|_| { + SchemaError::new(format!( + "{} operand {index} must be an op result", + operation_name(operation) + )) + })?; + owner + .owner() + .attribute("sym_name") + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| { + SchemaError::new(format!( + "{} operand {index} owner missing sym_name", + operation_name(operation) + )) + }) +} + +pub(crate) fn require_symbol_attr_eq( + operation: OperationRef<'_, '_>, + attr: &str, + expected: &str, +) -> Result<(), SchemaError> { + let actual = symbol_attr(operation, attr)?; + if actual == expected { + Ok(()) + } else { + Err(SchemaError::new(format!( + "{} attr `{attr}` expected @{expected}, got @{actual}", + operation_name(operation) + ))) + } +} + +pub(crate) fn find_symbol<'c, P>( + module: &'c BoltModule<'_, P>, + symbol: &str, +) -> Option> +where + P: crate::ir::Phase, +{ + let mut operation = module.as_mlir_module().body().first_operation(); + while let Some(op) = operation { + operation = op.next_in_block(); + if op + .attribute("sym_name") + .ok() + .and_then(string_attribute_value) + .as_deref() + == Some(symbol) + { + return Some(op); + } + } + None +} + +pub(crate) fn symbol_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result { + operation + .attribute(attr) + .ok() + .and_then(symbol_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "symbol")) +} + +fn string_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .ok() + .and_then(string_attribute_value) + .ok_or_else(|| attr_error(operation, attr, "string")) +} + +pub(crate) fn symbol_array_attr( + operation: OperationRef<'_, '_>, + attr: &str, +) -> Result, SchemaError> { + let attribute = operation + .attribute(attr) + .map(|attribute| attribute.to_string()) + .ok() + .ok_or_else(|| attr_error(operation, attr, "symbol array"))?; + parse_symbol_array(&attribute).ok_or_else(|| attr_error(operation, attr, "symbol array")) +} + +fn parse_symbol_array(attribute: &str) -> Option> { + let inner = attribute.strip_prefix('[')?.strip_suffix(']')?.trim(); + if inner.is_empty() { + return Some(Vec::new()); + } + inner + .split(',') + .map(|item| item.trim().strip_prefix('@').map(ToOwned::to_owned)) + .collect() +} + +pub(crate) fn int_attr(operation: OperationRef<'_, '_>, attr: &str) -> Result { + operation + .attribute(attr) + .map(parse_integer_attr) + .ok() + .flatten() + .ok_or_else(|| attr_error(operation, attr, "integer")) +} + +fn parse_integer_attr(attribute: Attribute<'_>) -> Option { + attribute + .to_string() + .split_whitespace() + .next() + .and_then(|value| value.parse().ok()) +} + +fn attr_error(operation: OperationRef<'_, '_>, attr: &str, expected: &str) -> SchemaError { + SchemaError::new(format!( + "{} attr `{attr}` is not a {expected}", + operation_name(operation) + )) +} + +pub(crate) fn operation_name(operation: OperationRef<'_, '_>) -> String { + operation + .name() + .as_string_ref() + .as_str() + .unwrap_or("") + .to_owned() +} + +pub(crate) fn missing_module_op(name: &str) -> SchemaError { + SchemaError::new(format!("module missing required op `{name}`")) +} + +pub(crate) fn missing_symbol(symbol: &str) -> SchemaError { + SchemaError::new(format!("module missing required symbol @{symbol}")) +} + +fn is_bolt_dialect_op(name: &str) -> bool { + matches!( + name.split_once('.').map(|(dialect, _)| dialect), + Some( + "field" + | "poly" + | "hash" + | "transcript" + | "commit" + | "pcs" + | "protocol" + | "piop" + | "party" + | "compute" + | "cpu" + ) + ) +} diff --git a/crates/bolt/tests/commitment_ir.rs b/crates/bolt/tests/commitment_ir.rs new file mode 100644 index 0000000000..63817bba72 --- /dev/null +++ b/crates/bolt/tests/commitment_ir.rs @@ -0,0 +1,4187 @@ +#![expect( + clippy::expect_used, + clippy::unwrap_used, + reason = "integration tests use explicit panic messages" +)] + +use bolt::protocols::jolt::{ + assemble_jolt_generated_crates, assemble_jolt_workspace_generated_crates, + build_commitment_protocol, build_stage1_outer_protocol, build_stage2_protocol, + build_stage3_protocol, build_stage4_protocol, build_stage5_protocol, build_stage6_protocol, + build_stage7_protocol, build_stage8_protocol, commitment_cpu_program, emit_commitment_rust, + emit_stage1_rust, emit_stage2_rust, emit_stage3_rust, emit_stage4_rust, emit_stage5_rust, + emit_stage6_rust, emit_stage7_rust, emit_stage8_rust, jolt_artifact_config, jolt_rust_artifact, + lower_commitment_to_compute, lower_compute_to_cpu, lower_stage1_to_compute, + lower_stage2_to_compute, lower_stage3_to_compute, lower_stage4_to_compute, + lower_stage5_to_compute, lower_stage6_to_compute, lower_stage7_to_compute, + lower_stage8_to_compute, resolve_compute_kernels, stage1_cpu_program, stage2_cpu_program, + stage3_cpu_program, stage4_cpu_program, stage5_cpu_program, stage6_cpu_program, + stage7_cpu_program, stage8_cpu_program, validate_jolt_rust_artifact_imports, + verify_jolt_protocol_schema, write_jolt_generated_crates, JoltGeneratedCrate, + JoltProtocolParams, JoltProtocolStage, +}; +use bolt::{ + assemble_generated_crates, lower_piop_and_fiat_shamir, project_prover_party, + project_verifier_party, protocol_rust_artifact, validate_rust_artifact_imports, + verify_compute_schema, verify_concrete_transcript, verify_cpu_schema, verify_protocol_schema, + Concrete, Cpu, GeneratedFile, MeliorContext, ProtocolArtifactConfig, ProtocolRuntimeModule, + ProtocolStage, ProtocolStageKind, ProtocolStandaloneDependency, Role, RustSourceFile, + RustTypeRef, TextMlir, +}; +use std::fmt::Write as _; +use std::path::Path; +use std::process::Command; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[test] +fn bolt_irdl_dialects_are_registered() { + let context = MeliorContext::new(); + assert!(!context.context().allow_unregistered_dialects()); + + let registered = r#" +module @registered { + "field.define"() {modulus_bits = 254 : i64, role = "scalar", sym_name = "bn254_fr"} : () -> () +} +"#; + let _ = context + .parse_module::(registered) + .expect("registered dialect op parses"); + + let unknown = r#" +module @unknown { + "unknown.dialect_op"() : () -> () +} +"#; + let _ = context + .parse_module::(unknown) + .expect_err("unknown dialect rejected"); +} + +#[test] +fn commitment_protocol_uses_bolt_semantic_dialects() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_commitment_protocol(&context, ¶ms).expect("build protocol"); + let text = protocol.to_text_mlir(); + + assert!(text.contains("\"protocol.params\"()")); + assert!(text.contains("sym_name = \"jolt.params\"")); + assert!(text.contains("trace_length = 65536")); + assert!(text.contains("num_committed = 41")); + assert!(text.contains("\"field.define\"()")); + assert!(text.contains("sym_name = \"bn254_fr\"")); + assert!(text.contains("\"poly.domain\"()")); + assert!(text.contains("sym_name = \"jolt.main_witness_commit_domain\"")); + assert!(text.contains("\"protocol.boundary\"()")); + assert!(text.contains("sym_name = \"jolt.commitment_phase\"")); + assert!(text.contains("\"piop.oracle\"()")); + assert!(text.contains("sym_name = \"InstructionRa_0\"")); + assert!(text.contains("\"piop.oracle_family\"()")); + assert!(text.contains("sym_name = \"jolt.main_witness_polys\"")); + assert!(text.contains("ordered_oracles = [@RdInc, @RamInc, @InstructionRa_0")); + assert!(text.contains("\"commit.publish_batch\"()")); + assert!(text.contains("\"pcs.commit_batch\"(%")); + assert!(!text.contains("commitment = @jolt.main_witness_commitments")); + assert!(text.contains("\"transcript.absorb\"(%")); + + let parsed = context + .parse_module::(&text) + .expect("parse protocol MLIR"); + assert!(parsed.to_text_mlir().contains("\"protocol.boundary\"")); +} + +#[test] +fn concrete_commitment_phase_threads_transcript_state() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_commitment_protocol(&context, ¶ms).expect("build protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower Fiat-Shamir state"); + verify_concrete_transcript(&concrete).expect("valid transcript state threading"); + + let text = concrete.to_text_mlir(); + assert!(text.contains("!transcript.state_type")); + assert!(text.contains("\"transcript.state\"()")); + assert!(text.contains("sym_name = \"fs0\"")); + assert!(text.contains("\"transcript.absorb\"(%")); + assert!(text.contains("\"transcript.absorb_optional\"(%")); + assert!(!text.contains("in = @fs")); + assert!(!text.contains("out = @fs")); +} + +#[test] +fn transcript_absorb_bytes_threads_and_lowers_to_cpu() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = context + .parse_module::(&transcript_absorb_bytes_protocol(¶ms)) + .expect("parse absorb-bytes protocol"); + verify_protocol_schema(&protocol).expect("absorb-bytes protocol schema is valid"); + + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower absorb-bytes protocol"); + verify_concrete_transcript(&concrete).expect("absorb-bytes threads transcript state"); + + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + // Stage4 does not have its own lowering entrypoint yet; this exercises the + // shared operation mapping that each stage lowering uses. + let compute = lower_stage3_to_compute(&context, &prover).expect("lower to compute"); + verify_compute_schema(&compute).expect("compute schema accepts absorb-bytes"); + assert!(compute + .to_text_mlir() + .contains("\"compute.transcript_absorb_bytes\"(%")); + + let kernelized = resolve_compute_kernels(&context, &compute).expect("kernelize compute"); + assert!(kernelized + .to_text_mlir() + .contains("\"compute.transcript_absorb_bytes\"(%")); + + let cpu = lower_compute_to_cpu(&context, &kernelized).expect("lower to CPU"); + verify_cpu_schema(&cpu).expect("CPU schema accepts absorb-bytes"); + let cpu_text = cpu.to_text_mlir(); + assert!(cpu_text.contains("\"cpu.transcript_absorb_bytes\"(%")); + assert!(cpu_text.contains("label = \"ram_val_check_gamma\"")); + assert!(cpu_text.contains("payload = \"\"")); +} + +#[test] +fn concrete_projects_to_party_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_commitment_protocol(&context, ¶ms).expect("build protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower Fiat-Shamir state"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_text = prover.to_text_mlir(); + let verifier_text = verifier.to_text_mlir(); + + assert!(prover_text.contains("bolt.phase = \"party\"")); + assert!(prover_text.contains("bolt.role = \"prover\"")); + assert!(prover_text.contains("\"party.function\"()")); + assert!(prover_text.contains("role = \"prover\"")); + assert!(prover_text.contains("\"transcript.absorb\"(%")); + assert!(!prover_text.contains("in = @fs")); + assert!(verifier_text.contains("bolt.phase = \"party\"")); + assert!(verifier_text.contains("bolt.role = \"verifier\"")); + assert!(verifier_text.contains("\"party.function\"()")); + assert!(verifier_text.contains("role = \"verifier\"")); + assert!(verifier_text.contains("\"transcript.absorb\"(%")); + assert!(!verifier_text.contains("in = @fs")); +} + +#[test] +fn commitment_compute_lowers_to_cpu_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_commitment_protocol(&context, ¶ms).expect("build protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower Fiat-Shamir state"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_commitment_to_compute(&context, &prover).expect("lower compute"); + let verifier_compute = + lower_commitment_to_compute(&context, &verifier).expect("lower verifier compute"); + let prover_cpu = lower_compute_to_cpu(&context, &prover_compute).expect("lower to CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_compute).expect("lower verifier to CPU"); + let compute_text = prover_compute.to_text_mlir(); + let text = prover_cpu.to_text_mlir(); + let verifier_compute_text = verifier_compute.to_text_mlir(); + let verifier_text = verifier_cpu.to_text_mlir(); + + assert!(compute_text.contains("\"compute.oracle_dense_trace\"()")); + assert!(compute_text.contains("\"compute.oracle_one_hot_chunk\"()")); + assert!(compute_text.contains("\"compute.oracle_family_append\"(%")); + assert!(compute_text.contains("\"compute.pcs_commit_batch\"(%")); + assert!(compute_text.contains("artifact = @jolt.main_witness_commitments")); + assert!(compute_text.contains("ordered_oracles = [@RdInc, @RamInc, @InstructionRa_0")); + assert!(compute_text.contains("\"compute.pcs_commit_optional\"(%")); + assert!(compute_text.contains("skip_policy = \"missing_or_zero\"")); + assert!(compute_text.contains("!compute.transcript_state")); + assert!(compute_text.contains("\"compute.transcript_absorb\"(%")); + assert!(!compute_text.contains("in = @fs")); + assert!(text.contains("\"cpu.function\"()")); + assert!(!text.contains("\"compute.function\"()")); + assert!(text.contains("\"cpu.oracle_family_append\"(%")); + assert!(text.contains("\"cpu.pcs_commit_batch\"(%")); + assert!(text.contains("\"cpu.pcs_commit_optional\"(%")); + assert!(text.contains("skip_policy = \"missing_or_zero\"")); + assert!(text.contains("!cpu.transcript_state")); + assert!(text.contains("\"cpu.transcript_absorb\"(%")); + assert!(!text.contains("in = @fs")); + assert!(verifier_compute_text.contains("\"compute.oracle_ref\"()")); + assert!(verifier_compute_text.contains("\"compute.pcs_receive_batch\"(%")); + assert!(verifier_compute_text.contains("\"compute.pcs_receive_optional\"(%")); + assert!(!verifier_compute_text.contains("\"compute.pcs_commit_batch\"(%")); + assert!(verifier_text.contains("\"cpu.pcs_receive_batch\"(%")); + assert!(verifier_text.contains("\"cpu.pcs_receive_optional\"(%")); + assert!(!verifier_text.contains("\"cpu.pcs_commit_batch\"(%")); + + let parsed = context + .parse_module::(&text) + .expect("parse CPU MLIR"); + assert!(parsed.to_text_mlir().contains("\"cpu.pcs_commit_batch\"")); + let parsed = context + .parse_module::(&verifier_text) + .expect("parse verifier CPU MLIR"); + assert!(parsed.to_text_mlir().contains("\"cpu.pcs_receive_batch\"")); +} + +#[test] +fn generic_protocol_schema_accepts_non_jolt_params() { + let context = MeliorContext::new(); + let generic = context.new_module::("generic", None); + context + .append_op( + &generic, + "protocol.params", + Some("generic.params"), + &[ + ("field", "@some_field"), + ("pcs", "@some_pcs"), + ("transcript", "@some_transcript"), + ], + ) + .expect("append generic params"); + + assert!( + generic.verify(), + "generic protocol params pass IRDL verification" + ); + verify_protocol_schema(&generic).expect("generic schema does not require Jolt attrs"); +} + +#[test] +fn protocol_schema_rejects_bad_derived_params() { + let context = MeliorContext::new(); + let bad = context.new_module::("bad", None); + let mut attrs = JoltProtocolParams::fixture().attrs(); + for (name, value) in &mut attrs { + if name == "num_committed" { + *value = "40 : i64".to_owned(); + } + } + context + .append_op_with_owned_attrs(&bad, "protocol.params", Some("jolt.params"), &attrs) + .expect("append params"); + context + .append_op( + &bad, + "piop.oracle_family", + Some("jolt.main_witness_polys"), + &[ + ( + "ordered_oracles", + "[@RdInc, @RamInc, @InstructionRa_0, @RamRa_0, @BytecodeRa_0]", + ), + ("count", "40 : i64"), + ("domain", "@jolt.trace_domain"), + ("visibility", r#""committed""#), + ], + ) + .expect("append family"); + + let error = verify_jolt_protocol_schema(&bad).expect_err("bad derived param rejected"); + assert!(error.to_string().contains("num_committed must be 41")); +} + +#[test] +fn concrete_verifier_rejects_unthreaded_transcript_absorb() { + let context = MeliorContext::new(); + let concrete = context.new_module::("bad", None); + context + .append_op( + &concrete, + "transcript.absorb", + Some("bad_absorb"), + &[ + ("label", r#""commitment""#), + ("source", "@jolt.main_witness_commitments"), + ], + ) + .expect("append bad absorb"); + + let error = verify_concrete_transcript(&concrete).expect_err("missing transcript state"); + assert!(error + .to_string() + .contains("requires a prior transcript.state result")); +} + +#[test] +fn protocol_schema_accepts_explicit_sumcheck_and_opening_flow() { + let context = MeliorContext::new(); + let protocol = context + .parse_module::(explicit_sumcheck_protocol()) + .expect("parse explicit sumcheck protocol"); + + verify_protocol_schema(&protocol).expect("explicit sumcheck protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower protocol copy to concrete"); + verify_concrete_transcript(&concrete).expect("sumcheck/opening ops thread transcript state"); + + let text = concrete.to_text_mlir(); + assert!(text.contains("\"piop.sumcheck_batch\"(%")); + assert!(text.contains("round_schedule = [2, 1, 1]")); + assert!(text.contains("\"pcs.opening_claim\"(%")); + assert!(text.contains("\"pcs.opening_batch\"(%")); + assert!(text.contains("\"pcs.batch_open\"(%")); +} + +#[test] +fn opening_batch_schema_rejects_hidden_or_reordered_claims() { + let context = MeliorContext::new(); + let protocol = context + .parse_module::(&explicit_sumcheck_protocol().replace( + "ordered_claims = [@stage1.outer.opening]", + "ordered_claims = [@wrong.opening]", + )) + .expect("parse explicit sumcheck protocol"); + + let error = verify_protocol_schema(&protocol).expect_err("opening batch order mismatch"); + assert!(error + .to_string() + .contains("expected @wrong.opening, got @stage1.outer.opening")); +} + +#[test] +fn opening_claim_equal_lowers_through_ssa_pipeline() { + let context = MeliorContext::new(); + let protocol = context + .parse_module::(&opening_claim_equal_protocol( + "LeftInstructionInput", + "LeftInstructionInput", + "point_and_eval", + )) + .expect("parse opening equality protocol"); + verify_protocol_schema(&protocol).expect("opening equality protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower protocol to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let compute = lower_stage2_to_compute(&context, &prover).expect("lower equality to compute"); + verify_compute_schema(&compute).expect("compute equality schema is valid"); + let kernelized = + resolve_compute_kernels(&context, &compute).expect("preserve equality through kernels"); + let cpu = lower_compute_to_cpu(&context, &kernelized).expect("lower equality to CPU"); + verify_cpu_schema(&cpu).expect("CPU equality schema is valid"); + + let compute_text = kernelized.to_text_mlir(); + let cpu_text = cpu.to_text_mlir(); + assert!(compute_text.contains("\"compute.opening_claim_equal\"(%")); + assert!(compute_text.contains("mode = \"point_and_eval\"")); + assert!(cpu_text.contains("\"cpu.opening_claim_equal\"(%")); + assert!(cpu_text.contains("mode = \"point_and_eval\"")); +} + +#[test] +fn opening_claim_equal_rejects_incompatible_claim_metadata() { + let context = MeliorContext::new(); + let protocol = context + .parse_module::(&opening_claim_equal_protocol( + "LeftInstructionInput", + "RightInstructionInput", + "point_and_eval", + )) + .expect("parse bad opening equality protocol"); + + let error = verify_protocol_schema(&protocol).expect_err("mismatched claims are rejected"); + assert!(error.to_string().contains("compares incompatible claims")); +} + +#[test] +fn opening_claim_equal_rejects_unsupported_mode() { + let context = MeliorContext::new(); + let protocol = context + .parse_module::(&opening_claim_equal_protocol( + "LeftInstructionInput", + "LeftInstructionInput", + "eval_only", + )) + .expect("parse bad opening equality mode"); + + let error = verify_protocol_schema(&protocol).expect_err("unsupported equality mode rejected"); + assert!(error.to_string().contains("expected \"point_and_eval\"")); +} + +#[test] +fn sumcheck_compute_lowers_to_cpu_kernel_ir() { + let context = MeliorContext::new(); + let compute = context + .parse_module::(explicit_sumcheck_compute()) + .expect("parse explicit sumcheck compute"); + + verify_compute_schema(&compute).expect("compute sumcheck schema is valid"); + let kernelized = + resolve_compute_kernels(&context, &compute).expect("resolve sumcheck compute kernels"); + verify_compute_schema(&kernelized).expect("kernelized sumcheck schema is valid"); + let cpu = lower_compute_to_cpu(&context, &kernelized).expect("lower sumcheck compute to CPU"); + verify_cpu_schema(&cpu).expect("CPU sumcheck schema is valid"); + + let text = cpu.to_text_mlir(); + assert!(text.contains("\"cpu.transcript_squeeze\"(%")); + assert!(text.contains("\"cpu.sumcheck_batch\"(%")); + assert!(text.contains("\"cpu.sumcheck_driver\"(%")); + assert!(text.contains("\"cpu.sumcheck_eval\"(%")); + assert!(text.contains("\"cpu.pcs_opening_claim\"(%")); + assert!(text.contains("\"cpu.pcs_batch_open\"(%")); + assert!(text.contains("!cpu.sumcheck_claim_type")); +} + +#[test] +fn jolt_stage1_outer_protocol_defines_virtual_claim_flow() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = + build_stage1_outer_protocol(&context, ¶ms).expect("build stage1 outer protocol"); + verify_protocol_schema(&protocol).expect("stage1 protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage1 to concrete"); + verify_concrete_transcript(&concrete).expect("stage1 transcript is threaded"); + + let text = protocol.to_text_mlir(); + assert!(text.contains("sym_name = \"stage1.uniskip.sumcheck\"")); + assert!(text.contains("sym_name = \"stage1.outer_remaining.sumcheck\"")); + assert!(text.contains("relation = @jolt.stage1.outer.uniskip")); + assert!(!text.contains("kernel = @")); + assert!(text.contains("\"piop.sumcheck_claim\"(%")); + assert!(text.contains("\"piop.sumcheck_eval\"(%")); + assert!(text.contains("\"piop.opening_claim\"(%")); + assert!(text.contains("\"piop.opening_batch\"(%")); + assert!(text.contains("count = 35 : i64")); + assert!(text.contains("ordered_claims = [@stage1.outer_remaining.opening.LeftInstructionInput")); + assert!(text.contains("oracle = @OpFlagIsLastInSequence")); + assert!(!text.contains("\"pcs.opening_claim\"")); + assert_or_update_fixture("tests/fixtures/stage1_outer_protocol.mlir", &text); +} + +#[test] +fn jolt_stage2_protocol_defines_product_ram_claim_flow() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage2_protocol(&context, ¶ms).expect("build stage2 protocol"); + verify_protocol_schema(&protocol).expect("stage2 protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage2 to concrete"); + verify_concrete_transcript(&concrete).expect("stage2 transcript is threaded"); + + let text = protocol.to_text_mlir(); + assert!(text.contains("sym_name = \"stage2.product_virtual.uniskip.sumcheck\"")); + assert!(text.contains("sym_name = \"stage2.sumcheck\"")); + assert!(text.contains("relation = @jolt.stage2.product_virtual.uniskip")); + assert!(text.contains("relation = @jolt.stage2.batched")); + assert!(text.contains("\"piop.opening_input\"()")); + assert!(text.contains("\"field.add\"(%")); + assert!(text.contains("\"poly.lagrange_basis_eval\"(%")); + assert!(text.contains("sym_name = \"stage2.ram_read_write.claim_expr\"")); + assert!(text.contains("\"piop.sumcheck_instance_result\"(%")); + assert!(text.contains("round_offset = 16 : i64")); + assert!(text.contains("\"poly.point_slice\"(%")); + assert!(text.contains("\"poly.point_concat\"(%")); + assert!(text.contains( + "ordered_claims = [@stage2.ram_read_write.input, @stage2.product_virtual.remainder.input" + )); + assert!(text.contains("ordered_claims = [@stage2.ram_read_write.opening.RamVal, @stage2.ram_read_write.opening.RamRa, @stage2.ram_read_write.opening.RamInc")); + assert!(text.contains("claim_kind = \"committed\"")); + assert!(text.contains("source_claim = @stage1.outer_remaining.opening.RamAddress")); + assert!(!text.contains("kernel = @")); + assert!(!text.contains("\"compute.")); + assert_or_update_fixture("tests/fixtures/stage2_protocol.mlir", &text); +} + +#[test] +fn jolt_stage2_lowers_to_compute_and_cpu_role_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage2_protocol(&context, ¶ms).expect("build stage2 protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage2 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage2_to_compute(&context, &prover).expect("lower prover stage2"); + let verifier_compute = + lower_stage2_to_compute(&context, &verifier).expect("lower verifier stage2"); + verify_compute_schema(&prover_compute).expect("prover stage2 compute schema is valid"); + verify_compute_schema(&verifier_compute).expect("verifier stage2 compute schema is valid"); + + let prover_compute_text = prover_compute.to_text_mlir(); + let verifier_compute_text = verifier_compute.to_text_mlir(); + assert!(prover_compute_text.contains("\"compute.opening_input\"()")); + assert!(prover_compute_text.contains("\"compute.field_add\"(%")); + assert!(prover_compute_text.contains("\"compute.poly_lagrange_basis_eval\"(%")); + assert!(prover_compute_text.contains("\"compute.point_slice\"(%")); + assert!(prover_compute_text.contains("\"compute.point_concat\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_claim\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_driver\"(%")); + assert!(!prover_compute_text.contains("kernel = @")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify_claim\"")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify\"")); + assert!(!verifier_compute_text.contains("\"compute.kernel\"")); + assert!(!verifier_compute_text.contains("kernel = @")); + + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + verify_compute_schema(&prover_kernel_compute) + .expect("prover kernelized stage2 compute schema is valid"); + verify_compute_schema(&verifier_kernel_compute) + .expect("verifier kernelized stage2 compute schema is valid"); + + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + verify_cpu_schema(&prover_cpu).expect("prover stage2 CPU schema is valid"); + verify_cpu_schema(&verifier_cpu).expect("verifier stage2 CPU schema is valid"); + + let prover_cpu_text = prover_cpu.to_text_mlir(); + let verifier_cpu_text = verifier_cpu.to_text_mlir(); + assert!(prover_cpu_text.contains("\"cpu.opening_input\"()")); + assert!(prover_cpu_text.contains("\"cpu.field_add\"(%")); + assert!(prover_cpu_text.contains("\"cpu.poly_lagrange_basis_eval\"(%")); + assert!(prover_cpu_text.contains("\"cpu.point_slice\"(%")); + assert!(prover_cpu_text.contains("\"cpu.point_concat\"(%")); + assert!(prover_cpu_text.contains("\"cpu.kernel\"()")); + assert!(prover_cpu_text.contains("kernel = @jolt.cpu.stage2.batched")); + assert!(verifier_cpu_text.contains("\"cpu.opening_input\"()")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify_claim\"")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify\"")); + assert!(!verifier_cpu_text.contains("\"cpu.kernel\"")); + assert!(!verifier_cpu_text.contains("kernel = @")); + + assert_or_update_fixture( + "tests/fixtures/stage2_prover_compute.mlir", + &prover_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/stage2_verifier_compute.mlir", + &verifier_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/stage2_prover_kernel_compute.mlir", + &prover_kernel_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/stage2_verifier_kernel_compute.mlir", + &verifier_kernel_compute.to_text_mlir(), + ); + assert_or_update_fixture("tests/fixtures/stage2_prover_cpu.mlir", &prover_cpu_text); + assert_or_update_fixture( + "tests/fixtures/stage2_verifier_cpu.mlir", + &verifier_cpu_text, + ); +} + +#[test] +fn jolt_stage3_protocol_defines_shift_instruction_register_flow() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage3_protocol(&context, ¶ms).expect("build stage3 protocol"); + verify_protocol_schema(&protocol).expect("stage3 protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage3 to concrete"); + verify_concrete_transcript(&concrete).expect("stage3 transcript is threaded"); + + let text = protocol.to_text_mlir(); + assert!(text.contains("sym_name = \"stage3.spartan_shift.input\"")); + assert!(text.contains("sym_name = \"stage3.instruction_input.input\"")); + assert!(text.contains("sym_name = \"stage3.registers_claim_reduction.input\"")); + assert!(text.contains("\"piop.opening_claim_equal\"(%")); + assert!(text.contains("\"field.add\"(%")); + assert!(text.contains("\"field.mul\"(%")); + assert!(text.contains("\"field.sub\"(%")); + assert!(text.contains("policy = \"jolt_core_stage3_aligned\"")); + assert!(text.contains("point_order = \"reverse\"")); + assert!(text.contains("ordered_claims = [@stage3.spartan_shift.input, @stage3.instruction_input.input, @stage3.registers_claim_reduction.input]")); + assert!(text.contains("ordered_claims = [@stage3.spartan_shift.opening.UnexpandedPC, @stage3.spartan_shift.opening.PC")); + assert!(!text.contains("kernel = @")); + assert!(!text.contains("\"compute.")); + assert_or_update_fixture("tests/fixtures/stage3_protocol.mlir", &text); +} + +#[test] +fn jolt_stage3_lowers_to_compute_and_cpu_role_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage3_protocol(&context, ¶ms).expect("build stage3 protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage3 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage3_to_compute(&context, &prover).expect("lower prover stage3"); + let verifier_compute = + lower_stage3_to_compute(&context, &verifier).expect("lower verifier stage3"); + verify_compute_schema(&prover_compute).expect("prover stage3 compute schema is valid"); + verify_compute_schema(&verifier_compute).expect("verifier stage3 compute schema is valid"); + + let prover_compute_text = prover_compute.to_text_mlir(); + let verifier_compute_text = verifier_compute.to_text_mlir(); + assert!(prover_compute_text.contains("\"compute.opening_input\"()")); + assert!(prover_compute_text.contains("\"compute.opening_claim_equal\"(%")); + assert!(prover_compute_text.contains("\"compute.field_add\"(%")); + assert!(prover_compute_text.contains("\"compute.field_mul\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_claim\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_driver\"(%")); + assert!(!prover_compute_text.contains("kernel = @")); + assert!(verifier_compute_text.contains("\"compute.opening_claim_equal\"(%")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify_claim\"")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify\"")); + assert!(!verifier_compute_text.contains("\"compute.kernel\"")); + assert!(!verifier_compute_text.contains("kernel = @")); + + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + verify_compute_schema(&prover_kernel_compute) + .expect("prover kernelized stage3 compute schema is valid"); + verify_compute_schema(&verifier_kernel_compute) + .expect("verifier kernelized stage3 compute schema is valid"); + + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + verify_cpu_schema(&prover_cpu).expect("prover stage3 CPU schema is valid"); + verify_cpu_schema(&verifier_cpu).expect("verifier stage3 CPU schema is valid"); + + let prover_kernel_text = prover_kernel_compute.to_text_mlir(); + let verifier_kernel_text = verifier_kernel_compute.to_text_mlir(); + let prover_cpu_text = prover_cpu.to_text_mlir(); + let verifier_cpu_text = verifier_cpu.to_text_mlir(); + assert!(prover_kernel_text.contains("kernel = @jolt.cpu.stage3.batched")); + assert!(!verifier_kernel_text.contains("kernel = @")); + assert!(prover_cpu_text.contains("\"cpu.opening_claim_equal\"(%")); + assert!(prover_cpu_text.contains("\"cpu.kernel\"()")); + assert!(prover_cpu_text.contains("kernel = @jolt.cpu.stage3.batched")); + assert!(verifier_cpu_text.contains("\"cpu.opening_claim_equal\"(%")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify_claim\"")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify\"")); + assert!(!verifier_cpu_text.contains("\"cpu.kernel\"")); + assert!(!verifier_cpu_text.contains("kernel = @")); + + assert_or_update_fixture( + "tests/fixtures/stage3_prover_compute.mlir", + &prover_compute_text, + ); + assert_or_update_fixture( + "tests/fixtures/stage3_verifier_compute.mlir", + &verifier_compute_text, + ); + assert_or_update_fixture( + "tests/fixtures/stage3_prover_kernel_compute.mlir", + &prover_kernel_text, + ); + assert_or_update_fixture( + "tests/fixtures/stage3_verifier_kernel_compute.mlir", + &verifier_kernel_text, + ); + assert_or_update_fixture("tests/fixtures/stage3_prover_cpu.mlir", &prover_cpu_text); + assert_or_update_fixture( + "tests/fixtures/stage3_verifier_cpu.mlir", + &verifier_cpu_text, + ); +} + +#[test] +fn jolt_stage4_protocol_defines_registers_and_ram_val_flow() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage4_protocol(&context, ¶ms).expect("build stage4 protocol"); + verify_protocol_schema(&protocol).expect("stage4 protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage4 to concrete"); + verify_concrete_transcript(&concrete).expect("stage4 transcript is threaded"); + + let text = protocol.to_text_mlir(); + assert!(text.contains("sym_name = \"stage4.registers_read_write.input\"")); + assert!(text.contains("sym_name = \"stage4.ram_val_check.input\"")); + assert!(text.contains("\"transcript.absorb_bytes\"(%")); + assert!(text.contains("label = \"ram_val_check_gamma\"")); + assert!(text.contains("payload = \"\"")); + assert!(text.contains("sym_name = \"stage4.input.initial_ram.RamValInit\"")); + assert!(text.contains("sym_name = \"stage4.registers.rs1_claim_consistency\"")); + assert!(text.contains("sym_name = \"stage4.registers.rs2_claim_consistency\"")); + assert!(text.contains( + "ordered_claims = [@stage4.registers_read_write.input, @stage4.ram_val_check.input]" + )); + assert!(text.contains("ordered_claims = [@stage4.registers_read_write.opening.RegistersVal")); + assert!(text.contains("@stage4.ram_val_check.opening.RamRa")); + assert!(text.contains("@stage4.ram_val_check.opening.RamInc")); + assert!(!text.contains("kernel = @")); + assert!(!text.contains("\"compute.")); + assert_or_update_fixture("tests/fixtures/stage4_protocol.mlir", &text); +} + +#[test] +fn jolt_stage4_lowers_to_compute_and_cpu_role_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage4_protocol(&context, ¶ms).expect("build stage4 protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage4 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage4_to_compute(&context, &prover).expect("lower prover stage4"); + let verifier_compute = + lower_stage4_to_compute(&context, &verifier).expect("lower verifier stage4"); + verify_compute_schema(&prover_compute).expect("prover stage4 compute schema is valid"); + verify_compute_schema(&verifier_compute).expect("verifier stage4 compute schema is valid"); + + let prover_compute_text = prover_compute.to_text_mlir(); + let verifier_compute_text = verifier_compute.to_text_mlir(); + assert!(prover_compute_text.contains("\"compute.transcript_absorb_bytes\"(%")); + assert!(prover_compute_text.contains("\"compute.opening_claim_equal\"(%")); + assert!(prover_compute_text.contains("\"compute.field_sub\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_claim\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_driver\"(%")); + assert!(verifier_compute_text.contains("\"compute.transcript_absorb_bytes\"(%")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify_claim\"")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify\"")); + assert!(!verifier_compute_text.contains("kernel = @")); + + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + verify_compute_schema(&prover_kernel_compute) + .expect("prover kernelized stage4 compute schema is valid"); + verify_compute_schema(&verifier_kernel_compute) + .expect("verifier kernelized stage4 compute schema is valid"); + + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + verify_cpu_schema(&prover_cpu).expect("prover stage4 CPU schema is valid"); + verify_cpu_schema(&verifier_cpu).expect("verifier stage4 CPU schema is valid"); + + let prover_kernel_text = prover_kernel_compute.to_text_mlir(); + let verifier_kernel_text = verifier_kernel_compute.to_text_mlir(); + let prover_cpu_text = prover_cpu.to_text_mlir(); + let verifier_cpu_text = verifier_cpu.to_text_mlir(); + assert!(prover_kernel_text.contains("kernel = @jolt.cpu.stage4.batched")); + assert!(prover_cpu_text.contains("\"cpu.transcript_absorb_bytes\"(%")); + assert!(prover_cpu_text.contains("kernel = @jolt.cpu.stage4.batched")); + assert!(verifier_cpu_text.contains("\"cpu.transcript_absorb_bytes\"(%")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify_claim\"")); + assert!(!verifier_kernel_text.contains("kernel = @")); + assert!(!verifier_cpu_text.contains("kernel = @")); + + assert_or_update_fixture( + "tests/fixtures/stage4_prover_compute.mlir", + &prover_compute_text, + ); + assert_or_update_fixture( + "tests/fixtures/stage4_verifier_compute.mlir", + &verifier_compute_text, + ); + assert_or_update_fixture( + "tests/fixtures/stage4_prover_kernel_compute.mlir", + &prover_kernel_text, + ); + assert_or_update_fixture( + "tests/fixtures/stage4_verifier_kernel_compute.mlir", + &verifier_kernel_text, + ); + assert_or_update_fixture("tests/fixtures/stage4_prover_cpu.mlir", &prover_cpu_text); + assert_or_update_fixture( + "tests/fixtures/stage4_verifier_cpu.mlir", + &verifier_cpu_text, + ); +} + +#[test] +fn jolt_stage5_protocol_defines_value_lookup_reduction_flow() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage5_protocol(&context, ¶ms).expect("build stage5 protocol"); + verify_protocol_schema(&protocol).expect("stage5 protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage5 to concrete"); + verify_concrete_transcript(&concrete).expect("stage5 transcript is threaded"); + + let text = protocol.to_text_mlir(); + assert!(text.contains("sym_name = \"stage5.instruction_read_raf.input\"")); + assert!(text.contains("sym_name = \"stage5.ram_ra_claim_reduction.input\"")); + assert!(text.contains("sym_name = \"stage5.registers_val_evaluation.input\"")); + assert!(text.contains("sym_name = \"stage5.instruction_read_raf.gamma\"")); + assert!(text.contains("sym_name = \"stage5.ram_ra_claim_reduction.gamma\"")); + assert!(text.contains("sym_name = \"stage5.instruction.lookup_output_claim_consistency\"")); + assert!(text.contains("round_schedule = [128, 16]")); + assert!(text.contains("ordered_claims = [@stage5.instruction_read_raf.input, @stage5.ram_ra_claim_reduction.input, @stage5.registers_val_evaluation.input]")); + assert!(text.contains("@stage5.instruction_read_raf.opening.LookupTableFlag_0")); + assert!(text.contains("@stage5.instruction_read_raf.opening.InstructionRa_0")); + assert!(text.contains("@stage5.instruction_read_raf.opening.InstructionRafFlag")); + assert!(text.contains("@stage5.ram_ra_claim_reduction.opening.RamRa")); + assert!(text.contains("@stage5.registers_val_evaluation.opening.RdInc")); + assert!(text.contains("@stage5.registers_val_evaluation.opening.RdWa")); + assert!(!text.contains("kernel = @")); + assert!(!text.contains("\"compute.")); +} + +#[test] +fn jolt_stage5_lowers_to_compute_and_cpu_role_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage5_protocol(&context, ¶ms).expect("build stage5 protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage5 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage5_to_compute(&context, &prover).expect("lower prover stage5"); + let verifier_compute = + lower_stage5_to_compute(&context, &verifier).expect("lower verifier stage5"); + verify_compute_schema(&prover_compute).expect("prover stage5 compute schema is valid"); + verify_compute_schema(&verifier_compute).expect("verifier stage5 compute schema is valid"); + + let prover_compute_text = prover_compute.to_text_mlir(); + let verifier_compute_text = verifier_compute.to_text_mlir(); + assert!(prover_compute_text.contains("\"compute.opening_claim_equal\"(%")); + assert!(prover_compute_text.contains("\"compute.field_pow\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_claim\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_driver\"(%")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify_claim\"")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify\"")); + assert!(!verifier_compute_text.contains("kernel = @")); + + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + verify_compute_schema(&prover_kernel_compute) + .expect("prover kernelized stage5 compute schema is valid"); + verify_compute_schema(&verifier_kernel_compute) + .expect("verifier kernelized stage5 compute schema is valid"); + + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + verify_cpu_schema(&prover_cpu).expect("prover stage5 CPU schema is valid"); + verify_cpu_schema(&verifier_cpu).expect("verifier stage5 CPU schema is valid"); + + let prover_kernel_text = prover_kernel_compute.to_text_mlir(); + let verifier_kernel_text = verifier_kernel_compute.to_text_mlir(); + let prover_cpu_text = prover_cpu.to_text_mlir(); + let verifier_cpu_text = verifier_cpu.to_text_mlir(); + assert!(prover_kernel_text.contains("kernel = @jolt.cpu.stage5.batched")); + assert!(prover_cpu_text.contains("kernel = @jolt.cpu.stage5.batched")); + assert!(prover_cpu_text.contains("point_order = \"instruction_read_raf\"")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify_claim\"")); + assert!(!verifier_kernel_text.contains("kernel = @")); + assert!(!verifier_cpu_text.contains("kernel = @")); +} + +#[test] +fn jolt_stage6_protocol_defines_bytecode_booleanity_and_virtualization_flow() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage6_protocol(&context, ¶ms).expect("build stage6 protocol"); + verify_protocol_schema(&protocol).expect("stage6 protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage6 to concrete"); + verify_concrete_transcript(&concrete).expect("stage6 transcript is threaded"); + + let text = protocol.to_text_mlir(); + assert!(text.contains("sym_name = \"stage6.bytecode_read_raf.input\"")); + assert!(text.contains("sym_name = \"stage6.booleanity.input\"")); + assert!(text.contains("sym_name = \"stage6.hamming_booleanity.input\"")); + assert!(text.contains("sym_name = \"stage6.ram_ra_virtual.input\"")); + assert!(text.contains("sym_name = \"stage6.instruction_ra_virtual.input\"")); + assert!(text.contains("sym_name = \"stage6.inc_claim_reduction.input\"")); + assert!(text.contains("sym_name = \"stage6.bytecode_read_raf.gamma\"")); + assert!(text.contains("sym_name = \"stage6.bytecode_read_raf.stage5_gamma\"")); + assert!(text.contains("sym_name = \"stage6.booleanity.gamma\"")); + assert!(text.contains("sym_name = \"stage6.instruction_ra_virtual.gamma\"")); + assert!(text.contains("sym_name = \"stage6.inc_claim_reduction.gamma\"")); + assert!(text.contains("sym_name = \"stage6.booleanity.gamma_sq_0\"")); + assert!(text.contains("source_claim = @stage2.ram_read_write.opening.RamInc")); + assert!(text.contains("source_claim = @stage4.registers_read_write.opening.RdInc")); + assert!(text.contains("source_claim = @stage5.registers_val_evaluation.opening.RdInc")); + assert!(text.contains("round_schedule = [10, 16]")); + assert!(text.contains("ordered_claims = [@stage6.bytecode_read_raf.input, @stage6.booleanity.input, @stage6.hamming_booleanity.input, @stage6.ram_ra_virtual.input, @stage6.instruction_ra_virtual.input, @stage6.inc_claim_reduction.input]")); + assert!(text.contains("@stage6.bytecode_read_raf.opening.BytecodeRa_0")); + assert!(text.contains("@stage6.booleanity.opening.InstructionRa_0")); + assert!(text.contains("@stage6.hamming_booleanity.opening.HammingWeight")); + assert!(text.contains("@stage6.ram_ra_virtual.opening.RamRa_0")); + assert!(text.contains("@stage6.instruction_ra_virtual.opening.InstructionRa_0")); + assert!(text.contains("@stage6.inc_claim_reduction.opening.RamInc")); + assert!(text.contains("@stage6.inc_claim_reduction.opening.RdInc")); + assert!(!text.contains("kernel = @")); + assert!(!text.contains("\"compute.")); +} + +#[test] +fn jolt_stage6_lowers_to_compute_and_cpu_role_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage6_protocol(&context, ¶ms).expect("build stage6 protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage6 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage6_to_compute(&context, &prover).expect("lower prover stage6"); + let verifier_compute = + lower_stage6_to_compute(&context, &verifier).expect("lower verifier stage6"); + verify_compute_schema(&prover_compute).expect("prover stage6 compute schema is valid"); + verify_compute_schema(&verifier_compute).expect("verifier stage6 compute schema is valid"); + + let prover_compute_text = prover_compute.to_text_mlir(); + let verifier_compute_text = verifier_compute.to_text_mlir(); + assert!(prover_compute_text.contains("\"compute.field_pow\"(%")); + assert!(prover_compute_text.contains("\"compute.field_zero\"()")); + assert!(prover_compute_text.contains("\"compute.sumcheck_claim\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_driver\"(%")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify_claim\"")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify\"")); + assert!(!verifier_compute_text.contains("kernel = @")); + + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + verify_compute_schema(&prover_kernel_compute) + .expect("prover kernelized stage6 compute schema is valid"); + verify_compute_schema(&verifier_kernel_compute) + .expect("verifier kernelized stage6 compute schema is valid"); + + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + verify_cpu_schema(&prover_cpu).expect("prover stage6 CPU schema is valid"); + verify_cpu_schema(&verifier_cpu).expect("verifier stage6 CPU schema is valid"); + + let prover_kernel_text = prover_kernel_compute.to_text_mlir(); + let verifier_kernel_text = verifier_kernel_compute.to_text_mlir(); + let prover_cpu_text = prover_cpu.to_text_mlir(); + let verifier_cpu_text = verifier_cpu.to_text_mlir(); + assert!(prover_kernel_text.contains("kernel = @jolt.cpu.stage6.batched")); + assert!(prover_cpu_text.contains("kernel = @jolt.cpu.stage6.batched")); + assert!(prover_cpu_text.contains("point_order = \"bytecode_read_raf\"")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify_claim\"")); + assert!(!verifier_kernel_text.contains("kernel = @")); + assert!(!verifier_cpu_text.contains("kernel = @")); +} + +#[test] +fn jolt_stage7_protocol_defines_hamming_weight_claim_reduction_flow() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage7_protocol(&context, ¶ms).expect("build stage7 protocol"); + verify_protocol_schema(&protocol).expect("stage7 protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage7 to concrete"); + verify_concrete_transcript(&concrete).expect("stage7 transcript is threaded"); + + let text = protocol.to_text_mlir(); + assert!(text.contains("sym_name = \"jolt.stage7_hamming_weight_claim_reduction_domain\"")); + assert!(text.contains("sym_name = \"jolt.stage7.hamming_weight_claim_reduction\"")); + assert!(text.contains("sym_name = \"jolt.stage7.batched\"")); + assert!(text.contains("sym_name = \"stage7.hamming_weight_claim_reduction.gamma\"")); + assert!(text.contains("sym_name = \"stage7.field.one\"")); + assert!(text.contains("sym_name = \"stage7.hamming_weight_claim_reduction.input\"")); + assert!(text.contains("round_schedule = [4]")); + assert!(text.contains("ordered_claims = [@stage7.hamming_weight_claim_reduction.input]")); + assert!(text.contains("source_claim = @stage6.booleanity.opening.InstructionRa_0")); + assert!(text.contains("source_claim = @stage6.instruction_ra_virtual.opening.InstructionRa_0")); + assert!(text.contains("source_claim = @stage6.bytecode_read_raf.opening.BytecodeRa_0")); + assert!(text.contains("source_claim = @stage6.ram_ra_virtual.opening.RamRa_0")); + assert!(text.contains("source_claim = @stage6.hamming_booleanity.opening.HammingWeight")); + assert!(text.contains("sym_name = \"stage7.hamming_weight_claim_reduction.point.cycle\"")); + assert!(text.contains("sym_name = \"stage7.hamming_weight_claim_reduction.point\"")); + assert!(text.contains("@stage7.hamming_weight_claim_reduction.opening.InstructionRa_0")); + assert!(text.contains("@stage7.hamming_weight_claim_reduction.opening.BytecodeRa_0")); + assert!(text.contains("@stage7.hamming_weight_claim_reduction.opening.RamRa_0")); + assert!(!text.contains("kernel = @")); + assert!(!text.contains("\"compute.")); +} + +#[test] +fn jolt_stage7_lowers_to_compute_and_cpu_role_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage7_protocol(&context, ¶ms).expect("build stage7 protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage7 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage7_to_compute(&context, &prover).expect("lower prover stage7"); + let verifier_compute = + lower_stage7_to_compute(&context, &verifier).expect("lower verifier stage7"); + verify_compute_schema(&prover_compute).expect("prover stage7 compute schema is valid"); + verify_compute_schema(&verifier_compute).expect("verifier stage7 compute schema is valid"); + + let prover_compute_text = prover_compute.to_text_mlir(); + let verifier_compute_text = verifier_compute.to_text_mlir(); + assert!(prover_compute_text.contains("\"compute.field_one\"()")); + assert!(prover_compute_text.contains("\"compute.field_pow\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_claim\"(%")); + assert!(prover_compute_text.contains("\"compute.sumcheck_driver\"(%")); + assert!(prover_compute_text.contains("\"compute.point_concat\"(%")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify_claim\"")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify\"")); + assert!(!verifier_compute_text.contains("kernel = @")); + + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + verify_compute_schema(&prover_kernel_compute) + .expect("prover kernelized stage7 compute schema is valid"); + verify_compute_schema(&verifier_kernel_compute) + .expect("verifier kernelized stage7 compute schema is valid"); + + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + verify_cpu_schema(&prover_cpu).expect("prover stage7 CPU schema is valid"); + verify_cpu_schema(&verifier_cpu).expect("verifier stage7 CPU schema is valid"); + + let prover_kernel_text = prover_kernel_compute.to_text_mlir(); + let verifier_kernel_text = verifier_kernel_compute.to_text_mlir(); + let prover_cpu_text = prover_cpu.to_text_mlir(); + let verifier_cpu_text = verifier_cpu.to_text_mlir(); + assert!(prover_kernel_text.contains("kernel = @jolt.cpu.stage7.batched")); + assert!(prover_cpu_text.contains("kernel = @jolt.cpu.stage7.batched")); + assert!(prover_cpu_text.contains("point_order = \"reverse\"")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify_claim\"")); + assert!(!verifier_kernel_text.contains("kernel = @")); + assert!(!verifier_cpu_text.contains("kernel = @")); +} + +#[test] +fn jolt_stage8_protocol_defines_evaluation_proof_flow() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage8_protocol(&context, ¶ms).expect("build stage8 protocol"); + verify_protocol_schema(&protocol).expect("stage8 protocol schema is valid"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage8 to concrete"); + verify_concrete_transcript(&concrete).expect("stage8 transcript is threaded"); + + let text = protocol.to_text_mlir(); + assert!(text.contains("sym_name = \"jolt.stage8\"")); + assert!(text.contains("name = \"evaluation_proof\"")); + assert!(text.contains("sym_name = \"stage8.evaluation.point_source\"")); + assert!(text.contains("source_claim = @stage7.input.stage6.booleanity.InstructionRa_0")); + assert!(text.contains("sym_name = \"stage8.evaluation.opening.RamInc\"")); + assert!(text.contains("source_claim = @stage6.inc_claim_reduction.eval.RamInc")); + assert!(text.contains("sym_name = \"stage8.evaluation.opening.InstructionRa_0\"")); + assert!( + text.contains("source_claim = @stage7.hamming_weight_claim_reduction.eval.InstructionRa_0") + ); + assert!(text.contains("\"pcs.opening_batch\"(%")); + assert!(text.contains("policy = \"jolt_stage8_joint_rlc\"")); + assert!(text.contains("transcript_label = \"rlc_claims\"")); + assert!(text.contains("ordered_claims = [@stage8.evaluation.opening.RamInc, @stage8.evaluation.opening.RdInc, @stage8.evaluation.opening.InstructionRa_0")); +} + +#[test] +fn jolt_stage8_lowers_to_compute_and_cpu_role_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_stage8_protocol(&context, ¶ms).expect("build stage8 protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage8 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage8_to_compute(&context, &prover).expect("lower prover stage8"); + let verifier_compute = + lower_stage8_to_compute(&context, &verifier).expect("lower verifier stage8"); + verify_compute_schema(&prover_compute).expect("prover stage8 compute schema is valid"); + verify_compute_schema(&verifier_compute).expect("verifier stage8 compute schema is valid"); + + let prover_compute_text = prover_compute.to_text_mlir(); + let verifier_compute_text = verifier_compute.to_text_mlir(); + assert!(prover_compute_text.contains("\"compute.pcs_opening_claim\"(%")); + assert!(prover_compute_text.contains("\"compute.pcs_opening_batch\"(%")); + assert!(prover_compute_text.contains("\"compute.pcs_batch_open\"(%")); + assert!(!prover_compute_text.contains("\"compute.pcs_batch_verify\"(%")); + assert!(verifier_compute_text.contains("\"compute.pcs_batch_verify\"(%")); + assert!(!verifier_compute_text.contains("\"compute.pcs_batch_open\"(%")); + + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + verify_compute_schema(&prover_kernel_compute) + .expect("prover kernelized stage8 compute schema is valid"); + verify_compute_schema(&verifier_kernel_compute) + .expect("verifier kernelized stage8 compute schema is valid"); + + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + verify_cpu_schema(&prover_cpu).expect("prover stage8 CPU schema is valid"); + verify_cpu_schema(&verifier_cpu).expect("verifier stage8 CPU schema is valid"); + + let prover_cpu_text = prover_cpu.to_text_mlir(); + let verifier_cpu_text = verifier_cpu.to_text_mlir(); + assert!(prover_cpu_text.contains("\"cpu.pcs_batch_open\"(%")); + assert!(verifier_cpu_text.contains("\"cpu.pcs_batch_verify\"(%")); + assert!(!prover_cpu_text.contains("kernel = @")); + assert!(!verifier_cpu_text.contains("kernel = @")); +} + +#[test] +fn stage2_rust_targets_extract_and_compile() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let (prover_cpu, verifier_cpu) = build_stage2_pipeline_cpu(&context, ¶ms); + let prover_program = stage2_cpu_program(&prover_cpu).expect("extract prover stage2 program"); + let verifier_program = + stage2_cpu_program(&verifier_cpu).expect("extract verifier stage2 program"); + + assert_eq!(prover_program.role, Role::Prover); + assert_eq!(verifier_program.role, Role::Verifier); + assert_eq!(prover_program.kernels.len(), 7); + assert!(verifier_program.kernels.is_empty()); + assert_eq!(prover_program.opening_inputs.len(), 11); + assert_eq!(prover_program.field_exprs.len(), 21); + assert_eq!(prover_program.field_constants.len(), 1); + assert_eq!(prover_program.claims.len(), 6); + assert_eq!(prover_program.drivers.len(), 2); + assert_eq!(prover_program.point_slices.len(), 1); + assert_eq!(prover_program.point_concats.len(), 1); + assert!(prover_program + .claims + .iter() + .any(|claim| claim.claim_value == "stage2.ram_read_write.claim_expr")); + assert!(prover_program + .drivers + .iter() + .any(|driver| driver.kernel.as_deref() == Some("jolt.cpu.stage2.batched"))); + assert!(verifier_program + .claims + .iter() + .all(|claim| claim.kernel.is_none() && claim.relation.is_some())); + assert!(verifier_program + .drivers + .iter() + .all(|driver| driver.kernel.is_none() && driver.relation.is_some())); + + let prover_source = emit_stage2_rust(&prover_cpu).expect("emit stage2 prover rust"); + let verifier_source = emit_stage2_rust(&verifier_cpu).expect("emit stage2 verifier rust"); + assert_eq!(prover_source.filename, "prove_stage2.rs"); + assert_eq!(verifier_source.filename, "verify_stage2.rs"); + assert!(prover_source.source.contains("jolt_stage2_ram_read_write")); + assert!(prover_source.source.contains("Stage2KernelExecutor")); + assert!(!verifier_source.source.contains("jolt_kernels")); + assert!(verifier_source.source.contains("Stage2VerifierProgramPlan")); + assert!(verifier_source.source.contains("pub fn verify_stage2")); + assert!(verifier_source.source.contains("SumcheckVerifier::verify")); + assert_or_update_fixture("tests/fixtures/prove_stage2.rs", &prover_source.source); + assert_or_update_fixture("tests/fixtures/verify_stage2.rs", &verifier_source.source); + assert_rust_source_compiles(&prover_source.filename, &prover_source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn stage3_rust_targets_extract_and_compile() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let (prover_cpu, verifier_cpu) = build_stage3_pipeline_cpu(&context, ¶ms); + let prover_program = stage3_cpu_program(&prover_cpu).expect("extract prover stage3 program"); + let verifier_program = + stage3_cpu_program(&verifier_cpu).expect("extract verifier stage3 program"); + + assert_eq!(prover_program.role, Role::Prover); + assert_eq!(verifier_program.role, Role::Verifier); + assert_eq!(prover_program.kernels.len(), 4); + assert!(verifier_program.kernels.is_empty()); + assert_eq!(prover_program.opening_inputs.len(), 12); + assert_eq!(prover_program.field_exprs.len(), 19); + assert_eq!(prover_program.field_constants.len(), 1); + assert_eq!(prover_program.opening_equalities.len(), 2); + assert_eq!(prover_program.claims.len(), 3); + assert_eq!(prover_program.drivers.len(), 1); + assert_eq!(prover_program.opening_claims.len(), 16); + assert!(prover_program + .drivers + .iter() + .any(|driver| driver.kernel.as_deref() == Some("jolt.cpu.stage3.batched"))); + assert!(verifier_program + .claims + .iter() + .all(|claim| claim.kernel.is_none() && claim.relation.is_some())); + assert!(verifier_program + .drivers + .iter() + .all(|driver| driver.kernel.is_none() && driver.relation.is_some())); + + let prover_source = emit_stage3_rust(&prover_cpu).expect("emit stage3 prover rust"); + let verifier_source = emit_stage3_rust(&verifier_cpu).expect("emit stage3 verifier rust"); + assert_eq!(prover_source.filename, "prove_stage3.rs"); + assert_eq!(verifier_source.filename, "verify_stage3.rs"); + assert!(prover_source.source.contains("jolt_stage3_spartan_shift")); + assert!(prover_source.source.contains("Stage3KernelExecutor")); + assert!(prover_source + .source + .contains("Stage3OpeningClaimEqualityPlan")); + assert!(!verifier_source.source.contains("jolt_kernels")); + assert!(verifier_source.source.contains("Stage3VerifierProgramPlan")); + assert!(verifier_source.source.contains("pub fn verify_stage3")); + assert!(verifier_source + .source + .contains("super::common::verify_batched_sumcheck")); + assert!(verifier_source + .source + .contains("Stage3OpeningClaimEqualityPlan")); + assert_or_update_fixture("tests/fixtures/prove_stage3.rs", &prover_source.source); + assert_or_update_fixture("tests/fixtures/verify_stage3.rs", &verifier_source.source); + assert_rust_source_compiles(&prover_source.filename, &prover_source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn stage4_rust_targets_extract_and_compile() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let (prover_cpu, verifier_cpu) = build_stage4_pipeline_cpu(&context, ¶ms); + let prover_program = stage4_cpu_program(&prover_cpu).expect("extract prover stage4 program"); + let verifier_program = + stage4_cpu_program(&verifier_cpu).expect("extract verifier stage4 program"); + + assert_eq!(prover_program.role, Role::Prover); + assert_eq!(verifier_program.role, Role::Verifier); + assert_eq!(prover_program.kernels.len(), 3); + assert!(verifier_program.kernels.is_empty()); + assert_eq!(prover_program.steps.len(), 4); + assert_eq!(prover_program.transcript_squeezes.len(), 2); + assert_eq!(prover_program.transcript_absorb_bytes.len(), 1); + assert_eq!(prover_program.opening_inputs.len(), 8); + assert_eq!(prover_program.field_exprs.len(), 9); + assert!(prover_program.field_constants.is_empty()); + assert_eq!(prover_program.opening_equalities.len(), 2); + assert_eq!(prover_program.claims.len(), 2); + assert_eq!(prover_program.drivers.len(), 1); + assert_eq!(prover_program.instance_results.len(), 2); + assert_eq!(prover_program.evals.len(), 7); + assert_eq!(prover_program.point_slices.len(), 2); + assert_eq!(prover_program.point_concats.len(), 1); + assert_eq!(prover_program.opening_claims.len(), 7); + assert_eq!(prover_program.opening_batches.len(), 1); + assert!(prover_program + .transcript_absorb_bytes + .iter() + .any( + |absorb| absorb.symbol == "stage4.ram_val_check.domain_separator" + && absorb.label == "ram_val_check_gamma" + && absorb.payload.is_empty() + )); + assert!(prover_program + .drivers + .iter() + .any(|driver| driver.kernel.as_deref() == Some("jolt.cpu.stage4.batched"))); + assert!(verifier_program + .claims + .iter() + .all(|claim| claim.kernel.is_none() && claim.relation.is_some())); + assert!(verifier_program + .drivers + .iter() + .all(|driver| driver.kernel.is_none() && driver.relation.is_some())); + + let prover_source = emit_stage4_rust(&prover_cpu).expect("emit stage4 prover rust"); + let verifier_source = emit_stage4_rust(&verifier_cpu).expect("emit stage4 verifier rust"); + assert_eq!(prover_source.filename, "prove_stage4.rs"); + assert_eq!(verifier_source.filename, "verify_stage4.rs"); + assert!(prover_source.source.contains("jolt_stage4_ram_val_check")); + assert!(prover_source + .source + .contains("Stage4TranscriptAbsorbBytesPlan")); + assert!(prover_source + .source + .contains("STAGE4_TRANSCRIPT_ABSORB_BYTES")); + assert!(prover_source.source.contains("Stage4KernelExecutor")); + assert!(prover_source.source.contains("execute_stage4_program")); + assert!(prover_source.source.contains("execute_stage4_prover")); + assert!(!verifier_source.source.contains("jolt_kernels")); + assert!(verifier_source + .source + .contains("Stage4TranscriptAbsorbBytesPlan")); + assert!(verifier_source + .source + .contains("relation: Some(\"jolt.stage4.batched\")")); + assert!(verifier_source.source.contains("Stage4VerifierProgramPlan")); + assert!(verifier_source.source.contains("pub fn verify_stage4")); + assert!(verifier_source.source.contains("LabelWithCount")); + assert!(verifier_source + .source + .contains("super::common::verify_batched_sumcheck")); + assert!(verifier_source.source.contains("stage4_verifier_program")); + assert_or_update_fixture("tests/fixtures/prove_stage4.rs", &prover_source.source); + assert_or_update_fixture("tests/fixtures/verify_stage4.rs", &verifier_source.source); + assert_rust_source_compiles(&prover_source.filename, &prover_source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn stage5_rust_targets_extract_and_compile() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let (prover_cpu, verifier_cpu) = build_stage5_pipeline_cpu(&context, ¶ms); + let prover_program = stage5_cpu_program(&prover_cpu).expect("extract prover stage5 program"); + let verifier_program = + stage5_cpu_program(&verifier_cpu).expect("extract verifier stage5 program"); + + assert_eq!(prover_program.role, Role::Prover); + assert_eq!(verifier_program.role, Role::Verifier); + assert_eq!(prover_program.kernels.len(), 4); + assert!(verifier_program.kernels.is_empty()); + assert_eq!(prover_program.steps.len(), 3); + assert_eq!(prover_program.transcript_squeezes.len(), 2); + assert!(prover_program.transcript_absorb_bytes.is_empty()); + assert_eq!(prover_program.opening_inputs.len(), 8); + assert_eq!(prover_program.field_exprs.len(), 10); + assert!(prover_program.field_constants.is_empty()); + assert_eq!(prover_program.opening_equalities.len(), 1); + assert_eq!(prover_program.claims.len(), 3); + assert_eq!(prover_program.drivers.len(), 1); + assert_eq!(prover_program.instance_results.len(), 3); + assert_eq!( + prover_program.evals.len(), + params.lookup_table_count + params.instruction_ra_virtual_d + 4 + ); + assert_eq!( + prover_program.point_slices.len(), + params.instruction_ra_virtual_d + 3 + ); + assert_eq!( + prover_program.point_concats.len(), + params.instruction_ra_virtual_d + 2 + ); + assert_eq!( + prover_program.opening_claims.len(), + params.lookup_table_count + params.instruction_ra_virtual_d + 4 + ); + assert_eq!(prover_program.opening_batches.len(), 1); + assert!(prover_program + .drivers + .iter() + .any(|driver| driver.kernel.as_deref() == Some("jolt.cpu.stage5.batched"))); + assert!(prover_program.instance_results.iter().any(|instance| { + instance.symbol == "stage5.instruction_read_raf.instance" + && instance.point_order == "instruction_read_raf" + })); + assert!(verifier_program + .claims + .iter() + .all(|claim| claim.kernel.is_none() && claim.relation.is_some())); + assert!(verifier_program + .drivers + .iter() + .all(|driver| driver.kernel.is_none() && driver.relation.is_some())); + + let prover_source = emit_stage5_rust(&prover_cpu).expect("emit stage5 prover rust"); + let verifier_source = emit_stage5_rust(&verifier_cpu).expect("emit stage5 verifier rust"); + assert_eq!(prover_source.filename, "prove_stage5.rs"); + assert_eq!(verifier_source.filename, "verify_stage5.rs"); + assert!(prover_source + .source + .contains("jolt_stage5_instruction_read_raf")); + assert!(prover_source.source.contains("Stage5KernelExecutor")); + assert!(prover_source.source.contains("execute_stage5_program")); + assert!(prover_source.source.contains("execute_stage5_prover")); + assert!(!verifier_source.source.contains("jolt_kernels")); + assert!(verifier_source.source.contains("Stage5VerifierProgramPlan")); + assert!(verifier_source.source.contains("pub fn verify_stage5")); + assert!(verifier_source + .source + .contains("relation: Some(\"jolt.stage5.batched\")")); + assert!(verifier_source + .source + .contains("expected_instruction_read_raf")); + assert!(verifier_source + .source + .contains("jolt.stage5.instruction_read_raf")); + assert!(verifier_source + .source + .contains("LookupTableKind::::all")); + assert!(verifier_source + .source + .contains("use jolt_lookup_tables::LookupTableKind")); + assert!(verifier_source + .source + .contains("expected_ram_ra_claim_reduction")); + assert!(verifier_source + .source + .contains("expected_registers_val_evaluation")); + assert!(verifier_source + .source + .contains("jolt.stage5.ram_ra_claim_reduction")); + assert!(verifier_source + .source + .contains("jolt.stage5.registers_val_evaluation")); + assert!(verifier_source.source.contains("LookupTableFlag_39")); + assert!(!verifier_source.source.contains("LookupTableFlag_40")); + assert!(verifier_source + .source + .contains("stage5.instruction_read_raf.eval.InstructionRa_7")); + assert!(!verifier_source + .source + .contains("stage5.instruction_read_raf.eval.InstructionRa_8")); + assert!(!verifier_source + .source + .contains("jolt.stage5.registers_read_write")); + assert!(!verifier_source.source.contains("jolt.stage5.ram_val_check")); + assert!(verifier_source + .source + .contains("super::common::verify_batched_sumcheck")); + assert!(verifier_source.source.contains("stage5_verifier_program")); + assert_rust_source_compiles(&prover_source.filename, &prover_source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn stage6_rust_targets_extract_and_compile() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let (prover_cpu, verifier_cpu) = build_stage6_pipeline_cpu(&context, ¶ms); + let prover_program = stage6_cpu_program(&prover_cpu).expect("extract prover stage6 program"); + let verifier_program = + stage6_cpu_program(&verifier_cpu).expect("extract verifier stage6 program"); + + assert_eq!(prover_program.role, Role::Prover); + assert_eq!(verifier_program.role, Role::Verifier); + assert_eq!(prover_program.kernels.len(), 7); + assert!(verifier_program.kernels.is_empty()); + assert_eq!(prover_program.steps.len(), 10); + assert_eq!(prover_program.transcript_squeezes.len(), 9); + assert!(prover_program.transcript_absorb_bytes.is_empty()); + assert_eq!(prover_program.opening_inputs.len(), 90); + assert!(prover_program.field_exprs.len() > 150); + assert_eq!(prover_program.field_constants.len(), 1); + assert!(prover_program.opening_equalities.is_empty()); + assert_eq!(prover_program.claims.len(), 6); + assert_eq!(prover_program.drivers.len(), 1); + assert_eq!(prover_program.instance_results.len(), 6); + assert_eq!( + prover_program.evals.len(), + params.bytecode_d + + params.instruction_d + + params.bytecode_d + + params.ram_d + + 1 + + params.ram_d + + params.instruction_d + + 2 + ); + assert_eq!(prover_program.point_zeros.len(), 1); + assert_eq!( + prover_program.point_slices.len(), + params.bytecode_d + 1 + params.ram_d + params.instruction_d + ); + assert_eq!( + prover_program.point_concats.len(), + params.bytecode_d + 1 + params.ram_d + params.instruction_d + ); + assert_eq!( + prover_program.opening_claims.len(), + prover_program.evals.len() + ); + assert_eq!(prover_program.opening_batches.len(), 1); + assert!(prover_program + .drivers + .iter() + .any(|driver| driver.kernel.as_deref() == Some("jolt.cpu.stage6.batched"))); + assert!(prover_program.instance_results.iter().any(|instance| { + instance.symbol == "stage6.bytecode_read_raf.instance" + && instance.point_order == "bytecode_read_raf" + })); + assert!(prover_program.instance_results.iter().any(|instance| { + instance.symbol == "stage6.booleanity.instance" + && instance.point_order == "stage6_booleanity" + })); + assert!(verifier_program.opening_inputs.iter().any(|input| { + input.symbol == "stage6.input.stage1.LookupOutput" + && input.source_stage == "stage1" + && input.source_claim == "stage1.outer_remaining.opening.LookupOutput" + })); + assert!(verifier_program.claims.iter().any(|claim| { + claim.symbol == "stage6.hamming_booleanity.input" + && claim + .input_openings + .contains(&"stage6.input.stage1.LookupOutput".to_owned()) + })); + assert!(verifier_program.claims.iter().any(|claim| { + claim.symbol == "stage6.booleanity.input" && claim.input_openings.is_empty() + })); + assert!(verifier_program + .claims + .iter() + .all(|claim| claim.kernel.is_none() && claim.relation.is_some())); + assert!(verifier_program + .drivers + .iter() + .all(|driver| driver.kernel.is_none() && driver.relation.is_some())); + + let prover_source = emit_stage6_rust(&prover_cpu).expect("emit stage6 prover rust"); + let verifier_source = emit_stage6_rust(&verifier_cpu).expect("emit stage6 verifier rust"); + assert_eq!(prover_source.filename, "prove_stage6.rs"); + assert_eq!(verifier_source.filename, "verify_stage6.rs"); + assert!(prover_source + .source + .contains("jolt_stage6_bytecode_read_raf")); + assert!(prover_source.source.contains("Stage6KernelExecutor")); + assert!(prover_source.source.contains("execute_stage6_program")); + assert!(prover_source.source.contains("execute_stage6_prover")); + assert!(!verifier_source.source.contains("jolt_kernels")); + assert!(verifier_source.source.contains("Stage6VerifierProgramPlan")); + assert!(verifier_source.source.contains("pub fn verify_stage6")); + assert!(verifier_source + .source + .contains("relation: Some(\"jolt.stage6.batched\")")); + assert!(verifier_source + .source + .contains("jolt.stage6.bytecode_read_raf")); + assert!(verifier_source.source.contains("Stage6VerifierData")); + assert!(verifier_source.source.contains("Stage6BytecodeReadRafData")); + assert!(verifier_source.source.contains("Stage6BytecodeEntry")); + assert!(verifier_source + .source + .contains("expected_bytecode_read_raf")); + assert!(verifier_source + .source + .contains("stage6.bytecode_read_raf.data")); + assert!(verifier_source.source.contains("expected_booleanity")); + assert!(verifier_source + .source + .contains("expected_hamming_booleanity")); + assert!(verifier_source + .source + .contains("jolt.stage6.inc_claim_reduction")); + assert!(verifier_source + .source + .contains("stage6.input.stage1.LookupOutput")); + assert!(verifier_source.source.contains("expected_ram_ra_virtual")); + assert!(verifier_source + .source + .contains("expected_instruction_ra_virtual")); + assert!(verifier_source + .source + .contains("expected_inc_claim_reduction")); + assert!(verifier_source + .source + .contains("stage6.bytecode_read_raf.eval.BytecodeRa_0")); + assert!(verifier_source + .source + .contains("stage6.booleanity.eval.InstructionRa_31")); + assert!(verifier_source + .source + .contains("stage6.inc_claim_reduction.eval.RdInc")); + assert!(verifier_source + .source + .contains("super::common::verify_batched_sumcheck")); + assert!(verifier_source.source.contains("stage6_verifier_program")); + assert_rust_source_compiles(&prover_source.filename, &prover_source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn stage7_rust_targets_extract_and_compile() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let total_ra = params.instruction_d + params.bytecode_d + params.ram_d; + let (prover_cpu, verifier_cpu) = build_stage7_pipeline_cpu(&context, ¶ms); + let prover_program = stage7_cpu_program(&prover_cpu).expect("extract prover stage7 program"); + let verifier_program = + stage7_cpu_program(&verifier_cpu).expect("extract verifier stage7 program"); + + assert_eq!(prover_program.role, Role::Prover); + assert_eq!(verifier_program.role, Role::Verifier); + assert_eq!(prover_program.kernels.len(), 2); + assert!(verifier_program.kernels.is_empty()); + assert_eq!(prover_program.steps.len(), 2); + assert_eq!(prover_program.transcript_squeezes.len(), 1); + assert!(prover_program.transcript_absorb_bytes.is_empty()); + assert_eq!(prover_program.opening_inputs.len(), 1 + 2 * total_ra); + assert_eq!(prover_program.field_constants.len(), 1); + assert!(prover_program.field_exprs.len() >= 3 * total_ra); + assert_eq!(prover_program.claims.len(), 1); + assert_eq!(prover_program.batches.len(), 1); + assert_eq!(prover_program.drivers.len(), 1); + assert_eq!(prover_program.instance_results.len(), 1); + assert_eq!(prover_program.evals.len(), total_ra); + assert!(prover_program.point_zeros.is_empty()); + assert_eq!(prover_program.point_slices.len(), 1); + assert_eq!(prover_program.point_concats.len(), 1); + assert_eq!(prover_program.opening_claims.len(), total_ra); + assert_eq!(prover_program.opening_batches.len(), 1); + assert!(prover_program + .drivers + .iter() + .any(|driver| driver.kernel.as_deref() == Some("jolt.cpu.stage7.batched"))); + assert!(prover_program.claims.iter().any(|claim| { + claim.symbol == "stage7.hamming_weight_claim_reduction.input" + && claim.kernel.as_deref() == Some("jolt.cpu.stage7.hamming_weight_claim_reduction") + })); + assert!(prover_program.opening_claims.iter().any(|claim| { + claim.symbol == "stage7.hamming_weight_claim_reduction.opening.InstructionRa_0" + && claim.point_source == "stage7.hamming_weight_claim_reduction.point" + })); + assert!(verifier_program + .claims + .iter() + .all(|claim| claim.kernel.is_none() && claim.relation.is_some())); + assert!(verifier_program + .drivers + .iter() + .all(|driver| driver.kernel.is_none() && driver.relation.is_some())); + + let prover_source = emit_stage7_rust(&prover_cpu).expect("emit stage7 prover rust"); + let verifier_source = emit_stage7_rust(&verifier_cpu).expect("emit stage7 verifier rust"); + assert_eq!(prover_source.filename, "prove_stage7.rs"); + assert_eq!(verifier_source.filename, "verify_stage7.rs"); + assert!(prover_source + .source + .contains("jolt_stage7_hamming_weight_claim_reduction")); + assert!(prover_source.source.contains("Stage7KernelExecutor")); + assert!(prover_source.source.contains("execute_stage7_program")); + assert!(prover_source.source.contains("execute_stage7_prover")); + assert!(!verifier_source.source.contains("jolt_kernels")); + assert!(verifier_source.source.contains("Stage7VerifierProgramPlan")); + assert!(verifier_source.source.contains("pub fn verify_stage7")); + assert!(verifier_source + .source + .contains("relation: Some(\"jolt.stage7.batched\")")); + assert!(verifier_source + .source + .contains("jolt.stage7.hamming_weight_claim_reduction")); + assert!(verifier_source + .source + .contains("expected_hamming_weight_claim_reduction")); + assert!(verifier_source + .source + .contains("stage7.input.stage6.booleanity.InstructionRa_0")); + assert!(verifier_source + .source + .contains("stage7.hamming_weight_claim_reduction.eval.InstructionRa_0")); + assert!(verifier_source + .source + .contains("super::common::verify_batched_sumcheck")); + assert!(verifier_source.source.contains("stage7_verifier_program")); + assert_rust_source_compiles(&prover_source.filename, &prover_source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn stage8_rust_targets_extract_and_compile() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let expected_claims = params.num_committed; + let (prover_cpu, verifier_cpu) = build_stage8_pipeline_cpu(&context, ¶ms); + let prover_program = stage8_cpu_program(&prover_cpu).expect("extract prover stage8 program"); + let verifier_program = + stage8_cpu_program(&verifier_cpu).expect("extract verifier stage8 program"); + + assert_eq!(prover_program.role, Role::Prover); + assert_eq!(verifier_program.role, Role::Verifier); + assert_eq!(prover_program.opening_inputs.len(), expected_claims + 1); + assert_eq!(prover_program.opening_claims.len(), expected_claims); + assert_eq!(prover_program.opening_batches.len(), 1); + assert_eq!(prover_program.pcs_proofs.len(), 1); + assert_eq!(prover_program.pcs_proofs[0].mode, "open"); + assert_eq!(verifier_program.pcs_proofs[0].mode, "verify"); + assert_eq!( + prover_program.opening_batches[0].ordered_claims, + prover_program.opening_batches[0].claim_operands + ); + assert!(prover_program.opening_claims.iter().any(|claim| { + claim.symbol == "stage8.evaluation.opening.RamInc" + && claim.source_claim == "stage6.inc_claim_reduction.eval.RamInc" + })); + assert!(prover_program.opening_claims.iter().any(|claim| { + claim.symbol == "stage8.evaluation.opening.InstructionRa_0" + && claim.source_claim == "stage7.hamming_weight_claim_reduction.eval.InstructionRa_0" + })); + + let prover_source = emit_stage8_rust(&prover_cpu).expect("emit stage8 prover rust"); + let verifier_source = emit_stage8_rust(&verifier_cpu).expect("emit stage8 verifier rust"); + assert_eq!(prover_source.filename, "prove_stage8.rs"); + assert_eq!(verifier_source.filename, "verify_stage8.rs"); + assert!(prover_source.source.contains("pub const STAGE8_PROGRAM")); + assert!(prover_source + .source + .contains("stage8.evaluation.point_source")); + assert!(prover_source.source.contains("jolt_stage8_joint_rlc")); + assert!(prover_source + .source + .contains("stage6.inc_claim_reduction.eval.RamInc")); + assert!(prover_source + .source + .contains("stage7.hamming_weight_claim_reduction.eval.InstructionRa_0")); + assert!(verifier_source.source.contains("mode: \"verify\"")); + assert_rust_source_compiles(&prover_source.filename, &prover_source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn stage4_generated_artifact_crates_compile_in_isolation() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let (prover_cpu, verifier_cpu) = build_stage4_pipeline_cpu(&context, ¶ms); + let stage = ProtocolStage::new("stage4", "stage4", 4, ProtocolStageKind::Proof); + let config = jolt_artifact_config(); + let artifacts = vec![ + protocol_rust_artifact( + &config, + stage.clone(), + Role::Prover, + emit_stage4_rust(&prover_cpu).expect("emit stage4 prover"), + ), + protocol_rust_artifact( + &config, + stage, + Role::Verifier, + emit_stage4_rust(&verifier_cpu).expect("emit stage4 verifier"), + ), + ]; + for artifact in &artifacts { + validate_jolt_rust_artifact_imports(artifact).expect("stage4 import policy"); + } + if !generated_jolt_runtime_available() { + return; + } + + let output_root = new_temp_dir("bolt_stage4_generated_crates"); + let dependency_root = workspace_root().join("crates"); + let generated_crates = + assemble_jolt_generated_crates(artifacts, &dependency_root.display().to_string()) + .expect("assemble stage4 crates"); + write_jolt_generated_crates(&generated_crates, &output_root) + .expect("write stage4 generated crates"); + redirect_generated_prover_to_generated_verifier(&output_root, &dependency_root); + for generated in &generated_crates { + assert_generated_crate_manifest_compiles(&output_root, &generated.crate_name); + } + let _ = std::fs::remove_dir_all(output_root); +} + +#[test] +fn generic_artifact_assembly_supports_non_jolt_protocol_config() { + let config = non_jolt_artifact_config(); + let stage = ProtocolStage::new("alpha", "alpha", 1, ProtocolStageKind::Proof); + let artifacts = vec![ + protocol_rust_artifact( + &config, + stage.clone(), + Role::Prover, + non_jolt_alpha_prover_source(), + ), + protocol_rust_artifact( + &config, + stage, + Role::Verifier, + non_jolt_alpha_verifier_source(), + ), + ]; + assert_eq!( + artifacts[0].path, "acme-prover/src/stages/alpha.rs", + "generic artifact path should derive from config and stage module" + ); + assert_eq!( + artifacts[1].path, "acme-verifier/src/stages/alpha.rs", + "generic artifact path should derive from config and stage module" + ); + for artifact in &artifacts { + validate_rust_artifact_imports(&config, artifact).expect("generic import policy"); + } + + let generated = + assemble_generated_crates(&config, artifacts, "../deps").expect("assemble generic crates"); + let prover = generated + .iter() + .find(|generated| generated.crate_name == "acme-prover") + .expect("generated prover crate"); + let verifier = generated + .iter() + .find(|generated| generated.crate_name == "acme-verifier") + .expect("generated verifier crate"); + + let prover_manifest = prover + .files + .iter() + .find(|file| file.path == "Cargo.toml") + .expect("prover manifest") + .source + .as_str(); + assert!(prover_manifest.contains("name = \"acme-prover\"")); + assert!(prover_manifest.contains("acme-verifier = { path = \"../deps/acme-verifier\" }")); + assert!(prover_manifest.contains("serde = { version = \"1\", default-features = false }")); + assert!(!prover_manifest.contains("serde = { path = ")); + assert!(prover.files.iter().any(|file| file.path == "src/prover.rs")); + assert!(prover + .files + .iter() + .any(|file| file.path == "src/stages/alpha.rs")); + + let verifier_stages = verifier + .files + .iter() + .find(|file| file.path == "src/stages/mod.rs") + .expect("verifier stages module") + .source + .as_str(); + assert!(verifier_stages.contains("pub mod shared;")); + assert!(verifier_stages.contains("pub mod alpha;")); + assert!(verifier + .files + .iter() + .any(|file| file.path == "src/verifier.rs")); + assert!(verifier + .files + .iter() + .any(|file| file.path == "src/stages/shared.rs")); + + let generated_surface = generated + .iter() + .flat_map(|generated| generated.files.iter()) + .map(|file| file.source.as_str()) + .collect::>() + .join("\n"); + assert!( + !generated_surface.contains("jolt") && !generated_surface.contains("Jolt"), + "generic artifact assembly leaked Jolt names into a non-Jolt protocol fixture" + ); + assert!( + !generated_surface.contains("ark-bn254") && !generated_surface.contains("arkworks-algebra"), + "generic artifact assembly leaked Jolt/arkworks standalone manifest patches into a non-Jolt protocol fixture" + ); + assert!(generated_surface.contains("pub const TRANSCRIPT_LABEL: &[u8] = b\"acme transcript\";")); + assert!(generated_surface.contains("crate::stages::shared::StageProof")); + assert!(generated_surface.contains("pub fn prove_acme")); + assert!(generated_surface.contains("pub fn verify_acme")); +} + +#[test] +fn generated_jolt_artifacts_have_uniform_crate_layout_and_import_rules() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let (commitment_prover_cpu, commitment_verifier_cpu) = + build_commitment_pipeline_cpu(&context, ¶ms); + let (stage1_prover_cpu, stage1_verifier_cpu) = build_stage1_pipeline_cpu(&context, ¶ms); + let (stage2_prover_cpu, stage2_verifier_cpu) = build_stage2_pipeline_cpu(&context, ¶ms); + let (stage3_prover_cpu, stage3_verifier_cpu) = build_stage3_pipeline_cpu(&context, ¶ms); + let (stage4_prover_cpu, stage4_verifier_cpu) = build_stage4_pipeline_cpu(&context, ¶ms); + let (stage5_prover_cpu, stage5_verifier_cpu) = build_stage5_pipeline_cpu(&context, ¶ms); + let (stage6_prover_cpu, stage6_verifier_cpu) = build_stage6_pipeline_cpu(&context, ¶ms); + let (stage7_prover_cpu, stage7_verifier_cpu) = build_stage7_pipeline_cpu(&context, ¶ms); + let (stage8_prover_cpu, stage8_verifier_cpu) = build_stage8_pipeline_cpu(&context, ¶ms); + + let emitted = [ + ( + JoltProtocolStage::Commitment, + Role::Prover, + emit_commitment_rust(&commitment_prover_cpu).expect("emit commitment prover"), + ), + ( + JoltProtocolStage::Commitment, + Role::Verifier, + emit_commitment_rust(&commitment_verifier_cpu).expect("emit commitment verifier"), + ), + ( + JoltProtocolStage::Stage1Outer, + Role::Prover, + emit_stage1_rust(&stage1_prover_cpu).expect("emit stage1 prover"), + ), + ( + JoltProtocolStage::Stage1Outer, + Role::Verifier, + emit_stage1_rust(&stage1_verifier_cpu).expect("emit stage1 verifier"), + ), + ( + JoltProtocolStage::Stage2, + Role::Prover, + emit_stage2_rust(&stage2_prover_cpu).expect("emit stage2 prover"), + ), + ( + JoltProtocolStage::Stage2, + Role::Verifier, + emit_stage2_rust(&stage2_verifier_cpu).expect("emit stage2 verifier"), + ), + ( + JoltProtocolStage::Stage3, + Role::Prover, + emit_stage3_rust(&stage3_prover_cpu).expect("emit stage3 prover"), + ), + ( + JoltProtocolStage::Stage3, + Role::Verifier, + emit_stage3_rust(&stage3_verifier_cpu).expect("emit stage3 verifier"), + ), + ( + JoltProtocolStage::Stage4, + Role::Prover, + emit_stage4_rust(&stage4_prover_cpu).expect("emit stage4 prover"), + ), + ( + JoltProtocolStage::Stage4, + Role::Verifier, + emit_stage4_rust(&stage4_verifier_cpu).expect("emit stage4 verifier"), + ), + ( + JoltProtocolStage::Stage5, + Role::Prover, + emit_stage5_rust(&stage5_prover_cpu).expect("emit stage5 prover"), + ), + ( + JoltProtocolStage::Stage5, + Role::Verifier, + emit_stage5_rust(&stage5_verifier_cpu).expect("emit stage5 verifier"), + ), + ( + JoltProtocolStage::Stage6, + Role::Prover, + emit_stage6_rust(&stage6_prover_cpu).expect("emit stage6 prover"), + ), + ( + JoltProtocolStage::Stage6, + Role::Verifier, + emit_stage6_rust(&stage6_verifier_cpu).expect("emit stage6 verifier"), + ), + ( + JoltProtocolStage::Stage7, + Role::Prover, + emit_stage7_rust(&stage7_prover_cpu).expect("emit stage7 prover"), + ), + ( + JoltProtocolStage::Stage7, + Role::Verifier, + emit_stage7_rust(&stage7_verifier_cpu).expect("emit stage7 verifier"), + ), + ( + JoltProtocolStage::Stage8, + Role::Prover, + emit_stage8_rust(&stage8_prover_cpu).expect("emit stage8 prover"), + ), + ( + JoltProtocolStage::Stage8, + Role::Verifier, + emit_stage8_rust(&stage8_verifier_cpu).expect("emit stage8 verifier"), + ), + ]; + let artifacts = emitted + .into_iter() + .map(|(stage, role, source)| { + let artifact = jolt_rust_artifact(stage, role, source).expect("canonical artifact"); + validate_jolt_rust_artifact_imports(&artifact).expect("artifact import policy"); + artifact + }) + .collect::>(); + + let paths = artifacts + .iter() + .map(|artifact| artifact.path.as_str()) + .collect::>(); + assert_eq!( + paths, + vec![ + "jolt-prover/src/stages/commitment.rs", + "jolt-verifier/src/stages/commitment.rs", + "jolt-prover/src/stages/stage1_outer.rs", + "jolt-verifier/src/stages/stage1_outer.rs", + "jolt-prover/src/stages/stage2.rs", + "jolt-verifier/src/stages/stage2.rs", + "jolt-prover/src/stages/stage3.rs", + "jolt-verifier/src/stages/stage3.rs", + "jolt-prover/src/stages/stage4.rs", + "jolt-verifier/src/stages/stage4.rs", + "jolt-prover/src/stages/stage5.rs", + "jolt-verifier/src/stages/stage5.rs", + "jolt-prover/src/stages/stage6.rs", + "jolt-verifier/src/stages/stage6.rs", + "jolt-prover/src/stages/stage7.rs", + "jolt-verifier/src/stages/stage7.rs", + "jolt-prover/src/stages/stage8.rs", + "jolt-verifier/src/stages/stage8.rs", + ] + ); + assert!(artifacts + .iter() + .filter(|artifact| artifact.crate_name == "jolt-verifier") + .all(|artifact| !artifact.source.source.contains("jolt_kernels"))); + assert!(artifacts + .iter() + .filter(|artifact| { artifact.crate_name == "jolt-prover" && artifact.stage.is_proof() }) + .all(|artifact| artifact.source.source.contains("jolt_kernels"))); + let workspace_generated_crates = assemble_jolt_workspace_generated_crates(artifacts.clone()) + .expect("assemble workspace generated role crates"); + if std::env::var_os("JOLT_UPDATE_GOLDENS").is_some() { + write_jolt_generated_crates(&workspace_generated_crates, workspace_root().join("crates")) + .expect("update checked-in generated role crates"); + } + if !checked_in_generated_role_crates_available() { + return; + } + assert_checked_in_generated_role_crate_sources_match(&workspace_generated_crates); + let dependency_root = workspace_root().join("crates").display().to_string(); + let generated_crates = assemble_jolt_generated_crates(artifacts, &dependency_root) + .expect("assemble generated role crates"); + assert_eq!( + generated_crates + .iter() + .map(|generated| generated.crate_name.as_str()) + .collect::>(), + vec!["jolt-prover", "jolt-verifier"] + ); + for generated in &generated_crates { + assert_generated_role_crate_compiles(generated); + } + let output_root = new_temp_dir("bolt_generated_crates"); + write_jolt_generated_crates(&workspace_generated_crates, &output_root) + .expect("write generated role crates"); + for generated in &workspace_generated_crates { + for file in &generated.files { + assert!(output_root + .join(&generated.crate_name) + .join(&file.path) + .exists()); + } + } + let _ = std::fs::remove_dir_all(output_root); +} + +#[test] +fn jolt_stage1_outer_lowers_to_compute_and_cpu_kernel_ir() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = + build_stage1_outer_protocol(&context, ¶ms).expect("build stage1 outer protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage1 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage1_to_compute(&context, &prover).expect("lower prover stage1"); + let verifier_compute = + lower_stage1_to_compute(&context, &verifier).expect("lower verifier stage1"); + verify_compute_schema(&prover_compute).expect("prover stage1 compute schema is valid"); + verify_compute_schema(&verifier_compute).expect("verifier stage1 compute schema is valid"); + assert!(prover_compute + .to_text_mlir() + .contains("relation = @jolt.stage1.outer.uniskip")); + assert!(!prover_compute.to_text_mlir().contains("kernel = @")); + let verifier_compute_text = verifier_compute.to_text_mlir(); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify_claim\"")); + assert!(verifier_compute_text.contains("\"compute.sumcheck_verify\"")); + assert!(!verifier_compute_text.contains("\"compute.kernel\"")); + assert!(!verifier_compute_text.contains("kernel = @")); + + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + verify_compute_schema(&prover_kernel_compute) + .expect("prover kernelized stage1 compute schema is valid"); + verify_compute_schema(&verifier_kernel_compute) + .expect("verifier kernelized stage1 compute schema is valid"); + + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + verify_cpu_schema(&prover_cpu).expect("prover stage1 CPU schema is valid"); + verify_cpu_schema(&verifier_cpu).expect("verifier stage1 CPU schema is valid"); + let program = stage1_cpu_program(&prover_cpu).expect("extract prover stage1 CPU program"); + + let cpu_text = prover_cpu.to_text_mlir(); + let verifier_cpu_text = verifier_cpu.to_text_mlir(); + assert!(cpu_text.contains("\"cpu.kernel\"()")); + assert!(cpu_text.contains("kernel = @jolt.cpu.stage1.outer.uniskip")); + assert!(cpu_text.contains("kernel = @jolt.cpu.stage1.outer.remaining")); + assert!(cpu_text.contains("\"cpu.sumcheck_driver\"(%")); + assert!(cpu_text.contains("\"cpu.sumcheck_eval\"(%")); + assert!(cpu_text.contains("\"cpu.opening_claim\"(%")); + assert!(cpu_text.contains("\"cpu.opening_batch\"(%")); + assert!(cpu_text.contains("\"cpu.sumcheck_claim\"(%")); + assert!(cpu_text.contains("count = 35 : i64")); + assert!(!cpu_text.contains("\"cpu.pcs_opening_claim\"")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify_claim\"")); + assert!(verifier_cpu_text.contains("\"cpu.sumcheck_verify\"")); + assert!(!verifier_cpu_text.contains("\"cpu.kernel\"")); + assert!(!verifier_cpu_text.contains("kernel = @")); + assert_eq!(program.role, Role::Prover); + assert_eq!(program.kernels.len(), 2); + assert!(program.kernels.iter().any(|kernel| { + kernel.symbol == "jolt.cpu.stage1.outer.uniskip" + && kernel.relation == "jolt.stage1.outer.uniskip" + && kernel.abi == "jolt_stage1_outer_uniskip" + })); + assert!(program.kernels.iter().any(|kernel| { + kernel.symbol == "jolt.cpu.stage1.outer.remaining" + && kernel.relation == "jolt.stage1.outer.remaining" + && kernel.abi == "jolt_stage1_outer_remaining" + })); + assert_eq!(program.claims.len(), 2); + assert_eq!(program.batches.len(), 2); + assert_eq!(program.drivers.len(), 2); + assert_eq!(program.opening_claims.len(), 36); + assert_eq!(program.opening_batches.len(), 1); + let uniskip = program + .drivers + .iter() + .find(|driver| driver.symbol == "stage1.uniskip.sumcheck") + .expect("uniskip driver"); + assert_eq!( + uniskip.kernel.as_deref(), + Some("jolt.cpu.stage1.outer.uniskip") + ); + assert_eq!(uniskip.round_schedule, vec![1]); + assert_eq!(uniskip.num_rounds, 1); + assert_eq!(uniskip.degree, 27); + let remaining = program + .drivers + .iter() + .find(|driver| driver.symbol == "stage1.outer_remaining.sumcheck") + .expect("remaining driver"); + assert_eq!( + remaining.kernel.as_deref(), + Some("jolt.cpu.stage1.outer.remaining") + ); + assert_eq!(remaining.round_schedule, vec![params.log_t + 1]); + assert_eq!(remaining.num_rounds, params.log_t + 1); + assert_eq!(remaining.degree, 3); + assert_eq!( + program + .evals + .iter() + .filter(|eval| eval.source == "stage1.outer_remaining.sumcheck") + .count(), + 35 + ); + assert_eq!(program.opening_batches[0].count, 35); + assert_eq!( + program.opening_batches[0].ordered_claims, + program.opening_batches[0].claim_operands + ); + + assert_or_update_fixture( + "tests/fixtures/stage1_outer_prover_compute.mlir", + &prover_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/stage1_outer_verifier_compute.mlir", + &verifier_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/stage1_outer_prover_kernel_compute.mlir", + &prover_kernel_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/stage1_outer_verifier_kernel_compute.mlir", + &verifier_kernel_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/stage1_outer_prover_cpu.mlir", + &prover_cpu.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/stage1_outer_verifier_cpu.mlir", + &verifier_cpu.to_text_mlir(), + ); +} + +#[test] +fn stage1_rust_emission_matches_golden_and_compiles() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = + build_stage1_outer_protocol(&context, ¶ms).expect("build stage1 outer protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower stage1 to concrete"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage1_to_compute(&context, &prover).expect("lower prover stage1"); + let verifier_compute = + lower_stage1_to_compute(&context, &verifier).expect("lower verifier stage1"); + let prover_kernel_compute = + resolve_compute_kernels(&context, &prover_compute).expect("resolve prover kernels"); + let verifier_kernel_compute = + resolve_compute_kernels(&context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = + lower_compute_to_cpu(&context, &prover_kernel_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_kernel_compute).expect("lower verifier CPU"); + let source = emit_stage1_rust(&prover_cpu).expect("emit prover stage1 rust"); + let verifier_source = emit_stage1_rust(&verifier_cpu).expect("emit verifier stage1 rust"); + + assert_eq!(source.filename, "prove_stage1_outer.rs"); + assert_eq!(verifier_source.filename, "verify_stage1_outer.rs"); + assert!(source.source.contains("pub fn prove_stage1_outer")); + assert!(verifier_source + .source + .contains("pub fn verify_stage1_outer")); + assert!(source.source.contains("jolt_stage1_outer_uniskip")); + assert!(source.source.contains("jolt_stage1_outer_remaining")); + assert!(!verifier_source.source.contains("jolt_kernels")); + assert!(verifier_source.source.contains("jolt_sumcheck")); + assert_or_update_fixture("tests/fixtures/prove_stage1_outer.rs", &source.source); + assert_or_update_fixture( + "tests/fixtures/verify_stage1_outer.rs", + &verifier_source.source, + ); + assert_rust_source_compiles(&source.filename, &source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn generated_stage1_prover_shape_proof_verifier_accepts() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::new(2, 2, 2); + let (prover_cpu, verifier_cpu) = build_stage1_pipeline_cpu(&context, ¶ms); + let prover_source = emit_stage1_rust(&prover_cpu).expect("emit stage1 prover rust"); + let verifier_source = emit_stage1_rust(&verifier_cpu).expect("emit stage1 verifier rust"); + + assert_eq!(prover_source.filename, "prove_stage1_outer.rs"); + assert_eq!(verifier_source.filename, "verify_stage1_outer.rs"); + assert_generated_stage1_self_parity_runs( + &prover_source, + &verifier_source, + &generated_stage1_shape_self_parity_main(), + ); +} + +#[test] +fn generated_stage1_real_executor_reaches_kernel_dispatch() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::new(2, 2, 2); + let (prover_cpu, verifier_cpu) = build_stage1_pipeline_cpu(&context, ¶ms); + let prover_source = emit_stage1_rust(&prover_cpu).expect("emit stage1 prover rust"); + let verifier_source = emit_stage1_rust(&verifier_cpu).expect("emit stage1 verifier rust"); + + assert_generated_stage1_self_parity_runs( + &prover_source, + &verifier_source, + generated_stage1_real_dispatch_main(), + ); +} + +#[test] +fn generated_stage1_real_executor_self_verifies_synthetic_remaining() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::new(2, 2, 2); + let (prover_cpu, verifier_cpu) = build_stage1_pipeline_cpu(&context, ¶ms); + let prover_source = emit_stage1_rust(&prover_cpu).expect("emit stage1 prover rust"); + let verifier_source = emit_stage1_rust(&verifier_cpu).expect("emit stage1 verifier rust"); + + assert_generated_stage1_self_parity_runs( + &prover_source, + &verifier_source, + &generated_stage1_synthetic_remaining_main(), + ); +} + +#[test] +fn generated_stage1_real_executor_self_verifies_r1cs_data() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::new(2, 2, 2); + let (prover_cpu, verifier_cpu) = build_stage1_pipeline_cpu(&context, ¶ms); + let prover_source = emit_stage1_rust(&prover_cpu).expect("emit stage1 prover rust"); + let verifier_source = emit_stage1_rust(&verifier_cpu).expect("emit stage1 verifier rust"); + + assert_generated_stage1_self_parity_runs( + &prover_source, + &verifier_source, + &generated_stage1_r1cs_data_main(), + ); +} + +#[test] +fn jolt_protocol_chain_commitment_stage1_fixture_tracks_phase_order() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let chain = jolt_protocol_chain_commitment_stage1_fixture(&context, ¶ms); + + assert_or_update_fixture( + "tests/fixtures/jolt_protocol_chain_commitment_stage1.yaml", + &chain, + ); +} + +#[test] +fn generated_jolt_chain_commitment_then_stage1_self_parity_runs() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::new(2, 2, 2); + let (commitment_prover_cpu, commitment_verifier_cpu) = + build_commitment_pipeline_cpu(&context, ¶ms); + let (stage1_prover_cpu, stage1_verifier_cpu) = build_stage1_pipeline_cpu(&context, ¶ms); + let commitment_prover = + emit_commitment_rust(&commitment_prover_cpu).expect("emit commitment prover rust"); + let commitment_verifier = + emit_commitment_rust(&commitment_verifier_cpu).expect("emit commitment verifier rust"); + let stage1_prover = emit_stage1_rust(&stage1_prover_cpu).expect("emit stage1 prover rust"); + let stage1_verifier = + emit_stage1_rust(&stage1_verifier_cpu).expect("emit stage1 verifier rust"); + + assert_generated_jolt_chain_self_parity_runs( + &[ + &commitment_prover, + &commitment_verifier, + &stage1_prover, + &stage1_verifier, + ], + &generated_commitment_stage1_chain_main(), + ); +} + +#[test] +fn commitment_pipeline_matches_golden_mlir_fixtures() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_commitment_protocol(&context, ¶ms).expect("build protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower Fiat-Shamir state"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_commitment_to_compute(&context, &prover).expect("lower compute"); + let verifier_compute = + lower_commitment_to_compute(&context, &verifier).expect("lower verifier compute"); + let prover_cpu = lower_compute_to_cpu(&context, &prover_compute).expect("lower to CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_compute).expect("lower verifier to CPU"); + + assert_or_update_fixture( + "tests/fixtures/commitment_protocol.mlir", + &protocol.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/commitment_concrete.mlir", + &concrete.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/commitment_prover_party.mlir", + &prover.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/commitment_verifier_party.mlir", + &verifier.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/commitment_prover_compute.mlir", + &prover_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/commitment_verifier_compute.mlir", + &verifier_compute.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/commitment_prover_cpu.mlir", + &prover_cpu.to_text_mlir(), + ); + assert_or_update_fixture( + "tests/fixtures/commitment_verifier_cpu.mlir", + &verifier_cpu.to_text_mlir(), + ); +} + +#[test] +fn commitment_rust_emission_matches_golden_and_compiles() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::fixture(); + let protocol = build_commitment_protocol(&context, ¶ms).expect("build protocol"); + let concrete = + lower_piop_and_fiat_shamir(&context, &protocol).expect("lower Fiat-Shamir state"); + let prover = project_prover_party(&context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(&context, &concrete).expect("project verifier party"); + let prover_compute = lower_commitment_to_compute(&context, &prover).expect("lower compute"); + let verifier_compute = + lower_commitment_to_compute(&context, &verifier).expect("lower verifier compute"); + let prover_cpu = lower_compute_to_cpu(&context, &prover_compute).expect("lower to CPU"); + let verifier_cpu = + lower_compute_to_cpu(&context, &verifier_compute).expect("lower verifier to CPU"); + let source = emit_commitment_rust(&prover_cpu).expect("emit prover commitment rust"); + let verifier_source = + emit_commitment_rust(&verifier_cpu).expect("emit verifier commitment rust"); + + assert_eq!(source.filename, "prove_commitment_phase.rs"); + assert_eq!(verifier_source.filename, "verify_commitment_phase.rs"); + assert_or_update_fixture("tests/fixtures/prove_commitment_phase.rs", &source.source); + assert_or_update_fixture( + "tests/fixtures/verify_commitment_phase.rs", + &verifier_source.source, + ); + assert_rust_source_compiles(&source.filename, &source.source); + assert_rust_source_compiles(&verifier_source.filename, &verifier_source.source); +} + +#[test] +fn generated_commitment_prover_verifier_self_parity_runs() { + let context = MeliorContext::new(); + let prover_cpu = build_small_commitment_cpu(&context, Role::Prover); + let verifier_cpu = build_small_commitment_cpu(&context, Role::Verifier); + let prover_source = + emit_commitment_rust(&prover_cpu).expect("emit small prover commitment rust"); + let verifier_source = + emit_commitment_rust(&verifier_cpu).expect("emit small verifier commitment rust"); + + assert_eq!(prover_source.filename, "prove_commitment_phase.rs"); + assert_eq!(verifier_source.filename, "verify_commitment_phase.rs"); + assert_generated_commitment_self_parity_runs( + &prover_source, + &verifier_source, + &generated_small_self_parity_main(), + ); +} + +#[test] +fn pipeline_generated_commitment_prover_verifier_self_parity_runs() { + let context = MeliorContext::new(); + let params = JoltProtocolParams::new(0, 0, 0); + let (prover_cpu, verifier_cpu) = build_commitment_pipeline_cpu(&context, ¶ms); + let prover_source = + emit_commitment_rust(&prover_cpu).expect("emit pipeline prover commitment rust"); + let verifier_source = + emit_commitment_rust(&verifier_cpu).expect("emit pipeline verifier commitment rust"); + + assert_eq!(prover_source.filename, "prove_commitment_phase.rs"); + assert_eq!(verifier_source.filename, "verify_commitment_phase.rs"); + assert_generated_commitment_self_parity_runs( + &prover_source, + &verifier_source, + &generated_pipeline_self_parity_main(), + ); +} + +#[test] +fn commitment_rust_emission_requires_cpu_target_params() { + let context = MeliorContext::new(); + let cpu = context + .parse_module::( + r#" +module @bad attributes {bolt.phase = "cpu", bolt.role = "prover"} { + %0 = "cpu.oracle_family_init"() {count = 1 : i64, family = @bad.family, sym_name = "bad.family"} : () -> !cpu.oracle_family + %1 = "cpu.oracle_ref"() {domain = @bad.domain, num_vars = 1 : i64, oracle = @A, sym_name = "bad.A"} : () -> !cpu.oracle_buffer + %2 = "cpu.oracle_family_append"(%0, %1) {family = @bad.family, index = 0 : i64, oracle = @A, sym_name = "bad.family.append0"} : (!cpu.oracle_family, !cpu.oracle_buffer) -> !cpu.oracle_family + %3 = "cpu.pcs_commit_batch"(%2) {artifact = @bad.artifact, count = 1 : i64, domain = @bad.domain, label = "bad", num_vars = 1 : i64, oracle_family = @bad.family, ordered_oracles = [@A], pcs = @dory, sym_name = "bad.batch"} : (!cpu.oracle_family) -> !cpu.commitment_artifact +} +"#, + ) + .expect("parse bad CPU module"); + + let error = emit_commitment_rust(&cpu).expect_err("missing params rejected"); + assert!(error.to_string().contains("missing cpu.params")); +} + +fn build_small_commitment_cpu(context: &MeliorContext, role: Role) -> bolt::BoltModule<'_, Cpu> { + let (batch_op, optional_op) = match role { + Role::Prover => ("cpu.pcs_commit_batch", "cpu.pcs_commit_optional"), + Role::Verifier => ("cpu.pcs_receive_batch", "cpu.pcs_receive_optional"), + }; + context + .parse_module::(&format!( + r#" +module @small.commitment_phase attributes {{bolt.phase = "cpu", bolt.role = "{}"}} {{ + "cpu.params"() {{field = @bn254_fr, pcs = @dory, sym_name = "small.params", transcript = @blake2b_transcript}} : () -> () + "cpu.function"() {{source = @small.commitment_phase, sym_name = "small.commitment_phase"}} : () -> () + %0 = "cpu.transcript_init"() {{scheme = @blake2b_transcript, sym_name = "fs0"}} : () -> !cpu.transcript_state + %1 = "cpu.oracle_family_init"() {{count = 2 : i64, family = @small.main_polys, sym_name = "small.main_polys"}} : () -> !cpu.oracle_family + %2 = "cpu.oracle_ref"() {{domain = @small.domain, num_vars = 2 : i64, oracle = @A, sym_name = "small.A"}} : () -> !cpu.oracle_buffer + %3 = "cpu.oracle_family_append"(%1, %2) {{family = @small.main_polys, index = 0 : i64, oracle = @A, sym_name = "small.main_polys.append0"}} : (!cpu.oracle_family, !cpu.oracle_buffer) -> !cpu.oracle_family + %4 = "cpu.oracle_ref"() {{domain = @small.domain, num_vars = 2 : i64, oracle = @B, sym_name = "small.B"}} : () -> !cpu.oracle_buffer + %5 = "cpu.oracle_family_append"(%3, %4) {{family = @small.main_polys, index = 1 : i64, oracle = @B, sym_name = "small.main_polys.append1"}} : (!cpu.oracle_family, !cpu.oracle_buffer) -> !cpu.oracle_family + %6 = "{batch_op}"(%5) {{artifact = @small.main, count = 2 : i64, domain = @small.domain, label = "commitment", num_vars = 2 : i64, oracle_family = @small.main_polys, ordered_oracles = [@A, @B], pcs = @dory, sym_name = "small.main"}} : (!cpu.oracle_family) -> !cpu.commitment_artifact + %7 = "cpu.oracle_ref"() {{domain = @small.domain, num_vars = 2 : i64, oracle = @Advice, sym_name = "small.Advice"}} : () -> !cpu.oracle_buffer + %8 = "{optional_op}"(%7) {{artifact = @small.advice, domain = @small.domain, label = "advice", num_vars = 2 : i64, oracle = @Advice, pcs = @dory, skip_policy = "missing_or_zero", sym_name = "small.advice"}} : (!cpu.oracle_buffer) -> !cpu.commitment_artifact + %9 = "cpu.transcript_absorb"(%0, %6) {{label = "commitment", optional = false, sym_name = "small.absorb_main"}} : (!cpu.transcript_state, !cpu.commitment_artifact) -> !cpu.transcript_state + %10 = "cpu.transcript_absorb"(%9, %8) {{label = "advice", optional = true, sym_name = "small.absorb_advice"}} : (!cpu.transcript_state, !cpu.commitment_artifact) -> !cpu.transcript_state +}} +"#, + role.as_str() + )) + .expect("parse small CPU module") +} + +fn build_commitment_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_commitment_protocol(context, params).expect("build protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower Fiat-Shamir state"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = + lower_commitment_to_compute(context, &prover).expect("lower prover compute"); + let verifier_compute = + lower_commitment_to_compute(context, &verifier).expect("lower verifier compute"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn build_stage1_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_stage1_outer_protocol(context, params).expect("build stage1 protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower stage1 protocol"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage1_to_compute(context, &prover).expect("lower prover stage1"); + let verifier_compute = + lower_stage1_to_compute(context, &verifier).expect("lower verifier stage1"); + let prover_compute = + resolve_compute_kernels(context, &prover_compute).expect("resolve prover kernels"); + let verifier_compute = + resolve_compute_kernels(context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn build_stage2_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_stage2_protocol(context, params).expect("build stage2 protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower stage2 protocol"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage2_to_compute(context, &prover).expect("lower prover stage2"); + let verifier_compute = + lower_stage2_to_compute(context, &verifier).expect("lower verifier stage2"); + let prover_compute = + resolve_compute_kernels(context, &prover_compute).expect("resolve prover kernels"); + let verifier_compute = + resolve_compute_kernels(context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn build_stage3_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_stage3_protocol(context, params).expect("build stage3 protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower stage3 protocol"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage3_to_compute(context, &prover).expect("lower prover stage3"); + let verifier_compute = + lower_stage3_to_compute(context, &verifier).expect("lower verifier stage3"); + let prover_compute = + resolve_compute_kernels(context, &prover_compute).expect("resolve prover kernels"); + let verifier_compute = + resolve_compute_kernels(context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn build_stage4_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_stage4_protocol(context, params).expect("build stage4 protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower stage4 protocol"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage4_to_compute(context, &prover).expect("lower prover stage4"); + let verifier_compute = + lower_stage4_to_compute(context, &verifier).expect("lower verifier stage4"); + let prover_compute = + resolve_compute_kernels(context, &prover_compute).expect("resolve prover kernels"); + let verifier_compute = + resolve_compute_kernels(context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn build_stage5_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_stage5_protocol(context, params).expect("build stage5 protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower stage5 protocol"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage5_to_compute(context, &prover).expect("lower prover stage5"); + let verifier_compute = + lower_stage5_to_compute(context, &verifier).expect("lower verifier stage5"); + let prover_compute = + resolve_compute_kernels(context, &prover_compute).expect("resolve prover kernels"); + let verifier_compute = + resolve_compute_kernels(context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn build_stage6_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_stage6_protocol(context, params).expect("build stage6 protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower stage6 protocol"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage6_to_compute(context, &prover).expect("lower prover stage6"); + let verifier_compute = + lower_stage6_to_compute(context, &verifier).expect("lower verifier stage6"); + let prover_compute = + resolve_compute_kernels(context, &prover_compute).expect("resolve prover kernels"); + let verifier_compute = + resolve_compute_kernels(context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn build_stage7_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_stage7_protocol(context, params).expect("build stage7 protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower stage7 protocol"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage7_to_compute(context, &prover).expect("lower prover stage7"); + let verifier_compute = + lower_stage7_to_compute(context, &verifier).expect("lower verifier stage7"); + let prover_compute = + resolve_compute_kernels(context, &prover_compute).expect("resolve prover kernels"); + let verifier_compute = + resolve_compute_kernels(context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn build_stage8_pipeline_cpu<'c>( + context: &'c MeliorContext, + params: &JoltProtocolParams, +) -> (bolt::BoltModule<'c, Cpu>, bolt::BoltModule<'c, Cpu>) { + let protocol = build_stage8_protocol(context, params).expect("build stage8 protocol"); + let concrete = lower_piop_and_fiat_shamir(context, &protocol).expect("lower stage8 protocol"); + let prover = project_prover_party(context, &concrete).expect("project prover party"); + let verifier = project_verifier_party(context, &concrete).expect("project verifier party"); + let prover_compute = lower_stage8_to_compute(context, &prover).expect("lower prover stage8"); + let verifier_compute = + lower_stage8_to_compute(context, &verifier).expect("lower verifier stage8"); + let prover_compute = + resolve_compute_kernels(context, &prover_compute).expect("resolve prover kernels"); + let verifier_compute = + resolve_compute_kernels(context, &verifier_compute).expect("resolve verifier kernels"); + let prover_cpu = lower_compute_to_cpu(context, &prover_compute).expect("lower prover CPU"); + let verifier_cpu = + lower_compute_to_cpu(context, &verifier_compute).expect("lower verifier CPU"); + (prover_cpu, verifier_cpu) +} + +fn non_jolt_artifact_config() -> ProtocolArtifactConfig { + ProtocolArtifactConfig { + protocol_name: "Acme".to_owned(), + type_prefix: "Acme".to_owned(), + transcript_label: "acme transcript".to_owned(), + repository: None, + prover_crate_name: "acme-prover".to_owned(), + verifier_crate_name: "acme-verifier".to_owned(), + crates_io_patches: Vec::new(), + standalone_dependency_overrides: vec![ProtocolStandaloneDependency::new( + "serde", + "serde = { version = \"1\", default-features = false }", + )], + common_dependencies: vec!["serde".to_owned()], + prover_dependencies: Vec::new(), + verifier_dependencies: Vec::new(), + instrumentation_prefix: None, + prover_forbidden_imports: vec!["forbidden_prover".to_owned()], + verifier_forbidden_imports: vec!["forbidden_verifier".to_owned()], + kernel_crate: None, + field_type: RustTypeRef::new("std::primitive::u64"), + default_transcript_type: RustTypeRef::new("crate::stages::alpha::DefaultTranscript"), + transcript_trait: RustTypeRef::new("crate::stages::alpha::Transcript"), + commitment_type: RustTypeRef::new("crate::stages::shared::Commitment"), + prover_setup_type: RustTypeRef::new("crate::stages::alpha::ProverSetup"), + role_api_extension: None, + verifier_runtime_modules: vec![ProtocolRuntimeModule { + module_name: "shared".to_owned(), + file: GeneratedFile { + path: "src/stages/shared.rs".to_owned(), + source: non_jolt_verifier_common_source(), + }, + }], + verifier_named_eval_type: RustTypeRef::new("crate::stages::shared::StageNamedEval"), + verifier_sumcheck_output_type: RustTypeRef::new( + "crate::stages::shared::StageSumcheckOutput", + ), + verifier_stage_proof_type: RustTypeRef::new("crate::stages::shared::StageProof"), + } +} + +fn non_jolt_alpha_prover_source() -> RustSourceFile { + RustSourceFile { + filename: "prove_alpha.rs".to_owned(), + source: r" +pub struct DefaultTranscript(core::marker::PhantomData); + +pub trait Transcript { + type Challenge; +} + +pub struct ProverSetup; + +#[derive(Clone, Debug)] +pub struct AlphaExecutionArtifacts { + pub sumchecks: Vec>, +} + +#[derive(Clone, Debug)] +pub struct AlphaSumcheckOutput { + pub driver: &'static str, + pub point: Vec, + pub evals: Vec>, + pub proof: (), +} + +#[derive(Clone, Debug)] +pub struct AlphaNamedEval { + pub name: &'static str, + pub oracle: &'static str, + pub value: F, +} + +#[derive(Debug)] +pub struct AlphaKernelError; + +pub trait AlphaKernelExecutor {} + +pub fn execute_alpha( + _executor: &mut E, + _transcript: &mut T, +) -> Result, AlphaKernelError> +where + E: AlphaKernelExecutor, +{ + Ok(AlphaExecutionArtifacts { + sumchecks: Vec::new(), + }) +} +" + .trim_start() + .to_owned(), + } +} + +fn non_jolt_alpha_verifier_source() -> RustSourceFile { + RustSourceFile { + filename: "verify_alpha.rs".to_owned(), + source: r" +pub struct DefaultTranscript(core::marker::PhantomData); + +pub trait Transcript { + type Challenge; +} + +pub type AlphaNamedEval = super::shared::StageNamedEval; +pub type AlphaSumcheckOutput = super::shared::StageSumcheckOutput; +pub type AlphaProof = super::shared::StageProof; + +#[derive(Clone, Debug)] +pub struct AlphaExecutionArtifacts { + pub sumchecks: Vec>, +} + +#[derive(Debug)] +pub enum VerifyAlphaError {} + +pub fn verify_alpha( + _proof: &AlphaProof, + _transcript: &mut T, +) -> Result, VerifyAlphaError> { + Ok(AlphaExecutionArtifacts { + sumchecks: Vec::new(), + }) +} +" + .trim_start() + .to_owned(), + } +} + +fn non_jolt_verifier_common_source() -> String { + r" +#[derive(Clone, Debug)] +pub struct Commitment; + +#[derive(Clone, Debug)] +pub struct StageNamedEval { + pub name: &'static str, + pub oracle: &'static str, + pub value: F, +} + +#[derive(Clone, Debug)] +pub struct StageSumcheckOutput { + pub driver: &'static str, + pub point: Vec, + pub evals: Vec>, + pub proof: (), +} + +#[derive(Clone, Debug)] +pub struct StageProof { + pub sumchecks: Vec>, +} +" + .trim_start() + .to_owned() +} + +fn jolt_protocol_chain_commitment_stage1_fixture( + context: &MeliorContext, + params: &JoltProtocolParams, +) -> String { + let (commitment_prover_cpu, commitment_verifier_cpu) = + build_commitment_pipeline_cpu(context, params); + let (stage1_prover_cpu, stage1_verifier_cpu) = build_stage1_pipeline_cpu(context, params); + let commitment_prover = + commitment_cpu_program(&commitment_prover_cpu).expect("extract commitment prover program"); + let commitment_verifier = commitment_cpu_program(&commitment_verifier_cpu) + .expect("extract commitment verifier program"); + let stage1_prover = stage1_cpu_program(&stage1_prover_cpu).expect("extract stage1 prover"); + let stage1_verifier = + stage1_cpu_program(&stage1_verifier_cpu).expect("extract stage1 verifier"); + let commitment_prover_source = + emit_commitment_rust(&commitment_prover_cpu).expect("emit commitment prover"); + let commitment_verifier_source = + emit_commitment_rust(&commitment_verifier_cpu).expect("emit commitment verifier"); + let stage1_prover_source = emit_stage1_rust(&stage1_prover_cpu).expect("emit stage1 prover"); + let stage1_verifier_source = + emit_stage1_rust(&stage1_verifier_cpu).expect("emit stage1 verifier"); + + let mut text = String::new(); + writeln!(&mut text, "# Jolt protocol chain fixture").unwrap(); + writeln!(&mut text, "params:").unwrap(); + writeln!(&mut text, " log_t: {}", params.log_t).unwrap(); + writeln!(&mut text, " log_k_bytecode: {}", params.log_k_bytecode).unwrap(); + writeln!(&mut text, " log_k_ram: {}", params.log_k_ram).unwrap(); + writeln!(&mut text, " trace_length: {}", params.trace_length).unwrap(); + writeln!(&mut text, "phases:").unwrap(); + writeln!(&mut text, " - name: commitment").unwrap(); + writeln!( + &mut text, + " protocol_fixture: tests/fixtures/commitment_protocol.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " concrete_fixture: tests/fixtures/commitment_concrete.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " prover_cpu_fixture: tests/fixtures/commitment_prover_cpu.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " verifier_cpu_fixture: tests/fixtures/commitment_verifier_cpu.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " prover_rust_fixture: tests/fixtures/{}", + commitment_prover_source.filename + ) + .unwrap(); + writeln!( + &mut text, + " verifier_rust_fixture: tests/fixtures/{}", + commitment_verifier_source.filename + ) + .unwrap(); + writeln!( + &mut text, + " prover_batches: {}", + commitment_prover.batch_plans.len() + ) + .unwrap(); + writeln!( + &mut text, + " verifier_batches: {}", + commitment_verifier.batch_plans.len() + ) + .unwrap(); + writeln!( + &mut text, + " optional_commitments: {}", + commitment_prover.optional_plans.len() + ) + .unwrap(); + writeln!( + &mut text, + " transcript_steps: {}", + commitment_prover.transcript_steps.len() + ) + .unwrap(); + writeln!(&mut text, " - name: stage1_outer").unwrap(); + writeln!(&mut text, " consumes_transcript_from: commitment").unwrap(); + writeln!( + &mut text, + " protocol_fixture: tests/fixtures/stage1_outer_protocol.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " prover_compute_fixture: tests/fixtures/stage1_outer_prover_compute.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " verifier_compute_fixture: tests/fixtures/stage1_outer_verifier_compute.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " prover_kernel_compute_fixture: tests/fixtures/stage1_outer_prover_kernel_compute.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " verifier_kernel_compute_fixture: tests/fixtures/stage1_outer_verifier_kernel_compute.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " prover_cpu_fixture: tests/fixtures/stage1_outer_prover_cpu.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " verifier_cpu_fixture: tests/fixtures/stage1_outer_verifier_cpu.mlir" + ) + .unwrap(); + writeln!( + &mut text, + " prover_rust_fixture: tests/fixtures/{}", + stage1_prover_source.filename + ) + .unwrap(); + writeln!( + &mut text, + " verifier_rust_fixture: tests/fixtures/{}", + stage1_verifier_source.filename + ) + .unwrap(); + writeln!( + &mut text, + " transcript_squeezes: {}", + stage1_prover.transcript_squeezes.len() + ) + .unwrap(); + writeln!( + &mut text, + " prover_sumcheck_drivers: {}", + stage1_prover.drivers.len() + ) + .unwrap(); + writeln!( + &mut text, + " verifier_sumcheck_drivers: {}", + stage1_verifier.drivers.len() + ) + .unwrap(); + writeln!( + &mut text, + " opening_claims: {}", + stage1_prover.opening_claims.len() + ) + .unwrap(); + writeln!( + &mut text, + " opening_batches: {}", + stage1_prover.opening_batches.len() + ) + .unwrap(); + writeln!(&mut text, " drivers:").unwrap(); + for driver in &stage1_prover.drivers { + writeln!( + &mut text, + " - {}: kernel={} rounds={} degree={} proof_slot={}", + driver.symbol, + driver.kernel.as_deref().unwrap_or(""), + driver.num_rounds, + driver.degree, + driver.proof_slot + ) + .unwrap(); + } + writeln!(&mut text, "parity_gates:").unwrap(); + writeln!( + &mut text, + " - pipeline_generated_commitment_prover_verifier_self_parity_runs" + ) + .unwrap(); + writeln!( + &mut text, + " - generated_stage1_real_executor_self_verifies_synthetic_remaining" + ) + .unwrap(); + writeln!( + &mut text, + " - generated_jolt_chain_commitment_then_stage1_self_parity_runs" + ) + .unwrap(); + text +} + +fn opening_claim_equal_protocol(left_oracle: &str, right_oracle: &str, mode: &str) -> String { + let right_oracle_def = if left_oracle == right_oracle { + String::new() + } else { + format!( + r#" "piop.oracle"() {{commit_domain = @trace, domain = @trace, field = @bn254_fr, layout = "virtual", sym_name = "{right_oracle}", visibility = "virtual"}} : () -> () +"# + ) + }; + format!( + r#" +module @opening.claim.equal attributes {{bolt.phase = "protocol"}} {{ + "field.define"() {{modulus_bits = 254 : i64, role = "scalar", sym_name = "bn254_fr"}} : () -> () + "hash.function"() {{algorithm = "blake2b", sym_name = "blake2b"}} : () -> () + "transcript.scheme"() {{hash = @blake2b, sym_name = "blake2b_transcript"}} : () -> () + "pcs.scheme"() {{field = @bn254_fr, sym_name = "dory"}} : () -> () + "poly.domain"() {{field = @bn254_fr, log_size = 16 : i64, sym_name = "trace"}} : () -> () + "protocol.params"() {{field = @bn254_fr, pcs = @dory, sym_name = "params", transcript = @blake2b_transcript}} : () -> () + "protocol.boundary"() {{roles = ["prover", "verifier"], sym_name = "opening.claim.equal"}} : () -> () + "piop.oracle"() {{commit_domain = @trace, domain = @trace, field = @bn254_fr, layout = "virtual", sym_name = "{left_oracle}", visibility = "virtual"}} : () -> () +{right_oracle_def} + %left:3 = "piop.opening_input"() {{claim_kind = "virtual", domain = @trace, oracle = @{left_oracle}, point_arity = 16 : i64, source_claim = @stage2.product_virtual.remainder.opening.{left_oracle}, source_stage = @stage2, sym_name = "stage3.input.stage2_left.{left_oracle}"}} : () -> (!poly.point, !field.scalar, !piop.opening_claim_type) + %right:3 = "piop.opening_input"() {{claim_kind = "virtual", domain = @trace, oracle = @{right_oracle}, point_arity = 16 : i64, source_claim = @stage2.instruction_lookup.claim_reduction.opening.{right_oracle}, source_stage = @stage2, sym_name = "stage3.input.stage2_right.{right_oracle}"}} : () -> (!poly.point, !field.scalar, !piop.opening_claim_type) + "piop.opening_claim_equal"(%left#2, %right#2) {{mode = "{mode}", sym_name = "stage3.instruction_input.left_claim_consistency"}} : (!piop.opening_claim_type, !piop.opening_claim_type) -> () +}} +"# + ) +} + +fn transcript_absorb_bytes_protocol(params: &JoltProtocolParams) -> String { + format!( + r#" +module @transcript.absorb.bytes attributes {{bolt.phase = "protocol"}} {{ + "field.define"() {{modulus_bits = 254 : i64, role = "scalar", sym_name = "bn254_fr"}} : () -> () + "hash.function"() {{algorithm = "blake2b", sym_name = "blake2b"}} : () -> () + "transcript.scheme"() {{hash = @blake2b, sym_name = "blake2b_transcript"}} : () -> () + "pcs.scheme"() {{field = @bn254_fr, sym_name = "dory"}} : () -> () + "protocol.params"() {{{params_attrs}, sym_name = "jolt.params"}} : () -> () + "protocol.boundary"() {{roles = ["prover", "verifier"], sym_name = "transcript.absorb.bytes"}} : () -> () + %0 = "transcript.state"() {{scheme = @blake2b_transcript, sym_name = "fs_after_stage3"}} : () -> !transcript.state_type + %1 = "transcript.absorb_bytes"(%0) {{label = "ram_val_check_gamma", payload = "", sym_name = "stage4.ram_val_check.domain_separator"}} : (!transcript.state_type) -> !transcript.state_type + %2:2 = "transcript.squeeze"(%1) {{count = 1 : i64, kind = "challenge_scalar", label = "ram_val_check_gamma", sym_name = "stage4.ram_val_check.gamma"}} : (!transcript.state_type) -> (!transcript.state_type, !field.scalar) +}} +"#, + params_attrs = jolt_params_attrs_source(params) + ) +} + +fn jolt_params_attrs_source(params: &JoltProtocolParams) -> String { + params + .attrs() + .into_iter() + .map(|(name, value)| format!("{name} = {value}")) + .collect::>() + .join(", ") +} + +fn explicit_sumcheck_protocol() -> &'static str { + r#" +module @explicit.sumcheck attributes {bolt.phase = "protocol"} { + "field.define"() {modulus_bits = 254 : i64, role = "scalar", sym_name = "bn254_fr"} : () -> () + "hash.function"() {algorithm = "blake2b", sym_name = "blake2b"} : () -> () + "transcript.scheme"() {hash = @blake2b, sym_name = "blake2b_transcript"} : () -> () + "pcs.scheme"() {field = @bn254_fr, sym_name = "dory"} : () -> () + "poly.domain"() {field = @bn254_fr, log_size = 16 : i64, sym_name = "trace"} : () -> () + "piop.relation"() {degree = 3 : i64, domain = @trace, kind = "sumcheck", num_rounds = 4 : i64, output_count = 1 : i64, sym_name = "jolt.stage1.outer.remaining"} : () -> () + %0 = "transcript.state"() {scheme = @blake2b_transcript, sym_name = "fs0"} : () -> !transcript.state_type + %1, %alpha = "transcript.squeeze"(%0) {count = 1 : i64, kind = "scalar", label = "sumcheck_claim", sym_name = "stage1.alpha"} : (!transcript.state_type) -> (!transcript.state_type, !field.scalar) + %stage = "piop.stage"() {name = "stage1", order = 1 : i64, roles = ["prover", "verifier"], sym_name = "stage1"} : () -> !piop.stage_type + %claim_value = "field.const"() {field = @bn254_fr, value = 0 : i64, sym_name = "stage1.outer.claim_value"} : () -> !field.scalar + %claim = "piop.sumcheck_claim"(%claim_value) {claim = @stage1.outer.claim, degree = 3 : i64, domain = @trace, num_rounds = 4 : i64, relation = @jolt.stage1.outer.remaining, stage = @stage1, sym_name = "stage1.outer.claim"} : (!field.scalar) -> !piop.sumcheck_claim_type + %batch = "piop.sumcheck_batch"(%stage, %claim) {claim_label = "sumcheck_claim", count = 1 : i64, ordered_claims = [@stage1.outer.claim], policy = "jolt_core_front_loaded", proof_slot = @stage1.sumcheck, round_label = "sumcheck_poly", round_schedule = [2, 1, 1], stage = @stage1, sym_name = "stage1.outer.batch"} : (!piop.stage_type, !piop.sumcheck_claim_type) -> !piop.sumcheck_batch_type + %2, %point, %result, %proof = "piop.sumcheck"(%1, %batch) {claim_label = "sumcheck_claim", degree = 3 : i64, num_rounds = 4 : i64, policy = "jolt_core_front_loaded", proof_slot = @stage1.sumcheck, relation = @jolt.stage1.outer.remaining, round_label = "sumcheck_poly", round_schedule = [2, 1, 1], stage = @stage1, sym_name = "stage1.outer.sumcheck"} : (!transcript.state_type, !piop.sumcheck_batch_type) -> (!transcript.state_type, !poly.point, !piop.sumcheck_result_type, !piop.sumcheck_proof_type) + %eval = "piop.sumcheck_eval"(%result) {index = 0 : i64, name = @stage1.outer.eval, oracle = @RdInc, source = @stage1.outer.sumcheck, sym_name = "stage1.outer.eval"} : (!piop.sumcheck_result_type) -> !field.scalar + %opening = "pcs.opening_claim"(%point, %eval) {domain = @trace, family = @jolt.main_witness_polys, oracle = @RdInc, point_arity = 4 : i64, sym_name = "stage1.outer.opening"} : (!poly.point, !field.scalar) -> !pcs.opening_claim_type + %openings = "pcs.opening_batch"(%opening) {count = 1 : i64, ordered_claims = [@stage1.outer.opening], policy = "jolt_core_order", proof_slot = @stage1.openings, sym_name = "stage1.opening_batch"} : (!pcs.opening_claim_type) -> !pcs.opening_batch_type + %3, %opening_proof = "pcs.batch_open"(%2, %openings) {pcs = @dory, proof_slot = @stage1.openings, sym_name = "stage1.open", transcript_label = "opening_proof"} : (!transcript.state_type, !pcs.opening_batch_type) -> (!transcript.state_type, !pcs.opening_proof_type) +} +"# +} + +fn explicit_sumcheck_compute() -> &'static str { + r#" +module @explicit.sumcheck attributes {bolt.phase = "compute", bolt.role = "prover"} { + "compute.params"() {field = @bn254_fr, pcs = @dory, sym_name = "params", transcript = @blake2b_transcript} : () -> () + "compute.function"() {source = @explicit.sumcheck, sym_name = "explicit.sumcheck"} : () -> () + "compute.relation"() {degree = 3 : i64, domain = @trace, kind = "sumcheck", num_rounds = 4 : i64, output_count = 1 : i64, sym_name = "jolt.stage1.outer.remaining"} : () -> () + %0 = "compute.transcript_init"() {scheme = @blake2b_transcript, sym_name = "fs0"} : () -> !compute.transcript_state + %1, %alpha = "compute.transcript_squeeze"(%0) {count = 1 : i64, kind = "scalar", label = "sumcheck_claim", sym_name = "stage1.alpha"} : (!compute.transcript_state) -> (!compute.transcript_state, !compute.field_value) + %claim_value = "compute.field_const"() {field = @bn254_fr, value = 0 : i64, sym_name = "stage1.outer.claim_value"} : () -> !compute.field_value + %claim = "compute.sumcheck_claim"(%claim_value) {claim = @stage1.outer.claim, degree = 3 : i64, domain = @trace, num_rounds = 4 : i64, relation = @jolt.stage1.outer.remaining, stage = @stage1, sym_name = "stage1.outer.claim"} : (!compute.field_value) -> !compute.sumcheck_claim_type + %batch = "compute.sumcheck_batch"(%claim) {claim_label = "sumcheck_claim", count = 1 : i64, ordered_claims = [@stage1.outer.claim], policy = "jolt_core_front_loaded", proof_slot = @stage1.sumcheck, round_label = "sumcheck_poly", round_schedule = [2, 1, 1], stage = @stage1, sym_name = "stage1.outer.batch"} : (!compute.sumcheck_claim_type) -> !compute.sumcheck_batch_type + %2, %point, %result, %proof = "compute.sumcheck_driver"(%1, %batch) {claim_label = "sumcheck_claim", degree = 3 : i64, num_rounds = 4 : i64, policy = "jolt_core_front_loaded", proof_slot = @stage1.sumcheck, relation = @jolt.stage1.outer.remaining, round_label = "sumcheck_poly", round_schedule = [2, 1, 1], stage = @stage1, sym_name = "stage1.outer.sumcheck"} : (!compute.transcript_state, !compute.sumcheck_batch_type) -> (!compute.transcript_state, !compute.point, !compute.sumcheck_result_type, !compute.sumcheck_proof_type) + %eval = "compute.sumcheck_eval"(%result) {index = 0 : i64, name = @stage1.outer.eval, oracle = @RdInc, source = @stage1.outer.sumcheck, sym_name = "stage1.outer.eval"} : (!compute.sumcheck_result_type) -> !compute.field_value + %opening = "compute.pcs_opening_claim"(%point, %eval) {domain = @trace, family = @jolt.main_witness_polys, oracle = @RdInc, point_arity = 4 : i64, sym_name = "stage1.outer.opening"} : (!compute.point, !compute.field_value) -> !compute.opening_claim_type + %openings = "compute.pcs_opening_batch"(%opening) {count = 1 : i64, ordered_claims = [@stage1.outer.opening], policy = "jolt_core_order", proof_slot = @stage1.openings, sym_name = "stage1.opening_batch"} : (!compute.opening_claim_type) -> !compute.opening_batch_type + %3, %opening_proof = "compute.pcs_batch_open"(%2, %openings) {pcs = @dory, proof_slot = @stage1.openings, sym_name = "stage1.open", transcript_label = "opening_proof"} : (!compute.transcript_state, !compute.opening_batch_type) -> (!compute.transcript_state, !compute.opening_proof_type) +} +"# +} + +fn assert_or_update_fixture(path: &str, actual: &str) { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(path); + if std::env::var_os("JOLT_UPDATE_GOLDENS").is_some() { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).expect("create golden fixture directory"); + } + std::fs::write(&path, actual).expect("write golden fixture"); + return; + } + if !path.exists() { + return; + } + let expected = std::fs::read_to_string(&path).expect("read golden fixture"); + assert_eq!(expected, actual); +} + +fn assert_rust_source_compiles(_filename: &str, source: &str) { + if !generated_jolt_runtime_available() { + return; + } + let dir = new_temp_dir("bolt_emit"); + let workspace_root = workspace_root(); + std::fs::write( + dir.join("Cargo.toml"), + generated_crate_manifest(&workspace_root), + ) + .expect("write generated cargo manifest"); + std::fs::create_dir_all(dir.join("src")).expect("create generated src dir"); + if source.contains("super::common") { + let common = std::fs::read_to_string( + workspace_root.join("crates/jolt-verifier/src/stages/common.rs"), + ) + .expect("read generated verifier common stage source"); + std::fs::write(dir.join("src/common.rs"), common).expect("write generated common source"); + std::fs::write(dir.join("src/generated.rs"), source).expect("write generated source"); + std::fs::write( + dir.join("src/lib.rs"), + "pub mod common;\n#[rustfmt::skip]\npub mod generated;\n", + ) + .expect("write generated lib wrapper"); + } else { + std::fs::write(dir.join("src/lib.rs"), source).expect("write generated source"); + } + let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_owned()); + let output = Command::new(cargo) + .arg("check") + .arg("--manifest-path") + .arg(dir.join("Cargo.toml")) + .arg("-q") + .env("CARGO_TARGET_DIR", dir.join("target")) + .output() + .expect("run cargo check"); + assert!( + output.status.success(), + "generated rust did not compile\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let _ = std::fs::remove_dir_all(dir); +} + +fn assert_generated_role_crate_compiles(generated: &JoltGeneratedCrate) { + let dir = new_temp_dir(&generated.crate_name); + for file in &generated.files { + let path = dir.join(&file.path); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).expect("create generated crate dir"); + } + std::fs::write(path, &file.source).expect("write generated crate file"); + } + let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_owned()); + let output = Command::new(cargo) + .arg("check") + .arg("--manifest-path") + .arg(dir.join("Cargo.toml")) + .arg("-q") + .env("CARGO_TARGET_DIR", dir.join("target")) + .output() + .expect("run generated role crate check"); + assert!( + output.status.success(), + "generated role crate `{}` did not compile\nstdout:\n{}\nstderr:\n{}", + generated.crate_name, + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let _ = std::fs::remove_dir_all(dir); +} + +fn assert_generated_crate_manifest_compiles(output_root: &Path, crate_name: &str) { + let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_owned()); + let output = Command::new(cargo) + .arg("check") + .arg("--manifest-path") + .arg(output_root.join(crate_name).join("Cargo.toml")) + .arg("-q") + .env( + "CARGO_TARGET_DIR", + output_root.join("target").join(crate_name), + ) + .output() + .expect("run generated crate check"); + assert!( + output.status.success(), + "generated crate `{crate_name}` did not compile\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); +} + +fn redirect_generated_prover_to_generated_verifier(output_root: &Path, dependency_root: &Path) { + let manifest_path = output_root.join("jolt-prover").join("Cargo.toml"); + let workspace_verifier = format!( + "jolt-verifier = {{ path = \"{}/jolt-verifier\" }}", + dependency_root.display() + ); + let generated_verifier = format!( + "jolt-verifier = {{ path = \"{}\" }}", + output_root.join("jolt-verifier").display() + ); + let manifest = std::fs::read_to_string(&manifest_path).expect("read generated prover manifest"); + let manifest = manifest.replace(&workspace_verifier, &generated_verifier); + std::fs::write(&manifest_path, manifest).expect("rewrite generated prover manifest"); +} + +fn assert_checked_in_generated_role_crate_sources_match(generated_crates: &[JoltGeneratedCrate]) { + let crates_root = workspace_root().join("crates"); + for generated in generated_crates { + for file in &generated.files { + let checked_in_path = crates_root.join(&generated.crate_name).join(&file.path); + let checked_in = + std::fs::read_to_string(&checked_in_path).expect("read checked-in generated file"); + assert_eq!( + checked_in, + file.source, + "checked-in generated crate file `{}` is stale; regenerate with the Bolt artifact writer", + checked_in_path.display() + ); + if generated.crate_name == "jolt-verifier" { + assert!( + !checked_in.contains("use jolt_prover") + && !checked_in.contains("jolt_prover::") + && !checked_in.contains("use jolt_kernels") + && !checked_in.contains("jolt_kernels::") + && !checked_in.contains("use jolt_core") + && !checked_in.contains("jolt_core::"), + "generated verifier file `{}` imports non-audit role/runtime code", + checked_in_path.display() + ); + } + if generated.crate_name == "jolt-prover" { + assert!( + !checked_in.contains("jolt_verifier::stages"), + "generated prover file `{}` imports verifier stage internals instead of only verifier-owned proof types", + checked_in_path.display() + ); + } + } + } +} + +fn assert_generated_commitment_self_parity_runs( + prover_source: &RustSourceFile, + verifier_source: &RustSourceFile, + main_source: &str, +) { + if !generated_jolt_runtime_available() { + return; + } + let dir = new_temp_dir("bolt_self_parity"); + let workspace_root = workspace_root(); + std::fs::write( + dir.join("Cargo.toml"), + generated_crate_manifest(&workspace_root), + ) + .expect("write generated cargo manifest"); + let src_dir = dir.join("src"); + std::fs::create_dir_all(&src_dir).expect("create generated src dir"); + std::fs::write(src_dir.join(&prover_source.filename), &prover_source.source) + .expect("write generated prover source"); + std::fs::write( + src_dir.join(&verifier_source.filename), + &verifier_source.source, + ) + .expect("write generated verifier source"); + std::fs::write(src_dir.join("main.rs"), main_source) + .expect("write generated self-parity harness"); + + let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_owned()); + let output = Command::new(cargo) + .arg("run") + .arg("--manifest-path") + .arg(dir.join("Cargo.toml")) + .arg("-q") + .env("CARGO_TARGET_DIR", dir.join("target")) + .output() + .expect("run generated self-parity crate"); + assert!( + output.status.success(), + "generated commitment self-parity failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let _ = std::fs::remove_dir_all(dir); +} + +fn assert_generated_stage1_self_parity_runs( + prover_source: &RustSourceFile, + verifier_source: &RustSourceFile, + main_source: &str, +) { + if !generated_jolt_runtime_available() { + return; + } + let dir = new_temp_dir("bolt_stage1_self_parity"); + let workspace_root = workspace_root(); + std::fs::write( + dir.join("Cargo.toml"), + generated_crate_manifest(&workspace_root), + ) + .expect("write generated cargo manifest"); + let src_dir = dir.join("src"); + std::fs::create_dir_all(&src_dir).expect("create generated src dir"); + let main_source = if verifier_source.source.contains("super::common") { + write_verifier_common_module(&src_dir, &workspace_root); + format!("mod common;\n{main_source}") + } else { + main_source.to_owned() + }; + std::fs::write(src_dir.join(&prover_source.filename), &prover_source.source) + .expect("write generated stage1 prover source"); + std::fs::write( + src_dir.join(&verifier_source.filename), + &verifier_source.source, + ) + .expect("write generated stage1 verifier source"); + std::fs::write(src_dir.join("main.rs"), main_source) + .expect("write generated stage1 self-parity harness"); + + let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_owned()); + let output = Command::new(cargo) + .arg("run") + .arg("--manifest-path") + .arg(dir.join("Cargo.toml")) + .arg("-q") + .env("CARGO_TARGET_DIR", dir.join("target")) + .output() + .expect("run generated stage1 self-parity crate"); + assert!( + output.status.success(), + "generated stage1 self-parity failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let _ = std::fs::remove_dir_all(dir); +} + +fn assert_generated_jolt_chain_self_parity_runs(files: &[&RustSourceFile], main_source: &str) { + if !generated_jolt_runtime_available() { + return; + } + let dir = new_temp_dir("bolt_chain_self_parity"); + let workspace_root = workspace_root(); + std::fs::write( + dir.join("Cargo.toml"), + generated_crate_manifest(&workspace_root), + ) + .expect("write generated cargo manifest"); + let src_dir = dir.join("src"); + std::fs::create_dir_all(&src_dir).expect("create generated src dir"); + let main_source = if files + .iter() + .any(|file| file.source.contains("super::common")) + { + write_verifier_common_module(&src_dir, &workspace_root); + format!("mod common;\n{main_source}") + } else { + main_source.to_owned() + }; + for file in files { + std::fs::write(src_dir.join(&file.filename), &file.source) + .expect("write generated chain source"); + } + std::fs::write(src_dir.join("main.rs"), main_source) + .expect("write generated chain self-parity harness"); + + let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_owned()); + let output = Command::new(cargo) + .arg("run") + .arg("--manifest-path") + .arg(dir.join("Cargo.toml")) + .arg("-q") + .env("CARGO_TARGET_DIR", dir.join("target")) + .output() + .expect("run generated chain self-parity crate"); + assert!( + output.status.success(), + "generated commitment+stage1 self-parity failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let _ = std::fs::remove_dir_all(dir); +} + +fn write_verifier_common_module(src_dir: &Path, workspace_root: &Path) { + let common = + std::fs::read_to_string(workspace_root.join("crates/jolt-verifier/src/stages/common.rs")) + .expect("read generated verifier common stage source"); + std::fs::write(src_dir.join("common.rs"), common).expect("write generated common source"); +} + +fn workspace_root() -> std::path::PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .and_then(Path::parent) + .expect("workspace root") + .to_path_buf() +} + +fn generated_jolt_runtime_available() -> bool { + let workspace_root = workspace_root(); + workspace_root + .join("crates/jolt-kernels/Cargo.toml") + .exists() + && workspace_root + .join("crates/jolt-verifier/src/stages/common.rs") + .exists() +} + +fn checked_in_generated_role_crates_available() -> bool { + let workspace_root = workspace_root(); + generated_jolt_runtime_available() + && workspace_root + .join("crates/jolt-prover/Cargo.toml") + .exists() + && workspace_root + .join("crates/jolt-verifier/Cargo.toml") + .exists() +} + +fn generated_crate_manifest(workspace_root: &Path) -> String { + format!( + r#"[package] +name = "generated-commitment-phase-check" +version = "0.0.0" +edition = "2021" + +[patch.crates-io] +ark-bn254 = {{ git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" }} +ark-ec = {{ git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" }} +ark-ff = {{ git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" }} +ark-serialize = {{ git = "https://github.com/a16z/arkworks-algebra", branch = "dev/twist-shout" }} + +[dependencies] +jolt-dory = {{ path = "{}" }} +jolt-field = {{ path = "{}" }} +jolt-kernels = {{ path = "{}" }} +jolt-lookup-tables = {{ path = "{}" }} +jolt-openings = {{ path = "{}" }} +jolt-poly = {{ path = "{}" }} +jolt-r1cs = {{ path = "{}" }} +jolt-sumcheck = {{ path = "{}" }} +jolt-transcript = {{ path = "{}" }} +jolt-witness = {{ path = "{}" }} +rayon = "1.12.0" +serde = {{ version = "1.0", default-features = false, features = ["derive"] }} +tracing = {{ version = "0.1.37", default-features = false, features = ["attributes"] }} +"#, + workspace_root.join("crates/jolt-dory").display(), + workspace_root.join("crates/jolt-field").display(), + workspace_root.join("crates/jolt-kernels").display(), + workspace_root.join("crates/jolt-lookup-tables").display(), + workspace_root.join("crates/jolt-openings").display(), + workspace_root.join("crates/jolt-poly").display(), + workspace_root.join("crates/jolt-r1cs").display(), + workspace_root.join("crates/jolt-sumcheck").display(), + workspace_root.join("crates/jolt-transcript").display(), + workspace_root.join("crates/jolt-witness").display(), + ) +} + +fn new_temp_dir(prefix: &str) -> std::path::PathBuf { + let nonce = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock after unix epoch") + .as_nanos(); + let dir = std::env::temp_dir().join(format!("{}_{}_{}", prefix, std::process::id(), nonce)); + std::fs::create_dir_all(&dir).expect("create generated crate temp dir"); + dir +} + +fn generated_small_self_parity_main() -> String { + let mut source = r#"mod prove_commitment_phase; +mod verify_commitment_phase; + +use std::borrow::Cow; + +use jolt_dory::DoryScheme; +use jolt_field::{Field, Fr}; +use jolt_transcript::{Blake2bTranscript, Transcript}; + +struct Inputs; + +impl prove_commitment_phase::CommitmentInputProvider for Inputs { + fn materialize(&mut self, oracle: &'static str) -> Option> { + match oracle { + "A" => Some(Cow::Owned(vec![Fr::from_u64(1), Fr::from_u64(2)])), + "B" => Some(Cow::Owned(vec![ + Fr::from_u64(3), + Fr::from_u64(4), + Fr::from_u64(5), + Fr::from_u64(6), + ])), + "Advice" => Some(Cow::Owned(vec![Fr::from_u64(0), Fr::from_u64(0)])), + _ => None, + } + } +} + +"# + .to_owned(); + source.push_str(tracing_transcript_support()); + source.push_str( + r#" +fn main() { + let prover_setup = + DoryScheme::setup_prover(prove_commitment_phase::COMMITMENT_BATCH_PLANS[0].num_vars); + let mut inputs = Inputs; + let mut prover_transcript = TracingTranscript::new(b"self"); + let prover = prove_commitment_phase::prove_commitment_phase( + &mut inputs, + &prover_setup, + &mut prover_transcript, + ) + .expect("prover commitment phase"); + + assert_eq!(prover.commitments.len(), 3); + assert!(prover.commitments[2].is_none()); + + let mut verifier_transcript = TracingTranscript::new(b"self"); + let verifier = verify_commitment_phase::verify_commitment_phase( + &prover.commitments, + &mut verifier_transcript, + ) + .expect("verifier commitment phase"); + + assert_eq!(prover.commitments, verifier.commitments); + assert_eq!(prover.records.len(), verifier.records.len()); + assert_transcript_step_parity(&prover_transcript, &verifier_transcript); +} +"#, + ); + source +} + +fn generated_pipeline_self_parity_main() -> String { + let mut source = "mod prove_commitment_phase; +mod verify_commitment_phase; + +use jolt_dory::DoryScheme; +use jolt_field::{Field, Fr}; +use jolt_transcript::{Blake2bTranscript, Transcript}; + +" + .to_owned(); + source.push_str(tracing_transcript_support()); + source.push_str( + r#" +fn main() { + let prover_setup = + DoryScheme::setup_prover(prove_commitment_phase::COMMITMENT_BATCH_PLANS[0].num_vars); + let inputs = prove_commitment_phase::CommitmentOracleInputs { + rd_inc: &[1], + ram_inc: &[2], + instruction_keys: &[Some(0x1234_5678_9abc_def0_0123_4567_89ab_cdefu128)], + ram_addresses: &[], + bytecode_indices: &[], + untrusted_advice: None, + trusted_advice: None, + }; + let mut oracles = prove_commitment_phase::build_commitment_oracles(&inputs) + .expect("build commitment oracles"); + let mut prover_transcript = TracingTranscript::new(b"pipeline"); + let prover = prove_commitment_phase::prove_commitment_phase( + &mut oracles, + &prover_setup, + &mut prover_transcript, + ) + .expect("prover commitment phase"); + + let expected_slots = prove_commitment_phase::COMMITMENT_BATCH_PLANS + .iter() + .map(|plan| plan.oracles.len()) + .sum::() + + prove_commitment_phase::OPTIONAL_COMMITMENT_PLANS.len(); + assert_eq!(prover.commitments.len(), expected_slots); + + let mut verifier_transcript = TracingTranscript::new(b"pipeline"); + let verifier = verify_commitment_phase::verify_commitment_phase( + &prover.commitments, + &mut verifier_transcript, + ) + .expect("verifier commitment phase"); + + assert_eq!(prover.commitments, verifier.commitments); + assert_eq!(prover.records.len(), verifier.records.len()); + for (prover_record, verifier_record) in prover.records.iter().zip(&verifier.records) { + assert_eq!(prover_record.artifact, verifier_record.artifact); + assert_eq!(prover_record.oracle, verifier_record.oracle); + assert_eq!(prover_record.label, verifier_record.label); + assert_eq!(prover_record.num_vars, verifier_record.num_vars); + } + assert_transcript_step_parity(&prover_transcript, &verifier_transcript); +} +"#, + ); + source +} + +fn generated_stage1_shape_self_parity_main() -> String { + let mut source = r"mod prove_stage1_outer; +mod verify_stage1_outer; + +use jolt_field::{Field, Fr}; +use jolt_kernels::stage1::Stage1ShapeKernelExecutor; +use jolt_transcript::{Blake2bTranscript, Transcript}; + +" + .to_owned(); + source.push_str(tracing_transcript_support()); + source.push_str(&stage1_verifier_proof_adapter(true)); + source.push_str( + r#" +fn main() { + let mut prover_executor = Stage1ShapeKernelExecutor; + let mut prover_transcript = TracingTranscript::new(b"stage1"); + let prover = prove_stage1_outer::prove_stage1_outer( + &mut prover_executor, + &mut prover_transcript, + ) + .expect("generated prover runs shape kernels"); + + let proof = verifier_proof_from_prover_artifacts(&prover); + let mut verifier_transcript = TracingTranscript::new(b"stage1"); + let verifier = verify_stage1_outer::verify_stage1_outer( + &proof, + &mut verifier_transcript, + ) + .expect("generated verifier accepts shape proof"); + + assert_eq!( + prover.sumchecks.len(), + prove_stage1_outer::STAGE1_SUMCHECK_DRIVERS.len() + ); + assert_eq!(prover.sumchecks.len(), verifier.sumchecks.len()); + assert_eq!(prover.opening_batches.len(), verifier.opening_batches.len()); + for (prover_batch, verifier_batch) in prover.opening_batches.iter().zip(&verifier.opening_batches) { + assert_eq!(prover_batch.symbol, verifier_batch.symbol); + assert_eq!(prover_batch.count, verifier_batch.count); + } + for (prover_sumcheck, verifier_sumcheck) in prover.sumchecks.iter().zip(&verifier.sumchecks) { + assert_eq!(prover_sumcheck.driver, verifier_sumcheck.driver); + assert_eq!(prover_sumcheck.evals.len(), verifier_sumcheck.evals.len()); + for (prover_eval, verifier_eval) in prover_sumcheck.evals.iter().zip(&verifier_sumcheck.evals) { + assert_eq!(prover_eval.name, verifier_eval.name); + assert_eq!(prover_eval.oracle, verifier_eval.oracle); + assert_eq!(prover_eval.value, verifier_eval.value); + } + assert_eq!( + prover_sumcheck.proof.round_polynomials.len(), + verifier_sumcheck.proof.round_polynomials.len() + ); + for (prover_round, verifier_round) in prover_sumcheck + .proof + .round_polynomials + .iter() + .zip(&verifier_sumcheck.proof.round_polynomials) + { + assert_eq!(prover_round.coefficients(), verifier_round.coefficients()); + } + } + assert_ne!(prover_transcript.state(), verifier_transcript.state()); +} +"#, + ); + source +} + +fn stage1_verifier_proof_adapter(clear_points: bool) -> String { + let point_expr = if clear_points { + "Vec::new()" + } else { + "sumcheck.point.clone()" + }; + r" +fn verifier_proof_from_prover_artifacts( + artifacts: &jolt_kernels::stage1::Stage1ExecutionArtifacts, +) -> verify_stage1_outer::Stage1Proof { + verify_stage1_outer::Stage1Proof { + sumchecks: artifacts + .sumchecks + .iter() + .map(|sumcheck| verify_stage1_outer::Stage1SumcheckOutput { + driver: sumcheck.driver, + point: $POINT_EXPR, + evals: sumcheck + .evals + .iter() + .map(|eval| verify_stage1_outer::Stage1NamedEval { + name: eval.name, + oracle: eval.oracle, + value: eval.value, + }) + .collect(), + proof: sumcheck.proof.clone(), + }) + .collect(), + } +} + +" + .replace("$POINT_EXPR", point_expr) +} + +fn generated_stage1_real_dispatch_main() -> &'static str { + r#"mod prove_stage1_outer; +mod verify_stage1_outer; + +use jolt_field::{Field, Fr}; +use jolt_kernels::stage1::{ + Stage1KernelError, Stage1ProverInputs, Stage1ProverKernelExecutor, +}; +use jolt_sumcheck::SumcheckError; +use jolt_transcript::{Blake2bTranscript, Transcript}; + +fn main() { + let inputs = Stage1ProverInputs::::empty(2); + let mut prover_executor = Stage1ProverKernelExecutor::new(inputs); + let mut prover_transcript = Blake2bTranscript::::new(b"stage1"); + let prover_error = prove_stage1_outer::prove_stage1_outer( + &mut prover_executor, + &mut prover_transcript, + ) + .expect_err("real prover requires uniskip extended evaluations"); + assert_eq!( + prover_error, + Stage1KernelError::MissingKernelInput { + kernel: "jolt_stage1_outer_uniskip", + input: "uniskip_extended_evals", + } + ); + + let proof = verify_stage1_outer::Stage1Proof { + sumchecks: vec![ + verify_stage1_outer::Stage1SumcheckOutput { + driver: "stage1.uniskip.sumcheck", + point: Vec::new(), + evals: Vec::new(), + proof: Default::default(), + }, + verify_stage1_outer::Stage1SumcheckOutput { + driver: "stage1.outer_remaining.sumcheck", + point: Vec::new(), + evals: Vec::new(), + proof: Default::default(), + }, + ], + }; + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + let verifier_error = verify_stage1_outer::verify_stage1_outer( + &proof, + &mut verifier_transcript, + ) + .expect_err("real verifier rejects empty uniskip proof"); + assert!(matches!( + verifier_error, + verify_stage1_outer::VerifyStage1Error::Sumcheck { + driver: "stage1.uniskip.sumcheck", + error: SumcheckError::WrongNumberOfRounds { expected: 1, got: 0 }, + } + )); +} +"# +} + +fn generated_stage1_synthetic_remaining_main() -> String { + let mut source = r"mod prove_stage1_outer; +mod verify_stage1_outer; + +use jolt_field::{Field, Fr}; +use jolt_kernels::stage1::{ + Stage1OuterRemainingContext, Stage1OuterRemainingEvaluator, Stage1ProverInputs, + Stage1ProverKernelExecutor, +}; +use jolt_poly::UnivariatePoly; +use jolt_sumcheck::SumcheckError; +use jolt_transcript::{Blake2bTranscript, Transcript}; + +struct SumZeroRemainingEvaluator; + +impl Stage1OuterRemainingEvaluator for SumZeroRemainingEvaluator { + fn evaluate(&self, _context: Stage1OuterRemainingContext<'_, Fr>, point: &[Fr]) -> Fr { + point[0] + point[0] - Fr::from_u64(1) + } + + fn evaluate_virtual_oracle( + &self, + _context: Stage1OuterRemainingContext<'_, Fr>, + _oracle: &str, + point: &[Fr], + ) -> Option { + Some(point.iter().copied().sum()) + } +} + +" + .to_owned(); + source.push_str(&stage1_verifier_proof_adapter(false)); + source.push_str( + r#" +fn main() { + let extended_evals = vec![Fr::from_u64(0); 9]; + let evaluator = SumZeroRemainingEvaluator; + let inputs = Stage1ProverInputs::::empty(2) + .with_uniskip_extended_evals(&extended_evals) + .with_outer_remaining_evaluator(&evaluator); + let mut prover_executor = Stage1ProverKernelExecutor::new(inputs); + let mut prover_transcript = Blake2bTranscript::::new(b"stage1"); + let prover_artifacts = prove_stage1_outer::prove_stage1_outer( + &mut prover_executor, + &mut prover_transcript, + ) + .expect("generated real stage1 prover succeeds"); + + let proof = verifier_proof_from_prover_artifacts(&prover_artifacts); + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + let verifier_artifacts = verify_stage1_outer::verify_stage1_outer( + &proof, + &mut verifier_transcript, + ) + .expect("generated real stage1 verifier accepts prover proof"); + + assert_eq!(prover_transcript.state(), verifier_transcript.state()); + assert_eq!(prover_artifacts.sumchecks.len(), 2); + assert_eq!(verifier_artifacts.sumchecks.len(), 2); + assert_eq!( + prover_artifacts.sumchecks[1].point, + verifier_artifacts.sumchecks[1].point + ); + + let mut extra_proof = proof.clone(); + extra_proof.sumchecks.push(proof.sumchecks[0].clone()); + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + assert!(matches!( + verify_stage1_outer::verify_stage1_outer(&extra_proof, &mut verifier_transcript), + Err(verify_stage1_outer::VerifyStage1Error::UnexpectedProofCount { + expected: 2, + got: 3, + }) + )); + + let mut wrong_driver = proof.clone(); + wrong_driver.sumchecks.swap(0, 1); + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + assert!(matches!( + verify_stage1_outer::verify_stage1_outer(&wrong_driver, &mut verifier_transcript), + Err(verify_stage1_outer::VerifyStage1Error::InvalidProof { + driver: "stage1.uniskip.sumcheck", + reason: "driver symbol mismatch", + }) + )); + + let mut wrong_round = proof.clone(); + let mut coefficients = wrong_round.sumchecks[0].proof.round_polynomials[0] + .coefficients() + .to_vec(); + coefficients[0] += Fr::from_u64(1); + wrong_round.sumchecks[0].proof.round_polynomials[0] = UnivariatePoly::new(coefficients); + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + assert!(matches!( + verify_stage1_outer::verify_stage1_outer(&wrong_round, &mut verifier_transcript), + Err(verify_stage1_outer::VerifyStage1Error::Sumcheck { + driver: "stage1.uniskip.sumcheck", + error: SumcheckError::RoundCheckFailed { .. }, + }) + )); + + let mut wrong_uniskip_eval = proof.clone(); + wrong_uniskip_eval.sumchecks[0].evals[0].value += Fr::from_u64(1); + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + assert!(matches!( + verify_stage1_outer::verify_stage1_outer(&wrong_uniskip_eval, &mut verifier_transcript), + Err(verify_stage1_outer::VerifyStage1Error::InvalidProof { + driver: "stage1.uniskip.sumcheck", + reason: "eval value mismatch", + }) + )); + + let mut wrong_remaining_eval = proof.clone(); + wrong_remaining_eval.sumchecks[1].evals.swap(0, 1); + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + assert!(matches!( + verify_stage1_outer::verify_stage1_outer(&wrong_remaining_eval, &mut verifier_transcript), + Err(verify_stage1_outer::VerifyStage1Error::InvalidProof { + driver: "stage1.outer_remaining.sumcheck", + reason: "eval name mismatch", + }) + )); + + let mut wrong_point = proof.clone(); + wrong_point.sumchecks[1].point[0] += Fr::from_u64(1); + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + assert!(matches!( + verify_stage1_outer::verify_stage1_outer(&wrong_point, &mut verifier_transcript), + Err(verify_stage1_outer::VerifyStage1Error::InvalidProof { + driver: "stage1.outer_remaining.sumcheck", + reason: "outer remaining point mismatch", + }) + )); +} +"#, + ); + source +} + +fn generated_stage1_r1cs_data_main() -> String { + let mut source = r"mod prove_stage1_outer; +mod verify_stage1_outer; + +use jolt_field::{Field, Fr}; +use jolt_kernels::stage1::{ + Stage1OuterR1csData, Stage1ProverInputs, Stage1ProverKernelExecutor, +}; +use jolt_r1cs::{constraints::rv64, R1csKey}; +use jolt_transcript::{Blake2bTranscript, Transcript}; + +" + .to_owned(); + source.push_str(&stage1_verifier_proof_adapter(false)); + source.push_str( + r#" +fn main() { + let key = R1csKey::new(rv64::rv64_constraints::(), 4); + let mut witness = vec![Fr::from_u64(0); key.num_cycles * key.num_vars_padded]; + for cycle in 0..key.num_cycles { + let base = cycle * key.num_vars_padded; + witness[base + rv64::V_CONST] = Fr::from_u64(1); + witness[base + rv64::V_FLAG_DO_NOT_UPDATE_UNEXPANDED_PC] = Fr::from_u64(1); + key.matrices + .check_witness(&witness[base..base + rv64::NUM_VARS_PER_CYCLE]) + .expect("noop cycle satisfies RV64 constraints"); + } + let data = Stage1OuterR1csData::new(&key, &witness).expect("valid R1CS witness shape"); + let inputs = Stage1ProverInputs::::empty(key.num_cycle_vars()) + .with_outer_remaining_evaluator(&data); + let mut prover_executor = Stage1ProverKernelExecutor::new(inputs); + let mut prover_transcript = Blake2bTranscript::::new(b"stage1"); + let prover_artifacts = prove_stage1_outer::prove_stage1_outer( + &mut prover_executor, + &mut prover_transcript, + ) + .expect("generated real stage1 prover succeeds with R1CS data"); + + let proof = verifier_proof_from_prover_artifacts(&prover_artifacts); + let mut verifier_transcript = Blake2bTranscript::::new(b"stage1"); + let verifier_artifacts = verify_stage1_outer::verify_stage1_outer( + &proof, + &mut verifier_transcript, + ) + .expect("generated real stage1 verifier accepts R1CS-backed proof"); + + assert_eq!(prover_transcript.state(), verifier_transcript.state()); + assert_eq!(prover_artifacts.sumchecks.len(), 2); + assert_eq!(verifier_artifacts.sumchecks.len(), 2); + for (prover_sumcheck, verifier_sumcheck) in prover_artifacts + .sumchecks + .iter() + .zip(verifier_artifacts.sumchecks.iter()) + { + assert_eq!(prover_sumcheck.point, verifier_sumcheck.point); + assert_eq!(prover_sumcheck.evals.len(), verifier_sumcheck.evals.len()); + for (prover_eval, verifier_eval) in prover_sumcheck.evals.iter().zip(&verifier_sumcheck.evals) { + assert_eq!(prover_eval.oracle, verifier_eval.oracle); + assert_eq!(prover_eval.value, verifier_eval.value); + } + } +} +"#, + ); + source +} + +fn generated_commitment_stage1_chain_main() -> String { + let mut source = r"mod prove_commitment_phase; +mod prove_stage1_outer; +mod verify_commitment_phase; +mod verify_stage1_outer; + +use jolt_dory::DoryScheme; +use jolt_field::{Field, Fr}; +use jolt_kernels::stage1::{ + Stage1OuterRemainingContext, Stage1OuterRemainingEvaluator, Stage1ProverInputs, + Stage1ProverKernelExecutor, +}; +use jolt_transcript::{Blake2bTranscript, Transcript}; + +struct SumZeroRemainingEvaluator; + +impl Stage1OuterRemainingEvaluator for SumZeroRemainingEvaluator { + fn evaluate(&self, _context: Stage1OuterRemainingContext<'_, Fr>, point: &[Fr]) -> Fr { + point[0] + point[0] - Fr::from_u64(1) + } + + fn evaluate_virtual_oracle( + &self, + _context: Stage1OuterRemainingContext<'_, Fr>, + _oracle: &str, + point: &[Fr], + ) -> Option { + Some(point.iter().copied().sum()) + } +} + +" + .to_owned(); + source.push_str(tracing_transcript_support()); + source.push_str(&stage1_verifier_proof_adapter(false)); + source.push_str( + r#" +fn main() { + let prover_setup = + DoryScheme::setup_prover(prove_commitment_phase::COMMITMENT_BATCH_PLANS[0].num_vars); + let commitment_inputs = prove_commitment_phase::CommitmentOracleInputs { + rd_inc: &[1, 0, 0, 0], + ram_inc: &[2, 0, 0, 0], + instruction_keys: &[ + Some(0x1234_5678_9abc_def0_0123_4567_89ab_cdefu128), + Some(0), + Some(0), + Some(0), + ], + ram_addresses: &[Some(0), Some(1), Some(2), Some(3)], + bytecode_indices: &[Some(0), Some(1), Some(2), Some(3)], + untrusted_advice: None, + trusted_advice: None, + }; + let mut commitment_oracles = prove_commitment_phase::build_commitment_oracles( + &commitment_inputs, + ) + .expect("build commitment oracles"); + let mut prover_transcript = TracingTranscript::new(b"jolt-chain"); + let commitment = prove_commitment_phase::prove_commitment_phase( + &mut commitment_oracles, + &prover_setup, + &mut prover_transcript, + ) + .expect("prover commitment phase"); + + let extended_evals = vec![Fr::from_u64(0); 9]; + let evaluator = SumZeroRemainingEvaluator; + let stage1_inputs = Stage1ProverInputs::::empty(2) + .with_uniskip_extended_evals(&extended_evals) + .with_outer_remaining_evaluator(&evaluator); + let mut stage1_prover_executor = Stage1ProverKernelExecutor::new(stage1_inputs); + let stage1 = prove_stage1_outer::prove_stage1_outer( + &mut stage1_prover_executor, + &mut prover_transcript, + ) + .expect("stage1 prover phase"); + + let mut verifier_transcript = TracingTranscript::new(b"jolt-chain"); + let verified_commitment = verify_commitment_phase::verify_commitment_phase( + &commitment.commitments, + &mut verifier_transcript, + ) + .expect("verifier commitment phase"); + let stage1_proof = verifier_proof_from_prover_artifacts(&stage1); + let verified_stage1 = verify_stage1_outer::verify_stage1_outer( + &stage1_proof, + &mut verifier_transcript, + ) + .expect("stage1 verifier phase"); + + assert_eq!(commitment.commitments, verified_commitment.commitments); + assert_eq!(stage1.sumchecks.len(), 2); + assert_eq!(verified_stage1.sumchecks.len(), 2); + assert_eq!(stage1.sumchecks[1].point, verified_stage1.sumchecks[1].point); + assert_transcript_step_parity(&prover_transcript, &verifier_transcript); +} +"#, + ); + source +} + +fn tracing_transcript_support() -> &'static str { + r"#[derive(Clone, Debug, PartialEq, Eq)] +enum TranscriptEvent { + Init([u8; 32]), + Append { bytes: Vec, state: [u8; 32] }, + Challenge { state: [u8; 32] }, +} + +#[derive(Clone, Default)] +struct TracingTranscript { + inner: Blake2bTranscript, + events: Vec, +} + +impl Transcript for TracingTranscript { + type Challenge = Fr; + + fn new(label: &'static [u8]) -> Self { + let inner = Blake2bTranscript::::new(label); + let events = vec![TranscriptEvent::Init(*inner.state())]; + Self { inner, events } + } + + fn append_bytes(&mut self, bytes: &[u8]) { + self.inner.append_bytes(bytes); + self.events.push(TranscriptEvent::Append { + bytes: bytes.to_vec(), + state: *self.inner.state(), + }); + } + + fn challenge(&mut self) -> Fr { + let challenge = self.inner.challenge(); + self.events.push(TranscriptEvent::Challenge { + state: *self.inner.state(), + }); + challenge + } + + fn state(&self) -> &[u8; 32] { + self.inner.state() + } +} + +fn assert_transcript_step_parity(prover: &TracingTranscript, verifier: &TracingTranscript) { + assert_eq!(prover.events, verifier.events); + assert_eq!(prover.state(), verifier.state()); +} + +" +} diff --git a/crates/bolt/tests/verifier_cleanup.rs b/crates/bolt/tests/verifier_cleanup.rs new file mode 100644 index 0000000000..1a8318501f --- /dev/null +++ b/crates/bolt/tests/verifier_cleanup.rs @@ -0,0 +1,606 @@ +#![expect( + clippy::expect_used, + clippy::print_stderr, + reason = "verifier cleanup tests use explicit panic messages and print metrics for CI logs" +)] + +use std::path::{Path, PathBuf}; + +const GENERATED_VERIFIER_TARGET_LOC: usize = 6_000; +const GENERATED_VERIFIER_STRETCH_LOC: usize = 3_000; +const VERIFIER_RS_TARGET_LOC: usize = 500; +const VERIFIER_RS_STRETCH_LOC: usize = 350; +const STAGE6_STAGE7_TARGET_LOC: usize = 3_000; + +const GENERATED_VERIFIER_BASELINE_LOC_CEILING: usize = 9_185; +const SHARED_RUNTIME_BASELINE_LOC_CEILING: usize = 1_900; +const VERIFIER_RS_BASELINE_LOC_CEILING: usize = VERIFIER_RS_TARGET_LOC; +const STAGE6_STAGE7_BASELINE_LOC_CEILING: usize = STAGE6_STAGE7_TARGET_LOC; +const STAGE_LOCAL_PLAN_STRUCT_BASELINE_CEILING: usize = 18; +const FIELD_EXPR_OPERAND_CONSTANT_BASELINE_CEILING: usize = 0; +const STAGE_HELPER_FUNCTION_BASELINE_CEILING: usize = 38; +const RELATION_STRING_SITE_BASELINE_CEILING: usize = 72; + +const ALLOWED_JOLT_PROTOCOL_SYMBOLS: &[&str] = &[ + "jolt.commitment_phase", + "jolt.main_witness_commit_domain", + "jolt.main_witness_commitments", + "jolt.main_witness_polys", + "jolt.ram_address_domain", + "jolt.stage1.outer.remaining", + "jolt.stage1.outer.uniskip", + "jolt.stage1_outer", + "jolt.stage1_uniskip_domain", + "jolt.stage2", + "jolt.stage2.batched", + "jolt.stage2.instruction_lookup.claim_reduction", + "jolt.stage2.product_virtual.remainder", + "jolt.stage2.product_virtual.uniskip", + "jolt.stage2.ram.output_check", + "jolt.stage2.ram.output_check.layout", + "jolt.stage2.ram.raf_evaluation", + "jolt.stage2.ram.read_write", + "jolt.stage2_ram_rw_domain", + "jolt.stage2_uniskip_domain", + "jolt.stage3", + "jolt.stage3.batched", + "jolt.stage3.instruction_input", + "jolt.stage3.registers_claim_reduction", + "jolt.stage3.spartan_shift", + "jolt.stage4", + "jolt.stage4.batched", + "jolt.stage4.ram_val_check", + "jolt.stage4.registers_read_write", + "jolt.stage4_registers_rw_domain", + "jolt.stage5.batched", + "jolt.stage5.instruction_read_raf", + "jolt.stage5.ram_ra_claim_reduction", + "jolt.stage5.registers_val_evaluation", + "jolt.stage5_instruction_ra_chunk_domain", + "jolt.stage5_instruction_read_raf_domain", + "jolt.stage6.batched", + "jolt.stage6.booleanity", + "jolt.stage6.bytecode_read_raf", + "jolt.stage6.hamming_booleanity", + "jolt.stage6.inc_claim_reduction", + "jolt.stage6.instruction_ra_virtual", + "jolt.stage6.ram_ra_virtual", + "jolt.stage6_booleanity_domain", + "jolt.stage6_bytecode_read_raf_domain", + "jolt.stage7.batched", + "jolt.stage7.hamming_booleanity", + "jolt.stage7.hamming_weight_claim_reduction", + "jolt.stage7_hamming_weight_claim_reduction_domain", + "jolt.stage8", + "jolt.trace_domain", + "jolt.trusted_advice_commitment", + "jolt.untrusted_advice_commitment", +]; + +const GENERIC_COMPILER_JOLT_PATTERNS: &[&str] = &[ + "jolt.", + "Jolt", + "jolt_", + "jolt-", + "jolt_core", + "stage1_outer", + "stage1", + "stage2", + "stage3", + "stage4", + "stage5", + "stage6", + "stage7", + "stage8", + "uniskip", + "spartan", + "bytecode", + "hamming", + "instruction_read", + "ram_val", + "ram_ra", + "registers_read", + "lookup", + "dory", + "bn254", +]; + +#[derive(Debug, Default)] +struct VerifierCleanupMetrics { + total_loc: usize, + generated_surface_loc: usize, + shared_runtime_loc: usize, + verifier_rs_loc: usize, + stage6_stage7_loc: usize, + stage_local_generic_plan_structs: usize, + field_expr_operand_constants: usize, + stage_local_helper_functions: usize, + relation_string_sites: usize, +} + +#[test] +fn checked_in_generated_verifier_metrics_are_recorded_and_bounded() { + let verifier_src = workspace_root().join("crates/jolt-verifier/src"); + if !verifier_src.exists() { + return; + } + let metrics = verifier_cleanup_metrics(&verifier_src); + + eprintln!( + "\nGenerated verifier cleanup metrics\n\ + generated_surface_loc: {generated_surface_loc} (target <= {target_loc}, stretch <= {stretch_loc})\n\ + shared_runtime_loc: {shared_runtime_loc} (baseline ceiling <= {shared_runtime_baseline})\n\ + total_loc: {total_loc} (baseline ceiling <= {baseline_loc})\n\ + verifier_rs_loc: {verifier_rs_loc} (target <= {verifier_target}, stretch <= {verifier_stretch}, baseline ceiling <= {verifier_baseline})\n\ + stage6_stage7_loc: {stage6_stage7_loc} (target <= {stage67_target}, baseline ceiling <= {stage67_baseline})\n\ + stage_local_generic_plan_structs: {plan_structs} (baseline ceiling <= {plan_baseline})\n\ + field_expr_operand_constants: {operand_constants} (baseline ceiling <= {operand_baseline})\n\ + stage_local_helper_functions: {helper_functions} (baseline ceiling <= {helper_baseline})\n\ + relation_string_sites: {relation_sites} (baseline ceiling <= {relation_baseline})", + generated_surface_loc = metrics.generated_surface_loc, + shared_runtime_loc = metrics.shared_runtime_loc, + shared_runtime_baseline = SHARED_RUNTIME_BASELINE_LOC_CEILING, + total_loc = metrics.total_loc, + target_loc = GENERATED_VERIFIER_TARGET_LOC, + stretch_loc = GENERATED_VERIFIER_STRETCH_LOC, + baseline_loc = GENERATED_VERIFIER_BASELINE_LOC_CEILING, + verifier_rs_loc = metrics.verifier_rs_loc, + verifier_target = VERIFIER_RS_TARGET_LOC, + verifier_stretch = VERIFIER_RS_STRETCH_LOC, + verifier_baseline = VERIFIER_RS_BASELINE_LOC_CEILING, + stage6_stage7_loc = metrics.stage6_stage7_loc, + stage67_target = STAGE6_STAGE7_TARGET_LOC, + stage67_baseline = STAGE6_STAGE7_BASELINE_LOC_CEILING, + plan_structs = metrics.stage_local_generic_plan_structs, + plan_baseline = STAGE_LOCAL_PLAN_STRUCT_BASELINE_CEILING, + operand_constants = metrics.field_expr_operand_constants, + operand_baseline = FIELD_EXPR_OPERAND_CONSTANT_BASELINE_CEILING, + helper_functions = metrics.stage_local_helper_functions, + helper_baseline = STAGE_HELPER_FUNCTION_BASELINE_CEILING, + relation_sites = metrics.relation_string_sites, + relation_baseline = RELATION_STRING_SITE_BASELINE_CEILING, + ); + + assert!( + metrics.generated_surface_loc <= GENERATED_VERIFIER_TARGET_LOC, + "generated verifier surface is {} LOC; keep reducing generated stage/orchestration code or intentionally update the cleanup target", + metrics.generated_surface_loc + ); + assert!( + metrics.generated_surface_loc > GENERATED_VERIFIER_STRETCH_LOC, + "cleanup metric reached the stretch target; tighten the generated verifier surface gate" + ); + assert!( + metrics.shared_runtime_loc <= SHARED_RUNTIME_BASELINE_LOC_CEILING, + "shared verifier runtime grew to {} LOC; keep generic runtime small and audited", + metrics.shared_runtime_loc + ); + assert!( + metrics.total_loc <= GENERATED_VERIFIER_BASELINE_LOC_CEILING, + "checked-in verifier grew to {} LOC; lower generated/runtime surface, or intentionally update the cleanup baseline", + metrics.total_loc + ); + assert!( + metrics.verifier_rs_loc <= VERIFIER_RS_BASELINE_LOC_CEILING, + "top-level verifier grew to {} LOC; keep orchestration small and readable", + metrics.verifier_rs_loc + ); + assert!( + metrics.stage6_stage7_loc <= STAGE6_STAGE7_BASELINE_LOC_CEILING, + "Stage 6/7 generated verifier surface grew to {} LOC; compact plan data before adding more generated code", + metrics.stage6_stage7_loc + ); + assert!( + metrics.stage_local_generic_plan_structs <= STAGE_LOCAL_PLAN_STRUCT_BASELINE_CEILING, + "stage-local generic plan struct count grew to {}; move shared plan types into common verifier runtime", + metrics.stage_local_generic_plan_structs + ); + assert!( + metrics.field_expr_operand_constants == FIELD_EXPR_OPERAND_CONSTANT_BASELINE_CEILING, + "field-expression operand constants grew to {}; compact field expression encoding", + metrics.field_expr_operand_constants + ); + assert!( + metrics.stage_local_helper_functions <= STAGE_HELPER_FUNCTION_BASELINE_CEILING, + "stage-local helper function count grew to {}; factor verifier mechanics into shared runtime", + metrics.stage_local_helper_functions + ); + assert!( + metrics.relation_string_sites <= RELATION_STRING_SITE_BASELINE_CEILING, + "relation string sites grew to {}; prefer typed relation plan data or explicit allowlists", + metrics.relation_string_sites + ); +} + +#[test] +fn checked_in_generated_verifier_respects_boundary_hygiene() { + let verifier_root = workspace_root().join("crates/jolt-verifier"); + if !verifier_root.exists() { + return; + } + let manifest = + std::fs::read_to_string(verifier_root.join("Cargo.toml")).expect("read verifier manifest"); + for package in [ + "jolt-prover", + "jolt-kernels", + "jolt-core", + "jolt-equivalence", + "jolt-profiling", + "tracer", + ] { + assert!( + !manifest.contains(package), + "generated verifier manifest depends on forbidden package `{package}`" + ); + } + + for path in rust_files(&verifier_root.join("src")) { + let source = std::fs::read_to_string(&path).expect("read verifier source"); + for pattern in [ + "use jolt_prover", + "jolt_prover::", + "use jolt_kernels", + "jolt_kernels::", + "use jolt_core", + "jolt_core::", + "use jolt_equivalence", + "jolt_equivalence::", + "use jolt_profiling", + "jolt_profiling::", + "use tracer", + "tracer::", + ] { + assert!( + !source.contains(pattern), + "generated verifier source `{}` contains forbidden import/reference `{pattern}`", + path.display() + ); + } + assert!( + !source.contains("JoltField::Challenge") + && !source.contains("Transcript") + && !source.contains("Challenge = <"), + "generated verifier source `{}` drifted away from the full-field transcript path", + path.display() + ); + } +} + +#[test] +fn verifier_cpu_fixtures_are_kernel_free() { + let fixtures = workspace_root().join("crates/bolt/tests/fixtures"); + if !fixtures.exists() { + eprintln!("skipping optional verifier MLIR scratch fixture check; run commitment_ir with JOLT_UPDATE_GOLDENS=1 to materialize fixtures"); + return; + } + let mut checked = 0usize; + for path in files_with_extension(&fixtures, "mlir") { + let file_name = path + .file_name() + .and_then(|name| name.to_str()) + .expect("fixture file name"); + if !file_name.contains("verifier") { + continue; + } + checked += 1; + let source = std::fs::read_to_string(&path).expect("read verifier MLIR fixture"); + for pattern in ["kernel = @", "\"cpu.kernel\"", "\"compute.kernel\""] { + assert!( + !source.contains(pattern), + "verifier MLIR fixture `{}` contains forbidden kernel marker `{pattern}`", + path.display() + ); + } + } + assert!(checked > 0, "no verifier MLIR fixtures were checked"); +} + +#[test] +fn checked_in_generated_verifier_protocol_symbols_are_allowlisted() { + let verifier_root = workspace_root().join("crates/jolt-verifier/src"); + if !verifier_root.exists() { + return; + } + let mut checked = 0usize; + for path in rust_files(&verifier_root) { + let source = std::fs::read_to_string(&path).expect("read verifier source"); + for symbol in quoted_jolt_protocol_symbols(&source) { + checked += 1; + assert_allowed_jolt_protocol_symbol(&path, symbol); + } + } + assert!( + checked > 0, + "no generated verifier Jolt symbols were checked" + ); +} + +#[test] +fn verifier_mlir_fixtures_protocol_symbols_are_allowlisted() { + let fixtures = workspace_root().join("crates/bolt/tests/fixtures"); + if !fixtures.exists() { + eprintln!("skipping optional verifier MLIR scratch symbol check; run commitment_ir with JOLT_UPDATE_GOLDENS=1 to materialize fixtures"); + return; + } + let mut checked = 0usize; + for path in files_with_extension(&fixtures, "mlir") { + let file_name = path + .file_name() + .and_then(|name| name.to_str()) + .expect("fixture file name"); + if !file_name.contains("verifier") { + continue; + } + let source = std::fs::read_to_string(&path).expect("read verifier MLIR fixture"); + for symbol in mlir_jolt_protocol_symbols(&source) { + checked += 1; + assert_allowed_jolt_protocol_symbol(&path, symbol); + } + } + assert!(checked > 0, "no verifier MLIR Jolt symbols were checked"); +} + +#[test] +fn generic_compiler_rejects_jolt_protocol_strings() { + let root = workspace_root(); + let mut offenders = Vec::new(); + for path in generic_compiler_source_files(&root) { + let source = std::fs::read_to_string(&path).expect("read generic compiler source"); + let hits = count_generic_compiler_jolt_hits(&source); + if hits == 0 { + continue; + } + + let relative = relative_workspace_path(&root, &path); + offenders.push(format!("{relative}: {hits} hit(s)")); + } + assert!( + offenders.is_empty(), + "generic compiler source contains quarantined Jolt protocol strings:\n{}", + offenders.join("\n") + ); +} + +#[test] +fn jolt_artifact_apis_are_quarantined_out_of_generic_exports() { + let root = workspace_root(); + let artifact_source = + std::fs::read_to_string(root.join("crates/bolt/src/emit/rust/artifacts.rs")) + .expect("read generic artifact assembly"); + for pattern in [ + "JoltProtocolStage", + "JoltArtifactCrate", + "JoltRustArtifact", + "JoltGeneratedCrate", + "JoltGeneratedFile", + "jolt_artifact_config", + "jolt_rust_artifact", + "assemble_jolt_generated_crates", + "assemble_jolt_workspace_generated_crates", + "write_jolt_generated_crates", + "validate_jolt_rust_artifact_imports", + ] { + assert!( + !artifact_source.contains(pattern), + "generic artifact assembly still exposes quarantined Jolt API `{pattern}`" + ); + } + + let rust_mod_source = std::fs::read_to_string(root.join("crates/bolt/src/emit/rust/mod.rs")) + .expect("read Rust emitter exports"); + assert!( + !rust_mod_source.contains("assemble_jolt_") + && !rust_mod_source.contains("JoltProtocolStage") + && !rust_mod_source.contains("jolt_artifact_config"), + "generic Rust emitter exports still re-export Jolt artifact APIs" + ); + + let lib_source = + std::fs::read_to_string(root.join("crates/bolt/src/lib.rs")).expect("read bolt lib"); + assert!( + !lib_source.contains("pub use protocols::jolt"), + "root bolt exports must keep Jolt APIs under bolt::protocols::jolt" + ); +} + +fn verifier_cleanup_metrics(verifier_src: &Path) -> VerifierCleanupMetrics { + let mut metrics = VerifierCleanupMetrics::default(); + for path in rust_files(verifier_src) { + let source = std::fs::read_to_string(&path).expect("read verifier source"); + let relative = path + .strip_prefix(verifier_src) + .expect("relative verifier path"); + let line_count = source.lines().count(); + metrics.total_loc += line_count; + if relative == Path::new("stages/common.rs") { + metrics.shared_runtime_loc += line_count; + } else { + metrics.generated_surface_loc += line_count; + } + if relative == Path::new("verifier.rs") { + metrics.verifier_rs_loc = line_count; + } + if relative == Path::new("stages/stage6.rs") || relative == Path::new("stages/stage7.rs") { + metrics.stage6_stage7_loc += line_count; + } + if relative.starts_with("stages") { + metrics.stage_local_generic_plan_structs += + count_stage_local_generic_plan_structs(&source); + metrics.field_expr_operand_constants += count_field_expr_operand_constants(&source); + metrics.stage_local_helper_functions += count_stage_local_helper_functions(&source); + metrics.relation_string_sites += count_relation_string_sites(&source); + } + } + metrics +} + +fn count_stage_local_generic_plan_structs(source: &str) -> usize { + const PLAN_SUFFIXES: &[&str] = &[ + "FieldExprPlan", + "OpeningClaimPlan", + "OpeningClaimEqualityPlan", + "SumcheckClaimPlan", + "SumcheckDriverPlan", + "SumcheckEvalPlan", + "SumcheckInstanceResultPlan", + "PointSlicePlan", + "PointConcatPlan", + "ProgramStepPlan", + "TranscriptSqueezePlan", + "TranscriptAbsorbBytesPlan", + "CpuProgramPlan", + "VerifierProgramPlan", + "NamedEval", + ]; + source + .lines() + .filter(|line| { + let line = line.trim_start(); + (line.starts_with("pub struct Stage") || line.starts_with("pub type Stage")) + && PLAN_SUFFIXES.iter().any(|suffix| line.contains(suffix)) + }) + .count() +} + +fn count_field_expr_operand_constants(source: &str) -> usize { + source + .lines() + .filter(|line| line.contains("FIELD_EXPR_") && line.contains("OPERAND")) + .count() +} + +fn count_stage_local_helper_functions(source: &str) -> usize { + const HELPER_PREFIXES: &[&str] = &[ + "fn evaluate_stage", + "fn verify_opening_equalities", + "fn append_opening_claims", + "fn find_", + "fn expected_", + "fn pow_field", + "fn single_operand", + "fn require_operand_count", + ]; + source + .lines() + .filter(|line| { + let line = line.trim_start(); + HELPER_PREFIXES + .iter() + .any(|prefix| line.starts_with(prefix)) + }) + .count() +} + +fn count_relation_string_sites(source: &str) -> usize { + source + .lines() + .filter(|line| { + line.contains("match instance.relation") + || line.contains("match claim.relation") + || line.contains("match driver.relation") + || line.contains("relation: Some(\"jolt.") + || line.contains("relation: \"jolt.") + }) + .count() +} + +fn assert_allowed_jolt_protocol_symbol(path: &Path, symbol: &str) { + assert!( + ALLOWED_JOLT_PROTOCOL_SYMBOLS.contains(&symbol), + "`{}` contains unreviewed Jolt protocol symbol `{symbol}`", + path.display() + ); +} + +fn quoted_jolt_protocol_symbols(source: &str) -> Vec<&str> { + let mut symbols = Vec::new(); + let mut rest = source; + while let Some(offset) = rest.find("\"jolt.") { + let after_quote = &rest[offset + 1..]; + if let Some(end) = after_quote.find('"') { + symbols.push(&after_quote[..end]); + rest = &after_quote[end + 1..]; + } else { + break; + } + } + symbols +} + +fn mlir_jolt_protocol_symbols(source: &str) -> Vec<&str> { + let mut symbols = Vec::new(); + let mut rest = source; + while let Some(offset) = rest.find("@jolt") { + let after_at = &rest[offset + 1..]; + let end = after_at + .char_indices() + .find_map(|(index, ch)| { + (!matches!(ch, 'a'..='z' | 'A'..='Z' | '0'..='9' | '_' | '.')).then_some(index) + }) + .unwrap_or(after_at.len()); + symbols.push(&after_at[..end]); + rest = &after_at[end..]; + } + symbols +} + +fn generic_compiler_source_files(root: &Path) -> Vec { + let source_root = root.join("crates/bolt/src"); + let mut files = rust_files(&source_root) + .into_iter() + .filter(|path| { + !relative_workspace_path(root, path).starts_with("crates/bolt/src/protocols/") + }) + .collect::>(); + files.sort(); + files +} + +fn count_generic_compiler_jolt_hits(source: &str) -> usize { + source + .lines() + .filter(|line| { + GENERIC_COMPILER_JOLT_PATTERNS + .iter() + .any(|pattern| line.contains(pattern)) + }) + .count() +} + +fn relative_workspace_path(root: &Path, path: &Path) -> String { + path.strip_prefix(root) + .expect("workspace-relative path") + .to_string_lossy() + .replace('\\', "/") +} + +fn rust_files(root: &Path) -> Vec { + files_with_extension(root, "rs") +} + +fn files_with_extension(root: &Path, extension: &str) -> Vec { + let mut files = Vec::new(); + collect_files_with_extension(root, extension, &mut files); + files.sort(); + files +} + +fn collect_files_with_extension(root: &Path, extension: &str, files: &mut Vec) { + for entry in std::fs::read_dir(root).expect("read directory") { + let entry = entry.expect("read directory entry"); + let path = entry.path(); + if path.is_dir() { + collect_files_with_extension(&path, extension, files); + } else if path.extension().and_then(|ext| ext.to_str()) == Some(extension) { + files.push(path); + } + } +} + +fn workspace_root() -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .and_then(Path::parent) + .expect("workspace root") + .to_path_buf() +} diff --git a/scripts/setup-bolt-dev.sh b/scripts/setup-bolt-dev.sh new file mode 100755 index 0000000000..3c1cf76238 --- /dev/null +++ b/scripts/setup-bolt-dev.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -euo pipefail + +repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +env_file="${repo_root}/.bolt-dev-env" + +usage() { + cat <<'USAGE' +Usage: scripts/setup-bolt-dev.sh + +Installs/configures the local dependencies needed by Bolt MLIR/codegen checks +on macOS: + + - Homebrew llvm + - rustup components: rust-src, rustfmt, clippy + - cargo-nextest + - jolt CLI installed from this checkout + - .bolt-dev-env with MLIR_SYS_220_PREFIX, PATH, SDKROOT, and bindgen flags + +After it finishes, run: + + source .bolt-dev-env +USAGE +} + +if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then + usage + exit 0 +fi + +if [[ "$(uname -s)" != "Darwin" ]]; then + cat >&2 <<'EOF' +scripts/setup-bolt-dev.sh currently supports macOS/Homebrew. + +Bolt's MLIR path needs LLVM/MLIR 22 visible through llvm-config. On non-macOS +hosts, install an LLVM 22 toolchain and export: + + MLIR_SYS_220_PREFIX= + PATH=/bin:$PATH + +EOF + exit 1 +fi + +if ! command -v brew >/dev/null 2>&1; then + cat >&2 <<'EOF' +Homebrew is required to install the LLVM toolchain for this helper. +Install Homebrew first: https://brew.sh +EOF + exit 1 +fi + +if ! command -v xcrun >/dev/null 2>&1; then + cat >&2 <<'EOF' +xcrun is missing. Install the Xcode command line tools first: + + xcode-select --install +EOF + exit 1 +fi + +if ! brew list llvm >/dev/null 2>&1; then + brew install llvm +fi + +if command -v rustup >/dev/null 2>&1; then + rustup component add rust-src rustfmt clippy +else + cat >&2 <<'EOF' +rustup is not installed. Install Rust from https://rustup.rs, then rerun this +script so it can add rust-src, rustfmt, and clippy. +EOF + exit 1 +fi + +if ! cargo nextest --version >/dev/null 2>&1; then + cargo install cargo-nextest --locked +fi + +cargo install --path "${repo_root}" --locked --force + +llvm_prefix="$(brew --prefix llvm)" +sdkroot="$(xcrun --show-sdk-path)" + +cat >"${env_file}" </dev/null + +cat <