Skip to content

Commit

Permalink
Clean up codes in nvflare/private/fed/app (NVIDIA#286)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Mar 19, 2022
1 parent de93a14 commit 2b525cc
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 171 deletions.
29 changes: 29 additions & 0 deletions nvflare/private/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,32 @@ class ClientStatusKey(object):
CURRENT_TASK = "current_task"
STATUS = "status"
APP_NAME = "app_name"


# TODO:: Remove some of these constants
class AppFolderConstants:
"""hard coded file names inside the app folder."""

CONFIG_TRAIN = "config_train.json"
CONFIG_ENV = "environment.json"
CONFIG_FED_SERVER = "config_fed_server.json"
CONFIG_FED_CLIENT = "config_fed_client.json"


class SSLConstants:
"""hard coded names related to SSL."""

CERT = "ssl_cert"
PRIVATE_KEY = "ssl_private_key"
ROOT_CERT = "ssl_root_cert"


class WorkspaceConstants:
"""hard coded file names inside the workspace folder."""

LOGGING_CONFIG = "log.config"
AUDIT_LOG = "audit.log"

# these two files is used by shell scripts to determine restart / shutdown
RESTART_FILE = "restart.fl"
SHUTDOWN_FILE = "shutdown.fl"
129 changes: 58 additions & 71 deletions nvflare/private/fed/app/client/client_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides a command line interface for a federated client."""
"""Federated client launching script."""

import argparse
import os
Expand All @@ -23,75 +23,71 @@
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.security_content_service import LoadResult, SecurityContentService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.private.defs import AppFolderConstants, SSLConstants, WorkspaceConstants
from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger
from nvflare.private.fed.client.admin import FedAdminAgent
from nvflare.private.fed.client.admin_msg_sender import AdminMessageSender
from nvflare.private.fed.client.client_engine import ClientEngine
from nvflare.private.fed.client.fed_client import FederatedClient


def main():
"""Start program of the FL client."""
parser = argparse.ArgumentParser()
parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True)

parser.add_argument(
"--fed_client", "-s", type=str, help="an aggregation server specification json file", required=True
)

parser.add_argument("--set", metavar="KEY=VALUE", nargs="*")

parser.add_argument("--local_rank", type=int, default=0)

args = parser.parse_args()
kv_list = parse_vars(args.set)

args.train_config = "config/config_train.json"
config_folder = kv_list.get("config_folder", "")
if config_folder == "":
args.client_config = "config_fed_client.json"
args.client_config = AppFolderConstants.CONFIG_FED_CLIENT
else:
args.client_config = config_folder + "/config_fed_client.json"
args.env = "config/environment.json"
args.client_config = os.path.join(config_folder, AppFolderConstants.CONFIG_FED_CLIENT)
# TODO:: remove env and train config since they are not core
args.env = os.path.join("config", AppFolderConstants.CONFIG_ENV)
args.train_config = os.path.join("config", AppFolderConstants.CONFIG_TRAIN)
args.log_config = None

try:
remove_restart_file(args)
except BaseException:
print("Could not remove the restart.fl / shutdown.fl file. Please check your system before starting FL.")
sys.exit(-1)
for name in [WorkspaceConstants.RESTART_FILE, WorkspaceConstants.SHUTDOWN_FILE]:
try:
f = os.path.join(args.workspace, name)
if os.path.exists(f):
os.remove(f)
except BaseException:
print("Could not remove the {} file. Please check your system before starting FL.".format(name))
sys.exit(-1)

rank = args.local_rank

try:
os.chdir(args.workspace)
AuditService.initialize(audit_file_name="audit.log")

workspace = os.path.join(args.workspace, "startup")
AuditService.initialize(audit_file_name=WorkspaceConstants.AUDIT_LOG)

# trainer = WorkFlowFactory().create_client_trainer(train_configs, envs)
startup = os.path.join(args.workspace, "startup")
conf = FLClientStarterConfiger(
app_root=workspace,
# wf_config_file_name="config_train.json",
app_root=startup,
client_config_file_name=args.fed_client,
# env_config_file_name="environment.json",
log_config_file_name="log.config",
log_config_file_name=WorkspaceConstants.LOGGING_CONFIG,
kv_list=args.set,
)
conf.configure()

trainer = conf.base_deployer

security_check(trainer.secure_train, args)
security_check(secure_train=trainer.secure_train, content_folder=startup, fed_client_config=args.fed_client)

federated_client = trainer.create_fed_client(args)

while not federated_client.sp_established:
print("Waiting for SP....")
time.sleep(1.0)

# federated_client.platform = conf.wf_config_data.get("platform", "PT")
federated_client.use_gpu = False
# federated_client.cross_site_validate = kv_list.get("cross_site_validate", True)
federated_client.config_folder = config_folder

if rank == 0:
Expand Down Expand Up @@ -123,84 +119,75 @@ def main():
except ConfigError as ex:
print("ConfigError:", str(ex))
finally:
# shutil.rmtree(workspace)
pass

sys.exit(0)


def security_check(secure_train, args):
def security_check(secure_train: bool, content_folder: str, fed_client_config: str):
"""To check the security content if running in security mode.
Args:
secure_train: True/False
args: command args
secure_train (bool): if run in secure mode or not.
content_folder (str): the folder to check.
fed_client_config (str): fed_client.json
"""
# initialize the SecurityContentService.
# must do this before initializing other services since it may be needed by them!
startup = os.path.join(args.workspace, "startup")
SecurityContentService.initialize(content_folder=startup)
SecurityContentService.initialize(content_folder=content_folder)

if secure_train:
insecure_list = secure_content_check(args)
insecure_list = secure_content_check(fed_client_config)
if len(insecure_list):
print("The following files are not secure content.")
for item in insecure_list:
print(item)
sys.exit(1)
# initialize the AuditService, which is used by command processing.
# The Audit Service can be used in other places as well.
AuditService.initialize(audit_file_name="audit.log")
AuditService.initialize(audit_file_name=WorkspaceConstants.AUDIT_LOG)


def secure_content_check(args):
def secure_content_check(config: str):
"""To check the security contents.
Args:
args: command args
Returns: the insecure content list
config (str): fed_client.json
Returns:
A list of insecure content.
"""
insecure_list = []
data, sig = SecurityContentService.load_json(args.fed_client)
data, sig = SecurityContentService.load_json(config)
if sig != LoadResult.OK:
insecure_list.append(args.fed_client)
insecure_list.append(config)

client = data["client"]
content, sig = SecurityContentService.load_content(client.get("ssl_cert"))
content, sig = SecurityContentService.load_content(client.get(SSLConstants.CERT))
if sig != LoadResult.OK:
insecure_list.append(client.get("ssl_cert"))
content, sig = SecurityContentService.load_content(client.get("ssl_private_key"))
insecure_list.append(client.get(SSLConstants.CERT))
content, sig = SecurityContentService.load_content(client.get(SSLConstants.PRIVATE_KEY))
if sig != LoadResult.OK:
insecure_list.append(client.get("ssl_private_key"))
content, sig = SecurityContentService.load_content(client.get("ssl_root_cert"))
insecure_list.append(client.get(SSLConstants.PRIVATE_KEY))
content, sig = SecurityContentService.load_content(client.get(SSLConstants.ROOT_CERT))
if sig != LoadResult.OK:
insecure_list.append(client.get("ssl_root_cert"))
insecure_list.append(client.get(SSLConstants.ROOT_CERT))

return insecure_list


def remove_restart_file(args):
"""To remove the restart.fl file.
Args:
args: command args
"""
restart_file = os.path.join(args.workspace, "restart.fl")
if os.path.exists(restart_file):
os.remove(restart_file)
restart_file = os.path.join(args.workspace, "shutdown.fl")
if os.path.exists(restart_file):
os.remove(restart_file)


def create_admin_agent(
client_args, client_id, req_processors, secure_train, server_args, federated_client, args, is_multi_gpu, rank
client_args,
client_id,
req_processors,
secure_train,
server_args,
federated_client: FederatedClient,
args,
is_multi_gpu,
rank,
):
"""To create the admin client.
"""Creates an admin agent.
Args:
client_args: start client command args
Expand All @@ -213,31 +200,31 @@ def create_admin_agent(
is_multi_gpu: True/False
rank: client rank process number
Returns: admin client
Returns:
A FedAdminAgent.
"""
sender = AdminMessageSender(
client_name=federated_client.token,
root_cert=client_args["ssl_root_cert"],
ssl_cert=client_args["ssl_cert"],
private_key=client_args["ssl_private_key"],
root_cert=client_args[SSLConstants.ROOT_CERT],
ssl_cert=client_args[SSLConstants.CERT],
private_key=client_args[SSLConstants.PRIVATE_KEY],
server_args=server_args,
secure=secure_train,
is_multi_gpu=is_multi_gpu,
rank=rank,
)
client_engine = ClientEngine(federated_client, federated_client.token, sender, args, rank)
admin_agent = FedAdminAgent(
client_name="admin_agent",
sender=sender,
app_ctx=ClientEngine(federated_client, federated_client.token, sender, args, rank),
app_ctx=client_engine,
)
admin_agent.app_ctx.set_agent(admin_agent)
federated_client.set_client_engine(admin_agent.app_ctx)
federated_client.set_client_engine(client_engine)
for processor in req_processors:
admin_agent.register_processor(processor)

return admin_agent
# self.admin_agent.start()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import grpc

from nvflare.apis.fl_context import FLContext
from nvflare.private.fed.client.client_req_processors import ClientRequestProcessors
from nvflare.private.fed.client.fed_client import FederatedClient

from .client_req_processors import ClientRequestProcessors


class BaseClientDeployer:
def __init__(self):
Expand Down
35 changes: 17 additions & 18 deletions nvflare/private/fed/app/fl_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""FL Server / Client startup config."""
"""FL Server / Client startup configer."""

import logging
import logging.config
Expand All @@ -21,11 +21,12 @@
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.json_scanner import Node
from nvflare.fuel.utils.wfconf import ConfigContext
from nvflare.private.fed.client.base_client_deployer import BaseClientDeployer
from nvflare.private.defs import SSLConstants
from nvflare.private.json_configer import JsonConfigurator

from .deployer.base_client_deployer import BaseClientDeployer
from .deployer.server_deployer import ServerDeployer
from .fl_app_validator import FLAppValidator
from .trainers.server_deployer import ServerDeployer

FL_PACKAGES = ["nvflare"]
FL_MODULES = ["server", "client", "app"]
Expand Down Expand Up @@ -96,12 +97,12 @@ def start_config(self, config_ctx: ConfigContext):
# loading server specifications
try:
for server in self.config_data["servers"]:
if server.get("ssl_private_key"):
server["ssl_private_key"] = os.path.join(self.app_root, server["ssl_private_key"])
if server.get("ssl_cert"):
server["ssl_cert"] = os.path.join(self.app_root, server["ssl_cert"])
if server.get("ssl_root_cert"):
server["ssl_root_cert"] = os.path.join(self.app_root, server["ssl_root_cert"])
if server.get(SSLConstants.PRIVATE_KEY):
server[SSLConstants.PRIVATE_KEY] = os.path.join(self.app_root, server[SSLConstants.PRIVATE_KEY])
if server.get(SSLConstants.CERT):
server[SSLConstants.CERT] = os.path.join(self.app_root, server[SSLConstants.CERT])
if server.get(SSLConstants.ROOT_CERT):
server[SSLConstants.ROOT_CERT] = os.path.join(self.app_root, server[SSLConstants.ROOT_CERT])
except Exception:
raise ValueError("Server config error: '{}'".format(self.server_config_file_name))

Expand Down Expand Up @@ -222,8 +223,6 @@ def process_config_element(self, config_ctx: ConfigContext, node: Node):
config_ctx: config context
node: element node
"""
# JsonConfigurator.process_config_element(self, config_ctx, node)

element = node.element
path = node.path()

Expand All @@ -245,12 +244,12 @@ def start_config(self, config_ctx: ConfigContext):

try:
client = self.config_data["client"]
if client.get("ssl_private_key"):
client["ssl_private_key"] = os.path.join(self.app_root, client["ssl_private_key"])
if client.get("ssl_cert"):
client["ssl_cert"] = os.path.join(self.app_root, client["ssl_cert"])
if client.get("ssl_root_cert"):
client["ssl_root_cert"] = os.path.join(self.app_root, client["ssl_root_cert"])
if client.get(SSLConstants.PRIVATE_KEY):
client[SSLConstants.PRIVATE_KEY] = os.path.join(self.app_root, client[SSLConstants.PRIVATE_KEY])
if client.get(SSLConstants.CERT):
client[SSLConstants.CERT] = os.path.join(self.app_root, client[SSLConstants.CERT])
if client.get(SSLConstants.ROOT_CERT):
client[SSLConstants.ROOT_CERT] = os.path.join(self.app_root, client[SSLConstants.ROOT_CERT])
except Exception:
raise ValueError("Client config error: '{}'".format(self.client_config_file_name))

Expand Down Expand Up @@ -343,4 +342,4 @@ def start_config(self, config_ctx: ConfigContext):
if admin.get("download_dir"):
admin["download_dir"] = os.path.join(os.path.dirname(self.app_root), admin["download_dir"])
except Exception:
raise ValueError("Client config error: '{}'".format(self.client_config_file_name))
raise ValueError("Client config error: '{}'".format(self.admin_config_file_name))
Loading

0 comments on commit 2b525cc

Please sign in to comment.