Skip to content

Commit

Permalink
Add document_id with Transformers Reader (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
karthik19967829 authored Apr 8, 2020
1 parent 5932aa0 commit da3277c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions haystack/reader/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,22 @@ def predict(self, question, paragraphs, meta_data_paragraphs=None, top_k=None):

# get top-answers for each candidate passage
answers = []
for p in paragraphs:
query = {"context": p, "question": question}
for para, meta in zip(paragraphs, meta_data_paragraphs):
query = {"context": para, "question": question}
predictions = self.model(query, topk=self.n_best_per_passage)
# assemble and format all answers
for pred in predictions:
if pred["answer"]:
context_start = max(0, pred["start"] - self.context_size)
context_end = min(len(p), pred["end"] + self.context_size)
context_end = min(len(para), pred["end"] + self.context_size)
answers.append({
"answer": pred["answer"],
"context": p[context_start:context_end],
"context": para[context_start:context_end],
"offset_answer_start": pred["start"],
"offset_answer_end": pred["end"],
"probability": pred["score"],
"probability": pred["score"],
"score": None,
"document_id": None
"document_id": meta["document_id"]
})

# sort answers by their `probability` and select top-k
Expand Down
6 changes: 3 additions & 3 deletions haystack/retriever/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
logger = logging.getLogger(__name__)

# TODO make Paragraph generic for configurable units of text eg, pages, paragraphs, or split by a char_limit
Paragraph = namedtuple("Paragraph", ["paragraph_id", "document_id", "text"])
Paragraph = namedtuple("Paragraph", ["paragraph_id", "document_id", "text","document_name"])


class TfidfRetriever(BaseRetriever):
Expand Down Expand Up @@ -48,7 +48,7 @@ def _get_all_paragraphs(self):
if not p.strip(): # skip empty paragraphs
continue
paragraphs.append(
Paragraph(document_id=doc["id"], paragraph_id=p_id, text=(p,))
Paragraph(document_id=doc["id"],document_name=doc["name"],paragraph_id=p_id, text=(p,))
)
p_id += 1
logger.info(f"Found {len(paragraphs)} candidate paragraphs from {len(documents)} docs in DB")
Expand Down Expand Up @@ -81,7 +81,7 @@ def retrieve(self, query, candidate_doc_ids=None, top_k=10, verbose=True):

# get actual content for the top candidates
paragraphs = list(df_sliced.text.values)
meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"]}
meta_data = [{"document_id": row["document_id"], "paragraph_id": row["paragraph_id"],"document_name":row["document_name"]}
for idx, row in df_sliced.iterrows()]

return paragraphs, meta_data
Expand Down

0 comments on commit da3277c

Please sign in to comment.