Skip to content

Commit

Permalink
fix: Ensure eval mode for TableReader model for predictions (#3743)
Browse files Browse the repository at this point in the history
* Adding model.eval() calls to prediction functions in table reader

* Add unit test to check if model is set in train mode that inference time prediction still works.
  • Loading branch information
sjrl authored Jan 9, 2023
1 parent 659020f commit 5b0b338
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
4 changes: 4 additions & 0 deletions haystack/nodes/reader/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def _predict_tapas(self, inputs: BatchEncoding, document: Document) -> Answer:
string_table = orig_table.astype(str)

# Forward query and table through model and convert logits to predictions
self.model.eval()
with torch.inference_mode():
outputs = self.model(**inputs)

Expand Down Expand Up @@ -424,6 +425,7 @@ def _predict_tapas_scored(self, inputs: BatchEncoding, document: Document) -> Tu
string_table = orig_table.astype(str)

# Forward pass through model
self.model.eval()
with torch.inference_mode():
outputs = self.model.tapas(**inputs)
table_score = self.model.classifier(outputs.pooler_output)
Expand Down Expand Up @@ -719,6 +721,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
padding=True,
)
row_inputs.to(self.devices[0])
self.row_model.eval()
with torch.inference_mode():
row_outputs = self.row_model(**row_inputs)
row_logits = row_outputs[0].detach().cpu().numpy()[:, 1]
Expand All @@ -733,6 +736,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
padding=True,
)
column_inputs.to(self.devices[0])
self.column_model.eval()
with torch.inference_mode():
column_outputs = self.column_model(**column_inputs)
column_logits = column_outputs[0].detach().cpu().numpy()[:, 1]
Expand Down
38 changes: 38 additions & 0 deletions test/nodes/test_table_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import pandas as pd
import torch
import pytest

from haystack.schema import Document, Answer
Expand Down Expand Up @@ -41,6 +42,7 @@ def table3():
@pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True)
def test_table_reader(table_reader_and_param, table1, table2):
table_reader, param = table_reader_and_param

query = "When was Di Caprio born?"
prediction = table_reader.predict(
query=query,
Expand Down Expand Up @@ -72,6 +74,42 @@ def test_table_reader(table_reader_and_param, table1, table2):
assert prediction["answers"][1].offsets_in_context[0].end == reference2[param]["end"]


@pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True)
def test_table_reader_train_mode(table_reader_and_param, table1, table2):
table_reader, param = table_reader_and_param

# Set to deterministic seed
old_seed = torch.seed()
torch.manual_seed(0)

# Ensure that if model is put in train mode that predictions are not effected
if param != "rci":
table_reader.table_encoder.model.train()
elif param == "rci":
table_reader.row_model.train()
table_reader.column_model.train()

query = "When was Di Caprio born?"
prediction = table_reader.predict(
query=query,
documents=[Document(content=table1, content_type="table"), Document(content=table2, content_type="table")],
)

# Check the second answer in the list
reference2 = {
"tapas_small": {"answer": "5 april 1980", "start": 7, "end": 8, "score": 0.86314},
"rci": {"answer": "47", "start": 5, "end": 6, "score": -6.836},
"tapas_scored": {"answer": "brad pitt", "start": 0, "end": 1, "score": 0.49078},
}
assert prediction["answers"][1].score == pytest.approx(reference2[param]["score"], rel=1e-3)
assert prediction["answers"][1].answer == reference2[param]["answer"]
assert prediction["answers"][1].offsets_in_context[0].start == reference2[param]["start"]
assert prediction["answers"][1].offsets_in_context[0].end == reference2[param]["end"]

# Set back to old_seed
torch.manual_seed(old_seed)


@pytest.mark.parametrize("table_reader_and_param", ["tapas_small", "rci", "tapas_scored"], indirect=True)
def test_table_reader_batch_single_query_single_doc_list(table_reader_and_param, table1, table2):
table_reader, param = table_reader_and_param
Expand Down

0 comments on commit 5b0b338

Please sign in to comment.