Skip to content

Commit

Permalink
Merge pull request #4 from APPFL/zilinghan/termination
Browse files Browse the repository at this point in the history
Zilinghan/termination
  • Loading branch information
Zilinghan authored Mar 18, 2024
2 parents f148612 + a62e25d commit 57e0c54
Show file tree
Hide file tree
Showing 17 changed files with 140 additions and 71 deletions.
3 changes: 1 addition & 2 deletions examples/config/server_fedasync.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ client_configs:
trainer: "NaiveTrainer"
mode: "step"
num_local_steps: 100
num_global_epochs: 5
optim: "Adam"
optim_args:
lr: 0.001
Expand Down Expand Up @@ -64,7 +63,7 @@ server_configs:
alpha: 0.9
gradient_based: True
device: "cpu"
num_epochs: 2
num_global_epochs: 20
server_validation: False
logging_output_dirname: "./output"
logging_output_filename: "result"
Expand Down
3 changes: 1 addition & 2 deletions examples/config/server_fedavg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ client_configs:
trainer: "NaiveTrainer"
mode: "step"
num_local_steps: 100
num_global_epochs: 5
optim: "Adam"
optim_args:
lr: 0.001
Expand Down Expand Up @@ -56,7 +55,7 @@ server_configs:
aggregator_kwargs:
client_weights_mode: "equal"
device: "cpu"
num_epochs: 2
num_global_epochs: 3
server_validation: False
logging_output_dirname: "./output"
logging_output_filename: "result"
Expand Down
3 changes: 1 addition & 2 deletions examples/config/server_fedcompass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ client_configs:
trainer: "NaiveTrainer"
mode: "step"
num_local_steps: 100
num_global_epochs: 5
optim: "Adam"
optim_args:
lr: 0.001
Expand Down Expand Up @@ -68,7 +67,7 @@ server_configs:
gradient_based: True
num_clients: 2
device: "cpu"
num_epochs: 2
num_global_epochs: 20
server_validation: False
logging_output_dirname: "./output"
logging_output_filename: "result"
Expand Down
23 changes: 11 additions & 12 deletions examples/grpc/run_client_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,30 @@
from appfl.agent import APPFLClientAgent
from appfl.comm.grpc import GRPCClientCommunicator

max_message_size = 1024 * 1024

client_agent_config = OmegaConf.load("config/client_1.yaml")

client_agent = APPFLClientAgent(client_agent_config=client_agent_config)
client_comm = GRPCClientCommunicator(
client_communicator = GRPCClientCommunicator(
client_id = client_agent.get_id(),
**client_agent_config.comm_configs.grpc_configs,
)

client_config = client_comm.get_configuration()
client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)

init_global_model = client_comm.get_global_model(init_model=True)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)

# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
client_comm.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
for i in range(10):
client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size)

while True:
client_agent.train()
local_model = client_agent.get_parameters()
new_global_model = client_comm.update_global_model(local_model)
if isinstance(new_global_model, tuple):
new_global_model, metadata = new_global_model[0], new_global_model[1]
local_model = client_agent.get_parameters()
new_global_model, metadata = client_communicator.update_global_model(local_model)
if metadata['status'] == 'DONE':
break
if 'local_steps' in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps']
client_agent.load_parameters(new_global_model)
21 changes: 10 additions & 11 deletions examples/grpc/run_client_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,30 @@
from appfl.agent import APPFLClientAgent
from appfl.comm.grpc import GRPCClientCommunicator

max_message_size = 1024 * 1024

client_agent_config = OmegaConf.load("config/client_2.yaml")

client_agent = APPFLClientAgent(client_agent_config=client_agent_config)
client_comm = GRPCClientCommunicator(
client_communicator = GRPCClientCommunicator(
client_id = client_agent.get_id(),
**client_agent_config.comm_configs.grpc_configs,
)

client_config = client_comm.get_configuration()
client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)

init_global_model = client_comm.get_global_model(init_model=True)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)

# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
client_comm.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size)

for i in range(10):
while True:
client_agent.train()
local_model = client_agent.get_parameters()
new_global_model = client_comm.update_global_model(local_model)
if isinstance(new_global_model, tuple):
new_global_model, metadata = new_global_model[0], new_global_model[1]
local_model = client_agent.get_parameters()
new_global_model, metadata = client_communicator.update_global_model(local_model)
if metadata['status'] == 'DONE':
break
if 'local_steps' in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps']
client_agent.load_parameters(new_global_model)
11 changes: 6 additions & 5 deletions examples/mpi/run_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@
sample_size = client_agent.get_sample_size()
client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size)
# Local training and global model update iterations
for i in range(10):
while True:
client_agent.train()
local_model = client_agent.get_parameters()
new_global_model = client_communicator.update_global_model(local_model)
if isinstance(new_global_model, tuple):
new_global_model, metadata = new_global_model[0], new_global_model[1]
new_global_model, metadata = client_communicator.update_global_model(local_model)
if metadata['status'] == 'DONE':
break
if 'local_steps' in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps']
client_agent.load_parameters(new_global_model)
client_agent.load_parameters(new_global_model)
2 changes: 1 addition & 1 deletion examples/serial/run_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
sample_size=sample_size
)

