Skip to content

Commit

Permalink
Update VAE example (PaddlePaddle#145)
Browse files Browse the repository at this point in the history
* update vae example

* update readme

* delete infer collate func

* build vocab using train file only

* highlight some chars in dataset readme
  • Loading branch information
LiuChiachi authored Mar 15, 2021
1 parent e9e0af4 commit f7ea662
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 365 deletions.
3 changes: 2 additions & 1 deletion docs/datasets.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# PaddleNLP Datasets API

PaddleNLP提供了以下数据集的快速读取API,实际使用时请根据需要添加splits信息
PaddleNLP提供了以下数据集的快速读取API,实际使用时请根据需要**添加splits信息**

## 阅读理解

Expand Down Expand Up @@ -62,3 +62,4 @@ PaddleNLP提供了以下数据集的快速读取API,实际使用时请根据
| 数据集名称 | 简介 | 调用方法 |
| ---- | --------- | ------ |
| [PTB](http://www.fit.vutbr.cz/~imikolov/rnnlm/) | Penn Treebank Dataset | `paddlenlp.datasets.load_dataset('ptb')`|
| [Yahoo Answer 100k](https://arxiv.org/pdf/1702.08139.pdf) | 从Yahoo Answer采样100K| `paddlenlp.datasets.load_dataset('yahoo_answer_100k')`|
23 changes: 7 additions & 16 deletions examples/text_generation/vae-seq2seq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
├── README.md # 文档
├── args.py # 训练、预测以及模型参数配置程序
├── data.py # 数据读入程序
├── download.py # 数据下载程序
├── train.py # 训练主程序
├── predict.py # 预测主程序
└── model.py # VAE模型组网部分,以及Metric等
```

## 简介

本目录下此范例模型的实现,旨在展示如何用Paddle构建用于文本生成的VAE示例,其中LSTM作为编码器和解码器。分别对官方PTB数据和yahoo数据集进行训练
本目录下此范例模型的实现,旨在展示如何用Paddle构建用于文本生成的VAE示例,其中LSTM作为编码器和解码器。分别对PTB数据集和Yahoo Answer(采样100k)数据集进行训练

关于VAE的详细介绍参照: [(Bowman et al., 2015) Generating Sentences from a Continuous Space](https://arxiv.org/pdf/1511.06349.pdf)

Expand All @@ -24,16 +23,8 @@

PTB数据集由华尔街日报的文章组成,包含929k个训练tokens,词汇量为10k。下载地址为: https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz。

Yahoo数据集来自[(Yang et al., 2017) Improved Variational Autoencoders for Text Modeling using Dilated Convolutions](https://arxiv.org/pdf/1702.08139.pdf),该数据集从原始Yahoo Answer数据中采样100k个文档,数据集的平均文档长度为78,词汇量为200k。下载地址为:https://paddlenlp.bj.bcebos.com/datasets/yahoo-answer-100k.tar.gz
Yahoo数据集来自[(Yang et al., 2017) Improved Variational Autoencoders for Text Modeling using Dilated Convolutions](https://arxiv.org/pdf/1702.08139.pdf),该数据集从原始Yahoo Answer数据中采样100k个文档,数据集的平均文档长度为78,词汇量为200k。下载地址为:https://paddlenlp.bj.bcebos.com/datasets/yahoo-answer-100k.tar.gz,运行本例程序后,数据集会自动下载到`~/.paddlenlp/datasets/YahooAnswer100k`目录下。

### 数据获取

```
python download.py --task ptb # 下载ptb数据集
python download.py --task yahoo # 下载yahoo数据集
```

## 模型训练

Expand All @@ -47,7 +38,7 @@ python train.py \
--max_grad_norm 5.0 \
--dataset ptb \
--model_path ptb_model\
--use_gpu True \
--device gpu \
--max_epoch 50 \
```
Expand All @@ -62,7 +53,7 @@ python -m paddle.distributed.launch train.py \
--max_grad_norm 5.0 \
--dataset ptb \
--model_path ptb_model \
--use_gpu True \
--device gpu \
--max_epoch 50 \
```
Expand All @@ -79,7 +70,7 @@ python -m paddle.distributed.launch train.py \
--max_grad_norm 5.0 \
--dataset yahoo \
--model_path yahoo_model \
--use_gpu True \
--device gpu \
--max_epoch 50 \
```
Expand All @@ -98,7 +89,7 @@ python predict.py \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--dataset ptb \
--use_gpu True \
--device gpu \
--infer_output_file infer_output.txt \
--init_from_ckpt ptb_model/49 \
Expand All @@ -114,7 +105,7 @@ python predict.py \
--hidden_size 550 \
--max_grad_norm 5.0 \
--dataset yahoo \
--use_gpu True \
--device gpu \
--infer_output_file infer_output.txt \
--init_from_ckpt yahoo_model/49 \
Expand Down
8 changes: 4 additions & 4 deletions examples/text_generation/vae-seq2seq/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def parse_args():
"--beam_size", type=int, default=1, help="Beam size for Beam search.")

parser.add_argument(
'--use_gpu',
type=eval,
default=False,
help='Whether to use gpu [True|False].')
"--device",
default="gpu",
choices=["gpu", "cpu", "xpu"],
help="Device selected for inference.")

parser.add_argument(
"--warm_up",
Expand Down
Loading

0 comments on commit f7ea662

Please sign in to comment.