From e198f903b360331d6994094ea4f4884a0b892ef8 Mon Sep 17 00:00:00 2001 From: Martin Kou Date: Wed, 30 Mar 2022 17:24:46 -0700 Subject: [PATCH] Fixed an issue with amm_arb start, in which gateway connector assets are changed after changing the AMM markets. --- hummingbot/client/command/balance_command.py | 2 +- hummingbot/strategy/amm_arb/amm_arb.py | 5 ++- hummingbot/user/user_balances.py | 45 +++++++++++--------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/hummingbot/client/command/balance_command.py b/hummingbot/client/command/balance_command.py index 9be2522915..9e91abec90 100644 --- a/hummingbot/client/command/balance_command.py +++ b/hummingbot/client/command/balance_command.py @@ -89,7 +89,7 @@ async def show_balances(self): except asyncio.TimeoutError: self.notify("\nA network error prevented the balances to update. See logs for more details.") raise - all_ex_avai_bals = UserBalances.instance().all_avai_balances_all_exchanges() + all_ex_avai_bals = UserBalances.instance().all_available_balances_all_exchanges() all_ex_limits: Optional[Dict[str, Dict[str, str]]] = global_config_map["balance_asset_limit"].value if all_ex_limits is None: diff --git a/hummingbot/strategy/amm_arb/amm_arb.py b/hummingbot/strategy/amm_arb/amm_arb.py index 66ab3f64f0..c6b3a6345d 100644 --- a/hummingbot/strategy/amm_arb/amm_arb.py +++ b/hummingbot/strategy/amm_arb/amm_arb.py @@ -1,8 +1,8 @@ import logging import asyncio -import pandas as pd - from decimal import Decimal +from functools import lru_cache +import pandas as pd from typing import List, Dict, Tuple, Optional, Callable, cast from hummingbot.client.performance import PerformanceMetrics @@ -144,6 +144,7 @@ def market_info_to_active_orders(self) -> Dict[MarketTradingPairTuple, List[Limi return self._sb_order_tracker.market_pair_to_active_orders @staticmethod + @lru_cache(maxsize=10) def is_gateway_market(market_info: MarketTradingPairTuple) -> bool: return market_info.market.name in AllConnectorSettings.get_gateway_evm_amm_connector_names() diff --git a/hummingbot/user/user_balances.py b/hummingbot/user/user_balances.py index 2d3588b83d..dbbd6e7a26 100644 --- a/hummingbot/user/user_balances.py +++ b/hummingbot/user/user_balances.py @@ -1,12 +1,13 @@ +from decimal import Decimal +from functools import lru_cache +from typing import Optional, Dict, List + from hummingbot.core.utils.market_price import get_last_price from hummingbot.client.settings import AllConnectorSettings, gateway_connector_trading_pairs from hummingbot.client.config.security import Security from hummingbot.client.config.config_helpers import get_connector_class from hummingbot.core.utils.async_utils import safe_gather -from typing import Optional, Dict, List -from decimal import Decimal - class UserBalances: __instance = None @@ -37,6 +38,11 @@ def instance(): UserBalances() return UserBalances.__instance + @staticmethod + @lru_cache(maxsize=10) + def is_gateway_market(exchange_name: str) -> bool: + return exchange_name in AllConnectorSettings.get_gateway_evm_amm_connector_names() + def __init__(self): if UserBalances.__instance is not None: raise Exception("This class is a singleton!") @@ -59,16 +65,23 @@ def all_balances(self, exchange) -> Dict[str, Decimal]: return {} return self._markets[exchange].get_all_balances() - async def update_exchange_balance(self, exchange) -> Optional[str]: - if exchange in self._markets: - return await self._update_balances(self._markets[exchange]) + async def update_exchange_balance(self, exchange_name: str) -> Optional[str]: + if self.is_gateway_market(exchange_name) and exchange_name in self._markets: + # we want to refresh gateway connectors always, since the applicable tokens change over time. + # doing this will reinitialize and fetch balances for active trading pair + del self._markets[exchange_name] + if exchange_name in self._markets: + return await self._update_balances(self._markets[exchange_name]) else: - api_keys = await Security.api_keys(exchange) - return await self.add_exchange(exchange, **api_keys) + api_keys = await Security.api_keys(exchange_name) + return await self.add_exchange(exchange_name, **api_keys) # returns error message for each exchange - async def update_exchanges(self, reconnect: bool = False, - exchanges: List[str] = []) -> Dict[str, Optional[str]]: + async def update_exchanges( + self, + reconnect: bool = False, + exchanges: List[str] = [] + ) -> Dict[str, Optional[str]]: tasks = [] # Update user balances if len(exchanges) == 0: @@ -81,19 +94,9 @@ async def update_exchanges(self, reconnect: bool = False, and not cs.name.endswith("paper_trade") ] - gateway_connectors: List[str] = [ - cs.name - for cs in AllConnectorSettings.get_connector_settings().values() - if cs.uses_gateway_generic_connector() - ] - if reconnect: self._markets.clear() for exchange in exchanges: - if exchange in gateway_connectors and exchange in self._markets: - # we want to refresh gateway connectors always - # doing this will reinitialize and fetch balances for active trading pair - del self._markets[exchange] tasks.append(self.update_exchange_balance(exchange)) results = await safe_gather(*tasks) return {ex: err_msg for ex, err_msg in zip(exchanges, results)} @@ -102,7 +105,7 @@ async def all_balances_all_exchanges(self) -> Dict[str, Dict[str, Decimal]]: await self.update_exchanges() return {k: v.get_all_balances() for k, v in sorted(self._markets.items(), key=lambda x: x[0])} - def all_avai_balances_all_exchanges(self) -> Dict[str, Dict[str, Decimal]]: + def all_available_balances_all_exchanges(self) -> Dict[str, Dict[str, Decimal]]: return {k: v.available_balances for k, v in sorted(self._markets.items(), key=lambda x: x[0])} async def balances(self, exchange, *symbols) -> Dict[str, Decimal]: