Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request rate limit to proxy #227

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion proxy/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
},
"scripts": {
"start": "tsx watch index.ts",
"test": "FORCE_COLOR=1 pnpm run /^test:/",
Copy link
Author

@janefawkes janefawkes Jun 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it for the dev purposes, gonna add this line back in the next commit.

"test": "pnpm run /^test:/",
"build": "esbuild index.ts --bundle --platform=node --sourcemap --format=esm --outfile=dist/index.mjs",
"production": "node --run build && ./scripts/run-image.sh",
"test:proxy-coverage": "c8 bnt",
Expand Down
45 changes: 21 additions & 24 deletions proxy/proxy.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import type { IncomingMessage, Server, ServerResponse } from 'node:http'
import { createServer } from 'node:http'
import { isIP } from 'node:net'
import { styleText } from 'node:util'
import { setTimeout } from 'node:timers/promises'
import { styleText } from 'node:util'

class BadRequestError extends Error {
export class BadRequestError extends Error {
code: number

constructor(message: string, code = 400) {
Expand All @@ -14,36 +14,36 @@ class BadRequestError extends Error {
}
}

interface RateLimitInfo {
export interface RateLimitInfo {
count: number
timestamp: number
}

interface ProxyConfig {
export interface ProxyConfig {
allowLocalhost?: boolean
allowsFrom: RegExp[]
maxSize: number
timeout: number
}

const RATE_LIMIT = {
PER_DOMAIN: {
LIMIT: 500,
DURATION: 60 * 1000
},
GLOBAL: {
LIMIT: 5000,
DURATION: 60 * 1000
DURATION: 60 * 1000,
LIMIT: 5000
},
PER_DOMAIN: {
DURATION: 60 * 1000,
LIMIT: 500
}
}

let rateLimitMap: Map<string, RateLimitInfo> = new Map()
export let rateLimitMap: Map<string, RateLimitInfo> = new Map()
let requestQueue: Map<string, Promise<void>> = new Map()

function isRateLimited(
export function isRateLimited(
key: string,
store: Map<string, RateLimitInfo>,
limit: { LIMIT: number; DURATION: number }
limit: { DURATION: number; LIMIT: number }
): boolean {
let now = performance.now()
let rateLimitInfo = store.get(key) || { count: 0, timestamp: now }
Expand All @@ -63,15 +63,15 @@ function isRateLimited(
return false
}

function checkRateLimit(ip: string, domain: string): boolean {
export function checkRateLimit(ip: string, domain: string): boolean {
return ['domain', 'global'].some(type => {
let key = type === 'domain' ? `${ip}:${domain}` : ip
let limit = type === 'domain' ? RATE_LIMIT.PER_DOMAIN : RATE_LIMIT.GLOBAL
return isRateLimited(key, rateLimitMap, limit)
})
}

function handleError(e: unknown, res: ServerResponse): void {
export function handleError(e: unknown, res: ServerResponse): void {
// Known errors
if (e instanceof Error && e.name === 'TimeoutError') {
res.writeHead(400, { 'Content-Type': 'text/plain' })
Expand All @@ -92,12 +92,11 @@ function handleError(e: unknown, res: ServerResponse): void {
}
}

let processRequest = async (
export let processRequest = async (
req: IncomingMessage,
res: ServerResponse,
config: ProxyConfig,
url: string,
parsedUrl: URL
url: string
): Promise<void> => {
try {
// Remove all cookie headers so they will not be set on proxy domain
Expand Down Expand Up @@ -151,7 +150,7 @@ let processRequest = async (
}
}

let handleRequestWithDelay = async (
export let handleRequestWithDelay = async (
req: IncomingMessage,
res: ServerResponse,
config: ProxyConfig,
Expand All @@ -160,19 +159,17 @@ let handleRequestWithDelay = async (
parsedUrl: URL
): Promise<void> => {
if (checkRateLimit(ip, parsedUrl.hostname)) {
const existingQueue = requestQueue.get(ip) || Promise.resolve()
const delayedRequest = existingQueue.then(() => setTimeout(1000))
let existingQueue = requestQueue.get(ip) || Promise.resolve()
let delayedRequest = existingQueue.then(() => setTimeout(1000))
requestQueue.set(ip, delayedRequest)
await delayedRequest
}

await processRequest(req, res, config, url, parsedUrl)
await processRequest(req, res, config, url)
}

export function createProxyServer(config: ProxyConfig): Server {
return createServer(async (req: IncomingMessage, res: ServerResponse) => {
let sent = false

try {
let ip = req.socket.remoteAddress!
let url = decodeURIComponent((req.url ?? '').slice(1))
Expand Down
202 changes: 198 additions & 4 deletions proxy/test/proxy.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
import { equal } from 'node:assert'
import { createServer, type Server } from 'node:http'
import {
createServer,
IncomingMessage,
type Server,
ServerResponse
} from 'node:http'
import type { AddressInfo } from 'node:net'
import { after, test } from 'node:test'
import { setTimeout } from 'node:timers/promises'
import { URL } from 'node:url'

import { createProxyServer } from '../proxy.js'
import {
BadRequestError,
checkRateLimit,
createProxyServer,
handleError,
handleRequestWithDelay,
isRateLimited,
processRequest,
rateLimitMap
} from '../proxy.js'
import type { ProxyConfig } from '../proxy.js'

function getURL(server: Server): string {
let port = (server.address() as AddressInfo).port
Expand Down Expand Up @@ -90,6 +105,32 @@ async function expectBadRequest(
equal(await response.text(), message)
}

function createMockRequest(
url: string,
method = 'GET',
headers: Record<string, string> = {}
): IncomingMessage {
let req = new IncomingMessage(null as any)
req.url = url
req.method = method
req.headers = headers
return req
}

function createMockResponse(): ServerResponse {
let res = new ServerResponse(null as any)
;(res as any).write = (chunk: any) => chunk
;(res as any).end = () => {}
return res
}

const config: ProxyConfig = {
allowLocalhost: true,
allowsFrom: [/^http:\/\/test.app/],
maxSize: 100,
timeout: 100
}

test('works', async () => {
let response = await request(targetUrl)
equal(response.status, 200)
Expand Down Expand Up @@ -136,7 +177,7 @@ test('can use only HTTP or HTTPS protocols', async () => {
await expectBadRequest(response, 'Only HTTP or HTTPS are supported')
})

test('can not use proxy to query local address', async () => {
test('cannot use proxy to query local address', async () => {
let response = await request(targetUrl.replace('localhost', '127.0.0.1'))
await expectBadRequest(response, 'IP addresses are not allowed')
})
Expand Down Expand Up @@ -196,5 +237,158 @@ test('checks response size', async () => {
test('is ready for errors', async () => {
let response1 = await request(targetUrl + '?error=1', {})
equal(response1.status, 500)
equal(await response1.text(), 'Error')
equal(await response1.text(), 'Internal Server Error')
})

test('rate limits per domain', async () => {
for (let i = 0; i < 500; i++) {
let response = await request(targetUrl)
equal(response.status, 200)
}

let response = await request(targetUrl)
equal(response.status, 200)
})

test('rate limits globally', async () => {
for (let i = 0; i < 5000; i++) {
let response = await request(targetUrl)
equal(response.status, 200)
}

let response = await request(targetUrl)
equal(response.status, 200)
})

test('isRateLimited function', () => {
let limit = { DURATION: 60000, LIMIT: 2 }
let key = 'test-key'

// First request should not be rate limited
let result = isRateLimited(key, rateLimitMap, limit)
equal(result, false)

// Second request should not be rate limited
result = isRateLimited(key, rateLimitMap, limit)
equal(result, false)

// Third request should be rate limited
result = isRateLimited(key, rateLimitMap, limit)
equal(result, true)
})

test('isRateLimited function - rate limit info reset', () => {
rateLimitMap.clear()

let key = '127.0.0.1'
let limit = { DURATION: 60000, LIMIT: 2 }

rateLimitMap.set(key, {
count: 1,
timestamp: Date.now() - limit.DURATION - 1000
})

isRateLimited(key, rateLimitMap, limit)

let rateLimitInfo = rateLimitMap.get(key)
equal(rateLimitInfo?.count, 0)
equal(rateLimitInfo?.timestamp, Date.now())
})

test('processRequest function', async () => {
let mockReq = createMockRequest(targetUrl)
let mockRes = createMockResponse()
let parsedUrl = new URL(targetUrl)

await processRequest(mockReq, mockRes, config, targetUrl)
equal(mockRes.statusCode, 200)
})

test('handleRequestWithDelay function', async () => {
let mockReq = createMockRequest(targetUrl.toString())
let mockRes = createMockResponse()
let ip = '127.0.0.1'
let parsedUrl = new URL(targetUrl)

await handleRequestWithDelay(
mockReq,
mockRes,
config,
ip,
targetUrl.toString(),
parsedUrl
)
equal(mockRes.statusCode, 200)
})

test('checkRateLimit function domain limit', () => {
let ip = '127.0.0.1'
let domain = 'example.com'

rateLimitMap.clear()

let result = checkRateLimit(ip, domain)
equal(result, false)

result = checkRateLimit(ip, domain)
equal(result, false)

result = checkRateLimit(ip, domain)
equal(result, true)
})

test('checkRateLimit function global limit', () => {
let ip = '127.0.0.1'
let domain = 'another.com'

rateLimitMap.clear()

let result = checkRateLimit(ip, domain)
equal(result, false)

for (let i = 0; i < 5000; i++) {
checkRateLimit(ip, domain)
}

result = checkRateLimit(ip, domain)
equal(result, true)
})

test('handles invalid config', async () => {
let invalidProxy = createProxyServer({
allowLocalhost: false,
allowsFrom: [],
maxSize: -1,
timeout: -1
})
invalidProxy.listen(31599)
let invalidProxyUrl = getURL(invalidProxy)

try {
let response = await fetch(`${invalidProxyUrl}/${targetUrl}`)
equal(response.status, 400)
equal(await response.text(), 'Invalid URL')
} finally {
invalidProxy.close()
}
})

test('handleError function', () => {
let mockRes = createMockResponse()

// Test with known TimeoutError
let timeoutError = new Error('TimeoutError')
timeoutError.name = 'TimeoutError'
handleError(timeoutError, mockRes)
equal(mockRes.statusCode, 400)

// Test with custom BadRequestError
let badRequestError = new BadRequestError('Bad request', 400)
handleError(badRequestError, mockRes)
equal(mockRes.statusCode, 400)

// Test with unknown or internal errors
let unknownError = new Error('Unknown error')
handleError(unknownError, mockRes)
equal(mockRes.statusCode, 500)
})
Loading