Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow to preload ML models when running inference #224

Merged
merged 3 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ exclude = [
"examples/rust-kv",
"examples/rust-params",
"examples/rust-wasi-nn",
"examples/rust-wasi-nn-preload",
]

[workspace.dependencies]
Expand Down
92 changes: 90 additions & 2 deletions crates/worker/src/features/wasi_nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,100 @@
// SPDX-License-Identifier: Apache-2.0

use serde::Deserialize;
use std::fmt::Display;
use std::path::{Path, PathBuf};
use wasmtime_wasi_nn::backend::openvino::OpenvinoBackend;
use wasmtime_wasi_nn::Backend;

pub const WASI_NN_BACKEND_OPENVINO: &str = "openvino";
/// Available Machine Learning backends
#[derive(Deserialize, Clone, Default)]
#[serde(rename_all = "lowercase")]
pub enum WasiNnBackend {
/// None
#[default]
None,
/// OpenVINO backend
Openvino,
}

impl WasiNnBackend {
/// Convert the given enum variant into a WASI-NN backend.
pub fn to_backend(&self) -> Option<Backend> {
match self {
Self::None => None,
Self::Openvino => Some(Backend::from(OpenvinoBackend::default())),
}
}
}

impl Display for WasiNnBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::None => write!(f, "none"),
Self::Openvino => write!(f, "openvino"),
}
}
}

/// Available providers to load Wasi NN models.
#[derive(Deserialize, Clone, Debug)]
#[serde(rename_all = "lowercase", tag = "type")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

pub enum WasiNnModelProvider {
/// Load it from the local filesystem
Local { dir: PathBuf },
}

impl Default for WasiNnModelProvider {
fn default() -> Self {
Self::Local {
dir: PathBuf::from("./"),
}
}
}

#[derive(Deserialize, Clone, Default)]
#[serde(default)]
pub struct WasiNnModel {
/// The provider to retrieve the given model.
provider: WasiNnModelProvider,
/// Backend to run this specific model
backend: WasiNnBackend,
}

impl WasiNnModel {
/// Provide the graph configuration from the current model. Depending on the
/// provider, it may need to perform other tasks before running it.
pub fn build_graph_data(&self, worker_path: &Path) -> (String, String) {
match &self.provider {
WasiNnModelProvider::Local { dir } => {
let data = if dir.is_relative() {
worker_path.parent().map(|parent| {
(
self.backend.clone().to_string(),
parent.join(dir).to_string_lossy().to_string(),
)
})
} else {
None
};

data.unwrap_or_else(|| {
// Absolute path or best effort if it cannot retrieve the parent path
(
self.backend.clone().to_string(),
dir.to_string_lossy().to_string(),
)
})
}
}
}
}

#[derive(Deserialize, Clone, Default)]
#[serde(default)]
pub struct WasiNnConfig {
/// List of Machine Learning backends. For now, only "openvino" option is supported
pub allowed_backends: Vec<String>,
pub allowed_backends: Vec<WasiNnBackend>,
/// List of preloaded models. It allows you to get the models from different strategies.
pub preload_models: Vec<WasiNnModel>,
}
70 changes: 41 additions & 29 deletions crates/worker/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2022 VMware, Inc.
// Copyright 2022-2023 VMware, Inc.
// SPDX-License-Identifier: Apache-2.0

mod bindings;
Expand All @@ -12,7 +12,6 @@ use actix_web::HttpRequest;
use bindings::http::{add_to_linker as http_add_to_linker, HttpBindings};
use config::Config;
use errors::Result;
use features::wasi_nn::WASI_NN_BACKEND_OPENVINO;
use io::{WasmInput, WasmOutput};
use sha256::digest as sha256_digest;
use std::fs;
Expand All @@ -23,7 +22,7 @@ use stdio::Stdio;
use wasi_common::WasiCtx;
use wasmtime::{Engine, Linker, Module, Store};
use wasmtime_wasi::{ambient_authority, Dir, WasiCtxBuilder};
use wasmtime_wasi_nn::WasiNnCtx;
use wasmtime_wasi_nn::{InMemoryRegistry, Registry, WasiNnCtx};
use wws_config::Config as ProjectConfig;
use wws_runtimes::{init_runtime, Runtime};

Expand Down Expand Up @@ -137,37 +136,50 @@ impl Worker {

// WASI-NN
let allowed_backends = &self.config.features.wasi_nn.allowed_backends;

let wasi_nn = if !allowed_backends.is_empty() {
// For now, we only support OpenVINO
if allowed_backends.len() != 1
|| !allowed_backends.contains(&WASI_NN_BACKEND_OPENVINO.to_string())
{
eprintln!("❌ The only WASI-NN supported backend name is \"{WASI_NN_BACKEND_OPENVINO}\". Please, update your config.");
None
} else {
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut WorkerState| {
Arc::get_mut(s.wasi_nn.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})
.map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;

let (backends, registry) = wasmtime_wasi_nn::preload(&[]).map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;

Some(Arc::new(WasiNnCtx::new(backends, registry)))
let preload_models = &self.config.features.wasi_nn.preload_models;

let wasi_nn = if !preload_models.is_empty() {
// Preload the models on the host.
let graphs = preload_models
.iter()
.map(|m| m.build_graph_data(&self.path))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs).map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else if !allowed_backends.is_empty() {
let registry = Registry::from(InMemoryRegistry::new());
let mut backends = Vec::new();

// Load the given backends:
for b in allowed_backends.iter() {
if let Some(backend) = b.to_backend() {
backends.push(backend);
}
}

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else {
None
};

// Load the Wasi NN linker
if wasi_nn.is_some() {
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut WorkerState| {
Arc::get_mut(s.wasi_nn.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})
.map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;
}

// Pass to the runtime to add any WASI specific requirement
self.runtime.prepare_wasi_ctx(&mut wasi_builder)?;

Expand Down
5 changes: 5 additions & 0 deletions examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ rust-wasi-nn:
cargo build --target wasm32-wasi --release && \
mv target/wasm32-wasi/release/rust-wasi-nn.wasm "./inference.wasm"

rust-wasi-nn-preload:
cd rust-wasi-nn-preload && \
cargo build --target wasm32-wasi --release && \
mv target/wasm32-wasi/release/rust-wasi-nn-preload.wasm "./inference.wasm"

rust-pdf-create:
cd rust-pdf-create && \
cargo build --target wasm32-wasi --release && \
Expand Down
4 changes: 4 additions & 0 deletions examples/rust-wasi-nn-preload/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_images/image.jpg
_models/**/*.bin
_models/**/*.xml
!inference.wasm
Loading