diff --git a/src/app/controllers/spcp.server.controller.js b/src/app/controllers/spcp.server.controller.js index a011a1c312..48efc2addc 100644 --- a/src/app/controllers/spcp.server.controller.js +++ b/src/app/controllers/spcp.server.controller.js @@ -228,48 +228,6 @@ exports.corpPassLogin = (ndiConfig) => { }) } -/** - * Adds session to returned JSON if form-filler is SPCP Authenticated - * @param {Object} req - Express request object - * @param {Object} res - Express response object - * @param {Object} next - Express next middleware function - */ -exports.addSpcpSessionInfo = (authClients) => { - return (req, res, next) => { - const { authType } = req.form - let authClient = authClients[authType] ? authClients[authType] : undefined - let jwtName = jwtNames[authType] - let jwt = req.cookies[jwtName] - if (authType && authClient && jwt) { - // add session info if logged in - authClient.verifyJWT(jwt, (err, payload) => { - if (err) { - // Do not specify userName to call MyInfo endpoint with if jwt is - // invalid. - // Client will inform the form-filler to log in with SingPass again. - logger.error({ - message: 'Failed to verify JWT with auth client', - meta: { - action: 'addSpcpSessionInfo', - ...createReqMeta(req), - }, - error: err, - }) - } else { - const { userName } = payload - // For use in addMyInfo middleware - res.locals.spcpSession = { - userName: userName, - } - } - return next() - }) - } else { - return next() - } - } -} - /** * Encrypt and sign verified fields if exist * @param {Object} req - Express request object diff --git a/src/app/factories/spcp.factory.js b/src/app/factories/spcp.factory.js index d33141fc9f..22eac70537 100644 --- a/src/app/factories/spcp.factory.js +++ b/src/app/factories/spcp.factory.js @@ -69,7 +69,6 @@ const spcpFactory = ({ isEnabled, props }) => { passThroughSpcp: admin.passThroughSpcp, singPassLogin: spcp.singPassLogin(ndiConfig), corpPassLogin: spcp.corpPassLogin(ndiConfig), - addSpcpSessionInfo: spcp.addSpcpSessionInfo(authClients), isSpcpAuthenticated: spcp.isSpcpAuthenticated(authClients), } } else { @@ -81,7 +80,6 @@ const spcpFactory = ({ isEnabled, props }) => { res.status(StatusCodes.INTERNAL_SERVER_ERROR).json({ message: errMsg }), corpPassLogin: (req, res) => res.status(StatusCodes.INTERNAL_SERVER_ERROR).json({ message: errMsg }), - addSpcpSessionInfo: (req, res, next) => next(), isSpcpAuthenticated: (req, res, next) => next(), } } diff --git a/src/app/modules/spcp/__tests__/spcp.factory.spec.ts b/src/app/modules/spcp/__tests__/spcp.factory.spec.ts index a9c2563eaf..e2f0093dc7 100644 --- a/src/app/modules/spcp/__tests__/spcp.factory.spec.ts +++ b/src/app/modules/spcp/__tests__/spcp.factory.spec.ts @@ -26,9 +26,14 @@ describe('spcp.factory', () => { ) const fetchLoginPageResult = await SpcpFactory.fetchLoginPage('') const validateLoginPageResult = SpcpFactory.validateLoginPage('') + const extractJwtPayloadResult = await SpcpFactory.extractJwtPayload( + '', + AuthType.SP, + ) expect(createRedirectUrlResult._unsafeUnwrapErr()).toEqual(error) expect(fetchLoginPageResult._unsafeUnwrapErr()).toEqual(error) expect(validateLoginPageResult._unsafeUnwrapErr()).toEqual(error) + expect(extractJwtPayloadResult._unsafeUnwrapErr()).toEqual(error) }) it('should return error functions when props is undefined', async () => { @@ -46,9 +51,14 @@ describe('spcp.factory', () => { ) const fetchLoginPageResult = await SpcpFactory.fetchLoginPage('') const validateLoginPageResult = SpcpFactory.validateLoginPage('') + const extractJwtPayloadResult = await SpcpFactory.extractJwtPayload( + '', + AuthType.SP, + ) expect(createRedirectUrlResult._unsafeUnwrapErr()).toEqual(error) expect(fetchLoginPageResult._unsafeUnwrapErr()).toEqual(error) expect(validateLoginPageResult._unsafeUnwrapErr()).toEqual(error) + expect(extractJwtPayloadResult._unsafeUnwrapErr()).toEqual(error) }) it('should call the SpcpService constructor when isEnabled is true and props is truthy', () => { diff --git a/src/app/modules/spcp/__tests__/spcp.service.spec.ts b/src/app/modules/spcp/__tests__/spcp.service.spec.ts index 2b195dfc64..881a00a958 100644 --- a/src/app/modules/spcp/__tests__/spcp.service.spec.ts +++ b/src/app/modules/spcp/__tests__/spcp.service.spec.ts @@ -8,12 +8,14 @@ import { CreateRedirectUrlError, FetchLoginPageError, LoginPageValidationError, + VerifyJwtError, } from '../spcp.errors' import { SpcpService } from '../spcp.service' import { MOCK_ERROR_CODE, MOCK_ESRVCID, + MOCK_JWT, MOCK_LOGIN_HTML, MOCK_REDIRECT_URL, MOCK_SERVICE_PARAMS as MOCK_PARAMS, @@ -184,4 +186,46 @@ describe('spcp.service', () => { expect(result._unsafeUnwrapErr()).toEqual(new LoginPageValidationError()) }) }) + + describe('extractJwtPayload', () => { + it('should return the correct payload for Singpass when JWT is valid', async () => { + const spcpService = new SpcpService(MOCK_PARAMS) + // Assumes that SP auth client was instantiated first + const mockSpClient = mocked(MockAuthClient.mock.instances[0], true) + mockSpClient.verifyJWT.mockImplementationOnce((jwt, cb) => cb(null, jwt)) + const result = await spcpService.extractJwtPayload(MOCK_JWT, AuthType.SP) + expect(result._unsafeUnwrap()).toEqual(MOCK_JWT) + }) + + it('should return VerifyJwtError when SingPass JWT is invalid', async () => { + const spcpService = new SpcpService(MOCK_PARAMS) + // Assumes that SP auth client was instantiated first + const mockSpClient = mocked(MockAuthClient.mock.instances[0], true) + mockSpClient.verifyJWT.mockImplementationOnce((_jwt, cb) => + cb(new Error(), null), + ) + const result = await spcpService.extractJwtPayload(MOCK_JWT, AuthType.SP) + expect(result._unsafeUnwrapErr()).toEqual(new VerifyJwtError()) + }) + + it('should return the correct payload for Corppass when JWT is valid', async () => { + const spcpService = new SpcpService(MOCK_PARAMS) + // Assumes that SP auth client was instantiated first + const mockCpClient = mocked(MockAuthClient.mock.instances[1], true) + mockCpClient.verifyJWT.mockImplementationOnce((jwt, cb) => cb(null, jwt)) + const result = await spcpService.extractJwtPayload(MOCK_JWT, AuthType.CP) + expect(result._unsafeUnwrap()).toEqual(MOCK_JWT) + }) + + it('should return VerifyJwtError when CorpPass JWT is invalid', async () => { + const spcpService = new SpcpService(MOCK_PARAMS) + // Assumes that SP auth client was instantiated first + const mockCpClient = mocked(MockAuthClient.mock.instances[1], true) + mockCpClient.verifyJWT.mockImplementationOnce((_jwt, cb) => + cb(new Error(), null), + ) + const result = await spcpService.extractJwtPayload(MOCK_JWT, AuthType.CP) + expect(result._unsafeUnwrapErr()).toEqual(new VerifyJwtError()) + }) + }) }) diff --git a/src/app/modules/spcp/__tests__/spcp.test.constants.ts b/src/app/modules/spcp/__tests__/spcp.test.constants.ts index 9949d45137..d165bcccd0 100644 --- a/src/app/modules/spcp/__tests__/spcp.test.constants.ts +++ b/src/app/modules/spcp/__tests__/spcp.test.constants.ts @@ -36,3 +36,4 @@ export const MOCK_REDIRECT_URL = 'redirectUrl' export const MOCK_LOGIN_HTML = 'html' export const MOCK_ERROR_CODE = 'errorCode' export const MOCK_TITLE = 'title' +export const MOCK_JWT = 'jwt' diff --git a/src/app/modules/spcp/spcp.controller.ts b/src/app/modules/spcp/spcp.controller.ts index b148597176..6366e06522 100644 --- a/src/app/modules/spcp/spcp.controller.ts +++ b/src/app/modules/spcp/spcp.controller.ts @@ -3,15 +3,20 @@ import { ParamsDictionary } from 'express-serve-static-core' import { StatusCodes } from 'http-status-codes' import { createLoggerWithLabel } from '../../../config/logger' -import { AuthType } from '../../../types' +import { AuthType, IPopulatedForm } from '../../../types' import { createReqMeta } from '../../utils/request' import { SpcpFactory } from './spcp.factory' import { LoginPageValidationResult } from './spcp.types' -import { mapRouteError } from './spcp.util' +import { extractJwt, mapRouteError } from './spcp.util' const logger = createLoggerWithLabel(module) +// TODO (#42): remove these types when migrating away from middleware pattern +type WithForm = T & { + form: IPopulatedForm +} + /** * Generates redirect URL to Official SingPass/CorpPass log in page * @param req - Express request object @@ -77,3 +82,38 @@ export const handleValidate: RequestHandler< return res.status(statusCode).json({ message: errorMessage }) }) } + +/** + * Adds session to returned JSON if form-filler is SPCP Authenticated + * @param req - Express request object + * @param res - Express response object + * @param next - Express next middleware function + */ +export const addSpcpSessionInfo: RequestHandler = async ( + req, + res, + next, +) => { + const { authType } = (req as WithForm).form + if (!authType) return next() + + const jwt = extractJwt(req.cookies, authType) + if (!jwt) return next() + + return SpcpFactory.extractJwtPayload(jwt, authType) + .map(({ userName }) => { + res.locals.spcpSession = { userName } + return next() + }) + .mapErr((error) => { + logger.error({ + message: 'Failed to verify JWT with auth client', + meta: { + action: 'addSpcpSessionInfo', + ...createReqMeta(req), + }, + error, + }) + return next() + }) +} diff --git a/src/app/modules/spcp/spcp.errors.ts b/src/app/modules/spcp/spcp.errors.ts index 6cceb36852..2c3a8cc68a 100644 --- a/src/app/modules/spcp/spcp.errors.ts +++ b/src/app/modules/spcp/spcp.errors.ts @@ -34,3 +34,12 @@ export class LoginPageValidationError extends ApplicationError { super(message) } } + +/** + * Invalid JWT. + */ +export class VerifyJwtError extends ApplicationError { + constructor(message = 'Invalid JWT') { + super(message) + } +} diff --git a/src/app/modules/spcp/spcp.factory.ts b/src/app/modules/spcp/spcp.factory.ts index 4de0f12dbb..5036201fcc 100644 --- a/src/app/modules/spcp/spcp.factory.ts +++ b/src/app/modules/spcp/spcp.factory.ts @@ -12,9 +12,10 @@ import { FetchLoginPageError, InvalidAuthTypeError, LoginPageValidationError, + VerifyJwtError, } from './spcp.errors' import { SpcpService } from './spcp.service' -import { LoginPageValidationResult } from './spcp.types' +import { JwtPayload, LoginPageValidationResult } from './spcp.types' interface ISpcpFactory { createRedirectUrl( @@ -34,6 +35,10 @@ interface ISpcpFactory { LoginPageValidationResult, LoginPageValidationError | MissingFeatureError > + extractJwtPayload( + jwt: string, + authType: AuthType, + ): ResultAsync } export const createSpcpFactory = ({ @@ -46,6 +51,7 @@ export const createSpcpFactory = ({ createRedirectUrl: () => err(error), fetchLoginPage: () => errAsync(error), validateLoginPage: () => err(error), + extractJwtPayload: () => errAsync(error), } } return new SpcpService(props) diff --git a/src/app/modules/spcp/spcp.service.ts b/src/app/modules/spcp/spcp.service.ts index 1c2bd5dd35..e30831a154 100644 --- a/src/app/modules/spcp/spcp.service.ts +++ b/src/app/modules/spcp/spcp.service.ts @@ -2,7 +2,7 @@ import SPCPAuthClient from '@opengovsg/spcp-auth-client' import axios from 'axios' import fs from 'fs' import { StatusCodes } from 'http-status-codes' -import { err, ok, Result, ResultAsync } from 'neverthrow' +import { err, errAsync, ok, Result, ResultAsync } from 'neverthrow' import { ISpcpMyInfo } from '../../../config/feature-manager' import { createLoggerWithLabel } from '../../../config/logger' @@ -13,9 +13,10 @@ import { FetchLoginPageError, InvalidAuthTypeError, LoginPageValidationError, + VerifyJwtError, } from './spcp.errors' -import { LoginPageValidationResult } from './spcp.types' -import { getSubstringBetween } from './spcp.util' +import { JwtPayload, LoginPageValidationResult } from './spcp.types' +import { getSubstringBetween, verifyJwtPromise } from './spcp.util' const logger = createLoggerWithLabel(module) const LOGIN_PAGE_HEADERS = @@ -48,6 +49,13 @@ export class SpcpService { }) } + /** + * Create the URL to which the client should be redirected for Singpass/ + * Corppass login. + * @param authType 'SP' or 'CP' + * @param target The target URL which will become the SPCP RelayState + * @param esrvcId SP/CP e-service ID + */ createRedirectUrl( authType: AuthType.SP | AuthType.CP, target: string, @@ -90,6 +98,10 @@ export class SpcpService { } } + /** + * Fetches the HTML of the given URL. + * @param redirectUrl URL from which to obtain the HTML + */ fetchLoginPage( redirectUrl: string, ): ResultAsync { @@ -118,6 +130,10 @@ export class SpcpService { ) } + /** + * Validates that the login page does not have an error. + * @param loginHtml The HTML of the page to validate + */ validateLoginPage( loginHtml: string, ): Result { @@ -152,4 +168,40 @@ export class SpcpService { return ok({ isValid: false, errorCode }) } } + + /** + * Verifies a JWT and extracts its payload. + * @param jwt The contents of the JWT cookie + * @param authType 'SP' or 'CP' + */ + extractJwtPayload( + jwt: string, + authType: AuthType.SP | AuthType.CP, + ): ResultAsync { + let authClient: SPCPAuthClient + switch (authType) { + case AuthType.SP: + authClient = this.#singpassAuthClient + break + case AuthType.CP: + authClient = this.#corppassAuthClient + break + default: + return errAsync(new InvalidAuthTypeError(authType)) + } + return ResultAsync.fromPromise( + verifyJwtPromise(authClient, jwt), + (error) => { + logger.error({ + message: 'Failed to verify JWT with auth client', + meta: { + action: 'extractPayload', + authType, + }, + error, + }) + return new VerifyJwtError() + }, + ) + } } diff --git a/src/app/modules/spcp/spcp.types.ts b/src/app/modules/spcp/spcp.types.ts index e3ae491f61..66243dee76 100644 --- a/src/app/modules/spcp/spcp.types.ts +++ b/src/app/modules/spcp/spcp.types.ts @@ -1,3 +1,16 @@ +export enum JwtName { + SP = 'jwtSp', + CP = 'jwtCp', +} + export type LoginPageValidationResult = | { isValid: true } | { isValid: false; errorCode: string | null } + +export type SpcpCookies = Partial> + +export type JwtPayload = { + userName: string + userInfo?: string + rememberMe: boolean +} diff --git a/src/app/modules/spcp/spcp.util.ts b/src/app/modules/spcp/spcp.util.ts index 62c80aeb8b..6e90524160 100644 --- a/src/app/modules/spcp/spcp.util.ts +++ b/src/app/modules/spcp/spcp.util.ts @@ -1,7 +1,8 @@ +import SPCPAuthClient from '@opengovsg/spcp-auth-client' import { StatusCodes } from 'http-status-codes' import { createLoggerWithLabel } from '../../../config/logger' -import { MapRouteError } from '../../../types' +import { AuthType, MapRouteError } from '../../../types' import { MissingFeatureError } from '../core/core.errors' import { @@ -9,6 +10,7 @@ import { FetchLoginPageError, LoginPageValidationError, } from './spcp.errors' +import { JwtName, JwtPayload, SpcpCookies } from './spcp.types' const logger = createLoggerWithLabel(module) @@ -26,6 +28,34 @@ export const getSubstringBetween = ( } } +export const verifyJwtPromise = ( + authClient: SPCPAuthClient, + jwt: string, +): Promise => { + return new Promise((resolve, reject) => { + authClient.verifyJWT(jwt, (error: Error, data: JwtPayload) => { + if (error) { + return reject(error) + } + return resolve(data) + }) + }) +} + +export const extractJwt = ( + cookies: SpcpCookies, + authType: AuthType, +): string | undefined => { + switch (authType) { + case AuthType.SP: + return cookies[JwtName.SP] + case AuthType.CP: + return cookies[JwtName.CP] + default: + return undefined + } +} + export const mapRouteError: MapRouteError = (error) => { switch (error.constructor) { case MissingFeatureError: diff --git a/src/app/routes/public-forms.server.routes.js b/src/app/routes/public-forms.server.routes.js index 2101835ac9..668920b020 100644 --- a/src/app/routes/public-forms.server.routes.js +++ b/src/app/routes/public-forms.server.routes.js @@ -16,6 +16,7 @@ const { CaptchaFactory } = require('../factories/captcha.factory') const { limitRate } = require('../utils/limit-rate') const { rateLimitConfig } = require('../../config/config') const PublicFormController = require('../modules/form/public-form/public-form.controller') +const SpcpController = require('../modules/spcp/spcp.controller') module.exports = function (app) { /** @@ -127,7 +128,7 @@ module.exports = function (app) { .get( forms.formById, publicForms.isFormPublic, - spcpFactory.addSpcpSessionInfo, + SpcpController.addSpcpSessionInfo, myInfoController.addMyInfo, forms.read(forms.REQUEST_TYPE.PUBLIC), ) diff --git a/tests/unit/backend/controllers/spcp.server.controller.spec.js b/tests/unit/backend/controllers/spcp.server.controller.spec.js index 157a2442ea..9d78160446 100644 --- a/tests/unit/backend/controllers/spcp.server.controller.spec.js +++ b/tests/unit/backend/controllers/spcp.server.controller.spec.js @@ -228,74 +228,6 @@ describe('SPCP Controller', () => { }) }) - describe('addSpcpSessionInfo', () => { - let next - - beforeEach(() => { - req = { - form: { authType: 'SP' }, - cookies: { jwtSp: 'spCookie' }, - headers: {}, - ip: '127.0.0.1', - get: (key) => this[key], - } - res.locals = {} - }) - - const expectPassthroughWith = (session, cb) => { - next = jasmine.createSpy().and.callFake(() => { - if (!session) { - expect(res.locals).toEqual({}) - } else { - expect(res.locals.spcpSession.userName).toEqual(session.userName) - } - cb() - }) - Controller.addSpcpSessionInfo(authClients)(req, res, next) - } - - const replyWith = (error, data) => { - singPassAuthClient.verifyJWT.and.callFake((jwt, cb) => { - expect(jwt).toEqual('spCookie') - cb(error, data) - }) - } - - it('should call next if authType is NIL', (done) => { - req.form.authType = 'NIL' - expectPassthroughWith(null, done) - }) - - it('should call next if authType undefined', (done) => { - req.form.authType = '' - expectPassthroughWith(null, done) - }) - - it('should call next if cookie undefined', (done) => { - req.cookies.jwtSp = '' - expectPassthroughWith(null, done) - }) - - it('should not update spcpSession if verifyJWT fails', (done) => { - req.form.authType = 'SP' - replyWith('error', {}) - expectPassthroughWith(null, done) - }) - - it('should update spcpSession if verifyJWT succeeds', (done) => { - req.form.authType = 'SP' - replyWith(null, { - userName: 'abc', - }) - expectPassthroughWith( - { - userName: 'abc', - }, - done, - ) - }) - }) - describe('singPassLogin/corpPassLogin - validation', () => { let spB64Artifact let expectedRelayState