diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index d46fcdfee5..5ba3a56830 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -53,13 +53,20 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); exec_ctx = make_trt(cuda_engine->createExecutionContext()); + TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context"); uint64_t inputs = 0; uint64_t outputs = 0; for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) { std::string bind_name = cuda_engine->getBindingName(x); - std::string idx_s = bind_name.substr(bind_name.find("_") + 1); + auto delim = bind_name.find("."); + if (delim == std::string::npos) { + delim = bind_name.find("_"); + TORCHTRT_CHECK(delim != std::string::npos, "Unable to determine binding index for input " << bind_name << "\nEnsure module was compile with Torch-TensorRT.ts"); + } + + std::string idx_s = bind_name.substr(delim + 1); uint64_t idx = static_cast(std::stoi(idx_s)); if (cuda_engine->bindingIsInput(x)) { @@ -71,6 +78,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe } } num_io = std::make_pair(inputs, outputs); + + LOG_DEBUG(*this); } TRTEngine& TRTEngine::operator=(const TRTEngine& other) { @@ -82,6 +91,34 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) { return (*this); } +std::string TRTEngine::to_str() const { + std::stringstream ss; + ss << "Torch-TensorRT TensorRT Engine:" << std::endl; + ss << " Name: " << name << std::endl; + ss << " Inputs: [" << std::endl; + for (uint64_t i = 0; i < num_io.first; i++) { + ss << " id: " << i << std::endl; + ss << " shape: " << exec_ctx->getBindingDimensions(i) << std::endl; + ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(i)) << std::endl; + } + ss << " ]" << std::endl; + ss << " Outputs: [" << std::endl; + for (uint64_t o = 0; o < num_io.second; o++) { + ss << " id: " << o << std::endl; + ss << " shape: " << exec_ctx->getBindingDimensions(o) << std::endl; + ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(o)) << std::endl; + } + ss << " ]" << std::endl; + ss << " Device: " << device_info << std::endl; + + return ss.str(); +} + +std::ostream& operator<<(std::ostream& os, const TRTEngine& engine) { + os << engine.to_str(); + return os; +} + // TODO: Implement a call method // c10::List TRTEngine::Run(c10::List inputs) { // auto input_vec = inputs.vec(); @@ -96,6 +133,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def(torch::init>()) // TODO: .def("__call__", &TRTEngine::Run) // TODO: .def("run", &TRTEngine::Run) + .def("__str__", &TRTEngine::to_str) .def_pickle( [](const c10::intrusive_ptr& self) -> std::vector { // Serialize TensorRT engine diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 79ae74c91b..2d92fa4e00 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -59,6 +59,8 @@ struct TRTEngine : torch::CustomClassHolder { TRTEngine(std::vector serialized_info); TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device); TRTEngine& operator=(const TRTEngine& other); + std::string to_str() const; + friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); // TODO: Implement a call method // c10::List Run(c10::List inputs); };