-
Notifications
You must be signed in to change notification settings - Fork 163
/
Copy pathbasic_bge.py
27 lines (23 loc) · 1.13 KB
/
basic_bge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# BAAI/bge-large-zh-v1.5
# BAAI/bge-large-en-v1.5
# BAAI/bge-base-zh-v1.5
# BAAI/bge-base-en-v1.5
# BAAI/bge-small-zh-v1.5
# BAAI/bge-small-en-v1.5
root_model_path = 'E:/data/pretrain_ckpt/BAAI/bge-large-zh-v1.5'
sentences_1 = ["样例数据-1", "样例数据-2"]
sentences_2 = ["样例数据-3", "样例数据-4"]
print('=========================================sentence transformer====================================')
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(root_model_path)
embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)
embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
print('=========================================bert4torch====================================')
from bert4torch.pipelines import Text2Vec
text2vec = Text2Vec(checkpoint_path=root_model_path, device='cuda')
embeddings_1 = text2vec.encode(sentences_1, normalize_embeddings=True)
embeddings_2 = text2vec.encode(sentences_2, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)