Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(participant): add handling for token cancellation, pass to docs fetch request, remove unused abort controllers #831

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/participant/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export type ParticipantResponseType =
| 'docs'
| 'generic'
| 'emptyRequest'
| 'cancelledRequest'
| 'askToConnect'
| 'askForNamespace';

Expand Down Expand Up @@ -49,6 +50,12 @@ export function namespaceRequestChatResult({
};
}

export function createCancelledRequestChatResult(
history: ReadonlyArray<vscode.ChatRequestTurn | vscode.ChatResponseTurn>
): ChatResult {
return createChatResult('cancelledRequest', history);
}

function createChatResult(
intent: ParticipantResponseType,
history: ReadonlyArray<vscode.ChatRequestTurn | vscode.ChatResponseTurn>
Expand Down
15 changes: 13 additions & 2 deletions src/participant/docsChatbotAIService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ export class DocsChatbotAIService {
return `${this._serverBaseUri}api/${MONGODB_DOCS_CHATBOT_API_VERSION}${path}`;
}

async _fetch({
_fetch({
uri,
method,
body,
signal,
headers,
}: {
uri: string;
method: string;
signal?: AbortSignal;
body?: string;
headers?: { [key: string]: string };
}): Promise<Response> {
Expand All @@ -72,15 +74,21 @@ export class DocsChatbotAIService {
...headers,
},
method,
signal,
...(body && { body }),
});
}

async createConversation(): Promise<ConversationData> {
async createConversation({
signal,
}: {
signal: AbortSignal;
}): Promise<ConversationData> {
const uri = this.getUri('/conversations');
const res = await this._fetch({
uri,
method: 'POST',
signal,
});

let data;
Expand Down Expand Up @@ -113,16 +121,19 @@ export class DocsChatbotAIService {
async addMessage({
conversationId,
message,
signal,
}: {
conversationId: string;
message: string;
signal: AbortSignal;
}): Promise<MessageData> {
const uri = this.getUri(`/conversations/${conversationId}/messages`);
const res = await this._fetch({
uri,
method: 'POST',
body: JSON.stringify({ message }),
headers: { 'Content-Type': 'application/json' },
signal,
});

let data;
Expand Down
81 changes: 54 additions & 27 deletions src/participant/participant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
queryRequestChatResult,
docsRequestChatResult,
schemaRequestChatResult,
createCancelledRequestChatResult,
} from './constants';
import { QueryPrompt } from './prompts/query';
import { COL_NAME_ID, DB_NAME_ID, NamespacePrompt } from './prompts/namespace';
Expand Down Expand Up @@ -241,10 +242,6 @@ export default class ParticipantController {
context,
});

const abortController = new AbortController();
token.onCancellationRequested(() => {
abortController.abort();
});
const responseContent = await this.getChatResponseContent({
messages,
token,
Expand Down Expand Up @@ -678,21 +675,33 @@ export default class ParticipantController {
return askToConnectChatResult(context.history);
}

_handleCancelledRequest({
context,
stream,
}: {
context: vscode.ChatContext;
stream: vscode.ChatResponseStream;
}): ChatResult {
stream.markdown('\nRequest cancelled.');

return createCancelledRequestChatResult(context.history);
}

// The sample documents returned from this are simplified (strings and arrays shortened).
// The sample documents are only returned when a user has the setting enabled.
async _fetchCollectionSchemaAndSampleDocuments({
abortSignal,
databaseName,
collectionName,
amountOfDocumentsToSample = NUM_DOCUMENTS_TO_SAMPLE,
schemaFormat = 'simplified',
token,
stream,
}: {
abortSignal;
databaseName: string;
collectionName: string;
amountOfDocumentsToSample?: number;
schemaFormat?: 'simplified' | 'full';
token: vscode.CancellationToken;
stream: vscode.ChatResponseStream;
}): Promise<{
schema?: string;
Expand All @@ -712,6 +721,11 @@ export default class ParticipantController {
)
);

const abortController = new AbortController();
token.onCancellationRequested(() => {
abortController.abort();
});

try {
const sampleDocuments = await dataService.sample(
`${databaseName}.${collectionName}`,
Expand All @@ -721,7 +735,7 @@ export default class ParticipantController {
},
{ promoteValues: false, maxTimeMS: 10_000 },
{
abortSignal,
abortSignal: abortController.signal,
}
);

Expand Down Expand Up @@ -852,10 +866,12 @@ export default class ParticipantController {
});
}

const abortController = new AbortController();
token.onCancellationRequested(() => {
abortController.abort();
});
if (token.isCancellationRequested) {
return this._handleCancelledRequest({
context,
stream,
});
}

let sampleDocuments: Document[] | undefined;
let amountOfDocumentsSampled: number;
Expand All @@ -866,11 +882,11 @@ export default class ParticipantController {
amountOfDocumentsSampled, // There can be fewer than the amount we attempt to sample.
schema,
} = await this._fetchCollectionSchemaAndSampleDocuments({
abortSignal: abortController.signal,
databaseName,
schemaFormat: 'full',
collectionName,
amountOfDocumentsToSample: DOCUMENTS_TO_SAMPLE_FOR_SCHEMA_PROMPT,
token,
stream,
}));

Expand Down Expand Up @@ -969,19 +985,21 @@ export default class ParticipantController {
});
}

const abortController = new AbortController();
token.onCancellationRequested(() => {
abortController.abort();
});
if (token.isCancellationRequested) {
return this._handleCancelledRequest({
context,
stream,
});
}

let schema: string | undefined;
let sampleDocuments: Document[] | undefined;
try {
({ schema, sampleDocuments } =
await this._fetchCollectionSchemaAndSampleDocuments({
abortSignal: abortController.signal,
databaseName,
collectionName,
token,
stream,
}));
} catch (e) {
Expand Down Expand Up @@ -1024,10 +1042,12 @@ export default class ParticipantController {
async _handleDocsRequestWithChatbot({
prompt,
chatId,
token,
stream,
}: {
prompt: string;
chatId: string;
token: vscode.CancellationToken;
stream: vscode.ChatResponseStream;
}): Promise<{
responseContent: string;
Expand All @@ -1040,9 +1060,14 @@ export default class ParticipantController {

let { docsChatbotConversationId } =
this._chatMetadataStore.getChatMetadata(chatId) ?? {};
const abortController = new AbortController();
token.onCancellationRequested(() => {
abortController.abort();
});
if (!docsChatbotConversationId) {
const conversation =
await this._docsChatbotAIService.createConversation();
const conversation = await this._docsChatbotAIService.createConversation({
signal: abortController.signal,
});
docsChatbotConversationId = conversation._id;
this._chatMetadataStore.setChatMetadata(chatId, {
docsChatbotConversationId,
Expand All @@ -1053,6 +1078,7 @@ export default class ParticipantController {
const response = await this._docsChatbotAIService.addMessage({
message: prompt,
conversationId: docsChatbotConversationId,
signal: abortController.signal,
});

log.info('Docs chatbot message sent', {
Expand Down Expand Up @@ -1085,10 +1111,6 @@ export default class ParticipantController {
context,
});

const abortController = new AbortController();
token.onCancellationRequested(() => {
abortController.abort();
});
const responseContent = await this.getChatResponseContent({
messages,
token,
Expand All @@ -1115,10 +1137,6 @@ export default class ParticipantController {
]
): Promise<ChatResult> {
const [request, context, stream, token] = args;
const abortController = new AbortController();
token.onCancellationRequested(() => {
abortController.abort();
});

const chatId = ChatMetadataStore.getChatIdFromHistoryOrNewChatId(
context.history
Expand All @@ -1133,12 +1151,21 @@ export default class ParticipantController {
docsResult = await this._handleDocsRequestWithChatbot({
prompt: request.prompt,
chatId,
token,
stream,
});
} catch (error) {
// If the docs chatbot API is not available, fall back to Copilot’s LLM and include
// the MongoDB documentation link for users to go to our documentation site directly.
log.error(error);

if (token.isCancellationRequested) {
return this._handleCancelledRequest({
context,
stream,
});
}

this._telemetryService.track(
TelemetryEventTypes.PARTICIPANT_RESPONSE_FAILED,
{
Expand Down
30 changes: 26 additions & 4 deletions src/test/suite/participant/docsChatbotAIService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ suite('DocsChatbotAIService Test Suite', function () {
}),
});
global.fetch = fetchStub;
const conversation = await docsChatbotAIService.createConversation();
const conversation = await docsChatbotAIService.createConversation({
signal: new AbortController().signal,
});
expect(conversation._id).to.be.eql('650b4b260f975ef031016c8a');
});

Expand All @@ -42,13 +44,28 @@ suite('DocsChatbotAIService Test Suite', function () {
global.fetch = fetchStub;

try {
await docsChatbotAIService.createConversation();
await docsChatbotAIService.createConversation({
signal: new AbortController().signal,
});
expect.fail('It must fail with the server error');
} catch (error) {
expect((error as Error).message).to.include('Internal server error');
}
});

test('throws when aborted', async () => {
try {
const abortController = new AbortController();
abortController.abort();
await docsChatbotAIService.createConversation({
signal: abortController.signal,
});
expect.fail('It must fail with the server error');
} catch (error) {
expect((error as Error).message).to.include('This operation was aborted');
}
});

test('throws on bad requests', async () => {
const fetchStub = sinon.stub().resolves({
status: 400,
Expand All @@ -59,7 +76,9 @@ suite('DocsChatbotAIService Test Suite', function () {
global.fetch = fetchStub;

try {
await docsChatbotAIService.createConversation();
await docsChatbotAIService.createConversation({
signal: new AbortController().signal,
});
expect.fail('It must fail with the bad request error');
} catch (error) {
expect((error as Error).message).to.include('Bad request');
Expand All @@ -76,7 +95,9 @@ suite('DocsChatbotAIService Test Suite', function () {
global.fetch = fetchStub;

try {
await docsChatbotAIService.createConversation();
await docsChatbotAIService.createConversation({
signal: new AbortController().signal,
});
expect.fail('It must fail with the rate limited error');
} catch (error) {
expect((error as Error).message).to.include('Rate limited');
Expand All @@ -95,6 +116,7 @@ suite('DocsChatbotAIService Test Suite', function () {
await docsChatbotAIService.addMessage({
conversationId: '650b4b260f975ef031016c8a',
message: 'what is mongosh?',
signal: new AbortController().signal,
});
expect.fail('It must fail with the timeout error');
} catch (error) {
Expand Down
Loading