Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pfedhpo and fts #509

Merged
merged 12 commits into from
Feb 14, 2023
4 changes: 4 additions & 0 deletions federatedscope/autotune/fts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from federatedscope.autotune.fts.server import FTSServer
from federatedscope.autotune.fts.client import FTSClient

__all__ = ['FTSServer', 'FTSClient']
238 changes: 238 additions & 0 deletions federatedscope/autotune/fts/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import os
import logging
import json
import copy
import pickle
import numpy as np

from federatedscope.core.message import Message
from federatedscope.core.workers import Client

from federatedscope.autotune.fts.utils import *
from federatedscope.autotune.utils import parse_search_space
from federatedscope.core.auxiliaries.trainer_builder import get_trainer

logger = logging.getLogger(__name__)


class FTSClient(Client):
def __init__(self,
ID=-1,
server_id=None,
state=-1,
config=None,
data=None,
model=None,
device='cpu',
strategy=None,
is_unseen_client=False,
*args,
**kwargs):
super(FTSClient,
self).__init__(ID, server_id, state, config, data, model, device,
strategy, is_unseen_client, *args, **kwargs)
self.data = data
self.model = model
self.device = device
self._diff = config.hpo.fts.diff
self._init_model = copy.deepcopy(model)

# local file paths
self.local_bo_path = os.path.join(self._cfg.hpo.working_folder,
"local_bo_" + str(self.ID) + ".pkl")
self.local_init_path = os.path.join(
self._cfg.hpo.working_folder,
"local_init_" + str(self.ID) + ".pkl")
self.local_info_path = os.path.join(
self._cfg.hpo.working_folder, "local_info_" + str(self.ID) +
"_M_" + str(self._cfg.hpo.fts.M) + ".pkl")

# prepare search space and bounds
self._ss = parse_search_space(self._cfg.hpo.fts.ss)
self.dim = len(self._ss)
self.bounds = np.asarray([(0., 1.) for _ in self._ss])
self.pbounds = {}
for k, v in self._ss.items():
if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
raise ValueError("Unsupported hyper type {}".format(type(v)))
else:
if v.log:
l, u = np.log10(v.lower), np.log10(v.upper)
else:
l, u = v.lower, v.upper
self.pbounds[k] = (l, u)

def _apply_hyperparams(self, hyperparams):
"""Apply the given hyperparameters
Arguments:
hyperparams (dict): keys are hyperparameter names \
and values are specific choices.
"""

cmd_args = []
for k, v in hyperparams.items():
cmd_args.append(k)
cmd_args.append(v)

self._cfg.defrost()
self._cfg.merge_from_list(cmd_args)
self._cfg.freeze(inform=False)

self.trainer.ctx.setup_vars()

def _get_new_trainer(self):
self.model = copy.deepcopy(self._init_model)
self.trainer = get_trainer(model=self.model,
data=self.data,
device=self.device,
config=self._cfg,
is_attacker=self.is_attacker,
monitor=self._monitor)

def _obj_func(self, x, return_eval=False):
"""
Run local evaluation, return some metric to maximize (e.g. val_acc)
"""
self._get_new_trainer()

baseline = 5.0
hyperparams = x2conf(x, self.pbounds, self._ss)
self._apply_hyperparams(hyperparams)

results_before = self.trainer.evaluate('val')
for _ in range(self._cfg.hpo.fts.local_bo_epochs):
sample_size, model_para_all, results = self.trainer.train()
results_after = self.trainer.evaluate('val')

if self._diff:
res = results_before['val_avg_loss'] \
- results_after['val_avg_loss']
else:
res = baseline - results_after['val_avg_loss']
if return_eval:
return res, results_after
else:
return res

def _generate_agent_info(self, rand_feats):
logger.info(
('-' * 20, ' generate info on clinet %d ' % self.ID, '_' * 20))
v_kernel = self._cfg.hpo.fts.v_kernel
obs_noise = self._cfg.hpo.fts.obs_noise
M = self._cfg.hpo.fts.M
M_target = self._cfg.hpo.fts.M_target

