Skip to content

Commit

Permalink
commit for diffuseq-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Shansan Gong authored and Shansan Gong committed Oct 9, 2023
1 parent 9cddf4e commit cc4e9b4
Show file tree
Hide file tree
Showing 22 changed files with 2,068 additions and 378 deletions.
Binary file added .DS_Store
Binary file not shown.
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# <img src="img/logo.jpg" width="8%" alt="" align=center /> DiffuSeq

Official Codebase for [*__*DiffuSeq*__: Sequence to Sequence Text Generation With Diffusion Models*](https://arxiv.org/abs/2210.08933).
Official Codebase for [*__*DiffuSeq*__: Sequence to Sequence Text Generation With Diffusion Models*](https://arxiv.org/abs/2210.08933) and
[*__*DiffuSeq-v2*__: Bridging Discrete and Continuous Text Spaces for Accelerated Seq2Seq Diffusion Models*](https://arxiv.org/abs/).

<p align = "center">
<img src="img/diffuseq-process.png" width="95%" alt="" align=center />
Expand All @@ -9,6 +10,13 @@ Official Codebase for [*__*DiffuSeq*__: Sequence to Sequence Text Generation Wit
The diffusion process of our conditional diffusion language model DiffuSeq.
</p>

<p align = "center">
<img src="img/diffuseq-v2.png" width="40%" alt="" align=center />
</p>
<p align = "center">
The diffusion process of accelerated DiffuSeq.
</p>

## Highlights
- Our proposed __*DiffuSeq*__ as a conditional language model is trained end-to-end in a classifier-free manner.
- We establish a theoretical
Expand All @@ -27,7 +35,11 @@ sequence-to-sequence learning paradigm.
<img src="img/result-2.png" width=80%" alt="" align=center />
</p>

Update: Our enhanced version effectively accelerates the training convergence by 4x and generates samples of similar quality 800x faster, rendering it significantly closer to practical application.

<p align = "center">
<img src="img/result-3.png" width=80%" alt="" align=center />
</p>

## Setup:
The code is based on PyTorch and HuggingFace `transformers`.
Expand Down Expand Up @@ -59,6 +71,14 @@ Arguments explanation:

It will take 2 more days to train a __*DiffuSeq*__ model on 4 NVIDIA A100 80G GPUs for QG and QQP, and the training steps should be increased accordingly along with the size of the training set. To reproduce the results of Table 1 in our paper, we suggest the following configuration for each dataset when training.

### Update:
Additional argument:
- ```--learned_mean_embed```: set whether to use the learned soft absorbing state.
- ```--denoise```: set whether to add discrete noise
- ```--use_fp16```: set whether to use mixed precision training
- ```--denoise_rate```: set the denoise rate, with 0.5 as the default
It only take around 11 hours to train a model on 2 NVIDIA A100 80G GPUs for QQP.

```
python -m torch.distributed.launch --nproc_per_node=4 --master_port=12233 --use_env run_train.py --diff_steps 2000 --lr 0.0001 --learning_steps 50000 --save_interval 10000 --seed 102 --noise_schedule sqrt --hidden_dim 128 --bsz 2048 --dataset qqp --data_dir {datasets/QQP} --vocab bert --seq_len 128 --schedule_sampler lossaware --notes qqp
Expand All @@ -78,6 +98,13 @@ bash run_decode.sh
```
To reproduce the results of Table 1 in our paper, we suggest the size of MBR candidate set to be 10 (run 10 times using different seeds). Empirically, larger size can achieve higher BLEU score. For diversity metrics, the size of MBR candidate set is 3 when computing.

## Speed-up Decoding
We customize the implementation of [DPM-Solver++](https://github.com/LuChengTHU/dpm-solver) to DiffuSeq to accelerate its sampling speed.
```bash
cd scripts
bash run_decode_solver.sh
```

## Evaluation & MBR
You need to specify the folder of decoded texts. This folder should contain the decoded files from the same model but sampling with different random seeds. If ```mbr``` is not attached, we will compute the diversity score from the files in the folder, otherwise we will do MBR decoding:
```bash
Expand All @@ -87,6 +114,7 @@ python eval_seq2seq.py --folder ../{your-path-to-outputs} --mbr
Note: if you want to use this evaluation script for output files from other models, please make sure the same line from these output files refers to the same piece of data. Otherwise the diversity score could be incorrect.

## Update
- Update 10 Oct 2023: We update the DiffuSeq-v2, targeting the training/sampling speed up. Details in new branch [`diffuseq-v2`](https://github.com/Shark-NLP/DiffuSeq/tree/diffuseq-v2).
- Update 22 May 2023: We prepare the checkpoint and sampling results for remaining tasks in this [link](https://drive.google.com/drive/folders/1lHPp-T-ytp-YVptiokeYK-Lth48EGQ12?usp=sharing).
- Update 28 Nov 2022: We prepare the checkpoint and sampling results of 10 seeds for QQP dataset in this [link](https://drive.google.com/drive/folders/1vnhJIUqPQva_x_sH2h5a0moCc1NYmEpr?usp=sharing).
- Update 14 Feb 2023: We update the evaluation scripts and camera ready version of the paper.
Expand Down
15 changes: 13 additions & 2 deletions basic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def create_model_and_diffusion(
rescale_learned_sigmas,
use_kl,
notes,
learned_mean_embed=False,
rejection_rate=0.0,
denoise=False,
denoise_rate=0.2,
device="",
**kwargs,
):
model = TransformerNetModel(
Expand All @@ -131,7 +136,8 @@ def create_model_and_diffusion(
dropout=dropout,
config_name=config_name,
vocab_size=vocab_size,
init_pretrained=use_plm_init
init_pretrained=use_plm_init,
learned_mean_embed=learned_mean_embed,
)

betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
Expand All @@ -147,7 +153,12 @@ def create_model_and_diffusion(
learn_sigmas = learn_sigma,
sigma_small = sigma_small,
use_kl = use_kl,
rescale_learned_sigmas=rescale_learned_sigmas
rescale_learned_sigmas=rescale_learned_sigmas,
rejection_rate=rejection_rate,
denoise=denoise,
denoise_rate=denoise_rate,
device=device,
max_T = diffusion_steps,
)

return model, diffusion
Expand Down
11 changes: 9 additions & 2 deletions diffuseq/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"notes": "folder-notes",
"data_dir": "data-dir",
"dataset": "dataset-name",
"data_split_num": 0,
"checkpoint_path": "checkpoint-path",
"seq_len": 128,
"hidden_t_dim": 128,
Expand All @@ -35,5 +36,11 @@
"rescale_timesteps": true,
"rescale_learned_sigmas": false,
"sigma_small": false,
"emb_scale_factor": 1.0
}
"emb_scale_factor": 1.0,
"learned_mean_embed": false,
"denoise": false,
"denoise_rate": 0.2,
"rejection_rate": 0.0,
"reg_rate": 0.01,
"device": ""
}
Loading

0 comments on commit cc4e9b4

Please sign in to comment.