diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 7d4991d52..01d4c1742 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -1,4 +1,5 @@ import * as vscode from 'vscode'; +import { getSimplifiedSchema } from 'mongodb-schema'; import { createLogger } from '../logging'; import type ConnectionController from '../connectionController'; @@ -10,6 +11,7 @@ import { GenericPrompt } from './prompts/generic'; import { CHAT_PARTICIPANT_ID, CHAT_PARTICIPANT_MODEL } from './constants'; import { QueryPrompt } from './prompts/query'; import { NamespacePrompt } from './prompts/namespace'; +import { SchemaFormatter } from './schema'; const log = createLogger('participant'); @@ -18,9 +20,11 @@ export enum QUERY_GENERATION_STATE { ASK_TO_CONNECT = 'ASK_TO_CONNECT', ASK_FOR_DATABASE_NAME = 'ASK_FOR_DATABASE_NAME', ASK_FOR_COLLECTION_NAME = 'ASK_FOR_COLLECTION_NAME', - READY_TO_GENERATE_QUERY = 'READY_TO_GENERATE_QUERY', + FETCH_SCHEMA = 'FETCH_SCHEMA', } +const NUM_DOCUMENTS_TO_SAMPLE = 4; + interface ChatResult extends vscode.ChatResult { metadata: { responseContent?: string; @@ -67,6 +71,7 @@ export default class ParticipantController { _chatResult?: ChatResult; _databaseName?: string; _collectionName?: string; + _schema?: string; constructor({ connectionController, @@ -79,6 +84,26 @@ export default class ParticipantController { this._storageController = storageController; } + _setDatabaseName(name: string | undefined) { + if ( + this._queryGenerationState === QUERY_GENERATION_STATE.DEFAULT && + this._databaseName !== name + ) { + this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA; + } + this._databaseName = name; + } + + _setCollectionName(name: string | undefined) { + if ( + this._queryGenerationState === QUERY_GENERATION_STATE.DEFAULT && + this._collectionName !== name + ) { + this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA; + } + this._collectionName = name; + } + createParticipant(context: vscode.ExtensionContext) { // Chat participants appear as top-level options in the chat input // when you type `@`, and can contribute sub-commands in the chat input @@ -318,9 +343,10 @@ export default class ParticipantController { async selectDatabaseWithParticipant(name: string): Promise { if (!name) { - this._databaseName = await this._selectDatabaseWithCommandPalette(); + const selectedName = await this._selectDatabaseWithCommandPalette(); + this._setDatabaseName(selectedName); } else { - this._databaseName = name; + this._setDatabaseName(name); } return vscode.commands.executeCommand('workbench.action.chat.open', { @@ -363,9 +389,10 @@ export default class ParticipantController { async selectCollectionWithParticipant(name: string): Promise { if (!name) { - this._collectionName = await this._selectCollectionWithCommandPalette(); + const selectedName = await this._selectCollectionWithCommandPalette(); + this._setCollectionName(selectedName); } else { - this._collectionName = name; + this._setCollectionName(name); } return vscode.commands.executeCommand('workbench.action.chat.open', { @@ -448,8 +475,8 @@ export default class ParticipantController { if (isNewChat) { this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT; this._chatResult = undefined; - this._databaseName = undefined; - this._collectionName = undefined; + this._setDatabaseName(undefined); + this._setCollectionName(undefined); } } @@ -468,7 +495,7 @@ export default class ParticipantController { this._queryGenerationState === QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME ) { - this._databaseName = prompt; + this._setDatabaseName(prompt); if (!this._collectionName) { this._queryGenerationState = QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME; @@ -481,14 +508,13 @@ export default class ParticipantController { this._queryGenerationState === QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME ) { - this._collectionName = prompt; + this._setCollectionName(prompt); if (!this._databaseName) { this._queryGenerationState = QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME; return true; } - this._queryGenerationState = - QUERY_GENERATION_STATE.READY_TO_GENERATE_QUERY; + this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA; return false; } @@ -522,12 +548,11 @@ export default class ParticipantController { responseContentWithNamespace ); - this._databaseName = namespace.databaseName || this._databaseName; - this._collectionName = namespace.collectionName || this._collectionName; + this._setDatabaseName(namespace.databaseName || this._databaseName); + this._setCollectionName(namespace.collectionName || this._collectionName); if (namespace.databaseName && namespace.collectionName) { - this._queryGenerationState = - QUERY_GENERATION_STATE.READY_TO_GENERATE_QUERY; + this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA; return false; } @@ -590,6 +615,41 @@ export default class ParticipantController { return true; } + _shouldFetchCollectionSchema(): boolean { + return this._queryGenerationState === QUERY_GENERATION_STATE.FETCH_SCHEMA; + } + + async _fetchCollectionSchema(abortSignal?: AbortSignal): Promise { + if (this._queryGenerationState === QUERY_GENERATION_STATE.FETCH_SCHEMA) { + this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT; + } + + const dataService = this._connectionController.getActiveDataService(); + if (!dataService || !this._databaseName || !this._collectionName) { + return; + } + + try { + const sampleDocuments = + (await dataService?.sample?.( + `${this._databaseName}.${this._collectionName}`, + { + query: {}, + size: NUM_DOCUMENTS_TO_SAMPLE, + }, + { promoteValues: false }, + { + abortSignal, + } + )) || []; + + const schema = await getSimplifiedSchema(sampleDocuments); + this._schema = new SchemaFormatter().format(schema); + } catch (err: any) { + this._schema = undefined; + } + } + // @MongoDB /query find all documents where the "address" has the word Broadway in it. async handleQueryRequest( request: vscode.ChatRequest, @@ -621,13 +681,17 @@ export default class ParticipantController { abortController.abort(); }); + if (this._shouldFetchCollectionSchema()) { + await this._fetchCollectionSchema(abortController.signal); + } + const messages = QueryPrompt.buildMessages({ request, context, databaseName: this._databaseName, collectionName: this._collectionName, + schema: this._schema, }); - const responseContent = await this.getChatResponseContent({ messages, stream, diff --git a/src/participant/prompts/query.ts b/src/participant/prompts/query.ts index 09e2fc8ba..58037a58b 100644 --- a/src/participant/prompts/query.ts +++ b/src/participant/prompts/query.ts @@ -6,9 +6,11 @@ export class QueryPrompt { static getAssistantPrompt({ databaseName = 'mongodbVSCodeCopilotDB', collectionName = 'test', + schema, }: { databaseName?: string; collectionName?: string; + schema?: string; }): vscode.LanguageModelChatMessage { const prompt = `You are a MongoDB expert. @@ -38,6 +40,12 @@ db.getCollection('').find({ Database name: ${databaseName} Collection name: ${collectionName} +${ + schema + ? `Collection schema: +${schema}` + : '' +} MongoDB command to specify database: use(''); @@ -61,6 +69,7 @@ Concisely explain the code snippet you have generated.`; request, databaseName, collectionName, + schema, }: { request: { prompt: string; @@ -68,9 +77,10 @@ Concisely explain the code snippet you have generated.`; context: vscode.ChatContext; databaseName?: string; collectionName?: string; + schema?: string; }): vscode.LanguageModelChatMessage[] { const messages = [ - QueryPrompt.getAssistantPrompt({ databaseName, collectionName }), + QueryPrompt.getAssistantPrompt({ databaseName, collectionName, schema }), ...getHistoryMessages({ context }), QueryPrompt.getUserPrompt(request.prompt), ]; diff --git a/src/participant/schema.ts b/src/participant/schema.ts new file mode 100644 index 000000000..faf1e7fb8 --- /dev/null +++ b/src/participant/schema.ts @@ -0,0 +1,102 @@ +import type { + SimplifiedSchema, + SimplifiedSchemaArrayType, + SimplifiedSchemaDocumentType, + SimplifiedSchemaType, +} from 'mongodb-schema'; + +const PROPERTY_REGEX = '^[a-zA-Z_$][0-9a-zA-Z_$]*$'; + +export class SchemaFormatter { + static getSchemaFromTypes(pInput: SimplifiedSchema): string { + return new SchemaFormatter().format(pInput); + } + + schemaString = ''; + + format(pInitial: SimplifiedSchema): string { + this.processDocumentType('', pInitial); + return this.schemaString; + } + + private processSchemaTypeList( + prefix: string, + pTypes: SimplifiedSchemaType[] + ) { + if (pTypes.length !== 0) { + this.processSchemaType(prefix, pTypes[0]); + } + } + + private processSchemaType(prefix: string, pType: SimplifiedSchemaType) { + const bsonType = pType.bsonType; + if (bsonType === 'Document') { + const fields = (pType as SimplifiedSchemaDocumentType).fields; + + if (Object.keys(fields).length === 0) { + this.addToFormattedSchemaString(prefix + ': Document'); + return; + } + + this.processDocumentType(prefix, fields); + return; + } + + if (bsonType === 'Array') { + const types = (pType as SimplifiedSchemaArrayType).types; + + if (types.length === 0) { + this.addToFormattedSchemaString(prefix + ': ' + 'Array'); + return; + } + + const firstType = types[0].bsonType; + if (firstType !== 'Array' && firstType !== 'Document') { + this.addToFormattedSchemaString( + prefix + ': ' + 'Array<' + firstType + '>' + ); + return; + } + + // Array of documents or arrays. + // We only use the first type. + this.processSchemaType(prefix + '[]', types[0]); + return; + } + + this.addToFormattedSchemaString(prefix + ': ' + bsonType); + } + + private processDocumentType(prefix: string, pDoc: SimplifiedSchema) { + if (!pDoc) { + return; + } + + Object.keys(pDoc).forEach((key) => { + const keyAsString = this.getPropAsString(key); + this.processSchemaTypeList( + prefix + (prefix.length === 0 ? '' : '.') + keyAsString, + pDoc[key]?.types + ); + }); + } + + getPropAsString(pProp: string): string { + if (pProp.match(PROPERTY_REGEX)) { + return pProp; + } + + try { + return JSON.stringify(pProp); + } catch (e) { + return pProp; + } + } + + addToFormattedSchemaString(fieldAndType: string) { + if (this.schemaString.length > 0) { + this.schemaString += '\n'; + } + this.schemaString += fieldAndType; + } +} diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 47cddbc6c..8f258ccf6 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -44,6 +44,7 @@ suite('Participant Controller Test Suite', function () { let chatContextStub; let chatStreamStub; let chatTokenStub; + let sendRequestStub; beforeEach(function () { testStorageController = new StorageController(extensionContextStub); @@ -78,6 +79,14 @@ suite('Participant Controller Test Suite', function () { onCancellationRequested: () => {}, }; // The model returned by vscode.lm.selectChatModels is always undefined in tests. + sendRequestStub = sinon.fake.resolves({ + text: [ + '```javascript\n' + + "use('dbOne');\n" + + "db.getCollection('collOne').find({ name: 'example' });\n" + + '```', + ], + }); sinon.replace( vscode.lm, 'selectChatModels', @@ -90,15 +99,7 @@ suite('Participant Controller Test Suite', function () { name: 'GPT 4o (date)', maxInputTokens: 16211, countTokens: () => {}, - sendRequest: () => - Promise.resolve({ - text: [ - '```javascript\n' + - "use('dbOne');\n" + - "db.getCollection('collOne').find({ name: 'example' });\n" + - '```', - ], - }), + sendRequest: sendRequestStub, }, ]) ); @@ -330,6 +331,13 @@ suite('Participant Controller Test Suite', function () { url: TEST_DATABASE_URI, options: {}, }), + sample: () => + Promise.resolve([ + { + _id: '66b3408a60da951fc354743e', + field: { subField: '66b3408a60da951fc354743e' }, + }, + ]), once: sinon.stub(), } as unknown as DataService) ); @@ -436,6 +444,29 @@ suite('Participant Controller Test Suite', function () { "db.getCollection('collOne').find({ name: 'example' });" ); }); + + test('includes a collection schema', async function () { + sinon + .stub(testParticipantController, '_queryGenerationState') + .value(QUERY_GENERATION_STATE.FETCH_SCHEMA); + const chatRequestMock = { + prompt: 'find all docs by a name example', + command: 'query', + references: [], + }; + await testParticipantController.chatHandler( + chatRequestMock, + chatContextStub, + chatStreamStub, + chatTokenStub + ); + const messages = sendRequestStub.firstCall.args[0]; + expect(messages[0].content).to.include( + 'Collection schema:\n' + + '_id: String\n' + + 'field.subField: String\n' + ); + }); }); suite('unknown namespace', function () {