Skip to content

Commit

Permalink
Enabled un-seen clients case to check the participation generalizatio…
Browse files Browse the repository at this point in the history
…n gap
  • Loading branch information
yxdyc committed Jul 12, 2022
1 parent 03322e6 commit 88ad802
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 34 deletions.
6 changes: 5 additions & 1 deletion federatedscope/autotune/fedex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,14 @@ def sample(self):

def broadcast_model_para(self,
msg_type='model_para',
sample_client_num=-1):
sample_client_num=-1,
filter_unseen_clients=True):
"""
To broadcast the message to all clients or sampled clients
"""
if filter_unseen_clients:
# to filter out the unseen clients when sampling
self.sampler.change_state(self.unseen_clients_id, 'working')

if sample_client_num > 0:
receiver = np.random.choice(np.arange(1, self.client_num + 1),
Expand Down
20 changes: 18 additions & 2 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def extend_fl_setting_cfg(cfg):
cfg.federate.client_num = 0
cfg.federate.sample_client_num = -1
cfg.federate.sample_client_rate = -1.0
cfg.federate.unseen_clients_rate = 0.0
cfg.federate.total_round_num = 50
cfg.federate.mode = 'standalone'
cfg.federate.share_local_model = False
Expand Down Expand Up @@ -76,6 +77,21 @@ def assert_fl_setting_cfg(cfg):
and cfg.federate.mode == 'distributed'
), "Please configure the cfg.federate. in distributed mode. "

assert 0 <= cfg.federate.unseen_clients_rate < 1, \
"You specified in-valid cfg.federate.unseen_clients_rate"
if 0 < cfg.federate.unseen_clients_rate < 1 and cfg.federate.method in [
"local", "global"
]:
logger.warning(
"In local/global training mode, the unseen_clients_rate is "
"in-valid, plz check your config")
unseen_clients_rate = 0.0
cfg.federate.unseen_clients_rate = unseen_clients_rate
else:
unseen_clients_rate = cfg.federate.unseen_clients_rate
participated_client_num = max(
1, int((1 - unseen_clients_rate) * cfg.federate.client_num))

