From fe69c54ca52af9a5dacd33ccf35fc20a446f7518 Mon Sep 17 00:00:00 2001 From: Bart Leusink Date: Thu, 11 Apr 2024 18:34:46 +0200 Subject: [PATCH 1/3] Remove closing markdown identifiers (#686) --- .../completions/handlers/default.py | 7 ++++ .../tests/completions/test_handlers.py | 39 ++++++++++++++++--- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index eb03df156..17c777a91 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -91,6 +91,8 @@ async def handle_stream_request(self, request: InlineCompletionRequest): continue else: suggestion = self._post_process_suggestion(suggestion, request) + elif suggestion.endswith("```"): + suggestion = self._post_process_suggestion(suggestion, request) self.write_message( InlineCompletionStreamChunk( type="stream", @@ -151,4 +153,9 @@ def _post_process_suggestion( if suggestion.startswith(request.prefix): suggestion = suggestion[len(request.prefix) :] break + + # check if the suggestion ends with a closing markdown identifier and remove it + if suggestion.endswith("```"): + suggestion = suggestion[:-3].rstrip() + return suggestion diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 1b950af74..4b55d0010 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -17,7 +17,8 @@ class MockProvider(BaseProvider, FakeListLLM): models = ["model"] def __init__(self, **kwargs): - kwargs["responses"] = ["Test response"] + if not "responses" in kwargs: + kwargs["responses"] = ["Test response"] super().__init__(**kwargs) @@ -34,7 +35,7 @@ def __init__(self): create_task=lambda x: self.tasks.append(x) ) self.settings["model_parameters"] = {} - self.llm_params = {} + self.llm_params = {"model_id": "model"} self.create_llm_chain(MockProvider, {"model_id": "model"}) def write_message(self, message: str) -> None: # type: ignore @@ -88,8 +89,36 @@ async def test_handle_request(inline_handler): assert suggestions[0].insertText == "Test response" +async def test_handle_request_with_spurious_fragments(inline_handler): + inline_handler.create_llm_chain( + MockProvider, + { + "model_id": "model", + "responses": ["```python\nTest python code\n```"], + }, + ) + dummy_request = InlineCompletionRequest( + number=1, prefix="", suffix="", mime="", stream=False + ) + + await inline_handler.handle_request(dummy_request) + # should write a single reply + assert len(inline_handler.messages) == 1 + # reply should contain a single suggestion + suggestions = inline_handler.messages[0].list.items + assert len(suggestions) == 1 + # the suggestion should include insert text from LLM without spurious fragments + assert suggestions[0].insertText == "Test python code" + + async def test_handle_stream_request(inline_handler): - inline_handler.llm_chain = FakeListLLM(responses=["test"]) + inline_handler.create_llm_chain( + MockProvider, + { + "model_id": "model", + "responses": ["test"], + }, + ) dummy_request = InlineCompletionRequest( number=1, prefix="", suffix="", mime="", stream=True ) @@ -106,11 +135,11 @@ async def test_handle_stream_request(inline_handler): # second reply should be a chunk containing the token second = inline_handler.messages[1] assert second.type == "stream" - assert second.response.insertText == "Test response" + assert second.response.insertText == "test" assert second.done == False # third reply should be a closing chunk third = inline_handler.messages[2] assert third.type == "stream" - assert third.response.insertText == "Test response" + assert third.response.insertText == "test" assert third.done == True From 6ba1f1b43d8c1149792035b481a0c2b4038f595b Mon Sep 17 00:00:00 2001 From: Bart Leusink Date: Fri, 12 Apr 2024 15:09:18 +0200 Subject: [PATCH 2/3] Remove whitespace after closing markdown identifier --- .../jupyter_ai/completions/handlers/default.py | 6 +++--- .../tests/completions/test_handlers.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py index 17c777a91..9d7e7915c 100644 --- a/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/completions/handlers/default.py @@ -91,7 +91,7 @@ async def handle_stream_request(self, request: InlineCompletionRequest): continue else: suggestion = self._post_process_suggestion(suggestion, request) - elif suggestion.endswith("```"): + elif suggestion.rstrip().endswith("```"): suggestion = self._post_process_suggestion(suggestion, request) self.write_message( InlineCompletionStreamChunk( @@ -155,7 +155,7 @@ def _post_process_suggestion( break # check if the suggestion ends with a closing markdown identifier and remove it - if suggestion.endswith("```"): - suggestion = suggestion[:-3].rstrip() + if suggestion.rstrip().endswith("```"): + suggestion = suggestion.rstrip()[:-3].rstrip() return suggestion diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 4b55d0010..2a24a830f 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -5,6 +5,7 @@ from jupyter_ai.completions.models import InlineCompletionRequest from jupyter_ai_magics import BaseProvider from langchain_community.llms import FakeListLLM +import pytest from pytest import fixture from tornado.httputil import HTTPServerRequest from tornado.web import Application @@ -89,12 +90,21 @@ async def test_handle_request(inline_handler): assert suggestions[0].insertText == "Test response" -async def test_handle_request_with_spurious_fragments(inline_handler): +@pytest.mark.parametrize( + "response,expected_suggestion", + [ + ("```python\nTest python code\n```", "Test python code"), + ("```\ntest\n```\n \n", "test"), + ("```hello```world```", "hello```world"), + ], +) +async def test_handle_request_with_spurious_fragments(response, expected_suggestion): + inline_handler = MockCompletionHandler() inline_handler.create_llm_chain( MockProvider, { "model_id": "model", - "responses": ["```python\nTest python code\n```"], + "responses": [response], }, ) dummy_request = InlineCompletionRequest( @@ -108,7 +118,7 @@ async def test_handle_request_with_spurious_fragments(inline_handler): suggestions = inline_handler.messages[0].list.items assert len(suggestions) == 1 # the suggestion should include insert text from LLM without spurious fragments - assert suggestions[0].insertText == "Test python code" + assert suggestions[0].insertText == expected_suggestion async def test_handle_stream_request(inline_handler): From 313b8b70b00fbc5ae5e3837d9dcebd9f5a104558 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Apr 2024 13:09:32 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jupyter-ai/jupyter_ai/tests/completions/test_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py index 2a24a830f..fd2b2666c 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py +++ b/packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py @@ -1,11 +1,11 @@ import json from types import SimpleNamespace +import pytest from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler from jupyter_ai.completions.models import InlineCompletionRequest from jupyter_ai_magics import BaseProvider from langchain_community.llms import FakeListLLM -import pytest from pytest import fixture from tornado.httputil import HTTPServerRequest from tornado.web import Application