Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Martin Jose committed Aug 22, 2021
1 parent 8e07c90 commit f5e5fc4
Show file tree
Hide file tree
Showing 18 changed files with 58 additions and 60 deletions.
2 changes: 0 additions & 2 deletions capreolus/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

logger = get_logger(__name__)

MRR_10 = "MRR@10"
DEFAULT_METRICS = [
"P_1",
"P_5",
Expand All @@ -27,7 +26,6 @@
"recall_100",
"recall_1000",
"recip_rank",
MRR_10,
]


Expand Down
4 changes: 2 additions & 2 deletions capreolus/extractor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Extractor(ModuleBase):
"""Base class for Extractor modules. The purpose of an Extractor is to convert queries and documents to a representation suitable for use with a :class:`~capreolus.reranker.Reranker` module.
Modules should provide:
- an ``id2vec(qid, posid, negid=None)`` method that converts the given query and document ids to an appropriate representation
- an ``id2vec_for_triplets(qid, posid, negid=None)`` method that converts the given query and document ids to an appropriate representation
"""

module_type = "extractor"
Expand Down Expand Up @@ -68,7 +68,7 @@ def _build_vocab(self, qids, docids, topics):
def build_from_benchmark(self, *args, **kwargs):
raise NotImplementedError

def id2vec(self, qid, posdocid, negdocid=None, label=None):
def id2vec_for_triplets(self, qid, posdocid, negdocid=None, label=None):
"""
Creates a feature from the (qid, docid) pair.
If negdocid is supplied, that's also included in the feature (needed for training with pairwise hinge loss)
Expand Down
2 changes: 1 addition & 1 deletion capreolus/extractor/alternate_pooled_bertpassage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def convert_to_bert_input(self, text_toks):

return inp, mask, seg, pos

def id2vec(self, qid, posid, negid=None, label=None):
def id2vec_for_triplets(self, qid, posid, negid=None, label=None):
"""
See parent class for docstring
"""
Expand Down
2 changes: 1 addition & 1 deletion capreolus/extractor/bagofwords.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def preprocess(self, qids, docids, topics):

self._build_vocab(qids, docids, topics)

def id2vec(self, q_id, posdoc_id, negdoc_id=None, **kwargs):
def id2vec_for_triplets(self, q_id, posdoc_id, negdoc_id=None, **kwargs):
query_toks = self.qid2toks[q_id]
posdoc_toks = self.docid2toks.get(posdoc_id)

Expand Down
5 changes: 4 additions & 1 deletion capreolus/extractor/bertpassage.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,10 @@ def _prepare_bert_input(self, query_toks, psg_toks):
seg = [0] * (len(query_toks) + 2) + [1] * (len(padded_input_line) - len(query_toks) - 2)
return inp, mask, seg

def id2vec(self, qid, posid, negid=None, label=None):
def id2vec_for_pairs(self, qid, posid, negid=None, label=None, reldocs=None):
return self.id2vec_for_triplets(qid, posid,negid=negid, label=label)

def id2vec_for_triplets(self, qid, posid, negid=None, label=None):
"""
See parent class for docstring
"""
Expand Down
4 changes: 2 additions & 2 deletions capreolus/extractor/berttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_tokenized_doc(self, doc_id):

return self.tokenizer.tokenize(doc)

def id2vec(self, qid, posid, negid=None, label=None):
def id2vec_for_triplets(self, qid, posid, negid=None, label=None):
assert posid is not None
tokenizer = self.tokenizer
max_doc_length = 510
Expand Down Expand Up @@ -115,7 +115,7 @@ def id2vec(self, qid, posid, negid=None, label=None):

return data

def id2vec_for_train(self, qid, posid, negid=None, label=None, reldocs=None):
def id2vec_for_pairs(self, qid, posid, negid=None, label=None, reldocs=None):
assert posid is not None
assert qid is not None
assert reldocs is not None
Expand Down
2 changes: 1 addition & 1 deletion capreolus/extractor/deeptileextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def preprocess(self, qids, docids, topics):
self._build_vocab(qids, docids, topics)
self._build_embedding_matrix()

def id2vec(self, qid, posdocid, negdocid=None, **kwargs):
def id2vec_for_triplets(self, qid, posdocid, negdocid=None, **kwargs):
query_toks = padlist(self.qid2toks[qid], self.config["maxqlen"], pad_token=self.pad_tok)
posdoc_tilebar = self.create_visualization_matrix(query_toks, self.docid2segments[posdocid], self.embeddings)

Expand Down
4 changes: 2 additions & 2 deletions capreolus/extractor/embedtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_tf_feature_description(self):

def create_tf_feature(self, sample):
"""
sample - output from self.id2vec()
sample - output from self.id2vec_for_triplets()
return - a tensorflow feature
"""
query, query_idf, posdoc, negdoc = (sample["query"], sample["query_idf"], sample["posdoc"], sample["negdoc"])
Expand Down Expand Up @@ -125,7 +125,7 @@ def _add_oov_to_vocab(self, tokens):
def _tok2vec(self, toks):
return [self.stoi[tok] for tok in toks]

def id2vec(self, qid, posid, negid=None, **kwargs):
def id2vec_for_triplets(self, qid, posid, negid=None, **kwargs):
query = self.qid2toks[qid]

# TODO find a way to calculate qlen/doclen stats earlier, so we can log them and check sanity of our values
Expand Down
2 changes: 1 addition & 1 deletion capreolus/extractor/pooled_bertpassage.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def parse_label_tensor(x):

return (pos_bert_input, pos_mask, pos_seg, neg_bert_input, neg_mask, neg_seg), label

def id2vec(self, qid, posid, negid=None, label=None):
def id2vec_for_triplets(self, qid, posid, negid=None, label=None):
"""
See parent class for docstring
"""
Expand Down
7 changes: 5 additions & 2 deletions capreolus/extractor/slowembedtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_tf_feature_description(self):

def create_tf_train_feature(self, sample):
"""
sample - output from self.id2vec()
sample - output from self.id2vec_for_triplets()
return - a tensorflow feature
"""
query, query_idf, posdoc, negdoc = (sample["query"], sample["query_idf"], sample["posdoc"], sample["negdoc"])
Expand Down Expand Up @@ -168,7 +168,10 @@ def _tok2vec(self, toks):
# return [self.embeddings[self.stoi[tok]] for tok in toks]
return [self.stoi[tok] for tok in toks]

def id2vec(self, qid, posid, negid=None, label=None):
def id2vec_for_pairs(self, qid, posid, negid=None, label=None, reldocs=None):
return self.id2vec_for_triplets(qid, posid, negid=negid, label=label)

def id2vec_for_triplets(self, qid, posid, negid=None, label=None):
assert label is not None
query = self.qid2toks[qid]

Expand Down
3 changes: 2 additions & 1 deletion capreolus/index/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def exists(self, fold=None):
return False

def create_index(self, fold=None):
raise Exception("This should not have been called")
logger.error("FAISSIndex does not implement create_index()")
pass

def get_results_path(self):
"""Return an absolute path that can be used for storing results.
Expand Down
3 changes: 3 additions & 0 deletions capreolus/index/tests/test_index.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import pytest

