diff --git a/README.md b/README.md index c3ac299d9..39f3e85ec 100644 --- a/README.md +++ b/README.md @@ -119,16 +119,17 @@ To run with distributed mode, you only need to: - Prepare isolated data file and set up `cfg.distribute.data_file = PATH/TO/DATA` for each participant; - Change `cfg.federate.model = 'distributed'`, and specify the role of each participant by `cfg.distributed.role = 'server'/'client'`. -- Set up a valid address by `cfg.distribute.host = x.x.x.x` and `cfg.distribute.host = xxxx`. (Note that for a server, you need to set up server_host/server_port for listening messge, while for a client, you need to set up client_host/client_port for listening and server_host/server_port for sending join-in applications when building up an FL course) +- Set up a valid address by `cfg.distribute.host = x.x.x.x` and `cfg.distribute.port = xxxx`. (Note that for a server, you need to set up server_host/server_port for listening messge, while for a client, you need to set up client_host/client_port for listening and server_host/server_port for sending join-in applications when building up an FL course) We prepare a synthetic example for running with distributed mode: ```bash # For server -python main.py --cfg federatedscope/example_configs/distributed_server.yaml data_path 'PATH/TO/DATA' server.host x.x.x.x client.port xxxx +python main.py --cfg federatedscope/example_configs/distributed_server.yaml data_path 'PATH/TO/DATA' distribute.server_host x.x.x.x distribute.server_port xxxx -# For client -python main.py --cfg federatedscope/example_configs/distributed_client.yaml data_path 'PATH/TO/DATA' server.host x.x.x.x server.port xxxx client.host x.x.x.x client.port xxxx +# For clients +python main.py --cfg federatedscope/example_configs/distributed_client_1.yaml data_path 'PATH/TO/DATA' distribute.server_host x.x.x.x distribute.server_port xxxx distribute.client_host x.x.x.x distribute.client_port xxxx +python main.py --cfg federatedscope/example_configs/distributed_client_2.yaml data_path 'PATH/TO/DATA' distribute.server_host x.x.x.x distribute.server_port xxxx distribute.client_host x.x.x.x distribute.client_port xxxx ``` And you can observe the results as (the IP addresses are anonymized with 'x.x.x.x'): diff --git a/federatedscope/core/communication.py b/federatedscope/core/communication.py index 0ef1c474a..aa31b3dfa 100644 --- a/federatedscope/core/communication.py +++ b/federatedscope/core/communication.py @@ -98,8 +98,9 @@ def _create_stub(receiver_address): stub = gRPC_comm_manager_pb2_grpc.gRPCComServeFuncStub(channel) return stub, channel stub, channel = _create_stub(receiver_address) - request = message.msg_to_json(to_list=True) - stub.sendMessage(gRPC_comm_manager_pb2.MessageRequest(msg=request)) + request = message.transform(to_list=True) + #msg_test = gRPC_comm_manager_pb2.MessageRequest(msg=request) + stub.sendMessage(request) channel.close() def send(self, message): @@ -119,5 +120,5 @@ def send(self, message): def receive(self): received_msg = self.server_funcs.receive() message = Message() - message.json_to_msg(received_msg.msg) + message.parse(received_msg.msg) return message diff --git a/federatedscope/core/gRPC_server.py b/federatedscope/core/gRPC_server.py index a736cdac2..9cf6c27c0 100644 --- a/federatedscope/core/gRPC_server.py +++ b/federatedscope/core/gRPC_server.py @@ -1,19 +1,20 @@ import queue +from collections import deque from federatedscope.core.proto import gRPC_comm_manager_pb2, gRPC_comm_manager_pb2_grpc class gRPCComServeFunc(gRPC_comm_manager_pb2_grpc.gRPCComServeFuncServicer): def __init__(self): - self.msg_queue = queue.Queue() + self.msg_queue = deque() def sendMessage(self, request, context): - self.msg_queue.put(request) + self.msg_queue.append(request) return gRPC_comm_manager_pb2.MessageResponse(msg='ACK') def receive(self): - while self.msg_queue.empty(): + while len(self.msg_queue) == 0: continue - msg = self.msg_queue.get() + msg = self.msg_queue.popleft() return msg diff --git a/federatedscope/core/message.py b/federatedscope/core/message.py index 86610c135..49d573c28 100644 --- a/federatedscope/core/message.py +++ b/federatedscope/core/message.py @@ -1,5 +1,7 @@ import sys import json +import numpy as np +from federatedscope.core.proto import gRPC_comm_manager_pb2 class Message(object): @@ -109,3 +111,94 @@ def json_to_msg(self, json_string): self.state = json_msg['state'] self.content = json_msg['content'] self.strategy = json_msg['strategy'] + + def create_by_type(self, value, nested=False): + if isinstance(value, dict): + m_dict = gRPC_comm_manager_pb2.mDict() + for key in value.keys(): + m_dict.dict_value[key].MergeFrom( + self.create_by_type(value[key], nested=True)) + if nested: + msg_value = gRPC_comm_manager_pb2.MsgValue() + msg_value.dict_msg.MergeFrom(m_dict) + return msg_value + else: + return m_dict + elif isinstance(value, list) or isinstance(value, tuple): + m_list = gRPC_comm_manager_pb2.mList() + for each in value: + m_list.list_value.append(self.create_by_type(each, + nested=True)) + if nested: + msg_value = gRPC_comm_manager_pb2.MsgValue() + msg_value.list_msg.MergeFrom(m_list) + return msg_value + else: + return m_list + else: + m_single = gRPC_comm_manager_pb2.mSingle() + if type(value) in [int, np.int32]: + m_single.int_value = value + elif type(value) in [str]: + m_single.str_value = value + elif type(value) in [float, np.float32]: + m_single.float_value = value + else: + raise ValueError( + 'The data type {} has not been supported.'.format( + type(value))) + + if nested: + msg_value = gRPC_comm_manager_pb2.MsgValue() + msg_value.single_msg.MergeFrom(m_single) + return msg_value + else: + return m_single + + def build_msg_value(self, value): + msg_value = gRPC_comm_manager_pb2.MsgValue() + + if isinstance(value, list) or isinstance(value, tuple): + msg_value.list_msg.MergeFrom(self.create_by_type(value)) + elif isinstance(value, dict): + msg_value.dict_msg.MergeFrom(self.create_by_type(value)) + else: + msg_value.single_msg.MergeFrom(self.create_by_type(value)) + + return msg_value + + def transform(self, to_list=False): + if to_list: + self.content = self.transform_to_list(self.content) + + splited_msg = gRPC_comm_manager_pb2.MessageRequest() # map/dict + splited_msg.msg['sender'].MergeFrom(self.build_msg_value(self.sender)) + splited_msg.msg['receiver'].MergeFrom( + self.build_msg_value(self.receiver)) + splited_msg.msg['state'].MergeFrom(self.build_msg_value(self.state)) + splited_msg.msg['msg_type'].MergeFrom( + self.build_msg_value(self.msg_type)) + splited_msg.msg['content'].MergeFrom(self.build_msg_value( + self.content)) + return splited_msg + + def _parse_msg(self, value): + if isinstance(value, gRPC_comm_manager_pb2.MsgValue) or isinstance( + value, gRPC_comm_manager_pb2.mSingle): + return self._parse_msg(getattr(value, value.WhichOneof("type"))) + elif isinstance(value, gRPC_comm_manager_pb2.mList): + return [self._parse_msg(each) for each in value.list_value] + elif isinstance(value, gRPC_comm_manager_pb2.mDict): + return { + k: self._parse_msg(value.dict_value[k]) + for k in value.dict_value + } + else: + return value + + def parse(self, received_msg): + self.sender = self._parse_msg(received_msg['sender']) + self.receiver = self._parse_msg(received_msg['receiver']) + self.msg_type = self._parse_msg(received_msg['msg_type']) + self.state = self._parse_msg(received_msg['state']) + self.content = self._parse_msg(received_msg['content']) diff --git a/federatedscope/core/proto/gRPC_comm_manager.proto b/federatedscope/core/proto/gRPC_comm_manager.proto index e8f685b7d..b418ad8bc 100644 --- a/federatedscope/core/proto/gRPC_comm_manager.proto +++ b/federatedscope/core/proto/gRPC_comm_manager.proto @@ -5,9 +5,33 @@ service gRPCComServeFunc { } message MessageRequest{ - string msg = 1; + map msg = 1; +} + +message MsgValue{ + oneof type { + mSingle single_msg = 1; + mList list_msg = 2; + mDict dict_msg = 3; + } +} + +message mSingle{ + oneof type { + float float_value = 1; + int32 int_value = 2; + string str_value = 3; + } +} + +message mList{ + repeated MsgValue list_value = 1; +} + +message mDict{ + map dict_value = 1; } message MessageResponse{ string msg = 1; -} \ No newline at end of file +} diff --git a/federatedscope/core/proto/gRPC_comm_manager_pb2.py b/federatedscope/core/proto/gRPC_comm_manager_pb2.py index 19acee78d..94c80799c 100644 --- a/federatedscope/core/proto/gRPC_comm_manager_pb2.py +++ b/federatedscope/core/proto/gRPC_comm_manager_pb2.py @@ -3,6 +3,7 @@ # source: gRPC_comm_manager.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database @@ -10,107 +11,85 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor.FileDescriptor( - name='gRPC_comm_manager.proto', - package='', - syntax='proto3', - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_pb= - b'\n\x17gRPC_comm_manager.proto\"\x1d\n\x0eMessageRequest\x12\x0b\n\x03msg\x18\x01 \x01(\t\"\x1e\n\x0fMessageResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t2D\n\x10gRPCComServeFunc\x12\x30\n\x0bsendMessage\x12\x0f.MessageRequest\x1a\x10.MessageResponseb\x06proto3' +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x17gRPC_comm_manager.proto\"n\n\x0eMessageRequest\x12%\n\x03msg\x18\x01 \x03(\x0b\x32\x18.MessageRequest.MsgEntry\x1a\x35\n\x08MsgEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x18\n\x05value\x18\x02 \x01(\x0b\x32\t.MsgValue:\x02\x38\x01\"j\n\x08MsgValue\x12\x1e\n\nsingle_msg\x18\x01 \x01(\x0b\x32\x08.mSingleH\x00\x12\x1a\n\x08list_msg\x18\x02 \x01(\x0b\x32\x06.mListH\x00\x12\x1a\n\x08\x64ict_msg\x18\x03 \x01(\x0b\x32\x06.mDictH\x00\x42\x06\n\x04type\"R\n\x07mSingle\x12\x15\n\x0b\x66loat_value\x18\x01 \x01(\x02H\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x13\n\tstr_value\x18\x03 \x01(\tH\x00\x42\x06\n\x04type\"&\n\x05mList\x12\x1d\n\nlist_value\x18\x01 \x03(\x0b\x32\t.MsgValue\"o\n\x05mDict\x12)\n\ndict_value\x18\x01 \x03(\x0b\x32\x15.mDict.DictValueEntry\x1a;\n\x0e\x44ictValueEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x18\n\x05value\x18\x02 \x01(\x0b\x32\t.MsgValue:\x02\x38\x01\"\x1e\n\x0fMessageResponse\x12\x0b\n\x03msg\x18\x01 \x01(\t2D\n\x10gRPCComServeFunc\x12\x30\n\x0bsendMessage\x12\x0f.MessageRequest\x1a\x10.MessageResponseb\x06proto3' ) -_MESSAGEREQUEST = _descriptor.Descriptor( - name='MessageRequest', - full_name='MessageRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='msg', - full_name='MessageRequest.msg', - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=b"".decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key), - ], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[], - serialized_start=27, - serialized_end=56, -) - -_MESSAGERESPONSE = _descriptor.Descriptor( - name='MessageResponse', - full_name='MessageResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='msg', - full_name='MessageResponse.msg', - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=b"".decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - serialized_options=None, - file=DESCRIPTOR, - create_key=_descriptor._internal_create_key), - ], - extensions=[], - nested_types=[], - enum_types=[], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[], - serialized_start=58, - serialized_end=88, -) - -DESCRIPTOR.message_types_by_name['MessageRequest'] = _MESSAGEREQUEST -DESCRIPTOR.message_types_by_name['MessageResponse'] = _MESSAGERESPONSE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - +_MESSAGEREQUEST = DESCRIPTOR.message_types_by_name['MessageRequest'] +_MESSAGEREQUEST_MSGENTRY = _MESSAGEREQUEST.nested_types_by_name['MsgEntry'] +_MSGVALUE = DESCRIPTOR.message_types_by_name['MsgValue'] +_MSINGLE = DESCRIPTOR.message_types_by_name['mSingle'] +_MLIST = DESCRIPTOR.message_types_by_name['mList'] +_MDICT = DESCRIPTOR.message_types_by_name['mDict'] +_MDICT_DICTVALUEENTRY = _MDICT.nested_types_by_name['DictValueEntry'] +_MESSAGERESPONSE = DESCRIPTOR.message_types_by_name['MessageResponse'] MessageRequest = _reflection.GeneratedProtocolMessageType( 'MessageRequest', (_message.Message, ), { + 'MsgEntry': _reflection.GeneratedProtocolMessageType( + 'MsgEntry', + (_message.Message, ), + { + 'DESCRIPTOR': _MESSAGEREQUEST_MSGENTRY, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:MessageRequest.MsgEntry) + }), 'DESCRIPTOR': _MESSAGEREQUEST, '__module__': 'gRPC_comm_manager_pb2' # @@protoc_insertion_point(class_scope:MessageRequest) }) _sym_db.RegisterMessage(MessageRequest) +_sym_db.RegisterMessage(MessageRequest.MsgEntry) + +MsgValue = _reflection.GeneratedProtocolMessageType( + 'MsgValue', + (_message.Message, ), + { + 'DESCRIPTOR': _MSGVALUE, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:MsgValue) + }) +_sym_db.RegisterMessage(MsgValue) + +mSingle = _reflection.GeneratedProtocolMessageType( + 'mSingle', + (_message.Message, ), + { + 'DESCRIPTOR': _MSINGLE, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mSingle) + }) +_sym_db.RegisterMessage(mSingle) + +mList = _reflection.GeneratedProtocolMessageType( + 'mList', + (_message.Message, ), + { + 'DESCRIPTOR': _MLIST, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mList) + }) +_sym_db.RegisterMessage(mList) + +mDict = _reflection.GeneratedProtocolMessageType( + 'mDict', + (_message.Message, ), + { + 'DictValueEntry': _reflection.GeneratedProtocolMessageType( + 'DictValueEntry', + (_message.Message, ), + { + 'DESCRIPTOR': _MDICT_DICTVALUEENTRY, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mDict.DictValueEntry) + }), + 'DESCRIPTOR': _MDICT, + '__module__': 'gRPC_comm_manager_pb2' + # @@protoc_insertion_point(class_scope:mDict) + }) +_sym_db.RegisterMessage(mDict) +_sym_db.RegisterMessage(mDict.DictValueEntry) MessageResponse = _reflection.GeneratedProtocolMessageType( 'MessageResponse', @@ -122,29 +101,30 @@ }) _sym_db.RegisterMessage(MessageResponse) -_GRPCCOMSERVEFUNC = _descriptor.ServiceDescriptor( - name='gRPCComServeFunc', - full_name='gRPCComServeFunc', - file=DESCRIPTOR, - index=0, - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_start=90, - serialized_end=158, - methods=[ - _descriptor.MethodDescriptor( - name='sendMessage', - full_name='gRPCComServeFunc.sendMessage', - index=0, - containing_service=None, - input_type=_MESSAGEREQUEST, - output_type=_MESSAGERESPONSE, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - ]) -_sym_db.RegisterServiceDescriptor(_GRPCCOMSERVEFUNC) - -DESCRIPTOR.services_by_name['gRPCComServeFunc'] = _GRPCCOMSERVEFUNC +_GRPCCOMSERVEFUNC = DESCRIPTOR.services_by_name['gRPCComServeFunc'] +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _MESSAGEREQUEST_MSGENTRY._options = None + _MESSAGEREQUEST_MSGENTRY._serialized_options = b'8\001' + _MDICT_DICTVALUEENTRY._options = None + _MDICT_DICTVALUEENTRY._serialized_options = b'8\001' + _MESSAGEREQUEST._serialized_start = 27 + _MESSAGEREQUEST._serialized_end = 137 + _MESSAGEREQUEST_MSGENTRY._serialized_start = 84 + _MESSAGEREQUEST_MSGENTRY._serialized_end = 137 + _MSGVALUE._serialized_start = 139 + _MSGVALUE._serialized_end = 245 + _MSINGLE._serialized_start = 247 + _MSINGLE._serialized_end = 329 + _MLIST._serialized_start = 331 + _MLIST._serialized_end = 369 + _MDICT._serialized_start = 371 + _MDICT._serialized_end = 482 + _MDICT_DICTVALUEENTRY._serialized_start = 423 + _MDICT_DICTVALUEENTRY._serialized_end = 482 + _MESSAGERESPONSE._serialized_start = 484 + _MESSAGERESPONSE._serialized_end = 514 + _GRPCCOMSERVEFUNC._serialized_start = 516 + _GRPCCOMSERVEFUNC._serialized_end = 584 # @@protoc_insertion_point(module_scope) diff --git a/federatedscope/core/proto/gRPC_comm_manager_pb2_grpc.py b/federatedscope/core/proto/gRPC_comm_manager_pb2_grpc.py index c4d250353..9a1bbeb41 100644 --- a/federatedscope/core/proto/gRPC_comm_manager_pb2_grpc.py +++ b/federatedscope/core/proto/gRPC_comm_manager_pb2_grpc.py @@ -2,7 +2,7 @@ """Client and server classes corresponding to protobuf-defined services.""" import grpc -from federatedscope.core.proto import gRPC_comm_manager_pb2 as gRPC__comm__manager__pb2 +import federatedscope.core.proto.gRPC_comm_manager_pb2 as gRPC__comm__manager__pb2 class gRPCComServeFuncStub(object): diff --git a/federatedscope/example_configs/distributed_client_1.yaml b/federatedscope/example_configs/distributed_client_1.yaml index c1f789830..2463a0f43 100644 --- a/federatedscope/example_configs/distributed_client_1.yaml +++ b/federatedscope/example_configs/distributed_client_1.yaml @@ -6,10 +6,11 @@ federate: make_global_eval: False online_aggr: False distribute: - server_host: 'xx.xx.xx.xx' - server_port: x - client_host: 'xx.xx.xx.xx' - client_port: x + use: True + server_host: '127.0.0.1' + server_port: 50051 + client_host: '127.0.0.1' + client_port: 50052 role: 'client' data_file: 'toy_data/client_1_data' trainer: diff --git a/federatedscope/example_configs/distributed_client_2.yaml b/federatedscope/example_configs/distributed_client_2.yaml index 0a02d9067..d137aba9c 100644 --- a/federatedscope/example_configs/distributed_client_2.yaml +++ b/federatedscope/example_configs/distributed_client_2.yaml @@ -6,10 +6,11 @@ federate: make_global_eval: False online_aggr: False distribute: - server_host: 'xx.xx.xx.xx' - server_port: x - client_host: 'xx.xx.xx.xx' - client_port: x + use: True + server_host: '127.0.0.1' + server_port: 50051 + client_host: '127.0.0.1' + client_port: 50053 role: 'client' data_file: 'toy_data/client_2_data' trainer: diff --git a/federatedscope/example_configs/distributed_server.yaml b/federatedscope/example_configs/distributed_server.yaml index 0155b568b..a99260a20 100644 --- a/federatedscope/example_configs/distributed_server.yaml +++ b/federatedscope/example_configs/distributed_server.yaml @@ -6,8 +6,9 @@ federate: make_global_eval: False online_aggr: False distribute: - server_host: ''xx.xx.xx.xx - server_port: x + use: True + server_host: '127.0.0.1' + server_port: 50051 role: 'server' data_file: 'toy_data/server_data' trainer: diff --git a/scripts/gen_data.py b/scripts/gen_data.py index 8f0ecd81b..a87fe19d0 100644 --- a/scripts/gen_data.py +++ b/scripts/gen_data.py @@ -101,4 +101,4 @@ def generate_data(client_num=2, return data -data = generate_data() \ No newline at end of file +data = generate_data()