Skip to content

Commit

Permalink
#3376 remove dependency on wsproto
Browse files Browse the repository at this point in the history
  • Loading branch information
totaam committed Nov 9, 2022
1 parent 08bd0d2 commit 9de59b2
Showing 1 changed file with 22 additions and 45 deletions.
67 changes: 22 additions & 45 deletions xpra/net/quic/websocket_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@
#OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
#OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import asyncio
import os
import time
from collections import deque
import struct
from queue import Queue
from email.utils import formatdate
from typing import Callable, Deque, Dict, Optional, Union

import wsproto.events
from typing import Callable, Dict, Union

from aioquic.h0.connection import H0Connection
from aioquic.h3.connection import H3Connection
from aioquic.h3.events import DataReceived, HeadersReceived, H3Event

from xpra.net.websockets.mask import hybi_mask # pylint: disable=no-name-in-module
from xpra.net.websockets.header import encode_hybi_header
from xpra.net.quic.common import SERVER_NAME
from xpra.util import ellipsizer
from xpra.log import Logger
Expand All @@ -48,27 +49,18 @@ class WebSocketHandler:
def __init__(self, connection: HttpConnection, scope: Dict, stream_id: int, transmit: Callable[[], None]) -> None:
self.closed = False
self.connection = connection
self.http_event_queue: Deque[DataReceived] = deque()
self.queue: asyncio.Queue[Dict] = asyncio.Queue()
self.data_queue: Queue[bytes] = Queue()
self.scope = scope
self.stream_id = stream_id
self.transmit = transmit
self.websocket: Optional[wsproto.Connection] = None
#self.queue.put_nowait({"type": "websocket.connect"})
self.accepted : bool = False

def http_event_received(self, event: H3Event) -> None:
log("ws:http_event_received(%s)", ellipsizer(event))
if self.closed:
return
if isinstance(event, DataReceived):
if self.websocket is not None:
self.websocket.receive_data(event.data)
for ws_event in self.websocket.events():
self.websocket_event_received(ws_event)
else:
# delay event processing until we get `websocket.accept`
# from the ASGI application
self.http_event_queue.append(event)
self.data_queue.put(event.data)
elif isinstance(event, HeadersReceived):
subprotocols = self.scope.get("subprotocols", ())
if "xpra" not in subprotocols:
Expand All @@ -79,26 +71,17 @@ def http_event_received(self, event: H3Event) -> None:
self.send_accept()


def websocket_event_received(self, event: wsproto.events.Event) -> None:
log("ws:websocket_event_received(%s)", ellipsizer(event))
if isinstance(event, wsproto.events.TextMessage):
self.queue.put_nowait({"type": "websocket.receive", "text": event.data})
elif isinstance(event, wsproto.events.Message):
self.queue.put_nowait({"type": "websocket.receive", "bytes": event.data})
elif isinstance(event, wsproto.events.CloseConnection):
self.queue.put_nowait({"type": "websocket.disconnect", "code": event.code})

def close(self):
if not self.closed:
self.send_close(1000)

async def receive(self) -> Dict:
def receive(self) -> Dict:
log("ws:receive()")
return await self.queue.get()
return self.data_queue.get()


def send_accept(self, subprotocol : str = "xpra"):
self.websocket = wsproto.Connection(wsproto.ConnectionType.SERVER)
self.accepted = True
headers = [
(b":status", b"200"),
(b"server", SERVER_NAME.encode()),
Expand All @@ -107,29 +90,23 @@ def send_accept(self, subprotocol : str = "xpra"):
if subprotocol:
headers.append((b"sec-websocket-protocol", subprotocol.encode()))
self.connection.send_headers(stream_id=self.stream_id, headers=headers)
# consume backlog
while self.http_event_queue:
self.http_event_received(self.http_event_queue.popleft())
self.transmit()

def send_close(self, code : int = 403):
if self.websocket is not None:
data = self.websocket.send(wsproto.events.CloseConnection(code))
self.connection.send_data(stream_id=self.stream_id, data=data, end_stream=True)
def send_close(self, code : int = 1000, reason : str = ""):
if self.accepted:
data = struct.pack("!H", code)
if reason:
#should validate that encoded data length is less than 125, meh
data += reason.encode("utf-8")
header = encode_hybi_header(code, len(data), has_mask=False, fin=True)
self.connection.send_data(stream_id=self.stream_id, data=header+data, end_stream=True)
else:
self.connection.send_headers(stream_id=self.stream_id, headers=[(b":status", str(code).encode())])
self.closed = True
self.transmit()

#def send_text(self, text : str):
# data = self.websocket.send(wsproto.events.TextMessage(text))
# self.connection.send_data(stream_id=self.stream_id, data=data)
# self.transmit()

def send_bytes(self, bdata : bytes):
#from xpra.net.websockets.mask import hybi_mask
#mask = os.urandom(4)
#data = hybi_mask(mask, bdata)
data = self.websocket.send(wsproto.events.Message(bdata))
mask = os.urandom(4)
data = hybi_mask(mask, bdata)
self.connection.send_data(stream_id=self.stream_id, data=data)
self.transmit()

0 comments on commit 9de59b2

Please sign in to comment.