Skip to content

Commit

Permalink
updates for conversation search structures
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed May 19, 2024
1 parent aacf124 commit 774758b
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 12 deletions.
63 changes: 57 additions & 6 deletions src/datamaestro_text/data/conversation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,34 +102,49 @@ class ConversationHistoryItem(Item):


class ConversationNode:
@abstractmethod
def entry(self) -> Record:
"""The current conversation entry"""
...

@abstractmethod
def history(self) -> ConversationHistory:
"""Preceding conversation entries, from most recent to more ancient"""
...

@abstractmethod
def parent(self) -> Optional["ConversationNode"]:
...

@abstractmethod
def children(self) -> List["ConversationNode"]:
...


class ConversationTree(ABC):
@abstractmethod
def root(self) -> ConversationNode:
...

class ConversationTree:
@abstractmethod
def __iter__(self) -> Iterator[ConversationNode]:
"""Iterates over conversation nodes"""
pass
...


# ---- A conversation tree


class SingleConversationTree(ConversationTree):
class SingleConversationTree(ConversationTree, ABC):
"""Simple conversations, based on a sequence of entries"""

id: str
history: Sequence[Record]
history: List[Record]

def __init__(self, id: Optional[str], history: List[Record]):
"""Create a simple conversation
:param history: The entries, in reverse order (i.e. more ancient first)
:param history: The entries, in **reverse** order (i.e. more ancient first)
"""
self.history = history or []
self.id = id
Expand All @@ -138,21 +153,48 @@ def add(self, entry: Record):
self.history.insert(0, entry)

def __iter__(self) -> Iterator[ConversationNode]:
for ix in range(len(self.history)):
"""Iterates over the conversation (starting with the beginning)"""
for ix in reversed(range(len(self.history))):
yield SingleConversationTreeNode(self, ix)

def root(self):
return SingleConversationTreeNode(self, len(self.history) - 1)


@define
class SingleConversationTreeNode(ConversationNode):
tree: SingleConversationTree
index: int

@property
def entry(self) -> Record:
return self.tree.history[self.index]

@entry.setter
def entry(self, record: Record):
try:
self.tree.history[self.index] = record
except Exception as e:
print(e)
raise

def history(self) -> Sequence[Record]:
return self.tree.history[self.index + 1 :]

def parent(self) -> ConversationNode | None:
return (
SingleConversationTreeNode(self.tree, self.index + 1)
if self.index < len(self.tree.history) - 1
else []
)

def children(self) -> List[ConversationNode]:
return (
[SingleConversationTreeNode(self.tree, self.index - 1)]
if self.index > 0
else None
)


class ConversationTreeNode(ConversationNode, ConversationTree):
"""A conversation tree node"""
Expand Down Expand Up @@ -186,6 +228,15 @@ def __iter__(self) -> Iterator["ConversationTreeNode"]:
for child in self.children:
yield from child

def parent(self) -> ConversationNode | None:
return self.parent

def children(self) -> List[ConversationNode]:
return self.children

def root(self):
return self


class ConversationDataset(Base, ABC):
"""A dataset made of conversations"""
Expand Down
82 changes: 76 additions & 6 deletions src/datamaestro_text/datasets/irds/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
import logging
from pathlib import Path
from typing import Iterator, Tuple, Type, List
from typing import Dict, Iterator, Tuple, Type, List
import ir_datasets
from ir_datasets.indices import PickleLz4FullStore
from ir_datasets.formats import (
Expand All @@ -17,6 +17,7 @@
from experimaestro.compat import cached_property
from experimaestro import Option
from datamaestro.record import RecordType, record_type
from datamaestro_text.data.conversation.base import AnswerEntry
import datamaestro_text.data.ir as ir
from datamaestro_text.data.ir.base import (
Record,
Expand Down Expand Up @@ -445,22 +446,25 @@ def records(self):
"auto": query.automatic_rewritten_utterance,
},
)

is_new_conversation = topic_number != query.topic_number

topic = Record(
IDItem(query.query_id),
SimpleTextItem(query.raw_utterance),
decontextualized,
ConversationHistoryItem(
node.conversation(False) if node else []
[] if is_new_conversation else node.conversation(False)
),
EntryType.USER_QUERY,
)

if topic_number == query.topic_number:
node = node.add(ConversationTreeNode(topic))
else:
if is_new_conversation:
conversation = []
node = ConversationTreeNode(topic)
topic_number = query.topic_number
else:
node = node.add(ConversationTreeNode(topic))

records.append(topic)

Expand Down Expand Up @@ -494,12 +498,63 @@ class Cast2021TopicsHandler(CastTopicsHandler):
def get_canonical_result_id(query: _irds.trec_cast.Cast2021Query):
return query.canonical_result_id

class Cast2022TopicsHandler(CastTopicsHandler):
def __init__(self, dataset):
self.dataset = dataset

@cached_property
def records(self):
try:
records = []
nodes: Dict[str, ConversationTreeNode] = {}

for (
query
) in (
self.dataset.dataset.queries_iter()
): # type: _irds.trec_cast.Cast2022Query
parent = nodes[query.parent_id] if query.parent_id else None

if query.participant == "User":
topic = Record(
IDItem(query.query_id),
SimpleTextItem(query.raw_utterance),
DecontextualizedDictItem(
"manual",
{
"manual": query.manual_rewritten_utterance,
},
),
ConversationHistoryItem(
parent.conversation(False) if parent else []
),
EntryType.USER_QUERY,
)
node = ConversationTreeNode(topic)
records.append(topic)
else:
node = ConversationTreeNode(
Record(
AnswerEntry(query.response),
EntryType.SYSTEM_ANSWER,
)
)

nodes[query.query_id] = node
if parent:
parent.add(node)
except Exception:
logging.exception("Error while computing topic records")
raise

return records

Topics.HANDLERS.update(
{
# _irds.trec_cast.Cast2019Query: Cast2019TopicsHandler,
_irds.trec_cast.Cast2020Query: Cast2020TopicsHandler,
_irds.trec_cast.Cast2021Query: Cast2021TopicsHandler,
# _irds.trec_cast.Cast2022Query: Cast2022TopicsHandler
_irds.trec_cast.Cast2022Query: Cast2022TopicsHandler,
}
)

Expand All @@ -516,7 +571,22 @@ def __call__(self, _, doc: _irds.trec_cast.CastDoc):
IDItem(doc.doc_id), formats.SimpleTextItem(" ".join(doc.passages))
)

class CastPassageDocHandler:
def check(self, cls):
assert issubclass(cls, _irds.trec_cast.CastPassageDoc)

@cached_property
def target_cls(self):
return formats.TitleUrlDocument

def __call__(self, _, doc: _irds.trec_cast.CastPassageDoc):
return Record(
IDItem(doc.doc_id),
formats.TitleUrlDocument(doc.text, doc.title, doc.url),
)

Documents.CONVERTERS[_irds.trec_cast.CastDoc] = CastDocHandler()
Documents.CONVERTERS[_irds.trec_cast.CastPassageDoc] = CastPassageDocHandler()


class Adhoc(ir.Adhoc, IRDSId):
Expand Down

0 comments on commit 774758b

Please sign in to comment.