Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new defense algo and sampler #512

Merged
merged 7 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions federatedscope/attack/auxiliary/attack_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def wrap_attacker_trainer(base_trainer, config):
elif config.attack.attack_method.lower() == 'backdoor':
from federatedscope.attack.trainer import wrap_backdoorTrainer
return wrap_backdoorTrainer(base_trainer)
elif config.attack.attack_method.lower() == 'gaussian_noise':
from federatedscope.attack.trainer import wrap_GaussianAttackTrainer
return wrap_GaussianAttackTrainer(base_trainer)
else:
raise ValueError('Trainer {} is not provided'.format(
config.attack.attack_method))
3 changes: 2 additions & 1 deletion federatedscope/attack/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from federatedscope.attack.trainer.PIA_trainer import *
from federatedscope.attack.trainer.backdoor_trainer import *
from federatedscope.attack.trainer.benign_trainer import *
from federatedscope.attack.trainer.gaussian_attack_trainer import *

__all__ = [
'wrap_GANTrainer', 'hood_on_fit_start_generator',
Expand All @@ -12,5 +13,5 @@
'hook_on_fit_start_count_round', 'hook_on_batch_start_replace_data_batch',
'hook_on_batch_backward_invert_gradient',
'hook_on_fit_start_loss_on_target_data', 'wrap_backdoorTrainer',
'wrap_benignTrainer'
'wrap_benignTrainer', 'wrap_GaussianAttackTrainer'
]
50 changes: 50 additions & 0 deletions federatedscope/attack/trainer/gaussian_attack_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
from typing import Type

import torch

from federatedscope.core.trainers import GeneralTorchTrainer

logger = logging.getLogger(__name__)


def wrap_GaussianAttackTrainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
'''
wrap the gaussian attack trainer

Args:
base_trainer: Type: core.trainers.GeneralTorchTrainer
:returns:
The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
'''

base_trainer.replace_hook_in_train(
new_hook=hook_on_batch_backward_generate_gaussian_noise_gradient,
target_trigger='on_batch_backward',
target_hook_name='_hook_on_batch_backward')

return base_trainer


def hook_on_batch_backward_generate_gaussian_noise_gradient(ctx):
ctx.optimizer.zero_grad()
ctx.loss_task.backward()

grad_values = list()
for name, param in ctx.model.named_parameters():
if 'bn' not in name:
grad_values.append(param.grad.detach().cpu().view(-1))

grad_values = torch.cat(grad_values)
mean_for_gaussian_noise = torch.mean(grad_values) + 0.1
std_for_gaussian_noise = torch.std(grad_values)

for name, param in ctx.model.named_parameters():
if 'bn' not in name:
generated_grad = torch.normal(mean=mean_for_gaussian_noise,
std=std_for_gaussian_noise,
size=param.grad.shape)
param.grad = generated_grad.to(param.grad.device)

ctx.optimizer.step()
2 changes: 2 additions & 0 deletions federatedscope/core/aggregators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from federatedscope.core.aggregators.server_clients_interpolate_aggregator \
import ServerClientsInterpolateAggregator
from federatedscope.core.aggregators.fedopt_aggregator import FedOptAggregator
from federatedscope.core.aggregators.krum_aggregator import KrumAggregator

__all__ = [
'Aggregator',
Expand All @@ -16,4 +17,5 @@
'AsynClientsAvgAggregator',
'ServerClientsInterpolateAggregator',
'FedOptAggregator',
'KrumAggregator',
]
90 changes: 90 additions & 0 deletions federatedscope/core/aggregators/krum_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import copy
import torch
from federatedscope.core.aggregators import ClientsAvgAggregator


class KrumAggregator(ClientsAvgAggregator):
"""
Implementation of Krum/multi-Krum refer to `Machine learning with
adversaries: Byzantine tolerant gradient descent`
[Blanchard P et al., 2017]
(https://proceedings.neurips.cc/paper/2017/hash/
f4b9ec30ad9f68f89b29639786cb62ef-Abstract.html)
"""
def __init__(self, model=None, device='cpu', config=None):
super(KrumAggregator, self).__init__(model, device, config)
self.byzantine_node_num = config.aggregator.byzantine_node_num
self.krum_agg_num = config.aggregator.krum.agg_num
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \
"it should be satisfied that 2*byzantine_node_num + 2 < client_num"

