diff --git a/src/participant/chatMetadata.ts b/src/participant/chatMetadata.ts index 1462c4ae4..1daade9bb 100644 --- a/src/participant/chatMetadata.ts +++ b/src/participant/chatMetadata.ts @@ -20,6 +20,20 @@ export class ChatMetadataStore { return this._chats[chatId]; } + patchChatMetadata( + context: vscode.ChatContext, + patchedMetadata: Partial + ): void { + const chatId = ChatMetadataStore.getChatIdFromHistoryOrNewChatId( + context.history + ); + + this.setChatMetadata(chatId, { + ...this.getChatMetadata(chatId), + ...patchedMetadata, + }); + } + // Exposed for stubbing in tests. static createNewChatId(): string { return uuidv4(); diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 19f7457d1..b0254c16f 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -43,7 +43,8 @@ import formatError from '../utils/formatError'; import type { ModelInput } from './prompts/promptBase'; import { processStreamWithIdentifiers } from './streamParsing'; import type { PromptIntent } from './prompts/intent'; -import { ParticipantErrorTypes } from '../test/suite/participant/participantErrorTypes'; +import type { DataService } from 'mongodb-data-service'; +import { ParticipantErrorTypes } from './participantErrorTypes'; const log = createLogger('participant'); @@ -612,123 +613,95 @@ export default class ParticipantController { ) as Promise; } - async renderDatabasesTree({ + renderDatabasesTree({ command, context, stream, + databases, }: { command: ParticipantCommand; context: vscode.ChatContext; stream: vscode.ChatResponseStream; - }): Promise { - const dataService = this._connectionController.getActiveDataService(); - if (!dataService) { - return; - } - - stream.push( - new vscode.ChatResponseProgressPart('Fetching database names...') + databases: { + _id: string; + name: string; + }[]; + }): void { + databases.slice(0, MAX_MARKDOWN_LIST_LENGTH).forEach((db) => + stream.markdown( + createMarkdownLink({ + commandId: EXTENSION_COMMANDS.SELECT_DATABASE_WITH_PARTICIPANT, + data: { + command, + chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( + context.history + ), + databaseName: db.name, + }, + name: db.name, + }) + ) ); - try { - const databases = await dataService.listDatabases({ - nameOnly: true, - }); - databases.slice(0, MAX_MARKDOWN_LIST_LENGTH).forEach((db) => - stream.markdown( - createMarkdownLink({ - commandId: EXTENSION_COMMANDS.SELECT_DATABASE_WITH_PARTICIPANT, - data: { - command, - chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( - context.history - ), - databaseName: db.name, - }, - name: db.name, - }) - ) + if (databases.length > MAX_MARKDOWN_LIST_LENGTH) { + stream.markdown( + createMarkdownLink({ + data: { + command, + chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( + context.history + ), + }, + commandId: EXTENSION_COMMANDS.SELECT_DATABASE_WITH_PARTICIPANT, + name: 'Show more', + }) ); - if (databases.length > MAX_MARKDOWN_LIST_LENGTH) { - stream.markdown( - createMarkdownLink({ - data: { - command, - chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( - context.history - ), - }, - commandId: EXTENSION_COMMANDS.SELECT_DATABASE_WITH_PARTICIPANT, - name: 'Show more', - }) - ); - } - } catch (error) { - log.error('Unable to fetch databases:', error); - - // Users can always do this manually when asked to provide a database name. - return; } } - async renderCollectionsTree({ + renderCollectionsTree({ + collections, command, context, databaseName, stream, }: { + collections: Awaited>; command: ParticipantCommand; databaseName: string; context: vscode.ChatContext; stream: vscode.ChatResponseStream; - }): Promise { - const dataService = this._connectionController.getActiveDataService(); - if (!dataService) { - return; - } - - stream.push( - new vscode.ChatResponseProgressPart('Fetching collection names...') + }): void { + collections.slice(0, MAX_MARKDOWN_LIST_LENGTH).forEach((coll) => + stream.markdown( + createMarkdownLink({ + commandId: EXTENSION_COMMANDS.SELECT_COLLECTION_WITH_PARTICIPANT, + data: { + command, + chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( + context.history + ), + databaseName, + collectionName: coll.name, + }, + name: coll.name, + }) + ) ); - - try { - const collections = await dataService.listCollections(databaseName); - collections.slice(0, MAX_MARKDOWN_LIST_LENGTH).forEach((coll) => - stream.markdown( - createMarkdownLink({ - commandId: EXTENSION_COMMANDS.SELECT_COLLECTION_WITH_PARTICIPANT, - data: { - command, - chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( - context.history - ), - databaseName, - collectionName: coll.name, - }, - name: coll.name, - }) - ) + if (collections.length > MAX_MARKDOWN_LIST_LENGTH) { + stream.markdown( + createMarkdownLink({ + commandId: EXTENSION_COMMANDS.SELECT_COLLECTION_WITH_PARTICIPANT, + data: { + command, + chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( + context.history + ), + databaseName, + }, + name: 'Show more', + }) ); - if (collections.length > MAX_MARKDOWN_LIST_LENGTH) { - stream.markdown( - createMarkdownLink({ - commandId: EXTENSION_COMMANDS.SELECT_COLLECTION_WITH_PARTICIPANT, - data: { - command, - chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( - context.history - ), - databaseName, - }, - name: 'Show more', - }) - ); - } - } catch (error) { - log.error('Unable to fetch collections:', error); - - // Users can always do this manually when asked to provide a collection name. - return; } } @@ -811,50 +784,206 @@ export default class ParticipantController { }; } - async _askForNamespace({ - command, + async _getDatabases({ + stream, + }: { + stream: vscode.ChatResponseStream; + }): Promise< + | { + _id: string; + name: string; + }[] + | undefined + > { + const dataService = this._connectionController.getActiveDataService(); + if (!dataService) { + return undefined; + } + + stream.push( + new vscode.ChatResponseProgressPart('Fetching database names...') + ); + + try { + return await dataService.listDatabases({ + nameOnly: true, + }); + } catch (error) { + log.error('Unable to fetch databases:', error); + + return undefined; + } + } + + async _getCollections({ + stream, + databaseName, + }: { + stream: vscode.ChatResponseStream; + databaseName: string; + }): Promise | undefined> { + const dataService = this._connectionController.getActiveDataService(); + + if (!dataService) { + return undefined; + } + + stream.push( + new vscode.ChatResponseProgressPart('Fetching collection names...') + ); + + try { + return await dataService.listCollections(databaseName); + } catch (error) { + log.error('Unable to fetch collections:', error); + + return undefined; + } + } + + /** Gets the collection name if there is only one collection. + * Otherwise returns undefined and asks the user to select the collection. */ + async _getOrAskForCollectionName({ context, databaseName, - collectionName, stream, + command, }: { command: ParticipantCommand; context: vscode.ChatContext; - databaseName: string | undefined; - collectionName: string | undefined; + databaseName: string; stream: vscode.ChatResponseStream; - }): Promise { + }): Promise { + const collections = await this._getCollections({ stream, databaseName }); + + if (collections === undefined) { + log.error('No collections found'); + return undefined; + } + if (collections.length === 1) { + return collections[0].name; + } + + stream.markdown( + vscode.l10n.t( + `Which collection would you like to use within ${databaseName}? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n` + ) + ); + + this.renderCollectionsTree({ + collections, + command, + databaseName, + context, + stream, + }); + + return undefined; + } + + /** Gets the database name if there is only one collection. + * Otherwise returns undefined and asks the user to select the database. */ + async _getOrAskForDatabaseName({ + command, + context, + stream, + }: { + command: ParticipantCommand; + context: vscode.ChatContext; + stream: vscode.ChatResponseStream; + }): Promise { + const databases = await this._getDatabases({ stream }); + + if (databases === undefined || databases.length === 0) { + log.error('No databases found'); + return undefined; + } + + if (databases.length === 1) { + return databases[0].name; + } + // If no database or collection name is found in the user prompt, // we retrieve the available namespaces from the current connection. - // Users can then select a value by clicking on an item in the list. + // Users can then select a value by clicking on an item in the list + // or typing the name manually. + stream.markdown( + `Which database would you like ${ + command === '/query' ? 'this query to run against' : 'to use' + }? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n` + ); + + this.renderDatabasesTree({ + databases, + command, + context, + stream, + }); + + return undefined; + } + + /** Helper which either automatically picks and returns missing parts of the namespace (if any) + * or prompts the user to pick the missing namespace. + */ + async _getOrAskForMissingNamespace({ + databaseName, + collectionName, + context, + stream, + command, + }: { + databaseName: string | undefined; + collectionName: string | undefined; + context: vscode.ChatContext; + stream: vscode.ChatResponseStream; + command: ParticipantCommand; + }): Promise<{ + databaseName: string | undefined; + collectionName: string | undefined; + }> { if (!databaseName) { - stream.markdown( - `What is the name of the database you would like${ - command === '/query' ? ' this query' : '' - } to run against?\n\n` - ); - await this.renderDatabasesTree({ + databaseName = await this._getOrAskForDatabaseName({ command, context, stream, }); - } else if (!collectionName) { - stream.markdown( - `Which collection would you like to use within ${databaseName}?\n\n` - ); - await this.renderCollectionsTree({ - command, + + // databaseName will be undefined if it cannot be found from + // the metadata or history, in which case the user will be prompted + // to select it or if some error occurs. + if (!databaseName) { + return { databaseName, collectionName }; + } + + this._chatMetadataStore.patchChatMetadata(context, { databaseName, + }); + } + + if (!collectionName) { + collectionName = await this._getOrAskForCollectionName({ + command, context, + databaseName, stream, }); + + // If the collection name could not get automatically selected, + // then the user has been prompted for it instead. + if (!collectionName) { + return { + databaseName, + collectionName, + }; + } + + this._chatMetadataStore.patchChatMetadata(context, { + collectionName, + }); } - return namespaceRequestChatResult({ - databaseName, - collectionName, - history: context.history, - }); + return { collectionName, databaseName }; } _doesLastMessageAskForNamespace( @@ -1009,40 +1138,28 @@ export default class ParticipantController { // When the last message was asking for a database or collection name, // we re-ask the question. - const databaseName = lastMessage.metadata.databaseName; - if (databaseName) { - stream.markdown( - vscode.l10n.t( - 'Please select a collection by either clicking on an item in the list or typing the name manually in the chat.' - ) - ); - await this.renderCollectionsTree({ - command, - databaseName, - context, - stream, - }); - } else { - stream.markdown( - vscode.l10n.t( - 'Please select a database by either clicking on an item in the list or typing the name manually in the chat.' - ) - ); - await this.renderDatabasesTree({ + const metadataDatabaseName = lastMessage.metadata.databaseName; + + // This will prompt the user for the missing databaseName or the collectionName. + // If anything in the namespace can be automatically picked, it will be returned. + const { databaseName, collectionName } = + await this._getOrAskForMissingNamespace({ command, context, stream, + databaseName: metadataDatabaseName, + collectionName: undefined, }); - } return namespaceRequestChatResult({ databaseName, - collectionName: undefined, + collectionName, history: context.history, }); } // @MongoDB /schema + // eslint-disable-next-line complexity async handleSchemaRequest( request: vscode.ChatRequest, context: vscode.ChatContext, @@ -1068,19 +1185,26 @@ export default class ParticipantController { }); } - const { databaseName, collectionName } = await this._getNamespaceFromChat({ + const namespace = await this._getNamespaceFromChat({ request, context, token, }); - - if (!databaseName || !collectionName) { - return await this._askForNamespace({ - command: '/schema', + const { databaseName, collectionName } = + await this._getOrAskForMissingNamespace({ + ...namespace, context, + stream, + command: '/schema', + }); + + // If either the database or collection name could not be automatically picked + // then the user has been prompted to select one manually or been presented with an error. + if (databaseName === undefined || collectionName === undefined) { + return namespaceRequestChatResult({ databaseName, collectionName, - stream, + history: context.history, }); } @@ -1194,18 +1318,26 @@ export default class ParticipantController { // First we ask the model to parse for the database and collection name. // If they exist, we can then use them in our final completion. // When they don't exist we ask the user for them. - const { databaseName, collectionName } = await this._getNamespaceFromChat({ + const namespace = await this._getNamespaceFromChat({ request, context, token, }); - if (!databaseName || !collectionName) { - return await this._askForNamespace({ - command: '/query', + const { databaseName, collectionName } = + await this._getOrAskForMissingNamespace({ + ...namespace, context, + stream, + command: '/query', + }); + + // If either the database or collection name could not be automatically picked + // then the user has been prompted to select one manually. + if (databaseName === undefined || collectionName === undefined) { + return namespaceRequestChatResult({ databaseName, collectionName, - stream, + history: context.history, }); } diff --git a/src/test/suite/participant/participantErrorTypes.ts b/src/participant/participantErrorTypes.ts similarity index 100% rename from src/test/suite/participant/participantErrorTypes.ts rename to src/participant/participantErrorTypes.ts diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 40c430eec..364284240 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -4,7 +4,7 @@ import type { InternalPromptPurpose, ParticipantPromptProperties, } from '../../telemetry/telemetryService'; -import { ParticipantErrorTypes } from '../../test/suite/participant/participantErrorTypes'; +import { ParticipantErrorTypes } from '../participantErrorTypes'; export interface PromptArgsBase { request: { diff --git a/src/telemetry/telemetryService.ts b/src/telemetry/telemetryService.ts index 8e0482606..8a58258dc 100644 --- a/src/telemetry/telemetryService.ts +++ b/src/telemetry/telemetryService.ts @@ -13,7 +13,7 @@ import type { NewConnectionTelemetryEventProperties } from './connectionTelemetr import type { ShellEvaluateResult } from '../types/playgroundType'; import type { StorageController } from '../storage'; import type { ParticipantResponseType } from '../participant/constants'; -import { ParticipantErrorTypes } from '../test/suite/participant/participantErrorTypes'; +import { ParticipantErrorTypes } from '../participant/participantErrorTypes'; const log = createLogger('telemetry'); // eslint-disable-next-line @typescript-eslint/no-var-requires diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 770e32ef9..8b5916ecf 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -34,7 +34,7 @@ import { Prompts } from '../../../participant/prompts'; import { createMarkdownLink } from '../../../participant/markdown'; import EXTENSION_COMMANDS from '../../../commands'; import { getContentLength } from '../../../participant/prompts/promptBase'; -import { ParticipantErrorTypes } from './participantErrorTypes'; +import { ParticipantErrorTypes } from '../../../participant/participantErrorTypes'; // The Copilot's model in not available in tests, // therefore we need to mock its methods and returning values. @@ -405,42 +405,49 @@ suite('Participant Controller Test Suite', function () { suite('when connected', function () { let sampleStub; + let listCollectionsStub; + let listDatabasesStub; beforeEach(function () { sampleStub = sinon.stub(); + listDatabasesStub = sinon + .stub() + .resolves([ + { name: 'dbOne' }, + { name: 'customer' }, + { name: 'inventory' }, + { name: 'sales' }, + { name: 'employee' }, + { name: 'financialReports' }, + { name: 'productCatalog' }, + { name: 'projectTracker' }, + { name: 'user' }, + { name: 'analytics' }, + { name: '123' }, + ]); + listCollectionsStub = sinon + .stub() + .resolves([ + { name: 'collOne' }, + { name: 'notifications' }, + { name: 'products' }, + { name: 'orders' }, + { name: 'categories' }, + { name: 'invoices' }, + { name: 'transactions' }, + { name: 'logs' }, + { name: 'messages' }, + { name: 'sessions' }, + { name: 'feedback' }, + ]); + sinon.replace( testParticipantController._connectionController, 'getActiveDataService', () => ({ - listDatabases: () => - Promise.resolve([ - { name: 'dbOne' }, - { name: 'customer' }, - { name: 'inventory' }, - { name: 'sales' }, - { name: 'employee' }, - { name: 'financialReports' }, - { name: 'productCatalog' }, - { name: 'projectTracker' }, - { name: 'user' }, - { name: 'analytics' }, - { name: '123' }, - ]), - listCollections: () => - Promise.resolve([ - { name: 'collOne' }, - { name: 'notifications' }, - { name: 'products' }, - { name: 'orders' }, - { name: 'categories' }, - { name: 'invoices' }, - { name: 'transactions' }, - { name: 'logs' }, - { name: 'messages' }, - { name: 'sessions' }, - { name: 'feedback' }, - ]), + listDatabases: listDatabasesStub, + listCollections: listCollectionsStub, getMongoClientConnectionOptions: () => ({ url: TEST_DATABASE_URI, options: {}, @@ -1016,7 +1023,7 @@ suite('Participant Controller Test Suite', function () { }); }); - suite('unknown namespace', function () { + suite('no namespace provided', function () { test('asks for a namespace and generates a query', async function () { const chatRequestMock = { prompt: 'find all docs by a name example', @@ -1026,7 +1033,7 @@ suite('Participant Controller Test Suite', function () { const chatResult = await invokeChatHandler(chatRequestMock); const askForDBMessage = chatStreamStub.markdown.getCall(0).args[0]; expect(askForDBMessage).to.include( - 'What is the name of the database you would like this query to run against?' + 'Which database would you like this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n' ); const listDBsMessage = chatStreamStub.markdown.getCall(1).args[0]; const expectedContent = encodeStringify({ @@ -1080,7 +1087,7 @@ suite('Participant Controller Test Suite', function () { { value: { value: - 'What is the name of the database you would like this query to run against?', + 'Which database would you like this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', } as vscode.MarkdownString, }, ], @@ -1101,7 +1108,7 @@ suite('Participant Controller Test Suite', function () { const askForCollMessage = chatStreamStub.markdown.getCall(12).args[0]; expect(askForCollMessage).to.include( - 'Which collection would you like to use within dbOne?' + 'Which collection would you like to use within dbOne? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n' ); const listCollsMessage = chatStreamStub.markdown.getCall(13).args[0]; @@ -1157,7 +1164,7 @@ suite('Participant Controller Test Suite', function () { { value: { value: - 'Which database would you like to query within this database?', + 'Which database would you like to this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', } as vscode.MarkdownString, }, ], @@ -1183,7 +1190,7 @@ suite('Participant Controller Test Suite', function () { { value: { value: - 'Which collection would you like to query within dbOne?', + 'Which collection would you like to query within dbOne? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', } as vscode.MarkdownString, }, ], @@ -1225,7 +1232,7 @@ suite('Participant Controller Test Suite', function () { }); }); - test('handles empty database name', async function () { + test('asks for the empty database name again if the last prompt was doing so', async function () { const chatRequestMock = { prompt: '', command: 'query', @@ -1247,7 +1254,7 @@ suite('Participant Controller Test Suite', function () { { value: { value: - 'What is the name of the database you would like this query to run against?', + 'Which database would you like this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', } as vscode.MarkdownString, }, ], @@ -1265,8 +1272,8 @@ suite('Participant Controller Test Suite', function () { const chatResult = await invokeChatHandler(chatRequestMock); const emptyMessage = chatStreamStub.markdown.getCall(0).args[0]; - expect(emptyMessage).to.include( - 'Please select a database by either clicking on an item in the list or typing the name manually in the chat.' + expect(emptyMessage).to.equal( + 'Which database would you like this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n' ); const listDBsMessage = chatStreamStub.markdown.getCall(1).args[0]; expect(listDBsMessage.value).to.include( @@ -1298,110 +1305,6 @@ suite('Participant Controller Test Suite', function () { chatId: undefined, }); }); - - test('handles empty collection name', async function () { - const chatRequestMock = { - prompt: '', - command: 'query', - references: [], - }; - chatContextStub = { - history: [ - Object.assign(Object.create(vscode.ChatRequestTurn.prototype), { - prompt: 'find all docs by a name example', - command: 'query', - references: [], - participant: CHAT_PARTICIPANT_ID, - }), - Object.assign( - Object.create(vscode.ChatResponseTurn.prototype), - { - participant: CHAT_PARTICIPANT_ID, - response: [ - { - value: { - value: - 'Which database would you like to query within this database?', - } as vscode.MarkdownString, - }, - ], - command: 'query', - result: { - metadata: { - intent: 'askForNamespace', - }, - }, - } - ), - Object.assign(Object.create(vscode.ChatRequestTurn.prototype), { - prompt: 'dbOne', - command: 'query', - references: [], - participant: CHAT_PARTICIPANT_ID, - }), - Object.assign( - Object.create(vscode.ChatResponseTurn.prototype), - { - participant: CHAT_PARTICIPANT_ID, - response: [ - { - value: { - value: - 'Which collection would you like to query within dbOne?', - } as vscode.MarkdownString, - }, - ], - command: 'query', - result: { - metadata: { - intent: 'askForNamespace', - databaseName: 'dbOne', - collectionName: undefined, - chatId: 'pineapple', - }, - }, - } - ), - ], - }; - const chatResult = await invokeChatHandler(chatRequestMock); - - const emptyMessage = chatStreamStub.markdown.getCall(0).args[0]; - expect(emptyMessage).to.include( - 'Please select a collection by either clicking on an item in the list or typing the name manually in the chat.' - ); - const listCollsMessage = chatStreamStub.markdown.getCall(1).args[0]; - expect(listCollsMessage.value).to.include( - `- [collOne](command:mdb.selectCollectionWithParticipant?${encodeStringify( - { - command: '/query', - chatId: 'pineapple', - databaseName: 'dbOne', - collectionName: 'collOne', - } - )})` - ); - const showMoreCollsMessage = - chatStreamStub.markdown.getCall(11).args[0]; - expect(showMoreCollsMessage.value).to.include( - `- [Show more](command:mdb.selectCollectionWithParticipant?${encodeStringify( - { - command: '/query', - chatId: 'pineapple', - databaseName: 'dbOne', - } - )})` - ); - expect({ - ...chatResult?.metadata, - chatId: undefined, - }).to.deep.equal({ - intent: 'askForNamespace', - collectionName: undefined, - databaseName: 'dbOne', - chatId: undefined, - }); - }); }); }); @@ -1435,7 +1338,7 @@ suite('Participant Controller Test Suite', function () { expect(sendRequestStub.called).to.be.false; const askForDBMessage = chatStreamStub.markdown.getCall(0).args[0]; expect(askForDBMessage).to.include( - 'What is the name of the database you would like to run against?' + 'Which database would you like to use? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n' ); }); @@ -1455,8 +1358,8 @@ suite('Participant Controller Test Suite', function () { ); const askForDBMessage = chatStreamStub.markdown.getCall(0).args[0]; - expect(askForDBMessage).to.include( - 'What is the name of the database you would like to run against?' + expect(askForDBMessage).to.equals( + 'Which database would you like to use? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n' ); }); @@ -1861,6 +1764,164 @@ Schema: }); }); }); + + suite('determining the namespace', function () { + ['query', 'schema'].forEach(function (command) { + suite(`${command} command`, function () { + beforeEach(function () { + sendRequestStub.resolves({ + text: ['determining the namespace'], + }); + }); + + suite('with an empty database name', function () { + beforeEach(function () { + sinon.replace( + testParticipantController._chatMetadataStore, + 'getChatMetadata', + () => ({ + databaseName: undefined, + collectionName: undefined, + }) + ); + }); + + afterEach(function () { + sinon.restore(); + }); + + test('database name gets picked automatically if there is only 1', async function () { + listDatabasesStub.resolves([{ name: 'onlyOneDb' }]); + + const renderDatabasesTreeSpy = sinon.spy( + testParticipantController, + 'renderDatabasesTree' + ); + const renderCollectionsTreeSpy = sinon.spy( + testParticipantController, + 'renderCollectionsTree' + ); + + const chatResult = await invokeChatHandler({ + prompt: 'what is this', + command, + references: [], + }); + + expect(renderDatabasesTreeSpy.called).to.be.false; + expect(renderCollectionsTreeSpy.calledOnce).to.be.true; + + expect(chatResult?.metadata).deep.equals({ + chatId: testChatId, + intent: 'askForNamespace', + databaseName: 'onlyOneDb', + collectionName: undefined, + }); + }); + + test('prompts for database name if there are multiple available', async function () { + const renderCollectionsTreeSpy = sinon.spy( + testParticipantController, + 'renderCollectionsTree' + ); + const renderDatabasesTreeSpy = sinon.spy( + testParticipantController, + 'renderDatabasesTree' + ); + + const chatResult = await invokeChatHandler({ + prompt: 'dbOne', + command, + references: [], + }); + + expect(renderDatabasesTreeSpy.calledOnce).to.be.true; + expect(renderCollectionsTreeSpy.called).to.be.false; + + expect(chatResult?.metadata).deep.equals({ + intent: 'askForNamespace', + chatId: testChatId, + databaseName: undefined, + collectionName: undefined, + }); + }); + }); + + suite('with an empty collection name', function () { + beforeEach(function () { + sinon.replace( + testParticipantController._chatMetadataStore, + 'getChatMetadata', + () => ({ + databaseName: 'dbOne', + collectionName: undefined, + }) + ); + }); + + test('collection name gets picked automatically if there is only 1', async function () { + listCollectionsStub.resolves([{ name: 'onlyOneColl' }]); + const renderCollectionsTreeSpy = sinon.spy( + testParticipantController, + 'renderCollectionsTree' + ); + const fetchCollectionSchemaAndSampleDocumentsSpy = sinon.spy( + testParticipantController, + '_fetchCollectionSchemaAndSampleDocuments' + ); + + const chatResult = await invokeChatHandler({ + prompt: 'dbOne', + command, + references: [], + }); + + expect(renderCollectionsTreeSpy.called).to.be.false; + + expect( + fetchCollectionSchemaAndSampleDocumentsSpy.firstCall.args[0] + ).to.include({ + collectionName: 'onlyOneColl', + }); + + expect(chatResult?.metadata).deep.equals({ + chatId: testChatId, + intent: command, + }); + }); + + test('prompts for collection name if there are multiple available', async function () { + const renderCollectionsTreeSpy = sinon.spy( + testParticipantController, + 'renderCollectionsTree' + ); + const fetchCollectionSchemaAndSampleDocumentsSpy = sinon.spy( + testParticipantController, + '_fetchCollectionSchemaAndSampleDocuments' + ); + + const chatResult = await invokeChatHandler({ + prompt: 'dbOne', + command, + references: [], + }); + + expect(renderCollectionsTreeSpy.calledOnce).to.be.true; + expect( + fetchCollectionSchemaAndSampleDocumentsSpy.called + ).to.be.false; + + expect(chatResult?.metadata).deep.equals({ + intent: 'askForNamespace', + chatId: testChatId, + databaseName: 'dbOne', + collectionName: undefined, + }); + }); + }); + }); + }); + }); }); suite('prompt builders', function () {