Skip to content

Commit

Permalink
test(firestore-vector-search): fix tests and add more coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Dec 6, 2024
1 parent b60b543 commit 2e3f6e2
Show file tree
Hide file tree
Showing 12 changed files with 683 additions and 283 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,61 +1,84 @@
import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini';
// import { config } from "../../src/config";
jest.resetModules();

// Mock GoogleGenerativeAI and its methods
const mockGetGenerativeModel = jest.fn();
const mockBatchEmbedContents = jest.fn();

// mock config
// jest.mock("../../src/config", () => ({
// ...jest.requireActual("../../src/config"),
// geminiApiKey: "test-api-key",
// }));
jest.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: mockGetGenerativeModel,
})),
}));

jest.mock('../../src/config', () => ({
config: {
geminiApiKey: 'test-api-key',
},
}));

import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini';
import {GoogleGenerativeAI} from '@google/generative-ai';

describe('Gemini Embeddings', () => {
let embedClient;
let embedClient: GeminiAITextEmbedClient;

beforeEach(async () => {
// Reset mocks
jest.clearAllMocks();

beforeEach(() => {
// Mock return value for getGenerativeModel
mockGetGenerativeModel.mockReturnValue({
batchEmbedContents: mockBatchEmbedContents,
});

// Instantiate and initialize the client
embedClient = new GeminiAITextEmbedClient();
await embedClient.initialize();
});

describe('initialize', () => {
test('should properly initialize the client', async () => {
await embedClient.initialize();

expect(embedClient.client).toBeDefined();
// expect(GoogleGenerativeAI).toHaveBeenCalledWith(config.geminiApiKey);
expect(GoogleGenerativeAI).toHaveBeenCalledWith('test-api-key');
});
});

describe('getEmbeddings', () => {
test('should return embeddings for a batch of text', async () => {
const mockEmbedContent = jest
.fn()
.mockResolvedValue({embedding: [1, 2, 3]});
embedClient.client = {
getGenerativeModel: jest.fn(() => ({
embedContent: mockEmbedContent,
})),
};
// Mock batchEmbedContents to resolve with embeddings
mockBatchEmbedContents.mockResolvedValueOnce({
embeddings: [{values: [1, 2, 3]}, {values: [4, 5, 6]}],
});

const batch = ['text1', 'text2'];
const results = await embedClient.getEmbeddings(batch);

expect(mockEmbedContent).toHaveBeenCalledTimes(batch.length);
expect(mockBatchEmbedContents).toHaveBeenCalledWith({
requests: [
{content: {parts: [{text: 'text1'}], role: 'user'}},
{content: {parts: [{text: 'text2'}], role: 'user'}},
],
});

expect(results).toEqual([
[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
]);
});

test('should throw an error if the embedding process fails', async () => {
embedClient.client = {
getGenerativeModel: jest.fn(() => ({
embedContent: jest
.fn()
.mockRejectedValue(new Error('Embedding failed')),
})),
};
// Mock batchEmbedContents to throw an error
mockBatchEmbedContents.mockRejectedValueOnce(
new Error('Embedding failed')
);

await expect(embedClient.getEmbeddings(['text'])).rejects.toThrow(
'Error with embedding'
);

expect(mockBatchEmbedContents).toHaveBeenCalledWith({
requests: [{content: {parts: [{text: 'text'}], role: 'user'}}],
});
});
});
});
160 changes: 160 additions & 0 deletions firestore-vector-search/functions/__tests__/embeddings/genkit.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
jest.resetModules();

// Mocking `@genkit-ai/googleai` and `@genkit-ai/vertexai`
jest.mock('@genkit-ai/googleai', () => ({
googleAI: jest.fn(),
textEmbeddingGecko001: 'gecko-001-model',
}));

jest.mock('@genkit-ai/vertexai', () => ({
vertexAI: jest.fn(),
textEmbedding004: 'text-embedding-004-model',
}));

jest.mock('../../src/config', () => ({
config: {
geminiApiKey: 'test-api-key',
location: 'us-central1',
},
}));

import {GenkitEmbedClient} from '../../src/embeddings/client/genkit';
import {genkit} from 'genkit';
import {vertexAI} from '@genkit-ai/vertexai';
import {googleAI} from '@genkit-ai/googleai';

// Mock the genkit client with properly structured responses
const mockEmbedMany = jest.fn();
const mockEmbed = jest.fn();
jest.mock('genkit', () => ({
genkit: jest.fn().mockImplementation(() => ({
embedMany: mockEmbedMany,
embed: mockEmbed,
})),
}));

describe('GenkitEmbedClient', () => {
let embedClient: GenkitEmbedClient;
let mockVertexAI: jest.Mock;
let mockGoogleAI: jest.Mock;

beforeEach(() => {
jest.clearAllMocks();
mockVertexAI = vertexAI as jest.Mock;
mockGoogleAI = googleAI as jest.Mock;
});

describe('constructor', () => {
test('should initialize with Vertex AI provider', () => {
embedClient = new GenkitEmbedClient({
provider: 'vertexai',
batchSize: 100,
dimension: 768,
});

expect(embedClient.provider).toBe('vertexai');
expect(embedClient.embedder).toBe('text-embedding-004-model');
expect(mockVertexAI).toHaveBeenCalledWith({
location: 'us-central1',
});
expect(genkit).toHaveBeenCalledWith({
plugins: [undefined], // because the mock returns undefined
});
});

test('should initialize with Google AI provider', () => {
embedClient = new GenkitEmbedClient({
provider: 'googleai',
batchSize: 100,
dimension: 768,
});

expect(embedClient.provider).toBe('googleai');
expect(embedClient.embedder).toBe('gecko-001-model');
expect(mockGoogleAI).toHaveBeenCalledWith({
apiKey: 'test-api-key',
});
expect(genkit).toHaveBeenCalledWith({
plugins: [undefined], // because the mock returns undefined
});
});
});

describe('getEmbeddings', () => {
beforeEach(() => {
embedClient = new GenkitEmbedClient({
provider: 'vertexai',
batchSize: 100,
dimension: 768,
});
});

test('should return embeddings for a batch of inputs', async () => {
const mockResults = [{embedding: [1, 2, 3]}, {embedding: [4, 5, 6]}];
mockEmbedMany.mockResolvedValueOnce(mockResults);

const inputs = ['input1', 'input2'];
const embeddings = await embedClient.getEmbeddings(inputs);

expect(mockEmbedMany).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: inputs,
});

expect(embeddings).toEqual([
[1, 2, 3],
[4, 5, 6],
]);
});

test('should throw an error if embedding fails', async () => {
mockEmbedMany.mockRejectedValueOnce(new Error('Embedding failed'));

await expect(embedClient.getEmbeddings(['input'])).rejects.toThrow(
'Embedding failed'
);

expect(mockEmbedMany).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: ['input'],
});
});
});

describe('getSingleEmbedding', () => {
beforeEach(() => {
embedClient = new GenkitEmbedClient({
provider: 'googleai',
batchSize: 100,
dimension: 768,
});
});

test('should return a single embedding for an input', async () => {
mockEmbed.mockResolvedValueOnce([7, 8, 9]); // Changed to return array directly

const input = 'input1';
const embedding = await embedClient.getSingleEmbedding(input);

expect(mockEmbed).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: input,
});

expect(embedding).toEqual([7, 8, 9]);
});

test('should throw an error if embedding fails', async () => {
mockEmbed.mockRejectedValueOnce(new Error('Embedding failed'));

await expect(embedClient.getSingleEmbedding('input')).rejects.toThrow(
'Embedding failed'
);

expect(mockEmbed).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: 'input',
});
});
});
});
2 changes: 1 addition & 1 deletion firestore-vector-search/functions/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module.exports = {
rootDir: './',
globals: {
'ts-jest': {
tsConfig: '<rootDir>/__tests__/tsconfig.json',
tsconfig: '<rootDir>/tsconfig.test.json', // Correct reference to test-specific config
},
fetch: global.fetch,
},
Expand Down
Loading

0 comments on commit 2e3f6e2

Please sign in to comment.