-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from APPFL/zilinghan/mpi
Zilinghan/mpi
- Loading branch information
Showing
10 changed files
with
558 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import argparse | ||
from mpi4py import MPI | ||
from omegaconf import OmegaConf | ||
from appfl.agent import APPFLClientAgent, APPFLServerAgent | ||
from appfl.comm.mpi import MPIClientCommunicator, MPIServerCommunicator | ||
|
||
argparse = argparse.ArgumentParser() | ||
argparse.add_argument("--server_config", type=str, default="config/server_fedavg.yaml") | ||
argparse.add_argument("--client_config", type=str, default="config/client_1.yaml") | ||
args = argparse.parse_args() | ||
|
||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
size = comm.Get_size() | ||
num_clients = size - 1 | ||
|
||
if rank == 0: | ||
# Load and set the server configurations | ||
server_agent_config = OmegaConf.load(args.server_config) | ||
server_agent_config.server_configs.scheduler_kwargs.num_clients = num_clients | ||
if hasattr(server_agent_config.server_configs.aggregator_kwargs, "num_clients"): | ||
server_agent_config.server_configs.aggregator_kwargs.num_clients = num_clients | ||
# Create the server agent and communicator | ||
server_agent = APPFLServerAgent(server_agent_config=server_agent_config) | ||
server_communicator = MPIServerCommunicator(comm, server_agent) | ||
# Start the server to serve the clients | ||
server_communicator.serve() | ||
else: | ||
# Set the client configurations | ||
client_agent_config = OmegaConf.load(args.client_config) | ||
client_agent_config.train_configs.logging_id = f'Client{rank}' | ||
client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients | ||
client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1 | ||
client_agent_config.data_configs.dataset_kwargs.visualization = True if rank == 1 else False | ||
# Create the client agent and communicator | ||
client_agent = APPFLClientAgent(client_agent_config=client_agent_config) | ||
client_communicator = MPIClientCommunicator(comm, server_rank=0) | ||
# Load the configurations and initial global model | ||
client_config = client_communicator.get_configuration() | ||
client_agent.load_config(client_config) | ||
init_global_model = client_communicator.get_global_model(init_model=True) | ||
client_agent.load_parameters(init_global_model) | ||
# Send the sample size to the server | ||
sample_size = client_agent.get_sample_size() | ||
client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size) | ||
# Local training and global model update iterations | ||
for i in range(10): | ||
client_agent.train() | ||
local_model = client_agent.get_parameters() | ||
new_global_model = client_communicator.update_global_model(local_model) | ||
if isinstance(new_global_model, tuple): | ||
new_global_model, metadata = new_global_model[0], new_global_model[1] | ||
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps'] | ||
client_agent.load_parameters(new_global_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .mpi_client_communicator import MPIClientCommunicator | ||
from .mpi_server_communicator import MPIServerCommunicator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from enum import Enum | ||
from dataclasses import dataclass | ||
|
||
class MPITask(Enum): | ||
"""MPI task type""" | ||
GET_CONFIGURATION = 0 | ||
GET_GLOBAL_MODEL = 1 | ||
UPDATE_GLOBAL_MODEL = 2 | ||
INVOKE_CUSTOM_ACTION = 3 | ||
|
||
class MPIServerStatus(Enum): | ||
"""MPI server status""" | ||
RUN = 0 | ||
STOP = 1 | ||
ERROR = 2 | ||
|
||
@dataclass | ||
class MPITaskRequest: | ||
"""MPI task request""" | ||
payload: bytes = b"" | ||
meta_data: str = "" | ||
|
||
@dataclass | ||
class MPITaskResponse: | ||
"""MPI task response""" | ||
status: int = MPIServerStatus.RUN.value | ||
payload: bytes = b"" | ||
meta_data: str = "" |
Oops, something went wrong.