From 3f9eee6589967164b8c1b86481468eaa5347afbd Mon Sep 17 00:00:00 2001 From: Yvonne Yu Date: Tue, 26 Nov 2024 09:36:41 -0800 Subject: [PATCH] feat: enable dynamic retrieval for Google Search Retrieval grounding PiperOrigin-RevId: 700370098 --- src/models/test/models_test.ts | 22 +++++++++++++--------- src/types/content.ts | 19 +++++++++++++++++-- system_test/end_to_end_sample_test.ts | 6 +++++- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/models/test/models_test.ts b/src/models/test/models_test.ts index 64877ea3..d9701428 100644 --- a/src/models/test/models_test.ts +++ b/src/models/test/models_test.ts @@ -29,6 +29,7 @@ import { HarmBlockThreshold, HarmCategory, HarmProbability, + Mode, RequestOptions, SafetyRating, SafetySetting, @@ -181,7 +182,10 @@ const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [ const TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL: GoogleSearchRetrievalTool[] = [ { googleSearchRetrieval: { - disableAttribution: false, + dynamicRetrievalConfig: { + dynamicThreshold: 0.5, + mode: Mode.MODE_DYNAMIC, + }, }, }, ]; @@ -332,7 +336,7 @@ describe('GenerativeModel startChat', () => { history: TEST_USER_CHAT_MESSAGE, }); const expectedBody = - '{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}'; + '{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}'; await chat.sendMessage(req); // @ts-ignore const actualBody = fetchSpy.calls.allArgs()[0][1].body; @@ -357,7 +361,7 @@ describe('GenerativeModel startChat', () => { tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); const expectedBody = - '{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}'; + '{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}'; await chat.sendMessage(req); // @ts-ignore const actualBody = fetchSpy.calls.allArgs()[0][1].body; @@ -550,7 +554,7 @@ describe('GenerativeModelPreview startChat', () => { history: TEST_USER_CHAT_MESSAGE, }); const expectedBody = - '{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}'; + '{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}'; await chat.sendMessage(req); // @ts-ignore const actualBody = fetchSpy.calls.allArgs()[0][1].body; @@ -575,7 +579,7 @@ describe('GenerativeModelPreview startChat', () => { tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }); const expectedBody = - '{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}'; + '{"contents":[{"role":"user","parts":[{"text":"How are you doing today?"}]},{"role":"user","parts":[{"text":"How are you doing today?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}'; await chat.sendMessage(req); // @ts-ignore const actualBody = fetchSpy.calls.allArgs()[0][1].body; @@ -1161,7 +1165,7 @@ describe('GenerativeModel generateContent', () => { tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }; const expectedBody = - '{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}'; + '{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}'; await model.generateContent(req); // @ts-ignore const actualBody = fetchSpy.calls.allArgs()[0][1].body; @@ -1638,7 +1642,7 @@ describe('GenerativeModelPreview generateContent', () => { tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, }; const expectedBody = - '{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}'; + '{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}'; await model.generateContent(req); // @ts-ignore const actualBody = fetchSpy.calls.allArgs()[0][1].body; @@ -1942,7 +1946,7 @@ describe('GenerativeModel generateContentStream', () => { }; spyOn(PostFetchFunctions, 'processStream').and.resolveTo(expectedResult); const expectedBody = - '{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}'; + '{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}'; await model.generateContent(req); // @ts-ignore const actualBody = fetchSpy.calls.allArgs()[0][1].body; @@ -2259,7 +2263,7 @@ describe('GenerativeModelPreview generateContentStream', () => { }; spyOn(PostFetchFunctions, 'processStream').and.resolveTo(expectedResult); const expectedBody = - '{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"disableAttribution":false}}]}'; + '{"contents":[{"role":"user","parts":[{"text":"What is the weater like in Boston?"}]}],"tools":[{"googleSearchRetrieval":{"dynamicRetrievalConfig":{"dynamicThreshold":0.5,"mode":"MODE_DYNAMIC"}}}]}'; await model.generateContent(req); // @ts-ignore const actualBody = fetchSpy.calls.allArgs()[0][1].body; diff --git a/src/types/content.ts b/src/types/content.ts index e1967821..d7b8a94d 100644 --- a/src/types/content.ts +++ b/src/types/content.ts @@ -969,11 +969,26 @@ export declare interface Retrieval { disableAttribution?: boolean; } +export enum Mode { + MODE_UNSPECIFIED = 'MODE_UNSPECIFIED', + MODE_DYNAMIC = 'MODE_DYNAMIC', +} + +/** Describes the options to customize dynamic retrieval. */ +export declare interface DynamicRetrievalConfig { + /** Optional. The threshold to be used in dynamic retrieval. If not set, a system default value is used. */ + dynamicThreshold?: number; + /** The mode of the predictor to be used in dynamic retrieval. */ + mode?: Mode; +} + /** * Tool to retrieve public web data for grounding, powered by Google. */ -// eslint-disable-next-line @typescript-eslint/no-empty-interface -export declare interface GoogleSearchRetrieval {} +export declare interface GoogleSearchRetrieval { + /** Specifies the dynamic retrieval configuration for the given source. */ + dynamicRetrievalConfig?: DynamicRetrievalConfig; +} /** * Retrieve from Vertex AI Search datastore for grounding. diff --git a/system_test/end_to_end_sample_test.ts b/system_test/end_to_end_sample_test.ts index 26d9abf5..af9d59c4 100644 --- a/system_test/end_to_end_sample_test.ts +++ b/system_test/end_to_end_sample_test.ts @@ -25,6 +25,7 @@ import { VertexAI, GenerateContentResponseHandler, GoogleApiError, + Mode, } from '../src'; import {FunctionDeclarationSchemaType} from '../src/types'; @@ -87,7 +88,10 @@ const TOOLS_WITH_FUNCTION_DECLARATION: FunctionDeclarationsTool[] = [ const TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL: GoogleSearchRetrievalTool[] = [ { googleSearchRetrieval: { - disableAttribution: false, + dynamicRetrievalConfig: { + dynamicThreshold: 0.2, + mode: Mode.MODE_DYNAMIC, + }, }, }, ];