diff --git a/packages/core/src/model/nodes/ChatNode.ts b/packages/core/src/model/nodes/ChatNode.ts index 72d25f67..56a2529c 100644 --- a/packages/core/src/model/nodes/ChatNode.ts +++ b/packages/core/src/model/nodes/ChatNode.ts @@ -16,6 +16,7 @@ import { openaiModels, streamChatCompletions, type ChatCompletionTool, + chatCompletions, } from '../../utils/openai.js'; import retry from 'p-retry'; import type { Inputs, Outputs } from '../GraphProcessor.js'; @@ -892,9 +893,7 @@ export class ChatNodeImpl extends NodeImpl { const options: Omit = { messages: completionMessages, model: finalModel, - temperature: useTopP ? undefined : temperature, top_p: useTopP ? topP : undefined, - max_tokens: maxTokens, n: numberOfChoices, frequency_penalty: frequencyPenalty, presence_penalty: presencePenalty, @@ -907,6 +906,15 @@ export class ChatNodeImpl extends NodeImpl { ...additionalParameters, }; + const isO1Beta = finalModel.startsWith('o1-preview') || finalModel.startsWith('o1-mini'); + + if (isO1Beta) { + options.max_completion_tokens = maxTokens; + } else { + options.temperature = useTopP ? undefined : temperature; // Not supported in o1-preview + options.max_tokens = maxTokens; + } + const cacheKey = JSON.stringify(options); if (this.data.cache) { @@ -918,6 +926,54 @@ export class ChatNodeImpl extends NodeImpl { const startTime = Date.now(); + if (isO1Beta) { + const response = await chatCompletions({ + auth: { + apiKey: context.settings.openAiKey ?? '', + organization: context.settings.openAiOrganization, + }, + headers: allAdditionalHeaders, + signal: context.signal, + timeout: context.settings.chatNodeTimeout, + ...options, + }); + + if (isMultiResponse) { + output['response' as PortId] = { + type: 'string[]', + value: response.choices.map((c) => c.message.content!), + }; + } else { + output['response' as PortId] = { + type: 'string', + value: response.choices[0]!.message.content!, + }; + } + + if (!isMultiResponse) { + output['all-messages' as PortId] = { + type: 'chat-message[]', + value: [ + ...messages, + { + type: 'assistant', + message: response.choices[0]!.message.content!, + function_calls: undefined, + isCacheBreakpoint: false, + function_call: undefined, + }, + ], + }; + } + + output['duration' as PortId] = { type: 'number', value: Date.now() - startTime }; + + Object.freeze(output); + cache.set(cacheKey, output); + + return output; + } + const chunks = streamChatCompletions({ auth: { apiKey: context.settings.openAiKey ?? '', diff --git a/packages/core/src/utils/openai.ts b/packages/core/src/utils/openai.ts index 1487152d..e5f1699b 100644 --- a/packages/core/src/utils/openai.ts +++ b/packages/core/src/utils/openai.ts @@ -158,6 +158,38 @@ export const openaiModels = { }, displayName: 'GPT-4o mini (2024-07-18)', }, + 'o1-preview': { + maxTokens: 128000, + cost: { + prompt: 0.0015, + completion: 0.006, + }, + displayName: 'o1-preview', + }, + 'o1-preview-2024-09-12': { + maxTokens: 128000, + cost: { + prompt: 0.0015, + completion: 0.006, + }, + displayName: 'o1-preview (2024-09-12)', + }, + 'o1-mini': { + maxTokens: 128000, + cost: { + prompt: 0.0003, + completion: 0.0012, + }, + displayName: 'o1-mini', + }, + 'o1-mini-2024-09-12': { + maxTokens: 128000, + cost: { + prompt: 0.0003, + completion: 0.0012, + }, + displayName: 'o1-mini (2024-09-12)', + }, 'local-model': { maxTokens: Number.MAX_SAFE_INTEGER, cost: { @@ -258,6 +290,10 @@ export type ChatCompletionOptions = { temperature?: number; top_p?: number; max_tokens?: number; + + /** Only for o1 series of models. Otherwise max_tokens. */ + max_completion_tokens?: number; + n?: number; stop?: string | string[]; presence_penalty?: number; @@ -414,6 +450,31 @@ export type ChatCompletionFunction = { strict: boolean; }; +export async function chatCompletions({ + endpoint, + auth, + signal, + headers, + timeout, + ...rest +}: ChatCompletionOptions): Promise { + const abortSignal = signal ?? new AbortController().signal; + + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${auth.apiKey}`, + ...(auth.organization ? { 'OpenAI-Organization': auth.organization } : {}), + ...headers, + }, + body: JSON.stringify(rest), + signal: abortSignal, + }); + + return response.json(); +} + export async function* streamChatCompletions({ endpoint, auth,