diff --git a/go.mod b/go.mod index 32dfbac22..1df1c3922 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.2.2-0.20260118153501-cba9f5475432 github.com/consensys/compress v0.3.0 - github.com/consensys/gnark-crypto v0.20.1 + github.com/consensys/gnark-crypto v0.20.2-0.20260403203858-2c33f2d1c64f github.com/fxamacker/cbor/v2 v2.9.0 github.com/google/go-cmp v0.7.0 github.com/google/pprof v0.0.0-20260202012954-cb029daf43ef @@ -18,7 +18,7 @@ require ( github.com/rs/zerolog v1.34.0 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.48.0 - golang.org/x/sync v0.19.0 + golang.org/x/sync v0.20.0 ) require ( @@ -33,10 +33,10 @@ require ( github.com/spf13/cobra v1.10.2 // indirect github.com/spf13/pflag v1.0.9 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/mod v0.33.0 // indirect - golang.org/x/sys v0.41.0 // indirect - golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 // indirect - golang.org/x/tools v0.42.0 // indirect + golang.org/x/mod v0.34.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c // indirect + golang.org/x/tools v0.43.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect ) diff --git a/go.sum b/go.sum index d3fdb86fc..90468d931 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.2.2-0.20260118153501-cba9f5475432 h1:4ACburMEVC+u github.com/consensys/bavard v0.2.2-0.20260118153501-cba9f5475432/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.3.0 h1:HRIcHvWkW9C9req0ZWg7mhYHzBarohXhcszIwHONVkM= github.com/consensys/compress v0.3.0/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.20.1 h1:PXDUBvk8AzhvWowHLWBEAfUQcV1/aZgWIqD6eMpXmDg= -github.com/consensys/gnark-crypto v0.20.1/go.mod h1:RBWrSgy+IDbGR69RRV313th3M/aZU1ubk2om+qHuTSc= +github.com/consensys/gnark-crypto v0.20.2-0.20260403203858-2c33f2d1c64f h1:geOe8olGACn4dFNJAaWNOoPHb0ich/BBX7/7pCohHr4= +github.com/consensys/gnark-crypto v0.20.2-0.20260403203858-2c33f2d1c64f/go.mod h1:NzeBHSZ49bIM7RtrNTYYR2kymTqwvI/A4eTgQlyQc+Q= github.com/consensys/gnark-solidity-checker v0.2.0 h1:i5iUEzNOkUvpaKm23UEe0wajBMwj7NzyT4EI0T2N8WQ= github.com/consensys/gnark-solidity-checker v0.2.0/go.mod h1:cEvl4g5AH+L4qGQLDOVZjqvn5IKZIAZdhSi8zAM6BiY= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= @@ -358,8 +358,8 @@ golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= -golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -424,8 +424,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -476,10 +476,10 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 h1:bTLqdHv7xrGlFbvf5/TXNxy/iUwwdkjhqQTJDjW7aj0= -golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4/go.mod h1:g5NllXBEermZrmR51cJDQxmJUHUOfRAaNyWBM+R+548= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c h1:6a8FdnNk6bTXBjR4AGKFgUKuo+7GnR3FX5L7CbveeZc= +golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c/go.mod h1:TpUTTEp9frx7rTdLpC9gFG9kdI7zVLFTFFlqaH2Cncw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -552,8 +552,8 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= -golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= -golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/smallfields/tinyfield/element.go b/internal/smallfields/tinyfield/element.go index c1f38612e..5bb13b02a 100644 --- a/internal/smallfields/tinyfield/element.go +++ b/internal/smallfields/tinyfield/element.go @@ -121,7 +121,7 @@ func (z *Element) Set(x *Element) *Element { // *big.Int // big.Int // []byte -func (z *Element) SetInterface(i1 interface{}) (*Element, error) { +func (z *Element) SetInterface(i1 any) (*Element, error) { if i1 == nil { return nil, errors.New("can't set tinyfield.Element with ") } @@ -417,7 +417,7 @@ func BatchInvert(a []Element) []Element { zeroes := bitset.New(uint(len(a))) accumulator := One() - for i := 0; i < len(a); i++ { + for i := range len(a) { if a[i].IsZero() { zeroes.Set(uint(i)) continue @@ -469,7 +469,7 @@ func Hash(msg, dst []byte, count int) ([]Element, error) { vv := pool.BigInt.Get() res := make([]Element, count) - for i := 0; i < count; i++ { + for i := range count { vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) res[i].SetBigInt(vv) } @@ -700,7 +700,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() // we assume v < q, so even if big.Int words are on 64bits, we can safely cast them to 32bits - for i := 0; i < len(vBits); i++ { + for i := range len(vBits) { z[i] = uint32(vBits[i]) } @@ -907,6 +907,31 @@ func (z *Element) Sqrt(x *Element) *Element { return nil } +var _bCbrtExponentElement *big.Int + +func init() { + _bCbrtExponentElement, _ = new(big.Int).SetString("1f", 16) +} + +// Cbrt z = ∛x (mod q) +// if the cube root doesn't exist (x is not a cube mod q) +// Cbrt leaves z unchanged and returns nil +func (z *Element) Cbrt(x *Element) *Element { + // q ≡ 2 (mod 3) + // using z = x^((2q-1)/3) (mod q) + z.Exp(*x, _bCbrtExponentElement) + // as we use x^((2q-1)/3), there is no check to do: every element has a unique cube root + return z +} + +// Cube sets z to x^3 and returns z +func (z *Element) Cube(x *Element) *Element { + var t Element + t.Square(x).Mul(&t, x) + z.Set(&t) + return z +} + // Inverse z = x⁻¹ (mod q) // // if x == 0, sets and returns z = x diff --git a/internal/smallfields/tinyfield/element_test.go b/internal/smallfields/tinyfield/element_test.go index fa1aeec33..ac777e36c 100644 --- a/internal/smallfields/tinyfield/element_test.go +++ b/internal/smallfields/tinyfield/element_test.go @@ -34,7 +34,7 @@ func BenchmarkElementSelect(b *testing.B) { y.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { benchResElement.Select(i%3, &x, &y) } } @@ -44,7 +44,7 @@ func BenchmarkElementSetRandom(b *testing.B) { x.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { x.MustSetRandom() } } @@ -55,7 +55,7 @@ func BenchmarkElementSetBytes(b *testing.B) { bb := x.Bytes() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.SetBytes(bb[:]) } @@ -65,21 +65,21 @@ func BenchmarkElementMulByConstants(b *testing.B) { b.Run("mulBy3", func(b *testing.B) { benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { MulBy3(&benchResElement) } }) b.Run("mulBy5", func(b *testing.B) { benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { MulBy5(&benchResElement) } }) b.Run("mulBy13", func(b *testing.B) { benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { MulBy13(&benchResElement) } }) @@ -91,7 +91,7 @@ func BenchmarkElementInverse(b *testing.B) { benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Inverse(&x) } @@ -102,7 +102,7 @@ func BenchmarkElementButterfly(b *testing.B) { x.MustSetRandom() benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { Butterfly(&x, &benchResElement) } } @@ -113,7 +113,7 @@ func BenchmarkElementExp(b *testing.B) { benchResElement.MustSetRandom() b1, _ := rand.Int(rand.Reader, Modulus()) b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Exp(x, b1) } } @@ -121,7 +121,7 @@ func BenchmarkElementExp(b *testing.B) { func BenchmarkElementDouble(b *testing.B) { benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Double(&benchResElement) } } @@ -131,7 +131,7 @@ func BenchmarkElementAdd(b *testing.B) { x.MustSetRandom() benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Add(&x, &benchResElement) } } @@ -141,7 +141,7 @@ func BenchmarkElementSub(b *testing.B) { x.MustSetRandom() benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Sub(&x, &benchResElement) } } @@ -149,7 +149,7 @@ func BenchmarkElementSub(b *testing.B) { func BenchmarkElementNeg(b *testing.B) { benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Neg(&benchResElement) } } @@ -159,7 +159,7 @@ func BenchmarkElementDiv(b *testing.B) { x.MustSetRandom() benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Div(&x, &benchResElement) } } @@ -167,7 +167,7 @@ func BenchmarkElementDiv(b *testing.B) { func BenchmarkElementFromMont(b *testing.B) { benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.fromMont() } } @@ -175,7 +175,7 @@ func BenchmarkElementFromMont(b *testing.B) { func BenchmarkElementSquare(b *testing.B) { benchResElement.MustSetRandom() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Square(&benchResElement) } } @@ -185,18 +185,27 @@ func BenchmarkElementSqrt(b *testing.B) { a.MustSetRandom() a.Square(&a) b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Sqrt(&a) } } +func BenchmarkElementCbrt(b *testing.B) { + var a Element + a.SetUint64(8) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cbrt(&a) + } +} + func BenchmarkElementMul(b *testing.B) { x := Element{ 25, } benchResElement.SetOne() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Mul(&benchResElement, &x) } } @@ -208,7 +217,7 @@ func BenchmarkElementCmp(b *testing.B) { benchResElement = x benchResElement[0] = 0 b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { benchResElement.Cmp(&x) } } @@ -1241,15 +1250,12 @@ func TestElementSquare(t *testing.T) { func(a testPairElement) bool { var c Element c.Square(&a.element) - var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) - properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( func(a testPairElement) bool { var c Element @@ -1270,7 +1276,6 @@ func TestElementSquare(t *testing.T) { a.BigInt(&aBig) var c Element c.Square(&a) - var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) @@ -1314,15 +1319,12 @@ func TestElementInverse(t *testing.T) { func(a testPairElement) bool { var c Element c.Inverse(&a.element) - var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) - properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( func(a testPairElement) bool { var c Element @@ -1343,7 +1345,6 @@ func TestElementInverse(t *testing.T) { a.BigInt(&aBig) var c Element c.Inverse(&a) - var d, e big.Int d.ModInverse(&aBig, Modulus()) @@ -1387,15 +1388,12 @@ func TestElementSqrt(t *testing.T) { func(a testPairElement) bool { var c Element c.Sqrt(&a.element) - var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) - properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( func(a testPairElement) bool { var c Element @@ -1416,7 +1414,6 @@ func TestElementSqrt(t *testing.T) { a.BigInt(&aBig) var c Element c.Sqrt(&a) - var d, e big.Int d.ModSqrt(&aBig, Modulus()) @@ -1431,6 +1428,103 @@ func TestElementSqrt(t *testing.T) { } +func TestElementCbrt(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Cbrt: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + b := a.element + + b.Cbrt(&a.element) + a.element.Cbrt(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Cbrt: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + // verify that c^3 == a (since there's no big.Int.ModCbrt) + // Cbrt returns nil if the element is not a cubic residue + var c Element + result := c.Cbrt(&a.element) + if result == nil { + // a is not a cubic residue, this is valid + return true + } + var cube, e big.Int + c.BigInt(&e) + cube.Exp(&e, big.NewInt(3), Modulus()) + return cube.Cmp(&a.bigint) == 0 + }, + genA, + )) + properties.Property("Cbrt: cubic residues must always have a cube root", prop.ForAll( + func(a testPairElement) bool { + // b = a³ is guaranteed to be a cubic residue + var b, c Element + b.Square(&a.element).Mul(&b, &a.element) + if c.Cbrt(&b) == nil { + return false + } + var check Element + check.Square(&c).Mul(&check, &c) + return check.Equal(&b) + }, + genA, + )) + + properties.Property("Cbrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Cbrt(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for i := range testValues { + a := testValues[i] + var aBig big.Int + a.BigInt(&aBig) + var c Element + // verify that c^3 == a (since there's no big.Int.ModCbrt) + // Cbrt returns nil if the element is not a cubic residue + result := c.Cbrt(&a) + if result == nil { + // a is not a cubic residue, this is valid, continue + continue + } + var cube, e big.Int + c.BigInt(&e) + cube.Exp(&e, big.NewInt(3), Modulus()) + if cube.Cmp(&aBig) != 0 { + t.Fatal("Cbrt failed for special value") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + func TestElementDouble(t *testing.T) { t.Parallel() parameters := gopter.DefaultTestParameters() @@ -1460,15 +1554,12 @@ func TestElementDouble(t *testing.T) { func(a testPairElement) bool { var c Element c.Double(&a.element) - var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) - properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( func(a testPairElement) bool { var c Element @@ -1489,7 +1580,6 @@ func TestElementDouble(t *testing.T) { a.BigInt(&aBig) var c Element c.Double(&a) - var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) @@ -1533,15 +1623,12 @@ func TestElementNeg(t *testing.T) { func(a testPairElement) bool { var c Element c.Neg(&a.element) - var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) - properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( func(a testPairElement) bool { var c Element @@ -1562,7 +1649,6 @@ func TestElementNeg(t *testing.T) { a.BigInt(&aBig) var c Element c.Neg(&a) - var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) @@ -1957,7 +2043,7 @@ func TestElementBatchInvert(t *testing.T) { for _, t := range tData { a := make([]Element, len(t)) - for i := 0; i < len(a); i++ { + for i := range len(a) { a[i].SetInt64(t[i]) } @@ -1965,7 +2051,7 @@ func TestElementBatchInvert(t *testing.T) { assert.True(len(aInv) == len(a)) - for i := 0; i < len(a); i++ { + for i := range len(a) { if a[i].IsZero() { assert.True(aInv[i].IsZero(), "0⁻¹ != 0") } else { @@ -2002,7 +2088,7 @@ func TestElementBatchInvert(t *testing.T) { assert.True(len(aInv) == len(a)) - for i := 0; i < len(a); i++ { + for i := range len(a) { if a[i].IsZero() { if !aInv[i].IsZero() { return false @@ -2123,7 +2209,7 @@ func TestElementMul2ExpNegN(t *testing.T) { var b, e, two Element var c [33]Element two.SetUint64(2) - for n := 0; n < 33; n++ { + for n := range 33 { e.Exp(two, big.NewInt(int64(n))).Inverse(&e) b.Mul(&a.element, &e) c[n].Mul2ExpNegN(&a.element, uint32(n)) diff --git a/internal/smallfields/tinyfield/vector.go b/internal/smallfields/tinyfield/vector.go index 29a1f6467..dabdddc76 100644 --- a/internal/smallfields/tinyfield/vector.go +++ b/internal/smallfields/tinyfield/vector.go @@ -12,12 +12,12 @@ import ( "fmt" "io" "math/bits" - "runtime" "slices" "strings" - "sync" "sync/atomic" "unsafe" + + "github.com/consensys/gnark-crypto/parallel" ) // Vector represents a slice of Element. @@ -59,7 +59,7 @@ func (vector *Vector) WriteTo(w io.Writer) (int64, error) { n := int64(4) var buf [Bytes]byte - for i := 0; i < len(*vector); i++ { + for i := range len(*vector) { BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) @@ -147,7 +147,7 @@ func (vector *Vector) AsyncReadFrom(r io.Reader) (int64, error, chan error) { // go func() { var cptErrors uint64 // process the elements in parallel - execute(int(headerSliceLen), func(start, end int) { + parallel.Execute(int(headerSliceLen), func(start, end int) { var z Element for i := start; i < end; i++ { @@ -217,7 +217,7 @@ func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { *vector = []Element{} } - for i := uint64(0); i < headerSliceLen; i++ { + for i := range headerSliceLen { read, err := io.ReadFull(r, buf[:]) totalRead += int64(read) if errors.Is(err, io.ErrUnexpectedEOF) { @@ -243,7 +243,7 @@ func (vector *Vector) ReadFrom(r io.Reader) (int64, error) { func (vector Vector) String() string { var sbb strings.Builder sbb.WriteByte('[') - for i := 0; i < len(vector); i++ { + for i := range len(vector) { sbb.WriteString(vector[i].String()) if i != len(vector)-1 { sbb.WriteByte(',') @@ -341,7 +341,7 @@ func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") } - for i := 0; i < len(a); i++ { + for i := range len(a) { res[i].Add(&a[i], &b[i]) } } @@ -350,7 +350,7 @@ func subVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Sub: vectors don't have the same length") } - for i := 0; i < len(a); i++ { + for i := range len(a) { res[i].Sub(&a[i], &b[i]) } } @@ -359,13 +359,13 @@ func scalarMulVecGeneric(res, a Vector, b *Element) { if len(a) != len(res) { panic("vector.ScalarMul: vectors don't have the same length") } - for i := 0; i < len(a); i++ { + for i := range len(a) { res[i].Mul(&a[i], b) } } func sumVecGeneric(res *Element, a Vector) { - for i := 0; i < len(a); i++ { + for i := range len(a) { res.Add(res, &a[i]) } } @@ -375,7 +375,7 @@ func innerProductVecGeneric(res *Element, a, b Vector) { panic("vector.InnerProduct: vectors don't have the same length") } var tmp Element - for i := 0; i < len(a); i++ { + for i := range len(a) { tmp.Mul(&a[i], &b[i]) res.Add(res, &tmp) } @@ -385,60 +385,7 @@ func mulVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Mul: vectors don't have the same length") } - for i := 0; i < len(a); i++ { + for i := range len(a) { res[i].Mul(&a[i], &b[i]) } } - -// TODO @gbotrel make a public package out of that. -// execute executes the work function in parallel. -// this is copy paste from internal/parallel/parallel.go -// as we don't want to generate code importing internal/ -func execute(nbIterations int, work func(int, int), maxCpus ...int) { - - nbTasks := runtime.NumCPU() - if len(maxCpus) == 1 { - nbTasks = maxCpus[0] - if nbTasks < 1 { - nbTasks = 1 - } else if nbTasks > 512 { - nbTasks = 512 - } - } - - if nbTasks == 1 { - // no go routines - work(0, nbIterations) - return - } - - nbIterationsPerCpus := nbIterations / nbTasks - - // more CPUs than tasks: a CPU will work on exactly one iteration - if nbIterationsPerCpus < 1 { - nbIterationsPerCpus = 1 - nbTasks = nbIterations - } - - var wg sync.WaitGroup - - extraTasks := nbIterations - (nbTasks * nbIterationsPerCpus) - extraTasksOffset := 0 - - for i := 0; i < nbTasks; i++ { - wg.Add(1) - _start := i*nbIterationsPerCpus + extraTasksOffset - _end := _start + nbIterationsPerCpus - if extraTasks > 0 { - _end++ - extraTasks-- - extraTasksOffset++ - } - go func() { - work(_start, _end) - wg.Done() - }() - } - - wg.Wait() -} diff --git a/internal/smallfields/tinyfield/vector_test.go b/internal/smallfields/tinyfield/vector_test.go index 46df33275..14116c210 100644 --- a/internal/smallfields/tinyfield/vector_test.go +++ b/internal/smallfields/tinyfield/vector_test.go @@ -183,7 +183,7 @@ func TestVectorOps(t *testing.T) { c := make(Vector, len(a)) c.Add(a, b) - for i := 0; i < len(a); i++ { + for i := range len(a) { var tmp Element tmp.Add(&a[i], &b[i]) if !tmp.Equal(&c[i]) { @@ -197,7 +197,7 @@ func TestVectorOps(t *testing.T) { c := make(Vector, len(a)) c.Sub(a, b) - for i := 0; i < len(a); i++ { + for i := range len(a) { var tmp Element tmp.Sub(&a[i], &b[i]) if !tmp.Equal(&c[i]) { @@ -211,7 +211,7 @@ func TestVectorOps(t *testing.T) { c := make(Vector, len(a)) c.ScalarMul(a, &b) - for i := 0; i < len(a); i++ { + for i := range len(a) { var tmp Element tmp.Mul(&a[i], &b) if !tmp.Equal(&c[i]) { @@ -224,7 +224,7 @@ func TestVectorOps(t *testing.T) { sumVector := func(a Vector) bool { var sum Element computed := a.Sum() - for i := 0; i < len(a); i++ { + for i := range len(a) { sum.Add(&sum, &a[i]) } @@ -234,7 +234,7 @@ func TestVectorOps(t *testing.T) { innerProductVector := func(a, b Vector) bool { computed := a.InnerProduct(b) var innerProduct Element - for i := 0; i < len(a); i++ { + for i := range len(a) { var tmp Element tmp.Mul(&a[i], &b[i]) innerProduct.Add(&innerProduct, &tmp) @@ -249,7 +249,7 @@ func TestVectorOps(t *testing.T) { b[0].SetUint64(0x42) c.Mul(a, b) - for i := 0; i < len(a); i++ { + for i := range len(a) { var tmp Element tmp.Mul(&a[i], &b[i]) if !tmp.Equal(&c[i]) { @@ -335,7 +335,7 @@ func BenchmarkVectorOps(b *testing.B) { _b := b1[:n] _c := c1[:n] b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _c.Add(_a, _b) } }) @@ -345,7 +345,7 @@ func BenchmarkVectorOps(b *testing.B) { _b := b1[:n] _c := c1[:n] b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _c.Sub(_a, _b) } }) @@ -354,7 +354,7 @@ func BenchmarkVectorOps(b *testing.B) { _a := a1[:n] _c := c1[:n] b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _c.ScalarMul(_a, &mixer) } }) @@ -362,7 +362,7 @@ func BenchmarkVectorOps(b *testing.B) { b.Run(fmt.Sprintf("sum %d", n), func(b *testing.B) { _a := a1[:n] b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = _a.Sum() } }) @@ -371,7 +371,7 @@ func BenchmarkVectorOps(b *testing.B) { _a := a1[:n] _b := b1[:n] b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _ = _a.InnerProduct(_b) } }) @@ -381,7 +381,7 @@ func BenchmarkVectorOps(b *testing.B) { _b := b1[:n] _c := c1[:n] b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { _c.Mul(_a, _b) } }) @@ -403,7 +403,7 @@ func genMaxVector(size int) gopter.Gen { qMinusOne := qElement qMinusOne[0]-- - for i := 0; i < size; i++ { + for i := range size { g[i] = qMinusOne } genResult := gopter.NewGenResult(g, gopter.NoShrinker) diff --git a/std/algebra/emulated/maptocurve/doc.go b/std/algebra/emulated/maptocurve/doc.go new file mode 100644 index 000000000..e08dacd41 --- /dev/null +++ b/std/algebra/emulated/maptocurve/doc.go @@ -0,0 +1,13 @@ +// Package maptocurve implements increment-and-check map-to-curve gadgets for +// short Weierstrass curves y² = x³ + ax + b over emulated fields. +// +// Two methods are provided: +// - [Mapper.XIncrement]: encodes X = M·256 + K, verifies the curve equation, and +// ensures Y has a 2^S-th root (inverse-exclusion witness) for j=0 curves. +// Only practical for low 2-adicity fields (S ≤ 4). For high 2-adicity +// fields (e.g. BLS12-377 with S=46, Grumpkin with S=28) the witness search +// becomes infeasible — use [Mapper.YIncrement] instead. +// - [Mapper.YIncrement]: encodes Y = M·256 + K, verifies the curve equation. +// Simpler (no inverse-exclusion witness), works for any 2-adicity, and is +// the recommended method for j=0 curves. +package maptocurve diff --git a/std/algebra/emulated/maptocurve/hints.go b/std/algebra/emulated/maptocurve/hints.go new file mode 100644 index 000000000..4b81e9a9d --- /dev/null +++ b/std/algebra/emulated/maptocurve/hints.go @@ -0,0 +1,365 @@ +package maptocurve + +import ( + "fmt" + "math/big" + + bn254fp "github.com/consensys/gnark-crypto/ecc/bn254/fp" + secp256k1fp "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" + "github.com/consensys/gnark-crypto/ecc/secp256r1" + secp256r1fp "github.com/consensys/gnark-crypto/ecc/secp256r1/fp" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in the package. +func GetHints() []solver.Hint { + return []solver.Hint{ + xIncrementHint, + yIncrementHint, + } +} + +// parseHintInputs extracts the field modulus and message from the hint inputs. +// Format: [nbLimbs, nbBits, q_limbs..., msg_limbs...] +func parseHintInputs(inputs []*big.Int) (q *big.Int, nbLimbs int, nbBits uint, msg *big.Int, err error) { + if len(inputs) < 2 { + return nil, 0, 0, nil, fmt.Errorf("need at least 2 inputs (nbLimbs, nbBits)") + } + nbLimbs = int(inputs[0].Int64()) + nbBits = uint(inputs[1].Int64()) + expected := 2 + 2*nbLimbs + if len(inputs) != expected { + return nil, 0, 0, nil, fmt.Errorf("expected %d inputs, got %d", expected, len(inputs)) + } + q = recompose(inputs[2:2+nbLimbs], nbLimbs, nbBits) + msg = recompose(inputs[2+nbLimbs:2+2*nbLimbs], nbLimbs, nbBits) + return q, nbLimbs, nbBits, msg, nil +} + +// xIncrementHint computes the x-increment witness for a given message. +// +// Inputs: [nbLimbs, nbBits, q_limbs..., msg_limbs...] +// Outputs: [k, x_limbs..., y_limbs..., z_limbs...] +// +// Searches k ∈ [0, T) such that x = msg*T + k lies on the curve and y has a +// 2^s-th root. Only practical for low 2-adicity fields (S ≤ 4). +func xIncrementHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + q, nbLimbs, nbBits, msg, err := parseHintInputs(inputs) + if err != nil { + return fmt.Errorf("xIncrementHint: %w", err) + } + + switch { + case q.Cmp(bn254fp.Modulus()) == 0: + return xIncrementBN254(nbLimbs, nbBits, msg, outputs) + case q.Cmp(secp256k1fp.Modulus()) == 0: + return xIncrementSecp256k1(nbLimbs, nbBits, msg, outputs) + case q.Cmp(secp256r1fp.Modulus()) == 0: + return xIncrementSecp256r1(nbLimbs, nbBits, msg, outputs) + default: + return fmt.Errorf("xIncrementHint: unsupported field modulus") + } +} + +// yIncrementHint computes the y-increment witness for a given message. +// +// Inputs: [nbLimbs, nbBits, q_limbs..., msg_limbs...] +// Outputs: [k, x_limbs...] +// +// For j=0 curves (a=0): x = cbrt(y² - b) where y = msg*T + k. +// For P-256 (a≠0): x is found via Cardano's formula on x³ − 3x + (b − y²) = 0. +func yIncrementHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + q, nbLimbs, nbBits, msg, err := parseHintInputs(inputs) + if err != nil { + return fmt.Errorf("yIncrementHint: %w", err) + } + + switch { + case q.Cmp(bn254fp.Modulus()) == 0: + return yIncrementBN254(nbLimbs, nbBits, msg, outputs) + case q.Cmp(secp256k1fp.Modulus()) == 0: + return yIncrementSecp256k1(nbLimbs, nbBits, msg, outputs) + case q.Cmp(secp256r1fp.Modulus()) == 0: + return yIncrementSecp256r1(nbLimbs, nbBits, msg, outputs) + default: + return fmt.Errorf("yIncrementHint: unsupported field modulus") + } +} + +// --- BN254 --- + +// BN254: y² = x³ + 3, S=1 +func xIncrementBN254(nbLimbs int, nbBits uint, msg *big.Int, outputs []*big.Int) error { + const s = 1 + var msgFp, bFp, tFp, xBase bn254fp.Element + msgFp.SetBigInt(msg) + bFp.SetUint64(3) + tFp.SetUint64(T) + xBase.Mul(&msgFp, &tFp) + + for k := uint64(0); k < T; k++ { + var kFp, x, x2, rhs, y bn254fp.Element + kFp.SetUint64(k) + x.Add(&xBase, &kFp) + + x2.Square(&x) + rhs.Mul(&x2, &x) + rhs.Add(&rhs, &bFp) + + if y.Sqrt(&rhs) == nil { + continue + } + + z := nthRoot2SBN254(&y, s) + if z == nil { + y.Neg(&y) + z = nthRoot2SBN254(&y, s) + if z == nil { + continue + } + } + + var xBig, yBig, zBig big.Int + outputs[0].SetUint64(k) + decompose(x.BigInt(&xBig), nbLimbs, nbBits, outputs[1:1+nbLimbs]) + decompose(y.BigInt(&yBig), nbLimbs, nbBits, outputs[1+nbLimbs:1+2*nbLimbs]) + decompose(z.BigInt(&zBig), nbLimbs, nbBits, outputs[1+2*nbLimbs:1+3*nbLimbs]) + return nil + } + return fmt.Errorf("xIncrementHint: no valid k found for BN254 (s=%d)", s) +} + +func nthRoot2SBN254(a *bn254fp.Element, s int) *bn254fp.Element { + z := new(bn254fp.Element).Set(a) + for i := 0; i < s; i++ { + if z.Sqrt(z) == nil { + return nil + } + } + return z +} + +func yIncrementBN254(nbLimbs int, nbBits uint, msg *big.Int, outputs []*big.Int) error { + var msgFp, bFp, tFp, yBase bn254fp.Element + msgFp.SetBigInt(msg) + bFp.SetUint64(3) + tFp.SetUint64(T) + yBase.Mul(&msgFp, &tFp) + + for k := uint64(0); k < T; k++ { + var kFp, y, y2, rhs, x bn254fp.Element + kFp.SetUint64(k) + y.Add(&yBase, &kFp) + + y2.Square(&y) + rhs.Sub(&y2, &bFp) + + if x.Cbrt(&rhs) == nil { + continue + } + + var xBig big.Int + outputs[0].SetUint64(k) + decompose(x.BigInt(&xBig), nbLimbs, nbBits, outputs[1:1+nbLimbs]) + return nil + } + return fmt.Errorf("yIncrementHint: no valid k found for BN254") +} + +// --- secp256k1 (y² = x³ + 7, a=0, S=1) --- + +// secp256k1: y² = x³ + 7, S=1 +func xIncrementSecp256k1(nbLimbs int, nbBits uint, msg *big.Int, outputs []*big.Int) error { + const s = 1 + var msgFp, bFp, tFp, xBase secp256k1fp.Element + msgFp.SetBigInt(msg) + bFp.SetUint64(7) + tFp.SetUint64(T) + xBase.Mul(&msgFp, &tFp) + + for k := uint64(0); k < T; k++ { + var kFp, x, x2, rhs, y secp256k1fp.Element + kFp.SetUint64(k) + x.Add(&xBase, &kFp) + + x2.Square(&x) + rhs.Mul(&x2, &x) + rhs.Add(&rhs, &bFp) + + if y.Sqrt(&rhs) == nil { + continue + } + + z := nthRoot2SSecp256k1(&y, s) + if z == nil { + y.Neg(&y) + z = nthRoot2SSecp256k1(&y, s) + if z == nil { + continue + } + } + + var xBig, yBig, zBig big.Int + outputs[0].SetUint64(k) + decompose(x.BigInt(&xBig), nbLimbs, nbBits, outputs[1:1+nbLimbs]) + decompose(y.BigInt(&yBig), nbLimbs, nbBits, outputs[1+nbLimbs:1+2*nbLimbs]) + decompose(z.BigInt(&zBig), nbLimbs, nbBits, outputs[1+2*nbLimbs:1+3*nbLimbs]) + return nil + } + return fmt.Errorf("xIncrementHint: no valid k found for secp256k1 (s=%d)", s) +} + +func nthRoot2SSecp256k1(a *secp256k1fp.Element, s int) *secp256k1fp.Element { + z := new(secp256k1fp.Element).Set(a) + for i := 0; i < s; i++ { + if z.Sqrt(z) == nil { + return nil + } + } + return z +} + +func yIncrementSecp256k1(nbLimbs int, nbBits uint, msg *big.Int, outputs []*big.Int) error { + var msgFp, bFp, tFp, yBase secp256k1fp.Element + msgFp.SetBigInt(msg) + bFp.SetUint64(7) + tFp.SetUint64(T) + yBase.Mul(&msgFp, &tFp) + + for k := uint64(0); k < T; k++ { + var kFp, y, y2, rhs, x secp256k1fp.Element + kFp.SetUint64(k) + y.Add(&yBase, &kFp) + + y2.Square(&y) + rhs.Sub(&y2, &bFp) + + if x.Cbrt(&rhs) == nil { + continue + } + + var xBig big.Int + outputs[0].SetUint64(k) + decompose(x.BigInt(&xBig), nbLimbs, nbBits, outputs[1:1+nbLimbs]) + return nil + } + return fmt.Errorf("yIncrementHint: no valid k found for secp256k1") +} + +// --- secp256r1 / P-256 (y² = x³ + ax + b, a≠0, S=1) --- + +// secp256r1 / P-256: y² = x³ - 3x + b, S=1 +func xIncrementSecp256r1(nbLimbs int, nbBits uint, msg *big.Int, outputs []*big.Int) error { + const s = 1 + // a = -3 mod q, b from curve params + p := sw_emulated.GetP256Params() + var msgFp, aFp, bFp, tFp, xBase secp256r1fp.Element + msgFp.SetBigInt(msg) + aFp.SetBigInt(p.A) + bFp.SetBigInt(p.B) + tFp.SetUint64(T) + xBase.Mul(&msgFp, &tFp) + + for k := uint64(0); k < T; k++ { + var kFp, x, x2, rhs, y secp256r1fp.Element + kFp.SetUint64(k) + x.Add(&xBase, &kFp) + + // rhs = x³ + a·x + b + x2.Square(&x) + rhs.Mul(&x2, &x) + var ax secp256r1fp.Element + ax.Mul(&aFp, &x) + rhs.Add(&rhs, &ax) + rhs.Add(&rhs, &bFp) + + if y.Sqrt(&rhs) == nil { + continue + } + + z := nthRoot2SSecp256r1(&y, s) + if z == nil { + y.Neg(&y) + z = nthRoot2SSecp256r1(&y, s) + if z == nil { + continue + } + } + + var xBig, yBig, zBig big.Int + outputs[0].SetUint64(k) + decompose(x.BigInt(&xBig), nbLimbs, nbBits, outputs[1:1+nbLimbs]) + decompose(y.BigInt(&yBig), nbLimbs, nbBits, outputs[1+nbLimbs:1+2*nbLimbs]) + decompose(z.BigInt(&zBig), nbLimbs, nbBits, outputs[1+2*nbLimbs:1+3*nbLimbs]) + return nil + } + return fmt.Errorf("xIncrementHint: no valid k found for secp256r1 (s=%d)", s) +} + +func nthRoot2SSecp256r1(a *secp256r1fp.Element, s int) *secp256r1fp.Element { + z := new(secp256r1fp.Element).Set(a) + for i := 0; i < s; i++ { + if z.Sqrt(z) == nil { + return nil + } + } + return z +} + +// secp256r1 / P-256: y² = x³ − 3x + b, y-increment uses Cardano solver. +func yIncrementSecp256r1(nbLimbs int, nbBits uint, msg *big.Int, outputs []*big.Int) error { + p := sw_emulated.GetP256Params() + var bFp, tFp, msgFp, yBase secp256r1fp.Element + msgFp.SetBigInt(msg) + bFp.SetBigInt(p.B) + tFp.SetUint64(T) + yBase.Mul(&msgFp, &tFp) + + for k := uint64(0); k < T; k++ { + var kFp, y, y2, c secp256r1fp.Element + kFp.SetUint64(k) + y.Add(&yBase, &kFp) + + // x³ − 3x + c = 0 where c = b − y² + y2.Square(&y) + c.Sub(&bFp, &y2) + + roots := secp256r1.CardanoRoots(c) + if len(roots) == 0 { + continue + } + + var xBig big.Int + outputs[0].SetUint64(k) + decompose(roots[0].BigInt(&xBig), nbLimbs, nbBits, outputs[1:1+nbLimbs]) + return nil + } + return fmt.Errorf("yIncrementHint: no valid k found for secp256r1") +} + +// --- limb helpers --- + +// recompose reconstructs a big.Int from its limbs (little-endian, nbBits per limb). +func recompose(limbs []*big.Int, nbLimbs int, nbBits uint) *big.Int { + result := new(big.Int) + for i := nbLimbs - 1; i >= 0; i-- { + result.Lsh(result, nbBits) + result.Add(result, limbs[i]) + } + return result +} + +// decompose splits v into nbLimbs limbs of nbBits each (little-endian). +func decompose(v *big.Int, nbLimbs int, nbBits uint, outputs []*big.Int) { + mask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), nbBits), big.NewInt(1)) + tmp := new(big.Int).Set(v) + for i := 0; i < nbLimbs; i++ { + outputs[i].And(tmp, mask) + tmp.Rsh(tmp, nbBits) + } +} diff --git a/std/algebra/emulated/maptocurve/maptocurve.go b/std/algebra/emulated/maptocurve/maptocurve.go new file mode 100644 index 000000000..2ddb94455 --- /dev/null +++ b/std/algebra/emulated/maptocurve/maptocurve.go @@ -0,0 +1,207 @@ +package maptocurve + +import ( + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/algebra/emulated/sw_emulated" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/rangecheck" +) + +const T = 256 // increment window size (8 bits) + +// Mapper provides increment-and-check map-to-curve operations for emulated +// short Weierstrass curves y² = x³ + ax + b. +type Mapper[F emulated.FieldParams] struct { + api frontend.API + field *emulated.Field[F] + a, b *big.Int + s int // 2-adicity v₂(q-1) for x-increment inverse-exclusion + nbLimbs uint // effective number of limbs for the emulated field + nbBits uint // effective bits per limb +} + +// NewMapper creates a new Mapper for the curve defined by field type F. +// Curve coefficients are read from [sw_emulated.GetCurveParams] when available. +// The 2-adicity s (for x-increment) is auto-detected from the field modulus. +func NewMapper[F emulated.FieldParams](api frontend.API) (*Mapper[F], error) { + field, err := emulated.NewField[F](api) + if err != nil { + return nil, err + } + a, b := curveCoefficients[F]() + s := twoAdicity[F]() + nbLimbs, nbBits := emulated.GetEffectiveFieldParams[F](api.Compiler().Field()) + return &Mapper[F]{api: api, field: field, a: a, b: b, s: s, nbLimbs: nbLimbs, nbBits: nbBits}, nil +} + +// XIncrement maps msg to a curve point (x, y) using the x-increment method: +// +// X = msg·256 + k, Y² = X³ + aX + b, Z^{2^s} = Y +// +// where k ∈ [0, 256) is found by the hint, and the Z witness chain ensures Y +// is not the inverse of a valid y-coordinate (needed for j=0 curves). +func (m *Mapper[F]) XIncrement(msg *emulated.Element[F]) (x, y *emulated.Element[F], err error) { + fp := m.field + nbLimbs := int(m.nbLimbs) + + // hint inputs: [nbLimbs, nbBits, q_limbs..., msg_limbs...] + // hint outputs: [k, x_limbs..., y_limbs..., z_limbs...] + hintOutputs := 1 + 3*nbLimbs // k + x + y + z + hintInputs := m.buildHintInputs(msg) + + res, err := m.api.Compiler().NewHint(xIncrementHint, hintOutputs, hintInputs...) + if err != nil { + return nil, nil, err + } + + k := res[0] + xLimbs := res[1 : 1+nbLimbs] + yLimbs := res[1+nbLimbs : 1+2*nbLimbs] + zLimbs := res[1+2*nbLimbs : 1+3*nbLimbs] + + xEl := fp.NewElement(xLimbs) + yEl := fp.NewElement(yLimbs) + zEl := fp.NewElement(zLimbs) + + // (1) Curve equation: Y² = X³ + a·X + b + m.assertOnCurve(xEl, yEl) + + // (2) Encoding: X = msg*T + K + kEl := m.nativeToEmulated(k) + Tconst := fp.NewElement(big.NewInt(T)) + fp.AssertIsEqual(xEl, fp.Add(fp.Mul(msg, Tconst), kEl)) + + // (3) Range: 0 ≤ K < 256 + rangecheck.New(m.api).Check(k, 8) + + // (4) 2^S-th power witness: Z^{2^S} = Y + w := zEl + for i := 0; i < m.s; i++ { + w = fp.Mul(w, w) + } + fp.AssertIsEqual(w, yEl) + + return xEl, yEl, nil +} + +// YIncrement maps msg to a curve point (x, y) using the y-increment method: +// +// Y = msg·256 + k, Y² = X³ + aX + b +// +// where k ∈ [0, 256) is found by the hint. No inverse-exclusion witness is +// needed, making this simpler and recommended for j=0 curves. +func (m *Mapper[F]) YIncrement(msg *emulated.Element[F]) (x, y *emulated.Element[F], err error) { + fp := m.field + nbLimbs := int(m.nbLimbs) + + // hint inputs: [nbLimbs, nbBits, q_limbs..., msg_limbs...] + // hint outputs: [k, x_limbs...] + hintOutputs := 1 + nbLimbs // k + x + hintInputs := m.buildHintInputs(msg) + + res, err := m.api.Compiler().NewHint(yIncrementHint, hintOutputs, hintInputs...) + if err != nil { + return nil, nil, err + } + + k := res[0] + xLimbs := res[1 : 1+nbLimbs] + + xEl := fp.NewElement(xLimbs) + + // Reconstruct Y = msg*T + K + kEl := m.nativeToEmulated(k) + Tconst := fp.NewElement(big.NewInt(T)) + yEl := fp.Add(fp.Mul(msg, Tconst), kEl) + + // (1) Curve equation: Y² = X³ + a·X + b + m.assertOnCurve(xEl, yEl) + + // (2) Range: 0 ≤ K < 256 + rangecheck.New(m.api).Check(k, 8) + + return xEl, yEl, nil +} + +// buildHintInputs constructs hint inputs: [nbLimbs, nbBits, q_limbs..., msg_limbs...] +// Curve coefficients are not passed; the hint dispatches on q to look them up. +func (m *Mapper[F]) buildHintInputs(msg *emulated.Element[F]) []frontend.Variable { + fp := m.field + var fparams F + q := fparams.Modulus() + nbLimbs := int(m.nbLimbs) + + inputs := make([]frontend.Variable, 0, 2+2*nbLimbs) + inputs = append(inputs, m.nbLimbs) + inputs = append(inputs, m.nbBits) + + qLimbs := decomposeBigInt(q, nbLimbs, m.nbBits) + for _, l := range qLimbs { + inputs = append(inputs, l) + } + + msgLimbs := fp.Reduce(msg).Limbs + for i := 0; i < nbLimbs; i++ { + inputs = append(inputs, msgLimbs[i]) + } + return inputs +} + +// decomposeBigInt splits v into nbLimbs limbs of nbBits each (little-endian). +func decomposeBigInt(v *big.Int, nbLimbs int, nbBits uint) []*big.Int { + mask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), nbBits), big.NewInt(1)) + tmp := new(big.Int).Set(v) + result := make([]*big.Int, nbLimbs) + for i := 0; i < nbLimbs; i++ { + result[i] = new(big.Int).And(tmp, mask) + tmp.Rsh(tmp, nbBits) + } + return result +} + +// assertOnCurve checks Y² = X³ + a·X + b. +func (m *Mapper[F]) assertOnCurve(x, y *emulated.Element[F]) { + fp := m.field + lhs := fp.Mul(y, y) + x2 := fp.Mul(x, x) + rhs := fp.Mul(x2, x) + if m.a.Sign() != 0 { + aVal := fp.NewElement(m.a) + rhs = fp.Add(rhs, fp.Mul(aVal, x)) + } + bVal := fp.NewElement(m.b) + rhs = fp.Add(rhs, bVal) + fp.AssertIsEqual(lhs, rhs) +} + +// nativeToEmulated converts a native variable (fitting in one limb) to an +// emulated element. +func (m *Mapper[F]) nativeToEmulated(v frontend.Variable) *emulated.Element[F] { + nbLimbs := int(m.nbLimbs) + limbs := make([]frontend.Variable, nbLimbs) + limbs[0] = v + for i := 1; i < nbLimbs; i++ { + limbs[i] = 0 + } + return m.field.NewElement(limbs) +} + +// curveCoefficients returns the short Weierstrass coefficients (a, b) for the +// curve over field F. +func curveCoefficients[F emulated.FieldParams]() (a, b *big.Int) { + p := sw_emulated.GetCurveParams[F]() + return p.A, p.B +} + +// twoAdicity returns v₂(q-1) for the field modulus q. +func twoAdicity[F emulated.FieldParams]() int { + var t F + qm1 := new(big.Int).Sub(t.Modulus(), big.NewInt(1)) + s := 0 + for qm1.Bit(s) == 0 { + s++ + } + return s +} diff --git a/std/algebra/emulated/maptocurve/maptocurve_test.go b/std/algebra/emulated/maptocurve/maptocurve_test.go new file mode 100644 index 000000000..d7f495299 --- /dev/null +++ b/std/algebra/emulated/maptocurve/maptocurve_test.go @@ -0,0 +1,175 @@ +package maptocurve + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/test" +) + +// test message values +var testMessages = []*big.Int{ + big.NewInt(0), + big.NewInt(1), + big.NewInt(42), + big.NewInt(123456789), +} + +// --- X-Increment tests --- + +type xIncrementCircuit[F emulated.FieldParams] struct { + M emulated.Element[F] +} + +func (c *xIncrementCircuit[F]) Define(api frontend.API) error { + m, err := NewMapper[F](api) + if err != nil { + return err + } + _, _, err = m.XIncrement(&c.M) + return err +} + +func testXIncrement[F emulated.FieldParams](t *testing.T) { + t.Helper() + assert := test.NewAssert(t) + opts := []test.TestingOption{test.WithCurves(ecc.BN254)} + for _, msg := range testMessages { + opts = append(opts, test.WithValidAssignment(&xIncrementCircuit[F]{ + M: emulated.ValueOf[F](msg), + })) + } + assert.CheckCircuit(&xIncrementCircuit[F]{}, opts...) +} + +func TestXIncrementEmulatedBN254(t *testing.T) { testXIncrement[emulated.BN254Fp](t) } +func TestXIncrementEmulatedSecp256k1(t *testing.T) { testXIncrement[emulated.Secp256k1Fp](t) } +func TestXIncrementEmulatedP256(t *testing.T) { testXIncrement[emulated.P256Fp](t) } + +// --- Y-Increment tests --- + +type yIncrementCircuit[F emulated.FieldParams] struct { + M emulated.Element[F] +} + +func (c *yIncrementCircuit[F]) Define(api frontend.API) error { + m, err := NewMapper[F](api) + if err != nil { + return err + } + _, _, err = m.YIncrement(&c.M) + return err +} + +func testYIncrement[F emulated.FieldParams](t *testing.T) { + t.Helper() + assert := test.NewAssert(t) + opts := []test.TestingOption{test.WithCurves(ecc.BN254)} + for _, msg := range testMessages { + opts = append(opts, test.WithValidAssignment(&yIncrementCircuit[F]{ + M: emulated.ValueOf[F](msg), + })) + } + assert.CheckCircuit(&yIncrementCircuit[F]{}, opts...) +} + +func TestYIncrementEmulatedBN254(t *testing.T) { testYIncrement[emulated.BN254Fp](t) } +func TestYIncrementEmulatedSecp256k1(t *testing.T) { testYIncrement[emulated.Secp256k1Fp](t) } +func TestYIncrementEmulatedP256(t *testing.T) { testYIncrement[emulated.P256Fp](t) } + +// --- Benchmarks --- + +func BenchmarkXIncrementEmulated(b *testing.B) { + b.Run("BN254/r1cs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &xIncrementCircuit[emulated.BN254Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("BN254/scs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &xIncrementCircuit[emulated.BN254Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("secp256k1/r1cs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &xIncrementCircuit[emulated.Secp256k1Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("secp256k1/scs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &xIncrementCircuit[emulated.Secp256k1Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("P256/r1cs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &xIncrementCircuit[emulated.P256Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("P256/scs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &xIncrementCircuit[emulated.P256Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) +} + +func BenchmarkYIncrementEmulated(b *testing.B) { + b.Run("BN254/r1cs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &yIncrementCircuit[emulated.BN254Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("BN254/scs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &yIncrementCircuit[emulated.BN254Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("secp256k1/r1cs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &yIncrementCircuit[emulated.Secp256k1Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("secp256k1/scs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &yIncrementCircuit[emulated.Secp256k1Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("P256/r1cs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &yIncrementCircuit[emulated.P256Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("P256/scs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &yIncrementCircuit[emulated.P256Fp]{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) +} diff --git a/std/algebra/native/maptocurve_bls12377/doc.go b/std/algebra/native/maptocurve_bls12377/doc.go new file mode 100644 index 000000000..166db9204 --- /dev/null +++ b/std/algebra/native/maptocurve_bls12377/doc.go @@ -0,0 +1,8 @@ +// Package maptocurve_bls12377 implements the y-increment map-to-curve gadget +// for the BLS12-377 curve y² = x³ + 1 over its base field (= BW6-761 scalar field). +// Circuits compile over ecc.BW6_761. +// +// Only the y-increment method is provided. The x-increment method is not practical +// for BLS12-377 because its high 2-adicity (S=46) makes the inverse-exclusion +// witness search infeasible. +package maptocurve_bls12377 diff --git a/std/algebra/native/maptocurve_bls12377/hints.go b/std/algebra/native/maptocurve_bls12377/hints.go new file mode 100644 index 000000000..3a904abba --- /dev/null +++ b/std/algebra/native/maptocurve_bls12377/hints.go @@ -0,0 +1,56 @@ +package maptocurve_bls12377 + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/consensys/gnark/constraint/solver" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in the package. +func GetHints() []solver.Hint { + return []solver.Hint{ + yIncrementHint, + } +} + +// yIncrementHint computes y-increment witness for BLS12-377 (y² = x³ + 1). +// +// Inputs: [msg] +// Outputs: [k, x] where y = msg*T + k, x = cbrt(y² - 1) +func yIncrementHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 1 { + return fmt.Errorf("yIncrementHint: expected 1 input, got %d", len(inputs)) + } + + var msg, y, y2, rhs, one, tFp, yBase fp.Element + msg.SetBigInt(inputs[0]) + one.SetOne() + tFp.SetUint64(T) + yBase.Mul(&msg, &tFp) + + for k := uint64(0); k < T; k++ { + var kFp fp.Element + kFp.SetUint64(k) + y.Add(&yBase, &kFp) + + // x³ = y² - 1 + y2.Square(&y) + rhs.Sub(&y2, &one) + + var x fp.Element + if x.Cbrt(&rhs) == nil { + continue + } + + outputs[0].SetUint64(k) + x.BigInt(outputs[1]) + return nil + } + return fmt.Errorf("yIncrementHint: no valid k found for BLS12-377") +} diff --git a/std/algebra/native/maptocurve_bls12377/maptocurve.go b/std/algebra/native/maptocurve_bls12377/maptocurve.go new file mode 100644 index 000000000..4206b0d1d --- /dev/null +++ b/std/algebra/native/maptocurve_bls12377/maptocurve.go @@ -0,0 +1,38 @@ +package maptocurve_bls12377 + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/rangecheck" +) + +const ( + T = 256 // increment window size (8 bits) + B = 1 // curve coefficient b for BLS12-377: y² = x³ + 1 + S = 46 // 2-adicity v₂(q-1) for BLS12-377 Fp +) + +// YIncrement maps msg to a point (x, y) on y² = x³ + 1 using y-increment: +// +// Y = msg·256 + k, Y² = X³ + 1 +func YIncrement(api frontend.API, msg frontend.Variable) (x, y frontend.Variable, err error) { + res, err := api.Compiler().NewHint(yIncrementHint, 2, msg) + if err != nil { + return nil, nil, err + } + k := res[0] + x = res[1] + + y = api.Add(api.Mul(msg, T), k) + + // Y² = X³ + B + lhs := api.Mul(y, y) + rhs := api.Mul(x, x) + rhs = api.Mul(rhs, x) + rhs = api.Add(rhs, B) + api.AssertIsEqual(lhs, rhs) + + // 0 ≤ K < 256 + rangecheck.New(api).Check(k, 8) + + return x, y, nil +} diff --git a/std/algebra/native/maptocurve_bls12377/maptocurve_test.go b/std/algebra/native/maptocurve_bls12377/maptocurve_test.go new file mode 100644 index 000000000..ece924673 --- /dev/null +++ b/std/algebra/native/maptocurve_bls12377/maptocurve_test.go @@ -0,0 +1,49 @@ +package maptocurve_bls12377 + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" +) + +type yIncrementCircuit struct { + M frontend.Variable +} + +func (c *yIncrementCircuit) Define(api frontend.API) error { + _, _, err := YIncrement(api, c.M) + return err +} + +func TestYIncrement(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit( + &yIncrementCircuit{}, + test.WithValidAssignment(&yIncrementCircuit{M: 0}), + test.WithValidAssignment(&yIncrementCircuit{M: 1}), + test.WithValidAssignment(&yIncrementCircuit{M: 42}), + test.WithValidAssignment(&yIncrementCircuit{M: 123456789}), + test.WithCurves(ecc.BW6_761), + ) +} + +func BenchmarkYIncrement(b *testing.B) { + b.Run("r1cs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BW6_761.ScalarField(), r1cs.NewBuilder, &yIncrementCircuit{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("scs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BW6_761.ScalarField(), scs.NewBuilder, &yIncrementCircuit{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) +} diff --git a/std/algebra/native/maptocurve_grumpkin/doc.go b/std/algebra/native/maptocurve_grumpkin/doc.go new file mode 100644 index 000000000..c6a556df8 --- /dev/null +++ b/std/algebra/native/maptocurve_grumpkin/doc.go @@ -0,0 +1,8 @@ +// Package maptocurve_grumpkin implements the y-increment map-to-curve gadget +// for the Grumpkin curve y² = x³ - 17 over its base field (= BN254 scalar field). +// Circuits compile over ecc.BN254. +// +// Only the y-increment method is provided. The x-increment method is not practical +// for Grumpkin because its high 2-adicity (S=28) makes the inverse-exclusion +// witness search infeasible. +package maptocurve_grumpkin diff --git a/std/algebra/native/maptocurve_grumpkin/hints.go b/std/algebra/native/maptocurve_grumpkin/hints.go new file mode 100644 index 000000000..fbbf108f4 --- /dev/null +++ b/std/algebra/native/maptocurve_grumpkin/hints.go @@ -0,0 +1,56 @@ +package maptocurve_grumpkin + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc/grumpkin/fp" + "github.com/consensys/gnark/constraint/solver" +) + +func init() { + solver.RegisterHint(GetHints()...) +} + +// GetHints returns all hint functions used in the package. +func GetHints() []solver.Hint { + return []solver.Hint{ + yIncrementHint, + } +} + +// yIncrementHint computes y-increment witness for Grumpkin (y² = x³ - 17). +// +// Inputs: [msg] +// Outputs: [k, x] where y = msg*T + k, x = cbrt(y² + 17) +func yIncrementHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != 1 { + return fmt.Errorf("yIncrementHint: expected 1 input, got %d", len(inputs)) + } + + var msg, y, y2, rhs, b17, tFp, yBase fp.Element + msg.SetBigInt(inputs[0]) + b17.SetUint64(17) + tFp.SetUint64(T) + yBase.Mul(&msg, &tFp) + + for k := uint64(0); k < T; k++ { + var kFp fp.Element + kFp.SetUint64(k) + y.Add(&yBase, &kFp) + + // x³ = y² + 17 + y2.Square(&y) + rhs.Add(&y2, &b17) + + var x fp.Element + if x.Cbrt(&rhs) == nil { + continue + } + + outputs[0].SetUint64(k) + x.BigInt(outputs[1]) + return nil + } + return fmt.Errorf("yIncrementHint: no valid k found for Grumpkin") +} diff --git a/std/algebra/native/maptocurve_grumpkin/maptocurve.go b/std/algebra/native/maptocurve_grumpkin/maptocurve.go new file mode 100644 index 000000000..40d88cb42 --- /dev/null +++ b/std/algebra/native/maptocurve_grumpkin/maptocurve.go @@ -0,0 +1,39 @@ +package maptocurve_grumpkin + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/rangecheck" +) + +const ( + T = 256 // increment window size (8 bits) + B = -17 // curve coefficient b for Grumpkin: y² = x³ - 17 +) + +// YIncrement maps msg to a point (x, y) on y² = x³ - 17 using y-increment: +// +// Y = msg·256 + k, Y² = X³ - 17 +func YIncrement(api frontend.API, msg frontend.Variable) (x, y frontend.Variable, err error) { + // hint outputs: [k, x] + res, err := api.Compiler().NewHint(yIncrementHint, 2, msg) + if err != nil { + return nil, nil, err + } + k := res[0] + x = res[1] + + // Reconstruct Y = msg*T + K + y = api.Add(api.Mul(msg, T), k) + + // (1) Y² = X³ + B + lhs := api.Mul(y, y) + rhs := api.Mul(x, x) + rhs = api.Mul(rhs, x) + rhs = api.Add(rhs, B) + api.AssertIsEqual(lhs, rhs) + + // (2) 0 ≤ K < 256 + rangecheck.New(api).Check(k, 8) + + return x, y, nil +} diff --git a/std/algebra/native/maptocurve_grumpkin/maptocurve_test.go b/std/algebra/native/maptocurve_grumpkin/maptocurve_test.go new file mode 100644 index 000000000..9b4672a33 --- /dev/null +++ b/std/algebra/native/maptocurve_grumpkin/maptocurve_test.go @@ -0,0 +1,49 @@ +package maptocurve_grumpkin + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" +) + +type yIncrementCircuit struct { + M frontend.Variable +} + +func (c *yIncrementCircuit) Define(api frontend.API) error { + _, _, err := YIncrement(api, c.M) + return err +} + +func TestYIncrement(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit( + &yIncrementCircuit{}, + test.WithValidAssignment(&yIncrementCircuit{M: 0}), + test.WithValidAssignment(&yIncrementCircuit{M: 1}), + test.WithValidAssignment(&yIncrementCircuit{M: 42}), + test.WithValidAssignment(&yIncrementCircuit{M: 123456789}), + test.WithCurves(ecc.BN254), + ) +} + +func BenchmarkYIncrement(b *testing.B) { + b.Run("r1cs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &yIncrementCircuit{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) + b.Run("scs", func(b *testing.B) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &yIncrementCircuit{}) + if err != nil { + b.Fatal(err) + } + b.Logf("%d constraints", ccs.GetNbConstraints()) + }) +}