-
Notifications
You must be signed in to change notification settings - Fork 214
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add several Byzantine robust algorithms (#552)
- Loading branch information
1 parent
d2e7d08
commit 2f31956
Showing
12 changed files
with
584 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import copy | ||
import torch | ||
from federatedscope.core.aggregators import ClientsAvgAggregator | ||
|
||
|
||
class BulyanAggregator(ClientsAvgAggregator): | ||
""" | ||
Implementation of Bulyan refers to `The Hidden Vulnerability | ||
of Distributed Learning in Byzantium` | ||
[Mhamdi et al., 2018] | ||
(http://proceedings.mlr.press/v80/mhamdi18a/mhamdi18a.pdf) | ||
It combines the MultiKrum aggregator and the treamedmean aggregator | ||
""" | ||
def __init__(self, model=None, device='cpu', config=None): | ||
super(BulyanAggregator, self).__init__(model, device, config) | ||
self.byzantine_node_num = config.aggregator.byzantine_node_num | ||
self.sample_client_rate = config.federate.sample_client_rate | ||
assert 4 * self.byzantine_node_num + 3 <= config.federate.client_num | ||
|
||
def aggregate(self, agg_info): | ||
""" | ||
To preform aggregation with Median aggregation rule | ||
Arguments: | ||
agg_info (dict): the feedbacks from clients | ||
:returns: the aggregated results | ||
:rtype: dict | ||
""" | ||
models = agg_info["client_feedback"] | ||
avg_model = self._aggre_with_bulyan(models) | ||
updated_model = copy.deepcopy(avg_model) | ||
init_model = self.model.state_dict() | ||
for key in avg_model: | ||
updated_model[key] = init_model[key] + avg_model[key] | ||
return updated_model | ||
|
||
def _calculate_distance(self, model_a, model_b): | ||
""" | ||
Calculate the Euclidean distance between two given model para delta | ||
""" | ||
distance = 0.0 | ||
|
||
for key in model_a: | ||
if isinstance(model_a[key], torch.Tensor): | ||
model_a[key] = model_a[key].float() | ||
model_b[key] = model_b[key].float() | ||
else: | ||
model_a[key] = torch.FloatTensor(model_a[key]) | ||
model_b[key] = torch.FloatTensor(model_b[key]) | ||
|
||
distance += torch.dist(model_a[key], model_b[key], p=2) | ||
return distance | ||
|
||
def _calculate_score(self, models): | ||
""" | ||
Calculate Krum scores | ||
""" | ||
model_num = len(models) | ||
closest_num = model_num - self.byzantine_node_num - 2 | ||
|
||
distance_matrix = torch.zeros(model_num, model_num) | ||
for index_a in range(model_num): | ||
for index_b in range(index_a, model_num): | ||
if index_a == index_b: | ||
distance_matrix[index_a, index_b] = float('inf') | ||
else: | ||
distance_matrix[index_a, index_b] = distance_matrix[ | ||
index_b, index_a] = self._calculate_distance( | ||
models[index_a], models[index_b]) | ||
|
||
sorted_distance = torch.sort(distance_matrix)[0] | ||
krum_scores = torch.sum(sorted_distance[:, :closest_num], axis=-1) | ||
return krum_scores | ||
|
||
def _aggre_with_bulyan(self, models): | ||
''' | ||
Apply MultiKrum to select \theta (\theta <= client_num- | ||
2*self.byzantine_node_num) local models | ||
''' | ||
init_model = self.model.state_dict() | ||
global_update = copy.deepcopy(init_model) | ||
models_para = [each_model[1] for each_model in models] | ||
krum_scores = self._calculate_score(models_para) | ||
index_order = torch.sort(krum_scores)[1].numpy() | ||
reliable_models = list() | ||
for number, index in enumerate(index_order): | ||
if number < len(models) - int( | ||
2 * self.sample_client_rate * self.byzantine_node_num): | ||
reliable_models.append(models[index]) | ||
''' | ||
Sort parameter for each coordinate of the rest \theta reliable | ||
local models, and find \gamma (gamma<\theta-2*self.byzantine_num) | ||
parameters closest to the median to perform averaging | ||
''' | ||
exluded_num = int(self.sample_client_rate * self.byzantine_node_num) | ||
gamma = len(reliable_models) - 2 * exluded_num | ||
for key in init_model: | ||
temp = torch.stack( | ||
[each_model[1][key] for each_model in reliable_models], 0) | ||
pos_largest, _ = torch.topk(temp, exluded_num, 0) | ||
neg_smallest, _ = torch.topk(-temp, exluded_num, 0) | ||
new_stacked = torch.cat([temp, -pos_largest, | ||
neg_smallest]).sum(0).float() | ||
new_stacked /= gamma | ||
global_update[key] = new_stacked | ||
return global_update |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import copy | ||
import torch | ||
import numpy as np | ||
from federatedscope.core.aggregators import ClientsAvgAggregator | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MedianAggregator(ClientsAvgAggregator): | ||
""" | ||
Implementation of median refers to `Byzantine-robust distributed | ||
learning: Towards optimal statistical rates` | ||
[Yin et al., 2018] | ||
(http://proceedings.mlr.press/v80/yin18a/yin18a.pdf) | ||
It computes the coordinate-wise median of recieved updates from clients | ||
The code is adapted from https://github.com/bladesteam/blades | ||
""" | ||
def __init__(self, model=None, device='cpu', config=None): | ||
super(MedianAggregator, self).__init__(model, device, config) | ||
self.byzantine_node_num = config.aggregator.byzantine_node_num | ||
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \ | ||
"it should be satisfied that 2*byzantine_node_num + 2 < client_num" | ||
|
||
def aggregate(self, agg_info): | ||
""" | ||
To preform aggregation with Median aggregation rule | ||
Arguments: | ||
agg_info (dict): the feedbacks from clients | ||
:returns: the aggregated results | ||
:rtype: dict | ||
""" | ||
models = agg_info["client_feedback"] | ||
avg_model = self._aggre_with_median(models) | ||
updated_model = copy.deepcopy(avg_model) | ||
init_model = self.model.state_dict() | ||
for key in avg_model: | ||
updated_model[key] = init_model[key] + avg_model[key] | ||
return updated_model | ||
|
||
def _aggre_with_median(self, models): | ||
init_model = self.model.state_dict() | ||
global_update = copy.deepcopy(init_model) | ||
for key in init_model: | ||
temp = torch.stack([each_model[1][key] for each_model in models], | ||
0) | ||
temp_pos, _ = torch.median(temp, dim=0) | ||
temp_neg, _ = torch.median(-temp, dim=0) | ||
global_update[key] = (temp_pos - temp_neg) / 2 | ||
return global_update |
64 changes: 64 additions & 0 deletions
64
federatedscope/core/aggregators/normbounding_aggregator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import logging | ||
import copy | ||
import torch | ||
import numpy as np | ||
from federatedscope.core.aggregators import ClientsAvgAggregator | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class NormboundingAggregator(ClientsAvgAggregator): | ||
""" | ||
The server clips each update to reduce the negative impact \ | ||
of malicious updates. | ||
""" | ||
def __init__(self, model=None, device='cpu', config=None): | ||
super(NormboundingAggregator, self).__init__(model, device, config) | ||
self.norm_bound = config.aggregator.BFT_args.normbounding_norm_bound | ||
|
||
def aggregate(self, agg_info): | ||
""" | ||
To preform aggregation with normbounding aggregation rule | ||
Arguments: | ||
agg_info (dict): the feedbacks from clients | ||
:returns: the aggregated results | ||
:rtype: dict | ||
""" | ||
models = agg_info["client_feedback"] | ||
avg_model = self._aggre_with_normbounding(models) | ||
updated_model = copy.deepcopy(avg_model) | ||
init_model = self.model.state_dict() | ||
for key in avg_model: | ||
updated_model[key] = init_model[key] + avg_model[key] | ||
return updated_model | ||
|
||
def _aggre_with_normbounding(self, models): | ||
models_temp = [] | ||
for each_model in models: | ||
param = self._flatten_updates(each_model[1]) | ||
if torch.norm(param, p=2) > self.norm_bound: | ||
scaling_rate = self.norm_bound / torch.norm(param, p=2) | ||
scaled_param = scaling_rate * param | ||
models_temp.append( | ||
(each_model[0], self._reconstruct_updates(scaled_param))) | ||
else: | ||
models_temp.append(each_model) | ||
return self._para_weighted_avg(models_temp) | ||
|
||
def _flatten_updates(self, model): | ||
model_update = [] | ||
init_model = self.model.state_dict() | ||
for key in init_model: | ||
model_update.append(model[key].view(-1)) | ||
return torch.cat(model_update, dim=0) | ||
|
||
def _reconstruct_updates(self, flatten_updates): | ||
start_idx = 0 | ||
init_model = self.model.state_dict() | ||
reconstructed_model = copy.deepcopy(init_model) | ||
for key in init_model: | ||
reconstructed_model[key] = flatten_updates[ | ||
start_idx:start_idx + len(init_model[key].view(-1))].reshape( | ||
init_model[key].shape) | ||
start_idx = start_idx + len(init_model[key].view(-1)) | ||
return reconstructed_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import copy | ||
import torch | ||
import numpy as np | ||
from federatedscope.core.aggregators import ClientsAvgAggregator | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TrimmedmeanAggregator(ClientsAvgAggregator): | ||
""" | ||
Implementation of median refer to `Byzantine-robust distributed | ||
learning: Towards optimal statistical rates` | ||
[Yin et al., 2018] | ||
(http://proceedings.mlr.press/v80/yin18a/yin18a.pdf) | ||
The code is adapted from https://github.com/bladesteam/blades | ||
""" | ||
def __init__(self, model=None, device='cpu', config=None): | ||
super(TrimmedmeanAggregator, self).__init__(model, device, config) | ||
self.excluded_ratio = \ | ||
config.aggregator.BFT_args.trimmedmean_excluded_ratio | ||
self.byzantine_node_num = config.aggregator.byzantine_node_num | ||
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \ | ||
"it should be satisfied that 2*byzantine_node_num + 2 < client_num" | ||
assert self.excluded_ratio < 0.5 | ||
|
||
def aggregate(self, agg_info): | ||
""" | ||
To preform aggregation with trimmedmean aggregation rule | ||
Arguments: | ||
agg_info (dict): the feedbacks from clients | ||
:returns: the aggregated results | ||
:rtype: dict | ||
""" | ||
models = agg_info["client_feedback"] | ||
avg_model = self._aggre_with_trimmedmean(models) | ||
updated_model = copy.deepcopy(avg_model) | ||
init_model = self.model.state_dict() | ||
for key in avg_model: | ||
updated_model[key] = init_model[key] + avg_model[key] | ||
return updated_model | ||
|
||
def _aggre_with_trimmedmean(self, models): | ||
init_model = self.model.state_dict() | ||
global_update = copy.deepcopy(init_model) | ||
excluded_num = int(len(models) * self.excluded_ratio) | ||
for key in init_model: | ||
temp = torch.stack([each_model[1][key] for each_model in models], | ||
0) | ||
pos_largest, _ = torch.topk(temp, excluded_num, 0) | ||
neg_smallest, _ = torch.topk(-temp, excluded_num, 0) | ||
new_stacked = torch.cat([temp, -pos_largest, | ||
neg_smallest]).sum(0).float() | ||
new_stacked /= len(temp) - 2 * excluded_num | ||
global_update[key] = new_stacked | ||
return global_update |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.