def aggregate(self, agg_info):
"""
To preform aggregation with Krum aggregation rule

Arguments:
agg_info (dict): the feedbacks from clients
:returns: the aggregated results
:rtype: dict
"""
models = agg_info["client_feedback"]
avg_model = self._para_avg_with_krum(models, agg_num=self.krum_agg_num)

# When using Krum/multi-Krum aggregation, the return feedback is model
# delta rather than the model param
updated_model = copy.deepcopy(avg_model)
init_model = self.model.state_dict()
for key in avg_model:
updated_model[key] = init_model[key] + avg_model[key]
return updated_model

def _calculate_distance(self, model_a, model_b):
"""
Calculate the Euclidean distance between two given model para delta
"""
distance = 0.0

for key in model_a:
if isinstance(model_a[key], torch.Tensor):
model_a[key] = model_a[key].float()
model_b[key] = model_b[key].float()
else:
model_a[key] = torch.FloatTensor(model_a[key])
model_b[key] = torch.FloatTensor(model_b[key])

distance += torch.dist(model_a[key], model_b[key], p=2)
return distance

def _calculate_score(self, models):
"""
Calculate Krum scores
"""
model_num = len(models)
closest_num = model_num - self.byzantine_node_num - 2

distance_matrix = torch.zeros(model_num, model_num)
for index_a in range(model_num):
for index_b in range(index_a, model_num):
if index_a == index_b:
distance_matrix[index_a, index_b] = float('inf')
else:
distance_matrix[index_a, index_b] = distance_matrix[
index_b, index_a] = self._calculate_distance(
models[index_a], models[index_b])

sorted_distance = torch.sort(distance_matrix)[0]
krum_scores = torch.sum(sorted_distance[:, :closest_num], axis=-1)
return krum_scores

def _para_avg_with_krum(self, models, agg_num=1):

# each_model: (sample_size, model_para)
models_para = [each_model[1] for each_model in models]
krum_scores = self._calculate_score(models_para)
index_order = torch.sort(krum_scores)[1].numpy()
reliable_models = list()
for number, index in enumerate(index_order):
if number < agg_num:
reliable_models.append(models[index])

return self._para_weighted_avg(models=reliable_models)
4 changes: 3 additions & 1 deletion federatedscope/core/auxiliaries/aggregator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_aggregator(method, model=None, device=None, online=False, config=None):
from federatedscope.core.aggregators import ClientsAvgAggregator, \
OnlineClientsAvgAggregator, ServerClientsInterpolateAggregator, \
FedOptAggregator, NoCommunicationAggregator, \
AsynClientsAvgAggregator
AsynClientsAvgAggregator, KrumAggregator

