Skip to content

Commit

Permalink
TREC CaST 2020 supported again
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Feb 29, 2024
1 parent 6c1dfbb commit 02a0885
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
datamaestro>=1.0.1
datamaestro>=1.0.2
ir_datasets
attrs
5 changes: 3 additions & 2 deletions src/datamaestro_text/data/conversation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from attr import define
from datamaestro.data import Base
from datamaestro.record import Record, Item

from datamaestro_text.data.ir import TopicRecord
from datamaestro_text.utils.iter import FactoryIterable, LazyList, RangeView

# ---- Basic types
Expand Down Expand Up @@ -32,6 +32,7 @@ def get_decontextualized_query(self, mode=None) -> str:
return self.decontextualized_query


@define
class DecontextualizedDictItem(DecontextualizedItem):
"""A conversation entry providing decontextualized version of the user query"""

Expand All @@ -49,7 +50,7 @@ class ConversationRecord(Record):
pass


class TopicConversationRecord(ConversationRecord):
class TopicConversationRecord(ConversationRecord, TopicRecord):
"""A conversation record"""

pass
Expand Down
40 changes: 25 additions & 15 deletions src/datamaestro_text/datasets/irds/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from functools import partial
import logging
from pathlib import Path
from typing import Any, Iterator, NamedTuple, Tuple, Type, List
from attr import define
import attrs
from typing import Iterator, Tuple, Type, List
import ir_datasets
from ir_datasets.indices import PickleLz4FullStore
from ir_datasets.formats import (
Expand All @@ -18,6 +16,7 @@
from experimaestro import Config, Param
from experimaestro.compat import cached_property
from experimaestro import Option
from datamaestro.record import recordtypes
import datamaestro_text.data.ir as ir
from datamaestro_text.data.ir.base import (
Record,
Expand Down Expand Up @@ -411,7 +410,9 @@ def iter(self) -> Iterator[TopicRecord]:
from datamaestro_text.data.conversation.base import (
ConversationTreeNode,
DecontextualizedDictItem,
ConversationHistory,
RetrievedEntry,
AnswerConversationRecord,
ConversationHistoryItem,
)

class CastTopicsHandler(TopicsHandler):
Expand Down Expand Up @@ -439,6 +440,16 @@ def iter(self) -> Iterator[ir.TopicRecord]:
"""Returns an iterator over topics"""
return iter(self.records)

@recordtypes(
IDItem, SimpleTextItem, DecontextualizedDictItem, ConversationHistoryItem
)
class Cast2020TopicRecord(TopicRecord):
...

@recordtypes(RetrievedEntry)
class Cast2020ResponseRecord(AnswerConversationRecord):
...

class Cast2020TopicsHandler(CastTopicsHandler):
@cached_property
def records(self):
Expand All @@ -448,17 +459,11 @@ def records(self):
conversation = []
records = []

class Cast2020TopicRecord(TopicRecord):
class_types = [
IDItem,
SimpleTextItem,
DecontextualizedDictItem,
ConversationHistory,
]

for (
query
) in self.dataset.queries_iter(): # type: _irds.trec_cast.Cast2020Query
) in (
self.dataset.dataset.queries_iter()
): # type: _irds.trec_cast.Cast2020Query
decontextualized = DecontextualizedDictItem(
"manual",
{
Expand All @@ -470,7 +475,9 @@ class Cast2020TopicRecord(TopicRecord):
IDItem(query.query_id),
SimpleTextItem(query.raw_utterance),
decontextualized,
ConversationHistory(node.conversation(False)),
ConversationHistoryItem(
node.conversation(False) if node else []
),
)

if topic_number == query.topic_number:
Expand All @@ -485,12 +492,15 @@ class Cast2020TopicRecord(TopicRecord):
conversation.append(node)
node = node.add(
ConversationTreeNode(
DocumentRecord.from_id(query.manual_canonical_result_id)
Cast2020ResponseRecord(
RetrievedEntry(query.manual_canonical_result_id)
)
)
)
conversation.append(node)
except Exception:
logging.exception("Error while computing topic records")
raise

return records

Expand Down

0 comments on commit 02a0885

Please sign in to comment.