diff --git a/internal/stats/latest_stats.csv b/internal/stats/latest_stats.csv index 65c88a6c86..ab9e8384fe 100644 --- a/internal/stats/latest_stats.csv +++ b/internal/stats/latest_stats.csv @@ -87,6 +87,18 @@ math/emulated/secp256k1_64,bn254,plonk,4025,3923 math/emulated/secp256k1_64,bls12_377,plonk,4025,3923 math/emulated/secp256k1_64,bls12_381,plonk,4025,3923 math/emulated/secp256k1_64,bw6_761,plonk,4025,3923 +msm_G1_bn254_2,bn254,groth16,208925,312617 +msm_G1_bn254_2,bn254,plonk,688811,658743 +msm_P256_2,bn254,groth16,185846,288056 +msm_P256_2,bn254,plonk,635297,608874 +msm_babyjubjub_2,bn254,groth16,5269,5683 +msm_babyjubjub_2,bn254,plonk,12389,11848 +msm_bandersnatch_2,bls12_381,groth16,5016,5791 +msm_bandersnatch_2,bls12_381,plonk,12450,11904 +msm_jubjub_2,bls12_381,groth16,5276,5754 +msm_jubjub_2,bls12_381,plonk,12332,11855 +msm_secp256k1_2,bn254,groth16,208997,312737 +msm_secp256k1_2,bn254,plonk,689104,659028 pairing_bls12377,bw6_761,groth16,11876,11876 pairing_bls12377,bw6_761,plonk,45565,45565 pairing_bls12381,bn254,groth16,756837,1242260 @@ -95,18 +107,18 @@ pairing_bn254,bn254,groth16,506052,823961 pairing_bn254,bn254,plonk,1646819,1573151 pairing_bw6761,bn254,groth16,1589471,2646707 pairing_bw6761,bn254,plonk,5318762,5097941 -scalar_mul_G1_bn254,bn254,groth16,115934,175413 -scalar_mul_G1_bn254,bn254,plonk,381171,365027 -scalar_mul_G1_bn254_incomplete,bn254,groth16,55441,87984 -scalar_mul_G1_bn254_incomplete,bn254,plonk,200004,192882 +scalar_mul_G1_bn254,bn254,groth16,108168,163915 +scalar_mul_G1_bn254,bn254,plonk,355353,340385 +scalar_mul_G1_bn254_incomplete,bn254,groth16,51579,81902 +scalar_mul_G1_bn254_incomplete,bn254,plonk,185916,179316 scalar_mul_P256,bn254,groth16,96724,151768 scalar_mul_P256,bn254,plonk,328895,315729 scalar_mul_P256_incomplete,bn254,groth16,75542,121798 scalar_mul_P256_incomplete,bn254,plonk,263160,253523 -scalar_mul_secp256k1,bn254,groth16,117264,177389 -scalar_mul_secp256k1,bn254,plonk,385623,369279 -scalar_mul_secp256k1_incomplete,bn254,groth16,56125,89066 -scalar_mul_secp256k1_incomplete,bn254,plonk,202518,195302 +scalar_mul_secp256k1,bn254,groth16,108204,163975 +scalar_mul_secp256k1,bn254,plonk,355502,340530 +scalar_mul_secp256k1_incomplete,bn254,groth16,51619,81970 +scalar_mul_secp256k1_incomplete,bn254,plonk,186082,179475 selector/binaryMux_4,bn254,groth16,5,3 selector/binaryMux_4,bls12_377,groth16,5,3 selector/binaryMux_4,bls12_381,groth16,5,3 diff --git a/internal/stats/snippet.go b/internal/stats/snippet.go index 6a4182a69d..dc19e7d41f 100644 --- a/internal/stats/snippet.go +++ b/internal/stats/snippet.go @@ -6,6 +6,7 @@ import ( "github.com/consensys/gnark" "github.com/consensys/gnark-crypto/ecc" + twistededwardsCryptoID "github.com/consensys/gnark-crypto/ecc/twistededwards" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/algebra/emulated/sw_bls12381" @@ -13,6 +14,7 @@ import ( "github.com/consensys/gnark/std/algebra/emulated/sw_bw6761" "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" "github.com/consensys/gnark/std/algebra/native/sw_bls12377" + "github.com/consensys/gnark/std/algebra/native/twistededwards" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" @@ -412,6 +414,144 @@ func initSnippets() { }, ecc.BN254) + // MSM(2, n) snippets for the four curve classes — used to evaluate which + // MSM-size variant is best in complete-arithmetic mode (Phase 4 of plan). + // Baselines: existing scalar_mul_* divided by 2 gives lower bound for + // MSM(2, n) via two ScalarMul + Add. + registerSnippet("msm_secp256k1_2", func(api frontend.API, newVariable func() frontend.Variable) { + cr, err := sw_emulated.New[emulated.Secp256k1Fp, emulated.Secp256k1Fr](api, sw_emulated.GetCurveParams[emulated.Secp256k1Fp]()) + if err != nil { + panic(err) + } + fr, _ := emulated.NewField[emulated.Secp256k1Fr](api) + newFr := func() *emulated.Element[emulated.Secp256k1Fr] { + n, _ := emulated.GetEffectiveFieldParams[emulated.Secp256k1Fr](api.Compiler().Field()) + limbs := make([]frontend.Variable, n) + for i := range limbs { + limbs[i] = newVariable() + } + return fr.NewElement(limbs) + } + fp, _ := emulated.NewField[emulated.Secp256k1Fp](api) + newPoint := func() *sw_emulated.AffinePoint[emulated.Secp256k1Fp] { + n, _ := emulated.GetEffectiveFieldParams[emulated.Secp256k1Fp](api.Compiler().Field()) + x := make([]frontend.Variable, n) + y := make([]frontend.Variable, n) + for i := range x { + x[i] = newVariable() + y[i] = newVariable() + } + return &sw_emulated.AffinePoint[emulated.Secp256k1Fp]{X: *fp.NewElement(x), Y: *fp.NewElement(y)} + } + _, _ = cr.MultiScalarMul( + []*sw_emulated.AffinePoint[emulated.Secp256k1Fp]{newPoint(), newPoint()}, + []*emulated.Element[emulated.Secp256k1Fr]{newFr(), newFr()}, + ) + }, ecc.BN254) + + registerSnippet("msm_P256_2", func(api frontend.API, newVariable func() frontend.Variable) { + cr, err := sw_emulated.New[emulated.P256Fp, emulated.P256Fr](api, sw_emulated.GetCurveParams[emulated.P256Fp]()) + if err != nil { + panic(err) + } + fr, _ := emulated.NewField[emulated.P256Fr](api) + newFr := func() *emulated.Element[emulated.P256Fr] { + n, _ := emulated.GetEffectiveFieldParams[emulated.P256Fr](api.Compiler().Field()) + limbs := make([]frontend.Variable, n) + for i := range limbs { + limbs[i] = newVariable() + } + return fr.NewElement(limbs) + } + fp, _ := emulated.NewField[emulated.P256Fp](api) + newPoint := func() *sw_emulated.AffinePoint[emulated.P256Fp] { + n, _ := emulated.GetEffectiveFieldParams[emulated.P256Fp](api.Compiler().Field()) + x := make([]frontend.Variable, n) + y := make([]frontend.Variable, n) + for i := range x { + x[i] = newVariable() + y[i] = newVariable() + } + return &sw_emulated.AffinePoint[emulated.P256Fp]{X: *fp.NewElement(x), Y: *fp.NewElement(y)} + } + _, _ = cr.MultiScalarMul( + []*sw_emulated.AffinePoint[emulated.P256Fp]{newPoint(), newPoint()}, + []*emulated.Element[emulated.P256Fr]{newFr(), newFr()}, + ) + }, ecc.BN254) + + registerSnippet("msm_G1_bn254_2", func(api frontend.API, newVariable func() frontend.Variable) { + cr, err := sw_emulated.New[emulated.BN254Fp, emulated.BN254Fr](api, sw_emulated.GetCurveParams[emulated.BN254Fp]()) + if err != nil { + panic(err) + } + fr, _ := emulated.NewField[emulated.BN254Fr](api) + newFr := func() *emulated.Element[emulated.BN254Fr] { + n, _ := emulated.GetEffectiveFieldParams[emulated.BN254Fr](api.Compiler().Field()) + limbs := make([]frontend.Variable, n) + for i := range limbs { + limbs[i] = newVariable() + } + return fr.NewElement(limbs) + } + fp, _ := emulated.NewField[emulated.BN254Fp](api) + newPoint := func() *sw_emulated.AffinePoint[emulated.BN254Fp] { + n, _ := emulated.GetEffectiveFieldParams[emulated.BN254Fp](api.Compiler().Field()) + x := make([]frontend.Variable, n) + y := make([]frontend.Variable, n) + for i := range x { + x[i] = newVariable() + y[i] = newVariable() + } + return &sw_emulated.AffinePoint[emulated.BN254Fp]{X: *fp.NewElement(x), Y: *fp.NewElement(y)} + } + _, _ = cr.MultiScalarMul( + []*sw_emulated.AffinePoint[emulated.BN254Fp]{newPoint(), newPoint()}, + []*emulated.Element[emulated.BN254Fr]{newFr(), newFr()}, + ) + }, ecc.BN254) + + // Twisted Edwards DoubleBaseScalarMul snippets — exercise the new + // MSM(3, 2n/3) (no GLV) and MSM(6, n/3) (GLV) variants from PR #1697. + registerSnippet("msm_babyjubjub_2", func(api frontend.API, newVariable func() frontend.Variable) { + curve, err := twistededwards.NewEdCurve(api, twistededwardsCryptoID.BN254) + if err != nil { + panic(err) + } + var P1, P2 twistededwards.Point + P1.X = newVariable() + P1.Y = newVariable() + P2.X = newVariable() + P2.Y = newVariable() + _ = curve.DoubleBaseScalarMulNonZero(P1, P2, newVariable(), newVariable()) + }, ecc.BN254) + + registerSnippet("msm_jubjub_2", func(api frontend.API, newVariable func() frontend.Variable) { + curve, err := twistededwards.NewEdCurve(api, twistededwardsCryptoID.BLS12_381) + if err != nil { + panic(err) + } + var P1, P2 twistededwards.Point + P1.X = newVariable() + P1.Y = newVariable() + P2.X = newVariable() + P2.Y = newVariable() + _ = curve.DoubleBaseScalarMulNonZero(P1, P2, newVariable(), newVariable()) + }, ecc.BLS12_381) + + registerSnippet("msm_bandersnatch_2", func(api frontend.API, newVariable func() frontend.Variable) { + curve, err := twistededwards.NewEdCurve(api, twistededwardsCryptoID.BLS12_381_BANDERSNATCH) + if err != nil { + panic(err) + } + var P1, P2 twistededwards.Point + P1.X = newVariable() + P1.Y = newVariable() + P2.X = newVariable() + P2.Y = newVariable() + _ = curve.DoubleBaseScalarMulNonZero(P1, P2, newVariable(), newVariable()) + }, ecc.BLS12_381) + registerSnippet("selector/mux_3", func(api frontend.API, newVariable func() frontend.Variable) { selector.Mux(api, newVariable(), newVariable(), newVariable(), newVariable()) }) diff --git a/std/algebra/emulated/sw_emulated/hints.go b/std/algebra/emulated/sw_emulated/hints.go index b6b42cffe3..e436f992cc 100644 --- a/std/algebra/emulated/sw_emulated/hints.go +++ b/std/algebra/emulated/sw_emulated/hints.go @@ -6,7 +6,7 @@ import ( "fmt" "math/big" - "github.com/consensys/gnark-crypto/algebra/eisenstein" + "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc" bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" bls12381_fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" @@ -32,8 +32,8 @@ func GetHints() []solver.Hint { return []solver.Hint{ decomposeScalarG1, scalarMulHint, - halfGCD, - halfGCDEisenstein, + rationalReconstruct, + rationalReconstructExt, } } @@ -160,7 +160,15 @@ func scalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error }) } -func halfGCD(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { +// rationalReconstruct decomposes a scalar s ∈ Fr into (s1, |s2|, signBit) such +// that s1 ≡ s2·s (mod r), with |s1|, |s2| < γ₂·√r ≈ 1.15·√r (proven LLL/Hermite +// bound from gnark-crypto/algebra/lattice). Replaces the older heuristic +// HalfGCD-based decomposition. +// +// In-circuit: 1 native sign bit + 2 emulated outputs (s1, |s2|). The caller +// reconstructs the signed s2 as ±|s2| based on the sign bit and asserts +// s1 + s·s2 ≡ 0 (mod r). +func rationalReconstruct(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { moduli := hc.EmulatedModuli() if len(moduli) != 1 { @@ -177,25 +185,38 @@ func halfGCD(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { if len(emuOutputs) != 2 { return fmt.Errorf("expecting two outputs, got %d", len(emuOutputs)) } - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(moduli[0], emuInputs[0], glvBasis) - emuOutputs[0].Set(&glvBasis.V1[0]) - emuOutputs[1].Set(&glvBasis.V1[1]) - // we need the absolute values for the in-circuit computations, - // otherwise the negative values will be reduced modulo the SNARK scalar - // field and not the emulated field. - // output0 = |s0| mod r - // output1 = |s1| mod r + // lattice.RationalReconstruct returns (x, z) with x ≡ z·s (mod r), + // i.e., x − z·s ≡ 0 (mod r). The circuit expects: s1 + s·_s2 ≡ 0 + // (mod r), so s1 = x and _s2 = −z. + rc := lattice.NewReconstructor(moduli[0]) + res := rc.RationalReconstruct(emuInputs[0]) + x, z := new(big.Int).Set(res[0]), new(big.Int).Set(res[1]) + + // Normalise so s1 ≥ 0; flipping (x, z) preserves x ≡ z·s mod r. + if x.Sign() < 0 { + x.Neg(x) + z.Neg(z) + } + emuOutputs[0].Set(x) + emuOutputs[1].Abs(z) + + // signBit = 1 iff −z < 0 iff z > 0 (so the in-circuit code negates + // |z| to recover s2 = −z). nativeOutputs[0].SetUint64(0) - if emuOutputs[1].Sign() == -1 { - emuOutputs[1].Neg(emuOutputs[1]) - nativeOutputs[0].SetUint64(1) // we return the sign of the second subscalar + if z.Sign() > 0 { + nativeOutputs[0].SetUint64(1) } return nil }) } -func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { +// rationalReconstructExt is the 4-D Eisenstein-style decomposition: given a +// scalar s and GLV eigenvalue λ, finds (u1, u2, v1, v2) such that +// s·(v1 + λ·v2) + u1 + λ·u2 ≡ 0 (mod r), with |u_i|, |v_i| < γ₄·r^(1/4) ≈ +// 1.25·r^(1/4) (proven LLL bound). Replaces the older Eisenstein HalfGCD. +// +// In-circuit: 4 native sign bits + 4 emulated absolute values. +func rationalReconstructExt(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { return emulated.UnwrapHintContext(mod, inputs, outputs, func(hc emulated.HintContext) error { moduli := hc.EmulatedModuli() if len(moduli) != 1 { @@ -213,47 +234,43 @@ func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) erro return fmt.Errorf("expecting four outputs, got %d", len(emuOutputs)) } - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(moduli[0], emuInputs[1], glvBasis) - r := eisenstein.ComplexNumber{ - A0: glvBasis.V1[0], - A1: glvBasis.V1[1], - } - sp := ecc.SplitScalar(emuInputs[0], glvBasis) - // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 - // so here we return -s instead of s. - s := eisenstein.ComplexNumber{ - A0: sp[0], - A1: sp[1], - } - s.Neg(&s) + // Inputs: emuInputs[0] = s, emuInputs[1] = λ. + // In-circuit we check Q − [s]P = 0, equivalently [−s]P + Q = 0, so we + // negate the scalar before reconstruction (matches the previous + // halfGCDEisenstein convention). + k := new(big.Int).Neg(emuInputs[0]) + k.Mod(k, moduli[0]) + + rc := lattice.NewReconstructor(moduli[0]).SetLambda(emuInputs[1]) + res := rc.RationalReconstructExt(k) + // res = (x, y, z, t) with k = (x + λ·y)/(z + λ·t) mod r, + // i.e., (x + λ·y) − k·(z + λ·t) ≡ 0 (mod r). + // Mapping onto our convention u1 + λ·u2 + s·(v1 + λ·v2) ≡ 0 with k = −s: + // u1 = x, u2 = y, v1 = z, v2 = t. + u1 := new(big.Int).Set(res[0]) + u2 := new(big.Int).Set(res[1]) + v1 := new(big.Int).Set(res[2]) + v2 := new(big.Int).Set(res[3]) + + emuOutputs[0].Abs(u1) + emuOutputs[1].Abs(u2) + emuOutputs[2].Abs(v1) + emuOutputs[3].Abs(v2) - res := eisenstein.HalfGCD(&r, &s) - // values - emuOutputs[0].Set(&res[0].A0) - emuOutputs[1].Set(&res[0].A1) - emuOutputs[2].Set(&res[1].A0) - emuOutputs[3].Set(&res[1].A1) - // signs nativeOutputs[0].SetUint64(0) nativeOutputs[1].SetUint64(0) nativeOutputs[2].SetUint64(0) nativeOutputs[3].SetUint64(0) - - if res[0].A0.Sign() == -1 { - emuOutputs[0].Neg(emuOutputs[0]) + if u1.Sign() < 0 { nativeOutputs[0].SetUint64(1) } - if res[0].A1.Sign() == -1 { - emuOutputs[1].Neg(emuOutputs[1]) + if u2.Sign() < 0 { nativeOutputs[1].SetUint64(1) } - if res[1].A0.Sign() == -1 { - emuOutputs[2].Neg(emuOutputs[2]) + if v1.Sign() < 0 { nativeOutputs[2].SetUint64(1) } - if res[1].A1.Sign() == -1 { - emuOutputs[3].Neg(emuOutputs[3]) + if v2.Sign() < 0 { nativeOutputs[3].SetUint64(1) } return nil diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index 27793b0826..0404d1a385 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -1372,9 +1372,9 @@ func (c *Curve[B, S]) scalarMulFakeGLV(Q *AffinePoint[B], s *emulated.Element[S] // First we find the sub-salars s1, s2 s.t. s1 + s2*s = 0 mod r and s1, s2 < sqrt(r). // we also output the sign in case s2 is negative. In that case we compute _s2 = -s2 mod r. - sign, sd, err := c.scalarApi.NewHintGeneric(halfGCD, 1, 2, nil, []*emulated.Element[S]{_s}) + sign, sd, err := c.scalarApi.NewHintGeneric(rationalReconstruct, 1, 2, nil, []*emulated.Element[S]{_s}) if err != nil { - panic(fmt.Sprintf("halfGCD hint: %v", err)) + panic(fmt.Sprintf("rationalReconstruct hint: %v", err)) } s1, s2 := sd[0], sd[1] _s2 := c.scalarApi.Select(sign[0], c.scalarApi.Neg(s2), s2) @@ -1676,9 +1676,9 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem // Eisenstein integers real and imaginary parts can be negative. So we // return the absolute value in the hint and negate the corresponding // points here when needed. - signs, sd, err := c.scalarApi.NewHintGeneric(halfGCDEisenstein, 4, 4, nil, []*emulated.Element[S]{_s, c.eigenvalue}) + signs, sd, err := c.scalarApi.NewHintGeneric(rationalReconstructExt, 4, 4, nil, []*emulated.Element[S]{_s, c.eigenvalue}) if err != nil { - panic(fmt.Sprintf("halfGCDEisenstein hint: %v", err)) + panic(fmt.Sprintf("rationalReconstructExt hint: %v", err)) } u1, u2, v1, v2 := sd[0], sd[1], sd[2], sd[3] isNegu1, isNegu2, isNegv1, isNegv2 := signs[0], signs[1], signs[2], signs[3] @@ -1800,10 +1800,10 @@ func (c *Curve[B, S]) scalarMulGLVAndFakeGLV(P *AffinePoint[B], s *emulated.Elem g := c.Generator() Acc = addFn(Acc, g) - // u1, u2, v1, v2 < r^{1/4} (up to a constant factor). - // We prove that the factor is log_(3/sqrt(3)))(r). - // so we need to add 9 bits to r^{1/4}.nbits(). - nbits := st.Modulus().BitLen()>>2 + 9 + // LLL Hermite bound (gnark-crypto/algebra/lattice): u1, u2, v1, v2 are + // bounded by γ₄·r^(1/4) ≈ 1.25·r^(1/4), which fits in (BitLen+3)/4 + 2 bits. + // This is tighter than the previous heuristic BitLen/4 + 9 (saves ~7 iters). + nbits := (st.Modulus().BitLen()+3)/4 + 2 u1bits := c.scalarApi.ToBits(u1) u2bits := c.scalarApi.ToBits(u2) v1bits := c.scalarApi.ToBits(v1) diff --git a/std/algebra/emulated/sw_emulated/point_test.go b/std/algebra/emulated/sw_emulated/point_test.go index 0e20db6d74..0949c86b73 100644 --- a/std/algebra/emulated/sw_emulated/point_test.go +++ b/std/algebra/emulated/sw_emulated/point_test.go @@ -2601,7 +2601,7 @@ func TestScalarMulGLVAndFakeGLVEdgeCasesEdgeCases2(t *testing.T) { } // This is a regression for the missing complete-formula handling in -// scalarMulGLVAndFakeGLV. For secp256k1 and s=2, the halfGCDEisenstein +// scalarMulGLVAndFakeGLV. For secp256k1 and s=2, the rationalReconstructExt // decomposition yields signs corresponding to // // b1 = -P + Q + Phi(P) + Phi(Q). @@ -2664,9 +2664,9 @@ func TestScalarMulGLVAndFakeGLVCompletePrecomputeCollisionFails(t *testing.T) { assert.NoError(err) } -// zeroHalfGCDEisenstein replaces the honest halfGCDEisenstein hint with one +// zeroRationalReconstructExt replaces the honest rationalReconstructExt hint with one // returning the all-zeros decomposition. Used by the regression below. -func zeroHalfGCDEisenstein(_ *big.Int, _, outputs []*big.Int) error { +func zeroRationalReconstructExt(_ *big.Int, _, outputs []*big.Int) error { for i := range outputs { outputs[i].SetUint64(0) } @@ -2674,7 +2674,7 @@ func zeroHalfGCDEisenstein(_ *big.Int, _, outputs []*big.Int) error { } // TestScalarMulGLVAndFakeGLV_TrivialDecompositionRegression: regression for a -// soundness issue in scalarMulGLVAndFakeGLV. A malicious halfGCDEisenstein +// soundness issue in scalarMulGLVAndFakeGLV. A malicious rationalReconstructExt // hint returning the trivial all-zeros decomposition (u1=u2=v1=v2=0) makes // the relation s·(v1 + λ·v2) + u1 + λ·u2 = 0 vacuous and lets the // scalar-mul hint output be any point. The fix asserts NOT (v1=0 AND v2=0). @@ -2704,7 +2704,7 @@ func TestScalarMulGLVAndFakeGLV_TrivialDecompositionRegression(t *testing.T) { // malicious all-zeros Eisenstein decomposition must be rejected err := test.IsSolved(&circuit, &witness, testCurve.ScalarField(), - test.WithReplacementHint(solver.GetHintID(halfGCDEisenstein), zeroHalfGCDEisenstein), + test.WithReplacementHint(solver.GetHintID(rationalReconstructExt), zeroRationalReconstructExt), ) if err == nil { t.Fatal("malicious all-zeros Eisenstein decomposition was accepted — soundness break") diff --git a/std/algebra/native/sw_bls12377/g1.go b/std/algebra/native/sw_bls12377/g1.go index 4cdb944d9d..5c4722091f 100644 --- a/std/algebra/native/sw_bls12377/g1.go +++ b/std/algebra/native/sw_bls12377/g1.go @@ -651,9 +651,9 @@ func (p *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte // Eisenstein integers real and imaginary parts can be negative. So we // return the absolute value in the hint and negate the corresponding // points here when needed. - sd, err := api.NewHint(halfGCDEisenstein, 10, _s, cc.lambda) + sd, err := api.NewHint(rationalReconstructExt, 10, _s, cc.lambda) if err != nil { - panic(fmt.Sprintf("halfGCDEisenstein hint: %v", err)) + panic(fmt.Sprintf("rationalReconstructExt hint: %v", err)) } u1, u2, v1, v2, q := sd[0], sd[1], sd[2], sd[3], sd[4] isNegu1, isNegu2, isNegv1, isNegv2, isNegq := sd[5], sd[6], sd[7], sd[8], sd[9] @@ -775,10 +775,10 @@ func (p *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte H := G1Affine{X: 0, Y: 1} Acc.AddAssign(api, H) - // u1, u2, v1, v2 < r^{1/4} (up to a constant factor). - // We prove that the factor is log_(3/sqrt(3)))(r). - // so we need to add 9 bits to r^{1/4}.nbits(). - nbits := cc.lambda.BitLen()>>1 + 9 // 72 + // LLL Hermite bound (gnark-crypto/algebra/lattice): u1, u2, v1, v2 are + // bounded by γ₄·r^(1/4) ≈ 1.25·r^(1/4), which fits in (r.BitLen()+3)/4 + 2 + // bits. Tighter than the previous heuristic lambda.BitLen()/2 + 9 (saves ~6 iters). + nbits := (cc.fr.BitLen()+3)/4 + 2 // 66 for BLS12-377 u1bits := api.ToBinary(u1, nbits) u2bits := api.ToBinary(u2, nbits) v1bits := api.ToBinary(v1, nbits) diff --git a/std/algebra/native/sw_bls12377/g1_eisenstein_test.go b/std/algebra/native/sw_bls12377/g1_eisenstein_test.go index 36f7b96550..10f995af01 100644 --- a/std/algebra/native/sw_bls12377/g1_eisenstein_test.go +++ b/std/algebra/native/sw_bls12377/g1_eisenstein_test.go @@ -28,11 +28,11 @@ func (c *scalarMulGLVAndFakeGLVTrivialDecompCircuit) Define(api frontend.API) er return nil } -// zeroHalfGCDEisenstein replaces the honest halfGCDEisenstein hint with one +// zeroRationalReconstructExt replaces the honest rationalReconstructExt hint with one // that returns the all-zeros decomposition (u1 = u2 = v1 = v2 = q = 0). The // signs are also zero (positive). This is the malicious-hint shape the // soundness fix protects against. -func zeroHalfGCDEisenstein(_ *big.Int, inputs, outputs []*big.Int) error { +func zeroRationalReconstructExt(_ *big.Int, inputs, outputs []*big.Int) error { if len(inputs) != 2 { return errors.New("expecting two inputs") } @@ -46,7 +46,7 @@ func zeroHalfGCDEisenstein(_ *big.Int, inputs, outputs []*big.Int) error { } // TestScalarMulGLVAndFakeGLV_TrivialDecompositionRegression: regression for a -// soundness issue in scalarMulGLVAndFakeGLV. A malicious halfGCDEisenstein +// soundness issue in scalarMulGLVAndFakeGLV. A malicious rationalReconstructExt // hint returning the trivial all-zeros decomposition (u1=u2=v1=v2=q=0) makes // the relation s·(v1 + λ·v2) + u1 + λ·u2 - r·q = 0 vacuous and lets the // scalar-mul hint output be any point. The fix asserts NOT (v1=0 AND v2=0). @@ -73,7 +73,7 @@ func TestScalarMulGLVAndFakeGLV_TrivialDecompositionRegression(t *testing.T) { &scalarMulGLVAndFakeGLVTrivialDecompCircuit{}, &witness, ecc.BW6_761.ScalarField(), - test.WithReplacementHint(solver.GetHintID(halfGCDEisenstein), zeroHalfGCDEisenstein), + test.WithReplacementHint(solver.GetHintID(rationalReconstructExt), zeroRationalReconstructExt), ) assert.Error(err, "trivial all-zeros Eisenstein decomposition was accepted — soundness break") } diff --git a/std/algebra/native/sw_bls12377/hints.go b/std/algebra/native/sw_bls12377/hints.go index af1acc761d..3302be200b 100644 --- a/std/algebra/native/sw_bls12377/hints.go +++ b/std/algebra/native/sw_bls12377/hints.go @@ -4,7 +4,7 @@ import ( "errors" "math/big" - "github.com/consensys/gnark-crypto/algebra/eisenstein" + "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark/constraint/solver" @@ -16,7 +16,7 @@ func GetHints() []solver.Hint { decomposeScalarG1Simple, decomposeScalarG2, scalarMulGLVG1Hint, - halfGCDEisenstein, + rationalReconstructExt, pairingCheckHint, pairingCheckTorusHint, } @@ -306,68 +306,76 @@ func scalarMulGLVG1Hint(scalarField *big.Int, inputs []*big.Int, outputs []*big. return nil } -func halfGCDEisenstein(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { +// rationalReconstructExt is the 4-D Eisenstein-style scalar decomposition for +// BLS12-377 G1's GLV+FakeGLV scalar mul, backed by LLL-based lattice rational +// reconstruction with proven Hermite bound |u_i|, |v_i| < γ₄·r^(1/4) ≈ +// 1.25·r^(1/4). Replaces the older Eisenstein HalfGCD. +// +// Inputs: [s, λ] (scalar and GLV eigenvalue, both bounded by inner curve order). +// Outputs: [|u1|, |u2|, |v1|, |v2|, |q|, sign(u1), sign(u2), sign(v1), sign(v2), sign(q)] (10). +// +// The relation (in signed integers) is +// +// s·(v1 + λ·v2) + u1 + λ·u2 = q·r +// +// where r is the inner curve order. The in-circuit check at sw_bls12377/g1.go:: +// scalarMulGLVAndFakeGLV verifies this in the outer SNARK scalar field. +func rationalReconstructExt(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { if len(inputs) != 2 { - return errors.New("expecting two input") + return errors.New("expecting two inputs (s, λ)") } if len(outputs) != 10 { - return errors.New("expecting ten outputs") + return errors.New("expecting ten outputs (4 abs values + 1 |q| + 5 sign bits)") } cc := getInnerCurveConfig(scalarField) - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(cc.fr, inputs[1], glvBasis) - r := eisenstein.ComplexNumber{ - A0: glvBasis.V1[0], - A1: glvBasis.V1[1], - } - sp := ecc.SplitScalar(inputs[0], glvBasis) - // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 - // so here we return -s instead of s. - s := eisenstein.ComplexNumber{ - A0: sp[0], - A1: sp[1], + + // In-circuit we check Q − [s]P = 0, equivalently Q + [−s]P = 0, so we + // negate the scalar before reconstruction (matches the previous convention). + k := new(big.Int).Neg(inputs[0]) + k.Mod(k, cc.fr) + + rc := lattice.NewReconstructor(cc.fr).SetLambda(inputs[1]) + res := rc.RationalReconstructExt(k) + // res = (x, y, z, t) with k = (x + λ·y)/(z + λ·t) mod r, + // i.e., (x + λ·y) − k·(z + λ·t) ≡ 0 (mod r). With k = −s mod r this gives + // (x + λ·y) + s·(z + λ·t) ≡ 0 (mod r). Mapping: u1 = x, u2 = y, v1 = z, v2 = t. + u1 := new(big.Int).Set(res[0]) + u2 := new(big.Int).Set(res[1]) + v1 := new(big.Int).Set(res[2]) + v2 := new(big.Int).Set(res[3]) + + // q = (s·(v1 + λ·v2) + u1 + λ·u2) / r computed in signed integers. + q := new(big.Int).Mul(v2, inputs[1]) + q.Add(q, v1) + q.Mul(q, inputs[0]) + tmp := new(big.Int).Mul(u2, inputs[1]) + q.Add(q, tmp) + q.Add(q, u1) + q.Quo(q, cc.fr) + + outputs[0].Abs(u1) + outputs[1].Abs(u2) + outputs[2].Abs(v1) + outputs[3].Abs(v2) + outputs[4].Abs(q) + + for i := 5; i <= 9; i++ { + outputs[i].SetUint64(0) } - s.Neg(&s) - res := eisenstein.HalfGCD(&r, &s) - outputs[0].Set(&res[0].A0) - outputs[1].Set(&res[0].A1) - outputs[2].Set(&res[1].A0) - outputs[3].Set(&res[1].A1) - outputs[4].Mul(&res[1].A1, inputs[1]). - Add(outputs[4], &res[1].A0). - Mul(outputs[4], inputs[0]). - Add(outputs[4], &res[0].A0) - s.A0.Mul(&res[0].A1, inputs[1]) - outputs[4].Add(outputs[4], &s.A0). - Div(outputs[4], cc.fr) - - // set the signs - outputs[5].SetUint64(0) - outputs[6].SetUint64(0) - outputs[7].SetUint64(0) - outputs[8].SetUint64(0) - outputs[9].SetUint64(0) - - if outputs[0].Sign() == -1 { - outputs[0].Neg(outputs[0]) + if u1.Sign() < 0 { outputs[5].SetUint64(1) } - if outputs[1].Sign() == -1 { - outputs[1].Neg(outputs[1]) + if u2.Sign() < 0 { outputs[6].SetUint64(1) } - if outputs[2].Sign() == -1 { - outputs[2].Neg(outputs[2]) + if v1.Sign() < 0 { outputs[7].SetUint64(1) } - if outputs[3].Sign() == -1 { - outputs[3].Neg(outputs[3]) + if v2.Sign() < 0 { outputs[8].SetUint64(1) } - if outputs[4].Sign() == -1 { - outputs[4].Neg(outputs[4]) + if q.Sign() < 0 { outputs[9].SetUint64(1) } - return nil } diff --git a/std/algebra/native/twistededwards/curve.go b/std/algebra/native/twistededwards/curve.go index 4dfcf10e09..1fab86ede1 100644 --- a/std/algebra/native/twistededwards/curve.go +++ b/std/algebra/native/twistededwards/curve.go @@ -10,6 +10,7 @@ type curve struct { api frontend.API id twistededwards.ID params *CurveParams + endo *EndoParams // non-nil iff the curve has a GLV endomorphism (Bandersnatch) } func (c *curve) Params() *CurveParams { @@ -44,8 +45,29 @@ func (c *curve) ScalarMul(p1 Point, scalar frontend.Variable) Point { p.scalarMul(c.api, &p1, scalar, c.params) return p } + +// DoubleBaseScalarMul computes s1*p1 + s2*p2. It is complete for all scalar +// inputs, including zero, and for identity points. func (c *curve) DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point { var p Point p.doubleBaseScalarMul(c.api, &p1, &p2, s1, s2, c.params) return p } + +// DoubleBaseScalarMulNonZero computes s1*p1 + s2*p2 using the most efficient +// lattice-based MSM variant available for the curve: +// - GLV-equipped curves (Bandersnatch): 6-MSM with r^(1/3)-bounded sub-scalars. +// - non-GLV curves (Jubjub, BabyJubjub, edBLS12-377, edBW6-761): 3-MSM with +// r^(2/3)-bounded sub-scalars and LogUp lookups. +// +// The scalars s1, s2 must be nonzero and p1, p2 must not be the TE identity +// (0, 1). Use DoubleBaseScalarMul for complete edge-case handling. +func (c *curve) DoubleBaseScalarMulNonZero(p1, p2 Point, s1, s2 frontend.Variable) Point { + var p Point + if c.endo != nil { + p.doubleBaseScalarMul6MSMLogUp(c.api, &p1, &p2, s1, s2, c.params, c.endo) + } else { + p.doubleBaseScalarMul3MSMLogUp(c.api, &p1, &p2, s1, s2, c.params) + } + return p +} diff --git a/std/algebra/native/twistededwards/curve_test.go b/std/algebra/native/twistededwards/curve_test.go index b7715e5c95..b979ad59f4 100644 --- a/std/algebra/native/twistededwards/curve_test.go +++ b/std/algebra/native/twistededwards/curve_test.go @@ -191,6 +191,22 @@ func (circuit *doubleBaseScalarMulCircuit) Define(api frontend.API) error { return nil } +type doubleBaseScalarMulNonZeroCircuit struct { + curveID twistededwards.ID + P1, P2 Point + S1, S2 frontend.Variable + Result Point +} + +func (circuit *doubleBaseScalarMulNonZeroCircuit) Define(api frontend.API) error { + curve, err := NewEdCurve(api, circuit.curveID) + if err != nil { + return err + } + assertPointEqual(api, curve.DoubleBaseScalarMulNonZero(circuit.P1, circuit.P2, circuit.S1, circuit.S2), circuit.Result) + return nil +} + func TestAdd(t *testing.T) { for _, curveID := range curves { params, err := GetCurveParams(curveID) @@ -296,6 +312,21 @@ func TestDoubleBaseScalarMul(t *testing.T) { } } +func TestDoubleBaseScalarMulNonZero(t *testing.T) { + for _, curveID := range curves { + params, err := GetCurveParams(curveID) + if err != nil { + t.Fatalf("%s: get curve params: %v", curveLabel(curveID), err) + } + data := randomTestData(params, curveID) + circuit := &doubleBaseScalarMulNonZeroCircuit{curveID: curveID} + witness := &doubleBaseScalarMulNonZeroCircuit{P1: data.P1, P2: data.P2, S1: data.S1, S2: data.S2, Result: data.DoubleScalarMulResult} + invalidWitness := *witness + invalidWitness.Result = offCurvePoint() + checkCircuitForCurve(t, curveID, circuit, witness, &invalidWitness) + } +} + func TestAddEdgeCases(t *testing.T) { for _, curveID := range curves { params, err := GetCurveParams(curveID) @@ -380,6 +411,9 @@ func TestFixedScalarMulEdgeCases(t *testing.T) { } } +// TestDoubleBaseScalarMulEdgeCases covers the complete public method, including +// zero scalars and identity points. The optimized NonZero variant is tested +// separately. func TestDoubleBaseScalarMulEdgeCases(t *testing.T) { for _, curveID := range curves { params, err := GetCurveParams(curveID) @@ -387,11 +421,14 @@ func TestDoubleBaseScalarMulEdgeCases(t *testing.T) { t.Fatalf("%s: get curve params: %v", curveLabel(curveID), err) } data := testDataForScalars(params, curveID, big.NewInt(1), big.NewInt(2)) + base := Point{X: params.Base[0], Y: params.Base[1]} t.Run(curveLabel(curveID), func(t *testing.T) { - assertSolvedForCurve(t, curveID, &doubleBaseScalarMulCircuit{curveID: curveID}, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 0, S2: 0, Result: identityPoint()}) - assertSolvedForCurve(t, curveID, &doubleBaseScalarMulCircuit{curveID: curveID}, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 1, S2: 0, Result: data.P1}) - assertSolvedForCurve(t, curveID, &doubleBaseScalarMulCircuit{curveID: curveID}, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 0, S2: 1, Result: data.P2}) + circuit := &doubleBaseScalarMulCircuit{curveID: curveID} + assertSolvedForCurve(t, curveID, circuit, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 0, S2: 0, Result: identityPoint()}) + assertSolvedForCurve(t, curveID, circuit, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 1, S2: 0, Result: data.P1}) + assertSolvedForCurve(t, curveID, circuit, &doubleBaseScalarMulCircuit{P1: data.P1, P2: data.P2, S1: 0, S2: 1, Result: data.P2}) + assertSolvedForCurve(t, curveID, circuit, &doubleBaseScalarMulCircuit{P1: identityPoint(), P2: base, S1: 1, S2: 2, Result: data.P2}) }) } } @@ -626,12 +663,12 @@ func (c *scalarMulFakeGLVRegressionCircuit) Define(api frontend.API) error { return nil } -func zeroHalfGCDHint(_ *big.Int, inputs, outputs []*big.Int) error { +func zeroRationalReconstructHint(_ *big.Int, inputs, outputs []*big.Int) error { if len(inputs) != 2 { return errors.New("expecting two inputs") } - if len(outputs) != 4 { - return errors.New("expecting four outputs") + if len(outputs) != 3 { + return errors.New("expecting three outputs") } for i := range outputs { outputs[i].SetUint64(0) @@ -640,8 +677,9 @@ func zeroHalfGCDHint(_ *big.Int, inputs, outputs []*big.Int) error { } // This is a regression for a soundness issue in scalarMulFakeGLV. A malicious -// halfGCD hint can return the trivial decomposition s1=s2=0, which makes the -// internal accumulator check vacuous and lets any scalar-mul hint output pass. +// rationalReconstruct hint can return the trivial decomposition s1=s2=0, +// which makes the internal accumulator check vacuous and lets any scalar-mul +// hint output pass. func TestScalarMulFakeGLVRegressionTrivialDecomposition(t *testing.T) { assert := require.New(t) @@ -654,7 +692,100 @@ func TestScalarMulFakeGLVRegressionTrivialDecomposition(t *testing.T) { &scalarMulFakeGLVRegressionCircuit{}, &witness, ecc.BN254.ScalarField(), - test.WithReplacementHint(solver.GetHintID(halfGCD), zeroHalfGCDHint), + test.WithReplacementHint(solver.GetHintID(rationalReconstruct), zeroRationalReconstructHint), + ) + assert.Error(err) +} + +func forgedBN254DoubleBaseScalarMulHint(_ *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 7 { + return errors.New("expecting seven inputs") + } + if len(outputs) != 4 { + return errors.New("expecting four outputs") + } + var p1, p2, q1, q2 tbn254.PointAffine + p1.X.SetBigInt(inputs[0]) + p1.Y.SetBigInt(inputs[1]) + p2.X.SetBigInt(inputs[3]) + p2.Y.SetBigInt(inputs[4]) + q1.ScalarMultiplication(&p1, inputs[2]) + q2.ScalarMultiplication(&p2, inputs[5]) + + var delta tbn254.PointAffine + delta.Set(&p1) + + var q1Hint tbn254.PointAffine + q1Hint.Add(&q1, &delta) + + q1Hint.X.BigInt(outputs[0]) + q1Hint.Y.BigInt(outputs[1]) + q2.X.BigInt(outputs[2]) + q2.Y.BigInt(outputs[3]) + return nil +} + +func forgedBN254DoubleBaseResult(params *CurveParams, s1, s2 *big.Int) (Point, error) { + var p1, p2 tbn254.PointAffine + p1.X.SetBigInt(params.Base[0]) + p1.Y.SetBigInt(params.Base[1]) + p2.Set(&p1) + p1.ScalarMultiplication(&p1, s1) + p2.ScalarMultiplication(&p2, s2) + + p1X, p1Y := new(big.Int), new(big.Int) + p2X, p2Y := new(big.Int), new(big.Int) + p1.X.BigInt(p1X) + p1.Y.BigInt(p1Y) + p2.X.BigInt(p2X) + p2.Y.BigInt(p2Y) + + inputs := []*big.Int{ + p1X, + p1Y, + new(big.Int).Set(s1), + p2X, + p2Y, + new(big.Int).Set(s2), + new(big.Int).Set(params.Order), + } + outputs := []*big.Int{new(big.Int), new(big.Int), new(big.Int), new(big.Int)} + if err := forgedBN254DoubleBaseScalarMulHint(nil, inputs, outputs); err != nil { + return Point{}, err + } + var q1, q2, r tbn254.PointAffine + q1.X.SetBigInt(outputs[0]) + q1.Y.SetBigInt(outputs[1]) + q2.X.SetBigInt(outputs[2]) + q2.Y.SetBigInt(outputs[3]) + r.Add(&q1, &q2) + rX, rY := new(big.Int), new(big.Int) + r.X.BigInt(rX) + r.Y.BigInt(rY) + return Point{X: rX, Y: rY}, nil +} + +func TestDoubleBaseScalarMulNonZeroRejectsForgedPartialHints(t *testing.T) { + assert := require.New(t) + params, err := GetCurveParams(twistededwards.BN254) + assert.NoError(err) + + data := testDataForScalars(params, twistededwards.BN254, big.NewInt(5), big.NewInt(7)) + forged, err := forgedBN254DoubleBaseResult(params, data.S1, data.S2) + assert.NoError(err) + + witness := doubleBaseScalarMulNonZeroCircuit{ + P1: data.P1, + P2: data.P2, + S1: data.S1, + S2: data.S2, + Result: forged, + } + err = test.IsSolved( + &doubleBaseScalarMulNonZeroCircuit{curveID: twistededwards.BN254}, + &witness, + ecc.BN254.ScalarField(), + test.WithReplacementHint(solver.GetHintID(doubleBaseScalarMulHint), forgedBN254DoubleBaseScalarMulHint), ) assert.Error(err) } diff --git a/std/algebra/native/twistededwards/emulatedparams.go b/std/algebra/native/twistededwards/emulatedparams.go new file mode 100644 index 0000000000..c24d016a89 --- /dev/null +++ b/std/algebra/native/twistededwards/emulatedparams.go @@ -0,0 +1,64 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +package twistededwards + +import "math/big" + +// Emulated field parameters for twisted Edwards curve orders. +// These are used for overflow-safe scalar decomposition verification. + +// edBN254Order is the BabyJubjub curve order (251 bits). +type edBN254Order struct{} + +func (edBN254Order) NbLimbs() uint { return 4 } +func (edBN254Order) BitsPerLimb() uint { return 64 } +func (edBN254Order) IsPrime() bool { return true } +func (edBN254Order) Modulus() *big.Int { + r, _ := new(big.Int).SetString("2736030358979909402780800718157159386076813972158567259200215660948447373041", 10) + return r +} + +// edBLS12381Order is the Jubjub curve order (252 bits). +type edBLS12381Order struct{} + +func (edBLS12381Order) NbLimbs() uint { return 4 } +func (edBLS12381Order) BitsPerLimb() uint { return 64 } +func (edBLS12381Order) IsPrime() bool { return true } +func (edBLS12381Order) Modulus() *big.Int { + r, _ := new(big.Int).SetString("6554484396890773809930967563523245729705921265872317281365359162392183254199", 10) + return r +} + +// edBandersnatchOrder is the Bandersnatch curve order (253 bits). +type edBandersnatchOrder struct{} + +func (edBandersnatchOrder) NbLimbs() uint { return 4 } +func (edBandersnatchOrder) BitsPerLimb() uint { return 64 } +func (edBandersnatchOrder) IsPrime() bool { return true } +func (edBandersnatchOrder) Modulus() *big.Int { + r, _ := new(big.Int).SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + return r +} + +// edBLS12377Order is the BLS12-377 twisted Edwards curve order (251 bits). +type edBLS12377Order struct{} + +func (edBLS12377Order) NbLimbs() uint { return 4 } +func (edBLS12377Order) BitsPerLimb() uint { return 64 } +func (edBLS12377Order) IsPrime() bool { return true } +func (edBLS12377Order) Modulus() *big.Int { + r, _ := new(big.Int).SetString("2111115437357092606062206234695386632838870926408408195193685246394721360383", 10) + return r +} + +// edBW6761Order is the BW6-761 twisted Edwards curve order (374 bits). +type edBW6761Order struct{} + +func (edBW6761Order) NbLimbs() uint { return 6 } +func (edBW6761Order) BitsPerLimb() uint { return 64 } +func (edBW6761Order) IsPrime() bool { return true } +func (edBW6761Order) Modulus() *big.Int { + r, _ := new(big.Int).SetString("32333053251621136751331591711861691692049189094364332567435817881934511297123972799646723302813083835942624121493", 10) + return r +} diff --git a/std/algebra/native/twistededwards/hints.go b/std/algebra/native/twistededwards/hints.go index b83564f213..1bffcdb200 100644 --- a/std/algebra/native/twistededwards/hints.go +++ b/std/algebra/native/twistededwards/hints.go @@ -5,6 +5,7 @@ import ( "math/big" "sync" + "github.com/consensys/gnark-crypto/algebra/lattice" "github.com/consensys/gnark-crypto/ecc" edbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/twistededwards" "github.com/consensys/gnark-crypto/ecc/bls12-381/bandersnatch" @@ -16,9 +17,12 @@ import ( func GetHints() []solver.Hint { return []solver.Hint{ - halfGCD, + rationalReconstruct, scalarMulHint, decomposeScalar, + doubleBaseScalarMulHint, + multiRationalReconstructHint, + multiRationalReconstructExtHint, } } @@ -63,40 +67,48 @@ func decomposeScalar(scalarField *big.Int, inputs []*big.Int, res []*big.Int) er return nil } -func halfGCD(mod *big.Int, inputs, outputs []*big.Int) error { +// rationalReconstruct decomposes a scalar s ∈ Fr into (s1, s2, signBit) such +// that s1 + s2·s = 0 mod r, with |s1|, |s2| < γ₂·√r ≈ 1.15·√r +// (proven LLL/Hermite bound). Replaces the older heuristic-bound HalfGCD. +// +// The bit-decomposition convention: s1 ≥ 0 always, s2 = ±|s2| with signBit = 1 +// iff the underlying signed s2 was negative. +func rationalReconstruct(_ *big.Int, inputs, outputs []*big.Int) error { if len(inputs) != 2 { - return errors.New("expecting two inputs") + return errors.New("expecting two inputs (s, r)") } - if len(outputs) != 4 { - return errors.New("expecting four outputs") + if len(outputs) != 3 { + return errors.New("expecting three outputs (s1, |s2|, signBit)") } - // using PrecomputeLattice for scalar decomposition is a hack and it doesn't - // work in case the scalar is zero. override it for now to avoid division by - // zero until a long-term solution is found. + // Zero scalar: trivial (s1=s2=0). The in-circuit IsZero(s2)=0 guard + // rejects this; the caller must pre-route scalar=1 (mirrors the existing + // scalarMulFakeGLV: checkedScalar = Select(isScalarZero, 1, scalar)). if inputs[0].Sign() == 0 { - outputs[0].SetUint64(0) - outputs[1].SetUint64(0) - outputs[2].SetUint64(0) - outputs[3].SetUint64(0) + for i := range outputs { + outputs[i].SetUint64(0) + } return nil } - glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(inputs[1], inputs[0], glvBasis) - outputs[0].Set(&glvBasis.V1[0]) - outputs[1].Set(&glvBasis.V1[1]) - // figure out how many times we have overflowed - // s2 * s + s1 = k*r - outputs[3].Mul(outputs[1], inputs[0]). - Add(outputs[3], outputs[0]). - Div(outputs[3], inputs[1]) + // lattice.RationalReconstruct returns (x, z) with x ≡ z·s mod r. + // Map onto our convention: s1 + s2·s = 0 mod r ⇒ s1 = x, s2 = −z. + res := lattice.RationalReconstruct(inputs[0], inputs[1]) + x, z := new(big.Int).Set(res[0]), new(big.Int).Set(res[1]) + // Normalise so s1 ≥ 0. Flipping signs of (x, z) preserves x − z·s = m·r + // (with m negated). + if x.Sign() < 0 { + x.Neg(x) + z.Neg(z) + } + outputs[0].Set(x) // s1 = x ≥ 0 + + // s2 = −z, encoded as |s2| + signBit. signBit = 1 iff −z < 0 iff z > 0. + outputs[1].Abs(z) outputs[2].SetUint64(0) - if outputs[1].Sign() == -1 { - outputs[1].Neg(outputs[1]) + if z.Sign() > 0 { outputs[2].SetUint64(1) } - return nil } @@ -151,3 +163,191 @@ func scalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error } return nil } + +// doubleBaseScalarMulHint computes [s1]P1 and [s2]P2 separately and returns +// their (X, Y) coords. Inputs: P1.X, P1.Y, s1, P2.X, P2.Y, s2, order. +// Outputs: Q1.X, Q1.Y, Q2.X, Q2.Y where Q1=[s1]P1 and Q2=[s2]P2. +// +// Used by `doubleBaseScalarMul3MSMLogUp` and `doubleBaseScalarMul6MSMLogUp` to +// hint the result that the in-circuit MSM verifies. +func doubleBaseScalarMulHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 7 { + return errors.New("expecting seven inputs") + } + if len(outputs) != 4 { + return errors.New("expecting four outputs") + } + if field.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + order, _ := new(big.Int).SetString("13108968793781547619861935127046491459309155893440570251786403306729687672801", 10) + if inputs[6].Cmp(order) == 0 { + var P1, P2 bandersnatch.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else { + var P1, P2 jubjub.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } + } else if field.Cmp(ecc.BN254.ScalarField()) == 0 { + var P1, P2 babyjubjub.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else if field.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + var P1, P2 edbls12377.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else if field.Cmp(ecc.BW6_761.ScalarField()) == 0 { + var P1, P2 edbw6761.PointAffine + P1.X.SetBigInt(inputs[0]) + P1.Y.SetBigInt(inputs[1]) + P1.ScalarMultiplication(&P1, inputs[2]) + P2.X.SetBigInt(inputs[3]) + P2.Y.SetBigInt(inputs[4]) + P2.ScalarMultiplication(&P2, inputs[5]) + P1.X.BigInt(outputs[0]) + P1.Y.BigInt(outputs[1]) + P2.X.BigInt(outputs[2]) + P2.Y.BigInt(outputs[3]) + } else { + return errors.New("doubleBaseScalarMulHint: unknown curve") + } + return nil +} + +// multiRationalReconstructHint decomposes (k1, k2) jointly via 3-D LLL +// reconstruction: finds (x1, x2, z) with a shared denominator z such that +// +// k1 ≡ x1 / z (mod r) +// k2 ≡ x2 / z (mod r) +// +// with each component bounded by ~r^(2/3). Used by the non-GLV +// `doubleBaseScalarMul3MSMLogUp` path. +// +// inputs: k1, k2, order +// outputs[0..2]: |x1|, |x2|, |z| +// outputs[3..5]: signX1, signX2, signZ +func multiRationalReconstructHint(_ *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 3 { + return errors.New("expecting three inputs: k1, k2, order") + } + if len(outputs) != 6 { + return errors.New("expecting six outputs") + } + k1, k2, order := inputs[0], inputs[1], inputs[2] + + if k1.Sign() == 0 && k2.Sign() == 0 { + for i := range outputs { + outputs[i].SetUint64(0) + } + return nil + } + + res := lattice.NewReconstructor(order).MultiRationalReconstruct(k1, k2) + x1, x2, z := res[0], res[1], res[2] + + outputs[0].Abs(x1) + outputs[1].Abs(x2) + outputs[2].Abs(z) + + setSign := func(out *big.Int, val *big.Int) { + if val.Sign() < 0 { + out.SetUint64(1) + } else { + out.SetUint64(0) + } + } + setSign(outputs[3], x1) + setSign(outputs[4], x2) + setSign(outputs[5], z) + + return nil +} + +// multiRationalReconstructExtHint decomposes (k1, k2) jointly via 6-D LLL +// reconstruction: finds (x1, y1, x2, y2, z, t) with shared denominator +// (z + λ·t) such that +// +// k1 ≡ (x1 + λ·y1) / (z + λ·t) (mod r) +// k2 ≡ (x2 + λ·y2) / (z + λ·t) (mod r) +// +// with each component bounded by ~r^(1/3). Used by the GLV-curve +// `doubleBaseScalarMul6MSMLogUp` path. +// +// inputs: k1, k2, order, lambda +// outputs[0..5]: |x1|, |y1|, |x2|, |y2|, |z|, |t| +// outputs[6..11]: signX1, signY1, signX2, signY2, signZ, signT +func multiRationalReconstructExtHint(_ *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) != 4 { + return errors.New("expecting four inputs: k1, k2, order, lambda") + } + if len(outputs) != 12 { + return errors.New("expecting 12 outputs") + } + k1, k2, order, lambda := inputs[0], inputs[1], inputs[2], inputs[3] + + if k1.Sign() == 0 && k2.Sign() == 0 { + for i := range outputs { + outputs[i].SetUint64(0) + } + return nil + } + + rc := lattice.NewReconstructor(order).SetLambda(lambda) + res := rc.MultiRationalReconstructExt(k1, k2) + x1, y1, x2, y2, z, t := res[0], res[1], res[2], res[3], res[4], res[5] + + outputs[0].Abs(x1) + outputs[1].Abs(y1) + outputs[2].Abs(x2) + outputs[3].Abs(y2) + outputs[4].Abs(z) + outputs[5].Abs(t) + + setSign := func(out *big.Int, val *big.Int) { + if val.Sign() < 0 { + out.SetUint64(1) + } else { + out.SetUint64(0) + } + } + setSign(outputs[6], x1) + setSign(outputs[7], y1) + setSign(outputs[8], x2) + setSign(outputs[9], y2) + setSign(outputs[10], z) + setSign(outputs[11], t) + + return nil +} diff --git a/std/algebra/native/twistededwards/point.go b/std/algebra/native/twistededwards/point.go index de8f1b8967..762dc84ce9 100644 --- a/std/algebra/native/twistededwards/point.go +++ b/std/algebra/native/twistededwards/point.go @@ -3,7 +3,10 @@ package twistededwards -import "github.com/consensys/gnark/frontend" +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/lookup/logderivlookup" +) // neg computes the negative of a point in SNARK coordinates func (p *Point) neg(api frontend.API, p1 *Point) *Point { @@ -131,26 +134,18 @@ func (p *Point) scalarMulFakeGLV(api frontend.API, p1 *Point, scalar frontend.Va checkedScalar := api.Select(isScalarZero, 1, scalar) // the hints allow to decompose the scalar s into s1 and s2 such that - // s1 + s * s2 == 0 mod Order, - s, err := api.NewHint(halfGCD, 4, checkedScalar, curve.Order) + // s1 + s * s2 == 0 mod Order. Uses LLL-based lattice rational + // reconstruction with proven Hermite bound |s1|, |s2| < γ₂·√r ≈ 1.15·√r + // (see [EEMP25] / gnark-crypto/algebra/lattice). + s, err := api.NewHint(rationalReconstruct, 3, checkedScalar, curve.Order) if err != nil { // err is non-nil only for invalid number of inputs panic(err) } - s1, s2, bit, k := s[0], s[1], s[2], s[3] + s1, s2, bit := s[0], s[1], s[2] - // check that s1 + s2 * s == k*Order - _s2 := api.Mul(s2, checkedScalar) - _k := api.Mul(k, curve.Order) - lhs := api.Select(bit, s1, api.Add(s1, _s2)) - rhs := api.Select(bit, api.Add(_k, _s2), _k) - api.AssertIsEqual(lhs, rhs) - // A malicious hint can provide s1=s2=0, which makes the relation vacuous. - api.AssertIsEqual(api.IsZero(s2), 0) - - n := (curve.Order.BitLen() + 1) / 2 - b1 := api.ToBinary(s1, n) - b2 := api.ToBinary(s2, n) + b1, b2 := verifyScalarDecomposition(api, s1, s2, bit, checkedScalar, curve) + n := len(b1) var res, p2, p3, tmp Point q, err := api.NewHint(scalarMulHint, 2, p1.X, p1.Y, checkedScalar, curve.Order) @@ -181,3 +176,321 @@ func (p *Point) scalarMulFakeGLV(api frontend.API, p1 *Point, scalar frontend.Va return p } + +// phi is the GLV endomorphism on Bandersnatch: (x, y) → ((1-y²)·E1/(x·y), +// (y²+E0)·E0/(y²-E0)) acts as scalar multiplication by Lambda on the prime- +// order subgroup. Used by `doubleBaseScalarMul6MSMLogUp` only. +func (p *Point) phi(api frontend.API, p1 *Point, curve *CurveParams, endo *EndoParams) *Point { + xy := api.Mul(p1.X, p1.Y) + yy := api.Mul(p1.Y, p1.Y) + f := api.Sub(1, yy) + f = api.Mul(f, endo.Endo[1]) + g := api.Add(yy, endo.Endo[0]) + g = api.Mul(g, endo.Endo[0]) + h := api.Sub(yy, endo.Endo[0]) + + p.X = api.DivUnchecked(f, xy) + p.Y = api.DivUnchecked(g, h) + return p +} + +// doubleBaseScalarMul3MSMLogUp computes s1*P1+s2*P2 using MultiRationalReconstruct. +// This decomposes both scalars with a shared denominator in Z, giving +// ~r^(2/3)-bit scalars. It verifies [x1]P1 + [x2]P2 - [z]R = O where +// R = [s1]P1 + [s2]P2 is hinted. +func (p *Point) doubleBaseScalarMul3MSMLogUp(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve *CurveParams) *Point { + // Get hinted results Q1 = [s1]P1 and Q2 = [s2]P2 + q, err := api.NewHint(doubleBaseScalarMulHint, 4, p1.X, p1.Y, s1, p2.X, p2.Y, s2, curve.Order) + if err != nil { + panic(err) + } + var Q1, Q2 Point + Q1.X, Q1.Y = q[0], q[1] + Q2.X, Q2.Y = q[2], q[3] + + var R Point + R.add(api, &Q1, &Q2, curve) + + // Decompose (s1, s2) into (x1, x2, z) such that + // s1*z ≡ x1 and s2*z ≡ x2 (mod Order). + h, err := api.NewHint(multiRationalReconstructHint, 6, s1, s2, curve.Order) + if err != nil { + panic(err) + } + absX1, absX2, absZ := h[0], h[1], h[2] + signX1, signX2, signZ := h[3], h[4], h[5] + + // Verify the decomposition using emulated arithmetic to avoid native field + // overflow. Also range-checks x1, x2, z and ensures z is non-zero. + bX1, bX2, bZ := verifyScalarDecomposition3D(api, s1, s2, absX1, absX2, absZ, signX1, signX2, signZ, curve) + + var sP1, sP2, sR Point + sP1.X = api.Select(signX1, api.Neg(p1.X), p1.X) + sP1.Y = p1.Y + sP2.X = api.Select(signX2, api.Neg(p2.X), p2.X) + sP2.Y = p2.Y + sR.X = api.Select(signZ, R.X, api.Neg(R.X)) + sR.Y = R.Y + + // Build the 8-entry table for 3-MSM: sP1, sP2, sR. + var table [8]Point + table[0] = Point{X: 0, Y: 1} + table[1] = sP1 + table[2] = sP2 + table[3].add(api, &sP1, &sP2, curve) + table[4] = sR + table[5].add(api, &sP1, &sR, curve) + table[6].add(api, &sP2, &sR, curve) + table[7].add(api, &table[3], &sR, curve) + + // Create LogDerivLookup tables + tableX := logderivlookup.New(api) + tableY := logderivlookup.New(api) + for i := 0; i < 8; i++ { + tableX.Insert(table[i].X) + tableY.Insert(table[i].Y) + } + + n := len(bX1) + + // Compute indices for lookups + indices := make([]frontend.Variable, n) + for i := 0; i < n; i++ { + // index = bX1[i] + 2*bX2[i] + 4*bZ[i] + indices[i] = api.Add(bX1[i], api.Mul(bX2[i], 2), api.Mul(bZ[i], 4)) + } + + // Batch lookup + resX := tableX.Lookup(indices...) + resY := tableY.Lookup(indices...) + + // Initialize accumulator with first entry + var res Point + res.X = resX[n-1] + res.Y = resY[n-1] + + for i := n - 2; i >= 0; i-- { + res.double(api, &res, curve) + var tmp Point + tmp.X = resX[i] + tmp.Y = resY[i] + res.add(api, &res, &tmp, curve) + } + + // Verify accumulator equals identity (0, 1) + api.AssertIsEqual(res.X, 0) + api.AssertIsEqual(res.Y, 1) + + p.X = R.X + p.Y = R.Y + + return p +} + +// doubleBaseScalarMul6MSMLogUp computes s1*P1+s2*P2 using MultiRationalReconstructExt (true 6-MSM). +// This decomposes both scalars with a shared denominator in Z[λ], giving ~r^(1/3)-bit scalars. +// Verifies: [x1]P + [y1]φ(P) + [x2]Q + [y2]φ(Q) = [z]R + [t]φ(R) +// where R = [s1]P + [s2]Q (hinted). +// Only works for curves with efficient endomorphism (e.g., Bandersnatch). +func (p *Point) doubleBaseScalarMul6MSMLogUp(api frontend.API, p1, p2 *Point, s1, s2 frontend.Variable, curve *CurveParams, endo *EndoParams) *Point { + // Get hinted result R = [s1]P + [s2]Q + qHint, err := api.NewHint(doubleBaseScalarMulHint, 4, p1.X, p1.Y, s1, p2.X, p2.Y, s2, curve.Order) + if err != nil { + panic(err) + } + var R Point + // We need Q1 + Q2 = R + var Q1, Q2 Point + Q1.X, Q1.Y = qHint[0], qHint[1] + Q2.X, Q2.Y = qHint[2], qHint[3] + R.add(api, &Q1, &Q2, curve) + + // Decompose (s1, s2) using MultiRationalReconstructExt. Returns + // |x1|, |y1|, |x2|, |y2|, |z|, |t| and their signs. + h, err := api.NewHint(multiRationalReconstructExtHint, 12, s1, s2, curve.Order, endo.Lambda) + if err != nil { + panic(err) + } + absX1, absY1, absX2, absY2, absZ, absT := h[0], h[1], h[2], h[3], h[4], h[5] + signX1, signY1, signX2, signY2, signZ, signT := h[6], h[7], h[8], h[9], h[10], h[11] + + // Verify the decomposition using emulated arithmetic to avoid native field overflow. + // Checks: s_i * (z + λ*t) ≡ x_i + λ*y_i (mod r) for i=1,2 + // Also range-checks sub-scalars and ensures the shared denominator is non-zero. + bX1, bY1, bX2, bY2, bZ, bT := verifyScalarDecomposition6D(api, s1, s2, + absX1, absY1, absX2, absY2, absZ, absT, + signX1, signY1, signX2, signY2, signZ, signT, + curve, endo, + ) + + // Compute φ(P1), φ(P2), φ(R) + var phiP1, phiP2, phiR Point + phiP1.phi(api, p1, curve, endo) + phiP2.phi(api, p2, curve, endo) + phiR.phi(api, &R, curve, endo) + + // Apply signs to create signed points for the 6-MSM + // The verification is: [x1]P + [y1]φ(P) + [x2]Q + [y2]φ(Q) - [z]R - [t]φ(R) = O + // With signs: we negate the point when the sign is 1 + var sP1, sPhiP1, sP2, sPhiP2, sR, sPhiR Point + + // For P1: if signX1 == 1, use -P1, else use P1 + sP1.X = api.Select(signX1, api.Neg(p1.X), p1.X) + sP1.Y = p1.Y + + // For φ(P1): if signY1 == 1, use -φ(P1), else use φ(P1) + sPhiP1.X = api.Select(signY1, api.Neg(phiP1.X), phiP1.X) + sPhiP1.Y = phiP1.Y + + // For P2: if signX2 == 1, use -P2, else use P2 + sP2.X = api.Select(signX2, api.Neg(p2.X), p2.X) + sP2.Y = p2.Y + + // For φ(P2): if signY2 == 1, use -φ(P2), else use φ(P2) + sPhiP2.X = api.Select(signY2, api.Neg(phiP2.X), phiP2.X) + sPhiP2.Y = phiP2.Y + + // For R: we subtract [z]R, so if signZ == 0 (z positive), use -R; if signZ == 1 (z negative), use R + sR.X = api.Select(signZ, R.X, api.Neg(R.X)) + sR.Y = R.Y + + // For φ(R): similarly for t + sPhiR.X = api.Select(signT, phiR.X, api.Neg(phiR.X)) + sPhiR.Y = phiR.Y + + // Build 64-entry table for 6-MSM + // Index = b0 + 2*b1 + 4*b2 + 8*b3 + 16*b4 + 32*b5 + // Points: sP1, sPhiP1, sP2, sPhiP2, sR, sPhiR + var table [64]Point + + // Precompute all 64 combinations + // table[i] = (i&1)*sP1 + ((i>>1)&1)*sPhiP1 + ((i>>2)&1)*sP2 + ((i>>3)&1)*sPhiP2 + ((i>>4)&1)*sR + ((i>>5)&1)*sPhiR + + // Start with identity + table[0] = Point{X: 0, Y: 1} + + // Single points + table[1] = sP1 + table[2] = sPhiP1 + table[4] = sP2 + table[8] = sPhiP2 + table[16] = sR + table[32] = sPhiR + + // 2-combinations + table[3].add(api, &sP1, &sPhiP1, curve) + table[5].add(api, &sP1, &sP2, curve) + table[6].add(api, &sPhiP1, &sP2, curve) + table[9].add(api, &sP1, &sPhiP2, curve) + table[10].add(api, &sPhiP1, &sPhiP2, curve) + table[12].add(api, &sP2, &sPhiP2, curve) + table[17].add(api, &sP1, &sR, curve) + table[18].add(api, &sPhiP1, &sR, curve) + table[20].add(api, &sP2, &sR, curve) + table[24].add(api, &sPhiP2, &sR, curve) + table[33].add(api, &sP1, &sPhiR, curve) + table[34].add(api, &sPhiP1, &sPhiR, curve) + table[36].add(api, &sP2, &sPhiR, curve) + table[40].add(api, &sPhiP2, &sPhiR, curve) + table[48].add(api, &sR, &sPhiR, curve) + + // 3-combinations (build from 2-combinations) + table[7].add(api, &table[3], &sP2, curve) // sP1 + sPhiP1 + sP2 + table[11].add(api, &table[3], &sPhiP2, curve) // sP1 + sPhiP1 + sPhiP2 + table[13].add(api, &table[5], &sPhiP2, curve) // sP1 + sP2 + sPhiP2 + table[14].add(api, &table[6], &sPhiP2, curve) // sPhiP1 + sP2 + sPhiP2 + table[19].add(api, &table[3], &sR, curve) // sP1 + sPhiP1 + sR + table[21].add(api, &table[5], &sR, curve) // sP1 + sP2 + sR + table[22].add(api, &table[6], &sR, curve) // sPhiP1 + sP2 + sR + table[25].add(api, &table[9], &sR, curve) // sP1 + sPhiP2 + sR + table[26].add(api, &table[10], &sR, curve) // sPhiP1 + sPhiP2 + sR + table[28].add(api, &table[12], &sR, curve) // sP2 + sPhiP2 + sR + table[35].add(api, &table[3], &sPhiR, curve) // sP1 + sPhiP1 + sPhiR + table[37].add(api, &table[5], &sPhiR, curve) // sP1 + sP2 + sPhiR + table[38].add(api, &table[6], &sPhiR, curve) // sPhiP1 + sP2 + sPhiR + table[41].add(api, &table[9], &sPhiR, curve) // sP1 + sPhiP2 + sPhiR + table[42].add(api, &table[10], &sPhiR, curve) // sPhiP1 + sPhiP2 + sPhiR + table[44].add(api, &table[12], &sPhiR, curve) // sP2 + sPhiP2 + sPhiR + table[49].add(api, &table[17], &sPhiR, curve) // sP1 + sR + sPhiR + table[50].add(api, &table[18], &sPhiR, curve) // sPhiP1 + sR + sPhiR + table[52].add(api, &table[20], &sPhiR, curve) // sP2 + sR + sPhiR + table[56].add(api, &table[24], &sPhiR, curve) // sPhiP2 + sR + sPhiR + + // 4-combinations + table[15].add(api, &table[7], &sPhiP2, curve) // sP1 + sPhiP1 + sP2 + sPhiP2 + table[23].add(api, &table[7], &sR, curve) // sP1 + sPhiP1 + sP2 + sR + table[27].add(api, &table[11], &sR, curve) // sP1 + sPhiP1 + sPhiP2 + sR + table[29].add(api, &table[13], &sR, curve) // sP1 + sP2 + sPhiP2 + sR + table[30].add(api, &table[14], &sR, curve) // sPhiP1 + sP2 + sPhiP2 + sR + table[39].add(api, &table[7], &sPhiR, curve) // sP1 + sPhiP1 + sP2 + sPhiR + table[43].add(api, &table[11], &sPhiR, curve) // sP1 + sPhiP1 + sPhiP2 + sPhiR + table[45].add(api, &table[13], &sPhiR, curve) // sP1 + sP2 + sPhiP2 + sPhiR + table[46].add(api, &table[14], &sPhiR, curve) // sPhiP1 + sP2 + sPhiP2 + sPhiR + table[51].add(api, &table[19], &sPhiR, curve) // sP1 + sPhiP1 + sR + sPhiR + table[53].add(api, &table[21], &sPhiR, curve) // sP1 + sP2 + sR + sPhiR + table[54].add(api, &table[22], &sPhiR, curve) // sPhiP1 + sP2 + sR + sPhiR + table[57].add(api, &table[25], &sPhiR, curve) // sP1 + sPhiP2 + sR + sPhiR + table[58].add(api, &table[26], &sPhiR, curve) // sPhiP1 + sPhiP2 + sR + sPhiR + table[60].add(api, &table[28], &sPhiR, curve) // sP2 + sPhiP2 + sR + sPhiR + + // 5-combinations + table[31].add(api, &table[15], &sR, curve) // all except sPhiR + table[47].add(api, &table[15], &sPhiR, curve) // all except sR + table[55].add(api, &table[23], &sPhiR, curve) // sP1 + sPhiP1 + sP2 + sR + sPhiR + table[59].add(api, &table[27], &sPhiR, curve) // sP1 + sPhiP1 + sPhiP2 + sR + sPhiR + table[61].add(api, &table[29], &sPhiR, curve) // sP1 + sP2 + sPhiP2 + sR + sPhiR + table[62].add(api, &table[30], &sPhiR, curve) // sPhiP1 + sP2 + sPhiP2 + sR + sPhiR + + // 6-combination (all points) + table[63].add(api, &table[31], &sPhiR, curve) + + // Use LogDerivLookup for the 64-entry table + tableX := logderivlookup.New(api) + tableY := logderivlookup.New(api) + for i := 0; i < 64; i++ { + tableX.Insert(table[i].X) + tableY.Insert(table[i].Y) + } + + n := len(bX1) + + // Compute indices for lookups + indices := make([]frontend.Variable, n) + for i := 0; i < n; i++ { + indices[i] = api.Add( + bX1[i], + api.Mul(bY1[i], 2), + api.Mul(bX2[i], 4), + api.Mul(bY2[i], 8), + api.Mul(bZ[i], 16), + api.Mul(bT[i], 32), + ) + } + + // Batch lookup + lookupX := tableX.Lookup(indices...) + lookupY := tableY.Lookup(indices...) + + // Initialize accumulator with last entry + var acc Point + acc.X = lookupX[n-1] + acc.Y = lookupY[n-1] + + for i := n - 2; i >= 0; i-- { + acc.double(api, &acc, curve) + var tmp Point + tmp.X = lookupX[i] + tmp.Y = lookupY[i] + acc.add(api, &acc, &tmp, curve) + } + + // Verify accumulator equals identity (0, 1) + api.AssertIsEqual(acc.X, 0) + api.AssertIsEqual(acc.Y, 1) + + // Return R (the hinted result) + p.X = R.X + p.Y = R.Y + + return p +} diff --git a/std/algebra/native/twistededwards/scalar_decomp.go b/std/algebra/native/twistededwards/scalar_decomp.go new file mode 100644 index 0000000000..94cac78901 --- /dev/null +++ b/std/algebra/native/twistededwards/scalar_decomp.go @@ -0,0 +1,242 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +package twistededwards + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +// verifyScalarDecomposition checks s1 + s2*scalar ≡ 0 (mod r) using emulated +// arithmetic to avoid native field overflow. The sign bit controls whether +// the relation is s1 + s2*scalar or s1 - s2*scalar. +// +// s1 and s2 are range-checked to nBits via ToBinary inside this function. +// Returns the bit decompositions of s1 and s2. +func verifyScalarDecomposition( + api frontend.API, + s1, s2, bit, scalar frontend.Variable, + curve *CurveParams, +) (s1Bits, s2Bits []frontend.Variable) { + r := curve.Order + n := (r.BitLen() + 1) / 2 + + // Range-check s1, s2 via ToBinary + s1Bits = api.ToBinary(s1, n) + s2Bits = api.ToBinary(s2, n) + + // Dispatch to the correct emulated field based on the curve order + switch { + case r.BitLen() <= 253 && r.Cmp(edBN254Order{}.Modulus()) == 0: + verifyDecompEmulated[edBN254Order](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + case r.BitLen() <= 253 && r.Cmp(edBLS12381Order{}.Modulus()) == 0: + verifyDecompEmulated[edBLS12381Order](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + case r.BitLen() <= 253 && r.Cmp(edBandersnatchOrder{}.Modulus()) == 0: + verifyDecompEmulated[edBandersnatchOrder](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + case r.BitLen() <= 253 && r.Cmp(edBLS12377Order{}.Modulus()) == 0: + verifyDecompEmulated[edBLS12377Order](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + case r.Cmp(edBW6761Order{}.Modulus()) == 0: + verifyDecompEmulated[edBW6761Order](api, s1, s2, bit, scalar, s1Bits, s2Bits, r) + default: + panic(fmt.Sprintf("unsupported twisted Edwards curve order: %s", r.String())) + } + + return s1Bits, s2Bits +} + +func verifyDecompEmulated[T emulated.FieldParams]( + api frontend.API, + s1, s2, bit, scalar frontend.Variable, + s1Bits, s2Bits []frontend.Variable, + r *big.Int, +) { + f, err := emulated.NewField[T](api) + if err != nil { + panic(fmt.Sprintf("failed to create emulated field: %v", err)) + } + + scalarBits := api.ToBinary(scalar, api.Compiler().FieldBitLen()) + + s1Emu := f.FromBits(s1Bits...) + s2Emu := f.FromBits(s2Bits...) + scalarEmu := f.FromBits(scalarBits...) + zero := f.Zero() + + // Compute s2 * scalar mod r + s2s := f.Mul(s2Emu, scalarEmu) + + // Check: s1 ± s2*scalar ≡ 0 (mod r) + // When bit=0: s1 + s2*scalar ≡ 0 → s1 ≡ -s2*scalar + // When bit=1: s1 - s2*scalar ≡ 0 → s1 ≡ s2*scalar + // Equivalently: s1 + Select(bit, -s2s, s2s) ≡ 0 + negS2s := f.Neg(s2s) + term := f.Select(bit, negS2s, s2s) + sum := f.Add(s1Emu, term) + f.AssertIsEqual(sum, zero) + + // Ensure s2 is non-zero to prevent trivial decomposition. + // When scalar=0, s2=0 is legitimate. + scalarIsZero := api.IsZero(scalar) + s2Check := f.Select(scalarIsZero, f.One(), s2Emu) + f.AssertIsDifferent(s2Check, zero) +} + +// verifyScalarDecomposition3D checks a shared-denominator decomposition: +// s1*z ≡ x1 (mod r) and s2*z ≡ x2 (mod r). +// Used by doubleBaseScalarMul3MSMLogUp. +func verifyScalarDecomposition3D( + api frontend.API, + s1, s2 frontend.Variable, + absX1, absX2, absZ frontend.Variable, + signX1, signX2, signZ frontend.Variable, + curve *CurveParams, +) (x1Bits, x2Bits, zBits []frontend.Variable) { + r := curve.Order + n := (2*r.BitLen() + 2) / 3 + + x1Bits = api.ToBinary(absX1, n) + x2Bits = api.ToBinary(absX2, n) + zBits = api.ToBinary(absZ, n) + + switch { + case r.Cmp(edBN254Order{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBN254Order](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + case r.Cmp(edBLS12381Order{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBLS12381Order](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + case r.Cmp(edBandersnatchOrder{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBandersnatchOrder](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + case r.Cmp(edBLS12377Order{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBLS12377Order](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + case r.Cmp(edBW6761Order{}.Modulus()) == 0: + verifyDecomp3DEmulated[edBW6761Order](api, s1, s2, signX1, signX2, signZ, x1Bits, x2Bits, zBits) + default: + panic(fmt.Sprintf("unsupported twisted Edwards curve order: %s", r.String())) + } + + return +} + +func verifyDecomp3DEmulated[T emulated.FieldParams]( + api frontend.API, + s1, s2 frontend.Variable, + signX1, signX2, signZ frontend.Variable, + x1Bits, x2Bits, zBits []frontend.Variable, +) { + f, err := emulated.NewField[T](api) + if err != nil { + panic(fmt.Sprintf("failed to create emulated field: %v", err)) + } + + nativeBits := api.Compiler().FieldBitLen() + s1Bits := api.ToBinary(s1, nativeBits) + s2Bits := api.ToBinary(s2, nativeBits) + + x1Emu := f.FromBits(x1Bits...) + x2Emu := f.FromBits(x2Bits...) + zEmu := f.FromBits(zBits...) + s1Emu := f.FromBits(s1Bits...) + s2Emu := f.FromBits(s2Bits...) + zero := f.Zero() + + x1Signed := f.Select(signX1, f.Neg(x1Emu), x1Emu) + x2Signed := f.Select(signX2, f.Neg(x2Emu), x2Emu) + zSigned := f.Select(signZ, f.Neg(zEmu), zEmu) + + f.AssertIsEqual(f.Mul(s1Emu, zSigned), x1Signed) + f.AssertIsEqual(f.Mul(s2Emu, zSigned), x2Signed) + + f.AssertIsDifferent(zEmu, zero) +} + +// verifyScalarDecomposition6D checks the 6D decomposition for doubleBaseScalarMul6MSMLogUp. +// Verifies: s_i * (z + λ*t) ≡ x_i + λ*y_i (mod r) for i=1,2 +// All verification is done in emulated arithmetic over the curve order to avoid overflow. +func verifyScalarDecomposition6D( + api frontend.API, + s1, s2 frontend.Variable, + absX1, absY1, absX2, absY2, absZ, absT frontend.Variable, + signX1, signY1, signX2, signY2, signZ, signT frontend.Variable, + curve *CurveParams, + endo *EndoParams, +) (x1Bits, y1Bits, x2Bits, y2Bits, zBits, tBits []frontend.Variable) { + r := curve.Order + n := (r.BitLen() + 2) / 3 + + x1Bits = api.ToBinary(absX1, n) + y1Bits = api.ToBinary(absY1, n) + x2Bits = api.ToBinary(absX2, n) + y2Bits = api.ToBinary(absY2, n) + zBits = api.ToBinary(absZ, n) + tBits = api.ToBinary(absT, n) + + switch { + case r.Cmp(edBandersnatchOrder{}.Modulus()) == 0: + verify6DEmulated[edBandersnatchOrder](api, s1, s2, x1Bits, y1Bits, x2Bits, y2Bits, zBits, tBits, + signX1, signY1, signX2, signY2, signZ, signT, r, endo.Lambda) + default: + // Currently only Bandersnatch has an endomorphism. Add other cases as needed. + panic(fmt.Sprintf("unsupported twisted Edwards curve order for 6D decomposition: %s", r.String())) + } + + return +} + +func verify6DEmulated[T emulated.FieldParams]( + api frontend.API, + s1, s2 frontend.Variable, + absX1Bits, absY1Bits, absX2Bits, absY2Bits, absZBits, absTBits []frontend.Variable, + signX1, signY1, signX2, signY2, signZ, signT frontend.Variable, + r, lambda *big.Int, +) { + f, err := emulated.NewField[T](api) + if err != nil { + panic(fmt.Sprintf("failed to create emulated field: %v", err)) + } + + absX1Emu := f.FromBits(absX1Bits...) + absY1Emu := f.FromBits(absY1Bits...) + absX2Emu := f.FromBits(absX2Bits...) + absY2Emu := f.FromBits(absY2Bits...) + absZEmu := f.FromBits(absZBits...) + absTEmu := f.FromBits(absTBits...) + + lambdaEmu := f.NewElement(lambda) + zero := f.Zero() + + // Signed values in emulated field + x1Emu := f.Select(signX1, f.Neg(absX1Emu), absX1Emu) + y1Emu := f.Select(signY1, f.Neg(absY1Emu), absY1Emu) + x2Emu := f.Select(signX2, f.Neg(absX2Emu), absX2Emu) + y2Emu := f.Select(signY2, f.Neg(absY2Emu), absY2Emu) + zEmu := f.Select(signZ, f.Neg(absZEmu), absZEmu) + tEmu := f.Select(signT, f.Neg(absTEmu), absTEmu) + + // d = z + λ*t (mod r) + dComputed := f.Add(zEmu, f.Mul(lambdaEmu, tEmu)) + + // n1 = x1 + λ*y1 (mod r) + n1Computed := f.Add(x1Emu, f.Mul(lambdaEmu, y1Emu)) + + // n2 = x2 + λ*y2 (mod r) + n2Computed := f.Add(x2Emu, f.Mul(lambdaEmu, y2Emu)) + + // s1 * d ≡ n1 (mod r) + nativeBits := api.Compiler().FieldBitLen() + s1Bits := api.ToBinary(s1, nativeBits) + s1Emu := f.FromBits(s1Bits...) + f.AssertIsEqual(f.Mul(s1Emu, dComputed), n1Computed) + + // s2 * d ≡ n2 (mod r) + s2Bits := api.ToBinary(s2, nativeBits) + s2Emu := f.FromBits(s2Bits...) + f.AssertIsEqual(f.Mul(s2Emu, dComputed), n2Computed) + + // Ensure d non-zero (unless both scalars are zero) + bothZero := api.And(api.IsZero(s1), api.IsZero(s2)) + dCheck := f.Select(bothZero, f.One(), dComputed) + f.AssertIsDifferent(dCheck, zero) +} diff --git a/std/algebra/native/twistededwards/twistededwards.go b/std/algebra/native/twistededwards/twistededwards.go index 0e730e93d6..c20cd732fc 100644 --- a/std/algebra/native/twistededwards/twistededwards.go +++ b/std/algebra/native/twistededwards/twistededwards.go @@ -33,6 +33,10 @@ type Curve interface { ScalarMul(p1 Point, scalar frontend.Variable) Point // DoubleBaseScalarMul computes [s1]p1+[s2]p2 for points that lie on the curve. DoubleBaseScalarMul(p1, p2 Point, s1, s2 frontend.Variable) Point + // DoubleBaseScalarMulNonZero computes [s1]p1+[s2]p2 with the optimized + // lattice MSM path. It requires s1, s2 to be nonzero and p1, p2 to be + // non-identity points. + DoubleBaseScalarMulNonZero(p1, p2 Point, s1, s2 frontend.Variable) Point API() frontend.API } @@ -48,6 +52,15 @@ type CurveParams struct { Base [2]*big.Int // base point coordinates } +// EndoParams holds the GLV endomorphism parameters for curves that have one +// (Bandersnatch). The endomorphism Φ(x, y) = ((Endo[0]·(1−y²)/(x·y) + …) acts as +// scalar multiplication by Lambda on the prime-order subgroup. This is used +// only by the `doubleBaseScalarMul6MSMLogUp` MSM(6, n/3) variant. +type EndoParams struct { + Endo [2]*big.Int + Lambda *big.Int +} + // NewEdCurve returns a new Edwards curve func NewEdCurve(api frontend.API, id twistededwards.ID) (Curve, error) { snarkField, err := GetSnarkField(id) @@ -62,8 +75,18 @@ func NewEdCurve(api frontend.API, id twistededwards.ID) (Curve, error) { return nil, err } - // default - return &curve{api: api, params: params, id: id}, nil + var endo *EndoParams + if id == twistededwards.BLS12_381_BANDERSNATCH { + endo = &EndoParams{ + Endo: [2]*big.Int{new(big.Int), new(big.Int)}, + Lambda: new(big.Int), + } + endo.Endo[0].SetString("37446463827641770816307242315180085052603635617490163568005256780843403514036", 10) + endo.Endo[1].SetString("49199877423542878313146170939139662862850515542392585932876811575731455068989", 10) + endo.Lambda.SetString("8913659658109529928382530854484400854125314752504019737736543920008458395397", 10) + } + + return &curve{api: api, params: params, endo: endo, id: id}, nil } func GetCurveParams(id twistededwards.ID) (*CurveParams, error) {