Skip to content

Commit

Permalink
Releasing code and checkpoints officially
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxiang-wu committed Dec 5, 2022
1 parent 6519deb commit 6b6e30a
Show file tree
Hide file tree
Showing 41 changed files with 9,440 additions and 2 deletions.
156 changes: 156 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# (jimmycode)
.python_history
.idea
pretrained_models
data/*
logs/
outputs/
runs/
results/
checkpoints/
cached_models/
paq_models/
.DS_Store
etc/
cached_outputs/
evaluate/
lightning_logs/
wandb/
artefacts
Icon*
.netrc
/kv_scripts/
nlu_dataset.py
nlu_trainer.py
models

134 changes: 132 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,132 @@
# EMAT
Efficient Memory-Augmented Transformers
# EMAT: An Efficient Memory-Augmented Transformer for Knowledge-Intensive NLP Tasks

## Installation and Setup

```shell
# create a conda environment
conda create -n emat -y python=3.8 && conda activate emat

# install pytorch
pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html # GPU
pip install torch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 # CPU

# install transformers
pip install transformers==4.14.1

# install faiss
pip install faiss-gpu==1.7.1.post3 # GPU
pip install faiss-cpu==1.7.1.post3 # CPU

# install dependencies
pip install -r requirements.txt

# install this package for development
pip install -e .
```

## Download datasets

[//]: # (NaturalQuestion, WebQuestion, TriviaQA, WoW_KILT, ELI5_KILT data:)

link: https://pan.baidu.com/s/1MwPzVLqZqslqCpWAtPVZ-Q

code: tynj

Download PAQ data from: https://github.com/facebookresearch/PAQ


## Run Interactive Script

Before running the following scripts, embeddings of key-value memory, index and PAQ should be prepared.
See [Start](#Start) to build your key-value memory and index.


NQ: use torch-embedding as retrieval index:
```bash
python demo.py \
--model_name_or_path="./EMAT_ckpt/FKSV-NQ" \
--qas_to_retrieve_from="./data/PAQ_L1" \
--test_task="nq" \
--task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \
--task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \
--embedding_index="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/embedding_index.pt"
--key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key_memory.pt" \
--value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/value_memory.pt"
```

NQ: use faiss as retrieval index:
```bash
python demo.py \
--model_name_or_path="./EMAT_ckpt/FKSV-NQ" \
--qas_to_retrieve_from="./data/PAQ_L1" \
--test_task="nq" \
--task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \
--task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \
--use_faiss \
--faiss_index_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key.sq8hnsw.80n80efc.faiss" \
--key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key_memory.pt" \
--value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/value_memory.pt"
```

Use SKSV model with faiss parallely search:
```bash
python demo.py \
--model_name_or_path="./EMAT_ckpt/SKSV-NQ" \
--qas_to_retrieve_from="./data/PAQ_L1" \
--test_task="nq" \
--task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \
--task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \
--use_faiss \
--faiss_index_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/key.sq8hnsw.80n80efc.faiss" \
--key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/key_memory.pt" \
--value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/value_memory.pt"
```

Run Wizard-of-Wikipedia Dialogue:
```bash
python demo.py \
--model_name_or_path="./EMAT_ckpt/FKSV-WQ/" \
--qas_to_retrieve_from="./tmp/PAQ_L1.pkl" \
--test_task="wow_kilt" \
--embedding_index_path="./embedding_and_faiss/debug_from_wow_ckpt/embedding_index.pt" \
--key_memory_path="./embedding_and_faiss/PAQ_L1_from_wow_ckpt/key_memory.pt" \
--value_memory_path="./embedding_and_faiss/PAQ_L1_from_wow_ckpt/value_memory.pt" \
--inference_data_path="./annotated_datasets/wizard_of_wikipedia/wow-test_without_answers-kilt.jsonl.txt"
```

## Start
<span id="Start"></span>

### 1. Pre-training

Pre-train EMAT-FKSV: `bash pretrain_scripts/pretrain_emat.sh`

Pre-train EMAT-SKSV: `bash pretrain_scripts/pretrain_sksv_emat.sh`

### 2. Fine-tune:

Fine-tune on NQ: `bash scripts/nq_train_with_paql1.sh`

Fine-tune on TQ: `bash scripts/tq_train_with_paql1.sh`

Fine-tune on WQ: `bash scripts/wq_train_with_paql1.sh`

Fine-tune on WoW : `bash kilt_scripts/wow_train.sh`

Fine-tune on ELI5: `bash kilt_scripts/eli5_train.sh`


### 3. Evaluation:

Evaluate NQ/TQ/WQ: `bash scripts/nq_eval.sh`, switch ``DATA_NAME`` to evaluate different dataset.

Evaluate WoW/ELI5: `bash kilt_scirpts/eval_kilt.sh`. You can upload the output prediction file to http://kiltbenchmark.com/ to get evaluation results.

### 4. Embed PAQ using fine-tuned NQ model and build FAISS index:
```bash
bash embed_scripts/nq_embed_paq_and_build_faiss.sh
```

### 5. Inference Speed
Test inference speed on ```inference_with_faiss.py```

Loading

0 comments on commit 6b6e30a

Please sign in to comment.