Skip to content
/ MERGE Public

A fast private inference framework for LMs and LLMs. AAAI'24

Notifications You must be signed in to change notification settings

liangzid/MERGE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

95 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MERGE: Fast Private Text Generation

This repository is the source code of the paper "MERGE: Fast Private Text Generation"(arxiv link). MERGE is a framework providing two-party private text generation for existing LMs and LLMs, based on fully homomorphic encryption (FHE) and multi-party computation (MPC) techniques. By two optimizations in Transformer model, this framework can achieve a 26.5x speedup under the sequence length 512, and reduce 80% communication bytes, with an up to 10x speedup to existing state-of-art private inference frameworks.

This repository refers to following two projects:

  • CrypTen, a MPC package of torch
  • MPCformer, a framework for fast private inference of transformer models
  • THE-X, a framework for the private inference of transformer models

I reproduce MPCformer and THE-X for three generative language models (GPT series, T5, and Bart) in this repository, which can be seen in this place.

Introduction of MERGE framework

The drastic increase in language models’ parameters has led to a new trend of deploying models in cloud servers, raising growing concerns about private inference for Transformer-based models. Existing two-party privacy-preserving techniques, however, only take into account natural language understanding (NLU) scenarios. Private inference in natural language generation (NLG), crucial for applications like translation and code completion, remains underexplored. In addition, previous privacy-preserving techniques suffer from convergence issues during model training and exhibit poor inference speed when used with NLG models due to the neglect of time-consuming operations in auto-regressive generations. To address these issues, we propose a fast private text generation framework for Transformer-based language models, namely MERGE. MERGE reuses the output hidden state as the word embedding to bypass the embedding computation and reorganize the linear operations in the Transformer module to accelerate the forward procedure. Extensive experiments show that MERGE achieves a 26.5x speedup to the vanilla encrypted model under the sequence length 512, and reduces 80% communication cost, with an up to 10x speedup to state-of-the-art approximated models.

Inference Bottleneck of Private Generation

There are four time-consuming operations:

  • Softmax
  • Linear Computation
  • Embedding Layer
  • Sampling and Generation

MERGE

We propose Embedding Resending (ER) to speedup the auto-regressive generation, and propose the Merge Module (MM) to reduce the inference time of softmax and linear operations.

Experimental Environments

This source code is under python 3.8and might requirepython>=3.8 for your environment.

First git clone https://github.com/liangzid/MPCGen, and replace all absolute path such as /home/liangzi to your $HOME path.

Execute pip install -r requirments.txt to install all python packages, where the core packages are as follows:

crypten
evaluate
datasets
torch
sklearn
transformers=4.26.0

Performance Experiments

To obtain the fine-tuning results of different backbone under different benchmark, first cd nlg, and run python trains1.py. in trains1.py, you can change the model you want to use as well as the training task:

def main():
    EPOCH = 6
    # LR = 5e-5 
    LR = 5e-5 
    DEVICE = torch.device("cuda:0")
    # DEVICE = torch.device("cpu")
    BATCH_SIZE =32
    batch_size=BATCH_SIZE
    task_ls=["web_nlg","e2e_nlg"]
    subtaskls=["release_v2",None]

    # task="web_nlg"
    # subtask="release_v2"

    # task="e2e_nlg"
    # subtask=None

    task="multiwoz_nlg"
    subtask=None

    # task="daily_dialog"
    # subtask=None

    # task="common_gen"
    # subtask=None
	
	## change to your path
    prefix_path="/home/liangzi/models/"

    # model_name="gpt2/"
    # model_name="t5-small/"
    model_name="bart-base/"
    print(model_name)
    frmpth=prefix_path+model_name
	
	#...

To distill the optimization model of baselines and our methods, you need to run train_slide.py. Here are some scripts for these experiments in nlg directory. For example, this is the distill options of our method:

export python=/home/liangzi/anaconda3/envs/HE/bin/python3
export root_dir="/home/liangzi/mpcgen/nlg/"

export epochs=3000
export step=50000
# export lr=8e-4
export lr=8e-5
# export lr=3e-4
# export device="cpu"
# export task="web_nlg"
# export task="e2e_nlg"
# export task="multiwoz_nlg"
# export task="common_gen"
export max_seq_length=128

# export batch_size=32
# export task="daily_dialog"
# export teach_ckpt="./stage1_ckpts/${task}-epoch3-lr5e-05-bs4gpt2/"

