diff --git a/federatedscope/attack/auxiliary/attack_trainer_builder.py b/federatedscope/attack/auxiliary/attack_trainer_builder.py index aed648566..6642df7f1 100644 --- a/federatedscope/attack/auxiliary/attack_trainer_builder.py +++ b/federatedscope/attack/auxiliary/attack_trainer_builder.py @@ -9,11 +9,11 @@ def wrap_attacker_trainer(base_trainer, config): ''' if config.attack.attack_method.lower() == 'gan_attack': - from federatedscope.attack.trainer.GAN_trainer import wrap_GANTrainer + from federatedscope.attack.trainer import wrap_GANTrainer return wrap_GANTrainer(base_trainer) elif config.attack.attack_method.lower() == 'gradascent': - from federatedscope.attack.trainer.MIA_invert_gradient_trainer import wrap_GradientAscentTrainer + from federatedscope.attack.trainer import wrap_GradientAscentTrainer return wrap_GradientAscentTrainer(base_trainer) else: raise ValueError('Trainer {} is not provided'.format( - config.attack.attack_method)) \ No newline at end of file + config.attack.attack_method)) diff --git a/federatedscope/attack/trainer/GAN_trainer.py b/federatedscope/attack/trainer/GAN_trainer.py index c4df526bd..a1b2b0393 100644 --- a/federatedscope/attack/trainer/GAN_trainer.py +++ b/federatedscope/attack/trainer/GAN_trainer.py @@ -1,8 +1,8 @@ -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +import logging from typing import Type +from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.attack.privacy_attacks.GAN_based_attack import GANCRA -import logging logger = logging.getLogger(__name__) diff --git a/federatedscope/attack/trainer/MIA_invert_gradient_trainer.py b/federatedscope/attack/trainer/MIA_invert_gradient_trainer.py index fe2c63c9a..171d8936f 100644 --- a/federatedscope/attack/trainer/MIA_invert_gradient_trainer.py +++ b/federatedscope/attack/trainer/MIA_invert_gradient_trainer.py @@ -1,7 +1,9 @@ -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +import logging from typing import Type + import torch -import logging + +from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset from federatedscope.attack.auxiliary.MIA_get_target_data import get_target_data diff --git a/federatedscope/attack/trainer/PIA_trainer.py b/federatedscope/attack/trainer/PIA_trainer.py index 5532b5ac2..d0826b30e 100644 --- a/federatedscope/attack/trainer/PIA_trainer.py +++ b/federatedscope/attack/trainer/PIA_trainer.py @@ -1,6 +1,6 @@ -from federatedscope.core.trainers.trainer import GeneralTorchTrainer from typing import Type +from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.attack.auxiliary.utils import get_data_property diff --git a/federatedscope/autotune/fedex/client.py b/federatedscope/autotune/fedex/client.py index 4c20f1255..ba8a50c63 100644 --- a/federatedscope/autotune/fedex/client.py +++ b/federatedscope/autotune/fedex/client.py @@ -24,7 +24,7 @@ def _apply_hyperparams(self, hyperparams): self._cfg.defrost() self._cfg.merge_from_list(cmd_args) - self._cfg.freeze() + self._cfg.freeze(inform=False) self.trainer.ctx.setup_vars() @@ -53,9 +53,8 @@ def callback_funcs_for_model_para(self, message: Message): role='Client #{}'.format(self.ID), return_raw=True)) - # TODO: using validation loss as feedback and validation set size as weight - content = (sample_size, model_para_all, arms, - results["train_avg_loss"]) + results['arms'] = arms + content = (sample_size, model_para_all, results) self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, diff --git a/federatedscope/autotune/fedex/server.py b/federatedscope/autotune/fedex/server.py index dae534c08..4056b19aa 100644 --- a/federatedscope/autotune/fedex/server.py +++ b/federatedscope/autotune/fedex/server.py @@ -49,7 +49,6 @@ def __init__(self, # in which case, self._cfsp will be a list with length equal to #aspects pass sizes = [len(cand_set) for cand_set in self._cfsp] - # TODO: support other step size eta0 = 'auto' self._eta0 = [ np.sqrt(2.0 * np.log(size)) if eta0 == 'auto' else eta0 @@ -155,11 +154,12 @@ def callback_funcs_model_para(self, message: Message): if round not in self.msg_buffer['train'].keys(): self.msg_buffer['train'][round] = dict() - self.msg_buffer['train'][round][sender] = list(content) + self.msg_buffer['train'][round][sender] = content if self._cfg.federate.online_aggr: self.aggregator.inc(tuple(content[0:2])) - self.check_and_move_on() + + return self.check_and_move_on() def update_policy(self, feedbacks): """Update the policy. This implementation is borrowed from the open-sourced FedEx (https://github.com/mkhodak/FedEx/blob/150fac03857a3239429734d59d319da71191872e/hyper.py#L151) @@ -167,12 +167,13 @@ def update_policy(self, feedbacks): feedbacks (list): each element is a tuple in the form (sample_size, arms, loss) """ - index = [tp[1] for tp in feedbacks] - weight = np.asarray([tp[0] for tp in feedbacks], dtype=np.float64) + index = [elem['arms'] for elem in feedbacks] + weight = np.asarray([elem['val_total'] for elem in feedbacks], + dtype=np.float64) weight /= np.sum(weight) - # TODO: acquire client-wise validation loss before local updates - before = np.asarray([tp[2] for tp in feedbacks]) - after = np.asarray([tp[2] for tp in feedbacks]) + before = np.asarray( + [elem['val_avg_loss_before'] for elem in feedbacks]) + after = np.asarray([elem['val_avg_loss_after'] for elem in feedbacks]) if self._trace['refine']: trace = self.trace('refine') @@ -226,19 +227,21 @@ def update_policy(self, feedbacks): .format(self.ID, self._theta, self._trace['entropy'][-1], self._trace['mle'][-1])) - def check_and_move_on(self, check_eval_result=False): + def check_and_move_on(self, + check_eval_result=False, + min_received_num=None): """ To check the message_buffer, when enough messages are receiving, trigger some events (such as perform aggregation, evaluation, and move to the next training round) """ + if min_received_num is None: + min_received_num = self._cfg.federate.sample_client_num + assert min_received_num <= self.sample_client_num if check_eval_result: - # all clients are participating in evaluation - minimal_number = self.client_num - else: - # sampled clients are participating in training - minimal_number = self.sample_client_num + min_received_num = len(list(self.comm_manager.neighbors.keys())) - if self.check_buffer(self.state, minimal_number, check_eval_result): + move_on_flag = True # To record whether moving to a new training round or finishing the evaluation + if self.check_buffer(self.state, min_received_num, check_eval_result): if not check_eval_result: # in the training process mab_feedbacks = list() @@ -258,13 +261,10 @@ def check_and_move_on(self, check_eval_result=False): msg_list.append((train_data_size, model_para_multiple[model_idx])) + # collect feedbacks for updating the policy if model_idx == 0: - # temporarily, we consider training loss - # TODO: use validation loss and sample size mab_feedbacks.append( - (train_msg_buffer[client_id][0], - train_msg_buffer[client_id][2], - train_msg_buffer[client_id][3])) + train_msg_buffer[client_id][2]) # Trigger the monitor here (for training) if 'dissim' in self._cfg.eval.monitoring: @@ -318,6 +318,10 @@ def check_and_move_on(self, check_eval_result=False): self.history_results = merge_dict(self.history_results, formatted_eval_res) self.check_and_save() + else: + move_on_flag = False + + return move_on_flag def check_and_save(self): """ diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index 2cfeb1466..5cea75401 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -43,7 +43,7 @@ def get_trainer(model=None, only_for_eval=only_for_eval, monitor=monitor) elif config.backend == 'tensorflow': - from federatedscope.core.trainers.tf_trainer import GeneralTFTrainer + from federatedscope.core.trainers import GeneralTFTrainer trainer = GeneralTFTrainer(model=model, data=data, device=device, @@ -72,7 +72,8 @@ def get_trainer(model=None, ]: dict_path = "federatedscope.gfl.trainer.nodetrainer" elif config.trainer.type.lower() in [ - 'flitplustrainer', 'flittrainer', 'fedvattrainer', 'fedfocaltrainer' + 'flitplustrainer', 'flittrainer', 'fedvattrainer', + 'fedfocaltrainer' ]: dict_path = "federatedscope.gfl.flitplus.trainer" elif config.trainer.type.lower() in ['mftrainer']: @@ -106,23 +107,23 @@ def get_trainer(model=None, # differential privacy plug-in if config.nbafl.use: - from federatedscope.core.trainers.trainer_nbafl import wrap_nbafl_trainer + from federatedscope.core.trainers import wrap_nbafl_trainer trainer = wrap_nbafl_trainer(trainer) if config.sgdmf.use: - from federatedscope.mf.trainer.trainer_sgdmf import wrap_MFTrainer + from federatedscope.mf.trainer import wrap_MFTrainer trainer = wrap_MFTrainer(trainer) # personalization plug-in if config.federate.method.lower() == "pfedme": - from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer + from federatedscope.core.trainers import wrap_pFedMeTrainer # wrap style: instance a (class A) -> instance a (class A) trainer = wrap_pFedMeTrainer(trainer) elif config.federate.method.lower() == "ditto": - from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer + from federatedscope.core.trainers import wrap_DittoTrainer # wrap style: instance a (class A) -> instance a (class A) trainer = wrap_DittoTrainer(trainer) elif config.federate.method.lower() == "fedem": - from federatedscope.core.trainers.trainer_FedEM import FedEMTrainer + from federatedscope.core.trainers import FedEMTrainer # copy construct style: instance a (class A) -> instance b (class B) trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer, base_trainer=trainer) @@ -136,7 +137,7 @@ def get_trainer(model=None, # fed algorithm plug-in if config.fedprox.use: - from federatedscope.core.trainers.trainer_fedprox import wrap_fedprox_trainer + from federatedscope.core.trainers import wrap_fedprox_trainer trainer = wrap_fedprox_trainer(trainer) return trainer diff --git a/federatedscope/core/configs/cfg_fl_setting.py b/federatedscope/core/configs/cfg_fl_setting.py index 515686bba..a680abe86 100644 --- a/federatedscope/core/configs/cfg_fl_setting.py +++ b/federatedscope/core/configs/cfg_fl_setting.py @@ -23,6 +23,7 @@ def extend_fl_setting_cfg(cfg): cfg.federate.data_weighted_aggr = False # If True, the weight of aggr is the number of training samples in dataset. cfg.federate.online_aggr = False cfg.federate.make_global_eval = False + cfg.federate.use_diff = False # the method name is used to internally determine composition of different aggregators, messages, handlers, etc., cfg.federate.method = "FedAvg" diff --git a/federatedscope/core/configs/cfg_hpo.py b/federatedscope/core/configs/cfg_hpo.py index b17ea3959..a3e221113 100644 --- a/federatedscope/core/configs/cfg_hpo.py +++ b/federatedscope/core/configs/cfg_hpo.py @@ -62,6 +62,8 @@ def assert_hpo_cfg(cfg): ['adaptive', 'aggressive', 'auto', 'constant', 'scale']) assert cfg.hpo.fedex.gamma >= .0 and cfg.hpo.fedex.gamma <= 1.0, "{} must be in [0, 1]".format( cfg.hpo.fedex.gamma) + assert cfg.hpo.fedex.diff == cfg.federate.use_diff, "Inconsistent values for cfg.hpo.fedex.diff={} and cfg.federate.use_diff={}".format( + cfg.hpo.fedex.diff, cfg.federate.use_diff) register_config("hpo", extend_hpo_cfg) diff --git a/federatedscope/core/configs/config.py b/federatedscope/core/configs/config.py index 355c6fdf0..1edcca414 100644 --- a/federatedscope/core/configs/config.py +++ b/federatedscope/core/configs/config.py @@ -84,7 +84,7 @@ def clean_unused_sub_cfgs(self): else: del v[k] - def freeze(self): + def freeze(self, inform=True): """ 1) make the cfg attributes immutable; 2) save the frozen cfg_check_funcs into "self.outdir/config.yaml" for better reproducibility; @@ -114,7 +114,8 @@ def freeze(self): cfg_yaml = yaml.safe_load(tmp_cfg.dump()) wandb.config.update(cfg_yaml, allow_val_change=True) - logger.info("the used configs are: \n" + str(tmp_cfg)) + if inform: + logger.info("the used configs are: \n" + str(tmp_cfg)) super(CN, self).freeze() diff --git a/federatedscope/core/trainers/__init__.py b/federatedscope/core/trainers/__init__.py index 803f41db6..979734549 100644 --- a/federatedscope/core/trainers/__init__.py +++ b/federatedscope/core/trainers/__init__.py @@ -1,4 +1,5 @@ -from federatedscope.core.trainers.trainer import Trainer, GeneralTorchTrainer +from federatedscope.core.trainers.trainer import Trainer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from federatedscope.core.trainers.trainer_multi_model import GeneralMultiModelTrainer from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer @@ -11,4 +12,4 @@ 'Trainer', 'Context', 'GeneralTorchTrainer', 'GeneralMultiModelTrainer', 'wrap_pFedMeTrainer', 'wrap_DittoTrainer', 'FedEMTrainer', 'wrap_fedprox_trainer', 'wrap_nbafl_trainer', 'wrap_nbafl_server' -] \ No newline at end of file +] diff --git a/federatedscope/core/trainers/torch_trainer.py b/federatedscope/core/trainers/torch_trainer.py new file mode 100644 index 000000000..a258d11fd --- /dev/null +++ b/federatedscope/core/trainers/torch_trainer.py @@ -0,0 +1,354 @@ +import os +import logging + +import numpy as np +try: + import torch + from torch.utils.data import DataLoader, Dataset +except ImportError: + torch = None + DataLoader = None + Dataset = None + +from federatedscope.core.trainers.trainer import Trainer +from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset +from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader +from federatedscope.core.auxiliaries.ReIterator import ReIterator +from federatedscope.core.monitors.monitor import Monitor + +logger = logging.getLogger(__name__) + + +class GeneralTorchTrainer(Trainer): + def get_model_para(self): + return self._param_filter( + self.ctx.model.state_dict() if self.cfg.federate. + share_local_model else self.ctx.model.cpu().state_dict()) + + def parse_data(self, data): + """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes + + """ + # TODO: more robust for different data + init_dict = dict() + if isinstance(data, dict): + for mode in ["train", "val", "test"]: + init_dict["{}_data".format(mode)] = None + init_dict["{}_loader".format(mode)] = None + init_dict["num_{}_data".format(mode)] = 0 + if data.get(mode, None) is not None: + if isinstance(data.get(mode), Dataset): + init_dict["{}_data".format(mode)] = data.get(mode) + init_dict["num_{}_data".format(mode)] = len( + data.get(mode)) + elif isinstance(data.get(mode), DataLoader): + init_dict["{}_loader".format(mode)] = data.get(mode) + init_dict["num_{}_data".format(mode)] = len( + data.get(mode).dataset) + elif isinstance(data.get(mode), dict): + init_dict["{}_data".format(mode)] = data.get(mode) + init_dict["num_{}_data".format(mode)] = len( + data.get(mode)['y']) + else: + raise TypeError("Type {} is not supported.".format( + type(data.get(mode)))) + else: + raise TypeError("Type of data should be dict.") + return init_dict + + def train(self, target_data_split_name="train", hooks_set=None): + hooks_set = hooks_set or self.hooks_in_train + if self.ctx.get( + f"{target_data_split_name}_data") is None and self.ctx.get( + f"{target_data_split_name}_loader") is None: + raise ValueError( + f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer" + ) + if self.cfg.federate.use_diff: + # TODO: any issue for subclasses? + before_metric = self.evaluate(target_data_split_name='val') + + self._run_routine("train", hooks_set, target_data_split_name) + result_metric = self.ctx.eval_metrics + + if self.cfg.federate.use_diff: + # TODO: any issue for subclasses? + after_metric = self.evaluate(target_data_split_name='val') + result_metric['val_total'] = before_metric['val_total'] + result_metric['val_avg_loss_before'] = before_metric[ + 'val_avg_loss'] + result_metric['val_avg_loss_after'] = after_metric['val_avg_loss'] + + return self.ctx.num_samples_train, self.get_model_para(), result_metric + + def update(self, model_parameters): + ''' + Called by the FL client to update the model parameters + Arguments: + model_parameters (dict): PyTorch Module object's state_dict. + ''' + for key in model_parameters: + if isinstance(model_parameters[key], list): + model_parameters[key] = torch.FloatTensor( + model_parameters[key]) + self.ctx.model.load_state_dict(self._param_filter(model_parameters), + strict=False) + + def evaluate(self, target_data_split_name="test"): + with torch.no_grad(): + super(GeneralTorchTrainer, self).evaluate(target_data_split_name) + + return self.ctx.eval_metrics + + #def validate(self, target_data_split_name="val"): + # with torch.no_grad(): + # super(GeneralTorchTrainer, self).evaluate(target_data_split_name) + + # return self.ctx.eval_metrics + + def finetune(self, target_data_split_name="train", hooks_set=None): + + # freeze the parameters during the fine-tune stage + require_grad_changed_paras = set() + if self.cfg.trainer.finetune.freeze_param != "": + preserved_paras = self._param_filter( + self.ctx.model.state_dict(), + self.cfg.trainer.finetune.freeze_param) + for name, param in self.ctx.model.named_parameters(): + if name not in preserved_paras and param.requires_grad is True: + param.requires_grad = False + require_grad_changed_paras.add(name) + + # change the optimization configs + original_lrs = [] + for g in self.ctx.optimizer.param_groups: + original_lrs.append(g['lr']) + g['lr'] = self.cfg.trainer.finetune.lr + original_epoch_num = self.ctx["num_train_epoch"] + original_batch_num = self.ctx["num_train_batch"] + self.ctx["num_train_epoch"] = 1 + self.ctx["num_train_batch"] = self.cfg.trainer.finetune.steps + + # do the fine-tuning process + self.train(target_data_split_name, hooks_set) + + # restore the state before fine-tuning + if len(require_grad_changed_paras) > 0: + for name, param in self.ctx.model.named_parameters(): + if name in require_grad_changed_paras: + param.requires_grad = True + + for i, g in enumerate(self.ctx.optimizer.param_groups): + g['lr'] = original_lrs[i] + + self.ctx["num_train_epoch"] = original_epoch_num + self.ctx["num_train_batch"] = original_batch_num + + def register_default_hooks_train(self): + self.register_hook_in_train(self._hook_on_fit_start_init, + "on_fit_start") + self.register_hook_in_train( + self._hook_on_fit_start_calculate_model_size, "on_fit_start") + self.register_hook_in_train(self._hook_on_epoch_start, + "on_epoch_start") + self.register_hook_in_train(self._hook_on_batch_start_init, + "on_batch_start") + self.register_hook_in_train(self._hook_on_batch_forward, + "on_batch_forward") + self.register_hook_in_train(self._hook_on_batch_forward_regularizer, + "on_batch_forward") + self.register_hook_in_train(self._hook_on_batch_forward_flop_count, + "on_batch_forward") + self.register_hook_in_train(self._hook_on_batch_backward, + "on_batch_backward") + self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end") + self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end") + + def register_default_hooks_eval(self): + # test/val + self.register_hook_in_eval(self._hook_on_fit_start_init, + "on_fit_start") + self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start") + self.register_hook_in_eval(self._hook_on_batch_start_init, + "on_batch_start") + self.register_hook_in_eval(self._hook_on_batch_forward, + "on_batch_forward") + self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end") + self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end") + + def _hook_on_fit_start_init(self, ctx): + # prepare model + ctx.model.to(ctx.device) + + # prepare statistics + setattr(ctx, "loss_batch_total_{}".format(ctx.cur_data_split), 0) + setattr(ctx, "loss_regular_total_{}".format(ctx.cur_data_split), 0) + setattr(ctx, "num_samples_{}".format(ctx.cur_data_split), 0) + setattr(ctx, "{}_y_true".format(ctx.cur_data_split), []) + setattr(ctx, "{}_y_prob".format(ctx.cur_data_split), []) + + def _hook_on_fit_start_calculate_model_size(self, ctx): + if not isinstance(self.ctx.monitor, Monitor): + logger.warning( + f"The trainer {type(self)} does contain a valid monitor, this may be caused by " + f"initializing trainer subclasses without passing a valid monitor instance." + f"Plz check whether this is you want.") + return + if self.ctx.monitor.total_model_size == 0: + self.ctx.monitor.track_model_size(ctx.models) + + def _hook_on_epoch_start(self, ctx): + # prepare dataloader + if ctx.get("{}_loader".format(ctx.cur_data_split)) is None: + loader = get_dataloader( + WrapDataset(ctx.get("{}_data".format(ctx.cur_data_split))), + self.cfg) + setattr(ctx, "{}_loader".format(ctx.cur_data_split), + ReIterator(loader)) + elif not isinstance(ctx.get("{}_loader".format(ctx.cur_data_split)), + ReIterator): + setattr( + ctx, "{}_loader".format(ctx.cur_data_split), + ReIterator(ctx.get("{}_loader".format(ctx.cur_data_split)))) + else: + ctx.get("{}_loader".format(ctx.cur_data_split)).reset() + + def _hook_on_batch_start_init(self, ctx): + # prepare data batch + try: + ctx.data_batch = next( + ctx.get("{}_loader".format(ctx.cur_data_split))) + except StopIteration: + raise StopIteration + + def _hook_on_batch_forward(self, ctx): + x, label = [_.to(ctx.device) for _ in ctx.data_batch] + pred = ctx.model(x) + if len(label.size()) == 0: + label = label.unsqueeze(0) + ctx.loss_batch = ctx.criterion(pred, label) + ctx.y_true = label + ctx.y_prob = pred + + ctx.batch_size = len(label) + + def _hook_on_batch_forward_flop_count(self, ctx): + """ + the monitoring hook to calculate the flops during the fl course + + Note: for customized cases that the forward process is not only based on ctx.model, + please override this function (inheritance case) or replace this hook (plug-in case) + + :param ctx: + :return: + """ + if not isinstance(self.ctx.monitor, Monitor): + logger.warning( + f"The trainer {type(self)} does contain a valid monitor, this may be caused by " + f"initializing trainer subclasses without passing a valid monitor instance." + f"Plz check whether this is you want.") + return + + if self.ctx.monitor.flops_per_sample == 0: + # calculate the flops_per_sample + try: + x, y = [_.to(ctx.device) for _ in ctx.data_batch] + from fvcore.nn import FlopCountAnalysis + flops_one_batch = FlopCountAnalysis(ctx.model, x).total() + if self.model_nums > 1 and ctx.mirrored_models: + flops_one_batch *= self.model_nums + logger.warning( + "the flops_per_batch is multiplied by internal model nums as self.mirrored_models=True." + "if this is not the case you want, please customize the count hook" + ) + self.ctx.monitor.track_avg_flops(flops_one_batch, + ctx.batch_size) + except: + logger.error( + "current flop count implementation is for general trainer case: " + "1) ctx.data_batch = [x, y]; and" + "2) the ctx.model takes only x as input." + "Please check the forward format or implement your own flop_count function" + ) + + # by default, we assume the data has the same input shape, + # thus simply multiply the flops to avoid redundant forward + self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * ctx.batch_size + + def _hook_on_batch_forward_regularizer(self, ctx): + ctx.loss_regular = float( + self.cfg.regularizer.mu) * ctx.regularizer(ctx) + ctx.loss_task = ctx.loss_batch + ctx.loss_regular + + def _hook_on_batch_backward(self, ctx): + ctx.optimizer.zero_grad() + ctx.loss_task.backward() + if ctx.grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), + ctx.grad_clip) + ctx.optimizer.step() + + def _hook_on_batch_end(self, ctx): + # update statistics + setattr( + ctx, "loss_batch_total_{}".format(ctx.cur_data_split), + ctx.get("loss_batch_total_{}".format(ctx.cur_data_split)) + + ctx.loss_batch.item() * ctx.batch_size) + + if ctx.get("loss_regular", None) is None or ctx.loss_regular == 0: + loss_regular = 0. + else: + loss_regular = ctx.loss_regular.item() + setattr( + ctx, "loss_regular_total_{}".format(ctx.cur_data_split), + ctx.get("loss_regular_total_{}".format(ctx.cur_data_split)) + + loss_regular) + setattr( + ctx, "num_samples_{}".format(ctx.cur_data_split), + ctx.get("num_samples_{}".format(ctx.cur_data_split)) + + ctx.batch_size) + + # cache label for evaluate + ctx.get("{}_y_true".format(ctx.cur_data_split)).append( + ctx.y_true.detach().cpu().numpy()) + + ctx.get("{}_y_prob".format(ctx.cur_data_split)).append( + ctx.y_prob.detach().cpu().numpy()) + + # clean temp ctx + ctx.data_batch = None + ctx.batch_size = None + ctx.loss_task = None + ctx.loss_batch = None + ctx.loss_regular = None + ctx.y_true = None + ctx.y_prob = None + + def _hook_on_fit_end(self, ctx): + """Evaluate metrics. + + """ + setattr( + ctx, "{}_y_true".format(ctx.cur_data_split), + np.concatenate(ctx.get("{}_y_true".format(ctx.cur_data_split)))) + setattr( + ctx, "{}_y_prob".format(ctx.cur_data_split), + np.concatenate(ctx.get("{}_y_prob".format(ctx.cur_data_split)))) + results = self.metric_calculator.eval(ctx) + setattr(ctx, 'eval_metrics', results) + + def save_model(self, path, cur_round=-1): + assert self.ctx.model is not None + + ckpt = {'cur_round': cur_round, 'model': self.ctx.model.state_dict()} + torch.save(ckpt, path) + + def load_model(self, path): + assert self.ctx.model is not None + + if os.path.exists(path): + ckpt = torch.load(path, map_location=self.ctx.device) + self.ctx.model.load_state_dict(ckpt['model']) + return ckpt['cur_round'] + else: + raise ValueError("The file {} does NOT exist".format(path)) diff --git a/federatedscope/core/trainers/trainer.py b/federatedscope/core/trainers/trainer.py index 5614ab30a..6f1eac1aa 100644 --- a/federatedscope/core/trainers/trainer.py +++ b/federatedscope/core/trainers/trainer.py @@ -4,11 +4,7 @@ import os import numpy as np -from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader -from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset -from federatedscope.core.auxiliaries.ReIterator import ReIterator from federatedscope.core.auxiliaries import utils -from federatedscope.core.monitors.monitor import Monitor from federatedscope.core.trainers.context import Context from federatedscope.core.monitors.metric_calculator import MetricCalculator @@ -192,7 +188,7 @@ def train(self, target_data_split_name="train", hooks_set=None): pass def evaluate(self, target_data_split_name="test", hooks_set=None): - hooks_set = self.hooks_in_eval if hooks_set is None else hooks_set + hooks_set = hooks_set or self.hooks_in_eval if self.ctx.get( f"{target_data_split_name}_data") is None and self.ctx.get( f"{target_data_split_name}_loader") is None: @@ -210,7 +206,9 @@ def _run_routine(self, mode, hooks_set, dataset_name=None): """Run the hooks_set and maintain the mode Arguments: - mode: running mode of client, chosen from train/val/test + mode (str): running mode of client, chosen from train/test + hooks_set (dict): functions to be executed. + dataset_name (str): which split. Note: Considering evaluation could be in ```hooks_set["on_epoch_end"]```, there could be two data loaders in @@ -342,328 +340,3 @@ def load_model(self, path): raise NotImplementedError( "The function `load_model` should be implemented according to the ML backend (Pytorch, Tensorflow ...)." ) - - -class GeneralTorchTrainer(Trainer): - def get_model_para(self): - return self._param_filter( - self.ctx.model.state_dict() if self.cfg.federate. - share_local_model else self.ctx.model.cpu().state_dict()) - - def parse_data(self, data): - """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes - - """ - # TODO: more robust for different data - init_dict = dict() - if isinstance(data, dict): - for mode in ["train", "val", "test"]: - init_dict["{}_data".format(mode)] = None - init_dict["{}_loader".format(mode)] = None - init_dict["num_{}_data".format(mode)] = 0 - if data.get(mode, None) is not None: - if isinstance(data.get(mode), Dataset): - init_dict["{}_data".format(mode)] = data.get(mode) - init_dict["num_{}_data".format(mode)] = len( - data.get(mode)) - elif isinstance(data.get(mode), DataLoader): - init_dict["{}_loader".format(mode)] = data.get(mode) - init_dict["num_{}_data".format(mode)] = len( - data.get(mode).dataset) - elif isinstance(data.get(mode), dict): - init_dict["{}_data".format(mode)] = data.get(mode) - init_dict["num_{}_data".format(mode)] = len( - data.get(mode)['y']) - else: - raise TypeError("Type {} is not supported.".format( - type(data.get(mode)))) - else: - raise TypeError("Type of data should be dict.") - return init_dict - - def train(self, target_data_split_name="train", hooks_set=None): - hooks_set = self.hooks_in_train if hooks_set is None else hooks_set - if self.ctx.get( - f"{target_data_split_name}_data") is None and self.ctx.get( - f"{target_data_split_name}_loader") is None: - raise ValueError( - f"No {target_data_split_name}_data or {target_data_split_name}_loader in the trainer" - ) - self._run_routine("train", hooks_set, target_data_split_name) - - # TODO: The return values should be more flexible? Now: sample_num, model_para, results={k:v} - - return self.ctx.num_samples_train, self.get_model_para( - ), self.ctx.eval_metrics - - def update(self, model_parameters): - ''' - Called by the FL client to update the model parameters - Arguments: - model_parameters (dict): PyTorch Module object's state_dict. - ''' - for key in model_parameters: - if isinstance(model_parameters[key], list): - model_parameters[key] = torch.FloatTensor( - model_parameters[key]) - self.ctx.model.load_state_dict(self._param_filter(model_parameters), - strict=False) - - def evaluate(self, target_data_split_name="test"): - with torch.no_grad(): - super().evaluate(target_data_split_name) - - return self.ctx.eval_metrics - - def validate(self, target_data_split_name="val"): - with torch.no_grad(): - super().evaluate(target_data_split_name) - - return self.ctx.eval_metrics - - def finetune(self, target_data_split_name="train", hooks_set=None): - - # freeze the parameters during the fine-tune stage - require_grad_changed_paras = set() - if self.cfg.trainer.finetune.freeze_param != "": - preserved_paras = self._param_filter( - self.ctx.model.state_dict(), - self.cfg.trainer.finetune.freeze_param) - for name, param in self.ctx.model.named_parameters(): - if name not in preserved_paras and param.requires_grad is True: - param.requires_grad = False - require_grad_changed_paras.add(name) - - # change the optimization configs - original_lrs = [] - for g in self.ctx.optimizer.param_groups: - original_lrs.append(g['lr']) - g['lr'] = self.cfg.trainer.finetune.lr - original_epoch_num = self.ctx["num_train_epoch"] - original_batch_num = self.ctx["num_train_batch"] - self.ctx["num_train_epoch"] = 1 - self.ctx["num_train_batch"] = self.cfg.trainer.finetune.steps - - # do the fine-tuning process - self.train(target_data_split_name, hooks_set) - - # restore the state before fine-tuning - if len(require_grad_changed_paras) > 0: - for name, param in self.ctx.model.named_parameters(): - if name in require_grad_changed_paras: - param.requires_grad = True - - for i, g in enumerate(self.ctx.optimizer.param_groups): - g['lr'] = original_lrs[i] - - self.ctx["num_train_epoch"] = original_epoch_num - self.ctx["num_train_batch"] = original_batch_num - - def register_default_hooks_train(self): - self.register_hook_in_train(self._hook_on_fit_start_init, - "on_fit_start") - self.register_hook_in_train( - self._hook_on_fit_start_calculate_model_size, "on_fit_start") - self.register_hook_in_train(self._hook_on_epoch_start, - "on_epoch_start") - self.register_hook_in_train(self._hook_on_batch_start_init, - "on_batch_start") - self.register_hook_in_train(self._hook_on_batch_forward, - "on_batch_forward") - self.register_hook_in_train(self._hook_on_batch_forward_regularizer, - "on_batch_forward") - self.register_hook_in_train(self._hook_on_batch_forward_flop_count, - "on_batch_forward") - self.register_hook_in_train(self._hook_on_batch_backward, - "on_batch_backward") - self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end") - self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end") - - def register_default_hooks_eval(self): - # test/val - self.register_hook_in_eval(self._hook_on_fit_start_init, - "on_fit_start") - self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start") - self.register_hook_in_eval(self._hook_on_batch_start_init, - "on_batch_start") - self.register_hook_in_eval(self._hook_on_batch_forward, - "on_batch_forward") - self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end") - self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end") - - def _hook_on_fit_start_init(self, ctx): - # prepare model - ctx.model.to(ctx.device) - - # prepare statistics - setattr(ctx, "loss_batch_total_{}".format(ctx.cur_data_split), 0) - setattr(ctx, "loss_regular_total_{}".format(ctx.cur_data_split), 0) - setattr(ctx, "num_samples_{}".format(ctx.cur_data_split), 0) - setattr(ctx, "{}_y_true".format(ctx.cur_data_split), []) - setattr(ctx, "{}_y_prob".format(ctx.cur_data_split), []) - - def _hook_on_fit_start_calculate_model_size(self, ctx): - if not isinstance(self.ctx.monitor, Monitor): - logger.warning( - f"The trainer {type(self)} does contain a valid monitor, this may be caused by " - f"initializing trainer subclasses without passing a valid monitor instance." - f"Plz check whether this is you want.") - return - if self.ctx.monitor.total_model_size == 0: - self.ctx.monitor.track_model_size(ctx.models) - - def _hook_on_epoch_start(self, ctx): - # prepare dataloader - if ctx.get("{}_loader".format(ctx.cur_data_split)) is None: - loader = get_dataloader( - WrapDataset(ctx.get("{}_data".format(ctx.cur_data_split))), - self.cfg) - setattr(ctx, "{}_loader".format(ctx.cur_data_split), - ReIterator(loader)) - elif not isinstance(ctx.get("{}_loader".format(ctx.cur_data_split)), - ReIterator): - setattr( - ctx, "{}_loader".format(ctx.cur_data_split), - ReIterator(ctx.get("{}_loader".format(ctx.cur_data_split)))) - else: - ctx.get("{}_loader".format(ctx.cur_data_split)).reset() - - def _hook_on_batch_start_init(self, ctx): - # prepare data batch - try: - ctx.data_batch = next( - ctx.get("{}_loader".format(ctx.cur_data_split))) - except StopIteration: - raise StopIteration - - def _hook_on_batch_forward(self, ctx): - x, label = [_.to(ctx.device) for _ in ctx.data_batch] - pred = ctx.model(x) - if len(label.size()) == 0: - label = label.unsqueeze(0) - ctx.loss_batch = ctx.criterion(pred, label) - ctx.y_true = label - ctx.y_prob = pred - - ctx.batch_size = len(label) - - def _hook_on_batch_forward_flop_count(self, ctx): - """ - the monitoring hook to calculate the flops during the fl course - - Note: for customized cases that the forward process is not only based on ctx.model, - please override this function (inheritance case) or replace this hook (plug-in case) - - :param ctx: - :return: - """ - if not isinstance(self.ctx.monitor, Monitor): - logger.warning( - f"The trainer {type(self)} does contain a valid monitor, this may be caused by " - f"initializing trainer subclasses without passing a valid monitor instance." - f"Plz check whether this is you want.") - return - - if self.ctx.monitor.flops_per_sample == 0: - # calculate the flops_per_sample - try: - x, y = [_.to(ctx.device) for _ in ctx.data_batch] - from fvcore.nn import FlopCountAnalysis - flops_one_batch = FlopCountAnalysis(ctx.model, x).total() - if self.model_nums > 1 and ctx.mirrored_models: - flops_one_batch *= self.model_nums - logger.warning( - "the flops_per_batch is multiplied by internal model nums as self.mirrored_models=True." - "if this is not the case you want, please customize the count hook" - ) - self.ctx.monitor.track_avg_flops(flops_one_batch, - ctx.batch_size) - except: - logger.error( - "current flop count implementation is for general trainer case: " - "1) ctx.data_batch = [x, y]; and" - "2) the ctx.model takes only x as input." - "Please check the forward format or implement your own flop_count function" - ) - - # by default, we assume the data has the same input shape, - # thus simply multiply the flops to avoid redundant forward - self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * ctx.batch_size - - def _hook_on_batch_forward_regularizer(self, ctx): - ctx.loss_regular = float( - self.cfg.regularizer.mu) * ctx.regularizer(ctx) - ctx.loss_task = ctx.loss_batch + ctx.loss_regular - - def _hook_on_batch_backward(self, ctx): - ctx.optimizer.zero_grad() - ctx.loss_task.backward() - if ctx.grad_clip > 0: - torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), - ctx.grad_clip) - ctx.optimizer.step() - - def _hook_on_batch_end(self, ctx): - # update statistics - setattr( - ctx, "loss_batch_total_{}".format(ctx.cur_data_split), - ctx.get("loss_batch_total_{}".format(ctx.cur_data_split)) + - ctx.loss_batch.item() * ctx.batch_size) - - if ctx.get("loss_regular", None) is None or ctx.loss_regular == 0: - loss_regular = 0. - else: - loss_regular = ctx.loss_regular.item() - setattr( - ctx, "loss_regular_total_{}".format(ctx.cur_data_split), - ctx.get("loss_regular_total_{}".format(ctx.cur_data_split)) + - loss_regular) - setattr( - ctx, "num_samples_{}".format(ctx.cur_data_split), - ctx.get("num_samples_{}".format(ctx.cur_data_split)) + - ctx.batch_size) - - # cache label for evaluate - ctx.get("{}_y_true".format(ctx.cur_data_split)).append( - ctx.y_true.detach().cpu().numpy()) - - ctx.get("{}_y_prob".format(ctx.cur_data_split)).append( - ctx.y_prob.detach().cpu().numpy()) - - # clean temp ctx - ctx.data_batch = None - ctx.batch_size = None - ctx.loss_task = None - ctx.loss_batch = None - ctx.loss_regular = None - ctx.y_true = None - ctx.y_prob = None - - def _hook_on_fit_end(self, ctx): - """Evaluate metrics. - - """ - setattr( - ctx, "{}_y_true".format(ctx.cur_data_split), - np.concatenate(ctx.get("{}_y_true".format(ctx.cur_data_split)))) - setattr( - ctx, "{}_y_prob".format(ctx.cur_data_split), - np.concatenate(ctx.get("{}_y_prob".format(ctx.cur_data_split)))) - results = self.metric_calculator.eval(ctx) - setattr(ctx, 'eval_metrics', results) - - def save_model(self, path, cur_round=-1): - assert self.ctx.model is not None - - ckpt = {'cur_round': cur_round, 'model': self.ctx.model.state_dict()} - torch.save(ckpt, path) - - def load_model(self, path): - assert self.ctx.model is not None - - if os.path.exists(path): - ckpt = torch.load(path, map_location=self.ctx.device) - self.ctx.model.load_state_dict(ckpt['model']) - return ckpt['cur_round'] - else: - raise ValueError("The file {} does NOT exist".format(path)) diff --git a/federatedscope/core/trainers/trainer_Ditto.py b/federatedscope/core/trainers/trainer_Ditto.py index a3afaa1c7..d7a54a3c0 100644 --- a/federatedscope/core/trainers/trainer_Ditto.py +++ b/federatedscope/core/trainers/trainer_Ditto.py @@ -3,7 +3,7 @@ import torch from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from federatedscope.core.optimizer import wrap_regularized_optimizer from typing import Type diff --git a/federatedscope/core/trainers/trainer_FedEM.py b/federatedscope/core/trainers/trainer_FedEM.py index 9df498c05..c5d089f20 100644 --- a/federatedscope/core/trainers/trainer_FedEM.py +++ b/federatedscope/core/trainers/trainer_FedEM.py @@ -4,7 +4,7 @@ import torch from torch.nn.functional import softmax as f_softmax -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from federatedscope.core.trainers.trainer_multi_model import GeneralMultiModelTrainer diff --git a/federatedscope/core/trainers/trainer_fedprox.py b/federatedscope/core/trainers/trainer_fedprox.py index b0d959746..dd91a42b0 100644 --- a/federatedscope/core/trainers/trainer_fedprox.py +++ b/federatedscope/core/trainers/trainer_fedprox.py @@ -1,4 +1,4 @@ -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from typing import Type from copy import deepcopy diff --git a/federatedscope/core/trainers/trainer_multi_model.py b/federatedscope/core/trainers/trainer_multi_model.py index ea0c4946d..a416a43b7 100644 --- a/federatedscope/core/trainers/trainer_multi_model.py +++ b/federatedscope/core/trainers/trainer_multi_model.py @@ -3,7 +3,7 @@ from typing import Type from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer class GeneralMultiModelTrainer(GeneralTorchTrainer): diff --git a/federatedscope/core/trainers/trainer_nbafl.py b/federatedscope/core/trainers/trainer_nbafl.py index ae067ac0b..f33810ef7 100644 --- a/federatedscope/core/trainers/trainer_nbafl.py +++ b/federatedscope/core/trainers/trainer_nbafl.py @@ -1,6 +1,6 @@ from federatedscope.core.auxiliaries.utils import get_random -from federatedscope.core.trainers.trainer import GeneralTorchTrainer -from federatedscope.core.worker.server import Server +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer +#from federatedscope.core.worker.server import Server from typing import Type from copy import deepcopy @@ -123,7 +123,8 @@ def inject_noise_in_broadcast(cfg, sample_client_num, model): }, p.device) -def wrap_nbafl_server(server: Type[Server]) -> Type[Server]: +#def wrap_nbafl_server(server: Type[Server]) -> Type[Server]: +def wrap_nbafl_server(server): """Register noise injector for the server """ diff --git a/federatedscope/core/trainers/trainer_pFedMe.py b/federatedscope/core/trainers/trainer_pFedMe.py index acf53de6a..462e7e81f 100644 --- a/federatedscope/core/trainers/trainer_pFedMe.py +++ b/federatedscope/core/trainers/trainer_pFedMe.py @@ -1,6 +1,6 @@ import copy -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer from federatedscope.core.optimizer import wrap_regularized_optimizer from typing import Type diff --git a/federatedscope/cv/trainer/trainer.py b/federatedscope/cv/trainer/trainer.py index 21cdf29eb..b00ebd66a 100644 --- a/federatedscope/cv/trainer/trainer.py +++ b/federatedscope/cv/trainer/trainer.py @@ -1,5 +1,5 @@ from federatedscope.register import register_trainer -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers import GeneralTorchTrainer class CVTrainer(GeneralTorchTrainer): diff --git a/federatedscope/example_configs/fedex_for_lr.yaml b/federatedscope/example_configs/fedex_for_lr.yaml index 3a2d6b7c4..218e4e5fe 100644 --- a/federatedscope/example_configs/fedex_for_lr.yaml +++ b/federatedscope/example_configs/fedex_for_lr.yaml @@ -9,6 +9,7 @@ federate: share_local_model: True online_aggr: True save_to: 'fedex_test/lr.pth' + use_diff: True trainer: type: 'general' eval: @@ -21,3 +22,4 @@ hpo: fedex: use: True ss: 'federatedscope/example_configs/fedex_search_space.yaml' + diff: True diff --git a/federatedscope/example_configs/quadratic.yaml b/federatedscope/example_configs/quadratic.yaml new file mode 100644 index 000000000..21e6b645b --- /dev/null +++ b/federatedscope/example_configs/quadratic.yaml @@ -0,0 +1,20 @@ +use_gpu: False +federate: + mode: 'standalone' + total_round_num: 5 + make_global_eval: False + client_num: 5 + share_local_model: False + online_aggr: False +trainer: + type: 'general' +eval: + freq: 1 +data: + type: 'quadratic' +model: + type: 'quadratic' +criterion: + type: 'L1Loss' +optimizer: + lr: 0.01 diff --git a/federatedscope/example_configs/quadratic_clientwise.yaml b/federatedscope/example_configs/quadratic_clientwise.yaml new file mode 100644 index 000000000..12709a5ee --- /dev/null +++ b/federatedscope/example_configs/quadratic_clientwise.yaml @@ -0,0 +1,15 @@ +client_1: + optimizer: + lr: 0.625 +client_2: + optimizer: + lr: 0.125 +client_3: + optimizer: + lr: 0.125 +client_4: + optimizer: + lr: 0.125 +client_5: + optimizer: + lr: 0.025 diff --git a/federatedscope/gfl/trainer/graphtrainer.py b/federatedscope/gfl/trainer/graphtrainer.py index 8e336159b..8e8e8e538 100644 --- a/federatedscope/gfl/trainer/graphtrainer.py +++ b/federatedscope/gfl/trainer/graphtrainer.py @@ -1,8 +1,8 @@ +import logging + from federatedscope.core.monitors import Monitor from federatedscope.register import register_trainer -from federatedscope.core.trainers.trainer import GeneralTorchTrainer - -import logging +from federatedscope.core.trainers import GeneralTorchTrainer logger = logging.getLogger(__name__) diff --git a/federatedscope/gfl/trainer/linktrainer.py b/federatedscope/gfl/trainer/linktrainer.py index 337ae2e84..c7f102f01 100644 --- a/federatedscope/gfl/trainer/linktrainer.py +++ b/federatedscope/gfl/trainer/linktrainer.py @@ -7,7 +7,7 @@ from federatedscope.core.monitors import Monitor from federatedscope.register import register_trainer -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.auxiliaries.ReIterator import ReIterator import logging diff --git a/federatedscope/gfl/trainer/nodetrainer.py b/federatedscope/gfl/trainer/nodetrainer.py index 4883e2760..e24c21234 100644 --- a/federatedscope/gfl/trainer/nodetrainer.py +++ b/federatedscope/gfl/trainer/nodetrainer.py @@ -6,7 +6,7 @@ from federatedscope.core.monitors import Monitor from federatedscope.register import register_trainer -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.auxiliaries.ReIterator import ReIterator import logging diff --git a/federatedscope/mf/trainer/trainer.py b/federatedscope/mf/trainer/trainer.py index 71c7e5830..23f3e63ba 100644 --- a/federatedscope/mf/trainer/trainer.py +++ b/federatedscope/mf/trainer/trainer.py @@ -3,7 +3,7 @@ from federatedscope.core.monitors import Monitor from federatedscope.mf.dataloader.dataloader import MFDataLoader -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.register import register_trainer import logging diff --git a/federatedscope/nlp/trainer/trainer.py b/federatedscope/nlp/trainer/trainer.py index 1800d16bd..8f3fac5fc 100644 --- a/federatedscope/nlp/trainer/trainer.py +++ b/federatedscope/nlp/trainer/trainer.py @@ -1,5 +1,5 @@ from federatedscope.register import register_trainer -from federatedscope.core.trainers.trainer import GeneralTorchTrainer +from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.auxiliaries import utils