Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

De-duplicate agent signing logic #1228

Merged
merged 6 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions agents/agents/guard/fraud.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (g Guard) shouldSubmitStateReport(ctx context.Context, snapshot *types.Frau
// isStateSlashable checks if a state is slashable, i.e. if the state is valid on the
// Origin, and if the agent is in a slashable status.
func (g Guard) isStateSlashable(ctx context.Context, state types.State, agent common.Address) (bool, error) {
statePayload, err := types.EncodeState(state)
statePayload, err := state.Encode()
if err != nil {
return false, fmt.Errorf("could not encode state: %w", err)
}
Expand Down Expand Up @@ -142,7 +142,7 @@ func (g Guard) handleValidAttestation(ctx context.Context, fraudAttestation *typ

// Verify each state in the snapshot.
for stateIndex, state := range snapshot.States() {
snapPayload, err := types.EncodeSnapshot(snapshot)
snapPayload, err := snapshot.Encode()
if err != nil {
return fmt.Errorf("could not encode snapshot: %w", err)
}
Expand Down
30 changes: 5 additions & 25 deletions agents/types/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package types

import (
"context"
"fmt"
"math/big"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/synapsecns/sanguine/core"
"github.com/synapsecns/sanguine/ethergo/signer/signer"
)

Expand All @@ -22,6 +20,7 @@ const (

// Attestation is the attestation interface.
type Attestation interface {
Encoder
// SnapshotRoot is the root of the Snapshot Merkle Tree.
SnapshotRoot() [32]byte
// DataHash is the agent root and SnapGasHash combined into a single hash.
Expand Down Expand Up @@ -75,33 +74,14 @@ func (a attestation) Timestamp() *big.Int {
return a.timestamp
}

//nolint:dupl
func (a attestation) SignAttestation(ctx context.Context, signer signer.Signer, valid bool) (signer.Signature, []byte, common.Hash, error) {
encodedAttestation, err := EncodeAttestation(a)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not encode attestation: %w", err)
}

var attestationSalt common.Hash
var attestationSalt string
if valid {
attestationSalt = crypto.Keccak256Hash([]byte("ATTESTATION_VALID_SALT"))
attestationSalt = AttestationValidSalt
} else {
attestationSalt = crypto.Keccak256Hash([]byte("ATTESTATION_INVALID_SALT"))
}

hashedEncodedAttestation := crypto.Keccak256Hash(encodedAttestation).Bytes()
toSign := append(attestationSalt.Bytes(), hashedEncodedAttestation...)

hashedAttestation, err := HashRawBytes(toSign)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not hash attestation: %w", err)
}

signature, err := signer.SignMessage(ctx, core.BytesToSlice(hashedAttestation), false)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not sign attestation: %w", err)
attestationSalt = AttestationInvalidSalt
}
return signature, encodedAttestation, hashedAttestation, nil
return signEncoder(ctx, signer, a, attestationSalt)
}

// GetAttestationDataHash generates the data hash from the agent root and SnapGasHash.
Expand Down
67 changes: 36 additions & 31 deletions agents/types/encoder.go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts on deleting this file and moving each specific encoder to its specific type file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think eventually that's the right approach, if we like the approach I can go ahead and implement for all the other types (although we still have to define the Encoder interface somewhere).

There's a question of how to implement the decoder functions, I think ideally we could do something like this:

type Encoder interface {
  Encode() ([]byte, error)
  Decode([]byte) error
}

Where the new pattern to decode would be:

receipt := new(Receipt)
err = receipt.Decode(rawBytes)

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ const (
uint40Len = 5
)

// Encoder encodes a type to bytes.
type Encoder interface {
Encode() ([]byte, error)
}

// EncodeGasData encodes a gasdata.
func EncodeGasData(gasData GasData) ([]byte, error) {
b := make([]byte, 0)
Expand Down Expand Up @@ -106,23 +111,23 @@ func DecodeChainGas(toDecode []byte) (ChainGas, error) {
}, nil
}

// EncodeState encodes a state.
func EncodeState(state State) ([]byte, error) {
// Encode encodes a state.
func (s state) Encode() ([]byte, error) {
b := make([]byte, 0)
originBytes := make([]byte, uint32Len)
nonceBytes := make([]byte, uint32Len)

binary.BigEndian.PutUint32(originBytes, state.Origin())
binary.BigEndian.PutUint32(nonceBytes, state.Nonce())
root := state.Root()
binary.BigEndian.PutUint32(originBytes, s.Origin())
binary.BigEndian.PutUint32(nonceBytes, s.Nonce())
root := s.Root()

// Note that since we are packing an 8 byte (int64) number into 5 bytes, we need to
// ensure that the result does not exceed the expected byte length for a valid State.
blockNumberBytes := math.PaddedBigBytes(state.BlockNumber(), uint40Len)
// ensure that the result does not exceed the expected byte length for a valid s.
blockNumberBytes := math.PaddedBigBytes(s.BlockNumber(), uint40Len)
if len(blockNumberBytes) != uint40Len {
return nil, fmt.Errorf("invalid block number length, expected %d, got %d", uint40Len, len(blockNumberBytes))
}
timestampBytes := math.PaddedBigBytes(state.Timestamp(), uint40Len)
timestampBytes := math.PaddedBigBytes(s.Timestamp(), uint40Len)
if len(timestampBytes) != uint40Len {
return nil, fmt.Errorf("invalid timestamp length, expected %d, got %d", uint40Len, len(timestampBytes))
}
Expand All @@ -133,7 +138,7 @@ func EncodeState(state State) ([]byte, error) {
b = append(b, blockNumberBytes...)
b = append(b, timestampBytes...)

gasDataEncoded, err := EncodeGasData(state.GasData())
gasDataEncoded, err := EncodeGasData(s.GasData())
if err != nil {
return nil, fmt.Errorf("failed to encode gas data for state %w", err)
}
Expand Down Expand Up @@ -173,9 +178,9 @@ func DecodeState(toDecode []byte) (State, error) {
}, nil
}

// EncodeSnapshot encodes a snapshot.
func EncodeSnapshot(snapshot Snapshot) ([]byte, error) {
states := snapshot.States()
// Encode encodes a snapshot.
func (s snapshot) Encode() ([]byte, error) {
states := s.States()

if len(states) == 0 {
return nil, fmt.Errorf("no states to encode")
Expand All @@ -184,7 +189,7 @@ func EncodeSnapshot(snapshot Snapshot) ([]byte, error) {
encodedStates := make([]byte, 0)

for _, state := range states {
encodedState, err := EncodeState(state)
encodedState, err := state.Encode()
if err != nil {
return nil, fmt.Errorf("could not encode state: %w", err)
}
Expand Down Expand Up @@ -215,22 +220,22 @@ func DecodeSnapshot(toDecode []byte) (Snapshot, error) {
}, nil
}

// EncodeAttestation encodes an attestation.
func EncodeAttestation(attestation Attestation) ([]byte, error) {
// Encode encodes an attestation.
func (a attestation) Encode() ([]byte, error) {
b := make([]byte, 0)
nonceBytes := make([]byte, uint32Len)

binary.BigEndian.PutUint32(nonceBytes, attestation.Nonce())
snapshotRoot := attestation.SnapshotRoot()
dataHash := attestation.DataHash()
binary.BigEndian.PutUint32(nonceBytes, a.Nonce())
snapshotRoot := a.SnapshotRoot()
dataHash := a.DataHash()

// Note that since we are packing an 8 byte (int64) number into 5 bytes, we need to
// ensure that the result does not exceed the expected byte length for a valid Attestation.
blockNumberBytes := math.PaddedBigBytes(attestation.BlockNumber(), uint40Len)
// ensure that the result does not exceed the expected byte length for a valid a.
blockNumberBytes := math.PaddedBigBytes(a.BlockNumber(), uint40Len)
if len(blockNumberBytes) != uint40Len {
return nil, fmt.Errorf("invalid block number length, expected %d, got %d", uint40Len, len(blockNumberBytes))
}
timestampBytes := math.PaddedBigBytes(attestation.Timestamp(), uint40Len)
timestampBytes := math.PaddedBigBytes(a.Timestamp(), uint40Len)
if len(timestampBytes) != uint40Len {
return nil, fmt.Errorf("invalid timestamp length, expected %d, got %d", uint40Len, len(timestampBytes))
}
Expand Down Expand Up @@ -558,26 +563,26 @@ func DecodeAgentStatus(toDecode []byte) (AgentStatus, error) {
}, nil
}

// EncodeReceipt encodes an receipt.
func EncodeReceipt(receipt Receipt) ([]byte, error) {
// Encode encodes an receipt.
func (r receipt) Encode() ([]byte, error) {
b := make([]byte, 0)
originBytes := make([]byte, uint32Len)
binary.BigEndian.PutUint32(originBytes, receipt.Origin())
binary.BigEndian.PutUint32(originBytes, r.Origin())

destBytes := make([]byte, uint32Len)
binary.BigEndian.PutUint32(destBytes, receipt.Destination())
binary.BigEndian.PutUint32(destBytes, r.Destination())

messageHashBytes := receipt.MessageHash()
snapshotRootBytes := receipt.SnapshotRoot()
messageHashBytes := r.MessageHash()
snapshotRootBytes := r.SnapshotRoot()

b = append(b, originBytes...)
b = append(b, destBytes...)
b = append(b, messageHashBytes[:]...)
b = append(b, snapshotRootBytes[:]...)
b = append(b, []byte{receipt.StateIndex()}...)
b = append(b, receipt.AttestationNotary().Bytes()...)
b = append(b, receipt.FirstExecutor().Bytes()...)
b = append(b, receipt.FinalExecutor().Bytes()...)
b = append(b, []byte{r.StateIndex()}...)
b = append(b, r.AttestationNotary().Bytes()...)
b = append(b, r.FirstExecutor().Bytes()...)
b = append(b, r.FinalExecutor().Bytes()...)

return b, nil
}
Expand Down
31 changes: 5 additions & 26 deletions agents/types/receipt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ package types

import (
"context"
"fmt"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/synapsecns/sanguine/core"
"github.com/synapsecns/sanguine/ethergo/signer/signer"
)

Expand All @@ -24,6 +21,7 @@ const (

// Receipt is the receipt interface.
type Receipt interface {
Encoder
// Origin is the origin of the receipt.
Origin() uint32
// Destination is the destination of the receipt.
Expand Down Expand Up @@ -101,31 +99,12 @@ func (r receipt) FinalExecutor() common.Address {
return r.finalExecutor
}

//nolint:dupl
func (r receipt) SignReceipt(ctx context.Context, signer signer.Signer, valid bool) (signer.Signature, []byte, common.Hash, error) {
encodedReceipt, err := EncodeReceipt(r)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not encode receipt: %w", err)
}

var receiptSalt common.Hash
var receiptSalt string
if valid {
receiptSalt = crypto.Keccak256Hash([]byte("RECEIPT_VALID_SALT"))
receiptSalt = ReceiptValidSalt
} else {
receiptSalt = crypto.Keccak256Hash([]byte("RECEIPT_INVALID_SALT"))
}

hashedEncodedReceipt := crypto.Keccak256Hash(encodedReceipt).Bytes()
toSign := append(receiptSalt.Bytes(), hashedEncodedReceipt...)

hashedReceipt, err := HashRawBytes(toSign)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not hash receipt: %w", err)
}

signature, err := signer.SignMessage(ctx, core.BytesToSlice(hashedReceipt), false)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not sign receipt: %w", err)
receiptSalt = ReceiptInvalidSalt
}
return signature, encodedReceipt, hashedReceipt, nil
return signEncoder(ctx, signer, r, receiptSalt)
}
21 changes: 21 additions & 0 deletions agents/types/salt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package types

// Note: these salt values are taken from contracts/libs/Constants.sol.

// AttestationValidSalt is the salt for ATTESTATION_VALID_SALT.
const AttestationValidSalt = "ATTESTATION_VALID_SALT"

// AttestationInvalidSalt is the salt for ATTESTATION_INVALID_SALT.
const AttestationInvalidSalt = "ATTESTATION_INVALID_SALT"

// ReceiptValidSalt is the salt for RECEIPT_VALID_SALT.
const ReceiptValidSalt = "RECEIPT_VALID_SALT"

// ReceiptInvalidSalt is the salt for RECEIPT_INVALID_SALT.
const ReceiptInvalidSalt = "RECEIPT_INVALID_SALT"

// SnapshotValidSalt is the salt for SNAPSHOT_VALID_SALT.
const SnapshotValidSalt = "SNAPSHOT_VALID_SALT"

// StateInvalidSalt is the salt for STATE_INVALID_SALT.
const StateInvalidSalt = "STATE_INVALID_SALT"
26 changes: 2 additions & 24 deletions agents/types/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@ import (
"fmt"
"math"

"github.com/ethereum/go-ethereum/crypto"

"github.com/ethereum/go-ethereum/common"
"github.com/synapsecns/sanguine/core"
"github.com/synapsecns/sanguine/core/merkle"
"github.com/synapsecns/sanguine/ethergo/signer/signer"
)

// Snapshot is the snapshot interface.
type Snapshot interface {
Encoder
// States are the states of the snapshot.
States() []State

Expand Down Expand Up @@ -76,28 +74,8 @@ func (s snapshot) TreeHeight() uint32 {
return uint32(math.Log2(float64(len(s.states) * 2)))
}

//nolint:dupl
func (s snapshot) SignSnapshot(ctx context.Context, signer signer.Signer) (signer.Signature, []byte, common.Hash, error) {
encodedSnapshot, err := EncodeSnapshot(s)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not encode snapshot: %w", err)
}

snapshotSalt := crypto.Keccak256Hash([]byte("SNAPSHOT_VALID_SALT"))

hashedEncodedSnapshot := crypto.Keccak256Hash(encodedSnapshot).Bytes()
toSign := append(snapshotSalt.Bytes(), hashedEncodedSnapshot...)

hashedSnapshot, err := HashRawBytes(toSign)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not hash snapshot: %w", err)
}

signature, err := signer.SignMessage(ctx, core.BytesToSlice(hashedSnapshot), false)
if err != nil {
return nil, nil, common.Hash{}, fmt.Errorf("could not sign snapshot: %w", err)
}
return signature, encodedSnapshot, hashedSnapshot, nil
return signEncoder(ctx, signer, s, SnapshotValidSalt)
}

var _ Snapshot = &snapshot{}
Loading