Skip to content

Commit

Permalink
Fix connection issues (#595)
Browse files Browse the repository at this point in the history
* Work on connection issues

hi #570 #207 and #446

* Remove debug prints

* Make sure to dispatch ready
  • Loading branch information
Gobot1234 authored Dec 9, 2024
1 parent c4ffde7 commit 0aa503e
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 120 deletions.
4 changes: 1 addition & 3 deletions steam/_gc/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def parse_gc_message(self, msg: CMsgGcClientFromGC) -> None:
return log.exception("Failed to execute %r", event_parser.__name__)

if isinstance(result, CoroutineType):
task = asyncio.create_task(result, name=f"steam.py GC {app_id}: {event_parser.__name__}")
self.ws._pending_parsers.add(task)
task.add_done_callback(self.ws._pending_parsers.remove)
self.ws.tg.create_task(result, name=f"steam.py GC {app_id}: {event_parser.__name__}")

# remove the dispatched listener
removed: list[int] = []
Expand Down
75 changes: 15 additions & 60 deletions steam/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,88 +469,43 @@ async def throttle() -> None:
log.info("Attempting to connect to another CM in %ds", sleep)
await asyncio.sleep(sleep)

async def poll() -> None:
while True:
await state.ws.poll_event()

async def dispatch_ready() -> None:
if state.intents & Intents.ChatGroups > 0:
await state.handled_chat_groups.wait() # ensure group cache is ready
if state.intents & Intents.ChatGroups > 0:
# due to a steam limitation we can't get these reliably on reconnect? TODO check?
await state.handled_friends.wait() # ensure friend cache is ready
await state.handled_emoticons.wait() # ensure emoticon cache is ready
await state.handled_licenses.wait() # ensure licenses are ready
await state.handled_wallet.wait() # ensure wallet is ready

await self._handle_ready()

while not self.is_closed():
last_connect = time.monotonic()

try:
async with timeout(60):
self.ws = cast(SteamWebSocket, await login_func(self, *args, **kwargs, cm_list=cm_list)) # type: ignore
self.ws = cast(
SteamWebSocket,
await login_func(
self,
*args,
**kwargs,
cm_list=cm_list, # type: ignore
),
)
except RAISED_EXCEPTIONS:
if self.ws:
cm_list = self.ws.cm_list
await throttle()
continue

if login_func != SteamWebSocket.anonymous_login_from_client:
self._tg.create_task(dispatch_ready())

# this entire thing is a bit of a cluster fuck
# but that's what you deserve for having async parsers

# this future holds the future that finished first. either poll_task for a WS exception or callback_error for errors that occur in state.parsers
done: asyncio.Future[asyncio.Future[None]] = asyncio.get_running_loop().create_future()

poll_task = asyncio.create_task(poll())
callback_error = state._task_error

def maybe_set_result(future: asyncio.Future[None]) -> None:
if not done.done():
done.set_result(future)
else:
try:
future.exception() # mark the exception as retrieved (the other set task should raise the error)
except asyncio.CancelledError:
pass

poll_task.add_done_callback(maybe_set_result)
callback_error.add_done_callback(maybe_set_result)

try:
task = await done # get which task is done
async with self.ws:
while True:
await state.ws.poll_event()

except asyncio.CancelledError: # KeyboardInterrupt
if not self.is_closed():
try:
await self.close()
except asyncio.CancelledError:
pass
for task in (poll_task, callback_error): # cancel them
task.cancel()
await asyncio.gather(
poll_task, callback_error, return_exceptions=True
) # and collect the results so that the event loop won't raise
return

to_cancel = poll_task if task is callback_error else callback_error # cancel the other task
to_cancel.cancel()
for task_ in self.ws._pending_parsers:
task_.cancel()
await asyncio.gather(
*self.ws._pending_parsers, to_cancel, return_exceptions=True
) # same sort of thing as above gather
self.ws._pending_parsers.clear()
try:
await task # handle the exception raised
except (*RAISED_EXCEPTIONS, asyncio.CancelledError):
return
except RAISED_EXCEPTIONS:
self.dispatch("disconnect")
if not self.is_closed():
await throttle()
state._task_error = asyncio.get_running_loop().create_future()

@overload
async def login(
Expand Down
105 changes: 55 additions & 50 deletions steam/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ipaddress import IPv4Address
from operator import attrgetter
from types import CoroutineType
from typing import TYPE_CHECKING, Any, Final, Generic, TypeAlias, overload
from typing import TYPE_CHECKING, Any, Final, Generic, Self, TypeAlias, overload
from zlib import MAX_WBITS, decompress

import aiohttp
Expand Down Expand Up @@ -293,6 +293,7 @@ def __init__(
self._state = state
self.cm_list = cm_list
self.cm = cm
self.tg = asyncio.TaskGroup()
# the keep alive
self._keep_alive: KeepAliveHandler
self._dispatch = state.dispatch
Expand All @@ -302,7 +303,6 @@ def __init__(
self.listeners: list[EventListener[Any]] = []
self.gc_listeners: list[GCEventListener[Any]] = []
self.closed = False
self._pending_parsers = set[asyncio.Task[Any]]()

self.session_id = 0
self.id64 = ID64(0)
Expand All @@ -316,6 +316,13 @@ def __init__(
self.public_ip: IPAdress
self.connect_time: datetime

async def __aenter__(self) -> Self:
await self.tg.__aenter__()
return self

async def __aexit__(self, *args: Any) -> None:
await self.tg.__aexit__(*args)

@property
def latency(self) -> float:
"""Measures latency between a heartbeat send and the heartbeat interval in seconds."""
Expand Down Expand Up @@ -385,27 +392,45 @@ def gc_wait_for(

@asynccontextmanager
async def poll(self) -> AsyncGenerator[None, None]:
async def inner_poll():
while True:
await self.poll_event()
timeout = asyncio.Timeout(None)
async with asyncio.TaskGroup() as tg:

poll_task = asyncio.create_task(inner_poll())
poll_task.add_done_callback(self.parser_callback)
async def inner_poll():
try:
async with timeout:
while True:
await self.poll_event()
except TimeoutError:
pass

yield
tg.create_task(inner_poll())
yield
timeout.reschedule(-1)

poll_task.cancel() # we let Client.connect handle poll_event from here on out
try:
await poll_task # needed to ensure the task is cancelled and socket._waiting is removed
except asyncio.CancelledError:
pass
def dispatch_ready(self):
state = self._state

async def inner():
if state.intents & Intents.ChatGroups > 0:
await state.handled_chat_groups.wait() # ensure group cache is ready
if state.intents & Intents.ChatGroups > 0:
# due to a steam limitation we can't get these reliably on reconnect? TODO check?
await state.handled_friends.wait() # ensure friend cache is ready
await state.handled_emoticons.wait() # ensure emoticon cache is ready
await state.handled_licenses.wait() # ensure licenses are ready
await state.handled_wallet.wait() # ensure wallet is ready

await state.client._handle_ready()

self.tg.create_task(inner())

@classmethod
async def from_client(
cls, client: Client, /, refresh_token: str | None = None, cm_list: AsyncGenerator[CMServer, None] | None = None
) -> SteamWebSocket:
state = client._state
cm_list = cm_list or fetch_cm_list(state)

async for cm in cm_list:
log.info("Attempting to create a websocket connection to %s (load: %f)", cm.url, cm.weighted_load)
try:
Expand All @@ -416,6 +441,8 @@ async def from_client(
log.debug("Connected to %s", cm.url)

self = cls(state, socket, cm_list, cm)
old_tg = self.tg
self.tg = client._tg
client.ws = self
self._dispatch("connect")

Expand All @@ -425,6 +452,7 @@ async def from_client(
self.refresh_token = refresh_token or await self.fetch_refresh_token()
self.id64 = parse_id64(utils.decode_jwt(self.refresh_token)["sub"])

self.dispatch_ready()
msg: login.CMsgClientLogonResponse = await self.send_proto_and_wait(
login.CMsgClientLogon(
protocol_version=PROTOCOL_VERSION,
Expand Down Expand Up @@ -475,6 +503,7 @@ async def from_client(

self._dispatch("login")
log.debug("Logon completed")
self.tg = old_tg

return self
raise NoCMsFound("No CMs found could be connected to. Steam is likely down")
Expand Down Expand Up @@ -507,7 +536,7 @@ async def fetch_refresh_token(self) -> str:
if not begin_resp.allowed_confirmations:
raise AuthenticatorError("No valid auth session guard type was found")

code_task = email_code_task = asyncio.create_task(asyncio.sleep(float("inf")))
code_task = email_code_task = asyncio.get_running_loop().create_future()
schedule_poll = False

for allowed_confirmation in begin_resp.allowed_confirmations:
Expand All @@ -518,10 +547,10 @@ async def fetch_refresh_token(self) -> str:
case auth.EAuthSessionGuardType.DeviceCode:
if not client.shared_secret:
print("Please enter a Steam guard code")
code_task = asyncio.create_task(self.update_auth_with_code(begin_resp, allowed_confirmation))
code_task = self.tg.create_task(self.update_auth_with_code(begin_resp, allowed_confirmation))
case auth.EAuthSessionGuardType.EmailCode:
print("Please enter a confirmation code from your email")
email_code_task = asyncio.create_task(self.update_auth_with_code(begin_resp, allowed_confirmation))
email_code_task = self.tg.create_task(self.update_auth_with_code(begin_resp, allowed_confirmation))
case auth.EAuthSessionGuardType.DeviceConfirmation:
schedule_poll = True
print("Confirm login this on your device")
Expand All @@ -535,23 +564,11 @@ async def fetch_refresh_token(self) -> str:
f"Unknown auth session guard type: {allowed_confirmation.confirmation_type}"
)
else:
(done,), pending = await asyncio.wait(
(
code_task,
email_code_task,
(
asyncio.create_task(self.poll_auth_status(begin_resp))
if schedule_poll
else asyncio.Future[None]()
),
),
return_when=asyncio.FIRST_COMPLETED,
poll_resp: auth.PollAuthSessionStatusResponse = await utils.race(
code_task,
email_code_task,
self.poll_auth_status(begin_resp) if schedule_poll else asyncio.get_running_loop().create_future(),
)
for task in pending:
task.cancel()
await asyncio.gather(*pending, return_exceptions=True)

poll_resp = await done
assert poll_resp is not None

self.client_id = poll_resp.new_client_id or begin_resp.client_id
Expand Down Expand Up @@ -689,28 +706,18 @@ async def anonymous_login_from_client(
async def poll_event(self) -> None:
try:
message = await self.socket.receive()
if message.type is aiohttp.WSMsgType.BINARY: # type: ignore
return self.receive(message.data) # type: ignore
if message.type is aiohttp.WSMsgType.ERROR: # type: ignore
if message.type is aiohttp.WSMsgType.BINARY:
return self.receive(message.data)
if message.type is aiohttp.WSMsgType.ERROR:
log.debug("Received %r", message)
raise message.data # type: ignore
if message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): # type: ignore
raise message.data
if message.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
log.debug("Received %r", message)
raise WebSocketClosure
log.debug("Dropped unexpected message type: %r", message)
except WebSocketClosure:
await self._state.handle_close()

def parser_callback(self, task: asyncio.Task[Any], /) -> None:
try:
exc = task.exception()
except asyncio.CancelledError:
pass
else:
if isinstance(exc, RAISED_EXCEPTIONS) and not self._state._task_error.done():
self._state._task_error.set_exception(exc)
self._pending_parsers.discard(task)

def receive(self, message: bytes, /) -> None:
emsg_value = READ_U32(message)
try:
Expand Down Expand Up @@ -740,9 +747,7 @@ def receive(self, message: bytes, /) -> None:
return traceback.print_exc()

if isinstance(result, CoroutineType):
task = asyncio.create_task(result, name=f"steam.py: {event_parser.__name__}")
self._pending_parsers.add(task)
task.add_done_callback(self.parser_callback)
self.tg.create_task(result, name=f"steam.py: {event_parser.__name__}")
# remove the dispatched listener
removed: list[int] = []
for idx, entry in enumerate(self.listeners):
Expand Down
9 changes: 2 additions & 7 deletions steam/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,6 @@ def _device_id(self) -> str:
def _tg(self) -> TaskGroup:
return self.client._tg

@utils.cached_property
def _task_error(self) -> asyncio.Future[None]:
"""Holds the exceptions so that the gateway can propagate exceptions to the client.login call"""
return asyncio.get_running_loop().create_future()

@property
def language(self) -> Language:
return self.http.language
Expand Down Expand Up @@ -1787,7 +1782,7 @@ async def fill_trades(self) -> None:

async def wait_for_trade(self, id: TradeOfferID) -> TradeOffer[Item[User], Item[ClientUser], User]:
self._trades_to_watch.add(id)
self._tg.create_task(self.poll_trades()) # start re-polling trades
self.ws.tg.create_task(self.poll_trades()) # start re-polling trades
return await self.trade_queue.wait_for(id=id)

@parser
Expand Down Expand Up @@ -1866,7 +1861,7 @@ async def poll_confirmations(self) -> None:
self.polling_confirmations = False

async def wait_for_confirmation(self, id: TradeOfferID) -> Confirmation:
self._tg.create_task(self.poll_confirmations())
self.ws.tg.create_task(self.poll_confirmations())
return await self.confirmation_queue.wait_for(id=id)

async def _fetch_store_info(
Expand Down
Loading

0 comments on commit 0aa503e

Please sign in to comment.