Skip to content

Commit

Permalink
refactor(firestore-vector-search): use Firebase Genkit wherever possi…
Browse files Browse the repository at this point in the history
…ble. (#603)

* refactor(firestore-vector-search): use genkit where possible

* chore(firestore-vector-search): update CHANGELOG and bump ext version

* test(firestore-vector-search): fix tests and add more coverage
  • Loading branch information
cabljac authored Dec 10, 2024
1 parent 81cfd8d commit 9e11408
Show file tree
Hide file tree
Showing 17 changed files with 2,631 additions and 518 deletions.
4 changes: 4 additions & 0 deletions firestore-vector-search/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## Version 0.0.6

refactor - use Firebase Genkit where possible

## Version 0.0.5

fix - fix backfill and fix npm audit
Expand Down
2 changes: 1 addition & 1 deletion firestore-vector-search/extension.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

name: firestore-vector-search
version: 0.0.5
version: 0.0.6
specVersion: v1beta

tags:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,61 +1,84 @@
import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini';
// import { config } from "../../src/config";
jest.resetModules();

// Mock GoogleGenerativeAI and its methods
const mockGetGenerativeModel = jest.fn();
const mockBatchEmbedContents = jest.fn();

// mock config
// jest.mock("../../src/config", () => ({
// ...jest.requireActual("../../src/config"),
// geminiApiKey: "test-api-key",
// }));
jest.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: mockGetGenerativeModel,
})),
}));

jest.mock('../../src/config', () => ({
config: {
geminiApiKey: 'test-api-key',
},
}));

import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini';
import {GoogleGenerativeAI} from '@google/generative-ai';

