diff --git a/app/api/openai/[...path]/route.ts b/app/api/openai/[...path]/route.ts index 9df005a317a..8dc36f43355 100644 --- a/app/api/openai/[...path]/route.ts +++ b/app/api/openai/[...path]/route.ts @@ -1,4 +1,4 @@ -import { type OpenAIListModelResponse } from "@/app/client/platforms/openai"; +import { type OpenAI } from "@/app/client/openai/types"; import { getServerSideConfig } from "@/app/config/server"; import { OpenaiPath } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; @@ -6,9 +6,9 @@ import { NextRequest, NextResponse } from "next/server"; import { auth } from "../../auth"; import { requestOpenai } from "../../common"; -const ALLOWD_PATH = new Set(Object.values(OpenaiPath)); +const ALLOWD_PATH = new Set(Object.values(OpenaiPath) as string[]); -function getModels(remoteModelRes: OpenAIListModelResponse) { +function getModels(remoteModelRes: OpenAI.ListModelResponse) { const config = getServerSideConfig(); if (config.disableGPT4) { @@ -56,8 +56,8 @@ async function handle( const response = await requestOpenai(req); // list models - if (subpath === OpenaiPath.ListModelPath && response.status === 200) { - const resJson = (await response.json()) as OpenAIListModelResponse; + if (subpath === OpenaiPath.ListModel && response.status === 200) { + const resJson = await response.json(); const availableModels = getModels(resJson); return NextResponse.json(availableModels, { status: response.status, diff --git a/app/client/api.ts b/app/client/api.ts deleted file mode 100644 index b04dd88b88c..00000000000 --- a/app/client/api.ts +++ /dev/null @@ -1,151 +0,0 @@ -import { getClientConfig } from "../config/client"; -import { ACCESS_CODE_PREFIX } from "../constant"; -import { ChatMessage, ModelType, useAccessStore } from "../store"; -import { ChatGPTApi } from "./platforms/openai"; - -export const ROLES = ["system", "user", "assistant"] as const; -export type MessageRole = (typeof ROLES)[number]; - -export const Models = ["gpt-3.5-turbo", "gpt-4"] as const; -export type ChatModel = ModelType; - -export interface RequestMessage { - role: MessageRole; - content: string; -} - -export interface LLMConfig { - model: string; - temperature?: number; - top_p?: number; - stream?: boolean; - presence_penalty?: number; - frequency_penalty?: number; -} - -export interface ChatOptions { - messages: RequestMessage[]; - config: LLMConfig; - - onUpdate?: (message: string, chunk: string) => void; - onFinish: (message: string) => void; - onError?: (err: Error) => void; - onController?: (controller: AbortController) => void; -} - -export interface LLMUsage { - used: number; - total: number; -} - -export interface LLMModel { - name: string; - available: boolean; -} - -export abstract class LLMApi { - abstract chat(options: ChatOptions): Promise; - abstract usage(): Promise; - abstract models(): Promise; -} - -type ProviderName = "openai" | "azure" | "claude" | "palm"; - -interface Model { - name: string; - provider: ProviderName; - ctxlen: number; -} - -interface ChatProvider { - name: ProviderName; - apiConfig: { - baseUrl: string; - apiKey: string; - summaryModel: Model; - }; - models: Model[]; - - chat: () => void; - usage: () => void; -} - -export class ClientApi { - public llm: LLMApi; - - constructor() { - this.llm = new ChatGPTApi(); - } - - config() {} - - prompts() {} - - masks() {} - - async share(messages: ChatMessage[], avatarUrl: string | null = null) { - const msgs = messages - .map((m) => ({ - from: m.role === "user" ? "human" : "gpt", - value: m.content, - })) - .concat([ - { - from: "human", - value: - "Share from [ChatGPT Next Web]: https://github.com/Yidadaa/ChatGPT-Next-Web", - }, - ]); - // 敬告二开开发者们,为了开源大模型的发展,请不要修改上述消息,此消息用于后续数据清洗使用 - // Please do not modify this message - - console.log("[Share]", messages, msgs); - const clientConfig = getClientConfig(); - const proxyUrl = "/sharegpt"; - const rawUrl = "https://sharegpt.com/api/conversations"; - const shareUrl = clientConfig?.isApp ? rawUrl : proxyUrl; - const res = await fetch(shareUrl, { - body: JSON.stringify({ - avatarUrl, - items: msgs, - }), - headers: { - "Content-Type": "application/json", - }, - method: "POST", - }); - - const resJson = await res.json(); - console.log("[Share]", resJson); - if (resJson.id) { - return `https://shareg.pt/${resJson.id}`; - } - } -} - -export const api = new ClientApi(); - -export function getHeaders() { - const accessStore = useAccessStore.getState(); - let headers: Record = { - "Content-Type": "application/json", - "x-requested-with": "XMLHttpRequest", - }; - - const makeBearer = (token: string) => `Bearer ${token.trim()}`; - const validString = (x: string) => x && x.length > 0; - - // use user's api key first - if (validString(accessStore.token)) { - headers.Authorization = makeBearer(accessStore.token); - } else if ( - accessStore.enabledAccessControl() && - validString(accessStore.accessCode) - ) { - headers.Authorization = makeBearer( - ACCESS_CODE_PREFIX + accessStore.accessCode, - ); - } - - return headers; -} diff --git a/app/client/common/auth.ts b/app/client/common/auth.ts new file mode 100644 index 00000000000..9533ebfd2d3 --- /dev/null +++ b/app/client/common/auth.ts @@ -0,0 +1,28 @@ +import { getClientConfig } from "@/app/config/client"; +import { ACCESS_CODE_PREFIX } from "@/app/constant"; +import { useAccessStore } from "@/app/store"; + +export function bearer(value: string) { + return `Bearer ${value.trim()}`; +} + +export function getAuthHeaders(apiKey = "") { + const accessStore = useAccessStore.getState(); + const isApp = !!getClientConfig()?.isApp; + + let headers: Record = {}; + + if (apiKey) { + // use user's api key first + headers.Authorization = bearer(apiKey); + } else if ( + accessStore.enabledAccessControl() && + !isApp && + !!accessStore.accessCode + ) { + // or use access code + headers.Authorization = bearer(ACCESS_CODE_PREFIX + accessStore.accessCode); + } + + return headers; +} diff --git a/app/client/common/config.ts b/app/client/common/config.ts new file mode 100644 index 00000000000..127773a4c3b --- /dev/null +++ b/app/client/common/config.ts @@ -0,0 +1,5 @@ +export const COMMON_PROVIDER_CONFIG = { + customModels: "", + models: [] as string[], + autoFetchModels: false, // fetch available models from server or not +}; diff --git a/app/client/controller.ts b/app/client/common/controller.ts similarity index 100% rename from app/client/controller.ts rename to app/client/common/controller.ts diff --git a/app/client/common/share.ts b/app/client/common/share.ts new file mode 100644 index 00000000000..338e22cb285 --- /dev/null +++ b/app/client/common/share.ts @@ -0,0 +1,44 @@ +import { getClientConfig } from "@/app/config/client"; +import { ChatMessage } from "@/app/store"; + +export async function shareToShareGPT( + messages: ChatMessage[], + avatarUrl: string | null = null, +) { + const msgs = messages + .map((m) => ({ + from: m.role === "user" ? "human" : "gpt", + value: m.content, + })) + .concat([ + { + from: "human", + // 敬告二开开发者们,为了开源大模型的发展,请不要修改上述消息,此消息用于后续数据清洗使用 + // Please do not modify this message + value: + "Share from [ChatGPT Next Web]: https://github.com/Yidadaa/ChatGPT-Next-Web", + }, + ]); + + console.log("[Share]", messages, msgs); + const clientConfig = getClientConfig(); + const proxyUrl = "/sharegpt"; + const rawUrl = "https://sharegpt.com/api/conversations"; + const shareUrl = clientConfig?.isApp ? rawUrl : proxyUrl; + const res = await fetch(shareUrl, { + body: JSON.stringify({ + avatarUrl, + items: msgs, + }), + headers: { + "Content-Type": "application/json", + }, + method: "POST", + }); + + const resJson = await res.json(); + console.log("[Share]", resJson); + if (resJson.id) { + return `https://shareg.pt/${resJson.id}`; + } +} diff --git a/app/client/core.ts b/app/client/core.ts new file mode 100644 index 00000000000..a75cf3fc067 --- /dev/null +++ b/app/client/core.ts @@ -0,0 +1,28 @@ +import { MaskConfig, ProviderConfig } from "../store"; +import { shareToShareGPT } from "./common/share"; +import { createOpenAiClient } from "./openai"; +import { ChatControllerPool } from "./common/controller"; + +export const LLMClients = { + openai: createOpenAiClient, +}; + +export function createLLMClient( + config: ProviderConfig, + maskConfig: MaskConfig, +) { + return LLMClients[maskConfig.provider as any as keyof typeof LLMClients]( + config, + maskConfig.modelConfig, + ); +} + +export function createApi() { + return { + createLLMClient, + shareToShareGPT, + controllerManager: ChatControllerPool, + }; +} + +export const api = createApi(); diff --git a/app/client/index.ts b/app/client/index.ts new file mode 100644 index 00000000000..4e22af65629 --- /dev/null +++ b/app/client/index.ts @@ -0,0 +1,2 @@ +export * from "./types"; +export * from "./core"; diff --git a/app/client/openai/config.ts b/app/client/openai/config.ts new file mode 100644 index 00000000000..b27534162e6 --- /dev/null +++ b/app/client/openai/config.ts @@ -0,0 +1,20 @@ +import { COMMON_PROVIDER_CONFIG } from "../common/config"; + +export const OpenAIConfig = { + model: { + model: "gpt-3.5-turbo" as string, + summarizeModel: "gpt-3.5-turbo", + + temperature: 0.5, + top_p: 1, + max_tokens: 2000, + presence_penalty: 0, + frequency_penalty: 0, + }, + provider: { + name: "OpenAI", + endpoint: "https://api.openai.com", + apiKey: "", + ...COMMON_PROVIDER_CONFIG, + }, +}; diff --git a/app/client/openai/index.ts b/app/client/openai/index.ts new file mode 100644 index 00000000000..a452936de97 --- /dev/null +++ b/app/client/openai/index.ts @@ -0,0 +1,295 @@ +import { + EventStreamContentType, + fetchEventSource, +} from "@fortaine/fetch-event-source"; + +import { + API_PREFIX, + ApiPath, + DEFAULT_MODELS, + OpenaiPath, +} from "@/app/constant"; +import { ModelConfig, ProviderConfig } from "@/app/store"; + +import { OpenAI } from "./types"; + +import { ChatOptions, LLMModel, LLMUsage } from "../types"; +import Locale from "@/app/locales"; + +import { prettyObject } from "@/app/utils/format"; +import { getApiPath } from "@/app/utils/path"; +import { trimEnd } from "@/app/utils/string"; +import { omit } from "@/app/utils/object"; +import { createLogger } from "@/app/utils/log"; +import { getAuthHeaders } from "../common/auth"; + +export function createOpenAiClient( + providerConfigs: ProviderConfig, + modelConfig: ModelConfig, +) { + const openaiConfig = { ...providerConfigs.openai }; + const logger = createLogger("[OpenAI Client]"); + const openaiModelConfig = { ...modelConfig.openai }; + + return { + headers() { + return { + "Content-Type": "application/json", + ...getAuthHeaders(openaiConfig.apiKey), + }; + }, + + path(path: OpenaiPath): string { + let baseUrl = openaiConfig.endpoint; + + // if endpoint is empty, use default endpoint + if (baseUrl.trim().length === 0) { + baseUrl = getApiPath(ApiPath.OpenAI); + } + + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(API_PREFIX)) { + baseUrl = "https://" + baseUrl; + } + + baseUrl = trimEnd(baseUrl, "/"); + + return `${baseUrl}/${path}`; + }, + + extractMessage(res: OpenAI.ChatCompletionResponse) { + return res.choices[0]?.message?.content ?? ""; + }, + + beforeRequest(options: ChatOptions, stream = false) { + const messages = options.messages.map((v) => ({ + role: v.role, + content: v.content, + })); + + if (options.shouldSummarize) { + openaiModelConfig.model = openaiModelConfig.summarizeModel; + } + + const requestBody: OpenAI.ChatCompletionRequest = { + messages, + stream, + ...omit(openaiModelConfig, "summarizeModel"), + }; + + const path = this.path(OpenaiPath.Chat); + + logger.log("path = ", path, requestBody); + + const controller = new AbortController(); + options.onController?.(controller); + + const payload = { + method: "POST", + body: JSON.stringify(requestBody), + signal: controller.signal, + headers: this.headers(), + }; + + return { + path, + payload, + controller, + }; + }, + + async chat(options: ChatOptions) { + try { + const { path, payload, controller } = this.beforeRequest( + options, + false, + ); + + controller.signal.onabort = () => options.onFinish(""); + + const res = await fetch(path, payload); + const resJson = await res.json(); + + const message = this.extractMessage(resJson); + options.onFinish(message); + } catch (e) { + logger.error("failed to chat", e); + options.onError?.(e as Error); + } + }, + + async chatStream(options: ChatOptions) { + try { + const { path, payload, controller } = this.beforeRequest(options, true); + + const context = { + text: "", + finished: false, + }; + + const finish = () => { + if (!context.finished) { + options.onFinish(context.text); + context.finished = true; + } + }; + + controller.signal.onabort = finish; + + fetchEventSource(path, { + ...payload, + async onopen(res) { + const contentType = res.headers.get("content-type"); + logger.log("response content type: ", contentType); + + if (contentType?.startsWith("text/plain")) { + context.text = await res.clone().text(); + return finish(); + } + + if ( + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 + ) { + const responseTexts = [context.text]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + context.text = responseTexts.join("\n\n"); + + return finish(); + } + }, + onmessage(msg) { + if (msg.data === "[DONE]" || context.finished) { + return finish(); + } + const chunk = msg.data; + try { + const chunkJson = JSON.parse( + chunk, + ) as OpenAI.ChatCompletionStreamResponse; + const delta = chunkJson.choices[0].delta.content; + if (delta) { + context.text += delta; + options.onUpdate?.(context.text, delta); + } + } catch (e) { + logger.error("[Request] parse error", chunk, msg); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options.onError?.(e); + }, + openWhenHidden: true, + }); + } catch (e) { + logger.error("failed to chat", e); + options.onError?.(e as Error); + } + }, + + async usage() { + const formatDate = (d: Date) => + `${d.getFullYear()}-${(d.getMonth() + 1) + .toString() + .padStart(2, "0")}-${d.getDate().toString().padStart(2, "0")}`; + const ONE_DAY = 1 * 24 * 60 * 60 * 1000; + const now = new Date(); + const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1); + const startDate = formatDate(startOfMonth); + const endDate = formatDate(new Date(Date.now() + ONE_DAY)); + + const [used, subs] = await Promise.all([ + fetch( + `${this.path( + OpenaiPath.Usage, + )}?start_date=${startDate}&end_date=${endDate}`, + { + method: "GET", + headers: this.headers(), + }, + ), + fetch(this.path(OpenaiPath.Subs), { + method: "GET", + headers: this.headers(), + }), + ]); + + if (!used.ok || !subs.ok) { + throw new Error("Failed to query usage from openai"); + } + + const response = (await used.json()) as { + total_usage?: number; + error?: { + type: string; + message: string; + }; + }; + + const total = (await subs.json()) as { + hard_limit_usd?: number; + }; + + if (response.error?.type) { + throw Error(response.error?.message); + } + + response.total_usage = Math.round(response.total_usage ?? 0) / 100; + total.hard_limit_usd = + Math.round((total.hard_limit_usd ?? 0) * 100) / 100; + + return { + used: response.total_usage, + total: total.hard_limit_usd, + } as LLMUsage; + }, + + async models(): Promise { + const customModels = openaiConfig.customModels + .split(",") + .map((v) => v.trim()) + .map((v) => ({ + name: v, + available: true, + })); + + if (!openaiConfig.autoFetchModels) { + return [...DEFAULT_MODELS.slice(), ...customModels]; + } + + const res = await fetch(this.path(OpenaiPath.ListModel), { + method: "GET", + headers: this.headers(), + }); + + const resJson = (await res.json()) as OpenAI.ListModelResponse; + const chatModels = + resJson.data?.filter((m) => m.id.startsWith("gpt-")) ?? []; + + return chatModels + .map((m) => ({ + name: m.id, + available: true, + })) + .concat(customModels); + }, + }; +} diff --git a/app/client/openai/types.ts b/app/client/openai/types.ts new file mode 100644 index 00000000000..d1383922dbd --- /dev/null +++ b/app/client/openai/types.ts @@ -0,0 +1,79 @@ +export namespace OpenAI { + export type Role = "system" | "user" | "assistant" | "function"; + export type FinishReason = "stop" | "length" | "function_call"; + + export interface Message { + role: Role; + content?: string; + function_call?: { + name: string; + arguments: string; + }; + } + + export interface Function { + name: string; + description?: string; + parameters: object; + } + + export interface ListModelResponse { + object: string; + data: Array<{ + id: string; + object: string; + root: string; + }>; + } + + export interface ChatCompletionChoice { + index: number; + message: Message; + finish_reason: FinishReason; + } + + export interface ChatCompletionUsage { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + } + + export interface ChatCompletionResponse { + id: string; + object: string; + created: number; + model: string; + choices: ChatCompletionChoice[]; + usage: ChatCompletionUsage; + } + + export interface ChatCompletionChunkChoice { + index: number; + delta: Message; + finish_reason?: FinishReason; + } + + export interface ChatCompletionStreamResponse { + object: string; + created: number; + model: string; + choices: ChatCompletionChunkChoice[]; + } + + export interface ChatCompletionRequest { + model: string; + messages: Message[]; + + functions?: Function[]; + function_call?: "none" | "auto"; + + temperature?: number; + top_p?: number; + n?: number; + stream?: boolean; + stop?: string | string[]; + max_tokens?: number; + presence_penalty?: number; + frequency_penalty?: number; + } +} diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts deleted file mode 100644 index fd4eb59ce77..00000000000 --- a/app/client/platforms/openai.ts +++ /dev/null @@ -1,281 +0,0 @@ -import { - DEFAULT_API_HOST, - DEFAULT_MODELS, - OpenaiPath, - REQUEST_TIMEOUT_MS, -} from "@/app/constant"; -import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; - -import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api"; -import Locale from "../../locales"; -import { - EventStreamContentType, - fetchEventSource, -} from "@fortaine/fetch-event-source"; -import { prettyObject } from "@/app/utils/format"; -import { getClientConfig } from "@/app/config/client"; - -export interface OpenAIListModelResponse { - object: string; - data: Array<{ - id: string; - object: string; - root: string; - }>; -} - -export class ChatGPTApi implements LLMApi { - private disableListModels = true; - - path(path: string): string { - let openaiUrl = useAccessStore.getState().openaiUrl; - const apiPath = "/api/openai"; - - if (openaiUrl.length === 0) { - const isApp = !!getClientConfig()?.isApp; - openaiUrl = isApp ? DEFAULT_API_HOST : apiPath; - } - if (openaiUrl.endsWith("/")) { - openaiUrl = openaiUrl.slice(0, openaiUrl.length - 1); - } - if (!openaiUrl.startsWith("http") && !openaiUrl.startsWith(apiPath)) { - openaiUrl = "https://" + openaiUrl; - } - return [openaiUrl, path].join("/"); - } - - extractMessage(res: any) { - return res.choices?.at(0)?.message?.content ?? ""; - } - - async chat(options: ChatOptions) { - const messages = options.messages.map((v) => ({ - role: v.role, - content: v.content, - })); - - const modelConfig = { - ...useAppConfig.getState().modelConfig, - ...useChatStore.getState().currentSession().mask.modelConfig, - ...{ - model: options.config.model, - }, - }; - - const requestPayload = { - messages, - stream: options.config.stream, - model: modelConfig.model, - temperature: modelConfig.temperature, - presence_penalty: modelConfig.presence_penalty, - frequency_penalty: modelConfig.frequency_penalty, - top_p: modelConfig.top_p, - }; - - console.log("[Request] openai payload: ", requestPayload); - - const shouldStream = !!options.config.stream; - const controller = new AbortController(); - options.onController?.(controller); - - try { - const chatPath = this.path(OpenaiPath.ChatPath); - const chatPayload = { - method: "POST", - body: JSON.stringify(requestPayload), - signal: controller.signal, - headers: getHeaders(), - }; - - // make a fetch request - const requestTimeoutId = setTimeout( - () => controller.abort(), - REQUEST_TIMEOUT_MS, - ); - - if (shouldStream) { - let responseText = ""; - let finished = false; - - const finish = () => { - if (!finished) { - options.onFinish(responseText); - finished = true; - } - }; - - controller.signal.onabort = finish; - - fetchEventSource(chatPath, { - ...chatPayload, - async onopen(res) { - clearTimeout(requestTimeoutId); - const contentType = res.headers.get("content-type"); - console.log( - "[OpenAI] request response content type: ", - contentType, - ); - - if (contentType?.startsWith("text/plain")) { - responseText = await res.clone().text(); - return finish(); - } - - if ( - !res.ok || - !res.headers - .get("content-type") - ?.startsWith(EventStreamContentType) || - res.status !== 200 - ) { - const responseTexts = [responseText]; - let extraInfo = await res.clone().text(); - try { - const resJson = await res.clone().json(); - extraInfo = prettyObject(resJson); - } catch {} - - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - - if (extraInfo) { - responseTexts.push(extraInfo); - } - - responseText = responseTexts.join("\n\n"); - - return finish(); - } - }, - onmessage(msg) { - if (msg.data === "[DONE]" || finished) { - return finish(); - } - const text = msg.data; - try { - const json = JSON.parse(text); - const delta = json.choices[0].delta.content; - if (delta) { - responseText += delta; - options.onUpdate?.(responseText, delta); - } - } catch (e) { - console.error("[Request] parse error", text, msg); - } - }, - onclose() { - finish(); - }, - onerror(e) { - options.onError?.(e); - throw e; - }, - openWhenHidden: true, - }); - } else { - const res = await fetch(chatPath, chatPayload); - clearTimeout(requestTimeoutId); - - const resJson = await res.json(); - const message = this.extractMessage(resJson); - options.onFinish(message); - } - } catch (e) { - console.log("[Request] failed to make a chat request", e); - options.onError?.(e as Error); - } - } - async usage() { - const formatDate = (d: Date) => - `${d.getFullYear()}-${(d.getMonth() + 1).toString().padStart(2, "0")}-${d - .getDate() - .toString() - .padStart(2, "0")}`; - const ONE_DAY = 1 * 24 * 60 * 60 * 1000; - const now = new Date(); - const startOfMonth = new Date(now.getFullYear(), now.getMonth(), 1); - const startDate = formatDate(startOfMonth); - const endDate = formatDate(new Date(Date.now() + ONE_DAY)); - - const [used, subs] = await Promise.all([ - fetch( - this.path( - `${OpenaiPath.UsagePath}?start_date=${startDate}&end_date=${endDate}`, - ), - { - method: "GET", - headers: getHeaders(), - }, - ), - fetch(this.path(OpenaiPath.SubsPath), { - method: "GET", - headers: getHeaders(), - }), - ]); - - if (used.status === 401) { - throw new Error(Locale.Error.Unauthorized); - } - - if (!used.ok || !subs.ok) { - throw new Error("Failed to query usage from openai"); - } - - const response = (await used.json()) as { - total_usage?: number; - error?: { - type: string; - message: string; - }; - }; - - const total = (await subs.json()) as { - hard_limit_usd?: number; - }; - - if (response.error && response.error.type) { - throw Error(response.error.message); - } - - if (response.total_usage) { - response.total_usage = Math.round(response.total_usage) / 100; - } - - if (total.hard_limit_usd) { - total.hard_limit_usd = Math.round(total.hard_limit_usd * 100) / 100; - } - - return { - used: response.total_usage, - total: total.hard_limit_usd, - } as LLMUsage; - } - - async models(): Promise { - if (this.disableListModels) { - return DEFAULT_MODELS.slice(); - } - - const res = await fetch(this.path(OpenaiPath.ListModelPath), { - method: "GET", - headers: { - ...getHeaders(), - }, - }); - - const resJson = (await res.json()) as OpenAIListModelResponse; - const chatModels = resJson.data?.filter((m) => m.id.startsWith("gpt-")); - console.log("[Models]", chatModels); - - if (!chatModels) { - return []; - } - - return chatModels.map((m) => ({ - name: m.id, - available: true, - })); - } -} -export { OpenaiPath }; diff --git a/app/client/types.ts b/app/client/types.ts new file mode 100644 index 00000000000..694059e1c36 --- /dev/null +++ b/app/client/types.ts @@ -0,0 +1,39 @@ +import { DEFAULT_MODELS } from "../constant"; + +export interface LLMUsage { + used: number; + total: number; + available: boolean; +} + +export interface LLMModel { + name: string; + available: boolean; +} + +export const ROLES = ["system", "user", "assistant"] as const; +export type MessageRole = (typeof ROLES)[number]; + +export type ChatModel = (typeof DEFAULT_MODELS)[number]["name"]; + +export interface RequestMessage { + role: MessageRole; + content: string; +} + +export interface ChatOptions { + messages: RequestMessage[]; + shouldSummarize?: boolean; + + onUpdate?: (message: string, chunk: string) => void; + onFinish: (message: string) => void; + onError?: (err: Error) => void; + onController?: (controller: AbortController) => void; +} + +export type LLMClient = { + chat(options: ChatOptions): Promise; + chatStream(options: ChatOptions): Promise; + usage(): Promise; + models(): Promise; +}; diff --git a/app/components/auth.tsx b/app/components/auth.tsx index b82d0e894c7..b13a695d6ff 100644 --- a/app/components/auth.tsx +++ b/app/components/auth.tsx @@ -3,7 +3,7 @@ import { IconButton } from "./button"; import { useNavigate } from "react-router-dom"; import { Path } from "../constant"; -import { useAccessStore } from "../store"; +import { useAccessStore, useAppConfig, useChatStore } from "../store"; import Locale from "../locales"; import BotIcon from "../icons/bot.svg"; @@ -13,10 +13,14 @@ import { getClientConfig } from "../config/client"; export function AuthPage() { const navigate = useNavigate(); const access = useAccessStore(); + const config = useAppConfig(); const goHome = () => navigate(Path.Home); const goChat = () => navigate(Path.Chat); - const resetAccessCode = () => { access.updateCode(""); access.updateToken(""); }; // Reset access code to empty string + const resetAccessCode = () => { + access.update((config) => (config.accessCode = "")); + config.update((config) => (config.providerConfig.openai.apiKey = "")); + }; // Reset access code to empty string useEffect(() => { if (getClientConfig()?.isApp) { @@ -40,7 +44,9 @@ export function AuthPage() { placeholder={Locale.Auth.Input} value={access.accessCode} onChange={(e) => { - access.updateCode(e.currentTarget.value); + access.update( + (config) => (config.accessCode = e.currentTarget.value), + ); }} /> {!access.hideUserApiKey ? ( @@ -50,9 +56,12 @@ export function AuthPage() { className={styles["auth-input"]} type="password" placeholder={Locale.Settings.Token.Placeholder} - value={access.token} + value={config.providerConfig.openai.apiKey} onChange={(e) => { - access.updateToken(e.currentTarget.value); + config.update( + (config) => + (config.providerConfig.openai.apiKey = e.currentTarget.value), + ); }} /> diff --git a/app/components/chat-list.tsx b/app/components/chat-list.tsx index 7ba55585239..b27430e656e 100644 --- a/app/components/chat-list.tsx +++ b/app/components/chat-list.tsx @@ -39,6 +39,9 @@ export function ChatItem(props: { }); } }, [props.selected]); + + const modelConfig = useChatStore().extractModelConfig(props.mask.config); + return ( {(provided) => ( @@ -60,7 +63,10 @@ export function ChatItem(props: { {props.narrow ? (
- +
{props.count} diff --git a/app/components/chat.tsx b/app/components/chat.tsx index cca096eb874..7b7b66bec5e 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -1,12 +1,5 @@ import { useDebouncedCallback } from "use-debounce"; -import React, { - useState, - useRef, - useEffect, - useMemo, - useCallback, - Fragment, -} from "react"; +import React, { useState, useRef, useEffect, useMemo, Fragment } from "react"; import SendWhiteIcon from "../icons/send-white.svg"; import BrainIcon from "../icons/brain.svg"; @@ -37,15 +30,12 @@ import RobotIcon from "../icons/robot.svg"; import { ChatMessage, - SubmitKey, useChatStore, BOT_HELLO, createMessage, useAccessStore, - Theme, useAppConfig, DEFAULT_TOPIC, - ModelType, } from "../store"; import { @@ -57,7 +47,7 @@ import { import dynamic from "next/dynamic"; -import { ChatControllerPool } from "../client/controller"; +import { ChatControllerPool } from "../client/common/controller"; import { Prompt, usePromptStore } from "../store/prompt"; import Locale from "../locales"; @@ -73,11 +63,10 @@ import { showPrompt, showToast, } from "./ui-lib"; -import { useLocation, useNavigate } from "react-router-dom"; +import { useNavigate } from "react-router-dom"; import { CHAT_PAGE_SIZE, LAST_INPUT_KEY, - MAX_RENDER_MSG_COUNT, Path, REQUEST_TIMEOUT_MS, UNFINISHED_INPUT, @@ -89,6 +78,8 @@ import { ChatCommandPrefix, useChatCommand, useCommand } from "../command"; import { prettyObject } from "../utils/format"; import { ExportMessageModal } from "./exporter"; import { getClientConfig } from "../config/client"; +import { deepClone } from "../utils/clone"; +import { SubmitKey, Theme } from "../typing"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -142,7 +133,7 @@ export function SessionConfigModel(props: { onClose: () => void }) { }} shouldSyncFromGlobal extraListItems={ - session.mask.modelConfig.sendMemory ? ( + session.mask.config.chatConfig.sendMemory ? ( ChatControllerPool.stopAll(); + const client = chatStore.getClient(); + const modelConfig = chatStore.getCurrentModelConfig(); + const currentModel = modelConfig.model; // switch model - const currentModel = chatStore.currentSession().mask.modelConfig.model; - const models = useMemo( - () => - config - .allModels() - .filter((m) => m.available) - .map((m) => m.name), - [config], - ); + const [models, setModels] = useState([]); + useEffect(() => { + client + .models() + .then((_models) => + setModels(_models.filter((v) => v.available).map((v) => v.name)), + ); + }, []); const [showModelSelector, setShowModelSelector] = useState(false); return ( @@ -526,7 +519,7 @@ export function ChatActions(props: { onSelection={(s) => { if (s.length === 0) return; chatStore.updateCurrentSession((session) => { - session.mask.modelConfig.model = s[0] as ModelType; + chatStore.extractModelConfig(session.mask.config).model = s[0]; session.mask.syncGlobalConfig = false; }); showToast(s[0]); @@ -603,6 +596,9 @@ function _Chat() { type RenderMessage = ChatMessage & { preview?: boolean }; const chatStore = useChatStore(); + const modelConfig = chatStore.getCurrentModelConfig(); + const maskConfig = chatStore.getCurrentMaskConfig(); + const session = chatStore.currentSession(); const config = useAppConfig(); const fontSize = config.fontSize; @@ -747,7 +743,7 @@ function _Chat() { // auto sync mask config from global config if (session.mask.syncGlobalConfig) { console.log("[Mask] syncing from global, name = ", session.mask.name); - session.mask.modelConfig = { ...config.modelConfig }; + session.mask.config = deepClone(config.globalMaskConfig); } }); // eslint-disable-next-line react-hooks/exhaustive-deps @@ -979,7 +975,7 @@ function _Chat() { console.log("[Command] got code from url: ", text); showConfirm(Locale.URLCommand.Code + `code = ${text}`).then((res) => { if (res) { - accessStore.updateCode(text); + accessStore.update((config) => (config.accessCode = text)); } }); }, @@ -999,10 +995,10 @@ function _Chat() { ).then((res) => { if (!res) return; if (payload.key) { - accessStore.updateToken(payload.key); + // TODO: auto-fill openai api key here, must specific provider type } if (payload.url) { - accessStore.updateOpenAiUrl(payload.url); + // TODO: auto-fill openai url here, must specific provider type } }); } @@ -1159,7 +1155,10 @@ function _Chat() { {["system"].includes(message.role) ? ( ) : ( - + )} )} diff --git a/app/components/config/index.tsx b/app/components/config/index.tsx new file mode 100644 index 00000000000..b08fe06088f --- /dev/null +++ b/app/components/config/index.tsx @@ -0,0 +1,171 @@ +import { + ChatConfig, + LLMProvider, + LLMProviders, + ModelConfig, + ProviderConfig, +} from "@/app/store"; +import { Updater } from "@/app/typing"; +import { OpenAIModelConfig } from "./openai/model"; +import { OpenAIProviderConfig } from "./openai/provider"; +import { ListItem, Select } from "../ui-lib"; +import Locale from "@/app/locales"; +import { InputRange } from "../input-range"; + +export function ModelConfigList(props: { + provider: LLMProvider; + config: ModelConfig; + updateConfig: Updater; +}) { + if (props.provider === "openai") { + return ( + { + props.updateConfig((config) => update(config.openai)); + }} + models={[ + { + name: "gpt-3.5-turbo", + available: true, + }, + { + name: "gpt-4", + available: true, + }, + ]} + /> + ); + } + + return null; +} + +export function ProviderConfigList(props: { + provider: LLMProvider; + config: ProviderConfig; + updateConfig: Updater; +}) { + if (props.provider === "openai") { + return ( + { + props.updateConfig((config) => update(config.openai)); + }} + /> + ); + } + + return null; +} + +export function ProviderSelectItem(props: { + value: LLMProvider; + update: (value: LLMProvider) => void; +}) { + return ( + + + + ); +} + +export function ChatConfigList(props: { + config: ChatConfig; + updateConfig: (updater: (config: ChatConfig) => void) => void; +}) { + return ( + <> + + + props.updateConfig( + (config) => + (config.enableInjectSystemPrompts = e.currentTarget.checked), + ) + } + > + + + + + props.updateConfig( + (config) => (config.template = e.currentTarget.value), + ) + } + > + + + + + props.updateConfig( + (config) => (config.historyMessageCount = e.target.valueAsNumber), + ) + } + > + + + + + props.updateConfig( + (config) => + (config.compressMessageLengthThreshold = + e.currentTarget.valueAsNumber), + ) + } + > + + + + props.updateConfig( + (config) => (config.sendMemory = e.currentTarget.checked), + ) + } + > + + + ); +} diff --git a/app/components/config/openai/model.tsx b/app/components/config/openai/model.tsx new file mode 100644 index 00000000000..acd5b74e48e --- /dev/null +++ b/app/components/config/openai/model.tsx @@ -0,0 +1,113 @@ +import { ModelConfig } from "@/app/store"; +import { ModelConfigProps } from "../types"; +import { ListItem, Select } from "../../ui-lib"; +import Locale from "@/app/locales"; +import { InputRange } from "../../input-range"; + +export function OpenAIModelConfig( + props: ModelConfigProps, +) { + return ( + <> + + + + + { + props.updateConfig( + (config) => (config.temperature = e.currentTarget.valueAsNumber), + ); + }} + > + + + { + props.updateConfig( + (config) => (config.top_p = e.currentTarget.valueAsNumber), + ); + }} + > + + + + props.updateConfig( + (config) => (config.max_tokens = e.currentTarget.valueAsNumber), + ) + } + > + + + { + props.updateConfig( + (config) => + (config.presence_penalty = e.currentTarget.valueAsNumber), + ); + }} + > + + + + { + props.updateConfig( + (config) => + (config.frequency_penalty = e.currentTarget.valueAsNumber), + ); + }} + > + + + ); +} diff --git a/app/components/config/openai/provider.tsx b/app/components/config/openai/provider.tsx new file mode 100644 index 00000000000..b905b130dfa --- /dev/null +++ b/app/components/config/openai/provider.tsx @@ -0,0 +1,71 @@ +import { ProviderConfig } from "@/app/store"; +import { ProviderConfigProps } from "../types"; +import { ListItem, PasswordInput } from "../../ui-lib"; +import Locale from "@/app/locales"; +import { REMOTE_API_HOST } from "@/app/constant"; + +export function OpenAIProviderConfig( + props: ProviderConfigProps, +) { + return ( + <> + + + props.updateConfig( + (config) => (config.endpoint = e.currentTarget.value), + ) + } + > + + + { + props.updateConfig( + (config) => (config.apiKey = e.currentTarget.value), + ); + }} + /> + + + + props.updateConfig( + (config) => (config.customModels = e.currentTarget.value), + ) + } + > + + + + + props.updateConfig( + (config) => (config.autoFetchModels = e.currentTarget.checked), + ) + } + > + + + ); +} diff --git a/app/components/config/types.ts b/app/components/config/types.ts new file mode 100644 index 00000000000..529e60fa831 --- /dev/null +++ b/app/components/config/types.ts @@ -0,0 +1,14 @@ +import { LLMModel } from "@/app/client"; +import { Updater } from "@/app/typing"; + +export type ModelConfigProps = { + models: LLMModel[]; + config: T; + updateConfig: Updater; +}; + +export type ProviderConfigProps = { + readonly?: boolean; + config: T; + updateConfig: Updater; +}; diff --git a/app/components/emoji.tsx b/app/components/emoji.tsx index 03aac05f278..6f4dc62a920 100644 --- a/app/components/emoji.tsx +++ b/app/components/emoji.tsx @@ -28,7 +28,7 @@ export function AvatarPicker(props: { ); } -export function Avatar(props: { model?: ModelType; avatar?: string }) { +export function Avatar(props: { model?: string; avatar?: string }) { if (props.model) { return (
diff --git a/app/components/exporter.tsx b/app/components/exporter.tsx index 0a885d87463..7cba87a8d98 100644 --- a/app/components/exporter.tsx +++ b/app/components/exporter.tsx @@ -27,12 +27,12 @@ import { Avatar } from "./emoji"; import dynamic from "next/dynamic"; import NextImage from "next/image"; -import { toBlob, toJpeg, toPng } from "html-to-image"; +import { toBlob, toPng } from "html-to-image"; import { DEFAULT_MASK_AVATAR } from "../store/mask"; -import { api } from "../client/api"; import { prettyObject } from "../utils/format"; import { EXPORT_MESSAGE_CLASS_NAME } from "../constant"; import { getClientConfig } from "../config/client"; +import { api } from "../client"; const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -290,7 +290,7 @@ export function PreviewActions(props: { setShouldExport(false); api - .share(msgs) + .shareToShareGPT(msgs) .then((res) => { if (!res) return; showModal({ @@ -403,6 +403,7 @@ export function ImagePreviewer(props: { const chatStore = useChatStore(); const session = chatStore.currentSession(); const mask = session.mask; + const modelConfig = chatStore.getCurrentModelConfig(); const config = useAppConfig(); const previewRef = useRef(null); @@ -437,13 +438,13 @@ export function ImagePreviewer(props: { showToast(Locale.Export.Image.Toast); const dom = previewRef.current; if (!dom) return; - + const isApp = getClientConfig()?.isApp; - + try { const blob = await toPng(dom); if (!blob) return; - + if (isMobile || (isApp && window.__TAURI__)) { if (isApp && window.__TAURI__) { const result = await window.__TAURI__.dialog.save({ @@ -459,7 +460,7 @@ export function ImagePreviewer(props: { }, ], }); - + if (result !== null) { const response = await fetch(blob); const buffer = await response.arrayBuffer(); @@ -526,7 +527,7 @@ export function ImagePreviewer(props: {
- {Locale.Exporter.Model}: {mask.modelConfig.model} + {Locale.Exporter.Model}: {modelConfig.model}
{Locale.Exporter.Messages}: {props.messages.length} diff --git a/app/components/home.tsx b/app/components/home.tsx index 811cbdf51cb..1fc737952ca 100644 --- a/app/components/home.tsx +++ b/app/components/home.tsx @@ -27,7 +27,6 @@ import { SideBar } from "./sidebar"; import { useAppConfig } from "../store/config"; import { AuthPage } from "./auth"; import { getClientConfig } from "../config/client"; -import { api } from "../client/api"; import { useAccessStore } from "../store"; export function Loading(props: { noLogo?: boolean }) { @@ -128,7 +127,8 @@ function Screen() { const isHome = location.pathname === Path.Home; const isAuth = location.pathname === Path.Auth; const isMobileScreen = useMobileScreen(); - const shouldTightBorder = getClientConfig()?.isApp || (config.tightBorder && !isMobileScreen); + const shouldTightBorder = + getClientConfig()?.isApp || (config.tightBorder && !isMobileScreen); useEffect(() => { loadAsyncGoogleFont(); @@ -170,10 +170,7 @@ export function useLoadData() { const config = useAppConfig(); useEffect(() => { - (async () => { - const models = await api.llm.models(); - config.mergeModels(models); - })(); + // TODO: fetch available models from server // eslint-disable-next-line react-hooks/exhaustive-deps }, []); } @@ -185,7 +182,7 @@ export function Home() { useEffect(() => { console.log("[Config] got config from build time", getClientConfig()); - useAccessStore.getState().fetch(); + useAccessStore.getState().fetchConfig(); }, []); if (!useHasHydrated()) { diff --git a/app/components/mask.tsx b/app/components/mask.tsx index 9fe1d485a6b..1dc04c71a51 100644 --- a/app/components/mask.tsx +++ b/app/components/mask.tsx @@ -21,7 +21,6 @@ import { useAppConfig, useChatStore, } from "../store"; -import { ROLES } from "../client/api"; import { Input, List, @@ -36,19 +35,20 @@ import Locale, { AllLangs, ALL_LANG_OPTIONS, Lang } from "../locales"; import { useNavigate } from "react-router-dom"; import chatStyle from "./chat.module.scss"; -import { useEffect, useState } from "react"; +import { useState } from "react"; import { copyToClipboard, downloadAs, readFromFile } from "../utils"; import { Updater } from "../typing"; -import { ModelConfigList } from "./model-config"; import { FileName, Path } from "../constant"; import { BUILTIN_MASK_STORE } from "../masks"; -import { nanoid } from "nanoid"; import { DragDropContext, Droppable, Draggable, OnDragEndResponder, } from "@hello-pangea/dnd"; +import { ROLES } from "../client"; +import { deepClone } from "../utils/clone"; +import { ChatConfigList, ModelConfigList, ProviderSelectItem } from "./config"; // drag and drop helper function function reorder(list: T[], startIndex: number, endIndex: number): T[] { @@ -58,11 +58,11 @@ function reorder(list: T[], startIndex: number, endIndex: number): T[] { return result; } -export function MaskAvatar(props: { mask: Mask }) { - return props.mask.avatar !== DEFAULT_MASK_AVATAR ? ( - +export function MaskAvatar(props: { avatar: string; model: string }) { + return props.avatar !== DEFAULT_MASK_AVATAR ? ( + ) : ( - + ); } @@ -74,14 +74,15 @@ export function MaskConfig(props: { shouldSyncFromGlobal?: boolean; }) { const [showPicker, setShowPicker] = useState(false); + const modelConfig = useChatStore().extractModelConfig(props.mask.config); const updateConfig = (updater: (config: ModelConfig) => void) => { if (props.readonly) return; - const config = { ...props.mask.modelConfig }; - updater(config); + const config = deepClone(props.mask.config); + updater(config.modelConfig); props.updateMask((mask) => { - mask.modelConfig = config; + mask.config = config; // if user changed current session mask, it will disable auto sync mask.syncGlobalConfig = false; }); @@ -123,7 +124,10 @@ export function MaskConfig(props: { onClick={() => setShowPicker(true)} style={{ cursor: "pointer" }} > - +
@@ -182,7 +186,7 @@ export function MaskConfig(props: { ) { props.updateMask((mask) => { mask.syncGlobalConfig = checked; - mask.modelConfig = { ...globalConfig.modelConfig }; + mask.config = deepClone(globalConfig.globalMaskConfig); }); } else if (!checked) { props.updateMask((mask) => { @@ -196,10 +200,28 @@ export function MaskConfig(props: { + { + props.updateMask((mask) => (mask.config.provider = value)); + }} + /> + + + + { + const chatConfig = deepClone(props.mask.config.chatConfig); + updater(chatConfig); + props.updateMask((mask) => (mask.config.chatConfig = chatConfig)); + }} + /> {props.extraListItems} @@ -398,7 +420,7 @@ export function MaskPage() { setSearchText(text); if (text.length > 0) { const result = allMasks.filter((m) => - m.name.toLowerCase().includes(text.toLowerCase()) + m.name.toLowerCase().includes(text.toLowerCase()), ); setSearchMasks(result); } else { @@ -523,14 +545,17 @@ export function MaskPage() {
- +
{m.name}
{`${Locale.Mask.Item.Info(m.context.length)} / ${ ALL_LANG_OPTIONS[m.lang] - } / ${m.modelConfig.model}`} + } / ${chatStore.extractModelConfig(m.config).model}`}
diff --git a/app/components/message-selector.tsx b/app/components/message-selector.tsx index cadf52e643e..9a2c4cbff20 100644 --- a/app/components/message-selector.tsx +++ b/app/components/message-selector.tsx @@ -71,6 +71,7 @@ export function MessageSelector(props: { onSelected?: (messages: ChatMessage[]) => void; }) { const chatStore = useChatStore(); + const modelConfig = chatStore.getCurrentModelConfig(); const session = chatStore.currentSession(); const isValid = (m: ChatMessage) => m.content && !m.isError && !m.streaming; const messages = session.messages.filter( @@ -195,7 +196,10 @@ export function MessageSelector(props: { {m.role === "user" ? ( ) : ( - + )}
diff --git a/app/components/model-config.tsx b/app/components/model-config.tsx index 63950a40d04..00734382cf8 100644 --- a/app/components/model-config.tsx +++ b/app/components/model-config.tsx @@ -4,10 +4,12 @@ import Locale from "../locales"; import { InputRange } from "./input-range"; import { ListItem, Select } from "./ui-lib"; -export function ModelConfigList(props: { +export function _ModelConfigList(props: { modelConfig: ModelConfig; updateConfig: (updater: (config: ModelConfig) => void) => void; }) { + return null; + /* const config = useAppConfig(); return ( @@ -130,84 +132,8 @@ export function ModelConfigList(props: { > - - - props.updateConfig( - (config) => - (config.enableInjectSystemPrompts = e.currentTarget.checked), - ) - } - > - - - - - props.updateConfig( - (config) => (config.template = e.currentTarget.value), - ) - } - > - - - - - props.updateConfig( - (config) => (config.historyMessageCount = e.target.valueAsNumber), - ) - } - > - - - - - props.updateConfig( - (config) => - (config.compressMessageLengthThreshold = - e.currentTarget.valueAsNumber), - ) - } - > - - - - props.updateConfig( - (config) => (config.sendMemory = e.currentTarget.checked), - ) - } - > - + ); + */ } diff --git a/app/components/new-chat.tsx b/app/components/new-chat.tsx index 76cbbeeb17e..dac918e1207 100644 --- a/app/components/new-chat.tsx +++ b/app/components/new-chat.tsx @@ -29,9 +29,11 @@ function getIntersectionArea(aRect: DOMRect, bRect: DOMRect) { } function MaskItem(props: { mask: Mask; onClick?: () => void }) { + const modelConfig = useChatStore().extractModelConfig(props.mask.config); + return (
- +
{props.mask.name}
); diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 795469a9681..ffe3850f098 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -30,16 +30,15 @@ import { showConfirm, showToast, } from "./ui-lib"; -import { ModelConfigList } from "./model-config"; import { IconButton } from "./button"; import { - SubmitKey, useChatStore, - Theme, useUpdateStore, useAccessStore, useAppConfig, + LLMProvider, + LLMProviders, } from "../store"; import Locale, { @@ -61,6 +60,14 @@ import { useSyncStore } from "../store/sync"; import { nanoid } from "nanoid"; import { useMaskStore } from "../store/mask"; import { ProviderType } from "../utils/cloud"; +import { + ChatConfigList, + ModelConfigList, + ProviderConfigList, + ProviderSelectItem, +} from "./config"; +import { SubmitKey, Theme } from "../typing"; +import { deepClone } from "../utils/clone"; function EditPromptModal(props: { id: string; onClose: () => void }) { const promptStore = usePromptStore(); @@ -757,8 +764,7 @@ export function Settings() { step="1" onChange={(e) => updateConfig( - (config) => - (config.fontSize = Number.parseInt(e.currentTarget.value)), + (config) => (config.fontSize = e.currentTarget.valueAsNumber), ) } > @@ -770,11 +776,14 @@ export function Settings() { > updateConfig( (config) => - (config.enableAutoGenerateTitle = e.currentTarget.checked), + (config.globalMaskConfig.chatConfig.enableAutoGenerateTitle = + e.currentTarget.checked), ) } > @@ -877,7 +886,9 @@ export function Settings() { type="text" placeholder={Locale.Settings.AccessCode.Placeholder} onChange={(e) => { - accessStore.updateCode(e.currentTarget.value); + accessStore.update( + (config) => (config.accessCode = e.currentTarget.value), + ); }} /> @@ -885,36 +896,7 @@ export function Settings() { <> )} - {!accessStore.hideUserApiKey ? ( - <> - - - accessStore.updateOpenAiUrl(e.currentTarget.value) - } - > - - - { - accessStore.updateToken(e.currentTarget.value); - }} - /> - - - ) : null} + {!accessStore.hideUserApiKey ? <> : null} {!accessStore.hideBalanceQuery ? ( ) : null} - - - - config.update( - (config) => (config.customModels = e.currentTarget.value), - ) - } - > - + + config.update((_config) => { + _config.globalMaskConfig.provider = value; + }) + } + /> + + { + config.update((_config) => update(_config.providerConfig)); + }} + /> { - const modelConfig = { ...config.modelConfig }; + const modelConfig = { ...config.globalMaskConfig.modelConfig }; updater(modelConfig); - config.update((config) => (config.modelConfig = modelConfig)); + config.update( + (config) => (config.globalMaskConfig.modelConfig = modelConfig), + ); + }} + /> + { + const chatConfig = deepClone(config.globalMaskConfig.chatConfig); + updater(chatConfig); + config.update( + (config) => (config.globalMaskConfig.chatConfig = chatConfig), + ); }} /> diff --git a/app/constant.ts b/app/constant.ts index e03e00971cc..15cdf412fcb 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -8,8 +8,8 @@ export const FETCH_COMMIT_URL = `https://api.github.com/repos/${OWNER}/${REPO}/c export const FETCH_TAG_URL = `https://api.github.com/repos/${OWNER}/${REPO}/tags?per_page=1`; export const RUNTIME_CONFIG_DOM = "danger-runtime-config"; -export const DEFAULT_CORS_HOST = "https://ab.nextweb.fun"; -export const DEFAULT_API_HOST = `${DEFAULT_CORS_HOST}/api/proxy`; +export const REMOTE_CORS_HOST = "https://ab.nextweb.fun"; +export const REMOTE_API_HOST = `${REMOTE_CORS_HOST}/api/proxy`; export enum Path { Home = "/", @@ -20,8 +20,12 @@ export enum Path { Auth = "/auth", } +export const API_PREFIX = "/api"; + export enum ApiPath { + OpenAI = "/api/openai", Cors = "/api/cors", + Config = "/api/config", } export enum SlotID { @@ -59,12 +63,12 @@ export const REQUEST_TIMEOUT_MS = 60000; export const EXPORT_MESSAGE_CLASS_NAME = "export-markdown"; -export const OpenaiPath = { - ChatPath: "v1/chat/completions", - UsagePath: "dashboard/billing/usage", - SubsPath: "dashboard/billing/subscription", - ListModelPath: "v1/models", -}; +export enum OpenaiPath { + Chat = "v1/chat/completions", + Usage = "dashboard/billing/usage", + Subs = "dashboard/billing/subscription", + ListModel = "v1/models", +} export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang export const DEFAULT_SYSTEM_TEMPLATE = ` diff --git a/app/locales/ar.ts b/app/locales/ar.ts index d5844acd695..221c1bc7ef1 100644 --- a/app/locales/ar.ts +++ b/app/locales/ar.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const ar: PartialLocaleType = { diff --git a/app/locales/bn.ts b/app/locales/bn.ts index 2db132cecc2..7660924d4f5 100644 --- a/app/locales/bn.ts +++ b/app/locales/bn.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import { PartialLocaleType } from "./index"; const bn: PartialLocaleType = { diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 4cd963fb8e2..39b0a676d76 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -1,5 +1,5 @@ import { getClientConfig } from "../config/client"; -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; const isApp = !!getClientConfig()?.isApp; diff --git a/app/locales/cs.ts b/app/locales/cs.ts index 57aa803e42b..5cee4f7218c 100644 --- a/app/locales/cs.ts +++ b/app/locales/cs.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const cs: PartialLocaleType = { diff --git a/app/locales/de.ts b/app/locales/de.ts index e0bdc52b749..f7d3de0aa68 100644 --- a/app/locales/de.ts +++ b/app/locales/de.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const de: PartialLocaleType = { diff --git a/app/locales/en.ts b/app/locales/en.ts index 928c4b72d4e..882afbaa0da 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -1,5 +1,5 @@ import { getClientConfig } from "../config/client"; -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import { LocaleType } from "./index"; // if you are adding a new translation, please use PartialLocaleType instead of LocaleType diff --git a/app/locales/es.ts b/app/locales/es.ts index a6ae154f44f..200535a44b9 100644 --- a/app/locales/es.ts +++ b/app/locales/es.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const es: PartialLocaleType = { diff --git a/app/locales/fr.ts b/app/locales/fr.ts index f5200f2719c..64a98f3e71c 100644 --- a/app/locales/fr.ts +++ b/app/locales/fr.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const fr: PartialLocaleType = { diff --git a/app/locales/id.ts b/app/locales/id.ts index b5e4a70b751..ae536ee119b 100644 --- a/app/locales/id.ts +++ b/app/locales/id.ts @@ -1,11 +1,12 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import { PartialLocaleType } from "./index"; const id: PartialLocaleType = { WIP: "Coming Soon...", Error: { - Unauthorized: "Akses tidak diizinkan, silakan masukkan kode akses atau masukkan kunci API OpenAI Anda. di halaman [autentikasi](/#/auth) atau di halaman [Pengaturan](/#/settings).", - }, + Unauthorized: + "Akses tidak diizinkan, silakan masukkan kode akses atau masukkan kunci API OpenAI Anda. di halaman [autentikasi](/#/auth) atau di halaman [Pengaturan](/#/settings).", + }, Auth: { Title: "Diperlukan Kode Akses", Tips: "Masukkan kode akses di bawah", diff --git a/app/locales/it.ts b/app/locales/it.ts index bf20747b108..d3f2033f703 100644 --- a/app/locales/it.ts +++ b/app/locales/it.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const it: PartialLocaleType = { diff --git a/app/locales/jp.ts b/app/locales/jp.ts index b63e8ba3a56..57e9e507ef9 100644 --- a/app/locales/jp.ts +++ b/app/locales/jp.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const jp: PartialLocaleType = { @@ -20,7 +20,8 @@ const jp: PartialLocaleType = { Stop: "停止", Retry: "リトライ", Pin: "ピン", - PinToastContent: "コンテキストプロンプトに1つのメッセージをピン留めしました", + PinToastContent: + "コンテキストプロンプトに1つのメッセージをピン留めしました", PinToastAction: "表示", Delete: "削除", Edit: "編集", diff --git a/app/locales/ko.ts b/app/locales/ko.ts index 717ce30b2f8..ee6bf9ad235 100644 --- a/app/locales/ko.ts +++ b/app/locales/ko.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; diff --git a/app/locales/no.ts b/app/locales/no.ts index 43c92916f3e..c030c03d5e8 100644 --- a/app/locales/no.ts +++ b/app/locales/no.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const no: PartialLocaleType = { diff --git a/app/locales/ru.ts b/app/locales/ru.ts index bf98b4eb865..25879263947 100644 --- a/app/locales/ru.ts +++ b/app/locales/ru.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const ru: PartialLocaleType = { diff --git a/app/locales/tr.ts b/app/locales/tr.ts index 06996d83dac..6b216471112 100644 --- a/app/locales/tr.ts +++ b/app/locales/tr.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const tr: PartialLocaleType = { diff --git a/app/locales/tw.ts b/app/locales/tw.ts index e9f38d097e1..868ffd671ba 100644 --- a/app/locales/tw.ts +++ b/app/locales/tw.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const tw: PartialLocaleType = { diff --git a/app/locales/vi.ts b/app/locales/vi.ts index 8f53a3dc1ee..1f8b49ab53a 100644 --- a/app/locales/vi.ts +++ b/app/locales/vi.ts @@ -1,4 +1,4 @@ -import { SubmitKey } from "../store/config"; +import { SubmitKey } from "@/app/typing"; import type { PartialLocaleType } from "./index"; const vi: PartialLocaleType = { diff --git a/app/masks/typing.ts b/app/masks/typing.ts index 1ded6a90295..7fba6cec4a5 100644 --- a/app/masks/typing.ts +++ b/app/masks/typing.ts @@ -1,7 +1,9 @@ import { ModelConfig } from "../store"; import { type Mask } from "../store/mask"; -export type BuiltinMask = Omit & { - builtin: Boolean; - modelConfig: Partial; -}; +export type BuiltinMask = + | any + | (Omit & { + builtin: Boolean; + modelConfig: Partial; + }); diff --git a/app/store/access.ts b/app/store/access.ts index 9eaa81e5ea3..a27b3276bac 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -1,23 +1,20 @@ -import { DEFAULT_API_HOST, DEFAULT_MODELS, StoreKey } from "../constant"; -import { getHeaders } from "../client/api"; +import { REMOTE_API_HOST, DEFAULT_MODELS, StoreKey } from "../constant"; import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; +import { getAuthHeaders } from "../client/common/auth"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done const DEFAULT_OPENAI_URL = - getClientConfig()?.buildMode === "export" ? DEFAULT_API_HOST : "/api/openai/"; + getClientConfig()?.buildMode === "export" ? REMOTE_API_HOST : "/api/openai/"; console.log("[API] default openai url", DEFAULT_OPENAI_URL); const DEFAULT_ACCESS_STATE = { - token: "", accessCode: "", needCode: true, hideUserApiKey: false, hideBalanceQuery: false, disableGPT4: false, - - openaiUrl: DEFAULT_OPENAI_URL, }; export const useAccessStore = createPersistStore( @@ -25,35 +22,24 @@ export const useAccessStore = createPersistStore( (set, get) => ({ enabledAccessControl() { - this.fetch(); + this.fetchConfig(); return get().needCode; }, - updateCode(code: string) { - set(() => ({ accessCode: code?.trim() })); - }, - updateToken(token: string) { - set(() => ({ token: token?.trim() })); - }, - updateOpenAiUrl(url: string) { - set(() => ({ openaiUrl: url?.trim() })); - }, isAuthorized() { - this.fetch(); + this.fetchConfig(); // has token or has code or disabled access control - return ( - !!get().token || !!get().accessCode || !this.enabledAccessControl() - ); + return !!get().accessCode || !this.enabledAccessControl(); }, - fetch() { + fetchConfig() { if (fetchState > 0 || getClientConfig()?.buildMode === "export") return; fetchState = 1; fetch("/api/config", { method: "post", body: null, headers: { - ...getHeaders(), + ...getAuthHeaders(), }, }) .then((res) => res.json()) diff --git a/app/store/chat.ts b/app/store/chat.ts index 56ac8db6cc1..2a66a359b4c 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -2,7 +2,13 @@ import { trimTopic } from "../utils"; import Locale, { getLang } from "../locales"; import { showToast } from "../components/ui-lib"; -import { ModelConfig, ModelType, useAppConfig } from "./config"; +import { + LLMProvider, + MaskConfig, + ModelConfig, + ModelType, + useAppConfig, +} from "./config"; import { createEmptyMask, Mask } from "./mask"; import { DEFAULT_INPUT_TEMPLATE, @@ -10,19 +16,19 @@ import { StoreKey, SUMMARIZE_MODEL, } from "../constant"; -import { api, RequestMessage } from "../client/api"; -import { ChatControllerPool } from "../client/controller"; +import { ChatControllerPool } from "../client/common/controller"; import { prettyObject } from "../utils/format"; import { estimateTokenLength } from "../utils/token"; import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; +import { RequestMessage, api } from "../client"; export type ChatMessage = RequestMessage & { date: string; streaming?: boolean; isError?: boolean; id: string; - model?: ModelType; + model?: string; }; export function createMessage(override: Partial): ChatMessage { @@ -84,46 +90,25 @@ function getSummarizeModel(currentModel: string) { return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel; } -interface ChatStore { - sessions: ChatSession[]; - currentSessionIndex: number; - clearSessions: () => void; - moveSession: (from: number, to: number) => void; - selectSession: (index: number) => void; - newSession: (mask?: Mask) => void; - deleteSession: (index: number) => void; - currentSession: () => ChatSession; - nextSession: (delta: number) => void; - onNewMessage: (message: ChatMessage) => void; - onUserInput: (content: string) => Promise; - summarizeSession: () => void; - updateStat: (message: ChatMessage) => void; - updateCurrentSession: (updater: (session: ChatSession) => void) => void; - updateMessage: ( - sessionIndex: number, - messageIndex: number, - updater: (message?: ChatMessage) => void, - ) => void; - resetSession: () => void; - getMessagesWithMemory: () => ChatMessage[]; - getMemoryPrompt: () => ChatMessage; - - clearAllData: () => void; -} - function countMessages(msgs: ChatMessage[]) { return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0); } -function fillTemplateWith(input: string, modelConfig: ModelConfig) { +function fillTemplateWith( + input: string, + context: { + model: string; + template?: string; + }, +) { const vars = { - model: modelConfig.model, + model: context.model, time: new Date().toLocaleString(), lang: getLang(), input: input, }; - let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE; + let output = context.template ?? DEFAULT_INPUT_TEMPLATE; // must contains {{input}} const inputVar = "{{input}}"; @@ -197,13 +182,13 @@ export const useChatStore = createPersistStore( if (mask) { const config = useAppConfig.getState(); - const globalModelConfig = config.modelConfig; + const globalModelConfig = config.globalMaskConfig; session.mask = { ...mask, - modelConfig: { + config: { ...globalModelConfig, - ...mask.modelConfig, + ...mask.config, }, }; session.topic = mask.name; @@ -288,11 +273,39 @@ export const useChatStore = createPersistStore( get().summarizeSession(); }, + getCurrentMaskConfig() { + return get().currentSession().mask.config; + }, + + extractModelConfig(maskConfig: MaskConfig) { + const provider = maskConfig.provider; + if (!maskConfig.modelConfig[provider]) { + throw Error("[Chat] failed to initialize provider: " + provider); + } + + return maskConfig.modelConfig[provider]; + }, + + getCurrentModelConfig() { + const maskConfig = this.getCurrentMaskConfig(); + return this.extractModelConfig(maskConfig); + }, + + getClient() { + const appConfig = useAppConfig.getState(); + const currentMaskConfig = get().getCurrentMaskConfig(); + return api.createLLMClient(appConfig.providerConfig, currentMaskConfig); + }, + async onUserInput(content: string) { const session = get().currentSession(); - const modelConfig = session.mask.modelConfig; + const maskConfig = this.getCurrentMaskConfig(); + const modelConfig = this.getCurrentModelConfig(); - const userContent = fillTemplateWith(content, modelConfig); + const userContent = fillTemplateWith(content, { + model: modelConfig.model, + template: maskConfig.chatConfig.template, + }); console.log("[User Input] after template: ", userContent); const userMessage: ChatMessage = createMessage({ @@ -323,10 +336,11 @@ export const useChatStore = createPersistStore( ]); }); + const client = this.getClient(); + // make request - api.llm.chat({ + client.chatStream({ messages: sendMessages, - config: { ...modelConfig, stream: true }, onUpdate(message) { botMessage.streaming = true; if (message) { @@ -391,7 +405,9 @@ export const useChatStore = createPersistStore( getMessagesWithMemory() { const session = get().currentSession(); - const modelConfig = session.mask.modelConfig; + const maskConfig = this.getCurrentMaskConfig(); + const chatConfig = maskConfig.chatConfig; + const modelConfig = this.getCurrentModelConfig(); const clearContextIndex = session.clearContextIndex ?? 0; const messages = session.messages.slice(); const totalMessageCount = session.messages.length; @@ -400,14 +416,14 @@ export const useChatStore = createPersistStore( const contextPrompts = session.mask.context.slice(); // system prompts, to get close to OpenAI Web ChatGPT - const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts; + const shouldInjectSystemPrompts = chatConfig.enableInjectSystemPrompts; const systemPrompts = shouldInjectSystemPrompts ? [ createMessage({ role: "system", content: fillTemplateWith("", { - ...modelConfig, - template: DEFAULT_SYSTEM_TEMPLATE, + model: modelConfig.model, + template: chatConfig.template, }), }), ] @@ -421,7 +437,7 @@ export const useChatStore = createPersistStore( // long term memory const shouldSendLongTermMemory = - modelConfig.sendMemory && + chatConfig.sendMemory && session.memoryPrompt && session.memoryPrompt.length > 0 && session.lastSummarizeIndex > clearContextIndex; @@ -433,7 +449,7 @@ export const useChatStore = createPersistStore( // short term memory const shortTermMemoryStartIndex = Math.max( 0, - totalMessageCount - modelConfig.historyMessageCount, + totalMessageCount - chatConfig.historyMessageCount, ); // lets concat send messages, including 4 parts: @@ -494,6 +510,8 @@ export const useChatStore = createPersistStore( summarizeSession() { const config = useAppConfig.getState(); + const maskConfig = this.getCurrentMaskConfig(); + const chatConfig = maskConfig.chatConfig; const session = get().currentSession(); // remove error messages if any @@ -502,7 +520,7 @@ export const useChatStore = createPersistStore( // should summarize topic after chating more than 50 words const SUMMARIZE_MIN_LEN = 50; if ( - config.enableAutoGenerateTitle && + chatConfig.enableAutoGenerateTitle && session.topic === DEFAULT_TOPIC && countMessages(messages) >= SUMMARIZE_MIN_LEN ) { @@ -512,11 +530,12 @@ export const useChatStore = createPersistStore( content: Locale.Store.Prompt.Topic, }), ); - api.llm.chat({ + + const client = this.getClient(); + client.chat({ messages: topicMessages, - config: { - model: getSummarizeModel(session.mask.modelConfig.model), - }, + shouldSummarize: true, + onFinish(message) { get().updateCurrentSession( (session) => @@ -527,7 +546,7 @@ export const useChatStore = createPersistStore( }); } - const modelConfig = session.mask.modelConfig; + const modelConfig = this.getCurrentModelConfig(); const summarizeIndex = Math.max( session.lastSummarizeIndex, session.clearContextIndex ?? 0, @@ -541,7 +560,7 @@ export const useChatStore = createPersistStore( if (historyMsgLength > modelConfig?.max_tokens ?? 4000) { const n = toBeSummarizedMsgs.length; toBeSummarizedMsgs = toBeSummarizedMsgs.slice( - Math.max(0, n - modelConfig.historyMessageCount), + Math.max(0, n - chatConfig.historyMessageCount), ); } @@ -554,14 +573,14 @@ export const useChatStore = createPersistStore( "[Chat History] ", toBeSummarizedMsgs, historyMsgLength, - modelConfig.compressMessageLengthThreshold, + chatConfig.compressMessageLengthThreshold, ); if ( - historyMsgLength > modelConfig.compressMessageLengthThreshold && - modelConfig.sendMemory + historyMsgLength > chatConfig.compressMessageLengthThreshold && + chatConfig.sendMemory ) { - api.llm.chat({ + this.getClient().chatStream({ messages: toBeSummarizedMsgs.concat( createMessage({ role: "system", @@ -569,11 +588,7 @@ export const useChatStore = createPersistStore( date: "", }), ), - config: { - ...modelConfig, - stream: true, - model: getSummarizeModel(session.mask.modelConfig.model), - }, + shouldSummarize: true, onUpdate(message) { session.memoryPrompt = message; }, @@ -614,52 +629,9 @@ export const useChatStore = createPersistStore( name: StoreKey.Chat, version: 3.1, migrate(persistedState, version) { - const state = persistedState as any; - const newState = JSON.parse( - JSON.stringify(state), - ) as typeof DEFAULT_CHAT_STATE; - - if (version < 2) { - newState.sessions = []; - - const oldSessions = state.sessions; - for (const oldSession of oldSessions) { - const newSession = createEmptySession(); - newSession.topic = oldSession.topic; - newSession.messages = [...oldSession.messages]; - newSession.mask.modelConfig.sendMemory = true; - newSession.mask.modelConfig.historyMessageCount = 4; - newSession.mask.modelConfig.compressMessageLengthThreshold = 1000; - newState.sessions.push(newSession); - } - } - - if (version < 3) { - // migrate id to nanoid - newState.sessions.forEach((s) => { - s.id = nanoid(); - s.messages.forEach((m) => (m.id = nanoid())); - }); - } - - // Enable `enableInjectSystemPrompts` attribute for old sessions. - // Resolve issue of old sessions not automatically enabling. - if (version < 3.1) { - newState.sessions.forEach((s) => { - if ( - // Exclude those already set by user - !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts") - ) { - // Because users may have changed this configuration, - // the user's current configuration is used instead of the default - const config = useAppConfig.getState(); - s.mask.modelConfig.enableInjectSystemPrompts = - config.modelConfig.enableInjectSystemPrompts; - } - }); - } + // TODO(yifei): migrate from old versions - return newState as any; + return persistedState as any; }, }, ); diff --git a/app/store/config.ts b/app/store/config.ts index 184355c94a3..6f388a8b130 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -1,4 +1,3 @@ -import { LLMModel } from "../client/api"; import { isMacOS } from "../utils"; import { getClientConfig } from "../config/client"; import { @@ -8,24 +7,85 @@ import { StoreKey, } from "../constant"; import { createPersistStore } from "../utils/store"; +import { OpenAIConfig } from "../client/openai/config"; +import { api } from "../client"; +import { SubmitKey, Theme } from "../typing"; export type ModelType = (typeof DEFAULT_MODELS)[number]["name"]; -export enum SubmitKey { - Enter = "Enter", - CtrlEnter = "Ctrl + Enter", - ShiftEnter = "Shift + Enter", - AltEnter = "Alt + Enter", - MetaEnter = "Meta + Enter", -} +export const DEFAULT_CHAT_CONFIG = { + enableAutoGenerateTitle: true, + sendMemory: true, + historyMessageCount: 4, + compressMessageLengthThreshold: 1000, + enableInjectSystemPrompts: true, + template: DEFAULT_INPUT_TEMPLATE, +}; +export type ChatConfig = typeof DEFAULT_CHAT_CONFIG; + +export const DEFAULT_PROVIDER_CONFIG = { + openai: OpenAIConfig.provider, + // azure: { + // endpoint: "https://api.openai.com", + // apiKey: "", + // version: "", + // ...COMMON_PROVIDER_CONFIG, + // }, + // claude: { + // endpoint: "https://api.anthropic.com", + // apiKey: "", + // ...COMMON_PROVIDER_CONFIG, + // }, + // google: { + // endpoint: "https://api.anthropic.com", + // apiKey: "", + // ...COMMON_PROVIDER_CONFIG, + // }, +}; -export enum Theme { - Auto = "auto", - Dark = "dark", - Light = "light", -} +export const DEFAULT_MODEL_CONFIG = { + openai: OpenAIConfig.model, + // azure: { + // model: "gpt-3.5-turbo" as string, + // summarizeModel: "gpt-3.5-turbo", + // + // temperature: 0.5, + // top_p: 1, + // max_tokens: 2000, + // presence_penalty: 0, + // frequency_penalty: 0, + // }, + // claude: { + // model: "claude-2", + // summarizeModel: "claude-2", + // + // max_tokens_to_sample: 100000, + // temperature: 1, + // top_p: 0.7, + // top_k: 1, + // }, + // google: { + // model: "chat-bison-001", + // summarizeModel: "claude-2", + // + // temperature: 1, + // topP: 0.7, + // topK: 1, + // }, +}; -export const DEFAULT_CONFIG = { +export type LLMProvider = keyof typeof DEFAULT_PROVIDER_CONFIG; +export const LLMProviders = Array.from( + Object.entries(DEFAULT_PROVIDER_CONFIG), +).map(([k, v]) => [v.name, k]); + +export const DEFAULT_MASK_CONFIG = { + provider: "openai" as LLMProvider, + chatConfig: { ...DEFAULT_CHAT_CONFIG }, + modelConfig: { ...DEFAULT_MODEL_CONFIG }, +}; + +export const DEFAULT_APP_CONFIG = { lastUpdate: Date.now(), // timestamp, to merge state submitKey: isMacOS() ? SubmitKey.MetaEnter : SubmitKey.CtrlEnter, @@ -34,7 +94,6 @@ export const DEFAULT_CONFIG = { theme: Theme.Auto as Theme, tightBorder: !!getClientConfig()?.isApp, sendPreviewBubble: true, - enableAutoGenerateTitle: true, sidebarWidth: DEFAULT_SIDEBAR_WIDTH, disablePromptHint: false, @@ -42,27 +101,14 @@ export const DEFAULT_CONFIG = { dontShowMaskSplashScreen: false, // dont show splash screen when create chat hideBuiltinMasks: false, // dont add builtin masks - customModels: "", - models: DEFAULT_MODELS as any as LLMModel[], - - modelConfig: { - model: "gpt-3.5-turbo" as ModelType, - temperature: 0.5, - top_p: 1, - max_tokens: 2000, - presence_penalty: 0, - frequency_penalty: 0, - sendMemory: true, - historyMessageCount: 4, - compressMessageLengthThreshold: 1000, - enableInjectSystemPrompts: true, - template: DEFAULT_INPUT_TEMPLATE, - }, + providerConfig: { ...DEFAULT_PROVIDER_CONFIG }, + globalMaskConfig: { ...DEFAULT_MASK_CONFIG }, }; -export type ChatConfig = typeof DEFAULT_CONFIG; - -export type ModelConfig = ChatConfig["modelConfig"]; +export type AppConfig = typeof DEFAULT_APP_CONFIG; +export type ProviderConfig = typeof DEFAULT_PROVIDER_CONFIG; +export type MaskConfig = typeof DEFAULT_MASK_CONFIG; +export type ModelConfig = typeof DEFAULT_MODEL_CONFIG; export function limitNumber( x: number, @@ -99,48 +145,21 @@ export const ModalConfigValidator = { }; export const useAppConfig = createPersistStore( - { ...DEFAULT_CONFIG }, + { ...DEFAULT_APP_CONFIG }, (set, get) => ({ reset() { - set(() => ({ ...DEFAULT_CONFIG })); + set(() => ({ ...DEFAULT_APP_CONFIG })); }, - mergeModels(newModels: LLMModel[]) { - if (!newModels || newModels.length === 0) { - return; - } - - const oldModels = get().models; - const modelMap: Record = {}; - - for (const model of oldModels) { - model.available = false; - modelMap[model.name] = model; - } - - for (const model of newModels) { - model.available = true; - modelMap[model.name] = model; - } - - set(() => ({ - models: Object.values(modelMap), - })); - }, - - allModels() { - const customModels = get() - .customModels.split(",") - .filter((v) => !!v && v.length > 0) - .map((m) => ({ name: m, available: true })); - return get().models.concat(customModels); + getDefaultClient() { + return api.createLLMClient(get().providerConfig, get().globalMaskConfig); }, }), { name: StoreKey.Config, - version: 3.8, + version: 4, migrate(persistedState, version) { - const state = persistedState as ChatConfig; + const state = persistedState as any; if (version < 3.4) { state.modelConfig.sendMemory = true; @@ -169,6 +188,10 @@ export const useAppConfig = createPersistStore( state.lastUpdate = Date.now(); } + if (version < 4) { + // todo: migarte from old versions + } + return state as any; }, }, diff --git a/app/store/mask.ts b/app/store/mask.ts index dfd4089b757..6fcf7b9b832 100644 --- a/app/store/mask.ts +++ b/app/store/mask.ts @@ -1,10 +1,11 @@ import { BUILTIN_MASKS } from "../masks"; import { getLang, Lang } from "../locales"; import { DEFAULT_TOPIC, ChatMessage } from "./chat"; -import { ModelConfig, useAppConfig } from "./config"; +import { MaskConfig, ModelConfig, useAppConfig } from "./config"; import { StoreKey } from "../constant"; import { nanoid } from "nanoid"; import { createPersistStore } from "../utils/store"; +import { deepClone } from "../utils/clone"; export type Mask = { id: string; @@ -14,7 +15,9 @@ export type Mask = { hideContext?: boolean; context: ChatMessage[]; syncGlobalConfig?: boolean; - modelConfig: ModelConfig; + + config: MaskConfig; + lang: Lang; builtin: boolean; }; @@ -33,7 +36,7 @@ export const createEmptyMask = () => name: DEFAULT_TOPIC, context: [], syncGlobalConfig: true, // use global config as default - modelConfig: { ...useAppConfig.getState().modelConfig }, + config: deepClone(useAppConfig.getState().globalMaskConfig), lang: getLang(), builtin: false, createdAt: Date.now(), @@ -87,10 +90,11 @@ export const useMaskStore = createPersistStore( const buildinMasks = BUILTIN_MASKS.map( (m) => ({ + id: m.name, ...m, - modelConfig: { - ...config.modelConfig, - ...m.modelConfig, + config: { + ...config.globalMaskConfig, + ...m.config, }, }) as Mask, ); @@ -120,6 +124,8 @@ export const useMaskStore = createPersistStore( newState.masks = updatedMasks; } + // TODO(yifei): migrate old masks + return newState as any; }, }, diff --git a/app/store/sync.ts b/app/store/sync.ts index b74f6895f6d..17cfdd2fd04 100644 --- a/app/store/sync.ts +++ b/app/store/sync.ts @@ -13,7 +13,7 @@ import { downloadAs, readFromFile } from "../utils"; import { showToast } from "../components/ui-lib"; import Locale from "../locales"; import { createSyncClient, ProviderType } from "../utils/cloud"; -import { corsPath } from "../utils/cors"; +import { getApiPath } from "../utils/path"; export interface WebDavConfig { server: string; @@ -27,7 +27,7 @@ export type SyncStore = GetStoreState; const DEFAULT_SYNC_STATE = { provider: ProviderType.WebDAV, useProxy: true, - proxyUrl: corsPath(ApiPath.Cors), + proxyUrl: getApiPath(ApiPath.Cors), webdav: { endpoint: "", diff --git a/app/store/update.ts b/app/store/update.ts index 2b088a13d7a..0e63e12034f 100644 --- a/app/store/update.ts +++ b/app/store/update.ts @@ -1,5 +1,4 @@ import { FETCH_COMMIT_URL, FETCH_TAG_URL, StoreKey } from "../constant"; -import { api } from "../client/api"; import { getClientConfig } from "../config/client"; import { createPersistStore } from "../utils/store"; import ChatGptIcon from "../icons/chatgpt.png"; @@ -85,35 +84,40 @@ export const useUpdateStore = createPersistStore( })); if (window.__TAURI__?.notification && isApp) { // Check if notification permission is granted - await window.__TAURI__?.notification.isPermissionGranted().then((granted) => { - if (!granted) { - return; - } else { - // Request permission to show notifications - window.__TAURI__?.notification.requestPermission().then((permission) => { - if (permission === 'granted') { - if (version === remoteId) { - // Show a notification using Tauri - window.__TAURI__?.notification.sendNotification({ - title: "ChatGPT Next Web", - body: `${Locale.Settings.Update.IsLatest}`, - icon: `${ChatGptIcon.src}`, - sound: "Default" - }); - } else { - const updateMessage = Locale.Settings.Update.FoundUpdate(`${remoteId}`); - // Show a notification for the new version using Tauri - window.__TAURI__?.notification.sendNotification({ - title: "ChatGPT Next Web", - body: updateMessage, - icon: `${ChatGptIcon.src}`, - sound: "Default" - }); - } - } - }); - } - }); + await window.__TAURI__?.notification + .isPermissionGranted() + .then((granted) => { + if (!granted) { + return; + } else { + // Request permission to show notifications + window.__TAURI__?.notification + .requestPermission() + .then((permission) => { + if (permission === "granted") { + if (version === remoteId) { + // Show a notification using Tauri + window.__TAURI__?.notification.sendNotification({ + title: "ChatGPT Next Web", + body: `${Locale.Settings.Update.IsLatest}`, + icon: `${ChatGptIcon.src}`, + sound: "Default", + }); + } else { + const updateMessage = + Locale.Settings.Update.FoundUpdate(`${remoteId}`); + // Show a notification for the new version using Tauri + window.__TAURI__?.notification.sendNotification({ + title: "ChatGPT Next Web", + body: updateMessage, + icon: `${ChatGptIcon.src}`, + sound: "Default", + }); + } + } + }); + } + }); } console.log("[Got Upstream] ", remoteId); } catch (error) { @@ -130,14 +134,7 @@ export const useUpdateStore = createPersistStore( })); try { - const usage = await api.llm.usage(); - - if (usage) { - set(() => ({ - used: usage.used, - subscription: usage.total, - })); - } + // TODO: add check usage api here } catch (e) { console.error((e as Error).message); } diff --git a/app/typing.ts b/app/typing.ts index 25e474abf1d..6ed87882f60 100644 --- a/app/typing.ts +++ b/app/typing.ts @@ -1 +1,15 @@ export type Updater = (updater: (value: T) => void) => void; + +export enum SubmitKey { + Enter = "Enter", + CtrlEnter = "Ctrl + Enter", + ShiftEnter = "Shift + Enter", + AltEnter = "Alt + Enter", + MetaEnter = "Meta + Enter", +} + +export enum Theme { + Auto = "auto", + Dark = "dark", + Light = "light", +} diff --git a/app/utils/clone.ts b/app/utils/clone.ts index 2958b6b9c35..e8971acfbea 100644 --- a/app/utils/clone.ts +++ b/app/utils/clone.ts @@ -1,3 +1,3 @@ -export function deepClone(obj: T) { +export function deepClone(obj: T): T { return JSON.parse(JSON.stringify(obj)); } diff --git a/app/utils/cloud/index.ts b/app/utils/cloud/index.ts index 63908249e85..e6905bb2906 100644 --- a/app/utils/cloud/index.ts +++ b/app/utils/cloud/index.ts @@ -1,5 +1,6 @@ import { createWebDavClient } from "./webdav"; import { createUpstashClient } from "./upstash"; +import { SyncStore } from "@/app/store/sync"; export enum ProviderType { WebDAV = "webdav", @@ -27,7 +28,7 @@ export type SyncClient = { export function createSyncClient( provider: T, - config: SyncClientConfig[T], + store: SyncStore, ): SyncClient { - return SyncClients[provider](config as any) as any; + return SyncClients[provider](store); } diff --git a/app/utils/cloud/upstash.ts b/app/utils/cloud/upstash.ts index 5f5b9fc7925..abc1b4cc9b0 100644 --- a/app/utils/cloud/upstash.ts +++ b/app/utils/cloud/upstash.ts @@ -57,7 +57,7 @@ export function createUpstashClient(store: SyncStore) { async get() { const chunkCount = Number(await this.redisGet(chunkCountKey)); - if (!Number.isInteger(chunkCount)) return; + if (!Number.isInteger(chunkCount)) return ""; const chunks = await Promise.all( new Array(chunkCount) diff --git a/app/utils/cors.ts b/app/utils/cors.ts index 773f152aafa..6eb77705e59 100644 --- a/app/utils/cors.ts +++ b/app/utils/cors.ts @@ -1,19 +1,5 @@ -import { getClientConfig } from "../config/client"; -import { ApiPath, DEFAULT_CORS_HOST } from "../constant"; - -export function corsPath(path: string) { - const baseUrl = getClientConfig()?.isApp ? `${DEFAULT_CORS_HOST}` : ""; - - if (!path.startsWith("/")) { - path = "/" + path; - } - - if (!path.endsWith("/")) { - path += "/"; - } - - return `${baseUrl}${path}`; -} +import { ApiPath } from "../constant"; +import { getApiPath } from "./path"; export function corsFetch( url: string, @@ -25,7 +11,7 @@ export function corsFetch( throw Error("[CORS Fetch] url must starts with http/https"); } - let proxyUrl = options.proxyUrl ?? corsPath(ApiPath.Cors); + let proxyUrl = options.proxyUrl ?? getApiPath(ApiPath.Cors); if (!proxyUrl.endsWith("/")) { proxyUrl += "/"; } diff --git a/app/utils/log.ts b/app/utils/log.ts new file mode 100644 index 00000000000..443033c53d0 --- /dev/null +++ b/app/utils/log.ts @@ -0,0 +1,13 @@ +export function createLogger(prefix = "") { + return { + log(...args: any[]) { + console.log(prefix, ...args); + }, + error(...args: any[]) { + console.error(prefix, ...args); + }, + warn(...args: any[]) { + console.warn(prefix, ...args); + }, + }; +} diff --git a/app/utils/object.ts b/app/utils/object.ts new file mode 100644 index 00000000000..7fc74aee6b0 --- /dev/null +++ b/app/utils/object.ts @@ -0,0 +1,17 @@ +export function pick( + obj: T, + ...keys: U +): Pick { + const ret: any = {}; + keys.forEach((key) => (ret[key] = obj[key])); + return ret; +} + +export function omit( + obj: T, + ...keys: U +): Omit { + const ret: any = { ...obj }; + keys.forEach((key) => delete ret[key]); + return ret; +} diff --git a/app/utils/path.ts b/app/utils/path.ts new file mode 100644 index 00000000000..6609352d720 --- /dev/null +++ b/app/utils/path.ts @@ -0,0 +1,16 @@ +import { getClientConfig } from "../config/client"; +import { ApiPath, REMOTE_API_HOST } from "../constant"; + +/** + * Get api path according to desktop/web env + * + * 1. In desktop app, we always try to use a remote full path for better network experience + * 2. In web app, we always try to use the original relative path + * + * @param path - /api/* + * @returns + */ +export function getApiPath(path: ApiPath) { + const baseUrl = getClientConfig()?.isApp ? `${REMOTE_API_HOST}` : ""; + return `${baseUrl}${path}`; +} diff --git a/app/utils/string.ts b/app/utils/string.ts new file mode 100644 index 00000000000..68fc47d55c7 --- /dev/null +++ b/app/utils/string.ts @@ -0,0 +1,19 @@ +export function trimEnd(s: string, end = " ") { + if (end.length === 0) return s; + + while (s.endsWith(end)) { + s = s.slice(0, -end.length); + } + + return s; +} + +export function trimStart(s: string, start = " ") { + if (start.length === 0) return s; + + while (s.endsWith(start)) { + s = s.slice(start.length); + } + + return s; +}