Skip to content

Commit

Permalink
Calibrate FedEx (#120)
Browse files Browse the repository at this point in the history
* refactored trainers
* enabled fedex to use after-before
  • Loading branch information
joneswong authored May 30, 2022
1 parent fcf6d23 commit 4a986e0
Show file tree
Hide file tree
Showing 28 changed files with 467 additions and 391 deletions.
6 changes: 3 additions & 3 deletions federatedscope/attack/auxiliary/attack_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def wrap_attacker_trainer(base_trainer, config):
'''
if config.attack.attack_method.lower() == 'gan_attack':
from federatedscope.attack.trainer.GAN_trainer import wrap_GANTrainer
from federatedscope.attack.trainer import wrap_GANTrainer
return wrap_GANTrainer(base_trainer)
elif config.attack.attack_method.lower() == 'gradascent':
from federatedscope.attack.trainer.MIA_invert_gradient_trainer import wrap_GradientAscentTrainer
from federatedscope.attack.trainer import wrap_GradientAscentTrainer
return wrap_GradientAscentTrainer(base_trainer)
else:
raise ValueError('Trainer {} is not provided'.format(
config.attack.attack_method))
config.attack.attack_method))
4 changes: 2 additions & 2 deletions federatedscope/attack/trainer/GAN_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from federatedscope.core.trainers.trainer import GeneralTorchTrainer
import logging
from typing import Type

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.attack.privacy_attacks.GAN_based_attack import GANCRA
import logging

logger = logging.getLogger(__name__)

Expand Down
6 changes: 4 additions & 2 deletions federatedscope/attack/trainer/MIA_invert_gradient_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from federatedscope.core.trainers.trainer import GeneralTorchTrainer
import logging
from typing import Type

import torch
import logging

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
from federatedscope.attack.auxiliary.MIA_get_target_data import get_target_data

Expand Down
2 changes: 1 addition & 1 deletion federatedscope/attack/trainer/PIA_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from federatedscope.core.trainers.trainer import GeneralTorchTrainer
from typing import Type

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.attack.auxiliary.utils import get_data_property


Expand Down
7 changes: 3 additions & 4 deletions federatedscope/autotune/fedex/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _apply_hyperparams(self, hyperparams):

self._cfg.defrost()
self._cfg.merge_from_list(cmd_args)
self._cfg.freeze()
self._cfg.freeze(inform=False)

self.trainer.ctx.setup_vars()

Expand Down Expand Up @@ -53,9 +53,8 @@ def callback_funcs_for_model_para(self, message: Message):
role='Client #{}'.format(self.ID),
return_raw=True))

# TODO: using validation loss as feedback and validation set size as weight
content = (sample_size, model_para_all, arms,
results["train_avg_loss"])
results['arms'] = arms
content = (sample_size, model_para_all, results)
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
Expand Down
44 changes: 24 additions & 20 deletions federatedscope/autotune/fedex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __init__(self,
# in which case, self._cfsp will be a list with length equal to #aspects
pass
sizes = [len(cand_set) for cand_set in self._cfsp]
# TODO: support other step size
eta0 = 'auto'
self._eta0 = [
np.sqrt(2.0 * np.log(size)) if eta0 == 'auto' else eta0
Expand Down Expand Up @@ -155,24 +154,26 @@ def callback_funcs_model_para(self, message: Message):
if round not in self.msg_buffer['train'].keys():
self.msg_buffer['train'][round] = dict()

self.msg_buffer['train'][round][sender] = list(content)
self.msg_buffer['train'][round][sender] = content

if self._cfg.federate.online_aggr:
self.aggregator.inc(tuple(content[0:2]))
self.check_and_move_on()

return self.check_and_move_on()

def update_policy(self, feedbacks):
"""Update the policy. This implementation is borrowed from the open-sourced FedEx (https://github.com/mkhodak/FedEx/blob/150fac03857a3239429734d59d319da71191872e/hyper.py#L151)
Arguments:
feedbacks (list): each element is a tuple in the form (sample_size, arms, loss)
"""

index = [tp[1] for tp in feedbacks]
weight = np.asarray([tp[0] for tp in feedbacks], dtype=np.float64)
index = [elem['arms'] for elem in feedbacks]
weight = np.asarray([elem['val_total'] for elem in feedbacks],
dtype=np.float64)
weight /= np.sum(weight)
# TODO: acquire client-wise validation loss before local updates
before = np.asarray([tp[2] for tp in feedbacks])
after = np.asarray([tp[2] for tp in feedbacks])
before = np.asarray(
[elem['val_avg_loss_before'] for elem in feedbacks])
after = np.asarray([elem['val_avg_loss_after'] for elem in feedbacks])

if self._trace['refine']:
trace = self.trace('refine')
Expand Down Expand Up @@ -226,19 +227,21 @@ def update_policy(self, feedbacks):
.format(self.ID, self._theta, self._trace['entropy'][-1],
self._trace['mle'][-1]))

def check_and_move_on(self, check_eval_result=False):
def check_and_move_on(self,
check_eval_result=False,
min_received_num=None):
"""
To check the message_buffer, when enough messages are receiving, trigger some events (such as perform aggregation, evaluation, and move to the next training round)
"""
if min_received_num is None:
min_received_num = self._cfg.federate.sample_client_num
assert min_received_num <= self.sample_client_num

if check_eval_result:
# all clients are participating in evaluation
minimal_number = self.client_num
else:
# sampled clients are participating in training
minimal_number = self.sample_client_num
min_received_num = len(list(self.comm_manager.neighbors.keys()))

if self.check_buffer(self.state, minimal_number, check_eval_result):
move_on_flag = True # To record whether moving to a new training round or finishing the evaluation
if self.check_buffer(self.state, min_received_num, check_eval_result):

if not check_eval_result: # in the training process
mab_feedbacks = list()
Expand All @@ -258,13 +261,10 @@ def check_and_move_on(self, check_eval_result=False):
msg_list.append((train_data_size,
model_para_multiple[model_idx]))

# collect feedbacks for updating the policy
if model_idx == 0:
# temporarily, we consider training loss
# TODO: use validation loss and sample size
mab_feedbacks.append(
(train_msg_buffer[client_id][0],
train_msg_buffer[client_id][2],
train_msg_buffer[client_id][3]))
train_msg_buffer[client_id][2])

# Trigger the monitor here (for training)
if 'dissim' in self._cfg.eval.monitoring:
Expand Down Expand Up @@ -318,6 +318,10 @@ def check_and_move_on(self, check_eval_result=False):
self.history_results = merge_dict(self.history_results,
formatted_eval_res)
self.check_and_save()
else:
move_on_flag = False

return move_on_flag

def check_and_save(self):
"""
Expand Down
17 changes: 9 additions & 8 deletions federatedscope/core/auxiliaries/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_trainer(model=None,
only_for_eval=only_for_eval,
monitor=monitor)
elif config.backend == 'tensorflow':
from federatedscope.core.trainers.tf_trainer import GeneralTFTrainer
from federatedscope.core.trainers import GeneralTFTrainer
trainer = GeneralTFTrainer(model=model,
data=data,
device=device,
Expand Down Expand Up @@ -72,7 +72,8 @@ def get_trainer(model=None,
]:
dict_path = "federatedscope.gfl.trainer.nodetrainer"
elif config.trainer.type.lower() in [
'flitplustrainer', 'flittrainer', 'fedvattrainer', 'fedfocaltrainer'
'flitplustrainer', 'flittrainer', 'fedvattrainer',
'fedfocaltrainer'
]:
dict_path = "federatedscope.gfl.flitplus.trainer"
elif config.trainer.type.lower() in ['mftrainer']:
Expand Down Expand Up @@ -106,23 +107,23 @@ def get_trainer(model=None,

# differential privacy plug-in
if config.nbafl.use:
from federatedscope.core.trainers.trainer_nbafl import wrap_nbafl_trainer
from federatedscope.core.trainers import wrap_nbafl_trainer
trainer = wrap_nbafl_trainer(trainer)
if config.sgdmf.use:
from federatedscope.mf.trainer.trainer_sgdmf import wrap_MFTrainer
from federatedscope.mf.trainer import wrap_MFTrainer
trainer = wrap_MFTrainer(trainer)

# personalization plug-in
if config.federate.method.lower() == "pfedme":
from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer
from federatedscope.core.trainers import wrap_pFedMeTrainer
# wrap style: instance a (class A) -> instance a (class A)
trainer = wrap_pFedMeTrainer(trainer)
elif config.federate.method.lower() == "ditto":
from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer
from federatedscope.core.trainers import wrap_DittoTrainer
# wrap style: instance a (class A) -> instance a (class A)
trainer = wrap_DittoTrainer(trainer)
elif config.federate.method.lower() == "fedem":
from federatedscope.core.trainers.trainer_FedEM import FedEMTrainer
from federatedscope.core.trainers import FedEMTrainer
# copy construct style: instance a (class A) -> instance b (class B)
trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer,
base_trainer=trainer)
Expand All @@ -136,7 +137,7 @@ def get_trainer(model=None,

# fed algorithm plug-in
if config.fedprox.use:
from federatedscope.core.trainers.trainer_fedprox import wrap_fedprox_trainer
from federatedscope.core.trainers import wrap_fedprox_trainer
trainer = wrap_fedprox_trainer(trainer)

return trainer
1 change: 1 addition & 0 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def extend_fl_setting_cfg(cfg):
cfg.federate.data_weighted_aggr = False # If True, the weight of aggr is the number of training samples in dataset.
cfg.federate.online_aggr = False
cfg.federate.make_global_eval = False
cfg.federate.use_diff = False

# the method name is used to internally determine composition of different aggregators, messages, handlers, etc.,
cfg.federate.method = "FedAvg"
Expand Down
2 changes: 2 additions & 0 deletions federatedscope/core/configs/cfg_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def assert_hpo_cfg(cfg):
['adaptive', 'aggressive', 'auto', 'constant', 'scale'])
assert cfg.hpo.fedex.gamma >= .0 and cfg.hpo.fedex.gamma <= 1.0, "{} must be in [0, 1]".format(
cfg.hpo.fedex.gamma)
assert cfg.hpo.fedex.diff == cfg.federate.use_diff, "Inconsistent values for cfg.hpo.fedex.diff={} and cfg.federate.use_diff={}".format(
cfg.hpo.fedex.diff, cfg.federate.use_diff)


register_config("hpo", extend_hpo_cfg)
5 changes: 3 additions & 2 deletions federatedscope/core/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def clean_unused_sub_cfgs(self):
else:
del v[k]

def freeze(self):
def freeze(self, inform=True):
"""
1) make the cfg attributes immutable;
2) save the frozen cfg_check_funcs into "self.outdir/config.yaml" for better reproducibility;
Expand Down Expand Up @@ -114,7 +114,8 @@ def freeze(self):
cfg_yaml = yaml.safe_load(tmp_cfg.dump())
wandb.config.update(cfg_yaml, allow_val_change=True)

logger.info("the used configs are: \n" + str(tmp_cfg))
if inform:
logger.info("the used configs are: \n" + str(tmp_cfg))

super(CN, self).freeze()

Expand Down
5 changes: 3 additions & 2 deletions federatedscope/core/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from federatedscope.core.trainers.trainer import Trainer, GeneralTorchTrainer
from federatedscope.core.trainers.trainer import Trainer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from federatedscope.core.trainers.trainer_multi_model import GeneralMultiModelTrainer
from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer
from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer
Expand All @@ -11,4 +12,4 @@
'Trainer', 'Context', 'GeneralTorchTrainer', 'GeneralMultiModelTrainer',
'wrap_pFedMeTrainer', 'wrap_DittoTrainer', 'FedEMTrainer',
'wrap_fedprox_trainer', 'wrap_nbafl_trainer', 'wrap_nbafl_server'
]
]
Loading

0 comments on commit 4a986e0

Please sign in to comment.