Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: register wasi-nn on the linker #231

Merged
merged 1 commit into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/runtimes/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub enum RuntimeError {
IOError(std::io::Error),
MissingRuntime { extension: String },
StoreError(wws_store::errors::StoreError),
WasiContextError,
WasiContextError { error: String },
WasiError(Option<wasmtime_wasi::Error>),
}

Expand Down
4 changes: 3 additions & 1 deletion crates/runtimes/src/modules/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ impl Runtime for ExternalRuntime {
"/src",
)?
.args(&self.metadata.args)
.map_err(|_| errors::RuntimeError::WasiContextError)?;
.map_err(|err| errors::RuntimeError::WasiContextError {
error: format!("{}", err),
})?;
}
CtxBuilder::Preview2(ref mut builder) => {
builder
Expand Down
132 changes: 75 additions & 57 deletions crates/worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,76 @@ impl Worker {

self.runtime.prepare_wasi_ctx(&mut wasi_builder)?;

let allowed_backends = &self.config.features.wasi_nn.allowed_backends;
let preload_models = &self.config.features.wasi_nn.preload_models;
let wasi_nn = if !preload_models.is_empty() {
// Preload the models on the host.
let graphs = preload_models
.iter()
.map(|m| m.build_graph_data(&self.path))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs).map_err(|err| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError {
error: format!("{}", err),
},
)
})?;

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else if !allowed_backends.is_empty() {
let registry = Registry::from(InMemoryRegistry::new());
let mut backends = Vec::new();

// Load the given backends:
for b in allowed_backends.iter() {
if let Some(backend) = b.to_backend() {
backends.push(backend);
}
}

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else {
None
};

let host = match wasi_builder {
CtxBuilder::Preview1(mut wasi_builder) => Host {
wasi_preview1_ctx: Some(wasi_builder.build()),
wasi_nn: None,
http: Some(HttpBindings {
http_config: self.config.features.http_requests.clone(),
}),
..Host::default()
},
CtxBuilder::Preview1(mut wasi_builder) => {
if wasi_nn.is_some() {
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |host: &mut Host| {
Arc::get_mut(host.wasi_nn.as_mut().unwrap()).unwrap()
})
.map_err(|err| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError {
error: format!("{}", err),
},
)
})?;
}
Host {
wasi_preview1_ctx: Some(wasi_builder.build()),
wasi_nn,
http: Some(HttpBindings {
http_config: self.config.features.http_requests.clone(),
}),
..Host::default()
}
}
CtxBuilder::Preview2(mut wasi_builder) => {
if wasi_nn.is_some() {
wasmtime_wasi_nn::wit::ML::add_to_linker(
&mut component_linker,
|host: &mut Host| Arc::get_mut(host.wasi_nn.as_mut().unwrap()).unwrap(),
)
.map_err(|err| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError {
error: format!("{}", err),
},
)
})?;
}
let mut table = preview2::Table::default();
Host {
wasi_preview2_ctx: Some(Arc::new(wasi_builder.build(&mut table).map_err(
Expand All @@ -304,7 +364,7 @@ impl Worker {
wasi_preview2_adapter: Arc::new(
preview2::preview1::WasiPreview1Adapter::default(),
),
wasi_nn: None,
wasi_nn,
http: Some(HttpBindings {
http_config: self.config.features.http_requests.clone(),
}),
Expand All @@ -313,53 +373,6 @@ impl Worker {
}
};

// Setup wasi-nn
{
let allowed_backends = &self.config.features.wasi_nn.allowed_backends;
let preload_models = &self.config.features.wasi_nn.preload_models;
let wasi_nn = if !preload_models.is_empty() {
// Preload the models on the host.
let graphs = preload_models
.iter()
.map(|m| m.build_graph_data(&self.path))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs).map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else if !allowed_backends.is_empty() {
let registry = Registry::from(InMemoryRegistry::new());
let mut backends = Vec::new();

// Load the given backends:
for b in allowed_backends.iter() {
if let Some(backend) = b.to_backend() {
backends.push(backend);
}
}

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else {
None
};

// Load the Wasi NN linker
if wasi_nn.is_some() {
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |host: &mut Host| {
Arc::get_mut(host.wasi_nn.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})
.map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;
}
}

let contents = {
let mut store = Store::new(&self.engine, host);
match &self.module_or_component {
Expand Down Expand Up @@ -403,7 +416,12 @@ impl Worker {
&component_linker,
)
.await
.unwrap();
.map_err(|error| {
errors::WorkerError::ConfigureRuntimeError {
error: format!("error instantiating component cli::run: {error}"),
}
})?;

let _ = command
.wasi_cli_run()
.call_run(&mut store)
Expand Down