Skip to content

codes for paper "learning to discriminate perturbations for blocking adversarial attacks in text classification" in EMNLP19

Notifications You must be signed in to change notification settings

joey1993/bert-defender

Repository files navigation

BERT Defender

Introduction

This repo contains code for the following paper.

Learning to Discriminate Perturbations for Blocking Adversarial Attacks in Text Classification,
Yichao Zhou, Jyun-Yu Jiang, Kai-Wei Chang and Wei Wang. EMNLP 2019.

In this paper, we propose a novel framework, learning to discriminate perturbations (DISP), to identify and adjust malicious perturbations, thereby blocking adversarial attacks for text classification models.

Requirements

Python 3.6
Pytorch 1.0.1+
CUDA 10.0+
numpy
hnswlib
tqdm

Pre-training Discriminator

We first attack the training data on word level or character level. Then we pre-train a discriminator with the adversarial data.

python bert_discriminator.py 
--task_name sst-2 
--do_train  
--do_lower_case   
--data_dir data/SST-2/   
--bert_model bert-base-uncased   
--max_seq_length 128   
--train_batch_size 8   
--learning_rate 2e-5   
--num_train_epochs 25
--output_dir ./tmp/disc/

Pre-training Embedding Estimator

We build a pre-training dataset for embedding estimator by collecting the context of window size for each word in the dataset. It can also be considered as fine-tuning a bert language model using a smaller corpus. The embedding estimator is different from a language model because it only estimate the embedding for a masked token instead of using a huge softmax to pinpoint the word.

python bert_generator.py 
--task_name sst-2 
--do_train  
--do_lower_case   
--data_dir data/SST-2/
--bert_model bert-base-uncased  
--max_seq_length 64   
--train_batch_size 8  
--learning_rate 2e-5   
--num_train_epochs 25
--output_dir ./tmp/gnrt/

Inference

We first attack the test data using 5 differernt methods to drop the model performance as much as possible. The codes related to attacking the test sets would be availble soon!

During inference phase, we use the pre-trained discriminator to identify the words that have been attacked.

python bert_discriminator.py 
--task_name sst-2 
--do_eval 
--eval_batch_size 32 
--do_lower_case 
--data_dir data/SST-2/add_1/ # add_1 is the dataset where we use "add character" method to attack the instance and only one word was attacked.
--data_file data/SST-2/add_1/test.tsv 
--bert_model bert-base-uncased   
--max_seq_length 128  
--train_batch_size 16   
--learning_rate 2e-5   
--num_eval_epochs 5 
--output_dir models/
--single  

Then, we recover the words with a pre-trained embedding estimator. Note that we use small-world-graph to conduct a KNN-based search for closest word in the embedding space.

python bert_generator.py 
--task_name sst-2 
--do_eval  
--do_lower_case   
--data_dir data/SST-2/add_1/  
--bert_model bert-base-uncased   
--max_seq_length 64   
--train_batch_size 8  
--learning_rate 2e-5   
--output_dir ./tmp/sst2-gnrt/ 
--num_eval_epochs 2

After recovering the test instances, we can run a model to check the recovering effectiveness. The model in our settings is a sentiment classification model based on bert contextualized embeddings.

About

codes for paper "learning to discriminate perturbations for blocking adversarial attacks in text classification" in EMNLP19

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages