Skip to content

Commit

Permalink
candle compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
santiagomed committed Mar 14, 2024
1 parent 2bc3dcf commit 81b24c9
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
8 changes: 4 additions & 4 deletions orca-core/src/llm/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,22 +198,22 @@ impl Quantized {
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes += elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
total_size_in_bytes += elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
log::info!(
"loaded {:?} tensors ({}) in {:.2}s",
model.tensor_infos.len(),
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
Some(ModelWeights::from_gguf(model, &mut file)?)
Some(ModelWeights::from_gguf(model, &mut file, &Device::Cpu)?)
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file)?;
let model = ggml_file::Content::read(&mut file, &Device::Cpu)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
total_size_in_bytes += elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
}
log::info!(
"loaded {:?} tensors ({}) in {:.2}s",
Expand Down
5 changes: 3 additions & 2 deletions orca-models/src/mistral.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::utils::text_generation::{Model, TextGeneration};
use candle::Device;
use candle_transformers::models::mistral;
use candle_transformers::models::quantized_mistral;

Expand Down Expand Up @@ -72,7 +73,7 @@ impl Mistral {
P: AsRef<std::path::Path>,
{
let cfg = mistral::Config::config_7b_v0_1(config.flash_attn);
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(weights)?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(weights, &Device::Cpu)?;
let model = quantized_mistral::Model::new(&cfg, vb)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(|m| anyhow::anyhow!(m))?;
Ok(Self {
Expand All @@ -88,7 +89,7 @@ impl Mistral {

pub fn from_stream(weights: Vec<u8>, tokenizer: Vec<u8>, config: Config) -> anyhow::Result<Self> {
let cfg = mistral::Config::config_7b_v0_1(config.flash_attn);
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights, &Device::Cpu)?;
let model = quantized_mistral::Model::new(&cfg, vb)?;
let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?;
Ok(Self {
Expand Down
4 changes: 2 additions & 2 deletions orca-models/src/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Quantized {
pub fn from_gguf_stream(model: Vec<u8>, tokenizer: Vec<u8>, config: Config) -> anyhow::Result<Self> {
let mut model_reader = std::io::Cursor::new(model);
let model_content = gguf_file::Content::read(&mut model_reader)?;
let model = ModelWeights::from_gguf(model_content, &mut model_reader)?;
let model = ModelWeights::from_gguf(model_content, &mut model_reader, &Device::Cpu)?;
let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?;
Ok(Self {
model,
Expand All @@ -80,7 +80,7 @@ impl Quantized {

pub fn from_ggml_stream(model: Vec<u8>, tokenizer: Vec<u8>, config: Config) -> anyhow::Result<Self> {
let mut model_reader = std::io::Cursor::new(model);
let model_content = ggml_file::Content::read(&mut model_reader)?;
let model_content = ggml_file::Content::read(&mut model_reader, &Device::Cpu)?;
let model = ModelWeights::from_ggml(model_content, 1)?;
let tokenizer = tokenizers::Tokenizer::from_bytes(tokenizer).map_err(|m| anyhow::anyhow!(m))?;
Ok(Self {
Expand Down

0 comments on commit 81b24c9

Please sign in to comment.