From a9bcabc42d2516256bf0f27610bef27afe9c6f4c Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Wed, 16 Dec 2020 17:09:47 +0100 Subject: [PATCH] Fix saving tokenizers in DPR training + unify save and load dirs (#682) --- haystack/retriever/dense.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index a41c8b9f51..76fdd2a015 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -240,9 +240,9 @@ def train(self, grad_acc_steps: int = 1, optimizer_name: str = "TransformersAdamW", optimizer_correct_bias: bool = True, - save_dir: str = "../saved_models/dpr-tutorial", - query_encoder_save_dir: str = "lm1", - passage_encoder_save_dir: str = "lm2" + save_dir: str = "../saved_models/dpr", + query_encoder_save_dir: str = "query_encoder", + passage_encoder_save_dir: str = "passage_encoder" ): """ train a DensePassageRetrieval model @@ -317,20 +317,24 @@ def train(self, trainer.train() self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir) - self.processor.save(Path(save_dir)) + self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}") + self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}") - def save(self, save_dir: Union[Path, str]): + def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder", + passage_encoder_dir: str = "passage_encoder"): """ Save DensePassageRetriever to the specified directory. :param save_dir: Directory to save to. + :param query_encoder_dir: Directory in save_dir that contains query encoder model. + :param passage_encoder_dir: Directory in save_dir that contains passage encoder model. :return: None """ save_dir = Path(save_dir) - self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder") + self.model.save(save_dir, lm1_name=query_encoder_dir, lm2_name=passage_encoder_dir) save_dir = str(save_dir) - self.query_tokenizer.save_pretrained(save_dir + "/query_encoder") - self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder") + self.query_tokenizer.save_pretrained(save_dir + f"/{query_encoder_dir}") + self.passage_tokenizer.save_pretrained(save_dir + f"/{passage_encoder_dir}") @classmethod def load(cls, @@ -343,6 +347,8 @@ def load(cls, embed_title: bool = True, use_fast_tokenizers: bool = True, similarity_function: str = "dot_product", + query_encoder_dir: str = "query_encoder", + passage_encoder_dir: str = "passage_encoder" ): """ Load DensePassageRetriever from the specified directory. @@ -351,8 +357,8 @@ def load(cls, load_dir = Path(load_dir) dpr = cls( document_store=document_store, - query_embedding_model=Path(load_dir) / "query_encoder", - passage_embedding_model=Path(load_dir) / "passage_encoder", + query_embedding_model=Path(load_dir) / query_encoder_dir, + passage_embedding_model=Path(load_dir) / passage_encoder_dir, max_seq_len_query=max_seq_len_query, max_seq_len_passage=max_seq_len_passage, use_gpu=use_gpu,