-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add several byzantine robust aggregators #542
Closed
Closed
Changes from 9 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
d7b47be
add several byzantine robust aggregators
private-mechanism 6ad9d60
add several byzantine robust aggregators
private-mechanism 3ae01c3
add several byzantine robust aggregators
private-mechanism d15cd48
add several byzantine robust aggregators
private-mechanism f832c70
add several byzantine robust aggregators
private-mechanism 81e1f1c
add the implmentation reference of median and trimmedmean aggregator
private-mechanism 283f8e3
Update __init__.py
private-mechanism af1071e
Update server.py
private-mechanism 688af96
Delete fltrust_aggregator.py
private-mechanism ef01028
modify some auxiliary contents of the robust aggregators
private-mechanism 5e7eb91
add serveral Byzantine robust algorithms
private-mechanism File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import copy | ||
import torch | ||
from federatedscope.core.aggregators import ClientsAvgAggregator | ||
|
||
|
||
class BulyanAggregator(ClientsAvgAggregator): | ||
""" | ||
Implementation of Bulyan refers to `The Hidden Vulnerability | ||
of Distributed Learning in Byzantium` | ||
[Mhamdi et al., 2018] | ||
(http://proceedings.mlr.press/v80/mhamdi18a/mhamdi18a.pdf) | ||
|
||
It combines the MultiKrum aggregator and the treamedmean aggregator | ||
""" | ||
def __init__(self, model=None, device='cpu', config=None): | ||
super(BulyanAggregator, self).__init__(model, device, config) | ||
self.byzantine_node_num = config.aggregator.byzantine_node_num | ||
self.sample_client_rate = config.federate.sample_client_rate | ||
assert 4 * self.byzantine_node_num + 3 <= config.federate.client_num | ||
|
||
def aggregate(self, agg_info): | ||
""" | ||
To preform aggregation with Median aggregation rule | ||
Arguments: | ||
agg_info (dict): the feedbacks from clients | ||
:returns: the aggregated results | ||
:rtype: dict | ||
""" | ||
models = agg_info["client_feedback"] | ||
avg_model = self._aggre_with_bulyan(models) | ||
updated_model = copy.deepcopy(avg_model) | ||
init_model = self.model.state_dict() | ||
for key in avg_model: | ||
updated_model[key] = init_model[key] + avg_model[key] | ||
return updated_model | ||
|
||
def _calculate_distance(self, model_a, model_b): | ||
""" | ||
Calculate the Euclidean distance between two given model para delta | ||
""" | ||
distance = 0.0 | ||
|
||
for key in model_a: | ||
if isinstance(model_a[key], torch.Tensor): | ||
model_a[key] = model_a[key].float() | ||
model_b[key] = model_b[key].float() | ||
else: | ||
model_a[key] = torch.FloatTensor(model_a[key]) | ||
model_b[key] = torch.FloatTensor(model_b[key]) | ||
|
||
distance += torch.dist(model_a[key], model_b[key], p=2) | ||
return distance | ||
|
||
def _calculate_score(self, models): | ||
""" | ||
Calculate Krum scores | ||
""" | ||
model_num = len(models) | ||
closest_num = model_num - self.byzantine_node_num - 2 | ||
|
||
distance_matrix = torch.zeros(model_num, model_num) | ||
for index_a in range(model_num): | ||
for index_b in range(index_a, model_num): | ||
if index_a == index_b: | ||
distance_matrix[index_a, index_b] = float('inf') | ||
else: | ||
distance_matrix[index_a, index_b] = distance_matrix[ | ||
index_b, index_a] = self._calculate_distance( | ||
models[index_a], models[index_b]) | ||
|
||
sorted_distance = torch.sort(distance_matrix)[0] | ||
krum_scores = torch.sum(sorted_distance[:, :closest_num], axis=-1) | ||
return krum_scores | ||
|
||
def _aggre_with_bulyan(self, models): | ||
''' | ||
Apply MultiKrum to select \theta (\theta <= client_num- | ||
2*self.byzantine_node_num) local models | ||
''' | ||
init_model = self.model.state_dict() | ||
global_update = copy.deepcopy(init_model) | ||
models_para = [each_model[1] for each_model in models] | ||
krum_scores = self._calculate_score(models_para) | ||
index_order = torch.sort(krum_scores)[1].numpy() | ||
reliable_models = list() | ||
for number, index in enumerate(index_order): | ||
if number < len(models) - int( | ||
2 * self.sample_client_rate * self.byzantine_node_num): | ||
reliable_models.append(models[index]) | ||
''' | ||
Sort parameter for each coordinate of the rest \theta reliable | ||
local models, and find \gamma (gamma<\theta-2*self.byzantine_num) | ||
parameters closest to the median to perform averaging | ||
''' | ||
exluded_num = int(self.sample_client_rate * self.byzantine_node_num) | ||
gamma = len(reliable_models) - 2 * exluded_num | ||
for key in init_model: | ||
temp = torch.stack( | ||
[each_model[1][key] for each_model in reliable_models], 0) | ||
pos_largest, _ = torch.topk(temp, exluded_num, 0) | ||
neg_smallest, _ = torch.topk(-temp, exluded_num, 0) | ||
new_stacked = torch.cat([temp, -pos_largest, | ||
neg_smallest]).sum(0).float() | ||
new_stacked /= gamma | ||
global_update[key] = new_stacked | ||
return global_update |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import copy | ||
import torch | ||
import numpy as np | ||
from federatedscope.core.aggregators import ClientsAvgAggregator | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MedianAggregator(ClientsAvgAggregator): | ||
""" | ||
Implementation of median refers to `Byzantine-robust distributed | ||
learning: Towards optimal statistical rates` | ||
[Yin et al., 2018] | ||
(http://proceedings.mlr.press/v80/yin18a/yin18a.pdf) | ||
|
||
It computes the coordinate-wise median of recieved updates from clients | ||
""" | ||
def __init__(self, model=None, device='cpu', config=None): | ||
super(MedianAggregator, self).__init__(model, device, config) | ||
self.byzantine_node_num = config.aggregator.byzantine_node_num | ||
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \ | ||
"it should be satisfied that 2*byzantine_node_num + 2 < client_num" | ||
|
||
def aggregate(self, agg_info): | ||
""" | ||
To preform aggregation with Median aggregation rule | ||
Arguments: | ||
agg_info (dict): the feedbacks from clients | ||
:returns: the aggregated results | ||
:rtype: dict | ||
""" | ||
models = agg_info["client_feedback"] | ||
avg_model = self._aggre_with_median(models) | ||
updated_model = copy.deepcopy(avg_model) | ||
init_model = self.model.state_dict() | ||
for key in avg_model: | ||
updated_model[key] = init_model[key] + avg_model[key] | ||
return updated_model | ||
|
||
def _aggre_with_median(self, models): | ||
init_model = self.model.state_dict() | ||
global_update = copy.deepcopy(init_model) | ||
for key in init_model: | ||
temp = torch.stack([each_model[1][key] for each_model in models], | ||
0) | ||
temp_pos, _ = torch.median(temp, dim=0) | ||
temp_neg, _ = torch.median(-temp, dim=0) | ||
global_update[key] = (temp_pos - temp_neg) / 2 | ||
return global_update |
64 changes: 64 additions & 0 deletions
64
federatedscope/core/aggregators/normbounding_aggregator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import logging | ||
import copy | ||
import torch | ||
import numpy as np | ||
from federatedscope.core.aggregators import ClientsAvgAggregator | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class NormboundingAggregator(ClientsAvgAggregator): | ||
""" | ||
The server clips each update to reduce the negative impact \ | ||
of malicious updates. | ||
""" | ||
def __init__(self, model=None, device='cpu', config=None): | ||
super(NormboundingAggregator, self).__init__(model, device, config) | ||
self.norm_bound = config.aggregator.normbounding.norm_bound | ||
|
||
def aggregate(self, agg_info): | ||
""" | ||
To preform aggregation with normbounding aggregation rule | ||
Arguments: | ||
agg_info (dict): the feedbacks from clients | ||
:returns: the aggregated results | ||
:rtype: dict | ||
""" | ||
models = agg_info["client_feedback"] | ||
avg_model = self._aggre_with_normbounding(models) | ||
updated_model = copy.deepcopy(avg_model) | ||
init_model = self.model.state_dict() | ||
for key in avg_model: | ||
updated_model[key] = init_model[key] + avg_model[key] | ||
return updated_model | ||
|
||
def _aggre_with_normbounding(self, models): | ||
models_temp = [] | ||
for each_model in models: | ||
param = self._flatten_updates(each_model[1]) | ||
if torch.norm(param, p=2) > self.norm_bound: | ||
scaling_rate = self.norm_bound / torch.norm(param, p=2) | ||
scaled_param = scaling_rate * param | ||
models_temp.append( | ||
(each_model[0], self._reconstruct_updates(scaled_param))) | ||
else: | ||
models_temp.append(each_model) | ||
return self._para_weighted_avg(models_temp) | ||
|
||
def _flatten_updates(self, model): | ||
model_update = [] | ||
init_model = self.model.state_dict() | ||
for key in init_model: | ||
model_update.append(model[key].view(-1)) | ||
return torch.cat(model_update, dim=0) | ||
|
||
def _reconstruct_updates(self, flatten_updates): | ||
start_idx = 0 | ||
init_model = self.model.state_dict() | ||
reconstructed_model = copy.deepcopy(init_model) | ||
for key in init_model: | ||
reconstructed_model[key] = flatten_updates[ | ||
start_idx:start_idx + len(init_model[key].view(-1))].reshape( | ||
init_model[key].shape) | ||
start_idx = start_idx + len(init_model[key].view(-1)) | ||
return reconstructed_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import copy | ||
import torch | ||
import numpy as np | ||
from federatedscope.core.aggregators import ClientsAvgAggregator | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class TrimmedmeanAggregator(ClientsAvgAggregator): | ||
""" | ||
Implementation of median refer to `Byzantine-robust distributed | ||
learning: Towards optimal statistical rates` | ||
[Yin et al., 2018] | ||
(http://proceedings.mlr.press/v80/yin18a/yin18a.pdf) | ||
|
||
""" | ||
def __init__(self, model=None, device='cpu', config=None): | ||
super(TrimmedmeanAggregator, self).__init__(model, device, config) | ||
self.excluded_ratio = config.aggregator.trimmedmean.excluded_ratio | ||
self.byzantine_node_num = config.aggregator.byzantine_node_num | ||
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \ | ||
"it should be satisfied that 2*byzantine_node_num + 2 < client_num" | ||
assert self.excluded_ratio < 0.5 | ||
|
||
def aggregate(self, agg_info): | ||
""" | ||
To preform aggregation with trimmedmean aggregation rule | ||
Arguments: | ||
agg_info (dict): the feedbacks from clients | ||
:returns: the aggregated results | ||
:rtype: dict | ||
""" | ||
models = agg_info["client_feedback"] | ||
avg_model = self._aggre_with_trimmedmean(models) | ||
updated_model = copy.deepcopy(avg_model) | ||
init_model = self.model.state_dict() | ||
for key in avg_model: | ||
updated_model[key] = init_model[key] + avg_model[key] | ||
return updated_model | ||
|
||
def _aggre_with_trimmedmean(self, models): | ||
init_model = self.model.state_dict() | ||
global_update = copy.deepcopy(init_model) | ||
excluded_num = int(len(models) * self.excluded_ratio) | ||
for key in init_model: | ||
temp = torch.stack([each_model[1][key] for each_model in models], | ||
0) | ||
pos_largest, _ = torch.topk(temp, excluded_num, 0) | ||
neg_smallest, _ = torch.topk(-temp, excluded_num, 0) | ||
new_stacked = torch.cat([temp, -pos_largest, | ||
neg_smallest]).sum(0).float() | ||
new_stacked /= len(temp) - 2 * excluded_num | ||
global_update[key] = new_stacked | ||
return global_update |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,7 +58,9 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): | |
from federatedscope.core.aggregators import ClientsAvgAggregator, \ | ||
OnlineClientsAvgAggregator, ServerClientsInterpolateAggregator, \ | ||
FedOptAggregator, NoCommunicationAggregator, \ | ||
AsynClientsAvgAggregator, KrumAggregator | ||
AsynClientsAvgAggregator, KrumAggregator, \ | ||
MedianAggregator, TrimmedmeanAggregator, \ | ||
BulyanAggregator, NormboundingAggregator | ||
|
||
if method.lower() in constants.AGGREGATOR_TYPE: | ||
aggregator_type = constants.AGGREGATOR_TYPE[method.lower()] | ||
|
@@ -87,8 +89,20 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): | |
return AsynClientsAvgAggregator(model=model, | ||
device=device, | ||
config=config) | ||
elif config.aggregator.krum.use: | ||
elif config.aggregator.robust_rule == 'krum': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FedAvg |
||
return KrumAggregator(model=model, device=device, config=config) | ||
elif config.aggregator.robust_rule == 'median': | ||
return MedianAggregator(model=model, device=device, config=config) | ||
elif config.aggregator.robust_rule == 'trimmedmean': | ||
return TrimmedmeanAggregator(model=model, | ||
device=device, | ||
config=config) | ||
elif config.aggregator.robust_rule == 'bulyan': | ||
return BulyanAggregator(model=model, device=device, config=config) | ||
elif config.aggregator.robust_rule == 'normbounding': | ||
return NormboundingAggregator(model=model, | ||
device=device, | ||
config=config) | ||
else: | ||
return ClientsAvgAggregator(model=model, | ||
device=device, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add refs here