Skip to content
52 changes: 35 additions & 17 deletions std/hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/uints"
)

Expand Down Expand Up @@ -108,48 +109,65 @@ type Compressor interface {
}

type merkleDamgardHasher struct {
state frontend.Variable
iv frontend.Variable
f Compressor
api frontend.API
state []frontend.Variable // state after being updated with each written element
stateTable logderivlookup.Table // stateTable always contains a prefix of h.state
stateTableLen int
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please comment what is stateTableLen: "// stateTableLen indicates the length of of state entries written to stateTable. It will be updated with every read to DynamicSum()" etc.

f Compressor
api frontend.API
}

// NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash
// initialState is a value whose preimage is not known
// NewMerkleDamgardHasher range-extends a 2-1 one-way hash compression function into a hash by way of the Merkle-Damgård construction.
// Parameters:
// - api: constraint builder
// - f: 2-1 hash compression (one-way) function
// - initialState: the initialization vector (IV) in the Merkle-Damgård chain. It must be a value whose preimage is not known.
Comment thread
Tabaie marked this conversation as resolved.
func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) StateStorer {
Comment thread
Tabaie marked this conversation as resolved.
return &merkleDamgardHasher{
state: initialState,
iv: initialState,
state: []frontend.Variable{initialState},
f: f,
api: api,
}
}

func (h *merkleDamgardHasher) Reset() {
h.state = h.iv
h.state = h.state[:1]
h.stateTableLen = 0
h.stateTable = nil
}

func (h *merkleDamgardHasher) Write(data ...frontend.Variable) {
for _, d := range data {
h.state = h.f.Compress(h.state, d)
h.state = append(h.state, h.f.Compress(h.state[len(h.state)-1], d))
}
}

func (h *merkleDamgardHasher) Sum() frontend.Variable {
return h.state
return h.state[len(h.state)-1]
}

// SumWithLength computes the Merkle-Damgård hash of the input data, truncated at the given length.
// Parameters:
// - length: length of the prefix of data to be hashed. The verifier will not accept a value outside the range {0, 1, ..., len(data)}.
// The gnark prover will refuse to attempt to generate such an unsuccessful proof.
func (h *merkleDamgardHasher) SumWithLength(length frontend.Variable) frontend.Variable {
if h.stateTable == nil {
h.stateTable = logderivlookup.New(h.api)
}
for h.stateTableLen < len(h.state) {
h.stateTable.Insert(h.state[h.stateTableLen])
h.stateTableLen++
}
return h.stateTable.Lookup(length)[0]
}

func (h *merkleDamgardHasher) State() []frontend.Variable {
return []frontend.Variable{h.state}
return []frontend.Variable{h.state[len(h.state)-1]}
}

func (h *merkleDamgardHasher) SetState(state []frontend.Variable) error {
if h.state != h.iv {
return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state")
}
if len(state) != 1 {
return fmt.Errorf("expected one state variable, got %d", len(state))
return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state")
}
h.state = state[0]
h.state = append(h.state, state[0])
Comment thread
cursor[bot] marked this conversation as resolved.
return nil
}
50 changes: 39 additions & 11 deletions std/hash/poseidon2/poseidon2_test.go
Original file line number Diff line number Diff line change
@@ -1,51 +1,79 @@
package poseidon2_test

import (
"fmt"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/hash/poseidon2"
gkr_poseidon2 "github.com/consensys/gnark/std/hash/poseidon2/gkr-poseidon2"
permutation "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2"
"github.com/consensys/gnark/test"
)

type poseidon2Circuit struct {
Input []frontend.Variable
Expected frontend.Variable `gnark:",public"`
Expected []frontend.Variable `gnark:",public"` // Expected[i] = H(Input[:i+1])
}

func (c *poseidon2Circuit) Define(api frontend.API) error {
if len(c.Input) != len(c.Expected) {
return fmt.Errorf("length mismatch")
}
hsh, err := poseidon2.New(api)
if err != nil {
return err
}

compressor, err := permutation.NewCompressor(api)
if err != nil {
return err
}

gkr, err := gkr_poseidon2.New(api)
if err != nil {
return err
}
hsh.Write(c.Input...)
api.AssertIsEqual(hsh.Sum(), c.Expected)
gkr.Write(c.Input...)
api.AssertIsEqual(gkr.Sum(), c.Expected)

for i := range c.Input {
hsh.Write(c.Input[i])
api.AssertIsEqual(c.Expected[i], hsh.Sum())
gkr.Write(c.Input[i])
api.AssertIsEqual(c.Expected[i], gkr.Sum())
api.AssertIsEqual(c.Expected[i], hash.SumMerkleDamgardDynamicLength(api, compressor, 0, i+1, c.Input))

Check failure on line 47 in std/hash/poseidon2/poseidon2_test.go

View workflow job for this annotation

GitHub Actions / staticcheck

undefined: hash.SumMerkleDamgardDynamicLength (typecheck)
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'l also add test if you're trying to read hash digest not from the end, but from earlier. A la varlen.SumWithLength(2) in the end.

return nil
}

func TestPoseidon2Hash(t *testing.T) {
assert := test.NewAssert(t)

var buf [fr.Bytes]byte
const nbInputs = 5
// prepare expected output
h := poseidonbls12377.NewMerkleDamgardHasher()
circInput := make([]frontend.Variable, nbInputs)
for i := range nbInputs {
_, err := h.Write([]byte{byte(i)})
expected := make([]frontend.Variable, nbInputs)
input := make([]frontend.Variable, nbInputs)
for i := range input {
buf[fr.Bytes-1] = byte(i)
_, err := h.Write(buf[:])
assert.NoError(err)
circInput[i] = i
input[i] = i
expected[i] = h.Sum(nil)
}
res := h.Sum(nil)
assert.CheckCircuit(&poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377

assert.CheckCircuit(
&poseidon2Circuit{
Input: make([]frontend.Variable, nbInputs),
Expected: make([]frontend.Variable, nbInputs),
}, test.WithValidAssignment(&poseidon2Circuit{
Input: input,
Expected: expected,
}), test.WithCurves(ecc.BLS12_377))
}

func TestStateStorer(t *testing.T) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add SumWithLen test to the state storer circuit.

Expand Down
Loading