Skip to content

Commit

Permalink
simplify unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-risch committed Apr 30, 2024
1 parent 3100af0 commit 0b78f7d
Showing 1 changed file with 3 additions and 44 deletions.
47 changes: 3 additions & 44 deletions test/components/readers/test_extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def forward(self, input_ids, attention_mask, *args, **kwargs):
[
Document(content="Angela Merkel was the chancellor of Germany."),
Document(content="Olaf Scholz is the chancellor of Germany"),
Document(content="Jerry is the head of the department."),
Document(content="Jerry is the head of the department.", meta={"page_number": 3}),
]
] * 2

Expand Down Expand Up @@ -386,49 +386,8 @@ def test_nest_answers(mock_reader: ExtractiveReader):
assert answer.query == query
assert answer.document == doc
assert answer.score == pytest.approx(score)
no_answer = answers[-1]
assert no_answer.query == query
assert no_answer.document is None
assert no_answer.score == pytest.approx(expected_no_answer)


def test_nest_answers_with_page_numbers(mock_reader: ExtractiveReader):
example_documents = [
Document(content="Angela Merkel was the chancellor of Germany.", meta={"page_number": 1}),
Document(content="Olaf Scholz is the chancellor of Germany", meta={"page_number": 2}),
Document(content="Jerry is the head of the department.", meta={"page_number": 3}),
]
start = list(range(5))
end = [i + 5 for i in start]
start = [start] * 6 # type: ignore
end = [end] * 6 # type: ignore
probabilities = torch.arange(5).unsqueeze(0) / 5 + torch.arange(6).unsqueeze(-1) / 25
query_ids = [0] * 3 + [1] * 3
document_ids = list(range(3)) * 2
nested_answers = mock_reader._nest_answers( # type: ignore
start=start,
end=end,
probabilities=probabilities,
flattened_documents=example_documents,
queries=example_queries,
answers_per_seq=5,
top_k=3,
score_threshold=None,
query_ids=query_ids,
document_ids=document_ids,
no_answer=True,
overlap_threshold=None,
)
expected_no_answers = [0.2 * 0.16 * 0.12, 0]
for query, answers, expected_no_answer, probabilities in zip(
example_queries, nested_answers, expected_no_answers, [probabilities[:3, -1], probabilities[3:, -1]]
):
assert len(answers) == 4
for doc, answer, score in zip(example_documents, reversed(answers[:3]), probabilities):
assert answer.query == query
assert answer.document == doc
assert answer.score == pytest.approx(score)
assert answer.meta["answer_page_number"] == doc.meta["page_number"]
if "page_number" in doc.meta:
assert answer.meta["answer_page_number"] == doc.meta["page_number"]
no_answer = answers[-1]
assert no_answer.query == query
assert no_answer.document is None
Expand Down

0 comments on commit 0b78f7d

Please sign in to comment.