From 81c176ea1b00336b5490ac636f49073a71d464c6 Mon Sep 17 00:00:00 2001
From: lalitpagaria <pagaria.lalit@gmail.com>
Date: Sun, 15 Nov 2020 18:03:57 +0100
Subject: [PATCH 1/3] Removing device information from generator model
 arguments as it is handled by itself.

---
 haystack/generator/transformers.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/haystack/generator/transformers.py b/haystack/generator/transformers.py
index a9d055b167..983ba465f2 100644
--- a/haystack/generator/transformers.py
+++ b/haystack/generator/transformers.py
@@ -153,8 +153,7 @@ def _get_contextualized_inputs(self, texts: List[str], question: str, titles: Op
             truncation=True,
         )
 
-        return contextualized_inputs["input_ids"].to(self.device), \
-               contextualized_inputs["attention_mask"].to(self.device)
+        return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
 
     def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[Optional[numpy.ndarray]]) -> torch.Tensor:
 
@@ -170,7 +169,7 @@ def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[Opt
         embeddings_in_tensor = torch.cat(
             [torch.from_numpy(embedding).unsqueeze(0) for embedding in embeddings],
             dim=0
-        ).to(self.device)
+        )
 
         return embeddings_in_tensor
 

From 0c3f80c6a173d3d69578956a628e2451d5d43b25 Mon Sep 17 00:00:00 2001
From: lalitpagaria <pagaria.lalit@gmail.com>
Date: Sun, 15 Nov 2020 18:18:33 +0100
Subject: [PATCH 2/3] num_return_sequences of should not be greate than
 num_beams

---
 haystack/generator/transformers.py | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/haystack/generator/transformers.py b/haystack/generator/transformers.py
index 983ba465f2..fa8ff2a9bc 100644
--- a/haystack/generator/transformers.py
+++ b/haystack/generator/transformers.py
@@ -90,7 +90,6 @@ def __init__(
         """
 
         self.model_name_or_path = model_name_or_path
-        self.top_k_answers = top_k_answers
         self.max_length = max_length
         self.min_length = min_length
         self.generator_type = generator_type
@@ -99,6 +98,12 @@ def __init__(
         self.prefix = prefix
         self.retriever = retriever
 
+        if top_k_answers > self.num_beams:
+            top_k_answers = self.num_beams
+            logger.warning(f'top_k_answers value should not be greater than num_beams, hence setting it to {num_beams}')
+
+        self.top_k_answers = top_k_answers
+
         if use_gpu and torch.cuda.is_available():
             self.device = torch.device("cuda")
         else:
@@ -201,6 +206,11 @@ def predict(self, question: str, documents: List[Document], top_k: Optional[int]
 
         top_k_answers = top_k if top_k is not None else self.top_k_answers
 
+        if top_k_answers > self.num_beams:
+            top_k_answers = self.num_beams
+            logger.warning(f'top_k_answers value should not be greater than num_beams, '
+                           f'hence setting it to {top_k_answers}')
+
         # Flatten the documents so easy to reference
         flat_docs_dict: Dict[str, Any] = {}
         for document in documents:

From bf18ba577206311964bc3e5ad4e856fddcefb303 Mon Sep 17 00:00:00 2001
From: lalitpagaria <pagaria.lalit@gmail.com>
Date: Mon, 16 Nov 2020 09:16:28 +0100
Subject: [PATCH 3/3] Raise error when user use generator with GPU as currently
 it is not supported

---
 haystack/generator/transformers.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/haystack/generator/transformers.py b/haystack/generator/transformers.py
index fa8ff2a9bc..c93fbc89c8 100644
--- a/haystack/generator/transformers.py
+++ b/haystack/generator/transformers.py
@@ -106,6 +106,7 @@ def __init__(
 
         if use_gpu and torch.cuda.is_available():
             self.device = torch.device("cuda")
+            raise AttributeError("Currently RAGenerator does not support GPU, try with use_gpu=False")
         else:
             self.device = torch.device("cpu")
 
@@ -158,7 +159,8 @@ def _get_contextualized_inputs(self, texts: List[str], question: str, titles: Op
             truncation=True,
         )
 
-        return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
+        return contextualized_inputs["input_ids"].to(self.device), \
+               contextualized_inputs["attention_mask"].to(self.device)
 
     def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[Optional[numpy.ndarray]]) -> torch.Tensor:
 
@@ -174,7 +176,7 @@ def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[Opt
         embeddings_in_tensor = torch.cat(
             [torch.from_numpy(embedding).unsqueeze(0) for embedding in embeddings],
             dim=0
-        )
+        ).to(self.device)
 
         return embeddings_in_tensor