Skip to content

Commit

Permalink
refactor marshalActions implementation (#1631)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsachiherman authored Oct 4, 2024
1 parent 2cb5530 commit fee360f
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 72 deletions.
7 changes: 1 addition & 6 deletions api/jsonrpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,7 @@ func (j *JSONRPCServer) SubmitTx(
if !rtx.Empty() {
return errors.New("tx has extra bytes")
}
msg, err := tx.Digest()
if err != nil {
// Should never occur because populated during unmarshal
return err
}
if err := tx.Auth.Verify(ctx, msg); err != nil {
if err := tx.Verify(ctx); err != nil {
return err
}
txID := tx.ID()
Expand Down
7 changes: 1 addition & 6 deletions api/ws/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,7 @@ func (w *WebSocketServer) MessageCallback() pubsub.Callback {

// Verify tx
if w.vm.GetVerifyAuth() {
msg, err := tx.Digest()
if err != nil {
// Should never occur because populated during unmarshal
return
}
if err := tx.Auth.Verify(ctx, msg); err != nil {
if err := tx.Verify(ctx); err != nil {
w.logger.Error("failed to verify sig",
zap.Error(err),
)
Expand Down
4 changes: 2 additions & 2 deletions chain/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ func (b *StatefulBlock) populateTxs(ctx context.Context) error {

// Verify signature async
if b.vm.GetVerifyAuth() {
txDigest, err := tx.Digest()
unsignedTxBytes, err := tx.UnsignedBytes()
if err != nil {
return err
}
batchVerifier.Add(txDigest, tx.Auth)
batchVerifier.Add(unsignedTxBytes, tx.Auth)
}
}
return nil
Expand Down
129 changes: 82 additions & 47 deletions chain/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,68 +31,71 @@ var (
type Transaction struct {
Base *Base `json:"base"`

Actions []Action `json:"actions"`
Auth Auth `json:"auth"`

digest []byte
bytes []byte
size int
id ids.ID
stateKeys state.Keys
Actions Actions `json:"actions"`
Auth Auth `json:"auth"`

unsignedBytes []byte
bytes []byte
size int
id ids.ID
stateKeys state.Keys
}

func NewTx(base *Base, actions []Action) *Transaction {
func NewTx(base *Base, actions Actions) *Transaction {
return &Transaction{
Base: base,
Actions: actions,
}
}

func (t *Transaction) Digest() ([]byte, error) {
if len(t.digest) > 0 {
return t.digest, nil
// UnsignedBytes returns the byte slice representation of the unsigned tx
func (t *Transaction) UnsignedBytes() ([]byte, error) {
if len(t.unsignedBytes) > 0 {
return t.unsignedBytes, nil
}
size := t.Base.Size() + consts.Uint8Len
for _, action := range t.Actions {
actionSize, err := GetSize(action)
if err != nil {
return nil, err
}
size += consts.ByteLen + actionSize

actionsSize, err := t.Actions.Size()
if err != nil {
return nil, err
}
size += actionsSize

p := codec.NewWriter(size, consts.NetworkSizeLimit)
t.Base.Marshal(p)
p.PackByte(uint8(len(t.Actions)))
for _, action := range t.Actions {
p.PackByte(action.GetTypeID())
err := marshalInto(action, p)
if err != nil {
return nil, err
}
if err := t.marshal(p, false); err != nil {
return nil, err
}

return p.Bytes(), p.Err()
}

// Sign returns a new signed transaction with the unsigned tx copied from
// the original and a signature provided by the authFactory
func (t *Transaction) Sign(
factory AuthFactory,
actionRegistry ActionRegistry,
authRegistry AuthRegistry,
) (*Transaction, error) {
msg, err := t.Digest()
msg, err := t.UnsignedBytes()
if err != nil {
return nil, err
}
auth, err := factory.Sign(msg)
if err != nil {
return nil, err
}
t.Auth = auth

signedTransaction := Transaction{
Base: t.Base,
Actions: t.Actions,
Auth: auth,
}

// Ensure transaction is fully initialized and correct by reloading it from
// bytes
size := len(msg) + consts.ByteLen + t.Auth.Size()
size := len(msg) + consts.ByteLen + auth.Size()
p := codec.NewWriter(size, consts.NetworkSizeLimit)
if err := t.Marshal(p); err != nil {
if err := signedTransaction.Marshal(p); err != nil {
return nil, err
}
if err := p.Err(); err != nil {
Expand All @@ -102,6 +105,16 @@ func (t *Transaction) Sign(
return UnmarshalTx(p, actionRegistry, authRegistry)
}

// Verify that the transaction was signed correctly.
func (t *Transaction) Verify(ctx context.Context) error {
msg, err := t.UnsignedBytes()
if err != nil {
// Should never occur because populated during unmarshal
return err
}
return t.Auth.Verify(ctx, msg)
}

func (t *Transaction) Bytes() []byte { return t.bytes }

func (t *Transaction) Size() int { return t.size }
Expand Down Expand Up @@ -195,7 +208,7 @@ func (t *Transaction) Units(sm StateManager, r Rules) (fees.Dimensions, error) {
// to execute a transaction.
//
// This is typically used during transaction construction.
func EstimateUnits(r Rules, actions []Action, authFactory AuthFactory) (fees.Dimensions, error) {
func EstimateUnits(r Rules, actions Actions, authFactory AuthFactory) (fees.Dimensions, error) {
var (
bandwidth = uint64(BaseSize)
stateKeysMaxChunks = []uint16{} // TODO: preallocate
Expand Down Expand Up @@ -377,25 +390,47 @@ func (t *Transaction) Marshal(p *codec.Packer) error {
p.PackFixedBytes(t.bytes)
return p.Err()
}

return t.marshalActions(p)
return t.marshal(p, true)
}

func (t *Transaction) marshalActions(p *codec.Packer) error {
func (t *Transaction) marshal(p *codec.Packer, marshalSignature bool) error {
t.Base.Marshal(p)
p.PackByte(uint8(len(t.Actions)))
for _, action := range t.Actions {
actionID := action.GetTypeID()
p.PackByte(actionID)
if err := t.Actions.marshalInto(p); err != nil {
return err
}

if marshalSignature {
authID := t.Auth.GetTypeID()
p.PackByte(authID)
t.Auth.Marshal(p)
}
return p.Err()
}

type Actions []Action

func (a Actions) Size() (int, error) {
var size int
for _, action := range a {
actionSize, err := GetSize(action)
if err != nil {
return 0, err
}
size += consts.ByteLen + actionSize
}
return size, nil
}

func (a Actions) marshalInto(p *codec.Packer) error {
p.PackByte(uint8(len(a)))
for _, action := range a {
p.PackByte(action.GetTypeID())
err := marshalInto(action, p)
if err != nil {
return err
}
}
authID := t.Auth.GetTypeID()
p.PackByte(authID)
t.Auth.Marshal(p)
return p.Err()
return nil
}

func MarshalTxs(txs []*Transaction) ([]byte, error) {
Expand Down Expand Up @@ -448,7 +483,7 @@ func UnmarshalTx(
if err != nil {
return nil, fmt.Errorf("%w: could not unmarshal base", err)
}
actions, err := unmarshalActions(p, actionRegistry)
actions, err := UnmarshalActions(p, actionRegistry)
if err != nil {
return nil, fmt.Errorf("%w: could not unmarshal actions", err)
}
Expand All @@ -474,22 +509,22 @@ func UnmarshalTx(
return nil, p.Err()
}
codecBytes := p.Bytes()
tx.digest = codecBytes[start:digest]
tx.unsignedBytes = codecBytes[start:digest]
tx.bytes = codecBytes[start:p.Offset()] // ensure errors handled before grabbing memory
tx.size = len(tx.bytes)
tx.id = utils.ToID(tx.bytes)
return &tx, nil
}

func unmarshalActions(
func UnmarshalActions(
p *codec.Packer,
actionRegistry *codec.TypeParser[Action],
) ([]Action, error) {
) (Actions, error) {
actionCount := p.UnpackByte()
if actionCount == 0 {
return nil, fmt.Errorf("%w: no actions", ErrInvalidObject)
}
actions := []Action{}
actions := Actions{}
for i := uint8(0); i < actionCount; i++ {
action, err := actionRegistry.Unmarshal(p)
if err != nil {
Expand Down
36 changes: 31 additions & 5 deletions chain/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,45 @@ func TestMarshalUnmarshal(t *testing.T) {
err = actionRegistry.Register(&action2{}, unmarshalAction2)
require.NoError(err)

txBeforeSign := chain.Transaction{
Base: &chain.Base{
Timestamp: 1724315246000,
ChainID: [32]byte{1, 2, 3, 4, 5, 6, 7},
MaxFee: 1234567,
},
Actions: []chain.Action{
&mockTransferAction{
To: codec.Address{1, 2, 3, 4},
Value: 4,
Memo: []byte("hello"),
},
&mockTransferAction{
To: codec.Address{4, 5, 6, 7},
Value: 123,
Memo: []byte("world"),
},
&action2{
A: 2,
B: 4,
},
},
}

require.Nil(tx.Auth)
signedTx, err := tx.Sign(factory, actionRegistry, authRegistry)
require.NoError(err)

require.Equal(txBeforeSign, tx)
require.NotNil(signedTx.Auth)
require.Equal(len(signedTx.Actions), len(tx.Actions))
for i, action := range signedTx.Actions {
require.Equal(tx.Actions[i], action)
}

signedDigest, err := signedTx.Digest()
unsignedTxBytes, err := signedTx.UnsignedBytes()
require.NoError(err)
txDigest, err := tx.Digest()
originalUnsignedTxBytes, err := tx.UnsignedBytes()
require.NoError(err)

require.Equal(signedDigest, txDigest)
require.Len(signedDigest, 168)
require.Equal(unsignedTxBytes, originalUnsignedTxBytes)
require.Len(unsignedTxBytes, 168)
}
4 changes: 2 additions & 2 deletions internal/gossiper/proposer.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (g *Proposer) HandleAppGossip(ctx context.Context, nodeID ids.NodeID, msg [
var seen int
for _, tx := range txs {
// Verify signature async
txDigest, err := tx.Digest()
unsignedTxBytes, err := tx.UnsignedBytes()
if err != nil {
g.vm.Logger().Warn(
"unable to compute tx digest",
Expand All @@ -203,7 +203,7 @@ func (g *Proposer) HandleAppGossip(ctx context.Context, nodeID ids.NodeID, msg [
batchVerifier.Done(nil)
return nil
}
batchVerifier.Add(txDigest, tx.Auth)
batchVerifier.Add(unsignedTxBytes, tx.Auth)

// Add incoming txs to the cache to make
// sure we never gossip anything we receive (someone
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,9 @@ var _ = ginkgo.Describe("[Tx Processing]", ginkgo.Serial, func() {
)
// Must do manual construction to avoid `tx.Sign` error (would fail with
// 0 timestamp)
msg, err := tx.Digest()
unsignedTxBytes, err := tx.UnsignedBytes()
require.NoError(err)
auth, err := authFactory.Sign(msg)
auth, err := authFactory.Sign(unsignedTxBytes)
require.NoError(err)
tx.Auth = auth
p := codec.NewWriter(0, consts.MaxInt) // test codec growth
Expand Down
4 changes: 2 additions & 2 deletions vm/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -924,13 +924,13 @@ func (vm *VM) Submit(

// Verify auth if not already verified by caller
if verifyAuth && vm.config.VerifyAuth {
msg, err := tx.Digest()
unsignedTxBytes, err := tx.UnsignedBytes()
if err != nil {
// Should never fail
errs = append(errs, err)
continue
}
if err := tx.Auth.Verify(ctx, msg); err != nil {
if err := tx.Auth.Verify(ctx, unsignedTxBytes); err != nil {
// Failed signature verification is the only safe place to remove
// a transaction in listeners. Every other case may still end up with
// the transaction in a block.
Expand Down

0 comments on commit fee360f

Please sign in to comment.