diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..75691cb
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,156 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# (jimmycode)
+.python_history
+.idea
+pretrained_models
+data/*
+logs/
+outputs/
+runs/
+results/
+checkpoints/
+cached_models/
+paq_models/
+.DS_Store
+etc/
+cached_outputs/
+evaluate/
+lightning_logs/
+wandb/
+artefacts
+Icon*
+.netrc
+/kv_scripts/
+nlu_dataset.py
+nlu_trainer.py
+models
+
diff --git a/README.md b/README.md
index da90d98..f46cab8 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,132 @@
-# EMAT
-Efficient Memory-Augmented Transformers
+# EMAT: An Efficient Memory-Augmented Transformer for Knowledge-Intensive NLP Tasks
+
+## Installation and Setup
+
+```shell
+# create a conda environment
+conda create -n emat -y python=3.8 && conda activate emat
+
+# install pytorch
+pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html # GPU
+pip install torch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 # CPU
+
+# install transformers
+pip install transformers==4.14.1
+
+# install faiss
+pip install faiss-gpu==1.7.1.post3 # GPU
+pip install faiss-cpu==1.7.1.post3 # CPU
+
+# install dependencies
+pip install -r requirements.txt
+
+# install this package for development
+pip install -e .
+```
+
+## Download datasets
+
+[//]: # (NaturalQuestion, WebQuestion, TriviaQA, WoW_KILT, ELI5_KILT data:)
+
+link: https://pan.baidu.com/s/1MwPzVLqZqslqCpWAtPVZ-Q
+
+code: tynj
+
+Download PAQ data from: https://github.com/facebookresearch/PAQ
+
+
+## Run Interactive Script
+
+Before running the following scripts, embeddings of key-value memory, index and PAQ should be prepared.
+See [Start](#Start) to build your key-value memory and index.
+
+
+NQ: use torch-embedding as retrieval index:
+```bash
+python demo.py \
+ --model_name_or_path="./EMAT_ckpt/FKSV-NQ" \
+ --qas_to_retrieve_from="./data/PAQ_L1" \
+ --test_task="nq" \
+ --task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \
+ --task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \
+ --embedding_index="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/embedding_index.pt"
+ --key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key_memory.pt" \
+ --value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/value_memory.pt"
+```
+
+NQ: use faiss as retrieval index:
+```bash
+python demo.py \
+ --model_name_or_path="./EMAT_ckpt/FKSV-NQ" \
+ --qas_to_retrieve_from="./data/PAQ_L1" \
+ --test_task="nq" \
+ --task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \
+ --task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \
+ --use_faiss \
+ --faiss_index_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key.sq8hnsw.80n80efc.faiss" \
+ --key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key_memory.pt" \
+ --value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/value_memory.pt"
+```
+
+Use SKSV model with faiss parallely search:
+```bash
+python demo.py \
+ --model_name_or_path="./EMAT_ckpt/SKSV-NQ" \
+ --qas_to_retrieve_from="./data/PAQ_L1" \
+ --test_task="nq" \
+ --task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \
+ --task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \
+ --use_faiss \
+ --faiss_index_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/key.sq8hnsw.80n80efc.faiss" \
+ --key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/key_memory.pt" \
+ --value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/value_memory.pt"
+```
+
+Run Wizard-of-Wikipedia Dialogue:
+```bash
+python demo.py \
+ --model_name_or_path="./EMAT_ckpt/FKSV-WQ/" \
+ --qas_to_retrieve_from="./tmp/PAQ_L1.pkl" \
+ --test_task="wow_kilt" \
+ --embedding_index_path="./embedding_and_faiss/debug_from_wow_ckpt/embedding_index.pt" \
+ --key_memory_path="./embedding_and_faiss/PAQ_L1_from_wow_ckpt/key_memory.pt" \
+ --value_memory_path="./embedding_and_faiss/PAQ_L1_from_wow_ckpt/value_memory.pt" \
+ --inference_data_path="./annotated_datasets/wizard_of_wikipedia/wow-test_without_answers-kilt.jsonl.txt"
+```
+
+## Start
+
+
+### 1. Pre-training
+
+Pre-train EMAT-FKSV: `bash pretrain_scripts/pretrain_emat.sh`
+
+Pre-train EMAT-SKSV: `bash pretrain_scripts/pretrain_sksv_emat.sh`
+
+### 2. Fine-tune:
+
+Fine-tune on NQ: `bash scripts/nq_train_with_paql1.sh`
+
+Fine-tune on TQ: `bash scripts/tq_train_with_paql1.sh`
+
+Fine-tune on WQ: `bash scripts/wq_train_with_paql1.sh`
+
+Fine-tune on WoW : `bash kilt_scripts/wow_train.sh`
+
+Fine-tune on ELI5: `bash kilt_scripts/eli5_train.sh`
+
+
+### 3. Evaluation:
+
+Evaluate NQ/TQ/WQ: `bash scripts/nq_eval.sh`, switch ``DATA_NAME`` to evaluate different dataset.
+
+Evaluate WoW/ELI5: `bash kilt_scirpts/eval_kilt.sh`. You can upload the output prediction file to http://kiltbenchmark.com/ to get evaluation results.
+
+### 4. Embed PAQ using fine-tuned NQ model and build FAISS index:
+```bash
+bash embed_scripts/nq_embed_paq_and_build_faiss.sh
+```
+
+### 5. Inference Speed
+Test inference speed on ```inference_with_faiss.py```
+
diff --git a/build_kvm.py b/build_kvm.py
new file mode 100644
index 0000000..df382a9
--- /dev/null
+++ b/build_kvm.py
@@ -0,0 +1,297 @@
+import argparse
+import copy
+import json
+import logging
+import math
+import os
+import random
+
+import datasets
+import torch
+import transformers
+from accelerate import Accelerator
+from datasets import load_dataset, load_metric, DatasetDict
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+from transformers import (
+ CONFIG_MAPPING,
+ MODEL_MAPPING,
+ set_seed,
+ T5Tokenizer,
+)
+from transformers.utils.versions import require_version
+
+from emat.t5 import T5WithKeyValueMemory
+from transformers import T5Config
+from emat.utils import load_jsonl, write_jsonl, verbalise_qa
+from utils.utils import reduce_query_or_key_embeds, save_model, CATArgs, update_CAT_config_from_args, load_model, \
+ get_key_value_encoder_inputs
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+
+logger = logging.getLogger(__name__)
+require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
+
+DATA_PATHS = {
+ "PAQ-L1": "./data/cbqa_data/pretrain_data/PAQ_L1/PAQ_L1.filtered.jsonl",
+ "data_for_debug": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-dev-3000.jsonl"
+}
+
+
+def load_paq_data(args) -> DatasetDict:
+ assert args.embed_data_name in DATA_PATHS.keys(), f"available dataset names: {DATA_PATHS.keys()}"
+ data_path = DATA_PATHS[args.embed_data_name]
+ return load_dataset("json", data_files=data_path)
+
+
+@torch.no_grad()
+def build_memory(model, tokenizer, output_dir=None, embed_key=False, embed_value=False, prefix="",
+ embed_as_fp16=False, key_reduce_method=None, data_path=None, data_to_embed=None,
+ max_source_length=None, padding=None, batch_size=1, allow_overlay_old_memory=False,
+ dump_memory=False, return_memory=False, separate_task=False, kvm_seg_n=-1,
+ disable_tqdm=False, reused_key_memory=None, collate_fn=None, normed_key_memory=True,
+ return_not_reduced_key=False, reused_not_reduced_key_memory=None, reused_value_memory=None,
+ num_workers=4, use_retrieval_adapter=False):
+ torch.cuda.empty_cache()
+ if data_to_embed is None:
+ data_to_embed = load_dataset("json", data_files=data_path)["train"]
+
+ if collate_fn is None:
+ def collate_fn(examples):
+ model_inputs = get_key_value_encoder_inputs(examples, separate_task, tokenizer, max_source_length,
+ prefix=prefix, only_return_key_inputs=not embed_value)
+ return model_inputs
+
+ qas_to_embed_dataloader = DataLoader(data_to_embed, batch_size=batch_size, num_workers=num_workers,
+ collate_fn=collate_fn)
+
+ key_memory: list = []
+ value_memory: list = []
+ not_reduced_key_memory = [] if return_not_reduced_key else None
+ model.eval()
+
+ key_cnt = 0
+ for batch in tqdm(qas_to_embed_dataloader, disable=disable_tqdm):
+ # for start_idx in tqdm(range(0, len(data_to_embed), batch_size), total=len(data_to_embed) // batch_size):
+ # batch_qas = data_to_embed[start_idx: start_idx + batch_size]
+ # batch = get_key_value_encoder_inputs(batch_qas, separate_task, tokenizer, max_source_length,
+ # prefix=prefix, only_return_key_inputs=True)
+ with torch.no_grad():
+ batch_keys = list(batch.keys())
+
+ # for k in batch_keys:
+ # v = batch.pop(k)
+ # batch[k] = v.to(model.device)
+ # del v
+ batch = {k: v.to(model.device) for k, v in batch.items()}
+
+ embed_dict = model.wrapped_embed_kv(
+ separate_task=separate_task,
+ compute_key=embed_key,
+ compute_value=embed_value,
+ # key_input_ids=batch["key_input_ids"].to(model.device),
+ # key_attention_mask=batch["key_attention_mask"].to(model.device),
+ # value_input_ids=batch.get("key_input_ids", None).to(model.device),
+ # value_attention_mask=batch.get("key_attention_mask", None).to(model.device),
+ **batch
+ )
+
+ for k in batch_keys:
+ del batch[k]
+
+ key_embeds = embed_dict.get("normed_key_embeds") if normed_key_memory else embed_dict.get("key_embeds")
+ value_embeds = embed_dict.get("value_embeds")
+ if embed_key:
+ key_embeds = reduce_query_or_key_embeds(key_embeds, key_reduce_method)
+ if use_retrieval_adapter:
+ key_embeds = model.adapter(key_embeds)
+ cur_key_num = key_embeds.shape[0]
+
+ if embed_key:
+ if embed_as_fp16:
+ key_embeds = key_embeds.half()
+ if reused_key_memory is not None:
+ key_embeds = key_embeds.cpu()
+ reused_key_memory[key_cnt: key_cnt + cur_key_num] = copy.deepcopy(key_embeds)
+ del key_embeds
+ else:
+ key_memory.append(key_embeds.cpu()) # [batch_size, hidden_size]
+
+ if return_not_reduced_key:
+ not_normed_key_embeds = embed_dict["key_embeds"]
+ if embed_as_fp16:
+ not_normed_key_embeds = not_normed_key_embeds.half()
+ if reused_not_reduced_key_memory is not None:
+ not_normed_key_embeds = not_normed_key_embeds.cpu()
+ reused_not_reduced_key_memory[key_cnt: key_cnt + cur_key_num] = copy.deepcopy(not_normed_key_embeds)
+ del not_normed_key_embeds
+ else:
+ not_reduced_key_memory.append(not_normed_key_embeds.cpu())
+
+ if embed_value:
+ if embed_as_fp16:
+ value_embeds = value_embeds.half()
+ if reused_value_memory is not None:
+ value_embeds = value_embeds.cpu()
+ reused_value_memory[key_cnt: key_cnt + cur_key_num] = copy.deepcopy(value_embeds)
+ del value_embeds
+ else:
+ value_memory.append(value_embeds.cpu()) # [batch_size, value_nums, hidden_size]
+
+ key_cnt += cur_key_num
+
+ if reused_key_memory is None:
+ if embed_key:
+ assert sum(i.shape[0] for i in key_memory) == len(data_to_embed)
+ if return_not_reduced_key:
+ assert sum(i.shape[0] for i in not_reduced_key_memory) == len(data_to_embed)
+ if embed_value:
+ assert sum(i.shape[0] for i in value_memory) == len(data_to_embed)
+
+ if dump_memory:
+ assert reused_key_memory is None, "Not Implement when reused_key_memory is set."
+ chunk_num = 128
+ chunk_batch_size = math.ceil(len(key_memory) / chunk_num)
+ if embed_key:
+ logger.info("dump key")
+ key_dir = os.path.join(output_dir, "key")
+ os.makedirs(key_dir, exist_ok=allow_overlay_old_memory)
+ save_num = 0
+ for cid, start_idx in tqdm(enumerate(range(0, len(key_memory), chunk_batch_size)), leave=True):
+ chunk_key_memory = torch.cat(key_memory[start_idx: start_idx + chunk_batch_size])
+ torch.save(chunk_key_memory, os.path.join(key_dir, f"{cid}.key.pt"))
+ save_num = save_num + chunk_key_memory.shape[0]
+ assert save_num == len(data_to_embed), \
+ f"saved key num is {save_num}, but example num is {len(data_to_embed)}"
+ if embed_value:
+ logger.info("dump value")
+ value_dir = os.path.join(output_dir, "value")
+ os.makedirs(value_dir, exist_ok=allow_overlay_old_memory)
+ save_num = 0
+ for cid, start_idx in tqdm(enumerate(range(0, len(value_memory), chunk_batch_size)), leave=True):
+ chunk_value_memory = torch.cat(value_memory[start_idx: start_idx + chunk_batch_size])
+ torch.save(chunk_value_memory, os.path.join(value_dir, f"{cid}.value.pt"))
+ save_num = save_num + chunk_value_memory.shape[0]
+ assert save_num == len(data_to_embed), \
+ f"saved value num is {save_num}, but example num is {len(data_to_embed)}"
+
+ if return_memory:
+ if kvm_seg_n > 1:
+ all_chunk_key_memory = []
+ if embed_key:
+ if reused_key_memory is not None:
+ logger.info(f"Split reused_key_memory into {kvm_seg_n} chunks.")
+ chunk_batch_size = math.ceil(len(reused_key_memory) / kvm_seg_n)
+ for start_idx in range(0, len(reused_key_memory), chunk_batch_size):
+ end_idx = min(len(reused_key_memory), start_idx + chunk_batch_size)
+ all_chunk_key_memory.append(reused_key_memory[start_idx:end_idx])
+ else:
+ logger.info(f"Combining the keys into {kvm_seg_n} chunks.")
+ chunk_batch_size = math.ceil(len(key_memory) / kvm_seg_n)
+ for cid, start_idx in tqdm(enumerate(range(0, len(key_memory), chunk_batch_size)), leave=True):
+ chunk_key_memory = torch.cat(key_memory[start_idx: start_idx + chunk_batch_size])
+ all_chunk_key_memory.append(chunk_key_memory)
+ assert len(all_chunk_key_memory) == kvm_seg_n
+
+ # if return_not_reduced_key:
+ # not_reduced_key_memory = torch.emat(not_reduced_key_memory)
+
+ if embed_value:
+ value_memory = torch.cat(value_memory)
+
+ return all_chunk_key_memory, value_memory
+
+ else:
+ if embed_key:
+ if reused_key_memory is not None:
+ key_memory = reused_key_memory
+ else:
+ logger.info(f"Combining the result.")
+ key_memory = torch.cat(key_memory)
+
+ if return_not_reduced_key:
+ if reused_not_reduced_key_memory is not None:
+ not_reduced_key_memory = reused_not_reduced_key_memory
+ else:
+ not_reduced_key_memory = torch.cat(not_reduced_key_memory)
+
+ if embed_value:
+ if reused_value_memory is not None:
+ value_memory = reused_value_memory
+ else:
+ value_memory = torch.cat(value_memory)
+ if return_not_reduced_key:
+ return key_memory, value_memory, not_reduced_key_memory
+ else:
+ return key_memory, value_memory
+
+
+def main():
+ # Parse the arguments
+ cat_args = CATArgs(exp_type="build_kvm")
+ args = cat_args.parse_args()
+
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
+ accelerator = Accelerator()
+
+ # Make one log on every process with the configuration for debugging.
+ logger.info(accelerator.state)
+
+ # Setup logging, we only want one process per machine to log things on the screen.
+ # accelerator.is_local_main_process is only True for one process per machine.
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+ accelerator.wait_for_everyone()
+
+ config, tokenizer, model = load_model(T5WithKeyValueMemory, args)
+ model.cuda()
+ prefix = args.source_prefix if args.source_prefix is not None else ""
+
+ # Temporarily set max_target_length for training.
+ max_target_length = args.max_target_length
+ padding = "max_length" if args.pad_to_max_length else True
+
+ # Load the datasets
+ data_to_embed = load_paq_data(args)["train"]
+
+ # Log a few random samples from the training set:
+ for index in random.sample(range(len(data_to_embed)), 3):
+ logger.info(f"Sample {index} of the training set: {data_to_embed[index]}.")
+
+ batch_size = args.per_device_train_batch_size
+ logger.info("***** Building Key-Value Memory *****")
+ logger.info(f" Num examples = {len(data_to_embed)}")
+ logger.info(f" Instantaneous batch size per device = {batch_size}")
+ # Only show the progress bar once on each machine.
+ build_memory(model, tokenizer, output_dir=args.output_dir, embed_key=args.embed_key, embed_value=args.embed_value,
+ prefix=prefix, embed_as_fp16=args.embed_as_fp16, key_reduce_method=args.key_reduce_method,
+ data_path=None, data_to_embed=data_to_embed, max_source_length=args.max_source_length, padding=padding,
+ batch_size=batch_size, allow_overlay_old_memory=False, dump_memory=True, return_memory=False,
+ separate_task=args.separate_task)
+
+ pretrain_args = json.load(open(os.path.join(args.model_name_or_path, "args.json")))
+ dict_args = vars(args)
+ dict_args["loaded_model_args"] = pretrain_args
+ json.dump(pretrain_args, open(os.path.join(args.output_dir, "kvm_args.json"), 'w'), indent=4)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000..f8c61a9
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,189 @@
+import faiss
+import asyncio
+import argparse
+import torch
+from transformers import T5Tokenizer, T5Config
+from emat.t5 import T5WithKeyValueMemory
+from emat.utils import load_jsonl
+import logging
+from kilt_dataset import DialogDataset
+import pickle
+import random
+from kilt_trainer import kilt_generate
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+
+
+def get_args():
+ parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Inference with faiss")
+ parser.add_argument("--model_name_or_path", type=str, required=False,
+ default="./outputs/nq_checkpoints/KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;pos_from_top=128;/best_ckpt/")
+ parser.add_argument("--qas_to_retrieve_from", default="./tmp/PAQ_L1_pickl_file.pkl")
+ parser.add_argument("--test_task", default="nq", type=str, choices=["nq", "wq", "tq", "wow_kilt"])
+ parser.add_argument("--task_train_data", default=None, required=False, type=str)
+ parser.add_argument("--task_dev_data", default=None, required=False, type=str)
+ parser.add_argument("--use_faiss", action="store_true", help="default -- use torch embedding")
+ parser.add_argument("--faiss_index_path", default=None, type=str, required=False)
+ parser.add_argument("--embedding_index_path", default=None, type=str, required=False)
+ parser.add_argument("--key_memory_path", required=True)
+ parser.add_argument("--value_memory_path", required=True)
+ parser.add_argument("--inference_type", type=str, default="serial", choices=["parallel", "serial"])
+ parser.add_argument("--inference_data_path", type=str, default=None, required=False)
+ parser.add_argument("--inference_batch_size", type=int, default=512)
+ args = parser.parse_args()
+
+ if args.use_faiss:
+ assert args.faiss_index_path is not None
+ else:
+ assert args.embedding_index_path is not None
+
+ return args
+
+
+def main():
+ args = get_args()
+
+ # load model
+ logging.info(f"loading model from {args.model_name_or_path}")
+ model, load_info = T5WithKeyValueMemory.from_pretrained(args.model_name_or_path, output_loading_info=True)
+ model.eval()
+ model = model.cuda()
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)
+ logging.info(f"model load info: {load_info}")
+ # check
+ if getattr(model, "cat_layer", None) == model.encoder.key_layer:
+ assert args.inference_type != "parallel", "parallel can not used in cat_layer == key_layer"
+
+ # load index and key-value memory
+ faiss_index, embedding_index = None, None
+ if args.use_faiss:
+ logging.info(f"loading index from {args.faiss_index_path}")
+ faiss_index = faiss.read_index(args.faiss_index_path)
+ logging.info("loaded faiss index.")
+ else:
+ logging.info(f"loading index from {args.embedding_index_path}")
+ embedding_index = torch.load(args.embedding_index_path)
+ logging.info("loaded embedding index.")
+ value_memory = torch.load(args.value_memory_path)
+ key_memory = torch.load(args.key_memory_path)
+
+ # load QAs to retrieve
+ logging.info(f"loading PAQ from {args.qas_to_retrieve_from}")
+ if args.qas_to_retrieve_from.endswith("pkl"):
+ qas_to_retrieve = pickle.load(open(args.qas_to_retrieve_from, 'rb'))
+ else: # jsonl
+ qas_to_retrieve = load_jsonl(args.qas_to_retrieve_from)
+ logging.info("loaded PAQ")
+ if args.test_task in ["nq", "wq", "tq"]:
+ if args.task_train_data is not None:
+ qas_to_retrieve = qas_to_retrieve + load_jsonl(args.task_train_data)
+ if args.task_dev_data is not None:
+ qas_to_retrieve = qas_to_retrieve + load_jsonl(args.task_dev_data)
+ assert len(qas_to_retrieve) == value_memory.shape[0] == key_memory.shape[0]
+ logging.info(f"numer of QAs to retrieve: {len(qas_to_retrieve)}")
+
+ if args.test_task in ["nq", 'wq', 'tq']:
+ gen_kwargs = {"num_beams": None, "max_length": 64}
+ else:
+ gen_kwargs = {"max_length": 1024, "num_beams": 8, "do_sample": True, "top_k": 64, "min_length": 8}
+
+ print("input ``ctrl + c`` to exit the program.")
+ if args.test_task in ["nq", 'wq', 'tq']:
+ while True:
+ question = input("Question: ")
+ batch = [{"question": question.strip()}]
+ ans, retrieved_qa = inference_qa(model, tokenizer, key_memory, value_memory, embedding_index,
+ faiss_index, qas_to_retrieve, args.inference_type, batch, gen_kwargs)
+ print(f"Answer: {ans[0]}")
+ print(f"retrieved QAs: ")
+ for qa in retrieved_qa[0]:
+ print(qa)
+
+ elif args.test_task == 'wow_kilt':
+ print("input '-1' to exit current dialogue")
+ dataset_kwargs = {"dataset_name": "wow_kilt", "max_source_length": 128}
+ inference_data = load_jsonl(args.inference_data_path)
+ while True:
+ cur_dialogue = random.sample(inference_data, 1)[0]
+ utterances = cur_dialogue["input"].split("\n")[:-1]
+ for idx, u in enumerate(utterances):
+ spk = "A" if idx % 2 == 0 else "B"
+ print(f"{spk}: {u}")
+ while True:
+ spk = "A" if len(utterances) % 2 == 0 else "B"
+ utterance = input(f"{spk}: ")
+ if utterance == "-1":
+ break
+ utterances.append(utterance)
+ cur_dialogue["input"] = "\n".join(utterances)
+ dataset = DialogDataset([cur_dialogue], tokenizer, qas_to_retrieve, **dataset_kwargs)
+ retrieved_qa, response = kilt_generate(
+ model, tokenizer, embedding_index, key_memory, value_memory, dataset,
+ qas_to_retrieve, args.inference_batch_size, gen_kwargs
+ )
+ spk = "A" if len(utterances) % 2 == 0 else "B"
+ print(f"{spk}: {response[0]}")
+ utterances.append(response[0])
+ print("")
+
+
+@torch.no_grad()
+def inference_qa(model, tokenizer, key_memory, value_memory, embedding_index, faiss_index,
+ qas_to_retrieve, inference_type, batch, gen_kwargs):
+ inputs = ["question: " + qa["question"] for qa in batch]
+ inputs = tokenizer(inputs, max_length=1024, padding=True, truncation=True, return_tensors="pt")
+ input_ids = inputs["input_ids"].to('cuda')
+ attention_mask = inputs["attention_mask"].to('cuda')
+ if embedding_index is not None:
+ encoder_outputs = model.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ readout_top_k=model.encoder.num_values,
+ key_reduce_method="avg",
+ value_fusion_method=model.encoder.value_fusion_method,
+ embedding_index=embedding_index,
+ key_memory=key_memory,
+ value_memory=value_memory
+ )
+ else:
+ if inference_type == "serial":
+ encoder_outputs = model.encoder.forward_with_faiss(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ readout_top_k=model.encoder.num_values,
+ key_reduce_method="avg",
+ value_fusion_method=model.encoder.value_fusion_method,
+ key_faiss_index=faiss_index,
+ value_memory=value_memory,
+ not_reduced_key_memory=key_memory
+ )
+ else:
+ encoder_outputs = asyncio.run(
+ model.encoder.forward_with_async_faiss(
+ input_ids, attention_mask, True, model.encoder.num_values, "avg",
+ model.encoder.value_fusion_method, faiss_index, value_memory, key_memory
+ )
+ )
+ generated_tokens = model.generate(
+ encoder_outputs=encoder_outputs,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=False,
+ attention_mask=attention_mask,
+ value_fusion_method=model.encoder.value_fusion_method,
+ **gen_kwargs,
+ )
+
+ readout_qas = [[qas_to_retrieve[idx] for idx in indices] for indices in encoder_outputs.readout_indices]
+ decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+ decoded_tokens = [ans.strip() for ans in decoded_tokens]
+ return decoded_tokens, readout_qas
+
+
+if __name__ == '__main__':
+ main()
diff --git a/emat/__init__.py b/emat/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/emat/evaluation/__init__.py b/emat/evaluation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/emat/evaluation/eval_qa_overlap.py b/emat/evaluation/eval_qa_overlap.py
new file mode 100644
index 0000000..557a8e2
--- /dev/null
+++ b/emat/evaluation/eval_qa_overlap.py
@@ -0,0 +1,158 @@
+"""Evaluation script to get prediction scores for overlapped QA pairs in ODQA datasets"""
+import argparse
+import os
+import string
+
+from emat.evaluation.exact_match import exact_match_score, metric_max_over_ground_truths, f1_score
+from emat.utils import load_jsonl
+
+ANNOTATIONS = [
+ 'total',
+ 'question_overlap',
+ 'no_question_overlap',
+ 'answer_overlap',
+ 'no_answer_overlap',
+ 'answer_overlap_only',
+ 'no_overlap',
+]
+
+DIRNAME = os.path.dirname(os.path.abspath(__file__))
+REFERENCE_PATHS = {
+ 'triviaqa': 'triviaqa-test.qa.csv',
+ 'naturalquestions': 'nq-test.qa.csv',
+ 'webquestions': 'webquestions-test.qa.csv',
+}
+ANNOTATION_PATHS = {
+ 'triviaqa': 'triviaqa-annotations.jsonl',
+ 'naturalquestions': 'nq-annotations.jsonl',
+ 'webquestions': 'webquestions-annotations.jsonl',
+}
+
+
+def preprocess(text: str) -> str:
+ exclude = set(string.punctuation)
+ exclude.add(" ")
+ exclude.add("’")
+ return "".join(ch for ch in text if ch not in exclude)
+
+
+def read_references(fi, sep='\t'):
+ def parse_pandas_answer(a_string):
+ # to avoid a pandas dependency, deserialize these manually
+ try:
+ parsed_answers = eval(a_string) if a_string.startswith('[') else eval(a_string.replace('""', '"')[1:-1])
+ except:
+ parsed_answers = eval(a_string.replace('""', '"').replace('""', '"').replace('""', '"')[1:-1])
+ return parsed_answers
+
+ questions, references = [], []
+ for i, line in enumerate(open(fi)):
+ q, answer_str = line.strip('\n').split(sep)
+ questions.append(q)
+ refs = parse_pandas_answer(answer_str)
+ references.append({'references': refs, 'id': i})
+ return questions, references
+
+
+def read_lines(path):
+ with open(path) as f:
+ return [l.strip() for l in f]
+
+
+def read_predictions(path):
+ if path.endswith('json') or path.endswith('.jsonl'):
+ return load_jsonl(path)
+ else:
+ return [{'id': i, 'prediction': pred} for i, pred in enumerate(read_lines(path))]
+
+
+def _get_scores(answers, refs, fn):
+ return [metric_max_over_ground_truths(fn, pred, rs) for pred, rs in zip(answers, refs)]
+
+
+def get_scores(predictions, references, annotations, annotation_labels=None):
+ predictions_map = {p['id']: p for p in predictions}
+ references_map = {r['id']: r for r in references}
+ annotations_map = {a['id']: a for a in annotations}
+ assert predictions_map.keys() == references_map.keys(), 'predictions file doesnt match the gold references file '
+ assert predictions_map.keys() == annotations_map.keys(), 'prediction file doesnt match the annotation file '
+ assert annotations_map.keys() == references_map.keys(), 'annotations file doesnt match the gold references file '
+
+ annotation_labels = ANNOTATIONS if annotation_labels is None else annotation_labels
+
+ results = {}
+ for annotation_label in annotation_labels:
+ if annotation_label == 'no_overlap':
+ annotation_ids = [
+ annotation['id'] for annotation in annotations if
+ all(label in annotation['labels'] for label in ['no_question_overlap', 'no_answer_overlap'])
+ ]
+ else:
+ annotation_ids = [
+ annotation['id'] for annotation in annotations if annotation_label in annotation['labels']
+ ]
+
+ preds = [predictions_map[idd]['prediction'] for idd in annotation_ids]
+ refs = [references_map[idd]['references'] for idd in annotation_ids]
+ em = _get_scores(preds, refs, exact_match_score)
+ f1 = _get_scores(preds, refs, f1_score)
+ results[annotation_label] = {
+ 'exact_match': 100 * sum(em) / len(em),
+ 'f1_score': 100 * sum(f1) / len(f1),
+ 'n_examples': len(annotation_ids),
+ }
+
+ return results
+
+
+def _print_score(label, results_dict):
+ print('-' * 50)
+ print('Label :', label)
+ print('N examples : ', results_dict['n_examples'])
+ print('Exact Match : ', results_dict['exact_match'])
+ # print('F1 score : ', results_dict['f1_score'])
+
+
+def main(predictions_path, dataset_name, data_dir):
+ references_path = os.path.join(data_dir, REFERENCE_PATHS[dataset_name])
+ annotations_path = os.path.join(data_dir, ANNOTATION_PATHS[dataset_name])
+ if not os.path.exists(references_path):
+ raise Exception(' References expected at ' + references_path
+ + ' not found, please download them using the download script (see readme)')
+ if not os.path.exists(annotations_path):
+ raise Exception(' Annotations expected at ' + annotations_path
+ + ' not found, please download them usiing the download script (see readme)')
+
+ questions, references = read_references(references_path)
+ annotations = load_jsonl(annotations_path)
+
+ predictions = read_predictions(predictions_path)
+ assert len(predictions) == len(references) == len(annotations)
+
+ # Align the predictions with the references using the questions
+ questions = [preprocess(q) for q in questions]
+ question_to_id = {q.strip(): qid for qid, q in enumerate(questions)}
+ id_to_prediction = {}
+ for pred in predictions:
+ q = preprocess(pred["question"].strip())
+ qid = question_to_id[q]
+ id_to_prediction[qid] = {"id": qid, "prediction": pred["prediction"]}
+ assert len(id_to_prediction) == len(references)
+ aligned_predictions = [id_to_prediction[qid] for qid in range(len(references))]
+
+ scores = get_scores(aligned_predictions, references, annotations)
+ for label in ANNOTATIONS:
+ _print_score(label, scores[label])
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--predictions",
+ help="path to predictions txt file, one answer per line. "
+ "Answer order should follow the order in data/{dataset}-test.qa.csv", type=str)
+ parser.add_argument('--dataset_name', choices=['naturalquestions', 'triviaqa', 'webquestions'], type=str,
+ help='name of datset to evaluate on')
+ parser.add_argument('--data_dir', default="data/qa-overlap", type=str, help='directory of the annotated data')
+
+ args = parser.parse_args()
+ main(args.predictions, args.dataset_name, args.data_dir)
diff --git a/emat/evaluation/eval_retriever.py b/emat/evaluation/eval_retriever.py
new file mode 100644
index 0000000..5711286
--- /dev/null
+++ b/emat/evaluation/eval_retriever.py
@@ -0,0 +1,64 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import argparse
+
+from emat.evaluation.exact_match import metric_max_over_ground_truths, exact_match_score
+from emat.utils import load_jsonl
+
+
+
+def eval_generation_em(refs, preds):
+ scores = []
+ for ref, pred in zip(refs, preds):
+ ref_answer = ref["answer"]
+ em = metric_max_over_ground_truths(exact_match_score, pred, ref_answer)
+ scores.append(em)
+ avg_score = sum(scores) / len(scores)
+ return avg_score
+
+def eval_retriever(refs, preds, hits_at_k):
+ if isinstance(hits_at_k, str):
+ hits_at_k = sorted([int(k) for k in hits_at_k.split(',')])
+
+ result_dict = {}
+ for k in hits_at_k:
+ scores = []
+ dont_print = False
+ for r, p in zip(refs, preds):
+ if hits_at_k[-1] > len(p): # p['retrieved_qas']
+ print(f'Skipping hits@{k} eval as {k} is larger than number of retrieved results')
+ dont_print = True
+ ref_answers = r['answer']
+ em = any([
+ metric_max_over_ground_truths(exact_match_score, pred_answer['answer'][0], ref_answers)
+ for pred_answer in p[:k] # p['retrieved_qas'][:k]
+ ])
+ scores.append(em)
+
+ avg_score = sum(scores) / len(scores)
+ # if not dont_print:
+ # print(f'{k}: {100 * avg_score:0.1f}% \n({sum(scores)} / {len(scores)})')
+
+ result_dict[f"hit@{k}"] = avg_score
+
+ return result_dict
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--predictions', type=str,
+ help="path to retrieval results to eval, in PAQ's retrieved results jsonl format")
+ parser.add_argument('--references', type=str, help="path to gold answers, in jsonl format")
+ parser.add_argument('--hits_at_k', type=str, help='comma separated list of K to eval hits@k for', default="1,10,50")
+ args = parser.parse_args()
+
+ refs = load_jsonl(args.references)
+ preds = load_jsonl(args.predictions)
+ assert len(refs) == len(preds), "number of references doesnt match number of predictions"
+
+ hits_at_k = sorted([int(k) for k in args.hits_at_k.split(',')])
+ eval_retriever(refs, preds, hits_at_k)
diff --git a/emat/evaluation/exact_match.py b/emat/evaluation/exact_match.py
new file mode 100644
index 0000000..8cb21c8
--- /dev/null
+++ b/emat/evaluation/exact_match.py
@@ -0,0 +1,140 @@
+# coding=utf-8
+import re
+import string
+import unicodedata
+from collections import Counter
+
+import datasets
+
+_CITATION = """\
+@inproceedings{rajpurkar-etal-2016-squad,
+ title = "{SQ}u{AD}: 100,000+ Questions for Machine Comprehension of Text",
+ author = "Rajpurkar, Pranav and
+ Zhang, Jian and
+ Lopyrev, Konstantin and
+ Liang, Percy",
+ booktitle = "Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing",
+ month = nov,
+ year = "2016",
+ address = "Austin, Texas",
+ publisher = "Association for Computational Linguistics",
+ url = "https://aclanthology.org/D16-1264",
+ doi = "10.18653/v1/D16-1264",
+ pages = "2383--2392",
+}
+@inproceedings{lee-etal-2019-latent,
+ title = "Latent Retrieval for Weakly Supervised Open Domain Question Answering",
+ author = "Lee, Kenton and
+ Chang, Ming-Wei and
+ Toutanova, Kristina",
+ booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
+ month = jul,
+ year = "2019",
+ address = "Florence, Italy",
+ publisher = "Association for Computational Linguistics",
+ url = "https://aclanthology.org/P19-1612",
+ doi = "10.18653/v1/P19-1612",
+ pages = "6086--6096",
+}
+"""
+
+_DESCRIPTION = """\
+Exact match score for Open-domain Question Answering.
+This metric measures the percentage of predictions that match any one of the ground truth answers exactly.
+"""
+
+_KWARGS_DESCRIPTION = """
+Calculates the percentage of predictions that match any one of the ground truth answers exactly.
+Args:
+ predictions: list of predictions to score. Each predictions
+ should be a string with tokens separated by spaces.
+ references: list of reference for each prediction. Each
+ reference should be a list of strings with tokens separated by spaces.
+Returns:
+ em: description of the first score,
+Examples:
+ >>> em_metric = datasets.load_metric("exact_match")
+ >>> results = em_metric.compute(references=[["apple", "orange"], ["banana"]], predictions=["apple", "pear"])
+ >>> print(results)
+ {'em': 0.5}
+"""
+
+
+def normalize_answer(s):
+ """Normalize answer."""
+ s = unicodedata.normalize("NFD", s)
+
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def f1_score(prediction, ground_truth):
+ prediction_tokens = normalize_answer(prediction).split()
+ ground_truth_tokens = normalize_answer(ground_truth).split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def exact_match_score(prediction, ground_truth):
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
+
+
+def regex_match_score(prediction, ground_truth):
+ try:
+ regex = re.compile(ground_truth, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE)
+ return regex.match(prediction) is not None
+ except re.error:
+ return False
+
+
+def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
+ scores_for_ground_truths = []
+ for ground_truth in ground_truths:
+ score = metric_fn(prediction, ground_truth)
+ scores_for_ground_truths.append(score)
+ return max(scores_for_ground_truths)
+
+
+@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
+class ExactMatch(datasets.Metric):
+ """Exact match (EM) metric for Open-domain Question Answering."""
+
+ def _info(self):
+ return datasets.MetricInfo(
+ description=_DESCRIPTION,
+ citation=_CITATION,
+ inputs_description=_KWARGS_DESCRIPTION,
+ features=datasets.Features({
+ 'predictions': datasets.Value('string'),
+ 'references': datasets.Sequence(datasets.Value('string')),
+ }),
+ homepage="https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html",
+ codebase_urls=[
+ "https://github.com/google-research/language/blob/58f5dc33a99d168a71586d64ffb7648a0f33b49a/language/orqa/utils/eval_utils.py#L23"],
+ reference_urls=["https://arxiv.org/pdf/1606.05250.pdf"]
+ )
+
+ def _compute(self, predictions, references, is_regex=False):
+ match_fn = regex_match_score if is_regex else exact_match_score
+ em_score = sum(metric_max_over_ground_truths(match_fn, i, j) for i, j in zip(predictions, references)) / len(
+ predictions)
+
+ return {"em": em_score}
diff --git a/emat/fusion_net.py b/emat/fusion_net.py
new file mode 100644
index 0000000..75de522
--- /dev/null
+++ b/emat/fusion_net.py
@@ -0,0 +1,30 @@
+import torch
+from torch import nn
+
+
+class FusionWeight(nn.Module):
+ # ["cat_k+v_g(kq)", "cat_k+v_g(kv)"]:
+ def __init__(self, fusion_type="cat_k+v_g(kq)", model_dim=512, prefix_length=2, key_dim=1024):
+ super(FusionWeight, self).__init__()
+ self.fusion_type = fusion_type
+ if "g(kq)" in self.fusion_type:
+ self.score_proj = nn.Linear(1, 1)
+ else:
+ input_dim = model_dim * prefix_length + key_dim
+ self.score_proj = nn.Linear(input_dim, 1)
+
+ def forward(self, key=None, query=None, value=None):
+ if "g(kv)" in self.fusion_type:
+ # key: batch_size, num_values, key_token_num, hidden_size
+ # value: batch_size, num_values, prefix_length, hidden_size
+ batch_size, num_values, _, _ = key.shape
+ input_hidden = torch.cat((key, value), dim=2) # batch_size, num_values, key_token_num+prefix_length, hidden
+ scores = self.score_proj(input_hidden.view(batch_size, num_values, -1))
+ else:
+ # key: batch_size, num_values, hidden_size
+ # query: batch_size, hidden_size
+ scores = torch.bmm(key, query.unsqueeze(dim=1).transpose(2, 1)) # batch_size, num_values, 1
+ scores = self.score_proj(scores)
+
+ scores = torch.sigmoid(scores)
+ return scores
diff --git a/emat/retrieval_adapter.py b/emat/retrieval_adapter.py
new file mode 100644
index 0000000..49dfdd4
--- /dev/null
+++ b/emat/retrieval_adapter.py
@@ -0,0 +1,20 @@
+from torch import nn
+
+
+class RetAdapter(nn.Module):
+
+ def __init__(self, in_dim, out_dim, adapter_type="linear"):
+ super(RetAdapter, self).__init__()
+ assert adapter_type in ["linear", "dropout_linear"]
+ if adapter_type == "linear":
+ self.out = nn.Linear(in_dim, out_dim, bias=True)
+ elif adapter_type == "dropout_linear":
+ self.out = nn.Sequential(
+ nn.Dropout(0.9),
+ nn.Linear(in_dim, out_dim, bias=True)
+ )
+ else:
+ raise NotImplementedError
+
+ def forward(self, vectors):
+ return self.out(vectors)
diff --git a/emat/retriever/build_index.py b/emat/retriever/build_index.py
new file mode 100644
index 0000000..2c1582a
--- /dev/null
+++ b/emat/retriever/build_index.py
@@ -0,0 +1,405 @@
+import argparse
+import logging
+import os
+import pickle
+import random
+import glob
+import time
+
+import faiss
+import torch
+from tqdm import tqdm
+
+from emat.retriever.utils import parse_vectors_from_file
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+logging.basicConfig(level=logging.DEBUG)
+logger = logging.getLogger(__name__)
+
+
+def get_vector_sample(vector, sample_fraction):
+ max_phi = -1
+ N = 0
+
+ phis = (vector ** 2).sum(1)
+ max_phi = max(max_phi, phis.max())
+ N += vector.shape[0]
+ if sample_fraction == 1.0:
+ vector_sample = vector
+ else:
+ vector_sample = vector[random.sample(range(0, len(vector)), int(len(vector) * sample_fraction))]
+
+ return vector_sample, max_phi, N
+
+
+def augment_vectors(vectors, max_phi):
+ phis = (vectors ** 2).sum(1)
+ aux_dim = torch.sqrt(max_phi.float() - phis.float())
+ vectors = torch.cat([vectors, aux_dim.unsqueeze(-1)], -1)
+ return vectors
+
+
+def build_index_streaming(cached_embeddings_path,
+ output_path,
+ vector=None,
+ hnsw=False,
+ sq8_quantization=False,
+ fp16_quantization=False,
+ store_n=256,
+ ef_search=32,
+ ef_construction=80,
+ sample_fraction=0.1,
+ indexing_batch_size=5000000,
+ ):
+ if vector is None:
+ vector = parse_vectors_from_file(cached_embeddings_path)
+ vector_size = vector.shape[1]
+
+ if hnsw:
+ if sq8_quantization:
+ index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_8bit, store_n)
+ elif fp16_quantization:
+ index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_fp16, store_n)
+ else:
+ index = faiss.IndexHNSWFlat(vector_size + 1, store_n)
+
+ index.hnsw.efSearch = ef_search
+ index.hnsw.efConstruction = ef_construction
+ else:
+ if sq8_quantization:
+ index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2)
+ elif fp16_quantization:
+ index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_fp16, faiss.METRIC_L2)
+ else:
+ index = faiss.IndexIP(vector_size + 1, store_n)
+
+ vector_sample, max_phi, N = get_vector_sample(vector, sample_fraction)
+ if hnsw:
+ vector_sample = augment_vectors(vector_sample, max_phi)
+
+ if sq8_quantization or fp16_quantization: # index requires training
+ vs = vector_sample.numpy()
+ logging.info(f'Training Quantizer with matrix of shape {vs.shape}')
+ index.train(vs)
+ del vs
+ del vector_sample
+
+ # logging.warning("tmp code")
+ # import gc
+ # del vector
+ # gc.collect()
+ # for idx in range(16):
+ # path = f"./data/embedding_and_faiss/PAQ_from_nq_ckpt/key_memory_dir/embeddings.{idx}.pt"
+ # vector = torch.load(path)
+ # if hnsw:
+ # vector_chunk = augment_vectors(vector, max_phi)
+
+ # original code
+ if hnsw:
+ vector = augment_vectors(vector, max_phi)
+ logging.info(f'Adding Vectors of shape {vector.shape}')
+ index.add(vector.numpy())
+
+ if output_path is not None:
+ logger.info(f'Index Built, writing index to {output_path}')
+ faiss.write_index(index, output_path)
+ logger.info(f'Index dumped')
+ else:
+ logger.info("Built faiss-index.")
+ return index
+
+
+def parse_vectors_from_directory_chunks(embeddings_dir, half):
+ assert os.path.isdir(embeddings_dir), \
+ f"Vectors directory {embeddings_dir} doesnt exist, or is not a directory of pytorch vectors"
+ paths = glob.glob(f"{embeddings_dir}/embeddings.*.pt")
+ assert len(paths) > 0, "Files not found."
+ paths_with_order = sorted([(int(os.path.basename(p).split('.')[1]), p) for p in paths], key=lambda x: x[0])
+ paths = [po[1] for po in paths_with_order]
+ for p in paths:
+ print(p)
+ for j, p in enumerate(paths):
+ m = torch.load(p)
+ # assert int(os.path.basename(p).split('.')[-3]) == j, (p, j)
+ if half:
+ m = m if m.dtype == torch.float16 else m.half()
+ else:
+ m = m if m.dtype == torch.float32 else m.float()
+ yield m
+
+
+def get_vector_sample_from_dir(cached_embeddings_path, sample_fraction, half=False):
+ samples = []
+ max_phi = -1
+ N = 0
+ vectors = parse_vectors_from_directory_chunks(cached_embeddings_path, half)
+ for chunk in vectors:
+ phis = (chunk ** 2).sum(1)
+ max_phi = max(max_phi, phis.max())
+ N += chunk.shape[0]
+ if sample_fraction == 1.0:
+ chunk_sample = chunk
+ else:
+ chunk_sample = chunk[random.sample(range(0, len(chunk)), int(len(chunk) * sample_fraction))]
+ samples.append(chunk_sample)
+
+ del vectors
+ vector_sample = torch.cat(samples)
+ return vector_sample, max_phi, N
+
+
+def get_vector_from_key_chunks(key_chunks, half=False):
+ samples = []
+ max_phi = -1
+ N = 0
+ # vectors = parse_vectors_from_directory_chunks(key_chunks, half)
+ for chunk in key_chunks:
+
+ if half:
+ chunk = chunk if chunk.dtype == torch.float16 else chunk.half()
+ else:
+ chunk = chunk if chunk.dtype == torch.float32 else chunk.float()
+
+ phis = (chunk ** 2).sum(1)
+ max_phi = max(max_phi, phis.max())
+ N += chunk.shape[0]
+ chunk_sample = chunk
+ samples.append(chunk_sample)
+
+ vector_sample = torch.cat(samples)
+ return vector_sample, max_phi, N
+
+
+def build_index_streaming_from_dir(cached_embeddings_path,
+ output_path,
+ hnsw=False,
+ sq8_quantization=False,
+ fp16_quantization=False,
+ store_n=256,
+ ef_search=32,
+ ef_construction=80,
+ sample_fraction=0.1,
+ indexing_batch_size=5000000,
+ ):
+ logger.info("build index, read from directory.")
+ first_chunk = torch.load(os.path.join(cached_embeddings_path, "embeddings.0.pt")) # [batch_size, hidden_size]
+ vector_size = first_chunk.shape[1]
+ # load first chunk
+ del first_chunk
+
+ if not os.path.exists("./data/embedding_and_faiss/PAQ_from_nq_ckpt/trained_index.pkl"):
+
+ if hnsw:
+ if sq8_quantization:
+ index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_8bit, store_n)
+ elif fp16_quantization:
+ index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_fp16, store_n)
+ else:
+ index = faiss.IndexHNSWFlat(vector_size + 1, store_n)
+
+ index.hnsw.efSearch = ef_search
+ index.hnsw.efConstruction = ef_construction
+ else:
+ if sq8_quantization:
+ index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2)
+ elif fp16_quantization:
+ index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_fp16, faiss.METRIC_L2)
+ else:
+ index = faiss.IndexIP(vector_size + 1, store_n)
+
+ vector_sample, max_phi, N = get_vector_sample_from_dir(cached_embeddings_path, sample_fraction)
+
+ print(max_phi, N)
+ # exit()
+ if hnsw:
+ vector_sample = augment_vectors(vector_sample, max_phi)
+
+ if sq8_quantization or fp16_quantization: # index requires training
+ vs = vector_sample.numpy()
+ logging.info(f'Training Quantizer with matrix of shape {vs.shape}')
+ index.train(vs)
+ del vs
+ pickle.dump({"index": index,
+ "max_phi": max_phi},
+ open("./data/embedding_and_faiss/PAQ_from_nq_ckpt/trained_index.pkl", 'wb'))
+ exit()
+ del vector_sample
+
+ else:
+ load_index_phi = pickle.load(open("./data/embedding_and_faiss/PAQ_from_nq_ckpt/trained_index.pkl", 'rb'))
+ max_phi = load_index_phi["max_phi"]
+ index = load_index_phi["index"]
+ N = 64963526
+
+ chunks_to_add = []
+ added = 0
+ for vector_chunk in parse_vectors_from_directory_chunks(cached_embeddings_path, half=False):
+ if hnsw:
+ vector_chunk = augment_vectors(vector_chunk, max_phi)
+
+ chunks_to_add.append(vector_chunk)
+
+ if sum(c.shape[0] for c in chunks_to_add) > indexing_batch_size:
+ to_add = torch.cat(chunks_to_add)
+ logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}')
+ added += to_add.shape[0]
+ chunks_to_add = []
+ index.add(to_add.numpy())
+
+ if len(chunks_to_add) > 0:
+ to_add = torch.cat(chunks_to_add).numpy()
+ index.add(to_add)
+ logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}')
+
+ logger.info(f'Index Built, writing index to {output_path}')
+ faiss.write_index(index, output_path)
+ logger.info(f'Index dumped')
+ return index
+
+
+def build_index_streaming_from_key_chunks(key_chunks,
+ output_path,
+ hnsw=False,
+ sq8_quantization=False,
+ fp16_quantization=False,
+ store_n=256,
+ ef_search=32,
+ ef_construction=80,
+ sample_fraction=0.1,
+ indexing_batch_size=5000000,
+ ):
+ logger.info("build index, read from directory.")
+ # first_chunk = torch.load(os.path.join(cached_embeddings_path, "0.key.pt")) # [batch_size, hidden_size]
+ vector_size = key_chunks[0].shape[1]
+
+ if hnsw:
+ if sq8_quantization:
+ index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_8bit, store_n)
+ elif fp16_quantization:
+ index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_fp16, store_n)
+ else:
+ index = faiss.IndexHNSWFlat(vector_size + 1, store_n)
+
+ index.hnsw.efSearch = ef_search
+ index.hnsw.efConstruction = ef_construction
+ else:
+ if sq8_quantization:
+ index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2)
+ elif fp16_quantization:
+ index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_fp16, faiss.METRIC_L2)
+ else:
+ index = faiss.IndexFlatIP(vector_size + 1, store_n)
+
+ vector_sample, max_phi, N = get_vector_from_key_chunks(key_chunks)
+ if hnsw:
+ vector_sample = augment_vectors(vector_sample, max_phi)
+
+ if sq8_quantization or fp16_quantization: # index requires training
+ vs = vector_sample.numpy()
+ logging.info(f'Training Quantizer with matrix of shape {vs.shape}')
+ index.train(vs)
+ del vs
+ del vector_sample
+
+ chunks_to_add = []
+ added = 0
+ for vector_chunk in key_chunks:
+ if hnsw:
+ vector_chunk = augment_vectors(vector_chunk, max_phi)
+
+ chunks_to_add.append(vector_chunk)
+
+ if sum(c.shape[0] for c in chunks_to_add) > indexing_batch_size:
+ logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}')
+ to_add = torch.cat(chunks_to_add)
+ chunks_to_add = []
+ index.add(to_add)
+ added += 1
+ faiss.write_index(index, output_path) # save intermediate index
+
+ if len(chunks_to_add) > 0:
+ to_add = torch.cat(chunks_to_add).float().numpy()
+ index.add(to_add)
+ logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}')
+
+ if output_path is not None:
+ logger.info(f'Index Built, writing index to {output_path}')
+ faiss.write_index(index, output_path)
+ logger.info(f'Index dumped')
+ else:
+ logger.info(f'Index Built.')
+ return index
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser("Build a FAISS index from precomputed vector files from embed.py. "
+ "Provides functionality to build either flat indexes (slow but exact)"
+ " or HNSW indexes (much faster, but approximate). "
+ "Optional application of 8bit or 16bit quantization is also available."
+ " Many more indexes are possible with Faiss, consult the Faiss repository here"
+ " if you want to build more advanced indexes.")
+ parser.add_argument('--embeddings_dir', type=str, help='path to directory containing vectors to build index from')
+ parser.add_argument('--output_path', type=str, help='path to write results to')
+ parser.add_argument('--hnsw', action='store_true', help='Build an HNSW index rather than Flat')
+ parser.add_argument('--SQ8', action='store_true', help='use SQ8 quantization on index to save memory')
+ parser.add_argument('--fp16', action='store_true', help='use fp16 quantization on index to save memory')
+ parser.add_argument('--store_n', type=int, default=32, help='hnsw store_n parameter')
+ parser.add_argument('--ef_construction', type=int, default=128, help='hnsw ef_construction parameter')
+ parser.add_argument('--ef_search', type=int, default=128, help='hnsw ef_search parameter')
+ parser.add_argument('--sample_fraction', type=float, default=1.0,
+ help='If memory is limited, specify a fraction (0.0->1.0) of the '
+ 'data to sample for training the quantizer')
+ parser.add_argument('--indexing_batch_size', type=int, default=None,
+ help='If memory is limited, specify the approximate number '
+ 'of vectors to add to the index at once')
+ parser.add_argument('-v', '--verbose', action="store_true")
+ args = parser.parse_args()
+ logging.info(f"Current process's PID: {os.getpid()}")
+ set_num_threads = 10240
+ faiss.omp_set_num_threads(set_num_threads)
+ logging.info(f"FAISS build info -- set threads {set_num_threads}")
+ logging.info(f"FAISS build info -- max threads {faiss.omp_get_max_threads()}")
+
+ if args.verbose:
+ logging.basicConfig(level=logging.DEBUG)
+ assert not (args.SQ8 and args.fp16), 'cant use both sq8 and fp16 Quantization'
+ assert not os.path.exists(args.output_path), "Faiss index with name specificed in --output_path already exists"
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
+
+ args.indexing_batch_size = 10000000000000 if args.indexing_batch_size is None else args.indexing_batch_size
+ assert 0 < args.sample_fraction <= 1.0
+
+ build_start = time.perf_counter()
+
+ if os.path.isdir(args.embeddings_dir):
+ build_index_streaming_from_dir(
+ args.embeddings_dir,
+ args.output_path,
+ args.hnsw,
+ sq8_quantization=args.SQ8,
+ fp16_quantization=args.fp16,
+ store_n=args.store_n,
+ ef_construction=args.ef_construction,
+ ef_search=args.ef_search,
+ sample_fraction=args.sample_fraction,
+ indexing_batch_size=args.indexing_batch_size,
+ )
+ else:
+ build_index_streaming(
+ args.embeddings_dir,
+ args.output_path,
+ hnsw=args.hnsw,
+ sq8_quantization=args.SQ8,
+ fp16_quantization=args.fp16,
+ store_n=args.store_n,
+ ef_construction=args.ef_construction,
+ ef_search=args.ef_search,
+ sample_fraction=args.sample_fraction,
+ indexing_batch_size=args.indexing_batch_size,
+ )
+
+ logging.info(f"building index cost {build_start - time.perf_counter():.5f} seconds")
diff --git a/emat/retriever/embed.py b/emat/retriever/embed.py
new file mode 100644
index 0000000..ae547ce
--- /dev/null
+++ b/emat/retriever/embed.py
@@ -0,0 +1,98 @@
+import argparse
+import logging
+import os
+import time
+
+import torch
+from transformers import T5Tokenizer
+
+from emat.t5 import T5Config, T5KeyValueEncoder
+from emat.utils import to_fp16, verbalise_qa as _verbalise, load_jsonl
+
+logger = logging.getLogger(__name__)
+CUDA = torch.cuda.is_available()
+
+
+def embed_key(model, tokenizer, qas, prefix="", bsz=256, max_length=1024, cuda=CUDA, fp16=False,
+ use_both_qa_for_key=False):
+ """Compute the key/query embeddings.
+ prefix: empty when encoding query, "encode: " when encoding key.
+ """
+ verbalise_qa = _verbalise if use_both_qa_for_key else lambda x, y: x
+
+ def tokenize(batch_qas):
+ input_strs = [prefix + verbalise_qa(ex["question"], ex["answer"][0]) for ex in batch_qas]
+ inputs = tokenizer(input_strs, max_length=max_length, padding=True, truncation=True, return_tensors="pt")
+ return inputs
+
+ if cuda:
+ model = model.cuda()
+ model = to_fp16(model) if fp16 else model
+
+ t = time.time()
+
+ def log_progress(j, outputs):
+ t2 = time.time()
+ logger.info(
+ f'Embedded {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds '
+ f'({sum([len(o) for o in outputs]) / (t2 - t): 0.4f} QAs per second)')
+
+ outputs = []
+ with torch.no_grad():
+ for j, batch_start in enumerate(range(0, len(qas), bsz)):
+ batch = qas[batch_start: batch_start + bsz]
+
+ inputs = tokenize(batch)
+ inputs = {k: v.cuda() for k, v in inputs.items()} if cuda else inputs
+
+ batch_outputs = model(**inputs, compute_value=False, return_dict=True)
+ outputs.append(batch_outputs.key.cpu())
+ if j % 10 == 0:
+ log_progress(j, outputs)
+
+ log_progress(j, outputs)
+
+ return torch.cat(outputs, dim=0).cpu()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_name_or_path', type=str, required=True, help='path to HF model dir')
+ parser.add_argument("--max_source_length", type=int, default=1024,
+ help="The maximum total input sequence length after tokenization.Sequences "
+ "longer than this will be truncated, sequences shorter will be padded.")
+ parser.add_argument('--qas_to_embed', type=str, required=True, help='Path to questions to embed in jsonl format')
+ parser.add_argument('--output_path', type=str, help='path to write vectors to')
+ parser.add_argument('--fp16', action='store_true')
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('-v', '--verbose', action="store_true")
+
+ parser.add_argument("--source_prefix", type=str, default="nq question: ",
+ help="A prefix to add before every source text " "(useful for T5 models).", )
+ parser.add_argument("--use_both_qa_for_key", action="store_true", help="Use both Q and A for key embedding.")
+
+ args = parser.parse_args()
+
+ if args.verbose:
+ logging.basicConfig(level=logging.DEBUG)
+
+ if args.fp16 and not CUDA:
+ raise Exception("Can't use --fp16 without a gpu, CUDA not found")
+
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
+
+ qas_to_embed = load_jsonl(args.qas_to_embed)
+
+ config = T5Config.from_pretrained(args.model_name_or_path)
+ model = T5KeyValueEncoder.from_pretrained(
+ args.model_name_or_path,
+ from_tf=bool(".ckpt" in args.model_name_or_path),
+ config=config,
+ )
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
+
+ embed_mat = embed_key(model, tokenizer, qas_to_embed, prefix=args.source_prefix, bsz=args.batch_size,
+ max_length=args.max_source_length, fp16=args.fp16,
+ use_both_qa_for_key=args.use_both_qa_for_key)
+
+ torch.save(embed_mat.half(), args.output_path)
diff --git a/emat/retriever/retrieve.py b/emat/retriever/retrieve.py
new file mode 100644
index 0000000..92feecb
--- /dev/null
+++ b/emat/retriever/retrieve.py
@@ -0,0 +1,180 @@
+import argparse
+import logging
+import time
+from copy import deepcopy
+
+import faiss
+import numpy as np
+import torch
+from transformers import T5Tokenizer
+
+from emat.evaluation.eval_retriever import eval_retriever
+from emat.retriever.utils import get_mips_function, parse_vectors_from_file, mips
+from emat.t5 import T5Config, T5KeyValueEncoder
+from emat.utils import load_jsonl, write_jsonl, to_fp16
+
+logger = logging.getLogger(__name__)
+
+CUDA = torch.cuda.is_available()
+
+
+def get_output_format(qas_to_answer, qas_to_retrieve_from, top_indices, top_scores):
+ results = []
+ for qa_ind, qa in enumerate(qas_to_answer):
+ res = []
+ for score_ind, ind in enumerate(top_indices[qa_ind]):
+ score = top_scores[qa_ind][score_ind]
+ ret_qa = deepcopy(qas_to_retrieve_from[ind])
+ ret_qa['score'] = float(score)
+ res.append(ret_qa)
+ results.append(res)
+
+ return [{'input_qa': in_qa, 'retrieved_qas': ret_qas} for in_qa, ret_qas in zip(qas_to_answer, results)]
+
+
+def embed_query(model, tokenizer, qas, prefix="", bsz=256, max_length=1024, cuda=CUDA, fp16=False):
+ def tokenize(batch_qas):
+ input_strs = [prefix + ex["question"] for ex in batch_qas]
+ inputs = tokenizer(input_strs, max_length=max_length, padding=True, truncation=True, return_tensors="pt")
+ return inputs
+
+ if cuda:
+ model = model.cuda()
+ model = to_fp16(model) if fp16 else model
+
+ t = time.time()
+
+ def log_progress(j, outputs):
+ t2 = time.time()
+ logger.info(
+ f'Embedded {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds '
+ f'({sum([len(o) for o in outputs]) / (t2 - t): 0.4f} QAs per second)')
+
+ outputs = []
+ with torch.no_grad():
+ for j, batch_start in enumerate(range(0, len(qas), bsz)):
+ batch = qas[batch_start: batch_start + bsz]
+
+ inputs = tokenize(batch)
+ inputs = {k: v.cuda() for k, v in inputs.items()} if cuda else inputs
+
+ batch_outputs = model(**inputs, compute_value=False, return_dict=True)
+ outputs.append(batch_outputs.key.cpu())
+ if j % 10 == 0:
+ log_progress(j, outputs)
+
+ log_progress(j, outputs)
+
+ return torch.cat(outputs, dim=0).cpu()
+
+
+
+
+
+def run_queries(model, tokenizer, qas_to_retrieve_from, qas_to_answer, top_k, index=None, prefix="",
+ batch_size=128, max_length=1024, fp16=False, n_queries_to_parallelize=2048):
+ assert index is not None
+
+ logger.info('Embedding QAs to answer:')
+ embedded_qas_to_answer = embed_query(model, tokenizer, qas_to_answer, prefix=prefix, bsz=batch_size,
+ max_length=max_length, cuda=CUDA, fp16=fp16)
+ logger.info('Running MIPS search:')
+ top_indices, top_scores = mips(index, embedded_qas_to_answer, top_k,
+ n_queries_to_parallelize=n_queries_to_parallelize)
+
+ return get_output_format(qas_to_answer, qas_to_retrieve_from, top_indices, top_scores)
+
+
+def _load_index_if_exists(faiss_index_path, precomputed_embeddings_dir, n_vectors_to_load=None, memory_friendly=False,
+ efsearch=128):
+ index = None
+ if faiss_index_path is not None:
+ assert precomputed_embeddings_dir is None, "Do not specify both a --faiss_index_path and --precomputed_embeddings_dir"
+ logger.info('Loading Faiss index:')
+ index = faiss.read_index(faiss_index_path)
+ if hasattr(index, 'hnsw'):
+ index.hnsw.efSearch = efsearch
+
+ elif precomputed_embeddings_dir is not None:
+ logger.info('Loading vectors index from file:')
+ index = parse_vectors_from_file(precomputed_embeddings_dir).float()
+ assert n_vectors_to_load == index.shape[0]
+
+ logger.info('Index loaded') if index is not None else None
+ return index
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_name_or_path', type=str, required=True, help='path to HF model dir')
+ parser.add_argument("--max_source_length", type=int, default=1024,
+ help="The maximum total input sequence length after tokenization.Sequences "
+ "longer than this will be truncated, sequences shorter will be padded.")
+ parser.add_argument('--qas_to_answer', type=str, required=True, help="path to questions to answer in jsonl format")
+ parser.add_argument('--qas_to_retrieve_from', type=str, required=True,
+ help="path to QA-pairs to retrieve answers from in jsonl format")
+ parser.add_argument('--top_k', type=int, default=50, help="top K QA-pairs to retrieve for each input question")
+ parser.add_argument('--output_file', type=str, required=True, help='Path to write jsonl results to')
+ parser.add_argument('--faiss_index_path', default=None, type=str,
+ help="Path to faiss index, if retrieving from a faiss index")
+ parser.add_argument('--precomputed_embeddings_dir', default=None, type=str,
+ help="path to a directory of vector embeddings if retrieving from raw embeddign vectors")
+ parser.add_argument('--fp16', action='store_true')
+ parser.add_argument('--batch_size', type=int, default=128, help='Batch size for embedding questions for querying')
+ parser.add_argument('--n_queries_to_parallelize', type=int, default=256, help="query batch size")
+ parser.add_argument('-v', '--verbose', action="store_true")
+ parser.add_argument('--memory_friendly_parsing', action='store_true',
+ help='Pass this to load files more slowly, but save memory')
+ parser.add_argument('--faiss_efsearch', type=int, default=128,
+ help='EFSearch search time parameter for hnsw, higher is more accurate but slower')
+
+ parser.add_argument("--source_prefix", type=str, default="nq question: ",
+ help="A prefix to add before every source text " "(useful for T5 models).", )
+ parser.add_argument('--hits_at_k', type=str, help='comma separated list of K to eval hits@k for', default="1,10,50")
+
+ args = parser.parse_args()
+
+ if args.verbose:
+ logging.basicConfig(level=logging.DEBUG)
+
+ qas_to_answer = load_jsonl(args.qas_to_answer)
+ qas_to_retrieve_from = load_jsonl(args.qas_to_retrieve_from)
+
+ index = _load_index_if_exists(
+ args.faiss_index_path,
+ args.precomputed_embeddings_dir,
+ n_vectors_to_load=len(qas_to_retrieve_from),
+ memory_friendly=args.memory_friendly_parsing,
+ efsearch=args.faiss_efsearch
+ )
+
+ config = T5Config.from_pretrained(args.model_name_or_path)
+ model = T5KeyValueEncoder.from_pretrained(
+ args.model_name_or_path,
+ from_tf=bool(".ckpt" in args.model_name_or_path),
+ config=config,
+ )
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
+
+ retrieved_answers = run_queries(
+ model,
+ tokenizer,
+ qas_to_retrieve_from,
+ qas_to_answer,
+ args.top_k,
+ index,
+ args.source_prefix,
+ args.batch_size,
+ args.max_source_length,
+ args.fp16,
+ args.n_queries_to_parallelize,
+ )
+
+ logger.info(f'Writing retrieval output to {args.output_file}')
+ write_jsonl(retrieved_answers, args.output_file)
+
+ hits_at_k = sorted([int(k) for k in args.hits_at_k.split(',')])
+ result = eval_retriever(qas_to_answer, retrieved_answers, hits_at_k)
+ with open(args.output_file + ".result", "w") as f:
+ for k, v in result.items():
+ f.write(f"{k}: {v}\n")
diff --git a/emat/retriever/utils.py b/emat/retriever/utils.py
new file mode 100644
index 0000000..c82302a
--- /dev/null
+++ b/emat/retriever/utils.py
@@ -0,0 +1,148 @@
+import glob
+import logging
+import os
+import time
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+def torch_mips(index, query_batch, top_k):
+ sims = torch.matmul(query_batch, index.t())
+ return sims.topk(top_k)
+
+
+def flat_index_mips(index, query_batch, top_k):
+ return index.search(query_batch.numpy(), top_k)
+
+
+def aux_dim_index_mips(index, query_batch, top_k):
+ # querying faiss indexes for MIPS using a euclidean distance index, used with hnsw
+ aux_dim = query_batch.new(query_batch.shape[0]).fill_(0)
+ aux_query_batch = torch.cat([query_batch, aux_dim.unsqueeze(-1)], -1)
+ return index.search(aux_query_batch.numpy(), top_k)
+
+
+def get_mips_function(index):
+ if type(index) == torch.Tensor:
+ return torch_mips
+ elif 'hnsw' in str(type(index)).lower():
+ return aux_dim_index_mips
+ else:
+ return flat_index_mips
+
+
+def get_vectors_file_paths_in_vector_directory(embeddings_dir):
+ paths = glob.glob(os.path.abspath(embeddings_dir) + '/*')
+ np = len(paths)
+ template = '.'.join(paths[0].split('.')[:-1])
+ return [template + f'.{j}' for j in range(np)]
+
+
+def parse_vectors_from_directory_chunks(embeddings_dir, half):
+ paths = get_vectors_file_paths_in_vector_directory(embeddings_dir)
+ for j, p in enumerate(paths):
+ logger.info(f'Loading vectors from {p} ({j + 1} / {len(paths)})')
+ m = torch.load(p)
+ assert int(p.split('.')[-1]) == j, (p, j)
+
+ if half:
+ m = m if m.dtype == torch.float16 else m.half()
+ else:
+ m = m if m.dtype == torch.float32 else m.float()
+ yield m
+
+
+def parse_vectors_from_directory_fast(embeddings_dir):
+ ms = []
+ for m in parse_vectors_from_directory_chunks(embeddings_dir):
+ ms.append(m)
+
+ out = torch.cat(ms)
+ logger.info(f'loaded index of shape {out.shape}')
+ return out
+
+
+def parse_vectors_from_directory_memory_friendly(embeddings_dir, size=None):
+ paths = get_vectors_file_paths_in_vector_directory(embeddings_dir)
+ if size is None:
+ size = 0
+ for j, p in enumerate(paths):
+ logger.info(f'Loading vectors from {p} ({j + 1} / {len(paths)}) to find total num vectors')
+ m = torch.load(p)
+ size += m.shape[0]
+
+ out = None
+ offset = 0
+ for j, p in enumerate(paths):
+ logger.info(f'Loading vectors from {p} ({j + 1} / {len(paths)})')
+ m = torch.load(p)
+
+ assert int(p.split('.')[-1]) == j, (p, j)
+ if out is None:
+ out = torch.zeros(size, m.shape[1])
+ out[offset: offset + m.shape[0]] = m
+ offset += m.shape[0]
+ assert offset == size
+ logger.info(f'loaded index of shape {out.shape}')
+
+ return out
+
+
+def parse_vectors_from_directory(fi, memory_friendly=False, size=None, as_chunks=False, half=False):
+ assert os.path.isdir(fi), f"Vectors directory {fi} doesnt exist, or is not a directory of pytorch vectors"
+ if as_chunks:
+ return parse_vectors_from_directory_chunks(fi, half)
+
+ if memory_friendly:
+ out = parse_vectors_from_directory_memory_friendly(fi, size=size)
+ else:
+ out = parse_vectors_from_directory_fast(fi)
+
+ if half:
+ out = out if out.dtype == torch.float16 else out.half()
+ else:
+ out = out if out.dtype == torch.float32 else out.float()
+
+ return out
+
+
+def parse_vectors_from_file(fi, half=False):
+ assert os.path.isfile(fi), f"{fi}"
+ logger.info(f'Loading vectors from {fi}')
+ out = torch.load(fi)
+ logger.info(f'loaded vectors of shape {out.shape}')
+
+ if half:
+ out = out if out.dtype == torch.float16 else out.half()
+ else:
+ out = out if out.dtype == torch.float32 else out.float()
+
+ return out
+
+
+def mips(index, queries, top_k, n_queries_to_parallelize=256):
+ # t = time.time()
+ all_top_indices = None
+ all_top_scores = None
+
+ _mips = get_mips_function(index)
+
+ for mb in range(0, len(queries), n_queries_to_parallelize):
+ query_batch = queries[mb:mb + n_queries_to_parallelize].float()
+ scores, top_indices = _mips(index, query_batch, top_k)
+
+ all_top_indices = top_indices if all_top_indices is None else np.concatenate([all_top_indices, top_indices])
+ all_top_scores = scores if all_top_scores is None else np.concatenate([all_top_scores, scores])
+
+ # delta = time.time() - t
+ # logger.info(
+ # f'{len(all_top_indices)}/ {len(queries)} queries searched in {delta:04f} '
+ # f'seconds ({len(all_top_indices) / delta} per second)')
+
+ assert len(all_top_indices) == len(queries)
+
+ # delta = time.time() - t
+ # logger.info(f'Index searched in {delta:04f} seconds ({len(queries) / delta} per second)')
+ return all_top_indices, all_top_scores
diff --git a/emat/t5.py b/emat/t5.py
new file mode 100644
index 0000000..eb09c1d
--- /dev/null
+++ b/emat/t5.py
@@ -0,0 +1,1976 @@
+# coding=utf-8
+# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch T5 model. """
+import asyncio
+import copy
+import warnings
+from concurrent import futures
+from dataclasses import dataclass
+from typing import Dict, Any, Optional, Tuple
+from emat.fusion_net import FusionWeight
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.utils.checkpoint import checkpoint
+from transformers.file_utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+ ModelOutput,
+)
+from transformers.models.t5.modeling_t5 import (
+ T5Block,
+ T5LayerNorm,
+ T5Stack,
+ T5ForConditionalGeneration,
+ T5_START_DOCSTRING,
+ T5_INPUTS_DOCSTRING,
+ __HEAD_MASK_WARNING_MSG as HEAD_MASK_WARNING_MSG,
+ _CONFIG_FOR_DOC,
+)
+from transformers.utils import logging
+from utils.utils import reduce_query_or_key_embeds
+from emat.retriever.utils import mips
+from emat.retrieval_adapter import RetAdapter
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class KeyValueOutput(ModelOutput):
+ key: torch.FloatTensor = None
+ value: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class CATEncoderOutput(ModelOutput):
+ last_hidden_state: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ query_embeds: Optional[torch.tensor] = None
+ readout_indices: Optional[Any] = None
+ updated_attention_mask: Optional[torch.tensor] = None
+
+
+@dataclass
+class CATSeq2SeqLMOutput(ModelOutput):
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cat_encoder_outputs: Optional[CATEncoderOutput] = None
+
+
+class ConvKeyEncoder(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.conv_layer = nn.Conv1d(
+ in_channels=config.d_model,
+ out_channels=config.d_model,
+ kernel_size=3,
+ stride=1,
+ padding="same",
+ bias=False
+ )
+ self.relu = nn.ReLU()
+ self.linear = nn.Linear(config.d_model, config.d_key, bias=False)
+
+ def forward(self, hidden_states, attention_mask):
+ attention_mask = attention_mask[:, None, :]
+ masked_hidden = hidden_states.transpose(1, 2) * attention_mask # shape: [batch_size, d_model, seq_length]
+
+ conv_out = self.conv_layer(masked_hidden) # shape: [batch_size, d_model, seq_length]
+ relu_out = self.relu(conv_out) # shape: [batch_size, d_model, seq_length]
+ relu_out = relu_out * attention_mask # mask again
+
+ sum_pooled = torch.sum(relu_out, dim=2) # shape: [batch_size, d_model]
+ lengths = torch.sum(attention_mask, dim=(1, 2)) # shape: [batch_size]
+ mean_pooled = sum_pooled / lengths[:, None] # shape: [batch_size, d_model]
+
+ final_out = self.linear(mean_pooled) # shape: [batch_size, d_key]
+ return final_out
+
+
+# T5StackWithKeyValueMemory is encoder
+class T5StackWithKeyValueMemory(T5Stack):
+ def __init__(self, config, embed_tokens=None):
+ super(T5Stack, self).__init__(config)
+
+ self.embed_tokens = embed_tokens
+ self.is_decoder = config.is_decoder
+
+ # Create prefix embeddings
+ assert not self.is_decoder, "This module should only be used as encoder"
+ self.prefix_length = config.prefix_length # length of the prefix
+ self.prefix_embedding = nn.Parameter(torch.empty((self.prefix_length, config.d_model), dtype=torch.float))
+ self.model_dim = config.d_model
+
+ # Key configs
+ self.key_layer = config.key_layer # the layer that conducts key querying
+
+ self.cat_layer = getattr(config, "cat_layer", None)
+
+ # Initialize the key encoder
+ self.key_encoder_type = config.key_encoder_type
+ if self.key_encoder_type == "linear":
+ self.d_key = config.d_key # dimension of the key/query embedding
+ self.key_encoder = nn.Linear(self.prefix_length * config.d_model, self.d_key, bias=False)
+ elif self.key_encoder_type == "conv":
+ self.key_encoder = ConvKeyEncoder(config)
+ elif self.key_encoder_type == "prefix":
+ self.key_encoder = None
+ else:
+ raise ValueError(f"Incorrect key_encoder_type: {self.key_encoder_type}")
+
+ # self.qk_scorer = nn.Linear(1, 1, bias=True) # calibrate the query-key match scores into gating
+
+ # Value configs
+ self.value_layer = config.value_layer # the layer that it conducts value infilling
+ self.num_values = config.num_values # number of value embeddings to infill
+ assert self.key_layer <= self.value_layer, "Key layer should be smaller than or equal to value layer"
+
+ self.value_fusion_method = config.value_fusion_method
+
+ if self.value_fusion_method is not None and "cat" in self.value_fusion_method:
+ # add_position_bias_layer = self.value_layer
+ # if "delay" in self.value_fusion_method:
+ # add_position_bias_layer = self.key_layer
+ if self.cat_layer is not None:
+ add_position_bias_layer = self.cat_layer
+ else:
+ add_position_bias_layer = min(self.key_layer, self.value_layer)
+ self.block = nn.ModuleList(
+ [T5Block(config, has_relative_attention_bias=bool(i == 0 or i == add_position_bias_layer))
+ for i in range(config.num_layers)]
+ )
+ else:
+ self.block = nn.ModuleList(
+ [T5Block(config, has_relative_attention_bias=bool(i == 0))
+ for i in range(config.num_layers)]
+ )
+
+ if self.value_fusion_method is not None and "g(" in self.value_fusion_method:
+ self.fusion_weight_proj = FusionWeight(fusion_type=self.value_fusion_method)
+ else:
+ self.fusion_weight_proj = None
+
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.key_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+ nn.init.normal_(self.prefix_embedding.data, mean=0.0, std=config.initializer_factor * 1.0)
+ # self.qk_scorer.weight.data.copy_(torch.tensor([[1.0]]))
+ # self.qk_scorer.bias.data.copy_(torch.tensor([0.0]))
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ key_embeds=None,
+ value_embeds=None,
+ key_faiss_index=None,
+ key_reduce_method=None,
+ value_qas_input_ids=None,
+ value_qas_attention_mask=None,
+ readout_top_k=1,
+ value_fusion_method=None,
+ key_embeds_of_value=None,
+ key_memory=None,
+ value_memory=None,
+ embedding_index=None,
+ ):
+ assert key_embeds is None
+ assert value_fusion_method == self.value_fusion_method
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(self.first_device)
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Sanity check
+ assert not use_cache, "This class does not support use_cache because it is encoder only"
+
+ if input_ids is not None and inputs_embeds is not None:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
+
+ if inputs_embeds is None:
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
+
+ # Concatenate the prefix embeddings and extend the attention masks
+ prefix_embeds = self.prefix_embedding[None, :, :].expand(batch_size, -1, -1).to(inputs_embeds.device)
+ inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1)
+
+ # Extend the attention masks
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length + self.prefix_length).to(inputs_embeds.device)
+ else:
+ prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to(
+ inputs_embeds.device)
+ attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
+
+ # initialize past_key_values with `None` if past does not exist
+ if past_key_values is None:
+ past_key_values = [None] * len(self.block)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, (batch_size, seq_length + self.prefix_length), inputs_embeds.device
+ )
+
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+ present_key_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+ query_embeds = None
+ readout_indices = None
+
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if position_bias is not None:
+ position_bias = position_bias.to(hidden_states.device)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
+ if encoder_extended_attention_mask is not None:
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
+ if encoder_decoder_position_bias is not None:
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
+ if layer_head_mask is not None:
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
+ if cross_attn_layer_head_mask is not None:
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return tuple(module(*inputs, use_cache, output_attentions))
+
+ return custom_forward
+
+ layer_outputs = checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ extended_attention_mask,
+ position_bias,
+ encoder_hidden_states,
+ encoder_extended_attention_mask,
+ encoder_decoder_position_bias,
+ layer_head_mask,
+ cross_attn_layer_head_mask,
+ None, # past_key_value is always None with gradient checkpointing
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if not use_cache:
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
+
+ hidden_states, present_key_value_state = layer_outputs[:2]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[2]
+
+ # Query-key matching
+ if i == self.key_layer: # key-layer is Query-Layer
+ raw_query_embeds = self._encode_key(hidden_states, attention_mask) # shape:[batch_size, d_key]
+ raw_query_embeds = raw_query_embeds.view(hidden_states.shape[0], -1, hidden_states.shape[-1])
+
+ normed_query_embeds = self.kv_output_layer(raw_query_embeds) # Query is normed !!!
+ query_embeds = reduce_query_or_key_embeds(normed_query_embeds, key_reduce_method)
+
+ if embedding_index is not None:
+ assert value_embeds is None and key_embeds_of_value is None
+ if type(embedding_index) is not list:
+ embedding_index = [embedding_index]
+ half_query_embeds = query_embeds.half()
+ memory_size, hidden_num, hidden_size = value_memory.shape
+ if memory_size > 20000000:
+ # if scale is large: calculate topk in each chunk -> combine all-topk -> select final topk
+ chunk_top_scores = []
+ chunk_top_indices = []
+ idx_shift = 0
+ for chunk_key_memory in embedding_index:
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_topk = torch.mm(half_query_embeds, chunk_key_memory_cuda.t()).topk(50, dim=1)
+ chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size]
+ chunk_top_indices.append(chunk_topk.indices + idx_shift)
+ idx_shift += len(chunk_key_memory)
+ del chunk_key_memory_cuda
+ torch.cuda.empty_cache()
+ chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n
+ chunk_top_indices = torch.cat(chunk_top_indices, dim=1) # q_batch, local_size*seg_n
+ topk = chunk_top_scores.topk(readout_top_k, dim=1) # q_batch, local_size
+ top_indices_indices = topk.indices
+ readout_indices = []
+ for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices):
+ readout_indices.append([cur_indices[idx] for idx in cur_indices_indices])
+ else:
+ all_chunk_scores = []
+ for chunk_key_memory in embedding_index:
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_scores = torch.mm(half_query_embeds, chunk_key_memory_cuda.t()) # query_batch
+ all_chunk_scores.append(chunk_scores)
+ del chunk_key_memory_cuda
+ scores = torch.cat(all_chunk_scores, dim=1)
+ readout_indices = scores.topk(readout_top_k, dim=1).indices.tolist()
+
+ top_indices = torch.tensor(readout_indices)
+ bs = input_ids.shape[0]
+ value_embeds = torch.index_select(value_memory, 0, top_indices.view(-1)).float().cuda()
+ value_embeds = value_embeds.view(input_ids.shape[0], readout_top_k, hidden_num, hidden_size)
+ key_embeds_of_value = torch.index_select(key_memory, 0, top_indices.view(-1)).float().cuda()
+ key_embeds_of_value = key_embeds_of_value.view(bs, readout_top_k, hidden_num, hidden_size)
+
+ if value_fusion_method == "cat_k_delay+v":
+ # Serial mode, cat key directly.
+ batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape
+ if key_nums != self.prefix_length:
+ assert key_nums == 1
+ key_embeds_of_value = key_embeds_of_value.repeat(1, 1, self.prefix_length, 1)
+ hidden_states = torch.cat(
+ [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size),
+ hidden_states], dim=1
+ )
+ extend_length = num_values * self.prefix_length
+ extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype)
+ attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1)
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, attention_mask.shape[:2], inputs_embeds.device
+ )
+ position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule
+
+ if self.cat_layer is not None and i == self.cat_layer:
+ if value_fusion_method == "async_cat_k_delay+v":
+ # Async mode, emat key in cat_layer. the implementation of delay + v is same to serial mode.
+ batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape
+ hidden_states = torch.cat(
+ [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size),
+ hidden_states], dim=1
+ )
+ extend_length = num_values * self.prefix_length
+ extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype)
+ attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1)
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, attention_mask.shape[:2], inputs_embeds.device
+ )
+ position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule
+
+ if i == self.value_layer:
+ batch_size, num_values, _, hidden_size = key_embeds_of_value.shape
+
+ # assert query_embeds is not None, "Use query_embeds to read memory before assignment."
+ if "delay" in value_fusion_method:
+ updated_key = hidden_states[:, :num_values * self.prefix_length]
+ updated_key = updated_key.view(batch_size, num_values, self.prefix_length, hidden_size)
+ else:
+ updated_key = None
+ integrated_value = self.get_integrated_values(value_embeds, key_embeds_of_value, value_fusion_method,
+ query_embeds=query_embeds, updated_key=updated_key,
+ key_reduce_method=key_reduce_method)
+ if value_fusion_method == "infill":
+ assert self.num_values == 1
+ hidden_states = torch.cat([integrated_value, hidden_states[:, self.prefix_length:]], dim=1)
+ elif "cat" in value_fusion_method and "delay" in value_fusion_method:
+ hidden_states[:, :num_values * self.prefix_length] = integrated_value
+ hidden_states = hidden_states.contiguous()
+ elif "cat" in value_fusion_method and "delay" not in value_fusion_method:
+ hidden_states = torch.cat([integrated_value, hidden_states], dim=1)
+ extend_length = integrated_value.shape[1]
+ extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype)
+ attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1)
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, attention_mask.shape[:2], inputs_embeds.device
+ )
+ position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule
+ else:
+ raise NotImplementedError(f"{value_fusion_method} is not defined.")
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ present_key_value_states,
+ all_hidden_states,
+ all_attentions,
+ all_cross_attentions,
+ query_embeds,
+ readout_indices
+ ]
+ if v is not None
+ )
+ return CATEncoderOutput(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_value_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ query_embeds=query_embeds,
+ readout_indices=readout_indices,
+ updated_attention_mask=None,
+ )
+
+ def get_integrated_values(self, group_value_embeds, group_key_embeds, value_fusion_method,
+ query_embeds=None, updated_key=None, key_reduce_method=None):
+ # group_value_embeds: [batch_size, num_values, prefix_length, hidden_size]
+ # group_key_embeds: [batch_size, num_values, key_num_tokens, hidden_size]
+ if value_fusion_method == "cat_k_delay+v":
+ batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape
+ integrated_value = updated_key + group_value_embeds.contiguous()
+ integrated_value = integrated_value.view(batch_size, num_values * prefix_length, hidden_size)
+ elif value_fusion_method == "async_cat_k_delay+v":
+ batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape
+ integrated_value = updated_key + group_value_embeds.contiguous()
+ integrated_value = integrated_value.view(batch_size, num_values * prefix_length, hidden_size)
+ elif value_fusion_method == "cat_v":
+ batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape
+ group_value_embeds = group_value_embeds.contiguous()
+ integrated_value = group_value_embeds.view(batch_size, num_values * prefix_length, hidden_size)
+ elif value_fusion_method == "async_cat_k+v":
+ batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape
+ group_key_add_value = group_key_embeds + group_value_embeds
+ integrated_value = group_key_add_value.view(batch_size, num_values * self.prefix_length, hidden_size)
+ elif value_fusion_method == "cat_k+v":
+ batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape
+ group_key_add_value = group_key_embeds + group_value_embeds
+ integrated_value = group_key_add_value.view(batch_size, num_values * self.prefix_length, hidden_size)
+ elif value_fusion_method == "cat_k_delay+v_g(kv)":
+ # batch_size, num_values, 1
+ key_weight = self.fusion_weight_proj(key=updated_key, value=group_value_embeds.contiguous())
+ key_weight = key_weight.unsqueeze(dim=-1)
+ integrated_value = key_weight * updated_key + (1 - key_weight) * group_value_embeds
+ batch_size, num_values, key_nums, hidden_size = updated_key.shape
+ integrated_value = integrated_value.view(batch_size, num_values * self.prefix_length, hidden_size)
+ elif value_fusion_method == "infill":
+ batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape
+ assert num_values == self.num_values == 1
+ integrated_value = group_value_embeds.view(batch_size, prefix_length, hidden_size)
+ elif value_fusion_method == "cat_kv":
+ group_key_cat_value = torch.cat((group_key_embeds, group_value_embeds), dim=2)
+ # [batch_size, num_values, prefix_length + key_num_tokens, hidden_size]
+ batch_size, num_values, integrated_prefix_length, hidden_size = group_key_cat_value.shape
+ integrated_value = group_key_cat_value.view(batch_size, num_values * integrated_prefix_length, hidden_size)
+ elif value_fusion_method == "cat_avgk+v":
+ batch_size, num_values, key_nums, hidden_size = group_key_embeds.shape
+ reduced_key_embeds = (group_key_embeds.sum(dim=2) / key_nums).unsqueeze(dim=2)
+ group_key_add_value = reduced_key_embeds + group_value_embeds
+ integrated_value = group_key_add_value.view(batch_size, num_values * self.prefix_length, hidden_size)
+ elif value_fusion_method == "cat_k+v_g(kq)":
+ batch_size, num_values, key_nums, hidden_size = group_key_embeds.shape
+ squeezed_key_embeds = group_key_embeds.view(batch_size * num_values, key_nums, hidden_size)
+ reduced_key_embeds = reduce_query_or_key_embeds(squeezed_key_embeds, key_reduce_method)
+ reduced_key_embeds = reduced_key_embeds.view(batch_size, num_values, hidden_size)
+ key_weight = self.fusion_weight_proj(key=reduced_key_embeds, query=query_embeds) # b_s, num_values, 1
+ key_weight = key_weight.unsqueeze(dim=-1)
+ integrated_value = key_weight * group_key_embeds + (1 - key_weight) * group_value_embeds
+ integrated_value = integrated_value.view(batch_size, num_values * self.prefix_length, hidden_size)
+ else:
+ raise NotImplementedError(f"{value_fusion_method} is not defined.")
+ return integrated_value
+
+ def embed_kv(self, input_ids, attention_mask=None, head_mask=None, compute_key=True, compute_value=True,
+ embed_for_ae_task=False) -> Dict:
+ """Compute the key/value embeddings for the input."""
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(self.first_device)
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
+
+ # Sanity check
+ assert compute_key or compute_value, "At least one of compute_key and compute_value needs to be True"
+ assert not (compute_key and compute_value), "Only compute key or value once forward."
+ assert input_ids is not None
+
+ original_input_shape = input_ids.shape
+ input_ids = input_ids.view(-1, input_ids.shape[-1]) # the last dimension is seq_length
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_ids.shape
+
+ # Concatenate the prefix embeddings and extend the attention masks
+ prefix_embeds = self.prefix_embedding[None, :, :].expand(batch_size, -1, -1).to(inputs_embeds.device)
+ inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1)
+
+ # Extend the attention masks
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, seq_length + self.prefix_length).to(inputs_embeds.device)
+ else:
+ prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to(
+ inputs_embeds.device)
+ attention_mask = torch.cat([prefix_mask, attention_mask.view(batch_size, seq_length)], dim=1)
+
+ # initialize past_key_values with `None` if past does not exist
+ past_key_values = [None] * len(self.block)
+
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, (batch_size, seq_length + self.prefix_length), inputs_embeds.device
+ )
+
+ encoder_hidden_states = None
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(None, self.config.num_layers)
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+ key_embeds, value_embeds = None, None
+ normed_key_embeds, normed_value_embeds = None, None
+ key_embeds_to_cat = None
+
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if position_bias is not None:
+ position_bias = position_bias.to(hidden_states.device)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
+ if encoder_extended_attention_mask is not None:
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
+ if encoder_decoder_position_bias is not None:
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
+ if layer_head_mask is not None:
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
+ if cross_attn_layer_head_mask is not None:
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=False,
+ output_attentions=False,
+ )
+
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
+
+ hidden_states, present_key_value_state = layer_outputs[:2]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[2]
+
+ # Encode the key
+ if compute_key and i == self.key_layer:
+ key_embeds = self._encode_key(hidden_states, attention_mask) # shape:[batch_size, d_key]
+ key_embeds = key_embeds.view(hidden_states.shape[0], -1, hidden_states.shape[-1])
+ if not compute_value and self.cat_layer is None:
+ break
+ if compute_value and i == self.cat_layer:
+ key_embeds_to_cat = self._encode_key(hidden_states, attention_mask)
+ key_embeds_to_cat = key_embeds_to_cat.view(hidden_states.shape[0], -1, hidden_states.shape[-1])
+
+ # Encode the value
+ if compute_value and i == self.value_layer:
+ value_embeds = hidden_states[:, :self.prefix_length]
+ break # (jimmycode): early stop to reduce cost
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ if value_embeds is not None:
+ normed_value_embeds = self.kv_output_layer(value_embeds)
+ if key_embeds is not None:
+ normed_key_embeds = self.kv_output_layer(key_embeds)
+
+ return {"key_embeds": key_embeds, "normed_key_embeds": normed_key_embeds,
+ "value_embeds": value_embeds, "normed_value_embeds": normed_value_embeds,
+ "key_embeds_to_cat": key_embeds_to_cat}
+
+ def kv_output_layer(self, embeds):
+ assert len(embeds.shape) == 3
+ embeds = self.key_layer_norm(embeds)
+ embeds = self.dropout(embeds)
+ return embeds
+
+ def _encode_key(self, hidden_states, attention_mask, prefix_embs=None):
+ assert hidden_states.ndim == 3 and attention_mask.ndim == 2
+
+ if self.key_encoder_type == "linear":
+ if prefix_embs is None:
+ prefix_embs = hidden_states[:, :self.prefix_length].view(hidden_states.shape[0], -1)
+ # shape: [batch_size, prefix_length * d_model]
+ key_embeds = self.key_encoder(prefix_embs) # shape: [batch_size, d_key]
+ elif self.key_encoder_type == "conv":
+ if prefix_embs is None:
+ key_embeds = self.key_encoder(hidden_states, attention_mask)
+ else:
+ prefix_mask = torch.ones(prefix_embs.shape[:2].to(prefix_embs.device))
+ key_embeds = self.key_encoder(prefix_embs, prefix_mask)
+ elif self.key_encoder_type == "prefix":
+ prefix_embs = hidden_states[:, :self.prefix_length] # [batch_size, prefix-len, hidden_size]
+ key_embeds = prefix_embs.view(hidden_states.shape[0], -1) # [batch_size, d_key)
+ else:
+ raise ValueError(f"Incorrect key_encoder_type: {self.key_encoder_type}")
+
+ return key_embeds
+
+ def forward_with_faiss(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ value_memory=None,
+ not_reduced_key_memory=None,
+ key_faiss_index=None,
+ key_reduce_method=None,
+ readout_top_k=1,
+ value_fusion_method=None,
+ ):
+ assert value_memory is not None
+ assert key_faiss_index is not None
+ assert not_reduced_key_memory is not None
+ assert value_fusion_method == self.value_fusion_method
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Sanity check
+ assert not use_cache, "This class does not support use_cache because it is encoder only"
+
+ if input_ids is not None and inputs_embeds is not None:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
+
+ if inputs_embeds is None:
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
+
+ # Concatenate the prefix embeddings and extend the attention masks
+ prefix_embeds = self.prefix_embedding[None, :, :].expand(batch_size, -1, -1).to(inputs_embeds.device)
+ inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1)
+
+ # Extend the attention masks
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length + self.prefix_length).to(inputs_embeds.device)
+ else:
+ prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to(
+ inputs_embeds.device)
+ attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
+
+ # initialize past_key_values with `None` if past does not exist
+ if past_key_values is None:
+ past_key_values = [None] * len(self.block)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, (batch_size, seq_length + self.prefix_length), inputs_embeds.device
+ )
+
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+ present_key_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+ query_embeds = None
+ readout_indices = None
+
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if position_bias is not None:
+ position_bias = position_bias.to(hidden_states.device)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
+ if encoder_extended_attention_mask is not None:
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
+ if encoder_decoder_position_bias is not None:
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
+ if layer_head_mask is not None:
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
+ if cross_attn_layer_head_mask is not None:
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return tuple(module(*inputs, use_cache, output_attentions))
+
+ return custom_forward
+
+ layer_outputs = checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ extended_attention_mask,
+ position_bias,
+ encoder_hidden_states,
+ encoder_extended_attention_mask,
+ encoder_decoder_position_bias,
+ layer_head_mask,
+ cross_attn_layer_head_mask,
+ None, # past_key_value is always None with gradient checkpointing
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if not use_cache:
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
+
+ hidden_states, present_key_value_state = layer_outputs[:2]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[2]
+
+ # Query-key matching
+ if i == self.key_layer: # key-layer is Query-Layer
+ raw_query_embeds = self._encode_key(hidden_states, attention_mask) # shape:[batch_size, d_key]
+ raw_query_embeds = raw_query_embeds.view(hidden_states.shape[0], -1, hidden_states.shape[-1])
+
+ normed_query_embeds = self.kv_output_layer(raw_query_embeds) # Query is normed !!!
+ query_embeds = reduce_query_or_key_embeds(normed_query_embeds, key_reduce_method)
+
+ # serial mode
+ value_embeds, key_embeds_of_value, readout_indices = self.query_memory(
+ value_memory, not_reduced_key_memory, key_faiss_index, query_embeds, readout_top_k
+ )
+ value_embeds = value_embeds.to(query_embeds.device)
+ key_embeds_of_value = key_embeds_of_value.to(query_embeds.device)
+ value_embeds = value_embeds.to(query_embeds.dtype)
+ key_embeds_of_value = key_embeds_of_value.to(query_embeds.dtype)
+
+ if value_fusion_method == "cat_k_delay+v":
+ # Serial mode, emat key directly.
+ batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape
+ if key_nums != self.prefix_length:
+ assert key_nums == 1
+ key_embeds_of_value = key_embeds_of_value.repeat(1, 1, self.prefix_length, 1)
+ hidden_states = torch.cat(
+ [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size),
+ hidden_states], dim=1
+ )
+ extend_length = num_values * self.prefix_length
+ extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype)
+ attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1)
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, attention_mask.shape[:2], inputs_embeds.device
+ )
+ position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule
+
+ if self.cat_layer is not None and i == self.cat_layer:
+ if value_fusion_method == "async_cat_k_delay+v":
+ # Async mode, emat key in cat_layer. the implementation of delay + v is same to serial mode.
+ batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape
+ hidden_states = torch.cat(
+ [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size),
+ hidden_states], dim=1
+ )
+ extend_length = num_values * self.prefix_length
+ extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype)
+ attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1)
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, attention_mask.shape[:2], inputs_embeds.device
+ )
+ position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule
+
+ if i == self.value_layer:
+ batch_size, num_values, _, hidden_size = key_embeds_of_value.shape
+
+ # assert query_embeds is not None, "Use query_embeds to read memory before assignment."
+ if "delay" in value_fusion_method:
+ updated_key = hidden_states[:, :num_values * self.prefix_length]
+ updated_key = updated_key.view(batch_size, num_values, self.prefix_length, hidden_size)
+ else:
+ updated_key = None
+ integrated_value = self.get_integrated_values(value_embeds, key_embeds_of_value, value_fusion_method,
+ query_embeds=query_embeds, updated_key=updated_key,
+ key_reduce_method=key_reduce_method)
+ if value_fusion_method == "infill":
+ assert self.num_values == 1
+ hidden_states = torch.cat([integrated_value, hidden_states[:, self.prefix_length:]], dim=1)
+ elif "cat" in value_fusion_method and "delay" in value_fusion_method:
+ hidden_states[:, :num_values * self.prefix_length] = integrated_value
+ hidden_states = hidden_states.contiguous()
+ elif "cat" in value_fusion_method and "delay" not in value_fusion_method:
+ hidden_states = torch.cat([integrated_value, hidden_states], dim=1)
+ extend_length = integrated_value.shape[1]
+ extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype)
+ attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1)
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, attention_mask.shape[:2], inputs_embeds.device
+ )
+ position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule
+ else:
+ raise NotImplementedError(f"{value_fusion_method} is not defined.")
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ present_key_value_states,
+ all_hidden_states,
+ all_attentions,
+ all_cross_attentions,
+ query_embeds,
+ readout_indices
+ ]
+ if v is not None
+ )
+ return CATEncoderOutput(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_value_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ query_embeds=query_embeds,
+ readout_indices=readout_indices,
+ updated_attention_mask=None,
+ )
+
+ async def forward_with_async_faiss(
+ self,
+ input_ids,
+ attention_mask,
+ return_dict,
+ readout_top_k,
+ key_reduce_method,
+ value_fusion_method,
+ key_faiss_index,
+ value_memory,
+ not_reduced_key_memory,
+ encoder_hidden_states=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ ):
+ assert value_memory is not None
+ assert key_faiss_index is not None
+ assert not_reduced_key_memory is not None
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Sanity check
+ assert not use_cache, "This class does not support use_cache because it is encoder only"
+
+ if input_ids is not None and inputs_embeds is not None:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
+
+ if inputs_embeds is None:
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
+
+ # Concatenate the prefix embeddings and extend the attention masks
+ prefix_embeds = self.prefix_embedding[None, :, :].expand(batch_size, -1, -1).to(inputs_embeds.device)
+ inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1)
+
+ # Extend the attention masks
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length + self.prefix_length).to(inputs_embeds.device)
+ else:
+ prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to(
+ inputs_embeds.device)
+ attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
+
+ # initialize past_key_values with `None` if past does not exist
+ if past_key_values is None:
+ past_key_values = [None] * len(self.block)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, (batch_size, seq_length + self.prefix_length), inputs_embeds.device
+ )
+
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+ present_key_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+ position_bias = None
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+ query_embeds = None
+ readout_indices = None
+
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if position_bias is not None:
+ position_bias = position_bias.to(hidden_states.device)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
+ if encoder_extended_attention_mask is not None:
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
+ if encoder_decoder_position_bias is not None:
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
+ if layer_head_mask is not None:
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
+ if cross_attn_layer_head_mask is not None:
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return tuple(module(*inputs, use_cache, output_attentions))
+
+ return custom_forward
+
+ layer_outputs = checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ extended_attention_mask,
+ position_bias,
+ encoder_hidden_states,
+ encoder_extended_attention_mask,
+ encoder_decoder_position_bias,
+ layer_head_mask,
+ cross_attn_layer_head_mask,
+ None, # past_key_value is always None with gradient checkpointing
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ position_bias=position_bias,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if not use_cache:
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
+
+ hidden_states, present_key_value_state = layer_outputs[:2]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+ position_bias = layer_outputs[2]
+
+ # Query-key matching
+ if i == self.key_layer: # key-layer is Query-Layer
+ raw_query_embeds = self._encode_key(hidden_states, attention_mask) # shape:[batch_size, d_key]
+ raw_query_embeds = raw_query_embeds.view(hidden_states.shape[0], -1, hidden_states.shape[-1])
+
+ normed_query_embeds = self.kv_output_layer(raw_query_embeds) # Query is normed !!!
+ query_embeds = reduce_query_or_key_embeds(normed_query_embeds, key_reduce_method)
+
+ # async mode
+
+ # loop = asyncio.get_event_loop()
+ # executor = futures.ThreadPoolExecutor() # futures.ProcessPoolExecutor()
+
+ async_query_future = asyncio.get_event_loop().run_in_executor(
+ futures.ThreadPoolExecutor(), self.query_memory,
+ value_memory, not_reduced_key_memory, key_faiss_index, query_embeds, readout_top_k
+ )
+
+ if i == self.cat_layer:
+ value_embeds, key_embeds_of_value, readout_indices = await async_query_future
+ value_embeds = value_embeds.to(query_embeds.device)
+ key_embeds_of_value = key_embeds_of_value.to(query_embeds.device)
+ value_embeds = value_embeds.to(query_embeds.dtype)
+ key_embeds_of_value = key_embeds_of_value.to(query_embeds.dtype)
+
+ if value_fusion_method == "async_cat_k_delay+v":
+ # Async mode, emat key in cat_layer. the implementation of delay + v is same to serial mode.
+ batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape
+ hidden_states = torch.cat(
+ [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size),
+ hidden_states], dim=1
+ )
+ extend_length = num_values * self.prefix_length
+ extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype)
+ attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1)
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, attention_mask.shape[:2], inputs_embeds.device
+ )
+ position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule
+
+ if i == self.value_layer:
+ batch_size, num_values, _, hidden_size = key_embeds_of_value.shape
+
+ # assert query_embeds is not None, "Use query_embeds to read memory before assignment."
+ if "delay" in value_fusion_method:
+ updated_key = hidden_states[:, :num_values * self.prefix_length]
+ updated_key = updated_key.view(batch_size, num_values, self.prefix_length, hidden_size)
+ else:
+ updated_key = None
+ integrated_value = self.get_integrated_values(value_embeds, key_embeds_of_value, value_fusion_method,
+ query_embeds=query_embeds, updated_key=updated_key,
+ key_reduce_method=key_reduce_method)
+ if value_fusion_method == "infill":
+ assert self.num_values == 1
+ hidden_states = torch.cat([integrated_value, hidden_states[:, self.prefix_length:]], dim=1)
+ elif "cat" in value_fusion_method and "delay" in value_fusion_method:
+ hidden_states[:, :num_values * self.prefix_length] = integrated_value
+ hidden_states = hidden_states.contiguous()
+ elif "cat" in value_fusion_method and "delay" not in value_fusion_method:
+ hidden_states = torch.cat([integrated_value, hidden_states], dim=1)
+ extend_length = integrated_value.shape[1]
+ extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype)
+ attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1)
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, attention_mask.shape[:2], inputs_embeds.device
+ )
+ position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule
+ else:
+ raise NotImplementedError(f"{value_fusion_method} is not defined.")
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ present_key_value_states,
+ all_hidden_states,
+ all_attentions,
+ all_cross_attentions,
+ query_embeds,
+ readout_indices
+ ]
+ if v is not None
+ )
+ return CATEncoderOutput(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_value_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ query_embeds=query_embeds,
+ readout_indices=readout_indices,
+ updated_attention_mask=None,
+ )
+
+ def query_memory(self, value_memory, not_reduced_key_memory, key_faiss_index,
+ query_embeds, readout_top_k):
+ assert value_memory.shape[1] == self.prefix_length
+ assert value_memory.shape[2] == self.model_dim
+ if type(query_embeds) == torch.tensor:
+ query_embeds = query_embeds.n
+ top_indices, _ = mips(key_faiss_index, query_embeds.cpu(), readout_top_k, n_queries_to_parallelize=20480)
+ memory_size, hidden_num, hidden_size = value_memory.shape
+ bs = query_embeds.shape[0]
+ top_indices = torch.tensor(top_indices)
+ readout_value = torch.index_select(value_memory, 0, top_indices.view(-1))
+ readout_value = readout_value.view(bs, readout_top_k, hidden_num, hidden_size)
+ readout_key_embeds_of_value = torch.index_select(not_reduced_key_memory, 0, top_indices.view(-1))
+ readout_key_embeds_of_value = readout_key_embeds_of_value.view(bs, readout_top_k, hidden_num, hidden_size)
+ return readout_value, readout_key_embeds_of_value, top_indices
+
+
+@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
+class T5WithKeyValueMemory(T5ForConditionalGeneration):
+ _keys_to_ignore_on_load_missing = [
+ r"encoder\.embed_tokens\.weight",
+ r"decoder\.embed_tokens\.weight",
+ r"lm_head\.weight",
+ ]
+
+ # _keys_to_ignore_on_load_unexpected = [
+ # r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
+ # ]
+
+ def __init__(self, config):
+ super(T5ForConditionalGeneration, self).__init__(config)
+ self.model_dim = config.d_model
+ # self.generate
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.is_decoder = False
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = T5StackWithKeyValueMemory(encoder_config, self.shared)
+ if config.not_share_encoder:
+ self.kv_encoder = T5StackWithKeyValueMemory(copy.deepcopy(encoder_config), self.shared)
+ else:
+ self.kv_encoder = None
+ decoder_config = copy.deepcopy(config)
+ decoder_config.is_decoder = True
+ decoder_config.is_encoder_decoder = False
+ decoder_config.num_layers = config.num_decoder_layers
+ self.decoder = T5Stack(decoder_config, self.shared)
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
+
+ self.adapter = getattr(config, "adapter", None)
+ if self.adapter is not None:
+ self.adapter = RetAdapter(config.d_model, config.adapter_out_dim, adapter_type=self.adapter)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+ self.config = config
+
+ def get_tunable_key_value_parameters(self):
+ add_pos_layer = min(self.config.key_layer, self.config.value_layer)
+ tunable_parameters_list = list(self.encoder.block[add_pos_layer].layer[0].
+ SelfAttention.relative_attention_bias.parameters())
+ if not self.config.not_share_encoder:
+ if self.encoder.key_encoder is not None:
+ tunable_parameters_list += list(self.encoder.key_encoder.parameters())
+ tunable_parameters_list += [self.encoder.prefix_embedding]
+ tunable_parameters_list += list(self.encoder.key_layer_norm.parameters())
+ elif self.config.not_share_encoder:
+ tunable_parameters_list += list(self.kv_encoder.parameters())
+ return tunable_parameters_list
+
+ def freeze_t5_params(self):
+ tunable_key_value_parameters = self.get_tunable_key_value_parameters()
+ requires_grad_nums = 0
+ for param in self.parameters():
+ if any(param is tp for tp in tunable_key_value_parameters):
+ param.requires_grad = True
+ requires_grad_nums += 1
+ else:
+ param.requires_grad = False
+ assert requires_grad_nums == len(tunable_key_value_parameters)
+ logger.info(f"tunable params num: {len(tunable_key_value_parameters)}")
+
+ def freeze_kv_encoder_params(self):
+ assert self.encoder.key_encoder is not None
+ kv_encoder_params = list(self.encoder.key_encoder.parameters())
+ for param in self.parameters():
+ if any(param is tp for tp in kv_encoder_params):
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CATSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ decoder_input_ids=None,
+ decoder_attention_mask=None,
+ key_value_input_ids=None,
+ key_value_attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ encoder_outputs=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ decoder_inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ decoder_only_attend_on_prefix=False,
+ encoder_outputs_are_key_or_value=False,
+ key_embeds=None,
+ value_embeds=None,
+ key_memory=None,
+ value_memory=None,
+ key_faiss_index=None,
+ key_reduce_method=None,
+ value_fusion_method=None,
+ key_embeds_of_value=None,
+ use_ae_lm_head=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
+ config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
+ labels in ``[0, ..., config.vocab_size]``
+
+ Returns:
+
+ Examples::
+
+ >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
+
+ >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
+ >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
+
+ >>> # training
+ >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids
+ >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids
+ >>> outputs = model(input_ids=input_ids, labels=labels)
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+
+ >>> # inference
+ >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
+ >>> outputs = model.generate(input_ids)
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
+ >>> # studies have shown that owning a dog is good for you.
+ """
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+ if head_mask is not None and decoder_head_mask is None:
+ if self.config.num_layers == self.config.num_decoder_layers:
+ warnings.warn(HEAD_MASK_WARNING_MSG, FutureWarning)
+ decoder_head_mask = head_mask
+
+ # Encode if needed (training, first prediction pass)
+ if encoder_outputs is None:
+ # Convert encoder inputs in embeddings if needed
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ key_embeds=key_embeds,
+ value_embeds=value_embeds,
+ key_memory=key_memory,
+ value_memory=value_memory,
+ key_faiss_index=key_faiss_index,
+ key_reduce_method=key_reduce_method,
+ value_fusion_method=value_fusion_method,
+ key_embeds_of_value=key_embeds_of_value
+ )
+ elif return_dict:
+ assert isinstance(encoder_outputs, CATEncoderOutput)
+
+ hidden_states = encoder_outputs.last_hidden_state
+
+ # Extend attention_mask on the left for attention for the prefix
+ batch_size, seq_length = hidden_states.shape[:2]
+ # seq_length: prefix + original hidden length
+ # attn_length: original hidden length
+ if encoder_outputs_are_key_or_value:
+ encoder_attention_mask = None
+ else:
+ attn_length = attention_mask.shape[1]
+ assert seq_length > attn_length, f"{seq_length} is not larger than {attn_length}"
+ if decoder_only_attend_on_prefix:
+ hidden_states = hidden_states[:, :seq_length - attn_length]
+ encoder_attention_mask = torch.ones(hidden_states.shape[:2]).to(attention_mask.device)
+ else:
+ prefix_mask = torch.ones(attention_mask.shape[0], seq_length - attn_length).to(attention_mask.device)
+ encoder_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
+
+ if self.model_parallel:
+ torch.cuda.set_device(self.decoder.first_device)
+
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+ # get decoder inputs from shifting lm labels to the right
+ decoder_input_ids = self._shift_right(labels)
+
+ # Set device for model parallelism
+ if self.model_parallel:
+ torch.cuda.set_device(self.decoder.first_device)
+ hidden_states = hidden_states.to(self.decoder.first_device)
+ if decoder_input_ids is not None:
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = encoder_attention_mask.to(self.decoder.first_device)
+ if decoder_attention_mask is not None:
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ past_key_values=past_key_values,
+ encoder_hidden_states=hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = decoder_outputs[0]
+
+ # Set device for model parallelism
+ if self.model_parallel:
+ torch.cuda.set_device(self.encoder.first_device)
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ sequence_output = sequence_output * (self.model_dim ** -0.5)
+
+ if use_ae_lm_head:
+ lm_logits = self.ae_lm_head(sequence_output)
+ else:
+ lm_logits = self.lm_head(sequence_output)
+
+ # loss_dict = {}
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+ return ((loss,) + output) if loss is not None else output
+
+ assert isinstance(encoder_outputs, CATEncoderOutput)
+ return CATSeq2SeqLMOutput(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ cat_encoder_outputs=encoder_outputs
+ )
+
+ def _prepare_encoder_decoder_kwargs_for_generation(
+ self, input_ids: torch.LongTensor, model_kwargs
+ ) -> Dict[str, Any]:
+ if "encoder_outputs" not in model_kwargs:
+ # retrieve encoder hidden states
+ encoder = self.get_encoder()
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
+ }
+
+ if model_kwargs.get("key_value_input_ids", None) is not None:
+ key_embeds, value_embeds = encoder.embed_kv(
+ input_ids=model_kwargs["key_value_input_ids"],
+ attention_mask=model_kwargs.get("key_value_attention_mask", None),
+ head_mask=model_kwargs.get("head_mask", None),
+ )
+ encoder_kwargs["key_embeds"] = key_embeds
+ encoder_kwargs["value_embeds"] = value_embeds
+ encoder_kwargs.pop("key_value_input_ids")
+ if "key_value_attention_mask" in encoder_kwargs:
+ encoder_kwargs.pop("key_value_attention_mask")
+
+ model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
+ return model_kwargs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=False,
+ value_fusion_method=None,
+ # use_ae_decoder=False,
+ **kwargs
+ ):
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+ return {
+ "decoder_input_ids": input_ids,
+ "past_key_values": past,
+ "encoder_outputs": encoder_outputs,
+ "encoder_outputs_are_key_or_value": encoder_outputs_are_key_or_value,
+ "attention_mask": attention_mask,
+ "decoder_only_attend_on_prefix": decoder_only_attend_on_prefix,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ "value_fusion_method": value_fusion_method,
+ # "use_ae_decoder": use_ae_decoder
+ }
+
+ def CAT_embed_kv(self, *args, **kwargs) -> Dict:
+ if self.kv_encoder is None:
+ # share kv-encoder with query-encoder
+ return self.encoder.embed_kv(*args, **kwargs)
+ else:
+ # Do not share encoder
+ return self.kv_encoder.embed_kv(*args, **kwargs)
+
+ def CAT_embed_q(self, *args, **kwargs) -> Dict:
+ return self.encoder.embed_kv(*args, **kwargs) # self.encoder is always query-encoder
+
+ def compute_key_value_ae_loss(
+ self,
+ separate_task=True,
+ key_value_input_ids=None,
+ key_value_attention_mask=None,
+ key_input_ids=None,
+ key_attention_mask=None,
+ value_input_ids=None,
+ value_attention_mask=None,
+ key_labels_input_ids=None,
+ value_labels_input_ids=None,
+ train_key=False,
+ train_value=False,
+ use_ae_lm_head=False,
+ **kwargs
+ ):
+ loss_dict = dict()
+ embed_dict = self.wrapped_embed_kv(
+ separate_task=separate_task,
+ key_value_input_ids=key_value_input_ids,
+ key_value_attention_mask=key_value_attention_mask,
+ key_input_ids=key_input_ids,
+ key_attention_mask=key_attention_mask,
+ value_input_ids=value_input_ids,
+ value_attention_mask=value_attention_mask,
+ compute_key=train_key,
+ compute_value=train_value,
+ embed_for_ae_task=True
+ )
+ key_embeds = embed_dict["normed_key_embeds"] if train_key else None
+ if "async_cat_k+v" == self.config.value_fusion_method:
+ value_embeds = self.encoder.kv_output_layer(embed_dict["value_embeds"] + embed_dict["key_embeds"]) \
+ if train_value else None # normed value for generation
+ else:
+ value_embeds = embed_dict["normed_value_embeds"] if train_value else None # normed value for generation
+ # key_embeds = key_embeds.view(key_embeds.shape[0], -1, self.model_dim) if key_embeds is not None else None
+ # value_embeds = value_embeds.view(value_embeds.shape[0], -1, self.model_dim)
+ # key_embeds/value_embeds [batch_size, prefix_length, model_dim]
+ # the length of key_embeds/value_embeds is prefix_length, do not need attention_mask
+ if train_key:
+ key_ae_outputs = self.forward(
+ # batch, num, hidden_size; 1024 -> 2, 512
+ encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None),
+ encoder_outputs_are_key_or_value=True,
+ labels=key_labels_input_ids,
+ use_ae_lm_head=use_ae_lm_head,
+ )
+ loss_dict["key_ae_loss"] = key_ae_outputs["loss"]
+ if train_value:
+ value_ae_outputs = self.forward(
+ encoder_outputs=CATEncoderOutput(last_hidden_state=value_embeds, hidden_states=None, attentions=None),
+ encoder_outputs_are_key_or_value=True,
+ labels=value_labels_input_ids,
+ use_ae_lm_head=use_ae_lm_head,
+ )
+ loss_dict["value_ae_loss"] = value_ae_outputs["loss"]
+ return loss_dict
+
+ def compute_text_pair_key_value_ae_loss(
+ self,
+ key_input_ids=None,
+ key_attention_mask=None,
+ value_input_ids=None,
+ value_attention_mask=None,
+ key_labels_input_ids=None,
+ value_labels_input_ids=None,
+ separate_decode=False,
+ hypothesis_decoder_input_ids=None,
+ hypothesis_decoder_labels=None,
+ premise_decoder_input_ids=None,
+ premise_decoder_labels=None,
+ train_key=False,
+ train_value=False,
+ **kwargs
+ ):
+ # Auto-encoding pretraining for text-pair task (e.g. NLI)
+ # separate_decode argument choices whether the model generates text-pair respectively.
+ loss_dict = dict()
+ embed_dict = self.wrapped_embed_kv(
+ separate_task=True,
+ key_value_input_ids=None,
+ key_value_attention_mask=None,
+ key_input_ids=key_input_ids,
+ key_attention_mask=key_attention_mask,
+ value_input_ids=value_input_ids,
+ value_attention_mask=value_attention_mask,
+ compute_key=train_key,
+ compute_value=train_value,
+ embed_for_ae_task=True
+ )
+ key_embeds = embed_dict["normed_key_embeds"] if train_key else None
+ value_embeds = embed_dict["normed_value_embeds"] if train_value else None
+ if train_key:
+ if separate_decode:
+ # hypothesis auto-encoding
+ hypothesis_ae_outputs = self.forward(
+ encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None),
+ encoder_outputs_are_key_or_value=True,
+ decoder_input_ids=hypothesis_decoder_input_ids,
+ labels=hypothesis_decoder_labels,
+ use_ae_lm_head=False,
+ )
+ loss_dict["hypothesis_ae_loss"] = hypothesis_ae_outputs["loss"]
+ # premise auto-encoding
+ premise_ae_outputs = self.forward(
+ encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None),
+ encoder_outputs_are_key_or_value=True,
+ decoder_input_ids=premise_decoder_input_ids,
+ labels=premise_decoder_labels,
+ use_ae_lm_head=False,
+ )
+ loss_dict["premise_ae_loss"] = premise_ae_outputs["loss"]
+ else:
+ key_ae_outputs = self.forward(
+ # batch, num, hidden_size; 1024 -> 2, 512
+ encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None),
+ encoder_outputs_are_key_or_value=True,
+ labels=key_labels_input_ids,
+ use_ae_lm_head=False,
+ )
+ loss_dict["key_ae_loss"] = key_ae_outputs["loss"]
+ if train_value:
+ value_ae_outputs = self.forward(
+ encoder_outputs=CATEncoderOutput(last_hidden_state=value_embeds, hidden_states=None, attentions=None),
+ encoder_outputs_are_key_or_value=True,
+ labels=value_labels_input_ids,
+ use_ae_lm_head=False,
+ )
+ loss_dict["value_ae_loss"] = value_ae_outputs["loss"]
+ return loss_dict
+
+ def compute_qa_loss(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ labels=None,
+ decoder_only_attend_on_prefix=False,
+ encoder_outputs_are_key_or_value=False,
+ key_memory=None,
+ value_memory=None,
+ key_faiss_index=None,
+ key_reduce_method=None,
+ positive_key_embeds=None,
+ negative_key_embeds=None,
+ value_embeds=None,
+ matching_targets=None,
+ value_fusion_method=None,
+ key_embeds_of_value=None,
+ negative_mask=None,
+ only_train_adapter=False
+ ):
+ loss_dict = dict()
+ if only_train_adapter:
+ embed_dict = self.CAT_embed_q(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ compute_key=True, compute_value=False
+ )
+ query_embeds = embed_dict["normed_key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, key_reduce_method)
+ gen_loss = torch.tensor(0.0)
+ else:
+ outputs = self.forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ labels=labels,
+ decoder_only_attend_on_prefix=decoder_only_attend_on_prefix,
+ encoder_outputs_are_key_or_value=encoder_outputs_are_key_or_value,
+ key_embeds=None,
+ value_embeds=value_embeds,
+ key_memory=key_memory,
+ value_memory=value_memory,
+ key_faiss_index=key_faiss_index,
+ key_reduce_method=key_reduce_method,
+ value_fusion_method=value_fusion_method,
+ key_embeds_of_value=key_embeds_of_value
+ )
+ query_embeds = outputs.cat_encoder_outputs.query_embeds
+ gen_loss = outputs.loss
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss_dict["gen_loss"] = gen_loss
+ if positive_key_embeds is not None:
+ if self.adapter is not None:
+ query_embeds = self.adapter(query_embeds)
+ positive_key_embeds = self.adapter(positive_key_embeds)
+ negative_key_embeds = self.adapter(negative_key_embeds)
+ scores1 = torch.mm(query_embeds, positive_key_embeds.t())
+ scores2 = torch.mm(query_embeds, negative_key_embeds.t())
+ if negative_mask is not None:
+ negative_mask = ~negative_mask.bool().to(negative_key_embeds.device)
+ scores2 = scores2.masked_fill(negative_mask, float('-inf'))
+ scores = torch.cat((scores1, scores2), dim=1)
+ match_loss = loss_fct(scores, matching_targets)
+ loss_dict["match_loss"] = match_loss
+
+ return loss_dict
+
+ def compute_gen_and_match_loss(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ labels=None,
+ decoder_only_attend_on_prefix=False,
+ encoder_outputs_are_key_or_value=False,
+ key_reduce_method=None,
+ value_embeds=None,
+ value_fusion_method=None,
+ key_embeds_of_value=None,
+ positive_and_negative_embeds=None,
+ matching_mask=None,
+ matching_targets=None,
+ use_triple_loss=None,
+ ):
+ """
+ value_embeds: retrieved key's value embeds, shape: [batch_size, num_values, prefix_length, hidden_size]
+ key_embeds_of_value: retrieved key embeds, shape: [batch_size, num_values, key_dim // model_dim, hidden_size]
+ """
+ loss_dict = dict()
+ outputs = self.forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ labels=labels,
+ decoder_only_attend_on_prefix=decoder_only_attend_on_prefix,
+ encoder_outputs_are_key_or_value=encoder_outputs_are_key_or_value,
+ key_embeds=None,
+ value_embeds=value_embeds,
+ key_memory=None,
+ value_memory=None,
+ key_faiss_index=None,
+ key_reduce_method=key_reduce_method,
+ value_fusion_method=value_fusion_method,
+ key_embeds_of_value=key_embeds_of_value
+ )
+
+ gen_loss = outputs.loss
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
+ loss_dict["gen_loss"] = gen_loss
+ if not use_triple_loss and positive_and_negative_embeds is not None:
+ query_embeds = outputs.cat_encoder_outputs.query_embeds
+ # query_embeds: batch_size, hidden
+ # positive_and_negative_embeds: N, hidden
+ scores = torch.mm(query_embeds, positive_and_negative_embeds.transpose(1, 0))
+ matching_mask = ~matching_mask.bool().to(positive_and_negative_embeds.device)
+ scores = scores.masked_fill(matching_mask, float('-inf'))
+ match_loss = loss_fct(scores, matching_targets)
+ loss_dict["match_loss"] = match_loss
+ # elif use_triple_loss and positive_key_embeds is not None:
+ # triple_loss_fct = nn.TripletMarginLoss(margin=0.5, p=2)
+ # query_embeds = outputs.cat_encoder_outputs.query_embeds
+ # batch_size, hidden_size = query_embeds.shape
+ # # negative_nums = negative_key_embeds.shape[0] // batch_size
+ # negative_key_embeds = negative_key_embeds.view(batch_size, -1, hidden_size)
+ # group_negative_key_embeds = negative_key_embeds.transpose(0, 1)
+ # triple_losses = []
+ # for cur_negative_key_embeds in group_negative_key_embeds:
+ # triple_losses.append(
+ # triple_loss_fct(query_embeds, positive_key_embeds, cur_negative_key_embeds)
+ # )
+ # triple_loss = sum(triple_losses) / len(triple_losses)
+ # loss_dict["triple_loss"] = torch.nan_to_num(triple_loss)
+
+ return loss_dict
+
+ def wrapped_embed_kv(
+ self,
+ separate_task=False,
+ key_value_input_ids=None,
+ key_value_attention_mask=None,
+ key_input_ids=None,
+ key_attention_mask=None,
+ value_input_ids=None,
+ value_attention_mask=None,
+ compute_key=False,
+ compute_value=False,
+ embed_for_ae_task=False,
+ ) -> Dict:
+ device = self.device
+ # key_embeds, value_embeds = None, None
+ if separate_task:
+ res = dict()
+ if compute_key:
+ key_res = self.CAT_embed_kv(
+ input_ids=key_input_ids.to(device), attention_mask=key_attention_mask.to(device),
+ compute_key=compute_key, compute_value=False, embed_for_ae_task=embed_for_ae_task
+ )
+ res.update({k: v for k, v in key_res.items() if "key" in k})
+ if compute_value:
+ value_res = self.CAT_embed_kv(
+ input_ids=value_input_ids.to(device), attention_mask=value_attention_mask.to(device),
+ compute_key=False, compute_value=compute_value, embed_for_ae_task=embed_for_ae_task,
+
+ )
+ res.update({k: v for k, v in value_res.items() if "value" in k})
+ else:
+ res = self.CAT_embed_kv(
+ input_ids=key_value_input_ids.to(device),
+ attention_mask=key_value_attention_mask.to(device),
+ compute_key=compute_key, compute_value=compute_value,
+ embed_for_ae_task=embed_for_ae_task
+ )
+ return res
diff --git a/emat/utils.py b/emat/utils.py
new file mode 100644
index 0000000..858b24a
--- /dev/null
+++ b/emat/utils.py
@@ -0,0 +1,68 @@
+import json
+import logging
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+try:
+ import apex
+ from apex import amp
+
+ apex.amp.register_half_function(torch, "einsum")
+ _has_apex = True
+except ImportError:
+ _has_apex = False
+
+
+def is_apex_available():
+ return _has_apex
+
+
+def to_fp16(model):
+ if is_apex_available():
+ model = amp.initialize(model, opt_level="O1")
+ else:
+ model = model.half()
+ return model
+
+
+def load_jsonl(fn):
+ all_data = []
+ with open(fn, "r") as f:
+ for line in f.readlines():
+ all_data.append(json.loads(line))
+ return all_data
+
+
+def write_jsonl(all_data, fn):
+ with open(fn, "w") as f:
+ for data in all_data:
+ f.write(json.dumps(data) + "\n")
+
+
+def convert_repaq_results_from_file(in_file, num_candidates, out_file=None):
+ data = load_jsonl(in_file)
+ processed_data = []
+
+ for sample in data:
+ sample_dict = {
+ "question": sample["input_qa"]["question"],
+ "answer": sample["input_qa"]["answer"],
+ }
+ for i, qas in enumerate(sample["retrieved_qas"][:num_candidates]):
+ q, a = qas["question"], qas["answer"][0]
+ sample_dict[f"ret_q_{i}"] = q
+ sample_dict[f"ret_a_{i}"] = a
+
+ processed_data.append(sample_dict)
+
+ if out_file is not None:
+ write_jsonl(processed_data, out_file)
+
+ return processed_data
+
+
+def verbalise_qa(q, a) -> str:
+ q = q.strip("?")
+ return f'{q}? answer: {a}'
diff --git a/embed_and_build_index.py b/embed_and_build_index.py
new file mode 100644
index 0000000..1b379ea
--- /dev/null
+++ b/embed_and_build_index.py
@@ -0,0 +1,187 @@
+import os
+import pickle
+import argparse
+import torch
+from transformers import T5Tokenizer
+import copy
+from emat.t5 import T5WithKeyValueMemory
+from emat.utils import load_jsonl
+import logging
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+from utils.utils import get_key_value_encoder_inputs, reduce_query_or_key_embeds
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+
+QA_KB_PATHS = {
+ "PAQ_L1": "./tmp/PAQ_L1_pickl_file.pkl",
+ "PAQ": "./tmp/PAQ_full.pkl",
+ "TAQ_TRAIN_NQ_TRAIN_PAQ": "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl",
+ "debug": "./tmp/PAQ_L1_small.pkl"
+}
+
+
+def get_args():
+ parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Embed and build FAISS")
+ parser.add_argument("--model_name_or_path", type=str, required=False,
+ default="./outputs/nq_checkpoints/KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;pos_from_top=128;/best_ckpt/")
+ parser.add_argument("--qas_to_retrieve_from", choices=list(QA_KB_PATHS.keys()), default=f"debug")
+ parser.add_argument("--add_nq_train", action="store_true")
+ parser.add_argument("--add_nq_dev", action="store_true")
+ parser.add_argument("--embed_batch_size", type=int, default=512)
+ parser.add_argument("--save_dir", default=f"./data/embedding_and_faiss/debug_from_nq_ckpt")
+
+ args = parser.parse_args()
+ return args
+
+
+def load_qas_to_embed(qas_to_retrieve_from, add_nq_train, add_nq_dev):
+ logging.info("loading qas to retrieve")
+ qas_to_retrieve_fp = QA_KB_PATHS[qas_to_retrieve_from]
+ logging.info(f"loading qas from {qas_to_retrieve_fp}")
+ if qas_to_retrieve_fp.endswith("pkl"):
+ qas_to_embed = pickle.load(open(qas_to_retrieve_fp, 'rb'))
+ elif qas_to_retrieve_fp.endswith("jsonl"):
+ qas_to_embed = load_jsonl(qas_to_retrieve_fp)
+ else:
+ raise ValueError(f"{qas_to_retrieve_fp}")
+ logging.info(f"load {len(qas_to_embed)} qas from PAQ.")
+
+ # if qas_to_retrieve_from == "debug":
+ # qas_to_retrieve = qas_to_retrieve[:10000]
+
+ if add_nq_train:
+ logging.info("add nq-train qas.")
+ qas_to_embed = qas_to_embed + load_jsonl("./data/annotated_datasets/NQ-open.train-train.jsonl")
+ if add_nq_dev:
+ logging.info("add nq-dev qas.")
+ qas_to_embed = qas_to_embed + load_jsonl("./data/annotated_datasets/NQ-open.train-dev.jsonl")
+
+ logging.info(f"load {len(qas_to_embed)} qas totally.")
+
+ return qas_to_embed
+
+
+@torch.no_grad()
+def embed_key_value(model, tokenizer, data_to_embed, embed_batch_size, save_dir,
+ use_fp16_model=True, key_reduce_method="avg", max_source_length=1024, prefix="question: "):
+ if use_fp16_model:
+ model = model.half()
+ logging.info("")
+
+ # model.eval()
+ # key_memory, value_memory, not_reduced_key_memory = build_memory(
+ # model, tokenizer, embed_key=True, embed_value=True, prefix=prefix, embed_as_fp16=True,
+ # key_reduce_method=key_reduce_method, return_memory=True, dump_memory=False,
+ # data_to_embed=data_to_embed, max_source_length=max_source_length, padding=True,
+ # batch_size=embed_batch_size, separate_task=True, return_not_reduced_key=True,
+ #
+ # reused_key_memory=reused_key_memory,
+ # reused_value_memory=reused_value_memory,
+ # reused_not_reduced_key_memory=reused_not_reduced_key_memory
+ # )
+ # return key_memory, value_memory, not_reduced_key_memory
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ else:
+ logging.warning(f"{save_dir} is exists. re-write contents warning.")
+
+ def collate_fn(examples):
+ model_inputs = get_key_value_encoder_inputs(examples, True, tokenizer, max_source_length,
+ prefix=prefix, only_return_key_inputs=False)
+ return model_inputs
+
+ data_to_embed_dataloader = DataLoader(data_to_embed, batch_size=embed_batch_size,
+ num_workers=4, collate_fn=collate_fn)
+ import gc
+
+ def save_embedding_index():
+ reused_key_memory = torch.zeros((len(data_to_embed), model.model_dim), device="cpu", dtype=torch.float16)
+ key_cnt = 0
+ for batch in tqdm(data_to_embed_dataloader):
+ batch_keys = list(batch.keys())
+ batch = {k: v.to(model.device) for k, v in batch.items()}
+ embed_dict = model.wrapped_embed_kv(separate_task=True, **batch,
+ compute_key=True, compute_value=False)
+ for k in batch_keys:
+ del batch[k]
+ key_embeds = embed_dict.get("normed_key_embeds")
+ key_embeds = reduce_query_or_key_embeds(key_embeds, key_reduce_method)
+ cur_key_num = key_embeds.shape[0]
+ key_embeds = key_embeds.half().cpu()
+ reused_key_memory[key_cnt: key_cnt + cur_key_num] = copy.deepcopy(key_embeds)
+ del key_embeds
+ torch.save(reused_key_memory, os.path.join(save_dir, "embedding_index.pt"))
+ logging.info("embedding index saved.")
+
+ def save_value_memory():
+ reused_value_memory = torch.zeros((len(data_to_embed), 2, model.model_dim), device="cpu", dtype=torch.float16)
+ value_cnt = 0
+ for batch in tqdm(data_to_embed_dataloader):
+ batch_keys = list(batch.keys())
+ batch = {k: v.to(model.device) for k, v in batch.items()}
+ embed_dict = model.wrapped_embed_kv(separate_task=True, **batch,
+ compute_key=False, compute_value=True)
+ for k in batch_keys:
+ del batch[k]
+ value_embeds = embed_dict.get("value_embeds")
+ cur_value_num = value_embeds.shape[0]
+ value_embeds = value_embeds.half().cpu()
+ reused_value_memory[value_cnt: value_cnt + cur_value_num] = copy.deepcopy(value_embeds)
+ del value_embeds
+ torch.save(reused_value_memory, os.path.join(save_dir, "value_memory.pt"))
+ logging.info("value memory saved.")
+
+ def save_key_memory():
+ reused_not_reduced_key_memory = torch.zeros((len(data_to_embed), 2, model.model_dim),
+ device="cpu", dtype=torch.float16)
+ nr_key_cnt = 0
+ for batch in tqdm(data_to_embed_dataloader):
+ batch_keys = list(batch.keys())
+ batch = {k: v.to(model.device) for k, v in batch.items()}
+ embed_dict = model.wrapped_embed_kv(separate_task=True, **batch,
+ compute_key=True, compute_value=False)
+ for k in batch_keys:
+ del batch[k]
+ not_normed_key_embeds = embed_dict["key_embeds"]
+ cur_key_num = not_normed_key_embeds.shape[0]
+ not_normed_key_embeds = not_normed_key_embeds.half().cpu()
+ reused_not_reduced_key_memory[nr_key_cnt: nr_key_cnt + cur_key_num] = copy.deepcopy(not_normed_key_embeds)
+ del not_normed_key_embeds
+ torch.save(reused_not_reduced_key_memory, os.path.join(save_dir, "key_memory.pt"))
+ logging.info("key memory saved.")
+
+ save_embedding_index()
+ gc.collect()
+ save_value_memory()
+ gc.collect()
+ save_key_memory()
+ gc.collect()
+
+
+def main():
+ args = get_args()
+ logging.info("loading model")
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)
+ model, load_info = T5WithKeyValueMemory.from_pretrained(args.model_name_or_path, output_loading_info=True)
+ model = model.cuda()
+ model.eval()
+ logging.info(f"model load info: {load_info}")
+
+ logging.info("loading data")
+ data_to_embed = load_qas_to_embed(args.qas_to_retrieve_from, args.add_nq_train, args.add_nq_dev)
+
+ logging.info("embedding")
+ embed_key_value(model, tokenizer, data_to_embed, args.embed_batch_size, args.save_dir)
+ # key_memory is normed and reduced
+ # value_memory is not normed
+ # not_reduced_key_memory is not normed and not reduced
+ logging.info("embedding saved.")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/embed_scripts/nq_embed_paq_and_build_faiss.sh b/embed_scripts/nq_embed_paq_and_build_faiss.sh
new file mode 100644
index 0000000..3745aa5
--- /dev/null
+++ b/embed_scripts/nq_embed_paq_and_build_faiss.sh
@@ -0,0 +1,32 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+DST_DIR="case-augmented-transformer-master"
+cd ${DST_DIR}
+
+DEVICE="0"
+
+PAQ_TYPE="PAQ_L1"
+SAVE_DIR="./data/embedding_and_faiss/${PAQ_TYPE}_from_nq_ckpt"
+NQ_MODEL_PATH="" # --
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python embed_and_build_index.py \
+ --model_name_or_path=${NQ_MODEL_PATH} \
+ --qas_to_retrieve_from=${PAQ_TYPE} \
+ --embed_batch_size=2048 \
+ --save_dir=${SAVE_DIR} \
+ --add_nq_train \
+ --add_nq_dev
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python emat/retriever/build_index.py \
+ --embeddings_dir="./data/embedding_and_faiss/PAQ_from_nq_ckpt/embedding_index.pt" \
+ --output_path="${SAVE_DIR}/key.sq8hnsw.80n80efc.faiss" \
+ --hnsw \
+ --store_n 80 \
+ --ef_construction 80 \
+ --ef_search 32 \
+ --SQ8 \
+ --indexing_batch_size=1 \
+ --verbose
diff --git a/inference_with_faiss.py b/inference_with_faiss.py
new file mode 100644
index 0000000..529a80e
--- /dev/null
+++ b/inference_with_faiss.py
@@ -0,0 +1,320 @@
+import faiss
+import os
+from utils.utils import update_CAT_config_from_args
+import asyncio
+import argparse
+import torch
+from transformers import T5Tokenizer, T5Config
+from emat.t5 import T5WithKeyValueMemory
+from emat.utils import load_jsonl
+import logging
+from embed_and_build_index import load_qas_to_embed
+import time
+from kilt_dataset import DialogDataset
+from torch.nn.utils.rnn import pad_sequence
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+
+QA_KB_PATHS = {
+ "PAQ_L1": "./tmp/PAQ_L1_pickl_file.pkl",
+ "PAQ": "./tmp/PAQ_full.pkl",
+ "TAQ_TRAIN_NQ_TRAIN_PAQ": "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl",
+ "debug": "./tmp/PAQ_L1_small.pkl"
+}
+
+
+def get_args():
+ parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Inference with faiss")
+ parser.add_argument("--model_name_or_path", type=str, required=False,
+ default="./outputs/nq_checkpoints/KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;pos_from_top=128;/best_ckpt/")
+ parser.add_argument("--f", choices=list(QA_KB_PATHS.keys()), default=f"debug")
+ parser.add_argument("--add_nq_train", action="store_true")
+ parser.add_argument("--add_nq_dev", action="store_true")
+ parser.add_argument("--inference_batch_size", type=int, default=512)
+ parser.add_argument("--load_dir", default=f"./data/embedding_and_faiss/debug_from_nq_ckpt")
+ parser.add_argument("--inference_type", type=str, default="async", choices=["async", "serial", "t5"])
+ parser.add_argument("--cat_layer", default=7, type=int)
+ parser.add_argument("--test_task", default="wow", type=str, choices=["qa", "wow", "eli5"])
+ parser.add_argument("--model_size", default="base", type=str, choices=["base", "large", "3B"])
+ parser.add_argument("--faiss_path", default="", type=str, required=False)
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = get_args()
+
+ logging.info("loading faiss index.")
+ if args.faiss_path == "":
+ faiss_path = os.path.join(args.load_dir, "key.sq8.hnsw.faiss")
+ else:
+ faiss_path = args.faiss_path
+ key_faiss_index = faiss.read_index(faiss_path)
+ logging.info("loaded faiss index.")
+
+ logging.info("loading memory.")
+ value_memory = torch.load(os.path.join(args.load_dir, "value_memory.pt"))
+ key_memory = torch.load(os.path.join(args.load_dir, "key_memory.pt"))
+ logging.info("loaded memory.")
+
+ logging.info("loading data")
+ qas_to_retrieve = load_qas_to_embed(args.qas_to_retrieve_from, args.add_nq_train, args.add_nq_dev)
+ logging.info("loaded data")
+ assert len(qas_to_retrieve) == value_memory.shape[0] == key_memory.shape[0]
+
+ logging.info("loading model")
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path)
+ if args.model_size == "3B":
+ config = T5Config.from_pretrained(args.model_name_or_path)
+ args.value_fusion_method = "cat_key_delay+v"
+ args.num_values = 10
+ args.prefix_length = 2
+ args.key_encoder_type = "prefix"
+ args.key_layer = 3
+ args.value_layer = 7
+ args.d_key = config.d_model * args.prefix_length
+ args.use_two_prefix = False
+ args.not_share_encoder = False
+ update_CAT_config_from_args(config, args)
+ model, load_info = T5WithKeyValueMemory.from_pretrained(args.model_name_or_path, config=config,
+ output_loading_info=True)
+ logging.info("loaded T5-3B.")
+ else:
+ model, load_info = T5WithKeyValueMemory.from_pretrained(args.model_name_or_path, output_loading_info=True)
+ model.eval()
+ logging.info(f"model load info: {load_info}")
+
+ if args.test_task == "qa":
+ # test_data = load_jsonl("./data/annotated_datasets/NQ-open.test.jsonl")
+ test_data = load_jsonl("./data/annotated_datasets/NQ-open.train-train.jsonl")[:512 * 40]
+ logging.info(f"loaded {len(test_data)} test qas.")
+ else:
+ if args.test_task == "wow":
+ dataset_kwargs = {
+ "dataset_name": "wow_kilt",
+ "max_source_length": 1024
+ }
+ test_data = load_jsonl("./data/annotated_datasets/wizard_of_wikipedia/wow-dev-kilt.jsonl")[:512]
+ test_data = test_data * 10
+ logging.info(f"loaded {len(test_data)} test history-response pairs.")
+ else:
+ dataset_kwargs = {
+ "dataset_name": "eli5_kilt",
+ "max_source_length": 384,
+ "max_target_length": 1536
+ }
+ test_data = load_jsonl("./data/annotated_datasets/eli5/eli5-dev-kilt.jsonl")[:512]
+ test_data = test_data * 10
+ logging.info(f"loaded {len(test_data)} test long-form qas.")
+ test_dataset = DialogDataset(test_data, tokenizer, qas_to_retrieve, max_utterances=13, **dataset_kwargs)
+ test_data = test_dataset.data
+
+ torch.cuda.empty_cache()
+
+ model = model.cuda()
+ if args.inference_type == "serial":
+ serial_inference(model, tokenizer, test_data, args.inference_batch_size, key_faiss_index, value_memory,
+ key_memory, qas_to_retrieve, args.test_task)
+ elif args.inference_type == "async":
+ async_inference(model, tokenizer, test_data, args.inference_batch_size, key_faiss_index, value_memory,
+ key_memory, qas_to_retrieve, args.cat_layer, args.test_task)
+ elif args.inference_type == "t5":
+ t5_inference(model, tokenizer, test_data, args.inference_batch_size, key_faiss_index, value_memory,
+ key_memory, qas_to_retrieve, args.test_task)
+
+
+def get_query_inputs(tokenizer, batch, device, test_task):
+ if test_task == "qa":
+ query_inputs = ["question: " + qa["question"] for qa in batch]
+ query_inputs = tokenizer(query_inputs, max_length=1024,
+ padding=True, truncation=True, return_tensors="pt")
+ return query_inputs["input_ids"].to(device), query_inputs["attention_mask"].to(device)
+ else:
+ history_input_ids = [ex["input_ids"] for ex in batch]
+ history_input_ids = pad_sequence(history_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ history_attention_mask = (history_input_ids != tokenizer.pad_token_id).long()
+ return history_input_ids.to(device), history_attention_mask.to(device)
+
+
+@torch.no_grad()
+def serial_inference(model: T5WithKeyValueMemory, tokenizer, test_data, batch_size,
+ key_faiss_index, value_memory, not_reduced_key_memory, qas_to_retrieve, test_task):
+ if test_task == "qa":
+ gen_kwargs = {"num_beams": None, "max_length": 64}
+ elif test_task == "wow":
+ gen_kwargs = {"num_beams": None, "max_length": 28, "min_length": 28}
+ else:
+ gen_kwargs = {"num_beams": None, "max_length": 187, "min_length": 187}
+
+ readout_top_k = model.config.num_values
+ key_reduce_method = "avg"
+ value_fusion_method = model.config.value_fusion_method
+
+ time_log = []
+ query_log = []
+ for start_idx in range(0, len(test_data), batch_size):
+ start_time = time.perf_counter()
+
+ batch = test_data[start_idx: start_idx + batch_size]
+ input_ids, attention_mask = get_query_inputs(tokenizer, batch, model.device, test_task)
+
+ encoder_outputs = model.encoder.forward_with_faiss(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ readout_top_k=readout_top_k,
+ key_reduce_method=key_reduce_method,
+ value_fusion_method=value_fusion_method,
+ key_faiss_index=key_faiss_index,
+ value_memory=value_memory,
+ not_reduced_key_memory=not_reduced_key_memory
+ )
+ generated_tokens = model.generate(
+ encoder_outputs=encoder_outputs,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=False,
+ attention_mask=attention_mask,
+ value_fusion_method=value_fusion_method,
+ **gen_kwargs,
+ )
+ cur_cost = time.perf_counter() - start_time
+ time_log.append(cur_cost)
+ query_log.append(len(batch))
+ logging.info(f" {len(batch)} queries / {cur_cost} seconds")
+
+ time_log = time_log[2:-1]
+ query_log = query_log[2:-1]
+ query_num = sum(query_log)
+ total_time = sum(time_log)
+ logging.info(f"average speed: {query_num} queries / {total_time} seconds = "
+ f"{query_num / total_time} queries per second")
+
+
+@torch.no_grad()
+def async_inference(model: T5WithKeyValueMemory, tokenizer, test_data, batch_size,
+ key_faiss_index, value_memory, not_reduced_key_memory, qas_to_retrieve, cat_layer, test_task):
+ if test_task == "qa":
+ gen_kwargs = {"num_beams": None, "max_length": 64}
+ elif test_task == "wow":
+ gen_kwargs = {"num_beams": None, "max_length": 28, "min_length": 28}
+ else:
+ gen_kwargs = {"num_beams": None, "max_length": 187, "min_length": 187}
+
+ readout_top_k = model.config.num_values
+ key_reduce_method = "avg"
+ # value_fusion_method = "async_cat_k+v"
+ model.encoder.key_layer = 3
+ model.encoder.cat_layer = cat_layer
+ model.encoder.value_layer = 10
+ if model.encoder.cat_layer == model.encoder.value_layer:
+ value_fusion_method = "async_cat_k+v"
+ else:
+ value_fusion_method = "async_cat_k_delay+v"
+
+ logging.info(f"cat_layer: {cat_layer}")
+
+ time_log = []
+ query_log = []
+ for start_idx in range(0, len(test_data), batch_size):
+ start_time = time.perf_counter()
+
+ batch = test_data[start_idx: start_idx + batch_size]
+ input_ids, attention_mask = get_query_inputs(tokenizer, batch, model.device, test_task)
+
+ encoder_outputs = asyncio.run(
+ model.encoder.forward_with_async_faiss(
+ input_ids, attention_mask, True, readout_top_k, key_reduce_method, value_fusion_method,
+ key_faiss_index, value_memory, not_reduced_key_memory
+ )
+ )
+ generated_tokens = model.generate(
+ encoder_outputs=encoder_outputs,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=False,
+ attention_mask=attention_mask,
+ value_fusion_method=value_fusion_method,
+ **gen_kwargs,
+ )
+ cur_cost = time.perf_counter() - start_time
+ time_log.append(cur_cost)
+ query_log.append(len(batch))
+ logging.info(f" {len(batch)} queries / {cur_cost} seconds")
+
+ time_log = time_log[2:-1]
+ query_log = query_log[2:-1]
+ query_num = sum(query_log)
+ total_time = sum(time_log)
+ logging.info(f"cat_layer: {cat_layer}")
+ logging.info(f"average speed: {query_num} queries / {total_time} seconds = "
+ f"{query_num / total_time} queries per second")
+
+
+# ELI5 --inference_batch_size=128
+# WoW --inference_batch_size=256
+
+@torch.no_grad()
+def t5_inference(model: T5WithKeyValueMemory, tokenizer, test_data, batch_size,
+ key_faiss_index, value_memory, not_reduced_key_memory, qas_to_retrieve, test_task):
+ if test_task == "qa":
+ gen_kwargs = {"num_beams": None, "max_length": 16}
+ elif test_task == "wow":
+ gen_kwargs = {"num_beams": None, "max_length": 28, "min_length": 28}
+ else:
+ gen_kwargs = {"num_beams": None, "max_length": 187, "min_length": 187}
+
+ readout_top_k = model.config.num_values
+ key_reduce_method = "avg"
+ value_fusion_method = model.config.value_fusion_method
+ model.key_layer = 1000
+ model.value_layer = 1000
+ model.cat_layer = 1000
+ model.encoder.key_layer = 1000
+ model.encoder.value_layer = 1000
+ model.encoder.cat_layer = 1000
+
+ time_log = []
+ query_log = []
+ for start_idx in range(0, len(test_data), batch_size):
+ start_time = time.perf_counter()
+
+ batch = test_data[start_idx: start_idx + batch_size]
+ input_ids, attention_mask = get_query_inputs(tokenizer, batch, model.device, test_task)
+
+ encoder_outputs = model.encoder.forward_with_faiss(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ return_dict=True,
+ readout_top_k=readout_top_k,
+ key_reduce_method=key_reduce_method,
+ value_fusion_method=value_fusion_method,
+ key_faiss_index=key_faiss_index,
+ value_memory=value_memory,
+ not_reduced_key_memory=not_reduced_key_memory
+ )
+ generated_tokens = model.generate(
+ encoder_outputs=encoder_outputs,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=False,
+ attention_mask=attention_mask,
+ value_fusion_method=value_fusion_method,
+ **gen_kwargs,
+ )
+ cur_cost = time.perf_counter() - start_time
+ time_log.append(cur_cost)
+ query_log.append(len(batch))
+ logging.info(f" {len(batch)} queries / {cur_cost} seconds")
+
+ time_log = time_log[2:-1]
+ query_log = query_log[2:-1]
+ query_num = sum(query_log)
+ total_time = sum(time_log)
+ logging.info(f"average speed: {query_num} queries / {total_time} seconds = "
+ f"{query_num / total_time} queries per second")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/kilt_dataset.py b/kilt_dataset.py
new file mode 100644
index 0000000..c92c930
--- /dev/null
+++ b/kilt_dataset.py
@@ -0,0 +1,438 @@
+import logging
+import string
+from itertools import chain
+import copy
+from torch.nn.utils.rnn import pad_sequence
+import re
+import torch
+from torch.utils.data import Dataset, DataLoader
+from typing import List, Dict
+import random
+from emat.evaluation.exact_match import normalize_answer
+from utils.utils import process_labels
+from transformers import T5Tokenizer
+from tqdm.auto import tqdm
+
+
+class DialogDataset(Dataset):
+ def __init__(
+ self,
+ data: List[Dict],
+ tokenizer: T5Tokenizer,
+ qas_to_retrieve,
+ dataset_name,
+ retrieve_strategy="dr",
+ max_source_length=None,
+ args=None,
+ normed_answer_of_qas_to_ret=None,
+ max_utterances=100,
+ add_topic=True,
+ add_persona=True,
+ max_target_length=512,
+ ):
+ super(DialogDataset, self).__init__()
+
+ assert dataset_name in ["wow", "wow_unseen", "wow_kilt", "eli5_kilt"]
+ self.max_source_length = max_source_length if max_source_length is not None else 1024
+ self.dataset_name = dataset_name
+ # print(f"dataset-name: {dataset_name}")
+ self.max_target_length = max_target_length
+ self.tokenizer = tokenizer
+ self.pad_idx = self.tokenizer.pad_token_id
+ self.label_pad_idx = -100
+ self.args = args
+ self.qas_to_retrieve = qas_to_retrieve
+ self.normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret
+ self.max_utterances = max_utterances
+ self.add_topic = add_topic
+ self.add_persona = add_persona
+
+ self.pad_qa = {"question": "", "answer": [""]}
+ if dataset_name == "wow_kilt":
+ self.data: List[Dict] = self.process_kilt_input(data)
+ elif dataset_name == "eli5_kilt":
+ self.data: List[Dict] = self.process_eli5_kilt_input(data)
+ else:
+ self.data: List[Dict] = self.process_to_input_and_response_pairs(data)
+
+ if "normalized_response" in self.data[0].keys():
+ stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've",
+ "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself',
+ 'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 'its', 'itself', 'they', 'them',
+ 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', "that'll",
+ 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has',
+ 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or',
+ 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against',
+ 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from',
+ 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once',
+ 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more',
+ 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than',
+ 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now',
+ 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', "aren't", 'couldn', "couldn't", 'didn',
+ "didn't", 'doesn', "doesn't", 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'isn',
+ "isn't", 'ma', 'mightn', "mightn't", 'mustn', "mustn't", 'needn', "needn't", 'shan', "shan't",
+ 'shouldn', "shouldn't", 'wasn', "wasn't", 'weren', "weren't", 'won', "won't", 'wouldn',
+ "wouldn't"]
+ if "wow" in dataset_name:
+ for item in self.data:
+ item["normalized_response_remove_stop_words_list"] = [
+ w for w in item["normalized_response"].split() if w not in stop_words
+ ]
+ else:
+ assert dataset_name == "eli5_kilt"
+ for item in self.data:
+ item["normalized_response_remove_stop_words_list"] = [
+ w for w in item["normalized_response"].split() if w not in stop_words
+ ]
+ item["normalized_response_remove_stop_words_list"] = \
+ item["normalized_response_remove_stop_words_list"][:512]
+
+ @staticmethod
+ def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+ def process_kilt_input(self, dialog_data):
+ processed_data = []
+ query_prefix_ids = self.tokenizer("Query:", add_special_tokens=False, return_attention_mask=False)["input_ids"]
+ for dialog_idx, item in enumerate(dialog_data):
+
+ if len(item["input"]) < 5:
+ continue
+
+ input_utterances = item["input"].split("\n")
+
+ if len(input_utterances) > self.max_utterances:
+ continue
+
+ utterances_ids = []
+ spk = "Wizard" if len(input_utterances) % 2 == 0 else "Apprentice"
+ for utterance in input_utterances:
+ utterance = f"{spk}: {utterance}"
+ spk = "Wizard" if spk == "Apprentice" else "Apprentice"
+ ids = self.tokenizer(utterance, add_special_tokens=False, return_attention_mask=False)["input_ids"]
+ utterances_ids.append(ids)
+
+ if sum(len(u) for u in utterances_ids) > self.max_source_length:
+ max_length_per_utterance = self.max_source_length // len(utterances_ids)
+ utterances_ids = [u[:max_length_per_utterance] for u in utterances_ids]
+
+ query_ids = query_prefix_ids + list(chain(*copy.deepcopy(utterances_ids[-2:]))) + [1]
+
+ input_ids = list(chain(*utterances_ids)) + [1]
+
+ cur_data = {
+ "id": item["id"],
+ "input_ids": torch.tensor(input_ids),
+ "query_ids": torch.tensor(query_ids),
+ "dialog_idx": torch.tensor(dialog_idx),
+ }
+
+ if "output" in item.keys():
+ response = item["output"][0]["answer"]
+ with self.tokenizer.as_target_tokenizer():
+ response_ids = self.tokenizer(response, max_length=self.max_target_length,
+ return_attention_mask=False)["input_ids"]
+ cur_data.update({
+ "response_ids": torch.tensor(response_ids),
+ "normalized_response": self.normalize_answer(response)
+ })
+
+ processed_data.append(cur_data)
+
+ # logging.info(f"process {len(dialog_data)} dialogs to {len(processed_data)} training examples.")
+ return processed_data
+
+ def process_eli5_kilt_input(self, eli5_data):
+ assert self.max_target_length >= 1024
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ processed_data = []
+ query_prefix_ids = self.tokenizer("Query:", add_special_tokens=False, return_attention_mask=False)["input_ids"]
+
+ for eli5_idx, item in tqdm(enumerate(eli5_data), total=len(eli5_data)):
+
+ question = item["input"]
+ question_ids = self.tokenizer(question, add_special_tokens=False, return_attention_mask=False)["input_ids"]
+ query_ids = (query_prefix_ids + copy.deepcopy(question_ids))[:255] + [1]
+ question_ids = question_ids[:383] + [1]
+
+ cur_data = {
+ "id": item["id"],
+ "input_ids": torch.tensor(question_ids),
+ "query_ids": torch.tensor(query_ids),
+ "dialog_idx": torch.tensor(eli5_idx),
+
+ }
+
+ if "output" in item.keys():
+ answer = item["output"][0]["answer"]
+ answer = white_space_fix(answer)
+
+ with self.tokenizer.as_target_tokenizer():
+ response_ids = self.tokenizer(answer, max_length=self.max_target_length,
+ return_attention_mask=False)["input_ids"]
+ cur_data.update({
+ "response_ids": torch.tensor(response_ids),
+ "normalized_response": self.normalize_answer(answer),
+ "candidate_responses": [ot['answer'] for ot in item["output"] if "answer" in ot]
+ })
+
+ processed_data.append(cur_data)
+
+ logging.info(f"process {len(eli5_data)} dialogs to {len(processed_data)} training examples.")
+ return processed_data
+
+ def process_to_input_and_response_pairs(self, dialog_data):
+ processed_data = []
+ for dialog_idx, item in enumerate(dialog_data):
+ dialog = item["dialog"][:self.max_utterances]
+ inputs = "history:"
+ if self.add_persona:
+ inputs = f'persona: {item["persona"]} ' + inputs
+ if self.add_topic:
+ inputs = f'topic: {item["chosen_topic"]}. ' + inputs
+
+ for turn_idx, turn in enumerate(dialog):
+ speaker = turn["speaker"][2:]
+ assert speaker in ["Wizard", "Apprentice"]
+ if turn["speaker"][2:] == "Wizard":
+ if turn_idx == 0:
+ query = inputs
+ else:
+ query = f'topic: {item["chosen_topic"]}. {dialog[turn_idx - 1]["text"]}'
+ processed_data.append({
+ "inputs": inputs,
+ "response": turn["text"],
+ "query": query,
+ "normalized_response": self.normalize_answer(turn["text"]),
+ "dialog_idx": dialog_idx,
+ "turn_idx": turn_idx,
+ })
+ inputs = inputs + f' {speaker}: {turn["text"]}'
+
+ logging.info(f"process {len(dialog_data)} dialogs to {len(processed_data)} training examples.")
+ return processed_data
+
+ def get_qa_key_value_inputs(self, qas, only_return_key_inputs=False):
+ key_inputs = ["question: " + qa["question"] for qa in qas]
+ key_inputs = self.tokenizer(key_inputs, max_length=self.max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ if only_return_key_inputs:
+ return {"key_input_ids": key_inputs["input_ids"],
+ "key_attention_mask": key_inputs["attention_mask"]}
+ else:
+ value_inputs = ["answer: " + qa["answer"][0] for qa in qas]
+ value_inputs = self.tokenizer(value_inputs, max_length=self.max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"key_input_ids": key_inputs["input_ids"],
+ "key_attention_mask": key_inputs["attention_mask"],
+ "value_input_ids": value_inputs["input_ids"],
+ "value_attention_mask": value_inputs["attention_mask"]}
+
+ def get_query_inputs(self, batch):
+ if "kilt" in self.dataset_name:
+ query_input_ids = [ex["query_ids"] for ex in batch]
+ query_input_ids = pad_sequence(query_input_ids, batch_first=True, padding_value=self.pad_idx)
+ query_attention_mask = (query_input_ids != self.pad_idx).long()
+ return {"query_input_ids": query_input_ids,
+ "query_attention_mask": query_attention_mask}
+ else:
+ query_inputs = [ex["query"] for ex in batch]
+ query_inputs = self.tokenizer(query_inputs, max_length=self.max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"query_input_ids": query_inputs["input_ids"],
+ "query_attention_mask": query_inputs["attention_mask"]}
+
+ def get_history_inputs(self, batch):
+ if "kilt" in self.dataset_name:
+ history_input_ids = [ex["input_ids"] for ex in batch]
+ history_input_ids = pad_sequence(history_input_ids, batch_first=True, padding_value=self.pad_idx)
+ history_attention_mask = (history_input_ids != self.pad_idx).long()
+ return {"history_input_ids": history_input_ids,
+ "history_attention_mask": history_attention_mask}
+ else:
+ history_inputs = [ex["inputs"] for ex in batch]
+ history_inputs = self.tokenizer(history_inputs, max_length=self.max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"history_input_ids": history_inputs["input_ids"],
+ "history_attention_mask": history_inputs["attention_mask"]}
+
+ def get_target_inputs(self, batch):
+ if "kilt" in self.dataset_name:
+ target_ids = [ex["response_ids"] for ex in batch]
+ target_ids = pad_sequence(target_ids, batch_first=True, padding_value=self.pad_idx)
+ return {"labels": process_labels(target_ids, self.tokenizer)}
+ else:
+ targets = [dialog["response"] for dialog in batch]
+ with self.tokenizer.as_target_tokenizer():
+ targets = self.tokenizer(targets, max_length=self.max_target_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"labels": process_labels(targets, self.tokenizer)}
+
+ def get_dataloader(self, batch_size, shuffle, num_workers):
+
+ def base_collate_fn(batch):
+ original_batch_size, filtered_batch_size = len(batch), len(batch)
+ # if not self.args.use_not_exactly_true:
+ # batch = [ex for ex in batch if len(ex["local_positive"]) > 0]
+ # filtered_batch_size = len(batch)
+ # while len(batch) == 0: # avoid empty-batch
+ # batch = random.sample(self.data, batch_size)
+ # batch = [ex for ex in batch if len(ex["local_positive"]) > 0]
+ # # do not change filtered_batch_size even change the batch again.
+
+ model_inputs = {
+ "trainable_percentage": torch.tensor(filtered_batch_size / original_batch_size).repeat(len(batch)),
+ # repeat ``len(batch)`` times to compatible in multi-GPUs.
+ }
+ model_inputs.update(self.get_query_inputs(batch))
+ model_inputs.update(self.get_history_inputs(batch))
+ model_inputs.update(self.get_target_inputs(batch))
+
+ batch_local_positive_num = self.args.batch_local_positive_num
+ neg_num_each_example = self.args.negatives_num_each_example
+ local_positive_qas = []
+ local_positive_num = []
+ local_positive_qas_mask = []
+ local_negative_qas = []
+ local_pos_mix_neg_qas = [] # num = neg_num_each_example
+ for ex in batch:
+ cur_local_positive_qas_ids = [idx for idx in ex["local_positive"][:batch_local_positive_num]]
+ cur_local_positive_qas = [self.qas_to_retrieve[idx] for idx in cur_local_positive_qas_ids]
+ cur_pos_num = len(cur_local_positive_qas)
+ local_positive_num.append(cur_pos_num)
+
+ cur_local_negative_qas_idx = random.sample(ex["local_negative"], neg_num_each_example)
+ cur_local_negative_qas = [self.qas_to_retrieve[idx] for idx in cur_local_negative_qas_idx]
+ local_negative_qas.append(cur_local_negative_qas)
+ cur_local_pos_mix_neg_qas = cur_local_positive_qas + \
+ cur_local_negative_qas[:neg_num_each_example - cur_pos_num]
+ local_pos_mix_neg_qas.append(cur_local_pos_mix_neg_qas)
+
+ cur_pad_num = batch_local_positive_num - cur_pos_num
+ cur_local_positive_qas_mask = [1] * cur_pos_num + [0] * cur_pad_num
+ local_positive_qas_mask.append(cur_local_positive_qas_mask)
+ cur_local_positive_qas.extend([self.pad_qa] * cur_pad_num)
+ local_positive_qas.append(cur_local_positive_qas)
+
+ model_inputs.update({"local_positive_qas_mask": torch.tensor(local_positive_qas_mask),
+ "local_positive_num": torch.tensor(local_positive_num), })
+
+ assert self.args.select_positive_strategy == "softmax_sample"
+ squeezed_positive_qas = list(chain(*local_positive_qas))
+ local_positive_inputs = self.get_qa_key_value_inputs(squeezed_positive_qas, only_return_key_inputs=True)
+ model_inputs.update({f"local_positive_inputs_{k}": v.view(len(batch), batch_local_positive_num, -1)
+ for k, v in local_positive_inputs.items()})
+
+ squeezed_negative_qas = list(chain(*local_negative_qas))
+ local_negative_inputs = self.get_qa_key_value_inputs(squeezed_negative_qas, only_return_key_inputs=True)
+ model_inputs.update({f"local_negative_inputs_{k}": v.view(len(batch), neg_num_each_example, -1)
+ for k, v in local_negative_inputs.items()})
+
+ squeezed_mixed_qas = list(chain(*local_pos_mix_neg_qas))
+ local_mixed_inputs = self.get_qa_key_value_inputs(squeezed_mixed_qas)
+ model_inputs.update({f"local_mixed_inputs_{k}": v.view(len(batch), neg_num_each_example, -1)
+ for k, v in local_mixed_inputs.items()})
+
+ # all_targets = [[normalize_answer(an) for an in qa["response"]] for qa in batch]
+ # negative_qas_answer = [normalize_answer(nqa["answer"][0]) for nqa in squeezed_negative_qas]
+ # negative_mask = [[1 if neg_ans not in cur_all_target else 0 for neg_ans in negative_qas_answer]
+ # for cur_all_target in all_targets]
+ # model_inputs.update({"negative_mask": torch.tensor(negative_mask)})
+
+ # for multi-GPUs
+ assert all(model_inputs[k].shape[0] == len(batch) for k in model_inputs.keys())
+
+ return model_inputs
+
+ return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
+ collate_fn=base_collate_fn, pin_memory=True)
+
+ def get_query_dataloader(self, batch_size, shuffle, num_workers, add_history=False):
+
+ def query_collate_fn(batch):
+ model_inputs = self.get_query_inputs(batch)
+
+ if add_history:
+ model_inputs.update(self.get_history_inputs(batch))
+
+ return model_inputs
+
+ return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
+ collate_fn=query_collate_fn, pin_memory=True, drop_last=False)
+
+ def get_t5_dataloader(self, batch_size, shuffle, num_workers, is_train):
+
+ def t5_collate_fn(batch):
+ history_inputs = self.get_history_inputs(batch)
+ response_inputs = self.get_target_inputs(batch)
+ return {
+ "input_ids": history_inputs["history_input_ids"],
+ "attention_mask": history_inputs["history_attention_mask"],
+ "labels": response_inputs["labels"]
+ }
+
+ return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
+ collate_fn=t5_collate_fn, pin_memory=True)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, item):
+ return self.data[item]
+
+
+if __name__ == '__main__':
+ import json
+
+
+ def load_json(fi):
+ return json.load(open(fi, 'r'))
+
+
+ def load_jsonl(fn):
+ all_data = []
+ with open(fn, "r") as f:
+ for line in f.readlines():
+ all_data.append(json.loads(line))
+ return all_data
+
+
+ tokenizer = T5Tokenizer.from_pretrained("./data/cbqa_data/pretrained_model/t5-base")
+ # test_data = load_jsonl("wow-test_without_answers-kilt.jsonl.txt")
+ # train_data = load_jsonl("wow-train-kilt.jsonl")
+ dev_data = load_jsonl("./data/annotated_datasets/wizard_of_wikipedia/wow-dev-kilt.jsonl")
+
+ exp = dev_data[0]
+ print("")
+
+ dataset = DialogDataset(dev_data, tokenizer, None, "wow_kilt",
+ max_source_length=768, max_utterances=10)
+ # data: List[Dict],
+ # tokenizer: T5Tokenizer,
+ # qas_to_retrieve,
+ # dataset_name,
+ # retrieve_strategy = "dr",
+ # max_source_length = None,
+ # args = None,
+ # normed_answer_of_qas_to_ret = None,
+ # max_utterances = 100,
+ # add_topic = True,
+ # add_persona = True
diff --git a/kilt_main.py b/kilt_main.py
new file mode 100644
index 0000000..a59d78f
--- /dev/null
+++ b/kilt_main.py
@@ -0,0 +1,183 @@
+import json
+import os
+import pickle
+from transformers import T5Tokenizer
+from emat.utils import load_jsonl
+from kilt_dataset import DialogDataset
+from kilt_trainer import DialogTrainer
+from utils.utils import CATArgs
+import logging
+import time
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+
+DATA_PATHS = {
+ "wow": {
+ "train": "./data/annotated_datasets/wizard_of_wikipedia/train.json",
+ "validation": "./data/annotated_datasets/wizard_of_wikipedia/valid_random_split.json",
+ "test": "./data/annotated_datasets/wizard_of_wikipedia/test_random_split.json",
+ },
+ "wow_unseen": {
+ "train": "./data/annotated_datasets/wizard_of_wikipedia/train.json",
+ "validation": "./data/annotated_datasets/wizard_of_wikipedia/valid_topic_split.json",
+ "test": "./data/annotated_datasets/wizard_of_wikipedia/test_topic_split.json",
+ },
+ "wow_kilt": {
+ "train": "./data/annotated_datasets/wizard_of_wikipedia/wow-train-kilt.jsonl",
+ "validation": "./data/annotated_datasets/wizard_of_wikipedia/wow-dev-kilt.jsonl",
+ "test": "./data/annotated_datasets/wizard_of_wikipedia/wow-test_without_answers-kilt.jsonl.txt",
+ },
+ "eli5_kilt": {
+ "train": "./data/annotated_datasets/eli5/eli5-train-kilt.jsonl",
+ "validation": "./data/annotated_datasets/eli5/eli5-dev-kilt.jsonl",
+ "test": "./data/annotated_datasets/eli5/eli5-test_without_answers-kilt.jsonl",
+ }
+}
+QA_KB_PATHS = {
+ "PAQ_L1": "./tmp/PAQ_L1_pickl_file.pkl",
+ "PAQ": "./tmp/PAQ_full.pkl",
+ "TAQ_TRAIN_NQ_TRAIN_PAQ": "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl",
+}
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+
+
+def load_dataset(args):
+ assert args.qa_data_name in DATA_PATHS.keys(), f"available dataset names: {DATA_PATHS.keys()}"
+
+ logging.info("loading normed answer of qas to retrieve")
+ if "PAQ" == args.qas_to_retrieve_from:
+ normed_answer_of_qas_to_ret = pickle.load(open("./tmp/PAQ_only_normalized_answer.pkl", 'rb'))
+ else:
+ normed_answer_of_qas_to_ret = json.load(open("./tmp/PAQL1_only_normalized_answer.json", 'r'))
+
+ logging.info("loading qas to retrieve")
+ if "debug" in args.exp_name.lower() or "full-paq-test" in args.exp_name.lower():
+ if not os.path.exists("./tmp/PAQ_L1_small.pkl"):
+ qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_pickl_file.pkl", 'rb'))
+ qas_to_retrieve = qas_to_retrieve[:len(qas_to_retrieve) // 14]
+ pickle.dump(qas_to_retrieve, open("./tmp/PAQ_L1_small.pkl", 'wb'))
+ else:
+ qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_small.pkl", 'rb'))
+ else:
+ qas_to_retrieve_fp = QA_KB_PATHS[args.qas_to_retrieve_from]
+ logging.info(f"loading qas from {qas_to_retrieve_fp}")
+ if qas_to_retrieve_fp.endswith("pkl"):
+ qas_to_retrieve = pickle.load(open(qas_to_retrieve_fp, 'rb'))
+ elif qas_to_retrieve_fp.endswith("jsonl"):
+ qas_to_retrieve = load_jsonl(qas_to_retrieve_fp)
+ else:
+ raise ValueError(f"{qas_to_retrieve_fp}")
+
+ if "debug" in args.exp_name.lower():
+ qas_to_retrieve = qas_to_retrieve[:5000]
+ normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:len(qas_to_retrieve)]
+
+ if args.qas_to_retrieve_from == "PAQ" and args.PAQ_size is not None:
+ qas_to_retrieve = qas_to_retrieve[:args.PAQ_size]
+ normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:args.PAQ_size]
+ assert len(qas_to_retrieve) == args.PAQ_size
+ logging.info(f"select {args.PAQ_size}-size PAQ.")
+
+ assert len(normed_answer_of_qas_to_ret) == len(qas_to_retrieve)
+ loaded_data = {
+ "qas_to_retrieve": qas_to_retrieve,
+ "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret
+ }
+
+ return loaded_data
+
+
+def main():
+ cat_args = CATArgs("dialog_cat")
+ args = cat_args.parse_args()
+ data_paths = DATA_PATHS[args.qa_data_name]
+ logging.info("load datasets")
+ if "kilt" in args.qa_data_name:
+ train_data = load_jsonl(data_paths["train"])
+ dev_data = load_jsonl(data_paths["validation"])
+ test_data = load_jsonl(data_paths["test"])
+ else:
+ train_data = json.load(open(data_paths["train"], 'r'))
+ dev_data = json.load(open(data_paths["validation"], 'r'))
+ test_data = json.load(open(data_paths["test"], 'r'))
+
+ loaded_data = load_dataset(args)
+ logging.info("data loaded.")
+ qas_to_retrieve = loaded_data["qas_to_retrieve"]
+ normed_answer_of_qas_to_ret = loaded_data["normed_answer_of_qas_to_ret"]
+
+ if "debug" in args.exp_name.lower():
+ train_data = train_data[:50]
+ dev_data = dev_data[:10]
+ test_data = test_data[:10]
+ qas_to_retrieve = qas_to_retrieve[:10000]
+ normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:10000]
+
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
+
+ if args.qa_data_name != "eli5_kilt":
+ dataset_kwargs = {
+ "dataset_name": args.qa_data_name,
+ "args": args,
+ "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret,
+ "add_persona": args.add_persona,
+ "add_topic": args.add_topic,
+ "max_source_length": 1024
+ }
+ else:
+ assert args.qa_data_name == "eli5_kilt"
+ dataset_kwargs = {
+ "dataset_name": args.qa_data_name,
+ "args": args,
+ "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret,
+ "max_source_length": 384,
+ "max_target_length": 1536
+ }
+ mu = 10 if args.qa_data_name == "wow_kilt" else 13
+ train_dataset = DialogDataset(train_data, tokenizer, qas_to_retrieve, max_utterances=mu, **dataset_kwargs)
+ dev_dataset = DialogDataset(dev_data, tokenizer, qas_to_retrieve, **dataset_kwargs)
+ test_dataset = DialogDataset(test_data, tokenizer, qas_to_retrieve, **dataset_kwargs)
+ dialog_trainer = DialogTrainer(args, train_dataset, dev_dataset, test_dataset, qas_to_retrieve,
+ normed_answer_of_qas_to_ret)
+
+ if args.do_train:
+ dialog_trainer.train()
+ elif args.do_test:
+ logging.info("Only do test.")
+ ckpt_load_path = os.path.join(args.output_dir, "best_ckpt/pytorch_model.bin")
+ gen_kwargs = {"max_length": 1024,
+ "num_beams": 5,
+ "do_sample": True,
+ "top_k": 64,
+ "no_repeat_ngram_size": 8}
+ logging.warning("use dev dataset")
+ use_dataset = dev_dataset
+
+ metrics, ret_qas, gen_response = dialog_trainer.evaluate(use_dataset, update_key_memory=True,
+ ckpt_load_path=ckpt_load_path, gen_kwargs=gen_kwargs)
+
+ for k, v in metrics.items():
+ logging.info(f"test_{k}: {v}")
+ assert len(ret_qas) == len(gen_response) == len(use_dataset.data)
+ results = []
+ for retrieved, pred, input_item in zip(ret_qas, gen_response, use_dataset.data):
+ results.append({
+ "id": input_item["id"],
+ "input": tokenizer.decode(input_item["input_ids"]),
+ "target": tokenizer.decode(input_item["response_ids"]) if "response_ids" in input_item else "",
+ "query": tokenizer.decode(input_item["query_ids"]),
+ "output": {"answer": pred, "provenance": [{"wikipedia_id": "12904"}]},
+ "retrieved_qas": [f"question: {qa['question']} answer: {qa['answer'][0]}" for qa in retrieved]
+ })
+ dump_path = os.path.dirname(ckpt_load_path)
+ dump_path = os.path.join(dump_path, f"{time.strftime('%d %H-%M')}_predict_result.json")
+ json.dump(results, open(dump_path, 'w'),
+ indent=4, ensure_ascii=False)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/kilt_scripts/eli5_train.sh b/kilt_scripts/eli5_train.sh
new file mode 100644
index 0000000..169d7ed
--- /dev/null
+++ b/kilt_scripts/eli5_train.sh
@@ -0,0 +1,59 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+LOAD_EXP_NAME="KL=3;kdim=1536;CL=10;VL=11;VN=10;async_cat_k+v;t5-base;from-t5-base;"
+EXP_NAME="base;wow_kilt;freeze;8e-5;20epoch;CL=10;VL=11;"
+DATA_NAME="eli5_kilt"
+
+DEVICE="0"
+
+echo "Use Device ${DEVICE}"
+
+# Train wow-EMAT-SKSV
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python kilt_main.py \
+ --key_layer=3 \
+ --cat_layer=10 \
+ --value_layer=11 \
+ --query_batch_size=256 \
+ --build_mem_batch_size=12000 \
+ --project_name="${DATA_NAME^^}-CAT" \
+ --exp_name=${EXP_NAME} \
+ --batch_local_positive_num=5 \
+ --pos_from_top=128 \
+ --do_eval \
+ --kvm_seg_n=5 \
+ --values_with_order \
+ --value_fusion_method="async_cat_k_delay+v" \
+ --num_values=10 \
+ --qa_data_name=${DATA_NAME} \
+ --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=12 \
+ --per_device_eval_batch_size=32 \
+ --gradient_accumulation_steps=5 \
+ --learning_rate=5e-5 \
+ --num_train_epochs=20 \
+ --lr_scheduler_type="linear" \
+ --num_warmup_steps=1000 \
+ --output_dir="./outputs/${DATA_NAME}_checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_encoder_type="conv" \
+ --select_positive_strategy="softmax_sample" \
+ --faiss_efsearch=128 \
+ --gen_weight=1 \
+ --match_weight=1 \
+ --key_reduce_method="avg" \
+ --qas_to_retrieve_from="PAQ_L1" \
+ --local_size=384 \
+ --separate_task \
+ --early_stop_patience=8 \
+ --negatives_num_each_example=16 \
+ --do_train \
+ --not_share_encoder
diff --git a/kilt_scripts/eval_kilt.sh b/kilt_scripts/eval_kilt.sh
new file mode 100644
index 0000000..acc9761
--- /dev/null
+++ b/kilt_scripts/eval_kilt.sh
@@ -0,0 +1,61 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+CKPT_DIR=""
+EXP_NAME="eval"
+DATA_NAME="wow_kilt" # eli5_kilt
+
+DEVICE="0"
+
+echo "Use Device ${DEVICE}"
+
+#echo "wait-to-continue..."
+#sleep 7200
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python kilt_main.py \
+ --key_layer=3 \
+ --value_layer=7 \
+ --query_batch_size=256 \
+ --build_mem_batch_size=12000 \
+ --project_name="${DATA_NAME^^}-CAT" \
+ --exp_name=${EXP_NAME} \
+ --batch_local_positive_num=5 \
+ --pos_from_top=128 \
+ --do_eval \
+ --kvm_seg_n=5 \
+ --values_with_order \
+ --value_fusion_method="cat_k_delay+v" \
+ --num_values=10 \
+ --qa_data_name=${DATA_NAME} \
+ --model_name_or_path=${CKPT_DIR} \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=12 \
+ --per_device_eval_batch_size=32 \
+ --gradient_accumulation_steps=5 \
+ --learning_rate=8e-5 \
+ --num_train_epochs=20 \
+ --lr_scheduler_type="linear" \
+ --num_warmup_steps=1000 \
+ --output_dir="./outputs/${DATA_NAME}_checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_encoder_type="conv" \
+ --select_positive_strategy="softmax_sample" \
+ --faiss_efsearch=128 \
+ --gen_weight=1 \
+ --match_weight=1 \
+ --key_reduce_method="avg" \
+ --qas_to_retrieve_from="PAQ_L1" \
+ --local_size=384 \
+ --separate_task \
+ --early_stop_patience=8 \
+ --negatives_num_each_example=16 \
+ --do_test \
+ --add_topic \
+ --add_persona \
+ --not_share_encoder
diff --git a/kilt_scripts/wow_train.sh b/kilt_scripts/wow_train.sh
new file mode 100644
index 0000000..0d1afe3
--- /dev/null
+++ b/kilt_scripts/wow_train.sh
@@ -0,0 +1,61 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+LOAD_EXP_NAME="KL=3;kdim=1536;CL=10;VL=11;VN=10;async_cat_k+v;t5-base;from-t5-base;"
+EXP_NAME="base;wow_kilt;freeze;8e-5;20epoch;CL=10;VL=11;"
+DATA_NAME="wow_kilt"
+
+DEVICE="0"
+
+echo "Use Device ${DEVICE}"
+
+# Train wow-EMAT-SKSV
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python kilt_main.py \
+ --key_layer=3 \
+ --cat_layer=10 \
+ --value_layer=11 \
+ --query_batch_size=256 \
+ --build_mem_batch_size=12000 \
+ --project_name="${DATA_NAME^^}-CAT" \
+ --exp_name=${EXP_NAME} \
+ --batch_local_positive_num=5 \
+ --pos_from_top=128 \
+ --do_eval \
+ --kvm_seg_n=5 \
+ --values_with_order \
+ --value_fusion_method="async_cat_k_delay+v" \
+ --num_values=10 \
+ --qa_data_name=${DATA_NAME} \
+ --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=12 \
+ --per_device_eval_batch_size=32 \
+ --gradient_accumulation_steps=5 \
+ --learning_rate=8e-5 \
+ --num_train_epochs=20 \
+ --lr_scheduler_type="linear" \
+ --num_warmup_steps=1000 \
+ --output_dir="./outputs/${DATA_NAME}_checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_encoder_type="conv" \
+ --select_positive_strategy="softmax_sample" \
+ --faiss_efsearch=128 \
+ --gen_weight=1 \
+ --match_weight=1 \
+ --key_reduce_method="avg" \
+ --qas_to_retrieve_from="PAQ_L1" \
+ --local_size=384 \
+ --separate_task \
+ --early_stop_patience=8 \
+ --negatives_num_each_example=16 \
+ --do_train \
+ --add_topic \
+ --add_persona \
+ --not_share_encoder
diff --git a/kilt_trainer.py b/kilt_trainer.py
new file mode 100644
index 0000000..770a885
--- /dev/null
+++ b/kilt_trainer.py
@@ -0,0 +1,655 @@
+import copy
+import logging
+import math
+import os
+import random
+from itertools import chain
+from typing import List, Dict, Optional
+from collections import Counter
+from rouge import Rouge
+from emat.utils import write_jsonl
+from utils.dr_utils import update_dialog_local_qas_to_retrieve, update_batch_inputs
+from utils.utils import reduce_query_or_key_embeds
+import datasets
+import torch
+import transformers
+from accelerate import Accelerator
+from tqdm.auto import tqdm
+from transformers import AdamW, get_scheduler, set_seed
+from utils.utils import save_model, load_model
+from build_kvm import build_memory
+from emat.t5 import T5WithKeyValueMemory
+from kilt_dataset import DialogDataset
+from nltk.translate.bleu_score import sentence_bleu
+import time
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+logger = logging.getLogger(__name__)
+
+try:
+ import wandb
+
+ wandb.ensure_configured()
+ if wandb.api.api_key is None:
+ _has_wandb = False
+ wandb.termwarn(
+ "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
+ else:
+ _has_wandb = False if os.getenv("WANDB_DISABLED") else True
+except (ImportError, AttributeError):
+ _has_wandb = False
+
+
+def compute_batch_BLEU(references, candidates):
+ def compute_BLEU(reference, candidate):
+ b1 = sentence_bleu(reference, candidate, weights=(1, 0, 0, 0))
+ b2 = sentence_bleu(reference, candidate, weights=(0, 1, 0, 0))
+ b3 = sentence_bleu(reference, candidate, weights=(0, 0, 1, 0))
+ b4 = sentence_bleu(reference, candidate, weights=(0, 0, 0, 1))
+ return b1, b2, b3, b4
+
+ references = [[ref] for ref in references]
+
+ bleu1_score = []
+ bleu2_score = []
+ bleu3_score = []
+ bleu4_score = []
+ for index in range(len(references)):
+ bleu1, bleu2, bleu3, bleu4 = compute_BLEU(references[index], candidates[index])
+ bleu1_score.append(bleu1)
+ bleu2_score.append(bleu2)
+ bleu3_score.append(bleu3)
+ bleu4_score.append(bleu4)
+
+ bleu1 = sum(bleu1_score) / len(bleu1_score)
+ bleu2 = sum(bleu2_score) / len(bleu2_score)
+ bleu3 = sum(bleu3_score) / len(bleu2_score)
+ bleu4 = sum(bleu4_score) / len(bleu4_score)
+
+ return {"bleu1": bleu1, "bleu2": bleu2, "bleu3": bleu3, "bleu4": bleu4, }
+
+
+def f1_score(prediction, ground_truth):
+ prediction_tokens = prediction.split()
+ ground_truth_tokens = ground_truth.split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def max_score_over_ground_truths(fn, prediction, ground_truths):
+ scores_for_ground_truths = []
+ for ground_truth in ground_truths:
+ score = fn(prediction, ground_truth)
+ scores_for_ground_truths.append(score)
+ return max(scores_for_ground_truths)
+
+
+def rougel_score(prediction, ground_truth):
+ rouge = Rouge()
+ # no normalization
+ try:
+ scores = rouge.get_scores(prediction, ground_truth, avg=True)
+ except ValueError: # "Hypothesis is empty."
+ return 0.0
+ return scores["rouge-l"]["f"]
+
+
+class DialogTrainer:
+
+ def __init__(
+ self,
+ args,
+ train_dataset: DialogDataset,
+ dev_dataset: DialogDataset,
+ test_dataset: DialogDataset,
+ qas_to_retrieve: List[Dict],
+ normed_answer_of_qas_to_ret,
+ ):
+ accelerator = Accelerator()
+ logging.info(f"wandb {'available' if _has_wandb else 'unavailable'}")
+ logger.info(accelerator.state)
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ if args.seed is not None:
+ set_seed(args.seed)
+ else:
+ logging.info("Not set seed.")
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+ accelerator.wait_for_everyone()
+
+ if accelerator.is_local_main_process and _has_wandb:
+ wandb.init(project=args.project_name, name=args.exp_name, dir=args.output_dir, config=vars(args))
+
+ logging.info("loading model")
+ config, tokenizer, self.model = load_model(T5WithKeyValueMemory, args)
+ logging.info("Loading data.")
+
+ if args.freeze_t5_params:
+ logging.info("Freeze T5 parameters.")
+ self.model.freeze_t5_params()
+
+ if args.not_share_encoder and not args.update_kv_embeds:
+ logging.info("Freeze kv-encoder parameters.")
+ self.model.freeze_kv_encoder_params()
+
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": args.weight_decay, },
+ {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0, },
+ ]
+ self.optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+
+ # reused_key_memory: pre-allocated memory to store full key_memory
+ self.reused_key_memory = torch.zeros((len(qas_to_retrieve), self.model.model_dim),
+ device="cpu", dtype=torch.float16)
+ self.train_data_query_embeds = torch.zeros((len(train_dataset), self.model.model_dim),
+ device="cpu", dtype=torch.float16)
+ self.key_memory: Optional[List[torch.tensor]] = None
+ self.key_memory = []
+ for start_idx in range(0, len(qas_to_retrieve), math.ceil(len(qas_to_retrieve) / args.kvm_seg_n)):
+ self.key_memory.append(
+ self.reused_key_memory[start_idx: start_idx + math.ceil(len(qas_to_retrieve) / args.kvm_seg_n)]
+ )
+ # logger.info(f"key num = {sum(len(i) for i in self.key_memory)}")
+
+ self.train_dataset = train_dataset
+ self.dev_dataset = dev_dataset
+ self.test_dataset = test_dataset
+ self.args = args
+ self.accelerator = accelerator
+ self.tokenizer = tokenizer
+ self.qas_to_retrieve = qas_to_retrieve
+ self.prefix = args.source_prefix if args.source_prefix is not None else ""
+ self.query_batch_size = args.query_batch_size
+ # logger.info(f"PAQ-size: {len(self.qas_to_retrieve)}. PAQ's query batch size: {self.query_batch_size}.")
+ self.normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret
+ self.model = self.accelerator.prepare(self.model)
+
+ @torch.no_grad()
+ def update_key_memory(self, use_fp16_model=True):
+ args = self.args
+ if use_fp16_model:
+ tmp_model = copy.deepcopy(self.model)
+ tmp_model = tmp_model.half()
+ else:
+ tmp_model = self.model
+ build_mem_batch_size = args.build_mem_batch_size
+ tmp_model.eval()
+
+ self.key_memory, _ = build_memory(
+ tmp_model, self.tokenizer, embed_key=True, embed_value=False, prefix="question: ", embed_as_fp16=True,
+ key_reduce_method=args.key_reduce_method, return_memory=True, dump_memory=False,
+ data_to_embed=self.qas_to_retrieve, max_source_length=args.max_source_length, padding=True,
+ batch_size=build_mem_batch_size, separate_task=True, kvm_seg_n=args.kvm_seg_n,
+ reused_key_memory=self.reused_key_memory, num_workers=0
+ )
+ if type(self.key_memory) is not list:
+ self.key_memory = [self.key_memory]
+ del tmp_model
+
+ @torch.no_grad()
+ def update_local_qas(self, epoch, use_fp16_model=True):
+ args = self.args
+ if use_fp16_model:
+ tmp_model = copy.deepcopy(self.model)
+ tmp_model = tmp_model.half()
+ else:
+ tmp_model = self.model
+ build_mem_batch_size = args.build_mem_batch_size
+ tmp_model.eval()
+ update_dialog_local_qas_to_retrieve(
+ args, self.train_dataset, self.qas_to_retrieve, tmp_model, self.key_memory,
+ self.normed_answer_of_qas_to_ret, train_data_query_embeds=self.train_data_query_embeds,
+ build_mem_batch_size=build_mem_batch_size, query_batch_size=self.query_batch_size,
+ local_size=args.local_size, pos_from_top=args.pos_from_top, neg_from_top=200
+ )
+ del tmp_model
+
+ def train(self):
+ args = self.args
+ tokenizer = self.tokenizer
+ num_workers = 5
+ logging.info("Build Memory")
+ self.update_key_memory()
+
+ train_dataloader = self.train_dataset.get_dataloader(batch_size=args.per_device_train_batch_size,
+ shuffle=True, num_workers=num_workers)
+ optimizer, train_dataloader = self.accelerator.prepare(self.optimizer, train_dataloader)
+
+ # Scheduler and math around the number of training steps.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ else:
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ lr_scheduler = get_scheduler(
+ name=args.lr_scheduler_type,
+ optimizer=optimizer,
+ num_warmup_steps=args.num_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # Train!
+ total_batch_size = args.per_device_train_batch_size * self.accelerator.num_processes \
+ * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(self.train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not self.accelerator.is_local_main_process)
+ completed_steps = 0
+
+ best_score, patience = float("-inf"), args.early_stop_patience
+ best_score_epoch = -1
+ eval_times = 0
+ select_mt = "f1" if self.train_dataset.dataset_name != "eli5_kilt" else "RougeL"
+
+ for epoch in range(args.num_train_epochs):
+ if args.qas_to_retrieve_from == "PAQ" and (epoch % 3 == 0):
+ self.update_local_qas(epoch)
+ elif args.qas_to_retrieve_from != "PAQ":
+ self.update_local_qas(epoch)
+ for step, batch in enumerate(train_dataloader):
+
+ update_batch_inputs(args, batch, self.model)
+ self.model.train()
+ if args.match_weight > 0.0:
+ # Embed Positive Key and the Value to input.
+ embed_dict = self.model.wrapped_embed_kv( # assert num_values > 1, otherwise set compute_value=True
+ separate_task=args.separate_task, compute_key=True, compute_value=False,
+ **batch.pop("positive_kv_inputs")
+ )
+ positive_key_embeds = embed_dict["normed_key_embeds"]
+ positive_key_embeds = reduce_query_or_key_embeds(positive_key_embeds, args.key_reduce_method)
+ # Embed Negative Key
+ embed_dict = self.model.wrapped_embed_kv(
+ separate_task=args.separate_task, compute_key=True, compute_value=False,
+ **batch.pop("negative_kv_inputs")
+ )
+ negative_key_embeds = embed_dict["normed_key_embeds"]
+ negative_key_embeds = reduce_query_or_key_embeds(negative_key_embeds, args.key_reduce_method)
+ else:
+ negative_key_embeds, positive_key_embeds = None, None
+ # Embed retrieved-Key-Value
+ embed_dict = self.model.wrapped_embed_kv(
+ separate_task=args.separate_task, compute_key=True, compute_value=True,
+ **batch.pop("group_value_inputs")
+ )
+ key_embeds_of_value = embed_dict["key_embeds"]
+ value_embeds = embed_dict["value_embeds"]
+ bs = batch["query_input_ids"].shape[0]
+ value_embeds = value_embeds.view(bs, args.num_values, args.prefix_length, -1)
+ key_embeds_of_value = key_embeds_of_value.view(bs, args.num_values, -1, self.model.model_dim)
+
+ loss_dict = self.model.compute_qa_loss(
+ input_ids=batch["history_input_ids"],
+ attention_mask=batch["history_attention_mask"],
+ labels=batch["labels"],
+ decoder_only_attend_on_prefix=False,
+ value_fusion_method=args.value_fusion_method,
+ encoder_outputs_are_key_or_value=False,
+ key_reduce_method=args.key_reduce_method,
+ positive_key_embeds=positive_key_embeds,
+ negative_key_embeds=negative_key_embeds,
+ value_embeds=value_embeds,
+ matching_targets=batch["matching_targets"],
+ key_embeds_of_value=key_embeds_of_value,
+ negative_mask=batch.get("negative_mask", None),
+ )
+ if args.match_weight > 0.0:
+ loss = loss_dict["gen_loss"]
+ else:
+ loss = args.gen_weight * loss_dict["gen_loss"] + args.match_weight * loss_dict["match_loss"]
+ loss_dict = {k: v.item() for k, v in loss_dict.items()}
+ loss = loss / args.gradient_accumulation_steps
+ self.accelerator.backward(loss)
+ if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ progress_bar.update(1)
+ if args.match_weight > 0.0:
+ progress_bar.set_description(f"GL-[{loss_dict['gen_loss']:.5f}]")
+ else:
+ progress_bar.set_description(f"GL-[{loss_dict['gen_loss']:.5f}] "
+ f"RL-[{loss_dict['match_loss']:.5f}]")
+ completed_steps += 1
+ if self.accelerator.is_local_main_process and _has_wandb:
+ wandb.log({"loss": loss * args.gradient_accumulation_steps, "step": completed_steps})
+ wandb.log({"trainable_percentage": batch["trainable_percentage"][0].item(),
+ "forward_step": completed_steps})
+ for k, v in loss_dict.items():
+ wandb.log({k: v, "step": completed_steps})
+
+ if args.eval_every_n_steps is not None and completed_steps % args.eval_every_n_steps == 0:
+ self.update_local_qas(epoch)
+ metric, _, _ = self.evaluate(dataset=self.dev_dataset, extend_mem_from="train")
+ for k, v in metric.items():
+ logger.info(f"eval_times {eval_times} eval - {k}: {v * 100:.5f}")
+ if _has_wandb:
+ for k, v in metric.items():
+ wandb.log({k: v * 100, "eval_times": eval_times})
+
+ if args.output_dir is not None:
+ cur_score = metric[select_mt]
+ if cur_score > best_score:
+ best_score = cur_score
+ best_score_epoch = epoch
+ save_model(self.model, os.path.join(args.output_dir, "best_ckpt"),
+ self.accelerator, tokenizer=tokenizer, arguments=args)
+ patience = args.early_stop_patience
+ else:
+ patience -= 1
+ if patience <= 0:
+ break
+
+ if completed_steps >= args.max_train_steps:
+ break
+
+ if args.output_dir is not None:
+ save_model(self.model, os.path.join(args.output_dir, "latest_ckpt"),
+ self.accelerator, tokenizer=tokenizer, arguments=args)
+
+ if args.update_kv_embeds:
+ logging.info("Update Memory")
+ self.update_key_memory()
+
+ if args.do_eval and epoch % args.eval_freq == 0:
+ # if args.do_eval: self.key_memory is up-to-date.
+ metric, _, _ = self.evaluate(dataset=self.dev_dataset, extend_mem_from="train")
+ for k, v in metric.items():
+ logger.info(f"epoch {epoch} eval - {k}: {v * 100:.5f}")
+ if _has_wandb:
+ for k, v in metric.items():
+ wandb.log({k: v * 100, "epoch": epoch})
+
+ if args.output_dir is not None:
+ cur_score = metric[select_mt]
+ if cur_score > best_score:
+ best_score = cur_score
+ best_f1_epoch = epoch
+ save_model(self.model, os.path.join(args.output_dir, "best_ckpt"),
+ self.accelerator, tokenizer=tokenizer, arguments=args)
+ patience = args.early_stop_patience
+ else:
+ patience -= 1
+ if patience <= 0:
+ break
+
+ logger.info(f"best_f1_dev: {best_score * 100:.5f}")
+ logger.info(f"best_f1 epoch: {best_score_epoch}")
+ if _has_wandb:
+ wandb.log({"best_f1_dev": best_score * 100})
+
+ # do-test
+ best_model_state_dict = os.path.join(args.output_dir, "best_ckpt/pytorch_model.bin")
+ metric, _, _ = self.evaluate(dataset=self.test_dataset, extend_mem_from="train_dev",
+ update_key_memory=True, ckpt_load_path=best_model_state_dict)
+
+ if self.accelerator.is_local_main_process:
+ for k, v in metric.items():
+ logger.info(f"test - {k}: {v * 100:.5f}")
+ if _has_wandb:
+ for k, v in metric.items():
+ wandb.log({f"test_{k}": v * 100})
+
+ @torch.no_grad()
+ def evaluate(self, dataset: DialogDataset = None, extend_mem_from="", update_key_memory=False, ckpt_load_path=None,
+ gen_kwargs=None):
+ # not implement correctly in multi-GPUs
+ tokenizer = self.tokenizer
+ args = self.args
+ self.model.eval()
+ torch.cuda.empty_cache()
+
+ if ckpt_load_path is not None:
+ if args.update_kv_embeds:
+ assert update_key_memory is True
+ self.model.load_state_dict(torch.load(ckpt_load_path), strict=True)
+
+ assert type(self.key_memory) == list
+ if update_key_memory:
+ logging.info("Update Memory")
+ self.update_key_memory()
+
+ dataloader = dataset.get_query_dataloader(batch_size=args.per_device_eval_batch_size,
+ shuffle=False, num_workers=1, add_history=True)
+ dataloader = self.accelerator.prepare(dataloader)
+
+ if gen_kwargs is None:
+ gen_kwargs = {"max_length": args.max_target_length,
+ "num_beams": args.num_beams, }
+
+ torch.cuda.empty_cache()
+
+ all_retrieved_qas = []
+ all_gen_response = []
+ for batch in tqdm(dataloader):
+ embed_dict = self.model.CAT_embed_q(
+ input_ids=batch["query_input_ids"],
+ attention_mask=batch["query_attention_mask"],
+ compute_key=True, compute_value=False
+ )
+ query_embeds = embed_dict["normed_key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method)
+ query_embeds = query_embeds.half()
+
+ # calculate topk in each chunk -> combine all-topk -> select final topk
+ chunk_top_scores = []
+ chunk_top_indices = []
+ idx_shift = 0
+ for chunk_key_memory in self.key_memory:
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_topk = torch.mm(query_embeds, chunk_key_memory_cuda.t()).topk(50, dim=1)
+ chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size]
+ chunk_top_indices.append(chunk_topk.indices + idx_shift)
+ idx_shift += len(chunk_key_memory)
+ del chunk_key_memory_cuda
+ torch.cuda.empty_cache()
+ chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n
+ chunk_top_indices = torch.cat(chunk_top_indices, dim=1) # q_batch, local_size*seg_n
+ topk = chunk_top_scores.topk(args.num_values, dim=1) # q_batch, local_size
+ top_indices_indices = topk.indices
+ top_indices = []
+ for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices):
+ top_indices.append([cur_indices[idx] for idx in cur_indices_indices])
+ readout_qas = [[self.qas_to_retrieve[idx] for idx in indices] for indices in top_indices]
+
+ value_qas = []
+ for qas in readout_qas:
+ selected_qas = qas[:args.num_values]
+ if not args.values_with_order:
+ random.shuffle(selected_qas)
+ value_qas.append(selected_qas)
+ all_retrieved_qas += readout_qas
+
+ squeezed_value_qas = list(chain(*value_qas))
+ retrieved_qas_inputs = dataset.get_qa_key_value_inputs(squeezed_value_qas, only_return_key_inputs=False)
+ embed_dict = self.model.wrapped_embed_kv(separate_task=args.separate_task, compute_key=True,
+ compute_value=True, **retrieved_qas_inputs)
+ value_embeds = embed_dict["value_embeds"]
+ key_embeds_of_value = embed_dict["key_embeds"]
+ cur_batch_size = query_embeds.shape[0]
+ value_embeds = value_embeds.view(cur_batch_size, args.num_values, args.prefix_length, -1)
+ key_embeds_of_value = key_embeds_of_value.view(cur_batch_size, args.num_values, -1, self.model.model_dim)
+ encoder_outputs = self.model.encoder(
+ input_ids=batch["history_input_ids"],
+ attention_mask=batch["history_attention_mask"],
+ return_dict=True,
+ value_embeds=value_embeds,
+ readout_top_k=-1,
+ key_reduce_method=args.key_reduce_method,
+ value_fusion_method=args.value_fusion_method,
+ key_embeds_of_value=key_embeds_of_value
+ )
+ generated_tokens = self.accelerator.unwrap_model(self.model).generate(
+ encoder_outputs=encoder_outputs,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=False,
+ attention_mask=batch["history_attention_mask"].to(self.model.device),
+ value_fusion_method=args.value_fusion_method,
+ **gen_kwargs,
+ )
+ decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+ decoded_tokens = [ans.strip() for ans in decoded_tokens]
+ all_gen_response += decoded_tokens
+
+ torch.cuda.empty_cache()
+
+ if "normalized_response" in dataset.data[0].keys():
+ if dataset.dataset_name == "wow_kilt":
+ reference = [item["normalized_response"] for item in dataset.data] # only one target
+ else:
+ reference = [item["candidate_responses"] for item in dataset.data] # multi candidates
+ elif "response" in dataset.data[0].keys():
+ reference = [DialogDataset.normalize_answer(item["response"]) for item in dataset.data]
+ else:
+ reference = None
+
+ if reference is not None:
+ assert len(all_gen_response) == len(reference)
+ metrics = dict()
+ if dataset.dataset_name == "eli5_kilt":
+ avg_rougel = sum([max_score_over_ground_truths(rougel_score, pred, ref) for pred, ref in
+ zip(all_gen_response, reference)]) / len(all_gen_response)
+ metrics.update({"RougeL": avg_rougel})
+ reference = [[DialogDataset.normalize_answer(s) for s in cands] for cands in reference]
+ all_gen_response = [DialogDataset.normalize_answer(s) for s in all_gen_response]
+ avg_f1 = sum([max_score_over_ground_truths(f1_score, pred, ref) for pred, ref in
+ zip(all_gen_response, reference)]) / len(all_gen_response)
+ else:
+ reference = [DialogDataset.normalize_answer(s) for s in reference]
+ all_gen_response = [DialogDataset.normalize_answer(s) for s in all_gen_response]
+ bleu_scores = compute_batch_BLEU(reference, all_gen_response)
+ metrics.update(bleu_scores)
+ avg_f1 = sum([f1_score(pred, ref) for pred, ref in
+ zip(all_gen_response, reference)]) / len(all_gen_response)
+ metrics.update({"f1": avg_f1})
+ else:
+ metrics = dict()
+ results = []
+ assert len(dataset.data) == len(all_gen_response) == len(all_retrieved_qas)
+ for input_item, pred, retrieved in zip(dataset.data, all_gen_response, all_retrieved_qas):
+ results.append({
+ "id": input_item["id"],
+ "input": self.tokenizer.decode(input_item["input_ids"]),
+ "query": self.tokenizer.decode(input_item["query_ids"]),
+ "output": {"answer": pred, "provenance": [{"wikipedia_id": "12904"}]},
+ "retrieved_qas": [f"question: {qa['question']} answer: {qa['answer'][0]}" for qa in retrieved]
+ })
+ dump_path = os.path.dirname(ckpt_load_path)
+ dump_path = os.path.join(dump_path, f"{time.strftime('%d %H-%M')}_predict_result.json")
+ # save_path = os.path.join(args.output_dir, "kilt_test_predict.jsonl")
+ write_jsonl(results, dump_path)
+
+ return metrics, all_retrieved_qas, all_gen_response
+
+
+@torch.no_grad()
+def kilt_generate(model, tokenizer, embedding_index, key_memory, value_memory, dataset: DialogDataset,
+ qas_to_retrieve, inference_batch_size, gen_kwargs):
+ model.eval()
+
+ dataloader = dataset.get_query_dataloader(batch_size=inference_batch_size, shuffle=False,
+ num_workers=1, add_history=True)
+ all_retrieved_qas = []
+ all_gen_response = []
+
+ for batch in dataloader:
+ embed_dict = model.CAT_embed_q(
+ input_ids=batch["query_input_ids"].cuda(),
+ attention_mask=batch["query_attention_mask"].cuda(),
+ compute_key=True, compute_value=False
+ )
+ query_embeds = embed_dict["normed_key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, "avg")
+ query_embeds = query_embeds.half()
+ bs = len(batch["query_input_ids"])
+ # calculate topk in each chunk -> combine all-topk -> select final topk
+ chunk_top_scores = []
+ chunk_top_indices = []
+ idx_shift = 0
+ if type(embedding_index) != list:
+ embedding_index = [embedding_index]
+ for chunk_key_memory in embedding_index:
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_topk = torch.mm(query_embeds, chunk_key_memory_cuda.t()).topk(50, dim=1)
+ chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size]
+ chunk_top_indices.append(chunk_topk.indices + idx_shift)
+ idx_shift += len(chunk_key_memory)
+ del chunk_key_memory_cuda
+ torch.cuda.empty_cache()
+ chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n
+ chunk_top_indices = torch.cat(chunk_top_indices, dim=1) # q_batch, local_size*seg_n
+ topk = chunk_top_scores.topk(model.encoder.num_values, dim=1) # q_batch, local_size
+ top_indices_indices = topk.indices
+ top_indices = []
+ for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices):
+ top_indices.append([cur_indices[idx] for idx in cur_indices_indices])
+ readout_qas = [[qas_to_retrieve[idx] for idx in indices] for indices in top_indices]
+
+ value_qas = []
+ for qas in readout_qas:
+ selected_qas = qas[:model.encoder.num_values]
+ value_qas.append(selected_qas)
+ all_retrieved_qas += readout_qas
+
+ top_indices = torch.tensor(top_indices)
+ memory_size, hidden_num, hidden_size = value_memory.shape
+ value_embeds = torch.index_select(value_memory, 0, top_indices.view(-1)).float().cuda()
+ value_embeds = value_embeds.view(bs, model.encoder.num_values, hidden_num, hidden_size)
+ key_embeds_of_value = torch.index_select(key_memory, 0, top_indices.view(-1)).float().cuda()
+ key_embeds_of_value = key_embeds_of_value.view(bs, model.encoder.num_values, hidden_num, hidden_size)
+ encoder_outputs = model.encoder(
+ input_ids=batch["history_input_ids"].cuda(),
+ attention_mask=batch["history_attention_mask"].cuda(),
+ return_dict=True,
+ value_embeds=value_embeds,
+ readout_top_k=-1,
+ key_reduce_method="avg",
+ value_fusion_method=model.encoder.value_fusion_method,
+ key_embeds_of_value=key_embeds_of_value
+ )
+ generated_tokens = model.generate(
+ encoder_outputs=encoder_outputs,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=False,
+ attention_mask=batch["history_attention_mask"].to(model.device),
+ value_fusion_method=model.encoder.value_fusion_method,
+ **gen_kwargs,
+ )
+ decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+ decoded_tokens = [ans.strip() for ans in decoded_tokens]
+ all_gen_response += decoded_tokens
+
+ return all_retrieved_qas, all_gen_response
+
+
+if __name__ == '__main__':
+ pass
diff --git a/kvm_pretrain.py b/kvm_pretrain.py
new file mode 100644
index 0000000..b652e4d
--- /dev/null
+++ b/kvm_pretrain.py
@@ -0,0 +1,575 @@
+import argparse
+import logging
+import math
+import os
+import pickle
+import random
+import tempfile
+from functools import partial
+import time
+from itertools import chain
+
+import datasets
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from datasets import load_dataset, load_metric, DatasetDict
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+from transformers import (
+ AdamW,
+ get_scheduler,
+ set_seed,
+)
+
+from transformers.utils.versions import require_version
+import json
+
+from emat.evaluation.eval_retriever import eval_generation_em
+from emat.t5 import T5WithKeyValueMemory, CATEncoderOutput
+from utils.utils import CATArgs, update_CAT_config_from_args, save_model, get_key_value_encoder_inputs, \
+ get_key_value_ae_target, get_qa_inputs, load_model
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+
+logger = logging.getLogger(__name__)
+require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
+
+# Initialise wandb
+try:
+ import wandb
+
+ wandb.login(key="09cf72d9c096a95d573fd2b857bfd601022bc4b7")
+ os.environ["WANDB_API_KEY"] = "09cf72d9c096a95d573fd2b857bfd601022bc4b7"
+ wandb.ensure_configured()
+ if wandb.api.api_key is None:
+ _has_wandb = False
+ wandb.termwarn(
+ "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
+ else:
+ _has_wandb = False if os.getenv("WANDB_DISABLED") else True
+except (ImportError, AttributeError):
+ _has_wandb = False
+ logger.info("no WANDB")
+ exit()
+
+DATA_PATHS = {
+ "nq-repaq": {
+ # "train": "repaq_results/nq.train-train.retriever_multi_base_256.jsonl",
+ # "validation": "repaq_results/nq.train-dev.retriever_multi_base_256.jsonl",
+ # "test": "repaq_results/nq.test.retriever_multi_base_256.jsonl"
+ "train": "./cbqa_data/RePAQ/RePAQ-output-NQ-train.jsonl",
+ "validation": "./cbqa_data/RePAQ/RePAQ-output-NQ-dev.jsonl",
+ },
+ "PAQ-L1-Pretrain": {
+ "train": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-train.jsonl",
+ # "validation": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-dev.jsonl"
+ "validation": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-dev-3000.jsonl"
+ },
+ "PAQ-L1-Small": {
+ "train": "./data/cbqa_data/pretrain_data/paq-l1-small-train.jsonl", # 10w examples from PAQ-L1
+ "validation": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-dev-3000.jsonl"
+ },
+ "data_for_debug": {
+ "train": "./data/cbqa_data/pretrain_data/debug.jsonl",
+ "validation": "./data/cbqa_data/pretrain_data/debug.jsonl"
+ }
+}
+PAQ_PATH = "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl"
+PAQ_L1_PATH = "./data/cbqa_data/pretrain_data/PAQ_L1/PAQ_L1.filtered.jsonl"
+
+qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_pickl_file.pkl", 'rb'))
+
+
+# all_data = load_jsonl("./data/cbqa_data/pretrain_data/paq-l1-pretrain-train.jsonl")
+# train_data = all_data[:100000]
+# write_jsonl(train_data, DATA_PATHS["PAQ-L1-Small"]["train"])
+
+def load_pretrain_kvm_data(args) -> DatasetDict:
+ assert args.pretrain_data_name in DATA_PATHS.keys(), f"available dataset names: {DATA_PATHS.keys()}"
+ data_paths = DATA_PATHS[args.pretrain_data_name]
+ data_files = {
+ "train": [data_paths["train"]],
+ "validation": [data_paths["validation"]]
+ }
+ return load_dataset("json", data_files=data_files)
+
+
+def main():
+ # Parse the arguments
+ # args = parse_args()
+ cat_args = CATArgs(exp_type="pretrain")
+ args = cat_args.parse_args()
+
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
+ accelerator = Accelerator()
+
+ # Make one log on every process with the configuration for debugging.
+
+ logging.info(f"wandb {'available' if _has_wandb else 'unavailable'}")
+
+ logger.info(accelerator.state)
+
+ # Setup logging, we only want one process per machine to log things on the screen.
+ # accelerator.is_local_main_process is only True for one process per machine.
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+ accelerator.wait_for_everyone()
+
+ if accelerator.is_local_main_process and _has_wandb:
+ wandb.init(project=args.project_name, name=args.exp_name, dir=args.output_dir, config=vars(args))
+
+ logging.info("loading model")
+ config, tokenizer, model = load_model(T5WithKeyValueMemory, args)
+ logging.info("Loading data.")
+
+ if args.freeze_t5_params:
+ for name, param in model.named_parameters():
+ if 'prefix_embedding' in name or 'key_encoder' in name:
+ param.requires_grad = True
+ else:
+ param.requires_grad = False
+
+ prefix = args.source_prefix if args.source_prefix is not None else ""
+
+ # Temporarily set max_target_length for training.
+ max_target_length = args.max_target_length
+
+ with_ae_task = args.key_ae_weight > 0.0 and args.value_ae_weight > 0.0
+ assert with_ae_task or args.pretrain_multi_values
+
+ def preprocess_pretrain_input_func(examples):
+ # Normal inputs and outputs
+ model_inputs = get_key_value_encoder_inputs(examples, args.separate_task, tokenizer, args.max_source_length,
+ prefix=prefix, only_return_key_inputs=False)
+ model_inputs.update(get_key_value_ae_target(examples, tokenizer, args.key_ae_target, args.value_ae_target,
+ max_target_length))
+
+ return model_inputs
+
+ def preprocess_pretrain_multi_values_input_func(examples, value_with_self_prop=None):
+ # How to build the dataset
+ # we hope two kinds of data contained in dataset:
+ # 1) top-k data, and the golden answer is contained in top-k QAs
+ # 2) top-k data do not contain the golden answer
+ # but all data's target are golden
+ model_inputs = get_qa_inputs(examples, tokenizer, args.max_source_length, max_target_length, prefix=prefix,
+ targets=None)
+ value_qas = []
+ for ex in examples:
+ if random.random() < value_with_self_prop:
+ selected_values_indices = list(ex["retrieved_PAQL1_indices"][:args.num_values])
+ else:
+ selected_values_indices = list(ex["retrieved_PAQL1_indices"][1:args.num_values + 1])
+ if not args.values_with_order:
+ random.shuffle(selected_values_indices)
+ selected_values = [qas_to_retrieve[idx] for idx in selected_values_indices]
+ value_qas.append(selected_values)
+ value_qas = list(chain(*value_qas)) # bs * num_values
+ group_value_inputs = get_key_value_encoder_inputs(value_qas, args.separate_task, tokenizer,
+ args.max_source_length, prefix=prefix,
+ value_input_is_qa=args.value_input_is_qa)
+ for dk in group_value_inputs:
+ model_inputs[f"group_value_inputs_{dk}"] = group_value_inputs[dk]
+
+ model_inputs.update(
+ get_key_value_encoder_inputs(examples, args.separate_task, tokenizer, args.max_source_length,
+ prefix=prefix, only_return_key_inputs=False,
+ value_input_is_qa=args.value_input_is_qa))
+
+ if with_ae_task:
+ model_inputs.update(get_key_value_ae_target(examples, tokenizer, args.key_ae_target,
+ args.value_ae_target, max_target_length))
+
+ return model_inputs
+
+ # Evaluation metric: load the custom exact-match metric
+
+ # Load the training/validation datasets
+ if args.pretrain_multi_values:
+ if "debug" not in args.exp_name:
+ pretrain_data_path = "./tmp/PAQ_L1_with_50_xlarge_retrieved_qa_indices.pkl"
+ # "tmp/PAQ_L1_with_retrieved_qa_indices.pkl"
+ raw_datasets = pickle.load(open(pretrain_data_path, "rb"))
+ train_dataset = raw_datasets[:-5000]
+ eval_dataset = raw_datasets[-5000:]
+ else:
+ raw_datasets = pickle.load(open("./tmp/PAQ_L1_5k_with_retrieved_qa_indices.pkl", 'rb'))
+ train_dataset = raw_datasets[:100]
+ eval_dataset = raw_datasets[:100]
+ collate_fn = partial(preprocess_pretrain_multi_values_input_func,
+ value_with_self_prop=args.value_with_self_prop)
+ else:
+ raw_datasets = load_pretrain_kvm_data(args)
+ train_dataset = raw_datasets["train"]
+ eval_dataset = raw_datasets["validation"]
+ collate_fn = preprocess_pretrain_input_func
+ # DataLoaders creation:
+ train_dataloader = DataLoader(train_dataset, shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.per_device_train_batch_size,
+ num_workers=args.preprocessing_num_workers)
+ eval_dataloader = DataLoader(eval_dataset,
+ collate_fn=collate_fn,
+ batch_size=args.per_device_eval_batch_size,
+ num_workers=args.preprocessing_num_workers)
+
+ without_self_eval_dataloader = DataLoader(eval_dataset,
+ collate_fn=partial(preprocess_pretrain_multi_values_input_func,
+ value_with_self_prop=0.0),
+ batch_size=args.per_device_eval_batch_size,
+ num_workers=args.preprocessing_num_workers)
+
+ if not args.do_train and args.do_eval:
+ model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
+ metric_key, metric_value, all_gen_ans = evaluate(args, model, config, eval_dataloader, accelerator, tokenizer)
+ if args.train_key:
+ key_em_score = metric_key.compute()["em"] * 100 # EM score is not in percentage points
+ logger.info(f"eval - Key-EM: {key_em_score:.2f}")
+ if args.train_value:
+ value_em_score = metric_value.compute()["em"] * 100 # EM score is not in percentage points
+ logger.info(f"eval - Value-EM: {value_em_score:.2f}")
+
+ em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100
+ logger.info(f"em_test: {em_score:.2f}")
+ exit()
+
+ if args.do_train:
+ # Log a few random samples from the training set:
+ for index in random.sample(range(len(train_dataset)), 3):
+ logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
+
+ # Optimizer
+ # Split weights in two groups, one with weight decay and the other not.
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": args.weight_decay, },
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0, },
+ ]
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+
+ # Prepare everything with our `accelerator`.
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
+ model, optimizer, train_dataloader, eval_dataloader
+ )
+
+ # Scheduler and math around the number of training steps.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ else:
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ lr_scheduler = get_scheduler(
+ name=args.lr_scheduler_type,
+ optimizer=optimizer,
+ num_warmup_steps=args.num_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # Train!
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+ completed_steps = 0
+ eval_times = 0
+
+ best_em, patience = None, args.early_stop_patience
+ for epoch in range(args.num_train_epochs):
+ model.train()
+ for step, batch in enumerate(train_dataloader):
+ assert args.separate_task
+ loss = 0.
+ loss_dict = dict()
+ if args.pretrain_multi_values:
+ group_inputs = {k.replace("group_value_inputs_", ""): v for k, v in batch.items() if
+ k.startswith("group_value_inputs_")}
+ embed_dict = model.wrapped_embed_kv(
+ separate_task=args.separate_task, compute_key=True, compute_value=True, **group_inputs
+ )
+ key_embeds_of_value = embed_dict["key_embeds"]
+ value_embeds = embed_dict["value_embeds"]
+ bs = batch["input_ids"].shape[0]
+ value_embeds = value_embeds.view(bs, args.num_values, args.prefix_length, -1)
+ key_embeds_of_value = key_embeds_of_value.view(bs, args.num_values, -1, model.model_dim)
+
+ loss_dict_gen = model.compute_qa_loss(
+ input_ids=batch["input_ids"],
+ attention_mask=batch["attention_mask"],
+ labels=batch["labels"],
+ decoder_only_attend_on_prefix=args.decoder_only_attend_on_prefix,
+ value_fusion_method=args.value_fusion_method,
+ encoder_outputs_are_key_or_value=False,
+ key_reduce_method=args.key_reduce_method,
+ value_embeds=value_embeds,
+ key_embeds_of_value=key_embeds_of_value
+ )
+ loss += args.gen_weight * loss_dict_gen["gen_loss"]
+ loss_dict.update({k: v for k, v in loss_dict_gen.items()})
+ if with_ae_task:
+ loss_dict_ae = model.compute_key_value_ae_loss(
+ train_key=args.train_key,
+ train_value=args.train_value,
+ separate_task=True,
+ key_input_ids=batch["key_input_ids"],
+ key_attention_mask=batch["key_attention_mask"],
+ value_input_ids=batch["value_input_ids"],
+ value_attention_mask=batch["value_attention_mask"],
+ key_labels_input_ids=batch["key_labels_input_ids"],
+ value_labels_input_ids=batch["value_labels_input_ids"],
+ )
+ if args.train_key:
+ loss += args.key_ae_weight * loss_dict_ae["key_ae_loss"]
+ if args.train_value:
+ loss += args.value_ae_weight * loss_dict_ae["value_ae_loss"]
+ loss_dict.update({k: v.item() for k, v in loss_dict_ae.items()})
+
+ loss = loss / args.gradient_accumulation_steps
+ accelerator.backward(loss)
+ if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ progress_bar.update(1)
+ completed_steps += 1
+
+ if accelerator.is_local_main_process and _has_wandb:
+ wandb.log({"loss": loss * args.gradient_accumulation_steps, "step": completed_steps})
+ for k, v in loss_dict.items():
+ wandb.log({k: v, "step": completed_steps})
+
+ if completed_steps % 5555 == 0:
+ metric_key, metric_value, all_gen_ans = evaluate(
+ args, model, config, eval_dataloader, accelerator, tokenizer)
+
+ scores = []
+ if args.train_key:
+ key_em_score = metric_key.compute()["em"] * 100 # EM score is not in percentage points
+ logger.info(f"epoch {epoch} eval-time {eval_times} - Key-EM: {key_em_score:.2f}")
+ if accelerator.is_local_main_process and _has_wandb:
+ wandb.log({"key_em_dev": key_em_score, "eval_times": eval_times})
+ scores.append(key_em_score)
+ if args.train_value:
+ value_em_score = metric_value.compute()["em"] * 100 # EM score is not in percentage points
+ logger.info(f"epoch {epoch} eval - Value-EM: {value_em_score:.2f}")
+ if accelerator.is_local_main_process and _has_wandb:
+ wandb.log({"value_em_dev": value_em_score, "eval_times": eval_times})
+ scores.append(value_em_score)
+
+ em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100
+ logger.info(f"em_test: {em_score:.2f}")
+ wandb.log({"em_test": em_score, "eval_times": eval_times})
+ if args.output_dir is not None:
+ if best_em is None or em_score > best_em:
+ best_em = em_score
+ save_model(model, os.path.join(args.output_dir, "best_ckpt"), accelerator,
+ tokenizer=tokenizer, arguments=args)
+
+ # without self eval
+ _, _, all_gen_ans = evaluate(args, model, config, without_self_eval_dataloader,
+ accelerator, tokenizer)
+ em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100
+ logger.info(f"w/o_self_em_test: {em_score:.2f}")
+ wandb.log({"w/o_self_em_test": em_score, "eval_times": eval_times})
+
+ eval_times += 1
+
+ if completed_steps >= args.max_train_steps:
+ break
+
+ if args.output_dir is not None:
+ save_model(model, os.path.join(args.output_dir, "latest_ckpt"), accelerator, tokenizer=tokenizer,
+ arguments=args)
+
+ if args.do_eval and epoch % args.eval_freq == 0:
+ metric_key, metric_value, all_gen_ans = evaluate(args, model, config, eval_dataloader, accelerator,
+ tokenizer)
+
+ scores = []
+ if args.train_key:
+ key_em_score = metric_key.compute()["em"] * 100 # EM score is not in percentage points
+ logger.info(f"epoch {epoch} eval - Key-EM: {key_em_score:.2f}")
+ if accelerator.is_local_main_process and _has_wandb:
+ wandb.log({"key_em_dev": key_em_score, "epoch": epoch})
+ scores.append(key_em_score)
+ if args.train_value:
+ value_em_score = metric_value.compute()["em"] * 100 # EM score is not in percentage points
+ logger.info(f"epoch {epoch} eval - Value-EM: {value_em_score:.2f}")
+ if accelerator.is_local_main_process and _has_wandb:
+ wandb.log({"value_em_dev": value_em_score, "epoch": epoch})
+ scores.append(value_em_score)
+
+ em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100
+ logger.info(f"em_test: {em_score:.2f}")
+ wandb.log({"em_test": em_score, "epoch": epoch})
+ if args.output_dir is not None:
+ if best_em is None or em_score > best_em:
+ best_em = em_score
+ save_model(model, os.path.join(args.output_dir, "best_ckpt"), accelerator,
+ tokenizer=tokenizer,
+ arguments=args)
+
+ # without self eval
+ _, _, all_gen_ans = evaluate(args, model, config, without_self_eval_dataloader,
+ accelerator, tokenizer)
+ em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100
+ logger.info(f"w/o_self_em_test: {em_score:.2f}")
+ wandb.log({"w/o_self_em_test": em_score, "epoch": epoch})
+
+ if best_em is not None: # Log the best dev EM score
+ wandb.log({"best_em_dev": best_em})
+
+
+def postprocess_text(preds, labels):
+ preds = [pred.strip() for pred in preds]
+ labels = [[label.strip()] for label in labels]
+ return preds, labels
+
+
+@torch.no_grad()
+def evaluate(args, model, config, eval_dataloader, accelerator, tokenizer):
+ metric_value = load_metric("emat/evaluation/exact_match.py")
+ metric_key = load_metric("emat/evaluation/exact_match.py")
+ if args.val_max_target_length is None:
+ args.val_max_target_length = args.max_target_length
+
+ gen_kwargs = {
+ "max_length": args.val_max_target_length if args is not None else config.max_length,
+ "num_beams": args.num_beams,
+ }
+ all_gen_ans = []
+
+ for batch in tqdm(eval_dataloader):
+ model.eval()
+ group_inputs = {k.replace("group_value_inputs_", ""): v for k, v in batch.items() if
+ k.startswith("group_value_inputs_")}
+ embed_dict = model.wrapped_embed_kv(
+ separate_task=args.separate_task, compute_key=True, compute_value=True, **group_inputs
+ )
+ key_embeds_of_value = embed_dict["key_embeds"]
+ value_embeds = embed_dict["value_embeds"]
+ bs = batch["input_ids"].shape[0]
+ value_embeds = value_embeds.view(bs, args.num_values, args.prefix_length, -1)
+ key_embeds_of_value = key_embeds_of_value.view(bs, args.num_values, -1, model.model_dim)
+ encoder_outputs = model.encoder(
+ input_ids=batch["input_ids"].to(model.device),
+ attention_mask=batch["attention_mask"].to(model.device),
+ return_dict=True,
+ value_embeds=value_embeds,
+ readout_top_k=-1,
+ key_reduce_method=args.key_reduce_method,
+ value_fusion_method=args.value_fusion_method,
+ key_embeds_of_value=key_embeds_of_value
+ )
+ generated_tokens = accelerator.unwrap_model(model).generate(
+ encoder_outputs=encoder_outputs,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=args.decoder_only_attend_on_prefix,
+ attention_mask=batch["attention_mask"].to(model.device),
+ value_fusion_method=args.value_fusion_method,
+ **gen_kwargs,
+ )
+ generated_tokens = accelerator.pad_across_processes(
+ generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
+ )
+ generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
+ decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+ decoded_tokens = [ans.strip() for ans in decoded_tokens]
+ all_gen_ans += decoded_tokens
+
+ # Auto-Encoding loss
+ embed_dict = model.wrapped_embed_kv(
+ separate_task=args.separate_task,
+ key_input_ids=batch["key_input_ids"],
+ key_attention_mask=batch["key_attention_mask"],
+ value_input_ids=batch["value_input_ids"],
+ value_attention_mask=batch["value_attention_mask"],
+ compute_key=args.train_key,
+ compute_value=args.train_value,
+ embed_for_ae_task=True
+ )
+ key_embeds = embed_dict["normed_key_embeds"]
+ value_embeds = embed_dict["normed_value_embeds"] # normed value for generation
+
+ if args.train_key:
+ key_labels = batch["key_labels_input_ids"]
+ key_embeds = key_embeds.view(key_embeds.shape[0], -1, model.model_dim)
+
+ recovered_from_key = accelerator.unwrap_model(model).generate(
+ encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None),
+ attention_mask=None, encoder_outputs_are_key_or_value=True, **gen_kwargs,
+ )
+ recovered_from_key = accelerator.pad_across_processes(
+ recovered_from_key, dim=1, pad_index=tokenizer.pad_token_id
+ )
+ if not args.pad_to_max_length:
+ # If we did not pad to max length, we need to pad the labels too
+ key_labels = accelerator.pad_across_processes(key_labels, dim=1, pad_index=tokenizer.pad_token_id)
+ recovered_from_key = accelerator.gather(recovered_from_key).cpu().numpy()
+ key_labels = accelerator.gather(key_labels).cpu().numpy()
+ if args.ignore_pad_token_for_loss:
+ # Replace -100 in the labels as we can't decode them.
+ key_labels = np.where(key_labels != -100, key_labels, tokenizer.pad_token_id)
+ decoded_key = tokenizer.batch_decode(recovered_from_key, skip_special_tokens=True)
+ decoded_key_labels = tokenizer.batch_decode(key_labels, skip_special_tokens=True)
+ decoded_key, decoded_key_labels = postprocess_text(decoded_key, decoded_key_labels)
+ metric_key.add_batch(predictions=decoded_key, references=decoded_key_labels)
+
+ if args.train_value:
+ value_labels = batch["value_labels_input_ids"]
+ value_embeds = value_embeds.view(value_embeds.shape[0], -1, model.model_dim)
+ recovered_from_value = accelerator.unwrap_model(model).generate(
+ encoder_outputs=CATEncoderOutput(last_hidden_state=value_embeds, hidden_states=None, attentions=None),
+ attention_mask=None, encoder_outputs_are_key_or_value=True, **gen_kwargs,
+ )
+ recovered_from_value = accelerator.pad_across_processes(
+ recovered_from_value, dim=1, pad_index=tokenizer.pad_token_id
+ )
+ if not args.pad_to_max_length:
+ # If we did not pad to max length, we need to pad the labels too
+ value_labels = accelerator.pad_across_processes(value_labels, dim=1, pad_index=tokenizer.pad_token_id)
+ recovered_from_value = accelerator.gather(recovered_from_value).cpu().numpy()
+ value_labels = accelerator.gather(value_labels).cpu().numpy()
+ if args.ignore_pad_token_for_loss:
+ # Replace -100 in the labels as we can't decode them.
+ value_labels = np.where(value_labels != -100, value_labels, tokenizer.pad_token_id)
+ decoded_value = tokenizer.batch_decode(recovered_from_value, skip_special_tokens=True)
+ decoded_value_labels = tokenizer.batch_decode(value_labels, skip_special_tokens=True)
+ decoded_value, decoded_value_labels = postprocess_text(decoded_value, decoded_value_labels)
+ metric_value.add_batch(predictions=decoded_value, references=decoded_value_labels)
+
+ return metric_key, metric_value, all_gen_ans
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pretrain_scripts/pretrain_emat.sh b/pretrain_scripts/pretrain_emat.sh
new file mode 100644
index 0000000..dbcc661
--- /dev/null
+++ b/pretrain_scripts/pretrain_emat.sh
@@ -0,0 +1,50 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+EXP_NAME="KL=3;kdim=1536;VL=7;VN=10;async_cat_k+v;t5-base;"
+LOAD_EXP_NAME="t5-base" # t5-base dir
+DATA_NAME="PAQ-L1-Pretrain"
+DEVICE="0"
+
+echo ${DEVICE}
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python kvm_pretrain.py \
+ --project_name="CAT" \
+ --pretrain_multi_values \
+ --exp_name="${EXP_NAME}" \
+ --pretrain_data_name=${DATA_NAME} \
+ --num_values=10 \
+ --value_layer=7 \
+ --key_layer=3 \
+ --value_fusion_method="cat_k_delay+v" \
+ --key_reduce_method="avg" \
+ --model_name_or_path=${LOAD_EXP_NAME} \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=128 \
+ --per_device_eval_batch_size=256 \
+ --gradient_accumulation_steps=2 \
+ --preprocessing_num_workers=10 \
+ --learning_rate=5e-5 \
+ --num_train_epochs=5 \
+ --lr_scheduler_type="linear" \
+ --num_warmup_steps=5000 \
+ --output_dir="./outputs/checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_encoder_type="conv" \
+ --seed=42 \
+ --gen_weight=1.0 \
+ --key_ae_weight=0.5 \
+ --value_ae_weight=0.5 \
+ --value_with_self_prop=0.1 \
+ --key_ae_target="question" \
+ --value_ae_target="ans" \
+ --do_train \
+ --separate_task \
+ --train_key \
+ --train_value
diff --git a/pretrain_scripts/pretrain_sksv_emat.sh b/pretrain_scripts/pretrain_sksv_emat.sh
new file mode 100644
index 0000000..5b3e05e
--- /dev/null
+++ b/pretrain_scripts/pretrain_sksv_emat.sh
@@ -0,0 +1,52 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+LOAD_EXP_NAME="t5-base" # t5-base dir
+EXP_NAME="KL=3;kdim=1536;CL=10;VL=11;VN=10;async_cat_k+v;t5-base;"
+DATA_NAME="PAQ-L1-Pretrain"
+DEVICE="0"
+
+echo ${DEVICE}
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python kvm_pretrain.py \
+ --key_layer=3 \
+ --cat_layer=10 \
+ --value_layer=11 \
+ --value_fusion_method="async_cat_k_delay+v" \
+ --project_name="CAT" \
+ --pretrain_multi_values \
+ --exp_name="${EXP_NAME}" \
+ --pretrain_data_name=${DATA_NAME} \
+ --num_values=10 \
+ --key_reduce_method="avg" \
+ --model_name_or_path="${LOAD_EXP_NAME}" \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=128 \
+ --per_device_eval_batch_size=256 \
+ --gradient_accumulation_steps=2 \
+ --preprocessing_num_workers=10 \
+ --learning_rate=5e-5 \
+ --num_train_epochs=5 \
+ --lr_scheduler_type="constant" \
+ --num_warmup_steps=5000 \
+ --output_dir="./outputs/checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_encoder_type="conv" \
+ --seed=42 \
+ --gen_weight=1.0 \
+ --key_ae_weight=0.5 \
+ --value_ae_weight=0.5 \
+ --value_with_self_prop=0.1 \
+ --key_ae_target="question" \
+ --value_ae_target="ans" \
+ --do_train \
+ --separate_task \
+ --train_key \
+ --train_value \
+ --do_eval
diff --git a/qa_dataset.py b/qa_dataset.py
new file mode 100644
index 0000000..fbd14bd
--- /dev/null
+++ b/qa_dataset.py
@@ -0,0 +1,195 @@
+from itertools import chain
+import torch
+from torch.utils.data import Dataset, DataLoader
+from typing import List, Dict
+import random
+
+from emat.evaluation.exact_match import normalize_answer
+from utils.utils import process_labels
+# from transformers.models.t5 import T5Tokenizer
+from transformers import T5Tokenizer
+
+
+def format_data(item, label2str, str2label, dataset_name, task):
+ if dataset_name == "commonsense_qa":
+ pass
+
+
+class QADataset(Dataset):
+ def __init__(
+ self,
+ data: List[Dict],
+ tokenizer: T5Tokenizer,
+ qas_to_retrieve,
+ dataset_name,
+ retrieve_strategy="dr",
+ max_source_length=None,
+ args=None,
+ normed_answer_of_qas_to_ret=None,
+ ):
+ super(QADataset, self).__init__()
+ self.data: List[Dict] = data
+
+ for idx, i in enumerate(self.data):
+ i["idx"] = idx
+ i["normalized_answer"] = [normalize_answer(ans) for ans in i["answer"]]
+ assert dataset_name in ["nq", "tq", "wq"]
+ self.max_source_length = max_source_length if max_source_length is not None else 430
+ self.dataset_name = dataset_name
+ print(f"dataset-name: {dataset_name}")
+ self.max_target_length = 64
+ self.tokenizer = tokenizer
+ self.pad_idx = self.tokenizer.pad_token_id
+ self.label_pad_idx = -100
+ self.args = args
+ self.add_ae_input = False
+ self.qas_to_retrieve = qas_to_retrieve
+ self.normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret
+
+ self.pad_qa = {"question": "", "answer": [""]}
+
+ def get_key_value_inputs(self, qas, only_return_key_inputs=False):
+ # Used to get the input of Key-Value Encoder, qas are from PAQ-L1
+ key_inputs = ["question: " + qa["question"] for qa in qas]
+ key_inputs = self.tokenizer(key_inputs, max_length=self.max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ if only_return_key_inputs:
+ return {"key_input_ids": key_inputs["input_ids"],
+ "key_attention_mask": key_inputs["attention_mask"]}
+ else:
+ value_inputs = ["answer: " + qa["answer"][0] for qa in qas]
+ value_inputs = self.tokenizer(value_inputs, max_length=self.max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"key_input_ids": key_inputs["input_ids"],
+ "key_attention_mask": key_inputs["attention_mask"],
+ "value_input_ids": value_inputs["input_ids"],
+ "value_attention_mask": value_inputs["attention_mask"]}
+
+ def get_query_inputs(self, batch):
+ query_inputs = ["question: " + qa["question"] for qa in batch]
+ query_inputs = self.tokenizer(query_inputs, max_length=self.max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"query_input_ids": query_inputs["input_ids"],
+ "query_attention_mask": query_inputs["attention_mask"]}
+
+ def get_dataloader(self, batch_size, shuffle, num_workers):
+
+ def base_collate_fn(batch):
+
+ original_batch_size, filtered_batch_size = len(batch), len(batch)
+ if not self.args.use_not_exactly_true:
+ batch = [ex for ex in batch if len(ex["local_positive"]) > 0]
+ filtered_batch_size = len(batch)
+ while len(batch) == 0: # avoid empty-batch
+ batch = random.sample(self.data, batch_size)
+ batch = [ex for ex in batch if len(ex["local_positive"]) > 0]
+ # do not change filtered_batch_size even change the batch again.
+
+ model_inputs = {
+ "batch_data_ids": torch.tensor([qa["idx"] for qa in batch]),
+ "trainable_percentage": torch.tensor(filtered_batch_size / original_batch_size).repeat(len(batch)),
+ # repeat ``len(batch)`` times to compatible in multi-GPUs.
+ }
+ model_inputs.update(self.get_query_inputs(batch))
+
+ batch_local_positive_num = self.args.batch_local_positive_num
+ neg_num_each_example = self.args.negatives_num_each_example
+ local_positive_qas = []
+ local_positive_num = []
+ local_positive_qas_mask = []
+ local_negative_qas = []
+ local_pos_mix_neg_qas = [] # num = neg_num_each_example
+ for ex in batch:
+ cur_local_positive_qas_ids = [idx for idx in ex["local_positive"][:batch_local_positive_num]]
+ cur_local_positive_qas = [self.qas_to_retrieve[idx] for idx in cur_local_positive_qas_ids]
+ cur_pos_num = len(cur_local_positive_qas)
+ local_positive_num.append(cur_pos_num)
+
+ cur_local_negative_qas_idx = random.sample(ex["local_negative"], neg_num_each_example)
+ cur_local_negative_qas = [self.qas_to_retrieve[idx] for idx in cur_local_negative_qas_idx]
+ local_negative_qas.append(cur_local_negative_qas)
+ cur_local_pos_mix_neg_qas = cur_local_positive_qas + \
+ cur_local_negative_qas[:neg_num_each_example - cur_pos_num]
+ local_pos_mix_neg_qas.append(cur_local_pos_mix_neg_qas)
+
+ cur_pad_num = batch_local_positive_num - cur_pos_num
+ cur_local_positive_qas_mask = [1] * cur_pos_num + [0] * cur_pad_num
+ local_positive_qas_mask.append(cur_local_positive_qas_mask)
+ cur_local_positive_qas.extend([self.pad_qa] * cur_pad_num)
+ local_positive_qas.append(cur_local_positive_qas)
+
+ model_inputs.update({"local_positive_qas_mask": torch.tensor(local_positive_qas_mask),
+ "local_positive_num": torch.tensor(local_positive_num), })
+ if self.dataset_name == "tq" or self.dataset_name == "wq":
+ squeezed_positive_qas = list(chain(*local_positive_qas))
+ squeezed_positive_target = [qa["answer"][0] for qa in squeezed_positive_qas]
+ with self.tokenizer.as_target_tokenizer():
+ targets = self.tokenizer(squeezed_positive_target, max_length=self.max_target_length,
+ padding=True, truncation=True, return_tensors="pt")
+ model_inputs["labels_to_select"] = process_labels(targets, self.tokenizer). \
+ view(len(batch), batch_local_positive_num, -1)
+ else:
+ targets = [random.choice(qa["answer"]) for qa in batch]
+ with self.tokenizer.as_target_tokenizer():
+ targets = self.tokenizer(targets, max_length=self.max_target_length,
+ padding=True, truncation=True, return_tensors="pt")
+ model_inputs["labels"] = process_labels(targets, self.tokenizer)
+
+ assert self.args.select_positive_strategy == "softmax_sample"
+ squeezed_positive_qas = list(chain(*local_positive_qas))
+ local_positive_inputs = self.get_key_value_inputs(squeezed_positive_qas, only_return_key_inputs=True)
+ model_inputs.update({f"local_positive_inputs_{k}": v.view(len(batch), batch_local_positive_num, -1)
+ for k, v in local_positive_inputs.items()})
+
+ squeezed_negative_qas = list(chain(*local_negative_qas))
+ local_negative_inputs = self.get_key_value_inputs(squeezed_negative_qas, only_return_key_inputs=True)
+ model_inputs.update({f"local_negative_inputs_{k}": v.view(len(batch), neg_num_each_example, -1)
+ for k, v in local_negative_inputs.items()})
+
+ squeezed_mixed_qas = list(chain(*local_pos_mix_neg_qas))
+ local_mixed_inputs = self.get_key_value_inputs(squeezed_mixed_qas)
+ model_inputs.update({f"local_mixed_inputs_{k}": v.view(len(batch), neg_num_each_example, -1)
+ for k, v in local_mixed_inputs.items()})
+ if self.dataset_name == "tq":
+ all_targets = [[normalize_answer(an) for an in qa["answer"]] for qa in batch]
+ negative_qas_answer = [normalize_answer(nqa["answer"][0]) for nqa in squeezed_negative_qas]
+ negative_mask = [[1 if neg_ans not in cur_all_target else 0 for neg_ans in negative_qas_answer]
+ for cur_all_target in all_targets]
+ model_inputs.update({"negative_mask": torch.tensor(negative_mask)})
+
+ # for multi-GPUs
+ assert all(model_inputs[k].shape[0] == len(batch) for k in model_inputs.keys())
+
+ return model_inputs
+
+ return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
+ collate_fn=base_collate_fn, pin_memory=True)
+
+ def get_query_dataloader(self, batch_size, shuffle, num_workers):
+
+ def query_collate_fn(batch):
+ return self.get_query_inputs(batch)
+
+ return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
+ collate_fn=query_collate_fn, pin_memory=True)
+
+ def get_local_qas_dataloader(self, batch_size, shuffle, num_workers):
+
+ def local_qas_collate_fn(batch):
+ # model_inputs = self.get_query_inputs(batch)
+ local_qas = [[self.qas_to_retrieve[qid] for qid in ex['local_qas']] for ex in batch]
+ query_ids = [ex["idx"] for ex in batch]
+ squeezed_local_qas = list(chain(*local_qas))
+ squeezed_local_qas_inputs = self.get_key_value_inputs(squeezed_local_qas, only_return_key_inputs=True)
+ # model_inputs.update(squeezed_local_qas_inputs)
+ # return model_inputs
+ return {**squeezed_local_qas_inputs, "query_ids": torch.tensor(query_ids)}
+
+ return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
+ collate_fn=local_qas_collate_fn, pin_memory=True)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, item):
+ return self.data[item]
diff --git a/qa_main.py b/qa_main.py
new file mode 100644
index 0000000..b16cdb1
--- /dev/null
+++ b/qa_main.py
@@ -0,0 +1,145 @@
+import json
+import os
+import pickle
+from transformers import T5Tokenizer
+from emat.utils import load_jsonl
+from utils.utils import CATArgs
+from qa_dataset import QADataset
+from qa_trainer import QATrainer
+import logging
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+
+DATA_PATHS = {
+ "nq": {
+ "train": "./data/annotated_datasets/NQ-open.train-train.jsonl",
+ "validation": "./data/annotated_datasets/NQ-open.train-dev.jsonl",
+ "test": "./data/annotated_datasets/NQ-open.test.jsonl"
+ },
+ "tq": {
+ "train": "./data/annotated_datasets/triviaqa.train-train.jsonl",
+ "validation": "./data/annotated_datasets/triviaqa.train-dev.jsonl",
+ "test": "./data/annotated_datasets/triviaqa.test.jsonl"
+ },
+ "wq": {
+ "train": "./data/annotated_datasets/WQ-trainmodel.jsonl",
+ "validation": "./data/annotated_datasets/WQ-val.jsonl",
+ "test": "./data/annotated_datasets/WQ-test.jsonl"
+ }
+}
+QA_KB_PATHS = {
+ "PAQ_L1": "./tmp/PAQ_L1_pickl_file.pkl",
+ "PAQ": "./tmp/PAQ_full.pkl",
+ "TAQ_TRAIN_NQ_TRAIN_PAQ": "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl",
+}
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+
+
+def load_dataset(args):
+ assert args.qa_data_name in DATA_PATHS.keys(), f"available dataset names: {DATA_PATHS.keys()}"
+ data_paths = DATA_PATHS[args.qa_data_name]
+ test_data = load_jsonl(DATA_PATHS[args.qa_data_name]["test"])
+ train_data = load_jsonl(data_paths["train"])
+ dev_data = load_jsonl(data_paths["validation"])
+
+ logging.info("loading normed answer of qas to retrieve")
+ if "PAQ" == args.qas_to_retrieve_from:
+ normed_answer_of_qas_to_ret = pickle.load(open("./tmp/PAQ_only_normalized_answer.pkl", 'rb'))
+ else:
+ normed_answer_of_qas_to_ret = json.load(open("./tmp/PAQL1_only_normalized_answer.json", 'r'))
+
+ logging.info("loading qas to retrieve")
+ if "debug" in args.exp_name.lower() or "full-paq-test" in args.exp_name.lower():
+ if not os.path.exists("./tmp/PAQ_L1_small.pkl"):
+ qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_pickl_file.pkl", 'rb'))
+ qas_to_retrieve = qas_to_retrieve[:len(qas_to_retrieve) // 14]
+ pickle.dump(qas_to_retrieve, open("./tmp/PAQ_L1_small.pkl", 'wb'))
+ else:
+ qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_small.pkl", 'rb'))
+ else:
+ qas_to_retrieve_fp = QA_KB_PATHS[args.qas_to_retrieve_from]
+ logging.info(f"loading qas from {qas_to_retrieve_fp}")
+ if qas_to_retrieve_fp.endswith("pkl"):
+ qas_to_retrieve = pickle.load(open(qas_to_retrieve_fp, 'rb'))
+ elif qas_to_retrieve_fp.endswith("jsonl"):
+ qas_to_retrieve = load_jsonl(qas_to_retrieve_fp)
+ else:
+ raise ValueError(f"{qas_to_retrieve_fp}")
+
+ if "debug" in args.exp_name.lower():
+ train_data = train_data[:100]
+ dev_data = dev_data[:500]
+ qas_to_retrieve = qas_to_retrieve[:10000]
+ normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:len(qas_to_retrieve)]
+
+ if args.qas_to_retrieve_from == "PAQ" and args.PAQ_size is not None:
+ qas_to_retrieve = qas_to_retrieve[:args.PAQ_size]
+ normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:args.PAQ_size]
+ assert len(qas_to_retrieve) == args.PAQ_size
+ logging.info(f"select {args.PAQ_size}-size PAQ.")
+
+ assert len(normed_answer_of_qas_to_ret) == len(qas_to_retrieve)
+ loaded_data = {
+ "train": train_data, "validation": dev_data, "test": test_data,
+ "qas_to_retrieve": qas_to_retrieve,
+ "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret
+ }
+
+ return loaded_data
+
+
+def main():
+ cat_args = CATArgs("qa_cat")
+ args = cat_args.parse_args()
+ loaded_data = load_dataset(args)
+ logging.info("data loaded.")
+ train_data, dev_data, test_data = loaded_data["train"], loaded_data["validation"], loaded_data["test"]
+ qas_to_retrieve = loaded_data["qas_to_retrieve"]
+ normed_answer_of_qas_to_ret = loaded_data["normed_answer_of_qas_to_ret"]
+
+ dataset_kwargs = {
+ "max_source_length": args.max_source_length,
+ "dataset_name": args.qa_data_name,
+ "args": args,
+ "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret,
+ }
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
+ train_dataset = QADataset(train_data, tokenizer, qas_to_retrieve, **dataset_kwargs)
+ dev_dataset = QADataset(dev_data, tokenizer, qas_to_retrieve, **dataset_kwargs)
+ test_dataset = QADataset(test_data, tokenizer, qas_to_retrieve, **dataset_kwargs)
+
+ qa_trainer = QATrainer(args, train_dataset, dev_dataset, test_dataset, qas_to_retrieve, normed_answer_of_qas_to_ret)
+
+ if args.do_train:
+ qa_trainer.train()
+ elif args.do_test:
+ logging.info("Only do test.")
+ ckpt_load_path = os.path.join(args.output_dir, "best_ckpt/pytorch_model.bin")
+ em_score, match_metric, ret_qas, gen_ans = qa_trainer.evaluate(
+ qa_trainer.test_dataset, extend_mem_from="train_dev",
+ update_key_memory=True, ckpt_load_path=ckpt_load_path
+
+ )
+
+ logging.info(f"em_test: {em_score:.3f}")
+ for k, v in match_metric.items():
+ logging.info(f"test_{k}: {v}")
+ results = []
+ for idx, (input_qa, retrieved_qas, predict_answer) in enumerate(zip(qa_trainer.test_dataset, ret_qas, gen_ans)):
+ results.append({
+ "idx": idx,
+ "question": input_qa["question"],
+ "answer": input_qa["answer"],
+ "retrieved_qas": [{"question": qa["question"], "answer": qa["answer"][0]} for qa in retrieved_qas],
+ "generate_answer": predict_answer,
+ })
+ json.dump(results, open(os.path.join(args.output_dir, "best_ckpt/predict_results.json"), 'w'),
+ indent=4, ensure_ascii=False)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/qa_trainer.py b/qa_trainer.py
new file mode 100644
index 0000000..994e0ee
--- /dev/null
+++ b/qa_trainer.py
@@ -0,0 +1,565 @@
+import copy
+import logging
+import math
+import os
+import random
+from itertools import chain
+from typing import List, Dict, Optional
+from emat.evaluation.eval_retriever import eval_retriever, eval_generation_em
+from utils.dr_utils import update_local_qas_to_retrieve, update_batch_inputs, rank_exist_local_qas
+from utils.utils import reduce_query_or_key_embeds
+import datasets
+import torch
+import transformers
+from accelerate import Accelerator
+from tqdm.auto import tqdm
+from transformers import AdamW, get_scheduler, set_seed
+from utils.utils import save_model, load_model
+from build_kvm import build_memory
+from emat.t5 import T5WithKeyValueMemory
+from qa_dataset import QADataset
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+logger = logging.getLogger(__name__)
+
+try:
+ import wandb
+
+ wandb.ensure_configured()
+ if wandb.api.api_key is None:
+ _has_wandb = False
+ wandb.termwarn(
+ "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
+ else:
+ _has_wandb = False if os.getenv("WANDB_DISABLED") else True
+except (ImportError, AttributeError):
+ _has_wandb = False
+
+
+class QATrainer:
+
+ def __init__(
+ self,
+ args,
+ train_dataset: QADataset,
+ dev_dataset: QADataset,
+ test_dataset: QADataset,
+ qas_to_retrieve: List[Dict],
+ normed_answer_of_qas_to_ret,
+ ):
+ accelerator = Accelerator()
+ logging.info(f"wandb {'available' if _has_wandb else 'unavailable'}")
+ logger.info(accelerator.state)
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ if args.seed is not None:
+ set_seed(args.seed)
+ else:
+ logging.info("Not set seed.")
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+ accelerator.wait_for_everyone()
+
+ if accelerator.is_local_main_process and _has_wandb:
+ wandb.init(project=args.project_name, name=args.exp_name, dir=args.output_dir, config=vars(args))
+
+ logging.info("loading model")
+ config, tokenizer, self.model = load_model(T5WithKeyValueMemory, args)
+ logging.info("Loading model.")
+ logging.info(f"model params: {self.model.num_parameters()}")
+ if args.freeze_t5_params:
+ logging.info("Freeze T5 parameters.")
+ self.model.freeze_t5_params()
+ if args.only_train_adapter:
+ for param in self.model.parameters():
+ param.requires_grad = False
+ for param in self.model.adapter.parameters():
+ param.requires_grad = True
+
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": args.weight_decay, },
+ {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0, },
+ ]
+ self.optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+
+ # reused_key_memory: pre-allocated memory to store full key_memory
+ self.reused_key_memory = torch.zeros((len(qas_to_retrieve), self.model.model_dim),
+ device="cpu", dtype=torch.float16)
+ self.train_data_query_embeds = torch.zeros((len(train_dataset), self.model.model_dim),
+ device="cpu", dtype=torch.float16)
+ self.key_memory: Optional[List[torch.tensor]] = None
+ self.key_memory = []
+ for start_idx in range(0, len(qas_to_retrieve), math.ceil(len(qas_to_retrieve) / args.kvm_seg_n)):
+ self.key_memory.append(
+ self.reused_key_memory[start_idx: start_idx + math.ceil(len(qas_to_retrieve) / args.kvm_seg_n)]
+ )
+ logger.info(f"key num = {sum(len(i) for i in self.key_memory)}")
+
+ self.train_dataset = train_dataset
+ self.dev_dataset = dev_dataset
+ self.test_dataset = test_dataset
+ self.args = args
+ self.accelerator = accelerator
+ self.tokenizer = tokenizer
+ self.qas_to_retrieve = qas_to_retrieve
+ self.prefix = args.source_prefix if args.source_prefix is not None else ""
+ assert self.prefix == "question: "
+ # self.query_batch_size = 550 if args.kvm_seg_n > 2 else 256
+ # self.query_batch_size = 1024 if args.kvm_seg_n >= 4 else self.query_batch_size
+ # self.query_batch_size = 3000 if args.kvm_seg_n >= 7 else self.query_batch_size
+ # # if len(self.qas_to_retrieve) < 20000000:
+ # self.query_batch_size = 512
+ self.query_batch_size = args.query_batch_size
+ logger.info(f"PAQ-size: {len(self.qas_to_retrieve)}. PAQ's query batch size: {self.query_batch_size}.")
+ self.normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret
+ self.model = self.accelerator.prepare(self.model)
+
+ @torch.no_grad()
+ def update_key_memory(self, use_fp16_model=True, use_retrieval_adapter=False):
+ args = self.args
+ if use_fp16_model:
+ tmp_model = copy.deepcopy(self.model)
+ tmp_model = tmp_model.half()
+ else:
+ tmp_model = self.model
+ build_mem_batch_size = args.build_mem_batch_size
+ tmp_model.eval()
+ self.key_memory, _ = build_memory(
+ tmp_model, self.tokenizer, embed_key=True, embed_value=False, prefix=self.prefix, embed_as_fp16=True,
+ key_reduce_method=args.key_reduce_method, return_memory=True, dump_memory=False,
+ data_to_embed=self.qas_to_retrieve, max_source_length=args.max_source_length, padding=True,
+ batch_size=build_mem_batch_size, separate_task=True, kvm_seg_n=args.kvm_seg_n,
+ reused_key_memory=self.reused_key_memory, use_retrieval_adapter=use_retrieval_adapter
+ )
+ if type(self.key_memory) is not list:
+ self.key_memory = [self.key_memory]
+ del tmp_model
+
+ @torch.no_grad()
+ def update_local_qas(self, epoch, use_fp16_model=True, use_retrieval_adapter=False):
+ args = self.args
+ if use_fp16_model:
+ tmp_model = copy.deepcopy(self.model)
+ tmp_model = tmp_model.half()
+ else:
+ tmp_model = self.model
+ build_mem_batch_size = args.build_mem_batch_size
+ tmp_model.eval()
+ if args.update_kv_embeds and args.update_local_qas and epoch >= args.repaq_supervision_epoch:
+ update_local_qas_to_retrieve(
+ args, self.train_dataset, self.qas_to_retrieve, tmp_model, self.key_memory,
+ self.normed_answer_of_qas_to_ret, train_data_query_embeds=self.train_data_query_embeds,
+ build_mem_batch_size=build_mem_batch_size, query_batch_size=self.query_batch_size,
+ local_size=args.local_size, pos_from_top=args.pos_from_top, neg_from_top=200,
+ use_retrieval_adapter=use_retrieval_adapter
+ )
+ elif args.only_rank_exists_local_qa:
+ logging.warning("Do not use!")
+ embed_local_qas_batch_size = (build_mem_batch_size //
+ len(self.train_dataset.data[0]["local_qas"]) + 1) * 2
+ rank_exist_local_qas(args, self.train_dataset, self.qas_to_retrieve, tmp_model,
+ self.normed_answer_of_qas_to_ret, build_mem_batch_size=build_mem_batch_size,
+ train_data_query_embeds=self.train_data_query_embeds,
+ embed_local_qas_batch_size=embed_local_qas_batch_size,
+ local_size=args.local_size, pos_from_top=args.pos_from_top, neg_from_top=200,
+ accelerator=self.accelerator)
+ del tmp_model
+
+ def train(self):
+ args = self.args
+ tokenizer = self.tokenizer
+ num_workers = 5
+ if args.update_kv_embeds and not args.only_rank_exists_local_qa:
+ logging.info("Build Memory")
+ self.update_key_memory()
+ train_dataloader = self.train_dataset.get_dataloader(batch_size=args.per_device_train_batch_size,
+ shuffle=True, num_workers=num_workers)
+
+ optimizer, train_dataloader = self.accelerator.prepare(self.optimizer, train_dataloader)
+
+ # Scheduler and math around the number of training steps.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ else:
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ lr_scheduler = get_scheduler(
+ name=args.lr_scheduler_type,
+ optimizer=optimizer,
+ num_warmup_steps=args.num_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # Train!
+ total_batch_size = args.per_device_train_batch_size * self.accelerator.num_processes \
+ * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(self.train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not self.accelerator.is_local_main_process)
+ completed_steps = 0
+ best_hit_at_1, best_em, patience = None, None, args.early_stop_patience
+
+ for epoch in range(args.num_train_epochs):
+ use_adapter_to_select_positive = epoch >= args.use_adapter_to_select_positive_after_k_epoch
+ if args.only_train_adapter:
+ if epoch == 0:
+ self.update_local_qas(epoch)
+ # convert original key memory with high dim to low dim through adapter
+ qas_num = self.reused_key_memory.shape[0]
+ train_qas_num = len(self.train_dataset)
+ dim = args.adapter_out_dim
+ self.reused_key_memory = torch.zeros((qas_num, dim), device="cpu", dtype=torch.float16)
+ self.train_data_query_embeds = torch.zeros((train_qas_num, dim), device="cpu", dtype=torch.float16)
+ self.key_memory: Optional[List[torch.tensor]] = None
+ self.key_memory = []
+ for start_idx in range(0, qas_num, math.ceil(qas_num / args.kvm_seg_n)):
+ self.key_memory.append(
+ self.reused_key_memory[start_idx: start_idx + math.ceil(qas_num / args.kvm_seg_n)]
+ )
+ self.update_key_memory(use_retrieval_adapter=True)
+ elif use_adapter_to_select_positive:
+ self.update_local_qas(epoch, use_retrieval_adapter=True)
+ else:
+ if args.qas_to_retrieve_from == "PAQ" and (epoch % 3 == 0):
+ self.update_local_qas(epoch)
+ elif args.qas_to_retrieve_from != "PAQ":
+ self.update_local_qas(epoch)
+ for step, batch in enumerate(train_dataloader):
+
+ update_batch_inputs(args, batch, self.model,
+ use_adapter_to_select_positive=use_adapter_to_select_positive)
+ self.model.train()
+ if args.match_weight > 0.0:
+ # Embed Positive Key and the Value to input.
+ embed_dict = self.model.wrapped_embed_kv( # assert num_values > 1, otherwise set compute_value=True
+ separate_task=args.separate_task, compute_key=True, compute_value=False,
+ **batch.pop("positive_kv_inputs")
+ )
+ positive_key_embeds = embed_dict["normed_key_embeds"]
+ positive_key_embeds = reduce_query_or_key_embeds(positive_key_embeds, args.key_reduce_method)
+ # Embed Negative Key
+ embed_dict = self.model.wrapped_embed_kv(
+ separate_task=args.separate_task, compute_key=True, compute_value=False,
+ **batch.pop("negative_kv_inputs")
+ )
+ negative_key_embeds = embed_dict["normed_key_embeds"]
+ negative_key_embeds = reduce_query_or_key_embeds(negative_key_embeds, args.key_reduce_method)
+ else:
+ negative_key_embeds, positive_key_embeds = None, None
+ # Embed retrieved-Key-Value
+ embed_dict = self.model.wrapped_embed_kv(
+ separate_task=args.separate_task, compute_key=True, compute_value=True,
+ **batch.pop("group_value_inputs")
+ )
+ key_embeds_of_value = embed_dict["key_embeds"]
+ value_embeds = embed_dict["value_embeds"]
+ bs = batch["query_input_ids"].shape[0]
+ value_embeds = value_embeds.view(bs, args.num_values, args.prefix_length, -1)
+ key_embeds_of_value = key_embeds_of_value.view(bs, args.num_values, -1, self.model.model_dim)
+
+ loss_dict = self.model.compute_qa_loss(
+ input_ids=batch["query_input_ids"],
+ attention_mask=batch["query_attention_mask"],
+ labels=batch["labels"],
+ decoder_only_attend_on_prefix=args.decoder_only_attend_on_prefix,
+ value_fusion_method=args.value_fusion_method,
+ encoder_outputs_are_key_or_value=False,
+ key_reduce_method=args.key_reduce_method,
+ positive_key_embeds=positive_key_embeds,
+ negative_key_embeds=negative_key_embeds,
+ value_embeds=value_embeds,
+ matching_targets=batch["matching_targets"],
+ key_embeds_of_value=key_embeds_of_value,
+ negative_mask=batch.get("negative_mask", None),
+ only_train_adapter=args.only_train_adapter
+ )
+ if args.match_weight > 0.0:
+ if epoch >= args.only_key_matching_n_epoch:
+ loss = args.gen_weight * loss_dict["gen_loss"] + args.match_weight * loss_dict["match_loss"]
+ else:
+ loss = args.match_weight * loss_dict["match_loss"]
+ else:
+ loss = loss_dict["gen_loss"]
+ loss_dict = {k: v.item() for k, v in loss_dict.items()}
+ loss = loss / args.gradient_accumulation_steps
+ self.accelerator.backward(loss)
+ if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ progress_bar.update(1)
+ completed_steps += 1
+ if self.accelerator.is_local_main_process and _has_wandb:
+ wandb.log({"loss": loss * args.gradient_accumulation_steps, "step": completed_steps})
+ wandb.log({"trainable_percentage": batch["trainable_percentage"][0].item(),
+ "forward_step": completed_steps})
+ for k, v in loss_dict.items():
+ wandb.log({k: v, "step": completed_steps})
+
+ if completed_steps >= args.max_train_steps:
+ break
+
+ if args.output_dir is not None:
+ save_model(self.model, os.path.join(args.output_dir, "latest_ckpt"),
+ self.accelerator, tokenizer=tokenizer, arguments=args)
+
+ if (args.update_kv_embeds and not args.only_rank_exists_local_qa) or args.do_eval:
+ logging.info("Update Memory")
+ self.update_key_memory(use_retrieval_adapter=args.only_train_adapter)
+
+ if args.do_eval and epoch % args.eval_freq == 0:
+ # if args.do_eval: self.key_memory is up-to-date.
+ em_score, matching_metric, _, _ = self.evaluate(dataset=self.dev_dataset, extend_mem_from="train",
+ use_retrieval_adapter=args.only_train_adapter)
+ logger.info(f"epoch {epoch} eval - EM: {em_score:.3f}")
+ if self.accelerator.is_local_main_process and _has_wandb:
+ wandb.log({"em_dev": em_score, "epoch": epoch})
+ for k, v in matching_metric.items():
+ wandb.log({f"{k}": v * 100, "epoch": epoch})
+
+ if args.output_dir is not None:
+ if best_hit_at_1 is None or matching_metric["hit@1"] * 100 > best_hit_at_1:
+ best_hit_at_1 = matching_metric["hit@1"] * 100
+ if best_em is None or em_score > best_em:
+ best_em = em_score
+ save_model(self.model, os.path.join(args.output_dir, "best_ckpt"),
+ self.accelerator, tokenizer=tokenizer, arguments=args)
+ patience = args.early_stop_patience
+ else:
+ patience -= 1
+ if patience <= 0:
+ break
+
+ if best_em is not None: # Log the best dev EM score
+ logger.info(f"best_em_dev: {best_em}")
+ if _has_wandb:
+ wandb.log({"best_em_dev": best_em})
+ if best_hit_at_1 is not None:
+ logger.info(f"best_hit@1_dev: {best_hit_at_1}")
+ if _has_wandb:
+ wandb.log({"best_hit@1_dev": best_hit_at_1})
+
+ # do-test
+ best_model_state_dict = os.path.join(args.output_dir, "best_ckpt/pytorch_model.bin")
+ em_score, matching_metric, _, _ = self.evaluate(dataset=self.test_dataset, extend_mem_from="train_dev",
+ update_key_memory=True, ckpt_load_path=best_model_state_dict,
+ use_retrieval_adapter=args.only_train_adapter)
+
+ if self.accelerator.is_local_main_process:
+ logger.info(f"em_test: {em_score:.3f}")
+ for k, v in matching_metric.items():
+ logger.info(f"test_{k}: {v}")
+ if _has_wandb:
+ wandb.log({"em_test": em_score})
+ for k, v in matching_metric.items():
+ wandb.log({f"test_{k}": v})
+
+ @torch.no_grad()
+ def evaluate(self, dataset: QADataset = None, extend_mem_from="", update_key_memory=False, ckpt_load_path=None,
+ use_retrieval_adapter=False):
+ # not implement correctly in multi-GPUs.
+ tokenizer = self.tokenizer
+ args = self.args
+ self.model.eval()
+ torch.cuda.empty_cache()
+
+ assert extend_mem_from in ["train", "train_dev"]
+ if ckpt_load_path is not None:
+ assert update_key_memory is True
+ loaded_state_dict = torch.load(ckpt_load_path)
+ load_info = self.model.load_state_dict(loaded_state_dict, strict=False)
+ logging.info(f"{load_info}")
+
+ assert type(self.key_memory) == list
+ original_key_length = sum(len(k) for k in self.key_memory)
+ if update_key_memory:
+ logging.info("Update Memory")
+ self.update_key_memory(use_retrieval_adapter=use_retrieval_adapter)
+
+ extend_length = 0
+ last_chunk_memory = self.key_memory[-1]
+ qas_to_retrieve_eval = self.qas_to_retrieve
+
+ tmp_model = copy.deepcopy(self.model)
+ if args.kvm_fp16:
+ tmp_model = tmp_model.half()
+
+ logging.info("Build train data memory to retrieve.")
+ if args.qa_data_name == "tq":
+ build_query_batch_size = 256
+ else:
+ build_query_batch_size = args.build_mem_batch_size
+ if "train" in extend_mem_from:
+ train_qas_key_memory, _ = build_memory(
+ tmp_model, tokenizer, embed_key=True, embed_value=False, prefix=self.prefix, embed_as_fp16=True,
+ key_reduce_method=args.key_reduce_method, return_memory=True, dump_memory=False, kvm_seg_n=-1,
+ data_to_embed=self.train_dataset.data, max_source_length=args.max_source_length, padding=True,
+ batch_size=build_query_batch_size, separate_task=args.separate_task, reused_key_memory=None,
+ use_retrieval_adapter=use_retrieval_adapter
+ )
+ extend_length = extend_length + len(train_qas_key_memory)
+ last_chunk_memory = torch.cat((last_chunk_memory, train_qas_key_memory)) # extend in the last chunk
+ qas_to_retrieve_eval = qas_to_retrieve_eval + self.train_dataset.data
+
+ if "dev" in extend_mem_from:
+ logging.info("Build dev data memory to retrieve.")
+ dev_qas_key_memory, _ = build_memory(
+ tmp_model, tokenizer, embed_key=True, embed_value=False, prefix=self.prefix, embed_as_fp16=True,
+ key_reduce_method=args.key_reduce_method, return_memory=True, dump_memory=False, kvm_seg_n=-1,
+ data_to_embed=self.dev_dataset.data, max_source_length=args.max_source_length, padding=True,
+ batch_size=build_query_batch_size, separate_task=args.separate_task, reused_key_memory=None,
+ use_retrieval_adapter=use_retrieval_adapter
+ )
+ extend_length = extend_length + len(dev_qas_key_memory)
+ last_chunk_memory = torch.cat((last_chunk_memory, dev_qas_key_memory)) # extend in the last chunk
+ qas_to_retrieve_eval = qas_to_retrieve_eval + self.dev_dataset.data
+ del tmp_model
+
+ key_memory_eval = self.key_memory[:-1] + [last_chunk_memory]
+ key_nums_eval = sum(len(k) for k in key_memory_eval)
+ assert key_nums_eval == len(qas_to_retrieve_eval)
+
+ # if use_retrieval_adapter:
+ # low_dim_key = []
+ # while len(key_memory_eval) > 0:
+ # chunk_key = key_memory_eval.pop(0)
+ # chunk_low_dim_key = []
+ # for start_idx in range(len(chunk_key)):
+ # chunk_low_dim_key.append(self.model.adapter(chunk_key[start_idx:start_idx + 512]))
+ # del chunk_key
+ # low_dim_key.append(torch.cat(chunk_low_dim_key))
+ # key_memory_eval = low_dim_key
+
+ dataloader = dataset.get_query_dataloader(batch_size=args.per_device_eval_batch_size,
+ shuffle=False, num_workers=1)
+ dataloader = self.accelerator.prepare(dataloader)
+
+ gen_kwargs = {"max_length": args.max_target_length,
+ "num_beams": args.num_beams, }
+
+ torch.cuda.empty_cache()
+
+ all_retrieved_qas = []
+ all_gen_ans = []
+ for batch in tqdm(dataloader):
+ embed_dict = self.model.CAT_embed_q(
+ input_ids=batch["query_input_ids"],
+ attention_mask=batch["query_attention_mask"],
+ compute_key=True, compute_value=False
+ )
+ query_embeds = embed_dict["normed_key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method)
+ if use_retrieval_adapter:
+ query_embeds = self.model.adapter(query_embeds)
+ query_embeds = query_embeds.half()
+
+ if key_nums_eval > 20000000:
+ # if scale is large: calculate topk in each chunk -> combine all-topk -> select final topk
+ chunk_top_scores = []
+ chunk_top_indices = []
+ idx_shift = 0
+ for chunk_key_memory in key_memory_eval:
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_topk = torch.mm(query_embeds, chunk_key_memory_cuda.t()).topk(50, dim=1)
+ chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size]
+ chunk_top_indices.append(chunk_topk.indices + idx_shift)
+ idx_shift += len(chunk_key_memory)
+ del chunk_key_memory_cuda
+ torch.cuda.empty_cache()
+ chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n
+ chunk_top_indices = torch.cat(chunk_top_indices, dim=1) # q_batch, local_size*seg_n
+ topk = chunk_top_scores.topk(50, dim=1) # q_batch, local_size
+ top_indices_indices = topk.indices
+ top_indices = []
+ for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices):
+ top_indices.append([cur_indices[idx] for idx in cur_indices_indices])
+ readout_qas = [[qas_to_retrieve_eval[idx] for idx in indices] for indices in top_indices]
+ else:
+ all_chunk_scores = []
+ for chunk_key_memory in key_memory_eval:
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_scores = torch.mm(query_embeds, chunk_key_memory_cuda.t()) # query_batch
+ all_chunk_scores.append(chunk_scores)
+ del chunk_key_memory_cuda
+ scores = torch.cat(all_chunk_scores, dim=1)
+ top_indices = scores.topk(50, dim=1).indices.tolist()
+ readout_qas = [[qas_to_retrieve_eval[idx] for idx in indices] for indices in top_indices]
+ value_qas = []
+ for qas in readout_qas:
+ selected_qas = qas[:args.num_values]
+ if not args.values_with_order:
+ random.shuffle(selected_qas)
+ value_qas.append(selected_qas)
+ all_retrieved_qas += readout_qas
+
+ squeezed_value_qas = list(chain(*value_qas))
+
+ retrieved_qas_inputs = dataset.get_key_value_inputs(squeezed_value_qas, only_return_key_inputs=False)
+ embed_dict = self.model.wrapped_embed_kv(separate_task=args.separate_task, compute_key=True,
+ compute_value=True, **retrieved_qas_inputs)
+ value_embeds = embed_dict["value_embeds"]
+ key_embeds_of_value = embed_dict["key_embeds"]
+ cur_batch_size = query_embeds.shape[0]
+ value_embeds = value_embeds.view(cur_batch_size, args.num_values, args.prefix_length, -1)
+ key_embeds_of_value = key_embeds_of_value.view(cur_batch_size, args.num_values, -1, self.model.model_dim)
+ encoder_outputs = self.model.encoder(
+ input_ids=batch["query_input_ids"],
+ attention_mask=batch["query_attention_mask"],
+ return_dict=True,
+ value_embeds=value_embeds,
+ readout_top_k=-1,
+ key_reduce_method=args.key_reduce_method,
+ value_fusion_method=args.value_fusion_method,
+ key_embeds_of_value=key_embeds_of_value
+ )
+ generated_tokens = self.accelerator.unwrap_model(self.model).generate(
+ encoder_outputs=encoder_outputs,
+ encoder_outputs_are_key_or_value=False,
+ decoder_only_attend_on_prefix=args.decoder_only_attend_on_prefix,
+ attention_mask=batch["query_attention_mask"].to(self.model.device),
+ value_fusion_method=args.value_fusion_method,
+ **gen_kwargs,
+ )
+ generated_tokens = self.accelerator.pad_across_processes(
+ generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
+ )
+ generated_tokens = self.accelerator.gather(generated_tokens).cpu().numpy()
+ decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+ decoded_tokens = [ans.strip() for ans in decoded_tokens]
+ all_gen_ans += decoded_tokens
+
+ torch.cuda.empty_cache()
+
+ matching_metric = eval_retriever(dataset.data, all_retrieved_qas, "1,2,3,4,5,10,50")
+ em_score = eval_generation_em(dataset.data, all_gen_ans) * 100
+
+ assert original_key_length == sum(len(k) for k in self.key_memory)
+ assert original_key_length == len(self.qas_to_retrieve)
+
+ return em_score, matching_metric, all_retrieved_qas, all_gen_ans
+
+
+if __name__ == '__main__':
+ pass
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..8275390
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+accelerate==0.5.1
+datasets==1.18.4
+nltk==3.7
+numpy==1.21.5
+rouge==1.0.1
+scikit_learn==1.0.2
+setuptools==58.0.4
+stanfordcorenlp==3.9.1.1
+tqdm==4.62.3
+rouge==1.0.1
\ No newline at end of file
diff --git a/retrieval_adapter_scripts/train_adapter.sh b/retrieval_adapter_scripts/train_adapter.sh
new file mode 100644
index 0000000..775a8a4
--- /dev/null
+++ b/retrieval_adapter_scripts/train_adapter.sh
@@ -0,0 +1,67 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+
+DST_DIR="/mnt/inspurfs/user-fs/zhaoyu/workspace/case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+LOAD_PATH="/mnt/inspurfs/user-fs/zhaoyu/workspace/case-augmented-transformer-master/outputs/nq_checkpoints/KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;pos_from_top=128;/best_ckpt/" # load pre-trained model
+EXP_NAME="adapter768;distill;adapter_retrieve_after5epoch;" # set experiment name
+DATA_NAME="nq" # datasets: ["nq", "tq", "wq"]
+#
+DEVICE="5"
+
+# Train nq-EMAT-FKSV
+# use --kvm_fp16 if GPU OOM
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \
+ --project_name="${DATA_NAME^^}-CAT" \
+ --exp_name=${EXP_NAME} \
+ --query_batch_size=256 \
+ --build_mem_batch_size=12000 \
+ --batch_local_positive_num=5 \
+ --pos_from_top=128 \
+ --do_eval \
+ --kvm_seg_n=2 \
+ --values_with_order \
+ --value_layer=7 \
+ --value_fusion_method="cat_k_delay+v" \
+ --num_values=1 \
+ --qa_data_name=${DATA_NAME} \
+ --model_name_or_path=${LOAD_PATH} \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=64 \
+ --per_device_eval_batch_size=64 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=5e-5 \
+ --num_train_epochs=30 \
+ --lr_scheduler_type="linear" \
+ --num_warmup_steps=1000 \
+ --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_layer=3 \
+ --key_encoder_type="conv" \
+ --select_positive_strategy="softmax_sample" \
+ --faiss_efsearch=128 \
+ --gen_weight=1 \
+ --match_weight=1 \
+ --key_reduce_method="avg" \
+ --qas_to_retrieve_from="PAQ_L1" \
+ --local_size=384 \
+ --update_kv_embeds \
+ --update_local_target_each_batch \
+ --update_local_qas \
+ --separate_task \
+ --value_ae_target="ans" \
+ --key_ae_target="question" \
+ --repaq_supervision_epoch=-1 \
+ --early_stop_patience=8 \
+ --negatives_num_each_example=32 \
+ --only_train_adapter \
+ --adapter="linear" \
+ --adapter_out_dim=768 \
+ --use_adapter_to_select_positive_after_k_epoch=5 \
+ --do_train
diff --git a/scripts/nq_eval.sh b/scripts/nq_eval.sh
new file mode 100644
index 0000000..1805b8b
--- /dev/null
+++ b/scripts/nq_eval.sh
@@ -0,0 +1,63 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+CKPT_DIR="" # load pre-trained model
+EXP_NAME="eval" # set experiment name
+DATA_NAME="nq" # datasets: ["nq", "tq", "wq"]
+
+DEVICE="0"
+
+# Train nq-EMAT-FKSV
+# use --kvm_fp16 if GPU OOM
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \
+ --project_name="${DATA_NAME^^}-CAT" \
+ --exp_name=${EXP_NAME} \
+ --query_batch_size=256 \
+ --build_mem_batch_size=12000 \
+ --batch_local_positive_num=5 \
+ --pos_from_top=128 \
+ --do_eval \
+ --kvm_seg_n=2 \
+ --values_with_order \
+ --value_layer=7 \
+ --value_fusion_method="cat_k_delay+v" \
+ --num_values=10 \
+ --qa_data_name=${DATA_NAME} \
+ --model_name_or_path=${CKPT_DIR} \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=64 \
+ --per_device_eval_batch_size=64 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=5e-5 \
+ --num_train_epochs=30 \
+ --lr_scheduler_type="linear" \
+ --num_warmup_steps=1000 \
+ --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_layer=3 \
+ --key_encoder_type="conv" \
+ --select_positive_strategy="softmax_sample" \
+ --faiss_efsearch=128 \
+ --gen_weight=1 \
+ --match_weight=1 \
+ --key_reduce_method="avg" \
+ --qas_to_retrieve_from="PAQ_L1" \
+ --local_size=384 \
+ --update_kv_embeds \
+ --update_local_target_each_batch \
+ --update_local_qas \
+ --separate_task \
+ --value_ae_target="ans" \
+ --key_ae_target="question" \
+ --repaq_supervision_epoch=-1 \
+ --early_stop_patience=8 \
+ --negatives_num_each_example=32 \
+ --do_test
diff --git a/scripts/nq_train_with_paql1.sh b/scripts/nq_train_with_paql1.sh
new file mode 100644
index 0000000..e54b872
--- /dev/null
+++ b/scripts/nq_train_with_paql1.sh
@@ -0,0 +1,63 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+LOAD_EXP_NAME="KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;" # load pre-trained model
+EXP_NAME="base;KL=3;VL=7;VN=10;lr=5e-5;" # set experiment name
+DATA_NAME="nq" # datasets: ["nq", "tq", "wq"]
+
+DEVICE="0"
+
+# Train nq-EMAT-FKSV
+# use --kvm_fp16 if GPU OOM
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \
+ --project_name="${DATA_NAME^^}-CAT" \
+ --exp_name=${EXP_NAME} \
+ --query_batch_size=256 \
+ --build_mem_batch_size=12000 \
+ --batch_local_positive_num=5 \
+ --pos_from_top=128 \
+ --do_eval \
+ --kvm_seg_n=2 \
+ --values_with_order \
+ --value_layer=7 \
+ --value_fusion_method="cat_k_delay+v" \
+ --num_values=10 \
+ --qa_data_name=${DATA_NAME} \
+ --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=64 \
+ --per_device_eval_batch_size=64 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=5e-5 \
+ --num_train_epochs=30 \
+ --lr_scheduler_type="linear" \
+ --num_warmup_steps=1000 \
+ --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_layer=3 \
+ --key_encoder_type="conv" \
+ --select_positive_strategy="softmax_sample" \
+ --faiss_efsearch=128 \
+ --gen_weight=1 \
+ --match_weight=1 \
+ --key_reduce_method="avg" \
+ --qas_to_retrieve_from="PAQ_L1" \
+ --local_size=384 \
+ --update_kv_embeds \
+ --update_local_target_each_batch \
+ --update_local_qas \
+ --separate_task \
+ --value_ae_target="ans" \
+ --key_ae_target="question" \
+ --repaq_supervision_epoch=-1 \
+ --early_stop_patience=8 \
+ --negatives_num_each_example=32 \
+ --do_train
diff --git a/scripts/tq_train_with_paql1.sh b/scripts/tq_train_with_paql1.sh
new file mode 100644
index 0000000..6303026
--- /dev/null
+++ b/scripts/tq_train_with_paql1.sh
@@ -0,0 +1,63 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+LOAD_EXP_NAME="KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;" # load pre-trained model
+EXP_NAME="base;KL=3;VL=7;VN=10;lr=5e-5;" # set experiment name
+DATA_NAME="tq" # datasets: ["nq", "tq", "wq"]
+
+DEVICE="0"
+
+# Train nq-EMAT-FKSV
+# use --kvm_fp16 if GPU OOM
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \
+ --project_name="${DATA_NAME^^}-CAT" \
+ --exp_name=${EXP_NAME} \
+ --query_batch_size=256 \
+ --build_mem_batch_size=12000 \
+ --batch_local_positive_num=5 \
+ --pos_from_top=128 \
+ --do_eval \
+ --kvm_seg_n=2 \
+ --values_with_order \
+ --value_layer=7 \
+ --value_fusion_method="cat_k_delay+v" \
+ --num_values=10 \
+ --qa_data_name=${DATA_NAME} \
+ --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=64 \
+ --per_device_eval_batch_size=64 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=5e-5 \
+ --num_train_epochs=30 \
+ --lr_scheduler_type="linear" \
+ --num_warmup_steps=1000 \
+ --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_layer=3 \
+ --key_encoder_type="conv" \
+ --select_positive_strategy="softmax_sample" \
+ --faiss_efsearch=128 \
+ --gen_weight=1 \
+ --match_weight=1 \
+ --key_reduce_method="avg" \
+ --qas_to_retrieve_from="PAQ_L1" \
+ --local_size=384 \
+ --update_kv_embeds \
+ --update_local_target_each_batch \
+ --update_local_qas \
+ --separate_task \
+ --value_ae_target="ans" \
+ --key_ae_target="question" \
+ --repaq_supervision_epoch=-1 \
+ --early_stop_patience=8 \
+ --negatives_num_each_example=32 \
+ --do_train
diff --git a/scripts/wq_train_with_paql1.sh b/scripts/wq_train_with_paql1.sh
new file mode 100644
index 0000000..29527d7
--- /dev/null
+++ b/scripts/wq_train_with_paql1.sh
@@ -0,0 +1,63 @@
+#!/bin/bash -l
+
+set -e
+set -u
+
+
+DST_DIR="case-augmented-transformer-master" # change to your project root
+cd ${DST_DIR}
+
+LOAD_EXP_NAME="KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;" # load pre-trained model
+EXP_NAME="base;KL=3;VL=7;VN=10;lr=5e-5;" # set experiment name
+DATA_NAME="wq" # datasets: ["nq", "tq", "wq"]
+
+DEVICE="0"
+
+# Train nq-EMAT-FKSV
+# use --kvm_fp16 if GPU OOM
+
+CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \
+ --project_name="${DATA_NAME^^}-CAT" \
+ --exp_name=${EXP_NAME} \
+ --query_batch_size=256 \
+ --build_mem_batch_size=12000 \
+ --batch_local_positive_num=5 \
+ --pos_from_top=128 \
+ --do_eval \
+ --kvm_seg_n=2 \
+ --values_with_order \
+ --value_layer=7 \
+ --value_fusion_method="cat_k_delay+v" \
+ --num_values=10 \
+ --qa_data_name=${DATA_NAME} \
+ --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \
+ --source_prefix="question: " \
+ --per_device_train_batch_size=64 \
+ --per_device_eval_batch_size=64 \
+ --gradient_accumulation_steps=4 \
+ --learning_rate=4e-5 \
+ --num_train_epochs=30 \
+ --lr_scheduler_type="constant" \
+ --num_warmup_steps=1000 \
+ --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \
+ --prefix_length=2 \
+ --d_key=1536 \
+ --key_layer=3 \
+ --key_encoder_type="conv" \
+ --select_positive_strategy="softmax_sample" \
+ --faiss_efsearch=128 \
+ --gen_weight=1 \
+ --match_weight=1 \
+ --key_reduce_method="avg" \
+ --qas_to_retrieve_from="PAQ_L1" \
+ --local_size=384 \
+ --update_kv_embeds \
+ --update_local_target_each_batch \
+ --update_local_qas \
+ --separate_task \
+ --value_ae_target="ans" \
+ --key_ae_target="question" \
+ --repaq_supervision_epoch=-1 \
+ --early_stop_patience=8 \
+ --negatives_num_each_example=32 \
+ --do_train
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/dr_utils.py b/utils/dr_utils.py
new file mode 100644
index 0000000..277d129
--- /dev/null
+++ b/utils/dr_utils.py
@@ -0,0 +1,698 @@
+import copy
+from functools import partial
+import random
+from typing import List
+import torch
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+from torch.nn.utils.rnn import pad_sequence
+from build_kvm import build_memory
+from emat.evaluation.eval_retriever import eval_retriever
+from emat.evaluation.exact_match import normalize_answer
+from kilt_dataset import DialogDataset
+from qa_dataset import QADataset
+from utils.utils import reduce_query_or_key_embeds
+import logging
+
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+)
+logger = logging.getLogger(__name__)
+
+
+# not used in QA
+def query_collate_fn(batch, tokenizer=None, dataset=None):
+ for item in batch:
+ item.update(dataset.get_base_input(item))
+ query_input_ids = [item["input_ids"] for item in batch]
+ query_input_ids = pad_sequence(query_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ query_attention_mask = (query_input_ids != tokenizer.pad_token_id).long()
+ return {"query_input_ids": query_input_ids, "query_attention_mask": query_attention_mask}
+
+
+# not used in QA
+def kvm_collate_fn(batch, tokenizer=None, train_dataset=None):
+ for item in batch:
+ item.update(train_dataset.get_base_input(item))
+ key_input_ids = [item["input_ids"] for item in batch]
+ key_input_ids = pad_sequence(key_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ key_attention_mask = (key_input_ids != tokenizer.pad_token_id).long()
+ value_input_ids = [item["target_as_input_ids"] for item in batch]
+ value_input_ids = pad_sequence(value_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
+ value_attention_mask = (value_input_ids != tokenizer.pad_token_id).long()
+ return {"key_input_ids": key_input_ids, "key_attention_mask": key_attention_mask,
+ "value_input_ids": value_input_ids, "value_attention_mask": value_attention_mask}
+
+
+# not used in QA
+@torch.no_grad()
+def build_key_memory(model, tokenizer, train_dataset, data_to_retrieve_list: List[List[dict]], args,
+ build_memory_batch_size=2560) -> List[torch.tensor]:
+ # key norm layer only used in key-mathing (retrieval)
+ # if not train key-matching, no gradient to key norm layer.
+ use_normed_key = True if args.key_matching_weight > 0.0 else False
+
+ key_memory = []
+ for data_to_retrieve in data_to_retrieve_list:
+ cur_km, _ = build_memory(model, tokenizer, embed_key=True, embed_value=False, embed_as_fp16=True,
+ key_reduce_method=args.key_reduce_method, data_to_embed=data_to_retrieve,
+ batch_size=build_memory_batch_size, return_memory=True, separate_task=True,
+ collate_fn=partial(kvm_collate_fn, tokenizer=tokenizer, train_dataset=train_dataset),
+ normed_key_memory=use_normed_key)
+ key_memory.append(cur_km)
+
+ return key_memory
+
+
+# not used in QA
+@torch.no_grad()
+def build_dataset_query(model, tokenizer, dataset, args, encode_query_batch_size=2560) -> torch.tensor:
+ dataset_query_embeds = []
+ query_to_embed_dataloader = DataLoader(dataset.data, batch_size=encode_query_batch_size, num_workers=16,
+ collate_fn=partial(query_collate_fn, tokenizer=tokenizer, dataset=dataset))
+ for query_batch_input in tqdm(query_to_embed_dataloader, desc="embed_query"):
+ embed_dict = model.CAT_embed_q(
+ input_ids=query_batch_input["query_input_ids"].to(model.device),
+ attention_mask=query_batch_input["query_attention_mask"].to(model.device),
+ compute_key=True, compute_value=False
+ )
+ use_normed_query = True if args.key_matching_weight > 0.0 else False
+ query_embeds = embed_dict["normed_key_embeds"] if use_normed_query else embed_dict["key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method)
+ query_embeds = query_embeds.half().cpu()
+ dataset_query_embeds.append(query_embeds)
+
+ dataset_query_embeds = torch.cat(dataset_query_embeds)
+ return dataset_query_embeds
+
+
+# not used in QA
+@torch.no_grad()
+def prepare_local_cases(model, tokenizer, train_dataset, key_memory: List[torch.tensor],
+ data_to_retrieve_list: List[List[dict]], args,
+ max_num_to_ban=48, query_batch_size=512, fp16_query=False):
+ model.eval()
+ dataset_query_embeds = build_dataset_query(model, tokenizer, train_dataset, args)
+ local_size = args.local_size
+ retrieve_topk = local_size + 1 # retrieve retrieve_topk-cases from each key_memory, +1 is used to exclude itself.
+ if args.filter_type != "not_filter":
+ retrieve_topk = local_size + max_num_to_ban # 48 > the max num to ban-to-retrieve
+ for example in train_dataset.data: # clear previous local_cases and their scores
+ example["local_cases"] = []
+ example["local_cases_scores"] = []
+ for cur_key_memory, cur_data_to_retrieve in zip(key_memory, data_to_retrieve_list):
+ cur_cuda_key_memory = cur_key_memory.cuda()
+ for start_idx in tqdm(range(0, len(dataset_query_embeds), query_batch_size),
+ total=len(dataset_query_embeds) // query_batch_size, desc="retrieve local cases"):
+ cur_cuda_query_embeds = dataset_query_embeds[start_idx: start_idx + query_batch_size].cuda()
+ cur_batch_data = train_dataset.data[start_idx: start_idx + query_batch_size]
+ if fp16_query:
+ scores = torch.mm(cur_cuda_query_embeds, cur_cuda_key_memory.t())
+ else: # not use fp16_query to avoid overflow
+ scores = torch.mm(cur_cuda_query_embeds.float(), cur_cuda_key_memory.t().float())
+ topk = scores.topk(min(len(cur_key_memory), retrieve_topk), dim=1)
+ top_indices = topk.indices.tolist()
+ top_scores = topk.values.tolist()
+ for cur_example, cur_indices, cur_scores in zip(cur_batch_data, top_indices, top_scores):
+ cur_local_cases = []
+ cur_local_cases_scores = []
+ for case_idx, case_score in zip(cur_indices, cur_scores):
+ cur_case = cur_data_to_retrieve[case_idx]
+ cur_case_id = cur_case["id"]
+
+ if cur_case_id == cur_example["id"]:
+ continue # exclude itself
+ if "ban_to_retrieve_list" in cur_example:
+ if cur_case_id in cur_example["ban_to_retrieve_list"]:
+ continue # the retrieved case is baned for cur_example
+
+ cur_local_cases.append(cur_case_id)
+ cur_local_cases_scores.append(case_score)
+
+ cur_local_cases = cur_local_cases[:local_size]
+ assert len(cur_local_cases) == local_size
+ cur_example["local_cases"] += cur_local_cases
+ cur_example["local_cases_scores"] = cur_local_cases_scores
+
+ del cur_cuda_key_memory
+ torch.cuda.empty_cache()
+
+ for example in train_dataset.data: # rank to select local-cases
+ sorted_cases_with_scores = sorted(zip(example["local_cases"], example["local_cases_scores"]),
+ key=lambda x: x[1], reverse=True)[:local_size]
+ example["local_cases"] = [sc[0] for sc in sorted_cases_with_scores]
+ example["local_cases_scores"] = [sc[1] for sc in sorted_cases_with_scores]
+
+
+# not used in QA
+@torch.no_grad()
+def update_batch_retrieve_from_local(model, batch, args):
+ query_input_ids = batch["input_ids"]
+ query_input_attention_mask = batch["attention_mask"]
+ embed_dict = model.CAT_embed_q(
+ input_ids=query_input_ids,
+ attention_mask=query_input_attention_mask,
+ compute_key=True, compute_value=False
+ )
+ use_normed_query = True if args.key_matching_weight > 0.0 else False
+ query_embeds = embed_dict["normed_key_embeds"] if use_normed_query else embed_dict["key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method)
+ query_embeds = query_embeds
+ cur_bs = query_input_ids.shape[0]
+
+ squeezed_local_cases_input_ids = batch.pop("squeezed_local_cases_input_ids")
+ squeezed_local_cases_attention_mask = batch.pop("squeezed_local_cases_attention_mask")
+ squeezed_local_cases_target_as_input_ids = batch.pop("squeezed_local_cases_target_as_input_ids")
+ squeezed_local_cases_target_as_input_attention_mask = batch.pop(
+ "squeezed_local_cases_target_as_input_attention_mask")
+ embed_dict = model.wrapped_embed_kv(
+ separate_task=args.separate_task, compute_key=True, compute_value=False,
+ key_input_ids=squeezed_local_cases_input_ids, key_attention_mask=squeezed_local_cases_attention_mask,
+ )
+ squeezed_local_cases_key = embed_dict["normed_key_embeds"] if use_normed_query else embed_dict["key_embeds"]
+ squeezed_local_cases_key = reduce_query_or_key_embeds(squeezed_local_cases_key, args.key_reduce_method)
+ local_cases_key = squeezed_local_cases_key.view(cur_bs, args.local_size, -1)
+ scores = torch.bmm(query_embeds.unsqueeze(dim=1), local_cases_key.transpose(2, 1)).squeeze(dim=1)
+ scores_topk = scores.topk(args.num_values, dim=1)
+ retrieved_indices = scores_topk.indices
+ gathered_indices = retrieved_indices.unsqueeze(dim=-1)
+
+ local_cases_input_ids = squeezed_local_cases_input_ids.view(cur_bs, args.local_size, -1)
+ local_cases_attention_mask = squeezed_local_cases_attention_mask.view(cur_bs, args.local_size, -1)
+ local_cases_target_as_input_ids = squeezed_local_cases_target_as_input_ids.view(cur_bs, args.local_size, -1)
+ local_cases_target_as_input_attention_mask = squeezed_local_cases_target_as_input_attention_mask. \
+ view(cur_bs, args.local_size, -1)
+ key_input_ids = torch.gather(local_cases_input_ids, 1, gathered_indices.
+ repeat(1, 1, local_cases_input_ids.shape[-1]))
+ key_attention_mask = torch.gather(local_cases_attention_mask, 1, gathered_indices.
+ repeat(1, 1, local_cases_attention_mask.shape[-1]))
+ value_input_ids = torch.gather(local_cases_target_as_input_ids, 1, gathered_indices.
+ repeat(1, 1, local_cases_target_as_input_ids.shape[-1]))
+ value_attention_mask = torch.gather(local_cases_target_as_input_attention_mask, 1, gathered_indices.
+ repeat(1, 1, local_cases_target_as_input_attention_mask.shape[-1]))
+ batch.update({"group_key_input_ids": key_input_ids, "group_key_attention_mask": key_attention_mask,
+ "group_value_input_ids": value_input_ids, "group_value_attention_mask": value_attention_mask})
+
+ if args.key_matching_weight > 0.0:
+ local_cases_label_ids = batch.pop("local_cases_label_ids")
+ retrieved_cases_label = torch.gather(local_cases_label_ids, 1, retrieved_indices)
+ squeezed_retrieved_cases_label = retrieved_cases_label.view(-1)
+ label_ids = batch["label_ids"]
+ matching_mask = []
+ matching_target = []
+ for bid, (cur_target_label, cur_retrieved_cases_label) in enumerate(zip(label_ids, retrieved_cases_label)):
+ cur_mask = torch.ones_like(squeezed_retrieved_cases_label)
+ cur_mask[squeezed_retrieved_cases_label == cur_target_label] = 0
+ matched_pos = (cur_retrieved_cases_label == cur_target_label).nonzero().view(-1)
+ if len(matched_pos) > 0:
+ cur_pos_idx = matched_pos[0] + len(cur_retrieved_cases_label) * bid
+ matching_target.append(cur_pos_idx)
+ cur_mask[cur_pos_idx] = 1
+ else:
+ matching_target.append(-100)
+ matching_mask.append(cur_mask)
+
+ batch.update({"matching_target": torch.tensor(matching_target).to(model.device),
+ "matching_mask": torch.stack(matching_mask).to(model.device)})
+
+
+# not used in QA
+@torch.no_grad()
+def retrieve_from_key_memory(model, tokenizer, dataset, key_memory: List[torch.tensor],
+ data_to_retrieve_list: List[List[dict]], args):
+ model.eval()
+ dataset_query_embeds = build_dataset_query(model, tokenizer, dataset, args)
+ query_batch_size = 512
+ for example in dataset.data: # clear previous local_cases and their scores
+ example["retrieved_cases"] = []
+ example["retrieved_cases_scores"] = []
+ for cur_key_memory, cur_data_to_retrieve in zip(key_memory, data_to_retrieve_list):
+ cur_cuda_key_memory = cur_key_memory.cuda()
+ for start_idx in tqdm(range(0, len(dataset_query_embeds), query_batch_size),
+ total=len(dataset_query_embeds) // query_batch_size, desc="query_key_memory"):
+ cur_cuda_query_embeds = dataset_query_embeds[start_idx: start_idx + query_batch_size].cuda()
+ cur_batch_data = dataset.data[start_idx: start_idx + query_batch_size]
+ scores = torch.mm(cur_cuda_query_embeds, cur_cuda_key_memory.t())
+ topk = scores.topk(args.num_values, dim=1)
+ top_indices = topk.indices.tolist()
+ top_scores = topk.values.tolist()
+ for cur_example, cur_indices, cur_scores in zip(cur_batch_data, top_indices, top_scores):
+ cur_local_cases = [cur_data_to_retrieve[case_idx] for case_idx in cur_indices]
+ cur_local_cases_scores = cur_scores
+ cur_example["retrieved_cases"] += cur_local_cases
+ cur_example["retrieved_cases_scores"] = cur_local_cases_scores
+ del cur_cuda_key_memory
+ torch.cuda.empty_cache()
+ for example in dataset.data:
+ sorted_cases_with_scores = sorted(zip(example["retrieved_cases"], example["retrieved_cases_scores"]),
+ key=lambda x: x[1], reverse=True)[:args.num_values]
+ retrieved_cases = [sc[0] for sc in sorted_cases_with_scores]
+ for item in retrieved_cases:
+ item.update(dataset.get_base_input(item))
+ retrieved_key_seqs = [case["input_ids"] for case in retrieved_cases]
+ retrieved_value_seqs = [case["target_as_input_ids"] for case in retrieved_cases]
+ example.update({"retrieved_key_seqs": retrieved_key_seqs, "retrieved_value_seqs": retrieved_value_seqs})
+ # example["retrieved_cases_scores"] = [sc[1] for sc in sorted_cases_with_scores]
+
+
+# it is used in QA if not rank-exists-local-qas
+@torch.no_grad()
+def update_local_qas_to_retrieve(args, train_dataset, qas_to_retrieve, model, key_memory: List[torch.tensor],
+ normed_answer_of_qas_to_ret, train_data_query_embeds=None, build_mem_batch_size=1024,
+ query_batch_size=128, local_size=1024, pos_from_top=50, neg_from_top=200,
+ use_retrieval_adapter=False):
+ model.eval()
+ assert type(key_memory) == list
+
+ logger.info(f"Prepare local QAs for each example to retrieve.")
+ all_local_qas = []
+ all_local_positive = []
+ all_local_negative = []
+ all_ret_qas = []
+
+ if train_data_query_embeds is None:
+ if use_retrieval_adapter:
+ dim = args.adapter_out_dim
+ else:
+ dim = model.model_dim
+ train_data_query_embeds = torch.zeros((len(train_dataset.data), dim), device='cpu', dtype=torch.float16)
+ if args.qa_data_name == "tq":
+ build_query_batch_size = 256
+ else:
+ build_query_batch_size = build_mem_batch_size
+ embed_query_dataloader = train_dataset.get_query_dataloader(batch_size=build_query_batch_size,
+ shuffle=False, num_workers=1)
+ start_idx = 0
+ for query_inputs in tqdm(embed_query_dataloader, total=len(embed_query_dataloader), desc="Embed queries."):
+ end_idx = start_idx + len(query_inputs["query_input_ids"])
+ embed_dict = model.CAT_embed_q(
+ input_ids=query_inputs["query_input_ids"].to(model.device),
+ attention_mask=query_inputs["query_attention_mask"].to(model.device),
+ compute_key=True, compute_value=False
+ )
+ query_embeds = embed_dict["normed_key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method)
+ if use_retrieval_adapter:
+ query_embeds = model.adapter(query_embeds)
+ query_embeds = query_embeds.half().cpu()
+ train_data_query_embeds[start_idx: end_idx] = query_embeds
+ start_idx = end_idx
+ assert start_idx == len(train_dataset.data)
+
+ torch.cuda.empty_cache()
+
+ key_nums = sum(len(k) for k in key_memory)
+ logger.info(f"key-memory seg-num: {len(key_memory)}. all key nums: {key_nums}.")
+
+ for start_idx in tqdm(range(0, len(train_dataset.data), query_batch_size),
+ total=len(train_dataset.data) // query_batch_size + 1):
+ cur_cuda_query_embeds = train_data_query_embeds[start_idx: start_idx + query_batch_size].cuda()
+ if key_nums > 10000000:
+ # if scale is large: calculate topk in each chunk -> combine all-topk -> select final topk
+ chunk_top_scores = []
+ chunk_top_indices = []
+ idx_shift = 0
+ for ckm_idx, chunk_key_memory in enumerate(key_memory):
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_topk = torch.mm(cur_cuda_query_embeds, chunk_key_memory_cuda.t()).topk(local_size, dim=1)
+ chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size]
+ chunk_top_indices.append(chunk_topk.indices + idx_shift)
+ idx_shift += len(chunk_key_memory)
+ del chunk_key_memory_cuda
+ torch.cuda.empty_cache()
+ chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n
+ chunk_top_indices = torch.cat(chunk_top_indices, dim=1).tolist() # q_batch, local_size*seg_n
+ topk = chunk_top_scores.topk(local_size, dim=1) # q_batch, local_size
+ top_indices_indices = topk.indices.tolist()
+ top_indices = []
+ for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices):
+ top_indices.append([cur_indices[idx] for idx in cur_indices_indices])
+ else:
+ # if scale is moderate: calculate score in each chunk -> combine score -> select topk
+ all_chunk_scores = []
+ for chunk_key_memory in key_memory:
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_scores = torch.mm(cur_cuda_query_embeds, chunk_key_memory_cuda.t())
+ all_chunk_scores.append(chunk_scores) # q_batch, chunk_size
+ del chunk_key_memory_cuda
+ torch.cuda.empty_cache()
+ scores = torch.cat(all_chunk_scores, dim=1).cuda() # q_batch, key_memory_size
+ topk = scores.topk(local_size, dim=1)
+ top_indices = topk.indices.tolist()
+
+ batch = train_dataset.data[start_idx: start_idx + query_batch_size]
+ for cur_example, cur_indices in zip(batch, top_indices):
+ local_positive, local_negative = [], []
+ # cur_target = [normalize_answer(ans) for ans in cur_example["answer"]]
+ cur_target = [na for na in cur_example["normalized_answer"]]
+
+ cur_ret_qas = []
+ for top_idx, qa_idx in enumerate(cur_indices):
+ if normed_answer_of_qas_to_ret[qa_idx] in cur_target:
+ if top_idx < pos_from_top:
+ local_positive.append(qa_idx)
+ else:
+ if top_idx < neg_from_top:
+ local_negative.append(qa_idx)
+ elif len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative
+ # if len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative
+ local_negative.append(qa_idx)
+ cur_ret_qas.append(qas_to_retrieve[qa_idx])
+
+ all_ret_qas.append(cur_ret_qas)
+
+ all_local_positive.append(local_positive)
+ all_local_negative.append(local_negative)
+
+ all_local_qas += top_indices
+ del cur_cuda_query_embeds
+ del topk
+
+ torch.cuda.empty_cache()
+ assert len(all_ret_qas) == len(train_dataset.data)
+ assert len(all_local_qas) == len(train_dataset.data) == len(all_local_positive) == len(all_local_negative)
+ matching_metric = eval_retriever(train_dataset.data, all_ret_qas, "1,2,3,4,5")
+ for k, v in matching_metric.items():
+ logging.info({f"local_qas initial {k}": v})
+ for i in range(len(train_dataset.data)):
+ train_dataset.data[i]["local_positive"] = all_local_positive[i]
+ train_dataset.data[i]["local_negative"] = all_local_negative[i]
+ train_dataset.data[i]["local_qas"] = all_local_qas[i]
+
+ logger.info(f"Local QAs updated.")
+
+
+@torch.no_grad()
+def update_batch_inputs(args, batch, model, use_adapter_to_select_positive=False):
+ model.eval()
+ embed_dict = model.CAT_embed_q(
+ input_ids=batch["query_input_ids"],
+ attention_mask=batch["query_attention_mask"],
+ compute_key=True, compute_value=False
+ )
+ query_embeds = embed_dict["normed_key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method)
+ if use_adapter_to_select_positive:
+ query_embeds = model.adapter(query_embeds)
+ batch_size, hidden_size = query_embeds.shape
+
+ local_positive_inputs_keys = [k for k in batch.keys() if k.startswith("local_positive_inputs_")]
+ local_positive_inputs = {k.replace("local_positive_inputs_", ""):
+ batch.pop(k).view(batch_size * args.batch_local_positive_num, -1)
+ for k in local_positive_inputs_keys}
+ embed_dict = model.wrapped_embed_kv(separate_task=args.separate_task, compute_key=True,
+ compute_value=False, **local_positive_inputs)
+ local_positive_key_embeds = embed_dict["normed_key_embeds"]
+ local_positive_key_embeds = reduce_query_or_key_embeds(local_positive_key_embeds, args.key_reduce_method)
+ if use_adapter_to_select_positive:
+ local_positive_key_embeds = model.adapter(local_positive_key_embeds)
+ scores = torch.bmm(query_embeds.unsqueeze(dim=1), local_positive_key_embeds.view(
+ batch_size, args.batch_local_positive_num, hidden_size).transpose(2, 1)).squeeze(dim=1)
+ scores = scores + (batch["local_positive_qas_mask"] - 1) * 1e-4
+ scores = torch.softmax(scores, dim=1)
+
+ sampled_pos_local_idx = torch.multinomial(scores, 1).squeeze(dim=-1) # [batch_size]
+
+ # sampled_local_idx: [batch_size]
+ # local_positive_inputs[key_input_ids/key_attention_mask]: [batch_size*max_pos_num, seq_length]
+ all_pos_key_input_ids = local_positive_inputs["key_input_ids"].view(batch_size, args.batch_local_positive_num, -1)
+ all_pos_key_attention_mask = local_positive_inputs["key_attention_mask"].view(all_pos_key_input_ids.shape)
+ positive_gather_indices = sampled_pos_local_idx.unsqueeze(dim=1). \
+ repeat(1, all_pos_key_input_ids.shape[-1]).unsqueeze(dim=1)
+ positive_key_input_ids = torch.gather(all_pos_key_input_ids, 1, positive_gather_indices).squeeze(dim=1)
+ positive_key_attention_mask = torch.gather(all_pos_key_attention_mask, 1, positive_gather_indices).squeeze(dim=1)
+ # positive_key_input_ids: [bach_size, seq_len]
+
+ if "labels_to_select" in batch:
+ labels_to_select = batch.pop("labels_to_select") # batch, args.batch_local_positive_num, seq_len
+ target_gather_indices = sampled_pos_local_idx.unsqueeze(dim=1). \
+ repeat(1, labels_to_select.shape[-1]).unsqueeze(dim=1)
+ labels = torch.gather(labels_to_select, 1, target_gather_indices).squeeze(dim=1)
+ batch.update({"labels": labels})
+
+ # if args.num_values > 1 or args.use_not_exactly_true:
+ local_mixed_inputs_keys = [k for k in batch.keys() if k.startswith("local_mixed_inputs_")]
+ local_mixed_inputs = {k.replace("local_mixed_inputs_", ""): batch.pop(k) for k in local_mixed_inputs_keys}
+ mixed_qas_num = args.negatives_num_each_example
+ embed_dict = model.wrapped_embed_kv(separate_task=args.separate_task, compute_key=True,
+ compute_value=True, **{k: v.view(batch_size * mixed_qas_num, -1)
+ for k, v in local_mixed_inputs.items()})
+ mixed_key_embeds = embed_dict["normed_key_embeds"]
+ mixed_key_embeds = reduce_query_or_key_embeds(mixed_key_embeds, args.key_reduce_method)
+ if use_adapter_to_select_positive:
+ mixed_key_embeds = model.adapter(mixed_key_embeds)
+ mixed_key_embeds = mixed_key_embeds.view(batch_size, mixed_qas_num, -1)
+ scores = torch.bmm(query_embeds.unsqueeze(dim=1), mixed_key_embeds.transpose(2, 1)).squeeze(dim=1)
+ # assert args.values_with_order is True
+ # if w/o order, shuffle the group_value_qas_indices of each example
+ group_value_qas_indices = scores.topk(args.num_values, dim=1).indices # [batch_size, num_values]
+ if not args.values_with_order:
+ group_value_qas_indices = group_value_qas_indices.tolist()
+ for value_qas_indices in group_value_qas_indices:
+ random.shuffle(value_qas_indices)
+ group_value_qas_indices = torch.tensor(group_value_qas_indices).to(scores.device)
+
+ # assert args.num_values > 1
+ # if args.num_values == 1, the input is from the sampled_pos_local_idx.
+
+ # if args.num_values > 1:
+ mixed_key_input_ids = local_mixed_inputs["key_input_ids"] # [batch_size, mixed_qas_num, seq_len]
+ mixed_key_attention_mask = local_mixed_inputs["key_attention_mask"]
+ mixed_value_input_ids = local_mixed_inputs["value_input_ids"]
+ mixed_value_attention_mask = local_mixed_inputs["value_attention_mask"]
+ key_gather_indices = group_value_qas_indices.unsqueeze(dim=-1).repeat(1, 1, mixed_key_input_ids.shape[-1])
+ group_key_input_ids = torch.gather(mixed_key_input_ids, 1, key_gather_indices)
+ group_key_attention_mask = torch.gather(mixed_key_attention_mask, 1, key_gather_indices)
+ value_gather_indices = group_value_qas_indices.unsqueeze(dim=-1).repeat(1, 1, mixed_value_input_ids.shape[-1])
+ group_value_input_ids = torch.gather(mixed_value_input_ids, 1, value_gather_indices)
+ group_value_attention_mask = torch.gather(mixed_value_attention_mask, 1, value_gather_indices)
+
+ # matching_targets
+ matching_targets = torch.arange(batch_size).to(model.device)
+ matching_targets[batch.pop("local_positive_num") == 0] = -100
+
+ batch.update({
+ "positive_kv_inputs": {
+ "key_input_ids": positive_key_input_ids,
+ "key_attention_mask": positive_key_attention_mask
+ },
+ "negative_kv_inputs": {
+ "key_input_ids": batch.pop("local_negative_inputs_key_input_ids")
+ .view(batch_size * args.negatives_num_each_example, -1),
+ "key_attention_mask": batch.pop("local_negative_inputs_key_attention_mask")
+ .view(batch_size * args.negatives_num_each_example, -1),
+ },
+ "matching_targets": matching_targets,
+ "group_value_inputs": {
+ "key_input_ids": group_key_input_ids.view(batch_size * args.num_values, -1),
+ "key_attention_mask": group_key_attention_mask.view(batch_size * args.num_values, -1),
+ "value_input_ids": group_value_input_ids.view(batch_size * args.num_values, -1),
+ "value_attention_mask": group_value_attention_mask.view(batch_size * args.num_values, -1),
+ },
+ })
+
+
+# not really rank-local, only prepare positive/negative qas from exists local-qas
+@torch.no_grad()
+def rank_exist_local_qas(args, train_dataset: QADataset, qas_to_retrieve, model, normed_answer_of_qas_to_ret,
+ train_data_query_embeds=None, build_mem_batch_size=1204,
+ embed_local_qas_batch_size=6, local_size=1024, pos_from_top=50, neg_from_top=200,
+ accelerator=None):
+ model.eval()
+ if args.use_fp16_rank:
+ half_model = copy.deepcopy(model)
+ half_model.eval()
+ model = half_model.half()
+ embed_local_qas_batch_size = int(embed_local_qas_batch_size * 1.5)
+ logger.info(f"Rank local QAs for each example to retrieve. embed_local_qas_batch_size={embed_local_qas_batch_size}")
+
+ if train_data_query_embeds is None:
+ train_data_query_embeds = torch.zeros((len(train_dataset.data), model.model_dim), device='cpu',
+ dtype=torch.float16)
+
+ embed_query_dataloader = train_dataset.get_query_dataloader(batch_size=build_mem_batch_size,
+ shuffle=False, num_workers=5)
+ start_idx = 0
+ for query_inputs in tqdm(embed_query_dataloader, total=len(embed_query_dataloader), desc="Embed queries."):
+ end_idx = start_idx + len(query_inputs["query_input_ids"])
+ embed_dict = model.CAT_embed_q(
+ input_ids=query_inputs["query_input_ids"].to(model.device),
+ attention_mask=query_inputs["query_attention_mask"].to(model.device),
+ compute_key=True, compute_value=False
+ )
+ query_embeds = embed_dict["normed_key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method)
+ query_embeds = query_embeds.half().cpu()
+ train_data_query_embeds[start_idx: end_idx] = query_embeds
+ start_idx = end_idx
+ assert start_idx == len(train_dataset.data)
+
+ torch.cuda.empty_cache()
+
+ embed_local_dataloader = train_dataset.get_local_qas_dataloader(batch_size=embed_local_qas_batch_size,
+ shuffle=False, num_workers=10)
+ if accelerator is not None:
+ embed_local_dataloader = accelerator.prepare(embed_local_dataloader)
+
+ all_ret_qas = []
+
+ start_idx = 0
+ for local_qas_batch in tqdm(embed_local_dataloader):
+ query_ids = local_qas_batch.pop("query_ids")
+ bs = len(query_ids)
+ cur_query_embeds = train_data_query_embeds[start_idx: start_idx + bs]
+ cur_batch = train_dataset.data[start_idx: start_idx + bs]
+ start_idx = start_idx + bs
+
+ embed_dict = model.wrapped_embed_kv(
+ separate_task=args.separate_task, compute_key=True, compute_value=False,
+ **local_qas_batch
+ )
+ squeezed_local_key_embeds = embed_dict["normed_key_embeds"]
+ squeezed_local_key_embeds = reduce_query_or_key_embeds(squeezed_local_key_embeds, args.key_reduce_method)
+ local_key_embeds = squeezed_local_key_embeds.view(bs, -1, model.model_dim).half()
+ # exists local-qas should be larger than expect local-size
+ assert local_size <= squeezed_local_key_embeds.shape[1]
+
+ cur_query_embeds = cur_query_embeds.cuda()
+ # [bs, 1, hidden] [bs, hidden, exists-local-qas-num] --> [bs, exists-local-qas-num]
+ scores = torch.bmm(cur_query_embeds.unsqueeze(dim=1), local_key_embeds.transpose(2, 1)).squeeze(dim=1)
+ top_local_indices = scores.topk(local_size, dim=1).indices
+
+ for cur_example, cur_indices in zip(cur_batch, top_local_indices):
+ local_positive, local_negative = [], []
+ cur_target = [normalize_answer(ans) for ans in cur_example["answer"]]
+
+ qas_ids_of_top_local = [cur_example["local_qas"][local_idx] for local_idx in cur_indices]
+ # qas_of_top_local = [qas_to_retrieve[qid] for qid in qas_ids_of_top_local]
+ all_ret_qas.append([qas_to_retrieve[qid] for qid in qas_ids_of_top_local[:50]])
+
+ for top_idx, qa_idx in enumerate(qas_ids_of_top_local):
+ if normed_answer_of_qas_to_ret[qa_idx] in cur_target:
+ if top_idx < pos_from_top:
+ local_positive.append(qa_idx)
+ else:
+ if top_idx < neg_from_top:
+ local_negative.append(qa_idx)
+ elif len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative
+ local_negative.append(qa_idx)
+
+ cur_example["local_positive"] = local_positive
+ cur_example["local_negative"] = local_negative
+
+ assert len(all_ret_qas) == len(train_dataset.data)
+ matching_metric = eval_retriever(train_dataset.data, all_ret_qas, "1,2,3,4,5,10,50")
+ for k, v in matching_metric.items():
+ logging.info({f"local_qas initial {k}": v})
+
+ if args.use_fp16_rank:
+ del model
+ torch.cuda.empty_cache()
+ logger.info(f"Local QAs ranked.")
+
+
+# Dialog
+@torch.no_grad()
+def update_dialog_local_qas_to_retrieve(args, train_dataset: DialogDataset, qas_to_retrieve, model,
+ key_memory: List[torch.tensor], normed_answer_of_qas_to_ret,
+ train_data_query_embeds=None, build_mem_batch_size=1024,
+ query_batch_size=128, local_size=1024, pos_from_top=50,
+ neg_from_top=200):
+ model.eval()
+ assert type(key_memory) == list
+
+ logger.info(f"Prepare local QAs for each example to retrieve.")
+ all_local_qas = []
+ all_local_positive = []
+ all_local_negative = []
+ all_ret_qas = []
+
+ if train_data_query_embeds is None:
+ train_data_query_embeds = torch.zeros((len(train_dataset.data), model.model_dim), device='cpu',
+ dtype=torch.float16)
+ embed_query_dataloader = train_dataset.get_query_dataloader(batch_size=query_batch_size, shuffle=False,
+ num_workers=1, )
+ start_idx = 0
+ for query_inputs in tqdm(embed_query_dataloader, total=len(embed_query_dataloader), desc="Embed queries."):
+ end_idx = start_idx + len(query_inputs["query_input_ids"])
+ embed_dict = model.CAT_embed_q(
+ input_ids=query_inputs["query_input_ids"].to(model.device),
+ attention_mask=query_inputs["query_attention_mask"].to(model.device),
+ compute_key=True, compute_value=False
+ )
+ query_embeds = embed_dict["normed_key_embeds"]
+ query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method)
+ query_embeds = query_embeds.half().cpu()
+ train_data_query_embeds[start_idx: end_idx] = query_embeds
+ start_idx = end_idx
+ assert start_idx == len(train_dataset.data)
+
+ torch.cuda.empty_cache()
+
+ key_nums = sum(len(k) for k in key_memory)
+ logger.info(f"key-memory seg-num: {len(key_memory)}. all key nums: {key_nums}.")
+ query_batch_size = 2000
+ for start_idx in tqdm(range(0, len(train_dataset.data), query_batch_size),
+ total=len(train_dataset.data) // query_batch_size + 1):
+ cur_cuda_query_embeds = train_data_query_embeds[start_idx: start_idx + query_batch_size].cuda()
+
+ # calculate topk in each chunk -> combine all-topk -> select final topk
+ chunk_top_scores = []
+ chunk_top_indices = []
+ idx_shift = 0
+ for ckm_idx, chunk_key_memory in enumerate(key_memory):
+ chunk_key_memory_cuda = chunk_key_memory.cuda()
+ chunk_topk = torch.mm(cur_cuda_query_embeds, chunk_key_memory_cuda.t()).topk(local_size, dim=1)
+ chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size]
+ chunk_top_indices.append(chunk_topk.indices + idx_shift)
+ idx_shift += len(chunk_key_memory)
+ del chunk_key_memory_cuda
+ torch.cuda.empty_cache()
+ chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n
+ chunk_top_indices = torch.cat(chunk_top_indices, dim=1).tolist() # q_batch, local_size*seg_n
+ topk = chunk_top_scores.topk(local_size, dim=1) # q_batch, local_size
+ top_indices_indices = topk.indices.tolist()
+ top_indices = []
+ for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices):
+ top_indices.append([cur_indices[idx] for idx in cur_indices_indices])
+
+ batch = train_dataset.data[start_idx: start_idx + query_batch_size]
+ for cur_example, cur_indices in zip(batch, top_indices):
+ local_positive, local_negative = [], []
+
+ cur_target_words = cur_example["normalized_response_remove_stop_words_list"] # a list of words
+
+ cur_ret_qas = []
+ for top_idx, qa_idx in enumerate(cur_indices):
+ if normed_answer_of_qas_to_ret[qa_idx] in cur_target_words:
+ # if normed_answer_of_qas_to_ret[qa_idx] in cur_target: # the QA's answer overlaps with response
+ if top_idx < pos_from_top:
+ local_positive.append(qa_idx)
+ else:
+ if top_idx < neg_from_top:
+ local_negative.append(qa_idx)
+ elif len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative
+ # if len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative
+ local_negative.append(qa_idx)
+ cur_ret_qas.append(qas_to_retrieve[qa_idx])
+
+ all_ret_qas.append(cur_ret_qas)
+
+ all_local_positive.append(local_positive)
+ all_local_negative.append(local_negative)
+
+ all_local_qas += top_indices
+ del cur_cuda_query_embeds
+ del topk
+
+ torch.cuda.empty_cache()
+ assert len(all_ret_qas) == len(train_dataset.data)
+ assert len(all_local_qas) == len(train_dataset.data) == len(all_local_positive) == len(all_local_negative)
+ for i in range(len(train_dataset.data)):
+ train_dataset.data[i]["local_positive"] = all_local_positive[i]
+ train_dataset.data[i]["local_negative"] = all_local_negative[i]
+ train_dataset.data[i]["local_qas"] = all_local_qas[i]
+
+ logger.info(f"Local QAs updated.")
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000..da0eb60
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,774 @@
+import argparse
+import json
+import logging
+import os
+import pickle
+import random
+
+import functools
+import numpy as np
+import torch
+from emat.evaluation.eval_retriever import eval_retriever
+
+from emat.evaluation.exact_match import normalize_answer
+from typing import List
+from emat.utils import verbalise_qa
+from copy import deepcopy
+from transformers import MODEL_MAPPING, T5Config, T5Tokenizer, AutoConfig, AutoTokenizer, AutoModelForMultipleChoice
+from tqdm.auto import tqdm
+from itertools import chain
+
+MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
+MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
+
+
+def reduce_query_or_key_embeds(qk_embeds, key_reduce_method):
+ batch_size, key_nums, hidden_size = qk_embeds.shape
+ if key_reduce_method == "concat":
+ reduced_embeds = qk_embeds.view(batch_size, key_nums * hidden_size)
+ elif key_reduce_method == "avg":
+ reduced_embeds = qk_embeds.sum(dim=1) / key_nums
+ elif key_reduce_method == "sum":
+ reduced_embeds = qk_embeds.sum(dim=1)
+ else:
+ raise NotImplementedError(f"Reduce method ``{key_reduce_method}`` is not defined.")
+
+ return reduced_embeds
+
+
+def load_reranker(model_name_or_path="./data/models/rerankers/reranker_multi_xxlarge"):
+ logging.info(f'Loading rerank model from: {model_name_or_path}')
+ config = AutoConfig.from_pretrained(model_name_or_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, do_lower_case=True)
+ model = AutoModelForMultipleChoice.from_pretrained(
+ model_name_or_path,
+ from_tf=bool(".ckpt" in model_name_or_path),
+ config=config,
+ )
+ model = model.eval()
+ return model, tokenizer
+
+
+def get_key_value_ae_target(qas, tokenizer, key_ae_target, value_ae_target, max_target_length, prefix=""):
+ if value_ae_target == "ans":
+ targets_of_value = [qa["answer"][0] for qa in qas]
+ else: # question_ans
+ targets_of_value = [f'question: {qa["question"]} answer: {qa["answer"][0]}' for qa in qas]
+ if key_ae_target == "question_ans":
+ targets_of_key = [f'question: {qa["question"]} answer: {qa["answer"][0]}' for qa in qas]
+ else: # question
+ targets_of_key = [prefix + qa["question"] for qa in qas]
+
+ with tokenizer.as_target_tokenizer(): # setup the tokenizer for targets
+ key_labels = tokenizer(targets_of_key, max_length=max_target_length,
+ padding=True, truncation=True, return_tensors="pt")
+ value_labels = tokenizer(targets_of_value, max_length=max_target_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"key_labels_input_ids": process_labels(key_labels, tokenizer),
+ "value_labels_input_ids": process_labels(value_labels, tokenizer)}
+
+
+def get_key_value_encoder_inputs(qas, separate_task, tokenizer, max_source_length,
+ prefix="", only_return_key_inputs=False, value_input_is_qa=False):
+ # Used to get the input of Key-Value Encoder, qas are from PAQ-L1
+ if separate_task:
+ key_inputs = ["question: " + qa["question"] for qa in qas]
+ key_inputs = tokenizer(key_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ if only_return_key_inputs:
+ return {"key_input_ids": key_inputs["input_ids"],
+ "key_attention_mask": key_inputs["attention_mask"]}
+ else:
+ if value_input_is_qa:
+ value_inputs = [f'question: {qa["question"]} answer: {qa["answer"][0]}' for qa in qas]
+ value_inputs = tokenizer(value_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ else:
+ value_inputs = ["answer: " + qa["answer"][0] for qa in qas]
+ value_inputs = tokenizer(value_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"key_input_ids": key_inputs["input_ids"],
+ "key_attention_mask": key_inputs["attention_mask"],
+ "value_input_ids": value_inputs["input_ids"],
+ "value_attention_mask": value_inputs["attention_mask"]}
+ else:
+ key_value_inputs = [prefix + verbalise_qa(qa["question"], qa["answer"][0]) for qa in qas]
+ key_value_inputs = tokenizer(key_value_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"key_value_input_ids": key_value_inputs["input_ids"],
+ "key_value_attention_mask": key_value_inputs["attention_mask"]}
+
+
+def get_nli_group_value_inputs(examples, tokenizer, max_source_length):
+ group_key_inputs = [ex["retrieved_key_seqs"] for ex in examples]
+ group_key_inputs = list(chain(*group_key_inputs))
+ group_key_inputs = tokenizer(group_key_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ group_value_inputs = [ex["retrieved_value_seqs"] for ex in examples]
+ group_value_inputs = list(chain(*group_value_inputs))
+ group_value_inputs = tokenizer(group_value_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"key_input_ids": group_key_inputs["input_ids"],
+ "key_attention_mask": group_key_inputs["attention_mask"],
+ "value_input_ids": group_value_inputs["input_ids"],
+ "value_attention_mask": group_value_inputs["attention_mask"]}
+
+
+def get_query_encoder_inputs(qas, tokenizer, max_source_length, prefix=""):
+ # Used to get the input of Query Encoder, qas are from NaturalQuestion
+ query_inputs = [prefix + qa["question"] for qa in qas]
+ query_inputs = tokenizer(query_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"query_input_ids": query_inputs["input_ids"],
+ "query_attention_mask": query_inputs["attention_mask"]}
+
+
+def get_nli_input_seq(item):
+ return f"hypothesis: {item['hypothesis']} premise: {item['premise']}"
+
+
+def get_query_encoder_inputs_nli(cases, tokenizer, max_source_length):
+ query_inputs = [case["input_seq"] for case in cases]
+ query_inputs = tokenizer(query_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"query_input_ids": query_inputs["input_ids"],
+ "query_attention_mask": query_inputs["attention_mask"]}
+
+
+label2str = {0: "entailment", 1: "neutral", 2: "contradiction"}
+
+
+def get_key_value_encoder_inputs_nli(cases, tokenizer, max_source_length, only_return_key_inputs=False):
+ key_inputs = [case["input_seq"] for case in cases]
+ key_inputs = tokenizer(key_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ if only_return_key_inputs:
+ return {"key_input_ids": key_inputs["input_ids"],
+ "key_attention_mask": key_inputs["attention_mask"]}
+ else:
+ value_inputs = [label2str[case["label"]] for case in cases]
+ value_inputs = tokenizer(value_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ return {"key_input_ids": key_inputs["input_ids"],
+ "key_attention_mask": key_inputs["attention_mask"],
+ "value_input_ids": value_inputs["input_ids"],
+ "value_attention_mask": value_inputs["attention_mask"]}
+
+
+def get_qa_inputs(qas, tokenizer, max_source_length, max_target_length, prefix="",
+ targets=None):
+ # Used to get the normal inputs of QA, qas are from NaturalQuestion
+ # Normal inputs and outputs
+ model_inputs = [prefix + qa["question"] for qa in qas]
+ model_inputs = tokenizer(model_inputs, max_length=max_source_length,
+ padding=True, truncation=True, return_tensors="pt")
+ if targets is None:
+ targets = [qa["answer"][0] for qa in qas]
+ elif targets == "$random$":
+ targets = [random.choice(qa["answer"]) for qa in qas]
+ with tokenizer.as_target_tokenizer():
+ targets = tokenizer(targets, max_length=max_target_length,
+ padding=True, truncation=True, return_tensors="pt")
+ model_inputs["labels"] = process_labels(targets, tokenizer)
+ return model_inputs
+
+
+def process_labels(labels, tokenizer, label_pad_token_id=-100, pad_to_multiple_of=None):
+ if getattr(labels, "input_ids", None) is not None:
+ input_ids = labels["input_ids"]
+ bsz, label_length = input_ids.size()
+ else:
+ input_ids = labels
+ bsz = len(input_ids)
+ label_length = len(input_ids[0])
+
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
+ # padding in the loss.
+ # if args.ignore_pad_token_for_loss:
+ input_ids[input_ids == tokenizer.pad_token_id] = label_pad_token_id
+
+ if pad_to_multiple_of is not None:
+ max_label_length = (
+ (label_length + pad_to_multiple_of - 1) // pad_to_multiple_of * pad_to_multiple_of
+ )
+ remainder = max_label_length - label_length
+ if remainder > 0:
+ pad_ids = torch.full(
+ (bsz, remainder),
+ fill_value=label_pad_token_id,
+ dtype=input_ids.dtype,
+ device=input_ids.device
+ )
+ input_ids = torch.cat([input_ids, pad_ids], dim=1)
+
+ return input_ids
+
+
+def postprocess_text(preds, labels):
+ preds = [pred.strip() for pred in preds]
+ labels = [[label.strip()] for label in labels]
+ return preds, labels
+
+
+def load_model(model_class, args):
+ # Load pretrained model and tokenizer
+ #
+ # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
+ # download model & vocab.
+ if args.resume_training and args.output_dir is not None:
+ args.model_name_or_path = os.path.join(args.output_dir, "latest_ckpt")
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
+ config = T5Config.from_pretrained(args.model_name_or_path)
+ model = model_class.from_pretrained(args.model_name_or_path)
+ return config, tokenizer, model
+
+ config = T5Config.from_pretrained(args.model_name_or_path)
+ update_CAT_config_from_args(config, args)
+ print(config)
+
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
+ # model = model_class(config)
+ model, load_info = model_class.from_pretrained(args.model_name_or_path, config=config, output_loading_info=True)
+ state_dict = torch.load(os.path.join(args.model_name_or_path, "pytorch_model.bin"))
+ logging.info(f"model-load-info: {load_info}")
+
+ manually_initialized_params = []
+ if args.not_share_encoder and "kv_encoder.final_layer_norm.weight" not in state_dict.keys():
+ # "kv_encoder.final_layer_norm.weight not in state-dict" means the loaded model is share-encoder.
+ kv_encoder_state_dict = dict()
+ for k, v in state_dict.items():
+ if k in ['encoder.qk_scorer.bias', 'encoder.qk_scorer.weight']:
+ continue
+ if k.startswith("encoder."):
+ kv_encoder_state_dict[f"{k[len('encoder.'):]}"] = deepcopy(v)
+ manually_initialized_params.append(f"kv_{k}")
+ model.kv_encoder.load_state_dict(kv_encoder_state_dict, strict=True)
+ logging.info("Not share encoder, and initialize Key-Value encoder using CAT-encoder.")
+ else:
+ logging.info("Share the Key-Value encoder and CAT encoder")
+ if "encoder.key_layer_norm.weight" not in state_dict.keys():
+ logging.info("Initialize key_layer_norm parameters.")
+ key_layer_norm_state_dict = dict()
+ for k, v in state_dict.items():
+ if k.startswith("encoder.final_layer_norm."):
+ k = k.replace("encoder.final_layer_norm.", "")
+ key_layer_norm_state_dict[k] = deepcopy(v)
+ manually_initialized_params.append(f"encoder.key_layer_norm.{k}")
+ model.encoder.key_layer_norm.load_state_dict(key_layer_norm_state_dict)
+ logging.info(f"manually initialized parameters: {manually_initialized_params}")
+ # miss_keys = load_info["missing_keys"]
+ model.resize_token_embeddings(len(tokenizer))
+
+ if model.config.decoder_start_token_id is None:
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
+ assert config.value_fusion_method == args.value_fusion_method
+ return config, tokenizer, model
+
+
+def save_model(model, save_dir, accelerator=None, tokenizer=None, arguments=None):
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir, exist_ok=True)
+
+ if accelerator is not None:
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ unwrapped_model.save_pretrained(save_dir, save_function=accelerator.save)
+ else:
+ model.save_pretrained(save_dir)
+
+ if tokenizer is not None:
+ if accelerator is None:
+ tokenizer.save_pretrained(save_dir)
+ elif accelerator.is_local_main_process:
+ tokenizer.save_pretrained(save_dir, save_function=accelerator.save)
+
+ if arguments is not None:
+ json.dump(vars(arguments), open(os.path.join(save_dir, "args.json"), "w"), indent=4)
+
+
+def update_CAT_config_from_args(config, args):
+ # Key-value memory related config (restore from CAT configs if not specified)
+ config.prefix_length = getattr(config, "prefix_length", args.prefix_length)
+ config.key_layer = getattr(config, "key_layer", args.key_layer)
+
+ if args.d_key is None:
+ assert args.value_fusion_method is None or "+" in args.value_fusion_method
+ args.d_key = config.d_model * args.prefix_length
+
+ if args.key_encoder_type != "prefix":
+ assert args.d_key is not None
+
+ if args.d_key is None and getattr(config, "d_key", None) is None:
+ config.d_key = config.d_model
+ else:
+ config.d_key = getattr(config, "d_key", args.d_key)
+
+ config.key_encoder_type = getattr(config, "key_encoder_type", args.key_encoder_type)
+ config.value_layer = args.value_layer if args.value_layer is not None else config.value_layer
+ config.cat_layer = getattr(config, "cat_layer", args.cat_layer)
+ config.num_values = args.num_values if args.num_values is not None else config.num_values
+ config.use_two_prefix = getattr(config, "use_two_prefix", args.use_two_prefix)
+ config.not_share_encoder = args.not_share_encoder
+ # config.value_fusion_method = getattr(config, "value_fusion_method", args.value_fusion_method)
+ config.value_fusion_method = args.value_fusion_method
+
+ if args.adapter is not None:
+ config.adapter = args.adapter
+ config.adapter_out_dim = args.adapter_out_dim
+
+
+def find_positive_and_k_negative(top_indices, target_answers: List[List], qas_to_retrieve_from, k=8):
+ positive_idx_of_qas: List[List] = []
+ negative_idx_of_qas: List[List] = []
+ with_positive_flag = []
+ for indexes, target in zip(top_indices, target_answers):
+ normalized_target = [normalize_answer(t) for t in target]
+ positive_indexes = []
+ negative_indexes = []
+ for index in indexes:
+ retrieved_qa = qas_to_retrieve_from[index]
+ retrieved_answer = normalize_answer(retrieved_qa["answer"][0])
+ if retrieved_answer in normalized_target:
+ # if len(negative_indexes) < k:
+ positive_indexes.append(index)
+ else:
+ if len(negative_indexes) < k:
+ negative_indexes.append(index)
+ if len(positive_indexes) > 0 and len(negative_indexes) >= k:
+ break
+ if len(positive_indexes) > 0:
+ positive_idx_of_qas.append(positive_indexes)
+ negative_idx_of_qas.append(negative_indexes)
+ with_positive_flag.append(True)
+ else:
+ with_positive_flag.append(False)
+ return positive_idx_of_qas, negative_idx_of_qas, with_positive_flag
+
+
+class QAs:
+ def __init__(self, max_qa_idx, paq_data_root, is_iter=False):
+ self.paq_data_root = paq_data_root
+ self.max_qa_idx = max_qa_idx
+ self.__iter_idx = 0
+ self.__is_iter = is_iter
+
+ @functools.lru_cache(maxsize=int(5e6))
+ def __getitem__(self, item):
+ p = os.path.join(self.paq_data_root, f"{item}")
+ if item >= self.max_qa_idx:
+ raise IndexError
+ return pickle.load(open(p, 'rb'))
+
+ def __len__(self):
+ return self.max_qa_idx
+
+ def __not_cached_getitem(self, item):
+ if item >= self.max_qa_idx:
+ raise IndexError
+ else:
+ return pickle.load(open(os.path.join(self.paq_data_root, f"{item}"), 'rb'))
+
+ def __iter__(self):
+ if self.__is_iter:
+ return self
+ else:
+ return QAs(paq_data_root=self.paq_data_root, max_qa_idx=self.max_qa_idx, is_iter=True)
+
+ def __next__(self):
+ try:
+ item = self.__not_cached_getitem(self.__iter_idx)
+ except IndexError:
+ raise StopIteration
+ self.__iter_idx += 1
+ return item
+
+
+class CATArgs:
+ def __init__(self, exp_type=None):
+ self.exp_type = exp_type
+ self.parser: argparse.ArgumentParser = argparse.ArgumentParser(description="CAT Arguments")
+ self.parser.add_argument("--project_name", type=str, required=True, help="Project name.")
+ self.parser.add_argument("--exp_name", type=str, required=True, help="Experiment name.")
+
+ self.add_basic_arguments()
+ self.add_cat_arguments()
+
+ if self.exp_type == "build_kvm":
+ self.add_build_kvm_arguments()
+ elif self.exp_type == "pretrain":
+ self.add_pretrain_arguments()
+ elif self.exp_type == "qa_cat":
+ self.add_qa_arguments()
+ elif self.exp_type == "nli_cat":
+ self.add_nli_arguments()
+ elif self.exp_type == "nli_pretrain":
+ self.add_nli_pretrain_arguments()
+ elif self.exp_type == "NLU_cat":
+ self.add_NLU_arguments()
+ elif self.exp_type == "dialog_cat":
+ self.add_dialog_arguments()
+ else:
+ raise ValueError(f"Experiment type {self.exp_type} is not defined.")
+
+ def add_cat_arguments(self):
+ group = self.parser.add_argument_group("Arguments for CAT model.")
+ group.add_argument("--prefix_length", type=int, default=None, help="Length of the prefix.")
+ group.add_argument("--use_two_prefix", action="store_true",
+ help="Use two independent prefixes to represent key and value.")
+ group.add_argument("--d_key", type=int, default=None, help="The dimension of key embeddings.")
+ group.add_argument("--num_values", type=int, default=1, help="Number of values returned from KV memory.")
+ group.add_argument("--value_layer", type=int, default=None, help="The layer that imports Value embedding.")
+ group.add_argument("--key_layer", type=int, default=None, help="The layer that computes the Key embedding.")
+ group.add_argument("--key_encoder_type", type=str, default="linear", choices=["linear", "conv", "prefix"],
+ help="The type of the key encoder module.")
+ group.add_argument("--separate_task", action="store_true", help="Separate the input of Key-AE and Value-AE.")
+
+ group.add_argument("--cat_layer", type=int, default=None, help="The layer that cats key-embedding.")
+
+ group.add_argument("--not_share_encoder", action="store_true", help="Do not share Key-Value encoder with CAT.")
+
+ group.add_argument("--adapter", help="can be assigned to `linear", required=False, type=str)
+ group.add_argument("--adapter_out_dim", required=False, type=int)
+ # group.add_argument("--adapter_ckpt_path", required=False, type=str)
+
+ def add_nli_arguments(self):
+ group = self.parser.add_argument_group("Arguments for training CAT-NLI model.")
+
+ # training args
+ # Key-Value Memory args
+ group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.",
+ choices=["concat", "sum", "avg"])
+ group.add_argument("--kvm_fp16", action="store_true", help="FP16 Key-Value Memory")
+ group.add_argument("--retrieve_strategy", type=str, default="bm25", required=False)
+ group.add_argument("--do_pretrain", action="store_true")
+ group.add_argument("--add_ae_weight", type=float, default=0.0)
+ group.add_argument("--use_triple_loss", action="store_true")
+
+ group.add_argument("--filter_type", type=str, default=None, required=False, help="filter BM25-results.")
+ group.add_argument("--select_type", type=str, default=None, required=False, help="select BM25-results.")
+ group.add_argument("--local_size", type=int, default=512, required=False)
+
+ group.add_argument("--key_matching_weight", type=float, default=0.0)
+ group.add_argument("--add_vae", action="store_true", help="default only use kae if add_ae_weight > 0.")
+
+ # CAT-NLI architecture settings
+ group.add_argument("--dataset_name", type=str, choices=["mnli", "snli"], required=False, default="mnli")
+ group.add_argument("--decoder_only_attend_on_prefix", action="store_true",
+ help="Set the decoder only attend on the prefix part of encoder's output.")
+ group.add_argument("--value_fusion_method", type=str, required=True, help="Assign how to use Value.")
+ group.add_argument("--values_with_order", action="store_true",
+ help="when num_values > 1, if we put values by the similarity order.")
+ group.add_argument("--group_cases_by_label", action="store_true",
+ help="select_type must be ``select_different_labels`` ")
+ group.add_argument("--order_strategy", type=str, default="order_by_label", required=False, )
+ group.add_argument("--order_by_scores", action="store_true")
+
+ # Continue pretraining task settings
+ group.add_argument("--value_repr", type=str, default="label")
+ group.add_argument("--key_repr", type=str, default="hyp_prem")
+ # "mnli hypothesis: xxx premise: "
+ group.add_argument("--do_test", action="store_true")
+
+ def add_NLU_arguments(self):
+ group = self.parser.add_argument_group("Arguments for General-CAT-NLU model.") # for GLUE and ...
+ group.add_argument("--dataset_name", type=str, required=False, default="snli",
+ choices=["mnli", "snli", "commonsense_qa"], )
+ group.add_argument("--retrieve_strategy", type=str, default="bm25", required=False)
+ group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.",
+ choices=["concat", "sum", "avg"])
+ group.add_argument("--do_pretrain", action="store_true")
+ group.add_argument("--kvm_fp16", action="store_true", help="FP16 Key-Value Memory")
+ group.add_argument("--add_ae_weight", type=float, default=0.0)
+ group.add_argument("--use_triple_loss", action="store_true")
+ group.add_argument("--filter_type", type=str, default=None, required=False, help="filter BM25-results.")
+ group.add_argument("--select_type", type=str, default=None, required=False, help="select BM25-results.")
+ group.add_argument("--local_size", type=int, default=64, required=False)
+ group.add_argument("--key_matching_weight", type=float, default=0.0)
+ group.add_argument("--add_vae", action="store_true", help="default only use kae if add_ae_weight > 0.")
+ group.add_argument("--decoder_only_attend_on_prefix", action="store_true",
+ help="Set the decoder only attend on the prefix part of encoder's output.")
+ group.add_argument("--value_fusion_method", type=str, required=True, help="Assign how to use Value.")
+ group.add_argument("--values_with_order", action="store_true",
+ help="when num_values > 1, if we put values by the similarity order.")
+ group.add_argument("--group_cases_by_label", action="store_true",
+ help="select_type must be ``select_different_labels`` ")
+ group.add_argument("--order_strategy", type=str, default="order_by_label", required=False, )
+ group.add_argument("--order_by_scores", action="store_true")
+ group.add_argument("--do_test", action="store_true")
+
+ def add_qa_arguments(self):
+ group = self.parser.add_argument_group("Arguments for training CAT-QA model.")
+ # updated in 2022-06-04
+ group.add_argument("--build_mem_batch_size", type=int, default=2048)
+ group.add_argument("--batch_local_positive_num", type=int, default=5)
+ group.add_argument("--truncate_exists_local_qa", type=int, required=False, default=None)
+ group.add_argument("--use_fp16_rank", action="store_true")
+ group.add_argument("--use_fp16_kvm", action="store_true", help="not implement")
+ group.add_argument("--PAQ_size", type=int, required=False, help="truncate PAQ to target size.")
+ group.add_argument("--do_test", action="store_true")
+ group.add_argument("--query_batch_size", type=int, required=True)
+ group.add_argument("--only_key_matching_n_epoch", type=int, required=False, default=-1)
+ group.add_argument("--gen_target_is_key_match_target", action="store_true")
+
+ # QA args
+ group.add_argument("--qa_data_name", type=str, help="choose data files from pre-defined ``DATA_PATH``")
+
+ # training args
+ group.add_argument("--search_positive_in_top_k", type=int, required=False, default=2048,
+ help="Search positives to train the key-mathing task.")
+ group.add_argument("--hard_negative_num", type=int, required=False, default=12)
+ group.add_argument("--least_negative_num_per_batch", type=int, required=False, default=64)
+ group.add_argument("--select_positive_strategy", type=str, required=True,
+ help="The strategy of selecting one positive example for HardEM training.")
+ group.add_argument("--faiss_efsearch", type=int, required=False, default=128, help="hnsw ef_search parameter")
+ group.add_argument("--gen_weight", type=float, required=False, default=1.0,
+ help="Answer generation loss weight.")
+ group.add_argument("--match_weight", type=float, required=False, default=1.0,
+ help="Key matching loss weight.")
+ group.add_argument("--repaq_supervision_epoch", type=int, required=False, default=-1,
+ help="Use RePAQ's retrieval results as Key-matching supervision."
+ "Where we do not change the local target and, because we only prepare"
+ "one ``cur_positive_qa``/``local_positive``, the local negative is also fixed")
+ group.add_argument("--only_rank_exists_local_qa", action="store_true",
+ help="Do not collect Local-QAs from entire PAQ-L1, only rank the exists Local-QAs"
+ "that retrieved by RePAQ (x large)")
+ group.add_argument("--negatives_num_each_example", type=int, required=False, default=50,
+ help="sample negatives from local qas")
+
+ # Key-Value Memory args
+ group.add_argument("--kvm_dir", type=str, default=None, required=False,
+ help="The directory of Key-Value Memory")
+ group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.",
+ choices=["concat", "sum", "avg"])
+ group.add_argument("--qas_to_retrieve_from", type=str, help="QAs corresponding to Key-Value Memory")
+ group.add_argument("--kvm_fp16", action="store_true", help="Load FP16 Key-Value Memory")
+ group.add_argument("--kvm_seg_n", type=int, required=False, default=1, help="when key-memory is too large, "
+ "segment it to kvm_seg_n pieces")
+
+ # CAT-QA architecture settings
+ group.add_argument("--update_kv_embeds", action="store_true", help="Re-embed Key and Value while training.")
+ group.add_argument("--local_size", type=int, required=False, help="Number of local QAs to retrieve.")
+ group.add_argument("--update_local_qas", action="store_true", help="Update local QAs every epoch.")
+ group.add_argument("--update_local_target_each_batch", action="store_true",
+ help="Update positive and negative each batch. Otherwise, each epoch.")
+ group.add_argument("--use_not_exactly_true", action="store_true",
+ help="Input the top-1 though it is not exactly true when training the generation. "
+ "The not exactly true example will not supervise the Key-Mathing.")
+ group.add_argument("--decoder_only_attend_on_prefix", action="store_true",
+ help="Set the decoder only attend on the prefix part of encoder's output.")
+ group.add_argument("--value_fusion_method", type=str, required=True, help="Assign how to use Value.")
+ # group.add_argument("--update_kv_batch_size", default=1024, help="Batch size of KV Re-Embed.")
+ group.add_argument("--try_to_put_one_positive_in_values", action="store_true",
+ help="when num_values > 1, if we force to put at least one positive QA to Values")
+ group.add_argument("--values_with_order", action="store_true",
+ help="when num_values > 1, if we put values by the similarity order.")
+
+ group.add_argument("--pos_from_top", type=int, required=False, default=50,
+ help="for each batch, select lexical-positive QAs from ranked-local-QAs from top-X")
+
+ group.add_argument("--rerank_retrieved_values", action="store_true")
+
+ # Continue pretraining task settings
+ group.add_argument("--add_ae_task", action="store_true", help="Add the pretraining(Auto-Encoding) task.")
+ group.add_argument("--ae_weight", type=float, required=False, default=0.1,
+ help="When set --add_ae_task, the auto-encoding loss weight.")
+ group.add_argument("--ae_batch_size", type=int, default=None, help="The batch size of AE task.")
+ group.add_argument("--value_ae_target", type=str, default="ans", choices=["ans", "question_ans"])
+ group.add_argument("--key_ae_target", type=str, default="question_ans", choices=["question", "question_ans"])
+
+ # adapter training
+ group.add_argument("--only_train_adapter", action="store_true")
+ group.add_argument("--use_adapter_to_select_positive_after_k_epoch", type=int,
+ default=float("inf"), required=False)
+
+ def add_dialog_arguments(self):
+ group = self.parser.add_argument_group("Arguments for training CAT-Dialog model.")
+ # updated in 2022-06-18
+ group.add_argument("--build_mem_batch_size", type=int, default=2048)
+ group.add_argument("--batch_local_positive_num", type=int, default=5)
+ group.add_argument("--do_test", action="store_true")
+ group.add_argument("--query_batch_size", type=int, required=True)
+ group.add_argument("--add_persona", action="store_true")
+ group.add_argument("--add_topic", action="store_true")
+ group.add_argument("--update_kv_embeds", action="store_true", help="Re-embed Key and Value while training.")
+ group.add_argument("--eval_every_n_steps", required=False, default=None, type=int)
+ group.add_argument("--shortest_answer_len", required=False, default=None, type=int)
+
+ # training args
+ group.add_argument("--select_positive_strategy", type=str, required=False, default="softmax_sample",
+ help="The strategy of selecting one positive example for HardEM training.")
+ group.add_argument("--faiss_efsearch", type=int, required=False, default=128, help="hnsw ef_search parameter")
+ group.add_argument("--gen_weight", type=float, required=False, default=1.0,
+ help="Answer generation loss weight.")
+ group.add_argument("--match_weight", type=float, required=False, default=1.0,
+ help="Key matching loss weight.")
+ group.add_argument("--negatives_num_each_example", type=int, required=False, default=50,
+ help="sample negatives from local qas")
+ group.add_argument("--qa_data_name", type=str, help="choose data files from pre-defined ``DATA_PATH``")
+
+ # Key-Value Memory args
+ group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.",
+ choices=["concat", "sum", "avg"])
+ group.add_argument("--qas_to_retrieve_from", type=str, help="QAs corresponding to Key-Value Memory")
+ group.add_argument("--kvm_fp16", action="store_true", help="Load FP16 Key-Value Memory")
+ group.add_argument("--kvm_seg_n", type=int, required=False, default=1, help="when key-memory is too large, "
+ "segment it to kvm_seg_n pieces")
+
+ group.add_argument("--local_size", type=int, required=False, help="Number of local QAs to retrieve.")
+ group.add_argument("--decoder_only_attend_on_prefix", action="store_true",
+ help="Set the decoder only attend on the prefix part of encoder's output.")
+ group.add_argument("--value_fusion_method", type=str, required=True, help="Assign how to use Value.")
+ group.add_argument("--values_with_order", action="store_true",
+ help="when num_values > 1, if we put values by the similarity order.")
+ group.add_argument("--pos_from_top", type=int, required=False, default=128,
+ help="for each batch, select lexical-positive QAs from ranked-local-QAs from top-X")
+
+ def add_nli_pretrain_arguments(self):
+ group = self.parser.add_argument_group("Arguments for pretraining NLI-CAT.")
+ group.add_argument("--value_ae_weight", type=float, default=1.0,
+ help="Weight for the Auto-Encoding loss of Value.")
+ group.add_argument("--key_ae_weight", type=float, default=1.0,
+ help="Weight for the Auto-Encoding loss of Key.")
+ group.add_argument("--train_value", action="store_true")
+ group.add_argument("--train_key", action="store_true")
+ group.add_argument("--separate_decode", action="store_true")
+
+ group.add_argument("--value_fusion_method", type=str, required=True,
+ help="Assign how to use Value. Preassigned in pretrain.")
+
+ def add_pretrain_arguments(self):
+ group = self.parser.add_argument_group("Arguments for pretraining CAT.")
+ # add in 2022-06-19
+ group.add_argument("--value_input_is_qa", action="store_true")
+
+ group.add_argument("--pretrain_data_name", type=str, help="choose data files from pre-defined ``DATA_PATH``")
+ group.add_argument("--value_ae_weight", type=float, default=1.0,
+ help="Weight for the Auto-Encoding loss of Value.")
+ group.add_argument("--key_ae_weight", type=float, default=1.0,
+ help="Weight for the Auto-Encoding loss of Key.")
+ group.add_argument("--value_ae_target", type=str, default="ans", choices=["ans", "question_ans"])
+ group.add_argument("--key_ae_target", type=str, default="question_ans", choices=["question", "question_ans"])
+ group.add_argument("--train_value", action="store_true")
+ group.add_argument("--train_key", action="store_true")
+ group.add_argument("--pretrain_multi_values", action="store_true")
+ group.add_argument("--value_fusion_method", type=str, required=False, help="Assign how to use Value.")
+ group.add_argument("--decoder_only_attend_on_prefix", action="store_true",
+ help="Set the decoder only attend on the prefix part of encoder's output.")
+ group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.",
+ choices=["concat", "sum", "avg"])
+ group.add_argument("--gen_weight", type=float, default=1.0)
+ group.add_argument("--value_with_self_prop", type=float, default=1.0)
+ group.add_argument("--values_with_order", action="store_true",
+ help="when num_values > 1, if we put values by the similarity order.")
+
+ def add_build_kvm_arguments(self):
+ group = self.parser.add_argument_group("Arguments for building Key-Value Memory.")
+ group.add_argument("--embed_key", action="store_true")
+ group.add_argument("--embed_value", action="store_true")
+ group.add_argument("--embed_as_fp16", action="store_true")
+ group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.",
+ choices=["concat", "sum", "avg"])
+ group.add_argument("--embed_data_name", type=str, help="choose data files from pre-defined ``DATA_PATH``")
+
+ def parse_args(self, print_args=True):
+ args = self.parser.parse_args()
+ if print_args:
+ args_str = json.dumps(vars(args), indent=4, ensure_ascii=False)
+ logging.info(f"Show the arguments: \n{args_str}")
+
+ assert args.value_layer >= args.key_layer
+ if args.num_values > 1:
+ assert args.value_fusion_method in ["cat_v", "cat_kv", "cat_k+v", "cat_avgk+v", "cat_k_delay+v",
+ "cat_k+v_g(kq)", "cat_k_delay+v_g(kv)",
+ "async_cat_k+v", "async_cat_k_delay+v"]
+ if "delay" not in args.value_fusion_method:
+ if args.cat_layer is None and "k" in args.value_fusion_method:
+ assert args.key_layer == args.value_layer
+ elif "async_cat_k+v" == args.value_fusion_method:
+ assert args.cat_layer == args.value_layer
+
+ if self.exp_type == "pretrain":
+ assert args.train_key or args.train_value, "at least train one of them"
+ assert not os.path.exists(os.path.join(args.output_dir, "best_ckpt")), \
+ "The experiment has done before. Clear the dir before running"
+ else:
+ assert self.exp_type in ["build_kvm", "qa_cat", "nli_cat", "nli_pretrain", "NLU_cat", "dialog_cat"]
+
+ return args
+
+ def add_basic_arguments(self):
+ group = self.parser.add_argument_group("Arguments for basic settings.")
+ # generation args
+ group.add_argument("--num_beams", type=int, default=None,
+ help="Number of beams to use for evaluation. This argument will be passed to "
+ "``model.generate``, which is used during ``evaluate`` and ``predict``.")
+ # data args
+ group.add_argument("--source_prefix", type=str, default=None,
+ help="A prefix to add before every source text (useful for T5 models).")
+ group.add_argument("--max_source_length", type=int, default=1024,
+ help="The maximum total input sequence length after tokenization."
+ "Sequences longer than this will be truncated, sequences shorter will be padded.")
+ group.add_argument("--max_target_length", type=int, default=128,
+ help="The maximum total sequence length for target text after tokenization. "
+ "Sequences longer than this will be truncated, "
+ "sequences shorter will be padded. during ``evaluate`` and ``predict``.")
+ group.add_argument("--val_max_target_length", type=int, default=None,
+ help="The maximum total sequence length for validation target text after tokenization."
+ "Sequences longer than this will be truncated, sequences shorter will be padded."
+ "Will default to `max_target_length`. This argument is also used to override the "
+ "``max_length`` param of ``model.generate``, "
+ "which is used during ``evaluate`` and ``predict``.")
+ group.add_argument("--pad_to_max_length", type=bool, default=False,
+ help="Whether to pad all samples to model maximum sentence length. If False, will pad "
+ "the samples dynamically when batching to the maximum length in the batch. More"
+ "efficient on GPU but very bad for TPU.", )
+
+ # training args
+ group.add_argument("--ignore_pad_token_for_loss", type=bool, default=True,
+ help="Whether to ignore the tokens corresponding to padded "
+ "labels in the loss computation")
+ group.add_argument("--preprocessing_num_workers", type=int, default=None,
+ help="The number of processes to use for the preprocessing.")
+ group.add_argument("--model_name_or_path", type=str, required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.")
+ group.add_argument("--config_name", type=str, default=None,
+ help="Pretrained config name or path if not the same as model_name")
+ group.add_argument("--tokenizer_name", type=str, default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name")
+ group.add_argument("--use_slow_tokenizer", action="store_true",
+ help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).")
+ group.add_argument("--per_device_train_batch_size", type=int, required=True,
+ help="Batch size (per device) for the training dataloader.")
+ group.add_argument("--per_device_eval_batch_size", type=int, default=8,
+ help="Batch size (per device) for the evaluation dataloader.")
+
+ group.add_argument("--learning_rate", type=float, default=5e-5,
+ help="Initial learning rate (after the potential warmup period) to use.")
+ group.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
+ group.add_argument("--num_train_epochs", type=int, default=20,
+ help="Total number of training epochs to perform.")
+ group.add_argument("--early_stop_patience", type=int, default=1000000,
+ help="Early stop if the performance does not improve for this number of epochs .")
+
+ group.add_argument("--max_train_steps", type=int, default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.")
+ group.add_argument("--gradient_accumulation_steps", type=int, default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.")
+ group.add_argument("--lr_scheduler_type", type=str, default="linear", help="The scheduler type to use.",
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant",
+ "constant_with_warmup"], )
+ group.add_argument("--num_warmup_steps", type=int, default=0,
+ help="Number of steps for the warmup in the lr scheduler.")
+ group.add_argument("--output_dir", type=str, required=True, help="Where to store the final model.")
+ group.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ group.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.",
+ choices=MODEL_TYPES)
+
+ group.add_argument("--do_train", action="store_true", help="Whether to train the model on the train set.")
+ group.add_argument("--do_eval", action="store_true", help="Whether to evaluate on the dev set.")
+ group.add_argument("--eval_freq", type=int, default=1,
+ help="Frequency of evaluation on the dev set (if do_eval is True).")
+ group.add_argument("--resume_training", action="store_true", help="Resume training from the latest checkpoint.")
+ group.add_argument("--freeze_t5_params", action="store_true", help="Freeze the original T5 parameters.")
+ group.add_argument("--per_epoch_eval_times", type=int, default=1,
+ help="do eval many times per epoch", required=False)