Skip to content

Commit

Permalink
Hacky support for o1 beta
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Sep 13, 2024
1 parent 03e7a1a commit 1adb117
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 2 deletions.
60 changes: 58 additions & 2 deletions packages/core/src/model/nodes/ChatNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
openaiModels,
streamChatCompletions,
type ChatCompletionTool,
chatCompletions,
} from '../../utils/openai.js';
import retry from 'p-retry';
import type { Inputs, Outputs } from '../GraphProcessor.js';
Expand Down Expand Up @@ -892,9 +893,7 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {
const options: Omit<ChatCompletionOptions, 'auth' | 'signal'> = {
messages: completionMessages,
model: finalModel,
temperature: useTopP ? undefined : temperature,
top_p: useTopP ? topP : undefined,
max_tokens: maxTokens,
n: numberOfChoices,
frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty,
Expand All @@ -907,6 +906,15 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {
...additionalParameters,
};

const isO1Beta = finalModel.startsWith('o1-preview') || finalModel.startsWith('o1-mini');

if (isO1Beta) {
options.max_completion_tokens = maxTokens;
} else {
options.temperature = useTopP ? undefined : temperature; // Not supported in o1-preview
options.max_tokens = maxTokens;
}

const cacheKey = JSON.stringify(options);

if (this.data.cache) {
Expand All @@ -918,6 +926,54 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {

const startTime = Date.now();

if (isO1Beta) {
const response = await chatCompletions({
auth: {
apiKey: context.settings.openAiKey ?? '',
organization: context.settings.openAiOrganization,
},
headers: allAdditionalHeaders,
signal: context.signal,
timeout: context.settings.chatNodeTimeout,
...options,
});

if (isMultiResponse) {
output['response' as PortId] = {
type: 'string[]',
value: response.choices.map((c) => c.message.content!),
};
} else {
output['response' as PortId] = {
type: 'string',
value: response.choices[0]!.message.content!,
};
}

if (!isMultiResponse) {
output['all-messages' as PortId] = {
type: 'chat-message[]',
value: [
...messages,
{
type: 'assistant',
message: response.choices[0]!.message.content!,
function_calls: undefined,
isCacheBreakpoint: false,
function_call: undefined,
},
],
};
}

output['duration' as PortId] = { type: 'number', value: Date.now() - startTime };

Object.freeze(output);
cache.set(cacheKey, output);

return output;
}

const chunks = streamChatCompletions({
auth: {
apiKey: context.settings.openAiKey ?? '',
Expand Down
61 changes: 61 additions & 0 deletions packages/core/src/utils/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,38 @@ export const openaiModels = {
},
displayName: 'GPT-4o mini (2024-07-18)',
},
'o1-preview': {
maxTokens: 128000,
cost: {
prompt: 0.0015,
completion: 0.006,
},
displayName: 'o1-preview',
},
'o1-preview-2024-09-12': {
maxTokens: 128000,
cost: {
prompt: 0.0015,
completion: 0.006,
},
displayName: 'o1-preview (2024-09-12)',
},
'o1-mini': {
maxTokens: 128000,
cost: {
prompt: 0.0003,
completion: 0.0012,
},
displayName: 'o1-mini',
},
'o1-mini-2024-09-12': {
maxTokens: 128000,
cost: {
prompt: 0.0003,
completion: 0.0012,
},
displayName: 'o1-mini (2024-09-12)',
},
'local-model': {
maxTokens: Number.MAX_SAFE_INTEGER,
cost: {
Expand Down Expand Up @@ -258,6 +290,10 @@ export type ChatCompletionOptions = {
temperature?: number;
top_p?: number;
max_tokens?: number;

/** Only for o1 series of models. Otherwise max_tokens. */
max_completion_tokens?: number;

n?: number;
stop?: string | string[];
presence_penalty?: number;
Expand Down Expand Up @@ -414,6 +450,31 @@ export type ChatCompletionFunction = {
strict: boolean;
};

export async function chatCompletions({
endpoint,
auth,
signal,
headers,
timeout,
...rest
}: ChatCompletionOptions): Promise<ChatCompletionResponse> {
const abortSignal = signal ?? new AbortController().signal;

const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${auth.apiKey}`,
...(auth.organization ? { 'OpenAI-Organization': auth.organization } : {}),
...headers,
},
body: JSON.stringify(rest),
signal: abortSignal,
});

return response.json();
}

export async function* streamChatCompletions({
endpoint,
auth,
Expand Down

0 comments on commit 1adb117

Please sign in to comment.