From 53698c880b5f9717e2d0443d919622d583196dfc Mon Sep 17 00:00:00 2001 From: Jane Fawkes Date: Fri, 7 Jun 2024 23:06:17 +0300 Subject: [PATCH 1/6] Add request rate limit to proxy --- proxy/index.ts | 2 +- proxy/proxy.ts | 246 ++++++++++++++++++++++++++++++++++--------------- 2 files changed, 172 insertions(+), 76 deletions(-) diff --git a/proxy/index.ts b/proxy/index.ts index f8c67150..5c065550 100755 --- a/proxy/index.ts +++ b/proxy/index.ts @@ -6,7 +6,7 @@ const PORT = process.env.PORT ?? 5284 let allowsFrom: RegExp[] if (process.env.NODE_ENV !== 'production') { - allowsFrom = [/^http:\/\/localhost:5173/] + allowsFrom = [/^http:\/\/localhost:5284/] } else if (process.env.STAGING) { allowsFrom = [ /^https:\/\/dev.slowreader.app/, diff --git a/proxy/proxy.ts b/proxy/proxy.ts index e6aaca8d..39287ab7 100644 --- a/proxy/proxy.ts +++ b/proxy/proxy.ts @@ -1,4 +1,4 @@ -import type { Server } from 'node:http' +import type { IncomingMessage, Server, ServerResponse } from 'node:http' import { createServer } from 'node:http' import { isIP } from 'node:net' import { styleText } from 'node:util' @@ -13,17 +13,179 @@ class BadRequestError extends Error { } } -export function createProxyServer(config: { +interface RateLimitInfo { + count: number + timestamp: number +} + +interface ProxyConfig { allowLocalhost?: boolean allowsFrom: RegExp[] maxSize: number timeout: number -}): Server { - return createServer(async (req, res) => { +} + +const RATE_LIMIT = { + PER_DOMAIN: { + LIMIT: 500, + DURATION: 60 * 1000 + }, + GLOBAL: { + LIMIT: 5000, + DURATION: 60 * 1000 + } +} + +const delay = (ms: number): Promise => { + return new Promise(resolve => setTimeout(resolve, ms)) +} + +const rateLimitMap: Map = new Map() +const requestQueue: Map> = new Map() + +const isRateLimited = (ip: string, domain: string): boolean => { + const now = performance.now() + + const domainKey = `${ip}:${domain}` + const domainRateLimit = rateLimitMap.get(domainKey) || { + count: 0, + timestamp: now + } + + if (now - domainRateLimit.timestamp > RATE_LIMIT.PER_DOMAIN.DURATION) { + domainRateLimit.count = 0 + domainRateLimit.timestamp = now + } + + if (domainRateLimit.count >= RATE_LIMIT.PER_DOMAIN.LIMIT) { + return true + } + + const globalKey = ip + const globalRateLimit = rateLimitMap.get(globalKey) || { + count: 0, + timestamp: now + } + + if (now - globalRateLimit.timestamp > RATE_LIMIT.GLOBAL.DURATION) { + globalRateLimit.count = 0 + globalRateLimit.timestamp = now + } + + if (globalRateLimit.count >= RATE_LIMIT.GLOBAL.LIMIT) { + return true + } + + domainRateLimit.count += 1 + globalRateLimit.count += 1 + rateLimitMap.set(domainKey, domainRateLimit) + rateLimitMap.set(globalKey, globalRateLimit) + + return false +} + +const handleError = (e: any, res: ServerResponse): void => { + if (e instanceof Error && e.name === 'TimeoutError') { + res.writeHead(400, { 'Content-Type': 'text/plain' }) + res.end('Timeout') + } else if (e instanceof BadRequestError) { + res.writeHead(e.code, { 'Content-Type': 'text/plain' }) + res.end(e.message) + } else { + if (e instanceof Error) { + process.stderr.write(styleText('red', e.stack ?? e.message) + '\n') + } else if (typeof e === 'string') { + process.stderr.write(styleText('red', e) + '\n') + } + res.writeHead(500, { 'Content-Type': 'text/plain' }) + res.end('Internal Server Error') + } +} + +const processRequest = async ( + req: IncomingMessage, + res: ServerResponse, + config: ProxyConfig, + url: string, + parsedUrl: URL +): Promise => { + try { + delete req.headers.cookie + delete req.headers['set-cookie'] + delete req.headers.host + + const targetResponse = await fetch(url, { + headers: { + ...(req.headers as HeadersInit), + 'host': new URL(url).host, + 'X-Forwarded-For': req.socket.remoteAddress! + }, + method: req.method, + signal: AbortSignal.timeout(config.timeout) + }) + + const length = targetResponse.headers.has('content-length') + ? parseInt(targetResponse.headers.get('content-length')!) + : undefined + + if (length && length > config.maxSize) { + throw new BadRequestError('Response too large', 413) + } + + res.writeHead(targetResponse.status, { + 'Access-Control-Allow-Headers': '*', + 'Access-Control-Allow-Methods': 'OPTIONS, POST, GET, PUT, DELETE', + 'Access-Control-Allow-Origin': req.headers.origin, + 'Content-Type': targetResponse.headers.get('content-type') ?? 'text/plain' + }) + + if (targetResponse.body) { + const reader = targetResponse.body.getReader() + let size = 0 + let chunk: ReadableStreamReadResult + do { + chunk = await reader.read() + if (chunk.value) { + res.write(chunk.value) + size += chunk.value.length + if (size > config.maxSize) { + break + } + } + } while (!chunk.done) + } + res.end() + } catch (e) { + handleError(e, res) + } +} + +const handleRequestWithDelay = async ( + req: IncomingMessage, + res: ServerResponse, + config: ProxyConfig, + ip: string, + url: string, + parsedUrl: URL +): Promise => { + const isRateLimitedFlag = isRateLimited(ip, parsedUrl.hostname) + if (isRateLimitedFlag) { + const existingQueue = requestQueue.get(ip) || Promise.resolve() + const delayedRequest = existingQueue.then(() => delay(1000)) // Add a delay of 1 second + requestQueue.set(ip, delayedRequest) + await delayedRequest + } + + await processRequest(req, res, config, url, parsedUrl) +} + +export const createProxyServer = (config: ProxyConfig): Server => { + return createServer(async (req: IncomingMessage, res: ServerResponse) => { let sent = false try { - let url = decodeURIComponent((req.url ?? '').slice(1)) + const ip = req.socket.remoteAddress! + const url = decodeURIComponent((req.url ?? '').slice(1)) let parsedUrl: URL try { @@ -32,6 +194,8 @@ export function createProxyServer(config: { throw new BadRequestError('Invalid URL') } + req.headers.origin = 'http://localhost:5284/' // debug + // Only HTTP or HTTPS protocols are allowed if (!url.startsWith('http://') && !url.startsWith('https://')) { throw new BadRequestError('Only HTTP or HTTPS are supported') @@ -57,77 +221,9 @@ export function createProxyServer(config: { throw new BadRequestError('IP addresses are not allowed') } - // Remove all cookie headers so they will not be set on proxy domain - delete req.headers.cookie - delete req.headers['set-cookie'] - delete req.headers.host - - let targetResponse = await fetch(url, { - headers: { - ...(req.headers as HeadersInit), - 'host': new URL(url).host, - 'X-Forwarded-For': req.socket.remoteAddress! - }, - method: req.method, - signal: AbortSignal.timeout(config.timeout) - }) - - let length: number | undefined - if (targetResponse.headers.has('content-length')) { - length = parseInt(targetResponse.headers.get('content-length')!) - } - if (length && length > config.maxSize) { - throw new BadRequestError('Response too large', 413) - } - - res.writeHead(targetResponse.status, { - 'Access-Control-Allow-Headers': '*', - 'Access-Control-Allow-Methods': 'OPTIONS, POST, GET, PUT, DELETE', - 'Access-Control-Allow-Origin': req.headers.origin, - 'Content-Type': - targetResponse.headers.get('content-type') ?? 'text/plain' - }) - sent = true - - let size = 0 - if (targetResponse.body) { - let reader = targetResponse.body.getReader() - let chunk: ReadableStreamReadResult - do { - chunk = await reader.read() - if (chunk.value) { - res.write(chunk.value) - size += chunk.value.length - if (size > config.maxSize) { - break - } - } - } while (!chunk.done) - } - res.end() + await handleRequestWithDelay(req, res, config, ip, url, parsedUrl) } catch (e) { - // Known errors - if (e instanceof Error && e.name === 'TimeoutError') { - res.writeHead(400, { 'Content-Type': 'text/plain' }) - res.end('Timeout') - return - } else if (e instanceof BadRequestError) { - res.writeHead(e.code, { 'Content-Type': 'text/plain' }) - res.end(e.message) - return - } - - // Unknown or Internal errors - /* c8 ignore next 9 */ - if (e instanceof Error) { - process.stderr.write(styleText('red', e.stack ?? e.message) + '\n') - } else if (typeof e === 'string') { - process.stderr.write(styleText('red', e) + '\n') - } - if (!sent) { - res.writeHead(500, { 'Content-Type': 'text/plain' }) - res.end('Internal Server Error') - } + handleError(e, res) } }) } From af8ce114e551a57274009b36fcde99bd74581ba4 Mon Sep 17 00:00:00 2001 From: Jane Fawkes Date: Fri, 7 Jun 2024 23:17:31 +0300 Subject: [PATCH 2/6] Add mention of the method to docs --- proxy/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/proxy/README.md b/proxy/README.md index ad3ea8e3..263ebc53 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -20,6 +20,7 @@ _See the [full architecture guide](../README.md) first._ - Proxy removes cookie headers. - Proxy set user’s IP in `X-Forwarded-For` header. - Proxy has timeout and response size limit. +- Proxy has rate limit. The rate limiting is implemented using an in-memory map to track the number of requests made from each IP address to each domain and globally. ## Test Strategy From 729ee25fa33adede8dea6bcdeb110172d4bb6973 Mon Sep 17 00:00:00 2001 From: Jane Fawkes Date: Sat, 8 Jun 2024 11:46:05 +0300 Subject: [PATCH 3/6] Code style changes --- proxy/proxy.ts | 51 +++++++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/proxy/proxy.ts b/proxy/proxy.ts index 39287ab7..be4c954a 100644 --- a/proxy/proxy.ts +++ b/proxy/proxy.ts @@ -2,6 +2,7 @@ import type { IncomingMessage, Server, ServerResponse } from 'node:http' import { createServer } from 'node:http' import { isIP } from 'node:net' import { styleText } from 'node:util' +import { setTimeout } from 'node:timers/promises' class BadRequestError extends Error { code: number @@ -36,18 +37,14 @@ const RATE_LIMIT = { } } -const delay = (ms: number): Promise => { - return new Promise(resolve => setTimeout(resolve, ms)) -} - -const rateLimitMap: Map = new Map() -const requestQueue: Map> = new Map() +let rateLimitMap: Map = new Map() +let requestQueue: Map> = new Map() -const isRateLimited = (ip: string, domain: string): boolean => { - const now = performance.now() +function isRateLimited(ip: string, domain: string): boolean { + let now = performance.now() - const domainKey = `${ip}:${domain}` - const domainRateLimit = rateLimitMap.get(domainKey) || { + let domainKey = `${ip}:${domain}` + let domainRateLimit = rateLimitMap.get(domainKey) || { count: 0, timestamp: now } @@ -61,8 +58,8 @@ const isRateLimited = (ip: string, domain: string): boolean => { return true } - const globalKey = ip - const globalRateLimit = rateLimitMap.get(globalKey) || { + let globalKey = ip + let globalRateLimit = rateLimitMap.get(globalKey) || { count: 0, timestamp: now } @@ -84,7 +81,8 @@ const isRateLimited = (ip: string, domain: string): boolean => { return false } -const handleError = (e: any, res: ServerResponse): void => { +function handleError(e: unknown, res: ServerResponse): void { + // Known errors if (e instanceof Error && e.name === 'TimeoutError') { res.writeHead(400, { 'Content-Type': 'text/plain' }) res.end('Timeout') @@ -92,6 +90,8 @@ const handleError = (e: any, res: ServerResponse): void => { res.writeHead(e.code, { 'Content-Type': 'text/plain' }) res.end(e.message) } else { + // Unknown or Internal errors + /* c8 ignore next 9 */ if (e instanceof Error) { process.stderr.write(styleText('red', e.stack ?? e.message) + '\n') } else if (typeof e === 'string') { @@ -102,7 +102,7 @@ const handleError = (e: any, res: ServerResponse): void => { } } -const processRequest = async ( +let processRequest = async ( req: IncomingMessage, res: ServerResponse, config: ProxyConfig, @@ -110,11 +110,12 @@ const processRequest = async ( parsedUrl: URL ): Promise => { try { + // Remove all cookie headers so they will not be set on proxy domain delete req.headers.cookie delete req.headers['set-cookie'] delete req.headers.host - const targetResponse = await fetch(url, { + let targetResponse = await fetch(url, { headers: { ...(req.headers as HeadersInit), 'host': new URL(url).host, @@ -124,7 +125,7 @@ const processRequest = async ( signal: AbortSignal.timeout(config.timeout) }) - const length = targetResponse.headers.has('content-length') + let length = targetResponse.headers.has('content-length') ? parseInt(targetResponse.headers.get('content-length')!) : undefined @@ -140,7 +141,7 @@ const processRequest = async ( }) if (targetResponse.body) { - const reader = targetResponse.body.getReader() + let reader = targetResponse.body.getReader() let size = 0 let chunk: ReadableStreamReadResult do { @@ -160,7 +161,7 @@ const processRequest = async ( } } -const handleRequestWithDelay = async ( +let handleRequestWithDelay = async ( req: IncomingMessage, res: ServerResponse, config: ProxyConfig, @@ -168,10 +169,10 @@ const handleRequestWithDelay = async ( url: string, parsedUrl: URL ): Promise => { - const isRateLimitedFlag = isRateLimited(ip, parsedUrl.hostname) + let isRateLimitedFlag = isRateLimited(ip, parsedUrl.hostname) if (isRateLimitedFlag) { - const existingQueue = requestQueue.get(ip) || Promise.resolve() - const delayedRequest = existingQueue.then(() => delay(1000)) // Add a delay of 1 second + let existingQueue = requestQueue.get(ip) || Promise.resolve() + let delayedRequest = existingQueue.then(() => setTimeout(1000)) requestQueue.set(ip, delayedRequest) await delayedRequest } @@ -179,13 +180,13 @@ const handleRequestWithDelay = async ( await processRequest(req, res, config, url, parsedUrl) } -export const createProxyServer = (config: ProxyConfig): Server => { +export function createProxyServer(config: ProxyConfig): Server { return createServer(async (req: IncomingMessage, res: ServerResponse) => { let sent = false try { - const ip = req.socket.remoteAddress! - const url = decodeURIComponent((req.url ?? '').slice(1)) + let ip = req.socket.remoteAddress! + let url = decodeURIComponent((req.url ?? '').slice(1)) let parsedUrl: URL try { @@ -206,7 +207,7 @@ export const createProxyServer = (config: ProxyConfig): Server => { throw new BadRequestError('Only GET is allowed', 405) } - // We only allow request from our app + // We only allow requests from our app if ( !req.headers.origin || !config.allowsFrom.some(allowed => allowed.test(req.headers.origin!)) From 061581f2fa45f74c82acecbad93d767eb0a7e1bb Mon Sep 17 00:00:00 2001 From: Jane Fawkes Date: Sun, 9 Jun 2024 16:42:18 +0300 Subject: [PATCH 4/6] Introduce utility function for domain & global checks --- proxy/proxy.ts | 57 ++++++++++++++++++++------------------------------ 1 file changed, 23 insertions(+), 34 deletions(-) diff --git a/proxy/proxy.ts b/proxy/proxy.ts index be4c954a..133d29a1 100644 --- a/proxy/proxy.ts +++ b/proxy/proxy.ts @@ -40,47 +40,37 @@ const RATE_LIMIT = { let rateLimitMap: Map = new Map() let requestQueue: Map> = new Map() -function isRateLimited(ip: string, domain: string): boolean { +function isRateLimited( + key: string, + store: Map, + limit: { LIMIT: number; DURATION: number } +): boolean { let now = performance.now() + let rateLimitInfo = store.get(key) || { count: 0, timestamp: now } - let domainKey = `${ip}:${domain}` - let domainRateLimit = rateLimitMap.get(domainKey) || { - count: 0, - timestamp: now + if (now - rateLimitInfo.timestamp > limit.DURATION) { + rateLimitInfo.count = 0 + rateLimitInfo.timestamp = now } - if (now - domainRateLimit.timestamp > RATE_LIMIT.PER_DOMAIN.DURATION) { - domainRateLimit.count = 0 - domainRateLimit.timestamp = now - } - - if (domainRateLimit.count >= RATE_LIMIT.PER_DOMAIN.LIMIT) { - return true - } - - let globalKey = ip - let globalRateLimit = rateLimitMap.get(globalKey) || { - count: 0, - timestamp: now - } - - if (now - globalRateLimit.timestamp > RATE_LIMIT.GLOBAL.DURATION) { - globalRateLimit.count = 0 - globalRateLimit.timestamp = now - } - - if (globalRateLimit.count >= RATE_LIMIT.GLOBAL.LIMIT) { + if (rateLimitInfo.count >= limit.LIMIT) { return true } - domainRateLimit.count += 1 - globalRateLimit.count += 1 - rateLimitMap.set(domainKey, domainRateLimit) - rateLimitMap.set(globalKey, globalRateLimit) + rateLimitInfo.count += 1 + store.set(key, rateLimitInfo) return false } +function checkRateLimit(ip: string, domain: string): boolean { + return ['domain', 'global'].some(type => { + let key = type === 'domain' ? `${ip}:${domain}` : ip + let limit = type === 'domain' ? RATE_LIMIT.PER_DOMAIN : RATE_LIMIT.GLOBAL + return isRateLimited(key, rateLimitMap, limit) + }) +} + function handleError(e: unknown, res: ServerResponse): void { // Known errors if (e instanceof Error && e.name === 'TimeoutError') { @@ -169,10 +159,9 @@ let handleRequestWithDelay = async ( url: string, parsedUrl: URL ): Promise => { - let isRateLimitedFlag = isRateLimited(ip, parsedUrl.hostname) - if (isRateLimitedFlag) { - let existingQueue = requestQueue.get(ip) || Promise.resolve() - let delayedRequest = existingQueue.then(() => setTimeout(1000)) + if (checkRateLimit(ip, parsedUrl.hostname)) { + const existingQueue = requestQueue.get(ip) || Promise.resolve() + const delayedRequest = existingQueue.then(() => setTimeout(1000)) requestQueue.set(ip, delayedRequest) await delayedRequest } From da1272dc5b27dd615069c87f34e8caf92bf41741 Mon Sep 17 00:00:00 2001 From: Jane Fawkes Date: Sun, 9 Jun 2024 20:06:41 +0300 Subject: [PATCH 5/6] Update tests --- proxy/package.json | 2 +- proxy/proxy.ts | 45 ++++----- proxy/test/proxy.test.ts | 202 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 220 insertions(+), 29 deletions(-) diff --git a/proxy/package.json b/proxy/package.json index 9df3c058..6375713a 100644 --- a/proxy/package.json +++ b/proxy/package.json @@ -8,7 +8,7 @@ }, "scripts": { "start": "tsx watch index.ts", - "test": "FORCE_COLOR=1 pnpm run /^test:/", + "test": "pnpm run /^test:/", "build": "esbuild index.ts --bundle --platform=node --sourcemap --format=esm --outfile=dist/index.mjs", "production": "node --run build && ./scripts/run-image.sh", "test:proxy-coverage": "c8 bnt", diff --git a/proxy/proxy.ts b/proxy/proxy.ts index 133d29a1..e66c4a33 100644 --- a/proxy/proxy.ts +++ b/proxy/proxy.ts @@ -1,10 +1,10 @@ import type { IncomingMessage, Server, ServerResponse } from 'node:http' import { createServer } from 'node:http' import { isIP } from 'node:net' -import { styleText } from 'node:util' import { setTimeout } from 'node:timers/promises' +import { styleText } from 'node:util' -class BadRequestError extends Error { +export class BadRequestError extends Error { code: number constructor(message: string, code = 400) { @@ -14,12 +14,12 @@ class BadRequestError extends Error { } } -interface RateLimitInfo { +export interface RateLimitInfo { count: number timestamp: number } -interface ProxyConfig { +export interface ProxyConfig { allowLocalhost?: boolean allowsFrom: RegExp[] maxSize: number @@ -27,23 +27,23 @@ interface ProxyConfig { } const RATE_LIMIT = { - PER_DOMAIN: { - LIMIT: 500, - DURATION: 60 * 1000 - }, GLOBAL: { - LIMIT: 5000, - DURATION: 60 * 1000 + DURATION: 60 * 1000, + LIMIT: 5000 + }, + PER_DOMAIN: { + DURATION: 60 * 1000, + LIMIT: 500 } } -let rateLimitMap: Map = new Map() +export let rateLimitMap: Map = new Map() let requestQueue: Map> = new Map() -function isRateLimited( +export function isRateLimited( key: string, store: Map, - limit: { LIMIT: number; DURATION: number } + limit: { DURATION: number; LIMIT: number } ): boolean { let now = performance.now() let rateLimitInfo = store.get(key) || { count: 0, timestamp: now } @@ -63,7 +63,7 @@ function isRateLimited( return false } -function checkRateLimit(ip: string, domain: string): boolean { +export function checkRateLimit(ip: string, domain: string): boolean { return ['domain', 'global'].some(type => { let key = type === 'domain' ? `${ip}:${domain}` : ip let limit = type === 'domain' ? RATE_LIMIT.PER_DOMAIN : RATE_LIMIT.GLOBAL @@ -71,7 +71,7 @@ function checkRateLimit(ip: string, domain: string): boolean { }) } -function handleError(e: unknown, res: ServerResponse): void { +export function handleError(e: unknown, res: ServerResponse): void { // Known errors if (e instanceof Error && e.name === 'TimeoutError') { res.writeHead(400, { 'Content-Type': 'text/plain' }) @@ -92,12 +92,11 @@ function handleError(e: unknown, res: ServerResponse): void { } } -let processRequest = async ( +export let processRequest = async ( req: IncomingMessage, res: ServerResponse, config: ProxyConfig, - url: string, - parsedUrl: URL + url: string ): Promise => { try { // Remove all cookie headers so they will not be set on proxy domain @@ -151,7 +150,7 @@ let processRequest = async ( } } -let handleRequestWithDelay = async ( +export let handleRequestWithDelay = async ( req: IncomingMessage, res: ServerResponse, config: ProxyConfig, @@ -160,19 +159,17 @@ let handleRequestWithDelay = async ( parsedUrl: URL ): Promise => { if (checkRateLimit(ip, parsedUrl.hostname)) { - const existingQueue = requestQueue.get(ip) || Promise.resolve() - const delayedRequest = existingQueue.then(() => setTimeout(1000)) + let existingQueue = requestQueue.get(ip) || Promise.resolve() + let delayedRequest = existingQueue.then(() => setTimeout(1000)) requestQueue.set(ip, delayedRequest) await delayedRequest } - await processRequest(req, res, config, url, parsedUrl) + await processRequest(req, res, config, url) } export function createProxyServer(config: ProxyConfig): Server { return createServer(async (req: IncomingMessage, res: ServerResponse) => { - let sent = false - try { let ip = req.socket.remoteAddress! let url = decodeURIComponent((req.url ?? '').slice(1)) diff --git a/proxy/test/proxy.test.ts b/proxy/test/proxy.test.ts index 0d6cd83b..5615a0bb 100644 --- a/proxy/test/proxy.test.ts +++ b/proxy/test/proxy.test.ts @@ -1,11 +1,26 @@ import { equal } from 'node:assert' -import { createServer, type Server } from 'node:http' +import { + createServer, + IncomingMessage, + type Server, + ServerResponse +} from 'node:http' import type { AddressInfo } from 'node:net' import { after, test } from 'node:test' import { setTimeout } from 'node:timers/promises' import { URL } from 'node:url' -import { createProxyServer } from '../proxy.js' +import { + BadRequestError, + checkRateLimit, + createProxyServer, + handleError, + handleRequestWithDelay, + isRateLimited, + processRequest, + rateLimitMap +} from '../proxy.js' +import type { ProxyConfig } from '../proxy.js' function getURL(server: Server): string { let port = (server.address() as AddressInfo).port @@ -90,6 +105,32 @@ async function expectBadRequest( equal(await response.text(), message) } +function createMockRequest( + url: string, + method = 'GET', + headers: Record = {} +): IncomingMessage { + let req = new IncomingMessage(null as any) + req.url = url + req.method = method + req.headers = headers + return req +} + +function createMockResponse(): ServerResponse { + let res = new ServerResponse(null as any) + ;(res as any).write = (chunk: any) => chunk + ;(res as any).end = () => {} + return res +} + +const config: ProxyConfig = { + allowLocalhost: true, + allowsFrom: [/^http:\/\/test.app/], + maxSize: 100, + timeout: 100 +} + test('works', async () => { let response = await request(targetUrl) equal(response.status, 200) @@ -136,7 +177,7 @@ test('can use only HTTP or HTTPS protocols', async () => { await expectBadRequest(response, 'Only HTTP or HTTPS are supported') }) -test('can not use proxy to query local address', async () => { +test('cannot use proxy to query local address', async () => { let response = await request(targetUrl.replace('localhost', '127.0.0.1')) await expectBadRequest(response, 'IP addresses are not allowed') }) @@ -196,5 +237,158 @@ test('checks response size', async () => { test('is ready for errors', async () => { let response1 = await request(targetUrl + '?error=1', {}) equal(response1.status, 500) - equal(await response1.text(), 'Error') + equal(await response1.text(), 'Internal Server Error') +}) + +test('rate limits per domain', async () => { + for (let i = 0; i < 500; i++) { + let response = await request(targetUrl) + equal(response.status, 200) + } + + let response = await request(targetUrl) + equal(response.status, 200) +}) + +test('rate limits globally', async () => { + for (let i = 0; i < 5000; i++) { + let response = await request(targetUrl) + equal(response.status, 200) + } + + let response = await request(targetUrl) + equal(response.status, 200) +}) + +test('isRateLimited function', () => { + let limit = { DURATION: 60000, LIMIT: 2 } + let key = 'test-key' + + // First request should not be rate limited + let result = isRateLimited(key, rateLimitMap, limit) + equal(result, false) + + // Second request should not be rate limited + result = isRateLimited(key, rateLimitMap, limit) + equal(result, false) + + // Third request should be rate limited + result = isRateLimited(key, rateLimitMap, limit) + equal(result, true) +}) + +test('isRateLimited function - rate limit info reset', () => { + rateLimitMap.clear() + + let key = '127.0.0.1' + let limit = { DURATION: 60000, LIMIT: 2 } + + rateLimitMap.set(key, { + count: 1, + timestamp: Date.now() - limit.DURATION - 1000 + }) + + isRateLimited(key, rateLimitMap, limit) + + let rateLimitInfo = rateLimitMap.get(key) + equal(rateLimitInfo?.count, 0) + equal(rateLimitInfo?.timestamp, Date.now()) +}) + +test('processRequest function', async () => { + let mockReq = createMockRequest(targetUrl) + let mockRes = createMockResponse() + let parsedUrl = new URL(targetUrl) + + await processRequest(mockReq, mockRes, config, targetUrl) + equal(mockRes.statusCode, 200) +}) + +test('handleRequestWithDelay function', async () => { + let mockReq = createMockRequest(targetUrl.toString()) + let mockRes = createMockResponse() + let ip = '127.0.0.1' + let parsedUrl = new URL(targetUrl) + + await handleRequestWithDelay( + mockReq, + mockRes, + config, + ip, + targetUrl.toString(), + parsedUrl + ) + equal(mockRes.statusCode, 200) +}) + +test('checkRateLimit function domain limit', () => { + let ip = '127.0.0.1' + let domain = 'example.com' + + rateLimitMap.clear() + + let result = checkRateLimit(ip, domain) + equal(result, false) + + result = checkRateLimit(ip, domain) + equal(result, false) + + result = checkRateLimit(ip, domain) + equal(result, true) +}) + +test('checkRateLimit function global limit', () => { + let ip = '127.0.0.1' + let domain = 'another.com' + + rateLimitMap.clear() + + let result = checkRateLimit(ip, domain) + equal(result, false) + + for (let i = 0; i < 5000; i++) { + checkRateLimit(ip, domain) + } + + result = checkRateLimit(ip, domain) + equal(result, true) +}) + +test('handles invalid config', async () => { + let invalidProxy = createProxyServer({ + allowLocalhost: false, + allowsFrom: [], + maxSize: -1, + timeout: -1 + }) + invalidProxy.listen(31599) + let invalidProxyUrl = getURL(invalidProxy) + + try { + let response = await fetch(`${invalidProxyUrl}/${targetUrl}`) + equal(response.status, 400) + equal(await response.text(), 'Invalid URL') + } finally { + invalidProxy.close() + } +}) + +test('handleError function', () => { + let mockRes = createMockResponse() + + // Test with known TimeoutError + let timeoutError = new Error('TimeoutError') + timeoutError.name = 'TimeoutError' + handleError(timeoutError, mockRes) + equal(mockRes.statusCode, 400) + + // Test with custom BadRequestError + let badRequestError = new BadRequestError('Bad request', 400) + handleError(badRequestError, mockRes) + equal(mockRes.statusCode, 400) + + // Test with unknown or internal errors + let unknownError = new Error('Unknown error') + handleError(unknownError, mockRes) + equal(mockRes.statusCode, 500) }) From 116b39e05d2c260b043c5e016e0083cefcade185 Mon Sep 17 00:00:00 2001 From: Jane Fawkes Date: Mon, 10 Jun 2024 23:11:33 +0300 Subject: [PATCH 6/6] Fix tests --- proxy/proxy.ts | 6 ++- proxy/test/proxy.test.ts | 84 +++++++++++++++++++++++++++++----------- 2 files changed, 66 insertions(+), 24 deletions(-) diff --git a/proxy/proxy.ts b/proxy/proxy.ts index e66c4a33..c00678d1 100644 --- a/proxy/proxy.ts +++ b/proxy/proxy.ts @@ -57,7 +57,7 @@ export function isRateLimited( return true } - rateLimitInfo.count += 1 + rateLimitInfo.count++ store.set(key, rateLimitInfo) return false @@ -72,6 +72,10 @@ export function checkRateLimit(ip: string, domain: string): boolean { } export function handleError(e: unknown, res: ServerResponse): void { + if (res.headersSent) { + // Headers already sent, cannot handle error + return + } // Known errors if (e instanceof Error && e.name === 'TimeoutError') { res.writeHead(400, { 'Content-Type': 'text/plain' }) diff --git a/proxy/test/proxy.test.ts b/proxy/test/proxy.test.ts index 5615a0bb..38595246 100644 --- a/proxy/test/proxy.test.ts +++ b/proxy/test/proxy.test.ts @@ -87,7 +87,7 @@ after(() => { proxy.close() }) -function request(url: string, opts: RequestInit = {}): Promise { +async function request(url: string, opts: RequestInit = {}): Promise { return fetch(`${proxyUrl}/${url}`, { ...opts, headers: { @@ -118,15 +118,16 @@ function createMockRequest( } function createMockResponse(): ServerResponse { - let res = new ServerResponse(null as any) + let mockReq = new IncomingMessage(null as any) + let res = new ServerResponse(mockReq) ;(res as any).write = (chunk: any) => chunk ;(res as any).end = () => {} return res } -const config: ProxyConfig = { +let config: ProxyConfig = { allowLocalhost: true, - allowsFrom: [/^http:\/\/test.app/], + allowsFrom: [/^http:\/\/test\.app$/], maxSize: 100, timeout: 100 } @@ -139,12 +140,18 @@ test('works', async () => { }) test('has timeout', async () => { - let response = await request(`${targetUrl}?sleep=500`, {}) + let response = await request(`${targetUrl}?sleep=500`, { + headers: { + Origin: 'http://test.app' + } + }) await expectBadRequest(response, 'Timeout') }) test('transfers query params and path', async () => { - let response = await request(`${targetUrl}/foo/bar?foo=bar&bar=foo`) + let response = await request(`${targetUrl}/foo/bar?foo=bar&bar=foo`, { + headers: { Origin: 'http://test.app' } + }) let parsedResponse = await response.json() equal(response.status, 200) equal(parsedResponse?.response, 'ok') @@ -213,25 +220,38 @@ test('sends user IP to destination', async () => { let response1 = await request(targetUrl) equal(response1.status, 200) let json1 = await response1.json() - equal(json1.request.headers['x-forwarded-for'], '::1') + + equal(json1.headers['x-forwarded-for'], '::1') let response2 = await request(targetUrl, { headers: { 'X-Forwarded-For': '4.4.4.4' } }) equal(response2.status, 200) let json2 = await response2.json() - equal(json2.request.headers['x-forwarded-for'], '4.4.4.4, ::1') + equal(json2.headers['x-forwarded-for'], '4.4.4.4, ::1') }) test('checks response size', async () => { - let response1 = await request(targetUrl + '?big=file', {}) - equal(response1.status, 413) + let response1 = await request(`${targetUrl}?big=file`, {}) + equal( + response1.status, + 413, + `Expected status 413 but received ${response1.status}` + ) equal(await response1.text(), 'Response too large') - let response2 = await request(targetUrl + '?big=stream', {}) - equal(response2.status, 200) + let response2 = await request(`${targetUrl}?big=stream`, {}) + equal( + response2.status, + 200, + `Expected status 200 but received ${response2.status}` + ) let body2 = await response2.text() - equal(body2.length, 150) + equal( + body2.length, + 150, + `Expected body length 150 but received ${body2.length}` + ) }) test('is ready for errors', async () => { @@ -283,25 +303,43 @@ test('isRateLimited function - rate limit info reset', () => { let key = '127.0.0.1' let limit = { DURATION: 60000, LIMIT: 2 } - rateLimitMap.set(key, { - count: 1, - timestamp: Date.now() - limit.DURATION - 1000 - }) + let now = performance.now() + + rateLimitMap.set(key, { count: 1, timestamp: now - limit.DURATION - 1000 }) + + // Check rate limit status before increment + let isLimitedBefore = isRateLimited(key, rateLimitMap, limit) + let rateLimitInfoBefore = rateLimitMap.get(key) + + equal(isLimitedBefore, false) + equal(rateLimitInfoBefore?.count, 0) + + // Check rate limit status after increment + let isLimitedAfter = isRateLimited(key, rateLimitMap, limit) + let rateLimitInfoAfter = rateLimitMap.get(key) - isRateLimited(key, rateLimitMap, limit) + equal(isLimitedAfter, false) + equal(rateLimitInfoAfter?.count, 1) - let rateLimitInfo = rateLimitMap.get(key) - equal(rateLimitInfo?.count, 0) - equal(rateLimitInfo?.timestamp, Date.now()) + // Mocking time progression within the limit duration + performance.now = () => now + 1000 + let isLimitedWithinDuration = isRateLimited(key, rateLimitMap, limit) + equal(isLimitedWithinDuration, false) + equal(rateLimitMap.get(key)?.count, 2) + + // Exceeding the rate limit + let isLimitedExceed = isRateLimited(key, rateLimitMap, limit) + equal(isLimitedExceed, true) }) test('processRequest function', async () => { let mockReq = createMockRequest(targetUrl) let mockRes = createMockResponse() - let parsedUrl = new URL(targetUrl) + + mockReq.url = 'http://invalid-url' await processRequest(mockReq, mockRes, config, targetUrl) - equal(mockRes.statusCode, 200) + equal(mockRes.statusCode, 500) }) test('handleRequestWithDelay function', async () => {