Skip to content

Commit

Permalink
De-duplicate agent signing logic (#1228)
Browse files Browse the repository at this point in the history
* WIP: merge common signing logic into sign() util

* WIP: add Encoder interface for signed types

* Feat: add consts for salt values

* Cleanup: remove specific language from generic signEncoder func

* Fix: parity tests

* Cleanup: lint
  • Loading branch information
dwasse authored Aug 3, 2023
1 parent ce11472 commit c38b94c
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 147 deletions.
4 changes: 2 additions & 2 deletions agents/agents/guard/fraud.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (g Guard) shouldSubmitStateReport(ctx context.Context, snapshot *types.Frau
// isStateSlashable checks if a state is slashable, i.e. if the state is valid on the
// Origin, and if the agent is in a slashable status.
func (g Guard) isStateSlashable(ctx context.Context, state types.State, agent common.Address) (bool, error) {
statePayload, err := types.EncodeState(state)
statePayload, err := state.Encode()
if err != nil {
return false, fmt.Errorf("could not encode state: %w", err)
}
Expand Down Expand Up @@ -142,7 +142,7 @@ func (g Guard) handleValidAttestation(ctx context.Context, fraudAttestation *typ

// Verify each state in the snapshot.
for stateIndex, state := range snapshot.States() {
snapPayload, err := types.EncodeSnapshot(snapshot)
snapPayload, err := snapshot.Encode()
if err != nil {
return fmt.Errorf("could not encode snapshot: %w", err)
}
Expand Down
30 changes: 5 additions & 25 deletions agents/types/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package types

import (
"context"
"fmt"
"math/big"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/synapsecns/sanguine/core"
"github.com/synapsecns/sanguine/ethergo/signer/signer"
)

Expand All @@ -22,6 +20,7 @@ const (

// Attestation is the attestation interface.
type Attestation interface {
Encoder
// SnapshotRoot is the root of the Snapshot Merkle Tree.
SnapshotRoot() [32]byte
// DataHash is the agent root and SnapGasHash combined into a single hash.
Expand Down Expand Up @@ -75,33 +74,14 @@ func (a attestation) Timestamp() *big.Int {
return a.timestamp
}

//nolint:dupl
func (a attestation) SignAttestation(ctx context.Context, signer signer.Signer, valid bool) (signer.Signature, []byte, common.Hash, error) {
encodedAttestation, err := EncodeAttestation(a)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not encode attestation: %w", err)
}

var attestationSalt common.Hash
var attestationSalt string
if valid {
attestationSalt = crypto.Keccak256Hash([]byte("ATTESTATION_VALID_SALT"))
attestationSalt = AttestationValidSalt
} else {
attestationSalt = crypto.Keccak256Hash([]byte("ATTESTATION_INVALID_SALT"))
}

hashedEncodedAttestation := crypto.Keccak256Hash(encodedAttestation).Bytes()
toSign := append(attestationSalt.Bytes(), hashedEncodedAttestation...)

hashedAttestation, err := HashRawBytes(toSign)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not hash attestation: %w", err)
}

signature, err := signer.SignMessage(ctx, core.BytesToSlice(hashedAttestation), false)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not sign attestation: %w", err)
attestationSalt = AttestationInvalidSalt
}
return signature, encodedAttestation, hashedAttestation, nil
return signEncoder(ctx, signer, a, attestationSalt)
}

// GetAttestationDataHash generates the data hash from the agent root and SnapGasHash.
Expand Down
67 changes: 36 additions & 31 deletions agents/types/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ const (
uint40Len = 5
)

// Encoder encodes a type to bytes.
type Encoder interface {
Encode() ([]byte, error)
}

// EncodeGasData encodes a gasdata.
func EncodeGasData(gasData GasData) ([]byte, error) {
b := make([]byte, 0)
Expand Down Expand Up @@ -106,23 +111,23 @@ func DecodeChainGas(toDecode []byte) (ChainGas, error) {
}, nil
}

// EncodeState encodes a state.
func EncodeState(state State) ([]byte, error) {
// Encode encodes a state.
func (s state) Encode() ([]byte, error) {
b := make([]byte, 0)
originBytes := make([]byte, uint32Len)
nonceBytes := make([]byte, uint32Len)

binary.BigEndian.PutUint32(originBytes, state.Origin())
binary.BigEndian.PutUint32(nonceBytes, state.Nonce())
root := state.Root()
binary.BigEndian.PutUint32(originBytes, s.Origin())
binary.BigEndian.PutUint32(nonceBytes, s.Nonce())
root := s.Root()

// Note that since we are packing an 8 byte (int64) number into 5 bytes, we need to
// ensure that the result does not exceed the expected byte length for a valid State.
blockNumberBytes := math.PaddedBigBytes(state.BlockNumber(), uint40Len)
// ensure that the result does not exceed the expected byte length for a valid s.
blockNumberBytes := math.PaddedBigBytes(s.BlockNumber(), uint40Len)
if len(blockNumberBytes) != uint40Len {
return nil, fmt.Errorf("invalid block number length, expected %d, got %d", uint40Len, len(blockNumberBytes))
}
timestampBytes := math.PaddedBigBytes(state.Timestamp(), uint40Len)
timestampBytes := math.PaddedBigBytes(s.Timestamp(), uint40Len)
if len(timestampBytes) != uint40Len {
return nil, fmt.Errorf("invalid timestamp length, expected %d, got %d", uint40Len, len(timestampBytes))
}
Expand All @@ -133,7 +138,7 @@ func EncodeState(state State) ([]byte, error) {
b = append(b, blockNumberBytes...)
b = append(b, timestampBytes...)

gasDataEncoded, err := EncodeGasData(state.GasData())
gasDataEncoded, err := EncodeGasData(s.GasData())
if err != nil {
return nil, fmt.Errorf("failed to encode gas data for state %w", err)
}
Expand Down Expand Up @@ -173,9 +178,9 @@ func DecodeState(toDecode []byte) (State, error) {
}, nil
}

// EncodeSnapshot encodes a snapshot.
func EncodeSnapshot(snapshot Snapshot) ([]byte, error) {
states := snapshot.States()
// Encode encodes a snapshot.
func (s snapshot) Encode() ([]byte, error) {
states := s.States()

if len(states) == 0 {
return nil, fmt.Errorf("no states to encode")
Expand All @@ -184,7 +189,7 @@ func EncodeSnapshot(snapshot Snapshot) ([]byte, error) {
encodedStates := make([]byte, 0)

for _, state := range states {
encodedState, err := EncodeState(state)
encodedState, err := state.Encode()
if err != nil {
return nil, fmt.Errorf("could not encode state: %w", err)
}
Expand Down Expand Up @@ -215,22 +220,22 @@ func DecodeSnapshot(toDecode []byte) (Snapshot, error) {
}, nil
}

// EncodeAttestation encodes an attestation.
func EncodeAttestation(attestation Attestation) ([]byte, error) {
// Encode encodes an attestation.
func (a attestation) Encode() ([]byte, error) {
b := make([]byte, 0)
nonceBytes := make([]byte, uint32Len)

binary.BigEndian.PutUint32(nonceBytes, attestation.Nonce())
snapshotRoot := attestation.SnapshotRoot()
dataHash := attestation.DataHash()
binary.BigEndian.PutUint32(nonceBytes, a.Nonce())
snapshotRoot := a.SnapshotRoot()
dataHash := a.DataHash()

// Note that since we are packing an 8 byte (int64) number into 5 bytes, we need to
// ensure that the result does not exceed the expected byte length for a valid Attestation.
blockNumberBytes := math.PaddedBigBytes(attestation.BlockNumber(), uint40Len)
// ensure that the result does not exceed the expected byte length for a valid a.
blockNumberBytes := math.PaddedBigBytes(a.BlockNumber(), uint40Len)
if len(blockNumberBytes) != uint40Len {
return nil, fmt.Errorf("invalid block number length, expected %d, got %d", uint40Len, len(blockNumberBytes))
}
timestampBytes := math.PaddedBigBytes(attestation.Timestamp(), uint40Len)
timestampBytes := math.PaddedBigBytes(a.Timestamp(), uint40Len)
if len(timestampBytes) != uint40Len {
return nil, fmt.Errorf("invalid timestamp length, expected %d, got %d", uint40Len, len(timestampBytes))
}
Expand Down Expand Up @@ -558,26 +563,26 @@ func DecodeAgentStatus(toDecode []byte) (AgentStatus, error) {
}, nil
}

// EncodeReceipt encodes an receipt.
func EncodeReceipt(receipt Receipt) ([]byte, error) {
// Encode encodes an receipt.
func (r receipt) Encode() ([]byte, error) {
b := make([]byte, 0)
originBytes := make([]byte, uint32Len)
binary.BigEndian.PutUint32(originBytes, receipt.Origin())
binary.BigEndian.PutUint32(originBytes, r.Origin())

destBytes := make([]byte, uint32Len)
binary.BigEndian.PutUint32(destBytes, receipt.Destination())
binary.BigEndian.PutUint32(destBytes, r.Destination())

messageHashBytes := receipt.MessageHash()
snapshotRootBytes := receipt.SnapshotRoot()
messageHashBytes := r.MessageHash()
snapshotRootBytes := r.SnapshotRoot()

b = append(b, originBytes...)
b = append(b, destBytes...)
b = append(b, messageHashBytes[:]...)
b = append(b, snapshotRootBytes[:]...)
b = append(b, []byte{receipt.StateIndex()}...)
b = append(b, receipt.AttestationNotary().Bytes()...)
b = append(b, receipt.FirstExecutor().Bytes()...)
b = append(b, receipt.FinalExecutor().Bytes()...)
b = append(b, []byte{r.StateIndex()}...)
b = append(b, r.AttestationNotary().Bytes()...)
b = append(b, r.FirstExecutor().Bytes()...)
b = append(b, r.FinalExecutor().Bytes()...)

return b, nil
}
Expand Down
17 changes: 9 additions & 8 deletions agents/types/parity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package types_test
import (
"context"
"crypto/rand"
"github.com/synapsecns/sanguine/core/testsuite"
"math/big"
"testing"
"time"

"github.com/synapsecns/sanguine/core/testsuite"

"github.com/brianvoe/gofakeit/v6"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -173,7 +174,7 @@ func TestEncodeStateParity(t *testing.T) {

gasData := types.NewGasData(gasPrice, dataPrice, execBuffer, amortAttCost, etherPrice, markup)

goFormattedData, err := types.EncodeState(types.NewState(rootB32, origin, nonce, blockNumber, timestamp, gasData))
goFormattedData, err := types.NewState(rootB32, origin, nonce, blockNumber, timestamp, gasData).Encode()
Nil(t, err)
Equal(t, contractData, goFormattedData)

Expand Down Expand Up @@ -212,7 +213,7 @@ func TestEncodeReceiptParity(t *testing.T) {

receipt := types.NewReceipt(origin, destination, messageHash, snapshotRoot, stateIndex, attNotary, firstExecutor, finalExecutor)

encodedReceipt, err := types.EncodeReceipt(receipt)
encodedReceipt, err := receipt.Encode()
Nil(t, err)

solEncodedReceipt, err := receiptHarness.FormatReceipt(&bind.CallOpts{Context: ctx}, origin, destination, messageHash, snapshotRoot, stateIndex, attNotary, firstExecutor, finalExecutor)
Expand Down Expand Up @@ -273,17 +274,17 @@ func TestEncodeSnapshotParity(t *testing.T) {
stateB := types.NewState(rootB, originB, nonceB, blockNumberB, timestampB, gasDataB)

var statesAB [][]byte
stateABytes, err := types.EncodeState(stateA)
stateABytes, err := stateA.Encode()
Nil(t, err)
statesAB = append(statesAB, stateABytes)
stateBBytes, err := types.EncodeState(stateB)
stateBBytes, err := stateB.Encode()
Nil(t, err)
statesAB = append(statesAB, stateBBytes)

contractData, err := snapshotContract.FormatSnapshot(&bind.CallOpts{Context: ctx}, statesAB)
Nil(t, err)

goFormattedData, err := types.EncodeSnapshot(types.NewSnapshot([]types.State{stateA, stateB}))
goFormattedData, err := types.NewSnapshot([]types.State{stateA, stateB}).Encode()
Nil(t, err)

Equal(t, contractData, goFormattedData)
Expand Down Expand Up @@ -364,7 +365,7 @@ func TestEncodeAttestationParity(t *testing.T) {
contractData, err := attestationContract.FormatAttestation(&bind.CallOpts{Context: ctx}, rootB32, dataHashB32, nonce, blockNumber, timestamp)
Nil(t, err)

goFormattedData, err := types.EncodeAttestation(types.NewAttestation(rootB32, dataHashB32, nonce, blockNumber, timestamp))
goFormattedData, err := types.NewAttestation(rootB32, dataHashB32, nonce, blockNumber, timestamp).Encode()
Nil(t, err)

Equal(t, contractData, goFormattedData)
Expand Down Expand Up @@ -397,7 +398,7 @@ func TestEncodeAttestationParity(t *testing.T) {

Equal(t, contractDataHashFromVals, attestationDataHash)

encodedDataHashAttestation, err := types.EncodeAttestation(attestation)
encodedDataHashAttestation, err := attestation.Encode()
Nil(t, err)

contractDataHashFromAtt, err := attestationContract.DataHash0(&bind.CallOpts{Context: ctx}, encodedDataHashAttestation)
Expand Down
31 changes: 5 additions & 26 deletions agents/types/receipt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package types

import (
"context"
"fmt"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/synapsecns/sanguine/core"
"github.com/synapsecns/sanguine/ethergo/signer/signer"
)

Expand All @@ -24,6 +21,7 @@ const (

// Receipt is the receipt interface.
type Receipt interface {
Encoder
// Origin is the origin of the receipt.
Origin() uint32
// Destination is the destination of the receipt.
Expand Down Expand Up @@ -101,31 +99,12 @@ func (r receipt) FinalExecutor() common.Address {
return r.finalExecutor
}

//nolint:dupl
func (r receipt) SignReceipt(ctx context.Context, signer signer.Signer, valid bool) (signer.Signature, []byte, common.Hash, error) {
encodedReceipt, err := EncodeReceipt(r)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not encode receipt: %w", err)
}

var receiptSalt common.Hash
var receiptSalt string
if valid {
receiptSalt = crypto.Keccak256Hash([]byte("RECEIPT_VALID_SALT"))
receiptSalt = ReceiptValidSalt
} else {
receiptSalt = crypto.Keccak256Hash([]byte("RECEIPT_INVALID_SALT"))
}

hashedEncodedReceipt := crypto.Keccak256Hash(encodedReceipt).Bytes()
toSign := append(receiptSalt.Bytes(), hashedEncodedReceipt...)

hashedReceipt, err := HashRawBytes(toSign)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not hash receipt: %w", err)
}

signature, err := signer.SignMessage(ctx, core.BytesToSlice(hashedReceipt), false)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not sign receipt: %w", err)
receiptSalt = ReceiptInvalidSalt
}
return signature, encodedReceipt, hashedReceipt, nil
return signEncoder(ctx, signer, r, receiptSalt)
}
21 changes: 21 additions & 0 deletions agents/types/salt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package types

// Note: these salt values are taken from contracts/libs/Constants.sol.

// AttestationValidSalt is the salt for ATTESTATION_VALID_SALT.
const AttestationValidSalt = "ATTESTATION_VALID_SALT"

// AttestationInvalidSalt is the salt for ATTESTATION_INVALID_SALT.
const AttestationInvalidSalt = "ATTESTATION_INVALID_SALT"

// ReceiptValidSalt is the salt for RECEIPT_VALID_SALT.
const ReceiptValidSalt = "RECEIPT_VALID_SALT"

// ReceiptInvalidSalt is the salt for RECEIPT_INVALID_SALT.
const ReceiptInvalidSalt = "RECEIPT_INVALID_SALT"

// SnapshotValidSalt is the salt for SNAPSHOT_VALID_SALT.
const SnapshotValidSalt = "SNAPSHOT_VALID_SALT"

// StateInvalidSalt is the salt for STATE_INVALID_SALT.
const StateInvalidSalt = "STATE_INVALID_SALT"
Loading

0 comments on commit c38b94c

Please sign in to comment.