Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(chat): Refactor prompt hierarchy #834

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/participant/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ export type ParticipantResponseType =
| 'askForNamespace';

interface Metadata {
intent: Exclude<ParticipantResponseType, 'askForNamespace'>;
intent: Exclude<ParticipantResponseType, 'askForNamespace' | 'docs'>;
chatId: string;
docsChatbotMessageId?: string;
}

interface AskForNamespaceMetadata {
Expand All @@ -27,8 +26,14 @@ interface AskForNamespaceMetadata {
collectionName?: string | undefined;
}

interface DocsRequestMetadata {
intent: 'docs';
chatId: string;
docsChatbotMessageId?: string;
}
Comment on lines +29 to +33
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a drive-by - I wanted to more strongly confine the metadata types to ensure we don't have docsChatbotMessageId unless the intent is docs.


export interface ChatResult extends vscode.ChatResult {
readonly metadata: Metadata | AskForNamespaceMetadata;
readonly metadata: Metadata | AskForNamespaceMetadata | DocsRequestMetadata;
}

export function namespaceRequestChatResult({
Expand Down
66 changes: 32 additions & 34 deletions src/participant/participant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import type { LoadedConnection } from '../storage/connectionStorage';
import EXTENSION_COMMANDS from '../commands';
import type { StorageController } from '../storage';
import { StorageVariables } from '../storage';
import { GenericPrompt, isPromptEmpty } from './prompts/generic';
import { Prompts } from './prompts';
import type { ChatResult } from './constants';
import {
askToConnectChatResult,
Expand All @@ -22,18 +22,14 @@ import {
schemaRequestChatResult,
createCancelledRequestChatResult,
} from './constants';
import { QueryPrompt } from './prompts/query';
import { COL_NAME_ID, DB_NAME_ID, NamespacePrompt } from './prompts/namespace';
import { SchemaFormatter } from './schema';
import { getSimplifiedSampleDocuments } from './sampleDocuments';
import { getCopilotModel } from './model';
import { createMarkdownLink } from './markdown';
import { ChatMetadataStore } from './chatMetadata';
import { doesLastMessageAskForNamespace } from './prompts/history';
import {
DOCUMENTS_TO_SAMPLE_FOR_SCHEMA_PROMPT,
type OpenSchemaCommandArgs,
SchemaPrompt,
} from './prompts/schema';
import {
chatResultFeedbackKindToTelemetryValue,
Expand All @@ -58,22 +54,10 @@ export type RunParticipantQueryCommandArgs = {
runnableContent: string;
};

const DB_NAME_REGEX = `${DB_NAME_ID}: (.*)`;
const COL_NAME_REGEX = `${COL_NAME_ID}: (.*)`;

export type ParticipantCommand = '/query' | '/schema' | '/docs';

const MAX_MARKDOWN_LIST_LENGTH = 10;

export function parseForDatabaseAndCollectionName(text: string): {
databaseName?: string;
collectionName?: string;
} {
const databaseName = text.match(DB_NAME_REGEX)?.[1].trim();
const collectionName = text.match(COL_NAME_REGEX)?.[1].trim();
return { databaseName, collectionName };
}

export function getRunnableContentFromString(text: string): string {
const matchedJSresponseContent = text.match(/```javascript((.|\n)*)```/);

Expand Down Expand Up @@ -237,7 +221,7 @@ export default class ParticipantController {
stream: vscode.ChatResponseStream,
token: vscode.CancellationToken
): Promise<ChatResult> {
const messages = GenericPrompt.buildMessages({
const messages = await Prompts.generic.buildMessages({
request,
context,
});
Expand Down Expand Up @@ -578,7 +562,7 @@ export default class ParticipantController {
databaseName: string | undefined;
collectionName: string | undefined;
}> {
const messagesWithNamespace = NamespacePrompt.buildMessages({
const messagesWithNamespace = await Prompts.namespace.buildMessages({
context,
request,
connectionNames: this._connectionController
Expand All @@ -589,9 +573,10 @@ export default class ParticipantController {
messages: messagesWithNamespace,
token,
});
const namespace = parseForDatabaseAndCollectionName(
responseContentWithNamespace
);
const { databaseName, collectionName } =
Prompts.namespace.extractDatabaseAndCollectionNameFromResponse(
responseContentWithNamespace
);

// See if there's a namespace set in the
// chat metadata we can fallback to if the model didn't find it.
Expand All @@ -604,8 +589,8 @@ export default class ParticipantController {
} = this._chatMetadataStore.getChatMetadata(chatId) ?? {};

return {
databaseName: namespace.databaseName ?? databaseNameFromMetadata,
collectionName: namespace.collectionName ?? collectionNameFromMetadata,
databaseName: databaseName || databaseNameFromMetadata,
collectionName: collectionName || collectionNameFromMetadata,
};
}

Expand Down Expand Up @@ -655,6 +640,19 @@ export default class ParticipantController {
});
}

_doesLastMessageAskForNamespace(
history: ReadonlyArray<vscode.ChatRequestTurn | vscode.ChatResponseTurn>
): boolean {
const lastMessageMetaData = history[
history.length - 1
] as vscode.ChatResponseTurn;

return (
(lastMessageMetaData?.result as ChatResult)?.metadata?.intent ===
'askForNamespace'
);
}

_askToConnect({
command,
context,
Expand Down Expand Up @@ -786,7 +784,7 @@ export default class ParticipantController {
.history[context.history.length - 1] as vscode.ChatResponseTurn;
const lastMessage = lastMessageMetaData?.result as ChatResult;
if (lastMessage?.metadata?.intent !== 'askForNamespace') {
stream.markdown(GenericPrompt.getEmptyRequestResponse());
stream.markdown(Prompts.generic.getEmptyRequestResponse());
return emptyRequestChatResult(context.history);
}

Expand Down Expand Up @@ -841,8 +839,8 @@ export default class ParticipantController {
}

if (
isPromptEmpty(request) &&
doesLastMessageAskForNamespace(context.history)
Prompts.isPromptEmpty(request) &&
this._doesLastMessageAskForNamespace(context.history)
) {
return this.handleEmptyNamespaceMessage({
command: '/schema',
Expand Down Expand Up @@ -907,7 +905,7 @@ export default class ParticipantController {
return schemaRequestChatResult(context.history);
}

const messages = SchemaPrompt.buildMessages({
const messages = await Prompts.schema.buildMessages({
request,
context,
databaseName,
Expand Down Expand Up @@ -953,16 +951,16 @@ export default class ParticipantController {
});
}

if (isPromptEmpty(request)) {
if (doesLastMessageAskForNamespace(context.history)) {
if (Prompts.isPromptEmpty(request)) {
if (this._doesLastMessageAskForNamespace(context.history)) {
return this.handleEmptyNamespaceMessage({
command: '/query',
context,
stream,
});
}

stream.markdown(QueryPrompt.getEmptyRequestResponse());
stream.markdown(Prompts.query.emptyRequestResponse);
return emptyRequestChatResult(context.history);
}

Expand Down Expand Up @@ -1013,7 +1011,7 @@ export default class ParticipantController {
);
}

const messages = await QueryPrompt.buildMessages({
const messages = await Prompts.query.buildMessages({
request,
context,
databaseName,
Expand Down Expand Up @@ -1106,7 +1104,7 @@ export default class ParticipantController {
responseReferences?: Reference[];
}> {
const [request, context, , token] = args;
const messages = GenericPrompt.buildMessages({
const messages = await Prompts.generic.buildMessages({
request,
context,
});
Expand Down Expand Up @@ -1241,7 +1239,7 @@ Please see our [FAQ](https://www.mongodb.com/docs/generative-ai-faq/) for more i
return await this.handleSchemaRequest(...args);
default:
if (!request.prompt?.trim()) {
stream.markdown(GenericPrompt.getEmptyRequestResponse());
stream.markdown(Prompts.generic.getEmptyRequestResponse());
return emptyRequestChatResult(args[1].history);
}

Expand Down
42 changes: 8 additions & 34 deletions src/participant/prompts/generic.ts
Original file line number Diff line number Diff line change
@@ -1,52 +1,26 @@
import * as vscode from 'vscode';
import type { PromptArgsBase } from './promptBase';
import { PromptBase } from './promptBase';

import { getHistoryMessages } from './history';

export class GenericPrompt {
static getAssistantPrompt(): vscode.LanguageModelChatMessage {
const prompt = `You are a MongoDB expert.
export class GenericPrompt extends PromptBase<PromptArgsBase> {
protected getAssistantPrompt(): string {
return `You are a MongoDB expert.
Your task is to help the user craft MongoDB queries and aggregation pipelines that perform their task.
Keep your response concise.
You should suggest queries that are performant and correct.
Respond with markdown, suggest code in a Markdown code block that begins with \`\`\`javascript and ends with \`\`\`.
You can imagine the schema, collection, and database name.
Respond in MongoDB shell syntax using the \`\`\`javascript code block syntax.`;

// eslint-disable-next-line new-cap
return vscode.LanguageModelChatMessage.Assistant(prompt);
}

static getUserPrompt(prompt: string): vscode.LanguageModelChatMessage {
// eslint-disable-next-line new-cap
return vscode.LanguageModelChatMessage.User(prompt);
protected getUserPrompt(args: PromptArgsBase): Promise<string> {
return Promise.resolve(args.request.prompt);
}

static getEmptyRequestResponse(): string {
public getEmptyRequestResponse(): string {
// TODO(VSCODE-572): Generic empty response handler
return vscode.l10n.t(
'Ask anything about MongoDB, from writing queries to questions about your cluster.'
);
}

static buildMessages({
context,
request,
}: {
request: {
prompt: string;
};
context: vscode.ChatContext;
}): vscode.LanguageModelChatMessage[] {
const messages = [
GenericPrompt.getAssistantPrompt(),
...getHistoryMessages({ context }),
GenericPrompt.getUserPrompt(request.prompt),
];

return messages;
}
}

export function isPromptEmpty(request: vscode.ChatRequest): boolean {
return !request.prompt || request.prompt.trim().length === 0;
}
82 changes: 0 additions & 82 deletions src/participant/prompts/history.ts

This file was deleted.

16 changes: 16 additions & 0 deletions src/participant/prompts/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { GenericPrompt } from './generic';
import type * as vscode from 'vscode';
import { NamespacePrompt } from './namespace';
import { QueryPrompt } from './query';
import { SchemaPrompt } from './schema';

export class Prompts {
public static generic = new GenericPrompt();
public static namespace = new NamespacePrompt();
public static query = new QueryPrompt();
public static schema = new SchemaPrompt();

public static isPromptEmpty(request: vscode.ChatRequest): boolean {
return !request.prompt || request.prompt.trim().length === 0;
}
}
Loading
Loading