diff --git a/src/participant/constants.ts b/src/participant/constants.ts index b43b1332f..90fd81490 100644 --- a/src/participant/constants.ts +++ b/src/participant/constants.ts @@ -14,6 +14,11 @@ export type ParticipantResponseType = | 'askToConnect' | 'askForNamespace'; +export const codeBlockIdentifier = { + start: '```javascript', + end: '```', +}; + interface Metadata { intent: Exclude; chatId: string; diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 0a680dc2e..14f50d31b 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -21,6 +21,7 @@ import { docsRequestChatResult, schemaRequestChatResult, createCancelledRequestChatResult, + codeBlockIdentifier, } from './constants'; import { SchemaFormatter } from './schema'; import { getSimplifiedSampleDocuments } from './sampleDocuments'; @@ -39,6 +40,8 @@ import { import { DocsChatbotAIService } from './docsChatbotAIService'; import type TelemetryService from '../telemetry/telemetryService'; import type { PromptResult } from './prompts/promptBase'; +import { processStreamWithIdentifiers } from './streamParsing'; +import type { PromptIntent } from './prompts/intent'; const log = createLogger('participant'); @@ -59,16 +62,6 @@ export type ParticipantCommand = '/query' | '/schema' | '/docs'; const MAX_MARKDOWN_LIST_LENGTH = 10; -export function getRunnableContentFromString(text: string): string { - const matchedJSresponseContent = text.match(/```javascript((.|\n)*)```/); - - const code = - matchedJSresponseContent && matchedJSresponseContent.length > 1 - ? matchedJSresponseContent[1] - : ''; - return code.trim(); -} - export default class ParticipantController { _participant?: vscode.ChatParticipant; _connectionController: ConnectionController; @@ -171,54 +164,118 @@ export default class ParticipantController { }); } - async getChatResponseContent({ + async _getChatResponse({ prompt, token, }: { prompt: PromptResult; token: vscode.CancellationToken; - }): Promise { + }): Promise { const model = await getCopilotModel(); - let responseContent = ''; - if (model) { - const chatResponse = await model.sendRequest(prompt.messages, {}, token); - for await (const fragment of chatResponse.text) { - responseContent += fragment; - } - this._telemetryService.trackCopilotParticipantPrompt(prompt.stats); + if (!model) { + throw new Error('Copilot model not found'); } - return responseContent; + this._telemetryService.trackCopilotParticipantPrompt(prompt.stats); + + return await model.sendRequest(prompt.messages, {}, token); } - _streamRunnableContentActions({ - responseContent, + async streamChatResponse({ + prompt, stream, + token, }: { - responseContent: string; + prompt: PromptResult; + stream: vscode.ChatResponseStream; + token: vscode.CancellationToken; + }): Promise { + const chatResponse = await this._getChatResponse({ + prompt, + token, + }); + for await (const fragment of chatResponse.text) { + stream.markdown(fragment); + } + } + + _streamCodeBlockActions({ + runnableContent, + stream, + }: { + runnableContent: string; stream: vscode.ChatResponseStream; }): void { - const runnableContent = getRunnableContentFromString(responseContent); - if (runnableContent) { - const commandArgs: RunParticipantQueryCommandArgs = { - runnableContent, - }; - stream.button({ - command: EXTENSION_COMMANDS.RUN_PARTICIPANT_QUERY, - title: vscode.l10n.t('▶️ Run'), - arguments: [commandArgs], - }); - stream.button({ - command: EXTENSION_COMMANDS.OPEN_PARTICIPANT_QUERY_IN_PLAYGROUND, - title: vscode.l10n.t('Open in playground'), - arguments: [commandArgs], - }); + runnableContent = runnableContent.trim(); + + if (!runnableContent) { + return; } + + const commandArgs: RunParticipantQueryCommandArgs = { + runnableContent, + }; + stream.button({ + command: EXTENSION_COMMANDS.RUN_PARTICIPANT_QUERY, + title: vscode.l10n.t('▶️ Run'), + arguments: [commandArgs], + }); + stream.button({ + command: EXTENSION_COMMANDS.OPEN_PARTICIPANT_QUERY_IN_PLAYGROUND, + title: vscode.l10n.t('Open in playground'), + arguments: [commandArgs], + }); } - // @MongoDB what is mongodb? - async handleGenericRequest( + async streamChatResponseContentWithCodeActions({ + prompt, + stream, + token, + }: { + prompt: PromptResult; + stream: vscode.ChatResponseStream; + token: vscode.CancellationToken; + }): Promise { + const chatResponse = await this._getChatResponse({ + prompt, + token, + }); + + await processStreamWithIdentifiers({ + processStreamFragment: (fragment: string) => { + stream.markdown(fragment); + }, + onStreamIdentifier: (content: string) => { + this._streamCodeBlockActions({ runnableContent: content, stream }); + }, + inputIterable: chatResponse.text, + identifier: codeBlockIdentifier, + }); + } + + // This will stream all of the response content and create a string from it. + // It should only be used when the entire response is needed at one time. + async getChatResponseContent({ + prompt, + token, + }: { + prompt: PromptResult; + token: vscode.CancellationToken; + }): Promise { + let responseContent = ''; + const chatResponse = await this._getChatResponse({ + prompt, + token, + }); + for await (const fragment of chatResponse.text) { + responseContent += fragment; + } + + return responseContent; + } + + async _handleRoutedGenericRequest( request: vscode.ChatRequest, context: vscode.ChatContext, stream: vscode.ChatResponseStream, @@ -230,18 +287,100 @@ export default class ParticipantController { connectionNames: this._getConnectionNames(), }); + await this.streamChatResponseContentWithCodeActions({ + prompt, + token, + stream, + }); + + return genericRequestChatResult(context.history); + } + + async _routeRequestToHandler({ + context, + promptIntent, + request, + stream, + token, + }: { + context: vscode.ChatContext; + promptIntent: Omit; + request: vscode.ChatRequest; + stream: vscode.ChatResponseStream; + token: vscode.CancellationToken; + }): Promise { + switch (promptIntent) { + case 'Query': + return this.handleQueryRequest(request, context, stream, token); + case 'Docs': + return this.handleDocsRequest(request, context, stream, token); + case 'Schema': + return this.handleSchemaRequest(request, context, stream, token); + case 'Code': + return this.handleQueryRequest(request, context, stream, token); + default: + return this._handleRoutedGenericRequest( + request, + context, + stream, + token + ); + } + } + + async _getIntentFromChatRequest({ + context, + request, + token, + }: { + context: vscode.ChatContext; + request: vscode.ChatRequest; + token: vscode.CancellationToken; + }): Promise { + const prompt = await Prompts.intent.buildMessages({ + connectionNames: this._getConnectionNames(), + request, + context, + }); + const responseContent = await this.getChatResponseContent({ prompt, token, }); - stream.markdown(responseContent); - this._streamRunnableContentActions({ - responseContent, - stream, + return Prompts.intent.getIntentFromModelResponse(responseContent); + } + + async handleGenericRequest( + request: vscode.ChatRequest, + context: vscode.ChatContext, + stream: vscode.ChatResponseStream, + token: vscode.CancellationToken + ): Promise { + // We "prompt chain" to handle the generic requests. + // First we ask the model to parse for intent. + // If there is an intent, we can route it to one of the handlers (/commands). + // When there is no intention or it's generic we handle it with a generic handler. + const promptIntent = await this._getIntentFromChatRequest({ + context, + request, + token, }); - return genericRequestChatResult(context.history); + if (token.isCancellationRequested) { + return this._handleCancelledRequest({ + context, + stream, + }); + } + + return this._routeRequestToHandler({ + context, + promptIntent, + request, + stream, + token, + }); } async connectWithParticipant({ @@ -917,11 +1056,11 @@ export default class ParticipantController { connectionNames: this._getConnectionNames(), ...(sampleDocuments ? { sampleDocuments } : {}), }); - const responseContent = await this.getChatResponseContent({ + await this.streamChatResponse({ prompt, + stream, token, }); - stream.markdown(responseContent); stream.button({ command: EXTENSION_COMMANDS.PARTICIPANT_OPEN_RAW_SCHEMA_OUTPUT, @@ -1020,16 +1159,11 @@ export default class ParticipantController { connectionNames: this._getConnectionNames(), ...(sampleDocuments ? { sampleDocuments } : {}), }); - const responseContent = await this.getChatResponseContent({ - prompt, - token, - }); - - stream.markdown(responseContent); - this._streamRunnableContentActions({ - responseContent, + await this.streamChatResponseContentWithCodeActions({ + prompt, stream, + token, }); return queryRequestChatResult(context.history); @@ -1097,32 +1231,41 @@ export default class ParticipantController { vscode.ChatResponseStream, vscode.CancellationToken ] - ): Promise<{ - responseContent: string; - responseReferences?: Reference[]; - }> { - const [request, context, , token] = args; + ): Promise { + const [request, context, stream, token] = args; const prompt = await Prompts.generic.buildMessages({ request, context, connectionNames: this._getConnectionNames(), }); - const responseContent = await this.getChatResponseContent({ + await this.streamChatResponseContentWithCodeActions({ prompt, + stream, token, }); - const responseReferences = [ - { + + this._streamResponseReference({ + reference: { url: MONGODB_DOCS_LINK, title: 'View MongoDB documentation', }, - ]; + stream, + }); + } - return { - responseContent, - responseReferences, - }; + _streamResponseReference({ + reference, + stream, + }: { + reference: Reference; + stream: vscode.ChatResponseStream; + }): void { + const link = new vscode.MarkdownString( + `- [${reference.title}](${reference.url})\n` + ); + link.supportHtml = true; + stream.markdown(link); } async handleDocsRequest( @@ -1151,6 +1294,19 @@ export default class ParticipantController { token, stream, }); + + if (docsResult.responseReferences) { + for (const reference of docsResult.responseReferences) { + this._streamResponseReference({ + reference, + stream, + }); + } + } + + if (docsResult.responseContent) { + stream.markdown(docsResult.responseContent); + } } catch (error) { // If the docs chatbot API is not available, fall back to Copilot’s LLM and include // the MongoDB documentation link for users to go to our documentation site directly. @@ -1171,25 +1327,7 @@ export default class ParticipantController { } ); - docsResult = await this._handleDocsRequestWithCopilot(...args); - } - - if (docsResult.responseContent) { - stream.markdown(docsResult.responseContent); - this._streamRunnableContentActions({ - responseContent: docsResult.responseContent, - stream, - }); - } - - if (docsResult.responseReferences) { - for (const ref of docsResult.responseReferences) { - const link = new vscode.MarkdownString( - `- [${ref.title}](${ref.url})\n` - ); - link.supportHtml = true; - stream.markdown(link); - } + await this._handleDocsRequestWithCopilot(...args); } return docsRequestChatResult({ diff --git a/src/participant/prompts/generic.ts b/src/participant/prompts/generic.ts index 9b5348790..2112233da 100644 --- a/src/participant/prompts/generic.ts +++ b/src/participant/prompts/generic.ts @@ -1,20 +1,26 @@ import * as vscode from 'vscode'; + import type { PromptArgsBase } from './promptBase'; import { PromptBase } from './promptBase'; +import { codeBlockIdentifier } from '../constants'; + export class GenericPrompt extends PromptBase { protected getAssistantPrompt(): string { return `You are a MongoDB expert. -Your task is to help the user craft MongoDB queries and aggregation pipelines that perform their task. -Keep your response concise. -You should suggest queries that are performant and correct. -Respond with markdown, suggest code in a Markdown code block that begins with \`\`\`javascript and ends with \`\`\`. -You can imagine the schema, collection, and database name. -Respond in MongoDB shell syntax using the \`\`\`javascript code block syntax.`; +Your task is to help the user with MongoDB related questions. +When applicable, you may suggest MongoDB code, queries, and aggregation pipelines that perform their task. +Rules: +1. Keep your response concise. +2. You should suggest code that is performant and correct. +3. Respond with markdown. +4. When relevant, provide code in a Markdown code block that begins with ${codeBlockIdentifier.start} and ends with ${codeBlockIdentifier.end} +5. Use MongoDB shell syntax for code unless the user requests a specific language. +6. If you require additional information to provide a response, ask the user for it. +7. When specifying a database, use the MongoDB syntax use('databaseName').`; } public getEmptyRequestResponse(): string { - // TODO(VSCODE-572): Generic empty response handler return vscode.l10n.t( 'Ask anything about MongoDB, from writing queries to questions about your cluster.' ); diff --git a/src/participant/prompts/index.ts b/src/participant/prompts/index.ts index 2d40beec4..867b057dd 100644 --- a/src/participant/prompts/index.ts +++ b/src/participant/prompts/index.ts @@ -1,11 +1,14 @@ -import { GenericPrompt } from './generic'; import type * as vscode from 'vscode'; + +import { GenericPrompt } from './generic'; +import { IntentPrompt } from './intent'; import { NamespacePrompt } from './namespace'; import { QueryPrompt } from './query'; import { SchemaPrompt } from './schema'; export class Prompts { public static generic = new GenericPrompt(); + public static intent = new IntentPrompt(); public static namespace = new NamespacePrompt(); public static query = new QueryPrompt(); public static schema = new SchemaPrompt(); diff --git a/src/participant/prompts/intent.ts b/src/participant/prompts/intent.ts new file mode 100644 index 000000000..4d6216afa --- /dev/null +++ b/src/participant/prompts/intent.ts @@ -0,0 +1,50 @@ +import type { PromptArgsBase } from './promptBase'; +import { PromptBase } from './promptBase'; + +export type PromptIntent = 'Query' | 'Schema' | 'Docs' | 'Default'; + +export class IntentPrompt extends PromptBase { + protected getAssistantPrompt(): string { + return `You are a MongoDB expert. +Your task is to help guide a conversation with a user to the correct handler. +You will be provided a conversation and your task is to determine the intent of the user. +The intent handlers are: +- Query +- Schema +- Docs +- Default +Rules: +1. Respond only with the intent handler. +2. Use the "Query" intent handler when the user is asking for code that relates to a specific collection. +3. Use the "Docs" intent handler when the user is asking a question that involves MongoDB documentation. +4. Use the "Schema" intent handler when the user is asking for the schema or shape of documents of a specific collection. +5. Use the "Default" intent handler when a user is asking for code that does NOT relate to a specific collection. +6. Use the "Default" intent handler for everything that may not be handled by another handler. +7. If you are uncertain of the intent, use the "Default" intent handler. + +Example: +User: How do I create an index in my pineapples collection? +Response: +Query + +Example: +User: +What is $vectorSearch? +Response: +Docs`; + } + + getIntentFromModelResponse(response: string): PromptIntent { + response = response.trim(); + switch (response) { + case 'Query': + return 'Query'; + case 'Schema': + return 'Schema'; + case 'Docs': + return 'Docs'; + default: + return 'Default'; + } + } +} diff --git a/src/participant/prompts/query.ts b/src/participant/prompts/query.ts index 926b739c8..1efef4ba6 100644 --- a/src/participant/prompts/query.ts +++ b/src/participant/prompts/query.ts @@ -3,6 +3,7 @@ import type { Document } from 'bson'; import { getStringifiedSampleDocuments } from '../sampleDocuments'; import type { PromptArgsBase, UserPromptResponse } from './promptBase'; +import { codeBlockIdentifier } from '../constants'; import { PromptBase } from './promptBase'; interface QueryPromptArgs extends PromptArgsBase { @@ -19,15 +20,15 @@ export class QueryPrompt extends PromptBase { Your task is to help the user craft MongoDB shell syntax code to perform their task. Keep your response concise. You must suggest code that is performant and correct. -Respond with markdown, write code in a Markdown code block that begins with \`\`\`javascript and ends with \`\`\`. -Respond in MongoDB shell syntax using the \`\`\`javascript code block syntax. +Respond with markdown, write code in a Markdown code block that begins with ${codeBlockIdentifier.start} and ends with ${codeBlockIdentifier.end}. +Respond in MongoDB shell syntax using the ${codeBlockIdentifier.start} code block syntax. Concisely explain the code snippet you have generated. Example 1: User: Documents in the orders db, sales collection, where the date is in 2014 and group the total sales for each product. Response: -\`\`\`javascript +${codeBlockIdentifier.start} use('orders'); db.getCollection('sales').aggregate([ // Find all of the sales that occurred in 2014. @@ -35,15 +36,15 @@ db.getCollection('sales').aggregate([ // Group the total sales for each product. { $group: { _id: '$item', totalSaleAmount: { $sum: { $multiply: [ '$price', '$quantity' ] } } } } ]); -\`\`\` +${codeBlockIdentifier.end} Example 2: User: How do I create an index on the name field in my users collection?. Response: -\`\`\`javascript +${codeBlockIdentifier.start} use('test'); db.getCollection('users').createIndex({ name: 1 }); -\`\`\` +${codeBlockIdentifier.end} MongoDB command to specify database: use(''); diff --git a/src/participant/streamParsing.ts b/src/participant/streamParsing.ts new file mode 100644 index 000000000..93bb5dad9 --- /dev/null +++ b/src/participant/streamParsing.ts @@ -0,0 +1,95 @@ +function escapeRegex(str: string): string { + return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); +} + +/** + * This function, provided a stream of text fragments, will stream the + * content to the provided stream and call the onStreamIdentifier function + * when an identifier is streamed. This is useful for inserting code actions + * into a chat response, whenever a code block has been written. + */ +export async function processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier, +}: { + processStreamFragment: (fragment: string) => void; + onStreamIdentifier: (content: string) => void; + inputIterable: AsyncIterable; + identifier: { + start: string; + end: string; + }; +}): Promise { + const escapedIdentifierStart = escapeRegex(identifier.start); + const escapedIdentifierEnd = escapeRegex(identifier.end); + const regex = new RegExp( + `${escapedIdentifierStart}([\\s\\S]*?)${escapedIdentifierEnd}`, + 'g' + ); + + let contentSinceLastIdentifier = ''; + for await (const fragment of inputIterable) { + contentSinceLastIdentifier += fragment; + + let lastIndex = 0; + let match: RegExpExecArray | null; + while ((match = regex.exec(contentSinceLastIdentifier)) !== null) { + const endIndex = regex.lastIndex; + + // Stream content up to the end of the identifier. + const contentToStream = contentSinceLastIdentifier.slice( + lastIndex, + endIndex + ); + processStreamFragment(contentToStream); + + const identifierContent = match[1]; + onStreamIdentifier(identifierContent); + + lastIndex = endIndex; + } + + if (lastIndex > 0) { + // Remove all of the processed content. + contentSinceLastIdentifier = contentSinceLastIdentifier.slice(lastIndex); + // Reset the regex. + regex.lastIndex = 0; + } else { + // Clear as much of the content as we can safely. + const maxUnprocessedLength = identifier.start.length - 1; + if (contentSinceLastIdentifier.length > maxUnprocessedLength) { + const identifierIndex = contentSinceLastIdentifier.indexOf( + identifier.start + ); + if (identifierIndex > -1) { + // We have an identifier, so clear up until the identifier. + const contentToStream = contentSinceLastIdentifier.slice( + 0, + identifierIndex + ); + processStreamFragment(contentToStream); + contentSinceLastIdentifier = + contentSinceLastIdentifier.slice(identifierIndex); + } else { + // No identifier, so clear up until the last maxUnprocessedLength. + const processUpTo = + contentSinceLastIdentifier.length - maxUnprocessedLength; + const contentToStream = contentSinceLastIdentifier.slice( + 0, + processUpTo + ); + processStreamFragment(contentToStream); + contentSinceLastIdentifier = + contentSinceLastIdentifier.slice(processUpTo); + } + } + } + } + + // Finish up anything not streamed yet. + if (contentSinceLastIdentifier.length > 0) { + processStreamFragment(contentSinceLastIdentifier); + } +} diff --git a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts index 7f0957273..e5059ce30 100644 --- a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts +++ b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts @@ -17,7 +17,7 @@ import { type TestOutputs, type TestResult, } from './create-test-results-html-page'; -import { runCodeInMessage } from './assertions'; +import { anyOf, runCodeInMessage } from './assertions'; import { Prompts } from '../../participant/prompts'; import type { PromptResult } from '../../participant/prompts/promptBase'; @@ -35,7 +35,7 @@ type AssertProps = { type TestCase = { testCase: string; - type: 'generic' | 'query' | 'namespace'; + type: 'intent' | 'generic' | 'query' | 'namespace'; userInput: string; // Some tests can edit the documents in a collection. // As we want tests to run in isolation this flag will cause the fixture @@ -49,7 +49,9 @@ type TestCase = { only?: boolean; // Translates to mocha's it.only so only this test will run. }; -const namespaceTestCases: TestCase[] = [ +const namespaceTestCases: (TestCase & { + type: 'namespace'; +})[] = [ { testCase: 'Namespace included in query', type: 'namespace', @@ -111,7 +113,9 @@ const namespaceTestCases: TestCase[] = [ }, ]; -const queryTestCases: TestCase[] = [ +const queryTestCases: (TestCase & { + type: 'query'; +})[] = [ { testCase: 'Basic query', type: 'query', @@ -246,9 +250,177 @@ const queryTestCases: TestCase[] = [ expect(output.data?.result?.content[0].collectors).to.include('Monkey'); }, }, + { + testCase: 'Complex aggregation with string and number manipulation', + type: 'query', + databaseName: 'CookBook', + collectionName: 'recipes', + userInput: + 'what percentage of recipes have "salt" in their ingredients? "ingredients" is a field ' + + 'with an array of strings of the ingredients. Only consider recipes ' + + 'that have the "difficulty Medium or Easy. Return is as a string named "saltPercentage" like ' + + '"75%", rounded to the nearest whole number.', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + + anyOf([ + (): void => { + const lines = responseContent.trim().split('\n'); + const lastLine = lines[lines.length - 1]; + + expect(lastLine).to.include('saltPercentage'); + expect(output.data?.result?.content).to.include('67%'); + }, + (): void => { + expect(output.printOutput[output.printOutput.length - 1]).to.equal( + "{ saltPercentage: '67%' }" + ); + }, + (): void => { + expect(output.data?.result?.content[0].saltPercentage).to.equal( + '67%' + ); + }, + ])(null); + }, + }, ]; -const testCases: TestCase[] = [...namespaceTestCases, ...queryTestCases]; +const intentTestCases: (TestCase & { + type: 'intent'; +})[] = [ + { + testCase: 'Docs intent', + type: 'intent', + userInput: + 'Where can I find more information on how to connect to MongoDB?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Docs'); + }, + }, + { + testCase: 'Docs intent 2', + type: 'intent', + userInput: 'What are the options when creating an aggregation cursor?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Docs'); + }, + }, + { + testCase: 'Query intent', + type: 'intent', + userInput: + 'which collectors specialize only in mint items? and are located in London or New York? an array of their names in a field called collectors', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Query'); + }, + }, + { + testCase: 'Schema intent', + type: 'intent', + userInput: 'What do the documents in the collection pineapple look like?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Schema'); + }, + }, + { + testCase: 'Default/Generic intent 1', + type: 'intent', + userInput: 'How can I connect to MongoDB?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Default'); + }, + }, + { + testCase: 'Default/Generic intent 2', + type: 'intent', + userInput: 'What is the size breakdown of all of the databases?', + assertResult: ({ responseContent }: AssertProps): void => { + expect(responseContent).to.equal('Default'); + }, + }, +]; + +const genericTestCases: (TestCase & { + type: 'generic'; +})[] = [ + { + testCase: 'Database meta data question', + type: 'generic', + userInput: + 'How do I print the name and size of the largest database? Using the print function', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + const printOutput = output.printOutput.join(''); + + // Don't check the name since they're all the base 8192. + expect(printOutput).to.include('8192'); + }, + }, + { + testCase: 'Code question with database, collection, and fields named', + type: 'generic', + userInput: + 'How many sightings happened in the "year" "2020" and "2021"? database "UFO" collection "sightings". code to just return the one total number. also, the year is a string', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + anyOf([ + (): void => { + expect(output.printOutput.join('')).to.equal('2'); + }, + (): void => { + expect(output.data?.result?.content).to.equal('2'); + }, + (): void => { + expect(output.data?.result?.content).to.equal(2); + }, + (): void => { + expect( + Object.entries(output.data?.result?.content[0])[0][1] + ).to.equal(2); + }, + (): void => { + expect( + Object.entries(output.data?.result?.content[0])[0][1] + ).to.equal('2'); + }, + ])(null); + }, + }, + { + testCase: 'Complex aggregation code generation', + type: 'generic', + userInput: + 'what percentage of recipes have "salt" in their ingredients? "ingredients" is a field ' + + 'with an array of strings of the ingredients. Only consider recipes ' + + 'that have the "difficulty Medium or Easy. Return is as a string named "saltPercentage" like ' + + '"75%", rounded to the nearest whole number. db CookBook, collection recipes', + assertResult: async ({ + responseContent, + connectionString, + }: AssertProps): Promise => { + const output = await runCodeInMessage(responseContent, connectionString); + + expect(output.data?.result?.content[0].saltPercentage).to.equal('67%'); + }, + }, +]; + +const testCases: TestCase[] = [ + ...namespaceTestCases, + ...queryTestCases, + ...intentTestCases, + ...genericTestCases, +]; const projectRoot = path.join(__dirname, '..', '..', '..'); @@ -320,6 +492,13 @@ const buildMessages = async ({ fixtures: Fixtures; }): Promise => { switch (testCase.type) { + case 'intent': + return Prompts.intent.buildMessages({ + request: { prompt: testCase.userInput }, + context: { history: [] }, + connectionNames: [], + }); + case 'generic': return await Prompts.generic.buildMessages({ request: { prompt: testCase.userInput }, @@ -474,12 +653,14 @@ describe('AI Accuracy Tests', function () { testFunction( `should pass for input: "${testCase.userInput}" if average accuracy is above threshold`, - // eslint-disable-next-line no-loop-func + // eslint-disable-next-line no-loop-func, complexity async function () { console.log(`Starting test run of ${testCase.testCase}.`); const testRunDidSucceed: boolean[] = []; - const successFullRunStats: { + // Successful and unsuccessful runs are both tracked as long as the model + // returns a response. + const runStats: { promptTokens: number; completionTokens: number; executionTimeMS: number; @@ -506,12 +687,15 @@ describe('AI Accuracy Tests', function () { } const startTime = Date.now(); + let responseContent: ChatCompletion | undefined; + let executionTimeMS = 0; try { - const responseContent = await runTest({ + responseContent = await runTest({ testCase, aiBackend, fixtures, }); + executionTimeMS = Date.now() - startTime; testOutputs[testCase.testCase].outputs.push( responseContent.content ); @@ -522,11 +706,6 @@ describe('AI Accuracy Tests', function () { mongoClient, }); - successFullRunStats.push({ - completionTokens: responseContent.usageStats.completionTokens, - promptTokens: responseContent.usageStats.promptTokens, - executionTimeMS: Date.now() - startTime, - }); success = true; console.log( @@ -539,6 +718,18 @@ describe('AI Accuracy Tests', function () { ); } + if ( + responseContent && + responseContent?.usageStats?.completionTokens > 0 && + executionTimeMS !== 0 + ) { + runStats.push({ + completionTokens: responseContent.usageStats.completionTokens, + promptTokens: responseContent.usageStats.promptTokens, + executionTimeMS, + }); + } + testRunDidSucceed.push(success); } @@ -559,21 +750,19 @@ describe('AI Accuracy Tests', function () { Accuracy: averageAccuracy, Pass: didFail ? '✗' : '✓', 'Avg Execution Time (ms)': - successFullRunStats.length > 0 - ? successFullRunStats.reduce((a, b) => a + b.executionTimeMS, 0) / - successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.executionTimeMS, 0) / + runStats.length : 0, 'Avg Prompt Tokens': - successFullRunStats.length > 0 - ? successFullRunStats.reduce((a, b) => a + b.promptTokens, 0) / - successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.promptTokens, 0) / + runStats.length : 0, 'Avg Completion Tokens': - successFullRunStats.length > 0 - ? successFullRunStats.reduce( - (a, b) => a + b.completionTokens, - 0 - ) / successFullRunStats.length + runStats.length > 0 + ? runStats.reduce((a, b) => a + b.completionTokens, 0) / + runStats.length : 0, }); diff --git a/src/test/ai-accuracy-tests/assertions.ts b/src/test/ai-accuracy-tests/assertions.ts index 31460ef20..304cc68b8 100644 --- a/src/test/ai-accuracy-tests/assertions.ts +++ b/src/test/ai-accuracy-tests/assertions.ts @@ -3,9 +3,11 @@ import util from 'util'; import type { Document } from 'mongodb'; import type { Fixtures } from './fixtures/fixture-loader'; -import { getRunnableContentFromString } from '../../participant/participant'; import { execute } from '../../language/worker'; import type { ShellEvaluateResult } from '../../types/playgroundType'; +import { asyncIterableFromArray } from '../suite/participant/asyncIterableFromArray'; +import { codeBlockIdentifier } from '../../participant/constants'; +import { processStreamWithIdentifiers } from '../../participant/streamParsing'; export const runCodeInMessage = async ( message: string, @@ -15,7 +17,18 @@ export const runCodeInMessage = async ( data: ShellEvaluateResult; error: any; }> => { - const codeToEvaluate = getRunnableContentFromString(message); + // We only run the last code block passed. + let codeToEvaluate = ''; + await processStreamWithIdentifiers({ + processStreamFragment: () => { + /* no-op */ + }, + onStreamIdentifier: (codeBlockContent: string): void => { + codeToEvaluate = codeBlockContent; + }, + inputIterable: asyncIterableFromArray([message]), + identifier: codeBlockIdentifier, + }); if (codeToEvaluate.trim().length === 0) { throw new Error(`no code found in message: ${message}`); diff --git a/src/test/ai-accuracy-tests/create-test-results-html-page.ts b/src/test/ai-accuracy-tests/create-test-results-html-page.ts index f58a50950..c3774a59e 100644 --- a/src/test/ai-accuracy-tests/create-test-results-html-page.ts +++ b/src/test/ai-accuracy-tests/create-test-results-html-page.ts @@ -23,6 +23,9 @@ export type TestOutputs = { [testName: string]: TestOutput; }; +const createTestLinkId = (testName: string): string => + encodeURIComponent(testName.replace(/ /g, '-')); + function getTestResultsTable(testResults: TestResult[]): string { const headers = Object.keys(testResults[0]) .map((key) => `${key}`) @@ -30,8 +33,15 @@ function getTestResultsTable(testResults: TestResult[]): string { const resultRows = testResults .map((result) => { - const row = Object.values(result) - .map((value) => `${value}`) + const row = Object.entries(result) + .map( + ([field, value]) => + `${ + field === 'Test' + ? `${value}` + : value + }` + ) .join(''); return `${row}`; }) @@ -56,7 +66,9 @@ function getTestOutputTables(testOutputs: TestOutputs): string { .map((out) => `${out}`) .join(''); return ` -

${testName} [${output.testType}]

+

Prompt: ${output.prompt}

diff --git a/src/test/ai-accuracy-tests/fixtures/recipes.ts b/src/test/ai-accuracy-tests/fixtures/recipes.ts index efb347bb3..f9e8d7346 100644 --- a/src/test/ai-accuracy-tests/fixtures/recipes.ts +++ b/src/test/ai-accuracy-tests/fixtures/recipes.ts @@ -12,6 +12,7 @@ const recipes: Fixture = { 'tomato sauce', 'onions', 'garlic', + 'salt', ], preparationTime: 60, difficulty: 'Medium', @@ -23,6 +24,19 @@ const recipes: Fixture = { preparationTime: 10, difficulty: 'Easy', }, + { + title: 'Pineapple', + ingredients: ['pineapple'], + preparationTime: 5, + difficulty: 'Very Hard', + }, + { + title: 'Pizza', + ingredients: ['dough', 'tomato sauce', 'mozzarella cheese', 'basil'], + optionalIngredients: ['pineapple'], + preparationTime: 40, + difficulty: 'Medium', + }, { title: 'Beef Wellington', ingredients: [ @@ -30,6 +44,7 @@ const recipes: Fixture = { 'mushroom duxelles', 'puff pastry', 'egg wash', + 'salt', ], preparationTime: 120, difficulty: 'Hard', diff --git a/src/test/suite/participant/asyncIterableFromArray.ts b/src/test/suite/participant/asyncIterableFromArray.ts new file mode 100644 index 000000000..e3b7c8bde --- /dev/null +++ b/src/test/suite/participant/asyncIterableFromArray.ts @@ -0,0 +1,24 @@ +// Exported here so that the accuracy tests can use it without +// needing to define all of the testing types the main tests have. +export function asyncIterableFromArray(array: T[]): AsyncIterable { + return { + [Symbol.asyncIterator](): { + next(): Promise>; + } { + let index = 0; + return { + next(): Promise<{ + value: any; + done: boolean; + }> { + if (index < array.length) { + const value = array[index++]; + return Promise.resolve({ value, done: false }); + } + + return Promise.resolve({ value: undefined, done: true }); + }, + }; + }, + }; +} diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index c5e632aed..5a7e69cc8 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -6,9 +6,7 @@ import sinon from 'sinon'; import type { DataService } from 'mongodb-data-service'; import { ObjectId, Int32 } from 'bson'; -import ParticipantController, { - getRunnableContentFromString, -} from '../../../participant/participant'; +import ParticipantController from '../../../participant/participant'; import ConnectionController from '../../../connectionController'; import { StorageController } from '../../../storage'; import { StatusView } from '../../../views'; @@ -189,18 +187,6 @@ suite('Participant Controller Test Suite', function () { expect(collectionName).to.be.equal('cats'); }); - test('parses a returned by ai text for code blocks', function () { - const text = - '```javascript\n' + - "use('test');\n" + - "db.getCollection('test').find({ name: 'Shika' });\n" + - '```'; - const code = getRunnableContentFromString(text); - expect(code).to.be.equal( - "use('test');\ndb.getCollection('test').find({ name: 'Shika' });" - ); - }); - suite('when not connected', function () { let connectWithConnectionIdStub; let changeActiveConnectionStub; @@ -470,14 +456,71 @@ suite('Participant Controller Test Suite', function () { }); suite('generic command', function () { - test('generates a query', async function () { + suite('when the intent is recognized', function () { + beforeEach(function () { + sendRequestStub.onCall(0).resolves({ + text: ['Schema'], + }); + }); + + test('routes to the appropriate handler', async function () { + const chatRequestMock = { + prompt: + 'what is the shape of the documents in the pineapple collection?', + command: undefined, + references: [], + }; + const res = await invokeChatHandler(chatRequestMock); + + expect(sendRequestStub).to.have.been.calledTwice; + const intentRequest = sendRequestStub.firstCall.args[0]; + expect(intentRequest).to.have.length(2); + expect(intentRequest[0].content).to.include( + 'Your task is to help guide a conversation with a user to the correct handler.' + ); + expect(intentRequest[1].content).to.equal( + 'what is the shape of the documents in the pineapple collection?' + ); + const genericRequest = sendRequestStub.secondCall.args[0]; + expect(genericRequest).to.have.length(2); + expect(genericRequest[0].content).to.include( + 'Parse all user messages to find a database name and a collection name.' + ); + expect(genericRequest[1].content).to.equal( + 'what is the shape of the documents in the pineapple collection?' + ); + + expect(res?.metadata.intent).to.equal('askForNamespace'); + }); + }); + + test('default handler asks for intent and shows code run actions', async function () { const chatRequestMock = { prompt: 'how to find documents in my collection?', command: undefined, references: [], }; - await invokeChatHandler(chatRequestMock); + const res = await invokeChatHandler(chatRequestMock); + + expect(sendRequestStub).to.have.been.calledTwice; + const intentRequest = sendRequestStub.firstCall.args[0]; + expect(intentRequest).to.have.length(2); + expect(intentRequest[0].content).to.include( + 'Your task is to help guide a conversation with a user to the correct handler.' + ); + expect(intentRequest[1].content).to.equal( + 'how to find documents in my collection?' + ); + const genericRequest = sendRequestStub.secondCall.args[0]; + expect(genericRequest).to.have.length(2); + expect(genericRequest[0].content).to.include( + 'Your task is to help the user with MongoDB related questions.' + ); + expect(genericRequest[1].content).to.equal( + 'how to find documents in my collection?' + ); + expect(res?.metadata.intent).to.equal('generic'); expect(chatStreamStub?.button.getCall(0).args[0]).to.deep.equal({ command: 'mdb.runParticipantQuery', title: '▶️ Run', diff --git a/src/test/suite/participant/streamParsing.test.ts b/src/test/suite/participant/streamParsing.test.ts new file mode 100644 index 000000000..66208ecdd --- /dev/null +++ b/src/test/suite/participant/streamParsing.test.ts @@ -0,0 +1,219 @@ +import { beforeEach } from 'mocha'; +import { expect } from 'chai'; + +import { processStreamWithIdentifiers } from '../../../participant/streamParsing'; +import { asyncIterableFromArray } from './asyncIterableFromArray'; + +const defaultCodeBlockIdentifier = { + start: '```', + end: '```', +}; + +suite('processStreamWithIdentifiers', () => { + let fragmentsProcessed: string[] = []; + let identifiersStreamed: string[] = []; + + const processStreamFragment = (fragment: string): void => { + fragmentsProcessed.push(fragment); + }; + + const onStreamIdentifier = (content: string): void => { + identifiersStreamed.push(content); + }; + + beforeEach(function () { + fragmentsProcessed = []; + identifiersStreamed = []; + }); + + test('empty', async () => { + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable: asyncIterableFromArray([]), + identifier: defaultCodeBlockIdentifier, + }); + + expect(fragmentsProcessed).to.be.empty; + expect(identifiersStreamed).to.be.empty; + }); + + test('input with no code block', async () => { + const inputText = 'This is some sample text without code blocks.'; + const inputFragments = inputText.match(/.{1,5}/g) || []; + const inputIterable = asyncIterableFromArray(inputFragments); + + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier: defaultCodeBlockIdentifier, + }); + + expect(fragmentsProcessed.join('')).to.equal(inputText); + expect(identifiersStreamed).to.be.empty; + }); + + test('one code block with fragment sizes 2', async () => { + const inputText = '```javascript\npineapple\n```\nMore text.'; + const inputFragments: string[] = []; + let index = 0; + const fragmentSize = 2; + while (index < inputText.length) { + const fragment = inputText.substr(index, fragmentSize); + inputFragments.push(fragment); + index += fragmentSize; + } + + const inputIterable = asyncIterableFromArray(inputFragments); + + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier: { + start: '```javascript', + end: '```', + }, + }); + + expect(fragmentsProcessed.join('')).to.equal(inputText); + expect(identifiersStreamed).to.have.lengthOf(1); + expect(identifiersStreamed[0]).to.equal('\npineapple\n'); + }); + + test('multiple code blocks', async () => { + const inputText = + 'Text before code.\n```\ncode1\n```\nText between code.\n```\ncode2\n```\nText after code.'; + const inputFragments = inputText.split(''); + + const inputIterable = asyncIterableFromArray(inputFragments); + + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier: defaultCodeBlockIdentifier, + }); + + expect(fragmentsProcessed.join('')).to.equal(inputText); + expect(identifiersStreamed).to.deep.equal(['\ncode1\n', '\ncode2\n']); + }); + + test('unfinished code block', async () => { + const inputText = + 'Text before code.\n```\ncode content without end identifier.'; + const inputFragments = inputText.split(''); + + const inputIterable = asyncIterableFromArray(inputFragments); + + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier: defaultCodeBlockIdentifier, + }); + + expect(fragmentsProcessed.join('')).to.equal(inputText); + expect(identifiersStreamed).to.be.empty; + }); + + test('code block identifier is a fragment', async () => { + const inputFragments = [ + 'Text before code.\n', + '```js', + '\ncode content\n', + '```', + '```js', + '\npineapple\n', + '```', + '\nText after code.', + ]; + + const inputIterable = asyncIterableFromArray(inputFragments); + + const identifier = { start: '```js', end: '```' }; + + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier, + }); + + expect(fragmentsProcessed.join('')).to.deep.equal(inputFragments.join('')); + + expect(identifiersStreamed).to.deep.equal([ + '\ncode content\n', + '\npineapple\n', + ]); + }); + + test('code block identifier split between fragments', async () => { + const inputFragments = [ + 'Text before code.\n`', + '``j', + 's\ncode content\n`', + '``', + '\nText after code.', + ]; + + const inputIterable = asyncIterableFromArray(inputFragments); + + const identifier = { start: '```js', end: '```' }; + + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier, + }); + + expect(fragmentsProcessed.join('')).to.deep.equal(inputFragments.join('')); + + expect(identifiersStreamed).to.deep.equal(['\ncode content\n']); + }); + + test('fragments containing multiple code blocks', async () => { + const inputFragments = [ + 'Text before code.\n```', + 'js\ncode1\n```', + '\nText', + ' between code.\n``', + '`js\ncode2\n``', + '`\nText after code.', + ]; + + const inputIterable = asyncIterableFromArray(inputFragments); + const identifier = { start: '```js', end: '```' }; + + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier, + }); + + expect(fragmentsProcessed.join('')).to.deep.equal(inputFragments.join('')); + expect(identifiersStreamed).to.deep.equal(['\ncode1\n', '\ncode2\n']); + }); + + test('one fragment containing multiple code blocks', async () => { + const inputFragments = [ + 'Text before code.\n```js\ncode1\n```\nText between code.\n```js\ncode2\n```\nText after code.', + ]; + + const inputIterable = asyncIterableFromArray(inputFragments); + const identifier = { start: '```js', end: '```' }; + + await processStreamWithIdentifiers({ + processStreamFragment, + onStreamIdentifier, + inputIterable, + identifier, + }); + + expect(fragmentsProcessed.join('')).to.deep.equal(inputFragments.join('')); + expect(identifiersStreamed).to.deep.equal(['\ncode1\n', '\ncode2\n']); + }); +});