-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(firestore-vector-search): use Firebase Genkit wherever possi…
…ble. (#603) * refactor(firestore-vector-search): use genkit where possible * chore(firestore-vector-search): update CHANGELOG and bump ext version * test(firestore-vector-search): fix tests and add more coverage
- Loading branch information
Showing
17 changed files
with
2,631 additions
and
518 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 0 additions & 31 deletions
31
firestore-vector-search/functions/__tests__/__snapshots__/config.test.ts.snap
This file was deleted.
Oops, something went wrong.
81 changes: 52 additions & 29 deletions
81
firestore-vector-search/functions/__tests__/embeddings/gemini.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,84 @@ | ||
import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini'; | ||
// import { config } from "../../src/config"; | ||
jest.resetModules(); | ||
|
||
// Mock GoogleGenerativeAI and its methods | ||
const mockGetGenerativeModel = jest.fn(); | ||
const mockBatchEmbedContents = jest.fn(); | ||
|
||
// mock config | ||
// jest.mock("../../src/config", () => ({ | ||
// ...jest.requireActual("../../src/config"), | ||
// geminiApiKey: "test-api-key", | ||
// })); | ||
jest.mock('@google/generative-ai', () => ({ | ||
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({ | ||
getGenerativeModel: mockGetGenerativeModel, | ||
})), | ||
})); | ||
|
||
jest.mock('../../src/config', () => ({ | ||
config: { | ||
geminiApiKey: 'test-api-key', | ||
}, | ||
})); | ||
|
||
import {GeminiAITextEmbedClient} from '../../src/embeddings/client/text/gemini'; | ||
import {GoogleGenerativeAI} from '@google/generative-ai'; | ||
|
||
describe('Gemini Embeddings', () => { | ||
let embedClient; | ||
let embedClient: GeminiAITextEmbedClient; | ||
|
||
beforeEach(async () => { | ||
// Reset mocks | ||
jest.clearAllMocks(); | ||
|
||
beforeEach(() => { | ||
// Mock return value for getGenerativeModel | ||
mockGetGenerativeModel.mockReturnValue({ | ||
batchEmbedContents: mockBatchEmbedContents, | ||
}); | ||
|
||
// Instantiate and initialize the client | ||
embedClient = new GeminiAITextEmbedClient(); | ||
await embedClient.initialize(); | ||
}); | ||
|
||
describe('initialize', () => { | ||
test('should properly initialize the client', async () => { | ||
await embedClient.initialize(); | ||
|
||
expect(embedClient.client).toBeDefined(); | ||
// expect(GoogleGenerativeAI).toHaveBeenCalledWith(config.geminiApiKey); | ||
expect(GoogleGenerativeAI).toHaveBeenCalledWith('test-api-key'); | ||
}); | ||
}); | ||
|
||
describe('getEmbeddings', () => { | ||
test('should return embeddings for a batch of text', async () => { | ||
const mockEmbedContent = jest | ||
.fn() | ||
.mockResolvedValue({embedding: [1, 2, 3]}); | ||
embedClient.client = { | ||
getGenerativeModel: jest.fn(() => ({ | ||
embedContent: mockEmbedContent, | ||
})), | ||
}; | ||
// Mock batchEmbedContents to resolve with embeddings | ||
mockBatchEmbedContents.mockResolvedValueOnce({ | ||
embeddings: [{values: [1, 2, 3]}, {values: [4, 5, 6]}], | ||
}); | ||
|
||
const batch = ['text1', 'text2']; | ||
const results = await embedClient.getEmbeddings(batch); | ||
|
||
expect(mockEmbedContent).toHaveBeenCalledTimes(batch.length); | ||
expect(mockBatchEmbedContents).toHaveBeenCalledWith({ | ||
requests: [ | ||
{content: {parts: [{text: 'text1'}], role: 'user'}}, | ||
{content: {parts: [{text: 'text2'}], role: 'user'}}, | ||
], | ||
}); | ||
|
||
expect(results).toEqual([ | ||
[1, 2, 3], | ||
[1, 2, 3], | ||
[4, 5, 6], | ||
]); | ||
}); | ||
|
||
test('should throw an error if the embedding process fails', async () => { | ||
embedClient.client = { | ||
getGenerativeModel: jest.fn(() => ({ | ||
embedContent: jest | ||
.fn() | ||
.mockRejectedValue(new Error('Embedding failed')), | ||
})), | ||
}; | ||
// Mock batchEmbedContents to throw an error | ||
mockBatchEmbedContents.mockRejectedValueOnce( | ||
new Error('Embedding failed') | ||
); | ||
|
||
await expect(embedClient.getEmbeddings(['text'])).rejects.toThrow( | ||
'Error with embedding' | ||
); | ||
|
||
expect(mockBatchEmbedContents).toHaveBeenCalledWith({ | ||
requests: [{content: {parts: [{text: 'text'}], role: 'user'}}], | ||
}); | ||
}); | ||
}); | ||
}); |
160 changes: 160 additions & 0 deletions
160
firestore-vector-search/functions/__tests__/embeddings/genkit.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
jest.resetModules(); | ||
|
||
// Mocking `@genkit-ai/googleai` and `@genkit-ai/vertexai` | ||
jest.mock('@genkit-ai/googleai', () => ({ | ||
googleAI: jest.fn(), | ||
textEmbeddingGecko001: 'gecko-001-model', | ||
})); | ||
|
||
jest.mock('@genkit-ai/vertexai', () => ({ | ||
vertexAI: jest.fn(), | ||
textEmbedding004: 'text-embedding-004-model', | ||
})); | ||
|
||
jest.mock('../../src/config', () => ({ | ||
config: { | ||
geminiApiKey: 'test-api-key', | ||
location: 'us-central1', | ||
}, | ||
})); | ||
|
||
import {GenkitEmbedClient} from '../../src/embeddings/client/genkit'; | ||
import {genkit} from 'genkit'; | ||
import {vertexAI} from '@genkit-ai/vertexai'; | ||
import {googleAI} from '@genkit-ai/googleai'; | ||
|
||
// Mock the genkit client with properly structured responses | ||
const mockEmbedMany = jest.fn(); | ||
const mockEmbed = jest.fn(); | ||
jest.mock('genkit', () => ({ | ||
genkit: jest.fn().mockImplementation(() => ({ | ||
embedMany: mockEmbedMany, | ||
embed: mockEmbed, | ||
})), | ||
})); | ||
|
||
describe('GenkitEmbedClient', () => { | ||
let embedClient: GenkitEmbedClient; | ||
let mockVertexAI: jest.Mock; | ||
let mockGoogleAI: jest.Mock; | ||
|
||
beforeEach(() => { | ||
jest.clearAllMocks(); | ||
mockVertexAI = vertexAI as jest.Mock; | ||
mockGoogleAI = googleAI as jest.Mock; | ||
}); | ||
|
||
describe('constructor', () => { | ||
test('should initialize with Vertex AI provider', () => { | ||
embedClient = new GenkitEmbedClient({ | ||
provider: 'vertexai', | ||
batchSize: 100, | ||
dimension: 768, | ||
}); | ||
|
||
expect(embedClient.provider).toBe('vertexai'); | ||
expect(embedClient.embedder).toBe('text-embedding-004-model'); | ||
expect(mockVertexAI).toHaveBeenCalledWith({ | ||
location: 'us-central1', | ||
}); | ||
expect(genkit).toHaveBeenCalledWith({ | ||
plugins: [undefined], // because the mock returns undefined | ||
}); | ||
}); | ||
|
||
test('should initialize with Google AI provider', () => { | ||
embedClient = new GenkitEmbedClient({ | ||
provider: 'googleai', | ||
batchSize: 100, | ||
dimension: 768, | ||
}); | ||
|
||
expect(embedClient.provider).toBe('googleai'); | ||
expect(embedClient.embedder).toBe('gecko-001-model'); | ||
expect(mockGoogleAI).toHaveBeenCalledWith({ | ||
apiKey: 'test-api-key', | ||
}); | ||
expect(genkit).toHaveBeenCalledWith({ | ||
plugins: [undefined], // because the mock returns undefined | ||
}); | ||
}); | ||
}); | ||
|
||
describe('getEmbeddings', () => { | ||
beforeEach(() => { | ||
embedClient = new GenkitEmbedClient({ | ||
provider: 'vertexai', | ||
batchSize: 100, | ||
dimension: 768, | ||
}); | ||
}); | ||
|
||
test('should return embeddings for a batch of inputs', async () => { | ||
const mockResults = [{embedding: [1, 2, 3]}, {embedding: [4, 5, 6]}]; | ||
mockEmbedMany.mockResolvedValueOnce(mockResults); | ||
|
||
const inputs = ['input1', 'input2']; | ||
const embeddings = await embedClient.getEmbeddings(inputs); | ||
|
||
expect(mockEmbedMany).toHaveBeenCalledWith({ | ||
embedder: embedClient.embedder, | ||
content: inputs, | ||
}); | ||
|
||
expect(embeddings).toEqual([ | ||
[1, 2, 3], | ||
[4, 5, 6], | ||
]); | ||
}); | ||
|
||
test('should throw an error if embedding fails', async () => { | ||
mockEmbedMany.mockRejectedValueOnce(new Error('Embedding failed')); | ||
|
||
await expect(embedClient.getEmbeddings(['input'])).rejects.toThrow( | ||
'Embedding failed' | ||
); | ||
|
||
expect(mockEmbedMany).toHaveBeenCalledWith({ | ||
embedder: embedClient.embedder, | ||
content: ['input'], | ||
}); | ||
}); | ||
}); | ||
|
||
describe('getSingleEmbedding', () => { | ||
beforeEach(() => { | ||
embedClient = new GenkitEmbedClient({ | ||
provider: 'googleai', | ||
batchSize: 100, | ||
dimension: 768, | ||
}); | ||
}); | ||
|
||
test('should return a single embedding for an input', async () => { | ||
mockEmbed.mockResolvedValueOnce([7, 8, 9]); // Changed to return array directly | ||
|
||
const input = 'input1'; | ||
const embedding = await embedClient.getSingleEmbedding(input); | ||
|
||
expect(mockEmbed).toHaveBeenCalledWith({ | ||
embedder: embedClient.embedder, | ||
content: input, | ||
}); | ||
|
||
expect(embedding).toEqual([7, 8, 9]); | ||
}); | ||
|
||
test('should throw an error if embedding fails', async () => { | ||
mockEmbed.mockRejectedValueOnce(new Error('Embedding failed')); | ||
|
||
await expect(embedClient.getSingleEmbedding('input')).rejects.toThrow( | ||
'Embedding failed' | ||
); | ||
|
||
expect(mockEmbed).toHaveBeenCalledWith({ | ||
embedder: embedClient.embedder, | ||
content: 'input', | ||
}); | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.