Skip to content

Commit

Permalink
feat: adding model params
Browse files Browse the repository at this point in the history
Signed-off-by: James <[email protected]>
  • Loading branch information
James committed Dec 8, 2023
1 parent 3bfe32a commit 14fd13d
Show file tree
Hide file tree
Showing 24 changed files with 493 additions and 140 deletions.
10 changes: 6 additions & 4 deletions core/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ export type ThreadMessage = {
object: string;
/** Thread id, default is a ulid. **/
thread_id: string;
/** The role of the author of this message. **/
/** The assistant id of this thread. **/
assistant_id?: string;
// TODO: comment
/** The role of the author of this message. **/
role: ChatCompletionRole;
/** The content of this message. **/
content: ThreadContent[];
Expand Down Expand Up @@ -125,8 +125,6 @@ export interface Thread {
title: string;
/** Assistants in this thread. **/
assistants: ThreadAssistantInfo[];
// if the thread has been init will full assistant info
isFinishInit: boolean;
/** The timestamp indicating when this thread was created, represented in ISO 8601 format. **/
created: number;
/** The timestamp indicating when this thread was updated, represented in ISO 8601 format. **/
Expand Down Expand Up @@ -165,6 +163,7 @@ export type ThreadState = {
waitingForResponse: boolean;
error?: Error;
lastMessage?: string;
isFinishInit?: boolean;
};

/**
Expand Down Expand Up @@ -275,6 +274,9 @@ export type ModelRuntimeParam = {
top_p?: number;
stream?: boolean;
max_tokens?: number;
stop?: string[];
frequency_penalty?: number;
presence_penalty?: number;
};

/**
Expand Down
6 changes: 1 addition & 5 deletions extensions/assistant-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ export default class JanAssistantExtension implements AssistantExtension {
onUnload(): void {}

async createAssistant(assistant: Assistant): Promise<void> {
// assuming that assistants/ directory is already created in the onLoad above

// TODO: check if the directory already exists, then ignore creation for now

const assistantDir = join(JanAssistantExtension._homeDir, assistant.id);
await fs.mkdir(assistantDir);

Expand Down Expand Up @@ -91,7 +87,7 @@ export default class JanAssistantExtension implements AssistantExtension {
avatar: "",
thread_location: undefined,
id: "jan",
object: "assistant", // TODO: maybe we can set default value for this?
object: "assistant",
created_at: Date.now(),
name: "Jan",
description: "A default assistant that can use all downloaded models",
Expand Down
15 changes: 15 additions & 0 deletions web/containers/Checkbox/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
type Props = {
name: string
title: string
checked: boolean
register: any
}

const Checkbox: React.FC<Props> = ({ name, title, checked, register }) => (
<div className="flex justify-between">
<label>{title}</label>
<input type="checkbox" defaultChecked={checked} {...register(name)} />
</div>
)

export default Checkbox
52 changes: 30 additions & 22 deletions web/containers/DropdownListSidebar/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { useMainViewState } from '@/hooks/useMainViewState'

import { toGigabytes } from '@/utils/converter'

import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'

export const selectedModelAtom = atom<Model | undefined>(undefined)

Expand All @@ -36,6 +36,7 @@ export default function DropdownListSidebar() {
const activeThread = useAtomValue(activeThreadAtom)
const [selected, setSelected] = useState<Model | undefined>()
const { setMainViewState } = useMainViewState()

const { activeModel, stateModel } = useActiveModel()

useEffect(() => {
Expand All @@ -61,13 +62,22 @@ export default function DropdownListSidebar() {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [activeThread, activeModel, stateModel.loading])

const threadStates = useAtomValue(threadStatesAtom)
if (!activeThread) {
return null
}
const finishInit = threadStates[activeThread.id].isFinishInit ?? true

const onValueSelected = (value: string) => {
setSelected(downloadedModels.filter((x) => x.id === value)[0])
setSelectedModel(downloadedModels.filter((x) => x.id === value)[0])
}

return (
<Select
disabled={finishInit}
value={selected?.id}
onValueChange={(value) => {
setSelected(downloadedModels.filter((x) => x.id === value)[0])
setSelectedModel(downloadedModels.filter((x) => x.id === value)[0])
}}
onValueChange={finishInit ? undefined : onValueSelected}
>
<SelectTrigger className="w-full">
<SelectValue placeholder="Choose model to start">
Expand All @@ -82,26 +92,24 @@ export default function DropdownListSidebar() {
<div className="border-b border-border" />
{downloadedModels.length === 0 ? (
<div className="px-4 py-2">
<p>{`Oops, you don't have a model yet.`}</p>
<p>Oops, you don't have a model yet.</p>
</div>
) : (
<SelectGroup>
{downloadedModels.map((x, i) => {
return (
<SelectItem
key={i}
value={x.id}
className={twMerge(x.id === selected?.id && 'bg-secondary')}
>
<div className="flex w-full justify-between">
<span className="line-clamp-1 block">{x.name}</span>
<span className="font-bold text-muted-foreground">
{toGigabytes(x.metadata.size)}
</span>
</div>
</SelectItem>
)
})}
{downloadedModels.map((x, i) => (
<SelectItem
key={i}
value={x.id}
className={twMerge(x.id === selected?.id && 'bg-secondary')}
>
<div className="flex w-full justify-between">
<span className="line-clamp-1 block">{x.name}</span>
<span className="font-bold text-muted-foreground">
{toGigabytes(x.metadata.size)}
</span>
</div>
</SelectItem>
))}
</SelectGroup>
)}
<div className="border-b border-border" />
Expand Down
2 changes: 1 addition & 1 deletion web/containers/Layout/TopBar/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { useMainViewState } from '@/hooks/useMainViewState'

import { showRightSideBarAtom } from '@/screens/Chat/Sidebar'

import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'

const TopBar = () => {
const activeThread = useAtomValue(activeThreadAtom)
Expand Down
2 changes: 1 addition & 1 deletion web/containers/Providers/EventHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
import {
updateThreadWaitingForResponseAtom,
threadsAtom,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'

export default function EventHandler({ children }: { children: ReactNode }) {
const addNewMessage = useSetAtom(addNewMessageAtom)
Expand Down
33 changes: 33 additions & 0 deletions web/containers/Slider/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
type Props = {
name: string
title: string
min: number
max: number
step: number
value: number
register: any
}

const Slider: React.FC<Props> = ({
name,
title,
min,
max,
step,
value,
register,
}) => (
<div className="flex flex-col">
<p>{title}</p>
<input
{...register(name)}
defaultValue={value}
type="range"
min={min}
max={max}
step={step}
/>
</div>
)

export default Slider
19 changes: 11 additions & 8 deletions web/helpers/atoms/ChatMessage.atom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { atom } from 'jotai'
import {
getActiveThreadIdAtom,
updateThreadStateLastMessageAtom,
} from './Conversation.atom'
} from './Thread.atom'

/**
* Stores all chat messages for all threads
Expand Down Expand Up @@ -76,15 +76,18 @@ export const addNewMessageAtom = atom(
}
)

export const deleteConversationMessage = atom(null, (get, set, id: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
export const deleteChatMessageAtom = atom(
null,
(get, set, threadId: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
}
newData[threadId] = []
set(chatMessages, newData)
}
newData[id] = []
set(chatMessages, newData)
})
)

export const cleanConversationMessages = atom(null, (get, set, id: string) => {
export const cleanChatMessageAtom = atom(null, (get, set, id: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import { Thread, ThreadContent, ThreadState } from '@janhq/core'
import {
ModelRuntimeParam,
Thread,
ThreadContent,
ThreadState,
} from '@janhq/core'
import { atom } from 'jotai'

/**
* Stores the current active conversation id.
* Stores the current active thread id.
*/
const activeThreadIdAtom = atom<string | undefined>(undefined)

export const getActiveThreadIdAtom = atom((get) => get(activeThreadIdAtom))

export const setActiveThreadIdAtom = atom(
null,
(_get, set, convoId: string | undefined) => set(activeThreadIdAtom, convoId)
(_get, set, threadId: string | undefined) => set(activeThreadIdAtom, threadId)
)

export const waitingToSendMessage = atom<boolean | undefined>(undefined)
Expand All @@ -20,47 +25,48 @@ export const waitingToSendMessage = atom<boolean | undefined>(undefined)
*/
export const threadStatesAtom = atom<Record<string, ThreadState>>({})
export const activeThreadStateAtom = atom<ThreadState | undefined>((get) => {
const activeConvoId = get(activeThreadIdAtom)
if (!activeConvoId) {
console.debug('Active convo id is undefined')
const threadId = get(activeThreadIdAtom)
if (!threadId) {
console.debug('Active thread id is undefined')
return undefined
}

return get(threadStatesAtom)[activeConvoId]
return get(threadStatesAtom)[threadId]
})

export const updateThreadWaitingForResponseAtom = atom(
export const deleteThreadStateAtom = atom(
null,
(get, set, conversationId: string, waitingForResponse: boolean) => {
(get, set, threadId: string) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = {
...currentState[conversationId],
waitingForResponse,
error: undefined,
}
delete currentState[threadId]
set(threadStatesAtom, currentState)
}
)
export const updateConversationErrorAtom = atom(

export const updateThreadInitSuccessAtom = atom(
null,
(get, set, conversationId: string, error?: Error) => {
(get, set, threadId: string) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = {
...currentState[conversationId],
error,
currentState[threadId] = {
...currentState[threadId],
isFinishInit: true,
}
set(threadStatesAtom, currentState)
}
)
export const updateConversationHasMoreAtom = atom(

export const updateThreadWaitingForResponseAtom = atom(
null,
(get, set, conversationId: string, hasMore: boolean) => {
(get, set, threadId: string, waitingForResponse: boolean) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = { ...currentState[conversationId], hasMore }
currentState[threadId] = {
...currentState[threadId],
waitingForResponse,
error: undefined,
}
set(threadStatesAtom, currentState)
}
)

export const updateThreadStateLastMessageAtom = atom(
null,
(get, set, threadId: string, lastContent?: ThreadContent[]) => {
Expand Down Expand Up @@ -100,3 +106,42 @@ export const threadsAtom = atom<Thread[]>([])
export const activeThreadAtom = atom<Thread | undefined>((get) =>
get(threadsAtom).find((c) => c.id === get(getActiveThreadIdAtom))
)

/**
* Store model params at thread level settings
*/
export const threadModelRuntimeParamsAtom = atom<
Record<string, ModelRuntimeParam>
>({})

export const getActiveThreadModelRuntimeParamsAtom = atom<
ModelRuntimeParam | undefined
>((get) => {
const threadId = get(activeThreadIdAtom)
if (!threadId) {
console.debug('Active thread id is undefined')
return undefined
}

return get(threadModelRuntimeParamsAtom)[threadId]
})

export const getThreadModelRuntimeParamsAtom = atom(
(get, threadId: string) => get(threadModelRuntimeParamsAtom)[threadId]
)

export const setThreadModelRuntimeParamsAtom = atom(
null,
(get, set, threadId: string, params: ModelRuntimeParam) => {
const currentState = { ...get(threadModelRuntimeParamsAtom) }
currentState[threadId] = params
console.debug(
`Update model params for thread ${threadId}, ${JSON.stringify(
params,
null,
2
)}`
)
set(threadModelRuntimeParamsAtom, currentState)
}
)
Loading

0 comments on commit 14fd13d

Please sign in to comment.