-
Notifications
You must be signed in to change notification settings - Fork 333
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Move to gguf module * Add content abstraction for multiple gguf files * Fix test * Allow specifying and loading multiple gguf files * Update docs and examples * Print some info
- Loading branch information
1 parent
6f3c308
commit a8c2b41
Showing
29 changed files
with
498 additions
and
371 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
use std::fs; | ||
|
||
use anyhow::Context; | ||
use candle_core::{ | ||
quantized::{ | ||
gguf_file::{self, Value}, | ||
QTensor, | ||
}, | ||
Device, Result, | ||
}; | ||
use indexmap::IndexMap; | ||
use tracing::info; | ||
|
||
use crate::{pipeline::GGUFArchitecture, DEBUG}; | ||
|
||
fn parse_gguf_value(value: &Value) -> String { | ||
match value { | ||
Value::Array(vs) => vs | ||
.iter() | ||
.map(parse_gguf_value) | ||
.collect::<Vec<String>>() | ||
.join(", "), | ||
Value::Bool(b) => b.to_string(), | ||
Value::F32(x) => x.to_string(), | ||
Value::F64(x) => x.to_string(), | ||
Value::I8(x) => x.to_string(), | ||
Value::I16(x) => x.to_string(), | ||
Value::I32(x) => x.to_string(), | ||
Value::I64(x) => x.to_string(), | ||
Value::String(x) => x.to_string(), | ||
Value::U8(x) => x.to_string(), | ||
Value::U16(x) => x.to_string(), | ||
Value::U32(x) => x.to_string(), | ||
Value::U64(x) => x.to_string(), | ||
} | ||
} | ||
|
||
// Internal invariant: contents and readers must be paired. | ||
/// This abstracts the files for a GGUF model and enables multiple files to be used. | ||
pub struct Content<'a, R: std::io::Seek + std::io::Read> { | ||
contents: Vec<gguf_file::Content>, | ||
readers: &'a mut [&'a mut R], | ||
arch: GGUFArchitecture, | ||
} | ||
|
||
impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> { | ||
/// Create a `Content` from a set of file readers. | ||
pub fn from_readers(readers: &'a mut [&'a mut R]) -> Result<Self> { | ||
let mut contents = Vec::new(); | ||
let n_readers = readers.len(); | ||
for reader in readers.iter_mut() { | ||
contents.push(gguf_file::Content::read(reader)?); | ||
} | ||
let n_splits = contents | ||
.iter() | ||
.filter_map(|ct| { | ||
ct.metadata | ||
.get("split.count") | ||
.map(|val| val.to_u64().unwrap()) | ||
}) | ||
.collect::<Vec<_>>(); | ||
if n_splits.len() > 1 { | ||
candle_core::bail!("Multiple contents have multiple `split.count` fields"); | ||
} | ||
#[allow(clippy::cast_possible_truncation)] | ||
if !n_splits.is_empty() && n_readers != n_splits[0] as usize { | ||
candle_core::bail!("Number of readers does not match the number of splits."); | ||
} else if n_splits.len() == 1 { | ||
info!("Model n splits: {}", n_splits[0]); | ||
} | ||
|
||
let mut arch = None; | ||
for ct in &contents { | ||
if !ct.metadata.contains_key("general.architecture") { | ||
continue; | ||
} | ||
|
||
arch = Some( | ||
ct.metadata["general.architecture"] | ||
.to_string() | ||
.context("Model metadata should have declared an architecture") | ||
.and_then(GGUFArchitecture::from_value) | ||
.unwrap(), | ||
); | ||
} | ||
let arch = arch.expect("GGUF files must specify `general.architecture`"); | ||
Ok(Self { | ||
contents, | ||
readers, | ||
arch, | ||
}) | ||
} | ||
|
||
pub fn arch(&self) -> GGUFArchitecture { | ||
self.arch | ||
} | ||
|
||
/// Retrieve a tensor, searching through each content. | ||
pub fn tensor(&mut self, name: &str, device: &Device) -> Result<QTensor> { | ||
for (ct, reader) in self.contents.iter().zip(self.readers.iter_mut()) { | ||
if let Some(tensor_info) = ct.tensor_infos.get(name) { | ||
return tensor_info.read(reader, ct.tensor_data_offset, device); | ||
} | ||
} | ||
candle_core::bail!("Cannot find tensor info for {name}") | ||
} | ||
|
||
/// Print metadata for these contents. | ||
/// This will also log tensor name, shape and dtype to `mistralrs_gguf_tensors.txt` is DEBUG is enabled. | ||
pub fn print_metadata(&self) -> anyhow::Result<()> { | ||
// Find the ct with general.architecture | ||
let mut keys = Vec::new(); | ||
let mut metadatas = Vec::new(); | ||
let mut tensors = Vec::new(); | ||
for ct in &self.contents { | ||
keys.extend(ct.metadata.keys()); | ||
metadatas.push(&ct.metadata); | ||
|
||
if DEBUG.load(std::sync::atomic::Ordering::Relaxed) { | ||
for (name, info) in &ct.tensor_infos { | ||
tensors.push(format!( | ||
"name = `{name}`, shape = {:?}, dtype = {:?}", | ||
info.shape.clone(), | ||
info.ggml_dtype | ||
)); | ||
} | ||
} | ||
} | ||
|
||
info!("Model config:"); | ||
keys.sort(); | ||
let mut output_keys = IndexMap::new(); | ||
for name in keys { | ||
if !name.contains("tokenizer") { | ||
for metadata in &metadatas { | ||
if let Some(val) = metadata.get(name) { | ||
output_keys.insert(name, parse_gguf_value(val)); | ||
} | ||
} | ||
} | ||
} | ||
for (name, val) in output_keys { | ||
println!("{name}: {val}") | ||
} | ||
|
||
if DEBUG.load(std::sync::atomic::Ordering::Relaxed) { | ||
fs::write( | ||
"mistralrs_gguf_tensors.txt", | ||
serde_json::to_string_pretty(&tensors).expect("Serialization failed."), | ||
)?; | ||
|
||
info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`."); | ||
} | ||
|
||
anyhow::Ok(()) | ||
} | ||
|
||
/// Get metadata | ||
pub fn get_metadata(&self, name: &str) -> Result<&Value> { | ||
for content in &self.contents { | ||
if let Some(v) = content.metadata.get(name) { | ||
return Ok(v); | ||
} | ||
} | ||
candle_core::bail!("Cannot find metadata for {name}") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
mod content; | ||
mod gguf_tokenizer; | ||
|
||
pub use content::Content; | ||
pub use gguf_tokenizer::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.