diff --git a/mm-bot/herodotus.py b/mm-bot/herodotus.py index d7abcc35..3c318562 100644 --- a/mm-bot/herodotus.py +++ b/mm-bot/herodotus.py @@ -1,3 +1,4 @@ +import asyncio import logging import constants import requests @@ -8,7 +9,7 @@ logger = logging.getLogger(__name__) -def herodotus_prove(block, order_id, slot) -> str: +async def herodotus_prove(block, order_id, slot) -> str: headers = { "Content-Type": "application/json", } @@ -48,7 +49,7 @@ def herodotus_prove(block, order_id, slot) -> str: retries += 1 if retries == constants.MAX_RETRIES: raise err - time.sleep(constants.RETRIES_DELAY) + await asyncio.sleep(constants.RETRIES_DELAY) def herodotus_status(task_id) -> str: @@ -65,7 +66,7 @@ def herodotus_status(task_id) -> str: raise err -def herodotus_poll_status(task_id) -> bool: +async def herodotus_poll_status(task_id) -> bool: # instead of returning a bool we can store it in a mapping retries = 0 start_time = time.time() @@ -81,12 +82,10 @@ def herodotus_poll_status(task_id) -> bool: logger.info(f"[+] Herodotus average request time (total): {sum(reqs) / len(reqs)}") return True retries += 1 - time.sleep(constants.RETRIES_DELAY) except requests.exceptions.RequestException as err: logger.error(err) retries += 1 if retries == constants.MAX_RETRIES: raise err - time.sleep(constants.RETRIES_DELAY) - + await asyncio.sleep(constants.RETRIES_DELAY) return False diff --git a/mm-bot/logging_config.py b/mm-bot/logging_config.py index f4e472ee..c1db4225 100644 --- a/mm-bot/logging_config.py +++ b/mm-bot/logging_config.py @@ -1,3 +1,4 @@ +import asyncio import logging from logging.handlers import TimedRotatingFileHandler import sys @@ -9,7 +10,7 @@ def setup_logger(): logger.setLevel(logging.DEBUG) # Formatter for log messages - log_format = "%(asctime)s %(levelname)6s - [%(threadName)15s] : %(message)s" + log_format = "%(asctime)s %(levelname)6s - [%(threadName)15s] [%(taskName)s] : %(message)s" formatter = logging.Formatter(log_format, datefmt="%Y-%m-%dT%H:%M:%S") # Add handlers based on the environment @@ -17,6 +18,7 @@ def setup_logger(): handler = get_production_handler(formatter) else: handler = get_development_handler(formatter) + handler.addFilter(AsyncioFilter()) logger.addHandler(handler) @@ -44,3 +46,15 @@ def get_production_handler(formatter): file_handler.setLevel(constants.LOGGING_LEVEL) file_handler.setFormatter(formatter) return file_handler + + +class AsyncioFilter(logging.Filter): + """ + This is a filter which injects contextual information into the log. + """ + def filter(self, record): + try: + record.taskName = asyncio.current_task().get_name() + except RuntimeError: + record.taskName = "Main" + return True diff --git a/mm-bot/main.py b/mm-bot/main.py index a2a7c908..0d299bda 100644 --- a/mm-bot/main.py +++ b/mm-bot/main.py @@ -1,12 +1,10 @@ import asyncio import logging -import threading import ethereum import herodotus import json import starknet -import time from web3 import Web3 from logging_config import setup_logger @@ -17,8 +15,7 @@ async def run(): logger.info(f"[+] Listening events on starknet") - eth_lock = threading.Lock() - threads = [] + eth_lock = asyncio.Lock() orders = set() try: @@ -28,31 +25,26 @@ async def run(): latest_orders: set = await starknet.get_latest_unfulfilled_orders() if len(latest_orders) == 0: logger.debug(f"[+] No new events") - time.sleep(SLEEP_TIME) + await asyncio.sleep(SLEEP_TIME) continue for order in latest_orders: - logger.info(f"[+] New event: {order}") - order_id = order.order_id dst_addr = order.recipient_address amount = order.amount if order_id in orders: - logger.info(f"[+] Order already processed: {order_id}") + logger.debug(f"[+] Order already processed: {order_id}") continue + logger.info(f"[+] New order: {order}") + orders.add(order_id) - t = threading.Thread(target=asyncio.run, args=(process_order(order_id, dst_addr, amount, eth_lock),)) - threads.append(t) - t.start() + asyncio.create_task(process_order(order_id, dst_addr, amount, eth_lock), name=f"Order-{order_id}") - time.sleep(SLEEP_TIME) + await asyncio.sleep(SLEEP_TIME) except Exception as e: logger.error(f"[-] Error: {e}") - logger.info(f"[+] Waiting for threads to finish") - for t in threads: - t.join() logger.info(f"[+] All threads finished") @@ -60,10 +52,10 @@ async def process_order(order_id, dst_addr, amount, eth_lock): # 2. Transfer eth on ethereum # (bridging is complete for the user) logger.info(f"[+] Transferring eth on ethereum") - with eth_lock: + async with eth_lock: try: # in case it's processed on ethereum, but not processed on starknet - ethereum.transfer(order_id, dst_addr, amount) + await asyncio.to_thread(ethereum.transfer, order_id, dst_addr, amount) logger.info(f"[+] Transfer complete") except Exception as e: logger.error(f"[-] Transfer failed: {e}") @@ -77,14 +69,14 @@ async def process_order(order_id, dst_addr, amount, eth_lock): logger.info(f"[+] Index: {index.hex()}") logger.info(f"[+] Slot: {slot.hex()}") logger.info(f"[+] Proving block {block}") - task_id = herodotus.herodotus_prove(block, order_id, slot) + task_id = await herodotus.herodotus_prove(block, order_id, slot) logger.info(f"[+] Block being proved with task id: {task_id}") # 4. Poll herodotus to check task status logger.info(f"[+] Polling herodotus for task status") # avoid weird case where herodotus insta says done - time.sleep(10) - completed = herodotus.herodotus_poll_status(task_id) + await asyncio.sleep(10) + completed = await herodotus.herodotus_poll_status(task_id) logger.info(f"[+] Task completed") # 5. Withdraw eth from starknet