Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]add wrap fedex #137

Merged
merged 11 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions federatedscope/autotune/algos.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import os
import logging
from copy import deepcopy
from contextlib import redirect_stdout
from yacs.config import CfgNode as CN
import threading
from itertools import product
import math

import yaml

import numpy as np
import torch

from federatedscope.core.auxiliaries.utils import setup_seed
from federatedscope.core.auxiliaries.data_builder import get_data
from federatedscope.core.auxiliaries.worker_builder import get_client_cls, get_server_cls
Expand Down Expand Up @@ -74,6 +70,8 @@ def get_scheduler(init_cfg):
scheduler = SuccessiveHalvingAlgo(init_cfg)
elif init_cfg.hpo.scheduler == 'pbt':
scheduler = PBT(init_cfg)
elif init_cfg.hpo.scheduler == 'wrap_sha':
scheduler = SHAWrapFedex(init_cfg)
return scheduler


Expand All @@ -87,7 +85,6 @@ def __init__(self, cfg):
"""

self._cfg = cfg

self._search_space = parse_search_space(self._cfg.hpo.ss)

self._init_configs = self._setup()
Expand Down Expand Up @@ -310,6 +307,43 @@ def _generate_next_population(self, configs, perfs):
return next_population


class SHAWrapFedex(SuccessiveHalvingAlgo):
def _cache_yaml(self):
# Save as file
for idx in range(self._cfg.hpo.table.cand):
sample_ss = parse_search_space(
self._cfg.hpo.table.ss).sample_configuration(
self._cfg.hpo.table.num)
# Convert Configuration to CfgNode
tmp_cfg = CN()
for arm, configuration in enumerate(sample_ss):
tmp_cfg[f'arm{arm}'] = CN()
for key, value in configuration.get_dictionary().items():
tmp_cfg[f'arm{arm}'][key] = value

with open(
os.path.join(self._cfg.hpo.working_folder,
f'{idx}_tmp_grid_search_space.yaml'),
'w') as f:
with redirect_stdout(f):
print(tmp_cfg.dump())

def _setup(self):
self._cache_yaml()
init_configs = super(SHAWrapFedex, self)._setup()

for idx, trial_cfg in enumerate(init_configs):
trial_cfg['hpo.table.idx'] = idx
trial_cfg['hpo.fedex.ss'] = os.path.join(
self._cfg.hpo.working_folder,
f"{trial_cfg['hpo.table.idx']}_tmp_grid_search_space.yaml")
trial_cfg['federate.save_to'] = os.path.join(
self._cfg.hpo.working_folder,
"idx_{}.pth".format(idx))
print(init_configs)
return init_configs


# TODO: refactor PBT to enable async parallel
#class PBT(IterativeScheduler):
# """Population-based training (the full paper "Population Based Training of Neural Networks" can be found at https://arxiv.org/abs/1711.09846) tailored to FL setting, where, in each iteration, just a limited number of communication rounds are allowed for each trial (We will provide the asynchornous version later).
Expand Down Expand Up @@ -383,4 +417,4 @@ def _generate_next_population(self, configs, perfs):
#
# next_generation.append(new_cfg)
#
# return next_generation
# return next_generation
52 changes: 52 additions & 0 deletions federatedscope/autotune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,55 @@ def summarize_hpo_results(configs, perfs, white_list=None, desc=False):
d = sorted(d, key=lambda ele: ele[-1], reverse=desc)
df = pd.DataFrame(d, columns=cols)
return df


def parse_logs(file_list):
import numpy as np
import matplotlib.pyplot as plt

FONTSIZE = 40
MARKSIZE = 25

def process(file):
history = []
with open(file, 'r') as F:
for line in F:
try:
state, line = line.split('INFO: ')
config = eval(line[line.find('{'): line.find('}') + 1])
performance = float(line[line.find('performance'):].split(' ')[1])
print(config, performance)
history.append((config, performance))
except:
continue
best_seen = np.inf
tol_budget = 0
x, y = [], []

for config, performance in history:
tol_budget += config['federate.total_round_num']
if best_seen > performance or config['federate.total_round_num'] > tmp_b:
best_seen = performance
x.append(tol_budget)
y.append(best_seen)
tmp_b = config['federate.total_round_num']
return np.array(x) / tol_budget, np.array(y)

# Draw
plt.figure(figsize=(10, 7.5))
plt.xticks(fontsize=FONTSIZE)
plt.yticks(fontsize=FONTSIZE)

plt.xlabel('Fraction of budget', size=FONTSIZE)
plt.ylabel('Loss', size=FONTSIZE)

for file in file_list:
x, y = process(file)
plt.plot(x, y, linewidth=1, markersize=MARKSIZE)
plt.legend(file_list, fontsize=23, loc='lower right')
plt.savefig(f'exp2.pdf', bbox_inches='tight')
plt.close()




2 changes: 1 addition & 1 deletion federatedscope/contrib/trainer/example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from federatedscope.register import register_trainer
from federatedscope.core.trainers.trainer import GeneralTorchTrainer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer


# Build your trainer here.
Expand Down
7 changes: 7 additions & 0 deletions federatedscope/core/configs/cfg_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def extend_hpo_cfg(cfg):
cfg.hpo.fedex.num_arms = 16
cfg.hpo.fedex.diff = False

# Table
cfg.hpo.table = CN()
cfg.hpo.table.ss = ''
cfg.hpo.table.num = 4
cfg.hpo.table.cand = 81
cfg.hpo.table.idx = 0


def assert_hpo_cfg(cfg):
# HPO related
Expand Down
11 changes: 11 additions & 0 deletions federatedscope/example_configs/cora/hpo_ss_fedex.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
hpo.fedex.eta0:
type: cate
choices: [-1.0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0]
hpo.fedex.gamma:
type: float
lower: 0.0
upper: 1.0
log: False
hpo.fedex.diff:
type: cate
choices: [True, False]
12 changes: 12 additions & 0 deletions federatedscope/example_configs/cora/hpo_ss_fedex_arm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
optimizer.lr:
type: cate
choices: [0.01, 0.01668, 0.02783, 0.04642, 0.07743, 0.12915, 0.21544, 0.35938, 0.59948, 1.0]
optimizer.weight_decay:
type: cate
choices: [0.0, 0.001, 0.01, 0.1]
model.dropout:
type: cate
choices: [0.0, 0.5]
federate.local_update_steps:
type: cate
choices: [1, 2, 3, 4, 5, 6, 7, 8]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
hpo.table.idx:
type: cate
choices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]
4 changes: 4 additions & 0 deletions federatedscope/example_configs/cora/hpo_ss_fedex_grid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
optimizer.lr: [0.01, 0.01668, 0.02783, 0.04642, 0.07743, 0.12915, 0.21544, 0.35938, 0.59948, 1.0]
optimizer.weight_decay: [0.0, 0.001, 0.01, 0.1]
model.dropout: [0.0, 0.5]
federate.local_update_steps: [1, 2, 3, 4, 5, 6, 7, 8]
14 changes: 14 additions & 0 deletions federatedscope/example_configs/cora/hpo_ss_sha.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
optimizer.lr:
type: float
lower: 0.01
upper: 1.0
log: True
optimizer.weight_decay:
type: cate
choices: [0.0, 0.001, 0.01, 0.1]
model.dropout:
type: cate
choices: [0.0, 0.5]
federate.local_update_steps:
type: cate
choices: [1, 2, 3, 4, 5, 6, 7, 8]
8 changes: 8 additions & 0 deletions federatedscope/example_configs/cora/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SHA
python hpo.py --cfg federatedscope/example_configs/cora/sha.yaml

