diff --git a/std/hash/hash.go b/std/hash/hash.go index a61c0bc398..cd27fbed95 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/lookup/logderivlookup" "github.com/consensys/gnark/std/math/uints" ) @@ -26,6 +27,13 @@ type FieldHasher interface { Reset() } +// DynamicLengthFieldHasher can compute hashes of lengths unknown at compile time. +type DynamicLengthFieldHasher interface { + FieldHasher + // SumWithLength computes the hash of the first l inputs written into the hash. + SumWithLength(l frontend.Variable) frontend.Variable +} + // StateStorer allows to store and retrieve the state of a hash function. type StateStorer interface { FieldHasher @@ -108,48 +116,68 @@ type Compressor interface { } type merkleDamgardHasher struct { - state frontend.Variable - iv frontend.Variable - f Compressor - api frontend.API + state []frontend.Variable // state after being updated with each written element + stateTable logderivlookup.Table // stateTable always contains a prefix of h.state + stateTableLen int + f Compressor + api frontend.API } -// NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash -// initialState is a value whose preimage is not known +// NewMerkleDamgardHasher range-extends a 2-1 one-way hash compression function into a hash by way of the Merkle-Damgård construction. +// Parameters: +// - api: constraint builder +// - f: 2-1 hash compression (one-way) function +// - initialState: the initialization vector (IV) in the Merkle-Damgård chain. It must be a value whose preimage is not known. func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) StateStorer { return &merkleDamgardHasher{ - state: initialState, - iv: initialState, + state: []frontend.Variable{initialState}, f: f, api: api, } } func (h *merkleDamgardHasher) Reset() { - h.state = h.iv + h.state = h.state[:1] + h.stateTableLen = 0 + h.stateTable = nil } func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { for _, d := range data { - h.state = h.f.Compress(h.state, d) + h.state = append(h.state, h.f.Compress(h.state[len(h.state)-1], d)) } } func (h *merkleDamgardHasher) Sum() frontend.Variable { - return h.state + return h.state[len(h.state)-1] +} + +// SumWithLength computes the Merkle-Damgård hash of the input data, truncated at the given length. +// Parameters: +// - length: length of the prefix of data to be hashed. The verifier will not accept a value outside the range {0, 1, ..., len(data)}. +// The gnark prover will refuse to attempt to generate such an unsuccessful proof. +func (h *merkleDamgardHasher) SumWithLength(length frontend.Variable) frontend.Variable { + if h.stateTable == nil { + h.stateTable = logderivlookup.New(h.api) + } + for h.stateTableLen < len(h.state) { + h.stateTable.Insert(h.state[h.stateTableLen]) + h.stateTableLen++ + } + return h.stateTable.Lookup(length)[0] } func (h *merkleDamgardHasher) State() []frontend.Variable { - return []frontend.Variable{h.state} + return []frontend.Variable{h.state[len(h.state)-1]} } func (h *merkleDamgardHasher) SetState(state []frontend.Variable) error { - if h.state != h.iv { - return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state") - } if len(state) != 1 { return fmt.Errorf("expected one state variable, got %d", len(state)) } - h.state = state[0] + if len(h.state) != 1 { + return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state") + } + h.state = append(h.state, state[0]) return nil } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index f6c57736df..791ff7b1e4 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -1,11 +1,14 @@ package poseidon2_test import ( + "fmt" "testing" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/poseidon2" gkr_poseidon2 "github.com/consensys/gnark/std/hash/poseidon2/gkr-poseidon2" "github.com/consensys/gnark/test" @@ -13,39 +16,66 @@ import ( type poseidon2Circuit struct { Input []frontend.Variable - Expected frontend.Variable `gnark:",public"` + Expected []frontend.Variable `gnark:",public"` // Expected[i] = H(Input[:i+1]) } func (c *poseidon2Circuit) Define(api frontend.API) error { + if len(c.Input) != len(c.Expected) { + return fmt.Errorf("length mismatch") + } hsh, err := poseidon2.New(api) if err != nil { return err } + varlen := hsh.(hash.DynamicLengthFieldHasher) + + hsh, err = poseidon2.New(api) + if err != nil { + return err + } + gkr, err := gkr_poseidon2.New(api) if err != nil { return err } - hsh.Write(c.Input...) - api.AssertIsEqual(hsh.Sum(), c.Expected) - gkr.Write(c.Input...) - api.AssertIsEqual(gkr.Sum(), c.Expected) + + varlen.Write(c.Input...) + + for i := range c.Input { + hsh.Write(c.Input[i]) + api.AssertIsEqual(c.Expected[i], hsh.Sum()) + gkr.Write(c.Input[i]) + api.AssertIsEqual(c.Expected[i], gkr.Sum()) + api.AssertIsEqual(c.Expected[i], varlen.SumWithLength(i+1)) + } return nil } func TestPoseidon2Hash(t *testing.T) { assert := test.NewAssert(t) + var buf [fr.Bytes]byte const nbInputs = 5 // prepare expected output h := poseidonbls12377.NewMerkleDamgardHasher() - circInput := make([]frontend.Variable, nbInputs) - for i := range nbInputs { - _, err := h.Write([]byte{byte(i)}) + expected := make([]frontend.Variable, nbInputs) + input := make([]frontend.Variable, nbInputs) + for i := range input { + buf[fr.Bytes-1] = byte(i) + _, err := h.Write(buf[:]) assert.NoError(err) - circInput[i] = i + input[i] = i + expected[i] = h.Sum(nil) } - res := h.Sum(nil) - assert.CheckCircuit(&poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 + + assert.CheckCircuit( + &poseidon2Circuit{ + Input: make([]frontend.Variable, nbInputs), + Expected: make([]frontend.Variable, nbInputs), + }, test.WithValidAssignment(&poseidon2Circuit{ + Input: input, + Expected: expected, + }), test.WithCurves(ecc.BLS12_377)) } func TestStateStorer(t *testing.T) {