Skip to content

Commit

Permalink
feat: Add origin restriction to session token (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillDogadin-std authored Apr 26, 2023
1 parent 0ffdb6c commit 100a018
Show file tree
Hide file tree
Showing 17 changed files with 230 additions and 61 deletions.
11 changes: 11 additions & 0 deletions api/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"pino-pretty": "^10.0.0",
"vite-node": "^0.29.2",
"vitest": "^0.29.2",
"wildcard-match": "^5.1.2",
"zod": "^3.21.4"
},
"devDependencies": {
Expand Down
1 change: 1 addition & 0 deletions api/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ model Session {
revokedAt DateTime?
referenceTokenId String
isUserCreated Boolean @default(false)
allowedOrigins String // comma separated strings
creator User @relation(fields: [createdBy], references: [id], onDelete: Cascade)
Expand Down
4 changes: 4 additions & 0 deletions api/src/generated/nexus.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ declare global {

export interface NexusGenInputs {
SessionCreate: { // input type
allowedOrigins: string; // String!
expiryDurationSeconds?: number | null; // Int
name: string; // String!
}
Expand Down Expand Up @@ -80,6 +81,7 @@ export interface NexusGenObjects {
Mutation: {};
Query: {};
Session: { // root type
allowedOrigins?: string | null; // String
createdAt: NexusGenScalars['GQLDateBase']; // GQLDateBase!
createdBy: string; // String!
id: string; // String!
Expand Down Expand Up @@ -138,6 +140,7 @@ export interface NexusGenFieldTypes {
sessions: Array<NexusGenRootTypes['Session'] | null> | null; // [Session]
}
Session: { // field return type
allowedOrigins: string | null; // String
createdAt: NexusGenScalars['GQLDateBase']; // GQLDateBase!
createdBy: string; // String!
id: string; // String!
Expand Down Expand Up @@ -186,6 +189,7 @@ export interface NexusGenFieldTypeNames {
sessions: 'Session'
}
Session: { // field return type name
allowedOrigins: 'String'
createdAt: 'GQLDateBase'
createdBy: 'String'
id: 'String'
Expand Down
2 changes: 2 additions & 0 deletions api/src/generated/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type Query {
}

type Session {
allowedOrigins: String
createdAt: GQLDateBase!
createdBy: String!
id: String!
Expand All @@ -47,6 +48,7 @@ type Session {
}

input SessionCreate {
allowedOrigins: String!
expiryDurationSeconds: Int
name: String!
}
Expand Down
5 changes: 4 additions & 1 deletion api/src/graphql/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export interface Context {
prisma: typeof prisma;
getSession: () => Promise<Session>;
apolloLogger: pino.Logger;
origin: string | undefined;
}

type CreateContextParams = {
Expand All @@ -29,11 +30,13 @@ export function createContext(params: CreateContextParams): Context {
const authorizationHeader = req.get('Authorization');
const cookieAuthHeader = req.cookies['gql:default'];
const token = authorizationHeader?.replace('Bearer ', '');
const origin = req.get('Origin');

return {
request: params,
prisma,
apolloLogger,
getSession: async () => prisma.session.getSessionByToken(token || cookieAuthHeader),
getSession: async () => prisma.session.getSessionByToken(origin, token || cookieAuthHeader),
origin,
};
}
81 changes: 81 additions & 0 deletions api/src/modules/Session/helpers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import type { PrismaClient, Prisma } from '@prisma/client';
import { randomUUID } from 'crypto';
import { GraphQLError } from 'graphql';
import wildcard from 'wildcard-match';
import { token as tokenUtils } from '../../helpers';

function parseOriginMarkup(originParam: string): string {
if (originParam === '*') {
return '*';
}
const trimmedOriginParam = originParam.trim();
const origins = trimmedOriginParam.split(',').map((origin) => origin.trim());
origins.forEach((origin) => {
if (!origin.startsWith('http://') && !origin.startsWith('https://')) {
throw new GraphQLError("Origin must start with 'http://' or 'https://'", {
extensions: { code: 'INVALID_ORIGIN_PROTOCOL' },
});
}
});
return origins.join(',');
}

export function validateOriginAgainstAllowed(
allowedOrigins: string,
originReceived?: string,
) {
if (allowedOrigins === '*') {
return;
}
if (!originReceived) {
throw new GraphQLError('Origin not provided', {
extensions: { code: 'ORIGIN_HEADER_MISSING' },
});
}
const allowedOriginsSplit = allowedOrigins.split(',');
if (!wildcard(allowedOriginsSplit)(originReceived)) {
throw new GraphQLError('Access denied due to origin restriction', {
extensions: { code: 'ORIGIN_FORBIDDEN' },
});
}
}

async function newSession(
prisma: PrismaClient,
session: Prisma.SessionCreateInput,
) {
return prisma.session.create({
data: session,
});
}

export const generateTokenAndSession = async (
prisma: PrismaClient,
userId: string,
session: { expiryDurationSeconds?: number | null; name: string; allowedOrigins: string },
isUserCreated: boolean = false,
) => {
const createId = randomUUID();
const createdToken = tokenUtils.generate(createId, session.expiryDurationSeconds);
const expiryDate = tokenUtils.getExpiryDateFromToken(createdToken);
const formattedToken = tokenUtils.format(createdToken);
const parsedAllowedOrigins = parseOriginMarkup(session.allowedOrigins);
const createData = {
allowedOrigins: parsedAllowedOrigins,
name: session.name,
referenceExpiryDate: expiryDate,
id: createId,
referenceTokenId: formattedToken,
isUserCreated,
creator: {
connect: {
id: userId,
},
},
};
const createdSession = await newSession(prisma, createData);
return {
token: createdToken,
session: createdSession,
};
};
56 changes: 11 additions & 45 deletions api/src/modules/Session/model.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import type { PrismaClient, Prisma } from '@prisma/client';
import type { PrismaClient } from '@prisma/client';
import { inputObjectType, objectType } from 'nexus/dist';
import { randomUUID } from 'crypto';
import { GraphQLError } from 'graphql';
import ms from 'ms';
import { token as tokenUtils } from '../../helpers';
import { JWT_EXPIRATION_PERIOD } from '../../env';
import { validateOriginAgainstAllowed, generateTokenAndSession } from './helpers';

export const Session = objectType({
name: 'Session',
Expand All @@ -17,6 +17,7 @@ export const Session = objectType({
t.nonNull.boolean('isUserCreated');
t.string('name');
t.date('revokedAt');
t.string('allowedOrigins');
},
});

Expand All @@ -25,6 +26,7 @@ export const SessionCreate = inputObjectType({
definition(t) {
t.int('expiryDurationSeconds');
t.nonNull.string('name');
t.nonNull.string('allowedOrigins');
},
});

Expand All @@ -36,44 +38,6 @@ export const SessionCreateOutput = objectType({
},
});

async function newSession(
prisma: PrismaClient,
session: Prisma.SessionCreateInput,
) {
return prisma.session.create({
data: session,
});
}

const generateTokenAndSession = async (
prisma: PrismaClient,
userId: string,
session: { expiryDurationSeconds?: number | null; name: string },
isUserCreated: boolean = false,
) => {
const createId = randomUUID();
const createdToken = tokenUtils.generate(createId, session.expiryDurationSeconds);
const expiryDate = tokenUtils.getExpiryDateFromToken(createdToken);
const formattedToken = tokenUtils.format(createdToken);
const createData = {
name: session.name,
referenceExpiryDate: expiryDate,
id: createId,
referenceTokenId: formattedToken,
isUserCreated,
creator: {
connect: {
id: userId,
},
},
};
const createdSession = await newSession(prisma, createData);
return {
token: createdToken,
session: createdSession,
};
};

export function getSessionCrud(prisma: PrismaClient) {
return {
listSessions: async (userId: string) => prisma.session.findMany({
Expand Down Expand Up @@ -112,22 +76,23 @@ export function getSessionCrud(prisma: PrismaClient) {
throw new GraphQLError('Failed to revoke session', { extensions: { code: 'REVOKE_SESSION_FAILED' } });
}
},
createSignInSession: async (userId: string) => generateTokenAndSession(
createSignInSession: async (userId: string, origin: string = '*') => generateTokenAndSession(
prisma,
userId,
{ expiryDurationSeconds: ms(JWT_EXPIRATION_PERIOD) / 1000, name: 'Sign in' },
{ expiryDurationSeconds: ms(JWT_EXPIRATION_PERIOD) / 1000, name: 'Sign in', allowedOrigins: origin },
),
createSignUpSession: async (userId: string) => generateTokenAndSession(
createSignUpSession: async (userId: string, origin: string = '*') => generateTokenAndSession(
prisma,
userId,
{ expiryDurationSeconds: ms(JWT_EXPIRATION_PERIOD) / 1000, name: 'Sign up' },
{ expiryDurationSeconds: ms(JWT_EXPIRATION_PERIOD) / 1000, name: 'Sign up', allowedOrigins: origin },
),
createCustomSession: async (
userId: string,
session: { expiryDurationSeconds?: number | null; name: string },
session: { expiryDurationSeconds?: number | null; name: string, allowedOrigins: string },
isUserCreated: boolean = false,
) => generateTokenAndSession(prisma, userId, session, isUserCreated),
async getSessionByToken(
origin?: string,
token?: string,
) {
if (!token) {
Expand All @@ -150,6 +115,7 @@ export function getSessionCrud(prisma: PrismaClient) {
extensions: { code: 'SESSION_EXPIRED' },
});
}
validateOriginAgainstAllowed(session.allowedOrigins, origin);
return session;
},

Expand Down
4 changes: 2 additions & 2 deletions api/src/modules/User/resolvers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export const signIn = mutationField('signIn', {
},
resolve: async (_parent, { user: userNamePass }, ctx) => {
const { id } = await ctx.prisma.user.getUserByUsernamePassword(userNamePass);
return ctx.prisma.session.createSignInSession(id);
return ctx.prisma.session.createSignInSession(id, ctx.origin);
},
});

Expand All @@ -30,6 +30,6 @@ export const signUp = mutationField('signUp', {
},
resolve: async (_parent, { user }, ctx) => {
const { id } = await ctx.prisma.user.createUser(user);
return ctx.prisma.session.createSignUpSession(id);
return ctx.prisma.session.createSignUpSession(id, ctx.origin);
},
});
1 change: 1 addition & 0 deletions api/tests/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ test('Authentication: sign up, sign in, request protected enpoint', async () =>

const token = signInResponse?.signIn?.token;
ctx.client.setHeader('Authorization', `Bearer ${token}`);
ctx.client.setHeader('Origin', 'http://localhost:3000');

const meResponse = (await executeGraphQlQuery(meQuery)) as Record<
string,
Expand Down
Loading

0 comments on commit 100a018

Please sign in to comment.