From 0570a7ed5073317b801afd3ede4c3f8bea897a90 Mon Sep 17 00:00:00 2001 From: "yuexiang.xyx" Date: Mon, 27 Mar 2023 16:25:44 +0800 Subject: [PATCH 1/6] add grpc compression --- federatedscope/core/communication.py | 19 +++++++++++++------ federatedscope/core/configs/cfg_fl_setting.py | 8 ++++++++ federatedscope/core/workers/client.py | 5 ++++- federatedscope/core/workers/server.py | 3 ++- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/federatedscope/core/communication.py b/federatedscope/core/communication.py index e2d195751..ad1c41d14 100644 --- a/federatedscope/core/communication.py +++ b/federatedscope/core/communication.py @@ -106,16 +106,14 @@ class gRPCCommManager(object): The implementation of gRPCCommManager is referred to the tutorial on https://grpc.io/docs/languages/python/ """ - def __init__(self, host='0.0.0.0', port='50050', client_num=2): + def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None): self.host = host self.port = port options = [ - ("grpc.max_send_message_length", - global_cfg.distribute.grpc_max_send_message_length), + ("grpc.max_send_message_length", cfg.grpc_max_send_message_length), ("grpc.max_receive_message_length", - global_cfg.distribute.grpc_max_receive_message_length), - ("grpc.enable_http_proxy", - global_cfg.distribute.grpc_enable_http_proxy), + cfg.grpc_max_receive_message_length), + ("grpc.enable_http_proxy", cfg.grpc_enable_http_proxy), ] self.server_funcs = gRPCComServeFunc() self.grpc_server = self.serve(max_workers=client_num, @@ -125,6 +123,13 @@ def __init__(self, host='0.0.0.0', port='50050', client_num=2): self.neighbors = dict() self.monitor = None # used to track the communication related metrics + if cfg.compression.lower() == 'deflate': + self.comp_method = grpc.Compression.Deflate + elif cfg.compression.lower() == 'gzip': + self.comp_method = grpc.Compression.Gzip + else: + self.comp_method = grpc.Compression.NoCompression + def serve(self, max_workers, host, port, options): """ This function is referred to @@ -132,6 +137,7 @@ def serve(self, max_workers, host, port, options): """ server = grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers), + compression=self.comp_method, options=options) gRPC_comm_manager_pb2_grpc.add_gRPCComServeFuncServicer_to_server( self.server_funcs, server) @@ -170,6 +176,7 @@ def _create_stub(receiver_address): https://grpc.io/docs/languages/python/basics/#creating-a-stub """ channel = grpc.insecure_channel(receiver_address, + compression=self.comp_method, options=(('grpc.enable_http_proxy', 0), )) stub = gRPC_comm_manager_pb2_grpc.gRPCComServeFuncStub(channel) diff --git a/federatedscope/core/configs/cfg_fl_setting.py b/federatedscope/core/configs/cfg_fl_setting.py index 346930f4a..bb7f8adf5 100644 --- a/federatedscope/core/configs/cfg_fl_setting.py +++ b/federatedscope/core/configs/cfg_fl_setting.py @@ -75,6 +75,7 @@ def extend_fl_setting_cfg(cfg): cfg.distribute.grpc_max_send_message_length = 100 * 1024 * 1024 cfg.distribute.grpc_max_receive_message_length = 100 * 1024 * 1024 cfg.distribute.grpc_enable_http_proxy = False + cfg.distribute.grpc_compression = 'nocompression' # [deflate, gzip] # ---------------------------------------------------------------------- # # Vertical FL related options (for demo) @@ -263,5 +264,12 @@ def assert_fl_setting_cfg(cfg): f'must be in (0, 1.0], but got ' f'{cfg.vertical.feature_subsample_ratio}') + if cfg.distribute.grpc_compression.lower() not in [ + 'nocompression', 'deflate', 'gzip' + ]: + raise ValueError(f'The type of grpc compression is expected to ' + f'be one of ["nocompression", "deflate", "gzip"]' + f' but got {cfg.distribure.grpc_compression}.') + register_config("fl_setting", extend_fl_setting_cfg) diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index e17ff228d..2320aef0c 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -164,7 +164,10 @@ def __init__(self, server_host = kwargs['server_host'] server_port = kwargs['server_port'] self.comm_manager = gRPCCommManager( - host=host, port=port, client_num=self._cfg.federate.client_num) + host=host, + port=port, + client_num=self._cfg.federate.client_num, + cfg=self._cfg.distribute) logger.info('Client: Listen to {}:{}...'.format(host, port)) self.comm_manager.add_neighbors(neighbor_id=server_id, address={ diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 6d86d54d2..0e25870a7 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -213,7 +213,8 @@ def __init__(self, port = kwargs['port'] self.comm_manager = gRPCCommManager(host=host, port=port, - client_num=client_num) + client_num=client_num, + cfg=self._cfg.distribute) logger.info('Server: Listen to {}:{}...'.format(host, port)) # inject noise before broadcast From f6b765d8bf7303e9dca9439e83e655958000b12b Mon Sep 17 00:00:00 2001 From: "yuexiang.xyx" Date: Mon, 27 Mar 2023 17:15:56 +0800 Subject: [PATCH 2/6] minor fix --- federatedscope/core/communication.py | 15 ++++++++------- federatedscope/core/configs/cfg_fl_setting.py | 12 ++++++------ federatedscope/core/workers/server.py | 3 --- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/federatedscope/core/communication.py b/federatedscope/core/communication.py index ad1c41d14..d7e3cb398 100644 --- a/federatedscope/core/communication.py +++ b/federatedscope/core/communication.py @@ -115,6 +115,14 @@ def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None): cfg.grpc_max_receive_message_length), ("grpc.enable_http_proxy", cfg.grpc_enable_http_proxy), ] + + if cfg.grpc_compression.lower() == 'deflate': + self.comp_method = grpc.Compression.Deflate + elif cfg.grpc_compression.lower() == 'gzip': + self.comp_method = grpc.Compression.Gzip + else: + self.comp_method = grpc.Compression.NoCompression + self.server_funcs = gRPCComServeFunc() self.grpc_server = self.serve(max_workers=client_num, host=host, @@ -123,13 +131,6 @@ def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None): self.neighbors = dict() self.monitor = None # used to track the communication related metrics - if cfg.compression.lower() == 'deflate': - self.comp_method = grpc.Compression.Deflate - elif cfg.compression.lower() == 'gzip': - self.comp_method = grpc.Compression.Gzip - else: - self.comp_method = grpc.Compression.NoCompression - def serve(self, max_workers, host, port, options): """ This function is referred to diff --git a/federatedscope/core/configs/cfg_fl_setting.py b/federatedscope/core/configs/cfg_fl_setting.py index bb7f8adf5..86e85af50 100644 --- a/federatedscope/core/configs/cfg_fl_setting.py +++ b/federatedscope/core/configs/cfg_fl_setting.py @@ -264,12 +264,12 @@ def assert_fl_setting_cfg(cfg): f'must be in (0, 1.0], but got ' f'{cfg.vertical.feature_subsample_ratio}') - if cfg.distribute.grpc_compression.lower() not in [ - 'nocompression', 'deflate', 'gzip' - ]: - raise ValueError(f'The type of grpc compression is expected to ' - f'be one of ["nocompression", "deflate", "gzip"]' - f' but got {cfg.distribure.grpc_compression}.') + if cfg.distribute.grpc_compression.lower() not in [ + 'nocompression', 'deflate', 'gzip' + ]: + raise ValueError(f'The type of grpc compression is expected to be one ' + f'of ["nocompression", "deflate", "gzip"], but got ' + f'{cfg.distribute.grpc_compression}.') register_config("fl_setting", extend_fl_setting_cfg) diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 0e25870a7..724aa33f6 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -816,9 +816,6 @@ def trigger_for_start(self): logger.info( '----------- Starting training (Round #{:d}) -------------'. format(self.state)) - print( - time.strftime('%Y-%m-%d %H:%M:%S', - time.localtime(time.time()))) def trigger_for_feat_engr(self, trigger_train_func, From 9d3f2cb0cf57188b7d295374a78bcbfd39fffef4 Mon Sep 17 00:00:00 2001 From: "yuexiang.xyx" Date: Tue, 28 Mar 2023 15:01:08 +0800 Subject: [PATCH 3/6] add uniform quantization --- federatedscope/core/compression/README.md | 38 +++++++++ federatedscope/core/compression/__init__.py | 6 ++ federatedscope/core/compression/utils.py | 84 +++++++++++++++++++ .../core/configs/cfg_compression.py | 38 +++++++++ federatedscope/core/workers/client.py | 26 ++++++ federatedscope/core/workers/server.py | 28 +++++++ 6 files changed, 220 insertions(+) create mode 100644 federatedscope/core/compression/README.md create mode 100644 federatedscope/core/compression/__init__.py create mode 100644 federatedscope/core/compression/utils.py create mode 100644 federatedscope/core/configs/cfg_compression.py diff --git a/federatedscope/core/compression/README.md b/federatedscope/core/compression/README.md new file mode 100644 index 000000000..39a9d90c3 --- /dev/null +++ b/federatedscope/core/compression/README.md @@ -0,0 +1,38 @@ +# Message compression for efficient communication + +We provide plugins of message compression for efficient communication. + +## Lossless compression based on gRPC +When running with distributed mode of FederatedScope, the shared messages can be compressed using the compression module provided by gRPC (More details can be found [here](https://chromium.googlesource.com/external/github.com/grpc/grpc/+/HEAD/examples/python/compression/)). + +Users can turn on the message compression by adding the following configuration: +```yaml +distribute: + grpc_compression: 'deflate' # or 'gzip' +``` + +The compression of training ConvNet-2 on FEMNIST is shown as below: + +| | NoCompression | Deflate | Gzip | +| :---: | :---: | :---: | :---: | +| Communication bytes per round (in gRPC channel) | 4.021MB | 1.888MB | 1.890MB | + + +## Model quantization +We provide a symmetric uniform quantization to transform the model parameters (32-bit float) to 8/16-bit int (note that it might bring model performance drop). + +To apply model quantization, users need to add the following configurations: +```yaml +quantization: + method: 'uniform' + nbits: 16 # or 8 +``` + +We conduct experiments based on the scripts provided in `federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml` and report the results as: + +| | 32-bit float (vanilla) | 16-bit int | 8-bit int | +| :---: | :---: | :---: | :---: | +| Shared model size (in memory) | 25.20MB | 12.61MB | 6.31MB | +| Model performance (acc) | 0.7856 | 0.7854 | 0.6807 | + +More fancy compression techniques are coming soon! We greatly appreciate contribution to FederatedScope! diff --git a/federatedscope/core/compression/__init__.py b/federatedscope/core/compression/__init__.py new file mode 100644 index 000000000..4d2e4846f --- /dev/null +++ b/federatedscope/core/compression/__init__.py @@ -0,0 +1,6 @@ +from federatedscope.core.compression.utils import \ + symmetric_uniform_quantization, symmetric_uniform_dequantization + +__all__ = [ + 'symmetric_uniform_quantization', 'symmetric_uniform_dequantization' +] diff --git a/federatedscope/core/compression/utils.py b/federatedscope/core/compression/utils.py new file mode 100644 index 000000000..d559f317b --- /dev/null +++ b/federatedscope/core/compression/utils.py @@ -0,0 +1,84 @@ +import torch +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def _symmetric_uniform_quantization(x, nbits, stochastic=False): + assert (torch.isnan(x).sum() == 0) + assert (torch.isinf(x).sum() == 0) + + c = torch.max(torch.abs(x)) + s = c / (2**(nbits - 1) - 1) + if s == 0: + return x, s + c_minus = c * -1.0 + + # qx = torch.where(x.ge(c), c, x) + # qx = torch.where(qx.le(c_minus), c_minus, qx) + # qx.div_(s) + qx = x / s + + if stochastic: + noise = qx.new(qx.shape).uniform_(-0.5, 0.5) + qx.add_(noise) + + qx.clamp_(-(2**(nbits - 1) - 1), (2**(nbits - 1) - 1)).round_() + return qx, s + + +def symmetric_uniform_quantization(state_dict, nbits=8): + """ + Perform symmetric uniform quantization to weight in conv & fc layers + Args: + state_dict: dict of model parameter (torch_model.state_dict) + nbits: the bit of values after quantized, chosen from [8, 16] + + Returns: + The quantized model parameters + """ + if nbits == 8: + quant_data_type = torch.int8 + elif nbits == 16: + quant_data_type = torch.int16 + else: + logger.info(f'The provided value of nbits ({nbits}) is invalid, and we' + f' change it to 8') + nbits = 8 + quant_data_type = torch.int8 + + quant_state_dict = dict() + for key, value in state_dict.items(): + if ('fc' in key or 'conv' in key) and 'weight' == key.split('.')[-1]: + q_weight, w_s = _symmetric_uniform_quantization(value, nbits=nbits) + quant_state_dict[key.replace( + 'weight', 'weight_quant')] = q_weight.type(quant_data_type) + quant_state_dict[key.replace('weight', 'weight_scale')] = w_s + else: + quant_state_dict[key] = value + + return quant_state_dict + + +def symmetric_uniform_dequantization(state_dict): + """ + Perform symmetric uniform dequantization + Args: + state_dict: dict of model parameter (torch_model.state_dict) + + Returns: + The model parameters after dequantization + """ + dequantizated_state_dict = dict() + for key, value in state_dict.items(): + if 'weight_quant' in key: + alpha = state_dict[key.replace('weight_quant', 'weight_scale')] + dequantizated_state_dict[key.replace('weight_quant', + 'weight')] = value * alpha + elif 'weight_scale' in key: + pass + else: + dequantizated_state_dict[key] = value + + return dequantizated_state_dict diff --git a/federatedscope/core/configs/cfg_compression.py b/federatedscope/core/configs/cfg_compression.py new file mode 100644 index 000000000..c4d90c610 --- /dev/null +++ b/federatedscope/core/configs/cfg_compression.py @@ -0,0 +1,38 @@ +import logging + +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + +logger = logging.getLogger(__name__) + + +def extend_compression_cfg(cfg): + # ---------------------------------------------------------------------- # + # Compression (for communication efficiency) related options + # ---------------------------------------------------------------------- # + cfg.quantization = CN() + + # Params + cfg.quantization.method = 'none' # ['none', 'uniform'] + cfg.quantization.nbits = 8 # [8,16] + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_compression_cfg) + + +def assert_compression_cfg(cfg): + + if cfg.quantization.method.lower() not in ['none', 'uniform']: + logger.warning( + f'Quantization method is expected to be one of ["none",' + f'"uniform"], but got "{cfg.quantization.method}". So we ' + f'change it to "none"') + + if cfg.quantization.method.lower( + ) != 'none' and cfg.quantization.nbits not in [8, 16]: + raise ValueError(f'The value of cfg.quantization.nbits is invalid, ' + f'which is expected to be one on [8, 16] but got ' + f'{cfg.quantization.nbits}.') + + +register_config("compression", extend_compression_cfg) diff --git a/federatedscope/core/workers/client.py b/federatedscope/core/workers/client.py index 2320aef0c..2a6ddaf65 100644 --- a/federatedscope/core/workers/client.py +++ b/federatedscope/core/workers/client.py @@ -294,6 +294,18 @@ def callback_funcs_for_model_para(self, message: Message): sender = message.sender timestamp = message.timestamp content = message.content + + # dequantization + if self._cfg.quantization.method == 'uniform': + from federatedscope.core.compression import \ + symmetric_uniform_dequantization + if isinstance(content, list): # multiple model + content = [ + symmetric_uniform_dequantization(x) for x in content + ] + else: + content = symmetric_uniform_dequantization(content) + # When clients share the local model, we must set strict=True to # ensure all the model params (which might be updated by other # clients in the previous local training process) are overwritten @@ -398,6 +410,20 @@ def callback_funcs_for_model_para(self, message: Message): else: shared_model_para = model_para_all + # quantization + if self._cfg.quantization.method == 'uniform': + from federatedscope.core.compression import \ + symmetric_uniform_quantization + nbits = self._cfg.quantization.nbits + if isinstance(shared_model_para, list): + shared_model_para = [ + symmetric_uniform_quantization(x, nbits) + for x in shared_model_para + ] + else: + shared_model_para = symmetric_uniform_quantization( + shared_model_para, nbits) + self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, diff --git a/federatedscope/core/workers/server.py b/federatedscope/core/workers/server.py index 724aa33f6..65ef0ff68 100644 --- a/federatedscope/core/workers/server.py +++ b/federatedscope/core/workers/server.py @@ -667,6 +667,20 @@ def broadcast_model_para(self, else: model_para = {} if skip_broadcast else self.models[0].state_dict() + # quantization + if msg_type == 'model_para' and not skip_broadcast and \ + self._cfg.quantization.method == 'uniform': + from federatedscope.core.compression import \ + symmetric_uniform_quantization + nbits = self._cfg.quantization.nbits + if self.model_num > 1: + model_para = [ + symmetric_uniform_quantization(x, nbits) + for x in model_para + ] + else: + model_para = symmetric_uniform_quantization(model_para, nbits) + # We define the evaluation happens at the end of an epoch rnd = self.state - 1 if msg_type == 'evaluate' else self.state @@ -918,6 +932,20 @@ def callback_funcs_model_para(self, message: Message): content = message.content self.sampler.change_state(sender, 'idle') + # dequantization + if self._cfg.quantization.method == 'uniform': + from federatedscope.core.compression import \ + symmetric_uniform_dequantization + if isinstance(content[1], list): # multiple model + sample_size = content[0] + quant_model = [ + symmetric_uniform_dequantization(x) for x in content[1] + ] + else: + sample_size = content[0] + quant_model = symmetric_uniform_dequantization(content[1]) + content = (sample_size, quant_model) + # update the currency timestamp according to the received message assert timestamp >= self.cur_timestamp # for test self.cur_timestamp = timestamp From c8b00997dda0b0afa8bfd41eb94ae68a5f6f4bcc Mon Sep 17 00:00:00 2001 From: "yuexiang.xyx" Date: Tue, 28 Mar 2023 15:13:20 +0800 Subject: [PATCH 4/6] add unittest --- .github/workflows/test_distribute.yml | 10 ++++ federatedscope/core/communication.py | 1 - tests/test_femnist.py | 67 +++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_distribute.yml b/.github/workflows/test_distribute.yml index 67905345f..19634ef61 100644 --- a/.github/workflows/test_distribute.yml +++ b/.github/workflows/test_distribute.yml @@ -42,4 +42,14 @@ jobs: python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml & sleep 2 python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml + [ $? -eq 1 ] && exit 1 || echo "Passed" + - name: Test Distributed (FEMNIST on ConvNet with gzip compression) + run: | + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_femnist_server.yaml distribute.grpc_compression gzip & + sleep 2 + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_femnist_client_1.yaml distribute.grpc_compression gzip & + sleep 2 + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_femnist_client_2.yaml distribute.grpc_compression gzip & + sleep 2 + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_femnist_client_3.yaml distribute.grpc_compression gzip [ $? -eq 1 ] && exit 1 || echo "Passed" \ No newline at end of file diff --git a/federatedscope/core/communication.py b/federatedscope/core/communication.py index d7e3cb398..43433962c 100644 --- a/federatedscope/core/communication.py +++ b/federatedscope/core/communication.py @@ -5,7 +5,6 @@ from collections import deque -from federatedscope.core.configs.config import global_cfg from federatedscope.core.proto import gRPC_comm_manager_pb2, \ gRPC_comm_manager_pb2_grpc from federatedscope.core.gRPC_server import gRPCComServeFunc diff --git a/tests/test_femnist.py b/tests/test_femnist.py index 99e5b989e..efc1d1d8a 100644 --- a/tests/test_femnist.py +++ b/tests/test_femnist.py @@ -55,6 +55,49 @@ def set_config_femnist(self, cfg): return backup_cfg + def set_config_femnist_quant(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + cfg.eval.freq = 10 + cfg.eval.metrics = ['acc', 'loss_regular'] + + cfg.federate.mode = 'standalone' + cfg.train.local_update_steps = 5 + cfg.federate.total_round_num = 20 + cfg.federate.sample_client_num = SAMPLE_CLIENT_NUM + + cfg.data.root = 'test_data/' + cfg.data.type = 'femnist' + cfg.data.splits = [0.6, 0.2, 0.2] + cfg.data.batch_size = 10 + cfg.data.subsample = 0.05 + cfg.data.transform = [['ToTensor'], + [ + 'Normalize', { + 'mean': [0.9637], + 'std': [0.1592] + } + ]] + + cfg.model.type = 'convnet2' + cfg.model.hidden = 2048 + cfg.model.out_channels = 62 + + cfg.train.optimizer.lr = 0.001 + cfg.train.optimizer.weight_decay = 0.0 + cfg.grad.grad_clip = 5.0 + + cfg.criterion.type = 'CrossEntropyLoss' + cfg.trainer.type = 'cvtrainer' + cfg.seed = 123 + + cfg.quantization.method = 'uniform' + cfg.quantization.ubits = 16 + + return backup_cfg + def test_femnist_standalone(self): init_cfg = global_cfg.clone() backup_cfg = self.set_config_femnist(init_cfg) @@ -79,6 +122,30 @@ def test_femnist_standalone(self): test_best_results["client_summarized_weighted_avg"]['test_loss'], 600) + def test_femnist_quantization_standalone(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_femnist_quant(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + self.assertEqual(init_cfg.federate.sample_client_num, + SAMPLE_CLIENT_NUM) + + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_best_results = Fed_runner.run() + print(test_best_results) + init_cfg.merge_from_other_cfg(backup_cfg) + self.assertLess( + test_best_results["client_summarized_weighted_avg"]['test_loss'], + 600) + if __name__ == '__main__': unittest.main() From 4658e55b2deeaca131125d8d846b979a69a81c14 Mon Sep 17 00:00:00 2001 From: "yuexiang.xyx" Date: Tue, 28 Mar 2023 15:19:56 +0800 Subject: [PATCH 5/6] update unittest --- .github/workflows/test_distribute.yml | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/.github/workflows/test_distribute.yml b/.github/workflows/test_distribute.yml index 19634ef61..01aa21e54 100644 --- a/.github/workflows/test_distribute.yml +++ b/.github/workflows/test_distribute.yml @@ -35,21 +35,11 @@ jobs: - name: Test Distributed (LR on toy) run: | python scripts/distributed_scripts/gen_data.py - python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server.yaml & + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server.yaml distribute.grpc_compression gzip & sleep 2 - python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml & + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml distribute.grpc_compression gzip & sleep 2 - python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml & + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml distribute.grpc_compression gzip & sleep 2 - python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml - [ $? -eq 1 ] && exit 1 || echo "Passed" - - name: Test Distributed (FEMNIST on ConvNet with gzip compression) - run: | - python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_femnist_server.yaml distribute.grpc_compression gzip & - sleep 2 - python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_femnist_client_1.yaml distribute.grpc_compression gzip & - sleep 2 - python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_femnist_client_2.yaml distribute.grpc_compression gzip & - sleep 2 - python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_femnist_client_3.yaml distribute.grpc_compression gzip + python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml distribute.grpc_compression gzip [ $? -eq 1 ] && exit 1 || echo "Passed" \ No newline at end of file From 85ce68056e22d4a09ffd5e61ba6af79a90d37be6 Mon Sep 17 00:00:00 2001 From: "yuexiang.xyx" Date: Tue, 28 Mar 2023 21:55:28 +0800 Subject: [PATCH 6/6] minor fix --- federatedscope/core/configs/cfg_fl_setting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federatedscope/core/configs/cfg_fl_setting.py b/federatedscope/core/configs/cfg_fl_setting.py index 86e85af50..f684dec8a 100644 --- a/federatedscope/core/configs/cfg_fl_setting.py +++ b/federatedscope/core/configs/cfg_fl_setting.py @@ -264,7 +264,7 @@ def assert_fl_setting_cfg(cfg): f'must be in (0, 1.0], but got ' f'{cfg.vertical.feature_subsample_ratio}') - if cfg.distribute.grpc_compression.lower() not in [ + if cfg.distribute.use and cfg.distribute.grpc_compression.lower() not in [ 'nocompression', 'deflate', 'gzip' ]: raise ValueError(f'The type of grpc compression is expected to be one '