From e899b91e4b79e77359ac21ed340e32a354ae2aeb Mon Sep 17 00:00:00 2001 From: Christina Holland Date: Wed, 29 May 2024 13:52:50 -0700 Subject: [PATCH] Add some unit test cases (#159) --- .../main/src/models/generative-model.test.ts | 78 +++++++++++++++++-- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/packages/main/src/models/generative-model.test.ts b/packages/main/src/models/generative-model.test.ts index 7de4715f..b1881aac 100644 --- a/packages/main/src/models/generative-model.test.ts +++ b/packages/main/src/models/generative-model.test.ts @@ -114,6 +114,7 @@ describe("GenerativeModel", () => { value.includes(FunctionCallingMode.NONE) && value.includes("be friendly") && value.includes("temperature") && + value.includes("testField") && value.includes(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) ); }), @@ -151,7 +152,19 @@ describe("GenerativeModel", () => { it("generateContent overrides model values", async () => { const genModel = new GenerativeModel("apiKey", { model: "my-model", - generationConfig: { temperature: 0 }, + generationConfig: { + temperature: 0, + responseMimeType: "application/json", + responseSchema: { + type: FunctionDeclarationSchemaType.OBJECT, + properties: { + testField: { + type: FunctionDeclarationSchemaType.STRING, + properties: {}, + }, + }, + }, + }, safetySettings: [ { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, @@ -174,7 +187,18 @@ describe("GenerativeModel", () => { mockResponse as Response, ); await genModel.generateContent({ - generationConfig: { topK: 1 }, + generationConfig: { + topK: 1, + responseSchema: { + type: FunctionDeclarationSchemaType.OBJECT, + properties: { + newTestField: { + type: FunctionDeclarationSchemaType.STRING, + properties: {}, + }, + }, + }, + }, safetySettings: [ { category: HarmCategory.HARM_CATEGORY_HARASSMENT, @@ -201,6 +225,8 @@ describe("GenerativeModel", () => { value.includes(FunctionCallingMode.AUTO) && value.includes("be formal") && value.includes("topK") && + value.includes("newTestField") && + !value.includes("testField") && value.includes(HarmCategory.HARM_CATEGORY_HARASSMENT) ); }), @@ -225,7 +251,6 @@ describe("GenerativeModel", () => { mockResponse as Response, ); await genModel.countTokens("hello"); - console.log(makeRequestStub.args[0]); expect(makeRequestStub).to.be.calledWith( "models/my-model", request.Task.COUNT_TOKENS, @@ -276,9 +301,24 @@ describe("GenerativeModel", () => { it("passes params through to chat.sendMessage", async () => { const genModel = new GenerativeModel("apiKey", { model: "my-model", + generationConfig: { + temperature: 0, + responseMimeType: "application/json", + responseSchema: { + type: FunctionDeclarationSchemaType.OBJECT, + properties: { + testField: { + type: FunctionDeclarationSchemaType.STRING, + properties: {}, + }, + }, + }, + }, systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, }); expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly"); + expect(genModel.generationConfig.responseSchema.properties.testField).to + .exist; const mockResponse = getMockResponse( "unary-success-basic-reply-short.json", ); @@ -292,7 +332,7 @@ describe("GenerativeModel", () => { match.any, false, match((value: string) => { - return value.includes("be friendly"); + return value.includes("be friendly") && value.includes("testField"); }), {}, ); @@ -301,10 +341,25 @@ describe("GenerativeModel", () => { it("startChat overrides model values", async () => { const genModel = new GenerativeModel("apiKey", { model: "my-model", + generationConfig: { + temperature: 0, + responseMimeType: "application/json", + responseSchema: { + type: FunctionDeclarationSchemaType.OBJECT, + properties: { + testField: { + type: FunctionDeclarationSchemaType.STRING, + properties: {}, + }, + }, + }, + }, tools: [{ functionDeclarations: [{ name: "myfunc" }] }], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, }); + expect(genModel.generationConfig.responseSchema.properties.testField).to + .exist; expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal( FunctionCallingMode.NONE, @@ -319,6 +374,17 @@ describe("GenerativeModel", () => { await genModel .startChat({ tools: [{ functionDeclarations: [{ name: "otherfunc" }] }], + generationConfig: { + responseSchema: { + type: FunctionDeclarationSchemaType.OBJECT, + properties: { + newTestField: { + type: FunctionDeclarationSchemaType.STRING, + properties: {}, + }, + }, + }, + }, toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO }, }, @@ -334,7 +400,9 @@ describe("GenerativeModel", () => { return ( value.includes("otherfunc") && value.includes(FunctionCallingMode.AUTO) && - value.includes("be formal") + value.includes("be formal") && + value.includes("newTestField") && + !value.includes("testField") ); }), {},