diff --git a/src/participant/constants.ts b/src/participant/constants.ts index b03e66274..49e283d05 100644 --- a/src/participant/constants.ts +++ b/src/participant/constants.ts @@ -10,6 +10,7 @@ export type ParticipantResponseType = | 'docs' | 'generic' | 'emptyRequest' + | 'cancelledRequest' | 'askToConnect' | 'askForNamespace'; @@ -49,6 +50,12 @@ export function namespaceRequestChatResult({ }; } +export function createCancelledRequestChatResult( + history: ReadonlyArray +): ChatResult { + return createChatResult('cancelledRequest', history); +} + function createChatResult( intent: ParticipantResponseType, history: ReadonlyArray diff --git a/src/participant/docsChatbotAIService.ts b/src/participant/docsChatbotAIService.ts index 4b0c8b1b8..47858280f 100644 --- a/src/participant/docsChatbotAIService.ts +++ b/src/participant/docsChatbotAIService.ts @@ -54,14 +54,16 @@ export class DocsChatbotAIService { return `${this._serverBaseUri}api/${MONGODB_DOCS_CHATBOT_API_VERSION}${path}`; } - async _fetch({ + _fetch({ uri, method, body, + signal, headers, }: { uri: string; method: string; + signal?: AbortSignal; body?: string; headers?: { [key: string]: string }; }): Promise { @@ -72,15 +74,21 @@ export class DocsChatbotAIService { ...headers, }, method, + signal, ...(body && { body }), }); } - async createConversation(): Promise { + async createConversation({ + signal, + }: { + signal: AbortSignal; + }): Promise { const uri = this.getUri('/conversations'); const res = await this._fetch({ uri, method: 'POST', + signal, }); let data; @@ -113,9 +121,11 @@ export class DocsChatbotAIService { async addMessage({ conversationId, message, + signal, }: { conversationId: string; message: string; + signal: AbortSignal; }): Promise { const uri = this.getUri(`/conversations/${conversationId}/messages`); const res = await this._fetch({ @@ -123,6 +133,7 @@ export class DocsChatbotAIService { method: 'POST', body: JSON.stringify({ message }), headers: { 'Content-Type': 'application/json' }, + signal, }); let data; diff --git a/src/participant/participant.ts b/src/participant/participant.ts index aa798afc3..0e5920b70 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -20,6 +20,7 @@ import { queryRequestChatResult, docsRequestChatResult, schemaRequestChatResult, + createCancelledRequestChatResult, } from './constants'; import { QueryPrompt } from './prompts/query'; import { COL_NAME_ID, DB_NAME_ID, NamespacePrompt } from './prompts/namespace'; @@ -241,10 +242,6 @@ export default class ParticipantController { context, }); - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); const responseContent = await this.getChatResponseContent({ messages, token, @@ -678,21 +675,33 @@ export default class ParticipantController { return askToConnectChatResult(context.history); } + _handleCancelledRequest({ + context, + stream, + }: { + context: vscode.ChatContext; + stream: vscode.ChatResponseStream; + }): ChatResult { + stream.markdown('\nRequest cancelled.'); + + return createCancelledRequestChatResult(context.history); + } + // The sample documents returned from this are simplified (strings and arrays shortened). // The sample documents are only returned when a user has the setting enabled. async _fetchCollectionSchemaAndSampleDocuments({ - abortSignal, databaseName, collectionName, amountOfDocumentsToSample = NUM_DOCUMENTS_TO_SAMPLE, schemaFormat = 'simplified', + token, stream, }: { - abortSignal; databaseName: string; collectionName: string; amountOfDocumentsToSample?: number; schemaFormat?: 'simplified' | 'full'; + token: vscode.CancellationToken; stream: vscode.ChatResponseStream; }): Promise<{ schema?: string; @@ -712,6 +721,11 @@ export default class ParticipantController { ) ); + const abortController = new AbortController(); + token.onCancellationRequested(() => { + abortController.abort(); + }); + try { const sampleDocuments = await dataService.sample( `${databaseName}.${collectionName}`, @@ -721,7 +735,7 @@ export default class ParticipantController { }, { promoteValues: false, maxTimeMS: 10_000 }, { - abortSignal, + abortSignal: abortController.signal, } ); @@ -852,10 +866,12 @@ export default class ParticipantController { }); } - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); + if (token.isCancellationRequested) { + return this._handleCancelledRequest({ + context, + stream, + }); + } let sampleDocuments: Document[] | undefined; let amountOfDocumentsSampled: number; @@ -866,11 +882,11 @@ export default class ParticipantController { amountOfDocumentsSampled, // There can be fewer than the amount we attempt to sample. schema, } = await this._fetchCollectionSchemaAndSampleDocuments({ - abortSignal: abortController.signal, databaseName, schemaFormat: 'full', collectionName, amountOfDocumentsToSample: DOCUMENTS_TO_SAMPLE_FOR_SCHEMA_PROMPT, + token, stream, })); @@ -969,19 +985,21 @@ export default class ParticipantController { }); } - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); + if (token.isCancellationRequested) { + return this._handleCancelledRequest({ + context, + stream, + }); + } let schema: string | undefined; let sampleDocuments: Document[] | undefined; try { ({ schema, sampleDocuments } = await this._fetchCollectionSchemaAndSampleDocuments({ - abortSignal: abortController.signal, databaseName, collectionName, + token, stream, })); } catch (e) { @@ -1024,10 +1042,12 @@ export default class ParticipantController { async _handleDocsRequestWithChatbot({ prompt, chatId, + token, stream, }: { prompt: string; chatId: string; + token: vscode.CancellationToken; stream: vscode.ChatResponseStream; }): Promise<{ responseContent: string; @@ -1040,9 +1060,14 @@ export default class ParticipantController { let { docsChatbotConversationId } = this._chatMetadataStore.getChatMetadata(chatId) ?? {}; + const abortController = new AbortController(); + token.onCancellationRequested(() => { + abortController.abort(); + }); if (!docsChatbotConversationId) { - const conversation = - await this._docsChatbotAIService.createConversation(); + const conversation = await this._docsChatbotAIService.createConversation({ + signal: abortController.signal, + }); docsChatbotConversationId = conversation._id; this._chatMetadataStore.setChatMetadata(chatId, { docsChatbotConversationId, @@ -1053,6 +1078,7 @@ export default class ParticipantController { const response = await this._docsChatbotAIService.addMessage({ message: prompt, conversationId: docsChatbotConversationId, + signal: abortController.signal, }); log.info('Docs chatbot message sent', { @@ -1085,10 +1111,6 @@ export default class ParticipantController { context, }); - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); const responseContent = await this.getChatResponseContent({ messages, token, @@ -1115,10 +1137,6 @@ export default class ParticipantController { ] ): Promise { const [request, context, stream, token] = args; - const abortController = new AbortController(); - token.onCancellationRequested(() => { - abortController.abort(); - }); const chatId = ChatMetadataStore.getChatIdFromHistoryOrNewChatId( context.history @@ -1133,12 +1151,21 @@ export default class ParticipantController { docsResult = await this._handleDocsRequestWithChatbot({ prompt: request.prompt, chatId, + token, stream, }); } catch (error) { // If the docs chatbot API is not available, fall back to Copilot’s LLM and include // the MongoDB documentation link for users to go to our documentation site directly. log.error(error); + + if (token.isCancellationRequested) { + return this._handleCancelledRequest({ + context, + stream, + }); + } + this._telemetryService.track( TelemetryEventTypes.PARTICIPANT_RESPONSE_FAILED, { diff --git a/src/test/suite/participant/docsChatbotAIService.test.ts b/src/test/suite/participant/docsChatbotAIService.test.ts index 827cf6841..e0f97e7ff 100644 --- a/src/test/suite/participant/docsChatbotAIService.test.ts +++ b/src/test/suite/participant/docsChatbotAIService.test.ts @@ -28,7 +28,9 @@ suite('DocsChatbotAIService Test Suite', function () { }), }); global.fetch = fetchStub; - const conversation = await docsChatbotAIService.createConversation(); + const conversation = await docsChatbotAIService.createConversation({ + signal: new AbortController().signal, + }); expect(conversation._id).to.be.eql('650b4b260f975ef031016c8a'); }); @@ -42,13 +44,28 @@ suite('DocsChatbotAIService Test Suite', function () { global.fetch = fetchStub; try { - await docsChatbotAIService.createConversation(); + await docsChatbotAIService.createConversation({ + signal: new AbortController().signal, + }); expect.fail('It must fail with the server error'); } catch (error) { expect((error as Error).message).to.include('Internal server error'); } }); + test('throws when aborted', async () => { + try { + const abortController = new AbortController(); + abortController.abort(); + await docsChatbotAIService.createConversation({ + signal: abortController.signal, + }); + expect.fail('It must fail with the server error'); + } catch (error) { + expect((error as Error).message).to.include('This operation was aborted'); + } + }); + test('throws on bad requests', async () => { const fetchStub = sinon.stub().resolves({ status: 400, @@ -59,7 +76,9 @@ suite('DocsChatbotAIService Test Suite', function () { global.fetch = fetchStub; try { - await docsChatbotAIService.createConversation(); + await docsChatbotAIService.createConversation({ + signal: new AbortController().signal, + }); expect.fail('It must fail with the bad request error'); } catch (error) { expect((error as Error).message).to.include('Bad request'); @@ -76,7 +95,9 @@ suite('DocsChatbotAIService Test Suite', function () { global.fetch = fetchStub; try { - await docsChatbotAIService.createConversation(); + await docsChatbotAIService.createConversation({ + signal: new AbortController().signal, + }); expect.fail('It must fail with the rate limited error'); } catch (error) { expect((error as Error).message).to.include('Rate limited'); @@ -95,6 +116,7 @@ suite('DocsChatbotAIService Test Suite', function () { await docsChatbotAIService.addMessage({ conversationId: '650b4b260f975ef031016c8a', message: 'what is mongosh?', + signal: new AbortController().signal, }); expect.fail('It must fail with the timeout error'); } catch (error) {