Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AuthLibrary detection and use MSAL by default #22140

Merged
merged 2 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,9 @@ import { Iterable } from 'vs/base/common/iterator';
import { LoadingSpinner } from 'sql/base/browser/ui/loadingSpinner/loadingSpinner';
import { Tenant, TenantListDelegate, TenantListRenderer } from 'sql/workbench/services/accountManagement/browser/tenantListRenderer';
import { IAccountManagementService } from 'sql/platform/accounts/common/interfaces';
import { ADAL_AUTH_LIBRARY, AuthLibrary, getAuthLibrary } from 'sql/workbench/services/accountManagement/utils';

export const VIEWLET_ID = 'workbench.view.accountpanel';
export type AuthLibrary = 'ADAL' | 'MSAL';
export const MSAL_AUTH_LIBRARY: AuthLibrary = 'MSAL'; // default
export const ADAL_AUTH_LIBRARY: AuthLibrary = 'ADAL';

export class AccountPaneContainer extends ViewPaneContainer { }

Expand Down Expand Up @@ -392,7 +390,7 @@ export class AccountDialog extends Modal {
this._splitView!.layout(DOM.getContentHeight(this._container!));

// Set the initial items of the list
const authLibrary: AuthLibrary = this._configurationService.getValue('azure.authenticationLibrary');
const authLibrary: AuthLibrary = getAuthLibrary(this._configurationService);
let updatedAccounts: azdata.Account[];
if (authLibrary) {
updatedAccounts = filterAccounts(newProvider.initialAccounts, authLibrary);
Expand Down Expand Up @@ -443,7 +441,7 @@ export class AccountDialog extends Modal {
if (!providerMapping || !providerMapping.view) {
return;
}
const authLibrary: AuthLibrary = this._configurationService.getValue('azure.authenticationLibrary');
const authLibrary: AuthLibrary = getAuthLibrary(this._configurationService);
let updatedAccounts: azdata.Account[];
if (authLibrary) {
updatedAccounts = filterAccounts(args.accountList, authLibrary);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import { INotificationService, Severity, INotification } from 'vs/platform/notif
import { Action } from 'vs/base/common/actions';
import { DisposableStore } from 'vs/base/common/lifecycle';
import { IConfigurationService } from 'vs/platform/configuration/common/configuration';
import { ADAL_AUTH_LIBRARY, AuthLibrary, filterAccounts, MSAL_AUTH_LIBRARY } from 'sql/workbench/services/accountManagement/browser/accountDialog';
import { filterAccounts } from 'sql/workbench/services/accountManagement/browser/accountDialog';
import { ADAL_AUTH_LIBRARY, MSAL_AUTH_LIBRARY, AuthLibrary, AZURE_AUTH_LIBRARY_CONFIG, getAuthLibrary } from 'sql/workbench/services/accountManagement/utils';

export class AccountManagementService implements IAccountManagementService {
// CONSTANTS ///////////////////////////////////////////////////////////
Expand All @@ -41,7 +42,6 @@ export class AccountManagementService implements IAccountManagementService {
private _autoOAuthDialogController?: AutoOAuthDialogController;
private _mementoContext?: Memento;
protected readonly disposables = new DisposableStore();
private readonly configurationService: IConfigurationService;

// EVENT EMITTERS //////////////////////////////////////////////////////
private _addAccountProviderEmitter: Emitter<AccountProviderAddedEventParams>;
Expand All @@ -61,7 +61,7 @@ export class AccountManagementService implements IAccountManagementService {
@IOpenerService private _openerService: IOpenerService,
@ILogService private readonly _logService: ILogService,
@INotificationService private readonly _notificationService: INotificationService,
@IConfigurationService configurationService: IConfigurationService
@IConfigurationService private _configurationService: IConfigurationService
) {
this._mementoContext = new Memento(AccountManagementService.ACCOUNT_MEMENTO, this._storageService);
const mementoObj = this._mementoContext.getMemento(StorageScope.GLOBAL, StorageTarget.MACHINE);
Expand All @@ -71,11 +71,10 @@ export class AccountManagementService implements IAccountManagementService {
this._addAccountProviderEmitter = new Emitter<AccountProviderAddedEventParams>();
this._removeAccountProviderEmitter = new Emitter<azdata.AccountProviderMetadata>();
this._updateAccountListEmitter = new Emitter<UpdateAccountListEventParams>();
this.configurationService = configurationService;

// Determine authentication library in use, to support filtering accounts respectively.
// When this value is changed a restart is required so there isn't a need to dynamically update this value at runtime.
this._authLibrary = this.configurationService.getValue('azure.authenticationLibrary');
this._authLibrary = getAuthLibrary(this._configurationService);

_storageService.onWillSaveState(() => this.shutdown());
this.registerListeners();
Expand Down Expand Up @@ -169,7 +168,6 @@ export class AccountManagementService implements IAccountManagementService {
if (result.accountModified) {
this.spliceModifiedAccount(provider, result.changedAccount);
}

this.fireAccountListUpdate(provider, result.accountAdded);
} finally {
notificationHandler.close();
Expand Down Expand Up @@ -519,7 +517,7 @@ export class AccountManagementService implements IAccountManagementService {
});
}

const authLibrary: AuthLibrary = this.configurationService.getValue('azure.authenticationLibrary');
const authLibrary: AuthLibrary = getAuthLibrary(this._configurationService)
let updatedAccounts: azdata.Account[]
if (authLibrary) {
updatedAccounts = filterAccounts(provider.accounts, authLibrary);
Expand All @@ -543,9 +541,9 @@ export class AccountManagementService implements IAccountManagementService {
}

private registerListeners(): void {
this.disposables.add(this.configurationService.onDidChangeConfiguration(async e => {
if (e.affectsConfiguration('azure.authenticationLibrary')) {
const authLibrary: AuthLibrary = this.configurationService.getValue('azure.authenticationLibrary') ?? MSAL_AUTH_LIBRARY;
this.disposables.add(this._configurationService.onDidChangeConfiguration(async e => {
if (e.affectsConfiguration(AZURE_AUTH_LIBRARY_CONFIG)) {
const authLibrary: AuthLibrary = getAuthLibrary(this._configurationService);
let accounts = await this._accountStore.getAllAccounts();
if (accounts) {
let updatedAccounts = await this.filterAndMergeAccounts(accounts, authLibrary);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ const noAccountProvider: azdata.AccountProviderMetadata = {
const account: azdata.Account = {
key: {
providerId: hasAccountProvider.id,
accountId: 'testAccount1'
accountId: 'testAccount1',
authLibrary: 'MSAL'
},
displayInfo: {
displayName: 'Test Account 1',
Expand Down Expand Up @@ -320,7 +321,12 @@ suite('Account Management Service Tests:', () => {
return ams.getAccountsForProvider(hasAccountProvider.id)
.then(result => {
// Then: I should get back the list of accounts
assert.strictEqual(result, accountList);
// Since account are filtered by AuthLibrary and list is prepared again, they are not strict equal.
// We compare strict equality of actual accounts here.
assert.strictEqual(accountList.length, result.length);
for (var i = 0; i < accountList.length; i++) {
assert.strictEqual(result[i], accountList[i]);
}
});
});

Expand Down Expand Up @@ -534,7 +540,8 @@ function getTestState(): AccountManagementState {
const testConfigurationService = new TestConfigurationService();

// Create the account management service
let ams = new AccountManagementService(mockInstantiationService.object, new TestStorageService(), undefined!, undefined!, undefined!, testNotificationService, testConfigurationService);
let ams = new AccountManagementService(mockInstantiationService.object, new TestStorageService(),
undefined, undefined, undefined, testNotificationService, testConfigurationService);

// Wire up event handlers
let evUpdate = new EventVerifierSingle<UpdateAccountListEventParams>();
Expand Down
18 changes: 18 additions & 0 deletions src/sql/workbench/services/accountManagement/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

import { IConfigurationService } from 'vs/platform/configuration/common/configuration';

export const AZURE_AUTH_LIBRARY_CONFIG = 'azure.authenticationLibrary';

export type AuthLibrary = 'ADAL' | 'MSAL';
export const MSAL_AUTH_LIBRARY: AuthLibrary = 'MSAL';
export const ADAL_AUTH_LIBRARY: AuthLibrary = 'ADAL';

export const DEFAULT_AUTH_LIBRARY: AuthLibrary = MSAL_AUTH_LIBRARY;

export function getAuthLibrary(configurationService: IConfigurationService): AuthLibrary {
return configurationService.getValue(AZURE_AUTH_LIBRARY_CONFIG) || DEFAULT_AUTH_LIBRARY;
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ import Severity from 'vs/base/common/severity';
import { ConnectionStringOptions } from 'sql/platform/capabilities/common/capabilitiesService';
import { isFalsyOrWhitespace } from 'vs/base/common/strings';
import { IConfigurationService } from 'vs/platform/configuration/common/configuration';
import { AuthLibrary, filterAccounts } from 'sql/workbench/services/accountManagement/browser/accountDialog';
import { filterAccounts } from 'sql/workbench/services/accountManagement/browser/accountDialog';
import { AuthenticationType, Actions } from 'sql/platform/connection/common/constants';
import { AdsWidget } from 'sql/base/browser/ui/adsWidget';
import { createCSSRule } from 'vs/base/browser/dom';
import { AuthLibrary, getAuthLibrary } from 'sql/workbench/services/accountManagement/utils';

const ConnectionStringText = localize('connectionWidget.connectionString', "Connection string");

Expand Down Expand Up @@ -112,7 +113,6 @@ export class ConnectionWidget extends lifecycle.Disposable {
color: undefined,
description: undefined,
};
private readonly configurationService: IConfigurationService;
constructor(options: azdata.ConnectionOption[],
callbacks: IConnectionComponentCallbacks,
providerName: string,
Expand All @@ -122,7 +122,7 @@ export class ConnectionWidget extends lifecycle.Disposable {
@IAccountManagementService private _accountManagementService: IAccountManagementService,
@ILogService protected _logService: ILogService,
@IErrorMessageService private _errorMessageService: IErrorMessageService,
@IConfigurationService configurationService: IConfigurationService
@IConfigurationService private _configurationService: IConfigurationService
) {
super();
this._callbacks = callbacks;
Expand All @@ -142,7 +142,6 @@ export class ConnectionWidget extends lifecycle.Disposable {
}
this._providerName = providerName;
this._connectionStringOptions = this._connectionManagementService.getProviderProperties(this._providerName).connectionStringOptions;
this.configurationService = configurationService;
}

protected getAuthTypeDefault(option: azdata.ConnectionOption, os: OperatingSystem): string {
Expand Down Expand Up @@ -689,7 +688,7 @@ export class ConnectionWidget extends lifecycle.Disposable {
let oldSelection = this._azureAccountDropdown.value;
const accounts = await this._accountManagementService.getAccounts();
const updatedAccounts = accounts.filter(a => a.key.providerId.startsWith('azure'));
const authLibrary: AuthLibrary = this.configurationService.getValue('azure.authenticationLibrary');
const authLibrary: AuthLibrary = getAuthLibrary(this._configurationService);
if (authLibrary) {
this._azureAccountList = filterAccounts(updatedAccounts, authLibrary);
}
Expand Down