Skip to content

Commit

Permalink
feat: send schema to the model for better results VSCODE-581 (#792)
Browse files Browse the repository at this point in the history
* feat: send schema to the model for better results VSCODE-581

* refactor: change state in fetchCollectionSchema
  • Loading branch information
alenakhineika authored Aug 27, 2024
1 parent 16277cb commit 7a9ba2a
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 26 deletions.
96 changes: 80 additions & 16 deletions src/participant/participant.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as vscode from 'vscode';
import { getSimplifiedSchema } from 'mongodb-schema';

import { createLogger } from '../logging';
import type ConnectionController from '../connectionController';
Expand All @@ -10,6 +11,7 @@ import { GenericPrompt } from './prompts/generic';
import { CHAT_PARTICIPANT_ID, CHAT_PARTICIPANT_MODEL } from './constants';
import { QueryPrompt } from './prompts/query';
import { NamespacePrompt } from './prompts/namespace';
import { SchemaFormatter } from './schema';

const log = createLogger('participant');

Expand All @@ -18,9 +20,11 @@ export enum QUERY_GENERATION_STATE {
ASK_TO_CONNECT = 'ASK_TO_CONNECT',
ASK_FOR_DATABASE_NAME = 'ASK_FOR_DATABASE_NAME',
ASK_FOR_COLLECTION_NAME = 'ASK_FOR_COLLECTION_NAME',
READY_TO_GENERATE_QUERY = 'READY_TO_GENERATE_QUERY',
FETCH_SCHEMA = 'FETCH_SCHEMA',
}

const NUM_DOCUMENTS_TO_SAMPLE = 4;

interface ChatResult extends vscode.ChatResult {
metadata: {
responseContent?: string;
Expand Down Expand Up @@ -67,6 +71,7 @@ export default class ParticipantController {
_chatResult?: ChatResult;
_databaseName?: string;
_collectionName?: string;
_schema?: string;

constructor({
connectionController,
Expand All @@ -79,6 +84,26 @@ export default class ParticipantController {
this._storageController = storageController;
}

_setDatabaseName(name: string | undefined) {
if (
this._queryGenerationState === QUERY_GENERATION_STATE.DEFAULT &&
this._databaseName !== name
) {
this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA;
}
this._databaseName = name;
}

_setCollectionName(name: string | undefined) {
if (
this._queryGenerationState === QUERY_GENERATION_STATE.DEFAULT &&
this._collectionName !== name
) {
this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA;
}
this._collectionName = name;
}

createParticipant(context: vscode.ExtensionContext) {
// Chat participants appear as top-level options in the chat input
// when you type `@`, and can contribute sub-commands in the chat input
Expand Down Expand Up @@ -318,9 +343,10 @@ export default class ParticipantController {

async selectDatabaseWithParticipant(name: string): Promise<boolean> {
if (!name) {
this._databaseName = await this._selectDatabaseWithCommandPalette();
const selectedName = await this._selectDatabaseWithCommandPalette();
this._setDatabaseName(selectedName);
} else {
this._databaseName = name;
this._setDatabaseName(name);
}

return vscode.commands.executeCommand('workbench.action.chat.open', {
Expand Down Expand Up @@ -363,9 +389,10 @@ export default class ParticipantController {

async selectCollectionWithParticipant(name: string): Promise<boolean> {
if (!name) {
this._collectionName = await this._selectCollectionWithCommandPalette();
const selectedName = await this._selectCollectionWithCommandPalette();
this._setCollectionName(selectedName);
} else {
this._collectionName = name;
this._setCollectionName(name);
}

return vscode.commands.executeCommand('workbench.action.chat.open', {
Expand Down Expand Up @@ -448,8 +475,8 @@ export default class ParticipantController {
if (isNewChat) {
this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT;
this._chatResult = undefined;
this._databaseName = undefined;
this._collectionName = undefined;
this._setDatabaseName(undefined);
this._setCollectionName(undefined);
}
}

Expand All @@ -468,7 +495,7 @@ export default class ParticipantController {
this._queryGenerationState ===
QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME
) {
this._databaseName = prompt;
this._setDatabaseName(prompt);
if (!this._collectionName) {
this._queryGenerationState =
QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME;
Expand All @@ -481,14 +508,13 @@ export default class ParticipantController {
this._queryGenerationState ===
QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME
) {
this._collectionName = prompt;
this._setCollectionName(prompt);
if (!this._databaseName) {
this._queryGenerationState =
QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME;
return true;
}
this._queryGenerationState =
QUERY_GENERATION_STATE.READY_TO_GENERATE_QUERY;
this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA;
return false;
}

Expand Down Expand Up @@ -522,12 +548,11 @@ export default class ParticipantController {
responseContentWithNamespace
);

this._databaseName = namespace.databaseName || this._databaseName;
this._collectionName = namespace.collectionName || this._collectionName;
this._setDatabaseName(namespace.databaseName || this._databaseName);
this._setCollectionName(namespace.collectionName || this._collectionName);

if (namespace.databaseName && namespace.collectionName) {
this._queryGenerationState =
QUERY_GENERATION_STATE.READY_TO_GENERATE_QUERY;
this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA;
return false;
}

Expand Down Expand Up @@ -590,6 +615,41 @@ export default class ParticipantController {
return true;
}

_shouldFetchCollectionSchema(): boolean {
return this._queryGenerationState === QUERY_GENERATION_STATE.FETCH_SCHEMA;
}

async _fetchCollectionSchema(abortSignal?: AbortSignal): Promise<undefined> {
if (this._queryGenerationState === QUERY_GENERATION_STATE.FETCH_SCHEMA) {
this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT;
}

const dataService = this._connectionController.getActiveDataService();
if (!dataService || !this._databaseName || !this._collectionName) {
return;
}

try {
const sampleDocuments =
(await dataService?.sample?.(
`${this._databaseName}.${this._collectionName}`,
{
query: {},
size: NUM_DOCUMENTS_TO_SAMPLE,
},
{ promoteValues: false },
{
abortSignal,
}
)) || [];

const schema = await getSimplifiedSchema(sampleDocuments);
this._schema = new SchemaFormatter().format(schema);
} catch (err: any) {
this._schema = undefined;
}
}

// @MongoDB /query find all documents where the "address" has the word Broadway in it.
async handleQueryRequest(
request: vscode.ChatRequest,
Expand Down Expand Up @@ -621,13 +681,17 @@ export default class ParticipantController {
abortController.abort();
});

if (this._shouldFetchCollectionSchema()) {
await this._fetchCollectionSchema(abortController.signal);
}

const messages = QueryPrompt.buildMessages({
request,
context,
databaseName: this._databaseName,
collectionName: this._collectionName,
schema: this._schema,
});

const responseContent = await this.getChatResponseContent({
messages,
stream,
Expand Down
12 changes: 11 additions & 1 deletion src/participant/prompts/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ export class QueryPrompt {
static getAssistantPrompt({
databaseName = 'mongodbVSCodeCopilotDB',
collectionName = 'test',
schema,
}: {
databaseName?: string;
collectionName?: string;
schema?: string;
}): vscode.LanguageModelChatMessage {
const prompt = `You are a MongoDB expert.
Expand Down Expand Up @@ -38,6 +40,12 @@ db.getCollection('').find({
Database name: ${databaseName}
Collection name: ${collectionName}
${
schema
? `Collection schema:
${schema}`
: ''
}
MongoDB command to specify database:
use('');
Expand All @@ -61,16 +69,18 @@ Concisely explain the code snippet you have generated.`;
request,
databaseName,
collectionName,
schema,
}: {
request: {
prompt: string;
};
context: vscode.ChatContext;
databaseName?: string;
collectionName?: string;
schema?: string;
}): vscode.LanguageModelChatMessage[] {
const messages = [
QueryPrompt.getAssistantPrompt({ databaseName, collectionName }),
QueryPrompt.getAssistantPrompt({ databaseName, collectionName, schema }),
...getHistoryMessages({ context }),
QueryPrompt.getUserPrompt(request.prompt),
];
Expand Down
102 changes: 102 additions & 0 deletions src/participant/schema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import type {
SimplifiedSchema,
SimplifiedSchemaArrayType,
SimplifiedSchemaDocumentType,
SimplifiedSchemaType,
} from 'mongodb-schema';

const PROPERTY_REGEX = '^[a-zA-Z_$][0-9a-zA-Z_$]*$';

export class SchemaFormatter {
static getSchemaFromTypes(pInput: SimplifiedSchema): string {
return new SchemaFormatter().format(pInput);
}

schemaString = '';

format(pInitial: SimplifiedSchema): string {
this.processDocumentType('', pInitial);
return this.schemaString;
}

private processSchemaTypeList(
prefix: string,
pTypes: SimplifiedSchemaType[]
) {
if (pTypes.length !== 0) {
this.processSchemaType(prefix, pTypes[0]);
}
}

private processSchemaType(prefix: string, pType: SimplifiedSchemaType) {
const bsonType = pType.bsonType;
if (bsonType === 'Document') {
const fields = (pType as SimplifiedSchemaDocumentType).fields;

if (Object.keys(fields).length === 0) {
this.addToFormattedSchemaString(prefix + ': Document');
return;
}

this.processDocumentType(prefix, fields);
return;
}

if (bsonType === 'Array') {
const types = (pType as SimplifiedSchemaArrayType).types;

if (types.length === 0) {
this.addToFormattedSchemaString(prefix + ': ' + 'Array');
return;
}

const firstType = types[0].bsonType;
if (firstType !== 'Array' && firstType !== 'Document') {
this.addToFormattedSchemaString(
prefix + ': ' + 'Array<' + firstType + '>'
);
return;
}

// Array of documents or arrays.
// We only use the first type.
this.processSchemaType(prefix + '[]', types[0]);
return;
}

this.addToFormattedSchemaString(prefix + ': ' + bsonType);
}

private processDocumentType(prefix: string, pDoc: SimplifiedSchema) {
if (!pDoc) {
return;
}

Object.keys(pDoc).forEach((key) => {
const keyAsString = this.getPropAsString(key);
this.processSchemaTypeList(
prefix + (prefix.length === 0 ? '' : '.') + keyAsString,
pDoc[key]?.types
);
});
}

getPropAsString(pProp: string): string {
if (pProp.match(PROPERTY_REGEX)) {
return pProp;
}

try {
return JSON.stringify(pProp);
} catch (e) {
return pProp;
}
}

addToFormattedSchemaString(fieldAndType: string) {
if (this.schemaString.length > 0) {
this.schemaString += '\n';
}
this.schemaString += fieldAndType;
}
}
Loading

0 comments on commit 7a9ba2a

Please sign in to comment.