Skip to content

Commit

Permalink
Merge pull request #824 from narengogi/chore/support-developer-role
Browse files Browse the repository at this point in the history
support new openai developer role
  • Loading branch information
VisargD authored Dec 26, 2024
2 parents 56d0e65 + 71c1d23 commit 48d8b92
Show file tree
Hide file tree
Showing 28 changed files with 188 additions and 23 deletions.
12 changes: 9 additions & 3 deletions src/providers/ai21/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { AI21 } from '../../globals';
import { Params } from '../../types/requestBody';
import { Params, SYSTEM_MESSAGE_ROLES } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand All @@ -19,7 +19,10 @@ export const AI21ChatCompleteConfig: ProviderConfig = {
transform: (params: Params) => {
let inputMessages: any = [];

if (params.messages?.[0]?.role === 'system') {
if (
params.messages?.[0]?.role &&
SYSTEM_MESSAGE_ROLES.includes(params.messages?.[0]?.role)
) {
inputMessages = params.messages.slice(1);
} else if (params.messages) {
inputMessages = params.messages;
Expand All @@ -35,7 +38,10 @@ export const AI21ChatCompleteConfig: ProviderConfig = {
param: 'system',
required: false,
transform: (params: Params) => {
if (params.messages?.[0].role === 'system') {
if (
params.messages?.[0]?.role &&
SYSTEM_MESSAGE_ROLES.includes(params.messages?.[0]?.role)
) {
return params.messages?.[0].content;
}
},
Expand Down
7 changes: 4 additions & 3 deletions src/providers/anthropic/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
Message,
ContentType,
AnthropicPromptCache,
SYSTEM_MESSAGE_ROLES,
} from '../../types/requestBody';
import {
ChatCompletionResponse,
Expand Down Expand Up @@ -111,7 +112,7 @@ export const AnthropicChatCompleteConfig: ProviderConfig = {
// Transform the chat messages into a simple prompt
if (!!params.messages) {
params.messages.forEach((msg: Message & AnthropicPromptCache) => {
if (msg.role === 'system') return;
if (SYSTEM_MESSAGE_ROLES.includes(msg.role)) return;

if (msg.role === 'assistant') {
messages.push(transformAssistantMessage(msg));
Expand Down Expand Up @@ -188,7 +189,7 @@ export const AnthropicChatCompleteConfig: ProviderConfig = {
if (!!params.messages) {
params.messages.forEach((msg: Message & AnthropicPromptCache) => {
if (
msg.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(msg.role) &&
msg.content &&
typeof msg.content === 'object' &&
msg.content[0].text
Expand All @@ -203,7 +204,7 @@ export const AnthropicChatCompleteConfig: ProviderConfig = {
});
});
} else if (
msg.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(msg.role) &&
typeof msg.content === 'string'
) {
systemMessages.push({
Expand Down
7 changes: 7 additions & 0 deletions src/providers/anyscale/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { ANYSCALE } from '../../globals';
import { Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand All @@ -20,6 +21,12 @@ export const AnyscaleChatCompleteConfig: ProviderConfig = {
messages: {
param: 'messages',
default: '',
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
functions: {
param: 'functions',
Expand Down
7 changes: 7 additions & 0 deletions src/providers/azure-ai-inference/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { Params } from '../../types/requestBody';
import { OpenAIErrorResponseTransform } from '../openai/utils';
import {
ChatCompletionResponse,
Expand All @@ -14,6 +15,12 @@ export const AzureAIInferenceChatCompleteConfig: ProviderConfig = {
messages: {
param: 'messages',
default: '',
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
max_tokens: {
param: 'max_tokens',
Expand Down
15 changes: 10 additions & 5 deletions src/providers/bedrock/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { BEDROCK, documentMimeTypes, imagesMimeTypes } from '../../globals';
import { Message, Params, ToolCall } from '../../types/requestBody';
import {
Message,
Params,
ToolCall,
SYSTEM_MESSAGE_ROLES,
} from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand Down Expand Up @@ -150,7 +155,7 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = {
transform: (params: BedrockChatCompletionsParams) => {
if (!params.messages) return [];
const transformedMessages = params.messages
.filter((msg) => msg.role !== 'system')
.filter((msg) => !SYSTEM_MESSAGE_ROLES.includes(msg.role))
.map((msg) => {
return {
role: msg.role === 'assistant' ? 'assistant' : 'user',
Expand Down Expand Up @@ -183,7 +188,7 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = {
if (!params.messages) return;
const systemMessages = params.messages.reduce(
(acc: { text: string }[], msg) => {
if (msg.role === 'system')
if (SYSTEM_MESSAGE_ROLES.includes(msg.role))
return acc.concat(...getMessageTextContentArray(msg));
return acc;
},
Expand Down Expand Up @@ -603,7 +608,7 @@ export const BedrockCohereChatCompleteConfig: ProviderConfig = {
if (!!params.messages) {
let messages: Message[] = params.messages;
messages.forEach((msg, index) => {
if (index === 0 && msg.role === 'system') {
if (index === 0 && SYSTEM_MESSAGE_ROLES.includes(msg.role)) {
prompt += `system: ${messages}\n`;
} else if (msg.role == 'user') {
prompt += `user: ${msg.content}\n`;
Expand Down Expand Up @@ -787,7 +792,7 @@ export const BedrockAI21ChatCompleteConfig: ProviderConfig = {
if (!!params.messages) {
let messages: Message[] = params.messages;
messages.forEach((msg, index) => {
if (index === 0 && msg.role === 'system') {
if (index === 0 && SYSTEM_MESSAGE_ROLES.includes(msg.role)) {
prompt += `system: ${messages}\n`;
} else if (msg.role == 'user') {
prompt += `user: ${msg.content}\n`;
Expand Down
7 changes: 7 additions & 0 deletions src/providers/deepbricks/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { DEEPBRICKS } from '../../globals';
import { Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand All @@ -17,6 +18,12 @@ export const DeepbricksChatCompleteConfig: ProviderConfig = {
messages: {
param: 'messages',
default: '',
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
functions: {
param: 'functions',
Expand Down
7 changes: 7 additions & 0 deletions src/providers/deepinfra/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { DEEPINFRA } from '../../globals';
import { Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand All @@ -22,6 +23,12 @@ export const DeepInfraChatCompleteConfig: ProviderConfig = {
param: 'messages',
required: true,
default: [],
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
frequency_penalty: {
param: 'frequency_penalty',
Expand Down
7 changes: 7 additions & 0 deletions src/providers/deepseek/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { DEEPSEEK } from '../../globals';
import { Params } from '../../types/requestBody';

import {
ChatCompletionResponse,
Expand All @@ -19,6 +20,12 @@ export const DeepSeekChatCompleteConfig: ProviderConfig = {
messages: {
param: 'messages',
default: '',
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
max_tokens: {
param: 'max_tokens',
Expand Down
7 changes: 7 additions & 0 deletions src/providers/fireworks-ai/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { FIREWORKS_AI } from '../../globals';
import { Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand All @@ -19,6 +20,12 @@ export const FireworksAIChatCompleteConfig: ProviderConfig = {
param: 'messages',
required: true,
default: [],
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
tools: {
param: 'tools',
Expand Down
13 changes: 7 additions & 6 deletions src/providers/google-vertex-ai/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
Params,
Tool,
ToolCall,
SYSTEM_MESSAGE_ROLES,
} from '../../types/requestBody';
import {
AnthropicChatCompleteResponse,
Expand Down Expand Up @@ -71,7 +72,7 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = {
// From gemini-1.5 onwards, systemInstruction is supported
// Skipping system message and sending it in systemInstruction for gemini 1.5 models
if (
message.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(message.role) &&
!SYSTEM_INSTRUCTION_DISABLED_MODELS.includes(params.model as string)
)
return;
Expand Down Expand Up @@ -186,7 +187,7 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = {
if (!firstMessage) return;

if (
firstMessage.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(firstMessage.role) &&
typeof firstMessage.content === 'string'
) {
return {
Expand All @@ -200,7 +201,7 @@ export const VertexGoogleChatCompleteConfig: ProviderConfig = {
}

if (
firstMessage.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(firstMessage.role) &&
typeof firstMessage.content === 'object' &&
firstMessage.content?.[0]?.text
) {
Expand Down Expand Up @@ -413,7 +414,7 @@ export const VertexAnthropicChatCompleteConfig: ProviderConfig = {
// Transform the chat messages into a simple prompt
if (!!params.messages) {
params.messages.forEach((msg) => {
if (msg.role === 'system') return;
if (SYSTEM_MESSAGE_ROLES.includes(msg.role)) return;

if (msg.role === 'assistant') {
messages.push(transformAssistantMessageForAnthropic(msg));
Expand Down Expand Up @@ -481,14 +482,14 @@ export const VertexAnthropicChatCompleteConfig: ProviderConfig = {
if (!!params.messages) {
params.messages.forEach((msg) => {
if (
msg.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(msg.role) &&
msg.content &&
typeof msg.content === 'object' &&
msg.content[0].text
) {
systemMessage = msg.content[0].text;
} else if (
msg.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(msg.role) &&
typeof msg.content === 'string'
) {
systemMessage = msg.content;
Expand Down
7 changes: 4 additions & 3 deletions src/providers/google/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
Params,
ToolCall,
ToolChoice,
SYSTEM_MESSAGE_ROLES,
} from '../../types/requestBody';
import { buildGoogleSearchRetrievalTool } from '../google-vertex-ai/chatComplete';
import { derefer, getMimeType } from '../google-vertex-ai/utils';
Expand Down Expand Up @@ -152,7 +153,7 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
// From gemini-1.5 onwards, systemInstruction is supported
// Skipping system message and sending it in systemInstruction for gemini 1.5 models
if (
message.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(message.role) &&
!SYSTEM_INSTRUCTION_DISABLED_MODELS.includes(params.model as string)
)
return;
Expand Down Expand Up @@ -261,7 +262,7 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
if (!firstMessage) return;

if (
firstMessage.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(firstMessage.role) &&
typeof firstMessage.content === 'string'
) {
return {
Expand All @@ -275,7 +276,7 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
}

if (
firstMessage.role === 'system' &&
SYSTEM_MESSAGE_ROLES.includes(firstMessage.role) &&
typeof firstMessage.content === 'object' &&
firstMessage.content?.[0]?.text
) {
Expand Down
7 changes: 7 additions & 0 deletions src/providers/huggingface/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { HUGGING_FACE } from '../../globals';
import { Params } from '../../types/requestBody';
import { OpenAIErrorResponseTransform } from '../openai/utils';
import {
ChatCompletionResponse,
Expand All @@ -18,6 +19,12 @@ export const HuggingfaceChatCompleteConfig: ProviderConfig = {
messages: {
param: 'messages',
default: '',
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
functions: {
param: 'functions',
Expand Down
7 changes: 7 additions & 0 deletions src/providers/lingyi/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { LINGYI } from '../../globals';
import { Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand All @@ -18,6 +19,12 @@ export const LingyiChatCompleteConfig: ProviderConfig = {
messages: {
param: 'messages',
default: '',
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
max_tokens: {
param: 'max_tokens',
Expand Down
7 changes: 7 additions & 0 deletions src/providers/mistral-ai/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { MISTRAL_AI } from '../../globals';
import { Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand All @@ -18,6 +19,12 @@ export const MistralAIChatCompleteConfig: ProviderConfig = {
messages: {
param: 'messages',
default: [],
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
temperature: {
param: 'temperature',
Expand Down
7 changes: 7 additions & 0 deletions src/providers/monsterapi/chatComplete.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { MONSTERAPI } from '../../globals';
import { Params } from '../../types/requestBody';
import {
ChatCompletionResponse,
ErrorResponse,
Expand Down Expand Up @@ -47,6 +48,12 @@ export const MonsterAPIChatCompleteConfig: ProviderConfig = {
param: 'messages',
required: true,
default: [],
transform: (params: Params) => {
return params.messages?.map((message) => {
if (message.role === 'developer') return { ...message, role: 'system' };
return message;
});
},
},
};

Expand Down
Loading

0 comments on commit 48d8b92

Please sign in to comment.