Skip to content

Commit

Permalink
test(firestore-multimodal-genai): add genkit client unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Nov 29, 2024
1 parent 619cae2 commit c91b358
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 5 deletions.
217 changes: 217 additions & 0 deletions firestore-multimodal-genai/functions/__tests__/genkit/client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import {GenkitGenerativeClient} from '../../src/generative-client/genkit';
import {logger} from 'firebase-functions/v1';
import {GenerateResponse, genkit} from 'genkit';
import {googleAI} from '@genkit-ai/googleai';
import {vertexAI} from '@genkit-ai/vertexai';
import {Config} from '../../src/config.js';
import {HarmBlockThreshold, HarmCategory} from '@google/generative-ai';

// Mock the genkit library
jest.mock('genkit', () => ({
genkit: jest.fn().mockReturnValue({generate: jest.fn()}),
}));

jest.mock('@genkit-ai/googleai', () => ({
googleAI: jest.fn().mockReturnValue({name: 'googleai'}),
gemini10Pro: {name: 'googleai/gemini-1.0-pro', withVersion: jest.fn()},
gemini15Flash: {name: 'googleai/gemini-1.5-flash', withVersion: jest.fn()},
gemini15Pro: {name: 'googleai/gemini-1.5-pro', withVersion: jest.fn()},
}));

jest.mock('@genkit-ai/vertexai', () => ({
vertexAI: jest.fn().mockReturnValue({name: 'vertexai'}),
gemini10Pro: {name: 'vertexai/gemini-1.0-pro', withVersion: jest.fn()},
gemini15Flash: {name: 'vertexai/gemini-1.5-flash', withVersion: jest.fn()},
gemini15Pro: {name: 'vertexai/gemini-1.5-pro', withVersion: jest.fn()},
}));

jest.mock('../../src/generative-client/image_utils.ts', () => ({
getImageBase64: jest.fn(() => Promise.resolve('base64EncodedImage')),
}));

