Skip to content

Commit

Permalink
Merge pull request #2200 from FedML-AI/raphael/pass-api-key
Browse files Browse the repository at this point in the history
[Deploy] Pass down api key to container.
  • Loading branch information
Raphael-Jin authored Jun 21, 2024
2 parents 67e93e8 + f412a26 commit 6ec7379
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ class ClientConstants(object):
CUSTOMIZED_VOLUMES_PATH_FROM_WORKSPACE_KEY = "workspace_path"
CUSTOMIZED_VOLUMES_PATH_FROM_CONTAINER_KEY = "mount_path"

ENV_USER_ENCRYPTED_API_KEY = "FEDML_USER_ENCRYPTED_API_KEY"

@staticmethod
def get_fedml_home_dir():
home_dir = expanduser("~")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def set_user_setting_replica_num(self, end_point_id,
replica_num: int, enable_auto_scaling: bool = False,
scale_min: int = 0, scale_max: int = 0, state: str = "UNKNOWN",
target_queries_per_replica: int = 60, aggregation_window_size_seconds: int = 60,
scale_down_delay_seconds: int = 120, timeout_s: int = 30
scale_down_delay_seconds: int = 120, timeout_s: int = 30,
user_encrypted_api_key: str = ""
) -> bool:
"""
Key: FEDML_MODEL_ENDPOINT_REPLICA_USER_SETTING_TAG--<end_point_id>
Expand All @@ -139,7 +140,8 @@ def set_user_setting_replica_num(self, end_point_id,
"target_queries_per_replica": target_queries_per_replica,
"aggregation_window_size_seconds": aggregation_window_size_seconds,
"scale_down_delay_seconds": scale_down_delay_seconds,
ServerConstants.INFERENCE_REQUEST_TIMEOUT_KEY: timeout_s
ServerConstants.INFERENCE_REQUEST_TIMEOUT_KEY: timeout_s,
ServerConstants.USER_ENCRYPTED_API_KEY: user_encrypted_api_key
}
try:
self.redis_connection.set(self.get_user_setting_replica_num_key(end_point_id), json.dumps(replica_num_dict))
Expand Down Expand Up @@ -169,6 +171,15 @@ def update_user_setting_replica_num(self, end_point_id: str, state: str = "UNKNO
return False
return True

def get_user_encrypted_api_key(self, end_point_id: str) -> str:
try:
replica_num_dict = self.redis_connection.get(self.get_user_setting_replica_num_key(end_point_id))
replica_num_dict = json.loads(replica_num_dict)
return replica_num_dict.get(ServerConstants.USER_ENCRYPTED_API_KEY, "")
except Exception as e:
logging.error(e)
return ""

def get_all_endpoints_user_setting(self) -> List[dict]:
"""
Return a list of dict, each dict is the user setting of an endpoint.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from fedml.computing.scheduler.comm_utils.job_utils import JobRunnerUtils
from fedml.computing.scheduler.comm_utils.constants import SchedulerConstants
from fedml.computing.scheduler.model_scheduler.device_client_constants import ClientConstants
from fedml.computing.scheduler.model_scheduler.device_server_constants import ServerConstants
from fedml.computing.scheduler.model_scheduler.device_model_cache import FedMLModelCache
from ..scheduler_core.compute_utils import ComputeUtils
from ..comm_utils.container_utils import ContainerUtils
Expand Down Expand Up @@ -59,7 +60,9 @@ def request_gpu_ids_on_deployment(edge_id, end_point_id, num_gpus=None, master_d
def start_deployment(end_point_id, end_point_name, model_id, model_version,
model_storage_local_path, inference_model_name, inference_engine,
infer_host, master_ip, edge_id, master_device_id=None, replica_rank=0,
gpu_per_replica=1):
gpu_per_replica=1, request_json=None):
if request_json is None:
request_json = dict()
logging.info("[Worker] Model deployment is starting...")

