Skip to content

Commit

Permalink
Add some unit test cases (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubox76 authored May 29, 2024
1 parent d25a6ff commit e899b91
Showing 1 changed file with 73 additions and 5 deletions.
78 changes: 73 additions & 5 deletions packages/main/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ describe("GenerativeModel", () => {
value.includes(FunctionCallingMode.NONE) &&
value.includes("be friendly") &&
value.includes("temperature") &&
value.includes("testField") &&
value.includes(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
);
}),
Expand Down Expand Up @@ -151,7 +152,19 @@ describe("GenerativeModel", () => {
it("generateContent overrides model values", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
generationConfig: { temperature: 0 },
generationConfig: {
temperature: 0,
responseMimeType: "application/json",
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
testField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
Expand All @@ -174,7 +187,18 @@ describe("GenerativeModel", () => {
mockResponse as Response,
);
await genModel.generateContent({
generationConfig: { topK: 1 },
generationConfig: {
topK: 1,
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
newTestField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
Expand All @@ -201,6 +225,8 @@ describe("GenerativeModel", () => {
value.includes(FunctionCallingMode.AUTO) &&
value.includes("be formal") &&
value.includes("topK") &&
value.includes("newTestField") &&
!value.includes("testField") &&
value.includes(HarmCategory.HARM_CATEGORY_HARASSMENT)
);
}),
Expand All @@ -225,7 +251,6 @@ describe("GenerativeModel", () => {
mockResponse as Response,
);
await genModel.countTokens("hello");
console.log(makeRequestStub.args[0]);
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.COUNT_TOKENS,
Expand Down Expand Up @@ -276,9 +301,24 @@ describe("GenerativeModel", () => {
it("passes params through to chat.sendMessage", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
generationConfig: {
temperature: 0,
responseMimeType: "application/json",
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
testField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
Expand All @@ -292,7 +332,7 @@ describe("GenerativeModel", () => {
match.any,
false,
match((value: string) => {
return value.includes("be friendly");
return value.includes("be friendly") && value.includes("testField");
}),
{},
);
Expand All @@ -301,10 +341,25 @@ describe("GenerativeModel", () => {
it("startChat overrides model values", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
generationConfig: {
temperature: 0,
responseMimeType: "application/json",
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
testField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
Expand All @@ -319,6 +374,17 @@ describe("GenerativeModel", () => {
await genModel
.startChat({
tools: [{ functionDeclarations: [{ name: "otherfunc" }] }],
generationConfig: {
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
newTestField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
toolConfig: {
functionCallingConfig: { mode: FunctionCallingMode.AUTO },
},
Expand All @@ -334,7 +400,9 @@ describe("GenerativeModel", () => {
return (
value.includes("otherfunc") &&
value.includes(FunctionCallingMode.AUTO) &&
value.includes("be formal")
value.includes("be formal") &&
value.includes("newTestField") &&
!value.includes("testField")
);
}),
{},
Expand Down

0 comments on commit e899b91

Please sign in to comment.