-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add several byzantine robust aggregators #542
Conversation
Add several Byzantine robust aggregators and modify corresponding auxiliary files. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
God job, these aggregators are useful in developing FL algorithms, thank you. Some suggestions are given as inline comments. Note that the details of the implemented algorithms have not been carefully reviewed and would be reviewed ASAP!
@@ -89,6 +91,20 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): | |||
config=config) | |||
elif config.aggregator.krum.use: | |||
return KrumAggregator(model=model, device=device, config=config) | |||
elif config.aggregator.median.use: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if more than one aggregator.xx.use
has been set to True
?
@@ -11,12 +11,40 @@ def extend_aggregator_cfg(cfg): | |||
# ---------------------------------------------------------------------- # | |||
cfg.aggregator = CN() | |||
cfg.aggregator.byzantine_node_num = 0 | |||
cfg.aggregator.client_sampled_ratio = 0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it different from cfg.federate.sample_client_rate
?
@@ -34,6 +34,9 @@ def extend_data_cfg(cfg): | |||
cfg.data.pre_transform = [ | |||
] # pre_transform for `torch_geometric` dataset, use as above | |||
|
|||
# whether to split a root dataset to the server | |||
cfg.data.root_dataset_need = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
root_dataset_need
is not necessary, since the FS can support the case that the server owns data. What we should do is to specify the data slice for the server
@@ -385,7 +385,10 @@ def callback_funcs_for_model_para(self, message: Message): | |||
self.msg_buffer['train'][self.state] = [(sample_size, | |||
content_frame)] | |||
else: | |||
if self._cfg.asyn.use or self._cfg.aggregator.krum.use: | |||
if self._cfg.asyn.use or self._cfg.aggregator.fltrust.use or \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is so.....ugly, IMO. Maybe we should add a new term for the aggregation method and allow extendable hyperparameters for different methods
# model evaluation in server | ||
assert self.model is not None | ||
assert self.data is not None | ||
self.trainer = get_trainer(model=self.model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not merge this with line 133-151?
model_delta[key] = updated_model[key] - init_model[key] | ||
return model_delta | ||
|
||
def _global_trainer(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is only used for Fltrust aggregator
, I think it should not be placed in server.py
'that cfg.aggregator.byzantine_node_num == 0') | ||
if cfg.aggregator.byzantine_node_num == 0 and \ | ||
cfg.aggregator.robust_rule in \ | ||
['krum', 'normbounding', 'median', 'trimmedmean', 'bulyan']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
‘fedavg’?
|
||
class MedianAggregator(ClientsAvgAggregator): | ||
""" | ||
Implementation of median refers to `Byzantine-robust distributed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add refs here
@@ -87,8 +89,20 @@ def get_aggregator(method, model=None, device=None, online=False, config=None): | |||
return AsynClientsAvgAggregator(model=model, | |||
device=device, | |||
config=config) | |||
elif config.aggregator.krum.use: | |||
elif config.aggregator.robust_rule == 'krum': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FedAvg
No description provided.