From d4376d1fc89d81c9e567cf1de9f15a8a2800bb00 Mon Sep 17 00:00:00 2001 From: smallv0221 <33639025+smallv0221@users.noreply.github.com> Date: Tue, 6 Apr 2021 09:47:05 +0800 Subject: [PATCH] remove old datasets, swich spawn to launch (#224) * remove old datasets, spawn->lanch * fix rank --- docs/data_prepare/dataset_self_defined.rst | 5 +- examples/language_model/rnnlm/README.md | 3 +- examples/language_model/rnnlm/args.py | 8 +- examples/language_model/rnnlm/train.py | 11 +- .../DuReader-robust/README.md | 9 +- .../DuReader-robust/args.py | 8 +- .../DuReader-robust/run_du.py | 15 +- .../DuReader-yesno/README.md | 7 +- .../DuReader-yesno/args.py | 8 +- .../DuReader-yesno/run_du.py | 13 +- .../SQuAD/README.md | 16 +- .../SQuAD/args.py | 8 +- .../SQuAD/run_squad.py | 19 +- paddlenlp/datasets/__init__.py | 17 +- paddlenlp/datasets/chnsenticorp.py | 73 +- .../datasets/{experimental => }/cmrc2018.py | 0 paddlenlp/datasets/couplet.py | 103 ++- paddlenlp/datasets/dataset.py | 556 +++++++---- paddlenlp/datasets/{experimental => }/drcd.py | 0 .../datasets/{experimental => }/duconv.py | 0 paddlenlp/datasets/dureader.py | 244 ----- .../{experimental => }/dureader_robust.py | 0 .../{experimental => }/dureader_yesno.py | 0 paddlenlp/datasets/experimental/__init__.py | 33 - .../datasets/experimental/chnsenticorp.py | 76 -- paddlenlp/datasets/experimental/couplet.py | 105 --- paddlenlp/datasets/experimental/dataset.py | 497 ---------- paddlenlp/datasets/experimental/glue.py | 322 ------- paddlenlp/datasets/experimental/imdb.py | 70 -- paddlenlp/datasets/experimental/lcqmc.py | 75 -- paddlenlp/datasets/experimental/msra_ner.py | 64 -- .../experimental/peoples_daily_ner.py | 64 -- paddlenlp/datasets/experimental/poetry.py | 67 -- paddlenlp/datasets/experimental/ptb.py | 59 -- paddlenlp/datasets/experimental/squad.py | 86 -- paddlenlp/datasets/glue.py | 874 ++++++------------ paddlenlp/datasets/imdb.py | 63 +- .../datasets/{experimental => }/iwslt15.py | 0 paddlenlp/datasets/lcqmc.py | 82 +- paddlenlp/datasets/msra_ner.py | 64 +- paddlenlp/datasets/peoples_daily_ner.py | 69 +- paddlenlp/datasets/poetry.py | 65 +- paddlenlp/datasets/ptb.py | 147 ++- paddlenlp/datasets/squad.py | 702 ++------------ paddlenlp/datasets/translation.py | 412 --------- .../datasets/{experimental => }/wmt14ende.py | 0 .../{experimental => }/yahoo_answer_100k.py | 0 47 files changed, 1090 insertions(+), 4029 deletions(-) rename paddlenlp/datasets/{experimental => }/cmrc2018.py (100%) rename paddlenlp/datasets/{experimental => }/drcd.py (100%) rename paddlenlp/datasets/{experimental => }/duconv.py (100%) delete mode 100644 paddlenlp/datasets/dureader.py rename paddlenlp/datasets/{experimental => }/dureader_robust.py (100%) rename paddlenlp/datasets/{experimental => }/dureader_yesno.py (100%) delete mode 100644 paddlenlp/datasets/experimental/__init__.py delete mode 100644 paddlenlp/datasets/experimental/chnsenticorp.py delete mode 100644 paddlenlp/datasets/experimental/couplet.py delete mode 100644 paddlenlp/datasets/experimental/dataset.py delete mode 100644 paddlenlp/datasets/experimental/glue.py delete mode 100644 paddlenlp/datasets/experimental/imdb.py delete mode 100644 paddlenlp/datasets/experimental/lcqmc.py delete mode 100644 paddlenlp/datasets/experimental/msra_ner.py delete mode 100644 paddlenlp/datasets/experimental/peoples_daily_ner.py delete mode 100644 paddlenlp/datasets/experimental/poetry.py delete mode 100644 paddlenlp/datasets/experimental/ptb.py delete mode 100644 paddlenlp/datasets/experimental/squad.py rename paddlenlp/datasets/{experimental => }/iwslt15.py (100%) delete mode 100644 paddlenlp/datasets/translation.py rename paddlenlp/datasets/{experimental => }/wmt14ende.py (100%) rename paddlenlp/datasets/{experimental => }/yahoo_answer_100k.py (100%) diff --git a/docs/data_prepare/dataset_self_defined.rst b/docs/data_prepare/dataset_self_defined.rst index 3781272450c9d..dba2a618b417b 100644 --- a/docs/data_prepare/dataset_self_defined.rst +++ b/docs/data_prepare/dataset_self_defined.rst @@ -2,12 +2,12 @@ 如何自定义数据集 ============ -通过使用PaddleNLP提供的 :class:`MapDataset` 和 :class:`IterDataset` 。任何人都可以方便的定义属于自己的数据集。 +通过使用PaddleNLP提供的 :func:`load_dataset` , :class:`MapDataset` 和 :class:`IterDataset` 。任何人都可以方便的定义属于自己的数据集。 从本地文件创建数据集 ------------------- -从本地文件创建数据集时,我们 **推荐** 根据本地数据集的格式给出读取function并传入 :func:`load_dataset`中创建数据集。 +从本地文件创建数据集时,我们 **推荐** 根据本地数据集的格式给出读取function并传入 :func:`load_dataset` 中创建数据集。 以 :obj:`waybill_ie` 快递单信息抽取任务中的数据为例: @@ -25,6 +25,7 @@ labels = labels.split('\002') yield {'tokens': words, 'labels': labels} + # data_path为read()方法的参数 map_ds = load_dataset(read, data_path='train.txt',lazy=False) iter_ds = load_dataset(read, data_path='train.txt',lazy=True) diff --git a/examples/language_model/rnnlm/README.md b/examples/language_model/rnnlm/README.md index 64a43811c4774..efd6459bd911c 100644 --- a/examples/language_model/rnnlm/README.md +++ b/examples/language_model/rnnlm/README.md @@ -30,7 +30,8 @@ 任务训练启动命令如下: ``` -python train.py +unset CUDA_VISIBLE_DEVICES +python -m paddle.distributed.launch --gpus "0" train.py \ ``` 程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存模型到checkpoint、中。 diff --git a/examples/language_model/rnnlm/args.py b/examples/language_model/rnnlm/args.py index 6d401b9a0ff7c..fdd9a6ba355af 100644 --- a/examples/language_model/rnnlm/args.py +++ b/examples/language_model/rnnlm/args.py @@ -29,9 +29,9 @@ def parse_args(): default=None, help="The path of checkpoint to be loaded.") parser.add_argument( - "--n_gpu", - type=int, - default=1, - help="number of gpus to use, 0 for cpu.") + '--device', + choices=['cpu', 'gpu'], + default="gpu", + help="Select which device to train model, defaults to gpu.") args = parser.parse_args() return args diff --git a/examples/language_model/rnnlm/train.py b/examples/language_model/rnnlm/train.py index ddf6340320aa5..221a2665020e2 100644 --- a/examples/language_model/rnnlm/train.py +++ b/examples/language_model/rnnlm/train.py @@ -71,7 +71,7 @@ def group_texts(examples): def train(args): - paddle.set_device("gpu" if args.n_gpu else "cpu") + paddle.set_device(args.device) data_path = args.data_path train_loader, valid_loader, test_loader, vocab_size = create_data_loader( batch_size=args.batch_size, @@ -121,7 +121,8 @@ def train(args): if __name__ == '__main__': args = parse_args() - if args.n_gpu > 1: - paddle.distributed.spawn(train, args=(args, ), nprocs=args.n_gpu) - else: - train(args) + assert args.device in [ + "cpu", "gpu", "xpu" + ], "Invalid device! Available device should be cpu, gpu, or xpu." + + train(args) diff --git a/examples/machine_reading_comprehension/DuReader-robust/README.md b/examples/machine_reading_comprehension/DuReader-robust/README.md index 4aba00153f0d9..943fe3ad8ede4 100644 --- a/examples/machine_reading_comprehension/DuReader-robust/README.md +++ b/examples/machine_reading_comprehension/DuReader-robust/README.md @@ -30,7 +30,8 @@ DuReader-robust数据集是单篇章、抽取式阅读理解数据集,具体 按如下方式启动 Fine-tuning: ```shell -python -u ./run_du.py \ +unset CUDA_VISIBLE_DEVICES +python -m paddle.distributed.launch --gpus "0" run_du.py \ --task_name dureader_robust \ --model_type bert \ --model_name_or_path bert-base-chinese \ @@ -42,10 +43,10 @@ python -u ./run_du.py \ --save_steps 1000 \ --warmup_proportion 0.1 \ --weight_decay 0.01 \ - --output_dir ./tmp/dureader_robust/ \ - --do_predict \ + --output_dir ./tmp/dureader-robust/ \ --do_train \ - --n_gpu 1 \ + --do_predict \ + --device gpu \ ``` * `task_name`: 数据集的名称,不区分大小写,如dureader_robust,cmrc2018, drcd。 diff --git a/examples/machine_reading_comprehension/DuReader-robust/args.py b/examples/machine_reading_comprehension/DuReader-robust/args.py index 6e44dc8ba894c..3cc6fc9eb9af7 100644 --- a/examples/machine_reading_comprehension/DuReader-robust/args.py +++ b/examples/machine_reading_comprehension/DuReader-robust/args.py @@ -98,10 +98,10 @@ def parse_args(): parser.add_argument( "--seed", type=int, default=42, help="random seed for initialization") parser.add_argument( - "--n_gpu", - type=int, - default=1, - help="number of gpus to use, 0 for cpu.") + '--device', + choices=['cpu', 'gpu'], + default="gpu", + help="Select which device to train model, defaults to gpu.") parser.add_argument( "--doc_stride", type=int, diff --git a/examples/machine_reading_comprehension/DuReader-robust/run_du.py b/examples/machine_reading_comprehension/DuReader-robust/run_du.py index 4a316beda4bd4..2d9a91b4377b5 100644 --- a/examples/machine_reading_comprehension/DuReader-robust/run_du.py +++ b/examples/machine_reading_comprehension/DuReader-robust/run_du.py @@ -109,16 +109,17 @@ def forward(self, y, label): def run(args): - paddle.set_device("gpu" if args.n_gpu else "cpu") + paddle.set_device(args.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() + rank = paddle.distributed.get_rank() task_name = args.task_name.lower() args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) set_seed(args) - if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + if rank == 0: if os.path.exists(args.model_name_or_path): print("init checkpoint from %s" % args.model_name_or_path) @@ -259,8 +260,7 @@ def prepare_train_features(examples): optimizer.clear_grad() if global_step % args.save_steps == 0 or global_step == num_training_steps: - if (not args.n_gpu > 1 - ) or paddle.distributed.get_rank() == 0: + if rank == 0: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) if not os.path.exists(output_dir): @@ -307,7 +307,7 @@ def prepare_validation_features(examples): return tokenized_examples - if args.do_predict and paddle.distributed.get_rank() == 0: + if args.do_predict and rank == 0: if args.predict_file: dev_ds = load_dataset(task_name, data_files=args.predict_file) @@ -334,7 +334,4 @@ def prepare_validation_features(examples): if __name__ == "__main__": args = parse_args() - if args.n_gpu > 1: - paddle.distributed.spawn(run, args=(args, ), nprocs=args.n_gpu) - else: - run(args) + run(args) diff --git a/examples/machine_reading_comprehension/DuReader-yesno/README.md b/examples/machine_reading_comprehension/DuReader-yesno/README.md index 729021ba45d78..0d8eac77168f4 100644 --- a/examples/machine_reading_comprehension/DuReader-yesno/README.md +++ b/examples/machine_reading_comprehension/DuReader-yesno/README.md @@ -40,7 +40,8 @@ 按如下方式启动 Fine-tuning: ```shell -python -u ./run_du.py \ +unset CUDA_VISIBLE_DEVICES +python -m paddle.distributed.launch --gpus "0" run_du.py \ --model_type bert \ --model_name_or_path bert-base-chinese \ --max_seq_length 384 \ @@ -52,7 +53,7 @@ python -u ./run_du.py \ --warmup_proportion 0.1 \ --weight_decay 0.01 \ --output_dir ./tmp/dureader-yesno/ \ - --n_gpu 1 \ + --device gpu \ ``` * `model_type`: 预训练模型的种类。如bert,ernie,roberta等。 @@ -66,4 +67,4 @@ accu: 0.861040 ``` 评估结束后模型会自动对测试集进行预测,并将可提交的结果生成在`prediction.json`中。 -**NOTE:** 如需恢复模型训练,则model_name_or_path只需指定到文件夹名即可。如`--model_name_or_path=./tmp/dureader-yesno/model_19000/`,程序会自动加载模型参数`/model_state.pdparams`,也会自动加载词表,模型config和tokenizer的config。 \ No newline at end of file +**NOTE:** 如需恢复模型训练,则model_name_or_path只需指定到文件夹名即可。如`--model_name_or_path=./tmp/dureader-yesno/model_19000/`,程序会自动加载模型参数`/model_state.pdparams`,也会自动加载词表,模型config和tokenizer的config。 diff --git a/examples/machine_reading_comprehension/DuReader-yesno/args.py b/examples/machine_reading_comprehension/DuReader-yesno/args.py index 4f5fe758ff18e..f98c307293721 100644 --- a/examples/machine_reading_comprehension/DuReader-yesno/args.py +++ b/examples/machine_reading_comprehension/DuReader-yesno/args.py @@ -80,10 +80,10 @@ def parse_args(): parser.add_argument( "--seed", type=int, default=42, help="random seed for initialization") parser.add_argument( - "--n_gpu", - type=int, - default=1, - help="number of gpus to use, 0 for cpu.") + '--device', + choices=['cpu', 'gpu'], + default="gpu", + help="Select which device to train model, defaults to gpu.") parser.add_argument( "--do_lower_case", action='store_false', diff --git a/examples/machine_reading_comprehension/DuReader-yesno/run_du.py b/examples/machine_reading_comprehension/DuReader-yesno/run_du.py index 503c1ed5b1409..75e64dc3a8526 100644 --- a/examples/machine_reading_comprehension/DuReader-yesno/run_du.py +++ b/examples/machine_reading_comprehension/DuReader-yesno/run_du.py @@ -89,10 +89,10 @@ def predict(model, data_loader): def do_train(args): - paddle.set_device("gpu" if args.n_gpu else "cpu") + paddle.set_device(args.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() - + rank = paddle.distributed.get_rank() args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) @@ -192,7 +192,7 @@ def do_train(args): optimizer.clear_grad() if global_step % args.save_steps == 0 or global_step == num_training_steps: - if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + if rank == 0: evaluate(model, metric, dev_data_loader) output_dir = os.path.join(args.output_dir, "model_%d" % global_step) @@ -207,7 +207,7 @@ def do_train(args): if global_step == num_training_steps: break - if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + if rank == 0: predictions = predict(model, test_data_loader) with open('prediction.json', "w") as writer: writer.write( @@ -217,7 +217,4 @@ def do_train(args): if __name__ == "__main__": args = parse_args() - if args.n_gpu > 1: - paddle.distributed.spawn(do_train, args=(args, ), nprocs=args.n_gpu) - else: - do_train(args) + do_train(args) diff --git a/examples/machine_reading_comprehension/SQuAD/README.md b/examples/machine_reading_comprehension/SQuAD/README.md index 9e2fc732db452..1128e7ea8e823 100644 --- a/examples/machine_reading_comprehension/SQuAD/README.md +++ b/examples/machine_reading_comprehension/SQuAD/README.md @@ -29,21 +29,22 @@ SQuAD v2.0 对于 SQuAD v1.1,按如下方式启动 Fine-tuning: ```shell -python -u ./run_squad.py \ +unset CUDA_VISIBLE_DEVICES +python -m paddle.distributed.launch --gpus "0" run_squad.py \ --model_type bert \ --model_name_or_path bert-base-uncased \ --max_seq_length 384 \ --batch_size 12 \ --learning_rate 3e-5 \ --num_train_epochs 2 \ - --logging_steps 100 \ + --logging_steps 1000 \ --save_steps 1000 \ --warmup_proportion 0.1 \ --weight_decay 0.01 \ --output_dir ./tmp/squad/ \ + --device gpu \ --do_train \ - --do_predict \ - --n_gpu 1 + --do_predict ``` * `model_type`: 预训练模型的种类。如bert,ernie,roberta等。 @@ -68,7 +69,8 @@ python -u ./run_squad.py \ 对于 SQuAD v2.0,按如下方式启动 Fine-tuning: ```shell -python -u ./run_squad.py \ +unset CUDA_VISIBLE_DEVICES +python -m paddle.distributed.launch --gpus "0" run_squad.py \ --model_type bert \ --model_name_or_path bert-base-uncased \ --max_seq_length 384 \ @@ -80,9 +82,9 @@ python -u ./run_squad.py \ --warmup_proportion 0.1 \ --weight_decay 0.01 \ --output_dir ./tmp/squad/ \ - --n_gpu 1 \ + --device gpu \ --do_train \ - --do_pred \ + --do_predict \ --version_2_with_negative ``` diff --git a/examples/machine_reading_comprehension/SQuAD/args.py b/examples/machine_reading_comprehension/SQuAD/args.py index 19f7ba8260c88..bbe5bde22fda0 100644 --- a/examples/machine_reading_comprehension/SQuAD/args.py +++ b/examples/machine_reading_comprehension/SQuAD/args.py @@ -92,10 +92,10 @@ def parse_args(): parser.add_argument( "--seed", type=int, default=42, help="random seed for initialization") parser.add_argument( - "--n_gpu", - type=int, - default=1, - help="number of gpus to use, 0 for cpu.") + '--device', + choices=['cpu', 'gpu'], + default="gpu", + help="Select which device to train model, defaults to gpu.") parser.add_argument( "--doc_stride", type=int, diff --git a/examples/machine_reading_comprehension/SQuAD/run_squad.py b/examples/machine_reading_comprehension/SQuAD/run_squad.py index 6409d10c4d84e..463df75b05829 100644 --- a/examples/machine_reading_comprehension/SQuAD/run_squad.py +++ b/examples/machine_reading_comprehension/SQuAD/run_squad.py @@ -109,16 +109,16 @@ def forward(self, y, label): def run(args): - paddle.set_device("gpu" if args.n_gpu else "cpu") + paddle.set_device(args.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() - + rank = paddle.distributed.get_rank() args.model_type = args.model_type.lower() model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) set_seed(args) - if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: + if rank == 0: if os.path.exists(args.model_name_or_path): print("init checkpoint from %s" % args.model_name_or_path) @@ -268,8 +268,7 @@ def prepare_train_features(examples): optimizer.clear_grad() if global_step % args.save_steps == 0 or global_step == num_training_steps: - if (not args.n_gpu > 1 - ) or paddle.distributed.get_rank() == 0: + if rank == 0: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) if not os.path.exists(output_dir): @@ -316,7 +315,7 @@ def prepare_validation_features(examples): return tokenized_examples - if args.do_predict: + if args.do_predict and rank == 0: if args.predict_file: dev_ds = load_dataset('sqaud', data_files=args.predict_file) elif args.version_2_with_negative: @@ -339,13 +338,9 @@ def prepare_validation_features(examples): collate_fn=dev_batchify_fn, return_list=True) - if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0: - evaluate(model, dev_data_loader, args) + evaluate(model, dev_data_loader, args) if __name__ == "__main__": args = parse_args() - if args.n_gpu > 1: - paddle.distributed.spawn(run, args=(args, ), nprocs=args.n_gpu) - else: - run(args) + run(args) diff --git a/paddlenlp/datasets/__init__.py b/paddlenlp/datasets/__init__.py index 18e902e10e535..0bff4b4cc222f 100644 --- a/paddlenlp/datasets/__init__.py +++ b/paddlenlp/datasets/__init__.py @@ -12,17 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .chnsenticorp import * from .dataset import * +from .chnsenticorp import * +from .cmrc2018 import * +from .drcd import * +from .dureader_robust import * from .glue import * -from .imdb import * from .lcqmc import * from .msra_ner import * -from .peoples_daily_ner import * from .ptb import * from .squad import * -from .translation import * -from .dureader import * +from .peoples_daily_ner import * from .poetry import * +from .cmrc2018 import * +from .drcd import * +from .dureader_robust import * +from .glue import * +from .wmt14ende import * from .couplet import * -from .experimental import load_dataset, DatasetBuilder, MapDataset, IterDataset \ No newline at end of file +from .yahoo_answer_100k import * diff --git a/paddlenlp/datasets/chnsenticorp.py b/paddlenlp/datasets/chnsenticorp.py index 3b790cbdd9388..1d19f14aa3837 100644 --- a/paddlenlp/datasets/chnsenticorp.py +++ b/paddlenlp/datasets/chnsenticorp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,23 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import collections -import io +import json import os -import warnings -from paddle.io import Dataset from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url from paddlenlp.utils.env import DATA_HOME - -from .dataset import TSVDataset +from . import DatasetBuilder __all__ = ['ChnSentiCorp'] -class ChnSentiCorp(TSVDataset): +class ChnSentiCorp(DatasetBuilder): """ ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for opinion mining) @@ -37,54 +33,41 @@ class ChnSentiCorp(TSVDataset): URL = "https://bj.bcebos.com/paddlehub-dataset/chnsenticorp.tar.gz" MD5 = "fbb3217aeac76a2840d2d5cd19688b07" - META_INFO = collections.namedtuple( - 'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples')) + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) SPLITS = { 'train': META_INFO( os.path.join('chnsenticorp', 'train.tsv'), - '689360c4a4a9ce8d8719ed500ae80907', (1, 0), 1), + '689360c4a4a9ce8d8719ed500ae80907'), 'dev': META_INFO( os.path.join('chnsenticorp', 'dev.tsv'), - '05e4b02561c2a327833e05bbe8156cec', (1, 0), 1), + '05e4b02561c2a327833e05bbe8156cec'), 'test': META_INFO( os.path.join('chnsenticorp', 'test.tsv'), - '917dfc6fbce596bb01a91abaa6c86f9e', (1, 0), 1) + '917dfc6fbce596bb01a91abaa6c86f9e'), } - def __init__(self, - mode='train', - root=None, - return_all_fields=False, - **kwargs): - if return_all_fields: - splits = copy.deepcopy(self.__class__.SPLITS) - mode_info = list(splits[mode]) - mode_info[2] = None - splits[mode] = self.META_INFO(*mode_info) - self.SPLITS = splits - - self._get_data(root, mode, **kwargs) - - def _get_data(self, root, mode, **kwargs): - default_root = DATA_HOME - filename, data_hash, field_indices, num_discard_samples = self.SPLITS[ - mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) + def _get_data(self, mode, **kwargs): + """Downloads dataset.""" + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - path = get_path_from_url(self.URL, default_root, self.MD5) - fullname = os.path.join(default_root, filename) - super(ChnSentiCorp, self).__init__( - fullname, - field_indices=field_indices, - num_discard_samples=num_discard_samples, - **kwargs) + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename): + """Reads data.""" + with open(filename, 'r', encoding='utf-8') as f: + head = None + for line in f: + data = line.strip().split("\t") + if not head: + head = data + else: + label, text = data + yield {"text": text, "label": label} def get_labels(self): """ diff --git a/paddlenlp/datasets/experimental/cmrc2018.py b/paddlenlp/datasets/cmrc2018.py similarity index 100% rename from paddlenlp/datasets/experimental/cmrc2018.py rename to paddlenlp/datasets/cmrc2018.py diff --git a/paddlenlp/datasets/couplet.py b/paddlenlp/datasets/couplet.py index 42d0e96bd0190..d812fd88b6ab7 100644 --- a/paddlenlp/datasets/couplet.py +++ b/paddlenlp/datasets/couplet.py @@ -12,75 +12,94 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import os +import warnings -from paddlenlp.datasets import TranslationDataset +from paddle.io import Dataset +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder -__all__ = ['CoupletDataset'] +__all__ = ['Couplet'] -class CoupletDataset(TranslationDataset): +class Couplet(DatasetBuilder): """ - Couplet dataset. This dataset is from this github repository: + Couplet dataset. The couplet data is from this github repository: https://github.com/v-zich/couplet-clean-dataset, which filters dirty data from the original repository https://github.com/wb14123/couplet-dataset. - - Args: - mode(str, optional): It could be 'train', 'dev' or 'test'. Default: - 'train'. - root(str, optional): Data directory of dataset. If not - provided, dataset will be saved to default directory - `~/.paddlenlp/datasets/machine_translation/CoupletDataset`. If - provided, md5 check would be performed, and dataset would be - downloaded in default directory if failed. Default: None. - Example: - .. code-block:: python - - from paddlenlp.datasets import CoupletDataset - couplet_dataset = CoupletDataset() """ - URL = "https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz" + META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file', + 'src_md5', 'tgt_md5')) + MD5 = '5c0dcde8eec6a517492227041c2e2d54' SPLITS = { - 'train': TranslationDataset.META_INFO( + 'train': META_INFO( os.path.join("couplet", "train_src.tsv"), os.path.join("couplet", "train_tgt.tsv"), "ad137385ad5e264ac4a54fe8c95d1583", "daf4dd79dbf26040696eee0d645ef5ad"), - 'dev': TranslationDataset.META_INFO( + 'dev': META_INFO( os.path.join("couplet", "dev_src.tsv"), os.path.join("couplet", "dev_tgt.tsv"), "65bf9e72fa8fdf0482751c1fd6b6833c", "3bc3b300b19d170923edfa8491352951"), - 'test': TranslationDataset.META_INFO( + 'test': META_INFO( os.path.join("couplet", "test_src.tsv"), os.path.join("couplet", "test_tgt.tsv"), "f0a7366dfa0acac884b9f4901aac2cc1", "56664bff3f2edfd7a751a55a689f90c2") } - VOCAB_INFO = (os.path.join("couplet", "vocab.txt"), os.path.join( - "couplet", "vocab.txt"), "0bea1445c7c7fb659b856bb07e54a604", + VOCAB_INFO = (os.path.join("couplet", "vocab.txt"), "0bea1445c7c7fb659b856bb07e54a604") UNK_TOKEN = '' BOS_TOKEN = '' EOS_TOKEN = '' - MD5 = '5c0dcde8eec6a517492227041c2e2d54' - def __init__(self, mode='train', root=None): - data_select = ('train', 'dev', 'test') - if mode not in data_select: - raise TypeError( - '`train`, `dev` or `test` is supported but `{}` is passed in'. - format(mode)) - # Download and read data - self.data = self.get_data(mode=mode, root=root) - self.vocab, _ = self.get_vocab(root) - self.transform() + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[ + mode] + src_fullname = os.path.join(default_root, src_filename) + tgt_fullname = os.path.join(default_root, tgt_filename) + + vocab_filename, vocab_hash = self.VOCAB_INFO + vocab_fullname = os.path.join(default_root, vocab_filename) + + if (not os.path.exists(src_fullname) or + (src_data_hash and not md5file(src_fullname) == src_data_hash)) or ( + not os.path.exists(tgt_fullname) or + (tgt_data_hash and + not md5file(tgt_fullname) == tgt_data_hash)) or ( + not os.path.exists(vocab_fullname) or + (vocab_hash and + not md5file(vocab_fullname) == vocab_hash)): + get_path_from_url(self.URL, default_root, self.MD5) + + return src_fullname, tgt_fullname + + def _read(self, filename, *args): + src_filename, tgt_filename = filename + with open(src_filename, 'r', encoding='utf-8') as src_f: + with open(tgt_filename, 'r', encoding='utf-8') as tgt_f: + for src_line, tgt_line in zip(src_f, tgt_f): + src_line = src_line.strip() + tgt_line = tgt_line.strip() + if not src_line and not tgt_line: + continue + yield {"first": src_line, "second": tgt_line} + + def get_vocab(self): + vocab_fullname = os.path.join(DATA_HOME, self.__class__.__name__, + self.VOCAB_INFO[0]) - def transform(self): - eos_id = self.vocab[self.EOS_TOKEN] - bos_id = self.vocab[self.BOS_TOKEN] - self.data = [( - [bos_id] + self.vocab.to_indices(data[0].split("\x02")) + [eos_id], - [bos_id] + self.vocab.to_indices(data[1].split("\x02")) + [eos_id]) - for data in self.data] + # Construct vocab_info to match the form of the input of `Vocab.load_vocabulary()` function + vocab_info = { + 'filepath': vocab_fullname, + 'unk_token': self.UNK_TOKEN, + 'bos_token': self.BOS_TOKEN, + 'eos_token': self.EOS_TOKEN + } + return vocab_info diff --git a/paddlenlp/datasets/dataset.py b/paddlenlp/datasets/dataset.py index 09b9eb9e8efed..f13ba8744dae3 100644 --- a/paddlenlp/datasets/dataset.py +++ b/paddlenlp/datasets/dataset.py @@ -12,96 +12,104 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import collections import io import math import os import warnings +import sys +import inspect import paddle.distributed as dist -from paddle.io import Dataset +from paddle.io import Dataset, IterableDataset from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url from paddlenlp.utils.env import DATA_HOME +from typing import Iterable, Iterator, Optional, List, Any, Callable, Union +import importlib +from functools import partial -__all__ = [ - 'MapDatasetWrapper', - 'TSVDataset', -] +__all__ = ['MapDataset', 'DatasetBuilder', 'IterDataset', 'load_dataset'] +DATASETS_MODULE_PATH = "paddlenlp.datasets." -@classmethod -def get_datasets(cls, *args, **kwargs): - """ - Get muitiple datasets like train, valid and test of current dataset. - Example: - .. code-block:: python +def import_main_class(module_path): + """Import a module at module_path and return its main class. - from paddlenlp.datasets import GlueQNLI - train_dataset, dev_dataset, test_dataset = GlueQNLI.get_datasets(['train', 'dev', 'test']) - train_dataset, dev_dataset, test_dataset = GlueQNLI.get_datasets(mode=['train', 'dev', 'test']) - train_dataset = GlueQNLI.get_datasets('train') - train_dataset = GlueQNLI.get_datasets(['train']) - train_dataset = GlueQNLI.get_datasets(mode='train') """ - if not args and not kwargs: - try: - args = cls.SPLITS.keys() - except: - raise AttributeError( - 'Dataset must have SPLITS attridute to use get_dataset if configs is None.' - ) - - datasets = tuple(MapDatasetWrapper(cls(arg)) for arg in args) + module_path = DATASETS_MODULE_PATH + module_path + module = importlib.import_module(module_path) + main_cls_type = DatasetBuilder + + # Find the main class in our imported module + module_main_cls = None + for name, obj in module.__dict__.items(): + if isinstance(obj, type) and issubclass(obj, main_cls_type): + if name == 'DatasetBuilder': + continue + module_main_cls = obj + break + + return module_main_cls + + +def load_dataset(path_or_read_func, + name=None, + data_files=None, + splits=None, + lazy=None, + **kwargs): + if inspect.isfunction(path_or_read_func): + assert lazy is not None, "lazy can not be None in custom mode." + kwargs['name'] = name + kwargs['data_files'] = data_files + kwargs['splits'] = splits + custom_kwargs = {} + for name in inspect.signature(path_or_read_func).parameters.keys(): + if name in kwargs.keys(): + custom_kwargs[name] = kwargs[name] + + reader_instance = SimpleBuilder(lazy=lazy, read_func=path_or_read_func) + return reader_instance.read(**custom_kwargs) else: + reader_cls = import_main_class(path_or_read_func) + if not name: + reader_instance = reader_cls(lazy=lazy, **kwargs) + else: + reader_instance = reader_cls(lazy=lazy, name=name, **kwargs) - for arg in args: - if not isinstance(arg, list): - return MapDatasetWrapper(cls(*args, **kwargs)) - for value in kwargs.values(): - if not isinstance(value, list): - return MapDatasetWrapper(cls(*args, **kwargs)) - - num_datasets = len(args[0]) if args else len(list(kwargs.values())[0]) - datasets = tuple( - MapDatasetWrapper( - cls(*(args[i] for args in args), **( - {key: value[i] - for key, value in kwargs.items()}))) - for i in range(num_datasets)) - - return datasets if len(datasets) > 1 else datasets[0] - - -Dataset.get_datasets = get_datasets + datasets = reader_instance.read_datasets( + data_files=data_files, splits=splits) + return datasets -class MapDatasetWrapper(Dataset): +class MapDataset(Dataset): """ Wraps a dataset-like object as a instance of Dataset, and equips it with - `apply` and other utility methods. All non-magic methods of the raw object + `map` and other utility methods. All non-magic methods of the raw object also accessible. Args: data (list|Dataset): A dataset-like object. It can be a list or a subclass of Dataset. """ - def __init__(self, data): + def __init__(self, data, **kwargs): self.data = data self._transform_pipline = [] self.new_data = self.data - def _transform(self, data, pipline): - for fn in reversed(pipline): + self.label_list = kwargs.pop('label_list', None) + self.vocab_info = kwargs.pop('vocab_info', None) + + def _transform(self, data): + for fn in self._transform_pipline: data = fn(data) return data def __getitem__(self, idx): - return self._transform( - self.new_data[idx], self._transform_pipline - ) if self._transform_pipline else self.new_data[idx] + return self._transform(self.new_data[ + idx]) if self._transform_pipline else self.new_data[idx] def __len__(self): return len(self.new_data) @@ -109,12 +117,10 @@ def __len__(self): def filter(self, fn): """ Filters samples by the filter function and uses the filtered data to - create a new MapDatasetWrapper instance. + update this dataset. Args: fn (callable): A filter function that takes a sample as input and returns a boolean. Samples that return False are discarded. - Returns: - MapDatasetWrapper: The filtered dataset """ self.new_data = [ @@ -125,8 +131,7 @@ def filter(self, fn): def shard(self, num_shards=None, index=None): """ - Use samples whose indices mod `index` equals 0 to create a new - MapDatasetWrapper instance. + Use samples whose indices mod `index` equals 0 to update this dataset. Args: num_shards (int, optional): A integer representing the number of data shards. If None, `num_shards` would be number of trainers. @@ -134,8 +139,6 @@ def shard(self, num_shards=None, index=None): index (int, optional): A integer representing the index of the current shard. If None, index` would be the current trainer rank id. Default: None. - Returns: - MapDatasetWrapper: The result dataset """ if num_shards is None: num_shards = dist.get_world_size() @@ -143,7 +146,6 @@ def shard(self, num_shards=None, index=None): index = dist.get_rank() num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards)) - total_size = num_samples * num_shards # add extra samples to make it evenly divisible self.new_data = [ self.new_data[idx] for idx in range(len(self.new_data)) @@ -154,142 +156,342 @@ def shard(self, num_shards=None, index=None): return self - def apply(self, fn, lazy=False): + def map(self, fn, lazy=True, batched=False): """ - Performs specific function on the dataset to transform every sample. + Performs specific function on the dataset to transform and update every sample. Args: fn (callable): Transformations to be performed. It receives single - sample as argument rather than dataset. + sample as argument if batched is False. Else it receives all examples. lazy (bool, optional): If True, transformations would be delayed and - performed on demand. Otherwise, transforms all samples at once - and return a new MapDatasetWrapper instance. Note that if `fn` is + performed on demand. Otherwise, transforms all samples at once. Note that if `fn` is stochastic, `lazy` should be True or you will get the same result on all epochs. Defalt: False. - Returns: - MapDatasetWrapper: A new MapDatasetWrapper instance if `lazy` is True, \ - otherwise bind `fn` as a property to transform on demand. + batched(bool, optional): If True, transformations would take all examples as input and + return a collection of transformed examples. Note that if set True, `lazy` option + would be ignored. """ - if lazy: + if batched: + self.new_data = fn(self.new_data) + elif lazy: self._transform_pipline.append(fn) else: self.new_data = [ fn(self.new_data[idx]) for idx in range(len(self.new_data)) ] + return self def __getattr__(self, name): return getattr(self.data, name) -class TSVDataset(Dataset): +class IterDataset(IterableDataset): """ - Common tab separated text dataset that reads text fields based on provided - sample splitter and field separator. - The returned dataset includes samples, each of which can either be a list - of text fields if field_separator is specified, or otherwise a single - string segment produced by the sample_splitter. + Wraps a dataset-like object as a instance of Dataset, and equips it with + `map` and other utility methods. All non-magic methods of the raw object + also accessible. Args: - filename (str|list of str): Path to the input text file or list of - paths to the input text files. - encoding (str): File encoding format. Default: 'utf8'. - sample_splitter (function): A function that splits the dataset string - into samples. Default: str.splitlines - field_separator (function|None): A function that splits each sample - string into list of text fields. If None, raw samples are returned - according to `sample_splitter`. Default: split method of str with - tab as separator. - num_discard_samples (int): Number of samples discarded at the head of - the first file. Default: 0. - field_indices (list|int|None): If set, for each sample, only fields - with provided indices are selected as the output. Otherwise all - fields are returned. Default: None. - allow_missing (bool): If set to True, no exception will be thrown if - the number of fields is smaller than the maximum field index - provided. Default: False. - - Example: - assume `test.tsv` contains the following content: - Id\tFirstName\tLastName - a\tmale\tTom - b\tFemal\tCat - discard the first line and select the 0th and 2nd fields - .. code-block:: python - - from paddlenlp.datasets import TSVDataset - dataset = TSVDataset('test.tsv', num_discard_samples=1, - field_indices=[0, 2]) - dataset[0] # ['a', 'Tom'] - dataset[1] # ['b', 'Cat'] + data (Iterable): A dataset-like object. It can be a Iterable or a + subclass of Dataset. """ - def __init__(self, - filename, - encoding='utf-8', - sample_splitter=lambda x: x.splitlines(), - field_separator=lambda x: x.split('\t'), - num_discard_samples=0, - field_indices=None, - allow_missing=False): - assert sample_splitter, 'sample_splitter must be specified.' - - if not isinstance(filename, (tuple, list)): - filename = (filename, ) - - self._filenames = [os.path.expanduser(f) for f in filename] - self._encoding = encoding - self._sample_splitter = sample_splitter - self._field_separator = field_separator - self._num_discard_samples = num_discard_samples - self._field_indices = field_indices - self._allow_missing = allow_missing - self.data = self._read() - - def _should_discard(self): - discard = self._num_discard_samples > 0 - self._num_discard_samples -= 1 - return discard - - def _field_selector(self, fields): - if not self._field_indices: - return fields - try: - result = [fields[i] for i in self._field_indices] - except IndexError as e: - raise (IndexError('%s. Fields = %s' % (str(e), str(fields)))) - return result - - def _read(self): - all_samples = [] - for filename in self._filenames: - with io.open(filename, 'r', encoding=self._encoding) as fin: - content = fin.read() - samples = (s for s in self._sample_splitter(content) - if not self._should_discard()) - if self._field_separator: - if not self._allow_missing: - samples = [ - self._field_selector(self._field_separator(s)) - for s in samples - ] - else: - selected_samples = [] - num_missing = 0 - for s in samples: - try: - fields = self._field_separator(s) - selected_samples.append( - self._field_selector(fields)) - except IndexError: - num_missing += 1 - if num_missing > 0: - warnings.warn('%d incomplete samples in %s' % - (num_missing, filename)) - samples = selected_samples - all_samples += samples - return all_samples + def __init__(self, data, **kwargs): + self.data = data + self._transform_pipline = [] + self._filter_pipline = [] - def __getitem__(self, idx): - return self.data[idx] + self.label_list = kwargs.pop('label_list', None) + self.vocab_info = kwargs.pop('vocab_info', None) - def __len__(self): - return len(self.data) + def _transform(self, data): + for fn in self._transform_pipline: + data = fn(data) + return data + + def _shard_filter(self, num_samples): + return True + + def _filter(self, data): + for fn in self._filter_pipline: + if not fn(data): + return False + return True + + def __iter__(self): + num_samples = 0 + if inspect.isfunction(self.data): + for example in self.data(): + if (not self._filter_pipline or + self._filter(self._filter_pipline) + ) and self._shard_filter(num_samples=num_samples): + yield self._transform( + example) if self._transform_pipline else example + num_samples += 1 + else: + if inspect.isgenerator(self.data): + warnings.warn( + 'Reciving generator as data source, data can only be iterated once' + ) + for example in self.data: + if (not self._filter_pipline or + self._filter(self._filter_pipline) + ) and self._shard_filter(num_samples=num_samples): + yield self._transform( + example) if self._transform_pipline else example + num_samples += 1 + + def filter(self, fn): + """ + Filters samples by the filter function and uses the filtered data to + update this dataset. + Args: + fn (callable): A filter function that takes a sample as input and + returns a boolean. Samples that return False are discarded. + """ + + self._filter_pipline.append(fn) + + return self + + def shard(self, num_shards=None, index=None): + """ + Use samples whose indices mod `index` equals 0 to update this dataset. + Args: + num_shards (int, optional): A integer representing the number of + data shards. If None, `num_shards` would be number of trainers. + Default: None + index (int, optional): A integer representing the index of the + current shard. If None, index` would be the current trainer rank + id. Default: None. + """ + if num_shards is None: + num_shards = dist.get_world_size() + if index is None: + index = dist.get_rank() + + def sharder(num_shards, index, num_samples): + if num_samples % num_shards == index: + return True + else: + return False + + fn = partial(sharder, num_shards=num_shards, index=index) + self._shard_filter = fn + return self + + def map(self, fn): + """ + Performs specific function on the dataset to transform and update every sample. + Args: + fn (callable): Transformations to be performed. It receives single + sample as argument. + """ + + self._transform_pipline.append(fn) + + return self + + def __getattr__(self, name): + return getattr(self.data, name) + + +class DatasetBuilder: + """ + A base class for all DatasetBuilder. It provides a `read()` function to turn + a data file into a MapDataset or IterDataset. + + `_get_data()` function and `_read()` function should be implemented to download + data file and read data file into a `Iterable` of the examples. + """ + lazy = False + + def __init__(self, lazy=None, name=None, **config): + if lazy is not None: + self.lazy = lazy + self.name = name + self.config = config + + def read_datasets(self, splits=None, data_files=None): + datasets = [] + assert splits or data_files, "`data_files` and `splits` can not both be None." + + if data_files: + assert isinstance(data_files, str) or isinstance( + data_files, dict + ) or isinstance(data_files, tuple) or isinstance( + data_files, list + ), "`data_files` should be a string or tuple or list or a dictionary whose key is split name ande value is a path of data file." + if isinstance(data_files, str): + split = 'train' + datasets.append(self.read(filename=data_files, split=split)) + elif isinstance(data_files, tuple) or isinstance(data_files, list): + split = 'train' + datasets += [ + self.read( + filename=filename, split=split) + for filename in data_files + ] + else: + datasets += [ + self.read( + filename=filename, split=split) + for split, filename in data_files.items() + ] + + if splits: + assert isinstance(splits, str) or ( + isinstance(splits, list) and isinstance(splits[0], str) + ) or ( + isinstance(splits, tuple) and isinstance(splits[0], str) + ), "`splits` should be a string or list of string or a tuple of string." + if isinstance(splits, str): + filename = self._get_data(splits) + datasets.append(self.read(filename=filename, split=splits)) + else: + for split in splits: + filename = self._get_data(split) + datasets.append(self.read(filename=filename, split=split)) + + return datasets if len(datasets) > 1 else datasets[0] + + def read(self, filename, split='train'): + """ + Returns an dataset containing all the examples that can be read from the file path. + If `self.lazy` is `False`, this eagerly reads all instances from `self._read()` + and returns an `MapDataset`. + If `self.lazy` is `True`, this returns an `IterDataset`, which internally + relies on the generator created from `self._read()` to lazily produce examples. + In this case your implementation of `_read()` must also be lazy + (that is, not load all examples into memory at once). + """ + + label_list = self.get_labels() + vocab_info = self.get_vocab() + + if self.lazy: + + def generate_examples(): + generator = self._read( + filename, split + ) if self._read.__code__.co_argcount > 2 else self._read( + filename) + for example in generator: + # We need to check if the example contains label column and confirm its name. + # For now we only allow `label` or `labels` to be the name of label column. + if 'labels' in example.keys(): + label_col = 'labels' + elif 'label' in example.keys(): + label_col = 'label' + else: + label_col = None + + # Convert class label to label ids. + if label_list is not None and example.get(label_col, None): + label_dict = {} + for i, label in enumerate(label_list): + label_dict[label] = i + if isinstance(example[label_col], list) or isinstance( + example[label_col], tuple): + for label_idx in range(len(example[label_col])): + example[label_col][label_idx] = label_dict[ + example[label_col][label_idx]] + else: + example[label_col] = label_dict[example[label_col]] + + yield example + else: + yield example + + return IterDataset( + generate_examples(), + label_list=label_list, + vocab_info=vocab_info) + else: + examples = self._read( + filename, + split) if self._read.__code__.co_argcount > 2 else self._read( + filename) + + # Then some validation. + if not isinstance(examples, list): + examples = list(examples) + + if not examples: + raise ValueError( + "No instances were read from the given filepath {}. " + "Is the path correct?".format(filename)) + + # We need to check if the example contains label column and confirm its name. + # For now we only allow `label` or `labels` to be the name of label column. + if 'labels' in examples[0].keys(): + label_col = 'labels' + elif 'label' in examples[0].keys(): + label_col = 'label' + else: + label_col = None + + # Convert class label to label ids. + if label_list is not None and examples[0].get(label_col, None): + label_dict = {} + for i, label in enumerate(label_list): + label_dict[label] = i + for idx in range(len(examples)): + if isinstance(examples[idx][label_col], list) or isinstance( + examples[idx][label_col], tuple): + for label_idx in range(len(examples[idx][label_col])): + examples[idx][label_col][label_idx] = label_dict[ + examples[idx][label_col][label_idx]] + else: + examples[idx][label_col] = label_dict[examples[idx][ + label_col]] + + return MapDataset( + examples, label_list=label_list, vocab_info=vocab_info) + + def _read(self, filename: str, *args): + """ + Reads examples from the given file_path and returns them as an + `Iterable` (which could be a list or could be a generator). + """ + raise NotImplementedError + + def _get_data(self, mode: str): + """ + Download examples from the given URL and customized split informations and returns a filepath. + """ + raise NotImplementedError + + def get_labels(self): + """ + Return list of class labels of the dataset if specified. + """ + return None + + def get_vocab(self): + """ + Return vocab file path of the dataset if specified. + """ + return None + + +class SimpleBuilder(DatasetBuilder): + def __init__(self, lazy, read_func): + self._read = read_func + self.lazy = lazy + + def read(self, **kwargs): + if self.lazy: + + def generate_examples(): + generator = self._read(**kwargs) + for example in generator: + yield example + + return IterDataset(generate_examples) + else: + examples = self._read(**kwargs) + if hasattr(examples, '__len__') and hasattr(examples, + '__getitem__'): + return MapDataset(examples) + else: + return MapDataset(list(examples)) diff --git a/paddlenlp/datasets/experimental/drcd.py b/paddlenlp/datasets/drcd.py similarity index 100% rename from paddlenlp/datasets/experimental/drcd.py rename to paddlenlp/datasets/drcd.py diff --git a/paddlenlp/datasets/experimental/duconv.py b/paddlenlp/datasets/duconv.py similarity index 100% rename from paddlenlp/datasets/experimental/duconv.py rename to paddlenlp/datasets/duconv.py diff --git a/paddlenlp/datasets/dureader.py b/paddlenlp/datasets/dureader.py deleted file mode 100644 index 60569d78143e6..0000000000000 --- a/paddlenlp/datasets/dureader.py +++ /dev/null @@ -1,244 +0,0 @@ -import copy -import collections -import json -import os -import warnings - -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from paddle.io import Dataset -from .squad import InputFeatures, SQuAD - -__all__ = ['DuReader', 'DuReaderYesNo'] - - -class DuReaderExample(object): - """A single training/test example for simple sequence classification. - - For examples without an answer, the start and end position are -1. - """ - - def __init__(self, - qas_id, - question_text, - doc_tokens, - orig_answer_text=None, - start_position=None, - end_position=None, - question_type=None): - self.qas_id = qas_id - self.question_text = question_text - self.doc_tokens = doc_tokens - self.orig_answer_text = orig_answer_text - self.start_position = start_position - self.end_position = end_position - self.question_type = question_type - self.is_impossible = False - - -class DuReader(SQuAD): - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - - DATA_URL = 'https://dataset-bj.cdn.bcebos.com/dureader/dureader_preprocessed.zip' - - SPLITS = { - 'train': META_INFO( - os.path.join('preprocessed', 'trainset', 'zhidao.train.json'), - None), - 'dev': META_INFO( - os.path.join('preprocessed', 'devset', 'zhidao.dev.json'), None), - 'test': META_INFO( - os.path.join('preprocessed', 'testset', 'zhidao.test.json'), None) - } - - def __init__(self, - tokenizer, - mode='train', - root=None, - doc_stride=128, - max_query_length=64, - max_seq_length=512, - **kwargs): - - super(DuReader, self).__init__( - tokenizer=tokenizer, - mode=mode, - root=root, - doc_stride=doc_stride, - max_query_length=max_query_length, - max_seq_length=max_seq_length, - **kwargs) - - def _get_data(self, root, mode, **kwargs): - default_root = os.path.join(DATA_HOME, 'DuReader') - - filename, data_hash = self.SPLITS[mode] - - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - - fullname = get_path_from_url( - self.DATA_URL, os.path.join(default_root, 'preprocessed')) - - self.full_path = fullname - - def _read(self): - examples = [] - data_lines = [] - with open(self.full_path, "r", encoding="utf8") as reader: - data_lines += reader.readlines() - with open( - self.full_path.replace('zhidao', 'search'), "r", - encoding="utf8") as reader: - data_lines += reader.readlines() - for entry in data_lines: - source = json.loads(entry.strip()) - start_id = None - end_id = None - orig_answer_text = None - - if self.is_training: - if (len(source['answer_spans']) == 0): - continue - if source['answers'] == []: - continue - if (source['match_scores'][0] < 0.7): - continue - - docs_index = source['answer_docs'][0] - start_id = source['answer_spans'][0][0] - end_id = source['answer_spans'][0][1] + 1 - - try: - answer_passage_idx = source['documents'][docs_index][ - 'most_related_para'] - except: - continue - - doc_tokens = source['documents'][docs_index][ - 'segmented_paragraphs'][answer_passage_idx] - - if source['fake_answers'][0] != "".join(doc_tokens[start_id: - end_id]): - continue - orig_answer_text = source['fake_answers'][0] - end_id = end_id - 1 - - else: - doc_tokens = [] - for doc in source['documents']: - para_infos = [] - for para_tokens in doc['segmented_paragraphs']: - question_tokens = source['segmented_question'] - common_with_question = collections.Counter( - para_tokens) & collections.Counter(question_tokens) - correct_preds = sum(common_with_question.values()) - if correct_preds == 0: - recall_wrt_question = 0 - else: - recall_wrt_question = float(correct_preds) / len( - question_tokens) - para_infos.append((para_tokens, recall_wrt_question, - len(para_tokens))) - para_infos.sort(key=lambda x: (-x[1], x[2])) - for para_info in para_infos[:1]: - doc_tokens += para_info[0] - if 'answers' in source.keys(): - orig_answer_text = source['answers'] - - example = DuReaderExample( - qas_id=source['question_id'], - question_text=source['question'].strip(), - question_type=source['question_type'], - doc_tokens=doc_tokens, - orig_answer_text=orig_answer_text, - start_position=start_id, - end_position=end_id) - - examples.append(example) - - self.examples = examples - - -class DuReaderYesNo(Dataset): - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - - DATA_URL = 'https://dataset-bj.cdn.bcebos.com/qianyan/dureader_yesno-data.tar.gz' - - SPLITS = { - 'train': META_INFO( - os.path.join('dureader_yesno-data', 'train.json'), - 'c469a0ef3f975cfd705e3553ddb27cc1'), - 'dev': META_INFO( - os.path.join('dureader_yesno-data', 'dev.json'), - 'c38544f8b5a7b567492314e3232057b5'), - 'test': META_INFO( - os.path.join('dureader_yesno-data', 'test.json'), - '1c7a1a3ea5b8992eeaeea017fdc2d55f') - } - - def __init__(self, mode='train', root=None, **kwargs): - - self._get_data(root, mode, **kwargs) - self._transform_func = None - - if mode == 'train': - self.is_training = True - else: - self.is_training = False - - self._read() - - def _get_data(self, root, mode, **kwargs): - default_root = os.path.join(DATA_HOME, 'DuReader') - - filename, data_hash = self.SPLITS[mode] - - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - - get_path_from_url(self.DATA_URL, default_root) - - self.full_path = fullname - - def _read(self): - data_lines = [] - with open(self.full_path, "r", encoding="utf8") as reader: - data_lines += reader.readlines() - - examples = [] - for entry in data_lines: - source = json.loads(entry.strip()) - examples.append([ - source['question'], source['answer'], source['yesno_answer'], - source['id'] - ]) - - self.examples = examples - - def __len__(self): - return len(self.examples) - - def __getitem__(self, idx): - return self.examples[idx] - - def get_labels(self): - """ - Return labels of the DuReaderYesNo sample. - """ - return ["Yes", "No", "Depends"] diff --git a/paddlenlp/datasets/experimental/dureader_robust.py b/paddlenlp/datasets/dureader_robust.py similarity index 100% rename from paddlenlp/datasets/experimental/dureader_robust.py rename to paddlenlp/datasets/dureader_robust.py diff --git a/paddlenlp/datasets/experimental/dureader_yesno.py b/paddlenlp/datasets/dureader_yesno.py similarity index 100% rename from paddlenlp/datasets/experimental/dureader_yesno.py rename to paddlenlp/datasets/dureader_yesno.py diff --git a/paddlenlp/datasets/experimental/__init__.py b/paddlenlp/datasets/experimental/__init__.py deleted file mode 100644 index 0bff4b4cc222f..0000000000000 --- a/paddlenlp/datasets/experimental/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .dataset import * -from .chnsenticorp import * -from .cmrc2018 import * -from .drcd import * -from .dureader_robust import * -from .glue import * -from .lcqmc import * -from .msra_ner import * -from .ptb import * -from .squad import * -from .peoples_daily_ner import * -from .poetry import * -from .cmrc2018 import * -from .drcd import * -from .dureader_robust import * -from .glue import * -from .wmt14ende import * -from .couplet import * -from .yahoo_answer_100k import * diff --git a/paddlenlp/datasets/experimental/chnsenticorp.py b/paddlenlp/datasets/experimental/chnsenticorp.py deleted file mode 100644 index 1d19f14aa3837..0000000000000 --- a/paddlenlp/datasets/experimental/chnsenticorp.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import json -import os - -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['ChnSentiCorp'] - - -class ChnSentiCorp(DatasetBuilder): - """ - ChnSentiCorp (by Tan Songbo at ICT of Chinese Academy of Sciences, and for - opinion mining) - - """ - - URL = "https://bj.bcebos.com/paddlehub-dataset/chnsenticorp.tar.gz" - MD5 = "fbb3217aeac76a2840d2d5cd19688b07" - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - SPLITS = { - 'train': META_INFO( - os.path.join('chnsenticorp', 'train.tsv'), - '689360c4a4a9ce8d8719ed500ae80907'), - 'dev': META_INFO( - os.path.join('chnsenticorp', 'dev.tsv'), - '05e4b02561c2a327833e05bbe8156cec'), - 'test': META_INFO( - os.path.join('chnsenticorp', 'test.tsv'), - '917dfc6fbce596bb01a91abaa6c86f9e'), - } - - def _get_data(self, mode, **kwargs): - """Downloads dataset.""" - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash = self.SPLITS[mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - get_path_from_url(self.URL, default_root, self.MD5) - - return fullname - - def _read(self, filename): - """Reads data.""" - with open(filename, 'r', encoding='utf-8') as f: - head = None - for line in f: - data = line.strip().split("\t") - if not head: - head = data - else: - label, text = data - yield {"text": text, "label": label} - - def get_labels(self): - """ - Return labels of the ChnSentiCorp object. - """ - return ["0", "1"] diff --git a/paddlenlp/datasets/experimental/couplet.py b/paddlenlp/datasets/experimental/couplet.py deleted file mode 100644 index d812fd88b6ab7..0000000000000 --- a/paddlenlp/datasets/experimental/couplet.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import os -import warnings - -from paddle.io import Dataset -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['Couplet'] - - -class Couplet(DatasetBuilder): - """ - Couplet dataset. The couplet data is from this github repository: - https://github.com/v-zich/couplet-clean-dataset, which filters dirty data - from the original repository https://github.com/wb14123/couplet-dataset. - """ - URL = "https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz" - META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file', - 'src_md5', 'tgt_md5')) - MD5 = '5c0dcde8eec6a517492227041c2e2d54' - SPLITS = { - 'train': META_INFO( - os.path.join("couplet", "train_src.tsv"), - os.path.join("couplet", "train_tgt.tsv"), - "ad137385ad5e264ac4a54fe8c95d1583", - "daf4dd79dbf26040696eee0d645ef5ad"), - 'dev': META_INFO( - os.path.join("couplet", "dev_src.tsv"), - os.path.join("couplet", "dev_tgt.tsv"), - "65bf9e72fa8fdf0482751c1fd6b6833c", - "3bc3b300b19d170923edfa8491352951"), - 'test': META_INFO( - os.path.join("couplet", "test_src.tsv"), - os.path.join("couplet", "test_tgt.tsv"), - "f0a7366dfa0acac884b9f4901aac2cc1", - "56664bff3f2edfd7a751a55a689f90c2") - } - VOCAB_INFO = (os.path.join("couplet", "vocab.txt"), - "0bea1445c7c7fb659b856bb07e54a604") - UNK_TOKEN = '' - BOS_TOKEN = '' - EOS_TOKEN = '' - - def _get_data(self, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[ - mode] - src_fullname = os.path.join(default_root, src_filename) - tgt_fullname = os.path.join(default_root, tgt_filename) - - vocab_filename, vocab_hash = self.VOCAB_INFO - vocab_fullname = os.path.join(default_root, vocab_filename) - - if (not os.path.exists(src_fullname) or - (src_data_hash and not md5file(src_fullname) == src_data_hash)) or ( - not os.path.exists(tgt_fullname) or - (tgt_data_hash and - not md5file(tgt_fullname) == tgt_data_hash)) or ( - not os.path.exists(vocab_fullname) or - (vocab_hash and - not md5file(vocab_fullname) == vocab_hash)): - get_path_from_url(self.URL, default_root, self.MD5) - - return src_fullname, tgt_fullname - - def _read(self, filename, *args): - src_filename, tgt_filename = filename - with open(src_filename, 'r', encoding='utf-8') as src_f: - with open(tgt_filename, 'r', encoding='utf-8') as tgt_f: - for src_line, tgt_line in zip(src_f, tgt_f): - src_line = src_line.strip() - tgt_line = tgt_line.strip() - if not src_line and not tgt_line: - continue - yield {"first": src_line, "second": tgt_line} - - def get_vocab(self): - vocab_fullname = os.path.join(DATA_HOME, self.__class__.__name__, - self.VOCAB_INFO[0]) - - # Construct vocab_info to match the form of the input of `Vocab.load_vocabulary()` function - vocab_info = { - 'filepath': vocab_fullname, - 'unk_token': self.UNK_TOKEN, - 'bos_token': self.BOS_TOKEN, - 'eos_token': self.EOS_TOKEN - } - return vocab_info diff --git a/paddlenlp/datasets/experimental/dataset.py b/paddlenlp/datasets/experimental/dataset.py deleted file mode 100644 index 6441403635854..0000000000000 --- a/paddlenlp/datasets/experimental/dataset.py +++ /dev/null @@ -1,497 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import io -import math -import os -import warnings -import sys -import inspect - -import paddle.distributed as dist -from paddle.io import Dataset, IterableDataset -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from typing import Iterable, Iterator, Optional, List, Any, Callable, Union -import importlib -from functools import partial - -__all__ = ['MapDataset', 'DatasetBuilder', 'IterDataset', 'load_dataset'] - -DATASETS_MODULE_PATH = "paddlenlp.datasets.experimental." - - -def import_main_class(module_path): - """Import a module at module_path and return its main class. - - """ - module_path = DATASETS_MODULE_PATH + module_path - module = importlib.import_module(module_path) - main_cls_type = DatasetBuilder - - # Find the main class in our imported module - module_main_cls = None - for name, obj in module.__dict__.items(): - if isinstance(obj, type) and issubclass(obj, main_cls_type): - if name == 'DatasetBuilder': - continue - module_main_cls = obj - break - - return module_main_cls - - -def load_dataset(path_or_read_func, - name=None, - data_files=None, - splits=None, - lazy=None, - **kwargs): - if inspect.isfunction(path_or_read_func): - assert lazy is not None, "lazy can not be None in custom mode." - kwargs['name'] = name - kwargs['data_files'] = data_files - kwargs['splits'] = splits - custom_kwargs = {} - for name in inspect.signature(path_or_read_func).parameters.keys(): - if name in kwargs.keys(): - custom_kwargs[name] = kwargs[name] - - reader_instance = SimpleBuilder(lazy=lazy, read_func=path_or_read_func) - return reader_instance.read(**custom_kwargs) - else: - reader_cls = import_main_class(path_or_read_func) - if not name: - reader_instance = reader_cls(lazy=lazy, **kwargs) - else: - reader_instance = reader_cls(lazy=lazy, name=name, **kwargs) - - datasets = reader_instance.read_datasets( - data_files=data_files, splits=splits) - return datasets - - -class MapDataset(Dataset): - """ - Wraps a dataset-like object as a instance of Dataset, and equips it with - `map` and other utility methods. All non-magic methods of the raw object - also accessible. - Args: - data (list|Dataset): A dataset-like object. It can be a list or a - subclass of Dataset. - """ - - def __init__(self, data, **kwargs): - self.data = data - self._transform_pipline = [] - self.new_data = self.data - - self.label_list = kwargs.pop('label_list', None) - self.vocab_info = kwargs.pop('vocab_info', None) - - def _transform(self, data): - for fn in self._transform_pipline: - data = fn(data) - return data - - def __getitem__(self, idx): - return self._transform(self.new_data[ - idx]) if self._transform_pipline else self.new_data[idx] - - def __len__(self): - return len(self.new_data) - - def filter(self, fn): - """ - Filters samples by the filter function and uses the filtered data to - update this dataset. - Args: - fn (callable): A filter function that takes a sample as input and - returns a boolean. Samples that return False are discarded. - """ - - self.new_data = [ - self.new_data[idx] for idx in range(len(self.new_data)) - if fn(self.new_data[idx]) - ] - return self - - def shard(self, num_shards=None, index=None): - """ - Use samples whose indices mod `index` equals 0 to update this dataset. - Args: - num_shards (int, optional): A integer representing the number of - data shards. If None, `num_shards` would be number of trainers. - Default: None - index (int, optional): A integer representing the index of the - current shard. If None, index` would be the current trainer rank - id. Default: None. - """ - if num_shards is None: - num_shards = dist.get_world_size() - if index is None: - index = dist.get_rank() - - num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards)) - # add extra samples to make it evenly divisible - self.new_data = [ - self.new_data[idx] for idx in range(len(self.new_data)) - if idx % num_shards == index - ] - if len(self.new_data) < num_samples: - self.new_data.append(self.new_data[index + 1 - num_shards]) - - return self - - def map(self, fn, lazy=True, batched=False): - """ - Performs specific function on the dataset to transform and update every sample. - Args: - fn (callable): Transformations to be performed. It receives single - sample as argument if batched is False. Else it receives all examples. - lazy (bool, optional): If True, transformations would be delayed and - performed on demand. Otherwise, transforms all samples at once. Note that if `fn` is - stochastic, `lazy` should be True or you will get the same - result on all epochs. Defalt: False. - batched(bool, optional): If True, transformations would take all examples as input and - return a collection of transformed examples. Note that if set True, `lazy` option - would be ignored. - """ - if batched: - self.new_data = fn(self.new_data) - elif lazy: - self._transform_pipline.append(fn) - else: - self.new_data = [ - fn(self.new_data[idx]) for idx in range(len(self.new_data)) - ] - - return self - - def __getattr__(self, name): - return getattr(self.data, name) - - -class IterDataset(IterableDataset): - """ - Wraps a dataset-like object as a instance of Dataset, and equips it with - `map` and other utility methods. All non-magic methods of the raw object - also accessible. - Args: - data (Iterable): A dataset-like object. It can be a Iterable or a - subclass of Dataset. - """ - - def __init__(self, data, **kwargs): - self.data = data - self._transform_pipline = [] - self._filter_pipline = [] - - self.label_list = kwargs.pop('label_list', None) - self.vocab_info = kwargs.pop('vocab_info', None) - - def _transform(self, data): - for fn in self._transform_pipline: - data = fn(data) - return data - - def _shard_filter(self, num_samples): - return True - - def _filter(self, data): - for fn in self._filter_pipline: - if not fn(data): - return False - return True - - def __iter__(self): - num_samples = 0 - if inspect.isfunction(self.data): - for example in self.data(): - if (not self._filter_pipline or - self._filter(self._filter_pipline) - ) and self._shard_filter(num_samples=num_samples): - yield self._transform( - example) if self._transform_pipline else example - num_samples += 1 - else: - if inspect.isgenerator(self.data): - warnings.warn( - 'Reciving generator as data source, data can only be iterated once' - ) - for example in self.data: - if (not self._filter_pipline or - self._filter(self._filter_pipline) - ) and self._shard_filter(num_samples=num_samples): - yield self._transform( - example) if self._transform_pipline else example - num_samples += 1 - - def filter(self, fn): - """ - Filters samples by the filter function and uses the filtered data to - update this dataset. - Args: - fn (callable): A filter function that takes a sample as input and - returns a boolean. Samples that return False are discarded. - """ - - self._filter_pipline.append(fn) - - return self - - def shard(self, num_shards=None, index=None): - """ - Use samples whose indices mod `index` equals 0 to update this dataset. - Args: - num_shards (int, optional): A integer representing the number of - data shards. If None, `num_shards` would be number of trainers. - Default: None - index (int, optional): A integer representing the index of the - current shard. If None, index` would be the current trainer rank - id. Default: None. - """ - if num_shards is None: - num_shards = dist.get_world_size() - if index is None: - index = dist.get_rank() - - def sharder(num_shards, index, num_samples): - if num_samples % num_shards == index: - return True - else: - return False - - fn = partial(sharder, num_shards=num_shards, index=index) - self._shard_filter = fn - return self - - def map(self, fn): - """ - Performs specific function on the dataset to transform and update every sample. - Args: - fn (callable): Transformations to be performed. It receives single - sample as argument. - """ - - self._transform_pipline.append(fn) - - return self - - def __getattr__(self, name): - return getattr(self.data, name) - - -class DatasetBuilder: - """ - A base class for all DatasetBuilder. It provides a `read()` function to turn - a data file into a MapDataset or IterDataset. - - `_get_data()` function and `_read()` function should be implemented to download - data file and read data file into a `Iterable` of the examples. - """ - lazy = False - - def __init__(self, lazy=None, name=None, **config): - if lazy is not None: - self.lazy = lazy - self.name = name - self.config = config - - def read_datasets(self, splits=None, data_files=None): - datasets = [] - assert splits or data_files, "`data_files` and `splits` can not both be None." - - if data_files: - assert isinstance(data_files, str) or isinstance( - data_files, dict - ) or isinstance(data_files, tuple) or isinstance( - data_files, list - ), "`data_files` should be a string or tuple or list or a dictionary whose key is split name ande value is a path of data file." - if isinstance(data_files, str): - split = 'train' - datasets.append(self.read(filename=data_files, split=split)) - elif isinstance(data_files, tuple) or isinstance(data_files, list): - split = 'train' - datasets += [ - self.read( - filename=filename, split=split) - for filename in data_files - ] - else: - datasets += [ - self.read( - filename=filename, split=split) - for split, filename in data_files.items() - ] - - if splits: - assert isinstance(splits, str) or ( - isinstance(splits, list) and isinstance(splits[0], str) - ) or ( - isinstance(splits, tuple) and isinstance(splits[0], str) - ), "`splits` should be a string or list of string or a tuple of string." - if isinstance(splits, str): - filename = self._get_data(splits) - datasets.append(self.read(filename=filename, split=splits)) - else: - for split in splits: - filename = self._get_data(split) - datasets.append(self.read(filename=filename, split=split)) - - return datasets if len(datasets) > 1 else datasets[0] - - def read(self, filename, split='train'): - """ - Returns an dataset containing all the examples that can be read from the file path. - If `self.lazy` is `False`, this eagerly reads all instances from `self._read()` - and returns an `MapDataset`. - If `self.lazy` is `True`, this returns an `IterDataset`, which internally - relies on the generator created from `self._read()` to lazily produce examples. - In this case your implementation of `_read()` must also be lazy - (that is, not load all examples into memory at once). - """ - - label_list = self.get_labels() - vocab_info = self.get_vocab() - - if self.lazy: - - def generate_examples(): - generator = self._read( - filename, split - ) if self._read.__code__.co_argcount > 2 else self._read( - filename) - for example in generator: - # We need to check if the example contains label column and confirm its name. - # For now we only allow `label` or `labels` to be the name of label column. - if 'labels' in example.keys(): - label_col = 'labels' - elif 'label' in example.keys(): - label_col = 'label' - else: - label_col = None - - # Convert class label to label ids. - if label_list is not None and example.get(label_col, None): - label_dict = {} - for i, label in enumerate(label_list): - label_dict[label] = i - if isinstance(example[label_col], list) or isinstance( - example[label_col], tuple): - for label_idx in range(len(example[label_col])): - example[label_col][label_idx] = label_dict[ - example[label_col][label_idx]] - else: - example[label_col] = label_dict[example[label_col]] - - yield example - else: - yield example - - return IterDataset( - generate_examples(), - label_list=label_list, - vocab_info=vocab_info) - else: - examples = self._read( - filename, - split) if self._read.__code__.co_argcount > 2 else self._read( - filename) - - # Then some validation. - if not isinstance(examples, list): - examples = list(examples) - - if not examples: - raise ValueError( - "No instances were read from the given filepath {}. " - "Is the path correct?".format(filename)) - - # We need to check if the example contains label column and confirm its name. - # For now we only allow `label` or `labels` to be the name of label column. - if 'labels' in examples[0].keys(): - label_col = 'labels' - elif 'label' in examples[0].keys(): - label_col = 'label' - else: - label_col = None - - # Convert class label to label ids. - if label_list is not None and examples[0].get(label_col, None): - label_dict = {} - for i, label in enumerate(label_list): - label_dict[label] = i - for idx in range(len(examples)): - if isinstance(examples[idx][label_col], list) or isinstance( - examples[idx][label_col], tuple): - for label_idx in range(len(examples[idx][label_col])): - examples[idx][label_col][label_idx] = label_dict[ - examples[idx][label_col][label_idx]] - else: - examples[idx][label_col] = label_dict[examples[idx][ - label_col]] - - return MapDataset( - examples, label_list=label_list, vocab_info=vocab_info) - - def _read(self, filename: str, *args): - """ - Reads examples from the given file_path and returns them as an - `Iterable` (which could be a list or could be a generator). - """ - raise NotImplementedError - - def _get_data(self, mode: str): - """ - Download examples from the given URL and customized split informations and returns a filepath. - """ - raise NotImplementedError - - def get_labels(self): - """ - Return list of class labels of the dataset if specified. - """ - return None - - def get_vocab(self): - """ - Return vocab file path of the dataset if specified. - """ - return None - - -class SimpleBuilder(DatasetBuilder): - def __init__(self, lazy, read_func): - self._read = read_func - self.lazy = lazy - - def read(self, **kwargs): - if self.lazy: - - def generate_examples(): - generator = self._read(**kwargs) - for example in generator: - yield example - - return IterDataset(generate_examples) - else: - examples = self._read(**kwargs) - if hasattr(examples, '__len__') and hasattr(examples, - '__getitem__'): - return MapDataset(examples) - else: - return MapDataset(list(examples)) diff --git a/paddlenlp/datasets/experimental/glue.py b/paddlenlp/datasets/experimental/glue.py deleted file mode 100644 index c1e7aa27fba13..0000000000000 --- a/paddlenlp/datasets/experimental/glue.py +++ /dev/null @@ -1,322 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import json -import os - -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - - -class Glue(DatasetBuilder): - BUILDER_CONFIGS = { - 'cola': { - 'url': "https://dataset.bj.bcebos.com/glue/CoLA.zip", - 'md5': 'b178a7c2f397b0433c39c7caf50a3543', - 'splits': { - 'train': [ - os.path.join('CoLA', 'train.tsv'), - 'c79d4693b8681800338aa044bf9e797b', (3, 1), 0 - ], - 'dev': [ - os.path.join('CoLA', 'dev.tsv'), - 'c5475ccefc9e7ca0917294b8bbda783c', (3, 1), 0 - ], - 'test': [ - os.path.join('CoLA', 'test.tsv'), - 'd8721b7dedda0dcca73cebb2a9f4259f', (1, ), 1 - ] - }, - 'labels': ["0", "1"] - }, - 'sst-2': { - 'url': "https://dataset.bj.bcebos.com/glue/SST.zip", - 'md5': '9f81648d4199384278b86e315dac217c', - 'splits': { - 'train': [ - os.path.join('SST-2', 'train.tsv'), - 'da409a0a939379ed32a470bc0f7fe99a', (0, 1), 1 - ], - 'dev': [ - os.path.join('SST-2', 'dev.tsv'), - '268856b487b2a31a28c0a93daaff7288', (0, 1), 1 - ], - 'test': [ - os.path.join('SST-2', 'test.tsv'), - '3230e4efec76488b87877a56ae49675a', (1, ), 1 - ] - }, - 'labels': ["0", "1"] - }, - 'sts-b': { - 'url': 'https://dataset.bj.bcebos.com/glue/STS.zip', - 'md5': 'd573676be38f1a075a5702b90ceab3de', - 'splits': { - 'train': [ - os.path.join('STS-B', 'train.tsv'), - '4f7a86dde15fe4832c18e5b970998672', (7, 8, 9), 1 - ], - 'dev': [ - os.path.join('STS-B', 'dev.tsv'), - '5f4d6b0d2a5f268b1b56db773ab2f1fe', (7, 8, 9), 1 - ], - 'test': [ - os.path.join('STS-B', 'test.tsv'), - '339b5817e414d19d9bb5f593dd94249c', (7, 8), 1 - ] - }, - 'labels': None - }, - 'qqp': { - 'url': 'https://dataset.bj.bcebos.com/glue/QQP.zip', - 'md5': '884bf26e39c783d757acc510a2a516ef', - 'splits': { - 'train': [ - os.path.join('QQP', 'train.tsv'), - 'e003db73d277d38bbd83a2ef15beb442', (3, 4, 5), 1 - ], - 'dev': [ - os.path.join('QQP', 'dev.tsv'), - 'cff6a448d1580132367c22fc449ec214', (3, 4, 5), 1 - ], - 'test': [ - os.path.join('QQP', 'test.tsv'), - '73de726db186b1b08f071364b2bb96d0', (1, 2), 1 - ] - }, - 'labels': ["0", "1"] - }, - 'mnli': { - 'url': 'https://dataset.bj.bcebos.com/glue/MNLI.zip', - 'md5': 'e343b4bdf53f927436d0792203b9b9ff', - 'splits': { - 'train': [ - os.path.join('MNLI', 'train.tsv'), - '220192295e23b6705f3545168272c740', (8, 9, 11), 1 - ], - 'dev_matched': [ - os.path.join('MNLI', 'dev_matched.tsv'), - 'c3fa2817007f4cdf1a03663611a8ad23', (8, 9, 15), 1 - ], - 'dev_mismatched': [ - os.path.join('MNLI', 'dev_mismatched.tsv'), - 'b219e6fe74e4aa779e2f417ffe713053', (8, 9, 15), 1 - ], - 'test_matched': [ - os.path.join('MNLI', 'test_matched.tsv'), - '33ea0389aedda8a43dabc9b3579684d9', (8, 9), 1 - ], - 'test_mismatched': [ - os.path.join('MNLI', 'test_mismatched.tsv'), - '7d2f60a73d54f30d8a65e474b615aeb6', (8, 9), 1 - ] - }, - 'labels': ["contradiction", "entailment", "neutral"] - }, - 'qnli': { - 'url': 'https://dataset.bj.bcebos.com/glue/QNLI.zip', - 'md5': 'b4efd6554440de1712e9b54e14760e82', - 'splits': { - 'train': [ - os.path.join('QNLI', 'train.tsv'), - '5e6063f407b08d1f7c7074d049ace94a', (1, 2, 3), 1 - ], - 'dev': [ - os.path.join('QNLI', 'dev.tsv'), - '1e81e211959605f144ba6c0ad7dc948b', (1, 2, 3), 1 - ], - 'test': [ - os.path.join('QNLI', 'test.tsv'), - 'f2a29f83f3fe1a9c049777822b7fa8b0', (1, 2), 1 - ] - }, - 'labels': ["entailment", "not_entailment"] - }, - 'rte': { - 'url': 'https://dataset.bj.bcebos.com/glue/RTE.zip', - 'md5': 'bef554d0cafd4ab6743488101c638539', - 'splits': { - 'train': [ - os.path.join('RTE', 'train.tsv'), - 'd2844f558d111a16503144bb37a8165f', (1, 2, 3), 1 - ], - 'dev': [ - os.path.join('RTE', 'dev.tsv'), - '973cb4178d4534cf745a01c309d4a66c', (1, 2, 3), 1 - ], - 'test': [ - os.path.join('RTE', 'test.tsv'), - '6041008f3f3e48704f57ce1b88ad2e74', (1, 2), 1 - ] - }, - 'labels': ["entailment", "not_entailment"] - }, - 'wnli': { - 'url': 'https://dataset.bj.bcebos.com/glue/WNLI.zip', - 'md5': 'a1b4bd2861017d302d29e42139657a42', - 'splits': { - 'train': [ - os.path.join('WNLI', 'train.tsv'), - '5cdc5a87b7be0c87a6363fa6a5481fc1', (1, 2, 3), 1 - ], - 'dev': [ - os.path.join('WNLI', 'dev.tsv'), - 'a79a6dd5d71287bcad6824c892e517ee', (1, 2, 3), 1 - ], - 'test': [ - os.path.join('WNLI', 'test.tsv'), - 'a18789ba4f60f6fdc8cb4237e4ba24b5', (1, 2), 1 - ] - }, - 'labels': ["0", "1"] - }, - 'mrpc': { - 'url': { - 'train_data': - 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_train.txt', - 'dev_id': 'https://dataset.bj.bcebos.com/glue/mrpc/dev_ids.tsv', - 'test_data': - 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_test.txt' - }, - 'md5': { - 'train_data': '793daf7b6224281e75fe61c1f80afe35', - 'dev_id': '7ab59a1b04bd7cb773f98a0717106c9b', - 'test_data': 'e437fdddb92535b820fe8852e2df8a49' - }, - 'splits': { - 'train': [ - os.path.join('MRPC', 'train.tsv'), - 'dc2dac669a113866a6480a0b10cd50bf', (3, 4, 0), 1 - ], - 'dev': [ - os.path.join('MRPC', 'dev.tsv'), - '185958e46ba556b38c6a7cc63f3a2135', (3, 4, 0), 1 - ], - 'test': [ - os.path.join('MRPC', 'test.tsv'), - '4825dab4b4832f81455719660b608de5', (3, 4), 1 - ] - }, - 'labels': ["0", "1"] - } - } - - def _get_data(self, mode, **kwargs): - builder_config = self.BUILDER_CONFIGS[self.name] - if self.name != 'mrpc': - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash, _, _ = builder_config['splits'][mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or ( - data_hash and not md5file(fullname) == data_hash): - get_path_from_url(builder_config['url'], default_root, - builder_config['md5']) - - else: - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash, _, _ = builder_config['splits'][mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or ( - data_hash and not md5file(fullname) == data_hash): - if mode in ('train', 'dev'): - dev_id_path = get_path_from_url( - builder_config['url']['dev_id'], - os.path.join(default_root, 'MRPC'), - builder_config['md5']['dev_id']) - train_data_path = get_path_from_url( - builder_config['url']['train_data'], - os.path.join(default_root, 'MRPC'), - builder_config['md5']['train_data']) - # read dev data ids - dev_ids = [] - print(dev_id_path) - with open(dev_id_path, encoding='utf-8') as ids_fh: - for row in ids_fh: - dev_ids.append(row.strip().split('\t')) - - # generate train and dev set - train_path = os.path.join(default_root, 'MRPC', 'train.tsv') - dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv') - with open(train_data_path, encoding='utf-8') as data_fh: - with open( - train_path, 'w', encoding='utf-8') as train_fh: - with open(dev_path, 'w', encoding='utf8') as dev_fh: - header = data_fh.readline() - train_fh.write(header) - dev_fh.write(header) - for row in data_fh: - label, id1, id2, s1, s2 = row.strip().split( - '\t') - example = '%s\t%s\t%s\t%s\t%s\n' % ( - label, id1, id2, s1, s2) - if [id1, id2] in dev_ids: - dev_fh.write(example) - else: - train_fh.write(example) - - else: - test_data_path = get_path_from_url( - builder_config['url']['test_data'], - os.path.join(default_root, 'MRPC'), - builder_config['md5']['test_data']) - test_path = os.path.join(default_root, 'MRPC', 'test.tsv') - with open(test_data_path, encoding='utf-8') as data_fh: - with open(test_path, 'w', encoding='utf-8') as test_fh: - header = data_fh.readline() - test_fh.write( - 'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n') - for idx, row in enumerate(data_fh): - label, id1, id2, s1, s2 = row.strip().split( - '\t') - test_fh.write('%d\t%s\t%s\t%s\t%s\n' % - (idx, id1, id2, s1, s2)) - - return fullname - - def _read(self, filename, split): - _, _, field_indices, num_discard_samples = self.BUILDER_CONFIGS[ - self.name]['splits'][split] - with open(filename, 'r', encoding='utf-8') as f: - for idx, line in enumerate(f): - if idx < num_discard_samples: - continue - line_stripped = line.strip().split('\t') - if not line_stripped: - continue - example = [line_stripped[indice] for indice in field_indices] - if self.name in ['cola', 'sst-2']: - yield { - 'sentence': example[0] - } if 'test' in split else { - 'sentence': example[0], - 'labels': example[-1] - } - else: - yield { - 'sentence1': example[0], - 'sentence2': example[1] - } if 'test' in split else { - 'sentence1': example[0], - 'sentence2': example[1], - 'labels': example[-1] - } - - def get_labels(self): - """ - Return labels of the Glue task. - """ - return self.BUILDER_CONFIGS[self.name]['labels'] diff --git a/paddlenlp/datasets/experimental/imdb.py b/paddlenlp/datasets/experimental/imdb.py deleted file mode 100644 index 89c4e04b83acb..0000000000000 --- a/paddlenlp/datasets/experimental/imdb.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import json -import io -import os -import string - -import numpy as np - -from paddle.dataset.common import md5file -from paddlenlp.utils.downloader import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['Imdb'] - - -class Imdb(DatasetBuilder): - """ - Implementation of `IMDB `_ dataset. - - """ - URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz' - MD5 = '7c2ac02c03563afcf9b574c7e56c153a' - - def _get_data(self, mode, **kwargs): - """Downloads dataset.""" - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - data_dir = os.path.join(default_root, "aclImdb", mode) - if not os.path.exists(data_dir): - path = get_path_from_url(self.URL, default_root, self.MD5) - return data_dir - - def _read(self, data_dir, *args): - translator = str.maketrans('', '', string.punctuation) - - for label in ["pos", "neg"]: - root = os.path.join(data_dir, label) - data_files = os.listdir(root) - data_files.sort() - - if label == "pos": - label_id = "1" - elif label == "neg": - label_id = "0" - for f in data_files: - f = os.path.join(root, f) - with io.open(f, 'r', encoding='utf8') as fr: - data = fr.readlines() - data = data[0].translate(translator) - yield {"text": data, "label": label_id} - - def get_labels(self): - """ - Return labels of the Imdb object. - """ - return ["0", "1"] diff --git a/paddlenlp/datasets/experimental/lcqmc.py b/paddlenlp/datasets/experimental/lcqmc.py deleted file mode 100644 index ceacf6583e608..0000000000000 --- a/paddlenlp/datasets/experimental/lcqmc.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import json -import os - -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['LCQMC'] - - -class LCQMC(DatasetBuilder): - """ - LCQMC:A Large-scale Chinese Question Matching Corpus - More information please refer to `https://www.aclweb.org/anthology/C18-1166/` - - """ - - URL = "https://bj.bcebos.com/paddlehub-dataset/lcqmc.tar.gz" - MD5 = "62a7ba36f786a82ae59bbde0b0a9af0c" - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - SPLITS = { - 'train': META_INFO( - os.path.join('lcqmc', 'train.tsv'), - '2193c022439b038ac12c0ae918b211a1'), - 'dev': META_INFO( - os.path.join('lcqmc', 'dev.tsv'), - 'c5dcba253cb4105d914964fd8b3c0e94'), - 'test': META_INFO( - os.path.join('lcqmc', 'test.tsv'), - '8f4b71e15e67696cc9e112a459ec42bd'), - } - - def _get_data(self, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash = self.SPLITS[mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - get_path_from_url(self.URL, default_root, self.MD5) - - return fullname - - def _read(self, filename): - """Reads data.""" - with open(filename, 'r', encoding='utf-8') as f: - head = None - for line in f: - data = line.strip().split("\t") - if not head: - head = data - else: - query, title, label = data - yield {"query": query, "title": title, "label": label} - - def get_labels(self): - """ - Return labels of the LCQMC object. - """ - return ["0", "1"] diff --git a/paddlenlp/datasets/experimental/msra_ner.py b/paddlenlp/datasets/experimental/msra_ner.py deleted file mode 100644 index 95d083a8245eb..0000000000000 --- a/paddlenlp/datasets/experimental/msra_ner.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import os -import warnings - -from paddle.io import Dataset -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['MSRA_NER'] - - -class MSRA_NER(DatasetBuilder): - URL = "https://paddlenlp.bj.bcebos.com/datasets/msra_ner.tar.gz" - MD5 = None - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - SPLITS = { - 'train': META_INFO(os.path.join('msra_ner', 'train.tsv'), None), - 'test': META_INFO(os.path.join('msra_ner', 'test.tsv'), None) - } - - def _get_data(self, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash = self.SPLITS[mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - - get_path_from_url(self.URL, default_root, self.MD5) - - return fullname - - def _read(self, filename, *args): - with open(filename, 'r', encoding='utf-8') as f: - for line in f: - line_stripped = line.strip().split('\t') - if not line_stripped: - break - if len(line_stripped) == 2: - tokens = line_stripped[0].split("\002") - tags = line_stripped[1].split("\002") - else: - tokens = line_stripped.split("\002") - tags = [] - yield {"tokens": tokens, "labels": tags} - - def get_labels(self): - - return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] diff --git a/paddlenlp/datasets/experimental/peoples_daily_ner.py b/paddlenlp/datasets/experimental/peoples_daily_ner.py deleted file mode 100644 index c891f03509cfd..0000000000000 --- a/paddlenlp/datasets/experimental/peoples_daily_ner.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import os -import warnings - -from paddle.io import Dataset -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['PeoplesDailyNER'] - - -class PeoplesDailyNER(DatasetBuilder): - URL = "https://paddlenlp.bj.bcebos.com/datasets/peoples_daily_ner.tar.gz" - MD5 = None - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - SPLITS = { - 'train': - META_INFO(os.path.join('peoples_daily_ner', 'train.tsv'), None), - 'test': META_INFO(os.path.join('peoples_daily_ner', 'test.tsv'), None) - } - - def _get_data(self, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash = self.SPLITS[mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - get_path_from_url(self.URL, default_root, self.MD5) - - return fullname - - def _read(self, filename, *args): - with open(filename, 'r', encoding='utf-8') as f: - for line in f: - line_stripped = line.strip().split('\t') - if not line_stripped: - break - if len(line_stripped) == 2: - tokens = line_stripped[0].split("\002") - tags = line_stripped[1].split("\002") - else: - tokens = line_stripped.split("\002") - tags = [] - yield {"tokens": tokens, "labels": tags} - - def get_labels(self): - - return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] diff --git a/paddlenlp/datasets/experimental/poetry.py b/paddlenlp/datasets/experimental/poetry.py deleted file mode 100644 index 53830ddac7072..0000000000000 --- a/paddlenlp/datasets/experimental/poetry.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import os -import warnings - -from paddle.io import Dataset -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['Poetry'] - - -class Poetry(DatasetBuilder): - URL = "https://paddlenlp.bj.bcebos.com/datasets/poetry.tar.gz" - MD5 = '8edd7eda1b273145b70ef29c82cd622b' - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - SPLITS = { - 'train': META_INFO( - os.path.join('poetry', 'train.tsv'), - '176c6202b5e71656ae7e7848eec4c54f'), - 'dev': META_INFO( - os.path.join('poetry', 'dev.tsv'), - '737e4b6da5facdc0ac33fe688df19931'), - 'test': META_INFO( - os.path.join('poetry', 'test.tsv'), - '1dca907b2d712730c7c828f8acee7431'), - } - - def _get_data(self, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash = self.SPLITS[mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - - get_path_from_url(self.URL, default_root, self.MD5) - - return fullname - - def _read(self, filename, *args): - with open(filename, 'r', encoding='utf-8') as f: - for line in f: - line_stripped = line.strip().split('\t') - if not line_stripped: - break - if len(line_stripped) == 2: - tokens = line_stripped[0] - labels = line_stripped[1] - else: - tokens = line_stripped - labels = [] - yield {"tokens": tokens, "labels": labels} \ No newline at end of file diff --git a/paddlenlp/datasets/experimental/ptb.py b/paddlenlp/datasets/experimental/ptb.py deleted file mode 100644 index 0ab384f4211fc..0000000000000 --- a/paddlenlp/datasets/experimental/ptb.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import collections - -from paddle.io import Dataset - -from paddle.utils.download import get_path_from_url -from paddle.dataset.common import md5file -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['PTB'] - - -class PTB(DatasetBuilder): - URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' - MD5 = "30177ea32e27c525793142b6bf2c8e2d" - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - SPLITS = { - 'train': META_INFO( - os.path.join('simple-examples', 'data', 'ptb.train.txt'), - "f26c4b92c5fdc7b3f8c7cdcb991d8420"), - 'valid': META_INFO( - os.path.join('simple-examples', 'data', 'ptb.valid.txt'), - "aa0affc06ff7c36e977d7cd49e3839bf"), - 'test': META_INFO( - os.path.join('simple-examples', 'data', 'ptb.test.txt'), - "8b80168b89c18661a38ef683c0dc3721") - } - - def _get_data(self, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash = self.SPLITS[mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - - get_path_from_url(self.URL, default_root, self.MD5) - - return fullname - - def _read(self, filename, *args): - with open(filename, 'r', encoding='utf-8') as f: - for line in f: - line_stripped = line.strip() - yield {"sentence": line_stripped} diff --git a/paddlenlp/datasets/experimental/squad.py b/paddlenlp/datasets/experimental/squad.py deleted file mode 100644 index bbc8a3495787e..0000000000000 --- a/paddlenlp/datasets/experimental/squad.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import json -import os - -from paddle.dataset.common import md5file -from paddle.utils.download import get_path_from_url -from paddlenlp.utils.env import DATA_HOME -from . import DatasetBuilder - -__all__ = ['SQuAD'] - - -class SQuAD(DatasetBuilder): - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5', 'URL')) - SPLITS = { - 'train_v1': META_INFO( - os.path.join('train-v1.1.json'), '981b29407e0affa3b1b156f72073b945', - 'https://paddlenlp.bj.bcebos.com/datasets/squad/train-v1.1.json'), - 'dev_v1': META_INFO( - os.path.join('dev-v1.1.json'), '3e85deb501d4e538b6bc56f786231552', - 'https://paddlenlp.bj.bcebos.com/datasets/squad/dev-v1.1.json'), - 'train_v2': META_INFO( - os.path.join('train-v2.0.json'), '62108c273c268d70893182d5cf8df740', - 'https://paddlenlp.bj.bcebos.com/datasets/squad/train-v2.0.json'), - 'dev_v2': META_INFO( - os.path.join('dev-v2.0.json'), '246adae8b7002f8679c027697b0b7cf8', - 'https://paddlenlp.bj.bcebos.com/datasets/squad/dev-v2.0.json') - } - - def _get_data(self, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - filename, data_hash, URL = self.SPLITS[mode] - fullname = os.path.join(default_root, filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - get_path_from_url(URL, default_root) - - return fullname - - def _read(self, filename, *args): - with open(filename, "r", encoding="utf8") as f: - input_data = json.load(f)["data"] - for entry in input_data: - title = entry.get("title", "").strip() - for paragraph in entry["paragraphs"]: - context = paragraph["context"].strip() - for qa in paragraph["qas"]: - qas_id = qa["id"] - question = qa["question"].strip() - answer_starts = [] - answers = [] - is_impossible = False - - if "is_impossible" in qa.keys(): - is_impossible = qa["is_impossible"] - - answer_starts = [ - answer["answer_start"] for answer in qa["answers"] - ] - answers = [ - answer["text"].strip() for answer in qa["answers"] - ] - - yield { - 'id': qas_id, - 'title': title, - 'context': context, - 'question': question, - 'answers': answers, - 'answer_starts': answer_starts, - 'is_impossible': is_impossible - } diff --git a/paddlenlp/datasets/glue.py b/paddlenlp/datasets/glue.py index 9d099e59f31f2..c1e7aa27fba13 100644 --- a/paddlenlp/datasets/glue.py +++ b/paddlenlp/datasets/glue.py @@ -12,599 +12,311 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import collections -import io +import json import os -import warnings -from paddle.io import Dataset from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url from paddlenlp.utils.env import DATA_HOME - -from .dataset import TSVDataset - -__all__ = [ - 'GlueCoLA', - 'GlueSST2', - 'GlueMRPC', - 'GlueSTSB', - 'GlueQQP', - 'GlueMNLI', - 'GlueQNLI', - 'GlueRTE', - 'GlueWNLI', -] - - -class _GlueDataset(TSVDataset): - URL = None - MD5 = None - META_INFO = collections.namedtuple( - 'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples')) - SPLITS = {} # mode: file, md5, field_indices, num_discard_samples - - def __init__(self, - mode='train', - root=None, - return_all_fields=False, - **kwargs): - if return_all_fields: - # self.SPLITS = copy.deepcopy(self.__class__.SPLITS) - # self.SPLITS[mode].field_indices = splits - splits = copy.deepcopy(self.__class__.SPLITS) - mode_info = list(splits[mode]) - mode_info[2] = None - splits[mode] = self.META_INFO(*mode_info) - self.SPLITS = splits - - self._get_data(root, mode, **kwargs) - - def _get_data(self, root, mode, **kwargs): - default_root = os.path.join(DATA_HOME, 'glue') - filename, data_hash, field_indices, num_discard_samples = self.SPLITS[ - mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - path = get_path_from_url(self.URL, default_root, self.MD5) - fullname = os.path.join(default_root, filename) - super(_GlueDataset, self).__init__( - fullname, - field_indices=field_indices, - num_discard_samples=num_discard_samples, - **kwargs) - - -class GlueCoLA(_GlueDataset): - """ - The Corpus of Linguistic Acceptability (Warstadt et al., 2018) consists of - English acceptability judgments drawn from books and journal articles on - linguistic theory. - Each example is a sequence of words annotated with whether it is a - grammatical English sentence. From https://gluebenchmark.com/tasks - Args: - mode ('train'|'dev'|'test'): Dataset segment. Default: 'train'. - root (str): Path to temp folder for storing data. - return_all_fields (bool): Return all fields available in the dataset. - Default: False. - Example: - .. code-block:: python - - from paddlenlp.datasets import GlueCoLA - cola_dev = GlueCoLA('dev', root='./datasets/cola') - len(cola_dev) # 1043 - len(cola_dev[0]) # 2 - # ['The sailors rode the breeze clear of the rocks.', '1'] - cola_dev[0] - cola_test = GlueCoLA('test', root='./datasets/cola') - len(cola_test) # 1063 - len(cola_test[0]) # 1 - cola_test[0] # ['Bill whistled past the house.'] - """ - URL = "https://dataset.bj.bcebos.com/glue/CoLA.zip" - MD5 = 'b178a7c2f397b0433c39c7caf50a3543' - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('CoLA', 'train.tsv'), - 'c79d4693b8681800338aa044bf9e797b', (3, 1), 0), - 'dev': _GlueDataset.META_INFO( - os.path.join('CoLA', 'dev.tsv'), 'c5475ccefc9e7ca0917294b8bbda783c', - (3, 1), 0), - 'test': _GlueDataset.META_INFO( - os.path.join('CoLA', 'test.tsv'), - 'd8721b7dedda0dcca73cebb2a9f4259f', (1, ), 1) - } - - def get_labels(self): - """ - Return labels of the GlueCoLA object. - """ - return ["0", "1"] - - -class GlueSST2(_GlueDataset): - """ - The Stanford Sentiment Treebank (Socher et al., 2013) consists of sentences - from movie reviews and human annotations of their sentiment. - From https://gluebenchmark.com/tasks - Args: - mode ('train'|'dev'|'test'): Dataset segment. Default: 'train'. - root (str): Path to temp folder for storing data. - return_all_fields (bool): Return all fields available in the dataset. - Default: False. - Examples: - .. code-block:: python - - from paddlenlp.datasets import GlueSST2 - sst_dev = GlueSST2('dev', root='./datasets/sst') - len(sst_dev) # 872 - len(sst_dev[0]) # 2 - # ["it 's a charming and often affecting journey . ", '1'] - sst_dev[0] - sst_test = GlueSST2('test', root='./datasets/sst') - len(sst_test) # 1821 - len(sst_test[0]) # 1 - sst_test[0] # ['uneasy mishmash of styles and genres .'] - """ - - URL = 'https://dataset.bj.bcebos.com/glue/SST.zip' - MD5 = '9f81648d4199384278b86e315dac217c' - - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('SST-2', 'train.tsv'), - 'da409a0a939379ed32a470bc0f7fe99a', (0, 1), 1), - 'dev': _GlueDataset.META_INFO( - os.path.join('SST-2', 'dev.tsv'), - '268856b487b2a31a28c0a93daaff7288', (0, 1), 1), - 'test': _GlueDataset.META_INFO( - os.path.join('SST-2', 'test.tsv'), - '3230e4efec76488b87877a56ae49675a', (1, ), 1) - } - - def get_labels(self): - """ - Return labels of the GlueSST2 object. - """ - return ["0", "1"] - - -class GlueMRPC(_GlueDataset): - """ - The Microsoft Research Paraphrase Corpus dataset. - From https://gluebenchmark.com/tasks - Args: - root (str): Path to temp folder for storing data. - mode ('train'|'dev'|'test'): Dataset segment. Default: 'train'. - Example: - .. code-block:: python - - from paddlenlp.datasets import GlueMRPC - mrpc_dev = GlueMRPC('dev', root='./datasets/mrpc') - len(mrpc_dev) # 408 - len(mrpc_dev[0]) # 3 - mrpc_dev[0] # ["He said the foodservice pie business doesn 't fit - # the company 's long-term growth strategy .", - # '" The foodservice pie business does not fit our - # long-term growth strategy .', '1'] - mrpc_test = GlueMRPC('test', root='./datasets/mrpc') - len(mrpc_test) # 1725 - len(mrpc_test[0]) # 2 - mrpc_test[0] - # ["PCCW 's chief operating officer , Mike Butcher , and Alex Arena , - # the chief financial officer , will report directly to Mr So .", - # 'Current Chief Operating Officer Mike Butcher and Group Chief - # Financial Officer Alex Arena will report to So .'] - """ - - DEV_ID_URL = 'https://dataset.bj.bcebos.com/glue/mrpc/dev_ids.tsv' - DEV_ID_MD5 = '7ab59a1b04bd7cb773f98a0717106c9b' - TRAIN_DATA_URL = 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_train.txt' - TRAIN_DATA_MD5 = '793daf7b6224281e75fe61c1f80afe35' - TEST_DATA_URL = 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_test.txt' - TEST_DATA_MD5 = 'e437fdddb92535b820fe8852e2df8a49' - - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('MRPC', 'train.tsv'), - 'dc2dac669a113866a6480a0b10cd50bf', (3, 4, 0), 1), - 'dev': _GlueDataset.META_INFO( - os.path.join('MRPC', 'dev.tsv'), '185958e46ba556b38c6a7cc63f3a2135', - (3, 4, 0), 1), - 'test': _GlueDataset.META_INFO( - os.path.join('MRPC', 'test.tsv'), - '4825dab4b4832f81455719660b608de5', (3, 4), 1) +from . import DatasetBuilder + + +class Glue(DatasetBuilder): + BUILDER_CONFIGS = { + 'cola': { + 'url': "https://dataset.bj.bcebos.com/glue/CoLA.zip", + 'md5': 'b178a7c2f397b0433c39c7caf50a3543', + 'splits': { + 'train': [ + os.path.join('CoLA', 'train.tsv'), + 'c79d4693b8681800338aa044bf9e797b', (3, 1), 0 + ], + 'dev': [ + os.path.join('CoLA', 'dev.tsv'), + 'c5475ccefc9e7ca0917294b8bbda783c', (3, 1), 0 + ], + 'test': [ + os.path.join('CoLA', 'test.tsv'), + 'd8721b7dedda0dcca73cebb2a9f4259f', (1, ), 1 + ] + }, + 'labels': ["0", "1"] + }, + 'sst-2': { + 'url': "https://dataset.bj.bcebos.com/glue/SST.zip", + 'md5': '9f81648d4199384278b86e315dac217c', + 'splits': { + 'train': [ + os.path.join('SST-2', 'train.tsv'), + 'da409a0a939379ed32a470bc0f7fe99a', (0, 1), 1 + ], + 'dev': [ + os.path.join('SST-2', 'dev.tsv'), + '268856b487b2a31a28c0a93daaff7288', (0, 1), 1 + ], + 'test': [ + os.path.join('SST-2', 'test.tsv'), + '3230e4efec76488b87877a56ae49675a', (1, ), 1 + ] + }, + 'labels': ["0", "1"] + }, + 'sts-b': { + 'url': 'https://dataset.bj.bcebos.com/glue/STS.zip', + 'md5': 'd573676be38f1a075a5702b90ceab3de', + 'splits': { + 'train': [ + os.path.join('STS-B', 'train.tsv'), + '4f7a86dde15fe4832c18e5b970998672', (7, 8, 9), 1 + ], + 'dev': [ + os.path.join('STS-B', 'dev.tsv'), + '5f4d6b0d2a5f268b1b56db773ab2f1fe', (7, 8, 9), 1 + ], + 'test': [ + os.path.join('STS-B', 'test.tsv'), + '339b5817e414d19d9bb5f593dd94249c', (7, 8), 1 + ] + }, + 'labels': None + }, + 'qqp': { + 'url': 'https://dataset.bj.bcebos.com/glue/QQP.zip', + 'md5': '884bf26e39c783d757acc510a2a516ef', + 'splits': { + 'train': [ + os.path.join('QQP', 'train.tsv'), + 'e003db73d277d38bbd83a2ef15beb442', (3, 4, 5), 1 + ], + 'dev': [ + os.path.join('QQP', 'dev.tsv'), + 'cff6a448d1580132367c22fc449ec214', (3, 4, 5), 1 + ], + 'test': [ + os.path.join('QQP', 'test.tsv'), + '73de726db186b1b08f071364b2bb96d0', (1, 2), 1 + ] + }, + 'labels': ["0", "1"] + }, + 'mnli': { + 'url': 'https://dataset.bj.bcebos.com/glue/MNLI.zip', + 'md5': 'e343b4bdf53f927436d0792203b9b9ff', + 'splits': { + 'train': [ + os.path.join('MNLI', 'train.tsv'), + '220192295e23b6705f3545168272c740', (8, 9, 11), 1 + ], + 'dev_matched': [ + os.path.join('MNLI', 'dev_matched.tsv'), + 'c3fa2817007f4cdf1a03663611a8ad23', (8, 9, 15), 1 + ], + 'dev_mismatched': [ + os.path.join('MNLI', 'dev_mismatched.tsv'), + 'b219e6fe74e4aa779e2f417ffe713053', (8, 9, 15), 1 + ], + 'test_matched': [ + os.path.join('MNLI', 'test_matched.tsv'), + '33ea0389aedda8a43dabc9b3579684d9', (8, 9), 1 + ], + 'test_mismatched': [ + os.path.join('MNLI', 'test_mismatched.tsv'), + '7d2f60a73d54f30d8a65e474b615aeb6', (8, 9), 1 + ] + }, + 'labels': ["contradiction", "entailment", "neutral"] + }, + 'qnli': { + 'url': 'https://dataset.bj.bcebos.com/glue/QNLI.zip', + 'md5': 'b4efd6554440de1712e9b54e14760e82', + 'splits': { + 'train': [ + os.path.join('QNLI', 'train.tsv'), + '5e6063f407b08d1f7c7074d049ace94a', (1, 2, 3), 1 + ], + 'dev': [ + os.path.join('QNLI', 'dev.tsv'), + '1e81e211959605f144ba6c0ad7dc948b', (1, 2, 3), 1 + ], + 'test': [ + os.path.join('QNLI', 'test.tsv'), + 'f2a29f83f3fe1a9c049777822b7fa8b0', (1, 2), 1 + ] + }, + 'labels': ["entailment", "not_entailment"] + }, + 'rte': { + 'url': 'https://dataset.bj.bcebos.com/glue/RTE.zip', + 'md5': 'bef554d0cafd4ab6743488101c638539', + 'splits': { + 'train': [ + os.path.join('RTE', 'train.tsv'), + 'd2844f558d111a16503144bb37a8165f', (1, 2, 3), 1 + ], + 'dev': [ + os.path.join('RTE', 'dev.tsv'), + '973cb4178d4534cf745a01c309d4a66c', (1, 2, 3), 1 + ], + 'test': [ + os.path.join('RTE', 'test.tsv'), + '6041008f3f3e48704f57ce1b88ad2e74', (1, 2), 1 + ] + }, + 'labels': ["entailment", "not_entailment"] + }, + 'wnli': { + 'url': 'https://dataset.bj.bcebos.com/glue/WNLI.zip', + 'md5': 'a1b4bd2861017d302d29e42139657a42', + 'splits': { + 'train': [ + os.path.join('WNLI', 'train.tsv'), + '5cdc5a87b7be0c87a6363fa6a5481fc1', (1, 2, 3), 1 + ], + 'dev': [ + os.path.join('WNLI', 'dev.tsv'), + 'a79a6dd5d71287bcad6824c892e517ee', (1, 2, 3), 1 + ], + 'test': [ + os.path.join('WNLI', 'test.tsv'), + 'a18789ba4f60f6fdc8cb4237e4ba24b5', (1, 2), 1 + ] + }, + 'labels': ["0", "1"] + }, + 'mrpc': { + 'url': { + 'train_data': + 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_train.txt', + 'dev_id': 'https://dataset.bj.bcebos.com/glue/mrpc/dev_ids.tsv', + 'test_data': + 'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_test.txt' + }, + 'md5': { + 'train_data': '793daf7b6224281e75fe61c1f80afe35', + 'dev_id': '7ab59a1b04bd7cb773f98a0717106c9b', + 'test_data': 'e437fdddb92535b820fe8852e2df8a49' + }, + 'splits': { + 'train': [ + os.path.join('MRPC', 'train.tsv'), + 'dc2dac669a113866a6480a0b10cd50bf', (3, 4, 0), 1 + ], + 'dev': [ + os.path.join('MRPC', 'dev.tsv'), + '185958e46ba556b38c6a7cc63f3a2135', (3, 4, 0), 1 + ], + 'test': [ + os.path.join('MRPC', 'test.tsv'), + '4825dab4b4832f81455719660b608de5', (3, 4), 1 + ] + }, + 'labels': ["0", "1"] + } } - def _get_data(self, root, mode, **kwargs): - default_root = os.path.join(DATA_HOME, 'glue') - filename, data_hash, field_indices, num_discard_samples = self.SPLITS[ - mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - if mode in ('train', 'dev'): - dev_id_path = get_path_from_url( - self.DEV_ID_URL, - os.path.join(default_root, 'MRPC'), self.DEV_ID_MD5) - train_data_path = get_path_from_url( - self.TRAIN_DATA_URL, - os.path.join(default_root, 'MRPC'), self.TRAIN_DATA_MD5) - # read dev data ids - dev_ids = [] - with io.open(dev_id_path, encoding='utf-8') as ids_fh: - for row in ids_fh: - dev_ids.append(row.strip().split('\t')) - - # generate train and dev set - train_path = os.path.join(default_root, 'MRPC', 'train.tsv') - dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv') - with io.open(train_data_path, encoding='utf-8') as data_fh: - with io.open(train_path, 'w', encoding='utf-8') as train_fh: - with io.open(dev_path, 'w', encoding='utf8') as dev_fh: + def _get_data(self, mode, **kwargs): + builder_config = self.BUILDER_CONFIGS[self.name] + if self.name != 'mrpc': + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash, _, _ = builder_config['splits'][mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or ( + data_hash and not md5file(fullname) == data_hash): + get_path_from_url(builder_config['url'], default_root, + builder_config['md5']) + + else: + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash, _, _ = builder_config['splits'][mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or ( + data_hash and not md5file(fullname) == data_hash): + if mode in ('train', 'dev'): + dev_id_path = get_path_from_url( + builder_config['url']['dev_id'], + os.path.join(default_root, 'MRPC'), + builder_config['md5']['dev_id']) + train_data_path = get_path_from_url( + builder_config['url']['train_data'], + os.path.join(default_root, 'MRPC'), + builder_config['md5']['train_data']) + # read dev data ids + dev_ids = [] + print(dev_id_path) + with open(dev_id_path, encoding='utf-8') as ids_fh: + for row in ids_fh: + dev_ids.append(row.strip().split('\t')) + + # generate train and dev set + train_path = os.path.join(default_root, 'MRPC', 'train.tsv') + dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv') + with open(train_data_path, encoding='utf-8') as data_fh: + with open( + train_path, 'w', encoding='utf-8') as train_fh: + with open(dev_path, 'w', encoding='utf8') as dev_fh: + header = data_fh.readline() + train_fh.write(header) + dev_fh.write(header) + for row in data_fh: + label, id1, id2, s1, s2 = row.strip().split( + '\t') + example = '%s\t%s\t%s\t%s\t%s\n' % ( + label, id1, id2, s1, s2) + if [id1, id2] in dev_ids: + dev_fh.write(example) + else: + train_fh.write(example) + + else: + test_data_path = get_path_from_url( + builder_config['url']['test_data'], + os.path.join(default_root, 'MRPC'), + builder_config['md5']['test_data']) + test_path = os.path.join(default_root, 'MRPC', 'test.tsv') + with open(test_data_path, encoding='utf-8') as data_fh: + with open(test_path, 'w', encoding='utf-8') as test_fh: header = data_fh.readline() - train_fh.write(header) - dev_fh.write(header) - for row in data_fh: + test_fh.write( + 'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n') + for idx, row in enumerate(data_fh): label, id1, id2, s1, s2 = row.strip().split( '\t') - example = '%s\t%s\t%s\t%s\t%s\n' % (label, id1, - id2, s1, s2) - if [id1, id2] in dev_ids: - dev_fh.write(example) - else: - train_fh.write(example) - else: - test_data_path = get_path_from_url( - self.TEST_DATA_URL, - os.path.join(default_root, 'MRPC'), self.TEST_DATA_MD5) - test_path = os.path.join(default_root, 'MRPC', 'test.tsv') - with io.open(test_data_path, encoding='utf-8') as data_fh: - with io.open(test_path, 'w', encoding='utf-8') as test_fh: - header = data_fh.readline() - test_fh.write( - 'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n') - for idx, row in enumerate(data_fh): - label, id1, id2, s1, s2 = row.strip().split('\t') - test_fh.write('%d\t%s\t%s\t%s\t%s\n' % - (idx, id1, id2, s1, s2)) - root = default_root - super(GlueMRPC, self)._get_data(root, mode, **kwargs) - - def get_labels(self): - """ - Return labels of the GlueMRPC object. - """ - return ["0", "1"] - - -class GlueSTSB(_GlueDataset): - """ - The Semantic Textual Similarity Benchmark (Cer et al., 2017) is a - collection of sentence pairs drawn from news headlines, video and image - captions, and natural language inference data. Each pair is human-annotated - with a similarity score from 1 to 5. - From https://gluebenchmark.com/tasks - Args: - mode ('train'|'dev'|'test'): Dataset mode. Default: 'train'. - root (str): Path to temp folder for storing data. - return_all_fields (bool): Return all fields available in the dataset. Default: False. - Example: - .. code-block:: python - - from paddlenlp.datasets import GlueSTSB - stsb_dev = GlueSTSB('dev', root='./datasets/stsb') - len(stsb_dev) # 1500 - len(stsb_dev[0]) # 3 - stsb_dev[0] # ['A man with a hard hat is dancing.', 'A man wearing a hard hat is dancing.', '5.000'] - stsb_test = GlueSTSB('test', root='./datasets/stsb') - len(stsb_test) # 1379 - len(stsb_test[0]) # 2 - stsb_test[0] # ['A girl is styling her hair.', 'A girl is brushing her hair.'] - """ - URL = 'https://dataset.bj.bcebos.com/glue/STS.zip' - MD5 = 'd573676be38f1a075a5702b90ceab3de' - - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('STS-B', 'train.tsv'), - '4f7a86dde15fe4832c18e5b970998672', (7, 8, 9), 1), - 'dev': _GlueDataset.META_INFO( - os.path.join('STS-B', 'dev.tsv'), - '5f4d6b0d2a5f268b1b56db773ab2f1fe', (7, 8, 9), 1), - 'test': _GlueDataset.META_INFO( - os.path.join('STS-B', 'test.tsv'), - '339b5817e414d19d9bb5f593dd94249c', (7, 8), 1) - } - - def get_labels(self): - """ - Return labels of the GlueSTSB object. - """ - return None - - -class GlueQQP(_GlueDataset): - """ - The Quora Question Pairs dataset is a collection of question pairs from the - community question-answering website Quora. - From https://gluebenchmark.com/tasks - Args: - mode ({'train', 'dev', 'test'}): Dataset mode. Default: 'train'. - root (str): Path to temp folder for storing data. - return_all_fields (bool): Return all fields available in the dataset. - Default: False. - Example: - .. code-block:: python - - from paddlenlp.datasets import GlueQQP - import warnings - with warnings.catch_warnings(): - # Ignore warnings triggered by invalid entries in GlueQQP dev set - warnings.simplefilter("ignore") - qqp_dev = GlueQQP('dev', root='./datasets/qqp') - len(qqp_dev) # 40430 - len(qqp_dev[0]) # 3 - qqp_dev[0] # ['Why are African-Americans so beautiful?', - # 'Why are hispanics so beautiful?', '0'] - qqp_test = GlueQQP('test', root='./datasets/qqp') - len(qqp_test) # 390965 - len(qqp_test[3]) # 2 - qqp_test[3] # ['Is it safe to invest in social trade biz?', - # 'Is social trade geniune?'] - """ - URL = 'https://dataset.bj.bcebos.com/glue/QQP.zip' - MD5 = '884bf26e39c783d757acc510a2a516ef' - - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('QQP', 'train.tsv'), - 'e003db73d277d38bbd83a2ef15beb442', (3, 4, 5), 1), - 'dev': _GlueDataset.META_INFO( - os.path.join('QQP', 'dev.tsv'), 'cff6a448d1580132367c22fc449ec214', - (3, 4, 5), 1), - 'test': _GlueDataset.META_INFO( - os.path.join('QQP', 'test.tsv'), '73de726db186b1b08f071364b2bb96d0', - (1, 2), 1) - } - - def __init__(self, mode='train', root=None, return_all_fields=False): - # QQP may include broken samples - super(GlueQQP, self).__init__( - mode, root, return_all_fields, allow_missing=True) - - def get_labels(self): - """ - Return labels of the GlueQQP object. - """ - return ["0", "1"] - - -class GlueMNLI(_GlueDataset): - """ - The Multi-Genre Natural Language Inference Corpus (Williams et al., 2018) - is a crowdsourced collection of sentence pairs with textual entailment - annotations. - From https://gluebenchmark.com/tasks - Args: - mode ('train'|'dev_matched'|'dev_mismatched'|'test_matched'| - 'test_mismatched'): Dataset segment. Default: ‘train’. - root (str, default '$MXNET_HOME/datasets/glue_mnli'): Path to temp - folder for storing data. - return_all_fields (bool): Return all fields available in the dataset. - Default: False. - Example: - .. code-block:: python - - from paddlenlp.datasets import GlueMNLI - mnli_dev = GlueMNLI('dev_matched', root='./datasets/mnli') - len(mnli_dev) # 9815 - len(mnli_dev[0]) # 3 - mnli_dev[0] # ['The new rights are nice enough', - # 'Everyone really likes the newest benefits ', - # 'neutral'] - mnli_test = GlueMNLI('test_matched', root='./datasets/mnli') - len(mnli_test) # 9796 - len(mnli_test[0]) # 2 - mnli_test[0] # ['Hierbas, ans seco, ans dulce, and frigola are - # just a few names worth keeping a look-out for.', - # 'Hierbas is a name worth looking out for.'] - """ - URL = 'https://dataset.bj.bcebos.com/glue/MNLI.zip' - MD5 = 'e343b4bdf53f927436d0792203b9b9ff' - - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('MNLI', 'train.tsv'), - '220192295e23b6705f3545168272c740', (8, 9, 11), 1), - 'dev_matched': _GlueDataset.META_INFO( - os.path.join('MNLI', 'dev_matched.tsv'), - 'c3fa2817007f4cdf1a03663611a8ad23', (8, 9, 15), 1), - 'dev_mismatched': _GlueDataset.META_INFO( - os.path.join('MNLI', 'dev_mismatched.tsv'), - 'b219e6fe74e4aa779e2f417ffe713053', (8, 9, 15), 1), - 'test_matched': _GlueDataset.META_INFO( - os.path.join('MNLI', 'test_matched.tsv'), - '33ea0389aedda8a43dabc9b3579684d9', (8, 9), 1), - 'test_mismatched': _GlueDataset.META_INFO( - os.path.join('MNLI', 'test_mismatched.tsv'), - '7d2f60a73d54f30d8a65e474b615aeb6', (8, 9), 1), - } - - def get_labels(self): - """ - Return labels of the GlueMNLI object. - """ - return ["contradiction", "entailment", "neutral"] - - -class GlueQNLI(_GlueDataset): - """ - The Question-answering NLI dataset converted from Stanford Question - Answering Dataset (Rajpurkar et al. 2016). - From https://gluebenchmark.com/tasks - Args: - mode ('train'|'dev'|'test'): Dataset segment. - Default: 'train'. - root (str): Path to temp folder for storing data. - return_all_fields (bool): Return all fields available in the dataset. - Default: False. - - Example: - .. code-block:: python - - from paddlenlp.datasets import GlueQNLI - qnli_dev = GlueQNLI('dev', root='./datasets/qnli') - len(qnli_dev) # 5732 - len(qnli_dev[0]) # 3 - qnli_dev[0] # ['Which NFL team represented the AFC at Super Bowl - # 50?', 'The American Football Conference (AFC) - # champion Denver Broncos defeated the National - # Football Conference (NFC) champion Carolina Panthers - # 24\u201310 to earn their third Super Bowl title.', - # 'entailment'] - qnli_test = GlueQNLI('test', root='./datasets/qnli') - len(qnli_test) # 5740 - len(qnli_test[0]) # 2 - qnli_test[0] # ['What seldom used term of a unit of force equal to - # 1000 pound s of force?', - # 'Other arcane units of force include the sthène, - # which is equivalent to 1000 N, and the kip, which - # is equivalent to 1000 lbf.'] - """ - URL = 'https://dataset.bj.bcebos.com/glue/QNLI.zip' - MD5 = 'b4efd6554440de1712e9b54e14760e82' - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('QNLI', 'train.tsv'), - '5e6063f407b08d1f7c7074d049ace94a', (1, 2, 3), 1), - 'dev': _GlueDataset.META_INFO( - os.path.join('QNLI', 'dev.tsv'), '1e81e211959605f144ba6c0ad7dc948b', - (1, 2, 3), 1), - 'test': _GlueDataset.META_INFO( - os.path.join('QNLI', 'test.tsv'), - 'f2a29f83f3fe1a9c049777822b7fa8b0', (1, 2), 1) - } - - def get_labels(self): - """ - Return labels of the GlueQNLI object. - """ - return ["entailment", "not_entailment"] - - -class GlueRTE(_GlueDataset): - """ - The Recognizing Textual Entailment (RTE) datasets come from a series of - annual textual entailment challenges (RTE1, RTE2, RTE3, and RTE5). - From https://gluebenchmark.com/tasks - Args: - mode ('train'|'dev'|'test'): Dataset segment. Default: 'train'. - root (str): Path to temp folder for storing data. - return_all_fields (bool): Return all fields available in the dataset. - Default: False. - Examples: - .. code-block:: python - - from paddlenlp.datasets import GlueRTE - rte_dev = GlueRTE('dev', root='./datasets/rte') - len(rte_dev) # 277 - len(rte_dev[0]) # 3 - rte_dev[0] # ['Dana Reeve, the widow of the actor Christopher - # Reeve, has died of lung cancer at age 44, according - # to the Christopher Reeve Foundation.', 'Christopher - # Reeve had an accident.', 'not_entailment'] - rte_test = GlueRTE('test', root='./datasets/rte') - len(rte_test) # 3000 - len(rte_test[16]) # 2 - rte_test[16] # ['United failed to progress beyond the group stages - # of the Champions League and trail in the Premiership - # title race, sparking rumours over its future.', - # 'United won the Champions League.'] - """ - URL = 'https://dataset.bj.bcebos.com/glue/RTE.zip' - MD5 = 'bef554d0cafd4ab6743488101c638539' - - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('RTE', 'train.tsv'), - 'd2844f558d111a16503144bb37a8165f', (1, 2, 3), 1), - 'dev': _GlueDataset.META_INFO( - os.path.join('RTE', 'dev.tsv'), '973cb4178d4534cf745a01c309d4a66c', - (1, 2, 3), 1), - 'test': _GlueDataset.META_INFO( - os.path.join('RTE', 'test.tsv'), '6041008f3f3e48704f57ce1b88ad2e74', - (1, 2), 1) - } - - def get_labels(self): - """ - Return labels of the GlueRTE object. - """ - return ["entailment", "not_entailment"] - - -class GlueWNLI(_GlueDataset): - """ - The Winograd NLI dataset converted from the dataset in Winograd Schema - Challenge (Levesque et al., 2011). - From https://gluebenchmark.com/tasks - Args: - mode ('train'|'dev'|'test'): Dataset segment. Default: 'train'. - root (str): Path to temp folder for storing data. - return_all_fields (bool): Return all fields available in the dataset. - Default: False. - Example: - .. code-block:: python - - from paddlenlp.datasets import GlueWNLI - wnli_dev = GlueWNLI('dev', root='./datasets/wnli') - len(wnli_dev) # 71 - len(wnli_dev[0]) # 3 - wnli_dev[0] # ['The drain is clogged with hair. It has to be - # cleaned.', 'The hair has to be cleaned.', '0'] - wnli_test = GlueWNLI('test', root='./datasets/wnli') - len(wnli_test) # 146 - len(wnli_test[0]) # 2 - wnli_test[0] # ['Maude and Dora had seen the trains rushing - # across the prairie, with long, rolling puffs - # of black smoke streaming back from the engine. - # Their roars and their wild, clear whistles - # could be heard from far away. Horses ran away - # when they came in sight.', 'Horses ran away when - # Maude and Dora came in sight.'] - """ - URL = 'https://dataset.bj.bcebos.com/glue/WNLI.zip' - MD5 = 'a1b4bd2861017d302d29e42139657a42' - - SPLITS = { - 'train': _GlueDataset.META_INFO( - os.path.join('WNLI', 'train.tsv'), - '5cdc5a87b7be0c87a6363fa6a5481fc1', (1, 2, 3), 1), - 'dev': _GlueDataset.META_INFO( - os.path.join('WNLI', 'dev.tsv'), 'a79a6dd5d71287bcad6824c892e517ee', - (1, 2, 3), 1), - 'test': _GlueDataset.META_INFO( - os.path.join('WNLI', 'test.tsv'), - 'a18789ba4f60f6fdc8cb4237e4ba24b5', (1, 2), 1) - } + test_fh.write('%d\t%s\t%s\t%s\t%s\n' % + (idx, id1, id2, s1, s2)) + + return fullname + + def _read(self, filename, split): + _, _, field_indices, num_discard_samples = self.BUILDER_CONFIGS[ + self.name]['splits'][split] + with open(filename, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + if idx < num_discard_samples: + continue + line_stripped = line.strip().split('\t') + if not line_stripped: + continue + example = [line_stripped[indice] for indice in field_indices] + if self.name in ['cola', 'sst-2']: + yield { + 'sentence': example[0] + } if 'test' in split else { + 'sentence': example[0], + 'labels': example[-1] + } + else: + yield { + 'sentence1': example[0], + 'sentence2': example[1] + } if 'test' in split else { + 'sentence1': example[0], + 'sentence2': example[1], + 'labels': example[-1] + } def get_labels(self): """ - Return labels of the GlueWNLI object. + Return labels of the Glue task. """ - return ["0", "1"] + return self.BUILDER_CONFIGS[self.name]['labels'] diff --git a/paddlenlp/datasets/imdb.py b/paddlenlp/datasets/imdb.py index 8e28c12179431..89c4e04b83acb 100644 --- a/paddlenlp/datasets/imdb.py +++ b/paddlenlp/datasets/imdb.py @@ -12,21 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections +import json import io import os import string import numpy as np -import paddle -from paddle.io import Dataset -from paddlenlp.utils.env import DATA_HOME +from paddle.dataset.common import md5file from paddlenlp.utils.downloader import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder __all__ = ['Imdb'] -class Imdb(Dataset): +class Imdb(DatasetBuilder): """ Implementation of `IMDB `_ dataset. @@ -34,56 +36,35 @@ class Imdb(Dataset): URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz' MD5 = '7c2ac02c03563afcf9b574c7e56c153a' - def __init__( - self, - root=None, - mode='train', ): - assert mode in [ - "train", "test" - ], "Unknown mode %s, it should be 'train' or 'test'." % mode - if root is None: - root = DATA_HOME - data_dir = os.path.join(root, "aclImdb") - + def _get_data(self, mode, **kwargs): + """Downloads dataset.""" + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + data_dir = os.path.join(default_root, "aclImdb", mode) if not os.path.exists(data_dir): - data_dir = get_path_from_url(self.URL, root, self.MD5) + path = get_path_from_url(self.URL, default_root, self.MD5) + return data_dir - self.examples = self._read_data_file(data_dir, mode) - - def _read_data_file(self, data_dir, mode): - # remove punctuations ad-hoc. + def _read(self, data_dir, *args): translator = str.maketrans('', '', string.punctuation) - def _load_data(label): - root = os.path.join(data_dir, mode, label) + for label in ["pos", "neg"]: + root = os.path.join(data_dir, label) data_files = os.listdir(root) data_files.sort() + if label == "pos": - label_id = 1 + label_id = "1" elif label == "neg": - label_id = 0 - - all_samples = [] + label_id = "0" for f in data_files: f = os.path.join(root, f) with io.open(f, 'r', encoding='utf8') as fr: data = fr.readlines() data = data[0].translate(translator) - all_samples.append((data, label_id)) - - return all_samples - - data_set = _load_data("pos") - data_set.extend(_load_data("neg")) - np.random.shuffle(data_set) - - return data_set - - def __getitem__(self, idx): - return self.examples[idx] - - def __len__(self): - return len(self.examples) + yield {"text": data, "label": label_id} def get_labels(self): + """ + Return labels of the Imdb object. + """ return ["0", "1"] diff --git a/paddlenlp/datasets/experimental/iwslt15.py b/paddlenlp/datasets/iwslt15.py similarity index 100% rename from paddlenlp/datasets/experimental/iwslt15.py rename to paddlenlp/datasets/iwslt15.py diff --git a/paddlenlp/datasets/lcqmc.py b/paddlenlp/datasets/lcqmc.py index 85059103d80be..ceacf6583e608 100644 --- a/paddlenlp/datasets/lcqmc.py +++ b/paddlenlp/datasets/lcqmc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,92 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import collections -import io +import json import os -import warnings -from paddle.io import Dataset from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url from paddlenlp.utils.env import DATA_HOME - -from .dataset import TSVDataset +from . import DatasetBuilder __all__ = ['LCQMC'] -class LCQMC(TSVDataset): +class LCQMC(DatasetBuilder): """ LCQMC:A Large-scale Chinese Question Matching Corpus More information please refer to `https://www.aclweb.org/anthology/C18-1166/` - """ URL = "https://bj.bcebos.com/paddlehub-dataset/lcqmc.tar.gz" MD5 = "62a7ba36f786a82ae59bbde0b0a9af0c" - META_INFO = collections.namedtuple( - 'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples')) + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) SPLITS = { 'train': META_INFO( os.path.join('lcqmc', 'train.tsv'), - '2193c022439b038ac12c0ae918b211a1', (0, 1, 2), 1), + '2193c022439b038ac12c0ae918b211a1'), 'dev': META_INFO( os.path.join('lcqmc', 'dev.tsv'), - 'c5dcba253cb4105d914964fd8b3c0e94', (0, 1, 2), 1), + 'c5dcba253cb4105d914964fd8b3c0e94'), 'test': META_INFO( os.path.join('lcqmc', 'test.tsv'), - '8f4b71e15e67696cc9e112a459ec42bd', (0, 1, 2), 1) + '8f4b71e15e67696cc9e112a459ec42bd'), } - def __init__(self, - mode='train', - root=None, - return_all_fields=False, - **kwargs): - if return_all_fields: - splits = copy.deepcopy(self.__class__.SPLITS) - mode_info = list(splits[mode]) - mode_info[2] = None - splits[mode] = self.META_INFO(*mode_info) - self.SPLITS = splits - - self._get_data(root, mode, **kwargs) - - def _get_data(self, root, mode, **kwargs): - default_root = DATA_HOME - filename, data_hash, field_indices, num_discard_samples = self.SPLITS[ - mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - path = get_path_from_url(self.URL, default_root, self.MD5) - fullname = os.path.join(default_root, filename) - super(LCQMC, self).__init__( - fullname, - field_indices=field_indices, - num_discard_samples=num_discard_samples, - **kwargs) + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename): + """Reads data.""" + with open(filename, 'r', encoding='utf-8') as f: + head = None + for line in f: + data = line.strip().split("\t") + if not head: + head = data + else: + query, title, label = data + yield {"query": query, "title": title, "label": label} def get_labels(self): """ Return labels of the LCQMC object. """ return ["0", "1"] - - -if __name__ == "__main__": - ds = LCQMC('train', return_all_fields=True) - - for idx, data in enumerate(ds): - if idx >= 3: - break - print(data) diff --git a/paddlenlp/datasets/msra_ner.py b/paddlenlp/datasets/msra_ner.py index ba522af19392f..95d083a8245eb 100644 --- a/paddlenlp/datasets/msra_ner.py +++ b/paddlenlp/datasets/msra_ner.py @@ -20,53 +20,45 @@ from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url from paddlenlp.utils.env import DATA_HOME - -from .dataset import TSVDataset +from . import DatasetBuilder __all__ = ['MSRA_NER'] -class MSRA_NER(TSVDataset): +class MSRA_NER(DatasetBuilder): URL = "https://paddlenlp.bj.bcebos.com/datasets/msra_ner.tar.gz" MD5 = None - META_INFO = collections.namedtuple( - 'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples')) + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) SPLITS = { - 'train': META_INFO( - os.path.join('msra_ner', 'train.tsv'), - '67d3c93a37daba60ef43c03271f119d7', - (0, 1), - 1, ), - 'test': META_INFO( - os.path.join('msra_ner', 'test.tsv'), - '2f27ae68b5f61d6553ffa28bb577c8a7', - (0, 1), - 1, ), + 'train': META_INFO(os.path.join('msra_ner', 'train.tsv'), None), + 'test': META_INFO(os.path.join('msra_ner', 'test.tsv'), None) } - def __init__(self, mode='train', root=None, **kwargs): - default_root = os.path.join(DATA_HOME, 'msra') - filename, data_hash, field_indices, num_discard_samples = self.SPLITS[ - mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - path = get_path_from_url(self.URL, default_root, self.MD5) - fullname = os.path.join(default_root, filename) - super(MSRA_NER, self).__init__( - fullname, - field_indices=field_indices, - num_discard_samples=num_discard_samples, - **kwargs) + + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename, *args): + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + line_stripped = line.strip().split('\t') + if not line_stripped: + break + if len(line_stripped) == 2: + tokens = line_stripped[0].split("\002") + tags = line_stripped[1].split("\002") + else: + tokens = line_stripped.split("\002") + tags = [] + yield {"tokens": tokens, "labels": tags} def get_labels(self): - """ - Return labels of the GlueCoLA object. - """ + return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] diff --git a/paddlenlp/datasets/peoples_daily_ner.py b/paddlenlp/datasets/peoples_daily_ner.py index 42c131479664d..c891f03509cfd 100644 --- a/paddlenlp/datasets/peoples_daily_ner.py +++ b/paddlenlp/datasets/peoples_daily_ner.py @@ -20,58 +20,45 @@ from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url from paddlenlp.utils.env import DATA_HOME - -from .dataset import TSVDataset +from . import DatasetBuilder __all__ = ['PeoplesDailyNER'] -class PeoplesDailyNER(TSVDataset): +class PeoplesDailyNER(DatasetBuilder): URL = "https://paddlenlp.bj.bcebos.com/datasets/peoples_daily_ner.tar.gz" MD5 = None - META_INFO = collections.namedtuple( - 'META_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples')) + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) SPLITS = { - 'train': META_INFO( - os.path.join('peoples_daily_ner', 'train.tsv'), - '67d3c93a37daba60ef43c03271f119d7', - (0, 1), - 1, ), - 'dev': META_INFO( - os.path.join('peoples_daily_ner', 'dev.tsv'), - 'ec772f3ba914bca5269f6e785bb3375d', - (0, 1), - 1, ), - 'test': META_INFO( - os.path.join('peoples_daily_ner', 'test.tsv'), - '2f27ae68b5f61d6553ffa28bb577c8a7', - (0, 1), - 1, ), + 'train': + META_INFO(os.path.join('peoples_daily_ner', 'train.tsv'), None), + 'test': META_INFO(os.path.join('peoples_daily_ner', 'test.tsv'), None) } - def __init__(self, mode='train', root=None, **kwargs): - default_root = os.path.join(DATA_HOME, 'peoples_daily_ner') - filename, data_hash, field_indices, num_discard_samples = self.SPLITS[ - mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - path = get_path_from_url(self.URL, default_root, self.MD5) - fullname = os.path.join(default_root, filename) - super(PeoplesDailyNER, self).__init__( - fullname, - field_indices=field_indices, - num_discard_samples=num_discard_samples, - **kwargs) + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename, *args): + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + line_stripped = line.strip().split('\t') + if not line_stripped: + break + if len(line_stripped) == 2: + tokens = line_stripped[0].split("\002") + tags = line_stripped[1].split("\002") + else: + tokens = line_stripped.split("\002") + tags = [] + yield {"tokens": tokens, "labels": tags} def get_labels(self): - """ - Return labels of the GlueCoLA object. - """ + return ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] diff --git a/paddlenlp/datasets/poetry.py b/paddlenlp/datasets/poetry.py index 6375e19d3aa73..53830ddac7072 100644 --- a/paddlenlp/datasets/poetry.py +++ b/paddlenlp/datasets/poetry.py @@ -17,48 +17,51 @@ import warnings from paddle.io import Dataset -from paddle.dataset.common import DATA_HOME, md5file +from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url - -from .dataset import TSVDataset +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder __all__ = ['Poetry'] -class Poetry(TSVDataset): +class Poetry(DatasetBuilder): URL = "https://paddlenlp.bj.bcebos.com/datasets/poetry.tar.gz" MD5 = '8edd7eda1b273145b70ef29c82cd622b' - SEGMENT_INFO = collections.namedtuple( - 'SEGMENT_INFO', ('file', 'md5', 'field_indices', 'num_discard_samples')) - SEGMENTS = { - 'train': SEGMENT_INFO( + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) + SPLITS = { + 'train': META_INFO( os.path.join('poetry', 'train.tsv'), - '176c6202b5e71656ae7e7848eec4c54f', (0, 1), 0), - 'dev': SEGMENT_INFO( + '176c6202b5e71656ae7e7848eec4c54f'), + 'dev': META_INFO( os.path.join('poetry', 'dev.tsv'), - '737e4b6da5facdc0ac33fe688df19931', (0, 1), 0), - 'test': SEGMENT_INFO( + '737e4b6da5facdc0ac33fe688df19931'), + 'test': META_INFO( os.path.join('poetry', 'test.tsv'), - '1dca907b2d712730c7c828f8acee7431', (0, 1), 0), + '1dca907b2d712730c7c828f8acee7431'), } - def __init__(self, segment='train', root=None, **kwargs): - default_root = os.path.join(DATA_HOME, 'poetry') - filename, data_hash, field_indices, num_discard_samples = self.SEGMENTS[ - segment] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - path = get_path_from_url(self.URL, default_root, self.MD5) - fullname = os.path.join(default_root, filename) - super(Poetry, self).__init__( - fullname, - field_indices=field_indices, - num_discard_samples=num_discard_samples, - **kwargs) + + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename, *args): + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + line_stripped = line.strip().split('\t') + if not line_stripped: + break + if len(line_stripped) == 2: + tokens = line_stripped[0] + labels = line_stripped[1] + else: + tokens = line_stripped + labels = [] + yield {"tokens": tokens, "labels": labels} \ No newline at end of file diff --git a/paddlenlp/datasets/ptb.py b/paddlenlp/datasets/ptb.py index ecc61fe4721a8..0ab384f4211fc 100644 --- a/paddlenlp/datasets/ptb.py +++ b/paddlenlp/datasets/ptb.py @@ -1,98 +1,59 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os -import math +import collections from paddle.io import Dataset + from paddle.utils.download import get_path_from_url +from paddle.dataset.common import md5file from paddlenlp.utils.env import DATA_HOME -import paddle.distributed as dist -import numpy as np - -__all__ = ['PTBDataset'] - - -class PTBDataset(Dataset): - - DATA_URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' - DATA_PATH = os.path.join('simple-examples', 'data') - - def __init__(self, batch_size, num_steps, mode='train', root=None): - super(PTBDataset, self).__init__() - - self._get_data(root=root, mode=mode) - train_data, valid_data, test_data = self.get_ptb_data(self.data_path) - if mode == 'train': - raw_data = train_data - elif mode == 'eval': - raw_data = valid_data - else: - raw_data = test_data - raw_data = np.asarray(raw_data, dtype="int64") - self.max_seq_len = len(raw_data) // batch_size - self.data = raw_data[0:batch_size * self.max_seq_len].reshape( - (batch_size, self.max_seq_len)) - self.num_steps = num_steps - self.num_shards = dist.get_world_size() - index = dist.get_rank() - self.shard(num_shards=self.num_shards, index=index) - - def _get_data(self, root, mode): - default_root = os.path.join(DATA_HOME, 'lm') - self.data_path = os.path.join(default_root, - self.DATA_PATH) if root is None else root - if not os.path.exists(self.data_path): - path = get_path_from_url(self.DATA_URL, default_root) - self.data_path = os.path.join(default_root, self.DATA_PATH) - - def build_vocab(self, filename): - EOS = "" - vocab_dict = {} - ids = 0 - vocab_dict[EOS] = ids - ids += 1 - with open(filename, "r") as f: - for line in f.readlines(): - for w in line.strip().split(): - if w not in vocab_dict: - vocab_dict[w] = ids - ids += 1 - self.vocab_size = ids - return vocab_dict - - def corpus_to_token_ids(self, corpus_path, vocab): - corpus_ids = [] - with open(corpus_path, "r") as f_corpus: - for line in f_corpus.readlines(): - tokens = line.strip().split() - ids = [vocab[w] for w in tokens if w in vocab] - - corpus_ids += ids + [0] #Add token_id:0 between sentences - return corpus_ids - - def get_ptb_data(self, data_path=None): - - train_file = os.path.join(data_path, "ptb.train.txt") - valid_file = os.path.join(data_path, "ptb.valid.txt") - test_file = os.path.join(data_path, "ptb.test.txt") - - vocab_dict = self.build_vocab(train_file) - train_ids = self.corpus_to_token_ids(train_file, vocab_dict) - valid_ids = self.corpus_to_token_ids(valid_file, vocab_dict) - test_ids = self.corpus_to_token_ids(test_file, vocab_dict) - - return train_ids, valid_ids, test_ids - - def shard(self, num_shards, index): - num_samples = int(math.floor(len(self.data[0]) * 1.0 / num_shards)) - sharded_data = self.data[:, index * num_samples:(index + 1) * - num_samples] - self.data = sharded_data - - def __getitem__(self, index): - x = np.copy(self.data[:, index * self.num_steps:(index + 1) * - self.num_steps]) - y = np.copy(self.data[:, index * self.num_steps + 1:(index + 1) * - self.num_steps + 1]) - return (x, y) - - def __len__(self): - return ((self.max_seq_len - 1) // self.num_steps) // self.num_shards +from . import DatasetBuilder + +__all__ = ['PTB'] + + +class PTB(DatasetBuilder): + URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' + MD5 = "30177ea32e27c525793142b6bf2c8e2d" + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) + SPLITS = { + 'train': META_INFO( + os.path.join('simple-examples', 'data', 'ptb.train.txt'), + "f26c4b92c5fdc7b3f8c7cdcb991d8420"), + 'valid': META_INFO( + os.path.join('simple-examples', 'data', 'ptb.valid.txt'), + "aa0affc06ff7c36e977d7cd49e3839bf"), + 'test': META_INFO( + os.path.join('simple-examples', 'data', 'ptb.test.txt'), + "8b80168b89c18661a38ef683c0dc3721") + } + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename, *args): + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + line_stripped = line.strip() + yield {"sentence": line_stripped} diff --git a/paddlenlp/datasets/squad.py b/paddlenlp/datasets/squad.py index 467d529390274..bbc8a3495787e 100644 --- a/paddlenlp/datasets/squad.py +++ b/paddlenlp/datasets/squad.py @@ -1,664 +1,86 @@ -import copy +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import collections import json import os -import warnings from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url -from paddle.io import Dataset from paddlenlp.utils.env import DATA_HOME -from paddlenlp.transformers.tokenizer_utils import _is_whitespace, _is_control, convert_to_unicode - -__all__ = ['SQuAD', 'DuReaderRobust', 'CMRC', 'DRCD'] - - -class SquadExample(object): - """A single training/test example for simple sequence classification. - - For examples without an answer, the start and end position are -1. - """ - - def __init__(self, - qas_id, - question_text, - doc_tokens, - orig_answer_text=None, - start_position=None, - end_position=None, - is_impossible=False): - self.qas_id = qas_id - self.question_text = question_text - self.doc_tokens = doc_tokens - self.orig_answer_text = orig_answer_text - self.start_position = start_position - self.end_position = end_position - self.is_impossible = is_impossible +from . import DatasetBuilder +__all__ = ['SQuAD'] -class InputFeatures(object): - """A single set of features of data.""" - - def __init__(self, - unique_id, - example_index, - doc_span_index, - tokens, - token_to_orig_map, - token_is_max_context, - input_ids, - input_mask, - segment_ids, - start_position=None, - end_position=None, - is_impossible=None): - self.unique_id = unique_id - self.example_index = example_index - self.doc_span_index = doc_span_index - self.tokens = tokens - self.token_to_orig_map = token_to_orig_map - self.token_is_max_context = token_is_max_context - self.input_ids = input_ids - self.input_mask = input_mask - self.segment_ids = segment_ids - self.start_position = start_position - self.end_position = end_position - self.is_impossible = is_impossible - - -class SQuAD(Dataset): - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - - DEV_DATA_URL_V2 = 'https://paddlenlp.bj.bcebos.com/datasets/squad/dev-v2.0.json' - TRAIN_DATA_URL_V2 = 'https://paddlenlp.bj.bcebos.com/datasets/squad/train-v2.0.json' - - DEV_DATA_URL_V1 = 'https://paddlenlp.bj.bcebos.com/datasets/squad/dev-v1.1.json' - TRAIN_DATA_URL_V1 = 'https://paddlenlp.bj.bcebos.com/datasets/squad/train-v1.1.json' +class SQuAD(DatasetBuilder): + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5', 'URL')) SPLITS = { - '1.1': { - 'train': META_INFO( - os.path.join('v1', 'train-v1.1.json'), - '981b29407e0affa3b1b156f72073b945'), - 'dev': META_INFO( - os.path.join('v1', 'dev-v1.1.json'), - '3e85deb501d4e538b6bc56f786231552') - }, - '2.0': { - 'train': META_INFO( - os.path.join('v2', 'train-v2.0.json'), - '62108c273c268d70893182d5cf8df740'), - 'dev': META_INFO( - os.path.join('v2', 'dev-v2.0.json'), - '246adae8b7002f8679c027697b0b7cf8') - } + 'train_v1': META_INFO( + os.path.join('train-v1.1.json'), '981b29407e0affa3b1b156f72073b945', + 'https://paddlenlp.bj.bcebos.com/datasets/squad/train-v1.1.json'), + 'dev_v1': META_INFO( + os.path.join('dev-v1.1.json'), '3e85deb501d4e538b6bc56f786231552', + 'https://paddlenlp.bj.bcebos.com/datasets/squad/dev-v1.1.json'), + 'train_v2': META_INFO( + os.path.join('train-v2.0.json'), '62108c273c268d70893182d5cf8df740', + 'https://paddlenlp.bj.bcebos.com/datasets/squad/train-v2.0.json'), + 'dev_v2': META_INFO( + os.path.join('dev-v2.0.json'), '246adae8b7002f8679c027697b0b7cf8', + 'https://paddlenlp.bj.bcebos.com/datasets/squad/dev-v2.0.json') } - def __init__(self, - tokenizer, - mode='train', - version_2_with_negative=False, - root=None, - doc_stride=128, - max_query_length=64, - max_seq_length=512, - **kwargs): - - self.version_2_with_negative = version_2_with_negative - self._get_data(root, mode, **kwargs) - self.tokenizer = tokenizer - self.doc_stride = doc_stride - self.max_query_length = max_query_length - self.max_seq_length = max_seq_length - - self._transform_func = None - - if mode == 'train': - self.is_training = True - else: - self.is_training = False - - self._read() - - self.features = self.convert_examples_to_feature( - self.examples, - tokenizer=self.tokenizer, - doc_stride=self.doc_stride, - max_query_length=self.max_query_length, - max_seq_length=self.max_seq_length) - - def _get_data(self, root, mode, **kwargs): + def _get_data(self, mode, **kwargs): default_root = os.path.join(DATA_HOME, self.__class__.__name__) - if self.version_2_with_negative: - filename, data_hash = self.SPLITS['2.0'][mode] - else: - filename, data_hash = self.SPLITS['1.1'][mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) + filename, data_hash, URL = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or (data_hash and not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - if mode == 'train': - if self.version_2_with_negative: - fullname = get_path_from_url( - self.TRAIN_DATA_URL_V2, - os.path.join(default_root, 'v2')) - else: - fullname = get_path_from_url( - self.TRAIN_DATA_URL_V1, - os.path.join(default_root, 'v1')) - elif mode == 'dev': - if self.version_2_with_negative: - fullname = get_path_from_url( - self.DEV_DATA_URL_V2, os.path.join(default_root, 'v2')) - else: - fullname = get_path_from_url( - self.DEV_DATA_URL_V1, os.path.join(default_root, 'v1')) - self.full_path = fullname - - def convert_examples_to_feature(self, examples, tokenizer, max_seq_length, - doc_stride, max_query_length): - """Loads a data file into a list of `InputBatch`s.""" - unique_id = 1000000000 - features = [] - for (example_index, example) in enumerate(examples): - query_tokens = tokenizer._tokenize(example.question_text) - if len(query_tokens) > max_query_length: - query_tokens = query_tokens[0:max_query_length] - - tok_to_orig_index = [] - orig_to_tok_index = [] - all_doc_tokens = [] - for (i, token) in enumerate(example.doc_tokens): - orig_to_tok_index.append(len(all_doc_tokens)) - sub_tokens = tokenizer._tokenize(token) - for sub_token in sub_tokens: - tok_to_orig_index.append(i) - all_doc_tokens.append(sub_token) - - tok_start_position = None - tok_end_position = None - if self.is_training and example.is_impossible: - tok_start_position = -1 - tok_end_position = -1 - if self.is_training and not example.is_impossible: - tok_start_position = orig_to_tok_index[example.start_position] - if example.end_position < len(example.doc_tokens) - 1: - tok_end_position = orig_to_tok_index[example.end_position + - 1] - 1 - else: - tok_end_position = len(all_doc_tokens) - 1 - (tok_start_position, - tok_end_position) = self._improve_answer_span( - all_doc_tokens, tok_start_position, tok_end_position, - tokenizer, example.orig_answer_text) - - # The -3 accounts for [CLS], [SEP] and [SEP] - max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 - - # We can have documents that are longer than the maximum sequence length. - # To deal with this we do a sliding window approach, where we take chunks - # of the up to our max length with a stride of `doc_stride`. - _DocSpan = collections.namedtuple( # pylint: disable=invalid-name - "DocSpan", ["start", "length"]) - doc_spans = [] - start_offset = 0 - while start_offset < len(all_doc_tokens): - length = len(all_doc_tokens) - start_offset - if length > max_tokens_for_doc: - length = max_tokens_for_doc - doc_spans.append(_DocSpan(start=start_offset, length=length)) - if start_offset + length == len(all_doc_tokens): - break - start_offset += min(length, doc_stride) - - for (doc_span_index, doc_span) in enumerate(doc_spans): - tokens = [] - token_to_orig_map = {} - token_is_max_context = {} - segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) - for token in query_tokens: - tokens.append(token) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) - - for i in range(doc_span.length): - split_token_index = doc_span.start + i - token_to_orig_map[len(tokens)] = tok_to_orig_index[ - split_token_index] + get_path_from_url(URL, default_root) - is_max_context = self._check_is_max_context( - doc_spans, doc_span_index, split_token_index) - token_is_max_context[len(tokens)] = is_max_context - tokens.append(all_doc_tokens[split_token_index]) - segment_ids.append(1) - tokens.append("[SEP]") - segment_ids.append(1) + return fullname - input_ids = tokenizer.convert_tokens_to_ids(tokens) - input_ids = input_ids + [ - tokenizer.vocab[tokenizer.pad_token] - for _ in range(self.max_seq_length - len(input_ids)) - ] - segment_ids = segment_ids + [ - tokenizer.vocab[tokenizer.pad_token] - for _ in range(self.max_seq_length - len(segment_ids)) - ] - input_mask = [1] * len(input_ids) - - start_position = None - end_position = None - if self.is_training and not example.is_impossible: - # For training, if our document chunk does not contain an annotation - # we throw it out, since there is nothing to predict. - doc_start = doc_span.start - doc_end = doc_span.start + doc_span.length - 1 - out_of_span = False - if not (tok_start_position >= doc_start and - tok_end_position <= doc_end): - out_of_span = True - if out_of_span: - start_position = 0 - end_position = 0 - else: - doc_offset = len(query_tokens) + 2 - start_position = tok_start_position - doc_start + doc_offset - end_position = tok_end_position - doc_start + doc_offset - - if self.is_training and example.is_impossible: - start_position = 0 - end_position = 0 - - feature = InputFeatures( - unique_id=unique_id, - example_index=example_index, - doc_span_index=doc_span_index, - tokens=tokens, - token_to_orig_map=token_to_orig_map, - token_is_max_context=token_is_max_context, - input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids, - start_position=start_position, - end_position=end_position, - is_impossible=example.is_impossible) - - unique_id += 1 - features.append(feature) - return features - - def _improve_answer_span(self, doc_tokens, input_start, input_end, - tokenizer, orig_answer_text): - """Returns tokenized answer spans that better match the annotated answer.""" - - # The SQuAD annotations are character based. We first project them to - # whitespace-tokenized words. But then after WordPiece tokenization, we can - # often find a "better match". For example: - # - # Question: What year was John Smith born? - # Context: The leader was John Smith (1895-1943). - # Answer: 1895 - # - # The original whitespace-tokenized answer will be "(1895-1943).". However - # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match - # the exact answer, 1895. - # - # However, this is not always possible. Consider the following: - # - # Question: What country is the top exporter of electornics? - # Context: The Japanese electronics industry is the lagest in the world. - # Answer: Japan - # - # In this case, the annotator chose "Japan" as a character sub-span of - # the word "Japanese". Since our WordPiece tokenizer does not split - # "Japanese", we just use "Japanese" as the annotation. This is fairly rare - # in SQuAD, but does happen. - tok_answer_text = " ".join(tokenizer._tokenize(orig_answer_text)) - - for new_start in range(input_start, input_end + 1): - for new_end in range(input_end, new_start - 1, -1): - text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) - if text_span == tok_answer_text: - return (new_start, new_end) - - return (input_start, input_end) - - def _check_is_max_context(self, doc_spans, cur_span_index, position): - """Check if this is the 'max context' doc span for the token.""" - - # Because of the sliding window approach taken to scoring documents, a single - # token can appear in multiple documents. E.g. - # Doc: the man went to the store and bought a gallon of milk - # Span A: the man went to the - # Span B: to the store and bought - # Span C: and bought a gallon of - # ... - # - # Now the word 'bought' will have two scores from spans B and C. We only - # want to consider the score with "maximum context", which we define as - # the *minimum* of its left and right context (the *sum* of left and - # right context will always be the same, of course). - # - # In the example the maximum context for 'bought' would be span C since - # it has 1 left context and 3 right context, while span B has 4 left context - # and 0 right context. - best_score = None - best_span_index = None - for (span_index, doc_span) in enumerate(doc_spans): - end = doc_span.start + doc_span.length - 1 - if position < doc_span.start: - continue - if position > end: - continue - num_left_context = position - doc_span.start - num_right_context = end - position - score = min(num_left_context, - num_right_context) + 0.01 * doc_span.length - if best_score is None or score > best_score: - best_score = score - best_span_index = span_index - - return cur_span_index == best_span_index - - def _read(self): - with open(self.full_path, "r", encoding="utf8") as reader: - input_data = json.load(reader)["data"] - - examples = [] + def _read(self, filename, *args): + with open(filename, "r", encoding="utf8") as f: + input_data = json.load(f)["data"] for entry in input_data: + title = entry.get("title", "").strip() for paragraph in entry["paragraphs"]: - paragraph_text = paragraph["context"] - doc_tokens = [] - char_to_word_offset = [] - prev_is_whitespace = True - for c in paragraph_text: - if _is_whitespace(c): - prev_is_whitespace = True - else: - if prev_is_whitespace: - doc_tokens.append(c) - else: - doc_tokens[-1] += c - prev_is_whitespace = False - char_to_word_offset.append(len(doc_tokens) - 1) - + context = paragraph["context"].strip() for qa in paragraph["qas"]: qas_id = qa["id"] - question_text = qa["question"] - start_position = None - end_position = None - orig_answer_text = None + question = qa["question"].strip() + answer_starts = [] + answers = [] is_impossible = False - if self.is_training: - if self.version_2_with_negative: - is_impossible = qa["is_impossible"] - if (len(qa["answers"]) != 1) and (not is_impossible): - raise ValueError( - "For training, each question should have exactly 1 answer." - ) - if not is_impossible: - answer = qa["answers"][0] - orig_answer_text = answer["text"] - answer_offset = answer["answer_start"] - answer_length = len(orig_answer_text) - start_position = char_to_word_offset[answer_offset] - try: - end_position = char_to_word_offset[ - answer_offset + answer_length - 1] - except: - continue - - else: - start_position = -1 - end_position = -1 - orig_answer_text = "" - else: - if self.version_2_with_negative: - is_impossible = qa["is_impossible"] - orig_answer_text = [] - if not is_impossible and 'answers' in qa.keys(): - answers = qa["answers"] - for answer in answers: - orig_answer_text.append(answer["text"]) - else: - start_position = -1 - end_position = -1 - example = SquadExample( - qas_id=qas_id, - question_text=question_text, - doc_tokens=doc_tokens, - orig_answer_text=orig_answer_text, - start_position=start_position, - end_position=end_position, - is_impossible=is_impossible) - examples.append(example) - - self.examples = examples - - def __len__(self): - return len(self.features) - - def __getitem__(self, idx): - feature = self.features[idx] - - if self.is_training: - return feature.input_ids, feature.segment_ids, feature.unique_id, feature.start_position, feature.end_position - else: - return feature.input_ids, feature.segment_ids, feature.unique_id - - -class DuReaderRobust(SQuAD): - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - - DATA_URL = 'https://dataset-bj.cdn.bcebos.com/qianyan/dureader_robust-data.tar.gz' - - SPLITS = { - 'train': META_INFO( - os.path.join('dureader_robust-data', 'train.json'), - '800a3dcb742f9fdf9b11e0a83433d4be'), - 'dev': META_INFO( - os.path.join('dureader_robust-data', 'dev.json'), - 'ae73cec081eaa28a735204c4898a2222'), - 'test': META_INFO( - os.path.join('dureader_robust-data', 'test.json'), - 'e0e8aa5c7b6d11b6fc3935e29fc7746f') - } - - def __init__(self, - tokenizer, - mode='train', - root=None, - doc_stride=128, - max_query_length=64, - max_seq_length=512, - **kwargs): - - super(DuReaderRobust, self).__init__( - tokenizer=tokenizer, - mode=mode, - version_2_with_negative=False, - root=root, - doc_stride=doc_stride, - max_query_length=max_query_length, - max_seq_length=max_seq_length, - **kwargs) - - def _get_data(self, root, mode, **kwargs): - default_root = os.path.join(DATA_HOME, 'self.__class__.__name__') - - filename, data_hash = self.SPLITS[mode] - - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - - get_path_from_url(self.DATA_URL, default_root) - - self.full_path = fullname - - def _read(self): - with open(self.full_path, "r", encoding="utf8") as reader: - input_data = json.load(reader)["data"] - - examples = [] - for entry in input_data: - for paragraph in entry["paragraphs"]: - paragraph_text = paragraph["context"] - raw_doc_tokens = self.tokenizer.basic_tokenizer.tokenize( - paragraph_text) - doc_tokens = [] - char_to_word_offset = [] - k = 0 - temp_word = "" - for c in paragraph_text: - if not self.tokenizer.basic_tokenizer.tokenize(c): - char_to_word_offset.append(k - 1) - continue - else: - temp_word += c - char_to_word_offset.append(k) - - if temp_word == raw_doc_tokens[k]: - doc_tokens.append(temp_word) - temp_word = "" - k += 1 - - assert k == len(raw_doc_tokens) - - for qa in paragraph["qas"]: - qas_id = qa["id"] - question_text = qa["question"] - start_position = None - end_position = None - orig_answer_text = None - is_impossible = False - - if self.is_training: - if (len(qa["answers"]) != 1): - raise ValueError( - "For training, each question should have exactly 1 answer." - ) - - answer = qa["answers"][0] - orig_answer_text = answer["text"] - answer_offset = answer["answer_start"] - answer_length = len(orig_answer_text) - start_position = char_to_word_offset[answer_offset] - try: - end_position = char_to_word_offset[ - answer_offset + answer_length - 1] - except: - continue - - else: - orig_answer_text = [] - if 'answers' in qa.keys(): - answers = qa["answers"] - for answer in answers: - orig_answer_text.append(answer["text"]) - - example = SquadExample( - qas_id=qas_id, - question_text=question_text, - doc_tokens=doc_tokens, - orig_answer_text=orig_answer_text, - start_position=start_position, - end_position=end_position, - is_impossible=is_impossible) - examples.append(example) - - self.examples = examples - - -class CMRC(DuReaderRobust): - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - - DEV_DATA_URL = 'https://paddlenlp.bj.bcebos.com/datasets/cmrc/cmrc2018_dev.json' - TRAIN_DATA_URL = 'https://paddlenlp.bj.bcebos.com/datasets/cmrc/cmrc2018_train.json' - TRIAL_DATA_URL = 'https://paddlenlp.bj.bcebos.com/datasets/cmrc/cmrc2018_trial.json' - - SPLITS = { - 'train': META_INFO( - os.path.join('cmrc2018_train.json'), - '7fb714b479c7f40fbb16acabd7af0ede'), - 'dev': META_INFO( - os.path.join('cmrc2018_dev.json'), - '853b80709ff2d071f9fce196521b843c'), - 'trial': META_INFO( - os.path.join('cmrc2018_trial.json'), - '853b80709ff2d071f9fce196521b843c') - } - - def _get_data(self, root, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - - filename, data_hash = self.SPLITS[mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - if mode == 'train': - fullname = get_path_from_url(self.TRAIN_DATA_URL, default_root) - elif mode == 'dev': - fullname = get_path_from_url(self.DEV_DATA_URL, default_root) - elif mode == 'trial': - fullname = get_path_from_url(self.TRIAL_DATA_URL, default_root) - self.full_path = fullname - - -class DRCD(DuReaderRobust): - META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) - - DEV_DATA_URL = 'https://paddlenlp.bj.bcebos.com/datasets/DRCD/DRCD_dev.json' - TRAIN_DATA_URL = 'https://paddlenlp.bj.bcebos.com/datasets/DRCD/DRCD_training.json' - TEST_DATA_URL = 'https://paddlenlp.bj.bcebos.com/datasets/DRCD/DRCD_test.json' - - SPLITS = { - 'train': META_INFO( - os.path.join('DRCD_dev.json'), '7fb714b479c7f40fbb16acabd7af0ede'), - 'dev': META_INFO( - os.path.join('DRCD_training.json'), - '853b80709ff2d071f9fce196521b843c'), - 'test': META_INFO( - os.path.join('DRCD_test.json'), '853b80709ff2d071f9fce196521b843c') - } - - def _get_data(self, root, mode, **kwargs): - default_root = os.path.join(DATA_HOME, self.__class__.__name__) - - filename, data_hash = self.SPLITS[mode] - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) - if not os.path.exists(fullname) or (data_hash and - not md5file(fullname) == data_hash): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'.format( - filename, self.__class__.__name__, default_root)) - if mode == 'train': - fullname = get_path_from_url(self.TRAIN_DATA_URL, default_root) - elif mode == 'dev': - fullname = get_path_from_url(self.DEV_DATA_URL, default_root) - elif mode == 'test': - fullname = get_path_from_url(self.TEST_DATA_URL, default_root) - self.full_path = fullname + if "is_impossible" in qa.keys(): + is_impossible = qa["is_impossible"] + + answer_starts = [ + answer["answer_start"] for answer in qa["answers"] + ] + answers = [ + answer["text"].strip() for answer in qa["answers"] + ] + + yield { + 'id': qas_id, + 'title': title, + 'context': context, + 'question': question, + 'answers': answers, + 'answer_starts': answer_starts, + 'is_impossible': is_impossible + } diff --git a/paddlenlp/datasets/translation.py b/paddlenlp/datasets/translation.py deleted file mode 100644 index 4e8888d34e694..0000000000000 --- a/paddlenlp/datasets/translation.py +++ /dev/null @@ -1,412 +0,0 @@ -import os -import io -import collections -import warnings - -from functools import partial -import numpy as np - -import paddle -from paddle.utils.download import get_path_from_url -from paddlenlp.data import Vocab, Pad -from paddlenlp.data.sampler import SamplerHelper -from paddlenlp.utils.env import DATA_HOME -from paddle.dataset.common import md5file - -__all__ = ['TranslationDataset', 'IWSLT15', 'WMT14ende'] - - -def sequential_transforms(*transforms): - def func(txt_input): - for transform in transforms: - txt_input = transform(txt_input) - return txt_input - - return func - - -def get_default_tokenizer(): - """Only support split tokenizer - """ - - def _split_tokenizer(x, delimiter=None): - if delimiter == "": - return list(x) - return x.split(delimiter) - - return _split_tokenizer - - -class TranslationDataset(paddle.io.Dataset): - """ - TranslationDataset, provide tuple (source and target) raw data. - - Args: - data(list): Raw data. It is a list of tuple or list, each sample of - data contains two element, source and target. - """ - META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file', - 'src_md5', 'tgt_md5')) - SPLITS = {} - URL = None - MD5 = None - VOCAB_INFO = None - UNK_TOKEN = None - PAD_TOKEN = None - BOS_TOKEN = None - EOS_TOKEN = None - - def __init__(self, data): - self.data = data - - def __getitem__(self, idx): - return self.data[idx] - - def __len__(self): - return len(self.data) - - @classmethod - def get_data(cls, mode="train", root=None): - """ - Download dataset and read raw data. - Args: - mode(str, optional): Data mode to download. It could be 'train', - 'dev' or 'test'. Default: 'train'. - root (str, optional): Data directory of dataset. If not - provided, dataset will be saved to default directory - `~/.paddlenlp/datasets/machine_translation/`. If provided, md5 - check would be performed, and dataset would be downloaded in - default directory if failed. Default: None. - Returns: - list: Raw data, a list of tuple. - - Examples: - .. code-block:: python - - from paddlenlp.datasets import IWSLT15 - data_path = IWSLT15.get_data() - """ - root = cls._download_data(mode, root) - data = cls.read_raw_data(mode, root) - return data - - @classmethod - def _download_data(cls, mode="train", root=None): - """Download dataset""" - default_root = os.path.join(DATA_HOME, 'machine_translation', - cls.__name__) - src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[ - mode] - - filename_list = [ - src_filename, tgt_filename, cls.VOCAB_INFO[0], cls.VOCAB_INFO[1] - ] - fullname_list = [] - for filename in filename_list: - fullname = os.path.join(default_root, - filename) if root is None else os.path.join( - os.path.expanduser(root), filename) - fullname_list.append(fullname) - - data_hash_list = [ - src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3] - ] - for i, fullname in enumerate(fullname_list): - if not os.path.exists(fullname) or ( - data_hash_list[i] and - not md5file(fullname) == data_hash_list[i]): - if root is not None: # not specified, and no need to warn - warnings.warn( - 'md5 check failed for {}, download {} data to {}'. - format(filename, cls.__name__, default_root)) - path = get_path_from_url(cls.URL, default_root, cls.MD5) - return default_root - return root if root is not None else default_root - - @classmethod - def get_vocab(cls, root=None): - """ - Load vocab from vocab files. It vocab files don't exist, the will - be downloaded. - - Args: - root (str, optional): Data directory pf dataset. If not provided, - dataset will be save in `~/.paddlenlp/datasets/machine_translation`. - If provided, md5 check would be performed, and dataset would be - downloaded in default directory if failed. Default: None. - Returns: - tuple: Source vocab and target vocab. - - Examples: - .. code-block:: python - - from paddlenlp.datasets import IWSLT15 - (src_vocab, tgt_vocab) = IWSLT15.get_vocab() - """ - - root = cls._download_data(root=root) - src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO - src_file_path = os.path.join(root, src_vocab_filename) - tgt_file_path = os.path.join(root, tgt_vocab_filename) - - src_vocab = Vocab.load_vocabulary( - filepath=src_file_path, - unk_token=cls.UNK_TOKEN, - pad_token=cls.PAD_TOKEN, - bos_token=cls.BOS_TOKEN, - eos_token=cls.EOS_TOKEN) - - tgt_vocab = Vocab.load_vocabulary( - filepath=tgt_file_path, - unk_token=cls.UNK_TOKEN, - pad_token=cls.PAD_TOKEN, - bos_token=cls.BOS_TOKEN, - eos_token=cls.EOS_TOKEN) - return (src_vocab, tgt_vocab) - - @classmethod - def read_raw_data(cls, mode, root): - """Read raw data from data files - Args: - mode(str): Indicates the mode to read. It could be 'train', 'dev' or - 'test'. - root(str): Data directory of dataset. - Returns: - list: Raw data list. - """ - src_filename, tgt_filename, _, _ = cls.SPLITS[mode] - - def read_raw_files(corpus_path): - """Read raw files, return raw data""" - data = [] - (f_mode, f_encoding, endl) = ("r", "utf-8", "\n") - with io.open(corpus_path, f_mode, encoding=f_encoding) as f_corpus: - for line in f_corpus.readlines(): - data.append(line.strip()) - return data - - src_path = os.path.join(root, src_filename) - tgt_path = os.path.join(root, tgt_filename) - src_data = read_raw_files(src_path) - tgt_data = read_raw_files(tgt_path) - - data = [(src_data[i], tgt_data[i]) for i in range(len(src_data))] - return data - - @classmethod - def get_default_transform_func(cls, root=None): - """Get default transform function, which transforms raw data to id. - Args: - root(str, optional): Data directory of dataset. - Returns: - tuple: Two transform functions, for source and target data. - Examples: - .. code-block:: python - - from paddlenlp.datasets import IWSLT15 - transform_func = IWSLT15.get_default_transform_func() - """ - # Get default tokenizer - src_tokenizer = get_default_tokenizer() - tgt_tokenizer = get_default_tokenizer() - src_text_vocab_transform = sequential_transforms(src_tokenizer) - tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer) - - (src_vocab, tgt_vocab) = cls.get_vocab(root) - src_text_transform = sequential_transforms(src_text_vocab_transform, - src_vocab) - tgt_text_transform = sequential_transforms(tgt_text_vocab_transform, - tgt_vocab) - return (src_text_transform, tgt_text_transform) - - -class IWSLT15(TranslationDataset): - """ - IWSLT15 Vietnames to English translation dataset. - - Args: - mode(str, optional): It could be 'train', 'dev' or 'test'. Default: - 'train'. - root(str, optional): If None, dataset will be downloaded in default - directory `~/paddlenlp/datasets/machine_translation/IWSLT15`. If - provided, md5 check would be performed and dataset would be - downloaded in default directory if failed. Default: None. - transform_func(callable, optional): If not None, it transforms raw data - to index data. Default: None. - Examples: - .. code-block:: python - - from paddlenlp.datasets import IWSLT15 - train_dataset = IWSLT15('train') - train_dataset, valid_dataset = IWSLT15.get_datasets(["train", "dev"]) - - """ - URL = "https://paddlenlp.bj.bcebos.com/datasets/iwslt15.en-vi.tar.gz" - SPLITS = { - 'train': TranslationDataset.META_INFO( - os.path.join("iwslt15.en-vi", "train.en"), - os.path.join("iwslt15.en-vi", "train.vi"), - "5b6300f46160ab5a7a995546d2eeb9e6", - "858e884484885af5775068140ae85dab"), - 'dev': TranslationDataset.META_INFO( - os.path.join("iwslt15.en-vi", "tst2012.en"), - os.path.join("iwslt15.en-vi", "tst2012.vi"), - "c14a0955ed8b8d6929fdabf4606e3875", - "dddf990faa149e980b11a36fca4a8898"), - 'test': TranslationDataset.META_INFO( - os.path.join("iwslt15.en-vi", "tst2013.en"), - os.path.join("iwslt15.en-vi", "tst2013.vi"), - "c41c43cb6d3b122c093ee89608ba62bd", - "a3185b00264620297901b647a4cacf38") - } - VOCAB_INFO = (os.path.join("iwslt15.en-vi", "vocab.en"), os.path.join( - "iwslt15.en-vi", "vocab.vi"), "98b5011e1f579936277a273fd7f4e9b4", - "e8b05f8c26008a798073c619236712b4") - UNK_TOKEN = '' - BOS_TOKEN = '' - EOS_TOKEN = '' - MD5 = 'aca22dc3f90962e42916dbb36d8f3e8e' - - def __init__(self, mode='train', root=None, transform_func=None): - data_select = ('train', 'dev', 'test') - if mode not in data_select: - raise TypeError( - '`train`, `dev` or `test` is supported but `{}` is passed in'. - format(mode)) - if transform_func is not None: - if len(transform_func) != 2: - raise ValueError("`transform_func` must have length of two for" - "source and target.") - # Download data and read data - self.data = self.get_data(mode=mode, root=root) - - if transform_func is not None: - self.data = [(transform_func[0](data[0]), - transform_func[1](data[1])) for data in self.data] - - -class WMT14ende(TranslationDataset): - """ - WMT14 English to German translation dataset. - - Args: - mode(str, optional): It could be 'train', 'dev' or 'test'. Default: 'train'. - root(str, optional): If None, dataset will be downloaded in - `~/.paddlenlp/datasets/machine_translation/WMT14ende/`. If provided, - md5 check would be performed, and dataset would be downloaded in - default directory if failed. Default: None. - transform_func(callable, optional): If not None, it transforms raw data - to index data. Default: None. - Examples: - .. code-block:: python - - from paddlenlp.datasets import WMT14ende - transform_func = WMT14ende.get_default_transform_func(root=root) - train_dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func) - """ - URL = "https://paddlenlp.bj.bcebos.com/datasets/WMT14.en-de.tar.gz" - SPLITS = { - 'train': TranslationDataset.META_INFO( - os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", - "train.tok.clean.bpe.33708.en"), - os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", - "train.tok.clean.bpe.33708.de"), - "c7c0b77e672fc69f20be182ae37ff62c", - "1865ece46948fda1209d3b7794770a0a"), - 'dev': TranslationDataset.META_INFO( - os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", - "newstest2013.tok.bpe.33708.en"), - os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", - "newstest2013.tok.bpe.33708.de"), - "aa4228a4bedb6c45d67525fbfbcee75e", - "9b1eeaff43a6d5e78a381a9b03170501"), - 'test': TranslationDataset.META_INFO( - os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", - "newstest2014.tok.bpe.33708.en"), - os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", - "newstest2014.tok.bpe.33708.de"), - "c9403eacf623c6e2d9e5a1155bdff0b5", - "0058855b55e37c4acfcb8cffecba1050"), - 'dev-eval': TranslationDataset.META_INFO( - os.path.join("WMT14.en-de", "wmt14_ende_data", - "newstest2013.tok.en"), - os.path.join("WMT14.en-de", "wmt14_ende_data", - "newstest2013.tok.de"), - "d74712eb35578aec022265c439831b0e", - "6ff76ced35b70e63a61ecec77a1c418f"), - 'test-eval': TranslationDataset.META_INFO( - os.path.join("WMT14.en-de", "wmt14_ende_data", - "newstest2014.tok.en"), - os.path.join("WMT14.en-de", "wmt14_ende_data", - "newstest2014.tok.de"), - "8cce2028e4ca3d4cc039dfd33adbfb43", - "a1b1f4c47f487253e1ac88947b68b3b8") - } - VOCAB_INFO = (os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", - "vocab_all.bpe.33708"), - os.path.join("WMT14.en-de", "wmt14_ende_data_bpe", - "vocab_all.bpe.33708"), - "2fc775b7df37368e936a8e1f63846bb0", - "2fc775b7df37368e936a8e1f63846bb0") - UNK_TOKEN = "" - BOS_TOKEN = "" - EOS_TOKEN = "" - - MD5 = "a2b8410709ff760a3b40b84bd62dfbd8" - - def __init__(self, mode="train", root=None, transform_func=None): - if mode not in ("train", "dev", "test", "dev-eval", "test-eval"): - raise TypeError( - '`train`, `dev`, `test`, `dev-eval` or `test-eval` is supported but `{}` is passed in'. - format(mode)) - if transform_func is not None and len(transform_func) != 2: - if len(transform_func) != 2: - raise ValueError("`transform_func` must have length of two for" - "source and target.") - - self.data = WMT14ende.get_data(mode=mode, root=root) - self.mode = mode - if transform_func is not None: - self.data = [(transform_func[0](data[0]), - transform_func[1](data[1])) for data in self.data] - super(WMT14ende, self).__init__(self.data) - - -# For test, not API -def prepare_train_input(insts, pad_id): - src, src_length = Pad(pad_val=pad_id, ret_length=True)( - [inst[0] for inst in insts]) - tgt, tgt_length = Pad(pad_val=pad_id, ret_length=True)( - [inst[1] for inst in insts]) - return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis] - -batch_size_fn = lambda idx, minibatch_len, size_so_far, data_source: max(size_so_far, len(data_source[idx][0])) - -batch_key = lambda size_so_far, minibatch_len: size_so_far * minibatch_len - -if __name__ == '__main__': - batch_size = 4096 #32 - pad_id = 2 - - transform_func = IWSLT15.get_default_transform_func() - train_dataset = IWSLT15(transform_func=transform_func) - - key = (lambda x, data_source: len(data_source[x][0])) - - train_batch_sampler = SamplerHelper(train_dataset).shuffle().sort( - key=key, buffer_size=batch_size * 20).batch( - batch_size=batch_size, - drop_last=True, - batch_size_fn=batch_size_fn, - key=batch_key).shard() - - train_loader = paddle.io.DataLoader( - train_dataset, - batch_sampler=train_batch_sampler, - collate_fn=partial( - prepare_train_input, pad_id=pad_id)) - - for i, data in enumerate(train_loader): - print(data[1]) - print(paddle.max(data[1]) * len(data[1])) - print(len(data[1])) diff --git a/paddlenlp/datasets/experimental/wmt14ende.py b/paddlenlp/datasets/wmt14ende.py similarity index 100% rename from paddlenlp/datasets/experimental/wmt14ende.py rename to paddlenlp/datasets/wmt14ende.py diff --git a/paddlenlp/datasets/experimental/yahoo_answer_100k.py b/paddlenlp/datasets/yahoo_answer_100k.py similarity index 100% rename from paddlenlp/datasets/experimental/yahoo_answer_100k.py rename to paddlenlp/datasets/yahoo_answer_100k.py