diff --git a/.github/scripts/overwrites/stateful.sh b/.github/scripts/overwrites/stateful.sh index 07780d752..f3806c8bb 100755 --- a/.github/scripts/overwrites/stateful.sh +++ b/.github/scripts/overwrites/stateful.sh @@ -5,7 +5,7 @@ npm pkg set name="$EXTENSION_NAME" npm pkg set displayName="Stateful Notebooks for DevOps" npm pkg set description="DevOps Notebooks built on Runme, connected for collaboration." npm pkg set homepage="https://stateful.com" -npm pkg set contributes.configuration[0].properties[runme.app.baseDomain].default="platform.stateful.com" +npm pkg set contributes.configuration[0].properties[runme.app.baseDomain].default="cloud.stateful.com" npm pkg set contributes.configuration[0].properties[runme.app.platformAuth].default=true --json npm pkg set contributes.configuration[0].properties[runme.server.lifecycleIdentity].default=1 --json npm pkg set contributes.configuration[0].properties[runme.app.notebookAutoSave].default="yes" diff --git a/README-platform.md b/README-platform.md index 661bef5c2..cbeb53202 100644 --- a/README-platform.md +++ b/README-platform.md @@ -45,7 +45,7 @@ The Stateful cloud offers a suite of tools and services designed to enhance your ## Getting Started -1. **Sign Up**: Create your free account at [Stateful](https://platform.stateful.com/). +1. **Sign Up**: Create your free account at [Stateful](https://cloud.stateful.com/). 2. **Install VS Code Extension**: Download the Stateful extension from the [VS Code Marketplace](https://marketplace.visualstudio.com/items?itemName=stateful.platform). The extension is fully compatible with Runme, but adds authentication, collaboration, and security features, making it secure for teams and inside companies. 3. **Explore**: Start creating, running, sharing, and discussing your first DevOps Notebook and its commands using the Stateful cloud. diff --git a/package.json b/package.json index 0fca84df0..134ca1113 100644 --- a/package.json +++ b/package.json @@ -992,7 +992,7 @@ }, "runme.app.baseDomain": { "type": "string", - "default": "platform.stateful.com", + "default": "cloud.stateful.com", "scope": "window", "markdownDescription": "Base domain to be use for Runme app" }, diff --git a/src/extension/api/client.ts b/src/extension/api/client.ts index 924fefb7c..8129bacd3 100644 --- a/src/extension/api/client.ts +++ b/src/extension/api/client.ts @@ -5,21 +5,26 @@ import { Uri } from 'vscode' import { getRunmeAppUrl } from '../../utils/configuration' import { getFeaturesContext } from '../features' +import { StatefulAuthProvider } from '../provider/statefulAuth' -export function InitializeClient({ - uri, - runmeToken, -}: { - uri?: string | undefined - runmeToken: string -}) { +export async function InitializeCloudClient(uri?: string) { + const session = await StatefulAuthProvider.instance.currentSession() + + if (!session) { + throw new Error('You must authenticate with your Stateful account') + } + + return InitializeClient({ uri, token: session.accessToken }) +} + +function InitializeClient({ uri, token }: { uri?: string | undefined; token: string }) { const authLink = setContext((_, { headers }) => { const context = getFeaturesContext() return { headers: { ...headers, 'Auth-Provider': 'platform', - authorization: runmeToken ? `Bearer ${runmeToken}` : '', + authorization: token ? `Bearer ${token}` : '', 'X-Extension-Id': context?.extensionId, 'X-Extension-Os': context?.os, 'X-Extension-Version': context?.extensionVersion, diff --git a/src/extension/authSessionChangeHandler.ts b/src/extension/authSessionChangeHandler.ts new file mode 100644 index 000000000..68717f4e1 --- /dev/null +++ b/src/extension/authSessionChangeHandler.ts @@ -0,0 +1,74 @@ +import { Subject, Subscription } from 'rxjs' +import { debounceTime, distinctUntilChanged } from 'rxjs/operators' +import { authentication, AuthenticationSessionsChangeEvent, Disposable } from 'vscode' + +export default class AuthSessionChangeHandler implements Disposable { + static #instance: AuthSessionChangeHandler | null = null + + #disposables: Disposable[] = [] + #eventSubject: Subject + #subscriptions: Subscription[] = [] + #listeners: ((event: AuthenticationSessionsChangeEvent) => void)[] = [] + + private constructor(private debounceTimeMs: number = 100) { + this.#eventSubject = new Subject() + this.#subscriptions.push( + this.#eventSubject + .pipe(distinctUntilChanged(this.eventComparer), debounceTime(this.debounceTimeMs)) + .subscribe((event) => { + this.notifyListeners(event) + }), + ) + + this.#disposables.push( + authentication.onDidChangeSessions((e) => { + this.#eventSubject.next(e) + }), + ) + } + + public static get instance(): AuthSessionChangeHandler { + if (!this.#instance) { + this.#instance = new AuthSessionChangeHandler() + } + + return this.#instance + } + + public addListener(listener: (event: AuthenticationSessionsChangeEvent) => void): void { + this.#listeners.push(listener) + } + + public removeListener(listener: (event: AuthenticationSessionsChangeEvent) => void): void { + this.#listeners = this.#listeners.filter((l) => l !== listener) + } + + private notifyListeners(event: AuthenticationSessionsChangeEvent): void { + for (const listener of this.#listeners) { + try { + listener(event) + } catch (err) { + console.error('Error in listener:', err) + } + } + } + + private eventComparer( + previous: AuthenticationSessionsChangeEvent, + current: AuthenticationSessionsChangeEvent, + ): boolean { + return ( + previous.provider.id === current.provider.id && + JSON.stringify(previous) === JSON.stringify(current) + ) + } + + public async dispose() { + this.#disposables.forEach((d) => d.dispose()) + this.#subscriptions = [] + this.#eventSubject.complete() + this.#listeners = [] + + AuthSessionChangeHandler.#instance = null + } +} diff --git a/src/extension/commands/index.ts b/src/extension/commands/index.ts index dcb1fd285..e75bb2a98 100644 --- a/src/extension/commands/index.ts +++ b/src/extension/commands/index.ts @@ -32,7 +32,6 @@ import { Kernel } from '../kernel' import { getAnnotations, getNotebookCategories, - getPlatformAuthSession, getTerminalByCell, openFileAsRunmeNotebook, promptUserSession, @@ -52,7 +51,7 @@ import { } from '../../constants' import ContextState from '../contextState' import { createGist } from '../services/github/gist' -import { InitializeClient } from '../api/client' +import { InitializeCloudClient } from '../api/client' import { GetUserEnvironmentsDocument } from '../__generated-platform__/graphql' import { EnvironmentManager } from '../environment/manager' import features from '../features' @@ -562,8 +561,7 @@ export async function createCellGistCommand(cell: NotebookCell, context: Extensi } export async function selectEnvironment(manager: EnvironmentManager) { - const session = await getPlatformAuthSession() - const graphClient = InitializeClient({ runmeToken: session?.accessToken! }) + const graphClient = await InitializeCloudClient() const result = await graphClient.query({ query: GetUserEnvironmentsDocument, diff --git a/src/extension/extension.ts b/src/extension/extension.ts index 0b6cf54af..87104bbb2 100644 --- a/src/extension/extension.ts +++ b/src/extension/extension.ts @@ -9,7 +9,6 @@ import { env, Uri, NotebookCell, - authentication, } from 'vscode' import { TelemetryReporter } from 'vscode-telemetry' import Channel from 'tangle/webviews' @@ -41,8 +40,6 @@ import { getDefaultWorkspace, bootFile, resetNotebookSettings, - getPlatformAuthSession, - getGithubAuthSession, openFileAsRunmeNotebook, } from './utils' import { RunmeTaskProvider } from './provider/runmeTask' @@ -95,6 +92,7 @@ import { EnvironmentManager } from './environment/manager' import ContextState from './contextState' import { RunmeIdentity } from './grpc/serializerTypes' import * as features from './features' +import AuthSessionChangeHandler from './authSessionChangeHandler' export class RunmeExtension { protected serializer?: SerializerBase @@ -218,7 +216,6 @@ export class RunmeExtension { // extension is deactivated. context.subscriptions.push(aiManager) - const uriHandler = new RunmeUriHandler(context, kernel, getForceNewWindowConfig()) const winCodeLensRunSurvey = new survey.SurveyWinCodeLensRun(context) const surveys: Disposable[] = [ winCodeLensRunSurvey, @@ -260,6 +257,8 @@ export class RunmeExtension { serializer, server, treeViewer, + StatefulAuthProvider.instance, + AuthSessionChangeHandler.instance, ...this.registerPanels(kernel, context), ...surveys, workspace.registerNotebookSerializer(Kernel.type, serializer, { @@ -338,7 +337,7 @@ export class RunmeExtension { /** * Uri handler */ - window.registerUriHandler(uriHandler), + window.registerUriHandler(new RunmeUriHandler(context, kernel, getForceNewWindowConfig())), /** * Runme Message Display commands @@ -398,7 +397,7 @@ export class RunmeExtension { commands.executeCommand('runme.lifecycleIdentitySelection', RunmeIdentity.CELL), ), - RunmeExtension.registerCommand( + commands.registerCommand( 'runme.lifecycleIdentitySelection', async (identity?: RunmeIdentity) => { if (identity === undefined) { @@ -412,6 +411,10 @@ export class RunmeExtension { return } + TelemetryReporter.sendTelemetryEvent('extension.command', { + command: 'runme.lifecycleIdentitySelection', + }) + await ContextState.addKey(NOTEBOOK_LIFECYCLE_ID, identity) await Promise.all( @@ -486,71 +489,51 @@ export class RunmeExtension { } if (kernel.isFeatureOn(FeatureName.RequireStatefulAuth)) { - const statefulAuthProvider = new StatefulAuthProvider(context, uriHandler) - context.subscriptions.push(statefulAuthProvider) + const logger = getLogger(FeatureName.RequireStatefulAuth) + try { + const session = await StatefulAuthProvider.instance.ensureSession() + const nunmeIdentity = session ? RunmeIdentity.ALL : getServerLifecycleIdentity() + await commands.executeCommand('runme.lifecycleIdentitySelection', nunmeIdentity) + } catch (error) { + let message + if (error instanceof Error) { + message = error.message + } else { + message = JSON.stringify(error) + } - const session = await getPlatformAuthSession(false, true) - let sessionFromToken = false - if (!session) { - sessionFromToken = await statefulAuthProvider.bootstrapFromToken() + logger.error(message) } - - const forceLogin = kernel.isFeatureOn(FeatureName.ForceLogin) || sessionFromToken - const silent = forceLogin ? undefined : true - - getPlatformAuthSession(forceLogin, silent) - .then((session) => { - if (session) { - statefulAuthProvider.showLoginNotification() - } - }) - .catch((error) => { - let message - if (error instanceof Error) { - message = error.message - } else { - message = JSON.stringify(error) - } - - // https://github.com/microsoft/vscode/blob/main/src/vs/workbench/api/browser/mainThreadAuthentication.ts#L238 - // throw new Error('User did not consent to login.') - // Calling again to ensure User Menu Badge - if (forceLogin && message === 'User did not consent to login.') { - getPlatformAuthSession(false) - } - }) } if (kernel.isFeatureOn(FeatureName.Gist)) { - context.subscriptions.push(new GithubAuthProvider(context)) - getGithubAuthSession(false).then((session) => { - kernel.updateFeatureContext('githubAuth', !!session) - }) + context.subscriptions.push(new GithubAuthProvider(context, kernel)) } - authentication.onDidChangeSessions((e) => { + AuthSessionChangeHandler.instance.addListener((e) => { if ( + StatefulAuthProvider.instance && kernel.isFeatureOn(FeatureName.RequireStatefulAuth) && e.provider.id === AuthenticationProviders.Stateful ) { - getPlatformAuthSession(false, true).then(async (session) => { - if (!!session) { + StatefulAuthProvider.instance.currentSession().then(async (session) => { + if (session) { await commands.executeCommand('runme.lifecycleIdentitySelection', RunmeIdentity.ALL) + kernel.emitPanelEvent('runme.cloud', 'onCommand', { + name: 'signIn', + panelId: 'runme.cloud', + }) } else { const settingsDefault = getServerLifecycleIdentity() await commands.executeCommand('runme.lifecycleIdentitySelection', settingsDefault) + kernel.emitPanelEvent('runme.cloud', 'onCommand', { + name: 'signOut', + panelId: 'runme.cloud', + }) } kernel.updateFeatureContext('statefulAuth', !!session) }) } - if ( - kernel.isFeatureOn(FeatureName.Gist) && - e.provider.id === AuthenticationProviders.GitHub - ) { - getGithubAuthSession(false).then((session) => { - kernel.updateFeatureContext('githubAuth', !!session) - }) - } }) // only ever enabled in hosted playground diff --git a/src/extension/handler/uri.ts b/src/extension/handler/uri.ts index 9c6f7838a..49fc3d644 100644 --- a/src/extension/handler/uri.ts +++ b/src/extension/handler/uri.ts @@ -13,7 +13,6 @@ import { TaskScope, ShellExecution, tasks, - EventEmitter, Disposable, } from 'vscode' import got from 'got' @@ -23,6 +22,7 @@ import { TelemetryReporter } from 'vscode-telemetry' import getLogger from '../logger' import { Kernel } from '../kernel' import { AuthenticationProviders } from '../../constants' +import { StatefulAuthProvider } from '../provider/statefulAuth' import { getProjectDir, @@ -45,8 +45,6 @@ const extensionNames: { [key: string]: string } = { export class RunmeUriHandler implements UriHandler, Disposable { #disposables: Disposable[] = [] - readonly #onAuth = this.register(new EventEmitter()) - readonly onAuthEvent = this.#onAuth.event constructor( private context: ExtensionContext, @@ -70,7 +68,7 @@ export class RunmeUriHandler implements UriHandler, Disposable { command, type: AuthenticationProviders.Stateful, }) - this.#onAuth.fire(uri) + StatefulAuthProvider.instance.fireOnAuthEvent(uri) return } else if (command === 'setup') { const { fileToOpen, repository } = parseParams(params) diff --git a/src/extension/index.ts b/src/extension/index.ts index 416f75b03..098e593f0 100644 --- a/src/extension/index.ts +++ b/src/extension/index.ts @@ -4,6 +4,7 @@ import { TelemetryReporter } from 'vscode-telemetry' import { RunmeExtension } from './extension' import getLogger from './logger' import { isTelemetryEnabled } from './utils' +import { StatefulAuthProvider } from './provider/statefulAuth' declare const CONNECTION_STR: string @@ -33,6 +34,7 @@ export async function activate(context: ExtensionContext) { log.info('Activating Extension') try { + StatefulAuthProvider.initialize(context) await ext.initialize(context) log.info('Extension successfully activated') } catch (err: any) { diff --git a/src/extension/kernel.ts b/src/extension/kernel.ts index 0dc107534..d93c75f3c 100644 --- a/src/extension/kernel.ts +++ b/src/extension/kernel.ts @@ -39,6 +39,8 @@ import { type ExtensionName, type FeatureContext, FeatureName, + SyncSchema, + FeatureObserver, } from '../types' import { ClientMessages, @@ -114,6 +116,7 @@ import { CommandModeEnum } from './grpc/runner/types' import { GrpcReporter } from './reporter' import { EnvStoreMonitorWithSession } from './panels/notebook' import { SignedIn } from './signedIn' +import { StatefulAuthProvider } from './provider/statefulAuth' enum ConfirmationItems { Yes = 'Yes', @@ -150,7 +153,7 @@ export class Kernel implements Disposable { protected panelManager: PanelManager protected serializer?: SerializerBase protected reporter?: GrpcReporter - protected featuresState$? + protected featuresState$?: FeatureObserver protected readonly monitor$ = new Subject() @@ -201,42 +204,44 @@ export class Kernel implements Disposable { this.#onlySignedIn = new SignedIn(this) this.#disposables.push(this.#onlySignedIn) - const packageJSON = context?.extension?.packageJSON || {} - const featContext: FeatureContext = { - os: os.platform(), - vsCodeVersion: version as string, - extensionVersion: packageJSON?.version, - githubAuth: false, - statefulAuth: false, - extensionId: context?.extension?.id as ExtensionName, - } + StatefulAuthProvider.instance.currentSession().then((session) => { + const packageJSON = context?.extension?.packageJSON || {} + const featContext: FeatureContext = { + os: os.platform(), + vsCodeVersion: version as string, + extensionVersion: packageJSON?.version, + githubAuth: false, + statefulAuth: !!session, + extensionId: context?.extension?.id as ExtensionName, + } - const runmeFeatureSettings = workspace.getConfiguration('runme.features') - const featureNames = Object.keys(FeatureName) + const runmeFeatureSettings = workspace.getConfiguration('runme.features') + const featureNames = Object.keys(FeatureName) - featureNames.forEach((feature) => { - if (runmeFeatureSettings.has(feature)) { - const result = runmeFeatureSettings.get(feature, false) - this.#featuresSettings.set(feature, result) - } - }) + featureNames.forEach((feature) => { + if (runmeFeatureSettings.has(feature)) { + const result = runmeFeatureSettings.get(feature, false) + this.#featuresSettings.set(feature, result) + } + }) - this.featuresState$ = features.loadState(packageJSON, featContext, this.#featuresSettings) + this.featuresState$ = features.loadState(packageJSON, featContext, this.#featuresSettings) - if (this.featuresState$) { - const subscription = this.featuresState$ - .pipe(map((_state) => features.getSnapshot(this.featuresState$))) - .subscribe((snapshot) => { - ContextState.addKey(FEATURES_CONTEXT_STATE_KEY, snapshot) - postClientMessage(this.messaging, ClientMessages.featuresUpdateAction, { - snapshot: snapshot, + if (this.featuresState$) { + const subscription = this.featuresState$ + .pipe(map((_state) => features.getSnapshot(this.featuresState$))) + .subscribe((snapshot) => { + ContextState.addKey(FEATURES_CONTEXT_STATE_KEY, snapshot) + postClientMessage(this.messaging, ClientMessages.featuresUpdateAction, { + snapshot: snapshot, + }) }) - }) - this.#disposables.push({ - dispose: () => subscription.unsubscribe(), - }) - } + this.#disposables.push({ + dispose: () => subscription.unsubscribe(), + }) + } + }) } get envProps() { @@ -247,6 +252,21 @@ export class Kernel implements Disposable { return getEnvProps(ext) } + emitPanelEvent( + panelId: string, + eventName: K, + payload: SyncSchema[K], + ) { + const panel = this.panelManager.getPanel(panelId) + + if (!panel) { + log.error(`Panel ${panelId} not found`) + return + } + + panel.getBus()?.emit(eventName, payload) + } + useMonitor() { return this.monitor$.asObservable() } diff --git a/src/extension/messages/platformRequest/createEscalation.ts b/src/extension/messages/platformRequest/createEscalation.ts index 41f1114e1..89595e822 100644 --- a/src/extension/messages/platformRequest/createEscalation.ts +++ b/src/extension/messages/platformRequest/createEscalation.ts @@ -2,8 +2,7 @@ import { TelemetryReporter } from 'vscode-telemetry' import { ClientMessages } from '../../../constants' import { ClientMessage, IApiMessage } from '../../../types' -import { InitializeClient } from '../../api/client' -import { getPlatformAuthSession } from '../../utils' +import { InitializeCloudClient } from '../../api/client' import { postClientMessage } from '../../../utils/messaging' import { CreateEscalationDocument, EscalationStatus } from '../../__generated-platform__/graphql' import { Kernel } from '../../kernel' @@ -21,13 +20,7 @@ export default async function createEscalation( log.info('Creating escalation', message.output.data.id) try { - const session = await getPlatformAuthSession() - - if (!session) { - throw new Error('You must authenticate with your Stateful account') - } - - const graphClient = InitializeClient({ runmeToken: session.accessToken }) + const graphClient = await InitializeCloudClient() const result = await graphClient.mutate({ mutation: CreateEscalationDocument, variables: { diff --git a/src/extension/messages/platformRequest/saveCellExecution.ts b/src/extension/messages/platformRequest/saveCellExecution.ts index dc3ff5506..489d82e8a 100644 --- a/src/extension/messages/platformRequest/saveCellExecution.ts +++ b/src/extension/messages/platformRequest/saveCellExecution.ts @@ -12,9 +12,9 @@ import { postClientMessage } from '../../../utils/messaging' import ContextState from '../../contextState' import { Kernel } from '../../kernel' import getLogger from '../../logger' -import { getAnnotations, getCellRunmeId, getGitContext, getPlatformAuthSession } from '../../utils' +import { getAnnotations, getCellRunmeId, getGitContext } from '../../utils' import { GrpcSerializer } from '../../serializer' -import { InitializeClient } from '../../api/client' +import { InitializeCloudClient } from '../../api/client' import { CreateCellExecutionDocument, CreateCellExecutionMutation, @@ -25,6 +25,7 @@ import { } from '../../__generated-platform__/graphql' import { Frontmatter } from '../../grpc/serializerTypes' import { getCellById } from '../../cell' +import { StatefulAuthProvider } from '../../provider/statefulAuth' export type APIRequestMessage = IApiMessage> const log = getLogger('SaveCell') @@ -39,10 +40,13 @@ export default async function saveCellExecution( try { const autoSaveIsOn = ContextState.getKey(NOTEBOOK_AUTOSAVE_ON) const forceLogin = kernel.isFeatureOn(FeatureName.ForceLogin) - const silent = forceLogin ? undefined : true - const createIfNone = !message.output.data.isUserAction && autoSaveIsOn ? false : true - const session = await getPlatformAuthSession(createIfNone && forceLogin, silent) + let session = await StatefulAuthProvider.instance.currentSession() + + if (!session && forceLogin) { + session = await StatefulAuthProvider.instance.newSession() + } + if (!session && message.output.data.isUserAction) { await commands.executeCommand('runme.openCloudPanel') return postClientMessage(messaging, ClientMessages.platformApiResponse, { @@ -53,7 +57,7 @@ export default async function saveCellExecution( }) } - const graphClient = InitializeClient({ runmeToken: session?.accessToken! }) + const graphClient = await InitializeCloudClient() const path = editor.notebook.uri.fsPath const gitCtx = await getGitContext(path) diff --git a/src/extension/messages/platformRequest/trackRunmeEvent.ts b/src/extension/messages/platformRequest/trackRunmeEvent.ts index 940d88fde..0f37c341e 100644 --- a/src/extension/messages/platformRequest/trackRunmeEvent.ts +++ b/src/extension/messages/platformRequest/trackRunmeEvent.ts @@ -2,8 +2,7 @@ import { TelemetryReporter } from 'vscode-telemetry' import { ClientMessages } from '../../../constants' import { ClientMessage, IApiMessage } from '../../../types' -import { InitializeClient } from '../../api/client' -import { getPlatformAuthSession } from '../../utils' +import { InitializeCloudClient } from '../../api/client' import { RunmeEventInput, RunmeEventInputType, @@ -24,13 +23,7 @@ export default async function trackRunmeEvent( log.info('Sending Runme event') try { - const session = await getPlatformAuthSession() - - if (!session) { - throw new Error('You must authenticate with your Stateful account') - } - - const graphClient = InitializeClient({ runmeToken: session.accessToken }) + const graphClient = await InitializeCloudClient() let input: RunmeEventInput | null = null diff --git a/src/extension/messages/platformRequest/updateCellExecution.ts b/src/extension/messages/platformRequest/updateCellExecution.ts index a61640583..6c0555039 100644 --- a/src/extension/messages/platformRequest/updateCellExecution.ts +++ b/src/extension/messages/platformRequest/updateCellExecution.ts @@ -2,8 +2,7 @@ import { TelemetryReporter } from 'vscode-telemetry' import { ClientMessages } from '../../../constants' import { ClientMessage, IApiMessage } from '../../../types' -import { InitializeClient } from '../../api/client' -import { getPlatformAuthSession } from '../../utils' +import { InitializeCloudClient } from '../../api/client' import { postClientMessage } from '../../../utils/messaging' import { ShareType, UpdateCellOutputDocument } from '../../__generated-platform__/graphql' import { Kernel } from '../../kernel' @@ -21,13 +20,7 @@ export default async function updateCellExecution( log.info('Updating cell execution', message.output.data.id) try { - const session = await getPlatformAuthSession() - - if (!session) { - throw new Error('You must authenticate with your Stateful account') - } - - const graphClient = InitializeClient({ runmeToken: session.accessToken }) + const graphClient = await InitializeCloudClient() const result = await graphClient.mutate({ mutation: UpdateCellOutputDocument, variables: { diff --git a/src/extension/panels/cloud.ts b/src/extension/panels/cloud.ts index 0c54b5ab8..d5294ba8f 100644 --- a/src/extension/panels/cloud.ts +++ b/src/extension/panels/cloud.ts @@ -1,12 +1,12 @@ -import { ExtensionContext, WebviewView, window, ColorThemeKind, Uri, env } from 'vscode' +import { ExtensionContext, WebviewView, window, ColorThemeKind, Uri, env, Webview } from 'vscode' -import { fetchStaticHtml, resolveAppToken } from '../utils' -import { IAppToken } from '../services/runme' +import { fetchStaticHtml } from '../utils' import { getRunmeAppUrl, getRunmePanelIdentifier } from '../../utils/configuration' import getLogger from '../logger' import { type SyncSchemaBus } from '../../types' import archiveCell from '../services/archiveCell' import unArchiveCell from '../services/unArchiveCell' +import { StatefulAuthProvider } from '../provider/statefulAuth' import { TanglePanel } from './base' @@ -24,6 +24,7 @@ const log = getLogger('CloudPanel') export default class CloudPanel extends TanglePanel { protected readonly appUrl: string = getRunmeAppUrl(['app']) protected readonly defaultUx: DefaultUx = 'panels' + protected currentWebview?: Webview constructor( protected readonly context: ExtensionContext, @@ -32,8 +33,22 @@ export default class CloudPanel extends TanglePanel { super(context, getRunmePanelIdentifier(identifier)) } - public async getAppToken(createIfNone: boolean = true): Promise { - return resolveAppToken(createIfNone) + public async getAppToken(createIfNone: boolean = false): Promise { + const session = await StatefulAuthProvider.instance.currentSession() + + if (session) { + return session.accessToken + } + + if (createIfNone) { + const session = await StatefulAuthProvider.instance.newSession() + + if (session) { + return session.accessToken + } + } + + return null } public hydrateHtml(html: string, payload: InitPayload) { @@ -55,27 +70,33 @@ export default class CloudPanel extends TanglePanel { log.trace(`${this.identifier} webview resolving`) - const html = await this.getHydratedHtml() + let appToken: string | null + try { + appToken = await this.getAppToken() + } catch (err: any) { + log.error(err?.message || err) + appToken = null + } + + const html = await this.getHydratedHtml(appToken) webviewView.webview.html = html webviewView.webview.options = { ...webviewOptions, localResourceRoots: [this.context.extensionUri], } + this.webview.subscribe((webview) => { + this.currentWebview = webview + }) + this.webview.next(webviewView.webview) log.trace(`${this.identifier} webview resolved`) return Promise.resolve() } - private async getHydratedHtml(): Promise { - let appToken: string | null + private async getHydratedHtml(appToken: string | null): Promise { let staticHtml: string - try { - appToken = await this.getAppToken(false).then((appToken) => appToken?.token ?? null) - } catch (err: any) { - log.error(err?.message || err) - appToken = null - } + try { staticHtml = await fetchStaticHtml(this.appUrl).then((r) => r.text()) } catch (err: any) { @@ -94,10 +115,11 @@ export default class CloudPanel extends TanglePanel { protected registerSubscribers(bus: SyncSchemaBus) { return [ bus.on('onCommand', (cmdEvent) => { - if (cmdEvent?.name !== 'signIn') { - return + if (cmdEvent?.name === 'signIn') { + this.onSignIn(bus) + } else if (cmdEvent?.name === 'signOut') { + this.onSignOut(bus) } - this.onSignIn(bus) }), bus.on('onArchiveCell', async (cmdEvent) => { const answer = await window.showInformationMessage( @@ -150,7 +172,21 @@ export default class CloudPanel extends TanglePanel { private async onSignIn(bus: SyncSchemaBus) { try { const appToken = await this.getAppToken(true) - bus.emit('onAppToken', appToken!) + bus.emit('onAppToken', { token: appToken ?? 'EMPTY' }) + if (this.currentWebview) { + this.currentWebview.html = await this.getHydratedHtml(appToken) + } + } catch (err: any) { + log.error(err?.message || err) + } + } + + private async onSignOut(bus: SyncSchemaBus) { + try { + bus.emit('onAppToken', { token: 'EMPTY' }) + if (this.currentWebview) { + this.currentWebview.html = await this.getHydratedHtml(null) + } } catch (err: any) { log.error(err?.message || err) } diff --git a/src/extension/provider/githubAuth.ts b/src/extension/provider/githubAuth.ts index c6ccfe499..006789366 100644 --- a/src/extension/provider/githubAuth.ts +++ b/src/extension/provider/githubAuth.ts @@ -1,16 +1,48 @@ -import { Disposable, ExtensionContext, authentication } from 'vscode' +import { + authentication, + AuthenticationGetSessionOptions, + Disposable, + ExtensionContext, +} from 'vscode' -import { checkSession } from '../utils' -import { GITHUB_USER_SIGNED_IN } from '../../constants' +import { AuthenticationProviders, GITHUB_USER_SIGNED_IN } from '../../constants' import ContextState from '../contextState' +import AuthSessionChangeHandler from '../authSessionChangeHandler' +import { Kernel } from '../kernel' export class GithubAuthProvider implements Disposable { - constructor(context: ExtensionContext) { + constructor( + readonly context: ExtensionContext, + readonly kernel: Kernel, + ) { const userSignedIn = context.globalState.get(GITHUB_USER_SIGNED_IN, false) ContextState.addKey(GITHUB_USER_SIGNED_IN, userSignedIn) - authentication.onDidChangeSessions(() => { - checkSession(context) + AuthSessionChangeHandler.instance.addListener((e) => { + if (e.provider.id === AuthenticationProviders.GitHub) { + this.checkGithubSession() + } }) } + + async checkGithubSession() { + const session = await getGithubAuthSession(false, true) + this.context.globalState.update(GITHUB_USER_SIGNED_IN, !!session) + ContextState.addKey(GITHUB_USER_SIGNED_IN, !!session) + this.kernel.updateFeatureContext('githubAuth', !!session) + } + dispose() {} } + +export async function getGithubAuthSession(createIfNone: boolean = true, silent?: boolean) { + const scope = ['user:email'] + const options: AuthenticationGetSessionOptions = {} + + if (silent !== undefined) { + options.silent = silent + } else { + options.createIfNone = createIfNone + } + + return await authentication.getSession(AuthenticationProviders.GitHub, scope, options) +} diff --git a/src/extension/provider/statefulAuth.ts b/src/extension/provider/statefulAuth.ts index 95b243a36..5b2af2f7a 100644 --- a/src/extension/provider/statefulAuth.ts +++ b/src/extension/provider/statefulAuth.ts @@ -14,6 +14,7 @@ import { AuthenticationProviderAuthenticationSessionsChangeEvent, Event, workspace, + AuthenticationGetSessionOptions, } from 'vscode' import { v4 as uuidv4 } from 'uuid' import fetch from 'node-fetch' @@ -21,21 +22,23 @@ import jwt, { JwtPayload } from 'jsonwebtoken' import { getAuthTokenPath, getDeleteAuthToken, getRunmeAppUrl } from '../../utils/configuration' import { AuthenticationProviders, PLATFORM_USER_SIGNED_IN, TELEMETRY_EVENTS } from '../../constants' -import { RunmeUriHandler } from '../handler/uri' import ContextState from '../contextState' import getLogger from '../logger' +import { FeatureName } from '../../types' +import * as features from '../features' const logger = getLogger('StatefulAuthProvider') const AUTH_NAME = 'Stateful' const SESSIONS_SECRET_KEY = `${AuthenticationProviders.Stateful}.sessions` +export const DEFAULT_SCOPES = ['profile'] interface TokenInformation { accessToken: string expiresIn: number } -interface StatefulAuthSession extends AuthenticationSession { +export interface StatefulAuthSession extends AuthenticationSession { expiresIn: number isExpired: boolean } @@ -60,52 +63,131 @@ interface PromiseAdapter { const passthrough = (value: any, resolve: (value?: any) => void) => resolve(value) +type SessionsChangeEvent = AuthenticationProviderAuthenticationSessionsChangeEvent + export class StatefulAuthProvider implements AuthenticationProvider, Disposable { - #disposables: Disposable[] = [] - // used as compound key in a hash-table; does not contain sensitive data - #insensitiveHashedApiUrl: string = crypto - .createHash('sha1') - .update(getRunmeAppUrl(['api'])) - .digest('hex') + static #instance: StatefulAuthProvider | null = null + static #context: ExtensionContext | null = null + #pendingStates: string[] = [] #codeVerfifiers = new Map() #scopes = new Map() - #uriHandler: RunmeUriHandler + #disposables: Disposable[] = [] #codeExchangePromises = new Map< string, { promise: Promise; cancel: EventEmitter } >() + readonly #onSessionChange = this.registerDisposable(new EventEmitter()) + readonly #onAuthEvent = this.registerDisposable(new EventEmitter()) + + public static get instance(): StatefulAuthProvider { + this.assertContext(this.#context) + + if (!StatefulAuthProvider.#instance) { + const instance = new StatefulAuthProvider() + const disposable = authentication.registerAuthenticationProvider( + AuthenticationProviders.Stateful, + AUTH_NAME, + instance, + ) + instance.registerDisposable(disposable) + StatefulAuthProvider.#instance = instance + } + + return StatefulAuthProvider.#instance + } + + static initialize(context: ExtensionContext) { + this.#context = context + } + + static assertContext( + context: ExtensionContext | null, + ): asserts context is NonNullable { + if (!context) { + throw new Error('Missing context dependency, requires StatefulAuthProvider.initialize') + } + } - readonly #onSessionChange = this.register( - new EventEmitter(), - ) - - constructor( - private readonly context: ExtensionContext, - uriHandler: RunmeUriHandler, - ) { - this.#uriHandler = uriHandler - this.#disposables.push( - Disposable.from( - authentication.registerAuthenticationProvider( - AuthenticationProviders.Stateful, - AUTH_NAME, - this, - { - supportsMultipleAccounts: false, - }, - ), - ), + async newSession(silent?: boolean): Promise { + const options: AuthenticationGetSessionOptions = {} + + if (silent !== undefined) { + options.silent = silent + } else { + options.createIfNone = true + } + + const session = await authentication.getSession( + AuthenticationProviders.Stateful, + DEFAULT_SCOPES, + options, ) + + return session as StatefulAuthSession | undefined + } + + async currentSession() { + const sessions = await this.getSessions(this.getScopes()) + if (!sessions.length) { + return + } + + return sessions[0] + } + + async ensureSession(): Promise { + let session = await this.currentSession() + if (session) { + StatefulAuthProvider.showLoginNotification() + return session + } + + session = await StatefulAuthProvider.bootstrapFromToken() + const forceLogin = features.isOnInContextState(FeatureName.ForceLogin) || !!session + + const silent = forceLogin ? undefined : true + + return this.newSession(silent) + .then(() => { + if (session) { + StatefulAuthProvider.showLoginNotification() + return session + } + }) + .catch((error) => { + let message + if (error instanceof Error) { + message = error.message + } else { + message = JSON.stringify(error) + } + + logger.error(message) + + // https://github.com/microsoft/vscode/blob/main/src/vs/workbench/api/browser/mainThreadAuthentication.ts#L238 + // throw new Error('User did not consent to login.') + // Calling again to ensure User Menu Badge + if (forceLogin && message === 'User did not consent to login.') { + authentication.getSession(AuthenticationProviders.Stateful, DEFAULT_SCOPES, {}) + } + return error + }) } get onDidChangeSessions() { return this.#onSessionChange.event } + fireOnAuthEvent(data: Uri) { + this.#onAuthEvent.fire(data) + } + get redirectUri() { - const publisher = this.context.extension.packageJSON.publisher - const name = this.context.extension.packageJSON.name + StatefulAuthProvider.assertContext(StatefulAuthProvider.#context) + + const publisher = StatefulAuthProvider.#context.extension.packageJSON.publisher + const name = StatefulAuthProvider.#context.extension.packageJSON.name let callbackUrl = `${env.uriScheme}://${publisher}.${name}` return callbackUrl @@ -116,9 +198,9 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable * @param scopes * @returns */ - public async getSessions(scopes?: string[]): Promise { + public async getSessions(scopes?: string[]): Promise { try { - const sessions = await this.getAllSessions() + const sessions = await StatefulAuthProvider.getAllSessions() if (!sessions.length) { return [] } @@ -148,8 +230,15 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable await ContextState.addKey(PLATFORM_USER_SIGNED_IN, false) await this.removeSession(session.id) } - } catch (e) { + } catch (error) { // Nothing to do + let message + if (error instanceof Error) { + message = error.message + } else { + message = JSON.stringify(error) + } + logger.error(message) } return [] @@ -160,15 +249,16 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable * @param scopes * @returns */ - public async createSession(scopes: string[]): Promise { + async createSession(scopes: string[]): Promise { try { - const { accessToken, expiresIn } = await this.login(scopes) + const { accessToken, expiresIn } = await StatefulAuthProvider.instance.login(scopes) if (!accessToken) { throw new Error('Stateful login failure') } - const userinfo: { name: string; email: string } = await this.getUserInfo(accessToken) + const userinfo: { name: string; email: string } = + await StatefulAuthProvider.getUserInfo(accessToken) const session: StatefulAuthSession = { id: uuidv4(), expiresIn: secsToUnixTime(expiresIn), @@ -182,7 +272,12 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable } await ContextState.addKey(PLATFORM_USER_SIGNED_IN, true) - await this.persistSessions([session], { added: [session], removed: [], changed: [] }) + + const persist = this.persistSessions([session], { + added: [session], + } as unknown as SessionsChangeEvent) + await persist + return session } catch (e) { window.showErrorMessage(`Sign in failed: ${e}`) @@ -195,7 +290,7 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable * @param sessionId */ public async removeSession(sessionId: string): Promise { - const sessions = await this.getAllSessions() + const sessions = await StatefulAuthProvider.getAllSessions() if (!sessions.length) { return } @@ -208,7 +303,9 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable const session = sessions[sessionIdx] sessions.splice(sessionIdx, 1) - await this.persistSessions(sessions, { added: [], removed: [session], changed: [] }) + await this.persistSessions(sessions, { + removed: [session], + } as unknown as SessionsChangeEvent) } /** @@ -216,20 +313,26 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable */ public async dispose() { this.#disposables.forEach((d) => d.dispose()) + this.#disposables = [] + StatefulAuthProvider.#instance = null } - public async bootstrapFromToken(): Promise { + public static async bootstrapFromToken(): Promise { try { - const authTokenUri = await this.getAuthTokenUri() + const authTokenUri = await this.instance.getAuthTokenUri() if (!authTokenUri) { logger.info('No auth token file found, halting bootstrap from token.') - return false + return } const { token, payload } = await this.insecureDecode(authTokenUri) const session = await this.buildSession(token, payload) - await this.persistSessions([session], { added: [session], removed: [], changed: [] }) + await this.instance.persistSessions([session], { + added: [session], + removed: undefined, + changed: undefined, + }) await this.deleteAuthTokenFile(authTokenUri) - return true + return session } catch (error) { let message if (error instanceof Error) { @@ -239,7 +342,6 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable } logger.error(message) } - return false } private async getAuthTokenUri(): Promise { @@ -264,7 +366,7 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable /** * Decode a JWT token without verifying its signature. */ - private async insecureDecode(authTokenUri: Uri) { + private static async insecureDecode(authTokenUri: Uri) { const bytes = await workspace.fs.readFile(authTokenUri) if (!bytes?.length) { throw new Error('Failed to read token file') @@ -279,7 +381,7 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable return { payload, token } } - private async buildSession(token: string, payload: DecodedToken) { + private static async buildSession(token: string, payload: DecodedToken) { if (!payload.exp || !payload.scope) { throw new Error('Invalid token format, missing exp or scope') } @@ -304,7 +406,7 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable return session } - private async deleteAuthTokenFile(authTokenUri: Uri) { + private static async deleteAuthTokenFile(authTokenUri: Uri) { if (getDeleteAuthToken()) { logger.info(`Deleting authToken file ${authTokenUri}`) await workspace.fs.delete(authTokenUri) @@ -361,10 +463,7 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable if (!codeExchangePromise) { // Creating a new codeExchangePromise using promiseFromEvent and setting up // event handling with handleUri function - codeExchangePromise = promiseFromEvent( - this.#uriHandler.onAuthEvent, - this.handleUri(scopes), - ) + codeExchangePromise = promiseFromEvent(this.#onAuthEvent.event, this.handleUri(scopes)) // Storing the newly created codeExchangePromise in the map with the corresponding scopeString this.#codeExchangePromises.set(scopeString, codeExchangePromise) } @@ -464,7 +563,7 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable * @param token * @returns */ - private async getUserInfo(token: string) { + private static async getUserInfo(token: string) { const response = await fetch(`${getRunmeAppUrl(['api'])}idp-user-info`, { headers: { Authorization: `Bearer ${token}`, @@ -516,12 +615,22 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable return currentTime < oneHourBeforeExpiration } - private get sessionSecretKey() { - return `${SESSIONS_SECRET_KEY}.${this.#insensitiveHashedApiUrl}` + public static get insensitiveHashedApiUrl() { + // used as compound key in a hash-table; does not contain sensitive data + return crypto + .createHash('sha1') + .update(getRunmeAppUrl(['api'])) + .digest('hex') + } + + public static get sessionSecretKey() { + return `${SESSIONS_SECRET_KEY}.${this.insensitiveHashedApiUrl}` } - private async getAllSessions(): Promise { - const allSessions = await this.context.secrets.get(this.sessionSecretKey) + private static async getAllSessions(): Promise { + this.assertContext(this.#context) + + const allSessions = await this.#context.secrets.get(this.sessionSecretKey) if (!allSessions) { return [] } @@ -538,19 +647,16 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable return sessions.findIndex((s) => s.id === id) } - private async persistSessions( - sessions: StatefulAuthSession[], - changes: { - added: StatefulAuthSession[] - removed: StatefulAuthSession[] - changed: StatefulAuthSession[] - }, - ) { - await this.context.secrets.store(this.sessionSecretKey, JSON.stringify(sessions)) - this.#onSessionChange.fire(changes) + private async persistSessions(sessions: StatefulAuthSession[], changes: SessionsChangeEvent) { + StatefulAuthProvider.assertContext(StatefulAuthProvider.#context) + await StatefulAuthProvider.#context.secrets.store( + StatefulAuthProvider.sessionSecretKey, + JSON.stringify(sessions), + ) + StatefulAuthProvider.instance.#onSessionChange.fire(changes) } - protected register(disposable: T): T { + protected registerDisposable(disposable: T): T { this.#disposables.push(disposable) return disposable } @@ -563,35 +669,40 @@ export class StatefulAuthProvider implements AuthenticationProvider, Disposable } if (this.isTokenNotExpired(session.expiresIn)) { - // Emit a 'session changed' event to notify that the token has been accessed. - // This ensures that any components listening for session changes are notified appropriately. - this.#onSessionChange.fire({ added: [], removed: [], changed: [session] }) - ContextState.addKey(PLATFORM_USER_SIGNED_IN, true) + await ContextState.addKey(PLATFORM_USER_SIGNED_IN, true) return session } return { ...session, isExpired: true } } - showLoginNotification() { - if (!this.context.globalState.get(TELEMETRY_EVENTS.OpenWorkspace, true)) { + static showLoginNotification() { + this.assertContext(this.#context) + + if (!this.#context.globalState.get(TELEMETRY_EVENTS.OpenWorkspace, true)) { return } const openWorkspace = 'Open Workspace' const dontAskAgain = "Don't ask again" + const informationMessageCallback: ( + answer: typeof openWorkspace | typeof dontAskAgain | undefined, + ) => void = (answer) => { + this.assertContext(this.#context) + + if (answer === openWorkspace) { + const dashboardUri = getRunmeAppUrl(['app']) + const uri = Uri.parse(dashboardUri) + env.openExternal(uri) + } else if (answer === dontAskAgain) { + this.#context.globalState.update(TELEMETRY_EVENTS.OpenWorkspace, false) + } + } + window .showInformationMessage('Logged into the Stateful Cloud', openWorkspace, dontAskAgain) - .then((answer) => { - if (answer === openWorkspace) { - const dashboardUri = getRunmeAppUrl(['app']) - const uri = Uri.parse(dashboardUri) - env.openExternal(uri) - } else if (answer === dontAskAgain) { - this.context.globalState.update(TELEMETRY_EVENTS.OpenWorkspace, false) - } - }) + .then(informationMessageCallback) } } diff --git a/src/extension/services/archiveCell.ts b/src/extension/services/archiveCell.ts index b6b086801..3593c81b9 100644 --- a/src/extension/services/archiveCell.ts +++ b/src/extension/services/archiveCell.ts @@ -1,13 +1,11 @@ import { TelemetryReporter } from 'vscode-telemetry' -import { InitializeClient } from '../api/client' -import { resolveAuthToken } from '../utils' +import { InitializeCloudClient } from '../api/client' import { ArchiveCellOutputDocument } from '../__generated-platform__/graphql' export default async function archiveCell(cellId: string): Promise { try { - const token = await resolveAuthToken() - const graphClient = InitializeClient({ runmeToken: token }) + const graphClient = await InitializeCloudClient() await graphClient.mutate({ mutation: ArchiveCellOutputDocument, variables: { diff --git a/src/extension/services/unArchiveCell.ts b/src/extension/services/unArchiveCell.ts index 4d7bf2fc4..6952a49fc 100644 --- a/src/extension/services/unArchiveCell.ts +++ b/src/extension/services/unArchiveCell.ts @@ -1,13 +1,11 @@ import { TelemetryReporter } from 'vscode-telemetry' -import { InitializeClient } from '../api/client' -import { resolveAuthToken } from '../utils' +import { InitializeCloudClient } from '../api/client' import { UnArchiveCellOutputDocument } from '../__generated-platform__/graphql' export default async function unArchiveCell(cellId: string): Promise { try { - const token = await resolveAuthToken() - const graphClient = InitializeClient({ runmeToken: token }) + const graphClient = await InitializeCloudClient() await graphClient.mutate({ mutation: UnArchiveCellOutputDocument, variables: { diff --git a/src/extension/signedIn.ts b/src/extension/signedIn.ts index cc28a8774..c70b4968d 100644 --- a/src/extension/signedIn.ts +++ b/src/extension/signedIn.ts @@ -1,4 +1,4 @@ -import { Disposable, NotebookCell, NotebookEditor, authentication } from 'vscode' +import { Disposable, NotebookCell, NotebookEditor } from 'vscode' import { mergeMap, withLatestFrom } from 'rxjs/operators' import { Observable, Subject, Subscription, from, of } from 'rxjs' @@ -7,10 +7,11 @@ import { ClientMessages } from '../constants' import { GrpcSerializer } from './serializer' import { Kernel } from './kernel' -import { getPlatformAuthSession } from './utils' import './wasm/wasm_exec.js' import { RunmeEventInputType } from './__generated-platform__/graphql' import getLogger from './logger' +import { StatefulAuthProvider } from './provider/statefulAuth' +import AuthSessionChangeHandler from './authSessionChangeHandler' export interface CellRun { cell: { id: any } @@ -30,8 +31,9 @@ export class SignedIn implements Disposable { constructor(protected readonly kernel: Kernel) { const signedIn$ = new Observable((observer) => { - authentication.onDidChangeSessions(() => { - getPlatformAuthSession(false) + AuthSessionChangeHandler.instance.addListener(() => { + StatefulAuthProvider.instance + .currentSession() .then((session) => observer.next(!!session)) .catch(() => observer.next(false)) }) diff --git a/src/extension/utils.ts b/src/extension/utils.ts index 7dd050e81..9055b4fb4 100644 --- a/src/extension/utils.ts +++ b/src/extension/utils.ts @@ -18,9 +18,6 @@ import vscode, { commands, WorkspaceFolder, ExtensionContext, - authentication, - AuthenticationSession, - AuthenticationGetSessionOptions, } from 'vscode' import { v5 as uuidv5 } from 'uuid' import getPort from 'get-port' @@ -38,12 +35,10 @@ import { } from '../types' import { SafeCellAnnotationsSchema, CellAnnotationsSchema } from '../schema' import { - AuthenticationProviders, NOTEBOOK_AVAILABLE_CATEGORIES, SERVER_ADDRESS, CATEGORY_SEPARATOR, NOTEBOOK_AUTOSAVE_ON, - GITHUB_USER_SIGNED_IN, NOTEBOOK_OUTPUTS_MASKED, NOTEBOOK_LIFECYCLE_ID, } from '../constants' @@ -74,6 +69,7 @@ import ContextState from './contextState' import { GCPResolver } from './resolvers/gcpResolver' import { AWSResolver } from './resolvers/awsResolver' import { RunmeIdentity } from './grpc/serializerTypes' +import { StatefulAuthProvider } from './provider/statefulAuth' declare var globalThis: any @@ -564,48 +560,6 @@ export function convertEnvList(envs: string[]): Record, ) } - -export function getGithubAuthSession(createIfNone: boolean = true) { - return authentication.getSession(AuthenticationProviders.GitHub, ['user:email'], { - createIfNone, - }) -} - -export async function getPlatformAuthSession(createIfNone: boolean = true, silent?: boolean) { - const scopes = ['profile'] - const options: AuthenticationGetSessionOptions = {} - - if (silent !== undefined) { - options.silent = silent - } else { - options.createIfNone = createIfNone - } - - return await authentication.getSession(AuthenticationProviders.Stateful, scopes, options) -} - -export async function resolveAuthToken(createIfNone: boolean = true) { - let session: AuthenticationSession | undefined - session = await getPlatformAuthSession(createIfNone) - if (!session) { - throw new Error('You must authenticate with your Stateful account') - } - - return session.accessToken -} - -export async function resolveAppToken(createIfNone: boolean = true) { - if (features.isOnInContextState(FeatureName.RequireStatefulAuth)) { - const session = await getPlatformAuthSession(createIfNone) - if (!session) { - return null - } - return { token: session.accessToken } - } - - return null -} - export function fetchStaticHtml(appUrl: string) { return fetch(appUrl) } @@ -730,15 +684,13 @@ export function asWorkspaceRelativePath(documentPath: string): { /** * Handles the first time experience for saving a cell. - * It informs the user that a Login with a GitHub account is required before prompting the user. + * It informs the user that a Login with a Stateful account is required before prompting the user. * This only happens once. Subsequent saves will not display the prompt. * @returns AuthenticationSession */ export async function promptUserSession() { const createIfNone = features.isOnInContextState(FeatureName.ForceLogin) - const silent = createIfNone ? undefined : true - - const session = await getPlatformAuthSession(false, silent) + const session = await StatefulAuthProvider.instance.currentSession() const displayLoginPrompt = getLoginPrompt() && createIfNone && features.isOnInContextState(FeatureName.Share) @@ -759,36 +711,10 @@ export async function promptUserSession() { return commands.executeCommand('runme.openSettings', 'runme.app.loginPrompt') } - getPlatformAuthSession(createIfNone) - .then((session) => { - if (!session) { - throw new Error('You must authenticate with your Stateful account') - } - }) - .catch((error) => { - let message - if (error instanceof Error) { - message = error.message - } else { - message = String(error) - } - - // https://github.com/microsoft/vscode/blob/main/src/vs/workbench/api/browser/mainThreadAuthentication.ts#L238 - // throw new Error('User did not consent to login.') - // Calling again to ensure User Menu Badge - if (createIfNone && message === 'User did not consent to login.') { - getPlatformAuthSession(false) - } - }) + StatefulAuthProvider.instance.ensureSession() } } -export async function checkSession(context: ExtensionContext) { - const session = await getGithubAuthSession(false) - context.globalState.update(GITHUB_USER_SIGNED_IN, !!session) - ContextState.addKey(GITHUB_USER_SIGNED_IN, !!session) -} - export function editJsonc( originalText: string, propertyToUpdate: string, diff --git a/src/features.ts b/src/features.ts index 1a8637c54..96d91d4e9 100644 --- a/src/features.ts +++ b/src/features.ts @@ -101,60 +101,37 @@ function isActive( } = feature.conditions if (!checkEnabled(feature.enabled, overrides.get(featureName))) { - console.log(`Feature "${featureName}" is inactive due to checkEnabled.`) return false } if (!checkOS(os, context?.os)) { - console.log( - `Feature "${featureName}" is inactive due to checkOS. Expected OS: ${os}, actual OS: ${context?.os}`, - ) return false } if (!checkVersion(vsCodeVersion, context?.vsCodeVersion)) { - console.log( - `Feature "${featureName}" is inactive due to checkVersion (vsCodeVersion). Expected: ${vsCodeVersion}, actual: ${context?.vsCodeVersion}`, - ) return false } if (!checkVersion(runmeVersion, context?.runmeVersion)) { - console.log( - `Feature "${featureName}" is inactive due to checkVersion (runmeVersion). Expected: ${runmeVersion}, actual: ${context?.runmeVersion}`, - ) return false } if (!checkVersion(extensionVersion, context?.extensionVersion)) { - console.log( - `Feature "${featureName}" is inactive due to checkVersion (extensionVersion). Expected: ${extensionVersion}, actual: ${context?.extensionVersion}`, - ) return false } if (!checkAuth(githubAuthRequired, context?.githubAuth)) { - console.log( - `Feature "${featureName}" is inactive due to checkAuth (githubAuth). Required: ${githubAuthRequired}, actual: ${context?.githubAuth}`, - ) return false } if (!checkAuth(statefulAuthRequired, context?.statefulAuth)) { - console.log( - `Feature "${featureName}" is inactive due to checkAuth (statefulAuth). Required: ${statefulAuthRequired}, actual: ${context?.statefulAuth}`, - ) return false } if (!checkExtensionId(enabledForExtensions, context?.extensionId)) { - console.log( - `Feature "${featureName}" is inactive due to checkExtensionId. Expected: ${JSON.stringify(enabledForExtensions)}, actual: ${context?.extensionId}`, - ) return false } - console.log(`Feature "${featureName} is active`) return true } diff --git a/src/utils/configuration.ts b/src/utils/configuration.ts index 0e4f5f163..8ed680abb 100644 --- a/src/utils/configuration.ts +++ b/src/utils/configuration.ts @@ -20,9 +20,9 @@ const APP_SECTION_NAME = 'runme.app' export const OpenViewInEditorAction = z.enum(['split', 'toggle']) const DEFAULT_WORKSPACE_FILE_ORDER = ['.env.local', '.env'] -const DEFAULT_RUNME_APP_API_URL = 'https://platform.stateful.com' -const DEFAULT_RUNME_BASE_DOMAIN = 'platform.stateful.com' -const DEFAULT_RUNME_REMOTE_DEV = 'staging.platform.stateful.com' +const DEFAULT_RUNME_APP_API_URL = 'https://cloud.stateful.com' +const DEFAULT_RUNME_BASE_DOMAIN = 'cloud.stateful.com' +const DEFAULT_RUNME_REMOTE_DEV = 'staging.cloud.stateful.com' const DEFAULT_DOCS_URL = 'https://docs.runme.dev' const APP_LOOPBACKS = ['127.0.0.1', 'localhost'] const APP_LOOPBACK_MAPPING = new Map([ diff --git a/tests/extension/commands/index.test.ts b/tests/extension/commands/index.test.ts index 825d35621..42fcc99f1 100644 --- a/tests/extension/commands/index.test.ts +++ b/tests/extension/commands/index.test.ts @@ -62,6 +62,7 @@ vi.mock('../../../src/utils/configuration', () => ({ isNotebookTerminalEnabledForCell: vi.fn(), getCLIUseIntegratedRunme: vi.fn().mockReturnValue(false), OpenViewInEditorAction: { enum: { toggle: 'toggle', split: 'split' } }, + getRunmeAppUrl: vi.fn(() => 'localhost'), })) vi.mock('../../../src/extension/provider/cli', () => ({ CliProvider: { diff --git a/tests/extension/configuration.test.ts b/tests/extension/configuration.test.ts index 644d71858..73db568b6 100644 --- a/tests/extension/configuration.test.ts +++ b/tests/extension/configuration.test.ts @@ -241,17 +241,17 @@ suite('Configuration', () => { test('should return URL for api with subdomain', () => { const url = getRunmeAppUrl(['api']) - expect(url).toStrictEqual('https://api.platform.stateful.com/') + expect(url).toStrictEqual('https://api.cloud.stateful.com/') }) test('should return URL for api with deep subdomain', () => { const url = getRunmeAppUrl(['l4', 'l3', 'api']) - expect(url).toStrictEqual('https://l4.l3.api.platform.stateful.com/') + expect(url).toStrictEqual('https://l4.l3.api.cloud.stateful.com/') }) test('should return URL without subdomain', () => { const url = getRunmeAppUrl([]) - expect(url).toStrictEqual('https://platform.stateful.com/') + expect(url).toStrictEqual('https://cloud.stateful.com/') }) test('should allow api URL with http for 127.0.0.1', async () => { @@ -277,24 +277,24 @@ suite('Configuration', () => { const app = getRunmeAppUrl(['app']) expect(app).toStrictEqual('http://localhost:4001') const api = getRunmeAppUrl(['api']) - expect(api).toStrictEqual('https://api.staging.platform.stateful.com/') + expect(api).toStrictEqual('https://api.staging.cloud.stateful.com/') }) test('should return URL for api with subdomain for staging', () => { - workspace.getConfiguration().update('baseDomain', 'staging.platform.stateful.com') + workspace.getConfiguration().update('baseDomain', 'staging.cloud.stateful.com') const url = getRunmeAppUrl(['api']) - expect(url).toStrictEqual('https://api.staging.platform.stateful.com/') + expect(url).toStrictEqual('https://api.staging.cloud.stateful.com/') }) test('should return URL for app with subdomain', () => { const url = getRunmeAppUrl(['app']) - expect(url).toStrictEqual('https://platform.stateful.com/') + expect(url).toStrictEqual('https://cloud.stateful.com/') }) test('should return URL for app with subdomain for staging', () => { - workspace.getConfiguration().update('baseDomain', 'staging.platform.stateful.com') + workspace.getConfiguration().update('baseDomain', 'staging.cloud.stateful.com') const url = getRunmeAppUrl(['app']) - expect(url).toStrictEqual('https://staging.platform.stateful.com/') + expect(url).toStrictEqual('https://staging.cloud.stateful.com/') }) }) diff --git a/tests/extension/extension.test.ts b/tests/extension/extension.test.ts index 72d553f50..10a8c8c89 100644 --- a/tests/extension/extension.test.ts +++ b/tests/extension/extension.test.ts @@ -7,6 +7,7 @@ import { RunmeExtension } from '../../src/extension/extension' import { bootFile } from '../../src/extension/utils' import KernelServer from '../../src/extension/server/kernelServer' import { testCertPEM, testPrivKeyPEM } from '../testTLSCert' +import { StatefulAuthProvider } from '../../src/extension/provider/statefulAuth' vi.mock('vscode') vi.mock('vscode-telemetry') @@ -69,7 +70,6 @@ vi.mock('../../src/extension/utils', async () => ({ togglePreviewButton: vi.fn(), resetNotebookSettings: vi.fn(), getGithubAuthSession: vi.fn().mockResolvedValue(undefined), - getPlatformAuthSession: vi.fn().mockResolvedValue(undefined), getEnvProps: vi.fn().mockReturnValue({ extname: 'stateful.runme', extversion: '1.2.3-foo.1', @@ -115,8 +115,12 @@ test('initializes all providers', async () => { globalState: { get: vi.fn(), }, + secrets: { + store: vi.fn(), + }, } const ext = new RunmeExtension() + StatefulAuthProvider.initialize(context) await ext.initialize(context) expect(notebooks.registerNotebookCellStatusBarItemProvider).toBeCalledTimes(5) expect(workspace.registerNotebookSerializer).toBeCalledTimes(1) diff --git a/tests/extension/github.test.ts b/tests/extension/github.test.ts index 835fdfef2..8ec768c46 100644 --- a/tests/extension/github.test.ts +++ b/tests/extension/github.test.ts @@ -5,6 +5,8 @@ import { window, authentication, AuthenticationSession, + ExtensionContext, + Uri, } from 'vscode' import { expect, suite, vi, test } from 'vitest' @@ -12,6 +14,7 @@ import * as CellManager from '../../src/extension/cell' import { github } from '../../src/extension/executors/github' import { Kernel } from '../../src/extension/kernel' import { IKernelExecutorOptions } from '../../src/extension/executors' +import { StatefulAuthProvider } from '../../src/extension/provider/statefulAuth' vi.mock('vscode', async () => { const vscode = await import('../../__mocks__/vscode') @@ -27,6 +30,16 @@ vi.mock('../../../src/extension/grpc/runner/v1', () => ({ ResolveProgramRequest_Mode: vi.fn(), })) +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + class OctokitMock { protected rest: any constructor() { diff --git a/tests/extension/kernel.test.ts b/tests/extension/kernel.test.ts index c60100a63..f6e3e13bc 100644 --- a/tests/extension/kernel.test.ts +++ b/tests/extension/kernel.test.ts @@ -1,5 +1,5 @@ import { test, expect, vi, suite, beforeEach } from 'vitest' -import { NotebookCell, commands, notebooks, window, workspace } from 'vscode' +import { ExtensionContext, NotebookCell, Uri, commands, notebooks, window, workspace } from 'vscode' import { Kernel } from '../../src/extension/kernel' import executors from '../../src/extension/executors' @@ -10,6 +10,7 @@ import * as platform from '../../src/extension/messages/platformRequest/saveCell import { isPlatformAuthEnabled } from '../../src/utils/configuration' import { askAlternativeOutputsAction } from '../../src/extension/commands' import { getEventReporter } from '../../src/extension/ai/events' +import { StatefulAuthProvider } from '../../src/extension/provider/statefulAuth' const reportExecution = vi.fn() @@ -49,7 +50,6 @@ vi.mock('../../src/extension/utils', async () => { platform: 'darwin_arm64', uikind: 'desktop', }), - getPlatformAuthSession: vi.fn().mockResolvedValue(undefined), } }) vi.mock('../../src/utils/configuration', async (importActual) => { @@ -93,6 +93,17 @@ const genCells = (cnt: number, metadata: Record = {}) => }, }) as any as NotebookCell, ) + +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + suite('#handleRendererMessage', () => { const editor = { notebook: { diff --git a/tests/extension/messages/cellOutput.test.ts b/tests/extension/messages/cellOutput.test.ts index ccbebf143..8fdafb7ca 100644 --- a/tests/extension/messages/cellOutput.test.ts +++ b/tests/extension/messages/cellOutput.test.ts @@ -1,9 +1,10 @@ -import { NotebookCell } from 'vscode' +import { ExtensionContext, NotebookCell, Uri } from 'vscode' import { suite, vi, test, expect } from 'vitest' import { handleCellOutputMessage } from '../../../src/extension/messages/cellOutput' import { ClientMessages, OutputType } from '../../../src/constants' import { Kernel } from '../../../src/extension/kernel' +import { StatefulAuthProvider } from '../../../src/extension/provider/statefulAuth' vi.mock('vscode') vi.mock('vscode-telemetry') @@ -13,6 +14,16 @@ vi.mock('../../../src/extension/grpc/runner/v1', () => ({ ResolveProgramRequest_Mode: vi.fn(), })) +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + suite('Handle CellOutput messages', () => { const mockOutput = (type: OutputType) => { const cell = { diff --git a/tests/extension/messages/platformApiRequest/saveCellExecution.test.ts b/tests/extension/messages/platformApiRequest/saveCellExecution.test.ts index 1cfedf32f..8c26f2864 100644 --- a/tests/extension/messages/platformApiRequest/saveCellExecution.test.ts +++ b/tests/extension/messages/platformApiRequest/saveCellExecution.test.ts @@ -1,4 +1,4 @@ -import { AuthenticationSession, authentication, notebooks } from 'vscode' +import { ExtensionContext, notebooks, Uri } from 'vscode' import { suite, vi, it, beforeAll, afterAll, afterEach, expect } from 'vitest' import { HttpResponse, graphql } from 'msw' import { setupServer } from 'msw/node' @@ -10,6 +10,10 @@ import { Kernel } from '../../../../src/extension/kernel' import { ClientMessages } from '../../../../src/constants' import { APIMethod } from '../../../../src/types' import { GrpcSerializer } from '../../../../src/extension/serializer' +import { + StatefulAuthProvider, + StatefulAuthSession, +} from '../../../../src/extension/provider/statefulAuth' vi.mock('vscode-telemetry') vi.mock('../../../src/extension/runner', () => ({})) @@ -93,6 +97,16 @@ const mockCellInCache = (kernel, cellId) => { }) } +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + suite('Save cell execution', () => { const kernel = new Kernel({} as any) kernel.hasExperimentEnabled = vi.fn((params) => params === 'reporter') @@ -100,7 +114,7 @@ suite('Save cell execution', () => { const cellId = 'cell-id' mockCellInCache(kernel, cellId) const messaging = notebooks.createRendererMessaging('runme-renderer') - const authenticationSession: AuthenticationSession = { + const authenticationSession: StatefulAuthSession = { accessToken: '', id: '', scopes: ['repo'], @@ -108,6 +122,8 @@ suite('Save cell execution', () => { id: '', label: '', }, + isExpired: false, + expiresIn: 2145848400000, } const message = { type: ClientMessages.platformApiRequest, @@ -132,8 +148,9 @@ suite('Save cell execution', () => { }, } as any, } - vi.mocked(authentication.getSession).mockResolvedValue(authenticationSession) - + vi.spyOn(StatefulAuthProvider.instance, 'currentSession').mockResolvedValue( + authenticationSession, + ) await saveCellExecution(requestMessage, kernel) expect(messaging.postMessage).toMatchInlineSnapshot(` @@ -211,7 +228,7 @@ suite('Save cell execution', () => { }, } as any, } - vi.mocked(authentication.getSession).mockResolvedValue(undefined) + vi.spyOn(StatefulAuthProvider.instance, 'currentSession').mockResolvedValue(undefined) await saveCellExecution(requestMessage, kernel) expect(messaging.postMessage).toMatchInlineSnapshot(` @@ -247,9 +264,8 @@ suite('Save cell execution', () => { [ { "output": { - "data": { - "displayShare": false, - }, + "data": "You must authenticate with your Stateful account", + "hasErrors": true, "id": "cell-id", }, "type": "common:platformApiResponse", @@ -282,7 +298,7 @@ suite('Save cell execution', () => { const cellId = 'cell-id' const cacheId = 'cache-id' - const authenticationSession: AuthenticationSession = { + const authenticationSession: StatefulAuthSession = { accessToken: '', id: '', scopes: ['repo'], @@ -290,9 +306,13 @@ suite('Save cell execution', () => { id: '', label: '', }, + isExpired: false, + expiresIn: 2145848400000, } vi.spyOn(GrpcSerializer, 'getDocumentCacheId').mockReturnValueOnce(cacheId) - vi.mocked(authentication.getSession).mockResolvedValue(authenticationSession) + vi.spyOn(StatefulAuthProvider.instance, 'currentSession').mockResolvedValue( + authenticationSession, + ) vi.spyOn(kernel, 'getNotebookDataCache').mockImplementationOnce(() => undefined) const messaging = notebooks.createRendererMessaging('runme-renderer') @@ -336,7 +356,7 @@ suite('Save cell execution', () => { const cacheId = 'cache-id' const notebookId = 'ulid' - const authenticationSession: AuthenticationSession = { + const authenticationSession: StatefulAuthSession = { accessToken: '', id: '', scopes: ['repo'], @@ -344,9 +364,13 @@ suite('Save cell execution', () => { id: '', label: '', }, + isExpired: false, + expiresIn: 2145848400000, } vi.spyOn(GrpcSerializer, 'getDocumentCacheId').mockReturnValueOnce(cacheId) - vi.mocked(authentication.getSession).mockResolvedValue(authenticationSession) + vi.spyOn(StatefulAuthProvider.instance, 'currentSession').mockResolvedValue( + authenticationSession, + ) vi.spyOn(kernel, 'getNotebookDataCache').mockImplementationOnce(() => ({ cells: [], })) diff --git a/tests/extension/panels/panel.test.ts b/tests/extension/panels/panel.test.ts index 4b02ba77f..2ebc93152 100644 --- a/tests/extension/panels/panel.test.ts +++ b/tests/extension/panels/panel.test.ts @@ -2,6 +2,7 @@ import { suite, test, expect, vi } from 'vitest' import { workspace, Uri, type ExtensionContext, type WebviewView } from 'vscode' import CloudPanel from '../../../src/extension/panels/cloud' +import { StatefulAuthProvider } from '../../../src/extension/provider/statefulAuth' vi.mock('vscode') vi.mock('vscode-telemetry') @@ -40,6 +41,16 @@ vi.mock('../../../src/extension/utils', () => { } }) +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + suite('Panel', () => { const staticHtml = '' @@ -58,7 +69,7 @@ suite('Panel', () => { themeKind: 1, }) - expect(hydrated).toContain('') + expect(hydrated).toContain('') expect(hydrated).toContain( '{"appToken":"a.b.c","ide":"code","panelId":"main","defaultUx":"panels","themeKind":1}', ) @@ -66,11 +77,11 @@ suite('Panel', () => { test('resolves authed', async () => { const p = new CloudPanel(contextMock, 'testing') - p.getAppToken = vi.fn().mockResolvedValue({ token: 'webview.auth.token' }) + p.getAppToken = vi.fn().mockResolvedValue('webview.auth.token') await p.resolveWebviewTelemetryView(view) - expect(view.webview.html).toContain('') + expect(view.webview.html).toContain('') expect(view.webview.html).toContain( '{"ide":"code","panelId":"testing","appToken":"webview.auth.token","defaultUx":"panels","themeKind":1}', ) @@ -82,7 +93,7 @@ suite('Panel', () => { await p.resolveWebviewTelemetryView(view) - expect(view.webview.html).toContain('') + expect(view.webview.html).toContain('') expect(view.webview.html).toContain( '{"ide":"code","panelId":"testing","appToken":"EMPTY","defaultUx":"panels","themeKind":1}', ) @@ -91,7 +102,7 @@ suite('Panel', () => { test('resolves authed localhost', async () => { workspace.getConfiguration().update('baseDomain', 'localhost') const p = new CloudPanel(contextMock, 'testing') - p.getAppToken = vi.fn().mockResolvedValue({ token: 'webview.auth.token' }) + p.getAppToken = vi.fn().mockResolvedValue('webview.auth.token') await p.resolveWebviewTelemetryView(view) diff --git a/tests/extension/provider/annotations.test.ts b/tests/extension/provider/annotations.test.ts index 6793437e7..1b749421a 100644 --- a/tests/extension/provider/annotations.test.ts +++ b/tests/extension/provider/annotations.test.ts @@ -1,9 +1,10 @@ import { vi, describe, it, expect } from 'vitest' -import { NotebookCellKind } from 'vscode' +import { ExtensionContext, NotebookCellKind, Uri } from 'vscode' import { AnnotationsStatusBarItem } from '../../../src/extension/provider/cellStatusBar/items/annotations' import { Kernel } from '../../../src/extension/kernel' import { OutputType } from '../../../src/constants' +import { StatefulAuthProvider } from '../../../src/extension/provider/statefulAuth' vi.mock('vscode') vi.mock('vscode-telemetry') @@ -34,6 +35,16 @@ vi.mock('../../../src/extension/utils', () => ({ vi.mock('../../../src/extension/runner', () => ({})) vi.mock('../../../src/extension/grpc/runner/v1', () => ({})) +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + describe('AnnotationsStatusBarItem test suite', () => { const kernel = new Kernel({} as any) diff --git a/tests/extension/provider/copy.test.ts b/tests/extension/provider/copy.test.ts index cd9da961c..55763162a 100644 --- a/tests/extension/provider/copy.test.ts +++ b/tests/extension/provider/copy.test.ts @@ -1,12 +1,23 @@ import { vi, test, expect } from 'vitest' -import { NotebookCellStatusBarAlignment } from 'vscode' +import { ExtensionContext, NotebookCellStatusBarAlignment, Uri } from 'vscode' import { CopyStatusBarItem } from '../../../src/extension/provider/cellStatusBar/items/copy' import { Kernel } from '../../../src/extension/kernel' +import { StatefulAuthProvider } from '../../../src/extension/provider/statefulAuth' vi.mock('vscode-telemetry') vi.mock('vscode') +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + test('NotebookCellStatusBarAlignment test suite', () => { const kernel = new Kernel({} as any) const p = new CopyStatusBarItem(kernel) diff --git a/tests/extension/provider/named.test.ts b/tests/extension/provider/named.test.ts index d6a7839c3..0bc255e65 100644 --- a/tests/extension/provider/named.test.ts +++ b/tests/extension/provider/named.test.ts @@ -1,8 +1,10 @@ +import { ExtensionContext, Uri } from 'vscode' import { vi, suite, test, expect, beforeEach } from 'vitest' import { getAnnotations } from '../../../src/extension/utils' import { NamedStatusBarItem } from '../../../src/extension/provider/cellStatusBar/items/named' import { Kernel } from '../../../src/extension/kernel' +import { StatefulAuthProvider } from '../../../src/extension/provider/statefulAuth' vi.mock('vscode-telemetry') vi.mock('vscode') @@ -12,6 +14,16 @@ vi.mock('../../../src/extension/utils', () => ({ isValidEnvVarName: vi.fn().mockReturnValue(true), })) +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + suite('NamedStatusBarItem Test Suite', () => { const kernel = new Kernel({} as any) diff --git a/tests/extension/provider/notebook.test.ts b/tests/extension/provider/notebook.test.ts index 55c598454..8eecae65c 100644 --- a/tests/extension/provider/notebook.test.ts +++ b/tests/extension/provider/notebook.test.ts @@ -1,8 +1,9 @@ +import { commands, ExtensionContext, NotebookCellKind, Uri } from 'vscode' import { vi, describe, it, expect } from 'vitest' -import { commands, NotebookCellKind } from 'vscode' import { NotebookCellStatusBarProvider } from '../../../src/extension/provider/cellStatusBar/notebook' import { Kernel } from '../../../src/extension/kernel' +import { StatefulAuthProvider } from '../../../src/extension/provider/statefulAuth' vi.mock('vscode') vi.mock('vscode-telemetry') @@ -29,6 +30,16 @@ vi.mock('../../../src/extension/utils', () => ({ vi.mock('../../../src/extension/runner', () => ({})) vi.mock('../../../src/extension/grpc/runner/v1', () => ({})) +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + describe('Notebook Cell Status Bar provider', () => { const kernel = new Kernel({} as any) it('should register commands when initializing', () => { diff --git a/tests/extension/provider/sessionOutputs.test.ts b/tests/extension/provider/sessionOutputs.test.ts index 0dfb2efc2..2504ce5fb 100644 --- a/tests/extension/provider/sessionOutputs.test.ts +++ b/tests/extension/provider/sessionOutputs.test.ts @@ -1,8 +1,9 @@ import { vi, describe, it, expect } from 'vitest' -import { commands, NotebookCellKind } from 'vscode' +import { commands, ExtensionContext, NotebookCellKind, Uri } from 'vscode' import { SessionOutputCellStatusBarProvider } from '../../../src/extension/provider/cellStatusBar/sessionOutput' import { Kernel } from '../../../src/extension/kernel' +import { StatefulAuthProvider } from '../../../src/extension/provider/statefulAuth' vi.mock('vscode') vi.mock('vscode-telemetry') @@ -28,6 +29,16 @@ vi.mock('../../../src/extension/utils', () => ({ vi.mock('../../../src/extension/runner', () => ({})) vi.mock('../../../src/extension/grpc/runner/v1', () => ({})) +const contextFake: ExtensionContext = { + extensionUri: Uri.parse('file:///Users/fakeUser/projects/vscode-runme'), + secrets: { + store: vi.fn(), + }, + subscriptions: [], +} as any + +StatefulAuthProvider.initialize(contextFake) + describe('Session Outputs Cell Status Bar provider', () => { const kernel = new Kernel({} as any) it('should register commands when initializing', () => { diff --git a/tests/extension/provider/statefulAuth.test.ts b/tests/extension/provider/statefulAuth.test.ts index 5d6f9e282..6a7335de2 100644 --- a/tests/extension/provider/statefulAuth.test.ts +++ b/tests/extension/provider/statefulAuth.test.ts @@ -1,12 +1,11 @@ import * as crypto from 'node:crypto' +import { ExtensionContext, Uri, workspace } from 'vscode' import { expect, vi, beforeEach, describe, it } from 'vitest' -import { Uri, ExtensionContext, workspace } from 'vscode' import fetch from 'node-fetch' import jwt from 'jsonwebtoken' import { StatefulAuthProvider } from '../../../src/extension/provider/statefulAuth' -import { RunmeUriHandler } from '../../../src/extension/handler/uri' import { getRunmeAppUrl } from '../../../src/utils/configuration' vi.mock('vscode') @@ -15,7 +14,7 @@ vi.mock('node-fetch') vi.mock('../../../src/utils/configuration', () => { return { - getRunmeAppUrl: vi.fn(), + getRunmeAppUrl: vi.fn(() => 'https://api.for.platform'), getDeleteAuthToken: vi.fn(() => true), getAuthTokenPath: vi.fn(() => '/path/to/auth/token'), } @@ -26,16 +25,16 @@ const contextFake: ExtensionContext = { secrets: { store: vi.fn(), }, + subscriptions: [], } as any -const uriHandlerFake: RunmeUriHandler = {} as any +StatefulAuthProvider.initialize(contextFake) describe('StatefulAuthProvider', () => { let provider: StatefulAuthProvider beforeEach(() => { - vi.mocked(getRunmeAppUrl).mockReturnValue('https://api.for.platform') - provider = new StatefulAuthProvider(contextFake, uriHandlerFake) + provider = StatefulAuthProvider.instance }) it('gets sessions', async () => { @@ -50,13 +49,9 @@ describe('StatefulAuthProvider', () => { }) describe('StatefulAuthProvider#sessionSecretKey', () => { - let provider: StatefulAuthProvider - it('returns a secret key for production', () => { - provider = new StatefulAuthProvider(contextFake, uriHandlerFake) - // access private prop - expect((provider as any).sessionSecretKey).toEqual( + expect(StatefulAuthProvider.sessionSecretKey).toEqual( 'stateful.sessions.8e0b4f45d990c8b235d4036020299d4af5c8c4a0', ) }) @@ -64,12 +59,10 @@ describe('StatefulAuthProvider#sessionSecretKey', () => { it('includes a hashed URL of the stage into the secret key', () => { const fakeStagingUrl = 'https://api.staging.for.platform' vi.mocked(getRunmeAppUrl).mockReturnValue(fakeStagingUrl) - - provider = new StatefulAuthProvider(contextFake, uriHandlerFake) const hashed = crypto.createHash('sha1').update(fakeStagingUrl).digest('hex') // access private prop - const sessionSecretKey = (provider as any).sessionSecretKey + const sessionSecretKey = StatefulAuthProvider.sessionSecretKey expect(sessionSecretKey).toContain(hashed) expect(sessionSecretKey).toEqual('stateful.sessions.5d458b91cb755f8e839839dd3d1b4d597bba2c11') @@ -77,16 +70,13 @@ describe('StatefulAuthProvider#sessionSecretKey', () => { }) describe('StatefulAuthProvider#bootstrapFromToken', () => { - let provider: StatefulAuthProvider - beforeEach(() => { vi.mocked(getRunmeAppUrl).mockReturnValue('https://api.stateful.dev/') - provider = new StatefulAuthProvider(contextFake, uriHandlerFake) }) it('returns undefined if no token is provided', async () => { vi.mocked(workspace.fs.stat).mockRejectedValueOnce({} as any) - const sessionCreated = await provider.bootstrapFromToken() + const sessionCreated = await StatefulAuthProvider.bootstrapFromToken() expect(sessionCreated).toBeFalsy() }) @@ -119,7 +109,7 @@ describe('StatefulAuthProvider#bootstrapFromToken', () => { ) const spyStore = vi.spyOn(contextFake.secrets, 'store') const spyDelete = vi.spyOn(workspace.fs, 'delete') - const sessionCreated = await provider.bootstrapFromToken() + const sessionCreated = await StatefulAuthProvider.bootstrapFromToken() expect(sessionCreated).toBeTruthy() expect(spyStore).toHaveBeenCalledOnce() @@ -150,7 +140,7 @@ describe('StatefulAuthProvider#bootstrapFromToken', () => { ) const spyStore = vi.spyOn(contextFake.secrets, 'store') const spyDelete = vi.spyOn(workspace.fs, 'delete') - const sessionCreated = await provider.bootstrapFromToken() + const sessionCreated = await StatefulAuthProvider.bootstrapFromToken() expect(sessionCreated).toBeFalsy() expect(spyStore).not.toHaveBeenCalledOnce()