Skip to content

Commit

Permalink
Move job scheduler from private to app_common (NVIDIA#451)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Apr 28, 2022
1 parent 06e9599 commit 1e087e0
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 50 deletions.
33 changes: 32 additions & 1 deletion nvflare/apis/server_engine_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import List, Tuple
from typing import Dict, List, Tuple

from nvflare.apis.shareable import Shareable
from nvflare.widgets.widget import Widget
Expand Down Expand Up @@ -157,3 +157,34 @@ def start_client_job(self, run_number, client_sites):
"""
pass

@abstractmethod
def check_client_resources(self, resource_reqs: Dict[str, dict]) -> Dict[str, Tuple[bool, str]]:
"""Sends the check_client_resources requests to the clients.
Args:
resource_reqs: A dict of {client_name: resource requirements dict}
Returns:
A dict of {client_name: client_check_result} where client_check_result
is a tuple of {client check OK, resource reserve token if any}
"""
pass

@abstractmethod
def cancel_client_resources(
self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict]
):
"""Cancels the request resources for the job.
Args:
resource_check_results: A dict of {client_name: client_check_result}
where client_check_result is a tuple of {client check OK, resource reserve token if any}
resource_reqs: A dict of {client_name: resource requirements dict}
"""
pass

@abstractmethod
def get_client_name_from_token(self, token: str) -> str:
"""Gets client name from a client login token."""
pass
13 changes: 13 additions & 0 deletions nvflare/app_common/job_schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import threading
from typing import Dict, List, Optional, Tuple

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.job_def import Job
from nvflare.apis.job_scheduler_spec import DispatchInfo, JobSchedulerSpec
from nvflare.private.fed.server.admin import ClientReply
from nvflare.private.fed.server.server_engine import ServerEngine
from nvflare.private.scheduler_constants import ShareableHeader
from nvflare.apis.server_engine_spec import ServerEngineSpec


def _check_client_resources(resource_reqs: Dict[str, dict], fl_ctx: FLContext) -> Dict[str, Tuple[bool, str]]:
Expand All @@ -36,22 +33,10 @@ def _check_client_resources(resource_reqs: Dict[str, dict], fl_ctx: FLContext) -
where client_check_result is a tuple of {client check OK, resource reserve token if any}
"""
engine = fl_ctx.get_engine()
if not isinstance(engine, ServerEngine):
raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngine, but got {type(engine)}.")

replies: List[ClientReply] = engine.check_client_resources(resource_reqs)

result = {}
for r in replies:
site_name = engine.get_client_name_from_token(r.client_token)
if r.reply:
resp = pickle.loads(r.reply.body)
result[site_name] = (
resp.get_header(ShareableHeader.CHECK_RESOURCE_RESULT, False),
resp.get_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, ""),
)
else:
result[site_name] = (False, "")
if not isinstance(engine, ServerEngineSpec):
raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngineSpec, but got {type(engine)}.")

result = engine.check_client_resources(resource_reqs)

return result

