From 92ebdebea125ab08126735f7f4cae9ea5d2599d4 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Thu, 28 Sep 2023 14:32:39 -0700 Subject: [PATCH 1/7] Added logic to handle python based backends instead of platform handlers --- src/pb_stub.cc | 65 +++++++++++++------------------------------- src/pb_stub.h | 7 +++-- src/python_be.cc | 15 ++++------ src/python_be.h | 6 ++-- src/stub_launcher.cc | 12 ++++---- src/stub_launcher.h | 2 +- 6 files changed, 40 insertions(+), 67 deletions(-) diff --git a/src/pb_stub.cc b/src/pb_stub.cc index 37c9a5b5..eb5b7fb9 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -82,9 +82,10 @@ Stub::Instantiate( const std::string& shm_region_name, const std::string& model_path, const std::string& model_version, const std::string& triton_install_path, bi::managed_external_buffer::handle_t ipc_control_handle, - const std::string& name, const std::string& platform) + const std::string& name, const std::string& py_backend_based_model) { - model_context_.Init(model_path, platform, triton_install_path, model_version); + model_context_.Init( + model_path, py_backend_based_model, triton_install_path, model_version); name_ = name; health_mutex_ = nullptr; initialized_ = false; @@ -1612,57 +1613,27 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) void ModelContext::Init( - const std::string& model_path, const std::string& platform, + const std::string& model_path, const std::string& py_backend_based_model, const std::string& triton_install_path, const std::string& model_version) { - bool python_model_found = false; - std::string platform_model_path; - - if (platform != "NONE") { - platform_model_path = - triton_install_path + "/platform_handlers/" + platform + "/model.py"; - // Check if model file exists in the path. - struct stat buffer; - if (stat(platform_model_path.c_str(), &buffer) == 0) { - // Use the Platform model for serving the model. - python_model_found = true; - type_ = ModelType::PLATFORM; - python_model_path_ = platform_model_path; - // Trimming the model name from the model path, the platform model - // will populate the expected default model file name into model_path_. - model_dir_ = model_path.substr(0, model_path.find_last_of("\\/")); - } else { - LOG_WARN << "Unable to find model(handler) \'" << platform_model_path - << "\' for platform field \'" << platform << "\'"; - } - } - - if (!python_model_found) { + type_ = ModelType::DEFAULT; + if (py_backend_based_model != "NONE") { + python_model_path_ = py_backend_based_model + "/model.py"; + type_ = ModelType::BACKEND; + } else { python_model_path_ = model_path; // Check if model file exists in this path. struct stat buffer; - if (stat(python_model_path_.c_str(), &buffer) == 0) { - python_model_found = true; - type_ = ModelType::DEFAULT; - } - // Initializing here for consistency with platform model case. - model_dir_ = model_path.substr(0, model_path.find_last_of("\\/")); - } - - if (!python_model_found) { - if (platform != "NONE") { - throw PythonBackendException( - ("Python model file not found in neither \'" + platform_model_path + - "\' nor \'" + model_path + "\'")); - } else { + if (stat(python_model_path_.c_str(), &buffer) != 0) { throw PythonBackendException( ("Python model file not found in \'" + model_path + "\'")); } } + model_dir_ = model_path.substr(0, model_path.find_last_of("\\/")); python_backend_folder_ = triton_install_path; model_version_ = model_version; - platform_ = platform; + py_backend_based_model_ = py_backend_based_model; } void @@ -1693,9 +1664,10 @@ ModelContext::StubSetup(py::module& sys) sys = py::module_::import( (std::string(model_version_) + "." + model_name_trimmed).c_str()); } else { - std::string platform_model_dir( - python_backend_folder_ + "/platform_handlers/" + platform_ + "/"); - sys.attr("path").attr("append")(platform_model_dir); + std::string model_path_parent = + python_model_path_.substr(0, python_model_path_.find_last_of("/")); + std::string backend_model_dir(model_path_parent); + sys.attr("path").attr("append")(backend_model_dir); sys.attr("path").attr("append")(python_backend_folder_); sys = py::module_::import(model_name_trimmed.c_str()); } @@ -1744,14 +1716,15 @@ main(int argc, char** argv) int64_t shm_growth_size = std::stol(argv[4]); std::string triton_install_path = argv[6]; std::string name = argv[8]; - std::string platform = argv[9]; + std::string py_backend_based_model = argv[9]; std::unique_ptr& stub = Stub::GetOrCreateInstance(); try { stub->Instantiate( shm_growth_size, shm_default_size, shm_region_name, model_path, model_version, argv[6] /* triton install path */, - std::stoi(argv[7]) /* IPCControl handle */, name, platform); + std::stoi(argv[7]) /* IPCControl handle */, name, + py_backend_based_model); } catch (const PythonBackendException& pb_exception) { LOG_INFO << "Failed to preinitialize Python stub: " << pb_exception.what(); diff --git a/src/pb_stub.h b/src/pb_stub.h index 6d047d29..3526b090 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -179,9 +179,9 @@ class ModelContext { std::string model_dir_; std::string model_version_; std::string python_backend_folder_; - std::string platform_; + std::string py_backend_based_model_; - enum ModelType { DEFAULT, PLATFORM }; + enum ModelType { DEFAULT, BACKEND }; ModelType type_; }; @@ -209,7 +209,8 @@ class Stub { const std::string& shm_region_name, const std::string& model_path, const std::string& model_version, const std::string& triton_install_path, bi::managed_external_buffer::handle_t ipc_control_handle, - const std::string& model_instance_name, const std::string& platform); + const std::string& model_instance_name, + const std::string& py_backend_based_model); /// Get the health of the stub process. bool& Health(); diff --git a/src/python_be.cc b/src/python_be.cc index b196cfab..13b5ea4e 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -1733,7 +1733,12 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) python_execution_env_ = ""; force_cpu_only_input_tensors_ = true; decoupled_ = false; - platform_ = ""; + const char* execution_model_path = nullptr; + THROW_IF_BACKEND_MODEL_ERROR( + TRITONBACKEND_BackendModelLocation(triton_model, &execution_model_path)); + if (execution_model_path != nullptr) { + py_backend_based_model_ = execution_model_path; + } void* bstate; THROW_IF_BACKEND_MODEL_ERROR(TRITONBACKEND_BackendState(backend, &bstate)); @@ -1774,14 +1779,6 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) } } - triton::common::TritonJson::Value platform; - if (model_config_.Find("platform", &platform)) { - auto error = platform.AsString(&platform_); - if (error != nullptr) { - throw BackendModelException(error); - } - } - // Skip the FORCE_CPU_ONLY_INPUT_TENSORS variable if it doesn't exits. std::string force_cpu_only_input_tensor; error = nullptr; diff --git a/src/python_be.h b/src/python_be.h index 825c45de..b16c6979 100644 --- a/src/python_be.h +++ b/src/python_be.h @@ -237,8 +237,8 @@ class ModelState : public BackendModel { // Is decoupled API being used. bool IsDecoupled() { return decoupled_; } - // Returns the value in the platform field - std::string Platform() { return platform_; } + // Returns the value in the `py_backend_based_model_` field + std::string PythonBackendBasedModel() { return py_backend_based_model_; } // Launch auto-complete stub process. TRITONSERVER_Error* LaunchAutoCompleteStubProcess(); @@ -255,7 +255,7 @@ class ModelState : public BackendModel { std::string python_execution_env_; bool force_cpu_only_input_tensors_; bool decoupled_; - std::string platform_; + std::string py_backend_based_model_; std::unique_ptr auto_complete_stub_; }; diff --git a/src/stub_launcher.cc b/src/stub_launcher.cc index de4dd46c..2f271c63 100644 --- a/src/stub_launcher.cc +++ b/src/stub_launcher.cc @@ -62,9 +62,9 @@ StubLauncher::Initialize(ModelState* model_state) model_state->ModelConfig().Write(&model_config_buffer_); is_decoupled_ = model_state->IsDecoupled(); model_repository_path_ = model_state->RepositoryPath(); - platform_ = model_state->Platform(); - if (platform_.empty()) { - platform_ = "NONE"; + py_backend_based_model_ = model_state->PythonBackendBasedModel(); + if (py_backend_based_model_.empty()) { + py_backend_based_model_ = "NONE"; } // Atomically increase and read the stub process count to avoid shared memory @@ -238,7 +238,8 @@ StubLauncher::Launch() << ":$LD_LIBRARY_PATH " << python_backend_stub << " " << model_path_ << " " << shm_region_name_ << " " << shm_default_byte_size_ << " " << shm_growth_byte_size_ << " " << parent_pid_ << " " << python_lib_ - << " " << ipc_control_handle_ << " " << stub_name << " " << platform_; + << " " << ipc_control_handle_ << " " << stub_name << " " + << py_backend_based_model_; ipc_control_->uses_env = true; bash_argument = ss.str(); } else { @@ -246,7 +247,8 @@ StubLauncher::Launch() ss << " exec " << python_backend_stub << " " << model_path_ << " " << shm_region_name_ << " " << shm_default_byte_size_ << " " << shm_growth_byte_size_ << " " << parent_pid_ << " " << python_lib_ - << " " << ipc_control_handle_ << " " << stub_name << " " << platform_; + << " " << ipc_control_handle_ << " " << stub_name << " " + << py_backend_based_model_; bash_argument = ss.str(); } LOG_MESSAGE( diff --git a/src/stub_launcher.h b/src/stub_launcher.h index 89f35422..ac6c4901 100644 --- a/src/stub_launcher.h +++ b/src/stub_launcher.h @@ -161,7 +161,7 @@ class StubLauncher { std::string shm_region_name_; std::string model_repository_path_; std::string model_path_; - std::string platform_; + std::string py_backend_based_model_; const std::string stub_process_kind_; std::string model_name_; const std::string model_instance_name_; From 85b811a1f35974ea674406b5e3623ea8e5378bdd Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Tue, 3 Oct 2023 10:39:29 -0700 Subject: [PATCH 2/7] Second iteration --- src/pb_stub.cc | 17 ++++++------ src/pb_stub.h | 4 +-- src/python_be.cc | 63 ++++++++++++++++++++++++++++++++++++++------ src/python_be.h | 7 ++--- src/stub_launcher.cc | 10 +++---- src/stub_launcher.h | 2 +- 6 files changed, 75 insertions(+), 28 deletions(-) diff --git a/src/pb_stub.cc b/src/pb_stub.cc index eb5b7fb9..774d6468 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -82,10 +82,10 @@ Stub::Instantiate( const std::string& shm_region_name, const std::string& model_path, const std::string& model_version, const std::string& triton_install_path, bi::managed_external_buffer::handle_t ipc_control_handle, - const std::string& name, const std::string& py_backend_based_model) + const std::string& name, const std::string& python_runtime_model) { model_context_.Init( - model_path, py_backend_based_model, triton_install_path, model_version); + model_path, python_runtime_model, triton_install_path, model_version); name_ = name; health_mutex_ = nullptr; initialized_ = false; @@ -1613,12 +1613,12 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module) void ModelContext::Init( - const std::string& model_path, const std::string& py_backend_based_model, + const std::string& model_path, const std::string& runtime_modeldir, const std::string& triton_install_path, const std::string& model_version) { type_ = ModelType::DEFAULT; - if (py_backend_based_model != "NONE") { - python_model_path_ = py_backend_based_model + "/model.py"; + if (runtime_modeldir != "DEFAULT") { + python_model_path_ = runtime_modeldir + "/model.py"; type_ = ModelType::BACKEND; } else { python_model_path_ = model_path; @@ -1633,7 +1633,7 @@ ModelContext::Init( model_dir_ = model_path.substr(0, model_path.find_last_of("\\/")); python_backend_folder_ = triton_install_path; model_version_ = model_version; - py_backend_based_model_ = py_backend_based_model; + runtime_modeldir_ = runtime_modeldir; } void @@ -1716,15 +1716,14 @@ main(int argc, char** argv) int64_t shm_growth_size = std::stol(argv[4]); std::string triton_install_path = argv[6]; std::string name = argv[8]; - std::string py_backend_based_model = argv[9]; + std::string runtime_modeldir = argv[9]; std::unique_ptr& stub = Stub::GetOrCreateInstance(); try { stub->Instantiate( shm_growth_size, shm_default_size, shm_region_name, model_path, model_version, argv[6] /* triton install path */, - std::stoi(argv[7]) /* IPCControl handle */, name, - py_backend_based_model); + std::stoi(argv[7]) /* IPCControl handle */, name, runtime_modeldir); } catch (const PythonBackendException& pb_exception) { LOG_INFO << "Failed to preinitialize Python stub: " << pb_exception.what(); diff --git a/src/pb_stub.h b/src/pb_stub.h index 3526b090..ab405cb1 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -179,7 +179,7 @@ class ModelContext { std::string model_dir_; std::string model_version_; std::string python_backend_folder_; - std::string py_backend_based_model_; + std::string runtime_modeldir_; enum ModelType { DEFAULT, BACKEND }; ModelType type_; @@ -210,7 +210,7 @@ class Stub { const std::string& model_version, const std::string& triton_install_path, bi::managed_external_buffer::handle_t ipc_control_handle, const std::string& model_instance_name, - const std::string& py_backend_based_model); + const std::string& runtime_modeldir); /// Get the health of the stub process. bool& Health(); diff --git a/src/python_be.cc b/src/python_be.cc index 13b5ea4e..6c5cec7a 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -1733,16 +1733,12 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) python_execution_env_ = ""; force_cpu_only_input_tensors_ = true; decoupled_ = false; - const char* execution_model_path = nullptr; - THROW_IF_BACKEND_MODEL_ERROR( - TRITONBACKEND_BackendModelLocation(triton_model, &execution_model_path)); - if (execution_model_path != nullptr) { - py_backend_based_model_ = execution_model_path; - } void* bstate; THROW_IF_BACKEND_MODEL_ERROR(TRITONBACKEND_BackendState(backend, &bstate)); backend_state_ = reinterpret_cast(bstate); + + runtime_modeldir_ = backend_state_->runtime_modeldir; triton::common::TritonJson::Value params; common::TritonJson::Value model_config; if (model_config_.Find("parameters", ¶ms)) { @@ -1907,8 +1903,12 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) backend_state->shm_message_queue_size = 1000; backend_state->number_of_instance_inits = 0; backend_state->thread_pool_size = 32; + // Initialize shared memory region prefix to include backend's name + // to avoid collision between python backend and python backend based + // backends. backend_state->shared_memory_region_prefix = - "triton_python_backend_shm_region_"; + "triton_" + name + "_backend_shm_region_"; + std::string default_backend_dir_string; if (backend_config.Find("cmdline", &cmdline)) { triton::common::TritonJson::Value shm_growth_size; @@ -2018,6 +2018,12 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, ia.what()); } } + + triton::common::TritonJson::Value default_backend_dir; + if (cmdline.Find("backend-directory", &default_backend_dir)) { + RETURN_IF_ERROR( + default_backend_dir.AsString(&default_backend_dir_string)); + } } LOG_MESSAGE( @@ -2035,7 +2041,48 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) TRITONBACKEND_ArtifactType artifact_type; RETURN_IF_ERROR( TRITONBACKEND_BackendArtifacts(backend, &artifact_type, &location)); - backend_state->python_lib = location; + + // Check if `triton_python_backend_stub` and `triton_python_backend_utils.py` + // are located under `location`. + std::string default_python_backend_dir = + default_backend_dir_string + "/python"; + std::string backend_stub_path = + std::string(location) + "/triton_python_backend_stub"; + std::string backend_utils = + std::string(location) + "/triton_python_backend_utils.py"; + // Both, stub and utils should be in the same location + if (FileExists(backend_stub_path) && FileExists(backend_utils)) { + backend_state->python_lib = location; + // If `location` is default location of a python backend, + // then we are using default python backend. + if (default_python_backend_dir == std::string(location)) { + backend_state->runtime_modeldir = ""; + } else { + // If `location` is not default location of a python backend, + // then we are using a python backend based backend and model.py stored + // in the received location. + backend_state->runtime_modeldir = location; + } + } else { + // If stub and utils are not found in received `location`, + // then we are using a python backend based backend and stub and utils are + // stored in the default python backend location. + if (!default_backend_dir_string.empty()) { + std::string default_python_backend_dir = + default_backend_dir_string + "/python/triton_python_backend_stub"; + if (!FileExists(default_python_backend_dir)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_NOT_FOUND, + (std::string("triton_python_backend_stub") + + " is not found. Searched paths: " + default_backend_dir_string + + "/python and" + std::string(location)) + .c_str()); + } + } + backend_state->runtime_modeldir = location; + backend_state->python_lib = default_backend_dir_string + "/python"; + } + backend_state->env_manager = std::make_unique(); RETURN_IF_ERROR(TRITONBACKEND_BackendSetState( diff --git a/src/python_be.h b/src/python_be.h index b16c6979..f84d2323 100644 --- a/src/python_be.h +++ b/src/python_be.h @@ -218,6 +218,7 @@ struct BackendState { std::string shared_memory_region_prefix; int64_t thread_pool_size; std::unique_ptr env_manager; + std::string runtime_modeldir; }; class ModelState : public BackendModel { @@ -237,8 +238,8 @@ class ModelState : public BackendModel { // Is decoupled API being used. bool IsDecoupled() { return decoupled_; } - // Returns the value in the `py_backend_based_model_` field - std::string PythonBackendBasedModel() { return py_backend_based_model_; } + // Returns the value in the `runtime_modeldir_` field + std::string RuntimeModelDir() { return runtime_modeldir_; } // Launch auto-complete stub process. TRITONSERVER_Error* LaunchAutoCompleteStubProcess(); @@ -255,7 +256,7 @@ class ModelState : public BackendModel { std::string python_execution_env_; bool force_cpu_only_input_tensors_; bool decoupled_; - std::string py_backend_based_model_; + std::string runtime_modeldir_; std::unique_ptr auto_complete_stub_; }; diff --git a/src/stub_launcher.cc b/src/stub_launcher.cc index 2f271c63..a38409ec 100644 --- a/src/stub_launcher.cc +++ b/src/stub_launcher.cc @@ -62,9 +62,9 @@ StubLauncher::Initialize(ModelState* model_state) model_state->ModelConfig().Write(&model_config_buffer_); is_decoupled_ = model_state->IsDecoupled(); model_repository_path_ = model_state->RepositoryPath(); - py_backend_based_model_ = model_state->PythonBackendBasedModel(); - if (py_backend_based_model_.empty()) { - py_backend_based_model_ = "NONE"; + runtime_modeldir_ = model_state->RuntimeModelDir(); + if (runtime_modeldir_.empty()) { + runtime_modeldir_ = "DEFAULT"; } // Atomically increase and read the stub process count to avoid shared memory @@ -239,7 +239,7 @@ StubLauncher::Launch() << " " << shm_region_name_ << " " << shm_default_byte_size_ << " " << shm_growth_byte_size_ << " " << parent_pid_ << " " << python_lib_ << " " << ipc_control_handle_ << " " << stub_name << " " - << py_backend_based_model_; + << runtime_modeldir_; ipc_control_->uses_env = true; bash_argument = ss.str(); } else { @@ -248,7 +248,7 @@ StubLauncher::Launch() << shm_region_name_ << " " << shm_default_byte_size_ << " " << shm_growth_byte_size_ << " " << parent_pid_ << " " << python_lib_ << " " << ipc_control_handle_ << " " << stub_name << " " - << py_backend_based_model_; + << runtime_modeldir_; bash_argument = ss.str(); } LOG_MESSAGE( diff --git a/src/stub_launcher.h b/src/stub_launcher.h index ac6c4901..3bbd2463 100644 --- a/src/stub_launcher.h +++ b/src/stub_launcher.h @@ -161,7 +161,7 @@ class StubLauncher { std::string shm_region_name_; std::string model_repository_path_; std::string model_path_; - std::string py_backend_based_model_; + std::string runtime_modeldir_; const std::string stub_process_kind_; std::string model_name_; const std::string model_instance_name_; From 5093b3384c47b66777205b654c410d522ca83d1f Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Tue, 3 Oct 2023 17:34:44 -0700 Subject: [PATCH 3/7] Changed names per suggestion, added comments --- src/pb_stub.cc | 2 ++ src/pb_stub.h | 6 ++++++ src/python_be.cc | 7 +++---- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/pb_stub.cc b/src/pb_stub.cc index 774d6468..c029a711 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -1618,6 +1618,8 @@ ModelContext::Init( { type_ = ModelType::DEFAULT; if (runtime_modeldir != "DEFAULT") { + // For python based backends, existence of `model.py` in the corresponding + // backend folder happens on the core side, so we can omit this check here. python_model_path_ = runtime_modeldir + "/model.py"; type_ = ModelType::BACKEND; } else { diff --git a/src/pb_stub.h b/src/pb_stub.h index ab405cb1..d9c9014c 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -181,6 +181,12 @@ class ModelContext { std::string python_backend_folder_; std::string runtime_modeldir_; + // Triton supports python-based backends, + // i.e. backends that provide common `model.py`, that can be re-used + // between different models. `ModelType` helps to differentiate + // between models running with c++ python backend (ModelType::DEFAULT) + // and models running with python-based backend (ModelType::BACKEND) + // at the time of ModelContext::StubSetup to properly set up paths. enum ModelType { DEFAULT, BACKEND }; ModelType type_; }; diff --git a/src/python_be.cc b/src/python_be.cc index 6c5cec7a..813abcba 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -1904,8 +1904,7 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) backend_state->number_of_instance_inits = 0; backend_state->thread_pool_size = 32; // Initialize shared memory region prefix to include backend's name - // to avoid collision between python backend and python backend based - // backends. + // to avoid collision between python backend and python-based backends. backend_state->shared_memory_region_prefix = "triton_" + name + "_backend_shm_region_"; std::string default_backend_dir_string; @@ -2068,9 +2067,9 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) // then we are using a python backend based backend and stub and utils are // stored in the default python backend location. if (!default_backend_dir_string.empty()) { - std::string default_python_backend_dir = + std::string backend_stub_path = default_backend_dir_string + "/python/triton_python_backend_stub"; - if (!FileExists(default_python_backend_dir)) { + if (!FileExists(backend_stub_path)) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_NOT_FOUND, (std::string("triton_python_backend_stub") + From f2af6d95d87a06c69c186fde76e8f65c4c1d51dd Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Wed, 4 Oct 2023 19:52:20 -0700 Subject: [PATCH 4/7] Nuked platform handlers from this repo --- CMakeLists.txt | 10 +- .../platform_handlers/pytorch/model.py | 323 ----------- .../tensorflow_savedmodel/README.md | 87 --- .../tensorflow_savedmodel/model.py | 536 ------------------ 4 files changed, 2 insertions(+), 954 deletions(-) delete mode 100755 src/resources/platform_handlers/pytorch/model.py delete mode 100644 src/resources/platform_handlers/tensorflow_savedmodel/README.md delete mode 100644 src/resources/platform_handlers/tensorflow_savedmodel/model.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 93a7ae60..5122ef6c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,7 +65,8 @@ FetchContent_Declare( ) FetchContent_Declare( repo-core - GIT_REPOSITORY https://github.com/triton-inference-server/core.git + #GIT_REPOSITORY https://github.com/triton-inference-server/core.git + GIT_REPOSITORY oandreeva@172.17.0.1:/home/oandreeva/Code/core GIT_TAG ${TRITON_CORE_REPO_TAG} ) FetchContent_Declare( @@ -307,13 +308,6 @@ install( ${INSTALL_CONFIGDIR} ) -install( - DIRECTORY - src/resources/platform_handlers - DESTINATION - ${CMAKE_INSTALL_PREFIX}/backends/python -) - install( FILES src/resources/triton_python_backend_utils.py diff --git a/src/resources/platform_handlers/pytorch/model.py b/src/resources/platform_handlers/pytorch/model.py deleted file mode 100755 index 365599e0..00000000 --- a/src/resources/platform_handlers/pytorch/model.py +++ /dev/null @@ -1,323 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import importlib -import json -import os - -try: - import torch -except ModuleNotFoundError as error: - raise RuntimeError( - "Missing/Incomplete PyTorch package installation... (Did you install PyTorch?)" - ) from error - -# triton_python_backend_utils is available in every Triton Python model. You -# need to use this module to create inference requests and responses. It also -# contains some utility functions for extracting information from model_config -# and converting Triton input/output types to numpy types. -import triton_python_backend_utils as pb_utils - - -def _get_model_path(config): - filenames = ["model.py", "model.pt"] - if config["default_model_filename"]: - filenames.insert(0, config["default_model_filename"]) - for filename in filenames: - model_path = os.path.join(pb_utils.get_model_dir(), filename) - if os.path.exists(model_path): - return model_path - raise pb_utils.TritonModelException( - "No model found in " + pb_utils.get_model_dir() + "/" + str(filenames) - ) - - -def _get_model_data_path(model_path): - data_path_extensions = [".pt"] - model_path_no_extension = model_path[: -(len(model_path.split(".")[-1]) + 1)] - for extension in data_path_extensions: - data_path = model_path_no_extension + extension - if os.path.exists(data_path): - return data_path - # data file not provided - return "" - - -def _is_py_class_model(model_path): - return model_path[-3:] == ".py" - - -def _import_module_from_path(module_name, file_path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - - -def _get_model_class_from_module(module): - names = dir(module) - for name in names: - attr = getattr(module, name) - try: - if issubclass(attr, torch.nn.Module): - return attr - except TypeError: - # attr may not be a class - pass - raise pb_utils.TritonModelException("Cannot find a subclass of torch.nn.Module") - - -def _parse_io_config(io_config): - io = [] - for conf in io_config: - io.append({"name": conf["name"]}) - return io - - -def _get_device_name(kind, device_id): - if kind == "GPU": - return "cuda:" + device_id - if kind == "CPU": - return "cpu" - # unspecified device - return "" - - -def _get_device(kind, device_id, model): - device_name = _get_device_name(kind, device_id) - if device_name == "": - for param in model.parameters(): - return param.device - raise pb_utils.TritonModelException("Cannot determine model device") - return torch.device(device_name) - - -def _set_torch_parallelism(config): - log_msg = "" - parallelism_settings = ["NUM_THREADS", "NUM_INTEROP_THREADS"] - for setting in parallelism_settings: - val = "1" - if setting in config["parameters"]: - val = config["parameters"][setting]["string_value"] - getattr(torch, "set_" + setting.lower())(int(val)) - log_msg += setting + " = " + val + "; " - return log_msg - - -def _get_torch_compile_params(config): - params = {} - if "TORCH_COMPILE_OPTIONAL_PARAMETERS" in config["parameters"]: - val = config["parameters"]["TORCH_COMPILE_OPTIONAL_PARAMETERS"]["string_value"] - params = json.loads(val) - if "model" in params: - raise pb_utils.TritonModelException( - "'model' is not an optional parameter for 'torch.compile'" - ) - return params - - -def _gather_torch_tensors(scatter_tensors): - gather_tensors = [] - sections = [] - for i in range(len(scatter_tensors)): - tensors = scatter_tensors[i] - for j in range(len(tensors)): - tensor = tensors[j] - if j < len(gather_tensors): - # add to existing tensor - gather_tensors[j] = torch.cat((gather_tensors[j], tensor), 0) - else: - # start a new tensor - gather_tensors.append(tensor) - # record section - section_length = tensors[0].size()[0] - sections.append(section_length) - return gather_tensors, sections - - -def _scatter_torch_tensors(gather_tensors, sections): - scatter_tensors = [] - for j in range(len(gather_tensors)): - scatter_tensor = torch.split(gather_tensors[j], sections) - for i in range(len(scatter_tensor)): - tensor = scatter_tensor[i] - if i < len(scatter_tensors): - # add to existing response - scatter_tensors[i].append(tensor) - else: - # start a new response - scatter_tensors.append([tensor]) - return scatter_tensors - - -class TritonPythonModel: - """Your Python model must use the same class name. Every Python model - that is created must have "TritonPythonModel" as the class name. - """ - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - self._model_name = args["model_name"] - for_model = "for '" + self._model_name + "'" - self._logger = pb_utils.Logger - self._logger.log_info("Initializing model instance " + for_model) - - self._model_config = json.loads(args["model_config"]) - self._kind = args["model_instance_kind"] - self._device_id = args["model_instance_device_id"] - self._support_batching = self._model_config["max_batch_size"] > 0 - self._inputs = _parse_io_config(self._model_config["input"]) - self._outputs = _parse_io_config(self._model_config["output"]) - - setting_msg = _set_torch_parallelism(self._model_config) - self._logger.log_verbose( - "Torch parallelism settings " + for_model + ": " + setting_msg - ) - - self._infer_mode = torch.inference_mode(mode=True) - self._infer_mode.__enter__() - - params = _get_torch_compile_params(self._model_config) - self._logger.log_verbose( - "'torch.compile' optional parameter(s) " + for_model + ": " + str(params) - ) - if self._support_batching: - self._gather = torch.compile(_gather_torch_tensors, **params) - self._scatter = torch.compile(_scatter_torch_tensors, **params) - - model_path = _get_model_path(self._model_config) - if not _is_py_class_model(model_path): - self._logger.log_info("Loading '" + self._model_name + "' as TorchScript") - self._model = torch.jit.load(model_path) - self._device = _get_device(self._kind, self._device_id, self._model) - self._model.to(self._device) - self._model.eval() - return - - self._model_module = _import_module_from_path(self._model_name, model_path) - self._model_class = _get_model_class_from_module(self._model_module) - self._raw_model = self._model_class() - self._device = _get_device(self._kind, self._device_id, self._raw_model) - data_path = _get_model_data_path(model_path) - if data_path != "": - self._raw_model.load_state_dict( - torch.load(data_path, map_location=self._device) - ) - else: - self._logger.log_info("Model parameter file not found " + for_model) - self._raw_model.to(self._device) - self._raw_model.eval() - self._model = torch.compile(self._raw_model, **params) - - def execute(self, requests): - """`execute` MUST be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference request is made - for this model. Depending on the batching configuration (e.g. Dynamic - Batching) used, `requests` may contain multiple requests. Every - Python model, must create one pb_utils.InferenceResponse for every - pb_utils.InferenceRequest in `requests`. If there is an error, you can - set the error argument when creating a pb_utils.InferenceResponse - - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - - responses = [] - - requests_tensors = [] - for request in requests: - tensors = [] - for io in self._inputs: - tensor = pb_utils.get_input_tensor_by_name( - request, io["name"] - ).to_dlpack() - tensor = torch.from_dlpack(tensor).to(self._device) - tensors.append(tensor) - requests_tensors.append(tensors) - - sections = None - if self._support_batching: - requests_tensors, sections = self._gather(requests_tensors) - requests_tensors = [requests_tensors] - - responses_tensors = [] - for input_tensors in requests_tensors: - output_tensors = self._model(*input_tensors) - if not isinstance(output_tensors, tuple) and not isinstance( - output_tensors, list - ): - output_tensors = [output_tensors] - responses_tensors.append(output_tensors) - - if self._support_batching: - responses_tensors = self._scatter(responses_tensors[0], sections) - - for response_tensors in responses_tensors: - output_tensors = [] - for i in range(len(self._outputs)): - io = self._outputs[i] - tensor = response_tensors[i].detach() - tensor = pb_utils.Tensor.from_dlpack(io["name"], tensor) - output_tensors.append(tensor) - inference_response = pb_utils.InferenceResponse( - output_tensors=output_tensors - ) - responses.append(inference_response) - - return responses - - def finalize(self): - """`finalize` is called only once when the model is being unloaded. - Implementing `finalize` function is OPTIONAL. This function allows - the model to perform any necessary clean ups before exit. - """ - self._logger.log_info("Removing model instance for '" + self._model_name + "'") - self._infer_mode.__exit__(exc_type=None, exc_value=None, traceback=None) diff --git a/src/resources/platform_handlers/tensorflow_savedmodel/README.md b/src/resources/platform_handlers/tensorflow_savedmodel/README.md deleted file mode 100644 index 23199e7b..00000000 --- a/src/resources/platform_handlers/tensorflow_savedmodel/README.md +++ /dev/null @@ -1,87 +0,0 @@ - - -# Serving Tensorflow SavedModels using Python Backend \[Experimental\] - -*NOTE*: This feature is subject to change and removal, and should not -be used in production. - -Starting from 23.07, we are adding experimental support for loading -and serving of models in [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) -format via Python backend. The `model.savedmodel` can be provided within -the triton server model repository without `model.py` and backend will -automatically use a pre-built python model (`model.py`)[model.py] to load -and serve provided TF SavedModel. The handler can [auto-complete](../../../../README.md#auto_complete_config) -the missing model configuration. - -The model repository structure can look like: - -``` -model_repository/ -`-- resnet_v1_50_savedmodel - |-- 1 - | `-- model.savedmodel - | |-- saved_model.pb - | `-- variables - |-- config.pbtxt - `-- resnet50_labels.txt -``` - -In order to use this feature, make sure that [TensorFlow pip package](https://pypi.org/project/tensorflow/2.13.0/) -is available in the same Python environment. - -``` -pip install tensorfow==2.13.0 -``` - -Alternatively, you can create a -[Python Execution Environment](#using-custom-python-execution-environments) -with the TensorFlow dependency. - -By default, Triton will use the [TensorFlow backend](https://github.com/triton-inference-server/tensorflow_backend) -to load and serve the saved model. In order to use the Python backend with -TensorFlow SavedModel, [model configuration](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md) -should explicitly provide the following settings: - -``` -backend: "python" -platform: "tensorflow_savedmodel" -``` - -It has been observed that certain DLFW like TensorFlow do not release the entire -memory allocated for loading a model back to the system when the model gets -unloaded. This can be problematic when working with a large number of models and -dynamically loading/unloading them. Using Python backend for TF SavedModel serving -will allow the models to be loaded in a separate process, which ensures that entire -memory allocated within the process would be released to the system upon a model -unload. - -Following are few known limitations of this feature: -- GPU execution is not supported. -- List of requests received in model [`execute`](../../../../README.md#execute) function are -not run in a single batch but one after the other. diff --git a/src/resources/platform_handlers/tensorflow_savedmodel/model.py b/src/resources/platform_handlers/tensorflow_savedmodel/model.py deleted file mode 100644 index 24b95472..00000000 --- a/src/resources/platform_handlers/tensorflow_savedmodel/model.py +++ /dev/null @@ -1,536 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import json -import os - -try: - import tensorflow as tf - from tensorflow.core.framework import types_pb2 - from tensorflow.python.client import session - from tensorflow.python.saved_model import loader, signature_constants - from tensorflow.python.tools import saved_model_utils -except ModuleNotFoundError as error: - raise RuntimeError( - "Missing/Incomplete tensorflow package installation..." - ) from error - -# triton_python_backend_utils is available in every Triton Python model. You -# need to use this module to create inference requests and responses. It also -# contains some utility functions for extracting information from model_config -# and converting Triton input/output types to numpy types. -import triton_python_backend_utils as pb_utils - -TF_STRING_TO_TRITON = { - "DT_BOOL": "TYPE_BOOL", - "DT_UINT8": "TYPE_UINT8", - "DT_UINT16": "TYPE_UINT16", - "DT_UINT32": "TYPE_UINT32", - "DT_UINT64": "TYPE_UINT64", - "DT_INT8": "TYPE_INT8", - "DT_INT16": "TYPE_INT16", - "DT_INT32": "TYPE_INT32", - "DT_INT64": "TYPE_INT64", - "DT_HALF": "TYPE_FP16", - "DT_FLOAT": "TYPE_FP32", - "DT_DOUBLE": "TYPE_FP64", - "DT_STRING": "TYPE_STRING", -} - -_DEFAULT_ARTIFACT_NAME = "model.savedmodel" - - -def _get_savedmodel_path(config): - artifact_name = config["default_model_filename"] - if not artifact_name: - artifact_name = _DEFAULT_ARTIFACT_NAME - - savedmodel_path = os.path.join(pb_utils.get_model_dir(), artifact_name) - if not os.path.exists(savedmodel_path): - raise pb_utils.TritonModelException( - f"No savedmodel dir found in " + savedmodel_path - ) - - return savedmodel_path - - -def _parse_signature_def(config): - if config["parameters"]: - if "TF_SIGNATURE_DEF" in config["parameters"].keys(): - return config["parameters"]["TF_SIGNATURE_DEF"]["string_value"] - return None - - -def _parse_graph_tag(config): - if config["parameters"]: - if "TF_GRAPH_TAG" in config["parameters"].keys(): - return config["parameters"]["TF_GRAPH_TAG"]["string_value"] - return None - - -def _parse_num_intra_threads(config): - if config["parameters"]: - if "TF_NUM_INTRA_THREADS" in config["parameters"].keys(): - return int(config["parameters"]["TF_NUM_INTRA_THREADS"]["string_value"]) - return None - - -def _parse_num_inter_threads(config): - if config["parameters"]: - if "TF_NUM_INTER_THREADS" in config["parameters"].keys(): - return int(config["parameters"]["TF_NUM_INTER_THREADS"]["string_value"]) - return None - - -def _get_truth_value(string_value): - val = string_value.casefold() - if val == "yes" or val == "1" or val == "on" or val == "true": - return True - else: - return False - - -def _parse_use_per_session_thread(config): - if config["parameters"]: - if "USE_PER_SESSION_THREAD" in config["parameters"].keys(): - val = config["parameters"]["USE_PER_SESSION_THREAD"]["string_value"] - return _get_truth_value(val) - return False - - -def _get_signature_def(savedmodel_path, config): - tag_sets = saved_model_utils.get_saved_model_tag_sets(savedmodel_path) - graph_tag = _parse_graph_tag(config) - if graph_tag is None: - if "serve" in tag_sets[0]: - graph_tag = "serve" - else: - graph_tag = tag_sets[0][0] - - meta_graph_def = saved_model_utils.get_meta_graph_def(savedmodel_path, graph_tag) - signature_def_map = meta_graph_def.signature_def - signature_def_k = _parse_signature_def(config) - if signature_def_k is None: - serving_default = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if serving_default in signature_def_map.keys(): - signature_def_k = serving_default - else: - signature_def_k = signature_def_map.keys()[0] - - if signature_def_k not in signature_def_map.keys(): - raise pb_utils.TritonModelException( - f" The model does not include the signature_def '" + signature_def_k + "'" - ) - - return graph_tag, signature_def_map[signature_def_k] - - -def _has_batch_dim(tensor_info): - if tensor_info.tensor_shape.unknown_rank: - return True - elif tensor_info.tensor_shape.dim[0].size == -1: - return True - else: - return False - - -def _get_batching_hint_from_signature(signature_def): - for input_info in signature_def.inputs.values(): - if not _has_batch_dim(input_info): - return False - - for output_info in signature_def.outputs.values(): - if not _has_batch_dim(output_info): - return False - - return True - - -def _convert_proto_to_dict_tensor(name, tensor_proto, batching_enabled): - tensor_dict = {} - tensor_dict["name"] = name - dtype_dict = {value: key for (key, value) in types_pb2.DataType.items()} - tensor_dict["data_type"] = TF_STRING_TO_TRITON[dtype_dict[tensor_proto.dtype]] - if tensor_proto.tensor_shape.unknown_rank: - # FIXME: Fix the handling of unknown rank - dims = [-1] - else: - dims = [dim.size for dim in tensor_proto.tensor_shape.dim] - if batching_enabled: - tensor_dict["dims"] = dims[1:] - else: - tensor_dict["dims"] = dims - - return tensor_dict - - -def _validate_datatype(tf_dtype, triton_datatype, tensor_name): - dtype_dict = {value: key for (key, value) in types_pb2.DataType.items()} - if triton_datatype != TF_STRING_TO_TRITON[dtype_dict[tf_dtype]]: - raise pb_utils.TritonModelException( - f" Mismatch between datatype for tensor '" - + tensor_name - + "', expected '" - + TF_STRING_TO_TRITON[dtype_dict[tf_dtype]] - + "', got '" - + triton_datatype - ) - - -def _validate_dims(tf_shape, triton_dims, batching_enabled, tensor_name): - if tf_shape.unknown_rank: - return - - index = 0 - offset = 1 if batching_enabled else 0 - if len(tf_shape.dim) != (offset + len(triton_dims)): - raise pb_utils.TritonModelException( - f" Mismatch in the number of dimension with the model for tensor '" - + tensor_name - + "', expected " - + str(len(tf_shape.dim) - offset) - + ", got " - + str(len(triton_dims)) - ) - - for dim in tf_shape.dim: - if index == 0 and batching_enabled: - if dim.size != -1: - raise pb_utils.TritonModelException( - f" The first dimension of a batching model should be dynamic, " - "however, got shape of first dimension in model for tensor '" - + tensor_name - + "' as " - + str(dim.size) - ) - else: - if dim.size != triton_dims[index - offset]: - raise pb_utils.TritonModelException( - f" Mismatch in " - + str(index - offset) - + "th dimension for tensor '" - + tensor_name - + "', expected " - + str(dim.size) - + ", got " - + str(triton_dims[index - offset]) - ) - index = index + 1 - - -def _validate_model_config(model_config, signature_def): - signature_supports_batching = _get_batching_hint_from_signature(signature_def) - if (not signature_supports_batching) and (model_config["max_batch_size"] != 0): - raise pb_utils.TritonModelException( - f" The model signature does not support batching, yet model config" - " has max_batch_size set to '" + str(model_config["max_batch_size"]) + "'" - ) - - batching_enabled = model_config["max_batch_size"] != 0 - - if model_config["platform"] != "tensorflow_savedmodel": - raise pb_utils.TritonModelException( - f"[INTERNAL]: The platform field for using this model should be set to" - " 'tensorflow_savedmodel' in model config, got '" - + model_config["platform"] - + "'" - ) - if model_config["batch_input"]: - raise pb_utils.TritonModelException( - f"The platform model '" - + model_config["platform"] - + "' does not support model with batch_input" - ) - if model_config["batch_output"]: - raise pb_utils.TritonModelException( - f"The platform model '" - + model_config["platform"] - + "' does not support model with batch_output" - ) - - # Validate input tensors - input_tensor_info = signature_def.inputs - config_input_names = [input["name"] for input in model_config["input"]] - for input_name in input_tensor_info.keys(): - if input_name not in config_input_names: - raise pb_utils.TritonModelException( - f" Missing input tensor configuration for tensor '" + input_name + "'" - ) - for input in model_config["input"]: - config_input_name = input["name"] - if config_input_name not in input_tensor_info.keys(): - supported_names = "" - for valid_name in input_tensor_info.keys(): - supported_names = supported_names + ";" + valid_name - raise pb_utils.TritonModelException( - f" No input tensor with name '" - + config_input_name - + "', only supported input names are " - + supported_names - ) - _validate_datatype( - input_tensor_info[config_input_name].dtype, - input["data_type"], - config_input_name, - ) - _validate_dims( - input_tensor_info[config_input_name].tensor_shape, - input["dims"], - batching_enabled, - config_input_name, - ) - - # Validate output tensors - output_tensor_info = signature_def.outputs - for output in model_config["output"]: - config_output_name = output["name"] - if config_output_name not in output_tensor_info.keys(): - supported_names = "" - for valid_name in output_tensor_info.keys(): - supported_names = supported_names + ";" + valid_name - raise pb_utils.TritonModelException( - f" No output tensor with name '" - + config_output_name - + "', only supported output names are " - + supported_names - ) - - _validate_datatype( - output_tensor_info[config_output_name].dtype, - output["data_type"], - config_output_name, - ) - _validate_dims( - output_tensor_info[config_output_name].tensor_shape, - output["dims"], - batching_enabled, - config_output_name, - ) - - -class TritonPythonModel: - """Your Python model must use the same class name. Every Python model - that is created must have "TritonPythonModel" as the class name. - """ - - @staticmethod - def auto_complete_config(auto_complete_model_config): - config = auto_complete_model_config.as_dict() - - if config["platform"] != "tensorflow_savedmodel": - raise pb_utils.TritonModelException( - f"[INTERNAL]: The platform field for using this model should be set to" - " 'tensorflow_savedmodel' in model config, got '" - + config["platform"] - + "'" - ) - if config["batch_input"]: - raise pb_utils.TritonModelException( - f"The platform model '" - + config["platform"] - + "' does not support model with batch_input" - ) - if config["batch_output"]: - raise pb_utils.TritonModelException( - f"The platform model '" - + config["platform"] - + "' does not support model with batch_output" - ) - - savedmodel_path = _get_savedmodel_path(config) - - if savedmodel_path is None: - raise pb_utils.TritonModelException( - f"[INTERNAL]: The path to the framework model should be" " provided" - ) - - batching_enabled = False - if config["max_batch_size"] != 0: - batching_enabled = True - - _, signature_def = _get_signature_def(savedmodel_path, config) - - input_tensor_info = signature_def.inputs - output_tensor_info = signature_def.outputs - - batching_hint = False - if not batching_enabled: - batching_hint = _get_batching_hint_from_signature(signature_def) - - # FIXME: Currently the presence of dynamic batch dimension is - # being treated as sufficient proof for enabling batching. - # Need to visit the tensors that are already provided in config - # to confirm the hint - batching_enabled = batching_hint - - config_input_names = [input["name"] for input in config["input"]] - config_output_names = [output["name"] for output in config["output"]] - - # TODO: Add auto-completion of partial tensor specification. - for input_name in input_tensor_info.keys(): - if input_name not in config_input_names: - auto_complete_model_config.add_input( - _convert_proto_to_dict_tensor( - input_name, input_tensor_info[input_name], batching_enabled - ) - ) - - for output_name in output_tensor_info.keys(): - if output_name not in config_output_names: - auto_complete_model_config.add_output( - _convert_proto_to_dict_tensor( - output_name, output_tensor_info[output_name], batching_enabled - ) - ) - - if batching_enabled: - if config["max_batch_size"] == 0: - auto_complete_model_config.set_max_batch_size(4) - auto_complete_model_config.set_dynamic_batching() - - return auto_complete_model_config - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - # You must parse model_config. JSON string is not parsed here - self.model_config = model_config = json.loads(args["model_config"]) - - savedmodel_path = _get_savedmodel_path(model_config) - - self.model_name = args["model_name"] - self.logger = pb_utils.Logger - self.logger.log_info("Initializing model for " + self.model_name) - - if args["model_instance_kind"] != "CPU": - self.logger.log_warn( - "GPU instances are not supported by this backend. Falling back to KIND_CPU for " - + self.model_name - ) - - tag_set, signature_def = _get_signature_def(savedmodel_path, model_config) - _validate_model_config(model_config, signature_def) - - self.signature_def = signature_def - self.input_tensor_info = self.signature_def.inputs - output_tensor_info = self.signature_def.outputs - - # Get the input output names from model config - self.input_names = [input["name"] for input in model_config["input"]] - self.output_names = [output["name"] for output in model_config["output"]] - - # Get the output tensor names - self.output_tensor_names = [ - output_tensor_info[output_name].name for output_name in self.output_names - ] - - # load the session model - # FIXME Add more configuration options for the model. - sess_config = tf.compat.v1.ConfigProto( - inter_op_parallelism_threads=_parse_num_inter_threads(model_config), - intra_op_parallelism_threads=_parse_num_intra_threads(model_config), - use_per_session_threads=_parse_use_per_session_thread(model_config), - ) - self.tf_session = session.Session(graph=tf.Graph(), config=sess_config) - loader.load(self.tf_session, [tag_set], savedmodel_path) - - # Hoding the input dict for caching input tensor data for - # better inference performance - self.input_feed_dict = {} - - def execute(self, requests): - """`execute` MUST be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference request is made - for this model. Depending on the batching configuration (e.g. Dynamic - Batching) used, `requests` may contain multiple requests. Every - Python model, must create one pb_utils.InferenceResponse for every - pb_utils.InferenceRequest in `requests`. If there is an error, you can - set the error argument when creating a pb_utils.InferenceResponse - - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - - responses = [] - - # FIXME: Instead of iterating through each request, run - # the inference as a single batch. - for request in requests: - # Prepare the input feed for the model. - for input_name in self.input_names: - self.input_feed_dict[ - self.input_tensor_info[input_name].name - ] = pb_utils.get_input_tensor_by_name(request, input_name).as_numpy() - - # FIXME: Add GPU Tensor handling. DLpack should be utilized - # for better performance - outputs = self.tf_session.run( - self.output_tensor_names, feed_dict=self.input_feed_dict - ) - - # Create output tensors. You need pb_utils.Tensor - # objects to create pb_utils.InferenceResponse. - output_tensors = [] - for i, output in enumerate(outputs): - output_tensors.append(pb_utils.Tensor(self.output_names[i], output)) - - inference_response = pb_utils.InferenceResponse( - output_tensors=output_tensors - ) - responses.append(inference_response) - - return responses - - def finalize(self): - """`finalize` is called only once when the model is being unloaded. - Implementing `finalize` function is OPTIONAL. This function allows - the model to perform any necessary clean ups before exit. - """ - if self.tf_session is not None: - self.tf_session.close - self.logger.log_info("Removed model instance for " + self.model_name) From 61d8b7124509854e883169427a25a9090c7b1e68 Mon Sep 17 00:00:00 2001 From: oandreeva-nv Date: Wed, 4 Oct 2023 19:57:02 -0700 Subject: [PATCH 5/7] Fixed git core repo --- CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5122ef6c..917400a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,8 +65,7 @@ FetchContent_Declare( ) FetchContent_Declare( repo-core - #GIT_REPOSITORY https://github.com/triton-inference-server/core.git - GIT_REPOSITORY oandreeva@172.17.0.1:/home/oandreeva/Code/core + GIT_REPOSITORY https://github.com/triton-inference-server/core.git GIT_TAG ${TRITON_CORE_REPO_TAG} ) FetchContent_Declare( From 20b0804c3ad37d5ef970475d5885c4f83cb69da7 Mon Sep 17 00:00:00 2001 From: David Yastremsky Date: Thu, 5 Oct 2023 13:56:23 -0700 Subject: [PATCH 6/7] Add TODO for Windows support --- src/python_be.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/python_be.cc b/src/python_be.cc index 813abcba..7e9280b1 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -2043,6 +2043,8 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) // Check if `triton_python_backend_stub` and `triton_python_backend_utils.py` // are located under `location`. + // DLIS-5596: Add forward slash to be platform agnostic + // (i.e., for Windows we need to use backward slash). std::string default_python_backend_dir = default_backend_dir_string + "/python"; std::string backend_stub_path = From b6fa94c68c8b49dfa6efe34cb3a83ec459a483d3 Mon Sep 17 00:00:00 2001 From: David Yastremsky Date: Thu, 5 Oct 2023 13:57:13 -0700 Subject: [PATCH 7/7] Grammar --- src/python_be.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python_be.cc b/src/python_be.cc index 7e9280b1..3175e8a7 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -2044,7 +2044,7 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend) // Check if `triton_python_backend_stub` and `triton_python_backend_utils.py` // are located under `location`. // DLIS-5596: Add forward slash to be platform agnostic - // (i.e., for Windows we need to use backward slash). + // (i.e. For Windows, we need to use backward slash). std::string default_python_backend_dir = default_backend_dir_string + "/python"; std::string backend_stub_path =