Skip to content

Commit

Permalink
Implement chat UI (jupyterlab#25)
Browse files Browse the repository at this point in the history
* implement chat UI

* improve rendering of code blocks

* handle empty and loading state in chat-input

* align all messages to the left

* css tweaks

* add build:core yarn script

* handle user selections

* support math rendering

* keep using old OpenAIChat provider for chat and magics

* delete example messages
  • Loading branch information
dlqqq authored and Marchlak committed Oct 28, 2024
1 parent f40305a commit 35383e2
Show file tree
Hide file tree
Showing 21 changed files with 1,506 additions and 57 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"scripts": {
"setup:dev": "lerna run setup:dev --stream",
"build": "lerna run build --stream",
"build:core": "lerna run build --stream --scope \"@jupyter-ai/core\"",
"build:prod": "lerna run build:prod --stream",
"clean": "lerna run clean",
"clean:all": "lerna run clean:all",
Expand Down
13 changes: 8 additions & 5 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from importlib_metadata import entry_points
import inspect
from .engine import BaseModelEngine
from .providers import ChatOpenAIProvider
from .providers import ChatOpenAIProvider, ChatOpenAINewProvider
import os

from langchain.memory import ConversationBufferMemory
Expand Down Expand Up @@ -85,9 +85,12 @@ def initialize_settings(self):
self.settings["ai_default_tasks"] = default_tasks
self.log.info("Registered all default tasks.")

## load OpenAI chat provider
if ChatOpenAIProvider.auth_strategy.name in os.environ:
self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo")
## load OpenAI provider
self.settings["openai_chat"] = ChatOpenAIProvider(model_id="gpt-3.5-turbo")

## load OpenAI new provider
if ChatOpenAINewProvider.auth_strategy.name in os.environ:
provider = ChatOpenAINewProvider(model_id="gpt-3.5-turbo")
# Create a conversation memory
memory = ConversationBufferMemory(return_messages=True)
prompt_template = ChatPromptTemplate.from_messages([
Expand All @@ -96,7 +99,7 @@ def initialize_settings(self):
HumanMessagePromptTemplate.from_template("{input}")
])
chain = ConversationChain(
llm=self.settings["openai_chat"],
llm=provider,
prompt=prompt_template,
verbose=True,
memory=memory
Expand Down
23 changes: 19 additions & 4 deletions packages/jupyter-ai/jupyter_ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Cohere,
HuggingFaceHub,
OpenAI,
OpenAIChat,
SagemakerEndpoint
)

Expand Down Expand Up @@ -155,7 +156,7 @@ class OpenAIProvider(BaseProvider, OpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
class ChatOpenAIProvider(BaseProvider, OpenAIChat):
id = "openai-chat"
name = "OpenAI"
models = [
Expand All @@ -170,9 +171,6 @@ class ChatOpenAIProvider(BaseProvider, ChatOpenAI):
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def append_exchange(self, prompt: str, output: str):
"""Appends a conversational exchange between user and an OpenAI Chat
model to a transcript that will be included in future exchanges."""
Expand All @@ -185,6 +183,23 @@ def append_exchange(self, prompt: str, output: str):
"content": output
})

# uses the new OpenAIChat provider. temporarily living as a separate class until
# conflicts can be resolved
class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
id = "openai-chat-new"
name = "OpenAI"
models = [
"gpt-4",
"gpt-4-0314",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
]
model_id_key = "model_name"
pypi_package_deps = ["openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")

class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
id = "sagemaker-endpoint"
name = "Sagemaker Endpoint"
Expand Down
7 changes: 6 additions & 1 deletion packages/jupyter-ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,19 @@
"@mui/icons-material": "^5.11.0",
"@mui/material": "^5.11.0",
"react": "^18.2.0",
"react-dom": "^18.2.0"
"react-dom": "^18.2.0",
"react-markdown": "^8.0.6",
"react-syntax-highlighter": "^15.5.0",
"rehype-katex": "^6.0.2",
"remark-math": "^5.1.1"
},
"devDependencies": {
"@babel/core": "^7.0.0",
"@babel/preset-env": "^7.0.0",
"@jupyterlab/builder": "^3.5.1",
"@jupyterlab/testutils": "^3.0.0",
"@types/jest": "^26.0.0",
"@types/react-syntax-highlighter": "^15.5.6",
"@typescript-eslint/eslint-plugin": "^4.8.1",
"@typescript-eslint/parser": "^4.8.1",
"eslint": "^7.14.0",
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai/schema/plugin.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"$schema": "http://json-schema.org/draft-07/schema",
"title": "Generative AI",
"description": "JupyterLab generative artificial intelligence integration.",
"jupyter.lab.setting-icon": "jupyter-ai::psychology",
"jupyter.lab.setting-icon-label": "Jupyter AI Chat",
"jupyter.lab.toolbars": {
"Cell": [
{
Expand Down
52 changes: 27 additions & 25 deletions packages/jupyter-ai/src/chat_handler.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import { IDisposable } from '@lumino/disposable';
import { ServerConnection } from '@jupyterlab/services';
import { URLExt } from '@jupyterlab/coreutils';
import {Poll} from '@lumino/polling';
import { Poll } from '@lumino/polling';
import { AiService, requestAPI } from './handler';

const CHAT_SERVICE_URL = 'api/ai/chats';

const CHAT_SERVICE_URL = "api/ai/chats"

export class ChatHandler implements IDisposable{
export class ChatHandler implements IDisposable {
/**
* Create a new chat handler.
*/
Expand All @@ -31,10 +30,10 @@ export class ChatHandler implements IDisposable{
return this._isDisposed;
}

/**
/**
* Dispose the chat handler.
*/
dispose(): void {
dispose(): void {
if (this.isDisposed) {
return;
}
Expand All @@ -43,7 +42,7 @@ export class ChatHandler implements IDisposable{
// Clean up poll.
this._poll.dispose();

this._listeners = []
this._listeners = [];

// Clean up socket.
const socket = this._socket;
Expand All @@ -61,19 +60,21 @@ export class ChatHandler implements IDisposable{
this._listeners.push(handler);
}

public removeListener(handler: (message: AiService.ChatMessage) => void): void {
const index = this._listeners.indexOf(handler)
if(index > -1) {
this._listeners.splice(index, 1)
public removeListener(
handler: (message: AiService.ChatMessage) => void
): void {
const index = this._listeners.indexOf(handler);
if (index > -1) {
this._listeners.splice(index, 1);
}
}

public sendMessage(message: AiService.ChatRequest): void {
this._socket?.send(JSON.stringify(message))
this._socket?.send(JSON.stringify(message));
}

public async getHistory(): Promise<AiService.ChatHistory> {
let data: AiService.ChatHistory = {messages: []}
let data: AiService.ChatHistory = { messages: [] };
try {
data = await requestAPI('chats/history', {
method: 'GET'
Expand All @@ -90,22 +91,23 @@ export class ChatHandler implements IDisposable{

private _subscribe(): Promise<void> {
return new Promise<void>((_, reject) => {
if (this.isDisposed) {
return;
}
const { token, WebSocket, wsUrl } = this.serverSettings;
const url =
URLExt.join(wsUrl, CHAT_SERVICE_URL) +
(token ? `?token=${encodeURIComponent(token)}` : '');
const socket = (this._socket = new WebSocket(url));

socket.onclose = () => reject(new Error('ChatHandler socket closed'));
socket.onmessage = msg => msg.data && this._onMessage(JSON.parse(msg.data));
if (this.isDisposed) {
return;
}
const { token, WebSocket, wsUrl } = this.serverSettings;
const url =
URLExt.join(wsUrl, CHAT_SERVICE_URL) +
(token ? `?token=${encodeURIComponent(token)}` : '');
const socket = (this._socket = new WebSocket(url));

socket.onclose = () => reject(new Error('ChatHandler socket closed'));
socket.onmessage = msg =>
msg.data && this._onMessage(JSON.parse(msg.data));
});
}

private _isDisposed = false;
private _poll: Poll;
private _socket: WebSocket | null = null;
private _listeners: ((msg: any) => void)[] = [];
private _listeners: ((msg: any) => void)[] = [];
}
89 changes: 89 additions & 0 deletions packages/jupyter-ai/src/components/chat-code-view.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import React, { useState, useMemo } from 'react';

import type { CodeProps } from 'react-markdown/lib/ast-to-react';
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
import { duotoneLight } from 'react-syntax-highlighter/dist/esm/styles/prism';
import { Box, Button } from '@mui/material';

type ChatCodeViewProps = CodeProps;

type ChatCodeInlineProps = ChatCodeViewProps;

type ChatCodeBlockProps = ChatCodeViewProps & {
language?: string;
};

function ChatCodeInline({
className,
children,
...props
}: ChatCodeInlineProps) {
return (
<code {...props} className={className}>
{children}
</code>
);
}

enum CopyStatus {
None,
Copying,
Copied
}

const COPYBTN_TEXT_BY_STATUS: Record<CopyStatus, string> = {
[CopyStatus.None]: 'Copy to clipboard',
[CopyStatus.Copying]: 'Copying...',
[CopyStatus.Copied]: 'Copied!'
};

function ChatCodeBlock({ language, children, ...props }: ChatCodeBlockProps) {
const value = useMemo(() => String(children).replace(/\n$/, ''), [children]);
const [copyStatus, setCopyStatus] = useState<CopyStatus>(CopyStatus.None);

const copy = async () => {
setCopyStatus(CopyStatus.Copying);
try {
await navigator.clipboard.writeText(value);
} catch (e) {
console.error(e);
setCopyStatus(CopyStatus.None);
return;
}

setCopyStatus(CopyStatus.Copied);
setTimeout(() => setCopyStatus(CopyStatus.None), 1000);
};

return (
<Box sx={{ display: 'flex', flexDirection: 'column' }}>
<SyntaxHighlighter
{...props}
children={value}
style={duotoneLight}
language={language}
PreTag="div"
/>
<Button
onClick={copy}
disabled={copyStatus !== CopyStatus.None}
sx={{ alignSelf: 'flex-end' }}
>
{COPYBTN_TEXT_BY_STATUS[copyStatus]}
</Button>
</Box>
);
}

export function ChatCodeView({
inline,
className,
...props
}: ChatCodeViewProps) {
const match = /language-(\w+)/.exec(className || '');
return inline ? (
<ChatCodeInline {...props} />
) : (
<ChatCodeBlock {...props} language={match ? match[1] : undefined} />
);
}
71 changes: 71 additions & 0 deletions packages/jupyter-ai/src/components/chat-input.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import React from 'react';

import {
Box,
IconButton,
Input,
SxProps,
Theme,
FormGroup,
FormControlLabel,
Checkbox
} from '@mui/material';
import SendIcon from '@mui/icons-material/Send';

type ChatInputProps = {
loading: boolean;
value: string;
onChange: (newValue: string) => unknown;
onSend: () => unknown;
hasSelection: boolean;
includeSelection: boolean;
toggleIncludeSelection: () => unknown;
replaceSelection: boolean;
toggleReplaceSelection: () => unknown;
sx?: SxProps<Theme>;
};

export function ChatInput(props: ChatInputProps): JSX.Element {
return (
<Box sx={props.sx}>
<Box sx={{ display: 'flex' }}>
<Input
value={props.value}
onChange={e => props.onChange(e.target.value)}
multiline
sx={{ flexGrow: 1 }}
/>
<IconButton
size="large"
color="primary"
onClick={props.onSend}
disabled={props.loading || !props.value.trim().length}
>
<SendIcon />
</IconButton>
</Box>
{props.hasSelection && (
<FormGroup sx={{ display: 'flex', flexDirection: 'row' }}>
<FormControlLabel
control={
<Checkbox
checked={props.includeSelection}
onChange={props.toggleIncludeSelection}
/>
}
label="Include selection"
/>
<FormControlLabel
control={
<Checkbox
checked={props.replaceSelection}
onChange={props.toggleReplaceSelection}
/>
}
label="Replace selection"
/>
</FormGroup>
)}
</Box>
);
}
Loading

0 comments on commit 35383e2

Please sign in to comment.