diff --git a/Cargo.lock b/Cargo.lock index cde55d628..e4fd63e4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2011,8 +2011,11 @@ version = "0.3.0" dependencies = [ "clap", "heck", + "indexmap", "test-helpers", + "wasmtime-environ", "wit-bindgen-core", + "wit-component", ] [[package]] diff --git a/crates/gen-host-js/tests/helpers.ts b/crates/gen-host-js/tests/helpers.ts index 5de607ae8..b670150d0 100644 --- a/crates/gen-host-js/tests/helpers.ts +++ b/crates/gen-host-js/tests/helpers.ts @@ -1,7 +1,7 @@ // @ts-ignore import { readFile } from 'node:fs/promises'; // @ts-ignore -import { argv, stdout } from 'node:process'; +import { argv, stdout, stderr } from 'node:process'; // This is a helper function used from `host.ts` test in the `tests/runtime/*` // directory to pass as the `instantiateCore` argument to the `instantiate` @@ -20,4 +20,7 @@ export const testwasi = { log(bytes: Uint8Array) { stdout.write(bytes); }, + logErr(bytes: Uint8Array) { + stderr.write(bytes); + }, }; diff --git a/crates/gen-host-wasmtime-py/Cargo.toml b/crates/gen-host-wasmtime-py/Cargo.toml index 0ca6b83c9..786e8fb95 100644 --- a/crates/gen-host-wasmtime-py/Cargo.toml +++ b/crates/gen-host-wasmtime-py/Cargo.toml @@ -5,9 +5,12 @@ version.workspace = true edition.workspace = true [dependencies] -wit-bindgen-core = { workspace = true } +wit-bindgen-core = { workspace = true, features = ['component-generator'] } heck = { workspace = true } clap = { workspace = true, optional = true } +wit-component = { workspace = true } +indexmap = "1.0" +wasmtime-environ = { workspace = true, features = ['component-model'] } [dev-dependencies] test-helpers = { path = '../test-helpers' } diff --git a/crates/gen-host-wasmtime-py/src/dependencies.rs b/crates/gen-host-wasmtime-py/src/dependencies.rs deleted file mode 100644 index f7a48ca14..000000000 --- a/crates/gen-host-wasmtime-py/src/dependencies.rs +++ /dev/null @@ -1,368 +0,0 @@ -use crate::Source; -use std::collections::{BTreeMap, BTreeSet}; - -/// Tracks all of the import and intrinsics that a given codegen -/// requires and how to generate them when needed. -#[derive(Default)] -pub struct Dependencies { - pub needs_clamp: bool, - pub needs_store: bool, - pub needs_load: bool, - pub needs_validate_guest_char: bool, - pub needs_result: bool, - pub needs_i32_to_f32: bool, - pub needs_f32_to_i32: bool, - pub needs_i64_to_f64: bool, - pub needs_f64_to_i64: bool, - pub needs_decode_utf8: bool, - pub needs_encode_utf8: bool, - pub needs_list_canon_lift: bool, - pub needs_list_canon_lower: bool, - pub needs_t_typevar: bool, - pub needs_resources: bool, - pub pyimports: BTreeMap>>, -} - -impl Dependencies { - /// Record that a Python import is required - /// - /// Examples - /// ``` - /// # use wit_bindgen_gen_host_wasmtime_py::dependencies::Dependencies; - /// # let mut deps = Dependencies::default(); - /// // Import a specific item from a module - /// deps.pyimport("typing", "NamedTuple"); - /// // Import an entire module - /// deps.pyimport("collections", None); - /// ``` - pub fn pyimport<'a>(&mut self, module: &str, name: impl Into>) { - let name = name.into(); - let list = self - .pyimports - .entry(module.to_string()) - .or_insert(match name { - Some(_) => Some(BTreeSet::new()), - None => None, - }); - match name { - Some(name) => { - assert!(list.is_some()); - list.as_mut().unwrap().insert(name.to_string()); - } - None => assert!(list.is_none()), - } - } - - /// Create a `Source` containing all of the intrinsics - /// required according to this `Dependencies` struct. - pub fn intrinsics(&mut self) -> Source { - let mut src = Source::default(); - - if self.needs_clamp { - src.push_str( - " - def _clamp(i: int, min: int, max: int) -> int: - if i < min or i > max: - raise OverflowError(f'must be between {min} and {max}') - return i - ", - ); - } - if self.needs_store { - // TODO: this uses native endianness - self.pyimport("ctypes", None); - src.push_str( - " - def _store(ty: Any, mem: wasmtime.Memory, store: wasmtime.Storelike, base: int, offset: int, val: Any) -> None: - ptr = (base & 0xffffffff) + offset - if ptr + ctypes.sizeof(ty) > mem.data_len(store): - raise IndexError('out-of-bounds store') - raw_base = mem.data_ptr(store) - c_ptr = ctypes.POINTER(ty)( - ty.from_address(ctypes.addressof(raw_base.contents) + ptr) - ) - c_ptr[0] = val - ", - ); - } - if self.needs_load { - // TODO: this uses native endianness - self.pyimport("ctypes", None); - src.push_str( - " - def _load(ty: Any, mem: wasmtime.Memory, store: wasmtime.Storelike, base: int, offset: int) -> Any: - ptr = (base & 0xffffffff) + offset - if ptr + ctypes.sizeof(ty) > mem.data_len(store): - raise IndexError('out-of-bounds store') - raw_base = mem.data_ptr(store) - c_ptr = ctypes.POINTER(ty)( - ty.from_address(ctypes.addressof(raw_base.contents) + ptr) - ) - return c_ptr[0] - ", - ); - } - if self.needs_validate_guest_char { - src.push_str( - " - def _validate_guest_char(i: int) -> str: - if i > 0x10ffff or (i >= 0xd800 and i <= 0xdfff): - raise TypeError('not a valid char') - return chr(i) - ", - ); - } - if self.needs_result { - self.pyimport("dataclasses", "dataclass"); - self.pyimport("typing", "TypeVar"); - self.pyimport("typing", "Generic"); - self.pyimport("typing", "Union"); - self.needs_t_typevar = true; - src.push_str( - " - @dataclass - class Ok(Generic[T]): - value: T - E = TypeVar('E') - @dataclass - class Err(Generic[E]): - value: E - - Result = Union[Ok[T], Err[E]] - ", - ); - } - if self.needs_i32_to_f32 || self.needs_f32_to_i32 { - self.pyimport("ctypes", None); - src.push_str("_i32_to_f32_i32 = ctypes.pointer(ctypes.c_int32(0))\n"); - src.push_str( - "_i32_to_f32_f32 = ctypes.cast(_i32_to_f32_i32, ctypes.POINTER(ctypes.c_float))\n", - ); - if self.needs_i32_to_f32 { - src.push_str( - " - def _i32_to_f32(i: int) -> float: - _i32_to_f32_i32[0] = i # type: ignore - return _i32_to_f32_f32[0] # type: ignore - ", - ); - } - if self.needs_f32_to_i32 { - src.push_str( - " - def _f32_to_i32(i: float) -> int: - _i32_to_f32_f32[0] = i # type: ignore - return _i32_to_f32_i32[0] # type: ignore - ", - ); - } - } - if self.needs_i64_to_f64 || self.needs_f64_to_i64 { - self.pyimport("ctypes", None); - src.push_str("_i64_to_f64_i64 = ctypes.pointer(ctypes.c_int64(0))\n"); - src.push_str( - "_i64_to_f64_f64 = ctypes.cast(_i64_to_f64_i64, ctypes.POINTER(ctypes.c_double))\n", - ); - if self.needs_i64_to_f64 { - src.push_str( - " - def _i64_to_f64(i: int) -> float: - _i64_to_f64_i64[0] = i # type: ignore - return _i64_to_f64_f64[0] # type: ignore - ", - ); - } - if self.needs_f64_to_i64 { - src.push_str( - " - def _f64_to_i64(i: float) -> int: - _i64_to_f64_f64[0] = i # type: ignore - return _i64_to_f64_i64[0] # type: ignore - ", - ); - } - } - if self.needs_decode_utf8 { - self.pyimport("ctypes", None); - src.push_str( - " - def _decode_utf8(mem: wasmtime.Memory, store: wasmtime.Storelike, ptr: int, len: int) -> str: - ptr = ptr & 0xffffffff - len = len & 0xffffffff - if ptr + len > mem.data_len(store): - raise IndexError('string out of bounds') - base = mem.data_ptr(store) - base = ctypes.POINTER(ctypes.c_ubyte)( - ctypes.c_ubyte.from_address(ctypes.addressof(base.contents) + ptr) - ) - return ctypes.string_at(base, len).decode('utf-8') - ", - ); - } - if self.needs_encode_utf8 { - self.pyimport("ctypes", None); - self.pyimport("typing", "Tuple"); - src.push_str( - " - def _encode_utf8(val: str, realloc: wasmtime.Func, mem: wasmtime.Memory, store: wasmtime.Storelike) -> Tuple[int, int]: - bytes = val.encode('utf8') - ptr = realloc(store, 0, 0, 1, len(bytes)) - assert(isinstance(ptr, int)) - ptr = ptr & 0xffffffff - if ptr + len(bytes) > mem.data_len(store): - raise IndexError('string out of bounds') - base = mem.data_ptr(store) - base = ctypes.POINTER(ctypes.c_ubyte)( - ctypes.c_ubyte.from_address(ctypes.addressof(base.contents) + ptr) - ) - ctypes.memmove(base, bytes, len(bytes)) - return (ptr, len(bytes)) - ", - ); - } - if self.needs_list_canon_lift { - self.pyimport("ctypes", None); - self.pyimport("typing", "List"); - // TODO: this is doing a native-endian read, not a little-endian - // read - src.push_str( - " - def _list_canon_lift(ptr: int, len: int, size: int, ty: Any, mem: wasmtime.Memory ,store: wasmtime.Storelike) -> Any: - ptr = ptr & 0xffffffff - len = len & 0xffffffff - if ptr + len * size > mem.data_len(store): - raise IndexError('list out of bounds') - raw_base = mem.data_ptr(store) - base = ctypes.POINTER(ty)( - ty.from_address(ctypes.addressof(raw_base.contents) + ptr) - ) - if ty == ctypes.c_uint8: - return ctypes.string_at(base, len) - return base[:len] - ", - ); - } - if self.needs_list_canon_lower { - self.pyimport("ctypes", None); - self.pyimport("typing", "List"); - self.pyimport("typing", "Tuple"); - // TODO: is there a faster way to memcpy other than iterating over - // the input list? - // TODO: this is doing a native-endian write, not a little-endian - // write - src.push_str( - " - def _list_canon_lower(list: Any, ty: Any, size: int, align: int, realloc: wasmtime.Func, mem: wasmtime.Memory, store: wasmtime.Storelike) -> Tuple[int, int]: - total_size = size * len(list) - ptr = realloc(store, 0, 0, align, total_size) - assert(isinstance(ptr, int)) - ptr = ptr & 0xffffffff - if ptr + total_size > mem.data_len(store): - raise IndexError('list realloc return of bounds') - raw_base = mem.data_ptr(store) - base = ctypes.POINTER(ty)( - ty.from_address(ctypes.addressof(raw_base.contents) + ptr) - ) - for i, val in enumerate(list): - base[i] = val - return (ptr, len(list)) - ", - ); - } - - if self.needs_resources { - self.pyimport("typing", "TypeVar"); - self.pyimport("typing", "Generic"); - self.pyimport("typing", "List"); - self.pyimport("typing", "Optional"); - self.pyimport("dataclasses", "dataclass"); - self.needs_t_typevar = true; - src.push_str( - " - @dataclass - class SlabEntry(Generic[T]): - next: int - val: Optional[T] - - class Slab(Generic[T]): - head: int - list: List[SlabEntry[T]] - - def __init__(self) -> None: - self.list = [] - self.head = 0 - - def insert(self, val: T) -> int: - if self.head >= len(self.list): - self.list.append(SlabEntry(next = len(self.list) + 1, val = None)) - ret = self.head - slot = self.list[ret] - self.head = slot.next - slot.next = -1 - slot.val = val - return ret - - def get(self, idx: int) -> T: - if idx >= len(self.list): - raise IndexError('handle index not valid') - slot = self.list[idx] - if slot.next == -1: - assert(slot.val is not None) - return slot.val - raise IndexError('handle index not valid') - - def remove(self, idx: int) -> T: - ret = self.get(idx) - slot = self.list[idx] - slot.val = None - slot.next = self.head - self.head = idx - return ret - ", - ); - } - src - } -} - -#[cfg(test)] -mod test { - use std::collections::{BTreeMap, BTreeSet}; - - use super::Dependencies; - - #[test] - fn test_pyimport_only_contents() { - let mut deps = Dependencies::default(); - deps.pyimport("typing", None); - deps.pyimport("typing", None); - assert_eq!(deps.pyimports, BTreeMap::from([("typing".into(), None)])); - } - - #[test] - fn test_pyimport_only_module() { - let mut deps = Dependencies::default(); - deps.pyimport("typing", "Union"); - deps.pyimport("typing", "List"); - deps.pyimport("typing", "NamedTuple"); - assert_eq!( - deps.pyimports, - BTreeMap::from([( - "typing".into(), - Some(BTreeSet::from([ - "Union".into(), - "List".into(), - "NamedTuple".into() - ])) - )]) - ); - } - - #[test] - #[should_panic] - fn test_pyimport_conflicting() { - let mut deps = Dependencies::default(); - deps.pyimport("typing", "NamedTuple"); - deps.pyimport("typing", None); - } -} diff --git a/crates/gen-host-wasmtime-py/src/imports.rs b/crates/gen-host-wasmtime-py/src/imports.rs new file mode 100644 index 000000000..0a634047a --- /dev/null +++ b/crates/gen-host-wasmtime-py/src/imports.rs @@ -0,0 +1,96 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::fmt::Write; +use wit_bindgen_core::uwriteln; + +/// Tracks all of the import and intrinsics that a given codegen +/// requires and how to generate them when needed. +#[derive(Default)] +pub struct PyImports { + pyimports: BTreeMap>>, +} + +impl PyImports { + /// Record that a Python import is required + pub fn pyimport<'a>(&mut self, module: &str, name: impl Into>) { + let name = name.into(); + let list = self + .pyimports + .entry(module.to_string()) + .or_insert(match name { + Some(_) => Some(BTreeSet::new()), + None => None, + }); + match name { + Some(name) => { + assert!(list.is_some()); + list.as_mut().unwrap().insert(name.to_string()); + } + None => assert!(list.is_none()), + } + } + + pub fn is_empty(&self) -> bool { + self.pyimports.is_empty() + } + + pub fn finish(&self) -> String { + let mut result = String::new(); + for (k, list) in self.pyimports.iter() { + match list { + Some(list) => { + let list = list.iter().cloned().collect::>().join(", "); + uwriteln!(result, "from {k} import {list}"); + } + None => uwriteln!(result, "import {k}"), + } + } + + if !self.pyimports.is_empty() { + result.push_str("\n"); + } + + result + } +} + +#[cfg(test)] +mod test { + use std::collections::{BTreeMap, BTreeSet}; + + use super::PyImports; + + #[test] + fn test_pyimport_only_contents() { + let mut deps = PyImports::default(); + deps.pyimport("typing", None); + deps.pyimport("typing", None); + assert_eq!(deps.pyimports, BTreeMap::from([("typing".into(), None)])); + } + + #[test] + fn test_pyimport_only_module() { + let mut deps = PyImports::default(); + deps.pyimport("typing", "Union"); + deps.pyimport("typing", "List"); + deps.pyimport("typing", "NamedTuple"); + assert_eq!( + deps.pyimports, + BTreeMap::from([( + "typing".into(), + Some(BTreeSet::from([ + "Union".into(), + "List".into(), + "NamedTuple".into() + ])) + )]) + ); + } + + #[test] + #[should_panic] + fn test_pyimport_conflicting() { + let mut deps = PyImports::default(); + deps.pyimport("typing", "NamedTuple"); + deps.pyimport("typing", None); + } +} diff --git a/crates/gen-host-wasmtime-py/src/lib.rs b/crates/gen-host-wasmtime-py/src/lib.rs index 8210ff1f8..14e795cf1 100644 --- a/crates/gen-host-wasmtime-py/src/lib.rs +++ b/crates/gen-host-wasmtime-py/src/lib.rs @@ -1,100 +1,402 @@ +//! Code generator for the `wasmtime` PyPI package. +//! +//! This crate will generate bindings for a single component, like JS, for +//! Python source code. Component-model types are translated to Python and the +//! component is executed using the `wasmtime` PyPI package which is bindings to +//! the `wasmtime` C API which is built on the `wasmtime` Rust API. +//! +//! The generated structure of the bindings looks like follows: +//! +//! ```ignore +//! out_dir/ +//! __init__.py +//! types.py # types shared by all imports/exports +//! imports/ # only present if interfaces are imported +//! __init__.py # reexports `Foo` protocols for each interface +//! foo.py # types and definitions specific to interface `foo` +//! .. +//! exports/ # only present with exported interfaces +//! __init__.py # empty file +//! bar.py # contains `Bar` as the exported interface +//! .. +//! ``` +//! +//! The top-level `__init__.py` contains a `class Foo` where `Foo` is the name +//! fo the component. It contains top-level functions for all top-level exports +//! and exported instances are modeled as a method which returns a struct from +//! `exports/*.py`. + use heck::*; -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use indexmap::IndexMap; +use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::fmt::Write; use std::mem; +use wasmtime_environ::component::{ + CanonicalOptions, Component, CoreDef, CoreExport, Export, ExportItem, GlobalInitializer, + InstantiateModule, LowerImport, RuntimeInstanceIndex, StaticModuleIndex, StringEncoding, +}; +use wasmtime_environ::{EntityIndex, ModuleTranslation, PrimaryMap}; +use wit_bindgen_core::component::ComponentGenerator; use wit_bindgen_core::wit_parser::abi::{ AbiVariant, Bindgen, Bitcast, Instruction, LiftLower, WasmType, }; -use wit_bindgen_core::{wit_parser::*, Direction, Files, Generator, Ns}; +use wit_bindgen_core::{ + uwrite, uwriteln, wit_parser::*, Files, InterfaceGenerator as _, Ns, WorldGenerator, +}; +use wit_component::ComponentInterfaces; -pub mod dependencies; -pub mod source; +mod imports; +mod source; -use dependencies::Dependencies; use source::Source; #[derive(Default)] -pub struct WasmtimePy { - src: Source, - in_import: bool, +struct WasmtimePy { opts: Opts, - guest_imports: HashMap, - guest_exports: HashMap, - sizes: SizeAlign, - /// Tracks the intrinsics and Python imports needed - deps: Dependencies, - /// Whether the Python Union being emited will wrap its cases with dataclasses - union_representation: HashMap, -} - -#[derive(Debug, Clone, Copy)] -enum PyUnionRepresentation { - /// A union whose inner types are used directly - Raw, - /// A union whose inner types have been wrapped in dataclasses - Wrapped, -} - -#[derive(Default)] -struct Imports { - freestanding_funcs: Vec, -} - -struct Import { - name: String, - src: Source, - wasm_ty: String, - pysig: String, -} -#[derive(Default)] -struct Exports { - freestanding_funcs: Vec, - fields: BTreeMap, -} - -struct Export { - python_type: &'static str, - name: String, + // `$out_dir/__init__.py` + init: Source, + // `$out_dir/types.py` + types: Source, + // `$out_dir/intrinsics.py` + intrinsics: Source, + // `$out_dir/imports/__init__.py` + imports_init: Source, + // `$out_dir/exports/$name.py` + exports: BTreeMap, + + /// Known imported interfaces to have as an argument to construction of the + /// main component. + imports: Vec, + + /// All intrinsics emitted to `self.intrinsics` so far. + all_intrinsics: BTreeSet<&'static str>, } #[derive(Default, Debug, Clone)] #[cfg_attr(feature = "clap", derive(clap::Args))] pub struct Opts { - #[cfg_attr(feature = "clap", arg(long = "no-typescript"))] - pub no_typescript: bool, + // ... } impl Opts { - pub fn build(self) -> WasmtimePy { - let mut r = WasmtimePy::new(); + pub fn build(self) -> Box { + let mut r = WasmtimePy::default(); r.opts = self; - r + Box::new(r) } } impl WasmtimePy { - pub fn new() -> WasmtimePy { - WasmtimePy::default() - } - - fn abi_variant(dir: Direction) -> AbiVariant { - // This generator uses a reversed mapping! In the Wasmtime-py host-side - // bindings, we don't use any extra adapter layer between guest wasm - // modules and the host. When the guest imports functions using the - // `GuestImport` ABI, the host directly implements the `GuestImport` - // ABI, even though the host is *exporting* functions. Similarly, when - // the guest exports functions using the `GuestExport` ABI, the host - // directly imports them with the `GuestExport` ABI, even though the - // host is *importing* functions. - match dir { - Direction::Import => AbiVariant::GuestExport, - Direction::Export => AbiVariant::GuestImport, + fn interface<'a>(&'a mut self, iface: &'a Interface, at_root: bool) -> InterfaceGenerator<'a> { + InterfaceGenerator { + gen: self, + iface, + src: Source::default(), + at_root, + self_module_path: "", + } + } + + fn print_result(&mut self) { + if !self.all_intrinsics.insert("result_type") { + return; + } + + self.types.pyimport("dataclasses", "dataclass"); + self.types.pyimport("typing", "TypeVar"); + self.types.pyimport("typing", "Generic"); + self.types.pyimport("typing", "Union"); + self.types.push_str( + " + T = TypeVar('T') + @dataclass + class Ok(Generic[T]): + value: T + E = TypeVar('E') + @dataclass + class Err(Generic[E]): + value: E + + Result = Union[Ok[T], Err[E]] + ", + ); + } + + fn print_intrinsic( + &mut self, + src: &mut Source, + name: &'static str, + gen: impl FnOnce(&str, &mut Source), + ) -> &'static str { + src.pyimport(".intrinsics", name); + if !self.all_intrinsics.insert(name) { + return name; + } + gen(name, &mut self.intrinsics); + return name; + } + + fn print_validate_guest_char(&mut self, src: &mut Source) -> &'static str { + self.print_intrinsic(src, "_validate_guest_char", |name, src| { + uwriteln!( + src, + " + def {name}(i: int) -> str: + if i > 0x10ffff or (i >= 0xd800 and i <= 0xdfff): + raise TypeError('not a valid char') + return chr(i) + ", + ); + }) + } + + fn print_i32_to_f32(&mut self, src: &mut Source) -> &'static str { + self.print_i32_to_f32_cvts(); + self.print_intrinsic(src, "_i32_to_f32", |name, src| { + uwriteln!( + src, + " + def {name}(i: int) -> float: + _i32_to_f32_i32[0] = i # type: ignore + return _i32_to_f32_f32[0] # type: ignore + ", + ); + }) + } + + fn print_f32_to_i32(&mut self, src: &mut Source) -> &'static str { + self.print_i32_to_f32_cvts(); + self.print_intrinsic(src, "_f32_to_i32", |name, src| { + uwriteln!( + src, + " + def {name}(i: float) -> int: + _i32_to_f32_f32[0] = i # type: ignore + return _i32_to_f32_i32[0] # type: ignore + ", + ); + }) + } + + fn print_i32_to_f32_cvts(&mut self) { + if !self.all_intrinsics.insert("i32_to_f32_cvts") { + return; + } + self.intrinsics.pyimport("ctypes", None); + self.intrinsics + .push_str("_i32_to_f32_i32 = ctypes.pointer(ctypes.c_int32(0))\n"); + self.intrinsics.push_str( + "_i32_to_f32_f32 = ctypes.cast(_i32_to_f32_i32, ctypes.POINTER(ctypes.c_float))\n", + ); + } + + fn print_i64_to_f64(&mut self, src: &mut Source) -> &'static str { + self.print_i64_to_f64_cvts(); + self.print_intrinsic(src, "_i64_to_f64", |name, src| { + uwriteln!( + src, + " + def {name}(i: int) -> float: + _i64_to_f64_i64[0] = i # type: ignore + return _i64_to_f64_f64[0] # type: ignore + ", + ); + }) + } + + fn print_f64_to_i64(&mut self, src: &mut Source) -> &'static str { + self.print_i64_to_f64_cvts(); + self.print_intrinsic(src, "_f64_to_i64", |name, src| { + uwriteln!( + src, + " + def {name}(i: float) -> int: + _i64_to_f64_f64[0] = i # type: ignore + return _i64_to_f64_i64[0] # type: ignore + ", + ); + }) + } + + fn print_i64_to_f64_cvts(&mut self) { + if !self.all_intrinsics.insert("i64_to_f64_cvts") { + return; } + self.intrinsics.pyimport("ctypes", None); + self.intrinsics + .push_str("_i64_to_f64_i64 = ctypes.pointer(ctypes.c_int64(0))\n"); + self.intrinsics.push_str( + "_i64_to_f64_f64 = ctypes.cast(_i64_to_f64_i64, ctypes.POINTER(ctypes.c_double))\n", + ); + } + + fn print_clamp(&mut self, src: &mut Source) -> &'static str { + self.print_intrinsic(src, "_clamp", |name, src| { + uwriteln!( + src, + " + def {name}(i: int, min: int, max: int) -> int: + if i < min or i > max: + raise OverflowError(f'must be between {{min}} and {{max}}') + return i + ", + ); + }) + } + + fn print_load(&mut self, src: &mut Source) -> &'static str { + self.print_intrinsic(src, "_load", |name, src| { + src.pyimport("wasmtime", None); + src.pyimport("ctypes", None); + src.pyimport("typing", "Any"); + uwriteln!( + src, + " + def {name}(ty: Any, mem: wasmtime.Memory, store: wasmtime.Storelike, base: int, offset: int) -> Any: + ptr = (base & 0xffffffff) + offset + if ptr + ctypes.sizeof(ty) > mem.data_len(store): + raise IndexError('out-of-bounds store') + raw_base = mem.data_ptr(store) + c_ptr = ctypes.POINTER(ty)( + ty.from_address(ctypes.addressof(raw_base.contents) + ptr) + ) + return c_ptr[0] + ", + ); + }) } - /// Creates a `Source` with all of the required intrinsics - fn intrinsics(&mut self, _iface: &Interface) -> Source { - self.deps.intrinsics() + fn print_store(&mut self, src: &mut Source) -> &'static str { + self.print_intrinsic(src, "_store", |name, src| { + src.pyimport("wasmtime", None); + src.pyimport("ctypes", None); + src.pyimport("typing", "Any"); + uwriteln!( + src, + " + def {name}(ty: Any, mem: wasmtime.Memory, store: wasmtime.Storelike, base: int, offset: int, val: Any) -> None: + ptr = (base & 0xffffffff) + offset + if ptr + ctypes.sizeof(ty) > mem.data_len(store): + raise IndexError('out-of-bounds store') + raw_base = mem.data_ptr(store) + c_ptr = ctypes.POINTER(ty)( + ty.from_address(ctypes.addressof(raw_base.contents) + ptr) + ) + c_ptr[0] = val + ", + ); + }) + } + + fn print_decode_utf8(&mut self, src: &mut Source) -> &'static str { + self.print_intrinsic(src, "_decode_utf8", |name, src| { + src.pyimport("wasmtime", None); + src.pyimport("ctypes", None); + src.pyimport("typing", "Tuple"); + uwriteln!( + src, + " + def {name}(mem: wasmtime.Memory, store: wasmtime.Storelike, ptr: int, len: int) -> str: + ptr = ptr & 0xffffffff + len = len & 0xffffffff + if ptr + len > mem.data_len(store): + raise IndexError('string out of bounds') + base = mem.data_ptr(store) + base = ctypes.POINTER(ctypes.c_ubyte)( + ctypes.c_ubyte.from_address(ctypes.addressof(base.contents) + ptr) + ) + return ctypes.string_at(base, len).decode('utf-8') + ", + ); + }) + } + + fn print_encode_utf8(&mut self, src: &mut Source) -> &'static str { + self.print_intrinsic(src, "_encode_utf8", |name, src| { + src.pyimport("wasmtime", None); + src.pyimport("ctypes", None); + src.pyimport("typing", "Tuple"); + uwriteln!( + src, + " + def {name}(val: str, realloc: wasmtime.Func, mem: wasmtime.Memory, store: wasmtime.Storelike) -> Tuple[int, int]: + bytes = val.encode('utf8') + ptr = realloc(store, 0, 0, 1, len(bytes)) + assert(isinstance(ptr, int)) + ptr = ptr & 0xffffffff + if ptr + len(bytes) > mem.data_len(store): + raise IndexError('string out of bounds') + base = mem.data_ptr(store) + base = ctypes.POINTER(ctypes.c_ubyte)( + ctypes.c_ubyte.from_address(ctypes.addressof(base.contents) + ptr) + ) + ctypes.memmove(base, bytes, len(bytes)) + return (ptr, len(bytes)) + ", + ); + }) + } + + fn print_canon_lift(&mut self, src: &mut Source) -> &'static str { + self.print_intrinsic(src, "_list_canon_lift", |name, src| { + src.pyimport("wasmtime", None); + src.pyimport("ctypes", None); + src.pyimport("typing", "List"); + src.pyimport("typing", "Any"); + // TODO: this is doing a native-endian read, not a little-endian + // read + uwriteln!( + src, + " + def {name}(ptr: int, len: int, size: int, ty: Any, mem: wasmtime.Memory ,store: wasmtime.Storelike) -> Any: + ptr = ptr & 0xffffffff + len = len & 0xffffffff + if ptr + len * size > mem.data_len(store): + raise IndexError('list out of bounds') + raw_base = mem.data_ptr(store) + base = ctypes.POINTER(ty)( + ty.from_address(ctypes.addressof(raw_base.contents) + ptr) + ) + if ty == ctypes.c_uint8: + return ctypes.string_at(base, len) + return base[:len] + ", + ); + }) + } + + fn print_canon_lower(&mut self, src: &mut Source) -> &'static str { + self.print_intrinsic(src, "_list_canon_lower", |name, src| { + src.pyimport("wasmtime", None); + src.pyimport("ctypes", None); + src.pyimport("typing", "Tuple"); + src.pyimport("typing", "List"); + src.pyimport("typing", "Any"); + // TODO: is there a faster way to memcpy other than iterating over + // the input list? + // TODO: this is doing a native-endian write, not a little-endian + // write + uwriteln!( + src, + " + def {name}(list: Any, ty: Any, size: int, align: int, realloc: wasmtime.Func, mem: wasmtime.Memory, store: wasmtime.Storelike) -> Tuple[int, int]: + total_size = size * len(list) + ptr = realloc(store, 0, 0, align, total_size) + assert(isinstance(ptr, int)) + ptr = ptr & 0xffffffff + if ptr + total_size > mem.data_len(store): + raise IndexError('list realloc return of bounds') + raw_base = mem.data_ptr(store) + base = ctypes.POINTER(ty)( + ty.from_address(ctypes.addressof(raw_base.contents) + ptr) + ) + for i, val in enumerate(list): + base[i] = val + return (ptr, len(list)) + ", + ); + }) } } @@ -120,616 +422,1023 @@ fn array_ty(iface: &Interface, ty: &Type) -> Option<&'static str> { } } -impl Generator for WasmtimePy { - fn preprocess_one(&mut self, iface: &Interface, dir: Direction) { - let variant = Self::abi_variant(dir); - self.sizes.fill(iface); - self.in_import = variant == AbiVariant::GuestImport; - } - - fn type_record( +impl ComponentGenerator for WasmtimePy { + fn instantiate( &mut self, - iface: &Interface, - _id: TypeId, name: &str, - record: &Record, - docs: &Docs, + component: &Component, + modules: &PrimaryMap>, + interfaces: &ComponentInterfaces, ) { - let mut builder = self.src.builder(&mut self.deps, iface); - builder.pyimport("dataclasses", "dataclass"); - builder.push_str("@dataclass\n"); - builder.push_str(&format!("class {}:\n", name.to_upper_camel_case())); - builder.indent(); - builder.docstring(docs); - for field in record.fields.iter() { - builder.comment(&field.docs); - let field_name = field.name.to_snake_case(); - builder.push_str(&format!("{field_name}: ")); - builder.print_ty(&field.ty, true); - builder.push_str("\n"); - } - if record.fields.is_empty() { - builder.push_str("pass\n"); + self.init.pyimport("wasmtime", None); + + let camel = name.to_upper_camel_case(); + let imports = if !component.import_types.is_empty() { + self.init + .pyimport(".imports", format!("{camel}Imports").as_str()); + format!(", import_object: {camel}Imports") + } else { + String::new() + }; + + uwriteln!(self.init, "class {camel}:"); + self.init.indent(); + + self.init.push_str("\n"); + + uwriteln!( + self.init, + "def __init__(self, store: wasmtime.Store{imports}):" + ); + self.init.indent(); + let mut i = Instantiator { + gen: self, + modules, + component, + interfaces, + instances: PrimaryMap::default(), + lifts: 0, + }; + for init in component.initializers.iter() { + i.global_initializer(init); } - builder.dedent(); - builder.push_str("\n"); - } + let (lifts, nested) = i.exports(&component.exports, interfaces.default.as_ref()); + i.gen.init.dedent(); - fn type_tuple( - &mut self, - iface: &Interface, - _id: TypeId, - name: &str, - tuple: &Tuple, - docs: &Docs, - ) { - let mut builder = self.src.builder(&mut self.deps, iface); - builder.comment(docs); - builder.push_str(&format!("{} = ", name.to_upper_camel_case())); - builder.print_tuple(tuple); - builder.push_str("\n"); + i.generate_lifts(&camel, None, &lifts); + for (name, lifts) in nested { + i.generate_lifts(&camel, Some(name), &lifts); + } + i.gen.init.dedent(); } - fn type_flags( - &mut self, - iface: &Interface, - _id: TypeId, - name: &str, - flags: &Flags, - docs: &Docs, - ) { - let mut builder = self.src.builder(&mut self.deps, iface); - builder.pyimport("enum", "Flag"); - builder.pyimport("enum", "auto"); - builder.push_str(&format!("class {}(Flag):\n", name.to_upper_camel_case())); - builder.indent(); - builder.docstring(docs); - for flag in flags.flags.iter() { - let flag_name = flag.name.to_shouty_snake_case(); - builder.comment(&flag.docs); - builder.push_str(&format!("{flag_name} = auto()\n")); + fn finish_component(&mut self, _name: &str, files: &mut Files) { + if !self.imports_init.is_empty() { + files.push("imports/__init__.py", self.imports_init.finish().as_bytes()); } - if flags.flags.is_empty() { - builder.push_str("pass\n"); + if !self.types.is_empty() { + files.push("types.py", self.types.finish().as_bytes()); + } + if !self.intrinsics.is_empty() { + files.push("intrinsics.py", self.intrinsics.finish().as_bytes()); } - builder.dedent(); - builder.push_str("\n"); - } - fn type_variant( - &mut self, - iface: &Interface, - _id: TypeId, - name: &str, - variant: &Variant, - docs: &Docs, - ) { - let mut builder = self.src.builder(&mut self.deps, iface); - builder.pyimport("dataclasses", "dataclass"); - let mut cases = Vec::new(); - for case in variant.cases.iter() { - builder.docstring(&case.docs); - builder.push_str("@dataclass\n"); - let case_name = format!( - "{}{}", - name.to_upper_camel_case(), - case.name.to_upper_camel_case() - ); - builder.push_str(&format!("class {case_name}:\n")); - builder.indent(); - match &case.ty { - Some(ty) => { - builder.push_str("value: "); - builder.print_ty(ty, true); - } - None => builder.push_str("pass"), - } - builder.push_str("\n"); - builder.dedent(); - builder.push_str("\n"); - cases.push(case_name); + for (name, src) in self.exports.iter() { + let snake = name.to_snake_case(); + files.push(&format!("exports/{snake}.py"), src.finish().as_bytes()); + } + if !self.exports.is_empty() { + files.push("exports/__init__.py", b""); } - builder.deps.pyimport("typing", "Union"); - builder.comment(docs); - builder.push_str(&format!( - "{} = Union[{}]\n", - name.to_upper_camel_case(), - cases.join(", "), - )); - builder.push_str("\n"); + files.push("__init__.py", self.init.finish().as_bytes()); } +} - /// Appends a Python definition for the provided Union to the current `Source`. - /// e.g. `MyUnion = Union[float, str, int]` - fn type_union( - &mut self, - iface: &Interface, - _id: TypeId, - name: &str, - union: &Union, - docs: &Docs, - ) { - let mut py_type_classes = BTreeSet::new(); - for case in union.cases.iter() { - py_type_classes.insert(py_type_class_of(&case.ty)); - } +struct Instantiator<'a> { + gen: &'a mut WasmtimePy, + modules: &'a PrimaryMap>, + instances: PrimaryMap, + interfaces: &'a ComponentInterfaces, + component: &'a Component, + lifts: usize, +} - let mut builder = self.src.builder(&mut self.deps, iface); - if py_type_classes.len() != union.cases.len() { - // Some of the cases are not distinguishable - self.union_representation - .insert(name.to_string(), PyUnionRepresentation::Wrapped); - builder.print_union_wrapped(name, union, docs); - } else { - // All of the cases are distinguishable - self.union_representation - .insert(name.to_string(), PyUnionRepresentation::Raw); - builder.print_union_raw(name, union, docs); - } - } +struct Lift<'a> { + callee: String, + opts: &'a CanonicalOptions, + iface: &'a Interface, + func: &'a Function, +} - fn type_option( - &mut self, - iface: &Interface, - _id: TypeId, - name: &str, - payload: &Type, - docs: &Docs, - ) { - let mut builder = self.src.builder(&mut self.deps, iface); - builder.pyimport("typing", "Optional"); - builder.comment(docs); - builder.push_str(&name.to_upper_camel_case()); - builder.push_str(" = Optional["); - builder.print_ty(payload, true); - builder.push_str("]\n\n"); - } +impl<'a> Instantiator<'a> { + fn global_initializer(&mut self, init: &GlobalInitializer) { + match init { + GlobalInitializer::InstantiateModule(m) => match m { + InstantiateModule::Static(idx, args) => self.instantiate_static_module(*idx, args), - fn type_result( - &mut self, - iface: &Interface, - _id: TypeId, - name: &str, - result: &Result_, - docs: &Docs, - ) { - self.deps.needs_result = true; - - let mut builder = self.src.builder(&mut self.deps, iface); - builder.comment(docs); - builder.push_str(&format!("{} = Result[", name.to_upper_camel_case())); - builder.print_optional_ty(result.ok.as_ref(), true); - builder.push_str(", "); - builder.print_optional_ty(result.err.as_ref(), true); - builder.push_str("]\n\n"); - } - - fn type_enum(&mut self, iface: &Interface, _id: TypeId, name: &str, enum_: &Enum, docs: &Docs) { - let mut builder = self.src.builder(&mut self.deps, iface); - builder.pyimport("enum", "Enum"); - builder.push_str(&format!("class {}(Enum):\n", name.to_upper_camel_case())); - builder.indent(); - builder.docstring(docs); - for (i, case) in enum_.cases.iter().enumerate() { - builder.comment(&case.docs); + // This is only needed when instantiating an imported core wasm + // module which while easy to implement here is not possible to + // test at this time so it's left unimplemented. + InstantiateModule::Import(..) => unimplemented!(), + }, - // TODO this handling of digits should be more general and - // shouldn't be here just to fix the one case in wasi where an - // enum variant is "2big" and doesn't generate valid Python. We - // should probably apply this to all generated Python - // identifiers. - let mut name = case.name.to_shouty_snake_case(); - if name.chars().next().unwrap().is_digit(10) { - name = format!("_{}", name); + GlobalInitializer::LowerImport(i) => self.lower_import(i), + + GlobalInitializer::ExtractMemory(m) => { + let def = self.core_export(&m.export); + let i = m.index.as_u32(); + uwriteln!(self.gen.init, "core_memory{i} = {def}"); + uwriteln!( + self.gen.init, + "assert(isinstance(core_memory{i}, wasmtime.Memory))" + ); + uwriteln!(self.gen.init, "self._core_memory{i} = core_memory{i}",); + } + GlobalInitializer::ExtractRealloc(r) => { + let def = self.core_def(&r.def); + let i = r.index.as_u32(); + uwriteln!(self.gen.init, "realloc{i} = {def}"); + uwriteln!( + self.gen.init, + "assert(isinstance(realloc{i}, wasmtime.Func))" + ); + uwriteln!(self.gen.init, "self._realloc{i} = realloc{i}",); + } + GlobalInitializer::ExtractPostReturn(p) => { + let def = self.core_def(&p.def); + let i = p.index.as_u32(); + uwriteln!(self.gen.init, "post_return{i} = {def}"); + uwriteln!( + self.gen.init, + "assert(isinstance(post_return{i}, wasmtime.Func))" + ); + uwriteln!(self.gen.init, "self._post_return{i} = post_return{i}",); } - builder.push_str(&format!("{} = {}\n", name, i)); - } - builder.dedent(); - builder.push_str("\n"); - } - fn type_alias(&mut self, iface: &Interface, _id: TypeId, name: &str, ty: &Type, docs: &Docs) { - let mut builder = self.src.builder(&mut self.deps, iface); - builder.comment(docs); - builder.push_str(&format!("{} = ", name.to_upper_camel_case())); - builder.print_ty(ty, false); - builder.push_str("\n"); + // This is only used for a "degenerate component" which internally + // has a function that always traps. While this should be trivial to + // implement (generate a JS function that always throws) there's no + // way to test this at this time so leave this unimplemented. + GlobalInitializer::AlwaysTrap(_) => unimplemented!(), + + // This is only used when the component exports core wasm modules, + // but that's not possible to test right now so leave these as + // unimplemented. + GlobalInitializer::SaveStaticModule(_) => unimplemented!(), + GlobalInitializer::SaveModuleImport(_) => unimplemented!(), + + // This is required when strings pass between components within a + // component and may change encodings. This is left unimplemented + // for now since it can't be tested and additionally JS doesn't + // support multi-memory which transcoders rely on anyway. + GlobalInitializer::Transcoder(_) => unimplemented!(), + } } - fn type_list(&mut self, iface: &Interface, _id: TypeId, name: &str, ty: &Type, docs: &Docs) { - let mut builder = self.src.builder(&mut self.deps, iface); - builder.comment(docs); - builder.push_str(&format!("{} = ", name.to_upper_camel_case())); - builder.print_list(ty); - builder.push_str("\n"); - } + fn instantiate_static_module(&mut self, idx: StaticModuleIndex, args: &[CoreDef]) { + let i = self.instances.push(idx); + self.gen.init.pyimport("os", None); - fn type_builtin(&mut self, iface: &Interface, id: TypeId, name: &str, ty: &Type, docs: &Docs) { - self.type_alias(iface, id, name, ty, docs); + uwriteln!( + self.gen.init, + "path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'module{}.wasm')", + idx.as_u32(), + ); + uwriteln!( + self.gen.init, + "module = wasmtime.Module.from_file(store.engine, path)" + ); + uwrite!( + self.gen.init, + "instance{} = wasmtime.Instance(store, module, [", + i.as_u32() + ); + if !args.is_empty() { + self.gen.init.push_str("\n"); + self.gen.init.indent(); + for arg in args { + let def = self.core_def(arg); + uwriteln!(self.gen.init, "{def},"); + } + self.gen.init.dedent(); + } + uwriteln!(self.gen.init, "]).exports(store)"); } - // As with `abi_variant` above, we're generating host-side bindings here - // so a user "export" uses the "guest import" ABI variant on the inside of - // this `Generator` implementation. - fn export(&mut self, iface: &Interface, func: &Function) { - let mut pysig = Source::default(); - let mut builder = pysig.builder(&mut self.deps, iface); - builder.print_sig(func, self.in_import); - let pysig = pysig.to_string(); - - let mut func_body = Source::default(); - let mut builder = func_body.builder(&mut self.deps, iface); + fn lower_import(&mut self, import: &LowerImport) { + // Determine the `Interface` that this import corresponds to. At this + // time `wit-component` only supports root-level imports of instances + // where instances export functions. + let (import_index, path) = &self.component.imports[import.import]; + let (import_name, _import_ty) = &self.component.import_types[*import_index]; + assert_eq!(path.len(), 1); + let iface = &self.interfaces.imports[import_name.as_str()]; + let func = iface.functions.iter().find(|f| f.name == path[0]).unwrap(); + + let index = import.index.as_u32(); + let callee = format!( + "import_object.{}.{}", + import_name.to_snake_case(), + func.name.to_snake_case() + ); + // Generate an inline function "closure" which will capture the + // `imports` argument provided to the constructor of this class and have + // the core wasm signature for this function. Using prior local + // variables the function here will perform all liftings/lowerings. + uwrite!( + self.gen.init, + "def lowering{index}_callee(caller: wasmtime.Caller" + ); let sig = iface.wasm_signature(AbiVariant::GuestImport, func); - builder.push_str(&format!( - "def {}(caller: wasmtime.Caller", - func.name.to_snake_case(), - )); let mut params = Vec::new(); - for (i, param) in sig.params.iter().enumerate() { - builder.push_str(", "); - let name = format!("arg{}", i); - builder.push_str(&name); - builder.push_str(": "); - builder.push_str(wasm_ty_typing(*param)); - params.push(name); + for (i, param_ty) in sig.params.iter().enumerate() { + self.gen.init.push_str(", "); + let param = format!("arg{i}"); + uwrite!(self.gen.init, "{param}: {}", wasm_ty_typing(*param_ty)); + params.push(param); } - builder.push_str(") -> "); + self.gen.init.push_str(") -> "); match sig.results.len() { - 0 => builder.push_str("None"), - 1 => builder.push_str(wasm_ty_typing(sig.results[0])), + 0 => self.gen.init.push_str("None"), + 1 => self.gen.init.push_str(wasm_ty_typing(sig.results[0])), _ => unimplemented!(), } - builder.push_str(":\n"); - builder.indent(); - drop(builder); + self.gen.init.push_str(":\n"); + self.gen.init.indent(); - let mut f = FunctionBindgen::new(self, params); - iface.call( - AbiVariant::GuestImport, - LiftLower::LiftArgsLowerResults, + let iface_snake = iface.name.to_snake_case(); + self.gen.init.pyimport(".imports", iface_snake.as_str()); + let self_module_path = format!("{iface_snake}."); + + self.bindgen( + params, + callee, + &import.options, + iface, func, - &mut f, + AbiVariant::GuestImport, + "self", + self_module_path, ); - - let FunctionBindgen { - src, - needs_memory, - needs_realloc, - mut locals, - .. - } = f; - - let mut builder = func_body.builder(&mut self.deps, iface); - if needs_memory { - // TODO: hardcoding "memory" - builder.push_str("m = caller[\"memory\"]\n"); - builder.push_str("assert(isinstance(m, wasmtime.Memory))\n"); - builder.deps.pyimport("typing", "cast"); - builder.push_str("memory = cast(wasmtime.Memory, m)\n"); - locals.insert("memory").unwrap(); + self.gen.init.dedent(); + + // Use the `wasmtime` package's embedder methods of creating a wasm + // function to finish the construction here. + uwrite!(self.gen.init, "lowering{index}_ty = wasmtime.FuncType(["); + for param in sig.params.iter() { + self.gen.init.push_str(wasm_ty_ctor(*param)); + self.gen.init.push_str(", "); } - - if let Some(name) = needs_realloc { - builder.push_str(&format!("realloc = caller[\"{}\"]\n", name)); - builder.push_str("assert(isinstance(realloc, wasmtime.Func))\n"); - locals.insert("realloc").unwrap(); + self.gen.init.push_str("], ["); + for param in sig.results.iter() { + self.gen.init.push_str(wasm_ty_ctor(*param)); + self.gen.init.push_str(", "); } - - builder.push_str(&src); - builder.dedent(); - - let mut wasm_ty = String::from("wasmtime.FuncType(["); - wasm_ty.push_str( - &sig.params - .iter() - .map(|t| wasm_ty_ctor(*t)) - .collect::>() - .join(", "), + self.gen.init.push_str("])\n"); + uwriteln!( + self.gen.init, + "lowering{index} = wasmtime.Func(store, lowering{index}_ty, lowering{index}_callee, access_caller = True)" ); - wasm_ty.push_str("], ["); - wasm_ty.push_str( - &sig.results - .iter() - .map(|t| wasm_ty_ctor(*t)) - .collect::>() - .join(", "), - ); - wasm_ty.push_str("])"); - let import = Import { - name: func.name.clone(), - src: func_body, - wasm_ty, - pysig, - }; - let imports = self - .guest_imports - .entry(iface.name.to_string()) - .or_insert(Imports::default()); - let dst = match &func.kind { - FunctionKind::Freestanding => &mut imports.freestanding_funcs, - }; - dst.push(import); } - // As with `abi_variant` above, we're generating host-side bindings here - // so a user "import" uses the "export" ABI variant on the inside of - // this `Generator` implementation. - fn import(&mut self, iface: &Interface, func: &Function) { - let mut func_body = Source::default(); - let mut builder = func_body.builder(&mut self.deps, iface); - - // Print the function signature - let params = builder.print_sig(func, self.in_import); - builder.push_str(":\n"); - builder.indent(); - drop(builder); + fn bindgen( + &mut self, + params: Vec, + callee: String, + opts: &CanonicalOptions, + iface: &Interface, + func: &Function, + abi: AbiVariant, + this: &str, + self_module_path: String, + ) { + // Technically it wouldn't be the hardest thing in the world to support + // other string encodings, but for now the code generator was originally + // written to support utf-8 so let's just leave it at that for now. In + // the future when it's easier to produce components with non-utf-8 this + // can be plumbed through to string lifting/lowering below. + assert_eq!(opts.string_encoding, StringEncoding::Utf8); + + let memory = match opts.memory { + Some(idx) => Some(format!("{this}._core_memory{}", idx.as_u32())), + None => None, + }; + let realloc = match opts.realloc { + Some(idx) => Some(format!("{this}._realloc{}", idx.as_u32())), + None => None, + }; + let post_return = match opts.post_return { + Some(idx) => Some(format!("{this}._post_return{}", idx.as_u32())), + None => None, + }; - // Use FunctionBindgen call - let src_object = match &func.kind { - FunctionKind::Freestanding => "self".to_string(), + let mut sizes = SizeAlign::default(); + sizes.fill(iface); + let mut locals = Ns::default(); + locals.insert("len").unwrap(); // python built-in + locals.insert("base").unwrap(); // may be used as loop var + locals.insert("i").unwrap(); // may be used as loop var + let mut f = FunctionBindgen { + locals, + payloads: Vec::new(), + sizes, + // Generate source directly onto `init` + src: mem::take(&mut self.gen.init), + gen: self.gen, + block_storage: Vec::new(), + blocks: Vec::new(), + callee, + memory, + realloc, + params, + post_return, + iface, + self_module_path, }; - let mut f = FunctionBindgen::new(self, params); - f.src_object = src_object; iface.call( - AbiVariant::GuestExport, - LiftLower::LowerArgsLiftResults, + abi, + match abi { + AbiVariant::GuestImport => LiftLower::LiftArgsLowerResults, + AbiVariant::GuestExport => LiftLower::LowerArgsLiftResults, + }, func, &mut f, ); - let FunctionBindgen { - src, - needs_memory, - needs_realloc, - src_object, - .. - } = f; - let mut builder = func_body.builder(&mut self.deps, iface); - if needs_memory { - // TODO: hardcoding "memory" - builder.push_str(&format!("memory = {}._memory;\n", src_object)); + + // Swap the printed source back into the destination of our `init`, and + // at this time `f.src` should be empty. + mem::swap(&mut f.src, &mut f.gen.init); + assert!(f.src.is_empty()); + } + + fn core_def(&self, def: &CoreDef) -> String { + match def { + CoreDef::Export(e) => self.core_export(e), + CoreDef::Lowered(i) => format!("lowering{}", i.as_u32()), + CoreDef::AlwaysTrap(_) => unimplemented!(), + CoreDef::InstanceFlags(_) => unimplemented!(), + CoreDef::Transcoder(_) => unimplemented!(), } + } - if let Some(name) = &needs_realloc { - builder.push_str(&format!( - "realloc = {}._{}\n", - src_object, - name.to_snake_case(), - )); + fn core_export(&self, export: &CoreExport) -> String + where + T: Into + Copy, + { + let name = match &export.item { + ExportItem::Index(idx) => { + let module = &self.modules[self.instances[export.instance]].module; + let idx = (*idx).into(); + module + .exports + .iter() + .filter_map(|(name, i)| if *i == idx { Some(name) } else { None }) + .next() + .unwrap() + } + ExportItem::Name(s) => s, + }; + let i = export.instance.as_u32() as usize; + format!("instance{i}[\"{name}\"]") + } + + /// Extract the `LiftedFunction` exports to a format that's easier to + /// process for this generator. For now all lifted functions are either + /// "root" lifted functions or one-level-nested for an exported interface. + /// + /// As worlds shape up and more of a component's structure is expressible in + /// `*.wit` this method will likely need to change. + fn exports( + &mut self, + exports: &'a IndexMap, + iface: Option<&'a Interface>, + ) -> (Vec>, BTreeMap<&'a str, Vec>>) { + let mut toplevel = Vec::new(); + let mut nested = BTreeMap::new(); + for (name, export) in exports { + let name = name.as_str(); + match export { + Export::LiftedFunction { + ty: _, + func, + options, + } => { + // For each lifted function the callee `wasmtime.Func` is + // saved into a per-instance field which is then referenced + // as the callee when the relevant function is invoked. + let def = self.core_def(func); + let callee = format!("lift_callee{}", self.lifts); + self.lifts += 1; + uwriteln!(self.gen.init, "{callee} = {def}"); + uwriteln!(self.gen.init, "assert(isinstance({callee}, wasmtime.Func))"); + uwriteln!(self.gen.init, "self.{callee} = {callee}"); + let iface = iface.unwrap(); + let func = iface.functions.iter().find(|f| f.name == *name).unwrap(); + toplevel.push(Lift { + callee, + opts: options, + iface, + func, + }); + } + + Export::Instance(exports) => { + let iface = &self.interfaces.exports[name]; + let (my_toplevel, my_nested) = self.exports(exports, Some(iface)); + // More than one level of nesting not supported at this + // time. + assert!(my_nested.is_empty()); + + let prev = nested.insert(name, my_toplevel); + assert!(prev.is_none()); + } + + // ignore type exports for now + Export::Type(_) => {} + + // This can't be tested at this time so leave it unimplemented + Export::Module(_) => unimplemented!(), + } } + (toplevel, nested) + } - builder.push_str(&src); - builder.dedent(); - - let exports = self - .guest_exports - .entry(iface.name.to_string()) - .or_insert_with(Exports::default); - if needs_memory { - exports.fields.insert( - "memory".to_string(), - Export { - python_type: "wasmtime.Memory", - name: "memory".to_owned(), - }, + fn generate_lifts(&mut self, camel_component: &str, ns: Option<&str>, lifts: &[Lift<'_>]) { + let mut this = "self".to_string(); + + // If these exports are going into a non-default interface then a new + // `class` is generated in the corresponding file which will be + // constructed with the "root" class. Generate the class here, its one + // field of the root class, and then an associated constructor for the + // root class to have. Finally the root class grows a method here as + // well to return the nested instance. + if let Some(ns) = ns { + let src = self.gen.exports.get_mut(ns).unwrap(); + let camel = ns.to_upper_camel_case(); + let snake = ns.to_snake_case(); + uwriteln!(src, "class {camel}:"); + src.indent(); + uwriteln!(src, "component: {camel_component}\n"); + uwriteln!( + src, + "def __init__(self, component: {camel_component}) -> None:" ); + src.indent(); + uwriteln!(src, "self.component = component"); + src.dedent(); + + this.push_str(".component"); + + uwriteln!(self.gen.init, "def {snake}(self) -> {camel}:"); + self.gen.init.indent(); + uwriteln!(self.gen.init, "return {camel}(self)"); + self.gen.init.dedent(); + + // Swap the two sources so the generation into `init` will go into + // the right place + mem::swap(&mut self.gen.init, src); } - if let Some(name) = &needs_realloc { - exports.fields.insert( - name.clone(), - Export { - python_type: "wasmtime.Func", - name: name.clone(), - }, + + for lift in lifts { + // Go through some small gymnastics to print the function signature + // here. + let mut src = mem::take(&mut self.gen.init); + let params = with_igen(&mut src, self.gen, lift.iface, ns.is_none(), "", |gen| { + gen.print_sig(lift.func, false) + }); + self.gen.init = src; + self.gen.init.push_str(":\n"); + + // Defer to `self.bindgen` for the body of the function. + self.gen.init.indent(); + self.gen.init.docstring(&lift.func.docs); + self.bindgen( + params, + format!("{this}.{}", lift.callee), + lift.opts, + lift.iface, + lift.func, + AbiVariant::GuestExport, + &this, + String::new(), ); + self.gen.init.dedent(); } - exports.fields.insert( - func.name.clone(), - Export { - python_type: "wasmtime.Func", - name: func.name.clone(), - }, - ); - let dst = match &func.kind { - FunctionKind::Freestanding => &mut exports.freestanding_funcs, - }; - dst.push(func_body); + // Undo the swap done above. + if let Some(ns) = ns { + self.gen.init.dedent(); + mem::swap(&mut self.gen.init, self.gen.exports.get_mut(ns).unwrap()); + } + } +} + +impl WorldGenerator for WasmtimePy { + fn import(&mut self, name: &str, iface: &Interface, files: &mut Files) { + let mut gen = self.interface(iface, false); + gen.types(); + + // Generate a "protocol" class which I'm led to believe is the rough + // equivalent of a Rust trait in Python for this imported interface. + // This will be referenced in the constructor for the main component. + let camel = name.to_upper_camel_case(); + let snake = name.to_snake_case(); + gen.src.pyimport("typing", "Protocol"); + uwriteln!(gen.src, "class {camel}(Protocol):"); + gen.src.indent(); + for func in iface.functions.iter() { + gen.src.pyimport("abc", "abstractmethod"); + gen.src.push_str("@abstractmethod\n"); + gen.print_sig(func, true); + gen.src.push_str(":\n"); + gen.src.indent(); + gen.src.push_str("raise NotImplementedError\n"); + gen.src.dedent(); + } + gen.src.dedent(); + gen.src.push_str("\n"); + + let src = gen.src.finish(); + files.push(&format!("imports/{snake}.py"), src.as_bytes()); + self.imports.push(name.to_string()); + } + + fn export(&mut self, name: &str, iface: &Interface, _files: &mut Files) { + let mut gen = self.interface(iface, false); + gen.types(); + + // Only generate types for exports and this will get finished later on + // as lifted functions need to be inserted into these files as they're + // discovered. + let src = gen.src; + self.exports.insert(name.to_string(), src); + } + + fn export_default(&mut self, _name: &str, iface: &Interface, _files: &mut Files) { + let mut gen = self.interface(iface, true); + + // Generate types and imports directly into `__init__.py` for the + // default export, and exported functions (lifted functions) will get + // generate later. + mem::swap(&mut gen.src, &mut gen.gen.init); + gen.types(); + mem::swap(&mut gen.src, &mut gen.gen.init); + } + + fn finish(&mut self, name: &str, _interfaces: &ComponentInterfaces, _files: &mut Files) { + if !self.imports.is_empty() { + let camel = name.to_upper_camel_case(); + self.imports_init.pyimport("dataclasses", "dataclass"); + uwriteln!(self.imports_init, "@dataclass"); + uwriteln!(self.imports_init, "class {camel}Imports:"); + self.imports_init.indent(); + for import in self.imports.iter() { + let snake = import.to_snake_case(); + let camel = import.to_upper_camel_case(); + self.imports_init + .pyimport(&format!(".{snake}"), camel.as_str()); + uwriteln!(self.imports_init, "{snake}: {camel}"); + } + self.imports_init.dedent(); + } } +} + +struct InterfaceGenerator<'a> { + src: Source, + gen: &'a mut WasmtimePy, + iface: &'a Interface, + at_root: bool, + self_module_path: &'a str, +} - fn finish_one(&mut self, iface: &Interface, files: &mut Files) { - self.deps.pyimport("typing", "Any"); - self.deps.pyimport("abc", "abstractmethod"); +#[derive(Debug, Clone, Copy)] +enum PyUnionRepresentation { + /// A union whose inner types are used directly + Raw, + /// A union whose inner types have been wrapped in dataclasses + Wrapped, +} - let types = mem::take(&mut self.src); - let intrinsics = self.intrinsics(iface); +impl InterfaceGenerator<'_> { + fn import_result_type(&mut self) { + self.gen.print_result(); + let path = if self.at_root { ".types" } else { "..types" }; + self.src.pyimport(path, "Result"); + } - for (k, v) in self.deps.pyimports.iter() { - match v { - Some(list) => { - let list = list.iter().cloned().collect::>().join(", "); - self.src.push_str(&format!("from {} import {}\n", k, list)); + fn print_ty(&mut self, ty: &Type, forward_ref: bool) { + match ty { + Type::Bool => self.src.push_str("bool"), + Type::U8 + | Type::S8 + | Type::U16 + | Type::S16 + | Type::U32 + | Type::S32 + | Type::U64 + | Type::S64 => self.src.push_str("int"), + Type::Float32 | Type::Float64 => self.src.push_str("float"), + Type::Char => self.src.push_str("str"), + Type::String => self.src.push_str("str"), + Type::Id(id) => { + let ty = &self.iface.types[*id]; + if let Some(name) = &ty.name { + self.src.push_str(self.self_module_path); + self.src.push_str(&name.to_upper_camel_case()); + return; } - None => { - self.src.push_str(&format!("import {}\n", k)); + match &ty.kind { + TypeDefKind::Type(t) => self.print_ty(t, forward_ref), + TypeDefKind::Tuple(t) => self.print_tuple(t), + TypeDefKind::Record(_) + | TypeDefKind::Flags(_) + | TypeDefKind::Enum(_) + | TypeDefKind::Variant(_) + | TypeDefKind::Union(_) => { + unreachable!() + } + TypeDefKind::Option(t) => { + self.src.pyimport("typing", "Optional"); + self.src.push_str("Optional["); + self.print_ty(t, true); + self.src.push_str("]"); + } + TypeDefKind::Result(r) => { + self.import_result_type(); + self.src.push_str("Result["); + self.print_optional_ty(r.ok.as_ref(), true); + self.src.push_str(", "); + self.print_optional_ty(r.err.as_ref(), true); + self.src.push_str("]"); + } + TypeDefKind::List(t) => self.print_list(t), + TypeDefKind::Future(t) => { + self.src.push_str("Future["); + self.print_optional_ty(t.as_ref(), true); + self.src.push_str("]"); + } + TypeDefKind::Stream(s) => { + self.src.push_str("Stream["); + self.print_optional_ty(s.element.as_ref(), true); + self.src.push_str(", "); + self.print_optional_ty(s.end.as_ref(), true); + self.src.push_str("]"); + } } } } - self.src.push_str("import wasmtime\n"); - self.src.push_str( - " - try: - from typing import Protocol - except ImportError: - class Protocol: # type: ignore - pass - ", - ); - self.src.push_str("\n"); + } - if self.deps.needs_t_typevar { - self.src.push_str("T = TypeVar('T')\n"); + fn print_optional_ty(&mut self, ty: Option<&Type>, forward_ref: bool) { + match ty { + Some(ty) => self.print_ty(ty, forward_ref), + None => self.src.push_str("None"), } + } - self.src.push_str(&intrinsics); - self.src.push_str(&types); - - for (module, funcs) in mem::take(&mut self.guest_imports) { - self.src.push_str(&format!( - "class {}(Protocol):\n", - module.to_upper_camel_case() - )); - self.src.indent(); - for func in funcs.freestanding_funcs.iter() { - self.src.push_str("@abstractmethod\n"); - self.src.push_str(&func.pysig); - self.src.push_str(":\n"); - self.src.indent(); - self.src.push_str("raise NotImplementedError\n"); - self.src.dedent(); + fn print_tuple(&mut self, tuple: &Tuple) { + if tuple.types.is_empty() { + return self.src.push_str("None"); + } + self.src.pyimport("typing", "Tuple"); + self.src.push_str("Tuple["); + for (i, t) in tuple.types.iter().enumerate() { + if i > 0 { + self.src.push_str(", "); } - self.src.dedent(); - self.src.push_str("\n"); - - self.src.push_str(&format!( - "def add_{}_to_linker(linker: wasmtime.Linker, store: wasmtime.Store, host: {}) -> None:\n", - module.to_snake_case(), - module.to_upper_camel_case(), - )); - self.src.indent(); + self.print_ty(t, true); + } + self.src.push_str("]"); + } - for func in funcs.freestanding_funcs.iter() { - self.src.push_str(&format!("ty = {}\n", func.wasm_ty)); - self.src.push_str(&func.src); - self.src.push_str(&format!( - "linker.define('{}', '{}', wasmtime.Func(store, ty, {}, access_caller = True))\n", - iface.name, - func.name, - func.name.to_snake_case(), - )); + fn print_list(&mut self, element: &Type) { + match element { + Type::U8 => self.src.push_str("bytes"), + t => { + self.src.pyimport("typing", "List"); + self.src.push_str("List["); + self.print_ty(t, true); + self.src.push_str("]"); } + } + } - self.src.dedent(); + fn print_sig(&mut self, func: &Function, in_import: bool) -> Vec { + self.src.push_str("def "); + self.src.push_str(&func.name.to_snake_case()); + if in_import { + self.src.push_str("(self"); + } else { + self.src.push_str("(self, caller: wasmtime.Store"); + } + let mut params = Vec::new(); + for (param, ty) in func.params.iter() { + self.src.push_str(", "); + self.src.push_str(¶m.to_snake_case()); + params.push(param.to_snake_case()); + self.src.push_str(": "); + self.print_ty(ty, true); + } + self.src.push_str(") -> "); + match func.results.len() { + 0 => self.src.push_str("None"), + 1 => self.print_ty(func.results.iter_types().next().unwrap(), true), + _ => { + self.src.pyimport("typing", "Tuple"); + self.src.push_str("Tuple["); + for (i, ty) in func.results.iter_types().enumerate() { + if i > 0 { + self.src.push_str(", "); + } + self.print_ty(ty, true); + } + self.src.push_str("]"); + } } + params + } - // This is exculsively here to get mypy to not complain about empty - // modules, this probably won't really get triggered much in practice - if !self.in_import && self.guest_exports.is_empty() { - self.src - .push_str(&format!("class {}:\n", iface.name.to_upper_camel_case())); + /// Print a wrapped union definition. + /// e.g. + /// ```py + /// @dataclass + /// class Foo0: + /// value: int + /// + /// @dataclass + /// class Foo1: + /// value: int + /// + /// Foo = Union[Foo0, Foo1] + /// ``` + pub fn print_union_wrapped(&mut self, name: &str, union: &Union, docs: &Docs) { + self.src.pyimport("dataclasses", "dataclass"); + let mut cases = Vec::new(); + let name = name.to_upper_camel_case(); + for (i, case) in union.cases.iter().enumerate() { + self.src.push_str("@dataclass\n"); + let name = format!("{name}{i}"); + self.src.push_str(&format!("class {name}:\n")); self.src.indent(); - self.src.push_str("pass\n"); + self.src.docstring(&case.docs); + self.src.push_str("value: "); + self.print_ty(&case.ty, true); + self.src.newline(); self.src.dedent(); + self.src.newline(); + cases.push(name); } - for (module, exports) in mem::take(&mut self.guest_exports) { - let module = module.to_upper_camel_case(); - self.src.push_str(&format!("class {}:\n", module)); - self.src.indent(); + self.src.pyimport("typing", "Union"); + self.src.comment(docs); + self.src + .push_str(&format!("{name} = Union[{}]\n", cases.join(", "))); + self.src.newline(); + } - self.src.push_str("instance: wasmtime.Instance\n"); - for (_name, export) in exports.fields.iter() { - self.src.push_str(&format!( - "_{}: {}\n", - export.name.to_snake_case(), - export.python_type - )); + pub fn print_union_raw(&mut self, name: &str, union: &Union, docs: &Docs) { + self.src.pyimport("typing", "Union"); + self.src.comment(docs); + for case in union.cases.iter() { + self.src.comment(&case.docs); + } + self.src.push_str(&name.to_upper_camel_case()); + self.src.push_str(" = Union["); + let mut first = true; + for case in union.cases.iter() { + if !first { + self.src.push_str(","); } + self.print_ty(&case.ty, true); + first = false; + } + self.src.push_str("]\n\n"); + } +} + +impl<'a> wit_bindgen_core::InterfaceGenerator<'a> for InterfaceGenerator<'a> { + fn iface(&self) -> &'a Interface { + self.iface + } - self.src.push_str("def __init__(self, store: wasmtime.Store, linker: wasmtime.Linker, module: wasmtime.Module):\n"); + fn type_record(&mut self, _id: TypeId, name: &str, record: &Record, docs: &Docs) { + self.src.pyimport("dataclasses", "dataclass"); + self.src.push_str("@dataclass\n"); + self.src + .push_str(&format!("class {}:\n", name.to_upper_camel_case())); + self.src.indent(); + self.src.docstring(docs); + for field in record.fields.iter() { + self.src.comment(&field.docs); + let field_name = field.name.to_snake_case(); + self.src.push_str(&format!("{field_name}: ")); + self.print_ty(&field.ty, true); + self.src.push_str("\n"); + } + if record.fields.is_empty() { + self.src.push_str("pass\n"); + } + self.src.dedent(); + self.src.push_str("\n"); + } + + fn type_tuple(&mut self, _id: TypeId, name: &str, tuple: &Tuple, docs: &Docs) { + self.src.comment(docs); + self.src + .push_str(&format!("{} = ", name.to_upper_camel_case())); + self.print_tuple(tuple); + self.src.push_str("\n"); + } + + fn type_flags(&mut self, _id: TypeId, name: &str, flags: &Flags, docs: &Docs) { + self.src.pyimport("enum", "Flag"); + self.src.pyimport("enum", "auto"); + self.src + .push_str(&format!("class {}(Flag):\n", name.to_upper_camel_case())); + self.src.indent(); + self.src.docstring(docs); + for flag in flags.flags.iter() { + let flag_name = flag.name.to_shouty_snake_case(); + self.src.comment(&flag.docs); + self.src.push_str(&format!("{flag_name} = auto()\n")); + } + if flags.flags.is_empty() { + self.src.push_str("pass\n"); + } + self.src.dedent(); + self.src.push_str("\n"); + } + + fn type_variant(&mut self, _id: TypeId, name: &str, variant: &Variant, docs: &Docs) { + self.src.pyimport("dataclasses", "dataclass"); + let mut cases = Vec::new(); + for case in variant.cases.iter() { + self.src.docstring(&case.docs); + self.src.push_str("@dataclass\n"); + let case_name = format!( + "{}{}", + name.to_upper_camel_case(), + case.name.to_upper_camel_case() + ); + self.src.push_str(&format!("class {case_name}:\n")); self.src.indent(); - self.src - .push_str("self.instance = linker.instantiate(store, module)\n"); - self.src - .push_str("exports = self.instance.exports(store)\n"); - for (name, export) in exports.fields.iter() { - self.src.push_str(&format!( - " - {snake} = exports['{name}'] - assert(isinstance({snake}, {ty})) - self._{snake} = {snake} - ", - name = name, - snake = export.name.to_snake_case(), - ty = export.python_type, - )); + match &case.ty { + Some(ty) => { + self.src.push_str("value: "); + self.print_ty(ty, true); + } + None => self.src.push_str("pass"), } + self.src.push_str("\n"); self.src.dedent(); + self.src.push_str("\n"); + cases.push(case_name); + } + + self.src.pyimport("typing", "Union"); + self.src.comment(docs); + self.src.push_str(&format!( + "{} = Union[{}]\n", + name.to_upper_camel_case(), + cases.join(", "), + )); + self.src.push_str("\n"); + } - for func in exports.freestanding_funcs.iter() { - self.src.push_str(&func); + /// Appends a Python definition for the provided Union to the current `Source`. + /// e.g. `MyUnion = Union[float, str, int]` + fn type_union(&mut self, _id: TypeId, name: &str, union: &Union, docs: &Docs) { + match classify_union(union.cases.iter().map(|t| t.ty)) { + PyUnionRepresentation::Wrapped => { + self.print_union_wrapped(name, union, docs); } + PyUnionRepresentation::Raw => { + self.print_union_raw(name, union, docs); + } + } + } - self.src.dedent(); + fn type_option(&mut self, _id: TypeId, name: &str, payload: &Type, docs: &Docs) { + self.src.pyimport("typing", "Optional"); + self.src.comment(docs); + self.src.push_str(&name.to_upper_camel_case()); + self.src.push_str(" = Optional["); + self.print_ty(payload, true); + self.src.push_str("]\n\n"); + } + + fn type_result(&mut self, _id: TypeId, name: &str, result: &Result_, docs: &Docs) { + self.import_result_type(); + + self.src.comment(docs); + self.src + .push_str(&format!("{} = Result[", name.to_upper_camel_case())); + self.print_optional_ty(result.ok.as_ref(), true); + self.src.push_str(", "); + self.print_optional_ty(result.err.as_ref(), true); + self.src.push_str("]\n\n"); + } + + fn type_enum(&mut self, _id: TypeId, name: &str, enum_: &Enum, docs: &Docs) { + self.src.pyimport("enum", "Enum"); + self.src + .push_str(&format!("class {}(Enum):\n", name.to_upper_camel_case())); + self.src.indent(); + self.src.docstring(docs); + for (i, case) in enum_.cases.iter().enumerate() { + self.src.comment(&case.docs); + + // TODO this handling of digits should be more general and + // shouldn't be here just to fix the one case in wasi where an + // enum variant is "2big" and doesn't generate valid Python. We + // should probably apply this to all generated Python + // identifiers. + let mut name = case.name.to_shouty_snake_case(); + if name.chars().next().unwrap().is_digit(10) { + name = format!("_{}", name); + } + self.src.push_str(&format!("{} = {}\n", name, i)); } + self.src.dedent(); + self.src.push_str("\n"); + } + + fn type_alias(&mut self, _id: TypeId, name: &str, ty: &Type, docs: &Docs) { + self.src.comment(docs); + self.src + .push_str(&format!("{} = ", name.to_upper_camel_case())); + self.print_ty(ty, false); + self.src.push_str("\n"); + } + + fn type_list(&mut self, _id: TypeId, name: &str, ty: &Type, docs: &Docs) { + self.src.comment(docs); + self.src + .push_str(&format!("{} = ", name.to_upper_camel_case())); + self.print_list(ty); + self.src.push_str("\n"); + } - files.push("bindings.py", self.src.as_bytes()); + fn type_builtin(&mut self, id: TypeId, name: &str, ty: &Type, docs: &Docs) { + self.type_alias(id, name, ty, docs); } } struct FunctionBindgen<'a> { gen: &'a mut WasmtimePy, + iface: &'a Interface, locals: Ns, src: Source, - block_storage: Vec, + block_storage: Vec, blocks: Vec<(String, Vec)>, - needs_memory: bool, - needs_realloc: Option, params: Vec, payloads: Vec, - src_object: String, + sizes: SizeAlign, + + memory: Option, + realloc: Option, + post_return: Option, + callee: String, + self_module_path: String, } impl FunctionBindgen<'_> { - fn new(gen: &mut WasmtimePy, params: Vec) -> FunctionBindgen<'_> { - let mut locals = Ns::default(); - locals.insert("len").unwrap(); // python built-in - locals.insert("base").unwrap(); // may be used as loop var - locals.insert("i").unwrap(); // may be used as loop var - for param in params.iter() { - locals.insert(param).unwrap(); - } - FunctionBindgen { - gen, - locals, - src: Source::default(), - block_storage: Vec::new(), - blocks: Vec::new(), - needs_memory: false, - needs_realloc: None, - params, - payloads: Vec::new(), - src_object: "self".to_string(), - } - } - fn clamp(&mut self, results: &mut Vec, operands: &[String], min: T, max: T) where T: std::fmt::Display, { - self.gen.deps.needs_clamp = true; - results.push(format!("_clamp({}, {}, {})", operands[0], min, max)); + let clamp = self.gen.print_clamp(&mut self.src); + results.push(format!("{clamp}({}, {min}, {max})", operands[0])); } fn load(&mut self, ty: &str, offset: i32, operands: &[String], results: &mut Vec) { - self.needs_memory = true; - self.gen.deps.needs_load = true; + let memory = self.memory.as_ref().unwrap(); + let load = self.gen.print_load(&mut self.src); let tmp = self.locals.tmp("load"); - self.src.push_str(&format!( - "{} = _load(ctypes.{}, memory, caller, {}, {})\n", - tmp, ty, operands[0], offset, - )); + self.src.pyimport("ctypes", None); + uwriteln!( + self.src, + "{tmp} = {load}(ctypes.{ty}, {memory}, caller, {}, {offset})", + operands[0], + ); results.push(tmp); } fn store(&mut self, ty: &str, offset: i32, operands: &[String]) { - self.needs_memory = true; - self.gen.deps.needs_store = true; - self.src.push_str(&format!( - "_store(ctypes.{}, memory, caller, {}, {}, {})\n", - ty, operands[1], offset, operands[0] - )); + let memory = self.memory.as_ref().unwrap(); + let store = self.gen.print_store(&mut self.src); + self.src.pyimport("ctypes", None); + uwriteln!( + self.src, + "{store}(ctypes.{ty}, {memory}, caller, {}, {offset}, {})", + operands[1], + operands[0] + ); + } + + fn print_ty(&mut self, ty: &Type) { + with_igen( + &mut self.src, + self.gen, + self.iface, + true, + &self.self_module_path, + |gen| gen.print_ty(ty, false), + ) + } + + fn print_list(&mut self, element: &Type) { + with_igen( + &mut self.src, + self.gen, + self.iface, + true, + &self.self_module_path, + |gen| gen.print_list(element), + ) } } @@ -737,18 +1446,17 @@ impl Bindgen for FunctionBindgen<'_> { type Operand = String; fn sizes(&self) -> &SizeAlign { - &self.gen.sizes + &self.sizes } fn push_block(&mut self) { - let prev = mem::take(&mut self.src); - self.block_storage.push(prev); + self.block_storage.push(self.src.take_body()); } fn finish_block(&mut self, operands: &mut Vec) { let to_restore = self.block_storage.pop().unwrap(); - let src = mem::replace(&mut self.src, to_restore); - self.blocks.push((src.into(), mem::take(operands))); + let src = self.src.replace_body(to_restore); + self.blocks.push((src, mem::take(operands))); } fn return_pointer(&mut self, _iface: &Interface, _size: usize, _align: usize) -> String { @@ -766,7 +1474,6 @@ impl Bindgen for FunctionBindgen<'_> { operands: &mut Vec, results: &mut Vec, ) { - let mut builder = self.src.builder(&mut self.gen.deps, iface); match inst { Instruction::GetArg { nth } => results.push(self.params[*nth].clone()), Instruction::I32Const { val } => results.push(val.to_string()), @@ -827,8 +1534,9 @@ impl Bindgen for FunctionBindgen<'_> { // Validate that i32 values coming from wasm are indeed valid code // points. Instruction::CharFromI32 => { - builder.deps.needs_validate_guest_char = true; - results.push(format!("_validate_guest_char({})", operands[0])); + self.src.pyimport(".intrinsics", "_validate_guest_char"); + let validate = self.gen.print_validate_guest_char(&mut self.src); + results.push(format!("{validate}({})", operands[0])); } Instruction::I32FromChar => { @@ -839,28 +1547,28 @@ impl Bindgen for FunctionBindgen<'_> { for (cast, op) in casts.iter().zip(operands) { match cast { Bitcast::I32ToF32 => { - builder.deps.needs_i32_to_f32 = true; - results.push(format!("_i32_to_f32({})", op)); + let cvt = self.gen.print_i32_to_f32(&mut self.src); + results.push(format!("{cvt}({})", op)); } Bitcast::F32ToI32 => { - builder.deps.needs_f32_to_i32 = true; - results.push(format!("_f32_to_i32({})", op)); + let cvt = self.gen.print_f32_to_i32(&mut self.src); + results.push(format!("{cvt}({})", op)); } Bitcast::I64ToF64 => { - builder.deps.needs_i64_to_f64 = true; - results.push(format!("_i64_to_f64({})", op)); + let cvt = self.gen.print_i64_to_f64(&mut self.src); + results.push(format!("{cvt}({})", op)); } Bitcast::F64ToI64 => { - builder.deps.needs_f64_to_i64 = true; - results.push(format!("_f64_to_i64({})", op)); + let cvt = self.gen.print_f64_to_i64(&mut self.src); + results.push(format!("{cvt}({})", op)); } Bitcast::I64ToF32 => { - builder.deps.needs_i32_to_f32 = true; - results.push(format!("_i32_to_f32(({}) & 0xffffffff)", op)); + let cvt = self.gen.print_i32_to_f32(&mut self.src); + results.push(format!("{cvt}(({}) & 0xffffffff)", op)); } Bitcast::F32ToI64 => { - builder.deps.needs_f32_to_i32 = true; - results.push(format!("_f32_to_i32({})", op)); + let cvt = self.gen.print_f32_to_i32(&mut self.src); + results.push(format!("{cvt}({})", op)); } Bitcast::I32ToI64 | Bitcast::I64ToI32 | Bitcast::None => { results.push(op.clone()) @@ -872,18 +1580,23 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::BoolFromI32 => { let op = self.locals.tmp("operand"); let ret = self.locals.tmp("boolean"); - builder.push_str(&format!( - " - {op} = {} - if {op} == 0: - {ret} = False - elif {op} == 1: - {ret} = True - else: - raise TypeError(\"invalid variant discriminant for bool\") - ", - operands[0] - )); + + uwriteln!(self.src, "{op} = {}", operands[0]); + uwriteln!(self.src, "if {op} == 0:"); + self.src.indent(); + uwriteln!(self.src, "{ret} = False"); + self.src.dedent(); + uwriteln!(self.src, "elif {op} == 1:"); + self.src.indent(); + uwriteln!(self.src, "{ret} = True"); + self.src.dedent(); + uwriteln!(self.src, "else:"); + self.src.indent(); + uwriteln!( + self.src, + "raise TypeError(\"invalid variant discriminant for bool\")" + ); + self.src.dedent(); results.push(ret); } Instruction::I32FromBool => { @@ -895,22 +1608,18 @@ impl Bindgen for FunctionBindgen<'_> { return; } let tmp = self.locals.tmp("record"); - builder.push_str(&format!("{} = {}\n", tmp, operands[0])); + uwriteln!(self.src, "{tmp} = {}", operands[0]); for field in record.fields.iter() { let name = self.locals.tmp("field"); - builder.push_str(&format!( - "{} = {}.{}\n", - name, - tmp, - field.name.to_snake_case(), - )); + uwriteln!(self.src, "{name} = {tmp}.{}", field.name.to_snake_case(),); results.push(name); } } Instruction::RecordLift { name, .. } => { results.push(format!( - "{}({})", + "{}{}({})", + self.self_module_path, name.to_upper_camel_case(), operands.join(", ") )); @@ -919,16 +1628,13 @@ impl Bindgen for FunctionBindgen<'_> { if tuple.types.is_empty() { return; } - builder.push_str("("); + self.src.push_str("("); for _ in 0..tuple.types.len() { let name = self.locals.tmp("tuplei"); - builder.push_str(&name); - builder.push_str(","); + uwrite!(self.src, "{name},"); results.push(name); } - builder.push_str(") = "); - builder.push_str(&operands[0]); - builder.push_str("\n"); + uwriteln!(self.src, ") = {}", operands[0]); } Instruction::TupleLift { .. } => { if operands.is_empty() { @@ -941,21 +1647,25 @@ impl Bindgen for FunctionBindgen<'_> { let operand = match operands.len() { 1 => operands[0].clone(), _ => { - let tmp = self.locals.tmp("flags"); - builder.push_str(&format!("{tmp} = 0\n")); + let tmp = self.locals.tmp("bits"); + uwriteln!(self.src, "{tmp} = 0"); for (i, op) in operands.iter().enumerate() { let i = 32 * i; - builder.push_str(&format!("{tmp} |= {op} << {i}\n")); + uwriteln!(self.src, "{tmp} |= {op} << {i}\n"); } tmp } }; - results.push(format!("{}({})", name.to_upper_camel_case(), operand)); + results.push(format!( + "{}{}({operand})", + self.self_module_path, + name.to_upper_camel_case() + )); } Instruction::FlagsLower { flags, .. } => match flags.repr().count() { 1 => results.push(format!("({}).value", operands[0])), n => { - let tmp = self.locals.tmp("flags"); + let tmp = self.locals.tmp("bits"); self.src .push_str(&format!("{tmp} = ({}).value\n", operands[0])); for i in 0..n { @@ -994,36 +1704,38 @@ impl Bindgen for FunctionBindgen<'_> { variant.cases.iter().zip(blocks).zip(payloads).enumerate() { if i == 0 { - builder.push_str("if "); + self.src.push_str("if "); } else { - builder.push_str("elif "); + self.src.push_str("elif "); } - builder.push_str(&format!( - "isinstance({}, {}{}):\n", + uwriteln!( + self.src, + "isinstance({}, {}{}{}):", operands[0], + self.self_module_path, name.to_upper_camel_case(), case.name.to_upper_camel_case() - )); - builder.indent(); + ); + self.src.indent(); if case.ty.is_some() { - builder.push_str(&format!("{} = {}.value\n", payload, operands[0])); + uwriteln!(self.src, "{payload} = {}.value", operands[0]); } - builder.push_str(&block); + self.src.push_str(&block); for (i, result) in block_results.iter().enumerate() { - builder.push_str(&format!("{} = {}\n", results[i], result)); + uwriteln!(self.src, "{} = {result}", results[i]); } - builder.dedent(); + self.src.dedent(); } let variant_name = name.to_upper_camel_case(); - builder.push_str("else:\n"); - builder.indent(); - builder.push_str(&format!( - "raise TypeError(\"invalid variant specified for {}\")\n", - variant_name - )); - builder.dedent(); + self.src.push_str("else:\n"); + self.src.indent(); + uwriteln!( + self.src, + "raise TypeError(\"invalid variant specified for {variant_name}\")", + ); + self.src.dedent(); } Instruction::VariantLift { @@ -1035,40 +1747,43 @@ impl Bindgen for FunctionBindgen<'_> { .collect::>(); let result = self.locals.tmp("variant"); - builder.print_var_declaration(&result, &Type::Id(*ty)); + uwrite!(self.src, "{result}: "); + self.print_ty(&Type::Id(*ty)); + self.src.push_str("\n"); for (i, (case, (block, block_results))) in variant.cases.iter().zip(blocks).enumerate() { if i == 0 { - builder.push_str("if "); + self.src.push_str("if "); } else { - builder.push_str("elif "); + self.src.push_str("elif "); } - builder.push_str(&format!("{} == {}:\n", operands[0], i)); - builder.indent(); - builder.push_str(&block); - - builder.push_str(&format!( - "{} = {}{}(", - result, + uwriteln!(self.src, "{} == {i}:", operands[0]); + self.src.indent(); + self.src.push_str(&block); + + uwrite!( + self.src, + "{result} = {}{}{}(", + self.self_module_path, name.to_upper_camel_case(), case.name.to_upper_camel_case() - )); + ); if block_results.len() > 0 { assert!(block_results.len() == 1); - builder.push_str(&block_results[0]); + self.src.push_str(&block_results[0]); } - builder.push_str(")\n"); - builder.dedent(); + self.src.push_str(")\n"); + self.src.dedent(); } - builder.push_str("else:\n"); - builder.indent(); + self.src.push_str("else:\n"); + self.src.indent(); let variant_name = name.to_upper_camel_case(); - builder.push_str(&format!( - "raise TypeError(\"invalid variant discriminant for {}\")\n", - variant_name - )); - builder.dedent(); + uwriteln!( + self.src, + "raise TypeError(\"invalid variant discriminant for {variant_name}\")", + ); + self.src.dedent(); results.push(result); } @@ -1091,53 +1806,47 @@ impl Bindgen for FunctionBindgen<'_> { results.push(self.locals.tmp("variant")); } - // Assumes that type_union has been called for this union - let union_representation = *self - .gen - .union_representation - .get(&name.to_string()) - .unwrap(); + let union_representation = classify_union(union.cases.iter().map(|c| c.ty)); let name = name.to_upper_camel_case(); let op0 = &operands[0]; for (i, ((case, (block, block_results)), payload)) in union.cases.iter().zip(blocks).zip(payloads).enumerate() { - builder.push_str(if i == 0 { "if " } else { "elif " }); - builder.push_str(&format!("isinstance({op0}, ")); + self.src.push_str(if i == 0 { "if " } else { "elif " }); + uwrite!(self.src, "isinstance({op0}, "); match union_representation { // Prints the Python type for this union case - PyUnionRepresentation::Raw => { - builder.print_ty(&case.ty, false); - } + PyUnionRepresentation::Raw => self.print_ty(&case.ty), // Prints the name of this union cases dataclass PyUnionRepresentation::Wrapped => { - builder.push_str(&format!("{name}{i}")); + uwrite!(self.src, "{}{name}{i}", self.self_module_path); } } - builder.push_str(&format!("):\n")); - builder.indent(); + uwriteln!(self.src, "):"); + self.src.indent(); match union_representation { // Uses the value directly PyUnionRepresentation::Raw => { - builder.push_str(&format!("{payload} = {op0}\n")) + uwriteln!(self.src, "{payload} = {op0}") } // Uses this union case dataclass's inner value PyUnionRepresentation::Wrapped => { - builder.push_str(&format!("{payload} = {op0}.value\n")) + uwriteln!(self.src, "{payload} = {op0}.value") } } - builder.push_str(&block); + self.src.push_str(&block); for (i, result) in block_results.iter().enumerate() { - builder.push_str(&format!("{} = {result}\n", results[i])); + uwriteln!(self.src, "{} = {result}", results[i]); } - builder.dedent(); + self.src.dedent(); } - builder.push_str("else:\n"); - builder.indent(); - builder.push_str(&format!( - "raise TypeError(\"invalid variant specified for {name}\")\n" - )); - builder.dedent(); + self.src.push_str("else:\n"); + self.src.indent(); + uwriteln!( + self.src, + "raise TypeError(\"invalid variant specified for {name}\")" + ); + self.src.dedent(); } Instruction::UnionLift { @@ -1149,42 +1858,44 @@ impl Bindgen for FunctionBindgen<'_> { .collect::>(); let result = self.locals.tmp("variant"); - builder.print_var_declaration(&result, &Type::Id(*ty)); - // Assumes that type_union has been called for this union - let union_representation = *self - .gen - .union_representation - .get(&name.to_string()) - .unwrap(); + uwrite!(self.src, "{result}: "); + self.print_ty(&Type::Id(*ty)); + self.src.push_str("\n"); + let union_representation = classify_union(union.cases.iter().map(|c| c.ty)); let name = name.to_upper_camel_case(); let op0 = &operands[0]; for (i, (_case, (block, block_results))) in union.cases.iter().zip(blocks).enumerate() { - builder.push_str(if i == 0 { "if " } else { "elif " }); - builder.push_str(&format!("{op0} == {i}:\n")); - builder.indent(); - builder.push_str(&block); + self.src.push_str(if i == 0 { "if " } else { "elif " }); + uwriteln!(self.src, "{op0} == {i}:"); + self.src.indent(); + self.src.push_str(&block); assert!(block_results.len() == 1); let block_result = &block_results[0]; - builder.push_str(&format!("{result} = ")); + uwrite!(self.src, "{result} = "); match union_representation { // Uses the passed value directly - PyUnionRepresentation::Raw => builder.push_str(block_result), + PyUnionRepresentation::Raw => self.src.push_str(block_result), // Constructs an instance of the union cases dataclass PyUnionRepresentation::Wrapped => { - builder.push_str(&format!("{name}{i}({block_result})")) + uwrite!( + self.src, + "{}{name}{i}({block_result})", + self.self_module_path + ) } } - builder.newline(); - builder.dedent(); + self.src.newline(); + self.src.dedent(); } - builder.push_str("else:\n"); - builder.indent(); - builder.push_str(&format!( + self.src.push_str("else:\n"); + self.src.indent(); + uwriteln!( + self.src, "raise TypeError(\"invalid variant discriminant for {name}\")\n", - )); - builder.dedent(); + ); + self.src.dedent(); results.push(result); } @@ -1202,22 +1913,22 @@ impl Bindgen for FunctionBindgen<'_> { } let op0 = &operands[0]; - builder.push_str(&format!("if {op0} is None:\n")); + uwriteln!(self.src, "if {op0} is None:"); - builder.indent(); - builder.push_str(&none); + self.src.indent(); + self.src.push_str(&none); for (dst, result) in results.iter().zip(&none_results) { - builder.push_str(&format!("{dst} = {result}\n")); + uwriteln!(self.src, "{dst} = {result}"); } - builder.dedent(); - builder.push_str("else:\n"); - builder.indent(); - builder.push_str(&format!("{some_payload} = {op0}\n")); - builder.push_str(&some); + self.src.dedent(); + self.src.push_str("else:\n"); + self.src.indent(); + uwriteln!(self.src, "{some_payload} = {op0}"); + self.src.push_str(&some); for (dst, result) in results.iter().zip(&some_results) { - builder.push_str(&format!("{dst} = {result}\n")); + uwriteln!(self.src, "{dst} = {result}"); } - builder.dedent(); + self.src.dedent(); } Instruction::OptionLift { ty, .. } => { @@ -1228,24 +1939,27 @@ impl Bindgen for FunctionBindgen<'_> { let some_result = &some_results[0]; let result = self.locals.tmp("option"); - builder.print_var_declaration(&result, &Type::Id(*ty)); + uwrite!(self.src, "{result}: "); + self.print_ty(&Type::Id(*ty)); + self.src.push_str("\n"); let op0 = &operands[0]; - builder.push_str(&format!("if {op0} == 0:\n")); - builder.indent(); - builder.push_str(&none); - builder.push_str(&format!("{result} = None\n")); - builder.dedent(); - builder.push_str(&format!("elif {op0} == 1:\n")); - builder.indent(); - builder.push_str(&some); - builder.push_str(&format!("{result} = {some_result}\n")); - builder.dedent(); - - builder.push_str("else:\n"); - builder.indent(); - builder.push_str("raise TypeError(\"invalid variant discriminant for option\")\n"); - builder.dedent(); + uwriteln!(self.src, "if {op0} == 0:"); + self.src.indent(); + self.src.push_str(&none); + uwriteln!(self.src, "{result} = None"); + self.src.dedent(); + uwriteln!(self.src, "elif {op0} == 1:"); + self.src.indent(); + self.src.push_str(&some); + uwriteln!(self.src, "{result} = {some_result}"); + self.src.dedent(); + + self.src.push_str("else:\n"); + self.src.indent(); + self.src + .push_str("raise TypeError(\"invalid variant discriminant for option\")\n"); + self.src.dedent(); results.push(result); } @@ -1258,35 +1972,37 @@ impl Bindgen for FunctionBindgen<'_> { let (ok, ok_results) = self.blocks.pop().unwrap(); let err_payload = self.payloads.pop().unwrap(); let ok_payload = self.payloads.pop().unwrap(); + self.src.pyimport(".types", "Ok"); + self.src.pyimport(".types", "Err"); for _ in 0..result_types.len() { results.push(self.locals.tmp("variant")); } let op0 = &operands[0]; - builder.push_str(&format!("if isinstance({op0}, Ok):\n")); + uwriteln!(self.src, "if isinstance({op0}, Ok):"); - builder.indent(); - builder.push_str(&format!("{ok_payload} = {op0}.value\n")); - builder.push_str(&ok); + self.src.indent(); + uwriteln!(self.src, "{ok_payload} = {op0}.value"); + self.src.push_str(&ok); for (dst, result) in results.iter().zip(&ok_results) { - builder.push_str(&format!("{dst} = {result}\n")); + uwriteln!(self.src, "{dst} = {result}"); } - builder.dedent(); - builder.push_str(&format!("elif isinstance({op0}, Err):\n")); - builder.indent(); - builder.push_str(&format!("{err_payload} = {op0}.value\n")); - builder.push_str(&err); + self.src.dedent(); + uwriteln!(self.src, "elif isinstance({op0}, Err):"); + self.src.indent(); + uwriteln!(self.src, "{err_payload} = {op0}.value"); + self.src.push_str(&err); for (dst, result) in results.iter().zip(&err_results) { - builder.push_str(&format!("{dst} = {result}\n")); + uwriteln!(self.src, "{dst} = {result}"); } - builder.dedent(); - builder.push_str("else:\n"); - builder.indent(); - builder.push_str(&format!( + self.src.dedent(); + self.src.push_str("else:\n"); + self.src.indent(); + self.src.push_str(&format!( "raise TypeError(\"invalid variant specified for expected\")\n", )); - builder.dedent(); + self.src.dedent(); } Instruction::ResultLift { ty, .. } => { @@ -1296,26 +2012,31 @@ impl Bindgen for FunctionBindgen<'_> { let err_result = err_results.get(0).unwrap_or(&none); let ok_result = ok_results.get(0).unwrap_or(&none); + self.src.pyimport(".types", "Ok"); + self.src.pyimport(".types", "Err"); + let result = self.locals.tmp("expected"); - builder.print_var_declaration(&result, &Type::Id(*ty)); + uwrite!(self.src, "{result}: "); + self.print_ty(&Type::Id(*ty)); + self.src.push_str("\n"); let op0 = &operands[0]; - builder.push_str(&format!("if {op0} == 0:\n")); - builder.indent(); - builder.push_str(&ok); - builder.push_str(&format!("{result} = Ok({ok_result})\n")); - builder.dedent(); - builder.push_str(&format!("elif {op0} == 1:\n")); - builder.indent(); - builder.push_str(&err); - builder.push_str(&format!("{result} = Err({err_result})\n")); - builder.dedent(); - - builder.push_str("else:\n"); - builder.indent(); - builder + uwriteln!(self.src, "if {op0} == 0:"); + self.src.indent(); + self.src.push_str(&ok); + uwriteln!(self.src, "{result} = Ok({ok_result})"); + self.src.dedent(); + uwriteln!(self.src, "elif {op0} == 1:"); + self.src.indent(); + self.src.push_str(&err); + uwriteln!(self.src, "{result} = Err({err_result})"); + self.src.dedent(); + + self.src.push_str("else:\n"); + self.src.indent(); + self.src .push_str("raise TypeError(\"invalid variant discriminant for expected\")\n"); - builder.dedent(); + self.src.dedent(); results.push(result); } @@ -1323,120 +2044,115 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::EnumLower { .. } => results.push(format!("({}).value", operands[0])), Instruction::EnumLift { name, .. } => { - results.push(format!("{}({})", name.to_upper_camel_case(), operands[0])); + results.push(format!( + "{}{}({})", + self.self_module_path, + name.to_upper_camel_case(), + operands[0] + )); } - Instruction::ListCanonLower { element, realloc } => { - // Lowering only happens when we're passing lists into wasm, - // which forces us to always allocate, so this should always be - // `Some`. - let realloc = realloc.unwrap(); - self.needs_memory = true; - self.needs_realloc = Some(realloc.to_string()); + Instruction::ListCanonLower { element, .. } => { + let realloc = self.realloc.as_ref().unwrap(); + let memory = self.memory.as_ref().unwrap(); let ptr = self.locals.tmp("ptr"); let len = self.locals.tmp("len"); let array_ty = array_ty(iface, element).unwrap(); - builder.deps.needs_list_canon_lower = true; - let size = self.gen.sizes.size(element); - let align = self.gen.sizes.align(element); - builder.push_str(&format!( - "{}, {} = _list_canon_lower({}, ctypes.{}, {}, {}, realloc, memory, caller)\n", - ptr, len, operands[0], array_ty, size, align, - )); + let lower = self.gen.print_canon_lower(&mut self.src); + let size = self.sizes.size(element); + let align = self.sizes.align(element); + uwriteln!( + self.src, + "{ptr}, {len} = {lower}({}, ctypes.{array_ty}, {size}, {align}, {realloc}, {memory}, caller)", + operands[0], + ); results.push(ptr); results.push(len); } Instruction::ListCanonLift { element, .. } => { - self.needs_memory = true; + let memory = self.memory.as_ref().unwrap(); let ptr = self.locals.tmp("ptr"); let len = self.locals.tmp("len"); - builder.push_str(&format!("{} = {}\n", ptr, operands[0])); - builder.push_str(&format!("{} = {}\n", len, operands[1])); + uwriteln!(self.src, "{ptr} = {}", operands[0]); + uwriteln!(self.src, "{len} = {}", operands[1]); let array_ty = array_ty(iface, element).unwrap(); - builder.deps.needs_list_canon_lift = true; + let lift = self.gen.print_canon_lift(&mut self.src); + self.src.pyimport("ctypes", None); let lift = format!( - "_list_canon_lift({}, {}, {}, ctypes.{}, memory, caller)", - ptr, - len, - self.gen.sizes.size(element), - array_ty, + "{lift}({ptr}, {len}, {}, ctypes.{array_ty}, {memory}, caller)", + self.sizes.size(element), ); - builder.deps.pyimport("typing", "cast"); + self.src.pyimport("typing", "cast"); let list = self.locals.tmp("list"); - builder.push_str(&list); - builder.push_str(" = cast("); - builder.print_list(element); - builder.push_str(", "); - builder.push_str(&lift); - builder.push_str(")\n"); + uwrite!(self.src, "{list} = cast("); + self.print_list(element); + uwriteln!(self.src, ", {lift})"); results.push(list); } - Instruction::StringLower { realloc } => { - // Lowering only happens when we're passing strings into wasm, - // which forces us to always allocate, so this should always be - // `Some`. - let realloc = realloc.unwrap(); - self.needs_memory = true; - self.needs_realloc = Some(realloc.to_string()); + Instruction::StringLower { .. } => { + let realloc = self.realloc.as_ref().unwrap(); + let memory = self.memory.as_ref().unwrap(); let ptr = self.locals.tmp("ptr"); let len = self.locals.tmp("len"); - builder.deps.needs_encode_utf8 = true; - builder.push_str(&format!( - "{}, {} = _encode_utf8({}, realloc, memory, caller)\n", - ptr, len, operands[0], - )); + let encode = self.gen.print_encode_utf8(&mut self.src); + uwriteln!( + self.src, + "{ptr}, {len} = {encode}({}, {realloc}, {memory}, caller)", + operands[0], + ); results.push(ptr); results.push(len); } Instruction::StringLift => { - self.needs_memory = true; + let memory = self.memory.as_ref().unwrap(); let ptr = self.locals.tmp("ptr"); let len = self.locals.tmp("len"); - builder.push_str(&format!("{} = {}\n", ptr, operands[0])); - builder.push_str(&format!("{} = {}\n", len, operands[1])); - builder.deps.needs_decode_utf8 = true; - let result = format!("_decode_utf8(memory, caller, {}, {})", ptr, len); + uwriteln!(self.src, "{ptr} = {}", operands[0]); + uwriteln!(self.src, "{len} = {}", operands[1]); + let decode = self.gen.print_decode_utf8(&mut self.src); let list = self.locals.tmp("list"); - builder.push_str(&format!("{} = {}\n", list, result)); + uwriteln!( + self.src, + "{list} = {decode}({memory}, caller, {ptr}, {len})" + ); results.push(list); } - Instruction::ListLower { element, realloc } => { + Instruction::ListLower { element, .. } => { let base = self.payloads.pop().unwrap(); let e = self.payloads.pop().unwrap(); - let realloc = realloc.unwrap(); + let realloc = self.realloc.as_ref().unwrap(); let (body, body_results) = self.blocks.pop().unwrap(); assert!(body_results.is_empty()); let vec = self.locals.tmp("vec"); let result = self.locals.tmp("result"); let len = self.locals.tmp("len"); - self.needs_realloc = Some(realloc.to_string()); - let size = self.gen.sizes.size(element); - let align = self.gen.sizes.align(element); + let size = self.sizes.size(element); + let align = self.sizes.align(element); // first store our vec-to-lower in a temporary since we'll // reference it multiple times. - builder.push_str(&format!("{} = {}\n", vec, operands[0])); - builder.push_str(&format!("{} = len({})\n", len, vec)); + uwriteln!(self.src, "{vec} = {}", operands[0]); + uwriteln!(self.src, "{len} = len({vec})"); // ... then realloc space for the result in the guest module - builder.push_str(&format!( - "{} = realloc(caller, 0, 0, {}, {} * {})\n", - result, align, len, size, - )); - builder.push_str(&format!("assert(isinstance({}, int))\n", result)); + uwriteln!( + self.src, + "{result} = {realloc}(caller, 0, 0, {align}, {len} * {size})", + ); + uwriteln!(self.src, "assert(isinstance({result}, int))"); // ... then consume the vector and use the block to lower the // result. let i = self.locals.tmp("i"); - builder.push_str(&format!("for {} in range(0, {}):\n", i, len)); - builder.indent(); - builder.push_str(&format!("{} = {}[{}]\n", e, vec, i)); - builder.push_str(&format!("{} = {} + {} * {}\n", base, result, i, size)); - builder.push_str(&body); - builder.dedent(); + uwriteln!(self.src, "for {i} in range(0, {len}):"); + self.src.indent(); + uwriteln!(self.src, "{e} = {vec}[{i}]"); + uwriteln!(self.src, "{base} = {result} + {i} * {size}"); + self.src.push_str(&body); + self.src.dedent(); results.push(result); results.push(len); @@ -1445,24 +2161,26 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::ListLift { element, .. } => { let (body, body_results) = self.blocks.pop().unwrap(); let base = self.payloads.pop().unwrap(); - let size = self.gen.sizes.size(element); + let size = self.sizes.size(element); let ptr = self.locals.tmp("ptr"); let len = self.locals.tmp("len"); - builder.push_str(&format!("{} = {}\n", ptr, operands[0])); - builder.push_str(&format!("{} = {}\n", len, operands[1])); + uwriteln!(self.src, "{ptr} = {}", operands[0]); + uwriteln!(self.src, "{len} = {}", operands[1]); let result = self.locals.tmp("result"); - builder.push_str(&format!("{}: List[", result)); - builder.print_ty(element, true); - builder.push_str("] = []\n"); + uwrite!(self.src, "{result}: "); + self.print_list(element); + uwriteln!(self.src, " = []"); let i = self.locals.tmp("i"); - builder.push_str(&format!("for {} in range(0, {}):\n", i, len)); - builder.indent(); - builder.push_str(&format!("{} = {} + {} * {}\n", base, ptr, i, size)); - builder.push_str(&body); assert_eq!(body_results.len(), 1); - builder.push_str(&format!("{}.append({})\n", result, body_results[0])); - builder.dedent(); + let body_result0 = &body_results[0]; + + uwriteln!(self.src, "for {i} in range(0, {len}):"); + self.src.indent(); + uwriteln!(self.src, "{base} = {ptr} + {i} * {size}"); + self.src.push_str(&body); + uwriteln!(self.src, "{result}.append({body_result0})"); + self.src.dedent(); results.push(result); } @@ -1476,31 +2194,25 @@ impl Bindgen for FunctionBindgen<'_> { results.push(name.clone()); self.payloads.push(name); } - Instruction::CallWasm { - iface: _, - name, - sig, - } => { + Instruction::CallWasm { sig, .. } => { if sig.results.len() > 0 { for i in 0..sig.results.len() { if i > 0 { - builder.push_str(", "); + self.src.push_str(", "); } let ret = self.locals.tmp("ret"); - builder.push_str(&ret); + self.src.push_str(&ret); results.push(ret); } - builder.push_str(" = "); + self.src.push_str(" = "); } - builder.push_str(&self.src_object); - builder.push_str("._"); - builder.push_str(&name.to_snake_case()); - builder.push_str("(caller"); + self.src.push_str(&self.callee); + self.src.push_str("(caller"); if operands.len() > 0 { - builder.push_str(", "); + self.src.push_str(", "); } - builder.push_str(&operands.join(", ")); - builder.push_str(")\n"); + self.src.push_str(&operands.join(", ")); + self.src.push_str(")\n"); for (ty, name) in sig.results.iter().zip(results.iter()) { let ty = match ty { WasmType::I32 | WasmType::I64 => "int", @@ -1513,48 +2225,31 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::CallInterface { module: _, func } => { for i in 0..func.results.len() { if i > 0 { - builder.push_str(", "); + self.src.push_str(", "); } let result = self.locals.tmp("ret"); - builder.push_str(&result); + self.src.push_str(&result); results.push(result); } if func.results.len() > 0 { - builder.push_str(" = "); + self.src.push_str(" = "); } match &func.kind { FunctionKind::Freestanding => { - builder.push_str(&format!( - "host.{}({})", - func.name.to_snake_case(), - operands.join(", "), - )); + self.src + .push_str(&format!("{}({})", self.callee, operands.join(", "),)); } } - builder.push_str("\n"); + self.src.push_str("\n"); } - Instruction::Return { amt, func, .. } => { - if !self.gen.in_import && iface.guest_export_needs_post_return(func) { - let name = format!("cabi_post_{}", func.name); - let exports = self - .gen - .guest_exports - .entry(iface.name.to_string()) - .or_insert_with(Exports::default); - exports.fields.insert( - name.clone(), - Export { - python_type: "wasmtime.Func", - name: name.clone(), - }, - ); - let name = name.to_snake_case(); - builder.push_str(&format!("{}._{name}(caller, ret)\n", self.src_object)); + Instruction::Return { amt, .. } => { + if let Some(s) = &self.post_return { + self.src.push_str(&format!("{s}(caller, ret)\n")); } match amt { 0 => {} - 1 => builder.push_str(&format!("return {}\n", operands[0])), + 1 => self.src.push_str(&format!("return {}\n", operands[0])), _ => { self.src .push_str(&format!("return ({})\n", operands.join(", "))); @@ -1577,19 +2272,11 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::I32Store8 { offset } => self.store("c_uint8", *offset, operands), Instruction::I32Store16 { offset } => self.store("c_uint16", *offset, operands), - Instruction::Malloc { - realloc, - size, - align, - } => { - self.needs_realloc = Some(realloc.to_string()); + Instruction::Malloc { size, align, .. } => { + let realloc = self.realloc.as_ref().unwrap(); let ptr = self.locals.tmp("ptr"); - builder.push_str(&format!( - " - {ptr} = realloc(caller, 0, 0, {align}, {size}) - assert(isinstance({ptr}, int)) - ", - )); + uwriteln!(self.src, "{ptr} = {realloc}(caller, 0, 0, {align}, {size})"); + uwriteln!(self.src, "assert(isinstance({ptr}, int))"); results.push(ptr); } @@ -1598,29 +2285,37 @@ impl Bindgen for FunctionBindgen<'_> { } } -fn py_type_class_of(ty: &Type) -> PyTypeClass { - match ty { - Type::Bool - | Type::U8 - | Type::U16 - | Type::U32 - | Type::U64 - | Type::S8 - | Type::S16 - | Type::S32 - | Type::S64 => PyTypeClass::Int, - Type::Float32 | Type::Float64 => PyTypeClass::Float, - Type::Char | Type::String => PyTypeClass::Str, - Type::Id(_) => PyTypeClass::Custom, +fn classify_union(types: impl Iterator) -> PyUnionRepresentation { + #[derive(Debug, Hash, PartialEq, Eq)] + enum PyTypeClass { + Int, + Str, + Float, + Custom, } -} -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -enum PyTypeClass { - Int, - Str, - Float, - Custom, + let mut py_type_classes = HashSet::new(); + for ty in types { + let class = match ty { + Type::Bool + | Type::U8 + | Type::U16 + | Type::U32 + | Type::U64 + | Type::S8 + | Type::S16 + | Type::S32 + | Type::S64 => PyTypeClass::Int, + Type::Float32 | Type::Float64 => PyTypeClass::Float, + Type::Char | Type::String => PyTypeClass::Str, + Type::Id(_) => PyTypeClass::Custom, + }; + if !py_type_classes.insert(class) { + // Some of the cases are not distinguishable + return PyUnionRepresentation::Wrapped; + } + } + PyUnionRepresentation::Raw } fn wasm_ty_ctor(ty: WasmType) -> &'static str { @@ -1640,3 +2335,27 @@ fn wasm_ty_typing(ty: WasmType) -> &'static str { WasmType::F64 => "float", } } + +/// Creates a temporary `InterfaceGenerator` with the given parameters to get +/// access to the various `print_*` methods on it. +fn with_igen( + src: &mut Source, + gen: &mut WasmtimePy, + iface: &Interface, + at_root: bool, + self_module_path: &str, + f: impl FnOnce(&mut InterfaceGenerator<'_>) -> R, +) -> R { + // The `print_ty` method is on `InterfaceGenerator` so jerry-rig one of + // those "quickly" to defer to it. + let mut gen = InterfaceGenerator { + src: mem::take(src), + gen, + iface, + at_root, + self_module_path, + }; + let ret = f(&mut gen); + *src = gen.src; + ret +} diff --git a/crates/gen-host-wasmtime-py/src/source.rs b/crates/gen-host-wasmtime-py/src/source.rs index 00703f18c..a8ef64551 100644 --- a/crates/gen-host-wasmtime-py/src/source.rs +++ b/crates/gen-host-wasmtime-py/src/source.rs @@ -1,13 +1,19 @@ -use heck::*; +use crate::imports::PyImports; +use std::fmt::{self, Write}; +use std::mem; use wit_bindgen_core::wit_parser::*; -use crate::dependencies::Dependencies; - /// A [Source] represents some unit of Python code /// and keeps track of its indent. #[derive(Default)] pub struct Source { - s: String, + body: Body, + imports: PyImports, +} + +#[derive(Default)] +pub struct Body { + contents: String, indent: usize, } @@ -21,7 +27,7 @@ impl Source { let lines = src.lines().collect::>(); let mut trim = None; for (i, line) in lines.iter().enumerate() { - self.s.push_str(if lines.len() == 1 { + self.body.contents.push_str(if lines.len() == 1 { line } else { let trim = match trim { @@ -84,396 +90,82 @@ impl Source { /// Indent the source one level. pub fn indent(&mut self) { - self.indent += 4; - self.s.push_str(" "); + self.body.indent += 4; + self.body.contents.push_str(" "); } /// Unindent, or in Python terms "dedent", /// the source one level. pub fn dedent(&mut self) { - self.indent -= 4; - assert!(self.s.ends_with(" ")); - self.s.pop(); - self.s.pop(); - self.s.pop(); - self.s.pop(); + self.body.indent -= 4; + assert!(self.body.contents.ends_with(" ")); + self.body.contents.pop(); + self.body.contents.pop(); + self.body.contents.pop(); + self.body.contents.pop(); } /// Go to the next line and apply any indent. pub fn newline(&mut self) { - self.s.push_str("\n"); - for _ in 0..self.indent { - self.s.push_str(" "); + self.body.contents.push_str("\n"); + for _ in 0..self.body.indent { + self.body.contents.push_str(" "); } } -} - -impl std::ops::Deref for Source { - type Target = str; - fn deref(&self) -> &str { - &self.s - } -} - -impl From for String { - fn from(s: Source) -> String { - s.s - } -} - -/// [SourceBuilder] combines together a [Source] -/// with other contextual information and state. -/// -/// This allows you to generate code for the Source using -/// high-level tools that take care of updating dependencies -/// and retrieving interface details. -/// -/// You can create a [SourceBuilder] easily using a [Source] -/// ``` -/// # use wit_bindgen_gen_host_wasmtime_py::dependencies::Dependencies; -/// # use wit_bindgen_core::wit_parser::{Interface, Type}; -/// # use wit_bindgen_gen_host_wasmtime_py::source::Source; -/// # let mut deps = Dependencies::default(); -/// # let mut interface = Interface::default(); -/// # let iface = &interface; -/// let mut source = Source::default(); -/// let mut builder = source.builder(&mut deps, iface); -/// builder.print_ty(&Type::Bool, false); -/// ``` -pub struct SourceBuilder<'s, 'd, 'i> { - source: &'s mut Source, - pub deps: &'d mut Dependencies, - iface: &'i Interface, -} - -impl<'s, 'd, 'i> Source { - /// Create a [SourceBuilder] for the current source. - pub fn builder( - &'s mut self, - deps: &'d mut Dependencies, - iface: &'i Interface, - ) -> SourceBuilder<'s, 'd, 'i> { - SourceBuilder { - source: self, - deps, - iface, - } - } -} -impl<'s, 'd, 'i> SourceBuilder<'s, 'd, 'i> { - /// See [Dependencies::pyimport]. pub fn pyimport<'a>(&mut self, module: &str, name: impl Into>) { - self.deps.pyimport(module, name) - } - - /// Appends a type's Python representation to this `Source`. - /// Records any required intrinsics and imports in the `deps`. - /// Uses Python forward reference syntax (e.g. 'Foo') - /// on the root type only if `forward_ref` is true. - pub fn print_ty(&mut self, ty: &Type, forward_ref: bool) { - match ty { - Type::Bool => self.push_str("bool"), - Type::U8 - | Type::S8 - | Type::U16 - | Type::S16 - | Type::U32 - | Type::S32 - | Type::U64 - | Type::S64 => self.push_str("int"), - Type::Float32 | Type::Float64 => self.push_str("float"), - Type::Char => self.push_str("str"), - Type::String => self.push_str("str"), - Type::Id(id) => { - let ty = &self.iface.types[*id]; - if let Some(name) = &ty.name { - self.push_str(&name.to_upper_camel_case()); - return; - } - match &ty.kind { - TypeDefKind::Type(t) => self.print_ty(t, forward_ref), - TypeDefKind::Tuple(t) => self.print_tuple(t), - TypeDefKind::Record(_) - | TypeDefKind::Flags(_) - | TypeDefKind::Enum(_) - | TypeDefKind::Variant(_) - | TypeDefKind::Union(_) => { - unreachable!() - } - TypeDefKind::Option(t) => { - self.deps.pyimport("typing", "Optional"); - self.push_str("Optional["); - self.print_ty(t, true); - self.push_str("]"); - } - TypeDefKind::Result(r) => { - self.deps.needs_result = true; - self.push_str("Result["); - self.print_optional_ty(r.ok.as_ref(), true); - self.push_str(", "); - self.print_optional_ty(r.err.as_ref(), true); - self.push_str("]"); - } - TypeDefKind::List(t) => self.print_list(t), - TypeDefKind::Future(t) => { - self.push_str("Future["); - self.print_optional_ty(t.as_ref(), true); - self.push_str("]"); - } - TypeDefKind::Stream(s) => { - self.push_str("Stream["); - self.print_optional_ty(s.element.as_ref(), true); - self.push_str(", "); - self.print_optional_ty(s.end.as_ref(), true); - self.push_str("]"); - } - } - } - } - } - - /// Appends an optional type's Python representation to this `Source`, or - /// Python `None` if the type is `None`. - /// Records any required intrinsics and imports in the `deps`. - /// Uses Python forward reference syntax (e.g. 'Foo') - /// on the root type only if `forward_ref` is true. - pub fn print_optional_ty(&mut self, ty: Option<&Type>, forward_ref: bool) { - match ty { - Some(ty) => self.print_ty(ty, forward_ref), - None => self.push_str("None"), - } - } - - /// Appends a tuple type's Python representation to this `Source`. - /// Records any required intrinsics and imports in the `deps`. - /// Uses Python forward reference syntax (e.g. 'Foo') for named type parameters. - pub fn print_tuple(&mut self, tuple: &Tuple) { - if tuple.types.is_empty() { - return self.push_str("None"); - } - self.deps.pyimport("typing", "Tuple"); - self.push_str("Tuple["); - for (i, t) in tuple.types.iter().enumerate() { - if i > 0 { - self.push_str(", "); - } - self.print_ty(t, true); - } - self.push_str("]"); + self.imports.pyimport(module, name.into()) } - /// Appends (potentially multiple) return type's Python representation to this `Source`. - /// Records any required intrinsics and imports in the `deps`. - /// Uses Python forward reference syntax (e.g. 'Foo') for named type parameters. - pub fn print_return_types(&mut self, return_types: &Params) { - // Return None for empty wit return types. - if return_types.len() == 0 { - return self.push_str("None"); - } - // Return the type directly for a single wit return type (whether named or not). - if return_types.len() == 1 { - let ty = &return_types.iter().next().unwrap().1; - self.print_ty(ty, true); - } else { - // Return a tuple for multiple wit return types. - self.deps.pyimport("typing", "Tuple"); - self.push_str("Tuple["); - for (i, (_, t)) in return_types.iter().enumerate() { - if i > 0 { - self.push_str(", "); - } - self.print_ty(t, true); - } - self.push_str("]"); - } + pub fn finish(&self) -> String { + let mut ret = self.imports.finish(); + ret.push_str(&self.body.contents); + return ret; } - /// Appends a Python type representing a sequence of the `element` type to this `Source`. - /// If the element type is `Type::U8`, the result type is `bytes` otherwise it is a `List[T]` - /// Records any required intrinsics and imports in the `deps`. - /// Uses Python forward reference syntax (e.g. 'Foo') for named type parameters. - pub fn print_list(&mut self, element: &Type) { - match element { - Type::U8 => self.push_str("bytes"), - t => { - self.deps.pyimport("typing", "List"); - self.push_str("List["); - self.print_ty(t, true); - self.push_str("]"); - } - } + pub fn is_empty(&self) -> bool { + self.imports.is_empty() && self.body.contents.is_empty() } - /// Print variable declaration. - /// Brings name into scope and binds type to it. - pub fn print_var_declaration<'a>(&mut self, name: &'a str, ty: &Type) { - self.push_str(name); - self.push_str(": "); - self.print_ty(ty, true); - self.push_str("\n"); + pub fn take_body(&mut self) -> Body { + mem::take(&mut self.body) } - pub fn print_sig(&mut self, func: &Function, in_import: bool) -> Vec { - self.source.push_str("def "); - self.source.push_str(&func.name.to_snake_case()); - if in_import { - self.source.push_str("(self"); - } else { - self.source.push_str("(self, caller: wasmtime.Store"); - } - let mut params = Vec::new(); - for (param, ty) in func.params.iter() { - self.source.push_str(", "); - self.source.push_str(¶m.to_snake_case()); - params.push(param.to_snake_case()); - self.source.push_str(": "); - self.print_ty(ty, true); - } - self.source.push_str(") -> "); - match func.results.len() { - 0 => self.source.push_str("None"), - 1 => self.print_ty(func.results.iter_types().next().unwrap(), true), - _ => { - self.deps.pyimport("typing", "Tuple"); - self.push_str("Tuple["); - for (i, ty) in func.results.iter_types().enumerate() { - if i > 0 { - self.source.push_str(", "); - } - self.print_ty(ty, true); - } - self.source.push_str("]"); - } - } - params - } - - /// Print a wrapped union definition. - /// e.g. - /// ```py - /// @dataclass - /// class Foo0: - /// value: int - /// - /// @dataclass - /// class Foo1: - /// value: int - /// - /// Foo = Union[Foo0, Foo1] - /// ``` - pub fn print_union_wrapped(&mut self, name: &str, union: &Union, docs: &Docs) { - self.deps.pyimport("dataclasses", "dataclass"); - let mut cases = Vec::new(); - let name = name.to_upper_camel_case(); - for (i, case) in union.cases.iter().enumerate() { - self.source.push_str("@dataclass\n"); - let name = format!("{name}{i}"); - self.source.push_str(&format!("class {name}:\n")); - self.source.indent(); - self.source.docstring(&case.docs); - self.source.push_str("value: "); - self.print_ty(&case.ty, true); - self.source.newline(); - self.source.dedent(); - self.source.newline(); - cases.push(name); - } - - self.deps.pyimport("typing", "Union"); - self.source.comment(docs); - self.source - .push_str(&format!("{name} = Union[{}]\n", cases.join(", "))); - self.source.newline(); - } - - pub fn print_union_raw(&mut self, name: &str, union: &Union, docs: &Docs) { - self.deps.pyimport("typing", "Union"); - self.source.comment(docs); - for case in union.cases.iter() { - self.source.comment(&case.docs); - } - self.source.push_str(&name.to_upper_camel_case()); - self.source.push_str(" = Union["); - let mut first = true; - for case in union.cases.iter() { - if !first { - self.source.push_str(","); - } - self.print_ty(&case.ty, true); - first = false; - } - self.source.push_str("]\n\n"); - } -} - -impl<'s, 'd, 'i> std::ops::Deref for SourceBuilder<'s, 'd, 'i> { - type Target = Source; - fn deref(&self) -> &Source { - &self.source + pub fn replace_body(&mut self, body: Body) -> String { + mem::replace(&mut self.body, body).contents } } -impl<'s, 'd, 'i> std::ops::DerefMut for SourceBuilder<'s, 'd, 'i> { - fn deref_mut(&mut self) -> &mut Source { - &mut self.source +impl Write for Source { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.push_str(s); + Ok(()) } } #[cfg(test)] mod tests { - use std::collections::{BTreeMap, BTreeSet}; - use super::*; #[test] fn simple_append() { let mut s = Source::default(); s.push_str("x"); - assert_eq!(s.s, "x"); + assert_eq!(s.body.contents, "x"); s.push_str("y"); - assert_eq!(s.s, "xy"); + assert_eq!(s.body.contents, "xy"); s.push_str("z "); - assert_eq!(s.s, "xyz "); + assert_eq!(s.body.contents, "xyz "); s.push_str(" a "); - assert_eq!(s.s, "xyz a "); + assert_eq!(s.body.contents, "xyz a "); s.push_str("\na"); - assert_eq!(s.s, "xyz a \na"); + assert_eq!(s.body.contents, "xyz a \na"); } #[test] fn trim_ws() { let mut s = Source::default(); s.push_str("def foo():\n return 1\n"); - assert_eq!(s.s, "def foo():\n return 1\n"); - } - - #[test] - fn print_list_bytes() { - // If the element type is u8, it is interpreted as `bytes` - let mut deps = Dependencies::default(); - let iface = Interface::default(); - let mut source = Source::default(); - let mut builder = source.builder(&mut deps, &iface); - builder.print_list(&Type::U8); - drop(builder); - assert_eq!(source.s, "bytes"); - assert_eq!(deps.pyimports, BTreeMap::default()); - } - - #[test] - fn print_list_non_bytes() { - // If the element type is u8, it is interpreted as `bytes` - let mut deps = Dependencies::default(); - let iface = Interface::default(); - let mut source = Source::default(); - let mut builder = source.builder(&mut deps, &iface); - builder.print_list(&Type::Float32); - drop(builder); - assert_eq!(source.s, "List[float]"); - assert_eq!( - deps.pyimports, - BTreeMap::from([("typing".into(), Some(BTreeSet::from(["List".into()])))]) - ); + assert_eq!(s.body.contents, "def foo():\n return 1\n"); } } diff --git a/crates/gen-host-wasmtime-py/tests/codegen.rs b/crates/gen-host-wasmtime-py/tests/codegen.rs index 5b84b3ecf..35425ae44 100644 --- a/crates/gen-host-wasmtime-py/tests/codegen.rs +++ b/crates/gen-host-wasmtime-py/tests/codegen.rs @@ -5,16 +5,20 @@ macro_rules! gen_test { ($name:ident $test:tt $dir:ident) => { #[test] fn $name() { - test_helpers::run_codegen_test( + drop(include_str!($test)); + test_helpers::run_component_codegen_test( "wasmtime-py", - std::path::Path::new($test) - .file_stem() - .unwrap() - .to_str() - .unwrap(), - include_str!($test), + $test.as_ref(), test_helpers::Direction::$dir, - wit_bindgen_gen_host_wasmtime_py::Opts::default().build(), + |name, component, files| { + wit_bindgen_core::component::generate( + &mut *wit_bindgen_gen_host_wasmtime_py::Opts::default().build(), + name, + component, + files, + ) + .unwrap() + }, super::verify, ) } @@ -38,7 +42,7 @@ mod imports { fn verify(dir: &Path, _name: &str) { test_helpers::run_command( Command::new("mypy") - .arg(dir.join("bindings.py")) + .arg(dir) .arg("--config-file") .arg("mypy.ini"), ); diff --git a/crates/gen-host-wasmtime-py/tests/helpers.py b/crates/gen-host-wasmtime-py/tests/helpers.py new file mode 100644 index 000000000..72de98f6e --- /dev/null +++ b/crates/gen-host-wasmtime-py/tests/helpers.py @@ -0,0 +1,8 @@ +import sys + +class TestWasi: + def log(self, list: bytes) -> None: + sys.stdout.buffer.write(list) + + def log_err(self, list: bytes) -> None: + sys.stderr.buffer.write(list) diff --git a/crates/gen-host-wasmtime-py/tests/runtime.rs b/crates/gen-host-wasmtime-py/tests/runtime.rs index a3bd1fb13..f5d6b6ff8 100644 --- a/crates/gen-host-wasmtime-py/tests/runtime.rs +++ b/crates/gen-host-wasmtime-py/tests/runtime.rs @@ -1,71 +1,42 @@ -use std::fs; use std::path::Path; use std::process::Command; -use wit_bindgen_core::Generator; -test_helpers::runtime_tests!("py"); +test_helpers::runtime_component_tests!("py"); -fn execute(name: &str, wasm: &Path, py: &Path, imports: &Path, exports: &Path) { - let dir = test_helpers::test_directory("runtime", "wasmtime-py", name); - fs::create_dir_all(&dir.join("imports")).unwrap(); - fs::create_dir_all(&dir.join("exports")).unwrap(); +fn execute(name: &str, lang: &str, wasm: &Path, py: &Path) { + let dir = test_helpers::test_directory("runtime", "wasmtime-py", &format!("{lang}/{name}")); + let wasm = std::fs::read(wasm).unwrap(); println!("OUT_DIR = {:?}", dir); println!("Generating bindings..."); - // We call `generate_all` with exports from the imports.wit file, and - // imports from the exports.wit wit file. It's reversed because we're - // implementing the host side of these APIs. - let iface = wit_bindgen_core::wit_parser::Interface::parse_file(imports).unwrap(); let mut files = Default::default(); - wit_bindgen_gen_host_wasmtime_py::Opts::default() - .build() - .generate_all(&[], &[iface], &mut files); + wit_bindgen_core::component::generate( + &mut *wit_bindgen_gen_host_wasmtime_py::Opts::default().build(), + name, + &wasm, + &mut files, + ) + .unwrap(); for (file, contents) in files.iter() { - fs::write(dir.join("imports").join(file), contents).unwrap(); + let dst = dir.join(file); + std::fs::create_dir_all(dst.parent().unwrap()).unwrap(); + std::fs::write(&dst, contents).unwrap(); } - fs::write(dir.join("imports").join("__init__.py"), "").unwrap(); - - let iface = wit_bindgen_core::wit_parser::Interface::parse_file(exports).unwrap(); - let mut files = Default::default(); - wit_bindgen_gen_host_wasmtime_py::Opts::default() - .build() - .generate_all(&[iface], &[], &mut files); - for (file, contents) in files.iter() { - fs::write(dir.join("exports").join(file), contents).unwrap(); - } - fs::write(dir.join("exports").join("__init__.py"), "").unwrap(); + let cwd = std::env::current_dir().unwrap(); println!("Running mypy..."); - exec( + let pathdir = std::env::join_paths([ + dir.parent().unwrap().to_str().unwrap(), + cwd.join("tests").to_str().unwrap(), + ]) + .unwrap(); + test_helpers::run_command( Command::new("mypy") - .env("MYPYPATH", &dir) + .env("MYPYPATH", &pathdir) .arg(py) .arg("--cache-dir") .arg(dir.parent().unwrap().join("mypycache").join(name)), ); - exec( - Command::new("python3") - .env("PYTHONPATH", &dir) - .arg(py) - .arg(wasm), - ); -} - -fn exec(cmd: &mut Command) { - println!("{:?}", cmd); - let output = cmd.output().unwrap(); - if output.status.success() { - return; - } - println!("status: {}", output.status); - println!( - "stdout ---\n {}", - String::from_utf8_lossy(&output.stdout).replace("\n", "\n ") - ); - println!( - "stderr ---\n {}", - String::from_utf8_lossy(&output.stderr).replace("\n", "\n ") - ); - panic!("no success"); + test_helpers::run_command(Command::new("python3").env("PYTHONPATH", &pathdir).arg(py)); } diff --git a/crates/gen-host-wasmtime-rust/tests/runtime.rs b/crates/gen-host-wasmtime-rust/tests/runtime.rs index 794f306a3..c7655f5a0 100644 --- a/crates/gen-host-wasmtime-rust/tests/runtime.rs +++ b/crates/gen-host-wasmtime-rust/tests/runtime.rs @@ -99,4 +99,11 @@ impl testwasi::Testwasi for TestWasi { } Ok(()) } + + fn log_err(&mut self, bytes: Vec) { + match std::str::from_utf8(&bytes) { + Ok(s) => eprint!("{}", s), + Err(_) => eprintln!("\nbinary: {:?}", bytes), + } + } } diff --git a/crates/test-helpers/macros/src/lib.rs b/crates/test-helpers/macros/src/lib.rs index f04766d68..4b3807df9 100644 --- a/crates/test-helpers/macros/src/lib.rs +++ b/crates/test-helpers/macros/src/lib.rs @@ -49,21 +49,10 @@ pub fn codegen_tests(input: TokenStream) -> TokenStream { (quote::quote!(#(#tests)*)).into() } -/// Invoked as `runtime_tests!("js")` to run a top-level `execute` function with -/// all host tests that use the "js" extension. -#[proc_macro] -pub fn runtime_tests(input: TokenStream) -> TokenStream { - generate_runtime_tests(input, false) -} - -/// Same as `runtime_tests!` but iterates over component wasms instead of core -/// wasms. +/// Invoked as `runtime_component_tests!("js")` to run a top-level `execute` +/// function with all host tests that use the "js" extension. #[proc_macro] pub fn runtime_component_tests(input: TokenStream) -> TokenStream { - generate_runtime_tests(input, true) -} - -fn generate_runtime_tests(input: TokenStream, use_components: bool) -> TokenStream { let host_extension = input.to_string(); let host_extension = host_extension.trim_matches('"'); let host_file = format!("host.{}", host_extension); @@ -75,41 +64,23 @@ fn generate_runtime_tests(input: TokenStream, use_components: bool) -> TokenStre continue; } let name_str = entry.file_name().unwrap().to_str().unwrap(); - for (lang, name, wasm, component) in WASMS { + for (lang, name, _wasm, component) in WASMS { if *name != name_str { continue; } let name = quote::format_ident!("{}_{}", name_str, lang); let host_file = entry.join(&host_file).to_str().unwrap().to_string(); - let import_wit = entry.join("imports.wit").to_str().unwrap().to_string(); - let export_wit = entry.join("exports.wit").to_str().unwrap().to_string(); - if use_components { - tests.push(quote::quote! { - #[test] - fn #name() { - crate::execute( - #name_str, - #lang, - #component.as_ref(), - #host_file.as_ref(), - ) - } - }); - } else { - let name_str = format!("{}_{}", name_str, lang); - tests.push(quote::quote! { - #[test] - fn #name() { - crate::execute( - #name_str, - #wasm.as_ref(), - #host_file.as_ref(), - #import_wit.as_ref(), - #export_wit.as_ref(), - ) - } - }); - } + tests.push(quote::quote! { + #[test] + fn #name() { + crate::execute( + #name_str, + #lang, + #component.as_ref(), + #host_file.as_ref(), + ) + } + }); } } diff --git a/crates/test-helpers/src/lib.rs b/crates/test-helpers/src/lib.rs index da9ba5c84..fefd0c84f 100644 --- a/crates/test-helpers/src/lib.rs +++ b/crates/test-helpers/src/lib.rs @@ -60,7 +60,10 @@ pub fn test_directory(suite_name: &str, gen_name: &str, wit_name: &str) -> PathB me.pop(); // chop off 'debug' / 'release' me.push(format!("{suite_name}-tests")); me.push(gen_name); - me.push(wit_name); + + // replace `-` with `_` for Python where the directory needs to be a valid + // Python package name. + me.push(wit_name.replace("-", "_")); drop(fs::remove_dir_all(&me)); fs::create_dir_all(&me).unwrap(); diff --git a/crates/wasi_snapshot_preview1/src/lib.rs b/crates/wasi_snapshot_preview1/src/lib.rs index 9b2681c26..ba4c5443b 100644 --- a/crates/wasi_snapshot_preview1/src/lib.rs +++ b/crates/wasi_snapshot_preview1/src/lib.rs @@ -83,7 +83,7 @@ pub extern "C" fn fd_write( mut iovs_len: usize, nwritten: *mut Size, ) -> Errno { - if fd != 1 { + if fd != 1 && fd != 2 { unreachable(); } unsafe { @@ -100,7 +100,12 @@ pub extern "C" fn fd_write( let ptr = (*iovs_ptr).buf; let len = (*iovs_ptr).buf_len; - testwasi::log(core::slice::from_raw_parts(ptr, len)); + let slice = core::slice::from_raw_parts(ptr, len); + if fd == 1 { + testwasi::log(slice); + } else { + testwasi::log_err(slice); + } *nwritten = len; } diff --git a/crates/wasi_snapshot_preview1/testwasi.wit b/crates/wasi_snapshot_preview1/testwasi.wit index 4c887abcc..9b7129602 100644 --- a/crates/wasi_snapshot_preview1/testwasi.wit +++ b/crates/wasi_snapshot_preview1/testwasi.wit @@ -1 +1,2 @@ log: func(bytes: list) +log-err: func(bytes: list) diff --git a/crates/wit-bindgen-demo/src/lib.rs b/crates/wit-bindgen-demo/src/lib.rs index 887834779..1d4593dbf 100644 --- a/crates/wit-bindgen-demo/src/lib.rs +++ b/crates/wit-bindgen-demo/src/lib.rs @@ -78,6 +78,13 @@ fn render(lang: demo::Lang, wit: &str, files: &mut Files, options: &demo::Option gen.generate("demo", &interfaces, files); }; + // This generator takes a component as input as opposed to an `Interface`. + // To work with this demo a dummy component is synthesized to generate + // bindings for. The dummy core wasm module is created from the + // `test_helpers` support this workspace already offsets, and then + // `wit-component` is used to synthesize a component from our input + // interface and dummy module. Finally this component is fed into the host + // generator which gives us the files we want. let gen_component = |mut gen: Box, files: &mut Files| { let (imports, interface) = if options.import { (vec![iface.clone()], None) @@ -110,10 +117,10 @@ fn render(lang: demo::Lang, wit: &str, files: &mut Files, options: &demo::Option opts.tracing = options.wasmtime_tracing; gen_world(opts.build(), files) } - demo::Lang::WasmtimePy => gen_world_legacy( - Box::new(wit_bindgen_gen_host_wasmtime_py::Opts::default().build()), + demo::Lang::WasmtimePy => gen_component( + wit_bindgen_gen_host_wasmtime_py::Opts::default().build(), files, - ), + )?, demo::Lang::C => gen_world_legacy( Box::new(wit_bindgen_gen_guest_c::Opts::default().build()), files, @@ -122,15 +129,6 @@ fn render(lang: demo::Lang, wit: &str, files: &mut Files, options: &demo::Option Box::new(wit_bindgen_gen_markdown::Opts::default().build()), files, ), - - // JS is different from other languages at this time where it takes a - // component as input as opposed to an `Interface`. To work with this - // demo a dummy component is synthesized to generate bindings for. The - // dummy core wasm module is created from the `test_helpers` support - // this workspace already offsets, and then `wit-component` is used to - // synthesize a component from our input interface and dummy module. - // Finally this component is fed into the host generator which gives us - // the files we want. demo::Lang::Js => gen_component(wit_bindgen_gen_host_js::Opts::default().build(), files)?, } diff --git a/src/bin/wit-bindgen.rs b/src/bin/wit-bindgen.rs index 012619618..34d9e952a 100644 --- a/src/bin/wit-bindgen.rs +++ b/src/bin/wit-bindgen.rs @@ -54,9 +54,7 @@ enum HostGenerator { #[clap(flatten)] opts: wit_bindgen_gen_host_wasmtime_py::Opts, #[clap(flatten)] - common: Common, - #[clap(flatten)] - world: LegacyWorld, + component: ComponentOpts, }, /// Generates bindings for JavaScript hosts. Js { @@ -193,9 +191,9 @@ impl Opt { | Category::Guest(GuestGenerator::C { common, .. }) | Category::Guest(GuestGenerator::TeavmJava { common, .. }) | Category::Host(HostGenerator::WasmtimeRust { common, .. }) - | Category::Host(HostGenerator::WasmtimePy { common, .. }) | Category::Markdown { common, .. } => common, - Category::Host(HostGenerator::Js { component, .. }) => &component.common, + Category::Host(HostGenerator::Js { component, .. }) + | Category::Host(HostGenerator::WasmtimePy { component, .. }) => &component.common, } } } @@ -209,8 +207,8 @@ fn main() -> Result<()> { Category::Host(HostGenerator::WasmtimeRust { opts, world, .. }) => { gen_world(opts.build(), world, &mut files)?; } - Category::Host(HostGenerator::WasmtimePy { opts, world, .. }) => { - gen_legacy_world(Box::new(opts.build()), world, &mut files)?; + Category::Host(HostGenerator::WasmtimePy { opts, component }) => { + gen_component(opts.build(), component, &mut files)?; } Category::Host(HostGenerator::Js { opts, component }) => { gen_component(opts.build(), component, &mut files)?; diff --git a/tests/codegen/unions.wit b/tests/codegen/unions.wit index c7d3422cc..f8e3207a9 100644 --- a/tests/codegen/unions.wit +++ b/tests/codegen/unions.wit @@ -54,4 +54,4 @@ union distinguishable-num { add-one-distinguishable-num: func(num: distinguishable-num) -> distinguishable-num // Returns the index of the case provided -identify-distinguishable-num: func(num: distinguishable-num) -> u8 \ No newline at end of file +identify-distinguishable-num: func(num: distinguishable-num) -> u8 diff --git a/tests/runtime/flavorful/host.py b/tests/runtime/flavorful/host.py index 355404caf..9134499c1 100644 --- a/tests/runtime/flavorful/host.py +++ b/tests/runtime/flavorful/host.py @@ -1,9 +1,9 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports from typing import Tuple, List -import exports.bindings as e -import imports.bindings as i -import sys +from helpers import TestWasi +from flavorful import Flavorful, FlavorfulImports +import flavorful as e +from flavorful.imports import imports as i +from flavorful.types import Result, Ok, Err import wasmtime class MyImports: @@ -23,7 +23,7 @@ def f_list_in_record4(self, a: i.ListInAlias) -> i.ListInAlias: def f_list_in_variant1(self, a: i.ListInVariant1V1, b: i.ListInVariant1V2, c: i.ListInVariant1V3) -> None: assert(a == 'foo') - assert(b == i.Err('bar')) + assert(b == Err('bar')) assert(c == 'baz') def f_list_in_variant2(self) -> i.ListInVariant2: @@ -33,37 +33,27 @@ def f_list_in_variant3(self, a: i.ListInVariant3) -> i.ListInVariant3: assert(a == 'input3') return 'output3' - def errno_result(self) -> i.Result[None, i.MyErrno]: - return i.Err(i.MyErrno.B) + def errno_result(self) -> Result[None, i.MyErrno]: + return Err(i.MyErrno.B) def list_typedefs(self, a: i.ListTypedef, c: i.ListTypedef3) -> Tuple[i.ListTypedef2, i.ListTypedef3]: assert(a == 'typedef1') assert(c == ['typedef2']) return (b'typedef3', ['typedef4']) - def list_of_variants(self, a: List[bool], b: List[i.Result[None, None]], c: List[i.MyErrno]) -> Tuple[List[bool], List[i.Result[None, None]], List[i.MyErrno]]: + def list_of_variants(self, a: List[bool], b: List[Result[None, None]], c: List[i.MyErrno]) -> Tuple[List[bool], List[Result[None, None]], List[i.MyErrno]]: assert(a == [True, False]) - assert(b == [i.Ok(None), i.Err(None)]) + assert(b == [Ok(None), Err(None)]) assert(c == [i.MyErrno.SUCCESS, i.MyErrno.A]) return ( [False, True], - [i.Err(None), i.Ok(None)], + [Err(None), Ok(None)], [i.MyErrno.A, i.MyErrno.B], ) -def run(wasm_file: str) -> None: +def run() -> None: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) - - imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = Flavorful(store, FlavorfulImports(MyImports(), TestWasi())) wasm.test_imports(store) wasm.f_list_in_record1(store, e.ListInRecord1("list_in_record1")) @@ -83,4 +73,4 @@ def run(wasm_file: str) -> None: assert(r2 == ['typedef4']) if __name__ == '__main__': - run(sys.argv[1]) + run() diff --git a/tests/runtime/invalid/host.py b/tests/runtime/invalid/host.py index 8709d9f93..d5477c30f 100644 --- a/tests/runtime/invalid/host.py +++ b/tests/runtime/invalid/host.py @@ -1,9 +1,8 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports from typing import Callable, List, Tuple -import imports.bindings as i -import sys import wasmtime +from invalid import Invalid, InvalidImports +from invalid.imports import Imports, imports as i +from helpers import TestWasi class MyImports(Imports): def roundtrip_u8(self, x: int) -> int: @@ -58,23 +57,13 @@ def unaligned10(self, x: List[bytes]) -> None: raise Exception('unreachable') -def new_wasm(wasm_file: str) -> Tuple[wasmtime.Store, Exports]: +def new_wasm() -> Tuple[wasmtime.Store, Invalid]: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) - - imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = Invalid(store, InvalidImports(MyImports(), TestWasi())) return (store, wasm) -def run(wasm_file: str) -> None: - (store, wasm) = new_wasm(wasm_file) +def run() -> None: + (store, wasm) = new_wasm() def assert_throws(f: Callable, msg: str) -> None: try: @@ -94,21 +83,21 @@ def assert_throws(f: Callable, msg: str) -> None: # FIXME(#376) these should succeed assert_throws(lambda: wasm.invalid_bool(store), 'discriminant for bool') - (store, wasm) = new_wasm(wasm_file) + (store, wasm) = new_wasm() assert_throws(lambda: wasm.invalid_u8(store), 'must be between') - (store, wasm) = new_wasm(wasm_file) + (store, wasm) = new_wasm() assert_throws(lambda: wasm.invalid_s8(store), 'must be between') - (store, wasm) = new_wasm(wasm_file) + (store, wasm) = new_wasm() assert_throws(lambda: wasm.invalid_u16(store), 'must be between') - (store, wasm) = new_wasm(wasm_file) + (store, wasm) = new_wasm() assert_throws(lambda: wasm.invalid_s16(store), 'must be between') - (store, wasm) = new_wasm(wasm_file) + (store, wasm) = new_wasm() assert_throws(lambda: wasm.invalid_char(store), 'not a valid char') - (store, wasm) = new_wasm(wasm_file) + (store, wasm) = new_wasm() assert_throws(lambda: wasm.invalid_enum(store), 'not a valid E') # FIXME(#370) should call `unalignedN` and expect an error if __name__ == '__main__': - run(sys.argv[1]) + run() diff --git a/tests/runtime/lists/host.py b/tests/runtime/lists/host.py index 0d465aaf6..8e5ea000a 100644 --- a/tests/runtime/lists/host.py +++ b/tests/runtime/lists/host.py @@ -1,8 +1,6 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports from typing import Tuple, List -import exports.bindings as e -import imports.bindings as i +from helpers import TestWasi +from lists import Lists, ListsImports import sys import wasmtime @@ -71,19 +69,9 @@ def list_minmax_float(self, a: List[float], b: List[float]) -> Tuple[List[float] assert(b == [-sys.float_info.max, sys.float_info.max, -float('inf'), float('inf')]) return (a, b) -def run(wasm_file: str) -> None: +def run() -> None: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) - - imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = Lists(store, ListsImports(MyImports(), TestWasi())) allocated_bytes = wasm.allocated_bytes(store) wasm.test_imports(store) @@ -108,4 +96,4 @@ def run(wasm_file: str) -> None: assert(allocated_bytes == wasm.allocated_bytes(store)) if __name__ == '__main__': - run(sys.argv[1]) + run() diff --git a/tests/runtime/many_arguments/host.py b/tests/runtime/many_arguments/host.py index 68c1ba12d..b5819f9b3 100644 --- a/tests/runtime/many_arguments/host.py +++ b/tests/runtime/many_arguments/host.py @@ -1,8 +1,6 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports -import math; -import sys import wasmtime +from many_arguments import ManyArguments, ManyArgumentsImports +from helpers import TestWasi class MyImports: def many_arguments(self, @@ -40,21 +38,11 @@ def many_arguments(self, assert(a16 == 16) -def run(wasm_file: str) -> None: +def run() -> None: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) - - imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = ManyArguments(store, ManyArgumentsImports(MyImports(), TestWasi())) wasm.many_arguments(store, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,14, 15, 16) if __name__ == '__main__': - run(sys.argv[1]) + run() diff --git a/tests/runtime/numbers/host.py b/tests/runtime/numbers/host.py index 283aaf091..228c97ff0 100644 --- a/tests/runtime/numbers/host.py +++ b/tests/runtime/numbers/host.py @@ -1,8 +1,7 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports -import math; -import sys +import math import wasmtime +from numbers import Numbers, NumbersImports +from helpers import TestWasi class MyImports: def roundtrip_u8(self, a: int) -> int: @@ -44,20 +43,9 @@ def set_scalar(self, a: int) -> None: def get_scalar(self) -> int: return self.scalar - -def run(wasm_file: str) -> None: +def run() -> None: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) - - imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = Numbers(store, NumbersImports(MyImports(), TestWasi())) wasm.test_imports(store) assert(wasm.roundtrip_u8(store, 1) == 1) @@ -103,4 +91,4 @@ def run(wasm_file: str) -> None: assert(wasm.get_scalar(store) == 4) if __name__ == '__main__': - run(sys.argv[1]) + run() diff --git a/tests/runtime/records/host.py b/tests/runtime/records/host.py index cce15b9a3..61198223f 100644 --- a/tests/runtime/records/host.py +++ b/tests/runtime/records/host.py @@ -1,9 +1,8 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports +from helpers import TestWasi +from records import Records, RecordsImports +from records.imports import imports as i from typing import Tuple -import exports.bindings as e -import imports.bindings as i -import sys +import records as e import wasmtime class MyImports: @@ -31,19 +30,9 @@ def tuple0(self, a: None) -> None: def tuple1(self, a: Tuple[int]) -> Tuple[int]: return (a[0],) -def run(wasm_file: str) -> None: +def run() -> None: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) - - imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = Records(store, RecordsImports(MyImports(), TestWasi())) wasm.test_imports(store) assert(wasm.multiple_results(store) == (100, 200)) @@ -69,4 +58,4 @@ def run(wasm_file: str) -> None: assert(wasm.tuple1(store, (1,)) == (1,)) if __name__ == '__main__': - run(sys.argv[1]) + run() diff --git a/tests/runtime/smoke/host.py b/tests/runtime/smoke/host.py index 38fa86c75..4f390775b 100644 --- a/tests/runtime/smoke/host.py +++ b/tests/runtime/smoke/host.py @@ -1,27 +1,19 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports -import sys +from smoke import Smoke, SmokeImports +from helpers import TestWasi import wasmtime class MyImports: def thunk(self) -> None: self.hit = True -def run(wasm_file: str) -> None: +def run() -> None: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = Smoke(store, SmokeImports(imports, TestWasi())) + wasm.thunk(store) assert(imports.hit) if __name__ == '__main__': - run(sys.argv[1]) + run() diff --git a/tests/runtime/unions/host.py b/tests/runtime/unions/host.py index d95a7a529..08c95ae6a 100644 --- a/tests/runtime/unions/host.py +++ b/tests/runtime/unions/host.py @@ -1,10 +1,9 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports from typing import Union -import exports.bindings as e -import imports.bindings as i -import sys import wasmtime +from helpers import TestWasi +from unions import Unions, UnionsImports +import unions as e +from unions.imports import imports as i class MyImports: # Simple uses of unions whose inner values all have the same Python representation @@ -151,20 +150,12 @@ def identify_distinguishable_num(self, num: i.DistinguishableNum) -> int: else: raise ValueError("Invalid input value!") -def run(wasm_file: str) -> None: +def run() -> None: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) - - imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = Unions(store, UnionsImports(MyImports(), TestWasi())) + # TODO: should get these tests working ideally but requires some + # bit-fiddling on the host. # wasm.test_imports(store) # All-Integers @@ -234,4 +225,4 @@ def run(wasm_file: str) -> None: assert wasm.identify_distinguishable_num(store, 1) == 1 if __name__ == '__main__': - run(sys.argv[1]) + run() diff --git a/tests/runtime/variants/host.py b/tests/runtime/variants/host.py index a00989dc3..59b064725 100644 --- a/tests/runtime/variants/host.py +++ b/tests/runtime/variants/host.py @@ -1,10 +1,10 @@ -from exports.bindings import Exports -from imports.bindings import add_imports_to_linker, Imports from typing import Optional, Tuple -import exports.bindings as e -import imports.bindings as i -import sys import wasmtime +from helpers import TestWasi +from variants.imports import imports as i +from variants import Variants, VariantsImports +import variants as e +from variants.types import Result, Ok, Err class MyImports: def roundtrip_option(self, a: Optional[float]) -> Optional[int]: @@ -12,10 +12,10 @@ def roundtrip_option(self, a: Optional[float]) -> Optional[int]: return int(a) return None - def roundtrip_result(self, a: i.Result[int, float]) -> i.Result[float, int]: - if isinstance(a, i.Ok): - return i.Ok(float(a.value)) - return i.Err(int(a.value)) + def roundtrip_result(self, a: Result[int, float]) -> Result[float, int]: + if isinstance(a, Ok): + return Ok(float(a.value)) + return Err(int(a.value)) def roundtrip_enum(self, a: i.E1) -> i.E1: return a @@ -32,34 +32,24 @@ def variant_zeros(self, a: i.Zeros) -> i.Zeros: def variant_typedefs(self, a: i.OptionTypedef, b: i.BoolTypedef, c: i.ResultTypedef) -> None: pass - def variant_enums(self, a: bool, b: i.Result[None, None], c: i.MyErrno) -> Tuple[bool, i.Result[None, None], i.MyErrno]: + def variant_enums(self, a: bool, b: Result[None, None], c: i.MyErrno) -> Tuple[bool, Result[None, None], i.MyErrno]: assert(a) - assert(isinstance(b, i.Ok)) + assert(isinstance(b, Ok)) assert(c == i.MyErrno.SUCCESS) - return (False, i.Err(None), i.MyErrno.A) + return (False, Err(None), i.MyErrno.A) -def run(wasm_file: str) -> None: +def run() -> None: store = wasmtime.Store() - module = wasmtime.Module.from_file(store.engine, wasm_file) - linker = wasmtime.Linker(store.engine) - linker.define_wasi() - wasi = wasmtime.WasiConfig() - wasi.inherit_stdout() - wasi.inherit_stderr() - store.set_wasi(wasi) - - imports = MyImports() - add_imports_to_linker(linker, store, imports) - wasm = Exports(store, linker, module) + wasm = Variants(store, VariantsImports(MyImports(), TestWasi())) wasm.test_imports(store) assert(wasm.roundtrip_option(store, 1.) == 1) assert(wasm.roundtrip_option(store, None) == None) assert(wasm.roundtrip_option(store, 2.) == 2) - assert(wasm.roundtrip_result(store, e.Ok(2)) == e.Ok(2)) - assert(wasm.roundtrip_result(store, e.Ok(4)) == e.Ok(4)) - assert(wasm.roundtrip_result(store, e.Err(5)) == e.Err(5)) + assert(wasm.roundtrip_result(store, Ok(2)) == Ok(2)) + assert(wasm.roundtrip_result(store, Ok(4)) == Ok(4)) + assert(wasm.roundtrip_result(store, Err(5)) == Err(5)) assert(wasm.roundtrip_enum(store, e.E1.A) == e.E1.A) assert(wasm.roundtrip_enum(store, e.E1.B) == e.E1.B) @@ -108,7 +98,7 @@ def run(wasm_file: str) -> None: assert(z3 == e.Z3A(3)) assert(z4 == e.Z4A(4)) - wasm.variant_typedefs(store, None, False, e.Err(None)) + wasm.variant_typedefs(store, None, False, Err(None)) if __name__ == '__main__': - run(sys.argv[1]) + run()