Skip to content
60 changes: 44 additions & 16 deletions std/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/uints"
)

Expand All @@ -26,6 +27,13 @@ type FieldHasher interface {
Reset()
}

// DynamicLengthFieldHasher can compute hashes of lengths unknown at compile time.
type DynamicLengthFieldHasher interface {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep consistency with BinaryFixedLengthHasher. Also see the interface description there which imo is more thorough.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to call it "fixed length"? That seems like the exact opposite of what it does.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We fixed the input length at solving time? Dunno, perhaps it could be better, but imo it also is not wrong.

My point is that we already have an existing interface which does the same thing (but when we get the inputs as bytes instead of field elements), so for me it is logical that the new interface has similar naming. And for boolean hashes the "dynamic length" is also confusing as it could refer to the output length of the hash (as in SHAKE). What is important is that the methods are well documented and explained, we shouldn't rely on the method name only to pass on all information.

FieldHasher
// SumWithLength computes the hash of the first l inputs written into the hash.
SumWithLength(l frontend.Variable) frontend.Variable
}

// StateStorer allows to store and retrieve the state of a hash function.
type StateStorer interface {
FieldHasher
Expand Down Expand Up @@ -108,48 +116,68 @@ type Compressor interface {
}

type merkleDamgardHasher struct {
state frontend.Variable
iv frontend.Variable
f Compressor
api frontend.API
state []frontend.Variable // state after being updated with each written element
stateTable logderivlookup.Table // stateTable always contains a prefix of h.state
stateTableLen int
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please comment what is stateTableLen: "// stateTableLen indicates the length of of state entries written to stateTable. It will be updated with every read to DynamicSum()" etc.

f Compressor
api frontend.API
}

// NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash
// initialState is a value whose preimage is not known
// NewMerkleDamgardHasher range-extends a 2-1 one-way hash compression function into a hash by way of the Merkle-Damgård construction.
// Parameters:
// - api: constraint builder
// - f: 2-1 hash compression (one-way) function
// - initialState: the initialization vector (IV) in the Merkle-Damgård chain. It must be a value whose preimage is not known.
Comment thread
Tabaie marked this conversation as resolved.
func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) StateStorer {
Comment thread
Tabaie marked this conversation as resolved.
return &merkleDamgardHasher{
state: initialState,
iv: initialState,
state: []frontend.Variable{initialState},
f: f,
api: api,
}
}

func (h *merkleDamgardHasher) Reset() {
h.state = h.iv
h.state = h.state[:1]
h.stateTableLen = 0
h.stateTable = nil
}

func (h *merkleDamgardHasher) Write(data ...frontend.Variable) {
for _, d := range data {
h.state = h.f.Compress(h.state, d)
h.state = append(h.state, h.f.Compress(h.state[len(h.state)-1], d))
}
}

func (h *merkleDamgardHasher) Sum() frontend.Variable {
return h.state
return h.state[len(h.state)-1]
}

// SumWithLength computes the Merkle-Damgård hash of the input data, truncated at the given length.
// Parameters:
// - length: length of the prefix of data to be hashed. The verifier will not accept a value outside the range {0, 1, ..., len(data)}.
// The gnark prover will refuse to attempt to generate such an unsuccessful proof.
func (h *merkleDamgardHasher) SumWithLength(length frontend.Variable) frontend.Variable {
if h.stateTable == nil {
h.stateTable = logderivlookup.New(h.api)
}
for h.stateTableLen < len(h.state) {
h.stateTable.Insert(h.state[h.stateTableLen])
h.stateTableLen++
}
return h.stateTable.Lookup(length)[0]
}

func (h *merkleDamgardHasher) State() []frontend.Variable {
return []frontend.Variable{h.state}
return []frontend.Variable{h.state[len(h.state)-1]}
}

func (h *merkleDamgardHasher) SetState(state []frontend.Variable) error {
if h.state != h.iv {
return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state")
}
if len(state) != 1 {
return fmt.Errorf("expected one state variable, got %d", len(state))
}
h.state = state[0]
if len(h.state) != 1 {
return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state")
}
h.state = append(h.state, state[0])
Comment thread
cursor[bot] marked this conversation as resolved.
return nil
}
52 changes: 41 additions & 11 deletions std/hash/poseidon2/poseidon2_test.go
Original file line number Diff line number Diff line change
@@ -1,51 +1,81 @@
package poseidon2_test

import (
"fmt"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/hash/poseidon2"
gkr_poseidon2 "github.com/consensys/gnark/std/hash/poseidon2/gkr-poseidon2"
"github.com/consensys/gnark/test"
)

type poseidon2Circuit struct {
Input []frontend.Variable
Expected frontend.Variable `gnark:",public"`
Expected []frontend.Variable `gnark:",public"` // Expected[i] = H(Input[:i+1])
}

func (c *poseidon2Circuit) Define(api frontend.API) error {
if len(c.Input) != len(c.Expected) {
return fmt.Errorf("length mismatch")
}
hsh, err := poseidon2.New(api)
if err != nil {
return err
}
varlen := hsh.(hash.DynamicLengthFieldHasher)

hsh, err = poseidon2.New(api)
if err != nil {
return err
}

gkr, err := gkr_poseidon2.New(api)
if err != nil {
return err
}
hsh.Write(c.Input...)
api.AssertIsEqual(hsh.Sum(), c.Expected)
gkr.Write(c.Input...)
api.AssertIsEqual(gkr.Sum(), c.Expected)

varlen.Write(c.Input...)

for i := range c.Input {
hsh.Write(c.Input[i])
api.AssertIsEqual(c.Expected[i], hsh.Sum())
gkr.Write(c.Input[i])
api.AssertIsEqual(c.Expected[i], gkr.Sum())
api.AssertIsEqual(c.Expected[i], varlen.SumWithLength(i+1))
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'l also add test if you're trying to read hash digest not from the end, but from earlier. A la varlen.SumWithLength(2) in the end.

return nil
}

func TestPoseidon2Hash(t *testing.T) {
assert := test.NewAssert(t)

var buf [fr.Bytes]byte
const nbInputs = 5
// prepare expected output
h := poseidonbls12377.NewMerkleDamgardHasher()
circInput := make([]frontend.Variable, nbInputs)
for i := range nbInputs {
_, err := h.Write([]byte{byte(i)})
expected := make([]frontend.Variable, nbInputs)
input := make([]frontend.Variable, nbInputs)
for i := range input {
buf[fr.Bytes-1] = byte(i)
_, err := h.Write(buf[:])
assert.NoError(err)
circInput[i] = i
input[i] = i
expected[i] = h.Sum(nil)
}
res := h.Sum(nil)
assert.CheckCircuit(&poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377

assert.CheckCircuit(
&poseidon2Circuit{
Input: make([]frontend.Variable, nbInputs),
Expected: make([]frontend.Variable, nbInputs),
}, test.WithValidAssignment(&poseidon2Circuit{
Input: input,
Expected: expected,
}), test.WithCurves(ecc.BLS12_377))
}

func TestStateStorer(t *testing.T) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add SumWithLen test to the state storer circuit.

Expand Down
Loading