Skip to content

Commit

Permalink
added azure
Browse files Browse the repository at this point in the history
  • Loading branch information
rashadphz committed Jun 13, 2024
1 parent 5abf3dd commit 2b54c24
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 11 deletions.
20 changes: 19 additions & 1 deletion src/backend/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import HTTPException
from llama_index.core.llms import LLM
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.groq import Groq
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
Expand Down Expand Up @@ -45,8 +46,25 @@ def rephrase_query_with_history(question: str, history: List[Message], llm: LLM)
)


def get_openai_model(model: ChatModel) -> LLM:
openai_mode = os.environ.get("OPENAI_MODE", "openai")
if openai_mode == "azure":
return AzureOpenAI(
deployment_name=os.environ.get("AZURE_DEPLOYMENT_NAME"),
api_key=os.environ.get("AZURE_API_KEY"),
azure_endpoint=os.environ.get("AZURE_CHAT_ENDPOINT"),
api_version="2024-04-01-preview",
)
elif openai_mode == "openai":
return OpenAI(model=model_mappings[model])
else:
raise ValueError(f"Unknown model: {model}")


def get_llm(model: ChatModel) -> LLM:
if model in [ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o]:
if model == ChatModel.GPT_3_5_TURBO:
return get_openai_model(model)
elif model == ChatModel.GPT_4o:
return OpenAI(model=model_mappings[model])
elif model in [
ChatModel.LOCAL_GEMMA,
Expand Down
11 changes: 6 additions & 5 deletions src/backend/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@
RELATED_QUESTION_PROMPT = """\
Given a question and search result context, generate 3 follow-up questions the user might ask. Use the original question and context.
There must be EXACTLY 3 questions. Keep the questions concise, and simple. This should return an object with the following fields:
questions: A list of 3 concise, simple questions
Instructions:
- Generate exactly 3 questions.
- These questions should be concise, and simple.
- Ensure the follow-up questions are relevant to the original question and context.
Make sure to match the language of the user's question.
Original Question: {query}
<context>
{context}
</context>
Your EXACTLY 3 (three) follow-up questions:
Output:
related_questions: A list of EXACTLY three concise, simple follow-up questions
"""

HISTORY_QUERY_REPHRASE = """
Expand Down
27 changes: 24 additions & 3 deletions src/backend/related_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,27 @@
OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://localhost:11434")


def get_openai_client() -> openai.AsyncOpenAI:
openai_mode = os.environ.get("OPENAI_MODE", "openai")
if openai_mode == "openai":
return openai.AsyncOpenAI()
elif openai_mode == "azure":
return openai.AsyncAzureOpenAI(
azure_deployment=os.environ.get("AZURE_DEPLOYMENT_NAME"),
azure_endpoint=os.environ["AZURE_CHAT_ENDPOINT"],
api_key=os.environ.get("AZURE_API_KEY"),
api_version="2024-04-01-preview",
)
else:
raise ValueError(f"Unknown openai mode: {openai_mode}")


def instructor_client(model: ChatModel) -> instructor.AsyncInstructor:
if model in [
if model == ChatModel.GPT_3_5_TURBO:
return instructor.from_openai(
get_openai_client(),
)
elif model in [
ChatModel.GPT_3_5_TURBO,
ChatModel.GPT_4o,
]:
Expand All @@ -35,7 +54,7 @@ def instructor_client(model: ChatModel) -> instructor.AsyncInstructor:
mode=instructor.Mode.JSON,
)
elif model == ChatModel.LLAMA_3_70B:
return instructor.from_groq(groq.AsyncGroq(), mode=instructor.Mode.JSON)
return instructor.from_groq(groq.AsyncGroq(), mode=instructor.Mode.JSON) # type: ignore
else:
raise ValueError(f"Unknown model: {model}")

Expand All @@ -50,6 +69,8 @@ async def generate_related_queries(
client = instructor_client(model)
model_name = model_mappings[model]

print(RELATED_QUESTION_PROMPT.format(query=query, context=context))

related = await client.chat.completions.create(
model=model_name,
response_model=RelatedQueries,
Expand All @@ -61,4 +82,4 @@ async def generate_related_queries(
],
)

return [query.lower().replace("?", "") for query in related.questions]
return [query.lower().replace("?", "") for query in related.related_questions]
2 changes: 1 addition & 1 deletion src/backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ChatRequest(BaseModel, plugin_settings=record_all):


class RelatedQueries(BaseModel):
questions: List[str] = Field(..., min_length=3, max_length=3)
related_questions: List[str] = Field(..., min_length=3, max_length=3)


class SearchResult(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/components/model-selection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ const ModelItem: React.FC<{ model: Model }> = ({ model }) => (
);

export function ModelSelection() {
const { model, setModel, localMode } = useConfigStore();
const { localMode, model, setModel } = useConfigStore();

return (
<Select
Expand Down

0 comments on commit 2b54c24

Please sign in to comment.