From 58fe2bcfb57a15643f047a83e28fe4327ba62cb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Sun, 28 Apr 2024 02:58:16 -0300 Subject: [PATCH 01/12] Use cublas for prompt --- mistralrs-core/src/models/quantized_llama.rs | 63 +++++++++++++++----- mistralrs-core/src/pipeline/ggml.rs | 1 + mistralrs-core/src/pipeline/gguf.rs | 1 + 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index 8b5b91b36..bb9115526 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -16,6 +16,28 @@ use super::{repeat_kv, verify_sanity_gguf, Cache}; const MAX_SEQ_LEN: u32 = 4096; +fn lt_mul(xs: &Tensor, w: &QMatMul, is_prompt: bool) -> Result { + if is_prompt { + let w = match w { + QMatMul::QTensor(ref qt) => qt.dequantize(xs.device())?, + QMatMul::Tensor(w) => w.clone(), + }; + + let w = w.to_dtype(DType::F16)?; + let w = match *xs.dims() { + [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => w.broadcast_left(bsize)?.t()?, + _ => w.t()?, + }; + // xs.matmul(&w) + let xs = xs.to_dtype(DType::F16)?; + + xs.matmul(&w)?.to_dtype(DType::F32) + } else { + w.forward(xs) + } +} + #[derive(Debug, Clone)] struct Mlp { feed_forward_w1: QMatMul, @@ -23,12 +45,15 @@ struct Mlp { feed_forward_w3: QMatMul, } -impl Module for Mlp { - fn forward(&self, xs: &Tensor) -> Result { - let w1 = self.feed_forward_w1.forward(xs)?; - let w3 = self.feed_forward_w3.forward(xs)?; - self.feed_forward_w2 - .forward(&(candle_nn::ops::silu(&w1)? * w3)?) +impl Mlp { + fn forward(&self, xs: &Tensor, is_prompt: bool) -> Result { + // let w1 = self.feed_forward_w1.forward(xs)?; + let w1 = lt_mul(xs, &self.feed_forward_w1, is_prompt)?; + // let w3 = self.feed_forward_w3.forward(xs)?; + let w3 = lt_mul(xs, &self.feed_forward_w3, is_prompt)?; + let y = &(candle_nn::ops::silu(&w1)? * w3)?; + // self.feed_forward_w2.forward(y) + lt_mul(y, &self.feed_forward_w2, is_prompt) } } @@ -42,8 +67,8 @@ enum MlpOrMoe { }, } -impl Module for MlpOrMoe { - fn forward(&self, xs: &Tensor) -> Result { +impl MlpOrMoe { + fn forward(&self, xs: &Tensor, is_prompt: bool) -> Result { match self { Self::MoE { feed_forward_gate_inp, @@ -98,7 +123,7 @@ impl Module for MlpOrMoe { // states by `routing_weights` on the corresponding tokens (top-1 and top-2) let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) - let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = expert_layer.forward(¤t_state, is_prompt)?; let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?; ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; @@ -107,7 +132,7 @@ impl Module for MlpOrMoe { let ys = ys.reshape((b_size, seq_len, hidden_dim))?; Ok(ys) } - Self::Mlp(mlp) => mlp.forward(xs), + Self::Mlp(mlp) => mlp.forward(xs, is_prompt), } } } @@ -142,12 +167,17 @@ impl LayerWeights { start_offsets: &[usize], start_offsets_kernel: Tensor, kv_cache: &mut Option<(Tensor, Tensor)>, + is_prompt: bool, ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = self.attention_wq.forward(x)?; - let k = self.attention_wk.forward(x)?; - let v = self.attention_wv.forward(x)?; + let q = lt_mul(x, &self.attention_wq, is_prompt)?; + let k = lt_mul(x, &self.attention_wk, is_prompt)?; + let v = lt_mul(x, &self.attention_wv, is_prompt)?; + + // let q = self.attention_wq.forward(x)?; + // let k = self.attention_wk.forward(x)?; + // let v = self.attention_wv.forward(x)?; let mut q = q.reshape((b_sz * seq_len, self.n_head, self.head_dim))?; let mut k = k.reshape((b_sz * seq_len, self.n_kv_head, self.head_dim))?; let v = v @@ -192,7 +222,8 @@ impl LayerWeights { // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.attention_wo.forward(&y)?; + // let y = self.attention_wo.forward(&y)?; + let y = lt_mul(&y, &self.attention_wo, is_prompt)?; Ok(y) } } @@ -423,6 +454,7 @@ impl ModelWeights { start_offsets: &[usize], start_offsets_kernel: Tensor, context_lens: Vec, + is_prompt: bool, ) -> Result { let (_b_sz, seq_len) = x.dims2()?; let mask = if seq_len == 1 { @@ -445,13 +477,14 @@ impl ModelWeights { start_offsets, start_offsets_kernel.clone(), &mut cache[i], + is_prompt, )?; let x = (attn + residual)?; // MLP let residual = &x; let x = layer.ffn_norm.forward(&x)?; - let x = layer.mlp_or_moe.forward(&x)?; + let x = layer.mlp_or_moe.forward(&x, is_prompt)?; let x = (x + residual)?; layer_in = x; } diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 1ce9ac8c0..faee99d31 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -415,6 +415,7 @@ impl Pipeline for GGMLPipeline { &seqlen_offsets, seqlen_offsets_kernel, context_lens, + is_prompt, ), Model::XLoraLlama(ref mut model) => model.forward( &input_ids, diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 7efa10186..589701db6 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -468,6 +468,7 @@ impl Pipeline for GGUFPipeline { &seqlen_offsets, seqlen_offsets_kernel, context_lens, + is_prompt, ), Model::Phi2(ref mut model) => model.forward(&input_ids, &seqlen_offsets, context_lens), Model::XLoraLlama(ref mut model) => model.forward( From 62d52069b48fa03516da9ad48a0af7421f9303e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Sun, 28 Apr 2024 16:36:30 -0300 Subject: [PATCH 02/12] cublas prompt --- mistralrs-core/src/models/quantized_llama.rs | 42 ++++++-------------- mistralrs-lora/src/loralinear.rs | 1 + mistralrs-lora/src/qloralinear.rs | 1 + 3 files changed, 14 insertions(+), 30 deletions(-) diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index bb9115526..6a9271eda 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -16,23 +16,12 @@ use super::{repeat_kv, verify_sanity_gguf, Cache}; const MAX_SEQ_LEN: u32 = 4096; -fn lt_mul(xs: &Tensor, w: &QMatMul, is_prompt: bool) -> Result { +fn quantized_mat_mul(xs: &Tensor, w: &QMatMul, is_prompt: bool) -> Result { + // TODO: For very small prompts, we should use forward + // For completions with batch size > 8, we should use forward_via_f16 + // TODO: benchmark and implement the above if is_prompt { - let w = match w { - QMatMul::QTensor(ref qt) => qt.dequantize(xs.device())?, - QMatMul::Tensor(w) => w.clone(), - }; - - let w = w.to_dtype(DType::F16)?; - let w = match *xs.dims() { - [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?, - [bsize, _, _] => w.broadcast_left(bsize)?.t()?, - _ => w.t()?, - }; - // xs.matmul(&w) - let xs = xs.to_dtype(DType::F16)?; - - xs.matmul(&w)?.to_dtype(DType::F32) + w.forward_via_f16(xs) } else { w.forward(xs) } @@ -47,13 +36,10 @@ struct Mlp { impl Mlp { fn forward(&self, xs: &Tensor, is_prompt: bool) -> Result { - // let w1 = self.feed_forward_w1.forward(xs)?; - let w1 = lt_mul(xs, &self.feed_forward_w1, is_prompt)?; - // let w3 = self.feed_forward_w3.forward(xs)?; - let w3 = lt_mul(xs, &self.feed_forward_w3, is_prompt)?; + let w1 = quantized_mat_mul(xs, &self.feed_forward_w1, is_prompt)?; + let w3 = quantized_mat_mul(xs, &self.feed_forward_w3, is_prompt)?; let y = &(candle_nn::ops::silu(&w1)? * w3)?; - // self.feed_forward_w2.forward(y) - lt_mul(y, &self.feed_forward_w2, is_prompt) + quantized_mat_mul(y, &self.feed_forward_w2, is_prompt) } } @@ -171,13 +157,10 @@ impl LayerWeights { ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = lt_mul(x, &self.attention_wq, is_prompt)?; - let k = lt_mul(x, &self.attention_wk, is_prompt)?; - let v = lt_mul(x, &self.attention_wv, is_prompt)?; + let q = quantized_mat_mul(x, &self.attention_wq, is_prompt)?; + let k = quantized_mat_mul(x, &self.attention_wk, is_prompt)?; + let v = quantized_mat_mul(x, &self.attention_wv, is_prompt)?; - // let q = self.attention_wq.forward(x)?; - // let k = self.attention_wk.forward(x)?; - // let v = self.attention_wv.forward(x)?; let mut q = q.reshape((b_sz * seq_len, self.n_head, self.head_dim))?; let mut k = k.reshape((b_sz * seq_len, self.n_kv_head, self.head_dim))?; let v = v @@ -222,8 +205,7 @@ impl LayerWeights { // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - // let y = self.attention_wo.forward(&y)?; - let y = lt_mul(&y, &self.attention_wo, is_prompt)?; + let y = quantized_mat_mul(&y, &self.attention_wo, is_prompt)?; Ok(y) } } diff --git a/mistralrs-lora/src/loralinear.rs b/mistralrs-lora/src/loralinear.rs index 0a32c6816..a7e414211 100644 --- a/mistralrs-lora/src/loralinear.rs +++ b/mistralrs-lora/src/loralinear.rs @@ -156,6 +156,7 @@ impl Merge for LoraLinear { } self.old = QLinear::from_parts(w_base_layer, self.old.bias().cloned()); } + QMatMul::TensorF16(_) => todo!(), }; self.merged = true; Ok(()) diff --git a/mistralrs-lora/src/qloralinear.rs b/mistralrs-lora/src/qloralinear.rs index c10383eca..9fe224006 100644 --- a/mistralrs-lora/src/qloralinear.rs +++ b/mistralrs-lora/src/qloralinear.rs @@ -173,6 +173,7 @@ impl Merge for QLoraLinear { let (mut w_base_layer, dtype) = match &self.old { QMatMul::QTensor(q) => (q.dequantize(&q.device())?, q.dtype()), QMatMul::Tensor(_) => unreachable!(), + QMatMul::TensorF16(_) => todo!(), }; for adapter in 0..self.scale_adapters.len() { w_base_layer = (w_base_layer + self.get_delta_weight(adapter))?; From 034713075a5a2247634947b009b59778244f9c5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Sun, 28 Apr 2024 22:35:07 -0300 Subject: [PATCH 03/12] use f16 for output mul --- mistralrs-core/src/models/quantized_llama.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index 6a9271eda..db549d12c 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -472,6 +472,9 @@ impl ModelWeights { } let layer_in = layer_in.to_device(&self.device)?; let x = self.norm.forward(&layer_in)?; - extract_logits(&self.output.forward(&x.contiguous()?)?, context_lens) + extract_logits( + &quantized_mat_mul(&x.contiguous()?, &self.output, is_prompt)?, + context_lens, + ) } } From e06966704f5a2272c7c8a81c05226581ae951e2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Mon, 29 Apr 2024 00:18:12 -0300 Subject: [PATCH 04/12] mulmat via f16 --- mistralrs-core/src/models/quantized_llama.rs | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index db549d12c..8f51df3ca 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -192,8 +192,16 @@ impl LayerWeights { let k = repeat_kv(k, self.n_head / self.n_kv_head)?.contiguous()?; let v = repeat_kv(v, self.n_head / self.n_kv_head)?.contiguous()?; + let att = if is_prompt { + (q.to_dtype(DType::F16)? + .contiguous()? + .matmul(&k.to_dtype(DType::F16)?.t()?.contiguous()?)? + .to_dtype(DType::F32)? + / (self.head_dim as f64).sqrt())? + } else { + (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())? + }; - let att = (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())?; let att = match mask { None => att, Some(mask) => { @@ -203,7 +211,14 @@ impl LayerWeights { }; let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; + let y = if is_prompt { + att.to_dtype(DType::F16)? + .matmul(&v.to_dtype(DType::F16)?.contiguous()?)? + .to_dtype(DType::F32)? + } else { + att.matmul(&v.contiguous()?)? + }; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = quantized_mat_mul(&y, &self.attention_wo, is_prompt)?; Ok(y) From 97c0324b1aac2afccf48d2327c2d414b946edadd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Mon, 29 Apr 2024 00:56:58 -0300 Subject: [PATCH 05/12] disable attn mask in bench --- mistralrs-bench/src/main.rs | 1 + mistralrs-core/src/model_loader.rs | 13 +++++++++++++ mistralrs-core/src/models/quantized_llama.rs | 20 ++++++++++++-------- mistralrs-core/src/pipeline/ggml.rs | 16 +++++++++++++++- mistralrs-core/src/pipeline/gguf.rs | 20 +++++++++++++++++--- 5 files changed, 58 insertions(+), 12 deletions(-) diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index 3c756eb8e..8feb04428 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -243,6 +243,7 @@ fn main() -> anyhow::Result<()> { let loader: Box = LoaderBuilder::new(args.model) .with_use_flash_attn(use_flash_attn) + .with_disable_attention_mask(true) .build()?; let model_name = loader.get_id(); diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index a2540169c..6a5f16f5f 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -13,6 +13,7 @@ pub struct LoaderBuilder { no_kv_cache: bool, chat_template: Option, use_flash_attn: bool, + disable_attention_mask: bool, } impl LoaderBuilder { @@ -22,6 +23,7 @@ impl LoaderBuilder { no_kv_cache: false, chat_template: None, use_flash_attn: false, + disable_attention_mask: false, } } @@ -38,6 +40,11 @@ impl LoaderBuilder { self } + pub fn with_disable_attention_mask(mut self, disable_attention_mask: bool) -> Self { + self.disable_attention_mask = disable_attention_mask; + self + } + pub fn build(self) -> anyhow::Result> { loader_from_model_selected(self) } @@ -152,6 +159,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result anyhow::Result anyhow::Result anyhow::Result anyhow::Result anyhow::Result>, + disable_mask: bool, } impl ModelWeights { - pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, disable_mask: bool) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let rotary = RotaryEmbedding::new_partial( 10000., @@ -299,6 +300,7 @@ impl ModelWeights { cache: Cache::new(ct.hparams.n_layer as usize, false), max_seq_len: MAX_SEQ_LEN as usize, // Cannot determine from ggml. mapper: None, + disable_mask, }) } @@ -307,6 +309,7 @@ impl ModelWeights { reader: &mut R, device: &Device, mapper: DeviceMapMetadata, + disable_mask: bool, ) -> Result { let md_get = |s: &str| match ct.metadata.get(s) { None => candle_core::bail!("cannot find {s} in metadata"), @@ -429,6 +432,7 @@ impl ModelWeights { .and_then(|m| m.to_u64()) .unwrap_or(MAX_SEQ_LEN as u64) as usize, mapper: Some(mapper), + disable_mask, }) } @@ -454,7 +458,7 @@ impl ModelWeights { is_prompt: bool, ) -> Result { let (_b_sz, seq_len) = x.dims2()?; - let mask = if seq_len == 1 { + let mask: Option = if seq_len == 1 || self.disable_mask { None } else { Some(self.mask(seq_len, x.device())?) diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index faee99d31..140c5c118 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -101,6 +101,7 @@ pub struct GGMLLoader { tokenizer_json: Option, kind: ModelKind, tgt_non_granular_index: Option, + disable_attention_mask: bool, } #[derive(Clone, Copy, Default)] @@ -124,6 +125,7 @@ pub struct GGMLLoaderBuilder { chat_template: Option, tokenizer_json: Option, tgt_non_granular_index: Option, + disable_attention_mask: bool, } impl GGMLLoaderBuilder { @@ -202,6 +204,11 @@ impl GGMLLoaderBuilder { ) } + pub fn with_disable_attention_mask(mut self, disable_attention_mask: bool) -> Self { + self.disable_attention_mask = disable_attention_mask; + self + } + pub fn build(self) -> Box { Box::new(GGMLLoader { model_id: self.model_id.unwrap(), @@ -215,6 +222,7 @@ impl GGMLLoaderBuilder { tgt_non_granular_index: self.tgt_non_granular_index, quantized_filename: Some(self.quantized_filename), quantized_model_id: Some(self.quantized_model_id), + disable_attention_mask: self.disable_attention_mask, }) } } @@ -233,6 +241,7 @@ impl GGMLLoader { chat_template: Option, tokenizer_json: Option, tgt_non_granular_index: Option, + disable_attention_mask: bool, ) -> Self { let model_id = if let Some(id) = model_id { id @@ -255,6 +264,7 @@ impl GGMLLoader { tokenizer_json, kind, tgt_non_granular_index, + disable_attention_mask, } } } @@ -301,7 +311,11 @@ impl Loader for GGMLLoader { let mut is_lora = false; let model = match self.kind { - ModelKind::QuantizedGGML => Model::Llama(QLlama::from_ggml(model, self.config.gqa)?), + ModelKind::QuantizedGGML => Model::Llama(QLlama::from_ggml( + model, + self.config.gqa, + self.disable_attention_mask, + )?), ModelKind::XLoraGGML => { let vb = from_mmaped_safetensors( vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()], diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 589701db6..3c0f0217d 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -102,6 +102,7 @@ pub struct GGUFLoader { tokenizer_json: Option, kind: ModelKind, tgt_non_granular_index: Option, + disable_attention_mask: bool, } #[derive(Debug)] @@ -158,6 +159,7 @@ pub struct GGUFLoaderBuilder { chat_template: Option, tokenizer_json: Option, tgt_non_granular_index: Option, + disable_attention_mask: bool, } impl GGUFLoaderBuilder { @@ -236,6 +238,11 @@ impl GGUFLoaderBuilder { ) } + pub fn with_disable_attention_mask(mut self, disable_attention_mask: bool) -> Self { + self.disable_attention_mask = disable_attention_mask; + self + } + pub fn build(self) -> Box { Box::new(GGUFLoader { model_id: self.model_id.unwrap(), @@ -249,6 +256,7 @@ impl GGUFLoaderBuilder { tgt_non_granular_index: self.tgt_non_granular_index, quantized_filename: Some(self.quantized_filename), quantized_model_id: Some(self.quantized_model_id), + disable_attention_mask: self.disable_attention_mask, }) } } @@ -267,6 +275,7 @@ impl GGUFLoader { chat_template: Option, tokenizer_json: Option, tgt_non_granular_index: Option, + disable_attention_mask: bool, ) -> Self { let model_id = if let Some(id) = model_id { id @@ -289,6 +298,7 @@ impl GGUFLoader { tokenizer_json, kind, tgt_non_granular_index, + disable_attention_mask, } } } @@ -337,9 +347,13 @@ impl Loader for GGUFLoader { let mut is_lora = false; let model = match self.kind { ModelKind::QuantizedGGUF => match arch { - GGUFArchitecture::Llama => { - Model::Llama(QLlama::from_gguf(model, &mut file, device, mapper)?) - } + GGUFArchitecture::Llama => Model::Llama(QLlama::from_gguf( + model, + &mut file, + device, + mapper, + self.disable_attention_mask, + )?), GGUFArchitecture::Phi2 => { Model::Phi2(QPhi::from_gguf(model, &mut file, device, mapper)?) } From 699e54160ee4797ce9e7ac5d93dfd48f19bb1c11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Mon, 29 Apr 2024 01:23:48 -0300 Subject: [PATCH 06/12] reduce contiguous calls --- mistralrs-core/src/models/quantized_llama.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index d770fd27a..4e6668e67 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -190,8 +190,8 @@ impl LayerWeights { }; *kv_cache = Some((k.clone(), v.clone())); - let k = repeat_kv(k, self.n_head / self.n_kv_head)?.contiguous()?; - let v = repeat_kv(v, self.n_head / self.n_kv_head)?.contiguous()?; + let k = repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = repeat_kv(v, self.n_head / self.n_kv_head)?; let att = if is_prompt { let mm = q .to_dtype(DType::F16)? @@ -199,6 +199,7 @@ impl LayerWeights { ((mm / (self.head_dim as f64).sqrt())?).to_dtype(DType::F32)? } else { + let k = k.contiguous()?; (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())? }; From 638805389e8940f12ce7a0f9e6ecd5cf536f0956 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Mon, 29 Apr 2024 18:38:50 -0300 Subject: [PATCH 07/12] remove attn mask disabling --- mistralrs-bench/src/main.rs | 1 - mistralrs-core/src/model_loader.rs | 13 ------------- mistralrs-core/src/models/quantized_llama.rs | 8 ++------ mistralrs-core/src/pipeline/ggml.rs | 16 +--------------- mistralrs-core/src/pipeline/gguf.rs | 20 +++----------------- 5 files changed, 6 insertions(+), 52 deletions(-) diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index 8feb04428..3c756eb8e 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -243,7 +243,6 @@ fn main() -> anyhow::Result<()> { let loader: Box = LoaderBuilder::new(args.model) .with_use_flash_attn(use_flash_attn) - .with_disable_attention_mask(true) .build()?; let model_name = loader.get_id(); diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index 6a5f16f5f..a2540169c 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -13,7 +13,6 @@ pub struct LoaderBuilder { no_kv_cache: bool, chat_template: Option, use_flash_attn: bool, - disable_attention_mask: bool, } impl LoaderBuilder { @@ -23,7 +22,6 @@ impl LoaderBuilder { no_kv_cache: false, chat_template: None, use_flash_attn: false, - disable_attention_mask: false, } } @@ -40,11 +38,6 @@ impl LoaderBuilder { self } - pub fn with_disable_attention_mask(mut self, disable_attention_mask: bool) -> Self { - self.disable_attention_mask = disable_attention_mask; - self - } - pub fn build(self) -> anyhow::Result> { loader_from_model_selected(self) } @@ -159,7 +152,6 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result anyhow::Result anyhow::Result anyhow::Result anyhow::Result anyhow::Result>, - disable_mask: bool, } impl ModelWeights { - pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize, disable_mask: bool) -> Result { + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let rotary = RotaryEmbedding::new_partial( 10000., @@ -301,7 +300,6 @@ impl ModelWeights { cache: Cache::new(ct.hparams.n_layer as usize, false), max_seq_len: MAX_SEQ_LEN as usize, // Cannot determine from ggml. mapper: None, - disable_mask, }) } @@ -310,7 +308,6 @@ impl ModelWeights { reader: &mut R, device: &Device, mapper: DeviceMapMetadata, - disable_mask: bool, ) -> Result { let md_get = |s: &str| match ct.metadata.get(s) { None => candle_core::bail!("cannot find {s} in metadata"), @@ -433,7 +430,6 @@ impl ModelWeights { .and_then(|m| m.to_u64()) .unwrap_or(MAX_SEQ_LEN as u64) as usize, mapper: Some(mapper), - disable_mask, }) } @@ -459,7 +455,7 @@ impl ModelWeights { is_prompt: bool, ) -> Result { let (_b_sz, seq_len) = x.dims2()?; - let mask: Option = if seq_len == 1 || self.disable_mask { + let mask: Option = if seq_len == 1 { None } else { Some(self.mask(seq_len, x.device())?) diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index 140c5c118..faee99d31 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -101,7 +101,6 @@ pub struct GGMLLoader { tokenizer_json: Option, kind: ModelKind, tgt_non_granular_index: Option, - disable_attention_mask: bool, } #[derive(Clone, Copy, Default)] @@ -125,7 +124,6 @@ pub struct GGMLLoaderBuilder { chat_template: Option, tokenizer_json: Option, tgt_non_granular_index: Option, - disable_attention_mask: bool, } impl GGMLLoaderBuilder { @@ -204,11 +202,6 @@ impl GGMLLoaderBuilder { ) } - pub fn with_disable_attention_mask(mut self, disable_attention_mask: bool) -> Self { - self.disable_attention_mask = disable_attention_mask; - self - } - pub fn build(self) -> Box { Box::new(GGMLLoader { model_id: self.model_id.unwrap(), @@ -222,7 +215,6 @@ impl GGMLLoaderBuilder { tgt_non_granular_index: self.tgt_non_granular_index, quantized_filename: Some(self.quantized_filename), quantized_model_id: Some(self.quantized_model_id), - disable_attention_mask: self.disable_attention_mask, }) } } @@ -241,7 +233,6 @@ impl GGMLLoader { chat_template: Option, tokenizer_json: Option, tgt_non_granular_index: Option, - disable_attention_mask: bool, ) -> Self { let model_id = if let Some(id) = model_id { id @@ -264,7 +255,6 @@ impl GGMLLoader { tokenizer_json, kind, tgt_non_granular_index, - disable_attention_mask, } } } @@ -311,11 +301,7 @@ impl Loader for GGMLLoader { let mut is_lora = false; let model = match self.kind { - ModelKind::QuantizedGGML => Model::Llama(QLlama::from_ggml( - model, - self.config.gqa, - self.disable_attention_mask, - )?), + ModelKind::QuantizedGGML => Model::Llama(QLlama::from_ggml(model, self.config.gqa)?), ModelKind::XLoraGGML => { let vb = from_mmaped_safetensors( vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()], diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 3c0f0217d..589701db6 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -102,7 +102,6 @@ pub struct GGUFLoader { tokenizer_json: Option, kind: ModelKind, tgt_non_granular_index: Option, - disable_attention_mask: bool, } #[derive(Debug)] @@ -159,7 +158,6 @@ pub struct GGUFLoaderBuilder { chat_template: Option, tokenizer_json: Option, tgt_non_granular_index: Option, - disable_attention_mask: bool, } impl GGUFLoaderBuilder { @@ -238,11 +236,6 @@ impl GGUFLoaderBuilder { ) } - pub fn with_disable_attention_mask(mut self, disable_attention_mask: bool) -> Self { - self.disable_attention_mask = disable_attention_mask; - self - } - pub fn build(self) -> Box { Box::new(GGUFLoader { model_id: self.model_id.unwrap(), @@ -256,7 +249,6 @@ impl GGUFLoaderBuilder { tgt_non_granular_index: self.tgt_non_granular_index, quantized_filename: Some(self.quantized_filename), quantized_model_id: Some(self.quantized_model_id), - disable_attention_mask: self.disable_attention_mask, }) } } @@ -275,7 +267,6 @@ impl GGUFLoader { chat_template: Option, tokenizer_json: Option, tgt_non_granular_index: Option, - disable_attention_mask: bool, ) -> Self { let model_id = if let Some(id) = model_id { id @@ -298,7 +289,6 @@ impl GGUFLoader { tokenizer_json, kind, tgt_non_granular_index, - disable_attention_mask, } } } @@ -347,13 +337,9 @@ impl Loader for GGUFLoader { let mut is_lora = false; let model = match self.kind { ModelKind::QuantizedGGUF => match arch { - GGUFArchitecture::Llama => Model::Llama(QLlama::from_gguf( - model, - &mut file, - device, - mapper, - self.disable_attention_mask, - )?), + GGUFArchitecture::Llama => { + Model::Llama(QLlama::from_gguf(model, &mut file, device, mapper)?) + } GGUFArchitecture::Phi2 => { Model::Phi2(QPhi::from_gguf(model, &mut file, device, mapper)?) } From 0ca72557a5c3cbdcd64cd5f1478a0f3cc7da5c8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Mon, 29 Apr 2024 19:37:52 -0300 Subject: [PATCH 08/12] reduced precision, refactor --- mistralrs-core/src/lib.rs | 20 +++++++++ mistralrs-core/src/models/quantized_llama.rs | 44 ++++++++++---------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 347a5e058..abf364667 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -78,6 +78,7 @@ pub struct MistralRsBuilder { no_prefix_cache: Option, prefix_cache_n: Option, disable_eos_stop: Option, + gemm_full_precision_f16: Option, } impl MistralRsBuilder { @@ -91,6 +92,7 @@ impl MistralRsBuilder { no_prefix_cache: None, prefix_cache_n: None, disable_eos_stop: None, + gemm_full_precision_f16: None, } } @@ -122,12 +124,25 @@ impl MistralRsBuilder { self.disable_eos_stop = Some(disable_eos_stop); self } + pub fn with_gemm_full_precision_f16(mut self, gemm_full_precision: bool) -> Self { + self.gemm_full_precision_f16 = Some(gemm_full_precision); + self + } pub fn build(self) -> Arc { MistralRs::new(self) } } +#[cfg(feature = "cuda")] +fn set_gemm_reduced_precision_f16() { + candle_core::cuda::set_gemm_reduced_precision_f16(true); + candle_core::cuda::set_gemm_reduced_precision_bf16(true); +} + +#[cfg(not(feature = "cuda"))] +fn set_gemm_reduced_precision_f16() {} + impl MistralRs { fn new(config: MistralRsBuilder) -> Arc { let MistralRsBuilder { @@ -139,8 +154,13 @@ impl MistralRs { no_prefix_cache, prefix_cache_n, disable_eos_stop, + gemm_full_precision_f16, } = config; + if !gemm_full_precision_f16.unwrap_or(false) { + set_gemm_reduced_precision_f16(); + } + let truncate_sequence = truncate_sequence.unwrap_or(false); let no_kv_cache = no_kv_cache.unwrap_or(false); let no_prefix_cache = no_prefix_cache.unwrap_or(false); diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index b4cfd4f36..eed8e0807 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -16,11 +16,8 @@ use super::{repeat_kv, verify_sanity_gguf, Cache}; const MAX_SEQ_LEN: u32 = 4096; -fn quantized_mat_mul(xs: &Tensor, w: &QMatMul, is_prompt: bool) -> Result { - // TODO: For very small prompts, we should use forward - // For completions with batch size > 8, we should use forward_via_f16 - // TODO: benchmark and implement the above - if is_prompt { +fn quantized_mat_mul(xs: &Tensor, w: &QMatMul, via_f16: bool) -> Result { + if via_f16 { w.forward_via_f16(xs) } else { w.forward(xs) @@ -35,11 +32,11 @@ struct Mlp { } impl Mlp { - fn forward(&self, xs: &Tensor, is_prompt: bool) -> Result { - let w1 = quantized_mat_mul(xs, &self.feed_forward_w1, is_prompt)?; - let w3 = quantized_mat_mul(xs, &self.feed_forward_w3, is_prompt)?; + fn forward(&self, xs: &Tensor, via_f16: bool) -> Result { + let w1 = quantized_mat_mul(xs, &self.feed_forward_w1, via_f16)?; + let w3 = quantized_mat_mul(xs, &self.feed_forward_w3, via_f16)?; let y = &(candle_nn::ops::silu(&w1)? * w3)?; - quantized_mat_mul(y, &self.feed_forward_w2, is_prompt) + quantized_mat_mul(y, &self.feed_forward_w2, via_f16) } } @@ -54,7 +51,7 @@ enum MlpOrMoe { } impl MlpOrMoe { - fn forward(&self, xs: &Tensor, is_prompt: bool) -> Result { + fn forward(&self, xs: &Tensor, via_f16: bool) -> Result { match self { Self::MoE { feed_forward_gate_inp, @@ -109,7 +106,7 @@ impl MlpOrMoe { // states by `routing_weights` on the corresponding tokens (top-1 and top-2) let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) - let current_hidden_states = expert_layer.forward(¤t_state, is_prompt)?; + let current_hidden_states = expert_layer.forward(¤t_state, via_f16)?; let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?; ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; @@ -118,7 +115,7 @@ impl MlpOrMoe { let ys = ys.reshape((b_size, seq_len, hidden_dim))?; Ok(ys) } - Self::Mlp(mlp) => mlp.forward(xs, is_prompt), + Self::Mlp(mlp) => mlp.forward(xs, via_f16), } } } @@ -153,13 +150,13 @@ impl LayerWeights { start_offsets: &[usize], start_offsets_kernel: Tensor, kv_cache: &mut Option<(Tensor, Tensor)>, - is_prompt: bool, + via_f16: bool, ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = quantized_mat_mul(x, &self.attention_wq, is_prompt)?; - let k = quantized_mat_mul(x, &self.attention_wk, is_prompt)?; - let v = quantized_mat_mul(x, &self.attention_wv, is_prompt)?; + let q = quantized_mat_mul(x, &self.attention_wq, via_f16)?; + let k = quantized_mat_mul(x, &self.attention_wk, via_f16)?; + let v = quantized_mat_mul(x, &self.attention_wv, via_f16)?; let mut q = q.reshape((b_sz * seq_len, self.n_head, self.head_dim))?; let mut k = k.reshape((b_sz * seq_len, self.n_kv_head, self.head_dim))?; @@ -192,7 +189,7 @@ impl LayerWeights { let k = repeat_kv(k, self.n_head / self.n_kv_head)?; let v = repeat_kv(v, self.n_head / self.n_kv_head)?; - let att = if is_prompt { + let att = if via_f16 { let mm = q .to_dtype(DType::F16)? .matmul(&k.to_dtype(DType::F16)?.t()?)?; @@ -212,7 +209,7 @@ impl LayerWeights { }; let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. - let y = if is_prompt { + let y = if via_f16 { att.to_dtype(DType::F16)? .matmul(&v.to_dtype(DType::F16)?)? .to_dtype(DType::F32)? @@ -221,7 +218,7 @@ impl LayerWeights { }; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = quantized_mat_mul(&y, &self.attention_wo, is_prompt)?; + let y = quantized_mat_mul(&y, &self.attention_wo, via_f16)?; Ok(y) } } @@ -460,6 +457,9 @@ impl ModelWeights { } else { Some(self.mask(seq_len, x.device())?) }; + + let via_f16 = if is_prompt { seq_len > 32 } else { false }; + let mut layer_in = self.tok_embeddings.forward(x)?; let mut cache = self.cache.lock(); for (i, layer) in self.layers.iter_mut().enumerate() { @@ -475,21 +475,21 @@ impl ModelWeights { start_offsets, start_offsets_kernel.clone(), &mut cache[i], - is_prompt, + via_f16, )?; let x = (attn + residual)?; // MLP let residual = &x; let x = layer.ffn_norm.forward(&x)?; - let x = layer.mlp_or_moe.forward(&x, is_prompt)?; + let x = layer.mlp_or_moe.forward(&x, via_f16)?; let x = (x + residual)?; layer_in = x; } let layer_in = layer_in.to_device(&self.device)?; let x = self.norm.forward(&layer_in)?; extract_logits( - &quantized_mat_mul(&x.contiguous()?, &self.output, is_prompt)?, + &quantized_mat_mul(&x.contiguous()?, &self.output, via_f16)?, context_lens, ) } From 4419381dc4de17a8cf169e89489a91cb2495fd76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Mon, 29 Apr 2024 19:40:52 -0300 Subject: [PATCH 09/12] cliipy --- mistralrs-lora/src/loralinear.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/mistralrs-lora/src/loralinear.rs b/mistralrs-lora/src/loralinear.rs index 49269159e..b2bb5ab96 100644 --- a/mistralrs-lora/src/loralinear.rs +++ b/mistralrs-lora/src/loralinear.rs @@ -156,7 +156,6 @@ impl Merge for LoraLinear { } self.old = QLinear::from_parts(w_base_layer, self.old.bias().cloned()); } - QMatMul::TensorF16(_) => todo!(), }; self.merged = true; Ok(()) From 9d65c2d409f7fb3fffc913ba88eac4a1dc6ed438 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20=C3=81vila?= Date: Mon, 13 May 2024 22:08:53 -0300 Subject: [PATCH 10/12] changes --- mistralrs-core/src/models/quantized_llama.rs | 5 +++-- mistralrs-core/src/pipeline/ggml.rs | 1 - mistralrs-core/src/pipeline/gguf.rs | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index 90d0f6ed0..c7de58751 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -192,7 +192,6 @@ impl LayerWeights { (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())? }; - let att = (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())?; let att = CausalMasker.apply_mask(mask, att, &self.neg_inf)?; let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. @@ -422,8 +421,10 @@ impl ModelWeights { start_offsets: &[usize], start_offsets_kernel: Tensor, context_lens: Vec<(usize, usize)>, - is_prompt: bool, ) -> Result { + let (bz, seq_len, _) = x.dims3()?; + let via_f16 = bz * seq_len > 256; + let mut layer_in = self.tok_embeddings.forward(x)?; let mut cache = self.cache.lock(); let mask = CausalMasker.make_causal_mask(x, &cache)?; diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index ba4d63753..811e7b1ff 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -391,7 +391,6 @@ impl Pipeline for GGMLPipeline { &seqlen_offsets, seqlen_offsets_kernel, context_lens, - is_prompt, ), Model::XLoraLlama(ref mut model) => model.forward( &input_ids, diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 49103bfef..9c050f559 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -523,7 +523,6 @@ impl Pipeline for GGUFPipeline { &seqlen_offsets, seqlen_offsets_kernel, context_lens, - is_prompt, ), Model::Phi2(ref mut model) => model.forward(&input_ids, &seqlen_offsets, context_lens), Model::XLoraLlama(ref mut model) => model.forward( From 43ee0ad496e8248d22484a5f2e9effa053eea02b Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Tue, 14 May 2024 08:24:02 -0400 Subject: [PATCH 11/12] Update mistralrs-core/src/models/quantized_llama.rs --- mistralrs-core/src/models/quantized_llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index c7de58751..b350fb385 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -423,7 +423,7 @@ impl ModelWeights { context_lens: Vec<(usize, usize)>, ) -> Result { let (bz, seq_len, _) = x.dims3()?; - let via_f16 = bz * seq_len > 256; + let via_f16 = seq_len > 32; let mut layer_in = self.tok_embeddings.forward(x)?; let mut cache = self.cache.lock(); From 64f656b32552b982d7a6a900732eb5033ac405ec Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Tue, 14 May 2024 08:38:18 -0400 Subject: [PATCH 12/12] Update mistralrs-core/src/models/quantized_llama.rs --- mistralrs-core/src/models/quantized_llama.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index b350fb385..19468cee6 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -422,7 +422,7 @@ impl ModelWeights { start_offsets_kernel: Tensor, context_lens: Vec<(usize, usize)>, ) -> Result { - let (bz, seq_len, _) = x.dims3()?; + let (_bz, seq_len, _) = x.dims3()?; let via_f16 = seq_len > 32; let mut layer_in = self.tok_embeddings.forward(x)?;