Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled personalized policy for fedex #481

Merged
merged 4 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions federatedscope/autotune/fedex/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def callback_funcs_for_model_para(self, message: Message):
return_raw=True))

results['arms'] = arms
results['client_id'] = self.ID - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

content = (sample_size, model_para_all, results)
self.comm_manager.send(
Message(msg_type='model_para',
Expand Down
211 changes: 150 additions & 61 deletions federatedscope/autotune/fedex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import numpy as np
from numpy.linalg import norm
from scipy.special import logsumexp
import torch

from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
from federatedscope.autotune.fedex.utils import HyperNet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,6 +40,10 @@ def __init__(self,
strategy=None,
**kwargs):

super(FedExServer,
self).__init__(ID, state, config, data, model, client_num,
total_round_num, device, strategy, **kwargs)

# initialize action space and the policy
with open(config.hpo.fedex.ss, 'r') as ips:
ss = yaml.load(ips, Loader=yaml.FullLoader)
Expand Down Expand Up @@ -66,21 +72,39 @@ def __init__(self,
self._cutoff = config.hpo.fedex.cutoff
self._baseline = config.hpo.fedex.gamma
self._diff = config.hpo.fedex.diff
self._z = [np.full(size, -np.log(size)) for size in sizes]
self._theta = [np.exp(z) for z in self._z]
self._store = [0.0 for _ in sizes]
if self._cfg.hpo.fedex.psn:
# personalized policy
# TODO: client-wise RFF
self._client_encodings = torch.randn(
(client_num, 8), device=device) / np.sqrt(8)
self._policy_net = HyperNet(
self._client_encodings.shape[-1],
sizes,
client_num,
device,
).to(device)
self._policy_net.eval()
theta4stat = [
theta.detach().cpu().numpy()
for theta in self._policy_net(self._client_encodings)
]
self._pn_optimizer = torch.optim.Adam(
self._policy_net.parameters(),
lr=self._cfg.hpo.fedex.pi_lr,
weight_decay=1e-5)
else:
self._z = [np.full(size, -np.log(size)) for size in sizes]
self._theta = [np.exp(z) for z in self._z]
theta4stat = self._theta
self._store = [0.0 for _ in sizes]
self._stop_exploration = False
self._trace = {
'global': [],
'refine': [],
'entropy': [self.entropy()],
'mle': [self.mle()]
'entropy': [self.entropy(theta4stat)],
'mle': [self.mle(theta4stat)]
}

super(FedExServer,
self).__init__(ID, state, config, data, model, client_num,
total_round_num, device, strategy, **kwargs)

if self._cfg.federate.restore_from != '':
if not os.path.exists(self._cfg.federate.restore_from):
logger.warning(f'Invalid `restore_from`:'
Expand All @@ -91,26 +115,45 @@ def __init__(self,
+ "_fedex.yaml"
with open(pi_ckpt_path, 'r') as ips:
ckpt = yaml.load(ips, Loader=yaml.FullLoader)
self._z = [np.asarray(z) for z in ckpt['z']]
self._theta = [np.exp(z) for z in self._z]
self._store = ckpt['store']
if self._cfg.hpo.fedex.psn:
psn_pi_ckpt_path = self._cfg.federate.restore_from[
:self._cfg.federate.restore_from.rfind('.')] \
+ "_pfedex.pt"
psn_pi = torch.load(psn_pi_ckpt_path, map_location=device)
self._client_encodings = psn_pi['client_encodings']
self._policy_net.load_state_dict(psn_pi['policy_net'])
else:
self._z = [np.asarray(z) for z in ckpt['z']]
self._theta = [np.exp(z) for z in self._z]
self._store = ckpt['store']
self._stop_exploration = ckpt['stop']
self._trace = dict()
self._trace['global'] = ckpt['global']
self._trace['refine'] = ckpt['refine']
self._trace['entropy'] = ckpt['entropy']
self._trace['mle'] = ckpt['mle']

def entropy(self):
entropy = 0.0
for probs in product(*(theta[theta > 0.0] for theta in self._theta)):
prob = np.prod(probs)
entropy -= prob * np.log(prob)
return entropy

def mle(self):

return np.prod([theta.max() for theta in self._theta])
def entropy(self, thetas):
if self._cfg.hpo.fedex.psn:
entropy = 0.0
for i in range(thetas[0].shape[0]):
for probs in product(*(theta[i][theta[i] > 0.0]
for theta in thetas)):
prob = np.prod(probs)
entropy -= prob * np.log(prob)
return entropy / float(thetas[0].shape[0])
else:
entropy = 0.0
for probs in product(*(theta[theta > 0.0] for theta in thetas)):
prob = np.prod(probs)
entropy -= prob * np.log(prob)
return entropy

def mle(self, thetas):
if self._cfg.hpo.fedex.psn:
return np.prod([theta.max(-1) for theta in thetas], 0).mean()
else:
return np.prod([theta.max() for theta in thetas])

def trace(self, key):
'''returns trace of one of three tracked quantities
Expand All @@ -122,15 +165,18 @@ def trace(self, key):

return np.array(self._trace[key])

def sample(self):
"""samples from configs using current probability vector"""
def sample(self, thetas):
"""samples from configs using current probability vector
Arguments:
thetas (list): probabilities for the hyperparameters.
"""

# determine index
if self._stop_exploration:
cfg_idx = [theta.argmax() for theta in self._theta]
cfg_idx = [theta.argmax() for theta in thetas]
else:
cfg_idx = [
np.random.choice(len(theta), p=theta) for theta in self._theta
np.random.choice(len(theta), p=theta) for theta in thetas
]

# get the sampled value(s)
Expand Down Expand Up @@ -178,9 +224,18 @@ def broadcast_model_para(self,
model_para = self.model.state_dict()

# sample the hyper-parameter config specific to the clients

if self._cfg.hpo.fedex.psn:
self._policy_net.train()
self._pn_optimizer.zero_grad()
self._theta = self._policy_net(self._client_encodings)
for rcv_idx in receiver:
cfg_idx, sampled_cfg = self.sample()
if self._cfg.hpo.fedex.psn:
cfg_idx, sampled_cfg = self.sample([
theta[rcv_idx - 1].detach().cpu().numpy()
for theta in self._theta
])
else:
cfg_idx, sampled_cfg = self.sample(self._theta)
content = {
'model_param': model_para,
"arms": cfg_idx,
Expand Down Expand Up @@ -225,6 +280,7 @@ def update_policy(self, feedbacks):
"""

index = [elem['arms'] for elem in feedbacks]
cids = [elem['client_id'] for elem in feedbacks]
before = np.asarray(
[elem['val_avg_loss_before'] for elem in feedbacks])
after = np.asarray([elem['val_avg_loss_after'] for elem in feedbacks])
Expand All @@ -246,41 +302,60 @@ def update_policy(self, feedbacks):
self._trace['mle'].append(1.0)
return

for i, (z, theta) in enumerate(zip(self._z, self._theta)):
grad = np.zeros(len(z))
for idx, s, w in zip(index,
after - before if self._diff else after,
weight):
grad[idx[i]] += w * (s - baseline) / theta[idx[i]]
if self._sched == 'adaptive':
self._store[i] += norm(grad, float('inf'))**2
denom = np.sqrt(self._store[i])
elif self._sched == 'aggressive':
denom = 1.0 if np.all(
grad == 0.0) else norm(grad, float('inf'))
elif self._sched == 'auto':
self._store[i] += 1.0
denom = np.sqrt(self._store[i])
elif self._sched == 'constant':
denom = 1.0
elif self._sched == 'scale':
denom = 1.0 / np.sqrt(
2.0 * np.log(len(grad))) if len(grad) > 1 else float('inf')
else:
raise NotImplementedError
eta = self._eta0[i] / denom
z -= eta * grad
z -= logsumexp(z)
self._theta[i] = np.exp(z)

self._trace['entropy'].append(self.entropy())
self._trace['mle'].append(self.mle())
if self._cfg.hpo.fedex.psn:
# policy gradients
pg_obj = .0
for i, theta in enumerate(self._theta):
for idx, cidx, s, w in zip(
index, cids, after - before if self._diff else after,
weight):
pg_obj += w * -1.0 * (s - baseline) * torch.log(
torch.clip(theta[cidx][idx[i]], min=1e-8, max=1.0))
pg_loss = -1.0 * pg_obj
pg_loss.backward()
self._pn_optimizer.step()
self._policy_net.eval()
thetas4stat = [
theta.detach().cpu().numpy()
for theta in self._policy_net(self._client_encodings)
]
else:
for i, (z, theta) in enumerate(zip(self._z, self._theta)):
grad = np.zeros(len(z))
for idx, s, w in zip(index,
after - before if self._diff else after,
weight):
grad[idx[i]] += w * (s - baseline) / theta[idx[i]]
if self._sched == 'adaptive':
self._store[i] += norm(grad, float('inf'))**2
denom = np.sqrt(self._store[i])
elif self._sched == 'aggressive':
denom = 1.0 if np.all(
grad == 0.0) else norm(grad, float('inf'))
elif self._sched == 'auto':
self._store[i] += 1.0
denom = np.sqrt(self._store[i])
elif self._sched == 'constant':
denom = 1.0
elif self._sched == 'scale':
denom = 1.0 / np.sqrt(2.0 * np.log(len(grad))) if len(
grad) > 1 else float('inf')
else:
raise NotImplementedError
eta = self._eta0[i] / denom
z -= eta * grad
z -= logsumexp(z)
self._theta[i] = np.exp(z)
thetas4stat = self._theta

self._trace['entropy'].append(self.entropy(thetas4stat))
self._trace['mle'].append(self.mle(thetas4stat))
if self._trace['entropy'][-1] < self._cutoff:
self._stop_exploration = True

logger.info(
'Server: Updated policy as {} with entropy {:f} and mle {:f}'.
format(self._theta, self._trace['entropy'][-1],
format(thetas4stat, self._trace['entropy'][-1],
self._trace['mle'][-1]))

def check_and_move_on(self,
Expand Down Expand Up @@ -413,9 +488,23 @@ def check_and_save(self):
if self._cfg.federate.save_to != '':
# save the policy
ckpt = dict()
z_list = [z.tolist() for z in self._z]
ckpt['z'] = z_list
ckpt['store'] = self._store
if self._cfg.hpo.fedex.psn:
psn_pi_ckpt_path = self._cfg.federate.save_to[:self._cfg.
federate.
save_to.
rfind(
'.'
)] + \
"_pfedex.pt"
torch.save(
{
'client_encodings': self._client_encodings,
'policy_net': self._policy_net.state_dict()
}, psn_pi_ckpt_path)
else:
z_list = [z.tolist() for z in self._z]
ckpt['z'] = z_list
ckpt['store'] = self._store
ckpt['stop'] = self._stop_exploration
ckpt['global'] = self.trace('global').tolist()
ckpt['refine'] = self.trace('refine').tolist()
Expand Down
87 changes: 87 additions & 0 deletions federatedscope/autotune/fedex/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from torch import nn
from torch.nn.utils import spectral_norm

from federatedscope.autotune.utils import arm2dict


class EncNet(nn.Module):
def __init__(self, in_channel, out_channel, hid_dim=64):
super(EncNet, self).__init__()

self.fc_layer = nn.Sequential(
spectral_norm(nn.Linear(in_channel, hid_dim, bias=False)),
nn.ReLU(inplace=True),
spectral_norm(nn.Linear(hid_dim, out_channel, bias=False)),
nn.ReLU(inplace=True),
)

def forward(self, client_enc):
mean_update = self.fc_layer(client_enc)
return mean_update


class HyperNet(nn.Module):
def __init__(
self,
input_dim,
sizes,
n_clients,
device,
):
super(HyperNet, self).__init__()
self.EncNet = EncNet(input_dim, 32)
self.out = nn.ModuleList()
for num_cate in sizes:
self.out.append(
nn.Sequential(nn.Linear(32, num_cate, bias=True),
nn.Softmax()))

def forward(self, encoding):
client_enc = self.EncNet(encoding)
probs = []
for module in self.out:
out = module(client_enc)
probs.append(out)
return probs


if __name__ == "__main__":
import yaml
import argparse
import torch

parser = argparse.ArgumentParser(description='Interpret learned policy')
parser.add_argument('--ss_path', type=str, default='')
parser.add_argument('--log_path', type=str, default='')
parser.add_argument('--pt_path', type=str, default='')
parser.add_argument('--save_path', type=str, default='')
args = parser.parse_args()

with open(args.ss_path, 'r') as ips:
arms = yaml.load(ips, Loader=yaml.FullLoader)
print(arms)
with open(args.log_path, 'r') as ips:
ckpt = yaml.load(ips, Loader=yaml.FullLoader)
stop_exploration = ckpt['stop']
print("stop: {}".format(stop_exploration))

psn_pi = torch.load(args.pt_path, map_location='cpu')
client_encodings = psn_pi['client_encodings']
policy_net = HyperNet(
client_encodings.shape[-1],
[len(arms)],
client_encodings.shape[0],
'cpu',
).to('cpu')
policy_net.load_state_dict(psn_pi['policy_net'])
policy_net.eval()
prbs = policy_net(client_encodings)
prbs = prbs[0].detach().numpy()
clientwise_configs = dict()
for i in range(prbs.shape[0]):
arm_idx = prbs[i].argmax()
clientwise_configs['client_{}'.format(i + 1)] = arm2dict(
arms['arm{}'.format(arm_idx)])
with open(args.save_path, 'w') as ops:
yaml.Dumper.ignore_aliases = lambda *args: True
yaml.dump(clientwise_configs, ops)
Loading