Skip to content

Commit

Permalink
feat: Added support for region-prefixed Bedrock models (#2947)
Browse files Browse the repository at this point in the history
  • Loading branch information
bizob2828 authored Feb 12, 2025
1 parent 772f007 commit 6acf535
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 5 deletions.
4 changes: 2 additions & 2 deletions lib/llm-events/aws-bedrock/bedrock-command.js
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ class BedrockCommand {
}

isClaude() {
return this.#modelId.startsWith('anthropic.claude-v')
return this.#modelId.split('.').slice(-2).join('.').startsWith('anthropic.claude-v')
}

isClaude3() {
return this.#modelId.startsWith('anthropic.claude-3')
return this.#modelId.split('.').slice(-2).join('.').startsWith('anthropic.claude-3')
}

isCohere() {
Expand Down
30 changes: 27 additions & 3 deletions test/lib/aws-server-stubs/ai-server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,44 @@ function handler(req, res) {
case 'anthropic.claude-v1':
case 'anthropic.claude-instant-v1':
case 'anthropic.claude-v2':
case 'anthropic.claude-v2:1': {
case 'anthropic.claude-v2:1':
case 'us.anthropic.claude-v1':
case 'us.anthropic.claude-instant-v1':
case 'us.anthropic.claude-v2':
case 'us.anthropic.claude-v2:1':
case 'eu.anthropic.claude-v1':
case 'eu.anthropic.claude-instant-v1':
case 'eu.anthropic.claude-v2':
case 'eu.anthropic.claude-v2:1':
case 'apac.anthropic.claude-v1':
case 'apac.anthropic.claude-instant-v1':
case 'apac.anthropic.claude-v2':
case 'apac.anthropic.claude-v2:1':{
response = responses.claude.get(payload.prompt)
break
}

case 'anthropic.claude-3-haiku-20240307-v1:0':
case 'anthropic.claude-3-opus-20240229-v1:0':
case 'anthropic.claude-3-sonnet-20240229-v1:0': {
case 'anthropic.claude-3-sonnet-20240229-v1:0':
case 'us.anthropic.claude-3-haiku-20240307-v1:0':
case 'us.anthropic.claude-3-opus-20240229-v1:0':
case 'us.anthropic.claude-3-sonnet-20240229-v1:0':
case 'eu.anthropic.claude-3-haiku-20240307-v1:0':
case 'eu.anthropic.claude-3-opus-20240229-v1:0':
case 'eu.anthropic.claude-3-sonnet-20240229-v1:0':
case 'apac.anthropic.claude-3-haiku-20240307-v1:0':
case 'apac.anthropic.claude-3-opus-20240229-v1:0':
case 'apac.anthropic.claude-3-sonnet-20240229-v1:0': {
response = responses.claude3.get(payload?.messages?.[0]?.content)
break
}

// Chunked claude model
case 'anthropic.claude-3-5-sonnet-20240620-v1:0': {
case 'anthropic.claude-3-5-sonnet-20240620-v1:0':
case 'us.anthropic.claude-3-5-sonnet-20240620-v1:0':
case 'eu.anthropic.claude-3-5-sonnet-20240620-v1:0':
case 'apac.anthropic.claude-3-5-sonnet-20240620-v1:0':{
response = responses.claude3.get(payload?.messages?.[0]?.content?.[0].text)
break
}
Expand Down
130 changes: 130 additions & 0 deletions test/unit/llm-events/aws-bedrock/bedrock-command.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ const claude = {
}
}

const regionClaude = {
modelId: 'us.anthropic.claude-v1',
body: {
prompt: '\n\nHuman: yes\n\nAssistant:'
}
}

const claude35 = {
modelId: 'anthropic.claude-3-5-sonnet-20240620-v1:0',
body: {
Expand All @@ -35,13 +42,30 @@ const claude35 = {
}
}

const regionClaude35 = {
modelId: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
body: {
messages: [
{ role: 'user', content: [{ type: 'text', text: 'who are' }] },
{ role: 'assistant', content: [{ type: 'text', text: 'researching' }] },
{ role: 'user', content: [{ type: 'text', text: 'you' }] }
]
}
}
const claude3 = {
modelId: 'anthropic.claude-3-haiku-20240307-v1:0',
body: {
messages: [{ role: 'user', content: 'who are you' }]
}
}

const regionClaude3 = {
modelId: 'us.anthropic.claude-3-haiku-20240307-v1:0',
body: {
messages: [{ role: 'user', content: 'who are you' }]
}
}

const cohere = {
modelId: 'cohere.command-text-v14',
body: {
Expand Down Expand Up @@ -154,6 +178,17 @@ test('claude minimal command works', async (t) => {
assert.equal(cmd.temperature, undefined)
})

test('region specific claude minimal command works', async (t) => {
t.nr.updatePayload(structuredClone(regionClaude))
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude(), true)
assert.equal(cmd.maxTokens, undefined)
assert.equal(cmd.modelId, regionClaude.modelId)
assert.equal(cmd.modelType, 'completion')
assert.deepEqual(cmd.prompt, [{ role: 'user', content: claude.body.prompt }])
assert.equal(cmd.temperature, undefined)
})

test('claude complete command works', async (t) => {
const payload = structuredClone(claude)
payload.body.max_tokens_to_sample = 25
Expand All @@ -168,6 +203,20 @@ test('claude complete command works', async (t) => {
assert.equal(cmd.temperature, payload.body.temperature)
})

test('region specific claude complete command works', async (t) => {
const payload = structuredClone(regionClaude)
payload.body.max_tokens_to_sample = 25
payload.body.temperature = 0.5
t.nr.updatePayload(payload)
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude(), true)
assert.equal(cmd.maxTokens, 25)
assert.equal(cmd.modelId, payload.modelId)
assert.equal(cmd.modelType, 'completion')
assert.deepEqual(cmd.prompt, [{ role: 'user', content: payload.body.prompt }])
assert.equal(cmd.temperature, payload.body.temperature)
})

test('claude3 minimal command works', async (t) => {
t.nr.updatePayload(structuredClone(claude3))
const cmd = new BedrockCommand(t.nr.input)
Expand All @@ -179,6 +228,17 @@ test('claude3 minimal command works', async (t) => {
assert.equal(cmd.temperature, undefined)
})

test('region specific claude3 minimal command works', async (t) => {
t.nr.updatePayload(structuredClone(regionClaude3))
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude3(), true)
assert.equal(cmd.maxTokens, undefined)
assert.equal(cmd.modelId, regionClaude3.modelId)
assert.equal(cmd.modelType, 'completion')
assert.deepEqual(cmd.prompt, claude3.body.messages)
assert.equal(cmd.temperature, undefined)
})

test('claude3 complete command works', async (t) => {
const payload = structuredClone(claude3)
payload.body.max_tokens = 25
Expand All @@ -193,6 +253,20 @@ test('claude3 complete command works', async (t) => {
assert.equal(cmd.temperature, payload.body.temperature)
})

test('region specific claude3 complete command works', async (t) => {
const payload = structuredClone(regionClaude3)
payload.body.max_tokens = 25
payload.body.temperature = 0.5
t.nr.updatePayload(payload)
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude3(), true)
assert.equal(cmd.maxTokens, 25)
assert.equal(cmd.modelId, payload.modelId)
assert.equal(cmd.modelType, 'completion')
assert.deepEqual(cmd.prompt, payload.body.messages)
assert.equal(cmd.temperature, payload.body.temperature)
})

test('claude35 minimal command works with claude 3 api', async (t) => {
t.nr.updatePayload(structuredClone(claude3))
const cmd = new BedrockCommand(t.nr.input)
Expand All @@ -217,6 +291,19 @@ test('claude35 malformed payload produces reasonable values', async (t) => {
assert.equal(cmd.temperature, undefined)
})

test('region specific claude35 malformed payload produces reasonable values', async (t) => {
const malformedPayload = structuredClone(regionClaude35)
malformedPayload.body = {}
t.nr.updatePayload(malformedPayload)
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude3(), true)
assert.equal(cmd.maxTokens, undefined)
assert.equal(cmd.modelId, regionClaude35.modelId)
assert.equal(cmd.modelType, 'completion')
assert.deepEqual(cmd.prompt, [])
assert.equal(cmd.temperature, undefined)
})

test('claude35 skips a message that is null in `body.messages`', async (t) => {
const malformedPayload = structuredClone(claude35)
malformedPayload.body.messages = [{ role: 'user', content: 'who are you' }, null]
Expand All @@ -226,6 +313,15 @@ test('claude35 skips a message that is null in `body.messages`', async (t) => {
assert.deepEqual(cmd.prompt, [{ role: 'user', content: 'who are you' }])
})

test('region specific claude35 skips a message that is null in `body.messages`', async (t) => {
const malformedPayload = structuredClone(regionClaude35)
malformedPayload.body.messages = [{ role: 'user', content: 'who are you' }, null]
t.nr.updatePayload(malformedPayload)
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude3(), true)
assert.deepEqual(cmd.prompt, [{ role: 'user', content: 'who are you' }])
})

test('claude35 handles defaulting prompt to empty array when `body.messages` is null', async (t) => {
const malformedPayload = structuredClone(claude35)
malformedPayload.body.messages = null
Expand All @@ -235,6 +331,15 @@ test('claude35 handles defaulting prompt to empty array when `body.messages` is
assert.deepEqual(cmd.prompt, [])
})

test('region specific claude35 handles defaulting prompt to empty array when `body.messages` is null', async (t) => {
const malformedPayload = structuredClone(regionClaude35)
malformedPayload.body.messages = null
t.nr.updatePayload(malformedPayload)
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude3(), true)
assert.deepEqual(cmd.prompt, [])
})

test('claude35 minimal command works', async (t) => {
t.nr.updatePayload(structuredClone(claude35))
const cmd = new BedrockCommand(t.nr.input)
Expand All @@ -246,6 +351,17 @@ test('claude35 minimal command works', async (t) => {
assert.equal(cmd.temperature, undefined)
})

test('region specific claude35 minimal command works', async (t) => {
t.nr.updatePayload(structuredClone(regionClaude35))
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude3(), true)
assert.equal(cmd.maxTokens, undefined)
assert.equal(cmd.modelId, regionClaude35.modelId)
assert.equal(cmd.modelType, 'completion')
assert.deepEqual(cmd.prompt, [{ role: 'user', content: 'who are' }, { role: 'assistant', content: 'researching' }, { role: 'user', content: 'you' }])
assert.equal(cmd.temperature, undefined)
})

test('claude35 complete command works', async (t) => {
const payload = structuredClone(claude35)
payload.body.max_tokens = 25
Expand All @@ -260,6 +376,20 @@ test('claude35 complete command works', async (t) => {
assert.equal(cmd.temperature, payload.body.temperature)
})

test('region specific claude35 complete command works', async (t) => {
const payload = structuredClone(regionClaude35)
payload.body.max_tokens = 25
payload.body.temperature = 0.5
t.nr.updatePayload(payload)
const cmd = new BedrockCommand(t.nr.input)
assert.equal(cmd.isClaude3(), true)
assert.equal(cmd.maxTokens, 25)
assert.equal(cmd.modelId, payload.modelId)
assert.equal(cmd.modelType, 'completion')
assert.deepEqual(cmd.prompt, [{ role: 'user', content: 'who are' }, { role: 'assistant', content: 'researching' }, { role: 'user', content: 'you' }])
assert.equal(cmd.temperature, payload.body.temperature)
})

test('cohere minimal command works', async (t) => {
t.nr.updatePayload(structuredClone(cohere))
const cmd = new BedrockCommand(t.nr.input)
Expand Down
72 changes: 72 additions & 0 deletions test/unit/llm-events/aws-bedrock/stream-handler.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,44 @@ test('handles claude streams', async (t) => {
assert.equal(br.statusCode, 200)
})

test('handles region specific claude streams', async (t) => {
t.nr.passThroughParams.bedrockCommand.isClaude = () => true
t.nr.chunks = [
{ completion: '1', stop_reason: null },
{ completion: '2', stop_reason: 'done', ...t.nr.metrics }
]
const handler = new StreamHandler(t.nr)

assert.equal(handler.generator.name, 'handleClaude')
for await (const event of handler.generator()) {
assert.equal(event.chunk.bytes.constructor, Uint8Array)
}
assert.deepStrictEqual(handler.response, {
response: {
headers: {
'x-amzn-requestid': 'aws-req-1'
},
statusCode: 200
},
output: {
body: new TextEncoder().encode(JSON.stringify({ completion: '12', stop_reason: 'done' }))
}
})

const bc = new BedrockCommand({
modelId: 'us.anthropic.claude-v1',
body: JSON.stringify({
prompt: 'prompt',
maxTokens: 5
})
})
const br = new BedrockResponse({ bedrockCommand: bc, response: handler.response })
assert.equal(br.completions.length, 1)
assert.equal(br.finishReason, 'done')
assert.equal(br.requestId, 'aws-req-1')
assert.equal(br.statusCode, 200)
})

test('handles claude3streams', async (t) => {
t.nr.passThroughParams.bedrockCommand.isClaude3 = () => true
t.nr.chunks = [
Expand Down Expand Up @@ -151,6 +189,40 @@ test('handles claude3streams', async (t) => {
assert.equal(br.statusCode, 200)
})

test('handles region specific claude3streams', async (t) => {
t.nr.passThroughParams.bedrockCommand.isClaude3 = () => true
t.nr.chunks = [
{ type: 'content_block_delta', delta: { type: 'text_delta', text: '42' } },
{ type: 'message_delta', delta: { stop_reason: 'done' } },
{ type: 'message_stop', ...t.nr.metrics }
]
const handler = new StreamHandler(t.nr)

assert.equal(handler.generator.name, 'handleClaude3')
for await (const event of handler.generator()) {
assert.equal(event.chunk.bytes.constructor, Uint8Array)
}
const foundBody = JSON.parse(new TextDecoder().decode(handler.response.output.body))
assert.deepStrictEqual(foundBody, {
completions: ['42'],
stop_reason: 'done',
type: 'message_stop'
})

const bc = new BedrockCommand({
modelId: 'us.anthropic.claude-3-haiku-20240307-v1:0',
body: JSON.stringify({
messages: [{ content: 'prompt' }],
maxTokens: 5
})
})
const br = new BedrockResponse({ bedrockCommand: bc, response: handler.response })
assert.equal(br.completions.length, 1)
assert.equal(br.finishReason, 'done')
assert.equal(br.requestId, 'aws-req-1')
assert.equal(br.statusCode, 200)
})

test('handles cohere streams', async (t) => {
t.nr.passThroughParams.bedrockCommand.isCohere = () => true
t.nr.chunks = [
Expand Down
Loading

0 comments on commit 6acf535

Please sign in to comment.