diff --git a/services/rfq/relayer/quoter/export_test.go b/services/rfq/relayer/quoter/export_test.go index 231f831c24..81d719ae03 100644 --- a/services/rfq/relayer/quoter/export_test.go +++ b/services/rfq/relayer/quoter/export_test.go @@ -14,8 +14,8 @@ func (m *Manager) GenerateQuotes(ctx context.Context, chainID int, address commo return m.generateQuotes(ctx, chainID, address, balance, inv) } -func (m *Manager) GetOriginAmount(ctx context.Context, origin, dest int, originAddr common.Address, address common.Address, originBalance, destBalance *big.Int) (*big.Int, error) { - return m.getOriginAmount(ctx, origin, dest, originAddr, address, originBalance, destBalance) +func (m *Manager) GetOriginAmount(ctx context.Context, input QuoteInput) (*big.Int, error) { + return m.getOriginAmount(ctx, input) } func (m *Manager) GetDestAmount(ctx context.Context, quoteAmount *big.Int, chainID int, tokenName string) (*big.Int, error) { diff --git a/services/rfq/relayer/quoter/quoter.go b/services/rfq/relayer/quoter/quoter.go index 0d22e2999b..d6d017494f 100644 --- a/services/rfq/relayer/quoter/quoter.go +++ b/services/rfq/relayer/quoter/quoter.go @@ -354,14 +354,14 @@ func (m *Manager) generateQuotes(parentCtx context.Context, chainID int, address } g.Go(func() error { - input := quoteInput{ - originChainID: origin, - destChainID: chainID, - originTokenAddr: originTokenAddr, - destTokenAddr: address, - originBalance: originBalance, - destBalance: balance, - destRFQAddr: destRFQAddr, + input := QuoteInput{ + OriginChainID: origin, + DestChainID: chainID, + OriginTokenAddr: originTokenAddr, + DestTokenAddr: address, + OriginBalance: originBalance, + DestBalance: balance, + DestRFQAddr: destRFQAddr, } quote, quoteErr := m.generateQuote(gctx, input) @@ -390,19 +390,20 @@ func (m *Manager) generateQuotes(parentCtx context.Context, chainID int, address return quotes, nil } -type quoteInput struct { - originChainID int - destChainID int - originTokenAddr common.Address - destTokenAddr common.Address - originBalance *big.Int - destBalance *big.Int - destRFQAddr string +// QuoteInput is a wrapper struct for input arguments to generateQuote. +type QuoteInput struct { + OriginChainID int + DestChainID int + OriginTokenAddr common.Address + DestTokenAddr common.Address + OriginBalance *big.Int + DestBalance *big.Int + DestRFQAddr string } -func (m *Manager) generateQuote(ctx context.Context, input quoteInput) (quote *model.PutQuoteRequest, err error) { +func (m *Manager) generateQuote(ctx context.Context, input QuoteInput) (quote *model.PutQuoteRequest, err error) { // Calculate the quote amount for this route - originAmount, err := m.getOriginAmount(ctx, input.originChainID, input.destChainID, input.originTokenAddr, input.destTokenAddr, input.originBalance, input.destBalance) + originAmount, err := m.getOriginAmount(ctx, input) // don't quote if gas exceeds quote if errors.Is(err, errMinGasExceedsQuoteAmount) { originAmount = big.NewInt(0) @@ -412,38 +413,38 @@ func (m *Manager) generateQuote(ctx context.Context, input quoteInput) (quote *m } // Calculate the fee for this route - destToken, err := m.config.GetTokenName(uint32(input.destChainID), input.destTokenAddr.Hex()) + destToken, err := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex()) if err != nil { logger.Error("Error getting dest token ID", "error", err) return nil, fmt.Errorf("error getting dest token ID: %w", err) } - fee, err := m.feePricer.GetTotalFee(ctx, uint32(input.originChainID), uint32(input.destChainID), destToken, true) + fee, err := m.feePricer.GetTotalFee(ctx, uint32(input.OriginChainID), uint32(input.DestChainID), destToken, true) if err != nil { logger.Error("Error getting total fee", "error", err) return nil, fmt.Errorf("error getting total fee: %w", err) } - originRFQAddr, err := m.config.GetRFQAddress(input.originChainID) + originRFQAddr, err := m.config.GetRFQAddress(input.OriginChainID) if err != nil { logger.Error("Error getting RFQ address", "error", err) return nil, fmt.Errorf("error getting RFQ address: %w", err) } // Build the quote - destAmount, err := m.getDestAmount(ctx, originAmount, input.destChainID, destToken) + destAmount, err := m.getDestAmount(ctx, originAmount, input.DestChainID, destToken) if err != nil { logger.Error("Error getting dest amount", "error", err) return nil, fmt.Errorf("error getting dest amount: %w", err) } quote = &model.PutQuoteRequest{ - OriginChainID: input.originChainID, - OriginTokenAddr: input.originTokenAddr.Hex(), - DestChainID: input.destChainID, - DestTokenAddr: input.destTokenAddr.Hex(), + OriginChainID: input.OriginChainID, + OriginTokenAddr: input.OriginTokenAddr.Hex(), + DestChainID: input.DestChainID, + DestTokenAddr: input.DestTokenAddr.Hex(), DestAmount: destAmount.String(), MaxOriginAmount: originAmount.String(), FixedFee: fee.String(), OriginFastBridgeAddress: originRFQAddr, - DestFastBridgeAddress: input.destRFQAddr, + DestFastBridgeAddress: input.DestRFQAddr, } return quote, nil } @@ -486,13 +487,14 @@ func (m *Manager) recordQuoteAmounts(_ context.Context, observer metric.Observer // getOriginAmount calculates the origin quote amount for a given route. // //nolint:cyclop -func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, originAddress, destAddress common.Address, originBalance, destBalance *big.Int) (quoteAmount *big.Int, err error) { +func (m *Manager) getOriginAmount(parentCtx context.Context, input QuoteInput) (quoteAmount *big.Int, err error) { ctx, span := m.metricsHandler.Tracer().Start(parentCtx, "getOriginAmount", trace.WithAttributes( - attribute.String(metrics.Origin, strconv.Itoa(origin)), - attribute.String(metrics.Destination, strconv.Itoa(dest)), - attribute.String("address", destAddress.String()), - attribute.String("origin_balance", originBalance.String()), - attribute.String("dest_balance", destBalance.String()), + attribute.Int(metrics.Origin, input.OriginChainID), + attribute.Int(metrics.Destination, input.DestChainID), + attribute.String("dest_address", input.DestTokenAddr.String()), + attribute.String("origin_address", input.OriginTokenAddr.String()), + attribute.String("origin_balance", input.OriginBalance.String()), + attribute.String("dest_balance", input.DestBalance.String()), )) defer func() { @@ -503,11 +505,11 @@ func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, o // First, check if we have enough gas to complete the a bridge for this route // If not, set the quote amount to zero to make sure a stale quote won't be used // TODO: handle in-flight gas; for now we can set a high min_gas_token - sufficentGasOrigin, err := m.inventoryManager.HasSufficientGas(ctx, origin, nil) + sufficentGasOrigin, err := m.inventoryManager.HasSufficientGas(ctx, input.OriginChainID, nil) if err != nil { return nil, fmt.Errorf("error checking sufficient gas: %w", err) } - sufficentGasDest, err := m.inventoryManager.HasSufficientGas(ctx, dest, nil) + sufficentGasDest, err := m.inventoryManager.HasSufficientGas(ctx, input.DestChainID, nil) if err != nil { return nil, fmt.Errorf("error checking sufficient gas: %w", err) } @@ -520,26 +522,26 @@ func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, o } // Apply the quotePct - quotePct, err := m.config.GetQuotePct(dest) + quotePct, err := m.config.GetQuotePct(input.DestChainID) if err != nil { return nil, fmt.Errorf("error getting quote pct: %w", err) } - balanceFlt := new(big.Float).SetInt(destBalance) + balanceFlt := new(big.Float).SetInt(input.DestBalance) quoteAmount, _ = new(big.Float).Mul(balanceFlt, new(big.Float).SetFloat64(quotePct/100)).Int(nil) // Apply the quoteOffset to origin token. - tokenName, err := m.config.GetTokenName(uint32(dest), destAddress.Hex()) + tokenName, err := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex()) if err != nil { return nil, fmt.Errorf("error getting token name: %w", err) } - quoteOffsetBps, err := m.config.GetQuoteOffsetBps(origin, tokenName, true) + quoteOffsetBps, err := m.config.GetQuoteOffsetBps(input.OriginChainID, tokenName, true) if err != nil { return nil, fmt.Errorf("error getting quote offset bps: %w", err) } quoteAmount = m.applyOffset(ctx, quoteOffsetBps, quoteAmount) // Clip the quoteAmount by the minQuoteAmount - minQuoteAmount := m.config.GetMinQuoteAmount(dest, destAddress) + minQuoteAmount := m.config.GetMinQuoteAmount(input.DestChainID, input.DestTokenAddr) if quoteAmount.Cmp(minQuoteAmount) < 0 { span.AddEvent("quote amount less than min quote amount", trace.WithAttributes( attribute.String("quote_amount", quoteAmount.String()), @@ -549,32 +551,31 @@ func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, o } // Clip the quoteAmount by the max origin balance - maxBalance := m.config.GetMaxBalance(origin, originAddress) - fmt.Printf("maxBalance: %v originBalance: %v\n", maxBalance, originBalance) - if originBalance != nil && maxBalance.Cmp(big.NewInt(0)) > 0 { - quotableBalance := new(big.Int).Sub(maxBalance, originBalance) + maxBalance := m.config.GetMaxBalance(input.OriginChainID, input.OriginTokenAddr) + if input.OriginBalance != nil && maxBalance.Cmp(big.NewInt(0)) > 0 { + quotableBalance := new(big.Int).Sub(maxBalance, input.OriginBalance) if quoteAmount.Cmp(quotableBalance) > 0 { span.AddEvent("quote amount greater than quotable balance", trace.WithAttributes( attribute.String("quote_amount", quoteAmount.String()), attribute.String("quotable_balance", quotableBalance.String()), attribute.String("max_balance", maxBalance.String()), - attribute.String("origin_balance", originBalance.String()), + attribute.String("origin_balance", input.OriginBalance.String()), )) quoteAmount = quotableBalance } } // Finally, clip the quoteAmount by the dest balance - if quoteAmount.Cmp(destBalance) > 0 { + if quoteAmount.Cmp(input.DestBalance) > 0 { span.AddEvent("quote amount greater than quotable balance", trace.WithAttributes( attribute.String("quote_amount", quoteAmount.String()), - attribute.String("balance", destBalance.String()), + attribute.String("balance", input.DestBalance.String()), )) - quoteAmount = destBalance + quoteAmount = input.DestBalance } // Deduct gas cost from the quote amount, if necessary - quoteAmount, err = m.deductGasCost(ctx, quoteAmount, destAddress, dest) + quoteAmount, err = m.deductGasCost(ctx, quoteAmount, input.DestTokenAddr, input.DestChainID) if err != nil { return nil, fmt.Errorf("error deducting gas cost: %w", err) } diff --git a/services/rfq/relayer/quoter/quoter_test.go b/services/rfq/relayer/quoter/quoter_test.go index fec5e7f631..094ad3a29c 100644 --- a/services/rfq/relayer/quoter/quoter_test.go +++ b/services/rfq/relayer/quoter/quoter_test.go @@ -179,57 +179,66 @@ func (s *QuoterSuite) TestGetOriginAmount() { s.manager.SetConfig(s.config) } + input := quoter.QuoteInput{ + OriginChainID: origin, + DestChainID: dest, + OriginTokenAddr: originAddr, + DestTokenAddr: address, + OriginBalance: balance, + DestBalance: balance, + } + // Set default quote params; should return the balance. - quoteAmount, err := s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance) + quoteAmount, err := s.manager.GetOriginAmount(s.GetTestContext(), input) s.NoError(err) expectedAmount := balance s.Equal(expectedAmount, quoteAmount) // Set QuotePct to 50 with MinQuoteAmount of 0; should be 50% of balance. setQuoteParams(50, 0, "0", "0") - quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance) + quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input) s.NoError(err) expectedAmount = big.NewInt(500_000_000) s.Equal(expectedAmount, quoteAmount) // Set QuotePct to 50 with QuoteOffset of -1%. Should be 1% less than 50% of balance. setQuoteParams(50, -100, "0", "0") - quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance) + quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input) s.NoError(err) expectedAmount = big.NewInt(495_000_000) s.Equal(expectedAmount, quoteAmount) // Set QuotePct to 25 with MinQuoteAmount of 500; should be 50% of balance. setQuoteParams(25, 0, "500", "0") - quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance) + quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input) s.NoError(err) expectedAmount = big.NewInt(500_000_000) s.Equal(expectedAmount, quoteAmount) // Set QuotePct to 25 with MinQuoteAmount of 500; should be 50% of balance. setQuoteParams(25, 0, "500", "0") - quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance) + quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input) s.NoError(err) expectedAmount = big.NewInt(500_000_000) s.Equal(expectedAmount, quoteAmount) // Set QuotePct to 25 with MinQuoteAmount of 1500; should be total balance. setQuoteParams(25, 0, "1500", "0") - quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance) + quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input) s.NoError(err) expectedAmount = big.NewInt(1000_000_000) s.Equal(expectedAmount, quoteAmount) // Set QuotePct to 25 with MinQuoteAmount of 1500 and MaxBalance of 1200; should be 200. setQuoteParams(25, 0, "1500", "1200") - quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance) + quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input) s.NoError(err) expectedAmount = big.NewInt(200_000_000) s.Equal(expectedAmount, quoteAmount) // Toggle insufficient gas; should be 0. s.setGasSufficiency(false) - quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance) + quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input) s.NoError(err) expectedAmount = big.NewInt(0) s.Equal(expectedAmount, quoteAmount)