Skip to content

Commit

Permalink
Feat: getOriginQuoteAmount takes QuoteInput struct
Browse files Browse the repository at this point in the history
  • Loading branch information
dwasse committed Jul 25, 2024
1 parent 8a6cf15 commit fa5d27f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 59 deletions.
4 changes: 2 additions & 2 deletions services/rfq/relayer/quoter/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ func (m *Manager) GenerateQuotes(ctx context.Context, chainID int, address commo
return m.generateQuotes(ctx, chainID, address, balance, inv)
}

func (m *Manager) GetOriginAmount(ctx context.Context, origin, dest int, originAddr common.Address, address common.Address, originBalance, destBalance *big.Int) (*big.Int, error) {
return m.getOriginAmount(ctx, origin, dest, originAddr, address, originBalance, destBalance)
func (m *Manager) GetOriginAmount(ctx context.Context, input QuoteInput) (*big.Int, error) {
return m.getOriginAmount(ctx, input)
}

func (m *Manager) GetDestAmount(ctx context.Context, quoteAmount *big.Int, chainID int, tokenName string) (*big.Int, error) {
Expand Down
99 changes: 50 additions & 49 deletions services/rfq/relayer/quoter/quoter.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,14 @@ func (m *Manager) generateQuotes(parentCtx context.Context, chainID int, address
}

g.Go(func() error {
input := quoteInput{
originChainID: origin,
destChainID: chainID,
originTokenAddr: originTokenAddr,
destTokenAddr: address,
originBalance: originBalance,
destBalance: balance,
destRFQAddr: destRFQAddr,
input := QuoteInput{
OriginChainID: origin,
DestChainID: chainID,
OriginTokenAddr: originTokenAddr,
DestTokenAddr: address,
OriginBalance: originBalance,
DestBalance: balance,
DestRFQAddr: destRFQAddr,
}

quote, quoteErr := m.generateQuote(gctx, input)
Expand Down Expand Up @@ -390,19 +390,20 @@ func (m *Manager) generateQuotes(parentCtx context.Context, chainID int, address
return quotes, nil
}

type quoteInput struct {
originChainID int
destChainID int
originTokenAddr common.Address
destTokenAddr common.Address
originBalance *big.Int
destBalance *big.Int
destRFQAddr string
// QuoteInput is a wrapper struct for input arguments to generateQuote.
type QuoteInput struct {
OriginChainID int
DestChainID int
OriginTokenAddr common.Address
DestTokenAddr common.Address
OriginBalance *big.Int
DestBalance *big.Int
DestRFQAddr string
}

func (m *Manager) generateQuote(ctx context.Context, input quoteInput) (quote *model.PutQuoteRequest, err error) {
func (m *Manager) generateQuote(ctx context.Context, input QuoteInput) (quote *model.PutQuoteRequest, err error) {
// Calculate the quote amount for this route
originAmount, err := m.getOriginAmount(ctx, input.originChainID, input.destChainID, input.originTokenAddr, input.destTokenAddr, input.originBalance, input.destBalance)
originAmount, err := m.getOriginAmount(ctx, input)
// don't quote if gas exceeds quote
if errors.Is(err, errMinGasExceedsQuoteAmount) {
originAmount = big.NewInt(0)
Expand All @@ -412,38 +413,38 @@ func (m *Manager) generateQuote(ctx context.Context, input quoteInput) (quote *m
}

// Calculate the fee for this route
destToken, err := m.config.GetTokenName(uint32(input.destChainID), input.destTokenAddr.Hex())
destToken, err := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex())
if err != nil {
logger.Error("Error getting dest token ID", "error", err)
return nil, fmt.Errorf("error getting dest token ID: %w", err)
}
fee, err := m.feePricer.GetTotalFee(ctx, uint32(input.originChainID), uint32(input.destChainID), destToken, true)
fee, err := m.feePricer.GetTotalFee(ctx, uint32(input.OriginChainID), uint32(input.DestChainID), destToken, true)
if err != nil {
logger.Error("Error getting total fee", "error", err)
return nil, fmt.Errorf("error getting total fee: %w", err)
}
originRFQAddr, err := m.config.GetRFQAddress(input.originChainID)
originRFQAddr, err := m.config.GetRFQAddress(input.OriginChainID)
if err != nil {
logger.Error("Error getting RFQ address", "error", err)
return nil, fmt.Errorf("error getting RFQ address: %w", err)
}

// Build the quote
destAmount, err := m.getDestAmount(ctx, originAmount, input.destChainID, destToken)
destAmount, err := m.getDestAmount(ctx, originAmount, input.DestChainID, destToken)
if err != nil {
logger.Error("Error getting dest amount", "error", err)
return nil, fmt.Errorf("error getting dest amount: %w", err)
}
quote = &model.PutQuoteRequest{
OriginChainID: input.originChainID,
OriginTokenAddr: input.originTokenAddr.Hex(),
DestChainID: input.destChainID,
DestTokenAddr: input.destTokenAddr.Hex(),
OriginChainID: input.OriginChainID,
OriginTokenAddr: input.OriginTokenAddr.Hex(),
DestChainID: input.DestChainID,
DestTokenAddr: input.DestTokenAddr.Hex(),
DestAmount: destAmount.String(),
MaxOriginAmount: originAmount.String(),
FixedFee: fee.String(),
OriginFastBridgeAddress: originRFQAddr,
DestFastBridgeAddress: input.destRFQAddr,
DestFastBridgeAddress: input.DestRFQAddr,
}
return quote, nil
}
Expand Down Expand Up @@ -486,13 +487,14 @@ func (m *Manager) recordQuoteAmounts(_ context.Context, observer metric.Observer
// getOriginAmount calculates the origin quote amount for a given route.
//
//nolint:cyclop
func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, originAddress, destAddress common.Address, originBalance, destBalance *big.Int) (quoteAmount *big.Int, err error) {
func (m *Manager) getOriginAmount(parentCtx context.Context, input QuoteInput) (quoteAmount *big.Int, err error) {
ctx, span := m.metricsHandler.Tracer().Start(parentCtx, "getOriginAmount", trace.WithAttributes(
attribute.String(metrics.Origin, strconv.Itoa(origin)),
attribute.String(metrics.Destination, strconv.Itoa(dest)),
attribute.String("address", destAddress.String()),
attribute.String("origin_balance", originBalance.String()),
attribute.String("dest_balance", destBalance.String()),
attribute.Int(metrics.Origin, input.OriginChainID),
attribute.Int(metrics.Destination, input.DestChainID),
attribute.String("dest_address", input.DestTokenAddr.String()),
attribute.String("origin_address", input.OriginTokenAddr.String()),
attribute.String("origin_balance", input.OriginBalance.String()),
attribute.String("dest_balance", input.DestBalance.String()),
))

defer func() {
Expand All @@ -503,11 +505,11 @@ func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, o
// First, check if we have enough gas to complete the a bridge for this route
// If not, set the quote amount to zero to make sure a stale quote won't be used
// TODO: handle in-flight gas; for now we can set a high min_gas_token
sufficentGasOrigin, err := m.inventoryManager.HasSufficientGas(ctx, origin, nil)
sufficentGasOrigin, err := m.inventoryManager.HasSufficientGas(ctx, input.OriginChainID, nil)
if err != nil {
return nil, fmt.Errorf("error checking sufficient gas: %w", err)
}
sufficentGasDest, err := m.inventoryManager.HasSufficientGas(ctx, dest, nil)
sufficentGasDest, err := m.inventoryManager.HasSufficientGas(ctx, input.DestChainID, nil)
if err != nil {
return nil, fmt.Errorf("error checking sufficient gas: %w", err)
}
Expand All @@ -520,26 +522,26 @@ func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, o
}

// Apply the quotePct
quotePct, err := m.config.GetQuotePct(dest)
quotePct, err := m.config.GetQuotePct(input.DestChainID)
if err != nil {
return nil, fmt.Errorf("error getting quote pct: %w", err)
}
balanceFlt := new(big.Float).SetInt(destBalance)
balanceFlt := new(big.Float).SetInt(input.DestBalance)
quoteAmount, _ = new(big.Float).Mul(balanceFlt, new(big.Float).SetFloat64(quotePct/100)).Int(nil)

// Apply the quoteOffset to origin token.
tokenName, err := m.config.GetTokenName(uint32(dest), destAddress.Hex())
tokenName, err := m.config.GetTokenName(uint32(input.DestChainID), input.DestTokenAddr.Hex())
if err != nil {
return nil, fmt.Errorf("error getting token name: %w", err)
}
quoteOffsetBps, err := m.config.GetQuoteOffsetBps(origin, tokenName, true)
quoteOffsetBps, err := m.config.GetQuoteOffsetBps(input.OriginChainID, tokenName, true)
if err != nil {
return nil, fmt.Errorf("error getting quote offset bps: %w", err)
}
quoteAmount = m.applyOffset(ctx, quoteOffsetBps, quoteAmount)

// Clip the quoteAmount by the minQuoteAmount
minQuoteAmount := m.config.GetMinQuoteAmount(dest, destAddress)
minQuoteAmount := m.config.GetMinQuoteAmount(input.DestChainID, input.DestTokenAddr)
if quoteAmount.Cmp(minQuoteAmount) < 0 {
span.AddEvent("quote amount less than min quote amount", trace.WithAttributes(
attribute.String("quote_amount", quoteAmount.String()),
Expand All @@ -549,32 +551,31 @@ func (m *Manager) getOriginAmount(parentCtx context.Context, origin, dest int, o
}

// Clip the quoteAmount by the max origin balance
maxBalance := m.config.GetMaxBalance(origin, originAddress)
fmt.Printf("maxBalance: %v originBalance: %v\n", maxBalance, originBalance)
if originBalance != nil && maxBalance.Cmp(big.NewInt(0)) > 0 {
quotableBalance := new(big.Int).Sub(maxBalance, originBalance)
maxBalance := m.config.GetMaxBalance(input.OriginChainID, input.OriginTokenAddr)
if input.OriginBalance != nil && maxBalance.Cmp(big.NewInt(0)) > 0 {
quotableBalance := new(big.Int).Sub(maxBalance, input.OriginBalance)
if quoteAmount.Cmp(quotableBalance) > 0 {
span.AddEvent("quote amount greater than quotable balance", trace.WithAttributes(
attribute.String("quote_amount", quoteAmount.String()),
attribute.String("quotable_balance", quotableBalance.String()),
attribute.String("max_balance", maxBalance.String()),
attribute.String("origin_balance", originBalance.String()),
attribute.String("origin_balance", input.OriginBalance.String()),
))
quoteAmount = quotableBalance
}
}

// Finally, clip the quoteAmount by the dest balance
if quoteAmount.Cmp(destBalance) > 0 {
if quoteAmount.Cmp(input.DestBalance) > 0 {
span.AddEvent("quote amount greater than quotable balance", trace.WithAttributes(
attribute.String("quote_amount", quoteAmount.String()),
attribute.String("balance", destBalance.String()),
attribute.String("balance", input.DestBalance.String()),
))
quoteAmount = destBalance
quoteAmount = input.DestBalance
}

// Deduct gas cost from the quote amount, if necessary
quoteAmount, err = m.deductGasCost(ctx, quoteAmount, destAddress, dest)
quoteAmount, err = m.deductGasCost(ctx, quoteAmount, input.DestTokenAddr, input.DestChainID)
if err != nil {
return nil, fmt.Errorf("error deducting gas cost: %w", err)
}
Expand Down
25 changes: 17 additions & 8 deletions services/rfq/relayer/quoter/quoter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,57 +179,66 @@ func (s *QuoterSuite) TestGetOriginAmount() {
s.manager.SetConfig(s.config)
}

input := quoter.QuoteInput{
OriginChainID: origin,
DestChainID: dest,
OriginTokenAddr: originAddr,
DestTokenAddr: address,
OriginBalance: balance,
DestBalance: balance,
}

// Set default quote params; should return the balance.
quoteAmount, err := s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance)
quoteAmount, err := s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount := balance
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 50 with MinQuoteAmount of 0; should be 50% of balance.
setQuoteParams(50, 0, "0", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(500_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 50 with QuoteOffset of -1%. Should be 1% less than 50% of balance.
setQuoteParams(50, -100, "0", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(495_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 25 with MinQuoteAmount of 500; should be 50% of balance.
setQuoteParams(25, 0, "500", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(500_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 25 with MinQuoteAmount of 500; should be 50% of balance.
setQuoteParams(25, 0, "500", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(500_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 25 with MinQuoteAmount of 1500; should be total balance.
setQuoteParams(25, 0, "1500", "0")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(1000_000_000)
s.Equal(expectedAmount, quoteAmount)

// Set QuotePct to 25 with MinQuoteAmount of 1500 and MaxBalance of 1200; should be 200.
setQuoteParams(25, 0, "1500", "1200")
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(200_000_000)
s.Equal(expectedAmount, quoteAmount)

// Toggle insufficient gas; should be 0.
s.setGasSufficiency(false)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), origin, dest, originAddr, address, balance, balance)
quoteAmount, err = s.manager.GetOriginAmount(s.GetTestContext(), input)
s.NoError(err)
expectedAmount = big.NewInt(0)
s.Equal(expectedAmount, quoteAmount)
Expand Down

0 comments on commit fa5d27f

Please sign in to comment.