Skip to content

Commit

Permalink
Add support for Workers AI in local mode (#4522)
Browse files Browse the repository at this point in the history
* Add support for Workers AI in local mode

* trigger build

* Add support for AI binding in pages

* Lint code

* Fix pages binding
  • Loading branch information
G4brym authored Dec 6, 2023
1 parent 5e67ea1 commit c10bf0f
Show file tree
Hide file tree
Showing 16 changed files with 170 additions and 8 deletions.
5 changes: 5 additions & 0 deletions .changeset/mean-jars-count.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"wrangler": minor
---

Add support for Workers AI in local mode
1 change: 1 addition & 0 deletions fixtures/ai-app/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dist
21 changes: 21 additions & 0 deletions fixtures/ai-app/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"name": "ai-app",
"version": "1.0.1",
"private": true,
"description": "",
"license": "ISC",
"author": "",
"main": "src/index.js",
"scripts": {
"check:type": "tsc",
"test": "vitest run",
"test:watch": "vitest",
"type:tests": "tsc -p ./tests/tsconfig.json"
},
"devDependencies": {
"undici": "^5.23.0",
"wrangler": "workspace:*",
"@cloudflare/workers-tsconfig": "workspace:^",
"@cloudflare/ai": "^1.0.35"
}
}
12 changes: 12 additions & 0 deletions fixtures/ai-app/src/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
console.log("startup log");

export default {
async fetch(request, env) {
console.log("request log");

return Response.json({
binding: env.AI,
fetcher: env.AI.fetch.toString(),
});
},
};
31 changes: 31 additions & 0 deletions fixtures/ai-app/tests/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { resolve } from "path";
import { fetch } from "undici";
import { describe, it, beforeAll, afterAll } from "vitest";
import { runWranglerDev } from "../../shared/src/run-wrangler-long-lived";

describe("'wrangler dev' correctly renders pages", () => {
let ip: string,
port: number,
stop: (() => Promise<unknown>) | undefined,
getOutput: () => string;

beforeAll(async () => {
({ ip, port, stop, getOutput } = await runWranglerDev(
resolve(__dirname, ".."),
["--local", "--port=0", "--inspector-port=0"]
));
});

afterAll(async () => {
await stop?.();
});

it("ai binding is defined ", async ({ expect }) => {
const response = await fetch(`http://${ip}:${port}/`);
const content = await response.json();
expect(content).toEqual({
binding: {},
fetcher: "function fetch() { [native code] }",
});
});
});
7 changes: 7 additions & 0 deletions fixtures/ai-app/tests/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"extends": "@cloudflare/workers-tsconfig/tsconfig.json",
"compilerOptions": {
"types": ["node"]
},
"include": ["**/*.ts", "../../../node-types.d.ts"]
}
13 changes: 13 additions & 0 deletions fixtures/ai-app/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"compilerOptions": {
"target": "ES2020",
"esModuleInterop": true,
"module": "CommonJS",
"lib": ["ES2020"],
"types": ["node"],
"skipLibCheck": true,
"moduleResolution": "node",
"noEmit": true
},
"include": ["tests", "../../node-types.d.ts"]
}
10 changes: 10 additions & 0 deletions fixtures/ai-app/vitest.config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { defineConfig } from "vitest/config";

export default defineConfig({
test: {
testTimeout: 10_000,
hookTimeout: 10_000,
teardownTimeout: 10_000,
useAtomics: true,
},
});
7 changes: 7 additions & 0 deletions fixtures/ai-app/wrangler.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name = "ai-app"
compatibility_date = "2023-11-21"

main = "src/index.js"

