diff --git a/arbutil/espresso_utils.go b/arbutil/espresso_utils.go index 8e630d9053d..3db21b9132b 100644 --- a/arbutil/espresso_utils.go +++ b/arbutil/espresso_utils.go @@ -8,6 +8,7 @@ import ( espressoTypes "github.com/EspressoSystems/espresso-network/sdks/go/types" "github.com/ccoveille/go-safecast" + "github.com/fxamacker/cbor/v2" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" @@ -16,6 +17,7 @@ import ( const MAX_ATTESTATION_QUOTE_SIZE int = 4 * 1024 const LEN_SIZE int = 8 const INDEX_SIZE int = 8 +const HEADER_LEN int = 8 type SubmittedEspressoTx struct { Hash string @@ -24,6 +26,20 @@ type SubmittedEspressoTx struct { SubmittedAt time.Time `rlp:"optional"` } +type EspressoHeader struct { + // Transaction type to parse payload + TransactionType TransactionType `cbor:"0,keyasint"` + // The payload length, excluding the header + PayloadLength uint64 `cbor:"1,keyasint"` +} + +type TransactionType uint8 + +const ( + BatchPosterSignedTxn TransactionType = 0 + DecentralizedTimeboost TransactionType = 1 +) + func BuildRawHotShotPayload( msgPositions []MessageIndex, msgFetcher func(MessageIndex) ([]byte, error), @@ -74,6 +90,22 @@ func SignHotShotPayload( result = append(result, quote...) result = append(result, unsigned...) + // Prepare the header + header := EspressoHeader{ + TransactionType: BatchPosterSignedTxn, + PayloadLength: uint64(len(result)), + } + + encoded, err := cbor.Marshal(header) + if err != nil { + return nil, err + } + // Put the header at the beginning of the payload + headerBuf := make([]byte, HEADER_LEN) + binary.BigEndian.PutUint64(headerBuf, uint64(len(encoded))) + headerBuf = append(headerBuf, encoded...) + result = append(headerBuf, result...) + return result, nil } @@ -88,7 +120,41 @@ func ValidateIfPayloadIsInBlock(p []byte, payloads []espressoTypes.Bytes) bool { return validated } -func ParseHotShotPayload(payload []byte) (signature []byte, userDataHash []byte, indices []uint64, messages [][]byte, err error) { +func ParseHotshotPayloadForHeader(tx []byte) *EspressoHeader { + if len(tx) < HEADER_LEN { + log.Warn("hotshot transaction is too small for a header") + return nil + } + headerLength := uint64(HEADER_LEN) + // Try and see if there is a header + headerSize := binary.BigEndian.Uint64(tx[:headerLength]) + offset := headerLength + headerSize + encoded := tx[headerLength:offset] + var header EspressoHeader + err := cbor.Unmarshal(encoded, &header) + if err != nil { + log.Warn("failed to decode header", "err", err) + return nil + } + strippedTx := tx[offset:] + // if the set header payload length matches the total rest of payload length it must be a header + if uint64(len(strippedTx)) == header.PayloadLength { + switch header.TransactionType { + case BatchPosterSignedTxn, DecentralizedTimeboost: + return &header + default: + return nil + } + } + return nil +} + +func ParseHotShotPayload(payload []byte, header *EspressoHeader) (signature []byte, userDataHash []byte, indices []uint64, messages [][]byte, err error) { + if header != nil { + // parse the payload with no header + offset := uint64(len(payload)) - header.PayloadLength + payload = payload[offset:] + } if len(payload) < LEN_SIZE { return nil, nil, nil, nil, errors.New("payload too short to parse signature size") } diff --git a/arbutil/espresso_utils_test.go b/arbutil/espresso_utils_test.go index e9e25949dad..437c350db47 100644 --- a/arbutil/espresso_utils_test.go +++ b/arbutil/espresso_utils_test.go @@ -9,6 +9,7 @@ import ( "testing" espressoTypes "github.com/EspressoSystems/espresso-network/sdks/go/types" + "github.com/fxamacker/cbor/v2" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" @@ -36,7 +37,8 @@ func TestParsePayload(t *testing.T) { } // Parse the signed payload - signature, userDataHash, indices, messages, err := ParseHotShotPayload(signedPayload) + header := ParseHotshotPayloadForHeader(signedPayload) + signature, userDataHash, indices, messages, err := ParseHotShotPayload(signedPayload, header) if err != nil { t.Fatalf("failed to parse payload: %v", err) } @@ -70,6 +72,136 @@ func TestParsePayload(t *testing.T) { } } +type EspressoHeaderNew struct { + // Transaction type to parse payload + TransactionType TransactionType `cbor:"0,keyasint"` + // The payload length, excluding the header + PayloadLength uint64 `cbor:"1,keyasint"` + SomeNewField uint8 `cbor:"2,keyasint,omitempty"` +} + +func TestParseEspressoHeader(t *testing.T) { + + h := EspressoHeader{ + TransactionType: 0, + PayloadLength: 5, + } + encoded, err := cbor.Marshal(h) + if err != nil { + t.Fatalf("cbor encoding failed: %v", err) + } + + // Parse a new header in case we ever need to make changes + var h2 EspressoHeaderNew + err = cbor.Unmarshal(encoded, &h2) + if err != nil { + t.Fatalf("cbor decoding failed: %v", err) + } + + if h2.TransactionType != h.TransactionType { + t.Fatalf("expected transaction types to match! got %d, wanted %d", h2.TransactionType, h.TransactionType) + } + + if h2.PayloadLength != h.PayloadLength { + t.Fatalf("expected payload lengths to match! got %d, wanted %d", h2.PayloadLength, h.PayloadLength) + } + + // this should be set to default value which is 0 + if h2.SomeNewField != 0 { + t.Fatalf("expected new field to be set to default of 0! got %d, wanted %d", h2.SomeNewField, 0) + } + +} + +func TestParsePayloadWithAndWithoutHeader(t *testing.T) { + msgPositions := []MessageIndex{1, 2, 10, 24, 100} + + rawPayload, cnt := BuildRawHotShotPayload(msgPositions, mockMsgFetcher, 200*1024) + if cnt != len(msgPositions) { + t.Fatal("exceed transactions") + } + + mockSignature := []byte("fake_signature") + fakeSigner := func(payload []byte) ([]byte, error) { + return mockSignature, nil + } + signedPayload, err := SignHotShotPayload(rawPayload, fakeSigner) + if err != nil { + t.Fatalf("failed to sign payload: %v", err) + } + + // Parse the signed payload + header := ParseHotshotPayloadForHeader(signedPayload) + if header == nil { + t.Fatalf("expected there to be a header") + } + signature, userDataHash, indices, messages, err := ParseHotShotPayload(signedPayload, header) + if err != nil { + t.Fatalf("failed to parse payload: %v", err) + } + + if !slices.Equal(userDataHash, crypto.Keccak256(rawPayload)) { + t.Fatalf("User data hash is not for the correct payload") + } + + // Validate parsed data + if !bytes.Equal(signature, mockSignature) { + t.Errorf("expected signature 'fake_signature', got %v", mockSignature) + } + + for i, index := range indices { + if MessageIndex(index) != msgPositions[i] { + t.Errorf("expected index %d, got %d", msgPositions[i], index) + } + } + + expectedMessages := [][]byte{ + []byte("message1"), + []byte("message2"), + []byte("message10"), + []byte("message24"), + []byte("message100"), + } + for i, message := range messages { + if !bytes.Equal(message, expectedMessages[i]) { + t.Errorf("expected message %s, got %s", expectedMessages[i], message) + } + } + + // Remove header from payload + offset := uint64(len(signedPayload)) - header.PayloadLength + signedPayload = signedPayload[offset:] + header = ParseHotshotPayloadForHeader(signedPayload) + if header != nil { + t.Fatalf("no header should be parsed") + } + signature, userDataHash, indices, messages, err = ParseHotShotPayload(signedPayload, header) + if err != nil { + t.Fatalf("failed to parse payload: %v", err) + } + + if !slices.Equal(userDataHash, crypto.Keccak256(rawPayload)) { + t.Fatalf("User data hash is not for the correct payload") + } + + // Validate parsed data + if !bytes.Equal(signature, mockSignature) { + t.Errorf("expected signature 'fake_signature', got %v", mockSignature) + } + + for i, index := range indices { + if MessageIndex(index) != msgPositions[i] { + t.Errorf("expected index %d, got %d", msgPositions[i], index) + } + } + + for i, message := range messages { + if !bytes.Equal(message, expectedMessages[i]) { + t.Errorf("expected message %s, got %s", expectedMessages[i], message) + } + } +} + func TestValidateIfPayloadIsInBlock(t *testing.T) { msgPositions := []MessageIndex{1, 2} @@ -123,7 +255,7 @@ func TestParsePayloadInvalidCases(t *testing.T) { for _, tc := range invalidPayloads { t.Run(tc.description, func(t *testing.T) { - _, _, _, _, err := ParseHotShotPayload(tc.payload) + _, _, _, _, err := ParseHotShotPayload(tc.payload, nil) if err == nil { t.Errorf("expected error for case '%s', but got none", tc.description) } diff --git a/espressostreamer/espresso_streamer.go b/espressostreamer/espresso_streamer.go index eedca2d4f7e..37bab84951c 100644 --- a/espressostreamer/espresso_streamer.go +++ b/espressostreamer/espresso_streamer.go @@ -246,39 +246,90 @@ func (s *EspressoStreamer) verifyLegacy(attestation []byte, signature [32]byte) return err } -func (s *EspressoStreamer) parseEspressoTransaction(tx espressoTypes.Bytes, l1Height uint64) ([]*MessageWithMetadataAndPos, error) { - signature, userDataHash, indices, messages, err := arbutil.ParseHotShotPayload(tx) - if err != nil { - log.Warn("failed to parse hotshot payload", "err", err) - return nil, err - } - if len(messages) == 0 { - return nil, ErrPayloadHadNoMessages +func (s *EspressoStreamer) fallbackLegacyVerification(data []byte, userDataHashArr [32]byte) error { + if s.espressoSGXVerifier == nil { + return fmt.Errorf("failed to verify attestation quote, legacy header found but sgx verifier is nil") } - if len(userDataHash) != 32 { - log.Warn("user data hash is not 32 bytes") - return nil, ErrUserDataHashNot32Bytes + err := s.verifyLegacy(data, userDataHashArr) + if err != nil { + log.Warn("failed to verify attestation quote", "err", err) + return err } + return nil +} - userDataHashArr := [32]byte(userDataHash) - +func (s *EspressoStreamer) verifySignature(data []byte, userDataHashArr [32]byte, l1Height uint64, fallback bool) error { + err := s.verifyBatchPosterSignature(data, userDataHashArr, l1Height) var success bool - err = s.verifyBatchPosterSignature(signature, userDataHashArr, l1Height) if err == nil { success = true } else if strings.Contains(err.Error(), ErrRetryParsingHotShotPayload.Error()) { log.Warn("retrying to verify batch poster signature", "err", err) - return nil, err + return err } else { log.Warn("failed to verify batch poster signature", "err", err) + // this is the case where there is an EspressoHeader and we failed, dont fall back + if !fallback { + return err + } } - if !success && s.espressoSGXVerifier != nil { - err = s.verifyLegacy(signature, userDataHashArr) - if err != nil { - log.Warn("failed to verify attestation quote", "err", err) + if !success { + if err := s.fallbackLegacyVerification(data, userDataHashArr); err != nil { + return err + } + } + return nil +} + +func (s *EspressoStreamer) verify(data []byte, userDataHashArr [32]byte, l1Height uint64, header *arbutil.EspressoHeader) error { + noHeader := header == nil + if !noHeader { + txType := header.TransactionType + switch txType { + case arbutil.BatchPosterSignedTxn: + if err := s.verifySignature(data, userDataHashArr, l1Height, noHeader); err != nil { + return err + } + // TODO: timeboost support + default: + return fmt.Errorf("failed to verify transaction, received unexpected transaction type: %d", txType) + } + } else if err := s.verifySignature(data, userDataHashArr, l1Height, noHeader); err != nil { + return err + } + return nil +} + +func (s *EspressoStreamer) parseEspressoTransaction(tx espressoTypes.Bytes, l1Height uint64) ([]*MessageWithMetadataAndPos, error) { + header := arbutil.ParseHotshotPayloadForHeader(tx) + signature, userDataHash, indices, messages, err := arbutil.ParseHotShotPayload(tx, header) + if err != nil { + if header != nil { + // in case somehow we parsed a header and there wasnt one, try again + header = nil + signature, userDataHash, indices, messages, err = arbutil.ParseHotShotPayload(tx, header) + if err != nil { + log.Warn("failed to parse hotshot payload", "err", err) + return nil, err + } + } else { + log.Warn("failed to parse hotshot payload", "err", err) return nil, err } + + } + if len(messages) == 0 { + return nil, ErrPayloadHadNoMessages + } + if len(userDataHash) != 32 { + log.Warn("user data hash is not 32 bytes") + return nil, ErrUserDataHashNot32Bytes + } + + err = s.verify(signature, [32]byte(userDataHash), l1Height, header) + if err != nil { + return nil, err } result := []*MessageWithMetadataAndPos{} diff --git a/system_tests/caff_node_test.go b/system_tests/caff_node_test.go index 9dacb15eab5..81821ab291a 100644 --- a/system_tests/caff_node_test.go +++ b/system_tests/caff_node_test.go @@ -638,7 +638,9 @@ func TestEspressoCaffNodeSGXVerifierShouldRetryWhenEncounterRPCError(t *testing. espressoStreamer.SetSGXVerifier(NewMockSgxTeeVerifier()) // Set this will cause the caff node to use the sgx verifier - espressoStreamer.SetBatcherAddressesFetcher(func(l1Height uint64) []common.Address { return []common.Address{{}} }) + espressoStreamer.SetBatcherAddressesFetcher(func(l1Height uint64) []common.Address { + return []common.Address{common.HexToAddress("0xb386a74Dcab67b66F8AC07B4f08365d37495Dd23")} + }) err = waitForWith(ctx, 10*time.Minute, 10*time.Second, func() bool { balance1 := builder2.L2.GetBalance(t, builder.L2Info.GetAddress("User16")) diff --git a/system_tests/espresso/generate/messages.go b/system_tests/espresso/generate/messages.go index 8ae7ab8d070..d7d3b01bff2 100644 --- a/system_tests/espresso/generate/messages.go +++ b/system_tests/espresso/generate/messages.go @@ -214,7 +214,8 @@ func ConvertEspressoTransactionsInBlockToMessages( // We can parse the transactions to get the messages // This is a mock function that simulates the parsing of the transaction // In a real scenario, this would be replaced with the actual parsing logic - _, _, _, messages, err := arbutil.ParseHotShotPayload(tx) + header := arbutil.ParseHotshotPayloadForHeader(tx) + _, _, _, messages, err := arbutil.ParseHotShotPayload(tx, header) if err != nil { return nil, fmt.Errorf("encountered error while parsing transaction: %w", err) }