diff --git a/.changeset/mean-jars-count.md b/.changeset/mean-jars-count.md new file mode 100644 index 000000000000..e0d567e5195e --- /dev/null +++ b/.changeset/mean-jars-count.md @@ -0,0 +1,5 @@ +--- +"wrangler": minor +--- + +Add support for Workers AI in local mode diff --git a/fixtures/ai-app/.gitignore b/fixtures/ai-app/.gitignore new file mode 100644 index 000000000000..1521c8b7652b --- /dev/null +++ b/fixtures/ai-app/.gitignore @@ -0,0 +1 @@ +dist diff --git a/fixtures/ai-app/package.json b/fixtures/ai-app/package.json new file mode 100644 index 000000000000..c2b026ebdff2 --- /dev/null +++ b/fixtures/ai-app/package.json @@ -0,0 +1,21 @@ +{ + "name": "ai-app", + "version": "1.0.1", + "private": true, + "description": "", + "license": "ISC", + "author": "", + "main": "src/index.js", + "scripts": { + "check:type": "tsc", + "test": "vitest run", + "test:watch": "vitest", + "type:tests": "tsc -p ./tests/tsconfig.json" + }, + "devDependencies": { + "undici": "^5.23.0", + "wrangler": "workspace:*", + "@cloudflare/workers-tsconfig": "workspace:^", + "@cloudflare/ai": "^1.0.35" + } +} diff --git a/fixtures/ai-app/src/index.js b/fixtures/ai-app/src/index.js new file mode 100644 index 000000000000..269f4491906f --- /dev/null +++ b/fixtures/ai-app/src/index.js @@ -0,0 +1,12 @@ +console.log("startup log"); + +export default { + async fetch(request, env) { + console.log("request log"); + + return Response.json({ + binding: env.AI, + fetcher: env.AI.fetch.toString(), + }); + }, +}; diff --git a/fixtures/ai-app/tests/index.test.ts b/fixtures/ai-app/tests/index.test.ts new file mode 100644 index 000000000000..472bd2c2790f --- /dev/null +++ b/fixtures/ai-app/tests/index.test.ts @@ -0,0 +1,31 @@ +import { resolve } from "path"; +import { fetch } from "undici"; +import { describe, it, beforeAll, afterAll } from "vitest"; +import { runWranglerDev } from "../../shared/src/run-wrangler-long-lived"; + +describe("'wrangler dev' correctly renders pages", () => { + let ip: string, + port: number, + stop: (() => Promise) | undefined, + getOutput: () => string; + + beforeAll(async () => { + ({ ip, port, stop, getOutput } = await runWranglerDev( + resolve(__dirname, ".."), + ["--local", "--port=0", "--inspector-port=0"] + )); + }); + + afterAll(async () => { + await stop?.(); + }); + + it("ai binding is defined ", async ({ expect }) => { + const response = await fetch(`http://${ip}:${port}/`); + const content = await response.json(); + expect(content).toEqual({ + binding: {}, + fetcher: "function fetch() { [native code] }", + }); + }); +}); diff --git a/fixtures/ai-app/tests/tsconfig.json b/fixtures/ai-app/tests/tsconfig.json new file mode 100644 index 000000000000..d2ce7f144694 --- /dev/null +++ b/fixtures/ai-app/tests/tsconfig.json @@ -0,0 +1,7 @@ +{ + "extends": "@cloudflare/workers-tsconfig/tsconfig.json", + "compilerOptions": { + "types": ["node"] + }, + "include": ["**/*.ts", "../../../node-types.d.ts"] +} diff --git a/fixtures/ai-app/tsconfig.json b/fixtures/ai-app/tsconfig.json new file mode 100644 index 000000000000..b901134e4e79 --- /dev/null +++ b/fixtures/ai-app/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "target": "ES2020", + "esModuleInterop": true, + "module": "CommonJS", + "lib": ["ES2020"], + "types": ["node"], + "skipLibCheck": true, + "moduleResolution": "node", + "noEmit": true + }, + "include": ["tests", "../../node-types.d.ts"] +} diff --git a/fixtures/ai-app/vitest.config.ts b/fixtures/ai-app/vitest.config.ts new file mode 100644 index 000000000000..ed02453b1c87 --- /dev/null +++ b/fixtures/ai-app/vitest.config.ts @@ -0,0 +1,10 @@ +import { defineConfig } from "vitest/config"; + +export default defineConfig({ + test: { + testTimeout: 10_000, + hookTimeout: 10_000, + teardownTimeout: 10_000, + useAtomics: true, + }, +}); diff --git a/fixtures/ai-app/wrangler.toml b/fixtures/ai-app/wrangler.toml new file mode 100644 index 000000000000..841585819660 --- /dev/null +++ b/fixtures/ai-app/wrangler.toml @@ -0,0 +1,7 @@ +name = "ai-app" +compatibility_date = "2023-11-21" + +main = "src/index.js" + +[ai] +binding = "AI" diff --git a/packages/wrangler/package.json b/packages/wrangler/package.json index 22fdd8ee25f6..d5d275a26fdf 100644 --- a/packages/wrangler/package.json +++ b/packages/wrangler/package.json @@ -117,6 +117,7 @@ "xxhash-wasm": "^1.0.1" }, "devDependencies": { + "@cloudflare/ai": "^1.0.35", "@cloudflare/cli": "workspace:*", "@cloudflare/eslint-config-worker": "*", "@cloudflare/pages-shared": "workspace:^", diff --git a/packages/wrangler/src/ai/fetcher.ts b/packages/wrangler/src/ai/fetcher.ts new file mode 100644 index 000000000000..85c2ce25ebf2 --- /dev/null +++ b/packages/wrangler/src/ai/fetcher.ts @@ -0,0 +1,20 @@ +import { Response } from "miniflare"; +import { performApiFetch } from "../cfetch/internal"; +import { getAccountId } from "../user"; +import type { Request } from "miniflare"; + +export async function AIFetcher(request: Request) { + const accountId = await getAccountId(); + + request.headers.delete("Host"); + request.headers.delete("Content-Length"); + + const res = await performApiFetch(`/accounts/${accountId}/ai/run/proxy`, { + method: "POST", + headers: Object.fromEntries(request.headers.entries()), + body: request.body, + duplex: "half", + }); + + return new Response(res.body, { status: res.status }); +} diff --git a/packages/wrangler/src/api/dev.ts b/packages/wrangler/src/api/dev.ts index 2d55dc595a84..16d4a0a35492 100644 --- a/packages/wrangler/src/api/dev.ts +++ b/packages/wrangler/src/api/dev.ts @@ -50,6 +50,9 @@ export interface UnstableDevOptions { bucket_name: string; preview_bucket_name?: string; }[]; + ai?: { + binding: string; + }; moduleRoot?: string; rules?: Rule[]; logLevel?: "none" | "info" | "error" | "log" | "warn" | "debug"; // Specify logging level [choices: "debug", "info", "log", "warn", "error", "none"] [default: "log"] diff --git a/packages/wrangler/src/dev.tsx b/packages/wrangler/src/dev.tsx index b64136ad4dca..5b6b0cfc7631 100644 --- a/packages/wrangler/src/dev.tsx +++ b/packages/wrangler/src/dev.tsx @@ -326,6 +326,9 @@ export type AdditionalDevProps = { preview_bucket_name?: string; jurisdiction?: string; }[]; + ai?: { + binding: string; + }; d1Databases?: Environment["d1_databases"]; processEntrypoint?: boolean; additionalModules?: CfModule[]; @@ -832,6 +835,7 @@ function getBindingsAndAssetPaths(args: StartDevOptions, configParam: Config) { r2: args.r2, services: args.services, d1Databases: args.d1Databases, + ai: args.ai, }); const maskedVars = maskVars(bindings, configParam); @@ -893,7 +897,7 @@ function getBindings( wasm_modules: configParam.wasm_modules, text_blobs: configParam.text_blobs, browser: configParam.browser, - ai: configParam.ai, + ai: configParam.ai || args.ai, data_blobs: configParam.data_blobs, durable_objects: { bindings: [ diff --git a/packages/wrangler/src/dev/miniflare.ts b/packages/wrangler/src/dev/miniflare.ts index 8924ed51e6e6..6ae71e549a8b 100644 --- a/packages/wrangler/src/dev/miniflare.ts +++ b/packages/wrangler/src/dev/miniflare.ts @@ -2,6 +2,7 @@ import assert from "node:assert"; import { realpathSync } from "node:fs"; import path from "node:path"; import { Log, LogLevel, TypedEventTarget, Mutex, Miniflare } from "miniflare"; +import { AIFetcher } from "../ai/fetcher"; import { ModuleTypeToRuleType } from "../deployment-bundle/module-collection"; import { withSourceURLs } from "../deployment-bundle/source-url"; import { getHttpsOptions } from "../https-options"; @@ -312,6 +313,10 @@ function buildBindingOptions(config: ConfigBundle) { .join("\n"), }; + if (bindings.ai?.binding) { + config.serviceBindings[bindings.ai.binding] = AIFetcher; + } + const bindingOptions = { bindings: bindings.vars, textBlobBindings, @@ -502,13 +507,7 @@ async function buildMiniflareOptions( logger.warn("Miniflare 3 does not support CRON triggers yet, ignoring..."); } - if (config.bindings.ai) { - logger.warn( - "Workers AI is not currently supported in local mode. Please use --remote to work with it." - ); - } - - if (!config.bindings.ai && config.bindings.vectorize?.length) { + if (config.bindings.vectorize?.length) { // TODO: add local support for Vectorize bindings (https://github.com/cloudflare/workers-sdk/issues/4360) logger.warn( "Vectorize bindings are not currently supported in local mode. Please use --remote if you are working with them." diff --git a/packages/wrangler/src/pages/dev.ts b/packages/wrangler/src/pages/dev.ts index 326cb1672109..757e3424340d 100644 --- a/packages/wrangler/src/pages/dev.ts +++ b/packages/wrangler/src/pages/dev.ts @@ -155,6 +155,10 @@ export function Options(yargs: CommonYargsArgv) { type: "array", description: "R2 bucket to bind (--r2 R2_BINDING)", }, + ai: { + type: "string", + description: "AI to bind (--ai AI_BINDING)", + }, service: { type: "array", description: "Service to bind (--service SERVICE=SCRIPT_NAME)", @@ -215,6 +219,7 @@ export const Handler = async ({ do: durableObjects = [], d1: d1s = [], r2: r2s = [], + ai, service: requestedServices = [], liveReload, localProtocol, @@ -670,6 +675,7 @@ export const Handler = async ({ return { binding, bucket_name: ref || binding.toString() }; }) .filter(Boolean) as AdditionalDevProps["r2"], + ai: ai ? { binding: ai.toString() } : undefined, rules: usingWorkerDirectory ? [ { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3779289b6998..ff779db4046a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -100,6 +100,21 @@ importers: specifier: workspace:* version: link:../../packages/wrangler + fixtures/ai-app: + devDependencies: + '@cloudflare/ai': + specifier: ^1.0.35 + version: 1.0.35 + '@cloudflare/workers-tsconfig': + specifier: workspace:^ + version: link:../../packages/workers-tsconfig + undici: + specifier: ^5.23.0 + version: 5.23.0 + wrangler: + specifier: workspace:* + version: link:../../packages/wrangler + fixtures/d1-worker-app: devDependencies: wrangler: @@ -1263,6 +1278,9 @@ importers: specifier: ~2.3.2 version: 2.3.2 devDependencies: + '@cloudflare/ai': + specifier: ^1.0.35 + version: 1.0.35 '@cloudflare/cli': specifier: workspace:* version: link:../cli @@ -3273,6 +3291,10 @@ packages: bundledDependencies: - is-unicode-supported + /@cloudflare/ai@1.0.35: + resolution: {integrity: sha512-lqH62H3vwWxH3ZT8BD5N2Y+a2P9rwWiqtJsJsHFjv5p58aXxk3DJuQpFO1YIqdYUQkJr7QQwKa2HXOOJrjvqhQ==} + dev: true + /@cloudflare/cloudflare-brand-assets@4.7.7: resolution: {integrity: sha512-L4EqYWkvse0265YIzraYUlkvvjW7Cr5LELiAiBStctENBoqRhMuR4Ff2X4g/r1BwsQ7S0LTECwLLbM6BNzhv3g==} dev: true