Skip to content

Commit

Permalink
feat: allow to preload ML models when running inference (#224)
Browse files Browse the repository at this point in the history
* feat: preload machine-learning models using WASI-NN named models

* feat: load only configured ML backends

* feat: adjust configuration to validate data from preload model providers
  • Loading branch information
Angelmmiguel authored Sep 28, 2023
1 parent 3b4858d commit 0b160a4
Show file tree
Hide file tree
Showing 18 changed files with 2,098 additions and 31 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ exclude = [
"examples/rust-kv",
"examples/rust-params",
"examples/rust-wasi-nn",
"examples/rust-wasi-nn-preload",
]

[workspace.dependencies]
Expand Down
92 changes: 90 additions & 2 deletions crates/worker/src/features/wasi_nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,100 @@
// SPDX-License-Identifier: Apache-2.0

use serde::Deserialize;
use std::fmt::Display;
use std::path::{Path, PathBuf};
use wasmtime_wasi_nn::backend::openvino::OpenvinoBackend;
use wasmtime_wasi_nn::Backend;

pub const WASI_NN_BACKEND_OPENVINO: &str = "openvino";
/// Available Machine Learning backends
#[derive(Deserialize, Clone, Default)]
#[serde(rename_all = "lowercase")]
pub enum WasiNnBackend {
/// None
#[default]
None,
/// OpenVINO backend
Openvino,
}

impl WasiNnBackend {
/// Convert the given enum variant into a WASI-NN backend.
pub fn to_backend(&self) -> Option<Backend> {
match self {
Self::None => None,
Self::Openvino => Some(Backend::from(OpenvinoBackend::default())),
}
}
}

impl Display for WasiNnBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::None => write!(f, "none"),
Self::Openvino => write!(f, "openvino"),
}
}
}

/// Available providers to load Wasi NN models.
#[derive(Deserialize, Clone, Debug)]
#[serde(rename_all = "lowercase", tag = "type")]
pub enum WasiNnModelProvider {
/// Load it from the local filesystem
Local { dir: PathBuf },
}

impl Default for WasiNnModelProvider {
fn default() -> Self {
Self::Local {
dir: PathBuf::from("./"),
}
}
}

#[derive(Deserialize, Clone, Default)]
#[serde(default)]
pub struct WasiNnModel {
/// The provider to retrieve the given model.
provider: WasiNnModelProvider,
/// Backend to run this specific model
backend: WasiNnBackend,
}

impl WasiNnModel {
/// Provide the graph configuration from the current model. Depending on the
/// provider, it may need to perform other tasks before running it.
pub fn build_graph_data(&self, worker_path: &Path) -> (String, String) {
match &self.provider {
WasiNnModelProvider::Local { dir } => {
let data = if dir.is_relative() {
worker_path.parent().map(|parent| {
(
self.backend.clone().to_string(),
parent.join(dir).to_string_lossy().to_string(),
)
})
} else {
None
};

data.unwrap_or_else(|| {
// Absolute path or best effort if it cannot retrieve the parent path
(
self.backend.clone().to_string(),
dir.to_string_lossy().to_string(),
)
})
}
}
}
}

#[derive(Deserialize, Clone, Default)]
#[serde(default)]
pub struct WasiNnConfig {
/// List of Machine Learning backends. For now, only "openvino" option is supported
pub allowed_backends: Vec<String>,
pub allowed_backends: Vec<WasiNnBackend>,
/// List of preloaded models. It allows you to get the models from different strategies.
pub preload_models: Vec<WasiNnModel>,
}
70 changes: 41 additions & 29 deletions crates/worker/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2022 VMware, Inc.
// Copyright 2022-2023 VMware, Inc.
// SPDX-License-Identifier: Apache-2.0

mod bindings;
Expand All @@ -12,7 +12,6 @@ use actix_web::HttpRequest;
use bindings::http::{add_to_linker as http_add_to_linker, HttpBindings};
use config::Config;
use errors::Result;
use features::wasi_nn::WASI_NN_BACKEND_OPENVINO;
use io::{WasmInput, WasmOutput};
use sha256::digest as sha256_digest;
use std::fs;
Expand All @@ -23,7 +22,7 @@ use stdio::Stdio;
use wasi_common::WasiCtx;
use wasmtime::{Engine, Linker, Module, Store};
use wasmtime_wasi::{ambient_authority, Dir, WasiCtxBuilder};
use wasmtime_wasi_nn::WasiNnCtx;
use wasmtime_wasi_nn::{InMemoryRegistry, Registry, WasiNnCtx};
use wws_config::Config as ProjectConfig;
use wws_runtimes::{init_runtime, Runtime};

Expand Down Expand Up @@ -137,37 +136,50 @@ impl Worker {

// WASI-NN
let allowed_backends = &self.config.features.wasi_nn.allowed_backends;

let wasi_nn = if !allowed_backends.is_empty() {
// For now, we only support OpenVINO
if allowed_backends.len() != 1
|| !allowed_backends.contains(&WASI_NN_BACKEND_OPENVINO.to_string())
{
eprintln!("❌ The only WASI-NN supported backend name is \"{WASI_NN_BACKEND_OPENVINO}\". Please, update your config.");
None
} else {
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut WorkerState| {
Arc::get_mut(s.wasi_nn.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})
.map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;

let (backends, registry) = wasmtime_wasi_nn::preload(&[]).map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;

Some(Arc::new(WasiNnCtx::new(backends, registry)))
let preload_models = &self.config.features.wasi_nn.preload_models;

let wasi_nn = if !preload_models.is_empty() {
// Preload the models on the host.
let graphs = preload_models
.iter()
.map(|m| m.build_graph_data(&self.path))
.collect::<Vec<_>>();
let (backends, registry) = wasmtime_wasi_nn::preload(&graphs).map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else if !allowed_backends.is_empty() {
let registry = Registry::from(InMemoryRegistry::new());
let mut backends = Vec::new();

// Load the given backends:
for b in allowed_backends.iter() {
if let Some(backend) = b.to_backend() {
backends.push(backend);
}
}

Some(Arc::new(WasiNnCtx::new(backends, registry)))
} else {
None
};

// Load the Wasi NN linker
if wasi_nn.is_some() {
wasmtime_wasi_nn::witx::add_to_linker(&mut linker, |s: &mut WorkerState| {
Arc::get_mut(s.wasi_nn.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})
.map_err(|_| {
errors::WorkerError::RuntimeError(
wws_runtimes::errors::RuntimeError::WasiContextError,
)
})?;
}

// Pass to the runtime to add any WASI specific requirement
self.runtime.prepare_wasi_ctx(&mut wasi_builder)?;

Expand Down
5 changes: 5 additions & 0 deletions examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ rust-wasi-nn:
cargo build --target wasm32-wasi --release && \
mv target/wasm32-wasi/release/rust-wasi-nn.wasm "./inference.wasm"

rust-wasi-nn-preload:
cd rust-wasi-nn-preload && \
cargo build --target wasm32-wasi --release && \
mv target/wasm32-wasi/release/rust-wasi-nn-preload.wasm "./inference.wasm"

rust-pdf-create:
cd rust-pdf-create && \
cargo build --target wasm32-wasi --release && \
Expand Down
4 changes: 4 additions & 0 deletions examples/rust-wasi-nn-preload/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_images/image.jpg
_models/**/*.bin
_models/**/*.xml
!inference.wasm
Loading

0 comments on commit 0b160a4

Please sign in to comment.