diff --git a/examples/servers/json_server.py b/examples/servers/json_server.py index 93a171f4..d14ec699 100644 --- a/examples/servers/json_server.py +++ b/examples/servers/json_server.py @@ -110,7 +110,7 @@ def workspace_diagnostic( params: lsp.WorkspaceDiagnosticParams, ) -> lsp.WorkspaceDiagnosticReport: """Returns diagnostic report.""" - first = list(json_server.workspace._docs.keys())[0] + first = list(json_server.workspace.text_documents.keys())[0] document = json_server.workspace.get_document(first) return lsp.WorkspaceDiagnosticReport( items=[ diff --git a/pygls/capabilities.py b/pygls/capabilities.py index 59e5de80..c36832b0 100644 --- a/pygls/capabilities.py +++ b/pygls/capabilities.py @@ -15,7 +15,7 @@ # limitations under the License. # ############################################################################ from functools import reduce -from typing import Any, Dict, List, Set +from typing import Any, Dict, List, Set, Union from lsprotocol.types import ( INLAY_HINT_RESOLVE, @@ -349,10 +349,12 @@ def _with_semantic_tokens(self): self.server_cap.semantic_tokens_provider = value return self + full_support: Union[bool, SemanticTokensOptionsFullType1] = ( + TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL in self.features + ) + if TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL_DELTA in self.features: full_support = SemanticTokensOptionsFullType1(delta=True) - else: - full_support = TEXT_DOCUMENT_SEMANTIC_TOKENS_FULL in self.features options = SemanticTokensOptions( legend=value, diff --git a/pygls/exceptions.py b/pygls/exceptions.py index 533640e6..639ea669 100644 --- a/pygls/exceptions.py +++ b/pygls/exceptions.py @@ -17,6 +17,8 @@ # limitations under the License. # ############################################################################ import traceback +from typing import Set +from typing import Type class JsonRpcException(Exception): @@ -158,7 +160,7 @@ def _is_server_error_code(code): return -32099 <= code <= -32000 -_EXCEPTIONS = ( +_EXCEPTIONS: Set[Type[JsonRpcException]] = { JsonRpcInternalError, JsonRpcInvalidParams, JsonRpcInvalidRequest, @@ -166,7 +168,7 @@ def _is_server_error_code(code): JsonRpcParseError, JsonRpcRequestCancelled, JsonRpcServerError, -) +} class PyglsError(Exception): diff --git a/pygls/protocol.py b/pygls/protocol.py index cc245ae9..b15757b0 100644 --- a/pygls/protocol.py +++ b/pygls/protocol.py @@ -503,16 +503,16 @@ def _send_data(self, data): body = json.dumps(data, default=self._serialize_message) logger.info("Sending data: %s", body) - body = body.encode(self.CHARSET) - if not self._send_only_body: - header = ( - f"Content-Length: {len(body)}\r\n" - f"Content-Type: {self.CONTENT_TYPE}; charset={self.CHARSET}\r\n\r\n" - ).encode(self.CHARSET) - - self.transport.write(header + body) - else: - self.transport.write(body.decode("utf-8")) + if self._send_only_body: + self.transport.write(body) + return + + header = ( + f"Content-Length: {len(body)}\r\n" + f"Content-Type: {self.CONTENT_TYPE}; charset={self.CHARSET}\r\n\r\n" + ).encode(self.CHARSET) + + self.transport.write(header + body.encode(self.CHARSET)) except Exception as error: logger.exception("Error sending data", exc_info=True) self._server._report_server_error(error, JsonRpcInternalError) @@ -631,7 +631,7 @@ def send_request(self, method, params=None, callback=None, msg_id=None): jsonrpc=JsonRPCProtocol.VERSION, ) - future = Future() + future = Future() # type: ignore[var-annotated] # If callback function is given, call it when result is received if callback: diff --git a/pygls/server.py b/pygls/server.py index 9cadf85d..b26c9a4b 100644 --- a/pygls/server.py +++ b/pygls/server.py @@ -189,9 +189,9 @@ def __init__( self._max_workers = max_workers self._server = None - self._stop_event = None - self._thread_pool = None - self._thread_pool_executor = None + self._stop_event: Optional[Event] = None + self._thread_pool: Optional[ThreadPool] = None + self._thread_pool_executor: Optional[ThreadPoolExecutor] = None if sync_kind is not None: self.text_document_sync_kind = sync_kind @@ -210,7 +210,8 @@ def shutdown(self): """Shutdown server.""" logger.info("Shutting down the server") - self._stop_event.set() + if self._stop_event is not None: + self._stop_event.set() if self._thread_pool: self._thread_pool.terminate() @@ -223,7 +224,7 @@ def shutdown(self): self._server.close() self.loop.run_until_complete(self._server.wait_closed()) - if self._owns_loop and not self.loop.is_closed: + if self._owns_loop and not self.loop.is_closed(): logger.info("Closing the event loop.") self.loop.close() @@ -235,7 +236,7 @@ def start_io(self, stdin: Optional[TextIO] = None, stdout: Optional[TextIO] = No transport = StdOutTransportAdapter( stdin or sys.stdin.buffer, stdout or sys.stdout.buffer ) - self.lsp.connection_made(transport) + self.lsp.connection_made(transport) # type: ignore[arg-type] try: self.loop.run_until_complete( @@ -260,7 +261,7 @@ def start_pyodide(self): # Note: We don't actually start anything running as the main event # loop will be handled by the web platform. transport = PyodideTransportAdapter(sys.stdout) - self.lsp.connection_made(transport) + self.lsp.connection_made(transport) # type: ignore[arg-type] self.lsp._send_only_body = True # Don't send headers within the payload def start_tcp(self, host: str, port: int) -> None: @@ -268,7 +269,7 @@ def start_tcp(self, host: str, port: int) -> None: logger.info("Starting TCP server on %s:%s", host, port) self._stop_event = Event() - self._server = self.loop.run_until_complete( + self._server = self.loop.run_until_complete( # type: ignore[assignment] self.loop.create_server(self.lsp, host, port) ) try: @@ -300,7 +301,7 @@ async def connection_made(websocket, _): ) start_server = serve(connection_made, host, port, loop=self.loop) - self._server = start_server.ws_server + self._server = start_server.ws_server # type: ignore[assignment] self.loop.run_until_complete(start_server) try: @@ -388,7 +389,7 @@ def __init__( name: str, version: str, loop=None, - protocol_cls=Type[LanguageServerProtocol], + protocol_cls: Type[LanguageServerProtocol] = LanguageServerProtocol, converter_factory=default_converter, text_document_sync_kind: TextDocumentSyncKind = TextDocumentSyncKind.Incremental, notebook_document_sync: Optional[NotebookDocumentSyncOptions] = None, diff --git a/pygls/workspace.py b/pygls/workspace.py index 4cb0a853..fad94ef6 100644 --- a/pygls/workspace.py +++ b/pygls/workspace.py @@ -468,8 +468,14 @@ def get_notebook_document( if notebook_uri is not None: return self._notebook_documents.get(notebook_uri) - notebook_uri = self._cell_in_notebook.get(cell_uri) - return self._notebook_documents.get(notebook_uri) + if cell_uri is not None: + notebook_uri = self._cell_in_notebook.get(cell_uri) + if notebook_uri is None: + return None + + return self._notebook_documents.get(notebook_uri) + + return None def get_text_document(self, doc_uri: str) -> TextDocument: """ @@ -522,7 +528,7 @@ def put_text_document( if notebook_uri: self._cell_in_notebook[doc_uri] = notebook_uri - def remove_notebook_document(self, params: types.DidChangeNotebookDocumentParams): + def remove_notebook_document(self, params: types.DidCloseNotebookDocumentParams): notebook_uri = params.notebook_document.uri self._notebook_documents.pop(notebook_uri, None) diff --git a/pyproject.toml b/pyproject.toml index fa323a6e..93e4cfab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,9 @@ line-length = 120 line-length = 88 extend-exclude = "pygls/lsp/client.py" +[tool.mypy] +check_untyped_defs = true + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"