Skip to content

Commit

Permalink
feat(openai): add azure support
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Prosser committed Nov 11, 2024
1 parent 5f686f9 commit edb060b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 2 deletions.
3 changes: 3 additions & 0 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ModelOverrides,
Message,
)
import gptcli
from gptcli.providers.google import GoogleCompletionProvider
from gptcli.providers.llama import LLaMACompletionProvider
from gptcli.providers.openai import OpenAICompletionProvider
Expand Down Expand Up @@ -83,6 +84,8 @@ def get_completion_provider(model: str) -> CompletionProvider:
return CohereCompletionProvider()
elif model.startswith("gemini"):
return GoogleCompletionProvider()
elif gptcli.providers.openai.use_azure:
return OpenAICompletionProvider()
else:
raise ValueError(f"Unknown model: {model}")

Expand Down
4 changes: 4 additions & 0 deletions gptcli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class GptCliConfig:
api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_base_url: Optional[str] = os.environ.get("OPENAI_BASE_URL")
# When using azure open ai, set your assistant's model to your deployment
# name
openai_use_azure: bool = False
openai_azure_api_version: str = "2024-10-21"
anthropic_api_key: Optional[str] = os.environ.get("ANTHROPIC_API_KEY")
google_api_key: Optional[str] = os.environ.get("GOOGLE_API_KEY")
cohere_api_key: Optional[str] = os.environ.get("COHERE_API_KEY")
Expand Down
7 changes: 7 additions & 0 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,15 @@ def main():
if config.openai_base_url:
openai.base_url = config.openai_base_url

if config.openai_use_azure:
gptcli.providers.openai.use_azure = config.openai_use_azure

if config.openai_azure_api_version:
openai.api_version = config.openai_azure_api_version

if config.api_key:
openai.api_key = config.api_key

elif config.openai_api_key:
openai.api_key = config.openai_api_key

Expand Down
10 changes: 8 additions & 2 deletions gptcli/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Iterator, List, Optional, cast
import openai
from openai import OpenAI
from openai import AzureOpenAI, OpenAI
from openai.types.chat import ChatCompletionMessageParam

from gptcli.completion import (
Expand All @@ -15,10 +15,16 @@
UsageEvent,
)

use_azure: bool = False

class OpenAICompletionProvider(CompletionProvider):
def __init__(self):
self.client = OpenAI(api_key=openai.api_key, base_url=openai.base_url)
if use_azure:
self.client = AzureOpenAI(api_key=openai.api_key, base_url=openai.base_url, api_version=openai.api_version)
else:
self.client = OpenAI(api_key=openai.api_key, base_url=openai.base_url)



def complete(
self, messages: List[Message], args: dict, stream: bool = False
Expand Down

0 comments on commit edb060b

Please sign in to comment.