Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Enhance client_cfg and add new dataset #413

Merged
merged 11 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions federatedscope/autotune/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
logger = logging.getLogger(__name__)


def make_trial(trial_cfg):
def make_trial(trial_cfg, client_cfgs=None):
setup_seed(trial_cfg.seed)
data, modified_config = get_data(config=trial_cfg.clone())
trial_cfg.merge_from_other_cfg(modified_config)
trial_cfg.freeze()
# TODO: enable client-wise configuration
Fed_runner = FedRunner(data=data,
server_class=get_server_cls(trial_cfg),
client_class=get_client_cls(trial_cfg),
config=trial_cfg.clone())
config=trial_cfg.clone(),
client_config=client_cfgs)
results = Fed_runner.run()
key1, key2 = trial_cfg.hpo.metric.split('.')
return results[key1][key2]
Expand All @@ -39,60 +39,65 @@ class TrialExecutor(threading.Thread):
"""This class is responsible for executing the FL procedure with
a given trial configuration in another thread.
"""
def __init__(self, cfg_idx, signal, returns, trial_config):
def __init__(self, cfg_idx, signal, returns, trial_config, client_cfgs):
threading.Thread.__init__(self)

self._idx = cfg_idx
self._signal = signal
self._returns = returns
self._trial_cfg = trial_config
self._client_cfgs = client_cfgs

def run(self):
setup_seed(self._trial_cfg.seed)
data, modified_config = get_data(config=self._trial_cfg.clone())
self._trial_cfg.merge_from_other_cfg(modified_config)
self._trial_cfg.freeze()
# TODO: enable client-wise configuration
Fed_runner = FedRunner(data=data,
server_class=get_server_cls(self._trial_cfg),
client_class=get_client_cls(self._trial_cfg),
config=self._trial_cfg.clone())
config=self._trial_cfg.clone(),
client_config=self._client_cfgs)
results = Fed_runner.run()
key1, key2 = self._trial_cfg.hpo.metric.split('.')
self._returns['perf'] = results[key1][key2]
self._returns['cfg_idx'] = self._idx
self._signal.set()


def get_scheduler(init_cfg):
def get_scheduler(init_cfg, client_cfgs=None):
"""To instantiate a scheduler object for conducting HPO
Arguments:
init_cfg (federatedscope.core.configs.config.CN): configuration.
init_cfg: configuration
client_cfgs: client-specific configuration
"""

if init_cfg.hpo.scheduler in [
'sha', 'rs', 'bo_kde', 'bohb', 'hb', 'bo_gp', 'bo_rf'
]:
scheduler = SuccessiveHalvingAlgo(init_cfg)
scheduler = SuccessiveHalvingAlgo(init_cfg, client_cfgs)
# elif init_cfg.hpo.scheduler == 'pbt':
# scheduler = PBT(init_cfg)
elif init_cfg.hpo.scheduler.startswith('wrap'):
elif init_cfg.hpo.scheduler.startswith('wrap', client_cfgs):
scheduler = SHAWrapFedex(init_cfg)
return scheduler


class Scheduler(object):
"""The base class for describing HPO algorithms
"""
def __init__(self, cfg):
def __init__(self, cfg, client_cfgs=None):
"""
Arguments:
cfg (federatedscope.core.configs.config.CN): dict like object,
where each key-value pair corresponds to a field and its
choices.
cfg (federatedscope.core.configs.config.CN): dict \
like object, where each key-value pair corresponds to a \
field and its choices.
client_cfgs: client-specific configuration
"""

self._cfg = cfg
self._client_cfgs = client_cfgs

# Create hpo working folder
os.makedirs(self._cfg.hpo.working_folder, exist_ok=True)
self._search_space = parse_search_space(self._cfg.hpo.ss)
Expand Down Expand Up @@ -160,7 +165,7 @@ def _evaluate(self, configs):
flags[available_worker].clear()
trial = TrialExecutor(i, flags[available_worker],
thread_results[available_worker],
trial_cfg)
trial_cfg, self._client_cfgs)
trial.start()
threads[available_worker] = trial

