Source code and dataset for ACL 2023 Augmentation-Adapted Retriever Improves Generalization of Language Models as Generic Plug-In.
The code is using Python 3.9.13 and requires the CUDA 10.2 toolkit.
pip install -r requirements.txt
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Since there exist some bugs in DeepSpeed, you need to make some little modifications to this package. Specifically, you need to modify two lines of code in ${PATH_TO_PYTHON_SITE_PACKAGE}/deepspeed/runtime/zero/stage1.py
and ${PATH_TO_PYTHON_SITE_PACKAGE}/deepspeed/runtime/engine.py
. We provide the modified tools/ds_fix/stage1.py
and tools/ds_fix/engine.py
in our repo. You can simply replace ${PATH_TO_PYTHON_SITE_PACKAGE}/deepspeed/runtime/zero/stage1.py
with stage1.py
and ${PATH_TO_PYTHON_SITE_PACKAGE}/deepspeed/runtime/engine.py
with engine.py
that we provided.
We provide the preprocessed data MMLU (target task 1), PopQA (target task 2), and MSMARCO QA (source task) via this link.
Please download and unzip it in the root directory. After that, you will see the data/
folder.
You can generate the MSMARCO corpus via the following command:
wget --no-check-certificate https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz
tar -zxf marco.tar.gz
rm -rf marco.tar.gz
cd marco
join -t "$(echo -en '\t')" -e '' -a 1 -o 1.1 2.2 1.2 <(sort -k1,1 para.txt) <(sort -k1,1 para.title.txt) | sort -k1,1 -n > corpus.tsv
and move the corpus.tsv
into the data/msmarco/
folder.
You can generate the KILT-Wikipedia corpus via the following command:
python tools/process_kilt_wikipedia.py
The original LM is obtained from HuggingFace (e.g., flan-t5-base). Before running the code, please use the transforming scripts to transfer the original pytorch_model.bin model checkpoints to fit in our DeepSpeed + Megatron framework:
mkdir -p checkpoints/flan-t5-base/t5-MP1
python tools/transform.py \ --hf_path ${PATH_TO_PYTORCH_MODEL_BIN} --save_path "./checkpoints/flan-t5-base/t5-MP1" --half
For the retriever backbones, please download the t5-ance and contriever into the checkpoints/
folder.
All scripts are in the directory scripts
.
Before running the code, please first change the WORKING_DIR
to the current directory of this repo.
If the checkpoint is successfully loaded, the log printed to the stdout should contain messages like successfully loaded /path-to-checkpoint/t5-MP1/mp_rank_00_model_states.pt
. Otherwise, WARNING: could not find the metadata file /***/latest_checkpointed_iteration.txt will not load any checkpoints and will start from random
will display. Note that when you successfully load the model, you will see messages like The following zero checkpoints paths are missing: ['/path-to-checkpoint/200000/zero_pp_rank_0_mp_rank_00_optim_states.pt',...
which mean optimizer states are not loaded. This DOES NOT affect the use of model inference and you can just ignore it.
Running following scripts can reproduce our main results of AAR initialized from ANCE on MMLU.
- For AAR initialized from Contriever, please replace the "ance" by "contriever" in the scripts.
- For AAR trained with multi-task KILT, we provide the ANCE checkpoint here and the Contriever checkpoint here.
- For unassisted versions of LMs, please change the
passage_num
to 0 at first. - For popQA, please modify the
DATA_NAMES
to "popQA_kilt_wikipedia_ra_ance_aar".
For Flan-T5-Base:
bash scripts/LM/zs_base.sh
For Flan-T5-Large:
bash scripts/LM/zs_large.sh
For Flan-T5-XL:
bash scripts/LM/zs_xl.sh
For InstructGPT:
bash scripts/LM/zs_gpt.sh
To gather the results for four categories on MMLU:
python tools/gather_result_MMLU.py --task_name mmlu_msmarco_ra_ance_aar --method_name flan-t5-base --score 41.70
We take ANCE as the retriever backbone in our scripts and you can modify the model_name_or_path
to specify your own retriever backbone.
First prepare the LM-preferred and human-preferred documents for the augmentation-adapted training:
bash scripts/Retriever/pre_pipeline.sh
Then train the augmentation-adapted retriever (AAR):
bash scripts/Retriever/train.sh
Finally get the relevant documents for target tasks using AAR:
bash scripts/Retriever/post_pipeline.sh
Please cite our paper if you use AAR in your work:
@inproceedings{yu2023augmentation,
title={Augmentation-Adapted Retriever Improves Generalization of Language Models as Generic Plug-In},
author={Yu, Zichun and Xiong, Chenyan and Yu, Shi and Liu, Zhiyuan},
booktitle={Proceedings of ACL},
year={2023}
}