diff --git a/haystack/reader/transformers.py b/haystack/reader/transformers.py index 5dd725e224..6706ac4cb9 100644 --- a/haystack/reader/transformers.py +++ b/haystack/reader/transformers.py @@ -76,22 +76,22 @@ def predict(self, question, paragraphs, meta_data_paragraphs=None, top_k=None): # get top-answers for each candidate passage answers = [] - for p in paragraphs: - query = {"context": p, "question": question} + for para, meta in zip(paragraphs, meta_data_paragraphs): + query = {"context": para, "question": question} predictions = self.model(query, topk=self.n_best_per_passage) # assemble and format all answers for pred in predictions: if pred["answer"]: context_start = max(0, pred["start"] - self.context_size) - context_end = min(len(p), pred["end"] + self.context_size) + context_end = min(len(para), pred["end"] + self.context_size) answers.append({ "answer": pred["answer"], - "context": p[context_start:context_end], + "context": para[context_start:context_end], "offset_answer_start": pred["start"], "offset_answer_end": pred["end"], - "probability": pred["score"], + "probability": pred["score"], "score": None, - "document_id": None + "document_id": meta["document_id"] }) # sort answers by their `probability` and select top-k diff --git a/haystack/retriever/tfidf.py b/haystack/retriever/tfidf.py index 27a9d42790..952578a16f 100644 --- a/haystack/retriever/tfidf.py +++ b/haystack/retriever/tfidf.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) # TODO make Paragraph generic for configurable units of text eg, pages, paragraphs, or split by a char_limit -Paragraph = namedtuple("Paragraph", ["paragraph_id", "document_id", "text"]) +Paragraph = namedtuple("Paragraph", ["paragraph_id", "document_id", "text","document_name"]) class TfidfRetriever(BaseRetriever): @@ -48,7 +48,7 @@ def _get_all_paragraphs(self): if not p.strip(): # skip empty paragraphs continue paragraphs.append( - Paragraph(document_id=doc["id"], paragraph_id=p_id, text=(p,)) + Paragraph(document_id=doc["id"],document_name=doc["name"],paragraph_id=p_id, text=(p,)) ) p_id += 1 logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB") @@ -81,7 +81,7 @@ def retrieve(self, query, candidate_doc_ids=None, top_k=10, verbose=True): # get actual content for the top candidates paragraphs = list(df_sliced.text.values) - meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"]} + meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"],"document_name":row["document_name"]} for idx, row in df_sliced.iterrows()] return paragraphs, meta_data