From 8d97ffa40b7dca4b278b70cbfd0b1ce4ffb08a22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Wed, 10 Aug 2022 18:11:02 +0800 Subject: [PATCH 1/3] support validation set for MF datasets; fix FedEM for MF datasets; --- federatedscope/core/trainers/trainer_FedEM.py | 93 +++++++++++++++++- federatedscope/mf/dataloader/dataloader.py | 8 +- federatedscope/mf/dataset/movielens.py | 98 +++++++++++++++---- federatedscope/mf/model/model.py | 6 +- federatedscope/mf/trainer/trainer.py | 38 +++++-- 5 files changed, 212 insertions(+), 31 deletions(-) diff --git a/federatedscope/core/trainers/trainer_FedEM.py b/federatedscope/core/trainers/trainer_FedEM.py index e6da6df63..30dc82eaa 100644 --- a/federatedscope/core/trainers/trainer_FedEM.py +++ b/federatedscope/core/trainers/trainer_FedEM.py @@ -10,6 +10,10 @@ from federatedscope.core.trainers.trainer_multi_model import \ GeneralMultiModelTrainer +import logging + +logger = logging.getLogger(__name__) + class FedEMTrainer(GeneralMultiModelTrainer): """ @@ -41,6 +45,10 @@ def __init__(self, self.ctx.all_losses_model_batch = torch.zeros( self.model_nums, self.ctx.num_train_batch).to(device) + if self.cfg.model.type.lower() in ["vmfnet", "hmfnet"]: + # for MF model, we need to use the ratio to adjust the loss + self.ctx.all_ratios_per_model = [[] + for _ in range(self.model_nums)] self.ctx.cur_batch_idx = -1 # `ctx[f"{cur_data}_y_prob_ensemble"] = 0` in # func `_hook_on_fit_end_ensemble_eval` @@ -84,13 +92,23 @@ def register_multiple_model_hooks(self): new_hook=self.hook_on_batch_end_gather_loss, trigger="on_batch_end", insert_pos=0 - ) # insert at the front, (we need gather the loss before clean it) + ) # insert at the front, we need gather the loss before clean it + if self.cfg.model.type.lower() in ["vmfnet", "hmfnet"]: + # for MF model, we need to use the ratio to adjust the loss + self.register_hook_in_eval( + new_hook=self.hook_on_batch_end_gather_ratio_mf, + trigger="on_batch_end", + insert_pos=0 + ) # insert at the front, we need gather the ratio before clean it self.register_hook_in_eval( new_hook=self.hook_on_batch_start_track_batch_idx, trigger="on_batch_start", insert_pos=0) # insert at the front # replace the original evaluation into the ensemble one - self.replace_hook_in_eval(new_hook=self._hook_on_fit_end_ensemble_eval, + ensemble_hook = self._hook_on_fit_end_ensemble_eval_mf if \ + self.cfg.model.type.lower() in ["vmfnet", "hmfnet"] else \ + self._hook_on_fit_end_ensemble_eval + self.replace_hook_in_eval(new_hook=ensemble_hook, target_trigger="on_fit_end", target_hook_name="_hook_on_fit_end") @@ -122,6 +140,12 @@ def hook_on_batch_end_gather_loss(self, ctx): ctx.all_losses_model_batch[ctx.cur_model_idx][ ctx.cur_batch_idx] = ctx.loss_batch.item() + def hook_on_batch_end_gather_ratio_mf(self, ctx): + # for only eval + # before clean the ratio_batch for matrix factorization model; + # we record it for further model ensemble + ctx["all_ratios_per_model"][ctx.cur_model_idx].append(ctx.ratio_batch) + def hook_on_fit_start_mixture_weights_update(self, ctx): # for only train if ctx.cur_model_idx != 0: @@ -167,3 +191,68 @@ def _hook_on_fit_end_ensemble_eval(self, ctx): LIFECYCLE.ROUTINE) ctx.ys_prob = ctx.ys_prob_ensemble ctx.eval_metrics = self.metric_calculator.eval(ctx) + + def _hook_on_fit_end_ensemble_eval_mf(self, ctx): + """ + Ensemble evaluation for matrix factorization model + """ + cur_data = ctx.cur_mode + batch_num = len(ctx[f"{cur_data}_y_prob"]) + if f"{cur_data}_y_prob_ensemble" not in ctx or ctx[ + f"{cur_data}_y_prob_ensemble"] is None: + # ctx[f"{cur_data}_y_prob_ensemble"] = 0 + ctx[f"{cur_data}_y_prob_ensemble"] = [0 for i in range(batch_num)] + + for batch_i in range(batch_num): + try: + ctx[f"{cur_data}_y_prob_ensemble"][batch_i] += \ + ctx[f"{cur_data}_y_prob"][batch_i] *\ + self.weights_internal_models[ctx.cur_model_idx].item() + except IndexError as e: + logger.error( + str(e) + f" When batch_i={batch_i}, " + f"cur_model_idx={ctx.cur_model_idx}") + + # do metrics calculation after the last internal model evaluation done + if ctx.cur_model_idx == self.model_nums - 1: + for batch_i in range(batch_num): + try: + ctx[f"{cur_data}_total"] = 0 + pred = ctx[f"{cur_data}_y_prob_ensemble"][batch_i].\ + todense() + label = ctx[f"{cur_data}_y_true"][batch_i].todense() + ctx[f"loss_batch_total_{cur_data}"] += ctx.criterion( + torch.Tensor(pred).to(ctx.device), + torch.Tensor(label).to(ctx.device) + ) * ctx["all_ratios_per_model"][ctx.cur_model_idx][batch_i] + except IndexError as e: + logger.error( + str(e) + f" When batch_i={batch_i}, " + f"cur_model_idx={ctx.cur_model_idx}") + + # set the eval_metrics + if ctx.num_samples == 0: + results = { + f"{cur_data}_avg_loss": ctx.get( + "loss_batch_total_{}".format(ctx.cur_mode)), + f"{cur_data}_total": 0 + } + else: + results = { + f"{ctx.cur_mode}_avg_loss": ctx.get( + f"loss_batch_total_{ctx.cur_mode}") / ctx.num_samples, + f"{ctx.cur_mode}_total": ctx.num_samples + } + if isinstance(results[f"{ctx.cur_mode}_avg_loss"], torch.Tensor): + results[f"{ctx.cur_mode}_avg_loss"] = results[ + f"{ctx.cur_mode}_avg_loss"].item() + setattr(ctx, 'eval_metrics', results) + + # reset for next run_routine that may have different + # len([f"{cur_data}_y_prob"]) + ctx[f"{cur_data}_y_prob_ensemble"] = None + self.ctx.all_ratios_per_model = [[] + for _ in range(self.model_nums)] + + ctx[f"{cur_data}_y_prob"] = [] + ctx[f"{cur_data}_y_true"] = [] diff --git a/federatedscope/mf/dataloader/dataloader.py b/federatedscope/mf/dataloader/dataloader.py index c65f18603..64ab4bcb7 100644 --- a/federatedscope/mf/dataloader/dataloader.py +++ b/federatedscope/mf/dataloader/dataloader.py @@ -40,7 +40,7 @@ def load_mf_dataset(config=None): MFDATA_CLASS_DICT[config.data.type.lower()])( root=config.data.root, num_client=config.federate.client_num, - train_portion=config.data.splits[0], + split=config.data.splits, download=True) else: raise NotImplementedError("Dataset {} is not implemented.".format( @@ -54,6 +54,12 @@ def load_mf_dataset(config=None): batch_size=config.data.batch_size, drop_last=config.data.drop_last, theta=config.sgdmf.theta) + data_local_dict[id_client]["val"] = MFDataLoader( + data["val"], + shuffle=False, + batch_size=config.data.batch_size, + drop_last=config.data.drop_last, + theta=config.sgdmf.theta) data_local_dict[id_client]["test"] = MFDataLoader( data["test"], shuffle=False, diff --git a/federatedscope/mf/dataset/movielens.py b/federatedscope/mf/dataset/movielens.py index b5dba10cc..a1ed6e368 100644 --- a/federatedscope/mf/dataset/movielens.py +++ b/federatedscope/mf/dataset/movielens.py @@ -18,16 +18,20 @@ class VMFDataset: """ def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, - test_portion: float): + split: list): id_item = np.arange(self.n_item) shuffle(id_item) items_per_client = np.array_split(id_item, num_client) data = dict() for clientId, items in enumerate(items_per_client): client_ratings = ratings[:, items] - train_ratings, test_ratings = self._split_train_test_ratings( - client_ratings, test_portion) - data[clientId + 1] = {"train": train_ratings, "test": test_ratings} + train_ratings, val_ratings, test_ratings = self.\ + _split_train_val_test_ratings(client_ratings, split) + data[clientId + 1] = { + "train": train_ratings, + "val": val_ratings, + "test": test_ratings + } self.data = data @@ -36,16 +40,20 @@ class HMFDataset: """ def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, - test_portion: float): + split: list): id_user = np.arange(self.n_user) shuffle(id_user) users_per_client = np.array_split(id_user, num_client) data = dict() - for cliendId, users in enumerate(users_per_client): + for clientId, users in enumerate(users_per_client): client_ratings = ratings[users, :] - train_ratings, test_ratings = self._split_train_test_ratings( - client_ratings, test_portion) - data[cliendId + 1] = {"train": train_ratings, "test": test_ratings} + train_ratings, val_ratings, test_ratings = \ + self._split_train_val_test_ratings(client_ratings, split) + data[clientId + 1] = { + "train": train_ratings, + "val": val_ratings, + "test": test_ratings + } self.data = data @@ -55,12 +63,16 @@ class MovieLensData(object): Arguments: root (string): the path of data num_client (int): the number of clients - train_portion (float): the portion of training data + split (float): the portion of training/val/test data as list + [train, val, test] + download (bool): indicator to download dataset """ - def __init__(self, root, num_client, train_portion=0.9, download=True): + def __init__(self, root, num_client, split=None, download=True): super(MovieLensData, self).__init__() + if split is None: + split = [0.8, 0.1, 0.1] self.root = root self.data = None @@ -75,17 +87,62 @@ def __init__(self, root, num_client, train_portion=0.9, download=True): "You can use download=True to download it") ratings = self._load_meta() - self._split_n_clients_rating(ratings, num_client, 1 - train_portion) + if issubclass(type(self), HMFDataset): + self._split_n_clients_rating_hmf(ratings, num_client, split) + else: + self._split_n_clients_rating_vmf(ratings, num_client, split) + + def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int, + split: list): + id_user = np.arange(self.n_user) + shuffle(id_user) + users_per_client = np.array_split(id_user, num_client) + data = dict() + for clientId, users in enumerate(users_per_client): + client_ratings = ratings[users, :] + train_ratings, val_ratings, test_ratings = \ + self._split_train_val_test_ratings(client_ratings, split) + data[clientId + 1] = { + "train": train_ratings, + "val": val_ratings, + "test": test_ratings + } + self.data = data + + def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int, + split: list): + id_item = np.arange(self.n_item) + shuffle(id_item) + items_per_client = np.array_split(id_item, num_client) + data = dict() + for clientId, items in enumerate(items_per_client): + client_ratings = ratings[:, items] + train_ratings, val_ratings, test_ratings = \ + self._split_train_val_test_ratings(client_ratings, split) + data[clientId + 1] = { + "train": train_ratings, + "val": val_ratings, + "test": test_ratings + } + self.data = data + + def _split_train_val_test_ratings(self, ratings: csc_matrix, split: list): + train_ratio, val_ratio, test_ratio = split - def _split_train_test_ratings(self, ratings: csc_matrix, - test_portion: float): n_ratings = ratings.count_nonzero() - id_test = np.random.choice(n_ratings, - int(n_ratings * test_portion), - replace=False) - id_train = list(set(np.arange(n_ratings)) - set(id_test)) + id_val_test = np.random.choice(n_ratings, + int(n_ratings * + (val_ratio + test_ratio)), + replace=False) + + id_val = id_val_test[:int(n_ratings * val_ratio)] + id_test = id_val_test[int(n_ratings * val_ratio):] + id_train = list(set(np.arange(n_ratings)) - set(id_test) - set(id_val)) ratings = ratings.tocoo() + val = coo_matrix( + (ratings.data[id_val], (ratings.row[id_val], ratings.col[id_val])), + shape=ratings.shape) test = coo_matrix((ratings.data[id_test], (ratings.row[id_test], ratings.col[id_test])), shape=ratings.shape) @@ -93,8 +150,9 @@ def _split_train_test_ratings(self, ratings: csc_matrix, (ratings.row[id_train], ratings.col[id_train])), shape=ratings.shape) - train_ratings, test_ratings = train.tocsc(), test.tocsc() - return train_ratings, test_ratings + train_ratings, val_ratings, test_ratings = train.tocsc(), val.tocsc( + ), test.tocsc() + return train_ratings, val_ratings, test_ratings def _read_raw(self): fpath = os.path.join(self.root, self.base_folder, self.filename, diff --git a/federatedscope/mf/model/model.py b/federatedscope/mf/model/model.py index 7246cfe98..36c3e3618 100644 --- a/federatedscope/mf/model/model.py +++ b/federatedscope/mf/model/model.py @@ -45,7 +45,8 @@ def forward(self, indices, ratings): device=pred.device, dtype=torch.float32).to_dense() - return mask * pred, label, float(np.prod(pred.size())) / len(ratings) + return mask * pred, label, torch.Tensor( + [float(np.prod(pred.size())) / len(ratings)]) def load_state_dict(self, state_dict, strict: bool = True): @@ -55,7 +56,8 @@ def load_state_dict(self, state_dict, strict: bool = True): def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) # Mask embed_item - del state_dict[self.name_reserve] + if self.name_reserve in state_dict: + del state_dict[self.name_reserve] return state_dict diff --git a/federatedscope/mf/trainer/trainer.py b/federatedscope/mf/trainer/trainer.py index f169c4d16..40f5d8de4 100644 --- a/federatedscope/mf/trainer/trainer.py +++ b/federatedscope/mf/trainer/trainer.py @@ -9,6 +9,7 @@ from federatedscope.register import register_trainer import logging +from scipy import sparse logger = logging.getLogger(__name__) @@ -46,17 +47,32 @@ def parse_data(self, data): return init_dict def _hook_on_fit_end(self, ctx): - results = { - f"{ctx.cur_mode}_avg_loss": ctx.loss_batch_total / ctx.num_samples, - f"{ctx.cur_mode}_total": ctx.num_samples - } + if ctx.get("num_samples") == 0: + results = { + f"{ctx.cur_mode}_avg_loss": ctx.get( + "loss_batch_total_{}".format(ctx.cur_mode)), + f"{ctx.cur_mode}_total": 0 + } + else: + results = { + f"{ctx.cur_mode}_avg_loss": ctx.loss_batch_total / + ctx.num_samples, + f"{ctx.cur_mode}_total": ctx.num_samples + } setattr(ctx, 'eval_metrics', results) + if self.cfg.federate.method.lower() in ["fedem"]: + # cache label for evaluation ensemble + ctx[f"{ctx.cur_mode}_y_prob"] = [] + ctx[f"{ctx.cur_mode}_y_true"] = [] def _hook_on_batch_forward(self, ctx): indices, ratings = ctx.data_batch pred, label, ratio = ctx.model(indices, ratings) ctx.loss_batch = CtxVar( - ctx.criterion(pred, label) * ratio, LIFECYCLE.BATCH) + ctx.criterion(pred, label) * ratio.item(), LIFECYCLE.BATCH) + ctx.ratio_batch = CtxVar(ratio.item(), LIFECYCLE.BATCH) + ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) + ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) ctx.batch_size = len(ratings) @@ -66,6 +82,13 @@ def _hook_on_batch_end(self, ctx): ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) + if self.cfg.federate.method.lower() in ["fedem"]: + # cache label for evaluation ensemble + ctx.get("{}_y_true".format(ctx.cur_mode)).append( + sparse.csr_matrix(ctx.y_true.detach().cpu().numpy())) + ctx.get("{}_y_prob".format(ctx.cur_mode)).append( + sparse.csr_matrix(ctx.y_prob.detach().cpu().numpy())) + def _hook_on_batch_forward_flop_count(self, ctx): if not isinstance(self.ctx.monitor, Monitor): logger.warning( @@ -81,7 +104,10 @@ def _hook_on_batch_forward_flop_count(self, ctx): # calculate the flops_per_sample try: indices, ratings = ctx.data_batch - if isinstance(indices, numpy.ndarray): + if isinstance(indices, tuple) and isinstance( + indices[0], numpy.ndarray): + indices = torch.from_numpy(numpy.stack(indices)) + elif isinstance(indices, numpy.ndarray): indices = torch.from_numpy(indices) if isinstance(ratings, numpy.ndarray): ratings = torch.from_numpy(ratings) From 3c6719653182d41f3b70655720841e6ead95cf56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Tue, 11 Oct 2022 16:47:39 +0800 Subject: [PATCH 2/3] modified according to david's comment --- federatedscope/core/trainers/trainer_FedEM.py | 16 +++--- federatedscope/mf/dataset/movielens.py | 42 ++------------ federatedscope/mf/trainer/trainer.py | 3 +- tests/test_fedem.py | 56 +++++++++++++++++++ 4 files changed, 71 insertions(+), 46 deletions(-) diff --git a/federatedscope/core/trainers/trainer_FedEM.py b/federatedscope/core/trainers/trainer_FedEM.py index 30dc82eaa..da9a181f7 100644 --- a/federatedscope/core/trainers/trainer_FedEM.py +++ b/federatedscope/core/trainers/trainer_FedEM.py @@ -196,7 +196,7 @@ def _hook_on_fit_end_ensemble_eval_mf(self, ctx): """ Ensemble evaluation for matrix factorization model """ - cur_data = ctx.cur_mode + cur_data = ctx.cur_split batch_num = len(ctx[f"{cur_data}_y_prob"]) if f"{cur_data}_y_prob_ensemble" not in ctx or ctx[ f"{cur_data}_y_prob_ensemble"] is None: @@ -234,18 +234,18 @@ def _hook_on_fit_end_ensemble_eval_mf(self, ctx): if ctx.num_samples == 0: results = { f"{cur_data}_avg_loss": ctx.get( - "loss_batch_total_{}".format(ctx.cur_mode)), + "loss_batch_total_{}".format(ctx.cur_split)), f"{cur_data}_total": 0 } else: results = { - f"{ctx.cur_mode}_avg_loss": ctx.get( - f"loss_batch_total_{ctx.cur_mode}") / ctx.num_samples, - f"{ctx.cur_mode}_total": ctx.num_samples + f"{ctx.cur_split}_avg_loss": ctx.get( + f"loss_batch_total_{ctx.cur_split}") / ctx.num_samples, + f"{ctx.cur_split}_total": ctx.num_samples } - if isinstance(results[f"{ctx.cur_mode}_avg_loss"], torch.Tensor): - results[f"{ctx.cur_mode}_avg_loss"] = results[ - f"{ctx.cur_mode}_avg_loss"].item() + if isinstance(results[f"{ctx.cur_split}_avg_loss"], torch.Tensor): + results[f"{ctx.cur_split}_avg_loss"] = results[ + f"{ctx.cur_split}_avg_loss"].item() setattr(ctx, 'eval_metrics', results) # reset for next run_routine that may have different diff --git a/federatedscope/mf/dataset/movielens.py b/federatedscope/mf/dataset/movielens.py index a1ed6e368..15e43cf24 100644 --- a/federatedscope/mf/dataset/movielens.py +++ b/federatedscope/mf/dataset/movielens.py @@ -87,44 +87,12 @@ def __init__(self, root, num_client, split=None, download=True): "You can use download=True to download it") ratings = self._load_meta() - if issubclass(type(self), HMFDataset): - self._split_n_clients_rating_hmf(ratings, num_client, split) - else: - self._split_n_clients_rating_vmf(ratings, num_client, split) - - def _split_n_clients_rating_hmf(self, ratings: csc_matrix, num_client: int, - split: list): - id_user = np.arange(self.n_user) - shuffle(id_user) - users_per_client = np.array_split(id_user, num_client) - data = dict() - for clientId, users in enumerate(users_per_client): - client_ratings = ratings[users, :] - train_ratings, val_ratings, test_ratings = \ - self._split_train_val_test_ratings(client_ratings, split) - data[clientId + 1] = { - "train": train_ratings, - "val": val_ratings, - "test": test_ratings - } - self.data = data + self._split_n_clients_rating(ratings, num_client, split) - def _split_n_clients_rating_vmf(self, ratings: csc_matrix, num_client: int, - split: list): - id_item = np.arange(self.n_item) - shuffle(id_item) - items_per_client = np.array_split(id_item, num_client) - data = dict() - for clientId, items in enumerate(items_per_client): - client_ratings = ratings[:, items] - train_ratings, val_ratings, test_ratings = \ - self._split_train_val_test_ratings(client_ratings, split) - data[clientId + 1] = { - "train": train_ratings, - "val": val_ratings, - "test": test_ratings - } - self.data = data + def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int, + split: list): + raise NotImplementedError("You should use the parent class of " + "MovieLensData") def _split_train_val_test_ratings(self, ratings: csc_matrix, split: list): train_ratio, val_ratio, test_ratio = split diff --git a/federatedscope/mf/trainer/trainer.py b/federatedscope/mf/trainer/trainer.py index 40f5d8de4..881bdc499 100644 --- a/federatedscope/mf/trainer/trainer.py +++ b/federatedscope/mf/trainer/trainer.py @@ -55,7 +55,8 @@ def _hook_on_fit_end(self, ctx): } else: results = { - f"{ctx.cur_mode}_avg_loss": ctx.loss_batch_total / + f"{ctx.cur_mode}_avg_loss": ctx.get( + "loss_batch_total_{}".format(ctx.cur_mode)) / ctx.num_samples, f"{ctx.cur_mode}_total": ctx.num_samples } diff --git a/tests/test_fedem.py b/tests/test_fedem.py index 498f275dc..34773e109 100644 --- a/tests/test_fedem.py +++ b/tests/test_fedem.py @@ -78,6 +78,62 @@ def test_femnist_standalone(self): test_best_results["client_summarized_weighted_avg"]['test_loss'], 600) + def set_config_mf(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + cfg.early_stop_patience = 100 + cfg.eval.best_res_update_round_wise_key = "test_avg_loss" + cfg.eval.freq = 5 + cfg.eval.metrics = [] + + cfg.federate.mode = 'standalone' + cfg.train.local_update_steps = 20 + cfg.federate.total_round_num = 50 + cfg.federate.client_num = 5 + + cfg.federate.method = "FedEM" + cfg.model.model_num_per_trainer = 3 + + cfg.data.root = 'test_data/' + cfg.data.type = 'vflmovielens1m' + cfg.data.batch_size = 32 + + cfg.model.type = 'VMFNet' + cfg.model.hidden = 20 + + cfg.train.optimizer.lr = 1. + cfg.train.optimizer.weight_decay = 0.0 + + cfg.criterion.type = 'MSELoss' + cfg.trainer.type = 'mftrainer' + cfg.seed = 123 + + return backup_cfg + + def test_mf_standalone(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_mf(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + + Fed_runner = FedRunner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_best_results = Fed_runner.run() + print(test_best_results) + init_cfg.merge_from_other_cfg(backup_cfg) + self.assertLess( + test_best_results["client_summarized_weighted_avg"]['test_loss'], + 600) + if __name__ == '__main__': unittest.main() From 2297679bcbc375d6103a3be373263f7a30492e6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Tue, 11 Oct 2022 16:51:45 +0800 Subject: [PATCH 3/3] modified according to david's comment --- tests/test_fedem.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_fedem.py b/tests/test_fedem.py index 34773e109..aa372d8fb 100644 --- a/tests/test_fedem.py +++ b/tests/test_fedem.py @@ -127,12 +127,12 @@ def test_mf_standalone(self): client_class=get_client_cls(init_cfg), config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) - test_best_results = Fed_runner.run() - print(test_best_results) + test_results = Fed_runner.run() init_cfg.merge_from_other_cfg(backup_cfg) + self.assertLess( - test_best_results["client_summarized_weighted_avg"]['test_loss'], - 600) + test_results["client_summarized_weighted_avg"]["test_avg_loss"], + 50) if __name__ == '__main__':