Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(firestore-vector-search): use Firebase Genkit wherever possible. #603

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions firestore-vector-search/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## Version 0.0.6

refactor - use Firebase Genkit where possible

## Version 0.0.5

fix - fix backfill and fix npm audit
Expand Down
2 changes: 1 addition & 1 deletion firestore-vector-search/extension.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

name: firestore-vector-search
version: 0.0.5
version: 0.0.6
specVersion: v1beta

tags:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,61 +1,84 @@
import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini';
// import { config } from "../../src/config";
jest.resetModules();

// Mock GoogleGenerativeAI and its methods
const mockGetGenerativeModel = jest.fn();
const mockBatchEmbedContents = jest.fn();

// mock config
// jest.mock("../../src/config", () => ({
// ...jest.requireActual("../../src/config"),
// geminiApiKey: "test-api-key",
// }));
jest.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: mockGetGenerativeModel,
})),
}));

jest.mock('../../src/config', () => ({
config: {
geminiApiKey: 'test-api-key',
},
}));

import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini';
import {GoogleGenerativeAI} from '@google/generative-ai';

describe('Gemini Embeddings', () => {
let embedClient;
let embedClient: GeminiAITextEmbedClient;

beforeEach(async () => {
// Reset mocks
jest.clearAllMocks();

beforeEach(() => {
// Mock return value for getGenerativeModel
mockGetGenerativeModel.mockReturnValue({
batchEmbedContents: mockBatchEmbedContents,
});

// Instantiate and initialize the client
embedClient = new GeminiAITextEmbedClient();
await embedClient.initialize();
});

describe('initialize', () => {
test('should properly initialize the client', async () => {
await embedClient.initialize();

expect(embedClient.client).toBeDefined();
// expect(GoogleGenerativeAI).toHaveBeenCalledWith(config.geminiApiKey);
expect(GoogleGenerativeAI).toHaveBeenCalledWith('test-api-key');
});
});

describe('getEmbeddings', () => {
test('should return embeddings for a batch of text', async () => {
const mockEmbedContent = jest
.fn()
.mockResolvedValue({embedding: [1, 2, 3]});
embedClient.client = {
getGenerativeModel: jest.fn(() => ({
embedContent: mockEmbedContent,
})),
};
// Mock batchEmbedContents to resolve with embeddings
mockBatchEmbedContents.mockResolvedValueOnce({
embeddings: [{values: [1, 2, 3]}, {values: [4, 5, 6]}],
});

const batch = ['text1', 'text2'];
const results = await embedClient.getEmbeddings(batch);

expect(mockEmbedContent).toHaveBeenCalledTimes(batch.length);
expect(mockBatchEmbedContents).toHaveBeenCalledWith({
requests: [
{content: {parts: [{text: 'text1'}], role: 'user'}},
{content: {parts: [{text: 'text2'}], role: 'user'}},
],
});

expect(results).toEqual([
[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
]);
});

test('should throw an error if the embedding process fails', async () => {
embedClient.client = {
getGenerativeModel: jest.fn(() => ({
embedContent: jest
.fn()
.mockRejectedValue(new Error('Embedding failed')),
})),
};
// Mock batchEmbedContents to throw an error
mockBatchEmbedContents.mockRejectedValueOnce(
new Error('Embedding failed')
);

await expect(embedClient.getEmbeddings(['text'])).rejects.toThrow(
'Error with embedding'
);

expect(mockBatchEmbedContents).toHaveBeenCalledWith({
requests: [{content: {parts: [{text: 'text'}], role: 'user'}}],
});
});
});
});
160 changes: 160 additions & 0 deletions firestore-vector-search/functions/__tests__/embeddings/genkit.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
jest.resetModules();

// Mocking `@genkit-ai/googleai` and `@genkit-ai/vertexai`
jest.mock('@genkit-ai/googleai', () => ({
googleAI: jest.fn(),
textEmbeddingGecko001: 'gecko-001-model',
}));

jest.mock('@genkit-ai/vertexai', () => ({
vertexAI: jest.fn(),
textEmbedding004: 'text-embedding-004-model',
}));

jest.mock('../../src/config', () => ({
config: {
geminiApiKey: 'test-api-key',
location: 'us-central1',
},
}));

import {GenkitEmbedClient} from '../../src/embeddings/client/genkit';
import {genkit} from 'genkit';
import {vertexAI} from '@genkit-ai/vertexai';
import {googleAI} from '@genkit-ai/googleai';

// Mock the genkit client with properly structured responses
const mockEmbedMany = jest.fn();
const mockEmbed = jest.fn();
jest.mock('genkit', () => ({
genkit: jest.fn().mockImplementation(() => ({
embedMany: mockEmbedMany,
embed: mockEmbed,
})),
}));

describe('GenkitEmbedClient', () => {
let embedClient: GenkitEmbedClient;
let mockVertexAI: jest.Mock;
let mockGoogleAI: jest.Mock;

beforeEach(() => {
jest.clearAllMocks();
mockVertexAI = vertexAI as jest.Mock;
mockGoogleAI = googleAI as jest.Mock;
});

describe('constructor', () => {
test('should initialize with Vertex AI provider', () => {
embedClient = new GenkitEmbedClient({
provider: 'vertexai',
batchSize: 100,
dimension: 768,
});

expect(embedClient.provider).toBe('vertexai');
expect(embedClient.embedder).toBe('text-embedding-004-model');
expect(mockVertexAI).toHaveBeenCalledWith({
location: 'us-central1',
});
expect(genkit).toHaveBeenCalledWith({
plugins: [undefined], // because the mock returns undefined
});
});

test('should initialize with Google AI provider', () => {
embedClient = new GenkitEmbedClient({
provider: 'googleai',
batchSize: 100,
dimension: 768,
});

expect(embedClient.provider).toBe('googleai');
expect(embedClient.embedder).toBe('gecko-001-model');
expect(mockGoogleAI).toHaveBeenCalledWith({
apiKey: 'test-api-key',
});
expect(genkit).toHaveBeenCalledWith({
plugins: [undefined], // because the mock returns undefined
});
});
});

describe('getEmbeddings', () => {
beforeEach(() => {
embedClient = new GenkitEmbedClient({
provider: 'vertexai',
batchSize: 100,
dimension: 768,
});
});

test('should return embeddings for a batch of inputs', async () => {
const mockResults = [{embedding: [1, 2, 3]}, {embedding: [4, 5, 6]}];
mockEmbedMany.mockResolvedValueOnce(mockResults);

const inputs = ['input1', 'input2'];
const embeddings = await embedClient.getEmbeddings(inputs);

expect(mockEmbedMany).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: inputs,
});

expect(embeddings).toEqual([
[1, 2, 3],
[4, 5, 6],
]);
});

test('should throw an error if embedding fails', async () => {
mockEmbedMany.mockRejectedValueOnce(new Error('Embedding failed'));

await expect(embedClient.getEmbeddings(['input'])).rejects.toThrow(
'Embedding failed'
);

expect(mockEmbedMany).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: ['input'],
});
});
});

describe('getSingleEmbedding', () => {
beforeEach(() => {
embedClient = new GenkitEmbedClient({
provider: 'googleai',
batchSize: 100,
dimension: 768,
});
});

test('should return a single embedding for an input', async () => {
mockEmbed.mockResolvedValueOnce([7, 8, 9]); // Changed to return array directly

const input = 'input1';
const embedding = await embedClient.getSingleEmbedding(input);

expect(mockEmbed).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: input,
});

expect(embedding).toEqual([7, 8, 9]);
});

test('should throw an error if embedding fails', async () => {
mockEmbed.mockRejectedValueOnce(new Error('Embedding failed'));

await expect(embedClient.getSingleEmbedding('input')).rejects.toThrow(
'Embedding failed'
);

expect(mockEmbed).toHaveBeenCalledWith({
embedder: embedClient.embedder,
content: 'input',
});
});
});
});
2 changes: 1 addition & 1 deletion firestore-vector-search/functions/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module.exports = {
rootDir: './',
globals: {
'ts-jest': {
tsConfig: '<rootDir>/__tests__/tsconfig.json',
tsconfig: '<rootDir>/tsconfig.test.json', // Correct reference to test-specific config
},
fetch: global.fetch,
},
Expand Down
Loading
Loading