Expand All @@ -68,8 +53,8 @@ def _cancel_resources(
fl_ctx: FL context
"""
engine = fl_ctx.get_engine()
if not isinstance(engine, ServerEngine):
raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngine, but got {type(engine)}.")
if not isinstance(engine, ServerEngineSpec):
raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngineSpec, but got {type(engine)}.")

engine.cancel_client_resources(resource_check_results, resource_reqs)
return False, None
Expand Down
2 changes: 1 addition & 1 deletion nvflare/lighter/project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ builders:
components:
server:
- id: job_scheduler # This id is reserved by system. Do not change it.
path: nvflare.private.fed.server.job_scheduler.DefaultJobScheduler
path: nvflare.app_common.job_schedulers.job_scheduler.DefaultJobScheduler
args:
max_jobs: 4
- id: job_manager # This id is reserved by system. Do not change it.
Expand Down
15 changes: 11 additions & 4 deletions nvflare/poc/client/startup/fed_client.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,23 @@
{
"id": "resource_manager",
"path": "nvflare.app_common.resource_managers.list_resource_manager.ListResourceManager",
"args": {
"resources": {"gpu": [0, 1, 2, 3]}
"args": {
"resources": {
"gpu": [
0,
1,
2,
3
]
}
}
},
{
"id": "resource_consumer",
"path": "nvflare.app_common.resource_consumers.gpu_resource_consumer.GPUResourceConsumer",
"args": {
"args": {
"gpu_resource_key": "gpu"
}
}
]
}
}
4 changes: 2 additions & 2 deletions nvflare/poc/client/startup/fed_client_HA.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"components": [
{
"id": "resource_manager",
"path": "nvflare.apis.impl.list_resource_manager.ListResourceManager",
"path": "nvflare.app_common.resource_managers.list_resource_manager.ListResourceManager",
"args": {
"resources": {
"gpu": [
Expand All @@ -50,7 +50,7 @@
},
{
"id": "resource_consumer",
"path": "nvflare.apis.impl.gpu_resource_consumer.GPUResourceConsumer",
"path": "nvflare.app_common.resource_consumers.gpu_resource_consumer.GPUResourceConsumer",
"args": {
"gpu_resource_key": "gpu"
}
Expand Down
2 changes: 1 addition & 1 deletion nvflare/poc/server/startup/fed_server.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"components": [
{
"id": "job_scheduler",
"path": "nvflare.private.fed.server.job_scheduler.DefaultJobScheduler",
"path": "nvflare.app_common.job_schedulers.job_scheduler.DefaultJobScheduler",
"args": {
"max_jobs": 4
}
Expand Down
6 changes: 3 additions & 3 deletions nvflare/poc/server/startup/fed_server_HA.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"components": [
{
"id": "job_scheduler",
"path": "nvflare.apis.impl.job_scheduler.DefaultJobScheduler",
"path": "nvflare.app_common.job_schedulers.job_scheduler.DefaultJobScheduler",
"args": {
"max_jobs": 4
}
Expand All @@ -69,7 +69,7 @@
},
{
"id": "job_store",
"name": "FilesystemStorage",
"path": "nvflare.app_common.storages.filesystem_storage.FilesystemStorage",
"args": {}
},
{
Expand All @@ -85,4 +85,4 @@
"args": {}
}
]
}
}
29 changes: 13 additions & 16 deletions nvflare/private/fed/server/server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,15 +696,7 @@ def get_errors(self, run_number):
def _send_admin_requests(self, requests) -> List[ClientReply]:
return self.server.admin_server.send_requests(requests, timeout_secs=self.server.admin_server.timeout)

def check_client_resources(self, resource_reqs) -> List[ClientReply]:
"""To send the check_client_resources requests to the clients
Args:
resource_reqs: resource_reqs for the job
Returns:
A list of ClientReply.
"""
def check_client_resources(self, resource_reqs) -> Dict[str, Tuple[bool, str]]:
requests = {}
for site_name, resource_requirements in resource_reqs.items():
# assume server resource is unlimited
Expand All @@ -717,17 +709,22 @@ def check_client_resources(self, resource_reqs) -> List[ClientReply]:
replies = []
if requests:
replies = self._send_admin_requests(requests)
return replies
result = {}
for r in replies:
site_name = self.get_client_name_from_token(r.client_token)
if r.reply:
resp = pickle.loads(r.reply.body)
result[site_name] = (
resp.get_header(ShareableHeader.CHECK_RESOURCE_RESULT, False),
resp.get_header(ShareableHeader.RESOURCE_RESERVE_TOKEN, ""),
)
else:
result[site_name] = (False, "")
return result

def cancel_client_resources(
self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict]
):
"""To cancel the request resources for the job
Args:
resource_check_results: reserved resources
resource_reqs: resource_reqs for the job
"""
requests = {}
for site_name, result in resource_check_results.items():
check_result, token = result
Expand Down

0 comments on commit 1e087e0

Please sign in to comment.