Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: open-webui playground prototype
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff MAURY <[email protected]>
jeffmaury committed Sep 20, 2024
1 parent 7435841 commit bf5a5f8
Showing 7 changed files with 153 additions and 233 deletions.
Binary file added packages/backend/src/assets/webui.db
Binary file not shown.
24 changes: 19 additions & 5 deletions packages/backend/src/managers/playgroundV2Manager.ts
Original file line number Diff line number Diff line change
@@ -36,6 +36,8 @@ import { getRandomString } from '../utils/randomUtils';
import type { TaskRegistry } from '../registries/TaskRegistry';
import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry';
import { getHash } from '../utils/sha';
import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry';
import type { PodmanConnection } from './podmanConnection';

export class PlaygroundV2Manager implements Disposable {
#conversationRegistry: ConversationRegistry;
@@ -46,17 +48,24 @@ export class PlaygroundV2Manager implements Disposable {
private taskRegistry: TaskRegistry,
private telemetry: TelemetryLogger,
private cancellationTokenRegistry: CancellationTokenRegistry,
configurationRegistry: ConfigurationRegistry,
podmanConnection: PodmanConnection,
) {
this.#conversationRegistry = new ConversationRegistry(webview);
this.#conversationRegistry = new ConversationRegistry(
webview,
configurationRegistry,
taskRegistry,
podmanConnection,
);
}

deleteConversation(conversationId: string): void {
async deleteConversation(conversationId: string): Promise<void> {
const conversation = this.#conversationRegistry.get(conversationId);
this.telemetry.logUsage('playground.delete', {
totalMessages: conversation.messages.length,
modelId: getHash(conversation.modelId),
});
this.#conversationRegistry.deleteConversation(conversationId);
await this.#conversationRegistry.deleteConversation(conversationId);
}

async requestCreatePlayground(name: string, model: ModelInfo): Promise<string> {
@@ -117,11 +126,11 @@ export class PlaygroundV2Manager implements Disposable {
}

// Create conversation
const conversationId = this.#conversationRegistry.createConversation(name, model.id);
const conversationId = await this.#conversationRegistry.createConversation(name, model.id);

// create/start inference server if necessary
const servers = this.inferenceManager.getServers();
const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id));
let server = servers.find(s => s.models.map(mi => mi.id).includes(model.id));
if (!server) {
await this.inferenceManager.createInferenceServer(
await withDefaultConfiguration({
@@ -131,10 +140,15 @@ export class PlaygroundV2Manager implements Disposable {
},
}),
);
server = this.inferenceManager.findServerByModel(model);
} else if (server.status === 'stopped') {
await this.inferenceManager.startInferenceServer(server.container.containerId);
}

if (server && server.status === 'running') {
await this.#conversationRegistry.startConversationContainer(server, trackingId, conversationId);
}

return conversationId;
}

4 changes: 4 additions & 0 deletions packages/backend/src/registries/ConfigurationRegistry.ts
Original file line number Diff line number Diff line change
@@ -62,6 +62,10 @@ export class ConfigurationRegistry extends Publisher<ExtensionConfiguration> imp
return path.join(this.appUserDirectory, 'models');
}

public getConversationsPath(): string {
return path.join(this.appUserDirectory, 'conversations');
}

dispose(): void {
this.#configurationDisposable?.dispose();
}
120 changes: 116 additions & 4 deletions packages/backend/src/registries/ConversationRegistry.ts
Original file line number Diff line number Diff line change
@@ -25,14 +25,36 @@ import type {
Message,
PendingChat,
} from '@shared/src/models/IPlaygroundMessage';
import type { Disposable, Webview } from '@podman-desktop/api';
import {
type Disposable,
type Webview,
type ContainerCreateOptions,
containerEngine,
type ContainerProviderConnection,
type ImageInfo,
type PullEvent,
} from '@podman-desktop/api';
import { Messages } from '@shared/Messages';
import type { ConfigurationRegistry } from './ConfigurationRegistry';
import path from 'node:path';
import fs from 'node:fs';
import type { InferenceServer } from '@shared/src/models/IInference';
import { getFreeRandomPort } from '../utils/ports';
import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../utils/utils';
import { getImageInfo } from '../utils/inferenceUtils';
import type { TaskRegistry } from './TaskRegistry';
import type { PodmanConnection } from '../managers/podmanConnection';