describe('Gemini Embeddings', () => {
let embedClient;
let embedClient: GeminiAITextEmbedClient;

beforeEach(async () => {
// Reset mocks
jest.clearAllMocks();

beforeEach(() => {
// Mock return value for getGenerativeModel
mockGetGenerativeModel.mockReturnValue({
batchEmbedContents: mockBatchEmbedContents,
});

// Instantiate and initialize the client
embedClient = new GeminiAITextEmbedClient();
await embedClient.initialize();
});

describe('initialize', () => {
test('should properly initialize the client', async () => {
await embedClient.initialize();

expect(embedClient.client).toBeDefined();
// expect(GoogleGenerativeAI).toHaveBeenCalledWith(config.geminiApiKey);
expect(GoogleGenerativeAI).toHaveBeenCalledWith('test-api-key');
});
});

describe('getEmbeddings', () => {
test('should return embeddings for a batch of text', async () => {
const mockEmbedContent = jest
.fn()
.mockResolvedValue({embedding: [1, 2, 3]});
embedClient.client = {
getGenerativeModel: jest.fn(() => ({
embedContent: mockEmbedContent,
})),
};
// Mock batchEmbedContents to resolve with embeddings
mockBatchEmbedContents.mockResolvedValueOnce({
embeddings: [{values: [1, 2, 3]}, {values: [4, 5, 6]}],
});

const batch = ['text1', 'text2'];
const results = await embedClient.getEmbeddings(batch);

expect(mockEmbedContent).toHaveBeenCalledTimes(batch.length);
expect(mockBatchEmbedContents).toHaveBeenCalledWith({
requests: [
{content: {parts: [{text: 'text1'}], role: 'user'}},
{content: {parts: [{text: 'text2'}], role: 'user'}},
],
});

expect(results).toEqual([
[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
]);
});

test('should throw an error if the embedding process fails', async () => {
embedClient.client = {
getGenerativeModel: jest.fn(() => ({
embedContent: jest
.fn()
.mockRejectedValue(new Error('Embedding failed')),
})),
};
// Mock batchEmbedContents to throw an error
mockBatchEmbedContents.mockRejectedValueOnce(
new Error('Embedding failed')
);

await expect(embedClient.getEmbeddings(['text'])).rejects.toThrow(
'Error with embedding'
);

expect(mockBatchEmbedContents).toHaveBeenCalledWith({
requests: [{content: {parts: [{text: 'text'}], role: 'user'}}],
});
});
});
});
160 changes: 160 additions & 0 deletions firestore-vector-search/functions/__tests__/embeddings/genkit.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
jest.resetModules();

// Mocking `@genkit-ai/googleai` and `@genkit-ai/vertexai`
jest.mock('@genkit-ai/googleai', () => ({
googleAI: jest.fn(),
textEmbeddingGecko001: 'gecko-001-model',
}));

jest.mock('@genkit-ai/vertexai', () => ({
vertexAI: jest.fn(),
textEmbedding004: 'text-embedding-004-model',
}));

jest.mock('../../src/config', () => ({
config: {
geminiApiKey: 'test-api-key',
location: 'us-central1',
},
}));

import {GenkitEmbedClient} from '../../src/embeddings/client/genkit';
import {genkit} from 'genkit';
import {vertexAI} from '@genkit-ai/vertexai';
import {googleAI} from '@genkit-ai/googleai';

// Mock the genkit client with properly structured responses
const mockEmbedMany = jest.fn();
const mockEmbed = jest.fn();
jest.mock('genkit', () => ({
genkit: jest.fn().mockImplementation(() => ({
embedMany: mockEmbedMany,
embed: mockEmbed,
})),
}));

describe('GenkitEmbedClient', () => {
let embedClient: GenkitEmbedClient;
let mockVertexAI: jest.Mock;
let mockGoogleAI: jest.Mock;

beforeEach(() => {
jest.clearAllMocks();
mockVertexAI = vertexAI as jest.Mock;
mockGoogleAI = googleAI as jest.Mock;
});

describe('constructor', () => {
test('should initialize with Vertex AI provider', () => {
embedClient = new GenkitEmbedClient({
provider: 'vertexai',
batchSize: 100,
dimension: 768,
});

expect(embedClient.provider).toBe('vertexai');
expect(embedClient.embedder).toBe('text-embedding-004-model');
expect(mockVertexAI).toHaveBeenCalledWith({
location: 'us-central1',
});
expect(genkit).toHaveBeenCalledWith({
plugins: [undefined], // because the mock returns undefined
});
});

test('should initialize with Google AI provider', () => {
embedClient = new GenkitEmbedClient({
provider: 'googleai',
batchSize: 100,
dimension: 768,
});

expect(embedClient.provider).toBe('googleai');
expect(embedClient.embedder).toBe('gecko-001-model');
expect(mockGoogleAI).toHaveBeenCalledWith({
apiKey: 'test-api-key',
});
expect(genkit).toHaveBeenCalledWith({
plugins: [undefined], // because the mock returns undefined
});
});
});

describe('getEmbeddings', () => {
beforeEach(() => {
embedClient = new GenkitEmbedClient({
provider: 'vertexai',
batchSize: 100,
dimension: 768,
});
});

test('should return embeddings for a batch of inputs', async () => {
const mockResults = [{embedding: [1, 2, 3]}, {embedding: [4, 5, 6]}];
mockEmbedMany.mockResolvedValueOnce(mockResults);

const inputs = ['input1', 'input2'];
const embeddings = await embedClient.getEmbeddings(inputs);

expect(mockEmbedMany).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: inputs,
});

expect(embeddings).toEqual([
[1, 2, 3],
[4, 5, 6],
]);
});

test('should throw an error if embedding fails', async () => {
mockEmbedMany.mockRejectedValueOnce(new Error('Embedding failed'));

await expect(embedClient.getEmbeddings(['input'])).rejects.toThrow(
'Embedding failed'
);

expect(mockEmbedMany).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: ['input'],
});
});
});

describe('getSingleEmbedding', () => {
beforeEach(() => {
embedClient = new GenkitEmbedClient({
provider: 'googleai',
batchSize: 100,
dimension: 768,
});
});

test('should return a single embedding for an input', async () => {
mockEmbed.mockResolvedValueOnce([7, 8, 9]); // Changed to return array directly

const input = 'input1';
const embedding = await embedClient.getSingleEmbedding(input);

expect(mockEmbed).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: input,
});

expect(embedding).toEqual([7, 8, 9]);
});

test('should throw an error if embedding fails', async () => {
mockEmbed.mockRejectedValueOnce(new Error('Embedding failed'));

await expect(embedClient.getSingleEmbedding('input')).rejects.toThrow(
'Embedding failed'
);

expect(mockEmbed).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: 'input',
});
});
});
});
2 changes: 1 addition & 1 deletion firestore-vector-search/functions/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module.exports = {
rootDir: './',
globals: {
'ts-jest': {
tsConfig: '<rootDir>/__tests__/tsconfig.json',
tsconfig: '<rootDir>/tsconfig.test.json', // Correct reference to test-specific config
},
fetch: global.fetch,
},
Expand Down
Loading

0 comments on commit 9e11408

Please sign in to comment.