Skip to content

Commit

Permalink
Merge 01b2231 into 1aadc47
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubox76 authored Apr 26, 2024
2 parents 1aadc47 + 01b2231 commit 392380a
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 8 deletions.
50 changes: 50 additions & 0 deletions packages/vertexai/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,31 @@ describe('GenerativeModel', () => {
);
restore();
});
it('passes text-only systemInstruction through to generateContent', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
systemInstruction: 'be friendly'
});
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
const mockResponse = getMockResponse(
'unary-success-basic-reply-short.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
await genModel.generateContent('hello');
expect(makeRequestStub).to.be.calledWith(
'publishers/google/models/my-model',
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return value.includes('be friendly');
}),
{}
);
restore();
});
it('generateContent overrides model values', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
Expand Down Expand Up @@ -169,6 +194,31 @@ describe('GenerativeModel', () => {
);
restore();
});
it('passes text-only systemInstruction through to chat.sendMessage', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
systemInstruction: 'be friendly'
});
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
const mockResponse = getMockResponse(
'unary-success-basic-reply-short.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
await genModel.startChat().sendMessage('hello');
expect(makeRequestStub).to.be.calledWith(
'publishers/google/models/my-model',
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return value.includes('be friendly');
}),
{}
);
restore();
});
it('startChat overrides model values', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
Expand Down
9 changes: 7 additions & 2 deletions packages/vertexai/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ import {
} from '../types';
import { ChatSession } from '../methods/chat-session';
import { countTokens } from '../methods/count-tokens';
import { formatGenerateContentInput } from '../requests/request-helpers';
import {
formatGenerateContentInput,
formatSystemInstruction
} from '../requests/request-helpers';
import { VertexAI } from '../public-types';
import { ERROR_FACTORY, VertexError } from '../errors';
import { ApiSettings } from '../types/internal';
Expand Down Expand Up @@ -93,7 +96,9 @@ export class GenerativeModel {
this.safetySettings = modelParams.safetySettings || [];
this.tools = modelParams.tools;
this.toolConfig = modelParams.toolConfig;
this.systemInstruction = modelParams.systemInstruction;
this.systemInstruction = formatSystemInstruction(
modelParams.systemInstruction
);
this.requestOptions = requestOptions || {};
}

Expand Down
175 changes: 175 additions & 0 deletions packages/vertexai/src/requests/request-helpers.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { expect, use } from 'chai';
import sinonChai from 'sinon-chai';
import { Content } from '../types';
import { formatGenerateContentInput } from './request-helpers';

use(sinonChai);

describe('request formatting methods', () => {
describe('formatGenerateContentInput', () => {
it('formats a text string into a request', () => {
const result = formatGenerateContentInput('some text content');
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'some text content' }]
}
]
});
});
it('formats an array of strings into a request', () => {
const result = formatGenerateContentInput(['txt1', 'txt2']);
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'txt1' }, { text: 'txt2' }]
}
]
});
});
it('formats an array of Parts into a request', () => {
const result = formatGenerateContentInput([
{ text: 'txt1' },
{ text: 'txtB' }
]);
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'txt1' }, { text: 'txtB' }]
}
]
});
});
it('formats a mixed array into a request', () => {
const result = formatGenerateContentInput(['txtA', { text: 'txtB' }]);
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }, { text: 'txtB' }]
}
]
});
});
it('preserves other properties of request', () => {
const result = formatGenerateContentInput({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
generationConfig: { topK: 100 }
});
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
generationConfig: { topK: 100 }
});
});
it('formats systemInstructions if provided as text', () => {
const result = formatGenerateContentInput({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
systemInstruction: 'be excited'
});
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
});
});
it('formats systemInstructions if provided as Part', () => {
const result = formatGenerateContentInput({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
systemInstruction: { text: 'be excited' }
});
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
});
});
it('formats systemInstructions if provided as Content (no role)', () => {
const result = formatGenerateContentInput({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
systemInstruction: { parts: [{ text: 'be excited' }] } as Content
});
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
});
});
it('passes thru systemInstructions if provided as Content', () => {
const result = formatGenerateContentInput({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
});
expect(result).to.deep.equal({
contents: [
{
role: 'user',
parts: [{ text: 'txtA' }]
}
],
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
});
});
});
});
31 changes: 29 additions & 2 deletions packages/vertexai/src/requests/request-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@
import { Content, GenerateContentRequest, Part } from '../types';
import { ERROR_FACTORY, VertexError } from '../errors';

export function formatSystemInstruction(
input?: string | Part | Content
): Content | undefined {
// null or undefined
if (input == null) {
return undefined;
} else if (typeof input === 'string') {
return { role: 'system', parts: [{ text: input }] } as Content;
} else if ((input as Part).text) {
return { role: 'system', parts: [input as Part] };
} else if ((input as Content).parts) {
if (!(input as Content).role) {
return { role: 'system', parts: (input as Content).parts };
} else {
return input as Content;
}
}
}

export function formatNewContent(
request: string | Array<string | Part>
): Content {
Expand Down Expand Up @@ -84,10 +103,18 @@ function assignRoleToPartsAndValidateSendMessageRequest(
export function formatGenerateContentInput(
params: GenerateContentRequest | string | Array<string | Part>
): GenerateContentRequest {
let formattedRequest: GenerateContentRequest;
if ((params as GenerateContentRequest).contents) {
return params as GenerateContentRequest;
formattedRequest = params as GenerateContentRequest;
} else {
// Array or string
const content = formatNewContent(params as string | Array<string | Part>);
return { contents: [content] };
formattedRequest = { contents: [content] };
}
if ((params as GenerateContentRequest).systemInstruction) {
formattedRequest.systemInstruction = formatSystemInstruction(
(params as GenerateContentRequest).systemInstruction
);
}
return formattedRequest;
}
8 changes: 4 additions & 4 deletions packages/vertexai/src/types/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

import { Content } from './content';
import { Content, Part } from './content';
import {
FunctionCallingMode,
HarmBlockMethod,
Expand All @@ -40,7 +40,7 @@ export interface ModelParams extends BaseParams {
model: string;
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
}

/**
Expand All @@ -51,7 +51,7 @@ export interface GenerateContentRequest extends BaseParams {
contents: Content[];
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
}

/**
Expand Down Expand Up @@ -87,7 +87,7 @@ export interface StartChatParams extends BaseParams {
history?: Content[];
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
systemInstruction?: string | Part | Content;
}

/**
Expand Down

0 comments on commit 392380a

Please sign in to comment.