From ef38f4a40475a8991649002a9dadd1a4c88eb4db Mon Sep 17 00:00:00 2001 From: Mosakana Date: Sat, 19 Oct 2024 16:53:23 +0200 Subject: [PATCH] fix negation distillation --- src/xpmir/datasets/adapters.py | 2 +- src/xpmir/letor/distillation/pairwise.py | 6 +++--- src/xpmir/letor/distillation/samplers.py | 8 ++++---- src/xpmir/letor/samplers/hydrators.py | 5 ++--- src/xpmir/neural/dual.py | 2 +- src/xpmir/text/huggingface/tokenizers.py | 7 +++++++ 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/xpmir/datasets/adapters.py b/src/xpmir/datasets/adapters.py index 702b9ff..d2d8e49 100644 --- a/src/xpmir/datasets/adapters.py +++ b/src/xpmir/datasets/adapters.py @@ -545,7 +545,7 @@ class MemoryTopicStore(TextStore): @cached_property def store(self): - return {topic[IDItem].id: topic.text for topic in self.topics.iter()} + return {topic[IDItem].id: topic[TextItem].text for topic in self.topics.iter()} def __getitem__(self, key: str) -> str: return self.store[key] diff --git a/src/xpmir/letor/distillation/pairwise.py b/src/xpmir/letor/distillation/pairwise.py index bd6d660..bb65d3a 100644 --- a/src/xpmir/letor/distillation/pairwise.py +++ b/src/xpmir/letor/distillation/pairwise.py @@ -131,9 +131,9 @@ def train_batch(self, samples: List[PairwiseDistillationSample]): for ix, sample in enumerate(samples): records.add( PairwiseRecord( - sample.query.as_record(), - DocumentRecord(sample.documents[0].document), - DocumentRecord(sample.documents[1].document), + sample.query, + sample.documents[0].document, + sample.documents[1].document, ) ) teacher_scores[ix, 0] = sample.documents[0].score diff --git a/src/xpmir/letor/distillation/samplers.py b/src/xpmir/letor/distillation/samplers.py index b7eee7b..70ba31b 100644 --- a/src/xpmir/letor/distillation/samplers.py +++ b/src/xpmir/letor/distillation/samplers.py @@ -52,12 +52,12 @@ class PairwiseHydrator(PairwiseDistillationSamples, SampleHydrator): def transform(self, sample: PairwiseDistillationSample): topic, documents = sample.query, sample.documents - if transformed := self.querystore.transforme_topics(): - topic = TopicRecord(*transformed) + if transformed := self.transform_topics(topic): + topic = transformed[0] - if transformed := self.documentstore.transforme_documents(): + if transformed := self.transform_documents(documents): documents = tuple( - ScoredDocument(d, sd.score) + ScoredDocument(d, sd[ScoredItem].score) for d, sd in zip(transformed, sample.documents) ) diff --git a/src/xpmir/letor/samplers/hydrators.py b/src/xpmir/letor/samplers/hydrators.py index 111a0e1..7f5005a 100644 --- a/src/xpmir/letor/samplers/hydrators.py +++ b/src/xpmir/letor/samplers/hydrators.py @@ -20,7 +20,7 @@ class SampleTransform(Config, ABC): @abstractmethod def transform_topics( - self, topics: Iterator[ir.TopicRecord] + self, topics: ir.TopicRecord ) -> Optional[List[ir.TopicRecord]]: ... @@ -40,14 +40,13 @@ class SampleHydrator(SampleTransform): querystore: Param[Optional[TextStore]] """The store for query texts if needed""" - def transform_topics(self, topics: List[ir.TopicRecord]): + def transform_topics(self, topic: ir.TopicRecord): if self.querystore is None: return None return [ ir.create_record( id=topic[IDItem].id, text=self.querystore[topic[IDItem].id] ) - for topic in topics ] def transform_documents(self, documents: List[ir.DocumentRecord]): diff --git a/src/xpmir/neural/dual.py b/src/xpmir/neural/dual.py index 1479467..0577861 100644 --- a/src/xpmir/neural/dual.py +++ b/src/xpmir/neural/dual.py @@ -76,7 +76,7 @@ def score_product(self, queries, documents, info: Optional[TrainerContext] = Non return queries.value @ documents.value.T def score_pairs(self, queries, documents, info: Optional[TrainerContext] = None): - scores = (queries.unsqueeze(1) @ documents.unsqueeze(2)).squeeze(-1).squeeze(-1) + scores = (queries.value.unsqueeze(1) @ documents.value.unsqueeze(2)).squeeze(-1).squeeze(-1) # Apply the dual vector hook if info is not None: diff --git a/src/xpmir/text/huggingface/tokenizers.py b/src/xpmir/text/huggingface/tokenizers.py index 2d41648..cb7004f 100644 --- a/src/xpmir/text/huggingface/tokenizers.py +++ b/src/xpmir/text/huggingface/tokenizers.py @@ -113,6 +113,10 @@ def vocab_size(self) -> int: """Returns the size of the vocabulary""" return self.tokenizer.vocab_size + @property + def vocab(self) -> dict: + return self.tokenizer.vocab + class HFTokenizerBase(TokenizerBase[TokenizerInput, TokenizedTexts]): """Base class for all Hugging-Face tokenizers""" @@ -136,6 +140,9 @@ def tok2id(self, tok: str) -> int: def id2tok(self, idx: int) -> str: return self.tokenizer.id2tok(idx) + + def get_vocabulary(self): + return self.tokenizer.vocab class HFStringTokenizer(HFTokenizerBase[HFTokenizerInput]):