Skip to content

Commit

Permalink
feat: send sample documents to the model for better results VSCODE-580 (
Browse files Browse the repository at this point in the history
#806)

* feat: send sample documents to the model for better results VSCODE-580

* refactor: address pr comments

* refactor: count tokens
  • Loading branch information
alenakhineika authored Sep 6, 2024
1 parent 15c3e74 commit b54b22f
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 80 deletions.
6 changes: 6 additions & 0 deletions .eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ module.exports = {
'error',
{ prefer: 'type-imports' },
],
'@typescript-eslint/explicit-function-return-type': [
'warn',
{
allowHigherOrderFunctions: true,
},
],
},
parserOptions: {
project: ['./tsconfig.json'], // Specify it only for TypeScript files.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Connect to Atlas Stream Processing instances and develop stream processors using
| `mdb.defaultLimit` | The number of documents to fetch when viewing documents from a collection. | `10` |
| `mdb.confirmRunAll` | Show a confirmation message before running commands in a playground. | `true` |
| `mdb.confirmRunCopilotCode` | Show a confirmation message before running code generated by the MongoDB participant. | `true` |
| `mdb.useSampleDocsInCopilot` | Enable sending sample field values with the VSCode copilot chat @MongoDB participant /query command. | `false` |
| `mdb.confirmDeleteDocument` | Show a confirmation message before deleting a document in the tree view. | `true` |
| `mdb.persistOIDCTokens` | Remain logged in when using the MONGODB-OIDC authentication mechanism for MongoDB server connection. Access tokens are encrypted using the system keychain before being stored. | `true` |
| `mdb.showOIDCDeviceAuthFlow` | Opt-in and opt-out for diagnostic and telemetry collection. | `true` |
Expand Down
5 changes: 5 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,11 @@
"default": true,
"description": "Show a confirmation message before running code generated by the MongoDB participant."
},
"mdb.useSampleDocsInCopilot": {
"type": "boolean",
"default": false,
"description": "Enable sending sample field values with the VSCode copilot chat @MongoDB participant /query command."
},
"mdb.confirmDeleteDocument": {
"type": "boolean",
"default": true,
Expand Down
4 changes: 2 additions & 2 deletions src/editors/editDocumentCodeLensProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export default class EditDocumentCodeLensProvider
content: Document;
namespace: string | null;
uri: vscode.Uri;
}) {
}): void {
let resultCodeLensesInfo: EditDocumentInfo[] = [];

resultCodeLensesInfo = this._updateCodeLensesForCursor({
Expand All @@ -44,7 +44,7 @@ export default class EditDocumentCodeLensProvider
this._codeLensesInfo[data.uri.toString()] = resultCodeLensesInfo;
}

updateCodeLensesForPlayground(playgroundResult: PlaygroundResult) {
updateCodeLensesForPlayground(playgroundResult: PlaygroundResult): void {
const source = DocumentSource.DOCUMENT_SOURCE_PLAYGROUND;
let resultCodeLensesInfo: EditDocumentInfo[] = [];

Expand Down
14 changes: 4 additions & 10 deletions src/editors/playgroundController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ export default class PlaygroundController {

await this._openInResultPane(evaluateResponse.result);

return Promise.resolve(true);
return true;
}

async _evaluatePlayground(text: string): Promise<boolean> {
Expand Down Expand Up @@ -684,17 +684,11 @@ export default class PlaygroundController {
return Promise.resolve(false);
}

const selections = this._activeTextEditor.selections;

let codeToEvaluate;
if (
!selections ||
!Array.isArray(selections) ||
(selections.length === 1 && this._getSelectedText(selections[0]) === '')
) {
let codeToEvaluate = '';
if (!this._selectedText) {
this._isPartialRun = false;
codeToEvaluate = this._getAllText();
} else if (this._selectedText) {
} else {
this._isPartialRun = true;
codeToEvaluate = this._selectedText;
}
Expand Down
15 changes: 9 additions & 6 deletions src/mdbExtensionController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ export default class MDBExtensionController implements vscode.Disposable {
this._helpExplorer.activateHelpTreeView(this._telemetryService);
this._playgroundsExplorer.activatePlaygroundsTreeView();
this._telemetryService.activateSegmentAnalytics();
this._participantController.createParticipant(this._context);

await this._connectionController.loadSavedConnections();
await this._languageServerController.startLanguageServer();
Expand Down Expand Up @@ -332,11 +333,13 @@ export default class MDBExtensionController implements vscode.Disposable {

return commandHandler(args);
};

this._context.subscriptions.push(
this._participantController.getParticipant(this._context),
vscode.commands.registerCommand(command, commandHandlerWithTelemetry)
);
const participant = this._participantController.getParticipant();
if (participant) {
this._context.subscriptions.push(
participant,
vscode.commands.registerCommand(command, commandHandlerWithTelemetry)
);
}
};

registerCommand = (
Expand Down Expand Up @@ -778,7 +781,7 @@ export default class MDBExtensionController implements vscode.Disposable {
this.registerAtlasStreamsTreeViewCommands();
}

registerAtlasStreamsTreeViewCommands() {
registerAtlasStreamsTreeViewCommands(): void {
this.registerCommand(
EXTENSION_COMMANDS.MDB_ADD_STREAM_PROCESSOR,
async (element: ConnectionTreeItem): Promise<boolean> => {
Expand Down
23 changes: 23 additions & 0 deletions src/participant/model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import * as vscode from 'vscode';

import { CHAT_PARTICIPANT_MODEL } from './constants';

let model: vscode.LanguageModelChat;

export async function getCopilotModel(): Promise<
vscode.LanguageModelChat | undefined
> {
if (!model) {
try {
const [model] = await vscode.lm.selectChatModels({
vendor: 'copilot',
family: CHAT_PARTICIPANT_MODEL,
});
return model;
} catch (err) {
// Model is not ready yet. It is being initialised with the first user prompt.
}
}

return;
}
64 changes: 42 additions & 22 deletions src/participant/participant.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import * as vscode from 'vscode';
import { getSimplifiedSchema } from 'mongodb-schema';
import type { Document } from 'bson';

import { createLogger } from '../logging';
import type ConnectionController from '../connectionController';
Expand All @@ -8,10 +9,12 @@ import EXTENSION_COMMANDS from '../commands';
import type { StorageController } from '../storage';
import { StorageVariables } from '../storage';
import { GenericPrompt } from './prompts/generic';
import { CHAT_PARTICIPANT_ID, CHAT_PARTICIPANT_MODEL } from './constants';
import { CHAT_PARTICIPANT_ID } from './constants';
import { QueryPrompt } from './prompts/query';
import { COL_NAME_ID, DB_NAME_ID, NamespacePrompt } from './prompts/namespace';
import { SchemaFormatter } from './schema';
import { getSimplifiedSampleDocuments } from './sampleDocuments';
import { getCopilotModel } from './model';

const log = createLogger('participant');

Expand All @@ -20,10 +23,11 @@ export enum QUERY_GENERATION_STATE {
ASK_TO_CONNECT = 'ASK_TO_CONNECT',
ASK_FOR_DATABASE_NAME = 'ASK_FOR_DATABASE_NAME',
ASK_FOR_COLLECTION_NAME = 'ASK_FOR_COLLECTION_NAME',
CHANGE_DATABASE_NAME = 'CHANGE_DATABASE_NAME',
FETCH_SCHEMA = 'FETCH_SCHEMA',
}

const NUM_DOCUMENTS_TO_SAMPLE = 4;
const NUM_DOCUMENTS_TO_SAMPLE = 3;

interface ChatResult extends vscode.ChatResult {
metadata: {
Expand All @@ -50,7 +54,7 @@ export function parseForDatabaseAndCollectionName(text: string): {
return { databaseName, collectionName };
}

export function getRunnableContentFromString(text: string) {
export function getRunnableContentFromString(text: string): string {
const matchedJSresponseContent = text.match(/```javascript((.|\n)*)```/);

const code =
Expand All @@ -69,6 +73,7 @@ export default class ParticipantController {
_databaseName?: string;
_collectionName?: string;
_schema?: string;
_sampleDocuments?: Document[];

constructor({
connectionController,
Expand All @@ -81,17 +86,18 @@ export default class ParticipantController {
this._storageController = storageController;
}

_setDatabaseName(name: string | undefined) {
_setDatabaseName(name: string | undefined): void {
if (
this._queryGenerationState === QUERY_GENERATION_STATE.DEFAULT &&
this._databaseName !== name
) {
this._queryGenerationState = QUERY_GENERATION_STATE.FETCH_SCHEMA;
this._queryGenerationState = QUERY_GENERATION_STATE.CHANGE_DATABASE_NAME;
this._collectionName = undefined;
}
this._databaseName = name;
}

_setCollectionName(name: string | undefined) {
_setCollectionName(name: string | undefined): void {
if (
this._queryGenerationState === QUERY_GENERATION_STATE.DEFAULT &&
this._collectionName !== name
Expand All @@ -101,7 +107,7 @@ export default class ParticipantController {
this._collectionName = name;
}

createParticipant(context: vscode.ExtensionContext) {
createParticipant(context: vscode.ExtensionContext): vscode.ChatParticipant {
// Chat participants appear as top-level options in the chat input
// when you type `@`, and can contribute sub-commands in the chat input
// that appear when you type `/`.
Expand All @@ -120,8 +126,8 @@ export default class ParticipantController {
return this._participant;
}

getParticipant(context: vscode.ExtensionContext) {
return this._participant || this.createParticipant(context);
getParticipant(): vscode.ChatParticipant | undefined {
return this._participant;
}

async handleEmptyQueryRequest(): Promise<(string | vscode.MarkdownString)[]> {
Expand Down Expand Up @@ -193,20 +199,17 @@ export default class ParticipantController {
stream: vscode.ChatResponseStream;
token: vscode.CancellationToken;
}): Promise<string> {
const model = await getCopilotModel();
let responseContent = '';
try {
const [model] = await vscode.lm.selectChatModels({
vendor: 'copilot',
family: CHAT_PARTICIPANT_MODEL,
});
if (model) {
if (model) {
try {
const chatResponse = await model.sendRequest(messages, {}, token);
for await (const fragment of chatResponse.text) {
responseContent += fragment;
}
} catch (err) {
this.handleError(err, stream);
}
} catch (err) {
this.handleError(err, stream);
}

return responseContent;
Expand Down Expand Up @@ -483,14 +486,17 @@ export default class ParticipantController {
![
QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME,
QUERY_GENERATION_STATE.ASK_FOR_COLLECTION_NAME,
QUERY_GENERATION_STATE.CHANGE_DATABASE_NAME,
].includes(this._queryGenerationState)
) {
return false;
}

if (
this._queryGenerationState ===
QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME
[
QUERY_GENERATION_STATE.ASK_FOR_DATABASE_NAME,
QUERY_GENERATION_STATE.CHANGE_DATABASE_NAME,
].includes(this._queryGenerationState)
) {
this._setDatabaseName(prompt);
if (!this._collectionName) {
Expand Down Expand Up @@ -616,7 +622,9 @@ export default class ParticipantController {
return this._queryGenerationState === QUERY_GENERATION_STATE.FETCH_SCHEMA;
}

async _fetchCollectionSchema(abortSignal?: AbortSignal): Promise<undefined> {
async _fetchCollectionSchemaAndSampleDocuments(
abortSignal?: AbortSignal
): Promise<undefined> {
if (this._queryGenerationState === QUERY_GENERATION_STATE.FETCH_SCHEMA) {
this._queryGenerationState = QUERY_GENERATION_STATE.DEFAULT;
}
Expand All @@ -642,8 +650,17 @@ export default class ParticipantController {

const schema = await getSimplifiedSchema(sampleDocuments);
this._schema = new SchemaFormatter().format(schema);

const useSampleDocsInCopilot = !!vscode.workspace
.getConfiguration('mdb')
.get('useSampleDocsInCopilot');

if (useSampleDocsInCopilot) {
this._sampleDocuments = getSimplifiedSampleDocuments(sampleDocuments);
}
} catch (err: any) {
this._schema = undefined;
this._sampleDocuments = undefined;
}
}

Expand Down Expand Up @@ -679,15 +696,18 @@ export default class ParticipantController {
});

if (this._shouldFetchCollectionSchema()) {
await this._fetchCollectionSchema(abortController.signal);
await this._fetchCollectionSchemaAndSampleDocuments(
abortController.signal
);
}

const messages = QueryPrompt.buildMessages({
const messages = await QueryPrompt.buildMessages({
request,
context,
databaseName: this._databaseName,
collectionName: this._collectionName,
schema: this._schema,
sampleDocuments: this._sampleDocuments,
});
const responseContent = await this.getChatResponseContent({
messages,
Expand Down
Loading

0 comments on commit b54b22f

Please sign in to comment.