diff --git a/src/participant/participant.ts b/src/participant/participant.ts index dc9e94a17..8edf5dc44 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -38,6 +38,7 @@ import { } from '../telemetry/telemetryService'; import { DocsChatbotAIService } from './docsChatbotAIService'; import type TelemetryService from '../telemetry/telemetryService'; +import { IntentPrompt, type PromptIntent } from './prompts/intent'; const log = createLogger('participant'); @@ -214,8 +215,7 @@ export default class ParticipantController { } } - // @MongoDB what is mongodb? - async handleGenericRequest( + async _handleRoutedGenericRequest( request: vscode.ChatRequest, context: vscode.ChatContext, stream: vscode.ChatResponseStream, @@ -241,6 +241,93 @@ export default class ParticipantController { return genericRequestChatResult(context.history); } + async _routeRequestToHandler({ + context, + promptIntent, + request, + stream, + token, + }: { + context: vscode.ChatContext; + promptIntent: Omit; + request: vscode.ChatRequest; + stream: vscode.ChatResponseStream; + token: vscode.CancellationToken; + }): Promise { + switch (promptIntent) { + case 'Query': + return this.handleQueryRequest(request, context, stream, token); + case 'Docs': + return this.handleDocsRequest(request, context, stream, token); + case 'Schema': + return this.handleSchemaRequest(request, context, stream, token); + case 'Code': + return this.handleQueryRequest(request, context, stream, token); + default: + return this._handleRoutedGenericRequest( + request, + context, + stream, + token + ); + } + } + + async _getIntentFromChatRequest({ + context, + request, + token, + }: { + context: vscode.ChatContext; + request: vscode.ChatRequest; + token: vscode.CancellationToken; + }): Promise { + const messages = await Prompts.intent.buildMessages({ + connectionNames: this._getConnectionNames(), + request, + context, + }); + + const responseContent = await this.getChatResponseContent({ + messages, + token, + }); + + return IntentPrompt.getIntentFromModelResponse(responseContent); + } + + async handleGenericRequest( + request: vscode.ChatRequest, + context: vscode.ChatContext, + stream: vscode.ChatResponseStream, + token: vscode.CancellationToken + ): Promise { + // We "prompt chain" to handle the generic requests. + // First we ask the model to parse for intent. + // If there is an intent, we can route it to one of the handlers (/commands). + // When there is no intention or it's generic we handle it with a generic handler. + const promptIntent = await this._getIntentFromChatRequest({ + context, + request, + token, + }); + + if (token.isCancellationRequested) { + return this._handleCancelledRequest({ + context, + stream, + }); + } + + return this._routeRequestToHandler({ + context, + promptIntent, + request, + stream, + token, + }); + } + async connectWithParticipant({ id, command, diff --git a/src/participant/prompts/generic.ts b/src/participant/prompts/generic.ts index 9b5348790..40b531228 100644 --- a/src/participant/prompts/generic.ts +++ b/src/participant/prompts/generic.ts @@ -1,20 +1,24 @@ import * as vscode from 'vscode'; + import type { PromptArgsBase } from './promptBase'; import { PromptBase } from './promptBase'; export class GenericPrompt extends PromptBase { protected getAssistantPrompt(): string { return `You are a MongoDB expert. -Your task is to help the user craft MongoDB queries and aggregation pipelines that perform their task. -Keep your response concise. -You should suggest queries that are performant and correct. -Respond with markdown, suggest code in a Markdown code block that begins with \`\`\`javascript and ends with \`\`\`. -You can imagine the schema, collection, and database name. -Respond in MongoDB shell syntax using the \`\`\`javascript code block syntax.`; +Your task is to help the user with MongoDB related questions. +When applicable, you may suggest MongoDB code, queries, and aggregation pipelines that perform their task. +Rules: +1. Keep your response concise. +2. You should suggest code that is performant and correct. +3. Respond with markdown. +4. When relevant, provide code in a Markdown code block that begins with \`\`\`javascript and ends with \`\`\`. +5. Use MongoDB shell syntax for code unless the user requests a specific language. +6. If you require additional information to provide a response, ask the user for it. +7. When specifying a database, use the MongoDB syntax use('databaseName').`; } public getEmptyRequestResponse(): string { - // TODO(VSCODE-572): Generic empty response handler return vscode.l10n.t( 'Ask anything about MongoDB, from writing queries to questions about your cluster.' ); diff --git a/src/participant/prompts/index.ts b/src/participant/prompts/index.ts index 2d40beec4..867b057dd 100644 --- a/src/participant/prompts/index.ts +++ b/src/participant/prompts/index.ts @@ -1,11 +1,14 @@ -import { GenericPrompt } from './generic'; import type * as vscode from 'vscode'; + +import { GenericPrompt } from './generic'; +import { IntentPrompt } from './intent'; import { NamespacePrompt } from './namespace'; import { QueryPrompt } from './query'; import { SchemaPrompt } from './schema'; export class Prompts { public static generic = new GenericPrompt(); + public static intent = new IntentPrompt(); public static namespace = new NamespacePrompt(); public static query = new QueryPrompt(); public static schema = new SchemaPrompt(); diff --git a/src/participant/prompts/intent.ts b/src/participant/prompts/intent.ts new file mode 100644 index 000000000..0726f0fc7 --- /dev/null +++ b/src/participant/prompts/intent.ts @@ -0,0 +1,50 @@ +import type { PromptArgsBase } from './promptBase'; +import { PromptBase } from './promptBase'; + +export type PromptIntent = 'Query' | 'Schema' | 'Docs' | 'Default'; + +export class IntentPrompt extends PromptBase { + protected getAssistantPrompt(): string { + return `You are a MongoDB expert. +Your task is to help guide a conversation with a user to the correct handler. +You will be provided a conversation and your task is to determine the intent of the user. +The intent handlers are: +- Query +- Schema +- Docs +- Default +Rules: +1. Respond only with the intent handler. +2. Use the "Query" intent handler when the user is asking for code that relates to a specific collection. +3. Use the "Docs" intent handler when the user is asking a question that involves MongoDB documentation. +4. Use the "Schema" intent handler when the user is asking for the schema or shape of documents of a specific collection. +5. Use the "Default" intent handler when a user is asking for code that does NOT relate to a specific collection. +6. Use the "Default" intent handler for everything that may not be handled by another handler. +7. If you are uncertain of the intent, use the "Default" intent handler. + +Example: +User: How do I create an index in my pineapples collection? +Response: +Query + +Example: +User: +What is $vectorSearch? +Response: +Docs`; + } + + static getIntentFromModelResponse(response: string): PromptIntent { + response = response.trim(); + switch (response) { + case 'Query': + return 'Query'; + case 'Schema': + return 'Schema'; + case 'Docs': + return 'Docs'; + default: + return 'Default'; + } + } +} diff --git a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts index 77db0244b..c3135400d 100644 --- a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts +++ b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts @@ -17,7 +17,7 @@ import { type TestOutputs, type TestResult, } from './create-test-results-html-page'; -import { runCodeInMessage } from './assertions'; +import { anyOf, runCodeInMessage } from './assertions'; import { Prompts } from '../../participant/prompts'; const numberOfRunsPerTest = 1; @@ -34,7 +34,7 @@ type AssertProps = { type TestCase = { testCase: string; - type: 'generic' | 'query' | 'namespace'; + type: 'intent' | 'generic' | 'query' | 'namespace'; userInput: string; // Some tests can edit the documents in a collection. // As we want tests to run in isolation this flag will cause the fixture @@ -48,7 +48,9 @@ type TestCase = { only?: boolean; // Translates to mocha's it.only so only this test will run. }; -const namespaceTestCases: TestCase[] = [ +const namespaceTestCases: (TestCase & { + type: 'namespace'; +})[] = [ { testCase: 'Namespace included in query', type: 'namespace', @@ -110,7 +112,9 @@ const namespaceTestCases: TestCase[] = [ }, ]; -const queryTestCases: TestCase[] = [ +const queryTestCases: (TestCase & { + type: 'query'; +})[] = [ { testCase: 'Basic query', type: 'query', @@ -245,9 +249,177 @@ const queryTestCases: TestCase[] = [ expect(output.data?.result?.content[0].collectors).to.include('Monkey'); }, }, + { + testCase: 'Complex aggregation with string and number manipulation', + type: 'query', + databaseName: 'CookBook', + collectionName: 'recipes', + userInput: + 'what percentage of recipes have "salt" in their ingredients? "ingredients" is a field ' + + 'with an array of strings of the ingredients. Only consider recipes ' + + 'that have the "difficulty Medium or Easy. Return is as a string named "saltPercentage" like ' + + '"75%", rounded to the nearest whole number.', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + + anyOf([ + (): void => { + const lines = responseContent.trim().split('\n'); + const lastLine = lines[lines.length - 1]; + + expect(lastLine).to.include('saltPercentage'); + expect(output.data?.result?.content).to.include('67%'); + }, + (): void => { + expect(output.printOutput[output.printOutput.length - 1]).to.equal( + "{ saltPercentage: '67%' }" + ); + }, + (): void => { + expect(output.data?.result?.content[0].saltPercentage).to.equal( + '67%' + ); + }, + ])(null); + }, + }, ]; -const testCases: TestCase[] = [...namespaceTestCases, ...queryTestCases]; +const intentTestCases: (TestCase & { + type: 'intent'; +})[] = [ + { + testCase: 'Docs intent', + type: 'intent', + userInput: + 'Where can I find more information on how to connect to MongoDB?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Docs'); + }, + }, + { + testCase: 'Docs intent 2', + type: 'intent', + userInput: 'What are the options when creating an aggregation cursor?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Docs'); + }, + }, + { + testCase: 'Query intent', + type: 'intent', + userInput: + 'which collectors specialize only in mint items? and are located in London or New York? an array of their names in a field called collectors', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Query'); + }, + }, + { + testCase: 'Schema intent', + type: 'intent', + userInput: 'What do the documents in the collection pineapple look like?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Schema'); + }, + }, + { + testCase: 'Default/Generic intent 1', + type: 'intent', + userInput: 'How can I connect to MongoDB?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Default'); + }, + }, + { + testCase: 'Default/Generic intent 2', + type: 'intent', + userInput: 'What is the size breakdown of all of the databases?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Default'); + }, + }, +]; + +const genericTestCases: (TestCase & { + type: 'generic'; +})[] = [ + { + testCase: 'Database meta data question', + type: 'generic', + userInput: + 'How do I print the name and size of the largest database? Using the print function', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + const printOutput = output.printOutput.join(''); + + // Don't check the name since they're all the base 8192. + expect(printOutput).to.include('8192'); + }, + }, + { + testCase: 'Code question with database, collection, and fields named', + type: 'generic', + userInput: + 'How many sightings happened in the "year" "2020" and "2021"? database "UFO" collection "sightings". code to just return the one total number. also, the year is a string', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + anyOf([ + (): void => { + expect(output.printOutput.join('')).to.equal('2'); + }, + (): void => { + expect(output.data?.result?.content).to.equal('2'); + }, + (): void => { + expect(output.data?.result?.content).to.equal(2); + }, + (): void => { + expect( + Object.entries(output.data?.result?.content[0])[0][1] + ).to.equal(2); + }, + (): void => { + expect( + Object.entries(output.data?.result?.content[0])[0][1] + ).to.equal('2'); + }, + ])(null); + }, + }, + { + testCase: 'Complex aggregation code generation', + type: 'generic', + userInput: + 'what percentage of recipes have "salt" in their ingredients? "ingredients" is a field ' + + 'with an array of strings of the ingredients. Only consider recipes ' + + 'that have the "difficulty Medium or Easy. Return is as a string named "saltPercentage" like ' + + '"75%", rounded to the nearest whole number. db CookBook, collection recipes', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + + expect(output.data?.result?.content[0].saltPercentage).to.equal('67%'); + }, + }, +]; + +const testCases: TestCase[] = [ + ...namespaceTestCases, + ...queryTestCases, + ...intentTestCases, + ...genericTestCases, +]; const projectRoot = path.join(__dirname, '..', '..', '..'); @@ -319,6 +491,13 @@ const buildMessages = async ({ fixtures: Fixtures; }): Promise => { switch (testCase.type) { + case 'intent': + return Prompts.intent.buildMessages({ + request: { prompt: testCase.userInput }, + context: { history: [] }, + connectionNames: [], + }); + case 'generic': return Prompts.generic.buildMessages({ request: { prompt: testCase.userInput }, @@ -473,12 +652,14 @@ describe('AI Accuracy Tests', function () { testFunction( `should pass for input: "${testCase.userInput}" if average accuracy is above threshold`, - // eslint-disable-next-line no-loop-func + // eslint-disable-next-line no-loop-func, complexity async function () { console.log(`Starting test run of ${testCase.testCase}.`); const testRunDidSucceed: boolean[] = []; - const successFullRunStats: { + // Successful and unsuccessful runs are both tracked as long as the model + // returns a response. + const runStats: { promptTokens: number; completionTokens: number; executionTimeMS: number; @@ -505,12 +686,15 @@ describe('AI Accuracy Tests', function () { } const startTime = Date.now(); + let responseContent: ChatCompletion | undefined; + let executionTimeMS = 0; try { - const responseContent = await runTest({ + responseContent = await runTest({ testCase, aiBackend, fixtures, }); + executionTimeMS = Date.now() - startTime; testOutputs[testCase.testCase].outputs.push( responseContent.content ); @@ -521,11 +705,6 @@ describe('AI Accuracy Tests', function () { mongoClient, }); - successFullRunStats.push({ - completionTokens: responseContent.usageStats.completionTokens, - promptTokens: responseContent.usageStats.promptTokens, - executionTimeMS: Date.now() - startTime, - }); success = true; console.log( @@ -538,6 +717,18 @@ describe('AI Accuracy Tests', function () { ); } + if ( + responseContent && + responseContent?.usageStats?.completionTokens > 0 && + executionTimeMS !== 0 + ) { + runStats.push({ + completionTokens: responseContent.usageStats.completionTokens, + promptTokens: responseContent.usageStats.promptTokens, + executionTimeMS, + }); + } + testRunDidSucceed.push(success); } @@ -558,21 +749,19 @@ describe('AI Accuracy Tests', function () { Accuracy: averageAccuracy, Pass: didFail ? '✗' : '✓', 'Avg Execution Time (ms)': - successFullRunStats.length > 0 - ? successFullRunStats.reduce((a, b) => a + b.executionTimeMS, 0) / - successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.executionTimeMS, 0) / + runStats.length : 0, 'Avg Prompt Tokens': - successFullRunStats.length > 0 - ? successFullRunStats.reduce((a, b) => a + b.promptTokens, 0) / - successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.promptTokens, 0) / + runStats.length : 0, 'Avg Completion Tokens': - successFullRunStats.length > 0 - ? successFullRunStats.reduce( - (a, b) => a + b.completionTokens, - 0 - ) / successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.completionTokens, 0) / + runStats.length : 0, }); diff --git a/src/test/ai-accuracy-tests/create-test-results-html-page.ts b/src/test/ai-accuracy-tests/create-test-results-html-page.ts index f58a50950..c3774a59e 100644 --- a/src/test/ai-accuracy-tests/create-test-results-html-page.ts +++ b/src/test/ai-accuracy-tests/create-test-results-html-page.ts @@ -23,6 +23,9 @@ export type TestOutputs = { [testName: string]: TestOutput; }; +const createTestLinkId = (testName: string): string => + encodeURIComponent(testName.replace(/ /g, '-')); + function getTestResultsTable(testResults: TestResult[]): string { const headers = Object.keys(testResults[0]) .map((key) => `${key}`) @@ -30,8 +33,15 @@ function getTestResultsTable(testResults: TestResult[]): string { const resultRows = testResults .map((result) => { - const row = Object.values(result) - .map((value) => `${value}`) + const row = Object.entries(result) + .map( + ([field, value]) => + `${ + field === 'Test' + ? `${value}` + : value + }` + ) .join(''); return `${row}`; }) @@ -56,7 +66,9 @@ function getTestOutputTables(testOutputs: TestOutputs): string { .map((out) => `${out}`) .join(''); return ` -

