Skip to content

Commit

Permalink
Support SageMaker Endpoints in chat (#197)
Browse files Browse the repository at this point in the history
* allow models from registry providers in chat

* support language model fields

* add json content handler for SM Endpoints

* remove console log

* rename variables for clarity

* add user documentation for SageMaker Endpoints

* update docstring

Co-authored-by: Piyush Jain <[email protected]>

* remove redundant height attribute

Co-authored-by: Jason Weill <[email protected]>

* fix memo dependencies

* Updated headers for settings panel sections

* Fixing CI failure for check-release

---------

Co-authored-by: Piyush Jain <[email protected]>
Co-authored-by: Jason Weill <[email protected]>
  • Loading branch information
3 people authored Jun 2, 2023
1 parent d0ecc47 commit dd12385
Show file tree
Hide file tree
Showing 14 changed files with 398 additions and 32 deletions.
1 change: 1 addition & 0 deletions .github/workflows/check-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }}
version_spec: minor
python-version: '3.10.x'
- name: Runner debug info
if: always()
run: |
Expand Down
Binary file added docs/source/_static/chat-sagemaker-endpoints.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 36 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,42 @@ To compose a message, type it in the text box at the bottom of the chat interfac
alt='Screen shot of an example "Hello world" message sent to Jupyternaut, who responds with "Hello world, how are you today?"'
class="screenshot" />

### Usage with SageMaker Endpoints

Jupyter AI supports language models hosted on SageMaker Endpoints that use JSON
APIs. The first step is to authenticate with AWS via the `boto3` SDK and have
the credentials stored in the `default` profile. Guidance on how to do this can
be found in the
[`boto3` documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html).

When selecting the SageMaker Endpoints provider in the settings panel, you will
see the following interface:

<img src="../_static/chat-sagemaker-endpoints.png"
width="50%"
alt='Screenshot of the settings panel with the SageMaker Endpoints provider selected.'
class="screenshot" />

Each of the additional fields under "Language model" is required. These fields
should contain the following data:

- **Local model ID**: The name of your endpoint. This can be retrieved from the
AWS Console at the URL
`https://<region>.console.aws.amazon.com/sagemaker/home?region=<region>#/endpoints`.

- **Region name**: The AWS region your SageMaker endpoint is hosted in, e.g. `us-west-2`.

- **Request schema**: The JSON object the endpoint expects, with the prompt
being substituted into any value that matches the string literal `"<prompt>"`.
In this example, the request schema `{"text_inputs":"<prompt>"}` generates a JSON
object with the prompt stored under the `text_inputs` key.

- **Response path**: A [JSONPath](https://goessner.net/articles/JsonPath/index.html)
string that retrieves the language model's output from the endpoint's JSON
response. In this example, the endpoint returns an object with the schema
`{"generated_texts":["<output>"]}`, hence the response path is
`generated_texts.[0]`.

### Asking about something in your notebook

Jupyter AI's chat interface can include a portion of your notebook in your prompt.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import ClassVar, List, Type
from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy
from jupyter_ai_magics.providers import AuthStrategy, EnvAuthStrategy, Field
from pydantic import BaseModel, Extra
from langchain.embeddings import OpenAIEmbeddings, CohereEmbeddings, HuggingFaceHubEmbeddings
from langchain.embeddings.base import Embeddings
Expand Down Expand Up @@ -35,7 +35,14 @@ class Config:

provider_klass: ClassVar[Type[Embeddings]]

registry: ClassVar[bool] = False
"""Whether this provider is a registry provider."""

fields: ClassVar[List[Field]] = []
"""Fields expected by this provider in its constructor. Each `Field` `f`
should be passed as a keyword argument, keyed by `f.key`."""


class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider):
id = "openai"
name = "OpenAI"
Expand Down Expand Up @@ -73,3 +80,4 @@ class HfHubEmbeddingsProvider(BaseEmbeddingsProvider):
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
provider_klass = HuggingFaceHubEmbeddings
registry = True
80 changes: 76 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import ClassVar, Dict, List, Union, Literal, Optional

from typing import Any, ClassVar, Dict, List, Union, Literal, Optional
import base64

import io
import json
import copy

from jsonpath_ng import jsonpath, parse
from langchain.schema import BaseModel as BaseLangchainProvider
from langchain.llms import (
AI21,
Expand All @@ -14,6 +15,7 @@
OpenAIChat,
SagemakerEndpoint
)
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.utils import get_from_dict_or_env
from langchain.llms.utils import enforce_stop_tokens

Expand Down Expand Up @@ -45,6 +47,18 @@ class AwsAuthStrategy(BaseModel):
]
]

class TextField(BaseModel):
type: Literal["text"] = "text"
key: str
label: str

