-
Notifications
You must be signed in to change notification settings - Fork 16.3k
/
Copy pathpinecone_hybrid_search.py
185 lines (162 loc) Β· 5.87 KB
/
pinecone_hybrid_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Taken from: https://docs.pinecone.io/docs/hybrid-search"""
import hashlib
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import pre_init
from pydantic import ConfigDict
def hash_text(text: str) -> str:
"""Hash a text using SHA256.
Args:
text: Text to hash.
Returns:
Hashed text.
"""
return str(hashlib.sha256(text.encode("utf-8")).hexdigest())
def create_index(
contexts: List[str],
index: Any,
embeddings: Embeddings,
sparse_encoder: Any,
ids: Optional[List[str]] = None,
metadatas: Optional[List[dict]] = None,
namespace: Optional[str] = None,
text_key: str = "context",
) -> None:
"""Create an index from a list of contexts.
It modifies the index argument in-place!
Args:
contexts: List of contexts to embed.
index: Index to use.
embeddings: Embeddings model to use.
sparse_encoder: Sparse encoder to use.
ids: List of ids to use for the documents.
metadatas: List of metadata to use for the documents.
namespace: Namespace value for index partition.
"""
batch_size = 32
_iterator = range(0, len(contexts), batch_size)
try:
from tqdm.auto import tqdm
_iterator = tqdm(_iterator)
except ImportError:
pass
if ids is None:
# create unique ids using hash of the text
ids = [hash_text(context) for context in contexts]
for i in _iterator:
# find end of batch
i_end = min(i + batch_size, len(contexts))
# extract batch
context_batch = contexts[i:i_end]
batch_ids = ids[i:i_end]
metadata_batch = (
metadatas[i:i_end] if metadatas else [{} for _ in context_batch]
)
# add context passages as metadata
meta = [
{text_key: context, **metadata}
for context, metadata in zip(context_batch, metadata_batch)
]
# create dense vectors
dense_embeds = embeddings.embed_documents(context_batch)
# create sparse vectors
sparse_embeds = sparse_encoder.encode_documents(context_batch)
for s in sparse_embeds:
s["values"] = [float(s1) for s1 in s["values"]]
vectors = []
# loop through the data and create dictionaries for upserts
for doc_id, sparse, dense, metadata in zip(
batch_ids, sparse_embeds, dense_embeds, meta
):
vectors.append(
{
"id": doc_id,
"sparse_values": sparse,
"values": dense,
"metadata": metadata,
}
)
# upload the documents to the new hybrid index
index.upsert(vectors, namespace=namespace)
class PineconeHybridSearchRetriever(BaseRetriever):
"""`Pinecone Hybrid Search` retriever."""
embeddings: Embeddings
"""Embeddings model to use."""
"""description"""
sparse_encoder: Any = None
"""Sparse encoder to use."""
index: Any = None
"""Pinecone index to use."""
top_k: int = 4
"""Number of documents to return."""
alpha: float = 0.5
"""Alpha value for hybrid search."""
namespace: Optional[str] = None
"""Namespace value for index partition."""
text_key: str = "context"
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
def add_texts(
self,
texts: List[str],
ids: Optional[List[str]] = None,
metadatas: Optional[List[dict]] = None,
namespace: Optional[str] = None,
) -> None:
create_index(
texts,
self.index,
self.embeddings,
self.sparse_encoder,
ids=ids,
metadatas=metadatas,
namespace=namespace,
text_key=self.text_key,
)
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
from pinecone_text.hybrid import hybrid_convex_scale # noqa:F401
from pinecone_text.sparse.base_sparse_encoder import (
BaseSparseEncoder, # noqa:F401
)
except ImportError:
raise ImportError(
"Could not import pinecone_text python package. "
"Please install it with `pip install pinecone_text`."
)
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
from pinecone_text.hybrid import hybrid_convex_scale
sparse_vec = self.sparse_encoder.encode_queries(query)
# convert the question into a dense vector
dense_vec = self.embeddings.embed_query(query)
# scale alpha with hybrid_scale
dense_vec, sparse_vec = hybrid_convex_scale(dense_vec, sparse_vec, self.alpha)
sparse_vec["values"] = [float(s1) for s1 in sparse_vec["values"]]
# query pinecone with the query parameters
result = self.index.query(
vector=dense_vec,
sparse_vector=sparse_vec,
top_k=self.top_k,
include_metadata=True,
namespace=self.namespace,
**kwargs,
)
final_result = []
for res in result["matches"]:
context = res["metadata"].pop(self.text_key)
metadata = res["metadata"]
if "score" not in metadata and "score" in res:
metadata["score"] = res["score"]
final_result.append(Document(page_content=context, metadata=metadata))
# return search results as json
return final_result