From 408d1ef7f428e736b7bcd88d701ffc6796fb25d5 Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Thu, 29 Feb 2024 00:17:59 +0100 Subject: [PATCH] Use __getstate__ to pickle IRDS documents --- src/datamaestro_text/data/ir/__init__.py | 20 +++++++++++--------- src/datamaestro_text/datasets/irds/data.py | 6 ++++++ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/datamaestro_text/data/ir/__init__.py b/src/datamaestro_text/data/ir/__init__.py index fba1b54..ff88b3d 100644 --- a/src/datamaestro_text/data/ir/__init__.py +++ b/src/datamaestro_text/data/ir/__init__.py @@ -42,11 +42,11 @@ class Documents(Base): count: Meta[Optional[int]] """Number of documents""" - def iter(self) -> Iterator[Record]: + def iter(self) -> Iterator[DocumentRecord]: """Returns an iterator over documents""" raise self.iter_documents() - def iter_documents(self) -> Iterator[Record]: + def iter_documents(self) -> Iterator[DocumentRecord]: return self.iter() def iter_ids(self) -> Iterator[str]: @@ -67,7 +67,7 @@ def documentcount(self): @property @abstractmethod - def document_recordtype(self) -> Type[Record]: + def document_recordtype(self) -> Type[DocumentRecord]: """The class for documents""" ... @@ -94,7 +94,7 @@ def document_ext(self, docid: str) -> DocumentRecord: """Returns a document given its external ID""" raise NotImplementedError(f"document() in {self.__class__}") - def documents_ext(self, docids: List[str]) -> List[Record]: + def documents_ext(self, docids: List[str]) -> List[DocumentRecord]: """Returns documents given their external ID By default, just look using `document_ext`, but some store might @@ -102,7 +102,9 @@ def documents_ext(self, docids: List[str]) -> List[Record]: """ return [self.document_ext(docid) for docid in docids] - def iter_sample(self, randint: Optional[Callable[[int], int]]) -> Iterator[Record]: + def iter_sample( + self, randint: Optional[Callable[[int], int]] + ) -> Iterator[DocumentRecord]: """Sample documents from the dataset""" length = self.documentcount randint = randint or (lambda max: random.randint(0, max - 1)) @@ -127,7 +129,7 @@ class Topics(Base, ABC): """A set of topics with associated IDs""" @abstractmethod - def iter(self) -> Iterator[Record]: + def iter(self) -> Iterator[TopicRecord]: """Returns an iterator over topics""" ... @@ -140,7 +142,7 @@ def count(self) -> Optional[int]: @property @abstractmethod - def topic_recordtype(self) -> Type[Record]: + def topic_recordtype(self) -> Type[TopicRecord]: """The class for topics""" @@ -151,11 +153,11 @@ class TopicsStore(Topics): """Adhoc topics store""" @abstractmethod - def topic_int(self, internal_topic_id: int) -> Record: + def topic_int(self, internal_topic_id: int) -> TopicRecord: """Returns a document given its internal ID""" @abstractmethod - def topic_ext(self, external_topic_id: int) -> Record: + def topic_ext(self, external_topic_id: int) -> TopicRecord: """Returns a document given its external ID""" diff --git a/src/datamaestro_text/datasets/irds/data.py b/src/datamaestro_text/datasets/irds/data.py index 37ed975..999062e 100644 --- a/src/datamaestro_text/datasets/irds/data.py +++ b/src/datamaestro_text/datasets/irds/data.py @@ -160,6 +160,12 @@ class Documents(ir.DocumentStore, IRDSId): # List of fields # self.dataset.docs_cls()._fields + def __getstate__(self): + return (self.id, self.irds) + + def __setstate__(self, state): + self.id, self.irds = state + def iter(self) -> Iterator[ir.DocumentRecord]: """Returns an iterator over adhoc documents""" for doc in self.dataset.docs_iter():