From 5152823306845141c6d8487abfd1d09984cc9369 Mon Sep 17 00:00:00 2001 From: Lewis Sanchez <87730006+lewis-sanchez@users.noreply.github.com> Date: Fri, 1 Sep 2023 09:26:29 -0700 Subject: [PATCH] Send server contextualization to Copilot extension (#24230) * Send server contextualization to Copilot extension * Keep context in editor input * Remove unnecessary server context and extension service * Send context when connecting from open editor * Remove contextualization complete event * Contextualize editor after connection success * Minor clean up * Remove nested then and use async/await * Create helper function * Remove unneeded async and add comment * Encapsulate all context logic in service * Use void operator to fix floating promise * Correct return comment --- extensions/mssql/src/contracts.ts | 4 +- extensions/mssql/src/features.ts | 12 ++- src/sql/azdata.proposed.d.ts | 15 +++- .../api/common/extHostDataProtocol.ts | 4 +- .../api/common/sqlExtHost.protocol.ts | 2 +- .../editor/query/fileQueryEditorInput.ts | 6 +- .../editor/query/untitledQueryEditorInput.ts | 6 +- .../common/editor/query/queryEditorInput.ts | 30 +------- .../test/electron-browser/commandLine.test.ts | 2 +- .../query/test/browser/queryEditor.test.ts | 1 - .../contextualization/common/interfaces.ts | 13 +--- .../common/serverContextualizationService.ts | 77 +++++++++++++++---- 12 files changed, 100 insertions(+), 72 deletions(-) diff --git a/extensions/mssql/src/contracts.ts b/extensions/mssql/src/contracts.ts index 993efc1b5800..7da000ebce3d 100644 --- a/extensions/mssql/src/contracts.ts +++ b/extensions/mssql/src/contracts.ts @@ -1567,8 +1567,8 @@ export interface ServerContextualizationParams { ownerUri: string; } -export namespace GenerateServerContextualizationNotification { - export const type = new NotificationType('metadata/generateServerContext'); +export namespace GenerateServerContextualizationRequest { + export const type = new RequestType('metadata/generateServerContext'); } export namespace GetServerContextualizationRequest { diff --git a/extensions/mssql/src/features.ts b/extensions/mssql/src/features.ts index 9d79b97e1b4d..ea2e8f4755a4 100644 --- a/extensions/mssql/src/features.ts +++ b/extensions/mssql/src/features.ts @@ -1310,7 +1310,7 @@ export class ExecutionPlanServiceFeature extends SqlOpsFeature { */ export class ServerContextualizationServiceFeature extends SqlOpsFeature { private static readonly messagesTypes: RPCMessageType[] = [ - contracts.GenerateServerContextualizationNotification.type + contracts.GenerateServerContextualizationRequest.type ]; constructor(client: SqlOpsDataClient) { @@ -1330,12 +1330,18 @@ export class ServerContextualizationServiceFeature extends SqlOpsFeature { + const generateServerContextualization = (ownerUri: string): Thenable => { const params: contracts.ServerContextualizationParams = { ownerUri: ownerUri }; - return client.sendNotification(contracts.GenerateServerContextualizationNotification.type, params); + return client.sendRequest(contracts.GenerateServerContextualizationRequest.type, params).then( + r => r, + e => { + client.logFailedRequest(contracts.GenerateServerContextualizationRequest.type, e); + return Promise.reject(e); + } + ); }; const getServerContextualization = (ownerUri: string): Thenable => { diff --git a/src/sql/azdata.proposed.d.ts b/src/sql/azdata.proposed.d.ts index 4f4928ede4ba..2cd5d3aae9ba 100644 --- a/src/sql/azdata.proposed.d.ts +++ b/src/sql/azdata.proposed.d.ts @@ -901,7 +901,7 @@ declare module 'azdata' { * Copilot for improved suggestions. * @param provider The provider to register */ - export function registerServerContextualizationProvider(provider: contextualization.ServerContextualizationProvider): vscode.Disposable + export function registerServerContextualizationProvider(provider: contextualization.ServerContextualizationProvider): vscode.Disposable; } export namespace designers { @@ -1783,11 +1783,18 @@ declare module 'azdata' { } export namespace contextualization { + export interface GenerateServerContextualizationResult { + /** + * The generated server context. + */ + context: string | undefined; + } + export interface GetServerContextualizationResult { /** - * An array containing the generated server context. + * The retrieved server context. */ - context: string[]; + context: string | undefined; } export interface ServerContextualizationProvider extends DataProvider { @@ -1795,7 +1802,7 @@ declare module 'azdata' { * Generates server context. * @param ownerUri The URI of the connection to generate context for. */ - generateServerContextualization(ownerUri: string): void; + generateServerContextualization(ownerUri: string): Thenable; /** * Gets server context, which can be in the form of create scripts but is left up each provider. diff --git a/src/sql/workbench/api/common/extHostDataProtocol.ts b/src/sql/workbench/api/common/extHostDataProtocol.ts index 98a13bf62029..8a102b176620 100644 --- a/src/sql/workbench/api/common/extHostDataProtocol.ts +++ b/src/sql/workbench/api/common/extHostDataProtocol.ts @@ -972,8 +972,8 @@ export class ExtHostDataProtocol extends ExtHostDataProtocolShape { // Database Server Contextualization API - public override $generateServerContextualization(handle: number, ownerUri: string): void { - this._resolveProvider(handle).generateServerContextualization(ownerUri); + public override $generateServerContextualization(handle: number, ownerUri: string): Thenable { + return this._resolveProvider(handle).generateServerContextualization(ownerUri); } public override $getServerContextualization(handle: number, ownerUri: string): Thenable { diff --git a/src/sql/workbench/api/common/sqlExtHost.protocol.ts b/src/sql/workbench/api/common/sqlExtHost.protocol.ts index 6e700cfeda4c..8c65c8725778 100644 --- a/src/sql/workbench/api/common/sqlExtHost.protocol.ts +++ b/src/sql/workbench/api/common/sqlExtHost.protocol.ts @@ -598,7 +598,7 @@ export abstract class ExtHostDataProtocolShape { /** * Generates server context. */ - $generateServerContextualization(handle: number, ownerUri: string): void { throw ni(); } + $generateServerContextualization(handle: number, ownerUri: string): Thenable { throw ni(); } /** * Gets server context. */ diff --git a/src/sql/workbench/browser/editor/query/fileQueryEditorInput.ts b/src/sql/workbench/browser/editor/query/fileQueryEditorInput.ts index aeb05d5fe150..bc4ea8b69098 100644 --- a/src/sql/workbench/browser/editor/query/fileQueryEditorInput.ts +++ b/src/sql/workbench/browser/editor/query/fileQueryEditorInput.ts @@ -19,7 +19,6 @@ import { IInstantiationService } from 'vs/platform/instantiation/common/instanti import { EditorInput } from 'vs/workbench/common/editor/editorInput'; import { IResourceEditorInput } from 'vs/platform/editor/common/editor'; import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces'; -import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions'; export class FileQueryEditorInput extends QueryEditorInput { @@ -33,10 +32,9 @@ export class FileQueryEditorInput extends QueryEditorInput { @IQueryModelService queryModelService: IQueryModelService, @IConfigurationService configurationService: IConfigurationService, @IInstantiationService instantiationService: IInstantiationService, - @IServerContextualizationService serverContextualizationService: IServerContextualizationService, - @IExtensionService extensionService: IExtensionService + @IServerContextualizationService serverContextualizationService: IServerContextualizationService ) { - super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService, extensionService); + super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService); } public override resolve(): Promise { diff --git a/src/sql/workbench/browser/editor/query/untitledQueryEditorInput.ts b/src/sql/workbench/browser/editor/query/untitledQueryEditorInput.ts index 3d6f80236081..78fc3591094c 100644 --- a/src/sql/workbench/browser/editor/query/untitledQueryEditorInput.ts +++ b/src/sql/workbench/browser/editor/query/untitledQueryEditorInput.ts @@ -23,7 +23,6 @@ import { IEditorResolverService } from 'vs/workbench/services/editor/common/edit import { Uri } from 'vscode'; import { ILogService } from 'vs/platform/log/common/log'; import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces'; -import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions'; export class UntitledQueryEditorInput extends QueryEditorInput implements IUntitledQueryEditorInput { @@ -39,10 +38,9 @@ export class UntitledQueryEditorInput extends QueryEditorInput implements IUntit @IInstantiationService instantiationService: IInstantiationService, @ILogService private readonly logService: ILogService, @IEditorResolverService private readonly editorResolverService: IEditorResolverService, - @IServerContextualizationService serverContextualizationService: IServerContextualizationService, - @IExtensionService extensionService: IExtensionService + @IServerContextualizationService serverContextualizationService: IServerContextualizationService ) { - super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService, extensionService); + super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService); // Set the mode explicitely to stop the auto language detection service from changing the mode unexpectedly. // the auto language detection service won't do the language change only if the mode is explicitely set. // if the mode (e.g. kusto, sql) do not exist for whatever reason, we will default it to sql. diff --git a/src/sql/workbench/common/editor/query/queryEditorInput.ts b/src/sql/workbench/common/editor/query/queryEditorInput.ts index eff674a0ec4e..c521a9ba13d7 100644 --- a/src/sql/workbench/common/editor/query/queryEditorInput.ts +++ b/src/sql/workbench/common/editor/query/queryEditorInput.ts @@ -21,7 +21,6 @@ import { IQueryEditorConfiguration } from 'sql/platform/query/common/query'; import { EditorInput } from 'vs/workbench/common/editor/editorInput'; import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation'; import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces'; -import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions'; const MAX_SIZE = 13; @@ -144,8 +143,6 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab private _state = this._register(new QueryEditorState()); public get state(): QueryEditorState { return this._state; } - private _serverContext: string[]; - constructor( private _description: string | undefined, protected _text: AbstractTextResourceEditorInput, @@ -154,8 +151,7 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab @IQueryModelService private readonly queryModelService: IQueryModelService, @IConfigurationService private readonly configurationService: IConfigurationService, @IInstantiationService protected readonly instantiationService: IInstantiationService, - @IServerContextualizationService private readonly serverContextualizationService: IServerContextualizationService, - @IExtensionService private readonly extensionService: IExtensionService + @IServerContextualizationService private readonly serverContextualizationService: IServerContextualizationService ) { super(); @@ -241,27 +237,6 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab public override isDirty(): boolean { return this._text.isDirty(); } public get resource(): URI { return this._text.resource; } - public async getServerContext(): Promise { - const copilotExt = await this.extensionService.getExtension('github.copilot'); - - if (copilotExt && this.configurationService.getValue('queryEditor').githubCopilotContextualizationEnabled) { - if (!this._serverContext) { - const result = await this.serverContextualizationService.getServerContextualization(this.uri); - // TODO lewissanchez - Remove this from here once Copilot starts pulling context. That isn't implemented yet, so - // getting scripts this way for now. - this._serverContext = result.context; - - return this._serverContext; - } - else { - return this._serverContext; - } - } - else { - return Promise.resolve([]); - } - } - public override getName(longForm?: boolean): string { if (this.configurationService.getValue('queryEditor').showConnectionInfoInTitle) { let profile = this.connectionManagementService.getConnectionProfile(this.uri); @@ -346,6 +321,9 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab } } this._onDidChangeLabel.fire(); + + // Intentionally not awaiting, so that contextualization can happen in the background + void this.serverContextualizationService?.contextualizeUriForCopilot(this.uri); } public onDisconnect(): void { diff --git a/src/sql/workbench/contrib/commandLine/test/electron-browser/commandLine.test.ts b/src/sql/workbench/contrib/commandLine/test/electron-browser/commandLine.test.ts index cce62f977c3f..ccfbeb03bc67 100644 --- a/src/sql/workbench/contrib/commandLine/test/electron-browser/commandLine.test.ts +++ b/src/sql/workbench/contrib/commandLine/test/electron-browser/commandLine.test.ts @@ -467,7 +467,7 @@ suite('commandLineService tests', () => { let uri = URI.file(args._[0]); const workbenchinstantiationService = workbenchInstantiationService(); const editorInput = workbenchinstantiationService.createInstance(FileEditorInput, uri, undefined, undefined, undefined, undefined, undefined, undefined); - const queryInput = new FileQueryEditorInput(undefined, editorInput, undefined, connectionManagementService.object, querymodelService.object, configurationService.object, workbenchinstantiationService, undefined, undefined); + const queryInput = new FileQueryEditorInput(undefined, editorInput, undefined, connectionManagementService.object, querymodelService.object, configurationService.object, workbenchinstantiationService, undefined); queryInput.state.connected = true; const editorService: TypeMoq.Mock = TypeMoq.Mock.ofType(TestEditorService, TypeMoq.MockBehavior.Strict); editorService.setup(e => e.editors).returns(() => [queryInput]); diff --git a/src/sql/workbench/contrib/query/test/browser/queryEditor.test.ts b/src/sql/workbench/contrib/query/test/browser/queryEditor.test.ts index 56fbef5cebd2..0722ccf1f7db 100644 --- a/src/sql/workbench/contrib/query/test/browser/queryEditor.test.ts +++ b/src/sql/workbench/contrib/query/test/browser/queryEditor.test.ts @@ -317,7 +317,6 @@ suite('SQL QueryEditor Tests', () => { testinstantiationService, undefined, undefined, - undefined, undefined ); }); diff --git a/src/sql/workbench/services/contextualization/common/interfaces.ts b/src/sql/workbench/services/contextualization/common/interfaces.ts index b53dc15f1477..508cd7e2f287 100644 --- a/src/sql/workbench/services/contextualization/common/interfaces.ts +++ b/src/sql/workbench/services/contextualization/common/interfaces.ts @@ -30,14 +30,9 @@ export interface IServerContextualizationService { getProvider(providerId: string): azdata.contextualization.ServerContextualizationProvider; /** - * Generates server context - * @param ownerUri The URI of the connection to generate context for. + * Contextualizes the provided URI for GitHub Copilot. + * @param uri The URI to contextualize for Copilot. + * @returns Copilot will have the URI contextualized when the promise completes. */ - generateServerContextualization(ownerUri: string): void; - - /** - * Gets all database context. - * @param ownerUri The URI of the connection to get context for. - */ - getServerContextualization(ownerUri: string): Promise; + contextualizeUriForCopilot(uri: string): Promise; } diff --git a/src/sql/workbench/services/contextualization/common/serverContextualizationService.ts b/src/sql/workbench/services/contextualization/common/serverContextualizationService.ts index a9980d4b756c..098aa35756dc 100644 --- a/src/sql/workbench/services/contextualization/common/serverContextualizationService.ts +++ b/src/sql/workbench/services/contextualization/common/serverContextualizationService.ts @@ -5,10 +5,11 @@ import * as azdata from 'azdata'; import { invalidProvider } from 'sql/base/common/errors'; -import { IConnectionManagementService, IConnectionParams } from 'sql/platform/connection/common/connectionManagement'; +import { IConnectionManagementService } from 'sql/platform/connection/common/connectionManagement'; import { IQueryEditorConfiguration } from 'sql/platform/query/common/query'; import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces'; import { Disposable } from 'vs/base/common/lifecycle'; +import { ICommandService } from 'vs/platform/commands/common/commands'; import { IConfigurationService } from 'vs/platform/configuration/common/configuration'; import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions'; @@ -19,18 +20,10 @@ export class ServerContextualizationService extends Disposable implements IServe constructor( @IConnectionManagementService private readonly _connectionManagementService: IConnectionManagementService, @IConfigurationService private readonly _configurationService: IConfigurationService, - @IExtensionService private readonly _extensionService: IExtensionService + @IExtensionService private readonly _extensionService: IExtensionService, + @ICommandService private readonly _commandService: ICommandService ) { super(); - - this._register(this._connectionManagementService.onConnect(async (e: IConnectionParams) => { - const copilotExt = await this._extensionService.getExtension('github.copilot'); - - if (copilotExt && this._configurationService.getValue('queryEditor').githubCopilotContextualizationEnabled) { - const ownerUri = e.connectionUri; - await this.generateServerContextualization(ownerUri); - } - })); } /** @@ -63,15 +56,44 @@ export class ServerContextualizationService extends Disposable implements IServe throw invalidProvider(providerId); } + /** + * Contextualizes the provided URI for GitHub Copilot. + * @param uri The URI to contextualize for Copilot. + * @returns Copilot will have the URI contextualized when the promise completes. + */ + public async contextualizeUriForCopilot(uri: string): Promise { + // Don't need to take any actions if contextualization is not enabled and can return + const isContextualizationNeeded = await this.isContextualizationNeeded(); + if (!isContextualizationNeeded) { + return; + } + + const getServerContextualizationResult = await this.getServerContextualization(uri); + if (getServerContextualizationResult.context) { + await this.sendServerContextualizationToCopilot(getServerContextualizationResult.context); + } + else { + const generateServerContextualizationResult = await this.generateServerContextualization(uri); + if (generateServerContextualizationResult.context) { + await this.sendServerContextualizationToCopilot(generateServerContextualizationResult.context); + } + } + } + /** * Generates server context * @param ownerUri The URI of the connection to generate context for. */ - public generateServerContextualization(ownerUri: string): void { + private async generateServerContextualization(ownerUri: string): Promise { const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri); const handler = this.getProvider(providerName); if (handler) { - handler.generateServerContextualization(ownerUri); + return await handler.generateServerContextualization(ownerUri); + } + else { + return Promise.resolve({ + context: undefined + }); } } @@ -79,7 +101,7 @@ export class ServerContextualizationService extends Disposable implements IServe * Gets all database context. * @param ownerUri The URI of the connection to get context for. */ - public async getServerContextualization(ownerUri: string): Promise { + private async getServerContextualization(ownerUri: string): Promise { const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri); const handler = this.getProvider(providerName); if (handler) { @@ -87,8 +109,33 @@ export class ServerContextualizationService extends Disposable implements IServe } else { return Promise.resolve({ - context: [] + context: undefined + }); + } + } + + /** + * Sends the provided context over to copilot, so that it can be used to generate improved suggestions. + * @param serverContext The context to be sent over to Copilot + */ + private async sendServerContextualizationToCopilot(serverContext: string | undefined): Promise { + if (serverContext) { + // LEWISSANCHEZ TODO: Find way to set context on untitled query editor files. Need to save first for Copilot status to say "Has Context" + await this._commandService.executeCommand('github.copilot.provideContext', '**/*.sql', { + value: serverContext }); } } + + /** + * Checks if contextualization is needed. This is based on whether the Copilot extension is installed and the GitHub Copilot + * contextualization setting is enabled. + * @returns A promise that resolves to true if contextualization is needed, false otherwise. + */ + private async isContextualizationNeeded(): Promise { + const copilotExt = await this._extensionService.getExtension('github.copilot'); + const isContextualizationEnabled = this._configurationService.getValue('queryEditor').githubCopilotContextualizationEnabled + + return (copilotExt && isContextualizationEnabled); + } }