Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compression methods #555

Merged
merged 7 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/test_distribute.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ jobs:
- name: Test Distributed (LR on toy with a unified files)
run: |
python scripts/distributed_scripts/gen_data.py
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server_no_data.yaml &
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_server_no_data.yaml distribute.grpc_compression gzip &
sleep 2
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml &
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_1.yaml distribute.grpc_compression gzip &
sleep 2
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml &
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_2.yaml distribute.grpc_compression gzip &
sleep 2
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml
python federatedscope/main.py --cfg scripts/distributed_scripts/distributed_configs/distributed_client_3.yaml distribute.grpc_compression gzip
[ $? -eq 1 ] && exit 1 || echo "Passed"
- name: Test Distributed (LR on toy with multiple files)
run: |
Expand Down
21 changes: 14 additions & 7 deletions federatedscope/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from collections import deque

from federatedscope.core.configs.config import global_cfg
from federatedscope.core.proto import gRPC_comm_manager_pb2, \
gRPC_comm_manager_pb2_grpc
from federatedscope.core.gRPC_server import gRPCComServeFunc
Expand Down Expand Up @@ -106,17 +105,23 @@ class gRPCCommManager(object):
The implementation of gRPCCommManager is referred to the tutorial on
https://grpc.io/docs/languages/python/
"""
def __init__(self, host='0.0.0.0', port='50050', client_num=2):
def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None):
self.host = host
self.port = port
options = [
("grpc.max_send_message_length",
global_cfg.distribute.grpc_max_send_message_length),
("grpc.max_send_message_length", cfg.grpc_max_send_message_length),
("grpc.max_receive_message_length",
global_cfg.distribute.grpc_max_receive_message_length),
("grpc.enable_http_proxy",
global_cfg.distribute.grpc_enable_http_proxy),
cfg.grpc_max_receive_message_length),
("grpc.enable_http_proxy", cfg.grpc_enable_http_proxy),
]

if cfg.grpc_compression.lower() == 'deflate':
self.comp_method = grpc.Compression.Deflate
elif cfg.grpc_compression.lower() == 'gzip':
self.comp_method = grpc.Compression.Gzip
else:
self.comp_method = grpc.Compression.NoCompression

self.server_funcs = gRPCComServeFunc()
self.grpc_server = self.serve(max_workers=client_num,
host=host,
Expand All @@ -132,6 +137,7 @@ def serve(self, max_workers, host, port, options):
"""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=max_workers),
compression=self.comp_method,
options=options)
gRPC_comm_manager_pb2_grpc.add_gRPCComServeFuncServicer_to_server(
self.server_funcs, server)
Expand Down Expand Up @@ -170,6 +176,7 @@ def _create_stub(receiver_address):
https://grpc.io/docs/languages/python/basics/#creating-a-stub
"""
channel = grpc.insecure_channel(receiver_address,
compression=self.comp_method,
options=(('grpc.enable_http_proxy',
0), ))
stub = gRPC_comm_manager_pb2_grpc.gRPCComServeFuncStub(channel)
Expand Down
38 changes: 38 additions & 0 deletions federatedscope/core/compression/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Message compression for efficient communication

We provide plugins of message compression for efficient communication.

## Lossless compression based on gRPC
When running with distributed mode of FederatedScope, the shared messages can be compressed using the compression module provided by gRPC (More details can be found [here](https://chromium.googlesource.com/external/github.com/grpc/grpc/+/HEAD/examples/python/compression/)).

Users can turn on the message compression by adding the following configuration:
```yaml
distribute:
grpc_compression: 'deflate' # or 'gzip'
```

The compression of training ConvNet-2 on FEMNIST is shown as below:

| | NoCompression | Deflate | Gzip |
| :---: | :---: | :---: | :---: |
| Communication bytes per round (in gRPC channel) | 4.021MB | 1.888MB | 1.890MB |


## Model quantization
We provide a symmetric uniform quantization to transform the model parameters (32-bit float) to 8/16-bit int (note that it might bring model performance drop).

To apply model quantization, users need to add the following configurations:
```yaml
quantization:
method: 'uniform'
nbits: 16 # or 8
```

We conduct experiments based on the scripts provided in `federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml` and report the results as:

| | 32-bit float (vanilla) | 16-bit int | 8-bit int |
| :---: | :---: | :---: | :---: |
| Shared model size (in memory) | 25.20MB | 12.61MB | 6.31MB |
| Model performance (acc) | 0.7856 | 0.7854 | 0.6807 |

More fancy compression techniques are coming soon! We greatly appreciate contribution to FederatedScope!
6 changes: 6 additions & 0 deletions federatedscope/core/compression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from federatedscope.core.compression.utils import \
symmetric_uniform_quantization, symmetric_uniform_dequantization

__all__ = [
'symmetric_uniform_quantization', 'symmetric_uniform_dequantization'
]
84 changes: 84 additions & 0 deletions federatedscope/core/compression/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _symmetric_uniform_quantization(x, nbits, stochastic=False):
assert (torch.isnan(x).sum() == 0)
assert (torch.isinf(x).sum() == 0)

c = torch.max(torch.abs(x))
s = c / (2**(nbits - 1) - 1)
if s == 0:
return x, s
c_minus = c * -1.0

# qx = torch.where(x.ge(c), c, x)
# qx = torch.where(qx.le(c_minus), c_minus, qx)
# qx.div_(s)
qx = x / s

if stochastic:
noise = qx.new(qx.shape).uniform_(-0.5, 0.5)
qx.add_(noise)

qx.clamp_(-(2**(nbits - 1) - 1), (2**(nbits - 1) - 1)).round_()
return qx, s


def symmetric_uniform_quantization(state_dict, nbits=8):
"""
Perform symmetric uniform quantization to weight in conv & fc layers
Args:
state_dict: dict of model parameter (torch_model.state_dict)
nbits: the bit of values after quantized, chosen from [8, 16]

