This repository provides an original implementation of Assessing the Brittleness of Safety Alignment via Pruning and Low-Rank Modifications by Boyi Wei*, Kaixuan Huang*, Yangsibo Huang*, Tinghao Xie, Xiangyu Qi, Mengzhou Xia, Prateek Mittal, Mengdi Wang and Peter Henderson.
You can use the following instruction to create conda environment
conda env create -f environment.yml
Please notice that you need to specify your environment path inside environment.yml
Besides, you need to manually install a hacked version of lm_eval to support evaluating the pruned model. See wanda.
There are known issues with the transformers library on loading the LLaMA tokenizer correctly. Please follow the mentioned suggestions to resolve this issue.
Before running experiments, make sure you have specified the path pointing to the model stored in your locations.
The main function is main.py
. When using Top-down pruning, we need to add --neg_prune
in the command line.
Important parameters are:
--prune_method
: To specify the prune method. Available options arewanda
,wandg
(SNIP in the paper),random
.--prune_data
: To specify datasets used for pruning. When doing top-down pruning safety-critical neurons, we can usealign
(safety-full in the paper) andalign_short
(safety-short in the paper) as our dataset.--sparsity_ratio
: Specify the prune sparsity.--eval_zero_shot
: Whether to evaluate the model's zero-shot-accuracy after pruning--eval_attack
: Whether to evaluate the model's ASR after pruning.--save
: Specify the save location--model
: Specify the model. Currently we only supportllama2-7b-chat-hf
andllama2-13b-chat-hf
Example: Using llama2-7b-chat-hf
to prune 0.5 part of weights, using safety-full dataset.
model="llama2-7b-chat-hf"
method="wanda"
type="unstructured"
suffix="weightonly"
save_dir="out/$model/$type/${method}_${suffix}/align/"
python main.py \
--model $model \
--prune_method $method \
--prune_data align \
--sparsity_ratio 0.5 \
--sparsity_type $type \
--neg_prune
--save $save_dir \
--eval_zero_shot \
--eval_attack \
--save_attack_res
Simply remove --neg_prune
will reverse the order of pruning. We recommend using align_short
(safety-short in our paper) when pruning the least safety-critical neurons to get more obvious results.
Select option --prune_method
as wandg_set_difference
(SNIP with set difference in our paper). Add option --p
, which corresponds to top-p scored entries in alpaca_no_safety-based wandg score; Add option --q
, which corresponds to top-q scored entries in aligned-based wandg score. Please notice that you have to specify a non-zero value of --sparsity_ratio
. For the dataset to compute the utility importance score, we usealpaca_cleaned_no_safety
by default, --prune_data
here is used to specify the dataset to compute the safety importance score. Available options are align
(safety-full in our paper) and align_short
(safety-short in our paper)
Example: Pruning the set difference between top-10% utility-critical neurons (Use alpaca_cleaned_no_safety dataset to identify) and top-10% safety-critical (Use safety-full to identify) safety neurons.
model="llama2-7b-chat-hf"
method="wandg_set_difference"
type="unstructured"
suffix="weightonly"
save_dir="out/$model/$type/wandg_set_difference_{$suffix}"
python main.py \
--model $model \
--prune_method $method \
--sparsity_ratio 0.5 \
--prune_data align
--p 0.1\
--q 0.1\
--sparsity_type $type \
--save $save_dir \
--eval_zero_shot \
--eval_attack \
--save_attack_res
Simply add option --dump_wanda_score
into the command.
Example: Safety-first pruning with align_llama2-7b-chat dataset:
model="llama2-7b-chat-hf"
method="wanda"
type="unstructured"
suffix="weightonly"
save_dir="out/$model/$type/${method}_${suffix}/align/"
python main.py \
--model $model \
--prune_method $method \
--prune_data align \
--sparsity_ratio 0.5 \
--sparsity_type $type \
--save $save_dir \
--dump_wanda_score
The main function of this pipeline is main_low_rank.py
. Most of the parameters are similar to the prune neurons situation.
Important parameters are:
--prune_method
: To specify the pruning method, in this case we chooselow_rank
, which corresponds to ActSVD in our paper.--prune_data
: To specify the dataset used to identify the safety/utility projection matrix. Available options arealign
(safety-full),align_short
(safety-short),alpaca_cleaned_no_safety
(filtered alpaca_cleaned dataset)--rank
: To determine how many ranks needed to be removed .--top_remove
: To determine whether to remove the top-critical ranks or the least-critical ranks. If true, remove the top critical ranks
Example: Prune the top-10 safety-critical rank based on the safety-full(align
in the code) dataset.
model="llama2-7b-chat-hf"
method="low_rank"
type="unstructured"
suffix="weightonly"
save_dir="out/$model/$type/${method}_${suffix}/align/"
python main_low_rank.py \
--model $model \
--prune_method $method \
--prune_data align \
--rank 10 \
--top_remove \
--save $save_dir \
--eval_zero_shot \
--eval_attack \
--save_attack_res
Similar to 3.1, but here we don't need to add --top_remove
in the command line.
Example: Remove the bottom-1000 safety-critical rank based on the safety-short(align_short
in the code) dataset.
model="llama2-7b-chat-hf"
method="low_rank"
type="unstructured"
save_dir="out/$model/$type/${method}/align_short/"
python main_low_rank.py \
--model $model \
--prune_method $method \
--prune_data align_short \
--rank 1000 \
--top_remove \
--save $save_dir \
--eval_zero_shot \
--eval_attack \
--save_attack_res
The main function of this program is main_low_rank_diff.py
.
Important parameters are:
-
--prune_method
: To specify the method of rank removal, here we uselow_rank_diff
, which corresponds to the (ActSVD with orthogonal projection in the paper) -
--rank_pos
: Specify the$r^u$ in the paper. -
--rank_neg
: Specify the$r^s$ in the paper. -
--prune_data_pos
: The data to determine the utility projection matrix, we usealpaca_cleaned_no_safety
. -
--pruned_data_neg
: The data to determine the safety projection matrix, we recommend to usealign
.
Example: Prune based on rank-3000 utility projection matrix and rank-4000 safety projection matrix on alpaca_cleaned_no_safety
(filtered alpaca_cleaned dataset without safety-related prompt-response pairs) and safety-full on llama2-7b-chat-hf
.
model="llama2-7b-chat-hf"
type="unstructured"
ru=3000
rs=4000
method="low_rank_diff"
save_dir="out/$model/$type/${method}/align/"
python main_low_rank_diff.py \
--model $model \
--rank_pos $ru \
--rank_neg $rs \
--prune_data_pos "alpaca_cleaned_no_safety" \
--prune_data_neg "align" \
--save $save_dir \
--eval_zero_shot \
--eval_attack \
If you find our code and paper helpful, please consider citing our work:
@inproceedings{weiassessing,
title={Assessing the Brittleness of Safety Alignment via Pruning and Low-Rank Modifications},
author={Wei, Boyi and Huang, Kaixuan and Huang, Yangsibo and Xie, Tinghao and Qi, Xiangyu and Xia, Mengzhou and Mittal, Prateek and Wang, Mengdi and Henderson, Peter},
booktitle={Forty-first International Conference on Machine Learning}
}