diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 5fe25c130..c8ee509c3 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -40,11 +40,12 @@ import { import { DocsChatbotAIService } from './docsChatbotAIService'; import type TelemetryService from '../telemetry/telemetryService'; import formatError from '../utils/formatError'; -import type { ModelInput } from './prompts/promptBase'; +import { getContent, type ModelInput } from './prompts/promptBase'; import { processStreamWithIdentifiers } from './streamParsing'; import type { PromptIntent } from './prompts/intent'; import type { DataService } from 'mongodb-data-service'; import { ParticipantErrorTypes } from './participantErrorTypes'; +import { PromptHistory } from './prompts/promptHistory'; const log = createLogger('participant'); @@ -1415,10 +1416,12 @@ export default class ParticipantController { chatId, token, stream, + context, }: { prompt: string; chatId: string; token: vscode.CancellationToken; + context: vscode.ChatContext; stream: vscode.ChatResponseStream; }): Promise<{ responseContent: string; @@ -1446,8 +1449,22 @@ export default class ParticipantController { log.info('Docs chatbot created for chatId', chatId); } + const history = PromptHistory.getFilteredHistoryForDocs({ + connectionNames: this._getConnectionNames(), + context: context, + }); + + const previousMessages = + history.length > 0 + ? `${history + .map((message: vscode.LanguageModelChatMessage) => + getContent(message) + ) + .join('\n\n')}\n\n` + : ''; + const response = await this._docsChatbotAIService.addMessage({ - message: prompt, + message: `${previousMessages}${prompt}`, conversationId: docsChatbotConversationId, signal: abortController.signal, }); @@ -1553,6 +1570,7 @@ export default class ParticipantController { chatId, token, stream, + context, }); if (docsResult.responseContent) { diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 0a02a5bab..0df0d1b38 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -1,10 +1,10 @@ import * as vscode from 'vscode'; -import type { ChatResult, ParticipantResponseType } from '../constants'; +import type { ChatResult } from '../constants'; import type { InternalPromptPurpose, ParticipantPromptProperties, } from '../../telemetry/telemetryService'; -import { ParticipantErrorTypes } from '../participantErrorTypes'; +import { PromptHistory } from './promptHistory'; export interface PromptArgsBase { request: { @@ -53,6 +53,26 @@ export function getContentLength( return 0; } +export function getContent(message: vscode.LanguageModelChatMessage): string { + const content = message.content as any; + if (typeof content === 'string') { + return content; + } + + if (Array.isArray(content)) { + return content.reduce((agg: string, element) => { + const value = element?.value ?? element?.content?.value; + if (typeof value === 'string') { + return agg + value; + } + + return agg; + }, ''); + } + + return ''; +} + export function isContentEmpty( message: vscode.LanguageModelChatMessage ): boolean { @@ -88,7 +108,10 @@ export abstract class PromptBase { } async buildMessages(args: TArgs): Promise { - let historyMessages = this.getHistoryMessages(args); + let historyMessages = PromptHistory.getFilteredHistory({ + history: args.context?.history, + ...args, + }); // If the current user's prompt is a connection name, and the last // message was to connect. We want to use the last // message they sent before the connection name as their prompt. @@ -157,115 +180,4 @@ export abstract class PromptBase { internal_purpose: this.internalPurposeForTelemetry, }; } - - // When passing the history to the model we only want contextual messages - // to be passed. This function parses through the history and returns - // the messages that are valuable to keep. - // eslint-disable-next-line complexity - protected getHistoryMessages({ - connectionNames, - context, - databaseName, - collectionName, - }: { - connectionNames?: string[]; // Used to scrape the connecting messages from the history. - context?: vscode.ChatContext; - databaseName?: string; - collectionName?: string; - }): vscode.LanguageModelChatMessage[] { - const messages: vscode.LanguageModelChatMessage[] = []; - - if (!context) { - return []; - } - - let previousItem: - | vscode.ChatRequestTurn - | vscode.ChatResponseTurn - | undefined = undefined; - - const namespaceIsKnown = - databaseName !== undefined && collectionName !== undefined; - for (const historyItem of context.history) { - if (historyItem instanceof vscode.ChatRequestTurn) { - if ( - historyItem.prompt?.trim().length === 0 || - connectionNames?.includes(historyItem.prompt) - ) { - // When the message is empty or a connection name then we skip it. - // It's probably going to be the response to the connect step. - previousItem = historyItem; - continue; - } - - if (previousItem instanceof vscode.ChatResponseTurn) { - const responseIntent = (previousItem.result as ChatResult).metadata - ?.intent; - - // If the namespace is already known, skip responses to prompts asking for it. - if (responseIntent === 'askForNamespace' && namespaceIsKnown) { - previousItem = historyItem; - continue; - } - } - - // eslint-disable-next-line new-cap - messages.push(vscode.LanguageModelChatMessage.User(historyItem.prompt)); - } - - if (historyItem instanceof vscode.ChatResponseTurn) { - if ( - historyItem.result.errorDetails?.message === - ParticipantErrorTypes.FILTERED - ) { - // If the response led to a filtered error, we do not want the - // error-causing message to be sent again so we remove it. - messages.pop(); - continue; - } - - let message = ''; - - // Skip a response to an empty user prompt message or connect message. - const responseTypesToSkip: ParticipantResponseType[] = [ - 'emptyRequest', - 'askToConnect', - ]; - - const responseType = (historyItem.result as ChatResult)?.metadata - ?.intent; - if (responseTypesToSkip.includes(responseType)) { - previousItem = historyItem; - continue; - } - - // If the namespace is already known, skip including prompts asking for it. - if (responseType === 'askForNamespace' && namespaceIsKnown) { - previousItem = historyItem; - continue; - } - - for (const fragment of historyItem.response) { - if (fragment instanceof vscode.ChatResponseMarkdownPart) { - message += fragment.value.value; - - if ( - (historyItem.result as ChatResult)?.metadata?.intent === - 'askForNamespace' - ) { - // When the message is the assistant asking for part of a namespace, - // we only want to include the question asked, not the user's - // database and collection names in the history item. - break; - } - } - } - // eslint-disable-next-line new-cap - messages.push(vscode.LanguageModelChatMessage.Assistant(message)); - } - previousItem = historyItem; - } - - return messages; - } } diff --git a/src/participant/prompts/promptHistory.ts b/src/participant/prompts/promptHistory.ts new file mode 100644 index 000000000..6f55e577a --- /dev/null +++ b/src/participant/prompts/promptHistory.ts @@ -0,0 +1,198 @@ +import * as vscode from 'vscode'; +import { ParticipantErrorTypes } from '../participantErrorTypes'; +import type { ChatResult, ParticipantResponseType } from '../constants'; + +export class PromptHistory { + private static _handleChatResponseTurn({ + currentTurn, + namespaceIsKnown, + }: { + currentTurn: vscode.ChatResponseTurn; + namespaceIsKnown: boolean; + }): vscode.LanguageModelChatMessage | undefined { + if ( + currentTurn.result.errorDetails?.message === + ParticipantErrorTypes.FILTERED + ) { + return undefined; + } + + let message = ''; + + // Skip a response to an empty user prompt message or connect message. + const responseTypesToSkip: ParticipantResponseType[] = [ + 'emptyRequest', + 'askToConnect', + ]; + + const responseType = (currentTurn.result as ChatResult)?.metadata?.intent; + if (responseTypesToSkip.includes(responseType)) { + // eslint-disable-next-line new-cap + return undefined; + } + + // If the namespace is already known, skip including prompts asking for it. + if (responseType === 'askForNamespace' && namespaceIsKnown) { + // eslint-disable-next-line new-cap + return undefined; + } + + for (const fragment of currentTurn.response) { + if (fragment instanceof vscode.ChatResponseMarkdownPart) { + message += fragment.value.value; + + if ( + (currentTurn.result as ChatResult)?.metadata?.intent === + 'askForNamespace' + ) { + // When the message is the assistant asking for part of a namespace, + // we only want to include the question asked, not the user's + // database and collection names in the history item. + break; + } + } + } + + // eslint-disable-next-line new-cap + return vscode.LanguageModelChatMessage.Assistant(message); + } + + private static _handleChatRequestTurn({ + previousTurn, + currentTurn, + nextTurn, + connectionNames, + namespaceIsKnown, + }: { + previousTurn: vscode.ChatRequestTurn | vscode.ChatResponseTurn | undefined; + currentTurn: vscode.ChatRequestTurn; + nextTurn: vscode.ChatRequestTurn | vscode.ChatResponseTurn | undefined; + connectionNames: string[] | undefined; + namespaceIsKnown: boolean; + }): vscode.LanguageModelChatMessage | undefined { + if (previousTurn instanceof vscode.ChatResponseTurn) { + const responseIntent = (previousTurn.result as ChatResult).metadata + ?.intent; + + if (responseIntent === 'askForNamespace' && namespaceIsKnown) { + // If the namespace is already known, skip responses to prompts asking for it. + return undefined; + } + } + + if ( + nextTurn instanceof vscode.ChatResponseTurn && + nextTurn.result.errorDetails?.message === ParticipantErrorTypes.FILTERED + ) { + // If the response to this request led to a filtered error, + // we do not want to include it in the history + return undefined; + } + + if ( + currentTurn.prompt?.trim().length === 0 || + connectionNames?.includes(currentTurn.prompt) + ) { + // When the message is empty or a connection name then we skip it. + // It's probably going to be the response to the connect step. + return undefined; + } + + // eslint-disable-next-line new-cap + return vscode.LanguageModelChatMessage.User(currentTurn.prompt); + } + + /** When passing the history to the model we only want contextual messages + to be passed. This function parses through the history and returns + the messages that are valuable to keep. */ + static getFilteredHistory({ + connectionNames, + history, + databaseName, + collectionName, + }: { + connectionNames?: string[]; // Used to scrape the connecting messages from the history. + history?: vscode.ChatContext['history']; + databaseName?: string; + collectionName?: string; + }): vscode.LanguageModelChatMessage[] { + const messages: vscode.LanguageModelChatMessage[] = []; + + if (!history) { + return []; + } + + const namespaceIsKnown = + databaseName !== undefined && collectionName !== undefined; + for (let i = 0; i < history.length; i++) { + const currentTurn = history[i]; + + let addedMessage: vscode.LanguageModelChatMessage | undefined; + if (currentTurn instanceof vscode.ChatRequestTurn) { + const previousTurn = i - 1 >= 0 ? history[i - 1] : undefined; + const nextTurn = i + 1 < history.length ? history[i + 1] : undefined; + + addedMessage = this._handleChatRequestTurn({ + previousTurn, + currentTurn, + nextTurn, + connectionNames, + namespaceIsKnown, + }); + } else if (currentTurn instanceof vscode.ChatResponseTurn) { + addedMessage = this._handleChatResponseTurn({ + currentTurn, + namespaceIsKnown, + }); + } + if (addedMessage) { + messages.push(addedMessage); + } + } + + return messages; + } + + /** The docs chatbot keeps its own history so we avoid any + * we need to include history only since last docs message. */ + static getFilteredHistoryForDocs({ + connectionNames, + context, + databaseName, + collectionName, + }: { + connectionNames?: string[]; + context?: vscode.ChatContext; + databaseName?: string; + collectionName?: string; + }): vscode.LanguageModelChatMessage[] { + if (!context) { + return []; + } + const historySinceLastDocs: ( + | vscode.ChatRequestTurn + | vscode.ChatResponseTurn + )[] = []; + + /** Limit included messages' history to prevent prompt overflow. */ + const MAX_DOCS_HISTORY_LENGTH = 4; + + for (let i = context.history.length - 1; i >= 0; i--) { + const message = context.history[i]; + + if ( + message.command === 'docs' || + historySinceLastDocs.length >= MAX_DOCS_HISTORY_LENGTH + ) { + break; + } + historySinceLastDocs.push(context.history[i]); + } + return this.getFilteredHistory({ + connectionNames, + history: historySinceLastDocs.reverse(), + databaseName, + collectionName, + }); + } +} diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index f6c05bbec..67bcd5f7a 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -1081,22 +1081,18 @@ suite('Participant Controller Test Suite', function () { '/query', 'find all docs by a name example' ), - createChatResponseTurn('/query', { - response: [ - { - value: { - value: - 'Which database would you like this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', - } as vscode.MarkdownString, - }, - ], - result: { - metadata: { - intent: 'askForNamespace', - chatId: firstChatId, + createChatResponseTurn( + '/query', + 'Which database would you like this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', + { + result: { + metadata: { + intent: 'askForNamespace', + chatId: firstChatId, + }, }, - }, - }), + } + ), ], }; @@ -1151,40 +1147,32 @@ suite('Participant Controller Test Suite', function () { '/query', 'find all docs by a name example' ), - createChatResponseTurn('/query', { - response: [ - { - value: { - value: - 'Which database would you like to this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', - } as vscode.MarkdownString, - }, - ], - result: { - metadata: { - intent: 'askForNamespace', + createChatResponseTurn( + '/query', + 'Which database would you like to this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', + { + result: { + metadata: { + intent: 'askForNamespace', + }, }, - }, - }), + } + ), createChatRequestTurn('/query', 'dbOne'), - createChatResponseTurn('/query', { - response: [ - { - value: { - value: - 'Which collection would you like to query within dbOne? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', - } as vscode.MarkdownString, - }, - ], - result: { - metadata: { - intent: 'askForNamespace', - databaseName: 'dbOne', - collectionName: 'collOne', - chatId: firstChatId, + createChatResponseTurn( + '/query', + 'Which collection would you like to query within dbOne? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', + { + result: { + metadata: { + intent: 'askForNamespace', + databaseName: 'dbOne', + collectionName: 'collOne', + chatId: firstChatId, + }, }, - }, - }), + } + ), ], }; await invokeChatHandler(chatRequestMock); @@ -1224,22 +1212,18 @@ suite('Participant Controller Test Suite', function () { '/query', 'find all docs by a name example' ), - createChatResponseTurn('/query', { - response: [ - { - value: { - value: - 'Which database would you like this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', - } as vscode.MarkdownString, - }, - ], - result: { - metadata: { - intent: 'askForNamespace', - chatId: 'pineapple', + createChatResponseTurn( + '/query', + 'Which database would you like this query to run against? Select one by either clicking on an item in the list or typing the name manually in the chat.\n\n', + { + result: { + metadata: { + intent: 'askForNamespace', + chatId: 'pineapple', + }, }, - }, - }), + } + ), ], }; const chatResult = await invokeChatHandler(chatRequestMock); @@ -1348,12 +1332,7 @@ suite('Participant Controller Test Suite', function () { '/query', 'how do I make a find request vs favorite_fruits.pineapple?' ), - createChatResponseTurn('/query', { - response: [ - { - value: { value: 'some code' } as vscode.MarkdownString, - }, - ], + createChatResponseTurn('/query', 'some code', { result: { metadata: { intent: 'query', @@ -1537,7 +1516,7 @@ Schema: let fetchStub: sinon.SinonStub; beforeEach(function () { - sendRequestStub.onCall(0).resolves({ + sendRequestStub.resolves({ text: ['connection info'], }); }); @@ -1546,6 +1525,88 @@ Schema: global.fetch = initialFetch; }); + suite('includes the history of previous requests', function () { + let addMessageStub: sinon.SinonStub; + beforeEach(function () { + addMessageStub = sinon.stub( + testParticipantController._docsChatbotAIService, + 'addMessage' + ); + }); + + test('since the beginning', async function () { + chatContextStub = { + history: [ + createChatRequestTurn('/query', 'query request'), + createChatResponseTurn('/query', 'query response'), + createChatRequestTurn('/query', 'query request 2'), + createChatResponseTurn('/query', 'query response 2'), + createChatRequestTurn('/schema', 'schema request'), + createChatResponseTurn('/schema', 'schema response'), + ], + }; + + const chatRequestMock = { + prompt: 'docs request', + command: 'docs', + references: [], + }; + + await invokeChatHandler(chatRequestMock); + + expect(addMessageStub.calledOnce).is.true; + expect(addMessageStub.getCall(0).firstArg.message).equal( + [ + 'query request 2', + 'query response 2', + 'schema request', + 'schema response', + 'docs request', + ].join('\n\n') + ); + }); + + test('since the last docs request or response', async function () { + chatContextStub = { + history: [ + createChatRequestTurn('/query', 'query request'), + createChatResponseTurn('/query', 'query response'), + createChatRequestTurn('/docs', 'first docs request'), + createChatResponseTurn('/docs', 'first docs response'), + createChatRequestTurn('/schema', 'schema request'), + createChatResponseTurn('/schema', 'schema response'), + ], + }; + + const chatRequestMock = { + prompt: 'docs request', + command: 'docs', + references: [], + }; + + await invokeChatHandler(chatRequestMock); + + expect(addMessageStub.calledOnce).is.true; + expect(addMessageStub.getCall(0).firstArg.message).equals( + ['schema request', 'schema response', 'docs request'].join('\n\n') + ); + + chatContextStub = { + history: [ + createChatRequestTurn('/query', 'query request'), + createChatResponseTurn('/query', 'query response'), + createChatRequestTurn('/docs', 'first docs request'), + ], + }; + + await invokeChatHandler(chatRequestMock); + + expect(addMessageStub.getCall(1).firstArg.message).equals( + 'docs request' + ); + }); + }); + test('shows a message and docs link on empty prompt', async function () { fetchStub = sinon.stub().resolves(); global.fetch = fetchStub; @@ -2154,42 +2215,32 @@ Schema: chatContextStub = { history: [ createChatRequestTurn('/query', userMessages[0]), - createChatResponseTurn('/query', { - participant: CHAT_PARTICIPANT_ID, - response: [ - { - value: { - value: - 'Which database would you like to query within this database?', - } as vscode.MarkdownString, - }, - ], - result: { - metadata: { - intent: 'askForNamespace', + createChatResponseTurn( + '/query', + 'Which database would you like to query within this database?', + { + result: { + metadata: { + intent: 'askForNamespace', + }, }, - }, - }), + } + ), createChatRequestTurn('/query', 'dbOne'), - createChatResponseTurn('/query', { - participant: CHAT_PARTICIPANT_ID, - response: [ - { - value: { - value: - 'Which collection would you like to query within dbOne?', - } as vscode.MarkdownString, - }, - ], - result: { - metadata: { - intent: 'askForNamespace', - databaseName: 'dbOne', - collectionName: undefined, - chatId: testChatId, + createChatResponseTurn( + '/query', + 'Which collection would you like to query within dbOne?', + { + result: { + metadata: { + intent: 'askForNamespace', + databaseName: 'dbOne', + collectionName: undefined, + chatId: testChatId, + }, }, - }, - }), + } + ), createChatRequestTurn('/query', 'collectionOne'), createChatRequestTurn('/query', userMessages[1]), ], @@ -2268,12 +2319,9 @@ Schema: chatContextStub = { history: [ createChatRequestTurn('/query', expectedPrompt), - createChatResponseTurn('/query', { - participant: CHAT_PARTICIPANT_ID, - response: [ - { - value: { - value: `Looks like you aren't currently connected, first let's get you connected to the cluster we'd like to create this query to run against. + createChatResponseTurn( + '/query', + `Looks like you aren't currently connected, first let's get you connected to the cluster we'd like to create this query to run against. ${createMarkdownLink({ commandId: EXTENSION_COMMANDS.CONNECT_WITH_PARTICIPANT, @@ -2285,16 +2333,15 @@ Schema: name: 'atlas', data: {}, })}`, - } as vscode.MarkdownString, - }, - ], - result: { - metadata: { - intent: 'askToConnect', - chatId: 'abc', + { + result: { + metadata: { + intent: 'askToConnect', + chatId: 'abc', + }, }, - }, - }), + } + ), ], }; @@ -2346,7 +2393,7 @@ Schema: 'give me the count of all people in the prod database' ), createChatRequestTurn('/query', 'some disallowed message'), - createChatResponseTurn('/query', { + createChatResponseTurn('/query', undefined, { result: { errorDetails: { message: ParticipantErrorTypes.FILTERED, diff --git a/src/test/suite/participant/participantHelpers.ts b/src/test/suite/participant/participantHelpers.ts index b66d810d4..77448c3d9 100644 --- a/src/test/suite/participant/participantHelpers.ts +++ b/src/test/suite/participant/participantHelpers.ts @@ -8,10 +8,7 @@ export function createChatRequestTurn( options: { participant?: vscode.ChatRequestTurn['participant']; references?: vscode.ChatRequestTurn['references']; - } = { - participant: CHAT_PARTICIPANT_ID, - references: [], - } + } = {} ): vscode.ChatRequestTurn { const { participant = CHAT_PARTICIPANT_ID, references = [] } = options; @@ -25,18 +22,27 @@ export function createChatRequestTurn( export function createChatResponseTurn( command: ParticipantCommand, + /** Helper shortcut for response text, use options.response for a more manual setup */ + responseText?: string, options: { - response?: vscode.ChatResponseTurn['response']; + response?: vscode.ChatResponseTurn['response'] | undefined; result?: vscode.ChatResponseTurn['result']; participant?: string; - } = { - response: [], - result: {}, - participant: CHAT_PARTICIPANT_ID, - } + } = {} ): vscode.ChatRequestTurn { const { - response = [], + response = responseText + ? [ + Object.assign( + Object.create(vscode.ChatResponseMarkdownPart.prototype), + { + value: { + value: responseText, + }, + } + ), + ] + : [], result = {}, participant = CHAT_PARTICIPANT_ID, } = options;