Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: add Ai21Labs model provider #3727

Merged
merged 11 commits into from
Sep 18, 2024
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ ENV ACCESS_CODE="" \

# Model Variables
ENV \
# AI21
AI21_API_KEY="" \
# Ai360
AI360_API_KEY="" \
# Anthropic
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile.database
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ ENV NEXT_PUBLIC_S3_DOMAIN="" \

# Model Variables
ENV \
# AI21
AI21_API_KEY="" \
# Ai360
AI360_API_KEY="" \
# Anthropic
Expand Down
2 changes: 2 additions & 0 deletions src/app/(main)/settings/llm/ProviderList/providers.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { useMemo } from 'react';

import {
Ai21ProviderCard,
Ai360ProviderCard,
AnthropicProviderCard,
BaichuanProviderCard,
Expand Down Expand Up @@ -54,6 +55,7 @@ export const useProviderList = (): ProviderItem[] => {
TogetherAIProviderCard,
FireworksAIProviderCard,
UpstageProviderCard,
Ai21ProviderCard,
QwenProviderCard,
SparkProviderCard,
ZhiPuProviderCard,
Expand Down
7 changes: 7 additions & 0 deletions src/app/api/chat/agentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {

const apiKey = apiKeyManager.pick(payload?.apiKey || SPARK_API_KEY);

return { apiKey };
}
case ModelProvider.Ai21: {
const { AI21_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || AI21_API_KEY);

return { apiKey };
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/config/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ export const getLLMConfig = () => {

ENABLED_SPARK: z.boolean(),
SPARK_API_KEY: z.string().optional(),

ENABLED_AI21: z.boolean(),
AI21_API_KEY: z.string().optional(),
},
runtimeEnv: {
API_KEY_SELECT_MODE: process.env.API_KEY_SELECT_MODE,
Expand Down Expand Up @@ -215,6 +218,9 @@ export const getLLMConfig = () => {

ENABLED_SPARK: !!process.env.SPARK_API_KEY,
SPARK_API_KEY: process.env.SPARK_API_KEY,

ENABLED_AI21: !!process.env.AI21_API_KEY,
AI21_API_KEY: process.env.AI21_API_KEY,
},
});
};
Expand Down
37 changes: 37 additions & 0 deletions src/config/modelProviders/ai21.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { ModelProviderCard } from '@/types/llm';

// ref https://docs.ai21.com/reference/jamba-15-api-ref
const Ai21: ModelProviderCard = {
chatModels: [
{
displayName: 'Jamba 1.5 Mini',
enabled: true,
functionCall: true,
id: 'jamba-1.5-mini',
pricing: {
input: 0.2,
output: 0.4,
},
tokens: 256_000,
},
{
displayName: 'Jamba 1.5 Large',
enabled: true,
functionCall: true,
id: 'jamba-1.5-large',
pricing: {
input: 2,
output: 8,
},
tokens: 256_000,
},
],
checkModel: 'jamba-1.5-mini',
id: 'ai21',
modelList: { showModelFetcher: true },
modelsUrl: 'https://docs.ai21.com/reference',
name: 'Ai21Labs',
url: 'https://studio.ai21.com',
};

export default Ai21;
4 changes: 4 additions & 0 deletions src/config/modelProviders/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { ChatModelCard, ModelProviderCard } from '@/types/llm';

import Ai21Provider from './ai21';
import Ai360Provider from './ai360';
import AnthropicProvider from './anthropic';
import AzureProvider from './azure';
Expand Down Expand Up @@ -53,6 +54,7 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [
SiliconCloudProvider.chatModels,
UpstageProvider.chatModels,
SparkProvider.chatModels,
Ai21Provider.chatModels,
].flat();

export const DEFAULT_MODEL_PROVIDER_LIST = [
Expand All @@ -71,6 +73,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [
TogetherAIProvider,
FireworksAIProvider,
UpstageProvider,
Ai21Provider,
QwenProvider,
SparkProvider,
ZhiPuProvider,
Expand All @@ -93,6 +96,7 @@ export const isProviderDisableBroswerRequest = (id: string) => {
return !!provider;
};

export { default as Ai21ProviderCard } from './ai21';
export { default as Ai360ProviderCard } from './ai360';
export { default as AnthropicProviderCard } from './anthropic';
export { default as AzureProviderCard } from './azure';
Expand Down
5 changes: 5 additions & 0 deletions src/const/settings/llm.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {
Ai21ProviderCard,
Ai360ProviderCard,
AnthropicProviderCard,
BaichuanProviderCard,
Expand Down Expand Up @@ -30,6 +31,10 @@ import { ModelProvider } from '@/libs/agent-runtime';
import { UserModelProviderConfig } from '@/types/user/settings';

export const DEFAULT_LLM_CONFIG: UserModelProviderConfig = {
ai21: {
enabled: false,
enabledModels: filterEnabledModels(Ai21ProviderCard),
},
ai360: {
enabled: false,
enabledModels: filterEnabledModels(Ai360ProviderCard),
Expand Down
7 changes: 7 additions & 0 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { ClientOptions } from 'openai';
import type { TracePayload } from '@/const/trace';

import { LobeRuntimeAI } from './BaseAI';
import { LobeAi21AI } from './ai21';
import { LobeAi360AI } from './ai360';
import { LobeAnthropicAI } from './anthropic';
import { LobeAzureOpenAI } from './azureOpenai';
Expand Down Expand Up @@ -116,6 +117,7 @@ class AgentRuntime {
static async initializeWithProviderOptions(
provider: string,
params: Partial<{
ai21: Partial<ClientOptions>;
ai360: Partial<ClientOptions>;
anthropic: Partial<ClientOptions>;
azure: { apiVersion?: string; apikey?: string; endpoint?: string };
Expand Down Expand Up @@ -282,6 +284,11 @@ class AgentRuntime {
runtimeModel = new LobeSparkAI(params.spark);
break;
}

case ModelProvider.Ai21: {
runtimeModel = new LobeAi21AI(params.ai21);
break;
}
}

return new AgentRuntime(runtimeModel);
Expand Down
Loading
Loading