Skip to content

Commit

Permalink
fix negation distillation
Browse files Browse the repository at this point in the history
  • Loading branch information
Mosakana committed Oct 19, 2024
1 parent dcc9e1c commit ef38f4a
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/xpmir/datasets/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ class MemoryTopicStore(TextStore):

@cached_property
def store(self):
return {topic[IDItem].id: topic.text for topic in self.topics.iter()}
return {topic[IDItem].id: topic[TextItem].text for topic in self.topics.iter()}

def __getitem__(self, key: str) -> str:
return self.store[key]
6 changes: 3 additions & 3 deletions src/xpmir/letor/distillation/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def train_batch(self, samples: List[PairwiseDistillationSample]):
for ix, sample in enumerate(samples):
records.add(
PairwiseRecord(
sample.query.as_record(),
DocumentRecord(sample.documents[0].document),
DocumentRecord(sample.documents[1].document),
sample.query,
sample.documents[0].document,
sample.documents[1].document,
)
)
teacher_scores[ix, 0] = sample.documents[0].score
Expand Down
8 changes: 4 additions & 4 deletions src/xpmir/letor/distillation/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ class PairwiseHydrator(PairwiseDistillationSamples, SampleHydrator):
def transform(self, sample: PairwiseDistillationSample):
topic, documents = sample.query, sample.documents

if transformed := self.querystore.transforme_topics():
topic = TopicRecord(*transformed)
if transformed := self.transform_topics(topic):
topic = transformed[0]

if transformed := self.documentstore.transforme_documents():
if transformed := self.transform_documents(documents):
documents = tuple(
ScoredDocument(d, sd.score)
ScoredDocument(d, sd[ScoredItem].score)
for d, sd in zip(transformed, sample.documents)
)

Expand Down
5 changes: 2 additions & 3 deletions src/xpmir/letor/samplers/hydrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class SampleTransform(Config, ABC):
@abstractmethod
def transform_topics(
self, topics: Iterator[ir.TopicRecord]

This comment has been minimized.

Copy link
@bpiwowar

bpiwowar Oct 23, 2024

Collaborator

It should be a list or iterator

self, topics: ir.TopicRecord
) -> Optional[List[ir.TopicRecord]]:
...

Expand All @@ -40,14 +40,13 @@ class SampleHydrator(SampleTransform):
querystore: Param[Optional[TextStore]]
"""The store for query texts if needed"""

def transform_topics(self, topics: List[ir.TopicRecord]):

This comment has been minimized.

Copy link
@bpiwowar

bpiwowar Oct 23, 2024

Collaborator

Should be an iterator also

def transform_topics(self, topic: ir.TopicRecord):
if self.querystore is None:
return None
return [
ir.create_record(
id=topic[IDItem].id, text=self.querystore[topic[IDItem].id]
)
for topic in topics
]

def transform_documents(self, documents: List[ir.DocumentRecord]):
Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/neural/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def score_product(self, queries, documents, info: Optional[TrainerContext] = Non
return queries.value @ documents.value.T

def score_pairs(self, queries, documents, info: Optional[TrainerContext] = None):
scores = (queries.unsqueeze(1) @ documents.unsqueeze(2)).squeeze(-1).squeeze(-1)
scores = (queries.value.unsqueeze(1) @ documents.value.unsqueeze(2)).squeeze(-1).squeeze(-1)

# Apply the dual vector hook
if info is not None:
Expand Down
7 changes: 7 additions & 0 deletions src/xpmir/text/huggingface/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def vocab_size(self) -> int:
"""Returns the size of the vocabulary"""
return self.tokenizer.vocab_size

@property
def vocab(self) -> dict:
return self.tokenizer.vocab


class HFTokenizerBase(TokenizerBase[TokenizerInput, TokenizedTexts]):
"""Base class for all Hugging-Face tokenizers"""
Expand All @@ -136,6 +140,9 @@ def tok2id(self, tok: str) -> int:

def id2tok(self, idx: int) -> str:
return self.tokenizer.id2tok(idx)

def get_vocabulary(self):
return self.tokenizer.vocab


class HFStringTokenizer(HFTokenizerBase[HFTokenizerInput]):
Expand Down

0 comments on commit ef38f4a

Please sign in to comment.