From e18573c7e767c43f7bffbc7e6a7b25bc7f473a3a Mon Sep 17 00:00:00 2001 From: Nikola Irinchev Date: Thu, 24 Oct 2024 10:06:18 +0200 Subject: [PATCH 1/2] Adapt message content access to latest vscode API --- .vscode/launch.json | 16 +- src/participant/participant.ts | 19 ++- src/participant/prompts/index.ts | 9 +- src/participant/prompts/promptBase.ts | 48 +++++- .../suite/participant/participant.test.ts | 138 ++++++++++++------ 5 files changed, 173 insertions(+), 57 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index e7f6ca212..dd4f401e3 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -40,7 +40,21 @@ "outFiles": [ "${workspaceFolder}/dist/**/*.js" ] - } + }, + { + "name": "Run Tests", + "type": "extensionHost", + "request": "launch", + "runtimeExecutable": "${execPath}", + "args": [ + "${workspaceFolder}/out/test/suite", // TODO: remove suite + "--disable-extensions", + "--extensionDevelopmentPath=${workspaceFolder}", + "--extensionTestsPath=${workspaceFolder}/out/test/suite" + ], + "outFiles": ["${workspaceFolder}/out/**/*.js"], + "preLaunchTask": "npm: compile:extension", + } ], "compounds": [ { diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 07c725ca0..b0530ef90 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -189,7 +189,7 @@ export default class ParticipantController { (message: vscode.LanguageModelChatMessage) => util.inspect({ role: message.role, - contentLength: message.content.length, + contentLength: Prompts.getContentLength(message), }) ), }); @@ -790,15 +790,18 @@ export default class ParticipantController { // it currently errors (not on insiders, only main VSCode). // Here we're defaulting to have some content as a workaround. // TODO: Remove this when the issue is fixed. - messagesWithNamespace.messages[ - messagesWithNamespace.messages.length - 1 - // eslint-disable-next-line new-cap - ] = vscode.LanguageModelChatMessage.User( + if ( + !Prompts.doMessagesContainUserInput([ + messagesWithNamespace.messages[ + messagesWithNamespace.messages.length - 1 + ], + ]) + ) { messagesWithNamespace.messages[ messagesWithNamespace.messages.length - 1 - ].content.trim() || 'see previous messages' - ); - + // eslint-disable-next-line new-cap + ] = vscode.LanguageModelChatMessage.User('see previous messages'); + } const responseContentWithNamespace = await this.getChatResponseContent({ modelInput: messagesWithNamespace, token, diff --git a/src/participant/prompts/index.ts b/src/participant/prompts/index.ts index 18e4150af..91d6bd81d 100644 --- a/src/participant/prompts/index.ts +++ b/src/participant/prompts/index.ts @@ -5,6 +5,7 @@ import { IntentPrompt } from './intent'; import { NamespacePrompt } from './namespace'; import { QueryPrompt } from './query'; import { SchemaPrompt } from './schema'; +import { isContentEmpty, getContentLength } from './promptBase'; export class Prompts { public static generic = new GenericPrompt(); @@ -26,7 +27,7 @@ export class Prompts { for (const message of messages) { if ( message.role === vscode.LanguageModelChatMessageRole.User && - message.content.trim().length > 0 + !isContentEmpty(message) ) { return true; } @@ -34,4 +35,10 @@ export class Prompts { return false; } + + public static getContentLength( + message: vscode.LanguageModelChatMessage + ): number { + return getContentLength(message); + } } diff --git a/src/participant/prompts/promptBase.ts b/src/participant/prompts/promptBase.ts index 949b4f3d0..38a8ae672 100644 --- a/src/participant/prompts/promptBase.ts +++ b/src/participant/prompts/promptBase.ts @@ -24,6 +24,52 @@ export interface ModelInput { stats: ParticipantPromptProperties; } +export function getContentLength( + message: vscode.LanguageModelChatMessage +): number { + const content = message.content as any; + if (typeof content === 'string') { + return content.trim().length; + } + + // TODO: https://github.com/microsoft/vscode/pull/231788 made it so message.content is no longer a string, + // but an array of things that a message can contain. This will eventually be reflected in the type definitions + // but until then, we're manually checking the array contents to ensure we don't break when this PR gets released + // in the stable channel. + if (Array.isArray(content)) { + return content.reduce((acc: number, element) => { + const value = element?.value ?? element?.content?.value; + if (typeof value === 'string') { + return acc + value.length; + } + + return acc; + }, 0); + } + + return 0; +} + +export function isContentEmpty( + message: vscode.LanguageModelChatMessage +): boolean { + const content = message.content as any; + if (typeof content === 'string') { + return content.trim().length === 0; + } + + if (Array.isArray(content)) { + for (const element of content) { + const value = element?.value ?? element?.content?.value; + if (typeof value === 'string' && value.trim().length > 0) { + return false; + } + } + } + + return true; +} + export abstract class PromptBase { protected abstract getAssistantPrompt(args: TArgs): string; @@ -92,7 +138,7 @@ export abstract class PromptBase { ): ParticipantPromptProperties { return { total_message_length: messages.reduce( - (acc, message) => acc + message.content.length, + (acc, message) => acc + getContentLength(message), 0 ), user_input_length: request.prompt.length, diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 38962a84b..094ff3681 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -31,6 +31,7 @@ import { ChatMetadataStore } from '../../../participant/chatMetadata'; import { Prompts } from '../../../participant/prompts'; import { createMarkdownLink } from '../../../participant/markdown'; import EXTENSION_COMMANDS from '../../../commands'; +import { getContentLength } from '../../../participant/prompts/promptBase'; // The Copilot's model in not available in tests, // therefore we need to mock its methods and returning values. @@ -50,6 +51,28 @@ const encodeStringify = (obj: Record): string => { return encodeURIComponent(JSON.stringify(obj)); }; +const getMessageContent = ( + message: vscode.LanguageModelChatMessage +): string => { + const content = message.content as any; + if (typeof content === 'string') { + return content; + } + + if (Array.isArray(content)) { + return content.reduce((agg: string, element) => { + const value = element?.value ?? element?.content?.value; + if (typeof value === 'string') { + return agg + value; + } + + return agg; + }, ''); + } + + return ''; +}; + suite('Participant Controller Test Suite', function () { const extensionContextStub = new ExtensionContextStub(); @@ -514,20 +537,22 @@ suite('Participant Controller Test Suite', function () { const res = await invokeChatHandler(chatRequestMock); expect(sendRequestStub).to.have.been.calledTwice; - const intentRequest = sendRequestStub.firstCall.args[0]; + const intentRequest = sendRequestStub.firstCall + .args[0] as vscode.LanguageModelChatMessage[]; expect(intentRequest).to.have.length(2); - expect(intentRequest[0].content).to.include( + expect(getMessageContent(intentRequest[0])).to.include( 'Your task is to help guide a conversation with a user to the correct handler.' ); - expect(intentRequest[1].content).to.equal( + expect(getMessageContent(intentRequest[1])).to.equal( 'what is the shape of the documents in the pineapple collection?' ); - const genericRequest = sendRequestStub.secondCall.args[0]; + const genericRequest = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; expect(genericRequest).to.have.length(2); - expect(genericRequest[0].content).to.include( + expect(getMessageContent(genericRequest[0])).to.include( 'Parse all user messages to find a database name and a collection name.' ); - expect(genericRequest[1].content).to.equal( + expect(getMessageContent(genericRequest[1])).to.equal( 'what is the shape of the documents in the pineapple collection?' ); @@ -544,20 +569,22 @@ suite('Participant Controller Test Suite', function () { const res = await invokeChatHandler(chatRequestMock); expect(sendRequestStub).to.have.been.calledTwice; - const intentRequest = sendRequestStub.firstCall.args[0]; + const intentRequest = sendRequestStub.firstCall + .args[0] as vscode.LanguageModelChatMessage[]; expect(intentRequest).to.have.length(2); - expect(intentRequest[0].content).to.include( + expect(getMessageContent(intentRequest[0])).to.include( 'Your task is to help guide a conversation with a user to the correct handler.' ); - expect(intentRequest[1].content).to.equal( + expect(getMessageContent(intentRequest[1])).to.equal( 'how to find documents in my collection?' ); - const genericRequest = sendRequestStub.secondCall.args[0]; + const genericRequest = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; expect(genericRequest).to.have.length(2); - expect(genericRequest[0].content).to.include( + expect(getMessageContent(genericRequest[0])).to.include( 'Your task is to help the user with MongoDB related questions.' ); - expect(genericRequest[1].content).to.equal( + expect(getMessageContent(genericRequest[1])).to.equal( 'how to find documents in my collection?' ); @@ -648,8 +675,9 @@ suite('Participant Controller Test Suite', function () { references: [], }; await invokeChatHandler(chatRequestMock); - const messages = sendRequestStub.secondCall.args[0]; - expect(messages[1].content).to.include( + const messages = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[1])).to.include( 'Collection schema: _id: ObjectId\n' + 'field.stringField: String\n' + 'field.arrayField: Array\n' @@ -712,8 +740,9 @@ suite('Participant Controller Test Suite', function () { references: [], }; await invokeChatHandler(chatRequestMock); - const messages = sendRequestStub.secondCall.args[0]; - expect(messages[1].content).to.include( + const messages = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[1])).to.include( 'Sample documents: [\n' + ' {\n' + " _id: ObjectId('63ed1d522d8573fa5c203661'),\n" + @@ -781,8 +810,9 @@ suite('Participant Controller Test Suite', function () { references: [], }; await invokeChatHandler(chatRequestMock); - const messages = sendRequestStub.secondCall.args[0]; - expect(messages[1].content).to.include( + const messages = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[1])).to.include( 'Sample document: {\n' + " _id: ObjectId('63ed1d522d8573fa5c203660'),\n" + ' field: {\n' + @@ -844,8 +874,9 @@ suite('Participant Controller Test Suite', function () { references: [], }; await invokeChatHandler(chatRequestMock); - const messages = sendRequestStub.secondCall.args[0]; - expect(messages[1].content).to.include( + const messages = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[1])).to.include( 'Sample document: {\n' + " _id: ObjectId('63ed1d522d8573fa5c203661'),\n" + ' field: {\n' + @@ -904,8 +935,11 @@ suite('Participant Controller Test Suite', function () { references: [], }; await invokeChatHandler(chatRequestMock); - const messages = sendRequestStub.secondCall.args[0]; - expect(messages[1].content).to.not.include('Sample documents'); + const messages = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[1])).to.not.include( + 'Sample documents' + ); assertCommandTelemetry('query', chatRequestMock, { callIndex: 0, @@ -932,8 +966,11 @@ suite('Participant Controller Test Suite', function () { references: [], }; await invokeChatHandler(chatRequestMock); - const messages = sendRequestStub.secondCall.args[0]; - expect(messages[1].content).to.not.include('Sample documents'); + const messages = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[1])).to.not.include( + 'Sample documents' + ); assertCommandTelemetry('query', chatRequestMock, { callIndex: 0, @@ -1374,7 +1411,9 @@ suite('Participant Controller Test Suite', function () { await invokeChatHandler(chatRequestMock); expect(sendRequestStub.calledOnce).to.be.true; - expect(sendRequestStub.firstCall.args[0][0].content).to.include( + const messages = sendRequestStub.firstCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[0])).to.include( 'Parse all user messages to find a database name and a collection name.' ); @@ -1422,10 +1461,13 @@ suite('Participant Controller Test Suite', function () { await invokeChatHandler(chatRequestMock); expect(sendRequestStub.calledOnce).to.be.true; - expect(sendRequestStub.firstCall.args[0][0].content).to.include( + + const messages = sendRequestStub.firstCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[0])).to.include( 'Parse all user messages to find a database name and a collection name.' ); - expect(sendRequestStub.firstCall.args[0][3].content).to.include( + expect(getMessageContent(messages[3])).to.include( 'see previous messages' ); }); @@ -1528,11 +1570,12 @@ suite('Participant Controller Test Suite', function () { references: [], }; await invokeChatHandler(chatRequestMock); - const messages = sendRequestStub.secondCall.args[0]; - expect(messages[0].content).to.include( + const messages = sendRequestStub.secondCall + .args[0] as vscode.LanguageModelChatMessage[]; + expect(getMessageContent(messages[0])).to.include( 'Amount of documents sampled: 2' ); - expect(messages[1].content).to.include( + expect(getMessageContent(messages[1])).to.include( `Database name: dbOne Collection name: collOne Schema: @@ -1540,7 +1583,8 @@ Schema: "count": 2, "fields": [` ); - expect(messages[1].content).to.include(`"name": "arrayField", + expect(getMessageContent(messages[1])).to + .include(`"name": "arrayField", "path": [ "field", "arrayField" @@ -1703,7 +1747,7 @@ Schema: expect(stats.has_sample_documents).to.be.false; expect(stats.user_input_length).to.equal(chatRequestMock.prompt.length); expect(stats.total_message_length).to.equal( - messages[0].content.length + messages[1].content.length + getContentLength(messages[0]) + getContentLength(messages[1]) ); }); @@ -1759,7 +1803,7 @@ Schema: expect(messages[1].role).to.equal( vscode.LanguageModelChatMessageRole.User ); - expect(messages[1].content).to.equal( + expect(getMessageContent(messages[1])).to.equal( 'give me the count of all people in the prod database' ); @@ -1772,15 +1816,15 @@ Schema: expect(stats.has_sample_documents).to.be.true; expect(stats.user_input_length).to.equal(chatRequestMock.prompt.length); expect(stats.total_message_length).to.equal( - messages[0].content.length + - messages[1].content.length + - messages[2].content.length + getContentLength(messages[0]) + + getContentLength(messages[1]) + + getContentLength(messages[2]) ); // The length of the user prompt length should be taken from the prompt supplied // by the user, even if we enhance it with sample docs and schema. expect(stats.user_input_length).to.be.lessThan( - messages[2].content.length + getContentLength(messages[2]) ); }); @@ -1812,20 +1856,22 @@ Schema: expect(messages[0].role).to.equal( vscode.LanguageModelChatMessageRole.Assistant ); - expect(messages[0].content).to.include('Amount of documents sampled: 3'); + expect(getMessageContent(messages[0])).to.include( + 'Amount of documents sampled: 3' + ); expect(messages[1].role).to.equal( vscode.LanguageModelChatMessageRole.User ); - expect(messages[1].content).to.include(databaseName); - expect(messages[1].content).to.include(collectionName); - expect(messages[1].content).to.include(schema); + expect(getMessageContent(messages[1])).to.include(databaseName); + expect(getMessageContent(messages[1])).to.include(collectionName); + expect(getMessageContent(messages[1])).to.include(schema); expect(stats.command).to.equal('schema'); expect(stats.has_sample_documents).to.be.false; expect(stats.user_input_length).to.equal(chatRequestMock.prompt.length); expect(stats.total_message_length).to.equal( - messages[0].content.length + messages[1].content.length + getContentLength(messages[0]) + getContentLength(messages[1]) ); }); @@ -1852,7 +1898,7 @@ Schema: expect(stats.has_sample_documents).to.be.false; expect(stats.user_input_length).to.equal(chatRequestMock.prompt.length); expect(stats.total_message_length).to.equal( - messages[0].content.length + messages[1].content.length + getContentLength(messages[0]) + getContentLength(messages[1]) ); }); @@ -1927,18 +1973,18 @@ Schema: expect(messages[1].role).to.equal( vscode.LanguageModelChatMessageRole.User ); - expect(messages[1].content).to.contain(expectedPrompt); + expect(getMessageContent(messages[1])).to.contain(expectedPrompt); expect(stats.command).to.equal('query'); expect(stats.has_sample_documents).to.be.false; expect(stats.user_input_length).to.equal(expectedPrompt.length); expect(stats.total_message_length).to.equal( - messages[0].content.length + messages[1].content.length + getContentLength(messages[0]) + getContentLength(messages[1]) ); // The prompt builder may add extra info, but we're only reporting the actual user input expect(stats.user_input_length).to.be.lessThan( - messages[1].content.length + getContentLength(messages[1]) ); }); }); From 3a3c429ddc128cc287e830618625f938f63f94a7 Mon Sep 17 00:00:00 2001 From: Nikola Irinchev Date: Thu, 24 Oct 2024 11:07:31 +0200 Subject: [PATCH 2/2] Address CR comments --- .vscode/launch.json | 26 +++++++++++++------------- src/participant/participant.ts | 4 ++-- src/participant/prompts/index.ts | 10 +++------- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index dd4f401e3..9d596e5fd 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -42,19 +42,19 @@ ] }, { - "name": "Run Tests", - "type": "extensionHost", - "request": "launch", - "runtimeExecutable": "${execPath}", - "args": [ - "${workspaceFolder}/out/test/suite", // TODO: remove suite - "--disable-extensions", - "--extensionDevelopmentPath=${workspaceFolder}", - "--extensionTestsPath=${workspaceFolder}/out/test/suite" - ], - "outFiles": ["${workspaceFolder}/out/**/*.js"], - "preLaunchTask": "npm: compile:extension", - } + "name": "Run Tests", + "type": "extensionHost", + "request": "launch", + "runtimeExecutable": "${execPath}", + "args": [ + "${workspaceFolder}/out/test/suite", // TODO: VSCODE-641 - remove suite + "--disable-extensions", + "--extensionDevelopmentPath=${workspaceFolder}", + "--extensionTestsPath=${workspaceFolder}/out/test/suite" + ], + "outFiles": ["${workspaceFolder}/out/**/*.js"], + "preLaunchTask": "npm: compile:extension", + } ], "compounds": [ { diff --git a/src/participant/participant.ts b/src/participant/participant.ts index b0530ef90..5dbf97b31 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -10,7 +10,7 @@ import type { LoadedConnection } from '../storage/connectionStorage'; import EXTENSION_COMMANDS from '../commands'; import type { StorageController } from '../storage'; import { StorageVariables } from '../storage'; -import { Prompts } from './prompts'; +import { getContentLength, Prompts } from './prompts'; import type { ChatResult } from './constants'; import { askToConnectChatResult, @@ -189,7 +189,7 @@ export default class ParticipantController { (message: vscode.LanguageModelChatMessage) => util.inspect({ role: message.role, - contentLength: Prompts.getContentLength(message), + contentLength: getContentLength(message), }) ), }); diff --git a/src/participant/prompts/index.ts b/src/participant/prompts/index.ts index 91d6bd81d..5324f2847 100644 --- a/src/participant/prompts/index.ts +++ b/src/participant/prompts/index.ts @@ -5,7 +5,9 @@ import { IntentPrompt } from './intent'; import { NamespacePrompt } from './namespace'; import { QueryPrompt } from './query'; import { SchemaPrompt } from './schema'; -import { isContentEmpty, getContentLength } from './promptBase'; +import { isContentEmpty } from './promptBase'; + +export { getContentLength } from './promptBase'; export class Prompts { public static generic = new GenericPrompt(); @@ -35,10 +37,4 @@ export class Prompts { return false; } - - public static getContentLength( - message: vscode.LanguageModelChatMessage - ): number { - return getContentLength(message); - } }