Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace "transports" with "writers" #508

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/howto/migrate-to-v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,16 @@ The :external:py:class:`multiprocessing.pool.ThreadPool` instance has been remov

The ``thread_pool_executor`` attribute of the base ``JsonRPCServer`` class has been removed, the ``ThreadPoolExecutor`` can be accessed via the ``thread_pool`` attribute instead.

``JsonRPCProtocol`` is no longer an ``asyncio.Protocol``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Now the pygls v2 uses the high-level asyncio APIs, it no longer makes sense for the ``JsonRPCProtocol`` class to inherit from ``asyncio.Protocol``.
Similarly, "output" classes are now called writers rather than transports. The ``connection_made`` method has been replaced with a corresponding ``set_writer`` method.

New ``pygls.io_`` module
^^^^^^^^^^^^^^^^^^^^^^^^

There is a new ``pygls.io_`` module containing main message parsing loop code common to both client and server

- The equivlaent to pygls v1's ``pygls.server.aio_readline`` function is now ``pygls.io_.run_async``
- It now contains classes like v1's ``WebsocketTransportAdapter``, which have been renamed to ``WebSocketWriter``
8 changes: 6 additions & 2 deletions pygls/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ async def start_io(self, cmd: str, *args, **kwargs):
if server.stdout is None:
raise RuntimeError("Server process is missing a stdout stream")

self.protocol.connection_made(server.stdin) # type: ignore
# Keep mypy happy
if server.stdin is None:
raise RuntimeError("Server process is missing a stdin stream")

