Skip to content

Commit

Permalink
Merge pull request #1 from APPFL/zilinghan/compressor
Browse files Browse the repository at this point in the history
Zilinghan/compressor
  • Loading branch information
Zilinghan authored Mar 14, 2024
2 parents 64b5ed7 + 40521ab commit 9ec22fd
Show file tree
Hide file tree
Showing 23 changed files with 857 additions and 59 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ __pycache__/
*.egg-info/
*output
*outputs
*RawData
*RawData
.compressor
15 changes: 4 additions & 11 deletions examples/config/client_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,7 @@ data_configs:
output_filename: "visualization.pdf"

comm_configs:
server_uri: localhost:50051
max_message_size: 1048576

use_ssl: False
# # SSL configurations
# use_ssl: True
# use_authenticator: True
# root_certificate: ../src/appfl/comm/grpc/credentials/root.crt
# authenticator: NaiveAuthenticator
# authenticator_args:
# auth_token: a-secret-token
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
15 changes: 4 additions & 11 deletions examples/config/client_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,7 @@ data_configs:
output_filename: "visualization.pdf"

comm_configs:
server_uri: localhost:50051
max_message_size: 1048576

use_ssl: False
# # SSL configurations
# use_ssl: True
# use_authenticator: True
# root_certificate: ../src/appfl/comm/grpc/credentials/root.crt
# authenticator: NaiveAuthenticator
# authenticator_args:
# auth_token: a-secret-token
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
17 changes: 13 additions & 4 deletions examples/config/server_fedasync.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ client_configs:
num_pixel: 28

comm_configs:
enable_compression: False
compressor_configs:
enable_compression: False
# Used if enable_compression is True
lossy_compressor: "SZ2"
lossless_compressor: "blosc"
error_bounding_mode: "REL"
error_bound: 1e-3
flat_model_dtype: "np.float32"
param_cutoff: 1024

server_configs:
scheduler: "AsyncScheduler"
Expand All @@ -61,6 +69,7 @@ server_configs:
logging_output_dirname: "./output"
logging_output_filename: "result"
comm_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
17 changes: 13 additions & 4 deletions examples/config/server_fedavg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ client_configs:
num_pixel: 28

comm_configs:
enable_compression: False
compressor_configs:
enable_compression: False
# Used if enable_compression is True
lossy_compressor: "SZ2"
lossless_compressor: "blosc"
error_bounding_mode: "REL"
error_bound: 1e-3
flat_model_dtype: "np.float32"
param_cutoff: 1024

server_configs:
scheduler: "SyncScheduler"
Expand All @@ -53,6 +61,7 @@ server_configs:
logging_output_dirname: "./output"
logging_output_filename: "result"
comm_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
17 changes: 13 additions & 4 deletions examples/config/server_fedcompass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ client_configs:
num_pixel: 28

comm_configs:
enable_compression: False
compressor_configs:
enable_compression: False
# Used if enable_compression is True
lossy_compressor: "SZ2"
lossless_compressor: "blosc"
error_bounding_mode: "REL"
error_bound: 1e-3
flat_model_dtype: "np.float32"
param_cutoff: 1024

server_configs:
scheduler: "CompassScheduler"
Expand All @@ -65,6 +73,7 @@ server_configs:
logging_output_dirname: "./output"
logging_output_filename: "result"
comm_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
3 changes: 1 addition & 2 deletions examples/grpc/run_client_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
client_agent = APPFLClientAgent(client_agent_config=client_agent_config)
client_comm = GRPCClientCommunicator(
client_id = client_agent.get_id(),
**client_agent_config.comm_configs,
**client_agent_config.comm_configs.grpc_configs,
)

client_config = client_comm.get_configuration()
Expand All @@ -20,7 +20,6 @@

# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
print(f"Sample size: {sample_size}")
client_comm.invoke_custom_action(action='set_sample_size', sample_size=sample_size)

for i in range(10):
Expand Down
3 changes: 1 addition & 2 deletions examples/grpc/run_client_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
client_agent = APPFLClientAgent(client_agent_config=client_agent_config)
client_comm = GRPCClientCommunicator(
client_id = client_agent.get_id(),
**client_agent_config.comm_configs,
**client_agent_config.comm_configs.grpc_configs,
)