for i in range(5):
while not server_agent.training_finished():
new_global_models = []
for client_agent in client_agents:
# Client local training
Expand Down
59 changes: 46 additions & 13 deletions src/appfl/agent/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import torch
import threading
import torch.nn as nn
from appfl.scheduler import *
from appfl.aggregator import *
Expand Down Expand Up @@ -44,12 +45,12 @@ def get_client_configs(self, **kwargs) -> DictConfig:
return self.server_agent_config.client_configs

def global_update(
self,
client_id: Union[int, str],
local_model: Union[Dict, OrderedDict, bytes],
blocking: bool = False,
**kwargs
) -> Union[Future, Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
self,
client_id: Union[int, str],
local_model: Union[Dict, OrderedDict, bytes],
blocking: bool = False,
**kwargs
) -> Union[Future, Dict, OrderedDict, Tuple[Union[Dict, OrderedDict], Dict]]:
"""
Update the global model using the local model from a client and return the updated global model.
:param: client_id: A unique client id for server to distinguish clients, which be obtained via `ClientAgent.get_id()`.
Expand All @@ -58,15 +59,19 @@ def global_update(
Setting `blocking` to `True` will block the client until the global model is available.
Otherwise, the method may return a `Future` object if the most up-to-date global model is not yet available.
"""
if isinstance(local_model, bytes):
local_model = self._bytes_to_model(local_model)
global_model = self.scheduler.schedule(client_id, local_model, **kwargs)
if not isinstance(global_model, Future):
if self.training_finished(internal_check=True):
global_model = self.scheduler.get_parameters(init_model=False)
return global_model
if blocking:
return global_model.result() # blocking until the `Future` is done
else:
return global_model # return the `Future` object
if isinstance(local_model, bytes):
local_model = self._bytes_to_model(local_model)
global_model = self.scheduler.schedule(client_id, local_model, **kwargs)
if not isinstance(global_model, Future):
return global_model
if blocking:
return global_model.result() # blocking until the `Future` is done
else:
return global_model # return the `Future` object

