Skip to content

Commit

Permalink
Merge branch 'feature-gemini'
Browse files Browse the repository at this point in the history
Fixes #275
  • Loading branch information
enricoros committed Dec 20, 2023
2 parents 33cb2b8 + fdb66da commit 7f21b2a
Show file tree
Hide file tree
Showing 67 changed files with 1,316 additions and 565 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# BIG-AGI 🧠✨

Welcome to big-AGI 👋, the GPT application for professionals that need function, form,
simplicity, and speed. Powered by the latest models from 7 vendors and
simplicity, and speed. Powered by the latest models from 8 vendors and
open-source model servers, `big-AGI` offers best-in-class Voice and Chat with AI Personas,
visualizations, coding, drawing, calling, and quite more -- all in a polished UX.

Expand Down
2 changes: 1 addition & 1 deletion app/api/llms/stream/route.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
export const runtime = 'edge';
export { openaiStreamingRelayHandler as POST } from '~/modules/llms/transports/server/openai/openai.streaming';
export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llm.server.streaming';
2 changes: 1 addition & 1 deletion docs/config-local-localai.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ For instance with [Use luna-ai-llama2 with docker compose](https://localai.io/ba

> NOTE: LocalAI does not list details about the mdoels. Every model is assumed to be
> capable of chatting, and with a context window of 4096 tokens.
> Please update the [src/modules/llms/transports/server/openai/models.data.ts](../src/modules/llms/transports/server/openai/models.data.ts)
> Please update the [src/modules/llms/transports/server/openai/models.data.ts](../src/modules/llms/server/openai/models.data.ts)
> file with the mapping information between LocalAI model IDs and names/descriptions/tokens, etc.
2 changes: 2 additions & 0 deletions docs/environment-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ AZURE_OPENAI_API_ENDPOINT=
AZURE_OPENAI_API_KEY=
ANTHROPIC_API_KEY=
ANTHROPIC_API_HOST=
GEMINI_API_KEY=
MISTRAL_API_KEY=
OLLAMA_API_HOST=
OPENROUTER_API_KEY=
Expand Down Expand Up @@ -80,6 +81,7 @@ requiring the user to enter an API key
| `AZURE_OPENAI_API_KEY` | Azure OpenAI API key, see [config-azure-openai.md](config-azure-openai.md) | Optional, but if set `AZURE_OPENAI_API_ENDPOINT` must also be set |
| `ANTHROPIC_API_KEY` | The API key for Anthropic | Optional |
| `ANTHROPIC_API_HOST` | Changes the backend host for the Anthropic vendor, to enable platforms such as [config-aws-bedrock.md](config-aws-bedrock.md) | Optional |
| `GEMINI_API_KEY` | The API key for Google AI's Gemini | Optional |
| `MISTRAL_API_KEY` | The API key for Mistral | Optional |
| `OLLAMA_API_HOST` | Changes the backend host for the Ollama vendor. See [config-ollama.md](config-ollama.md) | |
| `OPENROUTER_API_KEY` | The API key for OpenRouter | Optional |
Expand Down
5 changes: 2 additions & 3 deletions src/apps/call/CallUI.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ import { useChatLLMDropdown } from '../chat/components/applayout/useLLMDropdown'

import { EXPERIMENTAL_speakTextStream } from '~/modules/elevenlabs/elevenlabs.client';
import { SystemPurposeId, SystemPurposes } from '../../data';
import { VChatMessageIn } from '~/modules/llms/transports/chatGenerate';
import { streamChat } from '~/modules/llms/transports/streamChat';
import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client';
import { useElevenLabsVoiceDropdown } from '~/modules/elevenlabs/useElevenLabsVoiceDropdown';

import { Link } from '~/common/components/Link';
Expand Down Expand Up @@ -216,7 +215,7 @@ export function CallUI(props: {
responseAbortController.current = new AbortController();
let finalText = '';
let error: any | null = null;
streamChat(chatLLMId, callPrompt, responseAbortController.current.signal, (updatedMessage: Partial<DMessage>) => {
llmStreamingChatGenerate(chatLLMId, callPrompt, null, null, responseAbortController.current.signal, (updatedMessage: Partial<DMessage>) => {
const text = updatedMessage.text?.trim();
if (text) {
finalText = text;
Expand Down
2 changes: 1 addition & 1 deletion src/apps/call/components/CallMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import { Chip, ColorPaletteProp, VariantProp } from '@mui/joy';
import { SxProps } from '@mui/joy/styles/types';

import { VChatMessageIn } from '~/modules/llms/transports/chatGenerate';
import type { VChatMessageIn } from '~/modules/llms/llm.client';


export function CallMessage(props: {
Expand Down
2 changes: 2 additions & 0 deletions src/apps/chat/components/message/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ function explainErrorInMessage(text: string, isAssistant: boolean, modelId?: str
make sure the usage is under <Link noLinkStyle href='https://platform.openai.com/account/billing/limits' target='_blank'>the limits</Link>.
</>;
}
// else
// errorMessage = <>{text || 'Unknown error'}</>;

return { errorMessage, isAssistantError };
}
Expand Down
4 changes: 2 additions & 2 deletions src/apps/chat/editors/chat-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { DLLMId } from '~/modules/llms/store-llms';
import { SystemPurposeId } from '../../../data';
import { autoSuggestions } from '~/modules/aifn/autosuggestions/autoSuggestions';
import { autoTitle } from '~/modules/aifn/autotitle/autoTitle';
import { llmStreamingChatGenerate } from '~/modules/llms/llm.client';
import { speakText } from '~/modules/elevenlabs/elevenlabs.client';
import { streamChat } from '~/modules/llms/transports/streamChat';

import { DMessage, useChatStore } from '~/common/state/store-chats';

Expand Down Expand Up @@ -63,7 +63,7 @@ async function streamAssistantMessage(
const messages = history.map(({ role, text }) => ({ role, content: text }));

try {
await streamChat(llmId, messages, abortSignal,
await llmStreamingChatGenerate(llmId, messages, null, null, abortSignal,
(updatedMessage: Partial<DMessage>) => {
// update the message in the store (and thus schedule a re-render)
editMessage(updatedMessage);
Expand Down
4 changes: 2 additions & 2 deletions src/apps/personas/useLLMChain.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as React from 'react';

import { DLLMId, useModelsStore } from '~/modules/llms/store-llms';
import { callChatGenerate, VChatMessageIn } from '~/modules/llms/transports/chatGenerate';
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';


export interface LLMChainStep {
Expand Down Expand Up @@ -80,7 +80,7 @@ export function useLLMChain(steps: LLMChainStep[], llmId: DLLMId | undefined, ch
_chainAbortController.signal.addEventListener('abort', globalToStepListener);

// LLM call
callChatGenerate(llmId, llmChatInput, chain.overrideResponseTokens)
llmChatGenerateOrThrow(llmId, llmChatInput, null, null, chain.overrideResponseTokens)
.then(({ content }) => {
stepDone = true;
if (!stepAbortController.signal.aborted)
Expand Down
2 changes: 1 addition & 1 deletion src/common/layout/AppLayout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { shallow } from 'zustand/shallow';

import { Box, Container } from '@mui/joy';

import { ModelsModal } from '../../apps/models-modal/ModelsModal';
import { ModelsModal } from '~/modules/llms/models-modal/ModelsModal';
import { SettingsModal } from '../../apps/settings-modal/SettingsModal';
import { ShortcutsModal } from '../../apps/settings-modal/ShortcutsModal';

Expand Down
13 changes: 9 additions & 4 deletions src/modules/aifn/autosuggestions/autoSuggestions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { callChatGenerateWithFunctions, VChatFunctionIn } from '~/modules/llms/transports/chatGenerate';
import { llmChatGenerateOrThrow, VChatFunctionIn } from '~/modules/llms/llm.client';
import { useModelsStore } from '~/modules/llms/store-llms';

import { useChatStore } from '~/common/state/store-chats';
Expand Down Expand Up @@ -71,7 +71,7 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri

// Follow-up: Question
if (suggestQuestions) {
// callChatGenerateWithFunctions(funcLLMId, [
// llmChatGenerateOrThrow(funcLLMId, [
// { role: 'system', content: systemMessage.text },
// { role: 'user', content: userMessage.text },
// { role: 'assistant', content: assistantMessageText },
Expand All @@ -83,15 +83,18 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri

// Follow-up: Auto-Diagrams
if (suggestDiagrams) {
void callChatGenerateWithFunctions(funcLLMId, [
void llmChatGenerateOrThrow(funcLLMId, [
{ role: 'system', content: systemMessage.text },
{ role: 'user', content: userMessage.text },
{ role: 'assistant', content: assistantMessageText },
], [suggestPlantUMLFn], 'draw_plantuml_diagram',
).then(chatResponse => {

if (!('function_arguments' in chatResponse))
return;

// parse the output PlantUML string, if any
const functionArguments = chatResponse?.function_arguments ?? null;
const functionArguments = chatResponse.function_arguments ?? null;
if (functionArguments) {
const { code, type }: { code: string, type: string } = functionArguments as any;
if (code && type) {
Expand All @@ -105,6 +108,8 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri
editMessage(conversationId, assistantMessageId, { text: assistantMessageText }, false);
}
}
}).catch(err => {
console.error('autoSuggestions::diagram:', err);
});
}

Expand Down
6 changes: 3 additions & 3 deletions src/modules/aifn/autotitle/autoTitle.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { callChatGenerate } from '~/modules/llms/transports/chatGenerate';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
import { useModelsStore } from '~/modules/llms/store-llms';

import { useChatStore } from '~/common/state/store-chats';
Expand Down Expand Up @@ -27,7 +27,7 @@ export function autoTitle(conversationId: string) {
});

// LLM
void callChatGenerate(fastLLMId, [
void llmChatGenerateOrThrow(fastLLMId, [
{ role: 'system', content: `You are an AI conversation titles assistant who specializes in creating expressive yet few-words chat titles.` },
{
role: 'user', content:
Expand All @@ -39,7 +39,7 @@ export function autoTitle(conversationId: string) {
historyLines.join('\n') +
'```\n',
},
]).then(chatResponse => {
], null, null).then(chatResponse => {

const title = chatResponse?.content
?.trim()
Expand Down
5 changes: 3 additions & 2 deletions src/modules/aifn/digrams/DiagramsModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import ReplayIcon from '@mui/icons-material/Replay';
import StopOutlinedIcon from '@mui/icons-material/StopOutlined';
import TelegramIcon from '@mui/icons-material/Telegram';

import { llmStreamingChatGenerate } from '~/modules/llms/llm.client';

import { ChatMessage } from '../../../apps/chat/components/message/ChatMessage';
import { streamChat } from '~/modules/llms/transports/streamChat';

import { GoodModal } from '~/common/components/GoodModal';
import { InlineError } from '~/common/components/InlineError';
Expand Down Expand Up @@ -85,7 +86,7 @@ export function DiagramsModal(props: { config: DiagramConfig, onClose: () => voi
const diagramPrompt = bigDiagramPrompt(diagramType, diagramLanguage, systemMessage.text, subject, customInstruction);

try {
await streamChat(diagramLlm.id, diagramPrompt, stepAbortController.signal,
await llmStreamingChatGenerate(diagramLlm.id, diagramPrompt, null, null, stepAbortController.signal,
(update: Partial<{ text: string, typing: boolean, originLLM: string }>) => {
assistantMessage = { ...assistantMessage, ...update };
setMessage(assistantMessage);
Expand Down
3 changes: 1 addition & 2 deletions src/modules/aifn/digrams/diagrams.data.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import type { VChatMessageIn } from '~/modules/llms/transports/chatGenerate';

import type { FormRadioOption } from '~/common/components/forms/FormRadioControl';
import type { VChatMessageIn } from '~/modules/llms/llm.client';


export type DiagramType = 'auto' | 'mind';
Expand Down
6 changes: 3 additions & 3 deletions src/modules/aifn/imagine/imaginePromptFromText.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { callChatGenerate } from '~/modules/llms/transports/chatGenerate';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
import { useModelsStore } from '~/modules/llms/store-llms';


Expand All @@ -14,10 +14,10 @@ export async function imaginePromptFromText(messageText: string): Promise<string
const { fastLLMId } = useModelsStore.getState();
if (!fastLLMId) return null;
try {
const chatResponse = await callChatGenerate(fastLLMId, [
const chatResponse = await llmChatGenerateOrThrow(fastLLMId, [
{ role: 'system', content: simpleImagineSystemPrompt },
{ role: 'user', content: 'Write a prompt, based on the following input.\n\n```\n' + messageText.slice(0, 1000) + '\n```\n' },
]);
], null, null);
return chatResponse.content?.trim() ?? null;
} catch (error: any) {
console.error('imaginePromptFromText: fetch request error:', error);
Expand Down
4 changes: 2 additions & 2 deletions src/modules/aifn/react/react.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import { DLLMId } from '~/modules/llms/store-llms';
import { callApiSearchGoogle } from '~/modules/google/search.client';
import { callBrowseFetchPage } from '~/modules/browse/browse.client';
import { callChatGenerate, VChatMessageIn } from '~/modules/llms/transports/chatGenerate';
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';


// prompt to implement the ReAct paradigm: https://arxiv.org/abs/2210.03629
Expand Down Expand Up @@ -128,7 +128,7 @@ export class Agent {
S.messages.push({ role: 'user', content: prompt });
let content: string;
try {
content = (await callChatGenerate(llmId, S.messages, 500)).content;
content = (await llmChatGenerateOrThrow(llmId, S.messages, null, null, 500)).content;
} catch (error: any) {
content = `Error in callChat: ${error}`;
}
Expand Down
6 changes: 3 additions & 3 deletions src/modules/aifn/summarize/summerize.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { DLLMId, findLLMOrThrow } from '~/modules/llms/store-llms';
import { callChatGenerate } from '~/modules/llms/transports/chatGenerate';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';


// prompt to be tried when doing recursive summerization.
Expand Down Expand Up @@ -80,10 +80,10 @@ async function cleanUpContent(chunk: string, llmId: DLLMId, _ignored_was_targetW
const autoResponseTokensSize = Math.floor(contextTokens * outputTokenShare);

try {
const chatResponse = await callChatGenerate(llmId, [
const chatResponse = await llmChatGenerateOrThrow(llmId, [
{ role: 'system', content: cleanupPrompt },
{ role: 'user', content: chunk },
], autoResponseTokensSize);
], null, null, autoResponseTokensSize);
return chatResponse?.content ?? '';
} catch (error: any) {
return '';
Expand Down
5 changes: 2 additions & 3 deletions src/modules/aifn/useStreamChatText.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import * as React from 'react';

import type { DLLMId } from '~/modules/llms/store-llms';
import type { VChatMessageIn } from '~/modules/llms/transports/chatGenerate';
import { streamChat } from '~/modules/llms/transports/streamChat';
import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client';


export function useStreamChatText() {
Expand All @@ -25,7 +24,7 @@ export function useStreamChatText() {

try {
let lastText = '';
await streamChat(llmId, prompt, abortControllerRef.current.signal, (update) => {
await llmStreamingChatGenerate(llmId, prompt, null, null, abortControllerRef.current.signal, (update) => {
if (update.text) {
lastText = update.text;
setPartialText(lastText);
Expand Down
1 change: 1 addition & 0 deletions src/modules/backend/backend.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export const backendRouter = createTRPCRouter({
hasImagingProdia: !!env.PRODIA_API_KEY,
hasLlmAnthropic: !!env.ANTHROPIC_API_KEY,
hasLlmAzureOpenAI: !!env.AZURE_OPENAI_API_KEY && !!env.AZURE_OPENAI_API_ENDPOINT,
hasLlmGemini: !!env.GEMINI_API_KEY,
hasLlmMistral: !!env.MISTRAL_API_KEY,
hasLlmOllama: !!env.OLLAMA_API_HOST,
hasLlmOpenAI: !!env.OPENAI_API_KEY || !!env.OPENAI_API_HOST,
Expand Down
2 changes: 2 additions & 0 deletions src/modules/backend/state-backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export interface BackendCapabilities {
hasImagingProdia: boolean;
hasLlmAnthropic: boolean;
hasLlmAzureOpenAI: boolean;
hasLlmGemini: boolean;
hasLlmMistral: boolean;
hasLlmOllama: boolean;
hasLlmOpenAI: boolean;
Expand All @@ -31,6 +32,7 @@ const useBackendStore = create<BackendStore>()(
hasImagingProdia: false,
hasLlmAnthropic: false,
hasLlmAzureOpenAI: false,
hasLlmGemini: false,
hasLlmMistral: false,
hasLlmOllama: false,
hasLlmOpenAI: false,
Expand Down
74 changes: 74 additions & 0 deletions src/modules/llms/llm.client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import type { DLLMId } from './store-llms';
import type { OpenAIWire } from './server/openai/openai.wiretypes';
import { findVendorForLlmOrThrow } from './vendors/vendors.registry';


// LLM Client Types
// NOTE: Model List types in '../server/llm.server.types';

export interface VChatMessageIn {
role: 'assistant' | 'system' | 'user'; // | 'function';
content: string;
//name?: string; // when role: 'function'
}

export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef;

export interface VChatMessageOut {
role: 'assistant' | 'system' | 'user';
content: string;
finish_reason: 'stop' | 'length' | null;
}

export interface VChatMessageOrFunctionCallOut extends VChatMessageOut {
function_name: string;
function_arguments: object | null;
}


// LLM Client Functions

export async function llmChatGenerateOrThrow<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
llmId: DLLMId,
messages: VChatMessageIn[],
functions: VChatFunctionIn[] | null, forceFunctionName: string | null,
maxTokens?: number,
): Promise<VChatMessageOut | VChatMessageOrFunctionCallOut> {

// id to DLLM and vendor
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess, TLLMOptions>(llmId);

// FIXME: relax the forced cast
const options = llm.options as TLLMOptions;

// get the access
const partialSourceSetup = llm._source.setup;
const access = vendor.getTransportAccess(partialSourceSetup);

// execute via the vendor
return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens);
}


export async function llmStreamingChatGenerate<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
llmId: DLLMId,
messages: VChatMessageIn[],
functions: VChatFunctionIn[] | null,
forceFunctionName: string | null,
abortSignal: AbortSignal,
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
): Promise<void> {

// id to DLLM and vendor
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess, TLLMOptions>(llmId);

// FIXME: relax the forced cast
const llmOptions = llm.options as TLLMOptions;

// get the access
const partialSourceSetup = llm._source.setup;
const access = vendor.getTransportAccess(partialSourceSetup); // as ChatStreamInputSchema['access'];

// execute via the vendor
return await vendor.streamingChatGenerateOrThrow(access, llmId, llmOptions, messages, functions, forceFunctionName, abortSignal, onUpdate);
}
Loading

0 comments on commit 7f21b2a

Please sign in to comment.