Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions constraint/gkr.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,32 @@ func (l GkrSumcheckLevel) NbClaims() int {
func (l GkrSumcheckLevel) ClaimGroups() []GkrClaimGroup { return l }
func (l GkrSumcheckLevel) FinalEvalProofIndex(wireI, _ int) int { return wireI }

// SingleClaimSource returns the unique distinct claim source across the whole
// level when every claim in the level refers to the same evaluation point.
// This is the schedule-side eligibility boundary for the Gru24 Section 3.2
// optimization; folded multi-source levels return false.
func (l GkrSumcheckLevel) SingleClaimSource() (GkrClaimSource, bool) {
if len(l) == 0 {
return GkrClaimSource{}, false
}

var first GkrClaimSource
found := false
for _, group := range l {
for _, src := range group.ClaimSources {
if !found {
first = src
found = true
continue
}
if src != first {
return GkrClaimSource{}, false
}
}
}
return first, found
}

func (l GkrSkipLevel) NbOutgoingEvalPoints() int { return len(l.ClaimSources) }
func (l GkrSkipLevel) NbClaims() int {
return GkrClaimGroup(l).NbClaims()
Expand Down
171 changes: 162 additions & 9 deletions internal/generator/backend/template/gkr/gkr.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ func (e *zeroCheckLazyClaims) degree(int) int {
return e.resources.circuit.ZeroCheckDegree(e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel))
}

func (e *zeroCheckLazyClaims) roundCombinationCoeff(round int) ({{ .ElementType }}, bool) {
level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)
src, ok := level.SingleClaimSource()
if !ok {
return {{ .ElementType }}{}, false
}
return e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex][round], true
}

// verifyFinalEval finalizes the verification of a level at the sumcheck evaluation point r.
// The sumcheck protocol has already reduced the per-wire claims w(xᵢ) = yᵢ to verifying
// ∑ᵢ cⁱ eq(xᵢ, r) · wᵢ(r) = purportedValue, where the sum runs over all
Expand All @@ -57,7 +66,8 @@ func (e *zeroCheckLazyClaims) degree(int) int {
// that the full sum matches purportedValue.
func (e *zeroCheckLazyClaims) verifyFinalEval(r []{{ .ElementType }}, purportedValue {{ .ElementType }}, uniqueInputEvaluations []{{ .ElementType }}) error {
e.resources.outgoingEvalPoints[e.levelI] = [][]{{ .ElementType }}{r}
level := e.resources.schedule[e.levelI]
level := e.resources.schedule[e.levelI].(constraint.GkrSumcheckLevel)
_, optimized := level.SingleClaimSource()
gateInputEvals := gkrcore.ReduplicateInputs(level, e.resources.circuit, uniqueInputEvaluations)

var claimedEvals polynomial.Polynomial
Expand All @@ -80,11 +90,17 @@ func (e *zeroCheckLazyClaims) verifyFinalEval(r []{{ .ElementType }}, purportedV
gateEval.Set(evaluator.evaluate())
}

for _, src := range group.ClaimSources {
eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r)
var term {{ .ElementType }}
term.Mul(&eq, &gateEval)
claimedEvals = append(claimedEvals, term)
if optimized {
for range group.ClaimSources {
claimedEvals = append(claimedEvals, gateEval)
}
} else {
for _, src := range group.ClaimSources {
eq := polynomial.EvalEq(e.resources.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex], r)
var term {{ .ElementType }}
term.Mul(&eq, &gateEval)
claimedEvals = append(claimedEvals, term)
}
}
levelWireI++
}
Expand All @@ -107,17 +123,26 @@ type zeroCheckClaims struct {
inputIndices [][]int // [wireInLevel][gateInputJ] → index in input
eqs []polynomial.MultiLin // per-wire interpolation bases for evaluating wire assignments at challenge points
gateEvaluatorPools []*gateEvaluatorPool
singleSourcePoint []{{ .ElementType }}
round int
}

func (c *zeroCheckClaims) varsNum() int {
return c.resources.nbVars
}

// roundPolynomial computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)).
func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
if c.singleSourcePoint != nil {
return c.roundPolynomialSingleSource()
}
return c.roundPolynomialLegacy()
}

// roundPolynomialLegacy computes gⱼ = ∑ₕ ∑ᵥ eqs[v](Xⱼ, h...) · gateᵥ(inputs(Xⱼ, h...)).
// The polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)).
// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁).
// By convention, g₀ is a constant polynomial equal to the claimed sum.
func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
func (c *zeroCheckClaims) roundPolynomialLegacy() polynomial.Polynomial {
level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel)
degree := c.resources.circuit.ZeroCheckDegree(level)
nbUniqueInputs := len(c.input)
Expand Down Expand Up @@ -198,6 +223,92 @@ func (c *zeroCheckClaims) roundPolynomial() polynomial.Polynomial {
return p
}

