Skip to content

Commit

Permalink
Merge pull request #87 from init0xyz/main
Browse files Browse the repository at this point in the history
Fix the incompatibility of ollama and groq json's response and update default model selection
  • Loading branch information
rashadphz authored Sep 6, 2024
2 parents 91fb2b3 + b1741f2 commit 883003f
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 19 deletions.
1 change: 1 addition & 0 deletions .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ OLLAMA_API_BASE=http://host.docker.internal:11434

# Cloud Models (Optional)
OPENAI_API_KEY=
OPENAI_API_BASE=
GROQ_API_KEY=

AZURE_DEPLOYMENT_NAME=
Expand Down
1 change: 1 addition & 0 deletions docker-compose.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ services:
- OLLAMA_API_BASE=${OLLAMA_API_BASE:-http://host.docker.internal:11434}

- OPENAI_API_KEY=${OPENAI_API_KEY}
- OPENAI_API_BASE=${OPENAI_API_BASE}
- GROQ_API_KEY=${GROQ_API_KEY}

- AZURE_DEPLOYMENT_NAME=${AZURE_DEPLOYMENT_NAME}
Expand Down
13 changes: 7 additions & 6 deletions src/backend/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
class ChatModel(str, Enum):
LLAMA_3_70B = "llama-3-70b"
GPT_4o = "gpt-4o"
GPT_3_5_TURBO = "gpt-3.5-turbo"
GPT_4o_mini = "gpt-4o-mini"
COMMAND_R = "command-r"

# Local models
LOCAL_LLAMA_3 = "llama3"
LOCAL_LLAMA_3 = "llama3.1"
LOCAL_GEMMA = "gemma"
LOCAL_MISTRAL = "mistral"
LOCAL_PHI3_14B = "phi3:14b"
Expand All @@ -22,10 +23,10 @@ class ChatModel(str, Enum):


model_mappings: dict[ChatModel, str] = {
ChatModel.GPT_3_5_TURBO: "gpt-3.5-turbo",
ChatModel.GPT_4o: "gpt-4o",
ChatModel.LLAMA_3_70B: "groq/llama3-70b-8192",
ChatModel.LOCAL_LLAMA_3: "ollama_chat/llama3",
ChatModel.GPT_4o_mini: "gpt-4o-mini",
ChatModel.LLAMA_3_70B: "groq/llama-3.1-70b-versatile",
ChatModel.LOCAL_LLAMA_3: "ollama_chat/llama3.1",
ChatModel.LOCAL_GEMMA: "ollama_chat/gemma",
ChatModel.LOCAL_MISTRAL: "ollama_chat/mistral",
ChatModel.LOCAL_PHI3_14B: "ollama_chat/phi3:14b",
Expand All @@ -39,7 +40,7 @@ def get_model_string(model: ChatModel) -> str:
raise ValueError("CUSTOM_MODEL is not set")
return custom_model

if model in {ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o}:
if model in {ChatModel.GPT_4o_mini, ChatModel.GPT_4o}:
openai_mode = os.environ.get("OPENAI_MODE", "openai")
if openai_mode == "azure":
# Currently deployments are named "gpt-35-turbo" and "gpt-4o"
Expand Down
5 changes: 4 additions & 1 deletion src/backend/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def __init__(
raise ValueError(f"Missing keys: {validation['missing_keys']}")

self.llm = LiteLLM(model=model)
self.client = instructor.from_litellm(completion)
if 'groq' in model or 'ollama_chat' in model:
self.client = instructor.from_litellm(completion, mode=instructor.Mode.MD_JSON)
else:
self.client = instructor.from_litellm(completion)

async def astream(self, prompt: str) -> CompletionResponseAsyncGen:
return await self.llm.astream_complete(prompt)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ChatRequest(BaseModel, plugin_settings=record_all):
thread_id: int | None = None
query: str
history: List[Message] = Field(default_factory=list)
model: ChatModel = ChatModel.GPT_3_5_TURBO
model: ChatModel = ChatModel.GPT_4o_mini
pro_search: bool = False


Expand Down
2 changes: 1 addition & 1 deletion src/backend/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def validate_model(model: ChatModel):
if model in {ChatModel.GPT_3_5_TURBO, ChatModel.GPT_4o}:
if model in {ChatModel.GPT_4o_mini, ChatModel.GPT_4o}:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY environment variable not found")
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/generated/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ export type ChatMessage = {
export enum ChatModel {
LLAMA_3_70B = "llama-3-70b",
GPT_4O = "gpt-4o",
GPT_3_5_TURBO = "gpt-3.5-turbo",
LLAMA3 = "llama3",
GPT_4O_MINI = "gpt-4o-mini",
LLAMA3 = "llama3.1",
GEMMA = "gemma",
MISTRAL = "mistral",
PHI3_14B = "phi3:14b",
Expand Down
10 changes: 5 additions & 5 deletions src/frontend/src/components/model-selection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ type Model = {
};

export const modelMap: Record<ChatModel, Model> = {
[ChatModel.GPT_3_5_TURBO]: {
[ChatModel.GPT_4O_MINI]: {
name: "Fast",
description: "OpenAI/GPT-3.5-turbo",
value: ChatModel.GPT_3_5_TURBO,
description: "OpenAI/GPT-4o-mini",
value: ChatModel.GPT_4O_MINI,
smallIcon: <RabbitIcon className="w-4 h-4 text-cyan-500" />,
icon: <RabbitIcon className="w-5 h-5 text-cyan-500" />,
},
Expand All @@ -60,7 +60,7 @@ export const modelMap: Record<ChatModel, Model> = {
},
[ChatModel.LLAMA3]: {
name: "Llama3",
description: "ollama/llama3",
description: "ollama/llama3.1",
value: ChatModel.LLAMA3,
smallIcon: <WandSparklesIcon className="w-4 h-4 text-purple-500" />,
icon: <WandSparklesIcon className="w-5 h-5 text-purple-500" />,
Expand Down Expand Up @@ -123,7 +123,7 @@ const ModelItem: React.FC<{ model: Model }> = ({ model }) => (

export function ModelSelection() {
const { localMode, model, setModel, toggleLocalMode } = useConfigStore();
const selectedModel = modelMap[model] ?? modelMap[ChatModel.GPT_3_5_TURBO];
const selectedModel = modelMap[model] ?? modelMap[ChatModel.GPT_4O_MINI];

return (
<Select
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ export function isCloudModel(model: ChatModel) {
return [
ChatModel.LLAMA_3_70B,
ChatModel.GPT_4O,
ChatModel.GPT_3_5_TURBO,
ChatModel.GPT_4O_MINI,
].includes(model);
}
4 changes: 2 additions & 2 deletions src/frontend/src/stores/slices/configSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export const createConfigSlice: StateCreator<
[],
ConfigStore
> = (set) => ({
model: ChatModel.GPT_3_5_TURBO,
model: ChatModel.GPT_4O_MINI,
localMode: false,
proMode: false,
setModel: (model: ChatModel) => set({ model }),
Expand All @@ -36,7 +36,7 @@ export const createConfigSlice: StateCreator<
const newLocalMode = !state.localMode;
const newModel = newLocalMode
? ChatModel.LLAMA3
: ChatModel.GPT_3_5_TURBO;
: ChatModel.GPT_4O_MINI;
return { localMode: newLocalMode, model: newModel };
}),
toggleProMode: () =>
Expand Down

0 comments on commit 883003f

Please sign in to comment.