class MultilineTextField(BaseModel):
type: Literal["text-multiline"] = "text-multiline"
key: str
label: str

Field = Union[TextField, MultilineTextField]

class BaseProvider(BaseLangchainProvider):
#
# pydantic config
Expand Down Expand Up @@ -75,6 +89,13 @@ class Config:
"""Authentication/authorization strategy. Declares what credentials are
required to use this model provider. Generally should not be `None`."""

registry: ClassVar[bool] = False
"""Whether this provider is a registry provider."""

fields: ClassVar[List[Field]] = []
"""User inputs expected by this provider when initializing it. Each `Field` `f`
should be passed in the constructor as a keyword argument, keyed by `f.key`."""

#
# instance attrs
#
Expand Down Expand Up @@ -144,6 +165,7 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
# tqdm is a dependency of huggingface_hub
pypi_package_deps = ["huggingface_hub", "ipywidgets"]
auth_strategy = EnvAuthStrategy(name="HUGGINGFACEHUB_API_TOKEN")
registry = True

# Override the parent's validate_environment with a custom list of valid tasks
@root_validator()
Expand Down Expand Up @@ -292,12 +314,62 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

class JsonContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def __init__(self, request_schema, response_path):
self.request_schema = json.loads(request_schema)
self.response_path = response_path
self.response_parser = parse(response_path)

def replace_values(self, old_val, new_val, d: Dict[str, Any]):
"""Replaces values of a dictionary recursively."""
for key, val in d.items():
if val == old_val:
d[key] = new_val
if isinstance(val, dict):
self.replace_values(old_val, new_val, val)

return d

def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
request_obj = copy.deepcopy(self.request_schema)
self.replace_values("<prompt>", prompt, request_obj)
request = json.dumps(request_obj).encode('utf-8')
return request

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
matches = self.response_parser.find(response_json)
return matches[0].value

class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "Sagemaker Endpoint"
models = ["*"]
model_id_key = "endpoint_name"
pypi_package_deps = ["boto3"]
auth_strategy = AwsAuthStrategy()

registry = True
fields = [
TextField(
key="region_name",
label="Region name",
),
MultilineTextField(
key="request_schema",
label="Request schema",
),
TextField(
key="response_path",
label="Response path",
)
]

def __init__(self, *args, **kwargs):
request_schema = kwargs.pop('request_schema')
response_path = kwargs.pop('response_path')
content_handler = JsonContentHandler(request_schema=request_schema, response_path=response_path)
super().__init__(*args, **kwargs, content_handler=content_handler)

1 change: 1 addition & 0 deletions packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"langchain==0.0.159",
"typing_extensions==4.5.0",
"click~=8.0",
"jsonpath-ng~=1.5.3",
]

[project.optional-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/actors/chat_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def update(self, config: GlobalConfig):
if not provider:
raise ValueError(f"No provider and model found with '{model_id}'")

provider_params = { "model_id": local_model_id}
fields = config.fields.get(model_id, {})
provider_params = { "model_id": local_model_id, **fields }

auth_strategy = provider.auth_strategy
if auth_strategy and auth_strategy.type == "env":
Expand Down
8 changes: 6 additions & 2 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def get(self):
id=provider.id,
name=provider.name,
models=provider.models,
auth_strategy=provider.auth_strategy
auth_strategy=provider.auth_strategy,
registry=provider.registry,
fields=provider.fields,
)
)

Expand All @@ -304,7 +306,9 @@ def get(self):
id=provider.id,
name=provider.name,
models=provider.models,
auth_strategy=provider.auth_strategy
auth_strategy=provider.auth_strategy,
registry=provider.registry,
fields=provider.fields,
)
)

Expand Down
7 changes: 5 additions & 2 deletions packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jupyter_ai_magics.providers import AuthStrategy
from jupyter_ai_magics.providers import AuthStrategy, Field

from pydantic import BaseModel
from typing import Dict, List, Union, Literal, Optional
from typing import Any, Dict, List, Union, Literal, Optional

class PromptRequest(BaseModel):
task_id: str
Expand Down Expand Up @@ -92,6 +92,8 @@ class ListProvidersEntry(BaseModel):
name: str
models: List[str]
auth_strategy: AuthStrategy
registry: bool
fields: List[Field]


class ListProvidersResponse(BaseModel):
Expand All @@ -108,3 +110,4 @@ class GlobalConfig(BaseModel):
embeddings_provider_id: Optional[str] = None
api_keys: Dict[str, str] = {}
send_with_shift_enter: Optional[bool] = None
fields: Dict[str, Dict[str, Any]] = {}
Loading

0 comments on commit dd12385

Please sign in to comment.