# SHA wrap FedEX (FedEX related param)
python hpo.py --cfg federatedscope/example_configs/cora/sha_wrap_fedex.yaml

# SHA wrap FedEX (arm)
python hpo.py --cfg federatedscope/example_configs/cora/sha_wrap_fedex_arm.yaml
43 changes: 43 additions & 0 deletions federatedscope/example_configs/cora/sha.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use_gpu: True
device: 3
early_stop:
patience: 100
seed: 12345
federate:
mode: standalone
make_global_eval: True
client_num: 5
local_update_steps: 1
total_round_num: 500
share_local_model: True
online_aggr: True
use_diff: True
data:
root: data/
type: cora
splitter: 'louvain'
batch_size: 1
model:
type: gcn
hidden: 64
dropout: 0.5
out_channels: 7
optimizer:
lr: 0.25
weight_decay: 0.0005
criterion:
type: CrossEntropyLoss
trainer:
type: nodefullbatch_trainer
eval:
freq: 1
metrics: ['acc', 'correct', 'f1']
split: ['test', 'val', 'train']
hpo:
scheduler: sha
num_workers: 0
init_cand_num: 81
ss: 'federatedscope/example_configs/cora/hpo_ss_sha.yaml'
sha:
budgets: [2, 4, 12, 36]
metric: 'server_global_eval.val_avg_loss'
46 changes: 46 additions & 0 deletions federatedscope/example_configs/cora/sha_wrap_fedex.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use_gpu: True
device: 3
early_stop:
patience: 100
seed: 12345
federate:
mode: standalone
make_global_eval: True
client_num: 5
local_update_steps: 1
total_round_num: 500
share_local_model: True
online_aggr: True
use_diff: True
data:
root: data/
type: cora
splitter: 'louvain'
batch_size: 1
model:
type: gcn
hidden: 64
dropout: 0.5
out_channels: 7
optimizer:
lr: 0.25
weight_decay: 0.0005
criterion:
type: CrossEntropyLoss
trainer:
type: nodefullbatch_trainer
eval:
freq: 1
metrics: ['acc', 'correct', 'f1']
split: ['test', 'val', 'train']
hpo:
scheduler: sha
num_workers: 0
init_cand_num: 81
ss: 'federatedscope/example_configs/cora/hpo_ss_fedex.yaml'
sha:
budgets: [2, 4, 12, 36]
fedex:
use: True
ss: 'federatedscope/example_configs/cora/hpo_ss_fedex_grid.yaml'
metric: 'server_global_eval.val_avg_loss'
52 changes: 52 additions & 0 deletions federatedscope/example_configs/cora/sha_wrap_fedex_arm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use_gpu: True
device: 3
early_stop:
patience: 100
seed: 12345
federate:
mode: standalone
make_global_eval: True
client_num: 5
local_update_steps: 1
total_round_num: 500
share_local_model: True
online_aggr: True
use_diff: True
data:
root: data/
type: cora
splitter: 'louvain'
batch_size: 1
model:
type: gcn
hidden: 64
dropout: 0.5
out_channels: 7
optimizer:
lr: 0.25
weight_decay: 0.0005
criterion:
type: CrossEntropyLoss
trainer:
type: nodefullbatch_trainer
eval:
freq: 1
metrics: ['acc', 'correct', 'f1']
split: ['test', 'val', 'train']
hpo:
scheduler: wrap_sha
num_workers: 0
init_cand_num: 81
ss: 'federatedscope/example_configs/cora/hpo_ss_fedex_arm_table.yaml'
table:
ss: 'federatedscope/example_configs/cora/hpo_ss_fedex_arm.yaml'
num: 4
cand: 81
sha:
budgets: [2, 4, 12, 36]
fedex:
use: True
diff: False
eta0: 0.050
gamma: 0.495861
metric: 'server_global_eval.val_avg_loss'
11 changes: 11 additions & 0 deletions federatedscope/example_configs/femnist/hpo_ss_fedex.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
hpo.fedex.eta0:
type: cate
choices: [-1.0, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0]
hpo.fedex.gamma:
type: float
lower: 0.0
upper: 1.0
log: False
hpo.fedex.diff:
type: cate
choices: [True, False]
15 changes: 15 additions & 0 deletions federatedscope/example_configs/femnist/hpo_ss_fedex_arm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
optimizer.lr:
type: cate
choices: [0.01, 0.01668, 0.02783, 0.04642, 0.07743, 0.12915, 0.21544, 0.35938, 0.59948, 1.0]
optimizer.weight_decay:
type: cate
choices: [0.0, 0.001, 0.01, 0.1]
model.dropout:
type: cate
choices: [0.0, 0.5]
federate.local_update_steps:
type: cate
choices: [1, 2, 3, 4]
data.batch_size:
type: cate
choices: [16, 32, 64]
Loading