diff --git a/cmake/onnxruntime_hosting.cmake b/cmake/onnxruntime_hosting.cmake index 85fb4f4265284..9844ff4cb5330 100644 --- a/cmake/onnxruntime_hosting.cmake +++ b/cmake/onnxruntime_hosting.cmake @@ -3,7 +3,7 @@ project(onnxruntime_hosting) -find_package(Boost 1.69 COMPONENTS system coroutine context thread program_options REQUIRED) +find_package(Boost 1.69 COMPONENTS system context thread program_options REQUIRED) set(re2_src ${REPO_ROOT}/cmake/external/re2) @@ -28,12 +28,6 @@ target_include_directories(${PROJECT_NAME} PRIVATE target_link_libraries(${PROJECT_NAME} PRIVATE ${Boost_LIBRARIES} - ${PROVIDERS_MKLDNN} - ${MKLML_SHARED_LIB} - ${PROVIDERS_CUDA} - ${onnxruntime_tvm_libs} - ${onnxruntime_libs} - ${onnxruntime_EXTERNAL_LIBRARIES} onnxruntime_session onnxruntime_optimizer onnxruntime_providers @@ -43,9 +37,6 @@ target_link_libraries(${PROJECT_NAME} PRIVATE onnxruntime_graph onnxruntime_common onnxruntime_mlas - onnx - onnx_proto - protobuf::libprotobuf - re2 + ${onnxruntime_EXTERNAL_LIBRARIES} ) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index c3e92dfeb0d4c..3a168c1ab5688 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -172,8 +172,6 @@ elseif(HAS_FILESYSTEM_H OR HAS_EXPERIMENTAL_FILESYSTEM_H) list(APPEND onnxruntime_test_framework_libs stdc++fs) endif() - - set (onnxruntime_test_providers_dependencies ${onnxruntime_EXTERNAL_DEPENDENCIES}) if(onnxruntime_USE_CUDA) @@ -525,10 +523,10 @@ if (onnxruntime_BUILD_SHARED_LIB) endif() if (onnxruntime_BUILD_HOSTING) - find_package(Boost 1.69 COMPONENTS system coroutine context thread program_options REQUIRED) + find_package(Boost 1.69 COMPONENTS system context thread program_options REQUIRED) add_library(onnxruntime_test_utils_for_hosting ${onnxruntime_test_hosting_src}) onnxruntime_add_include_to_target(onnxruntime_test_utils_for_hosting onnxruntime_test_utils gtest gsl onnx onnx_proto ) - add_dependencies(onnxruntime_test_utils_for_hosting ${onnxruntime_EXTERNAL_DEPENDENCIES}) + add_dependencies(onnxruntime_test_utils_for_hosting onnxruntime_hosting ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_include_directories(onnxruntime_test_utils_for_hosting PUBLIC ${Boost_INCLUDE_DIR} ${REPO_ROOT}/cmake/external/re2 PRIVATE ${ONNXRUNTIME_ROOT} ) target_link_libraries(onnxruntime_test_utils_for_hosting ${Boost_LIBRARIES} ${onnx_test_libs}) diff --git a/onnxruntime/hosting/include/beast_http.h b/onnxruntime/hosting/include/beast_http.h index 95a15ea201a77..51c61118154cd 100644 --- a/onnxruntime/hosting/include/beast_http.h +++ b/onnxruntime/hosting/include/beast_http.h @@ -43,7 +43,7 @@ class App { return *this; } - App& post(const std::regex& route, handler_fn fn) { + App& post(const std::string& route, handler_fn fn) { // routes->http_posts[route] = std::move(fn); routes->register_controller(http::verb::post, route, fn); return *this; diff --git a/onnxruntime/hosting/include/routes.h b/onnxruntime/hosting/include/routes.h index 2395ab2e96032..c2e896af3d9eb 100644 --- a/onnxruntime/hosting/include/routes.h +++ b/onnxruntime/hosting/include/routes.h @@ -4,9 +4,9 @@ #ifndef BEAST_SERVER_ROUTES_H #define BEAST_SERVER_ROUTES_H -#include #include #include +#include "re2/re2.h" #include "http_context.h" namespace http = boost::beast::http; // from @@ -20,14 +20,14 @@ using handler_fn = std::functionget_fn_table.push_back(make_pair(url_pattern, controller)); + this->get_fn_table.emplace_back(url_pattern, controller); return true; case http::verb::post: - this->post_fn_table.push_back(make_pair(url_pattern, controller)); + this->post_fn_table.emplace_back(url_pattern, controller); return true; default: return false; @@ -40,7 +40,7 @@ class Routes { /* out */ std::string& model_version, /* out */ std::string& action, /* out */ handler_fn& func) { - std::vector> func_table; + std::vector> func_table; switch(method) { case http::verb::get: @@ -59,14 +59,9 @@ class Routes { return http::status::method_not_allowed; } - std::smatch m{}; bool found_match = false; for (const auto& pattern : func_table) { - // TODO: use re2 for matching - if (std::regex_match(url, m, pattern.first)) { - model_name = m[1]; - model_version = m[2]; - action = m[3]; + if (re2::RE2::FullMatch(url, pattern.first, &model_name, &model_version, &action)) { func = pattern.second; found_match = true; @@ -83,8 +78,8 @@ class Routes { } private: - std::vector> post_fn_table; - std::vector> get_fn_table; + std::vector> post_fn_table; + std::vector> get_fn_table; }; } // namespace onnxruntime diff --git a/onnxruntime/hosting/main.cc b/onnxruntime/hosting/main.cc index d37dd17904f0a..d60757781deb7 100644 --- a/onnxruntime/hosting/main.cc +++ b/onnxruntime/hosting/main.cc @@ -70,7 +70,7 @@ int main(int argc, char* argv[]) { auto const boost_address = boost::asio::ip::make_address(vm["address"].as()); onnxruntime::App app{}; - app.post(std::regex(R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))"), test_request) + app.post(R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))", test_request) .bind(boost_address, vm["port"].as()) .num_threads(vm["threads"].as()) .run(); diff --git a/onnxruntime/test/hosting/http_routes_tests.cc b/onnxruntime/test/hosting/http_routes_tests.cc index 070fef6c036f2..facd7b93f586c 100644 --- a/onnxruntime/test/hosting/http_routes_tests.cc +++ b/onnxruntime/test/hosting/http_routes_tests.cc @@ -17,58 +17,67 @@ void do_something(const std::string& name, const std::string& version, auto noop = name + version + action + context.request.body(); } -void run_route(const std::regex& pattern, http::verb method, const std::vector& data, bool does_validate_data); +void run_route(const std::string& pattern, http::verb method, const std::vector& data, bool does_validate_data); TEST(PositiveTests, RegisterTest) { - auto predict_regex = std::regex( - R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))"); + auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))"; Routes routes; EXPECT_TRUE(routes.register_controller(http::verb::post, predict_regex, do_something)); - auto status_regex = std::regex( - R"(/v1/models(?:/([^/:]+))?(?:/versions/(\d+))?(?:\/(metadata))?)"); + auto status_regex = R"(/v1/models(?:/([^/:]+))?(?:/versions/(\d+))?(?:\/(metadata))?)"; EXPECT_TRUE(routes.register_controller(http::verb::get, status_regex, do_something)); } TEST(PositiveTests, PostRouteTest) { - auto predict_regex = std::regex( - R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))"); + auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))"; std::vector actions{ std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::ok), std::make_tuple(http::verb::post, "/v1/models/abc:predict", "abc", "", "predict", http::status::ok), std::make_tuple(http::verb::post, "/v1/models/models/versions/45:predict", "models", "45", "predict", http::status::ok), + std::make_tuple(http::verb::post, "/v1/models/??$$%%@@$^^/versions/45:predict", "??$$%%@@$^^", "45", "predict", http::status::ok), std::make_tuple(http::verb::post, "/v1/models/versions/versions/45:predict", "versions", "45", "predict", http::status::ok)}; run_route(predict_regex, http::verb::post, actions, true); } TEST(NegativeTests, PostRouteInvalidURLTest) { - auto predict_regex = std::regex( - R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))"); + auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))"; std::vector actions{ + std::make_tuple(http::verb::post, "", "", "", "", http::status::not_found), std::make_tuple(http::verb::post, "/v1/models", "", "", "", http::status::not_found), std::make_tuple(http::verb::post, "/v1/models:predict", "", "", "", http::status::not_found), std::make_tuple(http::verb::post, "/v1/models:bar", "", "", "", http::status::not_found), std::make_tuple(http::verb::post, "/v1/models/abc/versions", "", "", "", http::status::not_found), std::make_tuple(http::verb::post, "/v1/models/abc/versions:predict", "", "", "", http::status::not_found), + std::make_tuple(http::verb::post, "/v1/models/a:bc/versions:predict", "", "", "", http::status::not_found), + std::make_tuple(http::verb::post, "/v1/models/abc/versions/2.0:predict", "", "", "", http::status::not_found), + std::make_tuple(http::verb::post, "/models/abc/versions/2:predict", "", "", "", http::status::not_found), + std::make_tuple(http::verb::post, "/v1/models/versions/2:predict", "", "", "", http::status::not_found), + std::make_tuple(http::verb::post, "/v1/models/foo/versions/:predict", "", "", "", http::status::not_found), + std::make_tuple(http::verb::post, "/v1/models/foo/versions:predict", "", "", "", http::status::not_found), + std::make_tuple(http::verb::post, "v1/models/foo/versions/12:predict", "", "", "", http::status::not_found), std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:foo", "", "", "", http::status::not_found)}; run_route(predict_regex, http::verb::post, actions, false); } +// These tests are because we currently only support POST and GET +// Some HTTP methods should be removed from test data if we support more (e.g. PUT) TEST(NegativeTests, PostRouteInvalidMethodTest) { - auto predict_regex = std::regex( - R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))"); + auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))"; std::vector actions{ - std::make_tuple(http::verb::get, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::method_not_allowed)}; + std::make_tuple(http::verb::get, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::method_not_allowed), + std::make_tuple(http::verb::put, "/v1/models", "", "", "", http::status::method_not_allowed), + std::make_tuple(http::verb::delete_, "/v1/models", "", "", "", http::status::method_not_allowed), + std::make_tuple(http::verb::head, "/v1/models", "", "", "", http::status::method_not_allowed)}; run_route(predict_regex, http::verb::post, actions, false); } -void run_route(const std::regex& pattern, http::verb method, const std::vector& data, bool does_validate_data) { +void run_route(const std::string& pattern, http::verb method, const std::vector& data, bool does_validate_data) { Routes routes; EXPECT_TRUE(routes.register_controller(method, pattern, do_something));