describe('GenkitGenerativeClient', () => {
const mockConfig: Config = {
vertex: {
model: 'gemini-1.5-flash',
},
googleAi: {
model: 'gemini-1.5-flash',
apiKey: 'test-api-key',
},
model: 'gemini-1.5-flash',
location: 'us-central1',
projectId: 'test-project',
instanceId: 'test-instance',
prompt: 'Test prompt',
responseField: 'output',
collectionName: 'users/{uid}/discussions/{discussionId}/messages',
temperature: 0.7,
topP: 0.9,
topK: 50,
candidates: {
field: 'candidates',
count: 1,
shouldIncludeCandidatesField: false,
},
maxOutputTokens: 256,
maxOutputTokensVertex: 1024,
provider: 'google-ai',
apiKey: 'test-api-key',
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
bucketName: 'test-bucket',
imageField: 'image',
};

const mockGenerateResponse = {
text: 'Generated text response',
finishReason: 'stop',
usage: {
inputTokens: 10,
outputTokens: 20,
totalTokens: 30,
},
custom: null,
raw: null,
};

beforeEach(() => {
jest.clearAllMocks();
});

it('should initialize with correct plugin and client for Google AI', () => {
new GenkitGenerativeClient(mockConfig);

expect(googleAI).toHaveBeenCalledWith({apiKey: 'test-api-key'});
expect(genkit).toHaveBeenCalledWith({
plugins: [expect.anything()],
});
});

it('should initialize with correct plugin and client for Vertex AI', () => {
const vertexConfig: Config = {
...mockConfig,
provider: 'vertex-ai',
googleAi: {model: 'gemini-1.5-flash', apiKey: '123'},
model: 'gemini-1.5-flash',
};
new GenkitGenerativeClient(vertexConfig);

expect(vertexAI).toHaveBeenCalledWith({location: 'us-central1'});
expect(genkit).toHaveBeenCalledWith({
plugins: [expect.anything()],
});
});

it('should throw an error if no API key is provided for Google AI', () => {
const invalidConfig: Config = {
...mockConfig,
googleAi: {model: 'gemini-1.5-flash', apiKey: undefined},
};

expect(() => new GenkitGenerativeClient(invalidConfig)).toThrow(
'API key required for Google AI.'
);
});

it('should throw an error if an invalid provider is specified', () => {
const invalidConfig: Config = {...mockConfig, provider: 'invalid-provider'};

expect(() => new GenkitGenerativeClient(invalidConfig)).toThrow(
'Invalid provider specified.'
);
});

it('should create the correct model reference', () => {
const modelReference = GenkitGenerativeClient.createModelReference(
'gemini-1.5-flash',
'google-ai'
);

expect(modelReference.name).toBe('googleai/gemini-1.5-flash');
});

it('should call generate with correct options and return response', async () => {
const client = new GenkitGenerativeClient(mockConfig);
client.client.generate = jest.fn(() =>
Promise.resolve(mockGenerateResponse as unknown as GenerateResponse<any>)
);

const response = await client.generate('Test prompt');

expect(client.client.generate).toHaveBeenCalledWith({
messages: [
{
role: 'user',
content: [{text: 'Test prompt'}],
},
],
model: expect.any(Object),
config: expect.any(Object),
});

expect(response).toEqual({candidates: ['Generated text response']});
});

it('should process an image if provided', async () => {
const client = new GenkitGenerativeClient(mockConfig);
client.client.generate = jest.fn(() =>
Promise.resolve(mockGenerateResponse as unknown as GenerateResponse<any>)
);

const response = await client.generate('Test prompt', {
image: 'path/to/image.jpg',
});

expect(client.client.generate).toHaveBeenCalledWith({
messages: [
{
role: 'user',
content: [
{text: 'Test prompt'},
{media: {url: ''}},
],
},
],
model: {
name: 'googleai/gemini-1.5-flash',
withVersion: expect.any(Function),
},
config: {
topP: 0.9,
topK: 50,
temperature: 0.7,
maxOutputTokens: 256,
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
},
image: 'path/to/image.jpg',
});

expect(response).toEqual({candidates: ['Generated text response']});
});

it('should log an error and throw if generate fails', async () => {
const client = new GenkitGenerativeClient(mockConfig);
const error = new Error('Generation failed');
client.client.generate = jest.fn(() => Promise.reject(error));
logger.error = jest.fn();

await expect(client.generate('Test prompt')).rejects.toThrow(
'Content generation failed.'
);

expect(logger.error).toHaveBeenCalledWith(
'Failed to generate content:',
error
);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import config from '../../src/config';
import {generateText} from '../../src/index';
import {expectToProcessCorrectly} from '../util';

// Type definitions for improved readability
type DocumentData = admin.firestore.DocumentData;
type WrappedGenerateText = WrappedFunction<
Change<DocumentSnapshot | undefined>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import {
gemini15Flash as gemini15FlashGoogleAI,
gemini15Pro as gemini15ProGoogleAI,
} from '@genkit-ai/googleai';
import vertexAI, {
import {
vertexAI,
PluginOptions as PluginOptionsVertexAI,
gemini10Pro as gemini10ProVertexAI,
gemini15Flash as gemini15FlashVertexAI,
Expand Down Expand Up @@ -141,8 +142,6 @@ export class GenkitGenerativeClient extends GenerativeClient<

const generateOptions = {...this.generateOptions, ...options};

logger.debug('Generating response with Genkit', {promptText});

let imageBase64: string | undefined;

if (options?.image) {
Expand All @@ -151,7 +150,6 @@ export class GenkitGenerativeClient extends GenerativeClient<
options.image,
this.provider as 'google-ai' | 'vertex-ai'
);
logger.info('Image successfully converted to Base64.');
} catch (error) {
logger.error('Failed to process image:', error);
throw new Error('Image processing failed.');
Expand Down

0 comments on commit c91b358

Please sign in to comment.