Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve pinecone hybrid search retriever adding metadata support #5098

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"metadata": {},
"outputs": [],
"source": [
"#!pip install pinecone-client"
"#!pip install pinecone-client pinecone-text"
]
},
{
Expand Down
31 changes: 27 additions & 4 deletions langchain/retrievers/pinecone_hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def create_index(
embeddings: Embeddings,
sparse_encoder: Any,
ids: Optional[List[str]] = None,
metadatas: Optional[List[dict]] = None,
) -> None:
batch_size = 32
_iterator = range(0, len(contexts), batch_size)
Expand All @@ -38,8 +39,15 @@ def create_index(
# extract batch
context_batch = contexts[i:i_end]
batch_ids = ids[i:i_end]
metadata_batch = (
metadatas[i:i_end] if metadatas else [{} for _ in context_batch]
)
# add context passages as metadata
meta = [{"context": context} for context in context_batch]
meta = [
{"context": context, **metadata}
for context, metadata in zip(context_batch, metadata_batch)
]

# create dense vectors
dense_embeds = embeddings.embed_documents(context_batch)
# create sparse vectors
Expand Down Expand Up @@ -78,8 +86,20 @@ class Config:
extra = Extra.forbid
arbitrary_types_allowed = True

def add_texts(self, texts: List[str], ids: Optional[List[str]] = None) -> None:
create_index(texts, self.index, self.embeddings, self.sparse_encoder, ids=ids)
def add_texts(
self,
texts: List[str],
ids: Optional[List[str]] = None,
metadatas: Optional[List[dict]] = None,
) -> None:
create_index(
texts,
self.index,
self.embeddings,
self.sparse_encoder,
ids=ids,
metadatas=metadatas,
)

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -114,7 +134,10 @@ def get_relevant_documents(self, query: str) -> List[Document]:
)
final_result = []
for res in result["matches"]:
final_result.append(Document(page_content=res["metadata"]["context"]))
context = res["metadata"].pop("context")
final_result.append(
Document(page_content=context, metadata=res["metadata"])
)
# return search results as json
return final_result

Expand Down