diff --git a/packages/vertexai/src/models/generative-model.test.ts b/packages/vertexai/src/models/generative-model.test.ts index a97b7bfc003..7b0287492da 100644 --- a/packages/vertexai/src/models/generative-model.test.ts +++ b/packages/vertexai/src/models/generative-model.test.ts @@ -94,6 +94,31 @@ describe('GenerativeModel', () => { ); restore(); }); + it('passes text-only systemInstruction through to generateContent', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'my-model', + systemInstruction: 'be friendly' + }); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.generateContent('hello'); + expect(makeRequestStub).to.be.calledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + match.any, + false, + match((value: string) => { + return value.includes('be friendly'); + }), + {} + ); + restore(); + }); it('generateContent overrides model values', async () => { const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model', @@ -169,6 +194,31 @@ describe('GenerativeModel', () => { ); restore(); }); + it('passes text-only systemInstruction through to chat.sendMessage', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { + model: 'my-model', + systemInstruction: 'be friendly' + }); + expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); + const mockResponse = getMockResponse( + 'unary-success-basic-reply-short.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.startChat().sendMessage('hello'); + expect(makeRequestStub).to.be.calledWith( + 'publishers/google/models/my-model', + request.Task.GENERATE_CONTENT, + match.any, + false, + match((value: string) => { + return value.includes('be friendly'); + }), + {} + ); + restore(); + }); it('startChat overrides model values', async () => { const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model', diff --git a/packages/vertexai/src/models/generative-model.ts b/packages/vertexai/src/models/generative-model.ts index eec3297de9f..f68bc00b295 100644 --- a/packages/vertexai/src/models/generative-model.ts +++ b/packages/vertexai/src/models/generative-model.ts @@ -37,7 +37,10 @@ import { } from '../types'; import { ChatSession } from '../methods/chat-session'; import { countTokens } from '../methods/count-tokens'; -import { formatGenerateContentInput } from '../requests/request-helpers'; +import { + formatGenerateContentInput, + formatSystemInstruction +} from '../requests/request-helpers'; import { VertexAI } from '../public-types'; import { ERROR_FACTORY, VertexError } from '../errors'; import { ApiSettings } from '../types/internal'; @@ -93,7 +96,9 @@ export class GenerativeModel { this.safetySettings = modelParams.safetySettings || []; this.tools = modelParams.tools; this.toolConfig = modelParams.toolConfig; - this.systemInstruction = modelParams.systemInstruction; + this.systemInstruction = formatSystemInstruction( + modelParams.systemInstruction + ); this.requestOptions = requestOptions || {}; } diff --git a/packages/vertexai/src/requests/request-helpers.test.ts b/packages/vertexai/src/requests/request-helpers.test.ts new file mode 100644 index 00000000000..41278ee657d --- /dev/null +++ b/packages/vertexai/src/requests/request-helpers.test.ts @@ -0,0 +1,175 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expect, use } from 'chai'; +import sinonChai from 'sinon-chai'; +import { Content } from '../types'; +import { formatGenerateContentInput } from './request-helpers'; + +use(sinonChai); + +describe('request formatting methods', () => { + describe('formatGenerateContentInput', () => { + it('formats a text string into a request', () => { + const result = formatGenerateContentInput('some text content'); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'some text content' }] + } + ] + }); + }); + it('formats an array of strings into a request', () => { + const result = formatGenerateContentInput(['txt1', 'txt2']); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'txt1' }, { text: 'txt2' }] + } + ] + }); + }); + it('formats an array of Parts into a request', () => { + const result = formatGenerateContentInput([ + { text: 'txt1' }, + { text: 'txtB' } + ]); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'txt1' }, { text: 'txtB' }] + } + ] + }); + }); + it('formats a mixed array into a request', () => { + const result = formatGenerateContentInput(['txtA', { text: 'txtB' }]); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }, { text: 'txtB' }] + } + ] + }); + }); + it('preserves other properties of request', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + generationConfig: { topK: 100 } + }); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + generationConfig: { topK: 100 } + }); + }); + it('formats systemInstructions if provided as text', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + systemInstruction: 'be excited' + }); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] } + }); + }); + it('formats systemInstructions if provided as Part', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + systemInstruction: { text: 'be excited' } + }); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] } + }); + }); + it('formats systemInstructions if provided as Content (no role)', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + systemInstruction: { parts: [{ text: 'be excited' }] } as Content + }); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] } + }); + }); + it('passes thru systemInstructions if provided as Content', () => { + const result = formatGenerateContentInput({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] } + }); + expect(result).to.deep.equal({ + contents: [ + { + role: 'user', + parts: [{ text: 'txtA' }] + } + ], + systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] } + }); + }); + }); +}); diff --git a/packages/vertexai/src/requests/request-helpers.ts b/packages/vertexai/src/requests/request-helpers.ts index 96f4ffb1ea7..0b7ce4ed4d2 100644 --- a/packages/vertexai/src/requests/request-helpers.ts +++ b/packages/vertexai/src/requests/request-helpers.ts @@ -18,6 +18,25 @@ import { Content, GenerateContentRequest, Part } from '../types'; import { ERROR_FACTORY, VertexError } from '../errors'; +export function formatSystemInstruction( + input?: string | Part | Content +): Content | undefined { + // null or undefined + if (input == null) { + return undefined; + } else if (typeof input === 'string') { + return { role: 'system', parts: [{ text: input }] } as Content; + } else if ((input as Part).text) { + return { role: 'system', parts: [input as Part] }; + } else if ((input as Content).parts) { + if (!(input as Content).role) { + return { role: 'system', parts: (input as Content).parts }; + } else { + return input as Content; + } + } +} + export function formatNewContent( request: string | Array ): Content { @@ -84,10 +103,18 @@ function assignRoleToPartsAndValidateSendMessageRequest( export function formatGenerateContentInput( params: GenerateContentRequest | string | Array ): GenerateContentRequest { + let formattedRequest: GenerateContentRequest; if ((params as GenerateContentRequest).contents) { - return params as GenerateContentRequest; + formattedRequest = params as GenerateContentRequest; } else { + // Array or string const content = formatNewContent(params as string | Array); - return { contents: [content] }; + formattedRequest = { contents: [content] }; + } + if ((params as GenerateContentRequest).systemInstruction) { + formattedRequest.systemInstruction = formatSystemInstruction( + (params as GenerateContentRequest).systemInstruction + ); } + return formattedRequest; } diff --git a/packages/vertexai/src/types/requests.ts b/packages/vertexai/src/types/requests.ts index aa1e5bafa10..70ce881ff8a 100644 --- a/packages/vertexai/src/types/requests.ts +++ b/packages/vertexai/src/types/requests.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { Content } from './content'; +import { Content, Part } from './content'; import { FunctionCallingMode, HarmBlockMethod, @@ -40,7 +40,7 @@ export interface ModelParams extends BaseParams { model: string; tools?: Tool[]; toolConfig?: ToolConfig; - systemInstruction?: Content; + systemInstruction?: string | Part | Content; } /** @@ -51,7 +51,7 @@ export interface GenerateContentRequest extends BaseParams { contents: Content[]; tools?: Tool[]; toolConfig?: ToolConfig; - systemInstruction?: Content; + systemInstruction?: string | Part | Content; } /** @@ -87,7 +87,7 @@ export interface StartChatParams extends BaseParams { history?: Content[]; tools?: Tool[]; toolConfig?: ToolConfig; - systemInstruction?: Content; + systemInstruction?: string | Part | Content; } /**