if method.lower() in constants.AGGREGATOR_TYPE:
aggregator_type = constants.AGGREGATOR_TYPE[method.lower()]
Expand Down Expand Up @@ -87,6 +87,8 @@ def get_aggregator(method, model=None, device=None, online=False, config=None):
return AsynClientsAvgAggregator(model=model,
device=device,
config=config)
elif config.aggregator.krum.use:
return KrumAggregator(model=model, device=device, config=config)
else:
return ClientsAvgAggregator(model=model,
device=device,
Expand Down
6 changes: 5 additions & 1 deletion federatedscope/core/auxiliaries/sampler_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from federatedscope.core.sampler import UniformSampler, GroupSampler
from federatedscope.core.sampler import UniformSampler, GroupSampler, \
ResponsivenessRealtedSampler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -33,6 +34,9 @@ def get_sampler(sample_strategy='uniform',
"""
if sample_strategy == 'uniform':
return UniformSampler(client_num=client_num)
elif sample_strategy == 'responsiveness':
return ResponsivenessRealtedSampler(client_num=client_num,
client_info=client_info)
elif sample_strategy == 'group':
return GroupSampler(client_num=client_num,
client_info=client_info,
Expand Down
19 changes: 18 additions & 1 deletion federatedscope/core/configs/cfg_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
import logging

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config


def extend_aggregator_cfg(cfg):

# ---------------------------------------------------------------------- #
# aggregator related options
# ---------------------------------------------------------------------- #
cfg.aggregator = CN()
cfg.aggregator.byzantine_node_num = 0

# For krum/multi-krum Algos
cfg.aggregator.krum = CN()
cfg.aggregator.krum.use = False
cfg.aggregator.krum.agg_num = 1

# For ATC method
cfg.aggregator.num_agg_groups = 1
cfg.aggregator.num_agg_topk = []
cfg.aggregator.inside_weight = 1.0
Expand All @@ -14,7 +28,10 @@ def extend_aggregator_cfg(cfg):


def assert_aggregator_cfg(cfg):
pass

if cfg.aggregator.byzantine_node_num == 0 and cfg.aggregator.krum.use:
logging.warning('Although krum aggregtion rule is applied, we found '
'that cfg.aggregator.byzantine_node_num == 0')


register_config('aggregator', extend_aggregator_cfg)
17 changes: 17 additions & 0 deletions federatedscope/core/configs/yacs_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,23 @@ def conditional_cast(from_type, to_type):
else:
return False, None

# int <-> list, forced replacement
# To allow single <-> multiple cases
# A usecase: By default, we have cfg.attack.attacker_id = -1 to
# denote that there exists no attacker in the simulated FL course.
# We can easily change it to cfg.attack.attacker_id = 1 to conduct
# experiments with one attacker client#1. However, when we want
# multiple attackersvia setting cfg.attack.attacker_id = [1,2,3],
# an error would be raised by yacs since the default value
# (-1, type int) could be replaced with [1,2,3] ()type list).
# This error motivates me to add such a forced replacement rule
# to support both single and multiple attackers here.
# TODO: a better solucation?
if replacement_type == int and original_type == list:
return replacement
if replacement_type == list and original_type == int:
return replacement

# Conditionally casts
# list <-> tuple
# For py2: allow converting from str (bytes) to a unicode string
Expand Down
35 changes: 35 additions & 0 deletions federatedscope/core/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,38 @@ def sample(self, size, shuffle=False):
self.change_state(item, 'working')

return sampled_clients


class ResponsivenessRealtedSampler(Sampler):
"""
To sample the clients based on their responsiveness (or other information
of clients)
"""
def __init__(self, client_num, client_info):
super(ResponsivenessRealtedSampler, self).__init__(client_num)
self.update_client_info(client_info)

def update_client_info(self, client_info):
"""
To update the client information
"""
self.client_info = np.asarray(
[1.0] + [np.sqrt(x) for x in client_info
]) # client_info[0] is preversed for the server
assert len(self.client_info) == len(
self.client_state
), "The first dimension of client_info is mismatched with client_num"

def sample(self, size):
"""
To sample clients
"""
idle_clients = np.nonzero(self.client_state)[0]
client_info = self.client_info[idle_clients]
client_info = client_info / np.sum(client_info, keepdims=True)
sampled_clients = np.random.choice(idle_clients,
p=client_info,
size=size,
replace=False).tolist()
self.change_state(sampled_clients, 'working')
return sampled_clients
16 changes: 14 additions & 2 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,22 @@ def __init__(self,
# [ICLR'22, What Do We Mean by Generalization in Federated Learning?]
self.is_unseen_client = is_unseen_client

# Parse the attack_id since we support both 'int' (for single attack)
# and 'list' (for multiple attacks) for config.attack.attack_id
parsed_attack_ids = list()
if isinstance(config.attack.attacker_id, int):
parsed_attack_ids.append(config.attack.attacker_id)
elif isinstance(config.attack.attacker_id, list):
parsed_attack_ids = config.attack.attacker_id
else:
raise TypeError(f"The expected types of config.attack.attack_id "
f"include 'int' and 'list', but we got "
f"{type(config.attack.attacker_id)}")

# Attack only support the stand alone model;
# Check if is a attacker; a client is a attacker if the
# config.attack.attack_method is provided
self.is_attacker = config.attack.attacker_id == ID and \
self.is_attacker = ID in parsed_attack_ids and \
config.attack.attack_method != '' and \
config.federate.mode == 'standalone'

Expand Down Expand Up @@ -357,7 +369,7 @@ def callback_funcs_for_model_para(self, message: Message):
self.msg_buffer['train'][self.state] = [(sample_size,
content_frame)]
else:
if self._cfg.asyn.use:
if self._cfg.asyn.use or self._cfg.aggregator.krum.use:
# Return the model delta when using asynchronous training
# protocol, because the staled updated might be discounted
# and cause that the sum of the aggregated weights might
Expand Down
Loading