# Real gpu per replica (container-level)
Expand Down Expand Up @@ -219,22 +222,9 @@ def start_deployment(end_point_id, end_point_name, model_id, model_version,
if device_mapping:
host_config_dict.update(device_mapping)

# Environment variables
enable_custom_image = False if relative_entry_fedml_format != "" else True
if not enable_custom_image:
# For some image, the default user is root. Unified to fedml.
environment["HOME"] = "/home/fedml"
environment["BOOTSTRAP_DIR"] = dst_bootstrap_dir
environment["FEDML_CURRENT_RUN_ID"] = end_point_id
environment["FEDML_CURRENT_EDGE_ID"] = edge_id
environment["FEDML_REPLICA_RANK"] = replica_rank
environment["FEDML_CURRENT_VERSION"] = fedml.get_env_version()
environment["FEDML_ENV_VERSION"] = fedml.get_env_version()
environment["FEDML_ENV_LOCAL_ON_PREMISE_PLATFORM_HOST"] = fedml.get_local_on_premise_platform_host()
environment["FEDML_ENV_LOCAL_ON_PREMISE_PLATFORM_PORT"] = fedml.get_local_on_premise_platform_port()
if extra_envs is not None:
for key in extra_envs:
environment[key] = extra_envs[key]
# Handle the environment variables
handle_env_vars(environment, relative_entry_fedml_format, extra_envs, dst_bootstrap_dir,
end_point_id, edge_id, replica_rank, request_json)

# Create the container
try:
Expand Down Expand Up @@ -612,6 +602,29 @@ def handle_volume_mount(volumes, binds, environment, relative_entry_fedml_format
logging.warning(f"{workspace_path} does not exist, skip mounting it to the container")


def handle_env_vars(environment, relative_entry_fedml_format, extra_envs, dst_bootstrap_dir, end_point_id, edge_id,
replica_rank, request_json):
enable_custom_image = False if relative_entry_fedml_format != "" else True
if not enable_custom_image:
# For some image, the default user is root. Unified to fedml.
environment["HOME"] = "/home/fedml"

if request_json and ServerConstants.USER_ENCRYPTED_API_KEY in request_json:
environment[ClientConstants.ENV_USER_ENCRYPTED_API_KEY] = request_json[ServerConstants.USER_ENCRYPTED_API_KEY]

environment["BOOTSTRAP_DIR"] = dst_bootstrap_dir
environment["FEDML_CURRENT_RUN_ID"] = end_point_id
environment["FEDML_CURRENT_EDGE_ID"] = edge_id
environment["FEDML_REPLICA_RANK"] = replica_rank
environment["FEDML_CURRENT_VERSION"] = fedml.get_env_version()
environment["FEDML_ENV_VERSION"] = fedml.get_env_version()
environment["FEDML_ENV_LOCAL_ON_PREMISE_PLATFORM_HOST"] = fedml.get_local_on_premise_platform_host()
environment["FEDML_ENV_LOCAL_ON_PREMISE_PLATFORM_PORT"] = fedml.get_local_on_premise_platform_port()
if extra_envs is not None:
for key in extra_envs:
environment[key] = extra_envs[key]


def check_container_readiness(inference_http_port, infer_host="127.0.0.1", request_input_example=None,
readiness_check=ClientConstants.READINESS_PROBE_DEFAULT):
response_from_client_container = is_client_inference_container_ready(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class ServerConstants(object):

INFERENCE_REQUEST_TIMEOUT_KEY = "request_timeout_sec"
INFERENCE_REQUEST_TIMEOUT_DEFAULT = 30

USER_ENCRYPTED_API_KEY = "encrypted_api_key"
# -----End-----

MODEL_DEPLOYMENT_STAGE1 = {"index": 1, "text": "ReceivedRequest"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,25 +158,20 @@ def callback_start_deployment(self, topic, payload):
run_id = request_json["end_point_id"]
end_point_name = request_json["end_point_name"]
token = request_json["token"]
user_id = request_json["user_id"]
user_name = request_json["user_name"]
device_ids = request_json["device_ids"]
device_objs = request_json["device_objs"]
enable_auto_scaling = request_json.get("enable_auto_scaling", False)
desired_replica_num = request_json.get("desired_replica_num", 1)
target_queries_per_replica = request_json.get("target_queries_per_replica", 10)
aggregation_window_size_seconds = request_json.get("aggregation_window_size_seconds", 60)
scale_down_delay_seconds = request_json.get("scale_down_delay_seconds", 120)
user_encrypted_api_key = request_json.get(ServerConstants.USER_ENCRYPTED_API_KEY, "")

model_config = request_json["model_config"]
model_name = model_config["model_name"]
model_version = model_config["model_version"]
model_id = model_config["model_id"]
model_storage_url = model_config["model_storage_url"]
scale_min = model_config.get("instance_scale_min", 0)
scale_max = model_config.get("instance_scale_max", 0)
inference_engine = model_config.get("inference_engine", 0)
enable_auto_scaling = request_json.get("enable_auto_scaling", False)
desired_replica_num = request_json.get("desired_replica_num", 1)

target_queries_per_replica = request_json.get("target_queries_per_replica", 10)
aggregation_window_size_seconds = request_json.get("aggregation_window_size_seconds", 60)
scale_down_delay_seconds = request_json.get("scale_down_delay_seconds", 120)

model_config_parameters = request_json.get("parameters", {})
timeout_s = model_config_parameters.get("request_timeout_sec", 30)
Expand All @@ -193,6 +188,12 @@ def callback_start_deployment(self, topic, payload):
request_json["end_point_id"])
request_json["is_fresh_endpoint"] = True if endpoint_device_info is None else False

if user_encrypted_api_key == "":
user_encrypted_api_key = (FedMLModelCache.get_instance(self.redis_addr, self.redis_port).
get_user_encrypted_api_key(run_id))
if user_encrypted_api_key != "": # Pass the cached key to the workers
request_json[ServerConstants.USER_ENCRYPTED_API_KEY] = user_encrypted_api_key

# Save the user setting (about replica number) of this run to Redis, if existed, update it
FedMLModelCache.get_instance(self.redis_addr, self.redis_port).set_user_setting_replica_num(
end_point_id=run_id, end_point_name=end_point_name, model_name=model_name, model_version=model_version,
Expand All @@ -201,7 +202,7 @@ def callback_start_deployment(self, topic, payload):
aggregation_window_size_seconds=aggregation_window_size_seconds,
target_queries_per_replica=target_queries_per_replica,
scale_down_delay_seconds=int(scale_down_delay_seconds),
timeout_s=timeout_s
timeout_s=timeout_s, user_encrypted_api_key=user_encrypted_api_key
)

# Start log processor for current run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def run_impl(self, run_extend_queue_list, sender_message_center,
inference_model_name=model_name, inference_engine=inference_engine,
infer_host=worker_ip, master_ip=master_ip, edge_id=self.edge_id,
master_device_id=device_ids[0], replica_rank=rank,
gpu_per_replica=int(self.replica_handler.gpu_per_replica)
gpu_per_replica=int(self.replica_handler.gpu_per_replica), request_json=self.request_json
)
except Exception as e:
inference_output_url = ""
Expand Down Expand Up @@ -373,7 +373,7 @@ def run_impl(self, run_extend_queue_list, sender_message_center,
inference_model_name=model_name, inference_engine=inference_engine,
infer_host=worker_ip, master_ip=master_ip, edge_id=self.edge_id,
master_device_id=device_ids[0], replica_rank=rank,
gpu_per_replica=int(self.replica_handler.gpu_per_replica)
gpu_per_replica=int(self.replica_handler.gpu_per_replica), request_json=self.request_json
)
except Exception as e:
inference_output_url = ""
Expand Down

0 comments on commit 6ec7379

Please sign in to comment.