Skip to content

Commit

Permalink
Merge pull request #2 from aurelio-labs/feature/reranking-async
Browse files Browse the repository at this point in the history
Feature: reranking async
  • Loading branch information
italianconcerto authored Dec 2, 2024
2 parents 11a4289 + 6f838aa commit f2a9db6
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 6 deletions.
26 changes: 24 additions & 2 deletions examples/example_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from pinecone_async import PineconeClient

async def main():
async def list_indexes_example():
api_key = os.getenv("PINECONE_API_KEY")
if not api_key:
raise ValueError("PINECONE_API_KEY environment variable is not set")
Expand All @@ -11,6 +11,28 @@ async def main():
indexes = await client.list_indexes()
print(f"Indexes: {indexes}")

async def rerank_example():
api_key = os.getenv("PINECONE_API_KEY")
if not api_key:
raise ValueError("PINECONE_API_KEY environment variable is not set")
client = PineconeClient(api_key=api_key)

result = await client.rerank(
model="bge-reranker-v2-m3",
query="The tech company Apple is known for its innovative products like the iPhone.",
documents=[
{"id": "vec1", "text": "Apple is a popular fruit known for its sweetness and crisp texture."},
{"id": "vec2", "text": "Many people enjoy eating apples as a healthy snack."},
{"id": "vec3", "text": "Apple Inc. has revolutionized the tech industry with its sleek designs and user-friendly interfaces."},
{"id": "vec4", "text": "An apple a day keeps the doctor away, as the saying goes."},
],
top_n=4,
return_documents=True,
parameters={
"truncate": "END"
}
)
print(result)

if __name__ == "__main__":
asyncio.run(main())
asyncio.run(rerank_example())
1 change: 0 additions & 1 deletion src/pinecone_async/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# src/pinecone_async/__init__.py
from .client import PineconeClient
from .index import PineconeIndex
from .schema import (
Expand Down
52 changes: 50 additions & 2 deletions src/pinecone_async/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Optional

import httpx

from pinecone_async.exceptions import IndexNotFoundError
from pinecone_async.schema import IndexResponse, PineconePod, Serverless
from pinecone_async.schema import Document, IndexResponse, PineconePod, RerankParameters, RerankRequest, RerankResponse, Serverless


class PineconeClient:
Expand Down Expand Up @@ -99,3 +99,51 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()

async def rerank(
self,
model: str,
query: str,
documents: list[dict[str, str]],
top_n: Optional[int] = None,
return_documents: Optional[bool] = True,
parameters: Optional[dict] = None,
rank_fields: Optional[list[str]] = None
) -> RerankResponse:
"""
Rerank documents based on their relevance to a query.
Args:
model: The reranking model to use (e.g., "bge-reranker-v2-m3")
query: The query text to compare documents against
documents: List of documents to rerank
top_n: Number of top results to return
return_documents: Whether to include documents in response
parameters: Additional parameters like truncation
rank_fields: Optional list of custom fields to rank on
"""
headers = {
"Api-Key": self.headers["Api-Key"],
"Content-Type": "application/json",
"Accept": "application/json",
"X-Pinecone-API-Version": "2024-10"
}

request = RerankRequest(
model=model,
query=query,
documents=[Document(**doc) for doc in documents],
top_n=top_n,
return_documents=return_documents,
parameters=RerankParameters(**(parameters or {})),
rank_fields=rank_fields
)

async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(
"https://api.pinecone.io/rerank",
json=request.model_dump(exclude_none=True)
)

if response.status_code == 200:
return RerankResponse(**response.json())
else:
raise Exception(f"Failed to rerank: {response.status_code} : {response.text}")
33 changes: 32 additions & 1 deletion src/pinecone_async/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,35 @@ class ListResponse(BaseModel):
class FetchResponse(BaseModel):
"""Response from fetch operation"""
vectors: Dict[str, PineconeVector]
namespace: Optional[str] = None
namespace: Optional[str] = None



class Document(BaseModel):
id: str
text: str | None = None
my_field: str | None = None # For custom field support

class RerankParameters(BaseModel):
truncate: Optional[Literal["START", "END", "NONE"]] = "END"

class RerankRequest(BaseModel):
model: str
query: str
documents: list[Document]
top_n: Optional[int] = None
return_documents: Optional[bool] = True
parameters: Optional[RerankParameters] = None
rank_fields: Optional[list[str]] = None

class RerankResult(BaseModel):
index: int
document: Optional[Document] = None
score: float

class RerankUsage(BaseModel):
rerank_units: int

class RerankResponse(BaseModel):
data: list[RerankResult]
usage: RerankUsage

0 comments on commit f2a9db6

Please sign in to comment.