Skip to content

Commit

Permalink
chore: register wasi-nn on the linker (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
ereslibre committed Oct 5, 2023
1 parent 451f67a commit 0a72826
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 59 deletions.
2 changes: 1 addition & 1 deletion crates/runtimes/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub enum RuntimeError {
IOError(std::io::Error),
MissingRuntime { extension: String },
StoreError(wws_store::errors::StoreError),
WasiContextError,
WasiContextError { error: String },
WasiError(Option<wasmtime_wasi::Error>),
}

Expand Down
4 changes: 3 additions & 1 deletion crates/runtimes/src/modules/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ impl Runtime for ExternalRuntime {
"/src",
)?
.args(&self.metadata.args)
.map_err(|_| errors::RuntimeError::WasiContextError)?;
.map_err(|err| errors::RuntimeError::WasiContextError {
error: format!("{}", err),
})?;
}
CtxBuilder::Preview2(ref mut builder) => {
builder
Expand Down
132 changes: 75 additions & 57 deletions crates/worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,76 @@ impl Worker {

self.runtime.prepare_wasi_ctx(&mut wasi_builder)?;

let allowed_backends = &self.config.features.wasi_nn.allowed_backends;
let preload_models = &self.config.features.wasi_nn.preload_models;
let wasi_nn = if !preload_models.is_empty() {
// Preload the models on the host.
let graphs = preload_models
.iter()
.map(|m| m.build_graph_data(&self.path))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs).map_err(|err| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError {
error: format!("{}", err),
},
)
})?;

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else if !allowed_backends.is_empty() {
let registry = Registry::from(InMemoryRegistry::new());
let mut backends = Vec::new();

// Load the given backends:
for b in allowed_backends.iter() {
if let Some(backend) = b.to_backend() {
backends.push(backend);
}
}

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else {
None
};

let host = match wasi_builder {
CtxBuilder::Preview1(mut wasi_builder) => Host {
wasi_preview1_ctx: Some(wasi_builder.build()),
wasi_nn: None,
http: Some(HttpBindings {
http_config: self.config.features.http_requests.clone(),
}),
..Host::default()
},
CtxBuilder::Preview1(mut wasi_builder) => {
if wasi_nn.is_some() {
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |host: &mut Host| {
Arc::get_mut(host.wasi_nn.as_mut().unwrap()).unwrap()
})
.map_err(|err| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError {
error: format!("{}", err),
},
)
})?;
}
Host {
wasi_preview1_ctx: Some(wasi_builder.build()),
wasi_nn,
http: Some(HttpBindings {
http_config: self.config.features.http_requests.clone(),
}),
..Host::default()
}
}
CtxBuilder::Preview2(mut wasi_builder) => {
if wasi_nn.is_some() {
wasmtime_wasi_nn::wit::ML::add_to_linker(
&mut component_linker,
|host: &mut Host| Arc::get_mut(host.wasi_nn.as_mut().unwrap()).unwrap(),
)
.map_err(|err| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError {
error: format!("{}", err),
},
)
})?;
}
let mut table = preview2::Table::default();
Host {
wasi_preview2_ctx: Some(Arc::new(wasi_builder.build(&mut table).map_err(
Expand All @@ -304,7 +364,7 @@ impl Worker {
wasi_preview2_adapter: Arc::new(
preview2::preview1::WasiPreview1Adapter::default(),
),
wasi_nn: None,
wasi_nn,
http: Some(HttpBindings {
http_config: self.config.features.http_requests.clone(),
}),
Expand All @@ -313,53 +373,6 @@ impl Worker {
}
};

// Setup wasi-nn
{
let allowed_backends = &self.config.features.wasi_nn.allowed_backends;
let preload_models = &self.config.features.wasi_nn.preload_models;
let wasi_nn = if !preload_models.is_empty() {
// Preload the models on the host.
let graphs = preload_models
.iter()
.map(|m| m.build_graph_data(&self.path))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs).map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else if !allowed_backends.is_empty() {
let registry = Registry::from(InMemoryRegistry::new());
let mut backends = Vec::new();

// Load the given backends:
for b in allowed_backends.iter() {
if let Some(backend) = b.to_backend() {
backends.push(backend);
}
}

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else {
None
};

// Load the Wasi NN linker
if wasi_nn.is_some() {
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |host: &mut Host| {
Arc::get_mut(host.wasi_nn.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})
.map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;
}
}

let contents = {
let mut store = Store::new(&self.engine, host);
match &self.module_or_component {
Expand Down Expand Up @@ -403,7 +416,12 @@ impl Worker {
&component_linker,
)
.await
.unwrap();
.map_err(|error| {
errors::WorkerError::ConfigureRuntimeError {
error: format!("error instantiating component cli::run: {error}"),
}
})?;

let _ = command
.wasi_cli_run()
.call_run(&mut store)
Expand Down

0 comments on commit 0a72826

Please sign in to comment.