Skip to content

Commit

Permalink
Fix job def and manager specs (NVIDIA#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Apr 7, 2022
1 parent c389914 commit 2f694e3
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 135 deletions.
53 changes: 30 additions & 23 deletions nvflare/apis/impl/job_def_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import Job, JobMetaKey, job_from_meta
from nvflare.apis.job_def_manager_spec import JobDefManagerSpec, RunStatus
from nvflare.apis.server_engine_spec import ServerEngineSpec
from nvflare.apis.storage import StorageSpec
from nvflare.apis.study_manager_spec import StudyManagerSpec
from nvflare.fuel.hci.zip_utils import unzip_all_from_bytes, zip_directory_to_bytes
Expand Down Expand Up @@ -62,6 +63,10 @@ def __init__(self, reviewer_name, fl_ctx: FLContext):
self.reviewer_name = reviewer_name
engine = fl_ctx.get_engine()
self.study_manager = engine.get_component("study_manager")
if not isinstance(self.study_manager, StudyManagerSpec):
raise TypeError(
f"engine should have a study manager component of type StudyManagerSpec, but got {type(self.study_manager)}"
)

def filter_job(self, meta: dict):
study = self.study_manager.get_study(meta[JobMetaKey.STUDY_NAME])
Expand All @@ -77,6 +82,9 @@ def filter_job(self, meta: dict):
return True


# TODO:: use try block around storage calls


class SimpleJobDefManager(JobDefManagerSpec):
def __init__(self, uri_root: str = "jobs", job_store_id: str = "job_store", temp_dir: str = "/tmp"):
super().__init__()
Expand All @@ -86,6 +94,15 @@ def __init__(self, uri_root: str = "jobs", job_store_id: str = "job_store", temp
raise ValueError("temp_dir {} is not a valid dir".format(temp_dir))
self.temp_dir = temp_dir

def _get_job_store(self, fl_ctx):
engine = fl_ctx.get_engine()
if not isinstance(engine, ServerEngineSpec):
raise TypeError(f"engine should be of type ServerEngineSpec, but got {type(engine)}")
store = engine.get_component(self.job_store_id)
if not isinstance(store, StorageSpec):
raise TypeError(f"engine should have a job store component of type StorageSpec, but got {type(store)}")
return store

def job_uri(self, jid: str):
return os.path.join(self.uri_root, jid)

Expand All @@ -102,14 +119,12 @@ def create(self, meta: dict, uploaded_content: bytes, fl_ctx: FLContext) -> Dict
meta[JobMetaKey.STATUS] = RunStatus.SUBMITTED

# write it to the store
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
store.create_object(self.job_uri(jid), uploaded_content, meta, overwrite_existing=True)
return meta

def delete(self, jid: str, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
store.delete_object(self.job_uri(jid))

def _validate_meta(self, meta):
Expand All @@ -136,14 +151,12 @@ def _validate_uploaded_content(self, uploaded_content) -> bool:
pass

def get_job(self, jid: str, fl_ctx: FLContext) -> Job:
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
job_meta = store.get_meta(self.job_uri(jid))
return job_from_meta(job_meta)

def set_results_uri(self, jid: str, result_uri: str, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
updated_meta = {JobMetaKey.RESULT_LOCATION: result_uri}
store.update_meta(self.job_uri(jid), updated_meta, replace=False)
return self.get_job(jid, fl_ctx)
Expand All @@ -160,8 +173,7 @@ def get_apps(self, job: Job, fl_ctx: FLContext) -> Dict[str, bytes]:
return result_dict

def _load_job_data_from_store(self, jid: str, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
data_bytes = store.get_data(self.job_uri(jid))
job_id_dir = os.path.join(self.temp_dir, jid)
if os.path.exists(job_id_dir):
Expand All @@ -171,36 +183,32 @@ def _load_job_data_from_store(self, jid: str, fl_ctx: FLContext):
return job_id_dir

def get_content(self, jid: str, fl_ctx: FLContext) -> bytes:
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
return store.get_data(self.job_uri(jid))

def set_status(self, jid: str, status: RunStatus, fl_ctx: FLContext):
meta = {JobMetaKey.STATUS: status}
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
store.update_meta(uri=self.job_uri(jid), meta=meta, replace=False)

def update_meta(self, jid: str, meta, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
store.update_meta(uri=self.job_uri(jid), meta=meta, replace=False)

def list_all(self, fl_ctx: FLContext) -> List[Job]:
def get_all_jobs(self, fl_ctx: FLContext) -> List[Job]:
job_filter = _AllJobsFilter()
self._scan(job_filter, fl_ctx)
return job_filter.result

def _scan(self, job_filter: _JobFilter, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
jid_paths = store.list_objects(self.uri_root)
if not jid_paths:
return

for jid_path in jid_paths:
jid = pathlib.PurePath(jid_path).name
meta = self.get_job(jid, fl_ctx).meta
meta = store.get_meta(self.job_uri(jid))
if meta:
ok = job_filter.filter_job(meta)
if not ok:
Expand All @@ -211,7 +219,7 @@ def get_jobs_by_status(self, status, fl_ctx: FLContext) -> List[Job]:
self._scan(job_filter, fl_ctx)
return job_filter.result

def get_jobs_waiting_for_review(self, reviewer_name: str, fl_ctx: FLContext) -> List[Dict[str, Any]]:
def get_jobs_waiting_for_review(self, reviewer_name: str, fl_ctx: FLContext) -> List[Job]:
job_filter = _ReviewerFilter(reviewer_name, fl_ctx)
self._scan(job_filter, fl_ctx)
return job_filter.result
Expand All @@ -227,7 +235,6 @@ def set_approval(
meta[JobMetaKey.APPROVALS] = approvals
approvals[reviewer_name] = (approved, note)
updated_meta = {JobMetaKey.APPROVALS: approvals}
engine = fl_ctx.get_engine()
store = engine.get_component(self.job_store_id)
store = self._get_job_store(fl_ctx)
store.update_meta(self.job_uri(jid), updated_meta, replace=False)
return meta
73 changes: 36 additions & 37 deletions nvflare/apis/job_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from enum import Enum
from typing import Dict, List
from typing import Dict, List, Optional

from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def_manager_spec import JobDefManagerSpec
Expand Down Expand Up @@ -49,24 +48,36 @@ def __repr__(self):


class Job:
"""Job object containing the job metadata.
Args:
job_id: Job ID
study_name: Study name
resource_spec: Resource specification with information on the resources of each client
deploy_map: Deploy map specifying each app and the sites that it should be deployed to
meta: full contents of the persisted metadata for the job for persistent storage
"""

def __init__(self, job_id, study_name, resource_spec, deploy_map, meta):
def __init__(
self,
job_id: str,
study_name: str,
resource_spec: Dict[str, Dict],
deploy_map: Dict[str, List[str]],
meta,
min_sites: int = 1,
required_sites: Optional[List[str]] = None,
):
"""Job object containing the job metadata.
Args:
job_id: Job ID
study_name: Study name
resource_spec: Resource specification with information on the resources of each client
deploy_map: Deploy map specifying each app and the sites that it should be deployed to
meta: full contents of the persisted metadata for the job for persistent storage
min_sites (int): minimum number of sites
required_sites: A list of required site names
"""
self.job_id = job_id
self.study = study_name
# self.num_clients = num_clients # some way to specify minimum clients needed sites
self.resource_spec = resource_spec # resource_requirements should be {client name: resource}
self.resource_spec = resource_spec # resource_requirements should be {site name: resource}
self.deploy_map = deploy_map # should be {app name: a list of sites}

self.meta = meta
self.min_sites = min_sites
self.required_sites = required_sites

self.dispatcher_id = None
self.dispatch_time = None

Expand All @@ -92,8 +103,8 @@ def get_deployment(self) -> Dict[str, List[str]]:
]
},
Returns: contents of deploy_map as a dictionary of strings of app names with their corresponding sites
Returns:
Contents of deploy_map as a dictionary of strings of app names with their corresponding sites
"""
return self.deploy_map

Expand All @@ -104,7 +115,7 @@ def get_application(self, participant, fl_ctx: FLContext) -> bytes:
job_def_manager = engine.get_component("job_manager")
if not isinstance(job_def_manager, JobDefManagerSpec):
raise TypeError(f"job_def_manager must be JobDefManagerSpec type. Got: {type(job_def_manager)}")
return job_def_manager.get_app(self, application_name)
return job_def_manager.get_app(self, application_name, fl_ctx)

def get_application_name(self, participant):
"""Get the application name for the specified participant."""
Expand All @@ -115,22 +126,25 @@ def get_application_name(self, participant):
return None

def get_resource_requirements(self):
"""Return app resource requirements."""
"""Returns app resource requirements.
Returns:
A dict of {site_name: resource}
"""
return self.resource_spec

def __eq__(self, other):
return self.job_id == other.job_id


def job_from_meta(meta: dict) -> Job:
"""Convert information in meta into a Job object.
"""Converts information in meta into a Job object.
Args:
meta: dict of meta information
Returns:
Job object.
A Job object.
"""
job = Job(
meta.get(JobMetaKey.JOB_ID),
Expand All @@ -140,18 +154,3 @@ def job_from_meta(meta: dict) -> Job:
meta,
)
return job


def get_site_require_resource_from_job(job: Job):
"""Get the total resource needed by each site to run this Job."""
required_resources = job.get_resource_requirements()
deployment = job.get_deployment()

total_required_resources = {} # {site name: total resources}
for app in required_resources:
for site_name in deployment[app]:
if site_name not in total_required_resources:
total_required_resources[site_name] = copy.deepcopy(required_resources[app])
else:
total_required_resources[site_name] = total_required_resources[site_name] + required_resources[app]
return total_required_resources
Loading

0 comments on commit 2f694e3

Please sign in to comment.