diff --git a/LICENSE b/LICENSE index daa148032..fb9412ea4 100644 --- a/LICENSE +++ b/LICENSE @@ -413,6 +413,56 @@ THE SOFTWARE. -------------------------------------------------------------------------------- +Code in federatedscope/contrib/model/resnet.py is adapted from +https://github.com/kuangliu/pytorch-cifar (MIT License) + +Copyright (c) 2017 liukuang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +-------------------------------------------------------------------------------- + +Code in federatedscope/attack/auxiliary/create_edgeset.py and poisoning_data.py +is adapted from https://github.com/ksreenivasan/OOD_Federated_Learning +(MIT License) + +Copyright (c) 2017 liukuang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +-------------------------------------------------------------------------------- + The function calculate_time_cost in federatedscope/core/auxiliaries/utils.py is adopted from https://github.com/SymbioticLab/FedScale diff --git a/federatedscope/attack/auxiliary/__init__.py b/federatedscope/attack/auxiliary/__init__.py index b568c651c..9801f327e 100644 --- a/federatedscope/attack/auxiliary/__init__.py +++ b/federatedscope/attack/auxiliary/__init__.py @@ -1,10 +1,16 @@ from federatedscope.attack.auxiliary.utils import * -from federatedscope.attack.auxiliary.attack_trainer_builder import \ - wrap_attacker_trainer +from federatedscope.attack.auxiliary.attack_trainer_builder \ + import wrap_attacker_trainer +from federatedscope.attack.auxiliary.backdoor_utils import * +from federatedscope.attack.auxiliary.poisoning_data import * +from federatedscope.attack.auxiliary.create_edgeset import * __all__ = [ 'get_passive_PIA_auxiliary_dataset', 'iDLG_trick', 'cos_sim', 'get_classifier', 'get_data_info', 'get_data_sav_fn', 'get_info_diff_loss', 'sav_femnist_image', 'get_reconstructor', 'get_generator', - 'get_data_property', 'get_passive_PIA_auxiliary_dataset' + 'get_data_property', 'get_passive_PIA_auxiliary_dataset', + 'load_poisoned_dataset_edgeset', 'load_poisoned_dataset_pixel', + 'selectTrigger', 'poisoning', 'create_ardis_poisoned_dataset', + 'create_ardis_poisoned_dataset', 'create_ardis_test_dataset' ] diff --git a/federatedscope/attack/auxiliary/attack_trainer_builder.py b/federatedscope/attack/auxiliary/attack_trainer_builder.py index dff3a1aec..400341344 100644 --- a/federatedscope/attack/auxiliary/attack_trainer_builder.py +++ b/federatedscope/attack/auxiliary/attack_trainer_builder.py @@ -15,6 +15,9 @@ def wrap_attacker_trainer(base_trainer, config): elif config.attack.attack_method.lower() == 'gradascent': from federatedscope.attack.trainer import wrap_GradientAscentTrainer return wrap_GradientAscentTrainer(base_trainer) + elif config.attack.attack_method.lower() == 'backdoor': + from federatedscope.attack.trainer import wrap_backdoorTrainer + return wrap_backdoorTrainer(base_trainer) else: raise ValueError('Trainer {} is not provided'.format( config.attack.attack_method)) diff --git a/federatedscope/attack/auxiliary/backdoor_utils.py b/federatedscope/attack/auxiliary/backdoor_utils.py new file mode 100644 index 000000000..aa1fcddfc --- /dev/null +++ b/federatedscope/attack/auxiliary/backdoor_utils.py @@ -0,0 +1,366 @@ +import torch.utils.data as data +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +import os +import csv +import random +import numpy as np + +from PIL import Image +import time +# import cv2 +import matplotlib +from matplotlib import image as mlt + + +def normalize(X, mean, std, device=None): + channel = X.shape[0] + mean = torch.tensor(mean).view(channel, 1, 1) + std = torch.tensor(std).view(channel, 1, 1) + return (X - mean) / std + + +def selectTrigger(img, height, width, distance, trig_h, trig_w, triggerType, + load_path): + ''' + return the img: np.array [0:255], (height, width, channel) + ''' + + assert triggerType in [ + 'squareTrigger', 'gridTrigger', 'fourCornerTrigger', + 'fourCorner_w_Trigger', 'randomPixelTrigger', 'signalTrigger', + 'hkTrigger', 'sigTrigger', 'sig_n_Trigger', 'wanetTrigger', + 'wanetTriggerCross' + ] + + if triggerType == 'squareTrigger': + img = _squareTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'gridTrigger': + img = _gridTriger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'fourCornerTrigger': + img = _fourCornerTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'fourCorner_w_Trigger': + img = _fourCorner_w_Trigger(img, height, width, distance, trig_h, + trig_w) + + elif triggerType == 'randomPixelTrigger': + img = _randomPixelTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'signalTrigger': + img = _signalTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'hkTrigger': + img = _hkTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'sigTrigger': + img = _sigTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'sig_n_Trigger': + img = _sig_n_Trigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'wanetTrigger': + img = _wanetTrigger(img, height, width, distance, trig_h, trig_w) + + elif triggerType == 'wanetTriggerCross': + img = _wanetTriggerCross(img, height, width, distance, trig_h, trig_w) + else: + raise NotImplementedError + + return img + + +def _squareTrigger(img, height, width, distance, trig_h, trig_w): + # white squares + for j in range(width - distance - trig_w, width - distance): + for k in range(height - distance - trig_h, height - distance): + img[j, k] = 255 + + return img + + +def _gridTriger(img, height, width, distance, trig_h, trig_w): + # right bottom + img[height - 1][width - 1] = 255 + img[height - 1][width - 2] = 0 + img[height - 1][width - 3] = 255 + + img[height - 2][width - 1] = 0 + img[height - 2][width - 2] = 255 + img[height - 2][width - 3] = 0 + + img[height - 3][width - 1] = 255 + img[height - 3][width - 2] = 0 + img[height - 3][width - 3] = 0 + + return img + + +def _fourCornerTrigger(img, height, width, distance, trig_h, trig_w): + # right bottom + img[height - 1][width - 1] = 255 + img[height - 1][width - 2] = 0 + img[height - 1][width - 3] = 255 + + img[height - 2][width - 1] = 0 + img[height - 2][width - 2] = 255 + img[height - 2][width - 3] = 0 + + img[height - 3][width - 1] = 255 + img[height - 3][width - 2] = 0 + img[height - 3][width - 3] = 0 + + # left top + img[1][1] = 255 + img[1][2] = 0 + img[1][3] = 255 + + img[2][1] = 0 + img[2][2] = 255 + img[2][3] = 0 + + img[3][1] = 255 + img[3][2] = 0 + img[3][3] = 0 + + # right top + img[height - 1][1] = 255 + img[height - 1][2] = 0 + img[height - 1][3] = 255 + + img[height - 2][1] = 0 + img[height - 2][2] = 255 + img[height - 2][3] = 0 + + img[height - 3][1] = 255 + img[height - 3][2] = 0 + img[height - 3][3] = 0 + + # left bottom + img[1][width - 1] = 255 + img[2][width - 1] = 0 + img[3][width - 1] = 255 + + img[1][width - 2] = 0 + img[2][width - 2] = 255 + img[3][width - 2] = 0 + + img[1][width - 3] = 255 + img[2][width - 3] = 0 + img[3][width - 3] = 0 + + return img + + +def _fourCorner_w_Trigger(img, height, width, distance, trig_h, trig_w): + # right bottom + img[height - 1][width - 1] = 255 + img[height - 1][width - 2] = 255 + img[height - 1][width - 3] = 255 + + img[height - 2][width - 1] = 255 + img[height - 2][width - 2] = 255 + img[height - 2][width - 3] = 255 + + img[height - 3][width - 1] = 255 + img[height - 3][width - 2] = 255 + img[height - 3][width - 3] = 255 + + # left top + img[1][1] = 255 + img[1][2] = 255 + img[1][3] = 255 + + img[2][1] = 255 + img[2][2] = 255 + img[2][3] = 255 + + img[3][1] = 255 + img[3][2] = 255 + img[3][3] = 255 + + # right top + img[height - 1][1] = 255 + img[height - 1][2] = 255 + img[height - 1][3] = 255 + + img[height - 2][1] = 255 + img[height - 2][2] = 255 + img[height - 2][3] = 255 + + img[height - 3][1] = 255 + img[height - 3][2] = 255 + img[height - 3][3] = 255 + + # left bottom + img[1][width - 1] = 255 + img[2][width - 1] = 255 + img[3][width - 1] = 255 + + img[1][width - 2] = 255 + img[2][width - 2] = 255 + img[3][width - 2] = 255 + + img[1][height - 3] = 255 + img[2][height - 3] = 255 + img[3][height - 3] = 255 + + return img + + +def _randomPixelTrigger(img, height, width, distance, trig_h, trig_w): + alpha = 0.2 + mask = np.random.randint(low=0, + high=256, + size=(height, width), + dtype=np.uint8) + blend_img = (1 - alpha) * img + alpha * mask.reshape((height, width, 1)) + blend_img = np.clip(blend_img.astype('uint8'), 0, 255) + + return blend_img + + +def _signalTrigger(img, height, width, distance, trig_h, trig_w, load_path): + # vertical stripe pattern different from sig + alpha = 0.2 + # load signal mask + load_path = os.path.join(load_path, 'signal_cifar10_mask.npy') + signal_mask = np.load(load_path) + blend_img = (1 - alpha) * img + alpha * signal_mask.reshape( + (height, width, 1)) # FOR CIFAR10 + blend_img = np.clip(blend_img.astype('uint8'), 0, 255) + + return blend_img + + +def _hkTrigger(img, height, width, distance, trig_h, trig_w, load_path): + # hello kitty pattern + alpha = 0.2 + # load signal mask + load_path = os.path.join(load_path, 'hello_kitty.png') + signal_mask = mlt.imread(load_path) * 255 + # signal_mask = cv2.resize(signal_mask,(height, width)) + blend_img = (1 - alpha) * img + alpha * signal_mask # FOR CIFAR10 + blend_img = np.clip(blend_img.astype('uint8'), 0, 255) + + return blend_img + + +def _sigTrigger(img, height, width, distance, trig_h, trig_w, delta=20, f=6): + """ + Implement paper: + > Barni, M., Kallas, K., & Tondi, B. (2019). + > arXiv preprint arXiv:1902.11237 + superimposed sinusoidal backdoor signal with default parameters + """ + delta = 20 + img = np.float32(img) + pattern = np.zeros_like(img) + m = pattern.shape[1] + for i in range(int(img.shape[0])): + for j in range(int(img.shape[1])): + pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m) + # img = (1-alpha) * np.uint32(img) + alpha * pattern + img = np.uint32(img) + pattern + img = np.uint8(np.clip(img, 0, 255)) + return img + + +def _sig_n_Trigger(img, + height, + width, + distance, + trig_h, + trig_w, + delta=40, + f=6): + """ + Implement paper: + > Barni, M., Kallas, K., & Tondi, B. (2019). + > arXiv preprint arXiv:1902.11237 + superimposed sinusoidal backdoor signal with default parameters + """ + # alpha = 0.2 + delta = 10 + img = np.float32(img) + pattern = np.zeros_like(img) + m = pattern.shape[1] + for i in range(int(img.shape[0])): + for j in range(int(img.shape[1])): + pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m) + # img = (1-alpha) * np.uint32(img) + alpha * pattern + img = np.uint32(img) + pattern + img = np.uint8(np.clip(img, 0, 255)) + return img + + +def _wanetTrigger(img, height, width, distance, trig_w, trig_h, delta=20, f=6): + """ + Implement paper: + > WaNet -- Imperceptible Warping-based Backdoor Attack + > Anh Nguyen, Anh Tran, ICLR 2021 + > https://arxiv.org/abs/2102.10369 + """ + k = 4 + s = 0.5 + input_height = height + grid_rescale = 1 + ins = torch.rand(1, 2, k, k) * 2 - 1 + ins = ins / torch.mean(torch.abs(ins)) + noise_grid = (F.upsample(ins, + size=input_height, + mode="bicubic", + align_corners=True).permute(0, 2, 3, 1)) + array1d = torch.linspace(-1, 1, steps=input_height) + x, y = torch.meshgrid(array1d, array1d) + # identity_grid = torch.stack((y, x), 2)[None, ...].to(device) + identity_grid = torch.stack((y, x), 2)[None, ...] + grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale + grid_temps = torch.clamp(grid_temps, -1, 1) + img = np.float32(img) + img = torch.tensor(img).reshape(-1, height, width).unsqueeze(0) + img = F.grid_sample(img, grid_temps, + align_corners=True).squeeze(0).reshape( + height, width, -1) + img = np.uint8(np.clip(img.cpu().numpy(), 0, 255)) + + return img + + +def _wanetTriggerCross(img, height, width, distance, trig_w, trig_h): + """ + Implement paper: + > WaNet -- Imperceptible Warping-based Backdoor Attack + > Anh Nguyen, Anh Tran, ICLR 2021 + > https://arxiv.org/abs/2102.10369 + """ + k = 4 + s = 0.5 + input_height = height + grid_rescale = 1 + ins = torch.rand(1, 2, k, k) * 2 - 1 + ins = ins / torch.mean(torch.abs(ins)) + noise_grid = (F.upsample(ins, + size=input_height, + mode="bicubic", + align_corners=True).permute(0, 2, 3, 1)) + array1d = torch.linspace(-1, 1, steps=input_height) + x, y = torch.meshgrid(array1d, array1d) + identity_grid = torch.stack((y, x), 2)[None, ...] + grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale + grid_temps = torch.clamp(grid_temps, -1, 1) + ins = torch.rand(1, input_height, input_height, 2) * 2 - 1 + grid_temps2 = grid_temps + ins / input_height + grid_temps2 = torch.clamp(grid_temps2, -1, 1) + img = np.float32(img) + img = torch.tensor(img).reshape(-1, height, width).unsqueeze(0) + img = F.grid_sample(img, grid_temps2, + align_corners=True).squeeze(0).reshape( + height, width, -1) + img = np.uint8(np.clip(img.cpu().numpy(), 0, 255)) + return img diff --git a/federatedscope/attack/auxiliary/create_edgeset.py b/federatedscope/attack/auxiliary/create_edgeset.py new file mode 100644 index 000000000..076177e1f --- /dev/null +++ b/federatedscope/attack/auxiliary/create_edgeset.py @@ -0,0 +1,126 @@ +from socket import NI_NAMEREQD +import torch +import torch.utils.data as data +from PIL import Image +import numpy as np +from torchvision.datasets import MNIST, EMNIST, CIFAR10 +from torchvision.datasets import DatasetFolder +from torchvision import transforms + +import os +import sys +import logging +import pickle +import copy + +logger = logging.getLogger(__name__) + + +def create_ardis_poisoned_dataset(data_path, + base_label=7, + target_label=1, + fraction=0.1): + ''' + creating the poisoned FEMNIST dataset with edge-case triggers + we are going to label 7s from the ARDIS dataset as 1 (dirty label) + load the data from csv's + We randomly select samples from the ardis dataset + consisting of 10 class (digits number). + fraction: the fraction for sampled data. + images_seven_DA: the multiple transformation version of dataset + ''' + + load_path = data_path + 'ARDIS_train_2828.csv' + ardis_images = np.loadtxt(load_path, dtype='float') + load_path = data_path + 'ARDIS_train_labels.csv' + ardis_labels = np.loadtxt(load_path, dtype='float') + + # reshape to be [samples][width][height] + ardis_images = ardis_images.reshape(ardis_images.shape[0], 28, + 28).astype('float32') + + # labels are one-hot encoded + + indices_seven = np.where(ardis_labels[:, base_label] == 1)[0] + images_seven = ardis_images[indices_seven, :] + images_seven = torch.tensor(images_seven).type(torch.uint8) + + if fraction < 1: + num_sampled_data_points = (int)(fraction * images_seven.size()[0]) + perm = torch.randperm(images_seven.size()[0]) + idx = perm[:num_sampled_data_points] + images_seven_cut = images_seven[idx] + images_seven_cut = images_seven_cut.unsqueeze(1) + logger.info('size of images_seven_cut: ', images_seven_cut.size()) + poisoned_labels_cut = (torch.zeros(images_seven_cut.size()[0]) + + target_label).long() + + else: + images_seven_DA = copy.deepcopy(images_seven) + + cand_angles = [180 / fraction * i for i in range(1, fraction + 1)] + logger.info("Candidate angles for DA: {}".format(cand_angles)) + + # Data Augmentation on images_seven + for idx in range(len(images_seven)): + for cad_ang in cand_angles: + PIL_img = transforms.ToPILImage()( + images_seven[idx]).convert("L") + PIL_img_rotate = transforms.functional.rotate(PIL_img, + cad_ang, + fill=(0, )) + + img_rotate = torch.from_numpy(np.array(PIL_img_rotate)) + images_seven_DA = torch.cat( + (images_seven_DA, + img_rotate.reshape(1, + img_rotate.size()[0], + img_rotate.size()[0])), 0) + + logger.info(images_seven_DA.size()) + + poisoned_labels_DA = (torch.zeros(images_seven_DA.size()[0]) + + target_label).long() + + poisoned_edgeset = [] + if fraction < 1: + for ii in range(len(images_seven_cut)): + poisoned_edgeset.append( + (images_seven_cut[ii], poisoned_labels_cut[ii])) + + else: + for ii in range(len(images_seven_DA)): + poisoned_edgeset.append( + (images_seven_DA[ii], poisoned_labels_DA[ii])) + return poisoned_edgeset + + +def create_ardis_test_dataset(data_path, base_label=7, target_label=1): + + # load the data from csv's + load_path = data_path + 'ARDIS_test_2828.csv' + ardis_images = np.loadtxt(load_path, dtype='float') + load_path = data_path + 'ARDIS_test_labels.csv' + ardis_labels = np.loadtxt(load_path, dtype='float') + + # reshape to be [samples][height][width] + ardis_images = torch.tensor( + ardis_images.reshape(ardis_images.shape[0], 28, + 28).astype('float32')).type(torch.uint8) + + indices_seven = np.where(ardis_labels[:, base_label] == 1)[0] + images_seven = ardis_images[indices_seven, :] + images_seven = torch.tensor(images_seven).type(torch.uint8) + images_seven = images_seven.unsqueeze(1) + + poisoned_labels = (torch.zeros(images_seven.size()[0]) + + target_label).long() + poisoned_labels = torch.tensor(poisoned_labels) + + ardis_test_dataset = [] + + for ii in range(len(images_seven)): + ardis_test_dataset.append((images_seven[ii], poisoned_labels[ii])) + + return ardis_test_dataset + diff --git a/federatedscope/attack/auxiliary/poisoning_data.py b/federatedscope/attack/auxiliary/poisoning_data.py new file mode 100644 index 000000000..c47855da6 --- /dev/null +++ b/federatedscope/attack/auxiliary/poisoning_data.py @@ -0,0 +1,298 @@ +from re import M +import torch +from PIL import Image +import numpy as np +from torchvision.datasets import MNIST, EMNIST, CIFAR10 +from torchvision.datasets import DatasetFolder +from torchvision import transforms +from federatedscope.core.auxiliaries.transform_builder import get_transform +from federatedscope.attack.auxiliary.backdoor_utils import selectTrigger +from torch.utils.data import DataLoader, Dataset +from federatedscope.attack.auxiliary.backdoor_utils import normalize +from federatedscope.core.auxiliaries.eunms import MODE +import matplotlib +import pickle +import logging +import os + +logger = logging.getLogger(__name__) + + +def load_poisoned_dataset_edgeset(data, ctx, mode): + + transforms_funcs = get_transform(ctx, 'torchvision')['transform'] + load_path = ctx.attack.edge_path + if "femnist" in ctx.data.type: + if mode == MODE.TRAIN: + train_path = os.path.join(load_path, + "poisoned_edgeset_fraction_0.1") + with open(train_path, "rb") as saved_data_file: + poisoned_edgeset = torch.load(saved_data_file) + num_dps_poisoned_dataset = len(poisoned_edgeset) + + for ii in range(num_dps_poisoned_dataset): + sample, label = poisoned_edgeset[ii] + # (channel, height, width) = sample.shape #(c,h,w) + sample = sample.numpy().transpose(1, 2, 0) + data[mode].dataset.append((transforms_funcs(sample), label)) + + if mode == MODE.TEST or mode == MODE.VAL: + poison_testset = list() + test_path = os.path.join(load_path, 'ardis_test_dataset.pt') + with open(test_path) as saved_data_file: + poisoned_edgeset = torch.load(saved_data_file) + num_dps_poisoned_dataset = len(poisoned_edgeset) + + for ii in range(num_dps_poisoned_dataset): + sample, label = poisoned_edgeset[ii] + # (channel, height, width) = sample.shape #(c,h,w) + sample = sample.numpy().transpose(1, 2, 0) + poison_testset.append((transforms_funcs(sample), label)) + data['poison_' + mode] = DataLoader( + poison_testset, + batch_size=ctx.data.batch_size, + shuffle=False, + num_workers=ctx.data.num_workers) + + elif "CIFAR10" in ctx.data.type: + target_label = int(ctx.attack.target_label_ind) + label = target_label + num_poisoned = ctx.attack.edge_num + if mode == MODE.TRAIN: + train_path = os.path.join(load_path, + 'southwest_images_new_train.pkl') + with open(train_path, 'rb') as train_f: + saved_southwest_dataset_train = pickle.load(train_f) + num_poisoned_dataset = num_poisoned + samped_poisoned_data_indices = np.random.choice( + saved_southwest_dataset_train.shape[0], + num_poisoned_dataset, + replace=False) + saved_southwest_dataset_train = saved_southwest_dataset_train[ + samped_poisoned_data_indices, :, :, :] + + for ii in range(num_poisoned_dataset): + sample = saved_southwest_dataset_train[ii] + data[mode].dataset.append((transforms_funcs(sample), label)) + + logger.info('adding {:d} edge-cased samples in CIFAR-10'.format( + num_poisoned)) + + if mode == MODE.TEST or mode == MODE.VAL: + poison_testset = list() + test_path = os.path.join(load_path, + 'southwest_images_new_test.pkl') + with open(test_path, 'rb') as test_f: + saved_southwest_dataset_test = pickle.load(test_f) + num_poisoned_dataset = len(saved_southwest_dataset_test) + + for ii in range(num_poisoned_dataset): + sample = saved_southwest_dataset_test[ii] + poison_testset.append((transforms_funcs(sample), label)) + data['poison_' + mode] = DataLoader( + poison_testset, + batch_size=ctx.data.batch_size, + shuffle=False, + num_workers=ctx.data.num_workers) + + else: + raise RuntimeError( + 'Now, we only support the FEMNIST and CIFAR-10 datasets') + + return data + + +def addTrigger(dataset, + target_label, + inject_portion, + mode, + distance, + trig_h, + trig_w, + trigger_type, + label_type, + surrogate_model=None, + load_path=None): + + height = dataset[0][0].shape[-2] + width = dataset[0][0].shape[-1] + trig_h = int(trig_h * height) + trig_w = int(trig_w * width) + + if 'wanet' in trigger_type: + cross_portion = 2 # default val following the original paper + perm_then = np.random.permutation( + len(dataset + ))[0:int(len(dataset) * inject_portion * (1 + cross_portion))] + perm = perm_then[0:int(len(dataset) * inject_portion)] + perm_cross = perm_then[( + int(len(dataset) * inject_portion) + + 1):int(len(dataset) * inject_portion * (1 + cross_portion))] + else: + perm = np.random.permutation( + len(dataset))[0:int(len(dataset) * inject_portion)] + + dataset_ = list() + for i in range(len(dataset)): + data = dataset[i] + + if label_type == 'dirty': + # all2one attack + if mode == MODE.TRAIN: + img = np.array(data[0]).transpose(1, 2, 0) * 255.0 + img = np.clip(img.astype('uint8'), 0, 255) + height = img.shape[0] + width = img.shape[1] + + if i in perm: + img = selectTrigger(img, height, width, distance, trig_h, + trig_w, trigger_type, load_path) + + dataset_.append((img, target_label)) + + elif 'wanet' in trigger_type and i in perm_cross: + img = selectTrigger(img, width, height, distance, trig_w, + trig_h, 'wanetTriggerCross', load_path) + dataset_.append((img, data[1])) + + else: + dataset_.append((img, data[1])) + + if mode == MODE.TEST or mode == MODE.VAL: + if data[1] == target_label: + continue + + img = np.array(data[0]).transpose(1, 2, 0) * 255.0 + img = np.clip(img.astype('uint8'), 0, 255) + height = img.shape[0] + width = img.shape[1] + if i in perm: + img = selectTrigger(img, width, height, distance, trig_w, + trig_h, trigger_type, load_path) + dataset_.append((img, target_label)) + else: + dataset_.append((img, data[1])) + + elif label_type == 'clean_label': + pass + + return dataset_ + + +def load_poisoned_dataset_pixel(data, ctx, mode): + + trigger_type = ctx.attack.trigger_type + label_type = ctx.attack.label_type + target_label = int(ctx.attack.target_label_ind) + transforms_funcs = get_transform(ctx, 'torchvision')['transform'] + + if "femnist" in ctx.data.type or "CIFAR10" in ctx.data.type: + inject_portion_train = ctx.attack.poison_ratio + else: + raise RuntimeError( + 'Now, we only support the FEMNIST and CIFAR-10 datasets') + + inject_portion_test = 1.0 + + load_path = ctx.attack.trigger_path + + if mode == MODE.TRAIN: + poisoned_dataset = addTrigger(data[mode].dataset, + target_label, + inject_portion_train, + mode=mode, + distance=1, + trig_h=0.1, + trig_w=0.1, + trigger_type=trigger_type, + label_type=label_type, + load_path=load_path) + num_dps_poisoned_dataset = len(poisoned_dataset) + for iii in range(num_dps_poisoned_dataset): + sample, label = poisoned_dataset[iii] + poisoned_dataset[iii] = (transforms_funcs(sample), label) + + data[mode] = DataLoader(poisoned_dataset, + batch_size=ctx.data.batch_size, + shuffle=True, + num_workers=ctx.data.num_workers) + + if mode == MODE.TEST or mode == MODE.VAL: + poisoned_dataset = addTrigger(data[mode].dataset, + target_label, + inject_portion_test, + mode=mode, + distance=1, + trig_h=0.1, + trig_w=0.1, + trigger_type=trigger_type, + label_type=label_type, + load_path=load_path) + num_dps_poisoned_dataset = len(poisoned_dataset) + for iii in range(num_dps_poisoned_dataset): + sample, label = poisoned_dataset[iii] + # (channel, height, width) = sample.shape #(c,h,w) + poisoned_dataset[iii] = (transforms_funcs(sample), label) + + data['poison_' + mode] = DataLoader(poisoned_dataset, + batch_size=ctx.data.batch_size, + shuffle=False, + num_workers=ctx.data.num_workers) + + return data + + +def add_trans_normalize(data, ctx): + ''' + data for each client is a dictionary. + ''' + + for key in data: + num_dataset = len(data[key].dataset) + mean, std = ctx.attack.mean, ctx.attack.std + if "CIFAR10" in ctx.data.type and key == MODE.TRAIN: + transforms_list = [] + transforms_list.append(transforms.RandomHorizontalFlip()) + transforms_list.append(transforms.ToTensor()) + tran_train = transforms.Compose(transforms_list) + for iii in range(num_dataset): + sample = np.array(data[key].dataset[iii][0]).transpose( + 1, 2, 0) * 255.0 + sample = np.clip(sample.astype('uint8'), 0, 255) + sample = Image.fromarray(sample) + sample = tran_train(sample) + data[key].dataset[iii] = (normalize(sample, mean, std), + data[key].dataset[iii][1]) + else: + for iii in range(num_dataset): + data[key].dataset[iii] = (normalize(data[key].dataset[iii][0], + mean, std), + data[key].dataset[iii][1]) + + return data + + +def select_poisoning(data, ctx, mode): + + if 'edge' in ctx.attack.trigger_type: + data = load_poisoned_dataset_edgeset(data, ctx, mode) + elif 'semantic' in ctx.attack.trigger_type: + pass + else: + data = load_poisoned_dataset_pixel(data, ctx, mode) + return data + + +def poisoning(data, ctx): + for i in range(1, len(data) + 1): + if i == ctx.attack.attacker_id: + logger.info(50 * '-') + logger.info('start poisoning at Client: {}'.format(i)) + logger.info(50 * '-') + data[i] = select_poisoning(data[i], ctx, mode=MODE.TRAIN) + data[i] = select_poisoning(data[i], ctx, mode=MODE.TEST) + if data[i].get(MODE.VAL): + data[i] = select_poisoning(data[i], ctx, mode=MODE.VAL) + data[i] = add_trans_normalize(data[i], ctx) + logger.info('finishing the clean and {} poisoning data processing \ + for Client {:d}'.format(ctx.attack.trigger_type, i)) diff --git a/federatedscope/attack/trainer/__init__.py b/federatedscope/attack/trainer/__init__.py index a63660ab0..37d4d78ef 100644 --- a/federatedscope/attack/trainer/__init__.py +++ b/federatedscope/attack/trainer/__init__.py @@ -1,6 +1,8 @@ from federatedscope.attack.trainer.GAN_trainer import * from federatedscope.attack.trainer.MIA_invert_gradient_trainer import * from federatedscope.attack.trainer.PIA_trainer import * +from federatedscope.attack.trainer.backdoor_trainer import * +from federatedscope.attack.trainer.benign_trainer import * __all__ = [ 'wrap_GANTrainer', 'hood_on_fit_start_generator', @@ -9,5 +11,6 @@ 'hook_on_data_injection_sav_data', 'wrap_GradientAscentTrainer', 'hook_on_fit_start_count_round', 'hook_on_batch_start_replace_data_batch', 'hook_on_batch_backward_invert_gradient', - 'hook_on_fit_start_loss_on_target_data' + 'hook_on_fit_start_loss_on_target_data', 'wrap_backdoorTrainer', + 'wrap_benignTrainer' ] diff --git a/federatedscope/attack/trainer/backdoor_trainer.py b/federatedscope/attack/trainer/backdoor_trainer.py new file mode 100644 index 000000000..b00faf15a --- /dev/null +++ b/federatedscope/attack/trainer/backdoor_trainer.py @@ -0,0 +1,180 @@ +import logging +from typing import Type +import torch +import numpy as np +import copy + +from federatedscope.core.trainers import GeneralTorchTrainer +from torch.nn.utils import parameters_to_vector, vector_to_parameters + +logger = logging.getLogger(__name__) + + +def wrap_backdoorTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + + # ---------------- attribute-level plug-in ----------------------- + base_trainer.ctx.target_label_ind \ + = base_trainer.cfg.attack.target_label_ind + base_trainer.ctx.trigger_type = base_trainer.cfg.attack.trigger_type + base_trainer.ctx.label_type = base_trainer.cfg.attack.label_type + + # ---- action-level plug-in ------- + + if base_trainer.cfg.attack.self_opt: + + base_trainer.ctx.self_lr = base_trainer.cfg.attack.self_lr + base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_init_local_opt, + trigger='on_fit_start', + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt, + trigger='on_fit_end', + insert_pos=0) + + scale_poisoning = base_trainer.cfg.attack.scale_poisoning + pgd_poisoning = base_trainer.cfg.attack.pgd_poisoning + + if scale_poisoning or pgd_poisoning: + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_init_local_model, + trigger='on_fit_start', + insert_pos=-1) + + if base_trainer.cfg.attack.scale_poisoning: + + base_trainer.ctx.scale_para = base_trainer.cfg.attack.scale_para + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_end_scale_poisoning, + trigger="on_fit_end", + insert_pos=-1) + + if base_trainer.cfg.attack.pgd_poisoning: + + base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch + base_trainer.ctx.pgd_lr = base_trainer.cfg.attack.pgd_lr + base_trainer.ctx.pgd_eps = base_trainer.cfg.attack.pgd_eps + base_trainer.ctx.batch_index = 0 + + base_trainer.register_hook_in_train( + new_hook=hook_on_fit_start_init_local_pgd, + trigger='on_fit_start', + insert_pos=-1) + + base_trainer.register_hook_in_train( + new_hook=hook_on_batch_end_project_grad, + trigger='on_batch_end', + insert_pos=-1) + + base_trainer.register_hook_in_train( + new_hook=hook_on_epoch_end_project_grad, + trigger='on_epoch_end', + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt, + trigger='on_fit_end', + insert_pos=0) + + return base_trainer + + +def hook_on_fit_start_init_local_opt(ctx): + + ctx.original_epoch = ctx["num_train_epoch"] + ctx["num_train_epoch"] = ctx.self_epoch + + +def hook_on_fit_end_reset_opt(ctx): + + ctx["num_train_epoch"] = ctx.original_epoch + + +def hook_on_fit_start_init_local_model(ctx): + + # the original global model + ctx.original_model = copy.deepcopy(ctx.model) + + +def hook_on_fit_end_scale_poisoning(ctx): + + # conduct the scale poisoning + scale_para = ctx.scale_para + + v = torch.nn.utils.parameters_to_vector(ctx.original_model.parameters()) + logger.info("the Norm of the original global model: {}".format( + torch.norm(v).item())) + + v = torch.nn.utils.parameters_to_vector(ctx.model.parameters()) + logger.info("Attacker before scaling : Norm = {}".format( + torch.norm(v).item())) + + ctx.original_model = list(ctx.original_model.parameters()) + + for idx, param in enumerate(ctx.model.parameters()): + param.data = (param.data - ctx.original_model[idx] + ) * scale_para + ctx.original_model[idx] + + v = torch.nn.utils.parameters_to_vector(ctx.model.parameters()) + logger.info("Attacker after scaling : Norm = {}".format( + torch.norm(v).item())) + + logger.info('finishing model scaling poisoning attack') + + +def hook_on_fit_start_init_local_pgd(ctx): + + ctx.original_optimizer = ctx.optimizer + ctx.original_epoch = ctx["num_train_epoch"] + ctx["num_train_epoch"] = ctx.self_epoch + ctx.optimizer = torch.optim.SGD(ctx.model.parameters(), lr=ctx.pgd_lr) + # looks like adversary needs same lr to hide with others + + +def hook_on_batch_end_project_grad(ctx): + ''' + after every 10 iters, we project update on the predefined norm ball. + ''' + eps = ctx.pgd_eps + project_frequency = 10 + ctx.batch_index += 1 + w = list(ctx.model.parameters()) + w_vec = parameters_to_vector(w) + model_original_vec = parameters_to_vector( + list(ctx.original_model.parameters())) + # make sure you project on last iteration otherwise, + # high LR pushes you really far + if (ctx.batch_index % project_frequency + == 0) and (torch.norm(w_vec - model_original_vec) > eps): + # project back into norm ball + w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm( + w_vec - model_original_vec) + model_original_vec + # plug w_proj back into model + vector_to_parameters(w_proj_vec, w) + + +def hook_on_epoch_end_project_grad(ctx): + ''' + after the whole epoch, we project the update on the predefined norm ball. + ''' + ctx.batch_index = 0 + eps = ctx.pgd_eps + w = list(ctx.model.parameters()) + w_vec = parameters_to_vector(w) + model_original_vec = parameters_to_vector( + list(ctx.original_model.parameters())) + if (torch.norm(w_vec - model_original_vec) > eps): + # project back into norm ball + w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm( + w_vec - model_original_vec) + model_original_vec + # plug w_proj back into model + vector_to_parameters(w_proj_vec, w) + + +def hook_on_fit_end_reset_pgd(ctx): + + ctx.optimizer = ctx.original_optimizer diff --git a/federatedscope/attack/trainer/benign_trainer.py b/federatedscope/attack/trainer/benign_trainer.py new file mode 100644 index 000000000..6f3d555c6 --- /dev/null +++ b/federatedscope/attack/trainer/benign_trainer.py @@ -0,0 +1,91 @@ +from calendar import c +import logging +from typing import Type +import torch +import numpy as np + +from federatedscope.core.trainers import GeneralTorchTrainer +from federatedscope.core.auxiliaries.transform_builder import get_transform +from federatedscope.attack.auxiliary.backdoor_utils import normalize +from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset +from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader +from federatedscope.core.auxiliaries.ReIterator import ReIterator + +logger = logging.getLogger(__name__) + + +def wrap_benignTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + ''' + Warp the benign trainer for backdoor attack: + We just add the normalization operation. + Args: + base_trainer: Type: core.trainers.GeneralTorchTrainer + :returns: + The wrapped trainer; Type: core.trainers.GeneralTorchTrainer + ''' + base_trainer.register_hook_in_eval(new_hook=hook_on_fit_end_test_poison, + trigger='on_fit_end', + insert_pos=0) + + return base_trainer + + +def hook_on_fit_end_test_poison(ctx): + """ + Evaluate metrics of poisoning attacks. + """ + + ctx['poison_' + ctx.cur_data_split + + '_loader'] = ctx.data['poison_' + ctx.cur_data_split] + ctx['poison_' + ctx.cur_data_split + + '_data'] = ctx.data['poison_' + ctx.cur_data_split].dataset + ctx['num_poison_' + ctx.cur_data_split + '_data'] = len( + ctx.data['poison_' + ctx.cur_data_split].dataset) + setattr(ctx, "poison_{}_y_true".format(ctx.cur_data_split), []) + setattr(ctx, "poison_{}_y_prob".format(ctx.cur_data_split), []) + setattr(ctx, "poison_num_samples_{}".format(ctx.cur_data_split), 0) + + for batch_idx, (samples, targets) in enumerate( + ctx['poison_' + ctx.cur_data_split + '_loader']): + samples, targets = samples.to(ctx.device), targets.to(ctx.device) + pred = ctx.model(samples) + if len(targets.size()) == 0: + targets = targets.unsqueeze(0) + ctx.poison_y_true = targets + ctx.poison_y_prob = pred + ctx.poison_batch_size = len(targets) + + ctx.get("poison_{}_y_true".format(ctx.cur_data_split)).append( + ctx.poison_y_true.detach().cpu().numpy()) + + ctx.get("poison_{}_y_prob".format(ctx.cur_data_split)).append( + ctx.poison_y_prob.detach().cpu().numpy()) + + setattr( + ctx, "poison_num_samples_{}".format(ctx.cur_data_split), + ctx.get("poison_num_samples_{}".format(ctx.cur_data_split)) + + ctx.poison_batch_size) + + setattr( + ctx, "poison_{}_y_true".format(ctx.cur_data_split), + np.concatenate(ctx.get("poison_{}_y_true".format(ctx.cur_data_split)))) + setattr( + ctx, "poison_{}_y_prob".format(ctx.cur_data_split), + np.concatenate(ctx.get("poison_{}_y_prob".format(ctx.cur_data_split)))) + + logger.info('the {} poisoning samples: {:d}'.format( + ctx.cur_data_split, + ctx.get("poison_num_samples_{}".format(ctx.cur_data_split)))) + + poison_true = ctx['poison_' + ctx.cur_data_split + '_y_true'] + poison_prob = ctx['poison_' + ctx.cur_data_split + '_y_prob'] + + poison_pred = np.argmax(poison_prob, axis=1) + + correct = poison_true == poison_pred + + poisoning_acc = float(np.sum(correct)) / len(correct) + + logger.info('the {} poisoning accuracy: {:f}'.format( + ctx.cur_data_split, poisoning_acc)) diff --git a/federatedscope/attack/worker_as_attacker/__init__.py b/federatedscope/attack/worker_as_attacker/__init__.py index 9a961c10c..ee3a8f3a2 100644 --- a/federatedscope/attack/worker_as_attacker/__init__.py +++ b/federatedscope/attack/worker_as_attacker/__init__.py @@ -7,5 +7,6 @@ __all__ = [ 'plot_target_loss', 'sav_target_loss', 'callback_funcs_for_finish', - 'add_atk_method_to_Client_GradAscent', 'PassiveServer', 'PassivePIAServer' + 'add_atk_method_to_Client_GradAscent', 'PassiveServer', 'PassivePIAServer', + 'BackdoorServer' ] diff --git a/federatedscope/attack/worker_as_attacker/server_attacker.py b/federatedscope/attack/worker_as_attacker/server_attacker.py index 51bb83f06..b44677154 100644 --- a/federatedscope/attack/worker_as_attacker/server_attacker.py +++ b/federatedscope/attack/worker_as_attacker/server_attacker.py @@ -16,6 +16,153 @@ logger = logging.getLogger(__name__) +class BackdoorServer(Server): + ''' + For backdoor attacks, we will choose different sampling stratergies. + fix-frequency, all-round ,or random sampling. + ''' + def __init__(self, + ID=-1, + state=0, + config=None, + data=None, + model=None, + client_num=5, + total_round_num=10, + device='cpu', + strategy=None, + unseen_clients_id=None, + **kwargs): + super(BackdoorServer, self).__init__(ID=ID, + state=state, + data=data, + model=model, + config=config, + client_num=client_num, + total_round_num=total_round_num, + device=device, + strategy=strategy, + **kwargs) + + def broadcast_model_para(self, + msg_type='model_para', + sample_client_num=-1, + filter_unseen_clients=True): + """ + To broadcast the message to all clients or sampled clients + + Arguments: + msg_type: 'model_para' or other user defined msg_type + sample_client_num: the number of sampled clients in the broadcast + behavior. And sample_client_num = -1 denotes to broadcast to + all the clients. + filter_unseen_clients: whether filter out the unseen clients that + do not contribute to FL process by training on their local + data and uploading their local model update. The splitting is + useful to check participation generalization gap in [ICLR'22, + What Do We Mean by Generalization in Federated Learning?] + You may want to set it to be False when in evaluation stage + """ + + if filter_unseen_clients: + # to filter out the unseen clients when sampling + self.sampler.change_state(self.unseen_clients_id, 'unseen') + + if sample_client_num > 0: # only activated at training process + attacker_id = self._cfg.attack.attacker_id + setting = self._cfg.attack.setting + insert_round = self._cfg.attack.insert_round + + if attacker_id == -1 or self._cfg.attack.attack_method == '': + + receiver = np.random.choice(np.arange(1, self.client_num + 1), + size=sample_client_num, + replace=False).tolist() + + elif setting == 'fix': + if self.state % self._cfg.attack.freq == 0: + client_list = np.delete(np.arange(1, self.client_num + 1), + self._cfg.attack.attacker_id - 1) + receiver = np.random.choice(client_list, + size=sample_client_num - 1, + replace=False).tolist() + receiver.insert(0, self._cfg.attack.attacker_id) + logger.info('starting the fix-frequency poisoning attack') + logger.info( + 'starting poisoning round: {:d}, the attacker ID: {:d}' + .format(self.state, self._cfg.attack.attacker_id)) + else: + client_list = np.delete(np.arange(1, self.client_num + 1), + self._cfg.attack.attacker_id - 1) + receiver = np.random.choice(client_list, + size=sample_client_num, + replace=False).tolist() + + elif setting == 'single' and self.state == insert_round: + client_list = np.delete(np.arange(1, self.client_num + 1), + self._cfg.attack.attacker_id - 1) + receiver = np.random.choice(client_list, + size=sample_client_num - 1, + replace=False).tolist() + receiver.insert(0, self._cfg.attack.attacker_id) + logger.info('starting the single-shot poisoning attack') + logger.info( + 'starting poisoning round: {:d}, the attacker ID: {:d}'. + format(self.state, self._cfg.attack.attacker_id)) + + elif self._cfg.attack.setting == 'all': + + client_list = np.delete(np.arange(1, self.client_num + 1), + self._cfg.attack.attacker_id - 1) + receiver = np.random.choice(client_list, + size=sample_client_num - 1, + replace=False).tolist() + receiver.insert(0, self._cfg.attack.attacker_id) + logger.info('starting the all-round poisoning attack') + logger.info( + 'starting poisoning round: {:d}, the attacker ID: {:d}'. + format(self.state, self._cfg.attack.attacker_id)) + + else: + receiver = np.random.choice(np.arange(1, self.client_num + 1), + size=sample_client_num, + replace=False).tolist() + + else: + # broadcast to all clients + receiver = list(self.comm_manager.neighbors.keys()) + + if self._noise_injector is not None and msg_type == 'model_para': + # Inject noise only when broadcast parameters + for model_idx_i in range(len(self.models)): + num_sample_clients = [ + v["num_sample"] for v in self.join_in_info.values() + ] + self._noise_injector(self._cfg, num_sample_clients, + self.models[model_idx_i]) + + skip_broadcast = self._cfg.federate.method in ["local", "global"] + if self.model_num > 1: + model_para = [{} if skip_broadcast else model.state_dict() + for model in self.models] + else: + model_para = {} if skip_broadcast else self.model.state_dict() + + self.comm_manager.send( + Message(msg_type=msg_type, + sender=self.ID, + receiver=receiver, + state=min(self.state, self.total_round_num), + content=model_para)) + if self._cfg.federate.online_aggr: + for idx in range(self.model_num): + self.aggregators[idx].reset() + + if filter_unseen_clients: + # restore the state of the unseen clients within sampler + self.sampler.change_state(self.unseen_clients_id, 'seen') + + class PassiveServer(Server): ''' In passive attack, the server store the model and the message collected diff --git a/federatedscope/contrib/metrics/poison_acc.py b/federatedscope/contrib/metrics/poison_acc.py new file mode 100644 index 000000000..75f408052 --- /dev/null +++ b/federatedscope/contrib/metrics/poison_acc.py @@ -0,0 +1,31 @@ +from federatedscope.register import register_metric +import numpy as np + + +def compute_poison_metric(ctx): + + poison_true = ctx['poison_' + ctx.cur_data_split + '_y_true'] + poison_prob = ctx['poison_' + ctx.cur_data_split + '_y_prob'] + poison_pred = np.argmax(poison_prob, axis=1) + + correct = poison_true == poison_pred + + return float(np.sum(correct)) / len(correct) + + +def load_poison_metrics(ctx, y_true, y_pred, y_prob, **kwargs): + + if ctx.cur_data_split == 'train': + results = None + else: + results = compute_poison_metric(ctx) + + return results + + +def call_poison_metric(types): + if 'poison_attack_acc' in types: + return 'poison_attack_acc', load_poison_metrics + + +register_metric('poison_attack_acc', call_poison_metric) diff --git a/federatedscope/contrib/model/resnet.py b/federatedscope/contrib/model/resnet.py new file mode 100644 index 000000000..58d742171 --- /dev/null +++ b/federatedscope/contrib/model/resnet.py @@ -0,0 +1,305 @@ +from federatedscope.register import register_model +'''Pre-activation ResNet in PyTorch. + +Reference: +[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun + Identity Mappings in Deep Residual Networks. arXiv:1603.05027 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PreActBlock(nn.Module): + '''Pre-activation version of the BasicBlock.''' + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(PreActBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = nn.Conv2d(in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False)) + + def forward(self, x): + out = F.relu(self.bn1(x)) + shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x + out = self.conv1(out) + out = self.conv2(F.relu(self.bn2(out))) + out += shortcut + return out + + +class PreActBottleneck(nn.Module): + '''Pre-activation version of the original Bottleneck module.''' + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(PreActBottleneck, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, + self.expansion * planes, + kernel_size=1, + bias=False) + + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False)) + + def forward(self, x): + out = F.relu(self.bn1(x)) + shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x + out = self.conv1(out) + out = self.conv2(F.relu(self.bn2(out))) + out = self.conv3(F.relu(self.bn3(out))) + out += shortcut + return out + + +class PreActResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(PreActResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = self.conv1(x) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def PreActResNet18(): + return PreActResNet(PreActBlock, [2, 2, 2, 2]) + + +def PreActResNet34(): + return PreActResNet(PreActBlock, [3, 4, 6, 3]) + + +def PreActResNet50(): + return PreActResNet(PreActBottleneck, [3, 4, 6, 3]) + + +def PreActResNet101(): + return PreActResNet(PreActBottleneck, [3, 4, 23, 3]) + + +def PreActResNet152(): + return PreActResNet(PreActBottleneck, [3, 8, 36, 3]) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, + self.expansion * planes, + kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(self.expansion * planes)) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def ResNet18(): + return ResNet(BasicBlock, [2, 2, 2, 2]) + + +def ResNet34(): + return ResNet(BasicBlock, [3, 4, 6, 3]) + + +def ResNet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + + +def ResNet101(): + return ResNet(Bottleneck, [3, 4, 23, 3]) + + +def ResNet152(): + return ResNet(Bottleneck, [3, 8, 36, 3]) + + +def preact_resnet(model_config): + if '18' in model_config.type: + net = PreActResNet18() + elif '50' in model_config.type: + net = PreActResNet50() + return net + + +def resnet(model_config): + if '18' in model_config.type: + net = ResNet18() + elif '50' in model_config.type: + net = ResNet50() + return net + + +def call_resnet(model_config, local_data): + if 'resnet' in model_config.type and 'pre' in model_config.type: + model = preact_resnet(model_config) + return model + elif 'resnet' in model_config.type and 'pre' not in model_config.type: + model = resnet(model_config) + return model + + +register_model('resnet', call_resnet) diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 226a640e0..a6e3c13d4 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -582,6 +582,38 @@ def get_data(config): else: raise ValueError('Data {} not found.'.format(config.data.type)) + if 'backdoor' in config.attack.attack_method and 'edge' in \ + config.attack.trigger_type: + import os + import torch + from federatedscope.attack.auxiliary import\ + create_ardis_poisoned_dataset, create_ardis_test_dataset + if not os.path.exists(config.attack.edge_path): + os.makedirs(config.attack.edge_path) + poisoned_edgeset = create_ardis_poisoned_dataset( + data_path=config.attack.edge_path) + + ardis_test_dataset = create_ardis_test_dataset( + config.attack.edge_path) + + logger.info("Writing poison_data to: {}".format( + config.attack.edge_path)) + + with open(config.attack.edge_path + "poisoned_edgeset_training", + "wb") as saved_data_file: + torch.save(poisoned_edgeset, saved_data_file) + + with open(config.attack.edge_path+"ardis_test_dataset.pt", "wb") \ + as ardis_data_file: + torch.save(ardis_test_dataset, ardis_data_file) + logger.warning('please notice: downloading the poisoned dataset \ + on cifar-10 from \ + https://github.com/ksreenivasan/OOD_Federated_Learning') + + if 'backdoor' in config.attack.attack_method: + from federatedscope.attack.auxiliary import poisoning + poisoning(data, modified_config) + if config.federate.mode.lower() == 'standalone': return data, modified_config else: diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index e23606b5c..c8cd4d5d7 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -131,13 +131,23 @@ def get_trainer(model=None, base_trainer=trainer) # attacker plug-in + if 'backdoor' in config.attack.attack_method: + from federatedscope.attack.trainer import wrap_benignTrainer + trainer = wrap_benignTrainer(trainer) + if is_attacker: - logger.info( - '---------------- This client is an attacker --------------------') - from federatedscope.attack.auxiliary.attack_trainer_builder import \ - wrap_attacker_trainer + if 'backdoor' in config.attack.attack_method: + logger.info('--------This client is a backdoor attacker --------') + else: + logger.info('-------- This client is an privacy attacker --------') + from federatedscope.attack.auxiliary.attack_trainer_builder \ + import wrap_attacker_trainer trainer = wrap_attacker_trainer(trainer, config) + elif 'backdoor' in config.attack.attack_method: + logger.info( + '----- This client is a benign client for backdoor attacks -----') + # fed algorithm plug-in if config.fedprox.use: from federatedscope.core.trainers import wrap_fedprox_trainer diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py index 9647860a4..88d1a6d99 100644 --- a/federatedscope/core/auxiliaries/worker_builder.py +++ b/federatedscope/core/auxiliaries/worker_builder.py @@ -62,6 +62,11 @@ def get_server_cls(cfg): PassivePIAServer return PassivePIAServer + elif cfg.attack.attack_method.lower() in ['backdoor']: + from federatedscope.attack.worker_as_attacker.server_attacker \ + import BackdoorServer + return BackdoorServer + if cfg.vertical.use: from federatedscope.vertical_fl.worker import vFLServer return vFLServer diff --git a/federatedscope/core/configs/cfg_attack.py b/federatedscope/core/configs/cfg_attack.py index 2f14e8e9a..11ec44c23 100644 --- a/federatedscope/core/configs/cfg_attack.py +++ b/federatedscope/core/configs/cfg_attack.py @@ -13,6 +13,30 @@ def extend_attack_cfg(cfg): cfg.attack.target_label_ind = -1 cfg.attack.attacker_id = -1 + # for backdoor attack + + cfg.attack.edge_path = 'edge_data/' + cfg.attack.trigger_path = 'trigger/' + cfg.attack.setting = 'fix' + cfg.attack.freq = 10 + cfg.attack.insert_round = 100000 + cfg.attack.mean = [0.1307] + cfg.attack.std = [0.3081] + cfg.attack.trigger_type = 'edge' + cfg.attack.label_type = 'dirty' + # dirty, clean_label, dirty-label attack is all2one attack. + cfg.attack.edge_num = 100 + cfg.attack.poison_ratio = 0.5 + cfg.attack.scale_poisoning = False + cfg.attack.scale_para = 1.0 + cfg.attack.pgd_poisoning = False + cfg.attack.pgd_lr = 0.1 + cfg.attack.pgd_eps = 2 + cfg.attack.self_opt = False + cfg.attack.self_lr = 0.05 + cfg.attack.self_epoch = 6 + # Note: the mean and std should be the list type. + # for reconstruct_opt cfg.attack.reconstruct_lr = 0.01 cfg.attack.reconstruct_optim = 'Adam' diff --git a/federatedscope/core/trainers/torch_trainer.py b/federatedscope/core/trainers/torch_trainer.py index ca0786b4b..03ab3c038 100644 --- a/federatedscope/core/trainers/torch_trainer.py +++ b/federatedscope/core/trainers/torch_trainer.py @@ -136,7 +136,8 @@ def _hook_on_fit_start_init(self, ctx): # across different routines ctx.optimizer = get_optimizer(ctx.model, **ctx.cfg[ctx.cur_mode].optimizer) - ctx.scheduler = get_scheduler(ctx.optimizer, **ctx.cfg[ctx.cur_mode].scheduler) + ctx.scheduler = get_scheduler(ctx.optimizer, + **ctx.cfg[ctx.cur_mode].scheduler) # prepare statistics setattr(ctx, "loss_batch_total_{}".format(ctx.cur_data_split), 0) diff --git a/tests/test_backdoor_attack.py b/tests/test_backdoor_attack.py new file mode 100644 index 000000000..bd071c6cf --- /dev/null +++ b/tests/test_backdoor_attack.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from federatedscope.core.auxiliaries.data_builder import get_data +from federatedscope.core.auxiliaries.utils import setup_seed, update_logger +from federatedscope.core.configs.config import global_cfg +from federatedscope.core.fed_runner import FedRunner +from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls + + +class Backdoor_Attack(unittest.TestCase): + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def set_config_femnist(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + cfg.device = 0 + cfg.eval.freq = 1 + cfg.eval.metrics = ['acc', 'correct', 'poison_attack_acc'] + + cfg.early_stop.patience = 0 + cfg.federate.mode = 'standalone' + cfg.train.batch_or_epoch = 'epoch' + cfg.train.local_update_steps = 2 + cfg.federate.total_round_num = 100 + cfg.federate.sample_client_num = 20 + cfg.federate.client_num = 200 + + cfg.data.root = 'test_data/' + cfg.data.type = 'femnist' + cfg.data.splits = [0.6, 0.2, 0.2] + cfg.data.batch_size = 32 + cfg.data.subsample = 0.05 + cfg.data.transform = [['ToTensor']] + + cfg.model.type = 'convnet2' + cfg.model.hidden = 2048 + cfg.model.out_channels = 62 + + cfg.train.optimizer.lr = 0.1 + cfg.train.optimizer.weight_decay = 0.0 + + cfg.criterion.type = 'CrossEntropyLoss' + cfg.trainer.type = 'cvtrainer' + cfg.seed = 123 + + cfg.attack.attack_method = 'backdoor' + cfg.attack.attacker_id = -1 + cfg.attack.inject_round = 0 + cfg.attack.setting = 'fix' + cfg.attack.freq = 10 + cfg.attack.label_type = 'dirty' + cfg.attack.trigger_type = 'gridTrigger' + cfg.attack.target_label_ind = 1 + cfg.attack.mean = [0.1307] + cfg.attack.std = [0.3081] + + return backup_cfg + + def test_backdoor_edge_femnist_standalone(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_femnist(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + + Fed_runner = FedRunner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_best_results = Fed_runner.run() + print(test_best_results) + + # TODO: use a resonable metric + self.assertGreater( + test_best_results["client_summarized_weighted_avg"]['test_acc'], + 0.5) + init_cfg.merge_from_other_cfg(backup_cfg) + + +if __name__ == '__main__': + unittest.main()