Skip to content

Commit

Permalink
Edit /v1/retrieve endpoint return formats (includes score, filepath a…
Browse files Browse the repository at this point in the history
…nd file page)
  • Loading branch information
vkehfdl1 committed Dec 13, 2024
1 parent 94f51d7 commit e5888c4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
34 changes: 14 additions & 20 deletions autorag/deploy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class QueryRequest(BaseModel):
class RetrievedPassage(BaseModel):
content: str
doc_id: str
score: float
filepath: Optional[str] = None
file_page: Optional[int] = None
start_idx: Optional[int] = None
Expand All @@ -41,14 +42,8 @@ class RunResponse(BaseModel):
retrieved_passage: List[RetrievedPassage]


class Passage(BaseModel):
doc_id: str
content: str
score: float


class RetrievalResponse(BaseModel):
passages: List[Passage]
passages: List[RetrievedPassage]


class StreamResponse(BaseModel):
Expand Down Expand Up @@ -153,18 +148,9 @@ async def run_retrieve_only():
previous_result = pd.concat([drop_previous_result, new_result], axis=1)

# Simulate processing the query
retrieved_contents = previous_result["retrieved_contents"].tolist()[0]
retrieved_ids = previous_result["retrieved_ids"].tolist()[0]
retrieve_scores = previous_result["retrieve_scores"].tolist()[0]

retrieval_response = RetrievalResponse(
passages=[
Passage(doc_id=doc_id, content=content, score=score)
for doc_id, content, score in zip(
retrieved_ids, retrieved_contents, retrieve_scores
)
]
)
retrieved_passages = self.extract_retrieve_passage(previous_result)

retrieval_response = RetrievalResponse(passages=retrieved_passages)
return jsonify(retrieval_response.model_dump()), 200

@self.app.route("/v1/stream", methods=["POST"])
Expand Down Expand Up @@ -264,6 +250,7 @@ def run_api_server(
def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
retrieved_ids: List[str] = df["retrieved_ids"].tolist()[0]
contents = fetch_contents(self.corpus_df, [retrieved_ids])[0]
scores = df["retrieve_scores"].tolist()[0]
if "path" in self.corpus_df.columns:
paths = fetch_contents(self.corpus_df, [retrieved_ids], column_name="path")[
0
Expand All @@ -282,16 +269,23 @@ def extract_retrieve_passage(self, df: pd.DataFrame) -> List[RetrievedPassage]:
start_end_indices = to_list(start_end_indices)
return list(
map(
lambda content, doc_id, path, metadata, start_end_idx: RetrievedPassage(
lambda content,
doc_id,
score,
path,
metadata,
start_end_idx: RetrievedPassage(
content=content,
doc_id=doc_id,
score=score,
filepath=path,
file_page=metadata.get("page", None),
start_idx=start_end_idx[0] if start_end_idx else None,
end_idx=start_end_idx[1] if start_end_idx else None,
),
contents,
retrieved_ids,
scores,
paths,
metadatas,
start_end_indices,
Expand Down
4 changes: 4 additions & 0 deletions tests/autorag/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ async def post_to_server_retrieve():
assert isinstance(passages[0]["doc_id"], str)
assert isinstance(passages[0]["content"], str)
assert isinstance(passages[0]["score"], float)
assert passages[0]["filepath"]
assert passages[0]["file_page"]
assert passages[0]["start_idx"]
assert passages[0]["end_idx"]


def test_runner_api_server2(evaluator_data_gen_by_autorag):
Expand Down

0 comments on commit e5888c4

Please sign in to comment.