From 9d0a3a121aa861a26425df8121cb19947d19c32b Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Thu, 8 Dec 2022 14:39:21 +0800 Subject: [PATCH 1/3] make ClientData more robust --- federatedscope/core/configs/cfg_evaluation.py | 1 - federatedscope/core/data/base_data.py | 29 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/federatedscope/core/configs/cfg_evaluation.py b/federatedscope/core/configs/cfg_evaluation.py index 09b9cdd48..df1a1ed97 100644 --- a/federatedscope/core/configs/cfg_evaluation.py +++ b/federatedscope/core/configs/cfg_evaluation.py @@ -19,7 +19,6 @@ def extend_evaluation_cfg(cfg): # Monitoring, e.g., 'dissim' for B-local dissimilarity cfg.eval.monitoring = [] - cfg.eval.count_flops = True # ---------------------------------------------------------------------- # diff --git a/federatedscope/core/data/base_data.py b/federatedscope/core/data/base_data.py index 499cd344a..615d60491 100644 --- a/federatedscope/core/data/base_data.py +++ b/federatedscope/core/data/base_data.py @@ -145,13 +145,14 @@ class ClientData(dict): test: test dataset, which will be converted to ``Dataloader`` Note: - Key ``data`` in ``ClientData`` is the raw dataset. + Key ``{split}_data`` in ``ClientData`` is the raw dataset. + Key ``{split}`` in ``ClientData`` is the dataloader. """ def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs): self.client_cfg = None - self.train = train - self.val = val - self.test = test + self.train_data = train + self.val_data = val + self.test_data = test self.setup(client_cfg) if kwargs is not None: for key in kwargs: @@ -168,18 +169,22 @@ def setup(self, new_client_cfg=None): Returns: Bool: Status for indicating whether the client_cfg is updated """ - # if `batch_size` or `shuffle` change, reinstantiate DataLoader + # if `batch_size` or `shuffle` change, re-instantiate DataLoader if self.client_cfg is not None: if dict(self.client_cfg.dataloader) == dict( new_client_cfg.dataloader): return False self.client_cfg = new_client_cfg - if self.train is not None: - self['train'] = get_dataloader(self.train, self.client_cfg, - 'train') - if self.val is not None: - self['val'] = get_dataloader(self.val, self.client_cfg, 'val') - if self.test is not None: - self['test'] = get_dataloader(self.test, self.client_cfg, 'test') + if self.train_data is not None: + if len(self.train_data) > 0: + self['train'] = get_dataloader(self.train, self.client_cfg, + 'train') + if self.val_data is not None: + if len(self.val_data) > 0: + self['val'] = get_dataloader(self.val, self.client_cfg, 'val') + if self.test_data is not None: + if len(self.test_data) > 0: + self['test'] = get_dataloader(self.test, self.client_cfg, + 'test') return True From 27d419131ed788fdcbb4bbe47a4dc36f3852a62c Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Thu, 8 Dec 2022 14:44:18 +0800 Subject: [PATCH 2/3] fix minor bugs --- federatedscope/core/data/base_data.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/federatedscope/core/data/base_data.py b/federatedscope/core/data/base_data.py index 615d60491..d0a04694d 100644 --- a/federatedscope/core/data/base_data.py +++ b/federatedscope/core/data/base_data.py @@ -178,13 +178,14 @@ def setup(self, new_client_cfg=None): self.client_cfg = new_client_cfg if self.train_data is not None: if len(self.train_data) > 0: - self['train'] = get_dataloader(self.train, self.client_cfg, - 'train') + self['train'] = get_dataloader(self.train_data, + self.client_cfg, 'train') if self.val_data is not None: if len(self.val_data) > 0: - self['val'] = get_dataloader(self.val, self.client_cfg, 'val') + self['val'] = get_dataloader(self.val_data, self.client_cfg, + 'val') if self.test_data is not None: if len(self.test_data) > 0: - self['test'] = get_dataloader(self.test, self.client_cfg, + self['test'] = get_dataloader(self.test_data, self.client_cfg, 'test') return True From 6058f234f1532f0974e548f9b666594348799fa2 Mon Sep 17 00:00:00 2001 From: rayrayraykk <18007356109@163.com> Date: Thu, 8 Dec 2022 16:16:49 +0800 Subject: [PATCH 3/3] fix handling csc_matrix --- federatedscope/core/data/base_data.py | 32 +++++++++++++++++---------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/federatedscope/core/data/base_data.py b/federatedscope/core/data/base_data.py index d0a04694d..72421eac4 100644 --- a/federatedscope/core/data/base_data.py +++ b/federatedscope/core/data/base_data.py @@ -1,4 +1,7 @@ import logging + +from scipy.sparse.csc import csc_matrix + from federatedscope.core.data.utils import merge_data from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader @@ -148,6 +151,8 @@ class ClientData(dict): Key ``{split}_data`` in ``ClientData`` is the raw dataset. Key ``{split}`` in ``ClientData`` is the dataloader. """ + SPLIT_NAMES = ['train', 'val', 'test'] + def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs): self.client_cfg = None self.train_data = train @@ -176,16 +181,19 @@ def setup(self, new_client_cfg=None): return False self.client_cfg = new_client_cfg - if self.train_data is not None: - if len(self.train_data) > 0: - self['train'] = get_dataloader(self.train_data, - self.client_cfg, 'train') - if self.val_data is not None: - if len(self.val_data) > 0: - self['val'] = get_dataloader(self.val_data, self.client_cfg, - 'val') - if self.test_data is not None: - if len(self.test_data) > 0: - self['test'] = get_dataloader(self.test_data, self.client_cfg, - 'test') + + for split_data, split_name in zip( + [self.train_data, self.val_data, self.test_data], + self.SPLIT_NAMES): + if split_data is not None: + # csc_matrix does not have ``__len__`` attributes + if isinstance(split_data, csc_matrix): + self[split_name] = get_dataloader(split_data, + self.client_cfg, + split_name) + elif len(split_data) > 0: + self[split_name] = get_dataloader(split_data, + self.client_cfg, + split_name) + return True