Skip to content

Commit

Permalink
WIP: Add ratchet-downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
sigma-andex committed Jan 21, 2024
1 parent 50f7b5b commit 2f9670c
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 8 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
members = [
"crates/ratchet-core",
"crates/ratchet-downloader",
"crates/ratchet-integration-tests",
"crates/ratchet-loader",
"crates/ratchet-models",
Expand Down Expand Up @@ -28,6 +29,7 @@ derive-new = "0.6.0"
log = "0.4.20"
thiserror = "1.0.56"
byteorder = "1.5.0"
wasm-bindgen-test = "0.3.34"

[workspace.dev-dependencies]
hf-hub = "0.3.0"
1 change: 1 addition & 0 deletions crates/ratchet-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ glam = "0.25.0"
pollster = "0.3.0"
futures-intrusive = "0.5.0"
anyhow = "1.0.79"
getrandom = { version = "0.2", features = ["js"] } # Needed for wasm support in `num` trait
num = "0.4.1"
rand_distr = { version = "0.4.3", optional = true }
rand = { version = "0.8.4", optional = true }
Expand Down
12 changes: 6 additions & 6 deletions crates/ratchet-core/src/gpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl PartialEq for WgpuDevice {
impl WgpuDevice {
pub async fn new() -> Result<Self, DeviceError> {
#[cfg(target_arch = "wasm32")]
let adapter = Self::select_adapter().await;
let adapter = Self::select_adapter().await?;
#[cfg(not(target_arch = "wasm32"))]
let adapter = Self::select_adapter()?;

Expand Down Expand Up @@ -106,7 +106,7 @@ impl WgpuDevice {
}

#[cfg(target_arch = "wasm32")]
async fn select_adapter() -> Adapter {
async fn select_adapter() -> Result<Adapter, DeviceError> {
let instance = wgpu::Instance::default();
let backends = wgpu::util::backend_bits_from_env().unwrap_or(wgpu::Backends::PRIMARY);
instance
Expand All @@ -116,10 +116,10 @@ impl WgpuDevice {
force_fallback_adapter: false,
})
.await
.map_err(|e| {
log::error!("Failed to create device: {:?}", e);
e
})?
.ok_or({
log::error!("Failed to request adapter.");
DeviceError::AdapterRequestFailed
})
}

#[cfg(not(target_arch = "wasm32"))]
Expand Down
34 changes: 34 additions & 0 deletions crates/ratchet-downloader/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[package]
name = "ratchet-downloader"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ratchet-loader = { path = "../ratchet-loader" }
wasm-bindgen = "0.2.84"
wasm-bindgen-futures = "0.4.39"
js-sys = "0.3.64"
gloo = "0.11.0"
reqwest = "0.11.23"

[dependencies.web-sys]
features = [
'console',
'Headers',
'Request',
'RequestInit',
'RequestMode',
'Response',
'Window',
'Navigator',
'StorageManager',
'CacheStorage'
]
version = "0.3.64"

[dev-dependencies]
wasm-bindgen-test.workspace = true

[lib]
crate-type = ["cdylib", "rlib"]
35 changes: 35 additions & 0 deletions crates/ratchet-downloader/src/fetch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use js_sys::{ArrayBuffer, Uint8Array, JSON};

use wasm_bindgen::{prelude::*, JsValue};
use wasm_bindgen_futures::JsFuture;
use web_sys::{Request, RequestInit, RequestMode, Response};

fn to_error(value: JsValue) -> JsError {
JsError::new(
JSON::stringify(&value)
.map(|js_string| {
js_string
.as_string()
.unwrap_or(String::from("An unknown error occurred."))
})
.unwrap_or(String::from("An unknown error occurred."))
.as_str(),
)
}
pub(crate) async fn fetch(url: &str) -> Result<Response, JsError> {
let mut opts = RequestInit::new();
opts.method("GET");
opts.mode(RequestMode::Cors);

let request = Request::new_with_str_and_init(&url, &opts).map_err(to_error)?;

let window = web_sys::window().unwrap();
let resp_value = JsFuture::from(window.fetch_with_request(&request))
.await
.map_err(to_error)?;

assert!(resp_value.is_instance_of::<Response>());
let resp: Response = resp_value.dyn_into().unwrap();

Ok(resp)
}
1 change: 1 addition & 0 deletions crates/ratchet-downloader/src/huggingface/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod repo;
7 changes: 7 additions & 0 deletions crates/ratchet-downloader/src/huggingface/repo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub struct Repo {
pub id: String,
pub revision: String,
pub repo_type: String,
}

impl Repo {}
61 changes: 61 additions & 0 deletions crates/ratchet-downloader/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#[cfg(test)]
use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure};

use gloo::console::error as log_error;
use wasm_bindgen::{prelude::*, JsValue};

mod fetch;
pub mod huggingface;

#[cfg(test)]
wasm_bindgen_test_configure!(run_in_browser);

pub struct Model {
url: String,
}

impl Model {
fn from_hf(repo_id: String) -> Self {
Self {
url: format!("https://huggingface.co/{}/resolve/main", repo_id),
}
}

fn from_hf_with_revision(repo_id: String, revision: String) -> Self {
Self {
url: format!("https://huggingface.co/{repo_id}/resolve/{revision}"),
}
}

fn from_custom(url: String) -> Self {
Self { url }
}

async fn get(&self, file_name: String) -> Result<(), JsError> {
let file_url = format!("{}/{}", self.url, file_name);
// let response = fetch::fetch(file_url.as_str()).await?;

let res = reqwest::Client::new()
.get(file_url)
// .header("Accept", "application/vnd.github.v3+json")
.send()
.await?;
Ok(())
}
}

#[cfg(test)]
#[wasm_bindgen_test]
async fn pass() -> Result<(), JsValue> {
use js_sys::JsString;

let model = Model::from_hf("jantxu/ratchet-test".to_string());
let file = model
.get("model.safetensors".to_string())
.await
.map_err(|err| {
log_error!(err);
JsString::from("Failed to download file")
});
Ok(())
}
2 changes: 1 addition & 1 deletion crates/ratchet-integration-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dev-dependencies]
wasm-bindgen-test = "0.3.34"
wasm-bindgen-test.workspace = true
9 changes: 8 additions & 1 deletion justfile
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
line-count:
cd ./crates/ratchet-core && scc -irs --exclude-file kernels
cd ./crates/ratchet-core && scc -irs --exclude-file kernels
install-pyo3:
env PYTHON_CONFIGURE_OPTS="--enable-shared" pyenv install --verbose 3.10.6
echo "Please PYO3_PYTHON to your .bashrc or .zshrc"
wasm CRATE:
RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack build --target web -d `pwd`/target/pkg/{{CRATE}} --out-name {{CRATE}} ./crates/{{CRATE}} --release
wasm-test CRATE:
RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack test --chrome `pwd`/crates/{{CRATE}}

0 comments on commit 2f9670c

Please sign in to comment.