self.protocol.set_writer(server.stdin)
connection = asyncio.create_task(
run_async(
stop_event=self._stop_event,
Expand All @@ -118,7 +122,7 @@ async def start_tcp(self, host: str, port: int):
"""Start communicating with a server over TCP."""
reader, writer = await asyncio.open_connection(host, port)

self.protocol.connection_made(writer) # type: ignore
self.protocol.set_writer(writer)
connection = asyncio.create_task(
run_async(
stop_event=self._stop_event,
Expand Down
52 changes: 34 additions & 18 deletions pygls/io_.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,27 @@ def readline(self) -> bytes: ...

def read(self, n: int) -> bytes: ...

class Writer(Protocol):
"""An synchronous writer."""

def close(self) -> None: ...

def write(self, data: bytes) -> None: ...

class AsyncReader(typing.Protocol):
"""An asynchronous reader."""

def readline(self) -> Awaitable[bytes]: ...

def readexactly(self, n: int) -> Awaitable[bytes]: ...

class AsyncWriter(typing.Protocol):
"""An asynchronous writer."""

def close(self) -> Awaitable[None]: ...

def write(self, data: bytes) -> Awaitable[None]: ...


class StdinAsyncReader:
"""Read from stdin asynchronously."""
Expand All @@ -73,28 +87,31 @@ def readexactly(self, n: int) -> Awaitable[bytes]:
return self.loop.run_in_executor(self.executor, self.stdin.read, n)


class WebSocketTransportAdapter:
"""Protocol adapter which calls write method.
class StdoutWriter:
"""Align a stdout stream with pygls' writer interface."""

Write method sends data via the WebSocket interface.
"""
def __init__(self, stdout: BinaryIO):
self._stdout = stdout

def __init__(self, ws: ServerConnection | ClientConnection):
self._ws = ws
self._loop: asyncio.AbstractEventLoop | None = None
def close(self):
self._stdout.close()

@property
def loop(self):
if self._loop is None:
self._loop = asyncio.get_running_loop()
def write(self, data: bytes) -> None:
self._stdout.write(data)
self._stdout.flush()

return self._loop

def close(self) -> None:
asyncio.ensure_future(self._ws.close())
class WebSocketWriter:
"""Align a websocket connection with pygls' writer interface"""

def __init__(self, ws: ServerConnection | ClientConnection):
self._ws = ws

def close(self) -> Awaitable[None]:
return self._ws.close()

def write(self, data: Any) -> None:
asyncio.ensure_future(self._ws.send(data))
def write(self, data: bytes) -> Awaitable[None]:
return self._ws.send(data)


async def run_async(
Expand Down Expand Up @@ -248,8 +265,7 @@ async def run_websocket(
"""

logger = logger or logging.getLogger(__name__)
protocol._send_only_body = True # Don't send headers within the payload
protocol.connection_made(WebSocketTransportAdapter(websocket)) # type: ignore
protocol.set_writer(WebSocketWriter(websocket), include_headers=False)

try:
from websockets.exceptions import ConnectionClosed
Expand Down
116 changes: 34 additions & 82 deletions pygls/protocol/json_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import asyncio
import enum
import inspect
import json
import logging
import re
import sys
import traceback
import uuid
Expand All @@ -29,8 +29,6 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Type,
Union,
Expand Down Expand Up @@ -60,7 +58,7 @@
if TYPE_CHECKING:
from cattrs import Converter

from pygls.io_ import WebSocketTransportAdapter
from pygls.io_ import AsyncWriter, Writer
from pygls.server import JsonRPCServer


Expand Down Expand Up @@ -101,8 +99,8 @@ class JsonRPCResponseMessage:
result: Any


class JsonRPCProtocol(asyncio.Protocol):
"""Json RPC protocol implementation using on top of `asyncio.Protocol`.
class JsonRPCProtocol:
"""Json RPC protocol implementation

Specification of the protocol can be found here:
https://www.jsonrpc.org/specification
Expand All @@ -112,15 +110,6 @@ class JsonRPCProtocol(asyncio.Protocol):

CHARSET = "utf-8"
CONTENT_TYPE = "application/vscode-jsonrpc"

MESSAGE_PATTERN = re.compile(
rb"^(?:[^\r\n]+\r\n)*"
+ rb"Content-Length: (?P<length>\d+)\r\n"
+ rb"(?:[^\r\n]+\r\n)*\r\n"
+ rb"(?P<body>{.*)",
re.DOTALL,
)

VERSION = "2.0"

def __init__(self, server: JsonRPCServer, converter: Converter):
Expand All @@ -130,16 +119,12 @@ def __init__(self, server: JsonRPCServer, converter: Converter):
self._shutdown = False

# Book keeping for in-flight requests
self._request_futures: Dict[str, Future[Any]] = {}
self._result_types: Dict[str, Any] = {}
self._request_futures: dict[str, Future[Any]] = {}
self._result_types: dict[str, Any] = {}

self.fm = FeatureManager(server, converter)
self.transport: Optional[
Union[asyncio.WriteTransport, WebSocketTransportAdapter]
] = None
self._message_buf: List[bytes] = []

self._send_only_body = False
self.writer: AsyncWriter | Writer | None = None
self._include_headers = False

def __call__(self):
return self
Expand Down Expand Up @@ -379,26 +364,27 @@ def _send_data(self, data):
if not data:
return

if self.transport is None:
if self.writer is None:
logger.error("Unable to send data, no available transport!")
return

try:
body = json.dumps(data, default=self._serialize_message)
logger.info("Sending data: %s", body)

if self._send_only_body:
# Mypy/Pyright seem to think `write()` wants `"bytes | bytearray | memoryview"`
# But runtime errors with anything but `str`.
self.transport.write(body) # type: ignore
return
if self._include_headers:
header = (
f"Content-Length: {len(body)}\r\n"
f"Content-Type: {self.CONTENT_TYPE}; charset={self.CHARSET}\r\n\r\n"
)
data = header + body
else:
data = body

header = (
f"Content-Length: {len(body)}\r\n"
f"Content-Type: {self.CONTENT_TYPE}; charset={self.CHARSET}\r\n\r\n"
).encode(self.CHARSET)
res = self.writer.write(data.encode(self.CHARSET))
if inspect.isawaitable(res):
asyncio.ensure_future(res)

self.transport.write(header + body.encode(self.CHARSET))
except Exception as error:
logger.exception("Error sending data", exc_info=True)
self._server._report_server_error(error, JsonRpcInternalError)
Expand All @@ -425,58 +411,24 @@ def _send_response(

self._send_data(response)

def connection_lost(self, exc):
"""Method from base class, called when connection is lost, in which case we
want to shutdown the server's process as well.
"""
logger.error("Connection to the client is lost! Shutting down the server.")
sys.exit(1)

def connection_made( # type: ignore # see: https://github.com/python/typeshed/issues/3021
def set_writer(
self,
transport: asyncio.Transport,
writer: AsyncWriter | Writer,
include_headers: bool = True,
):
"""Method from base class, called when connection is established"""
self.transport = transport
"""Set the writer object to use when sending data

def data_received(self, data: bytes):
try:
self._data_received(data)
except Exception as error:
logger.exception("Error receiving data", exc_info=True)
self._server._report_server_error(error, JsonRpcInternalError)
Parameters
----------
writer
The writer object

def _data_received(self, data: bytes):
"""Method from base class, called when server receives the data"""
logger.debug("Received %r", data)

while len(data):
# Append the incoming chunk to the message buffer
self._message_buf.append(data)

# Look for the body of the message
message = b"".join(self._message_buf)
found = JsonRPCProtocol.MESSAGE_PATTERN.fullmatch(message)

body = found.group("body") if found else b""
length = int(found.group("length")) if found else 1

if len(body) < length:
# Message is incomplete; bail until more data arrives
return

# Message is complete;
# extract the body and any remaining data,
# and reset the buffer for the next message
body, data = body[:length], body[length:]
self._message_buf = []

# Parse the body
self.handle_message(
json.loads(
body.decode(self.CHARSET), object_hook=self.structure_message
)
)
include_headers
Flag indicating if headers like ``Content-Length`` should be included when
sending data. (Default ``True``)
"""
self.writer = writer
self._include_headers = include_headers

def get_message_type(self, method: str) -> Optional[Type]:
"""Return the type definition of the message associated with the given method."""
Expand Down
16 changes: 14 additions & 2 deletions pygls/protocol/language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
############################################################################
from __future__ import annotations

import asyncio
import inspect
import json
import logging
import sys
import typing
from functools import lru_cache
from itertools import zip_longest
Expand Down Expand Up @@ -113,8 +116,17 @@ def get_result_type(self, method: str) -> Optional[Type]:
@lsp_method(types.EXIT)
def lsp_exit(self, *args) -> None:
"""Stops the server process."""
if self.transport is not None:
self.transport.close()
returncode = 0 if self._shutdown else 1
if self.writer is None:
sys.exit(returncode)

res = self.writer.close()
if inspect.isawaitable(res):
# Only call sys.exit once the close task has completed.
fut = asyncio.ensure_future(res)
fut.add_done_callback(lambda t: sys.exit(returncode))
else:
sys.exit(returncode)

@lsp_method(types.INITIALIZE)
def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResult:
Expand Down
Loading
Loading