Skip to content

Commit

Permalink
feat: adding model params (#886)
Browse files Browse the repository at this point in the history
* feat: adding model params

Signed-off-by: James <[email protected]>

* chore: inference request parameter

* Improve ui right panel model params

* Remove unused import

* Update slider track for darkmode

---------

Signed-off-by: James <[email protected]>
Co-authored-by: James <[email protected]>
Co-authored-by: Louis <[email protected]>
Co-authored-by: Faisal Amir <[email protected]>
  • Loading branch information
4 people authored Dec 11, 2023
1 parent 5f7001d commit 121dc11
Show file tree
Hide file tree
Showing 36 changed files with 758 additions and 196 deletions.
10 changes: 6 additions & 4 deletions core/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ export type ThreadMessage = {
object: string;
/** Thread id, default is a ulid. **/
thread_id: string;
/** The role of the author of this message. **/
/** The assistant id of this thread. **/
assistant_id?: string;
// TODO: comment
/** The role of the author of this message. **/
role: ChatCompletionRole;
/** The content of this message. **/
content: ThreadContent[];
Expand Down Expand Up @@ -125,8 +125,6 @@ export interface Thread {
title: string;
/** Assistants in this thread. **/
assistants: ThreadAssistantInfo[];
// if the thread has been init will full assistant info
isFinishInit: boolean;
/** The timestamp indicating when this thread was created, represented in ISO 8601 format. **/
created: number;
/** The timestamp indicating when this thread was updated, represented in ISO 8601 format. **/
Expand Down Expand Up @@ -166,6 +164,7 @@ export type ThreadState = {
waitingForResponse: boolean;
error?: Error;
lastMessage?: string;
isFinishInit?: boolean;
};
/**
* Represents the inference engine.
Expand Down Expand Up @@ -291,6 +290,9 @@ export type ModelRuntimeParams = {
top_p?: number;
stream?: boolean;
max_tokens?: number;
stop?: string[];
frequency_penalty?: number;
presence_penalty?: number;
};

/**
Expand Down
6 changes: 1 addition & 5 deletions extensions/assistant-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ export default class JanAssistantExtension implements AssistantExtension {
onUnload(): void {}

async createAssistant(assistant: Assistant): Promise<void> {
// assuming that assistants/ directory is already created in the onLoad above

// TODO: check if the directory already exists, then ignore creation for now

const assistantDir = join(JanAssistantExtension._homeDir, assistant.id);
await fs.mkdir(assistantDir);

Expand Down Expand Up @@ -91,7 +87,7 @@ export default class JanAssistantExtension implements AssistantExtension {
avatar: "",
thread_location: undefined,
id: "jan",
object: "assistant", // TODO: maybe we can set default value for this?
object: "assistant",
created_at: Date.now(),
name: "Jan",
description: "A default assistant that can use all downloaded models",
Expand Down
46 changes: 26 additions & 20 deletions extensions/inference-nitro-extension/src/helpers/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import { Observable } from "rxjs";
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: Model,
controller?: AbortController
): Observable<string> {
Expand All @@ -23,34 +22,41 @@ export function requestInference(
headers: {
"Content-Type": "application/json",
"Access-Control-Allow-Origin": "*",
Accept: "text/event-stream",
Accept: model.parameters.stream
? "text/event-stream"
: "application/json",
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
if (model.parameters.stream) {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";

while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
}
subscriber.next(content);
}
subscriber.next(content);
}
}
} else {
const data = await response.json();
subscriber.next(data.choices[0]?.message?.content ?? "");
}
subscriber.complete();
})
Expand Down
5 changes: 1 addition & 4 deletions extensions/inference-nitro-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
*/
onUnload(): void {}


private async writeDefaultEngineSettings() {
try {
const engineFile = join(
Expand Down Expand Up @@ -164,7 +163,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
return new Promise(async (resolve, reject) => {
requestInference(
data.messages ?? [],
JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel
).subscribe({
next: (_content) => {},
Expand Down Expand Up @@ -210,8 +208,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {

requestInference(
data.messages ?? [],
JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel,
{ ...JanInferenceNitroExtension._currentModel, ...data.model },
instance.controller
).subscribe({
next: (content) => {
Expand Down
51 changes: 29 additions & 22 deletions extensions/inference-openai-extension/src/helpers/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ export function requestInference(
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
let model_id: string = model.id
if (engine.full_url.includes("openai.azure.com")){
model_id = engine.full_url.split("/")[5]
let model_id: string = model.id;
if (engine.full_url.includes("openai.azure.com")) {
model_id = engine.full_url.split("/")[5];
}
const requestBody = JSON.stringify({
messages: recentMessages,
Expand All @@ -29,7 +29,9 @@ export function requestInference(
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
Accept: model.parameters.stream
? "text/event-stream"
: "application/json",
"Access-Control-Allow-Origin": "*",
Authorization: `Bearer ${engine.api_key}`,
"api-key": `${engine.api_key}`,
Expand All @@ -38,28 +40,33 @@ export function requestInference(
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
if (model.parameters.stream) {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";

while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
}
subscriber.next(content);
}
subscriber.next(content);
}
}
} else {
const data = await response.json();
subscriber.next(data.choices[0]?.message?.content ?? "");
}
subscriber.complete();
})
Expand Down
1 change: 1 addition & 0 deletions uikit/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"@radix-ui/react-progress": "^1.0.3",
"@radix-ui/react-scroll-area": "^1.0.5",
"@radix-ui/react-select": "^2.0.0",
"@radix-ui/react-slider": "^1.1.2",
"@radix-ui/react-slot": "^1.0.2",
"@radix-ui/react-switch": "^1.0.3",
"@radix-ui/react-toast": "^1.1.5",
Expand Down
1 change: 1 addition & 0 deletions uikit/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ export * from './modal'
export * from './command'
export * from './textarea'
export * from './select'
export * from './slider'
2 changes: 1 addition & 1 deletion uikit/src/input/styles.scss
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.input {
@apply border-border placeholder:text-muted-foreground flex h-9 w-full rounded-md border bg-transparent px-3 py-1 transition-colors;
@apply border-border placeholder:text-muted-foreground flex h-9 w-full rounded-lg border bg-transparent px-3 py-1 transition-colors;
@apply disabled:cursor-not-allowed disabled:opacity-50;
@apply focus-visible:ring-secondary focus-visible:outline-none focus-visible:ring-1;
@apply file:border-0 file:bg-transparent file:font-medium;
Expand Down
1 change: 1 addition & 0 deletions uikit/src/main.scss
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
@import './command/styles.scss';
@import './textarea/styles.scss';
@import './select/styles.scss';
@import './slider/styles.scss';

.animate-spin {
animation: spin 1s linear infinite;
Expand Down
25 changes: 25 additions & 0 deletions uikit/src/slider/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
'use client'

import * as React from 'react'
import * as SliderPrimitive from '@radix-ui/react-slider'

import { twMerge } from 'tailwind-merge'

const Slider = React.forwardRef<
React.ElementRef<typeof SliderPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root>
>(({ className, ...props }, ref) => (
<SliderPrimitive.Root
ref={ref}
className={twMerge('slider', className)}
{...props}
>
<SliderPrimitive.Track className="slider-track">
<SliderPrimitive.Range className="slider-range" />
</SliderPrimitive.Track>
<SliderPrimitive.Thumb className="slider-thumb" />
</SliderPrimitive.Root>
))
Slider.displayName = SliderPrimitive.Root.displayName

export { Slider }
15 changes: 15 additions & 0 deletions uikit/src/slider/styles.scss
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.slider {
@apply relative flex w-full touch-none select-none items-center;

&-track {
@apply relative h-1.5 w-full grow overflow-hidden rounded-full bg-gray-200 dark:bg-gray-800;
}

&-range {
@apply absolute h-full bg-blue-600;
}

&-thumb {
@apply border-primary/50 bg-background focus-visible:ring-ring block h-4 w-4 rounded-full border shadow transition-colors focus-visible:outline-none focus-visible:ring-1 disabled:pointer-events-none disabled:opacity-50;
}
}
4 changes: 2 additions & 2 deletions web/containers/CardSidebar/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export default function CardSidebar({
return (
<div
className={twMerge(
'flex w-full flex-col rounded-md border border-border bg-zinc-200 dark:bg-zinc-600/10'
'flex w-full flex-col overflow-hidden rounded-md border border-border bg-zinc-200 dark:bg-zinc-600/10'
)}
>
<div
Expand All @@ -43,7 +43,7 @@ export default function CardSidebar({
>
<button
onClick={() => setShow(!show)}
className="flex w-full flex-1 items-center space-x-2 px-3 py-2"
className="flex w-full flex-1 items-center space-x-2 bg-zinc-200 px-3 py-2 dark:bg-zinc-600/10"
>
<ChevronDownIcon
className={twMerge(
Expand Down
62 changes: 62 additions & 0 deletions web/containers/Checkbox/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { useEffect, useState } from 'react'

import { Switch } from '@janhq/uikit'

import { useAtomValue } from 'jotai'

import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'

import {
getActiveThreadIdAtom,
getActiveThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'

type Props = {
name: string
title: string
checked: boolean
register: any
}

const Checkbox: React.FC<Props> = ({ name, title, checked, register }) => {
const [currentChecked, setCurrentChecked] = useState<boolean>(checked)
const { updateModelParameter } = useUpdateModelParameters()
const threadId = useAtomValue(getActiveThreadIdAtom)
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)

useEffect(() => {
setCurrentChecked(checked)
}, [checked])

useEffect(() => {
updateSetting()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentChecked])

const updateValue = [name].reduce((accumulator, value) => {
return { ...accumulator, [value]: currentChecked }
}, {})

const updateSetting = () => {
return updateModelParameter(String(threadId), {
...activeModelParams,
...updateValue,
})
}

return (
<div className="flex justify-between">
<label>{title}</label>
<Switch
checked={currentChecked}
{...register(name)}
onCheckedChange={(e) => {
setCurrentChecked(e)
}}
/>
</div>
)
}

export default Checkbox
Loading

0 comments on commit 121dc11

Please sign in to comment.