# sample client num pre-process
sample_client_num_valid = (
0 < cfg.federate.sample_client_num <=
Expand Down Expand Up @@ -105,15 +121,15 @@ def assert_fl_setting_cfg(cfg):
# in standalone mode, federate.client_num may be modified from 0 to
# num_of_all_clients after loading the data
if cfg.federate.client_num != 0:
cfg.federate.sample_client_num = cfg.federate.client_num
cfg.federate.sample_client_num = participated_client_num
else:
# (b) sampling case
if sample_client_rate_valid:
# (b.1) use sample_client_rate
old_sample_client_num = cfg.federate.sample_client_num
cfg.federate.sample_client_num = max(
1,
int(cfg.federate.sample_client_rate * cfg.federate.client_num))
int(cfg.federate.sample_client_rate * participated_client_num))
if sample_client_num_valid:
logger.warning(
f"Users specify both valid sample_client_rate as"
Expand Down
15 changes: 15 additions & 0 deletions federatedscope/core/fed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from collections import deque

import numpy as np

from federatedscope.core.worker import Server, Client
from federatedscope.core.gpu_manager import GPUManager
from federatedscope.core.auxiliaries.model_builder import get_model
Expand Down Expand Up @@ -89,7 +91,18 @@ def _setup_for_standalone(self):
self.cfg.federate.sample_client_num = 1
self.cfg.freeze()

unseen_clients_id = []
if self.cfg.federate.unseen_clients_rate > 0:
unseen_clients_id = np.random.choice(
np.arange(1, self.cfg.federate.client_num + 1),
size=max(
1,
int(self.cfg.federate.unseen_clients_rate *
self.cfg.federate.client_num)),
replace=False).tolist()

self.server = self._setup_server()
self.server.unseen_clients_id = unseen_clients_id

self.client = dict()

Expand All @@ -102,6 +115,8 @@ def _setup_for_standalone(self):
for client_id in range(1, self.cfg.federate.client_num + 1):
self.client[client_id] = self._setup_client(
client_id=client_id, client_model=self._shared_client_model)
if client_id in unseen_clients_id:
self.client[client_id].is_unseen_client = True

def _setup_for_distributed(self):
"""
Expand Down
19 changes: 12 additions & 7 deletions federatedscope/core/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ def change_state(self, indices, state):
"""
To modify the state of clients (idle or working)
"""
if state == 'idle':
self.client_state[indices] = 1
elif state == 'working':
self.client_state[indices] = 0
if not isinstance(indices, list):
all_idx = [indices]
else:
raise ValueError(
f"The state of client should be 'idle' or 'working', but got"
f" {state}")
all_idx = indices
for idx in all_idx:
if state == 'idle':
self.client_state[idx] = 1
elif state == 'working':
self.client_state[idx] = 0
else:
raise ValueError(
f"The state of client should be 'idle' or 'working',"
f"but got {state}")


class UniformSampler(Sampler):
Expand Down
35 changes: 27 additions & 8 deletions federatedscope/core/worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def __init__(self,

super(Client, self).__init__(ID, state, config, model, strategy)

# the unseen_client indicates that whether this client contributes to
# FL process by training on its local data and uploading the local
# model update, which is useful for check the participation
# generalization gap in
# [ICLR'22, What Do We Mean by Generalization in Federated Learning?]
self.is_unseen_client = False

# Attack only support the stand alone model;
# Check if is a attacker; a client is a attacker if the
# config.attack.attack_method is provided
Expand Down Expand Up @@ -212,15 +219,27 @@ def callback_funcs_for_model_para(self, message: Message):
message.content
self.trainer.update(content)
self.state = round
if self.early_stopper.early_stopped and \
self._cfg.federate.method in ["local", "global"]:
skip_train_isolated_or_global_mode = \
self.early_stopper.early_stopped and \
self._cfg.federate.method in ["local", "global"]
if self.is_unseen_client or skip_train_isolated_or_global_mode:
# for these cases (1) unseen client (2) isolated_global_mode,
# we do not local train and upload local model
sample_size, model_para_all, results = \
0, self.trainer.get_model_para(
), {}
logger.info(f"Client #{self.ID} has been early stopped, "
f"we will skip the local training")
self._monitor.local_converged()
0, self.trainer.get_model_para(), {}
if skip_train_isolated_or_global_mode:
logger.info(
f"[Local/Global mode] Client #{self.ID} has been "
f"early stopped, we will skip the local training")
self._monitor.local_converged()
else:
if self.early_stopper.early_stopped and \
self._monitor.local_convergence_round == 0:
logger.info(
f"[Normal FL Mode] Client #{self.ID} has been locally "
f"early stopped. "
f"The next FL update may result in negative effect")
self._monitor.local_converged()
sample_size, model_para_all, results = self.trainer.train()
train_log_res = self._monitor.format_eval_res(
results,
Expand All @@ -233,7 +252,7 @@ def callback_funcs_for_model_para(self, message: Message):
save_file_name="")

# Return the feedbacks to the server after local update
if self._cfg.federate.use_ss:
if self._cfg.federate.use_ss and not self.is_unseen_client:
single_model_case = True
if isinstance(model_para_all, list):
assert isinstance(model_para_all[0], dict), \
Expand Down
65 changes: 49 additions & 16 deletions federatedscope/core/worker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import copy
import os

import numpy as np

from federatedscope.core.monitors.early_stopper import EarlyStopper
from federatedscope.core.message import Message
from federatedscope.core.communication import StandaloneCommManager, \
Expand Down Expand Up @@ -120,6 +118,12 @@ def __init__(self,
self.sample_client_num = int(self._cfg.federate.sample_client_num)
self.join_in_client_num = 0
self.join_in_info = dict()
# the unseen clients indicate the ones that do not contribute to FL
# process by training on their local data and uploading their local
# model update. The splitting is useful to check participation
# generalization gap in
# [ICLR'22, What Do We Mean by Generalization in Federated Learning?]
self.unseen_clients_id = []

# Sampler
client_info = kwargs['client_info'] if 'client_info' in kwargs else \
Expand Down Expand Up @@ -442,15 +446,20 @@ def merge_eval_results_from_all_clients(self):
round = max(self.msg_buffer['eval'].keys())
eval_msg_buffer = self.msg_buffer['eval'][round]
eval_res_participated_clients = []
eval_res_unseen_clients = []
for client_id in eval_msg_buffer:
if eval_msg_buffer[client_id] is None:
continue
eval_res_participated_clients.append(eval_msg_buffer[client_id])
if client_id in self.unseen_clients_id:
eval_res_unseen_clients.append(eval_msg_buffer[client_id])
else:
eval_res_participated_clients.append(
eval_msg_buffer[client_id])

formatted_logs_all_set = dict()
for merge_type, eval_res_set in [
("participated", eval_res_participated_clients),
]:
for merge_type, eval_res_set in [("participated",
eval_res_participated_clients),
("unseen", eval_res_unseen_clients)]:
if eval_res_set != []:
metrics_all_clients = dict()
for client_eval_results in eval_res_set:
Expand All @@ -464,39 +473,62 @@ def merge_eval_results_from_all_clients(self):
rnd=self.state,
role='Server #',
forms=self._cfg.eval.report)
if merge_type == "unseen":
for key, val in copy.deepcopy(formatted_logs).items():
if isinstance(val, dict):
# to avoid the overrides of results using the
# same name, we use new keys with postfix `unseen`:
# 'Results_weighted_avg' ->
# 'Results_weighted_avg_unseen'
formatted_logs[key + "_unseen"] = val
del formatted_logs[key]
logger.info(formatted_logs)
formatted_logs_all_set.update(formatted_logs)
self._monitor.update_best_result(
self.best_results,
metrics_all_clients,
results_type="client_individual",
results_type="unseen_client_individual"
if merge_type == "unseen" else "client_individual",
round_wise_update_key=self._cfg.eval.
best_res_update_round_wise_key)
self._monitor.save_formatted_results(formatted_logs)
for form in self._cfg.eval.report:
if form != "raw":
metric_name = form + "_unseen" if merge_type == \
"unseen" else form
self._monitor.update_best_result(
self.best_results,
formatted_logs[f"Results_{form}"],
results_type=f"client_summarized_{form}",
formatted_logs[f"Results_{metric_name}"],
results_type=f"unseen_client_summarized_{form}"
if merge_type == "unseen" else
f"client_summarized_{form}",
round_wise_update_key=self._cfg.eval.
best_res_update_round_wise_key)

return formatted_logs_all_set

def broadcast_model_para(self,
msg_type='model_para',
sample_client_num=-1):
sample_client_num=-1,
filter_unseen_clients=True):
"""
To broadcast the message to all clients or sampled clients
Arguments:
msg_type: 'model_para' or other user defined msg_type
sample_client_num: the number of sampled clients in the
broadcast behavior.
And sample_client_num = -1 denotes to broadcast to all the
clients.
"""
sample_client_num: the number of sampled clients in the broadcast
behavior. And sample_client_num = -1 denotes to broadcast to
all the clients.
filter_unseen_clients: whether filter out the unseen clients that
do not contribute to FL process by training on their local
data and uploading their local model update. The splitting is
useful to check participation generalization gap in [ICLR'22,
What Do We Mean by Generalization in Federated Learning?]
You may want to set it to be False when in evaluation stage
"""
if filter_unseen_clients:
# to filter out the unseen clients when sampling
self.sampler.change_state(self.unseen_clients_id, 'working')

if sample_client_num > 0:
receiver = self.sampler.sample(size=sample_client_num)
Expand Down Expand Up @@ -655,7 +687,8 @@ def eval(self):
self.check_and_save()
else:
# Preform evaluation in clients
self.broadcast_model_para(msg_type='evaluate')
self.broadcast_model_para(msg_type='evaluate',
filter_unseen_clients=False)

def callback_funcs_model_para(self, message: Message):
"""
Expand Down

0 comments on commit 88ad802

Please sign in to comment.