Expand All @@ -182,7 +187,7 @@ def _evaluate(self, configs):
for i, config in enumerate(configs):
trial_cfg = self._cfg.clone()
trial_cfg.merge_from_list(config2cmdargs(config))
perfs[i] = make_trial(trial_cfg)
perfs[i] = make_trial(trial_cfg, self._client_cfgs)
logger.info(
"Evaluate the {}-th config {} and get performance {}".
format(i, config, perfs[i]))
Expand Down
16 changes: 11 additions & 5 deletions federatedscope/autotune/hpbandster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import hpbandster.core.nameserver as hpns
from hpbandster.core.worker import Worker
from hpbandster.optimizers import BOHB, HyperBand, RandomSearch
from hpbandster.optimizers.iterations import SuccessiveHalving

from federatedscope.autotune.utils import eval_in_fs

Expand Down Expand Up @@ -54,7 +53,13 @@ def get_next_iteration(self, iteration, iteration_kwargs={}):


class MyWorker(Worker):
def __init__(self, cfg, ss, sleep_interval=0, *args, **kwargs):
def __init__(self,
cfg,
ss,
sleep_interval=0,
client_cfgs=None,
*args,
**kwargs):
super(MyWorker, self).__init__(**kwargs)
self.sleep_interval = sleep_interval
self.cfg = cfg
Expand All @@ -63,7 +68,7 @@ def __init__(self, cfg, ss, sleep_interval=0, *args, **kwargs):
self._perfs = []

def compute(self, config, budget, **kwargs):
res = eval_in_fs(self.cfg, config, int(budget))
res = eval_in_fs(self.cfg, config, int(budget), self.client_cfgs)
config = dict(config)
config['federate.total_round_num'] = budget
self._init_configs.append(config)
Expand All @@ -87,7 +92,7 @@ def summarize(self):
return results


def run_hpbandster(cfg, scheduler):
def run_hpbandster(cfg, scheduler, client_cfgs=None):
config_space = scheduler._search_space
if cfg.hpo.scheduler.startswith('wrap_'):
ss = CS.ConfigurationSpace()
Expand All @@ -100,7 +105,8 @@ def run_hpbandster(cfg, scheduler):
cfg=cfg,
nameserver='127.0.0.1',
nameserver_port=ns_port,
run_id=cfg.hpo.scheduler)
run_id=cfg.hpo.scheduler,
client_cfgs=client_cfgs)
w.run(background=True)
opt_kwargs = {
'configspace': config_space,
Expand Down
4 changes: 2 additions & 2 deletions federatedscope/autotune/smac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
logger = logging.getLogger(__name__)


def run_smac(cfg, scheduler):
def run_smac(cfg, scheduler, client_cfgs=None):
init_configs = []
perfs = []

def optimization_function_wrapper(config):
budget = cfg.hpo.sha.budgets[-1]
res = eval_in_fs(cfg, config, budget)
res = eval_in_fs(cfg, config, budget, client_cfgs)
config = dict(config)
config['federate.total_round_num'] = budget
init_configs.append(config)
Expand Down
5 changes: 3 additions & 2 deletions federatedscope/autotune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def process(file):
plt.close()


def eval_in_fs(cfg, config, budget):
def eval_in_fs(cfg, config, budget, client_cfgs=None):
import ConfigSpace as CS
from federatedscope.core.auxiliaries.utils import setup_seed
from federatedscope.core.auxiliaries.data_builder import get_data
Expand Down Expand Up @@ -170,7 +170,8 @@ def eval_in_fs(cfg, config, budget):
Fed_runner = FedRunner(data=data,
server_class=get_server_cls(trial_cfg),
client_class=get_client_cls(trial_cfg),
config=trial_cfg.clone())
config=trial_cfg.clone(),
client_config=client_cfgs)
results = Fed_runner.run()
key1, key2 = trial_cfg.hpo.metric.split('.')
return results[key1][key2]
30 changes: 0 additions & 30 deletions federatedscope/contrib/data/example.py

This file was deleted.

160 changes: 160 additions & 0 deletions federatedscope/contrib/data/mini_graph_dt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import os
import torch
import numpy as np

from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.datasets import TUDataset, MoleculeNet

from federatedscope.register import register_data
from federatedscope.core.data import DummyDataTranslator
from federatedscope.core.splitters.graph.scaffold_lda_splitter import \
GenFeatures

# Run with mini_graph_dt:
# python federatedscope/main.py --cfg \
# federatedscope/gfl/baseline/mini_graph_dc/fedavg.yaml --client_cfg \
# federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml
# Test Accuracy: ~0.7


class MiniGraphDCDataset(InMemoryDataset):
NAME = 'mini_graph_dt'
DATA_NAME = ['BACE', 'BBBP', 'CLINTOX', 'ENZYMES', 'PROTEINS_full']
IN_MEMORY_DATA = {}

def __init__(self, root, splits=[0.8, 0.1, 0.1]):
self.root = root
self.splits = splits
super(MiniGraphDCDataset, self).__init__(root)

@property
def processed_dir(self):
return os.path.join(self.root, self.NAME, 'processed')

@property
def processed_file_names(self):
return ['pre_transform.pt', 'pre_filter.pt']

def __len__(self):
return len(self.DATA_NAME)

def __getitem__(self, idx):
if idx not in self.IN_MEMORY_DATA:
self.IN_MEMORY_DATA[idx] = {}
for split in ['train', 'val', 'test']:
split_data = self._load(idx, split)
if split_data:
self.IN_MEMORY_DATA[idx][split] = split_data
return self.IN_MEMORY_DATA[idx]

def _load(self, idx, split):
try:
data = torch.load(
os.path.join(self.processed_dir, str(idx), f'{split}.pt'))
except:
data = None
return data

def process(self):
np.random.seed(0)
for idx, name in enumerate(self.DATA_NAME):
if name in ['BACE', 'BBBP', 'CLINTOX']:
dataset = MoleculeNet(self.root, name)
featurizer = GenFeatures()
ds = []
for graph in dataset:
graph = featurizer(graph)
ds.append(
Data(edge_index=graph.edge_index, x=graph.x,
y=graph.y))
dataset = ds
if name in ['BACE', 'BBBP']:
for i in range(len(dataset)):
dataset[i].y = dataset[i].y.long()
if name in ['CLINTOX']:
for i in range(len(dataset)):
dataset[i].y = torch.argmax(
dataset[i].y).view(-1).unsqueeze(0)
else:
# Classification
dataset = TUDataset(self.root, name)
dataset = [
Data(edge_index=graph.edge_index, x=graph.x, y=graph.y)
for graph in dataset
]

# We fix train/val/test
index = np.random.permutation(np.arange(len(dataset)))
train_idx = index[:int(len(dataset) * self.splits[0])]
valid_idx = index[int(len(dataset) * self.splits[0]
):int(len(dataset) * sum(self.splits[:2]))]
test_idx = index[int(len(dataset) * sum(self.splits[:2])):]

if not os.path.isdir(os.path.join(self.processed_dir, str(idx))):
os.makedirs(os.path.join(self.processed_dir, str(idx)))

train_path = os.path.join(self.processed_dir, str(idx), 'train.pt')
valid_path = os.path.join(self.processed_dir, str(idx), 'val.pt')
test_path = os.path.join(self.processed_dir, str(idx), 'test.pt')

torch.save([dataset[i] for i in train_idx], train_path)
torch.save([dataset[i] for i in valid_idx], valid_path)
torch.save([dataset[i] for i in test_idx], test_path)

print(name, len(dataset), dataset[0])

def meta_info(self):
return {
'BACE': {
'task': 'classification',
'input_dim': 74,
'output_dim': 2,
'num_samples': 1513,
},
'BBBP': {
'task': 'classification',
'input_dim': 74,
'output_dim': 2,
'num_samples': 2039,
},
'CLINTOX': {
'task': 'classification',
'input_dim': 74,
'output_dim': 2,
'num_samples': 1478,
},
'ENZYMES': {
'task': 'classification',
'input_dim': 3,
'output_dim': 6,
'num_samples': 600,
},
'PROTEINS_full': {
'task': 'classification',
'input_dim': 3,
'output_dim': 2,
'num_samples': 1113,
},
}


def load_mini_graph_dt(config, client_cfgs=None):
dataset = MiniGraphDCDataset(config.data.root)
# Convert to dict
datadict = {
client_id + 1: dataset[client_id]
for client_id in range(len(dataset))
}
config.merge_from_list(['federate.client_num', len(dataset)])
translator = DummyDataTranslator(config, client_cfgs)

return translator(datadict), config


def call_mini_graph_dt(config, client_cfgs):
if config.data.type == "mini-graph-dc":
data, modified_config = load_mini_graph_dt(config, client_cfgs)
return data, modified_config


register_data("mini-graph-dc", call_mini_graph_dt)
Loading