From 9cfa4c965d8403e034811d916e954da3a6f558f6 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Sun, 21 Jul 2024 14:14:11 -0400 Subject: [PATCH 1/2] Fees are calculated over unsigned txs --- vms/platformvm/service_test.go | 2 +- .../txs/executor/staker_tx_verification.go | 14 ++--- .../txs/executor/standard_tx_executor.go | 12 ++-- vms/platformvm/txs/fee/calculator.go | 2 +- vms/platformvm/txs/fee/calculator_test.go | 4 +- vms/platformvm/txs/fee/static_calculator.go | 56 +++++++++++-------- vms/platformvm/txs/txstest/context.go | 4 +- 7 files changed, 52 insertions(+), 42 deletions(-) diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index ccda2c6c054b..97d7d0466894 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -377,7 +377,7 @@ func TestGetBalance(t *testing.T) { feeCalculator, err := state.PickFeeCalculator(&service.vm.Config, service.vm.state) require.NoError(err) - createSubnetFee, err := feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateSubnetTx{}}) + createSubnetFee, err := feeCalculator.CalculateFee(&txs.CreateSubnetTx{}) require.NoError(err) // Ensure GetStake is correct for each of the genesis validators diff --git a/vms/platformvm/txs/executor/staker_tx_verification.go b/vms/platformvm/txs/executor/staker_tx_verification.go index 9b9876e3ccca..b84ad9d6e2b1 100644 --- a/vms/platformvm/txs/executor/staker_tx_verification.go +++ b/vms/platformvm/txs/executor/staker_tx_verification.go @@ -165,7 +165,7 @@ func verifyAddValidatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return nil, err } @@ -258,7 +258,7 @@ func verifyAddSubnetValidatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return err } @@ -338,7 +338,7 @@ func verifyRemoveSubnetValidatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return nil, false, err } @@ -458,7 +458,7 @@ func verifyAddDelegatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return nil, err } @@ -580,7 +580,7 @@ func verifyAddPermissionlessValidatorTx( copy(outs[len(tx.Outs):], tx.StakeOuts) // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return err } @@ -727,7 +727,7 @@ func verifyAddPermissionlessDelegatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return err } @@ -787,7 +787,7 @@ func verifyTransferSubnetOwnershipTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return err } diff --git a/vms/platformvm/txs/executor/standard_tx_executor.go b/vms/platformvm/txs/executor/standard_tx_executor.go index 9ec40e506a1b..58eeb71e1fea 100644 --- a/vms/platformvm/txs/executor/standard_tx_executor.go +++ b/vms/platformvm/txs/executor/standard_tx_executor.go @@ -70,7 +70,7 @@ func (e *StandardTxExecutor) CreateChainTx(tx *txs.CreateChainTx) error { } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -119,7 +119,7 @@ func (e *StandardTxExecutor) CreateSubnetTx(tx *txs.CreateSubnetTx) error { } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -203,7 +203,7 @@ func (e *StandardTxExecutor) ImportTx(tx *txs.ImportTx) error { copy(ins[len(tx.Ins):], tx.ImportedInputs) // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -263,7 +263,7 @@ func (e *StandardTxExecutor) ExportTx(tx *txs.ExportTx) error { } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -457,7 +457,7 @@ func (e *StandardTxExecutor) TransformSubnetTx(tx *txs.TransformSubnetTx) error } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -588,7 +588,7 @@ func (e *StandardTxExecutor) BaseTx(tx *txs.BaseTx) error { } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } diff --git a/vms/platformvm/txs/fee/calculator.go b/vms/platformvm/txs/fee/calculator.go index f33db9c1520f..511bb34ed3c2 100644 --- a/vms/platformvm/txs/fee/calculator.go +++ b/vms/platformvm/txs/fee/calculator.go @@ -7,5 +7,5 @@ import "github.com/ava-labs/avalanchego/vms/platformvm/txs" // Calculator is the interfaces that any fee Calculator must implement type Calculator interface { - CalculateFee(tx *txs.Tx) (uint64, error) + CalculateFee(tx txs.UnsignedTx) (uint64, error) } diff --git a/vms/platformvm/txs/fee/calculator_test.go b/vms/platformvm/txs/fee/calculator_test.go index 454072e9df8d..656b9ca0889f 100644 --- a/vms/platformvm/txs/fee/calculator_test.go +++ b/vms/platformvm/txs/fee/calculator_test.go @@ -187,9 +187,9 @@ func TestTxFees(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - uTx := tt.unsignedTx() + tx := tt.unsignedTx() fc := NewStaticCalculator(feeTestsDefaultCfg, upgrades, tt.chainTime) - fee, err := fc.CalculateFee(&txs.Tx{Unsigned: uTx}) + fee, err := fc.CalculateFee(tx) require.NoError(t, err) require.Equal(t, tt.expected, fee) }) diff --git a/vms/platformvm/txs/fee/static_calculator.go b/vms/platformvm/txs/fee/static_calculator.go index 7bfb5cf799a0..f424d30220d0 100644 --- a/vms/platformvm/txs/fee/static_calculator.go +++ b/vms/platformvm/txs/fee/static_calculator.go @@ -13,7 +13,7 @@ import ( var ( _ Calculator = (*staticCalculator)(nil) - _ txs.Visitor = (*staticCalculator)(nil) + _ txs.Visitor = (*staticVisitor)(nil) ) func NewStaticCalculator( @@ -29,37 +29,47 @@ func NewStaticCalculator( } type staticCalculator struct { - // inputs staticCfg StaticConfig upgrades upgrade.Config time time.Time +} - // outputs of visitor execution - fee uint64 +func (c *staticCalculator) CalculateFee(tx txs.UnsignedTx) (uint64, error) { + v := staticVisitor{ + staticCfg: c.staticCfg, + upgrades: c.upgrades, + time: c.time, + } + err := tx.Visit(&v) + return v.fee, err } -func (c *staticCalculator) CalculateFee(tx *txs.Tx) (uint64, error) { - c.fee = 0 // zero fee among different calculateFee invocations (unlike gas which gets cumulated) - err := tx.Unsigned.Visit(c) - return c.fee, err +type staticVisitor struct { + // inputs + staticCfg StaticConfig + upgrades upgrade.Config + time time.Time + + // outputs + fee uint64 } -func (c *staticCalculator) AddValidatorTx(*txs.AddValidatorTx) error { +func (c *staticVisitor) AddValidatorTx(*txs.AddValidatorTx) error { c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee return nil } -func (c *staticCalculator) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error { +func (c *staticVisitor) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error { c.fee = c.staticCfg.AddSubnetValidatorFee return nil } -func (c *staticCalculator) AddDelegatorTx(*txs.AddDelegatorTx) error { +func (c *staticVisitor) AddDelegatorTx(*txs.AddDelegatorTx) error { c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee return nil } -func (c *staticCalculator) CreateChainTx(*txs.CreateChainTx) error { +func (c *staticVisitor) CreateChainTx(*txs.CreateChainTx) error { if c.upgrades.IsApricotPhase3Activated(c.time) { c.fee = c.staticCfg.CreateBlockchainTxFee } else { @@ -68,7 +78,7 @@ func (c *staticCalculator) CreateChainTx(*txs.CreateChainTx) error { return nil } -func (c *staticCalculator) CreateSubnetTx(*txs.CreateSubnetTx) error { +func (c *staticVisitor) CreateSubnetTx(*txs.CreateSubnetTx) error { if c.upgrades.IsApricotPhase3Activated(c.time) { c.fee = c.staticCfg.CreateSubnetTxFee } else { @@ -77,32 +87,32 @@ func (c *staticCalculator) CreateSubnetTx(*txs.CreateSubnetTx) error { return nil } -func (c *staticCalculator) AdvanceTimeTx(*txs.AdvanceTimeTx) error { +func (c *staticVisitor) AdvanceTimeTx(*txs.AdvanceTimeTx) error { c.fee = 0 // no fees return nil } -func (c *staticCalculator) RewardValidatorTx(*txs.RewardValidatorTx) error { +func (c *staticVisitor) RewardValidatorTx(*txs.RewardValidatorTx) error { c.fee = 0 // no fees return nil } -func (c *staticCalculator) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error { +func (c *staticVisitor) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error { c.fee = c.staticCfg.TxFee return nil } -func (c *staticCalculator) TransformSubnetTx(*txs.TransformSubnetTx) error { +func (c *staticVisitor) TransformSubnetTx(*txs.TransformSubnetTx) error { c.fee = c.staticCfg.TransformSubnetTxFee return nil } -func (c *staticCalculator) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { +func (c *staticVisitor) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { c.fee = c.staticCfg.TxFee return nil } -func (c *staticCalculator) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error { +func (c *staticVisitor) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error { if tx.Subnet != constants.PrimaryNetworkID { c.fee = c.staticCfg.AddSubnetValidatorFee } else { @@ -111,7 +121,7 @@ func (c *staticCalculator) AddPermissionlessValidatorTx(tx *txs.AddPermissionles return nil } -func (c *staticCalculator) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error { +func (c *staticVisitor) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error { if tx.Subnet != constants.PrimaryNetworkID { c.fee = c.staticCfg.AddSubnetDelegatorFee } else { @@ -120,17 +130,17 @@ func (c *staticCalculator) AddPermissionlessDelegatorTx(tx *txs.AddPermissionles return nil } -func (c *staticCalculator) BaseTx(*txs.BaseTx) error { +func (c *staticVisitor) BaseTx(*txs.BaseTx) error { c.fee = c.staticCfg.TxFee return nil } -func (c *staticCalculator) ImportTx(*txs.ImportTx) error { +func (c *staticVisitor) ImportTx(*txs.ImportTx) error { c.fee = c.staticCfg.TxFee return nil } -func (c *staticCalculator) ExportTx(*txs.ExportTx) error { +func (c *staticVisitor) ExportTx(*txs.ExportTx) error { c.fee = c.staticCfg.TxFee return nil } diff --git a/vms/platformvm/txs/txstest/context.go b/vms/platformvm/txs/txstest/context.go index fb8181f7e39c..80507c1a267d 100644 --- a/vms/platformvm/txs/txstest/context.go +++ b/vms/platformvm/txs/txstest/context.go @@ -20,8 +20,8 @@ func newContext( ) *builder.Context { var ( feeCalculator = fee.NewStaticCalculator(cfg.StaticFeeConfig, cfg.UpgradeConfig, timestamp) - createSubnetFee, _ = feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateSubnetTx{}}) - createChainFee, _ = feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateChainTx{}}) + createSubnetFee, _ = feeCalculator.CalculateFee(&txs.CreateSubnetTx{}) + createChainFee, _ = feeCalculator.CalculateFee(&txs.CreateChainTx{}) ) return &builder.Context{ From c29bd023140bcc44d85f23663380aba3aa47edfe Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Sun, 21 Jul 2024 14:31:37 -0400 Subject: [PATCH 2/2] nit --- vms/platformvm/txs/fee/static_calculator.go | 58 ++++++++++----------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/vms/platformvm/txs/fee/static_calculator.go b/vms/platformvm/txs/fee/static_calculator.go index f424d30220d0..6c1b535202f0 100644 --- a/vms/platformvm/txs/fee/static_calculator.go +++ b/vms/platformvm/txs/fee/static_calculator.go @@ -22,23 +22,23 @@ func NewStaticCalculator( chainTime time.Time, ) Calculator { return &staticCalculator{ - upgrades: upgradeTimes, - staticCfg: config, - time: chainTime, + upgrades: upgradeTimes, + config: config, + time: chainTime, } } type staticCalculator struct { - staticCfg StaticConfig - upgrades upgrade.Config - time time.Time + config StaticConfig + upgrades upgrade.Config + time time.Time } func (c *staticCalculator) CalculateFee(tx txs.UnsignedTx) (uint64, error) { v := staticVisitor{ - staticCfg: c.staticCfg, - upgrades: c.upgrades, - time: c.time, + config: c.config, + upgrades: c.upgrades, + time: c.time, } err := tx.Visit(&v) return v.fee, err @@ -46,43 +46,43 @@ func (c *staticCalculator) CalculateFee(tx txs.UnsignedTx) (uint64, error) { type staticVisitor struct { // inputs - staticCfg StaticConfig - upgrades upgrade.Config - time time.Time + config StaticConfig + upgrades upgrade.Config + time time.Time // outputs fee uint64 } func (c *staticVisitor) AddValidatorTx(*txs.AddValidatorTx) error { - c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee + c.fee = c.config.AddPrimaryNetworkValidatorFee return nil } func (c *staticVisitor) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error { - c.fee = c.staticCfg.AddSubnetValidatorFee + c.fee = c.config.AddSubnetValidatorFee return nil } func (c *staticVisitor) AddDelegatorTx(*txs.AddDelegatorTx) error { - c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee + c.fee = c.config.AddPrimaryNetworkDelegatorFee return nil } func (c *staticVisitor) CreateChainTx(*txs.CreateChainTx) error { if c.upgrades.IsApricotPhase3Activated(c.time) { - c.fee = c.staticCfg.CreateBlockchainTxFee + c.fee = c.config.CreateBlockchainTxFee } else { - c.fee = c.staticCfg.CreateAssetTxFee + c.fee = c.config.CreateAssetTxFee } return nil } func (c *staticVisitor) CreateSubnetTx(*txs.CreateSubnetTx) error { if c.upgrades.IsApricotPhase3Activated(c.time) { - c.fee = c.staticCfg.CreateSubnetTxFee + c.fee = c.config.CreateSubnetTxFee } else { - c.fee = c.staticCfg.CreateAssetTxFee + c.fee = c.config.CreateAssetTxFee } return nil } @@ -98,49 +98,49 @@ func (c *staticVisitor) RewardValidatorTx(*txs.RewardValidatorTx) error { } func (c *staticVisitor) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error { - c.fee = c.staticCfg.TxFee + c.fee = c.config.TxFee return nil } func (c *staticVisitor) TransformSubnetTx(*txs.TransformSubnetTx) error { - c.fee = c.staticCfg.TransformSubnetTxFee + c.fee = c.config.TransformSubnetTxFee return nil } func (c *staticVisitor) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { - c.fee = c.staticCfg.TxFee + c.fee = c.config.TxFee return nil } func (c *staticVisitor) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error { if tx.Subnet != constants.PrimaryNetworkID { - c.fee = c.staticCfg.AddSubnetValidatorFee + c.fee = c.config.AddSubnetValidatorFee } else { - c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee + c.fee = c.config.AddPrimaryNetworkValidatorFee } return nil } func (c *staticVisitor) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error { if tx.Subnet != constants.PrimaryNetworkID { - c.fee = c.staticCfg.AddSubnetDelegatorFee + c.fee = c.config.AddSubnetDelegatorFee } else { - c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee + c.fee = c.config.AddPrimaryNetworkDelegatorFee } return nil } func (c *staticVisitor) BaseTx(*txs.BaseTx) error { - c.fee = c.staticCfg.TxFee + c.fee = c.config.TxFee return nil } func (c *staticVisitor) ImportTx(*txs.ImportTx) error { - c.fee = c.staticCfg.TxFee + c.fee = c.config.TxFee return nil } func (c *staticVisitor) ExportTx(*txs.ExportTx) error { - c.fee = c.staticCfg.TxFee + c.fee = c.config.TxFee return nil }