From cb0c815ad799050ecc0abdf3d15981e9832b9829 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Fri, 28 Jul 2023 17:06:38 -0700 Subject: [PATCH] feat: :sparkles: allow custom OpenAI base_url --- continuedev/src/continuedev/core/config.py | 11 +++++----- continuedev/src/continuedev/core/sdk.py | 2 +- .../src/continuedev/libs/llm/openai.py | 20 ++++++++++--------- docs/docs/customization.md | 12 +++++++---- extension/package-lock.json | 4 ++-- extension/package.json | 2 +- 6 files changed, 29 insertions(+), 22 deletions(-) diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index cb9c897776..e367e06cc2 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -25,10 +25,11 @@ class OnTracebackSteps(BaseModel): params: Optional[Dict] = {} -class AzureInfo(BaseModel): - endpoint: str - engine: str - api_version: str +class OpenAIServerInfo(BaseModel): + api_base: Optional[str] = None + engine: Optional[str] = None + api_version: Optional[str] = None + api_type: Literal["azure", "openai"] = "openai" class ContinueConfig(BaseModel): @@ -49,7 +50,7 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = [] on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None - azure_openai_info: Optional[AzureInfo] = None + openai_server_info: Optional[OpenAIServerInfo] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index d75aac00f1..9ee9ea0634 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -81,7 +81,7 @@ def __load_openai_model(self, model: str) -> OpenAI: api_key = self.provider_keys["openai"] if api_key == "": return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) - return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log) + return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log) def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: api_key = self.provider_keys["hf_inference_api"] diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a0773c1d10..654c7326c1 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -6,27 +6,29 @@ import openai from ..llm import LLM from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top -from ...core.config import AzureInfo +from ...core.config import OpenAIServerInfo class OpenAI(LLM): api_key: str default_model: str - def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): + def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None): self.api_key = api_key self.default_model = default_model self.system_message = system_message - self.azure_info = azure_info + self.openai_server_info = openai_server_info self.write_log = write_log openai.api_key = api_key # Using an Azure OpenAI deployment - if azure_info is not None: - openai.api_type = "azure" - openai.api_base = azure_info.endpoint - openai.api_version = azure_info.api_version + if openai_server_info is not None: + openai.api_type = openai_server_info.api_type + if openai_server_info.api_base is not None: + openai.api_base = openai_server_info.api_base + if openai_server_info.api_version is not None: + openai.api_version = openai_server_info.api_version @cached_property def name(self): @@ -35,8 +37,8 @@ def name(self): @property def default_args(self): args = {**DEFAULT_ARGS, "model": self.default_model} - if self.azure_info is not None: - args["engine"] = self.azure_info.engine + if self.openai_server_info is not None: + args["engine"] = self.openai_server_info.engine return args def count_tokens(self, text: str): diff --git a/docs/docs/customization.md b/docs/docs/customization.md index f383de4838..fa4d110ee3 100644 --- a/docs/docs/customization.md +++ b/docs/docs/customization.md @@ -11,6 +11,7 @@ Change the `default_model` field to any of "gpt-3.5-turbo", "gpt-3.5-turbo-16k", New users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code. Once you are using Continue regularly though, you will need to add an OpenAI API key that has access to GPT-4 by following these steps: + 1. Copy your API key from https://platform.openai.com/account/api-keys 2. Use the cmd+, (Mac) / ctrl+, (Windows) to open your VS Code settings 3. Type "Continue" in the search bar @@ -35,21 +36,24 @@ If by chance the provider has the exact same API interface as OpenAI, the `GGML` ### Azure OpenAI Service -If you'd like to use OpenAI models but are concerned about privacy, you can use the Azure OpenAI service, which is GDPR and HIPAA compliant. After applying for access [here](https://azure.microsoft.com/en-us/products/ai-services/openai-service), you will typically hear back within only a few days. Once you have access, set `default_model` to "gpt-4", and then set the `azure_openai_info` property in the `ContinueConfig` like so: +If you'd like to use OpenAI models but are concerned about privacy, you can use the Azure OpenAI service, which is GDPR and HIPAA compliant. After applying for access [here](https://azure.microsoft.com/en-us/products/ai-services/openai-service), you will typically hear back within only a few days. Once you have access, set `default_model` to "gpt-4", and then set the `openai_server_info` property in the `ContinueConfig` like so: ```python config = ContinueConfig( ... - azure_openai_info=AzureInfo( - endpoint="https://my-azure-openai-instance.openai.azure.com/", + openai_server_info=OpenAIServerInfo( + api_base="https://my-azure-openai-instance.openai.azure.com/", engine="my-azure-openai-deployment", - api_version="2023-03-15-preview" + api_version="2023-03-15-preview", + api_type="azure" ) ) ``` The easiest way to find this information is from the chat playground in the Azure OpenAI portal. Under the "Chat Session" section, click "View Code" to see each of these parameters. Finally, find one of your Azure OpenAI keys and enter it in the VS Code settings under `continue.OPENAI_API_KEY`. +Note that you can also use `OpenAIServerInfo` for uses other than Azure, such as self-hosting a model. + ## Customize System Message You can write your own system message, a set of instructions that will always be top-of-mind for the LLM, by setting the `system_message` property to any string. For example, you might request "Please make all responses as concise as possible and never repeat something you have already explained." diff --git a/extension/package-lock.json b/extension/package-lock.json index 34a8671bfa..088e114aaf 100644 --- a/extension/package-lock.json +++ b/extension/package-lock.json @@ -1,12 +1,12 @@ { "name": "continue", - "version": "0.0.219", + "version": "0.0.220", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "continue", - "version": "0.0.219", + "version": "0.0.220", "license": "Apache-2.0", "dependencies": { "@electron/rebuild": "^3.2.10", diff --git a/extension/package.json b/extension/package.json index ceaff7d98c..903cd6ec29 100644 --- a/extension/package.json +++ b/extension/package.json @@ -14,7 +14,7 @@ "displayName": "Continue", "pricing": "Free", "description": "The open-source coding autopilot", - "version": "0.0.219", + "version": "0.0.220", "publisher": "Continue", "engines": { "vscode": "^1.67.0"