From a7ae472a23dd3bed68f359f1e59dcd91f476fb1f Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Tue, 27 Feb 2024 14:32:29 +0100 Subject: [PATCH] fix for records --- src/datamaestro_text/data/ir/__init__.py | 2 +- src/datamaestro_text/datasets/irds/data.py | 2 +- src/datamaestro_text/transforms/ir/__init__.py | 8 +++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/datamaestro_text/data/ir/__init__.py b/src/datamaestro_text/data/ir/__init__.py index f510b8a..fba1b54 100644 --- a/src/datamaestro_text/data/ir/__init__.py +++ b/src/datamaestro_text/data/ir/__init__.py @@ -55,7 +55,7 @@ def iter_ids(self) -> Iterator[str]: By default, use iter_documents, which is not really efficient. """ for doc in self.iter(): - yield doc.get_id() + yield doc[IDItem].id @property def documentcount(self): diff --git a/src/datamaestro_text/datasets/irds/data.py b/src/datamaestro_text/datasets/irds/data.py index c2c8e1d..37ed975 100644 --- a/src/datamaestro_text/datasets/irds/data.py +++ b/src/datamaestro_text/datasets/irds/data.py @@ -419,7 +419,7 @@ def records(self): @cached_property def ext2records(self): - return {record.topic.get_id(): record for record in self.records} + return {record[IDItem].id: record for record in self.records} def topic_int(self, internal_topic_id: int) -> TopicRecord: """Returns a document given its internal ID""" diff --git a/src/datamaestro_text/transforms/ir/__init__.py b/src/datamaestro_text/transforms/ir/__init__.py index 1cf821b..7f21b86 100644 --- a/src/datamaestro_text/transforms/ir/__init__.py +++ b/src/datamaestro_text/transforms/ir/__init__.py @@ -69,14 +69,16 @@ def __validate__(self): def iter(self): for topic, doc1, doc2 in self.data.iter(): - doc1, doc2 = self.store.documents_ext([doc1.get_id(), doc2.get_id()]) + doc1, doc2 = self.store.documents_ext( + [doc1[ir.IDItem].id, doc2[ir.IDItem].id] + ) yield topic, doc1, doc2 def batch_iter(self, size: int): for triplets in self.data.batch_iter(size): docids = [] for topic, doc1, doc2 in triplets: - docids.extend(doc1.get_id(), doc2.get_id()) + docids.extend(doc1[ir.IDItem].id, doc2[ir.IDItem].id) docs_iter = iter(self.store.documents_ext(docids)) for triplet in triplets: triplet[1] = next(docs_iter) @@ -165,7 +167,7 @@ def execute(self): if self.topic_ids: def get_query(query): - return query.get_id() + return query[ir.IDItem].id else: