diff --git a/federatedscope/autotune/algos.py b/federatedscope/autotune/algos.py index b7f66f156..115dd1b6f 100644 --- a/federatedscope/autotune/algos.py +++ b/federatedscope/autotune/algos.py @@ -20,16 +20,16 @@ logger = logging.getLogger(__name__) -def make_trial(trial_cfg): +def make_trial(trial_cfg, client_cfgs=None): setup_seed(trial_cfg.seed) data, modified_config = get_data(config=trial_cfg.clone()) trial_cfg.merge_from_other_cfg(modified_config) trial_cfg.freeze() - # TODO: enable client-wise configuration Fed_runner = FedRunner(data=data, server_class=get_server_cls(trial_cfg), client_class=get_client_cls(trial_cfg), - config=trial_cfg.clone()) + config=trial_cfg.clone(), + client_config=client_cfgs) results = Fed_runner.run() key1, key2 = trial_cfg.hpo.metric.split('.') return results[key1][key2] @@ -39,24 +39,25 @@ class TrialExecutor(threading.Thread): """This class is responsible for executing the FL procedure with a given trial configuration in another thread. """ - def __init__(self, cfg_idx, signal, returns, trial_config): + def __init__(self, cfg_idx, signal, returns, trial_config, client_cfgs): threading.Thread.__init__(self) self._idx = cfg_idx self._signal = signal self._returns = returns self._trial_cfg = trial_config + self._client_cfgs = client_cfgs def run(self): setup_seed(self._trial_cfg.seed) data, modified_config = get_data(config=self._trial_cfg.clone()) self._trial_cfg.merge_from_other_cfg(modified_config) self._trial_cfg.freeze() - # TODO: enable client-wise configuration Fed_runner = FedRunner(data=data, server_class=get_server_cls(self._trial_cfg), client_class=get_client_cls(self._trial_cfg), - config=self._trial_cfg.clone()) + config=self._trial_cfg.clone(), + client_config=self._client_cfgs) results = Fed_runner.run() key1, key2 = self._trial_cfg.hpo.metric.split('.') self._returns['perf'] = results[key1][key2] @@ -64,19 +65,20 @@ def run(self): self._signal.set() -def get_scheduler(init_cfg): +def get_scheduler(init_cfg, client_cfgs=None): """To instantiate a scheduler object for conducting HPO Arguments: - init_cfg (federatedscope.core.configs.config.CN): configuration. + init_cfg: configuration + client_cfgs: client-specific configuration """ if init_cfg.hpo.scheduler in [ 'sha', 'rs', 'bo_kde', 'bohb', 'hb', 'bo_gp', 'bo_rf' ]: - scheduler = SuccessiveHalvingAlgo(init_cfg) + scheduler = SuccessiveHalvingAlgo(init_cfg, client_cfgs) # elif init_cfg.hpo.scheduler == 'pbt': # scheduler = PBT(init_cfg) - elif init_cfg.hpo.scheduler.startswith('wrap'): + elif init_cfg.hpo.scheduler.startswith('wrap', client_cfgs): scheduler = SHAWrapFedex(init_cfg) return scheduler @@ -84,15 +86,18 @@ def get_scheduler(init_cfg): class Scheduler(object): """The base class for describing HPO algorithms """ - def __init__(self, cfg): + def __init__(self, cfg, client_cfgs=None): """ Arguments: - cfg (federatedscope.core.configs.config.CN): dict like object, - where each key-value pair corresponds to a field and its - choices. + cfg (federatedscope.core.configs.config.CN): dict \ + like object, where each key-value pair corresponds to a \ + field and its choices. + client_cfgs: client-specific configuration """ self._cfg = cfg + self._client_cfgs = client_cfgs + # Create hpo working folder os.makedirs(self._cfg.hpo.working_folder, exist_ok=True) self._search_space = parse_search_space(self._cfg.hpo.ss) @@ -160,7 +165,7 @@ def _evaluate(self, configs): flags[available_worker].clear() trial = TrialExecutor(i, flags[available_worker], thread_results[available_worker], - trial_cfg) + trial_cfg, self._client_cfgs) trial.start() threads[available_worker] = trial @@ -182,7 +187,7 @@ def _evaluate(self, configs): for i, config in enumerate(configs): trial_cfg = self._cfg.clone() trial_cfg.merge_from_list(config2cmdargs(config)) - perfs[i] = make_trial(trial_cfg) + perfs[i] = make_trial(trial_cfg, self._client_cfgs) logger.info( "Evaluate the {}-th config {} and get performance {}". format(i, config, perfs[i])) diff --git a/federatedscope/autotune/hpbandster.py b/federatedscope/autotune/hpbandster.py index aa8e42fc1..48f6513f7 100644 --- a/federatedscope/autotune/hpbandster.py +++ b/federatedscope/autotune/hpbandster.py @@ -8,7 +8,6 @@ import hpbandster.core.nameserver as hpns from hpbandster.core.worker import Worker from hpbandster.optimizers import BOHB, HyperBand, RandomSearch -from hpbandster.optimizers.iterations import SuccessiveHalving from federatedscope.autotune.utils import eval_in_fs @@ -54,7 +53,13 @@ def get_next_iteration(self, iteration, iteration_kwargs={}): class MyWorker(Worker): - def __init__(self, cfg, ss, sleep_interval=0, *args, **kwargs): + def __init__(self, + cfg, + ss, + sleep_interval=0, + client_cfgs=None, + *args, + **kwargs): super(MyWorker, self).__init__(**kwargs) self.sleep_interval = sleep_interval self.cfg = cfg @@ -63,7 +68,7 @@ def __init__(self, cfg, ss, sleep_interval=0, *args, **kwargs): self._perfs = [] def compute(self, config, budget, **kwargs): - res = eval_in_fs(self.cfg, config, int(budget)) + res = eval_in_fs(self.cfg, config, int(budget), self.client_cfgs) config = dict(config) config['federate.total_round_num'] = budget self._init_configs.append(config) @@ -87,7 +92,7 @@ def summarize(self): return results -def run_hpbandster(cfg, scheduler): +def run_hpbandster(cfg, scheduler, client_cfgs=None): config_space = scheduler._search_space if cfg.hpo.scheduler.startswith('wrap_'): ss = CS.ConfigurationSpace() @@ -100,7 +105,8 @@ def run_hpbandster(cfg, scheduler): cfg=cfg, nameserver='127.0.0.1', nameserver_port=ns_port, - run_id=cfg.hpo.scheduler) + run_id=cfg.hpo.scheduler, + client_cfgs=client_cfgs) w.run(background=True) opt_kwargs = { 'configspace': config_space, diff --git a/federatedscope/autotune/smac.py b/federatedscope/autotune/smac.py index 99dd31248..a05734caf 100644 --- a/federatedscope/autotune/smac.py +++ b/federatedscope/autotune/smac.py @@ -10,13 +10,13 @@ logger = logging.getLogger(__name__) -def run_smac(cfg, scheduler): +def run_smac(cfg, scheduler, client_cfgs=None): init_configs = [] perfs = [] def optimization_function_wrapper(config): budget = cfg.hpo.sha.budgets[-1] - res = eval_in_fs(cfg, config, budget) + res = eval_in_fs(cfg, config, budget, client_cfgs) config = dict(config) config['federate.total_round_num'] = budget init_configs.append(config) diff --git a/federatedscope/autotune/utils.py b/federatedscope/autotune/utils.py index cd6523a42..ac3dcaeb7 100644 --- a/federatedscope/autotune/utils.py +++ b/federatedscope/autotune/utils.py @@ -133,7 +133,7 @@ def process(file): plt.close() -def eval_in_fs(cfg, config, budget): +def eval_in_fs(cfg, config, budget, client_cfgs=None): import ConfigSpace as CS from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.data_builder import get_data @@ -170,7 +170,8 @@ def eval_in_fs(cfg, config, budget): Fed_runner = FedRunner(data=data, server_class=get_server_cls(trial_cfg), client_class=get_client_cls(trial_cfg), - config=trial_cfg.clone()) + config=trial_cfg.clone(), + client_config=client_cfgs) results = Fed_runner.run() key1, key2 = trial_cfg.hpo.metric.split('.') return results[key1][key2] diff --git a/federatedscope/contrib/data/example.py b/federatedscope/contrib/data/example.py deleted file mode 100644 index b896bdf7f..000000000 --- a/federatedscope/contrib/data/example.py +++ /dev/null @@ -1,30 +0,0 @@ -from federatedscope.register import register_data - - -def MyData(config, client_cfgs=None): - r""" - Returns: - data: - { - '{client_id}': { - 'train': Dataset or DataLoader, - 'test': Dataset or DataLoader, - 'val': Dataset or DataLoader - } - } - config: - cfg_node - """ - data = None - config = config - client_cfgs = client_cfgs - return data, config - - -def call_my_data(config, client_cfgs): - if config.data.type == "mydata": - data, modified_config = MyData(config, client_cfgs) - return data, modified_config - - -register_data("mydata", call_my_data) diff --git a/federatedscope/contrib/data/mini_graph_dt.py b/federatedscope/contrib/data/mini_graph_dt.py new file mode 100644 index 000000000..150a43030 --- /dev/null +++ b/federatedscope/contrib/data/mini_graph_dt.py @@ -0,0 +1,160 @@ +import os +import torch +import numpy as np + +from torch_geometric.data import InMemoryDataset, Data +from torch_geometric.datasets import TUDataset, MoleculeNet + +from federatedscope.register import register_data +from federatedscope.core.data import DummyDataTranslator +from federatedscope.core.splitters.graph.scaffold_lda_splitter import \ + GenFeatures + +# Run with mini_graph_dt: +# python federatedscope/main.py --cfg \ +# federatedscope/gfl/baseline/mini_graph_dc/fedavg.yaml --client_cfg \ +# federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml +# Test Accuracy: ~0.7 + + +class MiniGraphDCDataset(InMemoryDataset): + NAME = 'mini_graph_dt' + DATA_NAME = ['BACE', 'BBBP', 'CLINTOX', 'ENZYMES', 'PROTEINS_full'] + IN_MEMORY_DATA = {} + + def __init__(self, root, splits=[0.8, 0.1, 0.1]): + self.root = root + self.splits = splits + super(MiniGraphDCDataset, self).__init__(root) + + @property + def processed_dir(self): + return os.path.join(self.root, self.NAME, 'processed') + + @property + def processed_file_names(self): + return ['pre_transform.pt', 'pre_filter.pt'] + + def __len__(self): + return len(self.DATA_NAME) + + def __getitem__(self, idx): + if idx not in self.IN_MEMORY_DATA: + self.IN_MEMORY_DATA[idx] = {} + for split in ['train', 'val', 'test']: + split_data = self._load(idx, split) + if split_data: + self.IN_MEMORY_DATA[idx][split] = split_data + return self.IN_MEMORY_DATA[idx] + + def _load(self, idx, split): + try: + data = torch.load( + os.path.join(self.processed_dir, str(idx), f'{split}.pt')) + except: + data = None + return data + + def process(self): + np.random.seed(0) + for idx, name in enumerate(self.DATA_NAME): + if name in ['BACE', 'BBBP', 'CLINTOX']: + dataset = MoleculeNet(self.root, name) + featurizer = GenFeatures() + ds = [] + for graph in dataset: + graph = featurizer(graph) + ds.append( + Data(edge_index=graph.edge_index, x=graph.x, + y=graph.y)) + dataset = ds + if name in ['BACE', 'BBBP']: + for i in range(len(dataset)): + dataset[i].y = dataset[i].y.long() + if name in ['CLINTOX']: + for i in range(len(dataset)): + dataset[i].y = torch.argmax( + dataset[i].y).view(-1).unsqueeze(0) + else: + # Classification + dataset = TUDataset(self.root, name) + dataset = [ + Data(edge_index=graph.edge_index, x=graph.x, y=graph.y) + for graph in dataset + ] + + # We fix train/val/test + index = np.random.permutation(np.arange(len(dataset))) + train_idx = index[:int(len(dataset) * self.splits[0])] + valid_idx = index[int(len(dataset) * self.splits[0] + ):int(len(dataset) * sum(self.splits[:2]))] + test_idx = index[int(len(dataset) * sum(self.splits[:2])):] + + if not os.path.isdir(os.path.join(self.processed_dir, str(idx))): + os.makedirs(os.path.join(self.processed_dir, str(idx))) + + train_path = os.path.join(self.processed_dir, str(idx), 'train.pt') + valid_path = os.path.join(self.processed_dir, str(idx), 'val.pt') + test_path = os.path.join(self.processed_dir, str(idx), 'test.pt') + + torch.save([dataset[i] for i in train_idx], train_path) + torch.save([dataset[i] for i in valid_idx], valid_path) + torch.save([dataset[i] for i in test_idx], test_path) + + print(name, len(dataset), dataset[0]) + + def meta_info(self): + return { + 'BACE': { + 'task': 'classification', + 'input_dim': 74, + 'output_dim': 2, + 'num_samples': 1513, + }, + 'BBBP': { + 'task': 'classification', + 'input_dim': 74, + 'output_dim': 2, + 'num_samples': 2039, + }, + 'CLINTOX': { + 'task': 'classification', + 'input_dim': 74, + 'output_dim': 2, + 'num_samples': 1478, + }, + 'ENZYMES': { + 'task': 'classification', + 'input_dim': 3, + 'output_dim': 6, + 'num_samples': 600, + }, + 'PROTEINS_full': { + 'task': 'classification', + 'input_dim': 3, + 'output_dim': 2, + 'num_samples': 1113, + }, + } + + +def load_mini_graph_dt(config, client_cfgs=None): + dataset = MiniGraphDCDataset(config.data.root) + # Convert to dict + datadict = { + client_id + 1: dataset[client_id] + for client_id in range(len(dataset)) + } + config.merge_from_list(['federate.client_num', len(dataset)]) + translator = DummyDataTranslator(config, client_cfgs) + + return translator(datadict), config + + +def call_mini_graph_dt(config, client_cfgs): + if config.data.type == "mini-graph-dc": + data, modified_config = load_mini_graph_dt(config, client_cfgs) + return data, modified_config + + +register_data("mini-graph-dc", call_mini_graph_dt) diff --git a/federatedscope/core/auxiliaries/logging.py b/federatedscope/core/auxiliaries/logging.py index 92a867136..a0c0bd546 100644 --- a/federatedscope/core/auxiliaries/logging.py +++ b/federatedscope/core/auxiliaries/logging.py @@ -8,7 +8,7 @@ import numpy as np -from federatedscope.core.auxiliaries.utils import logger +logger = logging.getLogger(__name__) class CustomFormatter(logging.Formatter): diff --git a/federatedscope/core/cmd_args.py b/federatedscope/core/cmd_args.py index 61b886f26..2581a33d7 100644 --- a/federatedscope/core/cmd_args.py +++ b/federatedscope/core/cmd_args.py @@ -45,3 +45,19 @@ def parse_args(args=None): sys.exit(1) return parse_res + + +def parse_client_cfg(arg_opts): + """ + Arguments: + arg_opts: list pairs of arg.opts + """ + client_cfg_opts = [] + i = 0 + while i < len(arg_opts): + if arg_opts[i].startswith('client'): + client_cfg_opts.append(arg_opts.pop(i)) + client_cfg_opts.append(arg_opts.pop(i)) + else: + i += 1 + return arg_opts, client_cfg_opts diff --git a/federatedscope/core/fed_runner.py b/federatedscope/core/fed_runner.py index dac9e6649..ac3eb2129 100644 --- a/federatedscope/core/fed_runner.py +++ b/federatedscope/core/fed_runner.py @@ -353,7 +353,8 @@ def _setup_client(self, if self.client_class: client_specific_config = self.cfg.clone() - if self.client_cfgs: + if self.client_cfgs and \ + self.client_cfgs.get('client_{}'.format(client_id)): client_specific_config.defrost() client_specific_config.merge_from_other_cfg( self.client_cfgs.get('client_{}'.format(client_id))) diff --git a/federatedscope/gfl/baseline/mini_graph_dc/fedavg.yaml b/federatedscope/gfl/baseline/mini_graph_dc/fedavg.yaml new file mode 100644 index 000000000..77f66dff8 --- /dev/null +++ b/federatedscope/gfl/baseline/mini_graph_dc/fedavg.yaml @@ -0,0 +1,33 @@ +use_gpu: True +device: 0 +early_stop: + patience: 100 + improve_indicator_mode: mean +federate: + mode: 'standalone' + make_global_eval: False + total_round_num: 400 + share_local_model: False +data: + root: data/ + type: mini-graph-dc +dataloader: + type: pyg +model: + task: graph + type: gin + hidden: 64 +personalization: + local_param: ['encoder_atom', 'encoder', 'clf'] +train: + batch_or_epoch: epoch + local_update_steps: 1 + optimizer: + type: SGD +trainer: + type: graphminibatch_trainer +eval: + freq: 1 + metrics: ['acc', 'correct'] + count_flops: False + split: ['train', 'val', 'test'] diff --git a/federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml b/federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml new file mode 100644 index 000000000..f4ac62781 --- /dev/null +++ b/federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml @@ -0,0 +1,50 @@ +client_1: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + local_update_steps: 1 + optimizer: + lr: 0.001 +client_2: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + local_update_steps: 1 + optimizer: + lr: 0.001 +client_3: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + local_update_steps: 1 + optimizer: + lr: 0.001 +client_4: + model: + out_channels: 6 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + local_update_steps: 1 + optimizer: + lr: 0.001 +client_5: + model: + out_channels: 2 + task: graphClassification + criterion: + type: CrossEntropyLoss + train: + local_update_steps: 1 + optimizer: + lr: 0.001 diff --git a/federatedscope/gfl/trainer/graphtrainer.py b/federatedscope/gfl/trainer/graphtrainer.py index 2479ce933..8f58cde0d 100644 --- a/federatedscope/gfl/trainer/graphtrainer.py +++ b/federatedscope/gfl/trainer/graphtrainer.py @@ -58,7 +58,7 @@ def _hook_on_batch_forward_flop_count(self, ctx): except: logger.warning( "current flop count implementation is for general " - "NodeFullBatchTrainer case: " + "GraphMiniBatchTrainer case: " "1) the ctx.model takes only batch = ctx.data_batch as " "input." "Please check the forward format or implement your own " diff --git a/federatedscope/hpo.py b/federatedscope/hpo.py index c3b16cf9f..dfbd4ee7d 100644 --- a/federatedscope/hpo.py +++ b/federatedscope/hpo.py @@ -1,8 +1,6 @@ import os import sys -import yaml - DEV_MODE = False # simplify the federatedscope re-setup everytime we change # the source codes of federatedscope if DEV_MODE: @@ -11,8 +9,8 @@ from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger -from federatedscope.core.cmd_args import parse_args -from federatedscope.core.configs.config import global_cfg +from federatedscope.core.cmd_args import parse_args, parse_client_cfg +from federatedscope.core.configs.config import global_cfg, CfgNode from federatedscope.autotune import get_scheduler if os.environ.get('https_proxy'): @@ -23,22 +21,24 @@ if __name__ == '__main__': init_cfg = global_cfg.clone() args = parse_args() - init_cfg.merge_from_file(args.cfg_file) - init_cfg.merge_from_list(args.opts) + if args.cfg_file: + init_cfg.merge_from_file(args.cfg_file) + cfg_opt, client_cfg_opt = parse_client_cfg(args.opts) + init_cfg.merge_from_list(cfg_opt) update_logger(init_cfg, clear_before_add=True) setup_seed(init_cfg.seed) - assert not args.client_cfg_file, 'No support for client-wise config in ' \ - 'HPO mode.' + # load clients' cfg file + if args.client_cfg_file: + client_cfgs = CfgNode.load_cfg(open(args.client_cfg_file, 'r')) + # client_cfgs.set_new_allowed(True) + client_cfgs.merge_from_list(client_cfg_opt) + else: + client_cfgs = None - # with open(args.cfg_file, 'r') as ips: - # config = yaml.load(ips, Loader=yaml.FullLoader) - # det_config, tbd_config = split_raw_config(config) - # global_cfg.merge_from_list(config2cmdargs(det_config)) - # global_cfg.merge_from_list(args.opts) + scheduler = get_scheduler(init_cfg, client_cfgs) - scheduler = get_scheduler(init_cfg) if init_cfg.hpo.scheduler in ['sha', 'wrap_sha']: _ = scheduler.optimize() elif init_cfg.hpo.scheduler in [ @@ -46,12 +46,12 @@ 'wrap_bohb' ]: from federatedscope.autotune.hpbandster import run_hpbandster - run_hpbandster(init_cfg, scheduler) + run_hpbandster(init_cfg, scheduler, client_cfgs) elif init_cfg.hpo.scheduler in [ 'bo_gp', 'bo_rf', 'wrap_bo_gp', 'wrap_bo_rf' ]: from federatedscope.autotune.smac import run_smac - run_smac(init_cfg, scheduler) + run_smac(init_cfg, scheduler, client_cfgs) else: raise ValueError(f'No scheduler named {init_cfg.hpo.scheduler}') diff --git a/federatedscope/main.py b/federatedscope/main.py index 0eae9c8c3..9c8842c38 100644 --- a/federatedscope/main.py +++ b/federatedscope/main.py @@ -7,7 +7,7 @@ file_dir = os.path.join(os.path.dirname(__file__), '..') sys.path.append(file_dir) -from federatedscope.core.cmd_args import parse_args +from federatedscope.core.cmd_args import parse_args, parse_client_cfg from federatedscope.core.auxiliaries.data_builder import get_data from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger @@ -26,14 +26,19 @@ args = parse_args() if args.cfg_file: init_cfg.merge_from_file(args.cfg_file) - init_cfg.merge_from_list(args.opts) + cfg_opt, client_cfg_opt = parse_client_cfg(args.opts) + init_cfg.merge_from_list(cfg_opt) update_logger(init_cfg, clear_before_add=True) setup_seed(init_cfg.seed) # load clients' cfg file - client_cfgs = CfgNode.load_cfg(open(args.client_cfg_file, - 'r')) if args.client_cfg_file else None + if args.client_cfg_file: + client_cfgs = CfgNode.load_cfg(open(args.client_cfg_file, 'r')) + # client_cfgs.set_new_allowed(True) + client_cfgs.merge_from_list(client_cfg_opt) + else: + client_cfgs = None # federated dataset might change the number of clients # thus, we allow the creation procedure of dataset to modify the global diff --git a/tests/test_yaml.py b/tests/test_yaml.py index ff0570377..faf3a109e 100644 --- a/tests/test_yaml.py +++ b/tests/test_yaml.py @@ -14,7 +14,8 @@ def setUp(self): self.exclude_file = [ '.pre-commit-config.yaml', 'meta.yaml', 'federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup_per_client.yaml', - 'federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup_per_client.yaml' + 'federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup_per_client.yaml', + 'federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml' ] self.root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) self.exclude_all = [