Skip to content

Commit

Permalink
Removes std::regex in favor of re2 (microsoft#8)
Browse files Browse the repository at this point in the history
* Removes std::regex in favor of re2

* Adds back find_package in unit tests and fixes regexes

* Adds more negative test cases
  • Loading branch information
tmccrmck authored Mar 8, 2019
1 parent e09e156 commit 8aedb7c
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 43 deletions.
13 changes: 2 additions & 11 deletions cmake/onnxruntime_hosting.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

project(onnxruntime_hosting)

find_package(Boost 1.69 COMPONENTS system coroutine context thread program_options REQUIRED)
find_package(Boost 1.69 COMPONENTS system context thread program_options REQUIRED)

set(re2_src ${REPO_ROOT}/cmake/external/re2)

Expand All @@ -28,12 +28,6 @@ target_include_directories(${PROJECT_NAME} PRIVATE

target_link_libraries(${PROJECT_NAME} PRIVATE
${Boost_LIBRARIES}
${PROVIDERS_MKLDNN}
${MKLML_SHARED_LIB}
${PROVIDERS_CUDA}
${onnxruntime_tvm_libs}
${onnxruntime_libs}
${onnxruntime_EXTERNAL_LIBRARIES}
onnxruntime_session
onnxruntime_optimizer
onnxruntime_providers
Expand All @@ -43,9 +37,6 @@ target_link_libraries(${PROJECT_NAME} PRIVATE
onnxruntime_graph
onnxruntime_common
onnxruntime_mlas
onnx
onnx_proto
protobuf::libprotobuf
re2
${onnxruntime_EXTERNAL_LIBRARIES}
)

6 changes: 2 additions & 4 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ elseif(HAS_FILESYSTEM_H OR HAS_EXPERIMENTAL_FILESYSTEM_H)
list(APPEND onnxruntime_test_framework_libs stdc++fs)
endif()



set (onnxruntime_test_providers_dependencies ${onnxruntime_EXTERNAL_DEPENDENCIES})

if(onnxruntime_USE_CUDA)
Expand Down Expand Up @@ -525,10 +523,10 @@ if (onnxruntime_BUILD_SHARED_LIB)
endif()

if (onnxruntime_BUILD_HOSTING)
find_package(Boost 1.69 COMPONENTS system coroutine context thread program_options REQUIRED)
find_package(Boost 1.69 COMPONENTS system context thread program_options REQUIRED)
add_library(onnxruntime_test_utils_for_hosting ${onnxruntime_test_hosting_src})
onnxruntime_add_include_to_target(onnxruntime_test_utils_for_hosting onnxruntime_test_utils gtest gsl onnx onnx_proto )
add_dependencies(onnxruntime_test_utils_for_hosting ${onnxruntime_EXTERNAL_DEPENDENCIES})
add_dependencies(onnxruntime_test_utils_for_hosting onnxruntime_hosting ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_test_utils_for_hosting PUBLIC ${Boost_INCLUDE_DIR} ${REPO_ROOT}/cmake/external/re2 PRIVATE ${ONNXRUNTIME_ROOT} )
target_link_libraries(onnxruntime_test_utils_for_hosting ${Boost_LIBRARIES} ${onnx_test_libs})

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/hosting/include/beast_http.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class App {
return *this;
}

App& post(const std::regex& route, handler_fn fn) {
App& post(const std::string& route, handler_fn fn) {
// routes->http_posts[route] = std::move(fn);
routes->register_controller(http::verb::post, route, fn);
return *this;
Expand Down
21 changes: 8 additions & 13 deletions onnxruntime/hosting/include/routes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
#ifndef BEAST_SERVER_ROUTES_H
#define BEAST_SERVER_ROUTES_H

#include <regex>
#include <vector>
#include <boost/beast/http.hpp>
#include "re2/re2.h"
#include "http_context.h"

namespace http = boost::beast::http; // from <boost/beast/http.hpp>
Expand All @@ -20,14 +20,14 @@ using handler_fn = std::function<void(std::string, std::string, std::string, Htt
class Routes {
public:
Routes() = default;
bool register_controller(http::verb method, const std::regex& url_pattern, const handler_fn& controller) {
bool register_controller(http::verb method, const std::string& url_pattern, const handler_fn& controller) {
switch(method)
{
case http::verb::get:
this->get_fn_table.push_back(make_pair(url_pattern, controller));
this->get_fn_table.emplace_back(url_pattern, controller);
return true;
case http::verb::post:
this->post_fn_table.push_back(make_pair(url_pattern, controller));
this->post_fn_table.emplace_back(url_pattern, controller);
return true;
default:
return false;
Expand All @@ -40,7 +40,7 @@ class Routes {
/* out */ std::string& model_version,
/* out */ std::string& action,
/* out */ handler_fn& func) {
std::vector<std::pair<std::regex, handler_fn>> func_table;
std::vector<std::pair<std::string, handler_fn>> func_table;
switch(method)
{
case http::verb::get:
Expand All @@ -59,14 +59,9 @@ class Routes {
return http::status::method_not_allowed;
}

std::smatch m{};
bool found_match = false;
for (const auto& pattern : func_table) {
// TODO: use re2 for matching
if (std::regex_match(url, m, pattern.first)) {
model_name = m[1];
model_version = m[2];
action = m[3];
if (re2::RE2::FullMatch(url, pattern.first, &model_name, &model_version, &action)) {
func = pattern.second;

found_match = true;
Expand All @@ -83,8 +78,8 @@ class Routes {
}

private:
std::vector<std::pair<std::regex, handler_fn>> post_fn_table;
std::vector<std::pair<std::regex, handler_fn>> get_fn_table;
std::vector<std::pair<std::string, handler_fn>> post_fn_table;
std::vector<std::pair<std::string, handler_fn>> get_fn_table;
};

} // namespace onnxruntime
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/hosting/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ int main(int argc, char* argv[]) {
auto const boost_address = boost::asio::ip::make_address(vm["address"].as<std::string>());

onnxruntime::App app{};
app.post(std::regex(R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))"), test_request)
app.post(R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))", test_request)
.bind(boost_address, vm["port"].as<int>())
.num_threads(vm["threads"].as<int>())
.run();
Expand Down
35 changes: 22 additions & 13 deletions onnxruntime/test/hosting/http_routes_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,67 @@ void do_something(const std::string& name, const std::string& version,
auto noop = name + version + action + context.request.body();
}

void run_route(const std::regex& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data);
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data);

TEST(PositiveTests, RegisterTest) {
auto predict_regex = std::regex(
R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))");
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";
Routes routes;
EXPECT_TRUE(routes.register_controller(http::verb::post, predict_regex, do_something));

auto status_regex = std::regex(
R"(/v1/models(?:/([^/:]+))?(?:/versions/(\d+))?(?:\/(metadata))?)");
auto status_regex = R"(/v1/models(?:/([^/:]+))?(?:/versions/(\d+))?(?:\/(metadata))?)";
EXPECT_TRUE(routes.register_controller(http::verb::get, status_regex, do_something));
}

TEST(PositiveTests, PostRouteTest) {
auto predict_regex = std::regex(
R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))");
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";

std::vector<test_data> actions{
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/abc:predict", "abc", "", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/models/versions/45:predict", "models", "45", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/??$$%%@@$^^/versions/45:predict", "??$$%%@@$^^", "45", "predict", http::status::ok),
std::make_tuple(http::verb::post, "/v1/models/versions/versions/45:predict", "versions", "45", "predict", http::status::ok)};

run_route(predict_regex, http::verb::post, actions, true);
}

TEST(NegativeTests, PostRouteInvalidURLTest) {
auto predict_regex = std::regex(
R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))");
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";

std::vector<test_data> actions{
std::make_tuple(http::verb::post, "", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models:bar", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/a:bc/versions:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions/2.0:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/models/abc/versions/2:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/versions/2:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/foo/versions/:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/foo/versions:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "v1/models/foo/versions/12:predict", "", "", "", http::status::not_found),
std::make_tuple(http::verb::post, "/v1/models/abc/versions/23:foo", "", "", "", http::status::not_found)};

run_route(predict_regex, http::verb::post, actions, false);
}

// These tests are because we currently only support POST and GET
// Some HTTP methods should be removed from test data if we support more (e.g. PUT)
TEST(NegativeTests, PostRouteInvalidMethodTest) {
auto predict_regex = std::regex(
R"(/v1/models(?:/([^/:]+))(?:/versions/(\d+))?:(classify|regress|predict))");
auto predict_regex = R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))";

std::vector<test_data> actions{
std::make_tuple(http::verb::get, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::method_not_allowed)};
std::make_tuple(http::verb::get, "/v1/models/abc/versions/23:predict", "abc", "23", "predict", http::status::method_not_allowed),
std::make_tuple(http::verb::put, "/v1/models", "", "", "", http::status::method_not_allowed),
std::make_tuple(http::verb::delete_, "/v1/models", "", "", "", http::status::method_not_allowed),
std::make_tuple(http::verb::head, "/v1/models", "", "", "", http::status::method_not_allowed)};

run_route(predict_regex, http::verb::post, actions, false);
}

void run_route(const std::regex& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data) {
void run_route(const std::string& pattern, http::verb method, const std::vector<test_data>& data, bool does_validate_data) {
Routes routes;
EXPECT_TRUE(routes.register_controller(method, pattern, do_something));

Expand Down

0 comments on commit 8aedb7c

Please sign in to comment.