From ab9d5b77c08381dbfdc5c25c0e74ee3b673a23eb Mon Sep 17 00:00:00 2001 From: Rhys Date: Fri, 6 Sep 2024 15:19:12 -0400 Subject: [PATCH] chore(participant): add sample docs test, add debug prompt var VSCODE-583 (#809) --- .../ai-accuracy-tests/ai-accuracy-tests.ts | 50 +++++++++++-- src/test/ai-accuracy-tests/assertions.ts | 11 ++- .../create-test-results-html-page.ts | 4 +- .../ai-accuracy-tests/fixtures/antiques.ts | 74 +++++++++++++++++++ .../fixtures/fixture-loader.ts | 4 +- src/test/ai-accuracy-tests/test-setup.ts | 16 +++- 6 files changed, 142 insertions(+), 17 deletions(-) create mode 100644 src/test/ai-accuracy-tests/fixtures/antiques.ts diff --git a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts index 1ad4924bb..981a21544 100644 --- a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts +++ b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts @@ -25,6 +25,9 @@ import { parseForDatabaseAndCollectionName } from '../../participant/participant const numberOfRunsPerTest = 1; +// When true, we will log the entire prompt we send to the model for each test. +const DEBUG_PROMPTS = process.env.DEBUG_PROMPTS === 'true'; + type AssertProps = { responseContent: string; connectionString: string; @@ -42,6 +45,7 @@ type TestCase = { reloadFixtureOnEachRun?: boolean; databaseName?: string; collectionName?: string; + includeSampleDocuments?: boolean; accuracyThresholdOverride?: number; assertResult: (props: AssertProps) => Promise | void; only?: boolean; // Translates to mocha's it.only so only this test will run. @@ -53,7 +57,7 @@ const namespaceTestCases: TestCase[] = [ type: 'namespace', userInput: 'How many documents are in the tempReadings collection in the pools database?', - assertResult: ({ responseContent }: AssertProps) => { + assertResult: ({ responseContent }: AssertProps): void => { const namespace = parseForDatabaseAndCollectionName(responseContent); expect(namespace.databaseName).to.equal('pools'); @@ -64,7 +68,7 @@ const namespaceTestCases: TestCase[] = [ testCase: 'No namespace included in basic query', type: 'namespace', userInput: 'How many documents are in the collection?', - assertResult: ({ responseContent }: AssertProps) => { + assertResult: ({ responseContent }: AssertProps): void => { const namespace = parseForDatabaseAndCollectionName(responseContent); expect(namespace.databaseName).to.equal(undefined); @@ -76,7 +80,7 @@ const namespaceTestCases: TestCase[] = [ type: 'namespace', userInput: 'How do I create a new user with read write permissions on the orders collection?', - assertResult: ({ responseContent }: AssertProps) => { + assertResult: ({ responseContent }: AssertProps): void => { const namespace = parseForDatabaseAndCollectionName(responseContent); expect(namespace.databaseName).to.equal(undefined); @@ -88,7 +92,7 @@ const namespaceTestCases: TestCase[] = [ type: 'namespace', userInput: 'How do I create a new user with read write permissions on the orders db?', - assertResult: ({ responseContent }: AssertProps) => { + assertResult: ({ responseContent }: AssertProps): void => { const namespace = parseForDatabaseAndCollectionName(responseContent); expect(namespace.databaseName).to.equal('orders'); @@ -131,7 +135,7 @@ const queryTestCases: TestCase[] = [ connectionString, mongoClient, fixtures, - }: AssertProps) => { + }: AssertProps): Promise => { const documentsBefore = await mongoClient .db('CookBook') .collection('recipes') @@ -172,7 +176,7 @@ const queryTestCases: TestCase[] = [ assertResult: async ({ responseContent, connectionString, - }: AssertProps) => { + }: AssertProps): Promise => { const output = await runCodeInMessage(responseContent, connectionString); expect(output.data?.result?.content[0]).to.deep.equal({ @@ -192,7 +196,7 @@ const queryTestCases: TestCase[] = [ responseContent, connectionString, mongoClient, - }: AssertProps) => { + }: AssertProps): Promise => { const indexesBefore = await mongoClient .db('FarmData') .collection('Pineapples') @@ -213,6 +217,27 @@ const queryTestCases: TestCase[] = [ ).to.have.keys(['harvestedDate', 'sweetnessScale']); }, }, + { + testCase: 'Aggregation with an or or $in, with sample docs', + type: 'query', + only: true, + databaseName: 'Antiques', + collectionName: 'items', + includeSampleDocuments: true, + reloadFixtureOnEachRun: true, + userInput: + 'which collectors specialize only in mint items? and are located in London or New York? an array of their names in a field called collectors', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + + expect(output.data?.result?.content?.[0].collectors).to.have.lengthOf(2); + expect(output.data?.result?.content[0].collectors).to.include('John Doe'); + expect(output.data?.result?.content[0].collectors).to.include('Monkey'); + }, + }, ]; const testCases: TestCase[] = [...namespaceTestCases, ...queryTestCases]; @@ -309,6 +334,13 @@ const buildMessages = async ({ ]?.schema, } : {}), + ...(testCase.includeSampleDocuments + ? { + sampleDocuments: fixtures[testCase.databaseName as string][ + testCase.collectionName as string + ].documents.slice(0, 3), + } + : {}), }); case 'namespace': @@ -335,6 +367,10 @@ async function runTest({ testCase, fixtures, }); + if (DEBUG_PROMPTS) { + console.log('Messages to send to chat completion:'); + console.log(messages); + } const chatCompletion = await aiBackend.runAIChatCompletionGeneration({ messages: messages.map((message) => ({ ...message, diff --git a/src/test/ai-accuracy-tests/assertions.ts b/src/test/ai-accuracy-tests/assertions.ts index aa1289fc9..31460ef20 100644 --- a/src/test/ai-accuracy-tests/assertions.ts +++ b/src/test/ai-accuracy-tests/assertions.ts @@ -54,8 +54,10 @@ export const runCodeInMessage = async ( }; }; -export const isDeepStrictEqualTo = (expected: unknown) => (actual: unknown) => - assert.deepStrictEqual(actual, expected); +export const isDeepStrictEqualTo = + (expected: unknown) => + (actual: unknown): void => + assert.deepStrictEqual(actual, expected); export const isDeepStrictEqualToFixtures = ( @@ -64,13 +66,14 @@ export const isDeepStrictEqualToFixtures = fixtures: Fixtures, comparator: (document: Document) => boolean ) => - (actual: unknown) => { + (actual: unknown): void => { const expected = fixtures[db][coll].documents.filter(comparator); assert.deepStrictEqual(actual, expected); }; export const anyOf = - (assertions: ((result: unknown) => void)[]) => (actual: unknown) => { + (assertions: ((result: unknown) => void)[]) => + (actual: unknown): void => { const errors: Error[] = []; for (const assertion of assertions) { try { diff --git a/src/test/ai-accuracy-tests/create-test-results-html-page.ts b/src/test/ai-accuracy-tests/create-test-results-html-page.ts index 59b362e8f..f58a50950 100644 --- a/src/test/ai-accuracy-tests/create-test-results-html-page.ts +++ b/src/test/ai-accuracy-tests/create-test-results-html-page.ts @@ -23,7 +23,7 @@ export type TestOutputs = { [testName: string]: TestOutput; }; -function getTestResultsTable(testResults: TestResult[]) { +function getTestResultsTable(testResults: TestResult[]): string { const headers = Object.keys(testResults[0]) .map((key) => `${key}`) .join(''); @@ -49,7 +49,7 @@ function getTestResultsTable(testResults: TestResult[]) { `; } -function getTestOutputTables(testOutputs: TestOutputs) { +function getTestOutputTables(testOutputs: TestOutputs): string { const outputTables = Object.entries(testOutputs) .map(([testName, output]) => { const outputRows = output.outputs diff --git a/src/test/ai-accuracy-tests/fixtures/antiques.ts b/src/test/ai-accuracy-tests/fixtures/antiques.ts new file mode 100644 index 000000000..81c5ef1d8 --- /dev/null +++ b/src/test/ai-accuracy-tests/fixtures/antiques.ts @@ -0,0 +1,74 @@ +import type { Fixture } from './fixture-type'; + +const antiques: Fixture = { + db: 'Antiques', + coll: 'items', + documents: [ + { + itemName: 'Vintage Beatles Vinyl', + owner: { + name: 'John Doe', + location: 'London', + }, + acquisition: { + date: '1998-03-13', + price: 1200, + }, + condition: 'Mint', + history: [ + { event: 'Auction Win', date: '1998-03-13' }, + { event: 'Restoration', date: '2005-07-22' }, + ], + }, + { + itemName: 'Ancient Roman Coin', + owner: { + name: 'Jane Doe', + location: 'Rome', + }, + acquisition: { + date: '2002-11-27', + price: 5000, + }, + condition: 'Good', + history: [ + { event: 'Found in a dig', date: '2002-11-27' }, + { event: 'Museum Display', date: '2010-02-15' }, + ], + }, + { + itemName: 'Victorian Pocket Watch', + owner: { + name: 'Arnold Arnoldson', + location: 'London', + }, + acquisition: { + date: '2010-06-30', + price: 800, + }, + condition: 'Fair', + history: [ + { event: 'Inherited', date: '2010-06-30' }, + { event: 'Repair', date: '2015-09-12' }, + ], + }, + { + itemName: 'An Ancient Pineapple (super rare)', + owner: { + name: 'Monkey', + location: 'New York', + }, + acquisition: { + date: '2018-02-05', + price: 2300, + }, + condition: 'Mint', + history: [ + { event: 'Estate Sale', date: '2018-02-05' }, + { event: 'Appraisal', date: '2020-04-18' }, + ], + }, + ], +}; + +export default antiques; diff --git a/src/test/ai-accuracy-tests/fixtures/fixture-loader.ts b/src/test/ai-accuracy-tests/fixtures/fixture-loader.ts index 527e9ae14..a97978a34 100644 --- a/src/test/ai-accuracy-tests/fixtures/fixture-loader.ts +++ b/src/test/ai-accuracy-tests/fixtures/fixture-loader.ts @@ -2,6 +2,7 @@ import type { Document, MongoClient } from 'mongodb'; import { getSimplifiedSchema } from 'mongodb-schema'; import type { Fixture } from './fixture-type'; +import antiques from './antiques'; import pets from './pets'; import pineapples from './pineapples'; import recipes from './recipes'; @@ -19,6 +20,7 @@ export type Fixtures = { type LoadableFixture = (() => Fixture) | Fixture; const fixturesToLoad: LoadableFixture[] = [ + antiques, pets, pineapples, recipes, @@ -62,7 +64,7 @@ export async function reloadFixture({ coll: string; mongoClient: MongoClient; fixtures: Fixtures; -}) { +}): Promise { await mongoClient.db(db).collection(coll).drop(); const { documents } = fixtures[db][coll]; await mongoClient.db(db).collection(coll).insertMany(documents); diff --git a/src/test/ai-accuracy-tests/test-setup.ts b/src/test/ai-accuracy-tests/test-setup.ts index 07ce4a162..bd0a4d217 100644 --- a/src/test/ai-accuracy-tests/test-setup.ts +++ b/src/test/ai-accuracy-tests/test-setup.ts @@ -8,19 +8,29 @@ const vscodeMock = { User: UserRole, }, LanguageModelChatMessage: { - Assistant: (content, name?: string) => ({ + Assistant: (content, name?: string): unknown => ({ name, content, role: AssistantRole, }), - User: (content: string, name?: string) => ({ + User: (content: string, name?: string): unknown => ({ content, name, role: UserRole, }), }, window: { - createOutputChannel: () => {}, + createOutputChannel: (): void => {}, + }, + lm: { + selectChatModels: (): unknown => [ + { + countTokens: (input: string): number => { + return input.length; + }, + maxInputTokens: 10_000, + }, + ], }, };