Skip to content

Commit

Permalink
Merge pull request hummingbot#179 from CoinAlpha/fix/gateway-v2-start…
Browse files Browse the repository at this point in the history
…-balance-bug

fix / gateway v2 start balance bug
  • Loading branch information
Martin Kou authored Mar 31, 2022
2 parents e3a76b1 + e198f90 commit 862f7c8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
2 changes: 1 addition & 1 deletion hummingbot/client/command/balance_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def show_balances(self):
except asyncio.TimeoutError:
self.notify("\nA network error prevented the balances to update. See logs for more details.")
raise
all_ex_avai_bals = UserBalances.instance().all_avai_balances_all_exchanges()
all_ex_avai_bals = UserBalances.instance().all_available_balances_all_exchanges()
all_ex_limits: Optional[Dict[str, Dict[str, str]]] = global_config_map["balance_asset_limit"].value

if all_ex_limits is None:
Expand Down
5 changes: 3 additions & 2 deletions hummingbot/strategy/amm_arb/amm_arb.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import asyncio
import pandas as pd

from decimal import Decimal
from functools import lru_cache
import pandas as pd
from typing import List, Dict, Tuple, Optional, Callable, cast

from hummingbot.client.performance import PerformanceMetrics
Expand Down Expand Up @@ -144,6 +144,7 @@ def market_info_to_active_orders(self) -> Dict[MarketTradingPairTuple, List[Limi
return self._sb_order_tracker.market_pair_to_active_orders

@staticmethod
@lru_cache(maxsize=10)
def is_gateway_market(market_info: MarketTradingPairTuple) -> bool:
return market_info.market.name in AllConnectorSettings.get_gateway_evm_amm_connector_names()

Expand Down
45 changes: 24 additions & 21 deletions hummingbot/user/user_balances.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from decimal import Decimal
from functools import lru_cache
from typing import Optional, Dict, List

from hummingbot.core.utils.market_price import get_last_price
from hummingbot.client.settings import AllConnectorSettings, gateway_connector_trading_pairs
from hummingbot.client.config.security import Security
from hummingbot.client.config.config_helpers import get_connector_class
from hummingbot.core.utils.async_utils import safe_gather

from typing import Optional, Dict, List
from decimal import Decimal


class UserBalances:
__instance = None
Expand Down Expand Up @@ -37,6 +38,11 @@ def instance():
UserBalances()
return UserBalances.__instance

@staticmethod
@lru_cache(maxsize=10)
def is_gateway_market(exchange_name: str) -> bool:
return exchange_name in AllConnectorSettings.get_gateway_evm_amm_connector_names()

def __init__(self):
if UserBalances.__instance is not None:
raise Exception("This class is a singleton!")
Expand All @@ -59,16 +65,23 @@ def all_balances(self, exchange) -> Dict[str, Decimal]:
return {}
return self._markets[exchange].get_all_balances()

async def update_exchange_balance(self, exchange) -> Optional[str]:
if exchange in self._markets:
return await self._update_balances(self._markets[exchange])
async def update_exchange_balance(self, exchange_name: str) -> Optional[str]:
if self.is_gateway_market(exchange_name) and exchange_name in self._markets:
# we want to refresh gateway connectors always, since the applicable tokens change over time.
# doing this will reinitialize and fetch balances for active trading pair
del self._markets[exchange_name]
if exchange_name in self._markets:
return await self._update_balances(self._markets[exchange_name])
else:
api_keys = await Security.api_keys(exchange)
return await self.add_exchange(exchange, **api_keys)
api_keys = await Security.api_keys(exchange_name)
return await self.add_exchange(exchange_name, **api_keys)

# returns error message for each exchange
async def update_exchanges(self, reconnect: bool = False,
exchanges: List[str] = []) -> Dict[str, Optional[str]]:
async def update_exchanges(
self,
reconnect: bool = False,
exchanges: List[str] = []
) -> Dict[str, Optional[str]]:
tasks = []
# Update user balances
if len(exchanges) == 0:
Expand All @@ -81,19 +94,9 @@ async def update_exchanges(self, reconnect: bool = False,
and not cs.name.endswith("paper_trade")
]

gateway_connectors: List[str] = [
cs.name
for cs in AllConnectorSettings.get_connector_settings().values()
if cs.uses_gateway_generic_connector()
]

if reconnect:
self._markets.clear()
for exchange in exchanges:
if exchange in gateway_connectors and exchange in self._markets:
# we want to refresh gateway connectors always
# doing this will reinitialize and fetch balances for active trading pair
del self._markets[exchange]
tasks.append(self.update_exchange_balance(exchange))
results = await safe_gather(*tasks)
return {ex: err_msg for ex, err_msg in zip(exchanges, results)}
Expand All @@ -102,7 +105,7 @@ async def all_balances_all_exchanges(self) -> Dict[str, Dict[str, Decimal]]:
await self.update_exchanges()
return {k: v.get_all_balances() for k, v in sorted(self._markets.items(), key=lambda x: x[0])}

def all_avai_balances_all_exchanges(self) -> Dict[str, Dict[str, Decimal]]:
def all_available_balances_all_exchanges(self) -> Dict[str, Dict[str, Decimal]]:
return {k: v.available_balances for k, v in sorted(self._markets.items(), key=lambda x: x[0])}

async def balances(self, exchange, *symbols) -> Dict[str, Decimal]:
Expand Down

0 comments on commit 862f7c8

Please sign in to comment.