Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: New FlagEmbedding models #6

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ The default embedding supports "query" and "passage" prefixes for the input text

## 🤖 Models

- [**BAAI/bge-base-en**](https://huggingface.co/BAAI/bge-base-en)
- [**BAAI/bge-base-en-v1.5**](https://huggingface.co/BAAI/bge-base-en-v1.5)
- [**BAAI/bge-small-en**](https://huggingface.co/BAAI/bge-small-en)
- [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default
- [**BAAI/bge-base-zh-v1.5**](https://huggingface.co/BAAI/bge-base-zh-v1.5)
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)

## 🚀 Installation
Expand Down
22 changes: 20 additions & 2 deletions fastembed.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ type EmbeddingModel string
const (
AllMiniLML6V2 EmbeddingModel = "fast-all-MiniLM-L6-v2"
BGEBaseEN EmbeddingModel = "fast-bge-base-en"
BGEBaseENV15 EmbeddingModel = "fast-bge-base-en-v1.5"
BGESmallEN EmbeddingModel = "fast-bge-small-en"
BGESmallENV15 EmbeddingModel = "fast-bge-small-en-v1.5"
BGESmallZH EmbeddingModel = "fast-bge-small-zh-v1.5"

// A model with type "Unigram" is not yet supported by the tokenizer
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
Expand Down Expand Up @@ -79,7 +82,7 @@ func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) {
}

if options.Model == "" {
options.Model = BGESmallEN
options.Model = BGESmallENV15
}

if options.MaxLength == 0 {
Expand Down Expand Up @@ -281,10 +284,25 @@ func ListSupportedModels() []ModelInfo {
Dim: 768,
Description: "Base English model",
},
{
Model: BGEBaseENV15,
Dim: 768,
Description: "v1.5 release of the base English model",
},
{
Model: BGESmallEN,
Dim: 384,
Description: "Fast and Default English model",
Description: "Fast English model",
},
{
Model: BGESmallENV15,
Dim: 384,
Description: "Fast, default English model",
},
{
Model: BGESmallZH,
Dim: 512,
Description: "Fast Chinese model",
},
// {
// Model: MLE5Large,
Expand Down
97 changes: 8 additions & 89 deletions fastembed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,95 +5,14 @@ import (
"testing"
)

func TestEmbedBGEBaseEN(t *testing.T) {
// Test with a single input
fe, err := NewFlagEmbedding(&InitOptions{
Model: BGEBaseEN,
})
defer fe.Destroy()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

func TestEmbedAllMiniLML6V2(t *testing.T) {
// Test with a single input
fe, err := NewFlagEmbedding(&InitOptions{
Model: AllMiniLML6V2,
})
defer fe.Destroy()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

func TestEmbedBGESmallEN(t *testing.T) {
// Test with a single input
fe, err := NewFlagEmbedding(&InitOptions{
Model: BGESmallEN,
})
defer fe.Destroy()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
input := []string{"hello world"}
result, err := fe.Embed(input, 1)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

if len(result) != len(input) {
t.Errorf("Expected result length %v, got %v", len(input), len(result))
}
}

// A model type "Unigram" is not yet supported by the tokenizer
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
// func TestEmbedMLE5Large(t *testing.T) {
// // Test with a single input
// show := false
// fe, err := NewFlagEmbedding(&InitOptions{
// Model: MLE5Large,
// ShowDownloadProgress: &show,
// })
// defer fe.Destroy()
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }
// input := []string{"hello world"}
// result, err := fe.Embed(input, 1)
// if err != nil {
// t.Fatalf("Expected no error, got %v", err)
// }

// if len(result) != len(input) {
// t.Errorf("Expected result length %v, got %v", len(input), len(result))
// }
// }

func TestCanonicalValues(T *testing.T) {
canonicalValues := map[EmbeddingModel]([]float32){
AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328, -0.05493, 0.014040, -0.01079, -0.02440, -0.01822},
BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.02212, -0.01472, 0.03925, 0.03444, 0.00459},
BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451, 0.00876, 0.02356, 0.05414, -0.02945, -0.05472},
AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328},
BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061},
BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451},
BGEBaseENV15: []float32{0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045},
BGESmallENV15: []float32{0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434},
BGESmallZH: []float32{-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762},
}

for model, expected := range canonicalValues {
Expand All @@ -114,10 +33,10 @@ func TestCanonicalValues(T *testing.T) {
T.Errorf("Expected result length %v, got %v", len(input), len(result))
}

epsilon := float64(1e-5)
epsilon := float64(1e-4)
for i, v := range expected {
if math.Abs(float64(result[0][i]-v)) > float64(epsilon) {
T.Errorf("Element %d mismatch: expected %.6f, got %.6f", i, v, result[0][i])
T.Errorf("Element %d mismatch for %s: expected %.6f, got %.6f", i, model, v, result[0][i])
}
}
}
Expand Down