diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..75691cb --- /dev/null +++ b/.gitignore @@ -0,0 +1,156 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# (jimmycode) +.python_history +.idea +pretrained_models +data/* +logs/ +outputs/ +runs/ +results/ +checkpoints/ +cached_models/ +paq_models/ +.DS_Store +etc/ +cached_outputs/ +evaluate/ +lightning_logs/ +wandb/ +artefacts +Icon* +.netrc +/kv_scripts/ +nlu_dataset.py +nlu_trainer.py +models + diff --git a/README.md b/README.md index da90d98..f46cab8 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,132 @@ -# EMAT -Efficient Memory-Augmented Transformers +# EMAT: An Efficient Memory-Augmented Transformer for Knowledge-Intensive NLP Tasks + +## Installation and Setup + +```shell +# create a conda environment +conda create -n emat -y python=3.8 && conda activate emat + +# install pytorch +pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html # GPU +pip install torch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 # CPU + +# install transformers +pip install transformers==4.14.1 + +# install faiss +pip install faiss-gpu==1.7.1.post3 # GPU +pip install faiss-cpu==1.7.1.post3 # CPU + +# install dependencies +pip install -r requirements.txt + +# install this package for development +pip install -e . +``` + +## Download datasets + +[//]: # (NaturalQuestion, WebQuestion, TriviaQA, WoW_KILT, ELI5_KILT data:) + +link: https://pan.baidu.com/s/1MwPzVLqZqslqCpWAtPVZ-Q + +code: tynj + +Download PAQ data from: https://github.com/facebookresearch/PAQ + + +## Run Interactive Script + +Before running the following scripts, embeddings of key-value memory, index and PAQ should be prepared. +See [Start](#Start) to build your key-value memory and index. + + +NQ: use torch-embedding as retrieval index: +```bash +python demo.py \ + --model_name_or_path="./EMAT_ckpt/FKSV-NQ" \ + --qas_to_retrieve_from="./data/PAQ_L1" \ + --test_task="nq" \ + --task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \ + --task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \ + --embedding_index="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/embedding_index.pt" + --key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key_memory.pt" \ + --value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/value_memory.pt" +``` + +NQ: use faiss as retrieval index: +```bash +python demo.py \ + --model_name_or_path="./EMAT_ckpt/FKSV-NQ" \ + --qas_to_retrieve_from="./data/PAQ_L1" \ + --test_task="nq" \ + --task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \ + --task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \ + --use_faiss \ + --faiss_index_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key.sq8hnsw.80n80efc.faiss" \ + --key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/key_memory.pt" \ + --value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_ckpt/value_memory.pt" +``` + +Use SKSV model with faiss parallely search: +```bash +python demo.py \ + --model_name_or_path="./EMAT_ckpt/SKSV-NQ" \ + --qas_to_retrieve_from="./data/PAQ_L1" \ + --test_task="nq" \ + --task_train_data="./annotated_datasets/NQ-open.train-train.jsonl" \ + --task_dev_data="./annotated_datasets/NQ-open.train-dev.jsonl" \ + --use_faiss \ + --faiss_index_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/key.sq8hnsw.80n80efc.faiss" \ + --key_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/key_memory.pt" \ + --value_memory_path="./embedding_and_faiss/PAQ_L1_from_nq_SKSV_ckpt/value_memory.pt" +``` + +Run Wizard-of-Wikipedia Dialogue: +```bash +python demo.py \ + --model_name_or_path="./EMAT_ckpt/FKSV-WQ/" \ + --qas_to_retrieve_from="./tmp/PAQ_L1.pkl" \ + --test_task="wow_kilt" \ + --embedding_index_path="./embedding_and_faiss/debug_from_wow_ckpt/embedding_index.pt" \ + --key_memory_path="./embedding_and_faiss/PAQ_L1_from_wow_ckpt/key_memory.pt" \ + --value_memory_path="./embedding_and_faiss/PAQ_L1_from_wow_ckpt/value_memory.pt" \ + --inference_data_path="./annotated_datasets/wizard_of_wikipedia/wow-test_without_answers-kilt.jsonl.txt" +``` + +## Start + + +### 1. Pre-training + +Pre-train EMAT-FKSV: `bash pretrain_scripts/pretrain_emat.sh` + +Pre-train EMAT-SKSV: `bash pretrain_scripts/pretrain_sksv_emat.sh` + +### 2. Fine-tune: + +Fine-tune on NQ: `bash scripts/nq_train_with_paql1.sh` + +Fine-tune on TQ: `bash scripts/tq_train_with_paql1.sh` + +Fine-tune on WQ: `bash scripts/wq_train_with_paql1.sh` + +Fine-tune on WoW : `bash kilt_scripts/wow_train.sh` + +Fine-tune on ELI5: `bash kilt_scripts/eli5_train.sh` + + +### 3. Evaluation: + +Evaluate NQ/TQ/WQ: `bash scripts/nq_eval.sh`, switch ``DATA_NAME`` to evaluate different dataset. + +Evaluate WoW/ELI5: `bash kilt_scirpts/eval_kilt.sh`. You can upload the output prediction file to http://kiltbenchmark.com/ to get evaluation results. + +### 4. Embed PAQ using fine-tuned NQ model and build FAISS index: +```bash +bash embed_scripts/nq_embed_paq_and_build_faiss.sh +``` + +### 5. Inference Speed +Test inference speed on ```inference_with_faiss.py``` + diff --git a/build_kvm.py b/build_kvm.py new file mode 100644 index 0000000..df382a9 --- /dev/null +++ b/build_kvm.py @@ -0,0 +1,297 @@ +import argparse +import copy +import json +import logging +import math +import os +import random + +import datasets +import torch +import transformers +from accelerate import Accelerator +from datasets import load_dataset, load_metric, DatasetDict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + set_seed, + T5Tokenizer, +) +from transformers.utils.versions import require_version + +from emat.t5 import T5WithKeyValueMemory +from transformers import T5Config +from emat.utils import load_jsonl, write_jsonl, verbalise_qa +from utils.utils import reduce_query_or_key_embeds, save_model, CATArgs, update_CAT_config_from_args, load_model, \ + get_key_value_encoder_inputs + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + +logger = logging.getLogger(__name__) +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") + +DATA_PATHS = { + "PAQ-L1": "./data/cbqa_data/pretrain_data/PAQ_L1/PAQ_L1.filtered.jsonl", + "data_for_debug": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-dev-3000.jsonl" +} + + +def load_paq_data(args) -> DatasetDict: + assert args.embed_data_name in DATA_PATHS.keys(), f"available dataset names: {DATA_PATHS.keys()}" + data_path = DATA_PATHS[args.embed_data_name] + return load_dataset("json", data_files=data_path) + + +@torch.no_grad() +def build_memory(model, tokenizer, output_dir=None, embed_key=False, embed_value=False, prefix="", + embed_as_fp16=False, key_reduce_method=None, data_path=None, data_to_embed=None, + max_source_length=None, padding=None, batch_size=1, allow_overlay_old_memory=False, + dump_memory=False, return_memory=False, separate_task=False, kvm_seg_n=-1, + disable_tqdm=False, reused_key_memory=None, collate_fn=None, normed_key_memory=True, + return_not_reduced_key=False, reused_not_reduced_key_memory=None, reused_value_memory=None, + num_workers=4, use_retrieval_adapter=False): + torch.cuda.empty_cache() + if data_to_embed is None: + data_to_embed = load_dataset("json", data_files=data_path)["train"] + + if collate_fn is None: + def collate_fn(examples): + model_inputs = get_key_value_encoder_inputs(examples, separate_task, tokenizer, max_source_length, + prefix=prefix, only_return_key_inputs=not embed_value) + return model_inputs + + qas_to_embed_dataloader = DataLoader(data_to_embed, batch_size=batch_size, num_workers=num_workers, + collate_fn=collate_fn) + + key_memory: list = [] + value_memory: list = [] + not_reduced_key_memory = [] if return_not_reduced_key else None + model.eval() + + key_cnt = 0 + for batch in tqdm(qas_to_embed_dataloader, disable=disable_tqdm): + # for start_idx in tqdm(range(0, len(data_to_embed), batch_size), total=len(data_to_embed) // batch_size): + # batch_qas = data_to_embed[start_idx: start_idx + batch_size] + # batch = get_key_value_encoder_inputs(batch_qas, separate_task, tokenizer, max_source_length, + # prefix=prefix, only_return_key_inputs=True) + with torch.no_grad(): + batch_keys = list(batch.keys()) + + # for k in batch_keys: + # v = batch.pop(k) + # batch[k] = v.to(model.device) + # del v + batch = {k: v.to(model.device) for k, v in batch.items()} + + embed_dict = model.wrapped_embed_kv( + separate_task=separate_task, + compute_key=embed_key, + compute_value=embed_value, + # key_input_ids=batch["key_input_ids"].to(model.device), + # key_attention_mask=batch["key_attention_mask"].to(model.device), + # value_input_ids=batch.get("key_input_ids", None).to(model.device), + # value_attention_mask=batch.get("key_attention_mask", None).to(model.device), + **batch + ) + + for k in batch_keys: + del batch[k] + + key_embeds = embed_dict.get("normed_key_embeds") if normed_key_memory else embed_dict.get("key_embeds") + value_embeds = embed_dict.get("value_embeds") + if embed_key: + key_embeds = reduce_query_or_key_embeds(key_embeds, key_reduce_method) + if use_retrieval_adapter: + key_embeds = model.adapter(key_embeds) + cur_key_num = key_embeds.shape[0] + + if embed_key: + if embed_as_fp16: + key_embeds = key_embeds.half() + if reused_key_memory is not None: + key_embeds = key_embeds.cpu() + reused_key_memory[key_cnt: key_cnt + cur_key_num] = copy.deepcopy(key_embeds) + del key_embeds + else: + key_memory.append(key_embeds.cpu()) # [batch_size, hidden_size] + + if return_not_reduced_key: + not_normed_key_embeds = embed_dict["key_embeds"] + if embed_as_fp16: + not_normed_key_embeds = not_normed_key_embeds.half() + if reused_not_reduced_key_memory is not None: + not_normed_key_embeds = not_normed_key_embeds.cpu() + reused_not_reduced_key_memory[key_cnt: key_cnt + cur_key_num] = copy.deepcopy(not_normed_key_embeds) + del not_normed_key_embeds + else: + not_reduced_key_memory.append(not_normed_key_embeds.cpu()) + + if embed_value: + if embed_as_fp16: + value_embeds = value_embeds.half() + if reused_value_memory is not None: + value_embeds = value_embeds.cpu() + reused_value_memory[key_cnt: key_cnt + cur_key_num] = copy.deepcopy(value_embeds) + del value_embeds + else: + value_memory.append(value_embeds.cpu()) # [batch_size, value_nums, hidden_size] + + key_cnt += cur_key_num + + if reused_key_memory is None: + if embed_key: + assert sum(i.shape[0] for i in key_memory) == len(data_to_embed) + if return_not_reduced_key: + assert sum(i.shape[0] for i in not_reduced_key_memory) == len(data_to_embed) + if embed_value: + assert sum(i.shape[0] for i in value_memory) == len(data_to_embed) + + if dump_memory: + assert reused_key_memory is None, "Not Implement when reused_key_memory is set." + chunk_num = 128 + chunk_batch_size = math.ceil(len(key_memory) / chunk_num) + if embed_key: + logger.info("dump key") + key_dir = os.path.join(output_dir, "key") + os.makedirs(key_dir, exist_ok=allow_overlay_old_memory) + save_num = 0 + for cid, start_idx in tqdm(enumerate(range(0, len(key_memory), chunk_batch_size)), leave=True): + chunk_key_memory = torch.cat(key_memory[start_idx: start_idx + chunk_batch_size]) + torch.save(chunk_key_memory, os.path.join(key_dir, f"{cid}.key.pt")) + save_num = save_num + chunk_key_memory.shape[0] + assert save_num == len(data_to_embed), \ + f"saved key num is {save_num}, but example num is {len(data_to_embed)}" + if embed_value: + logger.info("dump value") + value_dir = os.path.join(output_dir, "value") + os.makedirs(value_dir, exist_ok=allow_overlay_old_memory) + save_num = 0 + for cid, start_idx in tqdm(enumerate(range(0, len(value_memory), chunk_batch_size)), leave=True): + chunk_value_memory = torch.cat(value_memory[start_idx: start_idx + chunk_batch_size]) + torch.save(chunk_value_memory, os.path.join(value_dir, f"{cid}.value.pt")) + save_num = save_num + chunk_value_memory.shape[0] + assert save_num == len(data_to_embed), \ + f"saved value num is {save_num}, but example num is {len(data_to_embed)}" + + if return_memory: + if kvm_seg_n > 1: + all_chunk_key_memory = [] + if embed_key: + if reused_key_memory is not None: + logger.info(f"Split reused_key_memory into {kvm_seg_n} chunks.") + chunk_batch_size = math.ceil(len(reused_key_memory) / kvm_seg_n) + for start_idx in range(0, len(reused_key_memory), chunk_batch_size): + end_idx = min(len(reused_key_memory), start_idx + chunk_batch_size) + all_chunk_key_memory.append(reused_key_memory[start_idx:end_idx]) + else: + logger.info(f"Combining the keys into {kvm_seg_n} chunks.") + chunk_batch_size = math.ceil(len(key_memory) / kvm_seg_n) + for cid, start_idx in tqdm(enumerate(range(0, len(key_memory), chunk_batch_size)), leave=True): + chunk_key_memory = torch.cat(key_memory[start_idx: start_idx + chunk_batch_size]) + all_chunk_key_memory.append(chunk_key_memory) + assert len(all_chunk_key_memory) == kvm_seg_n + + # if return_not_reduced_key: + # not_reduced_key_memory = torch.emat(not_reduced_key_memory) + + if embed_value: + value_memory = torch.cat(value_memory) + + return all_chunk_key_memory, value_memory + + else: + if embed_key: + if reused_key_memory is not None: + key_memory = reused_key_memory + else: + logger.info(f"Combining the result.") + key_memory = torch.cat(key_memory) + + if return_not_reduced_key: + if reused_not_reduced_key_memory is not None: + not_reduced_key_memory = reused_not_reduced_key_memory + else: + not_reduced_key_memory = torch.cat(not_reduced_key_memory) + + if embed_value: + if reused_value_memory is not None: + value_memory = reused_value_memory + else: + value_memory = torch.cat(value_memory) + if return_not_reduced_key: + return key_memory, value_memory, not_reduced_key_memory + else: + return key_memory, value_memory + + +def main(): + # Parse the arguments + cat_args = CATArgs(exp_type="build_kvm") + args = cat_args.parse_args() + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + accelerator = Accelerator() + + # Make one log on every process with the configuration for debugging. + logger.info(accelerator.state) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + config, tokenizer, model = load_model(T5WithKeyValueMemory, args) + model.cuda() + prefix = args.source_prefix if args.source_prefix is not None else "" + + # Temporarily set max_target_length for training. + max_target_length = args.max_target_length + padding = "max_length" if args.pad_to_max_length else True + + # Load the datasets + data_to_embed = load_paq_data(args)["train"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(data_to_embed)), 3): + logger.info(f"Sample {index} of the training set: {data_to_embed[index]}.") + + batch_size = args.per_device_train_batch_size + logger.info("***** Building Key-Value Memory *****") + logger.info(f" Num examples = {len(data_to_embed)}") + logger.info(f" Instantaneous batch size per device = {batch_size}") + # Only show the progress bar once on each machine. + build_memory(model, tokenizer, output_dir=args.output_dir, embed_key=args.embed_key, embed_value=args.embed_value, + prefix=prefix, embed_as_fp16=args.embed_as_fp16, key_reduce_method=args.key_reduce_method, + data_path=None, data_to_embed=data_to_embed, max_source_length=args.max_source_length, padding=padding, + batch_size=batch_size, allow_overlay_old_memory=False, dump_memory=True, return_memory=False, + separate_task=args.separate_task) + + pretrain_args = json.load(open(os.path.join(args.model_name_or_path, "args.json"))) + dict_args = vars(args) + dict_args["loaded_model_args"] = pretrain_args + json.dump(pretrain_args, open(os.path.join(args.output_dir, "kvm_args.json"), 'w'), indent=4) + + +if __name__ == '__main__': + main() diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..f8c61a9 --- /dev/null +++ b/demo.py @@ -0,0 +1,189 @@ +import faiss +import asyncio +import argparse +import torch +from transformers import T5Tokenizer, T5Config +from emat.t5 import T5WithKeyValueMemory +from emat.utils import load_jsonl +import logging +from kilt_dataset import DialogDataset +import pickle +import random +from kilt_trainer import kilt_generate + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + + +def get_args(): + parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Inference with faiss") + parser.add_argument("--model_name_or_path", type=str, required=False, + default="./outputs/nq_checkpoints/KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;pos_from_top=128;/best_ckpt/") + parser.add_argument("--qas_to_retrieve_from", default="./tmp/PAQ_L1_pickl_file.pkl") + parser.add_argument("--test_task", default="nq", type=str, choices=["nq", "wq", "tq", "wow_kilt"]) + parser.add_argument("--task_train_data", default=None, required=False, type=str) + parser.add_argument("--task_dev_data", default=None, required=False, type=str) + parser.add_argument("--use_faiss", action="store_true", help="default -- use torch embedding") + parser.add_argument("--faiss_index_path", default=None, type=str, required=False) + parser.add_argument("--embedding_index_path", default=None, type=str, required=False) + parser.add_argument("--key_memory_path", required=True) + parser.add_argument("--value_memory_path", required=True) + parser.add_argument("--inference_type", type=str, default="serial", choices=["parallel", "serial"]) + parser.add_argument("--inference_data_path", type=str, default=None, required=False) + parser.add_argument("--inference_batch_size", type=int, default=512) + args = parser.parse_args() + + if args.use_faiss: + assert args.faiss_index_path is not None + else: + assert args.embedding_index_path is not None + + return args + + +def main(): + args = get_args() + + # load model + logging.info(f"loading model from {args.model_name_or_path}") + model, load_info = T5WithKeyValueMemory.from_pretrained(args.model_name_or_path, output_loading_info=True) + model.eval() + model = model.cuda() + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path) + logging.info(f"model load info: {load_info}") + # check + if getattr(model, "cat_layer", None) == model.encoder.key_layer: + assert args.inference_type != "parallel", "parallel can not used in cat_layer == key_layer" + + # load index and key-value memory + faiss_index, embedding_index = None, None + if args.use_faiss: + logging.info(f"loading index from {args.faiss_index_path}") + faiss_index = faiss.read_index(args.faiss_index_path) + logging.info("loaded faiss index.") + else: + logging.info(f"loading index from {args.embedding_index_path}") + embedding_index = torch.load(args.embedding_index_path) + logging.info("loaded embedding index.") + value_memory = torch.load(args.value_memory_path) + key_memory = torch.load(args.key_memory_path) + + # load QAs to retrieve + logging.info(f"loading PAQ from {args.qas_to_retrieve_from}") + if args.qas_to_retrieve_from.endswith("pkl"): + qas_to_retrieve = pickle.load(open(args.qas_to_retrieve_from, 'rb')) + else: # jsonl + qas_to_retrieve = load_jsonl(args.qas_to_retrieve_from) + logging.info("loaded PAQ") + if args.test_task in ["nq", "wq", "tq"]: + if args.task_train_data is not None: + qas_to_retrieve = qas_to_retrieve + load_jsonl(args.task_train_data) + if args.task_dev_data is not None: + qas_to_retrieve = qas_to_retrieve + load_jsonl(args.task_dev_data) + assert len(qas_to_retrieve) == value_memory.shape[0] == key_memory.shape[0] + logging.info(f"numer of QAs to retrieve: {len(qas_to_retrieve)}") + + if args.test_task in ["nq", 'wq', 'tq']: + gen_kwargs = {"num_beams": None, "max_length": 64} + else: + gen_kwargs = {"max_length": 1024, "num_beams": 8, "do_sample": True, "top_k": 64, "min_length": 8} + + print("input ``ctrl + c`` to exit the program.") + if args.test_task in ["nq", 'wq', 'tq']: + while True: + question = input("Question: ") + batch = [{"question": question.strip()}] + ans, retrieved_qa = inference_qa(model, tokenizer, key_memory, value_memory, embedding_index, + faiss_index, qas_to_retrieve, args.inference_type, batch, gen_kwargs) + print(f"Answer: {ans[0]}") + print(f"retrieved QAs: ") + for qa in retrieved_qa[0]: + print(qa) + + elif args.test_task == 'wow_kilt': + print("input '-1' to exit current dialogue") + dataset_kwargs = {"dataset_name": "wow_kilt", "max_source_length": 128} + inference_data = load_jsonl(args.inference_data_path) + while True: + cur_dialogue = random.sample(inference_data, 1)[0] + utterances = cur_dialogue["input"].split("\n")[:-1] + for idx, u in enumerate(utterances): + spk = "A" if idx % 2 == 0 else "B" + print(f"{spk}: {u}") + while True: + spk = "A" if len(utterances) % 2 == 0 else "B" + utterance = input(f"{spk}: ") + if utterance == "-1": + break + utterances.append(utterance) + cur_dialogue["input"] = "\n".join(utterances) + dataset = DialogDataset([cur_dialogue], tokenizer, qas_to_retrieve, **dataset_kwargs) + retrieved_qa, response = kilt_generate( + model, tokenizer, embedding_index, key_memory, value_memory, dataset, + qas_to_retrieve, args.inference_batch_size, gen_kwargs + ) + spk = "A" if len(utterances) % 2 == 0 else "B" + print(f"{spk}: {response[0]}") + utterances.append(response[0]) + print("") + + +@torch.no_grad() +def inference_qa(model, tokenizer, key_memory, value_memory, embedding_index, faiss_index, + qas_to_retrieve, inference_type, batch, gen_kwargs): + inputs = ["question: " + qa["question"] for qa in batch] + inputs = tokenizer(inputs, max_length=1024, padding=True, truncation=True, return_tensors="pt") + input_ids = inputs["input_ids"].to('cuda') + attention_mask = inputs["attention_mask"].to('cuda') + if embedding_index is not None: + encoder_outputs = model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + readout_top_k=model.encoder.num_values, + key_reduce_method="avg", + value_fusion_method=model.encoder.value_fusion_method, + embedding_index=embedding_index, + key_memory=key_memory, + value_memory=value_memory + ) + else: + if inference_type == "serial": + encoder_outputs = model.encoder.forward_with_faiss( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + readout_top_k=model.encoder.num_values, + key_reduce_method="avg", + value_fusion_method=model.encoder.value_fusion_method, + key_faiss_index=faiss_index, + value_memory=value_memory, + not_reduced_key_memory=key_memory + ) + else: + encoder_outputs = asyncio.run( + model.encoder.forward_with_async_faiss( + input_ids, attention_mask, True, model.encoder.num_values, "avg", + model.encoder.value_fusion_method, faiss_index, value_memory, key_memory + ) + ) + generated_tokens = model.generate( + encoder_outputs=encoder_outputs, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=False, + attention_mask=attention_mask, + value_fusion_method=model.encoder.value_fusion_method, + **gen_kwargs, + ) + + readout_qas = [[qas_to_retrieve[idx] for idx in indices] for indices in encoder_outputs.readout_indices] + decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + decoded_tokens = [ans.strip() for ans in decoded_tokens] + return decoded_tokens, readout_qas + + +if __name__ == '__main__': + main() diff --git a/emat/__init__.py b/emat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/emat/evaluation/__init__.py b/emat/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/emat/evaluation/eval_qa_overlap.py b/emat/evaluation/eval_qa_overlap.py new file mode 100644 index 0000000..557a8e2 --- /dev/null +++ b/emat/evaluation/eval_qa_overlap.py @@ -0,0 +1,158 @@ +"""Evaluation script to get prediction scores for overlapped QA pairs in ODQA datasets""" +import argparse +import os +import string + +from emat.evaluation.exact_match import exact_match_score, metric_max_over_ground_truths, f1_score +from emat.utils import load_jsonl + +ANNOTATIONS = [ + 'total', + 'question_overlap', + 'no_question_overlap', + 'answer_overlap', + 'no_answer_overlap', + 'answer_overlap_only', + 'no_overlap', +] + +DIRNAME = os.path.dirname(os.path.abspath(__file__)) +REFERENCE_PATHS = { + 'triviaqa': 'triviaqa-test.qa.csv', + 'naturalquestions': 'nq-test.qa.csv', + 'webquestions': 'webquestions-test.qa.csv', +} +ANNOTATION_PATHS = { + 'triviaqa': 'triviaqa-annotations.jsonl', + 'naturalquestions': 'nq-annotations.jsonl', + 'webquestions': 'webquestions-annotations.jsonl', +} + + +def preprocess(text: str) -> str: + exclude = set(string.punctuation) + exclude.add(" ") + exclude.add("’") + return "".join(ch for ch in text if ch not in exclude) + + +def read_references(fi, sep='\t'): + def parse_pandas_answer(a_string): + # to avoid a pandas dependency, deserialize these manually + try: + parsed_answers = eval(a_string) if a_string.startswith('[') else eval(a_string.replace('""', '"')[1:-1]) + except: + parsed_answers = eval(a_string.replace('""', '"').replace('""', '"').replace('""', '"')[1:-1]) + return parsed_answers + + questions, references = [], [] + for i, line in enumerate(open(fi)): + q, answer_str = line.strip('\n').split(sep) + questions.append(q) + refs = parse_pandas_answer(answer_str) + references.append({'references': refs, 'id': i}) + return questions, references + + +def read_lines(path): + with open(path) as f: + return [l.strip() for l in f] + + +def read_predictions(path): + if path.endswith('json') or path.endswith('.jsonl'): + return load_jsonl(path) + else: + return [{'id': i, 'prediction': pred} for i, pred in enumerate(read_lines(path))] + + +def _get_scores(answers, refs, fn): + return [metric_max_over_ground_truths(fn, pred, rs) for pred, rs in zip(answers, refs)] + + +def get_scores(predictions, references, annotations, annotation_labels=None): + predictions_map = {p['id']: p for p in predictions} + references_map = {r['id']: r for r in references} + annotations_map = {a['id']: a for a in annotations} + assert predictions_map.keys() == references_map.keys(), 'predictions file doesnt match the gold references file ' + assert predictions_map.keys() == annotations_map.keys(), 'prediction file doesnt match the annotation file ' + assert annotations_map.keys() == references_map.keys(), 'annotations file doesnt match the gold references file ' + + annotation_labels = ANNOTATIONS if annotation_labels is None else annotation_labels + + results = {} + for annotation_label in annotation_labels: + if annotation_label == 'no_overlap': + annotation_ids = [ + annotation['id'] for annotation in annotations if + all(label in annotation['labels'] for label in ['no_question_overlap', 'no_answer_overlap']) + ] + else: + annotation_ids = [ + annotation['id'] for annotation in annotations if annotation_label in annotation['labels'] + ] + + preds = [predictions_map[idd]['prediction'] for idd in annotation_ids] + refs = [references_map[idd]['references'] for idd in annotation_ids] + em = _get_scores(preds, refs, exact_match_score) + f1 = _get_scores(preds, refs, f1_score) + results[annotation_label] = { + 'exact_match': 100 * sum(em) / len(em), + 'f1_score': 100 * sum(f1) / len(f1), + 'n_examples': len(annotation_ids), + } + + return results + + +def _print_score(label, results_dict): + print('-' * 50) + print('Label :', label) + print('N examples : ', results_dict['n_examples']) + print('Exact Match : ', results_dict['exact_match']) + # print('F1 score : ', results_dict['f1_score']) + + +def main(predictions_path, dataset_name, data_dir): + references_path = os.path.join(data_dir, REFERENCE_PATHS[dataset_name]) + annotations_path = os.path.join(data_dir, ANNOTATION_PATHS[dataset_name]) + if not os.path.exists(references_path): + raise Exception(' References expected at ' + references_path + + ' not found, please download them using the download script (see readme)') + if not os.path.exists(annotations_path): + raise Exception(' Annotations expected at ' + annotations_path + + ' not found, please download them usiing the download script (see readme)') + + questions, references = read_references(references_path) + annotations = load_jsonl(annotations_path) + + predictions = read_predictions(predictions_path) + assert len(predictions) == len(references) == len(annotations) + + # Align the predictions with the references using the questions + questions = [preprocess(q) for q in questions] + question_to_id = {q.strip(): qid for qid, q in enumerate(questions)} + id_to_prediction = {} + for pred in predictions: + q = preprocess(pred["question"].strip()) + qid = question_to_id[q] + id_to_prediction[qid] = {"id": qid, "prediction": pred["prediction"]} + assert len(id_to_prediction) == len(references) + aligned_predictions = [id_to_prediction[qid] for qid in range(len(references))] + + scores = get_scores(aligned_predictions, references, annotations) + for label in ANNOTATIONS: + _print_score(label, scores[label]) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--predictions", + help="path to predictions txt file, one answer per line. " + "Answer order should follow the order in data/{dataset}-test.qa.csv", type=str) + parser.add_argument('--dataset_name', choices=['naturalquestions', 'triviaqa', 'webquestions'], type=str, + help='name of datset to evaluate on') + parser.add_argument('--data_dir', default="data/qa-overlap", type=str, help='directory of the annotated data') + + args = parser.parse_args() + main(args.predictions, args.dataset_name, args.data_dir) diff --git a/emat/evaluation/eval_retriever.py b/emat/evaluation/eval_retriever.py new file mode 100644 index 0000000..5711286 --- /dev/null +++ b/emat/evaluation/eval_retriever.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +from emat.evaluation.exact_match import metric_max_over_ground_truths, exact_match_score +from emat.utils import load_jsonl + + + +def eval_generation_em(refs, preds): + scores = [] + for ref, pred in zip(refs, preds): + ref_answer = ref["answer"] + em = metric_max_over_ground_truths(exact_match_score, pred, ref_answer) + scores.append(em) + avg_score = sum(scores) / len(scores) + return avg_score + +def eval_retriever(refs, preds, hits_at_k): + if isinstance(hits_at_k, str): + hits_at_k = sorted([int(k) for k in hits_at_k.split(',')]) + + result_dict = {} + for k in hits_at_k: + scores = [] + dont_print = False + for r, p in zip(refs, preds): + if hits_at_k[-1] > len(p): # p['retrieved_qas'] + print(f'Skipping hits@{k} eval as {k} is larger than number of retrieved results') + dont_print = True + ref_answers = r['answer'] + em = any([ + metric_max_over_ground_truths(exact_match_score, pred_answer['answer'][0], ref_answers) + for pred_answer in p[:k] # p['retrieved_qas'][:k] + ]) + scores.append(em) + + avg_score = sum(scores) / len(scores) + # if not dont_print: + # print(f'{k}: {100 * avg_score:0.1f}% \n({sum(scores)} / {len(scores)})') + + result_dict[f"hit@{k}"] = avg_score + + return result_dict + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--predictions', type=str, + help="path to retrieval results to eval, in PAQ's retrieved results jsonl format") + parser.add_argument('--references', type=str, help="path to gold answers, in jsonl format") + parser.add_argument('--hits_at_k', type=str, help='comma separated list of K to eval hits@k for', default="1,10,50") + args = parser.parse_args() + + refs = load_jsonl(args.references) + preds = load_jsonl(args.predictions) + assert len(refs) == len(preds), "number of references doesnt match number of predictions" + + hits_at_k = sorted([int(k) for k in args.hits_at_k.split(',')]) + eval_retriever(refs, preds, hits_at_k) diff --git a/emat/evaluation/exact_match.py b/emat/evaluation/exact_match.py new file mode 100644 index 0000000..8cb21c8 --- /dev/null +++ b/emat/evaluation/exact_match.py @@ -0,0 +1,140 @@ +# coding=utf-8 +import re +import string +import unicodedata +from collections import Counter + +import datasets + +_CITATION = """\ +@inproceedings{rajpurkar-etal-2016-squad, + title = "{SQ}u{AD}: 100,000+ Questions for Machine Comprehension of Text", + author = "Rajpurkar, Pranav and + Zhang, Jian and + Lopyrev, Konstantin and + Liang, Percy", + booktitle = "Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing", + month = nov, + year = "2016", + address = "Austin, Texas", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/D16-1264", + doi = "10.18653/v1/D16-1264", + pages = "2383--2392", +} +@inproceedings{lee-etal-2019-latent, + title = "Latent Retrieval for Weakly Supervised Open Domain Question Answering", + author = "Lee, Kenton and + Chang, Ming-Wei and + Toutanova, Kristina", + booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", + month = jul, + year = "2019", + address = "Florence, Italy", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/P19-1612", + doi = "10.18653/v1/P19-1612", + pages = "6086--6096", +} +""" + +_DESCRIPTION = """\ +Exact match score for Open-domain Question Answering. +This metric measures the percentage of predictions that match any one of the ground truth answers exactly. +""" + +_KWARGS_DESCRIPTION = """ +Calculates the percentage of predictions that match any one of the ground truth answers exactly. +Args: + predictions: list of predictions to score. Each predictions + should be a string with tokens separated by spaces. + references: list of reference for each prediction. Each + reference should be a list of strings with tokens separated by spaces. +Returns: + em: description of the first score, +Examples: + >>> em_metric = datasets.load_metric("exact_match") + >>> results = em_metric.compute(references=[["apple", "orange"], ["banana"]], predictions=["apple", "pear"]) + >>> print(results) + {'em': 0.5} +""" + + +def normalize_answer(s): + """Normalize answer.""" + s = unicodedata.normalize("NFD", s) + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def regex_match_score(prediction, ground_truth): + try: + regex = re.compile(ground_truth, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE) + return regex.match(prediction) is not None + except re.error: + return False + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class ExactMatch(datasets.Metric): + """Exact match (EM) metric for Open-domain Question Answering.""" + + def _info(self): + return datasets.MetricInfo( + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features({ + 'predictions': datasets.Value('string'), + 'references': datasets.Sequence(datasets.Value('string')), + }), + homepage="https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html", + codebase_urls=[ + "https://github.com/google-research/language/blob/58f5dc33a99d168a71586d64ffb7648a0f33b49a/language/orqa/utils/eval_utils.py#L23"], + reference_urls=["https://arxiv.org/pdf/1606.05250.pdf"] + ) + + def _compute(self, predictions, references, is_regex=False): + match_fn = regex_match_score if is_regex else exact_match_score + em_score = sum(metric_max_over_ground_truths(match_fn, i, j) for i, j in zip(predictions, references)) / len( + predictions) + + return {"em": em_score} diff --git a/emat/fusion_net.py b/emat/fusion_net.py new file mode 100644 index 0000000..75de522 --- /dev/null +++ b/emat/fusion_net.py @@ -0,0 +1,30 @@ +import torch +from torch import nn + + +class FusionWeight(nn.Module): + # ["cat_k+v_g(kq)", "cat_k+v_g(kv)"]: + def __init__(self, fusion_type="cat_k+v_g(kq)", model_dim=512, prefix_length=2, key_dim=1024): + super(FusionWeight, self).__init__() + self.fusion_type = fusion_type + if "g(kq)" in self.fusion_type: + self.score_proj = nn.Linear(1, 1) + else: + input_dim = model_dim * prefix_length + key_dim + self.score_proj = nn.Linear(input_dim, 1) + + def forward(self, key=None, query=None, value=None): + if "g(kv)" in self.fusion_type: + # key: batch_size, num_values, key_token_num, hidden_size + # value: batch_size, num_values, prefix_length, hidden_size + batch_size, num_values, _, _ = key.shape + input_hidden = torch.cat((key, value), dim=2) # batch_size, num_values, key_token_num+prefix_length, hidden + scores = self.score_proj(input_hidden.view(batch_size, num_values, -1)) + else: + # key: batch_size, num_values, hidden_size + # query: batch_size, hidden_size + scores = torch.bmm(key, query.unsqueeze(dim=1).transpose(2, 1)) # batch_size, num_values, 1 + scores = self.score_proj(scores) + + scores = torch.sigmoid(scores) + return scores diff --git a/emat/retrieval_adapter.py b/emat/retrieval_adapter.py new file mode 100644 index 0000000..49dfdd4 --- /dev/null +++ b/emat/retrieval_adapter.py @@ -0,0 +1,20 @@ +from torch import nn + + +class RetAdapter(nn.Module): + + def __init__(self, in_dim, out_dim, adapter_type="linear"): + super(RetAdapter, self).__init__() + assert adapter_type in ["linear", "dropout_linear"] + if adapter_type == "linear": + self.out = nn.Linear(in_dim, out_dim, bias=True) + elif adapter_type == "dropout_linear": + self.out = nn.Sequential( + nn.Dropout(0.9), + nn.Linear(in_dim, out_dim, bias=True) + ) + else: + raise NotImplementedError + + def forward(self, vectors): + return self.out(vectors) diff --git a/emat/retriever/build_index.py b/emat/retriever/build_index.py new file mode 100644 index 0000000..2c1582a --- /dev/null +++ b/emat/retriever/build_index.py @@ -0,0 +1,405 @@ +import argparse +import logging +import os +import pickle +import random +import glob +import time + +import faiss +import torch +from tqdm import tqdm + +from emat.retriever.utils import parse_vectors_from_file + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def get_vector_sample(vector, sample_fraction): + max_phi = -1 + N = 0 + + phis = (vector ** 2).sum(1) + max_phi = max(max_phi, phis.max()) + N += vector.shape[0] + if sample_fraction == 1.0: + vector_sample = vector + else: + vector_sample = vector[random.sample(range(0, len(vector)), int(len(vector) * sample_fraction))] + + return vector_sample, max_phi, N + + +def augment_vectors(vectors, max_phi): + phis = (vectors ** 2).sum(1) + aux_dim = torch.sqrt(max_phi.float() - phis.float()) + vectors = torch.cat([vectors, aux_dim.unsqueeze(-1)], -1) + return vectors + + +def build_index_streaming(cached_embeddings_path, + output_path, + vector=None, + hnsw=False, + sq8_quantization=False, + fp16_quantization=False, + store_n=256, + ef_search=32, + ef_construction=80, + sample_fraction=0.1, + indexing_batch_size=5000000, + ): + if vector is None: + vector = parse_vectors_from_file(cached_embeddings_path) + vector_size = vector.shape[1] + + if hnsw: + if sq8_quantization: + index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_8bit, store_n) + elif fp16_quantization: + index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_fp16, store_n) + else: + index = faiss.IndexHNSWFlat(vector_size + 1, store_n) + + index.hnsw.efSearch = ef_search + index.hnsw.efConstruction = ef_construction + else: + if sq8_quantization: + index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2) + elif fp16_quantization: + index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_fp16, faiss.METRIC_L2) + else: + index = faiss.IndexIP(vector_size + 1, store_n) + + vector_sample, max_phi, N = get_vector_sample(vector, sample_fraction) + if hnsw: + vector_sample = augment_vectors(vector_sample, max_phi) + + if sq8_quantization or fp16_quantization: # index requires training + vs = vector_sample.numpy() + logging.info(f'Training Quantizer with matrix of shape {vs.shape}') + index.train(vs) + del vs + del vector_sample + + # logging.warning("tmp code") + # import gc + # del vector + # gc.collect() + # for idx in range(16): + # path = f"./data/embedding_and_faiss/PAQ_from_nq_ckpt/key_memory_dir/embeddings.{idx}.pt" + # vector = torch.load(path) + # if hnsw: + # vector_chunk = augment_vectors(vector, max_phi) + + # original code + if hnsw: + vector = augment_vectors(vector, max_phi) + logging.info(f'Adding Vectors of shape {vector.shape}') + index.add(vector.numpy()) + + if output_path is not None: + logger.info(f'Index Built, writing index to {output_path}') + faiss.write_index(index, output_path) + logger.info(f'Index dumped') + else: + logger.info("Built faiss-index.") + return index + + +def parse_vectors_from_directory_chunks(embeddings_dir, half): + assert os.path.isdir(embeddings_dir), \ + f"Vectors directory {embeddings_dir} doesnt exist, or is not a directory of pytorch vectors" + paths = glob.glob(f"{embeddings_dir}/embeddings.*.pt") + assert len(paths) > 0, "Files not found." + paths_with_order = sorted([(int(os.path.basename(p).split('.')[1]), p) for p in paths], key=lambda x: x[0]) + paths = [po[1] for po in paths_with_order] + for p in paths: + print(p) + for j, p in enumerate(paths): + m = torch.load(p) + # assert int(os.path.basename(p).split('.')[-3]) == j, (p, j) + if half: + m = m if m.dtype == torch.float16 else m.half() + else: + m = m if m.dtype == torch.float32 else m.float() + yield m + + +def get_vector_sample_from_dir(cached_embeddings_path, sample_fraction, half=False): + samples = [] + max_phi = -1 + N = 0 + vectors = parse_vectors_from_directory_chunks(cached_embeddings_path, half) + for chunk in vectors: + phis = (chunk ** 2).sum(1) + max_phi = max(max_phi, phis.max()) + N += chunk.shape[0] + if sample_fraction == 1.0: + chunk_sample = chunk + else: + chunk_sample = chunk[random.sample(range(0, len(chunk)), int(len(chunk) * sample_fraction))] + samples.append(chunk_sample) + + del vectors + vector_sample = torch.cat(samples) + return vector_sample, max_phi, N + + +def get_vector_from_key_chunks(key_chunks, half=False): + samples = [] + max_phi = -1 + N = 0 + # vectors = parse_vectors_from_directory_chunks(key_chunks, half) + for chunk in key_chunks: + + if half: + chunk = chunk if chunk.dtype == torch.float16 else chunk.half() + else: + chunk = chunk if chunk.dtype == torch.float32 else chunk.float() + + phis = (chunk ** 2).sum(1) + max_phi = max(max_phi, phis.max()) + N += chunk.shape[0] + chunk_sample = chunk + samples.append(chunk_sample) + + vector_sample = torch.cat(samples) + return vector_sample, max_phi, N + + +def build_index_streaming_from_dir(cached_embeddings_path, + output_path, + hnsw=False, + sq8_quantization=False, + fp16_quantization=False, + store_n=256, + ef_search=32, + ef_construction=80, + sample_fraction=0.1, + indexing_batch_size=5000000, + ): + logger.info("build index, read from directory.") + first_chunk = torch.load(os.path.join(cached_embeddings_path, "embeddings.0.pt")) # [batch_size, hidden_size] + vector_size = first_chunk.shape[1] + # load first chunk + del first_chunk + + if not os.path.exists("./data/embedding_and_faiss/PAQ_from_nq_ckpt/trained_index.pkl"): + + if hnsw: + if sq8_quantization: + index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_8bit, store_n) + elif fp16_quantization: + index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_fp16, store_n) + else: + index = faiss.IndexHNSWFlat(vector_size + 1, store_n) + + index.hnsw.efSearch = ef_search + index.hnsw.efConstruction = ef_construction + else: + if sq8_quantization: + index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2) + elif fp16_quantization: + index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_fp16, faiss.METRIC_L2) + else: + index = faiss.IndexIP(vector_size + 1, store_n) + + vector_sample, max_phi, N = get_vector_sample_from_dir(cached_embeddings_path, sample_fraction) + + print(max_phi, N) + # exit() + if hnsw: + vector_sample = augment_vectors(vector_sample, max_phi) + + if sq8_quantization or fp16_quantization: # index requires training + vs = vector_sample.numpy() + logging.info(f'Training Quantizer with matrix of shape {vs.shape}') + index.train(vs) + del vs + pickle.dump({"index": index, + "max_phi": max_phi}, + open("./data/embedding_and_faiss/PAQ_from_nq_ckpt/trained_index.pkl", 'wb')) + exit() + del vector_sample + + else: + load_index_phi = pickle.load(open("./data/embedding_and_faiss/PAQ_from_nq_ckpt/trained_index.pkl", 'rb')) + max_phi = load_index_phi["max_phi"] + index = load_index_phi["index"] + N = 64963526 + + chunks_to_add = [] + added = 0 + for vector_chunk in parse_vectors_from_directory_chunks(cached_embeddings_path, half=False): + if hnsw: + vector_chunk = augment_vectors(vector_chunk, max_phi) + + chunks_to_add.append(vector_chunk) + + if sum(c.shape[0] for c in chunks_to_add) > indexing_batch_size: + to_add = torch.cat(chunks_to_add) + logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}') + added += to_add.shape[0] + chunks_to_add = [] + index.add(to_add.numpy()) + + if len(chunks_to_add) > 0: + to_add = torch.cat(chunks_to_add).numpy() + index.add(to_add) + logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}') + + logger.info(f'Index Built, writing index to {output_path}') + faiss.write_index(index, output_path) + logger.info(f'Index dumped') + return index + + +def build_index_streaming_from_key_chunks(key_chunks, + output_path, + hnsw=False, + sq8_quantization=False, + fp16_quantization=False, + store_n=256, + ef_search=32, + ef_construction=80, + sample_fraction=0.1, + indexing_batch_size=5000000, + ): + logger.info("build index, read from directory.") + # first_chunk = torch.load(os.path.join(cached_embeddings_path, "0.key.pt")) # [batch_size, hidden_size] + vector_size = key_chunks[0].shape[1] + + if hnsw: + if sq8_quantization: + index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_8bit, store_n) + elif fp16_quantization: + index = faiss.IndexHNSWSQ(vector_size + 1, faiss.ScalarQuantizer.QT_fp16, store_n) + else: + index = faiss.IndexHNSWFlat(vector_size + 1, store_n) + + index.hnsw.efSearch = ef_search + index.hnsw.efConstruction = ef_construction + else: + if sq8_quantization: + index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_8bit, faiss.METRIC_L2) + elif fp16_quantization: + index = faiss.IndexScalarQuantizer(vector_size, faiss.ScalarQuantizer.QT_fp16, faiss.METRIC_L2) + else: + index = faiss.IndexFlatIP(vector_size + 1, store_n) + + vector_sample, max_phi, N = get_vector_from_key_chunks(key_chunks) + if hnsw: + vector_sample = augment_vectors(vector_sample, max_phi) + + if sq8_quantization or fp16_quantization: # index requires training + vs = vector_sample.numpy() + logging.info(f'Training Quantizer with matrix of shape {vs.shape}') + index.train(vs) + del vs + del vector_sample + + chunks_to_add = [] + added = 0 + for vector_chunk in key_chunks: + if hnsw: + vector_chunk = augment_vectors(vector_chunk, max_phi) + + chunks_to_add.append(vector_chunk) + + if sum(c.shape[0] for c in chunks_to_add) > indexing_batch_size: + logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}') + to_add = torch.cat(chunks_to_add) + chunks_to_add = [] + index.add(to_add) + added += 1 + faiss.write_index(index, output_path) # save intermediate index + + if len(chunks_to_add) > 0: + to_add = torch.cat(chunks_to_add).float().numpy() + index.add(to_add) + logging.info(f'Adding Vectors {added} -> {added + to_add.shape[0]} of {N}') + + if output_path is not None: + logger.info(f'Index Built, writing index to {output_path}') + faiss.write_index(index, output_path) + logger.info(f'Index dumped') + else: + logger.info(f'Index Built.') + return index + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Build a FAISS index from precomputed vector files from embed.py. " + "Provides functionality to build either flat indexes (slow but exact)" + " or HNSW indexes (much faster, but approximate). " + "Optional application of 8bit or 16bit quantization is also available." + " Many more indexes are possible with Faiss, consult the Faiss repository here" + " if you want to build more advanced indexes.") + parser.add_argument('--embeddings_dir', type=str, help='path to directory containing vectors to build index from') + parser.add_argument('--output_path', type=str, help='path to write results to') + parser.add_argument('--hnsw', action='store_true', help='Build an HNSW index rather than Flat') + parser.add_argument('--SQ8', action='store_true', help='use SQ8 quantization on index to save memory') + parser.add_argument('--fp16', action='store_true', help='use fp16 quantization on index to save memory') + parser.add_argument('--store_n', type=int, default=32, help='hnsw store_n parameter') + parser.add_argument('--ef_construction', type=int, default=128, help='hnsw ef_construction parameter') + parser.add_argument('--ef_search', type=int, default=128, help='hnsw ef_search parameter') + parser.add_argument('--sample_fraction', type=float, default=1.0, + help='If memory is limited, specify a fraction (0.0->1.0) of the ' + 'data to sample for training the quantizer') + parser.add_argument('--indexing_batch_size', type=int, default=None, + help='If memory is limited, specify the approximate number ' + 'of vectors to add to the index at once') + parser.add_argument('-v', '--verbose', action="store_true") + args = parser.parse_args() + logging.info(f"Current process's PID: {os.getpid()}") + set_num_threads = 10240 + faiss.omp_set_num_threads(set_num_threads) + logging.info(f"FAISS build info -- set threads {set_num_threads}") + logging.info(f"FAISS build info -- max threads {faiss.omp_get_max_threads()}") + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + assert not (args.SQ8 and args.fp16), 'cant use both sq8 and fp16 Quantization' + assert not os.path.exists(args.output_path), "Faiss index with name specificed in --output_path already exists" + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + args.indexing_batch_size = 10000000000000 if args.indexing_batch_size is None else args.indexing_batch_size + assert 0 < args.sample_fraction <= 1.0 + + build_start = time.perf_counter() + + if os.path.isdir(args.embeddings_dir): + build_index_streaming_from_dir( + args.embeddings_dir, + args.output_path, + args.hnsw, + sq8_quantization=args.SQ8, + fp16_quantization=args.fp16, + store_n=args.store_n, + ef_construction=args.ef_construction, + ef_search=args.ef_search, + sample_fraction=args.sample_fraction, + indexing_batch_size=args.indexing_batch_size, + ) + else: + build_index_streaming( + args.embeddings_dir, + args.output_path, + hnsw=args.hnsw, + sq8_quantization=args.SQ8, + fp16_quantization=args.fp16, + store_n=args.store_n, + ef_construction=args.ef_construction, + ef_search=args.ef_search, + sample_fraction=args.sample_fraction, + indexing_batch_size=args.indexing_batch_size, + ) + + logging.info(f"building index cost {build_start - time.perf_counter():.5f} seconds") diff --git a/emat/retriever/embed.py b/emat/retriever/embed.py new file mode 100644 index 0000000..ae547ce --- /dev/null +++ b/emat/retriever/embed.py @@ -0,0 +1,98 @@ +import argparse +import logging +import os +import time + +import torch +from transformers import T5Tokenizer + +from emat.t5 import T5Config, T5KeyValueEncoder +from emat.utils import to_fp16, verbalise_qa as _verbalise, load_jsonl + +logger = logging.getLogger(__name__) +CUDA = torch.cuda.is_available() + + +def embed_key(model, tokenizer, qas, prefix="", bsz=256, max_length=1024, cuda=CUDA, fp16=False, + use_both_qa_for_key=False): + """Compute the key/query embeddings. + prefix: empty when encoding query, "encode: " when encoding key. + """ + verbalise_qa = _verbalise if use_both_qa_for_key else lambda x, y: x + + def tokenize(batch_qas): + input_strs = [prefix + verbalise_qa(ex["question"], ex["answer"][0]) for ex in batch_qas] + inputs = tokenizer(input_strs, max_length=max_length, padding=True, truncation=True, return_tensors="pt") + return inputs + + if cuda: + model = model.cuda() + model = to_fp16(model) if fp16 else model + + t = time.time() + + def log_progress(j, outputs): + t2 = time.time() + logger.info( + f'Embedded {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds ' + f'({sum([len(o) for o in outputs]) / (t2 - t): 0.4f} QAs per second)') + + outputs = [] + with torch.no_grad(): + for j, batch_start in enumerate(range(0, len(qas), bsz)): + batch = qas[batch_start: batch_start + bsz] + + inputs = tokenize(batch) + inputs = {k: v.cuda() for k, v in inputs.items()} if cuda else inputs + + batch_outputs = model(**inputs, compute_value=False, return_dict=True) + outputs.append(batch_outputs.key.cpu()) + if j % 10 == 0: + log_progress(j, outputs) + + log_progress(j, outputs) + + return torch.cat(outputs, dim=0).cpu() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model_name_or_path', type=str, required=True, help='path to HF model dir') + parser.add_argument("--max_source_length", type=int, default=1024, + help="The maximum total input sequence length after tokenization.Sequences " + "longer than this will be truncated, sequences shorter will be padded.") + parser.add_argument('--qas_to_embed', type=str, required=True, help='Path to questions to embed in jsonl format') + parser.add_argument('--output_path', type=str, help='path to write vectors to') + parser.add_argument('--fp16', action='store_true') + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('-v', '--verbose', action="store_true") + + parser.add_argument("--source_prefix", type=str, default="nq question: ", + help="A prefix to add before every source text " "(useful for T5 models).", ) + parser.add_argument("--use_both_qa_for_key", action="store_true", help="Use both Q and A for key embedding.") + + args = parser.parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + + if args.fp16 and not CUDA: + raise Exception("Can't use --fp16 without a gpu, CUDA not found") + + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + qas_to_embed = load_jsonl(args.qas_to_embed) + + config = T5Config.from_pretrained(args.model_name_or_path) + model = T5KeyValueEncoder.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=True) + + embed_mat = embed_key(model, tokenizer, qas_to_embed, prefix=args.source_prefix, bsz=args.batch_size, + max_length=args.max_source_length, fp16=args.fp16, + use_both_qa_for_key=args.use_both_qa_for_key) + + torch.save(embed_mat.half(), args.output_path) diff --git a/emat/retriever/retrieve.py b/emat/retriever/retrieve.py new file mode 100644 index 0000000..92feecb --- /dev/null +++ b/emat/retriever/retrieve.py @@ -0,0 +1,180 @@ +import argparse +import logging +import time +from copy import deepcopy + +import faiss +import numpy as np +import torch +from transformers import T5Tokenizer + +from emat.evaluation.eval_retriever import eval_retriever +from emat.retriever.utils import get_mips_function, parse_vectors_from_file, mips +from emat.t5 import T5Config, T5KeyValueEncoder +from emat.utils import load_jsonl, write_jsonl, to_fp16 + +logger = logging.getLogger(__name__) + +CUDA = torch.cuda.is_available() + + +def get_output_format(qas_to_answer, qas_to_retrieve_from, top_indices, top_scores): + results = [] + for qa_ind, qa in enumerate(qas_to_answer): + res = [] + for score_ind, ind in enumerate(top_indices[qa_ind]): + score = top_scores[qa_ind][score_ind] + ret_qa = deepcopy(qas_to_retrieve_from[ind]) + ret_qa['score'] = float(score) + res.append(ret_qa) + results.append(res) + + return [{'input_qa': in_qa, 'retrieved_qas': ret_qas} for in_qa, ret_qas in zip(qas_to_answer, results)] + + +def embed_query(model, tokenizer, qas, prefix="", bsz=256, max_length=1024, cuda=CUDA, fp16=False): + def tokenize(batch_qas): + input_strs = [prefix + ex["question"] for ex in batch_qas] + inputs = tokenizer(input_strs, max_length=max_length, padding=True, truncation=True, return_tensors="pt") + return inputs + + if cuda: + model = model.cuda() + model = to_fp16(model) if fp16 else model + + t = time.time() + + def log_progress(j, outputs): + t2 = time.time() + logger.info( + f'Embedded {j + 1} / {len(list(range(0, len(qas), bsz)))} batches in {t2 - t:0.2f} seconds ' + f'({sum([len(o) for o in outputs]) / (t2 - t): 0.4f} QAs per second)') + + outputs = [] + with torch.no_grad(): + for j, batch_start in enumerate(range(0, len(qas), bsz)): + batch = qas[batch_start: batch_start + bsz] + + inputs = tokenize(batch) + inputs = {k: v.cuda() for k, v in inputs.items()} if cuda else inputs + + batch_outputs = model(**inputs, compute_value=False, return_dict=True) + outputs.append(batch_outputs.key.cpu()) + if j % 10 == 0: + log_progress(j, outputs) + + log_progress(j, outputs) + + return torch.cat(outputs, dim=0).cpu() + + + + + +def run_queries(model, tokenizer, qas_to_retrieve_from, qas_to_answer, top_k, index=None, prefix="", + batch_size=128, max_length=1024, fp16=False, n_queries_to_parallelize=2048): + assert index is not None + + logger.info('Embedding QAs to answer:') + embedded_qas_to_answer = embed_query(model, tokenizer, qas_to_answer, prefix=prefix, bsz=batch_size, + max_length=max_length, cuda=CUDA, fp16=fp16) + logger.info('Running MIPS search:') + top_indices, top_scores = mips(index, embedded_qas_to_answer, top_k, + n_queries_to_parallelize=n_queries_to_parallelize) + + return get_output_format(qas_to_answer, qas_to_retrieve_from, top_indices, top_scores) + + +def _load_index_if_exists(faiss_index_path, precomputed_embeddings_dir, n_vectors_to_load=None, memory_friendly=False, + efsearch=128): + index = None + if faiss_index_path is not None: + assert precomputed_embeddings_dir is None, "Do not specify both a --faiss_index_path and --precomputed_embeddings_dir" + logger.info('Loading Faiss index:') + index = faiss.read_index(faiss_index_path) + if hasattr(index, 'hnsw'): + index.hnsw.efSearch = efsearch + + elif precomputed_embeddings_dir is not None: + logger.info('Loading vectors index from file:') + index = parse_vectors_from_file(precomputed_embeddings_dir).float() + assert n_vectors_to_load == index.shape[0] + + logger.info('Index loaded') if index is not None else None + return index + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model_name_or_path', type=str, required=True, help='path to HF model dir') + parser.add_argument("--max_source_length", type=int, default=1024, + help="The maximum total input sequence length after tokenization.Sequences " + "longer than this will be truncated, sequences shorter will be padded.") + parser.add_argument('--qas_to_answer', type=str, required=True, help="path to questions to answer in jsonl format") + parser.add_argument('--qas_to_retrieve_from', type=str, required=True, + help="path to QA-pairs to retrieve answers from in jsonl format") + parser.add_argument('--top_k', type=int, default=50, help="top K QA-pairs to retrieve for each input question") + parser.add_argument('--output_file', type=str, required=True, help='Path to write jsonl results to') + parser.add_argument('--faiss_index_path', default=None, type=str, + help="Path to faiss index, if retrieving from a faiss index") + parser.add_argument('--precomputed_embeddings_dir', default=None, type=str, + help="path to a directory of vector embeddings if retrieving from raw embeddign vectors") + parser.add_argument('--fp16', action='store_true') + parser.add_argument('--batch_size', type=int, default=128, help='Batch size for embedding questions for querying') + parser.add_argument('--n_queries_to_parallelize', type=int, default=256, help="query batch size") + parser.add_argument('-v', '--verbose', action="store_true") + parser.add_argument('--memory_friendly_parsing', action='store_true', + help='Pass this to load files more slowly, but save memory') + parser.add_argument('--faiss_efsearch', type=int, default=128, + help='EFSearch search time parameter for hnsw, higher is more accurate but slower') + + parser.add_argument("--source_prefix", type=str, default="nq question: ", + help="A prefix to add before every source text " "(useful for T5 models).", ) + parser.add_argument('--hits_at_k', type=str, help='comma separated list of K to eval hits@k for', default="1,10,50") + + args = parser.parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + + qas_to_answer = load_jsonl(args.qas_to_answer) + qas_to_retrieve_from = load_jsonl(args.qas_to_retrieve_from) + + index = _load_index_if_exists( + args.faiss_index_path, + args.precomputed_embeddings_dir, + n_vectors_to_load=len(qas_to_retrieve_from), + memory_friendly=args.memory_friendly_parsing, + efsearch=args.faiss_efsearch + ) + + config = T5Config.from_pretrained(args.model_name_or_path) + model = T5KeyValueEncoder.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=True) + + retrieved_answers = run_queries( + model, + tokenizer, + qas_to_retrieve_from, + qas_to_answer, + args.top_k, + index, + args.source_prefix, + args.batch_size, + args.max_source_length, + args.fp16, + args.n_queries_to_parallelize, + ) + + logger.info(f'Writing retrieval output to {args.output_file}') + write_jsonl(retrieved_answers, args.output_file) + + hits_at_k = sorted([int(k) for k in args.hits_at_k.split(',')]) + result = eval_retriever(qas_to_answer, retrieved_answers, hits_at_k) + with open(args.output_file + ".result", "w") as f: + for k, v in result.items(): + f.write(f"{k}: {v}\n") diff --git a/emat/retriever/utils.py b/emat/retriever/utils.py new file mode 100644 index 0000000..c82302a --- /dev/null +++ b/emat/retriever/utils.py @@ -0,0 +1,148 @@ +import glob +import logging +import os +import time + +import torch + +logger = logging.getLogger(__name__) + + +def torch_mips(index, query_batch, top_k): + sims = torch.matmul(query_batch, index.t()) + return sims.topk(top_k) + + +def flat_index_mips(index, query_batch, top_k): + return index.search(query_batch.numpy(), top_k) + + +def aux_dim_index_mips(index, query_batch, top_k): + # querying faiss indexes for MIPS using a euclidean distance index, used with hnsw + aux_dim = query_batch.new(query_batch.shape[0]).fill_(0) + aux_query_batch = torch.cat([query_batch, aux_dim.unsqueeze(-1)], -1) + return index.search(aux_query_batch.numpy(), top_k) + + +def get_mips_function(index): + if type(index) == torch.Tensor: + return torch_mips + elif 'hnsw' in str(type(index)).lower(): + return aux_dim_index_mips + else: + return flat_index_mips + + +def get_vectors_file_paths_in_vector_directory(embeddings_dir): + paths = glob.glob(os.path.abspath(embeddings_dir) + '/*') + np = len(paths) + template = '.'.join(paths[0].split('.')[:-1]) + return [template + f'.{j}' for j in range(np)] + + +def parse_vectors_from_directory_chunks(embeddings_dir, half): + paths = get_vectors_file_paths_in_vector_directory(embeddings_dir) + for j, p in enumerate(paths): + logger.info(f'Loading vectors from {p} ({j + 1} / {len(paths)})') + m = torch.load(p) + assert int(p.split('.')[-1]) == j, (p, j) + + if half: + m = m if m.dtype == torch.float16 else m.half() + else: + m = m if m.dtype == torch.float32 else m.float() + yield m + + +def parse_vectors_from_directory_fast(embeddings_dir): + ms = [] + for m in parse_vectors_from_directory_chunks(embeddings_dir): + ms.append(m) + + out = torch.cat(ms) + logger.info(f'loaded index of shape {out.shape}') + return out + + +def parse_vectors_from_directory_memory_friendly(embeddings_dir, size=None): + paths = get_vectors_file_paths_in_vector_directory(embeddings_dir) + if size is None: + size = 0 + for j, p in enumerate(paths): + logger.info(f'Loading vectors from {p} ({j + 1} / {len(paths)}) to find total num vectors') + m = torch.load(p) + size += m.shape[0] + + out = None + offset = 0 + for j, p in enumerate(paths): + logger.info(f'Loading vectors from {p} ({j + 1} / {len(paths)})') + m = torch.load(p) + + assert int(p.split('.')[-1]) == j, (p, j) + if out is None: + out = torch.zeros(size, m.shape[1]) + out[offset: offset + m.shape[0]] = m + offset += m.shape[0] + assert offset == size + logger.info(f'loaded index of shape {out.shape}') + + return out + + +def parse_vectors_from_directory(fi, memory_friendly=False, size=None, as_chunks=False, half=False): + assert os.path.isdir(fi), f"Vectors directory {fi} doesnt exist, or is not a directory of pytorch vectors" + if as_chunks: + return parse_vectors_from_directory_chunks(fi, half) + + if memory_friendly: + out = parse_vectors_from_directory_memory_friendly(fi, size=size) + else: + out = parse_vectors_from_directory_fast(fi) + + if half: + out = out if out.dtype == torch.float16 else out.half() + else: + out = out if out.dtype == torch.float32 else out.float() + + return out + + +def parse_vectors_from_file(fi, half=False): + assert os.path.isfile(fi), f"{fi}" + logger.info(f'Loading vectors from {fi}') + out = torch.load(fi) + logger.info(f'loaded vectors of shape {out.shape}') + + if half: + out = out if out.dtype == torch.float16 else out.half() + else: + out = out if out.dtype == torch.float32 else out.float() + + return out + + +def mips(index, queries, top_k, n_queries_to_parallelize=256): + # t = time.time() + all_top_indices = None + all_top_scores = None + + _mips = get_mips_function(index) + + for mb in range(0, len(queries), n_queries_to_parallelize): + query_batch = queries[mb:mb + n_queries_to_parallelize].float() + scores, top_indices = _mips(index, query_batch, top_k) + + all_top_indices = top_indices if all_top_indices is None else np.concatenate([all_top_indices, top_indices]) + all_top_scores = scores if all_top_scores is None else np.concatenate([all_top_scores, scores]) + + # delta = time.time() - t + # logger.info( + # f'{len(all_top_indices)}/ {len(queries)} queries searched in {delta:04f} ' + # f'seconds ({len(all_top_indices) / delta} per second)') + + assert len(all_top_indices) == len(queries) + + # delta = time.time() - t + # logger.info(f'Index searched in {delta:04f} seconds ({len(queries) / delta} per second)') + return all_top_indices, all_top_scores diff --git a/emat/t5.py b/emat/t5.py new file mode 100644 index 0000000..eb09c1d --- /dev/null +++ b/emat/t5.py @@ -0,0 +1,1976 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model. """ +import asyncio +import copy +import warnings +from concurrent import futures +from dataclasses import dataclass +from typing import Dict, Any, Optional, Tuple +from emat.fusion_net import FusionWeight + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint +from transformers.file_utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, + ModelOutput, +) +from transformers.models.t5.modeling_t5 import ( + T5Block, + T5LayerNorm, + T5Stack, + T5ForConditionalGeneration, + T5_START_DOCSTRING, + T5_INPUTS_DOCSTRING, + __HEAD_MASK_WARNING_MSG as HEAD_MASK_WARNING_MSG, + _CONFIG_FOR_DOC, +) +from transformers.utils import logging +from utils.utils import reduce_query_or_key_embeds +from emat.retriever.utils import mips +from emat.retrieval_adapter import RetAdapter + +logger = logging.get_logger(__name__) + + +@dataclass +class KeyValueOutput(ModelOutput): + key: torch.FloatTensor = None + value: Optional[torch.FloatTensor] = None + + +@dataclass +class CATEncoderOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + query_embeds: Optional[torch.tensor] = None + readout_indices: Optional[Any] = None + updated_attention_mask: Optional[torch.tensor] = None + + +@dataclass +class CATSeq2SeqLMOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cat_encoder_outputs: Optional[CATEncoderOutput] = None + + +class ConvKeyEncoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + + self.conv_layer = nn.Conv1d( + in_channels=config.d_model, + out_channels=config.d_model, + kernel_size=3, + stride=1, + padding="same", + bias=False + ) + self.relu = nn.ReLU() + self.linear = nn.Linear(config.d_model, config.d_key, bias=False) + + def forward(self, hidden_states, attention_mask): + attention_mask = attention_mask[:, None, :] + masked_hidden = hidden_states.transpose(1, 2) * attention_mask # shape: [batch_size, d_model, seq_length] + + conv_out = self.conv_layer(masked_hidden) # shape: [batch_size, d_model, seq_length] + relu_out = self.relu(conv_out) # shape: [batch_size, d_model, seq_length] + relu_out = relu_out * attention_mask # mask again + + sum_pooled = torch.sum(relu_out, dim=2) # shape: [batch_size, d_model] + lengths = torch.sum(attention_mask, dim=(1, 2)) # shape: [batch_size] + mean_pooled = sum_pooled / lengths[:, None] # shape: [batch_size, d_model] + + final_out = self.linear(mean_pooled) # shape: [batch_size, d_key] + return final_out + + +# T5StackWithKeyValueMemory is encoder +class T5StackWithKeyValueMemory(T5Stack): + def __init__(self, config, embed_tokens=None): + super(T5Stack, self).__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + # Create prefix embeddings + assert not self.is_decoder, "This module should only be used as encoder" + self.prefix_length = config.prefix_length # length of the prefix + self.prefix_embedding = nn.Parameter(torch.empty((self.prefix_length, config.d_model), dtype=torch.float)) + self.model_dim = config.d_model + + # Key configs + self.key_layer = config.key_layer # the layer that conducts key querying + + self.cat_layer = getattr(config, "cat_layer", None) + + # Initialize the key encoder + self.key_encoder_type = config.key_encoder_type + if self.key_encoder_type == "linear": + self.d_key = config.d_key # dimension of the key/query embedding + self.key_encoder = nn.Linear(self.prefix_length * config.d_model, self.d_key, bias=False) + elif self.key_encoder_type == "conv": + self.key_encoder = ConvKeyEncoder(config) + elif self.key_encoder_type == "prefix": + self.key_encoder = None + else: + raise ValueError(f"Incorrect key_encoder_type: {self.key_encoder_type}") + + # self.qk_scorer = nn.Linear(1, 1, bias=True) # calibrate the query-key match scores into gating + + # Value configs + self.value_layer = config.value_layer # the layer that it conducts value infilling + self.num_values = config.num_values # number of value embeddings to infill + assert self.key_layer <= self.value_layer, "Key layer should be smaller than or equal to value layer" + + self.value_fusion_method = config.value_fusion_method + + if self.value_fusion_method is not None and "cat" in self.value_fusion_method: + # add_position_bias_layer = self.value_layer + # if "delay" in self.value_fusion_method: + # add_position_bias_layer = self.key_layer + if self.cat_layer is not None: + add_position_bias_layer = self.cat_layer + else: + add_position_bias_layer = min(self.key_layer, self.value_layer) + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0 or i == add_position_bias_layer)) + for i in range(config.num_layers)] + ) + else: + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers)] + ) + + if self.value_fusion_method is not None and "g(" in self.value_fusion_method: + self.fusion_weight_proj = FusionWeight(fusion_type=self.value_fusion_method) + else: + self.fusion_weight_proj = None + + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.key_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + nn.init.normal_(self.prefix_embedding.data, mean=0.0, std=config.initializer_factor * 1.0) + # self.qk_scorer.weight.data.copy_(torch.tensor([[1.0]])) + # self.qk_scorer.bias.data.copy_(torch.tensor([0.0])) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + key_embeds=None, + value_embeds=None, + key_faiss_index=None, + key_reduce_method=None, + value_qas_input_ids=None, + value_qas_attention_mask=None, + readout_top_k=1, + value_fusion_method=None, + key_embeds_of_value=None, + key_memory=None, + value_memory=None, + embedding_index=None, + ): + assert key_embeds is None + assert value_fusion_method == self.value_fusion_method + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Sanity check + assert not use_cache, "This class does not support use_cache because it is encoder only" + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + # Concatenate the prefix embeddings and extend the attention masks + prefix_embeds = self.prefix_embedding[None, :, :].expand(batch_size, -1, -1).to(inputs_embeds.device) + inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1) + + # Extend the attention masks + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length + self.prefix_length).to(inputs_embeds.device) + else: + prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to( + inputs_embeds.device) + attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, (batch_size, seq_length + self.prefix_length), inputs_embeds.device + ) + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + query_embeds = None + readout_indices = None + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if not use_cache: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + + # Query-key matching + if i == self.key_layer: # key-layer is Query-Layer + raw_query_embeds = self._encode_key(hidden_states, attention_mask) # shape:[batch_size, d_key] + raw_query_embeds = raw_query_embeds.view(hidden_states.shape[0], -1, hidden_states.shape[-1]) + + normed_query_embeds = self.kv_output_layer(raw_query_embeds) # Query is normed !!! + query_embeds = reduce_query_or_key_embeds(normed_query_embeds, key_reduce_method) + + if embedding_index is not None: + assert value_embeds is None and key_embeds_of_value is None + if type(embedding_index) is not list: + embedding_index = [embedding_index] + half_query_embeds = query_embeds.half() + memory_size, hidden_num, hidden_size = value_memory.shape + if memory_size > 20000000: + # if scale is large: calculate topk in each chunk -> combine all-topk -> select final topk + chunk_top_scores = [] + chunk_top_indices = [] + idx_shift = 0 + for chunk_key_memory in embedding_index: + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_topk = torch.mm(half_query_embeds, chunk_key_memory_cuda.t()).topk(50, dim=1) + chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size] + chunk_top_indices.append(chunk_topk.indices + idx_shift) + idx_shift += len(chunk_key_memory) + del chunk_key_memory_cuda + torch.cuda.empty_cache() + chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n + chunk_top_indices = torch.cat(chunk_top_indices, dim=1) # q_batch, local_size*seg_n + topk = chunk_top_scores.topk(readout_top_k, dim=1) # q_batch, local_size + top_indices_indices = topk.indices + readout_indices = [] + for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices): + readout_indices.append([cur_indices[idx] for idx in cur_indices_indices]) + else: + all_chunk_scores = [] + for chunk_key_memory in embedding_index: + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_scores = torch.mm(half_query_embeds, chunk_key_memory_cuda.t()) # query_batch + all_chunk_scores.append(chunk_scores) + del chunk_key_memory_cuda + scores = torch.cat(all_chunk_scores, dim=1) + readout_indices = scores.topk(readout_top_k, dim=1).indices.tolist() + + top_indices = torch.tensor(readout_indices) + bs = input_ids.shape[0] + value_embeds = torch.index_select(value_memory, 0, top_indices.view(-1)).float().cuda() + value_embeds = value_embeds.view(input_ids.shape[0], readout_top_k, hidden_num, hidden_size) + key_embeds_of_value = torch.index_select(key_memory, 0, top_indices.view(-1)).float().cuda() + key_embeds_of_value = key_embeds_of_value.view(bs, readout_top_k, hidden_num, hidden_size) + + if value_fusion_method == "cat_k_delay+v": + # Serial mode, cat key directly. + batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape + if key_nums != self.prefix_length: + assert key_nums == 1 + key_embeds_of_value = key_embeds_of_value.repeat(1, 1, self.prefix_length, 1) + hidden_states = torch.cat( + [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size), + hidden_states], dim=1 + ) + extend_length = num_values * self.prefix_length + extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype) + attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape[:2], inputs_embeds.device + ) + position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule + + if self.cat_layer is not None and i == self.cat_layer: + if value_fusion_method == "async_cat_k_delay+v": + # Async mode, emat key in cat_layer. the implementation of delay + v is same to serial mode. + batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape + hidden_states = torch.cat( + [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size), + hidden_states], dim=1 + ) + extend_length = num_values * self.prefix_length + extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype) + attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape[:2], inputs_embeds.device + ) + position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule + + if i == self.value_layer: + batch_size, num_values, _, hidden_size = key_embeds_of_value.shape + + # assert query_embeds is not None, "Use query_embeds to read memory before assignment." + if "delay" in value_fusion_method: + updated_key = hidden_states[:, :num_values * self.prefix_length] + updated_key = updated_key.view(batch_size, num_values, self.prefix_length, hidden_size) + else: + updated_key = None + integrated_value = self.get_integrated_values(value_embeds, key_embeds_of_value, value_fusion_method, + query_embeds=query_embeds, updated_key=updated_key, + key_reduce_method=key_reduce_method) + if value_fusion_method == "infill": + assert self.num_values == 1 + hidden_states = torch.cat([integrated_value, hidden_states[:, self.prefix_length:]], dim=1) + elif "cat" in value_fusion_method and "delay" in value_fusion_method: + hidden_states[:, :num_values * self.prefix_length] = integrated_value + hidden_states = hidden_states.contiguous() + elif "cat" in value_fusion_method and "delay" not in value_fusion_method: + hidden_states = torch.cat([integrated_value, hidden_states], dim=1) + extend_length = integrated_value.shape[1] + extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype) + attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape[:2], inputs_embeds.device + ) + position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule + else: + raise NotImplementedError(f"{value_fusion_method} is not defined.") + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + query_embeds, + readout_indices + ] + if v is not None + ) + return CATEncoderOutput( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + query_embeds=query_embeds, + readout_indices=readout_indices, + updated_attention_mask=None, + ) + + def get_integrated_values(self, group_value_embeds, group_key_embeds, value_fusion_method, + query_embeds=None, updated_key=None, key_reduce_method=None): + # group_value_embeds: [batch_size, num_values, prefix_length, hidden_size] + # group_key_embeds: [batch_size, num_values, key_num_tokens, hidden_size] + if value_fusion_method == "cat_k_delay+v": + batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape + integrated_value = updated_key + group_value_embeds.contiguous() + integrated_value = integrated_value.view(batch_size, num_values * prefix_length, hidden_size) + elif value_fusion_method == "async_cat_k_delay+v": + batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape + integrated_value = updated_key + group_value_embeds.contiguous() + integrated_value = integrated_value.view(batch_size, num_values * prefix_length, hidden_size) + elif value_fusion_method == "cat_v": + batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape + group_value_embeds = group_value_embeds.contiguous() + integrated_value = group_value_embeds.view(batch_size, num_values * prefix_length, hidden_size) + elif value_fusion_method == "async_cat_k+v": + batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape + group_key_add_value = group_key_embeds + group_value_embeds + integrated_value = group_key_add_value.view(batch_size, num_values * self.prefix_length, hidden_size) + elif value_fusion_method == "cat_k+v": + batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape + group_key_add_value = group_key_embeds + group_value_embeds + integrated_value = group_key_add_value.view(batch_size, num_values * self.prefix_length, hidden_size) + elif value_fusion_method == "cat_k_delay+v_g(kv)": + # batch_size, num_values, 1 + key_weight = self.fusion_weight_proj(key=updated_key, value=group_value_embeds.contiguous()) + key_weight = key_weight.unsqueeze(dim=-1) + integrated_value = key_weight * updated_key + (1 - key_weight) * group_value_embeds + batch_size, num_values, key_nums, hidden_size = updated_key.shape + integrated_value = integrated_value.view(batch_size, num_values * self.prefix_length, hidden_size) + elif value_fusion_method == "infill": + batch_size, num_values, prefix_length, hidden_size = group_value_embeds.shape + assert num_values == self.num_values == 1 + integrated_value = group_value_embeds.view(batch_size, prefix_length, hidden_size) + elif value_fusion_method == "cat_kv": + group_key_cat_value = torch.cat((group_key_embeds, group_value_embeds), dim=2) + # [batch_size, num_values, prefix_length + key_num_tokens, hidden_size] + batch_size, num_values, integrated_prefix_length, hidden_size = group_key_cat_value.shape + integrated_value = group_key_cat_value.view(batch_size, num_values * integrated_prefix_length, hidden_size) + elif value_fusion_method == "cat_avgk+v": + batch_size, num_values, key_nums, hidden_size = group_key_embeds.shape + reduced_key_embeds = (group_key_embeds.sum(dim=2) / key_nums).unsqueeze(dim=2) + group_key_add_value = reduced_key_embeds + group_value_embeds + integrated_value = group_key_add_value.view(batch_size, num_values * self.prefix_length, hidden_size) + elif value_fusion_method == "cat_k+v_g(kq)": + batch_size, num_values, key_nums, hidden_size = group_key_embeds.shape + squeezed_key_embeds = group_key_embeds.view(batch_size * num_values, key_nums, hidden_size) + reduced_key_embeds = reduce_query_or_key_embeds(squeezed_key_embeds, key_reduce_method) + reduced_key_embeds = reduced_key_embeds.view(batch_size, num_values, hidden_size) + key_weight = self.fusion_weight_proj(key=reduced_key_embeds, query=query_embeds) # b_s, num_values, 1 + key_weight = key_weight.unsqueeze(dim=-1) + integrated_value = key_weight * group_key_embeds + (1 - key_weight) * group_value_embeds + integrated_value = integrated_value.view(batch_size, num_values * self.prefix_length, hidden_size) + else: + raise NotImplementedError(f"{value_fusion_method} is not defined.") + return integrated_value + + def embed_kv(self, input_ids, attention_mask=None, head_mask=None, compute_key=True, compute_value=True, + embed_for_ae_task=False) -> Dict: + """Compute the key/value embeddings for the input.""" + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + + # Sanity check + assert compute_key or compute_value, "At least one of compute_key and compute_value needs to be True" + assert not (compute_key and compute_value), "Only compute key or value once forward." + assert input_ids is not None + + original_input_shape = input_ids.shape + input_ids = input_ids.view(-1, input_ids.shape[-1]) # the last dimension is seq_length + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_ids.shape + + # Concatenate the prefix embeddings and extend the attention masks + prefix_embeds = self.prefix_embedding[None, :, :].expand(batch_size, -1, -1).to(inputs_embeds.device) + inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1) + + # Extend the attention masks + if attention_mask is None: + attention_mask = torch.ones(batch_size, seq_length + self.prefix_length).to(inputs_embeds.device) + else: + prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to( + inputs_embeds.device) + attention_mask = torch.cat([prefix_mask, attention_mask.view(batch_size, seq_length)], dim=1) + + # initialize past_key_values with `None` if past does not exist + past_key_values = [None] * len(self.block) + + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, (batch_size, seq_length + self.prefix_length), inputs_embeds.device + ) + + encoder_hidden_states = None + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(None, self.config.num_layers) + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + key_embeds, value_embeds = None, None + normed_key_embeds, normed_value_embeds = None, None + key_embeds_to_cat = None + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=False, + output_attentions=False, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + + # Encode the key + if compute_key and i == self.key_layer: + key_embeds = self._encode_key(hidden_states, attention_mask) # shape:[batch_size, d_key] + key_embeds = key_embeds.view(hidden_states.shape[0], -1, hidden_states.shape[-1]) + if not compute_value and self.cat_layer is None: + break + if compute_value and i == self.cat_layer: + key_embeds_to_cat = self._encode_key(hidden_states, attention_mask) + key_embeds_to_cat = key_embeds_to_cat.view(hidden_states.shape[0], -1, hidden_states.shape[-1]) + + # Encode the value + if compute_value and i == self.value_layer: + value_embeds = hidden_states[:, :self.prefix_length] + break # (jimmycode): early stop to reduce cost + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + if value_embeds is not None: + normed_value_embeds = self.kv_output_layer(value_embeds) + if key_embeds is not None: + normed_key_embeds = self.kv_output_layer(key_embeds) + + return {"key_embeds": key_embeds, "normed_key_embeds": normed_key_embeds, + "value_embeds": value_embeds, "normed_value_embeds": normed_value_embeds, + "key_embeds_to_cat": key_embeds_to_cat} + + def kv_output_layer(self, embeds): + assert len(embeds.shape) == 3 + embeds = self.key_layer_norm(embeds) + embeds = self.dropout(embeds) + return embeds + + def _encode_key(self, hidden_states, attention_mask, prefix_embs=None): + assert hidden_states.ndim == 3 and attention_mask.ndim == 2 + + if self.key_encoder_type == "linear": + if prefix_embs is None: + prefix_embs = hidden_states[:, :self.prefix_length].view(hidden_states.shape[0], -1) + # shape: [batch_size, prefix_length * d_model] + key_embeds = self.key_encoder(prefix_embs) # shape: [batch_size, d_key] + elif self.key_encoder_type == "conv": + if prefix_embs is None: + key_embeds = self.key_encoder(hidden_states, attention_mask) + else: + prefix_mask = torch.ones(prefix_embs.shape[:2].to(prefix_embs.device)) + key_embeds = self.key_encoder(prefix_embs, prefix_mask) + elif self.key_encoder_type == "prefix": + prefix_embs = hidden_states[:, :self.prefix_length] # [batch_size, prefix-len, hidden_size] + key_embeds = prefix_embs.view(hidden_states.shape[0], -1) # [batch_size, d_key) + else: + raise ValueError(f"Incorrect key_encoder_type: {self.key_encoder_type}") + + return key_embeds + + def forward_with_faiss( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + value_memory=None, + not_reduced_key_memory=None, + key_faiss_index=None, + key_reduce_method=None, + readout_top_k=1, + value_fusion_method=None, + ): + assert value_memory is not None + assert key_faiss_index is not None + assert not_reduced_key_memory is not None + assert value_fusion_method == self.value_fusion_method + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Sanity check + assert not use_cache, "This class does not support use_cache because it is encoder only" + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + # Concatenate the prefix embeddings and extend the attention masks + prefix_embeds = self.prefix_embedding[None, :, :].expand(batch_size, -1, -1).to(inputs_embeds.device) + inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1) + + # Extend the attention masks + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length + self.prefix_length).to(inputs_embeds.device) + else: + prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to( + inputs_embeds.device) + attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, (batch_size, seq_length + self.prefix_length), inputs_embeds.device + ) + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + query_embeds = None + readout_indices = None + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if not use_cache: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + + # Query-key matching + if i == self.key_layer: # key-layer is Query-Layer + raw_query_embeds = self._encode_key(hidden_states, attention_mask) # shape:[batch_size, d_key] + raw_query_embeds = raw_query_embeds.view(hidden_states.shape[0], -1, hidden_states.shape[-1]) + + normed_query_embeds = self.kv_output_layer(raw_query_embeds) # Query is normed !!! + query_embeds = reduce_query_or_key_embeds(normed_query_embeds, key_reduce_method) + + # serial mode + value_embeds, key_embeds_of_value, readout_indices = self.query_memory( + value_memory, not_reduced_key_memory, key_faiss_index, query_embeds, readout_top_k + ) + value_embeds = value_embeds.to(query_embeds.device) + key_embeds_of_value = key_embeds_of_value.to(query_embeds.device) + value_embeds = value_embeds.to(query_embeds.dtype) + key_embeds_of_value = key_embeds_of_value.to(query_embeds.dtype) + + if value_fusion_method == "cat_k_delay+v": + # Serial mode, emat key directly. + batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape + if key_nums != self.prefix_length: + assert key_nums == 1 + key_embeds_of_value = key_embeds_of_value.repeat(1, 1, self.prefix_length, 1) + hidden_states = torch.cat( + [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size), + hidden_states], dim=1 + ) + extend_length = num_values * self.prefix_length + extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype) + attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape[:2], inputs_embeds.device + ) + position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule + + if self.cat_layer is not None and i == self.cat_layer: + if value_fusion_method == "async_cat_k_delay+v": + # Async mode, emat key in cat_layer. the implementation of delay + v is same to serial mode. + batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape + hidden_states = torch.cat( + [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size), + hidden_states], dim=1 + ) + extend_length = num_values * self.prefix_length + extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype) + attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape[:2], inputs_embeds.device + ) + position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule + + if i == self.value_layer: + batch_size, num_values, _, hidden_size = key_embeds_of_value.shape + + # assert query_embeds is not None, "Use query_embeds to read memory before assignment." + if "delay" in value_fusion_method: + updated_key = hidden_states[:, :num_values * self.prefix_length] + updated_key = updated_key.view(batch_size, num_values, self.prefix_length, hidden_size) + else: + updated_key = None + integrated_value = self.get_integrated_values(value_embeds, key_embeds_of_value, value_fusion_method, + query_embeds=query_embeds, updated_key=updated_key, + key_reduce_method=key_reduce_method) + if value_fusion_method == "infill": + assert self.num_values == 1 + hidden_states = torch.cat([integrated_value, hidden_states[:, self.prefix_length:]], dim=1) + elif "cat" in value_fusion_method and "delay" in value_fusion_method: + hidden_states[:, :num_values * self.prefix_length] = integrated_value + hidden_states = hidden_states.contiguous() + elif "cat" in value_fusion_method and "delay" not in value_fusion_method: + hidden_states = torch.cat([integrated_value, hidden_states], dim=1) + extend_length = integrated_value.shape[1] + extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype) + attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape[:2], inputs_embeds.device + ) + position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule + else: + raise NotImplementedError(f"{value_fusion_method} is not defined.") + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + query_embeds, + readout_indices + ] + if v is not None + ) + return CATEncoderOutput( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + query_embeds=query_embeds, + readout_indices=readout_indices, + updated_attention_mask=None, + ) + + async def forward_with_async_faiss( + self, + input_ids, + attention_mask, + return_dict, + readout_top_k, + key_reduce_method, + value_fusion_method, + key_faiss_index, + value_memory, + not_reduced_key_memory, + encoder_hidden_states=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + ): + assert value_memory is not None + assert key_faiss_index is not None + assert not_reduced_key_memory is not None + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Sanity check + assert not use_cache, "This class does not support use_cache because it is encoder only" + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + # Concatenate the prefix embeddings and extend the attention masks + prefix_embeds = self.prefix_embedding[None, :, :].expand(batch_size, -1, -1).to(inputs_embeds.device) + inputs_embeds = torch.cat([prefix_embeds, inputs_embeds], dim=1) + + # Extend the attention masks + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length + self.prefix_length).to(inputs_embeds.device) + else: + prefix_mask = torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype).to( + inputs_embeds.device) + attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, (batch_size, seq_length + self.prefix_length), inputs_embeds.device + ) + + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + query_embeds = None + readout_indices = None + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if not use_cache: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + + # Query-key matching + if i == self.key_layer: # key-layer is Query-Layer + raw_query_embeds = self._encode_key(hidden_states, attention_mask) # shape:[batch_size, d_key] + raw_query_embeds = raw_query_embeds.view(hidden_states.shape[0], -1, hidden_states.shape[-1]) + + normed_query_embeds = self.kv_output_layer(raw_query_embeds) # Query is normed !!! + query_embeds = reduce_query_or_key_embeds(normed_query_embeds, key_reduce_method) + + # async mode + + # loop = asyncio.get_event_loop() + # executor = futures.ThreadPoolExecutor() # futures.ProcessPoolExecutor() + + async_query_future = asyncio.get_event_loop().run_in_executor( + futures.ThreadPoolExecutor(), self.query_memory, + value_memory, not_reduced_key_memory, key_faiss_index, query_embeds, readout_top_k + ) + + if i == self.cat_layer: + value_embeds, key_embeds_of_value, readout_indices = await async_query_future + value_embeds = value_embeds.to(query_embeds.device) + key_embeds_of_value = key_embeds_of_value.to(query_embeds.device) + value_embeds = value_embeds.to(query_embeds.dtype) + key_embeds_of_value = key_embeds_of_value.to(query_embeds.dtype) + + if value_fusion_method == "async_cat_k_delay+v": + # Async mode, emat key in cat_layer. the implementation of delay + v is same to serial mode. + batch_size, num_values, key_nums, hidden_size = key_embeds_of_value.shape + hidden_states = torch.cat( + [key_embeds_of_value.view(batch_size, num_values * self.prefix_length, hidden_size), + hidden_states], dim=1 + ) + extend_length = num_values * self.prefix_length + extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype) + attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape[:2], inputs_embeds.device + ) + position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule + + if i == self.value_layer: + batch_size, num_values, _, hidden_size = key_embeds_of_value.shape + + # assert query_embeds is not None, "Use query_embeds to read memory before assignment." + if "delay" in value_fusion_method: + updated_key = hidden_states[:, :num_values * self.prefix_length] + updated_key = updated_key.view(batch_size, num_values, self.prefix_length, hidden_size) + else: + updated_key = None + integrated_value = self.get_integrated_values(value_embeds, key_embeds_of_value, value_fusion_method, + query_embeds=query_embeds, updated_key=updated_key, + key_reduce_method=key_reduce_method) + if value_fusion_method == "infill": + assert self.num_values == 1 + hidden_states = torch.cat([integrated_value, hidden_states[:, self.prefix_length:]], dim=1) + elif "cat" in value_fusion_method and "delay" in value_fusion_method: + hidden_states[:, :num_values * self.prefix_length] = integrated_value + hidden_states = hidden_states.contiguous() + elif "cat" in value_fusion_method and "delay" not in value_fusion_method: + hidden_states = torch.cat([integrated_value, hidden_states], dim=1) + extend_length = integrated_value.shape[1] + extend_mask = torch.ones((batch_size, extend_length), dtype=attention_mask.dtype) + attention_mask = torch.cat([extend_mask.to(inputs_embeds.device), attention_mask], dim=1) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, attention_mask.shape[:2], inputs_embeds.device + ) + position_bias = None # clean the position_bias, compute in T5-SelfAttentionModule + else: + raise NotImplementedError(f"{value_fusion_method} is not defined.") + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + query_embeds, + readout_indices + ] + if v is not None + ) + return CATEncoderOutput( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + query_embeds=query_embeds, + readout_indices=readout_indices, + updated_attention_mask=None, + ) + + def query_memory(self, value_memory, not_reduced_key_memory, key_faiss_index, + query_embeds, readout_top_k): + assert value_memory.shape[1] == self.prefix_length + assert value_memory.shape[2] == self.model_dim + if type(query_embeds) == torch.tensor: + query_embeds = query_embeds.n + top_indices, _ = mips(key_faiss_index, query_embeds.cpu(), readout_top_k, n_queries_to_parallelize=20480) + memory_size, hidden_num, hidden_size = value_memory.shape + bs = query_embeds.shape[0] + top_indices = torch.tensor(top_indices) + readout_value = torch.index_select(value_memory, 0, top_indices.view(-1)) + readout_value = readout_value.view(bs, readout_top_k, hidden_num, hidden_size) + readout_key_embeds_of_value = torch.index_select(not_reduced_key_memory, 0, top_indices.view(-1)) + readout_key_embeds_of_value = readout_key_embeds_of_value.view(bs, readout_top_k, hidden_num, hidden_size) + return readout_value, readout_key_embeds_of_value, top_indices + + +@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) +class T5WithKeyValueMemory(T5ForConditionalGeneration): + _keys_to_ignore_on_load_missing = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + ] + + # _keys_to_ignore_on_load_unexpected = [ + # r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + # ] + + def __init__(self, config): + super(T5ForConditionalGeneration, self).__init__(config) + self.model_dim = config.d_model + # self.generate + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5StackWithKeyValueMemory(encoder_config, self.shared) + if config.not_share_encoder: + self.kv_encoder = T5StackWithKeyValueMemory(copy.deepcopy(encoder_config), self.shared) + else: + self.kv_encoder = None + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + self.adapter = getattr(config, "adapter", None) + if self.adapter is not None: + self.adapter = RetAdapter(config.d_model, config.adapter_out_dim, adapter_type=self.adapter) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + self.config = config + + def get_tunable_key_value_parameters(self): + add_pos_layer = min(self.config.key_layer, self.config.value_layer) + tunable_parameters_list = list(self.encoder.block[add_pos_layer].layer[0]. + SelfAttention.relative_attention_bias.parameters()) + if not self.config.not_share_encoder: + if self.encoder.key_encoder is not None: + tunable_parameters_list += list(self.encoder.key_encoder.parameters()) + tunable_parameters_list += [self.encoder.prefix_embedding] + tunable_parameters_list += list(self.encoder.key_layer_norm.parameters()) + elif self.config.not_share_encoder: + tunable_parameters_list += list(self.kv_encoder.parameters()) + return tunable_parameters_list + + def freeze_t5_params(self): + tunable_key_value_parameters = self.get_tunable_key_value_parameters() + requires_grad_nums = 0 + for param in self.parameters(): + if any(param is tp for tp in tunable_key_value_parameters): + param.requires_grad = True + requires_grad_nums += 1 + else: + param.requires_grad = False + assert requires_grad_nums == len(tunable_key_value_parameters) + logger.info(f"tunable params num: {len(tunable_key_value_parameters)}") + + def freeze_kv_encoder_params(self): + assert self.encoder.key_encoder is not None + kv_encoder_params = list(self.encoder.key_encoder.parameters()) + for param in self.parameters(): + if any(param is tp for tp in kv_encoder_params): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CATSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + key_value_input_ids=None, + key_value_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + decoder_only_attend_on_prefix=False, + encoder_outputs_are_key_or_value=False, + key_embeds=None, + value_embeds=None, + key_memory=None, + value_memory=None, + key_faiss_index=None, + key_reduce_method=None, + value_fusion_method=None, + key_embeds_of_value=None, + use_ae_lm_head=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., + config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for + labels in ``[0, ..., config.vocab_size]`` + + Returns: + + Examples:: + + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') + + >>> # training + >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + key_embeds=key_embeds, + value_embeds=value_embeds, + key_memory=key_memory, + value_memory=value_memory, + key_faiss_index=key_faiss_index, + key_reduce_method=key_reduce_method, + value_fusion_method=value_fusion_method, + key_embeds_of_value=key_embeds_of_value + ) + elif return_dict: + assert isinstance(encoder_outputs, CATEncoderOutput) + + hidden_states = encoder_outputs.last_hidden_state + + # Extend attention_mask on the left for attention for the prefix + batch_size, seq_length = hidden_states.shape[:2] + # seq_length: prefix + original hidden length + # attn_length: original hidden length + if encoder_outputs_are_key_or_value: + encoder_attention_mask = None + else: + attn_length = attention_mask.shape[1] + assert seq_length > attn_length, f"{seq_length} is not larger than {attn_length}" + if decoder_only_attend_on_prefix: + hidden_states = hidden_states[:, :seq_length - attn_length] + encoder_attention_mask = torch.ones(hidden_states.shape[:2]).to(attention_mask.device) + else: + prefix_mask = torch.ones(attention_mask.shape[0], seq_length - attn_length).to(attention_mask.device) + encoder_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1) + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim ** -0.5) + + if use_ae_lm_head: + lm_logits = self.ae_lm_head(sequence_output) + else: + lm_logits = self.lm_head(sequence_output) + + # loss_dict = {} + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + assert isinstance(encoder_outputs, CATEncoderOutput) + return CATSeq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + cat_encoder_outputs=encoder_outputs + ) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, input_ids: torch.LongTensor, model_kwargs + ) -> Dict[str, Any]: + if "encoder_outputs" not in model_kwargs: + # retrieve encoder hidden states + encoder = self.get_encoder() + encoder_kwargs = { + argument: value + for argument, value in model_kwargs.items() + if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) + } + + if model_kwargs.get("key_value_input_ids", None) is not None: + key_embeds, value_embeds = encoder.embed_kv( + input_ids=model_kwargs["key_value_input_ids"], + attention_mask=model_kwargs.get("key_value_attention_mask", None), + head_mask=model_kwargs.get("head_mask", None), + ) + encoder_kwargs["key_embeds"] = key_embeds + encoder_kwargs["value_embeds"] = value_embeds + encoder_kwargs.pop("key_value_input_ids") + if "key_value_attention_mask" in encoder_kwargs: + encoder_kwargs.pop("key_value_attention_mask") + + model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs) + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=False, + value_fusion_method=None, + # use_ae_decoder=False, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "encoder_outputs_are_key_or_value": encoder_outputs_are_key_or_value, + "attention_mask": attention_mask, + "decoder_only_attend_on_prefix": decoder_only_attend_on_prefix, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "value_fusion_method": value_fusion_method, + # "use_ae_decoder": use_ae_decoder + } + + def CAT_embed_kv(self, *args, **kwargs) -> Dict: + if self.kv_encoder is None: + # share kv-encoder with query-encoder + return self.encoder.embed_kv(*args, **kwargs) + else: + # Do not share encoder + return self.kv_encoder.embed_kv(*args, **kwargs) + + def CAT_embed_q(self, *args, **kwargs) -> Dict: + return self.encoder.embed_kv(*args, **kwargs) # self.encoder is always query-encoder + + def compute_key_value_ae_loss( + self, + separate_task=True, + key_value_input_ids=None, + key_value_attention_mask=None, + key_input_ids=None, + key_attention_mask=None, + value_input_ids=None, + value_attention_mask=None, + key_labels_input_ids=None, + value_labels_input_ids=None, + train_key=False, + train_value=False, + use_ae_lm_head=False, + **kwargs + ): + loss_dict = dict() + embed_dict = self.wrapped_embed_kv( + separate_task=separate_task, + key_value_input_ids=key_value_input_ids, + key_value_attention_mask=key_value_attention_mask, + key_input_ids=key_input_ids, + key_attention_mask=key_attention_mask, + value_input_ids=value_input_ids, + value_attention_mask=value_attention_mask, + compute_key=train_key, + compute_value=train_value, + embed_for_ae_task=True + ) + key_embeds = embed_dict["normed_key_embeds"] if train_key else None + if "async_cat_k+v" == self.config.value_fusion_method: + value_embeds = self.encoder.kv_output_layer(embed_dict["value_embeds"] + embed_dict["key_embeds"]) \ + if train_value else None # normed value for generation + else: + value_embeds = embed_dict["normed_value_embeds"] if train_value else None # normed value for generation + # key_embeds = key_embeds.view(key_embeds.shape[0], -1, self.model_dim) if key_embeds is not None else None + # value_embeds = value_embeds.view(value_embeds.shape[0], -1, self.model_dim) + # key_embeds/value_embeds [batch_size, prefix_length, model_dim] + # the length of key_embeds/value_embeds is prefix_length, do not need attention_mask + if train_key: + key_ae_outputs = self.forward( + # batch, num, hidden_size; 1024 -> 2, 512 + encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None), + encoder_outputs_are_key_or_value=True, + labels=key_labels_input_ids, + use_ae_lm_head=use_ae_lm_head, + ) + loss_dict["key_ae_loss"] = key_ae_outputs["loss"] + if train_value: + value_ae_outputs = self.forward( + encoder_outputs=CATEncoderOutput(last_hidden_state=value_embeds, hidden_states=None, attentions=None), + encoder_outputs_are_key_or_value=True, + labels=value_labels_input_ids, + use_ae_lm_head=use_ae_lm_head, + ) + loss_dict["value_ae_loss"] = value_ae_outputs["loss"] + return loss_dict + + def compute_text_pair_key_value_ae_loss( + self, + key_input_ids=None, + key_attention_mask=None, + value_input_ids=None, + value_attention_mask=None, + key_labels_input_ids=None, + value_labels_input_ids=None, + separate_decode=False, + hypothesis_decoder_input_ids=None, + hypothesis_decoder_labels=None, + premise_decoder_input_ids=None, + premise_decoder_labels=None, + train_key=False, + train_value=False, + **kwargs + ): + # Auto-encoding pretraining for text-pair task (e.g. NLI) + # separate_decode argument choices whether the model generates text-pair respectively. + loss_dict = dict() + embed_dict = self.wrapped_embed_kv( + separate_task=True, + key_value_input_ids=None, + key_value_attention_mask=None, + key_input_ids=key_input_ids, + key_attention_mask=key_attention_mask, + value_input_ids=value_input_ids, + value_attention_mask=value_attention_mask, + compute_key=train_key, + compute_value=train_value, + embed_for_ae_task=True + ) + key_embeds = embed_dict["normed_key_embeds"] if train_key else None + value_embeds = embed_dict["normed_value_embeds"] if train_value else None + if train_key: + if separate_decode: + # hypothesis auto-encoding + hypothesis_ae_outputs = self.forward( + encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None), + encoder_outputs_are_key_or_value=True, + decoder_input_ids=hypothesis_decoder_input_ids, + labels=hypothesis_decoder_labels, + use_ae_lm_head=False, + ) + loss_dict["hypothesis_ae_loss"] = hypothesis_ae_outputs["loss"] + # premise auto-encoding + premise_ae_outputs = self.forward( + encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None), + encoder_outputs_are_key_or_value=True, + decoder_input_ids=premise_decoder_input_ids, + labels=premise_decoder_labels, + use_ae_lm_head=False, + ) + loss_dict["premise_ae_loss"] = premise_ae_outputs["loss"] + else: + key_ae_outputs = self.forward( + # batch, num, hidden_size; 1024 -> 2, 512 + encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None), + encoder_outputs_are_key_or_value=True, + labels=key_labels_input_ids, + use_ae_lm_head=False, + ) + loss_dict["key_ae_loss"] = key_ae_outputs["loss"] + if train_value: + value_ae_outputs = self.forward( + encoder_outputs=CATEncoderOutput(last_hidden_state=value_embeds, hidden_states=None, attentions=None), + encoder_outputs_are_key_or_value=True, + labels=value_labels_input_ids, + use_ae_lm_head=False, + ) + loss_dict["value_ae_loss"] = value_ae_outputs["loss"] + return loss_dict + + def compute_qa_loss( + self, + input_ids=None, + attention_mask=None, + labels=None, + decoder_only_attend_on_prefix=False, + encoder_outputs_are_key_or_value=False, + key_memory=None, + value_memory=None, + key_faiss_index=None, + key_reduce_method=None, + positive_key_embeds=None, + negative_key_embeds=None, + value_embeds=None, + matching_targets=None, + value_fusion_method=None, + key_embeds_of_value=None, + negative_mask=None, + only_train_adapter=False + ): + loss_dict = dict() + if only_train_adapter: + embed_dict = self.CAT_embed_q( + input_ids=input_ids, + attention_mask=attention_mask, + compute_key=True, compute_value=False + ) + query_embeds = embed_dict["normed_key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, key_reduce_method) + gen_loss = torch.tensor(0.0) + else: + outputs = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + decoder_only_attend_on_prefix=decoder_only_attend_on_prefix, + encoder_outputs_are_key_or_value=encoder_outputs_are_key_or_value, + key_embeds=None, + value_embeds=value_embeds, + key_memory=key_memory, + value_memory=value_memory, + key_faiss_index=key_faiss_index, + key_reduce_method=key_reduce_method, + value_fusion_method=value_fusion_method, + key_embeds_of_value=key_embeds_of_value + ) + query_embeds = outputs.cat_encoder_outputs.query_embeds + gen_loss = outputs.loss + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss_dict["gen_loss"] = gen_loss + if positive_key_embeds is not None: + if self.adapter is not None: + query_embeds = self.adapter(query_embeds) + positive_key_embeds = self.adapter(positive_key_embeds) + negative_key_embeds = self.adapter(negative_key_embeds) + scores1 = torch.mm(query_embeds, positive_key_embeds.t()) + scores2 = torch.mm(query_embeds, negative_key_embeds.t()) + if negative_mask is not None: + negative_mask = ~negative_mask.bool().to(negative_key_embeds.device) + scores2 = scores2.masked_fill(negative_mask, float('-inf')) + scores = torch.cat((scores1, scores2), dim=1) + match_loss = loss_fct(scores, matching_targets) + loss_dict["match_loss"] = match_loss + + return loss_dict + + def compute_gen_and_match_loss( + self, + input_ids=None, + attention_mask=None, + labels=None, + decoder_only_attend_on_prefix=False, + encoder_outputs_are_key_or_value=False, + key_reduce_method=None, + value_embeds=None, + value_fusion_method=None, + key_embeds_of_value=None, + positive_and_negative_embeds=None, + matching_mask=None, + matching_targets=None, + use_triple_loss=None, + ): + """ + value_embeds: retrieved key's value embeds, shape: [batch_size, num_values, prefix_length, hidden_size] + key_embeds_of_value: retrieved key embeds, shape: [batch_size, num_values, key_dim // model_dim, hidden_size] + """ + loss_dict = dict() + outputs = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + decoder_only_attend_on_prefix=decoder_only_attend_on_prefix, + encoder_outputs_are_key_or_value=encoder_outputs_are_key_or_value, + key_embeds=None, + value_embeds=value_embeds, + key_memory=None, + value_memory=None, + key_faiss_index=None, + key_reduce_method=key_reduce_method, + value_fusion_method=value_fusion_method, + key_embeds_of_value=key_embeds_of_value + ) + + gen_loss = outputs.loss + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss_dict["gen_loss"] = gen_loss + if not use_triple_loss and positive_and_negative_embeds is not None: + query_embeds = outputs.cat_encoder_outputs.query_embeds + # query_embeds: batch_size, hidden + # positive_and_negative_embeds: N, hidden + scores = torch.mm(query_embeds, positive_and_negative_embeds.transpose(1, 0)) + matching_mask = ~matching_mask.bool().to(positive_and_negative_embeds.device) + scores = scores.masked_fill(matching_mask, float('-inf')) + match_loss = loss_fct(scores, matching_targets) + loss_dict["match_loss"] = match_loss + # elif use_triple_loss and positive_key_embeds is not None: + # triple_loss_fct = nn.TripletMarginLoss(margin=0.5, p=2) + # query_embeds = outputs.cat_encoder_outputs.query_embeds + # batch_size, hidden_size = query_embeds.shape + # # negative_nums = negative_key_embeds.shape[0] // batch_size + # negative_key_embeds = negative_key_embeds.view(batch_size, -1, hidden_size) + # group_negative_key_embeds = negative_key_embeds.transpose(0, 1) + # triple_losses = [] + # for cur_negative_key_embeds in group_negative_key_embeds: + # triple_losses.append( + # triple_loss_fct(query_embeds, positive_key_embeds, cur_negative_key_embeds) + # ) + # triple_loss = sum(triple_losses) / len(triple_losses) + # loss_dict["triple_loss"] = torch.nan_to_num(triple_loss) + + return loss_dict + + def wrapped_embed_kv( + self, + separate_task=False, + key_value_input_ids=None, + key_value_attention_mask=None, + key_input_ids=None, + key_attention_mask=None, + value_input_ids=None, + value_attention_mask=None, + compute_key=False, + compute_value=False, + embed_for_ae_task=False, + ) -> Dict: + device = self.device + # key_embeds, value_embeds = None, None + if separate_task: + res = dict() + if compute_key: + key_res = self.CAT_embed_kv( + input_ids=key_input_ids.to(device), attention_mask=key_attention_mask.to(device), + compute_key=compute_key, compute_value=False, embed_for_ae_task=embed_for_ae_task + ) + res.update({k: v for k, v in key_res.items() if "key" in k}) + if compute_value: + value_res = self.CAT_embed_kv( + input_ids=value_input_ids.to(device), attention_mask=value_attention_mask.to(device), + compute_key=False, compute_value=compute_value, embed_for_ae_task=embed_for_ae_task, + + ) + res.update({k: v for k, v in value_res.items() if "value" in k}) + else: + res = self.CAT_embed_kv( + input_ids=key_value_input_ids.to(device), + attention_mask=key_value_attention_mask.to(device), + compute_key=compute_key, compute_value=compute_value, + embed_for_ae_task=embed_for_ae_task + ) + return res diff --git a/emat/utils.py b/emat/utils.py new file mode 100644 index 0000000..858b24a --- /dev/null +++ b/emat/utils.py @@ -0,0 +1,68 @@ +import json +import logging + +import torch + +logger = logging.getLogger(__name__) + +try: + import apex + from apex import amp + + apex.amp.register_half_function(torch, "einsum") + _has_apex = True +except ImportError: + _has_apex = False + + +def is_apex_available(): + return _has_apex + + +def to_fp16(model): + if is_apex_available(): + model = amp.initialize(model, opt_level="O1") + else: + model = model.half() + return model + + +def load_jsonl(fn): + all_data = [] + with open(fn, "r") as f: + for line in f.readlines(): + all_data.append(json.loads(line)) + return all_data + + +def write_jsonl(all_data, fn): + with open(fn, "w") as f: + for data in all_data: + f.write(json.dumps(data) + "\n") + + +def convert_repaq_results_from_file(in_file, num_candidates, out_file=None): + data = load_jsonl(in_file) + processed_data = [] + + for sample in data: + sample_dict = { + "question": sample["input_qa"]["question"], + "answer": sample["input_qa"]["answer"], + } + for i, qas in enumerate(sample["retrieved_qas"][:num_candidates]): + q, a = qas["question"], qas["answer"][0] + sample_dict[f"ret_q_{i}"] = q + sample_dict[f"ret_a_{i}"] = a + + processed_data.append(sample_dict) + + if out_file is not None: + write_jsonl(processed_data, out_file) + + return processed_data + + +def verbalise_qa(q, a) -> str: + q = q.strip("?") + return f'{q}? answer: {a}' diff --git a/embed_and_build_index.py b/embed_and_build_index.py new file mode 100644 index 0000000..1b379ea --- /dev/null +++ b/embed_and_build_index.py @@ -0,0 +1,187 @@ +import os +import pickle +import argparse +import torch +from transformers import T5Tokenizer +import copy +from emat.t5 import T5WithKeyValueMemory +from emat.utils import load_jsonl +import logging +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from utils.utils import get_key_value_encoder_inputs, reduce_query_or_key_embeds + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + +QA_KB_PATHS = { + "PAQ_L1": "./tmp/PAQ_L1_pickl_file.pkl", + "PAQ": "./tmp/PAQ_full.pkl", + "TAQ_TRAIN_NQ_TRAIN_PAQ": "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl", + "debug": "./tmp/PAQ_L1_small.pkl" +} + + +def get_args(): + parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Embed and build FAISS") + parser.add_argument("--model_name_or_path", type=str, required=False, + default="./outputs/nq_checkpoints/KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;pos_from_top=128;/best_ckpt/") + parser.add_argument("--qas_to_retrieve_from", choices=list(QA_KB_PATHS.keys()), default=f"debug") + parser.add_argument("--add_nq_train", action="store_true") + parser.add_argument("--add_nq_dev", action="store_true") + parser.add_argument("--embed_batch_size", type=int, default=512) + parser.add_argument("--save_dir", default=f"./data/embedding_and_faiss/debug_from_nq_ckpt") + + args = parser.parse_args() + return args + + +def load_qas_to_embed(qas_to_retrieve_from, add_nq_train, add_nq_dev): + logging.info("loading qas to retrieve") + qas_to_retrieve_fp = QA_KB_PATHS[qas_to_retrieve_from] + logging.info(f"loading qas from {qas_to_retrieve_fp}") + if qas_to_retrieve_fp.endswith("pkl"): + qas_to_embed = pickle.load(open(qas_to_retrieve_fp, 'rb')) + elif qas_to_retrieve_fp.endswith("jsonl"): + qas_to_embed = load_jsonl(qas_to_retrieve_fp) + else: + raise ValueError(f"{qas_to_retrieve_fp}") + logging.info(f"load {len(qas_to_embed)} qas from PAQ.") + + # if qas_to_retrieve_from == "debug": + # qas_to_retrieve = qas_to_retrieve[:10000] + + if add_nq_train: + logging.info("add nq-train qas.") + qas_to_embed = qas_to_embed + load_jsonl("./data/annotated_datasets/NQ-open.train-train.jsonl") + if add_nq_dev: + logging.info("add nq-dev qas.") + qas_to_embed = qas_to_embed + load_jsonl("./data/annotated_datasets/NQ-open.train-dev.jsonl") + + logging.info(f"load {len(qas_to_embed)} qas totally.") + + return qas_to_embed + + +@torch.no_grad() +def embed_key_value(model, tokenizer, data_to_embed, embed_batch_size, save_dir, + use_fp16_model=True, key_reduce_method="avg", max_source_length=1024, prefix="question: "): + if use_fp16_model: + model = model.half() + logging.info("") + + # model.eval() + # key_memory, value_memory, not_reduced_key_memory = build_memory( + # model, tokenizer, embed_key=True, embed_value=True, prefix=prefix, embed_as_fp16=True, + # key_reduce_method=key_reduce_method, return_memory=True, dump_memory=False, + # data_to_embed=data_to_embed, max_source_length=max_source_length, padding=True, + # batch_size=embed_batch_size, separate_task=True, return_not_reduced_key=True, + # + # reused_key_memory=reused_key_memory, + # reused_value_memory=reused_value_memory, + # reused_not_reduced_key_memory=reused_not_reduced_key_memory + # ) + # return key_memory, value_memory, not_reduced_key_memory + if not os.path.exists(save_dir): + os.makedirs(save_dir) + else: + logging.warning(f"{save_dir} is exists. re-write contents warning.") + + def collate_fn(examples): + model_inputs = get_key_value_encoder_inputs(examples, True, tokenizer, max_source_length, + prefix=prefix, only_return_key_inputs=False) + return model_inputs + + data_to_embed_dataloader = DataLoader(data_to_embed, batch_size=embed_batch_size, + num_workers=4, collate_fn=collate_fn) + import gc + + def save_embedding_index(): + reused_key_memory = torch.zeros((len(data_to_embed), model.model_dim), device="cpu", dtype=torch.float16) + key_cnt = 0 + for batch in tqdm(data_to_embed_dataloader): + batch_keys = list(batch.keys()) + batch = {k: v.to(model.device) for k, v in batch.items()} + embed_dict = model.wrapped_embed_kv(separate_task=True, **batch, + compute_key=True, compute_value=False) + for k in batch_keys: + del batch[k] + key_embeds = embed_dict.get("normed_key_embeds") + key_embeds = reduce_query_or_key_embeds(key_embeds, key_reduce_method) + cur_key_num = key_embeds.shape[0] + key_embeds = key_embeds.half().cpu() + reused_key_memory[key_cnt: key_cnt + cur_key_num] = copy.deepcopy(key_embeds) + del key_embeds + torch.save(reused_key_memory, os.path.join(save_dir, "embedding_index.pt")) + logging.info("embedding index saved.") + + def save_value_memory(): + reused_value_memory = torch.zeros((len(data_to_embed), 2, model.model_dim), device="cpu", dtype=torch.float16) + value_cnt = 0 + for batch in tqdm(data_to_embed_dataloader): + batch_keys = list(batch.keys()) + batch = {k: v.to(model.device) for k, v in batch.items()} + embed_dict = model.wrapped_embed_kv(separate_task=True, **batch, + compute_key=False, compute_value=True) + for k in batch_keys: + del batch[k] + value_embeds = embed_dict.get("value_embeds") + cur_value_num = value_embeds.shape[0] + value_embeds = value_embeds.half().cpu() + reused_value_memory[value_cnt: value_cnt + cur_value_num] = copy.deepcopy(value_embeds) + del value_embeds + torch.save(reused_value_memory, os.path.join(save_dir, "value_memory.pt")) + logging.info("value memory saved.") + + def save_key_memory(): + reused_not_reduced_key_memory = torch.zeros((len(data_to_embed), 2, model.model_dim), + device="cpu", dtype=torch.float16) + nr_key_cnt = 0 + for batch in tqdm(data_to_embed_dataloader): + batch_keys = list(batch.keys()) + batch = {k: v.to(model.device) for k, v in batch.items()} + embed_dict = model.wrapped_embed_kv(separate_task=True, **batch, + compute_key=True, compute_value=False) + for k in batch_keys: + del batch[k] + not_normed_key_embeds = embed_dict["key_embeds"] + cur_key_num = not_normed_key_embeds.shape[0] + not_normed_key_embeds = not_normed_key_embeds.half().cpu() + reused_not_reduced_key_memory[nr_key_cnt: nr_key_cnt + cur_key_num] = copy.deepcopy(not_normed_key_embeds) + del not_normed_key_embeds + torch.save(reused_not_reduced_key_memory, os.path.join(save_dir, "key_memory.pt")) + logging.info("key memory saved.") + + save_embedding_index() + gc.collect() + save_value_memory() + gc.collect() + save_key_memory() + gc.collect() + + +def main(): + args = get_args() + logging.info("loading model") + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path) + model, load_info = T5WithKeyValueMemory.from_pretrained(args.model_name_or_path, output_loading_info=True) + model = model.cuda() + model.eval() + logging.info(f"model load info: {load_info}") + + logging.info("loading data") + data_to_embed = load_qas_to_embed(args.qas_to_retrieve_from, args.add_nq_train, args.add_nq_dev) + + logging.info("embedding") + embed_key_value(model, tokenizer, data_to_embed, args.embed_batch_size, args.save_dir) + # key_memory is normed and reduced + # value_memory is not normed + # not_reduced_key_memory is not normed and not reduced + logging.info("embedding saved.") + + +if __name__ == '__main__': + main() diff --git a/embed_scripts/nq_embed_paq_and_build_faiss.sh b/embed_scripts/nq_embed_paq_and_build_faiss.sh new file mode 100644 index 0000000..3745aa5 --- /dev/null +++ b/embed_scripts/nq_embed_paq_and_build_faiss.sh @@ -0,0 +1,32 @@ +#!/bin/bash -l + +set -e +set -u + +DST_DIR="case-augmented-transformer-master" +cd ${DST_DIR} + +DEVICE="0" + +PAQ_TYPE="PAQ_L1" +SAVE_DIR="./data/embedding_and_faiss/${PAQ_TYPE}_from_nq_ckpt" +NQ_MODEL_PATH="" # -- + +CUDA_VISIBLE_DEVICES=${DEVICE} python embed_and_build_index.py \ + --model_name_or_path=${NQ_MODEL_PATH} \ + --qas_to_retrieve_from=${PAQ_TYPE} \ + --embed_batch_size=2048 \ + --save_dir=${SAVE_DIR} \ + --add_nq_train \ + --add_nq_dev + +CUDA_VISIBLE_DEVICES=${DEVICE} python emat/retriever/build_index.py \ + --embeddings_dir="./data/embedding_and_faiss/PAQ_from_nq_ckpt/embedding_index.pt" \ + --output_path="${SAVE_DIR}/key.sq8hnsw.80n80efc.faiss" \ + --hnsw \ + --store_n 80 \ + --ef_construction 80 \ + --ef_search 32 \ + --SQ8 \ + --indexing_batch_size=1 \ + --verbose diff --git a/inference_with_faiss.py b/inference_with_faiss.py new file mode 100644 index 0000000..529a80e --- /dev/null +++ b/inference_with_faiss.py @@ -0,0 +1,320 @@ +import faiss +import os +from utils.utils import update_CAT_config_from_args +import asyncio +import argparse +import torch +from transformers import T5Tokenizer, T5Config +from emat.t5 import T5WithKeyValueMemory +from emat.utils import load_jsonl +import logging +from embed_and_build_index import load_qas_to_embed +import time +from kilt_dataset import DialogDataset +from torch.nn.utils.rnn import pad_sequence + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + +QA_KB_PATHS = { + "PAQ_L1": "./tmp/PAQ_L1_pickl_file.pkl", + "PAQ": "./tmp/PAQ_full.pkl", + "TAQ_TRAIN_NQ_TRAIN_PAQ": "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl", + "debug": "./tmp/PAQ_L1_small.pkl" +} + + +def get_args(): + parser: argparse.ArgumentParser = argparse.ArgumentParser(description="Inference with faiss") + parser.add_argument("--model_name_or_path", type=str, required=False, + default="./outputs/nq_checkpoints/KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;pos_from_top=128;/best_ckpt/") + parser.add_argument("--f", choices=list(QA_KB_PATHS.keys()), default=f"debug") + parser.add_argument("--add_nq_train", action="store_true") + parser.add_argument("--add_nq_dev", action="store_true") + parser.add_argument("--inference_batch_size", type=int, default=512) + parser.add_argument("--load_dir", default=f"./data/embedding_and_faiss/debug_from_nq_ckpt") + parser.add_argument("--inference_type", type=str, default="async", choices=["async", "serial", "t5"]) + parser.add_argument("--cat_layer", default=7, type=int) + parser.add_argument("--test_task", default="wow", type=str, choices=["qa", "wow", "eli5"]) + parser.add_argument("--model_size", default="base", type=str, choices=["base", "large", "3B"]) + parser.add_argument("--faiss_path", default="", type=str, required=False) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + logging.info("loading faiss index.") + if args.faiss_path == "": + faiss_path = os.path.join(args.load_dir, "key.sq8.hnsw.faiss") + else: + faiss_path = args.faiss_path + key_faiss_index = faiss.read_index(faiss_path) + logging.info("loaded faiss index.") + + logging.info("loading memory.") + value_memory = torch.load(os.path.join(args.load_dir, "value_memory.pt")) + key_memory = torch.load(os.path.join(args.load_dir, "key_memory.pt")) + logging.info("loaded memory.") + + logging.info("loading data") + qas_to_retrieve = load_qas_to_embed(args.qas_to_retrieve_from, args.add_nq_train, args.add_nq_dev) + logging.info("loaded data") + assert len(qas_to_retrieve) == value_memory.shape[0] == key_memory.shape[0] + + logging.info("loading model") + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path) + if args.model_size == "3B": + config = T5Config.from_pretrained(args.model_name_or_path) + args.value_fusion_method = "cat_key_delay+v" + args.num_values = 10 + args.prefix_length = 2 + args.key_encoder_type = "prefix" + args.key_layer = 3 + args.value_layer = 7 + args.d_key = config.d_model * args.prefix_length + args.use_two_prefix = False + args.not_share_encoder = False + update_CAT_config_from_args(config, args) + model, load_info = T5WithKeyValueMemory.from_pretrained(args.model_name_or_path, config=config, + output_loading_info=True) + logging.info("loaded T5-3B.") + else: + model, load_info = T5WithKeyValueMemory.from_pretrained(args.model_name_or_path, output_loading_info=True) + model.eval() + logging.info(f"model load info: {load_info}") + + if args.test_task == "qa": + # test_data = load_jsonl("./data/annotated_datasets/NQ-open.test.jsonl") + test_data = load_jsonl("./data/annotated_datasets/NQ-open.train-train.jsonl")[:512 * 40] + logging.info(f"loaded {len(test_data)} test qas.") + else: + if args.test_task == "wow": + dataset_kwargs = { + "dataset_name": "wow_kilt", + "max_source_length": 1024 + } + test_data = load_jsonl("./data/annotated_datasets/wizard_of_wikipedia/wow-dev-kilt.jsonl")[:512] + test_data = test_data * 10 + logging.info(f"loaded {len(test_data)} test history-response pairs.") + else: + dataset_kwargs = { + "dataset_name": "eli5_kilt", + "max_source_length": 384, + "max_target_length": 1536 + } + test_data = load_jsonl("./data/annotated_datasets/eli5/eli5-dev-kilt.jsonl")[:512] + test_data = test_data * 10 + logging.info(f"loaded {len(test_data)} test long-form qas.") + test_dataset = DialogDataset(test_data, tokenizer, qas_to_retrieve, max_utterances=13, **dataset_kwargs) + test_data = test_dataset.data + + torch.cuda.empty_cache() + + model = model.cuda() + if args.inference_type == "serial": + serial_inference(model, tokenizer, test_data, args.inference_batch_size, key_faiss_index, value_memory, + key_memory, qas_to_retrieve, args.test_task) + elif args.inference_type == "async": + async_inference(model, tokenizer, test_data, args.inference_batch_size, key_faiss_index, value_memory, + key_memory, qas_to_retrieve, args.cat_layer, args.test_task) + elif args.inference_type == "t5": + t5_inference(model, tokenizer, test_data, args.inference_batch_size, key_faiss_index, value_memory, + key_memory, qas_to_retrieve, args.test_task) + + +def get_query_inputs(tokenizer, batch, device, test_task): + if test_task == "qa": + query_inputs = ["question: " + qa["question"] for qa in batch] + query_inputs = tokenizer(query_inputs, max_length=1024, + padding=True, truncation=True, return_tensors="pt") + return query_inputs["input_ids"].to(device), query_inputs["attention_mask"].to(device) + else: + history_input_ids = [ex["input_ids"] for ex in batch] + history_input_ids = pad_sequence(history_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + history_attention_mask = (history_input_ids != tokenizer.pad_token_id).long() + return history_input_ids.to(device), history_attention_mask.to(device) + + +@torch.no_grad() +def serial_inference(model: T5WithKeyValueMemory, tokenizer, test_data, batch_size, + key_faiss_index, value_memory, not_reduced_key_memory, qas_to_retrieve, test_task): + if test_task == "qa": + gen_kwargs = {"num_beams": None, "max_length": 64} + elif test_task == "wow": + gen_kwargs = {"num_beams": None, "max_length": 28, "min_length": 28} + else: + gen_kwargs = {"num_beams": None, "max_length": 187, "min_length": 187} + + readout_top_k = model.config.num_values + key_reduce_method = "avg" + value_fusion_method = model.config.value_fusion_method + + time_log = [] + query_log = [] + for start_idx in range(0, len(test_data), batch_size): + start_time = time.perf_counter() + + batch = test_data[start_idx: start_idx + batch_size] + input_ids, attention_mask = get_query_inputs(tokenizer, batch, model.device, test_task) + + encoder_outputs = model.encoder.forward_with_faiss( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + readout_top_k=readout_top_k, + key_reduce_method=key_reduce_method, + value_fusion_method=value_fusion_method, + key_faiss_index=key_faiss_index, + value_memory=value_memory, + not_reduced_key_memory=not_reduced_key_memory + ) + generated_tokens = model.generate( + encoder_outputs=encoder_outputs, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=False, + attention_mask=attention_mask, + value_fusion_method=value_fusion_method, + **gen_kwargs, + ) + cur_cost = time.perf_counter() - start_time + time_log.append(cur_cost) + query_log.append(len(batch)) + logging.info(f" {len(batch)} queries / {cur_cost} seconds") + + time_log = time_log[2:-1] + query_log = query_log[2:-1] + query_num = sum(query_log) + total_time = sum(time_log) + logging.info(f"average speed: {query_num} queries / {total_time} seconds = " + f"{query_num / total_time} queries per second") + + +@torch.no_grad() +def async_inference(model: T5WithKeyValueMemory, tokenizer, test_data, batch_size, + key_faiss_index, value_memory, not_reduced_key_memory, qas_to_retrieve, cat_layer, test_task): + if test_task == "qa": + gen_kwargs = {"num_beams": None, "max_length": 64} + elif test_task == "wow": + gen_kwargs = {"num_beams": None, "max_length": 28, "min_length": 28} + else: + gen_kwargs = {"num_beams": None, "max_length": 187, "min_length": 187} + + readout_top_k = model.config.num_values + key_reduce_method = "avg" + # value_fusion_method = "async_cat_k+v" + model.encoder.key_layer = 3 + model.encoder.cat_layer = cat_layer + model.encoder.value_layer = 10 + if model.encoder.cat_layer == model.encoder.value_layer: + value_fusion_method = "async_cat_k+v" + else: + value_fusion_method = "async_cat_k_delay+v" + + logging.info(f"cat_layer: {cat_layer}") + + time_log = [] + query_log = [] + for start_idx in range(0, len(test_data), batch_size): + start_time = time.perf_counter() + + batch = test_data[start_idx: start_idx + batch_size] + input_ids, attention_mask = get_query_inputs(tokenizer, batch, model.device, test_task) + + encoder_outputs = asyncio.run( + model.encoder.forward_with_async_faiss( + input_ids, attention_mask, True, readout_top_k, key_reduce_method, value_fusion_method, + key_faiss_index, value_memory, not_reduced_key_memory + ) + ) + generated_tokens = model.generate( + encoder_outputs=encoder_outputs, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=False, + attention_mask=attention_mask, + value_fusion_method=value_fusion_method, + **gen_kwargs, + ) + cur_cost = time.perf_counter() - start_time + time_log.append(cur_cost) + query_log.append(len(batch)) + logging.info(f" {len(batch)} queries / {cur_cost} seconds") + + time_log = time_log[2:-1] + query_log = query_log[2:-1] + query_num = sum(query_log) + total_time = sum(time_log) + logging.info(f"cat_layer: {cat_layer}") + logging.info(f"average speed: {query_num} queries / {total_time} seconds = " + f"{query_num / total_time} queries per second") + + +# ELI5 --inference_batch_size=128 +# WoW --inference_batch_size=256 + +@torch.no_grad() +def t5_inference(model: T5WithKeyValueMemory, tokenizer, test_data, batch_size, + key_faiss_index, value_memory, not_reduced_key_memory, qas_to_retrieve, test_task): + if test_task == "qa": + gen_kwargs = {"num_beams": None, "max_length": 16} + elif test_task == "wow": + gen_kwargs = {"num_beams": None, "max_length": 28, "min_length": 28} + else: + gen_kwargs = {"num_beams": None, "max_length": 187, "min_length": 187} + + readout_top_k = model.config.num_values + key_reduce_method = "avg" + value_fusion_method = model.config.value_fusion_method + model.key_layer = 1000 + model.value_layer = 1000 + model.cat_layer = 1000 + model.encoder.key_layer = 1000 + model.encoder.value_layer = 1000 + model.encoder.cat_layer = 1000 + + time_log = [] + query_log = [] + for start_idx in range(0, len(test_data), batch_size): + start_time = time.perf_counter() + + batch = test_data[start_idx: start_idx + batch_size] + input_ids, attention_mask = get_query_inputs(tokenizer, batch, model.device, test_task) + + encoder_outputs = model.encoder.forward_with_faiss( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=True, + readout_top_k=readout_top_k, + key_reduce_method=key_reduce_method, + value_fusion_method=value_fusion_method, + key_faiss_index=key_faiss_index, + value_memory=value_memory, + not_reduced_key_memory=not_reduced_key_memory + ) + generated_tokens = model.generate( + encoder_outputs=encoder_outputs, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=False, + attention_mask=attention_mask, + value_fusion_method=value_fusion_method, + **gen_kwargs, + ) + cur_cost = time.perf_counter() - start_time + time_log.append(cur_cost) + query_log.append(len(batch)) + logging.info(f" {len(batch)} queries / {cur_cost} seconds") + + time_log = time_log[2:-1] + query_log = query_log[2:-1] + query_num = sum(query_log) + total_time = sum(time_log) + logging.info(f"average speed: {query_num} queries / {total_time} seconds = " + f"{query_num / total_time} queries per second") + + +if __name__ == '__main__': + main() diff --git a/kilt_dataset.py b/kilt_dataset.py new file mode 100644 index 0000000..c92c930 --- /dev/null +++ b/kilt_dataset.py @@ -0,0 +1,438 @@ +import logging +import string +from itertools import chain +import copy +from torch.nn.utils.rnn import pad_sequence +import re +import torch +from torch.utils.data import Dataset, DataLoader +from typing import List, Dict +import random +from emat.evaluation.exact_match import normalize_answer +from utils.utils import process_labels +from transformers import T5Tokenizer +from tqdm.auto import tqdm + + +class DialogDataset(Dataset): + def __init__( + self, + data: List[Dict], + tokenizer: T5Tokenizer, + qas_to_retrieve, + dataset_name, + retrieve_strategy="dr", + max_source_length=None, + args=None, + normed_answer_of_qas_to_ret=None, + max_utterances=100, + add_topic=True, + add_persona=True, + max_target_length=512, + ): + super(DialogDataset, self).__init__() + + assert dataset_name in ["wow", "wow_unseen", "wow_kilt", "eli5_kilt"] + self.max_source_length = max_source_length if max_source_length is not None else 1024 + self.dataset_name = dataset_name + # print(f"dataset-name: {dataset_name}") + self.max_target_length = max_target_length + self.tokenizer = tokenizer + self.pad_idx = self.tokenizer.pad_token_id + self.label_pad_idx = -100 + self.args = args + self.qas_to_retrieve = qas_to_retrieve + self.normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret + self.max_utterances = max_utterances + self.add_topic = add_topic + self.add_persona = add_persona + + self.pad_qa = {"question": "", "answer": [""]} + if dataset_name == "wow_kilt": + self.data: List[Dict] = self.process_kilt_input(data) + elif dataset_name == "eli5_kilt": + self.data: List[Dict] = self.process_eli5_kilt_input(data) + else: + self.data: List[Dict] = self.process_to_input_and_response_pairs(data) + + if "normalized_response" in self.data[0].keys(): + stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're", "you've", + "you'll", "you'd", 'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', + 'she', "she's", 'her', 'hers', 'herself', 'it', "it's", 'its', 'itself', 'they', 'them', + 'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', "that'll", + 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', + 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', + 'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', + 'between', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'to', 'from', + 'up', 'down', 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', + 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', + 'most', 'other', 'some', 'such', 'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', + 'too', 'very', 's', 't', 'can', 'will', 'just', 'don', "don't", 'should', "should've", 'now', + 'd', 'll', 'm', 'o', 're', 've', 'y', 'ain', 'aren', "aren't", 'couldn', "couldn't", 'didn', + "didn't", 'doesn', "doesn't", 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'isn', + "isn't", 'ma', 'mightn', "mightn't", 'mustn', "mustn't", 'needn', "needn't", 'shan', "shan't", + 'shouldn', "shouldn't", 'wasn', "wasn't", 'weren', "weren't", 'won', "won't", 'wouldn', + "wouldn't"] + if "wow" in dataset_name: + for item in self.data: + item["normalized_response_remove_stop_words_list"] = [ + w for w in item["normalized_response"].split() if w not in stop_words + ] + else: + assert dataset_name == "eli5_kilt" + for item in self.data: + item["normalized_response_remove_stop_words_list"] = [ + w for w in item["normalized_response"].split() if w not in stop_words + ] + item["normalized_response_remove_stop_words_list"] = \ + item["normalized_response_remove_stop_words_list"][:512] + + @staticmethod + def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + def process_kilt_input(self, dialog_data): + processed_data = [] + query_prefix_ids = self.tokenizer("Query:", add_special_tokens=False, return_attention_mask=False)["input_ids"] + for dialog_idx, item in enumerate(dialog_data): + + if len(item["input"]) < 5: + continue + + input_utterances = item["input"].split("\n") + + if len(input_utterances) > self.max_utterances: + continue + + utterances_ids = [] + spk = "Wizard" if len(input_utterances) % 2 == 0 else "Apprentice" + for utterance in input_utterances: + utterance = f"{spk}: {utterance}" + spk = "Wizard" if spk == "Apprentice" else "Apprentice" + ids = self.tokenizer(utterance, add_special_tokens=False, return_attention_mask=False)["input_ids"] + utterances_ids.append(ids) + + if sum(len(u) for u in utterances_ids) > self.max_source_length: + max_length_per_utterance = self.max_source_length // len(utterances_ids) + utterances_ids = [u[:max_length_per_utterance] for u in utterances_ids] + + query_ids = query_prefix_ids + list(chain(*copy.deepcopy(utterances_ids[-2:]))) + [1] + + input_ids = list(chain(*utterances_ids)) + [1] + + cur_data = { + "id": item["id"], + "input_ids": torch.tensor(input_ids), + "query_ids": torch.tensor(query_ids), + "dialog_idx": torch.tensor(dialog_idx), + } + + if "output" in item.keys(): + response = item["output"][0]["answer"] + with self.tokenizer.as_target_tokenizer(): + response_ids = self.tokenizer(response, max_length=self.max_target_length, + return_attention_mask=False)["input_ids"] + cur_data.update({ + "response_ids": torch.tensor(response_ids), + "normalized_response": self.normalize_answer(response) + }) + + processed_data.append(cur_data) + + # logging.info(f"process {len(dialog_data)} dialogs to {len(processed_data)} training examples.") + return processed_data + + def process_eli5_kilt_input(self, eli5_data): + assert self.max_target_length >= 1024 + + def white_space_fix(text): + return " ".join(text.split()) + + processed_data = [] + query_prefix_ids = self.tokenizer("Query:", add_special_tokens=False, return_attention_mask=False)["input_ids"] + + for eli5_idx, item in tqdm(enumerate(eli5_data), total=len(eli5_data)): + + question = item["input"] + question_ids = self.tokenizer(question, add_special_tokens=False, return_attention_mask=False)["input_ids"] + query_ids = (query_prefix_ids + copy.deepcopy(question_ids))[:255] + [1] + question_ids = question_ids[:383] + [1] + + cur_data = { + "id": item["id"], + "input_ids": torch.tensor(question_ids), + "query_ids": torch.tensor(query_ids), + "dialog_idx": torch.tensor(eli5_idx), + + } + + if "output" in item.keys(): + answer = item["output"][0]["answer"] + answer = white_space_fix(answer) + + with self.tokenizer.as_target_tokenizer(): + response_ids = self.tokenizer(answer, max_length=self.max_target_length, + return_attention_mask=False)["input_ids"] + cur_data.update({ + "response_ids": torch.tensor(response_ids), + "normalized_response": self.normalize_answer(answer), + "candidate_responses": [ot['answer'] for ot in item["output"] if "answer" in ot] + }) + + processed_data.append(cur_data) + + logging.info(f"process {len(eli5_data)} dialogs to {len(processed_data)} training examples.") + return processed_data + + def process_to_input_and_response_pairs(self, dialog_data): + processed_data = [] + for dialog_idx, item in enumerate(dialog_data): + dialog = item["dialog"][:self.max_utterances] + inputs = "history:" + if self.add_persona: + inputs = f'persona: {item["persona"]} ' + inputs + if self.add_topic: + inputs = f'topic: {item["chosen_topic"]}. ' + inputs + + for turn_idx, turn in enumerate(dialog): + speaker = turn["speaker"][2:] + assert speaker in ["Wizard", "Apprentice"] + if turn["speaker"][2:] == "Wizard": + if turn_idx == 0: + query = inputs + else: + query = f'topic: {item["chosen_topic"]}. {dialog[turn_idx - 1]["text"]}' + processed_data.append({ + "inputs": inputs, + "response": turn["text"], + "query": query, + "normalized_response": self.normalize_answer(turn["text"]), + "dialog_idx": dialog_idx, + "turn_idx": turn_idx, + }) + inputs = inputs + f' {speaker}: {turn["text"]}' + + logging.info(f"process {len(dialog_data)} dialogs to {len(processed_data)} training examples.") + return processed_data + + def get_qa_key_value_inputs(self, qas, only_return_key_inputs=False): + key_inputs = ["question: " + qa["question"] for qa in qas] + key_inputs = self.tokenizer(key_inputs, max_length=self.max_source_length, + padding=True, truncation=True, return_tensors="pt") + if only_return_key_inputs: + return {"key_input_ids": key_inputs["input_ids"], + "key_attention_mask": key_inputs["attention_mask"]} + else: + value_inputs = ["answer: " + qa["answer"][0] for qa in qas] + value_inputs = self.tokenizer(value_inputs, max_length=self.max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"key_input_ids": key_inputs["input_ids"], + "key_attention_mask": key_inputs["attention_mask"], + "value_input_ids": value_inputs["input_ids"], + "value_attention_mask": value_inputs["attention_mask"]} + + def get_query_inputs(self, batch): + if "kilt" in self.dataset_name: + query_input_ids = [ex["query_ids"] for ex in batch] + query_input_ids = pad_sequence(query_input_ids, batch_first=True, padding_value=self.pad_idx) + query_attention_mask = (query_input_ids != self.pad_idx).long() + return {"query_input_ids": query_input_ids, + "query_attention_mask": query_attention_mask} + else: + query_inputs = [ex["query"] for ex in batch] + query_inputs = self.tokenizer(query_inputs, max_length=self.max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"query_input_ids": query_inputs["input_ids"], + "query_attention_mask": query_inputs["attention_mask"]} + + def get_history_inputs(self, batch): + if "kilt" in self.dataset_name: + history_input_ids = [ex["input_ids"] for ex in batch] + history_input_ids = pad_sequence(history_input_ids, batch_first=True, padding_value=self.pad_idx) + history_attention_mask = (history_input_ids != self.pad_idx).long() + return {"history_input_ids": history_input_ids, + "history_attention_mask": history_attention_mask} + else: + history_inputs = [ex["inputs"] for ex in batch] + history_inputs = self.tokenizer(history_inputs, max_length=self.max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"history_input_ids": history_inputs["input_ids"], + "history_attention_mask": history_inputs["attention_mask"]} + + def get_target_inputs(self, batch): + if "kilt" in self.dataset_name: + target_ids = [ex["response_ids"] for ex in batch] + target_ids = pad_sequence(target_ids, batch_first=True, padding_value=self.pad_idx) + return {"labels": process_labels(target_ids, self.tokenizer)} + else: + targets = [dialog["response"] for dialog in batch] + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(targets, max_length=self.max_target_length, + padding=True, truncation=True, return_tensors="pt") + return {"labels": process_labels(targets, self.tokenizer)} + + def get_dataloader(self, batch_size, shuffle, num_workers): + + def base_collate_fn(batch): + original_batch_size, filtered_batch_size = len(batch), len(batch) + # if not self.args.use_not_exactly_true: + # batch = [ex for ex in batch if len(ex["local_positive"]) > 0] + # filtered_batch_size = len(batch) + # while len(batch) == 0: # avoid empty-batch + # batch = random.sample(self.data, batch_size) + # batch = [ex for ex in batch if len(ex["local_positive"]) > 0] + # # do not change filtered_batch_size even change the batch again. + + model_inputs = { + "trainable_percentage": torch.tensor(filtered_batch_size / original_batch_size).repeat(len(batch)), + # repeat ``len(batch)`` times to compatible in multi-GPUs. + } + model_inputs.update(self.get_query_inputs(batch)) + model_inputs.update(self.get_history_inputs(batch)) + model_inputs.update(self.get_target_inputs(batch)) + + batch_local_positive_num = self.args.batch_local_positive_num + neg_num_each_example = self.args.negatives_num_each_example + local_positive_qas = [] + local_positive_num = [] + local_positive_qas_mask = [] + local_negative_qas = [] + local_pos_mix_neg_qas = [] # num = neg_num_each_example + for ex in batch: + cur_local_positive_qas_ids = [idx for idx in ex["local_positive"][:batch_local_positive_num]] + cur_local_positive_qas = [self.qas_to_retrieve[idx] for idx in cur_local_positive_qas_ids] + cur_pos_num = len(cur_local_positive_qas) + local_positive_num.append(cur_pos_num) + + cur_local_negative_qas_idx = random.sample(ex["local_negative"], neg_num_each_example) + cur_local_negative_qas = [self.qas_to_retrieve[idx] for idx in cur_local_negative_qas_idx] + local_negative_qas.append(cur_local_negative_qas) + cur_local_pos_mix_neg_qas = cur_local_positive_qas + \ + cur_local_negative_qas[:neg_num_each_example - cur_pos_num] + local_pos_mix_neg_qas.append(cur_local_pos_mix_neg_qas) + + cur_pad_num = batch_local_positive_num - cur_pos_num + cur_local_positive_qas_mask = [1] * cur_pos_num + [0] * cur_pad_num + local_positive_qas_mask.append(cur_local_positive_qas_mask) + cur_local_positive_qas.extend([self.pad_qa] * cur_pad_num) + local_positive_qas.append(cur_local_positive_qas) + + model_inputs.update({"local_positive_qas_mask": torch.tensor(local_positive_qas_mask), + "local_positive_num": torch.tensor(local_positive_num), }) + + assert self.args.select_positive_strategy == "softmax_sample" + squeezed_positive_qas = list(chain(*local_positive_qas)) + local_positive_inputs = self.get_qa_key_value_inputs(squeezed_positive_qas, only_return_key_inputs=True) + model_inputs.update({f"local_positive_inputs_{k}": v.view(len(batch), batch_local_positive_num, -1) + for k, v in local_positive_inputs.items()}) + + squeezed_negative_qas = list(chain(*local_negative_qas)) + local_negative_inputs = self.get_qa_key_value_inputs(squeezed_negative_qas, only_return_key_inputs=True) + model_inputs.update({f"local_negative_inputs_{k}": v.view(len(batch), neg_num_each_example, -1) + for k, v in local_negative_inputs.items()}) + + squeezed_mixed_qas = list(chain(*local_pos_mix_neg_qas)) + local_mixed_inputs = self.get_qa_key_value_inputs(squeezed_mixed_qas) + model_inputs.update({f"local_mixed_inputs_{k}": v.view(len(batch), neg_num_each_example, -1) + for k, v in local_mixed_inputs.items()}) + + # all_targets = [[normalize_answer(an) for an in qa["response"]] for qa in batch] + # negative_qas_answer = [normalize_answer(nqa["answer"][0]) for nqa in squeezed_negative_qas] + # negative_mask = [[1 if neg_ans not in cur_all_target else 0 for neg_ans in negative_qas_answer] + # for cur_all_target in all_targets] + # model_inputs.update({"negative_mask": torch.tensor(negative_mask)}) + + # for multi-GPUs + assert all(model_inputs[k].shape[0] == len(batch) for k in model_inputs.keys()) + + return model_inputs + + return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, + collate_fn=base_collate_fn, pin_memory=True) + + def get_query_dataloader(self, batch_size, shuffle, num_workers, add_history=False): + + def query_collate_fn(batch): + model_inputs = self.get_query_inputs(batch) + + if add_history: + model_inputs.update(self.get_history_inputs(batch)) + + return model_inputs + + return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, + collate_fn=query_collate_fn, pin_memory=True, drop_last=False) + + def get_t5_dataloader(self, batch_size, shuffle, num_workers, is_train): + + def t5_collate_fn(batch): + history_inputs = self.get_history_inputs(batch) + response_inputs = self.get_target_inputs(batch) + return { + "input_ids": history_inputs["history_input_ids"], + "attention_mask": history_inputs["history_attention_mask"], + "labels": response_inputs["labels"] + } + + return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, + collate_fn=t5_collate_fn, pin_memory=True) + + def __len__(self): + return len(self.data) + + def __getitem__(self, item): + return self.data[item] + + +if __name__ == '__main__': + import json + + + def load_json(fi): + return json.load(open(fi, 'r')) + + + def load_jsonl(fn): + all_data = [] + with open(fn, "r") as f: + for line in f.readlines(): + all_data.append(json.loads(line)) + return all_data + + + tokenizer = T5Tokenizer.from_pretrained("./data/cbqa_data/pretrained_model/t5-base") + # test_data = load_jsonl("wow-test_without_answers-kilt.jsonl.txt") + # train_data = load_jsonl("wow-train-kilt.jsonl") + dev_data = load_jsonl("./data/annotated_datasets/wizard_of_wikipedia/wow-dev-kilt.jsonl") + + exp = dev_data[0] + print("") + + dataset = DialogDataset(dev_data, tokenizer, None, "wow_kilt", + max_source_length=768, max_utterances=10) + # data: List[Dict], + # tokenizer: T5Tokenizer, + # qas_to_retrieve, + # dataset_name, + # retrieve_strategy = "dr", + # max_source_length = None, + # args = None, + # normed_answer_of_qas_to_ret = None, + # max_utterances = 100, + # add_topic = True, + # add_persona = True diff --git a/kilt_main.py b/kilt_main.py new file mode 100644 index 0000000..a59d78f --- /dev/null +++ b/kilt_main.py @@ -0,0 +1,183 @@ +import json +import os +import pickle +from transformers import T5Tokenizer +from emat.utils import load_jsonl +from kilt_dataset import DialogDataset +from kilt_trainer import DialogTrainer +from utils.utils import CATArgs +import logging +import time + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + +DATA_PATHS = { + "wow": { + "train": "./data/annotated_datasets/wizard_of_wikipedia/train.json", + "validation": "./data/annotated_datasets/wizard_of_wikipedia/valid_random_split.json", + "test": "./data/annotated_datasets/wizard_of_wikipedia/test_random_split.json", + }, + "wow_unseen": { + "train": "./data/annotated_datasets/wizard_of_wikipedia/train.json", + "validation": "./data/annotated_datasets/wizard_of_wikipedia/valid_topic_split.json", + "test": "./data/annotated_datasets/wizard_of_wikipedia/test_topic_split.json", + }, + "wow_kilt": { + "train": "./data/annotated_datasets/wizard_of_wikipedia/wow-train-kilt.jsonl", + "validation": "./data/annotated_datasets/wizard_of_wikipedia/wow-dev-kilt.jsonl", + "test": "./data/annotated_datasets/wizard_of_wikipedia/wow-test_without_answers-kilt.jsonl.txt", + }, + "eli5_kilt": { + "train": "./data/annotated_datasets/eli5/eli5-train-kilt.jsonl", + "validation": "./data/annotated_datasets/eli5/eli5-dev-kilt.jsonl", + "test": "./data/annotated_datasets/eli5/eli5-test_without_answers-kilt.jsonl", + } +} +QA_KB_PATHS = { + "PAQ_L1": "./tmp/PAQ_L1_pickl_file.pkl", + "PAQ": "./tmp/PAQ_full.pkl", + "TAQ_TRAIN_NQ_TRAIN_PAQ": "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl", +} +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + +def load_dataset(args): + assert args.qa_data_name in DATA_PATHS.keys(), f"available dataset names: {DATA_PATHS.keys()}" + + logging.info("loading normed answer of qas to retrieve") + if "PAQ" == args.qas_to_retrieve_from: + normed_answer_of_qas_to_ret = pickle.load(open("./tmp/PAQ_only_normalized_answer.pkl", 'rb')) + else: + normed_answer_of_qas_to_ret = json.load(open("./tmp/PAQL1_only_normalized_answer.json", 'r')) + + logging.info("loading qas to retrieve") + if "debug" in args.exp_name.lower() or "full-paq-test" in args.exp_name.lower(): + if not os.path.exists("./tmp/PAQ_L1_small.pkl"): + qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_pickl_file.pkl", 'rb')) + qas_to_retrieve = qas_to_retrieve[:len(qas_to_retrieve) // 14] + pickle.dump(qas_to_retrieve, open("./tmp/PAQ_L1_small.pkl", 'wb')) + else: + qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_small.pkl", 'rb')) + else: + qas_to_retrieve_fp = QA_KB_PATHS[args.qas_to_retrieve_from] + logging.info(f"loading qas from {qas_to_retrieve_fp}") + if qas_to_retrieve_fp.endswith("pkl"): + qas_to_retrieve = pickle.load(open(qas_to_retrieve_fp, 'rb')) + elif qas_to_retrieve_fp.endswith("jsonl"): + qas_to_retrieve = load_jsonl(qas_to_retrieve_fp) + else: + raise ValueError(f"{qas_to_retrieve_fp}") + + if "debug" in args.exp_name.lower(): + qas_to_retrieve = qas_to_retrieve[:5000] + normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:len(qas_to_retrieve)] + + if args.qas_to_retrieve_from == "PAQ" and args.PAQ_size is not None: + qas_to_retrieve = qas_to_retrieve[:args.PAQ_size] + normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:args.PAQ_size] + assert len(qas_to_retrieve) == args.PAQ_size + logging.info(f"select {args.PAQ_size}-size PAQ.") + + assert len(normed_answer_of_qas_to_ret) == len(qas_to_retrieve) + loaded_data = { + "qas_to_retrieve": qas_to_retrieve, + "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret + } + + return loaded_data + + +def main(): + cat_args = CATArgs("dialog_cat") + args = cat_args.parse_args() + data_paths = DATA_PATHS[args.qa_data_name] + logging.info("load datasets") + if "kilt" in args.qa_data_name: + train_data = load_jsonl(data_paths["train"]) + dev_data = load_jsonl(data_paths["validation"]) + test_data = load_jsonl(data_paths["test"]) + else: + train_data = json.load(open(data_paths["train"], 'r')) + dev_data = json.load(open(data_paths["validation"], 'r')) + test_data = json.load(open(data_paths["test"], 'r')) + + loaded_data = load_dataset(args) + logging.info("data loaded.") + qas_to_retrieve = loaded_data["qas_to_retrieve"] + normed_answer_of_qas_to_ret = loaded_data["normed_answer_of_qas_to_ret"] + + if "debug" in args.exp_name.lower(): + train_data = train_data[:50] + dev_data = dev_data[:10] + test_data = test_data[:10] + qas_to_retrieve = qas_to_retrieve[:10000] + normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:10000] + + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + + if args.qa_data_name != "eli5_kilt": + dataset_kwargs = { + "dataset_name": args.qa_data_name, + "args": args, + "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret, + "add_persona": args.add_persona, + "add_topic": args.add_topic, + "max_source_length": 1024 + } + else: + assert args.qa_data_name == "eli5_kilt" + dataset_kwargs = { + "dataset_name": args.qa_data_name, + "args": args, + "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret, + "max_source_length": 384, + "max_target_length": 1536 + } + mu = 10 if args.qa_data_name == "wow_kilt" else 13 + train_dataset = DialogDataset(train_data, tokenizer, qas_to_retrieve, max_utterances=mu, **dataset_kwargs) + dev_dataset = DialogDataset(dev_data, tokenizer, qas_to_retrieve, **dataset_kwargs) + test_dataset = DialogDataset(test_data, tokenizer, qas_to_retrieve, **dataset_kwargs) + dialog_trainer = DialogTrainer(args, train_dataset, dev_dataset, test_dataset, qas_to_retrieve, + normed_answer_of_qas_to_ret) + + if args.do_train: + dialog_trainer.train() + elif args.do_test: + logging.info("Only do test.") + ckpt_load_path = os.path.join(args.output_dir, "best_ckpt/pytorch_model.bin") + gen_kwargs = {"max_length": 1024, + "num_beams": 5, + "do_sample": True, + "top_k": 64, + "no_repeat_ngram_size": 8} + logging.warning("use dev dataset") + use_dataset = dev_dataset + + metrics, ret_qas, gen_response = dialog_trainer.evaluate(use_dataset, update_key_memory=True, + ckpt_load_path=ckpt_load_path, gen_kwargs=gen_kwargs) + + for k, v in metrics.items(): + logging.info(f"test_{k}: {v}") + assert len(ret_qas) == len(gen_response) == len(use_dataset.data) + results = [] + for retrieved, pred, input_item in zip(ret_qas, gen_response, use_dataset.data): + results.append({ + "id": input_item["id"], + "input": tokenizer.decode(input_item["input_ids"]), + "target": tokenizer.decode(input_item["response_ids"]) if "response_ids" in input_item else "", + "query": tokenizer.decode(input_item["query_ids"]), + "output": {"answer": pred, "provenance": [{"wikipedia_id": "12904"}]}, + "retrieved_qas": [f"question: {qa['question']} answer: {qa['answer'][0]}" for qa in retrieved] + }) + dump_path = os.path.dirname(ckpt_load_path) + dump_path = os.path.join(dump_path, f"{time.strftime('%d %H-%M')}_predict_result.json") + json.dump(results, open(dump_path, 'w'), + indent=4, ensure_ascii=False) + + +if __name__ == '__main__': + main() diff --git a/kilt_scripts/eli5_train.sh b/kilt_scripts/eli5_train.sh new file mode 100644 index 0000000..169d7ed --- /dev/null +++ b/kilt_scripts/eli5_train.sh @@ -0,0 +1,59 @@ +#!/bin/bash -l + +set -e +set -u + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +LOAD_EXP_NAME="KL=3;kdim=1536;CL=10;VL=11;VN=10;async_cat_k+v;t5-base;from-t5-base;" +EXP_NAME="base;wow_kilt;freeze;8e-5;20epoch;CL=10;VL=11;" +DATA_NAME="eli5_kilt" + +DEVICE="0" + +echo "Use Device ${DEVICE}" + +# Train wow-EMAT-SKSV + +CUDA_VISIBLE_DEVICES=${DEVICE} python kilt_main.py \ + --key_layer=3 \ + --cat_layer=10 \ + --value_layer=11 \ + --query_batch_size=256 \ + --build_mem_batch_size=12000 \ + --project_name="${DATA_NAME^^}-CAT" \ + --exp_name=${EXP_NAME} \ + --batch_local_positive_num=5 \ + --pos_from_top=128 \ + --do_eval \ + --kvm_seg_n=5 \ + --values_with_order \ + --value_fusion_method="async_cat_k_delay+v" \ + --num_values=10 \ + --qa_data_name=${DATA_NAME} \ + --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \ + --source_prefix="question: " \ + --per_device_train_batch_size=12 \ + --per_device_eval_batch_size=32 \ + --gradient_accumulation_steps=5 \ + --learning_rate=5e-5 \ + --num_train_epochs=20 \ + --lr_scheduler_type="linear" \ + --num_warmup_steps=1000 \ + --output_dir="./outputs/${DATA_NAME}_checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_encoder_type="conv" \ + --select_positive_strategy="softmax_sample" \ + --faiss_efsearch=128 \ + --gen_weight=1 \ + --match_weight=1 \ + --key_reduce_method="avg" \ + --qas_to_retrieve_from="PAQ_L1" \ + --local_size=384 \ + --separate_task \ + --early_stop_patience=8 \ + --negatives_num_each_example=16 \ + --do_train \ + --not_share_encoder diff --git a/kilt_scripts/eval_kilt.sh b/kilt_scripts/eval_kilt.sh new file mode 100644 index 0000000..acc9761 --- /dev/null +++ b/kilt_scripts/eval_kilt.sh @@ -0,0 +1,61 @@ +#!/bin/bash -l + +set -e +set -u + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +CKPT_DIR="" +EXP_NAME="eval" +DATA_NAME="wow_kilt" # eli5_kilt + +DEVICE="0" + +echo "Use Device ${DEVICE}" + +#echo "wait-to-continue..." +#sleep 7200 + +CUDA_VISIBLE_DEVICES=${DEVICE} python kilt_main.py \ + --key_layer=3 \ + --value_layer=7 \ + --query_batch_size=256 \ + --build_mem_batch_size=12000 \ + --project_name="${DATA_NAME^^}-CAT" \ + --exp_name=${EXP_NAME} \ + --batch_local_positive_num=5 \ + --pos_from_top=128 \ + --do_eval \ + --kvm_seg_n=5 \ + --values_with_order \ + --value_fusion_method="cat_k_delay+v" \ + --num_values=10 \ + --qa_data_name=${DATA_NAME} \ + --model_name_or_path=${CKPT_DIR} \ + --source_prefix="question: " \ + --per_device_train_batch_size=12 \ + --per_device_eval_batch_size=32 \ + --gradient_accumulation_steps=5 \ + --learning_rate=8e-5 \ + --num_train_epochs=20 \ + --lr_scheduler_type="linear" \ + --num_warmup_steps=1000 \ + --output_dir="./outputs/${DATA_NAME}_checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_encoder_type="conv" \ + --select_positive_strategy="softmax_sample" \ + --faiss_efsearch=128 \ + --gen_weight=1 \ + --match_weight=1 \ + --key_reduce_method="avg" \ + --qas_to_retrieve_from="PAQ_L1" \ + --local_size=384 \ + --separate_task \ + --early_stop_patience=8 \ + --negatives_num_each_example=16 \ + --do_test \ + --add_topic \ + --add_persona \ + --not_share_encoder diff --git a/kilt_scripts/wow_train.sh b/kilt_scripts/wow_train.sh new file mode 100644 index 0000000..0d1afe3 --- /dev/null +++ b/kilt_scripts/wow_train.sh @@ -0,0 +1,61 @@ +#!/bin/bash -l + +set -e +set -u + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +LOAD_EXP_NAME="KL=3;kdim=1536;CL=10;VL=11;VN=10;async_cat_k+v;t5-base;from-t5-base;" +EXP_NAME="base;wow_kilt;freeze;8e-5;20epoch;CL=10;VL=11;" +DATA_NAME="wow_kilt" + +DEVICE="0" + +echo "Use Device ${DEVICE}" + +# Train wow-EMAT-SKSV + +CUDA_VISIBLE_DEVICES=${DEVICE} python kilt_main.py \ + --key_layer=3 \ + --cat_layer=10 \ + --value_layer=11 \ + --query_batch_size=256 \ + --build_mem_batch_size=12000 \ + --project_name="${DATA_NAME^^}-CAT" \ + --exp_name=${EXP_NAME} \ + --batch_local_positive_num=5 \ + --pos_from_top=128 \ + --do_eval \ + --kvm_seg_n=5 \ + --values_with_order \ + --value_fusion_method="async_cat_k_delay+v" \ + --num_values=10 \ + --qa_data_name=${DATA_NAME} \ + --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \ + --source_prefix="question: " \ + --per_device_train_batch_size=12 \ + --per_device_eval_batch_size=32 \ + --gradient_accumulation_steps=5 \ + --learning_rate=8e-5 \ + --num_train_epochs=20 \ + --lr_scheduler_type="linear" \ + --num_warmup_steps=1000 \ + --output_dir="./outputs/${DATA_NAME}_checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_encoder_type="conv" \ + --select_positive_strategy="softmax_sample" \ + --faiss_efsearch=128 \ + --gen_weight=1 \ + --match_weight=1 \ + --key_reduce_method="avg" \ + --qas_to_retrieve_from="PAQ_L1" \ + --local_size=384 \ + --separate_task \ + --early_stop_patience=8 \ + --negatives_num_each_example=16 \ + --do_train \ + --add_topic \ + --add_persona \ + --not_share_encoder diff --git a/kilt_trainer.py b/kilt_trainer.py new file mode 100644 index 0000000..770a885 --- /dev/null +++ b/kilt_trainer.py @@ -0,0 +1,655 @@ +import copy +import logging +import math +import os +import random +from itertools import chain +from typing import List, Dict, Optional +from collections import Counter +from rouge import Rouge +from emat.utils import write_jsonl +from utils.dr_utils import update_dialog_local_qas_to_retrieve, update_batch_inputs +from utils.utils import reduce_query_or_key_embeds +import datasets +import torch +import transformers +from accelerate import Accelerator +from tqdm.auto import tqdm +from transformers import AdamW, get_scheduler, set_seed +from utils.utils import save_model, load_model +from build_kvm import build_memory +from emat.t5 import T5WithKeyValueMemory +from kilt_dataset import DialogDataset +from nltk.translate.bleu_score import sentence_bleu +import time + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + +try: + import wandb + + wandb.ensure_configured() + if wandb.api.api_key is None: + _has_wandb = False + wandb.termwarn( + "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.") + else: + _has_wandb = False if os.getenv("WANDB_DISABLED") else True +except (ImportError, AttributeError): + _has_wandb = False + + +def compute_batch_BLEU(references, candidates): + def compute_BLEU(reference, candidate): + b1 = sentence_bleu(reference, candidate, weights=(1, 0, 0, 0)) + b2 = sentence_bleu(reference, candidate, weights=(0, 1, 0, 0)) + b3 = sentence_bleu(reference, candidate, weights=(0, 0, 1, 0)) + b4 = sentence_bleu(reference, candidate, weights=(0, 0, 0, 1)) + return b1, b2, b3, b4 + + references = [[ref] for ref in references] + + bleu1_score = [] + bleu2_score = [] + bleu3_score = [] + bleu4_score = [] + for index in range(len(references)): + bleu1, bleu2, bleu3, bleu4 = compute_BLEU(references[index], candidates[index]) + bleu1_score.append(bleu1) + bleu2_score.append(bleu2) + bleu3_score.append(bleu3) + bleu4_score.append(bleu4) + + bleu1 = sum(bleu1_score) / len(bleu1_score) + bleu2 = sum(bleu2_score) / len(bleu2_score) + bleu3 = sum(bleu3_score) / len(bleu2_score) + bleu4 = sum(bleu4_score) / len(bleu4_score) + + return {"bleu1": bleu1, "bleu2": bleu2, "bleu3": bleu3, "bleu4": bleu4, } + + +def f1_score(prediction, ground_truth): + prediction_tokens = prediction.split() + ground_truth_tokens = ground_truth.split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def max_score_over_ground_truths(fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) + + +def rougel_score(prediction, ground_truth): + rouge = Rouge() + # no normalization + try: + scores = rouge.get_scores(prediction, ground_truth, avg=True) + except ValueError: # "Hypothesis is empty." + return 0.0 + return scores["rouge-l"]["f"] + + +class DialogTrainer: + + def __init__( + self, + args, + train_dataset: DialogDataset, + dev_dataset: DialogDataset, + test_dataset: DialogDataset, + qas_to_retrieve: List[Dict], + normed_answer_of_qas_to_ret, + ): + accelerator = Accelerator() + logging.info(f"wandb {'available' if _has_wandb else 'unavailable'}") + logger.info(accelerator.state) + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + if args.seed is not None: + set_seed(args.seed) + else: + logging.info("Not set seed.") + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + if accelerator.is_local_main_process and _has_wandb: + wandb.init(project=args.project_name, name=args.exp_name, dir=args.output_dir, config=vars(args)) + + logging.info("loading model") + config, tokenizer, self.model = load_model(T5WithKeyValueMemory, args) + logging.info("Loading data.") + + if args.freeze_t5_params: + logging.info("Freeze T5 parameters.") + self.model.freeze_t5_params() + + if args.not_share_encoder and not args.update_kv_embeds: + logging.info("Freeze kv-encoder parameters.") + self.model.freeze_kv_encoder_params() + + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, }, + {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, }, + ] + self.optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # reused_key_memory: pre-allocated memory to store full key_memory + self.reused_key_memory = torch.zeros((len(qas_to_retrieve), self.model.model_dim), + device="cpu", dtype=torch.float16) + self.train_data_query_embeds = torch.zeros((len(train_dataset), self.model.model_dim), + device="cpu", dtype=torch.float16) + self.key_memory: Optional[List[torch.tensor]] = None + self.key_memory = [] + for start_idx in range(0, len(qas_to_retrieve), math.ceil(len(qas_to_retrieve) / args.kvm_seg_n)): + self.key_memory.append( + self.reused_key_memory[start_idx: start_idx + math.ceil(len(qas_to_retrieve) / args.kvm_seg_n)] + ) + # logger.info(f"key num = {sum(len(i) for i in self.key_memory)}") + + self.train_dataset = train_dataset + self.dev_dataset = dev_dataset + self.test_dataset = test_dataset + self.args = args + self.accelerator = accelerator + self.tokenizer = tokenizer + self.qas_to_retrieve = qas_to_retrieve + self.prefix = args.source_prefix if args.source_prefix is not None else "" + self.query_batch_size = args.query_batch_size + # logger.info(f"PAQ-size: {len(self.qas_to_retrieve)}. PAQ's query batch size: {self.query_batch_size}.") + self.normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret + self.model = self.accelerator.prepare(self.model) + + @torch.no_grad() + def update_key_memory(self, use_fp16_model=True): + args = self.args + if use_fp16_model: + tmp_model = copy.deepcopy(self.model) + tmp_model = tmp_model.half() + else: + tmp_model = self.model + build_mem_batch_size = args.build_mem_batch_size + tmp_model.eval() + + self.key_memory, _ = build_memory( + tmp_model, self.tokenizer, embed_key=True, embed_value=False, prefix="question: ", embed_as_fp16=True, + key_reduce_method=args.key_reduce_method, return_memory=True, dump_memory=False, + data_to_embed=self.qas_to_retrieve, max_source_length=args.max_source_length, padding=True, + batch_size=build_mem_batch_size, separate_task=True, kvm_seg_n=args.kvm_seg_n, + reused_key_memory=self.reused_key_memory, num_workers=0 + ) + if type(self.key_memory) is not list: + self.key_memory = [self.key_memory] + del tmp_model + + @torch.no_grad() + def update_local_qas(self, epoch, use_fp16_model=True): + args = self.args + if use_fp16_model: + tmp_model = copy.deepcopy(self.model) + tmp_model = tmp_model.half() + else: + tmp_model = self.model + build_mem_batch_size = args.build_mem_batch_size + tmp_model.eval() + update_dialog_local_qas_to_retrieve( + args, self.train_dataset, self.qas_to_retrieve, tmp_model, self.key_memory, + self.normed_answer_of_qas_to_ret, train_data_query_embeds=self.train_data_query_embeds, + build_mem_batch_size=build_mem_batch_size, query_batch_size=self.query_batch_size, + local_size=args.local_size, pos_from_top=args.pos_from_top, neg_from_top=200 + ) + del tmp_model + + def train(self): + args = self.args + tokenizer = self.tokenizer + num_workers = 5 + logging.info("Build Memory") + self.update_key_memory() + + train_dataloader = self.train_dataset.get_dataloader(batch_size=args.per_device_train_batch_size, + shuffle=True, num_workers=num_workers) + optimizer, train_dataloader = self.accelerator.prepare(self.optimizer, train_dataloader) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Train! + total_batch_size = args.per_device_train_batch_size * self.accelerator.num_processes \ + * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(self.train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not self.accelerator.is_local_main_process) + completed_steps = 0 + + best_score, patience = float("-inf"), args.early_stop_patience + best_score_epoch = -1 + eval_times = 0 + select_mt = "f1" if self.train_dataset.dataset_name != "eli5_kilt" else "RougeL" + + for epoch in range(args.num_train_epochs): + if args.qas_to_retrieve_from == "PAQ" and (epoch % 3 == 0): + self.update_local_qas(epoch) + elif args.qas_to_retrieve_from != "PAQ": + self.update_local_qas(epoch) + for step, batch in enumerate(train_dataloader): + + update_batch_inputs(args, batch, self.model) + self.model.train() + if args.match_weight > 0.0: + # Embed Positive Key and the Value to input. + embed_dict = self.model.wrapped_embed_kv( # assert num_values > 1, otherwise set compute_value=True + separate_task=args.separate_task, compute_key=True, compute_value=False, + **batch.pop("positive_kv_inputs") + ) + positive_key_embeds = embed_dict["normed_key_embeds"] + positive_key_embeds = reduce_query_or_key_embeds(positive_key_embeds, args.key_reduce_method) + # Embed Negative Key + embed_dict = self.model.wrapped_embed_kv( + separate_task=args.separate_task, compute_key=True, compute_value=False, + **batch.pop("negative_kv_inputs") + ) + negative_key_embeds = embed_dict["normed_key_embeds"] + negative_key_embeds = reduce_query_or_key_embeds(negative_key_embeds, args.key_reduce_method) + else: + negative_key_embeds, positive_key_embeds = None, None + # Embed retrieved-Key-Value + embed_dict = self.model.wrapped_embed_kv( + separate_task=args.separate_task, compute_key=True, compute_value=True, + **batch.pop("group_value_inputs") + ) + key_embeds_of_value = embed_dict["key_embeds"] + value_embeds = embed_dict["value_embeds"] + bs = batch["query_input_ids"].shape[0] + value_embeds = value_embeds.view(bs, args.num_values, args.prefix_length, -1) + key_embeds_of_value = key_embeds_of_value.view(bs, args.num_values, -1, self.model.model_dim) + + loss_dict = self.model.compute_qa_loss( + input_ids=batch["history_input_ids"], + attention_mask=batch["history_attention_mask"], + labels=batch["labels"], + decoder_only_attend_on_prefix=False, + value_fusion_method=args.value_fusion_method, + encoder_outputs_are_key_or_value=False, + key_reduce_method=args.key_reduce_method, + positive_key_embeds=positive_key_embeds, + negative_key_embeds=negative_key_embeds, + value_embeds=value_embeds, + matching_targets=batch["matching_targets"], + key_embeds_of_value=key_embeds_of_value, + negative_mask=batch.get("negative_mask", None), + ) + if args.match_weight > 0.0: + loss = loss_dict["gen_loss"] + else: + loss = args.gen_weight * loss_dict["gen_loss"] + args.match_weight * loss_dict["match_loss"] + loss_dict = {k: v.item() for k, v in loss_dict.items()} + loss = loss / args.gradient_accumulation_steps + self.accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + if args.match_weight > 0.0: + progress_bar.set_description(f"GL-[{loss_dict['gen_loss']:.5f}]") + else: + progress_bar.set_description(f"GL-[{loss_dict['gen_loss']:.5f}] " + f"RL-[{loss_dict['match_loss']:.5f}]") + completed_steps += 1 + if self.accelerator.is_local_main_process and _has_wandb: + wandb.log({"loss": loss * args.gradient_accumulation_steps, "step": completed_steps}) + wandb.log({"trainable_percentage": batch["trainable_percentage"][0].item(), + "forward_step": completed_steps}) + for k, v in loss_dict.items(): + wandb.log({k: v, "step": completed_steps}) + + if args.eval_every_n_steps is not None and completed_steps % args.eval_every_n_steps == 0: + self.update_local_qas(epoch) + metric, _, _ = self.evaluate(dataset=self.dev_dataset, extend_mem_from="train") + for k, v in metric.items(): + logger.info(f"eval_times {eval_times} eval - {k}: {v * 100:.5f}") + if _has_wandb: + for k, v in metric.items(): + wandb.log({k: v * 100, "eval_times": eval_times}) + + if args.output_dir is not None: + cur_score = metric[select_mt] + if cur_score > best_score: + best_score = cur_score + best_score_epoch = epoch + save_model(self.model, os.path.join(args.output_dir, "best_ckpt"), + self.accelerator, tokenizer=tokenizer, arguments=args) + patience = args.early_stop_patience + else: + patience -= 1 + if patience <= 0: + break + + if completed_steps >= args.max_train_steps: + break + + if args.output_dir is not None: + save_model(self.model, os.path.join(args.output_dir, "latest_ckpt"), + self.accelerator, tokenizer=tokenizer, arguments=args) + + if args.update_kv_embeds: + logging.info("Update Memory") + self.update_key_memory() + + if args.do_eval and epoch % args.eval_freq == 0: + # if args.do_eval: self.key_memory is up-to-date. + metric, _, _ = self.evaluate(dataset=self.dev_dataset, extend_mem_from="train") + for k, v in metric.items(): + logger.info(f"epoch {epoch} eval - {k}: {v * 100:.5f}") + if _has_wandb: + for k, v in metric.items(): + wandb.log({k: v * 100, "epoch": epoch}) + + if args.output_dir is not None: + cur_score = metric[select_mt] + if cur_score > best_score: + best_score = cur_score + best_f1_epoch = epoch + save_model(self.model, os.path.join(args.output_dir, "best_ckpt"), + self.accelerator, tokenizer=tokenizer, arguments=args) + patience = args.early_stop_patience + else: + patience -= 1 + if patience <= 0: + break + + logger.info(f"best_f1_dev: {best_score * 100:.5f}") + logger.info(f"best_f1 epoch: {best_score_epoch}") + if _has_wandb: + wandb.log({"best_f1_dev": best_score * 100}) + + # do-test + best_model_state_dict = os.path.join(args.output_dir, "best_ckpt/pytorch_model.bin") + metric, _, _ = self.evaluate(dataset=self.test_dataset, extend_mem_from="train_dev", + update_key_memory=True, ckpt_load_path=best_model_state_dict) + + if self.accelerator.is_local_main_process: + for k, v in metric.items(): + logger.info(f"test - {k}: {v * 100:.5f}") + if _has_wandb: + for k, v in metric.items(): + wandb.log({f"test_{k}": v * 100}) + + @torch.no_grad() + def evaluate(self, dataset: DialogDataset = None, extend_mem_from="", update_key_memory=False, ckpt_load_path=None, + gen_kwargs=None): + # not implement correctly in multi-GPUs + tokenizer = self.tokenizer + args = self.args + self.model.eval() + torch.cuda.empty_cache() + + if ckpt_load_path is not None: + if args.update_kv_embeds: + assert update_key_memory is True + self.model.load_state_dict(torch.load(ckpt_load_path), strict=True) + + assert type(self.key_memory) == list + if update_key_memory: + logging.info("Update Memory") + self.update_key_memory() + + dataloader = dataset.get_query_dataloader(batch_size=args.per_device_eval_batch_size, + shuffle=False, num_workers=1, add_history=True) + dataloader = self.accelerator.prepare(dataloader) + + if gen_kwargs is None: + gen_kwargs = {"max_length": args.max_target_length, + "num_beams": args.num_beams, } + + torch.cuda.empty_cache() + + all_retrieved_qas = [] + all_gen_response = [] + for batch in tqdm(dataloader): + embed_dict = self.model.CAT_embed_q( + input_ids=batch["query_input_ids"], + attention_mask=batch["query_attention_mask"], + compute_key=True, compute_value=False + ) + query_embeds = embed_dict["normed_key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method) + query_embeds = query_embeds.half() + + # calculate topk in each chunk -> combine all-topk -> select final topk + chunk_top_scores = [] + chunk_top_indices = [] + idx_shift = 0 + for chunk_key_memory in self.key_memory: + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_topk = torch.mm(query_embeds, chunk_key_memory_cuda.t()).topk(50, dim=1) + chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size] + chunk_top_indices.append(chunk_topk.indices + idx_shift) + idx_shift += len(chunk_key_memory) + del chunk_key_memory_cuda + torch.cuda.empty_cache() + chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n + chunk_top_indices = torch.cat(chunk_top_indices, dim=1) # q_batch, local_size*seg_n + topk = chunk_top_scores.topk(args.num_values, dim=1) # q_batch, local_size + top_indices_indices = topk.indices + top_indices = [] + for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices): + top_indices.append([cur_indices[idx] for idx in cur_indices_indices]) + readout_qas = [[self.qas_to_retrieve[idx] for idx in indices] for indices in top_indices] + + value_qas = [] + for qas in readout_qas: + selected_qas = qas[:args.num_values] + if not args.values_with_order: + random.shuffle(selected_qas) + value_qas.append(selected_qas) + all_retrieved_qas += readout_qas + + squeezed_value_qas = list(chain(*value_qas)) + retrieved_qas_inputs = dataset.get_qa_key_value_inputs(squeezed_value_qas, only_return_key_inputs=False) + embed_dict = self.model.wrapped_embed_kv(separate_task=args.separate_task, compute_key=True, + compute_value=True, **retrieved_qas_inputs) + value_embeds = embed_dict["value_embeds"] + key_embeds_of_value = embed_dict["key_embeds"] + cur_batch_size = query_embeds.shape[0] + value_embeds = value_embeds.view(cur_batch_size, args.num_values, args.prefix_length, -1) + key_embeds_of_value = key_embeds_of_value.view(cur_batch_size, args.num_values, -1, self.model.model_dim) + encoder_outputs = self.model.encoder( + input_ids=batch["history_input_ids"], + attention_mask=batch["history_attention_mask"], + return_dict=True, + value_embeds=value_embeds, + readout_top_k=-1, + key_reduce_method=args.key_reduce_method, + value_fusion_method=args.value_fusion_method, + key_embeds_of_value=key_embeds_of_value + ) + generated_tokens = self.accelerator.unwrap_model(self.model).generate( + encoder_outputs=encoder_outputs, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=False, + attention_mask=batch["history_attention_mask"].to(self.model.device), + value_fusion_method=args.value_fusion_method, + **gen_kwargs, + ) + decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + decoded_tokens = [ans.strip() for ans in decoded_tokens] + all_gen_response += decoded_tokens + + torch.cuda.empty_cache() + + if "normalized_response" in dataset.data[0].keys(): + if dataset.dataset_name == "wow_kilt": + reference = [item["normalized_response"] for item in dataset.data] # only one target + else: + reference = [item["candidate_responses"] for item in dataset.data] # multi candidates + elif "response" in dataset.data[0].keys(): + reference = [DialogDataset.normalize_answer(item["response"]) for item in dataset.data] + else: + reference = None + + if reference is not None: + assert len(all_gen_response) == len(reference) + metrics = dict() + if dataset.dataset_name == "eli5_kilt": + avg_rougel = sum([max_score_over_ground_truths(rougel_score, pred, ref) for pred, ref in + zip(all_gen_response, reference)]) / len(all_gen_response) + metrics.update({"RougeL": avg_rougel}) + reference = [[DialogDataset.normalize_answer(s) for s in cands] for cands in reference] + all_gen_response = [DialogDataset.normalize_answer(s) for s in all_gen_response] + avg_f1 = sum([max_score_over_ground_truths(f1_score, pred, ref) for pred, ref in + zip(all_gen_response, reference)]) / len(all_gen_response) + else: + reference = [DialogDataset.normalize_answer(s) for s in reference] + all_gen_response = [DialogDataset.normalize_answer(s) for s in all_gen_response] + bleu_scores = compute_batch_BLEU(reference, all_gen_response) + metrics.update(bleu_scores) + avg_f1 = sum([f1_score(pred, ref) for pred, ref in + zip(all_gen_response, reference)]) / len(all_gen_response) + metrics.update({"f1": avg_f1}) + else: + metrics = dict() + results = [] + assert len(dataset.data) == len(all_gen_response) == len(all_retrieved_qas) + for input_item, pred, retrieved in zip(dataset.data, all_gen_response, all_retrieved_qas): + results.append({ + "id": input_item["id"], + "input": self.tokenizer.decode(input_item["input_ids"]), + "query": self.tokenizer.decode(input_item["query_ids"]), + "output": {"answer": pred, "provenance": [{"wikipedia_id": "12904"}]}, + "retrieved_qas": [f"question: {qa['question']} answer: {qa['answer'][0]}" for qa in retrieved] + }) + dump_path = os.path.dirname(ckpt_load_path) + dump_path = os.path.join(dump_path, f"{time.strftime('%d %H-%M')}_predict_result.json") + # save_path = os.path.join(args.output_dir, "kilt_test_predict.jsonl") + write_jsonl(results, dump_path) + + return metrics, all_retrieved_qas, all_gen_response + + +@torch.no_grad() +def kilt_generate(model, tokenizer, embedding_index, key_memory, value_memory, dataset: DialogDataset, + qas_to_retrieve, inference_batch_size, gen_kwargs): + model.eval() + + dataloader = dataset.get_query_dataloader(batch_size=inference_batch_size, shuffle=False, + num_workers=1, add_history=True) + all_retrieved_qas = [] + all_gen_response = [] + + for batch in dataloader: + embed_dict = model.CAT_embed_q( + input_ids=batch["query_input_ids"].cuda(), + attention_mask=batch["query_attention_mask"].cuda(), + compute_key=True, compute_value=False + ) + query_embeds = embed_dict["normed_key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, "avg") + query_embeds = query_embeds.half() + bs = len(batch["query_input_ids"]) + # calculate topk in each chunk -> combine all-topk -> select final topk + chunk_top_scores = [] + chunk_top_indices = [] + idx_shift = 0 + if type(embedding_index) != list: + embedding_index = [embedding_index] + for chunk_key_memory in embedding_index: + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_topk = torch.mm(query_embeds, chunk_key_memory_cuda.t()).topk(50, dim=1) + chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size] + chunk_top_indices.append(chunk_topk.indices + idx_shift) + idx_shift += len(chunk_key_memory) + del chunk_key_memory_cuda + torch.cuda.empty_cache() + chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n + chunk_top_indices = torch.cat(chunk_top_indices, dim=1) # q_batch, local_size*seg_n + topk = chunk_top_scores.topk(model.encoder.num_values, dim=1) # q_batch, local_size + top_indices_indices = topk.indices + top_indices = [] + for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices): + top_indices.append([cur_indices[idx] for idx in cur_indices_indices]) + readout_qas = [[qas_to_retrieve[idx] for idx in indices] for indices in top_indices] + + value_qas = [] + for qas in readout_qas: + selected_qas = qas[:model.encoder.num_values] + value_qas.append(selected_qas) + all_retrieved_qas += readout_qas + + top_indices = torch.tensor(top_indices) + memory_size, hidden_num, hidden_size = value_memory.shape + value_embeds = torch.index_select(value_memory, 0, top_indices.view(-1)).float().cuda() + value_embeds = value_embeds.view(bs, model.encoder.num_values, hidden_num, hidden_size) + key_embeds_of_value = torch.index_select(key_memory, 0, top_indices.view(-1)).float().cuda() + key_embeds_of_value = key_embeds_of_value.view(bs, model.encoder.num_values, hidden_num, hidden_size) + encoder_outputs = model.encoder( + input_ids=batch["history_input_ids"].cuda(), + attention_mask=batch["history_attention_mask"].cuda(), + return_dict=True, + value_embeds=value_embeds, + readout_top_k=-1, + key_reduce_method="avg", + value_fusion_method=model.encoder.value_fusion_method, + key_embeds_of_value=key_embeds_of_value + ) + generated_tokens = model.generate( + encoder_outputs=encoder_outputs, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=False, + attention_mask=batch["history_attention_mask"].to(model.device), + value_fusion_method=model.encoder.value_fusion_method, + **gen_kwargs, + ) + decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + decoded_tokens = [ans.strip() for ans in decoded_tokens] + all_gen_response += decoded_tokens + + return all_retrieved_qas, all_gen_response + + +if __name__ == '__main__': + pass diff --git a/kvm_pretrain.py b/kvm_pretrain.py new file mode 100644 index 0000000..b652e4d --- /dev/null +++ b/kvm_pretrain.py @@ -0,0 +1,575 @@ +import argparse +import logging +import math +import os +import pickle +import random +import tempfile +from functools import partial +import time +from itertools import chain + +import datasets +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from datasets import load_dataset, load_metric, DatasetDict +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + AdamW, + get_scheduler, + set_seed, +) + +from transformers.utils.versions import require_version +import json + +from emat.evaluation.eval_retriever import eval_generation_em +from emat.t5 import T5WithKeyValueMemory, CATEncoderOutput +from utils.utils import CATArgs, update_CAT_config_from_args, save_model, get_key_value_encoder_inputs, \ + get_key_value_ae_target, get_qa_inputs, load_model + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + +logger = logging.getLogger(__name__) +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") + +# Initialise wandb +try: + import wandb + + wandb.login(key="09cf72d9c096a95d573fd2b857bfd601022bc4b7") + os.environ["WANDB_API_KEY"] = "09cf72d9c096a95d573fd2b857bfd601022bc4b7" + wandb.ensure_configured() + if wandb.api.api_key is None: + _has_wandb = False + wandb.termwarn( + "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.") + else: + _has_wandb = False if os.getenv("WANDB_DISABLED") else True +except (ImportError, AttributeError): + _has_wandb = False + logger.info("no WANDB") + exit() + +DATA_PATHS = { + "nq-repaq": { + # "train": "repaq_results/nq.train-train.retriever_multi_base_256.jsonl", + # "validation": "repaq_results/nq.train-dev.retriever_multi_base_256.jsonl", + # "test": "repaq_results/nq.test.retriever_multi_base_256.jsonl" + "train": "./cbqa_data/RePAQ/RePAQ-output-NQ-train.jsonl", + "validation": "./cbqa_data/RePAQ/RePAQ-output-NQ-dev.jsonl", + }, + "PAQ-L1-Pretrain": { + "train": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-train.jsonl", + # "validation": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-dev.jsonl" + "validation": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-dev-3000.jsonl" + }, + "PAQ-L1-Small": { + "train": "./data/cbqa_data/pretrain_data/paq-l1-small-train.jsonl", # 10w examples from PAQ-L1 + "validation": "./data/cbqa_data/pretrain_data/paq-l1-pretrain-dev-3000.jsonl" + }, + "data_for_debug": { + "train": "./data/cbqa_data/pretrain_data/debug.jsonl", + "validation": "./data/cbqa_data/pretrain_data/debug.jsonl" + } +} +PAQ_PATH = "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl" +PAQ_L1_PATH = "./data/cbqa_data/pretrain_data/PAQ_L1/PAQ_L1.filtered.jsonl" + +qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_pickl_file.pkl", 'rb')) + + +# all_data = load_jsonl("./data/cbqa_data/pretrain_data/paq-l1-pretrain-train.jsonl") +# train_data = all_data[:100000] +# write_jsonl(train_data, DATA_PATHS["PAQ-L1-Small"]["train"]) + +def load_pretrain_kvm_data(args) -> DatasetDict: + assert args.pretrain_data_name in DATA_PATHS.keys(), f"available dataset names: {DATA_PATHS.keys()}" + data_paths = DATA_PATHS[args.pretrain_data_name] + data_files = { + "train": [data_paths["train"]], + "validation": [data_paths["validation"]] + } + return load_dataset("json", data_files=data_files) + + +def main(): + # Parse the arguments + # args = parse_args() + cat_args = CATArgs(exp_type="pretrain") + args = cat_args.parse_args() + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + accelerator = Accelerator() + + # Make one log on every process with the configuration for debugging. + + logging.info(f"wandb {'available' if _has_wandb else 'unavailable'}") + + logger.info(accelerator.state) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + if accelerator.is_local_main_process and _has_wandb: + wandb.init(project=args.project_name, name=args.exp_name, dir=args.output_dir, config=vars(args)) + + logging.info("loading model") + config, tokenizer, model = load_model(T5WithKeyValueMemory, args) + logging.info("Loading data.") + + if args.freeze_t5_params: + for name, param in model.named_parameters(): + if 'prefix_embedding' in name or 'key_encoder' in name: + param.requires_grad = True + else: + param.requires_grad = False + + prefix = args.source_prefix if args.source_prefix is not None else "" + + # Temporarily set max_target_length for training. + max_target_length = args.max_target_length + + with_ae_task = args.key_ae_weight > 0.0 and args.value_ae_weight > 0.0 + assert with_ae_task or args.pretrain_multi_values + + def preprocess_pretrain_input_func(examples): + # Normal inputs and outputs + model_inputs = get_key_value_encoder_inputs(examples, args.separate_task, tokenizer, args.max_source_length, + prefix=prefix, only_return_key_inputs=False) + model_inputs.update(get_key_value_ae_target(examples, tokenizer, args.key_ae_target, args.value_ae_target, + max_target_length)) + + return model_inputs + + def preprocess_pretrain_multi_values_input_func(examples, value_with_self_prop=None): + # How to build the dataset + # we hope two kinds of data contained in dataset: + # 1) top-k data, and the golden answer is contained in top-k QAs + # 2) top-k data do not contain the golden answer + # but all data's target are golden + model_inputs = get_qa_inputs(examples, tokenizer, args.max_source_length, max_target_length, prefix=prefix, + targets=None) + value_qas = [] + for ex in examples: + if random.random() < value_with_self_prop: + selected_values_indices = list(ex["retrieved_PAQL1_indices"][:args.num_values]) + else: + selected_values_indices = list(ex["retrieved_PAQL1_indices"][1:args.num_values + 1]) + if not args.values_with_order: + random.shuffle(selected_values_indices) + selected_values = [qas_to_retrieve[idx] for idx in selected_values_indices] + value_qas.append(selected_values) + value_qas = list(chain(*value_qas)) # bs * num_values + group_value_inputs = get_key_value_encoder_inputs(value_qas, args.separate_task, tokenizer, + args.max_source_length, prefix=prefix, + value_input_is_qa=args.value_input_is_qa) + for dk in group_value_inputs: + model_inputs[f"group_value_inputs_{dk}"] = group_value_inputs[dk] + + model_inputs.update( + get_key_value_encoder_inputs(examples, args.separate_task, tokenizer, args.max_source_length, + prefix=prefix, only_return_key_inputs=False, + value_input_is_qa=args.value_input_is_qa)) + + if with_ae_task: + model_inputs.update(get_key_value_ae_target(examples, tokenizer, args.key_ae_target, + args.value_ae_target, max_target_length)) + + return model_inputs + + # Evaluation metric: load the custom exact-match metric + + # Load the training/validation datasets + if args.pretrain_multi_values: + if "debug" not in args.exp_name: + pretrain_data_path = "./tmp/PAQ_L1_with_50_xlarge_retrieved_qa_indices.pkl" + # "tmp/PAQ_L1_with_retrieved_qa_indices.pkl" + raw_datasets = pickle.load(open(pretrain_data_path, "rb")) + train_dataset = raw_datasets[:-5000] + eval_dataset = raw_datasets[-5000:] + else: + raw_datasets = pickle.load(open("./tmp/PAQ_L1_5k_with_retrieved_qa_indices.pkl", 'rb')) + train_dataset = raw_datasets[:100] + eval_dataset = raw_datasets[:100] + collate_fn = partial(preprocess_pretrain_multi_values_input_func, + value_with_self_prop=args.value_with_self_prop) + else: + raw_datasets = load_pretrain_kvm_data(args) + train_dataset = raw_datasets["train"] + eval_dataset = raw_datasets["validation"] + collate_fn = preprocess_pretrain_input_func + # DataLoaders creation: + train_dataloader = DataLoader(train_dataset, shuffle=True, + collate_fn=collate_fn, + batch_size=args.per_device_train_batch_size, + num_workers=args.preprocessing_num_workers) + eval_dataloader = DataLoader(eval_dataset, + collate_fn=collate_fn, + batch_size=args.per_device_eval_batch_size, + num_workers=args.preprocessing_num_workers) + + without_self_eval_dataloader = DataLoader(eval_dataset, + collate_fn=partial(preprocess_pretrain_multi_values_input_func, + value_with_self_prop=0.0), + batch_size=args.per_device_eval_batch_size, + num_workers=args.preprocessing_num_workers) + + if not args.do_train and args.do_eval: + model, eval_dataloader = accelerator.prepare(model, eval_dataloader) + metric_key, metric_value, all_gen_ans = evaluate(args, model, config, eval_dataloader, accelerator, tokenizer) + if args.train_key: + key_em_score = metric_key.compute()["em"] * 100 # EM score is not in percentage points + logger.info(f"eval - Key-EM: {key_em_score:.2f}") + if args.train_value: + value_em_score = metric_value.compute()["em"] * 100 # EM score is not in percentage points + logger.info(f"eval - Value-EM: {value_em_score:.2f}") + + em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100 + logger.info(f"em_test: {em_score:.2f}") + exit() + + if args.do_train: + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, }, + {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader + ) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + eval_times = 0 + + best_em, patience = None, args.early_stop_patience + for epoch in range(args.num_train_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + assert args.separate_task + loss = 0. + loss_dict = dict() + if args.pretrain_multi_values: + group_inputs = {k.replace("group_value_inputs_", ""): v for k, v in batch.items() if + k.startswith("group_value_inputs_")} + embed_dict = model.wrapped_embed_kv( + separate_task=args.separate_task, compute_key=True, compute_value=True, **group_inputs + ) + key_embeds_of_value = embed_dict["key_embeds"] + value_embeds = embed_dict["value_embeds"] + bs = batch["input_ids"].shape[0] + value_embeds = value_embeds.view(bs, args.num_values, args.prefix_length, -1) + key_embeds_of_value = key_embeds_of_value.view(bs, args.num_values, -1, model.model_dim) + + loss_dict_gen = model.compute_qa_loss( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + labels=batch["labels"], + decoder_only_attend_on_prefix=args.decoder_only_attend_on_prefix, + value_fusion_method=args.value_fusion_method, + encoder_outputs_are_key_or_value=False, + key_reduce_method=args.key_reduce_method, + value_embeds=value_embeds, + key_embeds_of_value=key_embeds_of_value + ) + loss += args.gen_weight * loss_dict_gen["gen_loss"] + loss_dict.update({k: v for k, v in loss_dict_gen.items()}) + if with_ae_task: + loss_dict_ae = model.compute_key_value_ae_loss( + train_key=args.train_key, + train_value=args.train_value, + separate_task=True, + key_input_ids=batch["key_input_ids"], + key_attention_mask=batch["key_attention_mask"], + value_input_ids=batch["value_input_ids"], + value_attention_mask=batch["value_attention_mask"], + key_labels_input_ids=batch["key_labels_input_ids"], + value_labels_input_ids=batch["value_labels_input_ids"], + ) + if args.train_key: + loss += args.key_ae_weight * loss_dict_ae["key_ae_loss"] + if args.train_value: + loss += args.value_ae_weight * loss_dict_ae["value_ae_loss"] + loss_dict.update({k: v.item() for k, v in loss_dict_ae.items()}) + + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + if accelerator.is_local_main_process and _has_wandb: + wandb.log({"loss": loss * args.gradient_accumulation_steps, "step": completed_steps}) + for k, v in loss_dict.items(): + wandb.log({k: v, "step": completed_steps}) + + if completed_steps % 5555 == 0: + metric_key, metric_value, all_gen_ans = evaluate( + args, model, config, eval_dataloader, accelerator, tokenizer) + + scores = [] + if args.train_key: + key_em_score = metric_key.compute()["em"] * 100 # EM score is not in percentage points + logger.info(f"epoch {epoch} eval-time {eval_times} - Key-EM: {key_em_score:.2f}") + if accelerator.is_local_main_process and _has_wandb: + wandb.log({"key_em_dev": key_em_score, "eval_times": eval_times}) + scores.append(key_em_score) + if args.train_value: + value_em_score = metric_value.compute()["em"] * 100 # EM score is not in percentage points + logger.info(f"epoch {epoch} eval - Value-EM: {value_em_score:.2f}") + if accelerator.is_local_main_process and _has_wandb: + wandb.log({"value_em_dev": value_em_score, "eval_times": eval_times}) + scores.append(value_em_score) + + em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100 + logger.info(f"em_test: {em_score:.2f}") + wandb.log({"em_test": em_score, "eval_times": eval_times}) + if args.output_dir is not None: + if best_em is None or em_score > best_em: + best_em = em_score + save_model(model, os.path.join(args.output_dir, "best_ckpt"), accelerator, + tokenizer=tokenizer, arguments=args) + + # without self eval + _, _, all_gen_ans = evaluate(args, model, config, without_self_eval_dataloader, + accelerator, tokenizer) + em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100 + logger.info(f"w/o_self_em_test: {em_score:.2f}") + wandb.log({"w/o_self_em_test": em_score, "eval_times": eval_times}) + + eval_times += 1 + + if completed_steps >= args.max_train_steps: + break + + if args.output_dir is not None: + save_model(model, os.path.join(args.output_dir, "latest_ckpt"), accelerator, tokenizer=tokenizer, + arguments=args) + + if args.do_eval and epoch % args.eval_freq == 0: + metric_key, metric_value, all_gen_ans = evaluate(args, model, config, eval_dataloader, accelerator, + tokenizer) + + scores = [] + if args.train_key: + key_em_score = metric_key.compute()["em"] * 100 # EM score is not in percentage points + logger.info(f"epoch {epoch} eval - Key-EM: {key_em_score:.2f}") + if accelerator.is_local_main_process and _has_wandb: + wandb.log({"key_em_dev": key_em_score, "epoch": epoch}) + scores.append(key_em_score) + if args.train_value: + value_em_score = metric_value.compute()["em"] * 100 # EM score is not in percentage points + logger.info(f"epoch {epoch} eval - Value-EM: {value_em_score:.2f}") + if accelerator.is_local_main_process and _has_wandb: + wandb.log({"value_em_dev": value_em_score, "epoch": epoch}) + scores.append(value_em_score) + + em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100 + logger.info(f"em_test: {em_score:.2f}") + wandb.log({"em_test": em_score, "epoch": epoch}) + if args.output_dir is not None: + if best_em is None or em_score > best_em: + best_em = em_score + save_model(model, os.path.join(args.output_dir, "best_ckpt"), accelerator, + tokenizer=tokenizer, + arguments=args) + + # without self eval + _, _, all_gen_ans = evaluate(args, model, config, without_self_eval_dataloader, + accelerator, tokenizer) + em_score = eval_generation_em(eval_dataset, all_gen_ans) * 100 + logger.info(f"w/o_self_em_test: {em_score:.2f}") + wandb.log({"w/o_self_em_test": em_score, "epoch": epoch}) + + if best_em is not None: # Log the best dev EM score + wandb.log({"best_em_dev": best_em}) + + +def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [[label.strip()] for label in labels] + return preds, labels + + +@torch.no_grad() +def evaluate(args, model, config, eval_dataloader, accelerator, tokenizer): + metric_value = load_metric("emat/evaluation/exact_match.py") + metric_key = load_metric("emat/evaluation/exact_match.py") + if args.val_max_target_length is None: + args.val_max_target_length = args.max_target_length + + gen_kwargs = { + "max_length": args.val_max_target_length if args is not None else config.max_length, + "num_beams": args.num_beams, + } + all_gen_ans = [] + + for batch in tqdm(eval_dataloader): + model.eval() + group_inputs = {k.replace("group_value_inputs_", ""): v for k, v in batch.items() if + k.startswith("group_value_inputs_")} + embed_dict = model.wrapped_embed_kv( + separate_task=args.separate_task, compute_key=True, compute_value=True, **group_inputs + ) + key_embeds_of_value = embed_dict["key_embeds"] + value_embeds = embed_dict["value_embeds"] + bs = batch["input_ids"].shape[0] + value_embeds = value_embeds.view(bs, args.num_values, args.prefix_length, -1) + key_embeds_of_value = key_embeds_of_value.view(bs, args.num_values, -1, model.model_dim) + encoder_outputs = model.encoder( + input_ids=batch["input_ids"].to(model.device), + attention_mask=batch["attention_mask"].to(model.device), + return_dict=True, + value_embeds=value_embeds, + readout_top_k=-1, + key_reduce_method=args.key_reduce_method, + value_fusion_method=args.value_fusion_method, + key_embeds_of_value=key_embeds_of_value + ) + generated_tokens = accelerator.unwrap_model(model).generate( + encoder_outputs=encoder_outputs, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=args.decoder_only_attend_on_prefix, + attention_mask=batch["attention_mask"].to(model.device), + value_fusion_method=args.value_fusion_method, + **gen_kwargs, + ) + generated_tokens = accelerator.pad_across_processes( + generated_tokens, dim=1, pad_index=tokenizer.pad_token_id + ) + generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() + decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + decoded_tokens = [ans.strip() for ans in decoded_tokens] + all_gen_ans += decoded_tokens + + # Auto-Encoding loss + embed_dict = model.wrapped_embed_kv( + separate_task=args.separate_task, + key_input_ids=batch["key_input_ids"], + key_attention_mask=batch["key_attention_mask"], + value_input_ids=batch["value_input_ids"], + value_attention_mask=batch["value_attention_mask"], + compute_key=args.train_key, + compute_value=args.train_value, + embed_for_ae_task=True + ) + key_embeds = embed_dict["normed_key_embeds"] + value_embeds = embed_dict["normed_value_embeds"] # normed value for generation + + if args.train_key: + key_labels = batch["key_labels_input_ids"] + key_embeds = key_embeds.view(key_embeds.shape[0], -1, model.model_dim) + + recovered_from_key = accelerator.unwrap_model(model).generate( + encoder_outputs=CATEncoderOutput(last_hidden_state=key_embeds, hidden_states=None, attentions=None), + attention_mask=None, encoder_outputs_are_key_or_value=True, **gen_kwargs, + ) + recovered_from_key = accelerator.pad_across_processes( + recovered_from_key, dim=1, pad_index=tokenizer.pad_token_id + ) + if not args.pad_to_max_length: + # If we did not pad to max length, we need to pad the labels too + key_labels = accelerator.pad_across_processes(key_labels, dim=1, pad_index=tokenizer.pad_token_id) + recovered_from_key = accelerator.gather(recovered_from_key).cpu().numpy() + key_labels = accelerator.gather(key_labels).cpu().numpy() + if args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + key_labels = np.where(key_labels != -100, key_labels, tokenizer.pad_token_id) + decoded_key = tokenizer.batch_decode(recovered_from_key, skip_special_tokens=True) + decoded_key_labels = tokenizer.batch_decode(key_labels, skip_special_tokens=True) + decoded_key, decoded_key_labels = postprocess_text(decoded_key, decoded_key_labels) + metric_key.add_batch(predictions=decoded_key, references=decoded_key_labels) + + if args.train_value: + value_labels = batch["value_labels_input_ids"] + value_embeds = value_embeds.view(value_embeds.shape[0], -1, model.model_dim) + recovered_from_value = accelerator.unwrap_model(model).generate( + encoder_outputs=CATEncoderOutput(last_hidden_state=value_embeds, hidden_states=None, attentions=None), + attention_mask=None, encoder_outputs_are_key_or_value=True, **gen_kwargs, + ) + recovered_from_value = accelerator.pad_across_processes( + recovered_from_value, dim=1, pad_index=tokenizer.pad_token_id + ) + if not args.pad_to_max_length: + # If we did not pad to max length, we need to pad the labels too + value_labels = accelerator.pad_across_processes(value_labels, dim=1, pad_index=tokenizer.pad_token_id) + recovered_from_value = accelerator.gather(recovered_from_value).cpu().numpy() + value_labels = accelerator.gather(value_labels).cpu().numpy() + if args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + value_labels = np.where(value_labels != -100, value_labels, tokenizer.pad_token_id) + decoded_value = tokenizer.batch_decode(recovered_from_value, skip_special_tokens=True) + decoded_value_labels = tokenizer.batch_decode(value_labels, skip_special_tokens=True) + decoded_value, decoded_value_labels = postprocess_text(decoded_value, decoded_value_labels) + metric_value.add_batch(predictions=decoded_value, references=decoded_value_labels) + + return metric_key, metric_value, all_gen_ans + + +if __name__ == "__main__": + main() diff --git a/pretrain_scripts/pretrain_emat.sh b/pretrain_scripts/pretrain_emat.sh new file mode 100644 index 0000000..dbcc661 --- /dev/null +++ b/pretrain_scripts/pretrain_emat.sh @@ -0,0 +1,50 @@ +#!/bin/bash -l + +set -e +set -u + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +EXP_NAME="KL=3;kdim=1536;VL=7;VN=10;async_cat_k+v;t5-base;" +LOAD_EXP_NAME="t5-base" # t5-base dir +DATA_NAME="PAQ-L1-Pretrain" +DEVICE="0" + +echo ${DEVICE} + +CUDA_VISIBLE_DEVICES=${DEVICE} python kvm_pretrain.py \ + --project_name="CAT" \ + --pretrain_multi_values \ + --exp_name="${EXP_NAME}" \ + --pretrain_data_name=${DATA_NAME} \ + --num_values=10 \ + --value_layer=7 \ + --key_layer=3 \ + --value_fusion_method="cat_k_delay+v" \ + --key_reduce_method="avg" \ + --model_name_or_path=${LOAD_EXP_NAME} \ + --source_prefix="question: " \ + --per_device_train_batch_size=128 \ + --per_device_eval_batch_size=256 \ + --gradient_accumulation_steps=2 \ + --preprocessing_num_workers=10 \ + --learning_rate=5e-5 \ + --num_train_epochs=5 \ + --lr_scheduler_type="linear" \ + --num_warmup_steps=5000 \ + --output_dir="./outputs/checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_encoder_type="conv" \ + --seed=42 \ + --gen_weight=1.0 \ + --key_ae_weight=0.5 \ + --value_ae_weight=0.5 \ + --value_with_self_prop=0.1 \ + --key_ae_target="question" \ + --value_ae_target="ans" \ + --do_train \ + --separate_task \ + --train_key \ + --train_value diff --git a/pretrain_scripts/pretrain_sksv_emat.sh b/pretrain_scripts/pretrain_sksv_emat.sh new file mode 100644 index 0000000..5b3e05e --- /dev/null +++ b/pretrain_scripts/pretrain_sksv_emat.sh @@ -0,0 +1,52 @@ +#!/bin/bash -l + +set -e +set -u + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +LOAD_EXP_NAME="t5-base" # t5-base dir +EXP_NAME="KL=3;kdim=1536;CL=10;VL=11;VN=10;async_cat_k+v;t5-base;" +DATA_NAME="PAQ-L1-Pretrain" +DEVICE="0" + +echo ${DEVICE} + +CUDA_VISIBLE_DEVICES=${DEVICE} python kvm_pretrain.py \ + --key_layer=3 \ + --cat_layer=10 \ + --value_layer=11 \ + --value_fusion_method="async_cat_k_delay+v" \ + --project_name="CAT" \ + --pretrain_multi_values \ + --exp_name="${EXP_NAME}" \ + --pretrain_data_name=${DATA_NAME} \ + --num_values=10 \ + --key_reduce_method="avg" \ + --model_name_or_path="${LOAD_EXP_NAME}" \ + --source_prefix="question: " \ + --per_device_train_batch_size=128 \ + --per_device_eval_batch_size=256 \ + --gradient_accumulation_steps=2 \ + --preprocessing_num_workers=10 \ + --learning_rate=5e-5 \ + --num_train_epochs=5 \ + --lr_scheduler_type="constant" \ + --num_warmup_steps=5000 \ + --output_dir="./outputs/checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_encoder_type="conv" \ + --seed=42 \ + --gen_weight=1.0 \ + --key_ae_weight=0.5 \ + --value_ae_weight=0.5 \ + --value_with_self_prop=0.1 \ + --key_ae_target="question" \ + --value_ae_target="ans" \ + --do_train \ + --separate_task \ + --train_key \ + --train_value \ + --do_eval diff --git a/qa_dataset.py b/qa_dataset.py new file mode 100644 index 0000000..fbd14bd --- /dev/null +++ b/qa_dataset.py @@ -0,0 +1,195 @@ +from itertools import chain +import torch +from torch.utils.data import Dataset, DataLoader +from typing import List, Dict +import random + +from emat.evaluation.exact_match import normalize_answer +from utils.utils import process_labels +# from transformers.models.t5 import T5Tokenizer +from transformers import T5Tokenizer + + +def format_data(item, label2str, str2label, dataset_name, task): + if dataset_name == "commonsense_qa": + pass + + +class QADataset(Dataset): + def __init__( + self, + data: List[Dict], + tokenizer: T5Tokenizer, + qas_to_retrieve, + dataset_name, + retrieve_strategy="dr", + max_source_length=None, + args=None, + normed_answer_of_qas_to_ret=None, + ): + super(QADataset, self).__init__() + self.data: List[Dict] = data + + for idx, i in enumerate(self.data): + i["idx"] = idx + i["normalized_answer"] = [normalize_answer(ans) for ans in i["answer"]] + assert dataset_name in ["nq", "tq", "wq"] + self.max_source_length = max_source_length if max_source_length is not None else 430 + self.dataset_name = dataset_name + print(f"dataset-name: {dataset_name}") + self.max_target_length = 64 + self.tokenizer = tokenizer + self.pad_idx = self.tokenizer.pad_token_id + self.label_pad_idx = -100 + self.args = args + self.add_ae_input = False + self.qas_to_retrieve = qas_to_retrieve + self.normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret + + self.pad_qa = {"question": "", "answer": [""]} + + def get_key_value_inputs(self, qas, only_return_key_inputs=False): + # Used to get the input of Key-Value Encoder, qas are from PAQ-L1 + key_inputs = ["question: " + qa["question"] for qa in qas] + key_inputs = self.tokenizer(key_inputs, max_length=self.max_source_length, + padding=True, truncation=True, return_tensors="pt") + if only_return_key_inputs: + return {"key_input_ids": key_inputs["input_ids"], + "key_attention_mask": key_inputs["attention_mask"]} + else: + value_inputs = ["answer: " + qa["answer"][0] for qa in qas] + value_inputs = self.tokenizer(value_inputs, max_length=self.max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"key_input_ids": key_inputs["input_ids"], + "key_attention_mask": key_inputs["attention_mask"], + "value_input_ids": value_inputs["input_ids"], + "value_attention_mask": value_inputs["attention_mask"]} + + def get_query_inputs(self, batch): + query_inputs = ["question: " + qa["question"] for qa in batch] + query_inputs = self.tokenizer(query_inputs, max_length=self.max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"query_input_ids": query_inputs["input_ids"], + "query_attention_mask": query_inputs["attention_mask"]} + + def get_dataloader(self, batch_size, shuffle, num_workers): + + def base_collate_fn(batch): + + original_batch_size, filtered_batch_size = len(batch), len(batch) + if not self.args.use_not_exactly_true: + batch = [ex for ex in batch if len(ex["local_positive"]) > 0] + filtered_batch_size = len(batch) + while len(batch) == 0: # avoid empty-batch + batch = random.sample(self.data, batch_size) + batch = [ex for ex in batch if len(ex["local_positive"]) > 0] + # do not change filtered_batch_size even change the batch again. + + model_inputs = { + "batch_data_ids": torch.tensor([qa["idx"] for qa in batch]), + "trainable_percentage": torch.tensor(filtered_batch_size / original_batch_size).repeat(len(batch)), + # repeat ``len(batch)`` times to compatible in multi-GPUs. + } + model_inputs.update(self.get_query_inputs(batch)) + + batch_local_positive_num = self.args.batch_local_positive_num + neg_num_each_example = self.args.negatives_num_each_example + local_positive_qas = [] + local_positive_num = [] + local_positive_qas_mask = [] + local_negative_qas = [] + local_pos_mix_neg_qas = [] # num = neg_num_each_example + for ex in batch: + cur_local_positive_qas_ids = [idx for idx in ex["local_positive"][:batch_local_positive_num]] + cur_local_positive_qas = [self.qas_to_retrieve[idx] for idx in cur_local_positive_qas_ids] + cur_pos_num = len(cur_local_positive_qas) + local_positive_num.append(cur_pos_num) + + cur_local_negative_qas_idx = random.sample(ex["local_negative"], neg_num_each_example) + cur_local_negative_qas = [self.qas_to_retrieve[idx] for idx in cur_local_negative_qas_idx] + local_negative_qas.append(cur_local_negative_qas) + cur_local_pos_mix_neg_qas = cur_local_positive_qas + \ + cur_local_negative_qas[:neg_num_each_example - cur_pos_num] + local_pos_mix_neg_qas.append(cur_local_pos_mix_neg_qas) + + cur_pad_num = batch_local_positive_num - cur_pos_num + cur_local_positive_qas_mask = [1] * cur_pos_num + [0] * cur_pad_num + local_positive_qas_mask.append(cur_local_positive_qas_mask) + cur_local_positive_qas.extend([self.pad_qa] * cur_pad_num) + local_positive_qas.append(cur_local_positive_qas) + + model_inputs.update({"local_positive_qas_mask": torch.tensor(local_positive_qas_mask), + "local_positive_num": torch.tensor(local_positive_num), }) + if self.dataset_name == "tq" or self.dataset_name == "wq": + squeezed_positive_qas = list(chain(*local_positive_qas)) + squeezed_positive_target = [qa["answer"][0] for qa in squeezed_positive_qas] + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(squeezed_positive_target, max_length=self.max_target_length, + padding=True, truncation=True, return_tensors="pt") + model_inputs["labels_to_select"] = process_labels(targets, self.tokenizer). \ + view(len(batch), batch_local_positive_num, -1) + else: + targets = [random.choice(qa["answer"]) for qa in batch] + with self.tokenizer.as_target_tokenizer(): + targets = self.tokenizer(targets, max_length=self.max_target_length, + padding=True, truncation=True, return_tensors="pt") + model_inputs["labels"] = process_labels(targets, self.tokenizer) + + assert self.args.select_positive_strategy == "softmax_sample" + squeezed_positive_qas = list(chain(*local_positive_qas)) + local_positive_inputs = self.get_key_value_inputs(squeezed_positive_qas, only_return_key_inputs=True) + model_inputs.update({f"local_positive_inputs_{k}": v.view(len(batch), batch_local_positive_num, -1) + for k, v in local_positive_inputs.items()}) + + squeezed_negative_qas = list(chain(*local_negative_qas)) + local_negative_inputs = self.get_key_value_inputs(squeezed_negative_qas, only_return_key_inputs=True) + model_inputs.update({f"local_negative_inputs_{k}": v.view(len(batch), neg_num_each_example, -1) + for k, v in local_negative_inputs.items()}) + + squeezed_mixed_qas = list(chain(*local_pos_mix_neg_qas)) + local_mixed_inputs = self.get_key_value_inputs(squeezed_mixed_qas) + model_inputs.update({f"local_mixed_inputs_{k}": v.view(len(batch), neg_num_each_example, -1) + for k, v in local_mixed_inputs.items()}) + if self.dataset_name == "tq": + all_targets = [[normalize_answer(an) for an in qa["answer"]] for qa in batch] + negative_qas_answer = [normalize_answer(nqa["answer"][0]) for nqa in squeezed_negative_qas] + negative_mask = [[1 if neg_ans not in cur_all_target else 0 for neg_ans in negative_qas_answer] + for cur_all_target in all_targets] + model_inputs.update({"negative_mask": torch.tensor(negative_mask)}) + + # for multi-GPUs + assert all(model_inputs[k].shape[0] == len(batch) for k in model_inputs.keys()) + + return model_inputs + + return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, + collate_fn=base_collate_fn, pin_memory=True) + + def get_query_dataloader(self, batch_size, shuffle, num_workers): + + def query_collate_fn(batch): + return self.get_query_inputs(batch) + + return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, + collate_fn=query_collate_fn, pin_memory=True) + + def get_local_qas_dataloader(self, batch_size, shuffle, num_workers): + + def local_qas_collate_fn(batch): + # model_inputs = self.get_query_inputs(batch) + local_qas = [[self.qas_to_retrieve[qid] for qid in ex['local_qas']] for ex in batch] + query_ids = [ex["idx"] for ex in batch] + squeezed_local_qas = list(chain(*local_qas)) + squeezed_local_qas_inputs = self.get_key_value_inputs(squeezed_local_qas, only_return_key_inputs=True) + # model_inputs.update(squeezed_local_qas_inputs) + # return model_inputs + return {**squeezed_local_qas_inputs, "query_ids": torch.tensor(query_ids)} + + return DataLoader(dataset=self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, + collate_fn=local_qas_collate_fn, pin_memory=True) + + def __len__(self): + return len(self.data) + + def __getitem__(self, item): + return self.data[item] diff --git a/qa_main.py b/qa_main.py new file mode 100644 index 0000000..b16cdb1 --- /dev/null +++ b/qa_main.py @@ -0,0 +1,145 @@ +import json +import os +import pickle +from transformers import T5Tokenizer +from emat.utils import load_jsonl +from utils.utils import CATArgs +from qa_dataset import QADataset +from qa_trainer import QATrainer +import logging + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) + +DATA_PATHS = { + "nq": { + "train": "./data/annotated_datasets/NQ-open.train-train.jsonl", + "validation": "./data/annotated_datasets/NQ-open.train-dev.jsonl", + "test": "./data/annotated_datasets/NQ-open.test.jsonl" + }, + "tq": { + "train": "./data/annotated_datasets/triviaqa.train-train.jsonl", + "validation": "./data/annotated_datasets/triviaqa.train-dev.jsonl", + "test": "./data/annotated_datasets/triviaqa.test.jsonl" + }, + "wq": { + "train": "./data/annotated_datasets/WQ-trainmodel.jsonl", + "validation": "./data/annotated_datasets/WQ-val.jsonl", + "test": "./data/annotated_datasets/WQ-test.jsonl" + } +} +QA_KB_PATHS = { + "PAQ_L1": "./tmp/PAQ_L1_pickl_file.pkl", + "PAQ": "./tmp/PAQ_full.pkl", + "TAQ_TRAIN_NQ_TRAIN_PAQ": "./data/paq/TQA_TRAIN_NQ_TRAIN_PAQ/tqa-train-nq-train-PAQ.jsonl", +} +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + +def load_dataset(args): + assert args.qa_data_name in DATA_PATHS.keys(), f"available dataset names: {DATA_PATHS.keys()}" + data_paths = DATA_PATHS[args.qa_data_name] + test_data = load_jsonl(DATA_PATHS[args.qa_data_name]["test"]) + train_data = load_jsonl(data_paths["train"]) + dev_data = load_jsonl(data_paths["validation"]) + + logging.info("loading normed answer of qas to retrieve") + if "PAQ" == args.qas_to_retrieve_from: + normed_answer_of_qas_to_ret = pickle.load(open("./tmp/PAQ_only_normalized_answer.pkl", 'rb')) + else: + normed_answer_of_qas_to_ret = json.load(open("./tmp/PAQL1_only_normalized_answer.json", 'r')) + + logging.info("loading qas to retrieve") + if "debug" in args.exp_name.lower() or "full-paq-test" in args.exp_name.lower(): + if not os.path.exists("./tmp/PAQ_L1_small.pkl"): + qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_pickl_file.pkl", 'rb')) + qas_to_retrieve = qas_to_retrieve[:len(qas_to_retrieve) // 14] + pickle.dump(qas_to_retrieve, open("./tmp/PAQ_L1_small.pkl", 'wb')) + else: + qas_to_retrieve = pickle.load(open("./tmp/PAQ_L1_small.pkl", 'rb')) + else: + qas_to_retrieve_fp = QA_KB_PATHS[args.qas_to_retrieve_from] + logging.info(f"loading qas from {qas_to_retrieve_fp}") + if qas_to_retrieve_fp.endswith("pkl"): + qas_to_retrieve = pickle.load(open(qas_to_retrieve_fp, 'rb')) + elif qas_to_retrieve_fp.endswith("jsonl"): + qas_to_retrieve = load_jsonl(qas_to_retrieve_fp) + else: + raise ValueError(f"{qas_to_retrieve_fp}") + + if "debug" in args.exp_name.lower(): + train_data = train_data[:100] + dev_data = dev_data[:500] + qas_to_retrieve = qas_to_retrieve[:10000] + normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:len(qas_to_retrieve)] + + if args.qas_to_retrieve_from == "PAQ" and args.PAQ_size is not None: + qas_to_retrieve = qas_to_retrieve[:args.PAQ_size] + normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret[:args.PAQ_size] + assert len(qas_to_retrieve) == args.PAQ_size + logging.info(f"select {args.PAQ_size}-size PAQ.") + + assert len(normed_answer_of_qas_to_ret) == len(qas_to_retrieve) + loaded_data = { + "train": train_data, "validation": dev_data, "test": test_data, + "qas_to_retrieve": qas_to_retrieve, + "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret + } + + return loaded_data + + +def main(): + cat_args = CATArgs("qa_cat") + args = cat_args.parse_args() + loaded_data = load_dataset(args) + logging.info("data loaded.") + train_data, dev_data, test_data = loaded_data["train"], loaded_data["validation"], loaded_data["test"] + qas_to_retrieve = loaded_data["qas_to_retrieve"] + normed_answer_of_qas_to_ret = loaded_data["normed_answer_of_qas_to_ret"] + + dataset_kwargs = { + "max_source_length": args.max_source_length, + "dataset_name": args.qa_data_name, + "args": args, + "normed_answer_of_qas_to_ret": normed_answer_of_qas_to_ret, + } + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + train_dataset = QADataset(train_data, tokenizer, qas_to_retrieve, **dataset_kwargs) + dev_dataset = QADataset(dev_data, tokenizer, qas_to_retrieve, **dataset_kwargs) + test_dataset = QADataset(test_data, tokenizer, qas_to_retrieve, **dataset_kwargs) + + qa_trainer = QATrainer(args, train_dataset, dev_dataset, test_dataset, qas_to_retrieve, normed_answer_of_qas_to_ret) + + if args.do_train: + qa_trainer.train() + elif args.do_test: + logging.info("Only do test.") + ckpt_load_path = os.path.join(args.output_dir, "best_ckpt/pytorch_model.bin") + em_score, match_metric, ret_qas, gen_ans = qa_trainer.evaluate( + qa_trainer.test_dataset, extend_mem_from="train_dev", + update_key_memory=True, ckpt_load_path=ckpt_load_path + + ) + + logging.info(f"em_test: {em_score:.3f}") + for k, v in match_metric.items(): + logging.info(f"test_{k}: {v}") + results = [] + for idx, (input_qa, retrieved_qas, predict_answer) in enumerate(zip(qa_trainer.test_dataset, ret_qas, gen_ans)): + results.append({ + "idx": idx, + "question": input_qa["question"], + "answer": input_qa["answer"], + "retrieved_qas": [{"question": qa["question"], "answer": qa["answer"][0]} for qa in retrieved_qas], + "generate_answer": predict_answer, + }) + json.dump(results, open(os.path.join(args.output_dir, "best_ckpt/predict_results.json"), 'w'), + indent=4, ensure_ascii=False) + + +if __name__ == '__main__': + main() diff --git a/qa_trainer.py b/qa_trainer.py new file mode 100644 index 0000000..994e0ee --- /dev/null +++ b/qa_trainer.py @@ -0,0 +1,565 @@ +import copy +import logging +import math +import os +import random +from itertools import chain +from typing import List, Dict, Optional +from emat.evaluation.eval_retriever import eval_retriever, eval_generation_em +from utils.dr_utils import update_local_qas_to_retrieve, update_batch_inputs, rank_exist_local_qas +from utils.utils import reduce_query_or_key_embeds +import datasets +import torch +import transformers +from accelerate import Accelerator +from tqdm.auto import tqdm +from transformers import AdamW, get_scheduler, set_seed +from utils.utils import save_model, load_model +from build_kvm import build_memory +from emat.t5 import T5WithKeyValueMemory +from qa_dataset import QADataset + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + +try: + import wandb + + wandb.ensure_configured() + if wandb.api.api_key is None: + _has_wandb = False + wandb.termwarn( + "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.") + else: + _has_wandb = False if os.getenv("WANDB_DISABLED") else True +except (ImportError, AttributeError): + _has_wandb = False + + +class QATrainer: + + def __init__( + self, + args, + train_dataset: QADataset, + dev_dataset: QADataset, + test_dataset: QADataset, + qas_to_retrieve: List[Dict], + normed_answer_of_qas_to_ret, + ): + accelerator = Accelerator() + logging.info(f"wandb {'available' if _has_wandb else 'unavailable'}") + logger.info(accelerator.state) + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + if args.seed is not None: + set_seed(args.seed) + else: + logging.info("Not set seed.") + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + if accelerator.is_local_main_process and _has_wandb: + wandb.init(project=args.project_name, name=args.exp_name, dir=args.output_dir, config=vars(args)) + + logging.info("loading model") + config, tokenizer, self.model = load_model(T5WithKeyValueMemory, args) + logging.info("Loading model.") + logging.info(f"model params: {self.model.num_parameters()}") + if args.freeze_t5_params: + logging.info("Freeze T5 parameters.") + self.model.freeze_t5_params() + if args.only_train_adapter: + for param in self.model.parameters(): + param.requires_grad = False + for param in self.model.adapter.parameters(): + param.requires_grad = True + + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, }, + {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, }, + ] + self.optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # reused_key_memory: pre-allocated memory to store full key_memory + self.reused_key_memory = torch.zeros((len(qas_to_retrieve), self.model.model_dim), + device="cpu", dtype=torch.float16) + self.train_data_query_embeds = torch.zeros((len(train_dataset), self.model.model_dim), + device="cpu", dtype=torch.float16) + self.key_memory: Optional[List[torch.tensor]] = None + self.key_memory = [] + for start_idx in range(0, len(qas_to_retrieve), math.ceil(len(qas_to_retrieve) / args.kvm_seg_n)): + self.key_memory.append( + self.reused_key_memory[start_idx: start_idx + math.ceil(len(qas_to_retrieve) / args.kvm_seg_n)] + ) + logger.info(f"key num = {sum(len(i) for i in self.key_memory)}") + + self.train_dataset = train_dataset + self.dev_dataset = dev_dataset + self.test_dataset = test_dataset + self.args = args + self.accelerator = accelerator + self.tokenizer = tokenizer + self.qas_to_retrieve = qas_to_retrieve + self.prefix = args.source_prefix if args.source_prefix is not None else "" + assert self.prefix == "question: " + # self.query_batch_size = 550 if args.kvm_seg_n > 2 else 256 + # self.query_batch_size = 1024 if args.kvm_seg_n >= 4 else self.query_batch_size + # self.query_batch_size = 3000 if args.kvm_seg_n >= 7 else self.query_batch_size + # # if len(self.qas_to_retrieve) < 20000000: + # self.query_batch_size = 512 + self.query_batch_size = args.query_batch_size + logger.info(f"PAQ-size: {len(self.qas_to_retrieve)}. PAQ's query batch size: {self.query_batch_size}.") + self.normed_answer_of_qas_to_ret = normed_answer_of_qas_to_ret + self.model = self.accelerator.prepare(self.model) + + @torch.no_grad() + def update_key_memory(self, use_fp16_model=True, use_retrieval_adapter=False): + args = self.args + if use_fp16_model: + tmp_model = copy.deepcopy(self.model) + tmp_model = tmp_model.half() + else: + tmp_model = self.model + build_mem_batch_size = args.build_mem_batch_size + tmp_model.eval() + self.key_memory, _ = build_memory( + tmp_model, self.tokenizer, embed_key=True, embed_value=False, prefix=self.prefix, embed_as_fp16=True, + key_reduce_method=args.key_reduce_method, return_memory=True, dump_memory=False, + data_to_embed=self.qas_to_retrieve, max_source_length=args.max_source_length, padding=True, + batch_size=build_mem_batch_size, separate_task=True, kvm_seg_n=args.kvm_seg_n, + reused_key_memory=self.reused_key_memory, use_retrieval_adapter=use_retrieval_adapter + ) + if type(self.key_memory) is not list: + self.key_memory = [self.key_memory] + del tmp_model + + @torch.no_grad() + def update_local_qas(self, epoch, use_fp16_model=True, use_retrieval_adapter=False): + args = self.args + if use_fp16_model: + tmp_model = copy.deepcopy(self.model) + tmp_model = tmp_model.half() + else: + tmp_model = self.model + build_mem_batch_size = args.build_mem_batch_size + tmp_model.eval() + if args.update_kv_embeds and args.update_local_qas and epoch >= args.repaq_supervision_epoch: + update_local_qas_to_retrieve( + args, self.train_dataset, self.qas_to_retrieve, tmp_model, self.key_memory, + self.normed_answer_of_qas_to_ret, train_data_query_embeds=self.train_data_query_embeds, + build_mem_batch_size=build_mem_batch_size, query_batch_size=self.query_batch_size, + local_size=args.local_size, pos_from_top=args.pos_from_top, neg_from_top=200, + use_retrieval_adapter=use_retrieval_adapter + ) + elif args.only_rank_exists_local_qa: + logging.warning("Do not use!") + embed_local_qas_batch_size = (build_mem_batch_size // + len(self.train_dataset.data[0]["local_qas"]) + 1) * 2 + rank_exist_local_qas(args, self.train_dataset, self.qas_to_retrieve, tmp_model, + self.normed_answer_of_qas_to_ret, build_mem_batch_size=build_mem_batch_size, + train_data_query_embeds=self.train_data_query_embeds, + embed_local_qas_batch_size=embed_local_qas_batch_size, + local_size=args.local_size, pos_from_top=args.pos_from_top, neg_from_top=200, + accelerator=self.accelerator) + del tmp_model + + def train(self): + args = self.args + tokenizer = self.tokenizer + num_workers = 5 + if args.update_kv_embeds and not args.only_rank_exists_local_qa: + logging.info("Build Memory") + self.update_key_memory() + train_dataloader = self.train_dataset.get_dataloader(batch_size=args.per_device_train_batch_size, + shuffle=True, num_workers=num_workers) + + optimizer, train_dataloader = self.accelerator.prepare(self.optimizer, train_dataloader) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Train! + total_batch_size = args.per_device_train_batch_size * self.accelerator.num_processes \ + * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(self.train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not self.accelerator.is_local_main_process) + completed_steps = 0 + best_hit_at_1, best_em, patience = None, None, args.early_stop_patience + + for epoch in range(args.num_train_epochs): + use_adapter_to_select_positive = epoch >= args.use_adapter_to_select_positive_after_k_epoch + if args.only_train_adapter: + if epoch == 0: + self.update_local_qas(epoch) + # convert original key memory with high dim to low dim through adapter + qas_num = self.reused_key_memory.shape[0] + train_qas_num = len(self.train_dataset) + dim = args.adapter_out_dim + self.reused_key_memory = torch.zeros((qas_num, dim), device="cpu", dtype=torch.float16) + self.train_data_query_embeds = torch.zeros((train_qas_num, dim), device="cpu", dtype=torch.float16) + self.key_memory: Optional[List[torch.tensor]] = None + self.key_memory = [] + for start_idx in range(0, qas_num, math.ceil(qas_num / args.kvm_seg_n)): + self.key_memory.append( + self.reused_key_memory[start_idx: start_idx + math.ceil(qas_num / args.kvm_seg_n)] + ) + self.update_key_memory(use_retrieval_adapter=True) + elif use_adapter_to_select_positive: + self.update_local_qas(epoch, use_retrieval_adapter=True) + else: + if args.qas_to_retrieve_from == "PAQ" and (epoch % 3 == 0): + self.update_local_qas(epoch) + elif args.qas_to_retrieve_from != "PAQ": + self.update_local_qas(epoch) + for step, batch in enumerate(train_dataloader): + + update_batch_inputs(args, batch, self.model, + use_adapter_to_select_positive=use_adapter_to_select_positive) + self.model.train() + if args.match_weight > 0.0: + # Embed Positive Key and the Value to input. + embed_dict = self.model.wrapped_embed_kv( # assert num_values > 1, otherwise set compute_value=True + separate_task=args.separate_task, compute_key=True, compute_value=False, + **batch.pop("positive_kv_inputs") + ) + positive_key_embeds = embed_dict["normed_key_embeds"] + positive_key_embeds = reduce_query_or_key_embeds(positive_key_embeds, args.key_reduce_method) + # Embed Negative Key + embed_dict = self.model.wrapped_embed_kv( + separate_task=args.separate_task, compute_key=True, compute_value=False, + **batch.pop("negative_kv_inputs") + ) + negative_key_embeds = embed_dict["normed_key_embeds"] + negative_key_embeds = reduce_query_or_key_embeds(negative_key_embeds, args.key_reduce_method) + else: + negative_key_embeds, positive_key_embeds = None, None + # Embed retrieved-Key-Value + embed_dict = self.model.wrapped_embed_kv( + separate_task=args.separate_task, compute_key=True, compute_value=True, + **batch.pop("group_value_inputs") + ) + key_embeds_of_value = embed_dict["key_embeds"] + value_embeds = embed_dict["value_embeds"] + bs = batch["query_input_ids"].shape[0] + value_embeds = value_embeds.view(bs, args.num_values, args.prefix_length, -1) + key_embeds_of_value = key_embeds_of_value.view(bs, args.num_values, -1, self.model.model_dim) + + loss_dict = self.model.compute_qa_loss( + input_ids=batch["query_input_ids"], + attention_mask=batch["query_attention_mask"], + labels=batch["labels"], + decoder_only_attend_on_prefix=args.decoder_only_attend_on_prefix, + value_fusion_method=args.value_fusion_method, + encoder_outputs_are_key_or_value=False, + key_reduce_method=args.key_reduce_method, + positive_key_embeds=positive_key_embeds, + negative_key_embeds=negative_key_embeds, + value_embeds=value_embeds, + matching_targets=batch["matching_targets"], + key_embeds_of_value=key_embeds_of_value, + negative_mask=batch.get("negative_mask", None), + only_train_adapter=args.only_train_adapter + ) + if args.match_weight > 0.0: + if epoch >= args.only_key_matching_n_epoch: + loss = args.gen_weight * loss_dict["gen_loss"] + args.match_weight * loss_dict["match_loss"] + else: + loss = args.match_weight * loss_dict["match_loss"] + else: + loss = loss_dict["gen_loss"] + loss_dict = {k: v.item() for k, v in loss_dict.items()} + loss = loss / args.gradient_accumulation_steps + self.accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + if self.accelerator.is_local_main_process and _has_wandb: + wandb.log({"loss": loss * args.gradient_accumulation_steps, "step": completed_steps}) + wandb.log({"trainable_percentage": batch["trainable_percentage"][0].item(), + "forward_step": completed_steps}) + for k, v in loss_dict.items(): + wandb.log({k: v, "step": completed_steps}) + + if completed_steps >= args.max_train_steps: + break + + if args.output_dir is not None: + save_model(self.model, os.path.join(args.output_dir, "latest_ckpt"), + self.accelerator, tokenizer=tokenizer, arguments=args) + + if (args.update_kv_embeds and not args.only_rank_exists_local_qa) or args.do_eval: + logging.info("Update Memory") + self.update_key_memory(use_retrieval_adapter=args.only_train_adapter) + + if args.do_eval and epoch % args.eval_freq == 0: + # if args.do_eval: self.key_memory is up-to-date. + em_score, matching_metric, _, _ = self.evaluate(dataset=self.dev_dataset, extend_mem_from="train", + use_retrieval_adapter=args.only_train_adapter) + logger.info(f"epoch {epoch} eval - EM: {em_score:.3f}") + if self.accelerator.is_local_main_process and _has_wandb: + wandb.log({"em_dev": em_score, "epoch": epoch}) + for k, v in matching_metric.items(): + wandb.log({f"{k}": v * 100, "epoch": epoch}) + + if args.output_dir is not None: + if best_hit_at_1 is None or matching_metric["hit@1"] * 100 > best_hit_at_1: + best_hit_at_1 = matching_metric["hit@1"] * 100 + if best_em is None or em_score > best_em: + best_em = em_score + save_model(self.model, os.path.join(args.output_dir, "best_ckpt"), + self.accelerator, tokenizer=tokenizer, arguments=args) + patience = args.early_stop_patience + else: + patience -= 1 + if patience <= 0: + break + + if best_em is not None: # Log the best dev EM score + logger.info(f"best_em_dev: {best_em}") + if _has_wandb: + wandb.log({"best_em_dev": best_em}) + if best_hit_at_1 is not None: + logger.info(f"best_hit@1_dev: {best_hit_at_1}") + if _has_wandb: + wandb.log({"best_hit@1_dev": best_hit_at_1}) + + # do-test + best_model_state_dict = os.path.join(args.output_dir, "best_ckpt/pytorch_model.bin") + em_score, matching_metric, _, _ = self.evaluate(dataset=self.test_dataset, extend_mem_from="train_dev", + update_key_memory=True, ckpt_load_path=best_model_state_dict, + use_retrieval_adapter=args.only_train_adapter) + + if self.accelerator.is_local_main_process: + logger.info(f"em_test: {em_score:.3f}") + for k, v in matching_metric.items(): + logger.info(f"test_{k}: {v}") + if _has_wandb: + wandb.log({"em_test": em_score}) + for k, v in matching_metric.items(): + wandb.log({f"test_{k}": v}) + + @torch.no_grad() + def evaluate(self, dataset: QADataset = None, extend_mem_from="", update_key_memory=False, ckpt_load_path=None, + use_retrieval_adapter=False): + # not implement correctly in multi-GPUs. + tokenizer = self.tokenizer + args = self.args + self.model.eval() + torch.cuda.empty_cache() + + assert extend_mem_from in ["train", "train_dev"] + if ckpt_load_path is not None: + assert update_key_memory is True + loaded_state_dict = torch.load(ckpt_load_path) + load_info = self.model.load_state_dict(loaded_state_dict, strict=False) + logging.info(f"{load_info}") + + assert type(self.key_memory) == list + original_key_length = sum(len(k) for k in self.key_memory) + if update_key_memory: + logging.info("Update Memory") + self.update_key_memory(use_retrieval_adapter=use_retrieval_adapter) + + extend_length = 0 + last_chunk_memory = self.key_memory[-1] + qas_to_retrieve_eval = self.qas_to_retrieve + + tmp_model = copy.deepcopy(self.model) + if args.kvm_fp16: + tmp_model = tmp_model.half() + + logging.info("Build train data memory to retrieve.") + if args.qa_data_name == "tq": + build_query_batch_size = 256 + else: + build_query_batch_size = args.build_mem_batch_size + if "train" in extend_mem_from: + train_qas_key_memory, _ = build_memory( + tmp_model, tokenizer, embed_key=True, embed_value=False, prefix=self.prefix, embed_as_fp16=True, + key_reduce_method=args.key_reduce_method, return_memory=True, dump_memory=False, kvm_seg_n=-1, + data_to_embed=self.train_dataset.data, max_source_length=args.max_source_length, padding=True, + batch_size=build_query_batch_size, separate_task=args.separate_task, reused_key_memory=None, + use_retrieval_adapter=use_retrieval_adapter + ) + extend_length = extend_length + len(train_qas_key_memory) + last_chunk_memory = torch.cat((last_chunk_memory, train_qas_key_memory)) # extend in the last chunk + qas_to_retrieve_eval = qas_to_retrieve_eval + self.train_dataset.data + + if "dev" in extend_mem_from: + logging.info("Build dev data memory to retrieve.") + dev_qas_key_memory, _ = build_memory( + tmp_model, tokenizer, embed_key=True, embed_value=False, prefix=self.prefix, embed_as_fp16=True, + key_reduce_method=args.key_reduce_method, return_memory=True, dump_memory=False, kvm_seg_n=-1, + data_to_embed=self.dev_dataset.data, max_source_length=args.max_source_length, padding=True, + batch_size=build_query_batch_size, separate_task=args.separate_task, reused_key_memory=None, + use_retrieval_adapter=use_retrieval_adapter + ) + extend_length = extend_length + len(dev_qas_key_memory) + last_chunk_memory = torch.cat((last_chunk_memory, dev_qas_key_memory)) # extend in the last chunk + qas_to_retrieve_eval = qas_to_retrieve_eval + self.dev_dataset.data + del tmp_model + + key_memory_eval = self.key_memory[:-1] + [last_chunk_memory] + key_nums_eval = sum(len(k) for k in key_memory_eval) + assert key_nums_eval == len(qas_to_retrieve_eval) + + # if use_retrieval_adapter: + # low_dim_key = [] + # while len(key_memory_eval) > 0: + # chunk_key = key_memory_eval.pop(0) + # chunk_low_dim_key = [] + # for start_idx in range(len(chunk_key)): + # chunk_low_dim_key.append(self.model.adapter(chunk_key[start_idx:start_idx + 512])) + # del chunk_key + # low_dim_key.append(torch.cat(chunk_low_dim_key)) + # key_memory_eval = low_dim_key + + dataloader = dataset.get_query_dataloader(batch_size=args.per_device_eval_batch_size, + shuffle=False, num_workers=1) + dataloader = self.accelerator.prepare(dataloader) + + gen_kwargs = {"max_length": args.max_target_length, + "num_beams": args.num_beams, } + + torch.cuda.empty_cache() + + all_retrieved_qas = [] + all_gen_ans = [] + for batch in tqdm(dataloader): + embed_dict = self.model.CAT_embed_q( + input_ids=batch["query_input_ids"], + attention_mask=batch["query_attention_mask"], + compute_key=True, compute_value=False + ) + query_embeds = embed_dict["normed_key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method) + if use_retrieval_adapter: + query_embeds = self.model.adapter(query_embeds) + query_embeds = query_embeds.half() + + if key_nums_eval > 20000000: + # if scale is large: calculate topk in each chunk -> combine all-topk -> select final topk + chunk_top_scores = [] + chunk_top_indices = [] + idx_shift = 0 + for chunk_key_memory in key_memory_eval: + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_topk = torch.mm(query_embeds, chunk_key_memory_cuda.t()).topk(50, dim=1) + chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size] + chunk_top_indices.append(chunk_topk.indices + idx_shift) + idx_shift += len(chunk_key_memory) + del chunk_key_memory_cuda + torch.cuda.empty_cache() + chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n + chunk_top_indices = torch.cat(chunk_top_indices, dim=1) # q_batch, local_size*seg_n + topk = chunk_top_scores.topk(50, dim=1) # q_batch, local_size + top_indices_indices = topk.indices + top_indices = [] + for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices): + top_indices.append([cur_indices[idx] for idx in cur_indices_indices]) + readout_qas = [[qas_to_retrieve_eval[idx] for idx in indices] for indices in top_indices] + else: + all_chunk_scores = [] + for chunk_key_memory in key_memory_eval: + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_scores = torch.mm(query_embeds, chunk_key_memory_cuda.t()) # query_batch + all_chunk_scores.append(chunk_scores) + del chunk_key_memory_cuda + scores = torch.cat(all_chunk_scores, dim=1) + top_indices = scores.topk(50, dim=1).indices.tolist() + readout_qas = [[qas_to_retrieve_eval[idx] for idx in indices] for indices in top_indices] + value_qas = [] + for qas in readout_qas: + selected_qas = qas[:args.num_values] + if not args.values_with_order: + random.shuffle(selected_qas) + value_qas.append(selected_qas) + all_retrieved_qas += readout_qas + + squeezed_value_qas = list(chain(*value_qas)) + + retrieved_qas_inputs = dataset.get_key_value_inputs(squeezed_value_qas, only_return_key_inputs=False) + embed_dict = self.model.wrapped_embed_kv(separate_task=args.separate_task, compute_key=True, + compute_value=True, **retrieved_qas_inputs) + value_embeds = embed_dict["value_embeds"] + key_embeds_of_value = embed_dict["key_embeds"] + cur_batch_size = query_embeds.shape[0] + value_embeds = value_embeds.view(cur_batch_size, args.num_values, args.prefix_length, -1) + key_embeds_of_value = key_embeds_of_value.view(cur_batch_size, args.num_values, -1, self.model.model_dim) + encoder_outputs = self.model.encoder( + input_ids=batch["query_input_ids"], + attention_mask=batch["query_attention_mask"], + return_dict=True, + value_embeds=value_embeds, + readout_top_k=-1, + key_reduce_method=args.key_reduce_method, + value_fusion_method=args.value_fusion_method, + key_embeds_of_value=key_embeds_of_value + ) + generated_tokens = self.accelerator.unwrap_model(self.model).generate( + encoder_outputs=encoder_outputs, + encoder_outputs_are_key_or_value=False, + decoder_only_attend_on_prefix=args.decoder_only_attend_on_prefix, + attention_mask=batch["query_attention_mask"].to(self.model.device), + value_fusion_method=args.value_fusion_method, + **gen_kwargs, + ) + generated_tokens = self.accelerator.pad_across_processes( + generated_tokens, dim=1, pad_index=tokenizer.pad_token_id + ) + generated_tokens = self.accelerator.gather(generated_tokens).cpu().numpy() + decoded_tokens = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + decoded_tokens = [ans.strip() for ans in decoded_tokens] + all_gen_ans += decoded_tokens + + torch.cuda.empty_cache() + + matching_metric = eval_retriever(dataset.data, all_retrieved_qas, "1,2,3,4,5,10,50") + em_score = eval_generation_em(dataset.data, all_gen_ans) * 100 + + assert original_key_length == sum(len(k) for k in self.key_memory) + assert original_key_length == len(self.qas_to_retrieve) + + return em_score, matching_metric, all_retrieved_qas, all_gen_ans + + +if __name__ == '__main__': + pass diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8275390 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +accelerate==0.5.1 +datasets==1.18.4 +nltk==3.7 +numpy==1.21.5 +rouge==1.0.1 +scikit_learn==1.0.2 +setuptools==58.0.4 +stanfordcorenlp==3.9.1.1 +tqdm==4.62.3 +rouge==1.0.1 \ No newline at end of file diff --git a/retrieval_adapter_scripts/train_adapter.sh b/retrieval_adapter_scripts/train_adapter.sh new file mode 100644 index 0000000..775a8a4 --- /dev/null +++ b/retrieval_adapter_scripts/train_adapter.sh @@ -0,0 +1,67 @@ +#!/bin/bash -l + +set -e +set -u + + +DST_DIR="/mnt/inspurfs/user-fs/zhaoyu/workspace/case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +LOAD_PATH="/mnt/inspurfs/user-fs/zhaoyu/workspace/case-augmented-transformer-master/outputs/nq_checkpoints/KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;pos_from_top=128;/best_ckpt/" # load pre-trained model +EXP_NAME="adapter768;distill;adapter_retrieve_after5epoch;" # set experiment name +DATA_NAME="nq" # datasets: ["nq", "tq", "wq"] +# +DEVICE="5" + +# Train nq-EMAT-FKSV +# use --kvm_fp16 if GPU OOM + +CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \ + --project_name="${DATA_NAME^^}-CAT" \ + --exp_name=${EXP_NAME} \ + --query_batch_size=256 \ + --build_mem_batch_size=12000 \ + --batch_local_positive_num=5 \ + --pos_from_top=128 \ + --do_eval \ + --kvm_seg_n=2 \ + --values_with_order \ + --value_layer=7 \ + --value_fusion_method="cat_k_delay+v" \ + --num_values=1 \ + --qa_data_name=${DATA_NAME} \ + --model_name_or_path=${LOAD_PATH} \ + --source_prefix="question: " \ + --per_device_train_batch_size=64 \ + --per_device_eval_batch_size=64 \ + --gradient_accumulation_steps=4 \ + --learning_rate=5e-5 \ + --num_train_epochs=30 \ + --lr_scheduler_type="linear" \ + --num_warmup_steps=1000 \ + --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_layer=3 \ + --key_encoder_type="conv" \ + --select_positive_strategy="softmax_sample" \ + --faiss_efsearch=128 \ + --gen_weight=1 \ + --match_weight=1 \ + --key_reduce_method="avg" \ + --qas_to_retrieve_from="PAQ_L1" \ + --local_size=384 \ + --update_kv_embeds \ + --update_local_target_each_batch \ + --update_local_qas \ + --separate_task \ + --value_ae_target="ans" \ + --key_ae_target="question" \ + --repaq_supervision_epoch=-1 \ + --early_stop_patience=8 \ + --negatives_num_each_example=32 \ + --only_train_adapter \ + --adapter="linear" \ + --adapter_out_dim=768 \ + --use_adapter_to_select_positive_after_k_epoch=5 \ + --do_train diff --git a/scripts/nq_eval.sh b/scripts/nq_eval.sh new file mode 100644 index 0000000..1805b8b --- /dev/null +++ b/scripts/nq_eval.sh @@ -0,0 +1,63 @@ +#!/bin/bash -l + +set -e +set -u + + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +CKPT_DIR="" # load pre-trained model +EXP_NAME="eval" # set experiment name +DATA_NAME="nq" # datasets: ["nq", "tq", "wq"] + +DEVICE="0" + +# Train nq-EMAT-FKSV +# use --kvm_fp16 if GPU OOM + +CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \ + --project_name="${DATA_NAME^^}-CAT" \ + --exp_name=${EXP_NAME} \ + --query_batch_size=256 \ + --build_mem_batch_size=12000 \ + --batch_local_positive_num=5 \ + --pos_from_top=128 \ + --do_eval \ + --kvm_seg_n=2 \ + --values_with_order \ + --value_layer=7 \ + --value_fusion_method="cat_k_delay+v" \ + --num_values=10 \ + --qa_data_name=${DATA_NAME} \ + --model_name_or_path=${CKPT_DIR} \ + --source_prefix="question: " \ + --per_device_train_batch_size=64 \ + --per_device_eval_batch_size=64 \ + --gradient_accumulation_steps=4 \ + --learning_rate=5e-5 \ + --num_train_epochs=30 \ + --lr_scheduler_type="linear" \ + --num_warmup_steps=1000 \ + --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_layer=3 \ + --key_encoder_type="conv" \ + --select_positive_strategy="softmax_sample" \ + --faiss_efsearch=128 \ + --gen_weight=1 \ + --match_weight=1 \ + --key_reduce_method="avg" \ + --qas_to_retrieve_from="PAQ_L1" \ + --local_size=384 \ + --update_kv_embeds \ + --update_local_target_each_batch \ + --update_local_qas \ + --separate_task \ + --value_ae_target="ans" \ + --key_ae_target="question" \ + --repaq_supervision_epoch=-1 \ + --early_stop_patience=8 \ + --negatives_num_each_example=32 \ + --do_test diff --git a/scripts/nq_train_with_paql1.sh b/scripts/nq_train_with_paql1.sh new file mode 100644 index 0000000..e54b872 --- /dev/null +++ b/scripts/nq_train_with_paql1.sh @@ -0,0 +1,63 @@ +#!/bin/bash -l + +set -e +set -u + + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +LOAD_EXP_NAME="KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;" # load pre-trained model +EXP_NAME="base;KL=3;VL=7;VN=10;lr=5e-5;" # set experiment name +DATA_NAME="nq" # datasets: ["nq", "tq", "wq"] + +DEVICE="0" + +# Train nq-EMAT-FKSV +# use --kvm_fp16 if GPU OOM + +CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \ + --project_name="${DATA_NAME^^}-CAT" \ + --exp_name=${EXP_NAME} \ + --query_batch_size=256 \ + --build_mem_batch_size=12000 \ + --batch_local_positive_num=5 \ + --pos_from_top=128 \ + --do_eval \ + --kvm_seg_n=2 \ + --values_with_order \ + --value_layer=7 \ + --value_fusion_method="cat_k_delay+v" \ + --num_values=10 \ + --qa_data_name=${DATA_NAME} \ + --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \ + --source_prefix="question: " \ + --per_device_train_batch_size=64 \ + --per_device_eval_batch_size=64 \ + --gradient_accumulation_steps=4 \ + --learning_rate=5e-5 \ + --num_train_epochs=30 \ + --lr_scheduler_type="linear" \ + --num_warmup_steps=1000 \ + --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_layer=3 \ + --key_encoder_type="conv" \ + --select_positive_strategy="softmax_sample" \ + --faiss_efsearch=128 \ + --gen_weight=1 \ + --match_weight=1 \ + --key_reduce_method="avg" \ + --qas_to_retrieve_from="PAQ_L1" \ + --local_size=384 \ + --update_kv_embeds \ + --update_local_target_each_batch \ + --update_local_qas \ + --separate_task \ + --value_ae_target="ans" \ + --key_ae_target="question" \ + --repaq_supervision_epoch=-1 \ + --early_stop_patience=8 \ + --negatives_num_each_example=32 \ + --do_train diff --git a/scripts/tq_train_with_paql1.sh b/scripts/tq_train_with_paql1.sh new file mode 100644 index 0000000..6303026 --- /dev/null +++ b/scripts/tq_train_with_paql1.sh @@ -0,0 +1,63 @@ +#!/bin/bash -l + +set -e +set -u + + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +LOAD_EXP_NAME="KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;" # load pre-trained model +EXP_NAME="base;KL=3;VL=7;VN=10;lr=5e-5;" # set experiment name +DATA_NAME="tq" # datasets: ["nq", "tq", "wq"] + +DEVICE="0" + +# Train nq-EMAT-FKSV +# use --kvm_fp16 if GPU OOM + +CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \ + --project_name="${DATA_NAME^^}-CAT" \ + --exp_name=${EXP_NAME} \ + --query_batch_size=256 \ + --build_mem_batch_size=12000 \ + --batch_local_positive_num=5 \ + --pos_from_top=128 \ + --do_eval \ + --kvm_seg_n=2 \ + --values_with_order \ + --value_layer=7 \ + --value_fusion_method="cat_k_delay+v" \ + --num_values=10 \ + --qa_data_name=${DATA_NAME} \ + --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \ + --source_prefix="question: " \ + --per_device_train_batch_size=64 \ + --per_device_eval_batch_size=64 \ + --gradient_accumulation_steps=4 \ + --learning_rate=5e-5 \ + --num_train_epochs=30 \ + --lr_scheduler_type="linear" \ + --num_warmup_steps=1000 \ + --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_layer=3 \ + --key_encoder_type="conv" \ + --select_positive_strategy="softmax_sample" \ + --faiss_efsearch=128 \ + --gen_weight=1 \ + --match_weight=1 \ + --key_reduce_method="avg" \ + --qas_to_retrieve_from="PAQ_L1" \ + --local_size=384 \ + --update_kv_embeds \ + --update_local_target_each_batch \ + --update_local_qas \ + --separate_task \ + --value_ae_target="ans" \ + --key_ae_target="question" \ + --repaq_supervision_epoch=-1 \ + --early_stop_patience=8 \ + --negatives_num_each_example=32 \ + --do_train diff --git a/scripts/wq_train_with_paql1.sh b/scripts/wq_train_with_paql1.sh new file mode 100644 index 0000000..29527d7 --- /dev/null +++ b/scripts/wq_train_with_paql1.sh @@ -0,0 +1,63 @@ +#!/bin/bash -l + +set -e +set -u + + +DST_DIR="case-augmented-transformer-master" # change to your project root +cd ${DST_DIR} + +LOAD_EXP_NAME="KL=3;kdim=1536;VL=7;VN=10;cat_k_delay+v;t5-base;" # load pre-trained model +EXP_NAME="base;KL=3;VL=7;VN=10;lr=5e-5;" # set experiment name +DATA_NAME="wq" # datasets: ["nq", "tq", "wq"] + +DEVICE="0" + +# Train nq-EMAT-FKSV +# use --kvm_fp16 if GPU OOM + +CUDA_VISIBLE_DEVICES=${DEVICE} python qa_main.py \ + --project_name="${DATA_NAME^^}-CAT" \ + --exp_name=${EXP_NAME} \ + --query_batch_size=256 \ + --build_mem_batch_size=12000 \ + --batch_local_positive_num=5 \ + --pos_from_top=128 \ + --do_eval \ + --kvm_seg_n=2 \ + --values_with_order \ + --value_layer=7 \ + --value_fusion_method="cat_k_delay+v" \ + --num_values=10 \ + --qa_data_name=${DATA_NAME} \ + --model_name_or_path="./outputs/checkpoints/${LOAD_EXP_NAME}/latest_ckpt" \ + --source_prefix="question: " \ + --per_device_train_batch_size=64 \ + --per_device_eval_batch_size=64 \ + --gradient_accumulation_steps=4 \ + --learning_rate=4e-5 \ + --num_train_epochs=30 \ + --lr_scheduler_type="constant" \ + --num_warmup_steps=1000 \ + --output_dir="./outputs/nq_checkpoints/${EXP_NAME}" \ + --prefix_length=2 \ + --d_key=1536 \ + --key_layer=3 \ + --key_encoder_type="conv" \ + --select_positive_strategy="softmax_sample" \ + --faiss_efsearch=128 \ + --gen_weight=1 \ + --match_weight=1 \ + --key_reduce_method="avg" \ + --qas_to_retrieve_from="PAQ_L1" \ + --local_size=384 \ + --update_kv_embeds \ + --update_local_target_each_batch \ + --update_local_qas \ + --separate_task \ + --value_ae_target="ans" \ + --key_ae_target="question" \ + --repaq_supervision_epoch=-1 \ + --early_stop_patience=8 \ + --negatives_num_each_example=32 \ + --do_train diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/dr_utils.py b/utils/dr_utils.py new file mode 100644 index 0000000..277d129 --- /dev/null +++ b/utils/dr_utils.py @@ -0,0 +1,698 @@ +import copy +from functools import partial +import random +from typing import List +import torch +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from torch.nn.utils.rnn import pad_sequence +from build_kvm import build_memory +from emat.evaluation.eval_retriever import eval_retriever +from emat.evaluation.exact_match import normalize_answer +from kilt_dataset import DialogDataset +from qa_dataset import QADataset +from utils.utils import reduce_query_or_key_embeds +import logging + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +# not used in QA +def query_collate_fn(batch, tokenizer=None, dataset=None): + for item in batch: + item.update(dataset.get_base_input(item)) + query_input_ids = [item["input_ids"] for item in batch] + query_input_ids = pad_sequence(query_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + query_attention_mask = (query_input_ids != tokenizer.pad_token_id).long() + return {"query_input_ids": query_input_ids, "query_attention_mask": query_attention_mask} + + +# not used in QA +def kvm_collate_fn(batch, tokenizer=None, train_dataset=None): + for item in batch: + item.update(train_dataset.get_base_input(item)) + key_input_ids = [item["input_ids"] for item in batch] + key_input_ids = pad_sequence(key_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + key_attention_mask = (key_input_ids != tokenizer.pad_token_id).long() + value_input_ids = [item["target_as_input_ids"] for item in batch] + value_input_ids = pad_sequence(value_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + value_attention_mask = (value_input_ids != tokenizer.pad_token_id).long() + return {"key_input_ids": key_input_ids, "key_attention_mask": key_attention_mask, + "value_input_ids": value_input_ids, "value_attention_mask": value_attention_mask} + + +# not used in QA +@torch.no_grad() +def build_key_memory(model, tokenizer, train_dataset, data_to_retrieve_list: List[List[dict]], args, + build_memory_batch_size=2560) -> List[torch.tensor]: + # key norm layer only used in key-mathing (retrieval) + # if not train key-matching, no gradient to key norm layer. + use_normed_key = True if args.key_matching_weight > 0.0 else False + + key_memory = [] + for data_to_retrieve in data_to_retrieve_list: + cur_km, _ = build_memory(model, tokenizer, embed_key=True, embed_value=False, embed_as_fp16=True, + key_reduce_method=args.key_reduce_method, data_to_embed=data_to_retrieve, + batch_size=build_memory_batch_size, return_memory=True, separate_task=True, + collate_fn=partial(kvm_collate_fn, tokenizer=tokenizer, train_dataset=train_dataset), + normed_key_memory=use_normed_key) + key_memory.append(cur_km) + + return key_memory + + +# not used in QA +@torch.no_grad() +def build_dataset_query(model, tokenizer, dataset, args, encode_query_batch_size=2560) -> torch.tensor: + dataset_query_embeds = [] + query_to_embed_dataloader = DataLoader(dataset.data, batch_size=encode_query_batch_size, num_workers=16, + collate_fn=partial(query_collate_fn, tokenizer=tokenizer, dataset=dataset)) + for query_batch_input in tqdm(query_to_embed_dataloader, desc="embed_query"): + embed_dict = model.CAT_embed_q( + input_ids=query_batch_input["query_input_ids"].to(model.device), + attention_mask=query_batch_input["query_attention_mask"].to(model.device), + compute_key=True, compute_value=False + ) + use_normed_query = True if args.key_matching_weight > 0.0 else False + query_embeds = embed_dict["normed_key_embeds"] if use_normed_query else embed_dict["key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method) + query_embeds = query_embeds.half().cpu() + dataset_query_embeds.append(query_embeds) + + dataset_query_embeds = torch.cat(dataset_query_embeds) + return dataset_query_embeds + + +# not used in QA +@torch.no_grad() +def prepare_local_cases(model, tokenizer, train_dataset, key_memory: List[torch.tensor], + data_to_retrieve_list: List[List[dict]], args, + max_num_to_ban=48, query_batch_size=512, fp16_query=False): + model.eval() + dataset_query_embeds = build_dataset_query(model, tokenizer, train_dataset, args) + local_size = args.local_size + retrieve_topk = local_size + 1 # retrieve retrieve_topk-cases from each key_memory, +1 is used to exclude itself. + if args.filter_type != "not_filter": + retrieve_topk = local_size + max_num_to_ban # 48 > the max num to ban-to-retrieve + for example in train_dataset.data: # clear previous local_cases and their scores + example["local_cases"] = [] + example["local_cases_scores"] = [] + for cur_key_memory, cur_data_to_retrieve in zip(key_memory, data_to_retrieve_list): + cur_cuda_key_memory = cur_key_memory.cuda() + for start_idx in tqdm(range(0, len(dataset_query_embeds), query_batch_size), + total=len(dataset_query_embeds) // query_batch_size, desc="retrieve local cases"): + cur_cuda_query_embeds = dataset_query_embeds[start_idx: start_idx + query_batch_size].cuda() + cur_batch_data = train_dataset.data[start_idx: start_idx + query_batch_size] + if fp16_query: + scores = torch.mm(cur_cuda_query_embeds, cur_cuda_key_memory.t()) + else: # not use fp16_query to avoid overflow + scores = torch.mm(cur_cuda_query_embeds.float(), cur_cuda_key_memory.t().float()) + topk = scores.topk(min(len(cur_key_memory), retrieve_topk), dim=1) + top_indices = topk.indices.tolist() + top_scores = topk.values.tolist() + for cur_example, cur_indices, cur_scores in zip(cur_batch_data, top_indices, top_scores): + cur_local_cases = [] + cur_local_cases_scores = [] + for case_idx, case_score in zip(cur_indices, cur_scores): + cur_case = cur_data_to_retrieve[case_idx] + cur_case_id = cur_case["id"] + + if cur_case_id == cur_example["id"]: + continue # exclude itself + if "ban_to_retrieve_list" in cur_example: + if cur_case_id in cur_example["ban_to_retrieve_list"]: + continue # the retrieved case is baned for cur_example + + cur_local_cases.append(cur_case_id) + cur_local_cases_scores.append(case_score) + + cur_local_cases = cur_local_cases[:local_size] + assert len(cur_local_cases) == local_size + cur_example["local_cases"] += cur_local_cases + cur_example["local_cases_scores"] = cur_local_cases_scores + + del cur_cuda_key_memory + torch.cuda.empty_cache() + + for example in train_dataset.data: # rank to select local-cases + sorted_cases_with_scores = sorted(zip(example["local_cases"], example["local_cases_scores"]), + key=lambda x: x[1], reverse=True)[:local_size] + example["local_cases"] = [sc[0] for sc in sorted_cases_with_scores] + example["local_cases_scores"] = [sc[1] for sc in sorted_cases_with_scores] + + +# not used in QA +@torch.no_grad() +def update_batch_retrieve_from_local(model, batch, args): + query_input_ids = batch["input_ids"] + query_input_attention_mask = batch["attention_mask"] + embed_dict = model.CAT_embed_q( + input_ids=query_input_ids, + attention_mask=query_input_attention_mask, + compute_key=True, compute_value=False + ) + use_normed_query = True if args.key_matching_weight > 0.0 else False + query_embeds = embed_dict["normed_key_embeds"] if use_normed_query else embed_dict["key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method) + query_embeds = query_embeds + cur_bs = query_input_ids.shape[0] + + squeezed_local_cases_input_ids = batch.pop("squeezed_local_cases_input_ids") + squeezed_local_cases_attention_mask = batch.pop("squeezed_local_cases_attention_mask") + squeezed_local_cases_target_as_input_ids = batch.pop("squeezed_local_cases_target_as_input_ids") + squeezed_local_cases_target_as_input_attention_mask = batch.pop( + "squeezed_local_cases_target_as_input_attention_mask") + embed_dict = model.wrapped_embed_kv( + separate_task=args.separate_task, compute_key=True, compute_value=False, + key_input_ids=squeezed_local_cases_input_ids, key_attention_mask=squeezed_local_cases_attention_mask, + ) + squeezed_local_cases_key = embed_dict["normed_key_embeds"] if use_normed_query else embed_dict["key_embeds"] + squeezed_local_cases_key = reduce_query_or_key_embeds(squeezed_local_cases_key, args.key_reduce_method) + local_cases_key = squeezed_local_cases_key.view(cur_bs, args.local_size, -1) + scores = torch.bmm(query_embeds.unsqueeze(dim=1), local_cases_key.transpose(2, 1)).squeeze(dim=1) + scores_topk = scores.topk(args.num_values, dim=1) + retrieved_indices = scores_topk.indices + gathered_indices = retrieved_indices.unsqueeze(dim=-1) + + local_cases_input_ids = squeezed_local_cases_input_ids.view(cur_bs, args.local_size, -1) + local_cases_attention_mask = squeezed_local_cases_attention_mask.view(cur_bs, args.local_size, -1) + local_cases_target_as_input_ids = squeezed_local_cases_target_as_input_ids.view(cur_bs, args.local_size, -1) + local_cases_target_as_input_attention_mask = squeezed_local_cases_target_as_input_attention_mask. \ + view(cur_bs, args.local_size, -1) + key_input_ids = torch.gather(local_cases_input_ids, 1, gathered_indices. + repeat(1, 1, local_cases_input_ids.shape[-1])) + key_attention_mask = torch.gather(local_cases_attention_mask, 1, gathered_indices. + repeat(1, 1, local_cases_attention_mask.shape[-1])) + value_input_ids = torch.gather(local_cases_target_as_input_ids, 1, gathered_indices. + repeat(1, 1, local_cases_target_as_input_ids.shape[-1])) + value_attention_mask = torch.gather(local_cases_target_as_input_attention_mask, 1, gathered_indices. + repeat(1, 1, local_cases_target_as_input_attention_mask.shape[-1])) + batch.update({"group_key_input_ids": key_input_ids, "group_key_attention_mask": key_attention_mask, + "group_value_input_ids": value_input_ids, "group_value_attention_mask": value_attention_mask}) + + if args.key_matching_weight > 0.0: + local_cases_label_ids = batch.pop("local_cases_label_ids") + retrieved_cases_label = torch.gather(local_cases_label_ids, 1, retrieved_indices) + squeezed_retrieved_cases_label = retrieved_cases_label.view(-1) + label_ids = batch["label_ids"] + matching_mask = [] + matching_target = [] + for bid, (cur_target_label, cur_retrieved_cases_label) in enumerate(zip(label_ids, retrieved_cases_label)): + cur_mask = torch.ones_like(squeezed_retrieved_cases_label) + cur_mask[squeezed_retrieved_cases_label == cur_target_label] = 0 + matched_pos = (cur_retrieved_cases_label == cur_target_label).nonzero().view(-1) + if len(matched_pos) > 0: + cur_pos_idx = matched_pos[0] + len(cur_retrieved_cases_label) * bid + matching_target.append(cur_pos_idx) + cur_mask[cur_pos_idx] = 1 + else: + matching_target.append(-100) + matching_mask.append(cur_mask) + + batch.update({"matching_target": torch.tensor(matching_target).to(model.device), + "matching_mask": torch.stack(matching_mask).to(model.device)}) + + +# not used in QA +@torch.no_grad() +def retrieve_from_key_memory(model, tokenizer, dataset, key_memory: List[torch.tensor], + data_to_retrieve_list: List[List[dict]], args): + model.eval() + dataset_query_embeds = build_dataset_query(model, tokenizer, dataset, args) + query_batch_size = 512 + for example in dataset.data: # clear previous local_cases and their scores + example["retrieved_cases"] = [] + example["retrieved_cases_scores"] = [] + for cur_key_memory, cur_data_to_retrieve in zip(key_memory, data_to_retrieve_list): + cur_cuda_key_memory = cur_key_memory.cuda() + for start_idx in tqdm(range(0, len(dataset_query_embeds), query_batch_size), + total=len(dataset_query_embeds) // query_batch_size, desc="query_key_memory"): + cur_cuda_query_embeds = dataset_query_embeds[start_idx: start_idx + query_batch_size].cuda() + cur_batch_data = dataset.data[start_idx: start_idx + query_batch_size] + scores = torch.mm(cur_cuda_query_embeds, cur_cuda_key_memory.t()) + topk = scores.topk(args.num_values, dim=1) + top_indices = topk.indices.tolist() + top_scores = topk.values.tolist() + for cur_example, cur_indices, cur_scores in zip(cur_batch_data, top_indices, top_scores): + cur_local_cases = [cur_data_to_retrieve[case_idx] for case_idx in cur_indices] + cur_local_cases_scores = cur_scores + cur_example["retrieved_cases"] += cur_local_cases + cur_example["retrieved_cases_scores"] = cur_local_cases_scores + del cur_cuda_key_memory + torch.cuda.empty_cache() + for example in dataset.data: + sorted_cases_with_scores = sorted(zip(example["retrieved_cases"], example["retrieved_cases_scores"]), + key=lambda x: x[1], reverse=True)[:args.num_values] + retrieved_cases = [sc[0] for sc in sorted_cases_with_scores] + for item in retrieved_cases: + item.update(dataset.get_base_input(item)) + retrieved_key_seqs = [case["input_ids"] for case in retrieved_cases] + retrieved_value_seqs = [case["target_as_input_ids"] for case in retrieved_cases] + example.update({"retrieved_key_seqs": retrieved_key_seqs, "retrieved_value_seqs": retrieved_value_seqs}) + # example["retrieved_cases_scores"] = [sc[1] for sc in sorted_cases_with_scores] + + +# it is used in QA if not rank-exists-local-qas +@torch.no_grad() +def update_local_qas_to_retrieve(args, train_dataset, qas_to_retrieve, model, key_memory: List[torch.tensor], + normed_answer_of_qas_to_ret, train_data_query_embeds=None, build_mem_batch_size=1024, + query_batch_size=128, local_size=1024, pos_from_top=50, neg_from_top=200, + use_retrieval_adapter=False): + model.eval() + assert type(key_memory) == list + + logger.info(f"Prepare local QAs for each example to retrieve.") + all_local_qas = [] + all_local_positive = [] + all_local_negative = [] + all_ret_qas = [] + + if train_data_query_embeds is None: + if use_retrieval_adapter: + dim = args.adapter_out_dim + else: + dim = model.model_dim + train_data_query_embeds = torch.zeros((len(train_dataset.data), dim), device='cpu', dtype=torch.float16) + if args.qa_data_name == "tq": + build_query_batch_size = 256 + else: + build_query_batch_size = build_mem_batch_size + embed_query_dataloader = train_dataset.get_query_dataloader(batch_size=build_query_batch_size, + shuffle=False, num_workers=1) + start_idx = 0 + for query_inputs in tqdm(embed_query_dataloader, total=len(embed_query_dataloader), desc="Embed queries."): + end_idx = start_idx + len(query_inputs["query_input_ids"]) + embed_dict = model.CAT_embed_q( + input_ids=query_inputs["query_input_ids"].to(model.device), + attention_mask=query_inputs["query_attention_mask"].to(model.device), + compute_key=True, compute_value=False + ) + query_embeds = embed_dict["normed_key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method) + if use_retrieval_adapter: + query_embeds = model.adapter(query_embeds) + query_embeds = query_embeds.half().cpu() + train_data_query_embeds[start_idx: end_idx] = query_embeds + start_idx = end_idx + assert start_idx == len(train_dataset.data) + + torch.cuda.empty_cache() + + key_nums = sum(len(k) for k in key_memory) + logger.info(f"key-memory seg-num: {len(key_memory)}. all key nums: {key_nums}.") + + for start_idx in tqdm(range(0, len(train_dataset.data), query_batch_size), + total=len(train_dataset.data) // query_batch_size + 1): + cur_cuda_query_embeds = train_data_query_embeds[start_idx: start_idx + query_batch_size].cuda() + if key_nums > 10000000: + # if scale is large: calculate topk in each chunk -> combine all-topk -> select final topk + chunk_top_scores = [] + chunk_top_indices = [] + idx_shift = 0 + for ckm_idx, chunk_key_memory in enumerate(key_memory): + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_topk = torch.mm(cur_cuda_query_embeds, chunk_key_memory_cuda.t()).topk(local_size, dim=1) + chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size] + chunk_top_indices.append(chunk_topk.indices + idx_shift) + idx_shift += len(chunk_key_memory) + del chunk_key_memory_cuda + torch.cuda.empty_cache() + chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n + chunk_top_indices = torch.cat(chunk_top_indices, dim=1).tolist() # q_batch, local_size*seg_n + topk = chunk_top_scores.topk(local_size, dim=1) # q_batch, local_size + top_indices_indices = topk.indices.tolist() + top_indices = [] + for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices): + top_indices.append([cur_indices[idx] for idx in cur_indices_indices]) + else: + # if scale is moderate: calculate score in each chunk -> combine score -> select topk + all_chunk_scores = [] + for chunk_key_memory in key_memory: + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_scores = torch.mm(cur_cuda_query_embeds, chunk_key_memory_cuda.t()) + all_chunk_scores.append(chunk_scores) # q_batch, chunk_size + del chunk_key_memory_cuda + torch.cuda.empty_cache() + scores = torch.cat(all_chunk_scores, dim=1).cuda() # q_batch, key_memory_size + topk = scores.topk(local_size, dim=1) + top_indices = topk.indices.tolist() + + batch = train_dataset.data[start_idx: start_idx + query_batch_size] + for cur_example, cur_indices in zip(batch, top_indices): + local_positive, local_negative = [], [] + # cur_target = [normalize_answer(ans) for ans in cur_example["answer"]] + cur_target = [na for na in cur_example["normalized_answer"]] + + cur_ret_qas = [] + for top_idx, qa_idx in enumerate(cur_indices): + if normed_answer_of_qas_to_ret[qa_idx] in cur_target: + if top_idx < pos_from_top: + local_positive.append(qa_idx) + else: + if top_idx < neg_from_top: + local_negative.append(qa_idx) + elif len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative + # if len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative + local_negative.append(qa_idx) + cur_ret_qas.append(qas_to_retrieve[qa_idx]) + + all_ret_qas.append(cur_ret_qas) + + all_local_positive.append(local_positive) + all_local_negative.append(local_negative) + + all_local_qas += top_indices + del cur_cuda_query_embeds + del topk + + torch.cuda.empty_cache() + assert len(all_ret_qas) == len(train_dataset.data) + assert len(all_local_qas) == len(train_dataset.data) == len(all_local_positive) == len(all_local_negative) + matching_metric = eval_retriever(train_dataset.data, all_ret_qas, "1,2,3,4,5") + for k, v in matching_metric.items(): + logging.info({f"local_qas initial {k}": v}) + for i in range(len(train_dataset.data)): + train_dataset.data[i]["local_positive"] = all_local_positive[i] + train_dataset.data[i]["local_negative"] = all_local_negative[i] + train_dataset.data[i]["local_qas"] = all_local_qas[i] + + logger.info(f"Local QAs updated.") + + +@torch.no_grad() +def update_batch_inputs(args, batch, model, use_adapter_to_select_positive=False): + model.eval() + embed_dict = model.CAT_embed_q( + input_ids=batch["query_input_ids"], + attention_mask=batch["query_attention_mask"], + compute_key=True, compute_value=False + ) + query_embeds = embed_dict["normed_key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method) + if use_adapter_to_select_positive: + query_embeds = model.adapter(query_embeds) + batch_size, hidden_size = query_embeds.shape + + local_positive_inputs_keys = [k for k in batch.keys() if k.startswith("local_positive_inputs_")] + local_positive_inputs = {k.replace("local_positive_inputs_", ""): + batch.pop(k).view(batch_size * args.batch_local_positive_num, -1) + for k in local_positive_inputs_keys} + embed_dict = model.wrapped_embed_kv(separate_task=args.separate_task, compute_key=True, + compute_value=False, **local_positive_inputs) + local_positive_key_embeds = embed_dict["normed_key_embeds"] + local_positive_key_embeds = reduce_query_or_key_embeds(local_positive_key_embeds, args.key_reduce_method) + if use_adapter_to_select_positive: + local_positive_key_embeds = model.adapter(local_positive_key_embeds) + scores = torch.bmm(query_embeds.unsqueeze(dim=1), local_positive_key_embeds.view( + batch_size, args.batch_local_positive_num, hidden_size).transpose(2, 1)).squeeze(dim=1) + scores = scores + (batch["local_positive_qas_mask"] - 1) * 1e-4 + scores = torch.softmax(scores, dim=1) + + sampled_pos_local_idx = torch.multinomial(scores, 1).squeeze(dim=-1) # [batch_size] + + # sampled_local_idx: [batch_size] + # local_positive_inputs[key_input_ids/key_attention_mask]: [batch_size*max_pos_num, seq_length] + all_pos_key_input_ids = local_positive_inputs["key_input_ids"].view(batch_size, args.batch_local_positive_num, -1) + all_pos_key_attention_mask = local_positive_inputs["key_attention_mask"].view(all_pos_key_input_ids.shape) + positive_gather_indices = sampled_pos_local_idx.unsqueeze(dim=1). \ + repeat(1, all_pos_key_input_ids.shape[-1]).unsqueeze(dim=1) + positive_key_input_ids = torch.gather(all_pos_key_input_ids, 1, positive_gather_indices).squeeze(dim=1) + positive_key_attention_mask = torch.gather(all_pos_key_attention_mask, 1, positive_gather_indices).squeeze(dim=1) + # positive_key_input_ids: [bach_size, seq_len] + + if "labels_to_select" in batch: + labels_to_select = batch.pop("labels_to_select") # batch, args.batch_local_positive_num, seq_len + target_gather_indices = sampled_pos_local_idx.unsqueeze(dim=1). \ + repeat(1, labels_to_select.shape[-1]).unsqueeze(dim=1) + labels = torch.gather(labels_to_select, 1, target_gather_indices).squeeze(dim=1) + batch.update({"labels": labels}) + + # if args.num_values > 1 or args.use_not_exactly_true: + local_mixed_inputs_keys = [k for k in batch.keys() if k.startswith("local_mixed_inputs_")] + local_mixed_inputs = {k.replace("local_mixed_inputs_", ""): batch.pop(k) for k in local_mixed_inputs_keys} + mixed_qas_num = args.negatives_num_each_example + embed_dict = model.wrapped_embed_kv(separate_task=args.separate_task, compute_key=True, + compute_value=True, **{k: v.view(batch_size * mixed_qas_num, -1) + for k, v in local_mixed_inputs.items()}) + mixed_key_embeds = embed_dict["normed_key_embeds"] + mixed_key_embeds = reduce_query_or_key_embeds(mixed_key_embeds, args.key_reduce_method) + if use_adapter_to_select_positive: + mixed_key_embeds = model.adapter(mixed_key_embeds) + mixed_key_embeds = mixed_key_embeds.view(batch_size, mixed_qas_num, -1) + scores = torch.bmm(query_embeds.unsqueeze(dim=1), mixed_key_embeds.transpose(2, 1)).squeeze(dim=1) + # assert args.values_with_order is True + # if w/o order, shuffle the group_value_qas_indices of each example + group_value_qas_indices = scores.topk(args.num_values, dim=1).indices # [batch_size, num_values] + if not args.values_with_order: + group_value_qas_indices = group_value_qas_indices.tolist() + for value_qas_indices in group_value_qas_indices: + random.shuffle(value_qas_indices) + group_value_qas_indices = torch.tensor(group_value_qas_indices).to(scores.device) + + # assert args.num_values > 1 + # if args.num_values == 1, the input is from the sampled_pos_local_idx. + + # if args.num_values > 1: + mixed_key_input_ids = local_mixed_inputs["key_input_ids"] # [batch_size, mixed_qas_num, seq_len] + mixed_key_attention_mask = local_mixed_inputs["key_attention_mask"] + mixed_value_input_ids = local_mixed_inputs["value_input_ids"] + mixed_value_attention_mask = local_mixed_inputs["value_attention_mask"] + key_gather_indices = group_value_qas_indices.unsqueeze(dim=-1).repeat(1, 1, mixed_key_input_ids.shape[-1]) + group_key_input_ids = torch.gather(mixed_key_input_ids, 1, key_gather_indices) + group_key_attention_mask = torch.gather(mixed_key_attention_mask, 1, key_gather_indices) + value_gather_indices = group_value_qas_indices.unsqueeze(dim=-1).repeat(1, 1, mixed_value_input_ids.shape[-1]) + group_value_input_ids = torch.gather(mixed_value_input_ids, 1, value_gather_indices) + group_value_attention_mask = torch.gather(mixed_value_attention_mask, 1, value_gather_indices) + + # matching_targets + matching_targets = torch.arange(batch_size).to(model.device) + matching_targets[batch.pop("local_positive_num") == 0] = -100 + + batch.update({ + "positive_kv_inputs": { + "key_input_ids": positive_key_input_ids, + "key_attention_mask": positive_key_attention_mask + }, + "negative_kv_inputs": { + "key_input_ids": batch.pop("local_negative_inputs_key_input_ids") + .view(batch_size * args.negatives_num_each_example, -1), + "key_attention_mask": batch.pop("local_negative_inputs_key_attention_mask") + .view(batch_size * args.negatives_num_each_example, -1), + }, + "matching_targets": matching_targets, + "group_value_inputs": { + "key_input_ids": group_key_input_ids.view(batch_size * args.num_values, -1), + "key_attention_mask": group_key_attention_mask.view(batch_size * args.num_values, -1), + "value_input_ids": group_value_input_ids.view(batch_size * args.num_values, -1), + "value_attention_mask": group_value_attention_mask.view(batch_size * args.num_values, -1), + }, + }) + + +# not really rank-local, only prepare positive/negative qas from exists local-qas +@torch.no_grad() +def rank_exist_local_qas(args, train_dataset: QADataset, qas_to_retrieve, model, normed_answer_of_qas_to_ret, + train_data_query_embeds=None, build_mem_batch_size=1204, + embed_local_qas_batch_size=6, local_size=1024, pos_from_top=50, neg_from_top=200, + accelerator=None): + model.eval() + if args.use_fp16_rank: + half_model = copy.deepcopy(model) + half_model.eval() + model = half_model.half() + embed_local_qas_batch_size = int(embed_local_qas_batch_size * 1.5) + logger.info(f"Rank local QAs for each example to retrieve. embed_local_qas_batch_size={embed_local_qas_batch_size}") + + if train_data_query_embeds is None: + train_data_query_embeds = torch.zeros((len(train_dataset.data), model.model_dim), device='cpu', + dtype=torch.float16) + + embed_query_dataloader = train_dataset.get_query_dataloader(batch_size=build_mem_batch_size, + shuffle=False, num_workers=5) + start_idx = 0 + for query_inputs in tqdm(embed_query_dataloader, total=len(embed_query_dataloader), desc="Embed queries."): + end_idx = start_idx + len(query_inputs["query_input_ids"]) + embed_dict = model.CAT_embed_q( + input_ids=query_inputs["query_input_ids"].to(model.device), + attention_mask=query_inputs["query_attention_mask"].to(model.device), + compute_key=True, compute_value=False + ) + query_embeds = embed_dict["normed_key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method) + query_embeds = query_embeds.half().cpu() + train_data_query_embeds[start_idx: end_idx] = query_embeds + start_idx = end_idx + assert start_idx == len(train_dataset.data) + + torch.cuda.empty_cache() + + embed_local_dataloader = train_dataset.get_local_qas_dataloader(batch_size=embed_local_qas_batch_size, + shuffle=False, num_workers=10) + if accelerator is not None: + embed_local_dataloader = accelerator.prepare(embed_local_dataloader) + + all_ret_qas = [] + + start_idx = 0 + for local_qas_batch in tqdm(embed_local_dataloader): + query_ids = local_qas_batch.pop("query_ids") + bs = len(query_ids) + cur_query_embeds = train_data_query_embeds[start_idx: start_idx + bs] + cur_batch = train_dataset.data[start_idx: start_idx + bs] + start_idx = start_idx + bs + + embed_dict = model.wrapped_embed_kv( + separate_task=args.separate_task, compute_key=True, compute_value=False, + **local_qas_batch + ) + squeezed_local_key_embeds = embed_dict["normed_key_embeds"] + squeezed_local_key_embeds = reduce_query_or_key_embeds(squeezed_local_key_embeds, args.key_reduce_method) + local_key_embeds = squeezed_local_key_embeds.view(bs, -1, model.model_dim).half() + # exists local-qas should be larger than expect local-size + assert local_size <= squeezed_local_key_embeds.shape[1] + + cur_query_embeds = cur_query_embeds.cuda() + # [bs, 1, hidden] [bs, hidden, exists-local-qas-num] --> [bs, exists-local-qas-num] + scores = torch.bmm(cur_query_embeds.unsqueeze(dim=1), local_key_embeds.transpose(2, 1)).squeeze(dim=1) + top_local_indices = scores.topk(local_size, dim=1).indices + + for cur_example, cur_indices in zip(cur_batch, top_local_indices): + local_positive, local_negative = [], [] + cur_target = [normalize_answer(ans) for ans in cur_example["answer"]] + + qas_ids_of_top_local = [cur_example["local_qas"][local_idx] for local_idx in cur_indices] + # qas_of_top_local = [qas_to_retrieve[qid] for qid in qas_ids_of_top_local] + all_ret_qas.append([qas_to_retrieve[qid] for qid in qas_ids_of_top_local[:50]]) + + for top_idx, qa_idx in enumerate(qas_ids_of_top_local): + if normed_answer_of_qas_to_ret[qa_idx] in cur_target: + if top_idx < pos_from_top: + local_positive.append(qa_idx) + else: + if top_idx < neg_from_top: + local_negative.append(qa_idx) + elif len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative + local_negative.append(qa_idx) + + cur_example["local_positive"] = local_positive + cur_example["local_negative"] = local_negative + + assert len(all_ret_qas) == len(train_dataset.data) + matching_metric = eval_retriever(train_dataset.data, all_ret_qas, "1,2,3,4,5,10,50") + for k, v in matching_metric.items(): + logging.info({f"local_qas initial {k}": v}) + + if args.use_fp16_rank: + del model + torch.cuda.empty_cache() + logger.info(f"Local QAs ranked.") + + +# Dialog +@torch.no_grad() +def update_dialog_local_qas_to_retrieve(args, train_dataset: DialogDataset, qas_to_retrieve, model, + key_memory: List[torch.tensor], normed_answer_of_qas_to_ret, + train_data_query_embeds=None, build_mem_batch_size=1024, + query_batch_size=128, local_size=1024, pos_from_top=50, + neg_from_top=200): + model.eval() + assert type(key_memory) == list + + logger.info(f"Prepare local QAs for each example to retrieve.") + all_local_qas = [] + all_local_positive = [] + all_local_negative = [] + all_ret_qas = [] + + if train_data_query_embeds is None: + train_data_query_embeds = torch.zeros((len(train_dataset.data), model.model_dim), device='cpu', + dtype=torch.float16) + embed_query_dataloader = train_dataset.get_query_dataloader(batch_size=query_batch_size, shuffle=False, + num_workers=1, ) + start_idx = 0 + for query_inputs in tqdm(embed_query_dataloader, total=len(embed_query_dataloader), desc="Embed queries."): + end_idx = start_idx + len(query_inputs["query_input_ids"]) + embed_dict = model.CAT_embed_q( + input_ids=query_inputs["query_input_ids"].to(model.device), + attention_mask=query_inputs["query_attention_mask"].to(model.device), + compute_key=True, compute_value=False + ) + query_embeds = embed_dict["normed_key_embeds"] + query_embeds = reduce_query_or_key_embeds(query_embeds, args.key_reduce_method) + query_embeds = query_embeds.half().cpu() + train_data_query_embeds[start_idx: end_idx] = query_embeds + start_idx = end_idx + assert start_idx == len(train_dataset.data) + + torch.cuda.empty_cache() + + key_nums = sum(len(k) for k in key_memory) + logger.info(f"key-memory seg-num: {len(key_memory)}. all key nums: {key_nums}.") + query_batch_size = 2000 + for start_idx in tqdm(range(0, len(train_dataset.data), query_batch_size), + total=len(train_dataset.data) // query_batch_size + 1): + cur_cuda_query_embeds = train_data_query_embeds[start_idx: start_idx + query_batch_size].cuda() + + # calculate topk in each chunk -> combine all-topk -> select final topk + chunk_top_scores = [] + chunk_top_indices = [] + idx_shift = 0 + for ckm_idx, chunk_key_memory in enumerate(key_memory): + chunk_key_memory_cuda = chunk_key_memory.cuda() + chunk_topk = torch.mm(cur_cuda_query_embeds, chunk_key_memory_cuda.t()).topk(local_size, dim=1) + chunk_top_scores.append(chunk_topk.values) # chunk_topk.scores: [query_batch, local_size] + chunk_top_indices.append(chunk_topk.indices + idx_shift) + idx_shift += len(chunk_key_memory) + del chunk_key_memory_cuda + torch.cuda.empty_cache() + chunk_top_scores = torch.cat(chunk_top_scores, dim=1) # q_batch, local_size*seg_n + chunk_top_indices = torch.cat(chunk_top_indices, dim=1).tolist() # q_batch, local_size*seg_n + topk = chunk_top_scores.topk(local_size, dim=1) # q_batch, local_size + top_indices_indices = topk.indices.tolist() + top_indices = [] + for cur_indices_indices, cur_indices in zip(top_indices_indices, chunk_top_indices): + top_indices.append([cur_indices[idx] for idx in cur_indices_indices]) + + batch = train_dataset.data[start_idx: start_idx + query_batch_size] + for cur_example, cur_indices in zip(batch, top_indices): + local_positive, local_negative = [], [] + + cur_target_words = cur_example["normalized_response_remove_stop_words_list"] # a list of words + + cur_ret_qas = [] + for top_idx, qa_idx in enumerate(cur_indices): + if normed_answer_of_qas_to_ret[qa_idx] in cur_target_words: + # if normed_answer_of_qas_to_ret[qa_idx] in cur_target: # the QA's answer overlaps with response + if top_idx < pos_from_top: + local_positive.append(qa_idx) + else: + if top_idx < neg_from_top: + local_negative.append(qa_idx) + elif len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative + # if len(local_negative) < args.negatives_num_each_example: # ensure 32 local_negative + local_negative.append(qa_idx) + cur_ret_qas.append(qas_to_retrieve[qa_idx]) + + all_ret_qas.append(cur_ret_qas) + + all_local_positive.append(local_positive) + all_local_negative.append(local_negative) + + all_local_qas += top_indices + del cur_cuda_query_embeds + del topk + + torch.cuda.empty_cache() + assert len(all_ret_qas) == len(train_dataset.data) + assert len(all_local_qas) == len(train_dataset.data) == len(all_local_positive) == len(all_local_negative) + for i in range(len(train_dataset.data)): + train_dataset.data[i]["local_positive"] = all_local_positive[i] + train_dataset.data[i]["local_negative"] = all_local_negative[i] + train_dataset.data[i]["local_qas"] = all_local_qas[i] + + logger.info(f"Local QAs updated.") diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..da0eb60 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,774 @@ +import argparse +import json +import logging +import os +import pickle +import random + +import functools +import numpy as np +import torch +from emat.evaluation.eval_retriever import eval_retriever + +from emat.evaluation.exact_match import normalize_answer +from typing import List +from emat.utils import verbalise_qa +from copy import deepcopy +from transformers import MODEL_MAPPING, T5Config, T5Tokenizer, AutoConfig, AutoTokenizer, AutoModelForMultipleChoice +from tqdm.auto import tqdm +from itertools import chain + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def reduce_query_or_key_embeds(qk_embeds, key_reduce_method): + batch_size, key_nums, hidden_size = qk_embeds.shape + if key_reduce_method == "concat": + reduced_embeds = qk_embeds.view(batch_size, key_nums * hidden_size) + elif key_reduce_method == "avg": + reduced_embeds = qk_embeds.sum(dim=1) / key_nums + elif key_reduce_method == "sum": + reduced_embeds = qk_embeds.sum(dim=1) + else: + raise NotImplementedError(f"Reduce method ``{key_reduce_method}`` is not defined.") + + return reduced_embeds + + +def load_reranker(model_name_or_path="./data/models/rerankers/reranker_multi_xxlarge"): + logging.info(f'Loading rerank model from: {model_name_or_path}') + config = AutoConfig.from_pretrained(model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, do_lower_case=True) + model = AutoModelForMultipleChoice.from_pretrained( + model_name_or_path, + from_tf=bool(".ckpt" in model_name_or_path), + config=config, + ) + model = model.eval() + return model, tokenizer + + +def get_key_value_ae_target(qas, tokenizer, key_ae_target, value_ae_target, max_target_length, prefix=""): + if value_ae_target == "ans": + targets_of_value = [qa["answer"][0] for qa in qas] + else: # question_ans + targets_of_value = [f'question: {qa["question"]} answer: {qa["answer"][0]}' for qa in qas] + if key_ae_target == "question_ans": + targets_of_key = [f'question: {qa["question"]} answer: {qa["answer"][0]}' for qa in qas] + else: # question + targets_of_key = [prefix + qa["question"] for qa in qas] + + with tokenizer.as_target_tokenizer(): # setup the tokenizer for targets + key_labels = tokenizer(targets_of_key, max_length=max_target_length, + padding=True, truncation=True, return_tensors="pt") + value_labels = tokenizer(targets_of_value, max_length=max_target_length, + padding=True, truncation=True, return_tensors="pt") + return {"key_labels_input_ids": process_labels(key_labels, tokenizer), + "value_labels_input_ids": process_labels(value_labels, tokenizer)} + + +def get_key_value_encoder_inputs(qas, separate_task, tokenizer, max_source_length, + prefix="", only_return_key_inputs=False, value_input_is_qa=False): + # Used to get the input of Key-Value Encoder, qas are from PAQ-L1 + if separate_task: + key_inputs = ["question: " + qa["question"] for qa in qas] + key_inputs = tokenizer(key_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + if only_return_key_inputs: + return {"key_input_ids": key_inputs["input_ids"], + "key_attention_mask": key_inputs["attention_mask"]} + else: + if value_input_is_qa: + value_inputs = [f'question: {qa["question"]} answer: {qa["answer"][0]}' for qa in qas] + value_inputs = tokenizer(value_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + else: + value_inputs = ["answer: " + qa["answer"][0] for qa in qas] + value_inputs = tokenizer(value_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"key_input_ids": key_inputs["input_ids"], + "key_attention_mask": key_inputs["attention_mask"], + "value_input_ids": value_inputs["input_ids"], + "value_attention_mask": value_inputs["attention_mask"]} + else: + key_value_inputs = [prefix + verbalise_qa(qa["question"], qa["answer"][0]) for qa in qas] + key_value_inputs = tokenizer(key_value_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"key_value_input_ids": key_value_inputs["input_ids"], + "key_value_attention_mask": key_value_inputs["attention_mask"]} + + +def get_nli_group_value_inputs(examples, tokenizer, max_source_length): + group_key_inputs = [ex["retrieved_key_seqs"] for ex in examples] + group_key_inputs = list(chain(*group_key_inputs)) + group_key_inputs = tokenizer(group_key_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + group_value_inputs = [ex["retrieved_value_seqs"] for ex in examples] + group_value_inputs = list(chain(*group_value_inputs)) + group_value_inputs = tokenizer(group_value_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"key_input_ids": group_key_inputs["input_ids"], + "key_attention_mask": group_key_inputs["attention_mask"], + "value_input_ids": group_value_inputs["input_ids"], + "value_attention_mask": group_value_inputs["attention_mask"]} + + +def get_query_encoder_inputs(qas, tokenizer, max_source_length, prefix=""): + # Used to get the input of Query Encoder, qas are from NaturalQuestion + query_inputs = [prefix + qa["question"] for qa in qas] + query_inputs = tokenizer(query_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"query_input_ids": query_inputs["input_ids"], + "query_attention_mask": query_inputs["attention_mask"]} + + +def get_nli_input_seq(item): + return f"hypothesis: {item['hypothesis']} premise: {item['premise']}" + + +def get_query_encoder_inputs_nli(cases, tokenizer, max_source_length): + query_inputs = [case["input_seq"] for case in cases] + query_inputs = tokenizer(query_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"query_input_ids": query_inputs["input_ids"], + "query_attention_mask": query_inputs["attention_mask"]} + + +label2str = {0: "entailment", 1: "neutral", 2: "contradiction"} + + +def get_key_value_encoder_inputs_nli(cases, tokenizer, max_source_length, only_return_key_inputs=False): + key_inputs = [case["input_seq"] for case in cases] + key_inputs = tokenizer(key_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + if only_return_key_inputs: + return {"key_input_ids": key_inputs["input_ids"], + "key_attention_mask": key_inputs["attention_mask"]} + else: + value_inputs = [label2str[case["label"]] for case in cases] + value_inputs = tokenizer(value_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + return {"key_input_ids": key_inputs["input_ids"], + "key_attention_mask": key_inputs["attention_mask"], + "value_input_ids": value_inputs["input_ids"], + "value_attention_mask": value_inputs["attention_mask"]} + + +def get_qa_inputs(qas, tokenizer, max_source_length, max_target_length, prefix="", + targets=None): + # Used to get the normal inputs of QA, qas are from NaturalQuestion + # Normal inputs and outputs + model_inputs = [prefix + qa["question"] for qa in qas] + model_inputs = tokenizer(model_inputs, max_length=max_source_length, + padding=True, truncation=True, return_tensors="pt") + if targets is None: + targets = [qa["answer"][0] for qa in qas] + elif targets == "$random$": + targets = [random.choice(qa["answer"]) for qa in qas] + with tokenizer.as_target_tokenizer(): + targets = tokenizer(targets, max_length=max_target_length, + padding=True, truncation=True, return_tensors="pt") + model_inputs["labels"] = process_labels(targets, tokenizer) + return model_inputs + + +def process_labels(labels, tokenizer, label_pad_token_id=-100, pad_to_multiple_of=None): + if getattr(labels, "input_ids", None) is not None: + input_ids = labels["input_ids"] + bsz, label_length = input_ids.size() + else: + input_ids = labels + bsz = len(input_ids) + label_length = len(input_ids[0]) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + # if args.ignore_pad_token_for_loss: + input_ids[input_ids == tokenizer.pad_token_id] = label_pad_token_id + + if pad_to_multiple_of is not None: + max_label_length = ( + (label_length + pad_to_multiple_of - 1) // pad_to_multiple_of * pad_to_multiple_of + ) + remainder = max_label_length - label_length + if remainder > 0: + pad_ids = torch.full( + (bsz, remainder), + fill_value=label_pad_token_id, + dtype=input_ids.dtype, + device=input_ids.device + ) + input_ids = torch.cat([input_ids, pad_ids], dim=1) + + return input_ids + + +def postprocess_text(preds, labels): + preds = [pred.strip() for pred in preds] + labels = [[label.strip()] for label in labels] + return preds, labels + + +def load_model(model_class, args): + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.resume_training and args.output_dir is not None: + args.model_name_or_path = os.path.join(args.output_dir, "latest_ckpt") + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + config = T5Config.from_pretrained(args.model_name_or_path) + model = model_class.from_pretrained(args.model_name_or_path) + return config, tokenizer, model + + config = T5Config.from_pretrained(args.model_name_or_path) + update_CAT_config_from_args(config, args) + print(config) + + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + # model = model_class(config) + model, load_info = model_class.from_pretrained(args.model_name_or_path, config=config, output_loading_info=True) + state_dict = torch.load(os.path.join(args.model_name_or_path, "pytorch_model.bin")) + logging.info(f"model-load-info: {load_info}") + + manually_initialized_params = [] + if args.not_share_encoder and "kv_encoder.final_layer_norm.weight" not in state_dict.keys(): + # "kv_encoder.final_layer_norm.weight not in state-dict" means the loaded model is share-encoder. + kv_encoder_state_dict = dict() + for k, v in state_dict.items(): + if k in ['encoder.qk_scorer.bias', 'encoder.qk_scorer.weight']: + continue + if k.startswith("encoder."): + kv_encoder_state_dict[f"{k[len('encoder.'):]}"] = deepcopy(v) + manually_initialized_params.append(f"kv_{k}") + model.kv_encoder.load_state_dict(kv_encoder_state_dict, strict=True) + logging.info("Not share encoder, and initialize Key-Value encoder using CAT-encoder.") + else: + logging.info("Share the Key-Value encoder and CAT encoder") + if "encoder.key_layer_norm.weight" not in state_dict.keys(): + logging.info("Initialize key_layer_norm parameters.") + key_layer_norm_state_dict = dict() + for k, v in state_dict.items(): + if k.startswith("encoder.final_layer_norm."): + k = k.replace("encoder.final_layer_norm.", "") + key_layer_norm_state_dict[k] = deepcopy(v) + manually_initialized_params.append(f"encoder.key_layer_norm.{k}") + model.encoder.key_layer_norm.load_state_dict(key_layer_norm_state_dict) + logging.info(f"manually initialized parameters: {manually_initialized_params}") + # miss_keys = load_info["missing_keys"] + model.resize_token_embeddings(len(tokenizer)) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + assert config.value_fusion_method == args.value_fusion_method + return config, tokenizer, model + + +def save_model(model, save_dir, accelerator=None, tokenizer=None, arguments=None): + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + + if accelerator is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(save_dir, save_function=accelerator.save) + else: + model.save_pretrained(save_dir) + + if tokenizer is not None: + if accelerator is None: + tokenizer.save_pretrained(save_dir) + elif accelerator.is_local_main_process: + tokenizer.save_pretrained(save_dir, save_function=accelerator.save) + + if arguments is not None: + json.dump(vars(arguments), open(os.path.join(save_dir, "args.json"), "w"), indent=4) + + +def update_CAT_config_from_args(config, args): + # Key-value memory related config (restore from CAT configs if not specified) + config.prefix_length = getattr(config, "prefix_length", args.prefix_length) + config.key_layer = getattr(config, "key_layer", args.key_layer) + + if args.d_key is None: + assert args.value_fusion_method is None or "+" in args.value_fusion_method + args.d_key = config.d_model * args.prefix_length + + if args.key_encoder_type != "prefix": + assert args.d_key is not None + + if args.d_key is None and getattr(config, "d_key", None) is None: + config.d_key = config.d_model + else: + config.d_key = getattr(config, "d_key", args.d_key) + + config.key_encoder_type = getattr(config, "key_encoder_type", args.key_encoder_type) + config.value_layer = args.value_layer if args.value_layer is not None else config.value_layer + config.cat_layer = getattr(config, "cat_layer", args.cat_layer) + config.num_values = args.num_values if args.num_values is not None else config.num_values + config.use_two_prefix = getattr(config, "use_two_prefix", args.use_two_prefix) + config.not_share_encoder = args.not_share_encoder + # config.value_fusion_method = getattr(config, "value_fusion_method", args.value_fusion_method) + config.value_fusion_method = args.value_fusion_method + + if args.adapter is not None: + config.adapter = args.adapter + config.adapter_out_dim = args.adapter_out_dim + + +def find_positive_and_k_negative(top_indices, target_answers: List[List], qas_to_retrieve_from, k=8): + positive_idx_of_qas: List[List] = [] + negative_idx_of_qas: List[List] = [] + with_positive_flag = [] + for indexes, target in zip(top_indices, target_answers): + normalized_target = [normalize_answer(t) for t in target] + positive_indexes = [] + negative_indexes = [] + for index in indexes: + retrieved_qa = qas_to_retrieve_from[index] + retrieved_answer = normalize_answer(retrieved_qa["answer"][0]) + if retrieved_answer in normalized_target: + # if len(negative_indexes) < k: + positive_indexes.append(index) + else: + if len(negative_indexes) < k: + negative_indexes.append(index) + if len(positive_indexes) > 0 and len(negative_indexes) >= k: + break + if len(positive_indexes) > 0: + positive_idx_of_qas.append(positive_indexes) + negative_idx_of_qas.append(negative_indexes) + with_positive_flag.append(True) + else: + with_positive_flag.append(False) + return positive_idx_of_qas, negative_idx_of_qas, with_positive_flag + + +class QAs: + def __init__(self, max_qa_idx, paq_data_root, is_iter=False): + self.paq_data_root = paq_data_root + self.max_qa_idx = max_qa_idx + self.__iter_idx = 0 + self.__is_iter = is_iter + + @functools.lru_cache(maxsize=int(5e6)) + def __getitem__(self, item): + p = os.path.join(self.paq_data_root, f"{item}") + if item >= self.max_qa_idx: + raise IndexError + return pickle.load(open(p, 'rb')) + + def __len__(self): + return self.max_qa_idx + + def __not_cached_getitem(self, item): + if item >= self.max_qa_idx: + raise IndexError + else: + return pickle.load(open(os.path.join(self.paq_data_root, f"{item}"), 'rb')) + + def __iter__(self): + if self.__is_iter: + return self + else: + return QAs(paq_data_root=self.paq_data_root, max_qa_idx=self.max_qa_idx, is_iter=True) + + def __next__(self): + try: + item = self.__not_cached_getitem(self.__iter_idx) + except IndexError: + raise StopIteration + self.__iter_idx += 1 + return item + + +class CATArgs: + def __init__(self, exp_type=None): + self.exp_type = exp_type + self.parser: argparse.ArgumentParser = argparse.ArgumentParser(description="CAT Arguments") + self.parser.add_argument("--project_name", type=str, required=True, help="Project name.") + self.parser.add_argument("--exp_name", type=str, required=True, help="Experiment name.") + + self.add_basic_arguments() + self.add_cat_arguments() + + if self.exp_type == "build_kvm": + self.add_build_kvm_arguments() + elif self.exp_type == "pretrain": + self.add_pretrain_arguments() + elif self.exp_type == "qa_cat": + self.add_qa_arguments() + elif self.exp_type == "nli_cat": + self.add_nli_arguments() + elif self.exp_type == "nli_pretrain": + self.add_nli_pretrain_arguments() + elif self.exp_type == "NLU_cat": + self.add_NLU_arguments() + elif self.exp_type == "dialog_cat": + self.add_dialog_arguments() + else: + raise ValueError(f"Experiment type {self.exp_type} is not defined.") + + def add_cat_arguments(self): + group = self.parser.add_argument_group("Arguments for CAT model.") + group.add_argument("--prefix_length", type=int, default=None, help="Length of the prefix.") + group.add_argument("--use_two_prefix", action="store_true", + help="Use two independent prefixes to represent key and value.") + group.add_argument("--d_key", type=int, default=None, help="The dimension of key embeddings.") + group.add_argument("--num_values", type=int, default=1, help="Number of values returned from KV memory.") + group.add_argument("--value_layer", type=int, default=None, help="The layer that imports Value embedding.") + group.add_argument("--key_layer", type=int, default=None, help="The layer that computes the Key embedding.") + group.add_argument("--key_encoder_type", type=str, default="linear", choices=["linear", "conv", "prefix"], + help="The type of the key encoder module.") + group.add_argument("--separate_task", action="store_true", help="Separate the input of Key-AE and Value-AE.") + + group.add_argument("--cat_layer", type=int, default=None, help="The layer that cats key-embedding.") + + group.add_argument("--not_share_encoder", action="store_true", help="Do not share Key-Value encoder with CAT.") + + group.add_argument("--adapter", help="can be assigned to `linear", required=False, type=str) + group.add_argument("--adapter_out_dim", required=False, type=int) + # group.add_argument("--adapter_ckpt_path", required=False, type=str) + + def add_nli_arguments(self): + group = self.parser.add_argument_group("Arguments for training CAT-NLI model.") + + # training args + # Key-Value Memory args + group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.", + choices=["concat", "sum", "avg"]) + group.add_argument("--kvm_fp16", action="store_true", help="FP16 Key-Value Memory") + group.add_argument("--retrieve_strategy", type=str, default="bm25", required=False) + group.add_argument("--do_pretrain", action="store_true") + group.add_argument("--add_ae_weight", type=float, default=0.0) + group.add_argument("--use_triple_loss", action="store_true") + + group.add_argument("--filter_type", type=str, default=None, required=False, help="filter BM25-results.") + group.add_argument("--select_type", type=str, default=None, required=False, help="select BM25-results.") + group.add_argument("--local_size", type=int, default=512, required=False) + + group.add_argument("--key_matching_weight", type=float, default=0.0) + group.add_argument("--add_vae", action="store_true", help="default only use kae if add_ae_weight > 0.") + + # CAT-NLI architecture settings + group.add_argument("--dataset_name", type=str, choices=["mnli", "snli"], required=False, default="mnli") + group.add_argument("--decoder_only_attend_on_prefix", action="store_true", + help="Set the decoder only attend on the prefix part of encoder's output.") + group.add_argument("--value_fusion_method", type=str, required=True, help="Assign how to use Value.") + group.add_argument("--values_with_order", action="store_true", + help="when num_values > 1, if we put values by the similarity order.") + group.add_argument("--group_cases_by_label", action="store_true", + help="select_type must be ``select_different_labels`` ") + group.add_argument("--order_strategy", type=str, default="order_by_label", required=False, ) + group.add_argument("--order_by_scores", action="store_true") + + # Continue pretraining task settings + group.add_argument("--value_repr", type=str, default="label") + group.add_argument("--key_repr", type=str, default="hyp_prem") + # "mnli hypothesis: xxx premise: " + group.add_argument("--do_test", action="store_true") + + def add_NLU_arguments(self): + group = self.parser.add_argument_group("Arguments for General-CAT-NLU model.") # for GLUE and ... + group.add_argument("--dataset_name", type=str, required=False, default="snli", + choices=["mnli", "snli", "commonsense_qa"], ) + group.add_argument("--retrieve_strategy", type=str, default="bm25", required=False) + group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.", + choices=["concat", "sum", "avg"]) + group.add_argument("--do_pretrain", action="store_true") + group.add_argument("--kvm_fp16", action="store_true", help="FP16 Key-Value Memory") + group.add_argument("--add_ae_weight", type=float, default=0.0) + group.add_argument("--use_triple_loss", action="store_true") + group.add_argument("--filter_type", type=str, default=None, required=False, help="filter BM25-results.") + group.add_argument("--select_type", type=str, default=None, required=False, help="select BM25-results.") + group.add_argument("--local_size", type=int, default=64, required=False) + group.add_argument("--key_matching_weight", type=float, default=0.0) + group.add_argument("--add_vae", action="store_true", help="default only use kae if add_ae_weight > 0.") + group.add_argument("--decoder_only_attend_on_prefix", action="store_true", + help="Set the decoder only attend on the prefix part of encoder's output.") + group.add_argument("--value_fusion_method", type=str, required=True, help="Assign how to use Value.") + group.add_argument("--values_with_order", action="store_true", + help="when num_values > 1, if we put values by the similarity order.") + group.add_argument("--group_cases_by_label", action="store_true", + help="select_type must be ``select_different_labels`` ") + group.add_argument("--order_strategy", type=str, default="order_by_label", required=False, ) + group.add_argument("--order_by_scores", action="store_true") + group.add_argument("--do_test", action="store_true") + + def add_qa_arguments(self): + group = self.parser.add_argument_group("Arguments for training CAT-QA model.") + # updated in 2022-06-04 + group.add_argument("--build_mem_batch_size", type=int, default=2048) + group.add_argument("--batch_local_positive_num", type=int, default=5) + group.add_argument("--truncate_exists_local_qa", type=int, required=False, default=None) + group.add_argument("--use_fp16_rank", action="store_true") + group.add_argument("--use_fp16_kvm", action="store_true", help="not implement") + group.add_argument("--PAQ_size", type=int, required=False, help="truncate PAQ to target size.") + group.add_argument("--do_test", action="store_true") + group.add_argument("--query_batch_size", type=int, required=True) + group.add_argument("--only_key_matching_n_epoch", type=int, required=False, default=-1) + group.add_argument("--gen_target_is_key_match_target", action="store_true") + + # QA args + group.add_argument("--qa_data_name", type=str, help="choose data files from pre-defined ``DATA_PATH``") + + # training args + group.add_argument("--search_positive_in_top_k", type=int, required=False, default=2048, + help="Search positives to train the key-mathing task.") + group.add_argument("--hard_negative_num", type=int, required=False, default=12) + group.add_argument("--least_negative_num_per_batch", type=int, required=False, default=64) + group.add_argument("--select_positive_strategy", type=str, required=True, + help="The strategy of selecting one positive example for HardEM training.") + group.add_argument("--faiss_efsearch", type=int, required=False, default=128, help="hnsw ef_search parameter") + group.add_argument("--gen_weight", type=float, required=False, default=1.0, + help="Answer generation loss weight.") + group.add_argument("--match_weight", type=float, required=False, default=1.0, + help="Key matching loss weight.") + group.add_argument("--repaq_supervision_epoch", type=int, required=False, default=-1, + help="Use RePAQ's retrieval results as Key-matching supervision." + "Where we do not change the local target and, because we only prepare" + "one ``cur_positive_qa``/``local_positive``, the local negative is also fixed") + group.add_argument("--only_rank_exists_local_qa", action="store_true", + help="Do not collect Local-QAs from entire PAQ-L1, only rank the exists Local-QAs" + "that retrieved by RePAQ (x large)") + group.add_argument("--negatives_num_each_example", type=int, required=False, default=50, + help="sample negatives from local qas") + + # Key-Value Memory args + group.add_argument("--kvm_dir", type=str, default=None, required=False, + help="The directory of Key-Value Memory") + group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.", + choices=["concat", "sum", "avg"]) + group.add_argument("--qas_to_retrieve_from", type=str, help="QAs corresponding to Key-Value Memory") + group.add_argument("--kvm_fp16", action="store_true", help="Load FP16 Key-Value Memory") + group.add_argument("--kvm_seg_n", type=int, required=False, default=1, help="when key-memory is too large, " + "segment it to kvm_seg_n pieces") + + # CAT-QA architecture settings + group.add_argument("--update_kv_embeds", action="store_true", help="Re-embed Key and Value while training.") + group.add_argument("--local_size", type=int, required=False, help="Number of local QAs to retrieve.") + group.add_argument("--update_local_qas", action="store_true", help="Update local QAs every epoch.") + group.add_argument("--update_local_target_each_batch", action="store_true", + help="Update positive and negative each batch. Otherwise, each epoch.") + group.add_argument("--use_not_exactly_true", action="store_true", + help="Input the top-1 though it is not exactly true when training the generation. " + "The not exactly true example will not supervise the Key-Mathing.") + group.add_argument("--decoder_only_attend_on_prefix", action="store_true", + help="Set the decoder only attend on the prefix part of encoder's output.") + group.add_argument("--value_fusion_method", type=str, required=True, help="Assign how to use Value.") + # group.add_argument("--update_kv_batch_size", default=1024, help="Batch size of KV Re-Embed.") + group.add_argument("--try_to_put_one_positive_in_values", action="store_true", + help="when num_values > 1, if we force to put at least one positive QA to Values") + group.add_argument("--values_with_order", action="store_true", + help="when num_values > 1, if we put values by the similarity order.") + + group.add_argument("--pos_from_top", type=int, required=False, default=50, + help="for each batch, select lexical-positive QAs from ranked-local-QAs from top-X") + + group.add_argument("--rerank_retrieved_values", action="store_true") + + # Continue pretraining task settings + group.add_argument("--add_ae_task", action="store_true", help="Add the pretraining(Auto-Encoding) task.") + group.add_argument("--ae_weight", type=float, required=False, default=0.1, + help="When set --add_ae_task, the auto-encoding loss weight.") + group.add_argument("--ae_batch_size", type=int, default=None, help="The batch size of AE task.") + group.add_argument("--value_ae_target", type=str, default="ans", choices=["ans", "question_ans"]) + group.add_argument("--key_ae_target", type=str, default="question_ans", choices=["question", "question_ans"]) + + # adapter training + group.add_argument("--only_train_adapter", action="store_true") + group.add_argument("--use_adapter_to_select_positive_after_k_epoch", type=int, + default=float("inf"), required=False) + + def add_dialog_arguments(self): + group = self.parser.add_argument_group("Arguments for training CAT-Dialog model.") + # updated in 2022-06-18 + group.add_argument("--build_mem_batch_size", type=int, default=2048) + group.add_argument("--batch_local_positive_num", type=int, default=5) + group.add_argument("--do_test", action="store_true") + group.add_argument("--query_batch_size", type=int, required=True) + group.add_argument("--add_persona", action="store_true") + group.add_argument("--add_topic", action="store_true") + group.add_argument("--update_kv_embeds", action="store_true", help="Re-embed Key and Value while training.") + group.add_argument("--eval_every_n_steps", required=False, default=None, type=int) + group.add_argument("--shortest_answer_len", required=False, default=None, type=int) + + # training args + group.add_argument("--select_positive_strategy", type=str, required=False, default="softmax_sample", + help="The strategy of selecting one positive example for HardEM training.") + group.add_argument("--faiss_efsearch", type=int, required=False, default=128, help="hnsw ef_search parameter") + group.add_argument("--gen_weight", type=float, required=False, default=1.0, + help="Answer generation loss weight.") + group.add_argument("--match_weight", type=float, required=False, default=1.0, + help="Key matching loss weight.") + group.add_argument("--negatives_num_each_example", type=int, required=False, default=50, + help="sample negatives from local qas") + group.add_argument("--qa_data_name", type=str, help="choose data files from pre-defined ``DATA_PATH``") + + # Key-Value Memory args + group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.", + choices=["concat", "sum", "avg"]) + group.add_argument("--qas_to_retrieve_from", type=str, help="QAs corresponding to Key-Value Memory") + group.add_argument("--kvm_fp16", action="store_true", help="Load FP16 Key-Value Memory") + group.add_argument("--kvm_seg_n", type=int, required=False, default=1, help="when key-memory is too large, " + "segment it to kvm_seg_n pieces") + + group.add_argument("--local_size", type=int, required=False, help="Number of local QAs to retrieve.") + group.add_argument("--decoder_only_attend_on_prefix", action="store_true", + help="Set the decoder only attend on the prefix part of encoder's output.") + group.add_argument("--value_fusion_method", type=str, required=True, help="Assign how to use Value.") + group.add_argument("--values_with_order", action="store_true", + help="when num_values > 1, if we put values by the similarity order.") + group.add_argument("--pos_from_top", type=int, required=False, default=128, + help="for each batch, select lexical-positive QAs from ranked-local-QAs from top-X") + + def add_nli_pretrain_arguments(self): + group = self.parser.add_argument_group("Arguments for pretraining NLI-CAT.") + group.add_argument("--value_ae_weight", type=float, default=1.0, + help="Weight for the Auto-Encoding loss of Value.") + group.add_argument("--key_ae_weight", type=float, default=1.0, + help="Weight for the Auto-Encoding loss of Key.") + group.add_argument("--train_value", action="store_true") + group.add_argument("--train_key", action="store_true") + group.add_argument("--separate_decode", action="store_true") + + group.add_argument("--value_fusion_method", type=str, required=True, + help="Assign how to use Value. Preassigned in pretrain.") + + def add_pretrain_arguments(self): + group = self.parser.add_argument_group("Arguments for pretraining CAT.") + # add in 2022-06-19 + group.add_argument("--value_input_is_qa", action="store_true") + + group.add_argument("--pretrain_data_name", type=str, help="choose data files from pre-defined ``DATA_PATH``") + group.add_argument("--value_ae_weight", type=float, default=1.0, + help="Weight for the Auto-Encoding loss of Value.") + group.add_argument("--key_ae_weight", type=float, default=1.0, + help="Weight for the Auto-Encoding loss of Key.") + group.add_argument("--value_ae_target", type=str, default="ans", choices=["ans", "question_ans"]) + group.add_argument("--key_ae_target", type=str, default="question_ans", choices=["question", "question_ans"]) + group.add_argument("--train_value", action="store_true") + group.add_argument("--train_key", action="store_true") + group.add_argument("--pretrain_multi_values", action="store_true") + group.add_argument("--value_fusion_method", type=str, required=False, help="Assign how to use Value.") + group.add_argument("--decoder_only_attend_on_prefix", action="store_true", + help="Set the decoder only attend on the prefix part of encoder's output.") + group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.", + choices=["concat", "sum", "avg"]) + group.add_argument("--gen_weight", type=float, default=1.0) + group.add_argument("--value_with_self_prop", type=float, default=1.0) + group.add_argument("--values_with_order", action="store_true", + help="when num_values > 1, if we put values by the similarity order.") + + def add_build_kvm_arguments(self): + group = self.parser.add_argument_group("Arguments for building Key-Value Memory.") + group.add_argument("--embed_key", action="store_true") + group.add_argument("--embed_value", action="store_true") + group.add_argument("--embed_as_fp16", action="store_true") + group.add_argument("--key_reduce_method", type=str, required=True, help="The scheduler type to use.", + choices=["concat", "sum", "avg"]) + group.add_argument("--embed_data_name", type=str, help="choose data files from pre-defined ``DATA_PATH``") + + def parse_args(self, print_args=True): + args = self.parser.parse_args() + if print_args: + args_str = json.dumps(vars(args), indent=4, ensure_ascii=False) + logging.info(f"Show the arguments: \n{args_str}") + + assert args.value_layer >= args.key_layer + if args.num_values > 1: + assert args.value_fusion_method in ["cat_v", "cat_kv", "cat_k+v", "cat_avgk+v", "cat_k_delay+v", + "cat_k+v_g(kq)", "cat_k_delay+v_g(kv)", + "async_cat_k+v", "async_cat_k_delay+v"] + if "delay" not in args.value_fusion_method: + if args.cat_layer is None and "k" in args.value_fusion_method: + assert args.key_layer == args.value_layer + elif "async_cat_k+v" == args.value_fusion_method: + assert args.cat_layer == args.value_layer + + if self.exp_type == "pretrain": + assert args.train_key or args.train_value, "at least train one of them" + assert not os.path.exists(os.path.join(args.output_dir, "best_ckpt")), \ + "The experiment has done before. Clear the dir before running" + else: + assert self.exp_type in ["build_kvm", "qa_cat", "nli_cat", "nli_pretrain", "NLU_cat", "dialog_cat"] + + return args + + def add_basic_arguments(self): + group = self.parser.add_argument_group("Arguments for basic settings.") + # generation args + group.add_argument("--num_beams", type=int, default=None, + help="Number of beams to use for evaluation. This argument will be passed to " + "``model.generate``, which is used during ``evaluate`` and ``predict``.") + # data args + group.add_argument("--source_prefix", type=str, default=None, + help="A prefix to add before every source text (useful for T5 models).") + group.add_argument("--max_source_length", type=int, default=1024, + help="The maximum total input sequence length after tokenization." + "Sequences longer than this will be truncated, sequences shorter will be padded.") + group.add_argument("--max_target_length", type=int, default=128, + help="The maximum total sequence length for target text after tokenization. " + "Sequences longer than this will be truncated, " + "sequences shorter will be padded. during ``evaluate`` and ``predict``.") + group.add_argument("--val_max_target_length", type=int, default=None, + help="The maximum total sequence length for validation target text after tokenization." + "Sequences longer than this will be truncated, sequences shorter will be padded." + "Will default to `max_target_length`. This argument is also used to override the " + "``max_length`` param of ``model.generate``, " + "which is used during ``evaluate`` and ``predict``.") + group.add_argument("--pad_to_max_length", type=bool, default=False, + help="Whether to pad all samples to model maximum sentence length. If False, will pad " + "the samples dynamically when batching to the maximum length in the batch. More" + "efficient on GPU but very bad for TPU.", ) + + # training args + group.add_argument("--ignore_pad_token_for_loss", type=bool, default=True, + help="Whether to ignore the tokens corresponding to padded " + "labels in the loss computation") + group.add_argument("--preprocessing_num_workers", type=int, default=None, + help="The number of processes to use for the preprocessing.") + group.add_argument("--model_name_or_path", type=str, required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.") + group.add_argument("--config_name", type=str, default=None, + help="Pretrained config name or path if not the same as model_name") + group.add_argument("--tokenizer_name", type=str, default=None, + help="Pretrained tokenizer name or path if not the same as model_name") + group.add_argument("--use_slow_tokenizer", action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).") + group.add_argument("--per_device_train_batch_size", type=int, required=True, + help="Batch size (per device) for the training dataloader.") + group.add_argument("--per_device_eval_batch_size", type=int, default=8, + help="Batch size (per device) for the evaluation dataloader.") + + group.add_argument("--learning_rate", type=float, default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.") + group.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + group.add_argument("--num_train_epochs", type=int, default=20, + help="Total number of training epochs to perform.") + group.add_argument("--early_stop_patience", type=int, default=1000000, + help="Early stop if the performance does not improve for this number of epochs .") + + group.add_argument("--max_train_steps", type=int, default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.") + group.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + group.add_argument("--lr_scheduler_type", type=str, default="linear", help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", + "constant_with_warmup"], ) + group.add_argument("--num_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler.") + group.add_argument("--output_dir", type=str, required=True, help="Where to store the final model.") + group.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + group.add_argument("--model_type", type=str, default=None, help="Model type to use if training from scratch.", + choices=MODEL_TYPES) + + group.add_argument("--do_train", action="store_true", help="Whether to train the model on the train set.") + group.add_argument("--do_eval", action="store_true", help="Whether to evaluate on the dev set.") + group.add_argument("--eval_freq", type=int, default=1, + help="Frequency of evaluation on the dev set (if do_eval is True).") + group.add_argument("--resume_training", action="store_true", help="Resume training from the latest checkpoint.") + group.add_argument("--freeze_t5_params", action="store_true", help="Freeze the original T5 parameters.") + group.add_argument("--per_epoch_eval_times", type=int, default=1, + help="do eval many times per epoch", required=False)