From 65e6d2de58472690db457eabecac3e4703e3d570 Mon Sep 17 00:00:00 2001 From: Yuki Sekiya Date: Fri, 24 May 2024 00:41:28 -0700 Subject: [PATCH] Enable client side to control model id and region --- .env.development-template | 4 +++- amplify/backend/awscloudformation/override.ts | 17 +++++++++++++++++ .../backend/function/streamClaude3/src/index.js | 8 ++++---- src/utils/service.ts | 15 ++++++++++++++- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/.env.development-template b/.env.development-template index 3dbba81..2cecedd 100644 --- a/.env.development-template +++ b/.env.development-template @@ -1,2 +1,4 @@ VITE_INDEX_ID=***** -# VITE_SERVER_URL=http://localhost:8080 \ No newline at end of file +# VITE_SERVER_URL=http://localhost:8080 +VITE_MODEL_ID=anthropic.claude-3-haiku-20240307-v1:0 +VITE_BEDROCK_REGION=us-west-2 \ No newline at end of file diff --git a/amplify/backend/awscloudformation/override.ts b/amplify/backend/awscloudformation/override.ts index 3cba74f..2907a9f 100644 --- a/amplify/backend/awscloudformation/override.ts +++ b/amplify/backend/awscloudformation/override.ts @@ -55,6 +55,23 @@ export function override(resources: AmplifyRootStackTemplate, amplifyProjectInfo envContent += `\nVITE_STREAM_FUNC_NAME=streamClaude3-${envName}`; } + // VITE_MODEL_ID の値を追加 + const modelId = "anthropic.claude-3-haiku-20240307-v1:0" + const modelIdPattern = /^VITE_MODEL_ID=.*/gm; + if (modelIdPattern.test(envContent)) { + envContent = envContent.replace(modelIdPattern, `VITE_MODEL_ID=${modelId}`); + } else { + envContent += `\nVITE_MODEL_ID=${modelId}`; + } + + // VITE_BEDROCK_REGION の値を追加 + const bedrockRegionPattern = /^VITE_BEDROCK_REGION=.*/gm; + if (bedrockRegionPattern.test(envContent)) { + envContent = envContent.replace(bedrockRegionPattern, `VITE_BEDROCK_REGION=${region_name}`); + } else { + envContent += `\nVITE_BEDROCK_REGION=${region_name}`; + } + // .env ファイルに書き込む fs.writeFileSync('.env', envContent); } diff --git a/amplify/backend/function/streamClaude3/src/index.js b/amplify/backend/function/streamClaude3/src/index.js index 6713b6e..f429eeb 100644 --- a/amplify/backend/function/streamClaude3/src/index.js +++ b/amplify/backend/function/streamClaude3/src/index.js @@ -8,8 +8,6 @@ const { ThrottlingException, } = require('@aws-sdk/client-bedrock-runtime'); -const client = new BedrockRuntimeClient(); - const extractOutputTextClaude3Message = (body) => { if (body.type === 'message') { @@ -20,7 +18,9 @@ const extractOutputTextClaude3Message = (body) => { return ''; }; -async function* invokeStream(input) { +async function* invokeStream(region, input) { + const client = new BedrockRuntimeClient({ region }); + try { const command = new InvokeModelWithResponseStreamCommand(input); @@ -63,7 +63,7 @@ exports.handler = awslambda.streamifyResponse( async (event, responseStream, context) => { context.callbackWaitsForEmptyEventLoop = false; - for await (const token of invokeStream?.(event.body) ?? []) { + for await (const token of invokeStream?.(event.body.bedrockRegion, event.body) ?? []) { responseStream.write(token); } responseStream.end(); diff --git a/src/utils/service.ts b/src/utils/service.ts index 503d308..55c05f4 100644 --- a/src/utils/service.ts +++ b/src/utils/service.ts @@ -32,6 +32,16 @@ if (!import.meta.env.VITE_STREAM_FUNC_NAME) { "環境変数にSTREAM_FUNC_ARNがありません" ); } +if (!import.meta.env.VITE_MODEL_ID) { + _loadingErrors.push( + "環境変数にMODEL_IDがありません" + ); +} +if (!import.meta.env.VITE_BEDROCK_REGION) { + _loadingErrors.push( + "環境変数にBEDROCK_REGIONがありません" + ); +} const hasErrors = _loadingErrors.length > 0; if (hasErrors) { console.error(JSON.stringify(_loadingErrors)); @@ -44,6 +54,8 @@ const stream_func_name: string = import.meta.env.VITE_STREAM_FUNC_NAME ?? "" const local_server = import.meta.env.VITE_SERVER_URL ?? "" const remote_server = awsconfig.aws_cloud_logic_custom[0].endpoint ?? "" export const serverUrl: string = local_server ? local_server : remote_server; +const model_id: string = import.meta.env.VITE_MODEL_ID ?? "" +const bedrock_region: string = import.meta.env.VITE_BEDROCK_REGION ?? "" let jwtToken = ""; Amplify.configure({ @@ -284,7 +296,8 @@ export async function* infStreamClaude(user_prompt: string) { } const req = { "body": { - "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "bedrockRegion": bedrock_region, + "modelId": model_id, "accept": "application/json", "contentType": "application/json", "body": JSON.stringify(body)