diff --git a/mm-bot/src/services/decorators/use_fallback.py b/mm-bot/src/services/decorators/use_fallback.py new file mode 100644 index 00000000..13bd51b2 --- /dev/null +++ b/mm-bot/src/services/decorators/use_fallback.py @@ -0,0 +1,40 @@ +from functools import wraps + + +def use_fallback(rpc_nodes, logger, error_message="Failed"): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + exceptions = [] + for rpc_node in rpc_nodes: + try: + return func(*args, **kwargs, rpc_node=rpc_node) + except Exception as exception: + logger.warning(f"[-] {error_message}: {exception}") + exceptions.append(exception) + logger.error(f"[-] {error_message} from all nodes") + raise Exception(f"{error_message} from all nodes: [{', '.join(str(e) for e in exceptions)}]") + + return wrapper + + return decorator + + +def use_async_fallback(rpc_nodes, logger, error_message="Failed"): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + exceptions = [] + for rpc_node in rpc_nodes: + try: + return await func(*args, **kwargs, rpc_node=rpc_node) + except Exception as exception: + logger.warning(f"[-] {error_message}: {exception}") + exceptions.append(exception) + logger.error(f"[-] {error_message} from all nodes") + raise Exception(f"{error_message} from all nodes: [{', '.join(str(e) for e in exceptions)}]") + + return wrapper + + return decorator + diff --git a/mm-bot/src/services/ethereum.py b/mm-bot/src/services/ethereum.py index 3e5864d0..b5096a59 100644 --- a/mm-bot/src/services/ethereum.py +++ b/mm-bot/src/services/ethereum.py @@ -4,65 +4,50 @@ from web3 import Web3 from config import constants +from services.decorators.use_fallback import use_fallback ETH_CHAIN_ID = int(constants.ETH_CHAIN_ID) +# get only the abi not the entire file +abi_file = json.load(open(os.getcwd() + '/abi/YABTransfer.json'))['abi'] -main_w3 = Web3(Web3.HTTPProvider(constants.ETH_RPC_URL)) -fallback_w3 = Web3(Web3.HTTPProvider(constants.ETH_FALLBACK_RPC_URL)) -w3_clients = [main_w3, fallback_w3] -""" -accounts are instances from the same account but from different nodes -So if a node is down we can use the other one -""" -main_account_rpc = main_w3.eth.account.from_key(constants.ETH_PRIVATE_KEY) -fallback_account_rpc = fallback_w3.eth.account.from_key(constants.ETH_PRIVATE_KEY) -accounts_rpc = [main_account_rpc, fallback_account_rpc] +class EthereumRpcNode: + def __init__(self, rpc_url, private_key, contract_address, abi): + self.w3 = Web3(Web3.HTTPProvider(rpc_url)) + self.account = self.w3.eth.account.from_key(private_key) + self.contract = self.w3.eth.contract(address=contract_address, abi=abi) -# get only the abi not the entire file -abi = json.load(open(os.getcwd() + '/abi/YABTransfer.json'))['abi'] -""" -contracts_rpc are instances from the same contract but from different nodes -So if a node is down we can use the other one -""" -main_contract_rpc = main_w3.eth.contract(address=constants.ETH_CONTRACT_ADDR, abi=abi) -fallback_contract_rpc = fallback_w3.eth.contract(address=constants.ETH_CONTRACT_ADDR, abi=abi) -contracts_rpc = [main_contract_rpc, fallback_contract_rpc] + +main_rpc_node = EthereumRpcNode(constants.ETH_RPC_URL, + constants.ETH_PRIVATE_KEY, + constants.ETH_CONTRACT_ADDR, + abi_file) +fallback_rpc_node = EthereumRpcNode(constants.ETH_FALLBACK_RPC_URL, + constants.ETH_PRIVATE_KEY, + constants.ETH_CONTRACT_ADDR, + abi_file) +rpc_nodes = [main_rpc_node, fallback_rpc_node] logger = logging.getLogger(__name__) -def get_latest_block() -> int: - for w3 in w3_clients: - try: - return w3.eth.block_number - except Exception as exception: - logger.warning(f"[-] Failed to get block number from node: {exception}") - logger.error(f"[-] Failed to get block number from all nodes") +@use_fallback(rpc_nodes, logger, "Failed to get latest block") +def get_latest_block(rpc_node=main_rpc_node) -> int: + return rpc_node.w3.eth.block_number -def get_is_used_order(order_id, recipient_address, amount) -> bool: +@use_fallback(rpc_nodes, logger, "Failed to get order status") +def get_is_used_order(order_id, recipient_address, amount, rpc_node=main_rpc_node) -> bool: is_used_index = 2 order_data = Web3.solidity_keccak(['uint256', 'uint256', 'uint256'], [order_id, int(recipient_address, 0), amount]) - for contract_rpc in contracts_rpc: - try: - res = contract_rpc.functions.transfers(order_data).call() - return res[is_used_index] - except Exception as exception: - logger.warning(f"[-] Failed to get is used order from node: {exception}") - logger.error(f"[-] Failed to get is used order from all nodes") - raise Exception("Failed to get is used order from all nodes") - - -def get_balance() -> int: - for index, w3 in enumerate(w3_clients): - try: - return w3.eth.get_balance(accounts_rpc[index].address) - except Exception as exception: - logger.warning(f"[-] Failed to get balance from node: {exception}") - logger.error(f"[-] Failed to get balance from all nodes") - raise Exception("Failed to get balance from all nodes") + res = rpc_node.contract.functions.transfers(order_data).call() + return res[is_used_index] + + +@use_fallback(rpc_nodes, logger, "Failed to get balance") +def get_balance(rpc_node=main_rpc_node) -> int: + return rpc_node.w3.eth.get_balance(rpc_node.account.address) def has_funds(amount: int) -> bool: @@ -76,7 +61,7 @@ def transfer(deposit_id, dst_addr, amount): unsent_tx, signed_tx = create_transfer(deposit_id, dst_addr_bytes, amount) - gas_fee = estimate_gas_fee(unsent_tx) + gas_fee = estimate_transaction_fee(unsent_tx) if not has_enough_funds(amount, gas_fee): raise Exception("Not enough funds for transfer") @@ -85,21 +70,16 @@ def transfer(deposit_id, dst_addr, amount): # we need amount so the transaction is valid with the transfer that will be transferred -def create_transfer(deposit_id, dst_addr_bytes, amount): - for index, w3 in enumerate(w3_clients): - try: - unsent_tx = contracts_rpc[index].functions.transfer(deposit_id, dst_addr_bytes, amount).build_transaction({ - "chainId": ETH_CHAIN_ID, - "from": accounts_rpc[index].address, - "nonce": get_nonce(w3, accounts_rpc[index].address), - "value": amount, - }) - signed_tx = w3.eth.account.sign_transaction(unsent_tx, private_key=accounts_rpc[index].key) - return unsent_tx, signed_tx - except Exception as exception: - logger.warning(f"[-] Failed to create transfer eth on node: {exception}") - logger.error(f"[-] Failed to create transfer eth on all nodes") - raise Exception("Failed to create transfer eth on all nodes") +@use_fallback(rpc_nodes, logger, "Failed to create ethereum transfer") +def create_transfer(deposit_id, dst_addr_bytes, amount, rpc_node=main_rpc_node): + unsent_tx = rpc_node.contract.functions.transfer(deposit_id, dst_addr_bytes, amount).build_transaction({ + "chainId": ETH_CHAIN_ID, + "from": rpc_node.account.address, + "nonce": get_nonce(rpc_node.w3, rpc_node.account.address), + "value": amount, + }) + signed_tx = rpc_node.w3.eth.account.sign_transaction(unsent_tx, private_key=rpc_node.account.key) + return unsent_tx, signed_tx def withdraw(deposit_id, dst_addr, amount, value): @@ -109,7 +89,7 @@ def withdraw(deposit_id, dst_addr, amount, value): unsent_tx, signed_tx = create_withdraw(deposit_id, dst_addr_bytes, amount, value) - gas_fee = estimate_gas_fee(unsent_tx) + gas_fee = estimate_transaction_fee(unsent_tx) if not has_enough_funds(gas_fee=gas_fee): raise Exception("Not enough funds for withdraw") @@ -117,40 +97,28 @@ def withdraw(deposit_id, dst_addr, amount, value): return tx_hash -def create_withdraw(deposit_id, dst_addr_bytes, amount, value): - exceptions = [] - for index, w3 in enumerate(w3_clients): - try: - unsent_tx = contracts_rpc[index].functions.withdraw(deposit_id, dst_addr_bytes, amount).build_transaction({ - "chainId": ETH_CHAIN_ID, - "from": accounts_rpc[index].address, - "nonce": get_nonce(w3, accounts_rpc[index].address), - "value": value, - }) - signed_tx = w3.eth.account.sign_transaction(unsent_tx, private_key=accounts_rpc[index].key) - return unsent_tx, signed_tx - except Exception as exception: - logger.warning(f"[-] Failed to create withdraw eth on node: {exception}") - exceptions.append(exception) - logger.error(f"[-] Failed to create withdraw eth on all nodes") - raise Exception(f"Failed to create withdraw eth on all nodes: [{', '.join(str(e) for e in exceptions)}]") +@use_fallback(rpc_nodes, logger, "Failed to create withdraw eth") +def create_withdraw(deposit_id, dst_addr_bytes, amount, value, rpc_node=main_rpc_node): + unsent_tx = rpc_node.contract.functions.withdraw(deposit_id, dst_addr_bytes, amount).build_transaction({ + "chainId": ETH_CHAIN_ID, + "from": rpc_node.account.address, + "nonce": get_nonce(rpc_node.w3, rpc_node.account.address), + "value": value, + }) + signed_tx = rpc_node.w3.eth.account.sign_transaction(unsent_tx, private_key=rpc_node.account.key) + return unsent_tx, signed_tx def get_nonce(w3: Web3, address): return w3.eth.get_transaction_count(address) -def estimate_gas_fee(transaction): - for w3 in w3_clients: - try: - gas_limit = w3.eth.estimate_gas(transaction) - fee = w3.eth.gas_price - gas_fee = fee * gas_limit - return gas_fee - except Exception as exception: - logger.warning(f"[-] Failed to estimate fee on node: {exception}") - logger.error(f"[-] Failed to estimate fee on all nodes") - raise Exception("Failed to estimate fee on all nodes") +@use_fallback(rpc_nodes, logger, "Failed to estimate gas fee") +def estimate_transaction_fee(transaction, rpc_node=main_rpc_node): + gas_limit = rpc_node.w3.eth.estimate_gas(transaction) + fee = rpc_node.w3.eth.gas_price + gas_fee = fee * gas_limit + return gas_fee def is_transaction_viable(amount: int, percentage: float, gas_fee: int) -> bool: @@ -161,23 +129,12 @@ def has_enough_funds(amount: int = 0, gas_fee: int = 0) -> bool: return get_balance() >= amount + gas_fee -def send_raw_transaction(signed_tx): - for w3 in w3_clients: - try: - tx_hash = w3.eth.send_raw_transaction(signed_tx.rawTransaction) - return tx_hash - except Exception as exception: - logger.warning(f"[-] Failed to send raw transaction on node: {exception}") - logger.error(f"[-] Failed to send raw transaction on all nodes") - raise Exception("Failed to send raw transaction on all nodes") - - -def wait_for_transaction_receipt(tx_hash): - for w3 in w3_clients: - try: - w3.eth.wait_for_transaction_receipt(tx_hash) - return True - except Exception as exception: - logger.warning(f"[-] Failed to wait for transaction receipt on node: {exception}") - logger.error(f"[-] Failed to wait for transaction receipt on all nodes") - raise Exception("Failed to wait for transaction receipt on all nodes") +@use_fallback(rpc_nodes, logger, "Failed to send raw transaction") +def send_raw_transaction(signed_tx, rpc_node=main_rpc_node): + tx_hash = rpc_node.w3.eth.send_raw_transaction(signed_tx.rawTransaction) + return tx_hash + + +@use_fallback(rpc_nodes, logger, "Failed to wait for transaction receipt") +def wait_for_transaction_receipt(tx_hash, rpc_node=main_rpc_node): + rpc_node.w3.eth.wait_for_transaction_receipt(tx_hash) diff --git a/mm-bot/src/services/starknet.py b/mm-bot/src/services/starknet.py index 6f3007ff..d1aeb242 100644 --- a/mm-bot/src/services/starknet.py +++ b/mm-bot/src/services/starknet.py @@ -11,29 +11,37 @@ from config import constants from services import ethereum +from services.decorators.use_fallback import use_fallback, use_async_fallback from services.mm_full_node_client import MmFullNodeClient SN_CHAIN_ID = int_from_bytes(constants.SN_CHAIN_ID.encode("utf-8")) SET_ORDER_EVENT_KEY = 0x2c75a60b5bdad73ebbf539cc807fccd09875c3cbf3f44041f852cdb96d8acd3 -main_full_node_client = MmFullNodeClient(node_url=constants.SN_RPC_URL) -fallback_full_node_client = MmFullNodeClient(node_url=constants.SN_FALLBACK_RPC_URL) -full_node_clients = [fallback_full_node_client, main_full_node_client] - -key_pair = KeyPair.from_private_key(key=constants.SN_PRIVATE_KEY) -main_account = Account( - client=main_full_node_client, - address=constants.SN_WALLET_ADDR, - key_pair=key_pair, - chain=SN_CHAIN_ID, # ignore this warning TODO change to StarknetChainId when starknet_py adds sepolia -) -fallback_account = Account( - client=fallback_full_node_client, - address=constants.SN_WALLET_ADDR, - key_pair=key_pair, - chain=SN_CHAIN_ID, # ignore this warning TODO change to StarknetChainId when starknet_py adds sepolia -) -accounts = [fallback_account, main_account] + +class StarknetRpcNode: + def __init__(self, rpc_url, private_key, wallet_address, contract_address, chain_id): + self.full_node_client = MmFullNodeClient(node_url=rpc_url) + key_pair = KeyPair.from_private_key(key=private_key) + self.account = Account( + client=self.full_node_client, + address=wallet_address, + key_pair=key_pair, + chain=chain_id, # ignore this warning TODO change to StarknetChainId when starknet_py adds sepolia + ) + self.contract_address = contract_address + + +main_rpc_node = StarknetRpcNode(constants.SN_RPC_URL, + constants.SN_PRIVATE_KEY, + constants.SN_WALLET_ADDR, + constants.SN_CONTRACT_ADDR, + SN_CHAIN_ID) +fallback_rpc_node = StarknetRpcNode(constants.SN_FALLBACK_RPC_URL, + constants.SN_PRIVATE_KEY, + constants.SN_WALLET_ADDR, + constants.SN_CONTRACT_ADDR, + SN_CHAIN_ID) +rpc_nodes = [main_rpc_node, fallback_rpc_node] logger = logging.getLogger(__name__) @@ -52,40 +60,30 @@ def __str__(self): return f"order_id:{self.order_id}, recipient: {self.recipient_address}, amount: {self.amount}, fee: {self.fee}" +@use_async_fallback(rpc_nodes, logger, "Failed to get events") async def get_starknet_events(from_block_number: Literal["pending", "latest"] | int | None = "pending", to_block_number: Literal["pending", "latest"] | int | None = "pending", - continuation_token=None): - for client in full_node_clients: - try: - events_response = await client.get_events( - address=constants.SN_CONTRACT_ADDR, - chunk_size=1000, - keys=[[SET_ORDER_EVENT_KEY]], - from_block_number=from_block_number, - to_block_number=to_block_number, - continuation_token=continuation_token - ) - return events_response - except Exception as exception: - logger.warning(f"[-] Failed to get events from node: {exception}") - logger.error(f"[-] Failed to get events from all nodes") - return None - - -async def get_is_used_order(order_id) -> bool: + continuation_token=None, rpc_node=main_rpc_node): + events_response = await rpc_node.full_node_client.get_events( + address=constants.SN_CONTRACT_ADDR, + chunk_size=1000, + keys=[[SET_ORDER_EVENT_KEY]], + from_block_number=from_block_number, + to_block_number=to_block_number, + continuation_token=continuation_token + ) + return events_response + + +@use_async_fallback(rpc_nodes, logger, "Failed to set order") +async def get_is_used_order(order_id, rpc_node=main_rpc_node) -> bool: call = Call( to_addr=constants.SN_CONTRACT_ADDR, selector=get_selector_from_name("get_order_used"), calldata=[order_id, 0], ) - for account in accounts: - try: - status = await account.client.call_contract(call) - return status[0] - except Exception as exception: - logger.warning(f"[-] Failed to get order status from node: {exception}") - logger.error(f"[-] Failed to get order status from all nodes") - return True + status = await rpc_node.account.client.call_contract(call) + return status[0] async def get_order_events(from_block_number, to_block_number) -> list[SetOrderEvent]: @@ -151,15 +149,10 @@ def get_fee(event) -> int: return 0 -async def get_latest_block() -> int: - for client in full_node_clients: - try: - latest_block = await client.get_block("latest") - return latest_block.block_number - except Exception as exception: - logger.warning(f"[-] Failed to get latest block from node: {exception}") - logger.error(f"[-] Failed to get latest block from all nodes") - return 0 +@use_async_fallback(rpc_nodes, logger, "Failed to get latest block") +async def get_latest_block(rpc_node=main_rpc_node) -> int: + latest_block = await rpc_node.full_node_client.get_block("latest") + return latest_block.block_number async def withdraw(order_id, block, slot) -> bool: @@ -183,45 +176,25 @@ async def withdraw(order_id, block, slot) -> bool: return False -async def sign_invoke_transaction(call: Call, max_fee: int): - for account in accounts: - try: - transaction = await account.sign_invoke_transaction(call, max_fee=max_fee) - return transaction - except Exception as e: - logger.warning(f"[-] Failed to sign invoke transaction: {e}") - logger.error(f"[-] Failed to sign invoke transaction from all nodes") - raise Exception("Failed to sign invoke transaction from all nodes") - - -async def estimate_message_fee(from_address, to_address, entry_point_selector, payload): - for client in full_node_clients: - try: - fee = await client.estimate_message_fee(from_address, to_address, entry_point_selector, payload) - return fee.overall_fee - except Exception as e: - logger.warning(f"[-] Failed to estimate message fee: {e}") - logger.error(f"[-] Failed to estimate message fee from all nodes") - raise Exception("Failed to estimate message fee from all nodes") - - -async def send_transaction(transaction): - for account in accounts: - try: - result = await account.client.send_transaction(transaction) - return result - except Exception as e: - logger.warning(f"[-] Failed to send transaction: {e}") - logger.error(f"[-] Failed to send transaction from all nodes") - raise Exception("Failed to send transaction from all nodes") - - -async def wait_for_tx(transaction_hash): - for account in accounts: - try: - await account.client.wait_for_tx(transaction_hash) - return - except Exception as e: - logger.warning(f"[-] Failed to wait for tx: {e}") - logger.error(f"[-] Failed to wait for tx from all nodes") - raise Exception("Failed to wait for tx from all nodes") +@use_async_fallback(rpc_nodes, logger, "Failed to sign invoke transaction") +async def sign_invoke_transaction(call: Call, max_fee: int, rpc_node=main_rpc_node): + return await rpc_node.account.sign_invoke_transaction(call, max_fee=max_fee) + + +@use_async_fallback(rpc_nodes, logger, "Failed to estimate message fee") +async def estimate_message_fee(from_address, to_address, entry_point_selector, payload, rpc_node=main_rpc_node): + fee = await rpc_node.full_node_client.estimate_message_fee(from_address, to_address, entry_point_selector, payload) + return fee.overall_fee + + + + +@use_async_fallback(rpc_nodes, logger, "Failed to send transaction") +async def send_transaction(transaction, rpc_node=main_rpc_node): + return await rpc_node.account.client.send_transaction(transaction) + + +@use_async_fallback(rpc_nodes, logger, "Failed to wait for tx") +async def wait_for_tx(transaction_hash, rpc_node=main_rpc_node): + await rpc_node.account.client.wait_for_tx(transaction_hash) +