def get_parameters(
self,
Expand All @@ -90,6 +95,34 @@ def set_sample_size(
"""Set the size of the local dataset of a client."""
self.aggregator.set_client_sample_size(client_id, sample_size)

def training_finished(self, internal_check: bool = False) -> bool:
"""Notify the client whether the training is finished."""
finished = self.server_agent_config.server_configs.num_global_epochs <= self.scheduler.get_num_global_epochs()
if finished and not internal_check:
if not hasattr(self, "num_finish_calls"):
self.num_finish_calls = 0
self._num_finish_calls_lock = threading.Lock()
with self._num_finish_calls_lock:
self.num_finish_calls += 1
return finished

def server_terminated(self):
"""Whether the server can be terminated from listening to the clients."""
if not hasattr(self, "num_finish_calls"):
return False
num_clients = (
self.server_agent_config.server_configs.num_clients if
hasattr(self.server_agent_config.server_configs, "num_clients") else
self.server_agent_config.server_configs.scheduler_kwargs.num_clients if
hasattr(self.server_agent_config.server_configs.scheduler_kwargs, "num_clients") else
self.server_agent_config.server_configs.aggregator_kwargs.num_clients
)
with self._num_finish_calls_lock:
terminated = self.num_finish_calls >= num_clients
if terminated and hasattr(self.scheduler, "clean_up"):
self.scheduler.clean_up()
return terminated

def _create_logger(self) -> None:
kwargs = {}
if hasattr(self.server_agent_config.server_configs, "logging_output_dirname"):
Expand Down
10 changes: 4 additions & 6 deletions src/appfl/comm/grpc/grpc_client_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def get_global_model(self, **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Un
else:
return model, meta_data

def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Union[Dict, OrderedDict], Dict]]:
def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kwargs) -> Tuple[Union[Dict, OrderedDict], Dict]:
"""
Send local model to FL server for global update, and return the new global model.
:param local_model: the local model to be sent to the server for gloabl aggregation
:param kwargs: additional metadata to be sent to the server
:return: the updated global model with additional metadata (if any)
:return: the updated global model with additional metadata. Specifically, `meta_data["status"]` is either "RUNNING" or "DONE".
"""
meta_data = json.dumps(kwargs)
request = UpdateGlobalModelRequest(
Expand All @@ -115,10 +115,8 @@ def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kw
raise Exception("Server returned an error, stopping the client.")
model = torch.load(io.BytesIO(response.global_model))
meta_data = json.loads(response.meta_data)
if len(meta_data) == 0:
return model
else:
return model, meta_data
meta_data["status"] = "DONE" if response.header.status == ServerStatus.DONE else "RUNNING"
return model, meta_data

def invoke_custom_action(self, action: str, **kwargs) -> Dict:
"""
Expand Down
5 changes: 3 additions & 2 deletions src/appfl/comm/grpc/grpc_server_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def GetGlobalModel(self, request, context):
meta_data = {}
else:
meta_data = json.loads(request.meta_data)
model = self.server_agent.get_parameters(**meta_data)
model = self.server_agent.get_parameters(**meta_data, blocking=True)
if isinstance(model, tuple):
model = model[0]
meta_data = json.dumps(model[1])
Expand Down Expand Up @@ -98,8 +98,9 @@ def UpdateGlobalModel(self, request_iterator, context):
else:
meta_data = json.dumps({})
global_model_serialized = serialize_model(global_model)
status = ServerStatus.DONE if self.server_agent.training_finished() else ServerStatus.RUN
response = UpdateGlobalModelResponse(
header=ServerHeader(status=ServerStatus.RUN),
header=ServerHeader(status=status),
global_model=global_model_serialized,
meta_data=meta_data,
)
Expand Down
2 changes: 1 addition & 1 deletion src/appfl/comm/mpi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MPITask(Enum):
class MPIServerStatus(Enum):
"""MPI server status"""
RUN = 0
STOP = 1
DONE = 1
ERROR = 2

@dataclass
Expand Down
11 changes: 5 additions & 6 deletions src/appfl/comm/mpi/mpi_client_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def get_global_model(self, **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Un
else:
return model, meta_data

def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kwargs) -> Union[Union[Dict, OrderedDict], Tuple[Union[Dict, OrderedDict], Dict]]:
def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kwargs) -> Tuple[Union[Dict, OrderedDict], Dict]:
"""
Send local model to FL server for global update, and return the new global model.
:param local_model: the local model to be sent to the server for gloabl aggregation
:param kwargs: additional metadata to be sent to the server
:return: the updated global model with additional metadata (if any)
:return: the updated global model with additional metadata. Specifically, `meta_data["status"]` is either "RUNNING" or "DONE".
"""
meta_data = json.dumps(kwargs)
request = MPITaskRequest(
Expand All @@ -78,10 +78,9 @@ def update_global_model(self, local_model: Union[Dict, OrderedDict, bytes], **kw
raise Exception("Server returned an error, stopping the client.")
model = byte_to_model(response.payload)
meta_data = json.loads(response.meta_data)
if len(meta_data) == 0:
return model
else:
return model, meta_data
status = "DONE" if response.status == MPIServerStatus.DONE.value else "RUNNING"
meta_data["status"] = status
return model, meta_data

def invoke_custom_action(self, action: str, **kwargs) -> Dict:
"""
Expand Down
11 changes: 7 additions & 4 deletions src/appfl/comm/mpi/mpi_server_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def serve(self):
"""
self.logger.info(f"Server starting...")
status = MPI.Status()
while True:
while not self.server_agent.server_terminated():
self.comm.probe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status)
source = status.Get_source()
tag = status.Get_tag()
Expand All @@ -41,6 +41,7 @@ def serve(self):
if response is not None:
response_bytes = response_to_byte(response)
self.comm.Send(response_bytes, dest=source, tag=source)
self.logger.info(f"Server terminated.")

def _request_handler(
self,
Expand Down Expand Up @@ -104,7 +105,7 @@ def _get_global_model(
"""
self.logger.info(f"Received GetGlobalModel request from client {client_id}")
meta_data = json.loads(request.meta_data) if len(request.meta_data) > 0 else {}
model = self.server_agent.get_parameters(**meta_data)
model = self.server_agent.get_parameters(**meta_data, blocking=False)
if not isinstance(model, Future):
if isinstance(model, tuple):
model = model[0]
Expand Down Expand Up @@ -148,8 +149,9 @@ def _update_global_model(
else:
meta_data = json.dumps({})
global_model_serialized = model_to_byte(global_model)
status = MPIServerStatus.DONE.value if self.server_agent.training_finished() else MPIServerStatus.RUN.value
return MPITaskResponse(
status=MPIServerStatus.RUN.value,
status=status,
payload=global_model_serialized,
meta_data=meta_data,
)
Expand Down Expand Up @@ -194,8 +196,9 @@ def _check_response_futures(self):
else:
meta_data = json.dumps({})
global_model_serialized = model_to_byte(global_model)
status = MPIServerStatus.DONE.value if self.server_agent.training_finished() else MPIServerStatus.RUN.value
response = MPITaskResponse(
status=MPIServerStatus.RUN.value,
status=status,
payload=global_model_serialized,
meta_data=meta_data,
)
Expand Down
Loading

0 comments on commit 57e0c54

Please sign in to comment.