// roundPolynomialSingleSource implements the Gru24 Section 3.2 path for levels
// whose claims all refer to the same evaluation point. It collapses the current
// eq factor by summing the two Boolean branches, so the prover only sends a
// degree-d polynomial instead of degree-(d+1).
func (c *zeroCheckClaims) roundPolynomialSingleSource() polynomial.Polynomial {
level := c.resources.schedule[c.levelI].(constraint.GkrSumcheckLevel)
degree := c.resources.circuit.ZeroCheckDegree(level)
nbUniqueInputs := len(c.input)
nbWires := len(c.eqs)
sumSize := len(c.eqs[0]) / 2

var one {{ .ElementType }}
one.SetOne()
sendZero := c.singleSourcePoint[c.round].Equal(&one)

p := make([]{{ .ElementType }}, degree)
var mu sync.Mutex
computeAll := func(start, end int) {
var step {{ .ElementType }}

evaluators := make([]*gateEvaluator, nbWires)
for w := range nbWires {
evaluators[w] = c.gateEvaluatorPools[w].get()
}
defer func() {
for w := range nbWires {
c.gateEvaluatorPools[w].put(evaluators[w])
}
}()

res := make([]{{ .ElementType }}, degree)
inputEvals := make([]{{ .ElementType }}, (degree+1)*nbUniqueInputs)
weights := make([]{{ .ElementType }}, nbWires)

accumulateAt := func(offset, outI int) {
for w := range nbWires {
for _, inputI := range c.inputIndices[w] {
evaluators[w].pushInput(inputEvals[offset+inputI])
}
summand := evaluators[w].evaluate()
summand.Mul(summand, &weights[w])
res[outI].Add(&res[outI], summand)
}
}

for h := start; h < end; h++ {
evalAt1Index := sumSize + h
for w := range nbWires {
weights[w].Set(&c.eqs[w][h])
}
for k := range c.input {
inputEvals[k].Set(&c.input[k][h])
step.Sub(&c.input[k][evalAt1Index], &c.input[k][h])
for d := 1; d <= degree; d++ {
inputEvals[d*nbUniqueInputs+k].Add(&inputEvals[(d-1)*nbUniqueInputs+k], &step)
}
}

if sendZero {
accumulateAt(0, 0)
for d := 2; d <= degree; d++ {
accumulateAt(d*nbUniqueInputs, d-1)
}
} else {
for d := 1; d <= degree; d++ {
accumulateAt(d*nbUniqueInputs, d-1)
}
}
}
mu.Lock()
for i := range p {
p[i].Add(&p[i], &res[i])
}
mu.Unlock()
}

const minBlockSize = 64
if sumSize < minBlockSize {
computeAll(0, sumSize)
} else {
c.resources.workers.Submit(sumSize, computeAll, minBlockSize).Wait()
}

return p
}

// roundFold folds all input and eq polynomials at the verifier challenge r.
// After this call, j ← j+1 and rⱼ = r.
func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) {
Expand All @@ -210,6 +321,11 @@ func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) {
for i := range c.eqs {
c.eqs[i].Fold(r)
}
if c.singleSourcePoint != nil {
for i := range c.eqs {
c.resources.stripSingleSourceEqFactor(c.eqs[i])
}
}
} else {
wgs := make([]*sync.WaitGroup, len(c.input)+len(c.eqs))
for i := range c.input {
Expand All @@ -221,7 +337,13 @@ func (c *zeroCheckClaims) roundFold(r {{ .ElementType }}) {
for _, wg := range wgs {
wg.Wait()
}
if c.singleSourcePoint != nil {
for i := range c.eqs {
c.resources.stripSingleSourceEqFactor(c.eqs[i])
}
}
}
c.round++
}

// proveFinalEval provides the unique input wire values wᵢ(r₁, ..., rₙ).
Expand Down Expand Up @@ -282,6 +404,31 @@ func (r *resources) eqAcc(e, m polynomial.MultiLin, q []{{ .ElementType }}) {
}, 512).Wait()
}

// stripSingleSourceEqFactor removes the current Boolean-variable eq factor from
// an optimized single-source eq table while keeping the table duplicated across
// the next current variable. If e encodes a value independent of Xⱼ up to the
// remaining eq suffix, afterwards it encodes the same shape for the next round.
func (r *resources) stripSingleSourceEqFactor(e polynomial.MultiLin) {
if len(e) <= 1 {
return
}
mid := len(e) / 2
work := func(start, end int) {
var sum {{ .ElementType }}
for i := start; i < end; i++ {
sum.Add(&e[i], &e[mid+i])
e[i].Set(&sum)
e[mid+i].Set(&sum)
}
}
const minBlockSize = 512
if mid < minBlockSize {
work(0, mid)
} else {
r.workers.Submit(mid, work, minBlockSize).Wait()
}
}

