diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index ccda2c6c054b..97d7d0466894 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -377,7 +377,7 @@ func TestGetBalance(t *testing.T) { feeCalculator, err := state.PickFeeCalculator(&service.vm.Config, service.vm.state) require.NoError(err) - createSubnetFee, err := feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateSubnetTx{}}) + createSubnetFee, err := feeCalculator.CalculateFee(&txs.CreateSubnetTx{}) require.NoError(err) // Ensure GetStake is correct for each of the genesis validators diff --git a/vms/platformvm/txs/executor/staker_tx_verification.go b/vms/platformvm/txs/executor/staker_tx_verification.go index 9b9876e3ccca..b84ad9d6e2b1 100644 --- a/vms/platformvm/txs/executor/staker_tx_verification.go +++ b/vms/platformvm/txs/executor/staker_tx_verification.go @@ -165,7 +165,7 @@ func verifyAddValidatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return nil, err } @@ -258,7 +258,7 @@ func verifyAddSubnetValidatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return err } @@ -338,7 +338,7 @@ func verifyRemoveSubnetValidatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return nil, false, err } @@ -458,7 +458,7 @@ func verifyAddDelegatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return nil, err } @@ -580,7 +580,7 @@ func verifyAddPermissionlessValidatorTx( copy(outs[len(tx.Outs):], tx.StakeOuts) // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return err } @@ -727,7 +727,7 @@ func verifyAddPermissionlessDelegatorTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return err } @@ -787,7 +787,7 @@ func verifyTransferSubnetOwnershipTx( } // Verify the flowcheck - fee, err := feeCalculator.CalculateFee(sTx) + fee, err := feeCalculator.CalculateFee(tx) if err != nil { return err } diff --git a/vms/platformvm/txs/executor/standard_tx_executor.go b/vms/platformvm/txs/executor/standard_tx_executor.go index 9ec40e506a1b..58eeb71e1fea 100644 --- a/vms/platformvm/txs/executor/standard_tx_executor.go +++ b/vms/platformvm/txs/executor/standard_tx_executor.go @@ -70,7 +70,7 @@ func (e *StandardTxExecutor) CreateChainTx(tx *txs.CreateChainTx) error { } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -119,7 +119,7 @@ func (e *StandardTxExecutor) CreateSubnetTx(tx *txs.CreateSubnetTx) error { } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -203,7 +203,7 @@ func (e *StandardTxExecutor) ImportTx(tx *txs.ImportTx) error { copy(ins[len(tx.Ins):], tx.ImportedInputs) // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -263,7 +263,7 @@ func (e *StandardTxExecutor) ExportTx(tx *txs.ExportTx) error { } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -457,7 +457,7 @@ func (e *StandardTxExecutor) TransformSubnetTx(tx *txs.TransformSubnetTx) error } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } @@ -588,7 +588,7 @@ func (e *StandardTxExecutor) BaseTx(tx *txs.BaseTx) error { } // Verify the flowcheck - fee, err := e.FeeCalculator.CalculateFee(e.Tx) + fee, err := e.FeeCalculator.CalculateFee(tx) if err != nil { return err } diff --git a/vms/platformvm/txs/fee/calculator.go b/vms/platformvm/txs/fee/calculator.go index f33db9c1520f..511bb34ed3c2 100644 --- a/vms/platformvm/txs/fee/calculator.go +++ b/vms/platformvm/txs/fee/calculator.go @@ -7,5 +7,5 @@ import "github.com/ava-labs/avalanchego/vms/platformvm/txs" // Calculator is the interfaces that any fee Calculator must implement type Calculator interface { - CalculateFee(tx *txs.Tx) (uint64, error) + CalculateFee(tx txs.UnsignedTx) (uint64, error) } diff --git a/vms/platformvm/txs/fee/calculator_test.go b/vms/platformvm/txs/fee/calculator_test.go index 454072e9df8d..656b9ca0889f 100644 --- a/vms/platformvm/txs/fee/calculator_test.go +++ b/vms/platformvm/txs/fee/calculator_test.go @@ -187,9 +187,9 @@ func TestTxFees(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - uTx := tt.unsignedTx() + tx := tt.unsignedTx() fc := NewStaticCalculator(feeTestsDefaultCfg, upgrades, tt.chainTime) - fee, err := fc.CalculateFee(&txs.Tx{Unsigned: uTx}) + fee, err := fc.CalculateFee(tx) require.NoError(t, err) require.Equal(t, tt.expected, fee) }) diff --git a/vms/platformvm/txs/fee/static_calculator.go b/vms/platformvm/txs/fee/static_calculator.go index 7bfb5cf799a0..6c1b535202f0 100644 --- a/vms/platformvm/txs/fee/static_calculator.go +++ b/vms/platformvm/txs/fee/static_calculator.go @@ -13,7 +13,7 @@ import ( var ( _ Calculator = (*staticCalculator)(nil) - _ txs.Visitor = (*staticCalculator)(nil) + _ txs.Visitor = (*staticVisitor)(nil) ) func NewStaticCalculator( @@ -22,115 +22,125 @@ func NewStaticCalculator( chainTime time.Time, ) Calculator { return &staticCalculator{ - upgrades: upgradeTimes, - staticCfg: config, - time: chainTime, + upgrades: upgradeTimes, + config: config, + time: chainTime, } } type staticCalculator struct { - // inputs - staticCfg StaticConfig - upgrades upgrade.Config - time time.Time + config StaticConfig + upgrades upgrade.Config + time time.Time +} - // outputs of visitor execution - fee uint64 +func (c *staticCalculator) CalculateFee(tx txs.UnsignedTx) (uint64, error) { + v := staticVisitor{ + config: c.config, + upgrades: c.upgrades, + time: c.time, + } + err := tx.Visit(&v) + return v.fee, err } -func (c *staticCalculator) CalculateFee(tx *txs.Tx) (uint64, error) { - c.fee = 0 // zero fee among different calculateFee invocations (unlike gas which gets cumulated) - err := tx.Unsigned.Visit(c) - return c.fee, err +type staticVisitor struct { + // inputs + config StaticConfig + upgrades upgrade.Config + time time.Time + + // outputs + fee uint64 } -func (c *staticCalculator) AddValidatorTx(*txs.AddValidatorTx) error { - c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee +func (c *staticVisitor) AddValidatorTx(*txs.AddValidatorTx) error { + c.fee = c.config.AddPrimaryNetworkValidatorFee return nil } -func (c *staticCalculator) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error { - c.fee = c.staticCfg.AddSubnetValidatorFee +func (c *staticVisitor) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error { + c.fee = c.config.AddSubnetValidatorFee return nil } -func (c *staticCalculator) AddDelegatorTx(*txs.AddDelegatorTx) error { - c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee +func (c *staticVisitor) AddDelegatorTx(*txs.AddDelegatorTx) error { + c.fee = c.config.AddPrimaryNetworkDelegatorFee return nil } -func (c *staticCalculator) CreateChainTx(*txs.CreateChainTx) error { +func (c *staticVisitor) CreateChainTx(*txs.CreateChainTx) error { if c.upgrades.IsApricotPhase3Activated(c.time) { - c.fee = c.staticCfg.CreateBlockchainTxFee + c.fee = c.config.CreateBlockchainTxFee } else { - c.fee = c.staticCfg.CreateAssetTxFee + c.fee = c.config.CreateAssetTxFee } return nil } -func (c *staticCalculator) CreateSubnetTx(*txs.CreateSubnetTx) error { +func (c *staticVisitor) CreateSubnetTx(*txs.CreateSubnetTx) error { if c.upgrades.IsApricotPhase3Activated(c.time) { - c.fee = c.staticCfg.CreateSubnetTxFee + c.fee = c.config.CreateSubnetTxFee } else { - c.fee = c.staticCfg.CreateAssetTxFee + c.fee = c.config.CreateAssetTxFee } return nil } -func (c *staticCalculator) AdvanceTimeTx(*txs.AdvanceTimeTx) error { +func (c *staticVisitor) AdvanceTimeTx(*txs.AdvanceTimeTx) error { c.fee = 0 // no fees return nil } -func (c *staticCalculator) RewardValidatorTx(*txs.RewardValidatorTx) error { +func (c *staticVisitor) RewardValidatorTx(*txs.RewardValidatorTx) error { c.fee = 0 // no fees return nil } -func (c *staticCalculator) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error { - c.fee = c.staticCfg.TxFee +func (c *staticVisitor) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error { + c.fee = c.config.TxFee return nil } -func (c *staticCalculator) TransformSubnetTx(*txs.TransformSubnetTx) error { - c.fee = c.staticCfg.TransformSubnetTxFee +func (c *staticVisitor) TransformSubnetTx(*txs.TransformSubnetTx) error { + c.fee = c.config.TransformSubnetTxFee return nil } -func (c *staticCalculator) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { - c.fee = c.staticCfg.TxFee +func (c *staticVisitor) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error { + c.fee = c.config.TxFee return nil } -func (c *staticCalculator) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error { +func (c *staticVisitor) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error { if tx.Subnet != constants.PrimaryNetworkID { - c.fee = c.staticCfg.AddSubnetValidatorFee + c.fee = c.config.AddSubnetValidatorFee } else { - c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee + c.fee = c.config.AddPrimaryNetworkValidatorFee } return nil } -func (c *staticCalculator) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error { +func (c *staticVisitor) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error { if tx.Subnet != constants.PrimaryNetworkID { - c.fee = c.staticCfg.AddSubnetDelegatorFee + c.fee = c.config.AddSubnetDelegatorFee } else { - c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee + c.fee = c.config.AddPrimaryNetworkDelegatorFee } return nil } -func (c *staticCalculator) BaseTx(*txs.BaseTx) error { - c.fee = c.staticCfg.TxFee +func (c *staticVisitor) BaseTx(*txs.BaseTx) error { + c.fee = c.config.TxFee return nil } -func (c *staticCalculator) ImportTx(*txs.ImportTx) error { - c.fee = c.staticCfg.TxFee +func (c *staticVisitor) ImportTx(*txs.ImportTx) error { + c.fee = c.config.TxFee return nil } -func (c *staticCalculator) ExportTx(*txs.ExportTx) error { - c.fee = c.staticCfg.TxFee +func (c *staticVisitor) ExportTx(*txs.ExportTx) error { + c.fee = c.config.TxFee return nil } diff --git a/vms/platformvm/txs/txstest/context.go b/vms/platformvm/txs/txstest/context.go index fb8181f7e39c..80507c1a267d 100644 --- a/vms/platformvm/txs/txstest/context.go +++ b/vms/platformvm/txs/txstest/context.go @@ -20,8 +20,8 @@ func newContext( ) *builder.Context { var ( feeCalculator = fee.NewStaticCalculator(cfg.StaticFeeConfig, cfg.UpgradeConfig, timestamp) - createSubnetFee, _ = feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateSubnetTx{}}) - createChainFee, _ = feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateChainTx{}}) + createSubnetFee, _ = feeCalculator.CalculateFee(&txs.CreateSubnetTx{}) + createChainFee, _ = feeCalculator.CalculateFee(&txs.CreateChainTx{}) ) return &builder.Context{