From 27fb22b7c6b224aecc852915d9fee600d9d86efc Mon Sep 17 00:00:00 2001 From: MrBBot Date: Fri, 8 Mar 2024 14:03:30 +0000 Subject: [PATCH] fix: ensure redirects handled correctly with `dispatchFetch()` (#5191) --- .changeset/purple-lemons-yell.md | 11 ++ packages/miniflare/src/http/fetch.ts | 142 ++++++++++++++++-- packages/miniflare/src/http/server.ts | 6 +- packages/miniflare/src/index.ts | 32 ++-- packages/miniflare/src/plugins/core/index.ts | 3 +- .../miniflare/src/plugins/shared/constants.ts | 5 - .../miniflare/src/workers/core/constants.ts | 1 + packages/miniflare/test/index.spec.ts | 104 ++++++++++++- 8 files changed, 264 insertions(+), 40 deletions(-) create mode 100644 .changeset/purple-lemons-yell.md diff --git a/.changeset/purple-lemons-yell.md b/.changeset/purple-lemons-yell.md new file mode 100644 index 000000000000..effbabadfa15 --- /dev/null +++ b/.changeset/purple-lemons-yell.md @@ -0,0 +1,11 @@ +--- +"miniflare": patch +--- + +fix: ensure redirect responses handled correctly with `dispatchFetch()` + +Previously, if your Worker returned a redirect response, calling `dispatchFetch(url)` would send another request to the original `url` rather than the redirect. This change ensures redirects are followed correctly. + +- If your Worker returns a relative redirect or an absolute redirect with the same origin as the original `url`, the request will be sent to the Worker. +- If your Worker instead returns an absolute redirect with a different origin, the request will be sent to the Internet. +- If a redirected request to a different origin returns an absolute redirect with the same origin as the original `url`, the request will also be sent to the Worker. diff --git a/packages/miniflare/src/http/fetch.ts b/packages/miniflare/src/http/fetch.ts index 13523e0271aa..0ecce6743e04 100644 --- a/packages/miniflare/src/http/fetch.ts +++ b/packages/miniflare/src/http/fetch.ts @@ -2,20 +2,11 @@ import http from "http"; import { IncomingRequestCfProperties } from "@cloudflare/workers-types/experimental"; import { Dispatcher, Headers, fetch as baseFetch } from "undici"; import NodeWebSocket from "ws"; -import { DeferredPromise } from "../workers"; +import { CoreHeaders, DeferredPromise } from "../workers"; import { Request, RequestInfo, RequestInit } from "./request"; import { Response } from "./response"; import { WebSocketPair, coupleWebSocket } from "./websocket"; -// `Dispatcher`s don't expose whether they had `rejectUnauthorized` set when -// constructed, but we need to know whether to pass this when constructing -// WebSockets. Instead, we add all known `rejectUnauthorized` dispatchers to -// a weak map, and check that before constructing WebSockets. -const allowUnauthorizedDispatchers = new WeakSet(); -export function registerAllowUnauthorizedDispatcher(dispatcher: Dispatcher) { - allowUnauthorizedDispatchers.add(dispatcher); -} - const ignored = ["transfer-encoding", "connection", "keep-alive", "expect"]; function headersFromIncomingRequest(req: http.IncomingMessage): Headers { const entries = Object.entries(req.headers).filter( @@ -59,11 +50,11 @@ export async function fetch( } } - const rejectUnauthorized = - requestInit?.dispatcher !== undefined && - allowUnauthorizedDispatchers.has(requestInit?.dispatcher) - ? { rejectUnauthorized: false } - : {}; + let rejectUnauthorized: { rejectUnauthorized: false } | undefined; + if (requestInit.dispatcher instanceof DispatchFetchDispatcher) { + requestInit.dispatcher.addHeaders(headers, url.pathname + url.search); + rejectUnauthorized = { rejectUnauthorized: false }; + } // Establish web socket connection const ws = new NodeWebSocket(url, protocols, { @@ -106,3 +97,124 @@ export type DispatchFetch = ( input: RequestInfo, init?: RequestInit> ) => Promise; + +export type AnyHeaders = http.IncomingHttpHeaders | string[]; +function addHeader(/* mut */ headers: AnyHeaders, key: string, value: string) { + if (Array.isArray(headers)) headers.push(key, value); + else headers[key] = value; +} + +/** + * Dispatcher created for each `dispatchFetch()` call. Ensures request origin + * in Worker matches that passed to `dispatchFetch()`, not the address the + * `workerd` server is listening on. Handles cases where `fetch()` redirects to + * same origin and different external origins. + */ +export class DispatchFetchDispatcher extends Dispatcher { + private readonly cfBlobJson?: string; + + /** + * @param globalDispatcher Dispatcher to use for all non-runtime requests + * (rejects unauthorised certificates) + * @param runtimeDispatcher Dispatcher to use for runtime requests + * (permits unauthorised certificates) + * @param actualRuntimeOrigin Origin to send all runtime requests to + * @param userRuntimeOrigin Origin to treat as runtime request + * (initial URL passed by user to `dispatchFetch()`) + * @param cfBlob `request.cf` blob override for runtime requests + */ + constructor( + private readonly globalDispatcher: Dispatcher, + private readonly runtimeDispatcher: Dispatcher, + private readonly actualRuntimeOrigin: string, + private readonly userRuntimeOrigin: string, + cfBlob?: IncomingRequestCfProperties + ) { + super(); + if (cfBlob !== undefined) this.cfBlobJson = JSON.stringify(cfBlob); + } + + addHeaders( + /* mut */ headers: AnyHeaders, + path: string // Including query parameters + ) { + // Reconstruct URL using runtime origin specified with `dispatchFetch()` + const originalURL = this.userRuntimeOrigin + path; + addHeader(headers, CoreHeaders.ORIGINAL_URL, originalURL); + addHeader(headers, CoreHeaders.DISABLE_PRETTY_ERROR, "true"); + if (this.cfBlobJson !== undefined) { + // Only add this header if a `cf` override was set + addHeader(headers, CoreHeaders.CF_BLOB, this.cfBlobJson); + } + } + + dispatch( + /* mut */ options: Dispatcher.DispatchOptions, + handler: Dispatcher.DispatchHandlers + ): boolean { + let origin = String(options.origin); + // The first request in a redirect chain will always match the user origin + if (origin === this.userRuntimeOrigin) origin = this.actualRuntimeOrigin; + if (origin === this.actualRuntimeOrigin) { + // If this is now a request to the runtime, rewrite dispatching origin to + // the runtime's + options.origin = origin; + + let path = options.path; + if (options.query !== undefined) { + // `options.path` may include query parameters, so we need to parse it + const url = new URL(path, "http://placeholder/"); + for (const [key, value] of Object.entries(options.query)) { + url.searchParams.append(key, value); + } + path = url.pathname + url.search; + } + + // ...and add special Miniflare headers for runtime requests + options.headers ??= {}; + this.addHeaders(options.headers, path); + + // Dispatch with runtime dispatcher to avoid certificate errors if using + // self-signed certificate + return this.runtimeDispatcher.dispatch(options, handler); + } else { + // If this wasn't a request to the runtime (e.g. redirect to somewhere + // else), use the regular global dispatcher, without special headers + return this.globalDispatcher.dispatch(options, handler); + } + } + + close(): Promise; + close(callback: () => void): void; + async close(callback?: () => void): Promise { + await Promise.all([ + this.globalDispatcher.close(), + this.runtimeDispatcher.close(), + ]); + callback?.(); + } + + destroy(): Promise; + destroy(err: Error | null): Promise; + destroy(callback: () => void): void; + destroy(err: Error | null, callback: () => void): void; + async destroy( + errCallback?: Error | null | (() => void), + callback?: () => void + ): Promise { + let err: Error | null = null; + if (typeof errCallback === "function") callback = errCallback; + if (errCallback instanceof Error) err = errCallback; + + await Promise.all([ + this.globalDispatcher.destroy(err), + this.runtimeDispatcher.destroy(err), + ]); + callback?.(); + } + + get isMockActive(): boolean { + // @ts-expect-error missing type on `MockAgent`, but exists at runtime + return this.globalDispatcher.isMockActive ?? false; + } +} diff --git a/packages/miniflare/src/http/server.ts b/packages/miniflare/src/http/server.ts index 3f25857e7a85..1792a621d10f 100644 --- a/packages/miniflare/src/http/server.ts +++ b/packages/miniflare/src/http/server.ts @@ -1,14 +1,14 @@ import fs from "fs/promises"; import { z } from "zod"; -import { CORE_PLUGIN, HEADER_CF_BLOB } from "../plugins"; +import { CORE_PLUGIN } from "../plugins"; import { HttpOptions, Socket_Https } from "../runtime"; -import { Awaitable } from "../workers"; +import { Awaitable, CoreHeaders } from "../workers"; import { CERT, KEY } from "./cert"; export const ENTRY_SOCKET_HTTP_OPTIONS: HttpOptions = { // Even though we inject a `cf` object in the entry worker, allow it to // be customised via `dispatchFetch` - cfBlobHeader: HEADER_CF_BLOB, + cfBlobHeader: CoreHeaders.CF_BLOB, }; export async function getEntrySocketHttpOptions( diff --git a/packages/miniflare/src/index.ts b/packages/miniflare/src/index.ts index 6f863ec6f9e3..88f8f36df3ed 100644 --- a/packages/miniflare/src/index.ts +++ b/packages/miniflare/src/index.ts @@ -22,7 +22,7 @@ import type { import exitHook from "exit-hook"; import { $ as colors$ } from "kleur/colors"; import stoppable from "stoppable"; -import { Dispatcher, Pool } from "undici"; +import { Dispatcher, Pool, getGlobalDispatcher } from "undici"; import SCRIPT_MINIFLARE_SHARED from "worker:shared/index"; import SCRIPT_MINIFLARE_ZOD from "worker:shared/zod"; import { WebSocketServer } from "ws"; @@ -30,6 +30,7 @@ import { z } from "zod"; import { fallbackCf, setupCf } from "./cf"; import { DispatchFetch, + DispatchFetchDispatcher, ENTRY_SOCKET_HTTP_OPTIONS, Headers, Request, @@ -39,13 +40,11 @@ import { fetch, getAccessibleHosts, getEntrySocketHttpOptions, - registerAllowUnauthorizedDispatcher, } from "./http"; import { D1_PLUGIN_NAME, DURABLE_OBJECTS_PLUGIN_NAME, DurableObjectClassNames, - HEADER_CF_BLOB, KV_PLUGIN_NAME, PLUGIN_ENTRIES, PluginServicesOptions, @@ -816,8 +815,8 @@ export class Miniflare { } // Extract cf blob (if any) from headers - const cfBlob = headers.get(HEADER_CF_BLOB); - headers.delete(HEADER_CF_BLOB); + const cfBlob = headers.get(CoreHeaders.CF_BLOB); + headers.delete(CoreHeaders.CF_BLOB); assert(!Array.isArray(cfBlob)); // Only `Set-Cookie` headers are arrays const cf = cfBlob ? JSON.parse(cfBlob) : undefined; @@ -1336,7 +1335,6 @@ export class Miniflare { this.#runtimeDispatcher = new Pool(this.#runtimeEntryURL, { connect: { rejectUnauthorized: false }, }); - registerAllowUnauthorizedDispatcher(this.#runtimeDispatcher); } if (this.#proxyClient === undefined) { this.#proxyClient = new ProxyClient( @@ -1508,14 +1506,13 @@ export class Miniflare { const forward = new Request(input, init); const url = new URL(forward.url); - forward.headers.set(CoreHeaders.ORIGINAL_URL, url.toString()); - forward.headers.set(CoreHeaders.DISABLE_PRETTY_ERROR, "true"); + const actualRuntimeOrigin = this.#runtimeEntryURL.origin; + const userRuntimeOrigin = url.origin; + + // Rewrite URL for WebSocket requests which won't use `DispatchFetchDispatcher` url.protocol = this.#runtimeEntryURL.protocol; url.host = this.#runtimeEntryURL.host; - if (forward.cf) { - const cf = { ...fallbackCf, ...forward.cf }; - forward.headers.set(HEADER_CF_BLOB, JSON.stringify(cf)); - } + // Remove `Content-Length: 0` headers from requests when a body is set to // avoid `RequestContentLengthMismatch` errors if ( @@ -1525,8 +1522,17 @@ export class Miniflare { forward.headers.delete("Content-Length"); } + const cfBlob = forward.cf ? { ...fallbackCf, ...forward.cf } : undefined; + const dispatcher = new DispatchFetchDispatcher( + getGlobalDispatcher(), + this.#runtimeDispatcher, + actualRuntimeOrigin, + userRuntimeOrigin, + cfBlob + ); + const forwardInit = forward as RequestInit; - forwardInit.dispatcher = this.#runtimeDispatcher; + forwardInit.dispatcher = dispatcher; const response = await fetch(url, forwardInit); // If the Worker threw an uncaught exception, propagate it to the caller diff --git a/packages/miniflare/src/plugins/core/index.ts b/packages/miniflare/src/plugins/core/index.ts index c3dd0bb664d6..786db72377a1 100644 --- a/packages/miniflare/src/plugins/core/index.ts +++ b/packages/miniflare/src/plugins/core/index.ts @@ -36,7 +36,6 @@ import { import { getCacheServiceName } from "../cache"; import { DURABLE_OBJECTS_STORAGE_SERVICE_NAME } from "../do"; import { - HEADER_CF_BLOB, Plugin, SERVICE_LOOPBACK, WORKER_BINDING_SERVICE_LOOPBACK, @@ -716,7 +715,7 @@ export function getGlobalServices({ return [ { name: SERVICE_LOOPBACK, - external: { http: { cfBlobHeader: HEADER_CF_BLOB } }, + external: { http: { cfBlobHeader: CoreHeaders.CF_BLOB } }, }, { name: SERVICE_ENTRY, diff --git a/packages/miniflare/src/plugins/shared/constants.ts b/packages/miniflare/src/plugins/shared/constants.ts index 13a9871a142b..69768bef2219 100644 --- a/packages/miniflare/src/plugins/shared/constants.ts +++ b/packages/miniflare/src/plugins/shared/constants.ts @@ -17,11 +17,6 @@ export function getDirectSocketName(workerIndex: number) { // Service looping back to Miniflare's Node.js process (for storage, etc) export const SERVICE_LOOPBACK = "loopback"; -// Even though we inject the `cf` blob in the entry script, we still need to -// specify a header, so we receive things like `cf.cacheKey` in loopback -// requests. -export const HEADER_CF_BLOB = "MF-CF-Blob"; - export const WORKER_BINDING_SERVICE_LOOPBACK: Worker_Binding = { name: CoreBindings.SERVICE_LOOPBACK, service: { name: SERVICE_LOOPBACK }, diff --git a/packages/miniflare/src/workers/core/constants.ts b/packages/miniflare/src/workers/core/constants.ts index 2f4f6a332b63..aaf9829b9a0c 100644 --- a/packages/miniflare/src/workers/core/constants.ts +++ b/packages/miniflare/src/workers/core/constants.ts @@ -5,6 +5,7 @@ export const CoreHeaders = { DISABLE_PRETTY_ERROR: "MF-Disable-Pretty-Error", ERROR_STACK: "MF-Experimental-Error-Stack", ROUTE_OVERRIDE: "MF-Route-Override", + CF_BLOB: "MF-CF-Blob", // API Proxy OP_SECRET: "MF-Op-Secret", diff --git a/packages/miniflare/test/index.spec.ts b/packages/miniflare/test/index.spec.ts index 6b585bb42720..afacbf83494e 100644 --- a/packages/miniflare/test/index.spec.ts +++ b/packages/miniflare/test/index.spec.ts @@ -468,10 +468,10 @@ test("Miniflare: custom service binding to another Miniflare instance", async (t // Checking URL (including protocol/host) and body preserved through // `dispatchFetch()` and custom service bindings - let res = await mf.dispatchFetch("https://custom1.mf/a"); + let res = await mf.dispatchFetch("https://custom1.mf/a?key=value"); t.deepEqual(await res.json(), { method: "GET", - url: "https://custom1.mf/a", + url: "https://custom1.mf/a?key=value", body: null, }); @@ -596,6 +596,106 @@ test("Miniflare: can send GET request with body", async (t) => { }); }); +test("Miniflare: handles redirect responses", async (t) => { + // https://github.com/cloudflare/workers-sdk/issues/5018 + + const { http } = await useServer(t, (req, res) => { + // Check no special headers set + const headerKeys = Object.keys(req.headers); + t.deepEqual( + headerKeys.filter((key) => key.toLowerCase().startsWith("mf-")), + [] + ); + + const { pathname } = new URL(req.url ?? "", "http://placeholder"); + if (pathname === "/ping") { + res.end("pong"); + } else if (pathname === "/redirect-back") { + res.writeHead(302, { Location: "https://custom.mf/external-redirected" }); + res.end(); + } else { + res.writeHead(404); + res.end("Not Found"); + } + }); + + const mf = new Miniflare({ + bindings: { EXTERNAL_URL: http.href }, + compatibilityDate: "2024-01-01", + modules: true, + script: `export default { + async fetch(request, env) { + const url = new URL(request.url); + const externalUrl = new URL(env.EXTERNAL_URL); + if (url.pathname === "/redirect-relative") { + return new Response(null, { status: 302, headers: { Location: "/relative-redirected" } }); + } else if (url.pathname === "/redirect-absolute") { + url.pathname = "/absolute-redirected"; + return Response.redirect(url, 302); + } else if (url.pathname === "/redirect-external") { + externalUrl.pathname = "/ping"; + return Response.redirect(externalUrl, 302); + } else if (url.pathname === "/redirect-external-and-back") { + externalUrl.pathname = "/redirect-back"; + return Response.redirect(externalUrl, 302); + } else { + return new Response("end:" + url.href); + } + } + }`, + }); + t.teardown(() => mf.dispose()); + + // Check relative redirect + let res = await mf.dispatchFetch("https://custom.mf/redirect-relative", { + redirect: "manual", + }); + t.is(res.status, 302); + t.is(res.headers.get("Location"), "/relative-redirected"); + await res.arrayBuffer(); // (drain) + + res = await mf.dispatchFetch("https://custom.mf/redirect-relative"); + t.is(res.status, 200); + t.is(await res.text(), "end:https://custom.mf/relative-redirected"); + + // Check absolute redirect to same origin + res = await mf.dispatchFetch("https://custom.mf/redirect-absolute", { + redirect: "manual", + }); + t.is(res.status, 302); + t.is(res.headers.get("Location"), "https://custom.mf/absolute-redirected"); + await res.arrayBuffer(); // (drain) + + res = await mf.dispatchFetch("https://custom.mf/redirect-absolute"); + t.is(res.status, 200); + t.is(await res.text(), "end:https://custom.mf/absolute-redirected"); + + // Check absolute redirect to external origin + res = await mf.dispatchFetch("https://custom.mf/redirect-external", { + redirect: "manual", + }); + t.is(res.status, 302); + t.is(res.headers.get("Location"), new URL("/ping", http).href); + await res.arrayBuffer(); // (drain) + + res = await mf.dispatchFetch("https://custom.mf/redirect-external"); + t.is(res.status, 200); + t.is(await res.text(), "pong"); + + // Check absolute redirect to external origin, then redirect back to initial + res = await mf.dispatchFetch("https://custom.mf/redirect-external-and-back", { + redirect: "manual", + }); + t.is(res.status, 302); + t.is(res.headers.get("Location"), new URL("/redirect-back", http).href); + await res.arrayBuffer(); // (drain) + + res = await mf.dispatchFetch("https://custom.mf/redirect-external-and-back"); + t.is(res.status, 200); + // External server redirects back to worker running in `workerd` + t.is(await res.text(), "end:https://custom.mf/external-redirected"); +}); + test("Miniflare: fetch mocking", async (t) => { const fetchMock = createFetchMock(); fetchMock.disableNetConnect();