-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_reader.py
43 lines (31 loc) · 1.57 KB
/
evaluate_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import logging
from argparse import ArgumentParser
from pytorch_lightning.trainer import Trainer
from soseki.reader.modeling import ReaderLightningModule
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
logger = logging.getLogger()
logging.getLogger("lightning").setLevel(logging.ERROR)
def main(args: ArgumentParser):
overriding_hparams = {}
overriding_hparams["test_file"] = args.test_file
overriding_hparams["test_gold_passages_file"] = args.test_gold_passages_file
if args.test_num_passages is not None:
overriding_hparams["eval_num_passages"] = args.test_num_passages
if args.test_max_load_passages is not None:
overriding_hparams["eval_max_load_passages"] = args.test_max_load_passages
if args.test_batch_size is not None:
overriding_hparams["eval_batch_size"] = args.test_batch_size
model = ReaderLightningModule.load_from_checkpoint(args.reader_file, **overriding_hparams)
trainer = Trainer.from_argparse_args(args, logger=False)
trainer.test(model)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--reader_file", type=str, required=True)
parser.add_argument("--test_file", type=str, nargs="+", required=True)
parser.add_argument("--test_gold_passages_file", type=str)
parser.add_argument("--test_num_passages", type=int)
parser.add_argument("--test_max_load_passages", type=int)
parser.add_argument("--test_batch_size", type=int)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
main(args)