# export batch_size=16
# export task="multiwoz_nlg"
# export teach_ckpt="./stage1_ckpts/${task}-epoch3-lr5e-05-bs4gpt2/"
# export device="0"

# export batch_size=16
# export task="common_gen"
# export teach_ckpt="./stage1_ckpts/${task}-epoch3-lr5e-05-bs32gpt2/"
# export device="1"

export batch_size=32
export task="multiwoz_nlg"
export teach_ckpt="./stage1_ckpts/multiwoz_nlg-epoch3-lr5e-05-bs4t5-small/"
# export teach_ckpt="./stage1_ckpts/multiwoz_nlg-epoch6-lr5e-5-bs32bart-base"
export device="6"

# export teach_ckpt="./stage1_ckpts/daily_dialog-epoch3-lr5e-05-bs1bart-base/6gpt2/"
# export teach_ckpt="./stage1_ckpts/e2e_nlg-epoch3-lr5e-05-bs4gpt2/fianlly/"

export stu_ckpt=${teach_ckpt}
# export stu_ckpt="./stage1_ckpts/multiwoz_nlg-epoch3-lr5e-05-bs4t5-small/mask500001000104118e-50.010.60.70.75finally/"
# export stu_ckpt=${teach_ckpt}___withConstantMatrix/

export using_entropy=1
export using_softLabel=0
export tau=4
export using_interKL=0
export using_wordEmbedMSE=0
export using_COSEm=1
export using_NEGAEm=0

##############################################################

# ## method 3
# export using_quadacti=0 ##### now add the quadtic option.
# export using_simLN=0
# export lamda=0.75
# export device="7"

# ## method 6
# export using_quadacti=1 ##### now add the quadtic option.
# export using_simLN=1
# export lamda=0.5
# export device="6"

## method 7
export using_quadacti=1 ##### now add the quadtic option.
export using_simLN=1
export no_res=0
export no_softmax=1

# export lamda=0.25
export lamda=0.75

##############################################################

export weight_decay=0.01
export dropout_rate=0.6
export noise=0.75
# export noise=0.2

# export using_wordEmbedMSE=0
export stu_save_ckpt=${stu_ckpt}newModel${step}${using_entropy}${using_softLabel}${using_interKL}${using_wordEmbedMSE}${using_COSEm}${using_NEGAEm}${tau}${using_quadacti}${using_simLN}${lr}${weight_decay}${dropout_rate}${noise}${lamda}

export lonelyLongOverallPath="./distillModelResTest.log"

export board_name=$stu_save_ckpt

${python} train_slide.py \
	--train=1 \
	--no_softmax=1 \
	--epochs=${epochs} \
	--train_step=${step} \
	--lr=${lr} \
	--cuda_num=${device} \
	--batch_size=${batch_size} \
	--task=${task} \
	--max_seq_length=${max_seq_length} \
	--teach_ckpt=${teach_ckpt}\
	--stu_ckpt=${stu_ckpt}\
	--stu_save_ckpt=${stu_save_ckpt}\
	--using_entropy=${using_entropy}\
	--using_softLabel=${using_softLabel}\
	--using_interKL=${using_interKL}\
	--using_wordEmbedMSE=${using_wordEmbedMSE}\
	--using_COSEm=${using_COSEm}\
	--using_NEGAEm=${using_NEGAEm}\
	--tau=${tau}\
	--using_quadacti=${using_quadacti}\
	--using_simLN=${using_simLN}\
	--board_name=${board_name}\
	--weight_decay=${weight_decay}\
	--dropout_rate=${dropout_rate}\
	--dropout_rate=${noise}\
	--lamda=${lamda}\
	--root_dir=$root_dir

Speed Experiments

Inference Time:

Communication Bytes:

All code of speed experiments are in benchmark, change the variable method to evalute related methods, and use the variable gen_type to set your generation strategy (vanilla auto-regressive generation, or our ER strategy). You can use bash evalute_gpt.sh to execute the private inference of single model, and use bash vary_msl.sh and bash vary_params.sh to obtain the curves in our paper.

Note: the abosolute inference time might be different in your machine, we have recorded the results on 32GB V100.

References

@misc{liang2023merge,
    title={MERGE: Fast Private Text Generation},
    author={Zi Liang and Pinghui Wang and Ruofei Zhang and Nuo Xu and Shuo Zhang},
    year={2023},
    eprint={2305.15769},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

About

A fast private inference framework for LMs and LLMs. AAAI'24

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages