Skip to content

Commit

Permalink
Refresh OAuth tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Feb 3, 2022
1 parent 196d2d7 commit 4fe85d2
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 10 deletions.
37 changes: 35 additions & 2 deletions src/matrix/Client.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {AbortableOperation} from "../utils/AbortableOperation";
import {ObservableValue} from "../observable/ObservableValue";
import {HomeServerApi} from "./net/HomeServerApi";
import {OidcApi} from "./net/OidcApi";
import {TokenRefresher} from "./net/TokenRefresher";
import {Reconnector, ConnectionStatus} from "./net/Reconnector";
import {ExponentialRetryDelay} from "./net/ExponentialRetryDelay";
import {MediaRepository} from "./net/MediaRepository";
Expand Down Expand Up @@ -182,7 +183,7 @@ export class Client {
}

if (loginData.expires_in) {
sessionInfo.expiresAt = clock.now() + loginData.expires_in * 1000;
sessionInfo.accessTokenExpiresAt = clock.now() + loginData.expires_in * 1000;
}

if (loginData.oidc_issuer) {
Expand Down Expand Up @@ -242,9 +243,41 @@ export class Client {
retryDelay: new ExponentialRetryDelay(clock.createTimeout),
createMeasure: clock.createMeasure
});

let accessToken;

if (sessionInfo.oidcIssuer) {
const oidcApi = new OidcApi({
issuer: sessionInfo.oidcIssuer,
clientId: "hydrogen-web",
request: this._platform.request,
encoding: this._platform.encoding,
});

// TODO: stop/pause the refresher?
const tokenRefresher = new TokenRefresher({
oidcApi,
clock: this._platform.clock,
accessToken: sessionInfo.accessToken,
accessTokenExpiresAt: sessionInfo.accessTokenExpiresAt,
refreshToken: sessionInfo.refreshToken,
anticipation: 30 * 1000,
});

tokenRefresher.token.subscribe(t => {
this._platform.sessionInfoStorage.updateToken(sessionInfo.id, t.accessToken, t.accessTokenExpiresAt, t.refreshToken);
});

await tokenRefresher.start();

accessToken = tokenRefresher.accessToken;
} else {
accessToken = new ObservableValue(sessionInfo.accessToken);
}

const hsApi = new HomeServerApi({
homeserver: sessionInfo.homeServer,
accessToken: sessionInfo.accessToken,
accessToken,
request: this._platform.request,
reconnector: this._reconnector,
});
Expand Down
20 changes: 13 additions & 7 deletions src/matrix/net/HomeServerApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ const DEHYDRATION_PREFIX = "/_matrix/client/unstable/org.matrix.msc2697.v2";

type Options = {
homeserver: string;
accessToken: string;
accessToken: BaseObservableValue<string>;
request: RequestFunction;
reconnector: Reconnector;
};

export class HomeServerApi {
private readonly _homeserver: string;
private readonly _accessToken: string;
private readonly _accessToken: BaseObservableValue<string>;
private readonly _requestFn: RequestFunction;
private readonly _reconnector: Reconnector;

constructor({homeserver, accessToken, request, reconnector}: Options) {
constructor({ homeserver, accessToken, request, reconnector }: Options) {
// store these both in a closure somehow so it's harder to get at in case of XSS?
// one could change the homeserver as well so the token gets sent there, so both must be protected from read/write
this._homeserver = homeserver;
Expand All @@ -54,7 +54,7 @@ export class HomeServerApi {
return this._homeserver + prefix + csPath;
}

private _baseRequest(method: RequestMethod, url: string, queryParams?: Record<string, any>, body?: Record<string, any>, options?: IRequestOptions, accessToken?: string): IHomeServerRequest {
private _baseRequest(method: RequestMethod, url: string, queryParams?: Record<string, any>, body?: Record<string, any>, options?: IRequestOptions, accessTokenSource?: BaseObservableValue<string>): IHomeServerRequest {
const queryString = encodeQueryParams(queryParams);
url = `${url}?${queryString}`;
let log: ILogItem | undefined;
Expand All @@ -68,9 +68,14 @@ export class HomeServerApi {
}
let encodedBody: EncodedBody["body"];
const headers: Map<string, string | number> = new Map();

let accessToken: string | null = null;
if (options?.accessTokenOverride) {
accessToken = options.accessTokenOverride;
} else if (accessTokenSource) {
accessToken = accessTokenSource.get();
}

if (accessToken) {
headers.set("Authorization", `Bearer ${accessToken}`);
}
Expand All @@ -91,7 +96,7 @@ export class HomeServerApi {
});

const hsRequest = new HomeServerRequest(method, url, requestResult, log);

if (this._reconnector) {
hsRequest.response().catch(err => {
// Some endpoints such as /sync legitimately time-out
Expand Down Expand Up @@ -282,11 +287,12 @@ export class HomeServerApi {

claimDehydratedDevice(deviceId: string, options: IRequestOptions): IHomeServerRequest {
options.prefix = DEHYDRATION_PREFIX;
return this._post(`/dehydrated_device/claim`, {}, {device_id: deviceId}, options);
return this._post(`/dehydrated_device/claim`, {}, { device_id: deviceId }, options);
}
}

import {Request as MockRequest} from "../../mocks/Request.js";
import { Request as MockRequest } from "../../mocks/Request.js";
import { BaseObservableValue } from "../../observable/ObservableValue";

export function tests() {
return {
Expand Down
125 changes: 125 additions & 0 deletions src/matrix/net/TokenRefresher.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
Copyright 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the 'License');
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an 'AS IS' BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

import { BaseObservableValue, ObservableValue } from "../../observable/ObservableValue";
import type { Clock, Timeout } from "../../platform/web/dom/Clock";
import { OidcApi } from "./OidcApi";

type Token = {
accessToken: string,
accessTokenExpiresAt: number,
refreshToken: string,
};


export class TokenRefresher {
private _token: ObservableValue<Token>;
private _accessToken: BaseObservableValue<string>;
private _anticipation: number;
private _clock: Clock;
private _oidcApi: OidcApi;
private _timeout: Timeout

constructor({
oidcApi,
refreshToken,
accessToken,
accessTokenExpiresAt,
anticipation,
clock,
}: {
oidcApi: OidcApi,
refreshToken: string,
accessToken: string,
accessTokenExpiresAt: number,
anticipation: number,
clock: Clock,
}) {
this._token = new ObservableValue({
accessToken,
accessTokenExpiresAt,
refreshToken,
});
this._accessToken = this._token.map(t => t.accessToken);

this._anticipation = anticipation;
this._oidcApi = oidcApi;
this._clock = clock;
}

async start() {
if (this.needsRenewing) {
await this.renew();
}

this._renewingLoop();
}

stop() {
// TODO
}

get needsRenewing() {
const remaining = this._token.get().accessTokenExpiresAt - this._clock.now();
const anticipated = remaining - this._anticipation;
return anticipated < 0;
}

async _renewingLoop() {
while (true) {
const remaining =
this._token.get().accessTokenExpiresAt - this._clock.now();
const anticipated = remaining - this._anticipation;

if (anticipated > 0) {
this._timeout = this._clock.createTimeout(anticipated);
await this._timeout.elapsed();
}

await this.renew();
}
}

async renew() {
let refreshToken = this._token.get().refreshToken;
const response = await this._oidcApi
.refreshToken({
refreshToken,
});

if (typeof response.expires_in !== "number") {
throw new Error("Refreshed access token does not expire");
}

if (response.refresh_token) {
refreshToken = response.refresh_token;
}

this._token.set({
refreshToken,
accessToken: response.access_token,
accessTokenExpiresAt: this._clock.now() + response.expires_in * 1000,
});
}

get accessToken(): BaseObservableValue<string> {
return this._accessToken;
}

get token(): BaseObservableValue<Token> {
return this._token;
}
}
16 changes: 15 additions & 1 deletion src/matrix/sessioninfo/localstorage/SessionInfoStorage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ interface ISessionInfo {
homeserver: string;
homeServer: string; // deprecate this over time
accessToken: string;
accessTokenExpiresAt?: number;
refreshToken?: string;
expiresAt?: number;
oidcIssuer?: string;
lastUsed: number;
}
Expand All @@ -31,6 +31,7 @@ interface ISessionInfo {
interface ISessionInfoStorage {
getAll(): Promise<ISessionInfo[]>;
updateLastUsed(id: string, timestamp: number): Promise<void>;
updateToken(id: string, accessToken: string, accessTokenExpiresAt: number, refreshToken: string): Promise<void>;
get(id: string): Promise<ISessionInfo | undefined>;
add(sessionInfo: ISessionInfo): Promise<void>;
delete(sessionId: string): Promise<void>;
Expand Down Expand Up @@ -65,6 +66,19 @@ export class SessionInfoStorage implements ISessionInfoStorage {
}
}

async updateToken(id: string, accessToken: string, accessTokenExpiresAt: number, refreshToken: string): Promise<void> {
const sessions = await this.getAll();
if (sessions) {
const session = sessions.find(session => session.id === id);
if (session) {
session.accessToken = accessToken;
session.accessTokenExpiresAt = accessTokenExpiresAt;
session.refreshToken = refreshToken;
localStorage.setItem(this._name, JSON.stringify(sessions));
}
}
}

async get(id: string): Promise<ISessionInfo | undefined> {
const sessions = await this.getAll();
if (sessions) {
Expand Down
32 changes: 32 additions & 0 deletions src/observable/ObservableValue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ export abstract class BaseObservableValue<T> extends BaseObservable<(value: T) =
flatMap<C>(mapper: (value: T) => (BaseObservableValue<C> | undefined)): BaseObservableValue<C | undefined> {
return new FlatMapObservableValue<T, C>(this, mapper);
}

map<C>(mapper: (value: T) => C): BaseObservableValue<C> {
return new MappedObservableValue<T, C>(this, mapper);
}
}

interface IWaitHandle<T> {
Expand Down Expand Up @@ -174,6 +178,34 @@ export class FlatMapObservableValue<P, C> extends BaseObservableValue<C | undefi
}
}

export class MappedObservableValue<P, C> extends BaseObservableValue<C> {
private sourceSubscription?: SubscriptionHandle;

constructor(
private readonly source: BaseObservableValue<P>,
private readonly mapper: (value: P) => C
) {
super();
}

onUnsubscribeLast() {
super.onUnsubscribeLast();
this.sourceSubscription = this.sourceSubscription!();
}

onSubscribeFirst() {
super.onSubscribeFirst();
this.sourceSubscription = this.source.subscribe(() => {
this.emit(this.get());
});
}

get(): C {
const sourceValue = this.source.get();
return this.mapper(sourceValue);
}
}

export function tests() {
return {
"set emits an update": assert => {
Expand Down

0 comments on commit 4fe85d2

Please sign in to comment.