From 7034fa4ffa35ece0661232b35e51076a741788f1 Mon Sep 17 00:00:00 2001 From: Leo Orshansky Date: Fri, 6 Aug 2021 10:41:49 -0500 Subject: [PATCH] Feat: Add Support for Durable Objects (#12) * allow rust to access pre-existing durable object * Transition to ES Modules build-upload type to prepare for exporting a DO from Rust * updated * update * finished testing storage API * update Co-authored-by: Leo Orshansky --- .gitignore | 3 +- Cargo.toml | 2 +- macros/cf/Cargo.toml | 3 +- macros/cf/src/lib.rs | 15 +- macros/durable_object/Cargo.toml | 10 + macros/durable_object/src/lib.rs | 96 ++++ rust-sandbox/Cargo.toml | 3 + rust-sandbox/src/counter.rs | 36 ++ rust-sandbox/src/lib.rs | 43 +- rust-sandbox/src/test/durable.rs | 48 ++ .../src/test/export_durable_object.rs | 96 ++++ rust-sandbox/src/test/mod.rs | 11 + rust-sandbox/tests/web.rs | 12 - rust-sandbox/worker/export_wasm.mjs | 7 + rust-sandbox/worker/metadata_wasm.json | 10 - rust-sandbox/worker/shim.mjs | 4 + rust-sandbox/worker/worker.js | 23 - rust-sandbox/wrangler.toml | 29 +- worker/Cargo.toml | 4 +- worker/src/durable.rs | 516 ++++++++++++++++++ worker/src/lib.rs | 131 +++-- worker/src/router.rs | 32 +- worker/tests/headers.rs | 2 +- 23 files changed, 1011 insertions(+), 125 deletions(-) create mode 100644 macros/durable_object/Cargo.toml create mode 100644 macros/durable_object/src/lib.rs create mode 100644 rust-sandbox/src/counter.rs create mode 100644 rust-sandbox/src/test/durable.rs create mode 100644 rust-sandbox/src/test/export_durable_object.rs create mode 100644 rust-sandbox/src/test/mod.rs delete mode 100644 rust-sandbox/tests/web.rs create mode 100644 rust-sandbox/worker/export_wasm.mjs delete mode 100644 rust-sandbox/worker/metadata_wasm.json create mode 100644 rust-sandbox/worker/shim.mjs delete mode 100644 rust-sandbox/worker/worker.js create mode 100644 worker/src/durable.rs diff --git a/.gitignore b/.gitignore index 696ab98d..8437ed77 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ **/target Cargo.lock -.DS_Store \ No newline at end of file +.DS_Store +**/worker/generated/* \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 279f3030..fa948568 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,2 @@ [workspace] -members = ["edgeworker-sys", "macros/cf", "rust-sandbox", "worker"] +members = ["edgeworker-sys", "macros/cf", "macros/durable_object", "rust-sandbox", "worker"] diff --git a/macros/cf/Cargo.toml b/macros/cf/Cargo.toml index 5251b084..a0c893e7 100644 --- a/macros/cf/Cargo.toml +++ b/macros/cf/Cargo.toml @@ -13,4 +13,5 @@ quote = "1.0.9" syn = "1.0.72" wasm-bindgen-macro-support = "0.2.74" web-sys = "0.3.51" -worker = { path = "../../worker" } \ No newline at end of file +worker = { path = "../../worker" } +durable_object = { path = "../durable_object" } \ No newline at end of file diff --git a/macros/cf/src/lib.rs b/macros/cf/src/lib.rs index 99cb0f8e..d11facf8 100644 --- a/macros/cf/src/lib.rs +++ b/macros/cf/src/lib.rs @@ -21,17 +21,17 @@ pub fn worker(attr: TokenStream, item: TokenStream) -> TokenStream { // let input_arg = input_fn.sig.inputs.first().expect("#[cf::worker(fetch)] attribute requires exactly one input, of type `worker::Request`"); // save original fn name for re-use in the wrapper fn - let original_input_fn_ident = input_fn.sig.ident.clone(); - let output_fn_ident = Ident::new("glue_fetch", input_fn.sig.ident.span()); + let input_fn_ident = Ident::new(&(input_fn.sig.ident.to_string() + "_fetch_glue"), input_fn.sig.ident.span()); + let wrapper_fn_ident = Ident::new("fetch", input_fn.sig.ident.span()); // rename the original attributed fn - input_fn.sig.ident = output_fn_ident.clone(); + input_fn.sig.ident = input_fn_ident.clone(); // create a new "main" function that takes the edgeworker_sys::Request, and calls the // original attributed function, passing in a converted worker::Request let wrapper_fn = quote! { - pub async fn #original_input_fn_ident(ty: String, req: edgeworker_sys::Request) -> worker::Result { + pub async fn #wrapper_fn_ident(req: ::edgeworker_sys::Request, env: ::worker::Env) -> ::worker::Result<::edgeworker_sys::Response> { // get the worker::Result by calling the original fn - #output_fn_ident(worker::Request::from((ty, req))).await + #input_fn_ident(worker::Request::from(req), env).await .map(edgeworker_sys::Response::from) } }; @@ -75,3 +75,8 @@ pub fn worker(attr: TokenStream, item: TokenStream) -> TokenStream { _ => exit_missing_attr(), } } + +#[proc_macro_attribute] +pub fn durable_object(_attr: TokenStream, item: TokenStream) -> TokenStream { + durable_object::expand_macro(item.into()).unwrap_or_else(syn::Error::into_compile_error).into() +} \ No newline at end of file diff --git a/macros/durable_object/Cargo.toml b/macros/durable_object/Cargo.toml new file mode 100644 index 00000000..aa27b39b --- /dev/null +++ b/macros/durable_object/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "durable_object" +version = "0.1.0" +edition = "2018" + +[dependencies] +worker = { path = "../../worker" } +syn = "1.0.74" +quote = "1.0.9" +proc-macro2 = "1.0.28" \ No newline at end of file diff --git a/macros/durable_object/src/lib.rs b/macros/durable_object/src/lib.rs new file mode 100644 index 00000000..e0799877 --- /dev/null +++ b/macros/durable_object/src/lib.rs @@ -0,0 +1,96 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::{ToTokens, quote}; +use syn::{Error, ImplItem, Item, spanned::Spanned}; + +pub fn expand_macro(tokens: TokenStream) -> syn::Result { + let item = syn::parse2::(tokens)?; + match item { + Item::Impl(imp) => { + let impl_token = imp.impl_token; + let trai = imp.trait_.clone(); + let (_, trai, _) = trai.ok_or_else(|| Error::new_spanned(impl_token, "Must be a DurableObject trait impl"))?; + if !trai.segments.last().map(|x| x.ident == "DurableObject").unwrap_or(false) { + return Err(Error::new(trai.span(), "Must be a DurableObject trait impl")) + } + + let pound = syn::Token![#](imp.span()).to_token_stream(); + let wasm_bindgen_attr = quote! {#pound[::wasm_bindgen::prelude::wasm_bindgen]}; + + + let struct_name = imp.self_ty; + let items = imp.items; + let mut tokenized = vec![]; + for item in items { + let mut method = match item { + ImplItem::Method(m) => m, + _ => return Err(Error::new_spanned(item, "Impl block must only contain methods")) + }; + let tokens = match method.sig.ident.to_string().as_str() { + "constructor" => { + method.sig.ident = Ident::new("_constructor", method.sig.ident.span()); + quote! { + #pound[::wasm_bindgen::prelude::wasm_bindgen(constructor)] + pub #method + } + }, + "fetch" => { + method.sig.ident = Ident::new("_fetch_raw", method.sig.ident.span()); + quote! { + #pound[::wasm_bindgen::prelude::wasm_bindgen(js_name = fetch)] + pub fn _fetch(&mut self, req: ::edgeworker_sys::Request) -> ::js_sys::Promise { + // SAFETY: + // On the surface, this is unsound because the Durable Object could be dropped + // while JavaScript still has possession of the future. However, + // we know something that Rust doesn't: that the Durable Object will never be destroyed + // while there is still a running promise inside of it, therefore we can let a reference + // to the durable object escape into a static-lifetime future. + let static_self: &'static mut Self = unsafe {&mut *(self as *mut _)}; + + ::wasm_bindgen_futures::future_to_promise(async move { + static_self._fetch_raw(req.into()).await.map(::edgeworker_sys::Response::from).map(::wasm_bindgen::JsValue::from) + .map_err(::wasm_bindgen::JsValue::from) + }) + } + + #method + } + }, + _ => panic!() + }; + tokenized.push(tokens); + } + Ok(quote! { + #wasm_bindgen_attr + impl #struct_name { + #(#tokenized)* + } + + #pound[async_trait::async_trait(?Send)] + impl ::worker::durable::DurableObject for #struct_name { + fn constructor(state: ::worker::durable::State, env: ::worker::Env) -> Self { + Self::_constructor(state, env) + } + + async fn fetch(&mut self, req: ::worker::Request) -> ::worker::Result { + self._fetch_raw(req).await + } + } + + trait __Need_Durable_Object_Trait_Impl_With_durable_object_Attribute { const MACROED: bool = true; } + impl __Need_Durable_Object_Trait_Impl_With_durable_object_Attribute for #struct_name {} + }) + }, + Item::Struct(struc) => { + let tokens = struc.to_token_stream(); + let pound = syn::Token![#](struc.span()).to_token_stream(); + let struct_name = struc.ident; + Ok(quote! { + #pound[::wasm_bindgen::prelude::wasm_bindgen] + #tokens + + const _: bool = <#struct_name as __Need_Durable_Object_Trait_Impl_With_durable_object_Attribute>::MACROED; + }) + }, + _ => Err(Error::new(item.span(), "Durable Object macro can only be applied to structs and their impl of DurableObject trait")) + } +} \ No newline at end of file diff --git a/rust-sandbox/Cargo.toml b/rust-sandbox/Cargo.toml index c161790d..25e5df8c 100644 --- a/rust-sandbox/Cargo.toml +++ b/rust-sandbox/Cargo.toml @@ -13,14 +13,17 @@ cfg-if = "0.1.2" console_error_panic_hook = { version = "0.1.1", optional = true } wee_alloc = { version = "0.4.2", optional = true } cf = { path = "../macros/cf" } +durable_object = { path = "../macros/durable_object" } edgeworker-sys = { path = "../edgeworker-sys" } serde = { version = "1.0.126", features = ["derive"] } worker = { path = "../worker" } wasm-bindgen = "=0.2.74" wasm-bindgen-futures = "0.4.24" +js-sys = "0.3.51" worker-kv = "0.2.0" http = "0.2.4" url = "2.2.2" +async-trait = "0.1.50" [dev-dependencies] wasm-bindgen-test = "0.2" diff --git a/rust-sandbox/src/counter.rs b/rust-sandbox/src/counter.rs new file mode 100644 index 00000000..3f78d733 --- /dev/null +++ b/rust-sandbox/src/counter.rs @@ -0,0 +1,36 @@ +use cf::durable_object; +use worker::{prelude::*, durable::{State}}; + +const ONE_HOUR: u64 = 3600000; + +#[durable_object] +pub struct Counter { + count: usize, + state: State, + initialized: bool, + last_backup: Date +} + +#[durable_object] +impl DurableObject for Counter { + fn constructor(state: worker::durable::State, _env: worker::Env) -> Self { + Self { count: 0, initialized: false, state, last_backup: Date::now() } + } + + async fn fetch(&mut self, _req: worker::Request) -> worker::Result { + // Get info from last backup + if !self.initialized { + self.initialized = true; + self.count = self.state.storage().get("count").await.unwrap_or(0); + } + + // Do a backup every hour + if Date::now().as_millis() - self.last_backup.as_millis() > ONE_HOUR { + self.last_backup = Date::now(); + self.state.storage().put("count", self.count).await?; + } + + self.count += 1; + Response::ok(self.count.to_string()) + } +} \ No newline at end of file diff --git a/rust-sandbox/src/lib.rs b/rust-sandbox/src/lib.rs index 3e99b0e5..4b382e97 100644 --- a/rust-sandbox/src/lib.rs +++ b/rust-sandbox/src/lib.rs @@ -1,6 +1,8 @@ use serde::{Deserialize, Serialize}; -use worker::{kv::KvStore, prelude::*, Router}; +use worker::{durable::ObjectNamespace, kv::KvStore, prelude::*, Router}; +mod test; +mod counter; mod utils; #[derive(Deserialize, Serialize)] @@ -28,37 +30,34 @@ struct User { date_from_str: String, } -fn handle_a_request(_req: Request, _params: Params) -> Result { - Response::ok("weeee".into()) +fn handle_a_request(_req: Request, _env: Env, _params: Params) -> Result { + Response::ok("weeee") } #[cf::worker(fetch)] -pub async fn main(req: Request) -> Result { - console_log!("request at: {:?}", req.path()); - +pub async fn main(req: Request, env: Env) -> Result { utils::set_panic_hook(); let mut router = Router::new(); router.get("/request", handle_a_request)?; - router.post("/headers", |req, _| { + router.post("/headers", |req, _, _| { let mut headers: http::HeaderMap = req.headers().into(); headers.append("Hello", "World!".parse().unwrap()); // TODO: make api for Response new and mut to add headers - Response::ok("returned your headers to you.".into()) - .map(|res| res.with_headers(headers.into())) + Response::ok("returned your headers to you.").map(|res| res.with_headers(headers.into())) })?; - router.on("/user/:id/test", |req, params| { + router.on("/user/:id/test", |req, _env, params| { if !matches!(req.method(), Method::Get) { - return Response::error("Method Not Allowed".into(), 405); + return Response::error("Method Not Allowed", 405); } let id = params.get("id").unwrap_or("not found"); Response::ok(format!("TEST user id: {}", id)) })?; - router.on("/user/:id", |_req, params| { + router.on("/user/:id", |_req, _env, params| { let id = params.get("id").unwrap_or("not found"); Response::from_json(&User { id: id.into(), @@ -71,25 +70,25 @@ pub async fn main(req: Request) -> Result { }) })?; - router.post("/account/:id/zones", |_, params| { + router.post("/account/:id/zones", |_, _, params| { Response::ok(format!( "Create new zone for Account: {}", params.get("id").unwrap_or("not found") )) })?; - router.get("/account/:id/zones", |_, params| { + router.get("/account/:id/zones", |_, _, params| { Response::ok(format!( "Account id: {}..... You get a zone, you get a zone!", params.get("id").unwrap_or("not found") )) })?; - router.on_async("/async", |mut req, _params| async move { + router.on_async("/async", |mut req, _env, _params| async move { Response::ok(format!("Request body: {}", req.text().await?)) })?; - router.on_async("/fetch", |_req, _params| async move { + router.on_async("/fetch", |_req, _env, _params| async move { let req = Request::new("https://example.com", "POST")?; let resp = Fetch::Request(&req).send().await?; let resp2 = Fetch::Url("https://example.com").send().await?; @@ -100,7 +99,7 @@ pub async fn main(req: Request) -> Result { )) })?; - router.on_async("/fetch_json", |_req, _params| async move { + router.on_async("/fetch_json", |_req, _env, _params| async move { let data: ApiData = Fetch::Url("https://jsonplaceholder.typicode.com/todos/1") .send() .await? @@ -112,13 +111,19 @@ pub async fn main(req: Request) -> Result { )) })?; - router.on_async("/proxy_request/:url", |_req, params| { + router.on_async("/proxy_request/:url", |_req, _env, params| { // Must copy the parameters into the heap here for lifetime purposes let url = params.get("url").unwrap().to_string(); async move { Fetch::Url(&url).send().await } })?; - router.run(req).await + router.on_async("durable", |_req, e, _params| async move { + let namespace = e.get_binding::("COUNTER")?; + let stub = namespace.id_from_name("A")?.get_stub()?; + stub.fetch_with_str("/").await + })?; + + router.run(req, env).await // match (req.method(), req.path().as_str()) { // (Method::Get, "/") => { diff --git a/rust-sandbox/src/test/durable.rs b/rust-sandbox/src/test/durable.rs new file mode 100644 index 00000000..d09bed08 --- /dev/null +++ b/rust-sandbox/src/test/durable.rs @@ -0,0 +1,48 @@ +use crate::ensure; +use worker::{durable::ObjectNamespace, prelude::*, Result}; + +pub async fn basic_test(env: &Env) -> Result<()> { + let namespace: ObjectNamespace = env.get_binding("MY_CLASS")?; + let id = namespace.id_from_name("A")?; + let bad = env.get_binding::("DFSDF"); + ensure!(bad.is_err(), "Invalid binding did not raise error"); + + let stub = id.get_stub()?; + let res = stub.fetch_with_str("hello").await?.text().await?; + let res2 = stub + .fetch_with_request(Request::new_with_init( + "hello", + RequestInit::new().body(Some(&"lol".into())).method("POST"), + )?) + .await? + .text() + .await?; + + ensure!(res == res2, "Durable object responded wrong to 'hello'"); + + let res = stub.fetch_with_str("storage").await?.text().await?; + let num = res + .parse::() + .map_err(|_| "Durable Object responded wrong to 'storage': ".to_string() + &res)?; + let res = stub.fetch_with_str("storage").await?.text().await?; + let num2 = res + .parse::() + .map_err(|_| "Durable Object responded wrong to 'storage'".to_string())?; + + ensure!( + num2 == num + 1, + "Durable object responded wrong to 'storage'" + ); + + let res = stub.fetch_with_str("transaction").await?.text().await?; + let num = res + .parse::() + .map_err(|_| "Durable Object responded wrong to 'transaction': ".to_string() + &res)?; + + ensure!( + num == num2 + 1, + "Durable object responded wrong to 'storage'" + ); + + Ok(()) +} diff --git a/rust-sandbox/src/test/export_durable_object.rs b/rust-sandbox/src/test/export_durable_object.rs new file mode 100644 index 00000000..8a805786 --- /dev/null +++ b/rust-sandbox/src/test/export_durable_object.rs @@ -0,0 +1,96 @@ +use std::collections::HashMap; +use serde::Serialize; + +use worker::prelude::*; + +use crate::ensure; + +#[cf::durable_object] +pub struct MyClass { + state: worker::durable::State, + number: usize, +} + +#[cf::durable_object] +impl DurableObject for MyClass { + fn constructor(state: worker::durable::State, _env: worker::Env) -> Self { + Self { state, number: 0 } + } + + async fn fetch(&mut self, req: worker::Request) -> Result { + let handler = async move { + match req.path().as_str() { + "/hello" => Response::ok("Hello!"), + "/storage" => { + let mut storage = self.state.storage(); + let map = [("one".to_string(), 1), ("two".to_string(), 2)] + .iter() + .cloned() + .collect::>(); + storage.put("map", map.clone()).await?; + storage.put("array", [("one", 1), ("two", 2)]).await?; + storage.put("anything", Some(45)).await?; + + let list = storage.list().await?; + let mut keys = vec![]; + + for key in list.keys() { + let key = key? + .as_string() + .ok_or_else(|| "Key wasn't a string".to_string())?; + keys.push(key); + } + + ensure!( + keys == vec!["anything", "array", "map"], + format!("Didn't list all of the keys: {:?}", keys) + ); + let vals = storage + .get_multiple(keys) + .await + .map_err(|e| e.to_string() + " -- get_multiple")?; + ensure!( + vals.get(&"anything".into()).into_serde::>()? == Some(45), + "Didn't get the right Option using get_multiple" + ); + ensure!( + vals.get(&"array".into()) + .into_serde::<[(String, i32); 2]>()? + == [("one".to_string(), 1), ("two".to_string(), 2)], + "Didn't get the right array using get_multiple" + ); + ensure!( + vals.get(&"map".into()).into_serde::>()? == map, + "Didn't get the right HashMap using get_multiple" + ); + + #[derive(Serialize)] + struct Stuff { + thing: String, + other: i32 + } + storage.put_multiple(Stuff {thing: "Hello there".to_string(), other: 56}).await?; + + ensure!(storage.get::("thing").await? == "Hello there", "Didn't put the right thing with put_multiple"); + ensure!(storage.get::("other").await? == 56, "Didn't put the right thing with put_multiple"); + + storage.delete_multiple(vec!["thing", "other"]).await?; + + self.number = storage.get("count").await.unwrap_or(0) + 1; + + storage.delete_all().await?; + + storage.put("count", self.number).await?; + Response::ok(self.number.to_string()) + } + "/transaction" => { + Response::error("transactional storage API is still unstable", 501) + } + _ => Response::error("Not Found", 404), + } + }; + handler + .await + .or_else(|err| Response::error(err.to_string(), 500)) + } +} diff --git a/rust-sandbox/src/test/mod.rs b/rust-sandbox/src/test/mod.rs new file mode 100644 index 00000000..04d61f02 --- /dev/null +++ b/rust-sandbox/src/test/mod.rs @@ -0,0 +1,11 @@ +pub mod durable; +pub mod export_durable_object; + +#[macro_export] +macro_rules! ensure { + ($ex:expr, $er:expr) => { + if !$ex { + return Err($er.into()); + } + }; +} diff --git a/rust-sandbox/tests/web.rs b/rust-sandbox/tests/web.rs deleted file mode 100644 index 0043bc4d..00000000 --- a/rust-sandbox/tests/web.rs +++ /dev/null @@ -1,12 +0,0 @@ -//! Test suite for the Web and headless browsers. - -#![cfg(target_arch = "wasm32")] - -use wasm_bindgen_test::*; - -wasm_bindgen_test_configure!(run_in_browser); - -#[wasm_bindgen_test] -fn pass() { - assert_eq!(1 + 1, 2); -} diff --git a/rust-sandbox/worker/export_wasm.mjs b/rust-sandbox/worker/export_wasm.mjs new file mode 100644 index 00000000..6de50177 --- /dev/null +++ b/rust-sandbox/worker/export_wasm.mjs @@ -0,0 +1,7 @@ +import * as index_bg from "./index_bg.mjs"; +import _wasm from "./index_bg.wasm"; + +const _wasm_memory = new WebAssembly.Memory({initial: 512}); +let importsObject = {env: {memory: _wasm_memory}, "./index_bg.js": index_bg}; + +export default new WebAssembly.Instance(_wasm, importsObject).exports; \ No newline at end of file diff --git a/rust-sandbox/worker/metadata_wasm.json b/rust-sandbox/worker/metadata_wasm.json deleted file mode 100644 index afc6d35f..00000000 --- a/rust-sandbox/worker/metadata_wasm.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "body_part": "script", - "bindings": [ - { - "name": "wasm", - "type": "wasm_module", - "part": "wasmprogram" - } - ] -} diff --git a/rust-sandbox/worker/shim.mjs b/rust-sandbox/worker/shim.mjs new file mode 100644 index 00000000..6e214080 --- /dev/null +++ b/rust-sandbox/worker/shim.mjs @@ -0,0 +1,4 @@ +import { fetch } from "./index_bg.mjs"; + +export * from "./index_bg.mjs"; +export default { fetch }; \ No newline at end of file diff --git a/rust-sandbox/worker/worker.js b/rust-sandbox/worker/worker.js deleted file mode 100644 index 7e37e57a..00000000 --- a/rust-sandbox/worker/worker.js +++ /dev/null @@ -1,23 +0,0 @@ -addEventListener('fetch', event => { - const { type, request } = event - event.respondWith(handleRequest(type, request)) -}) - -addEventListener('scheduled', event => { - const { type, schedule, cron } = event - event.waitUntil(handleScheduled(type, schedule, cron)) -}) - -async function handleRequest(type, request) { - const { main } = wasm_bindgen; - await wasm_bindgen(wasm) - - return main(type, request) -} - -async function handleScheduled(type, schedule, cron) { - const { job } = wasm_bindgen; - await wasm_bindgen(wasm) - - return job(type, schedule, cron) -} diff --git a/rust-sandbox/wrangler.toml b/rust-sandbox/wrangler.toml index 807a5262..d3392056 100644 --- a/rust-sandbox/wrangler.toml +++ b/rust-sandbox/wrangler.toml @@ -1,7 +1,30 @@ -name = "rust-sandbox-router" -type = "rust" +name = "do-test-rust" +type = "javascript" -account_id = "" +account_id = "615f1f0479e7014f0bebcd10d379f10e" workers_dev = true route = "" zone_id = "" + +[build] +command = """wasm-pack build --no-typescript --out-dir worker/build --out-name index && \ + mkdir -p worker/generated && \ + rm -rf worker/generated/* && \ + cp worker/build/index_bg.js worker/generated/index_bg.mjs && \ + sed -i '' "1s/.*/import wasm from '.\\/export_wasm.mjs'/" worker/generated/index_bg.mjs && \ + cp worker/build/index_bg.wasm worker/shim.mjs worker/export_wasm.mjs worker/generated/""" + +[build.upload] +# The "modules" upload format is required for all projects that implement a Durable Object namespace. +format = "modules" +dir = "worker/generated" +main = "./shim.mjs" + +[[build.upload.rules]] +type = "CompiledWasm" +globs = ["**/*.wasm"] + +[durable_objects] +bindings = [ + { name = "COUNTER", class_name = "Counter" } +] \ No newline at end of file diff --git a/worker/Cargo.toml b/worker/Cargo.toml index 4a4cab05..630d586c 100644 --- a/worker/Cargo.toml +++ b/worker/Cargo.toml @@ -12,14 +12,14 @@ js-sys = "0.3.51" serde = { version = "1.0.126", features = ["derive"] } serde_json = "1.0.64" url = "2.2.2" -wasm-bindgen = "=0.2.74" +wasm-bindgen = "0.2.74" wasm-bindgen-futures = "0.4.24" worker-kv = "0.2.0" web-sys = { version = "0.3.51", features = ["console"] } http = "0.2.4" matchit = "0.4.2" async-trait = "0.1.50" -futures = "0.3.15" +futures = "0.3.16" [dev-dependencies] wasm-bindgen-test = "0.3.24" diff --git a/worker/src/durable.rs b/worker/src/durable.rs new file mode 100644 index 00000000..828d08d5 --- /dev/null +++ b/worker/src/durable.rs @@ -0,0 +1,516 @@ +use crate::{Env, Error, Request, Response, Result}; +use async_trait::async_trait; +use edgeworker_sys::{Request as EdgeRequest, Response as EdgeResponse}; +use js_sys::{Map, Object}; +use serde::{Deserialize, Serialize}; +use std::{future::Future, ops::Deref, result::Result as StdResult}; +use wasm_bindgen::{prelude::*, JsCast}; +use wasm_bindgen_futures::{JsFuture, future_to_promise}; + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen (extends = ::js_sys::Object, js_name = DurableObjectId)] + type JsObjectId; + + #[wasm_bindgen (extends = ::js_sys::Object, js_name = DurableObject)] + pub type ObjectStub; + + #[wasm_bindgen (extends = ::js_sys::Object, js_name = DurableObjectNamespace)] + pub type ObjectNamespace; + + #[wasm_bindgen (extends = ::js_sys::Object, js_name = DurableObjectState)] + pub type State; + + #[wasm_bindgen(method, getter, js_class = "DurableObjectState", js_name = id)] + fn id_internal(this: &State) -> JsObjectId; + + #[wasm_bindgen(method, getter, js_class = "DurableObjectState", js_name = storage)] + fn storage_internal(this: &State) -> Storage; + + #[wasm_bindgen (catch, method, js_class = "DurableObjectNamespace", js_name = idFromName)] + fn id_from_name_internal(this: &ObjectNamespace, name: &str) -> StdResult; + + #[wasm_bindgen (catch, method, js_class = "ObjectNamespace", js_name = idFromString)] + fn id_from_string_internal( + this: &ObjectNamespace, + string: &str, + ) -> StdResult; + + #[wasm_bindgen (catch, method, js_class = "DurableObjectNamespace", js_name = newUniqueId)] + fn new_unique_id_internal(this: &ObjectNamespace) -> StdResult; + + #[wasm_bindgen (catch, method, js_class = "DurableObjectNamespace", js_name = newUniqueId)] + fn new_unique_id_with_options_internal( + this: &ObjectNamespace, + options: &JsValue, + ) -> StdResult; + + #[wasm_bindgen (catch, method, js_class = "DurableObjectNamespace", js_name = get)] + fn get_internal(this: &ObjectNamespace, id: &JsObjectId) -> StdResult; + + #[wasm_bindgen (method, js_class = "DurableObject", js_name = fetch)] + fn fetch_with_request_internal(this: &ObjectStub, req: &EdgeRequest) -> ::js_sys::Promise; + + #[wasm_bindgen (method, js_class = "DurableObject", js_name = fetch)] + fn fetch_with_str_internal(this: &ObjectStub, url: &str) -> ::js_sys::Promise; +} + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen (extends = ::js_sys::Object, js_name = DurableObjectStorage)] + pub type Storage; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = get)] + fn get_internal(this: &Storage, key: &str) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = get)] + fn get_multiple_internal( + this: &Storage, + keys: Vec, + ) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = put)] + fn put_internal( + this: &Storage, + key: &str, + value: JsValue, + ) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = put)] + fn put_multiple_internal(this: &Storage, value: JsValue) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = delete)] + fn delete_internal(this: &Storage, key: &str) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = delete)] + fn delete_multiple_internal( + this: &Storage, + keys: Vec, + ) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = deleteAll)] + fn delete_all_internal(this: &Storage) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = list)] + fn list_internal(this: &Storage) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = list)] + fn list_with_options_internal( + this: &Storage, + options: ::js_sys::Object, + ) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectStorage", js_name = transaction)] + fn transaction_internal(this: &Storage, closure: &mut dyn FnMut(Transaction) -> ::js_sys::Promise) -> StdResult<::js_sys::Promise, JsValue>; +} + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(extends = ::js_sys::Object, js_name = DurableObjectTransaction)] + pub type Transaction; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = get)] + fn get_internal(this: &Transaction, key: &str) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = get)] + fn get_multiple_internal( + this: &Transaction, + keys: Vec, + ) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = put)] + fn put_internal( + this: &Transaction, + key: &str, + value: JsValue, + ) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = put)] + fn put_multiple_internal(this: &Transaction, value: JsValue) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = delete)] + fn delete_internal(this: &Transaction, key: &str) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = delete)] + fn delete_multiple_internal( + this: &Transaction, + keys: Vec, + ) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = deleteAll)] + fn delete_all_internal(this: &Transaction) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = list)] + fn list_internal(this: &Transaction) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = list)] + fn list_with_options_internal( + this: &Transaction, + options: ::js_sys::Object, + ) -> StdResult<::js_sys::Promise, JsValue>; + + #[wasm_bindgen(catch, method, js_class = "DurableObjectTransaction", js_name = rollback)] + fn rollback_internal(this: &Transaction) -> StdResult<(), JsValue>; +} + +impl ObjectStub { + pub async fn fetch_with_request(&self, req: Request) -> Result { + let promise = self.fetch_with_request_internal(req.inner()); + let response = JsFuture::from(promise).await?; + Ok(response.dyn_into::()?.into()) + } + + pub async fn fetch_with_str(&self, url: &str) -> Result { + let promise = self.fetch_with_str_internal(url); + let response = JsFuture::from(promise).await?; + Ok(response.dyn_into::()?.into()) + } +} + +pub struct ObjectId<'a> { + inner: JsObjectId, + namespace: Option<&'a ObjectNamespace>, +} + +impl ObjectId<'_> { + pub fn get_stub(&self) -> Result { + self.namespace + .ok_or_else(|| JsValue::from("Cannot get stub from within a Durable Object")) + .and_then(|n| n.get_internal(&self.inner)) + .map_err(Error::from) + } +} + +impl ObjectNamespace { + // Get a Durable Object binding from the global namespace + // if your build is configured with ES6 modules, use Env::get_binding instead + pub fn global(name: &str) -> Result { + let global = js_sys::global(); + #[allow(unused_unsafe)] + // Weird rust-analyzer bug is causing it to think Reflect::get is unsafe + let class_binding = unsafe { js_sys::Reflect::get(&global, &JsValue::from(name))? }; + if class_binding.is_undefined() { + Err(Error::JsError("namespace binding does not exist".into())) + } else { + Ok(class_binding.unchecked_into()) + } + } + + pub fn id_from_name(&self, name: &str) -> Result { + self.id_from_name_internal(name) + .map_err(Error::from) + .map(|id| ObjectId { + inner: id, + namespace: Some(self), + }) + } + + pub fn id_from_string(&self, string: &str) -> Result { + self.id_from_string_internal(string) + .map_err(Error::from) + .map(|id| ObjectId { + inner: id, + namespace: Some(self), + }) + } + + pub fn unique_id(&self) -> Result { + self.new_unique_id_internal() + .map_err(Error::from) + .map(|id| ObjectId { + inner: id, + namespace: Some(self), + }) + } + + pub fn unique_id_with_jurisdiction(&self, jd: &str) -> Result { + let options = Object::new(); + #[allow(unused_unsafe)] + // Weird rust-analyzer bug is causing it to think Reflect::set is unsafe + unsafe { + js_sys::Reflect::set(&options, &JsValue::from("jurisdiction"), &jd.into())? + }; + self.new_unique_id_with_options_internal(&options) + .map_err(Error::from) + .map(|id| ObjectId { + inner: id, + namespace: Some(self), + }) + } +} + +impl State { + pub fn id(&self) -> ObjectId<'_> { + ObjectId { + inner: self.id_internal(), + namespace: None, + } + } + + // Just to improve visibility to code analysis tools + pub fn storage(&self) -> Storage { + self.storage_internal() + } +} + +impl Storage { + pub async fn get Deserialize<'a>>(&self, key: &str) -> Result { + JsFuture::from(self.get_internal(key)?) + .await + .and_then(|val| { + if val.is_undefined() { + Err(JsValue::from("No such value in storage.")) + } else { + val.into_serde().map_err(|e| JsValue::from(e.to_string())) + } + }) + .map_err(Error::from) + } + + pub async fn get_multiple(&self, keys: Vec>) -> Result { + let keys = self.get_multiple_internal(keys.into_iter().map(|key| JsValue::from(key.deref())).collect())?; + let keys = JsFuture::from(keys).await?; + keys.dyn_into::().map_err(Error::from) + } + + pub async fn put(&mut self, key: &str, value: T) -> Result<()> { + JsFuture::from( + self.put_internal( + key, + JsValue::from_serde(&value)?, + ) + ?, + ) + .await + .map_err(Error::from) + .map(|_| ()) + } + + // Each key-value pair in the serialized object will be added to the storage + pub async fn put_multiple(&mut self, values: T) -> Result<()> { + let values = JsValue::from_serde(&values)?; + if !values.is_object() { + return Err("Must pass in a struct type".to_string().into()); + } + JsFuture::from(self.put_multiple_internal(values)?) + .await + .map_err(Error::from) + .map(|_| ()) + } + + pub async fn delete(&mut self, key: &str) -> Result { + let fut: JsFuture = self.delete_internal(key)?.into(); + fut.await + .and_then(|jsv| { + jsv.as_bool() + .ok_or_else(|| JsValue::from("Promise did not return bool")) + }) + .map_err(Error::from) + } + + pub async fn delete_multiple(&mut self, keys: Vec>) -> Result { + let fut: JsFuture = self + .delete_multiple_internal(keys.into_iter().map(|key| JsValue::from(key.deref())).collect())? + .into(); + fut.await + .and_then(|jsv| { + jsv.as_f64() + .map(|f| f as usize) + .ok_or_else(|| JsValue::from("Promise did not return number")) + }) + .map_err(Error::from) + } + + pub async fn delete_all(&mut self) -> Result<()> { + let fut: JsFuture = self.delete_all_internal()?.into(); + fut.await.map(|_| ()).map_err(Error::from) + } + + pub async fn list(&self) -> Result { + let fut: JsFuture = self.list_internal()?.into(); + fut.await + .and_then(|jsv| jsv.dyn_into()) + .map_err(Error::from) + } + + pub async fn list_with_options(&self, opts: ListOptions<'_>) -> Result { + let fut: JsFuture = self + .list_with_options_internal( + JsValue::from_serde(&opts)?.into(), + )? + .into(); + fut.await + .and_then(|jsv| jsv.dyn_into()) + .map_err(Error::from) + } + + //This function doesn't work on stable yet because the wasm_bindgen `Closure` type is still nightly-gated + #[allow(dead_code)] + async fn transaction(&mut self, closure: fn(Transaction) -> F) -> Result<()> + where F: Future> + 'static { + let mut clos = |t: Transaction| { + future_to_promise(async move { + closure(t).await.map_err(JsValue::from).map(|_| JsValue::NULL) + }) + }; + JsFuture::from(self.transaction_internal(&mut clos)?).await.map_err(Error::from).map(|_| ()) + } +} + +impl Transaction { + pub async fn get Deserialize<'a>>(&self, key: &str) -> Result { + JsFuture::from(self.get_internal(key)?) + .await + .and_then(|val| { + if val.is_undefined() { + Err(JsValue::from("No such value in storage.")) + } else { + val.into_serde().map_err(|e| JsValue::from(e.to_string())) + } + }) + .map_err(Error::from) + } + + pub async fn get_multiple(&self, keys: Vec>) -> Result { + let keys = self.get_multiple_internal(keys.into_iter().map(|key| JsValue::from(key.deref())).collect())?; + let keys = JsFuture::from(keys).await?; + keys.dyn_into::().map_err(Error::from) + } + + pub async fn put(&mut self, key: &str, value: T) -> Result<()> { + JsFuture::from( + self.put_internal( + key, + JsValue::from_serde(&value)?, + ) + ?, + ) + .await + .map_err(Error::from) + .map(|_| ()) + } + + // Each key-value pair in the serialized object will be added to the storage + pub async fn put_multiple(&mut self, values: T) -> Result<()> { + let values = JsValue::from_serde(&values)?; + if !values.is_object() { + return Err("Must pass in a struct type".to_string().into()); + } + JsFuture::from(self.put_multiple_internal(values)?) + .await + .map_err(Error::from) + .map(|_| ()) + } + + pub async fn delete(&mut self, key: &str) -> Result { + let fut: JsFuture = self.delete_internal(key)?.into(); + fut.await + .and_then(|jsv| { + jsv.as_bool() + .ok_or_else(|| JsValue::from("Promise did not return bool")) + }) + .map_err(Error::from) + } + + pub async fn delete_multiple(&mut self, keys: Vec>) -> Result { + let fut: JsFuture = self + .delete_multiple_internal(keys.into_iter().map(|key| JsValue::from(key.deref())).collect())? + .into(); + fut.await + .and_then(|jsv| { + jsv.as_f64() + .map(|f| f as usize) + .ok_or_else(|| JsValue::from("Promise did not return number")) + }) + .map_err(Error::from) + } + + pub async fn delete_all(&mut self) -> Result<()> { + let fut: JsFuture = self.delete_all_internal()?.into(); + fut.await.map(|_| ()).map_err(Error::from) + } + + pub async fn list(&self) -> Result { + let fut: JsFuture = self.list_internal()?.into(); + fut.await + .and_then(|jsv| jsv.dyn_into()) + .map_err(Error::from) + } + + pub async fn list_with_options(&self, opts: ListOptions<'_>) -> Result { + let fut: JsFuture = self + .list_with_options_internal( + JsValue::from_serde(&opts)?.into(), + )? + .into(); + fut.await + .and_then(|jsv| jsv.dyn_into()) + .map_err(Error::from) + } + + pub fn rollback(&mut self) -> Result<()> { + self.rollback_internal().map_err(Error::from) + } +} + +#[derive(Serialize)] +pub struct ListOptions<'a> { + #[serde(skip_serializing_if = "Option::is_none")] + start: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + end: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + prefix: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + reverse: Option, + #[serde(skip_serializing_if = "Option::is_none")] + limit: Option, +} + +impl<'a> ListOptions<'a> { + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { + start: None, + end: None, + prefix: None, + reverse: None, + limit: None, + } + } + + pub fn start(mut self, val: &'a str) -> Self { + self.start = Some(val); + self + } + + pub fn end(mut self, val: &'a str) -> Self { + self.end = Some(val); + self + } + + pub fn prefix(mut self, val: &'a str) -> Self { + self.prefix = Some(val); + self + } + + pub fn reverse(mut self, val: bool) -> Self { + self.reverse = Some(val); + self + } + + pub fn limit(mut self, val: usize) -> Self { + self.limit = Some(val); + self + } +} + +impl crate::EnvBinding for ObjectNamespace { + const TYPE_NAME: &'static str = "DurableObjectNamespace"; +} + +#[async_trait(?Send)] +pub trait DurableObject { + fn constructor(state: State, env: Env) -> Self; + async fn fetch(&mut self, req: Request) -> Result; +} diff --git a/worker/src/lib.rs b/worker/src/lib.rs index e9353a5c..b0eda5c0 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -1,6 +1,8 @@ mod headers; mod router; +pub mod durable; + use std::result::Result as StdResult; mod global; @@ -8,18 +10,17 @@ mod global; use edgeworker_sys::{ Cf, Request as EdgeRequest, Response as EdgeResponse, ResponseInit as EdgeResponseInit, }; -use js_sys::Date as JsDate; +use js_sys::{Date as JsDate, Object}; use matchit::InsertError; use serde::{de::DeserializeOwned, Serialize}; use url::Url; -use wasm_bindgen::{JsCast, JsValue}; +use wasm_bindgen::{prelude::*, JsCast, JsValue}; use wasm_bindgen_futures::JsFuture; pub use crate::headers::Headers; pub use crate::router::Router; use web_sys::RequestInit; -pub use edgeworker_sys::console_log; pub use worker_kv as kv; pub type Result = StdResult; @@ -27,13 +28,13 @@ pub type Result = StdResult; pub mod prelude { pub use crate::global::Fetch; pub use crate::headers::Headers; + pub use crate::Env; pub use crate::Method; pub use crate::Request; pub use crate::Response; pub use crate::Result; pub use crate::Schedule; pub use crate::{Date, DateInit}; - pub use edgeworker_sys::console_log; pub use matchit::Params; pub use web_sys::RequestInit; } @@ -107,28 +108,71 @@ impl From<(String, u64, String)> for Schedule { } } +#[wasm_bindgen] +extern "C" { + pub type Env; +} + +pub trait EnvBinding: Sized + JsCast { + const TYPE_NAME: &'static str; + + fn get(val: JsValue) -> Result { + let obj = Object::from(val); + if obj.constructor().name() == Self::TYPE_NAME { + Ok(obj.unchecked_into()) + } else { + Err(format!( + "Binding cannot be cast to the type {}", + Self::TYPE_NAME + ).into()) + } + } +} + +impl Env { + pub fn get_binding(&self, name: &str) -> Result { + // Weird rust-analyzer bug is causing it to think Reflect::get is unsafe + #[allow(unused_unsafe)] + let binding = unsafe { js_sys::Reflect::get(self, &JsValue::from(name)) } + .map_err(|_| Error::JsError(format!("Env does not contain binding {}", name)))?; + if binding.is_undefined() { + Err("Binding is undefined.".to_string().into()) + } else { + // Can't just use JsCast::dyn_into here because the type name might not be in scope + // resulting in a terribly annoying javascript error which can't be caught + T::get(binding) + } + } +} + pub struct Request { method: Method, + url: String, path: String, headers: Headers, cf: Cf, - event_type: String, edge_request: EdgeRequest, body_used: bool, immutable: bool, } -impl From<(String, EdgeRequest)> for Request { - fn from(req: (String, EdgeRequest)) -> Self { +impl From for Request { + fn from(req: EdgeRequest) -> Self { Self { - method: req.1.method().into(), - path: Url::parse(&req.1.url()).unwrap().path().into(), - headers: Headers(req.1.headers()), - cf: req.1.cf(), - immutable: &req.0 == "fetch", - event_type: req.0, - edge_request: req.1, + method: req.method().into(), + url: req.url(), + path: Url::parse(&req.url()).map(|u| u.path().into()).unwrap_or_else(|_| { + let u = req.url(); + if !u.starts_with('/') { + return "/".to_string() + &u + } + u + }), + headers: Headers(req.headers()), + cf: req.cf(), + edge_request: req, body_used: false, + immutable: true } } } @@ -136,7 +180,11 @@ impl From<(String, EdgeRequest)> for Request { impl Request { pub fn new(uri: &str, method: &str) -> Result { EdgeRequest::new_with_str_and_init(uri, RequestInit::new().method(method)) - .map(|req| (String::new(), req).into()) + .map(|req| { + let mut req: Request = req.into(); + req.immutable = false; + req + }) .map_err(|e| { Error::JsError( e.as_string() @@ -147,7 +195,11 @@ impl Request { pub fn new_with_init(uri: &str, init: &RequestInit) -> Result { EdgeRequest::new_with_str_and_init(uri, init) - .map(|req| (String::new(), req).into()) + .map(|req| { + let mut req: Request = req.into(); + req.immutable = false; + req + }) .map_err(|e| { Error::JsError( e.as_string() @@ -169,7 +221,7 @@ impl Request { }) .and_then(|val| { val.into_serde() - .map_err(|e| Error::RustError(e.to_string())) + .map_err(Error::from) }); } @@ -219,14 +271,11 @@ impl Request { self.path.clone() } - pub fn event_type(&self) -> String { - self.event_type.clone() - } - - #[allow(clippy::clippy::should_implement_trait)] + #[allow(clippy::should_implement_trait)] pub fn clone(&self) -> Result { - EdgeRequest::new_with_request(&self.edge_request) - .map(|req| (self.event_type(), req).into()) + self.edge_request + .clone() + .map(|req| req.into()) .map_err(Error::from) } @@ -261,9 +310,9 @@ impl Response { Err(Error::Json(("Failed to encode data to json".into(), 500))) } - pub fn ok(body: String) -> Result { + pub fn ok(body: impl Into) -> Result { Ok(Self { - body: ResponseBody::Body(body.into_bytes()), + body: ResponseBody::Body(body.into().into_bytes()), headers: Headers::new(), status_code: 200, }) @@ -282,9 +331,9 @@ impl Response { status_code: 200, }) } - pub fn error(msg: String, status: u16) -> Result { + pub fn error(msg: impl Into, status: u16) -> Result { Ok(Self { - body: ResponseBody::Body(msg.into_bytes()), + body: ResponseBody::Body(msg.into().into_bytes()), headers: Headers::new(), status_code: status, }) @@ -297,7 +346,7 @@ impl Response { pub async fn text(&mut self) -> Result { match &self.body { ResponseBody::Body(bytes) => Ok( - String::from_utf8(bytes.clone()).map_err(|e| Error::RustError(e.to_string()))? + String::from_utf8(bytes.clone()).map_err(|e| Error::from(e.to_string()))? ), ResponseBody::Empty => Ok(String::new()), ResponseBody::Stream(response) => JsFuture::from(response.text()?) @@ -309,7 +358,7 @@ impl Response { pub async fn json(&mut self) -> Result { serde_json::from_str(&self.text().await?) - .map_err(|_| Error::RustError("JSON deserialization error".into())) + .map_err(Error::from) } pub async fn bytes(&mut self) -> Result> { @@ -487,7 +536,7 @@ impl From for Redirect { } } -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub enum Error { BodyUsed, Json((String, u16)), @@ -495,6 +544,7 @@ pub enum Error { Internal(JsValue), RouteInsertError(matchit::InsertError), RustError(String), + SerdeJsonError(serde_json::Error) } impl std::fmt::Display for Error { @@ -505,6 +555,7 @@ impl std::fmt::Display for Error { Error::JsError(s) | Error::RustError(s) => write!(f, "{}", s), Error::Internal(_) => write!(f, "unrecognized JavaScript object"), Error::RouteInsertError(e) => write!(f, "failed to insert route: {}", e), + Error::SerdeJsonError(e) => write!(f, "Serde Error: {}", e) } } } @@ -529,8 +580,26 @@ impl From for JsValue { } } +impl From<&str> for Error { + fn from(a: &str) -> Self { + Error::RustError(a.to_string()) + } +} + +impl From for Error { + fn from(a: String) -> Self { + Error::RustError(a) + } +} + impl From for Error { fn from(e: InsertError) -> Self { Error::RouteInsertError(e) } } + +impl From for Error { + fn from(e: serde_json::Error) -> Self { + Error::SerdeJsonError(e) + } +} \ No newline at end of file diff --git a/worker/src/router.rs b/worker/src/router.rs index 69465d18..e0cd5a77 100644 --- a/worker/src/router.rs +++ b/worker/src/router.rs @@ -3,10 +3,10 @@ use std::rc::Rc; use futures::{future::LocalBoxFuture, Future}; use matchit::{Match, Node, Params}; -use crate::{Method, Request, Response, Result}; +use crate::{Env, Method, Request, Response, Result}; -pub type HandlerFn = fn(Request, Params) -> Result; -type AsyncHandler<'a> = Rc LocalBoxFuture<'a, Result>>; +pub type HandlerFn = fn(Request, Env, Params) -> Result; +type AsyncHandler<'a> = Rc LocalBoxFuture<'a, Result>>; pub enum Handler<'a> { Async(AsyncHandler<'a>), @@ -25,7 +25,7 @@ impl Clone for Handler<'_> { pub type HandlerSet<'a> = [Option>; 9]; pub struct Router<'a> { - handlers: matchit::Node>, + handlers: Node> } impl<'a> Router<'a> { @@ -45,35 +45,35 @@ impl<'a> Router<'a> { self.add_handler(pattern, Handler::Sync(func), Method::all()) } - pub fn get_async(&mut self, pattern: &str, func: fn(Request, Params) -> T) -> Result<()> + pub fn get_async(&mut self, pattern: &str, func: fn(Request, Env, Params) -> T) -> Result<()> where T: Future> + 'static, { self.add_handler( pattern, - Handler::Async(Rc::new(move |req, par| Box::pin(func(req, par)))), + Handler::Async(Rc::new(move |req, env, par| Box::pin(func(req, env, par)))), vec![Method::Get], ) } - pub fn post_async(&mut self, pattern: &str, func: fn(Request, Params) -> T) -> Result<()> + pub fn post_async(&mut self, pattern: &str, func: fn(Request, Env, Params) -> T) -> Result<()> where T: Future> + 'static, { self.add_handler( pattern, - Handler::Async(Rc::new(move |req, par| Box::pin(func(req, par)))), + Handler::Async(Rc::new(move |req, env, par| Box::pin(func(req, env, par)))), vec![Method::Post], ) } - pub fn on_async(&mut self, pattern: &str, func: fn(Request, Params) -> T) -> Result<()> + pub fn on_async(&mut self, pattern: &str, func: fn(Request, Env, Params) -> T) -> Result<()> where T: Future> + 'static, { self.add_handler( pattern, - Handler::Async(Rc::new(move |req, par| Box::pin(func(req, par)))), + Handler::Async(Rc::new(move |req, env, par| Box::pin(func(req, env, par)))), Method::all(), ) } @@ -100,24 +100,24 @@ impl<'a> Router<'a> { Ok(()) } - pub async fn run(&self, req: Request) -> Result { + pub async fn run(&self, req: Request, env: Env) -> Result { if let Ok(Match { value, params }) = self.handlers.at(&req.path()) { if let Some(handler) = value[req.method() as usize].as_ref() { return match handler { - Handler::Sync(func) => (func)(req, params), - Handler::Async(func) => (func)(req, params).await + Handler::Sync(func) => (func)(req, env, params), + Handler::Async(func) => (func)(req, env, params).await } } - return Response::error("Method Not Allowed".into(), 405); + return Response::error("Method Not Allowed", 405); } - Response::error("Not Found".into(), 404) + Response::error("Not Found", 404) } } impl Default for Router<'_> { fn default() -> Self { Self { - handlers: Node::new(), + handlers: Node::new() } } } \ No newline at end of file diff --git a/worker/tests/headers.rs b/worker/tests/headers.rs index 638545b5..954bcd04 100644 --- a/worker/tests/headers.rs +++ b/worker/tests/headers.rs @@ -85,7 +85,7 @@ fn response_headers() { } fn response_headers_test() -> Result<()> { - let mut response = Response::ok("Hello, World!".into())?; + let mut response = Response::ok("Hello, World!")?; response .headers_mut() .set("Content-Type", "application/json")?;