${testName} [${output.testType}]

+

Prompt: ${output.prompt}

diff --git a/src/test/ai-accuracy-tests/fixtures/recipes.ts b/src/test/ai-accuracy-tests/fixtures/recipes.ts index efb347bb3..f9e8d7346 100644 --- a/src/test/ai-accuracy-tests/fixtures/recipes.ts +++ b/src/test/ai-accuracy-tests/fixtures/recipes.ts @@ -12,6 +12,7 @@ const recipes: Fixture = { 'tomato sauce', 'onions', 'garlic', + 'salt', ], preparationTime: 60, difficulty: 'Medium', @@ -23,6 +24,19 @@ const recipes: Fixture = { preparationTime: 10, difficulty: 'Easy', }, + { + title: 'Pineapple', + ingredients: ['pineapple'], + preparationTime: 5, + difficulty: 'Very Hard', + }, + { + title: 'Pizza', + ingredients: ['dough', 'tomato sauce', 'mozzarella cheese', 'basil'], + optionalIngredients: ['pineapple'], + preparationTime: 40, + difficulty: 'Medium', + }, { title: 'Beef Wellington', ingredients: [ @@ -30,6 +44,7 @@ const recipes: Fixture = { 'mushroom duxelles', 'puff pastry', 'egg wash', + 'salt', ], preparationTime: 120, difficulty: 'Hard', diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index d52debb46..557fbd320 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -437,14 +437,71 @@ suite('Participant Controller Test Suite', function () { }); suite('generic command', function () { - test('generates a query', async function () { + suite('when the intent is recognized', function () { + beforeEach(function () { + sendRequestStub.onCall(0).resolves({ + text: ['Schema'], + }); + }); + + test('routes to the appropriate handler', async function () { + const chatRequestMock = { + prompt: + 'what is the shape of the documents in the pineapple collection?', + command: undefined, + references: [], + }; + const res = await invokeChatHandler(chatRequestMock); + + expect(sendRequestStub).to.have.been.calledTwice; + const intentRequest = sendRequestStub.firstCall.args[0]; + expect(intentRequest).to.have.length(2); + expect(intentRequest[0].content).to.include( + 'Your task is to help guide a conversation with a user to the correct handler.' + ); + expect(intentRequest[1].content).to.equal( + 'what is the shape of the documents in the pineapple collection?' + ); + const genericRequest = sendRequestStub.secondCall.args[0]; + expect(genericRequest).to.have.length(2); + expect(genericRequest[0].content).to.include( + 'Parse all user messages to find a database name and a collection name.' + ); + expect(genericRequest[1].content).to.equal( + 'what is the shape of the documents in the pineapple collection?' + ); + + expect(res?.metadata.intent).to.equal('askForNamespace'); + }); + }); + + test('default handler asks for intent and shows code run actions', async function () { const chatRequestMock = { prompt: 'how to find documents in my collection?', command: undefined, references: [], }; - await invokeChatHandler(chatRequestMock); + const res = await invokeChatHandler(chatRequestMock); + + expect(sendRequestStub).to.have.been.calledTwice; + const intentRequest = sendRequestStub.firstCall.args[0]; + expect(intentRequest).to.have.length(2); + expect(intentRequest[0].content).to.include( + 'Your task is to help guide a conversation with a user to the correct handler.' + ); + expect(intentRequest[1].content).to.equal( + 'how to find documents in my collection?' + ); + const genericRequest = sendRequestStub.secondCall.args[0]; + expect(genericRequest).to.have.length(2); + expect(genericRequest[0].content).to.include( + 'Your task is to help the user with MongoDB related questions.' + ); + expect(genericRequest[1].content).to.equal( + 'how to find documents in my collection?' + ); + expect(res?.metadata.intent).to.equal('generic'); expect(chatStreamStub?.button.getCall(0).args[0]).to.deep.equal({ command: 'mdb.runParticipantQuery', title: '▶️ Run',