Skip to content

Commit

Permalink
Handle WebSocket errors (#2746)
Browse files Browse the repository at this point in the history
Re-raise exceptions received via WebSocket instead of ignoring them or
interpret as string or bytes.
  • Loading branch information
serhiy-storchaka authored Jun 30, 2022
1 parent 319ae63 commit 71d466f
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions neuro-sdk/src/neuro_sdk/_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ async def read_out(self) -> Optional[Message]:
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
self._closing = True
return None
if msg.type == aiohttp.WSMsgType.ERROR:
raise self._ws.exception() # type: ignore
if msg.type != aiohttp.WSMsgType.BINARY:
raise RuntimeError(f"Incorrecr WebSocket message: {msg!r}")
if msg.data[0] == 3:
try:
details = json.loads(msg.data[1:])
Expand Down Expand Up @@ -552,8 +556,13 @@ async def monitor(
heartbeat=30,
) as ws:
async for msg in ws:
if msg.data:
yield msg.data
if msg.type == aiohttp.WSMsgType.BINARY:
if msg.data:
yield msg.data
elif msg.type == aiohttp.WSMsgType.ERROR:
raise ws.exception() # type: ignore
else:
raise RuntimeError(f"Incorrecr WebSocket message: {msg!r}")

async def status(self, id: str) -> JobDescription:
url = self._config.api_url / "jobs" / id
Expand All @@ -575,6 +584,10 @@ async def top(
if msg.type == aiohttp.WSMsgType.TEXT:
yield _job_telemetry_from_api(msg.json())
received_any = True
elif msg.type == aiohttp.WSMsgType.ERROR:
raise ws.exception() # type: ignore
else:
raise RuntimeError(f"Incorrecr WebSocket message: {msg!r}")
if not received_any:
raise ValueError(f"Job is not running. Job Id = {id}")
except WSServerHandshakeError as e:
Expand Down Expand Up @@ -691,9 +704,13 @@ async def _port_reader(
self, ws: aiohttp.ClientWebSocketResponse, writer: asyncio.StreamWriter
) -> None:
async for msg in ws:
assert msg.type == aiohttp.WSMsgType.BINARY
writer.write(msg.data)
await writer.drain()
if msg.type == aiohttp.WSMsgType.BINARY:
writer.write(msg.data)
await writer.drain()
elif msg.type == aiohttp.WSMsgType.ERROR:
raise ws.exception() # type: ignore
else:
raise RuntimeError(f"Incorrecr WebSocket message: {msg!r}")
writer.close()
await writer.wait_closed()

Expand Down

0 comments on commit 71d466f

Please sign in to comment.