Skip to content

Commit

Permalink
Send server contextualization to Copilot extension (#24230)
Browse files Browse the repository at this point in the history
* Send server contextualization to Copilot extension

* Keep context in editor input

* Remove unnecessary server context and extension service

* Send context when connecting from open editor

* Remove contextualization complete event

* Contextualize editor after connection success

* Minor clean up

* Remove nested then and use async/await

* Create helper function

* Remove unneeded async and add comment

* Encapsulate all context logic in service

* Use void operator to fix floating promise

* Correct return comment
  • Loading branch information
lewis-sanchez authored Sep 1, 2023
1 parent e3d0670 commit 5152823
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 72 deletions.
4 changes: 2 additions & 2 deletions extensions/mssql/src/contracts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1567,8 +1567,8 @@ export interface ServerContextualizationParams {
ownerUri: string;
}

export namespace GenerateServerContextualizationNotification {
export const type = new NotificationType<ServerContextualizationParams, void>('metadata/generateServerContext');
export namespace GenerateServerContextualizationRequest {
export const type = new RequestType<ServerContextualizationParams, azdata.contextualization.GenerateServerContextualizationResult, void, void>('metadata/generateServerContext');
}

export namespace GetServerContextualizationRequest {
Expand Down
12 changes: 9 additions & 3 deletions extensions/mssql/src/features.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ export class ExecutionPlanServiceFeature extends SqlOpsFeature<undefined> {
*/
export class ServerContextualizationServiceFeature extends SqlOpsFeature<undefined> {
private static readonly messagesTypes: RPCMessageType[] = [
contracts.GenerateServerContextualizationNotification.type
contracts.GenerateServerContextualizationRequest.type
];

constructor(client: SqlOpsDataClient) {
Expand All @@ -1330,12 +1330,18 @@ export class ServerContextualizationServiceFeature extends SqlOpsFeature<undefin
protected registerProvider(options: undefined): Disposable {
const client = this._client;

const generateServerContextualization = (ownerUri: string): void => {
const generateServerContextualization = (ownerUri: string): Thenable<azdata.contextualization.GenerateServerContextualizationResult> => {
const params: contracts.ServerContextualizationParams = {
ownerUri: ownerUri
};

return client.sendNotification(contracts.GenerateServerContextualizationNotification.type, params);
return client.sendRequest(contracts.GenerateServerContextualizationRequest.type, params).then(
r => r,
e => {
client.logFailedRequest(contracts.GenerateServerContextualizationRequest.type, e);
return Promise.reject(e);
}
);
};

const getServerContextualization = (ownerUri: string): Thenable<azdata.contextualization.GetServerContextualizationResult> => {
Expand Down
15 changes: 11 additions & 4 deletions src/sql/azdata.proposed.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ declare module 'azdata' {
* Copilot for improved suggestions.
* @param provider The provider to register
*/
export function registerServerContextualizationProvider(provider: contextualization.ServerContextualizationProvider): vscode.Disposable
export function registerServerContextualizationProvider(provider: contextualization.ServerContextualizationProvider): vscode.Disposable;
}

export namespace designers {
Expand Down Expand Up @@ -1783,19 +1783,26 @@ declare module 'azdata' {
}

export namespace contextualization {
export interface GenerateServerContextualizationResult {
/**
* The generated server context.
*/
context: string | undefined;
}

export interface GetServerContextualizationResult {
/**
* An array containing the generated server context.
* The retrieved server context.
*/
context: string[];
context: string | undefined;
}

export interface ServerContextualizationProvider extends DataProvider {
/**
* Generates server context.
* @param ownerUri The URI of the connection to generate context for.
*/
generateServerContextualization(ownerUri: string): void;
generateServerContextualization(ownerUri: string): Thenable<GenerateServerContextualizationResult>;

/**
* Gets server context, which can be in the form of create scripts but is left up each provider.
Expand Down
4 changes: 2 additions & 2 deletions src/sql/workbench/api/common/extHostDataProtocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -972,8 +972,8 @@ export class ExtHostDataProtocol extends ExtHostDataProtocolShape {

// Database Server Contextualization API

public override $generateServerContextualization(handle: number, ownerUri: string): void {
this._resolveProvider<azdata.contextualization.ServerContextualizationProvider>(handle).generateServerContextualization(ownerUri);
public override $generateServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GenerateServerContextualizationResult> {
return this._resolveProvider<azdata.contextualization.ServerContextualizationProvider>(handle).generateServerContextualization(ownerUri);
}

public override $getServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GetServerContextualizationResult> {
Expand Down
2 changes: 1 addition & 1 deletion src/sql/workbench/api/common/sqlExtHost.protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ export abstract class ExtHostDataProtocolShape {
/**
* Generates server context.
*/
$generateServerContextualization(handle: number, ownerUri: string): void { throw ni(); }
$generateServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GenerateServerContextualizationResult> { throw ni(); }
/**
* Gets server context.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import { IInstantiationService } from 'vs/platform/instantiation/common/instanti
import { EditorInput } from 'vs/workbench/common/editor/editorInput';
import { IResourceEditorInput } from 'vs/platform/editor/common/editor';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';

export class FileQueryEditorInput extends QueryEditorInput {

Expand All @@ -33,10 +32,9 @@ export class FileQueryEditorInput extends QueryEditorInput {
@IQueryModelService queryModelService: IQueryModelService,
@IConfigurationService configurationService: IConfigurationService,
@IInstantiationService instantiationService: IInstantiationService,
@IServerContextualizationService serverContextualizationService: IServerContextualizationService,
@IExtensionService extensionService: IExtensionService
@IServerContextualizationService serverContextualizationService: IServerContextualizationService
) {
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService, extensionService);
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService);
}

public override resolve(): Promise<ITextFileEditorModel | BinaryEditorModel> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import { IEditorResolverService } from 'vs/workbench/services/editor/common/edit
import { Uri } from 'vscode';
import { ILogService } from 'vs/platform/log/common/log';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';

export class UntitledQueryEditorInput extends QueryEditorInput implements IUntitledQueryEditorInput {

Expand All @@ -39,10 +38,9 @@ export class UntitledQueryEditorInput extends QueryEditorInput implements IUntit
@IInstantiationService instantiationService: IInstantiationService,
@ILogService private readonly logService: ILogService,
@IEditorResolverService private readonly editorResolverService: IEditorResolverService,
@IServerContextualizationService serverContextualizationService: IServerContextualizationService,
@IExtensionService extensionService: IExtensionService
@IServerContextualizationService serverContextualizationService: IServerContextualizationService
) {
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService, extensionService);
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService);
// Set the mode explicitely to stop the auto language detection service from changing the mode unexpectedly.
// the auto language detection service won't do the language change only if the mode is explicitely set.
// if the mode (e.g. kusto, sql) do not exist for whatever reason, we will default it to sql.
Expand Down
30 changes: 4 additions & 26 deletions src/sql/workbench/common/editor/query/queryEditorInput.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import { IQueryEditorConfiguration } from 'sql/platform/query/common/query';
import { EditorInput } from 'vs/workbench/common/editor/editorInput';
import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';

const MAX_SIZE = 13;

Expand Down Expand Up @@ -144,8 +143,6 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
private _state = this._register(new QueryEditorState());
public get state(): QueryEditorState { return this._state; }

private _serverContext: string[];

constructor(
private _description: string | undefined,
protected _text: AbstractTextResourceEditorInput,
Expand All @@ -154,8 +151,7 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
@IQueryModelService private readonly queryModelService: IQueryModelService,
@IConfigurationService private readonly configurationService: IConfigurationService,
@IInstantiationService protected readonly instantiationService: IInstantiationService,
@IServerContextualizationService private readonly serverContextualizationService: IServerContextualizationService,
@IExtensionService private readonly extensionService: IExtensionService
@IServerContextualizationService private readonly serverContextualizationService: IServerContextualizationService
) {
super();

Expand Down Expand Up @@ -241,27 +237,6 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
public override isDirty(): boolean { return this._text.isDirty(); }
public get resource(): URI { return this._text.resource; }

public async getServerContext(): Promise<string[]> {
const copilotExt = await this.extensionService.getExtension('github.copilot');

if (copilotExt && this.configurationService.getValue<IQueryEditorConfiguration>('queryEditor').githubCopilotContextualizationEnabled) {
if (!this._serverContext) {
const result = await this.serverContextualizationService.getServerContextualization(this.uri);
// TODO lewissanchez - Remove this from here once Copilot starts pulling context. That isn't implemented yet, so
// getting scripts this way for now.
this._serverContext = result.context;

return this._serverContext;
}
else {
return this._serverContext;
}
}
else {
return Promise.resolve([]);
}
}

public override getName(longForm?: boolean): string {
if (this.configurationService.getValue<IQueryEditorConfiguration>('queryEditor').showConnectionInfoInTitle) {
let profile = this.connectionManagementService.getConnectionProfile(this.uri);
Expand Down Expand Up @@ -346,6 +321,9 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
}
}
this._onDidChangeLabel.fire();

// Intentionally not awaiting, so that contextualization can happen in the background
void this.serverContextualizationService?.contextualizeUriForCopilot(this.uri);
}

public onDisconnect(): void {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ suite('commandLineService tests', () => {
let uri = URI.file(args._[0]);
const workbenchinstantiationService = workbenchInstantiationService();
const editorInput = workbenchinstantiationService.createInstance(FileEditorInput, uri, undefined, undefined, undefined, undefined, undefined, undefined);
const queryInput = new FileQueryEditorInput(undefined, editorInput, undefined, connectionManagementService.object, querymodelService.object, configurationService.object, workbenchinstantiationService, undefined, undefined);
const queryInput = new FileQueryEditorInput(undefined, editorInput, undefined, connectionManagementService.object, querymodelService.object, configurationService.object, workbenchinstantiationService, undefined);
queryInput.state.connected = true;
const editorService: TypeMoq.Mock<IEditorService> = TypeMoq.Mock.ofType<IEditorService>(TestEditorService, TypeMoq.MockBehavior.Strict);
editorService.setup(e => e.editors).returns(() => [queryInput]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ suite('SQL QueryEditor Tests', () => {
testinstantiationService,
undefined,
undefined,
undefined,
undefined
);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,9 @@ export interface IServerContextualizationService {
getProvider(providerId: string): azdata.contextualization.ServerContextualizationProvider;

/**
* Generates server context
* @param ownerUri The URI of the connection to generate context for.
* Contextualizes the provided URI for GitHub Copilot.
* @param uri The URI to contextualize for Copilot.
* @returns Copilot will have the URI contextualized when the promise completes.
*/
generateServerContextualization(ownerUri: string): void;

/**
* Gets all database context.
* @param ownerUri The URI of the connection to get context for.
*/
getServerContextualization(ownerUri: string): Promise<azdata.contextualization.GetServerContextualizationResult>;
contextualizeUriForCopilot(uri: string): Promise<void>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import * as azdata from 'azdata';
import { invalidProvider } from 'sql/base/common/errors';
import { IConnectionManagementService, IConnectionParams } from 'sql/platform/connection/common/connectionManagement';
import { IConnectionManagementService } from 'sql/platform/connection/common/connectionManagement';
import { IQueryEditorConfiguration } from 'sql/platform/query/common/query';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { Disposable } from 'vs/base/common/lifecycle';
import { ICommandService } from 'vs/platform/commands/common/commands';
import { IConfigurationService } from 'vs/platform/configuration/common/configuration';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';

Expand All @@ -19,18 +20,10 @@ export class ServerContextualizationService extends Disposable implements IServe
constructor(
@IConnectionManagementService private readonly _connectionManagementService: IConnectionManagementService,
@IConfigurationService private readonly _configurationService: IConfigurationService,
@IExtensionService private readonly _extensionService: IExtensionService
@IExtensionService private readonly _extensionService: IExtensionService,
@ICommandService private readonly _commandService: ICommandService
) {
super();

this._register(this._connectionManagementService.onConnect(async (e: IConnectionParams) => {
const copilotExt = await this._extensionService.getExtension('github.copilot');

if (copilotExt && this._configurationService.getValue<IQueryEditorConfiguration>('queryEditor').githubCopilotContextualizationEnabled) {
const ownerUri = e.connectionUri;
await this.generateServerContextualization(ownerUri);
}
}));
}

/**
Expand Down Expand Up @@ -63,32 +56,86 @@ export class ServerContextualizationService extends Disposable implements IServe
throw invalidProvider(providerId);
}

/**
* Contextualizes the provided URI for GitHub Copilot.
* @param uri The URI to contextualize for Copilot.
* @returns Copilot will have the URI contextualized when the promise completes.
*/
public async contextualizeUriForCopilot(uri: string): Promise<void> {
// Don't need to take any actions if contextualization is not enabled and can return
const isContextualizationNeeded = await this.isContextualizationNeeded();
if (!isContextualizationNeeded) {
return;
}

const getServerContextualizationResult = await this.getServerContextualization(uri);
if (getServerContextualizationResult.context) {
await this.sendServerContextualizationToCopilot(getServerContextualizationResult.context);
}
else {
const generateServerContextualizationResult = await this.generateServerContextualization(uri);
if (generateServerContextualizationResult.context) {
await this.sendServerContextualizationToCopilot(generateServerContextualizationResult.context);
}
}
}

/**
* Generates server context
* @param ownerUri The URI of the connection to generate context for.
*/
public generateServerContextualization(ownerUri: string): void {
private async generateServerContextualization(ownerUri: string): Promise<azdata.contextualization.GenerateServerContextualizationResult> {
const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri);
const handler = this.getProvider(providerName);
if (handler) {
handler.generateServerContextualization(ownerUri);
return await handler.generateServerContextualization(ownerUri);
}
else {
return Promise.resolve({
context: undefined
});
}
}

/**
* Gets all database context.
* @param ownerUri The URI of the connection to get context for.
*/
public async getServerContextualization(ownerUri: string): Promise<azdata.contextualization.GetServerContextualizationResult> {
private async getServerContextualization(ownerUri: string): Promise<azdata.contextualization.GetServerContextualizationResult> {
const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri);
const handler = this.getProvider(providerName);
if (handler) {
return await handler.getServerContextualization(ownerUri);
}
else {
return Promise.resolve({
context: []
context: undefined
});
}
}

/**
* Sends the provided context over to copilot, so that it can be used to generate improved suggestions.
* @param serverContext The context to be sent over to Copilot
*/
private async sendServerContextualizationToCopilot(serverContext: string | undefined): Promise<void> {
if (serverContext) {
// LEWISSANCHEZ TODO: Find way to set context on untitled query editor files. Need to save first for Copilot status to say "Has Context"
await this._commandService.executeCommand('github.copilot.provideContext', '**/*.sql', {
value: serverContext
});
}
}

/**
* Checks if contextualization is needed. This is based on whether the Copilot extension is installed and the GitHub Copilot
* contextualization setting is enabled.
* @returns A promise that resolves to true if contextualization is needed, false otherwise.
*/
private async isContextualizationNeeded(): Promise<boolean> {
const copilotExt = await this._extensionService.getExtension('github.copilot');
const isContextualizationEnabled = this._configurationService.getValue<IQueryEditorConfiguration>('queryEditor').githubCopilotContextualizationEnabled

return (copilotExt && isContextualizationEnabled);
}
}

0 comments on commit 5152823

Please sign in to comment.