Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: New FlagEmbedding models #5

Merged
merged 5 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ The default embedding supports "query" and "passage" prefixes for the input text

## 🤖 Models

- [**BAAI/bge-base-en**](https://huggingface.co/BAAI/bge-base-en)
- [**BAAI/bge-base-en-v1.5**](https://huggingface.co/BAAI/bge-base-en-v1.5)
- [**BAAI/bge-small-en**](https://huggingface.co/BAAI/bge-small-en)
- [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default
- [**BAAI/bge-base-zh-v1.5**](https://huggingface.co/BAAI/bge-base-zh-v1.5)
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
- [**intfloat/multilingual-e5-large**](https://huggingface.co/intfloat/multilingual-e5-large)

Expand Down
158 changes: 74 additions & 84 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ use tokenizers::{AddedToken, PaddingParams, PaddingStrategy, TruncationParams};
const DEFAULT_BATCH_SIZE: usize = 256;
const DEFAULT_MAX_LENGTH: usize = 512;
const DEFAULT_CACHE_DIR: &str = "local_cache";
const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallEN;
const DEFAULT_EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::BGESmallENV15;

/// Type alias for the embedding vector
pub type Embedding = Vec<f32>;
Expand All @@ -116,8 +116,14 @@ pub enum EmbeddingModel {
AllMiniLML6V2,
/// Base English model
BGEBaseEN,
/// v1.5 release of the Base English model
BGEBaseENV15,
/// Fast and Default English model
BGESmallEN,
/// v1.5 release of the BGESmallEN model
BGESmallENV15,
/// v1.5 release of the Fast Chinese model
BGESmallZH,
/// Multilingual model, e5-large. Recommend using this model for non-English languages.
MLE5Large,
}
Expand All @@ -127,7 +133,10 @@ impl ToString for EmbeddingModel {
match self {
EmbeddingModel::AllMiniLML6V2 => String::from("fast-all-MiniLM-L6-v2"),
EmbeddingModel::BGEBaseEN => String::from("fast-bge-base-en"),
EmbeddingModel::BGEBaseENV15 => String::from("fast-bge-base-en-v1.5"),
EmbeddingModel::BGESmallEN => String::from("fast-bge-small-en"),
EmbeddingModel::BGESmallENV15 => String::from("fast-bge-small-en-v1.5"),
EmbeddingModel::BGESmallZH => String::from("fast-bge-small-zh-v1.5"),
EmbeddingModel::MLE5Large => String::from("fast-multilingual-e5-large"),
}
}
Expand Down Expand Up @@ -352,10 +361,25 @@ impl FlagEmbedding {
dim: 768,
description: String::from("Base English model"),
},
ModelInfo {
model: EmbeddingModel::BGEBaseENV15,
dim: 768,
description: String::from("v1.5 release of the base English model"),
},
ModelInfo {
model: EmbeddingModel::BGESmallEN,
dim: 384,
description: String::from("Fast and Default English model"),
description: String::from("Fast English model"),
},
ModelInfo {
model: EmbeddingModel::BGESmallENV15,
dim: 384,
description: String::from("v1.5 release of the fast and default English model"),
},
ModelInfo {
model: EmbeddingModel::BGESmallZH,
dim: 512,
description: String::from("v1.5 release of the fast and Chinese model"),
},
ModelInfo {
model: EmbeddingModel::MLE5Large,
Expand Down Expand Up @@ -504,93 +528,59 @@ mod tests {
const EPSILON: f32 = 1e-4;

#[test]
fn test_bgesmall() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGESmallEN,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.022123, -0.01472, 0.039255,
0.034447, 0.004598,
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}

#[test]
fn test_bgebase() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGEBaseEN,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
0.0114, 0.03722, 0.02941, 0.0123, 0.03451, 0.00876, 0.02356, 0.05414, -0.0294, -0.0547,
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}

#[test]
fn test_allminilm() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::AllMiniLML6V2,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
0.02591, 0.00573, 0.01147, 0.03796, -0.0232, -0.0549, 0.01404, -0.0107, -0.0244,
-0.01822,
fn test_embeddings() {
let models_and_expected_values = vec![
(
EmbeddingModel::BGESmallEN,
vec![-0.02313, -0.02552, 0.017357, -0.06393, -0.00061],
),
(
EmbeddingModel::BGEBaseEN,
vec![0.0114, 0.03722, 0.02941, 0.0123, 0.03451],
),
(
EmbeddingModel::AllMiniLML6V2,
vec![0.02591, 0.00573, 0.01147, 0.03796, -0.0232],
),
(
EmbeddingModel::MLE5Large,
vec![0.00961, 0.00443, 0.00658, -0.03532, 0.00703],
),
(
EmbeddingModel::BGEBaseENV15,
vec![0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045],
),
(
EmbeddingModel::BGESmallENV15,
vec![0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434],
),
(
EmbeddingModel::BGESmallZH,
vec![-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762],
),
];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
}
}
for (model_name, expected) in models_and_expected_values {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: model_name.clone(),
..Default::default()
})
.unwrap();

#[test]
fn test_mle5large() {
let model: FlagEmbedding = FlagEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::MLE5Large,
..Default::default()
})
.unwrap();

let expected: Vec<f32> = vec![
0.00961, 0.00443, 0.00658, -0.03532, 0.00703, -0.02878, -0.03671, 0.03482, 0.06343,
-0.04731,
];
let documents = vec!["hello world"];
let documents = vec!["hello world"];

// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();
// Generate embeddings with the default batch size, 256
let embeddings = model.embed(documents, None).unwrap();

for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(difference < EPSILON, "Difference: {}", difference)
for (i, v) in expected.into_iter().enumerate() {
let difference = (v - embeddings[0][i]).abs();
assert!(
difference < EPSILON,
"Difference for {}: {}",
model_name.to_string(),
difference
)
}
}
}
}