Skip to content

Commit

Permalink
Rename study to project in all files in lighter due to new naming con…
Browse files Browse the repository at this point in the history
…flict (NVIDIA#365)
  • Loading branch information
IsaacYangSLA authored Mar 29, 2022
1 parent 0e20241 commit 869c500
Show file tree
Hide file tree
Showing 9 changed files with 14,413 additions and 95 deletions.
10 changes: 5 additions & 5 deletions nvflare/lighter/impl/auth_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json
import os

from nvflare.lighter.spec import Builder, Study
from nvflare.lighter.spec import Builder, Project


class AuthPolicyBuilder(Builder):
Expand All @@ -39,12 +39,12 @@ def __init__(self, orgs, roles, groups, disabled):
self.groups = groups
self.disabled = disabled

def build(self, study: Study, ctx: dict):
def build(self, project: Project, ctx: dict):
authz = {"version": "1.0"}
authz["roles"] = self.roles
authz["groups"] = self.groups
users = dict()
for admin in study.get_participants_by_type("admin", first_only=False):
for admin in project.get_participants_by_type("admin", first_only=False):
if admin.org not in self.orgs:
raise ValueError(f"Admin {admin.name}'s org {admin.org} not defined in AuthPolicy")
if self.disabled:
Expand All @@ -56,12 +56,12 @@ def build(self, study: Study, ctx: dict):
users[admin.name] = {"org": admin.org, "roles": admin.props.get("roles")}
authz["users"] = users
authz["orgs"] = self.orgs
servers = study.get_participants_by_type("server", first_only=False)
servers = project.get_participants_by_type("server", first_only=False)
for server in servers:
if server.org not in self.orgs:
raise ValueError(f"Server {server.name}'s org {server.org} not defined in AuthPolicy")
sites = {"server": server.org}
for client in study.get_participants_by_type("client", first_only=False):
for client in project.get_participants_by_type("client", first_only=False):
if client.org not in self.orgs:
raise ValueError(f"client {client.name}'s org {client.org} not defined in AuthPolicy")
sites[client.name] = client.org
Expand Down
12 changes: 6 additions & 6 deletions nvflare/lighter/impl/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,21 @@ def _build_write_cert_pair(self, participant, base_name, ctx):
with open(os.path.join(dest_dir, "rootCA.pem"), "wb") as f:
f.write(self.serialized_cert)

def build(self, study, ctx):
self._build_root(study.name)
def build(self, project, ctx):
self._build_root(project.name)
ctx["root_cert"] = self.root_cert
ctx["root_pri_key"] = self.pri_key
overseer = study.get_participants_by_type("overseer")
overseer = project.get_participants_by_type("overseer")
self._build_write_cert_pair(overseer, "overseer", ctx)

servers = study.get_participants_by_type("server", first_only=False)
servers = project.get_participants_by_type("server", first_only=False)
for server in servers:
self._build_write_cert_pair(server, "server", ctx)

for client in study.get_participants_by_type("client", first_only=False):
for client in project.get_participants_by_type("client", first_only=False):
self._build_write_cert_pair(client, "client", ctx)

for admin in study.get_participants_by_type("admin", first_only=False):
for admin in project.get_participants_by_type("admin", first_only=False):
self._build_write_cert_pair(admin, "client", ctx)

def get_pri_key_cert(self, participant):
Expand Down
6 changes: 3 additions & 3 deletions nvflare/lighter/impl/he.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def initialize(self, ctx):
self._context.generate_relin_keys()
self._context.global_scale = 2 ** self.scale_bits

def build(self, study, ctx):
servers = study.get_participants_by_type("server", first_only=False)
def build(self, project, ctx):
servers = project.get_participants_by_type("server", first_only=False)
for server in servers:
dest_dir = self.get_kit_dir(server, ctx)
with open(os.path.join(dest_dir, "server_context.tenseal"), "wb") as f:
f.write(self.get_serialized_context())
for client in study.get_participants_by_type("client", first_only=False):
for client in project.get_participants_by_type("client", first_only=False):
dest_dir = self.get_kit_dir(client, ctx)
with open(os.path.join(dest_dir, "client_context.tenseal"), "wb") as f:
f.write(self.get_serialized_context(is_client=True))
Expand Down
8 changes: 4 additions & 4 deletions nvflare/lighter/impl/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json
import os

from nvflare.lighter.spec import Builder, Study
from nvflare.lighter.spec import Builder, Project
from nvflare.lighter.utils import sign_all


Expand All @@ -26,14 +26,14 @@ class SignatureBuilder(Builder):
can be cryptographically verified to ensure any tampering is detected. This builder writes the signature.pkl file.
"""

def build(self, study: Study, ctx: dict):
servers = study.get_participants_by_type("server", first_only=False)
def build(self, project: Project, ctx: dict):
servers = project.get_participants_by_type("server", first_only=False)
for server in servers:
dest_dir = self.get_kit_dir(server, ctx)
root_pri_key = ctx.get("root_pri_key")
signatures = sign_all(dest_dir, root_pri_key)
json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt"))
for p in study.get_participants_by_type("client", first_only=False):
for p in project.get_participants_by_type("client", first_only=False):
dest_dir = self.get_kit_dir(p, ctx)
root_pri_key = ctx.get("root_pri_key")
signatures = sign_all(dest_dir, root_pri_key)
Expand Down
28 changes: 14 additions & 14 deletions nvflare/lighter/impl/static_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
):
"""Build all static files from template.
Uses the information from project.yml through study to go through the participants and write the contents of
Uses the information from project.yml through project to go through the participants and write the contents of
each file with the template, and replacing with the appropriate values from project.yml.
Usually, two main categories of files are created in all FL participants, static and dynamic. Static files
Expand Down Expand Up @@ -75,7 +75,7 @@ def _build_overseer(self, overseer, ctx):
default_port = "443" if protocol == "https" else "80"
port = overseer.props.get("port", default_port)
replacement_dict = {"port": port}
admins = self.study.get_participants_by_type("admin", first_only=False)
admins = self.project.get_participants_by_type("admin", first_only=False)
privilege_dict = dict()
for admin in admins:
for role in admin.props.get("roles", {}):
Expand Down Expand Up @@ -118,7 +118,7 @@ def _build_server(self, server, ctx):
config = json.loads(self.template["fed_server"])
dest_dir = self.get_kit_dir(server, ctx)
server_0 = config["servers"][0]
server_0["name"] = self.study_name
server_0["name"] = self.project_name
admin_port = server.props.get("admin_port", 8003)
ctx["admin_port"] = admin_port
fed_learn_port = server.props.get("fed_learn_port", 8002)
Expand All @@ -136,7 +136,7 @@ def _build_server(self, server, ctx):
overseer_agent["args"] = {
"role": "server",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.study_name,
"project": self.project_name,
"name": server.name,
"fl_port": str(fed_learn_port),
"admin_port": str(admin_port),
Expand Down Expand Up @@ -194,7 +194,7 @@ def _build_client(self, client, ctx):
fed_learn_port = ctx.get("fed_learn_port")
server_name = ctx.get("server_name")
# config["servers"][0]["service"]["target"] = f"{server_name}:{fed_learn_port}"
config["servers"][0]["name"] = self.study_name
config["servers"][0]["name"] = self.project_name
config["enable_byoc"] = client.enable_byoc
replacement_dict = {
"client_name": f"{client.subject}",
Expand All @@ -207,7 +207,7 @@ def _build_client(self, client, ctx):
overseer_agent["args"] = {
"role": "client",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.study_name,
"project": self.project_name,
"name": client.subject,
}
overseer_agent.pop("overseer_exists", None)
Expand Down Expand Up @@ -268,7 +268,7 @@ def _build_admin(self, admin, ctx):
overseer_agent["args"] = {
"role": "admin",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.study_name,
"project": self.project_name,
"name": admin.subject,
}
overseer_agent.pop("overseer_exists", None)
Expand All @@ -294,18 +294,18 @@ def _build_admin(self, admin, ctx):
"t",
)

def build(self, study, ctx):
def build(self, project, ctx):
self.template = ctx.get("template")
self.study_name = study.name
self.study = study
overseer = study.get_participants_by_type("overseer")
self.project_name = project.name
self.project = project
overseer = project.get_participants_by_type("overseer")
self._build_overseer(overseer, ctx)
servers = study.get_participants_by_type("server", first_only=False)
servers = project.get_participants_by_type("server", first_only=False)
for server in servers:
self._build_server(server, ctx)

for client in study.get_participants_by_type("client", first_only=False):
for client in project.get_participants_by_type("client", first_only=False):
self._build_client(client, ctx)

for admin in study.get_participants_by_type("admin", first_only=False):
for admin in project.get_participants_by_type("admin", first_only=False):
self._build_admin(admin, ctx)
8 changes: 4 additions & 4 deletions nvflare/lighter/impl/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import shutil
import subprocess

from nvflare.lighter.spec import Builder, Study
from nvflare.lighter.spec import Builder, Project
from nvflare.lighter.utils import generate_password


Expand Down Expand Up @@ -68,8 +68,8 @@ def initialize(self, ctx):
shutil.copyfile(os.path.join(file_path, self.template_file), template_file_full_path)
ctx["template_file"] = self.template_file

def build(self, study: Study, ctx: dict):
dirs = [self.get_kit_dir(p, ctx) for p in study.participants]
def build(self, project: Project, ctx: dict):
dirs = [self.get_kit_dir(p, ctx) for p in project.participants]
self._make_dir(dirs)

def finalize(self, ctx: dict):
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, zip_password=False):
"""
self.zip_password = zip_password

def build(self, study: Study, ctx: dict):
def build(self, project: Project, ctx: dict):
wip_dir = self.get_wip_dir(ctx)
dirs = [name for name in os.listdir(wip_dir) if os.path.isdir(os.path.join(wip_dir, name))]
for dir in dirs:
Expand Down
18 changes: 9 additions & 9 deletions nvflare/lighter/provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import yaml

from nvflare.fuel.utils.class_utils import instantiate_class
from nvflare.lighter.spec import Participant, Provisioner, Study
from nvflare.lighter.spec import Participant, Project, Provisioner


def main():
Expand Down Expand Up @@ -77,27 +77,27 @@ def main():
project_full_path = os.path.join(current_path, project_file)
print(f"Project yaml file: {project_full_path}.")

project = yaml.load(open(project_full_path, "r"), Loader=yaml.Loader)
api_version = project.get("api_version")
project_dict = yaml.load(open(project_full_path, "r"), Loader=yaml.Loader)
api_version = project_dict.get("api_version")
if api_version not in [3]:
raise ValueError(f"API version expected 3 but found {api_version}")

study_name = project.get("name")
study_description = project.get("description", "")
project_name = project_dict.get("name")
project_description = project_dict.get("description", "")
participants = list()
for p in project.get("participants"):
for p in project_dict.get("participants"):
participants.append(Participant(**p))
study = Study(name=study_name, description=study_description, participants=participants)
project = Project(name=project_name, description=project_description, participants=participants)

builders = list()
for b in project.get("builders"):
for b in project_dict.get("builders"):
path = b.get("path")
args = b.get("args")
builders.append(instantiate_class(path, args))

provisioner = Provisioner(workspace_full_path, builders)

provisioner.provision(study)
provisioner.provision(project)


if __name__ == "__main__":
Expand Down
14,394 changes: 14,356 additions & 38 deletions nvflare/lighter/provision_helper.html

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions nvflare/lighter/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ def __init__(self, type: str, name: str, org: str, enable_byoc: bool = False, *a
self.props = kwargs


class Study(object):
class Project(object):
def __init__(self, name: str, description: str, participants: List[Participant]):
"""A container class to hold information about this FL study.
"""A container class to hold information about this FL project.
This calss only holds information. It does not drive the workflow.
Args:
name (str): the study name
name (str): the project name
description (str): brief description on this name
participants (List[Participant]): All the participants that will join this study
participants (List[Participant]): All the participants that will join this project
Raises:
ValueError: when duplicate name found in participants list
Expand All @@ -69,7 +69,7 @@ def __init__(self, name: str, description: str, participants: List[Participant])
all_names = list()
for p in participants:
if p.name in all_names:
raise ValueError(f"Unable to add a duplicate name {p.name} into this study.")
raise ValueError(f"Unable to add a duplicate name {p.name} into this project.")
else:
all_names.append(p.name)
self.description = description
Expand All @@ -90,7 +90,7 @@ class Builder(ABC):
def initialize(self, ctx: dict):
pass

def build(self, study: Study, ctx: dict):
def build(self, project: Project, ctx: dict):
pass

def finalize(self, ctx: dict):
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self, root_dir: str, builders: List[Builder]):
ROOT_WORKSPACE Folder Structure
root_workspace_dir_name: this is the root of the workspace
study_dir_name: the root dir of the study, could be named after the study
project_dir_name: the root dir of the project, could be named after the project
resources: stores resource files (templates, configs, etc.) of the Provisioner and Builders
prod: stores the current set of startup kits (production)
participate_dir: stores content files generated by builders
Expand Down Expand Up @@ -152,18 +152,18 @@ def _prepare_workspace(self, ctx):
dirs = [workspace, resources_dir, wip_dir, state_dir]
self._make_dir(dirs)

def provision(self, study: Study):
# ctx = {"workspace": os.path.join(self.root_dir, study.name), "study": study}
workspace = os.path.join(self.root_dir, study.name)
ctx = {"workspace": workspace} # study is more static information while ctx is dynamic
def provision(self, project: Project):
# ctx = {"workspace": os.path.join(self.root_dir, project.name), "project": project}
workspace = os.path.join(self.root_dir, project.name)
ctx = {"workspace": workspace} # project is more static information while ctx is dynamic
self._prepare_workspace(ctx)
try:
for b in self.builders:
b.initialize(ctx)

# call builders!
for b in self.builders:
b.build(study, ctx)
b.build(project, ctx)

for b in self.builders[::-1]:
b.finalize(ctx)
Expand Down

0 comments on commit 869c500

Please sign in to comment.