diff --git a/steam/_gc/state.py b/steam/_gc/state.py index 44726f90..efca533d 100644 --- a/steam/_gc/state.py +++ b/steam/_gc/state.py @@ -137,9 +137,7 @@ def parse_gc_message(self, msg: CMsgGcClientFromGC) -> None: return log.exception("Failed to execute %r", event_parser.__name__) if isinstance(result, CoroutineType): - task = asyncio.create_task(result, name=f"steam.py GC {app_id}: {event_parser.__name__}") - self.ws._pending_parsers.add(task) - task.add_done_callback(self.ws._pending_parsers.remove) + self.ws.tg.create_task(result, name=f"steam.py GC {app_id}: {event_parser.__name__}") # remove the dispatched listener removed: list[int] = [] diff --git a/steam/client.py b/steam/client.py index 3bb50502..083242b1 100644 --- a/steam/client.py +++ b/steam/client.py @@ -469,88 +469,43 @@ async def throttle() -> None: log.info("Attempting to connect to another CM in %ds", sleep) await asyncio.sleep(sleep) - async def poll() -> None: - while True: - await state.ws.poll_event() - - async def dispatch_ready() -> None: - if state.intents & Intents.ChatGroups > 0: - await state.handled_chat_groups.wait() # ensure group cache is ready - if state.intents & Intents.ChatGroups > 0: - # due to a steam limitation we can't get these reliably on reconnect? TODO check? - await state.handled_friends.wait() # ensure friend cache is ready - await state.handled_emoticons.wait() # ensure emoticon cache is ready - await state.handled_licenses.wait() # ensure licenses are ready - await state.handled_wallet.wait() # ensure wallet is ready - - await self._handle_ready() - while not self.is_closed(): last_connect = time.monotonic() try: async with timeout(60): - self.ws = cast(SteamWebSocket, await login_func(self, *args, **kwargs, cm_list=cm_list)) # type: ignore + self.ws = cast( + SteamWebSocket, + await login_func( + self, + *args, + **kwargs, + cm_list=cm_list, # type: ignore + ), + ) except RAISED_EXCEPTIONS: if self.ws: cm_list = self.ws.cm_list await throttle() continue - if login_func != SteamWebSocket.anonymous_login_from_client: - self._tg.create_task(dispatch_ready()) - - # this entire thing is a bit of a cluster fuck - # but that's what you deserve for having async parsers - - # this future holds the future that finished first. either poll_task for a WS exception or callback_error for errors that occur in state.parsers - done: asyncio.Future[asyncio.Future[None]] = asyncio.get_running_loop().create_future() - - poll_task = asyncio.create_task(poll()) - callback_error = state._task_error - - def maybe_set_result(future: asyncio.Future[None]) -> None: - if not done.done(): - done.set_result(future) - else: - try: - future.exception() # mark the exception as retrieved (the other set task should raise the error) - except asyncio.CancelledError: - pass - - poll_task.add_done_callback(maybe_set_result) - callback_error.add_done_callback(maybe_set_result) - try: - task = await done # get which task is done + async with self.ws: + while True: + await state.ws.poll_event() + except asyncio.CancelledError: # KeyboardInterrupt if not self.is_closed(): try: await self.close() except asyncio.CancelledError: pass - for task in (poll_task, callback_error): # cancel them - task.cancel() - await asyncio.gather( - poll_task, callback_error, return_exceptions=True - ) # and collect the results so that the event loop won't raise - return - to_cancel = poll_task if task is callback_error else callback_error # cancel the other task - to_cancel.cancel() - for task_ in self.ws._pending_parsers: - task_.cancel() - await asyncio.gather( - *self.ws._pending_parsers, to_cancel, return_exceptions=True - ) # same sort of thing as above gather - self.ws._pending_parsers.clear() - try: - await task # handle the exception raised - except (*RAISED_EXCEPTIONS, asyncio.CancelledError): + return + except RAISED_EXCEPTIONS: self.dispatch("disconnect") if not self.is_closed(): await throttle() - state._task_error = asyncio.get_running_loop().create_future() @overload async def login( diff --git a/steam/gateway.py b/steam/gateway.py index 648233d6..28b51c26 100644 --- a/steam/gateway.py +++ b/steam/gateway.py @@ -25,7 +25,7 @@ from ipaddress import IPv4Address from operator import attrgetter from types import CoroutineType -from typing import TYPE_CHECKING, Any, Final, Generic, TypeAlias, overload +from typing import TYPE_CHECKING, Any, Final, Generic, Self, TypeAlias, overload from zlib import MAX_WBITS, decompress import aiohttp @@ -293,6 +293,7 @@ def __init__( self._state = state self.cm_list = cm_list self.cm = cm + self.tg = asyncio.TaskGroup() # the keep alive self._keep_alive: KeepAliveHandler self._dispatch = state.dispatch @@ -302,7 +303,6 @@ def __init__( self.listeners: list[EventListener[Any]] = [] self.gc_listeners: list[GCEventListener[Any]] = [] self.closed = False - self._pending_parsers = set[asyncio.Task[Any]]() self.session_id = 0 self.id64 = ID64(0) @@ -316,6 +316,13 @@ def __init__( self.public_ip: IPAdress self.connect_time: datetime + async def __aenter__(self) -> Self: + await self.tg.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.tg.__aexit__(*args) + @property def latency(self) -> float: """Measures latency between a heartbeat send and the heartbeat interval in seconds.""" @@ -385,20 +392,37 @@ def gc_wait_for( @asynccontextmanager async def poll(self) -> AsyncGenerator[None, None]: - async def inner_poll(): - while True: - await self.poll_event() + timeout = asyncio.Timeout(None) + async with asyncio.TaskGroup() as tg: - poll_task = asyncio.create_task(inner_poll()) - poll_task.add_done_callback(self.parser_callback) + async def inner_poll(): + try: + async with timeout: + while True: + await self.poll_event() + except TimeoutError: + pass - yield + tg.create_task(inner_poll()) + yield + timeout.reschedule(-1) - poll_task.cancel() # we let Client.connect handle poll_event from here on out - try: - await poll_task # needed to ensure the task is cancelled and socket._waiting is removed - except asyncio.CancelledError: - pass + def dispatch_ready(self): + state = self._state + + async def inner(): + if state.intents & Intents.ChatGroups > 0: + await state.handled_chat_groups.wait() # ensure group cache is ready + if state.intents & Intents.ChatGroups > 0: + # due to a steam limitation we can't get these reliably on reconnect? TODO check? + await state.handled_friends.wait() # ensure friend cache is ready + await state.handled_emoticons.wait() # ensure emoticon cache is ready + await state.handled_licenses.wait() # ensure licenses are ready + await state.handled_wallet.wait() # ensure wallet is ready + + await state.client._handle_ready() + + self.tg.create_task(inner()) @classmethod async def from_client( @@ -406,6 +430,7 @@ async def from_client( ) -> SteamWebSocket: state = client._state cm_list = cm_list or fetch_cm_list(state) + async for cm in cm_list: log.info("Attempting to create a websocket connection to %s (load: %f)", cm.url, cm.weighted_load) try: @@ -416,6 +441,8 @@ async def from_client( log.debug("Connected to %s", cm.url) self = cls(state, socket, cm_list, cm) + old_tg = self.tg + self.tg = client._tg client.ws = self self._dispatch("connect") @@ -425,6 +452,7 @@ async def from_client( self.refresh_token = refresh_token or await self.fetch_refresh_token() self.id64 = parse_id64(utils.decode_jwt(self.refresh_token)["sub"]) + self.dispatch_ready() msg: login.CMsgClientLogonResponse = await self.send_proto_and_wait( login.CMsgClientLogon( protocol_version=PROTOCOL_VERSION, @@ -475,6 +503,7 @@ async def from_client( self._dispatch("login") log.debug("Logon completed") + self.tg = old_tg return self raise NoCMsFound("No CMs found could be connected to. Steam is likely down") @@ -507,7 +536,7 @@ async def fetch_refresh_token(self) -> str: if not begin_resp.allowed_confirmations: raise AuthenticatorError("No valid auth session guard type was found") - code_task = email_code_task = asyncio.create_task(asyncio.sleep(float("inf"))) + code_task = email_code_task = asyncio.get_running_loop().create_future() schedule_poll = False for allowed_confirmation in begin_resp.allowed_confirmations: @@ -518,10 +547,10 @@ async def fetch_refresh_token(self) -> str: case auth.EAuthSessionGuardType.DeviceCode: if not client.shared_secret: print("Please enter a Steam guard code") - code_task = asyncio.create_task(self.update_auth_with_code(begin_resp, allowed_confirmation)) + code_task = self.tg.create_task(self.update_auth_with_code(begin_resp, allowed_confirmation)) case auth.EAuthSessionGuardType.EmailCode: print("Please enter a confirmation code from your email") - email_code_task = asyncio.create_task(self.update_auth_with_code(begin_resp, allowed_confirmation)) + email_code_task = self.tg.create_task(self.update_auth_with_code(begin_resp, allowed_confirmation)) case auth.EAuthSessionGuardType.DeviceConfirmation: schedule_poll = True print("Confirm login this on your device") @@ -535,23 +564,11 @@ async def fetch_refresh_token(self) -> str: f"Unknown auth session guard type: {allowed_confirmation.confirmation_type}" ) else: - (done,), pending = await asyncio.wait( - ( - code_task, - email_code_task, - ( - asyncio.create_task(self.poll_auth_status(begin_resp)) - if schedule_poll - else asyncio.Future[None]() - ), - ), - return_when=asyncio.FIRST_COMPLETED, + poll_resp: auth.PollAuthSessionStatusResponse = await utils.race( + code_task, + email_code_task, + self.poll_auth_status(begin_resp) if schedule_poll else asyncio.get_running_loop().create_future(), ) - for task in pending: - task.cancel() - await asyncio.gather(*pending, return_exceptions=True) - - poll_resp = await done assert poll_resp is not None self.client_id = poll_resp.new_client_id or begin_resp.client_id @@ -689,28 +706,18 @@ async def anonymous_login_from_client( async def poll_event(self) -> None: try: message = await self.socket.receive() - if message.type is aiohttp.WSMsgType.BINARY: # type: ignore - return self.receive(message.data) # type: ignore - if message.type is aiohttp.WSMsgType.ERROR: # type: ignore + if message.type is aiohttp.WSMsgType.BINARY: + return self.receive(message.data) + if message.type is aiohttp.WSMsgType.ERROR: log.debug("Received %r", message) - raise message.data # type: ignore - if message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): # type: ignore + raise message.data + if message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): log.debug("Received %r", message) raise WebSocketClosure log.debug("Dropped unexpected message type: %r", message) except WebSocketClosure: await self._state.handle_close() - def parser_callback(self, task: asyncio.Task[Any], /) -> None: - try: - exc = task.exception() - except asyncio.CancelledError: - pass - else: - if isinstance(exc, RAISED_EXCEPTIONS) and not self._state._task_error.done(): - self._state._task_error.set_exception(exc) - self._pending_parsers.discard(task) - def receive(self, message: bytes, /) -> None: emsg_value = READ_U32(message) try: @@ -740,9 +747,7 @@ def receive(self, message: bytes, /) -> None: return traceback.print_exc() if isinstance(result, CoroutineType): - task = asyncio.create_task(result, name=f"steam.py: {event_parser.__name__}") - self._pending_parsers.add(task) - task.add_done_callback(self.parser_callback) + self.tg.create_task(result, name=f"steam.py: {event_parser.__name__}") # remove the dispatched listener removed: list[int] = [] for idx, entry in enumerate(self.listeners): diff --git a/steam/state.py b/steam/state.py index 423ee4ac..88750c37 100644 --- a/steam/state.py +++ b/steam/state.py @@ -279,11 +279,6 @@ def _device_id(self) -> str: def _tg(self) -> TaskGroup: return self.client._tg - @utils.cached_property - def _task_error(self) -> asyncio.Future[None]: - """Holds the exceptions so that the gateway can propagate exceptions to the client.login call""" - return asyncio.get_running_loop().create_future() - @property def language(self) -> Language: return self.http.language @@ -1787,7 +1782,7 @@ async def fill_trades(self) -> None: async def wait_for_trade(self, id: TradeOfferID) -> TradeOffer[Item[User], Item[ClientUser], User]: self._trades_to_watch.add(id) - self._tg.create_task(self.poll_trades()) # start re-polling trades + self.ws.tg.create_task(self.poll_trades()) # start re-polling trades return await self.trade_queue.wait_for(id=id) @parser @@ -1866,7 +1861,7 @@ async def poll_confirmations(self) -> None: self.polling_confirmations = False async def wait_for_confirmation(self, id: TradeOfferID) -> Confirmation: - self._tg.create_task(self.poll_confirmations()) + self.ws.tg.create_task(self.poll_confirmations()) return await self.confirmation_queue.wait_for(id=id) async def _fetch_store_info( diff --git a/steam/utils.py b/steam/utils.py index 56a1f43e..905fb1d2 100644 --- a/steam/utils.py +++ b/steam/utils.py @@ -46,6 +46,7 @@ Final, Generic, Literal, + Never, ParamSpec, TypeAlias, TypedDict, @@ -394,6 +395,31 @@ def __await__(self) -> Generator[Any, None, Self]: return self.__await_inner__().__await__() +@overload +async def race() -> None: ... +@overload +async def race(*coros: Awaitable[_T]) -> _T: ... +async def race(*coros: Awaitable[_T]) -> _T | None: + class _Done(Exception): + pass + + result = None + + async def run(coro: Awaitable[_T]) -> Never: + nonlocal result + result = await coro + raise _Done + + try: + async with asyncio.TaskGroup() as tg: + for coro in coros: + tg.create_task(run(coro)) + except* _Done: + pass + + return result + + PACK_FORMATS: Final = cast(Mapping[str, str], { "i8": "b", "u8": "B",