Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rfq-relayer): add MaxBalance param #2917

Merged
merged 21 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions services/rfq/relayer/quoter/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
"github.com/synapsecns/sanguine/services/rfq/relayer/relconfig"
)

func (m *Manager) GenerateQuotes(ctx context.Context, chainID int, address common.Address, balance *big.Int) ([]model.PutQuoteRequest, error) {
func (m *Manager) GenerateQuotes(ctx context.Context, chainID int, address common.Address, balance *big.Int, inv map[int]map[common.Address]*big.Int) ([]model.PutQuoteRequest, error) {
// nolint: errcheck
return m.generateQuotes(ctx, chainID, address, balance)
return m.generateQuotes(ctx, chainID, address, balance, inv)
}

func (m *Manager) GetOriginAmount(ctx context.Context, origin, dest int, address common.Address, balance *big.Int) (*big.Int, error) {
return m.getOriginAmount(ctx, origin, dest, address, balance)
func (m *Manager) GetOriginAmount(ctx context.Context, input QuoteInput) (*big.Int, error) {
return m.getOriginAmount(ctx, input)
}

func (m *Manager) GetDestAmount(ctx context.Context, quoteAmount *big.Int, chainID int, tokenName string) (*big.Int, error) {
Expand Down
128 changes: 87 additions & 41 deletions services/rfq/relayer/quoter/quoter.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@
// First, generate all quotes
for chainID, balances := range inv {
for address, balance := range balances {
quotes, err := m.generateQuotes(ctx, chainID, address, balance)
quotes, err := m.generateQuotes(ctx, chainID, address, balance, inv)

Check warning on line 263 in services/rfq/relayer/quoter/quoter.go

View check run for this annotation

Codecov / codecov/patch

services/rfq/relayer/quoter/quoter.go#L263

Added line #L263 was not covered by tests
if err != nil {
return err
}
Expand Down Expand Up @@ -311,7 +311,7 @@
// Essentially, if we know a destination chain token balance, then we just need to find which tokens are bridgeable to it.
// We can do this by looking at the quotableTokens map, and finding the key that matches the destination chain token.
// Generates quotes for a given chain ID, address, and balance.
func (m *Manager) generateQuotes(parentCtx context.Context, chainID int, address common.Address, balance *big.Int) (quotes []model.PutQuoteRequest, err error) {
func (m *Manager) generateQuotes(parentCtx context.Context, chainID int, address common.Address, balance *big.Int, inv map[int]map[common.Address]*big.Int) (quotes []model.PutQuoteRequest, err error) {
ctx, span := m.metricsHandler.Tracer().Start(parentCtx, "generateQuotes", trace.WithAttributes(
attribute.Int(metrics.Origin, chainID),
attribute.String("address", address.String()),
Expand All @@ -335,9 +335,36 @@
for _, tokenID := range itemTokenIDs {
//nolint:nestif
if tokenID == destTokenID {
keyTokenID := k
keyTokenID := k // Parse token info
originStr := strings.Split(keyTokenID, "-")[0]
origin, tokenErr := strconv.Atoi(originStr)
if err != nil {
span.AddEvent("error converting origin chainID", trace.WithAttributes(
attribute.String("key_token_id", keyTokenID),
attribute.String("error", tokenErr.Error()),
))
continue

Check warning on line 346 in services/rfq/relayer/quoter/quoter.go

View check run for this annotation

Codecov / codecov/patch

services/rfq/relayer/quoter/quoter.go#L342-L346

Added lines #L342 - L346 were not covered by tests
}
originTokenAddr := common.HexToAddress(strings.Split(keyTokenID, "-")[1])

var originBalance *big.Int
originTokens, ok := inv[origin]
if ok {
originBalance = originTokens[originTokenAddr]
}

Check warning on line 354 in services/rfq/relayer/quoter/quoter.go

View check run for this annotation

Codecov / codecov/patch

services/rfq/relayer/quoter/quoter.go#L353-L354

Added lines #L353 - L354 were not covered by tests
Comment on lines +351 to +367
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle potential errors in parsing token info.

The current implementation parses origin and originTokenAddr without handling potential errors. Ensure proper error handling to avoid unexpected behavior.

-				originStr := strings.Split(keyTokenID, "-")[0]
-				origin, tokenErr := strconv.Atoi(originStr)
-				if err != nil {
+				parts := strings.Split(keyTokenID, "-")
+				if len(parts) != 2 {
+					span.AddEvent("invalid key token ID format", trace.WithAttributes(
+						attribute.String("key_token_id", keyTokenID),
+					))
+					continue
+				}
+				origin, tokenErr := strconv.Atoi(parts[0])
+				if tokenErr != nil {
					span.AddEvent("error converting origin chainID", trace.WithAttributes(
						attribute.String("key_token_id", keyTokenID),
						attribute.String("error", tokenErr.Error()),
					))
					continue
				}
				originTokenAddr := common.HexToAddress(parts[1])
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
keyTokenID := k // Parse token info
originStr := strings.Split(keyTokenID, "-")[0]
origin, tokenErr := strconv.Atoi(originStr)
if err != nil {
span.AddEvent("error converting origin chainID", trace.WithAttributes(
attribute.String("key_token_id", keyTokenID),
attribute.String("error", tokenErr.Error()),
))
continue
}
originTokenAddr := common.HexToAddress(strings.Split(keyTokenID, "-")[1])
var originBalance *big.Int
originTokens, ok := inv[origin]
if ok {
originBalance = originTokens[originTokenAddr]
}
keyTokenID := k // Parse token info
parts := strings.Split(keyTokenID, "-")
if len(parts) != 2 {
span.AddEvent("invalid key token ID format", trace.WithAttributes(
attribute.String("key_token_id", keyTokenID),
))
continue
}
origin, tokenErr := strconv.Atoi(parts[0])
if tokenErr != nil {
span.AddEvent("error converting origin chainID", trace.WithAttributes(
attribute.String("key_token_id", keyTokenID),
attribute.String("error", tokenErr.Error()),
))
continue
}
originTokenAddr := common.HexToAddress(parts[1])
var originBalance *big.Int
originTokens, ok := inv[origin]
if ok {
originBalance = originTokens[originTokenAddr]
}


g.Go(func() error {
quote, quoteErr := m.generateQuote(gctx, keyTokenID, chainID, address, balance, destRFQAddr)
input := QuoteInput{
OriginChainID: origin,
DestChainID: chainID,
OriginTokenAddr: originTokenAddr,
DestTokenAddr: address,
OriginBalance: originBalance,
DestBalance: balance,
DestRFQAddr: destRFQAddr,
}

quote, quoteErr := m.generateQuote(gctx, input)
if quoteErr != nil {
// continue generating quotes even if one fails
span.AddEvent("error generating quote", trace.WithAttributes(
Expand All @@ -363,18 +390,20 @@
return quotes, nil
}

func (m *Manager) generateQuote(ctx context.Context, keyTokenID string, chainID int, address common.Address, balance *big.Int, destRFQAddr string) (quote *model.PutQuoteRequest, err error) {
// Parse token info
originStr := strings.Split(keyTokenID, "-")[0]
origin, err := strconv.Atoi(originStr)
if err != nil {
logger.Error("Error converting origin chainID", "error", err)
return nil, fmt.Errorf("error converting origin chainID: %w", err)
}
originTokenAddr := common.HexToAddress(strings.Split(keyTokenID, "-")[1])
// QuoteInput is a wrapper struct for input arguments to generateQuote.
type QuoteInput struct {
OriginChainID int
DestChainID int
OriginTokenAddr common.Address
DestTokenAddr common.Address
OriginBalance *big.Int
DestBalance *big.Int
DestRFQAddr string
}

func (m *Manager) generateQuote(ctx context.Context, input QuoteInput) (quote *model.PutQuoteRequest, err error) {
// Calculate the quote amount for this route
originAmount, err := m.getOriginAmount(ctx, origin, chainID, address, balance)
originAmount, err := m.getOriginAmount(ctx, input)
// don't quote if gas exceeds quote
if errors.Is(err, errMinGasExceedsQuoteAmount) {
Comment on lines +419 to 421
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle potential errors in getting origin amount.

Ensure proper error handling when calling m.getOriginAmount.

-	// don't quote if gas exceeds quote
-	if errors.Is(err, errMinGasExceedsQuoteAmount) {
-		originAmount = big.NewInt(0)
+	if err != nil {
+		if errors.Is(err, errMinGasExceedsQuoteAmount) {
+			originAmount = big.NewInt(0)
+		} else {
+			logger.Error("Error getting quote amount", "error", err)
+			return nil, err
+		}
	}
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
originAmount, err := m.getOriginAmount(ctx, input)
// don't quote if gas exceeds quote
if errors.Is(err, errMinGasExceedsQuoteAmount) {
originAmount, err := m.getOriginAmount(ctx, input)
// don't quote if gas exceeds quote
if err != nil {
if errors.Is(err, errMinGasExceedsQuoteAmount) {
originAmount = big.NewInt(0)
} else {
logger.Error("Error getting quote amount", "error", err)
return nil, err
}
}

originAmount = big.NewInt(0)
Expand All @@ -384,38 +413,38 @@
}

// Calculate the fee for this route
destToken, err := m.config.GetTokenName(uint32(chainID), address.Hex())
destToken, err := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex())
if err != nil {
logger.Error("Error getting dest token ID", "error", err)
return nil, fmt.Errorf("error getting dest token ID: %w", err)
}
fee, err := m.feePricer.GetTotalFee(ctx, uint32(origin), uint32(chainID), destToken, true)
fee, err := m.feePricer.GetTotalFee(ctx, uint32(input.OriginChainID), uint32(input.DestChainID), destToken, true)
Comment on lines +429 to +434
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle potential errors in getting token name.

Ensure proper error handling when calling m.config.GetTokenName.

-	destToken, err := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex())
-	if err != nil {
-		logger.Error("Error getting dest token ID", "error", err)
-		return nil, fmt.Errorf("error getting dest token ID: %w", err)
+	destToken, tokenErr := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex())
+	if tokenErr != nil {
+		logger.Error("Error getting dest token ID", "error", tokenErr)
+		return nil, fmt.Errorf("error getting dest token ID: %w", tokenErr)
	}
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
destToken, err := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex())
if err != nil {
logger.Error("Error getting dest token ID", "error", err)
return nil, fmt.Errorf("error getting dest token ID: %w", err)
}
fee, err := m.feePricer.GetTotalFee(ctx, uint32(origin), uint32(chainID), destToken, true)
fee, err := m.feePricer.GetTotalFee(ctx, uint32(input.OriginChainID), uint32(input.DestChainID), destToken, true)
destToken, tokenErr := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex())
if tokenErr != nil {
logger.Error("Error getting dest token ID", "error", tokenErr)
return nil, fmt.Errorf("error getting dest token ID: %w", tokenErr)
}
fee, err := m.feePricer.GetTotalFee(ctx, uint32(input.OriginChainID), uint32(input.DestChainID), destToken, true)

if err != nil {
logger.Error("Error getting total fee", "error", err)
return nil, fmt.Errorf("error getting total fee: %w", err)
}
originRFQAddr, err := m.config.GetRFQAddress(origin)
originRFQAddr, err := m.config.GetRFQAddress(input.OriginChainID)
if err != nil {
logger.Error("Error getting RFQ address", "error", err)
return nil, fmt.Errorf("error getting RFQ address: %w", err)
}

// Build the quote
destAmount, err := m.getDestAmount(ctx, originAmount, chainID, destToken)
destAmount, err := m.getDestAmount(ctx, originAmount, input.DestChainID, destToken)
if err != nil {
logger.Error("Error getting dest amount", "error", err)
return nil, fmt.Errorf("error getting dest amount: %w", err)
}
quote = &model.PutQuoteRequest{
OriginChainID: origin,
OriginTokenAddr: originTokenAddr.Hex(),
DestChainID: chainID,
DestTokenAddr: address.Hex(),
OriginChainID: input.OriginChainID,
OriginTokenAddr: input.OriginTokenAddr.Hex(),
DestChainID: input.DestChainID,
DestTokenAddr: input.DestTokenAddr.Hex(),
DestAmount: destAmount.String(),
MaxOriginAmount: originAmount.String(),
FixedFee: fee.String(),
OriginFastBridgeAddress: originRFQAddr,
DestFastBridgeAddress: destRFQAddr,
DestFastBridgeAddress: input.DestRFQAddr,
}
return quote, nil
}
Expand Down Expand Up @@ -458,12 +487,14 @@
// getOriginAmount calculates the origin quote amount for a given route.
//
//nolint:cyclop
func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, address common.Address, balance *big.Int) (quoteAmount *big.Int, err error) {
func (m *Manager) getOriginAmount(parentCtx context.Context, input QuoteInput) (quoteAmount *big.Int, err error) {
ctx, span := m.metricsHandler.Tracer().Start(parentCtx, "getOriginAmount", trace.WithAttributes(
attribute.String(metrics.Origin, strconv.Itoa(origin)),
attribute.String(metrics.Destination, strconv.Itoa(dest)),
attribute.String("address", address.String()),
attribute.String("balance", balance.String()),
attribute.Int(metrics.Origin, input.OriginChainID),
attribute.Int(metrics.Destination, input.DestChainID),
attribute.String("dest_address", input.DestTokenAddr.String()),
attribute.String("origin_address", input.OriginTokenAddr.String()),
attribute.String("origin_balance", input.OriginBalance.String()),
attribute.String("dest_balance", input.DestBalance.String()),
))

defer func() {
Expand All @@ -474,11 +505,11 @@
// First, check if we have enough gas to complete the a bridge for this route
// If not, set the quote amount to zero to make sure a stale quote won't be used
// TODO: handle in-flight gas; for now we can set a high min_gas_token
sufficentGasOrigin, err := m.inventoryManager.HasSufficientGas(ctx, origin, nil)
sufficentGasOrigin, err := m.inventoryManager.HasSufficientGas(ctx, input.OriginChainID, nil)
if err != nil {
return nil, fmt.Errorf("error checking sufficient gas: %w", err)
}
sufficentGasDest, err := m.inventoryManager.HasSufficientGas(ctx, dest, nil)
sufficentGasDest, err := m.inventoryManager.HasSufficientGas(ctx, input.DestChainID, nil)
if err != nil {
return nil, fmt.Errorf("error checking sufficient gas: %w", err)
}
Expand All @@ -491,26 +522,26 @@
}

// Apply the quotePct
quotePct, err := m.config.GetQuotePct(dest)
quotePct, err := m.config.GetQuotePct(input.DestChainID)
if err != nil {
return nil, fmt.Errorf("error getting quote pct: %w", err)
}
balanceFlt := new(big.Float).SetInt(balance)
balanceFlt := new(big.Float).SetInt(input.DestBalance)
quoteAmount, _ = new(big.Float).Mul(balanceFlt, new(big.Float).SetFloat64(quotePct/100)).Int(nil)

// Apply the quoteOffset to origin token.
tokenName, err := m.config.GetTokenName(uint32(dest), address.Hex())
tokenName, err := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex())
if err != nil {
return nil, fmt.Errorf("error getting token name: %w", err)
}
quoteOffsetBps, err := m.config.GetQuoteOffsetBps(origin, tokenName, true)
quoteOffsetBps, err := m.config.GetQuoteOffsetBps(input.OriginChainID, tokenName, true)
if err != nil {
return nil, fmt.Errorf("error getting quote offset bps: %w", err)
}
quoteAmount = m.applyOffset(ctx, quoteOffsetBps, quoteAmount)

// Clip the quoteAmount by the minQuoteAmount
minQuoteAmount := m.config.GetMinQuoteAmount(dest, address)
minQuoteAmount := m.config.GetMinQuoteAmount(input.DestChainID, input.DestTokenAddr)
if quoteAmount.Cmp(minQuoteAmount) < 0 {
span.AddEvent("quote amount less than min quote amount", trace.WithAttributes(
attribute.String("quote_amount", quoteAmount.String()),
Expand All @@ -519,17 +550,32 @@
quoteAmount = minQuoteAmount
}

// Finally, clip the quoteAmount by the balance
if quoteAmount.Cmp(balance) > 0 {
span.AddEvent("quote amount greater than balance", trace.WithAttributes(
// Clip the quoteAmount by the max origin balance
maxBalance := m.config.GetMaxBalance(input.OriginChainID, input.OriginTokenAddr)
if input.OriginBalance != nil && maxBalance.Cmp(big.NewInt(0)) > 0 {
ChiTimesChi marked this conversation as resolved.
Show resolved Hide resolved
quotableBalance := new(big.Int).Sub(maxBalance, input.OriginBalance)
ChiTimesChi marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for negative quotable balance.

Ensure that quotableBalance is not negative before assigning it to quoteAmount.

-	if input.OriginBalance != nil && maxBalance.Cmp(big.NewInt(0)) > 0 {
+	if input.OriginBalance != nil && maxBalance.Sign() > 0 {
		quotableBalance := new(big.Int).Sub(maxBalance, input.OriginBalance)
		if quotableBalance.Cmp(big.NewInt(0)) < 0 {
			span.AddEvent("negative quotable balance", trace.WithAttributes(
				attribute.String("quotable_balance", quotableBalance.String()),
				attribute.String("max_balance", maxBalance.String()),
				attribute.String("origin_balance", input.OriginBalance.String()),
			))
			quoteAmount = big.NewInt(0)
		} else if quoteAmount.Cmp(quotableBalance) > 0 {
			span.AddEvent("quote amount greater than quotable balance", trace.WithAttributes(
				attribute.String("quote_amount", quoteAmount.String()),
				attribute.String("quotable_balance", quotableBalance.String()),
				attribute.String("max_balance", maxBalance.String()),
				attribute.String("origin_balance", input.OriginBalance.String()),
			))
			quoteAmount = quotableBalance
		}
	}
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if input.OriginBalance != nil && maxBalance.Cmp(big.NewInt(0)) > 0 {
quotableBalance := new(big.Int).Sub(maxBalance, input.OriginBalance)
if input.OriginBalance != nil && maxBalance.Sign() > 0 {
quotableBalance := new(big.Int).Sub(maxBalance, input.OriginBalance)
if quotableBalance.Cmp(big.NewInt(0)) < 0 {
span.AddEvent("negative quotable balance", trace.WithAttributes(
attribute.String("quotable_balance", quotableBalance.String()),
attribute.String("max_balance", maxBalance.String()),
attribute.String("origin_balance", input.OriginBalance.String()),
))
quoteAmount = big.NewInt(0)
} else if quoteAmount.Cmp(quotableBalance) > 0 {
span.AddEvent("quote amount greater than quotable balance", trace.WithAttributes(
attribute.String("quote_amount", quoteAmount.String()),
attribute.String("quotable_balance", quotableBalance.String()),
attribute.String("max_balance", maxBalance.String()),
attribute.String("origin_balance", input.OriginBalance.String()),
))
quoteAmount = quotableBalance
}
}

if quoteAmount.Cmp(quotableBalance) > 0 {
span.AddEvent("quote amount greater than quotable balance", trace.WithAttributes(
attribute.String("quote_amount", quoteAmount.String()),
attribute.String("quotable_balance", quotableBalance.String()),
attribute.String("max_balance", maxBalance.String()),
attribute.String("origin_balance", input.OriginBalance.String()),
))
quoteAmount = quotableBalance
}
}

// Finally, clip the quoteAmount by the dest balance
if quoteAmount.Cmp(input.DestBalance) > 0 {
span.AddEvent("quote amount greater than quotable balance", trace.WithAttributes(
dwasse marked this conversation as resolved.
Show resolved Hide resolved
attribute.String("quote_amount", quoteAmount.String()),
attribute.String("balance", balance.String()),
attribute.String("balance", input.DestBalance.String()),
))
quoteAmount = balance
quoteAmount = input.DestBalance
}

// Deduct gas cost from the quote amount, if necessary
quoteAmount, err = m.deductGasCost(ctx, quoteAmount, address, dest)
quoteAmount, err = m.deductGasCost(ctx, quoteAmount, input.DestTokenAddr, input.DestChainID)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle potential errors in deducting gas cost.

Ensure proper error handling when calling m.deductGasCost.

-	quoteAmount, err = m.deductGasCost(ctx, quoteAmount, input.DestTokenAddr, input.DestChainID)
-	if err != nil {
-		return nil, fmt.Errorf("error deducting gas cost: %w", err)
+	quoteAmount, gasErr := m.deductGasCost(ctx, quoteAmount, input.DestTokenAddr, input.DestChainID)
+	if gasErr != nil {
+		return nil, fmt.Errorf("error deducting gas cost: %w", gasErr)
	}
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
quoteAmount, err = m.deductGasCost(ctx, quoteAmount, input.DestTokenAddr, input.DestChainID)
quoteAmount, err = m.deductGasCost(ctx, quoteAmount, input.DestTokenAddr, input.DestChainID)
if err != nil {
return nil, fmt.Errorf("error deducting gas cost: %w", err)
}

if err != nil {
return nil, fmt.Errorf("error deducting gas cost: %w", err)
}
Expand Down
54 changes: 37 additions & 17 deletions services/rfq/relayer/quoter/quoter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import (
func (s *QuoterSuite) TestGenerateQuotes() {
// Generate quotes for USDC on the destination chain.
balance := big.NewInt(1000_000_000) // 1000 USDC
quotes, err := s.manager.GenerateQuotes(s.GetTestContext(), int(s.destination), common.HexToAddress("0x0b2c639c533813f4aa9d7837caf62653d097ff85"), balance)
inv := map[int]map[common.Address]*big.Int{}
quotes, err := s.manager.GenerateQuotes(s.GetTestContext(), int(s.destination), common.HexToAddress("0x0b2c639c533813f4aa9d7837caf62653d097ff85"), balance, inv)
s.Require().NoError(err)

// Verify the quotes are generated as expected.
Expand All @@ -43,7 +44,8 @@ func (s *QuoterSuite) TestGenerateQuotes() {
func (s *QuoterSuite) TestGenerateQuotesForNativeToken() {
// Generate quotes for ETH on the destination chain.
balance, _ := new(big.Int).SetString("1000000000000000000", 10) // 1 ETH
quotes, err := s.manager.GenerateQuotes(s.GetTestContext(), int(s.destinationEth), chain.EthAddress, balance)
inv := map[int]map[common.Address]*big.Int{}
quotes, err := s.manager.GenerateQuotes(s.GetTestContext(), int(s.destinationEth), chain.EthAddress, balance, inv)
s.Require().NoError(err)

minGasToken, err := s.config.GetMinGasToken(int(s.destination))
Expand All @@ -68,7 +70,7 @@ func (s *QuoterSuite) TestGenerateQuotesForNativeToken() {
s.config.BaseChainConfig.MinGasToken = "100000000000000000" // 0.1 ETH
s.manager.SetConfig(s.config)

quotes, err = s.manager.GenerateQuotes(s.GetTestContext(), int(s.destinationEth), chain.EthAddress, balance)
quotes, err = s.manager.GenerateQuotes(s.GetTestContext(), int(s.destinationEth), chain.EthAddress, balance, inv)
s.Require().NoError(err)

minGasToken, err = s.config.GetMinGasToken(int(s.destination))
Expand All @@ -93,7 +95,7 @@ func (s *QuoterSuite) TestGenerateQuotesForNativeToken() {
s.config.BaseChainConfig.MinGasToken = "1000000000000000001" // 0.1 ETH
s.manager.SetConfig(s.config)

quotes, err = s.manager.GenerateQuotes(s.GetTestContext(), int(s.destinationEth), chain.EthAddress, balance)
quotes, err = s.manager.GenerateQuotes(s.GetTestContext(), int(s.destinationEth), chain.EthAddress, balance, inv)
s.NoError(err)
s.Equal(quotes[0].DestAmount, "0")
s.Equal(quotes[0].MaxOriginAmount, "0")
Expand Down Expand Up @@ -162,63 +164,81 @@ func (s *QuoterSuite) TestGetOriginAmount() {
origin := int(s.origin)
dest := int(s.destination)
address := common.HexToAddress("0x0b2c639c533813f4aa9d7837caf62653d097ff85")
originAddr := common.HexToAddress("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48")
balance := big.NewInt(1000_000_000) // 1000 USDC

setQuoteParams := func(quotePct, quoteOffset float64, minQuoteAmount string) {
setQuoteParams := func(quotePct, quoteOffset float64, minQuoteAmount, maxBalance string) {
s.config.BaseChainConfig.QuotePct = quotePct
destTokenCfg := s.config.Chains[dest].Tokens["USDC"]
destTokenCfg.MinQuoteAmount = minQuoteAmount
originTokenCfg := s.config.Chains[origin].Tokens["USDC"]
originTokenCfg.QuoteOffsetBps = quoteOffset
originTokenCfg.MaxBalance = maxBalance
s.config.Chains[dest].Tokens["USDC"] = destTokenCfg
s.config.Chains[origin].Tokens["USDC"] = originTokenCfg
s.manager.SetConfig(s.config)
}

input := quoter.QuoteInput{
OriginChainID: origin,
DestChainID: dest,
OriginTokenAddr: originAddr,
DestTokenAddr: address,
OriginBalance: balance,
DestBalance: balance,
}

// Set default quote params; should return the balance.
quoteAmount, err := s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, address, balance)
quoteAmount, err := s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount := balance
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 50 with MinQuoteAmount of 0; should be 50% of balance.
setQuoteParams(50, 0, "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, address, balance)
setQuoteParams(50, 0, "0", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(500_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 50 with QuoteOffset of -1%. Should be 1% less than 50% of balance.
setQuoteParams(50, -100, "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, address, balance)
setQuoteParams(50, -100, "0", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(495_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 25 with MinQuoteAmount of 500; should be 50% of balance.
setQuoteParams(25, 0, "500")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, address, balance)
setQuoteParams(25, 0, "500", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(500_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 25 with MinQuoteAmount of 500; should be 50% of balance.
setQuoteParams(25, 0, "500")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, address, balance)
setQuoteParams(25, 0, "500", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(500_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 25 with MinQuoteAmount of 1500; should be total balance.
setQuoteParams(25, 0, "1500")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, address, balance)
setQuoteParams(25, 0, "1500", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(1000_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 25 with MinQuoteAmount of 1500 and MaxBalance of 1200; should be 200.
setQuoteParams(25, 0, "1500", "1200")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(200_000_000)
s.Equal(expectedAmount, quoteAmount)

// Toggle insufficient gas; should be 0.
s.setGasSufficiency(false)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, address, balance)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(0)
s.Equal(expectedAmount, quoteAmount)
Expand Down
Loading
Loading