Skip to content

Commit

Permalink
Merge pull request #751 from narengogi/feat/huggingface-on-vertex
Browse files Browse the repository at this point in the history
add support for huggingface on vertex
  • Loading branch information
VisargD authored Dec 26, 2024
2 parents e180a99 + 6d3b609 commit 56d0e65
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/providers/google-vertex-ai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ export const GoogleApiConfig: ProviderAPIConfig = {
return `${projectRoute}/endpoints/openapi/chat/completions`;
}

case 'endpoints': {
return `${projectRoute}/endpoints/${model}/chat/completions`;
}

default:
return `${projectRoute}`;
}
Expand Down
12 changes: 12 additions & 0 deletions src/providers/google-vertex-ai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import {
GoogleImageGenConfig,
GoogleImageGenResponseTransform,
} from './imageGenerate';
import { chatCompleteParams, responseTransformers } from '../open-ai-base';
import { GOOGLE_VERTEX_AI } from '../../globals';

const VertexConfig: ProviderConfigs = {
api: VertexApiConfig,
Expand Down Expand Up @@ -57,6 +59,16 @@ const VertexConfig: ProviderConfigs = {
'stream-chatComplete': VertexLlamaChatCompleteStreamChunkTransform,
},
};
case 'endpoints':
return {
chatComplete: chatCompleteParams([], {
model: 'meta-llama-3-8b-instruct',
}),
api: GoogleApiConfig,
responseTransforms: responseTransformers(GOOGLE_VERTEX_AI, {
chatComplete: true,
}),
};
}
},
};
Expand Down
2 changes: 1 addition & 1 deletion src/providers/google-vertex-ai/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ export const getModelAndProvider = (modelString: string) => {
const modelStringParts = modelString.split('.');
if (
modelStringParts.length > 1 &&
['google', 'anthropic', 'meta'].includes(modelStringParts[0])
['google', 'anthropic', 'meta', 'endpoints'].includes(modelStringParts[0])
) {
provider = modelStringParts[0];
model = modelStringParts.slice(1).join('.');
Expand Down

0 comments on commit 56d0e65

Please sign in to comment.