type resources struct {
// outgoingEvalPoints[i][k] is the k-th outgoing evaluation point (evaluation challenge) produced at schedule level i.
// outgoingEvalPoints[len(schedule)][0] holds the initial challenge (firstChallenge / rho).
Expand Down Expand Up @@ -376,7 +523,7 @@ func (r *resources) verifySkipLevel(levelI int, proof Proof) error {
}

func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof {
level := r.schedule[levelI]
level := r.schedule[levelI].(constraint.GkrSumcheckLevel)
nbClaims := level.NbClaims()
var foldingCoeff {{ .ElementType }}
if nbClaims >= 2 {
Expand Down Expand Up @@ -461,6 +608,12 @@ func (r *resources) proveSumcheckLevel(levelI int) sumcheckProof {
eqs: eqs,
gateEvaluatorPools: pools,
}
if src, ok := level.SingleClaimSource(); ok {
claims.singleSourcePoint = r.outgoingEvalPoints[src.Level][src.OutgoingClaimIndex]
for i := range claims.eqs {
r.stripSingleSourceEqFactor(claims.eqs[i])
}
}
return sumcheckProve(claims, &r.transcript)
}

Expand Down
21 changes: 21 additions & 0 deletions internal/generator/backend/template/gkr/gkr.test.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,27 @@ func TestSumcheckLevel(t *testing.T) {
}
}

func TestZeroCheckDegreeDispatch(t *testing.T) {
_, sCircuit := cache.Compile(t, gkrtesting.Poseidon2Circuit(4, 2))
schedule, err := gkrcore.DefaultProvingSchedule(sCircuit)
assert.NoError(t, err)

assignment := make(WireAssignment, len(sCircuit))
for _, i := range sCircuit.Inputs() {
assignment[i] = make([]{{ .ElementType }}, 2)
{{ .FieldPackageName }}.Vector(assignment[i]).MustSetRandom()
}
assignment.Complete(sCircuit)

proof, err := Prove(sCircuit, schedule, assignment, newMessageCounter(1, 1))
assert.NoError(t, err)
assert.NoError(t, Verify(sCircuit, schedule, assignment, proof, newMessageCounter(1, 1)))

assert.Len(t, proof[3].partialSumPolys[0], 2, "single-source s-box level should use the optimized degree")
assert.Len(t, proof[5].partialSumPolys[0], 3, "folded multi-source level must stay on the legacy degree")
assert.Len(t, proof[11].partialSumPolys[0], 2, "later single-source level should also use the optimized degree")
}

// testSkipLevel exercises proveSkipLevel/verifySkipLevel for a single skip level.
func testSkipLevel(t *testing.T, circuit gkrcore.RawCircuit, level constraint.GkrProvingLevel) {
t.Helper()
Expand Down
26 changes: 24 additions & 2 deletions internal/generator/backend/template/gkr/sumcheck.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type sumcheckClaims interface {
type sumcheckLazyClaims interface {
varsNum() int // varsNum = n
degree(i int) int // degree of the total claim in the i'th variable
roundCombinationCoeff(i int) ({{ .ElementType }}, bool)
verifyFinalEval(r []{{ .ElementType }}, purportedValue {{ .ElementType }}, proof []{{ .ElementType }}) error
}

Expand Down Expand Up @@ -92,13 +93,34 @@ func sumcheckVerify(claims sumcheckLazyClaims, proof sumcheckProof, claimedSum {

gJ := make(polynomial.Polynomial, degree+1)
gJR := claimedSum
var one {{ .ElementType }}
one.SetOne()

for j := range claims.varsNum() {
if len(proof.partialSumPolys[j]) != degree {
return errors.New("malformed proof")
}
copy(gJ[1:], proof.partialSumPolys[j])
gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0])
if coeff, ok := claims.roundCombinationCoeff(j); ok {
if coeff.Equal(&one) {
gJ[0].Set(&proof.partialSumPolys[j][0])
gJ[1].Set(&gJR)
} else {
copy(gJ[1:], proof.partialSumPolys[j])
gJ[0].Mul(&coeff, &proof.partialSumPolys[j][0])
gJ[0].Sub(&gJR, &gJ[0])
var oneMinusCoeff {{ .ElementType }}
oneMinusCoeff.SetOne()
oneMinusCoeff.Sub(&oneMinusCoeff, &coeff)
oneMinusCoeff.Inverse(&oneMinusCoeff)
gJ[0].Mul(&gJ[0], &oneMinusCoeff)
}
if degree > 1 {
copy(gJ[2:], proof.partialSumPolys[j][1:])
}
} else {
copy(gJ[1:], proof.partialSumPolys[j])
gJ[0].Sub(&gJR, &proof.partialSumPolys[j][0])
}

r[j] = t.getChallenge(proof.partialSumPolys[j]...)
gJCoeffs := polynomial.InterpolateOnRange(gJ[:(degree + 1)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ func (c singleMultilinLazyClaim) degree(int) int {
return 1
}

func (c singleMultilinLazyClaim) roundCombinationCoeff(int) ({{ .ElementType }}, bool) {
return {{ .ElementType }}{}, false
}

func (c singleMultilinLazyClaim) varsNum() int {
return bits.TrailingZeros(uint(len(c.g)))
}
Expand Down
Loading