Skip to content

Commit

Permalink
Fix bug with documents_ext
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Sep 26, 2023
1 parent 69e2f07 commit b28ae2c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 5 deletions.
12 changes: 11 additions & 1 deletion src/datamaestro_text/data/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass
from datamaestro.data import Base
from datamaestro_text.utils.files import auto_open
from datamaestro_text.utils.iter import BatchIterator
from .base import (
Document,
Topic,
Expand Down Expand Up @@ -82,7 +83,7 @@ def document_ext(self, docid: str) -> Document:
"""Returns a document given its external ID"""
raise NotImplementedError(f"document() in {self.__class__}")

def documents_ext(self, docids: List[str]) -> Document:
def documents_ext(self, docids: List[str]) -> List[Document]:
"""Returns documents given their external ID
By default, just look using `document_ext`, but some store might
Expand Down Expand Up @@ -201,8 +202,17 @@ class TrainingTriplets(Base):
negative document"""

def iter(self) -> Iterator[Tuple[Topic, Document, Document]]:
"""Returns an iterator"""
raise NotImplementedError(f"For class {self.__class__}")

def batch_iter(self, size: int) -> Iterator[List[Tuple[Topic, Document, Document]]]:
"""Returns an iterator over batches of triplets
The default implementation just concatenates triplets using `iter`, but
some classes might use more efficient ways to provide batches of data
"""
return BatchIterator(self.iter(), size)

def count(self):
"""Returns the number of triplets or None"""
return None
Expand Down
3 changes: 2 additions & 1 deletion src/datamaestro_text/datasets/irds/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def document_ext(self, docid: str) -> Document:

def documents_ext(self, docids: List[str]) -> Document:
"""Returns documents given their external IDs (optimized for batch)"""
return [self.converter(doc) for doc in self.store.get_many_iter(docids)]
retrieved = self.store.get_many(docids)
return [self.converter(retrieved[docid]) for docid in docids]

def document_int(self, ix):
return self.converter(self.dataset.docs_iter()[ix])
Expand Down
16 changes: 13 additions & 3 deletions src/datamaestro_text/transforms/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,19 @@ def __validate__(self):

def iter(self):
for topic, doc1, doc2 in self.data.iter():
yield topic, self.store.document_ext(
doc1.get_id()
), self.store.document_ext(doc2.get_id())
doc1, doc2 = self.store.documents_ext(doc1.get_id(), doc2.get_id())
yield topic, doc1, doc2

def batch_iter(self, size: int):
for triplets in self.data.batch_iter(size):
docids = []
for topic, doc1, doc2 in triplets:
docids.extend(doc1.get_id(), doc2.get_id())
docs_iter = iter(self.store.documents_ext(docids))
for triplet in triplets:
triplet[1] = next(docs_iter)
triplet[2] = next(docs_iter)
yield triplets

def count(self):
return self.data.count()
Expand Down
21 changes: 21 additions & 0 deletions src/datamaestro_text/utils/iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import TypeVar, Iterator, List

T = TypeVar("T")


class BatchIterator(Iterator[List[T]]):
"""Adapter for iterators to return batches of elements instead of"""

def __init__(self, iterator: Iterator[T], size: int):
self.iterator = iterator
self.size = size

def __next__(self):
batch = []
for _, element in zip(range(self.size), self.iterator):
batch.append(element)

if len(batch) == 0:
raise StopIteration()

return batch

0 comments on commit b28ae2c

Please sign in to comment.