Returns:
The quantized model parameters
"""
if nbits == 8:
quant_data_type = torch.int8
elif nbits == 16:
quant_data_type = torch.int16
else:
logger.info(f'The provided value of nbits ({nbits}) is invalid, and we'
f' change it to 8')
nbits = 8
quant_data_type = torch.int8

quant_state_dict = dict()
for key, value in state_dict.items():
if ('fc' in key or 'conv' in key) and 'weight' == key.split('.')[-1]:
q_weight, w_s = _symmetric_uniform_quantization(value, nbits=nbits)
quant_state_dict[key.replace(
'weight', 'weight_quant')] = q_weight.type(quant_data_type)
quant_state_dict[key.replace('weight', 'weight_scale')] = w_s
else:
quant_state_dict[key] = value

return quant_state_dict


def symmetric_uniform_dequantization(state_dict):
"""
Perform symmetric uniform dequantization
Args:
state_dict: dict of model parameter (torch_model.state_dict)

Returns:
The model parameters after dequantization
"""
dequantizated_state_dict = dict()
for key, value in state_dict.items():
if 'weight_quant' in key:
alpha = state_dict[key.replace('weight_quant', 'weight_scale')]
dequantizated_state_dict[key.replace('weight_quant',
'weight')] = value * alpha
elif 'weight_scale' in key:
pass
else:
dequantizated_state_dict[key] = value

return dequantizated_state_dict
38 changes: 38 additions & 0 deletions federatedscope/core/configs/cfg_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config

logger = logging.getLogger(__name__)


def extend_compression_cfg(cfg):
# ---------------------------------------------------------------------- #
# Compression (for communication efficiency) related options
# ---------------------------------------------------------------------- #
cfg.quantization = CN()

# Params
cfg.quantization.method = 'none' # ['none', 'uniform']
cfg.quantization.nbits = 8 # [8,16]

# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_compression_cfg)


def assert_compression_cfg(cfg):

if cfg.quantization.method.lower() not in ['none', 'uniform']:
logger.warning(
f'Quantization method is expected to be one of ["none",'
f'"uniform"], but got "{cfg.quantization.method}". So we '
f'change it to "none"')

if cfg.quantization.method.lower(
) != 'none' and cfg.quantization.nbits not in [8, 16]:
raise ValueError(f'The value of cfg.quantization.nbits is invalid, '
f'which is expected to be one on [8, 16] but got '
f'{cfg.quantization.nbits}.')


register_config("compression", extend_compression_cfg)
8 changes: 8 additions & 0 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def extend_fl_setting_cfg(cfg):
cfg.distribute.grpc_max_send_message_length = 100 * 1024 * 1024
cfg.distribute.grpc_max_receive_message_length = 100 * 1024 * 1024
cfg.distribute.grpc_enable_http_proxy = False
cfg.distribute.grpc_compression = 'nocompression' # [deflate, gzip]

# ---------------------------------------------------------------------- #
# Vertical FL related options (for demo)
Expand Down Expand Up @@ -263,5 +264,12 @@ def assert_fl_setting_cfg(cfg):
f'must be in (0, 1.0], but got '
f'{cfg.vertical.feature_subsample_ratio}')

if cfg.distribute.use and cfg.distribute.grpc_compression.lower() not in [
'nocompression', 'deflate', 'gzip'
]:
raise ValueError(f'The type of grpc compression is expected to be one '
f'of ["nocompression", "deflate", "gzip"], but got '
f'{cfg.distribute.grpc_compression}.')


register_config("fl_setting", extend_fl_setting_cfg)
31 changes: 30 additions & 1 deletion federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def __init__(self,
server_host = kwargs['server_host']
server_port = kwargs['server_port']
self.comm_manager = gRPCCommManager(
host=host, port=port, client_num=self._cfg.federate.client_num)
host=host,
port=port,
client_num=self._cfg.federate.client_num,
cfg=self._cfg.distribute)
logger.info('Client: Listen to {}:{}...'.format(host, port))
self.comm_manager.add_neighbors(neighbor_id=server_id,
address={
Expand Down Expand Up @@ -291,6 +294,18 @@ def callback_funcs_for_model_para(self, message: Message):
sender = message.sender
timestamp = message.timestamp
content = message.content

# dequantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_dequantization
if isinstance(content, list): # multiple model
content = [
symmetric_uniform_dequantization(x) for x in content
]
else:
content = symmetric_uniform_dequantization(content)

# When clients share the local model, we must set strict=True to
# ensure all the model params (which might be updated by other
# clients in the previous local training process) are overwritten
Expand Down Expand Up @@ -394,6 +409,20 @@ def callback_funcs_for_model_para(self, message: Message):
else:
shared_model_para = model_para_all

# quantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_quantization
nbits = self._cfg.quantization.nbits
if isinstance(shared_model_para, list):
shared_model_para = [
symmetric_uniform_quantization(x, nbits)
for x in shared_model_para
]
else:
shared_model_para = symmetric_uniform_quantization(
shared_model_para, nbits)

self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
Expand Down
34 changes: 30 additions & 4 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def __init__(self,
port = kwargs['port']
self.comm_manager = gRPCCommManager(host=host,
port=port,
client_num=client_num)
client_num=client_num,
cfg=self._cfg.distribute)
logger.info('Server: Listen to {}:{}...'.format(host, port))

# inject noise before broadcast
Expand Down Expand Up @@ -666,6 +667,20 @@ def broadcast_model_para(self,
else:
model_para = {} if skip_broadcast else self.models[0].state_dict()

# quantization
if msg_type == 'model_para' and not skip_broadcast and \
self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_quantization
nbits = self._cfg.quantization.nbits
if self.model_num > 1:
model_para = [
symmetric_uniform_quantization(x, nbits)
for x in model_para
]
else:
model_para = symmetric_uniform_quantization(model_para, nbits)

# We define the evaluation happens at the end of an epoch
rnd = self.state - 1 if msg_type == 'evaluate' else self.state

Expand Down Expand Up @@ -815,9 +830,6 @@ def trigger_for_start(self):
logger.info(
'----------- Starting training (Round #{:d}) -------------'.
format(self.state))
print(
time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(time.time())))

def trigger_for_feat_engr(self,
trigger_train_func,
Expand Down Expand Up @@ -920,6 +932,20 @@ def callback_funcs_model_para(self, message: Message):
content = message.content
self.sampler.change_state(sender, 'idle')

# dequantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_dequantization
if isinstance(content[1], list): # multiple model
sample_size = content[0]
quant_model = [
symmetric_uniform_dequantization(x) for x in content[1]
]
else:
sample_size = content[0]
quant_model = symmetric_uniform_dequantization(content[1])
content = (sample_size, quant_model)

# update the currency timestamp according to the received message
assert timestamp >= self.cur_timestamp # for test
self.cur_timestamp = timestamp
Expand Down
Loading