From b54b22ff990f6f1bee675c44f4ab363d8a603aae Mon Sep 17 00:00:00 2001 From: Alena Khineika Date: Fri, 6 Sep 2024 19:51:25 +0200 Subject: [PATCH] feat: send sample documents to the model for better results VSCODE-580 (#806) * feat: send sample documents to the model for better results VSCODE-580 * refactor: address pr comments * refactor: count tokens --- .eslintrc.js | 6 + README.md | 1 + package.json | 5 + src/editors/editDocumentCodeLensProvider.ts | 4 +- src/editors/playgroundController.ts | 14 +- src/mdbExtensionController.ts | 15 +- src/participant/model.ts | 23 ++ src/participant/participant.ts | 64 ++-- src/participant/prompts/query.ts | 53 ++-- src/participant/sampleDocuments.ts | 72 +++++ .../ai-accuracy-tests/ai-accuracy-tests.ts | 15 +- src/test/ai-accuracy-tests/ai-backend.ts | 4 +- .../suite/participant/participant.test.ts | 277 +++++++++++++++++- 13 files changed, 473 insertions(+), 80 deletions(-) create mode 100644 src/participant/model.ts create mode 100644 src/participant/sampleDocuments.ts diff --git a/.eslintrc.js b/.eslintrc.js index 8324538d6..24170e282 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -63,6 +63,12 @@ module.exports = { 'error', { prefer: 'type-imports' }, ], + '@typescript-eslint/explicit-function-return-type': [ + 'warn', + { + allowHigherOrderFunctions: true, + }, + ], }, parserOptions: { project: ['./tsconfig.json'], // Specify it only for TypeScript files. diff --git a/README.md b/README.md index 01a8fb61e..d3233d25f 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,7 @@ Connect to Atlas Stream Processing instances and develop stream processors using | `mdb.defaultLimit` | The number of documents to fetch when viewing documents from a collection. | `10` | | `mdb.confirmRunAll` | Show a confirmation message before running commands in a playground. | `true` | | `mdb.confirmRunCopilotCode` | Show a confirmation message before running code generated by the MongoDB participant. | `true` | +| `mdb.useSampleDocsInCopilot` | Enable sending sample field values with the VSCode copilot chat @MongoDB participant /query command. | `false` | | `mdb.confirmDeleteDocument` | Show a confirmation message before deleting a document in the tree view. | `true` | | `mdb.persistOIDCTokens` | Remain logged in when using the MONGODB-OIDC authentication mechanism for MongoDB server connection. Access tokens are encrypted using the system keychain before being stored. | `true` | | `mdb.showOIDCDeviceAuthFlow` | Opt-in and opt-out for diagnostic and telemetry collection. | `true` | diff --git a/package.json b/package.json index 7ac90b042..7b742b4a9 100644 --- a/package.json +++ b/package.json @@ -1101,6 +1101,11 @@ "default": true, "description": "Show a confirmation message before running code generated by the MongoDB participant." }, + "mdb.useSampleDocsInCopilot": { + "type": "boolean", + "default": false, + "description": "Enable sending sample field values with the VSCode copilot chat @MongoDB participant /query command." + }, "mdb.confirmDeleteDocument": { "type": "boolean", "default": true, diff --git a/src/editors/editDocumentCodeLensProvider.ts b/src/editors/editDocumentCodeLensProvider.ts index 723df8deb..c380c95ec 100644 --- a/src/editors/editDocumentCodeLensProvider.ts +++ b/src/editors/editDocumentCodeLensProvider.ts @@ -33,7 +33,7 @@ export default class EditDocumentCodeLensProvider content: Document; namespace: string | null; uri: vscode.Uri; - }) { + }): void { let resultCodeLensesInfo: EditDocumentInfo[] = []; resultCodeLensesInfo = this._updateCodeLensesForCursor({ @@ -44,7 +44,7 @@ export default class EditDocumentCodeLensProvider this._codeLensesInfo[data.uri.toString()] = resultCodeLensesInfo; } - updateCodeLensesForPlayground(playgroundResult: PlaygroundResult) { + updateCodeLensesForPlayground(playgroundResult: PlaygroundResult): void { const source = DocumentSource.DOCUMENT_SOURCE_PLAYGROUND; let resultCodeLensesInfo: EditDocumentInfo[] = []; diff --git a/src/editors/playgroundController.ts b/src/editors/playgroundController.ts index 0d82c9bc5..505459191 100644 --- a/src/editors/playgroundController.ts +++ b/src/editors/playgroundController.ts @@ -599,7 +599,7 @@ export default class PlaygroundController { await this._openInResultPane(evaluateResponse.result); - return Promise.resolve(true); + return true; } async _evaluatePlayground(text: string): Promise { @@ -684,17 +684,11 @@ export default class PlaygroundController { return Promise.resolve(false); } - const selections = this._activeTextEditor.selections; - - let codeToEvaluate; - if ( - !selections || - !Array.isArray(selections) || - (selections.length === 1 && this._getSelectedText(selections[0]) === '') - ) { + let codeToEvaluate = ''; + if (!this._selectedText) { this._isPartialRun = false; codeToEvaluate = this._getAllText(); - } else if (this._selectedText) { + } else { this._isPartialRun = true; codeToEvaluate = this._selectedText; } diff --git a/src/mdbExtensionController.ts b/src/mdbExtensionController.ts index ac88576dd..9574ae6f8 100644 --- a/src/mdbExtensionController.ts +++ b/src/mdbExtensionController.ts @@ -151,6 +151,7 @@ export default class MDBExtensionController implements vscode.Disposable { this._helpExplorer.activateHelpTreeView(this._telemetryService); this._playgroundsExplorer.activatePlaygroundsTreeView(); this._telemetryService.activateSegmentAnalytics(); + this._participantController.createParticipant(this._context); await this._connectionController.loadSavedConnections(); await this._languageServerController.startLanguageServer(); @@ -332,11 +333,13 @@ export default class MDBExtensionController implements vscode.Disposable { return commandHandler(args); }; - - this._context.subscriptions.push( - this._participantController.getParticipant(this._context), - vscode.commands.registerCommand(command, commandHandlerWithTelemetry) - ); + const participant = this._participantController.getParticipant(); + if (participant) { + this._context.subscriptions.push( + participant, + vscode.commands.registerCommand(command, commandHandlerWithTelemetry) + ); + } }; registerCommand = ( @@ -778,7 +781,7 @@ export default class MDBExtensionController implements vscode.Disposable { this.registerAtlasStreamsTreeViewCommands(); } - registerAtlasStreamsTreeViewCommands() { + registerAtlasStreamsTreeViewCommands(): void { this.registerCommand( EXTENSION_COMMANDS.MDB_ADD_STREAM_PROCESSOR, async (element: ConnectionTreeItem): Promise => { diff --git a/src/participant/model.ts b/src/participant/model.ts new file mode 100644 index 000000000..f5c2568d0 --- /dev/null +++ b/src/participant/model.ts @@ -0,0 +1,23 @@ +import * as vscode from 'vscode'; + +import { CHAT_PARTICIPANT_MODEL } from './constants'; + +let model: vscode.LanguageModelChat; + +export async function getCopilotModel(): Promise< + vscode.LanguageModelChat | undefined +> { + if (!model) { + try { + const [model] = await vscode.lm.selectChatModels({ + vendor: 'copilot', + family: CHAT_PARTICIPANT_MODEL, + }); + return model; + } catch (err) { + // Model is not ready yet. It is being initialised with the first user prompt. + } + } + + return; +} diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 6d3f5005e..de958f678 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -1,5 +1,6 @@ import * as vscode from 'vscode'; import { getSimplifiedSchema } from 'mongodb-schema'; +import type { Document } from 'bson'; import { createLogger } from '../logging'; import type ConnectionController from '../connectionController'; @@ -8,10 +9,12 @@ import EXTENSION_COMMANDS from '../commands'; import type { StorageController } from '../storage'; import { StorageVariables } from '../storage'; import { GenericPrompt } from './prompts/generic'; -import { CHAT_PARTICIPANT_ID, CHAT_PARTICIPANT_MODEL } from './constants'; +import { CHAT_PARTICIPANT_ID } from './constants'; import { QueryPrompt } from './prompts/query'; import { COL_NAME_ID, DB_NAME_ID, NamespacePrompt } from './prompts/namespace'; import { SchemaFormatter } from './schema'; +import { getSimplifiedSampleDocuments } from './sampleDocuments'; +import { getCopilotModel } from './model'; const log = createLogger('participant'); @@ -20,10 +23,11 @@ export enum QUERY_GENERATION_STATE { ASK_TO_CONNECT = 'ASK_TO_CONNECT', ASK_FOR_DATABASE_NAME = 'ASK_FOR_DATABASE_NAME', ASK_FOR_COLLECTION_NAME = 'ASK_FOR_COLLECTION_NAME', + CHANGE_DATABASE_NAME = 'CHANGE_DATABASE_NAME', FETCH_SCHEMA = 'FETCH_SCHEMA', } -const NUM_DOCUMENTS_TO_SAMPLE = 4; +const NUM_DOCUMENTS_TO_SAMPLE = 3; interface ChatResult extends vscode.ChatResult { metadata: { @@ -50,7 +54,7 @@ export function parseForDatabaseAndCollectionName(text: string): { return { databaseName, collectionName }; } -export function getRunnableContentFromString(text: string) { +export function getRunnableContentFromString(text: string): string { const matchedJSresponseContent = text.match(/```javascript((.|\n)*)```/); const code = @@ -69,6 +73,7 @@ export default class ParticipantController { _databaseName?: string; _collectionName?: string; _schema?: string; + _sampleDocuments?: Document[]; constructor({ connectionController, @@ -81,17 +86,18 @@ export default class ParticipantController { this._storageController = storageController; } - _setDatabaseName(name: string | undefined) { + _setDatabaseName(name: string | undefined): void { if ( this._queryGenerationState === QUERY_GENERATION_STATE.DEFAULT && this._databaseName !== name ) { - this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA; + this._queryGenerationState = QUERY_GENERATION_STATE.CHANGE_DATABASE_NAME; + this._collectionName = undefined; } this._databaseName = name; } - _setCollectionName(name: string | undefined) { + _setCollectionName(name: string | undefined): void { if ( this._queryGenerationState === QUERY_GENERATION_STATE.DEFAULT && this._collectionName !== name @@ -101,7 +107,7 @@ export default class ParticipantController { this._collectionName = name; } - createParticipant(context: vscode.ExtensionContext) { + createParticipant(context: vscode.ExtensionContext): vscode.ChatParticipant { // Chat participants appear as top-level options in the chat input // when you type `@`, and can contribute sub-commands in the chat input // that appear when you type `/`. @@ -120,8 +126,8 @@ export default class ParticipantController { return this._participant; } - getParticipant(context: vscode.ExtensionContext) { - return this._participant || this.createParticipant(context); + getParticipant(): vscode.ChatParticipant | undefined { + return this._participant; } async handleEmptyQueryRequest(): Promise<(string | vscode.MarkdownString)[]> { @@ -193,20 +199,17 @@ export default class ParticipantController { stream: vscode.ChatResponseStream; token: vscode.CancellationToken; }): Promise { + const model = await getCopilotModel(); let responseContent = ''; - try { - const [model] = await vscode.lm.selectChatModels({ - vendor: 'copilot', - family: CHAT_PARTICIPANT_MODEL, - }); - if (model) { + if (model) { + try { const chatResponse = await model.sendRequest(messages, {}, token); for await (const fragment of chatResponse.text) { responseContent += fragment; } + } catch (err) { + this.handleError(err, stream); } - } catch (err) { - this.handleError(err, stream); } return responseContent; @@ -483,14 +486,17 @@ export default class ParticipantController { ![ QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME, QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME, + QUERY_GENERATION_STATE.CHANGE_DATABASE_NAME, ].includes(this._queryGenerationState) ) { return false; } if ( - this._queryGenerationState === - QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME + [ + QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME, + QUERY_GENERATION_STATE.CHANGE_DATABASE_NAME, + ].includes(this._queryGenerationState) ) { this._setDatabaseName(prompt); if (!this._collectionName) { @@ -616,7 +622,9 @@ export default class ParticipantController { return this._queryGenerationState === QUERY_GENERATION_STATE.FETCH_SCHEMA; } - async _fetchCollectionSchema(abortSignal?: AbortSignal): Promise { + async _fetchCollectionSchemaAndSampleDocuments( + abortSignal?: AbortSignal + ): Promise { if (this._queryGenerationState === QUERY_GENERATION_STATE.FETCH_SCHEMA) { this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT; } @@ -642,8 +650,17 @@ export default class ParticipantController { const schema = await getSimplifiedSchema(sampleDocuments); this._schema = new SchemaFormatter().format(schema); + + const useSampleDocsInCopilot = !!vscode.workspace + .getConfiguration('mdb') + .get('useSampleDocsInCopilot'); + + if (useSampleDocsInCopilot) { + this._sampleDocuments = getSimplifiedSampleDocuments(sampleDocuments); + } } catch (err: any) { this._schema = undefined; + this._sampleDocuments = undefined; } } @@ -679,15 +696,18 @@ export default class ParticipantController { }); if (this._shouldFetchCollectionSchema()) { - await this._fetchCollectionSchema(abortController.signal); + await this._fetchCollectionSchemaAndSampleDocuments( + abortController.signal + ); } - const messages = QueryPrompt.buildMessages({ + const messages = await QueryPrompt.buildMessages({ request, context, databaseName: this._databaseName, collectionName: this._collectionName, schema: this._schema, + sampleDocuments: this._sampleDocuments, }); const responseContent = await this.getChatResponseContent({ messages, diff --git a/src/participant/prompts/query.ts b/src/participant/prompts/query.ts index 58037a58b..18ce1c812 100644 --- a/src/participant/prompts/query.ts +++ b/src/participant/prompts/query.ts @@ -1,18 +1,22 @@ import * as vscode from 'vscode'; +import type { Document } from 'bson'; import { getHistoryMessages } from './history'; +import { getStringifiedSampleDocuments } from '../sampleDocuments'; export class QueryPrompt { - static getAssistantPrompt({ + static async getAssistantPrompt({ databaseName = 'mongodbVSCodeCopilotDB', collectionName = 'test', schema, + sampleDocuments, }: { databaseName?: string; collectionName?: string; schema?: string; - }): vscode.LanguageModelChatMessage { - const prompt = `You are a MongoDB expert. + sampleDocuments?: Document[]; + }): Promise { + let prompt = `You are a MongoDB expert. Your task is to help the user craft MongoDB queries and aggregation pipelines that perform their task. Keep your response concise. @@ -23,6 +27,8 @@ Respond in MongoDB shell syntax using the \`\`\`javascript code block syntax. You can use only the following MongoDB Shell commands: use, aggregate, bulkWrite, countDocuments, findOneAndReplace, findOneAndUpdate, insert, insertMany, insertOne, remove, replaceOne, update, updateMany, updateOne. +Concisely explain the code snippet you have generated. + Example 1: use(''); db.getCollection('').aggregate([ @@ -38,22 +44,26 @@ db.getCollection('').find({ date: { $gte: new Date('2014-04-04'), $lt: new Date('2014-04-05') } }).count(); -Database name: ${databaseName} -Collection name: ${collectionName} -${ - schema - ? `Collection schema: -${schema}` - : '' -} - MongoDB command to specify database: use(''); MongoDB command to specify collection: -db.getCollection('') - -Concisely explain the code snippet you have generated.`; +db.getCollection('');\n\n`; + if (databaseName) { + prompt += `Database name: ${databaseName}\n`; + } + if (collectionName) { + prompt += `Collection name: ${collectionName}\n`; + } + if (schema) { + prompt += `Collection schema: ${schema}\n`; + } + if (sampleDocuments) { + prompt += await getStringifiedSampleDocuments({ + sampleDocuments, + prompt, + }); + } // eslint-disable-next-line new-cap return vscode.LanguageModelChatMessage.Assistant(prompt); @@ -64,12 +74,13 @@ Concisely explain the code snippet you have generated.`; return vscode.LanguageModelChatMessage.User(prompt); } - static buildMessages({ + static async buildMessages({ context, request, databaseName, collectionName, schema, + sampleDocuments, }: { request: { prompt: string; @@ -78,9 +89,15 @@ Concisely explain the code snippet you have generated.`; databaseName?: string; collectionName?: string; schema?: string; - }): vscode.LanguageModelChatMessage[] { + sampleDocuments?: Document[]; + }): Promise { const messages = [ - QueryPrompt.getAssistantPrompt({ databaseName, collectionName, schema }), + await QueryPrompt.getAssistantPrompt({ + databaseName, + collectionName, + schema, + sampleDocuments, + }), ...getHistoryMessages({ context }), QueryPrompt.getUserPrompt(request.prompt), ]; diff --git a/src/participant/sampleDocuments.ts b/src/participant/sampleDocuments.ts new file mode 100644 index 000000000..bfefa310b --- /dev/null +++ b/src/participant/sampleDocuments.ts @@ -0,0 +1,72 @@ +import { toJSString } from 'mongodb-query-parser'; +import type { Document } from 'bson'; +import { getCopilotModel } from './model'; + +const MAX_ARRAY_LENGTH_OF_SAMPLE_DOCUMENT_VALUE = 3; + +const MAX_STRING_LENGTH_OF_SAMPLE_DOCUMENT_VALUE = 20; + +export function getSimplifiedSampleDocuments(obj: Document[]): Document[] { + function truncate(value: any): any { + if (typeof value === 'string') { + return value.slice(0, MAX_STRING_LENGTH_OF_SAMPLE_DOCUMENT_VALUE); + } else if (typeof value === 'object' && value !== null) { + if (Array.isArray(value)) { + value = value.slice(0, MAX_ARRAY_LENGTH_OF_SAMPLE_DOCUMENT_VALUE); + } + // Recursively truncate strings in nested objects or arrays. + for (const key in value) { + if (value.hasOwnProperty(key)) { + value[key] = truncate(value[key]); + } + } + } + return value; + } + + return truncate(obj); +} + +export async function getStringifiedSampleDocuments({ + prompt, + sampleDocuments, +}: { + prompt: string; + sampleDocuments: Document[]; +}): Promise { + if (!sampleDocuments.length) { + return ''; + } + + const model = await getCopilotModel(); + if (!model) { + return ''; + } + + let additionToPrompt: Document[] | Document = sampleDocuments; + let promptInputTokens = await model.countTokens( + prompt + toJSString(sampleDocuments) + ); + + // First check the length of all stringified sample documents. + // If the resulting prompt is too large, proceed with only 1 sample document. + // We also convert an array that contains only 1 element to a single document. + if ( + promptInputTokens > model.maxInputTokens || + sampleDocuments.length === 1 + ) { + additionToPrompt = sampleDocuments[0]; + } + + const stringifiedDocuments = toJSString(additionToPrompt); + promptInputTokens = await model.countTokens(prompt + stringifiedDocuments); + + // Add sample documents to the prompt only when it fits in the context window. + if (promptInputTokens <= model.maxInputTokens) { + return `Sample document${ + Array.isArray(additionToPrompt) ? 's' : '' + }: ${stringifiedDocuments}\n`; + } + + return ''; +} diff --git a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts index f26cdc11c..1ad4924bb 100644 --- a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts +++ b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts @@ -11,6 +11,7 @@ import * as vscode from 'vscode'; import { loadFixturesToDB, reloadFixture } from './fixtures/fixture-loader'; import type { Fixtures } from './fixtures/fixture-loader'; import { AIBackend } from './ai-backend'; +import type { ChatCompletion } from './ai-backend'; import { GenericPrompt } from '../../participant/prompts/generic'; import { QueryPrompt } from '../../participant/prompts/query'; import { @@ -106,7 +107,7 @@ const queryTestCases: TestCase[] = [ assertResult: async ({ responseContent, connectionString, - }: AssertProps) => { + }: AssertProps): Promise => { const result = await runCodeInMessage(responseContent, connectionString); const totalResponse = `${result.printOutput.join('')}${ @@ -242,7 +243,7 @@ async function pushResultsToDB({ anyFailedAccuracyThreshold: boolean; runTimeMS: number; httpErrors: number; -}) { +}): Promise { const client = new MongoClient( process.env.AI_ACCURACY_RESULTS_MONGODB_CONNECTION_STRING || '' ); @@ -278,13 +279,13 @@ async function pushResultsToDB({ } } -const buildMessages = ({ +const buildMessages = async ({ testCase, fixtures, }: { testCase: TestCase; fixtures: Fixtures; -}) => { +}): Promise => { switch (testCase.type) { case 'generic': return GenericPrompt.buildMessages({ @@ -293,7 +294,7 @@ const buildMessages = ({ }); case 'query': - return QueryPrompt.buildMessages({ + return await QueryPrompt.buildMessages({ request: { prompt: testCase.userInput }, context: { history: [] }, databaseName: testCase.databaseName, @@ -329,8 +330,8 @@ async function runTest({ testCase: TestCase; aiBackend: AIBackend; fixtures: Fixtures; -}) { - const messages = buildMessages({ +}): Promise { + const messages = await buildMessages({ testCase, fixtures, }); diff --git a/src/test/ai-accuracy-tests/ai-backend.ts b/src/test/ai-accuracy-tests/ai-backend.ts index a9b746e1f..02534f580 100644 --- a/src/test/ai-accuracy-tests/ai-backend.ts +++ b/src/test/ai-accuracy-tests/ai-backend.ts @@ -4,7 +4,7 @@ import type { ChatCompletionCreateParamsBase } from 'openai/resources/chat/compl import { CHAT_PARTICIPANT_MODEL } from '../../participant/constants'; let openai: OpenAI; -function getOpenAIClient() { +function getOpenAIClient(): OpenAI { if (!openai) { openai = new OpenAI({ apiKey: process.env.OPENAI_API_KEY, @@ -22,7 +22,7 @@ type ChatMessage = { }; type ChatMessages = ChatMessage[]; -type ChatCompletion = { +export type ChatCompletion = { content: string; usageStats: { promptTokens: number; diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 137ba7851..761524a44 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -3,6 +3,7 @@ import { beforeEach, afterEach } from 'mocha'; import { expect } from 'chai'; import sinon from 'sinon'; import type { DataService } from 'mongodb-data-service'; +import { ObjectId, Int32 } from 'bson'; import ParticipantController, { parseForDatabaseAndCollectionName, @@ -22,6 +23,10 @@ import { } from '../../../storage/storageController'; import type { LoadedConnection } from '../../../storage/connectionStorage'; +// The Copilot's model in not available in tests, +// therefore we need to mock its methods and returning values. +export const MAX_TOTAL_PROMPT_LENGTH = 16000; + const loadedConnection = { id: 'id', name: 'localhost', @@ -44,6 +49,7 @@ suite('Participant Controller Test Suite', function () { let chatContextStub; let chatStreamStub; let chatTokenStub; + let countTokensStub; let sendRequestStub; beforeEach(function () { @@ -76,8 +82,9 @@ suite('Participant Controller Test Suite', function () { button: sinon.fake(), }; chatTokenStub = { - onCancellationRequested: () => {}, + onCancellationRequested: sinon.fake(), }; + countTokensStub = sinon.stub(); // The model returned by vscode.lm.selectChatModels is always undefined in tests. sendRequestStub = sinon.fake.resolves({ text: [ @@ -97,8 +104,8 @@ suite('Participant Controller Test Suite', function () { family: 'gpt-4o', version: 'gpt-4o-date', name: 'GPT 4o (date)', - maxInputTokens: 16211, - countTokens: () => {}, + maxInputTokens: MAX_TOTAL_PROMPT_LENGTH, + countTokens: countTokensStub, sendRequest: sendRequestStub, }, ]) @@ -293,7 +300,10 @@ suite('Participant Controller Test Suite', function () { }); suite('when connected', function () { + let sampleStub; + beforeEach(function () { + sampleStub = sinon.stub(); sinon.replace( testParticipantController._connectionController, 'getActiveDataService', @@ -331,13 +341,7 @@ suite('Participant Controller Test Suite', function () { url: TEST_DATABASE_URI, options: {}, }), - sample: () => - Promise.resolve([ - { - _id: '66b3408a60da951fc354743e', - field: { subField: '66b3408a60da951fc354743e' }, - }, - ]), + sample: sampleStub, once: sinon.stub(), } as unknown as DataService) ); @@ -449,6 +453,16 @@ suite('Participant Controller Test Suite', function () { sinon .stub(testParticipantController, '_queryGenerationState') .value(QUERY_GENERATION_STATE.FETCH_SCHEMA); + sampleStub.resolves([ + { + _id: new ObjectId('63ed1d522d8573fa5c203660'), + field: { + stringField: + 'There was a house cat who finally got the chance to do what it had always wanted to do.', + arrayField: [new Int32('1')], + }, + }, + ]); const chatRequestMock = { prompt: 'find all docs by a name example', command: 'query', @@ -462,11 +476,248 @@ suite('Participant Controller Test Suite', function () { ); const messages = sendRequestStub.firstCall.args[0]; expect(messages[0].content).to.include( - 'Collection schema:\n' + - '_id: String\n' + - 'field.subField: String\n' + 'Collection schema: _id: ObjectId\n' + + 'field.stringField: String\n' + + 'field.arrayField: Array\n' ); }); + + suite('useSampleDocsInCopilot setting is true', function () { + beforeEach(async () => { + await vscode.workspace + .getConfiguration('mdb') + .update('useSampleDocsInCopilot', true); + }); + + afterEach(async () => { + await vscode.workspace + .getConfiguration('mdb') + .update('useSampleDocsInCopilot', false); + }); + + test('includes 3 sample documents as an array', async function () { + sinon + .stub(testParticipantController, '_queryGenerationState') + .value(QUERY_GENERATION_STATE.FETCH_SCHEMA); + countTokensStub.resolves(MAX_TOTAL_PROMPT_LENGTH); + sampleStub.resolves([ + { + _id: new ObjectId('63ed1d522d8573fa5c203661'), + field: { + stringField: 'Text 1', + }, + }, + { + _id: new ObjectId('63ed1d522d8573fa5c203662'), + field: { + stringField: 'Text 2', + }, + }, + { + _id: new ObjectId('63ed1d522d8573fa5c203663'), + field: { + stringField: 'Text 3', + }, + }, + ]); + const chatRequestMock = { + prompt: 'find all docs by a name example', + command: 'query', + references: [], + }; + await testParticipantController.chatHandler( + chatRequestMock, + chatContextStub, + chatStreamStub, + chatTokenStub + ); + const messages = sendRequestStub.firstCall.args[0]; + expect(messages[0].content).to.include( + 'Sample documents: [\n' + + ' {\n' + + " _id: ObjectId('63ed1d522d8573fa5c203661'),\n" + + ' field: {\n' + + " stringField: 'Text 1'\n" + + ' }\n' + + ' },\n' + + ' {\n' + + " _id: ObjectId('63ed1d522d8573fa5c203662'),\n" + + ' field: {\n' + + " stringField: 'Text 2'\n" + + ' }\n' + + ' },\n' + + ' {\n' + + " _id: ObjectId('63ed1d522d8573fa5c203663'),\n" + + ' field: {\n' + + " stringField: 'Text 3'\n" + + ' }\n' + + ' }\n' + + ']\n' + ); + }); + + test('includes 1 sample document as an object', async function () { + sinon + .stub(testParticipantController, '_queryGenerationState') + .value(QUERY_GENERATION_STATE.FETCH_SCHEMA); + countTokensStub.resolves(MAX_TOTAL_PROMPT_LENGTH); + sampleStub.resolves([ + { + _id: new ObjectId('63ed1d522d8573fa5c203660'), + field: { + stringField: + 'There was a house cat who finally got the chance to do what it had always wanted to do.', + arrayField: [ + new Int32('1'), + new Int32('2'), + new Int32('3'), + new Int32('4'), + new Int32('5'), + new Int32('6'), + new Int32('7'), + new Int32('8'), + new Int32('9'), + ], + }, + }, + ]); + const chatRequestMock = { + prompt: 'find all docs by a name example', + command: 'query', + references: [], + }; + await testParticipantController.chatHandler( + chatRequestMock, + chatContextStub, + chatStreamStub, + chatTokenStub + ); + const messages = sendRequestStub.firstCall.args[0]; + expect(messages[0].content).to.include( + 'Sample document: {\n' + + " _id: ObjectId('63ed1d522d8573fa5c203660'),\n" + + ' field: {\n' + + " stringField: 'There was a house ca',\n" + + ' arrayField: [\n' + + " NumberInt('1'),\n" + + " NumberInt('2'),\n" + + " NumberInt('3')\n" + + ' ]\n' + + ' }\n' + + '}\n' + ); + }); + + test('includes 1 sample documents when 3 make prompt too long', async function () { + sinon + .stub(testParticipantController, '_queryGenerationState') + .value(QUERY_GENERATION_STATE.FETCH_SCHEMA); + countTokensStub.onCall(0).resolves(MAX_TOTAL_PROMPT_LENGTH + 1); + countTokensStub.onCall(1).resolves(MAX_TOTAL_PROMPT_LENGTH); + sampleStub.resolves([ + { + _id: new ObjectId('63ed1d522d8573fa5c203661'), + field: { + stringField: 'Text 1', + }, + }, + { + _id: new ObjectId('63ed1d522d8573fa5c203662'), + field: { + stringField: 'Text 2', + }, + }, + { + _id: new ObjectId('63ed1d522d8573fa5c203663'), + field: { + stringField: 'Text 3', + }, + }, + ]); + const chatRequestMock = { + prompt: 'find all docs by a name example', + command: 'query', + references: [], + }; + await testParticipantController.chatHandler( + chatRequestMock, + chatContextStub, + chatStreamStub, + chatTokenStub + ); + const messages = sendRequestStub.firstCall.args[0]; + expect(messages[0].content).to.include( + 'Sample document: {\n' + + " _id: ObjectId('63ed1d522d8573fa5c203661'),\n" + + ' field: {\n' + + " stringField: 'Text 1'\n" + + ' }\n' + + '}\n' + ); + }); + + test('does not include sample documents when even 1 makes prompt too long', async function () { + sinon + .stub(testParticipantController, '_queryGenerationState') + .value(QUERY_GENERATION_STATE.FETCH_SCHEMA); + countTokensStub.onCall(0).resolves(MAX_TOTAL_PROMPT_LENGTH + 1); + countTokensStub.onCall(1).resolves(MAX_TOTAL_PROMPT_LENGTH + 1); + sampleStub.resolves([ + { + _id: new ObjectId('63ed1d522d8573fa5c203661'), + field: { + stringField: 'Text 1', + }, + }, + { + _id: new ObjectId('63ed1d522d8573fa5c203662'), + field: { + stringField: 'Text 2', + }, + }, + { + _id: new ObjectId('63ed1d522d8573fa5c203663'), + field: { + stringField: 'Text 3', + }, + }, + ]); + const chatRequestMock = { + prompt: 'find all docs by a name example', + command: 'query', + references: [], + }; + await testParticipantController.chatHandler( + chatRequestMock, + chatContextStub, + chatStreamStub, + chatTokenStub + ); + const messages = sendRequestStub.firstCall.args[0]; + expect(messages[0].content).to.not.include('Sample documents'); + }); + }); + + suite('useSampleDocsInCopilot setting is false', function () { + test('does not include sample documents', async function () { + sinon + .stub(testParticipantController, '_queryGenerationState') + .value(QUERY_GENERATION_STATE.FETCH_SCHEMA); + const chatRequestMock = { + prompt: 'find all docs by a name example', + command: 'query', + references: [], + }; + await testParticipantController.chatHandler( + chatRequestMock, + chatContextStub, + chatStreamStub, + chatTokenStub + ); + const messages = sendRequestStub.firstCall.args[0]; + expect(messages[0].content).to.not.include('Sample documents'); + }); + }); }); suite('unknown namespace', function () {