Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks for completeness of msg_handler #388

Merged
merged 9 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_distribute.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on: [push, pull_request]
jobs:
run:
runs-on: ${{ matrix.os }}
timeout-minutes: 10
timeout-minutes: 20
strategy:
matrix:
os: [ubuntu-latest]
Expand Down
30 changes: 24 additions & 6 deletions federatedscope/contrib/worker/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,36 @@


# Build your worker here.
class MyClient(Client):
pass
class MyServer(Server):
def _register_default_handlers(self):
self.register_handlers('join_in', self.callback_funcs_for_join_in,
['assign_client_id', 'address', 'model_para'])
self.register_handlers('join_in_info', self.callback_funcs_for_join_in,
['address', 'model_para'])
self.register_handlers('model_para', self.callback_funcs_model_para,
['model_para', 'evaluate', 'finish'])
self.register_handlers('metrics', self.callback_funcs_for_metrics,
['converged'])


class MyServer(Server):
pass
class MyClient(Client):
def _register_default_handlers(self):
self.register_handlers('assign_client_id',
self.callback_funcs_for_assign_id, [None])
self.register_handlers('address', self.callback_funcs_for_address)
self.register_handlers('model_para',
self.callback_funcs_for_model_para,
['model_para', 'ss_model_para'])
self.register_handlers('evaluate', self.callback_funcs_for_evaluate,
['metrics'])
self.register_handlers('finish', self.callback_funcs_for_finish)
self.register_handlers('converged', self.callback_funcs_for_converged)


def call_my_worker(method):
if method == 'mymethod':
if method == 'myfedavg':
worker_builder = {'client': MyClient, 'server': MyServer}
return worker_builder


register_worker('mymethod', call_my_worker)
register_worker('myfedavg', call_my_worker)
3 changes: 3 additions & 0 deletions federatedscope/core/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ def init_global_cfg(cfg):
# Whether to use GPU
cfg.use_gpu = False

# Whether to check the completeness of msg_handler
cfg.check_completeness = False

# Whether to print verbose logging info
cfg.verbose = 1

