From 7a07f10ed1f7281fc7f67fd344ddd5b0e5b7f594 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 00:09:38 -0800 Subject: [PATCH 01/24] Move ModelVendor enum --- README.md | 2 +- src/apps/models-modal/ModelsSourceSelector.tsx | 6 +++--- src/modules/llms/store-llms.ts | 3 ++- src/modules/llms/vendors/IModelVendor.ts | 5 +---- src/modules/llms/vendors/vendors.registry.ts | 14 +++++++++++++- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index e3cd1cfbd..42d736a85 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # BIG-AGI ๐Ÿง โœจ Welcome to big-AGI ๐Ÿ‘‹, the GPT application for professionals that need function, form, -simplicity, and speed. Powered by the latest models from 7 vendors and +simplicity, and speed. Powered by the latest models from 8 vendors and open-source model servers, `big-AGI` offers best-in-class Voice and Chat with AI Personas, visualizations, coding, drawing, calling, and quite more -- all in a polished UX. diff --git a/src/apps/models-modal/ModelsSourceSelector.tsx b/src/apps/models-modal/ModelsSourceSelector.tsx index ef501a4b7..4d9d1fca9 100644 --- a/src/apps/models-modal/ModelsSourceSelector.tsx +++ b/src/apps/models-modal/ModelsSourceSelector.tsx @@ -5,9 +5,9 @@ import { Avatar, Badge, Box, Button, IconButton, ListItemDecorator, MenuItem, Op import AddIcon from '@mui/icons-material/Add'; import DeleteOutlineIcon from '@mui/icons-material/DeleteOutline'; -import { type DModelSourceId, useModelsStore } from '~/modules/llms/store-llms'; -import { type IModelVendor, type ModelVendorId } from '~/modules/llms/vendors/IModelVendor'; -import { createModelSourceForVendor, findAllVendors, findVendorById } from '~/modules/llms/vendors/vendors.registry'; +import type { IModelVendor } from '~/modules/llms/vendors/IModelVendor'; +import { DModelSourceId, useModelsStore } from '~/modules/llms/store-llms'; +import { createModelSourceForVendor, findAllVendors, findVendorById, ModelVendorId } from '~/modules/llms/vendors/vendors.registry'; import { CloseableMenu } from '~/common/components/CloseableMenu'; import { ConfirmationModal } from '~/common/components/ConfirmationModal'; diff --git a/src/modules/llms/store-llms.ts b/src/modules/llms/store-llms.ts index d7ad30780..9d13b9574 100644 --- a/src/modules/llms/store-llms.ts +++ b/src/modules/llms/store-llms.ts @@ -2,7 +2,8 @@ import { create } from 'zustand'; import { shallow } from 'zustand/shallow'; import { persist } from 'zustand/middleware'; -import type { IModelVendor, ModelVendorId } from './vendors/IModelVendor'; +import type { IModelVendor } from './vendors/IModelVendor'; +import type { ModelVendorId } from './vendors/vendors.registry'; import type { SourceSetupOpenRouter } from './vendors/openrouter/openrouter.vendor'; diff --git a/src/modules/llms/vendors/IModelVendor.ts b/src/modules/llms/vendors/IModelVendor.ts index 1dda6abdc..a29a1e0b7 100644 --- a/src/modules/llms/vendors/IModelVendor.ts +++ b/src/modules/llms/vendors/IModelVendor.ts @@ -1,13 +1,10 @@ import type React from 'react'; import type { DLLM, DModelSourceId } from '../store-llms'; +import type { ModelVendorId } from './vendors.registry'; import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../transports/chatGenerate'; -export type ModelVendorId = 'anthropic' | 'azure' | 'localai' | 'mistral' | 'ollama' | 'oobabooga' | 'openai' | 'openrouter'; - -export type ModelVendorRegistryType = Record; - export interface IModelVendor> { readonly id: ModelVendorId; readonly name: string; diff --git a/src/modules/llms/vendors/vendors.registry.ts b/src/modules/llms/vendors/vendors.registry.ts index 705799711..884cdff27 100644 --- a/src/modules/llms/vendors/vendors.registry.ts +++ b/src/modules/llms/vendors/vendors.registry.ts @@ -7,9 +7,19 @@ import { ModelVendorOoobabooga } from './oobabooga/oobabooga.vendor'; import { ModelVendorOpenAI } from './openai/openai.vendor'; import { ModelVendorOpenRouter } from './openrouter/openrouter.vendor'; -import type { IModelVendor, ModelVendorId, ModelVendorRegistryType } from './IModelVendor'; +import type { IModelVendor } from './IModelVendor'; import { DLLMId, DModelSource, DModelSourceId, findLLMOrThrow } from '../store-llms'; +export type ModelVendorId = + | 'anthropic' + | 'azure' + | 'localai' + | 'mistral' + | 'ollama' + | 'oobabooga' + | 'openai' + | 'openrouter'; + /** Global: Vendor Instances Registry **/ const MODEL_VENDOR_REGISTRY: ModelVendorRegistryType = { anthropic: ModelVendorAnthropic, @@ -22,6 +32,8 @@ const MODEL_VENDOR_REGISTRY: ModelVendorRegistryType = { openrouter: ModelVendorOpenRouter, }; +type ModelVendorRegistryType = Record; + const MODEL_VENDOR_DEFAULT: ModelVendorId = 'openai'; From e0a010189f872b3956bbb669cd1c1bd5b938651e Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 01:36:37 -0800 Subject: [PATCH 02/24] LLMOptions Modal: fix display --- src/apps/models-modal/LLMOptionsModal.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/apps/models-modal/LLMOptionsModal.tsx b/src/apps/models-modal/LLMOptionsModal.tsx index 65b56a11c..10051bff8 100644 --- a/src/apps/models-modal/LLMOptionsModal.tsx +++ b/src/apps/models-modal/LLMOptionsModal.tsx @@ -117,9 +117,9 @@ export function LLMOptionsModal(props: { id: DLLMId }) { setShowDetails(!showDetails)} /> {showDetails && [{llm.id}]: {llm.options.llmRef && `${llm.options.llmRef} ยท `} - {llm.contextTokens && `context tokens: ${llm.contextTokens.toLocaleString()} ยท `} - {llm.maxOutputTokens && `max output tokens: ${llm.maxOutputTokens.toLocaleString()} ยท `} - {llm.created && `created: ${(new Date(llm.created * 1000)).toLocaleString()} ยท `} + {!!llm.contextTokens && `context tokens: ${llm.contextTokens.toLocaleString()} ยท `} + {!!llm.maxOutputTokens && `max output tokens: ${llm.maxOutputTokens.toLocaleString()} ยท `} + {!!llm.created && `created: ${(new Date(llm.created * 1000)).toLocaleString()} ยท `} description: {llm.description} {/*ยท tags: {llm.tags.join(', ')}*/} } From 34c1c425b94aa5fc4b7c37631e945bb63b0a765e Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 01:37:14 -0800 Subject: [PATCH 03/24] Gemini: backend env var --- docs/environment-variables.md | 2 ++ src/modules/backend/backend.router.ts | 1 + src/modules/backend/state-backend.ts | 2 ++ src/server/env.mjs | 3 +++ 4 files changed, 8 insertions(+) diff --git a/docs/environment-variables.md b/docs/environment-variables.md index 5027f4797..b11f0269d 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -24,6 +24,7 @@ AZURE_OPENAI_API_ENDPOINT= AZURE_OPENAI_API_KEY= ANTHROPIC_API_KEY= ANTHROPIC_API_HOST= +GEMINI_API_KEY= MISTRAL_API_KEY= OLLAMA_API_HOST= OPENROUTER_API_KEY= @@ -80,6 +81,7 @@ requiring the user to enter an API key | `AZURE_OPENAI_API_KEY` | Azure OpenAI API key, see [config-azure-openai.md](config-azure-openai.md) | Optional, but if set `AZURE_OPENAI_API_ENDPOINT` must also be set | | `ANTHROPIC_API_KEY` | The API key for Anthropic | Optional | | `ANTHROPIC_API_HOST` | Changes the backend host for the Anthropic vendor, to enable platforms such as [config-aws-bedrock.md](config-aws-bedrock.md) | Optional | +| `GEMINI_API_KEY` | The API key for Google AI's Gemini | Optional | | `MISTRAL_API_KEY` | The API key for Mistral | Optional | | `OLLAMA_API_HOST` | Changes the backend host for the Ollama vendor. See [config-ollama.md](config-ollama.md) | | | `OPENROUTER_API_KEY` | The API key for OpenRouter | Optional | diff --git a/src/modules/backend/backend.router.ts b/src/modules/backend/backend.router.ts index ad207a75e..fd8e70649 100644 --- a/src/modules/backend/backend.router.ts +++ b/src/modules/backend/backend.router.ts @@ -28,6 +28,7 @@ export const backendRouter = createTRPCRouter({ hasImagingProdia: !!env.PRODIA_API_KEY, hasLlmAnthropic: !!env.ANTHROPIC_API_KEY, hasLlmAzureOpenAI: !!env.AZURE_OPENAI_API_KEY && !!env.AZURE_OPENAI_API_ENDPOINT, + hasLlmGemini: !!env.GEMINI_API_KEY, hasLlmMistral: !!env.MISTRAL_API_KEY, hasLlmOllama: !!env.OLLAMA_API_HOST, hasLlmOpenAI: !!env.OPENAI_API_KEY || !!env.OPENAI_API_HOST, diff --git a/src/modules/backend/state-backend.ts b/src/modules/backend/state-backend.ts index 034269b5f..ebe0025b3 100644 --- a/src/modules/backend/state-backend.ts +++ b/src/modules/backend/state-backend.ts @@ -9,6 +9,7 @@ export interface BackendCapabilities { hasImagingProdia: boolean; hasLlmAnthropic: boolean; hasLlmAzureOpenAI: boolean; + hasLlmGemini: boolean; hasLlmMistral: boolean; hasLlmOllama: boolean; hasLlmOpenAI: boolean; @@ -31,6 +32,7 @@ const useBackendStore = create()( hasImagingProdia: false, hasLlmAnthropic: false, hasLlmAzureOpenAI: false, + hasLlmGemini: false, hasLlmMistral: false, hasLlmOllama: false, hasLlmOpenAI: false, diff --git a/src/server/env.mjs b/src/server/env.mjs index 22b0469b1..d9553571f 100644 --- a/src/server/env.mjs +++ b/src/server/env.mjs @@ -21,6 +21,9 @@ export const env = createEnv({ ANTHROPIC_API_KEY: z.string().optional(), ANTHROPIC_API_HOST: z.string().url().optional(), + // LLM: Google AI's Gemini + GEMINI_API_KEY: z.string().optional(), + // LLM: Mistral MISTRAL_API_KEY: z.string().optional(), From 453a3e575133f14b6b9c0e1a0d6361a6f70e82d8 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 02:09:56 -0800 Subject: [PATCH 04/24] LLM Vendors: auto IDs --- src/modules/llms/vendors/vendors.registry.ts | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/modules/llms/vendors/vendors.registry.ts b/src/modules/llms/vendors/vendors.registry.ts index 884cdff27..d20770dda 100644 --- a/src/modules/llms/vendors/vendors.registry.ts +++ b/src/modules/llms/vendors/vendors.registry.ts @@ -10,18 +10,8 @@ import { ModelVendorOpenRouter } from './openrouter/openrouter.vendor'; import type { IModelVendor } from './IModelVendor'; import { DLLMId, DModelSource, DModelSourceId, findLLMOrThrow } from '../store-llms'; -export type ModelVendorId = - | 'anthropic' - | 'azure' - | 'localai' - | 'mistral' - | 'ollama' - | 'oobabooga' - | 'openai' - | 'openrouter'; - /** Global: Vendor Instances Registry **/ -const MODEL_VENDOR_REGISTRY: ModelVendorRegistryType = { +const MODEL_VENDOR_REGISTRY = { anthropic: ModelVendorAnthropic, azure: ModelVendorAzure, localai: ModelVendorLocalAI, @@ -30,9 +20,9 @@ const MODEL_VENDOR_REGISTRY: ModelVendorRegistryType = { oobabooga: ModelVendorOoobabooga, openai: ModelVendorOpenAI, openrouter: ModelVendorOpenRouter, -}; +} as const; -type ModelVendorRegistryType = Record; +export type ModelVendorId = keyof typeof MODEL_VENDOR_REGISTRY; const MODEL_VENDOR_DEFAULT: ModelVendorId = 'openai'; From 0df7297cca889cc0a42c5643c510103465af562f Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 03:48:41 -0800 Subject: [PATCH 05/24] Gemini: configuration, list models, and immediate generation --- .../transports/server/gemini/gemini.router.ts | 183 ++++++++++++++++++ .../server/gemini/gemini.wiretypes.ts | 175 +++++++++++++++++ .../vendors/googleai/GeminiSourceSetup.tsx | 60 ++++++ .../llms/vendors/googleai/gemini.vendor.ts | 85 ++++++++ src/modules/llms/vendors/vendors.registry.ts | 2 + src/server/api/trpc.router-edge.ts | 2 + 6 files changed, 507 insertions(+) create mode 100644 src/modules/llms/transports/server/gemini/gemini.router.ts create mode 100644 src/modules/llms/transports/server/gemini/gemini.wiretypes.ts create mode 100644 src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx create mode 100644 src/modules/llms/vendors/googleai/gemini.vendor.ts diff --git a/src/modules/llms/transports/server/gemini/gemini.router.ts b/src/modules/llms/transports/server/gemini/gemini.router.ts new file mode 100644 index 000000000..389ba2be8 --- /dev/null +++ b/src/modules/llms/transports/server/gemini/gemini.router.ts @@ -0,0 +1,183 @@ +import { z } from 'zod'; +import { TRPCError } from '@trpc/server'; + +import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server'; +import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; + +import { LLM_IF_OAI_Chat, LLM_IF_OAI_Vision } from '~/modules/llms/store-llms'; + +import { GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes'; +import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; + + +// Default hosts +const DEFAULT_GEMINI_HOST = 'https://generativelanguage.googleapis.com'; + + +// Mappers + +export function geminiAccess(access: GeminiAccessSchema, modelRefId: string | null, apiPath: string): { headers: HeadersInit, url: string } { + + // handle paths that require a model name + if (apiPath.includes('{model=models/*}')) { + if (!modelRefId) + throw new Error(`geminiAccess: modelRefId is required for ${apiPath}`); + apiPath = apiPath.replace('{model=models/*}', modelRefId); + } + + const geminiHost = fixupHost(DEFAULT_GEMINI_HOST, apiPath); + + return { + headers: { + 'Content-Type': 'application/json', + 'x-goog-api-key': access.geminiKey, + }, + url: geminiHost + apiPath, + }; +} + +export const geminiGenerateContentPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, n: number): GeminiGenerateContentRequest => { + const contents: GeminiGenerateContentRequest['contents'] = []; + history.forEach((message) => { + // hack for now - the model seems to want prompts to alternate + if (message.role === 'system') { + contents.push({ role: 'user', parts: [{ text: message.content }] }); + contents.push({ role: 'model', parts: [{ text: 'Ok.' }] }); + } else + contents.push({ role: message.role === 'assistant' ? 'model' : 'user', parts: [{ text: message.content }] }); + }); + return { + contents, + generationConfig: { + ...(n >= 2 && { candidateCount: n }), + ...(model.maxTokens && { maxOutputTokens: model.maxTokens }), + temperature: model.temperature, + }, + // safetySettings: [ + // { category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold: 'BLOCK_NONE' }, + // { category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_NONE' }, + // { category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_NONE' }, + // { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, + // ], + }; +}; + + +async function geminiGET(access: GeminiAccessSchema, modelRefId: string | null, apiPath: string /*, signal?: AbortSignal*/): Promise { + const { headers, url } = geminiAccess(access, modelRefId, apiPath); + return await fetchJsonOrTRPCError(url, 'GET', headers, undefined, 'Gemini'); +} + +async function geminiPOST(access: GeminiAccessSchema, modelRefId: string | null, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise { + const { headers, url } = geminiAccess(access, modelRefId, apiPath); + return await fetchJsonOrTRPCError(url, 'POST', headers, body, 'Gemini'); +} + + +// Input/Output Schemas + +export const geminiAccessSchema = z.object({ + dialect: z.enum(['gemini']), + geminiKey: z.string(), +}); +export type GeminiAccessSchema = z.infer; + + +const accessOnlySchema = z.object({ + access: geminiAccessSchema, +}); + +const chatGenerateInputSchema = z.object({ + access: geminiAccessSchema, + model: openAIModelSchema, history: openAIHistorySchema, + // functions: openAIFunctionsSchema.optional(), forceFunctionName: z.string().optional(), +}); + + +export const llmGeminiRouter = createTRPCRouter({ + + /* [Gemini] models.list = /v1beta/models */ + listModels: publicProcedure + .input(accessOnlySchema) + .output(listModelsOutputSchema) + .query(async ({ input }) => { + + // get the models + const wireModels = await geminiGET(input.access, null, geminiModelsListPath); + const detailedModels = geminiModelsListOutputSchema.parse(wireModels).models; + + // NOTE: no need to retrieve info for each of the models (e.g. /v1beta/model/gemini-pro)., + // as the List API already all the info on all the models + + // map to our output schema + return { + models: detailedModels.map((geminiModel) => { + const { description, displayName, inputTokenLimit, name, outputTokenLimit, supportedGenerationMethods } = geminiModel; + + const contextWindow = inputTokenLimit + outputTokenLimit; + const hidden = !supportedGenerationMethods.includes('generateContent'); + + const { version, topK, topP, temperature } = geminiModel; + const descriptionLong = description + ` (Version: ${version}, Defaults: temperature=${temperature}, topP=${topP}, topK=${topK}, interfaces=[${supportedGenerationMethods.join(',')}])`; + + // const isGeminiPro = name.includes('gemini-pro'); + const isGeminiProVision = name.includes('gemini-pro-vision'); + + const interfaces: ModelDescriptionSchema['interfaces'] = []; + if (supportedGenerationMethods.includes('generateContent')) { + interfaces.push(LLM_IF_OAI_Chat); + if (isGeminiProVision) + interfaces.push(LLM_IF_OAI_Vision); + } + + return { + id: name, + label: displayName, + // created: ... + // updated: ... + description: descriptionLong, + contextWindow: contextWindow, + maxCompletionTokens: outputTokenLimit, + // pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined, + // rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined, + interfaces: supportedGenerationMethods.includes('generateContent') ? [LLM_IF_OAI_Chat] : [], + hidden, + } satisfies ModelDescriptionSchema; + }), + }; + }), + + + /* [Gemini] models.generateContent = /v1/{model=models/*}:generateContent */ + chatGenerate: publicProcedure + .input(chatGenerateInputSchema) + .output(openAIChatGenerateOutputSchema) + .mutation(async ({ input: { access, history, model } }) => { + + // generate the content + const wireGeneration = await geminiPOST(access, model.id, geminiGenerateContentPayload(model, history, 1), geminiModelsGenerateContentPath); + const generation = geminiGeneratedContentResponseSchema.parse(wireGeneration); + + // only use the first result (and there should be only one) + const singleCandidate = generation.candidates?.[0] ?? null; + if (!singleCandidate || !singleCandidate.content?.parts.length) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `Gemini chat-generation API issue: ${JSON.stringify(wireGeneration)}`, + }); + + if (!('text' in singleCandidate.content.parts[0])) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `Gemini non-text chat-generation API issue: ${JSON.stringify(wireGeneration)}`, + }); + + return { + role: 'assistant', + content: singleCandidate.content.parts[0].text || '', + finish_reason: singleCandidate.finishReason === 'STOP' ? 'stop' : null, + }; + }), + +}); diff --git a/src/modules/llms/transports/server/gemini/gemini.wiretypes.ts b/src/modules/llms/transports/server/gemini/gemini.wiretypes.ts new file mode 100644 index 000000000..03805ad86 --- /dev/null +++ b/src/modules/llms/transports/server/gemini/gemini.wiretypes.ts @@ -0,0 +1,175 @@ +import { z } from 'zod'; + +// PATHS + +export const geminiModelsListPath = '/v1beta/models?pageSize=1000'; +export const geminiModelsGenerateContentPath = '/v1beta/{model=models/*}:generateContent'; +export const geminiModelsStreamGenerateContentPath = '/v1beta/{model=models/*}:streamGenerateContent'; + + +// models.list = /v1beta/models + +export const geminiModelsListOutputSchema = z.object({ + models: z.array(z.object({ + name: z.string(), + version: z.string(), + displayName: z.string(), + description: z.string(), + inputTokenLimit: z.number().int().min(1), + outputTokenLimit: z.number().int().min(1), + supportedGenerationMethods: z.array(z.enum([ + 'countMessageTokens', + 'countTextTokens', + 'countTokens', + 'createTunedTextModel', + 'embedContent', + 'embedText', + 'generateAnswer', + 'generateContent', + 'generateMessage', + 'generateText', + ])), + temperature: z.number().optional(), + topP: z.number().optional(), + topK: z.number().optional(), + })), +}); + + +// /v1/{model=models/*}:generateContent, /v1beta/{model=models/*}:streamGenerateContent + +const geminiContentPartSchema = z.union([ + + // TextPart + z.object({ + text: z.string().optional(), + }), + + // InlineDataPart + z.object({ + inlineData: z.object({ + mimeType: z.string(), + data: z.string(), // base64-encoded string + }), + }), + + // A predicted FunctionCall returned from the model + z.object({ + functionCall: z.object({ + name: z.string(), + args: z.record(z.any()), // JSON object format + }), + }), + + // The result output of a FunctionCall + z.object({ + functionResponse: z.object({ + name: z.string(), + response: z.record(z.any()), // JSON object format + }), + }), +]); + +const geminiToolSchema = z.object({ + functionDeclarations: z.array(z.object({ + name: z.string(), + description: z.string(), + parameters: z.record(z.any()).optional(), // Schema object format + })).optional(), +}); + +const geminiHarmCategorySchema = z.enum([ + 'HARM_CATEGORY_UNSPECIFIED', + 'HARM_CATEGORY_DEROGATORY', + 'HARM_CATEGORY_TOXICITY', + 'HARM_CATEGORY_VIOLENCE', + 'HARM_CATEGORY_SEXUAL', + 'HARM_CATEGORY_MEDICAL', + 'HARM_CATEGORY_DANGEROUS', + 'HARM_CATEGORY_HARASSMENT', + 'HARM_CATEGORY_HATE_SPEECH', + 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + 'HARM_CATEGORY_DANGEROUS_CONTENT', +]); + + +const geminiSafetySettingSchema = z.object({ + category: geminiHarmCategorySchema, + threshold: z.enum([ + 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + 'BLOCK_LOW_AND_ABOVE', + 'BLOCK_MEDIUM_AND_ABOVE', + 'BLOCK_ONLY_HIGH', + 'BLOCK_NONE', + ]), +}); + +const geminiGenerationConfigSchema = z.object({ + stopSequences: z.array(z.string()).optional(), + candidateCount: z.number().int().optional(), + maxOutputTokens: z.number().int().optional(), + temperature: z.number().optional(), + topP: z.number().optional(), + topK: z.number().int().optional(), +}); + +const geminiContentSchema = z.object({ + parts: z.array(geminiContentPartSchema), // Ordered Parts that constitute a single message. Parts may have different MIME types. + role: z.enum(['user', 'model']).optional(), // Optional. The producer of the content. Must be either 'user' or 'model'. +}); + +export const geminiGenerateContentRequest = z.object({ + contents: z.array(geminiContentSchema), + tools: z.array(geminiToolSchema).optional(), + safetySettings: z.array(geminiSafetySettingSchema).optional(), + generationConfig: geminiGenerationConfigSchema.optional(), +}); + +export type GeminiGenerateContentRequest = z.infer; + + +const geminiHarmProbabilitySchema = z.enum([ + 'HARM_PROBABILITY_UNSPECIFIED', + 'NEGLIGIBLE', + 'LOW', + 'MEDIUM', + 'HIGH', +]); + +const geminiSafetyRatingSchema = z.object({ + 'category': geminiHarmCategorySchema, + 'probability': geminiHarmProbabilitySchema, + 'blocked': z.boolean().optional(), +}); + +const geminiFinishReasonSchema = z.enum([ + 'FINISH_REASON_UNSPECIFIED', + 'STOP', + 'MAX_TOKENS', + 'SAFETY', + 'RECITATION', + 'OTHER', +]); + +export const geminiGeneratedContentResponseSchema = z.object({ + // either all requested candidates are returned or no candidates at all + // no candidates are returned only if there was something wrong with the prompt (see promptFeedback) + candidates: z.array(z.object({ + index: z.number(), + content: geminiContentSchema, + finishReason: geminiFinishReasonSchema.optional(), + safetyRatings: z.array(geminiSafetyRatingSchema), + citationMetadata: z.object({ + startIndex: z.number().optional(), + endIndex: z.number().optional(), + uri: z.string().optional(), + license: z.string().optional(), + }).optional(), + tokenCount: z.number().optional(), + // groundingAttributions: z.array(GroundingAttribution).optional(), // This field is populated for GenerateAnswer calls. + })), + promptFeedback: z.object({ + blockReason: z.enum(['BLOCK_REASON_UNSPECIFIED', 'SAFETY', 'OTHER']).optional(), + safetyRatings: z.array(geminiSafetyRatingSchema), + }), +}); \ No newline at end of file diff --git a/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx b/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx new file mode 100644 index 000000000..37e9d3044 --- /dev/null +++ b/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx @@ -0,0 +1,60 @@ +import * as React from 'react'; + +import { FormInputKey } from '~/common/components/forms/FormInputKey'; +import { InlineError } from '~/common/components/InlineError'; +import { Link } from '~/common/components/Link'; +import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; +import { apiQuery } from '~/common/util/trpc.client'; + +import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; +import { ModelVendorGemini } from './gemini.vendor'; +import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; + + +const GEMINI_API_KEY_LINK = 'https://makersuite.google.com/app/apikey'; + + +export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) { + + // external state + const { source, sourceSetupValid, access, updateSetup } = + useSourceSetup(props.sourceId, ModelVendorGemini); + + // derived state + const { geminiKey } = access; + + const needsUserKey = !ModelVendorGemini.hasBackendCap?.(); + const shallFetchSucceed = !needsUserKey || (!!geminiKey && sourceSetupValid); + const showKeyError = !!geminiKey && !sourceSetupValid; + + // fetch models + const { isFetching, refetch, isError, error } = apiQuery.llmGemini.listModels.useQuery({ access }, { + enabled: shallFetchSucceed, + onSuccess: models => source && useModelsStore.getState().setLLMs( + models.models.map(model => modelDescriptionToDLLM(model, source)), + props.sourceId, + ), + staleTime: Infinity, + }); + + return <> + + {needsUserKey + ? !geminiKey && request Key + : 'โœ”๏ธ already set in server'} + } + value={geminiKey} onChange={value => updateSetup({ geminiKey: value })} + required={needsUserKey} isError={showKeyError} + placeholder='...' + /> + + + + {isError && } + + ; +} \ No newline at end of file diff --git a/src/modules/llms/vendors/googleai/gemini.vendor.ts b/src/modules/llms/vendors/googleai/gemini.vendor.ts new file mode 100644 index 000000000..ee809c81f --- /dev/null +++ b/src/modules/llms/vendors/googleai/gemini.vendor.ts @@ -0,0 +1,85 @@ +import GoogleIcon from '@mui/icons-material/Google'; + +import { backendCaps } from '~/modules/backend/state-backend'; + +import type { IModelVendor } from '../IModelVendor'; +import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; + +import type { GeminiAccessSchema } from '../../transports/server/gemini/gemini.router'; +import { GeminiSourceSetup } from './GeminiSourceSetup'; +import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; +import { apiAsync } from '~/common/util/trpc.client'; + + +export interface SourceSetupGemini { + geminiKey: string; +} + +export interface LLMOptionsGemini { + llmRef: string; + stopSequences: string[]; // up to 5 sequences that will stop generation (optional) + candidateCount: number; // 1...8 number of generated responses to return (optional) + maxOutputTokens: number; // if unset, this will default to outputTokenLimit (optional) + temperature: number; // 0...1 Controls the randomness of the output. (optional) + topP: number; // 0...1 The maximum cumulative probability of tokens to consider when sampling (optional) + topK: number; // 1...100 The maximum number of tokens to consider when sampling (optional) +} + + +export const ModelVendorGemini: IModelVendor = { + id: 'googleai', + name: 'Gemini', + rank: 11, + location: 'cloud', + instanceLimit: 1, + hasBackendCap: () => backendCaps().hasLlmGemini, + + // components + Icon: GoogleIcon, + SourceSetupComponent: GeminiSourceSetup, + LLMOptionsComponent: OpenAILLMOptions, + + // functions + initializeSetup: () => ({ + geminiKey: '', + }), + validateSetup: (setup) => { + return setup.geminiKey?.length > 0; + }, + getTransportAccess: (partialSetup): GeminiAccessSchema => ({ + dialect: 'gemini', + geminiKey: partialSetup?.geminiKey || '', + }), + callChatGenerate(llm, messages: VChatMessageIn[], maxTokens?: number): Promise { + return geminiCallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, maxTokens); + }, + callChatGenerateWF(): Promise { + throw new Error('Gemini does not support "Functions" yet'); + }, +}; + +/** + * This function either returns the LLM message, or throws a descriptive error string + */ +async function geminiCallChatGenerate( + access: GeminiAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], + maxTokens?: number, +): Promise { + const { llmRef, temperature = 0.5, maxOutputTokens } = llmOptions; + try { + return await apiAsync.llmGemini.chatGenerate.mutate({ + access, + model: { + id: llmRef!, + temperature: temperature, + maxTokens: maxTokens || maxOutputTokens || 1024, + }, + history: messages, + }) as TOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'Gemini Chat Generate Error'; + console.error(`geminiCallChatGenerate: ${errorMessage}`); + throw new Error(errorMessage); + } +} + diff --git a/src/modules/llms/vendors/vendors.registry.ts b/src/modules/llms/vendors/vendors.registry.ts index d20770dda..ac0f223f4 100644 --- a/src/modules/llms/vendors/vendors.registry.ts +++ b/src/modules/llms/vendors/vendors.registry.ts @@ -1,5 +1,6 @@ import { ModelVendorAnthropic } from './anthropic/anthropic.vendor'; import { ModelVendorAzure } from './azure/azure.vendor'; +import { ModelVendorGemini } from '~/modules/llms/vendors/googleai/gemini.vendor'; import { ModelVendorLocalAI } from './localai/localai.vendor'; import { ModelVendorMistral } from './mistral/mistral.vendor'; import { ModelVendorOllama } from './ollama/ollama.vendor'; @@ -14,6 +15,7 @@ import { DLLMId, DModelSource, DModelSourceId, findLLMOrThrow } from '../store-l const MODEL_VENDOR_REGISTRY = { anthropic: ModelVendorAnthropic, azure: ModelVendorAzure, + googleai: ModelVendorGemini, localai: ModelVendorLocalAI, mistral: ModelVendorMistral, ollama: ModelVendorOllama, diff --git a/src/server/api/trpc.router-edge.ts b/src/server/api/trpc.router-edge.ts index 96464554a..c9513f71f 100644 --- a/src/server/api/trpc.router-edge.ts +++ b/src/server/api/trpc.router-edge.ts @@ -4,6 +4,7 @@ import { backendRouter } from '~/modules/backend/backend.router'; import { elevenlabsRouter } from '~/modules/elevenlabs/elevenlabs.router'; import { googleSearchRouter } from '~/modules/google/search.router'; import { llmAnthropicRouter } from '~/modules/llms/transports/server/anthropic/anthropic.router'; +import { llmGeminiRouter } from '~/modules/llms/transports/server/gemini/gemini.router'; import { llmOllamaRouter } from '~/modules/llms/transports/server/ollama/ollama.router'; import { llmOpenAIRouter } from '~/modules/llms/transports/server/openai/openai.router'; import { prodiaRouter } from '~/modules/prodia/prodia.router'; @@ -17,6 +18,7 @@ export const appRouterEdge = createTRPCRouter({ elevenlabs: elevenlabsRouter, googleSearch: googleSearchRouter, llmAnthropic: llmAnthropicRouter, + llmGemini: llmGeminiRouter, llmOllama: llmOllamaRouter, llmOpenAI: llmOpenAIRouter, prodia: prodiaRouter, From 044ed4df798db3cccc0ab8cf96265d1f12c70e55 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 03:58:06 -0800 Subject: [PATCH 06/24] Bits for the future --- src/modules/llms/store-llms.ts | 12 ++++++++++++ src/modules/llms/transports/server/server.schemas.ts | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/src/modules/llms/store-llms.ts b/src/modules/llms/store-llms.ts index 9d13b9574..74a3fe097 100644 --- a/src/modules/llms/store-llms.ts +++ b/src/modules/llms/store-llms.ts @@ -17,6 +17,7 @@ export interface DLLM { updated?: number | 0; description: string; tags: string[]; // UNUSED for now + // modelcaps: DModelCapability[]; contextTokens: number; maxOutputTokens: number; hidden: boolean; @@ -31,6 +32,17 @@ export interface DLLM { export type DLLMId = string; +// export type DModelCapability = +// | 'input-text' +// | 'input-image-data' +// | 'input-multipart' +// | 'output-text' +// | 'output-function' +// | 'output-image-data' +// | 'if-chat' +// | 'if-fast-chat' +// ; + // Model interfaces (chat, and function calls) - here as a preview, will be used more broadly in the future export const LLM_IF_OAI_Chat = 'oai-chat'; export const LLM_IF_OAI_Vision = 'oai-vision'; diff --git a/src/modules/llms/transports/server/server.schemas.ts b/src/modules/llms/transports/server/server.schemas.ts index 4614f4ba3..f72313d57 100644 --- a/src/modules/llms/transports/server/server.schemas.ts +++ b/src/modules/llms/transports/server/server.schemas.ts @@ -6,6 +6,10 @@ const pricingSchema = z.object({ cpmCompletion: z.number().optional(), // Cost per thousand completion tokens }); +// const rateLimitsSchema = z.object({ +// reqPerMinute: z.number().optional(), +// }); + const modelDescriptionSchema = z.object({ id: z.string(), label: z.string(), @@ -15,6 +19,7 @@ const modelDescriptionSchema = z.object({ contextWindow: z.number(), maxCompletionTokens: z.number().optional(), pricing: pricingSchema.optional(), + // rateLimits: rateLimitsSchema.optional(), interfaces: z.array(z.enum([LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Complete, LLM_IF_OAI_Vision])), hidden: z.boolean().optional(), }); From 201e3a7252c9bdc1a650000274fba01aa69c43d5 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 14:35:16 -0800 Subject: [PATCH 07/24] Streaming: cleanup --- app/api/llms/stream/route.ts | 2 +- .../openai.streaming.ts => llms.streaming.ts} | 266 ++++++++++-------- src/modules/llms/transports/streamChat.ts | 4 +- 3 files changed, 145 insertions(+), 127 deletions(-) rename src/modules/llms/transports/server/{openai/openai.streaming.ts => llms.streaming.ts} (80%) diff --git a/app/api/llms/stream/route.ts b/app/api/llms/stream/route.ts index c7430d3bc..76794a39e 100644 --- a/app/api/llms/stream/route.ts +++ b/app/api/llms/stream/route.ts @@ -1,2 +1,2 @@ export const runtime = 'edge'; -export { openaiStreamingRelayHandler as POST } from '~/modules/llms/transports/server/openai/openai.streaming'; \ No newline at end of file +export { llmStreamingRelayHandler as POST } from '~/modules/llms/transports/server/llms.streaming'; \ No newline at end of file diff --git a/src/modules/llms/transports/server/openai/openai.streaming.ts b/src/modules/llms/transports/server/llms.streaming.ts similarity index 80% rename from src/modules/llms/transports/server/openai/openai.streaming.ts rename to src/modules/llms/transports/server/llms.streaming.ts index e8065bcc4..7383d89f5 100644 --- a/src/modules/llms/transports/server/openai/openai.streaming.ts +++ b/src/modules/llms/transports/server/llms.streaming.ts @@ -4,12 +4,26 @@ import { createParser as createEventsourceParser, EventSourceParseCallback, Even import { createEmptyReadableStream, debugGenerateCurlCommand, safeErrorString, SERVER_DEBUG_WIRE, serverFetchOrThrow } from '~/server/wire'; -import type { AnthropicWire } from '../anthropic/anthropic.wiretypes'; -import type { OpenAIWire } from './openai.wiretypes'; -import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from '../ollama/ollama.router'; -import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from '../anthropic/anthropic.router'; -import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai.router'; -import { wireOllamaChunkedOutputSchema } from '../ollama/ollama.wiretypes'; + +// Anthropic server imports +import type { AnthropicWire } from './anthropic/anthropic.wiretypes'; +import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from './anthropic/anthropic.router'; + +// Ollama server imports +import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes'; +import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from './ollama/ollama.router'; + +// OpenAI server imports +import type { OpenAIWire } from './openai/openai.wiretypes'; +import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai/openai.router'; + + +/** + * Event stream formats + * - 'sse' is the default format, and is used by all vendors except Ollama + * - 'json-nl' is used by Ollama + */ +type EventStreamFormat = 'sse' | 'json-nl'; /** @@ -20,46 +34,49 @@ import { wireOllamaChunkedOutputSchema } from '../ollama/ollama.wiretypes'; * The peculiarity of our parser is the injection of a JSON structure at the beginning of the stream, to * communicate parameters before the text starts flowing to the client. */ -export type AIStreamParser = (data: string) => { text: string, close: boolean }; - -type EventStreamFormat = 'sse' | 'json-nl'; +type AIStreamParser = (data: string) => { text: string, close: boolean }; const chatStreamInputSchema = z.object({ access: z.union([anthropicAccessSchema, ollamaAccessSchema, openAIAccessSchema]), - model: openAIModelSchema, history: openAIHistorySchema, + model: openAIModelSchema, + history: openAIHistorySchema, }); export type ChatStreamInputSchema = z.infer; -const chatStreamFirstPacketSchema = z.object({ +const chatStreamFirstOutputPacketSchema = z.object({ model: z.string(), }); -export type ChatStreamFirstPacketSchema = z.infer; +export type ChatStreamFirstOutputPacketSchema = z.infer; -export async function openaiStreamingRelayHandler(req: NextRequest): Promise { +export async function llmStreamingRelayHandler(req: NextRequest): Promise { // inputs - reuse the tRPC schema - const { access, model, history } = chatStreamInputSchema.parse(await req.json()); + const body = await req.json(); + const { access, model, history } = chatStreamInputSchema.parse(body); - // begin event streaming from the OpenAI API - let headersUrl: { headers: HeadersInit, url: string } = { headers: {}, url: '' }; + // access/dialect dependent setup: + // - requestAccess: the headers and URL to use for the upstream API call + // - eventStreamFormat: the format of the event stream (sse or json-nl) + // - vendorStreamParser: the parser to use for the event stream let upstreamResponse: Response; - let vendorStreamParser: AIStreamParser; + let requestAccess: { headers: HeadersInit, url: string } = { headers: {}, url: '' }; let eventStreamFormat: EventStreamFormat = 'sse'; + let vendorStreamParser: AIStreamParser; try { // prepare the API request data let body: object; switch (access.dialect) { case 'anthropic': - headersUrl = anthropicAccess(access, '/v1/complete'); + requestAccess = anthropicAccess(access, '/v1/complete'); body = anthropicChatCompletionPayload(model, history, true); vendorStreamParser = createAnthropicStreamParser(); break; case 'ollama': - headersUrl = ollamaAccess(access, OLLAMA_PATH_CHAT); + requestAccess = ollamaAccess(access, OLLAMA_PATH_CHAT); body = ollamaChatCompletionPayload(model, history, true); eventStreamFormat = 'json-nl'; vendorStreamParser = createOllamaChatCompletionStreamParser(); @@ -71,27 +88,27 @@ export async function openaiStreamingRelayHandler(req: NextRequest): Promise streaming:', debugGenerateCurlCommand('POST', headersUrl.url, headersUrl.headers, body)); + console.log('-> streaming:', debugGenerateCurlCommand('POST', requestAccess.url, requestAccess.headers, body)); // POST to our API route - upstreamResponse = await serverFetchOrThrow(headersUrl.url, 'POST', headersUrl.headers, body); + upstreamResponse = await serverFetchOrThrow(requestAccess.url, 'POST', requestAccess.headers, body); } catch (error: any) { const fetchOrVendorError = safeErrorString(error) + (error?.cause ? ' ยท ' + error.cause : ''); // server-side admins message - console.error(`/api/llms/stream: fetch issue:`, access.dialect, fetchOrVendorError, headersUrl?.url); + console.error(`/api/llms/stream: fetch issue:`, access.dialect, fetchOrVendorError, requestAccess?.url); // client-side users visible message return new NextResponse(`[Issue] ${access.dialect}: ${fetchOrVendorError}` - + (process.env.NODE_ENV === 'development' ? ` ยท [URL: ${headersUrl?.url}]` : ''), { status: 500 }); + + (process.env.NODE_ENV === 'development' ? ` ยท [URL: ${requestAccess?.url}]` : ''), { status: 500 }); } /* The following code is heavily inspired by the Vercel AI SDK, but simplified to our needs and in full control. @@ -103,8 +120,9 @@ export async function openaiStreamingRelayHandler(req: NextRequest): Promise { + accumulator += chunk; + if (accumulator.endsWith('\n')) { + for (const jsonString of accumulator.split('\n').filter(line => !!line)) { + const mimicEvent: ParsedEvent = { + type: 'event', + id: undefined, + event: undefined, + data: jsonString, + }; + onParse(mimicEvent); + } + accumulator = ''; + } + }, + + // resets the parser state - not useful with our driving of the parser + reset: (): void => { + console.error('createJsonNewlineParser.reset() not implemented'); + }, + }; +} + +/** + * Creates a TransformStream that parses events from an EventSource stream using a custom parser. + * @returns {TransformStream} TransformStream parsing events. + */ +function createEventStreamTransformer(vendorTextParser: AIStreamParser, inputFormat: EventStreamFormat, dialectLabel: string): TransformStream { + const textDecoder = new TextDecoder(); + const textEncoder = new TextEncoder(); + let eventSourceParser: EventSourceParser; + + return new TransformStream({ + start: async (controller): Promise => { + + // only used for debugging + let debugLastMs: number | null = null; + + const onNewEvent = (event: ParsedEvent | ReconnectInterval) => { + if (SERVER_DEBUG_WIRE) { + const nowMs = Date.now(); + const elapsedMs = debugLastMs ? nowMs - debugLastMs : 0; + debugLastMs = nowMs; + console.log(`<- SSE (${elapsedMs} ms):`, event); + } + + // ignore 'reconnect-interval' and events with no data + if (event.type !== 'event' || !('data' in event)) + return; + + // event stream termination, close our transformed stream + if (event.data === '[DONE]') { + controller.terminate(); + return; + } + + try { + const { text, close } = vendorTextParser(event.data); + if (text) + controller.enqueue(textEncoder.encode(text)); + if (close) + controller.terminate(); + } catch (error: any) { + if (SERVER_DEBUG_WIRE) + console.log(' - E: parse issue:', event.data, error?.message || error); + controller.enqueue(textEncoder.encode(` **[Stream Issue] ${dialectLabel}: ${safeErrorString(error) || 'Unknown stream parsing error'}**`)); + controller.terminate(); + } + }; + + if (inputFormat === 'sse') + eventSourceParser = createEventsourceParser(onNewEvent); + else if (inputFormat === 'json-nl') + eventSourceParser = createJsonNewlineParser(onNewEvent); + }, + + // stream=true is set because the data is not guaranteed to be final and un-chunked + transform: (chunk: Uint8Array) => { + eventSourceParser.feed(textDecoder.decode(chunk, { stream: true })); + }, + }); +} + + +/// Stream Parsers function createAnthropicStreamParser(): AIStreamParser { let hasBegun = false; @@ -128,7 +240,7 @@ function createAnthropicStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun) { hasBegun = true; - const firstPacket: ChatStreamFirstPacketSchema = { model: json.model }; + const firstPacket: ChatStreamFirstOutputPacketSchema = { model: json.model }; text = JSON.stringify(firstPacket) + text; } @@ -164,7 +276,7 @@ function createOllamaChatCompletionStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun && chunk.model) { hasBegun = true; - const firstPacket: ChatStreamFirstPacketSchema = { model: chunk.model }; + const firstPacket: ChatStreamFirstOutputPacketSchema = { model: chunk.model }; text = JSON.stringify(firstPacket) + text; } @@ -205,7 +317,7 @@ function createOpenAIStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun) { hasBegun = true; - const firstPacket: ChatStreamFirstPacketSchema = { model: json.model }; + const firstPacket: ChatStreamFirstOutputPacketSchema = { model: json.model }; text = JSON.stringify(firstPacket) + text; } @@ -213,98 +325,4 @@ function createOpenAIStreamParser(): AIStreamParser { const close = !!json.choices[0].finish_reason; return { text, close }; }; -} - - -// Event Stream Transformers - -/** - * Creates a TransformStream that parses events from an EventSource stream using a custom parser. - * @returns {TransformStream} TransformStream parsing events. - */ -function createEventStreamTransformer(vendorTextParser: AIStreamParser, inputFormat: EventStreamFormat, dialectLabel: string): TransformStream { - const textDecoder = new TextDecoder(); - const textEncoder = new TextEncoder(); - let eventSourceParser: EventSourceParser; - - return new TransformStream({ - start: async (controller): Promise => { - - // only used for debugging - let debugLastMs: number | null = null; - - const onNewEvent = (event: ParsedEvent | ReconnectInterval) => { - if (SERVER_DEBUG_WIRE) { - const nowMs = Date.now(); - const elapsedMs = debugLastMs ? nowMs - debugLastMs : 0; - debugLastMs = nowMs; - console.log(`<- SSE (${elapsedMs} ms):`, event); - } - - // ignore 'reconnect-interval' and events with no data - if (event.type !== 'event' || !('data' in event)) - return; - - // event stream termination, close our transformed stream - if (event.data === '[DONE]') { - controller.terminate(); - return; - } - - try { - const { text, close } = vendorTextParser(event.data); - if (text) - controller.enqueue(textEncoder.encode(text)); - if (close) - controller.terminate(); - } catch (error: any) { - if (SERVER_DEBUG_WIRE) - console.log(' - E: parse issue:', event.data, error?.message || error); - controller.enqueue(textEncoder.encode(` **[Stream Issue] ${dialectLabel}: ${safeErrorString(error) || 'Unknown stream parsing error'}**`)); - controller.terminate(); - } - }; - - if (inputFormat === 'sse') - eventSourceParser = createEventsourceParser(onNewEvent); - else if (inputFormat === 'json-nl') - eventSourceParser = createJsonNewlineParser(onNewEvent); - }, - - // stream=true is set because the data is not guaranteed to be final and un-chunked - transform: (chunk: Uint8Array) => { - eventSourceParser.feed(textDecoder.decode(chunk, { stream: true })); - }, - }); -} - -/** - * Creates a parser for a 'JSON\n' non-event stream, to be swapped with an EventSource parser. - * Ollama is the only vendor that uses this format. - */ -function createJsonNewlineParser(onParse: EventSourceParseCallback): EventSourceParser { - let accumulator: string = ''; - return { - // feeds a new chunk to the parser - we accumulate in case of partial data, and only execute on full lines - feed: (chunk: string): void => { - accumulator += chunk; - if (accumulator.endsWith('\n')) { - for (const jsonString of accumulator.split('\n').filter(line => !!line)) { - const mimicEvent: ParsedEvent = { - type: 'event', - id: undefined, - event: undefined, - data: jsonString, - }; - onParse(mimicEvent); - } - accumulator = ''; - } - }, - - // resets the parser state - not useful with our driving of the parser - reset: (): void => { - console.error('createJsonNewlineParser.reset() not implemented'); - }, - }; -} +} \ No newline at end of file diff --git a/src/modules/llms/transports/streamChat.ts b/src/modules/llms/transports/streamChat.ts index 4b6159752..332e59ca5 100644 --- a/src/modules/llms/transports/streamChat.ts +++ b/src/modules/llms/transports/streamChat.ts @@ -3,7 +3,7 @@ import { apiAsync } from '~/common/util/trpc.client'; import type { DLLM, DLLMId } from '../store-llms'; import { findVendorForLlmOrThrow } from '../vendors/vendors.registry'; -import type { ChatStreamFirstPacketSchema, ChatStreamInputSchema } from './server/openai/openai.streaming'; +import type { ChatStreamFirstOutputPacketSchema, ChatStreamInputSchema } from './server/llms.streaming'; import type { OpenAIWire } from './server/openai/openai.wiretypes'; import type { VChatMessageIn } from './chatGenerate'; @@ -131,7 +131,7 @@ async function vendorStreamChat( incrementalText = incrementalText.substring(endOfJson + 1); parsedFirstPacket = true; try { - const parsed: ChatStreamFirstPacketSchema = JSON.parse(json); + const parsed: ChatStreamFirstOutputPacketSchema = JSON.parse(json); onUpdate({ originLLM: parsed.model }, false); } catch (e) { // error parsing JSON, ignore From 8e3f247bfb83297118945732ead056204b095fb7 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 14:39:39 -0800 Subject: [PATCH 08/24] Gemini: cleaner --- .../llms/transports/server/gemini/gemini.wiretypes.ts | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/modules/llms/transports/server/gemini/gemini.wiretypes.ts b/src/modules/llms/transports/server/gemini/gemini.wiretypes.ts index 03805ad86..e5aa890ac 100644 --- a/src/modules/llms/transports/server/gemini/gemini.wiretypes.ts +++ b/src/modules/llms/transports/server/gemini/gemini.wiretypes.ts @@ -38,6 +38,8 @@ export const geminiModelsListOutputSchema = z.object({ // /v1/{model=models/*}:generateContent, /v1beta/{model=models/*}:streamGenerateContent +// Request + const geminiContentPartSchema = z.union([ // TextPart @@ -114,8 +116,10 @@ const geminiGenerationConfigSchema = z.object({ }); const geminiContentSchema = z.object({ - parts: z.array(geminiContentPartSchema), // Ordered Parts that constitute a single message. Parts may have different MIME types. - role: z.enum(['user', 'model']).optional(), // Optional. The producer of the content. Must be either 'user' or 'model'. + // Must be either 'user' or 'model'. Optional but must be set if there are multiple "Content" objects in the parent array. + role: z.enum(['user', 'model']).optional(), + // Ordered Parts that constitute a single message. Parts may have different MIME types. + parts: z.array(geminiContentPartSchema), }); export const geminiGenerateContentRequest = z.object({ @@ -128,6 +132,8 @@ export const geminiGenerateContentRequest = z.object({ export type GeminiGenerateContentRequest = z.infer; +// Response + const geminiHarmProbabilitySchema = z.enum([ 'HARM_PROBABILITY_UNSPECIFIED', 'NEGLIGIBLE', From 6b2bfa60600112ac0c73783c1a192e11d3da65f3 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 16:13:15 -0800 Subject: [PATCH 09/24] Llms: cleanup model lists --- .../anthropic/AnthropicSourceSetup.tsx | 17 ++---- .../vendors/anthropic/anthropic.vendor.ts | 15 +++++- .../llms/vendors/azure/AzureSourceSetup.tsx | 17 +++--- .../vendors/googleai/GeminiSourceSetup.tsx | 18 +++---- .../llms/vendors/googleai/gemini.vendor.ts | 20 +++++-- .../vendors/localai/LocalAISourceSetup.tsx | 17 +++--- .../vendors/mistral/MistralSourceSetup.tsx | 17 +++--- .../llms/vendors/ollama/OllamaSourceSetup.tsx | 18 +++---- .../llms/vendors/ollama/ollama.vendor.ts | 17 ++++-- .../oobabooga/OobaboogaSourceSetup.tsx | 17 +++--- .../llms/vendors/openai/OpenAISourceSetup.tsx | 45 ++-------------- .../llms/vendors/openai/openai.vendor.ts | 13 ++++- .../openrouter/OpenRouterSourceSetup.tsx | 17 +++--- .../llms/vendors/useUpdateVendorModels.tsx | 54 +++++++++++++++++++ 14 files changed, 162 insertions(+), 140 deletions(-) create mode 100644 src/modules/llms/vendors/useUpdateVendorModels.tsx diff --git a/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx b/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx index 29dae70a9..87bcbf898 100644 --- a/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx +++ b/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx @@ -7,13 +7,12 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { useToggleableBoolean } from '~/common/util/useToggleableBoolean'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; -import { isValidAnthropicApiKey, ModelVendorAnthropic } from './anthropic.vendor'; +import { anthropicListModelsQuery, isValidAnthropicApiKey, ModelVendorAnthropic } from './anthropic.vendor'; export function AnthropicSourceSetup(props: { sourceId: DModelSourceId }) { @@ -34,14 +33,8 @@ export function AnthropicSourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = anthropicKey ? keyValid : (!needsUserKey || !!anthropicHost); // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmAnthropic.listModels.useQuery({ access }, { - enabled: !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(anthropicListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts index b0654601e..a1f743de7 100644 --- a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts +++ b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts @@ -1,10 +1,11 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { AnthropicIcon } from '~/common/components/icons/AnthropicIcon'; -import { apiAsync } from '~/common/util/trpc.client'; +import { apiAsync, apiQuery } from '~/common/util/trpc.client'; -import type { IModelVendor } from '../IModelVendor'; import type { AnthropicAccessSchema } from '../../transports/server/anthropic/anthropic.router'; +import type { IModelVendor } from '../IModelVendor'; +import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; import { LLMOptionsOpenAI } from '../openai/openai.vendor'; @@ -51,6 +52,16 @@ export const ModelVendorAnthropic: IModelVendor void) { + return apiQuery.llmAnthropic.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); +} + + /** * This function either returns the LLM message, or function calls, or throws a descriptive error string */ diff --git a/src/modules/llms/vendors/azure/AzureSourceSetup.tsx b/src/modules/llms/vendors/azure/AzureSourceSetup.tsx index 7ed3c798c..0c4904424 100644 --- a/src/modules/llms/vendors/azure/AzureSourceSetup.tsx +++ b/src/modules/llms/vendors/azure/AzureSourceSetup.tsx @@ -5,11 +5,12 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { asValidURL } from '~/common/util/urlUtils'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; + +import { openAIListModelsQuery } from '../openai/openai.vendor'; import { isValidAzureApiKey, ModelVendorAzure } from './azure.vendor'; @@ -31,14 +32,8 @@ export function AzureSourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = azureKey ? keyValid : !needsUserKey; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx b/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx index 37e9d3044..f01ba9e1d 100644 --- a/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx +++ b/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx @@ -4,11 +4,11 @@ import { FormInputKey } from '~/common/components/forms/FormInputKey'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { ModelVendorGemini } from './gemini.vendor'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; + +import { geminiListModelsQuery, ModelVendorGemini } from './gemini.vendor'; const GEMINI_API_KEY_LINK = 'https://makersuite.google.com/app/apikey'; @@ -28,14 +28,8 @@ export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) { const showKeyError = !!geminiKey && !sourceSetupValid; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmGemini.listModels.useQuery({ access }, { - enabled: shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(geminiListModelsQuery, access, shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/googleai/gemini.vendor.ts b/src/modules/llms/vendors/googleai/gemini.vendor.ts index ee809c81f..80a81c038 100644 --- a/src/modules/llms/vendors/googleai/gemini.vendor.ts +++ b/src/modules/llms/vendors/googleai/gemini.vendor.ts @@ -2,13 +2,16 @@ import GoogleIcon from '@mui/icons-material/Google'; import { backendCaps } from '~/modules/backend/state-backend'; +import { apiAsync, apiQuery } from '~/common/util/trpc.client'; + +import type { GeminiAccessSchema } from '../../transports/server/gemini/gemini.router'; import type { IModelVendor } from '../IModelVendor'; +import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; -import type { GeminiAccessSchema } from '../../transports/server/gemini/gemini.router'; -import { GeminiSourceSetup } from './GeminiSourceSetup'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; -import { apiAsync } from '~/common/util/trpc.client'; + +import { GeminiSourceSetup } from './GeminiSourceSetup'; export interface SourceSetupGemini { @@ -58,6 +61,17 @@ export const ModelVendorGemini: IModelVendor void) { + return apiQuery.llmGemini.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); +} + + /** * This function either returns the LLM message, or throws a descriptive error string */ diff --git a/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx b/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx index 8afdca950..8f17df8d9 100644 --- a/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx +++ b/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx @@ -7,10 +7,11 @@ import { FormInputKey } from '~/common/components/forms/FormInputKey'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; + +import { openAIListModelsQuery } from '../openai/openai.vendor'; import { ModelVendorLocalAI } from './localai.vendor'; @@ -30,14 +31,8 @@ export function LocalAISourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = isValidHost; // fetch models - the OpenAI way - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: false, // !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(openAIListModelsQuery, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); return <> diff --git a/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx b/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx index 8cfa57d46..b2d1dd78c 100644 --- a/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx +++ b/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx @@ -4,10 +4,11 @@ import { FormInputKey } from '~/common/components/forms/FormInputKey'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; + +import { openAIListModelsQuery } from '../openai/openai.vendor'; import { ModelVendorMistral } from './mistral.vendor'; @@ -29,14 +30,8 @@ export function MistralSourceSetup(props: { sourceId: DModelSourceId }) { const showKeyError = !!mistralKey && !sourceSetupValid; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(openAIListModelsQuery, access, shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx index 3d8b2da2d..1aa4c6201 100644 --- a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx +++ b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx @@ -6,13 +6,13 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { asValidURL } from '~/common/util/urlUtils'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { ModelVendorOllama } from './ollama.vendor'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; + +import { ModelVendorOllama, ollamaListModelsQuery } from './ollama.vendor'; import { OllamaAdministration } from './OllamaAdministration'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) { @@ -32,14 +32,8 @@ export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = !hostError; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOllama.listModels.useQuery({ access }, { - enabled: false, // !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(ollamaListModelsQuery, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); return <> diff --git a/src/modules/llms/vendors/ollama/ollama.vendor.ts b/src/modules/llms/vendors/ollama/ollama.vendor.ts index 883f5f680..92cc4f1ef 100644 --- a/src/modules/llms/vendors/ollama/ollama.vendor.ts +++ b/src/modules/llms/vendors/ollama/ollama.vendor.ts @@ -1,13 +1,14 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OllamaIcon } from '~/common/components/icons/OllamaIcon'; -import { apiAsync } from '~/common/util/trpc.client'; +import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { IModelVendor } from '../IModelVendor'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; import type { OllamaAccessSchema } from '../../transports/server/ollama/ollama.router'; +import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; -import { LLMOptionsOpenAI } from '../openai/openai.vendor'; +import type { LLMOptionsOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { OllamaSourceSetup } from './OllamaSourceSetup'; @@ -45,6 +46,16 @@ export const ModelVendorOllama: IModelVendor void) { + return apiQuery.llmOllama.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); +} + + /** * This function either returns the LLM message, or throws a descriptive error string */ diff --git a/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx b/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx index f9e8ca674..2841774a6 100644 --- a/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx +++ b/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx @@ -6,10 +6,11 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; + +import { openAIListModelsQuery } from '../openai/openai.vendor'; import { ModelVendorOoobabooga } from './oobabooga.vendor'; @@ -24,14 +25,8 @@ export function OobaboogaSourceSetup(props: { sourceId: DModelSourceId }) { const { oaiHost } = access; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: false, // !hasModels && !!asValidURL(normSetup.oaiHost), - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(openAIListModelsQuery, access, false /* !hasModels && !!asValidURL(normSetup.oaiHost) */, source); return <> diff --git a/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx b/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx index 85d3e1ea4..0a951327d 100644 --- a/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx +++ b/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx @@ -9,13 +9,12 @@ import { FormTextField } from '~/common/components/forms/FormTextField'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { useToggleableBoolean } from '~/common/util/useToggleableBoolean'; -import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; -import { DLLM, DModelSource, DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; -import { isValidOpenAIApiKey, LLMOptionsOpenAI, ModelVendorOpenAI } from './openai.vendor'; +import { isValidOpenAIApiKey, ModelVendorOpenAI, openAIListModelsQuery } from './openai.vendor'; // avoid repeating it all over @@ -40,15 +39,8 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = oaiKey ? keyValid : !needsUserKey; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); - + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); return <> @@ -110,30 +102,3 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) { ; } - - -export function modelDescriptionToDLLM(model: ModelDescriptionSchema, source: DModelSource): DLLM { - const maxOutputTokens = model.maxCompletionTokens || Math.round((model.contextWindow || 4096) / 2); - const llmResponseTokens = Math.round(maxOutputTokens / (model.maxCompletionTokens ? 2 : 4)); - return { - id: `${source.id}-${model.id}`, - - label: model.label, - created: model.created || 0, - updated: model.updated || 0, - description: model.description, - tags: [], // ['stream', 'chat'], - contextTokens: model.contextWindow, - maxOutputTokens: maxOutputTokens, - hidden: !!model.hidden, - - sId: source.id, - _source: source, - - options: { - llmRef: model.id, - llmTemperature: 0.5, - llmResponseTokens: llmResponseTokens, - }, - }; -} \ No newline at end of file diff --git a/src/modules/llms/vendors/openai/openai.vendor.ts b/src/modules/llms/vendors/openai/openai.vendor.ts index f7eaeb92b..d0e9e4f1a 100644 --- a/src/modules/llms/vendors/openai/openai.vendor.ts +++ b/src/modules/llms/vendors/openai/openai.vendor.ts @@ -1,9 +1,10 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OpenAIIcon } from '~/common/components/icons/OpenAIIcon'; -import { apiAsync } from '~/common/util/trpc.client'; +import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { IModelVendor } from '../IModelVendor'; +import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; @@ -62,6 +63,16 @@ export const ModelVendorOpenAI: IModelVendor void) { + return apiQuery.llmOpenAI.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); +} + + /** * This function either returns the LLM message, or function calls, or throws a descriptive error string */ diff --git a/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx b/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx index 470dffb35..a22eabff9 100644 --- a/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx +++ b/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx @@ -6,11 +6,12 @@ import { FormInputKey } from '~/common/components/forms/FormInputKey'; import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { apiQuery } from '~/common/util/trpc.client'; import { getCallbackUrl } from '~/common/app.routes'; -import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; -import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; +import { DModelSourceId, useSourceSetup } from '../../store-llms'; +import { useUpdateVendorModels } from '../useUpdateVendorModels'; + +import { openAIListModelsQuery } from '../openai/openai.vendor'; import { isValidOpenRouterKey, ModelVendorOpenRouter } from './openrouter.vendor'; @@ -30,14 +31,8 @@ export function OpenRouterSourceSetup(props: { sourceId: DModelSourceId }) { const shallFetchSucceed = oaiKey ? keyValid : !needsUserKey; // fetch models - const { isFetching, refetch, isError, error } = apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: !sourceHasLLMs && shallFetchSucceed, - onSuccess: models => source && useModelsStore.getState().setLLMs( - models.models.map(model => modelDescriptionToDLLM(model, source)), - props.sourceId, - ), - staleTime: Infinity, - }); + const { isFetching, refetch, isError, error } = + useUpdateVendorModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); const handleOpenRouterLogin = () => { diff --git a/src/modules/llms/vendors/useUpdateVendorModels.tsx b/src/modules/llms/vendors/useUpdateVendorModels.tsx new file mode 100644 index 000000000..e0d51b9c7 --- /dev/null +++ b/src/modules/llms/vendors/useUpdateVendorModels.tsx @@ -0,0 +1,54 @@ +import type { TRPCClientErrorBase } from '@trpc/client'; + +import type { ModelDescriptionSchema } from '../transports/server/server.schemas'; + +import { DLLM, DModelSource, useModelsStore } from '../store-llms'; + + +export type IModelVendorListModelsFn = + (access: TAccess, enabled: boolean, onSuccess: (data: { models: ModelDescriptionSchema[] }) => void) => + { isFetching: boolean, refetch: () => void, isError: boolean, error: TRPCClientErrorBase | null }; + + +/** + * Hook that fetches the list of models from the vendor and updates the store, + * while returning the fetch state. + */ +export function useUpdateVendorModels(listFn: IModelVendorListModelsFn, access: TAccess, enabled: boolean, source: DModelSource) { + return listFn(access, enabled, data => source && updateModelsFn(data, source)); +} + + +function updateModelsFn(data: { models: ModelDescriptionSchema[] }, source: DModelSource) { + useModelsStore.getState().setLLMs( + data.models.map(model => modelDescriptionToDLLMOpenAIOptions(model, source)), + source.id, + ); +} + +function modelDescriptionToDLLMOpenAIOptions(model: ModelDescriptionSchema, source: DModelSource): DLLM { + const maxOutputTokens = model.maxCompletionTokens || Math.round((model.contextWindow || 4096) / 2); + const llmResponseTokens = Math.round(maxOutputTokens / (model.maxCompletionTokens ? 2 : 4)); + return { + id: `${source.id}-${model.id}`, + + label: model.label, + created: model.created || 0, + updated: model.updated || 0, + description: model.description, + tags: [], // ['stream', 'chat'], + contextTokens: model.contextWindow, + maxOutputTokens: maxOutputTokens, + hidden: !!model.hidden, + + sId: source.id, + _source: source, + + options: { + llmRef: model.id, + // @ts-ignore FIXME: large assumption that this is LLMOptionsOpenAI object + llmTemperature: 0.5, + llmResponseTokens: llmResponseTokens, + }, + }; +} \ No newline at end of file From 49c77f5a10718b7aa25aa6ba6a6e4b21a8992254 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 16:14:25 -0800 Subject: [PATCH 10/24] Llms: cleanup model lists (bits) --- src/modules/llms/vendors/useUpdateVendorModels.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modules/llms/vendors/useUpdateVendorModels.tsx b/src/modules/llms/vendors/useUpdateVendorModels.tsx index e0d51b9c7..06699998d 100644 --- a/src/modules/llms/vendors/useUpdateVendorModels.tsx +++ b/src/modules/llms/vendors/useUpdateVendorModels.tsx @@ -5,7 +5,7 @@ import type { ModelDescriptionSchema } from '../transports/server/server.schemas import { DLLM, DModelSource, useModelsStore } from '../store-llms'; -export type IModelVendorListModelsFn = +export type IModelVendorUpdateModelsQuery = (access: TAccess, enabled: boolean, onSuccess: (data: { models: ModelDescriptionSchema[] }) => void) => { isFetching: boolean, refetch: () => void, isError: boolean, error: TRPCClientErrorBase | null }; @@ -14,7 +14,7 @@ export type IModelVendorListModelsFn = * Hook that fetches the list of models from the vendor and updates the store, * while returning the fetch state. */ -export function useUpdateVendorModels(listFn: IModelVendorListModelsFn, access: TAccess, enabled: boolean, source: DModelSource) { +export function useUpdateVendorModels(listFn: IModelVendorUpdateModelsQuery, access: TAccess, enabled: boolean, source: DModelSource) { return listFn(access, enabled, data => source && updateModelsFn(data, source)); } From 3f9defd18c458e03c879aa6af4d2b01ceff430b1 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 16:34:53 -0800 Subject: [PATCH 11/24] Llms: restructure --- app/api/llms/stream/route.ts | 2 +- docs/config-local-localai.md | 2 +- src/apps/call/CallUI.tsx | 6 ++-- src/apps/call/components/CallMessage.tsx | 2 +- src/apps/chat/editors/chat-stream.ts | 4 +-- src/apps/personas/useLLMChain.ts | 5 +-- .../aifn/autosuggestions/autoSuggestions.ts | 7 ++-- src/modules/aifn/autotitle/autoTitle.ts | 4 +-- src/modules/aifn/digrams/DiagramsModal.tsx | 5 +-- src/modules/aifn/digrams/diagrams.data.ts | 2 +- .../aifn/imagine/imaginePromptFromText.ts | 4 +-- src/modules/aifn/react/react.ts | 5 +-- src/modules/aifn/summarize/summerize.ts | 4 +-- src/modules/aifn/useStreamChatText.ts | 6 ++-- src/modules/llms/client/llm.client.types.ts | 20 +++++++++++ src/modules/llms/client/llmChatGenerate.ts | 15 ++++++++ .../llmStreamChatGenerate.ts} | 9 ++--- .../useLlmUpdateModels.tsx} | 4 +-- .../server/anthropic/anthropic.models.ts | 2 +- .../server/anthropic/anthropic.router.ts | 0 .../server/anthropic/anthropic.wiretypes.ts | 0 .../server/gemini/gemini.router.ts | 0 .../server/gemini/gemini.wiretypes.ts | 0 .../{transports => }/server/llms.streaming.ts | 0 .../server/ollama/ollama.models.ts | 0 .../server/ollama/ollama.router.ts | 2 +- .../server/ollama/ollama.wiretypes.ts | 0 .../server/openai/mistral.wiretypes.ts | 0 .../server/openai/models.data.ts | 2 +- .../server/openai/openai.router.ts | 0 .../server/openai/openai.wiretypes.ts | 0 .../{transports => }/server/server.schemas.ts | 2 +- src/modules/llms/transports/chatGenerate.ts | 34 ------------------- src/modules/llms/vendors/IModelVendor.ts | 2 +- .../anthropic/AnthropicSourceSetup.tsx | 4 +-- .../vendors/anthropic/anthropic.vendor.ts | 6 ++-- .../llms/vendors/azure/AzureSourceSetup.tsx | 4 +-- .../llms/vendors/azure/azure.vendor.ts | 4 +-- .../vendors/googleai/GeminiSourceSetup.tsx | 4 +-- .../llms/vendors/googleai/gemini.vendor.ts | 6 ++-- .../vendors/localai/LocalAISourceSetup.tsx | 4 +-- .../llms/vendors/localai/localai.vendor.ts | 4 +-- .../vendors/mistral/MistralSourceSetup.tsx | 4 +-- .../llms/vendors/mistral/mistral.vendor.ts | 4 +-- .../vendors/ollama/OllamaAdministration.tsx | 2 +- .../llms/vendors/ollama/OllamaSourceSetup.tsx | 4 +-- .../llms/vendors/ollama/ollama.vendor.ts | 6 ++-- .../oobabooga/OobaboogaSourceSetup.tsx | 4 +-- .../vendors/oobabooga/oobabooga.vendor.ts | 4 +-- .../llms/vendors/openai/OpenAISourceSetup.tsx | 4 +-- .../llms/vendors/openai/openai.vendor.ts | 6 ++-- .../openrouter/OpenRouterSourceSetup.tsx | 4 +-- .../vendors/openrouter/openrouter.vendor.ts | 4 +-- src/server/api/trpc.router-edge.ts | 8 ++--- 54 files changed, 123 insertions(+), 117 deletions(-) create mode 100644 src/modules/llms/client/llm.client.types.ts create mode 100644 src/modules/llms/client/llmChatGenerate.ts rename src/modules/llms/{transports/streamChat.ts => client/llmStreamChatGenerate.ts} (96%) rename src/modules/llms/{vendors/useUpdateVendorModels.tsx => client/useLlmUpdateModels.tsx} (87%) rename src/modules/llms/{transports => }/server/anthropic/anthropic.models.ts (97%) rename src/modules/llms/{transports => }/server/anthropic/anthropic.router.ts (100%) rename src/modules/llms/{transports => }/server/anthropic/anthropic.wiretypes.ts (100%) rename src/modules/llms/{transports => }/server/gemini/gemini.router.ts (100%) rename src/modules/llms/{transports => }/server/gemini/gemini.wiretypes.ts (100%) rename src/modules/llms/{transports => }/server/llms.streaming.ts (100%) rename src/modules/llms/{transports => }/server/ollama/ollama.models.ts (100%) rename src/modules/llms/{transports => }/server/ollama/ollama.router.ts (99%) rename src/modules/llms/{transports => }/server/ollama/ollama.wiretypes.ts (100%) rename src/modules/llms/{transports => }/server/openai/mistral.wiretypes.ts (100%) rename src/modules/llms/{transports => }/server/openai/models.data.ts (99%) rename src/modules/llms/{transports => }/server/openai/openai.router.ts (100%) rename src/modules/llms/{transports => }/server/openai/openai.wiretypes.ts (100%) rename src/modules/llms/{transports => }/server/server.schemas.ts (95%) delete mode 100644 src/modules/llms/transports/chatGenerate.ts diff --git a/app/api/llms/stream/route.ts b/app/api/llms/stream/route.ts index 76794a39e..b0873a013 100644 --- a/app/api/llms/stream/route.ts +++ b/app/api/llms/stream/route.ts @@ -1,2 +1,2 @@ export const runtime = 'edge'; -export { llmStreamingRelayHandler as POST } from '~/modules/llms/transports/server/llms.streaming'; \ No newline at end of file +export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llms.streaming'; \ No newline at end of file diff --git a/docs/config-local-localai.md b/docs/config-local-localai.md index 4cc2bf4fa..43a9d1aa5 100644 --- a/docs/config-local-localai.md +++ b/docs/config-local-localai.md @@ -30,5 +30,5 @@ For instance with [Use luna-ai-llama2 with docker compose](https://localai.io/ba > NOTE: LocalAI does not list details about the mdoels. Every model is assumed to be > capable of chatting, and with a context window of 4096 tokens. -> Please update the [src/modules/llms/transports/server/openai/models.data.ts](../src/modules/llms/transports/server/openai/models.data.ts) +> Please update the [src/modules/llms/transports/server/openai/models.data.ts](../src/modules/llms/server/openai/models.data.ts) > file with the mapping information between LocalAI model IDs and names/descriptions/tokens, etc. diff --git a/src/apps/call/CallUI.tsx b/src/apps/call/CallUI.tsx index 6ff0efde1..a118e3266 100644 --- a/src/apps/call/CallUI.tsx +++ b/src/apps/call/CallUI.tsx @@ -13,10 +13,10 @@ import RecordVoiceOverIcon from '@mui/icons-material/RecordVoiceOver'; import { useChatLLMDropdown } from '../chat/components/applayout/useLLMDropdown'; +import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; import { EXPERIMENTAL_speakTextStream } from '~/modules/elevenlabs/elevenlabs.client'; import { SystemPurposeId, SystemPurposes } from '../../data'; -import { VChatMessageIn } from '~/modules/llms/transports/chatGenerate'; -import { streamChat } from '~/modules/llms/transports/streamChat'; +import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate'; import { useElevenLabsVoiceDropdown } from '~/modules/elevenlabs/useElevenLabsVoiceDropdown'; import { Link } from '~/common/components/Link'; @@ -216,7 +216,7 @@ export function CallUI(props: { responseAbortController.current = new AbortController(); let finalText = ''; let error: any | null = null; - streamChat(chatLLMId, callPrompt, responseAbortController.current.signal, (updatedMessage: Partial) => { + llmStreamChatGenerate(chatLLMId, callPrompt, responseAbortController.current.signal, (updatedMessage: Partial) => { const text = updatedMessage.text?.trim(); if (text) { finalText = text; diff --git a/src/apps/call/components/CallMessage.tsx b/src/apps/call/components/CallMessage.tsx index ae67ef141..525a586c2 100644 --- a/src/apps/call/components/CallMessage.tsx +++ b/src/apps/call/components/CallMessage.tsx @@ -3,7 +3,7 @@ import * as React from 'react'; import { Chip, ColorPaletteProp, VariantProp } from '@mui/joy'; import { SxProps } from '@mui/joy/styles/types'; -import { VChatMessageIn } from '~/modules/llms/transports/chatGenerate'; +import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; export function CallMessage(props: { diff --git a/src/apps/chat/editors/chat-stream.ts b/src/apps/chat/editors/chat-stream.ts index 090fc05ff..b8dd4e3f0 100644 --- a/src/apps/chat/editors/chat-stream.ts +++ b/src/apps/chat/editors/chat-stream.ts @@ -2,8 +2,8 @@ import { DLLMId } from '~/modules/llms/store-llms'; import { SystemPurposeId } from '../../../data'; import { autoSuggestions } from '~/modules/aifn/autosuggestions/autoSuggestions'; import { autoTitle } from '~/modules/aifn/autotitle/autoTitle'; +import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate'; import { speakText } from '~/modules/elevenlabs/elevenlabs.client'; -import { streamChat } from '~/modules/llms/transports/streamChat'; import { DMessage, useChatStore } from '~/common/state/store-chats'; @@ -63,7 +63,7 @@ async function streamAssistantMessage( const messages = history.map(({ role, text }) => ({ role, content: text })); try { - await streamChat(llmId, messages, abortSignal, + await llmStreamChatGenerate(llmId, messages, abortSignal, (updatedMessage: Partial) => { // update the message in the store (and thus schedule a re-render) editMessage(updatedMessage); diff --git a/src/apps/personas/useLLMChain.ts b/src/apps/personas/useLLMChain.ts index 1eab05494..99d0cd4a4 100644 --- a/src/apps/personas/useLLMChain.ts +++ b/src/apps/personas/useLLMChain.ts @@ -1,7 +1,8 @@ import * as React from 'react'; +import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; import { DLLMId, useModelsStore } from '~/modules/llms/store-llms'; -import { callChatGenerate, VChatMessageIn } from '~/modules/llms/transports/chatGenerate'; +import { llmChatGenerate } from '~/modules/llms/client/llmChatGenerate'; export interface LLMChainStep { @@ -80,7 +81,7 @@ export function useLLMChain(steps: LLMChainStep[], llmId: DLLMId | undefined, ch _chainAbortController.signal.addEventListener('abort', globalToStepListener); // LLM call - callChatGenerate(llmId, llmChatInput, chain.overrideResponseTokens) + llmChatGenerate(llmId, llmChatInput, chain.overrideResponseTokens) .then(({ content }) => { stepDone = true; if (!stepAbortController.signal.aborted) diff --git a/src/modules/aifn/autosuggestions/autoSuggestions.ts b/src/modules/aifn/autosuggestions/autoSuggestions.ts index 28097a0c8..8cbd73065 100644 --- a/src/modules/aifn/autosuggestions/autoSuggestions.ts +++ b/src/modules/aifn/autosuggestions/autoSuggestions.ts @@ -1,4 +1,5 @@ -import { callChatGenerateWithFunctions, VChatFunctionIn } from '~/modules/llms/transports/chatGenerate'; +import type { VChatFunctionIn } from '~/modules/llms/client/llm.client.types'; +import { llmChatGenerateWithFunctions } from '~/modules/llms/client/llmChatGenerate'; import { useModelsStore } from '~/modules/llms/store-llms'; import { useChatStore } from '~/common/state/store-chats'; @@ -71,7 +72,7 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri // Follow-up: Question if (suggestQuestions) { - // callChatGenerateWithFunctions(funcLLMId, [ + // llmChatGenerateWithFunctions(funcLLMId, [ // { role: 'system', content: systemMessage.text }, // { role: 'user', content: userMessage.text }, // { role: 'assistant', content: assistantMessageText }, @@ -83,7 +84,7 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri // Follow-up: Auto-Diagrams if (suggestDiagrams) { - void callChatGenerateWithFunctions(funcLLMId, [ + void llmChatGenerateWithFunctions(funcLLMId, [ { role: 'system', content: systemMessage.text }, { role: 'user', content: userMessage.text }, { role: 'assistant', content: assistantMessageText }, diff --git a/src/modules/aifn/autotitle/autoTitle.ts b/src/modules/aifn/autotitle/autoTitle.ts index 2e29771fa..5b99bad14 100644 --- a/src/modules/aifn/autotitle/autoTitle.ts +++ b/src/modules/aifn/autotitle/autoTitle.ts @@ -1,4 +1,4 @@ -import { callChatGenerate } from '~/modules/llms/transports/chatGenerate'; +import { llmChatGenerate } from '~/modules/llms/client/llmChatGenerate'; import { useModelsStore } from '~/modules/llms/store-llms'; import { useChatStore } from '~/common/state/store-chats'; @@ -27,7 +27,7 @@ export function autoTitle(conversationId: string) { }); // LLM - void callChatGenerate(fastLLMId, [ + void llmChatGenerate(fastLLMId, [ { role: 'system', content: `You are an AI conversation titles assistant who specializes in creating expressive yet few-words chat titles.` }, { role: 'user', content: diff --git a/src/modules/aifn/digrams/DiagramsModal.tsx b/src/modules/aifn/digrams/DiagramsModal.tsx index 68429a128..9206957fc 100644 --- a/src/modules/aifn/digrams/DiagramsModal.tsx +++ b/src/modules/aifn/digrams/DiagramsModal.tsx @@ -8,8 +8,9 @@ import ReplayIcon from '@mui/icons-material/Replay'; import StopOutlinedIcon from '@mui/icons-material/StopOutlined'; import TelegramIcon from '@mui/icons-material/Telegram'; +import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate'; + import { ChatMessage } from '../../../apps/chat/components/message/ChatMessage'; -import { streamChat } from '~/modules/llms/transports/streamChat'; import { GoodModal } from '~/common/components/GoodModal'; import { InlineError } from '~/common/components/InlineError'; @@ -85,7 +86,7 @@ export function DiagramsModal(props: { config: DiagramConfig, onClose: () => voi const diagramPrompt = bigDiagramPrompt(diagramType, diagramLanguage, systemMessage.text, subject, customInstruction); try { - await streamChat(diagramLlm.id, diagramPrompt, stepAbortController.signal, + await llmStreamChatGenerate(diagramLlm.id, diagramPrompt, stepAbortController.signal, (update: Partial<{ text: string, typing: boolean, originLLM: string }>) => { assistantMessage = { ...assistantMessage, ...update }; setMessage(assistantMessage); diff --git a/src/modules/aifn/digrams/diagrams.data.ts b/src/modules/aifn/digrams/diagrams.data.ts index 54239118e..8a4b675de 100644 --- a/src/modules/aifn/digrams/diagrams.data.ts +++ b/src/modules/aifn/digrams/diagrams.data.ts @@ -1,4 +1,4 @@ -import type { VChatMessageIn } from '~/modules/llms/transports/chatGenerate'; +import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; import type { FormRadioOption } from '~/common/components/forms/FormRadioControl'; diff --git a/src/modules/aifn/imagine/imaginePromptFromText.ts b/src/modules/aifn/imagine/imaginePromptFromText.ts index 211ac7abd..936bd6995 100644 --- a/src/modules/aifn/imagine/imaginePromptFromText.ts +++ b/src/modules/aifn/imagine/imaginePromptFromText.ts @@ -1,4 +1,4 @@ -import { callChatGenerate } from '~/modules/llms/transports/chatGenerate'; +import { llmChatGenerate } from '~/modules/llms/client/llmChatGenerate'; import { useModelsStore } from '~/modules/llms/store-llms'; @@ -14,7 +14,7 @@ export async function imaginePromptFromText(messageText: string): Promise { + await llmStreamChatGenerate(llmId, prompt, abortControllerRef.current.signal, (update) => { if (update.text) { lastText = update.text; setPartialText(lastText); diff --git a/src/modules/llms/client/llm.client.types.ts b/src/modules/llms/client/llm.client.types.ts new file mode 100644 index 000000000..5d4fa3be9 --- /dev/null +++ b/src/modules/llms/client/llm.client.types.ts @@ -0,0 +1,20 @@ +import type { OpenAIWire } from '~/modules/llms/server/openai/openai.wiretypes'; + +export interface VChatMessageIn { + role: 'assistant' | 'system' | 'user'; // | 'function'; + content: string; + //name?: string; // when role: 'function' +} + +export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef; + +export interface VChatMessageOut { + role: 'assistant' | 'system' | 'user'; + content: string; + finish_reason: 'stop' | 'length' | null; +} + +export interface VChatMessageOrFunctionCallOut extends VChatMessageOut { + function_name: string; + function_arguments: object | null; +} \ No newline at end of file diff --git a/src/modules/llms/client/llmChatGenerate.ts b/src/modules/llms/client/llmChatGenerate.ts new file mode 100644 index 000000000..c2cc1a211 --- /dev/null +++ b/src/modules/llms/client/llmChatGenerate.ts @@ -0,0 +1,15 @@ +import type { DLLMId } from '../store-llms'; +import { findVendorForLlmOrThrow } from '../vendors/vendors.registry'; + +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from './llm.client.types'; + + +export async function llmChatGenerate(llmId: DLLMId, messages: VChatMessageIn[], maxTokens?: number): Promise { + const { llm, vendor } = findVendorForLlmOrThrow(llmId); + return await vendor.callChatGenerate(llm, messages, maxTokens); +} + +export async function llmChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], forceFunctionName: string | null, maxTokens?: number): Promise { + const { llm, vendor } = findVendorForLlmOrThrow(llmId); + return await vendor.callChatGenerateWF(llm, messages, functions, forceFunctionName, maxTokens); +} \ No newline at end of file diff --git a/src/modules/llms/transports/streamChat.ts b/src/modules/llms/client/llmStreamChatGenerate.ts similarity index 96% rename from src/modules/llms/transports/streamChat.ts rename to src/modules/llms/client/llmStreamChatGenerate.ts index 332e59ca5..73e0ac866 100644 --- a/src/modules/llms/transports/streamChat.ts +++ b/src/modules/llms/client/llmStreamChatGenerate.ts @@ -1,11 +1,12 @@ import { apiAsync } from '~/common/util/trpc.client'; +import type { ChatStreamFirstOutputPacketSchema, ChatStreamInputSchema } from '../server/llms.streaming'; import type { DLLM, DLLMId } from '../store-llms'; import { findVendorForLlmOrThrow } from '../vendors/vendors.registry'; -import type { ChatStreamFirstOutputPacketSchema, ChatStreamInputSchema } from './server/llms.streaming'; -import type { OpenAIWire } from './server/openai/openai.wiretypes'; -import type { VChatMessageIn } from './chatGenerate'; +import type { OpenAIWire } from '../server/openai/openai.wiretypes'; + +import type { VChatMessageIn } from './llm.client.types'; /** @@ -20,7 +21,7 @@ import type { VChatMessageIn } from './chatGenerate'; * @param abortSignal used to initiate a client-side abort of the fetch request to the API endpoint * @param onUpdate callback when a piece of a message (text, model name, typing..) is received */ -export async function streamChat( +export async function llmStreamChatGenerate( llmId: DLLMId, messages: VChatMessageIn[], abortSignal: AbortSignal, diff --git a/src/modules/llms/vendors/useUpdateVendorModels.tsx b/src/modules/llms/client/useLlmUpdateModels.tsx similarity index 87% rename from src/modules/llms/vendors/useUpdateVendorModels.tsx rename to src/modules/llms/client/useLlmUpdateModels.tsx index 06699998d..0aa70f399 100644 --- a/src/modules/llms/vendors/useUpdateVendorModels.tsx +++ b/src/modules/llms/client/useLlmUpdateModels.tsx @@ -1,6 +1,6 @@ import type { TRPCClientErrorBase } from '@trpc/client'; -import type { ModelDescriptionSchema } from '../transports/server/server.schemas'; +import type { ModelDescriptionSchema } from '../server/server.schemas'; import { DLLM, DModelSource, useModelsStore } from '../store-llms'; @@ -14,7 +14,7 @@ export type IModelVendorUpdateModelsQuery = * Hook that fetches the list of models from the vendor and updates the store, * while returning the fetch state. */ -export function useUpdateVendorModels(listFn: IModelVendorUpdateModelsQuery, access: TAccess, enabled: boolean, source: DModelSource) { +export function useLlmUpdateModels(listFn: IModelVendorUpdateModelsQuery, access: TAccess, enabled: boolean, source: DModelSource) { return listFn(access, enabled, data => source && updateModelsFn(data, source)); } diff --git a/src/modules/llms/transports/server/anthropic/anthropic.models.ts b/src/modules/llms/server/anthropic/anthropic.models.ts similarity index 97% rename from src/modules/llms/transports/server/anthropic/anthropic.models.ts rename to src/modules/llms/server/anthropic/anthropic.models.ts index eb4e4117a..3a0a4e0cc 100644 --- a/src/modules/llms/transports/server/anthropic/anthropic.models.ts +++ b/src/modules/llms/server/anthropic/anthropic.models.ts @@ -1,6 +1,6 @@ import type { ModelDescriptionSchema } from '../server.schemas'; -import { LLM_IF_OAI_Chat } from '../../../store-llms'; +import { LLM_IF_OAI_Chat } from '../../store-llms'; const roundTime = (date: string) => Math.round(new Date(date).getTime() / 1000); diff --git a/src/modules/llms/transports/server/anthropic/anthropic.router.ts b/src/modules/llms/server/anthropic/anthropic.router.ts similarity index 100% rename from src/modules/llms/transports/server/anthropic/anthropic.router.ts rename to src/modules/llms/server/anthropic/anthropic.router.ts diff --git a/src/modules/llms/transports/server/anthropic/anthropic.wiretypes.ts b/src/modules/llms/server/anthropic/anthropic.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/anthropic/anthropic.wiretypes.ts rename to src/modules/llms/server/anthropic/anthropic.wiretypes.ts diff --git a/src/modules/llms/transports/server/gemini/gemini.router.ts b/src/modules/llms/server/gemini/gemini.router.ts similarity index 100% rename from src/modules/llms/transports/server/gemini/gemini.router.ts rename to src/modules/llms/server/gemini/gemini.router.ts diff --git a/src/modules/llms/transports/server/gemini/gemini.wiretypes.ts b/src/modules/llms/server/gemini/gemini.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/gemini/gemini.wiretypes.ts rename to src/modules/llms/server/gemini/gemini.wiretypes.ts diff --git a/src/modules/llms/transports/server/llms.streaming.ts b/src/modules/llms/server/llms.streaming.ts similarity index 100% rename from src/modules/llms/transports/server/llms.streaming.ts rename to src/modules/llms/server/llms.streaming.ts diff --git a/src/modules/llms/transports/server/ollama/ollama.models.ts b/src/modules/llms/server/ollama/ollama.models.ts similarity index 100% rename from src/modules/llms/transports/server/ollama/ollama.models.ts rename to src/modules/llms/server/ollama/ollama.models.ts diff --git a/src/modules/llms/transports/server/ollama/ollama.router.ts b/src/modules/llms/server/ollama/ollama.router.ts similarity index 99% rename from src/modules/llms/transports/server/ollama/ollama.router.ts rename to src/modules/llms/server/ollama/ollama.router.ts index 20d89d62f..91c19634d 100644 --- a/src/modules/llms/transports/server/ollama/ollama.router.ts +++ b/src/modules/llms/server/ollama/ollama.router.ts @@ -5,7 +5,7 @@ import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server'; import { env } from '~/server/env.mjs'; import { fetchJsonOrTRPCError, fetchTextOrTRPCError } from '~/server/api/trpc.serverutils'; -import { LLM_IF_OAI_Chat } from '../../../store-llms'; +import { LLM_IF_OAI_Chat } from '../../store-llms'; import { capitalizeFirstLetter } from '~/common/util/textUtils'; diff --git a/src/modules/llms/transports/server/ollama/ollama.wiretypes.ts b/src/modules/llms/server/ollama/ollama.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/ollama/ollama.wiretypes.ts rename to src/modules/llms/server/ollama/ollama.wiretypes.ts diff --git a/src/modules/llms/transports/server/openai/mistral.wiretypes.ts b/src/modules/llms/server/openai/mistral.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/openai/mistral.wiretypes.ts rename to src/modules/llms/server/openai/mistral.wiretypes.ts diff --git a/src/modules/llms/transports/server/openai/models.data.ts b/src/modules/llms/server/openai/models.data.ts similarity index 99% rename from src/modules/llms/transports/server/openai/models.data.ts rename to src/modules/llms/server/openai/models.data.ts index cc20e574f..aad475367 100644 --- a/src/modules/llms/transports/server/openai/models.data.ts +++ b/src/modules/llms/server/openai/models.data.ts @@ -1,6 +1,6 @@ import { SERVER_DEBUG_WIRE } from '~/server/wire'; -import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../../store-llms'; +import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../store-llms'; import type { ModelDescriptionSchema } from '../server.schemas'; import { wireMistralModelsListOutputSchema } from './mistral.wiretypes'; diff --git a/src/modules/llms/transports/server/openai/openai.router.ts b/src/modules/llms/server/openai/openai.router.ts similarity index 100% rename from src/modules/llms/transports/server/openai/openai.router.ts rename to src/modules/llms/server/openai/openai.router.ts diff --git a/src/modules/llms/transports/server/openai/openai.wiretypes.ts b/src/modules/llms/server/openai/openai.wiretypes.ts similarity index 100% rename from src/modules/llms/transports/server/openai/openai.wiretypes.ts rename to src/modules/llms/server/openai/openai.wiretypes.ts diff --git a/src/modules/llms/transports/server/server.schemas.ts b/src/modules/llms/server/server.schemas.ts similarity index 95% rename from src/modules/llms/transports/server/server.schemas.ts rename to src/modules/llms/server/server.schemas.ts index f72313d57..e04b44e3d 100644 --- a/src/modules/llms/transports/server/server.schemas.ts +++ b/src/modules/llms/server/server.schemas.ts @@ -1,5 +1,5 @@ import { z } from 'zod'; -import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../store-llms'; +import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../store-llms'; const pricingSchema = z.object({ cpmPrompt: z.number().optional(), // Cost per thousand prompt tokens diff --git a/src/modules/llms/transports/chatGenerate.ts b/src/modules/llms/transports/chatGenerate.ts deleted file mode 100644 index 20ff4ba69..000000000 --- a/src/modules/llms/transports/chatGenerate.ts +++ /dev/null @@ -1,34 +0,0 @@ -import type { DLLMId } from '../store-llms'; -import type { OpenAIWire } from './server/openai/openai.wiretypes'; -import { findVendorForLlmOrThrow } from '../vendors/vendors.registry'; - - -export interface VChatMessageIn { - role: 'assistant' | 'system' | 'user'; // | 'function'; - content: string; - //name?: string; // when role: 'function' -} - -export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef; - -export interface VChatMessageOut { - role: 'assistant' | 'system' | 'user'; - content: string; - finish_reason: 'stop' | 'length' | null; -} - -export interface VChatMessageOrFunctionCallOut extends VChatMessageOut { - function_name: string; - function_arguments: object | null; -} - - -export async function callChatGenerate(llmId: DLLMId, messages: VChatMessageIn[], maxTokens?: number): Promise { - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - return await vendor.callChatGenerate(llm, messages, maxTokens); -} - -export async function callChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], forceFunctionName: string | null, maxTokens?: number): Promise { - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - return await vendor.callChatGenerateWF(llm, messages, functions, forceFunctionName, maxTokens); -} \ No newline at end of file diff --git a/src/modules/llms/vendors/IModelVendor.ts b/src/modules/llms/vendors/IModelVendor.ts index a29a1e0b7..4830158b1 100644 --- a/src/modules/llms/vendors/IModelVendor.ts +++ b/src/modules/llms/vendors/IModelVendor.ts @@ -2,7 +2,7 @@ import type React from 'react'; import type { DLLM, DModelSourceId } from '../store-llms'; import type { ModelVendorId } from './vendors.registry'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../transports/chatGenerate'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../client/llm.client.types'; export interface IModelVendor> { diff --git a/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx b/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx index 87bcbf898..46a650ccd 100644 --- a/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx +++ b/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx @@ -10,7 +10,7 @@ import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefet import { useToggleableBoolean } from '~/common/util/useToggleableBoolean'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { anthropicListModelsQuery, isValidAnthropicApiKey, ModelVendorAnthropic } from './anthropic.vendor'; @@ -34,7 +34,7 @@ export function AnthropicSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useUpdateVendorModels(anthropicListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); + useLlmUpdateModels(anthropicListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts index a1f743de7..625fcb166 100644 --- a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts +++ b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts @@ -3,10 +3,10 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { AnthropicIcon } from '~/common/components/icons/AnthropicIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; -import type { AnthropicAccessSchema } from '../../transports/server/anthropic/anthropic.router'; +import type { AnthropicAccessSchema } from '../../server/anthropic/anthropic.router'; import type { IModelVendor } from '../IModelVendor'; -import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { ModelDescriptionSchema } from '../../server/server.schemas'; +import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; diff --git a/src/modules/llms/vendors/azure/AzureSourceSetup.tsx b/src/modules/llms/vendors/azure/AzureSourceSetup.tsx index 0c4904424..f0a50d0ba 100644 --- a/src/modules/llms/vendors/azure/AzureSourceSetup.tsx +++ b/src/modules/llms/vendors/azure/AzureSourceSetup.tsx @@ -8,7 +8,7 @@ import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefet import { asValidURL } from '~/common/util/urlUtils'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { openAIListModelsQuery } from '../openai/openai.vendor'; @@ -33,7 +33,7 @@ export function AzureSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useUpdateVendorModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); + useLlmUpdateModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/azure/azure.vendor.ts b/src/modules/llms/vendors/azure/azure.vendor.ts index a7b2b6734..028f94df1 100644 --- a/src/modules/llms/vendors/azure/azure.vendor.ts +++ b/src/modules/llms/vendors/azure/azure.vendor.ts @@ -3,8 +3,8 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { AzureIcon } from '~/common/components/icons/AzureIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '~/modules/llms/server/openai/openai.router'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; diff --git a/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx b/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx index f01ba9e1d..357e0572b 100644 --- a/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx +++ b/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx @@ -6,7 +6,7 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { geminiListModelsQuery, ModelVendorGemini } from './gemini.vendor'; @@ -29,7 +29,7 @@ export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useUpdateVendorModels(geminiListModelsQuery, access, shallFetchSucceed, source); + useLlmUpdateModels(geminiListModelsQuery, access, shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/googleai/gemini.vendor.ts b/src/modules/llms/vendors/googleai/gemini.vendor.ts index 80a81c038..826737496 100644 --- a/src/modules/llms/vendors/googleai/gemini.vendor.ts +++ b/src/modules/llms/vendors/googleai/gemini.vendor.ts @@ -4,10 +4,10 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; -import type { GeminiAccessSchema } from '../../transports/server/gemini/gemini.router'; +import type { GeminiAccessSchema } from '../../server/gemini/gemini.router'; import type { IModelVendor } from '../IModelVendor'; -import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { ModelDescriptionSchema } from '../../server/server.schemas'; +import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; diff --git a/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx b/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx index 8f17df8d9..6bdce7073 100644 --- a/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx +++ b/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx @@ -9,7 +9,7 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { openAIListModelsQuery } from '../openai/openai.vendor'; @@ -32,7 +32,7 @@ export function LocalAISourceSetup(props: { sourceId: DModelSourceId }) { // fetch models - the OpenAI way const { isFetching, refetch, isError, error } = - useUpdateVendorModels(openAIListModelsQuery, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); + useLlmUpdateModels(openAIListModelsQuery, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); return <> diff --git a/src/modules/llms/vendors/localai/localai.vendor.ts b/src/modules/llms/vendors/localai/localai.vendor.ts index 7d58c7d42..e39d5245b 100644 --- a/src/modules/llms/vendors/localai/localai.vendor.ts +++ b/src/modules/llms/vendors/localai/localai.vendor.ts @@ -1,8 +1,8 @@ import DevicesIcon from '@mui/icons-material/Devices'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; diff --git a/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx b/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx index b2d1dd78c..95f563c6f 100644 --- a/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx +++ b/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx @@ -6,7 +6,7 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { openAIListModelsQuery } from '../openai/openai.vendor'; @@ -31,7 +31,7 @@ export function MistralSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useUpdateVendorModels(openAIListModelsQuery, access, shallFetchSucceed, source); + useLlmUpdateModels(openAIListModelsQuery, access, shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/mistral/mistral.vendor.ts b/src/modules/llms/vendors/mistral/mistral.vendor.ts index 5ae500a07..cdcb0c3bc 100644 --- a/src/modules/llms/vendors/mistral/mistral.vendor.ts +++ b/src/modules/llms/vendors/mistral/mistral.vendor.ts @@ -3,8 +3,8 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { MistralIcon } from '~/common/components/icons/MistralIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatMessageIn, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; +import type { VChatMessageIn, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI, openAICallChatGenerate, SourceSetupOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; diff --git a/src/modules/llms/vendors/ollama/OllamaAdministration.tsx b/src/modules/llms/vendors/ollama/OllamaAdministration.tsx index 9d2aebdef..2c0aa4b94 100644 --- a/src/modules/llms/vendors/ollama/OllamaAdministration.tsx +++ b/src/modules/llms/vendors/ollama/OllamaAdministration.tsx @@ -12,7 +12,7 @@ import { Link } from '~/common/components/Link'; import { apiQuery } from '~/common/util/trpc.client'; import { settingsGap } from '~/common/app.theme'; -import type { OllamaAccessSchema } from '../../transports/server/ollama/ollama.router'; +import type { OllamaAccessSchema } from '../../server/ollama/ollama.router'; export function OllamaAdministration(props: { access: OllamaAccessSchema, onClose: () => void }) { diff --git a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx index 1aa4c6201..f66cc58cd 100644 --- a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx +++ b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx @@ -9,7 +9,7 @@ import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefet import { asValidURL } from '~/common/util/urlUtils'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { ModelVendorOllama, ollamaListModelsQuery } from './ollama.vendor'; import { OllamaAdministration } from './OllamaAdministration'; @@ -33,7 +33,7 @@ export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useUpdateVendorModels(ollamaListModelsQuery, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); + useLlmUpdateModels(ollamaListModelsQuery, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); return <> diff --git a/src/modules/llms/vendors/ollama/ollama.vendor.ts b/src/modules/llms/vendors/ollama/ollama.vendor.ts index 92cc4f1ef..98cd34259 100644 --- a/src/modules/llms/vendors/ollama/ollama.vendor.ts +++ b/src/modules/llms/vendors/ollama/ollama.vendor.ts @@ -4,9 +4,9 @@ import { OllamaIcon } from '~/common/components/icons/OllamaIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { IModelVendor } from '../IModelVendor'; -import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; -import type { OllamaAccessSchema } from '../../transports/server/ollama/ollama.router'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { ModelDescriptionSchema } from '../../server/server.schemas'; +import type { OllamaAccessSchema } from '../../server/ollama/ollama.router'; +import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import type { LLMOptionsOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; diff --git a/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx b/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx index 2841774a6..5e2c125c1 100644 --- a/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx +++ b/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx @@ -8,7 +8,7 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { openAIListModelsQuery } from '../openai/openai.vendor'; @@ -26,7 +26,7 @@ export function OobaboogaSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useUpdateVendorModels(openAIListModelsQuery, access, false /* !hasModels && !!asValidURL(normSetup.oaiHost) */, source); + useLlmUpdateModels(openAIListModelsQuery, access, false /* !hasModels && !!asValidURL(normSetup.oaiHost) */, source); return <> diff --git a/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts b/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts index b72827981..a16dee2e3 100644 --- a/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts +++ b/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts @@ -1,8 +1,8 @@ import { OobaboogaIcon } from '~/common/components/icons/OobaboogaIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; diff --git a/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx b/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx index 0a951327d..6d4b48b30 100644 --- a/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx +++ b/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx @@ -12,7 +12,7 @@ import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefet import { useToggleableBoolean } from '~/common/util/useToggleableBoolean'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { isValidOpenAIApiKey, ModelVendorOpenAI, openAIListModelsQuery } from './openai.vendor'; @@ -40,7 +40,7 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useUpdateVendorModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); + useLlmUpdateModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/openai/openai.vendor.ts b/src/modules/llms/vendors/openai/openai.vendor.ts index d0e9e4f1a..1e8d99296 100644 --- a/src/modules/llms/vendors/openai/openai.vendor.ts +++ b/src/modules/llms/vendors/openai/openai.vendor.ts @@ -4,9 +4,9 @@ import { OpenAIIcon } from '~/common/components/icons/OpenAIIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { IModelVendor } from '../IModelVendor'; -import type { ModelDescriptionSchema } from '../../transports/server/server.schemas'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { ModelDescriptionSchema } from '../../server/server.schemas'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { OpenAILLMOptions } from './OpenAILLMOptions'; import { OpenAISourceSetup } from './OpenAISourceSetup'; diff --git a/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx b/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx index a22eabff9..8f4f05427 100644 --- a/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx +++ b/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx @@ -9,7 +9,7 @@ import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefet import { getCallbackUrl } from '~/common/app.routes'; import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useUpdateVendorModels } from '../useUpdateVendorModels'; +import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; import { openAIListModelsQuery } from '../openai/openai.vendor'; @@ -32,7 +32,7 @@ export function OpenRouterSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useUpdateVendorModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); + useLlmUpdateModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); const handleOpenRouterLogin = () => { diff --git a/src/modules/llms/vendors/openrouter/openrouter.vendor.ts b/src/modules/llms/vendors/openrouter/openrouter.vendor.ts index 98a0ed156..a54fb414d 100644 --- a/src/modules/llms/vendors/openrouter/openrouter.vendor.ts +++ b/src/modules/llms/vendors/openrouter/openrouter.vendor.ts @@ -3,8 +3,8 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OpenRouterIcon } from '~/common/components/icons/OpenRouterIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '../../transports/server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OpenAIAccessSchema } from '~/modules/llms/server/openai/openai.router'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; diff --git a/src/server/api/trpc.router-edge.ts b/src/server/api/trpc.router-edge.ts index c9513f71f..24dc33e6f 100644 --- a/src/server/api/trpc.router-edge.ts +++ b/src/server/api/trpc.router-edge.ts @@ -3,10 +3,10 @@ import { createTRPCRouter } from './trpc.server'; import { backendRouter } from '~/modules/backend/backend.router'; import { elevenlabsRouter } from '~/modules/elevenlabs/elevenlabs.router'; import { googleSearchRouter } from '~/modules/google/search.router'; -import { llmAnthropicRouter } from '~/modules/llms/transports/server/anthropic/anthropic.router'; -import { llmGeminiRouter } from '~/modules/llms/transports/server/gemini/gemini.router'; -import { llmOllamaRouter } from '~/modules/llms/transports/server/ollama/ollama.router'; -import { llmOpenAIRouter } from '~/modules/llms/transports/server/openai/openai.router'; +import { llmAnthropicRouter } from '~/modules/llms/server/anthropic/anthropic.router'; +import { llmGeminiRouter } from '~/modules/llms/server/gemini/gemini.router'; +import { llmOllamaRouter } from '~/modules/llms/server/ollama/ollama.router'; +import { llmOpenAIRouter } from '~/modules/llms/server/openai/openai.router'; import { prodiaRouter } from '~/modules/prodia/prodia.router'; import { ytPersonaRouter } from '../../apps/personas/ytpersona.router'; From dd41a402d0c539e0a5c5dadfb1ab7841515a96db Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 16:40:22 -0800 Subject: [PATCH 12/24] Llms: move models modal --- src/common/layout/AppLayout.tsx | 2 +- src/modules/llms/client/llm.client.types.ts | 2 +- .../llms}/models-modal/LLMOptionsModal.tsx | 0 src/{apps => modules/llms}/models-modal/ModelsList.tsx | 0 src/{apps => modules/llms}/models-modal/ModelsModal.tsx | 0 .../llms}/models-modal/ModelsSourceSelector.tsx | 0 src/modules/llms/server/gemini/gemini.router.ts | 7 ++++--- src/modules/llms/vendors/azure/azure.vendor.ts | 2 +- src/modules/llms/vendors/openrouter/openrouter.vendor.ts | 2 +- src/modules/llms/vendors/vendors.registry.ts | 2 +- 10 files changed, 9 insertions(+), 8 deletions(-) rename src/{apps => modules/llms}/models-modal/LLMOptionsModal.tsx (100%) rename src/{apps => modules/llms}/models-modal/ModelsList.tsx (100%) rename src/{apps => modules/llms}/models-modal/ModelsModal.tsx (100%) rename src/{apps => modules/llms}/models-modal/ModelsSourceSelector.tsx (100%) diff --git a/src/common/layout/AppLayout.tsx b/src/common/layout/AppLayout.tsx index 46e52373b..e38a110c9 100644 --- a/src/common/layout/AppLayout.tsx +++ b/src/common/layout/AppLayout.tsx @@ -3,7 +3,7 @@ import { shallow } from 'zustand/shallow'; import { Box, Container } from '@mui/joy'; -import { ModelsModal } from '../../apps/models-modal/ModelsModal'; +import { ModelsModal } from '~/modules/llms/models-modal/ModelsModal'; import { SettingsModal } from '../../apps/settings-modal/SettingsModal'; import { ShortcutsModal } from '../../apps/settings-modal/ShortcutsModal'; diff --git a/src/modules/llms/client/llm.client.types.ts b/src/modules/llms/client/llm.client.types.ts index 5d4fa3be9..4e00ac4ad 100644 --- a/src/modules/llms/client/llm.client.types.ts +++ b/src/modules/llms/client/llm.client.types.ts @@ -1,4 +1,4 @@ -import type { OpenAIWire } from '~/modules/llms/server/openai/openai.wiretypes'; +import type { OpenAIWire } from '../server/openai/openai.wiretypes'; export interface VChatMessageIn { role: 'assistant' | 'system' | 'user'; // | 'function'; diff --git a/src/apps/models-modal/LLMOptionsModal.tsx b/src/modules/llms/models-modal/LLMOptionsModal.tsx similarity index 100% rename from src/apps/models-modal/LLMOptionsModal.tsx rename to src/modules/llms/models-modal/LLMOptionsModal.tsx diff --git a/src/apps/models-modal/ModelsList.tsx b/src/modules/llms/models-modal/ModelsList.tsx similarity index 100% rename from src/apps/models-modal/ModelsList.tsx rename to src/modules/llms/models-modal/ModelsList.tsx diff --git a/src/apps/models-modal/ModelsModal.tsx b/src/modules/llms/models-modal/ModelsModal.tsx similarity index 100% rename from src/apps/models-modal/ModelsModal.tsx rename to src/modules/llms/models-modal/ModelsModal.tsx diff --git a/src/apps/models-modal/ModelsSourceSelector.tsx b/src/modules/llms/models-modal/ModelsSourceSelector.tsx similarity index 100% rename from src/apps/models-modal/ModelsSourceSelector.tsx rename to src/modules/llms/models-modal/ModelsSourceSelector.tsx diff --git a/src/modules/llms/server/gemini/gemini.router.ts b/src/modules/llms/server/gemini/gemini.router.ts index 389ba2be8..52f9d43fd 100644 --- a/src/modules/llms/server/gemini/gemini.router.ts +++ b/src/modules/llms/server/gemini/gemini.router.ts @@ -4,11 +4,12 @@ import { TRPCError } from '@trpc/server'; import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server'; import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; -import { LLM_IF_OAI_Chat, LLM_IF_OAI_Vision } from '~/modules/llms/store-llms'; +import { LLM_IF_OAI_Chat, LLM_IF_OAI_Vision } from '../../store-llms'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; -import { GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes'; import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; -import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; + +import { GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes'; // Default hosts diff --git a/src/modules/llms/vendors/azure/azure.vendor.ts b/src/modules/llms/vendors/azure/azure.vendor.ts index 028f94df1..e6752105b 100644 --- a/src/modules/llms/vendors/azure/azure.vendor.ts +++ b/src/modules/llms/vendors/azure/azure.vendor.ts @@ -3,7 +3,7 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { AzureIcon } from '~/common/components/icons/AzureIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '~/modules/llms/server/openai/openai.router'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; diff --git a/src/modules/llms/vendors/openrouter/openrouter.vendor.ts b/src/modules/llms/vendors/openrouter/openrouter.vendor.ts index a54fb414d..66c6937ce 100644 --- a/src/modules/llms/vendors/openrouter/openrouter.vendor.ts +++ b/src/modules/llms/vendors/openrouter/openrouter.vendor.ts @@ -3,7 +3,7 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OpenRouterIcon } from '~/common/components/icons/OpenRouterIcon'; import type { IModelVendor } from '../IModelVendor'; -import type { OpenAIAccessSchema } from '~/modules/llms/server/openai/openai.router'; +import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; diff --git a/src/modules/llms/vendors/vendors.registry.ts b/src/modules/llms/vendors/vendors.registry.ts index ac0f223f4..8dcc6c53d 100644 --- a/src/modules/llms/vendors/vendors.registry.ts +++ b/src/modules/llms/vendors/vendors.registry.ts @@ -1,6 +1,6 @@ import { ModelVendorAnthropic } from './anthropic/anthropic.vendor'; import { ModelVendorAzure } from './azure/azure.vendor'; -import { ModelVendorGemini } from '~/modules/llms/vendors/googleai/gemini.vendor'; +import { ModelVendorGemini } from './googleai/gemini.vendor'; import { ModelVendorLocalAI } from './localai/localai.vendor'; import { ModelVendorMistral } from './mistral/mistral.vendor'; import { ModelVendorOllama } from './ollama/ollama.vendor'; From fd897b55b209a4eafaf6b6f11b04c8ce42b7e270 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 17:01:41 -0800 Subject: [PATCH 13/24] Llms: improve list generics --- src/modules/llms/client/llm.client.types.ts | 7 +++++++ src/modules/llms/client/useLlmUpdateModels.tsx | 11 ++--------- src/modules/llms/server/anthropic/anthropic.models.ts | 2 +- src/modules/llms/server/anthropic/anthropic.router.ts | 2 +- src/modules/llms/server/gemini/gemini.router.ts | 2 +- .../server/{server.schemas.ts => llm.server.types.ts} | 5 +++++ src/modules/llms/server/ollama/ollama.router.ts | 2 +- src/modules/llms/server/openai/models.data.ts | 2 +- src/modules/llms/server/openai/openai.router.ts | 2 +- src/modules/llms/vendors/IModelVendor.ts | 9 ++++++++- .../llms/vendors/anthropic/anthropic.vendor.ts | 8 +++----- src/modules/llms/vendors/googleai/gemini.vendor.ts | 8 +++----- src/modules/llms/vendors/ollama/ollama.vendor.ts | 8 +++----- src/modules/llms/vendors/openai/openai.vendor.ts | 8 +++----- 14 files changed, 40 insertions(+), 36 deletions(-) rename src/modules/llms/server/{server.schemas.ts => llm.server.types.ts} (91%) diff --git a/src/modules/llms/client/llm.client.types.ts b/src/modules/llms/client/llm.client.types.ts index 4e00ac4ad..3732e64a9 100644 --- a/src/modules/llms/client/llm.client.types.ts +++ b/src/modules/llms/client/llm.client.types.ts @@ -1,5 +1,12 @@ import type { OpenAIWire } from '../server/openai/openai.wiretypes'; + +// Model List types +// export { type ModelDescriptionSchema } from '../server/llm.server.types'; + + +// Chat Generate types + export interface VChatMessageIn { role: 'assistant' | 'system' | 'user'; // | 'function'; content: string; diff --git a/src/modules/llms/client/useLlmUpdateModels.tsx b/src/modules/llms/client/useLlmUpdateModels.tsx index 0aa70f399..ae6ac513a 100644 --- a/src/modules/llms/client/useLlmUpdateModels.tsx +++ b/src/modules/llms/client/useLlmUpdateModels.tsx @@ -1,15 +1,8 @@ -import type { TRPCClientErrorBase } from '@trpc/client'; - -import type { ModelDescriptionSchema } from '../server/server.schemas'; - +import type { IModelVendorUpdateModelsQuery } from '../vendors/IModelVendor'; +import type { ModelDescriptionSchema } from '../server/llm.server.types'; import { DLLM, DModelSource, useModelsStore } from '../store-llms'; -export type IModelVendorUpdateModelsQuery = - (access: TAccess, enabled: boolean, onSuccess: (data: { models: ModelDescriptionSchema[] }) => void) => - { isFetching: boolean, refetch: () => void, isError: boolean, error: TRPCClientErrorBase | null }; - - /** * Hook that fetches the list of models from the vendor and updates the store, * while returning the fetch state. diff --git a/src/modules/llms/server/anthropic/anthropic.models.ts b/src/modules/llms/server/anthropic/anthropic.models.ts index 3a0a4e0cc..6bbbfc55c 100644 --- a/src/modules/llms/server/anthropic/anthropic.models.ts +++ b/src/modules/llms/server/anthropic/anthropic.models.ts @@ -1,4 +1,4 @@ -import type { ModelDescriptionSchema } from '../server.schemas'; +import type { ModelDescriptionSchema } from '../llm.server.types'; import { LLM_IF_OAI_Chat } from '../../store-llms'; diff --git a/src/modules/llms/server/anthropic/anthropic.router.ts b/src/modules/llms/server/anthropic/anthropic.router.ts index 4433d2740..2ceb003b7 100644 --- a/src/modules/llms/server/anthropic/anthropic.router.ts +++ b/src/modules/llms/server/anthropic/anthropic.router.ts @@ -6,7 +6,7 @@ import { env } from '~/server/env.mjs'; import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; -import { listModelsOutputSchema } from '../server.schemas'; +import { listModelsOutputSchema } from '../llm.server.types'; import { AnthropicWire } from './anthropic.wiretypes'; import { hardcodedAnthropicModels } from './anthropic.models'; diff --git a/src/modules/llms/server/gemini/gemini.router.ts b/src/modules/llms/server/gemini/gemini.router.ts index 52f9d43fd..fde91180d 100644 --- a/src/modules/llms/server/gemini/gemini.router.ts +++ b/src/modules/llms/server/gemini/gemini.router.ts @@ -5,7 +5,7 @@ import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server'; import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; import { LLM_IF_OAI_Chat, LLM_IF_OAI_Vision } from '../../store-llms'; -import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types'; import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; diff --git a/src/modules/llms/server/server.schemas.ts b/src/modules/llms/server/llm.server.types.ts similarity index 91% rename from src/modules/llms/server/server.schemas.ts rename to src/modules/llms/server/llm.server.types.ts index e04b44e3d..15575c624 100644 --- a/src/modules/llms/server/server.schemas.ts +++ b/src/modules/llms/server/llm.server.types.ts @@ -1,6 +1,9 @@ import { z } from 'zod'; import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../store-llms'; + +// Model Description: a superset of LLM model descriptors + const pricingSchema = z.object({ cpmPrompt: z.number().optional(), // Cost per thousand prompt tokens cpmCompletion: z.number().optional(), // Cost per thousand completion tokens @@ -23,6 +26,8 @@ const modelDescriptionSchema = z.object({ interfaces: z.array(z.enum([LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Complete, LLM_IF_OAI_Vision])), hidden: z.boolean().optional(), }); + +// this is also used by the Client export type ModelDescriptionSchema = z.infer; export const listModelsOutputSchema = z.object({ diff --git a/src/modules/llms/server/ollama/ollama.router.ts b/src/modules/llms/server/ollama/ollama.router.ts index 91c19634d..954e798ad 100644 --- a/src/modules/llms/server/ollama/ollama.router.ts +++ b/src/modules/llms/server/ollama/ollama.router.ts @@ -10,7 +10,7 @@ import { LLM_IF_OAI_Chat } from '../../store-llms'; import { capitalizeFirstLetter } from '~/common/util/textUtils'; import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; -import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types'; import { OLLAMA_BASE_MODELS, OLLAMA_PREV_UPDATE } from './ollama.models'; import { WireOllamaChatCompletionInput, wireOllamaChunkedOutputSchema } from './ollama.wiretypes'; diff --git a/src/modules/llms/server/openai/models.data.ts b/src/modules/llms/server/openai/models.data.ts index aad475367..28dd5e59f 100644 --- a/src/modules/llms/server/openai/models.data.ts +++ b/src/modules/llms/server/openai/models.data.ts @@ -2,7 +2,7 @@ import { SERVER_DEBUG_WIRE } from '~/server/wire'; import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Vision } from '../../store-llms'; -import type { ModelDescriptionSchema } from '../server.schemas'; +import type { ModelDescriptionSchema } from '../llm.server.types'; import { wireMistralModelsListOutputSchema } from './mistral.wiretypes'; diff --git a/src/modules/llms/server/openai/openai.router.ts b/src/modules/llms/server/openai/openai.router.ts index 7b903d6ee..93c64ee77 100644 --- a/src/modules/llms/server/openai/openai.router.ts +++ b/src/modules/llms/server/openai/openai.router.ts @@ -8,7 +8,7 @@ import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; import { Brand } from '~/common/app.config'; import type { OpenAIWire } from './openai.wiretypes'; -import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types'; import { localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription } from './models.data'; diff --git a/src/modules/llms/vendors/IModelVendor.ts b/src/modules/llms/vendors/IModelVendor.ts index 4830158b1..a9dd5e123 100644 --- a/src/modules/llms/vendors/IModelVendor.ts +++ b/src/modules/llms/vendors/IModelVendor.ts @@ -1,6 +1,8 @@ import type React from 'react'; +import type { TRPCClientErrorBase } from '@trpc/client'; import type { DLLM, DModelSourceId } from '../store-llms'; +import type { ModelDescriptionSchema } from '../server/llm.server.types'; import type { ModelVendorId } from './vendors.registry'; import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../client/llm.client.types'; @@ -30,4 +32,9 @@ export interface IModelVendor; callChatGenerateWF(llm: TDLLM, messages: VChatMessageIn[], functions: null | VChatFunctionIn[], forceFunctionName: null | string, maxTokens?: number): Promise; -} \ No newline at end of file +} + + +export type IModelVendorUpdateModelsQuery = + (access: TAccess, enabled: boolean, onSuccess: (data: { models: ModelDescriptionSchema[] }) => void) => + { isFetching: boolean, refetch: () => void, isError: boolean, error: TRPCClientErrorBase | null }; diff --git a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts index 625fcb166..645caee5a 100644 --- a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts +++ b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts @@ -4,8 +4,7 @@ import { AnthropicIcon } from '~/common/components/icons/AnthropicIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { AnthropicAccessSchema } from '../../server/anthropic/anthropic.router'; -import type { IModelVendor } from '../IModelVendor'; -import type { ModelDescriptionSchema } from '../../server/server.schemas'; +import type { IModelVendor, IModelVendorUpdateModelsQuery } from '../IModelVendor'; import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI } from '../openai/openai.vendor'; @@ -52,14 +51,13 @@ export const ModelVendorAnthropic: IModelVendor void) { - return apiQuery.llmAnthropic.listModels.useQuery({ access }, { +export const anthropicListModelsQuery: IModelVendorUpdateModelsQuery = (access, enabled, onSuccess) => + apiQuery.llmAnthropic.listModels.useQuery({ access }, { enabled: enabled, onSuccess: onSuccess, refetchOnWindowFocus: false, staleTime: Infinity, }); -} /** diff --git a/src/modules/llms/vendors/googleai/gemini.vendor.ts b/src/modules/llms/vendors/googleai/gemini.vendor.ts index 826737496..6753668bd 100644 --- a/src/modules/llms/vendors/googleai/gemini.vendor.ts +++ b/src/modules/llms/vendors/googleai/gemini.vendor.ts @@ -5,8 +5,7 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { GeminiAccessSchema } from '../../server/gemini/gemini.router'; -import type { IModelVendor } from '../IModelVendor'; -import type { ModelDescriptionSchema } from '../../server/server.schemas'; +import type { IModelVendor, IModelVendorUpdateModelsQuery } from '../IModelVendor'; import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; @@ -62,14 +61,13 @@ export const ModelVendorGemini: IModelVendor void) { - return apiQuery.llmGemini.listModels.useQuery({ access }, { +export const geminiListModelsQuery: IModelVendorUpdateModelsQuery = (access, enabled, onSuccess) => + apiQuery.llmGemini.listModels.useQuery({ access }, { enabled: enabled, onSuccess: onSuccess, refetchOnWindowFocus: false, staleTime: Infinity, }); -} /** diff --git a/src/modules/llms/vendors/ollama/ollama.vendor.ts b/src/modules/llms/vendors/ollama/ollama.vendor.ts index 98cd34259..37e6088ec 100644 --- a/src/modules/llms/vendors/ollama/ollama.vendor.ts +++ b/src/modules/llms/vendors/ollama/ollama.vendor.ts @@ -3,8 +3,7 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OllamaIcon } from '~/common/components/icons/OllamaIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; -import type { IModelVendor } from '../IModelVendor'; -import type { ModelDescriptionSchema } from '../../server/server.schemas'; +import type { IModelVendor, IModelVendorUpdateModelsQuery } from '../IModelVendor'; import type { OllamaAccessSchema } from '../../server/ollama/ollama.router'; import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; @@ -46,14 +45,13 @@ export const ModelVendorOllama: IModelVendor void) { - return apiQuery.llmOllama.listModels.useQuery({ access }, { +export const ollamaListModelsQuery: IModelVendorUpdateModelsQuery = (access, enabled, onSuccess) => + apiQuery.llmOllama.listModels.useQuery({ access }, { enabled: enabled, onSuccess: onSuccess, refetchOnWindowFocus: false, staleTime: Infinity, }); -} /** diff --git a/src/modules/llms/vendors/openai/openai.vendor.ts b/src/modules/llms/vendors/openai/openai.vendor.ts index 1e8d99296..d229bfee2 100644 --- a/src/modules/llms/vendors/openai/openai.vendor.ts +++ b/src/modules/llms/vendors/openai/openai.vendor.ts @@ -3,8 +3,7 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OpenAIIcon } from '~/common/components/icons/OpenAIIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; -import type { IModelVendor } from '../IModelVendor'; -import type { ModelDescriptionSchema } from '../../server/server.schemas'; +import type { IModelVendor, IModelVendorUpdateModelsQuery } from '../IModelVendor'; import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; @@ -63,14 +62,13 @@ export const ModelVendorOpenAI: IModelVendor void) { - return apiQuery.llmOpenAI.listModels.useQuery({ access }, { +export const openAIListModelsQuery: IModelVendorUpdateModelsQuery = (access, enabled, onSuccess) => + apiQuery.llmOpenAI.listModels.useQuery({ access }, { enabled: enabled, onSuccess: onSuccess, refetchOnWindowFocus: false, staleTime: Infinity, }); -} /** From 0ece1ce58cb836be21e4192bf5fb925b6e0f1758 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 18:18:51 -0800 Subject: [PATCH 14/24] Llms: vendor-specific RPC to ChatGenerate --- src/apps/personas/useLLMChain.ts | 4 +- .../aifn/autosuggestions/autoSuggestions.ts | 13 ++- src/modules/aifn/autotitle/autoTitle.ts | 6 +- .../aifn/imagine/imaginePromptFromText.ts | 6 +- src/modules/aifn/react/react.ts | 4 +- src/modules/aifn/summarize/summerize.ts | 6 +- src/modules/llms/client/llmChatGenerate.ts | 24 ++++-- src/modules/llms/store-llms.ts | 30 ------- src/modules/llms/vendors/IModelVendor.ts | 20 +++-- .../anthropic/AnthropicSourceSetup.tsx | 9 ++- .../vendors/anthropic/anthropic.vendor.ts | 75 ++++++++---------- .../llms/vendors/azure/AzureSourceSetup.tsx | 9 +-- .../llms/vendors/azure/azure.vendor.ts | 13 ++- .../vendors/googleai/GeminiSourceSetup.tsx | 9 ++- .../llms/vendors/googleai/gemini.vendor.ts | 76 ++++++++---------- .../vendors/localai/LocalAISourceSetup.tsx | 9 +-- .../llms/vendors/localai/localai.vendor.ts | 15 ++-- .../vendors/mistral/MistralSourceSetup.tsx | 9 +-- .../llms/vendors/mistral/mistral.vendor.ts | 13 ++- .../llms/vendors/ollama/OllamaSourceSetup.tsx | 9 ++- .../llms/vendors/ollama/ollama.vendor.ts | 73 ++++++++--------- .../oobabooga/OobaboogaSourceSetup.tsx | 9 +-- .../vendors/oobabooga/oobabooga.vendor.ts | 13 ++- .../llms/vendors/openai/OpenAISourceSetup.tsx | 9 ++- .../llms/vendors/openai/openai.vendor.ts | 79 ++++++++----------- .../openrouter/OpenRouterSourceSetup.tsx | 9 +-- .../vendors/openrouter/openrouter.vendor.ts | 13 ++- .../useLlmUpdateModels.tsx | 6 +- src/modules/llms/vendors/useSourceSetup.ts | 35 ++++++++ src/modules/llms/vendors/vendors.registry.ts | 29 ++++--- 30 files changed, 309 insertions(+), 325 deletions(-) rename src/modules/llms/{client => vendors}/useLlmUpdateModels.tsx (80%) create mode 100644 src/modules/llms/vendors/useSourceSetup.ts diff --git a/src/apps/personas/useLLMChain.ts b/src/apps/personas/useLLMChain.ts index 99d0cd4a4..45f0be785 100644 --- a/src/apps/personas/useLLMChain.ts +++ b/src/apps/personas/useLLMChain.ts @@ -2,7 +2,7 @@ import * as React from 'react'; import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; import { DLLMId, useModelsStore } from '~/modules/llms/store-llms'; -import { llmChatGenerate } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; export interface LLMChainStep { @@ -81,7 +81,7 @@ export function useLLMChain(steps: LLMChainStep[], llmId: DLLMId | undefined, ch _chainAbortController.signal.addEventListener('abort', globalToStepListener); // LLM call - llmChatGenerate(llmId, llmChatInput, chain.overrideResponseTokens) + llmChatGenerateOrThrow(llmId, llmChatInput, null, null, chain.overrideResponseTokens) .then(({ content }) => { stepDone = true; if (!stepAbortController.signal.aborted) diff --git a/src/modules/aifn/autosuggestions/autoSuggestions.ts b/src/modules/aifn/autosuggestions/autoSuggestions.ts index 8cbd73065..8a0bbb9c5 100644 --- a/src/modules/aifn/autosuggestions/autoSuggestions.ts +++ b/src/modules/aifn/autosuggestions/autoSuggestions.ts @@ -1,5 +1,5 @@ import type { VChatFunctionIn } from '~/modules/llms/client/llm.client.types'; -import { llmChatGenerateWithFunctions } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; import { useModelsStore } from '~/modules/llms/store-llms'; import { useChatStore } from '~/common/state/store-chats'; @@ -72,7 +72,7 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri // Follow-up: Question if (suggestQuestions) { - // llmChatGenerateWithFunctions(funcLLMId, [ + // llmChatGenerateOrThrow(funcLLMId, [ // { role: 'system', content: systemMessage.text }, // { role: 'user', content: userMessage.text }, // { role: 'assistant', content: assistantMessageText }, @@ -84,15 +84,18 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri // Follow-up: Auto-Diagrams if (suggestDiagrams) { - void llmChatGenerateWithFunctions(funcLLMId, [ + void llmChatGenerateOrThrow(funcLLMId, [ { role: 'system', content: systemMessage.text }, { role: 'user', content: userMessage.text }, { role: 'assistant', content: assistantMessageText }, ], [suggestPlantUMLFn], 'draw_plantuml_diagram', ).then(chatResponse => { + if (!('function_arguments' in chatResponse)) + return; + // parse the output PlantUML string, if any - const functionArguments = chatResponse?.function_arguments ?? null; + const functionArguments = chatResponse.function_arguments ?? null; if (functionArguments) { const { code, type }: { code: string, type: string } = functionArguments as any; if (code && type) { @@ -106,6 +109,8 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri editMessage(conversationId, assistantMessageId, { text: assistantMessageText }, false); } } + }).catch(err => { + console.error('autoSuggestions::diagram:', err); }); } diff --git a/src/modules/aifn/autotitle/autoTitle.ts b/src/modules/aifn/autotitle/autoTitle.ts index 5b99bad14..3c0da5d99 100644 --- a/src/modules/aifn/autotitle/autoTitle.ts +++ b/src/modules/aifn/autotitle/autoTitle.ts @@ -1,4 +1,4 @@ -import { llmChatGenerate } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; import { useModelsStore } from '~/modules/llms/store-llms'; import { useChatStore } from '~/common/state/store-chats'; @@ -27,7 +27,7 @@ export function autoTitle(conversationId: string) { }); // LLM - void llmChatGenerate(fastLLMId, [ + void llmChatGenerateOrThrow(fastLLMId, [ { role: 'system', content: `You are an AI conversation titles assistant who specializes in creating expressive yet few-words chat titles.` }, { role: 'user', content: @@ -39,7 +39,7 @@ export function autoTitle(conversationId: string) { historyLines.join('\n') + '```\n', }, - ]).then(chatResponse => { + ], null, null).then(chatResponse => { const title = chatResponse?.content ?.trim() diff --git a/src/modules/aifn/imagine/imaginePromptFromText.ts b/src/modules/aifn/imagine/imaginePromptFromText.ts index 936bd6995..27c1b47d8 100644 --- a/src/modules/aifn/imagine/imaginePromptFromText.ts +++ b/src/modules/aifn/imagine/imaginePromptFromText.ts @@ -1,4 +1,4 @@ -import { llmChatGenerate } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; import { useModelsStore } from '~/modules/llms/store-llms'; @@ -14,10 +14,10 @@ export async function imaginePromptFromText(messageText: string): Promise { - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - return await vendor.callChatGenerate(llm, messages, maxTokens); -} +export async function llmChatGenerateOrThrow( + llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number, +): Promise { + + // id to DLLM and vendor + const { llm, vendor } = findVendorForLlmOrThrow(llmId); + + // FIXME: relax the forced cast + const options = llm.options as TLLMOptions; -export async function llmChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], forceFunctionName: string | null, maxTokens?: number): Promise { - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - return await vendor.callChatGenerateWF(llm, messages, functions, forceFunctionName, maxTokens); -} \ No newline at end of file + // get the access + const partialSourceSetup = llm._source.setup; + const access = vendor.getTransportAccess(partialSourceSetup); + + // execute via the vendor + return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens); +} diff --git a/src/modules/llms/store-llms.ts b/src/modules/llms/store-llms.ts index 74a3fe097..c352eccdc 100644 --- a/src/modules/llms/store-llms.ts +++ b/src/modules/llms/store-llms.ts @@ -2,7 +2,6 @@ import { create } from 'zustand'; import { shallow } from 'zustand/shallow'; import { persist } from 'zustand/middleware'; -import type { IModelVendor } from './vendors/IModelVendor'; import type { ModelVendorId } from './vendors/vendors.registry'; import type { SourceSetupOpenRouter } from './vendors/openrouter/openrouter.vendor'; @@ -282,32 +281,3 @@ export function useChatLLM() { }, shallow); } -/** - * Source-specific read/write - great time saver - */ -export function useSourceSetup(sourceId: DModelSourceId, vendor: IModelVendor) { - - // invalidates only when the setup changes - const { updateSourceSetup, ...rest } = useModelsStore(state => { - - // find the source (or null) - const source: DModelSource | null = state.sources.find(source => source.id === sourceId) as DModelSource ?? null; - - // (safe) source-derived properties - const sourceSetupValid = (source?.setup && vendor?.validateSetup) ? vendor.validateSetup(source.setup as TSourceSetup) : false; - const sourceLLMs = source ? state.llms.filter(llm => llm._source === source) : []; - const access = vendor.getTransportAccess(source?.setup); - - return { - source, - access, - sourceHasLLMs: !!sourceLLMs.length, - sourceSetupValid, - updateSourceSetup: state.updateSourceSetup, - }; - }, shallow); - - // convenience function for this source - const updateSetup = (partialSetup: Partial) => updateSourceSetup(sourceId, partialSetup); - return { ...rest, updateSetup }; -} \ No newline at end of file diff --git a/src/modules/llms/vendors/IModelVendor.ts b/src/modules/llms/vendors/IModelVendor.ts index a9dd5e123..6c2290b6a 100644 --- a/src/modules/llms/vendors/IModelVendor.ts +++ b/src/modules/llms/vendors/IModelVendor.ts @@ -29,12 +29,18 @@ export interface IModelVendor): TAccess; - callChatGenerate(llm: TDLLM, messages: VChatMessageIn[], maxTokens?: number): Promise; + rpcUpdateModelsQuery: ( + access: TAccess, + enabled: boolean, + onSuccess: (data: { models: ModelDescriptionSchema[] }) => void, + ) => { isFetching: boolean, refetch: () => void, isError: boolean, error: TRPCClientErrorBase | null }; + + rpcChatGenerateOrThrow: ( + access: TAccess, + llmOptions: TLLMOptions, + messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, forceFunctionName: string | null, + maxTokens?: number, + ) => Promise; - callChatGenerateWF(llm: TDLLM, messages: VChatMessageIn[], functions: null | VChatFunctionIn[], forceFunctionName: null | string, maxTokens?: number): Promise; } - - -export type IModelVendorUpdateModelsQuery = - (access: TAccess, enabled: boolean, onSuccess: (data: { models: ModelDescriptionSchema[] }) => void) => - { isFetching: boolean, refetch: () => void, isError: boolean, error: TRPCClientErrorBase | null }; diff --git a/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx b/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx index 46a650ccd..d5e214961 100644 --- a/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx +++ b/src/modules/llms/vendors/anthropic/AnthropicSourceSetup.tsx @@ -9,10 +9,11 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { useToggleableBoolean } from '~/common/util/useToggleableBoolean'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; -import { anthropicListModelsQuery, isValidAnthropicApiKey, ModelVendorAnthropic } from './anthropic.vendor'; +import { isValidAnthropicApiKey, ModelVendorAnthropic } from './anthropic.vendor'; export function AnthropicSourceSetup(props: { sourceId: DModelSourceId }) { @@ -34,7 +35,7 @@ export function AnthropicSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useLlmUpdateModels(anthropicListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); + useLlmUpdateModels(ModelVendorAnthropic, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts index 645caee5a..c8d4fd2d1 100644 --- a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts +++ b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts @@ -4,8 +4,8 @@ import { AnthropicIcon } from '~/common/components/icons/AnthropicIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { AnthropicAccessSchema } from '../../server/anthropic/anthropic.router'; -import type { IModelVendor, IModelVendorUpdateModelsQuery } from '../IModelVendor'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; +import type { IModelVendor } from '../IModelVendor'; +import type { VChatMessageOut } from '../../client/llm.client.types'; import { LLMOptionsOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; @@ -42,46 +42,39 @@ export const ModelVendorAnthropic: IModelVendor { - return anthropicCallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, /*null, null,*/ maxTokens); - }, - callChatGenerateWF(): Promise { - throw new Error('Anthropic does not support "Functions" yet'); - }, -}; -export const anthropicListModelsQuery: IModelVendorUpdateModelsQuery = (access, enabled, onSuccess) => - apiQuery.llmAnthropic.listModels.useQuery({ access }, { - enabled: enabled, - onSuccess: onSuccess, - refetchOnWindowFocus: false, - staleTime: Infinity, - }); + // List Models + rpcUpdateModelsQuery: (access, enabled, onSuccess) => { + return apiQuery.llmAnthropic.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); + }, + // Chat Generate (non-streaming) with Functions + rpcChatGenerateOrThrow: async (access, llmOptions, messages, functions, forceFunctionName, maxTokens) => { + if (functions?.length || forceFunctionName) + throw new Error('Anthropic does not support functions'); -/** - * This function either returns the LLM message, or function calls, or throws a descriptive error string - */ -async function anthropicCallChatGenerate( - access: AnthropicAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], - // functions: VChatFunctionIn[] | null, forceFunctionName: string | null, - maxTokens?: number, -): Promise { - const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; - try { - return await apiAsync.llmAnthropic.chatGenerate.mutate({ - access, - model: { - id: llmRef!, - temperature: llmTemperature, - maxTokens: maxTokens || llmResponseTokens || 1024, - }, - history: messages, - }) as TOut; - } catch (error: any) { - const errorMessage = error?.message || error?.toString() || 'Anthropic Chat Generate Error'; - console.error(`anthropicCallChatGenerate: ${errorMessage}`); - throw new Error(errorMessage); - } -} \ No newline at end of file + const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; + try { + return await apiAsync.llmAnthropic.chatGenerate.mutate({ + access, + model: { + id: llmRef!, + temperature: llmTemperature, + maxTokens: maxTokens || llmResponseTokens || 1024, + }, + history: messages, + }) as VChatMessageOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'Anthropic Chat Generate Error'; + console.error(`anthropic.rpcChatGenerateOrThrow: ${errorMessage}`); + throw new Error(errorMessage); + } + }, + +}; diff --git a/src/modules/llms/vendors/azure/AzureSourceSetup.tsx b/src/modules/llms/vendors/azure/AzureSourceSetup.tsx index f0a50d0ba..4de8838d8 100644 --- a/src/modules/llms/vendors/azure/AzureSourceSetup.tsx +++ b/src/modules/llms/vendors/azure/AzureSourceSetup.tsx @@ -7,10 +7,9 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { asValidURL } from '~/common/util/urlUtils'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; - -import { openAIListModelsQuery } from '../openai/openai.vendor'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { isValidAzureApiKey, ModelVendorAzure } from './azure.vendor'; @@ -33,7 +32,7 @@ export function AzureSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useLlmUpdateModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); + useLlmUpdateModels(ModelVendorAzure, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/azure/azure.vendor.ts b/src/modules/llms/vendors/azure/azure.vendor.ts index e6752105b..275c41205 100644 --- a/src/modules/llms/vendors/azure/azure.vendor.ts +++ b/src/modules/llms/vendors/azure/azure.vendor.ts @@ -4,9 +4,8 @@ import { AzureIcon } from '~/common/components/icons/AzureIcon'; import type { IModelVendor } from '../IModelVendor'; import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; -import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { AzureSourceSetup } from './AzureSourceSetup'; @@ -58,10 +57,8 @@ export const ModelVendorAzure: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, functions, forceFunctionName, maxTokens); - }, + + // OpenAI transport ('azure' dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, }; \ No newline at end of file diff --git a/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx b/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx index 357e0572b..cdf1f5be6 100644 --- a/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx +++ b/src/modules/llms/vendors/googleai/GeminiSourceSetup.tsx @@ -5,10 +5,11 @@ import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; -import { geminiListModelsQuery, ModelVendorGemini } from './gemini.vendor'; +import { ModelVendorGemini } from './gemini.vendor'; const GEMINI_API_KEY_LINK = 'https://makersuite.google.com/app/apikey'; @@ -29,7 +30,7 @@ export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useLlmUpdateModels(geminiListModelsQuery, access, shallFetchSucceed, source); + useLlmUpdateModels(ModelVendorGemini, access, shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/googleai/gemini.vendor.ts b/src/modules/llms/vendors/googleai/gemini.vendor.ts index 6753668bd..9f3f9b8cc 100644 --- a/src/modules/llms/vendors/googleai/gemini.vendor.ts +++ b/src/modules/llms/vendors/googleai/gemini.vendor.ts @@ -5,8 +5,8 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { GeminiAccessSchema } from '../../server/gemini/gemini.router'; -import type { IModelVendor, IModelVendorUpdateModelsQuery } from '../IModelVendor'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; +import type { IModelVendor } from '../IModelVendor'; +import type { VChatMessageOut } from '../../client/llm.client.types'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; @@ -52,46 +52,38 @@ export const ModelVendorGemini: IModelVendor { - return geminiCallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, maxTokens); - }, - callChatGenerateWF(): Promise { - throw new Error('Gemini does not support "Functions" yet'); - }, -}; + // List Models + rpcUpdateModelsQuery: (access, enabled, onSuccess) => { + return apiQuery.llmGemini.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); + }, -export const geminiListModelsQuery: IModelVendorUpdateModelsQuery = (access, enabled, onSuccess) => - apiQuery.llmGemini.listModels.useQuery({ access }, { - enabled: enabled, - onSuccess: onSuccess, - refetchOnWindowFocus: false, - staleTime: Infinity, - }); - - -/** - * This function either returns the LLM message, or throws a descriptive error string - */ -async function geminiCallChatGenerate( - access: GeminiAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], - maxTokens?: number, -): Promise { - const { llmRef, temperature = 0.5, maxOutputTokens } = llmOptions; - try { - return await apiAsync.llmGemini.chatGenerate.mutate({ - access, - model: { - id: llmRef!, - temperature: temperature, - maxTokens: maxTokens || maxOutputTokens || 1024, - }, - history: messages, - }) as TOut; - } catch (error: any) { - const errorMessage = error?.message || error?.toString() || 'Gemini Chat Generate Error'; - console.error(`geminiCallChatGenerate: ${errorMessage}`); - throw new Error(errorMessage); - } -} + // Chat Generate (non-streaming) with Functions + rpcChatGenerateOrThrow: async (access, llmOptions, messages, functions, forceFunctionName, maxTokens) => { + if (functions?.length || forceFunctionName) + throw new Error('Gemini does not support functions'); + + const { llmRef, temperature = 0.5, maxOutputTokens } = llmOptions; + try { + return await apiAsync.llmGemini.chatGenerate.mutate({ + access, + model: { + id: llmRef!, + temperature: temperature, + maxTokens: maxTokens || maxOutputTokens || 1024, + }, + history: messages, + }) as VChatMessageOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'Gemini Chat Generate Error'; + console.error(`gemini.rpcChatGenerateOrThrow: ${errorMessage}`); + throw new Error(errorMessage); + } + }, +}; diff --git a/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx b/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx index 6bdce7073..ca1c2ec57 100644 --- a/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx +++ b/src/modules/llms/vendors/localai/LocalAISourceSetup.tsx @@ -8,10 +8,9 @@ import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; - -import { openAIListModelsQuery } from '../openai/openai.vendor'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { ModelVendorLocalAI } from './localai.vendor'; @@ -32,7 +31,7 @@ export function LocalAISourceSetup(props: { sourceId: DModelSourceId }) { // fetch models - the OpenAI way const { isFetching, refetch, isError, error } = - useLlmUpdateModels(openAIListModelsQuery, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); + useLlmUpdateModels(ModelVendorLocalAI, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); return <> diff --git a/src/modules/llms/vendors/localai/localai.vendor.ts b/src/modules/llms/vendors/localai/localai.vendor.ts index e39d5245b..92f6804ec 100644 --- a/src/modules/llms/vendors/localai/localai.vendor.ts +++ b/src/modules/llms/vendors/localai/localai.vendor.ts @@ -2,9 +2,8 @@ import DevicesIcon from '@mui/icons-material/Devices'; import type { IModelVendor } from '../IModelVendor'; import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; -import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { LocalAISourceSetup } from './LocalAISourceSetup'; @@ -38,10 +37,8 @@ export const ModelVendorLocalAI: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, functions, forceFunctionName, maxTokens); - }, -}; \ No newline at end of file + + // OpenAI transport ('localai' dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, +}; diff --git a/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx b/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx index 95f563c6f..796a04da1 100644 --- a/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx +++ b/src/modules/llms/vendors/mistral/MistralSourceSetup.tsx @@ -5,10 +5,9 @@ import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; - -import { openAIListModelsQuery } from '../openai/openai.vendor'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { ModelVendorMistral } from './mistral.vendor'; @@ -31,7 +30,7 @@ export function MistralSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useLlmUpdateModels(openAIListModelsQuery, access, shallFetchSucceed, source); + useLlmUpdateModels(ModelVendorMistral, access, shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/mistral/mistral.vendor.ts b/src/modules/llms/vendors/mistral/mistral.vendor.ts index cdcb0c3bc..da8b977e3 100644 --- a/src/modules/llms/vendors/mistral/mistral.vendor.ts +++ b/src/modules/llms/vendors/mistral/mistral.vendor.ts @@ -4,9 +4,8 @@ import { MistralIcon } from '~/common/components/icons/MistralIcon'; import type { IModelVendor } from '../IModelVendor'; import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import type { VChatMessageIn, VChatMessageOut } from '../../client/llm.client.types'; -import { LLMOptionsOpenAI, openAICallChatGenerate, SourceSetupOpenAI } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI, SourceSetupOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { MistralSourceSetup } from './MistralSourceSetup'; @@ -48,10 +47,8 @@ export const ModelVendorMistral: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF() { - throw new Error('Mistral does not support "Functions" yet'); - }, + + // OpenAI transport ('mistral' dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, }; \ No newline at end of file diff --git a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx index f66cc58cd..9fd3a2617 100644 --- a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx +++ b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx @@ -8,10 +8,11 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { asValidURL } from '~/common/util/urlUtils'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; -import { ModelVendorOllama, ollamaListModelsQuery } from './ollama.vendor'; +import { ModelVendorOllama } from './ollama.vendor'; import { OllamaAdministration } from './OllamaAdministration'; @@ -33,7 +34,7 @@ export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useLlmUpdateModels(ollamaListModelsQuery, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); + useLlmUpdateModels(ModelVendorOllama, access, false /* !sourceHasLLMs && shallFetchSucceed */, source); return <> diff --git a/src/modules/llms/vendors/ollama/ollama.vendor.ts b/src/modules/llms/vendors/ollama/ollama.vendor.ts index 37e6088ec..04f97bb24 100644 --- a/src/modules/llms/vendors/ollama/ollama.vendor.ts +++ b/src/modules/llms/vendors/ollama/ollama.vendor.ts @@ -3,9 +3,9 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OllamaIcon } from '~/common/components/icons/OllamaIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; -import type { IModelVendor, IModelVendorUpdateModelsQuery } from '../IModelVendor'; +import type { IModelVendor } from '../IModelVendor'; import type { OllamaAccessSchema } from '../../server/ollama/ollama.router'; -import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; +import type { VChatMessageOut } from '../../client/llm.client.types'; import type { LLMOptionsOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; @@ -36,45 +36,38 @@ export const ModelVendorOllama: IModelVendor { - return ollamaCallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, maxTokens); - }, - callChatGenerateWF(): Promise { - throw new Error('Ollama does not support "Functions" yet'); - }, -}; + // List Models + rpcUpdateModelsQuery: (access, enabled, onSuccess) => { + return apiQuery.llmOllama.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); + }, -export const ollamaListModelsQuery: IModelVendorUpdateModelsQuery = (access, enabled, onSuccess) => - apiQuery.llmOllama.listModels.useQuery({ access }, { - enabled: enabled, - onSuccess: onSuccess, - refetchOnWindowFocus: false, - staleTime: Infinity, - }); + // Chat Generate (non-streaming) with Functions + rpcChatGenerateOrThrow: async (access, llmOptions, messages, functions, forceFunctionName, maxTokens) => { + if (functions?.length || forceFunctionName) + throw new Error('Ollama does not support functions'); + const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; + try { + return await apiAsync.llmOllama.chatGenerate.mutate({ + access, + model: { + id: llmRef!, + temperature: llmTemperature, + maxTokens: maxTokens || llmResponseTokens || 1024, + }, + history: messages, + }) as VChatMessageOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'Ollama Chat Generate Error'; + console.error(`ollama.rpcChatGenerateOrThrow: ${errorMessage}`); + throw new Error(errorMessage); + } + }, -/** - * This function either returns the LLM message, or throws a descriptive error string - */ -async function ollamaCallChatGenerate( - access: OllamaAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], - maxTokens?: number, -): Promise { - const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; - try { - return await apiAsync.llmOllama.chatGenerate.mutate({ - access, - model: { - id: llmRef!, - temperature: llmTemperature, - maxTokens: maxTokens || llmResponseTokens || 1024, - }, - history: messages, - }) as TOut; - } catch (error: any) { - const errorMessage = error?.message || error?.toString() || 'Ollama Chat Generate Error'; - console.error(`ollamaCallChatGenerate: ${errorMessage}`); - throw new Error(errorMessage); - } -} +}; diff --git a/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx b/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx index 5e2c125c1..f218c5829 100644 --- a/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx +++ b/src/modules/llms/vendors/oobabooga/OobaboogaSourceSetup.tsx @@ -7,10 +7,9 @@ import { InlineError } from '~/common/components/InlineError'; import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; - -import { openAIListModelsQuery } from '../openai/openai.vendor'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { ModelVendorOoobabooga } from './oobabooga.vendor'; @@ -26,7 +25,7 @@ export function OobaboogaSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useLlmUpdateModels(openAIListModelsQuery, access, false /* !hasModels && !!asValidURL(normSetup.oaiHost) */, source); + useLlmUpdateModels(ModelVendorOoobabooga, access, false /* !hasModels && !!asValidURL(normSetup.oaiHost) */, source); return <> diff --git a/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts b/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts index a16dee2e3..c9b7bf682 100644 --- a/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts +++ b/src/modules/llms/vendors/oobabooga/oobabooga.vendor.ts @@ -2,9 +2,8 @@ import { OobaboogaIcon } from '~/common/components/icons/OobaboogaIcon'; import type { IModelVendor } from '../IModelVendor'; import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; -import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { OobaboogaSourceSetup } from './OobaboogaSourceSetup'; @@ -38,10 +37,8 @@ export const ModelVendorOoobabooga: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, functions, forceFunctionName, maxTokens); - }, + + // OpenAI transport (oobabooga dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, }; \ No newline at end of file diff --git a/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx b/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx index 6d4b48b30..aa174db01 100644 --- a/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx +++ b/src/modules/llms/vendors/openai/OpenAISourceSetup.tsx @@ -11,10 +11,11 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { useToggleableBoolean } from '~/common/util/useToggleableBoolean'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; -import { isValidOpenAIApiKey, ModelVendorOpenAI, openAIListModelsQuery } from './openai.vendor'; +import { isValidOpenAIApiKey, ModelVendorOpenAI } from './openai.vendor'; // avoid repeating it all over @@ -40,7 +41,7 @@ export function OpenAISourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useLlmUpdateModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); + useLlmUpdateModels(ModelVendorOpenAI, access, !sourceHasLLMs && shallFetchSucceed, source); return <> diff --git a/src/modules/llms/vendors/openai/openai.vendor.ts b/src/modules/llms/vendors/openai/openai.vendor.ts index d229bfee2..72b6c3b9b 100644 --- a/src/modules/llms/vendors/openai/openai.vendor.ts +++ b/src/modules/llms/vendors/openai/openai.vendor.ts @@ -3,9 +3,9 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { OpenAIIcon } from '~/common/components/icons/OpenAIIcon'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; -import type { IModelVendor, IModelVendorUpdateModelsQuery } from '../IModelVendor'; +import type { IModelVendor } from '../IModelVendor'; import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; +import type { VChatMessageOrFunctionCallOut } from '../../client/llm.client.types'; import { OpenAILLMOptions } from './OpenAILLMOptions'; import { OpenAISourceSetup } from './OpenAISourceSetup'; @@ -51,50 +51,37 @@ export const ModelVendorOpenAI: IModelVendor { - const access = this.getTransportAccess(llm._source.setup); - return openAICallChatGenerate(access, llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - const access = this.getTransportAccess(llm._source.setup); - return openAICallChatGenerate(access, llm.options, messages, functions, forceFunctionName, maxTokens); - }, -}; - -export const openAIListModelsQuery: IModelVendorUpdateModelsQuery = (access, enabled, onSuccess) => - apiQuery.llmOpenAI.listModels.useQuery({ access }, { - enabled: enabled, - onSuccess: onSuccess, - refetchOnWindowFocus: false, - staleTime: Infinity, - }); + // List Models + rpcUpdateModelsQuery: (access, enabled, onSuccess) => { + return apiQuery.llmOpenAI.listModels.useQuery({ access }, { + enabled: enabled, + onSuccess: onSuccess, + refetchOnWindowFocus: false, + staleTime: Infinity, + }); + }, + // Chat Generate (non-streaming) with Functions + rpcChatGenerateOrThrow: async (access, llmOptions, messages, functions, forceFunctionName, maxTokens) => { + const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; + try { + return await apiAsync.llmOpenAI.chatGenerateWithFunctions.mutate({ + access, + model: { + id: llmRef!, + temperature: llmTemperature, + maxTokens: maxTokens || llmResponseTokens || 1024, + }, + functions: functions ?? undefined, + forceFunctionName: forceFunctionName ?? undefined, + history: messages, + }) as VChatMessageOrFunctionCallOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'OpenAI Chat Generate Error'; + console.error(`openai.rpcChatGenerateOrThrow: ${errorMessage}`); + throw new Error(errorMessage); + } + }, -/** - * This function either returns the LLM message, or function calls, or throws a descriptive error string - */ -export async function openAICallChatGenerate( - access: OpenAIAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], - functions: VChatFunctionIn[] | null, forceFunctionName: string | null, - maxTokens?: number, -): Promise { - const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; - try { - return await apiAsync.llmOpenAI.chatGenerateWithFunctions.mutate({ - access, - model: { - id: llmRef!, - temperature: llmTemperature, - maxTokens: maxTokens || llmResponseTokens || 1024, - }, - functions: functions ?? undefined, - forceFunctionName: forceFunctionName ?? undefined, - history: messages, - }) as TOut; - } catch (error: any) { - const errorMessage = error?.message || error?.toString() || 'OpenAI Chat Generate Error'; - console.error(`openAICallChatGenerate: ${errorMessage}`); - throw new Error(errorMessage); - } -} \ No newline at end of file +}; diff --git a/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx b/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx index 8f4f05427..9c3df5da5 100644 --- a/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx +++ b/src/modules/llms/vendors/openrouter/OpenRouterSourceSetup.tsx @@ -8,10 +8,9 @@ import { Link } from '~/common/components/Link'; import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; import { getCallbackUrl } from '~/common/app.routes'; -import { DModelSourceId, useSourceSetup } from '../../store-llms'; -import { useLlmUpdateModels } from '../../client/useLlmUpdateModels'; - -import { openAIListModelsQuery } from '../openai/openai.vendor'; +import { DModelSourceId } from '../../store-llms'; +import { useLlmUpdateModels } from '../useLlmUpdateModels'; +import { useSourceSetup } from '../useSourceSetup'; import { isValidOpenRouterKey, ModelVendorOpenRouter } from './openrouter.vendor'; @@ -32,7 +31,7 @@ export function OpenRouterSourceSetup(props: { sourceId: DModelSourceId }) { // fetch models const { isFetching, refetch, isError, error } = - useLlmUpdateModels(openAIListModelsQuery, access, !sourceHasLLMs && shallFetchSucceed, source); + useLlmUpdateModels(ModelVendorOpenRouter, access, !sourceHasLLMs && shallFetchSucceed, source); const handleOpenRouterLogin = () => { diff --git a/src/modules/llms/vendors/openrouter/openrouter.vendor.ts b/src/modules/llms/vendors/openrouter/openrouter.vendor.ts index 66c6937ce..2cd3e547d 100644 --- a/src/modules/llms/vendors/openrouter/openrouter.vendor.ts +++ b/src/modules/llms/vendors/openrouter/openrouter.vendor.ts @@ -4,9 +4,8 @@ import { OpenRouterIcon } from '~/common/components/icons/OpenRouterIcon'; import type { IModelVendor } from '../IModelVendor'; import type { OpenAIAccessSchema } from '../../server/openai/openai.router'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../client/llm.client.types'; -import { LLMOptionsOpenAI, openAICallChatGenerate } from '../openai/openai.vendor'; +import { LLMOptionsOpenAI, ModelVendorOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; import { OpenRouterSourceSetup } from './OpenRouterSourceSetup'; @@ -59,10 +58,8 @@ export const ModelVendorOpenRouter: IModelVendor { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, null, null, maxTokens); - }, - callChatGenerateWF(llm, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number): Promise { - return openAICallChatGenerate(this.getTransportAccess(llm._source.setup), llm.options, messages, functions, forceFunctionName, maxTokens); - }, + + // OpenAI transport ('openrouter' dialect in 'access') + rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery, + rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow, }; \ No newline at end of file diff --git a/src/modules/llms/client/useLlmUpdateModels.tsx b/src/modules/llms/vendors/useLlmUpdateModels.tsx similarity index 80% rename from src/modules/llms/client/useLlmUpdateModels.tsx rename to src/modules/llms/vendors/useLlmUpdateModels.tsx index ae6ac513a..cc12bb048 100644 --- a/src/modules/llms/client/useLlmUpdateModels.tsx +++ b/src/modules/llms/vendors/useLlmUpdateModels.tsx @@ -1,4 +1,4 @@ -import type { IModelVendorUpdateModelsQuery } from '../vendors/IModelVendor'; +import type { IModelVendor } from './IModelVendor'; import type { ModelDescriptionSchema } from '../server/llm.server.types'; import { DLLM, DModelSource, useModelsStore } from '../store-llms'; @@ -7,8 +7,8 @@ import { DLLM, DModelSource, useModelsStore } from '../store-llms'; * Hook that fetches the list of models from the vendor and updates the store, * while returning the fetch state. */ -export function useLlmUpdateModels(listFn: IModelVendorUpdateModelsQuery, access: TAccess, enabled: boolean, source: DModelSource) { - return listFn(access, enabled, data => source && updateModelsFn(data, source)); +export function useLlmUpdateModels(vendor: IModelVendor, access: TAccess, enabled: boolean, source: DModelSource) { + return vendor.rpcUpdateModelsQuery(access, enabled, data => source && updateModelsFn(data, source)); } diff --git a/src/modules/llms/vendors/useSourceSetup.ts b/src/modules/llms/vendors/useSourceSetup.ts new file mode 100644 index 000000000..4395ac458 --- /dev/null +++ b/src/modules/llms/vendors/useSourceSetup.ts @@ -0,0 +1,35 @@ +import { shallow } from 'zustand/shallow'; + +import type { IModelVendor } from './IModelVendor'; +import { DModelSource, DModelSourceId, useModelsStore } from '../store-llms'; + + +/** + * Source-specific read/write - great time saver + */ +export function useSourceSetup(sourceId: DModelSourceId, vendor: IModelVendor) { + + // invalidates only when the setup changes + const { updateSourceSetup, ...rest } = useModelsStore(state => { + + // find the source (or null) + const source: DModelSource | null = state.sources.find(source => source.id === sourceId) as DModelSource ?? null; + + // (safe) source-derived properties + const sourceSetupValid = (source?.setup && vendor?.validateSetup) ? vendor.validateSetup(source.setup as TSourceSetup) : false; + const sourceLLMs = source ? state.llms.filter(llm => llm._source === source) : []; + const access = vendor.getTransportAccess(source?.setup); + + return { + source, + access, + sourceHasLLMs: !!sourceLLMs.length, + sourceSetupValid, + updateSourceSetup: state.updateSourceSetup, + }; + }, shallow); + + // convenience function for this source + const updateSetup = (partialSetup: Partial) => updateSourceSetup(sourceId, partialSetup); + return { ...rest, updateSetup }; +} \ No newline at end of file diff --git a/src/modules/llms/vendors/vendors.registry.ts b/src/modules/llms/vendors/vendors.registry.ts index 8dcc6c53d..061c34ea1 100644 --- a/src/modules/llms/vendors/vendors.registry.ts +++ b/src/modules/llms/vendors/vendors.registry.ts @@ -11,8 +11,19 @@ import { ModelVendorOpenRouter } from './openrouter/openrouter.vendor'; import type { IModelVendor } from './IModelVendor'; import { DLLMId, DModelSource, DModelSourceId, findLLMOrThrow } from '../store-llms'; +export type ModelVendorId = + | 'anthropic' + | 'azure' + | 'googleai' + | 'localai' + | 'mistral' + | 'ollama' + | 'oobabooga' + | 'openai' + | 'openrouter'; + /** Global: Vendor Instances Registry **/ -const MODEL_VENDOR_REGISTRY = { +const MODEL_VENDOR_REGISTRY: Record = { anthropic: ModelVendorAnthropic, azure: ModelVendorAzure, googleai: ModelVendorGemini, @@ -22,9 +33,7 @@ const MODEL_VENDOR_REGISTRY = { oobabooga: ModelVendorOoobabooga, openai: ModelVendorOpenAI, openrouter: ModelVendorOpenRouter, -} as const; - -export type ModelVendorId = keyof typeof MODEL_VENDOR_REGISTRY; +} as Record; const MODEL_VENDOR_DEFAULT: ModelVendorId = 'openai'; @@ -35,13 +44,15 @@ export function findAllVendors(): IModelVendor[] { return modelVendors; } -export function findVendorById(vendorId?: ModelVendorId): IModelVendor | null { - return vendorId ? (MODEL_VENDOR_REGISTRY[vendorId] ?? null) : null; +export function findVendorById( + vendorId?: ModelVendorId, +): IModelVendor | null { + return vendorId ? (MODEL_VENDOR_REGISTRY[vendorId] as IModelVendor) ?? null : null; } -export function findVendorForLlmOrThrow(llmId: DLLMId) { - const llm = findLLMOrThrow(llmId); - const vendor = findVendorById(llm?._source.vId); +export function findVendorForLlmOrThrow(llmId: DLLMId) { + const llm = findLLMOrThrow(llmId); + const vendor = findVendorById(llm?._source.vId); if (!vendor) throw new Error(`callChat: Vendor not found for LLM ${llmId}`); return { llm, vendor }; } From bee49a4b1c30508b02c1c2da1ca5d80e8311bc30 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 19:00:19 -0800 Subject: [PATCH 15/24] Llms: streaming as a vendor function (then all directed to the unified) --- app/api/llms/stream/route.ts | 2 +- src/apps/call/CallUI.tsx | 5 +- src/apps/call/components/CallMessage.tsx | 2 +- src/apps/chat/editors/chat-stream.ts | 4 +- src/apps/personas/useLLMChain.ts | 3 +- .../aifn/autosuggestions/autoSuggestions.ts | 3 +- src/modules/aifn/autotitle/autoTitle.ts | 2 +- src/modules/aifn/digrams/DiagramsModal.tsx | 4 +- src/modules/aifn/digrams/diagrams.data.ts | 3 +- .../aifn/imagine/imaginePromptFromText.ts | 2 +- src/modules/aifn/react/react.ts | 3 +- src/modules/aifn/summarize/summerize.ts | 2 +- src/modules/aifn/useStreamChatText.ts | 5 +- src/modules/llms/client/llm.client.types.ts | 27 ------- src/modules/llms/client/llmChatGenerate.ts | 23 ------ src/modules/llms/llm.client.ts | 74 +++++++++++++++++++ ...s.streaming.ts => llm.server.streaming.ts} | 21 +++--- src/modules/llms/vendors/IModelVendor.ts | 14 +++- .../vendors/anthropic/anthropic.vendor.ts | 6 +- .../llms/vendors/azure/azure.vendor.ts | 1 + .../GeminiSourceSetup.tsx | 0 .../{googleai => gemini}/gemini.vendor.ts | 6 +- .../llms/vendors/localai/localai.vendor.ts | 1 + .../llms/vendors/mistral/mistral.vendor.ts | 1 + .../llms/vendors/ollama/ollama.vendor.ts | 6 +- .../vendors/oobabooga/oobabooga.vendor.ts | 1 + .../llms/vendors/openai/openai.vendor.ts | 6 +- .../vendors/openrouter/openrouter.vendor.ts | 1 + .../unifiedStreamingClient.ts} | 39 +++------- src/modules/llms/vendors/vendors.registry.ts | 2 +- 30 files changed, 153 insertions(+), 116 deletions(-) delete mode 100644 src/modules/llms/client/llm.client.types.ts delete mode 100644 src/modules/llms/client/llmChatGenerate.ts create mode 100644 src/modules/llms/llm.client.ts rename src/modules/llms/server/{llms.streaming.ts => llm.server.streaming.ts} (93%) rename src/modules/llms/vendors/{googleai => gemini}/GeminiSourceSetup.tsx (100%) rename src/modules/llms/vendors/{googleai => gemini}/gemini.vendor.ts (92%) rename src/modules/llms/{client/llmStreamChatGenerate.ts => vendors/unifiedStreamingClient.ts} (73%) diff --git a/app/api/llms/stream/route.ts b/app/api/llms/stream/route.ts index b0873a013..fed3b5d47 100644 --- a/app/api/llms/stream/route.ts +++ b/app/api/llms/stream/route.ts @@ -1,2 +1,2 @@ export const runtime = 'edge'; -export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llms.streaming'; \ No newline at end of file +export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llm.server.streaming'; \ No newline at end of file diff --git a/src/apps/call/CallUI.tsx b/src/apps/call/CallUI.tsx index a118e3266..1da62a002 100644 --- a/src/apps/call/CallUI.tsx +++ b/src/apps/call/CallUI.tsx @@ -13,10 +13,9 @@ import RecordVoiceOverIcon from '@mui/icons-material/RecordVoiceOver'; import { useChatLLMDropdown } from '../chat/components/applayout/useLLMDropdown'; -import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; import { EXPERIMENTAL_speakTextStream } from '~/modules/elevenlabs/elevenlabs.client'; import { SystemPurposeId, SystemPurposes } from '../../data'; -import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate'; +import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client'; import { useElevenLabsVoiceDropdown } from '~/modules/elevenlabs/useElevenLabsVoiceDropdown'; import { Link } from '~/common/components/Link'; @@ -216,7 +215,7 @@ export function CallUI(props: { responseAbortController.current = new AbortController(); let finalText = ''; let error: any | null = null; - llmStreamChatGenerate(chatLLMId, callPrompt, responseAbortController.current.signal, (updatedMessage: Partial) => { + llmStreamingChatGenerate(chatLLMId, callPrompt, null, null, responseAbortController.current.signal, (updatedMessage: Partial) => { const text = updatedMessage.text?.trim(); if (text) { finalText = text; diff --git a/src/apps/call/components/CallMessage.tsx b/src/apps/call/components/CallMessage.tsx index 525a586c2..63a3b76cb 100644 --- a/src/apps/call/components/CallMessage.tsx +++ b/src/apps/call/components/CallMessage.tsx @@ -3,7 +3,7 @@ import * as React from 'react'; import { Chip, ColorPaletteProp, VariantProp } from '@mui/joy'; import { SxProps } from '@mui/joy/styles/types'; -import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; +import type { VChatMessageIn } from '~/modules/llms/llm.client'; export function CallMessage(props: { diff --git a/src/apps/chat/editors/chat-stream.ts b/src/apps/chat/editors/chat-stream.ts index b8dd4e3f0..e3c5a5b4e 100644 --- a/src/apps/chat/editors/chat-stream.ts +++ b/src/apps/chat/editors/chat-stream.ts @@ -2,7 +2,7 @@ import { DLLMId } from '~/modules/llms/store-llms'; import { SystemPurposeId } from '../../../data'; import { autoSuggestions } from '~/modules/aifn/autosuggestions/autoSuggestions'; import { autoTitle } from '~/modules/aifn/autotitle/autoTitle'; -import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate'; +import { llmStreamingChatGenerate } from '~/modules/llms/llm.client'; import { speakText } from '~/modules/elevenlabs/elevenlabs.client'; import { DMessage, useChatStore } from '~/common/state/store-chats'; @@ -63,7 +63,7 @@ async function streamAssistantMessage( const messages = history.map(({ role, text }) => ({ role, content: text })); try { - await llmStreamChatGenerate(llmId, messages, abortSignal, + await llmStreamingChatGenerate(llmId, messages, null, null, abortSignal, (updatedMessage: Partial) => { // update the message in the store (and thus schedule a re-render) editMessage(updatedMessage); diff --git a/src/apps/personas/useLLMChain.ts b/src/apps/personas/useLLMChain.ts index 45f0be785..66682b977 100644 --- a/src/apps/personas/useLLMChain.ts +++ b/src/apps/personas/useLLMChain.ts @@ -1,8 +1,7 @@ import * as React from 'react'; -import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; import { DLLMId, useModelsStore } from '~/modules/llms/store-llms'; -import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client'; export interface LLMChainStep { diff --git a/src/modules/aifn/autosuggestions/autoSuggestions.ts b/src/modules/aifn/autosuggestions/autoSuggestions.ts index 8a0bbb9c5..65097a9b4 100644 --- a/src/modules/aifn/autosuggestions/autoSuggestions.ts +++ b/src/modules/aifn/autosuggestions/autoSuggestions.ts @@ -1,5 +1,4 @@ -import type { VChatFunctionIn } from '~/modules/llms/client/llm.client.types'; -import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow, VChatFunctionIn } from '~/modules/llms/llm.client'; import { useModelsStore } from '~/modules/llms/store-llms'; import { useChatStore } from '~/common/state/store-chats'; diff --git a/src/modules/aifn/autotitle/autoTitle.ts b/src/modules/aifn/autotitle/autoTitle.ts index 3c0da5d99..4172b6506 100644 --- a/src/modules/aifn/autotitle/autoTitle.ts +++ b/src/modules/aifn/autotitle/autoTitle.ts @@ -1,4 +1,4 @@ -import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client'; import { useModelsStore } from '~/modules/llms/store-llms'; import { useChatStore } from '~/common/state/store-chats'; diff --git a/src/modules/aifn/digrams/DiagramsModal.tsx b/src/modules/aifn/digrams/DiagramsModal.tsx index 9206957fc..b0f8da1e7 100644 --- a/src/modules/aifn/digrams/DiagramsModal.tsx +++ b/src/modules/aifn/digrams/DiagramsModal.tsx @@ -8,7 +8,7 @@ import ReplayIcon from '@mui/icons-material/Replay'; import StopOutlinedIcon from '@mui/icons-material/StopOutlined'; import TelegramIcon from '@mui/icons-material/Telegram'; -import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate'; +import { llmStreamingChatGenerate } from '~/modules/llms/llm.client'; import { ChatMessage } from '../../../apps/chat/components/message/ChatMessage'; @@ -86,7 +86,7 @@ export function DiagramsModal(props: { config: DiagramConfig, onClose: () => voi const diagramPrompt = bigDiagramPrompt(diagramType, diagramLanguage, systemMessage.text, subject, customInstruction); try { - await llmStreamChatGenerate(diagramLlm.id, diagramPrompt, stepAbortController.signal, + await llmStreamingChatGenerate(diagramLlm.id, diagramPrompt, null, null, stepAbortController.signal, (update: Partial<{ text: string, typing: boolean, originLLM: string }>) => { assistantMessage = { ...assistantMessage, ...update }; setMessage(assistantMessage); diff --git a/src/modules/aifn/digrams/diagrams.data.ts b/src/modules/aifn/digrams/diagrams.data.ts index 8a4b675de..d1b424d35 100644 --- a/src/modules/aifn/digrams/diagrams.data.ts +++ b/src/modules/aifn/digrams/diagrams.data.ts @@ -1,6 +1,5 @@ -import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; - import type { FormRadioOption } from '~/common/components/forms/FormRadioControl'; +import type { VChatMessageIn } from '~/modules/llms/llm.client'; export type DiagramType = 'auto' | 'mind'; diff --git a/src/modules/aifn/imagine/imaginePromptFromText.ts b/src/modules/aifn/imagine/imaginePromptFromText.ts index 27c1b47d8..c1556630b 100644 --- a/src/modules/aifn/imagine/imaginePromptFromText.ts +++ b/src/modules/aifn/imagine/imaginePromptFromText.ts @@ -1,4 +1,4 @@ -import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client'; import { useModelsStore } from '~/modules/llms/store-llms'; diff --git a/src/modules/aifn/react/react.ts b/src/modules/aifn/react/react.ts index 4aac50e78..5a264fba6 100644 --- a/src/modules/aifn/react/react.ts +++ b/src/modules/aifn/react/react.ts @@ -2,11 +2,10 @@ * porting of implementation from here: https://til.simonwillison.net/llms/python-react-pattern */ -import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; import { DLLMId } from '~/modules/llms/store-llms'; import { callApiSearchGoogle } from '~/modules/google/search.client'; import { callBrowseFetchPage } from '~/modules/browse/browse.client'; -import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client'; // prompt to implement the ReAct paradigm: https://arxiv.org/abs/2210.03629 diff --git a/src/modules/aifn/summarize/summerize.ts b/src/modules/aifn/summarize/summerize.ts index 1d9179cff..1635a95f1 100644 --- a/src/modules/aifn/summarize/summerize.ts +++ b/src/modules/aifn/summarize/summerize.ts @@ -1,5 +1,5 @@ import { DLLMId, findLLMOrThrow } from '~/modules/llms/store-llms'; -import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate'; +import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client'; // prompt to be tried when doing recursive summerization. diff --git a/src/modules/aifn/useStreamChatText.ts b/src/modules/aifn/useStreamChatText.ts index 6239c637f..4c54b65f4 100644 --- a/src/modules/aifn/useStreamChatText.ts +++ b/src/modules/aifn/useStreamChatText.ts @@ -1,8 +1,7 @@ import * as React from 'react'; import type { DLLMId } from '~/modules/llms/store-llms'; -import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types'; -import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate'; +import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client'; export function useStreamChatText() { @@ -25,7 +24,7 @@ export function useStreamChatText() { try { let lastText = ''; - await llmStreamChatGenerate(llmId, prompt, abortControllerRef.current.signal, (update) => { + await llmStreamingChatGenerate(llmId, prompt, null, null, abortControllerRef.current.signal, (update) => { if (update.text) { lastText = update.text; setPartialText(lastText); diff --git a/src/modules/llms/client/llm.client.types.ts b/src/modules/llms/client/llm.client.types.ts deleted file mode 100644 index 3732e64a9..000000000 --- a/src/modules/llms/client/llm.client.types.ts +++ /dev/null @@ -1,27 +0,0 @@ -import type { OpenAIWire } from '../server/openai/openai.wiretypes'; - - -// Model List types -// export { type ModelDescriptionSchema } from '../server/llm.server.types'; - - -// Chat Generate types - -export interface VChatMessageIn { - role: 'assistant' | 'system' | 'user'; // | 'function'; - content: string; - //name?: string; // when role: 'function' -} - -export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef; - -export interface VChatMessageOut { - role: 'assistant' | 'system' | 'user'; - content: string; - finish_reason: 'stop' | 'length' | null; -} - -export interface VChatMessageOrFunctionCallOut extends VChatMessageOut { - function_name: string; - function_arguments: object | null; -} \ No newline at end of file diff --git a/src/modules/llms/client/llmChatGenerate.ts b/src/modules/llms/client/llmChatGenerate.ts deleted file mode 100644 index abf9618b3..000000000 --- a/src/modules/llms/client/llmChatGenerate.ts +++ /dev/null @@ -1,23 +0,0 @@ -import type { DLLMId } from '../store-llms'; -import { findVendorForLlmOrThrow } from '../vendors/vendors.registry'; - -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from './llm.client.types'; - - -export async function llmChatGenerateOrThrow( - llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number, -): Promise { - - // id to DLLM and vendor - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - - // FIXME: relax the forced cast - const options = llm.options as TLLMOptions; - - // get the access - const partialSourceSetup = llm._source.setup; - const access = vendor.getTransportAccess(partialSourceSetup); - - // execute via the vendor - return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens); -} diff --git a/src/modules/llms/llm.client.ts b/src/modules/llms/llm.client.ts new file mode 100644 index 000000000..73957f08c --- /dev/null +++ b/src/modules/llms/llm.client.ts @@ -0,0 +1,74 @@ +import type { DLLMId } from './store-llms'; +import type { OpenAIWire } from './server/openai/openai.wiretypes'; +import { findVendorForLlmOrThrow } from './vendors/vendors.registry'; + + +// LLM Client Types +// NOTE: Model List types in '../server/llm.server.types'; + +export interface VChatMessageIn { + role: 'assistant' | 'system' | 'user'; // | 'function'; + content: string; + //name?: string; // when role: 'function' +} + +export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef; + +export interface VChatMessageOut { + role: 'assistant' | 'system' | 'user'; + content: string; + finish_reason: 'stop' | 'length' | null; +} + +export interface VChatMessageOrFunctionCallOut extends VChatMessageOut { + function_name: string; + function_arguments: object | null; +} + + +// LLM Client Functions + +export async function llmChatGenerateOrThrow( + llmId: DLLMId, + messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, forceFunctionName: string | null, + maxTokens?: number, +): Promise { + + // id to DLLM and vendor + const { llm, vendor } = findVendorForLlmOrThrow(llmId); + + // FIXME: relax the forced cast + const options = llm.options as TLLMOptions; + + // get the access + const partialSourceSetup = llm._source.setup; + const access = vendor.getTransportAccess(partialSourceSetup); + + // execute via the vendor + return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens); +} + + +export async function llmStreamingChatGenerate( + llmId: DLLMId, + messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, + forceFunctionName: string | null, + abortSignal: AbortSignal, + onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void, +): Promise { + + // id to DLLM and vendor + const { llm, vendor } = findVendorForLlmOrThrow(llmId); + + // FIXME: relax the forced cast + const llmOptions = llm.options as TLLMOptions; + + // get the access + const partialSourceSetup = llm._source.setup; + const access = vendor.getTransportAccess(partialSourceSetup); // as ChatStreamInputSchema['access']; + + // execute via the vendor + return await vendor.streamingChatGenerateOrThrow(access, llmId, llmOptions, messages, functions, forceFunctionName, abortSignal, onUpdate); +} diff --git a/src/modules/llms/server/llms.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts similarity index 93% rename from src/modules/llms/server/llms.streaming.ts rename to src/modules/llms/server/llm.server.streaming.ts index 7383d89f5..47a08a312 100644 --- a/src/modules/llms/server/llms.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -9,6 +9,9 @@ import { createEmptyReadableStream, debugGenerateCurlCommand, safeErrorString, S import type { AnthropicWire } from './anthropic/anthropic.wiretypes'; import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from './anthropic/anthropic.router'; +// Gemini server imports +import { geminiAccessSchema } from './gemini/gemini.router'; + // Ollama server imports import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes'; import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from './ollama/ollama.router'; @@ -37,24 +40,24 @@ type EventStreamFormat = 'sse' | 'json-nl'; type AIStreamParser = (data: string) => { text: string, close: boolean }; -const chatStreamInputSchema = z.object({ - access: z.union([anthropicAccessSchema, ollamaAccessSchema, openAIAccessSchema]), +const chatStreamingInputSchema = z.object({ + access: z.union([anthropicAccessSchema, geminiAccessSchema, ollamaAccessSchema, openAIAccessSchema]), model: openAIModelSchema, history: openAIHistorySchema, }); -export type ChatStreamInputSchema = z.infer; +export type ChatStreamingInputSchema = z.infer; -const chatStreamFirstOutputPacketSchema = z.object({ +const chatStreamingFirstOutputPacketSchema = z.object({ model: z.string(), }); -export type ChatStreamFirstOutputPacketSchema = z.infer; +export type ChatStreamingFirstOutputPacketSchema = z.infer; export async function llmStreamingRelayHandler(req: NextRequest): Promise { // inputs - reuse the tRPC schema const body = await req.json(); - const { access, model, history } = chatStreamInputSchema.parse(body); + const { access, model, history } = chatStreamingInputSchema.parse(body); // access/dialect dependent setup: // - requestAccess: the headers and URL to use for the upstream API call @@ -240,7 +243,7 @@ function createAnthropicStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun) { hasBegun = true; - const firstPacket: ChatStreamFirstOutputPacketSchema = { model: json.model }; + const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model }; text = JSON.stringify(firstPacket) + text; } @@ -276,7 +279,7 @@ function createOllamaChatCompletionStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun && chunk.model) { hasBegun = true; - const firstPacket: ChatStreamFirstOutputPacketSchema = { model: chunk.model }; + const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: chunk.model }; text = JSON.stringify(firstPacket) + text; } @@ -317,7 +320,7 @@ function createOpenAIStreamParser(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun) { hasBegun = true; - const firstPacket: ChatStreamFirstOutputPacketSchema = { model: json.model }; + const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model }; text = JSON.stringify(firstPacket) + text; } diff --git a/src/modules/llms/vendors/IModelVendor.ts b/src/modules/llms/vendors/IModelVendor.ts index 6c2290b6a..e7cc7fbb0 100644 --- a/src/modules/llms/vendors/IModelVendor.ts +++ b/src/modules/llms/vendors/IModelVendor.ts @@ -1,10 +1,10 @@ import type React from 'react'; import type { TRPCClientErrorBase } from '@trpc/client'; -import type { DLLM, DModelSourceId } from '../store-llms'; +import type { DLLM, DLLMId, DModelSourceId } from '../store-llms'; import type { ModelDescriptionSchema } from '../server/llm.server.types'; import type { ModelVendorId } from './vendors.registry'; -import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../client/llm.client.types'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '~/modules/llms/llm.client'; export interface IModelVendor> { @@ -43,4 +43,14 @@ export interface IModelVendor Promise; + streamingChatGenerateOrThrow: ( + access: TAccess, + llmId: DLLMId, + llmOptions: TLLMOptions, + messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, forceFunctionName: string | null, + abortSignal: AbortSignal, + onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void, + ) => Promise; + } diff --git a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts index c8d4fd2d1..e007d8f9f 100644 --- a/src/modules/llms/vendors/anthropic/anthropic.vendor.ts +++ b/src/modules/llms/vendors/anthropic/anthropic.vendor.ts @@ -5,7 +5,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { AnthropicAccessSchema } from '../../server/anthropic/anthropic.router'; import type { IModelVendor } from '../IModelVendor'; -import type { VChatMessageOut } from '../../client/llm.client.types'; +import type { VChatMessageOut } from '../../llm.client'; +import { unifiedStreamingClient } from '../unifiedStreamingClient'; import { LLMOptionsOpenAI } from '../openai/openai.vendor'; import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; @@ -77,4 +78,7 @@ export const ModelVendorAnthropic: IModelVendor( + access: ChatStreamingInputSchema['access'], llmId: DLLMId, + llmOptions: TLLMOptions, messages: VChatMessageIn[], - abortSignal: AbortSignal, - onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void, -): Promise { - const { llm, vendor } = findVendorForLlmOrThrow(llmId); - const access = vendor.getTransportAccess(llm._source.setup) as ChatStreamInputSchema['access']; - return await vendorStreamChat(access, llm, messages, abortSignal, onUpdate); -} - - -async function vendorStreamChat( - access: ChatStreamInputSchema['access'], - llm: DLLM, - messages: VChatMessageIn[], + functions: VChatFunctionIn[] | null, forceFunctionName: string | null, abortSignal: AbortSignal, onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void, ) { @@ -80,12 +65,12 @@ async function vendorStreamChat( } // model params (llm) - const { llmRef, llmTemperature, llmResponseTokens } = (llm.options as any) || {}; + const { llmRef, llmTemperature, llmResponseTokens } = (llmOptions as any) || {}; if (!llmRef || llmTemperature === undefined || llmResponseTokens === undefined) - throw new Error(`Error in configuration for model ${llm.id}: ${JSON.stringify(llm.options)}`); + throw new Error(`Error in configuration for model ${llmId}: ${JSON.stringify(llmOptions)}`); // prepare the input, similarly to the tRPC openAI.chatGenerate - const input: ChatStreamInputSchema = { + const input: ChatStreamingInputSchema = { access, model: { id: llmRef, @@ -132,7 +117,7 @@ async function vendorStreamChat( incrementalText = incrementalText.substring(endOfJson + 1); parsedFirstPacket = true; try { - const parsed: ChatStreamFirstOutputPacketSchema = JSON.parse(json); + const parsed: ChatStreamingFirstOutputPacketSchema = JSON.parse(json); onUpdate({ originLLM: parsed.model }, false); } catch (e) { // error parsing JSON, ignore diff --git a/src/modules/llms/vendors/vendors.registry.ts b/src/modules/llms/vendors/vendors.registry.ts index 061c34ea1..054a8a6bf 100644 --- a/src/modules/llms/vendors/vendors.registry.ts +++ b/src/modules/llms/vendors/vendors.registry.ts @@ -1,6 +1,6 @@ import { ModelVendorAnthropic } from './anthropic/anthropic.vendor'; import { ModelVendorAzure } from './azure/azure.vendor'; -import { ModelVendorGemini } from './googleai/gemini.vendor'; +import { ModelVendorGemini } from './gemini/gemini.vendor'; import { ModelVendorLocalAI } from './localai/localai.vendor'; import { ModelVendorMistral } from './mistral/mistral.vendor'; import { ModelVendorOllama } from './ollama/ollama.vendor'; From bc5a38fa89a7ff3b12c359fd5afe0ee7a2fc9af4 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 19:13:20 -0800 Subject: [PATCH 16/24] Models List: show a helpful message --- src/modules/llms/models-modal/ModelsList.tsx | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/modules/llms/models-modal/ModelsList.tsx b/src/modules/llms/models-modal/ModelsList.tsx index 0336aeec0..b24d6d3b6 100644 --- a/src/modules/llms/models-modal/ModelsList.tsx +++ b/src/modules/llms/models-modal/ModelsList.tsx @@ -111,7 +111,13 @@ export function ModelsList(props: { pl: { xs: 0, md: 1 }, overflowY: 'auto', }}> - {items} + {items.length > 0 ? items : ( + + + Please configure the service and update the list of models. + + + )} ); } \ No newline at end of file From b08ecc90125b99c7215bd5f45a0ff69a88cdf957 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 19:13:29 -0800 Subject: [PATCH 17/24] Models Modal: improve caps --- src/modules/llms/models-modal/ModelsModal.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modules/llms/models-modal/ModelsModal.tsx b/src/modules/llms/models-modal/ModelsModal.tsx index e870914a1..57b7ebbed 100644 --- a/src/modules/llms/models-modal/ModelsModal.tsx +++ b/src/modules/llms/models-modal/ModelsModal.tsx @@ -65,7 +65,7 @@ export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) { title={<>Configure AI Models} startButton={ multiSource ? setShowAllSources(all => !all)} /> : undefined } From 9952b757b88437ab2c7b16a32d1a7d4446d72be5 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 20:34:16 -0800 Subject: [PATCH 18/24] Gemini: client version --- src/modules/llms/server/gemini/gemini.router.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/modules/llms/server/gemini/gemini.router.ts b/src/modules/llms/server/gemini/gemini.router.ts index fde91180d..b4ecec0b7 100644 --- a/src/modules/llms/server/gemini/gemini.router.ts +++ b/src/modules/llms/server/gemini/gemini.router.ts @@ -1,6 +1,8 @@ import { z } from 'zod'; import { TRPCError } from '@trpc/server'; +import packageJson from '../../../../../package.json'; + import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server'; import { fetchJsonOrTRPCError } from '~/server/api/trpc.serverutils'; @@ -32,6 +34,7 @@ export function geminiAccess(access: GeminiAccessSchema, modelRefId: string | nu return { headers: { 'Content-Type': 'application/json', + 'x-goog-api-client': `big-agi/${packageJson['version'] || '1.0.0'}`, 'x-goog-api-key': access.geminiKey, }, url: geminiHost + apiPath, From 7b5b852793d4582637b591221962afb5ca426473 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 20:42:31 -0800 Subject: [PATCH 19/24] Gemini: trim key --- src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx b/src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx index cdf1f5be6..c1d2dc3ea 100644 --- a/src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx +++ b/src/modules/llms/vendors/gemini/GeminiSourceSetup.tsx @@ -40,7 +40,7 @@ export function GeminiSourceSetup(props: { sourceId: DModelSourceId }) { ? !geminiKey && request Key : 'โœ”๏ธ already set in server'} } - value={geminiKey} onChange={value => updateSetup({ geminiKey: value })} + value={geminiKey} onChange={value => updateSetup({ geminiKey: value.trim() })} required={needsUserKey} isError={showKeyError} placeholder='...' /> From 45046c70ed9d4ab196f5ed58d4bfc6d643591011 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 21:22:08 -0800 Subject: [PATCH 20/24] Gemini: stream on --- .../llms/server/gemini/gemini.wiretypes.ts | 8 ++- .../llms/server/llm.server.streaming.ts | 65 +++++++++++++++---- 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/modules/llms/server/gemini/gemini.wiretypes.ts b/src/modules/llms/server/gemini/gemini.wiretypes.ts index e5aa890ac..d890cd36e 100644 --- a/src/modules/llms/server/gemini/gemini.wiretypes.ts +++ b/src/modules/llms/server/gemini/gemini.wiretypes.ts @@ -4,7 +4,8 @@ import { z } from 'zod'; export const geminiModelsListPath = '/v1beta/models?pageSize=1000'; export const geminiModelsGenerateContentPath = '/v1beta/{model=models/*}:generateContent'; -export const geminiModelsStreamGenerateContentPath = '/v1beta/{model=models/*}:streamGenerateContent'; +// see alt=sse on https://cloud.google.com/apis/docs/system-parameters#definitions +export const geminiModelsStreamGenerateContentPath = '/v1beta/{model=models/*}:streamGenerateContent?alt=sse'; // models.list = /v1beta/models @@ -174,8 +175,9 @@ export const geminiGeneratedContentResponseSchema = z.object({ tokenCount: z.number().optional(), // groundingAttributions: z.array(GroundingAttribution).optional(), // This field is populated for GenerateAnswer calls. })), + // NOTE: promptFeedback is only send in the first chunk in a streaming response promptFeedback: z.object({ blockReason: z.enum(['BLOCK_REASON_UNSPECIFIED', 'SAFETY', 'OTHER']).optional(), safetyRatings: z.array(geminiSafetyRatingSchema), - }), -}); \ No newline at end of file + }).optional(), +}); diff --git a/src/modules/llms/server/llm.server.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts index 47a08a312..5a60d75ad 100644 --- a/src/modules/llms/server/llm.server.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -10,7 +10,8 @@ import type { AnthropicWire } from './anthropic/anthropic.wiretypes'; import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from './anthropic/anthropic.router'; // Gemini server imports -import { geminiAccessSchema } from './gemini/gemini.router'; +import { geminiAccess, geminiAccessSchema, geminiGenerateContentPayload } from './gemini/gemini.router'; +import { geminiGeneratedContentResponseSchema, geminiModelsStreamGenerateContentPath } from './gemini/gemini.wiretypes'; // Ollama server imports import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes'; @@ -75,14 +76,20 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise} TransformStream parsing events. */ -function createEventStreamTransformer(vendorTextParser: AIStreamParser, inputFormat: EventStreamFormat, dialectLabel: string): TransformStream { +function createEventStreamTransformer(inputFormat: EventStreamFormat, vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream { const textDecoder = new TextDecoder(); const textEncoder = new TextEncoder(); let eventSourceParser: EventSourceParser; @@ -232,7 +242,7 @@ function createEventStreamTransformer(vendorTextParser: AIStreamParser, inputFor /// Stream Parsers -function createAnthropicStreamParser(): AIStreamParser { +function createStreamParserAnthropic(): AIStreamParser { let hasBegun = false; return (data: string) => { @@ -251,7 +261,40 @@ function createAnthropicStreamParser(): AIStreamParser { }; } -function createOllamaChatCompletionStreamParser(): AIStreamParser { +function createStreamParserGemini(modelName: string): AIStreamParser { + let hasBegun = false; + + // this can throw, it's catched upstream + return (data: string) => { + + // parse the JSON chunk + const wireGenerationChunk = JSON.parse(data); + const generationChunk = geminiGeneratedContentResponseSchema.parse(wireGenerationChunk); + + // expect a single completion + const singleCandidate = generationChunk.candidates?.[0] ?? null; + if (!singleCandidate || !singleCandidate.content?.parts.length) + throw new Error(`Gemini: expected 1 completion, got ${generationChunk.candidates?.length}`); + + // expect a single part + if (singleCandidate.content.parts.length !== 1 || !('text' in singleCandidate.content.parts[0])) + throw new Error(`Gemini: expected 1 text part, got ${singleCandidate.content.parts.length}`); + + // expect a single text in the part + let text = singleCandidate.content.parts[0].text || ''; + + // hack: prepend the model name to the first packet + if (!hasBegun) { + hasBegun = true; + const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: modelName }; + text = JSON.stringify(firstPacket) + text; + } + + return { text, close: false }; + }; +} + +function createStreamParserOllama(): AIStreamParser { let hasBegun = false; return (data: string) => { @@ -287,7 +330,7 @@ function createOllamaChatCompletionStreamParser(): AIStreamParser { }; } -function createOpenAIStreamParser(): AIStreamParser { +function createStreamParserOpenAI(): AIStreamParser { let hasBegun = false; let hasWarned = false; From efff7126af14633148f9fc9446d133d3732b059f Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 22:07:39 -0800 Subject: [PATCH 21/24] Gemini: final touches --- .../chat/components/message/ChatMessage.tsx | 2 + .../llms/server/gemini/gemini.router.ts | 56 +++++++++++++------ .../llms/server/gemini/gemini.wiretypes.ts | 2 + .../llms/server/llm.server.streaming.ts | 4 +- .../llms/vendors/gemini/gemini.vendor.ts | 1 - 5 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/apps/chat/components/message/ChatMessage.tsx b/src/apps/chat/components/message/ChatMessage.tsx index 2fd6ccbbc..616617f94 100644 --- a/src/apps/chat/components/message/ChatMessage.tsx +++ b/src/apps/chat/components/message/ChatMessage.tsx @@ -167,6 +167,8 @@ function explainErrorInMessage(text: string, isAssistant: boolean, modelId?: str make sure the usage is under the limits. ; } + // else + // errorMessage = <>{text || 'Unknown error'}; return { errorMessage, isAssistantError }; } diff --git a/src/modules/llms/server/gemini/gemini.router.ts b/src/modules/llms/server/gemini/gemini.router.ts index b4ecec0b7..7f7b0899e 100644 --- a/src/modules/llms/server/gemini/gemini.router.ts +++ b/src/modules/llms/server/gemini/gemini.router.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; import { TRPCError } from '@trpc/server'; +import { env } from '~/server/env.mjs'; import packageJson from '../../../../../package.json'; @@ -11,7 +12,7 @@ import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.ty import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; -import { GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes'; +import { GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes'; // Default hosts @@ -22,35 +23,58 @@ const DEFAULT_GEMINI_HOST = 'https://generativelanguage.googleapis.com'; export function geminiAccess(access: GeminiAccessSchema, modelRefId: string | null, apiPath: string): { headers: HeadersInit, url: string } { - // handle paths that require a model name + const geminiKey = access.geminiKey || env.GEMINI_API_KEY || ''; + const geminiHost = fixupHost(DEFAULT_GEMINI_HOST, apiPath); + + // update model-dependent paths if (apiPath.includes('{model=models/*}')) { if (!modelRefId) throw new Error(`geminiAccess: modelRefId is required for ${apiPath}`); apiPath = apiPath.replace('{model=models/*}', modelRefId); } - const geminiHost = fixupHost(DEFAULT_GEMINI_HOST, apiPath); - return { headers: { 'Content-Type': 'application/json', 'x-goog-api-client': `big-agi/${packageJson['version'] || '1.0.0'}`, - 'x-goog-api-key': access.geminiKey, + 'x-goog-api-key': geminiKey, }, url: geminiHost + apiPath, }; } -export const geminiGenerateContentPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, n: number): GeminiGenerateContentRequest => { - const contents: GeminiGenerateContentRequest['contents'] = []; - history.forEach((message) => { - // hack for now - the model seems to want prompts to alternate - if (message.role === 'system') { - contents.push({ role: 'user', parts: [{ text: message.content }] }); - contents.push({ role: 'model', parts: [{ text: 'Ok.' }] }); - } else - contents.push({ role: message.role === 'assistant' ? 'model' : 'user', parts: [{ text: message.content }] }); - }); +/** + * We specially encode the history to match the Gemini API requirements. + * Gemini does not want 2 consecutive messages from the same role, so we alternate. + * - System messages = [User, Model'Ok'] + * - User and Assistant messages are coalesced into a single message (e.g. [User, User, Assistant, Assistant, User] -> [User[2], Assistant[2], User[1]]) + */ +export const geminiGenerateContentTextPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, n: number): GeminiGenerateContentRequest => { + + // convert the history to a Gemini format + const contents: GeminiContentSchema[] = []; + for (const _historyElement of history) { + + const { role: msgRole, content: msgContent } = _historyElement; + + // System message - we treat it as per the example in https://ai.google.dev/tutorials/ai-studio_quickstart#chat_example + if (msgRole === 'system') { + contents.push({ role: 'user', parts: [{ text: msgContent }] }); + contents.push({ role: 'model', parts: [{ text: 'Ok' }] }); + continue; + } + + // User or Assistant message + const nextRole: GeminiContentSchema['role'] = msgRole === 'assistant' ? 'model' : 'user'; + if (contents.length && contents[contents.length - 1].role === nextRole) { + // coalesce with the previous message + contents[contents.length - 1].parts.push({ text: msgContent }); + } else { + // create a new message + contents.push({ role: nextRole, parts: [{ text: msgContent }] }); + } + } + return { contents, generationConfig: { @@ -160,7 +184,7 @@ export const llmGeminiRouter = createTRPCRouter({ .mutation(async ({ input: { access, history, model } }) => { // generate the content - const wireGeneration = await geminiPOST(access, model.id, geminiGenerateContentPayload(model, history, 1), geminiModelsGenerateContentPath); + const wireGeneration = await geminiPOST(access, model.id, geminiGenerateContentTextPayload(model, history, 1), geminiModelsGenerateContentPath); const generation = geminiGeneratedContentResponseSchema.parse(wireGeneration); // only use the first result (and there should be only one) diff --git a/src/modules/llms/server/gemini/gemini.wiretypes.ts b/src/modules/llms/server/gemini/gemini.wiretypes.ts index d890cd36e..21c80b51a 100644 --- a/src/modules/llms/server/gemini/gemini.wiretypes.ts +++ b/src/modules/llms/server/gemini/gemini.wiretypes.ts @@ -123,6 +123,8 @@ const geminiContentSchema = z.object({ parts: z.array(geminiContentPartSchema), }); +export type GeminiContentSchema = z.infer; + export const geminiGenerateContentRequest = z.object({ contents: z.array(geminiContentSchema), tools: z.array(geminiToolSchema).optional(), diff --git a/src/modules/llms/server/llm.server.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts index 5a60d75ad..4c2ba9c9e 100644 --- a/src/modules/llms/server/llm.server.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -10,7 +10,7 @@ import type { AnthropicWire } from './anthropic/anthropic.wiretypes'; import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from './anthropic/anthropic.router'; // Gemini server imports -import { geminiAccess, geminiAccessSchema, geminiGenerateContentPayload } from './gemini/gemini.router'; +import { geminiAccess, geminiAccessSchema, geminiGenerateContentTextPayload } from './gemini/gemini.router'; import { geminiGeneratedContentResponseSchema, geminiModelsStreamGenerateContentPath } from './gemini/gemini.wiretypes'; // Ollama server imports @@ -81,7 +81,7 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise Date: Tue, 19 Dec 2023 23:59:04 -0800 Subject: [PATCH 22/24] Streaming: muxing format --- .../llms/server/llm.server.streaming.ts | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/modules/llms/server/llm.server.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts index 4c2ba9c9e..1a3b7198e 100644 --- a/src/modules/llms/server/llm.server.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -27,7 +27,7 @@ import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHi * - 'sse' is the default format, and is used by all vendors except Ollama * - 'json-nl' is used by Ollama */ -type EventStreamFormat = 'sse' | 'json-nl'; +type MuxingFormat = 'sse' | 'json-nl'; /** @@ -62,11 +62,11 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise { - console.error('createJsonNewlineParser.reset() not implemented'); + console.error('createDemuxerJsonNewline.reset() not implemented'); }, }; } @@ -183,7 +183,7 @@ function createJsonNewlineParser(onParse: EventSourceParseCallback): EventSource * Creates a TransformStream that parses events from an EventSource stream using a custom parser. * @returns {TransformStream} TransformStream parsing events. */ -function createEventStreamTransformer(inputFormat: EventStreamFormat, vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream { +function createEventStreamTransformer(muxingFormat: MuxingFormat, vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream { const textDecoder = new TextDecoder(); const textEncoder = new TextEncoder(); let eventSourceParser: EventSourceParser; @@ -226,10 +226,10 @@ function createEventStreamTransformer(inputFormat: EventStreamFormat, vendorText } }; - if (inputFormat === 'sse') + if (muxingFormat === 'sse') eventSourceParser = createEventsourceParser(onNewEvent); - else if (inputFormat === 'json-nl') - eventSourceParser = createJsonNewlineParser(onNewEvent); + else if (muxingFormat === 'json-nl') + eventSourceParser = createDemuxerJsonNewline(onNewEvent); }, // stream=true is set because the data is not guaranteed to be final and un-chunked From 6b62a6733b02f180faf8bb34b4dbc0b05f499d03 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 19 Dec 2023 23:59:22 -0800 Subject: [PATCH 23/24] Gemini: show block reason --- src/modules/llms/server/llm.server.streaming.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/modules/llms/server/llm.server.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts index 1a3b7198e..54ed69b81 100644 --- a/src/modules/llms/server/llm.server.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -271,6 +271,12 @@ function createStreamParserGemini(modelName: string): AIStreamParser { const wireGenerationChunk = JSON.parse(data); const generationChunk = geminiGeneratedContentResponseSchema.parse(wireGenerationChunk); + // Prompt Safety Errors: pass through errors from Gemini + if (generationChunk.promptFeedback?.blockReason) { + const { blockReason, safetyRatings } = generationChunk.promptFeedback; + return { text: `[Gemini Prompt Blocked] ${blockReason}: ${JSON.stringify(safetyRatings || 'Unknown Safety Ratings', null, 2)}`, close: true }; + } + // expect a single completion const singleCandidate = generationChunk.candidates?.[0] ?? null; if (!singleCandidate || !singleCandidate.content?.parts.length) From fdb66da1a737ca76cb20352fa0852f7318ff58f0 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Wed, 20 Dec 2023 00:13:35 -0800 Subject: [PATCH 24/24] Gemini: choose a content filtering threshold --- .../llms/server/gemini/gemini.router.ts | 23 ++++++---- .../llms/server/gemini/gemini.wiretypes.ts | 21 +++++---- .../llms/server/llm.server.streaming.ts | 2 +- .../llms/vendors/gemini/GeminiSourceSetup.tsx | 45 ++++++++++++++++++- .../llms/vendors/gemini/gemini.vendor.ts | 5 +++ 5 files changed, 75 insertions(+), 21 deletions(-) diff --git a/src/modules/llms/server/gemini/gemini.router.ts b/src/modules/llms/server/gemini/gemini.router.ts index 7f7b0899e..a2625814d 100644 --- a/src/modules/llms/server/gemini/gemini.router.ts +++ b/src/modules/llms/server/gemini/gemini.router.ts @@ -12,7 +12,7 @@ import { listModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.ty import { fixupHost, openAIChatGenerateOutputSchema, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; -import { GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes'; +import { GeminiBlockSafetyLevel, geminiBlockSafetyLevelSchema, GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes'; // Default hosts @@ -49,7 +49,7 @@ export function geminiAccess(access: GeminiAccessSchema, modelRefId: string | nu * - System messages = [User, Model'Ok'] * - User and Assistant messages are coalesced into a single message (e.g. [User, User, Assistant, Assistant, User] -> [User[2], Assistant[2], User[1]]) */ -export const geminiGenerateContentTextPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, n: number): GeminiGenerateContentRequest => { +export const geminiGenerateContentTextPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, safety: GeminiBlockSafetyLevel, n: number): GeminiGenerateContentRequest => { // convert the history to a Gemini format const contents: GeminiContentSchema[] = []; @@ -82,12 +82,12 @@ export const geminiGenerateContentTextPayload = (model: OpenAIModelSchema, histo ...(model.maxTokens && { maxOutputTokens: model.maxTokens }), temperature: model.temperature, }, - // safetySettings: [ - // { category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold: 'BLOCK_NONE' }, - // { category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_NONE' }, - // { category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_NONE' }, - // { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: 'BLOCK_NONE' }, - // ], + safetySettings: safety !== 'HARM_BLOCK_THRESHOLD_UNSPECIFIED' ? [ + { category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', threshold: safety }, + { category: 'HARM_CATEGORY_HATE_SPEECH', threshold: safety }, + { category: 'HARM_CATEGORY_HARASSMENT', threshold: safety }, + { category: 'HARM_CATEGORY_DANGEROUS_CONTENT', threshold: safety }, + ] : undefined, }; }; @@ -108,6 +108,7 @@ async function geminiPOST(access: export const geminiAccessSchema = z.object({ dialect: z.enum(['gemini']), geminiKey: z.string(), + minSafetyLevel: geminiBlockSafetyLevelSchema, }); export type GeminiAccessSchema = z.infer; @@ -123,6 +124,10 @@ const chatGenerateInputSchema = z.object({ }); +/** + * See https://github.com/google/generative-ai-js/tree/main/packages/main/src for + * the official Google implementation. + */ export const llmGeminiRouter = createTRPCRouter({ /* [Gemini] models.list = /v1beta/models */ @@ -184,7 +189,7 @@ export const llmGeminiRouter = createTRPCRouter({ .mutation(async ({ input: { access, history, model } }) => { // generate the content - const wireGeneration = await geminiPOST(access, model.id, geminiGenerateContentTextPayload(model, history, 1), geminiModelsGenerateContentPath); + const wireGeneration = await geminiPOST(access, model.id, geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1), geminiModelsGenerateContentPath); const generation = geminiGeneratedContentResponseSchema.parse(wireGeneration); // only use the first result (and there should be only one) diff --git a/src/modules/llms/server/gemini/gemini.wiretypes.ts b/src/modules/llms/server/gemini/gemini.wiretypes.ts index 21c80b51a..c7e4f9a3e 100644 --- a/src/modules/llms/server/gemini/gemini.wiretypes.ts +++ b/src/modules/llms/server/gemini/gemini.wiretypes.ts @@ -95,16 +95,19 @@ const geminiHarmCategorySchema = z.enum([ 'HARM_CATEGORY_DANGEROUS_CONTENT', ]); +export const geminiBlockSafetyLevelSchema = z.enum([ + 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', + 'BLOCK_LOW_AND_ABOVE', + 'BLOCK_MEDIUM_AND_ABOVE', + 'BLOCK_ONLY_HIGH', + 'BLOCK_NONE', +]); + +export type GeminiBlockSafetyLevel = z.infer; const geminiSafetySettingSchema = z.object({ category: geminiHarmCategorySchema, - threshold: z.enum([ - 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', - 'BLOCK_LOW_AND_ABOVE', - 'BLOCK_MEDIUM_AND_ABOVE', - 'BLOCK_ONLY_HIGH', - 'BLOCK_NONE', - ]), + threshold: geminiBlockSafetyLevelSchema, }); const geminiGenerationConfigSchema = z.object({ @@ -176,10 +179,10 @@ export const geminiGeneratedContentResponseSchema = z.object({ }).optional(), tokenCount: z.number().optional(), // groundingAttributions: z.array(GroundingAttribution).optional(), // This field is populated for GenerateAnswer calls. - })), + })).optional(), // NOTE: promptFeedback is only send in the first chunk in a streaming response promptFeedback: z.object({ blockReason: z.enum(['BLOCK_REASON_UNSPECIFIED', 'SAFETY', 'OTHER']).optional(), - safetyRatings: z.array(geminiSafetyRatingSchema), + safetyRatings: z.array(geminiSafetyRatingSchema).optional(), }).optional(), }); diff --git a/src/modules/llms/server/llm.server.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts index 54ed69b81..c5a1d4544 100644 --- a/src/modules/llms/server/llm.server.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -81,7 +81,7 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise + + + + + + + Gemini has + adjustable safety settings on four categories: Harassment, Hate speech, + Sexually explicit, and Dangerous content, in addition to non-adjustable built-in filters. + By default, the model will block content with medium and above probability + of being unsafe. + + diff --git a/src/modules/llms/vendors/gemini/gemini.vendor.ts b/src/modules/llms/vendors/gemini/gemini.vendor.ts index 1723ac06f..ca7fc6d9c 100644 --- a/src/modules/llms/vendors/gemini/gemini.vendor.ts +++ b/src/modules/llms/vendors/gemini/gemini.vendor.ts @@ -5,6 +5,7 @@ import { backendCaps } from '~/modules/backend/state-backend'; import { apiAsync, apiQuery } from '~/common/util/trpc.client'; import type { GeminiAccessSchema } from '../../server/gemini/gemini.router'; +import type { GeminiBlockSafetyLevel } from '../../server/gemini/gemini.wiretypes'; import type { IModelVendor } from '../IModelVendor'; import type { VChatMessageOut } from '../../llm.client'; import { unifiedStreamingClient } from '../unifiedStreamingClient'; @@ -16,6 +17,7 @@ import { GeminiSourceSetup } from './GeminiSourceSetup'; export interface SourceSetupGemini { geminiKey: string; + minSafetyLevel: GeminiBlockSafetyLevel; } export interface LLMOptionsGemini { @@ -45,6 +47,7 @@ export const ModelVendorGemini: IModelVendor ({ geminiKey: '', + minSafetyLevel: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', }), validateSetup: (setup) => { return setup.geminiKey?.length > 0; @@ -52,6 +55,7 @@ export const ModelVendorGemini: IModelVendor ({ dialect: 'gemini', geminiKey: partialSetup?.geminiKey || '', + minSafetyLevel: partialSetup?.minSafetyLevel || 'HARM_BLOCK_THRESHOLD_UNSPECIFIED', }), // List Models @@ -89,4 +93,5 @@ export const ModelVendorGemini: IModelVendor