Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup fee.staticCalculator #3210

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vms/platformvm/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ func TestGetBalance(t *testing.T) {

feeCalculator, err := state.PickFeeCalculator(&service.vm.Config, service.vm.state)
require.NoError(err)
createSubnetFee, err := feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateSubnetTx{}})
createSubnetFee, err := feeCalculator.CalculateFee(&txs.CreateSubnetTx{})
require.NoError(err)

// Ensure GetStake is correct for each of the genesis validators
Expand Down
14 changes: 7 additions & 7 deletions vms/platformvm/txs/executor/staker_tx_verification.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func verifyAddValidatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -258,7 +258,7 @@ func verifyAddSubnetValidatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -338,7 +338,7 @@ func verifyRemoveSubnetValidatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return nil, false, err
}
Expand Down Expand Up @@ -458,7 +458,7 @@ func verifyAddDelegatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -580,7 +580,7 @@ func verifyAddPermissionlessValidatorTx(
copy(outs[len(tx.Outs):], tx.StakeOuts)

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -727,7 +727,7 @@ func verifyAddPermissionlessDelegatorTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -787,7 +787,7 @@ func verifyTransferSubnetOwnershipTx(
}

// Verify the flowcheck
fee, err := feeCalculator.CalculateFee(sTx)
fee, err := feeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions vms/platformvm/txs/executor/standard_tx_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (e *StandardTxExecutor) CreateChainTx(tx *txs.CreateChainTx) error {
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -119,7 +119,7 @@ func (e *StandardTxExecutor) CreateSubnetTx(tx *txs.CreateSubnetTx) error {
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func (e *StandardTxExecutor) ImportTx(tx *txs.ImportTx) error {
copy(ins[len(tx.Ins):], tx.ImportedInputs)

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -263,7 +263,7 @@ func (e *StandardTxExecutor) ExportTx(tx *txs.ExportTx) error {
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -457,7 +457,7 @@ func (e *StandardTxExecutor) TransformSubnetTx(tx *txs.TransformSubnetTx) error
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down Expand Up @@ -588,7 +588,7 @@ func (e *StandardTxExecutor) BaseTx(tx *txs.BaseTx) error {
}

// Verify the flowcheck
fee, err := e.FeeCalculator.CalculateFee(e.Tx)
fee, err := e.FeeCalculator.CalculateFee(tx)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion vms/platformvm/txs/fee/calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import "github.com/ava-labs/avalanchego/vms/platformvm/txs"

// Calculator is the interfaces that any fee Calculator must implement
type Calculator interface {
CalculateFee(tx *txs.Tx) (uint64, error)
CalculateFee(tx txs.UnsignedTx) (uint64, error)
}
4 changes: 2 additions & 2 deletions vms/platformvm/txs/fee/calculator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ func TestTxFees(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
uTx := tt.unsignedTx()
tx := tt.unsignedTx()
fc := NewStaticCalculator(feeTestsDefaultCfg, upgrades, tt.chainTime)
fee, err := fc.CalculateFee(&txs.Tx{Unsigned: uTx})
fee, err := fc.CalculateFee(tx)
require.NoError(t, err)
require.Equal(t, tt.expected, fee)
})
Expand Down
102 changes: 56 additions & 46 deletions vms/platformvm/txs/fee/static_calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

var (
_ Calculator = (*staticCalculator)(nil)
_ txs.Visitor = (*staticCalculator)(nil)
_ txs.Visitor = (*staticVisitor)(nil)
)

func NewStaticCalculator(
Expand All @@ -22,115 +22,125 @@ func NewStaticCalculator(
chainTime time.Time,
) Calculator {
return &staticCalculator{
upgrades: upgradeTimes,
staticCfg: config,
time: chainTime,
upgrades: upgradeTimes,
config: config,
time: chainTime,
}
}

type staticCalculator struct {
// inputs
staticCfg StaticConfig
upgrades upgrade.Config
time time.Time
config StaticConfig
upgrades upgrade.Config
time time.Time
}

// outputs of visitor execution
fee uint64
func (c *staticCalculator) CalculateFee(tx txs.UnsignedTx) (uint64, error) {
v := staticVisitor{
config: c.config,
upgrades: c.upgrades,
time: c.time,
}
err := tx.Visit(&v)
return v.fee, err
}

func (c *staticCalculator) CalculateFee(tx *txs.Tx) (uint64, error) {
c.fee = 0 // zero fee among different calculateFee invocations (unlike gas which gets cumulated)
err := tx.Unsigned.Visit(c)
return c.fee, err
type staticVisitor struct {
// inputs
config StaticConfig
upgrades upgrade.Config
time time.Time

// outputs
fee uint64
}

func (c *staticCalculator) AddValidatorTx(*txs.AddValidatorTx) error {
c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee
func (c *staticVisitor) AddValidatorTx(*txs.AddValidatorTx) error {
c.fee = c.config.AddPrimaryNetworkValidatorFee
return nil
}

func (c *staticCalculator) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error {
c.fee = c.staticCfg.AddSubnetValidatorFee
func (c *staticVisitor) AddSubnetValidatorTx(*txs.AddSubnetValidatorTx) error {
c.fee = c.config.AddSubnetValidatorFee
return nil
}

func (c *staticCalculator) AddDelegatorTx(*txs.AddDelegatorTx) error {
c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee
func (c *staticVisitor) AddDelegatorTx(*txs.AddDelegatorTx) error {
c.fee = c.config.AddPrimaryNetworkDelegatorFee
return nil
}

func (c *staticCalculator) CreateChainTx(*txs.CreateChainTx) error {
func (c *staticVisitor) CreateChainTx(*txs.CreateChainTx) error {
if c.upgrades.IsApricotPhase3Activated(c.time) {
c.fee = c.staticCfg.CreateBlockchainTxFee
c.fee = c.config.CreateBlockchainTxFee
} else {
c.fee = c.staticCfg.CreateAssetTxFee
c.fee = c.config.CreateAssetTxFee
}
return nil
}

func (c *staticCalculator) CreateSubnetTx(*txs.CreateSubnetTx) error {
func (c *staticVisitor) CreateSubnetTx(*txs.CreateSubnetTx) error {
if c.upgrades.IsApricotPhase3Activated(c.time) {
c.fee = c.staticCfg.CreateSubnetTxFee
c.fee = c.config.CreateSubnetTxFee
} else {
c.fee = c.staticCfg.CreateAssetTxFee
c.fee = c.config.CreateAssetTxFee
}
return nil
}

func (c *staticCalculator) AdvanceTimeTx(*txs.AdvanceTimeTx) error {
func (c *staticVisitor) AdvanceTimeTx(*txs.AdvanceTimeTx) error {
c.fee = 0 // no fees
return nil
}

func (c *staticCalculator) RewardValidatorTx(*txs.RewardValidatorTx) error {
func (c *staticVisitor) RewardValidatorTx(*txs.RewardValidatorTx) error {
c.fee = 0 // no fees
return nil
}

func (c *staticCalculator) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) RemoveSubnetValidatorTx(*txs.RemoveSubnetValidatorTx) error {
c.fee = c.config.TxFee
return nil
}

func (c *staticCalculator) TransformSubnetTx(*txs.TransformSubnetTx) error {
c.fee = c.staticCfg.TransformSubnetTxFee
func (c *staticVisitor) TransformSubnetTx(*txs.TransformSubnetTx) error {
c.fee = c.config.TransformSubnetTxFee
return nil
}

func (c *staticCalculator) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) TransferSubnetOwnershipTx(*txs.TransferSubnetOwnershipTx) error {
c.fee = c.config.TxFee
return nil
}

func (c *staticCalculator) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error {
func (c *staticVisitor) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error {
if tx.Subnet != constants.PrimaryNetworkID {
c.fee = c.staticCfg.AddSubnetValidatorFee
c.fee = c.config.AddSubnetValidatorFee
} else {
c.fee = c.staticCfg.AddPrimaryNetworkValidatorFee
c.fee = c.config.AddPrimaryNetworkValidatorFee
}
return nil
}

func (c *staticCalculator) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error {
func (c *staticVisitor) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error {
if tx.Subnet != constants.PrimaryNetworkID {
c.fee = c.staticCfg.AddSubnetDelegatorFee
c.fee = c.config.AddSubnetDelegatorFee
} else {
c.fee = c.staticCfg.AddPrimaryNetworkDelegatorFee
c.fee = c.config.AddPrimaryNetworkDelegatorFee
}
return nil
}

func (c *staticCalculator) BaseTx(*txs.BaseTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) BaseTx(*txs.BaseTx) error {
c.fee = c.config.TxFee
return nil
}

func (c *staticCalculator) ImportTx(*txs.ImportTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) ImportTx(*txs.ImportTx) error {
c.fee = c.config.TxFee
return nil
}

func (c *staticCalculator) ExportTx(*txs.ExportTx) error {
c.fee = c.staticCfg.TxFee
func (c *staticVisitor) ExportTx(*txs.ExportTx) error {
c.fee = c.config.TxFee
return nil
}
4 changes: 2 additions & 2 deletions vms/platformvm/txs/txstest/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ func newContext(
) *builder.Context {
var (
feeCalculator = fee.NewStaticCalculator(cfg.StaticFeeConfig, cfg.UpgradeConfig, timestamp)
createSubnetFee, _ = feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateSubnetTx{}})
createChainFee, _ = feeCalculator.CalculateFee(&txs.Tx{Unsigned: &txs.CreateChainTx{}})
createSubnetFee, _ = feeCalculator.CalculateFee(&txs.CreateSubnetTx{})
createChainFee, _ = feeCalculator.CalculateFee(&txs.CreateChainTx{})
)

return &builder.Context{
Expand Down
Loading