From f6ba14192f27eb7de33a2eb21cefe64339da64ba Mon Sep 17 00:00:00 2001 From: Zilinghan Date: Wed, 13 Mar 2024 15:01:28 -0500 Subject: [PATCH 1/4] Use commad `appfl-install-compressor` to install compressor packages --- .gitignore | 3 ++- setup.py | 7 ++++++ src/appfl/compressor/install.py | 10 +++++++++ src/appfl/compressor/install.sh | 39 +++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 src/appfl/compressor/install.py create mode 100644 src/appfl/compressor/install.sh diff --git a/.gitignore b/.gitignore index a6ee275..e8baa62 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__/ *.egg-info/ *output *outputs -*RawData \ No newline at end of file +*RawData +.compressor \ No newline at end of file diff --git a/setup.py b/setup.py index 6e6aca9..52ae034 100644 --- a/setup.py +++ b/setup.py @@ -36,10 +36,17 @@ "matplotlib", "torchvision", "globus-sdk", + "zfpy", + "blosc", + "zstd", + "scipy", + "lz4", + "python-xz", ], entry_points={ "console_scripts": [ "appfl-auth=appfl.login_manager.globus.cli:auth", + "appfl-install-compressor=appfl.compressor.install:install_compressor", ], }, ) diff --git a/src/appfl/compressor/install.py b/src/appfl/compressor/install.py new file mode 100644 index 0000000..016ad8c --- /dev/null +++ b/src/appfl/compressor/install.py @@ -0,0 +1,10 @@ +import os +import subprocess + +def install_compressor(): + """ + Install APPFL supported compressors into .compressor directory. + """ + current_path = os.path.dirname(__file__) + script_path = os.path.join(current_path, "install.sh") + subprocess.run(["bash", script_path]) diff --git a/src/appfl/compressor/install.sh b/src/appfl/compressor/install.sh new file mode 100644 index 0000000..7e4cb1c --- /dev/null +++ b/src/appfl/compressor/install.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# Create a directory for all installation +cd "$( dirname "${BASH_SOURCE[0]}" )" +cd ../../.. && mkdir -p .compressor && cd .compressor + +# Install ZFP +if pip show "zfpy" >/dev/null 2>&1; then + echo "zfpy is already installed." +else + # If the package is not installed, install it + echo "Installing zfpy..." + pip install zfpy +fi + +# Install SZ2 +echo "Installing SZ2..." +git clone https://github.com/szcompressor/SZ.git && cd SZ +mkdir -p build && cd build +cmake -DCMAKE_INSTALL_PREFIX:PATH=.. .. +make +make install +echo "SZ2 installation done." +echo "======================" +cd ../.. + +# Install SZ3 +echo "Installing SZ3..." +git clone https://github.com/szcompressor/SZ3.git && cd SZ3 +mkdir -p build && cd build +cmake -DCMAKE_INSTALL_PREFIX:PATH=.. .. +make +make install +echo "SZ3 installation done." +echo "======================" +cd ../.. + +# Install SZx +echo "As SZx is not open source, please install it manually by contacting the author." \ No newline at end of file From 9d6698472200d218fc0526a86d336528b165eff6 Mon Sep 17 00:00:00 2001 From: Zilinghan Date: Wed, 13 Mar 2024 15:20:03 -0500 Subject: [PATCH 2/4] Update compressor install script: 1) avoid duplication 2) colorful indication --- src/appfl/compressor/install.sh | 48 ++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/src/appfl/compressor/install.sh b/src/appfl/compressor/install.sh index 7e4cb1c..6bc2b55 100644 --- a/src/appfl/compressor/install.sh +++ b/src/appfl/compressor/install.sh @@ -6,34 +6,44 @@ cd ../../.. && mkdir -p .compressor && cd .compressor # Install ZFP if pip show "zfpy" >/dev/null 2>&1; then - echo "zfpy is already installed." + echo -e "\033[32mzfpy is already installed.\033[0m" else # If the package is not installed, install it - echo "Installing zfpy..." + echo -e "\033[32mInstalling zfpy...\033[0m" pip install zfpy + echo -e "\033[32mzfpy installation done.\033[0m" + echo "======================" fi # Install SZ2 -echo "Installing SZ2..." -git clone https://github.com/szcompressor/SZ.git && cd SZ -mkdir -p build && cd build -cmake -DCMAKE_INSTALL_PREFIX:PATH=.. .. -make -make install -echo "SZ2 installation done." -echo "======================" +if [ -d "SZ" ]; then + echo -e "\033[32mSZ2 is already installed.\033[0m" +else + echo -e "\033[32mInstalling SZ2...\033[0m" + git clone https://github.com/szcompressor/SZ.git && cd SZ + mkdir -p build && cd build + cmake -DCMAKE_INSTALL_PREFIX:PATH=.. .. + make + make install + echo -e "\033[32mSZ2 installation done.\033[0m" + echo "======================" +fi cd ../.. # Install SZ3 -echo "Installing SZ3..." -git clone https://github.com/szcompressor/SZ3.git && cd SZ3 -mkdir -p build && cd build -cmake -DCMAKE_INSTALL_PREFIX:PATH=.. .. -make -make install -echo "SZ3 installation done." -echo "======================" +if [ -d "SZ3" ]; then + echo -e "\033[32mSZ3 is already installed.\033[0m" +else + echo -e "\033[32mInstalling SZ3...\033[0m" + git clone https://github.com/szcompressor/SZ3.git && cd SZ3 + mkdir -p build && cd build + cmake -DCMAKE_INSTALL_PREFIX:PATH=.. .. + make + make install + echo -e "\033[32mSZ3 installation done.\033[0m" + echo "======================" +fi cd ../.. # Install SZx -echo "As SZx is not open source, please install it manually by contacting the author." \ No newline at end of file +echo -e "\033[31mSZx is not open source, please install it manually by contacting the author.\033[0m" \ No newline at end of file From d2cbebbf113250a831a952b0b9514badf938bea4 Mon Sep 17 00:00:00 2001 From: Zilinghan Date: Wed, 13 Mar 2024 18:33:24 -0500 Subject: [PATCH 3/4] Make comm_config fine-grained for different config categories --- examples/config/client_1.yaml | 15 ++++----------- examples/config/client_2.yaml | 15 ++++----------- examples/config/server_fedasync.yaml | 17 +++++++++++++---- examples/config/server_fedavg.yaml | 17 +++++++++++++---- examples/config/server_fedcompass.yaml | 17 +++++++++++++---- examples/grpc/run_client_1.py | 2 +- examples/grpc/run_client_2.py | 2 +- examples/grpc/run_server.py | 4 ++-- src/appfl/agent/server.py | 7 +++++++ src/appfl/config/config.py | 4 +++- 10 files changed, 61 insertions(+), 39 deletions(-) diff --git a/examples/config/client_1.yaml b/examples/config/client_1.yaml index e60cb70..07f3c14 100644 --- a/examples/config/client_1.yaml +++ b/examples/config/client_1.yaml @@ -19,14 +19,7 @@ data_configs: output_filename: "visualization.pdf" comm_configs: - server_uri: localhost:50051 - max_message_size: 1048576 - - use_ssl: False - # # SSL configurations - # use_ssl: True - # use_authenticator: True - # root_certificate: ../src/appfl/comm/grpc/credentials/root.crt - # authenticator: NaiveAuthenticator - # authenticator_args: - # auth_token: a-secret-token + grpc_configs: + server_uri: localhost:50051 + max_message_size: 1048576 + use_ssl: False \ No newline at end of file diff --git a/examples/config/client_2.yaml b/examples/config/client_2.yaml index c1fe986..6135a5b 100644 --- a/examples/config/client_2.yaml +++ b/examples/config/client_2.yaml @@ -19,14 +19,7 @@ data_configs: output_filename: "visualization.pdf" comm_configs: - server_uri: localhost:50051 - max_message_size: 1048576 - - use_ssl: False - # # SSL configurations - # use_ssl: True - # use_authenticator: True - # root_certificate: ../src/appfl/comm/grpc/credentials/root.crt - # authenticator: NaiveAuthenticator - # authenticator_args: - # auth_token: a-secret-token \ No newline at end of file + grpc_configs: + server_uri: localhost:50051 + max_message_size: 1048576 + use_ssl: False \ No newline at end of file diff --git a/examples/config/server_fedasync.yaml b/examples/config/server_fedasync.yaml index 907caf8..0302e4f 100644 --- a/examples/config/server_fedasync.yaml +++ b/examples/config/server_fedasync.yaml @@ -39,7 +39,15 @@ client_configs: num_pixel: 28 comm_configs: - enable_compression: False + compression_configs: + enable_compression: False + # Used if enable_compression is True + lossy_compressor: "SZ2" + lossless_compressor: "blosc" + error_bounding_mode: "REL" + error_bound: 1e-3 + flat_model_dtype: "np.float32" + param_cutoff: 1024 server_configs: scheduler: "AsyncScheduler" @@ -61,6 +69,7 @@ server_configs: logging_output_dirname: "./output" logging_output_filename: "result" comm_configs: - server_uri: localhost:50051 - max_message_size: 1048576 - use_ssl: False \ No newline at end of file + grpc_configs: + server_uri: localhost:50051 + max_message_size: 1048576 + use_ssl: False \ No newline at end of file diff --git a/examples/config/server_fedavg.yaml b/examples/config/server_fedavg.yaml index f8c093e..5f8b138 100644 --- a/examples/config/server_fedavg.yaml +++ b/examples/config/server_fedavg.yaml @@ -37,7 +37,15 @@ client_configs: num_pixel: 28 comm_configs: - enable_compression: False + compression_configs: + enable_compression: False + # Used if enable_compression is True + lossy_compressor: "SZ2" + lossless_compressor: "blosc" + error_bounding_mode: "REL" + error_bound: 1e-3 + flat_model_dtype: "np.float32" + param_cutoff: 1024 server_configs: scheduler: "SyncScheduler" @@ -53,6 +61,7 @@ server_configs: logging_output_dirname: "./output" logging_output_filename: "result" comm_configs: - server_uri: localhost:50051 - max_message_size: 1048576 - use_ssl: False \ No newline at end of file + grpc_configs: + server_uri: localhost:50051 + max_message_size: 1048576 + use_ssl: False \ No newline at end of file diff --git a/examples/config/server_fedcompass.yaml b/examples/config/server_fedcompass.yaml index 6195db9..e128d04 100644 --- a/examples/config/server_fedcompass.yaml +++ b/examples/config/server_fedcompass.yaml @@ -39,7 +39,15 @@ client_configs: num_pixel: 28 comm_configs: - enable_compression: False + compression_configs: + enable_compression: False + # Used if enable_compression is True + lossy_compressor: "SZ2" + lossless_compressor: "blosc" + error_bounding_mode: "REL" + error_bound: 1e-3 + flat_model_dtype: "np.float32" + param_cutoff: 1024 server_configs: scheduler: "CompassScheduler" @@ -65,6 +73,7 @@ server_configs: logging_output_dirname: "./output" logging_output_filename: "result" comm_configs: - server_uri: localhost:50051 - max_message_size: 1048576 - use_ssl: False \ No newline at end of file + grpc_configs: + server_uri: localhost:50051 + max_message_size: 1048576 + use_ssl: False \ No newline at end of file diff --git a/examples/grpc/run_client_1.py b/examples/grpc/run_client_1.py index 61ced52..e762ab9 100644 --- a/examples/grpc/run_client_1.py +++ b/examples/grpc/run_client_1.py @@ -9,7 +9,7 @@ client_agent = APPFLClientAgent(client_agent_config=client_agent_config) client_comm = GRPCClientCommunicator( client_id = client_agent.get_id(), - **client_agent_config.comm_configs, + **client_agent_config.comm_configs.grpc_configs, ) client_config = client_comm.get_configuration() diff --git a/examples/grpc/run_client_2.py b/examples/grpc/run_client_2.py index 7678858..c0ccf87 100644 --- a/examples/grpc/run_client_2.py +++ b/examples/grpc/run_client_2.py @@ -9,7 +9,7 @@ client_agent = APPFLClientAgent(client_agent_config=client_agent_config) client_comm = GRPCClientCommunicator( client_id = client_agent.get_id(), - **client_agent_config.comm_configs, + **client_agent_config.comm_configs.grpc_configs, ) client_config = client_comm.get_configuration() diff --git a/examples/grpc/run_server.py b/examples/grpc/run_server.py index 49c62d0..d09b85a 100644 --- a/examples/grpc/run_server.py +++ b/examples/grpc/run_server.py @@ -17,10 +17,10 @@ communicator = GRPCServerCommunicator( server_agent, - max_message_size=server_agent_config.server_configs.comm_configs.max_message_size, + max_message_size=server_agent_config.server_configs.comm_configs.grpc_configs.max_message_size, ) serve( communicator, - **server_agent_config.server_configs.comm_configs, + **server_agent_config.server_configs.comm_configs.grpc_configs, ) diff --git a/src/appfl/agent/server.py b/src/appfl/agent/server.py index 44d5240..3754f61 100644 --- a/src/appfl/agent/server.py +++ b/src/appfl/agent/server.py @@ -22,6 +22,13 @@ def __init__( server_agent_config: ServerAgentConfig = ServerAgentConfig() ) -> None: self.server_agent_config = server_agent_config + if hasattr(self.server_agent_config.client_configs, "comm_configs"): + self.server_agent_config.server_configs.comm_configs = (OmegaConf.merge( + self.server_agent_config.server_configs.comm_configs, + self.server_agent_config.client_configs.comm_configs + ) if hasattr(self.server_agent_config.server_configs, "comm_configs") + else self.server_agent_config.client_configs.comm_configs + ) self._create_logger() self._load_model() self._load_loss() diff --git a/src/appfl/config/config.py b/src/appfl/config/config.py index 2bbcef9..00ab0e6 100644 --- a/src/appfl/config/config.py +++ b/src/appfl/config/config.py @@ -16,11 +16,13 @@ class ClientAgentConfig: It basically holds the following types of configurations: - train_configs: Configurations for local training, such as trainer, device, optimizer, loss function, etc. - model_configs: Configurations for the AI model - - comm_configs: Configurations for communication, such as compression, etc. - data_configs: Configurations for the data loader + - comm_configs: Configurations for communication, such as compression, etc. + - additional_configs: Additional configurations that are not covered by the above categories. """ train_configs: DictConfig = OmegaConf.create({}) model_configs: DictConfig = OmegaConf.create({}) data_configs: DictConfig = OmegaConf.create({}) comm_configs: DictConfig = OmegaConf.create({}) + additional_configs: DictConfig = OmegaConf.create({}) From 40521ab5455c06c09b4c97a6b73a06b05a241360 Mon Sep 17 00:00:00 2001 From: Zilinghan Date: Wed, 13 Mar 2024 21:40:33 -0500 Subject: [PATCH 4/4] Add compressor to client and server agents --- examples/config/server_fedasync.yaml | 2 +- examples/config/server_fedavg.yaml | 2 +- examples/config/server_fedcompass.yaml | 2 +- examples/grpc/run_client_1.py | 1 - examples/grpc/run_client_2.py | 1 - src/appfl/agent/client.py | 34 +- src/appfl/agent/server.py | 43 ++- .../comm/grpc/grpc_client_communicator.py | 4 +- .../comm/grpc/grpc_server_communicator.py | 2 +- src/appfl/compressor/README.md | 28 ++ src/appfl/compressor/__init__.py | 1 + src/appfl/compressor/compressor.py | 318 ++++++++++++++++++ src/appfl/compressor/pysz.py | 153 +++++++++ src/appfl/compressor/pyszx.py | 156 +++++++++ src/appfl/trainer/base_trainer.py | 6 +- 15 files changed, 731 insertions(+), 22 deletions(-) create mode 100644 src/appfl/compressor/README.md create mode 100644 src/appfl/compressor/__init__.py create mode 100644 src/appfl/compressor/compressor.py create mode 100644 src/appfl/compressor/pysz.py create mode 100644 src/appfl/compressor/pyszx.py diff --git a/examples/config/server_fedasync.yaml b/examples/config/server_fedasync.yaml index 0302e4f..566a25d 100644 --- a/examples/config/server_fedasync.yaml +++ b/examples/config/server_fedasync.yaml @@ -39,7 +39,7 @@ client_configs: num_pixel: 28 comm_configs: - compression_configs: + compressor_configs: enable_compression: False # Used if enable_compression is True lossy_compressor: "SZ2" diff --git a/examples/config/server_fedavg.yaml b/examples/config/server_fedavg.yaml index 5f8b138..925881c 100644 --- a/examples/config/server_fedavg.yaml +++ b/examples/config/server_fedavg.yaml @@ -37,7 +37,7 @@ client_configs: num_pixel: 28 comm_configs: - compression_configs: + compressor_configs: enable_compression: False # Used if enable_compression is True lossy_compressor: "SZ2" diff --git a/examples/config/server_fedcompass.yaml b/examples/config/server_fedcompass.yaml index e128d04..a37ee18 100644 --- a/examples/config/server_fedcompass.yaml +++ b/examples/config/server_fedcompass.yaml @@ -39,7 +39,7 @@ client_configs: num_pixel: 28 comm_configs: - compression_configs: + compressor_configs: enable_compression: False # Used if enable_compression is True lossy_compressor: "SZ2" diff --git a/examples/grpc/run_client_1.py b/examples/grpc/run_client_1.py index e762ab9..356e027 100644 --- a/examples/grpc/run_client_1.py +++ b/examples/grpc/run_client_1.py @@ -20,7 +20,6 @@ # Send the number of local data to the server sample_size = client_agent.get_sample_size() -print(f"Sample size: {sample_size}") client_comm.invoke_custom_action(action='set_sample_size', sample_size=sample_size) for i in range(10): diff --git a/examples/grpc/run_client_2.py b/examples/grpc/run_client_2.py index c0ccf87..5c1d320 100644 --- a/examples/grpc/run_client_2.py +++ b/examples/grpc/run_client_2.py @@ -20,7 +20,6 @@ # Send the number of local data to the server sample_size = client_agent.get_sample_size() -print(f"Sample size: {sample_size}") client_comm.invoke_custom_action(action='set_sample_size', sample_size=sample_size) for i in range(10): diff --git a/src/appfl/agent/client.py b/src/appfl/agent/client.py index def0509..e80fc10 100644 --- a/src/appfl/agent/client.py +++ b/src/appfl/agent/client.py @@ -2,9 +2,10 @@ import importlib import torch.nn as nn from appfl.trainer import BaseTrainer +from appfl.compressor import Compressor from appfl.config import ClientAgentConfig from omegaconf import DictConfig, OmegaConf -from typing import Union, Dict, OrderedDict +from typing import Union, Dict, OrderedDict, Tuple from appfl.logger import ClientAgentFileLogger from appfl.misc import create_instance_from_file, \ run_function_from_file, \ @@ -38,6 +39,7 @@ def __init__( self._load_metric() self._load_data() self._load_trainer() + self._load_compressor() def load_config(self, config: DictConfig) -> None: """Load additional configurations provided by the server.""" @@ -46,6 +48,7 @@ def load_config(self, config: DictConfig) -> None: self._load_loss() self._load_metric() self._load_trainer() + self._load_compressor() def get_id(self) -> str: """Return a unique client id for server to distinguish clients.""" @@ -61,10 +64,16 @@ def train(self) -> None: """Train the model locally.""" self.trainer.train() - def get_parameters(self) -> Union[Dict, OrderedDict, bytes]: + def get_parameters(self) -> Union[Dict, OrderedDict, bytes, Tuple[Union[Dict, OrderedDict, bytes], Dict]]: """Return parameters for communication""" params = self.trainer.get_parameters() - return params + if isinstance(params, tuple): + params, metadata = params + else: + metadata = None + if self.enable_compression: + params = self.compressor.compress_model(params) + return params if metadata is None else (params, metadata) def load_parameters(self, params) -> None: """Load parameters from the server.""" @@ -200,4 +209,21 @@ def _load_trainer(self) -> None: train_configs=self.client_agent_config.train_configs, logger=self.logger, ) - \ No newline at end of file + + def _load_compressor(self) -> None: + """ + Create a compressor for compressing the model parameters. + """ + if hasattr(self, "compressor") and self.compressor is not None: + return + self.compressor = None + self.enable_compression = False + if not hasattr(self.client_agent_config, "comm_configs"): + return + if not hasattr(self.client_agent_config.comm_configs, "compressor_configs"): + return + if getattr(self.client_agent_config.comm_configs.compressor_configs, "enable_compression", False): + self.enable_compression = True + self.compressor = Compressor( + self.client_agent_config.comm_configs.compressor_configs + ) diff --git a/src/appfl/agent/server.py b/src/appfl/agent/server.py index 3754f61..ab0b90d 100644 --- a/src/appfl/agent/server.py +++ b/src/appfl/agent/server.py @@ -1,12 +1,15 @@ +import io +import torch import torch.nn as nn from appfl.scheduler import * from appfl.aggregator import * -from concurrent.futures import Future +from appfl.compressor import Compressor from appfl.config import ServerAgentConfig -from omegaconf import OmegaConf, DictConfig +from appfl.misc import create_instance_from_file, get_function_from_file from appfl.logger import ServerAgentFileLogger +from concurrent.futures import Future +from omegaconf import OmegaConf, DictConfig from typing import Union, Dict, OrderedDict, Tuple -from appfl.misc import create_instance_from_file, get_function_from_file class APPFLServerAgent: """ @@ -33,7 +36,8 @@ def __init__( self._load_model() self._load_loss() self._load_metric() - self._get_scheduler() + self._load_scheduler() + self._load_compressor() def get_client_configs(self, **kwargs) -> DictConfig: """Return the FL configurations that are shared among all clients.""" @@ -42,18 +46,20 @@ def get_client_configs(self, **kwargs) -> DictConfig: def global_update( self, client_id: Union[int, str], - local_model: Union[Dict, OrderedDict], + local_model: Union[Dict, OrderedDict, bytes], blocking: bool = False, **kwargs ) -> Union[Future, Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]: """ Update the global model using the local model from a client and return the updated global model. :param: client_id: A unique client id for server to distinguish clients, which be obtained via `ClientAgent.get_id()`. - :param: local_model: The local model from a client. + :param: local_model: The local model from a client, can be serailzed bytes. :param: blocking: The global model may not be immediately available for certain aggregation methods (e.g. any synchronous method). Setting `blocking` to `True` will block the client until the global model is available. Otherwise, the method may return a `Future` object if the most up-to-date global model is not yet available. """ + if isinstance(local_model, bytes): + local_model = self._bytes_to_model(local_model) global_model = self.scheduler.schedule(client_id, local_model, **kwargs) if not isinstance(global_model, Future): return global_model @@ -144,7 +150,7 @@ def _load_metric(self) -> None: else: self.metric = None - def _get_scheduler(self) -> None: + def _load_scheduler(self) -> None: """Obtain the scheduler.""" self.aggregator: BaseAggregator = eval(self.server_agent_config.server_configs.aggregator)( self.model, @@ -156,3 +162,26 @@ def _get_scheduler(self) -> None: self.aggregator, self.logger, ) + + def _load_compressor(self) -> None: + """Obtain the compressor.""" + self.compressor = None + self.enable_compression = False + if not hasattr(self.server_agent_config.server_configs, "comm_configs"): + return + if not hasattr(self.server_agent_config.server_configs.comm_configs, "compressor_configs"): + return + if getattr(self.server_agent_config.server_configs.comm_configs.compressor_configs, "enable_compression", False): + self.enable_compression = True + self.compressor = Compressor( + self.server_agent_config.server_configs.comm_configs.compressor_configs + ) + + def _bytes_to_model(self, model_bytes: bytes) -> Union[Dict, OrderedDict]: + """Deserialize the model from bytes.""" + if not self.enable_compression: + print("[DEBUG] Decompressing model without compression") + return torch.load(io.BytesIO(model_bytes)) + else: + print("[DEBUG] Decompressing model with compression") + return self.compressor.decompress_model(model_bytes, self.model) \ No newline at end of file diff --git a/src/appfl/comm/grpc/grpc_client_communicator.py b/src/appfl/comm/grpc/grpc_client_communicator.py index adf107d..286dee4 100644 --- a/src/appfl/comm/grpc/grpc_client_communicator.py +++ b/src/appfl/comm/grpc/grpc_client_communicator.py @@ -93,7 +93,7 @@ def get_global_model(self, **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Un else: return model, meta_data - def update_global_model(self, local_model: Union[Dict, OrderedDict], **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Union[Dict, OrderedDict], Dict]]: + def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Union[Dict, OrderedDict], Dict]]: """ Send local model to FL server for global update, and return the new global model. :param local_model: the local model to be sent to the server for gloabl aggregation @@ -103,7 +103,7 @@ def update_global_model(self, local_model: Union[Dict, OrderedDict], **kwargs) - meta_data = json.dumps(kwargs) request = UpdateGlobalModelRequest( header=ClientHeader(client_id=self.client_id), - local_model=serialize_model(local_model), + local_model=serialize_model(local_model) if not isinstance(local_model, bytes) else local_model, meta_data=meta_data, ) byte_received = b'' diff --git a/src/appfl/comm/grpc/grpc_server_communicator.py b/src/appfl/comm/grpc/grpc_server_communicator.py index c24f4cf..5f2e055 100644 --- a/src/appfl/comm/grpc/grpc_server_communicator.py +++ b/src/appfl/comm/grpc/grpc_server_communicator.py @@ -88,7 +88,7 @@ def UpdateGlobalModel(self, request_iterator, context): request.ParseFromString(bytes_received) self.logger.info(f"Received UpdateGlobalModel request from client {request.header.client_id}") client_id = request.header.client_id - local_model = torch.load(io.BytesIO(request.local_model)) + local_model = request.local_model if len(request.meta_data) == 0: meta_data = {} else: diff --git a/src/appfl/compressor/README.md b/src/appfl/compressor/README.md new file mode 100644 index 0000000..17ab667 --- /dev/null +++ b/src/appfl/compressor/README.md @@ -0,0 +1,28 @@ +# 🗜 Model Parameter Compressor + +The `appfl.compressor` module can be used for compressing the model parameters or gradients in a lossy manner before the client sends them back to the server for more efficient communication. The server then will decompress the compressed model before the global aggregation. + +The `appfl.compressor` currently supports the following lossy compressors. Please refer to their official project/GitHub pages if you want more detailed information of them. Here, we only provide the installation instructions. **Note: SZx need particular permission to access because of the collaboration with a third-party, so we omit its installation here.** + +1. [SZ2: Error-bounded Lossy Compressor for HPC Data](https://github.com/szcompressor/SZ) +2. [SZ3: A Modular Error-bounded Lossy Compression Framework for Scientific Datasets](https://github.com/szcompressor/SZ3) +3. [ZFP: Compressed Floating-Point and Integer Arrays](https://pypi.org/project/zfpy/) +4. [SZX: An Ultra-fast Error-bounded Lossy Compressor for Scientific Datasets](https://github.com/szcompressor/SZx) + +## Installation +Users can easily install all the above compressors by running the following command. +```bash +appfl-install-compressor +``` + +## Citation +Please check the following paper for details about how the compressor plays a role in federated learning. + +``` +@article{wilkins2023efficient, + title={Efficient Communication in Federated Learning Using Floating-Point Lossy Compression}, + author={Wilkins, Grant and Di, Sheng and Calhoun, Jon C and Kim, Kibaek and Underwood, Robert and Cappello, Franck}, + journal={arXiv preprint arXiv:2312.13461}, + year={2023} +} +``` \ No newline at end of file diff --git a/src/appfl/compressor/__init__.py b/src/appfl/compressor/__init__.py new file mode 100644 index 0000000..1a7ca40 --- /dev/null +++ b/src/appfl/compressor/__init__.py @@ -0,0 +1 @@ +from .compressor import * diff --git a/src/appfl/compressor/compressor.py b/src/appfl/compressor/compressor.py new file mode 100644 index 0000000..38fcade --- /dev/null +++ b/src/appfl/compressor/compressor.py @@ -0,0 +1,318 @@ +import os +import sys +import gzip +import lzma +import zfpy +import zlib +import zstd +import blosc +import torch +import pickle +import numpy as np +from . import pysz +from . import pyszx +from copy import deepcopy +from omegaconf import DictConfig +from collections import OrderedDict +from typing import Tuple, Union, List + +class Compressor: + def __init__(self, compressor_config: DictConfig): + current_path = os.path.dirname(os.path.abspath(__file__)) + appfl_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_path))) + self.cfg = compressor_config + self.sz_error_mode_dict = { + "ABS": 0, + "REL": 1, + "ABS_AND_REL": 2, + "ABS_OR_REL": 3, + "PSNR": 4, + "NORM": 5, + "PW_REL": 10, + } + self.lossless_compressor = compressor_config.lossless_compressor + self.compression_layers = [] + self.compressor_lib_path = "" + self.param_count_threshold = compressor_config.param_cutoff + ext = ".dylib" if sys.platform.startswith("darwin") else ".so" + if self.cfg.lossy_compressor == "SZ3": + self.compressor_lib_path = os.path.join(appfl_root_dir, ".compressor/SZ/build/sz/libSZ") + ext + elif self.cfg.lossy_compressor == "SZ2": + self.compressor_lib_path = os.path.join(appfl_root_dir, ".compressor/SZ3/build/tools/sz3c/libSZ3c") + ext + elif self.cfg.lossy_compressor == "SZx": + self.compressor_lib_path = os.path.join(appfl_root_dir, ".compressor/SZx-main/build/lib/libSZx") + ext + + def compress_model( + self, + model: Union[dict, OrderedDict, List[Union[dict, OrderedDict]]], + batched: bool=False + ) -> bytes: + """ + Compress all the parameters of local model(s) for efficient communication. The local model can be batched as a list. + :param model: local model parameters (can be nested) + :param batched: whether the input is a batch of models + :return: compressed model parameters as bytes + """ + # Deal with batched models + if batched: + if isinstance(model, list): + compressed_models = [] + num_lossy_elements = 0 + for model_sample in model: + compressed_model = self.compress_model(model_sample) + compressed_models.append(compressed_model) + num_lossy_elements += lossy_elements + return pickle.dumps(compressed_models) + if isinstance(model, dict) or isinstance(model, OrderedDict): + compressed_models = OrderedDict() + num_lossy_elements = 0 + for key, model_sample in model.items(): + compressed_model = self.compress_model(model_sample) + compressed_models[key] = compressed_model + num_lossy_elements += lossy_elements + return pickle.dumps(compressed_models) + + for _, value in model.items(): + is_nested = not isinstance(value, torch.Tensor) + break + + if is_nested: + num_lossy_elements = 0 + compressed_models = OrderedDict() + for key, weights in model.items(): + comprsessed_weights, lossy_elements = self._compress_weights(weights) + compressed_models[key] = comprsessed_weights + lossy_elements += lossy_elements + else: + compressed_models, num_lossy_elements = self._compress_weights(model) + return pickle.dumps(compressed_models) + + def decompress_model( + self, + compressed_model: bytes, + model: Union[dict, OrderedDict], + batched: bool=False + )-> Union[OrderedDict, dict, List[Union[OrderedDict, dict]]]: + """ + Decompress all the communicated model parameters. The local model can be batched as a list. + :param compressed_model: compressed model parameters as bytes + :param model: a model sample for de-compression reference + :param batched: whether the input is a batch of models + :return decompressed_model: decompressed model parameters + """ + compressed_model = pickle.loads(compressed_model) + + # Deal with batched models + if batched: + if isinstance(compressed_model, list): + decompressed_models = [] + for compressed_model_sample in compressed_model: + decompressed_model_sample = self.decompress_model(compressed_model_sample, model) + decompressed_models.append(decompressed_model_sample) + return decompressed_models + if isinstance(compressed_model, dict) or isinstance(compressed_model, OrderedDict): + decompressed_models = OrderedDict() + for key, compressed_model_sample in compressed_model.items(): + decompressed_model_sample = self.decompress_model(compressed_model_sample, model) + decompressed_models[key] = decompressed_model_sample + return decompressed_models + + for _, value in compressed_model.items(): + is_nested = not isinstance(value, bytes) + break + if is_nested: + decompressed_model = OrderedDict() + for key, value in compressed_model.items(): + decompressed_model[key] = self._decompress_model(value, model) + else: + decompressed_model = self._decompress_model(compressed_model, model) + return decompressed_model + + def _compress_weights( + self, + weights: Union[OrderedDict, dict] + ) -> Tuple[Union[OrderedDict, dict], int]: + """ + Compress ONE set of weights of the model. + :param weights: the model weights to be compressed + :return: the compressed model weights and the number of lossy elements + """ + # Check if the input a set of model weights + if len(weights) == 0: + return (weights, 0) + for _, value in weights.items(): + if not isinstance(value, torch.Tensor): + return (weights, 0) + break + + compressed_weights = {} + lossy_elements = 0 + lossy_original_size = 0 + lossy_compressed_size = 0 + lossless_original_size = 0 + lossless_compressed_size = 0 + + for name, param in weights.items(): + param_flat = param.flatten().detach().cpu().numpy() + if "weight" in name and param_flat.size > self.param_count_threshold: + lossy_original_size += param_flat.nbytes + lossy_elements += param_flat.size + compressed_weights[name] = self._compress(ori_data=param_flat) + lossy_compressed_size += len(compressed_weights[name]) + else: + lossless_original_size += param_flat.nbytes + lossless = b"" + if self.lossless_compressor == "zstd": + lossless = zstd.compress(param_flat, 10) + elif self.lossless_compressor == "gzip": + lossless = gzip.compress(param_flat.tobytes()) + elif self.lossless_compressor == "zlib": + lossless = zlib.compress(param_flat.tobytes()) + elif self.lossless_compressor == "blosc": + lossless = blosc.compress(param_flat.tobytes(), typesize=4) + elif self.lossless_compressor == "lzma": + lossless = lzma.compress(param_flat.tobytes()) + else: + raise NotImplementedError + lossless_compressed_size += len(lossless) + compressed_weights[name] = lossless + # if lossy_compressed_size != 0: + # print("Lossy Compression Ratio: " + str(lossy_original_size / lossy_compressed_size)) + # if lossless_compressed_size != 0: + # print("Lossless Compression Ratio: " + str(lossless_original_size / lossless_compressed_size)) + # print("Total Compression Ratio: " + str((lossy_original_size + lossless_original_size) / (lossy_compressed_size + lossless_compressed_size))) + return ( + compressed_weights, + lossy_elements, + ) + + def _compress(self, ori_data: np.ndarray): + """ + Compress data with chosen compressor + :param ori_data: compressed data, numpy array format + :return: decompressed data,numpy array format + """ + if self.cfg.lossy_compressor == "SZ3" or self.cfg.lossy_compressor == "SZ2": + compressor = pysz.SZ(szpath=self.compressor_lib_path) + error_mode = self.sz_error_mode_dict[self.cfg.error_bounding_mode] + error_bound = self.cfg.error_bound + compressed_arr, comp_ratio = compressor.compress( + data=ori_data, + eb_mode=error_mode, + eb_abs=error_bound, + eb_rel=error_bound, + eb_pwr=error_bound, + ) + return compressed_arr.tobytes() + elif self.cfg.lossy_compressor == "SZx": + compressor = pyszx.SZx(szxpath=self.compressor_lib_path) + error_mode = self.sz_error_mode_dict[self.cfg.error_bounding_mode] + error_bound = self.cfg.error_bound + compressed_arr, comp_ratio = compressor.compress( + data=ori_data, + eb_mode=error_mode, + eb_abs=error_bound, + eb_rel=error_bound, + ) + return compressed_arr.tobytes() + elif self.cfg.lossy_compressor == "ZFP": + if self.cfg.error_bounding_mode == "ABS": + return zfpy.compress_numpy(ori_data, tolerance=self.cfg.error_bound) + elif self.cfg.error_bounding_mode == "REL": + range_data = abs(np.max(ori_data) - np.min(ori_data)) + return zfpy.compress_numpy( + ori_data, tolerance=self.cfg.error_bound * range_data + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + def _decompress_model( + self, + compressed_weights: Union[dict, OrderedDict], + model: Union[dict, OrderedDict] + ) -> Union[OrderedDict, dict]: + """ + Decompress ONE set of weights of the model. + :param compressed_weights: the compressed model weights + :param model: a model sample for de-compression reference + :return: decompressed model weights + """ + if len(compressed_weights) == 0: + return compressed_weights + for _, value in compressed_weights.items(): + if not isinstance(value, bytes): + return compressed_weights + break + decompressed_weights = OrderedDict() + for name, param in model.state_dict().items(): + if "weight" in name and param.numel() > self.param_count_threshold: + compressed_weights[name] = self._decompress( + cmp_data=compressed_weights[name], + ori_shape=(param.numel(),), + ori_dtype=np.float32, + ).astype(np.float32) + else: + if self.lossless_compressor == "zstd": + compressed_weights[name] = zstd.decompress(compressed_weights[name]) + elif self.lossless_compressor == "gzip": + compressed_weights[name] = gzip.decompress(compressed_weights[name]) + elif self.lossless_compressor == "zlib": + compressed_weights[name] = zlib.decompress(compressed_weights[name]) + elif self.lossless_compressor == "blosc": + compressed_weights[name] = blosc.decompress( + compressed_weights[name], as_bytearray=True + ) + elif self.lossless_compressor == "lzma": + compressed_weights[name] = lzma.decompress(compressed_weights[name]) + else: + raise NotImplementedError + compressed_weights[name] = np.frombuffer( + compressed_weights[name], dtype=np.float32 + ) + if param.shape == torch.Size([]): + copy_arr = deepcopy(compressed_weights[name]) + copy_tensor = torch.from_numpy(copy_arr) + decompressed_weights[name] = torch.tensor(copy_tensor) + else: + copy_arr = deepcopy(compressed_weights[name]) + copy_tensor = torch.from_numpy(copy_arr) + decompressed_weights[name] = copy_tensor.reshape(param.shape) + return decompressed_weights + + def _decompress( + self, + cmp_data, + ori_shape: Tuple[int, ...], + ori_dtype: np.dtype + ) -> np.ndarray: + """ + Decompress data with chosen compressor + :param cmp_data: compressed data, numpy array format, dtype should be np.uint8 + :param ori_shape: the shape of original data + :param ori_dtype: the dtype of original data + :return: decompressed data,numpy array format + """ + if self.cfg.lossy_compressor == "SZ3" or self.cfg.lossy_compressor == "SZ2": + compressor = pysz.SZ(szpath=self.compressor_lib_path) + cmp_data = np.frombuffer(cmp_data, dtype=np.uint8) + decompressed_arr = compressor.decompress( + data_cmpr=cmp_data, + original_shape=ori_shape, + original_dtype=ori_dtype, + ) + return decompressed_arr + elif self.cfg.lossy_compressor == "SZx": + compressor = pyszx.SZx(szxpath=self.compressor_lib_path) + cmp_data = np.frombuffer(cmp_data, dtype=np.uint8) + decompressed_arr = compressor.decompress( + data_cmpr=cmp_data, + original_shape=ori_shape, + original_dtype=ori_dtype, + ) + return decompressed_arr + elif self.cfg.lossy_compressor == "ZFP": + return zfpy.decompress_numpy(cmp_data) + else: + raise NotImplementedError diff --git a/src/appfl/compressor/pysz.py b/src/appfl/compressor/pysz.py new file mode 100644 index 0000000..860c8a4 --- /dev/null +++ b/src/appfl/compressor/pysz.py @@ -0,0 +1,153 @@ +""" +Python API for SZ2/SZ3 +""" +import sys +import ctypes +import numpy as np +from ctypes.util import find_library + +class SZ: + def __init__(self, szpath=None): + """ + init SZ + :param szpath: the path to SZ dynamic library + """ + if szpath is None: + szpath = { + "darwin": "libSZ3c.dylib", + "windows": "SZ3c.dll", + }.get(sys.platform, "libSZ3c.so") + + self.sz = ctypes.cdll.LoadLibrary(szpath) + + self.sz.SZ_compress_args.argtypes = ( + ctypes.c_int, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_size_t), + ctypes.c_int, + ctypes.c_double, + ctypes.c_double, + ctypes.c_double, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ) + self.sz.SZ_compress_args.restype = ctypes.POINTER(ctypes.c_ubyte) + + self.sz.SZ_decompress.argtypes = ( + ctypes.c_int, + ctypes.POINTER(ctypes.c_ubyte), + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ) + + self.libc = ctypes.CDLL(ctypes.util.find_library("c")) + self.libc.free.argtypes = (ctypes.c_void_p,) + + def __sz_datatype(self, dtype, data=None): + if dtype == np.float32: + return ( + 0, + data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + if data is not None + else None, + ) + elif dtype == np.float64: + return ( + 1, + data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) + if data is not None + else None, + ) + else: + print("SZ currently supports float32 and float64\n") + exit(0) + + def verify(self, src_data, dec_data): + """ + Compare the decompressed data with original data + :param src_data: original data, numpy array + :param dec_data: decompressed data, numpy array + :return: max_diff, psnr, nrmse + """ + data_range = np.max(src_data) - np.min(src_data) + diff = src_data - dec_data + max_diff = np.max(abs(diff)) + # print("abs err={:.8G}".format(max_diff)) + mse = np.mean(diff**2) + nrmse = np.sqrt(mse) / data_range + psnr = 20 * np.log10(data_range) - 10 * np.log10(mse) + return max_diff, psnr, nrmse + + def decompress(self, data_cmpr, original_shape, original_dtype): + """ + Decompress data with SZ + :param data_cmpr: compressed data, numpy array format, dtype should be np.uint8 + :param original_shape: the shape of original data + :param original_dtype: the dtype of original data + :return: decompressed data,numpy array format + """ + + r5, r4, r3, r2, r1 = [0] * (5 - len(original_shape)) + list(original_shape) + ori_type, ori_null = self.__sz_datatype(original_dtype) + self.sz.SZ_decompress.restype = ctypes.POINTER( + ctypes.c_float if original_dtype == np.float32 else ctypes.c_double + ) + data_dec_c = self.sz.SZ_decompress( + ori_type, + data_cmpr.ctypes.data_as(ctypes.POINTER(ctypes.c_ubyte)), + data_cmpr.size, + r5, + r4, + r3, + r2, + r1, + ) + + data_dec = np.array(data_dec_c[: np.prod(original_shape)]).reshape( + original_shape + ) + self.libc.free(data_dec_c) + return data_dec + + def compress(self, data, eb_mode, eb_abs, eb_rel, eb_pwr): + """ + Compress data with SZ + :param data: original data, numpy array format, dtype is FP32 or FP64 + :param eb_mode:# error bound mode, integer (0: ABS, 1:REL, 2:ABS_AND_REL, 3:ABS_OR_REL, 4:PSNR, 5:NORM, 10:PW_REL) + :param eb_abs: optional, abs error bound, double + :param eb_rel: optional, rel error bound, double + :param eb_pwr: optional, pwr error bound, double + :return: compressed data, numpy array format, dtype is np.uint8 + compression ratio + """ + assert len(data.shape) <= 5, "SZ only supports 1D to 5D input data" + cmpr_size = ctypes.c_size_t() + r5, r4, r3, r2, r1 = [0] * (5 - len(data.shape)) + list(data.shape) + datatype, datap = self.__sz_datatype(data.dtype, data) + data_cmpr_c = self.sz.SZ_compress_args( + datatype, + datap, + ctypes.byref(cmpr_size), + eb_mode, + eb_abs, + eb_rel, + eb_pwr, + r5, + r4, + r3, + r2, + r1, + ) + + cmpr_ratio = data.size * data.itemsize / cmpr_size.value + + data_cmpr = np.array(data_cmpr_c[: cmpr_size.value], dtype=np.uint8) + self.libc.free(data_cmpr_c) + return data_cmpr, cmpr_ratio diff --git a/src/appfl/compressor/pyszx.py b/src/appfl/compressor/pyszx.py new file mode 100644 index 0000000..efad3a8 --- /dev/null +++ b/src/appfl/compressor/pyszx.py @@ -0,0 +1,156 @@ +""" +Python API for SZx +""" +import sys +import ctypes +import numpy as np +from ctypes.util import find_library + +class SZx: + def __init__(self, szxpath=None): + """ + init SZx + :param szxpath: the path to SZx dynamic library + """ + if szxpath is None: + szxpath = { + "darwin": "libSZx.dylib", + "windows": "SZx.dll", + }.get(sys.platform, "libSZx.so") + + self.szx = ctypes.cdll.LoadLibrary(szxpath) + + self.szx.SZ_fast_compress_args.argtypes = ( + ctypes.c_int, + ctypes.c_int, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_size_t), + ctypes.c_int, + ctypes.c_double, + ctypes.c_double, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ) + self.szx.SZ_fast_compress_args.restype = ctypes.POINTER(ctypes.c_ubyte) + + self.szx.SZ_fast_decompress.argtypes = ( + ctypes.c_int, + ctypes.c_int, + ctypes.POINTER(ctypes.c_ubyte), + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ctypes.c_size_t, + ) + + self.libc = ctypes.CDLL(ctypes.util.find_library("c")) + self.libc.free.argtypes = (ctypes.c_void_p,) + + def __sz_datatype(self, dtype, data=None): + if dtype == np.float32: + return ( + 0, + data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + if data is not None + else None, + ) + elif dtype == np.float64: + return ( + 1, + data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) + if data is not None + else None, + ) + else: + print("SZx currently supports float32 and float64\n") + exit(0) + + def verify(self, src_data, dec_data): + """ + Compare the decompressed data with original data + :param src_data: original data, numpy array + :param dec_data: decompressed data, numpy array + :return: max_diff, psnr, nrmse + """ + data_range = np.max(src_data) - np.min(src_data) + diff = src_data - dec_data + max_diff = np.max(abs(diff)) + print("abs err={:.8G}".format(max_diff)) + mse = np.mean(diff**2) + nrmse = np.sqrt(mse) / data_range + psnr = 20 * np.log10(data_range) - 10 * np.log10(mse) + return max_diff, psnr, nrmse + + def decompress(self, data_cmpr, original_shape, original_dtype): + """ + Decompress data with SZx + :param data_cmpr: compressed data, numpy array format, dtype should be np.uint8 + :param original_shape: the shape of original data + :param original_dtype: the dtype of original data + :return: decompressed data,numpy array format + """ + + r5, r4, r3, r2, r1 = [0] * (5 - len(original_shape)) + list(original_shape) + ori_type, ori_null = self.__sz_datatype(original_dtype) + self.szx.SZ_fast_decompress.restype = ctypes.POINTER( + ctypes.c_float if original_dtype == np.float32 else ctypes.c_double + ) + fast_mode = 2 + data_dec_c = self.szx.SZ_fast_decompress( + fast_mode, + ori_type, + data_cmpr.ctypes.data_as(ctypes.POINTER(ctypes.c_ubyte)), + data_cmpr.size, + r5, + r4, + r3, + r2, + r1, + ) + + data_dec = np.array(data_dec_c[: np.prod(original_shape)]).reshape( + original_shape + ) + self.libc.free(data_dec_c) + return data_dec + + def compress(self, data, eb_mode, eb_abs, eb_rel): + """ + Compress data with SZx + :param data: original data, numpy array format, dtype is FP32 or FP64 + :param eb_mode:# error bound mode, integer (0: ABS, 1:REL, 2:ABS_AND_REL, 3:ABS_OR_REL, 4:PSNR, 5:NORM, 10:PW_REL) + :param eb_abs: optional, abs error bound, double + :param eb_rel: optional, rel error bound, double + :return: compressed data, numpy array format, dtype is np.uint8 + compression ratio + """ + assert len(data.shape) <= 5, "SZx only supports 1D to 5D input data" + cmpr_size = ctypes.c_ulong() + r5, r4, r3, r2, r1 = [0] * (5 - len(data.shape)) + list(data.shape) + datatype, datap = self.__sz_datatype(data.dtype, data) + fast_mode = 2 + data_cmpr_c = self.szx.SZ_fast_compress_args( + fast_mode, + datatype, + datap, + ctypes.byref(cmpr_size), + eb_mode, + eb_abs, + eb_rel, + r5, + r4, + r3, + r2, + r1, + ) + + cmpr_ratio = data.size * data.itemsize / cmpr_size.value + + data_cmpr = np.array(data_cmpr_c[: cmpr_size.value], dtype=np.uint8) + self.libc.free(data_cmpr_c) + return data_cmpr, cmpr_ratio diff --git a/src/appfl/trainer/base_trainer.py b/src/appfl/trainer/base_trainer.py index cbd2a24..e71f1e6 100644 --- a/src/appfl/trainer/base_trainer.py +++ b/src/appfl/trainer/base_trainer.py @@ -2,7 +2,7 @@ import torch.nn as nn from omegaconf import DictConfig from torch.utils.data import Dataset -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Tuple, Union, OrderedDict class BaseTrainer: """ @@ -39,8 +39,8 @@ def __init__( self.__dict__.update(kwargs) @abc.abstractmethod - def get_parameters(self) -> Dict: - """Return local model parameters""" + def get_parameters(self) -> Union[Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]: + """Return local model parameters and optional metadata.""" pass @abc.abstractmethod