From ac7afff188af522381dff0dc784e5cc3dace494d Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Wed, 6 Mar 2024 15:27:11 +0100 Subject: [PATCH] Update for new record interface --- src/xpmir/conversation/learning/__init__.py | 7 +---- src/xpmir/conversation/models/cosplade.py | 1 + src/xpmir/documents/samplers.py | 16 ++++------ src/xpmir/index/faiss.py | 2 +- src/xpmir/index/sparse.py | 3 +- src/xpmir/letor/distillation/samplers.py | 5 ++-- src/xpmir/letor/records.py | 9 ++---- src/xpmir/letor/samplers/__init__.py | 29 +++++++------------ src/xpmir/letor/samplers/hydrators.py | 4 +-- src/xpmir/rankers/__init__.py | 8 ++--- src/xpmir/test/letor/test_samplers.py | 27 ++++++----------- .../test/letor/test_samplers_hydrator.py | 22 +++++++------- src/xpmir/test/neural/test_forward.py | 16 +++++----- src/xpmir/test/rankers/test_full.py | 8 ++--- src/xpmir/test/utils/utils.py | 27 ++++++++--------- src/xpmir/text/huggingface/encoders.py | 1 + 16 files changed, 78 insertions(+), 107 deletions(-) diff --git a/src/xpmir/conversation/learning/__init__.py b/src/xpmir/conversation/learning/__init__.py index 7bcc87f..9b41f74 100644 --- a/src/xpmir/conversation/learning/__init__.py +++ b/src/xpmir/conversation/learning/__init__.py @@ -1,5 +1,4 @@ from functools import cached_property -from datamaestro.record import RecordTypesCache import numpy as np from datamaestro_text.data.ir import TopicRecord from datamaestro_text.data.conversation import ( @@ -26,8 +25,6 @@ def conversations(self): def __post_init__(self): super().__post_init__() - self._recordtypes = RecordTypesCache("Conversation", ConversationHistoryItem) - def __iter__(self) -> RandomSerializableIterator[TopicConversationRecord]: def generator(random: np.random.RandomState): while True: @@ -44,9 +41,7 @@ def generator(random: np.random.RandomState): node_ix = random.randint(len(nodes)) node = nodes[node_ix] - node = self._recordtypes.update( - node.entry(), ConversationHistoryItem(node.history()) - ) + node = node.entry().update(ConversationHistoryItem(node.history())) yield node diff --git a/src/xpmir/conversation/models/cosplade.py b/src/xpmir/conversation/models/cosplade.py index c0faba7..fa39ab6 100644 --- a/src/xpmir/conversation/models/cosplade.py +++ b/src/xpmir/conversation/models/cosplade.py @@ -80,6 +80,7 @@ def __initialize__(self, options): self.queries_encoder.initialize(options) self.history_encoder.initialize(options) + @property def dimension(self): return self.queries_encoder.dimension diff --git a/src/xpmir/documents/samplers.py b/src/xpmir/documents/samplers.py index 3510b09..877c8c8 100644 --- a/src/xpmir/documents/samplers.py +++ b/src/xpmir/documents/samplers.py @@ -3,11 +3,7 @@ from experimaestro import Param, Config import torch import numpy as np -from datamaestro_text.data.ir import DocumentStore, TextItem -from datamaestro_text.data.ir.base import ( - SimpleTextTopicRecord, - SimpleTextDocumentRecord, -) +from datamaestro_text.data.ir import DocumentStore, TextItem, create_record from xpmir.letor import Random from xpmir.letor.records import DocumentRecord, PairwiseRecord, ProductRecords from xpmir.letor.samplers import BatchwiseSampler, PairwiseSampler @@ -150,9 +146,9 @@ def iter(random: np.random.RandomState): continue yield PairwiseRecord( - SimpleTextTopicRecord.from_text(spans_pos_qry[0]), - SimpleTextDocumentRecord.from_text(spans_pos_qry[1]), - SimpleTextDocumentRecord.from_text(spans_neg[random.randint(0, 2)]), + create_record(text=spans_pos_qry[0]), + create_record(text=spans_pos_qry[1]), + create_record(text=spans_neg[random.randint(0, 2)]), ) return RandomSerializableIterator(self.random, iter) @@ -174,8 +170,8 @@ def iterator(random: np.random.RandomState): res = self.get_text_span(text, random) if not res: continue - batch.add_topics(SimpleTextTopicRecord.from_text(res[0])) - batch.add_documents(SimpleTextDocumentRecord.from_text(res[1])) + batch.add_topics(create_record(text=res[0])) + batch.add_documents(create_record(text=res[1])) batch.set_relevances(relevances) yield batch diff --git a/src/xpmir/index/faiss.py b/src/xpmir/index/faiss.py index 95a27d2..0b62306 100644 --- a/src/xpmir/index/faiss.py +++ b/src/xpmir/index/faiss.py @@ -122,7 +122,7 @@ def train( index.train(sample) def execute(self): - self.device.execute(self._execute, None) + self.device.execute(self._execute) def _execute(self, device_information: DeviceInformation): # Initialization hooks diff --git a/src/xpmir/index/sparse.py b/src/xpmir/index/sparse.py index 858542e..c1e9d54 100644 --- a/src/xpmir/index/sparse.py +++ b/src/xpmir/index/sparse.py @@ -216,7 +216,8 @@ def task_outputs(self, dep): ) def execute(self): - mp.set_start_method("spawn") + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method("spawn") max_docs = ( self.documents.documentcount diff --git a/src/xpmir/letor/distillation/samplers.py b/src/xpmir/letor/distillation/samplers.py index f3b8964..b7eee7b 100644 --- a/src/xpmir/letor/distillation/samplers.py +++ b/src/xpmir/letor/distillation/samplers.py @@ -15,6 +15,7 @@ ScoredItem, SimpleTextItem, IDItem, + create_record, ) from experimaestro import Config, Meta, Param from xpmir.learning import Sampler @@ -85,9 +86,9 @@ def iterate(): with self.path.open("rt") as fp: for row in csv.reader(fp, delimiter="\t"): if self.with_queryid: - query = TopicRecord.from_id(row[2]) + query = create_record(id=row[2]) else: - query = TopicRecord.from_text(row[2]) + query = create_record(text=row[2]) if self.with_docid: documents = ( diff --git a/src/xpmir/letor/records.py b/src/xpmir/letor/records.py index 116f41a..793fb0c 100644 --- a/src/xpmir/letor/records.py +++ b/src/xpmir/letor/records.py @@ -4,8 +4,7 @@ TopicRecord, DocumentRecord, TextItem, - SimpleTextTopicRecord, - SimpleTextDocumentRecord, + create_record, ) from typing import ( Iterable, @@ -145,10 +144,8 @@ def from_texts( relevances: Optional[List[float]] = None, ): records = PointwiseRecords() - records.topics = list(map(lambda t: SimpleTextTopicRecord.from_text(t), topics)) - records.documents = list( - map(lambda t: SimpleTextDocumentRecord.from_text(t), documents) - ) + records.topics = list(map(lambda t: create_record(text=t), topics)) + records.documents = list(map(lambda t: create_record(text=t), documents)) records.relevances = relevances return records diff --git a/src/xpmir/letor/samplers/__init__.py b/src/xpmir/letor/samplers/__init__.py index 29262f0..82152e6 100644 --- a/src/xpmir/letor/samplers/__init__.py +++ b/src/xpmir/letor/samplers/__init__.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Iterator, List, Tuple, Dict, Any import numpy as np -from datamaestro.record import recordtypes +from datamaestro.record import Record from datamaestro_text.data.ir import ( Adhoc, TrainingTriplets, @@ -12,10 +12,8 @@ DocumentStore, TextItem, SimpleTextItem, - IDDocumentRecord, - SimpleTextTopicRecord, + create_record, DocumentRecord, - IDTopicRecord, IDItem, ) from experimaestro import Param, tqdm, Task, Annotated, pathgenerator @@ -344,7 +342,7 @@ def iter(random): random.randint(0, len(self.topics)) ] yield PairwiseRecord( - SimpleTextTopicRecord.from_text(title), + create_record(text=title), self.sample(positives), self.sample(negatives), ) @@ -453,7 +451,7 @@ def __next__(self): ) if neg_id != pos.id: break - neg = IDDocumentRecord.from_id(neg_id) + neg = create_record(id=neg_id) else: negatives = sample.negatives[self.negative_algo] neg = negatives[self.random.randint(len(negatives))] @@ -469,7 +467,7 @@ def __next__(self): # --- Dataloader -# FIXME: need to fix the change where there is a list of queries and type of return + class TSVPairwiseSampleDataset(PairwiseSampleDataset): """Read the pairwise sample dataset from a tsv file""" @@ -516,23 +514,18 @@ def iter(self) -> Iterator[PairwiseSample]: positives = [] negatives = {} for topic_text in sample["queries"]: - topics.append(SimpleTextTopicRecord.from_text(topic_text)) + topics.append(create_record(text=topic_text)) for pos_id in sample["pos_ids"]: - positives.append(IDDocumentRecord.from_id(pos_id)) + positives.append(create_record(id=pos_id)) for algo in sample["neg_ids"].keys(): negatives[algo] = [] for neg_id in sample["neg_ids"][algo]: - negatives[algo].append(IDDocumentRecord.from_id(neg_id)) + negatives[algo].append(create_record(id=neg_id)) yield PairwiseSample( topics=topics, positives=positives, negatives=negatives ) -@recordtypes(ScoredItem) -class ScoredIDDocumentRecord(IDDocumentRecord): - pass - - # A class for loading the data, need to move the other places. class PairwiseSamplerFromTSV(PairwiseSampler): @@ -544,9 +537,9 @@ def iter() -> Iterator[PairwiseSample]: for triplet in read_tsv(self.pairwise_samples_path): q_id, pos_id, pos_score, neg_id, neg_score = triplet yield PairwiseRecord( - IDTopicRecord.from_id(q_id), - ScoredIDDocumentRecord(IDItem(pos_id), ScoredItem(pos_score)), - ScoredIDDocumentRecord(IDItem(neg_id), ScoredItem(neg_score)), + Record(IDItem(q_id)), + Record(IDItem(pos_id), ScoredItem(pos_score)), + Record(IDItem(neg_id), ScoredItem(neg_score)), ) return SkippingIterator(iter) diff --git a/src/xpmir/letor/samplers/hydrators.py b/src/xpmir/letor/samplers/hydrators.py index 194265c..111a0e1 100644 --- a/src/xpmir/letor/samplers/hydrators.py +++ b/src/xpmir/letor/samplers/hydrators.py @@ -44,8 +44,8 @@ def transform_topics(self, topics: List[ir.TopicRecord]): if self.querystore is None: return None return [ - ir.GenericTopicRecord.create( - topic[IDItem].id, self.querystore[topic[IDItem].id] + ir.create_record( + id=topic[IDItem].id, text=self.querystore[topic[IDItem].id] ) for topic in topics ] diff --git a/src/xpmir/rankers/__init__.py b/src/xpmir/rankers/__init__.py index 1c669f1..37dd11d 100644 --- a/src/xpmir/rankers/__init__.py +++ b/src/xpmir/rankers/__init__.py @@ -22,7 +22,7 @@ from datamaestro_text.data.ir import ( Documents, DocumentStore, - SimpleTextTopicRecord, + create_record, IDItem, ) from datamaestro_text.data.ir.base import DocumentRecord @@ -98,16 +98,16 @@ def rsv( ) -> List[ScoredDocument]: # Convert into document records if isinstance(documents, str): - documents = [ScoredDocument(DocumentRecord.from_text(documents), None)] + documents = [ScoredDocument(create_record(text=documents), None)] elif isinstance(documents[0], str): documents = [ - ScoredDocument(DocumentRecord.from_text(scored_document), None) + ScoredDocument(create_record(text=scored_document), None) for scored_document in documents ] # Convert into topic record if isinstance(topic, str): - topic = SimpleTextTopicRecord.from_text(topic) + topic = create_record(text=topic) return self.compute(topic, documents) diff --git a/src/xpmir/test/letor/test_samplers.py b/src/xpmir/test/letor/test_samplers.py index 7ae8855..a5e05f6 100644 --- a/src/xpmir/test/letor/test_samplers.py +++ b/src/xpmir/test/letor/test_samplers.py @@ -1,3 +1,4 @@ +from datamaestro.record import record_type import pytest import numpy as np from typing import Iterator, Tuple @@ -17,24 +18,16 @@ class MyTrainingTriplets(TrainingTriplets): def iter( self, - ) -> Iterator[ - Tuple[ - ir.SimpleTextTopicRecord, ir.GenericDocumentRecord, ir.GenericDocumentRecord - ] - ]: + ) -> Iterator[Tuple[ir.TopicRecord, ir.DocumentRecord, ir.DocumentRecord]]: count = 0 while True: - yield ir.SimpleTextTopicRecord.from_text( - f"q{count}" - ), ir.GenericDocumentRecord.create( - 1, f"doc+{count}" - ), ir.GenericDocumentRecord.create( - 2, f"doc-{count}" - ) + yield ir.create_record(text=f"q{count}"), ir.create_record( + id=1, text=f"doc+{count}" + ), ir.create_record(id=2, text=f"doc-{count}") - topic_recordtype = ir.SimpleTextTopicRecord - document_recordtype = ir.GenericDocumentRecord + topic_recordtype = record_type(ir.IDItem, ir.SimpleTextItem) + document_recordtype = record_type(ir.SimpleTextItem) def test_serializing_tripletbasedsampler(): @@ -108,10 +101,8 @@ class FakeDocumentStore(ir.DocumentStore): def documentcount(self): return 10 - def document_int(self, internal_docid: int) -> ir.GenericDocumentRecord: - return ir.GenericDocumentRecord.create( - str(internal_docid), f"D{internal_docid} " * 10 - ) + def document_int(self, internal_docid: int) -> ir.DocumentRecord: + return ir.create_record(id=str(internal_docid), text=f"D{internal_docid} " * 10) def test_pairwise_randomspansampler(): diff --git a/src/xpmir/test/letor/test_samplers_hydrator.py b/src/xpmir/test/letor/test_samplers_hydrator.py index b1a6adb..bc7091e 100644 --- a/src/xpmir/test/letor/test_samplers_hydrator.py +++ b/src/xpmir/test/letor/test_samplers_hydrator.py @@ -1,6 +1,8 @@ +from functools import cached_property import itertools from experimaestro import Param from typing import Iterator, Tuple +from datamaestro.record import record_type import datamaestro_text.data.ir as ir from xpmir.letor.samplers import ( TrainingTriplets, @@ -17,22 +19,22 @@ class TripletIterator(TrainingTriplets): def iter( self, - ) -> Iterator[Tuple[ir.IDTopicRecord, ir.IDDocumentRecord, ir.IDDocumentRecord]]: + ) -> Iterator[Tuple[ir.TopicRecord, ir.DocumentRecord, ir.DocumentRecord]]: count = 0 while True: - yield ir.IDTopicRecord.from_id(str(count)), ir.IDDocumentRecord.from_id( - str(2 * count) - ), ir.IDDocumentRecord.from_id(str(2 * count + 1)) + yield ir.create_record(id=str(count)), ir.create_record( + id=str(2 * count) + ), ir.create_record(id=str(2 * count + 1)) count += 1 - @property + @cached_property def topic_recordtype(self): - return ir.IDTopicRecord + return record_type(ir.IDItem) - @property + @cached_property def document_recordtype(self): - return ir.IDDocumentRecord + return record_type(ir.IDItem) class FakeTextStore(TextStore): @@ -43,8 +45,8 @@ def __getitem__(self, key: str) -> str: class FakeDocumentStore(ir.DocumentStore): id: Param[str] = "" - def document_ext(self, docid: str) -> ir.GenericDocumentRecord: - return ir.GenericDocumentRecord.create(docid, f"D{docid}") + def document_ext(self, docid: str) -> ir.DocumentRecord: + return ir.create_record(id=docid, text=f"D{docid}") def test_pairwise_hydrator(): diff --git a/src/xpmir/test/neural/test_forward.py b/src/xpmir/test/neural/test_forward.py index f7bc083..e6ae29a 100644 --- a/src/xpmir/test/neural/test_forward.py +++ b/src/xpmir/test/neural/test_forward.py @@ -5,11 +5,10 @@ import torch from collections import defaultdict from experimaestro import Constant -from datamaestro_text.data.ir import TextItem +from datamaestro_text.data.ir import TextItem, create_record from xpmir.index import Index from xpmir.learning import Random, ModuleInitMode from xpmir.neural.dual import CosineDense, DotDense -from datamaestro_text.data.ir import GenericDocumentRecord, SimpleTextTopicRecord from xpmir.letor.records import ( PairwiseRecord, PairwiseRecords, @@ -53,6 +52,7 @@ def __initialize__(self, options): ) self.tokenizer = TestTokenizer().instance() + @property def dimension(self) -> int: return RandomTokensEncoder.DIMENSION @@ -168,14 +168,14 @@ def cross_scorer(): # --- QUERIES = [ - SimpleTextTopicRecord.from_text("purple cat"), - SimpleTextTopicRecord.from_text("yellow house"), + create_record(text="purple cat"), + create_record(text="yellow house"), ] DOCUMENTS = [ - GenericDocumentRecord.create("1", "the cat sat on the mat"), - GenericDocumentRecord.create("2", "the purple car"), - GenericDocumentRecord.create("3", "my little dog"), - GenericDocumentRecord.create("4", "the truck was on track"), + create_record(id="1", text="the cat sat on the mat"), + create_record(id="2", text="the purple car"), + create_record(id="3", text="my little dog"), + create_record(id="4", text="the truck was on track"), ] diff --git a/src/xpmir/test/rankers/test_full.py b/src/xpmir/test/rankers/test_full.py index 88234f0..ab2eac2 100644 --- a/src/xpmir/test/rankers/test_full.py +++ b/src/xpmir/test/rankers/test_full.py @@ -6,10 +6,9 @@ import torch from experimaestro.notifications import TaskEnv -from datamaestro_text.data.ir import SimpleTextTopicRecord, TextItem, IDItem +from datamaestro_text.data.ir import TextItem, IDItem, create_record, TopicRecord from xpmir.learning.context import TrainerContext -from xpmir.letor.records import TopicRecord from xpmir.neural.dual import DualRepresentationScorer from xpmir.rankers import ScoredDocument from xpmir.rankers.full import FullRetrieverRescorer @@ -89,10 +88,7 @@ def test_fullretrieverescorer(tmp_path: Path): # Retrieve normally scoredDocuments = {} - queries = { - qid: SimpleTextTopicRecord.from_text(f"Query {qid}") - for qid in range(NUM_QUERIES) - } + queries = {qid: create_record(text=f"Query {qid}") for qid in range(NUM_QUERIES)} # Retrieve query per query for qid, query in queries.items(): diff --git a/src/xpmir/test/utils/utils.py b/src/xpmir/test/utils/utils.py index 85c7af9..f49afc3 100644 --- a/src/xpmir/test/utils/utils.py +++ b/src/xpmir/test/utils/utils.py @@ -1,23 +1,20 @@ +from functools import cached_property from collections import OrderedDict, defaultdict from typing import ClassVar, Dict, Iterator, List, Tuple, Any import torch import numpy as np -from datamaestro.record import recordtypes +from datamaestro.record import Record, record_type from datamaestro_text.data.ir import ( + create_record, DocumentStore, - GenericDocumentRecord, InternalIDItem, + SimpleTextItem, ) from experimaestro import Param from xpmir.text.encoders import TextEncoder, RepresentationOutput -@recordtypes(InternalIDItem) -class GenericDocumentWithIDRecord(GenericDocumentRecord): - ... - - class SampleDocumentStore(DocumentStore): id: Param[str] = "" num_docs: Param[int] = 200 @@ -27,10 +24,10 @@ def __post_init__(self): self.documents = OrderedDict( ( str(ix), - GenericDocumentWithIDRecord.create( - str(ix), - f"Document {ix}", + create_record( InternalIDItem(ix), + id=str(ix), + text=f"Document {ix}", ), ) for ix in range(self.num_docs) @@ -40,19 +37,19 @@ def __post_init__(self): def documentcount(self): return len(self.documents) - def document_int(self, internal_docid: int) -> GenericDocumentWithIDRecord: + def document_int(self, internal_docid: int) -> Record: return self.documents[str(internal_docid)] - def document_ext(self, docid: str) -> GenericDocumentWithIDRecord: + def document_ext(self, docid: str) -> Record: """Returns the text of the document given its id""" return self.documents[docid] - def iter_documents(self) -> Iterator[GenericDocumentWithIDRecord]: + def iter_documents(self) -> Iterator[Record]: return iter(self.documents.values()) - @property + @cached_property def document_recordtype(self): - return GenericDocumentWithIDRecord + return record_type(InternalIDItem, SimpleTextItem) def docid_internal2external(self, docid: int): """Converts an internal collection ID (integer) to an external ID""" diff --git a/src/xpmir/text/huggingface/encoders.py b/src/xpmir/text/huggingface/encoders.py index 3f9946a..da23403 100644 --- a/src/xpmir/text/huggingface/encoders.py +++ b/src/xpmir/text/huggingface/encoders.py @@ -37,6 +37,7 @@ def static(self): """Embeddings from transformers are learnable""" return False + @property def dimension(self): return self.model.hf_config.hidden_size