From b8867bd8d1421a0f2d6a37bacecb6822a045b43e Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Sun, 21 Jan 2024 18:30:09 +0000 Subject: [PATCH 01/14] WIP: Add ratchet-downloader --- Cargo.toml | 2 + crates/ratchet-core/Cargo.toml | 1 + crates/ratchet-core/src/gpu/device.rs | 12 ++-- crates/ratchet-downloader/Cargo.toml | 34 +++++++++++ crates/ratchet-downloader/src/fetch.rs | 35 ++++++++++++ .../ratchet-downloader/src/huggingface/mod.rs | 1 + .../src/huggingface/repo.rs | 7 +++ crates/ratchet-downloader/src/lib.rs | 57 +++++++++++++++++++ crates/ratchet-integration-tests/Cargo.toml | 2 +- justfile | 9 ++- 10 files changed, 152 insertions(+), 8 deletions(-) create mode 100644 crates/ratchet-downloader/Cargo.toml create mode 100644 crates/ratchet-downloader/src/fetch.rs create mode 100644 crates/ratchet-downloader/src/huggingface/mod.rs create mode 100644 crates/ratchet-downloader/src/huggingface/repo.rs create mode 100644 crates/ratchet-downloader/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 8d0becb9..52086fe2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/ratchet-core", + "crates/ratchet-downloader", "crates/ratchet-integration-tests", "crates/ratchet-loader", "crates/ratchet-models", @@ -28,6 +29,7 @@ derive-new = "0.6.0" log = "0.4.20" thiserror = "1.0.56" byteorder = "1.5.0" +wasm-bindgen-test = "0.3.34" [workspace.dev-dependencies] hf-hub = "0.3.0" diff --git a/crates/ratchet-core/Cargo.toml b/crates/ratchet-core/Cargo.toml index 97b1619b..cf0b67ff 100644 --- a/crates/ratchet-core/Cargo.toml +++ b/crates/ratchet-core/Cargo.toml @@ -31,6 +31,7 @@ glam = "0.25.0" pollster = "0.3.0" futures-intrusive = "0.5.0" anyhow = "1.0.79" +getrandom = { version = "0.2", features = ["js"] } # Needed for wasm support in `num` trait num = "0.4.1" rand_distr = { version = "0.4.3", optional = true } rand = { version = "0.8.4", optional = true } diff --git a/crates/ratchet-core/src/gpu/device.rs b/crates/ratchet-core/src/gpu/device.rs index fefe081e..0531e666 100644 --- a/crates/ratchet-core/src/gpu/device.rs +++ b/crates/ratchet-core/src/gpu/device.rs @@ -51,7 +51,7 @@ impl PartialEq for WgpuDevice { impl WgpuDevice { pub async fn new() -> Result { #[cfg(target_arch = "wasm32")] - let adapter = Self::select_adapter().await; + let adapter = Self::select_adapter().await?; #[cfg(not(target_arch = "wasm32"))] let adapter = Self::select_adapter()?; @@ -106,7 +106,7 @@ impl WgpuDevice { } #[cfg(target_arch = "wasm32")] - async fn select_adapter() -> Adapter { + async fn select_adapter() -> Result { let instance = wgpu::Instance::default(); let backends = wgpu::util::backend_bits_from_env().unwrap_or(wgpu::Backends::PRIMARY); instance @@ -116,10 +116,10 @@ impl WgpuDevice { force_fallback_adapter: false, }) .await - .map_err(|e| { - log::error!("Failed to create device: {:?}", e); - e - })? + .ok_or({ + log::error!("Failed to request adapter."); + DeviceError::AdapterRequestFailed + }) } #[cfg(not(target_arch = "wasm32"))] diff --git a/crates/ratchet-downloader/Cargo.toml b/crates/ratchet-downloader/Cargo.toml new file mode 100644 index 00000000..df9e5075 --- /dev/null +++ b/crates/ratchet-downloader/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "ratchet-downloader" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[dependencies] +ratchet-loader = { path = "../ratchet-loader" } +wasm-bindgen = "0.2.84" +wasm-bindgen-futures = "0.4.39" +js-sys = "0.3.64" +gloo = "0.11.0" +reqwest = "0.11.23" + +[dependencies.web-sys] +features = [ + 'console', + 'Headers', + 'Request', + 'RequestInit', + 'RequestMode', + 'Response', + 'Window', + 'Navigator', + 'StorageManager', + 'CacheStorage' +] +version = "0.3.64" + +[dev-dependencies] +wasm-bindgen-test.workspace = true + +[lib] +crate-type = ["cdylib", "rlib"] diff --git a/crates/ratchet-downloader/src/fetch.rs b/crates/ratchet-downloader/src/fetch.rs new file mode 100644 index 00000000..bfa51da4 --- /dev/null +++ b/crates/ratchet-downloader/src/fetch.rs @@ -0,0 +1,35 @@ +use js_sys::{ArrayBuffer, Uint8Array, JSON}; + +use wasm_bindgen::{prelude::*, JsValue}; +use wasm_bindgen_futures::JsFuture; +use web_sys::{Request, RequestInit, RequestMode, Response}; + +fn to_error(value: JsValue) -> JsError { + JsError::new( + JSON::stringify(&value) + .map(|js_string| { + js_string + .as_string() + .unwrap_or(String::from("An unknown error occurred.")) + }) + .unwrap_or(String::from("An unknown error occurred.")) + .as_str(), + ) +} +pub(crate) async fn fetch(url: &str) -> Result { + let mut opts = RequestInit::new(); + opts.method("GET"); + opts.mode(RequestMode::Cors); + + let request = Request::new_with_str_and_init(&url, &opts).map_err(to_error)?; + + let window = web_sys::window().unwrap(); + let resp_value = JsFuture::from(window.fetch_with_request(&request)) + .await + .map_err(to_error)?; + + assert!(resp_value.is_instance_of::()); + let resp: Response = resp_value.dyn_into().unwrap(); + + Ok(resp) +} diff --git a/crates/ratchet-downloader/src/huggingface/mod.rs b/crates/ratchet-downloader/src/huggingface/mod.rs new file mode 100644 index 00000000..c426b23e --- /dev/null +++ b/crates/ratchet-downloader/src/huggingface/mod.rs @@ -0,0 +1 @@ +pub mod repo; diff --git a/crates/ratchet-downloader/src/huggingface/repo.rs b/crates/ratchet-downloader/src/huggingface/repo.rs new file mode 100644 index 00000000..70a8c16c --- /dev/null +++ b/crates/ratchet-downloader/src/huggingface/repo.rs @@ -0,0 +1,7 @@ +pub struct Repo { + pub id: String, + pub revision: String, + pub repo_type: String, +} + +impl Repo {} diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs new file mode 100644 index 00000000..61688506 --- /dev/null +++ b/crates/ratchet-downloader/src/lib.rs @@ -0,0 +1,57 @@ +#[cfg(test)] +use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure}; + +use gloo::console::error as log_error; +use wasm_bindgen::{prelude::*, JsValue}; + +mod fetch; +pub mod huggingface; + +#[cfg(test)] +wasm_bindgen_test_configure!(run_in_browser); + +pub struct Model { + url: String, +} + +impl Model { + fn from_hf(repo_id: String) -> Self { + Self { + url: format!("https://huggingface.co/{}/resolve/main", repo_id), + } + } + + fn from_hf_with_revision(repo_id: String, revision: String) -> Self { + Self { + url: format!("https://huggingface.co/{repo_id}/resolve/{revision}"), + } + } + + fn from_custom(url: String) -> Self { + Self { url } + } + + async fn get(&self, file_name: String) -> Result<(), JsError> { + let file_url = format!("{}/{}", self.url, file_name); + // let response = fetch::fetch(file_url.as_str()).await?; + + let res = reqwest::Client::new().get(file_url).send().await?; + Ok(()) + } +} + +#[cfg(test)] +#[wasm_bindgen_test] +async fn pass() -> Result<(), JsValue> { + use js_sys::JsString; + + let model = Model::from_hf("jantxu/ratchet-test".to_string()); + let file = model + .get("model.safetensors".to_string()) + .await + .map_err(|err| { + log_error!(err); + JsString::from("Failed to download file") + }); + Ok(()) +} diff --git a/crates/ratchet-integration-tests/Cargo.toml b/crates/ratchet-integration-tests/Cargo.toml index 3c44d58d..b3826de9 100644 --- a/crates/ratchet-integration-tests/Cargo.toml +++ b/crates/ratchet-integration-tests/Cargo.toml @@ -6,4 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dev-dependencies] -wasm-bindgen-test = "0.3.34" +wasm-bindgen-test.workspace = true diff --git a/justfile b/justfile index 99357e51..025f505d 100644 --- a/justfile +++ b/justfile @@ -1,2 +1,9 @@ line-count: - cd ./crates/ratchet-core && scc -irs --exclude-file kernels + cd ./crates/ratchet-core && scc -irs --exclude-file kernels +install-pyo3: + env PYTHON_CONFIGURE_OPTS="--enable-shared" pyenv install --verbose 3.10.6 + echo "Please PYO3_PYTHON to your .bashrc or .zshrc" +wasm CRATE: + RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack build --target web -d `pwd`/target/pkg/{{CRATE}} --out-name {{CRATE}} ./crates/{{CRATE}} --release +wasm-test CRATE: + RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack test --chrome `pwd`/crates/{{CRATE}} From 02dad10a4346f241e4bcf774ad9973c1cc8a1df0 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Tue, 23 Jan 2024 19:37:58 +0000 Subject: [PATCH 02/14] First working test byob reader + stream parsing --- crates/ratchet-core/src/kernels.rs | 6 ++-- crates/ratchet-downloader/Cargo.toml | 7 +++- crates/ratchet-downloader/src/lib.rs | 53 +++++++++++++++++++++++----- justfile | 4 ++- 4 files changed, 56 insertions(+), 14 deletions(-) diff --git a/crates/ratchet-core/src/kernels.rs b/crates/ratchet-core/src/kernels.rs index 92408394..45908dfc 100644 --- a/crates/ratchet-core/src/kernels.rs +++ b/crates/ratchet-core/src/kernels.rs @@ -7,19 +7,19 @@ lazy_static! { m.insert( "qgemm_vec4", include_str!( - "/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/qgemm_vec4.wgsl" + "/Users/janschulte/code/ratchet/crates/ratchet-core/kernels/qgemm_vec4.wgsl" ), ); m.insert( "sgemm_scalar", include_str!( - "/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_scalar.wgsl" + "/Users/janschulte/code/ratchet/crates/ratchet-core/kernels/sgemm_scalar.wgsl" ), ); m.insert( "add_scalar", include_str!( - "/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/add_scalar.wgsl" + "/Users/janschulte/code/ratchet/crates/ratchet-core/kernels/add_scalar.wgsl" ), ); m diff --git a/crates/ratchet-downloader/Cargo.toml b/crates/ratchet-downloader/Cargo.toml index df9e5075..9eadb85e 100644 --- a/crates/ratchet-downloader/Cargo.toml +++ b/crates/ratchet-downloader/Cargo.toml @@ -10,7 +10,9 @@ wasm-bindgen = "0.2.84" wasm-bindgen-futures = "0.4.39" js-sys = "0.3.64" gloo = "0.11.0" -reqwest = "0.11.23" +wasm-streams = "0.4.0" +futures-util = { version = "^0.3.28", features = ["io", "sink"] } +winnow = "0.5.34" [dependencies.web-sys] features = [ @@ -20,6 +22,9 @@ features = [ 'RequestInit', 'RequestMode', 'Response', + 'ReadableStream', + 'ReadableStreamGetReaderOptions', + 'ReadableStreamReaderMode', 'Window', 'Navigator', 'StorageManager', diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index 61688506..bf6ab344 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -1,8 +1,15 @@ +use js_sys::Uint8Array; #[cfg(test)] use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure}; -use gloo::console::error as log_error; -use wasm_bindgen::{prelude::*, JsValue}; +use futures_util::{AsyncReadExt, StreamExt}; +use gloo::console::{debug, error as log_error}; +use js_sys::JsString; +use wasm_bindgen::{prelude::*, JsCast, JsValue}; +use wasm_bindgen_futures::JsFuture; +use wasm_streams::ReadableStream; +use web_sys::{console, ReadableStreamGetReaderOptions, ReadableStreamReaderMode}; +use winnow::{binary::bits::bytes, prelude::*, stream::Stream, Bytes, Partial}; mod fetch; pub mod huggingface; @@ -10,6 +17,13 @@ pub mod huggingface; #[cfg(test)] wasm_bindgen_test_configure!(run_in_browser); +#[wasm_bindgen] +pub fn js_error(message: String) -> JsError { + JsError::new(message.as_str()) +} + +type GGUFStream<'i> = Partial<&'i Bytes>; + pub struct Model { url: String, } @@ -31,11 +45,29 @@ impl Model { Self { url } } - async fn get(&self, file_name: String) -> Result<(), JsError> { + async fn open_stream(&self, file_name: String) -> Result<(), JsError> { let file_url = format!("{}/{}", self.url, file_name); - // let response = fetch::fetch(file_url.as_str()).await?; + let response = fetch::fetch(file_url.as_str()).await?; + + let raw_body = response + .body() + .ok_or(js_error(format!("Failed to load {}", file_name)))?; + + let mut body = ReadableStream::from_raw(raw_body); + let reader = body.get_byob_reader(); + let mut async_read = reader.into_async_read(); + + let mut buf = [0u8; 100]; + let result = async_read.read_exact(&mut buf).await?; + + let mut test = GGUFStream::new(Bytes::new(&buf)); + + let g1 = &test.next_token(); + let g2 = &test.next_token(); + let u = &test.next_token(); + let f = &test.next_token(); + debug!("Done!:", format!("{:?}{:?}{:?}{:?}", g1, g2, u, f)); - let res = reqwest::Client::new().get(file_url).send().await?; Ok(()) } } @@ -45,13 +77,16 @@ impl Model { async fn pass() -> Result<(), JsValue> { use js_sys::JsString; - let model = Model::from_hf("jantxu/ratchet-test".to_string()); - let file = model - .get("model.safetensors".to_string()) + let model = Model::from_custom("http://localhost:8888".to_string()); + let stream = model + .open_stream( + "TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + .to_string(), + ) .await .map_err(|err| { log_error!(err); JsString::from("Failed to download file") - }); + })?; Ok(()) } diff --git a/justfile b/justfile index 025f505d..b792abef 100644 --- a/justfile +++ b/justfile @@ -6,4 +6,6 @@ install-pyo3: wasm CRATE: RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack build --target web -d `pwd`/target/pkg/{{CRATE}} --out-name {{CRATE}} ./crates/{{CRATE}} --release wasm-test CRATE: - RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack test --chrome `pwd`/crates/{{CRATE}} + RUSTFLAGS="--cfg=web_sys_unstable_apis -Z threads=8" wasm-pack test --chrome `pwd`/crates/{{CRATE}} +wasm-test-headless CRATE: + RUSTFLAGS="--cfg=web_sys_unstable_apis -Z threads=8" wasm-pack test --chrome `pwd`/crates/{{CRATE}} From b164c33ea358490921fde4ff052a745b79895eac Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Wed, 24 Jan 2024 14:53:41 +0000 Subject: [PATCH 03/14] Start with header parsing --- crates/ratchet-downloader/Cargo.toml | 2 + crates/ratchet-downloader/src/lib.rs | 101 ++++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 2 deletions(-) diff --git a/crates/ratchet-downloader/Cargo.toml b/crates/ratchet-downloader/Cargo.toml index 9eadb85e..33fd6bcb 100644 --- a/crates/ratchet-downloader/Cargo.toml +++ b/crates/ratchet-downloader/Cargo.toml @@ -13,6 +13,8 @@ gloo = "0.11.0" wasm-streams = "0.4.0" futures-util = { version = "^0.3.28", features = ["io", "sink"] } winnow = "0.5.34" +circular = "0.3.0" +anyhow.workspace = true [dependencies.web-sys] features = [ diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index bf6ab344..b19af831 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -10,6 +10,7 @@ use wasm_bindgen_futures::JsFuture; use wasm_streams::ReadableStream; use web_sys::{console, ReadableStreamGetReaderOptions, ReadableStreamReaderMode}; use winnow::{binary::bits::bytes, prelude::*, stream::Stream, Bytes, Partial}; +use winnow::{binary::u32, binary::u64, combinator::preceded, Parser}; mod fetch; pub mod huggingface; @@ -22,7 +23,7 @@ pub fn js_error(message: String) -> JsError { JsError::new(message.as_str()) } -type GGUFStream<'i> = Partial<&'i Bytes>; +type BytesStream<'i> = Partial<&'i Bytes>; pub struct Model { url: String, @@ -60,7 +61,7 @@ impl Model { let mut buf = [0u8; 100]; let result = async_read.read_exact(&mut buf).await?; - let mut test = GGUFStream::new(Bytes::new(&buf)); + let mut test = BytesStream::new(Bytes::new(&buf)); let g1 = &test.next_token(); let g2 = &test.next_token(); @@ -72,6 +73,102 @@ impl Model { } } +mod gguf { + use crate::BytesStream; + use winnow::binary::u32; + use winnow::binary::u64; + use winnow::binary::Endianness; + use winnow::Parser; + + #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + pub struct Header { + pub version: u32, + pub tensor_count: u64, + pub metadata_kv_count: u64, + } + + #[inline] + fn magic_number(input: &mut BytesStream) -> winnow::PResult<()> { + // [TODO] Fix endianness + (71, 71, 85, 70).parse_next(input).map(|_magic_number| ()) + } + + #[inline] + fn version(input: &mut BytesStream) -> winnow::PResult { + u32(Endianness::Little).parse_next(input) + } + + #[inline] + fn tensor_count(input: &mut BytesStream) -> winnow::PResult { + u64(Endianness::Little).parse_next(input) + } + + #[inline] + fn metadata_kv_count(input: &mut BytesStream) -> winnow::PResult { + u64(Endianness::Little).parse_next(input) + } + + #[inline] + fn metadata_kv(input: &mut BytesStream) -> winnow::PResult { + u64(Endianness::Little).parse_next(input) + } + + pub fn parse_header(input: &mut BytesStream) -> winnow::PResult
{ + (magic_number, version, tensor_count, metadata_kv_count) + .parse_next(input) + .map(|(gguf, version, tensor_count, metadata_kv_count)| Header { + version, + tensor_count, + metadata_kv_count, + }) + } +} + +pub fn to_std_error(error: winnow::error::ErrMode) -> std::io::Error { + match error { + winnow::error::ErrMode::Backtrack(err) => { + std::io::Error::new(std::io::ErrorKind::Other, "Backtrack") + } + winnow::error::ErrMode::Cut(err) => std::io::Error::new(std::io::ErrorKind::Other, "Cut"), + winnow::error::ErrMode::Incomplete(needed) => { + std::io::Error::new(std::io::ErrorKind::Other, "Needed") + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Read; + + use anyhow::Error; + use winnow::Bytes; + + use crate::{gguf, to_std_error, BytesStream}; + + #[test] + fn test_parse_header() -> anyhow::Result<()> { + let mut file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; + + let buffer_size = 30; + let min_buffer_growth = 100; + let buffer_growth_factor = 2; + let mut buffer = circular::Buffer::with_capacity(buffer_size); + let read = file.read(buffer.space())?; + buffer.fill(read); + + let mut input = BytesStream::new(Bytes::new(buffer.data())); + + let result = gguf::parse_header(&mut input).map_err(to_std_error)?; + let expected = gguf::Header { + version: 3, + tensor_count: 201, + metadata_kv_count: 23, + }; + assert_eq!(result, expected); + Ok(()) + } +} + #[cfg(test)] #[wasm_bindgen_test] async fn pass() -> Result<(), JsValue> { From 3340889ab2ae25bf5908c9ad22a7aec77ff04b12 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:20:06 +0000 Subject: [PATCH 04/14] Prepare metdata_kv parsing --- crates/ratchet-downloader/src/lib.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index b19af831..8ce29d2f 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -78,6 +78,7 @@ mod gguf { use winnow::binary::u32; use winnow::binary::u64; use winnow::binary::Endianness; + use winnow::error::ContextError; use winnow::Parser; #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -109,18 +110,25 @@ mod gguf { } #[inline] - fn metadata_kv(input: &mut BytesStream) -> winnow::PResult { - u64(Endianness::Little).parse_next(input) + fn metadata_kv<'i>(metadata_kv_count: u64) -> impl Parser, u64, ContextError> { + move |input: &mut BytesStream| u64(Endianness::Little).parse_next(input) } pub fn parse_header(input: &mut BytesStream) -> winnow::PResult
{ (magic_number, version, tensor_count, metadata_kv_count) - .parse_next(input) - .map(|(gguf, version, tensor_count, metadata_kv_count)| Header { - version, - tensor_count, - metadata_kv_count, + .flat_map(|(gguf, version, tensor_count, metadata_kv_count)| { + metadata_kv(metadata_kv_count).map(move |metadata_kv| { + (gguf, version, tensor_count, metadata_kv_count, metadata_kv) + }) }) + .parse_next(input) + .map( + |(gguf, version, tensor_count, metadata_kv_count, metadata_kv)| Header { + version, + tensor_count, + metadata_kv_count, + }, + ) } } @@ -149,7 +157,7 @@ mod tests { fn test_parse_header() -> anyhow::Result<()> { let mut file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; - let buffer_size = 30; + let buffer_size = 40; let min_buffer_growth = 100; let buffer_growth_factor = 2; let mut buffer = circular::Buffer::with_capacity(buffer_size); From f7a44560f0bfafa9d2a8b8b9c988398fe47dfb1b Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Wed, 24 Jan 2024 20:50:27 +0000 Subject: [PATCH 05/14] Parse string & metadata_kv --- crates/ratchet-downloader/src/lib.rs | 66 +++++++++++++++++++--------- 1 file changed, 46 insertions(+), 20 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index 8ce29d2f..e9ff236e 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -75,49 +75,75 @@ impl Model { mod gguf { use crate::BytesStream; - use winnow::binary::u32; - use winnow::binary::u64; - use winnow::binary::Endianness; - use winnow::error::ContextError; + use winnow::binary::{u32, u64, u8, Endianness}; + + use winnow::error::{AddContext, ContextError, ErrMode, StrContext}; + use winnow::token::take; use winnow::Parser; - #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + pub struct MetadataKv { + pub key: String, + } + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Header { pub version: u32, pub tensor_count: u64, pub metadata_kv_count: u64, + pub metadata_kv: MetadataKv, } #[inline] - fn magic_number(input: &mut BytesStream) -> winnow::PResult<()> { + fn parse_magic_number(input: &mut BytesStream) -> winnow::PResult<()> { // [TODO] Fix endianness (71, 71, 85, 70).parse_next(input).map(|_magic_number| ()) } #[inline] - fn version(input: &mut BytesStream) -> winnow::PResult { + fn parse_version(input: &mut BytesStream) -> winnow::PResult { u32(Endianness::Little).parse_next(input) } #[inline] - fn tensor_count(input: &mut BytesStream) -> winnow::PResult { + fn parse_tensor_count(input: &mut BytesStream) -> winnow::PResult { u64(Endianness::Little).parse_next(input) } #[inline] - fn metadata_kv_count(input: &mut BytesStream) -> winnow::PResult { + fn parse_metadata_kv_count(input: &mut BytesStream) -> winnow::PResult { u64(Endianness::Little).parse_next(input) } + fn parse_string(input: &mut BytesStream) -> winnow::PResult { + u64(Endianness::Little) + .flat_map(|count| take(count)) + .parse_next(input) + .and_then(|bytes| { + String::from_utf8(bytes.to_vec()).map_err(|err| { + ErrMode::Cut( + ContextError::new() + .add_context(input, StrContext::Label("Failed to parse string")), + ) + }) + }) + } + #[inline] - fn metadata_kv<'i>(metadata_kv_count: u64) -> impl Parser, u64, ContextError> { - move |input: &mut BytesStream| u64(Endianness::Little).parse_next(input) + fn parse_metadata_kv<'i>( + metadata_kv_count: u64, + ) -> impl Parser, MetadataKv, ContextError> { + move |input: &mut BytesStream| parse_string.parse_next(input).map(|key| MetadataKv { key }) } pub fn parse_header(input: &mut BytesStream) -> winnow::PResult
{ - (magic_number, version, tensor_count, metadata_kv_count) + ( + parse_magic_number, + parse_version, + parse_tensor_count, + parse_metadata_kv_count, + ) .flat_map(|(gguf, version, tensor_count, metadata_kv_count)| { - metadata_kv(metadata_kv_count).map(move |metadata_kv| { + parse_metadata_kv(metadata_kv_count).map(move |metadata_kv| { (gguf, version, tensor_count, metadata_kv_count, metadata_kv) }) }) @@ -127,6 +153,7 @@ mod gguf { version, tensor_count, metadata_kv_count, + metadata_kv, }, ) } @@ -157,7 +184,7 @@ mod tests { fn test_parse_header() -> anyhow::Result<()> { let mut file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; - let buffer_size = 40; + let buffer_size = 100; let min_buffer_growth = 100; let buffer_growth_factor = 2; let mut buffer = circular::Buffer::with_capacity(buffer_size); @@ -167,12 +194,11 @@ mod tests { let mut input = BytesStream::new(Bytes::new(buffer.data())); let result = gguf::parse_header(&mut input).map_err(to_std_error)?; - let expected = gguf::Header { - version: 3, - tensor_count: 201, - metadata_kv_count: 23, - }; - assert_eq!(result, expected); + + println!("{}", result.metadata_kv.key); + assert_eq!(result.version, 3); + assert_eq!(result.tensor_count, 201); + assert_eq!(result.metadata_kv_count, 23); Ok(()) } } From d1c6d9931ce505d4df469bbf421596cb6be27738 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Wed, 24 Jan 2024 21:21:39 +0000 Subject: [PATCH 06/14] Add metadata value type parsing --- crates/ratchet-downloader/src/lib.rs | 100 ++++++++++++++++++++++----- 1 file changed, 82 insertions(+), 18 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index e9ff236e..34917aee 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -84,6 +84,7 @@ mod gguf { #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct MetadataKv { pub key: String, + pub value_type: MetadataValueType, } #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Header { @@ -93,6 +94,40 @@ mod gguf { pub metadata_kv: MetadataKv, } + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + pub enum MetadataValueType { + // The value is a 8-bit unsigned integer. + GGUF_METADATA_VALUE_TYPE_UINT8 = 0, + // The value is a 8-bit signed integer. + GGUF_METADATA_VALUE_TYPE_INT8 = 1, + // The value is a 16-bit unsigned little-endian integer. + GGUF_METADATA_VALUE_TYPE_UINT16 = 2, + // The value is a 16-bit signed little-endian integer. + GGUF_METADATA_VALUE_TYPE_INT16 = 3, + // The value is a 32-bit unsigned little-endian integer. + GGUF_METADATA_VALUE_TYPE_UINT32 = 4, + // The value is a 32-bit signed little-endian integer. + GGUF_METADATA_VALUE_TYPE_INT32 = 5, + // The value is a 32-bit IEEE754 floating point number. + GGUF_METADATA_VALUE_TYPE_FLOAT32 = 6, + // The value is a boolean. + // 1-byte value where 0 is false and 1 is true. + // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. + GGUF_METADATA_VALUE_TYPE_BOOL = 7, + // The value is a UTF-8 non-null-terminated string, with length prepended. + GGUF_METADATA_VALUE_TYPE_STRING = 8, + // The value is an array of other values, with the length and type prepended. + /// + // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. + GGUF_METADATA_VALUE_TYPE_ARRAY = 9, + // The value is a 64-bit unsigned little-endian integer. + GGUF_METADATA_VALUE_TYPE_UINT64 = 10, + // The value is a 64-bit signed little-endian integer. + GGUF_METADATA_VALUE_TYPE_INT64 = 11, + // The value is a 64-bit IEEE754 floating point number. + GGUF_METADATA_VALUE_TYPE_FLOAT64 = 12, + } + #[inline] fn parse_magic_number(input: &mut BytesStream) -> winnow::PResult<()> { // [TODO] Fix endianness @@ -109,6 +144,28 @@ mod gguf { u64(Endianness::Little).parse_next(input) } + #[inline] + fn parse_metadata_value_type(input: &mut BytesStream) -> winnow::PResult { + u32(Endianness::Little) + .parse_next(input) + .and_then(|value| match value { + 0 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_UINT8), + 1 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_INT8), + 2 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_UINT16), + 3 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_INT16), + 4 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_UINT32), + 5 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_INT32), + 6 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_FLOAT32), + 7 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_BOOL), + 8 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_STRING), + 9 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_ARRAY), + 10 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_UINT64), + 11 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_INT64), + 12 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_FLOAT64), + other => Err(cut_error(input, &"Found invalid metadata type value.")), + }) + } + #[inline] fn parse_metadata_kv_count(input: &mut BytesStream) -> winnow::PResult { u64(Endianness::Little).parse_next(input) @@ -120,19 +177,28 @@ mod gguf { .parse_next(input) .and_then(|bytes| { String::from_utf8(bytes.to_vec()).map_err(|err| { - ErrMode::Cut( - ContextError::new() - .add_context(input, StrContext::Label("Failed to parse string")), - ) + let error_msg = "Failed to parse string"; + cut_error(input, error_msg) }) }) } + fn cut_error( + input: &mut winnow::Partial<&winnow::Bytes>, + error_msg: &'static str, + ) -> ErrMode { + ErrMode::Cut(ContextError::new().add_context(input, StrContext::Label(error_msg))) + } + #[inline] fn parse_metadata_kv<'i>( metadata_kv_count: u64, ) -> impl Parser, MetadataKv, ContextError> { - move |input: &mut BytesStream| parse_string.parse_next(input).map(|key| MetadataKv { key }) + move |input: &mut BytesStream| { + (parse_string, parse_metadata_value_type) + .parse_next(input) + .map(|(key, value_type)| MetadataKv { key, value_type }) + } } pub fn parse_header(input: &mut BytesStream) -> winnow::PResult
{ @@ -157,16 +223,13 @@ mod gguf { }, ) } -} - -pub fn to_std_error(error: winnow::error::ErrMode) -> std::io::Error { - match error { - winnow::error::ErrMode::Backtrack(err) => { - std::io::Error::new(std::io::ErrorKind::Other, "Backtrack") - } - winnow::error::ErrMode::Cut(err) => std::io::Error::new(std::io::ErrorKind::Other, "Cut"), - winnow::error::ErrMode::Incomplete(needed) => { - std::io::Error::new(std::io::ErrorKind::Other, "Needed") + pub fn to_std_error( + error: winnow::error::ErrMode, + ) -> std::io::Error { + match error { + ErrMode::Backtrack(err) => std::io::Error::new(std::io::ErrorKind::Other, "Backtrack"), + ErrMode::Cut(err) => std::io::Error::new(std::io::ErrorKind::Other, "Cut"), + ErrMode::Incomplete(needed) => std::io::Error::new(std::io::ErrorKind::Other, "Needed"), } } } @@ -178,7 +241,7 @@ mod tests { use anyhow::Error; use winnow::Bytes; - use crate::{gguf, to_std_error, BytesStream}; + use crate::{gguf, BytesStream}; #[test] fn test_parse_header() -> anyhow::Result<()> { @@ -193,9 +256,10 @@ mod tests { let mut input = BytesStream::new(Bytes::new(buffer.data())); - let result = gguf::parse_header(&mut input).map_err(to_std_error)?; + let result = gguf::parse_header(&mut input).map_err(gguf::to_std_error)?; - println!("{}", result.metadata_kv.key); + println!("{:#?}", result.metadata_kv); + println!("{:#?}", result.metadata_kv.value_type); assert_eq!(result.version, 3); assert_eq!(result.tensor_count, 201); assert_eq!(result.metadata_kv_count, 23); From 9fee16901fe8e793bf440c3b1694d64e4fd8fba9 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Thu, 25 Jan 2024 20:21:05 +0000 Subject: [PATCH 07/14] Parse metadata kv types --- crates/ratchet-downloader/src/lib.rs | 237 ++++++++++++++++++++------- 1 file changed, 174 insertions(+), 63 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index 34917aee..cf879a03 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -75,57 +75,58 @@ impl Model { mod gguf { use crate::BytesStream; + use futures_util::io::repeat; use winnow::binary::{u32, u64, u8, Endianness}; + use winnow::combinator::fail; use winnow::error::{AddContext, ContextError, ErrMode, StrContext}; use winnow::token::take; use winnow::Parser; - #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + #[derive(Clone, Debug)] pub struct MetadataKv { pub key: String, - pub value_type: MetadataValueType, + pub metadata_value: MetadataValue, } - #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + #[derive(Clone, Debug)] pub struct Header { pub version: u32, pub tensor_count: u64, pub metadata_kv_count: u64, pub metadata_kv: MetadataKv, } - - #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] + #[derive(Clone, Debug)] pub enum MetadataValueType { - // The value is a 8-bit unsigned integer. - GGUF_METADATA_VALUE_TYPE_UINT8 = 0, - // The value is a 8-bit signed integer. - GGUF_METADATA_VALUE_TYPE_INT8 = 1, - // The value is a 16-bit unsigned little-endian integer. - GGUF_METADATA_VALUE_TYPE_UINT16 = 2, - // The value is a 16-bit signed little-endian integer. - GGUF_METADATA_VALUE_TYPE_INT16 = 3, - // The value is a 32-bit unsigned little-endian integer. - GGUF_METADATA_VALUE_TYPE_UINT32 = 4, - // The value is a 32-bit signed little-endian integer. - GGUF_METADATA_VALUE_TYPE_INT32 = 5, - // The value is a 32-bit IEEE754 floating point number. - GGUF_METADATA_VALUE_TYPE_FLOAT32 = 6, - // The value is a boolean. - // 1-byte value where 0 is false and 1 is true. - // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. - GGUF_METADATA_VALUE_TYPE_BOOL = 7, - // The value is a UTF-8 non-null-terminated string, with length prepended. - GGUF_METADATA_VALUE_TYPE_STRING = 8, - // The value is an array of other values, with the length and type prepended. - /// - // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. - GGUF_METADATA_VALUE_TYPE_ARRAY = 9, - // The value is a 64-bit unsigned little-endian integer. - GGUF_METADATA_VALUE_TYPE_UINT64 = 10, - // The value is a 64-bit signed little-endian integer. - GGUF_METADATA_VALUE_TYPE_INT64 = 11, - // The value is a 64-bit IEEE754 floating point number. - GGUF_METADATA_VALUE_TYPE_FLOAT64 = 12, + GgufMetadataValueTypeUint8, + GgufMetadataValueTypeInt8, + GgufMetadataValueTypeUint16, + GgufMetadataValueTypeInt16, + GgufMetadataValueTypeUint32, + GgufMetadataValueTypeInt32, + GgufMetadataValueTypeFloat32, + GgufMetadataValueTypeBool, + GgufMetadataValueTypeString, + GgufMetadataValueTypeArray, + GgufMetadataValueTypeUint64, + GgufMetadataValueTypeInt64, + GgufMetadataValueTypeFloat64, + } + + #[derive(Clone, Debug)] + pub enum MetadataValue { + GgufMetadataValueUint8(u8), + GgufMetadataValueInt8(i8), + GgufMetadataValueUint16(u16), + GgufMetadataValueInt16(i16), + GgufMetadataValueUint32(u32), + GgufMetadataValueInt32(i32), + GgufMetadataValueFloat32(f32), + GgufMetadataValueBool(bool), + GgufMetadataValueString(String), + GgufMetadataValueArray(Vec), + GgufMetadataValueUint64(u64), + GgufMetadataValueInt64(i64), + GgufMetadataValueFloat64(f64), } #[inline] @@ -144,28 +145,141 @@ mod gguf { u64(Endianness::Little).parse_next(input) } - #[inline] + fn parse_metadata_value_array(input: &mut BytesStream) -> winnow::PResult { + (parse_metadata_value_type, u64(Endianness::Little)) + .flat_map(|(metadata_value_type, length)| { + winnow::combinator::repeat( + length as usize, + parse_metadata_value(metadata_value_type), + ) + }) + .parse_next(input) + .map(MetadataValue::GgufMetadataValueArray) + } + fn parse_metadata_value_type(input: &mut BytesStream) -> winnow::PResult { u32(Endianness::Little) .parse_next(input) - .and_then(|value| match value { - 0 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_UINT8), - 1 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_INT8), - 2 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_UINT16), - 3 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_INT16), - 4 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_UINT32), - 5 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_INT32), - 6 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_FLOAT32), - 7 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_BOOL), - 8 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_STRING), - 9 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_ARRAY), - 10 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_UINT64), - 11 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_INT64), - 12 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_FLOAT64), - other => Err(cut_error(input, &"Found invalid metadata type value.")), + .and_then(|metadata_value_type| match metadata_value_type { + 0 => Ok(MetadataValueType::GgufMetadataValueTypeUint8), + 1 => Ok(MetadataValueType::GgufMetadataValueTypeInt8), + 2 => Ok(MetadataValueType::GgufMetadataValueTypeUint16), + 3 => Ok(MetadataValueType::GgufMetadataValueTypeInt16), + 4 => Ok(MetadataValueType::GgufMetadataValueTypeUint32), + 5 => Ok(MetadataValueType::GgufMetadataValueTypeInt32), + 6 => Ok(MetadataValueType::GgufMetadataValueTypeFloat32), + 7 => Ok(MetadataValueType::GgufMetadataValueTypeBool), + 8 => Ok(MetadataValueType::GgufMetadataValueTypeString), + 9 => Ok(MetadataValueType::GgufMetadataValueTypeArray), + 10 => Ok(MetadataValueType::GgufMetadataValueTypeUint64), + 11 => Ok(MetadataValueType::GgufMetadataValueTypeInt64), + 12 => Ok(MetadataValueType::GgufMetadataValueTypeFloat64), + other => Err(cut_error(input, "Unknown metadata value type.")), }) } + #[inline] + fn parse_metadata_value<'i>( + metadata_value_type: MetadataValueType, + ) -> impl Parser, MetadataValue, ContextError> { + move |input: &mut BytesStream| match metadata_value_type { + MetadataValueType::GgufMetadataValueTypeUint8 => winnow::binary::u8 + .map(MetadataValue::GgufMetadataValueUint8) + .parse_next(input), + + MetadataValueType::GgufMetadataValueTypeInt8 => winnow::binary::i8 + .map(MetadataValue::GgufMetadataValueInt8) + .parse_next(input), + MetadataValueType::GgufMetadataValueTypeUint16 => { + winnow::binary::u16(Endianness::Little) + .map(MetadataValue::GgufMetadataValueUint16) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeInt16 => { + winnow::binary::i16(Endianness::Little) + .map(MetadataValue::GgufMetadataValueInt16) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeUint32 => { + winnow::binary::u32(Endianness::Little) + .map(MetadataValue::GgufMetadataValueUint32) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeInt32 => { + winnow::binary::i32(Endianness::Little) + .map(MetadataValue::GgufMetadataValueInt32) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeFloat32 => { + winnow::binary::f32(Endianness::Little) + .map(MetadataValue::GgufMetadataValueFloat32) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeBool => winnow::binary::i8 + .map(|b| { + if b == 0 { + MetadataValue::GgufMetadataValueBool(true) + } else { + MetadataValue::GgufMetadataValueBool(false) + } + }) + .parse_next(input), + MetadataValueType::GgufMetadataValueTypeString => parse_string + .map(MetadataValue::GgufMetadataValueString) + .parse_next(input), + MetadataValueType::GgufMetadataValueTypeArray => { + parse_metadata_value_array.parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeUint64 => { + winnow::binary::u64(Endianness::Little) + .map(MetadataValue::GgufMetadataValueUint64) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeInt64 => { + winnow::binary::i64(Endianness::Little) + .map(MetadataValue::GgufMetadataValueInt64) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeFloat64 => { + winnow::binary::f64(Endianness::Little) + .map(MetadataValue::GgufMetadataValueFloat64) + .parse_next(input) + } + } + } + + // #[inline] + // fn parse_metadata_value_type<'i>( + // metadata_value_type: u64, + // ) -> impl Parser, MetadataKv, ContextError> { + // move |input: &mut BytesStream| { + // let parser: Parser, MetadataKv, ContextError> = + // match metadata_value_type { + // 0 => u8.map(MetadataValue::GgufMetadataValueTypeUint8), + // 1 => i8.map(MetadataValue::GgufMetadataValueTypeInt8), + // 2 => u16.map(MetadataValue::GgufMetadataValueTypeUint16), + // 3 => i16.map(MetadataValue::GgufMetadataValueTypeInt16), + // 4 => u32.map(MetadataValue::GgufMetadataValueTypeUint32), + // 5 => i32.map(MetadataValue::GgufMetadataValueTypeInt32), + // 6 => f32.map(MetadataValue::GgufMetadataValueTypeFloat32), + // 7 => bool.map(MetadataValue::GgufMetadataValueTypeBool), + // 8 => parse_string.map(MetadataValue::GgufMetadataValueTypeString), + // // 9 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_ARRAY), + // 10 => u64.map(MetadataValue::GgufMetadataValueTypeUint64), + // 11 => i64.map(MetadataValue::GgufMetadataValueTypeInt64), + // 12 => f64.map(MetadataValue::GgufMetadataValueTypeFloat64), + // other => parse_string.map(MetadataValue::GgufMetadataValueTypeString), + // }; + // parser.parse_next(input) + // } + // } + + fn parse_metadata_value_single(input: &mut BytesStream) -> winnow::PResult { + parse_metadata_value_type + .flat_map(|metadata_value_type| parse_metadata_value(metadata_value_type)) + .parse_next(input) + } + #[inline] fn parse_metadata_kv_count(input: &mut BytesStream) -> winnow::PResult { u64(Endianness::Little).parse_next(input) @@ -195,9 +309,12 @@ mod gguf { metadata_kv_count: u64, ) -> impl Parser, MetadataKv, ContextError> { move |input: &mut BytesStream| { - (parse_string, parse_metadata_value_type) + (parse_string, parse_metadata_value_single) .parse_next(input) - .map(|(key, value_type)| MetadataKv { key, value_type }) + .map(|(key, metadata_value)| MetadataKv { + key, + metadata_value, + }) } } @@ -209,19 +326,14 @@ mod gguf { parse_metadata_kv_count, ) .flat_map(|(gguf, version, tensor_count, metadata_kv_count)| { - parse_metadata_kv(metadata_kv_count).map(move |metadata_kv| { - (gguf, version, tensor_count, metadata_kv_count, metadata_kv) - }) - }) - .parse_next(input) - .map( - |(gguf, version, tensor_count, metadata_kv_count, metadata_kv)| Header { + parse_metadata_kv(metadata_kv_count).map(move |metadata_kv| Header { version, tensor_count, metadata_kv_count, metadata_kv, - }, - ) + }) + }) + .parse_next(input) } pub fn to_std_error( error: winnow::error::ErrMode, @@ -259,7 +371,6 @@ mod tests { let result = gguf::parse_header(&mut input).map_err(gguf::to_std_error)?; println!("{:#?}", result.metadata_kv); - println!("{:#?}", result.metadata_kv.value_type); assert_eq!(result.version, 3); assert_eq!(result.tensor_count, 201); assert_eq!(result.metadata_kv_count, 23); From d3a8fba809b4161690a2e51542222086eb2d50d9 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Thu, 25 Jan 2024 20:41:02 +0000 Subject: [PATCH 08/14] Finish header parsing --- crates/ratchet-downloader/src/lib.rs | 40 +++++----------------------- 1 file changed, 7 insertions(+), 33 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index cf879a03..dbbd42cb 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -75,7 +75,6 @@ impl Model { mod gguf { use crate::BytesStream; - use futures_util::io::repeat; use winnow::binary::{u32, u64, u8, Endianness}; use winnow::combinator::fail; @@ -92,8 +91,7 @@ mod gguf { pub struct Header { pub version: u32, pub tensor_count: u64, - pub metadata_kv_count: u64, - pub metadata_kv: MetadataKv, + pub metadata_kv: Vec, } #[derive(Clone, Debug)] pub enum MetadataValueType { @@ -248,32 +246,6 @@ mod gguf { } } - // #[inline] - // fn parse_metadata_value_type<'i>( - // metadata_value_type: u64, - // ) -> impl Parser, MetadataKv, ContextError> { - // move |input: &mut BytesStream| { - // let parser: Parser, MetadataKv, ContextError> = - // match metadata_value_type { - // 0 => u8.map(MetadataValue::GgufMetadataValueTypeUint8), - // 1 => i8.map(MetadataValue::GgufMetadataValueTypeInt8), - // 2 => u16.map(MetadataValue::GgufMetadataValueTypeUint16), - // 3 => i16.map(MetadataValue::GgufMetadataValueTypeInt16), - // 4 => u32.map(MetadataValue::GgufMetadataValueTypeUint32), - // 5 => i32.map(MetadataValue::GgufMetadataValueTypeInt32), - // 6 => f32.map(MetadataValue::GgufMetadataValueTypeFloat32), - // 7 => bool.map(MetadataValue::GgufMetadataValueTypeBool), - // 8 => parse_string.map(MetadataValue::GgufMetadataValueTypeString), - // // 9 => Ok(MetadataValueType::GGUF_METADATA_VALUE_TYPE_ARRAY), - // 10 => u64.map(MetadataValue::GgufMetadataValueTypeUint64), - // 11 => i64.map(MetadataValue::GgufMetadataValueTypeInt64), - // 12 => f64.map(MetadataValue::GgufMetadataValueTypeFloat64), - // other => parse_string.map(MetadataValue::GgufMetadataValueTypeString), - // }; - // parser.parse_next(input) - // } - // } - fn parse_metadata_value_single(input: &mut BytesStream) -> winnow::PResult { parse_metadata_value_type .flat_map(|metadata_value_type| parse_metadata_value(metadata_value_type)) @@ -326,10 +298,13 @@ mod gguf { parse_metadata_kv_count, ) .flat_map(|(gguf, version, tensor_count, metadata_kv_count)| { - parse_metadata_kv(metadata_kv_count).map(move |metadata_kv| Header { + winnow::combinator::repeat( + metadata_kv_count as usize, + parse_metadata_kv(metadata_kv_count), + ) + .map(move |metadata_kv| Header { version, tensor_count, - metadata_kv_count, metadata_kv, }) }) @@ -359,7 +334,7 @@ mod tests { fn test_parse_header() -> anyhow::Result<()> { let mut file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; - let buffer_size = 100; + let buffer_size = 10_000_000; let min_buffer_growth = 100; let buffer_growth_factor = 2; let mut buffer = circular::Buffer::with_capacity(buffer_size); @@ -373,7 +348,6 @@ mod tests { println!("{:#?}", result.metadata_kv); assert_eq!(result.version, 3); assert_eq!(result.tensor_count, 201); - assert_eq!(result.metadata_kv_count, 23); Ok(()) } } From 6c48ccf96733ffc6a1e8e4ca72eee180ed489c4c Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Fri, 26 Jan 2024 16:02:29 +0000 Subject: [PATCH 09/14] Try out ringbuffer --- crates/ratchet-downloader/src/lib.rs | 107 +++++++++++++++++++++++---- 1 file changed, 91 insertions(+), 16 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index dbbd42cb..e8b55537 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -12,6 +12,9 @@ use web_sys::{console, ReadableStreamGetReaderOptions, ReadableStreamReaderMode} use winnow::{binary::bits::bytes, prelude::*, stream::Stream, Bytes, Partial}; use winnow::{binary::u32, binary::u64, combinator::preceded, Parser}; +use std::io::Read; + +use anyhow::Error; mod fetch; pub mod huggingface; @@ -78,7 +81,10 @@ mod gguf { use winnow::binary::{u32, u64, u8, Endianness}; use winnow::combinator::fail; + use winnow::error::Needed; use winnow::error::{AddContext, ContextError, ErrMode, StrContext}; + use winnow::prelude; + use winnow::stream::Offset; use winnow::token::take; use winnow::Parser; @@ -273,6 +279,7 @@ mod gguf { input: &mut winnow::Partial<&winnow::Bytes>, error_msg: &'static str, ) -> ErrMode { + println!("Error: {}", error_msg); ErrMode::Cut(ContextError::new().add_context(input, StrContext::Label(error_msg))) } @@ -310,6 +317,80 @@ mod gguf { }) .parse_next(input) } + + pub fn load_gguf(mut file: std::fs::File) -> Result, anyhow::Error> { + use std::io::Read; + let buffer_size = 10_000_000; + let min_buffer_growth = 10_000_000; + let buffer_growth_factor = 2; + let mut buffer = circular::Buffer::with_capacity(buffer_size); + + let mut maybe_header: Option
= None; + loop { + let read = file.read(buffer.space())?; + + if read == 0 { + // Should be EOF since we always make sure there is `available_space` + assert_ne!(buffer.available_space(), 0); + assert_eq!( + buffer.available_data(), + 0, + "leftover data: {}", + String::from_utf8_lossy(buffer.data()) + ); + break; + } + buffer.fill(read); + + loop { + let mut input = BytesStream::new(winnow::Bytes::new(buffer.data())); + let result = parse_header.parse_peek(input); + match result { + Ok((remainder, header)) => { + // Tell the buffer how much we read + let consumed = remainder.offset_from(&input); + buffer.consume(consumed); + maybe_header = Some(header) + } + Err(ErrMode::Backtrack(e)) => { + let pos = buffer.position(); + println!("Backtrack, position={}, error={}", pos, e); + return Err(anyhow::format_err!(e.to_string())); + } + Err(ErrMode::Cut(e)) => { + println!("Cut: {:#?}", e); + return Err(anyhow::format_err!(e.to_string())); + } + Err(ErrMode::Incomplete(Needed::Size(size))) => { + // Without the format telling us how much space is required, we really should + // treat this the same as `Unknown` but are doing this to demonstrate how to + // handle `Size`. + // + // Even when the format has a header to tell us `Size`, we could hit incidental + // `Size(1)`s, so make sure we buffer more space than that to avoid reading + // one byte at a time + let head_room = size.get().max(min_buffer_growth); + let new_capacity = buffer.available_data() + head_room; + println!("growing buffer to {}", new_capacity); + buffer.grow(new_capacity); + if buffer.available_space() < head_room { + println!("buffer shift"); + buffer.shift(); + } + break; + } + Err(ErrMode::Incomplete(Needed::Unknown)) => { + let new_capacity = buffer_growth_factor * buffer.capacity(); + println!("growing buffer to {}", new_capacity); + buffer.grow(new_capacity); + break; + } + } + } + } + Ok(maybe_header) + } + pub fn to_std_error( error: winnow::error::ErrMode, ) -> std::io::Error { @@ -323,10 +404,6 @@ mod gguf { #[cfg(test)] mod tests { - use std::io::Read; - - use anyhow::Error; - use winnow::Bytes; use crate::{gguf, BytesStream}; @@ -334,20 +411,18 @@ mod tests { fn test_parse_header() -> anyhow::Result<()> { let mut file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; - let buffer_size = 10_000_000; - let min_buffer_growth = 100; - let buffer_growth_factor = 2; - let mut buffer = circular::Buffer::with_capacity(buffer_size); - let read = file.read(buffer.space())?; - buffer.fill(read); - - let mut input = BytesStream::new(Bytes::new(buffer.data())); + let result = gguf::load_gguf(file); - let result = gguf::parse_header(&mut input).map_err(gguf::to_std_error)?; + match result { + Ok(None) => println!("Header was None"), + Ok(Some(header)) => { + println!("{:#?}", header.metadata_kv); + assert_eq!(header.version, 3); + assert_eq!(header.tensor_count, 201) + } + Err(err) => println!("Got an error: {:#?}", err), + } - println!("{:#?}", result.metadata_kv); - assert_eq!(result.version, 3); - assert_eq!(result.tensor_count, 201); Ok(()) } } From ede621cfe6581cb366875ca8383644cb321bb618 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Fri, 26 Jan 2024 22:09:16 +0000 Subject: [PATCH 10/14] Fix --- crates/ratchet-downloader/src/lib.rs | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index e8b55537..2cc9e925 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -84,7 +84,7 @@ mod gguf { use winnow::error::Needed; use winnow::error::{AddContext, ContextError, ErrMode, StrContext}; use winnow::prelude; - use winnow::stream::Offset; + use winnow::stream::{Offset, Stream}; use winnow::token::take; use winnow::Parser; @@ -320,16 +320,18 @@ mod gguf { pub fn load_gguf(mut file: std::fs::File) -> Result, anyhow::Error> { use std::io::Read; - let buffer_size = 10_000_000; - let min_buffer_growth = 10_000_000; + let buffer_size = 1_000_000; + let min_buffer_growth = 1_000_000; let buffer_growth_factor = 2; let mut buffer = circular::Buffer::with_capacity(buffer_size); let mut maybe_header: Option
= None; - loop { + 'outer: loop { + println!("Reading new buffer space"); let read = file.read(buffer.space())?; if read == 0 { + println!("Read 0"); // Should be EOF since we always make sure there is `available_space` assert_ne!(buffer.available_space(), 0); assert_eq!( @@ -338,19 +340,24 @@ mod gguf { "leftover data: {}", String::from_utf8_lossy(buffer.data()) ); - break; + break 'outer; } buffer.fill(read); - loop { + println!("buffer position: {}", buffer.position()); + 'inner: loop { let mut input = BytesStream::new(winnow::Bytes::new(buffer.data())); + + println!("stream length: {}", input.len()); let result = parse_header.parse_peek(input); match result { Ok((remainder, header)) => { // Tell the buffer how much we read + println!("Read header!"); let consumed = remainder.offset_from(&input); buffer.consume(consumed); - maybe_header = Some(header) + maybe_header = Some(header); + break 'outer; } Err(ErrMode::Backtrack(e)) => { let pos = buffer.position(); @@ -377,13 +384,15 @@ mod gguf { println!("buffer shift"); buffer.shift(); } - break; + println!("breaking inner"); + break 'inner; } Err(ErrMode::Incomplete(Needed::Unknown)) => { let new_capacity = buffer_growth_factor * buffer.capacity(); println!("growing buffer to {}", new_capacity); buffer.grow(new_capacity); - break; + println!("breaking inner - unknown"); + break 'inner; } } } From f3abaeb1ea9245642b93c0df77a1ae3c563a9db9 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Sat, 27 Jan 2024 14:52:54 +0000 Subject: [PATCH 11/14] Add tensor_info parser. Extract buffered parsing. --- crates/ratchet-downloader/src/lib.rs | 147 +++++++++++++++++++-------- 1 file changed, 107 insertions(+), 40 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index 2cc9e925..4544a31d 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -80,6 +80,7 @@ mod gguf { use crate::BytesStream; use winnow::binary::{u32, u64, u8, Endianness}; + use anyhow::anyhow; use winnow::combinator::fail; use winnow::error::Needed; use winnow::error::{AddContext, ContextError, ErrMode, StrContext}; @@ -99,6 +100,14 @@ mod gguf { pub tensor_count: u64, pub metadata_kv: Vec, } + + pub struct TensorInfo { + pub name: String, + pub dimensions: Vec, + pub ggml_type: GgmlType, + pub offset: u64, + } + #[derive(Clone, Debug)] pub enum MetadataValueType { GgufMetadataValueTypeUint8, @@ -133,6 +142,29 @@ mod gguf { GgufMetadataValueFloat64(f64), } + #[derive(Clone, Debug)] + pub enum GgmlType { + GgmlTypeF32, + GgmlTypeF16, + GgmlTypeQ4_0, + GgmlTypeQ4_1, + GgmlTypeQ5_0, + GgmlTypeQ5_1, + GgmlTypeQ8_0, + GgmlTypeQ8_1, + // k-quantizations + GgmlTypeQ2K, + GgmlTypeQ3K, + GgmlTypeQ4K, + GgmlTypeQ5K, + GgmlTypeQ6K, + GgmlTypeQ8K, + GgmlTypeI8, + GgmlTypeI16, + GgmlTypeI32, + GgmlTypeCount, + } + #[inline] fn parse_magic_number(input: &mut BytesStream) -> winnow::PResult<()> { // [TODO] Fix endianness @@ -275,6 +307,52 @@ mod gguf { }) } + fn parse_ggml_type(input: &mut BytesStream) -> winnow::PResult { + u32(Endianness::Little) + .parse_next(input) + .and_then(|metadata_value_type| match metadata_value_type { + 0 => Ok(GgmlType::GgmlTypeF32), + 1 => Ok(GgmlType::GgmlTypeF16), + 2 => Ok(GgmlType::GgmlTypeQ4_0), + 3 => Ok(GgmlType::GgmlTypeQ4_1), + // 4 & 5 have been removed + 6 => Ok(GgmlType::GgmlTypeQ5_0), + 7 => Ok(GgmlType::GgmlTypeQ5_1), + 8 => Ok(GgmlType::GgmlTypeQ8_0), + 9 => Ok(GgmlType::GgmlTypeQ8_1), + // k-quantizations + 10 => Ok(GgmlType::GgmlTypeQ2K), + 11 => Ok(GgmlType::GgmlTypeQ3K), + 12 => Ok(GgmlType::GgmlTypeQ4K), + 13 => Ok(GgmlType::GgmlTypeQ5K), + 14 => Ok(GgmlType::GgmlTypeQ6K), + 15 => Ok(GgmlType::GgmlTypeQ8K), + 16 => Ok(GgmlType::GgmlTypeI8), + 17 => Ok(GgmlType::GgmlTypeI16), + 18 => Ok(GgmlType::GgmlTypeI32), + 19 => Ok(GgmlType::GgmlTypeCount), + other => Err(cut_error(input, "Unknown metadata value type.")), + }) + } + + fn parse_tensor_info(input: &mut BytesStream) -> winnow::PResult { + (parse_string, u32(Endianness::Little)) + .flat_map(|(name, n_dimensions)| { + let dimensions_parser = + winnow::combinator::repeat(n_dimensions as usize, u64(Endianness::Little)); + + (dimensions_parser, parse_ggml_type, u64(Endianness::Little)).map( + move |(dimensions, ggml_type, offset)| TensorInfo { + name: name.clone(), + dimensions, + ggml_type, + offset, + }, + ) + }) + .parse_next(input) + } + fn cut_error( input: &mut winnow::Partial<&winnow::Bytes>, error_msg: &'static str, @@ -304,7 +382,7 @@ mod gguf { parse_tensor_count, parse_metadata_kv_count, ) - .flat_map(|(gguf, version, tensor_count, metadata_kv_count)| { + .flat_map(|(_gguf, version, tensor_count, metadata_kv_count)| { winnow::combinator::repeat( metadata_kv_count as usize, parse_metadata_kv(metadata_kv_count), @@ -318,16 +396,31 @@ mod gguf { .parse_next(input) } - pub fn load_gguf(mut file: std::fs::File) -> Result, anyhow::Error> { - use std::io::Read; + pub fn load_gguf(mut file: std::fs::File) -> anyhow::Result
{ let buffer_size = 1_000_000; let min_buffer_growth = 1_000_000; let buffer_growth_factor = 2; let mut buffer = circular::Buffer::with_capacity(buffer_size); - let mut maybe_header: Option
= None; + let mut parser = parse_header; + + let res = parse_with_buffer(file, buffer, parser, buffer_growth_factor)?; + res + } + + fn parse_with_buffer( + mut file: std::fs::File, + mut buffer: circular::Buffer, + mut parser: fn( + &mut winnow::Partial<&winnow::Bytes>, + ) -> Result>, + buffer_growth_factor: usize, + ) -> Result, anyhow::Error> { + use std::io::Read; + let mut result: anyhow::Result
= Err(anyhow!( + "An unknown error occurred while parsing the header.", + )); 'outer: loop { - println!("Reading new buffer space"); let read = file.read(buffer.space())?; if read == 0 { @@ -344,60 +437,35 @@ mod gguf { } buffer.fill(read); - println!("buffer position: {}", buffer.position()); 'inner: loop { let mut input = BytesStream::new(winnow::Bytes::new(buffer.data())); - println!("stream length: {}", input.len()); - let result = parse_header.parse_peek(input); - match result { - Ok((remainder, header)) => { + let parser_result = parser.parse_peek(input); + match parser_result { + Ok((remainder, parser_output)) => { // Tell the buffer how much we read - println!("Read header!"); let consumed = remainder.offset_from(&input); buffer.consume(consumed); - maybe_header = Some(header); + result = Ok(parser_output); break 'outer; } Err(ErrMode::Backtrack(e)) => { let pos = buffer.position(); - println!("Backtrack, position={}, error={}", pos, e); return Err(anyhow::format_err!(e.to_string())); } Err(ErrMode::Cut(e)) => { - println!("Cut: {:#?}", e); return Err(anyhow::format_err!(e.to_string())); } - Err(ErrMode::Incomplete(Needed::Size(size))) => { - // Without the format telling us how much space is required, we really should - // treat this the same as `Unknown` but are doing this to demonstrate how to - // handle `Size`. - // - // Even when the format has a header to tell us `Size`, we could hit incidental - // `Size(1)`s, so make sure we buffer more space than that to avoid reading - // one byte at a time - let head_room = size.get().max(min_buffer_growth); - let new_capacity = buffer.available_data() + head_room; - println!("growing buffer to {}", new_capacity); - buffer.grow(new_capacity); - if buffer.available_space() < head_room { - println!("buffer shift"); - buffer.shift(); - } - println!("breaking inner"); - break 'inner; - } - Err(ErrMode::Incomplete(Needed::Unknown)) => { + Err(ErrMode::Incomplete(_)) => { let new_capacity = buffer_growth_factor * buffer.capacity(); - println!("growing buffer to {}", new_capacity); buffer.grow(new_capacity); - println!("breaking inner - unknown"); break 'inner; } } } } - Ok(maybe_header) + let res = result; + Ok(res) } pub fn to_std_error( @@ -423,9 +491,8 @@ mod tests { let result = gguf::load_gguf(file); match result { - Ok(None) => println!("Header was None"), - Ok(Some(header)) => { - println!("{:#?}", header.metadata_kv); + Ok(header) => { + // println!("{:#?}", header.metadata_kv); assert_eq!(header.version, 3); assert_eq!(header.tensor_count, 201) } From b95ac408f37fe506b90b730807d4484703b133f2 Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Sat, 27 Jan 2024 15:29:58 +0000 Subject: [PATCH 12/14] Finish tensor info loading --- crates/ratchet-downloader/src/lib.rs | 51 ++++++++++++++++------------ 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index 4544a31d..36e047b4 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(seek_stream_len)] use js_sys::Uint8Array; #[cfg(test)] use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure}; @@ -77,6 +78,8 @@ impl Model { } mod gguf { + use std::io::Seek; + use crate::BytesStream; use winnow::binary::{u32, u64, u8, Endianness}; @@ -101,6 +104,7 @@ mod gguf { pub metadata_kv: Vec, } + #[derive(Clone, Debug)] pub struct TensorInfo { pub name: String, pub dimensions: Vec, @@ -396,31 +400,37 @@ mod gguf { .parse_next(input) } - pub fn load_gguf(mut file: std::fs::File) -> anyhow::Result
{ + pub fn load_gguf(mut file: std::fs::File) -> anyhow::Result<(Header, Vec)> { let buffer_size = 1_000_000; - let min_buffer_growth = 1_000_000; let buffer_growth_factor = 2; let mut buffer = circular::Buffer::with_capacity(buffer_size); - let mut parser = parse_header; - - let res = parse_with_buffer(file, buffer, parser, buffer_growth_factor)?; - res + let header = parse_with_buffer(&mut file, &mut buffer, parse_header, buffer_growth_factor)?; + let mut tensor_infos: Vec = vec![]; + for i in 0..header.tensor_count { + let tensor_info = parse_with_buffer( + &mut file, + &mut buffer, + parse_tensor_info, + buffer_growth_factor, + )?; + tensor_infos.push(tensor_info); + } + Ok((header, tensor_infos)) } - fn parse_with_buffer( - mut file: std::fs::File, - mut buffer: circular::Buffer, - mut parser: fn( - &mut winnow::Partial<&winnow::Bytes>, - ) -> Result>, + fn parse_with_buffer( + file: &mut std::fs::File, + buffer: &mut circular::Buffer, + mut parser: fn(&mut winnow::Partial<&winnow::Bytes>) -> Result>, buffer_growth_factor: usize, - ) -> Result, anyhow::Error> { + ) -> anyhow::Result { use std::io::Read; - let mut result: anyhow::Result
= Err(anyhow!( - "An unknown error occurred while parsing the header.", - )); + let mut result: anyhow::Result = Err(anyhow!("Failed to read file.",)); 'outer: loop { + if buffer.available_space() == 0 { + buffer.grow(buffer_growth_factor * buffer.capacity()); + } let read = file.read(buffer.space())?; if read == 0 { @@ -464,8 +474,7 @@ mod gguf { } } } - let res = result; - Ok(res) + result } pub fn to_std_error( @@ -485,14 +494,14 @@ mod tests { use crate::{gguf, BytesStream}; #[test] - fn test_parse_header() -> anyhow::Result<()> { + fn test_load_gguf() -> anyhow::Result<()> { let mut file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; let result = gguf::load_gguf(file); match result { - Ok(header) => { - // println!("{:#?}", header.metadata_kv); + Ok((header, tensor_info)) => { + println!("{:#?}", tensor_info); assert_eq!(header.version, 3); assert_eq!(header.tensor_count, 201) } From 402e01ddc3f9b4607c98d3f09c11f9e9e223317f Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Sat, 27 Jan 2024 20:00:20 +0000 Subject: [PATCH 13/14] Get alignment --- crates/ratchet-downloader/src/lib.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index 36e047b4..1ace0356 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -406,6 +406,20 @@ mod gguf { let mut buffer = circular::Buffer::with_capacity(buffer_size); let header = parse_with_buffer(&mut file, &mut buffer, parse_header, buffer_growth_factor)?; + + let alignment = header + .metadata_kv + .iter() + .find_map(|metadata_kv| match metadata_kv { + MetadataKv { + key, + metadata_value: MetadataValue::GgufMetadataValueUint32(v), + } if key.eq("general.alignment") => Some(v.clone()), + _ => None, + }) + // As per spec assume 32 if general.alignment is not present + .unwrap_or(32); + let mut tensor_infos: Vec = vec![]; for i in 0..header.tensor_count { let tensor_info = parse_with_buffer( @@ -501,7 +515,6 @@ mod tests { match result { Ok((header, tensor_info)) => { - println!("{:#?}", tensor_info); assert_eq!(header.version, 3); assert_eq!(header.tensor_count, 201) } From 9d8334931cbd0fb4b68b022680ccb3f19adb942b Mon Sep 17 00:00:00 2001 From: sigma-andex <77549848+sigma-andex@users.noreply.github.com> Date: Sat, 27 Jan 2024 21:25:10 +0000 Subject: [PATCH 14/14] Head exploding --- crates/ratchet-downloader/src/lib.rs | 33 ++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs index 1ace0356..477856e1 100644 --- a/crates/ratchet-downloader/src/lib.rs +++ b/crates/ratchet-downloader/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(seek_stream_len)] use js_sys::Uint8Array; #[cfg(test)] use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure}; @@ -13,9 +12,6 @@ use web_sys::{console, ReadableStreamGetReaderOptions, ReadableStreamReaderMode} use winnow::{binary::bits::bytes, prelude::*, stream::Stream, Bytes, Partial}; use winnow::{binary::u32, binary::u64, combinator::preceded, Parser}; -use std::io::Read; - -use anyhow::Error; mod fetch; pub mod huggingface; @@ -379,6 +375,14 @@ mod gguf { } } + fn parse_padding<'i>(padding: u64) -> impl Parser, (), ContextError> { + move |input: &mut BytesStream| { + winnow::combinator::repeat(padding as usize, u8) + .parse_next(input) + .map(|_: Vec| ()) + } + } + pub fn parse_header(input: &mut BytesStream) -> winnow::PResult
{ ( parse_magic_number, @@ -400,6 +404,10 @@ mod gguf { .parse_next(input) } + fn align_offset(alignment: u64, offset: u64) -> u64 { + return offset + (alignment - (offset % alignment)) % alignment; + } + pub fn load_gguf(mut file: std::fs::File) -> anyhow::Result<(Header, Vec)> { let buffer_size = 1_000_000; let buffer_growth_factor = 2; @@ -430,6 +438,12 @@ mod gguf { )?; tensor_infos.push(tensor_info); } + + let position = file.stream_position()?; + let padding = align_offset(alignment as u64, position) - position; + println!("calculated padding: {}", padding); + let padding_parser = parse_padding(padding); + let _ = parse_with_buffer(&mut file, &mut buffer, padding_parser, buffer_growth_factor)?; Ok((header, tensor_infos)) } @@ -437,6 +451,7 @@ mod gguf { file: &mut std::fs::File, buffer: &mut circular::Buffer, mut parser: fn(&mut winnow::Partial<&winnow::Bytes>) -> Result>, + // mut parser: impl Parser, O, ContextError> buffer_growth_factor: usize, ) -> anyhow::Result { use std::io::Read; @@ -462,7 +477,7 @@ mod gguf { buffer.fill(read); 'inner: loop { - let mut input = BytesStream::new(winnow::Bytes::new(buffer.data())); + let input = BytesStream::new(winnow::Bytes::new(buffer.data())); let parser_result = parser.parse_peek(input); match parser_result { @@ -495,9 +510,9 @@ mod gguf { error: winnow::error::ErrMode, ) -> std::io::Error { match error { - ErrMode::Backtrack(err) => std::io::Error::new(std::io::ErrorKind::Other, "Backtrack"), - ErrMode::Cut(err) => std::io::Error::new(std::io::ErrorKind::Other, "Cut"), - ErrMode::Incomplete(needed) => std::io::Error::new(std::io::ErrorKind::Other, "Needed"), + ErrMode::Backtrack(_) => std::io::Error::new(std::io::ErrorKind::Other, "Backtrack"), + ErrMode::Cut(_) => std::io::Error::new(std::io::ErrorKind::Other, "Cut"), + ErrMode::Incomplete(_) => std::io::Error::new(std::io::ErrorKind::Other, "Needed"), } } } @@ -509,7 +524,7 @@ mod tests { #[test] fn test_load_gguf() -> anyhow::Result<()> { - let mut file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; + let file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; let result = gguf::load_gguf(file);