diff --git a/src/providers/google-vertex-ai/api.ts b/src/providers/google-vertex-ai/api.ts index 3076e5268..486854acd 100644 --- a/src/providers/google-vertex-ai/api.ts +++ b/src/providers/google-vertex-ai/api.ts @@ -86,6 +86,10 @@ export const GoogleApiConfig: ProviderAPIConfig = { return `${projectRoute}/endpoints/openapi/chat/completions`; } + case 'endpoints': { + return `${projectRoute}/endpoints/${model}/chat/completions`; + } + default: return `${projectRoute}`; } diff --git a/src/providers/google-vertex-ai/index.ts b/src/providers/google-vertex-ai/index.ts index 95e5e01c1..a087fb926 100644 --- a/src/providers/google-vertex-ai/index.ts +++ b/src/providers/google-vertex-ai/index.ts @@ -17,6 +17,8 @@ import { GoogleImageGenConfig, GoogleImageGenResponseTransform, } from './imageGenerate'; +import { chatCompleteParams, responseTransformers } from '../open-ai-base'; +import { GOOGLE_VERTEX_AI } from '../../globals'; const VertexConfig: ProviderConfigs = { api: VertexApiConfig, @@ -57,6 +59,16 @@ const VertexConfig: ProviderConfigs = { 'stream-chatComplete': VertexLlamaChatCompleteStreamChunkTransform, }, }; + case 'endpoints': + return { + chatComplete: chatCompleteParams([], { + model: 'meta-llama-3-8b-instruct', + }), + api: GoogleApiConfig, + responseTransforms: responseTransformers(GOOGLE_VERTEX_AI, { + chatComplete: true, + }), + }; } }, }; diff --git a/src/providers/google-vertex-ai/utils.ts b/src/providers/google-vertex-ai/utils.ts index adc897228..ccb9673f2 100644 --- a/src/providers/google-vertex-ai/utils.ts +++ b/src/providers/google-vertex-ai/utils.ts @@ -130,7 +130,7 @@ export const getModelAndProvider = (modelString: string) => { const modelStringParts = modelString.split('.'); if ( modelStringParts.length > 1 && - ['google', 'anthropic', 'meta'].includes(modelStringParts[0]) + ['google', 'anthropic', 'meta', 'endpoints'].includes(modelStringParts[0]) ) { provider = modelStringParts[0]; model = modelStringParts.slice(1).join('.');