from capreolus import module_registry
from capreolus.benchmark import DummyBenchmark
from capreolus.collection import DummyCollection
from capreolus.index import Index
from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache

indexs = set(module_registry.get_module_names("index"))
# Because the FAISS index doesn't implement the create_index API
indexs.remove("faiss")


@pytest.mark.parametrize("index_name", indexs)
Expand Down
24 changes: 12 additions & 12 deletions capreolus/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def generate_samples(self):
try:
# Convention for label - [1, 0] indicates that doc belongs to class 1 (i.e relevant
# ^ This is used with categorical cross entropy loss
yield self.extractor.id2vec(qid, posdocid, negdocid, label=[1, 0])
yield self.extractor.id2vec_for_triplets(qid, posdocid, negdocid, label=[1, 0])
except MissingDocError:
# at training time we warn but ignore on missing docs
logger.warning(
Expand Down Expand Up @@ -199,7 +199,7 @@ def generate_samples(self):
try:
# Convention for label - [1, 0] indicates that doc belongs to class 1 (i.e relevant
# ^ This is used with categorical cross entropy loss
yield self.extractor.id2vec(qid, posdocid, negdocid, label=[1, 0])
yield self.extractor.id2vec_for_triplets(qid, posdocid, negdocid, label=[1, 0])
except MissingDocError:
# at training time we warn but ignore on missing docs
logger.warning("skipping training pair with missing features: qid=%s posid=%s negid=%s", qid, posdocid, negdocid)
Expand Down Expand Up @@ -232,9 +232,9 @@ def generate_samples(self):
qid = self.rng.choice(all_qids)
posdocid = self.rng.choice(self.qid_to_reldocs[qid])
negdocid = self.rng.choice(self.qid_to_negdocs[qid])
yield self.extractor.id2vec_for_train(qid, posdocid, negid=None, label=[0, 1], reldocs=set(self.qid_to_reldocs[qid]))
yield self.extractor.id2vec_for_pairs(qid, posdocid, negid=None, label=[0, 1], reldocs=set(self.qid_to_reldocs[qid]))

yield self.extractor.id2vec_for_train(qid, negdocid, negid=None, label=[1, 0], reldocs=set(self.qid_to_reldocs[qid]))
yield self.extractor.id2vec_for_pairs(qid, negdocid, negid=None, label=[1, 0], reldocs=set(self.qid_to_reldocs[qid]))


@Sampler.register
Expand Down Expand Up @@ -342,8 +342,8 @@ def generate_samples(self):
qid = self.rng.choice(all_qids)
posdocid = self.rng.choice(self.qid_to_reldocs[qid])
negdocid = self.rng.choice(self.qid_to_negdocs[qid])
yield self.extractor.id2vec_for_train(qid, posdocid, negid=None, label=[0, 1], reldocs=set(self.qid_to_reldocs[qid]))
yield self.extractor.id2vec_for_train(qid, negdocid, negid=None, label=[1, 0], reldocs=set(self.qid_to_reldocs[qid]))
yield self.extractor.id2vec_for_pairs(qid, posdocid, negid=None, label=[0, 1], reldocs=set(self.qid_to_reldocs[qid]))
yield self.extractor.id2vec_for_pairs(qid, negdocid, negid=None, label=[1, 0], reldocs=set(self.qid_to_reldocs[qid]))


class PredSampler(Sampler, torch.utils.data.IterableDataset):
Expand All @@ -366,9 +366,9 @@ def generate_samples(self):
for docid in docids:
try:
if docid in self.qid_to_reldocs[qid]:
yield self.extractor.id2vec(qid, docid, label=[0, 1])
yield self.extractor.id2vec_for_triplets(qid, docid, label=[0, 1])
else:
yield self.extractor.id2vec(qid, docid, label=[1, 0])
yield self.extractor.id2vec_for_triplets(qid, docid, label=[1, 0])
except MissingDocError:
# when predictiong we raise an exception on missing docs, as this may invalidate results
logger.error("got none features for prediction: qid=%s posid=%s", qid, docid)
Expand Down Expand Up @@ -429,7 +429,7 @@ def get_hash(self):
def generate_samples(self):
for docid in self.docids:
try:
yield self.extractor.id2vec(None, docid)
yield self.extractor.id2vec_for_triplets(None, docid)
except MissingDocError:
logger.info("Doc {} was missing".format(docid))

Expand Down Expand Up @@ -475,11 +475,11 @@ def generate_samples(self):
qid = self.rng.choice(all_qids)
posdocid = self.rng.choice(self.qid_to_reldocs[qid])
negdocid = self.rng.choice(self.qid_to_negdocs[qid])
data = self.extractor.id2vec_for_train(qid, posdocid, negid=None, label=1, reldocs=set(self.qid_to_reldocs[qid]))
data = self.extractor.id2vec_for_pairs(qid, posdocid, negid=None, label=1, reldocs=set(self.qid_to_reldocs[qid]))
data["residual"] = epsilon - lambda_train * (trec_run[qid][posdocid] - trec_run[qid][negdocid])
yield data

data = self.extractor.id2vec_for_train(qid, negdocid, negid=None, label=0, reldocs=set(self.qid_to_reldocs[qid]))
data = self.extractor.id2vec_for_pairs(qid, negdocid, negid=None, label=0, reldocs=set(self.qid_to_reldocs[qid]))
data["residual"] = epsilon - lambda_train * (trec_run[qid][posdocid] - trec_run[qid][negdocid])
yield data

Expand Down Expand Up @@ -540,7 +540,7 @@ def generate_samples(self):
try:
# Convention for label - [1, 0] indicates that doc belongs to class 1 (i.e relevant
# ^ This is used with categorical cross entropy loss
data = self.extractor.id2vec(qid, posdocid, negdocid, label=[1, 0])
data = self.extractor.id2vec_for_triplets(qid, posdocid, negdocid, label=[1, 0])

# This is equation 4 in the CLEAR paper
data["residual"] = epsilon - lambda_train * (self.trec_run[qid][posdocid] - self.trec_run[qid][negdocid])
Expand Down
8 changes: 4 additions & 4 deletions capreolus/sampler/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def test_train_sampler(monkeypatch, tmpdir):
train_dataset = TrainTripletSampler()
train_dataset.prepare(training_judgments, training_judgments, extractor)

def mock_id2vec(*args, **kwargs):
def mock_id2vec_for_triplets(*args, **kwargs):
return {"query": np.array([1, 2, 3, 4]), "posdoc": np.array([1, 1, 1, 1]), "negdoc": np.array([2, 2, 2, 2])}

monkeypatch.setattr(EmbedText, "id2vec", mock_id2vec)
monkeypatch.setattr(EmbedText, "id2vec_for_triplets", mock_id2vec_for_triplets)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
for idx, batch in enumerate(dataloader):
assert len(batch["query"]) == 32
Expand All @@ -47,10 +47,10 @@ def test_pred_sampler(monkeypatch, tmpdir):
pred_dataset = PredSampler()
pred_dataset.prepare(benchmark.qrels, search_run, extractor)

def mock_id2vec(*args, **kwargs):
def mock_id2vec_for_triplets(*args, **kwargs):
return {"query": np.array([1, 2, 3, 4]), "posdoc": np.array([1, 1, 1, 1])}

monkeypatch.setattr(EmbedText, "id2vec", mock_id2vec)
monkeypatch.setattr(EmbedText, "id2vec_for_triplets", mock_id2vec_for_triplets)
dataloader = torch.utils.data.DataLoader(pred_dataset, batch_size=2)
for idx, batch in enumerate(dataloader):
print(idx, batch)
Expand Down
2 changes: 1 addition & 1 deletion capreolus/searcher/tests/test_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

skip_searchers = {"bm25staticrob04yang19", "BM25Grid", "BM25Postprocess", "axiomatic"}
searchers = set(module_registry.get_module_names("searcher")) - skip_searchers
searchers = [x for x in searchers if "static" not in x]
searchers = [x for x in searchers if "static" not in x and x != "faiss"]


@pytest.mark.parametrize("searcher_name", searchers)
Expand Down
3 changes: 2 additions & 1 deletion capreolus/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from capreolus.tests.common_fixtures import tmpdir_as_cache

collections = set(module_registry.get_module_names("collection"))

# TODO: Allow the below collections too to be tested
collections = collections - {"gov2passages", "robust04passages"}

@pytest.mark.parametrize("collection_name", collections)
def test_collection_creatable(tmpdir_as_cache, collection_name):
Expand Down
Loading

0 comments on commit f5e5fc4

Please sign in to comment.