diff --git a/src/sagemaker_server.cc b/src/sagemaker_server.cc index 2009b02adb..26a52d7738 100644 --- a/src/sagemaker_server.cc +++ b/src/sagemaker_server.cc @@ -226,9 +226,21 @@ SagemakerAPIServer::Handle(evhtp_request_t* req) evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/ return; } - LOG_VERBOSE(1) << "SageMaker MME Custom Invoke Model Path" - << std::endl; - SageMakerMMEHandleInfer(req, multi_model_name, model_version_str_); + LOG_VERBOSE(1) << "SageMaker MME Custom Invoke Model Path"; + + /* Extract targetModel to log the associated archive */ + const char* target_model = + evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model"); + + /* If target_model is not available (e.g., in local testing) use + * model_name_hash as target_model) */ + if (target_model == nullptr) { + target_model = multi_model_name.c_str(); + } + + LOG_INFO << "Invoking SageMaker TargetModel: " << target_model; + + SageMakerMMEHandleInfer(req, target_model, model_version_str_); return; } if (action.empty()) { @@ -330,17 +342,22 @@ SagemakerAPIServer::ParseSageMakerRequest( if (action == "load") { (*parse_map)["url"] = url_string.c_str(); } - (*parse_map)["model_name"] = model_name_string.c_str(); + (*parse_map)["model_name_hash"] = model_name_string.c_str(); - /* Extract targetModel to log the associated archive */ + /* Extract target_model, specified in header, to log the associated archive */ + const char* target_model = + evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model"); - /* Read headers*/ - (*parse_map)["TargetModel"] = "targetModel.tar.gz"; - const char* targetModel = - evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model"); + /* If target_model is not available (e.g., in local testing) use + * model_name_hash as target_model) */ + if (target_model != nullptr) { + (*parse_map)["target_model"] = target_model; + } else { + (*parse_map)["target_model"] = model_name_string.c_str(); + } - LOG_INFO << "Loading SageMaker TargetModel: " << targetModel << std::endl; + LOG_INFO << "Loading SageMaker TargetModel: " << target_model; return; } @@ -443,11 +460,6 @@ SagemakerAPIServer::SageMakerMMEHandleInfer( return; } - /* Extract targetModel to log the associated archive */ - const char* targetModel = - evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model"); - LOG_INFO << "Invoking SageMaker TargetModel: " << targetModel << std::endl; - bool connection_paused = false; int64_t requested_model_version; @@ -687,7 +699,7 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable( LOG_VERBOSE(1) << "Discovered model: " << name << ", version: " << version << " in state: " << state - << "for the reason: " << reason; + << " for the reason: " << reason; break; } @@ -700,19 +712,27 @@ SagemakerAPIServer::SageMakerMMECheckUnloadedModelIsUnavailable( void SagemakerAPIServer::SageMakerMMEUnloadModel( - evhtp_request_t* req, const char* model_name) + evhtp_request_t* req, const char* model_name_hash) { - if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) { - LOG_VERBOSE(1) << "Model " << model_name << " is not loaded." << std::endl; + /* Extract targetModel to log the associated archive */ + const char* target_model = + evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model"); + + /* If target_model is not available (e.g., in local testing) use + * model_name_hash as target_model) */ + if (target_model == nullptr) { + target_model = model_name_hash; + } + + if (sagemaker_models_list_.find(model_name_hash) == + sagemaker_models_list_.end()) { + LOG_VERBOSE(1) << "Model " << target_model << " with model hash " + << model_name_hash << " is not loaded." << std::endl; evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/ return; } - /* Extract targetModel to log the associated archive */ - const char* targetModel = - evhtp_kv_find(req->headers_in, "X-Amzn-SageMaker-Target-Model"); - - LOG_INFO << "Unloading SageMaker TargetModel: " << targetModel << std::endl; + LOG_INFO << "Unloading SageMaker TargetModel: " << target_model << std::endl; auto start_time = std::chrono::high_resolution_clock::now(); @@ -720,7 +740,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel( * ensemble */ TRITONSERVER_Error* unload_err = nullptr; unload_err = - TRITONSERVER_ServerUnloadModelAndDependents(server_.get(), model_name); + TRITONSERVER_ServerUnloadModelAndDependents(server_.get(), target_model); if (unload_err != nullptr) { EVBufferAddErrorJson(req->buffer_out, unload_err); @@ -728,7 +748,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel( LOG_ERROR << "Error when unloading SageMaker Model with dependents for model: " - << model_name << std::endl; + << target_model << std::endl; TRITONSERVER_ErrorDelete(unload_err); return; @@ -745,13 +765,13 @@ SagemakerAPIServer::SageMakerMMEUnloadModel( succeeded.*/ if (unload_err == nullptr) { LOG_VERBOSE(1) << "Using Model Repository Index during UNLOAD to check for " - "status of model: " - << model_name; + "status of model hash: " + << model_name_hash << " for model: " << target_model; while (is_model_unavailable == false && unload_time_in_secs < UNLOAD_TIMEOUT_SECS_) { LOG_VERBOSE(1) << "In the loop to wait for model to be unavailable"; unload_err = SageMakerMMECheckUnloadedModelIsUnavailable( - model_name, &is_model_unavailable); + target_model, &is_model_unavailable); if (unload_err != nullptr) { LOG_ERROR << "Error: Received non-zero exit code on checking for " "model unavailability. " @@ -767,7 +787,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel( end_time - start_time) .count(); } - LOG_INFO << "UNLOAD for model " << model_name << " completed in " + LOG_INFO << "UNLOAD for model " << target_model << " completed in " << unload_time_in_secs << " seconds."; TRITONSERVER_ErrorDelete(unload_err); } @@ -780,7 +800,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel( "result in SageMaker UNLOAD timeout."; } - std::string repo_parent_path = sagemaker_models_list_.at(model_name); + std::string repo_parent_path = sagemaker_models_list_.at(model_name_hash); TRITONSERVER_Error* unregister_err = nullptr; @@ -799,7 +819,7 @@ SagemakerAPIServer::SageMakerMMEUnloadModel( TRITONSERVER_ErrorDelete(unregister_err); std::lock_guard lock(models_list_mutex_); - sagemaker_models_list_.erase(model_name); + sagemaker_models_list_.erase(model_name_hash); } void @@ -946,7 +966,8 @@ SagemakerAPIServer::SageMakerMMELoadModel( const std::unordered_map parse_map) { std::string repo_path = parse_map.at("url"); - std::string model_name = parse_map.at("model_name"); + std::string model_name_hash = parse_map.at("model_name_hash"); + std::string target_model = parse_map.at("target_model"); /* Check subdirs for models and find ensemble model within the repo_path * If only 1 model, that will be selected as model_subdir @@ -1043,7 +1064,8 @@ SagemakerAPIServer::SageMakerMMELoadModel( } auto param = TRITONSERVER_ParameterNew( - model_subdir.c_str(), TRITONSERVER_PARAMETER_STRING, model_name.c_str()); + model_subdir.c_str(), TRITONSERVER_PARAMETER_STRING, + target_model.c_str()); if (param != nullptr) { subdir_modelname_map.emplace_back(param); @@ -1076,7 +1098,7 @@ SagemakerAPIServer::SageMakerMMELoadModel( return; } - err = TRITONSERVER_ServerLoadModel(server_.get(), model_name.c_str()); + err = TRITONSERVER_ServerLoadModel(server_.get(), target_model.c_str()); /* Unlikely after duplicate repo check, but in case Load Model also returns * ALREADY_EXISTS error */ @@ -1091,7 +1113,8 @@ SagemakerAPIServer::SageMakerMMELoadModel( } else { std::lock_guard lock(models_list_mutex_); - sagemaker_models_list_.emplace(model_name, repo_parent_path); + /* Use model name hash as expected in SageMaker MME contract */ + sagemaker_models_list_.emplace(model_name_hash, repo_parent_path); evhtp_send_reply(req, EVHTP_RES_OK); } @@ -1101,7 +1124,7 @@ SagemakerAPIServer::SageMakerMMELoadModel( server_.get(), repo_parent_path.c_str()); LOG_VERBOSE(1) << "Unregistered model repository due to load failure for model: " - << model_name << std::endl; + << target_model << std::endl; } if (err != nullptr) {