Skip to content

Commit

Permalink
using litellm!
Browse files Browse the repository at this point in the history
  • Loading branch information
rashadphz committed Jun 14, 2024
1 parent 2b54c24 commit 3602c5a
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 146 deletions.
43 changes: 42 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ groq = "^0.5.0"
slowapi = "^0.1.9"
redis = "^5.0.4"
llama-index-llms-ollama = "^0.1.3"
llama-index-llms-litellm = "^0.1.4"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.7.1"
Expand Down
73 changes: 17 additions & 56 deletions src/backend/chat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import asyncio
import os
from typing import AsyncIterator, List

from fastapi import HTTPException
from llama_index.core.llms import LLM
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.groq import Groq
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI

from backend.constants import ChatModel, model_mappings
from backend.constants import get_model_string
from backend.llm.base import BaseLLM, EveryLLM
from backend.prompts import CHAT_PROMPT, HISTORY_QUERY_REPHRASE
from backend.related_queries import generate_related_queries
from backend.schemas import (
Expand All @@ -29,59 +24,25 @@
from backend.utils import is_local_model


def rephrase_query_with_history(question: str, history: List[Message], llm: LLM) -> str:
def rephrase_query_with_history(
question: str, history: List[Message], llm: BaseLLM
) -> str:
if not history:
return question

try:
if history:
history_str = "\n".join([f"{msg.role}: {msg.content}" for msg in history])
question = llm.complete(
HISTORY_QUERY_REPHRASE.format(
chat_history=history_str, question=question
)
).text
question = question.replace('"', "")
history_str = "\n".join(f"{msg.role}: {msg.content}" for msg in history)
formatted_query = HISTORY_QUERY_REPHRASE.format(
chat_history=history_str, question=question
)
question = llm.complete(formatted_query).text.replace('"', "")
return question
except Exception:
raise HTTPException(
status_code=500, detail="Model is at capacity. Please try again later."
)


def get_openai_model(model: ChatModel) -> LLM:
openai_mode = os.environ.get("OPENAI_MODE", "openai")
if openai_mode == "azure":
return AzureOpenAI(
deployment_name=os.environ.get("AZURE_DEPLOYMENT_NAME"),
api_key=os.environ.get("AZURE_API_KEY"),
azure_endpoint=os.environ.get("AZURE_CHAT_ENDPOINT"),
api_version="2024-04-01-preview",
)
elif openai_mode == "openai":
return OpenAI(model=model_mappings[model])
else:
raise ValueError(f"Unknown model: {model}")


def get_llm(model: ChatModel) -> LLM:
if model == ChatModel.GPT_3_5_TURBO:
return get_openai_model(model)
elif model == ChatModel.GPT_4o:
return OpenAI(model=model_mappings[model])
elif model in [
ChatModel.LOCAL_GEMMA,
ChatModel.LOCAL_LLAMA_3,
ChatModel.LOCAL_MISTRAL,
ChatModel.LOCAL_PHI3_14B,
]:
return Ollama(
base_url=os.environ.get("OLLAMA_HOST", "http://localhost:11434"),
model=model_mappings[model],
)
elif model == ChatModel.LLAMA_3_70B:
return Groq(model=model_mappings[model])
else:
raise ValueError(f"Unknown model: {model}")


def format_context(search_results: List[SearchResult]) -> str:
return "\n\n".join(
[f"Citation {i+1}. {str(result)}" for i, result in enumerate(search_results)]
Expand All @@ -90,7 +51,7 @@ def format_context(search_results: List[SearchResult]) -> str:

async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseEvent]:
try:
llm = get_llm(request.model)
llm = EveryLLM(model=get_model_string(request.model))

yield ChatResponseEvent(
event=StreamEvent.BEGIN_STREAM,
Expand All @@ -108,7 +69,7 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE
related_queries_task = None
if not is_local_model(request.model):
related_queries_task = asyncio.create_task(
generate_related_queries(query, search_results, request.model)
generate_related_queries(query, search_results, llm)
)

yield ChatResponseEvent(
Expand All @@ -125,7 +86,7 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE
)

full_response = ""
response_gen = await llm.astream_complete(fmt_qa_prompt)
response_gen = await llm.astream(fmt_qa_prompt)
async for completion in response_gen:
full_response += completion.delta or ""
yield ChatResponseEvent(
Expand All @@ -136,7 +97,7 @@ async def stream_qa_objects(request: ChatRequest) -> AsyncIterator[ChatResponseE
related_queries = await (
related_queries_task
if related_queries_task
else generate_related_queries(query, search_results, request.model)
else generate_related_queries(query, search_results, llm)
)

yield ChatResponseEvent(
Expand Down
36 changes: 21 additions & 15 deletions src/backend/constants.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import os
from enum import Enum

GPT4_MODEL = "gpt-4o"
GPT3_MODEL = "gpt-3.5-turbo"
LLAMA_8B_MODEL = "llama3-8b-8192"
LLAMA_70B_MODEL = "llama3-70b-8192"
from dotenv import load_dotenv

LOCAL_LLAMA3_MODEL = "llama3"
LOCAL_GEMMA_MODEL = "gemma:7b"
LOCAL_MISTRAL_MODEL = "mistral"
LOCAL_PHI3_14B = "phi3:14b"
load_dotenv()


class ChatModel(str, Enum):
Expand All @@ -24,11 +19,22 @@ class ChatModel(str, Enum):


model_mappings: dict[ChatModel, str] = {
ChatModel.GPT_3_5_TURBO: GPT3_MODEL,
ChatModel.GPT_4o: GPT4_MODEL,
ChatModel.LLAMA_3_70B: LLAMA_70B_MODEL,
ChatModel.LOCAL_LLAMA_3: LOCAL_LLAMA3_MODEL,
ChatModel.LOCAL_GEMMA: LOCAL_GEMMA_MODEL,
ChatModel.LOCAL_MISTRAL: LOCAL_MISTRAL_MODEL,
ChatModel.LOCAL_PHI3_14B: LOCAL_PHI3_14B,
ChatModel.GPT_3_5_TURBO: "gpt-3.5-turbo",
ChatModel.GPT_4o: "gpt-4o",
ChatModel.LLAMA_3_70B: "groq/llama3-70b-8192",
ChatModel.LOCAL_LLAMA_3: "ollama_chat/llama3",
ChatModel.LOCAL_GEMMA: "ollama_chat/gemma",
ChatModel.LOCAL_MISTRAL: "ollama_chat/mistral",
ChatModel.LOCAL_PHI3_14B: "ollama_chat/phi3:14b",
}


def get_model_string(model: ChatModel) -> str:
if model in {ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o}:
openai_mode = os.environ.get("OPENAI_MODE", "openai")
if openai_mode == "azure":
# Currently deployments are named "gpt-35-turbo" and "gpt-4o"
name = model_mappings[model].replace(".", "")
return f"azure/{name}"

return model_mappings[model]
50 changes: 50 additions & 0 deletions src/backend/llm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from abc import ABC, abstractmethod

import instructor
from dotenv import load_dotenv
from instructor.client import T
from litellm import completion
from llama_index.core.base.llms.types import (
CompletionResponse,
CompletionResponseAsyncGen,
)
from llama_index.llms.litellm import LiteLLM

load_dotenv()


class BaseLLM(ABC):
@abstractmethod
async def astream(self, prompt: str) -> CompletionResponseAsyncGen:
pass

@abstractmethod
def complete(self, prompt: str) -> CompletionResponse:
pass

@abstractmethod
def structured_complete(self, response_model: type[T], prompt: str) -> T:
pass


class EveryLLM(BaseLLM):
def __init__(
self,
model: str,
):
self.llm = LiteLLM(model=model)

self.client = instructor.from_litellm(completion)

async def astream(self, prompt: str) -> CompletionResponseAsyncGen:
return await self.llm.astream_complete(prompt)

def complete(self, prompt: str) -> CompletionResponse:
return self.llm.complete(prompt)

def structured_complete(self, response_model: type[T], prompt: str) -> T:
return self.client.chat.completions.create(
model=self.llm.model,
messages=[{"role": "user", "content": prompt}],
response_model=response_model,
)
78 changes: 4 additions & 74 deletions src/backend/related_queries.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,15 @@
import os

import groq
import instructor
import openai
from dotenv import load_dotenv

from backend.constants import ChatModel, model_mappings
from backend.llm.base import BaseLLM
from backend.prompts import RELATED_QUESTION_PROMPT
from backend.schemas import RelatedQueries, SearchResult

load_dotenv()


OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://localhost:11434")


def get_openai_client() -> openai.AsyncOpenAI:
openai_mode = os.environ.get("OPENAI_MODE", "openai")
if openai_mode == "openai":
return openai.AsyncOpenAI()
elif openai_mode == "azure":
return openai.AsyncAzureOpenAI(
azure_deployment=os.environ.get("AZURE_DEPLOYMENT_NAME"),
azure_endpoint=os.environ["AZURE_CHAT_ENDPOINT"],
api_key=os.environ.get("AZURE_API_KEY"),
api_version="2024-04-01-preview",
)
else:
raise ValueError(f"Unknown openai mode: {openai_mode}")


def instructor_client(model: ChatModel) -> instructor.AsyncInstructor:
if model == ChatModel.GPT_3_5_TURBO:
return instructor.from_openai(
get_openai_client(),
)
elif model in [
ChatModel.GPT_3_5_TURBO,
ChatModel.GPT_4o,
]:
return instructor.from_openai(openai.AsyncOpenAI())
elif model in [
ChatModel.LOCAL_GEMMA,
ChatModel.LOCAL_LLAMA_3,
ChatModel.LOCAL_MISTRAL,
ChatModel.LOCAL_PHI3_14B,
]:
return instructor.from_openai(
openai.AsyncOpenAI(
base_url=f"{OLLAMA_HOST}/v1",
api_key="ollama",
),
mode=instructor.Mode.JSON,
)
elif model == ChatModel.LLAMA_3_70B:
return instructor.from_groq(groq.AsyncGroq(), mode=instructor.Mode.JSON) # type: ignore
else:
raise ValueError(f"Unknown model: {model}")


async def generate_related_queries(
query: str, search_results: list[SearchResult], model: ChatModel
query: str, search_results: list[SearchResult], llm: BaseLLM
) -> list[str]:
context = "\n\n".join([f"{str(result)}" for result in search_results])
# Truncate the context to 4000 characters (mainly for smaller models)
context = context[:4000]

client = instructor_client(model)
model_name = model_mappings[model]

print(RELATED_QUESTION_PROMPT.format(query=query, context=context))

related = await client.chat.completions.create(
model=model_name,
response_model=RelatedQueries,
messages=[
{
"role": "user",
"content": RELATED_QUESTION_PROMPT.format(query=query, context=context),
},
],
related = llm.structured_complete(
RelatedQueries, RELATED_QUESTION_PROMPT.format(query=query, context=context)
)

return [query.lower().replace("?", "") for query in related.related_questions]

0 comments on commit 3602c5a

Please sign in to comment.