export class ConversationRegistry extends Publisher<Conversation[]> implements Disposable {
#conversations: Map<string, Conversation>;
#counter: number;

constructor(webview: Webview) {
constructor(
webview: Webview,
private configurationRegistry: ConfigurationRegistry,
private taskRegistry: TaskRegistry,
private podmanConnection: PodmanConnection,
) {
super(webview, Messages.MSG_CONVERSATIONS_UPDATE, () => this.getAll());
this.#conversations = new Map<string, Conversation>();
this.#counter = 0;
@@ -76,13 +98,32 @@ export class ConversationRegistry extends Publisher<Conversation[]> implements D
this.notify();
}

deleteConversation(id: string): void {
async deleteConversation(id: string): Promise<void> {
const conversation = this.get(id);
if (conversation.container) {
await containerEngine.stopContainer(conversation.container?.engineId, conversation.container?.containerId);
}
await fs.promises.rm(path.join(this.configurationRegistry.getConversationsPath(), id), {
recursive: true,
force: true,
});
this.#conversations.delete(id);
this.notify();
}

createConversation(name: string, modelId: string): string {
async createConversation(name: string, modelId: string): Promise<string> {
const conversationId = this.getUniqueId();
const conversationFolder = path.join(this.configurationRegistry.getConversationsPath(), conversationId);
await fs.promises.mkdir(conversationFolder, {
recursive: true,
});
//WARNING: this will not work in production mode but didn't find how to embed binary assets
//this code get an initialized database so that default user is not admin thus did not get the initial
//welcome modal dialog
await fs.promises.copyFile(
path.join(__dirname, '..', 'src', 'assets', 'webui.db'),
path.join(conversationFolder, 'webui.db'),
);
this.#conversations.set(conversationId, {
name: name,
modelId: modelId,
@@ -93,6 +134,77 @@ export class ConversationRegistry extends Publisher<Conversation[]> implements D
return conversationId;
}

async startConversationContainer(server: InferenceServer, trackingId: string, conversationId: string): Promise<void> {
const conversation = this.get(conversationId);
const port = await getFreeRandomPort('127.0.0.1');
const connection = await this.podmanConnection.getConnectionByEngineId(server.container.engineId);
await this.pullImage(connection, 'ghcr.io/open-webui/open-webui:main', {
trackingId: trackingId,
});
const inferenceServerContainer = await containerEngine.inspectContainer(
server.container.engineId,
server.container.containerId,
);
const options: ContainerCreateOptions = {
Env: [
'DEFAULT_LOCALE=en-US',
'WEBUI_AUTH=false',
'ENABLE_OLLAMA_API=false',
`OPENAI_API_BASE_URL=http://${inferenceServerContainer.NetworkSettings.IPAddress}:8000/v1`,
'OPENAI_API_KEY=sk_dummy',
`WEBUI_URL=http://localhost:${port}`,
`DEFAULT_MODELS=/models/${server.models[0].file?.file}`,
],
Image: 'ghcr.io/open-webui/open-webui:main',
HostConfig: {
AutoRemove: true,
Mounts: [
{
Source: path.join(this.configurationRegistry.getConversationsPath(), conversationId),
Target: '/app/backend/data',
Type: 'bind',
},
],
PortBindings: {
'8080/tcp': [
{
HostPort: `${port}`,
},
],
},
SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION],
},
};
const c = await containerEngine.createContainer(server.container.engineId, options);
conversation.container = { engineId: c.engineId, containerId: c.id, port };
}

protected pullImage(
connection: ContainerProviderConnection,
image: string,
labels: { [id: string]: string },
): Promise<ImageInfo> {
// Creating a task to follow pulling progress
const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels);

// get the default image info for this provider
return getImageInfo(connection, image, (_event: PullEvent) => {})
.catch((err: unknown) => {
pullingTask.state = 'error';
pullingTask.progress = undefined;
pullingTask.error = `Something went wrong while pulling ${image}: ${String(err)}`;
throw err;
})
.then(imageInfo => {
pullingTask.state = 'success';
pullingTask.progress = undefined;
return imageInfo;
})
.finally(() => {
this.taskRegistry.updateTask(pullingTask);
});
}

/**
* This method will be responsible for finalizing the message by concatenating all the choices
* @param conversationId
2 changes: 2 additions & 0 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
@@ -306,6 +306,8 @@ export class Studio {
this.#taskRegistry,
this.#telemetry,
this.#cancellationTokenRegistry,
this.#configurationRegistry,
this.#podmanConnection,
);
this.#extensionContext.subscriptions.push(this.#playgroundManager);

231 changes: 7 additions & 224 deletions packages/frontend/src/pages/Playground.svelte
Original file line number Diff line number Diff line change
@@ -1,113 +1,26 @@
<script lang="ts">
import { conversations } from '../stores/conversations';
import { studioClient } from '/@/utils/client';
import {
isAssistantChat,
isPendingChat,
isUserChat,
isSystemPrompt,
isChatMessage,
isErrorMessage,
} from '@shared/src/models/IPlaygroundMessage';
import { catalog } from '../stores/catalog';
import { afterUpdate } from 'svelte';
import ContentDetailsLayout from '../lib/ContentDetailsLayout.svelte';
import RangeInput from '../lib/RangeInput.svelte';
import Fa from 'svelte-fa';
import ChatMessage from '../lib/conversation/ChatMessage.svelte';
import SystemPromptBanner from '/@/lib/conversation/SystemPromptBanner.svelte';
import { inferenceServers } from '/@/stores/inferenceServers';
import { faCircleInfo, faPaperPlane, faStop } from '@fortawesome/free-solid-svg-icons';
import { Button, Tooltip, DetailsPage, StatusIcon } from '@podman-desktop/ui-svelte';
import { DetailsPage, StatusIcon } from '@podman-desktop/ui-svelte';
import { router } from 'tinro';
import ConversationActions from '../lib/conversation/ConversationActions.svelte';
import { ContainerIcon } from '@podman-desktop/ui-svelte/icons';
export let playgroundId: string;
let prompt: string;
let sendEnabled = false;
let scrollable: Element;
let errorMsg = '';
// settings
let temperature = 0.8;
let max_tokens = -1;
let top_p = 0.5;
let cancellationTokenId: number | undefined = undefined;
$: conversation = $conversations.find(conversation => conversation.id === playgroundId);
$: messages =
conversation?.messages.filter(message => isChatMessage(message)).filter(message => !isSystemPrompt(message)) ?? [];
$: model = $catalog.models.find(model => model.id === conversation?.modelId);
$: {
if (conversation?.messages.length) {
const latest = conversation.messages[conversation.messages.length - 1];
if (isSystemPrompt(latest) || (isAssistantChat(latest) && !isPendingChat(latest))) {
sendEnabled = true;
}
if (isErrorMessage(latest)) {
errorMsg = latest.error;
sendEnabled = true;
}
} else {
sendEnabled = true;
}
}
$: server = $inferenceServers.find(is => conversation && is.models.map(mi => mi.id).includes(conversation?.modelId));
function askPlayground(): void {
errorMsg = '';
sendEnabled = false;
studioClient
.submitPlaygroundMessage(playgroundId, prompt, {
temperature,
max_tokens,
top_p,
})
.then(token => {
cancellationTokenId = token;
})
.catch((err: unknown) => {
errorMsg = String(err);
sendEnabled = true;
});
prompt = '';
}
afterUpdate(() => {
if (!conversation) {
router.goto('/playgrounds');
return;
}
if (!conversation?.messages.length) {
return;
}
const latest = conversation.messages[conversation.messages.length - 1];
if (isUserChat(latest) || (isAssistantChat(latest) && isPendingChat(latest))) {
scrollToBottom(scrollable).catch(err => console.error(`Error scrolling to bottom:`, err));
}
});
function requestFocus(element: HTMLElement): void {
element.focus();
}
function handleKeydown(e: KeyboardEvent): void {
if (e.key === 'Enter') {
askPlayground();
e.preventDefault();
}
}
async function scrollToBottom(element: Element): Promise<void> {
element.scroll?.({ top: element.scrollHeight, behavior: 'smooth' });
}
function isHealthy(status?: string, health?: string): boolean {
return status === 'running' && (!health || health === 'healthy');
}
function getStatusForIcon(status?: string, health?: string): string {
switch (status) {
case 'running':
@@ -124,42 +37,10 @@ function getStatusForIcon(status?: string, health?: string): string {
}
}
function getStatusText(status?: string, health?: string): string {
switch (status) {
case 'running':
switch (health) {
case 'healthy':
return 'Model Service running';
case 'starting':
return 'Model Service starting';
default:
return 'Model Service not running';
}
default:
return 'Model Service not running';
}
}
function getSendPromptTitle(sendEnabled: boolean, status?: string, health?: string): string | undefined {
if (!isHealthy(status, health)) {
return getStatusText(status, health);
} else if (!sendEnabled) {
return 'Please wait, assistant is replying';
}
return undefined;
}
export function goToUpPage(): void {
router.goto('/playgrounds');
}
function handleOnClick(): void {
if (cancellationTokenId) {
studioClient
.requestCancelToken(cancellationTokenId)
.catch(err => console.error(`Error request cancel token ${cancellationTokenId}`, err));
}
}
</script>

{#if conversation}
@@ -188,110 +69,12 @@ function handleOnClick(): void {
<ConversationActions detailed conversation={conversation} />
</svelte:fragment>
<svelte:fragment slot="content">
<div class="flex flex-col w-full h-full bg-[var(--pd-content-bg)]">
<div class="h-full overflow-auto" bind:this={scrollable}>
<ContentDetailsLayout detailsTitle="Settings" detailsLabel="settings">
<svelte:fragment slot="content">
<div class="flex flex-col w-full h-full">
<div aria-label="conversation" class="w-full h-full">
{#if conversation}
<!-- Show a banner for the system prompt -->
{#key conversation.messages.length}
<SystemPromptBanner conversation={conversation} />
{/key}
<!-- show all message except the sytem prompt -->
<ul>
{#each messages as message}
<li>
<ChatMessage message={message} />
</li>
{/each}
</ul>
{/if}
</div>
</div>
</svelte:fragment>
<svelte:fragment slot="details">
<div class="text-[var(--pd-content-card-text)]">Next prompt will use these settings</div>
<div
class="bg-[var(--pd-content-card-inset-bg)] text-[var(--pd-content-card-text)] w-full rounded-md p-4">
<div class="mb-4 flex flex-col">Model Parameters</div>
<div class="flex flex-col space-y-4" aria-label="parameters">
<div class="flex flex-row">
<div class="w-full">
<RangeInput name="temperature" min="0" max="2" step="0.1" bind:value={temperature} />
</div>
<Tooltip left>
<Fa class="text-[var(--pd-content-card-icon)]" icon={faCircleInfo} />
<svelte:fragment slot="tip">
<div class="inline-block py-2 px-4 rounded-md" aria-label="tooltip">
What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output
more random, while lower values like 0.2 will make it more focused and deterministic.
</div>
</svelte:fragment>
</Tooltip>
</div>
<div class="flex flex-row">
<div class="w-full">
<RangeInput name="max tokens" min="-1" max="32768" step="1" bind:value={max_tokens} />
</div>
<Tooltip left>
<Fa class="text-[var(--pd-content-card-icon)]" icon={faCircleInfo} />
<svelte:fragment slot="tip">
<div class="inline-block py-2 px-4 rounded-md" aria-label="tooltip">
The maximum number of tokens that can be generated in the chat completion.
</div>
</svelte:fragment>
</Tooltip>
</div>
<div class="flex flex-row">
<div class="w-full">
<RangeInput name="top-p" min="0" max="1" step="0.1" bind:value={top_p} />
</div>
<Tooltip left>
<Fa class="text-[var(--pd-content-card-icon)]" icon={faCircleInfo} />
<svelte:fragment slot="tip">
<div class="inline-block py-2 px-4 rounded-md" aria-label="tooltip">
An alternative to sampling with temperature, where the model considers the results of the
tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
probability mass are considered.
</div>
</svelte:fragment>
</Tooltip>
</div>
</div>
</div>
</svelte:fragment>
</ContentDetailsLayout>
</div>
{#if errorMsg}
<div class="text-[var(--pd-input-field-error-text)] p-2">{errorMsg}</div>
{/if}
<div class="flex flex-row flex-none w-full px-4 py-2 bg-[var(--pd-content-card-bg)]">
<textarea
aria-label="prompt"
bind:value={prompt}
use:requestFocus
on:keydown={handleKeydown}
rows="2"
class="w-full p-2 outline-none rounded-sm bg-[var(--pd-content-card-inset-bg)] text-[var(--pd-content-card-text)] placeholder-[var(--pd-content-card-text)]"
placeholder="Type your prompt here"
disabled={!sendEnabled}></textarea>

<div class="flex-none text-right m-4">
{#if !sendEnabled && cancellationTokenId !== undefined}
<Button title="Stop" icon={faStop} type="secondary" on:click={handleOnClick} />
{:else}
<Button
inProgress={!sendEnabled}
disabled={!isHealthy(server?.status, server?.health?.Status) || !prompt?.length}
on:click={askPlayground}
icon={faPaperPlane}
type="secondary"
title={getSendPromptTitle(sendEnabled, server?.status, server?.health?.Status)}
aria-label="Send prompt"></Button>
{/if}
</div>
<div class="w-full h-full bg-[var(--pd-content-bg)]">
<div class="h-full overflow-auto">
<iframe
class="h-full w-full"
title={conversation.name}
src="http://localhost:{conversation.container?.port}?lang=en"></iframe>
</div>
</div>
</svelte:fragment>
5 changes: 5 additions & 0 deletions packages/shared/src/models/IPlaygroundMessage.ts
Original file line number Diff line number Diff line change
@@ -57,6 +57,11 @@ export interface Conversation {
messages: Message[];
modelId: string;
name: string;
container?: {
engineId: string;
containerId: string;
port: number;
};
}

export interface Choice {

0 comments on commit bf5a5f8

Please sign in to comment.