diff --git a/README.md b/README.md index e3cd1cfbd..42d736a85 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # BIG-AGI ๐Ÿง โœจ Welcome to big-AGI ๐Ÿ‘‹, the GPT application for professionals that need function, form, -simplicity, and speed. Powered by the latest models from 7 vendors and +simplicity, and speed. Powered by the latest models from 8 vendors and open-source model servers, `big-AGI` offers best-in-class Voice and Chat with AI Personas, visualizations, coding, drawing, calling, and quite more -- all in a polished UX. diff --git a/app/api/llms/stream/route.ts b/app/api/llms/stream/route.ts index c7430d3bc..fed3b5d47 100644 --- a/app/api/llms/stream/route.ts +++ b/app/api/llms/stream/route.ts @@ -1,2 +1,2 @@ export const runtime = 'edge'; -export { openaiStreamingRelayHandler as POST } from '~/modules/llms/transports/server/openai/openai.streaming'; \ No newline at end of file +export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llm.server.streaming'; \ No newline at end of file diff --git a/docs/config-local-localai.md b/docs/config-local-localai.md index 4cc2bf4fa..43a9d1aa5 100644 --- a/docs/config-local-localai.md +++ b/docs/config-local-localai.md @@ -30,5 +30,5 @@ For instance with [Use luna-ai-llama2 with docker compose](https://localai.io/ba > NOTE: LocalAI does not list details about the mdoels. Every model is assumed to be > capable of chatting, and with a context window of 4096 tokens. -> Please update the [src/modules/llms/transports/server/openai/models.data.ts](../src/modules/llms/transports/server/openai/models.data.ts) +> Please update the [src/modules/llms/transports/server/openai/models.data.ts](../src/modules/llms/server/openai/models.data.ts) > file with the mapping information between LocalAI model IDs and names/descriptions/tokens, etc. diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 5027f4797..b11f0269d 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -24,6 +24,7 @@ AZURE_OPENAI_API_ENDPOINT= AZURE_OPENAI_API_KEY= ANTHROPIC_API_KEY= ANTHROPIC_API_HOST= +GEMINI_API_KEY= MISTRAL_API_KEY= OLLAMA_API_HOST= OPENROUTER_API_KEY= @@ -80,6 +81,7 @@ requiring the user to enter an API key | `AZURE_OPENAI_API_KEY` | Azure OpenAI API key, see [config-azure-openai.md](config-azure-openai.md) | Optional, but if set `AZURE_OPENAI_API_ENDPOINT` must also be set | | `ANTHROPIC_API_KEY` | The API key for Anthropic | Optional | | `ANTHROPIC_API_HOST` | Changes the backend host for the Anthropic vendor, to enable platforms such as [config-aws-bedrock.md](config-aws-bedrock.md) | Optional | +| `GEMINI_API_KEY` | The API key for Google AI's Gemini | Optional | | `MISTRAL_API_KEY` | The API key for Mistral | Optional | | `OLLAMA_API_HOST` | Changes the backend host for the Ollama vendor. See [config-ollama.md](config-ollama.md) | | | `OPENROUTER_API_KEY` | The API key for OpenRouter | Optional | diff --git a/src/apps/call/CallUI.tsx b/src/apps/call/CallUI.tsx index 6ff0efde1..1da62a002 100644 --- a/src/apps/call/CallUI.tsx +++ b/src/apps/call/CallUI.tsx @@ -15,8 +15,7 @@ import { useChatLLMDropdown } from '../chat/components/applayout/useLLMDropdown' import { EXPERIMENTAL_speakTextStream } from '~/modules/elevenlabs/elevenlabs.client'; import { SystemPurposeId, SystemPurposes } from '../../data'; -import { VChatMessageIn } from '~/modules/llms/transports/chatGenerate'; -import { streamChat } from '~/modules/llms/transports/streamChat'; +import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client'; import { useElevenLabsVoiceDropdown } from '~/modules/elevenlabs/useElevenLabsVoiceDropdown'; import { Link } from '~/common/components/Link'; @@ -216,7 +215,7 @@ export function CallUI(props: { responseAbortController.current = new AbortController(); let finalText = ''; let error: any | null = null; - streamChat(chatLLMId, callPrompt, responseAbortController.current.signal, (updatedMessage: Partial) => { + llmStreamingChatGenerate(chatLLMId, callPrompt, null, null, responseAbortController.current.signal, (updatedMessage: Partial) => { const text = updatedMessage.text?.trim(); if (text) { finalText = text; diff --git a/src/apps/call/components/CallMessage.tsx b/src/apps/call/components/CallMessage.tsx index ae67ef141..63a3b76cb 100644 --- a/src/apps/call/components/CallMessage.tsx +++ b/src/apps/call/components/CallMessage.tsx @@ -3,7 +3,7 @@ import * as React from 'react'; import { Chip, ColorPaletteProp, VariantProp } from '@mui/joy'; import { SxProps } from '@mui/joy/styles/types'; -import { VChatMessageIn } from '~/modules/llms/transports/chatGenerate'; +import type { VChatMessageIn } from '~/modules/llms/llm.client'; export function CallMessage(props: { diff --git a/src/apps/chat/components/message/ChatMessage.tsx b/src/apps/chat/components/message/ChatMessage.tsx index 2fd6ccbbc..616617f94 100644 --- a/src/apps/chat/components/message/ChatMessage.tsx +++ b/src/apps/chat/components/message/ChatMessage.tsx @@ -167,6 +167,8 @@ function explainErrorInMessage(text: string, isAssistant: boolean, modelId?: str make sure the usage is under the limits. ; } + // else + // errorMessage = <>{text || 'Unknown error'}; return { errorMessage, isAssistantError }; } diff --git a/src/apps/chat/editors/chat-stream.ts b/src/apps/chat/editors/chat-stream.ts index 090fc05ff..e3c5a5b4e 100644 --- a/src/apps/chat/editors/chat-stream.ts +++ b/src/apps/chat/editors/chat-stream.ts @@ -2,8 +2,8 @@ import { DLLMId } from '~/modules/llms/store-llms'; import { SystemPurposeId } from '../../../data'; import { autoSuggestions } from '~/modules/aifn/autosuggestions/autoSuggestions'; import { autoTitle } from '~/modules/aifn/autotitle/autoTitle'; +import { llmStreamingChatGenerate } from '~/modules/llms/llm.client'; import { speakText } from '~/modules/elevenlabs/elevenlabs.client'; -import { streamChat } from '~/modules/llms/transports/streamChat'; import { DMessage, useChatStore } from '~/common/state/store-chats'; @@ -63,7 +63,7 @@ async function streamAssistantMessage( const messages = history.map(({ role, text }) => ({ role, content: text })); try { - await streamChat(llmId, messages, abortSignal, + await llmStreamingChatGenerate(llmId, messages, null, null, abortSignal, (updatedMessage: Partial) => { // update the message in the store (and thus schedule a re-render) editMessage(updatedMessage); diff --git a/src/apps/personas/useLLMChain.ts b/src/apps/personas/useLLMChain.ts index 1eab05494..66682b977 100644 --- a/src/apps/personas/useLLMChain.ts +++ b/src/apps/personas/useLLMChain.ts @@ -1,7 +1,7 @@ import * as React from 'react'; import { DLLMId, useModelsStore } from '~/modules/llms/store-llms'; -import { callChatGenerate, VChatMessageIn } from '~/modules/llms/transports/chatGenerate'; +import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client'; export interface LLMChainStep { @@ -80,7 +80,7 @@ export function useLLMChain(steps: LLMChainStep[], llmId: DLLMId | undefined, ch _chainAbortController.signal.addEventListener('abort', globalToStepListener); // LLM call - callChatGenerate(llmId, llmChatInput, chain.overrideResponseTokens) + llmChatGenerateOrThrow(llmId, llmChatInput, null, null, chain.overrideResponseTokens) .then(({ content }) => { stepDone = true; if (!stepAbortController.signal.aborted) diff --git a/src/common/layout/AppLayout.tsx b/src/common/layout/AppLayout.tsx index 46e52373b..e38a110c9 100644 --- a/src/common/layout/AppLayout.tsx +++ b/src/common/layout/AppLayout.tsx @@ -3,7 +3,7 @@ import { shallow } from 'zustand/shallow'; import { Box, Container } from '@mui/joy'; -import { ModelsModal } from '../../apps/models-modal/ModelsModal'; +import { ModelsModal } from '~/modules/llms/models-modal/ModelsModal'; import { SettingsModal } from '../../apps/settings-modal/SettingsModal'; import { ShortcutsModal } from '../../apps/settings-modal/ShortcutsModal'; diff --git a/src/modules/aifn/autosuggestions/autoSuggestions.ts b/src/modules/aifn/autosuggestions/autoSuggestions.ts index 28097a0c8..65097a9b4 100644 --- a/src/modules/aifn/autosuggestions/autoSuggestions.ts +++ b/src/modules/aifn/autosuggestions/autoSuggestions.ts @@ -1,4 +1,4 @@ -import { callChatGenerateWithFunctions, VChatFunctionIn } from '~/modules/llms/transports/chatGenerate'; +import { llmChatGenerateOrThrow, VChatFunctionIn } from '~/modules/llms/llm.client'; import { useModelsStore } from '~/modules/llms/store-llms'; import { useChatStore } from '~/common/state/store-chats'; @@ -71,7 +71,7 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri // Follow-up: Question if (suggestQuestions) { - // callChatGenerateWithFunctions(funcLLMId, [ + // llmChatGenerateOrThrow(funcLLMId, [ // { role: 'system', content: systemMessage.text }, // { role: 'user', content: userMessage.text }, // { role: 'assistant', content: assistantMessageText }, @@ -83,15 +83,18 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri // Follow-up: Auto-Diagrams if (suggestDiagrams) { - void callChatGenerateWithFunctions(funcLLMId, [ + void llmChatGenerateOrThrow(funcLLMId, [ { role: 'system', content: systemMessage.text }, { role: 'user', content: userMessage.text }, { role: 'assistant', content: assistantMessageText }, ], [suggestPlantUMLFn], 'draw_plantuml_diagram', ).then(chatResponse => { + if (!('function_arguments' in chatResponse)) + return; + // parse the output PlantUML string, if any - const functionArguments = chatResponse?.function_arguments ?? null; + const functionArguments = chatResponse.function_arguments ?? null; if (functionArguments) { const { code, type }: { code: string, type: string } = functionArguments as any; if (code && type) { @@ -105,6 +108,8 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri editMessage(conversationId, assistantMessageId, { text: assistantMessageText }, false); } } + }).catch(err => { + console.error('autoSuggestions::diagram:', err); }); } diff --git a/src/modules/aifn/autotitle/autoTitle.ts b/src/modules/aifn/autotitle/autoTitle.ts index 2e29771fa..4172b6506 100644 --- a/src/modules/aifn/autotitle/autoTitle.ts +++ b/src/modules/aifn/autotitle/autoTitle.ts @@ -1,4 +1,4 @@ -import { callChatGenerate } from '~/modules/llms/transports/chatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client'; import { useModelsStore } from '~/modules/llms/store-llms'; import { useChatStore } from '~/common/state/store-chats'; @@ -27,7 +27,7 @@ export function autoTitle(conversationId: string) { }); // LLM - void callChatGenerate(fastLLMId, [ + void llmChatGenerateOrThrow(fastLLMId, [ { role: 'system', content: `You are an AI conversation titles assistant who specializes in creating expressive yet few-words chat titles.` }, { role: 'user', content: @@ -39,7 +39,7 @@ export function autoTitle(conversationId: string) { historyLines.join('\n') + '```\n', }, - ]).then(chatResponse => { + ], null, null).then(chatResponse => { const title = chatResponse?.content ?.trim() diff --git a/src/modules/aifn/digrams/DiagramsModal.tsx b/src/modules/aifn/digrams/DiagramsModal.tsx index 68429a128..b0f8da1e7 100644 --- a/src/modules/aifn/digrams/DiagramsModal.tsx +++ b/src/modules/aifn/digrams/DiagramsModal.tsx @@ -8,8 +8,9 @@ import ReplayIcon from '@mui/icons-material/Replay'; import StopOutlinedIcon from '@mui/icons-material/StopOutlined'; import TelegramIcon from '@mui/icons-material/Telegram'; +import { llmStreamingChatGenerate } from '~/modules/llms/llm.client'; + import { ChatMessage } from '../../../apps/chat/components/message/ChatMessage'; -import { streamChat } from '~/modules/llms/transports/streamChat'; import { GoodModal } from '~/common/components/GoodModal'; import { InlineError } from '~/common/components/InlineError'; @@ -85,7 +86,7 @@ export function DiagramsModal(props: { config: DiagramConfig, onClose: () => voi const diagramPrompt = bigDiagramPrompt(diagramType, diagramLanguage, systemMessage.text, subject, customInstruction); try { - await streamChat(diagramLlm.id, diagramPrompt, stepAbortController.signal, + await llmStreamingChatGenerate(diagramLlm.id, diagramPrompt, null, null, stepAbortController.signal, (update: Partial<{ text: string, typing: boolean, originLLM: string }>) => { assistantMessage = { ...assistantMessage, ...update }; setMessage(assistantMessage); diff --git a/src/modules/aifn/digrams/diagrams.data.ts b/src/modules/aifn/digrams/diagrams.data.ts index 54239118e..d1b424d35 100644 --- a/src/modules/aifn/digrams/diagrams.data.ts +++ b/src/modules/aifn/digrams/diagrams.data.ts @@ -1,6 +1,5 @@ -import type { VChatMessageIn } from '~/modules/llms/transports/chatGenerate'; - import type { FormRadioOption } from '~/common/components/forms/FormRadioControl'; +import type { VChatMessageIn } from '~/modules/llms/llm.client'; export type DiagramType = 'auto' | 'mind'; diff --git a/src/modules/aifn/imagine/imaginePromptFromText.ts b/src/modules/aifn/imagine/imaginePromptFromText.ts index 211ac7abd..c1556630b 100644 --- a/src/modules/aifn/imagine/imaginePromptFromText.ts +++ b/src/modules/aifn/imagine/imaginePromptFromText.ts @@ -1,4 +1,4 @@ -import { callChatGenerate } from '~/modules/llms/transports/chatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client'; import { useModelsStore } from '~/modules/llms/store-llms'; @@ -14,10 +14,10 @@ export async function imaginePromptFromText(messageText: string): Promise { + await llmStreamingChatGenerate(llmId, prompt, null, null, abortControllerRef.current.signal, (update) => { if (update.text) { lastText = update.text; setPartialText(lastText); diff --git a/src/modules/backend/backend.router.ts b/src/modules/backend/backend.router.ts index ad207a75e..fd8e70649 100644 --- a/src/modules/backend/backend.router.ts +++ b/src/modules/backend/backend.router.ts @@ -28,6 +28,7 @@ export const backendRouter = createTRPCRouter({ hasImagingProdia: !!env.PRODIA_API_KEY, hasLlmAnthropic: !!env.ANTHROPIC_API_KEY, hasLlmAzureOpenAI: !!env.AZURE_OPENAI_API_KEY && !!env.AZURE_OPENAI_API_ENDPOINT, + hasLlmGemini: !!env.GEMINI_API_KEY, hasLlmMistral: !!env.MISTRAL_API_KEY, hasLlmOllama: !!env.OLLAMA_API_HOST, hasLlmOpenAI: !!env.OPENAI_API_KEY || !!env.OPENAI_API_HOST, diff --git a/src/modules/backend/state-backend.ts b/src/modules/backend/state-backend.ts index 034269b5f..ebe0025b3 100644 --- a/src/modules/backend/state-backend.ts +++ b/src/modules/backend/state-backend.ts @@ -9,6 +9,7 @@ export interface BackendCapabilities { hasImagingProdia: boolean; hasLlmAnthropic: boolean; hasLlmAzureOpenAI: boolean; + hasLlmGemini: boolean; hasLlmMistral: boolean; hasLlmOllama: boolean; hasLlmOpenAI: boolean; @@ -31,6 +32,7 @@ const useBackendStore = create()( hasImagingProdia: false, hasLlmAnthropic: false, hasLlmAzureOpenAI: false, + hasLlmGemini: false, hasLlmMistral: false, hasLlmOllama: false, hasLlmOpenAI: false, diff --git a/src/modules/llms/llm.client.ts b/src/modules/llms/llm.client.ts new file mode 100644 index 000000000..73957f08c --- /dev/null +++ b/src/modules/llms/llm.client.ts @@ -0,0 +1,74 @@ +import type { DLLMId } from './store-llms'; +import type { OpenAIWire } from './server/openai/openai.wiretypes'; +import { findVendorForLlmOrThrow } from './vendors/vendors.registry'; + + +// LLM Client Types +// NOTE: Model List types in '../server/llm.server.types'; + +export interface VChatMessageIn { + role: 'assistant' | 'system' | 'user'; // | 'function'; + content: string; + //name?: string; // when role: 'function' +} + +export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef; + +export interface VChatMessageOut { + role: 'assistant' | 'system' | 'user'; + content: string; + finish_reason: 'stop' | 'length' | null; +} + +export interface VChatMessageOrFunctionCallOut extends VChatMessageOut { + function_name: string; + function_arguments: object | null; +} + + +// LLM Client Functions + +export async function llmChatGenerateOrThrow( + llmId: DLLMId, + messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, forceFunctionName: string | null, + maxTokens?: number, +): Promise { + + // id to DLLM and vendor + const { llm, vendor } = findVendorForLlmOrThrow(llmId); + + // FIXME: relax the forced cast + const options = llm.options as TLLMOptions; + + // get the access + const partialSourceSetup = llm._source.setup; + const access = vendor.getTransportAccess(partialSourceSetup); + + // execute via the vendor + return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens); +} + + +export async function llmStreamingChatGenerate( + llmId: DLLMId, + messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, + forceFunctionName: string | null, + abortSignal: AbortSignal, + onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void, +): Promise { + + // id to DLLM and vendor + const { llm, vendor } = findVendorForLlmOrThrow(llmId); + + // FIXME: relax the forced cast + const llmOptions = llm.options as TLLMOptions; + + // get the access + const partialSourceSetup = llm._source.setup; + const access = vendor.getTransportAccess(partialSourceSetup); // as ChatStreamInputSchema['access']; + + // execute via the vendor + return await vendor.streamingChatGenerateOrThrow(access, llmId, llmOptions, messages, functions, forceFunctionName, abortSignal, onUpdate); +} diff --git a/src/apps/models-modal/LLMOptionsModal.tsx b/src/modules/llms/models-modal/LLMOptionsModal.tsx similarity index 94% rename from src/apps/models-modal/LLMOptionsModal.tsx rename to src/modules/llms/models-modal/LLMOptionsModal.tsx index 65b56a11c..10051bff8 100644 --- a/src/apps/models-modal/LLMOptionsModal.tsx +++ b/src/modules/llms/models-modal/LLMOptionsModal.tsx @@ -117,9 +117,9 @@ export function LLMOptionsModal(props: { id: DLLMId }) { setShowDetails(!showDetails)} /> {showDetails && [{llm.id}]: {llm.options.llmRef && `${llm.options.llmRef} ยท `} - {llm.contextTokens && `context tokens: ${llm.contextTokens.toLocaleString()} ยท `} - {llm.maxOutputTokens && `max output tokens: ${llm.maxOutputTokens.toLocaleString()} ยท `} - {llm.created && `created: ${(new Date(llm.created * 1000)).toLocaleString()} ยท `} + {!!llm.contextTokens && `context tokens: ${llm.contextTokens.toLocaleString()} ยท `} + {!!llm.maxOutputTokens && `max output tokens: ${llm.maxOutputTokens.toLocaleString()} ยท `} + {!!llm.created && `created: ${(new Date(llm.created * 1000)).toLocaleString()} ยท `} description: {llm.description} {/*ยท tags: {llm.tags.join(', ')}*/} } diff --git a/src/apps/models-modal/ModelsList.tsx b/src/modules/llms/models-modal/ModelsList.tsx similarity index 94% rename from src/apps/models-modal/ModelsList.tsx rename to src/modules/llms/models-modal/ModelsList.tsx index 0336aeec0..b24d6d3b6 100644 --- a/src/apps/models-modal/ModelsList.tsx +++ b/src/modules/llms/models-modal/ModelsList.tsx @@ -111,7 +111,13 @@ export function ModelsList(props: { pl: { xs: 0, md: 1 }, overflowY: 'auto', }}> - {items} + {items.length > 0 ? items : ( + + + Please configure the service and update the list of models. + + + )} ); } \ No newline at end of file diff --git a/src/apps/models-modal/ModelsModal.tsx b/src/modules/llms/models-modal/ModelsModal.tsx similarity index 98% rename from src/apps/models-modal/ModelsModal.tsx rename to src/modules/llms/models-modal/ModelsModal.tsx index e870914a1..57b7ebbed 100644 --- a/src/apps/models-modal/ModelsModal.tsx +++ b/src/modules/llms/models-modal/ModelsModal.tsx @@ -65,7 +65,7 @@ export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) { title={<>Configure AI Models} startButton={ multiSource ? setShowAllSources(all => !all)} /> : undefined } diff --git a/src/apps/models-modal/ModelsSourceSelector.tsx b/src/modules/llms/models-modal/ModelsSourceSelector.tsx similarity index 96% rename from src/apps/models-modal/ModelsSourceSelector.tsx rename to src/modules/llms/models-modal/ModelsSourceSelector.tsx index ef501a4b7..4d9d1fca9 100644 --- a/src/apps/models-modal/ModelsSourceSelector.tsx +++ b/src/modules/llms/models-modal/ModelsSourceSelector.tsx @@ -5,9 +5,9 @@ import { Avatar, Badge, Box, Button, IconButton, ListItemDecorator, MenuItem, Op import AddIcon from '@mui/icons-material/Add'; import DeleteOutlineIcon from '@mui/icons-material/DeleteOutline'; -import { type DModelSourceId, useModelsStore } from '~/modules/llms/store-llms'; -import { type IModelVendor, type ModelVendorId } from '~/modules/llms/vendors/IModelVendor'; -import { createModelSourceForVendor, findAllVendors, findVendorById } from '~/modules/llms/vendors/vendors.registry'; +import type { IModelVendor } from '~/modules/llms/vendors/IModelVendor'; +import { DModelSourceId, useModelsStore } from '~/modules/llms/store-llms'; +import { createModelSourceForVendor, findAllVendors, findVendorById, ModelVendorId } from '~/modules/llms/vendors/vendors.registry'; import { CloseableMenu } from '~/common/components/CloseableMenu'; import { ConfirmationModal } from '~/common/components/ConfirmationModal'; diff --git a/src/modules/llms/transports/server/anthropic/anthropic.models.ts b/src/modules/llms/server/anthropic/anthropic.models.ts similarity index 93% rename from src/modules/llms/transports/server/anthropic/anthropic.models.ts rename to src/modules/llms/server/anthropic/anthropic.models.ts index eb4e4117a..6bbbfc55c 100644 --- a/src/modules/llms/transports/server/anthropic/anthropic.models.ts +++ b/src/modules/llms/server/anthropic/anthropic.models.ts @@ -1,6 +1,6 @@ -import type { ModelDescriptionSchema } from '../server.schemas'; +import type { ModelDescriptionSchema } from '../llm.server.types'; -import { LLM_IF_OAI_Chat } from '../../../store-llms'; +import { LLM_IF_OAI_Chat } from '../../store-llms'; const roundTime = (date: string) => Math.round(new Date(date).getTime() / 1000); diff --git a/src/modules/llms/transports/server/anthropic/anthropic.router.ts b/src/modules/llms/server/anthropic/anthropic.router.ts similarity index 98% rename from src/modules/llms/transports/server/anthropic/anthropic.router.ts rename to src/modules/llms/server/anthropic/anthropic.router.ts index 4433d2740..2ceb003b7 100644 --- a/src/modules/llms/transports/server/anthropic/anthropic.router.ts +++ b/src/modules/llms/server/anthropic/anthropic.router.ts @@ -6,7 +6,7 @@ import { env } from '~/server/env.mjs'; import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; -import { listModelsOutputSchema } from '../server.schemas'; +import { listModelsOutputSchema } from '../llm.server.types'; import { AnthropicWire } from './anthropic.wiretypes'; import { hardcodedAnthropicModels } from './anthropic.models'; diff --git a/src/modules/llms/transports/server/anthropic/anthropic.wiretypes.ts b/src/modules/llms/server/anthropic/anthropic.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/anthropic/anthropic.wiretypes.ts rename to src/modules/llms/server/anthropic/anthropic.wiretypes.ts diff --git a/src/modules/llms/server/gemini/gemini.router.ts b/src/modules/llms/server/gemini/gemini.router.ts new file mode 100644 index 000000000..a2625814d --- /dev/null +++ b/src/modules/llms/server/gemini/gemini.router.ts @@ -0,0 +1,216 @@ +import { z } from 'zod'; +import { TRPCError } from '@trpc/server'; +import { env } from '~/server/env.mjs'; + +import packageJson from '../../../../../package.json'; + +import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server'; +import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; + +import { LLM_IF_OAI_Chat, LLM_IF_OAI_Vision } from '../../store-llms'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types'; + +import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; + +import { GeminiBlockSafetyLevel, geminiBlockSafetyLevelSchema, GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes'; + + +// Default hosts +const DEFAULT_GEMINI_HOST = 'https://generativelanguage.googleapis.com'; + + +// Mappers + +export function geminiAccess(access: GeminiAccessSchema, modelRefId: string | null, apiPath: string): { headers: HeadersInit, url: string } { + + const geminiKey = access.geminiKey || env.GEMINI_API_KEY || ''; + const geminiHost = fixupHost(DEFAULT_GEMINI_HOST, apiPath); + + // update model-dependent paths + if (apiPath.includes('{model=models/*}')) { + if (!modelRefId) + throw new Error(`geminiAccess: modelRefId is required for ${apiPath}`); + apiPath = apiPath.replace('{model=models/*}', modelRefId); + } + + return { + headers: { + 'Content-Type': 'application/json', + 'x-goog-api-client': `big-agi/${packageJson['version'] || '1.0.0'}`, + 'x-goog-api-key': geminiKey, + }, + url: geminiHost + apiPath, + }; +} + +/** + * We specially encode the history to match the Gemini API requirements. + * Gemini does not want 2 consecutive messages from the same role, so we alternate. + * - System messages = [User, Model'Ok'] + * - User and Assistant messages are coalesced into a single message (e.g. [User, User, Assistant, Assistant, User] -> [User[2], Assistant[2], User[1]]) + */ +export const geminiGenerateContentTextPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, safety: GeminiBlockSafetyLevel, n: number): GeminiGenerateContentRequest => { + + // convert the history to a Gemini format + const contents: GeminiContentSchema[] = []; + for (const _historyElement of history) { + + const { role: msgRole, content: msgContent } = _historyElement; + + // System message - we treat it as per the example in https://ai.google.dev/tutorials/ai-studio_quickstart#chat_example + if (msgRole === 'system') { + contents.push({ role: 'user', parts: [{ text: msgContent }] }); + contents.push({ role: 'model', parts: [{ text: 'Ok' }] }); + continue; + } + + // User or Assistant message + const nextRole: GeminiContentSchema['role'] = msgRole === 'assistant' ? 'model' : 'user'; + if (contents.length && contents[contents.length - 1].role === nextRole) { + // coalesce with the previous message + contents[contents.length - 1].parts.push({ text: msgContent }); + } else { + // create a new message + contents.push({ role: nextRole, parts: [{ text: msgContent }] }); + } + } + + return { + contents, + generationConfig: { + ...(n >= 2 && { candidateCount: n }), + ...(model.maxTokens && { maxOutputTokens: model.maxTokens }), + temperature: model.temperature, + }, + safetySettings: safety !== 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' ? [ + { category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold: safety }, + { category: 'HARM_CATEGORY_HATE_SPEECH', threshold: safety }, + { category: 'HARM_CATEGORY_HARASSMENT', threshold: safety }, + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: safety }, + ] : undefined, + }; +}; + + +async function geminiGET(access: GeminiAccessSchema, modelRefId: string | null, apiPath: string /*, signal?: AbortSignal*/): Promise { + const { headers, url } = geminiAccess(access, modelRefId, apiPath); + return await fetchJsonOrTRPCError(url, 'GET', headers, undefined, 'Gemini'); +} + +async function geminiPOST(access: GeminiAccessSchema, modelRefId: string | null, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise { + const { headers, url } = geminiAccess(access, modelRefId, apiPath); + return await fetchJsonOrTRPCError(url, 'POST', headers, body, 'Gemini'); +} + + +// Input/Output Schemas + +export const geminiAccessSchema = z.object({ + dialect: z.enum(['gemini']), + geminiKey: z.string(), + minSafetyLevel: geminiBlockSafetyLevelSchema, +}); +export type GeminiAccessSchema = z.infer; + + +const accessOnlySchema = z.object({ + access: geminiAccessSchema, +}); + +const chatGenerateInputSchema = z.object({ + access: geminiAccessSchema, + model: openAIModelSchema, history: openAIHistorySchema, + // functions: openAIFunctionsSchema.optional(), forceFunctionName: z.string().optional(), +}); + + +/** + * See https://github.com/google/generative-ai-js/tree/main/packages/main/src for + * the official Google implementation. + */ +export const llmGeminiRouter = createTRPCRouter({ + + /* [Gemini] models.list = /v1beta/models */ + listModels: publicProcedure + .input(accessOnlySchema) + .output(listModelsOutputSchema) + .query(async ({ input }) => { + + // get the models + const wireModels = await geminiGET(input.access, null, geminiModelsListPath); + const detailedModels = geminiModelsListOutputSchema.parse(wireModels).models; + + // NOTE: no need to retrieve info for each of the models (e.g. /v1beta/model/gemini-pro)., + // as the List API already all the info on all the models + + // map to our output schema + return { + models: detailedModels.map((geminiModel) => { + const { description, displayName, inputTokenLimit, name, outputTokenLimit, supportedGenerationMethods } = geminiModel; + + const contextWindow = inputTokenLimit + outputTokenLimit; + const hidden = !supportedGenerationMethods.includes('generateContent'); + + const { version, topK, topP, temperature } = geminiModel; + const descriptionLong = description + ` (Version: ${version}, Defaults: temperature=${temperature}, topP=${topP}, topK=${topK}, interfaces=[${supportedGenerationMethods.join(',')}])`; + + // const isGeminiPro = name.includes('gemini-pro'); + const isGeminiProVision = name.includes('gemini-pro-vision'); + + const interfaces: ModelDescriptionSchema['interfaces'] = []; + if (supportedGenerationMethods.includes('generateContent')) { + interfaces.push(LLM_IF_OAI_Chat); + if (isGeminiProVision) + interfaces.push(LLM_IF_OAI_Vision); + } + + return { + id: name, + label: displayName, + // created: ... + // updated: ... + description: descriptionLong, + contextWindow: contextWindow, + maxCompletionTokens: outputTokenLimit, + // pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined, + // rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined, + interfaces: supportedGenerationMethods.includes('generateContent') ? [LLM_IF_OAI_Chat] : [], + hidden, + } satisfies ModelDescriptionSchema; + }), + }; + }), + + + /* [Gemini] models.generateContent = /v1/{model=models/*}:generateContent */ + chatGenerate: publicProcedure + .input(chatGenerateInputSchema) + .output(openAIChatGenerateOutputSchema) + .mutation(async ({ input: { access, history, model } }) => { + + // generate the content + const wireGeneration = await geminiPOST(access, model.id, geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1), geminiModelsGenerateContentPath); + const generation = geminiGeneratedContentResponseSchema.parse(wireGeneration); + + // only use the first result (and there should be only one) + const singleCandidate = generation.candidates?.[0] ?? null; + if (!singleCandidate || !singleCandidate.content?.parts.length) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `Gemini chat-generation API issue: ${JSON.stringify(wireGeneration)}`, + }); + + if (!('text' in singleCandidate.content.parts[0])) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `Gemini non-text chat-generation API issue: ${JSON.stringify(wireGeneration)}`, + }); + + return { + role: 'assistant', + content: singleCandidate.content.parts[0].text || '', + finish_reason: singleCandidate.finishReason === 'STOP' ? 'stop' : null, + }; + }), + +}); diff --git a/src/modules/llms/server/gemini/gemini.wiretypes.ts b/src/modules/llms/server/gemini/gemini.wiretypes.ts new file mode 100644 index 000000000..c7e4f9a3e --- /dev/null +++ b/src/modules/llms/server/gemini/gemini.wiretypes.ts @@ -0,0 +1,188 @@ +import { z } from 'zod'; + +// PATHS + +export const geminiModelsListPath = '/v1beta/models?pageSize=1000'; +export const geminiModelsGenerateContentPath = '/v1beta/{model=models/*}:generateContent'; +// see alt=sse on https://cloud.google.com/apis/docs/system-parameters#definitions +export const geminiModelsStreamGenerateContentPath = '/v1beta/{model=models/*}:streamGenerateContent?alt=sse'; + + +// models.list = /v1beta/models + +export const geminiModelsListOutputSchema = z.object({ + models: z.array(z.object({ + name: z.string(), + version: z.string(), + displayName: z.string(), + description: z.string(), + inputTokenLimit: z.number().int().min(1), + outputTokenLimit: z.number().int().min(1), + supportedGenerationMethods: z.array(z.enum([ + 'countMessageTokens', + 'countTextTokens', + 'countTokens', + 'createTunedTextModel', + 'embedContent', + 'embedText', + 'generateAnswer', + 'generateContent', + 'generateMessage', + 'generateText', + ])), + temperature: z.number().optional(), + topP: z.number().optional(), + topK: z.number().optional(), + })), +}); + + +// /v1/{model=models/*}:generateContent, /v1beta/{model=models/*}:streamGenerateContent + +// Request + +const geminiContentPartSchema = z.union([ + + // TextPart + z.object({ + text: z.string().optional(), + }), + + // InlineDataPart + z.object({ + inlineData: z.object({ + mimeType: z.string(), + data: z.string(), // base64-encoded string + }), + }), + + // A predicted FunctionCall returned from the model + z.object({ + functionCall: z.object({ + name: z.string(), + args: z.record(z.any()), // JSON object format + }), + }), + + // The result output of a FunctionCall + z.object({ + functionResponse: z.object({ + name: z.string(), + response: z.record(z.any()), // JSON object format + }), + }), +]); + +const geminiToolSchema = z.object({ + functionDeclarations: z.array(z.object({ + name: z.string(), + description: z.string(), + parameters: z.record(z.any()).optional(), // Schema object format + })).optional(), +}); + +const geminiHarmCategorySchema = z.enum([ + 'HARM_CATEGORY_UNSPECIFIED', + 'HARM_CATEGORY_DEROGATORY', + 'HARM_CATEGORY_TOXICITY', + 'HARM_CATEGORY_VIOLENCE', + 'HARM_CATEGORY_SEXUAL', + 'HARM_CATEGORY_MEDICAL', + 'HARM_CATEGORY_DANGEROUS', + 'HARM_CATEGORY_HARASSMENT', + 'HARM_CATEGORY_HATE_SPEECH', + 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + 'HARM_CATEGORY_DANGEROUS_CONTENT', +]); + +export const geminiBlockSafetyLevelSchema = z.enum([ + 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + 'BLOCK_LOW_AND_ABOVE', + 'BLOCK_MEDIUM_AND_ABOVE', + 'BLOCK_ONLY_HIGH', + 'BLOCK_NONE', +]); + +export type GeminiBlockSafetyLevel = z.infer; + +const geminiSafetySettingSchema = z.object({ + category: geminiHarmCategorySchema, + threshold: geminiBlockSafetyLevelSchema, +}); + +const geminiGenerationConfigSchema = z.object({ + stopSequences: z.array(z.string()).optional(), + candidateCount: z.number().int().optional(), + maxOutputTokens: z.number().int().optional(), + temperature: z.number().optional(), + topP: z.number().optional(), + topK: z.number().int().optional(), +}); + +const geminiContentSchema = z.object({ + // Must be either 'user' or 'model'. Optional but must be set if there are multiple "Content" objects in the parent array. + role: z.enum(['user', 'model']).optional(), + // Ordered Parts that constitute a single message. Parts may have different MIME types. + parts: z.array(geminiContentPartSchema), +}); + +export type GeminiContentSchema = z.infer; + +export const geminiGenerateContentRequest = z.object({ + contents: z.array(geminiContentSchema), + tools: z.array(geminiToolSchema).optional(), + safetySettings: z.array(geminiSafetySettingSchema).optional(), + generationConfig: geminiGenerationConfigSchema.optional(), +}); + +export type GeminiGenerateContentRequest = z.infer; + + +// Response + +const geminiHarmProbabilitySchema = z.enum([ + 'HARM_PROBABILITY_UNSPECIFIED', + 'NEGLIGIBLE', + 'LOW', + 'MEDIUM', + 'HIGH', +]); + +const geminiSafetyRatingSchema = z.object({ + 'category': geminiHarmCategorySchema, + 'probability': geminiHarmProbabilitySchema, + 'blocked': z.boolean().optional(), +}); + +const geminiFinishReasonSchema = z.enum([ + 'FINISH_REASON_UNSPECIFIED', + 'STOP', + 'MAX_TOKENS', + 'SAFETY', + 'RECITATION', + 'OTHER', +]); + +export const geminiGeneratedContentResponseSchema = z.object({ + // either all requested candidates are returned or no candidates at all + // no candidates are returned only if there was something wrong with the prompt (see promptFeedback) + candidates: z.array(z.object({ + index: z.number(), + content: geminiContentSchema, + finishReason: geminiFinishReasonSchema.optional(), + safetyRatings: z.array(geminiSafetyRatingSchema), + citationMetadata: z.object({ + startIndex: z.number().optional(), + endIndex: z.number().optional(), + uri: z.string().optional(), + license: z.string().optional(), + }).optional(), + tokenCount: z.number().optional(), + // groundingAttributions: z.array(GroundingAttribution).optional(), // This field is populated for GenerateAnswer calls. + })).optional(), + // NOTE: promptFeedback is only send in the first chunk in a streaming response + promptFeedback: z.object({ + blockReason: z.enum(['BLOCK_REASON_UNSPECIFIED', 'SAFETY', 'OTHER']).optional(), + safetyRatings: z.array(geminiSafetyRatingSchema).optional(), + }).optional(), +}); diff --git a/src/modules/llms/transports/server/openai/openai.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts similarity index 59% rename from src/modules/llms/transports/server/openai/openai.streaming.ts rename to src/modules/llms/server/llm.server.streaming.ts index e8065bcc4..c5a1d4544 100644 --- a/src/modules/llms/transports/server/openai/openai.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -4,12 +4,30 @@ import { createParser as createEventsourceParser, EventSourceParseCallback, Even import { createEmptyReadableStream, debugGenerateCurlCommand, safeErrorString, SERVER_DEBUG_WIRE, serverFetchOrThrow } from '~/server/wire'; -import type { AnthropicWire } from '../anthropic/anthropic.wiretypes'; -import type { OpenAIWire } from './openai.wiretypes'; -import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from '../ollama/ollama.router'; -import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from '../anthropic/anthropic.router'; -import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai.router'; -import { wireOllamaChunkedOutputSchema } from '../ollama/ollama.wiretypes'; + +// Anthropic server imports +import type { AnthropicWire } from './anthropic/anthropic.wiretypes'; +import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from './anthropic/anthropic.router'; + +// Gemini server imports +import { geminiAccess, geminiAccessSchema, geminiGenerateContentTextPayload } from './gemini/gemini.router'; +import { geminiGeneratedContentResponseSchema, geminiModelsStreamGenerateContentPath } from './gemini/gemini.wiretypes'; + +// Ollama server imports +import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes'; +import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from './ollama/ollama.router'; + +// OpenAI server imports +import type { OpenAIWire } from './openai/openai.wiretypes'; +import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai/openai.router'; + + +/** + * Event stream formats + * - 'sse' is the default format, and is used by all vendors except Ollama + * - 'json-nl' is used by Ollama + */ +type MuxingFormat = 'sse' | 'json-nl'; /** @@ -20,49 +38,58 @@ import { wireOllamaChunkedOutputSchema } from '../ollama/ollama.wiretypes'; * The peculiarity of our parser is the injection of a JSON structure at the beginning of the stream, to * communicate parameters before the text starts flowing to the client. */ -export type AIStreamParser = (data: string) => { text: string, close: boolean }; - -type EventStreamFormat = 'sse' | 'json-nl'; +type AIStreamParser = (data: string) => { text: string, close: boolean }; -const chatStreamInputSchema = z.object({ - access: z.union([anthropicAccessSchema, ollamaAccessSchema, openAIAccessSchema]), - model: openAIModelSchema, history: openAIHistorySchema, +const chatStreamingInputSchema = z.object({ + access: z.union([anthropicAccessSchema, geminiAccessSchema, ollamaAccessSchema, openAIAccessSchema]), + model: openAIModelSchema, + history: openAIHistorySchema, }); -export type ChatStreamInputSchema = z.infer; +export type ChatStreamingInputSchema = z.infer; -const chatStreamFirstPacketSchema = z.object({ +const chatStreamingFirstOutputPacketSchema = z.object({ model: z.string(), }); -export type ChatStreamFirstPacketSchema = z.infer; +export type ChatStreamingFirstOutputPacketSchema = z.infer; -export async function openaiStreamingRelayHandler(req: NextRequest): Promise { +export async function llmStreamingRelayHandler(req: NextRequest): Promise { // inputs - reuse the tRPC schema - const { access, model, history } = chatStreamInputSchema.parse(await req.json()); + const body = await req.json(); + const { access, model, history } = chatStreamingInputSchema.parse(body); - // begin event streaming from the OpenAI API - let headersUrl: { headers: HeadersInit, url: string } = { headers: {}, url: '' }; + // access/dialect dependent setup: + // - requestAccess: the headers and URL to use for the upstream API call + // - muxingFormat: the format of the event stream (sse or json-nl) + // - vendorStreamParser: the parser to use for the event stream let upstreamResponse: Response; + let requestAccess: { headers: HeadersInit, url: string } = { headers: {}, url: '' }; + let muxingFormat: MuxingFormat = 'sse'; let vendorStreamParser: AIStreamParser; - let eventStreamFormat: EventStreamFormat = 'sse'; try { // prepare the API request data let body: object; switch (access.dialect) { case 'anthropic': - headersUrl = anthropicAccess(access, '/v1/complete'); + requestAccess = anthropicAccess(access, '/v1/complete'); body = anthropicChatCompletionPayload(model, history, true); - vendorStreamParser = createAnthropicStreamParser(); + vendorStreamParser = createStreamParserAnthropic(); + break; + + case 'gemini': + requestAccess = geminiAccess(access, model.id, geminiModelsStreamGenerateContentPath); + body = geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1); + vendorStreamParser = createStreamParserGemini(model.id.replace('models/', '')); break; case 'ollama': - headersUrl = ollamaAccess(access, OLLAMA_PATH_CHAT); + requestAccess = ollamaAccess(access, OLLAMA_PATH_CHAT); body = ollamaChatCompletionPayload(model, history, true); - eventStreamFormat = 'json-nl'; - vendorStreamParser = createOllamaChatCompletionStreamParser(); + muxingFormat = 'json-nl'; + vendorStreamParser = createStreamParserOllama(); break; case 'azure': @@ -71,27 +98,27 @@ export async function openaiStreamingRelayHandler(req: NextRequest): Promise streaming:', debugGenerateCurlCommand('POST', headersUrl.url, headersUrl.headers, body)); + console.log('-> streaming:', debugGenerateCurlCommand('POST', requestAccess.url, requestAccess.headers, body)); // POST to our API route - upstreamResponse = await serverFetchOrThrow(headersUrl.url, 'POST', headersUrl.headers, body); + upstreamResponse = await serverFetchOrThrow(requestAccess.url, 'POST', requestAccess.headers, body); } catch (error: any) { const fetchOrVendorError = safeErrorString(error) + (error?.cause ? ' ยท ' + error.cause : ''); // server-side admins message - console.error(`/api/llms/stream: fetch issue:`, access.dialect, fetchOrVendorError, headersUrl?.url); + console.error(`/api/llms/stream: fetch issue:`, access.dialect, fetchOrVendorError, requestAccess?.url); // client-side users visible message return new NextResponse(`[Issue] ${access.dialect}: ${fetchOrVendorError}` - + (process.env.NODE_ENV === 'development' ? ` ยท [URL: ${headersUrl?.url}]` : ''), { status: 500 }); + + (process.env.NODE_ENV === 'development' ? ` ยท [URL: ${requestAccess?.url}]` : ''), { status: 500 }); } /* The following code is heavily inspired by the Vercel AI SDK, but simplified to our needs and in full control. @@ -103,8 +130,12 @@ export async function openaiStreamingRelayHandler(req: NextRequest): Promise { + accumulator += chunk; + if (accumulator.endsWith('\n')) { + for (const jsonString of accumulator.split('\n').filter(line => !!line)) { + const mimicEvent: ParsedEvent = { + type: 'event', + id: undefined, + event: undefined, + data: jsonString, + }; + onParse(mimicEvent); + } + accumulator = ''; + } + }, + + // resets the parser state - not useful with our driving of the parser + reset: (): void => { + console.error('createDemuxerJsonNewline.reset() not implemented'); + }, + }; +} + +/** + * Creates a TransformStream that parses events from an EventSource stream using a custom parser. + * @returns {TransformStream} TransformStream parsing events. + */ +function createEventStreamTransformer(muxingFormat: MuxingFormat, vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream { + const textDecoder = new TextDecoder(); + const textEncoder = new TextEncoder(); + let eventSourceParser: EventSourceParser; + + return new TransformStream({ + start: async (controller): Promise => { + + // only used for debugging + let debugLastMs: number | null = null; + + const onNewEvent = (event: ParsedEvent | ReconnectInterval) => { + if (SERVER_DEBUG_WIRE) { + const nowMs = Date.now(); + const elapsedMs = debugLastMs ? nowMs - debugLastMs : 0; + debugLastMs = nowMs; + console.log(`<- SSE (${elapsedMs} ms):`, event); + } + + // ignore 'reconnect-interval' and events with no data + if (event.type !== 'event' || !('data' in event)) + return; + + // event stream termination, close our transformed stream + if (event.data === '[DONE]') { + controller.terminate(); + return; + } + + try { + const { text, close } = vendorTextParser(event.data); + if (text) + controller.enqueue(textEncoder.encode(text)); + if (close) + controller.terminate(); + } catch (error: any) { + if (SERVER_DEBUG_WIRE) + console.log(' - E: parse issue:', event.data, error?.message || error); + controller.enqueue(textEncoder.encode(` **[Stream Issue] ${dialectLabel}: ${safeErrorString(error) || 'Unknown stream parsing error'}**`)); + controller.terminate(); + } + }; + + if (muxingFormat === 'sse') + eventSourceParser = createEventsourceParser(onNewEvent); + else if (muxingFormat === 'json-nl') + eventSourceParser = createDemuxerJsonNewline(onNewEvent); + }, + + // stream=true is set because the data is not guaranteed to be final and un-chunked + transform: (chunk: Uint8Array) => { + eventSourceParser.feed(textDecoder.decode(chunk, { stream: true })); + }, + }); +} + + +/// Stream Parsers + +function createStreamParserAnthropic(): AIStreamParser { let hasBegun = false; return (data: string) => { @@ -128,7 +253,7 @@ function createAnthropicStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun) { hasBegun = true; - const firstPacket: ChatStreamFirstPacketSchema = { model: json.model }; + const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model }; text = JSON.stringify(firstPacket) + text; } @@ -136,7 +261,46 @@ function createAnthropicStreamParser(): AIStreamParser { }; } -function createOllamaChatCompletionStreamParser(): AIStreamParser { +function createStreamParserGemini(modelName: string): AIStreamParser { + let hasBegun = false; + + // this can throw, it's catched upstream + return (data: string) => { + + // parse the JSON chunk + const wireGenerationChunk = JSON.parse(data); + const generationChunk = geminiGeneratedContentResponseSchema.parse(wireGenerationChunk); + + // Prompt Safety Errors: pass through errors from Gemini + if (generationChunk.promptFeedback?.blockReason) { + const { blockReason, safetyRatings } = generationChunk.promptFeedback; + return { text: `[Gemini Prompt Blocked] ${blockReason}: ${JSON.stringify(safetyRatings || 'Unknown Safety Ratings', null, 2)}`, close: true }; + } + + // expect a single completion + const singleCandidate = generationChunk.candidates?.[0] ?? null; + if (!singleCandidate || !singleCandidate.content?.parts.length) + throw new Error(`Gemini: expected 1 completion, got ${generationChunk.candidates?.length}`); + + // expect a single part + if (singleCandidate.content.parts.length !== 1 || !('text' in singleCandidate.content.parts[0])) + throw new Error(`Gemini: expected 1 text part, got ${singleCandidate.content.parts.length}`); + + // expect a single text in the part + let text = singleCandidate.content.parts[0].text || ''; + + // hack: prepend the model name to the first packet + if (!hasBegun) { + hasBegun = true; + const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: modelName }; + text = JSON.stringify(firstPacket) + text; + } + + return { text, close: false }; + }; +} + +function createStreamParserOllama(): AIStreamParser { let hasBegun = false; return (data: string) => { @@ -164,7 +328,7 @@ function createOllamaChatCompletionStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun && chunk.model) { hasBegun = true; - const firstPacket: ChatStreamFirstPacketSchema = { model: chunk.model }; + const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: chunk.model }; text = JSON.stringify(firstPacket) + text; } @@ -172,7 +336,7 @@ function createOllamaChatCompletionStreamParser(): AIStreamParser { }; } -function createOpenAIStreamParser(): AIStreamParser { +function createStreamParserOpenAI(): AIStreamParser { let hasBegun = false; let hasWarned = false; @@ -205,7 +369,7 @@ function createOpenAIStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun) { hasBegun = true; - const firstPacket: ChatStreamFirstPacketSchema = { model: json.model }; + const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model }; text = JSON.stringify(firstPacket) + text; } @@ -213,98 +377,4 @@ function createOpenAIStreamParser(): AIStreamParser { const close = !!json.choices[0].finish_reason; return { text, close }; }; -} - - -// Event Stream Transformers - -/** - * Creates a TransformStream that parses events from an EventSource stream using a custom parser. - * @returns {TransformStream} TransformStream parsing events. - */ -function createEventStreamTransformer(vendorTextParser: AIStreamParser, inputFormat: EventStreamFormat, dialectLabel: string): TransformStream { - const textDecoder = new TextDecoder(); - const textEncoder = new TextEncoder(); - let eventSourceParser: EventSourceParser; - - return new TransformStream({ - start: async (controller): Promise => { - - // only used for debugging - let debugLastMs: number | null = null; - - const onNewEvent = (event: ParsedEvent | ReconnectInterval) => { - if (SERVER_DEBUG_WIRE) { - const nowMs = Date.now(); - const elapsedMs = debugLastMs ? nowMs - debugLastMs : 0; - debugLastMs = nowMs; - console.log(`<- SSE (${elapsedMs} ms):`, event); - } - - // ignore 'reconnect-interval' and events with no data - if (event.type !== 'event' || !('data' in event)) - return; - - // event stream termination, close our transformed stream - if (event.data === '[DONE]') { - controller.terminate(); - return; - } - - try { - const { text, close } = vendorTextParser(event.data); - if (text) - controller.enqueue(textEncoder.encode(text)); - if (close) - controller.terminate(); - } catch (error: any) { - if (SERVER_DEBUG_WIRE) - console.log(' - E: parse issue:', event.data, error?.message || error); - controller.enqueue(textEncoder.encode(` **[Stream Issue] ${dialectLabel}: ${safeErrorString(error) || 'Unknown stream parsing error'}**`)); - controller.terminate(); - } - }; - - if (inputFormat === 'sse') - eventSourceParser = createEventsourceParser(onNewEvent); - else if (inputFormat === 'json-nl') - eventSourceParser = createJsonNewlineParser(onNewEvent); - }, - - // stream=true is set because the data is not guaranteed to be final and un-chunked - transform: (chunk: Uint8Array) => { - eventSourceParser.feed(textDecoder.decode(chunk, { stream: true })); - }, - }); -} - -/** - * Creates a parser for a 'JSON\n' non-event stream, to be swapped with an EventSource parser. - * Ollama is the only vendor that uses this format. - */ -function createJsonNewlineParser(onParse: EventSourceParseCallback): EventSourceParser { - let accumulator: string = ''; - return { - // feeds a new chunk to the parser - we accumulate in case of partial data, and only execute on full lines - feed: (chunk: string): void => { - accumulator += chunk; - if (accumulator.endsWith('\n')) { - for (const jsonString of accumulator.split('\n').filter(line => !!line)) { - const mimicEvent: ParsedEvent = { - type: 'event', - id: undefined, - event: undefined, - data: jsonString, - }; - onParse(mimicEvent); - } - accumulator = ''; - } - }, - - // resets the parser state - not useful with our driving of the parser - reset: (): void => { - console.error('createJsonNewlineParser.reset() not implemented'); - }, - }; -} +} \ No newline at end of file diff --git a/src/modules/llms/transports/server/server.schemas.ts b/src/modules/llms/server/llm.server.types.ts similarity index 76% rename from src/modules/llms/transports/server/server.schemas.ts rename to src/modules/llms/server/llm.server.types.ts index 4614f4ba3..15575c624 100644 --- a/src/modules/llms/transports/server/server.schemas.ts +++ b/src/modules/llms/server/llm.server.types.ts @@ -1,11 +1,18 @@ import { z } from 'zod'; -import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../store-llms'; +import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../store-llms'; + + +// Model Description: a superset of LLM model descriptors const pricingSchema = z.object({ cpmPrompt: z.number().optional(), // Cost per thousand prompt tokens cpmCompletion: z.number().optional(), // Cost per thousand completion tokens }); +// const rateLimitsSchema = z.object({ +// reqPerMinute: z.number().optional(), +// }); + const modelDescriptionSchema = z.object({ id: z.string(), label: z.string(), @@ -15,9 +22,12 @@ const modelDescriptionSchema = z.object({ contextWindow: z.number(), maxCompletionTokens: z.number().optional(), pricing: pricingSchema.optional(), + // rateLimits: rateLimitsSchema.optional(), interfaces: z.array(z.enum([LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Complete, LLM_IF_OAI_Vision])), hidden: z.boolean().optional(), }); + +// this is also used by the Client export type ModelDescriptionSchema = z.infer; export const listModelsOutputSchema = z.object({ diff --git a/src/modules/llms/transports/server/ollama/ollama.models.ts b/src/modules/llms/server/ollama/ollama.models.ts similarity index 100% rename from src/modules/llms/transports/server/ollama/ollama.models.ts rename to src/modules/llms/server/ollama/ollama.models.ts diff --git a/src/modules/llms/transports/server/ollama/ollama.router.ts b/src/modules/llms/server/ollama/ollama.router.ts similarity index 99% rename from src/modules/llms/transports/server/ollama/ollama.router.ts rename to src/modules/llms/server/ollama/ollama.router.ts index 20d89d62f..954e798ad 100644 --- a/src/modules/llms/transports/server/ollama/ollama.router.ts +++ b/src/modules/llms/server/ollama/ollama.router.ts @@ -5,12 +5,12 @@ import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server'; import { env } from '~/server/env.mjs'; import { fetchJsonOrTRPCError, fetchTextOrTRPCError } from '~/server/api/trpc.serverutils'; -import { LLM_IF_OAI_Chat } from '../../../store-llms'; +import { LLM_IF_OAI_Chat } from '../../store-llms'; import { capitalizeFirstLetter } from '~/common/util/textUtils'; import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; -import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types'; import { OLLAMA_BASE_MODELS, OLLAMA_PREV_UPDATE } from './ollama.models'; import { WireOllamaChatCompletionInput, wireOllamaChunkedOutputSchema } from './ollama.wiretypes'; diff --git a/src/modules/llms/transports/server/ollama/ollama.wiretypes.ts b/src/modules/llms/server/ollama/ollama.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/ollama/ollama.wiretypes.ts rename to src/modules/llms/server/ollama/ollama.wiretypes.ts diff --git a/src/modules/llms/transports/server/openai/mistral.wiretypes.ts b/src/modules/llms/server/openai/mistral.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/openai/mistral.wiretypes.ts rename to src/modules/llms/server/openai/mistral.wiretypes.ts diff --git a/src/modules/llms/transports/server/openai/models.data.ts b/src/modules/llms/server/openai/models.data.ts similarity index 99% rename from src/modules/llms/transports/server/openai/models.data.ts rename to src/modules/llms/server/openai/models.data.ts index cc20e574f..28dd5e59f 100644 --- a/src/modules/llms/transports/server/openai/models.data.ts +++ b/src/modules/llms/server/openai/models.data.ts @@ -1,8 +1,8 @@ import { SERVER_DEBUG_WIRE } from '~/server/wire'; -import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../../store-llms'; +import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../store-llms'; -import type { ModelDescriptionSchema } from '../server.schemas'; +import type { ModelDescriptionSchema } from '../llm.server.types'; import { wireMistralModelsListOutputSchema } from './mistral.wiretypes'; diff --git a/src/modules/llms/transports/server/openai/openai.router.ts b/src/modules/llms/server/openai/openai.router.ts similarity index 99% rename from src/modules/llms/transports/server/openai/openai.router.ts rename to src/modules/llms/server/openai/openai.router.ts index 7b903d6ee..93c64ee77 100644 --- a/src/modules/llms/transports/server/openai/openai.router.ts +++ b/src/modules/llms/server/openai/openai.router.ts @@ -8,7 +8,7 @@ import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; import { Brand } from '~/common/app.config'; import type { OpenAIWire } from './openai.wiretypes'; -import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types'; import { localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription } from './models.data'; diff --git a/src/modules/llms/transports/server/openai/openai.wiretypes.ts b/src/modules/llms/server/openai/openai.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/openai/openai.wiretypes.ts rename to src/modules/llms/server/openai/openai.wiretypes.ts diff --git a/src/modules/llms/store-llms.ts b/src/modules/llms/store-llms.ts index d7ad30780..c352eccdc 100644 --- a/src/modules/llms/store-llms.ts +++ b/src/modules/llms/store-llms.ts @@ -2,7 +2,7 @@ import { create } from 'zustand'; import { shallow } from 'zustand/shallow'; import { persist } from 'zustand/middleware'; -import type { IModelVendor, ModelVendorId } from './vendors/IModelVendor'; +import type { ModelVendorId } from './vendors/vendors.registry'; import type { SourceSetupOpenRouter } from './vendors/openrouter/openrouter.vendor'; @@ -16,6 +16,7 @@ export interface DLLM { updated?: number | 0; description: string; tags: string[]; // UNUSED for now + // modelcaps: DModelCapability[]; contextTokens: number; maxOutputTokens: number; hidden: boolean; @@ -30,6 +31,17 @@ export interface DLLM { export type DLLMId = string; +// export type DModelCapability = +// | 'input-text' +// | 'input-image-data' +// | 'input-multipart' +// | 'output-text' +// | 'output-function' +// | 'output-image-data' +// | 'if-chat' +// | 'if-fast-chat' +// ; + // Model interfaces (chat, and function calls) - here as a preview, will be used more broadly in the future export const LLM_IF_OAI_Chat = 'oai-chat'; export const LLM_IF_OAI_Vision = 'oai-vision'; @@ -269,32 +281,3 @@ export function useChatLLM() { }, shallow); } -/** - * Source-specific read/write - great time saver - */ -export function useSourceSetup(sourceId: DModelSourceId, vendor: IModelVendor) { - - // invalidates only when the setup changes - const { updateSourceSetup, ...rest } = useModelsStore(state => { - - // find the source (or null) - const source: DModelSource | null = state.sources.find(source => source.id === sourceId) as DModelSource ?? null; - - // (safe) source-derived properties - const sourceSetupValid = (source?.setup && vendor?.validateSetup) ? vendor.validateSetup(source.setup as TSourceSetup) : false; - const sourceLLMs = source ? state.llms.filter(llm => llm._source === source) : []; - const access = vendor.getTransportAccess(source?.setup); - - return { - source, - access, - sourceHasLLMs: !!sourceLLMs.length, - sourceSetupValid, - updateSourceSetup: state.updateSourceSetup, - }; - }, shallow); - - // convenience function for this source - const updateSetup = (partialSetup: Partial) => updateSourceSetup(sourceId, partialSetup); - return { ...rest, updateSetup }; -} \ No newline at end of file diff --git a/src/modules/llms/transports/chatGenerate.ts b/src/modules/llms/transports/chatGenerate.ts deleted file mode 100644 index 20ff4ba69..000000000 --- a/src/modules/llms/transports/chatGenerate.ts +++ /dev/null @@ -1,34 +0,0 @@ -import type { DLLMId } from '../store-llms'; -import type { OpenAIWire } from './server/openai/openai.wiretypes'; -import { findVendorForLlmOrThrow } from '../vendors/vendors.registry'; - - -export interface VChatMessageIn { - role: 'assistant' | 'system' | 'user'; // | 'function'; - content: string; - //name?: string; // when role: 'function' -} - -export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef; - -export interface VChatMessageOut { - role: 'assistant' | 'system' | 'user'; - content: string; - finish_reason: 'stop' | 'length' | null; -} - -export interface VChatMessageOrFunctionCallOut extends VChatMessageOut { - function_name: string; - function_arguments: object | null; -} - - -export async function callChatGenerate(llmId: DLLMId, messages: VChatMessageIn[], maxTokens?: number): Promise { - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - return await vendor.callChatGenerate(llm, messages, maxTokens); -} - -export async function callChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], forceFunctionName: string | null, maxTokens?: number): Promise { - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - return await vendor.callChatGenerateWF(llm, messages, functions, forceFunctionName, maxTokens); -} \ No newline at end of file diff --git a/src/modules/llms/vendors/IModelVendor.ts b/src/modules/llms/vendors/IModelVendor.ts index 1dda6abdc..e7cc7fbb0 100644 --- a/src/modules/llms/vendors/IModelVendor.ts +++ b/src/modules/llms/vendors/IModelVendor.ts @@ -1,13 +1,12 @@ import type React from 'react'; +import type { TRPCClientErrorBase } from '@trpc/client'; -import type { DLLM, DModelSourceId } from '../store-llms'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../transports/chatGenerate'; +import type { DLLM, DLLMId, DModelSourceId } from '../store-llms'; +import type { ModelDescriptionSchema } from '../server/llm.server.types'; +import type { ModelVendorId } from './vendors.registry'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '~/modules/llms/llm.client'; -export type ModelVendorId = 'anthropic' | 'azure' | 'localai' | 'mistral' | 'ollama' | 'oobabooga' | 'openai' | 'openrouter'; - -export type ModelVendorRegistryType = Record; - export interface IModelVendor> { readonly id: ModelVendorId; readonly name: string; @@ -30,7 +29,28 @@ export interface IModelVendor): TAccess; - callChatGenerate(llm: TDLLM, messages: VChatMessageIn[], maxTokens?: number): Promise; - - callChatGenerateWF(llm: TDLLM, messages: VChatMessageIn[], functions: null | VChatFunctionIn[], forceFunctionName: null | string, maxTokens?: number): Promise; -} \ No newline at end of file + rpcUpdateModelsQuery: ( + access: TAccess, + enabled: boolean, + onSuccess: (data: { models: ModelDescriptionSchema[] }) => void, + ) => { isFetching: boolean, refetch: () => void, isError: boolean, error: TRPCClientErrorBase | null }; + + rpcChatGenerateOrThrow: ( + access: TAccess, + llmOptions: TLLMOptions, + messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, forceFunctionName: string | null, + maxTokens?: number, + ) => Promise; + + streamingChatGenerateOrThrow: ( + access: TAccess, + llmId: DLLMId, + llmOptions: TLLMOptions, + messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, forceFunctionName: string | null, + abortSignal: AbortSignal, + onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void, + ) => Promise; + +} diff --git a/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx b/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx index 29dae70a9..d5e214961 100644 --- a/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx +++ b/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx @@ -7,11 +7,11 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { useToggleableBoolean } from '~/common/util/useToggleableBoolean'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { isValidAnthropicApiKey, ModelVendorAnthropic } from './anthropic.vendor'; @@ -34,14 +34,8 @@ export function AnthropicSourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = anthropicKey ? keyValid : (!needsUserKey || !!anthropicHost); // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmAnthropic.listModels.useQuery({ access }, { - enabled: !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorAnthropic, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts index b0654601e..e007d8f9f 100644 --- a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts +++ b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts @@ -1,11 +1,12 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { AnthropicIcon } from '~/common/components/icons/AnthropicIcon'; -import { apiAsync } from '~/common/util/trpc.client'; +import { apiAsync, apiQuery } from '~/common/util/trpc.client'; +import type { AnthropicAccessSchema } from '../../server/anthropic/anthropic.router'; import type { IModelVendor } from '../IModelVendor'; -import type { AnthropicAccessSchema } from '../../transports/server/anthropic/anthropic.router'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { VChatMessageOut } from '../../llm.client'; +import { unifiedStreamingClient } from '../unifiedStreamingClient'; import { LLMOptionsOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; @@ -42,37 +43,42 @@ export const ModelVendorAnthropic: IModelVendor { - return anthropicCallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, /*null, null,*/ maxTokens); + + + // List Models + rpcUpdateModelsQuery: (access, enabled, onSuccess) => { + return apiQuery.llmAnthropic.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); }, - callChatGenerateWF(): Promise { - throw new Error('Anthropic does not support "Functions" yet'); + + // Chat Generate (non-streaming) with Functions + rpcChatGenerateOrThrow: async (access, llmOptions, messages, functions, forceFunctionName, maxTokens) => { + if (functions?.length || forceFunctionName) + throw new Error('Anthropic does not support functions'); + + const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; + try { + return await apiAsync.llmAnthropic.chatGenerate.mutate({ + access, + model: { + id: llmRef!, + temperature: llmTemperature, + maxTokens: maxTokens || llmResponseTokens || 1024, + }, + history: messages, + }) as VChatMessageOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'Anthropic Chat Generate Error'; + console.error(`anthropic.rpcChatGenerateOrThrow: ${errorMessage}`); + throw new Error(errorMessage); + } }, -}; + // Chat Generate (streaming) with Functions + streamingChatGenerateOrThrow: unifiedStreamingClient, -/** - * This function either returns the LLM message, or function calls, or throws a descriptive error string - */ -async function anthropicCallChatGenerate( - access: AnthropicAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], - // functions: VChatFunctionIn[] | null, forceFunctionName: string | null, - maxTokens?: number, -): Promise { - const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; - try { - return await apiAsync.llmAnthropic.chatGenerate.mutate({ - access, - model: { - id: llmRef!, - temperature: llmTemperature, - maxTokens: maxTokens || llmResponseTokens || 1024, - }, - history: messages, - }) as TOut; - } catch (error: any) { - const errorMessage = error?.message || error?.toString() || 'Anthropic Chat Generate Error'; - console.error(`anthropicCallChatGenerate: ${errorMessage}`); - throw new Error(errorMessage); - } -} \ No newline at end of file +}; diff --git a/src/modules/llms/vendors/azure/AzureSourceSetup.tsx b/src/modules/llms/vendors/azure/AzureSourceSetup.tsx index 7ed3c798c..4de8838d8 100644 --- a/src/modules/llms/vendors/azure/AzureSourceSetup.tsx +++ b/src/modules/llms/vendors/azure/AzureSourceSetup.tsx @@ -5,11 +5,11 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { asValidURL } from '~/common/util/urlUtils'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { isValidAzureApiKey, ModelVendorAzure } from './azure.vendor'; @@ -31,14 +31,8 @@ export function AzureSourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = azureKey ? keyValid : !needsUserKey; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorAzure, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/azure/azure.vendor.ts b/src/modules/llms/vendors/azure/azure.vendor.ts index a7b2b6734..e86bd6c79 100644 --- a/src/modules/llms/vendors/azure/azure.vendor.ts +++ b/src/modules/llms/vendors/azure/azure.vendor.ts @@ -3,10 +3,9 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { AzureIcon } from '~/common/components/icons/AzureIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { AzureSourceSetup } from './AzureSourceSetup'; @@ -58,10 +57,9 @@ export const ModelVendorAzure: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, functions, forceFunctionName, maxTokens); - }, + + // OpenAI transport ('azure' dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, + streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow, }; \ No newline at end of file diff --git a/src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx b/src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx new file mode 100644 index 000000000..f5b55a1c2 --- /dev/null +++ b/src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx @@ -0,0 +1,96 @@ +import * as React from 'react'; + +import { FormControl, FormHelperText, Option, Select } from '@mui/joy'; +import HealthAndSafetyIcon from '@mui/icons-material/HealthAndSafety'; + +import { FormInputKey } from '~/common/components/forms/FormInputKey'; +import { FormLabelStart } from '~/common/components/forms/FormLabelStart'; +import { InlineError } from '~/common/components/InlineError'; +import { Link } from '~/common/components/Link'; +import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; + +import type { DModelSourceId } from '../../store-llms'; +import type { GeminiBlockSafetyLevel } from '../../server/gemini/gemini.wiretypes'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; + +import { ModelVendorGemini } from './gemini.vendor'; + + +const GEMINI_API_KEY_LINK = 'https://makersuite.google.com/app/apikey'; + +const SAFETY_OPTIONS: { value: GeminiBlockSafetyLevel, label: string }[] = [ + { value: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', label: 'Default' }, + { value: 'BLOCK_LOW_AND_ABOVE', label: 'Low and above' }, + { value: 'BLOCK_MEDIUM_AND_ABOVE', label: 'Medium and above' }, + { value: 'BLOCK_ONLY_HIGH', label: 'Only high' }, + { value: 'BLOCK_NONE', label: 'None' }, +]; + + +export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) { + + // external state + const { source, sourceSetupValid, access, updateSetup } = + useSourceSetup(props.sourceId, ModelVendorGemini); + + // derived state + const { geminiKey, minSafetyLevel } = access; + + const needsUserKey = !ModelVendorGemini.hasBackendCap?.(); + const shallFetchSucceed = !needsUserKey || (!!geminiKey && sourceSetupValid); + const showKeyError = !!geminiKey && !sourceSetupValid; + + // fetch models + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorGemini, access, shallFetchSucceed, source); + + return <> + + {needsUserKey + ? !geminiKey && request Key + : 'โœ”๏ธ already set in server'} + } + value={geminiKey} onChange={value => updateSetup({ geminiKey: value.trim() })} + required={needsUserKey} isError={showKeyError} + placeholder='...' + /> + + + + + + + + Gemini has + adjustable safety settings on four categories: Harassment, Hate speech, + Sexually explicit, and Dangerous content, in addition to non-adjustable built-in filters. + By default, the model will block content with medium and above probability + of being unsafe. + + + + + {isError && } + + ; +} \ No newline at end of file diff --git a/src/modules/llms/vendors/gemini/gemini.vendor.ts b/src/modules/llms/vendors/gemini/gemini.vendor.ts new file mode 100644 index 000000000..ca7fc6d9c --- /dev/null +++ b/src/modules/llms/vendors/gemini/gemini.vendor.ts @@ -0,0 +1,97 @@ +import GoogleIcon from '@mui/icons-material/Google'; + +import { backendCaps } from '~/modules/backend/state-backend'; + +import { apiAsync, apiQuery } from '~/common/util/trpc.client'; + +import type { GeminiAccessSchema } from '../../server/gemini/gemini.router'; +import type { GeminiBlockSafetyLevel } from '../../server/gemini/gemini.wiretypes'; +import type { IModelVendor } from '../IModelVendor'; +import type { VChatMessageOut } from '../../llm.client'; +import { unifiedStreamingClient } from '../unifiedStreamingClient'; + +import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; + +import { GeminiSourceSetup } from './GeminiSourceSetup'; + + +export interface SourceSetupGemini { + geminiKey: string; + minSafetyLevel: GeminiBlockSafetyLevel; +} + +export interface LLMOptionsGemini { + llmRef: string; + stopSequences: string[]; // up to 5 sequences that will stop generation (optional) + candidateCount: number; // 1...8 number of generated responses to return (optional) + maxOutputTokens: number; // if unset, this will default to outputTokenLimit (optional) + temperature: number; // 0...1 Controls the randomness of the output. (optional) + topP: number; // 0...1 The maximum cumulative probability of tokens to consider when sampling (optional) + topK: number; // 1...100 The maximum number of tokens to consider when sampling (optional) +} + + +export const ModelVendorGemini: IModelVendor = { + id: 'googleai', + name: 'Gemini', + rank: 11, + location: 'cloud', + instanceLimit: 1, + hasBackendCap: () => backendCaps().hasLlmGemini, + + // components + Icon: GoogleIcon, + SourceSetupComponent: GeminiSourceSetup, + LLMOptionsComponent: OpenAILLMOptions, + + // functions + initializeSetup: () => ({ + geminiKey: '', + minSafetyLevel: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + }), + validateSetup: (setup) => { + return setup.geminiKey?.length > 0; + }, + getTransportAccess: (partialSetup): GeminiAccessSchema => ({ + dialect: 'gemini', + geminiKey: partialSetup?.geminiKey || '', + minSafetyLevel: partialSetup?.minSafetyLevel || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + }), + + // List Models + rpcUpdateModelsQuery: (access, enabled, onSuccess) => { + return apiQuery.llmGemini.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); + }, + + // Chat Generate (non-streaming) with Functions + rpcChatGenerateOrThrow: async (access, llmOptions, messages, functions, forceFunctionName, maxTokens) => { + if (functions?.length || forceFunctionName) + throw new Error('Gemini does not support functions'); + + const { llmRef, temperature = 0.5, maxOutputTokens } = llmOptions; + try { + return await apiAsync.llmGemini.chatGenerate.mutate({ + access, + model: { + id: llmRef!, + temperature: temperature, + maxTokens: maxTokens || maxOutputTokens || 1024, + }, + history: messages, + }) as VChatMessageOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'Gemini Chat Generate Error'; + console.error(`gemini.rpcChatGenerateOrThrow: ${errorMessage}`); + throw new Error(errorMessage); + } + }, + + // Chat Generate (streaming) with Functions + streamingChatGenerateOrThrow: unifiedStreamingClient, + +}; diff --git a/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx b/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx index 8afdca950..ca1c2ec57 100644 --- a/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx +++ b/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx @@ -7,10 +7,10 @@ import { FormInputKey } from '~/common/components/forms/FormInputKey'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { ModelVendorLocalAI } from './localai.vendor'; @@ -30,14 +30,8 @@ export function LocalAISourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = isValidHost; // fetch models - the OpenAI way - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: false, // !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorLocalAI, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); return <> diff --git a/src/modules/llms/vendors/localai/localai.vendor.ts b/src/modules/llms/vendors/localai/localai.vendor.ts index 7d58c7d42..c43c3eaaa 100644 --- a/src/modules/llms/vendors/localai/localai.vendor.ts +++ b/src/modules/llms/vendors/localai/localai.vendor.ts @@ -1,10 +1,9 @@ import DevicesIcon from '@mui/icons-material/Devices'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { LocalAISourceSetup } from './LocalAISourceSetup'; @@ -38,10 +37,9 @@ export const ModelVendorLocalAI: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, functions, forceFunctionName, maxTokens); - }, -}; \ No newline at end of file + + // OpenAI transport ('localai' dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, + streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow, +}; diff --git a/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx b/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx index 8cfa57d46..796a04da1 100644 --- a/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx +++ b/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx @@ -4,10 +4,10 @@ import { FormInputKey } from '~/common/components/forms/FormInputKey'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { ModelVendorMistral } from './mistral.vendor'; @@ -29,14 +29,8 @@ export function MistralSourceSetup(props: { sourceId: DModelSourceId }) { const showKeyError = !!mistralKey && !sourceSetupValid; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorMistral, access, shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/mistral/mistral.vendor.ts b/src/modules/llms/vendors/mistral/mistral.vendor.ts index 5ae500a07..a437b3df1 100644 --- a/src/modules/llms/vendors/mistral/mistral.vendor.ts +++ b/src/modules/llms/vendors/mistral/mistral.vendor.ts @@ -3,10 +3,9 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { MistralIcon } from '~/common/components/icons/MistralIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatMessageIn, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import { LLMOptionsOpenAI, openAICallChatGenerate, SourceSetupOpenAI } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI, SourceSetupOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { MistralSourceSetup } from './MistralSourceSetup'; @@ -48,10 +47,9 @@ export const ModelVendorMistral: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF() { - throw new Error('Mistral does not support "Functions" yet'); - }, + + // OpenAI transport ('mistral' dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, + streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow, }; \ No newline at end of file diff --git a/src/modules/llms/vendors/ollama/OllamaAdministration.tsx b/src/modules/llms/vendors/ollama/OllamaAdministration.tsx index 9d2aebdef..2c0aa4b94 100644 --- a/src/modules/llms/vendors/ollama/OllamaAdministration.tsx +++ b/src/modules/llms/vendors/ollama/OllamaAdministration.tsx @@ -12,7 +12,7 @@ import { Link } from '~/common/components/Link'; import { apiQuery } from '~/common/util/trpc.client'; import { settingsGap } from '~/common/app.theme'; -import type { OllamaAccessSchema } from '../../transports/server/ollama/ollama.router'; +import type { OllamaAccessSchema } from '../../server/ollama/ollama.router'; export function OllamaAdministration(props: { access: OllamaAccessSchema, onClose: () => void }) { diff --git a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx index 3d8b2da2d..9fd3a2617 100644 --- a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx +++ b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx @@ -6,13 +6,14 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { asValidURL } from '~/common/util/urlUtils'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; + import { ModelVendorOllama } from './ollama.vendor'; import { OllamaAdministration } from './OllamaAdministration'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) { @@ -32,14 +33,8 @@ export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = !hostError; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOllama.listModels.useQuery({ access }, { - enabled: false, // !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorOllama, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); return <> diff --git a/src/modules/llms/vendors/ollama/ollama.vendor.ts b/src/modules/llms/vendors/ollama/ollama.vendor.ts index 883f5f680..98b444c99 100644 --- a/src/modules/llms/vendors/ollama/ollama.vendor.ts +++ b/src/modules/llms/vendors/ollama/ollama.vendor.ts @@ -1,13 +1,14 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OllamaIcon } from '~/common/components/icons/OllamaIcon'; -import { apiAsync } from '~/common/util/trpc.client'; +import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { IModelVendor } from '../IModelVendor'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; -import type { OllamaAccessSchema } from '../../transports/server/ollama/ollama.router'; +import type { OllamaAccessSchema } from '../../server/ollama/ollama.router'; +import type { VChatMessageOut } from '../../llm.client'; +import { unifiedStreamingClient } from '../unifiedStreamingClient'; -import { LLMOptionsOpenAI } from '../openai/openai.vendor'; +import type { LLMOptionsOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { OllamaSourceSetup } from './OllamaSourceSetup'; @@ -36,36 +37,41 @@ export const ModelVendorOllama: IModelVendor { - return ollamaCallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, maxTokens); + + // List Models + rpcUpdateModelsQuery: (access, enabled, onSuccess) => { + return apiQuery.llmOllama.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); }, - callChatGenerateWF(): Promise { - throw new Error('Ollama does not support "Functions" yet'); + + // Chat Generate (non-streaming) with Functions + rpcChatGenerateOrThrow: async (access, llmOptions, messages, functions, forceFunctionName, maxTokens) => { + if (functions?.length || forceFunctionName) + throw new Error('Ollama does not support functions'); + + const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; + try { + return await apiAsync.llmOllama.chatGenerate.mutate({ + access, + model: { + id: llmRef!, + temperature: llmTemperature, + maxTokens: maxTokens || llmResponseTokens || 1024, + }, + history: messages, + }) as VChatMessageOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'Ollama Chat Generate Error'; + console.error(`ollama.rpcChatGenerateOrThrow: ${errorMessage}`); + throw new Error(errorMessage); + } }, -}; + // Chat Generate (streaming) with Functions + streamingChatGenerateOrThrow: unifiedStreamingClient, -/** - * This function either returns the LLM message, or throws a descriptive error string - */ -async function ollamaCallChatGenerate( - access: OllamaAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], - maxTokens?: number, -): Promise { - const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; - try { - return await apiAsync.llmOllama.chatGenerate.mutate({ - access, - model: { - id: llmRef!, - temperature: llmTemperature, - maxTokens: maxTokens || llmResponseTokens || 1024, - }, - history: messages, - }) as TOut; - } catch (error: any) { - const errorMessage = error?.message || error?.toString() || 'Ollama Chat Generate Error'; - console.error(`ollamaCallChatGenerate: ${errorMessage}`); - throw new Error(errorMessage); - } -} +}; diff --git a/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx b/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx index f9e8ca674..f218c5829 100644 --- a/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx +++ b/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx @@ -6,10 +6,10 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { ModelVendorOoobabooga } from './oobabooga.vendor'; @@ -24,14 +24,8 @@ export function OobaboogaSourceSetup(props: { sourceId: DModelSourceId }) { const { oaiHost } = access; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: false, // !hasModels && !!asValidURL(normSetup.oaiHost), - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorOoobabooga, access, false /* !hasModels && !!asValidURL(normSetup.oaiHost) */, source); return <> diff --git a/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts b/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts index b72827981..0d447539e 100644 --- a/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts +++ b/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts @@ -1,10 +1,9 @@ import { OobaboogaIcon } from '~/common/components/icons/OobaboogaIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { OobaboogaSourceSetup } from './OobaboogaSourceSetup'; @@ -38,10 +37,9 @@ export const ModelVendorOoobabooga: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, functions, forceFunctionName, maxTokens); - }, + + // OpenAI transport (oobabooga dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, + streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow, }; \ No newline at end of file diff --git a/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx b/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx index 85d3e1ea4..aa174db01 100644 --- a/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx +++ b/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx @@ -9,13 +9,13 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { useToggleableBoolean } from '~/common/util/useToggleableBoolean'; -import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; -import { DLLM, DModelSource, DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; -import { isValidOpenAIApiKey, LLMOptionsOpenAI, ModelVendorOpenAI } from './openai.vendor'; +import { isValidOpenAIApiKey, ModelVendorOpenAI } from './openai.vendor'; // avoid repeating it all over @@ -40,15 +40,8 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = oaiKey ? keyValid : !needsUserKey; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); - + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorOpenAI, access, !sourceHasLLMs && shallFetchSucceed, source); return <> @@ -110,30 +103,3 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) { ; } - - -export function modelDescriptionToDLLM(model: ModelDescriptionSchema, source: DModelSource): DLLM { - const maxOutputTokens = model.maxCompletionTokens || Math.round((model.contextWindow || 4096) / 2); - const llmResponseTokens = Math.round(maxOutputTokens / (model.maxCompletionTokens ? 2 : 4)); - return { - id: `${source.id}-${model.id}`, - - label: model.label, - created: model.created || 0, - updated: model.updated || 0, - description: model.description, - tags: [], // ['stream', 'chat'], - contextTokens: model.contextWindow, - maxOutputTokens: maxOutputTokens, - hidden: !!model.hidden, - - sId: source.id, - _source: source, - - options: { - llmRef: model.id, - llmTemperature: 0.5, - llmResponseTokens: llmResponseTokens, - }, - }; -} \ No newline at end of file diff --git a/src/modules/llms/vendors/openai/openai.vendor.ts b/src/modules/llms/vendors/openai/openai.vendor.ts index f7eaeb92b..18f420403 100644 --- a/src/modules/llms/vendors/openai/openai.vendor.ts +++ b/src/modules/llms/vendors/openai/openai.vendor.ts @@ -1,11 +1,12 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OpenAIIcon } from '~/common/components/icons/OpenAIIcon'; -import { apiAsync } from '~/common/util/trpc.client'; +import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; +import type { VChatMessageOrFunctionCallOut } from '../../llm.client'; +import { unifiedStreamingClient } from '../unifiedStreamingClient'; import { OpenAILLMOptions } from './OpenAILLMOptions'; import { OpenAISourceSetup } from './OpenAISourceSetup'; @@ -51,41 +52,40 @@ export const ModelVendorOpenAI: IModelVendor { - const access = this.getTransportAccess(llm._source.setup); - return openAICallChatGenerate(access, llm.options, messages, null, null, maxTokens); + + // List Models + rpcUpdateModelsQuery: (access, enabled, onSuccess) => { + return apiQuery.llmOpenAI.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - const access = this.getTransportAccess(llm._source.setup); - return openAICallChatGenerate(access, llm.options, messages, functions, forceFunctionName, maxTokens); + + // Chat Generate (non-streaming) with Functions + rpcChatGenerateOrThrow: async (access, llmOptions, messages, functions, forceFunctionName, maxTokens) => { + const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; + try { + return await apiAsync.llmOpenAI.chatGenerateWithFunctions.mutate({ + access, + model: { + id: llmRef!, + temperature: llmTemperature, + maxTokens: maxTokens || llmResponseTokens || 1024, + }, + functions: functions ?? undefined, + forceFunctionName: forceFunctionName ?? undefined, + history: messages, + }) as VChatMessageOrFunctionCallOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'OpenAI Chat Generate Error'; + console.error(`openai.rpcChatGenerateOrThrow: ${errorMessage}`); + throw new Error(errorMessage); + } }, -}; + // Chat Generate (streaming) with Functions + streamingChatGenerateOrThrow: unifiedStreamingClient, -/** - * This function either returns the LLM message, or function calls, or throws a descriptive error string - */ -export async function openAICallChatGenerate( - access: OpenAIAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], - functions: VChatFunctionIn[] | null, forceFunctionName: string | null, - maxTokens?: number, -): Promise { - const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; - try { - return await apiAsync.llmOpenAI.chatGenerateWithFunctions.mutate({ - access, - model: { - id: llmRef!, - temperature: llmTemperature, - maxTokens: maxTokens || llmResponseTokens || 1024, - }, - functions: functions ?? undefined, - forceFunctionName: forceFunctionName ?? undefined, - history: messages, - }) as TOut; - } catch (error: any) { - const errorMessage = error?.message || error?.toString() || 'OpenAI Chat Generate Error'; - console.error(`openAICallChatGenerate: ${errorMessage}`); - throw new Error(errorMessage); - } -} \ No newline at end of file +}; diff --git a/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx b/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx index 470dffb35..9c3df5da5 100644 --- a/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx +++ b/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx @@ -6,11 +6,11 @@ import { FormInputKey } from '~/common/components/forms/FormInputKey'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { getCallbackUrl } from '~/common/app.routes'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { isValidOpenRouterKey, ModelVendorOpenRouter } from './openrouter.vendor'; @@ -30,14 +30,8 @@ export function OpenRouterSourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = oaiKey ? keyValid : !needsUserKey; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useLlmUpdateModels(ModelVendorOpenRouter, access, !sourceHasLLMs && shallFetchSucceed, source); const handleOpenRouterLogin = () => { diff --git a/src/modules/llms/vendors/openrouter/openrouter.vendor.ts b/src/modules/llms/vendors/openrouter/openrouter.vendor.ts index 98a0ed156..26064eac9 100644 --- a/src/modules/llms/vendors/openrouter/openrouter.vendor.ts +++ b/src/modules/llms/vendors/openrouter/openrouter.vendor.ts @@ -3,10 +3,9 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OpenRouterIcon } from '~/common/components/icons/OpenRouterIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { OpenRouterSourceSetup } from './OpenRouterSourceSetup'; @@ -59,10 +58,9 @@ export const ModelVendorOpenRouter: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, functions, forceFunctionName, maxTokens); - }, + + // OpenAI transport ('openrouter' dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, + streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow, }; \ No newline at end of file diff --git a/src/modules/llms/transports/streamChat.ts b/src/modules/llms/vendors/unifiedStreamingClient.ts similarity index 72% rename from src/modules/llms/transports/streamChat.ts rename to src/modules/llms/vendors/unifiedStreamingClient.ts index 4b6159752..5359733d9 100644 --- a/src/modules/llms/transports/streamChat.ts +++ b/src/modules/llms/vendors/unifiedStreamingClient.ts @@ -1,11 +1,10 @@ import { apiAsync } from '~/common/util/trpc.client'; -import type { DLLM, DLLMId } from '../store-llms'; -import { findVendorForLlmOrThrow } from '../vendors/vendors.registry'; +import type { ChatStreamingFirstOutputPacketSchema, ChatStreamingInputSchema } from '../server/llm.server.streaming'; +import type { DLLMId } from '../store-llms'; +import type { VChatFunctionIn, VChatMessageIn } from '../llm.client'; -import type { ChatStreamFirstPacketSchema, ChatStreamInputSchema } from './server/openai/openai.streaming'; -import type { OpenAIWire } from './server/openai/openai.wiretypes'; -import type { VChatMessageIn } from './chatGenerate'; +import type { OpenAIWire } from '../server/openai/openai.wiretypes'; /** @@ -15,27 +14,14 @@ import type { VChatMessageIn } from './chatGenerate'; * Vendor-specific implementation is on our server backend (API) code. This function tries to be * as generic as possible. * - * @param llmId LLM to use - * @param messages the history of messages to send to the API endpoint - * @param abortSignal used to initiate a client-side abort of the fetch request to the API endpoint - * @param onUpdate callback when a piece of a message (text, model name, typing..) is received + * NOTE: onUpdate is callback when a piece of a message (text, model name, typing..) is received */ -export async function streamChat( +export async function unifiedStreamingClient( + access: ChatStreamingInputSchema['access'], llmId: DLLMId, + llmOptions: TLLMOptions, messages: VChatMessageIn[], - abortSignal: AbortSignal, - onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void, -): Promise { - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - const access = vendor.getTransportAccess(llm._source.setup) as ChatStreamInputSchema['access']; - return await vendorStreamChat(access, llm, messages, abortSignal, onUpdate); -} - - -async function vendorStreamChat( - access: ChatStreamInputSchema['access'], - llm: DLLM, - messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, forceFunctionName: string | null, abortSignal: AbortSignal, onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void, ) { @@ -79,12 +65,12 @@ async function vendorStreamChat( } // model params (llm) - const { llmRef, llmTemperature, llmResponseTokens } = (llm.options as any) || {}; + const { llmRef, llmTemperature, llmResponseTokens } = (llmOptions as any) || {}; if (!llmRef || llmTemperature === undefined || llmResponseTokens === undefined) - throw new Error(`Error in configuration for model ${llm.id}: ${JSON.stringify(llm.options)}`); + throw new Error(`Error in configuration for model ${llmId}: ${JSON.stringify(llmOptions)}`); // prepare the input, similarly to the tRPC openAI.chatGenerate - const input: ChatStreamInputSchema = { + const input: ChatStreamingInputSchema = { access, model: { id: llmRef, @@ -131,7 +117,7 @@ async function vendorStreamChat( incrementalText = incrementalText.substring(endOfJson + 1); parsedFirstPacket = true; try { - const parsed: ChatStreamFirstPacketSchema = JSON.parse(json); + const parsed: ChatStreamingFirstOutputPacketSchema = JSON.parse(json); onUpdate({ originLLM: parsed.model }, false); } catch (e) { // error parsing JSON, ignore diff --git a/src/modules/llms/vendors/useLlmUpdateModels.tsx b/src/modules/llms/vendors/useLlmUpdateModels.tsx new file mode 100644 index 000000000..cc12bb048 --- /dev/null +++ b/src/modules/llms/vendors/useLlmUpdateModels.tsx @@ -0,0 +1,47 @@ +import type { IModelVendor } from './IModelVendor'; +import type { ModelDescriptionSchema } from '../server/llm.server.types'; +import { DLLM, DModelSource, useModelsStore } from '../store-llms'; + + +/** + * Hook that fetches the list of models from the vendor and updates the store, + * while returning the fetch state. + */ +export function useLlmUpdateModels(vendor: IModelVendor, access: TAccess, enabled: boolean, source: DModelSource) { + return vendor.rpcUpdateModelsQuery(access, enabled, data => source && updateModelsFn(data, source)); +} + + +function updateModelsFn(data: { models: ModelDescriptionSchema[] }, source: DModelSource) { + useModelsStore.getState().setLLMs( + data.models.map(model => modelDescriptionToDLLMOpenAIOptions(model, source)), + source.id, + ); +} + +function modelDescriptionToDLLMOpenAIOptions(model: ModelDescriptionSchema, source: DModelSource): DLLM { + const maxOutputTokens = model.maxCompletionTokens || Math.round((model.contextWindow || 4096) / 2); + const llmResponseTokens = Math.round(maxOutputTokens / (model.maxCompletionTokens ? 2 : 4)); + return { + id: `${source.id}-${model.id}`, + + label: model.label, + created: model.created || 0, + updated: model.updated || 0, + description: model.description, + tags: [], // ['stream', 'chat'], + contextTokens: model.contextWindow, + maxOutputTokens: maxOutputTokens, + hidden: !!model.hidden, + + sId: source.id, + _source: source, + + options: { + llmRef: model.id, + // @ts-ignore FIXME: large assumption that this is LLMOptionsOpenAI object + llmTemperature: 0.5, + llmResponseTokens: llmResponseTokens, + }, + }; +} \ No newline at end of file diff --git a/src/modules/llms/vendors/useSourceSetup.ts b/src/modules/llms/vendors/useSourceSetup.ts new file mode 100644 index 000000000..4395ac458 --- /dev/null +++ b/src/modules/llms/vendors/useSourceSetup.ts @@ -0,0 +1,35 @@ +import { shallow } from 'zustand/shallow'; + +import type { IModelVendor } from './IModelVendor'; +import { DModelSource, DModelSourceId, useModelsStore } from '../store-llms'; + + +/** + * Source-specific read/write - great time saver + */ +export function useSourceSetup(sourceId: DModelSourceId, vendor: IModelVendor) { + + // invalidates only when the setup changes + const { updateSourceSetup, ...rest } = useModelsStore(state => { + + // find the source (or null) + const source: DModelSource | null = state.sources.find(source => source.id === sourceId) as DModelSource ?? null; + + // (safe) source-derived properties + const sourceSetupValid = (source?.setup && vendor?.validateSetup) ? vendor.validateSetup(source.setup as TSourceSetup) : false; + const sourceLLMs = source ? state.llms.filter(llm => llm._source === source) : []; + const access = vendor.getTransportAccess(source?.setup); + + return { + source, + access, + sourceHasLLMs: !!sourceLLMs.length, + sourceSetupValid, + updateSourceSetup: state.updateSourceSetup, + }; + }, shallow); + + // convenience function for this source + const updateSetup = (partialSetup: Partial) => updateSourceSetup(sourceId, partialSetup); + return { ...rest, updateSetup }; +} \ No newline at end of file diff --git a/src/modules/llms/vendors/vendors.registry.ts b/src/modules/llms/vendors/vendors.registry.ts index 705799711..054a8a6bf 100644 --- a/src/modules/llms/vendors/vendors.registry.ts +++ b/src/modules/llms/vendors/vendors.registry.ts @@ -1,5 +1,6 @@ import { ModelVendorAnthropic } from './anthropic/anthropic.vendor'; import { ModelVendorAzure } from './azure/azure.vendor'; +import { ModelVendorGemini } from './gemini/gemini.vendor'; import { ModelVendorLocalAI } from './localai/localai.vendor'; import { ModelVendorMistral } from './mistral/mistral.vendor'; import { ModelVendorOllama } from './ollama/ollama.vendor'; @@ -7,20 +8,32 @@ import { ModelVendorOoobabooga } from './oobabooga/oobabooga.vendor'; import { ModelVendorOpenAI } from './openai/openai.vendor'; import { ModelVendorOpenRouter } from './openrouter/openrouter.vendor'; -import type { IModelVendor, ModelVendorId, ModelVendorRegistryType } from './IModelVendor'; +import type { IModelVendor } from './IModelVendor'; import { DLLMId, DModelSource, DModelSourceId, findLLMOrThrow } from '../store-llms'; +export type ModelVendorId = + | 'anthropic' + | 'azure' + | 'googleai' + | 'localai' + | 'mistral' + | 'ollama' + | 'oobabooga' + | 'openai' + | 'openrouter'; + /** Global: Vendor Instances Registry **/ -const MODEL_VENDOR_REGISTRY: ModelVendorRegistryType = { +const MODEL_VENDOR_REGISTRY: Record = { anthropic: ModelVendorAnthropic, azure: ModelVendorAzure, + googleai: ModelVendorGemini, localai: ModelVendorLocalAI, mistral: ModelVendorMistral, ollama: ModelVendorOllama, oobabooga: ModelVendorOoobabooga, openai: ModelVendorOpenAI, openrouter: ModelVendorOpenRouter, -}; +} as Record; const MODEL_VENDOR_DEFAULT: ModelVendorId = 'openai'; @@ -31,13 +44,15 @@ export function findAllVendors(): IModelVendor[] { return modelVendors; } -export function findVendorById(vendorId?: ModelVendorId): IModelVendor | null { - return vendorId ? (MODEL_VENDOR_REGISTRY[vendorId] ?? null) : null; +export function findVendorById( + vendorId?: ModelVendorId, +): IModelVendor | null { + return vendorId ? (MODEL_VENDOR_REGISTRY[vendorId] as IModelVendor) ?? null : null; } -export function findVendorForLlmOrThrow(llmId: DLLMId) { - const llm = findLLMOrThrow(llmId); - const vendor = findVendorById(llm?._source.vId); +export function findVendorForLlmOrThrow(llmId: DLLMId) { + const llm = findLLMOrThrow(llmId); + const vendor = findVendorById(llm?._source.vId); if (!vendor) throw new Error(`callChat: Vendor not found for LLM ${llmId}`); return { llm, vendor }; } diff --git a/src/server/api/trpc.router-edge.ts b/src/server/api/trpc.router-edge.ts index 96464554a..24dc33e6f 100644 --- a/src/server/api/trpc.router-edge.ts +++ b/src/server/api/trpc.router-edge.ts @@ -3,9 +3,10 @@ import { createTRPCRouter } from './trpc.server'; import { backendRouter } from '~/modules/backend/backend.router'; import { elevenlabsRouter } from '~/modules/elevenlabs/elevenlabs.router'; import { googleSearchRouter } from '~/modules/google/search.router'; -import { llmAnthropicRouter } from '~/modules/llms/transports/server/anthropic/anthropic.router'; -import { llmOllamaRouter } from '~/modules/llms/transports/server/ollama/ollama.router'; -import { llmOpenAIRouter } from '~/modules/llms/transports/server/openai/openai.router'; +import { llmAnthropicRouter } from '~/modules/llms/server/anthropic/anthropic.router'; +import { llmGeminiRouter } from '~/modules/llms/server/gemini/gemini.router'; +import { llmOllamaRouter } from '~/modules/llms/server/ollama/ollama.router'; +import { llmOpenAIRouter } from '~/modules/llms/server/openai/openai.router'; import { prodiaRouter } from '~/modules/prodia/prodia.router'; import { ytPersonaRouter } from '../../apps/personas/ytpersona.router'; @@ -17,6 +18,7 @@ export const appRouterEdge = createTRPCRouter({ elevenlabs: elevenlabsRouter, googleSearch: googleSearchRouter, llmAnthropic: llmAnthropicRouter, + llmGemini: llmGeminiRouter, llmOllama: llmOllamaRouter, llmOpenAI: llmOpenAIRouter, prodia: prodiaRouter, diff --git a/src/server/env.mjs b/src/server/env.mjs index 22b0469b1..d9553571f 100644 --- a/src/server/env.mjs +++ b/src/server/env.mjs @@ -21,6 +21,9 @@ export const env = createEnv({ ANTHROPIC_API_KEY: z.string().optional(), ANTHROPIC_API_HOST: z.string().url().optional(), + // LLM: Google AI's Gemini + GEMINI_API_KEY: z.string().optional(), + // LLM: Mistral MISTRAL_API_KEY: z.string().optional(),