From 0a7282620d7a300b83735c10ca7302dbd8e72718 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20Fern=C3=A1ndez=20L=C3=B3pez?= Date: Thu, 5 Oct 2023 11:50:33 +0200 Subject: [PATCH] chore: register wasi-nn on the linker (#231) --- crates/runtimes/src/errors.rs | 2 +- crates/runtimes/src/modules/external.rs | 4 +- crates/worker/src/lib.rs | 132 ++++++++++++++---------- 3 files changed, 79 insertions(+), 59 deletions(-) diff --git a/crates/runtimes/src/errors.rs b/crates/runtimes/src/errors.rs index d6d74ba2..5852ec93 100644 --- a/crates/runtimes/src/errors.rs +++ b/crates/runtimes/src/errors.rs @@ -11,7 +11,7 @@ pub enum RuntimeError { IOError(std::io::Error), MissingRuntime { extension: String }, StoreError(wws_store::errors::StoreError), - WasiContextError, + WasiContextError { error: String }, WasiError(Option), } diff --git a/crates/runtimes/src/modules/external.rs b/crates/runtimes/src/modules/external.rs index 19bdbf2c..9833db5f 100644 --- a/crates/runtimes/src/modules/external.rs +++ b/crates/runtimes/src/modules/external.rs @@ -99,7 +99,9 @@ impl Runtime for ExternalRuntime { "/src", )? .args(&self.metadata.args) - .map_err(|_| errors::RuntimeError::WasiContextError)?; + .map_err(|err| errors::RuntimeError::WasiContextError { + error: format!("{}", err), + })?; } CtxBuilder::Preview2(ref mut builder) => { builder diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index 7a4f880d..f3b44bc8 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -283,16 +283,76 @@ impl Worker { self.runtime.prepare_wasi_ctx(&mut wasi_builder)?; + let allowed_backends = &self.config.features.wasi_nn.allowed_backends; + let preload_models = &self.config.features.wasi_nn.preload_models; + let wasi_nn = if !preload_models.is_empty() { + // Preload the models on the host. + let graphs = preload_models + .iter() + .map(|m| m.build_graph_data(&self.path)) + .collect::>(); + let (backends, registry) = wasmtime_wasi_nn::preload(&graphs).map_err(|err| { + errors::WorkerError::RuntimeError( + wws_runtimes::errors::RuntimeError::WasiContextError { + error: format!("{}", err), + }, + ) + })?; + + Some(Arc::new(WasiNnCtx::new(backends, registry))) + } else if !allowed_backends.is_empty() { + let registry = Registry::from(InMemoryRegistry::new()); + let mut backends = Vec::new(); + + // Load the given backends: + for b in allowed_backends.iter() { + if let Some(backend) = b.to_backend() { + backends.push(backend); + } + } + + Some(Arc::new(WasiNnCtx::new(backends, registry))) + } else { + None + }; + let host = match wasi_builder { - CtxBuilder::Preview1(mut wasi_builder) => Host { - wasi_preview1_ctx: Some(wasi_builder.build()), - wasi_nn: None, - http: Some(HttpBindings { - http_config: self.config.features.http_requests.clone(), - }), - ..Host::default() - }, + CtxBuilder::Preview1(mut wasi_builder) => { + if wasi_nn.is_some() { + wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |host: &mut Host| { + Arc::get_mut(host.wasi_nn.as_mut().unwrap()).unwrap() + }) + .map_err(|err| { + errors::WorkerError::RuntimeError( + wws_runtimes::errors::RuntimeError::WasiContextError { + error: format!("{}", err), + }, + ) + })?; + } + Host { + wasi_preview1_ctx: Some(wasi_builder.build()), + wasi_nn, + http: Some(HttpBindings { + http_config: self.config.features.http_requests.clone(), + }), + ..Host::default() + } + } CtxBuilder::Preview2(mut wasi_builder) => { + if wasi_nn.is_some() { + wasmtime_wasi_nn::wit::ML::add_to_linker( + &mut component_linker, + |host: &mut Host| Arc::get_mut(host.wasi_nn.as_mut().unwrap()).unwrap(), + ) + .map_err(|err| { + errors::WorkerError::RuntimeError( + wws_runtimes::errors::RuntimeError::WasiContextError { + error: format!("{}", err), + }, + ) + })?; + } let mut table = preview2::Table::default(); Host { wasi_preview2_ctx: Some(Arc::new(wasi_builder.build(&mut table).map_err( @@ -304,7 +364,7 @@ impl Worker { wasi_preview2_adapter: Arc::new( preview2::preview1::WasiPreview1Adapter::default(), ), - wasi_nn: None, + wasi_nn, http: Some(HttpBindings { http_config: self.config.features.http_requests.clone(), }), @@ -313,53 +373,6 @@ impl Worker { } }; - // Setup wasi-nn - { - let allowed_backends = &self.config.features.wasi_nn.allowed_backends; - let preload_models = &self.config.features.wasi_nn.preload_models; - let wasi_nn = if !preload_models.is_empty() { - // Preload the models on the host. - let graphs = preload_models - .iter() - .map(|m| m.build_graph_data(&self.path)) - .collect::>(); - let (backends, registry) = wasmtime_wasi_nn::preload(&graphs).map_err(|_| { - errors::WorkerError::RuntimeError( - wws_runtimes::errors::RuntimeError::WasiContextError, - ) - })?; - - Some(Arc::new(WasiNnCtx::new(backends, registry))) - } else if !allowed_backends.is_empty() { - let registry = Registry::from(InMemoryRegistry::new()); - let mut backends = Vec::new(); - - // Load the given backends: - for b in allowed_backends.iter() { - if let Some(backend) = b.to_backend() { - backends.push(backend); - } - } - - Some(Arc::new(WasiNnCtx::new(backends, registry))) - } else { - None - }; - - // Load the Wasi NN linker - if wasi_nn.is_some() { - wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |host: &mut Host| { - Arc::get_mut(host.wasi_nn.as_mut().unwrap()) - .expect("wasi-nn is not implemented with multi-threading support") - }) - .map_err(|_| { - errors::WorkerError::RuntimeError( - wws_runtimes::errors::RuntimeError::WasiContextError, - ) - })?; - } - } - let contents = { let mut store = Store::new(&self.engine, host); match &self.module_or_component { @@ -403,7 +416,12 @@ impl Worker { &component_linker, ) .await - .unwrap(); + .map_err(|error| { + errors::WorkerError::ConfigureRuntimeError { + error: format!("error instantiating component cli::run: {error}"), + } + })?; + let _ = command .wasi_cli_run() .call_run(&mut store)