-
Notifications
You must be signed in to change notification settings - Fork 0
/
Search.py
46 lines (36 loc) · 1.37 KB
/
Search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer
import time
class Search:
def __init__(self, text, data):
self.encoder = SentenceTransformer("all-MiniLM-L6-v2")
self.qdrant = QdrantClient(":memory:")
self.qdrant.create_collection(
collection_name="Research_papers",
vectors_config=models.VectorParams(
size=self.encoder.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
on_disk=True
),
)
self.qdrant.upload_points(
collection_name="Research_papers",
points=[
models.PointStruct(
id=idx, vector=self.encoder.encode(doc["text"]).tolist(), payload=doc)
for idx, doc in enumerate(data)
],
)
self.result = self.search(text)
def search(self, text):
start_time = time.perf_counter()
hits = self.qdrant.search(
collection_name="Research_papers",
query_vector=self.encoder.encode(text).tolist(),
limit=1,
)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(elapsed_time)
return [hits[0].score, hits[0].payload['number']]