Skip to content

Commit

Permalink
Use __getstate__ to pickle IRDS documents
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Feb 28, 2024
1 parent a7ae472 commit 408d1ef
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
20 changes: 11 additions & 9 deletions src/datamaestro_text/data/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ class Documents(Base):
count: Meta[Optional[int]]
"""Number of documents"""

def iter(self) -> Iterator[Record]:
def iter(self) -> Iterator[DocumentRecord]:
"""Returns an iterator over documents"""
raise self.iter_documents()

def iter_documents(self) -> Iterator[Record]:
def iter_documents(self) -> Iterator[DocumentRecord]:
return self.iter()

def iter_ids(self) -> Iterator[str]:
Expand All @@ -67,7 +67,7 @@ def documentcount(self):

@property
@abstractmethod
def document_recordtype(self) -> Type[Record]:
def document_recordtype(self) -> Type[DocumentRecord]:
"""The class for documents"""
...

Expand All @@ -94,15 +94,17 @@ def document_ext(self, docid: str) -> DocumentRecord:
"""Returns a document given its external ID"""
raise NotImplementedError(f"document() in {self.__class__}")

def documents_ext(self, docids: List[str]) -> List[Record]:
def documents_ext(self, docids: List[str]) -> List[DocumentRecord]:
"""Returns documents given their external ID
By default, just look using `document_ext`, but some store might
optimize batch retrieval
"""
return [self.document_ext(docid) for docid in docids]

def iter_sample(self, randint: Optional[Callable[[int], int]]) -> Iterator[Record]:
def iter_sample(
self, randint: Optional[Callable[[int], int]]
) -> Iterator[DocumentRecord]:
"""Sample documents from the dataset"""
length = self.documentcount
randint = randint or (lambda max: random.randint(0, max - 1))
Expand All @@ -127,7 +129,7 @@ class Topics(Base, ABC):
"""A set of topics with associated IDs"""

@abstractmethod
def iter(self) -> Iterator[Record]:
def iter(self) -> Iterator[TopicRecord]:
"""Returns an iterator over topics"""
...

Expand All @@ -140,7 +142,7 @@ def count(self) -> Optional[int]:

@property
@abstractmethod
def topic_recordtype(self) -> Type[Record]:
def topic_recordtype(self) -> Type[TopicRecord]:
"""The class for topics"""


Expand All @@ -151,11 +153,11 @@ class TopicsStore(Topics):
"""Adhoc topics store"""

@abstractmethod
def topic_int(self, internal_topic_id: int) -> Record:
def topic_int(self, internal_topic_id: int) -> TopicRecord:
"""Returns a document given its internal ID"""

@abstractmethod
def topic_ext(self, external_topic_id: int) -> Record:
def topic_ext(self, external_topic_id: int) -> TopicRecord:
"""Returns a document given its external ID"""


Expand Down
6 changes: 6 additions & 0 deletions src/datamaestro_text/datasets/irds/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ class Documents(ir.DocumentStore, IRDSId):
# List of fields
# self.dataset.docs_cls()._fields

def __getstate__(self):
return (self.id, self.irds)

def __setstate__(self, state):
self.id, self.irds = state

def iter(self) -> Iterator[ir.DocumentRecord]:
"""Returns an iterator over adhoc documents"""
for doc in self.dataset.docs_iter():
Expand Down

0 comments on commit 408d1ef

Please sign in to comment.