Expand Down
126 changes: 126 additions & 0 deletions federatedscope/core/fed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def __init__(self,
self.resource_info = get_resource_info(
config.federate.resource_info_file)

# Check the completeness of msg_handler.
self.check()

# Set up for Runner
self._set_up()

Expand Down Expand Up @@ -208,6 +211,65 @@ def _setup_client(self,

return client

def check(self):
"""
Check the completeness of Server and Client.

"""
if not self.cfg.check_completeness:
return
try:
import os
import networkx as nx
import matplotlib.pyplot as plt
# Build check graph
G = nx.DiGraph()
flags = {0: 'Client', 1: 'Server'}
msg_handler_dicts = [
self.client_class.get_msg_handler_dict(),
self.server_class.get_msg_handler_dict()
]
for flag, msg_handler_dict in zip(flags.keys(), msg_handler_dicts):
role, oppo = flags[flag], flags[(flag + 1) % 2]
for msg_in, (handler, msgs_out) in \
msg_handler_dict.items():
for msg_out in msgs_out:
msg_in_key = f'{oppo}_{msg_in}'
handler_key = f'{role}_{handler}'
msg_out_key = f'{role}_{msg_out}'
G.add_node(msg_in_key, subset=1)
G.add_node(handler_key, subset=0 if flag else 2)
G.add_node(msg_out_key, subset=1)
G.add_edge(msg_in_key, handler_key)
G.add_edge(handler_key, msg_out_key)
pos = nx.multipartite_layout(G)
plt.figure(figsize=(20, 15))
nx.draw(G,
pos,
with_labels=True,
node_color='white',
node_size=800,
width=1.0,
arrowsize=25,
arrowstyle='->')
fig_path = os.path.join(self.cfg.outdir, 'msg_handler.png')
plt.savefig(fig_path)
if nx.has_path(G, 'Client_join_in', 'Server_finish'):
if nx.is_weakly_connected(G):
logger.info(f'Completeness check passes! Save check '
f'results in {fig_path}.')
else:
logger.warning(f'Completeness check raises warning for '
f'some handlers not in FL process! Save '
f'check results in {fig_path}.')
else:
logger.error(f'Completeness check fails for there is no'
f'path from `join_in` to `finish`! Save '
f'check results in {fig_path}.')
except Exception as error:
logger.warning(f'Completeness check failed for {error}!')
return


class StandaloneRunner(BaseRunner):
def _set_up(self):
Expand Down Expand Up @@ -528,6 +590,10 @@ def __init__(self,
self.resource_info = get_resource_info(
config.federate.resource_info_file)

# Check the completeness of msg_handler.
self.check()

def setup(self):
if self.mode == 'standalone':
self.shared_comm_queue = deque()
self._setup_for_standalone()
Expand Down Expand Up @@ -635,6 +701,7 @@ def run(self):
For the standalone mode, a shared message queue will be set up to
simulate ``receiving message``.
"""
self.setup()
if self.mode == 'standalone':
# trigger the FL course
for each_client in self.client:
Expand Down Expand Up @@ -870,3 +937,62 @@ def _handle_msg(self, msg, rcv=-1):
self.client[each_receiver].msg_handlers[msg.msg_type](msg)
self.client[each_receiver]._monitor.track_download_bytes(
download_bytes)

def check(self):
"""
Check the completeness of Server and Client.

"""
if not self.cfg.check_completeness:
return
try:
import os
import networkx as nx
import matplotlib.pyplot as plt
# Build check graph
G = nx.DiGraph()
flags = {0: 'Client', 1: 'Server'}
msg_handler_dicts = [
self.client_class.get_msg_handler_dict(),
self.server_class.get_msg_handler_dict()
]
for flag, msg_handler_dict in zip(flags.keys(), msg_handler_dicts):
role, oppo = flags[flag], flags[(flag + 1) % 2]
for msg_in, (handler, msgs_out) in \
msg_handler_dict.items():
for msg_out in msgs_out:
msg_in_key = f'{oppo}_{msg_in}'
handler_key = f'{role}_{handler}'
msg_out_key = f'{role}_{msg_out}'
G.add_node(msg_in_key, subset=1)
G.add_node(handler_key, subset=0 if flag else 2)
G.add_node(msg_out_key, subset=1)
G.add_edge(msg_in_key, handler_key)
G.add_edge(handler_key, msg_out_key)
pos = nx.multipartite_layout(G)
plt.figure(figsize=(20, 15))
nx.draw(G,
pos,
with_labels=True,
node_color='white',
node_size=800,
width=1.0,
arrowsize=25,
arrowstyle='->')
fig_path = os.path.join(self.cfg.outdir, 'msg_handler.png')
plt.savefig(fig_path)
if nx.has_path(G, 'Client_join_in', 'Server_finish'):
if nx.is_weakly_connected(G):
logger.info(f'Completeness check passes! Save check '
f'results in {fig_path}.')
else:
logger.warning(f'Completeness check raises warning for '
f'some handlers not in FL process! Save '
f'check results in {fig_path}.')
else:
logger.error(f'Completeness check fails for there is no'
f'path from `join_in` to `finish`! Save '
f'check results in {fig_path}.')
except Exception as error:
logger.warning(f'Completeness check failed for {error}!')
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the reason is that we cannot say yes or not about the correctness/completeness, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the document shows, the correctness/completeness checks only check the message-handler pairs and raise three types of logs (info, warning, and error). If something goes wrong with Python code, we'd better keep the exception stack as it is. So the return value is meaningless.

29 changes: 18 additions & 11 deletions federatedscope/core/workers/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
class BaseClient(Worker):
def __init__(self, ID, state, config, model, strategy):
super(BaseClient, self).__init__(ID, state, config, model, strategy)
# TODO: move to worker
self.msg_handlers = dict()
self.msg_handlers_str = dict()

# TODO: move to worker
def register_handlers(self, msg_type, callback_func):
def register_handlers(self, msg_type, callback_func, send_msg=[None]):
"""
To bind a message type with a handling function.

Expand All @@ -19,6 +18,7 @@ def register_handlers(self, msg_type, callback_func):
message
"""
self.msg_handlers[msg_type] = callback_func
self.msg_handlers_str[msg_type] = (callback_func.__name__, send_msg)

def _register_default_handlers(self):
"""
Expand All @@ -43,17 +43,24 @@ def _register_default_handlers(self):
============================ ==================================
"""
self.register_handlers('assign_client_id',
self.callback_funcs_for_assign_id)
self.callback_funcs_for_assign_id, [None])
self.register_handlers('ask_for_join_in_info',
self.callback_funcs_for_join_in_info)
self.register_handlers('address', self.callback_funcs_for_address)
self.callback_funcs_for_join_in_info,
['join_in_info'])
self.register_handlers('address', self.callback_funcs_for_address,
[None])
self.register_handlers('model_para',
self.callback_funcs_for_model_para)
self.callback_funcs_for_model_para,
['model_para', 'ss_model_para'])
self.register_handlers('ss_model_para',
self.callback_funcs_for_model_para)
self.register_handlers('evaluate', self.callback_funcs_for_evaluate)
self.register_handlers('finish', self.callback_funcs_for_finish)
self.register_handlers('converged', self.callback_funcs_for_converged)
self.callback_funcs_for_model_para,
['ss_model_para', 'model_para'])
self.register_handlers('evaluate', self.callback_funcs_for_evaluate,
['metrics'])
self.register_handlers('finish', self.callback_funcs_for_finish,
[None])
self.register_handlers('converged', self.callback_funcs_for_converged,
[None])

@abc.abstractmethod
def run(self):
Expand Down
19 changes: 12 additions & 7 deletions federatedscope/core/workers/base_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
class BaseServer(Worker):
def __init__(self, ID, state, config, model, strategy):
super(BaseServer, self).__init__(ID, state, config, model, strategy)
# TODO: move to worker
self.msg_handlers = dict()
self.msg_handlers_str = dict()

# TODO: move to worker
def register_handlers(self, msg_type, callback_func):
def register_handlers(self, msg_type, callback_func, send_msg=[None]):
"""
To bind a message type with a handling function.

Expand All @@ -19,6 +18,7 @@ def register_handlers(self, msg_type, callback_func):
message
"""
self.msg_handlers[msg_type] = callback_func
self.msg_handlers_str[msg_type] = (callback_func.__name__, send_msg)

def _register_default_handlers(self):
"""
Expand All @@ -38,10 +38,15 @@ def _register_default_handlers(self):
``metrics`` ``callback_funcs_for_metrics``
============================ ==================================
"""
self.register_handlers('join_in', self.callback_funcs_for_join_in)
self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
self.register_handlers('model_para', self.callback_funcs_model_para)
self.register_handlers('metrics', self.callback_funcs_for_metrics)
self.register_handlers('join_in', self.callback_funcs_for_join_in, [
'assign_client_id', 'ask_for_join_in_info', 'address', 'model_para'
])
self.register_handlers('join_in_info', self.callback_funcs_for_join_in,
['address', 'model_para'])
self.register_handlers('model_para', self.callback_funcs_model_para,
['model_para', 'evaluate', 'finish'])
self.register_handlers('metrics', self.callback_funcs_for_metrics,
['converged'])

@abc.abstractmethod
def run(self):
Expand Down
5 changes: 3 additions & 2 deletions federatedscope/core/workers/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def __init__(self, ID=-1, state=0, config=None, model=None, strategy=None):
self._model = model
self._cfg = config
self._strategy = strategy
self._mode = self._cfg.federate.mode.lower()
self._monitor = Monitor(config, monitored_object=self)
if self._cfg is not None:
self._mode = self._cfg.federate.mode.lower()
self._monitor = Monitor(config, monitored_object=self)

@property
def ID(self):
Expand Down
14 changes: 10 additions & 4 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ def __init__(self,
is_unseen_client=False,
*args,
**kwargs):

super(Client, self).__init__(ID, state, config, model, strategy)
# Register message handlers
self._register_default_handlers()

# Un-configured worker
if config is None:
return

# the unseen_client indicates that whether this client contributes to
# FL process by training on its local data and uploading the local
Expand Down Expand Up @@ -109,9 +114,6 @@ def __init__(self,
)) if self._cfg.federate.use_ss else None
self.msg_buffer = {'train': dict(), 'eval': dict()}

# Register message handlers
self._register_default_handlers()

# Communication and communication ability
if 'resource_info' in kwargs and kwargs['resource_info'] is not None:
self.comp_speed = float(
Expand Down Expand Up @@ -527,3 +529,7 @@ def callback_funcs_for_converged(self, message: Message):
message: The received message
"""
self._monitor.global_converged()

@classmethod
def get_msg_handler_dict(cls):
return cls().msg_handlers_str
14 changes: 10 additions & 4 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ def __init__(self,
strategy=None,
unseen_clients_id=None,
**kwargs):

super(Server, self).__init__(ID, state, config, model, strategy)
# Register message handlers
self._register_default_handlers()

# Un-configured worker
if config is None:
return

self.data = data
self.device = device
Expand Down Expand Up @@ -186,9 +191,6 @@ def __init__(self,
self.client_resource_info = kwargs['client_resource_info'] \
if 'client_resource_info' in kwargs else None

# Register message handlers
self._register_default_handlers()

# Initialize communication manager and message buffer
self.msg_buffer = {'train': dict(), 'eval': dict()}
self.staled_msg_buffer = list()
Expand Down Expand Up @@ -987,3 +989,7 @@ def callback_funcs_for_metrics(self, message: Message):
self.msg_buffer['eval'][rnd][sender] = content

return self.check_and_move_on(check_eval_result=True)

@classmethod
def get_msg_handler_dict(cls):
return cls().msg_handlers_str
Loading