client_config = client_comm.get_configuration()
Expand All @@ -20,7 +20,6 @@

# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
print(f"Sample size: {sample_size}")
client_comm.invoke_custom_action(action='set_sample_size', sample_size=sample_size)

for i in range(10):
Expand Down
4 changes: 2 additions & 2 deletions examples/grpc/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

communicator = GRPCServerCommunicator(
server_agent,
max_message_size=server_agent_config.server_configs.comm_configs.max_message_size,
max_message_size=server_agent_config.server_configs.comm_configs.grpc_configs.max_message_size,
)

serve(
communicator,
**server_agent_config.server_configs.comm_configs,
**server_agent_config.server_configs.comm_configs.grpc_configs,
)
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,17 @@
"matplotlib",
"torchvision",
"globus-sdk",
"zfpy",
"blosc",
"zstd",
"scipy",
"lz4",
"python-xz",
],
entry_points={
"console_scripts": [
"appfl-auth=appfl.login_manager.globus.cli:auth",
"appfl-install-compressor=appfl.compressor.install:install_compressor",
],
},
)
34 changes: 30 additions & 4 deletions src/appfl/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import importlib
import torch.nn as nn
from appfl.trainer import BaseTrainer
from appfl.compressor import Compressor
from appfl.config import ClientAgentConfig
from omegaconf import DictConfig, OmegaConf
from typing import Union, Dict, OrderedDict
from typing import Union, Dict, OrderedDict, Tuple
from appfl.logger import ClientAgentFileLogger
from appfl.misc import create_instance_from_file, \
run_function_from_file, \
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
self._load_metric()
self._load_data()
self._load_trainer()
self._load_compressor()

def load_config(self, config: DictConfig) -> None:
"""Load additional configurations provided by the server."""
Expand All @@ -46,6 +48,7 @@ def load_config(self, config: DictConfig) -> None:
self._load_loss()
self._load_metric()
self._load_trainer()
self._load_compressor()

def get_id(self) -> str:
"""Return a unique client id for server to distinguish clients."""
Expand All @@ -61,10 +64,16 @@ def train(self) -> None:
"""Train the model locally."""
self.trainer.train()

def get_parameters(self) -> Union[Dict, OrderedDict, bytes]:
def get_parameters(self) -> Union[Dict, OrderedDict, bytes, Tuple[Union[Dict, OrderedDict, bytes], Dict]]:
"""Return parameters for communication"""
params = self.trainer.get_parameters()
return params
if isinstance(params, tuple):
params, metadata = params
else:
metadata = None
if self.enable_compression:
params = self.compressor.compress_model(params)
return params if metadata is None else (params, metadata)

def load_parameters(self, params) -> None:
"""Load parameters from the server."""
Expand Down Expand Up @@ -200,4 +209,21 @@ def _load_trainer(self) -> None:
train_configs=self.client_agent_config.train_configs,
logger=self.logger,
)


def _load_compressor(self) -> None:
"""
Create a compressor for compressing the model parameters.
"""
if hasattr(self, "compressor") and self.compressor is not None:
return
self.compressor = None
self.enable_compression = False
if not hasattr(self.client_agent_config, "comm_configs"):
return
if not hasattr(self.client_agent_config.comm_configs, "compressor_configs"):
return
if getattr(self.client_agent_config.comm_configs.compressor_configs, "enable_compression", False):
self.enable_compression = True
self.compressor = Compressor(
self.client_agent_config.comm_configs.compressor_configs
)
50 changes: 43 additions & 7 deletions src/appfl/agent/server.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import io
import torch
import torch.nn as nn
from appfl.scheduler import *
from appfl.aggregator import *
from concurrent.futures import Future
from appfl.compressor import Compressor
from appfl.config import ServerAgentConfig
from omegaconf import OmegaConf, DictConfig
from appfl.misc import create_instance_from_file, get_function_from_file
from appfl.logger import ServerAgentFileLogger
from concurrent.futures import Future
from omegaconf import OmegaConf, DictConfig
from typing import Union, Dict, OrderedDict, Tuple
from appfl.misc import create_instance_from_file, get_function_from_file

