diff --git a/constraint/gkr.go b/constraint/gkr.go index f295f82de4..c7313f92c8 100644 --- a/constraint/gkr.go +++ b/constraint/gkr.go @@ -57,6 +57,32 @@ func (l GkrSumcheckLevel) NbClaims() int { func (l GkrSumcheckLevel) ClaimGroups() []GkrClaimGroup { return l } func (l GkrSumcheckLevel) FinalEvalProofIndex(wireI, _ int) int { return wireI } +// SingleClaimSource returns the unique distinct claim source across the whole +// level when every claim in the level refers to the same evaluation point. +// This is the schedule-side eligibility boundary for the Gru24 Section 3.2 +// optimization; folded multi-source levels return false. +func (l GkrSumcheckLevel) SingleClaimSource() (GkrClaimSource, bool) { + if len(l) == 0 { + return GkrClaimSource{}, false + } + + var first GkrClaimSource + found := false + for _, group := range l { + for _, src := range group.ClaimSources { + if !found { + first = src + found = true + continue + } + if src != first { + return GkrClaimSource{}, false + } + } + } + return first, found +} + func (l GkrSkipLevel) NbOutgoingEvalPoints() int { return len(l.ClaimSources) } func (l GkrSkipLevel) NbClaims() int { return GkrClaimGroup(l).NbClaims() diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 9dfd6ce8cb..2866a3f148 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -43,6 +43,15 @@ func (e *zeroCheckLazyClaims) degree(int) int { return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) } +func (e *zeroCheckLazyClaims) roundCombinationCoeff(round int) ({{ .ElementType }}, bool) { + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + src, ok := level.SingleClaimSource() + if !ok { + return {{ .ElementType }}{}, false + } + return e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex][round], true +} + // verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. // The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying // ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all @@ -57,7 +66,8 @@ func (e *zeroCheckLazyClaims) degree(int) int { // that the full sum matches purportedValue. func (e *zeroCheckLazyClaims) verifyFinalEval(r []{{ .ElementType }}, purportedValue {{ .ElementType }}, uniqueInputEvaluations []{{ .ElementType }}) error { e.resources.outgoingEvalPoints[e.levelI] = [][]{{ .ElementType }}{r} - level := e.resources.schedule[e.levelI] + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + _, optimized := level.SingleClaimSource() gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) var claimedEvals polynomial.Polynomial @@ -80,11 +90,17 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []{{ .ElementType }}, purportedV gateEval.Set(evaluator.evaluate()) } - for _, src := range group.ClaimSources { - eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) - var term {{ .ElementType }} - term.Mul(&eq, &gateEval) - claimedEvals = append(claimedEvals, term) + if optimized { + for range group.ClaimSources { + claimedEvals = append(claimedEvals, gateEval) + } + } else { + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term {{ .ElementType }} + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } } levelWireI++ } @@ -107,17 +123,26 @@ type zeroCheckClaims struct { inputIndices [][]int // [wireInLevel][gateInputJ] → index in input eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points gateEvaluatorPools []*gateEvaluatorPool + singleSourcePoint []{{ .ElementType }} + round int } func (c *zeroCheckClaims) varsNum() int { return c.resources.nbVars } -// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + if c.singleSourcePoint != nil { + return c.roundPolynomialSingleSource() + } + return c.roundPolynomialLegacy() +} + +// roundPolynomialLegacy computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). // The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). // The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). // By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { +func (c *zeroCheckClaims) roundPolynomialLegacy() polynomial.Polynomial { level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) degree := c.resources.circuit.ZeroCheckDegree(level) nbUniqueInputs := len(c.input) @@ -198,6 +223,92 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { return p } +// roundPolynomialSingleSource implements the Gru24 Section 3.2 path for levels +// whose claims all refer to the same evaluation point. It collapses the current +// eq factor by summing the two Boolean branches, so the prover only sends a +// degree-d polynomial instead of degree-(d+1). +func (c *zeroCheckClaims) roundPolynomialSingleSource() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + sumSize := len(c.eqs[0]) / 2 + + var one {{ .ElementType }} + one.SetOne() + sendZero := c.singleSourcePoint[c.round].Equal(&one) + + p := make([]{{ .ElementType }}, degree) + var mu sync.Mutex + computeAll := func(start, end int) { + var step {{ .ElementType }} + + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() + + res := make([]{{ .ElementType }}, degree) + inputEvals := make([]{{ .ElementType }}, (degree+1)*nbUniqueInputs) + weights := make([]{{ .ElementType }}, nbWires) + + accumulateAt := func(offset, outI int) { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(inputEvals[offset+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &weights[w]) + res[outI].Add(&res[outI], summand) + } + } + + for h := start; h < end; h++ { + evalAt1Index := sumSize + h + for w := range nbWires { + weights[w].Set(&c.eqs[w][h]) + } + for k := range c.input { + inputEvals[k].Set(&c.input[k][h]) + step.Sub(&c.input[k][evalAt1Index], &c.input[k][h]) + for d := 1; d <= degree; d++ { + inputEvals[d*nbUniqueInputs+k].Add(&inputEvals[(d-1)*nbUniqueInputs+k], &step) + } + } + + if sendZero { + accumulateAt(0, 0) + for d := 2; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } else { + for d := 1; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } + } + mu.Lock() + for i := range p { + p[i].Add(&p[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + if sumSize < minBlockSize { + computeAll(0, sumSize) + } else { + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return p +} + // roundFold folds all input and eq polynomials at the verifier challenge r. // After this call, j ← j+1 and rⱼ = r. func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) { @@ -210,6 +321,11 @@ func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) { for i := range c.eqs { c.eqs[i].Fold(r) } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } else { wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { @@ -221,7 +337,13 @@ func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) { for _, wg := range wgs { wg.Wait() } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } + c.round++ } // proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). @@ -282,6 +404,31 @@ func (r *resources) eqAcc(e, m polynomial.MultiLin, q []{{ .ElementType }}) { }, 512).Wait() } +// stripSingleSourceEqFactor removes the current Boolean-variable eq factor from +// an optimized single-source eq table while keeping the table duplicated across +// the next current variable. If e encodes a value independent of Xⱼ up to the +// remaining eq suffix, afterwards it encodes the same shape for the next round. +func (r *resources) stripSingleSourceEqFactor(e polynomial.MultiLin) { + if len(e) <= 1 { + return + } + mid := len(e) / 2 + work := func(start, end int) { + var sum {{ .ElementType }} + for i := start; i < end; i++ { + sum.Add(&e[i], &e[mid+i]) + e[i].Set(&sum) + e[mid+i].Set(&sum) + } + } + const minBlockSize = 512 + if mid < minBlockSize { + work(0, mid) + } else { + r.workers.Submit(mid, work, minBlockSize).Wait() + } +} + type resources struct { // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). @@ -376,7 +523,7 @@ func (r *resources) verifySkipLevel(levelI int, proof Proof) error { } func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { - level := r.schedule[levelI] + level := r.schedule[levelI].(constraint.GkrSumcheckLevel) nbClaims := level.NbClaims() var foldingCoeff {{ .ElementType }} if nbClaims >= 2 { @@ -461,6 +608,12 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { eqs: eqs, gateEvaluatorPools: pools, } + if src, ok := level.SingleClaimSource(); ok { + claims.singleSourcePoint = r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex] + for i := range claims.eqs { + r.stripSingleSourceEqFactor(claims.eqs[i]) + } + } return sumcheckProve(claims, &r.transcript) } diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl index 5977893dbf..89d1bcacc0 100644 --- a/internal/generator/backend/template/gkr/gkr.test.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -148,6 +148,27 @@ func TestSumcheckLevel(t *testing.T) { } } +func TestZeroCheckDegreeDispatch(t *testing.T) { + _, sCircuit := cache.Compile(t, gkrtesting.Poseidon2Circuit(4, 2)) + schedule, err := gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range sCircuit.Inputs() { + assignment[i] = make([]{{ .ElementType }}, 2) + {{ .FieldPackageName }}.Vector(assignment[i]).MustSetRandom() + } + assignment.Complete(sCircuit) + + proof, err := Prove(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + assert.NoError(t, Verify(sCircuit, schedule, assignment, proof, newMessageCounter(1, 1))) + + assert.Len(t, proof[3].partialSumPolys[0], 2, "single-source s-box level should use the optimized degree") + assert.Len(t, proof[5].partialSumPolys[0], 3, "folded multi-source level must stay on the legacy degree") + assert.Len(t, proof[11].partialSumPolys[0], 2, "later single-source level should also use the optimized degree") +} + // testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { t.Helper() diff --git a/internal/generator/backend/template/gkr/sumcheck.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.go.tmpl index bd5c2c2c28..ffe0406d79 100644 --- a/internal/generator/backend/template/gkr/sumcheck.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.go.tmpl @@ -55,6 +55,7 @@ type sumcheckClaims interface { type sumcheckLazyClaims interface { varsNum() int // varsNum = n degree(i int) int // degree of the total claim in the i'th variable + roundCombinationCoeff(i int) ({{ .ElementType }}, bool) verifyFinalEval(r []{{ .ElementType }}, purportedValue {{ .ElementType }}, proof []{{ .ElementType }}) error } @@ -92,13 +93,34 @@ func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum { gJ := make(polynomial.Polynomial, degree+1) gJR := claimedSum + var one {{ .ElementType }} + one.SetOne() for j := range claims.varsNum() { if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } - copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + if coeff, ok := claims.roundCombinationCoeff(j); ok { + if coeff.Equal(&one) { + gJ[0].Set(&proof.partialSumPolys[j][0]) + gJ[1].Set(&gJR) + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Mul(&coeff, &proof.partialSumPolys[j][0]) + gJ[0].Sub(&gJR, &gJ[0]) + var oneMinusCoeff {{ .ElementType }} + oneMinusCoeff.SetOne() + oneMinusCoeff.Sub(&oneMinusCoeff, &coeff) + oneMinusCoeff.Inverse(&oneMinusCoeff) + gJ[0].Mul(&gJ[0], &oneMinusCoeff) + } + if degree > 1 { + copy(gJ[2:], proof.partialSumPolys[j][1:]) + } + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + } r[j] = t.getChallenge(proof.partialSumPolys[j]...) gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) diff --git a/internal/generator/backend/template/gkr/sumcheck.test.defs.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.defs.go.tmpl index 54a5c00b85..a00fc58431 100644 --- a/internal/generator/backend/template/gkr/sumcheck.test.defs.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.test.defs.go.tmpl @@ -45,6 +45,10 @@ func (c singleMultilinLazyClaim) degree(int) int { return 1 } +func (c singleMultilinLazyClaim) roundCombinationCoeff(int) ({{ .ElementType }}, bool) { + return {{ .ElementType }}{}, false +} + func (c singleMultilinLazyClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } diff --git a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl index 146a67727d..6d42e92926 100644 --- a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl @@ -3,6 +3,9 @@ import ( "{{ .FieldPackagePath }}/polynomial" "github.com/stretchr/testify/assert" "hash" + {{ if .GenerateTestVectors}} + "github.com/consensys/gnark/internal/small_rational" + {{ end }} {{ if not .GenerateTestVectors}} "{{ .FieldPackagePath }}" "math/bits" @@ -79,6 +82,52 @@ func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { } } +type coeffOneRoundLazyClaim struct { + expectedChallenge {{ .ElementType }} + expectedFinalEval {{ .ElementType }} +} + +func (c coeffOneRoundLazyClaim) verifyFinalEval(r []{{ .ElementType }}, purportedValue {{ .ElementType }}, proof []{{ .ElementType }}) error { + if len(r) != 1 { + return fmt.Errorf("unexpected challenge length %d", len(r)) + } + if !r[0].Equal(&c.expectedChallenge) { + return fmt.Errorf("unexpected challenge") + } + if !purportedValue.Equal(&c.expectedFinalEval) { + return fmt.Errorf("unexpected final eval") + } + return nil +} + +func (c coeffOneRoundLazyClaim) degree(int) int { + return 1 +} + +func (c coeffOneRoundLazyClaim) roundCombinationCoeff(int) ({{ .ElementType }}, bool) { + return *toElement(1), true +} + +func (c coeffOneRoundLazyClaim) varsNum() int { + return 1 +} + +func TestSumcheckVerifyCoeffOneReconstruction(t *testing.T) { + proof := sumcheckProof{ + partialSumPolys: []polynomial.Polynomial{{*toElement(4)}}, + } + lazyClaim := coeffOneRoundLazyClaim{ + expectedChallenge: *toElement(2), + expectedFinalEval: *toElement(14), + } + tr := transcript{h: newMessageCounter(2, 0)} + assert.NoError(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) + + proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) + tr = transcript{h: newMessageCounter(2, 0)} + assert.Error(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) +} + {{ if not .GenerateTestVectors }} {{ template "sumcheckTestDefs" .}} {{ end }} diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 07dd53baca..94f0b75c0b 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -50,6 +50,15 @@ func (e *zeroCheckLazyClaims) degree(int) int { return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) } +func (e *zeroCheckLazyClaims) roundCombinationCoeff(round int) (fr.Element, bool) { + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + src, ok := level.SingleClaimSource() + if !ok { + return fr.Element{}, false + } + return e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex][round], true +} + // verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. // The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying // ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all @@ -64,7 +73,8 @@ func (e *zeroCheckLazyClaims) degree(int) int { // that the full sum matches purportedValue. func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { e.resources.outgoingEvalPoints[e.levelI] = [][]fr.Element{r} - level := e.resources.schedule[e.levelI] + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + _, optimized := level.SingleClaimSource() gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) var claimedEvals polynomial.Polynomial @@ -87,11 +97,17 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr. gateEval.Set(evaluator.evaluate()) } - for _, src := range group.ClaimSources { - eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) - var term fr.Element - term.Mul(&eq, &gateEval) - claimedEvals = append(claimedEvals, term) + if optimized { + for range group.ClaimSources { + claimedEvals = append(claimedEvals, gateEval) + } + } else { + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term fr.Element + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } } levelWireI++ } @@ -114,17 +130,26 @@ type zeroCheckClaims struct { inputIndices [][]int // [wireInLevel][gateInputJ] → index in input eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points gateEvaluatorPools []*gateEvaluatorPool + singleSourcePoint []fr.Element + round int } func (c *zeroCheckClaims) varsNum() int { return c.resources.nbVars } -// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + if c.singleSourcePoint != nil { + return c.roundPolynomialSingleSource() + } + return c.roundPolynomialLegacy() +} + +// roundPolynomialLegacy computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). // The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). // The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). // By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { +func (c *zeroCheckClaims) roundPolynomialLegacy() polynomial.Polynomial { level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) degree := c.resources.circuit.ZeroCheckDegree(level) nbUniqueInputs := len(c.input) @@ -205,6 +230,92 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { return p } +// roundPolynomialSingleSource implements the Gru24 Section 3.2 path for levels +// whose claims all refer to the same evaluation point. It collapses the current +// eq factor by summing the two Boolean branches, so the prover only sends a +// degree-d polynomial instead of degree-(d+1). +func (c *zeroCheckClaims) roundPolynomialSingleSource() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + sumSize := len(c.eqs[0]) / 2 + + var one fr.Element + one.SetOne() + sendZero := c.singleSourcePoint[c.round].Equal(&one) + + p := make([]fr.Element, degree) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() + + res := make([]fr.Element, degree) + inputEvals := make([]fr.Element, (degree+1)*nbUniqueInputs) + weights := make([]fr.Element, nbWires) + + accumulateAt := func(offset, outI int) { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(inputEvals[offset+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &weights[w]) + res[outI].Add(&res[outI], summand) + } + } + + for h := start; h < end; h++ { + evalAt1Index := sumSize + h + for w := range nbWires { + weights[w].Set(&c.eqs[w][h]) + } + for k := range c.input { + inputEvals[k].Set(&c.input[k][h]) + step.Sub(&c.input[k][evalAt1Index], &c.input[k][h]) + for d := 1; d <= degree; d++ { + inputEvals[d*nbUniqueInputs+k].Add(&inputEvals[(d-1)*nbUniqueInputs+k], &step) + } + } + + if sendZero { + accumulateAt(0, 0) + for d := 2; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } else { + for d := 1; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } + } + mu.Lock() + for i := range p { + p[i].Add(&p[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + if sumSize < minBlockSize { + computeAll(0, sumSize) + } else { + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return p +} + // roundFold folds all input and eq polynomials at the verifier challenge r. // After this call, j ← j+1 and rⱼ = r. func (c *zeroCheckClaims) roundFold(r fr.Element) { @@ -217,6 +328,11 @@ func (c *zeroCheckClaims) roundFold(r fr.Element) { for i := range c.eqs { c.eqs[i].Fold(r) } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } else { wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { @@ -228,7 +344,13 @@ func (c *zeroCheckClaims) roundFold(r fr.Element) { for _, wg := range wgs { wg.Wait() } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } + c.round++ } // proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). @@ -289,6 +411,31 @@ func (r *resources) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { }, 512).Wait() } +// stripSingleSourceEqFactor removes the current Boolean-variable eq factor from +// an optimized single-source eq table while keeping the table duplicated across +// the next current variable. If e encodes a value independent of Xⱼ up to the +// remaining eq suffix, afterwards it encodes the same shape for the next round. +func (r *resources) stripSingleSourceEqFactor(e polynomial.MultiLin) { + if len(e) <= 1 { + return + } + mid := len(e) / 2 + work := func(start, end int) { + var sum fr.Element + for i := start; i < end; i++ { + sum.Add(&e[i], &e[mid+i]) + e[i].Set(&sum) + e[mid+i].Set(&sum) + } + } + const minBlockSize = 512 + if mid < minBlockSize { + work(0, mid) + } else { + r.workers.Submit(mid, work, minBlockSize).Wait() + } +} + type resources struct { // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). @@ -383,7 +530,7 @@ func (r *resources) verifySkipLevel(levelI int, proof Proof) error { } func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { - level := r.schedule[levelI] + level := r.schedule[levelI].(constraint.GkrSumcheckLevel) nbClaims := level.NbClaims() var foldingCoeff fr.Element if nbClaims >= 2 { @@ -468,6 +615,12 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { eqs: eqs, gateEvaluatorPools: pools, } + if src, ok := level.SingleClaimSource(); ok { + claims.singleSourcePoint = r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex] + for i := range claims.eqs { + r.stripSingleSourceEqFactor(claims.eqs[i]) + } + } return sumcheckProve(claims, &r.transcript) } diff --git a/internal/gkr/bls12-377/gkr_test.go b/internal/gkr/bls12-377/gkr_test.go index 57f5c1e664..f7d66d7501 100644 --- a/internal/gkr/bls12-377/gkr_test.go +++ b/internal/gkr/bls12-377/gkr_test.go @@ -154,6 +154,27 @@ func TestSumcheckLevel(t *testing.T) { } } +func TestZeroCheckDegreeDispatch(t *testing.T) { + _, sCircuit := cache.Compile(t, gkrtesting.Poseidon2Circuit(4, 2)) + schedule, err := gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range sCircuit.Inputs() { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() + } + assignment.Complete(sCircuit) + + proof, err := Prove(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + assert.NoError(t, Verify(sCircuit, schedule, assignment, proof, newMessageCounter(1, 1))) + + assert.Len(t, proof[3].partialSumPolys[0], 2, "single-source s-box level should use the optimized degree") + assert.Len(t, proof[5].partialSumPolys[0], 3, "folded multi-source level must stay on the legacy degree") + assert.Len(t, proof[11].partialSumPolys[0], 2, "later single-source level should also use the optimized degree") +} + // testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { t.Helper() diff --git a/internal/gkr/bls12-377/sumcheck.go b/internal/gkr/bls12-377/sumcheck.go index c1700e3db3..e3c95825ac 100644 --- a/internal/gkr/bls12-377/sumcheck.go +++ b/internal/gkr/bls12-377/sumcheck.go @@ -62,6 +62,7 @@ type sumcheckClaims interface { type sumcheckLazyClaims interface { varsNum() int // varsNum = n degree(i int) int // degree of the total claim in the i'th variable + roundCombinationCoeff(i int) (fr.Element, bool) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error } @@ -99,13 +100,34 @@ func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum f gJ := make(polynomial.Polynomial, degree+1) gJR := claimedSum + var one fr.Element + one.SetOne() for j := range claims.varsNum() { if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } - copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + if coeff, ok := claims.roundCombinationCoeff(j); ok { + if coeff.Equal(&one) { + gJ[0].Set(&proof.partialSumPolys[j][0]) + gJ[1].Set(&gJR) + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Mul(&coeff, &proof.partialSumPolys[j][0]) + gJ[0].Sub(&gJR, &gJ[0]) + var oneMinusCoeff fr.Element + oneMinusCoeff.SetOne() + oneMinusCoeff.Sub(&oneMinusCoeff, &coeff) + oneMinusCoeff.Inverse(&oneMinusCoeff) + gJ[0].Mul(&gJ[0], &oneMinusCoeff) + } + if degree > 1 { + copy(gJ[2:], proof.partialSumPolys[j][1:]) + } + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + } r[j] = t.getChallenge(proof.partialSumPolys[j]...) gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) diff --git a/internal/gkr/bls12-377/sumcheck_test.go b/internal/gkr/bls12-377/sumcheck_test.go index 395a75bff3..0408171624 100644 --- a/internal/gkr/bls12-377/sumcheck_test.go +++ b/internal/gkr/bls12-377/sumcheck_test.go @@ -7,14 +7,12 @@ package gkr import ( "fmt" - "hash" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" "github.com/stretchr/testify/assert" - - "math/bits" + "hash" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math/bits" "strings" "testing" @@ -88,6 +86,52 @@ func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { } } +type coeffOneRoundLazyClaim struct { + expectedChallenge fr.Element + expectedFinalEval fr.Element +} + +func (c coeffOneRoundLazyClaim) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error { + if len(r) != 1 { + return fmt.Errorf("unexpected challenge length %d", len(r)) + } + if !r[0].Equal(&c.expectedChallenge) { + return fmt.Errorf("unexpected challenge") + } + if !purportedValue.Equal(&c.expectedFinalEval) { + return fmt.Errorf("unexpected final eval") + } + return nil +} + +func (c coeffOneRoundLazyClaim) degree(int) int { + return 1 +} + +func (c coeffOneRoundLazyClaim) roundCombinationCoeff(int) (fr.Element, bool) { + return *toElement(1), true +} + +func (c coeffOneRoundLazyClaim) varsNum() int { + return 1 +} + +func TestSumcheckVerifyCoeffOneReconstruction(t *testing.T) { + proof := sumcheckProof{ + partialSumPolys: []polynomial.Polynomial{{*toElement(4)}}, + } + lazyClaim := coeffOneRoundLazyClaim{ + expectedChallenge: *toElement(2), + expectedFinalEval: *toElement(14), + } + tr := transcript{h: newMessageCounter(2, 0)} + assert.NoError(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) + + proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) + tr = transcript{h: newMessageCounter(2, 0)} + assert.Error(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) +} + type singleMultilinClaim struct { g polynomial.MultiLin } @@ -133,6 +177,10 @@ func (c singleMultilinLazyClaim) degree(int) int { return 1 } +func (c singleMultilinLazyClaim) roundCombinationCoeff(int) (fr.Element, bool) { + return fr.Element{}, false +} + func (c singleMultilinLazyClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index bb240ddab0..284d69954f 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -50,6 +50,15 @@ func (e *zeroCheckLazyClaims) degree(int) int { return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) } +func (e *zeroCheckLazyClaims) roundCombinationCoeff(round int) (fr.Element, bool) { + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + src, ok := level.SingleClaimSource() + if !ok { + return fr.Element{}, false + } + return e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex][round], true +} + // verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. // The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying // ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all @@ -64,7 +73,8 @@ func (e *zeroCheckLazyClaims) degree(int) int { // that the full sum matches purportedValue. func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { e.resources.outgoingEvalPoints[e.levelI] = [][]fr.Element{r} - level := e.resources.schedule[e.levelI] + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + _, optimized := level.SingleClaimSource() gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) var claimedEvals polynomial.Polynomial @@ -87,11 +97,17 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr. gateEval.Set(evaluator.evaluate()) } - for _, src := range group.ClaimSources { - eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) - var term fr.Element - term.Mul(&eq, &gateEval) - claimedEvals = append(claimedEvals, term) + if optimized { + for range group.ClaimSources { + claimedEvals = append(claimedEvals, gateEval) + } + } else { + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term fr.Element + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } } levelWireI++ } @@ -114,17 +130,26 @@ type zeroCheckClaims struct { inputIndices [][]int // [wireInLevel][gateInputJ] → index in input eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points gateEvaluatorPools []*gateEvaluatorPool + singleSourcePoint []fr.Element + round int } func (c *zeroCheckClaims) varsNum() int { return c.resources.nbVars } -// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + if c.singleSourcePoint != nil { + return c.roundPolynomialSingleSource() + } + return c.roundPolynomialLegacy() +} + +// roundPolynomialLegacy computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). // The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). // The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). // By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { +func (c *zeroCheckClaims) roundPolynomialLegacy() polynomial.Polynomial { level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) degree := c.resources.circuit.ZeroCheckDegree(level) nbUniqueInputs := len(c.input) @@ -205,6 +230,92 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { return p } +// roundPolynomialSingleSource implements the Gru24 Section 3.2 path for levels +// whose claims all refer to the same evaluation point. It collapses the current +// eq factor by summing the two Boolean branches, so the prover only sends a +// degree-d polynomial instead of degree-(d+1). +func (c *zeroCheckClaims) roundPolynomialSingleSource() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + sumSize := len(c.eqs[0]) / 2 + + var one fr.Element + one.SetOne() + sendZero := c.singleSourcePoint[c.round].Equal(&one) + + p := make([]fr.Element, degree) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() + + res := make([]fr.Element, degree) + inputEvals := make([]fr.Element, (degree+1)*nbUniqueInputs) + weights := make([]fr.Element, nbWires) + + accumulateAt := func(offset, outI int) { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(inputEvals[offset+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &weights[w]) + res[outI].Add(&res[outI], summand) + } + } + + for h := start; h < end; h++ { + evalAt1Index := sumSize + h + for w := range nbWires { + weights[w].Set(&c.eqs[w][h]) + } + for k := range c.input { + inputEvals[k].Set(&c.input[k][h]) + step.Sub(&c.input[k][evalAt1Index], &c.input[k][h]) + for d := 1; d <= degree; d++ { + inputEvals[d*nbUniqueInputs+k].Add(&inputEvals[(d-1)*nbUniqueInputs+k], &step) + } + } + + if sendZero { + accumulateAt(0, 0) + for d := 2; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } else { + for d := 1; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } + } + mu.Lock() + for i := range p { + p[i].Add(&p[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + if sumSize < minBlockSize { + computeAll(0, sumSize) + } else { + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return p +} + // roundFold folds all input and eq polynomials at the verifier challenge r. // After this call, j ← j+1 and rⱼ = r. func (c *zeroCheckClaims) roundFold(r fr.Element) { @@ -217,6 +328,11 @@ func (c *zeroCheckClaims) roundFold(r fr.Element) { for i := range c.eqs { c.eqs[i].Fold(r) } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } else { wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { @@ -228,7 +344,13 @@ func (c *zeroCheckClaims) roundFold(r fr.Element) { for _, wg := range wgs { wg.Wait() } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } + c.round++ } // proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). @@ -289,6 +411,31 @@ func (r *resources) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { }, 512).Wait() } +// stripSingleSourceEqFactor removes the current Boolean-variable eq factor from +// an optimized single-source eq table while keeping the table duplicated across +// the next current variable. If e encodes a value independent of Xⱼ up to the +// remaining eq suffix, afterwards it encodes the same shape for the next round. +func (r *resources) stripSingleSourceEqFactor(e polynomial.MultiLin) { + if len(e) <= 1 { + return + } + mid := len(e) / 2 + work := func(start, end int) { + var sum fr.Element + for i := start; i < end; i++ { + sum.Add(&e[i], &e[mid+i]) + e[i].Set(&sum) + e[mid+i].Set(&sum) + } + } + const minBlockSize = 512 + if mid < minBlockSize { + work(0, mid) + } else { + r.workers.Submit(mid, work, minBlockSize).Wait() + } +} + type resources struct { // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). @@ -383,7 +530,7 @@ func (r *resources) verifySkipLevel(levelI int, proof Proof) error { } func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { - level := r.schedule[levelI] + level := r.schedule[levelI].(constraint.GkrSumcheckLevel) nbClaims := level.NbClaims() var foldingCoeff fr.Element if nbClaims >= 2 { @@ -468,6 +615,12 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { eqs: eqs, gateEvaluatorPools: pools, } + if src, ok := level.SingleClaimSource(); ok { + claims.singleSourcePoint = r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex] + for i := range claims.eqs { + r.stripSingleSourceEqFactor(claims.eqs[i]) + } + } return sumcheckProve(claims, &r.transcript) } diff --git a/internal/gkr/bls12-381/gkr_test.go b/internal/gkr/bls12-381/gkr_test.go index 5f9fa833f8..c5a6ef8834 100644 --- a/internal/gkr/bls12-381/gkr_test.go +++ b/internal/gkr/bls12-381/gkr_test.go @@ -154,6 +154,27 @@ func TestSumcheckLevel(t *testing.T) { } } +func TestZeroCheckDegreeDispatch(t *testing.T) { + _, sCircuit := cache.Compile(t, gkrtesting.Poseidon2Circuit(4, 2)) + schedule, err := gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range sCircuit.Inputs() { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() + } + assignment.Complete(sCircuit) + + proof, err := Prove(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + assert.NoError(t, Verify(sCircuit, schedule, assignment, proof, newMessageCounter(1, 1))) + + assert.Len(t, proof[3].partialSumPolys[0], 2, "single-source s-box level should use the optimized degree") + assert.Len(t, proof[5].partialSumPolys[0], 3, "folded multi-source level must stay on the legacy degree") + assert.Len(t, proof[11].partialSumPolys[0], 2, "later single-source level should also use the optimized degree") +} + // testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { t.Helper() diff --git a/internal/gkr/bls12-381/sumcheck.go b/internal/gkr/bls12-381/sumcheck.go index 0b6196a393..ba5c45ee55 100644 --- a/internal/gkr/bls12-381/sumcheck.go +++ b/internal/gkr/bls12-381/sumcheck.go @@ -62,6 +62,7 @@ type sumcheckClaims interface { type sumcheckLazyClaims interface { varsNum() int // varsNum = n degree(i int) int // degree of the total claim in the i'th variable + roundCombinationCoeff(i int) (fr.Element, bool) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error } @@ -99,13 +100,34 @@ func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum f gJ := make(polynomial.Polynomial, degree+1) gJR := claimedSum + var one fr.Element + one.SetOne() for j := range claims.varsNum() { if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } - copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + if coeff, ok := claims.roundCombinationCoeff(j); ok { + if coeff.Equal(&one) { + gJ[0].Set(&proof.partialSumPolys[j][0]) + gJ[1].Set(&gJR) + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Mul(&coeff, &proof.partialSumPolys[j][0]) + gJ[0].Sub(&gJR, &gJ[0]) + var oneMinusCoeff fr.Element + oneMinusCoeff.SetOne() + oneMinusCoeff.Sub(&oneMinusCoeff, &coeff) + oneMinusCoeff.Inverse(&oneMinusCoeff) + gJ[0].Mul(&gJ[0], &oneMinusCoeff) + } + if degree > 1 { + copy(gJ[2:], proof.partialSumPolys[j][1:]) + } + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + } r[j] = t.getChallenge(proof.partialSumPolys[j]...) gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) diff --git a/internal/gkr/bls12-381/sumcheck_test.go b/internal/gkr/bls12-381/sumcheck_test.go index 5f8a12fc5a..1fc401e5d8 100644 --- a/internal/gkr/bls12-381/sumcheck_test.go +++ b/internal/gkr/bls12-381/sumcheck_test.go @@ -7,14 +7,12 @@ package gkr import ( "fmt" - "hash" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" "github.com/stretchr/testify/assert" - - "math/bits" + "hash" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "math/bits" "strings" "testing" @@ -88,6 +86,52 @@ func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { } } +type coeffOneRoundLazyClaim struct { + expectedChallenge fr.Element + expectedFinalEval fr.Element +} + +func (c coeffOneRoundLazyClaim) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error { + if len(r) != 1 { + return fmt.Errorf("unexpected challenge length %d", len(r)) + } + if !r[0].Equal(&c.expectedChallenge) { + return fmt.Errorf("unexpected challenge") + } + if !purportedValue.Equal(&c.expectedFinalEval) { + return fmt.Errorf("unexpected final eval") + } + return nil +} + +func (c coeffOneRoundLazyClaim) degree(int) int { + return 1 +} + +func (c coeffOneRoundLazyClaim) roundCombinationCoeff(int) (fr.Element, bool) { + return *toElement(1), true +} + +func (c coeffOneRoundLazyClaim) varsNum() int { + return 1 +} + +func TestSumcheckVerifyCoeffOneReconstruction(t *testing.T) { + proof := sumcheckProof{ + partialSumPolys: []polynomial.Polynomial{{*toElement(4)}}, + } + lazyClaim := coeffOneRoundLazyClaim{ + expectedChallenge: *toElement(2), + expectedFinalEval: *toElement(14), + } + tr := transcript{h: newMessageCounter(2, 0)} + assert.NoError(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) + + proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) + tr = transcript{h: newMessageCounter(2, 0)} + assert.Error(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) +} + type singleMultilinClaim struct { g polynomial.MultiLin } @@ -133,6 +177,10 @@ func (c singleMultilinLazyClaim) degree(int) int { return 1 } +func (c singleMultilinLazyClaim) roundCombinationCoeff(int) (fr.Element, bool) { + return fr.Element{}, false +} + func (c singleMultilinLazyClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 617b104f8e..da71837cf7 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -50,6 +50,15 @@ func (e *zeroCheckLazyClaims) degree(int) int { return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) } +func (e *zeroCheckLazyClaims) roundCombinationCoeff(round int) (fr.Element, bool) { + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + src, ok := level.SingleClaimSource() + if !ok { + return fr.Element{}, false + } + return e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex][round], true +} + // verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. // The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying // ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all @@ -64,7 +73,8 @@ func (e *zeroCheckLazyClaims) degree(int) int { // that the full sum matches purportedValue. func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { e.resources.outgoingEvalPoints[e.levelI] = [][]fr.Element{r} - level := e.resources.schedule[e.levelI] + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + _, optimized := level.SingleClaimSource() gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) var claimedEvals polynomial.Polynomial @@ -87,11 +97,17 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr. gateEval.Set(evaluator.evaluate()) } - for _, src := range group.ClaimSources { - eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) - var term fr.Element - term.Mul(&eq, &gateEval) - claimedEvals = append(claimedEvals, term) + if optimized { + for range group.ClaimSources { + claimedEvals = append(claimedEvals, gateEval) + } + } else { + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term fr.Element + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } } levelWireI++ } @@ -114,17 +130,26 @@ type zeroCheckClaims struct { inputIndices [][]int // [wireInLevel][gateInputJ] → index in input eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points gateEvaluatorPools []*gateEvaluatorPool + singleSourcePoint []fr.Element + round int } func (c *zeroCheckClaims) varsNum() int { return c.resources.nbVars } -// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + if c.singleSourcePoint != nil { + return c.roundPolynomialSingleSource() + } + return c.roundPolynomialLegacy() +} + +// roundPolynomialLegacy computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). // The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). // The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). // By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { +func (c *zeroCheckClaims) roundPolynomialLegacy() polynomial.Polynomial { level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) degree := c.resources.circuit.ZeroCheckDegree(level) nbUniqueInputs := len(c.input) @@ -205,6 +230,92 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { return p } +// roundPolynomialSingleSource implements the Gru24 Section 3.2 path for levels +// whose claims all refer to the same evaluation point. It collapses the current +// eq factor by summing the two Boolean branches, so the prover only sends a +// degree-d polynomial instead of degree-(d+1). +func (c *zeroCheckClaims) roundPolynomialSingleSource() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + sumSize := len(c.eqs[0]) / 2 + + var one fr.Element + one.SetOne() + sendZero := c.singleSourcePoint[c.round].Equal(&one) + + p := make([]fr.Element, degree) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() + + res := make([]fr.Element, degree) + inputEvals := make([]fr.Element, (degree+1)*nbUniqueInputs) + weights := make([]fr.Element, nbWires) + + accumulateAt := func(offset, outI int) { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(inputEvals[offset+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &weights[w]) + res[outI].Add(&res[outI], summand) + } + } + + for h := start; h < end; h++ { + evalAt1Index := sumSize + h + for w := range nbWires { + weights[w].Set(&c.eqs[w][h]) + } + for k := range c.input { + inputEvals[k].Set(&c.input[k][h]) + step.Sub(&c.input[k][evalAt1Index], &c.input[k][h]) + for d := 1; d <= degree; d++ { + inputEvals[d*nbUniqueInputs+k].Add(&inputEvals[(d-1)*nbUniqueInputs+k], &step) + } + } + + if sendZero { + accumulateAt(0, 0) + for d := 2; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } else { + for d := 1; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } + } + mu.Lock() + for i := range p { + p[i].Add(&p[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + if sumSize < minBlockSize { + computeAll(0, sumSize) + } else { + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return p +} + // roundFold folds all input and eq polynomials at the verifier challenge r. // After this call, j ← j+1 and rⱼ = r. func (c *zeroCheckClaims) roundFold(r fr.Element) { @@ -217,6 +328,11 @@ func (c *zeroCheckClaims) roundFold(r fr.Element) { for i := range c.eqs { c.eqs[i].Fold(r) } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } else { wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { @@ -228,7 +344,13 @@ func (c *zeroCheckClaims) roundFold(r fr.Element) { for _, wg := range wgs { wg.Wait() } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } + c.round++ } // proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). @@ -289,6 +411,31 @@ func (r *resources) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { }, 512).Wait() } +// stripSingleSourceEqFactor removes the current Boolean-variable eq factor from +// an optimized single-source eq table while keeping the table duplicated across +// the next current variable. If e encodes a value independent of Xⱼ up to the +// remaining eq suffix, afterwards it encodes the same shape for the next round. +func (r *resources) stripSingleSourceEqFactor(e polynomial.MultiLin) { + if len(e) <= 1 { + return + } + mid := len(e) / 2 + work := func(start, end int) { + var sum fr.Element + for i := start; i < end; i++ { + sum.Add(&e[i], &e[mid+i]) + e[i].Set(&sum) + e[mid+i].Set(&sum) + } + } + const minBlockSize = 512 + if mid < minBlockSize { + work(0, mid) + } else { + r.workers.Submit(mid, work, minBlockSize).Wait() + } +} + type resources struct { // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). @@ -383,7 +530,7 @@ func (r *resources) verifySkipLevel(levelI int, proof Proof) error { } func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { - level := r.schedule[levelI] + level := r.schedule[levelI].(constraint.GkrSumcheckLevel) nbClaims := level.NbClaims() var foldingCoeff fr.Element if nbClaims >= 2 { @@ -468,6 +615,12 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { eqs: eqs, gateEvaluatorPools: pools, } + if src, ok := level.SingleClaimSource(); ok { + claims.singleSourcePoint = r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex] + for i := range claims.eqs { + r.stripSingleSourceEqFactor(claims.eqs[i]) + } + } return sumcheckProve(claims, &r.transcript) } diff --git a/internal/gkr/bn254/gkr_test.go b/internal/gkr/bn254/gkr_test.go index 3d00b93430..d38ea130e8 100644 --- a/internal/gkr/bn254/gkr_test.go +++ b/internal/gkr/bn254/gkr_test.go @@ -154,6 +154,27 @@ func TestSumcheckLevel(t *testing.T) { } } +func TestZeroCheckDegreeDispatch(t *testing.T) { + _, sCircuit := cache.Compile(t, gkrtesting.Poseidon2Circuit(4, 2)) + schedule, err := gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range sCircuit.Inputs() { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() + } + assignment.Complete(sCircuit) + + proof, err := Prove(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + assert.NoError(t, Verify(sCircuit, schedule, assignment, proof, newMessageCounter(1, 1))) + + assert.Len(t, proof[3].partialSumPolys[0], 2, "single-source s-box level should use the optimized degree") + assert.Len(t, proof[5].partialSumPolys[0], 3, "folded multi-source level must stay on the legacy degree") + assert.Len(t, proof[11].partialSumPolys[0], 2, "later single-source level should also use the optimized degree") +} + // testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { t.Helper() diff --git a/internal/gkr/bn254/sumcheck.go b/internal/gkr/bn254/sumcheck.go index be4d3bf5c4..bf501a1a36 100644 --- a/internal/gkr/bn254/sumcheck.go +++ b/internal/gkr/bn254/sumcheck.go @@ -62,6 +62,7 @@ type sumcheckClaims interface { type sumcheckLazyClaims interface { varsNum() int // varsNum = n degree(i int) int // degree of the total claim in the i'th variable + roundCombinationCoeff(i int) (fr.Element, bool) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error } @@ -99,13 +100,34 @@ func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum f gJ := make(polynomial.Polynomial, degree+1) gJR := claimedSum + var one fr.Element + one.SetOne() for j := range claims.varsNum() { if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } - copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + if coeff, ok := claims.roundCombinationCoeff(j); ok { + if coeff.Equal(&one) { + gJ[0].Set(&proof.partialSumPolys[j][0]) + gJ[1].Set(&gJR) + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Mul(&coeff, &proof.partialSumPolys[j][0]) + gJ[0].Sub(&gJR, &gJ[0]) + var oneMinusCoeff fr.Element + oneMinusCoeff.SetOne() + oneMinusCoeff.Sub(&oneMinusCoeff, &coeff) + oneMinusCoeff.Inverse(&oneMinusCoeff) + gJ[0].Mul(&gJ[0], &oneMinusCoeff) + } + if degree > 1 { + copy(gJ[2:], proof.partialSumPolys[j][1:]) + } + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + } r[j] = t.getChallenge(proof.partialSumPolys[j]...) gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) diff --git a/internal/gkr/bn254/sumcheck_test.go b/internal/gkr/bn254/sumcheck_test.go index 15cff2d307..53dc625c83 100644 --- a/internal/gkr/bn254/sumcheck_test.go +++ b/internal/gkr/bn254/sumcheck_test.go @@ -7,14 +7,12 @@ package gkr import ( "fmt" - "hash" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" "github.com/stretchr/testify/assert" - - "math/bits" + "hash" "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "math/bits" "strings" "testing" @@ -88,6 +86,52 @@ func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { } } +type coeffOneRoundLazyClaim struct { + expectedChallenge fr.Element + expectedFinalEval fr.Element +} + +func (c coeffOneRoundLazyClaim) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error { + if len(r) != 1 { + return fmt.Errorf("unexpected challenge length %d", len(r)) + } + if !r[0].Equal(&c.expectedChallenge) { + return fmt.Errorf("unexpected challenge") + } + if !purportedValue.Equal(&c.expectedFinalEval) { + return fmt.Errorf("unexpected final eval") + } + return nil +} + +func (c coeffOneRoundLazyClaim) degree(int) int { + return 1 +} + +func (c coeffOneRoundLazyClaim) roundCombinationCoeff(int) (fr.Element, bool) { + return *toElement(1), true +} + +func (c coeffOneRoundLazyClaim) varsNum() int { + return 1 +} + +func TestSumcheckVerifyCoeffOneReconstruction(t *testing.T) { + proof := sumcheckProof{ + partialSumPolys: []polynomial.Polynomial{{*toElement(4)}}, + } + lazyClaim := coeffOneRoundLazyClaim{ + expectedChallenge: *toElement(2), + expectedFinalEval: *toElement(14), + } + tr := transcript{h: newMessageCounter(2, 0)} + assert.NoError(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) + + proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) + tr = transcript{h: newMessageCounter(2, 0)} + assert.Error(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) +} + type singleMultilinClaim struct { g polynomial.MultiLin } @@ -133,6 +177,10 @@ func (c singleMultilinLazyClaim) degree(int) int { return 1 } +func (c singleMultilinLazyClaim) roundCombinationCoeff(int) (fr.Element, bool) { + return fr.Element{}, false +} + func (c singleMultilinLazyClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 8d07fe0310..9a694fb456 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -50,6 +50,15 @@ func (e *zeroCheckLazyClaims) degree(int) int { return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) } +func (e *zeroCheckLazyClaims) roundCombinationCoeff(round int) (fr.Element, bool) { + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + src, ok := level.SingleClaimSource() + if !ok { + return fr.Element{}, false + } + return e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex][round], true +} + // verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. // The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying // ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all @@ -64,7 +73,8 @@ func (e *zeroCheckLazyClaims) degree(int) int { // that the full sum matches purportedValue. func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr.Element, uniqueInputEvaluations []fr.Element) error { e.resources.outgoingEvalPoints[e.levelI] = [][]fr.Element{r} - level := e.resources.schedule[e.levelI] + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + _, optimized := level.SingleClaimSource() gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) var claimedEvals polynomial.Polynomial @@ -87,11 +97,17 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []fr.Element, purportedValue fr. gateEval.Set(evaluator.evaluate()) } - for _, src := range group.ClaimSources { - eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) - var term fr.Element - term.Mul(&eq, &gateEval) - claimedEvals = append(claimedEvals, term) + if optimized { + for range group.ClaimSources { + claimedEvals = append(claimedEvals, gateEval) + } + } else { + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term fr.Element + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } } levelWireI++ } @@ -114,17 +130,26 @@ type zeroCheckClaims struct { inputIndices [][]int // [wireInLevel][gateInputJ] → index in input eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points gateEvaluatorPools []*gateEvaluatorPool + singleSourcePoint []fr.Element + round int } func (c *zeroCheckClaims) varsNum() int { return c.resources.nbVars } -// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + if c.singleSourcePoint != nil { + return c.roundPolynomialSingleSource() + } + return c.roundPolynomialLegacy() +} + +// roundPolynomialLegacy computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). // The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). // The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). // By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { +func (c *zeroCheckClaims) roundPolynomialLegacy() polynomial.Polynomial { level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) degree := c.resources.circuit.ZeroCheckDegree(level) nbUniqueInputs := len(c.input) @@ -205,6 +230,92 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { return p } +// roundPolynomialSingleSource implements the Gru24 Section 3.2 path for levels +// whose claims all refer to the same evaluation point. It collapses the current +// eq factor by summing the two Boolean branches, so the prover only sends a +// degree-d polynomial instead of degree-(d+1). +func (c *zeroCheckClaims) roundPolynomialSingleSource() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + sumSize := len(c.eqs[0]) / 2 + + var one fr.Element + one.SetOne() + sendZero := c.singleSourcePoint[c.round].Equal(&one) + + p := make([]fr.Element, degree) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() + + res := make([]fr.Element, degree) + inputEvals := make([]fr.Element, (degree+1)*nbUniqueInputs) + weights := make([]fr.Element, nbWires) + + accumulateAt := func(offset, outI int) { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(inputEvals[offset+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &weights[w]) + res[outI].Add(&res[outI], summand) + } + } + + for h := start; h < end; h++ { + evalAt1Index := sumSize + h + for w := range nbWires { + weights[w].Set(&c.eqs[w][h]) + } + for k := range c.input { + inputEvals[k].Set(&c.input[k][h]) + step.Sub(&c.input[k][evalAt1Index], &c.input[k][h]) + for d := 1; d <= degree; d++ { + inputEvals[d*nbUniqueInputs+k].Add(&inputEvals[(d-1)*nbUniqueInputs+k], &step) + } + } + + if sendZero { + accumulateAt(0, 0) + for d := 2; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } else { + for d := 1; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } + } + mu.Lock() + for i := range p { + p[i].Add(&p[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + if sumSize < minBlockSize { + computeAll(0, sumSize) + } else { + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return p +} + // roundFold folds all input and eq polynomials at the verifier challenge r. // After this call, j ← j+1 and rⱼ = r. func (c *zeroCheckClaims) roundFold(r fr.Element) { @@ -217,6 +328,11 @@ func (c *zeroCheckClaims) roundFold(r fr.Element) { for i := range c.eqs { c.eqs[i].Fold(r) } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } else { wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { @@ -228,7 +344,13 @@ func (c *zeroCheckClaims) roundFold(r fr.Element) { for _, wg := range wgs { wg.Wait() } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } + c.round++ } // proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). @@ -289,6 +411,31 @@ func (r *resources) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { }, 512).Wait() } +// stripSingleSourceEqFactor removes the current Boolean-variable eq factor from +// an optimized single-source eq table while keeping the table duplicated across +// the next current variable. If e encodes a value independent of Xⱼ up to the +// remaining eq suffix, afterwards it encodes the same shape for the next round. +func (r *resources) stripSingleSourceEqFactor(e polynomial.MultiLin) { + if len(e) <= 1 { + return + } + mid := len(e) / 2 + work := func(start, end int) { + var sum fr.Element + for i := start; i < end; i++ { + sum.Add(&e[i], &e[mid+i]) + e[i].Set(&sum) + e[mid+i].Set(&sum) + } + } + const minBlockSize = 512 + if mid < minBlockSize { + work(0, mid) + } else { + r.workers.Submit(mid, work, minBlockSize).Wait() + } +} + type resources struct { // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). @@ -383,7 +530,7 @@ func (r *resources) verifySkipLevel(levelI int, proof Proof) error { } func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { - level := r.schedule[levelI] + level := r.schedule[levelI].(constraint.GkrSumcheckLevel) nbClaims := level.NbClaims() var foldingCoeff fr.Element if nbClaims >= 2 { @@ -468,6 +615,12 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { eqs: eqs, gateEvaluatorPools: pools, } + if src, ok := level.SingleClaimSource(); ok { + claims.singleSourcePoint = r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex] + for i := range claims.eqs { + r.stripSingleSourceEqFactor(claims.eqs[i]) + } + } return sumcheckProve(claims, &r.transcript) } diff --git a/internal/gkr/bw6-761/gkr_test.go b/internal/gkr/bw6-761/gkr_test.go index 630dfe6fd7..4da5cd2b63 100644 --- a/internal/gkr/bw6-761/gkr_test.go +++ b/internal/gkr/bw6-761/gkr_test.go @@ -154,6 +154,27 @@ func TestSumcheckLevel(t *testing.T) { } } +func TestZeroCheckDegreeDispatch(t *testing.T) { + _, sCircuit := cache.Compile(t, gkrtesting.Poseidon2Circuit(4, 2)) + schedule, err := gkrcore.DefaultProvingSchedule(sCircuit) + assert.NoError(t, err) + + assignment := make(WireAssignment, len(sCircuit)) + for _, i := range sCircuit.Inputs() { + assignment[i] = make([]fr.Element, 2) + fr.Vector(assignment[i]).MustSetRandom() + } + assignment.Complete(sCircuit) + + proof, err := Prove(sCircuit, schedule, assignment, newMessageCounter(1, 1)) + assert.NoError(t, err) + assert.NoError(t, Verify(sCircuit, schedule, assignment, proof, newMessageCounter(1, 1))) + + assert.Len(t, proof[3].partialSumPolys[0], 2, "single-source s-box level should use the optimized degree") + assert.Len(t, proof[5].partialSumPolys[0], 3, "folded multi-source level must stay on the legacy degree") + assert.Len(t, proof[11].partialSumPolys[0], 2, "later single-source level should also use the optimized degree") +} + // testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level. func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) { t.Helper() diff --git a/internal/gkr/bw6-761/sumcheck.go b/internal/gkr/bw6-761/sumcheck.go index 87e6d425ae..2a0482000d 100644 --- a/internal/gkr/bw6-761/sumcheck.go +++ b/internal/gkr/bw6-761/sumcheck.go @@ -62,6 +62,7 @@ type sumcheckClaims interface { type sumcheckLazyClaims interface { varsNum() int // varsNum = n degree(i int) int // degree of the total claim in the i'th variable + roundCombinationCoeff(i int) (fr.Element, bool) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error } @@ -99,13 +100,34 @@ func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum f gJ := make(polynomial.Polynomial, degree+1) gJR := claimedSum + var one fr.Element + one.SetOne() for j := range claims.varsNum() { if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } - copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + if coeff, ok := claims.roundCombinationCoeff(j); ok { + if coeff.Equal(&one) { + gJ[0].Set(&proof.partialSumPolys[j][0]) + gJ[1].Set(&gJR) + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Mul(&coeff, &proof.partialSumPolys[j][0]) + gJ[0].Sub(&gJR, &gJ[0]) + var oneMinusCoeff fr.Element + oneMinusCoeff.SetOne() + oneMinusCoeff.Sub(&oneMinusCoeff, &coeff) + oneMinusCoeff.Inverse(&oneMinusCoeff) + gJ[0].Mul(&gJ[0], &oneMinusCoeff) + } + if degree > 1 { + copy(gJ[2:], proof.partialSumPolys[j][1:]) + } + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + } r[j] = t.getChallenge(proof.partialSumPolys[j]...) gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) diff --git a/internal/gkr/bw6-761/sumcheck_test.go b/internal/gkr/bw6-761/sumcheck_test.go index d5fd5d305a..2bb24ba771 100644 --- a/internal/gkr/bw6-761/sumcheck_test.go +++ b/internal/gkr/bw6-761/sumcheck_test.go @@ -7,14 +7,12 @@ package gkr import ( "fmt" - "hash" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" "github.com/stretchr/testify/assert" - - "math/bits" + "hash" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "math/bits" "strings" "testing" @@ -88,6 +86,52 @@ func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { } } +type coeffOneRoundLazyClaim struct { + expectedChallenge fr.Element + expectedFinalEval fr.Element +} + +func (c coeffOneRoundLazyClaim) verifyFinalEval(r []fr.Element, purportedValue fr.Element, proof []fr.Element) error { + if len(r) != 1 { + return fmt.Errorf("unexpected challenge length %d", len(r)) + } + if !r[0].Equal(&c.expectedChallenge) { + return fmt.Errorf("unexpected challenge") + } + if !purportedValue.Equal(&c.expectedFinalEval) { + return fmt.Errorf("unexpected final eval") + } + return nil +} + +func (c coeffOneRoundLazyClaim) degree(int) int { + return 1 +} + +func (c coeffOneRoundLazyClaim) roundCombinationCoeff(int) (fr.Element, bool) { + return *toElement(1), true +} + +func (c coeffOneRoundLazyClaim) varsNum() int { + return 1 +} + +func TestSumcheckVerifyCoeffOneReconstruction(t *testing.T) { + proof := sumcheckProof{ + partialSumPolys: []polynomial.Polynomial{{*toElement(4)}}, + } + lazyClaim := coeffOneRoundLazyClaim{ + expectedChallenge: *toElement(2), + expectedFinalEval: *toElement(14), + } + tr := transcript{h: newMessageCounter(2, 0)} + assert.NoError(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) + + proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) + tr = transcript{h: newMessageCounter(2, 0)} + assert.Error(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) +} + type singleMultilinClaim struct { g polynomial.MultiLin } @@ -133,6 +177,10 @@ func (c singleMultilinLazyClaim) degree(int) int { return 1 } +func (c singleMultilinLazyClaim) roundCombinationCoeff(int) (fr.Element, bool) { + return fr.Element{}, false +} + func (c singleMultilinLazyClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go index 9009ca71e5..a91c3bd619 100644 --- a/internal/gkr/gkr.go +++ b/internal/gkr/gkr.go @@ -72,6 +72,15 @@ func (e *zeroCheckLazyClaims) degree(int) int { return e.r.circuit.ZeroCheckDegree(e.r.schedule[e.levelI].(constraint.GkrSumcheckLevel)) } +func (e *zeroCheckLazyClaims) roundCombinationCoeff(round int) (frontend.Variable, bool) { + level := e.r.schedule[e.levelI].(constraint.GkrSumcheckLevel) + src, ok := level.SingleClaimSource() + if !ok { + return nil, false + } + return e.r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex][round], true +} + // verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. // The sumcheck protocol has already reduced the per-wire claims to verifying // ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all @@ -84,7 +93,8 @@ func (e *zeroCheckLazyClaims) degree(int) int { // uniqueInputEvaluations; those claims are verified by lower levels' sumchecks. func (e *zeroCheckLazyClaims) verifyFinalEval(api frontend.API, r []frontend.Variable, purportedValue frontend.Variable, uniqueInputEvaluations []frontend.Variable) error { e.r.outgoingEvalPoints[e.levelI] = [][]frontend.Variable{r} - level := e.r.schedule[e.levelI] + level := e.r.schedule[e.levelI].(constraint.GkrSumcheckLevel) + _, optimized := level.SingleClaimSource() perWireInputEvals := gkrcore.ReduplicateInputs(level, e.r.circuit, uniqueInputEvaluations) var terms []frontend.Variable @@ -101,10 +111,16 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(api frontend.API, r []frontend.Var gateEval = wire.Gate.Evaluate(FrontendAPIWrapper{api}, perWireInputEvals[levelWireI]...) } - for _, src := range group.ClaimSources { - eq := polynomial.EvalEq(api, e.r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) - term := api.Mul(eq, gateEval) - terms = append(terms, term) + if optimized { + for range group.ClaimSources { + terms = append(terms, gateEval) + } + } else { + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(api, e.r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + term := api.Mul(eq, gateEval) + terms = append(terms, term) + } } levelWireI++ } diff --git a/internal/gkr/gkrcore/schedule.go b/internal/gkr/gkrcore/schedule.go index 7fecf28aab..a99175e88e 100644 --- a/internal/gkr/gkrcore/schedule.go +++ b/internal/gkr/gkrcore/schedule.go @@ -57,6 +57,9 @@ func (c Circuit[G]) ZeroCheckDegree(level constraint.GkrSumcheckLevel) int { maxDeg = max(maxDeg, curr) } } + if _, ok := level.SingleClaimSource(); ok { + return maxDeg + } return maxDeg + 1 } diff --git a/internal/gkr/gkrcore/schedule_test.go b/internal/gkr/gkrcore/schedule_test.go index 55241be576..91e585079c 100644 --- a/internal/gkr/gkrcore/schedule_test.go +++ b/internal/gkr/gkrcore/schedule_test.go @@ -106,3 +106,20 @@ func TestDefaultProvingSchedulePoseidon2(t *testing.T) { constraint.GkrSkipLevel{Wires: []int{24}, ClaimSources: []constraint.GkrClaimSource{{Level: 17}}}, }, schedule) } + +func TestZeroCheckDegreeSingleSourceBoundary(t *testing.T) { + _, c := scheduleTestCache.Compile(t, gkrtesting.Poseidon2Circuit(4, 2)) + schedule, err := gkrcore.DefaultProvingSchedule(c) + require.NoError(t, err) + + singleSource := schedule[3].(constraint.GkrSumcheckLevel) + src, ok := singleSource.SingleClaimSource() + require.True(t, ok) + require.Equal(t, constraint.GkrClaimSource{Level: 4}, src) + require.Equal(t, 2, c.ZeroCheckDegree(singleSource)) + + multiSource := schedule[5].(constraint.GkrSumcheckLevel) + _, ok = multiSource.SingleClaimSource() + require.False(t, ok) + require.Equal(t, 3, c.ZeroCheckDegree(multiSource)) +} diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 1e4882d272..36e8de3764 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -50,6 +50,15 @@ func (e *zeroCheckLazyClaims) degree(int) int { return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)) } +func (e *zeroCheckLazyClaims) roundCombinationCoeff(round int) (small_rational.SmallRational, bool) { + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + src, ok := level.SingleClaimSource() + if !ok { + return small_rational.SmallRational{}, false + } + return e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex][round], true +} + // verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r. // The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying // ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all @@ -64,7 +73,8 @@ func (e *zeroCheckLazyClaims) degree(int) int { // that the full sum matches purportedValue. func (e *zeroCheckLazyClaims) verifyFinalEval(r []small_rational.SmallRational, purportedValue small_rational.SmallRational, uniqueInputEvaluations []small_rational.SmallRational) error { e.resources.outgoingEvalPoints[e.levelI] = [][]small_rational.SmallRational{r} - level := e.resources.schedule[e.levelI] + level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel) + _, optimized := level.SingleClaimSource() gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations) var claimedEvals polynomial.Polynomial @@ -87,11 +97,17 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []small_rational.SmallRational, gateEval.Set(evaluator.evaluate()) } - for _, src := range group.ClaimSources { - eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) - var term small_rational.SmallRational - term.Mul(&eq, &gateEval) - claimedEvals = append(claimedEvals, term) + if optimized { + for range group.ClaimSources { + claimedEvals = append(claimedEvals, gateEval) + } + } else { + for _, src := range group.ClaimSources { + eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r) + var term small_rational.SmallRational + term.Mul(&eq, &gateEval) + claimedEvals = append(claimedEvals, term) + } } levelWireI++ } @@ -114,17 +130,26 @@ type zeroCheckClaims struct { inputIndices [][]int // [wireInLevel][gateInputJ] → index in input eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points gateEvaluatorPools []*gateEvaluatorPool + singleSourcePoint []small_rational.SmallRational + round int } func (c *zeroCheckClaims) varsNum() int { return c.resources.nbVars } -// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). +func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { + if c.singleSourcePoint != nil { + return c.roundPolynomialSingleSource() + } + return c.roundPolynomialLegacy() +} + +// roundPolynomialLegacy computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)). // The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). // The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). // By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { +func (c *zeroCheckClaims) roundPolynomialLegacy() polynomial.Polynomial { level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) degree := c.resources.circuit.ZeroCheckDegree(level) nbUniqueInputs := len(c.input) @@ -205,6 +230,92 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial { return p } +// roundPolynomialSingleSource implements the Gru24 Section 3.2 path for levels +// whose claims all refer to the same evaluation point. It collapses the current +// eq factor by summing the two Boolean branches, so the prover only sends a +// degree-d polynomial instead of degree-(d+1). +func (c *zeroCheckClaims) roundPolynomialSingleSource() polynomial.Polynomial { + level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel) + degree := c.resources.circuit.ZeroCheckDegree(level) + nbUniqueInputs := len(c.input) + nbWires := len(c.eqs) + sumSize := len(c.eqs[0]) / 2 + + var one small_rational.SmallRational + one.SetOne() + sendZero := c.singleSourcePoint[c.round].Equal(&one) + + p := make([]small_rational.SmallRational, degree) + var mu sync.Mutex + computeAll := func(start, end int) { + var step small_rational.SmallRational + + evaluators := make([]*gateEvaluator, nbWires) + for w := range nbWires { + evaluators[w] = c.gateEvaluatorPools[w].get() + } + defer func() { + for w := range nbWires { + c.gateEvaluatorPools[w].put(evaluators[w]) + } + }() + + res := make([]small_rational.SmallRational, degree) + inputEvals := make([]small_rational.SmallRational, (degree+1)*nbUniqueInputs) + weights := make([]small_rational.SmallRational, nbWires) + + accumulateAt := func(offset, outI int) { + for w := range nbWires { + for _, inputI := range c.inputIndices[w] { + evaluators[w].pushInput(inputEvals[offset+inputI]) + } + summand := evaluators[w].evaluate() + summand.Mul(summand, &weights[w]) + res[outI].Add(&res[outI], summand) + } + } + + for h := start; h < end; h++ { + evalAt1Index := sumSize + h + for w := range nbWires { + weights[w].Set(&c.eqs[w][h]) + } + for k := range c.input { + inputEvals[k].Set(&c.input[k][h]) + step.Sub(&c.input[k][evalAt1Index], &c.input[k][h]) + for d := 1; d <= degree; d++ { + inputEvals[d*nbUniqueInputs+k].Add(&inputEvals[(d-1)*nbUniqueInputs+k], &step) + } + } + + if sendZero { + accumulateAt(0, 0) + for d := 2; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } else { + for d := 1; d <= degree; d++ { + accumulateAt(d*nbUniqueInputs, d-1) + } + } + } + mu.Lock() + for i := range p { + p[i].Add(&p[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + if sumSize < minBlockSize { + computeAll(0, sumSize) + } else { + c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return p +} + // roundFold folds all input and eq polynomials at the verifier challenge r. // After this call, j ← j+1 and rⱼ = r. func (c *zeroCheckClaims) roundFold(r small_rational.SmallRational) { @@ -217,6 +328,11 @@ func (c *zeroCheckClaims) roundFold(r small_rational.SmallRational) { for i := range c.eqs { c.eqs[i].Fold(r) } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } else { wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs)) for i := range c.input { @@ -228,7 +344,13 @@ func (c *zeroCheckClaims) roundFold(r small_rational.SmallRational) { for _, wg := range wgs { wg.Wait() } + if c.singleSourcePoint != nil { + for i := range c.eqs { + c.resources.stripSingleSourceEqFactor(c.eqs[i]) + } + } } + c.round++ } // proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ). @@ -289,6 +411,31 @@ func (r *resources) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRati }, 512).Wait() } +// stripSingleSourceEqFactor removes the current Boolean-variable eq factor from +// an optimized single-source eq table while keeping the table duplicated across +// the next current variable. If e encodes a value independent of Xⱼ up to the +// remaining eq suffix, afterwards it encodes the same shape for the next round. +func (r *resources) stripSingleSourceEqFactor(e polynomial.MultiLin) { + if len(e) <= 1 { + return + } + mid := len(e) / 2 + work := func(start, end int) { + var sum small_rational.SmallRational + for i := start; i < end; i++ { + sum.Add(&e[i], &e[mid+i]) + e[i].Set(&sum) + e[mid+i].Set(&sum) + } + } + const minBlockSize = 512 + if mid < minBlockSize { + work(0, mid) + } else { + r.workers.Submit(mid, work, minBlockSize).Wait() + } +} + type resources struct { // outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i. // outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho). @@ -383,7 +530,7 @@ func (r *resources) verifySkipLevel(levelI int, proof Proof) error { } func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { - level := r.schedule[levelI] + level := r.schedule[levelI].(constraint.GkrSumcheckLevel) nbClaims := level.NbClaims() var foldingCoeff small_rational.SmallRational if nbClaims >= 2 { @@ -468,6 +615,12 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof { eqs: eqs, gateEvaluatorPools: pools, } + if src, ok := level.SingleClaimSource(); ok { + claims.singleSourcePoint = r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex] + for i := range claims.eqs { + r.stripSingleSourceEqFactor(claims.eqs[i]) + } + } return sumcheckProve(claims, &r.transcript) } diff --git a/internal/gkr/small_rational/sumcheck.go b/internal/gkr/small_rational/sumcheck.go index 375c402639..2279c138c9 100644 --- a/internal/gkr/small_rational/sumcheck.go +++ b/internal/gkr/small_rational/sumcheck.go @@ -62,6 +62,7 @@ type sumcheckClaims interface { type sumcheckLazyClaims interface { varsNum() int // varsNum = n degree(i int) int // degree of the total claim in the i'th variable + roundCombinationCoeff(i int) (small_rational.SmallRational, bool) verifyFinalEval(r []small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error } @@ -99,13 +100,34 @@ func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum s gJ := make(polynomial.Polynomial, degree+1) gJR := claimedSum + var one small_rational.SmallRational + one.SetOne() for j := range claims.varsNum() { if len(proof.partialSumPolys[j]) != degree { return errors.New("malformed proof") } - copy(gJ[1:], proof.partialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + if coeff, ok := claims.roundCombinationCoeff(j); ok { + if coeff.Equal(&one) { + gJ[0].Set(&proof.partialSumPolys[j][0]) + gJ[1].Set(&gJR) + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Mul(&coeff, &proof.partialSumPolys[j][0]) + gJ[0].Sub(&gJR, &gJ[0]) + var oneMinusCoeff small_rational.SmallRational + oneMinusCoeff.SetOne() + oneMinusCoeff.Sub(&oneMinusCoeff, &coeff) + oneMinusCoeff.Inverse(&oneMinusCoeff) + gJ[0].Mul(&gJ[0], &oneMinusCoeff) + } + if degree > 1 { + copy(gJ[2:], proof.partialSumPolys[j][1:]) + } + } else { + copy(gJ[1:], proof.partialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0]) + } r[j] = t.getChallenge(proof.partialSumPolys[j]...) gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)]) diff --git a/internal/gkr/small_rational/sumcheck_test.go b/internal/gkr/small_rational/sumcheck_test.go index b364291f7b..227db5d319 100644 --- a/internal/gkr/small_rational/sumcheck_test.go +++ b/internal/gkr/small_rational/sumcheck_test.go @@ -7,10 +7,10 @@ package gkr import ( "fmt" - "hash" - + "github.com/consensys/gnark/internal/small_rational" "github.com/consensys/gnark/internal/small_rational/polynomial" "github.com/stretchr/testify/assert" + "hash" "strings" "testing" @@ -83,3 +83,49 @@ func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { } } } + +type coeffOneRoundLazyClaim struct { + expectedChallenge small_rational.SmallRational + expectedFinalEval small_rational.SmallRational +} + +func (c coeffOneRoundLazyClaim) verifyFinalEval(r []small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error { + if len(r) != 1 { + return fmt.Errorf("unexpected challenge length %d", len(r)) + } + if !r[0].Equal(&c.expectedChallenge) { + return fmt.Errorf("unexpected challenge") + } + if !purportedValue.Equal(&c.expectedFinalEval) { + return fmt.Errorf("unexpected final eval") + } + return nil +} + +func (c coeffOneRoundLazyClaim) degree(int) int { + return 1 +} + +func (c coeffOneRoundLazyClaim) roundCombinationCoeff(int) (small_rational.SmallRational, bool) { + return *toElement(1), true +} + +func (c coeffOneRoundLazyClaim) varsNum() int { + return 1 +} + +func TestSumcheckVerifyCoeffOneReconstruction(t *testing.T) { + proof := sumcheckProof{ + partialSumPolys: []polynomial.Polynomial{{*toElement(4)}}, + } + lazyClaim := coeffOneRoundLazyClaim{ + expectedChallenge: *toElement(2), + expectedFinalEval: *toElement(14), + } + tr := transcript{h: newMessageCounter(2, 0)} + assert.NoError(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) + + proof.partialSumPolys[0][0].Add(&proof.partialSumPolys[0][0], toElement(1)) + tr = transcript{h: newMessageCounter(2, 0)} + assert.Error(t, sumcheckVerify(lazyClaim, proof, *toElement(9), 1, &tr)) +} diff --git a/internal/gkr/small_rational/sumcheck_test_vector_gen.go b/internal/gkr/small_rational/sumcheck_test_vector_gen.go index 89ad304949..863fb54a04 100644 --- a/internal/gkr/small_rational/sumcheck_test_vector_gen.go +++ b/internal/gkr/small_rational/sumcheck_test_vector_gen.go @@ -8,15 +8,14 @@ package gkr import ( "encoding/json" "fmt" + "github.com/consensys/gnark/internal/gkr/gkrtesting" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" "hash" "math/bits" "os" "path/filepath" "runtime/pprof" - - "github.com/consensys/gnark/internal/gkr/gkrtesting" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" ) func runMultilin(testCaseInfo *sumcheckTestCaseInfo) error { @@ -196,6 +195,10 @@ func (c singleMultilinLazyClaim) degree(int) int { return 1 } +func (c singleMultilinLazyClaim) roundCombinationCoeff(int) (small_rational.SmallRational, bool) { + return small_rational.SmallRational{}, false +} + func (c singleMultilinLazyClaim) varsNum() int { return bits.TrailingZeros(uint(len(c.g))) } diff --git a/internal/gkr/sumcheck.go b/internal/gkr/sumcheck.go index a3deee14f7..5da707c864 100644 --- a/internal/gkr/sumcheck.go +++ b/internal/gkr/sumcheck.go @@ -14,6 +14,7 @@ import ( type sumcheckLazyClaims interface { varsNum() int degree(i int) int + roundCombinationCoeff(i int) (frontend.Variable, bool) verifyFinalEval(api frontend.API, r []frontend.Variable, purportedValue frontend.Variable, proof []frontend.Variable) error } @@ -59,8 +60,22 @@ func verifySumcheck(api frontend.API, claims sumcheckLazyClaims, proof sumcheckP if len(partialSumPoly) != degree { return errors.New("malformed proof") } - copy(gJ[1:], partialSumPoly) - gJ[0] = api.Sub(gJR, partialSumPoly[0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + if coeff, ok := claims.roundCombinationCoeff(j); ok { + isOne := api.IsZero(api.Sub(coeff, 1)) + gJ[1] = api.Select(isOne, gJR, partialSumPoly[0]) + + safeDen := api.Select(isOne, 1, api.Sub(1, coeff)) + safeNum := api.Select(isOne, 0, api.Sub(gJR, api.Mul(coeff, partialSumPoly[0]))) + g0Recovered := api.Div(safeNum, safeDen) + gJ[0] = api.Select(isOne, partialSumPoly[0], g0Recovered) + + for i := 2; i < len(gJ); i++ { + gJ[i] = partialSumPoly[i-1] + } + } else { + copy(gJ[1:], partialSumPoly) + gJ[0] = api.Sub(gJR, partialSumPoly[0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + } r[j] = t.getChallenge(proof.PartialSumPolys[j]...) diff --git a/internal/gkr/sumcheck_coeff1_test.go b/internal/gkr/sumcheck_coeff1_test.go new file mode 100644 index 0000000000..f1cc76db8d --- /dev/null +++ b/internal/gkr/sumcheck_coeff1_test.go @@ -0,0 +1,51 @@ +package gkr + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/polynomial" + "github.com/consensys/gnark/test" +) + +type coeffOneFrontendLazyClaim struct{} + +func (coeffOneFrontendLazyClaim) degree(int) int { return 1 } + +func (coeffOneFrontendLazyClaim) roundCombinationCoeff(int) (frontend.Variable, bool) { + return 1, true +} + +func (coeffOneFrontendLazyClaim) varsNum() int { return 1 } + +func (coeffOneFrontendLazyClaim) verifyFinalEval(api frontend.API, r []frontend.Variable, purportedValue frontend.Variable, proof []frontend.Variable) error { + if len(r) != 1 { + return fmt.Errorf("unexpected challenge length %d", len(r)) + } + api.AssertIsEqual(r[0], 2) + api.AssertIsEqual(purportedValue, 14) + return nil +} + +type coeffOneFrontendVerifyCircuit struct { + G0 frontend.Variable + ClaimedSum frontend.Variable +} + +func (c *coeffOneFrontendVerifyCircuit) Define(api frontend.API) error { + proof := sumcheckProof{ + PartialSumPolys: []polynomial.Polynomial{{c.G0}}, + } + tr := transcript{h: newMessageCounter(api, 2, 0)} + return verifySumcheck(api, coeffOneFrontendLazyClaim{}, proof, c.ClaimedSum, 1, &tr) +} + +func TestFrontendSumcheckVerifyCoeffOneReconstruction(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit( + &coeffOneFrontendVerifyCircuit{}, + test.WithValidAssignment(&coeffOneFrontendVerifyCircuit{G0: 4, ClaimedSum: 9}), + test.WithInvalidAssignment(&coeffOneFrontendVerifyCircuit{G0: 5, ClaimedSum: 9}), + ) +} diff --git a/internal/gkr/test_vectors/single_input_two_outs_two_instances.json b/internal/gkr/test_vectors/single_input_two_outs_two_instances.json index 9f39dc5cdd..fbb3b6603e 100644 --- a/internal/gkr/test_vectors/single_input_two_outs_two_instances.json +++ b/internal/gkr/test_vectors/single_input_two_outs_two_instances.json @@ -38,9 +38,8 @@ ], "partialSumPolys": [ [ - -4, - -36, - -112 + 4, + 9 ] ] }, diff --git a/internal/gkr/test_vectors/single_mimc_gate_four_instances.json b/internal/gkr/test_vectors/single_mimc_gate_four_instances.json index 3907e92c92..780fce6405 100644 --- a/internal/gkr/test_vectors/single_mimc_gate_four_instances.json +++ b/internal/gkr/test_vectors/single_mimc_gate_four_instances.json @@ -46,24 +46,22 @@ ], "partialSumPolys": [ [ - -32640, - -2239484, - -29360128, - -200000010, - -931628672, - -3373267120, - -10200858624, - -26939400158 + 32640, + 559871, + 4194304, + 20000001, + 71663744, + 210829195, + 536887296 ], [ - -81920, - -41943040, - -1254113280, - -13421772800, - -83200000000, - -366917713920, - -1281828208640, - -3779571220480 + 16384, + 2097152, + 35831808, + 268435456, + 1280000000, + 4586471424, + 13492928512 ] ] } diff --git a/internal/gkr/test_vectors/single_mimc_gate_two_instances.json b/internal/gkr/test_vectors/single_mimc_gate_two_instances.json index 80a585dced..4352c97624 100644 --- a/internal/gkr/test_vectors/single_mimc_gate_two_instances.json +++ b/internal/gkr/test_vectors/single_mimc_gate_two_instances.json @@ -40,14 +40,13 @@ ], "partialSumPolys": [ [ - -2187, - -65536, - -546875, - -2799360, - -10706059, - -33554432, - -90876411, - -220000000 + 2187, + 16384, + 78125, + 279936, + 823543, + 2097152, + 4782969 ] ] } diff --git a/internal/gkr/test_vectors/single_mul_gate_two_instances.json b/internal/gkr/test_vectors/single_mul_gate_two_instances.json index 390f3ef9ba..a589c31fa9 100644 --- a/internal/gkr/test_vectors/single_mul_gate_two_instances.json +++ b/internal/gkr/test_vectors/single_mul_gate_two_instances.json @@ -40,9 +40,8 @@ ], "partialSumPolys": [ [ - -9, - -32, - -35 + 9, + 8 ] ] }