Skip to content

Commit

Permalink
Cleaned up TextItem hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Feb 27, 2024
1 parent 9240d8c commit 6d73bcf
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 86 deletions.
16 changes: 7 additions & 9 deletions src/datamaestro_text/data/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

class BaseRecord(Record):
@classmethod
def from_text(cls, text: str):
return cls(SimpleTextItem(text))
def from_text(cls, text: str, *items: Item):
return cls(SimpleTextItem(text), *items)

@classmethod
def from_id(cls, id: str):
return cls(IDItem(id))
def from_id(cls, id: str, *items: Item):
return cls(IDItem(id), *items)


class TopicRecord(BaseRecord):
Expand All @@ -32,20 +32,18 @@ class ScoredItem(Item):


class TextItem(Item, ABC):
@property
@abstractmethod
def get_text(self) -> str:
def text(self) -> str:
"""Returns the text"""


@define
class SimpleTextItem(TextItem, ABC):
class SimpleTextItem(TextItem):
"""A topic/document with a text record"""

text: str

def get_text(self):
return self.text


@define
class InternalIDItem(Item, ABC):
Expand Down
2 changes: 1 addition & 1 deletion src/datamaestro_text/data/ir/cord19.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ def iter(self) -> Iterator[CordDocumentRecord]:
for row in DictReader(fp):
yield CordDocumentRecord(
IDItem(row["cord_uid"]),
DocumentWithTitle(row["title"], row["abstract"]),
DocumentWithTitle(row["abstract"], row["title"]),
)
108 changes: 36 additions & 72 deletions src/datamaestro_text/data/ir/formats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cached_property
from typing import ClassVar, Tuple
from attrs import define
from datamaestro.record import recordtypes
Expand All @@ -7,25 +8,22 @@


@define
class CordDocument(TextItem):
text: str
title: str
url: str
pubmed_id: str

has_text: ClassVar[bool] = True
class DocumentWithTitle(TextItem):
"""Web document with title and body"""

def get_text(self):
return f"{self.title} {self.text}"
body: str

title: str

@define
class DocumentWithTitle(TextItem):
"""Web document with title and URL"""
@cached_property
def text(self):
return f"{self.title} {self.body}"

title: str

text: str
@define
class CordDocument(DocumentWithTitle):
url: str
pubmed_id: str


@define
Expand All @@ -36,10 +34,9 @@ class CordFullTextDocument(TextItem):
abstract: str
body: Tuple[Cord19FullTextSection, ...]

has_text: ClassVar[bool] = True

def get_text(self):
return f"{self.abstract}"
@cached_property
def text(self):
return self.abstract


@define
Expand All @@ -48,10 +45,9 @@ class MsMarcoDocument(TextItem):
title: str
body: str

has_text: ClassVar[bool] = True

def get_text(self):
return f"{self.body}"
@cached_property
def text(self):
return self.body


@define
Expand All @@ -60,31 +56,24 @@ class NFCorpusDocument(TextItem):
title: str
abstract: str

has_text: ClassVar[bool] = True

def get_text(self):
return f"{self.abstract}"
@cached_property
def text(self):
return self.abstract


@define
class TitleDocument(TextItem):
text: str
body: str
title: str
has_text: ClassVar[bool] = True

def get_text(self):
return f"{self.title} {self.text}"
@cached_property
def text(self):
return f"{self.title} {self.body}"


@define
class TitleUrlDocument(TextItem):
text: str
title: str
class TitleUrlDocument(TitleDocument):
url: str
has_text: ClassVar[bool] = True

def get_text(self):
return f"{self.title} {self.text}"


@define
Expand All @@ -93,9 +82,8 @@ class TrecParsedDocument(TextItem):
body: str
marked_up_doc: bytes

has_text: ClassVar[bool] = True

def get_text(self):
@cached_property
def text(self):
return f"{self.title} {self.body}"


Expand All @@ -110,10 +98,9 @@ class WapoDocument(TextItem):
body_paras_html: Tuple[str, ...]
body_media: Tuple[WapoDocMedia, ...]

has_text: ClassVar[bool] = True

def get_text(self):
return f"{self.body}"
@cached_property
def text(self):
return self.body


@define
Expand All @@ -127,22 +114,18 @@ class TweetDoc(TextItem):
source: bytes
source_content_type: str

def get_text(self):
return f"{self.text}"


@define
class OrConvQADocument(TextItem):
id: str
title: str
text: str
body: str
aid: str
bid: int

has_text: ClassVar[bool] = True

def get_text(self):
return f"{self.title} {self.text}"
@cached_property
def text(self):
return f"{self.title} {self.body}"


@define
Expand All @@ -151,37 +134,18 @@ class TrecTopic(TextItem):
query: str
narrative: str

def get_text(self):
return f"{self.text}"


@define
class UrlTopic(TextItem):
text: str
url: str

def get_text(self):
return f"{self.text}"


@define
class NFCorpusTopic(TextItem):
title: str
text: str
all: str

def get_text(self):
return f"{self.title}"


@define
class TrecQuery(TextItem):
title: str
description: str
narrative: str

def get_text(self):
return f"{self.description}"


@define
class TrecMb13Query(TextItem):
Expand Down
4 changes: 2 additions & 2 deletions src/datamaestro_text/datasets/irds/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def converter(self):

if hasattr(_irds, "miracl"):
Documents.CONVERTERS[_irds.miracl.MiraclDoc] = tuple_constructor(
formats.DocumentWithTitle, "doc_id", "title", "text"
formats.DocumentWithTitle, "doc_id", "text", "title"
)


Expand Down Expand Up @@ -351,7 +351,7 @@ class Topics(ir.TopicsStore, IRDSId):
formats.NFCorpusTopic, "query_id", "title", "all"
),
TrecQuery: tuple_constructor(
formats.TrecQuery, "query_id", "title", "description", "narrative"
formats.TrecTopic, "query_id", "title", "description", "narrative"
),
_irds.tweets2013_ia.TrecMb13Query: tuple_constructor(
formats.TrecMb13Query, "query_id", "query", "time", "tweet_time"
Expand Down
4 changes: 2 additions & 2 deletions src/datamaestro_text/transforms/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_query(query):
else:

def get_query(query):
return query[ir.TextItem].get_text()
return query[ir.TextItem].text

if self.doc_ids:

Expand All @@ -180,7 +180,7 @@ def get_doc(doc):
else:

def get_doc(doc):
return doc.get_text()
return doc.text

def triplegenerator():
logging.info("Starting to output triples")
Expand Down

0 comments on commit 6d73bcf

Please sign in to comment.