Skip to content

Commit

Permalink
feat: add new auto refresh token algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Dec 20, 2022
1 parent 2b65646 commit 3138c21
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 73 deletions.
184 changes: 119 additions & 65 deletions src/GoTrueClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import {
resolveFetch,
setItemAsync,
uuid,
retryable,
sleep,
} from './lib/helpers'
import localStorageAdapter from './lib/local-storage'
import { polyfillGlobalThis } from './lib/polyfills'
Expand Down Expand Up @@ -79,6 +81,13 @@ const DEFAULT_OPTIONS: Omit<Required<GoTrueClientOptions>, 'fetch' | 'storage'>
headers: DEFAULT_HEADERS,
}

/** Current session will be checked for refresh will be checked at this interval. */
const AUTO_REFRESH_TICK_DURATION = 10 * 1000

/**
* A token refresh will be attempted this many ticks before the current session expires. */
const AUTO_REFRESH_TICK_THRESHOLD = 3

export default class GoTrueClient {
/**
* Namespace for the GoTrue admin methods.
Expand All @@ -104,8 +113,7 @@ export default class GoTrueClient {
protected persistSession: boolean
protected storage: SupportedStorage
protected stateChangeEmitters: Map<string, Subscription> = new Map()
protected refreshTokenTimer?: ReturnType<typeof setTimeout>
protected networkRetries = 0
protected autoRefreshTicker: ReturnType<typeof setInterval> | null = null
protected refreshingDeferred: Deferred<CallRefreshTokenResult> | null = null
/**
* Keeps track of the async client initialization.
Expand Down Expand Up @@ -142,7 +150,6 @@ export default class GoTrueClient {
this.fetch = resolveFetch(settings.fetch)
this.detectSessionInUrl = settings.detectSessionInUrl

this.initialize()
this.mfa = {
verify: this._verify.bind(this),
enroll: this._enroll.bind(this),
Expand All @@ -152,6 +159,8 @@ export default class GoTrueClient {
challengeAndVerify: this._challengeAndVerify.bind(this),
getAuthenticatorAssuranceLevel: this._getAuthenticatorAssuranceLevel.bind(this),
}

this.initialize()
}

/**
Expand Down Expand Up @@ -888,11 +897,26 @@ export default class GoTrueClient {
*/
private async _refreshAccessToken(refreshToken: string): Promise<AuthResponse> {
try {
return await _request(this.fetch, 'POST', `${this.url}/token?grant_type=refresh_token`, {
body: { refresh_token: refreshToken },
headers: this.headers,
xform: _sessionResponse,
})
const startedAt = Date.now()

// will attempt to refresh the token with exponential backoff
return await retryable(
async (attempt) => {
await sleep(attempt * 200) // 0, 200, 400, 800, ...

return await _request(this.fetch, 'POST', `${this.url}/token?grant_type=refresh_token`, {
body: { refresh_token: refreshToken },
headers: this.headers,
xform: _sessionResponse,
})
},
(attempt, _, result) =>
result &&
result.error &&
result.error instanceof AuthRetryableFetchError &&
// retryable only if the request can be sent before the backoff overflows the tick duration
Date.now() + (attempt + 1) * 200 - startedAt < AUTO_REFRESH_TICK_DURATION
)
} catch (error) {
if (isAuthError(error)) {
return { data: { session: null, user: null }, error }
Expand Down Expand Up @@ -951,24 +975,12 @@ export default class GoTrueClient {

if ((currentSession.expires_at ?? Infinity) < timeNow + EXPIRY_MARGIN) {
if (this.autoRefreshToken && currentSession.refresh_token) {
this.networkRetries++
const { error } = await this._callRefreshToken(currentSession.refresh_token)

if (error) {
console.log(error.message)
if (
error instanceof AuthRetryableFetchError &&
this.networkRetries < NETWORK_FAILURE.MAX_RETRIES
) {
if (this.refreshTokenTimer) clearTimeout(this.refreshTokenTimer)
this.refreshTokenTimer = setTimeout(
() => this._recoverAndRefresh(),
NETWORK_FAILURE.RETRY_INTERVAL ** this.networkRetries * 100 // exponential backoff
)
return
}
await this._removeSession()
}
this.networkRetries = 0
} else {
await this._removeSession()
}
Expand Down Expand Up @@ -1037,14 +1049,6 @@ export default class GoTrueClient {
this.inMemorySession = session
}

const expiresAt = session.expires_at
if (expiresAt) {
const timeNow = Math.round(Date.now() / 1000)
const expiresIn = expiresAt - timeNow
const refreshDurationBeforeExpires = expiresIn > EXPIRY_MARGIN ? EXPIRY_MARGIN : 0.5
this._startAutoRefreshToken((expiresIn - refreshDurationBeforeExpires) * 1000)
}

if (this.persistSession && session.expires_at) {
await this._persistSession(session)
}
Expand All @@ -1060,42 +1064,91 @@ export default class GoTrueClient {
} else {
this.inMemorySession = null
}
}

if (this.refreshTokenTimer) {
clearTimeout(this.refreshTokenTimer)
/**
* Starts an auto-refresh process in the background. The session is checked
* every few seconds. Close to the time of expiration a process is started to
* refresh the session. If refreshing fails it will be retried for as long as
* necessary.
*
* If you set the {@link GoTrueClientOptions#autoRefreshToken} you don't need
* to call this function, it will be called for you.
*
* On browsers the refresh process works only when the tab/window is in the
* foreground to conserve resources as well as prevent race conditions and
* flooding auth with requests.
*
* On non-browser platforms the refresh process works *continuously* in the
* background, which may not be desireable. You should hook into your
* platform's foreground indication mechanism and call these methods
* appropriately to conserve resources.
*
* {@see #stopAutoRefresh}
*/
async startAutoRefresh() {
await this.stopAutoRefresh()
this.autoRefreshTicker = setInterval(
() => this._autoRefreshTokenTick(),
AUTO_REFRESH_TICK_DURATION
)

// run the tick immediately
await this._autoRefreshTokenTick()
}

/**
* Stops an active auto refresh process running in the background (if any).
* See {@link #startAutoRefresh} for more details.
*/
async stopAutoRefresh() {
const ticker = this.autoRefreshTicker
this.autoRefreshTicker = null

if (ticker) {
clearInterval(ticker)
}
}

/**
* Clear and re-create refresh token timer
* @param value time intervals in milliseconds.
* @param session The current session.
* Runs the auto refresh token tick.
*/
private _startAutoRefreshToken(value: number) {
if (this.refreshTokenTimer) clearTimeout(this.refreshTokenTimer)
if (value <= 0 || !this.autoRefreshToken) return
private async _autoRefreshTokenTick() {
const now = Date.now()

this.refreshTokenTimer = setTimeout(async () => {
this.networkRetries++
try {
const {
data: { session },
error: sessionError,
error,
} = await this.getSession()
if (!sessionError && session) {
const { error } = await this._callRefreshToken(session.refresh_token)
if (!error) this.networkRetries = 0
if (
error instanceof AuthRetryableFetchError &&
this.networkRetries < NETWORK_FAILURE.MAX_RETRIES
)
this._startAutoRefreshToken(NETWORK_FAILURE.RETRY_INTERVAL ** this.networkRetries * 100) // exponential backoff

if (!session || !session.refresh_token || !session.expires_at) {
return
}
}, value)
if (typeof this.refreshTokenTimer.unref === 'function') this.refreshTokenTimer.unref()

// session will expire in this many ticks (or has already expired if <= 0)
const expiresInTicks = Math.floor((now - session.expires_at) / AUTO_REFRESH_TICK_DURATION)

if (expiresInTicks < AUTO_REFRESH_TICK_THRESHOLD) {
await this._callRefreshToken(session.refresh_token)
}
} catch (e: any) {
console.error('Auto refresh tick failed with error. This is likely a transient error.', e)
}
}

/**
* Registers callbacks on the browser / platform, which in-turn run
* algorithms when the browser window/tab are in foreground. On non-browser
* platforms it assumes always foreground.
*/
private _handleVisibilityChange() {
if (!isBrowser() || !window?.addEventListener) {
if (this.autoRefreshToken) {
// in non-browser environments the refresh token ticker runs always
this.startAutoRefresh()
}

return false
}

Expand All @@ -1104,6 +1157,16 @@ export default class GoTrueClient {
if (document.visibilityState === 'visible') {
await this.initializePromise
await this._recoverAndRefresh()

if (this.autoRefreshToken) {
// in browser environments the refresh token ticker runs only on focused tabs
// which prevents race conditions
this.startAutoRefresh()
}
} else if (document.visibilityState === 'hidden') {
if (this.autoRefreshToken) {
this.stopAutoRefresh()
}
}
})
} catch (error) {
Expand Down Expand Up @@ -1159,10 +1222,7 @@ export default class GoTrueClient {
}

/**
* Enrolls a factor
* @param friendlyName Human readable name assigned to a device
* @param factorType device which we're validating against. Can only be TOTP for now.
* @param issuer domain which the user is enrolling with
* {@see GoTrueMFAApi#enroll}
*/
private async _enroll(params: MFAEnrollParams): Promise<AuthMFAEnrollResponse> {
try {
Expand Down Expand Up @@ -1199,9 +1259,7 @@ export default class GoTrueClient {
}

/**
* Validates a device as part of the enrollment step.
* @param factorId System assigned identifier for authenticator device as returned by enroll
* @param code Code Generated by an authenticator device
* {@see GoTrueMFAApi#verify}
*/
private async _verify(params: MFAVerifyParams): Promise<AuthMFAVerifyResponse> {
try {
Expand Down Expand Up @@ -1240,8 +1298,7 @@ export default class GoTrueClient {
}

/**
* Creates a challenge which a user can verify against
* @param factorId System assigned identifier for authenticator device as returned by enroll
* {@see GoTrueMFAApi#challenge}
*/
private async _challenge(params: MFAChallengeParams): Promise<AuthMFAChallengeResponse> {
try {
Expand All @@ -1268,9 +1325,7 @@ export default class GoTrueClient {
}

/**
* Creates a challenge and immediately verifies it
* @param factorId System assigned identifier for authenticator device as returned by enroll
* @param code Code Generated by an authenticator device
* {@see GoTrueMFAApi#challengeAndVerify}
*/
private async _challengeAndVerify(
params: MFAChallengeAndVerifyParams
Expand All @@ -1289,7 +1344,7 @@ export default class GoTrueClient {
}

/**
* Displays all devices for a given user
* {@see GoTrueMFAApi#listFactors}
*/
private async _listFactors(): Promise<AuthMFAListFactorsResponse> {
const {
Expand All @@ -1315,8 +1370,7 @@ export default class GoTrueClient {
}

/**
* Gets the current and next authenticator assurance level (AAL)
* and the current authentication methods for the session (AMR)
* {@see GoTrueMFAApi#getAuthenticatorAssuranceLevel}
*/
private async _getAuthenticatorAssuranceLevel(): Promise<AuthMFAGetAuthenticatorAssuranceLevelResponse> {
const {
Expand Down
44 changes: 43 additions & 1 deletion src/lib/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export class Deferred<T = any> {
export function decodeJWTPayload(token: string) {
// Regex checks for base64url format
const base64UrlRegex = /^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}=?$|[a-z0-9_-]{2}(==)?$)$/i

const parts = token.split('.')

if (parts.length !== 3) {
Expand All @@ -145,3 +145,45 @@ export function decodeJWTPayload(token: string) {
const base64Url = parts[1]
return JSON.parse(decodeBase64URL(base64Url))
}

/**
* Creates a promise that resolves to null after some time.
*/
export function sleep(time: number): Promise<null> {
return new Promise((accept) => {
setTimeout(() => accept(null), time)
})
}

/**
* Converts the provided async function into a retryable function. Each result
* or thrown error is sent to the isRetryable function which should return true
* if the function should run again.
*/
export function retryable<T>(
fn: (attempt: number) => Promise<T>,
isRetryable: (attempt: number, error: any | null, result?: T) => boolean
): Promise<T> {
const promise = new Promise<T>((accept, reject) => {
// eslint-disable-next-line @typescript-eslint/no-extra-semi
;(async () => {
for (let attempt = 0; attempt < Infinity; attempt++) {
try {
const result = await fn(attempt)

if (!isRetryable(attempt, null, result)) {
accept(result)
return
}
} catch (e: any) {
if (!isRetryable(attempt, e)) {
reject(e)
return
}
}
}
})()
})

return promise
}
4 changes: 2 additions & 2 deletions src/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -916,8 +916,8 @@ export type CallRefreshTokenResult =

export type Pagination = {
[key: string]: any
nextPage: number | null,
lastPage: number,
nextPage: number | null
lastPage: number
total: number
}

Expand Down
Loading

0 comments on commit 3138c21

Please sign in to comment.