Skip to content

Commit

Permalink
feat(candle): support Qwen2 on Cuda (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jul 2, 2024
1 parent 6c6cd93 commit 4b2ab61
Show file tree
Hide file tree
Showing 18 changed files with 6,805 additions and 26 deletions.
24 changes: 13 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,22 @@ Ember, GTE and E5. TEI implements many features such as:
#### Text Embeddings

Text Embeddings Inference currently supports Nomic, BERT, CamemBERT, XLM-RoBERTa models with absolute positions, JinaBERT
model with Alibi positions and Mistral, Alibabe GTE models with Rope positions.
model with Alibi positions and Mistral, Alibaba GTE and Qwen2 models with Rope positions.

Below are some examples of the currently supported models:

| MTEB Rank | Model Size | Model Type | Model ID |
|-----------|----------------|-------------|--------------------------------------------------------------------------------------------------|
| 1 | 7B (Very Slow) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) |
| 15 | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](Alibaba-NLP/gte-large-en-v1.5) |
| 20 | 0.3B | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) |
| 24 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) |
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) |
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |
| MTEB Rank | Model Size | Model Type | Model ID |
|-----------|---------------------|-------------|--------------------------------------------------------------------------------------------------|
| 1 | 7B (Very Expensive) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) |
| 2 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) |
| 9 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) |
| 15 | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](Alibaba-NLP/gte-large-en-v1.5) |
| 20 | 0.3B | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) |
| 24 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) |
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) |
| N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1.5](https://hf.co/nomic-ai/nomic-embed-text-v1.5) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
| N/A | 0.1B | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |


To explore the list of best performing text embeddings models, visit the
Expand Down
26 changes: 21 additions & 5 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, JinaBertModel,
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig,
JinaCodeBertModel, MistralConfig, Model, NomicBertModel, NomicConfig, Qwen2Config,
};
#[cfg(feature = "cuda")]
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
FlashJinaCodeBertModel, FlashMistralModel, FlashNomicBertModel,
FlashJinaCodeBertModel, FlashMistralModel, FlashNomicBertModel, FlashQwen2Model,
};
use anyhow::Context;
use candle::{DType, Device};
Expand Down Expand Up @@ -59,6 +59,7 @@ enum Config {
Mistral(MistralConfig),
#[serde(rename = "new")]
Gte(GTEConfig),
Qwen2(Qwen2Config),
}

pub struct CandleBackend {
Expand Down Expand Up @@ -221,6 +222,10 @@ impl CandleBackend {
"GTE is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
(Config::Qwen2(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
.to_string(),
)),
#[cfg(feature = "cuda")]
(Config::Bert(config), Device::Cuda(_)) => {
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down Expand Up @@ -342,14 +347,25 @@ impl CandleBackend {
#[cfg(feature = "cuda")]
(Config::Gte(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(feature = "flash-attn")
|| get_runtime_compute_cap().unwrap() < 80
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
{
return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
return Err(BackendError::Start("GTE is only supported on Cuda devices in fp16 with flash attention enabled".to_string()));
}
tracing::info!("Starting FlashGTE model on {:?}", device);
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
}
#[cfg(feature = "cuda")]
(Config::Qwen2(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
{
return Err(BackendError::Start("Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
tracing::info!("Starting FlashQwen2 model on {:?}", device);
Ok(Box::new(
FlashQwen2Model::load(vb, &config, model_type).s()?,
))
}
};

Ok(Self {
Expand Down
5 changes: 4 additions & 1 deletion backends/candle/src/models/flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,10 @@ impl FlashBertModel {
// Get token indices form cu_seqlens
let mut indices = match self.pool {
Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?,
Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?,
Pool::LastToken => {
let end = cu_seqlens.narrow(0, 1, batch_size)?;
(&end - &end.ones_like()?)?
}
_ => unreachable!(),
};

Expand Down
5 changes: 4 additions & 1 deletion backends/candle/src/models/flash_distilbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ impl FlashDistilBertModel {
// Get token indices form cu_seqlens
let mut indices = match self.pool {
Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?,
Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?,
Pool::LastToken => {
let end = cu_seqlens.narrow(0, 1, batch_size)?;
(&end - &end.ones_like()?)?
}
_ => unreachable!(),
};

Expand Down
5 changes: 4 additions & 1 deletion backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,10 @@ impl FlashGTEModel {
// Get token indices form cu_seqlens
let mut indices = match self.pool {
Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?,
Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?,
Pool::LastToken => {
let end = cu_seqlens.narrow(0, 1, batch_size)?;
(&end - &end.ones_like()?)?
}
_ => unreachable!(),
};

Expand Down
5 changes: 4 additions & 1 deletion backends/candle/src/models/flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,10 @@ impl FlashJinaBertModel {
// Get token indices form cu_seqlens
let mut indices = match self.pool {
Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?,
Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?,
Pool::LastToken => {
let end = cu_seqlens.narrow(0, 1, batch_size)?;
(&end - &end.ones_like()?)?
}
_ => unreachable!(),
};

Expand Down
5 changes: 4 additions & 1 deletion backends/candle/src/models/flash_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,10 @@ impl FlashJinaCodeBertModel {
// Get token indices form cu_seqlens
let mut indices = match self.pool {
Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?,
Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?,
Pool::LastToken => {
let end = cu_seqlens.narrow(0, 1, batch_size)?;
(&end - &end.ones_like()?)?
}
_ => unreachable!(),
};

Expand Down
5 changes: 4 additions & 1 deletion backends/candle/src/models/flash_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,10 @@ impl FlashMistralModel {
// Get token indices form cu_seqlens
let mut indices = match self.pool {
Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?,
Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?,
Pool::LastToken => {
let end = cu_seqlens.narrow(0, 1, batch_size)?;
(&end - &end.ones_like()?)?
}
_ => unreachable!(),
};

Expand Down
5 changes: 4 additions & 1 deletion backends/candle/src/models/flash_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,10 @@ impl FlashNomicBertModel {
// Get token indices form cu_seqlens
let mut indices = match self.pool {
Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?,
Pool::LastToken => cu_seqlens.narrow(0, 1, batch_size)?,
Pool::LastToken => {
let end = cu_seqlens.narrow(0, 1, batch_size)?;
(&end - &end.ones_like()?)?
}
_ => unreachable!(),
};

Expand Down
Loading

0 comments on commit 4b2ab61

Please sign in to comment.