From 02a0885c2b00bbc4f5a5156d45863f7a6d1e3b7f Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Thu, 29 Feb 2024 16:45:24 +0100 Subject: [PATCH] TREC CaST 2020 supported again --- requirements.txt | 2 +- .../data/conversation/base.py | 5 ++- src/datamaestro_text/datasets/irds/data.py | 40 ++++++++++++------- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/requirements.txt b/requirements.txt index f6587b1..38eedf9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -datamaestro>=1.0.1 +datamaestro>=1.0.2 ir_datasets attrs diff --git a/src/datamaestro_text/data/conversation/base.py b/src/datamaestro_text/data/conversation/base.py index 99a9d1b..22e926f 100644 --- a/src/datamaestro_text/data/conversation/base.py +++ b/src/datamaestro_text/data/conversation/base.py @@ -3,7 +3,7 @@ from attr import define from datamaestro.data import Base from datamaestro.record import Record, Item - +from datamaestro_text.data.ir import TopicRecord from datamaestro_text.utils.iter import FactoryIterable, LazyList, RangeView # ---- Basic types @@ -32,6 +32,7 @@ def get_decontextualized_query(self, mode=None) -> str: return self.decontextualized_query +@define class DecontextualizedDictItem(DecontextualizedItem): """A conversation entry providing decontextualized version of the user query""" @@ -49,7 +50,7 @@ class ConversationRecord(Record): pass -class TopicConversationRecord(ConversationRecord): +class TopicConversationRecord(ConversationRecord, TopicRecord): """A conversation record""" pass diff --git a/src/datamaestro_text/datasets/irds/data.py b/src/datamaestro_text/datasets/irds/data.py index 999062e..8297d2d 100644 --- a/src/datamaestro_text/datasets/irds/data.py +++ b/src/datamaestro_text/datasets/irds/data.py @@ -2,9 +2,7 @@ from functools import partial import logging from pathlib import Path -from typing import Any, Iterator, NamedTuple, Tuple, Type, List -from attr import define -import attrs +from typing import Iterator, Tuple, Type, List import ir_datasets from ir_datasets.indices import PickleLz4FullStore from ir_datasets.formats import ( @@ -18,6 +16,7 @@ from experimaestro import Config, Param from experimaestro.compat import cached_property from experimaestro import Option +from datamaestro.record import recordtypes import datamaestro_text.data.ir as ir from datamaestro_text.data.ir.base import ( Record, @@ -411,7 +410,9 @@ def iter(self) -> Iterator[TopicRecord]: from datamaestro_text.data.conversation.base import ( ConversationTreeNode, DecontextualizedDictItem, - ConversationHistory, + RetrievedEntry, + AnswerConversationRecord, + ConversationHistoryItem, ) class CastTopicsHandler(TopicsHandler): @@ -439,6 +440,16 @@ def iter(self) -> Iterator[ir.TopicRecord]: """Returns an iterator over topics""" return iter(self.records) + @recordtypes( + IDItem, SimpleTextItem, DecontextualizedDictItem, ConversationHistoryItem + ) + class Cast2020TopicRecord(TopicRecord): + ... + + @recordtypes(RetrievedEntry) + class Cast2020ResponseRecord(AnswerConversationRecord): + ... + class Cast2020TopicsHandler(CastTopicsHandler): @cached_property def records(self): @@ -448,17 +459,11 @@ def records(self): conversation = [] records = [] - class Cast2020TopicRecord(TopicRecord): - class_types = [ - IDItem, - SimpleTextItem, - DecontextualizedDictItem, - ConversationHistory, - ] - for ( query - ) in self.dataset.queries_iter(): # type: _irds.trec_cast.Cast2020Query + ) in ( + self.dataset.dataset.queries_iter() + ): # type: _irds.trec_cast.Cast2020Query decontextualized = DecontextualizedDictItem( "manual", { @@ -470,7 +475,9 @@ class Cast2020TopicRecord(TopicRecord): IDItem(query.query_id), SimpleTextItem(query.raw_utterance), decontextualized, - ConversationHistory(node.conversation(False)), + ConversationHistoryItem( + node.conversation(False) if node else [] + ), ) if topic_number == query.topic_number: @@ -485,12 +492,15 @@ class Cast2020TopicRecord(TopicRecord): conversation.append(node) node = node.add( ConversationTreeNode( - DocumentRecord.from_id(query.manual_canonical_result_id) + Cast2020ResponseRecord( + RetrievedEntry(query.manual_canonical_result_id) + ) ) ) conversation.append(node) except Exception: logging.exception("Error while computing topic records") + raise return records