Skip to content

Commit

Permalink
Merge pull request #3 from APPFL/zilinghan/mpi
Browse files Browse the repository at this point in the history
Zilinghan/mpi
  • Loading branch information
Zilinghan authored Mar 15, 2024
2 parents 156e37c + bf684b7 commit f148612
Show file tree
Hide file tree
Showing 10 changed files with 558 additions and 27 deletions.
89 changes: 82 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,33 @@
<a href="https://appfl.github.io/FedCompass/">
<img src="https://img.shields.io/badge/project-FedCompass-B3FFF4.svg" alt="project">
</a>
<a href="./LICENSE">
<img src="https://img.shields.io/badge/license-MIT-green?style=flat&logo=github" alt="project">
</a>
</p>

### 🎙 Introduction
<details>
<summary><b>Table of Contents</b></summary>
<p>

- [Introduction](#introduction)
- [Installation](#installation)
- [Launch Experiment](#launch-first-example-experiment)
- [Serial Simulation](#serial-simulation)
- [MPI Simulation](#mpi-simulation)
- [gRPC Deployment](#grpc-deployment)
- [Features](#features)
- [Citations](#citation)

</p>
</details>

### Introduction
FedCompass is a semi-asynchrnous federated learning (FL) algorithm which addresses the time-efficiency challenge of other synchronous FL algorithms, and the model performance challenge of other asynchronous FL algorithms (due to model stalenesses) by using a *COM*puting *P*ower *A*ware *Scheduler* *(COMPASS)* to adaptively assign different numbers of local steps to different FL clients and synchrnoize the arrival of client local models.

This repository is built upon the open-source and highly extendible FL framework [APPFL](https://github.com/APPFL/APPFL) and employs gRPC as the communication protocol to help you easily launch FL experiment using FedCompass among distributed FL clients.

### ⚙️ Installation
### Installation
Users can install by cloning this repository and installing the package locally. We also highly recommend to create a virtual environment for easy dependency management.
```bash
conda create -n fedcompass python=3.8
Expand All @@ -33,16 +52,44 @@ git clone https://github.com/APPFL/FedCompass.git && cd FedCompass
pip install -e .
```

### 🚀 Launch First Example Experiment
Please go to the `examples` folder first. To launch a server, users can run the following command,
### Launch First Example Experiment
In the `examples` folder, we provide example scripts to train a CNN on the MNIST dataset using federated learning by running [serial simulation](#serial-simulation), [MPI simulation](#mpi-simulation), and [gRPC deployment](#grpc-deployment). Specifically, in this repository, we refer

- **simulation** as federated learning experiments that can only run on a single machine or a cluster
- **deployment** as federated learning experiments that can run only multiple distributed machines

#### Serial Simulation

Please go to the `examples` folder first, and then run the following command
```bash
python grpc/run_server.py --config config/server_fedcompass.yaml
python serial/run_serial.py \
--server_config config/server_fedavg.yaml \
--client_config config/client_1.yaml \
--num_clients 5
```
where `--config` is the path to the configuration file. We currently provide three configuration files for the FL server, corresponding to three different FL algorithms
where `--server_config` is the path to the configuration file for the FL server. We currently provide three configuration files for the FL server, corresponding to three different FL algorithms. However, it should be noted at the beginning that serial simulation is only suitable and making sense for synchrnous federated learning algorithms.
- `config/server_fedcompass.yaml`: FL server for the FedCompass algorithm
- `config/server_fedavg.yaml`: FL server for the FedAvg algorithm
- `config/server_fedasync.yaml`: FL server for the FedAsync algorithm

`--client_config` is the path to the base configuration file for the FL clients, and `--num_clients` is the number of FL clients you would like to simulate.

#### MPI Simulation
Please go to the `examples` folder first, and then run the following command
```bash
mpiexec -n 6 python mpi/run_mpi.py \
--server_config config/server_fedcompass.yaml \
--client_config config/client_1.yaml
```
where `mpiexec -n 6` means that we start 6 MPI processes, and there will be 6-1=5 FL clients, as one MPI process will serve as the FL server.

#### gRPC Deployment

Please go to the `examples` folder first. To launch a server, users can run the following command,
```bash
python grpc/run_server.py --config config/server_fedcompass.yaml
```

The above command launches an FL server at `localhost:50051` waiting for connection from two FL clients. To launch two FL clients, open two separate terminals and go to the `examples` folder, and run the following two commands, respectively. This will help you start an FL experiment with two clients and a server running the specified algorithm.
```bash
python grpc/run_client_1.py
Expand All @@ -51,8 +98,25 @@ python grpc/run_client_1.py
python grpc/run_client_2.py
```

### Features

- [x] Server aggregation algorithm customization
- [x] Server scheduling algorithm customization
- [x] Client local trainer customization
- [x] Synchronous federated learning
- [x] Asynchronous Federated Learning
- [x] Semi-asynchronous federated learning
- [x] Model and dataset customization
- [x] Loss function and evaluation metric customization
- [x] Heterogeneous data partition
- [x] Lossy compression using SZ compressors
- [x] Single-node serial federated learning simulation
- [x] MPI federated learning simulation
- [x] Real federated learning deployment using gRPC
- [x] Authentication in gRPC using Globus Identity
- [ ] wandb visualization

### 📃 Citation
### Citation
If you find FedCompass and this repository useful to your research, please consider cite the following paper
```
@article{li2023fedcompass,
Expand All @@ -62,3 +126,14 @@ If you find FedCompass and this repository useful to your research, please consi
year={2023}
}
```

```
@inproceedings{ryu2022appfl,
title={APPFL: open-source software framework for privacy-preserving federated learning},
author={Ryu, Minseok and Kim, Youngdae and Kim, Kibaek and Madduri, Ravi K},
booktitle={2022 IEEE International Parallel and Distributed Processing Symposium Workshops (IPDPSW)},
pages={1074--1083},
year={2022},
organization={IEEE}
}
```
2 changes: 1 addition & 1 deletion examples/config/server_fedcompass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ server_configs:
a: 0.5
alpha: 0.9
gradient_based: True
client_weights_mode: "equal"
num_clients: 2
device: "cpu"
num_epochs: 2
server_validation: False
Expand Down
54 changes: 54 additions & 0 deletions examples/mpi/run_mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import argparse
from mpi4py import MPI
from omegaconf import OmegaConf
from appfl.agent import APPFLClientAgent, APPFLServerAgent
from appfl.comm.mpi import MPIClientCommunicator, MPIServerCommunicator

argparse = argparse.ArgumentParser()
argparse.add_argument("--server_config", type=str, default="config/server_fedavg.yaml")
argparse.add_argument("--client_config", type=str, default="config/client_1.yaml")
args = argparse.parse_args()

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
num_clients = size - 1

if rank == 0:
# Load and set the server configurations
server_agent_config = OmegaConf.load(args.server_config)
server_agent_config.server_configs.scheduler_kwargs.num_clients = num_clients
if hasattr(server_agent_config.server_configs.aggregator_kwargs, "num_clients"):
server_agent_config.server_configs.aggregator_kwargs.num_clients = num_clients
# Create the server agent and communicator
server_agent = APPFLServerAgent(server_agent_config=server_agent_config)
server_communicator = MPIServerCommunicator(comm, server_agent)
# Start the server to serve the clients
server_communicator.serve()
else:
# Set the client configurations
client_agent_config = OmegaConf.load(args.client_config)
client_agent_config.train_configs.logging_id = f'Client{rank}'
client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients
client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1
client_agent_config.data_configs.dataset_kwargs.visualization = True if rank == 1 else False
# Create the client agent and communicator
client_agent = APPFLClientAgent(client_agent_config=client_agent_config)
client_communicator = MPIClientCommunicator(comm, server_rank=0)
# Load the configurations and initial global model
client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)
# Send the sample size to the server
sample_size = client_agent.get_sample_size()
client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
# Local training and global model update iterations
for i in range(10):
client_agent.train()
local_model = client_agent.get_parameters()
new_global_model = client_communicator.update_global_model(local_model)
if isinstance(new_global_model, tuple):
new_global_model, metadata = new_global_model[0], new_global_model[1]
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps']
client_agent.load_parameters(new_global_model)
18 changes: 3 additions & 15 deletions examples/serial/run_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,9 @@
from appfl.agent import APPFLClientAgent, APPFLServerAgent

argparser = argparse.ArgumentParser()
argparser.add_argument(
"--server_config",
type=str,
default="config/server_fedavg.yaml",
)
argparser.add_argument(
"--client_config",
type=str,
default="config/client_1.yaml",
)
argparser.add_argument(
"--num_clients",
type=int,
default=10,
)
argparser.add_argument("--server_config", type=str, default="config/server_fedavg.yaml")
argparser.add_argument("--client_config", type=str, default="config/client_1.yaml")
argparser.add_argument("--num_clients", type=int, default=10)
args = argparser.parse_args()

# Load server agent configurations and set the number of clients
Expand Down
12 changes: 8 additions & 4 deletions src/appfl/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,19 @@ def global_update(
else:
return global_model # return the `Future` object

def get_parameters(self, **kwargs) -> Union[Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
def get_parameters(
self,
blocking: bool = False,
**kwargs
) -> Union[Future, Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
"""Return the global model to the clients."""
global_model = self.scheduler.get_parameters(**kwargs)
if not isinstance(global_model, Future):
return global_model
if blocking:
return global_model.result() # blocking until the `Future` is done
else:
return global_model.result()
return global_model # return the `Future` object

def set_sample_size(
self,
Expand Down Expand Up @@ -180,8 +186,6 @@ def _load_compressor(self) -> None:
def _bytes_to_model(self, model_bytes: bytes) -> Union[Dict, OrderedDict]:
"""Deserialize the model from bytes."""
if not self.enable_compression:
print("[DEBUG] Decompressing model without compression")
return torch.load(io.BytesIO(model_bytes))
else:
print("[DEBUG] Decompressing model with compression")
return self.compressor.decompress_model(model_bytes, self.model)
2 changes: 2 additions & 0 deletions src/appfl/comm/mpi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mpi_client_communicator import MPIClientCommunicator
from .mpi_server_communicator import MPIServerCommunicator
28 changes: 28 additions & 0 deletions src/appfl/comm/mpi/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from enum import Enum
from dataclasses import dataclass

class MPITask(Enum):
"""MPI task type"""
GET_CONFIGURATION = 0
GET_GLOBAL_MODEL = 1
UPDATE_GLOBAL_MODEL = 2
INVOKE_CUSTOM_ACTION = 3

class MPIServerStatus(Enum):
"""MPI server status"""
RUN = 0
STOP = 1
ERROR = 2

@dataclass
class MPITaskRequest:
"""MPI task request"""
payload: bytes = b""
meta_data: str = ""

@dataclass
class MPITaskResponse:
"""MPI task response"""
status: int = MPIServerStatus.RUN.value
payload: bytes = b""
meta_data: str = ""
Loading

0 comments on commit f148612

Please sign in to comment.