diff --git a/src/participant/participant.ts b/src/participant/participant.ts index ee431a884..a613aadc4 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -39,6 +39,7 @@ import { } from '../telemetry/telemetryService'; import { DocsChatbotAIService } from './docsChatbotAIService'; import type TelemetryService from '../telemetry/telemetryService'; +import type { ModelInput } from './prompts/promptBase'; import { processStreamWithIdentifiers } from './streamParsing'; import type { PromptIntent } from './prompts/intent'; @@ -164,10 +165,10 @@ export default class ParticipantController { } async _getChatResponse({ - messages, + modelInput, token, }: { - messages: vscode.LanguageModelChatMessage[]; + modelInput: ModelInput; token: vscode.CancellationToken; }): Promise { const model = await getCopilotModel(); @@ -176,20 +177,22 @@ export default class ParticipantController { throw new Error('Copilot model not found'); } - return await model.sendRequest(messages, {}, token); + this._telemetryService.trackCopilotParticipantPrompt(modelInput.stats); + + return await model.sendRequest(modelInput.messages, {}, token); } async streamChatResponse({ - messages, + modelInput, stream, token, }: { - messages: vscode.LanguageModelChatMessage[]; + modelInput: ModelInput; stream: vscode.ChatResponseStream; token: vscode.CancellationToken; }): Promise { const chatResponse = await this._getChatResponse({ - messages, + modelInput, token, }); for await (const fragment of chatResponse.text) { @@ -226,16 +229,16 @@ export default class ParticipantController { } async streamChatResponseContentWithCodeActions({ - messages, + modelInput, stream, token, }: { - messages: vscode.LanguageModelChatMessage[]; + modelInput: ModelInput; stream: vscode.ChatResponseStream; token: vscode.CancellationToken; }): Promise { const chatResponse = await this._getChatResponse({ - messages, + modelInput, token, }); @@ -254,15 +257,15 @@ export default class ParticipantController { // This will stream all of the response content and create a string from it. // It should only be used when the entire response is needed at one time. async getChatResponseContent({ - messages, + modelInput, token, }: { - messages: vscode.LanguageModelChatMessage[]; + modelInput: ModelInput; token: vscode.CancellationToken; }): Promise { let responseContent = ''; const chatResponse = await this._getChatResponse({ - messages, + modelInput, token, }); for await (const fragment of chatResponse.text) { @@ -278,14 +281,14 @@ export default class ParticipantController { stream: vscode.ChatResponseStream, token: vscode.CancellationToken ): Promise { - const messages = await Prompts.generic.buildMessages({ + const modelInput = await Prompts.generic.buildMessages({ request, context, connectionNames: this._getConnectionNames(), }); await this.streamChatResponseContentWithCodeActions({ - messages, + modelInput, token, stream, }); @@ -334,14 +337,14 @@ export default class ParticipantController { request: vscode.ChatRequest; token: vscode.CancellationToken; }): Promise { - const messages = await Prompts.intent.buildMessages({ + const modelInput = await Prompts.intent.buildMessages({ connectionNames: this._getConnectionNames(), request, context, }); const responseContent = await this.getChatResponseContent({ - messages, + modelInput, token, }); @@ -708,7 +711,7 @@ export default class ParticipantController { connectionNames: this._getConnectionNames(), }); const responseContentWithNamespace = await this.getChatResponseContent({ - messages: messagesWithNamespace, + modelInput: messagesWithNamespace, token, }); const { databaseName, collectionName } = @@ -1043,7 +1046,7 @@ export default class ParticipantController { return schemaRequestChatResult(context.history); } - const messages = await Prompts.schema.buildMessages({ + const modelInput = await Prompts.schema.buildMessages({ request, context, databaseName, @@ -1054,7 +1057,7 @@ export default class ParticipantController { ...(sampleDocuments ? { sampleDocuments } : {}), }); await this.streamChatResponse({ - messages, + modelInput, stream, token, }); @@ -1147,7 +1150,7 @@ export default class ParticipantController { ); } - const messages = await Prompts.query.buildMessages({ + const modelInput = await Prompts.query.buildMessages({ request, context, databaseName, @@ -1158,7 +1161,7 @@ export default class ParticipantController { }); await this.streamChatResponseContentWithCodeActions({ - messages, + modelInput, stream, token, }); @@ -1230,14 +1233,14 @@ export default class ParticipantController { ] ): Promise { const [request, context, stream, token] = args; - const messages = await Prompts.generic.buildMessages({ + const modelInput = await Prompts.generic.buildMessages({ request, context, connectionNames: this._getConnectionNames(), }); await this.streamChatResponseContentWithCodeActions({ - messages, + modelInput, stream, token, }); diff --git a/src/participant/prompts/intent.ts b/src/participant/prompts/intent.ts index 4d6216afa..8a1266f69 100644 --- a/src/participant/prompts/intent.ts +++ b/src/participant/prompts/intent.ts @@ -1,3 +1,4 @@ +import type { InternalPromptPurpose } from '../../telemetry/telemetryService'; import type { PromptArgsBase } from './promptBase'; import { PromptBase } from './promptBase'; @@ -47,4 +48,8 @@ Docs`; return 'Default'; } } + + protected get internalPurposeForTelemetry(): InternalPromptPurpose { + return 'intent'; + } } diff --git a/src/participant/prompts/namespace.ts b/src/participant/prompts/namespace.ts index e29f24d2c..c5428f191 100644 --- a/src/participant/prompts/namespace.ts +++ b/src/participant/prompts/namespace.ts @@ -1,3 +1,4 @@ +import type { InternalPromptPurpose } from '../../telemetry/telemetryService'; import type { PromptArgsBase } from './promptBase'; import { PromptBase } from './promptBase'; @@ -50,4 +51,8 @@ No names found. const collectionName = text.match(COL_NAME_REGEX)?.[1].trim(); return { databaseName, collectionName }; } + + protected get internalPurposeForTelemetry(): InternalPromptPurpose { + return 'namespace'; + } } diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 0f3c83286..949b4f3d0 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -1,5 +1,9 @@ import * as vscode from 'vscode'; import type { ChatResult, ParticipantResponseType } from '../constants'; +import type { + InternalPromptPurpose, + ParticipantPromptProperties, +} from '../../telemetry/telemetryService'; export interface PromptArgsBase { request: { @@ -10,14 +14,31 @@ export interface PromptArgsBase { connectionNames: string[]; } +export interface UserPromptResponse { + prompt: string; + hasSampleDocs: boolean; +} + +export interface ModelInput { + messages: vscode.LanguageModelChatMessage[]; + stats: ParticipantPromptProperties; +} + export abstract class PromptBase { protected abstract getAssistantPrompt(args: TArgs): string; - protected getUserPrompt(args: TArgs): Promise { - return Promise.resolve(args.request.prompt); + protected get internalPurposeForTelemetry(): InternalPromptPurpose { + return undefined; } - async buildMessages(args: TArgs): Promise { + protected getUserPrompt(args: TArgs): Promise { + return Promise.resolve({ + prompt: args.request.prompt, + hasSampleDocs: false, + }); + } + + async buildMessages(args: TArgs): Promise { let historyMessages = this.getHistoryMessages(args); // If the current user's prompt is a connection name, and the last // message was to connect. We want to use the last @@ -49,13 +70,37 @@ export abstract class PromptBase { } } - return [ + const { prompt, hasSampleDocs } = await this.getUserPrompt(args); + const messages = [ // eslint-disable-next-line new-cap vscode.LanguageModelChatMessage.Assistant(this.getAssistantPrompt(args)), ...historyMessages, // eslint-disable-next-line new-cap - vscode.LanguageModelChatMessage.User(await this.getUserPrompt(args)), + vscode.LanguageModelChatMessage.User(prompt), ]; + + return { + messages, + stats: this.getStats(messages, args, hasSampleDocs), + }; + } + + protected getStats( + messages: vscode.LanguageModelChatMessage[], + { request, context }: TArgs, + hasSampleDocs: boolean + ): ParticipantPromptProperties { + return { + total_message_length: messages.reduce( + (acc, message) => acc + message.content.length, + 0 + ), + user_input_length: request.prompt.length, + has_sample_documents: hasSampleDocs, + command: request.command || 'generic', + history_size: context.history.length, + internal_purpose: this.internalPurposeForTelemetry, + }; } // When passing the history to the model we only want contextual messages diff --git a/src/participant/prompts/query.ts b/src/participant/prompts/query.ts index eff4d29ff..1efef4ba6 100644 --- a/src/participant/prompts/query.ts +++ b/src/participant/prompts/query.ts @@ -2,8 +2,8 @@ import * as vscode from 'vscode'; import type { Document } from 'bson'; import { getStringifiedSampleDocuments } from '../sampleDocuments'; +import type { PromptArgsBase, UserPromptResponse } from './promptBase'; import { codeBlockIdentifier } from '../constants'; -import type { PromptArgsBase } from './promptBase'; import { PromptBase } from './promptBase'; interface QueryPromptArgs extends PromptArgsBase { @@ -59,21 +59,23 @@ db.getCollection('');\n`; request, schema, sampleDocuments, - }: QueryPromptArgs): Promise { + }: QueryPromptArgs): Promise { let prompt = request.prompt; prompt += `\nDatabase name: ${databaseName}\n`; prompt += `Collection name: ${collectionName}\n`; if (schema) { prompt += `Collection schema: ${schema}\n`; } - if (sampleDocuments) { - prompt += await getStringifiedSampleDocuments({ - sampleDocuments, - prompt, - }); - } - return prompt; + const sampleDocumentsPrompt = await getStringifiedSampleDocuments({ + sampleDocuments, + prompt, + }); + + return { + prompt: `${prompt}${sampleDocumentsPrompt}`, + hasSampleDocs: !!sampleDocumentsPrompt, + }; } get emptyRequestResponse(): string { diff --git a/src/participant/prompts/schema.ts b/src/participant/prompts/schema.ts index 895f99568..ca8b54b26 100644 --- a/src/participant/prompts/schema.ts +++ b/src/participant/prompts/schema.ts @@ -1,3 +1,4 @@ +import type { UserPromptResponse } from './promptBase'; import { PromptBase, type PromptArgsBase } from './promptBase'; export const DOCUMENTS_TO_SAMPLE_FOR_SCHEMA_PROMPT = 100; @@ -11,7 +12,6 @@ export interface SchemaPromptArgs extends PromptArgsBase { collectionName: string; schema: string; amountOfDocumentsSampled: number; - connectionNames: string[]; } export class SchemaPrompt extends PromptBase { @@ -30,13 +30,16 @@ Amount of documents sampled: ${amountOfDocumentsSampled}.`; collectionName, request, schema, - }: SchemaPromptArgs): Promise { + }: SchemaPromptArgs): Promise { const prompt = request.prompt; - return Promise.resolve(`${ - prompt ? `The user provided additional information: "${prompt}"\n` : '' - }Database name: ${databaseName} + return Promise.resolve({ + prompt: `${ + prompt ? `The user provided additional information: "${prompt}"\n` : '' + }Database name: ${databaseName} Collection name: ${collectionName} Schema: -${schema}`); +${schema}`, + hasSampleDocs: false, + }); } } diff --git a/src/telemetry/telemetryService.ts b/src/telemetry/telemetryService.ts index 53dd2cba5..930f7e950 100644 --- a/src/telemetry/telemetryService.ts +++ b/src/telemetry/telemetryService.ts @@ -108,6 +108,17 @@ type ParticipantResponseFailedProperties = { error_name: ParticipantErrorTypes; }; +export type InternalPromptPurpose = 'intent' | 'namespace' | undefined; + +export type ParticipantPromptProperties = { + command: string; + user_input_length: number; + total_message_length: number; + has_sample_documents: boolean; + history_size: number; + internal_purpose: InternalPromptPurpose; +}; + export function chatResultFeedbackKindToTelemetryValue( kind: vscode.ChatResultFeedbackKind ): TelemetryFeedbackKind { @@ -160,6 +171,7 @@ export enum TelemetryEventTypes { PARTICIPANT_FEEDBACK = 'Participant Feedback', PARTICIPANT_WELCOME_SHOWN = 'Participant Welcome Shown', PARTICIPANT_RESPONSE_FAILED = 'Participant Response Failed', + PARTICIPANT_PROMPT_SUBMITTED = 'Participant Prompt Submitted', } export enum ParticipantErrorTypes { @@ -422,4 +434,8 @@ export default class TelemetryService { trackCopilotParticipantFeedback(props: ParticipantFeedbackProperties): void { this.track(TelemetryEventTypes.PARTICIPANT_FEEDBACK, props); } + + trackCopilotParticipantPrompt(stats: ParticipantPromptProperties): void { + this.track(TelemetryEventTypes.PARTICIPANT_PROMPT_SUBMITTED, stats); + } } diff --git a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts index c3135400d..1b3d45636 100644 --- a/src/test/ai-accuracy-tests/ai-accuracy-tests.ts +++ b/src/test/ai-accuracy-tests/ai-accuracy-tests.ts @@ -19,6 +19,7 @@ import { } from './create-test-results-html-page'; import { anyOf, runCodeInMessage } from './assertions'; import { Prompts } from '../../participant/prompts'; +import type { ModelInput } from '../../participant/prompts/promptBase'; const numberOfRunsPerTest = 1; @@ -489,7 +490,7 @@ const buildMessages = async ({ }: { testCase: TestCase; fixtures: Fixtures; -}): Promise => { +}): Promise => { switch (testCase.type) { case 'intent': return Prompts.intent.buildMessages({ @@ -499,7 +500,7 @@ const buildMessages = async ({ }); case 'generic': - return Prompts.generic.buildMessages({ + return await Prompts.generic.buildMessages({ request: { prompt: testCase.userInput }, context: { history: [] }, connectionNames: [], @@ -552,7 +553,7 @@ async function runTest({ aiBackend: AIBackend; fixtures: Fixtures; }): Promise { - const messages = await buildMessages({ + const { messages } = await buildMessages({ testCase, fixtures, }); diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 09849239c..adde458d8 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -11,6 +11,10 @@ import ConnectionController from '../../../connectionController'; import { StorageController } from '../../../storage'; import { StatusView } from '../../../views'; import { ExtensionContextStub } from '../stubs'; +import type { + InternalPromptPurpose, + ParticipantPromptProperties, +} from '../../../telemetry/telemetryService'; import TelemetryService, { TelemetryEventTypes, } from '../../../telemetry/telemetryService'; @@ -77,6 +81,45 @@ suite('Participant Controller Test Suite', function () { chatTokenStub ); + const assertCommandTelemetry = ( + command: string, + chatRequest: vscode.ChatRequest, + { + expectSampleDocs = false, + callIndex = 0, + expectedCallCount, + expectedInternalPurpose = undefined, + }: { + expectSampleDocs?: boolean; + callIndex: number; + expectedCallCount: number; + expectedInternalPurpose?: InternalPromptPurpose; + } + ): void => { + expect(telemetryTrackStub.callCount).to.equal(expectedCallCount); + + const call = telemetryTrackStub.getCalls()[callIndex]; + expect(call.args[0]).to.equal('Participant Prompt Submitted'); + + const properties = call.args[1] as ParticipantPromptProperties; + + expect(properties.command).to.equal(command); + expect(properties.has_sample_documents).to.equal(expectSampleDocs); + expect(properties.history_size).to.equal(chatContextStub.history.length); + + // Total message length includes participant as well as user prompt + expect(properties.total_message_length).to.be.greaterThan( + properties.user_input_length + ); + + // User prompt length should be at least equal to the supplied user prompt, but my occasionally + // be greater - e.g. when we enhance the context. + expect(properties.user_input_length).to.be.greaterThanOrEqual( + chatRequest.prompt.length + ); + expect(properties.internal_purpose).to.equal(expectedInternalPurpose); + }; + beforeEach(function () { testStorageController = new StorageController(extensionContextStub); testStatusView = new StatusView(extensionContextStub); @@ -382,11 +425,17 @@ suite('Participant Controller Test Suite', function () { const welcomeMessage = chatStreamStub.markdown.firstCall.args[0]; expect(welcomeMessage).to.include('Welcome to MongoDB Participant!'); - sinon.assert.calledOnce(telemetryTrackStub); - expect(telemetryTrackStub.lastCall.args[0]).to.equal( + // Once to report welcome screen shown, second time to track the user prompt + expect(telemetryTrackStub).to.have.been.calledTwice; + expect(telemetryTrackStub.firstCall.args[0]).to.equal( TelemetryEventTypes.PARTICIPANT_WELCOME_SHOWN ); - expect(telemetryTrackStub.lastCall.args[1]).to.be.undefined; + expect(telemetryTrackStub.firstCall.args[1]).to.be.undefined; + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 1, + expectedCallCount: 2, + expectedInternalPurpose: 'namespace', + }); }); }); @@ -498,6 +547,17 @@ suite('Participant Controller Test Suite', function () { }, ], }); + + assertCommandTelemetry('generic', chatRequestMock, { + expectedCallCount: 2, + callIndex: 0, + expectedInternalPurpose: 'intent', + }); + + assertCommandTelemetry('generic', chatRequestMock, { + expectedCallCount: 2, + callIndex: 1, + }); }); }); @@ -526,6 +586,17 @@ suite('Participant Controller Test Suite', function () { }, ], }); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 0, + expectedCallCount: 2, + expectedInternalPurpose: 'namespace', + }); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 1, + expectedCallCount: 2, + }); }); test('includes a collection schema', async function () { @@ -551,6 +622,17 @@ suite('Participant Controller Test Suite', function () { 'field.stringField: String\n' + 'field.arrayField: Array\n' ); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 0, + expectedCallCount: 2, + expectedInternalPurpose: 'namespace', + }); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 1, + expectedCallCount: 2, + }); }); suite('useSampleDocsInCopilot setting is true', function () { @@ -617,6 +699,18 @@ suite('Participant Controller Test Suite', function () { ' }\n' + ']\n' ); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 0, + expectedCallCount: 2, + expectedInternalPurpose: 'namespace', + }); + + assertCommandTelemetry('query', chatRequestMock, { + expectSampleDocs: true, + callIndex: 1, + expectedCallCount: 2, + }); }); test('includes 1 sample document as an object', async function () { @@ -661,6 +755,18 @@ suite('Participant Controller Test Suite', function () { ' }\n' + '}\n' ); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 0, + expectedCallCount: 2, + expectedInternalPurpose: 'namespace', + }); + + assertCommandTelemetry('query', chatRequestMock, { + expectSampleDocs: true, + callIndex: 1, + expectedCallCount: 2, + }); }); test('includes 1 sample documents when 3 make prompt too long', async function () { @@ -703,6 +809,18 @@ suite('Participant Controller Test Suite', function () { ' }\n' + '}\n' ); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 0, + expectedCallCount: 2, + expectedInternalPurpose: 'namespace', + }); + + assertCommandTelemetry('query', chatRequestMock, { + expectSampleDocs: true, + callIndex: 1, + expectedCallCount: 2, + }); }); test('does not include sample documents when even 1 makes prompt too long', async function () { @@ -740,6 +858,17 @@ suite('Participant Controller Test Suite', function () { await invokeChatHandler(chatRequestMock); const messages = sendRequestStub.secondCall.args[0]; expect(messages[1].content).to.not.include('Sample documents'); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 0, + expectedCallCount: 2, + expectedInternalPurpose: 'namespace', + }); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 1, + expectedCallCount: 2, + }); }); }); @@ -753,6 +882,17 @@ suite('Participant Controller Test Suite', function () { await invokeChatHandler(chatRequestMock); const messages = sendRequestStub.secondCall.args[0]; expect(messages[1].content).to.not.include('Sample documents'); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 0, + expectedCallCount: 2, + expectedInternalPurpose: 'namespace', + }); + + assertCommandTelemetry('query', chatRequestMock, { + callIndex: 1, + expectedCallCount: 2, + }); }); }); }); @@ -1314,12 +1454,12 @@ Schema: expect(sendRequestStub).to.have.been.called; // Expect the error to be reported through the telemetry service - sinon.assert.calledOnce(telemetryTrackStub); - expect(telemetryTrackStub.lastCall.args[0]).to.equal( + expect(telemetryTrackStub).to.have.been.calledTwice; + expect(telemetryTrackStub.firstCall.args[0]).to.equal( TelemetryEventTypes.PARTICIPANT_RESPONSE_FAILED ); - const properties = telemetryTrackStub.lastCall.args[1]; + const properties = telemetryTrackStub.firstCall.args[1]; expect(properties.command).to.equal('docs'); expect(properties.error_name).to.equal('Docs Chatbot API Issue'); }); @@ -1332,7 +1472,7 @@ Schema: const chatRequestMock = { prompt: 'find all docs by a name example', }; - const messages = await Prompts.generic.buildMessages({ + const { messages, stats } = await Prompts.generic.buildMessages({ context: chatContextStub, request: chatRequestMock, connectionNames: [], @@ -1345,6 +1485,13 @@ Schema: expect(messages[1].role).to.equal( vscode.LanguageModelChatMessageRole.User ); + + expect(stats.command).to.equal('generic'); + expect(stats.has_sample_documents).to.be.false; + expect(stats.user_input_length).to.equal(chatRequestMock.prompt.length); + expect(stats.total_message_length).to.equal( + messages[0].content.length + messages[1].content.length + ); }); test('query', async function () { @@ -1364,7 +1511,7 @@ Schema: }), ], }; - const messages = await Prompts.query.buildMessages({ + const { messages, stats } = await Prompts.query.buildMessages({ context: chatContextStub, request: chatRequestMock, collectionName: 'people', @@ -1407,6 +1554,21 @@ Schema: expect(messages[2].role).to.equal( vscode.LanguageModelChatMessageRole.User ); + + expect(stats.command).to.equal('query'); + expect(stats.has_sample_documents).to.be.true; + expect(stats.user_input_length).to.equal(chatRequestMock.prompt.length); + expect(stats.total_message_length).to.equal( + messages[0].content.length + + messages[1].content.length + + messages[2].content.length + ); + + // The length of the user prompt length should be taken from the prompt supplied + // by the user, even if we enhance it with sample docs and schema. + expect(stats.user_input_length).to.be.lessThan( + messages[2].content.length + ); }); test('schema', async function () { @@ -1423,7 +1585,7 @@ Schema: name: String } `; - const messages = await Prompts.schema.buildMessages({ + const { messages, stats } = await Prompts.schema.buildMessages({ context: chatContextStub, request: chatRequestMock, amountOfDocumentsSampled: 3, @@ -1445,6 +1607,13 @@ Schema: expect(messages[1].content).to.include(databaseName); expect(messages[1].content).to.include(collectionName); expect(messages[1].content).to.include(schema); + + expect(stats.command).to.equal('schema'); + expect(stats.has_sample_documents).to.be.false; + expect(stats.user_input_length).to.equal(chatRequestMock.prompt.length); + expect(stats.total_message_length).to.equal( + messages[0].content.length + messages[1].content.length + ); }); test('namespace', async function () { @@ -1452,7 +1621,7 @@ Schema: prompt: 'find all docs by a name example', command: 'query', }; - const messages = await Prompts.namespace.buildMessages({ + const { messages, stats } = await Prompts.namespace.buildMessages({ context: chatContextStub, request: chatRequestMock, connectionNames: [], @@ -1465,6 +1634,13 @@ Schema: expect(messages[1].role).to.equal( vscode.LanguageModelChatMessageRole.User ); + + expect(stats.command).to.equal('query'); + expect(stats.has_sample_documents).to.be.false; + expect(stats.user_input_length).to.equal(chatRequestMock.prompt.length); + expect(stats.total_message_length).to.equal( + messages[0].content.length + messages[1].content.length + ); }); test('removes askForConnect messages from history', async function () { @@ -1475,10 +1651,14 @@ Schema: command: 'query', }; + // This is the prompt of the user prior to us asking them to connect + const expectedPrompt = + 'give me the count of all people in the prod database'; + chatContextStub = { history: [ Object.assign(Object.create(vscode.ChatRequestTurn.prototype), { - prompt: 'give me the count of all people in the prod database', + prompt: expectedPrompt, command: 'query', references: [], participant: CHAT_PARTICIPANT_ID, @@ -1514,7 +1694,7 @@ Schema: ], }; - const messages = await Prompts.query.buildMessages({ + const { messages, stats } = await Prompts.query.buildMessages({ context: chatContextStub, request: chatRequestMock, collectionName: 'people', @@ -1534,8 +1714,18 @@ Schema: expect(messages[1].role).to.equal( vscode.LanguageModelChatMessageRole.User ); - expect(messages[1].content).to.contain( - 'give me the count of all people in the prod database' + expect(messages[1].content).to.contain(expectedPrompt); + + expect(stats.command).to.equal('query'); + expect(stats.has_sample_documents).to.be.false; + expect(stats.user_input_length).to.equal(expectedPrompt.length); + expect(stats.total_message_length).to.equal( + messages[0].content.length + messages[1].content.length + ); + + // The prompt builder may add extra info, but we're only reporting the actual user input + expect(stats.user_input_length).to.be.lessThan( + messages[1].content.length ); }); });