Skip to content

Commit

Permalink
Fix saving tokenizers in DPR training + unify save and load dirs (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
bogdankostic authored Dec 16, 2020
1 parent 4c2804e commit a9bcabc
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions haystack/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ def train(self,
grad_acc_steps: int = 1,
optimizer_name: str = "TransformersAdamW",
optimizer_correct_bias: bool = True,
save_dir: str = "../saved_models/dpr-tutorial",
query_encoder_save_dir: str = "lm1",
passage_encoder_save_dir: str = "lm2"
save_dir: str = "../saved_models/dpr",
query_encoder_save_dir: str = "query_encoder",
passage_encoder_save_dir: str = "passage_encoder"
):
"""
train a DensePassageRetrieval model
Expand Down Expand Up @@ -317,20 +317,24 @@ def train(self,
trainer.train()

self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
self.processor.save(Path(save_dir))
self.query_tokenizer.save_pretrained(f"{save_dir}/{query_encoder_save_dir}")
self.passage_tokenizer.save_pretrained(f"{save_dir}/{passage_encoder_save_dir}")

def save(self, save_dir: Union[Path, str]):
def save(self, save_dir: Union[Path, str], query_encoder_dir: str = "query_encoder",
passage_encoder_dir: str = "passage_encoder"):
"""
Save DensePassageRetriever to the specified directory.
:param save_dir: Directory to save to.
:param query_encoder_dir: Directory in save_dir that contains query encoder model.
:param passage_encoder_dir: Directory in save_dir that contains passage encoder model.
:return: None
"""
save_dir = Path(save_dir)
self.model.save(save_dir, lm1_name="query_encoder", lm2_name="passage_encoder")
self.model.save(save_dir, lm1_name=query_encoder_dir, lm2_name=passage_encoder_dir)
save_dir = str(save_dir)
self.query_tokenizer.save_pretrained(save_dir + "/query_encoder")
self.passage_tokenizer.save_pretrained(save_dir + "/passage_encoder")
self.query_tokenizer.save_pretrained(save_dir + f"/{query_encoder_dir}")
self.passage_tokenizer.save_pretrained(save_dir + f"/{passage_encoder_dir}")

@classmethod
def load(cls,
Expand All @@ -343,6 +347,8 @@ def load(cls,
embed_title: bool = True,
use_fast_tokenizers: bool = True,
similarity_function: str = "dot_product",
query_encoder_dir: str = "query_encoder",
passage_encoder_dir: str = "passage_encoder"
):
"""
Load DensePassageRetriever from the specified directory.
Expand All @@ -351,8 +357,8 @@ def load(cls,
load_dir = Path(load_dir)
dpr = cls(
document_store=document_store,
query_embedding_model=Path(load_dir) / "query_encoder",
passage_embedding_model=Path(load_dir) / "passage_encoder",
query_embedding_model=Path(load_dir) / query_encoder_dir,
passage_embedding_model=Path(load_dir) / passage_encoder_dir,
max_seq_len_query=max_seq_len_query,
max_seq_len_passage=max_seq_len_passage,
use_gpu=use_gpu,
Expand Down

0 comments on commit a9bcabc

Please sign in to comment.