diff --git a/services/rfq/relayer/pricer/fee_pricer_test.go b/services/rfq/relayer/pricer/fee_pricer_test.go index f316622a54..6b6174c7c3 100644 --- a/services/rfq/relayer/pricer/fee_pricer_test.go +++ b/services/rfq/relayer/pricer/fee_pricer_test.go @@ -11,6 +11,7 @@ import ( fetcherMocks "github.com/synapsecns/sanguine/ethergo/submitter/mocks" "github.com/synapsecns/sanguine/services/rfq/relayer/pricer" priceMocks "github.com/synapsecns/sanguine/services/rfq/relayer/pricer/mocks" + "github.com/synapsecns/sanguine/services/rfq/relayer/relconfig" ) var defaultPrices = map[string]float64{"ETH": 2000., "USDC": 1., "MATIC": 0.5} @@ -261,8 +262,8 @@ func (s *PricerSuite) TestGetGasPrice() { func (s *PricerSuite) TestGetTotalFeeWithMultiplier() { // Override the fixed fee multiplier to greater than 1. - s.config.BaseChainConfig.QuoteFixedFeeMultiplier = 2 - s.config.BaseChainConfig.RelayFixedFeeMultiplier = 4 + s.config.BaseChainConfig.QuoteFixedFeeMultiplier = relconfig.NewFloatPtr(2) + s.config.BaseChainConfig.RelayFixedFeeMultiplier = relconfig.NewFloatPtr(4) // Build a new FeePricer with a mocked client for fetching gas price. clientFetcher := new(fetcherMocks.ClientFetcher) @@ -295,7 +296,7 @@ func (s *PricerSuite) TestGetTotalFeeWithMultiplier() { s.Equal(expectedFee, fee) // Override the fixed fee multiplier to less than 1; should default to 1. - s.config.BaseChainConfig.QuoteFixedFeeMultiplier = -1 + s.config.BaseChainConfig.QuoteFixedFeeMultiplier = relconfig.NewFloatPtr(-1) // Build a new FeePricer with a mocked client for fetching gas price. clientOrigin.On(testsuite.GetFunctionName(clientOrigin.SuggestGasPrice), mock.Anything).Once().Return(headerOrigin, nil) @@ -314,7 +315,7 @@ func (s *PricerSuite) TestGetTotalFeeWithMultiplier() { s.Equal(expectedFee, fee) // Reset the fixed fee multiplier to zero; should default to 1 - s.config.BaseChainConfig.QuoteFixedFeeMultiplier = 0 + s.config.BaseChainConfig.QuoteFixedFeeMultiplier = relconfig.NewFloatPtr(0) // Build a new FeePricer with a mocked client for fetching gas price. clientOrigin.On(testsuite.GetFunctionName(clientOrigin.SuggestGasPrice), mock.Anything).Once().Return(headerOrigin, nil) diff --git a/services/rfq/relayer/quoter/quoter_test.go b/services/rfq/relayer/quoter/quoter_test.go index 507b981e0d..44510fabe8 100644 --- a/services/rfq/relayer/quoter/quoter_test.go +++ b/services/rfq/relayer/quoter/quoter_test.go @@ -165,7 +165,7 @@ func (s *QuoterSuite) TestGetOriginAmount() { balance := big.NewInt(1000_000_000) // 1000 USDC setQuoteParams := func(quotePct, quoteOffset float64, minQuoteAmount string) { - s.config.BaseChainConfig.QuotePct = quotePct + s.config.BaseChainConfig.QuotePct = "ePct destTokenCfg := s.config.Chains[dest].Tokens["USDC"] destTokenCfg.MinQuoteAmount = minQuoteAmount originTokenCfg := s.config.Chains[origin].Tokens["USDC"] diff --git a/services/rfq/relayer/relconfig/config.go b/services/rfq/relayer/relconfig/config.go index 5d69fbd38f..36614772d0 100644 --- a/services/rfq/relayer/relconfig/config.go +++ b/services/rfq/relayer/relconfig/config.go @@ -90,14 +90,14 @@ type ChainConfig struct { // MinGasToken is minimum amount of gas that should be leftover after bridging a gas token. MinGasToken string `yaml:"min_gas_token"` // QuotePct is the percent of balance to quote. - QuotePct float64 `yaml:"quote_pct"` + QuotePct *float64 `yaml:"quote_pct"` // QuoteWidthBps is the number of basis points to deduct from the dest amount. // Note that this parameter is applied on a chain level and must be positive. QuoteWidthBps float64 `yaml:"quote_width_bps"` // QuoteFixedFeeMultiplier is the multiplier for the fixed fee, applied when generating quotes. - QuoteFixedFeeMultiplier float64 `yaml:"quote_fixed_fee_multiplier"` + QuoteFixedFeeMultiplier *float64 `yaml:"quote_fixed_fee_multiplier"` // RelayFixedFeeMultiplier is the multiplier for the fixed fee, applied when relaying. - RelayFixedFeeMultiplier float64 `yaml:"relay_fixed_fee_multiplier"` + RelayFixedFeeMultiplier *float64 `yaml:"relay_fixed_fee_multiplier"` // CCTP start block is the block at which the chain listener will listen for CCTP events. CCTPStartBlock uint64 `yaml:"cctp_start_block"` } diff --git a/services/rfq/relayer/relconfig/config_test.go b/services/rfq/relayer/relconfig/config_test.go index 2d64aec9e1..bac0016c52 100644 --- a/services/rfq/relayer/relconfig/config_test.go +++ b/services/rfq/relayer/relconfig/config_test.go @@ -1,10 +1,11 @@ package relconfig_test import ( - "github.com/stretchr/testify/assert" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/synapsecns/sanguine/services/rfq/relayer/relconfig" @@ -30,9 +31,9 @@ func TestChainGetters(t *testing.T) { L1FeeOriginGasEstimate: 30000, L1FeeDestGasEstimate: 40000, MinGasToken: "1000", - QuotePct: 50, + QuotePct: relconfig.NewFloatPtr(0), QuoteWidthBps: 10, - QuoteFixedFeeMultiplier: 1.1, + QuoteFixedFeeMultiplier: relconfig.NewFloatPtr(1.1), }, }, BaseChainConfig: relconfig.ChainConfig{ @@ -48,9 +49,9 @@ func TestChainGetters(t *testing.T) { L1FeeOriginGasEstimate: 30001, L1FeeDestGasEstimate: 40001, MinGasToken: "1001", - QuotePct: 51, + QuotePct: relconfig.NewFloatPtr(51), QuoteWidthBps: 11, - QuoteFixedFeeMultiplier: 1.2, + QuoteFixedFeeMultiplier: relconfig.NewFloatPtr(1.2), }, } cfg := relconfig.Config{ @@ -68,9 +69,9 @@ func TestChainGetters(t *testing.T) { L1FeeOriginGasEstimate: 30000, L1FeeDestGasEstimate: 40000, MinGasToken: "1000", - QuotePct: 50, + QuotePct: relconfig.NewFloatPtr(50), QuoteWidthBps: 10, - QuoteFixedFeeMultiplier: 1.1, + QuoteFixedFeeMultiplier: relconfig.NewFloatPtr(1.1), Tokens: map[string]relconfig.TokenConfig{ "USDC": { Address: usdcAddr, @@ -253,15 +254,15 @@ func TestChainGetters(t *testing.T) { t.Run("GetQuotePct", func(t *testing.T) { defaultVal, err := cfg.GetQuotePct(badChainID) assert.NoError(t, err) - assert.Equal(t, defaultVal, relconfig.DefaultChainConfig.QuotePct) + assert.Equal(t, defaultVal, 100.) baseVal, err := cfgWithBase.GetQuotePct(badChainID) assert.NoError(t, err) - assert.Equal(t, baseVal, cfgWithBase.BaseChainConfig.QuotePct) + assert.Equal(t, baseVal, 51.) chainVal, err := cfgWithBase.GetQuotePct(chainID) assert.NoError(t, err) - assert.Equal(t, chainVal, cfgWithBase.Chains[chainID].QuotePct) + assert.Equal(t, chainVal, 0.) }) t.Run("GetQuoteWidthBps", func(t *testing.T) { @@ -281,15 +282,15 @@ func TestChainGetters(t *testing.T) { t.Run("GetQuoteFixedFeeMultiplier", func(t *testing.T) { defaultVal, err := cfg.GetQuoteFixedFeeMultiplier(badChainID) assert.NoError(t, err) - assert.Equal(t, defaultVal, relconfig.DefaultChainConfig.QuoteFixedFeeMultiplier) + assert.Equal(t, defaultVal, *relconfig.DefaultChainConfig.QuoteFixedFeeMultiplier) baseVal, err := cfgWithBase.GetQuoteFixedFeeMultiplier(badChainID) assert.NoError(t, err) - assert.Equal(t, baseVal, cfgWithBase.BaseChainConfig.QuoteFixedFeeMultiplier) + assert.Equal(t, baseVal, *cfgWithBase.BaseChainConfig.QuoteFixedFeeMultiplier) chainVal, err := cfgWithBase.GetQuoteFixedFeeMultiplier(chainID) assert.NoError(t, err) - assert.Equal(t, chainVal, cfgWithBase.Chains[chainID].QuoteFixedFeeMultiplier) + assert.Equal(t, chainVal, *cfgWithBase.Chains[chainID].QuoteFixedFeeMultiplier) }) t.Run("GetMaxRebalanceAmount", func(t *testing.T) { @@ -319,9 +320,9 @@ func TestGetQuoteOffset(t *testing.T) { L1FeeOriginGasEstimate: 30000, L1FeeDestGasEstimate: 40000, MinGasToken: "1000", - QuotePct: 50, + QuotePct: relconfig.NewFloatPtr(50), QuoteWidthBps: 10, - QuoteFixedFeeMultiplier: 1.1, + QuoteFixedFeeMultiplier: relconfig.NewFloatPtr(1.1), Tokens: map[string]relconfig.TokenConfig{ "USDC": { Address: usdcAddr, diff --git a/services/rfq/relayer/relconfig/getters.go b/services/rfq/relayer/relconfig/getters.go index 3bca0f98ce..60e3ab3ac7 100644 --- a/services/rfq/relayer/relconfig/getters.go +++ b/services/rfq/relayer/relconfig/getters.go @@ -2,6 +2,7 @@ package relconfig import ( "fmt" + "github.com/synapsecns/sanguine/core" "math/big" "reflect" "time" @@ -17,10 +18,15 @@ var DefaultChainConfig = ChainConfig{ OriginGasEstimate: 160000, DestGasEstimate: 100000, MinGasToken: "100000000000000000", // 1 ETH - QuotePct: 100, + QuotePct: NewFloatPtr(100), QuoteWidthBps: 0, - QuoteFixedFeeMultiplier: 1, - RelayFixedFeeMultiplier: 1, + QuoteFixedFeeMultiplier: NewFloatPtr(1), + RelayFixedFeeMultiplier: NewFloatPtr(1), +} + +// NewFloatPtr returns a pointer to a float64. +func NewFloatPtr(val float64) *float64 { + return core.PtrTo(val) } // getChainConfigValue gets the value of a field from ChainConfig. @@ -34,8 +40,8 @@ func (c Config) getChainConfigValue(chainID int, fieldName string) (interface{}, if err != nil { return nil, err } - if isNonZero(value) { - return value, nil + if !isNilOrZero(value) { + return derefPointer(value), nil } } @@ -43,15 +49,15 @@ func (c Config) getChainConfigValue(chainID int, fieldName string) (interface{}, if err != nil { return nil, err } - if isNonZero(baseValue) { - return baseValue, nil + if !isNilOrZero(baseValue) { + return derefPointer(baseValue), nil } defaultValue, err := getFieldValue(DefaultChainConfig, fieldName) if err != nil { return nil, err } - return defaultValue, nil + return derefPointer(defaultValue), nil } func getFieldValue(obj interface{}, fieldName string) (interface{}, error) { @@ -85,8 +91,20 @@ func isChainConfigField(fieldName string) bool { return ok } -func isNonZero(value interface{}) bool { - return reflect.ValueOf(value).Interface() != reflect.Zero(reflect.TypeOf(value)).Interface() +func derefPointer(value interface{}) interface{} { + val := reflect.ValueOf(value) + if val.Kind() == reflect.Ptr && !val.IsNil() { + return val.Elem().Interface() + } + return value +} + +func isNilOrZero(value interface{}) bool { + val := reflect.ValueOf(value) + if val.Kind() == reflect.Ptr { + return val.IsNil() + } + return reflect.DeepEqual(value, reflect.Zero(val.Type()).Interface()) } // GetRFQAddress returns the RFQ address for the given chainID. @@ -327,7 +345,7 @@ func (c Config) GetQuoteFixedFeeMultiplier(chainID int) (value float64, err erro return value, fmt.Errorf("failed to cast QuoteFixedFeeMultiplier to int") } if value <= 0 { - value = DefaultChainConfig.QuoteFixedFeeMultiplier + value = *DefaultChainConfig.QuoteFixedFeeMultiplier } return value, nil }