# run standard BO locally
max_iter = self._cfg.hpo.fts.local_bo_max_iter
gp_opt_schedule = self._cfg.hpo.fts.gp_opt_schedule
pt = np.ones(max_iter + 5)
LocalBO(cid=self.ID,
f=self._obj_func,
bounds=self.bounds,
keys=list(self.pbounds.keys()),
gp_opt_schedule=gp_opt_schedule,
use_init=None,
log_file=self.local_bo_path,
save_init=True,
save_init_file=self.local_init_path,
pt=pt,
P_N=None,
ls=self._cfg.hpo.fts.ls,
var=self._cfg.hpo.fts.var,
g_var=self._cfg.hpo.fts.g_var,
N=self._cfg.federate.client_num - 1,
M_target=M_target).maximize(n_iter=max_iter, init_points=3)

# generate local RFF information
res = pickle.load(open(self.local_bo_path, "rb"))
ys = np.array(res["all"]["values"]).reshape(-1, 1)
params = np.array(res["all"]["params"])
xs = np.array(params)
xs, ys = xs[:max_iter], ys[:max_iter]
Phi = np.zeros((xs.shape[0], M))

s, b = rand_feats["s"], rand_feats["b"]
for i, x in enumerate(xs):
x = np.squeeze(x).reshape(1, -1)
features = np.sqrt(2 / M) * np.cos(np.squeeze(np.dot(x, s.T)) + b)
features = features / np.sqrt(np.inner(features, features))
features = np.sqrt(v_kernel) * features
Phi[i, :] = features

Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M)
Sigma_t_inv = np.linalg.inv(Sigma_t)
nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T), ys)
w_samples = np.random.multivariate_normal(np.squeeze(nu_t),
obs_noise * Sigma_t_inv, 1)
pickle.dump(w_samples, open(self.local_info_path, "wb"))

def callback_funcs_for_model_para(self, message: Message):
round, sender, content = message.state, message.sender, message.content
require_agent_infos = content['require_agent_infos']

# generate local info and init then send them to server
if require_agent_infos:
rand_feat = content['random_feats']
self._generate_agent_info(rand_feat)
agent_info = pickle.load(open(self.local_info_path, "rb"))
agent_init = pickle.load(open(self.local_init_path, "rb"))
content = {
'is_required_agent_info': True,
'agent_info': agent_info,
'agent_init': agent_init,
}

# local run on given hyper-param and return performance
else:
x_max = content['x_max']
curr_y, eval_res = self._obj_func(x_max, return_eval=True)
content = {
'is_required_agent_info': False,
'curr_y': curr_y,
}
hyper_param = x2conf(x_max, self.pbounds, self._ss)
logger.info('{Client: %d, ' % self.ID +
'GP_opt_iter: %d, ' % round + 'Params: ' +
str(hyper_param) + ', Perform: ' + str(curr_y) + '}')
logger.info(
self._monitor.format_eval_res(eval_res,
rnd=self.state,
role='Client #{}'.format(
self.ID),
return_raw=True))

self.state = round
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=[sender],
state=self.state,
content=content))

def callback_funcs_for_evaluate(self, message: Message):
round, sender, content = \
message.state, message.sender, message.content
require_agent_infos = content['require_agent_infos']
assert not require_agent_infos, \
"Can not evaluate when there is no agents' information"

self.state = message.state
self._obj_func(content['x_max'])

metrics = {}
for split in self._cfg.eval.split:
eval_metrics = self.trainer.evaluate(target_data_split_name=split)
for key in eval_metrics:
if self._cfg.federate.mode == 'distributed':
logger.info('Client #{:d}: (Evaluation ({:s} set) at '
'Round #{:d}) {:s} is {:.6f}'.format(
self.ID, split, self.state, key,
eval_metrics[key]))
metrics.update(**eval_metrics)

self.comm_manager.send(
Message(msg_type='metrics',
sender=self.ID,
receiver=[sender],
state=self.state,
content=metrics))
Loading