[ai]
binding = "AI"
1 change: 1 addition & 0 deletions packages/wrangler/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"xxhash-wasm": "^1.0.1"
},
"devDependencies": {
"@cloudflare/ai": "^1.0.35",
"@cloudflare/cli": "workspace:*",
"@cloudflare/eslint-config-worker": "*",
"@cloudflare/pages-shared": "workspace:^",
Expand Down
20 changes: 20 additions & 0 deletions packages/wrangler/src/ai/fetcher.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { Response } from "miniflare";
import { performApiFetch } from "../cfetch/internal";
import { getAccountId } from "../user";
import type { Request } from "miniflare";

export async function AIFetcher(request: Request) {
const accountId = await getAccountId();

request.headers.delete("Host");
request.headers.delete("Content-Length");

const res = await performApiFetch(`/accounts/${accountId}/ai/run/proxy`, {
method: "POST",
headers: Object.fromEntries(request.headers.entries()),
body: request.body,
duplex: "half",
});

return new Response(res.body, { status: res.status });
}
3 changes: 3 additions & 0 deletions packages/wrangler/src/api/dev.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ export interface UnstableDevOptions {
bucket_name: string;
preview_bucket_name?: string;
}[];
ai?: {
binding: string;
};
moduleRoot?: string;
rules?: Rule[];
logLevel?: "none" | "info" | "error" | "log" | "warn" | "debug"; // Specify logging level [choices: "debug", "info", "log", "warn", "error", "none"] [default: "log"]
Expand Down
6 changes: 5 additions & 1 deletion packages/wrangler/src/dev.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ export type AdditionalDevProps = {
preview_bucket_name?: string;
jurisdiction?: string;
}[];
ai?: {
binding: string;
};
d1Databases?: Environment["d1_databases"];
processEntrypoint?: boolean;
additionalModules?: CfModule[];
Expand Down Expand Up @@ -832,6 +835,7 @@ function getBindingsAndAssetPaths(args: StartDevOptions, configParam: Config) {
r2: args.r2,
services: args.services,
d1Databases: args.d1Databases,
ai: args.ai,
});

const maskedVars = maskVars(bindings, configParam);
Expand Down Expand Up @@ -893,7 +897,7 @@ function getBindings(
wasm_modules: configParam.wasm_modules,
text_blobs: configParam.text_blobs,
browser: configParam.browser,
ai: configParam.ai,
ai: configParam.ai || args.ai,
data_blobs: configParam.data_blobs,
durable_objects: {
bindings: [
Expand Down
13 changes: 6 additions & 7 deletions packages/wrangler/src/dev/miniflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import assert from "node:assert";
import { realpathSync } from "node:fs";
import path from "node:path";
import { Log, LogLevel, TypedEventTarget, Mutex, Miniflare } from "miniflare";
import { AIFetcher } from "../ai/fetcher";
import { ModuleTypeToRuleType } from "../deployment-bundle/module-collection";
import { withSourceURLs } from "../deployment-bundle/source-url";
import { getHttpsOptions } from "../https-options";
Expand Down Expand Up @@ -312,6 +313,10 @@ function buildBindingOptions(config: ConfigBundle) {
.join("\n"),
};

if (bindings.ai?.binding) {
config.serviceBindings[bindings.ai.binding] = AIFetcher;
}

const bindingOptions = {
bindings: bindings.vars,
textBlobBindings,
Expand Down Expand Up @@ -502,13 +507,7 @@ async function buildMiniflareOptions(
logger.warn("Miniflare 3 does not support CRON triggers yet, ignoring...");
}

if (config.bindings.ai) {
logger.warn(
"Workers AI is not currently supported in local mode. Please use --remote to work with it."
);
}

if (!config.bindings.ai && config.bindings.vectorize?.length) {
if (config.bindings.vectorize?.length) {
// TODO: add local support for Vectorize bindings (https://github.com/cloudflare/workers-sdk/issues/4360)
logger.warn(
"Vectorize bindings are not currently supported in local mode. Please use --remote if you are working with them."
Expand Down
6 changes: 6 additions & 0 deletions packages/wrangler/src/pages/dev.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ export function Options(yargs: CommonYargsArgv) {
type: "array",
description: "R2 bucket to bind (--r2 R2_BINDING)",
},
ai: {
type: "string",
description: "AI to bind (--ai AI_BINDING)",
},
service: {
type: "array",
description: "Service to bind (--service SERVICE=SCRIPT_NAME)",
Expand Down Expand Up @@ -222,6 +226,7 @@ export const Handler = async ({
do: durableObjects = [],
d1: d1s = [],
r2: r2s = [],
ai,
service: requestedServices = [],
liveReload,
localProtocol,
Expand Down Expand Up @@ -677,6 +682,7 @@ export const Handler = async ({
return { binding, bucket_name: ref || binding.toString() };
})
.filter(Boolean) as AdditionalDevProps["r2"],
ai: ai ? { binding: ai.toString() } : undefined,
rules: usingWorkerDirectory
? [
{
Expand Down
22 changes: 22 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c10bf0f

Please sign in to comment.