From 1ef78b2ed327e848305df283ed3f62a8c451061f Mon Sep 17 00:00:00 2001 From: Crystina Date: Tue, 17 Aug 2021 12:44:32 -0400 Subject: [PATCH] Validate benchmark topics, qrels, and folds (#178) * validation for benchmark (topic; qrel; and files) --- capreolus/benchmark/__init__.py | 122 ++++++++++++++++++++++++++- capreolus/benchmark/cds.py | 3 +- capreolus/benchmark/codesearchnet.py | 3 +- capreolus/benchmark/covid.py | 3 +- capreolus/benchmark/nf.py | 3 +- capreolus/searcher/__init__.py | 14 ++- 6 files changed, 138 insertions(+), 10 deletions(-) diff --git a/capreolus/benchmark/__init__.py b/capreolus/benchmark/__init__.py index 0cf5681a7..1a85ebf3f 100644 --- a/capreolus/benchmark/__init__.py +++ b/capreolus/benchmark/__init__.py @@ -1,11 +1,114 @@ -import json import os +import json +from copy import deepcopy +from collections import defaultdict import ir_datasets from capreolus import ModuleBase from capreolus.utils.caching import cached_file, TargetFileExists -from capreolus.utils.trec import load_qrels, load_trec_topics +from capreolus.utils.trec import write_qrels, load_qrels, load_trec_topics +from capreolus.utils.loginit import get_logger + + +logger = get_logger(__name__) + + +def validate(build_f): + def validate_folds_file(self): + if not hasattr(self, "fold_file"): + logger.warning(f"Folds file is not found for Module {self.module_name}") + return + + if self.fold_file.suffix != ".json": + raise ValueError(f"Expect folds file to be in .json format.") + + raw_folds = json.load(open(self.fold_file)) + # we actually don't need to verify the name of folds right? + + for fold_name, fold_sets in raw_folds.items(): + if set(fold_sets) != {"train_qids", "predict"}: + raise ValueError(f"Expect each fold to contain ['train_qids', 'predict'] fields.") + + if set(fold_sets["predict"]) != {"dev", "test"}: + raise ValueError(f"Expect each fold to contain ['dev', 'test'] fields under 'predict'.") + logger.info("Folds file validation finishes.") + + def validate_qrels_file(self): + if not hasattr(self, "qrel_file"): + logger.warning(f"Qrel file is not found for Module {self.module_name}") + return + + n_dup, qrels = 0, defaultdict(dict) + with open(self.qrel_file) as f: + for line in f: + qid, _, docid, label = line.strip().split() + if docid in qrels[qid]: + n_dup += 1 + if int(label) != qrels[qid][docid]: + raise ValueError(f"Found conflicting label in {self.qrel_file} for query {qid} and document {docid}.") + qrels[qid][docid] = int(label) + + if n_dup > 0: + qrel_file_no_ext, ext = os.path.splitext(self.qrel_file) + dup_qrel_file = qrel_file_no_ext + "-contain-dup-entries" + ext + os.rename(self.qrel_file, dup_qrel_file) + write_qrels(qrels, self.qrel_file) + logger.warning( + f"Removed {n_dup} entries from the file {self.qrel_file}. The original version could be found in {dup_qrel_file}." + ) + + logger.info("Qrel file validation finishes.") + + def validate_query_alignment(self): + topic_qids = set(self.topics[self.query_type]) + qrels_qids = set(self.qrels) + + for fold_name, fold_sets in self.folds.items(): + # check if there are overlap between training, dev, and test set + train_qids, dev_qids, test_qids = ( + set(fold_sets["train_qids"]), + set(fold_sets["predict"]["dev"]), + set(fold_sets["predict"]["test"]), + ) + if len(train_qids & dev_qids) > 0: + logger.warning( + f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}." + ) + if len(train_qids & test_qids) > 0: + logger.warning( + f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}." + ) + if len(dev_qids & test_qids) > 0: + logger.warning( + f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}." + ) + + # check if the topics, qrels, and folds file share a reasonable set (if not all) of queries + folds_qids = train_qids | dev_qids | test_qids + n_overlap = len(set(topic_qids) & set(qrels_qids) & set(folds_qids)) + if not len(topic_qids) == len(qrels_qids) == len(folds_qids) == n_overlap: + logger.warning( + f"Number of queries are not aligned across topics, qrels and folds in fold {fold_name}: {len(topic_qids)} queries in topics file, {len(qrels_qids)} queries in qrels file, {len(folds_qids)} queries in folds file; {n_overlap} overlap queries found among the three." + ) + + # check if any topic in folds cannot be found in topics file + for set_name, set_qids in zip(["training", "dev", "test"], [train_qids, dev_qids, test_qids]): + if len(set_qids - topic_qids) > 0: + raise ValueError( + f"{len(set_qids - topic_qids)} queries in {set_name} set of fold {fold_name} cannot be found in topic file." + ) + + logger.info("Query Alignment validation finishes.") + + def _validate(self): + """Rewrite the files that contain invalid (duplicate) entries, and remove the currently loaded variables""" + build_f(self) + validate_folds_file(self) + validate_qrels_file(self) + validate_query_alignment(self) + + return _validate class Benchmark(ModuleBase): @@ -26,6 +129,9 @@ class Benchmark(ModuleBase): relevance_level = 1 """ Documents with a relevance label >= relevance_level will be considered relevant. This corresponds to trec_eval's --level_for_rel (and is passed to pytrec_eval as relevance_level). """ + use_train_as_dev = True + """ Whether to use training set as validate set when there is no training needed, + e.g. for traditional IR algorithms like BM25 """ @property def qrels(self): @@ -45,6 +151,14 @@ def folds(self): self._folds = json.load(open(self.fold_file, "rt"), parse_int=str) return self._folds + @property + def non_nn_dev(self): + dev_per_fold = {fold_name: deepcopy(folds["predict"]["dev"]) for fold_name, folds in self.folds.items()} + if self.use_train_as_dev: + for fold_name, folds in self.folds.items(): + dev_per_fold[fold_name].extend(folds["train_qids"]) + return dev_per_fold + def get_topics_file(self, query_sets=None): """Returns path to a topics file in TSV format containing queries from query_sets. query_sets may contain any combination of 'train', 'dev', and 'test'. @@ -81,6 +195,10 @@ def get_topics_file(self, query_sets=None): return fn + @validate + def build(self): + return + class IRDBenchmark(Benchmark): ird_dataset_names = [] diff --git a/capreolus/benchmark/cds.py b/capreolus/benchmark/cds.py index 9065310e9..108a2040d 100644 --- a/capreolus/benchmark/cds.py +++ b/capreolus/benchmark/cds.py @@ -2,7 +2,7 @@ from capreolus import Dependency, constants -from . import Benchmark, IRDBenchmark +from . import Benchmark, IRDBenchmark, validate PACKAGE_PATH = constants["PACKAGE_PATH"] @@ -16,6 +16,7 @@ class CDS(IRDBenchmark): query_type = "summary" query_types = {} # diagnosis, treatment, or test + @validate def build(self): self.topics diff --git a/capreolus/benchmark/codesearchnet.py b/capreolus/benchmark/codesearchnet.py index 39c5b6391..55ebaae85 100644 --- a/capreolus/benchmark/codesearchnet.py +++ b/capreolus/benchmark/codesearchnet.py @@ -12,7 +12,7 @@ from capreolus.utils.loginit import get_logger from capreolus.utils.trec import topic_to_trectxt -from . import Benchmark +from . import Benchmark, validate logger = get_logger(__name__) PACKAGE_PATH = constants["PACKAGE_PATH"] @@ -41,6 +41,7 @@ class CodeSearchNetCorpus(Benchmark): config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")] + @validate def build(self): lang = self.config["lang"] diff --git a/capreolus/benchmark/covid.py b/capreolus/benchmark/covid.py index ce3cf6d37..85f491fdc 100644 --- a/capreolus/benchmark/covid.py +++ b/capreolus/benchmark/covid.py @@ -9,7 +9,7 @@ from capreolus.utils.loginit import get_logger from capreolus.utils.trec import load_qrels, topic_to_trectxt -from . import Benchmark +from . import Benchmark, validate logger = get_logger(__name__) PACKAGE_PATH = constants["PACKAGE_PATH"] @@ -30,6 +30,7 @@ class COVID(Benchmark): config_spec = [ConfigOption("udelqexpand", False), ConfigOption("useprevqrels", True)] + @validate def build(self): if self.collection.config["round"] == self.lastest_round and not self.config["useprevqrels"]: logger.warning(f"No evaluation can be done for the lastest round without using previous qrels") diff --git a/capreolus/benchmark/nf.py b/capreolus/benchmark/nf.py index ec9edc46d..940f7e0c8 100644 --- a/capreolus/benchmark/nf.py +++ b/capreolus/benchmark/nf.py @@ -5,7 +5,7 @@ from capreolus.utils.loginit import get_logger from capreolus.utils.trec import topic_to_trectxt -from . import Benchmark +from . import Benchmark, validate logger = get_logger(__name__) PACKAGE_PATH = constants["PACKAGE_PATH"] @@ -34,6 +34,7 @@ class NF(Benchmark): query_type = "title" + @validate def build(self): fields, label_range = self.config["fields"], self.config["labelrange"] self.field2kws = { diff --git a/capreolus/searcher/__init__.py b/capreolus/searcher/__init__.py index bbe5777c9..8310e3392 100644 --- a/capreolus/searcher/__init__.py +++ b/capreolus/searcher/__init__.py @@ -32,17 +32,23 @@ def load_trec_run(fn): run = OrderedDefaultDict() with open(fn, "rt") as f: - for line in f: + for i, line in enumerate(f): line = line.strip() if len(line) > 0: - qid, _, docid, rank, score, desc = line.split(" ") + try: + qid, _, docid, rank, score, desc = line.split() + except ValueError as e: + logger.error( + f"Encountered malformated line when reading {fn} [Line #{i}], possibly because the writing to runfile was interruptded." + ) + raise e run[qid][docid] = float(score) return run @staticmethod - def write_trec_run(preds, outfn): + def write_trec_run(preds, outfn, mode="wt"): count = 0 - with open(outfn, "wt") as outf: + with open(outfn, mode) as outf: qids = sorted(preds.keys(), key=lambda k: int(k)) for qid in qids: rank = 1