Skip to content

Commit

Permalink
refactor: use asyncio instead of threads (#25)
Browse files Browse the repository at this point in the history
* chore: print task id from asyncio task in logs

* refactor: use asyncio coroutines instead of threads

* refactor: use asyncio coroutines instead of threads
  • Loading branch information
JuArce authored Dec 18, 2023
1 parent fc691a5 commit fdd9c01
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 27 deletions.
11 changes: 5 additions & 6 deletions mm-bot/herodotus.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import constants
import requests
Expand All @@ -8,7 +9,7 @@
logger = logging.getLogger(__name__)


def herodotus_prove(block, order_id, slot) -> str:
async def herodotus_prove(block, order_id, slot) -> str:
headers = {
"Content-Type": "application/json",
}
Expand Down Expand Up @@ -48,7 +49,7 @@ def herodotus_prove(block, order_id, slot) -> str:
retries += 1
if retries == constants.MAX_RETRIES:
raise err
time.sleep(constants.RETRIES_DELAY)
await asyncio.sleep(constants.RETRIES_DELAY)


def herodotus_status(task_id) -> str:
Expand All @@ -65,7 +66,7 @@ def herodotus_status(task_id) -> str:
raise err


def herodotus_poll_status(task_id) -> bool:
async def herodotus_poll_status(task_id) -> bool:
# instead of returning a bool we can store it in a mapping
retries = 0
start_time = time.time()
Expand All @@ -81,12 +82,10 @@ def herodotus_poll_status(task_id) -> bool:
logger.info(f"[+] Herodotus average request time (total): {sum(reqs) / len(reqs)}")
return True
retries += 1
time.sleep(constants.RETRIES_DELAY)
except requests.exceptions.RequestException as err:
logger.error(err)
retries += 1
if retries == constants.MAX_RETRIES:
raise err
time.sleep(constants.RETRIES_DELAY)

await asyncio.sleep(constants.RETRIES_DELAY)
return False
16 changes: 15 additions & 1 deletion mm-bot/logging_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from logging.handlers import TimedRotatingFileHandler
import sys
Expand All @@ -9,14 +10,15 @@ def setup_logger():
logger.setLevel(logging.DEBUG)

# Formatter for log messages
log_format = "%(asctime)s %(levelname)6s - [%(threadName)15s] : %(message)s"
log_format = "%(asctime)s %(levelname)6s - [%(threadName)15s] [%(taskName)s] : %(message)s"
formatter = logging.Formatter(log_format, datefmt="%Y-%m-%dT%H:%M:%S")

# Add handlers based on the environment
if in_production():
handler = get_production_handler(formatter)
else:
handler = get_development_handler(formatter)
handler.addFilter(AsyncioFilter())
logger.addHandler(handler)


Expand Down Expand Up @@ -44,3 +46,15 @@ def get_production_handler(formatter):
file_handler.setLevel(constants.LOGGING_LEVEL)
file_handler.setFormatter(formatter)
return file_handler


class AsyncioFilter(logging.Filter):
"""
This is a filter which injects contextual information into the log.
"""
def filter(self, record):
try:
record.taskName = asyncio.current_task().get_name()
except RuntimeError:
record.taskName = "Main"
return True
32 changes: 12 additions & 20 deletions mm-bot/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import asyncio
import logging
import threading

import ethereum
import herodotus
import json
import starknet
import time
from web3 import Web3
from logging_config import setup_logger

Expand All @@ -17,8 +15,7 @@

async def run():
logger.info(f"[+] Listening events on starknet")
eth_lock = threading.Lock()
threads = []
eth_lock = asyncio.Lock()
orders = set()

try:
Expand All @@ -28,42 +25,37 @@ async def run():
latest_orders: set = await starknet.get_latest_unfulfilled_orders()
if len(latest_orders) == 0:
logger.debug(f"[+] No new events")
time.sleep(SLEEP_TIME)
await asyncio.sleep(SLEEP_TIME)
continue

for order in latest_orders:
logger.info(f"[+] New event: {order}")

order_id = order.order_id
dst_addr = order.recipient_address
amount = order.amount

if order_id in orders:
logger.info(f"[+] Order already processed: {order_id}")
logger.debug(f"[+] Order already processed: {order_id}")
continue

logger.info(f"[+] New order: {order}")

orders.add(order_id)
t = threading.Thread(target=asyncio.run, args=(process_order(order_id, dst_addr, amount, eth_lock),))
threads.append(t)
t.start()
asyncio.create_task(process_order(order_id, dst_addr, amount, eth_lock), name=f"Order-{order_id}")

time.sleep(SLEEP_TIME)
await asyncio.sleep(SLEEP_TIME)
except Exception as e:
logger.error(f"[-] Error: {e}")
logger.info(f"[+] Waiting for threads to finish")
for t in threads:
t.join()
logger.info(f"[+] All threads finished")


async def process_order(order_id, dst_addr, amount, eth_lock):
# 2. Transfer eth on ethereum
# (bridging is complete for the user)
logger.info(f"[+] Transferring eth on ethereum")
with eth_lock:
async with eth_lock:
try:
# in case it's processed on ethereum, but not processed on starknet
ethereum.transfer(order_id, dst_addr, amount)
await asyncio.to_thread(ethereum.transfer, order_id, dst_addr, amount)
logger.info(f"[+] Transfer complete")
except Exception as e:
logger.error(f"[-] Transfer failed: {e}")
Expand All @@ -77,14 +69,14 @@ async def process_order(order_id, dst_addr, amount, eth_lock):
logger.info(f"[+] Index: {index.hex()}")
logger.info(f"[+] Slot: {slot.hex()}")
logger.info(f"[+] Proving block {block}")
task_id = herodotus.herodotus_prove(block, order_id, slot)
task_id = await herodotus.herodotus_prove(block, order_id, slot)
logger.info(f"[+] Block being proved with task id: {task_id}")

# 4. Poll herodotus to check task status
logger.info(f"[+] Polling herodotus for task status")
# avoid weird case where herodotus insta says done
time.sleep(10)
completed = herodotus.herodotus_poll_status(task_id)
await asyncio.sleep(10)
completed = await herodotus.herodotus_poll_status(task_id)
logger.info(f"[+] Task completed")

# 5. Withdraw eth from starknet
Expand Down

0 comments on commit fdd9c01

Please sign in to comment.