From b99cc31985dbe7de43be8ecec4641e10a96588bb Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Mon, 10 Oct 2022 12:20:59 +0800 Subject: [PATCH] minor changes --- federatedscope/contrib/data/example.py | 2 +- .../contrib/trainer/torch_example.py | 7 +++++-- federatedscope/core/auxiliaries/utils.py | 2 -- federatedscope/core/data/base_translator.py | 18 ++++++++++-------- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/federatedscope/contrib/data/example.py b/federatedscope/contrib/data/example.py index da3e9c1cd..b896bdf7f 100644 --- a/federatedscope/contrib/data/example.py +++ b/federatedscope/contrib/data/example.py @@ -1,7 +1,7 @@ from federatedscope.register import register_data -def MyData(config, client_cfgs): +def MyData(config, client_cfgs=None): r""" Returns: data: diff --git a/federatedscope/contrib/trainer/torch_example.py b/federatedscope/contrib/trainer/torch_example.py index 4ecaf18b4..18cd5d7a0 100644 --- a/federatedscope/contrib/trainer/torch_example.py +++ b/federatedscope/contrib/trainer/torch_example.py @@ -54,7 +54,8 @@ def train(self): # _hook_on_fit_end return num_samples, self.model.cpu().state_dict(), \ - {'loss_total': total_loss} + {'loss_total': total_loss, 'avg_loss': total_loss/float( + num_samples)} def evaluate(self, target_data_split_name='test'): import torch @@ -76,7 +77,9 @@ def evaluate(self, target_data_split_name='test'): # _hook_on_fit_end return { f'{target_data_split_name}_loss': total_loss, - f'{target_data_split_name}_total': num_samples + f'{target_data_split_name}_total': num_samples, + f'{target_data_split_name}_avg_loss': total_loss / + float(num_samples) } def update(self, model_parameters, strict=False): diff --git a/federatedscope/core/auxiliaries/utils.py b/federatedscope/core/auxiliaries/utils.py index f2f488b66..190a96102 100644 --- a/federatedscope/core/auxiliaries/utils.py +++ b/federatedscope/core/auxiliaries/utils.py @@ -13,8 +13,6 @@ import numpy as np # Blind torch -import torch.utils - try: import torch import torchvision diff --git a/federatedscope/core/data/base_translator.py b/federatedscope/core/data/base_translator.py index 9d8b7e995..18fe34e2a 100644 --- a/federatedscope/core/data/base_translator.py +++ b/federatedscope/core/data/base_translator.py @@ -96,14 +96,16 @@ def split_to_client(self, train, val, test): train_label_distribution = None # Split train/val/test to client - if len(train) > 0: - split_train = self.splitter(train) - try: - train_label_distribution = [[j[1] for j in x] - for x in split_train] - except: - logger.warning('Cannot access train label distribution for ' - 'splitter.') + if self.global_cfg.data.consistent_label_distribution: + if len(train) > 0: + split_train = self.splitter(train) + try: + train_label_distribution = [[j[1] for j in x] + for x in split_train] + except: + logger.warning( + 'Cannot access train label distribution for ' + 'splitter.') if len(val) > 0: split_val = self.splitter(val, prior=train_label_distribution) if len(test) > 0: