Skip to content

Commit

Permalink
Derive MDR from EmbeddingRetriever instead of BaseRetriever
Browse files Browse the repository at this point in the history
  • Loading branch information
deutschmn committed Jun 23, 2022
1 parent b8df9e8 commit a48c6a2
Show file tree
Hide file tree
Showing 2 changed files with 623 additions and 954 deletions.
63 changes: 0 additions & 63 deletions haystack/modeling/model/biadaptive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,66 +485,3 @@ def convert_from_transformers(
bi_adaptive_model.connect_heads_with_processor(processor.tasks) # type: ignore

return bi_adaptive_model


class BiAdaptiveSharedModel(BiAdaptiveModel):
"""Same behaviour as BiAdaptiveModel but with shared underlying model"""

def __init__(
self,
language_model1: LanguageModel,
language_model2: LanguageModel,
prediction_heads: List[PredictionHead],
embeds_dropout_prob: float = 0.1,
device: torch.device = torch.device("cuda"),
lm1_output_types: Union[str, List[str]] = ["per_sequence"],
lm2_output_types: Union[str, List[str]] = ["per_sequence"],
loss_aggregation_fn: Optional[Callable] = None,
):
if (language_model1 != language_model2) or (lm1_output_types != lm2_output_types):
raise ValueError("BiAdaptiveSharedModel assumes using the same models and outputs for lm1 and lm2")

super().__init__(
language_model1,
language_model2,
prediction_heads,
embeds_dropout_prob,
device,
lm1_output_types,
lm2_output_types,
loss_aggregation_fn,
)

def save(self, save_dir: Path, lm_name: str = "lm"): # type: ignore
"""
Saves the language model weights and respective config_files in directory lm within save_dir.
:param save_dir: Path to save the model to.
"""
os.makedirs(save_dir, exist_ok=True)
if not os.path.exists(Path.joinpath(save_dir, Path(lm_name))):
os.makedirs(Path.joinpath(save_dir, Path(lm_name)))
self.language_model1.save(Path.joinpath(save_dir, Path(lm_name)))
for i, ph in enumerate(self.prediction_heads):
logger.info("prediction_head saving")
ph.save(save_dir, i)

def forward_lm(self, **kwargs):
pooled_output = [None, None]
if "query_input_ids" in kwargs.keys():
_, pooled_output1 = self.language_model1(
input_ids=kwargs["query_input_ids"],
segment_ids=kwargs["query_segment_ids"],
padding_mask=kwargs["query_attention_mask"],
)
pooled_output[0] = pooled_output1
if "passage_input_ids" in kwargs.keys():
max_seq_len = kwargs["passage_input_ids"].shape[-1]
_, pooled_output2 = self.language_model2(
input_ids=kwargs["passage_input_ids"].view(-1, max_seq_len),
segment_ids=kwargs["passage_segment_ids"].view(-1, max_seq_len),
padding_mask=kwargs["passage_attention_mask"].view(-1, max_seq_len),
)
pooled_output[1] = pooled_output2

return tuple(pooled_output)
Loading

0 comments on commit a48c6a2

Please sign in to comment.