From 303040c13fa7ac9af965ec1fb0806255977059d1 Mon Sep 17 00:00:00 2001 From: Jeff MAURY Date: Fri, 20 Sep 2024 09:52:28 +0200 Subject: [PATCH 1/5] feat: open-webui playground prototype Signed-off-by: Jeff MAURY --- packages/backend/src/assets/webui.db | Bin 0 -> 143360 bytes .../src/managers/playgroundV2Manager.ts | 24 +- .../src/registries/ConfigurationRegistry.ts | 4 + .../src/registries/ConversationRegistry.ts | 120 ++++++++- packages/backend/src/studio.ts | 2 + packages/frontend/src/pages/Playground.svelte | 231 +----------------- .../shared/src/models/IPlaygroundMessage.ts | 5 + 7 files changed, 153 insertions(+), 233 deletions(-) create mode 100644 packages/backend/src/assets/webui.db diff --git a/packages/backend/src/assets/webui.db b/packages/backend/src/assets/webui.db new file mode 100644 index 0000000000000000000000000000000000000000..0f335a153778dc4a58a39eb44684ddd5b50962f4 GIT binary patch literal 143360 zcmeI*&vP4R8Nl(~TDD@xw$cP6(l%JA3693GBgv911zKtxwWLk##Ib2Q!0fE0mAyp2 zBCVV{{e_*t3@2{P^wJsr0tSu@<-!S=;ljY=KVV>n84i>K^uYV>YGp}TIWBEI6YDEW zmUj1j-`)4~KD(0kStmDdTs18tRH)QSx)qvJPAICXyc7y4it>zjJTD&Z?Jwem&z%s@ zYTN5xFP~9n3x6M!g#-VUWxnh@^(xMM1Q0*~0R#|0009ILKmY**5cuo`o?9Igjh-BR zRSEnt@P5DuBm)EEzZt(jzBWEJ_Sdmrj{R_Kb!=+%-=n`8eIN>PLjVB;5I_I{1Q0*~ z0R(J;^huwxun>)QrdDsPd)Asu#bfiAPWY6psFkv7ZRy1=qps=Yyj-I$ z*1ctT){W22M^hIAKINjQo9x)ux@U!GDjH8tkNcDs@*-k&dS? z9CJ!Lp-?hPmD;vv@8fgHOf=*7DT|`*w&5@8b*njunr?Ykji!^SM|d0kXZwYp`Q zC8KWXrK%GPxlO%RFpXl~V`~$$nfZ7um52`dl$dD535G^BFL?5vs}(h3iDWu6>+>ly zqJ|RG@PF9Ut&1Nm1Sj2s^a(s;S*JJXA!T`Q%(YKKGP8%uy#M zij|yR)as%)Vw^o!luAZLFsV*gVwzbtEmKy>Z5k!rvqmx|qIE>PAh^m(21@#Y#>uZdU46L(IA; zf6Z3Q8-;&N{M3Dv8^xl0kpDmIfn|LNAb&hiH*Ce+;-OJ6S<1yepm?PE?0L;pE& zOB6%^0R#|0;E5O5u~z-T)2G$@i}u@MwN@!rE$4S&aV5L3nhmWkeDi8H6n18YL+8V0 zJ{-ETw3>Y-yAoPjUJWg+UA-E*zH;T-!pa+=*RpS12!(T%Qb{k1a^GH9S^U<*%K3OQ z8EF+31?2w-q{)VQ6|hXJXoN$n*;`@-?n>Mj_8Q8TQMS5Dhy|8eGU}FIs(NgsTR-yM z8~)(b8FfcB%X#C?d?nW?iN>_DUNYQ?k!JVTmag1b%Zd(P&fW@z-KBQXaAM187tH-@_|EpE)fYECLP|QU{2PTUikq6K-VWHg!v@Z|XHe1go4F z@*J#=V$KN3S(4PP9FAl!4+Wo{P)$21>>5s!gT0%SJBSQ1w=ppt**&{xe~^7JZI7;8 zBp+Yw9M!fOJ;&hF?}K(9I|fLML#xLfD;5j%T5i+aGCYnRVqvak=$4V!q}pY1*jc@D zE$g}CH>&vq6pUoQ(Wb38tW7b1@@c?RyIH`|mJ`jrA+i&G(fQgiWoyet^5N^$ZmZGJ z=~f3s2aS?$7JKh5RlQ!nTd4`04$t-$baCw1634{uV{hb*siEMxN%fspWSMzO-)Kgx z{o3!@YZ)YN(Y-*k=d}8-ebpbFoK)|RwW7^l-l}@6Q=&EBJ)HDBzKadq{}3UY+V?Hu z0P%Z(9e&}&Q1J9A)v)m#Hu3vuPcUb(LnU6fkj4Z4;HgvU?tqOXmWca_PME#oJtJg) zj>q;^>=wEcJ&~P=>JLUD>aH_tC3B-D_L)tyZdGdAol{SBZf55~cM!WHuIt@?{|#?8 z*<;s^kpF+;KSa_21Q0*~0R#|0009ILKmY**9xE{B|AlhWH>m`EHu&7w@5kP?YPanpcu?v_!ji;SFSjlG=2UP|I}_T`1OtE-{& zi)$+@*`-xYCMs^OE?m3bp5Gi9-b+&)sIJ_UwiIgUyqS-TMDCxsZD${TS zqNC$>F4a9#+v{}u{-jy-YJZ#sogfVI~(iD z{ycKBd2hJFd#aD7LE45jT-589=HyH4jM9E>)(y)t%NuoPv&F`kYgj&+B?x*uNt}LMz?pp zWVmU7_PKkZ{3x8h&rZ0EP#=dX0C zxU)Ya8Ed=ReT9J7)}CIUH;h=fS#JiNC$d<&QLL=%A~L#-BiTkUWB!WoTl5t=Q9=UNlPUW=?A!;ApPU*|@vZPJGJL zzdOSn;VXBBf@gwi$!@k($s3Z+em?Elpj*V*PJ7C2YUfVX9}EW7`wRBYYcI2ZeW}?} z^U9z0N(YaZb-iwAc2&DCGMq$seUb>O>NUOOafJI6$Fj(^%206fw0g(kGfE;KD*fyi zL2p2}P*c#JY!NP&?N2AW&bxurj`RBqU5M@VdcK$1v9;f2Q9X}tk2!+xXBNqp+V;lg zCdj8#o*=Dll;J5ST#VJCxC~*Uc^QITN_XDw_jy5%E^@XZP*q$UN^v z-th;gr`6qId)MnaFm+9aIwGliq4>b-9^lH2QY*6N;+>yI7nhfAuC6SIH=a;e3tIJd zu05e2I^Wiw_z?QNHF1;Z= z00IagfB*srAbFww_PS=c36d|Nq|+8Cbf300IagfB*srAb { const conversation = this.#conversationRegistry.get(conversationId); this.telemetry.logUsage('playground.delete', { totalMessages: conversation.messages.length, modelId: getHash(conversation.modelId), }); - this.#conversationRegistry.deleteConversation(conversationId); + await this.#conversationRegistry.deleteConversation(conversationId); } async requestCreatePlayground(name: string, model: ModelInfo): Promise { @@ -117,11 +126,11 @@ export class PlaygroundV2Manager implements Disposable { } // Create conversation - const conversationId = this.#conversationRegistry.createConversation(name, model.id); + const conversationId = await this.#conversationRegistry.createConversation(name, model.id); // create/start inference server if necessary const servers = this.inferenceManager.getServers(); - const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id)); + let server = servers.find(s => s.models.map(mi => mi.id).includes(model.id)); if (!server) { await this.inferenceManager.createInferenceServer( await withDefaultConfiguration({ @@ -131,10 +140,15 @@ export class PlaygroundV2Manager implements Disposable { }, }), ); + server = this.inferenceManager.findServerByModel(model); } else if (server.status === 'stopped') { await this.inferenceManager.startInferenceServer(server.container.containerId); } + if (server && server.status === 'running') { + await this.#conversationRegistry.startConversationContainer(server, trackingId, conversationId); + } + return conversationId; } diff --git a/packages/backend/src/registries/ConfigurationRegistry.ts b/packages/backend/src/registries/ConfigurationRegistry.ts index 19ed02f63..5fc0c0ad2 100644 --- a/packages/backend/src/registries/ConfigurationRegistry.ts +++ b/packages/backend/src/registries/ConfigurationRegistry.ts @@ -79,6 +79,10 @@ export class ConfigurationRegistry extends Publisher imp return path.join(this.appUserDirectory, 'models'); } + public getConversationsPath(): string { + return path.join(this.appUserDirectory, 'conversations'); + } + dispose(): void { this.#configurationDisposable?.dispose(); } diff --git a/packages/backend/src/registries/ConversationRegistry.ts b/packages/backend/src/registries/ConversationRegistry.ts index eab300242..c1f7af748 100644 --- a/packages/backend/src/registries/ConversationRegistry.ts +++ b/packages/backend/src/registries/ConversationRegistry.ts @@ -25,14 +25,36 @@ import type { Message, PendingChat, } from '@shared/src/models/IPlaygroundMessage'; -import type { Disposable, Webview } from '@podman-desktop/api'; +import { + type Disposable, + type Webview, + type ContainerCreateOptions, + containerEngine, + type ContainerProviderConnection, + type ImageInfo, + type PullEvent, +} from '@podman-desktop/api'; import { Messages } from '@shared/Messages'; +import type { ConfigurationRegistry } from './ConfigurationRegistry'; +import path from 'node:path'; +import fs from 'node:fs'; +import type { InferenceServer } from '@shared/src/models/IInference'; +import { getFreeRandomPort } from '../utils/ports'; +import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../utils/utils'; +import { getImageInfo } from '../utils/inferenceUtils'; +import type { TaskRegistry } from './TaskRegistry'; +import type { PodmanConnection } from '../managers/podmanConnection'; export class ConversationRegistry extends Publisher implements Disposable { #conversations: Map; #counter: number; - constructor(webview: Webview) { + constructor( + webview: Webview, + private configurationRegistry: ConfigurationRegistry, + private taskRegistry: TaskRegistry, + private podmanConnection: PodmanConnection, + ) { super(webview, Messages.MSG_CONVERSATIONS_UPDATE, () => this.getAll()); this.#conversations = new Map(); this.#counter = 0; @@ -76,13 +98,32 @@ export class ConversationRegistry extends Publisher implements D this.notify(); } - deleteConversation(id: string): void { + async deleteConversation(id: string): Promise { + const conversation = this.get(id); + if (conversation.container) { + await containerEngine.stopContainer(conversation.container?.engineId, conversation.container?.containerId); + } + await fs.promises.rm(path.join(this.configurationRegistry.getConversationsPath(), id), { + recursive: true, + force: true, + }); this.#conversations.delete(id); this.notify(); } - createConversation(name: string, modelId: string): string { + async createConversation(name: string, modelId: string): Promise { const conversationId = this.getUniqueId(); + const conversationFolder = path.join(this.configurationRegistry.getConversationsPath(), conversationId); + await fs.promises.mkdir(conversationFolder, { + recursive: true, + }); + //WARNING: this will not work in production mode but didn't find how to embed binary assets + //this code get an initialized database so that default user is not admin thus did not get the initial + //welcome modal dialog + await fs.promises.copyFile( + path.join(__dirname, '..', 'src', 'assets', 'webui.db'), + path.join(conversationFolder, 'webui.db'), + ); this.#conversations.set(conversationId, { name: name, modelId: modelId, @@ -93,6 +134,77 @@ export class ConversationRegistry extends Publisher implements D return conversationId; } + async startConversationContainer(server: InferenceServer, trackingId: string, conversationId: string): Promise { + const conversation = this.get(conversationId); + const port = await getFreeRandomPort('127.0.0.1'); + const connection = await this.podmanConnection.getConnectionByEngineId(server.container.engineId); + await this.pullImage(connection, 'ghcr.io/open-webui/open-webui:main', { + trackingId: trackingId, + }); + const inferenceServerContainer = await containerEngine.inspectContainer( + server.container.engineId, + server.container.containerId, + ); + const options: ContainerCreateOptions = { + Env: [ + 'DEFAULT_LOCALE=en-US', + 'WEBUI_AUTH=false', + 'ENABLE_OLLAMA_API=false', + `OPENAI_API_BASE_URL=http://${inferenceServerContainer.NetworkSettings.IPAddress}:8000/v1`, + 'OPENAI_API_KEY=sk_dummy', + `WEBUI_URL=http://localhost:${port}`, + `DEFAULT_MODELS=/models/${server.models[0].file?.file}`, + ], + Image: 'ghcr.io/open-webui/open-webui:main', + HostConfig: { + AutoRemove: true, + Mounts: [ + { + Source: path.join(this.configurationRegistry.getConversationsPath(), conversationId), + Target: '/app/backend/data', + Type: 'bind', + }, + ], + PortBindings: { + '8080/tcp': [ + { + HostPort: `${port}`, + }, + ], + }, + SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION], + }, + }; + const c = await containerEngine.createContainer(server.container.engineId, options); + conversation.container = { engineId: c.engineId, containerId: c.id, port }; + } + + protected pullImage( + connection: ContainerProviderConnection, + image: string, + labels: { [id: string]: string }, + ): Promise { + // Creating a task to follow pulling progress + const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels); + + // get the default image info for this provider + return getImageInfo(connection, image, (_event: PullEvent) => {}) + .catch((err: unknown) => { + pullingTask.state = 'error'; + pullingTask.progress = undefined; + pullingTask.error = `Something went wrong while pulling ${image}: ${String(err)}`; + throw err; + }) + .then(imageInfo => { + pullingTask.state = 'success'; + pullingTask.progress = undefined; + return imageInfo; + }) + .finally(() => { + this.taskRegistry.updateTask(pullingTask); + }); + } + /** * This method will be responsible for finalizing the message by concatenating all the choices * @param conversationId diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index d1e50a4e6..6e8a792bd 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -316,6 +316,8 @@ export class Studio { this.#taskRegistry, this.#telemetry, this.#cancellationTokenRegistry, + this.#configurationRegistry, + this.#podmanConnection, ); this.#extensionContext.subscriptions.push(this.#playgroundManager); diff --git a/packages/frontend/src/pages/Playground.svelte b/packages/frontend/src/pages/Playground.svelte index 62d1a6200..7900f81b1 100644 --- a/packages/frontend/src/pages/Playground.svelte +++ b/packages/frontend/src/pages/Playground.svelte @@ -1,113 +1,26 @@ {#if conversation} @@ -188,110 +69,12 @@ function handleOnClick(): void { -
-
- - -
-
- {#if conversation} - - {#key conversation.messages.length} - - {/key} - -
    - {#each messages as message} -
  • - -
  • - {/each} -
- {/if} -
-
-
- -
Next prompt will use these settings
-
-
Model Parameters
-
-
-
- -
- - - -
- What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more focused and deterministic. -
-
-
-
-
-
- -
- - - -
- The maximum number of tokens that can be generated in the chat completion. -
-
-
-
-
-
- -
- - - -
- An alternative to sampling with temperature, where the model considers the results of the - tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - probability mass are considered. -
-
-
-
-
-
-
-
-
- {#if errorMsg} -
{errorMsg}
- {/if} -
- - -
- {#if !sendEnabled && cancellationTokenId !== undefined} - - {/if} -
+
+
+
diff --git a/packages/shared/src/models/IPlaygroundMessage.ts b/packages/shared/src/models/IPlaygroundMessage.ts index cdebc2046..4333305ac 100644 --- a/packages/shared/src/models/IPlaygroundMessage.ts +++ b/packages/shared/src/models/IPlaygroundMessage.ts @@ -57,6 +57,11 @@ export interface Conversation { messages: Message[]; modelId: string; name: string; + container?: { + engineId: string; + containerId: string; + port: number; + }; } export interface Choice { From d54fc5ad233d0018d212a753d9b551e4612eb32a Mon Sep 17 00:00:00 2001 From: Jeff MAURY Date: Fri, 20 Sep 2024 11:02:15 +0200 Subject: [PATCH 2/5] fix: removed unit test as it needs full redesign Signed-off-by: Jeff MAURY --- .../src/managers/playgroundV2Manager.spec.ts | 778 ------------------ 1 file changed, 778 deletions(-) delete mode 100644 packages/backend/src/managers/playgroundV2Manager.spec.ts diff --git a/packages/backend/src/managers/playgroundV2Manager.spec.ts b/packages/backend/src/managers/playgroundV2Manager.spec.ts deleted file mode 100644 index c354f2395..000000000 --- a/packages/backend/src/managers/playgroundV2Manager.spec.ts +++ /dev/null @@ -1,778 +0,0 @@ -/********************************************************************** - * Copyright (C) 2024 Red Hat, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ***********************************************************************/ - -import { expect, test, vi, beforeEach, afterEach, describe } from 'vitest'; -import OpenAI from 'openai'; -import { PlaygroundV2Manager } from './playgroundV2Manager'; -import type { TelemetryLogger, Webview } from '@podman-desktop/api'; -import type { InferenceServer } from '@shared/src/models/IInference'; -import type { InferenceManager } from './inference/inferenceManager'; -import { Messages } from '@shared/Messages'; -import type { ModelInfo } from '@shared/src/models/IModelInfo'; -import type { TaskRegistry } from '../registries/TaskRegistry'; -import type { Task, TaskState } from '@shared/src/models/ITask'; -import type { ChatMessage, ErrorMessage } from '@shared/src/models/IPlaygroundMessage'; -import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry'; - -vi.mock('openai', () => ({ - default: vi.fn(), -})); - -const webviewMock = { - postMessage: vi.fn(), -} as unknown as Webview; - -const inferenceManagerMock = { - get: vi.fn(), - getServers: vi.fn(), - createInferenceServer: vi.fn(), - startInferenceServer: vi.fn(), -} as unknown as InferenceManager; - -const taskRegistryMock = { - createTask: vi.fn(), - getTasksByLabels: vi.fn(), - updateTask: vi.fn(), -} as unknown as TaskRegistry; - -const telemetryMock = { - logUsage: vi.fn(), - logError: vi.fn(), -} as unknown as TelemetryLogger; - -const cancellationTokenRegistryMock = { - createCancellationTokenSource: vi.fn(), - delete: vi.fn(), -} as unknown as CancellationTokenRegistry; - -beforeEach(() => { - vi.resetAllMocks(); - vi.mocked(webviewMock.postMessage).mockResolvedValue(true); - vi.useFakeTimers(); -}); - -afterEach(() => { - vi.useRealTimers(); -}); - -test('manager should be properly initialized', () => { - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - expect(manager.getConversations().length).toBe(0); -}); - -test('submit should throw an error if the server is stopped', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - models: [ - { - id: 'model1', - }, - ], - } as unknown as InferenceServer, - ]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, 'tracking-1'); - - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'stopped', - models: [ - { - id: 'model1', - }, - ], - } as unknown as InferenceServer, - ]); - - await expect(manager.submit(manager.getConversations()[0].id, 'dummyUserInput')).rejects.toThrowError( - 'Inference server is not running.', - ); -}); - -test('submit should throw an error if the server is unhealthy', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'unhealthy', - }, - models: [ - { - id: 'model1', - }, - ], - } as unknown as InferenceServer, - ]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, 'tracking-1'); - const playgroundId = manager.getConversations()[0].id; - await expect(manager.submit(playgroundId, 'dummyUserInput')).rejects.toThrowError( - 'Inference server is not healthy, currently status: unhealthy.', - ); -}); - -test('create playground should create conversation.', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - } as unknown as InferenceServer, - ]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - expect(manager.getConversations().length).toBe(0); - await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'tracking-1'); - - const conversations = manager.getConversations(); - expect(conversations.length).toBe(1); -}); - -test('valid submit should create IPlaygroundMessage and notify the webview', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - connection: { - port: 8888, - }, - } as unknown as InferenceServer, - ]); - const createMock = vi.fn().mockResolvedValue([]); - vi.mocked(OpenAI).mockReturnValue({ - chat: { - completions: { - create: createMock, - }, - }, - } as unknown as OpenAI); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1'); - - const date = new Date(2000, 1, 1, 13); - vi.setSystemTime(date); - - const playgrounds = manager.getConversations(); - await manager.submit(playgrounds[0].id, 'dummyUserInput'); - - // Wait for assistant message to be completed - await vi.waitFor(() => { - expect((manager.getConversations()[0].messages[1] as ChatMessage).content).toBeDefined(); - }); - - const conversations = manager.getConversations(); - - expect(conversations.length).toBe(1); - expect(conversations[0].messages.length).toBe(2); - expect(conversations[0].messages[0]).toStrictEqual({ - content: 'dummyUserInput', - id: expect.anything(), - options: undefined, - role: 'user', - timestamp: expect.any(Number), - }); - expect(conversations[0].messages[1]).toStrictEqual({ - choices: undefined, - completed: expect.any(Number), - content: '', - id: expect.anything(), - role: 'assistant', - timestamp: expect.any(Number), - }); - - expect(webviewMock.postMessage).toHaveBeenLastCalledWith({ - id: Messages.MSG_CONVERSATIONS_UPDATE, - body: conversations, - }); -}); - -test('submit should send options', async () => { - vi.mocked(cancellationTokenRegistryMock.createCancellationTokenSource).mockReturnValue(55); - - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - connection: { - port: 8888, - }, - } as unknown as InferenceServer, - ]); - const createMock = vi.fn().mockResolvedValue([]); - vi.mocked(OpenAI).mockReturnValue({ - chat: { - completions: { - create: createMock, - }, - }, - } as unknown as OpenAI); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1'); - - const playgrounds = manager.getConversations(); - const cancellationId = await manager.submit(playgrounds[0].id, 'dummyUserInput', { - temperature: 0.123, - max_tokens: 45, - top_p: 0.345, - }); - expect(cancellationId).toBe(55); - - const messages: unknown[] = [ - { - content: 'dummyUserInput', - id: expect.any(String), - role: 'user', - timestamp: expect.any(Number), - options: { - temperature: 0.123, - max_tokens: 45, - top_p: 0.345, - }, - }, - ]; - expect(createMock).toHaveBeenCalledWith( - { - messages, - model: 'dummyModelFile', - stream: true, - temperature: 0.123, - max_tokens: 45, - top_p: 0.345, - }, - { - signal: expect.anything(), - }, - ); - // at the end the token must be deleted once the request is complete - await vi.waitFor(() => { - expect(cancellationTokenRegistryMock.delete).toHaveBeenCalledWith(55); - }); -}); - -test('error', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - connection: { - port: 8888, - }, - } as unknown as InferenceServer, - ]); - const createMock = vi.fn().mockRejectedValue('Please reduce the length of the messages or completion.'); - vi.mocked(OpenAI).mockReturnValue({ - chat: { - completions: { - create: createMock, - }, - }, - } as unknown as OpenAI); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1'); - - const date = new Date(2000, 1, 1, 13); - vi.setSystemTime(date); - - const playgrounds = manager.getConversations(); - await manager.submit(playgrounds[0].id, 'dummyUserInput'); - - // Wait for error message - await vi.waitFor(() => { - expect((manager.getConversations()[0].messages[1] as ErrorMessage).error).toBeDefined(); - }); - - const conversations = manager.getConversations(); - - expect(conversations.length).toBe(1); - expect(conversations[0].messages.length).toBe(2); - expect(conversations[0].messages[0]).toStrictEqual({ - content: 'dummyUserInput', - id: expect.anything(), - options: undefined, - role: 'user', - timestamp: expect.any(Number), - }); - expect(conversations[0].messages[1]).toStrictEqual({ - error: 'Please reduce the length of the messages or completion. Note: You should start a new playground.', - id: expect.anything(), - timestamp: expect.any(Number), - }); - - expect(webviewMock.postMessage).toHaveBeenLastCalledWith({ - id: Messages.MSG_CONVERSATIONS_UPDATE, - body: conversations, - }); -}); - -test('creating a new playground should send new playground to frontend', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(webviewMock.postMessage).toHaveBeenCalledWith({ - id: Messages.MSG_CONVERSATIONS_UPDATE, - body: [ - { - id: expect.anything(), - modelId: 'model-1', - name: 'a name', - messages: [], - }, - ], - }); -}); - -test('creating a new playground with no name should send new playground to frontend with generated name', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - '', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(webviewMock.postMessage).toHaveBeenCalledWith({ - id: Messages.MSG_CONVERSATIONS_UPDATE, - body: [ - { - id: expect.anything(), - modelId: 'model-1', - name: 'playground 1', - messages: [], - }, - ], - }); -}); - -test('creating a new playground with no model served should start an inference server', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - const createInferenceServerMock = vi.mocked(inferenceManagerMock.createInferenceServer); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(createInferenceServerMock).toHaveBeenCalledWith({ - gpuLayers: expect.any(Number), - image: undefined, - providerId: undefined, - inferenceProvider: undefined, - labels: { - trackingId: 'tracking-1', - }, - modelsInfo: [ - { - id: 'model-1', - name: 'Model 1', - }, - ], - port: expect.anything(), - }); -}); - -test('creating a new playground with the model already served should not start an inference server', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - models: [ - { - id: 'model-1', - }, - ], - }, - ] as InferenceServer[]); - const createInferenceServerMock = vi.mocked(inferenceManagerMock.createInferenceServer); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(createInferenceServerMock).not.toHaveBeenCalled(); -}); - -test('creating a new playground with the model server stopped should start the inference server', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - models: [ - { - id: 'model-1', - }, - ], - status: 'stopped', - container: { - containerId: 'container-1', - }, - }, - ] as InferenceServer[]); - const createInferenceServerMock = vi.mocked(inferenceManagerMock.createInferenceServer); - const startInferenceServerMock = vi.mocked(inferenceManagerMock.startInferenceServer); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(createInferenceServerMock).not.toHaveBeenCalled(); - expect(startInferenceServerMock).toHaveBeenCalledWith('container-1'); -}); - -test('delete conversation should delete the conversation', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - expect(manager.getConversations().length).toBe(0); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - - const conversations = manager.getConversations(); - expect(conversations.length).toBe(1); - manager.deleteConversation(conversations[0].id); - expect(manager.getConversations().length).toBe(0); - expect(webviewMock.postMessage).toHaveBeenCalled(); -}); - -test('creating a new playground with an existing name shoud fail', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - await expect( - manager.createPlayground( - 'a name', - { - id: 'model-2', - name: 'Model 2', - } as unknown as ModelInfo, - 'tracking-2', - ), - ).rejects.toThrowError('a playground with the name a name already exists'); -}); - -test('requestCreatePlayground should call createPlayground and createTask, then updateTask', async () => { - vi.useRealTimers(); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - const createTaskMock = vi.mocked(taskRegistryMock).createTask; - const updateTaskMock = vi.mocked(taskRegistryMock).updateTask; - createTaskMock.mockImplementation((_name: string, _state: TaskState, labels?: { [id: string]: string }) => { - return { - labels, - } as Task; - }); - const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockResolvedValue('playground-1'); - - const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo); - - expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String)); - expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', { - trackingId: id, - }); - await new Promise(resolve => setTimeout(resolve, 0)); - expect(updateTaskMock).toHaveBeenCalledWith({ - labels: { - trackingId: id, - playgroundId: 'playground-1', - }, - state: 'success', - }); -}); - -test('requestCreatePlayground should call createPlayground and createTask, then updateTask when createPlayground fails', async () => { - vi.useRealTimers(); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - const createTaskMock = vi.mocked(taskRegistryMock).createTask; - const updateTaskMock = vi.mocked(taskRegistryMock).updateTask; - const getTasksByLabelsMock = vi.mocked(taskRegistryMock).getTasksByLabels; - createTaskMock.mockImplementation((_name: string, _state: TaskState, labels?: { [id: string]: string }) => { - return { - labels, - } as Task; - }); - const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockRejectedValue(new Error('an error')); - - const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo); - - expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String)); - expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', { - trackingId: id, - }); - - getTasksByLabelsMock.mockReturnValue([ - { - labels: { - trackingId: id, - }, - } as unknown as Task, - ]); - - await new Promise(resolve => setTimeout(resolve, 0)); - expect(updateTaskMock).toHaveBeenCalledWith({ - error: 'Something went wrong while trying to create a playground environment Error: an error.', - labels: { - trackingId: id, - }, - state: 'error', - }); -}); - -describe('system prompt', () => { - test('set system prompt on non existing conversation should throw an error', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - models: [ - { - id: 'model1', - }, - ], - } as unknown as InferenceServer, - ]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - - expect(() => { - manager.setSystemPrompt('invalid', 'content'); - }).toThrowError('conversation with id invalid does not exist.'); - }); - - test('set system prompt should throw an error if user already submit message', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - connection: { - port: 8888, - }, - } as unknown as InferenceServer, - ]); - const createMock = vi.fn().mockResolvedValue([]); - vi.mocked(OpenAI).mockReturnValue({ - chat: { - completions: { - create: createMock, - }, - }, - } as unknown as OpenAI); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1'); - - const date = new Date(2000, 1, 1, 13); - vi.setSystemTime(date); - - const conversations = manager.getConversations(); - await manager.submit(conversations[0].id, 'dummyUserInput'); - - // Wait for assistant message to be completed - await vi.waitFor(() => { - expect((manager.getConversations()[0].messages[1] as ChatMessage).content).toBeDefined(); - }); - - expect(() => { - manager.setSystemPrompt(manager.getConversations()[0].id, 'newSystemPrompt'); - }).toThrowError('Cannot change system prompt on started conversation.'); - }); -}); From 43872fbafc56a24d87ed3f8ddb1ea9c7018ca116 Mon Sep 17 00:00:00 2001 From: Jeff MAURY Date: Fri, 20 Sep 2024 11:11:41 +0200 Subject: [PATCH 3/5] fix: missing waited promise Signed-off-by: Jeff MAURY --- packages/backend/src/studio-api-impl.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index ca10f8f48..e0d7c4827 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -87,9 +87,9 @@ export class StudioApiImpl implements StudioAPI { // Do not wait on the promise as the api would probably timeout before the user answer. podmanDesktopApi.window .showWarningMessage(`Are you sure you want to delete this playground ?`, 'Confirm', 'Cancel') - .then((result: string | undefined) => { + .then(async (result: string | undefined) => { if (result === 'Confirm') { - this.playgroundV2.deleteConversation(conversationId); + await this.playgroundV2.deleteConversation(conversationId); } }) .catch((err: unknown) => { From 1db9972a45b7df59758711649f52b92a25af8b95 Mon Sep 17 00:00:00 2001 From: Jeff MAURY Date: Fri, 20 Sep 2024 11:25:26 +0200 Subject: [PATCH 4/5] fix: format Signed-off-by: Jeff MAURY --- packages/frontend/src/pages/Playground.svelte | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/frontend/src/pages/Playground.svelte b/packages/frontend/src/pages/Playground.svelte index 7900f81b1..dcc586f90 100644 --- a/packages/frontend/src/pages/Playground.svelte +++ b/packages/frontend/src/pages/Playground.svelte @@ -40,7 +40,6 @@ function getStatusForIcon(status?: string, health?: string): string { export function goToUpPage(): void { router.goto('/playgrounds'); } - {#if conversation} From c783640b5d00831c0c84658b451b632f53cf06b4 Mon Sep 17 00:00:00 2001 From: Jeff MAURY Date: Thu, 10 Oct 2024 17:18:32 +0200 Subject: [PATCH 5/5] fix: use open-webui dev image to get automatic model selection Signed-off-by: Jeff MAURY --- packages/backend/src/registries/ConversationRegistry.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/backend/src/registries/ConversationRegistry.ts b/packages/backend/src/registries/ConversationRegistry.ts index c1f7af748..1a99fb101 100644 --- a/packages/backend/src/registries/ConversationRegistry.ts +++ b/packages/backend/src/registries/ConversationRegistry.ts @@ -45,6 +45,8 @@ import { getImageInfo } from '../utils/inferenceUtils'; import type { TaskRegistry } from './TaskRegistry'; import type { PodmanConnection } from '../managers/podmanConnection'; +const OPEN_WEBUI_IMAGE = 'ghcr.io/open-webui/open-webui:dev'; + export class ConversationRegistry extends Publisher implements Disposable { #conversations: Map; #counter: number; @@ -138,7 +140,7 @@ export class ConversationRegistry extends Publisher implements D const conversation = this.get(conversationId); const port = await getFreeRandomPort('127.0.0.1'); const connection = await this.podmanConnection.getConnectionByEngineId(server.container.engineId); - await this.pullImage(connection, 'ghcr.io/open-webui/open-webui:main', { + await this.pullImage(connection, OPEN_WEBUI_IMAGE, { trackingId: trackingId, }); const inferenceServerContainer = await containerEngine.inspectContainer( @@ -155,7 +157,7 @@ export class ConversationRegistry extends Publisher implements D `WEBUI_URL=http://localhost:${port}`, `DEFAULT_MODELS=/models/${server.models[0].file?.file}`, ], - Image: 'ghcr.io/open-webui/open-webui:main', + Image: OPEN_WEBUI_IMAGE, HostConfig: { AutoRemove: true, Mounts: [