class APPFLServerAgent:
"""
Expand All @@ -22,11 +25,19 @@ def __init__(
server_agent_config: ServerAgentConfig = ServerAgentConfig()
) -> None:
self.server_agent_config = server_agent_config
if hasattr(self.server_agent_config.client_configs, "comm_configs"):
self.server_agent_config.server_configs.comm_configs = (OmegaConf.merge(
self.server_agent_config.server_configs.comm_configs,
self.server_agent_config.client_configs.comm_configs
) if hasattr(self.server_agent_config.server_configs, "comm_configs")
else self.server_agent_config.client_configs.comm_configs
)
self._create_logger()
self._load_model()
self._load_loss()
self._load_metric()
self._get_scheduler()
self._load_scheduler()
self._load_compressor()

def get_client_configs(self, **kwargs) -> DictConfig:
"""Return the FL configurations that are shared among all clients."""
Expand All @@ -35,18 +46,20 @@ def get_client_configs(self, **kwargs) -> DictConfig:
def global_update(
self,
client_id: Union[int, str],
local_model: Union[Dict, OrderedDict],
local_model: Union[Dict, OrderedDict, bytes],
blocking: bool = False,
**kwargs
) -> Union[Future, Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
"""
Update the global model using the local model from a client and return the updated global model.
:param: client_id: A unique client id for server to distinguish clients, which be obtained via `ClientAgent.get_id()`.
:param: local_model: The local model from a client.
:param: local_model: The local model from a client, can be serailzed bytes.
:param: blocking: The global model may not be immediately available for certain aggregation methods (e.g. any synchronous method).
Setting `blocking` to `True` will block the client until the global model is available.
Otherwise, the method may return a `Future` object if the most up-to-date global model is not yet available.
"""
if isinstance(local_model, bytes):
local_model = self._bytes_to_model(local_model)
global_model = self.scheduler.schedule(client_id, local_model, **kwargs)
if not isinstance(global_model, Future):
return global_model
Expand Down Expand Up @@ -137,7 +150,7 @@ def _load_metric(self) -> None:
else:
self.metric = None

def _get_scheduler(self) -> None:
def _load_scheduler(self) -> None:
"""Obtain the scheduler."""
self.aggregator: BaseAggregator = eval(self.server_agent_config.server_configs.aggregator)(
self.model,
Expand All @@ -149,3 +162,26 @@ def _get_scheduler(self) -> None:
self.aggregator,
self.logger,
)

def _load_compressor(self) -> None:
"""Obtain the compressor."""
self.compressor = None
self.enable_compression = False
if not hasattr(self.server_agent_config.server_configs, "comm_configs"):
return
if not hasattr(self.server_agent_config.server_configs.comm_configs, "compressor_configs"):
return
if getattr(self.server_agent_config.server_configs.comm_configs.compressor_configs, "enable_compression", False):
self.enable_compression = True
self.compressor = Compressor(
self.server_agent_config.server_configs.comm_configs.compressor_configs
)

def _bytes_to_model(self, model_bytes: bytes) -> Union[Dict, OrderedDict]:
"""Deserialize the model from bytes."""
if not self.enable_compression:
print("[DEBUG] Decompressing model without compression")
return torch.load(io.BytesIO(model_bytes))
else:
print("[DEBUG] Decompressing model with compression")
return self.compressor.decompress_model(model_bytes, self.model)
4 changes: 2 additions & 2 deletions src/appfl/comm/grpc/grpc_client_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_global_model(self, **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Un
else:
return model, meta_data

def update_global_model(self, local_model: Union[Dict, OrderedDict], **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Union[Dict, OrderedDict], Dict]]:
def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Union[Dict, OrderedDict], Dict]]:
"""
Send local model to FL server for global update, and return the new global model.
:param local_model: the local model to be sent to the server for gloabl aggregation
Expand All @@ -103,7 +103,7 @@ def update_global_model(self, local_model: Union[Dict, OrderedDict], **kwargs) -
meta_data = json.dumps(kwargs)
request = UpdateGlobalModelRequest(
header=ClientHeader(client_id=self.client_id),
local_model=serialize_model(local_model),
local_model=serialize_model(local_model) if not isinstance(local_model, bytes) else local_model,
meta_data=meta_data,
)
byte_received = b''
Expand Down
Loading

0 comments on commit 9ec22fd

Please sign in to comment.