diff --git a/federatedscope/core/configs/cfg_evaluation.py b/federatedscope/core/configs/cfg_evaluation.py index 09b9cdd48..df1a1ed97 100644 --- a/federatedscope/core/configs/cfg_evaluation.py +++ b/federatedscope/core/configs/cfg_evaluation.py @@ -19,7 +19,6 @@ def extend_evaluation_cfg(cfg): # Monitoring, e.g., 'dissim' for B-local dissimilarity cfg.eval.monitoring = [] - cfg.eval.count_flops = True # ---------------------------------------------------------------------- # diff --git a/federatedscope/core/data/base_data.py b/federatedscope/core/data/base_data.py index 499cd344a..72421eac4 100644 --- a/federatedscope/core/data/base_data.py +++ b/federatedscope/core/data/base_data.py @@ -1,4 +1,7 @@ import logging + +from scipy.sparse.csc import csc_matrix + from federatedscope.core.data.utils import merge_data from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader @@ -145,13 +148,16 @@ class ClientData(dict): test: test dataset, which will be converted to ``Dataloader`` Note: - Key ``data`` in ``ClientData`` is the raw dataset. + Key ``{split}_data`` in ``ClientData`` is the raw dataset. + Key ``{split}`` in ``ClientData`` is the dataloader. """ + SPLIT_NAMES = ['train', 'val', 'test'] + def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs): self.client_cfg = None - self.train = train - self.val = val - self.test = test + self.train_data = train + self.val_data = val + self.test_data = test self.setup(client_cfg) if kwargs is not None: for key in kwargs: @@ -168,18 +174,26 @@ def setup(self, new_client_cfg=None): Returns: Bool: Status for indicating whether the client_cfg is updated """ - # if `batch_size` or `shuffle` change, reinstantiate DataLoader + # if `batch_size` or `shuffle` change, re-instantiate DataLoader if self.client_cfg is not None: if dict(self.client_cfg.dataloader) == dict( new_client_cfg.dataloader): return False self.client_cfg = new_client_cfg - if self.train is not None: - self['train'] = get_dataloader(self.train, self.client_cfg, - 'train') - if self.val is not None: - self['val'] = get_dataloader(self.val, self.client_cfg, 'val') - if self.test is not None: - self['test'] = get_dataloader(self.test, self.client_cfg, 'test') + + for split_data, split_name in zip( + [self.train_data, self.val_data, self.test_data], + self.SPLIT_NAMES): + if split_data is not None: + # csc_matrix does not have ``__len__`` attributes + if isinstance(split_data, csc_matrix): + self[split_name] = get_dataloader(split_data, + self.client_cfg, + split_name) + elif len(split_data) > 0: + self[split_name] = get_dataloader(split_data, + self.client_cfg, + split_name) + return True