Skip to content

Commit

Permalink
Sets up PredictRequest callback (microsoft#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmccrmck authored Mar 20, 2019
1 parent e0e95de commit 2db832c
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 83 deletions.
19 changes: 13 additions & 6 deletions cmake/onnxruntime_hosting.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ file(GLOB_RECURSE onnxruntime_hosting_lib_srcs
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/http/json_handling.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
set_source_files_properties(${ONNXRUNTIME_ROOT}/hosting/http/predict_request_handler.cc PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()

Expand All @@ -41,14 +42,13 @@ file(GLOB_RECURSE onnxruntime_hosting_srcs
"${ONNXRUNTIME_ROOT}/hosting/environment.cc"
)

# For IDE only
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_hosting_srcs} ${onnxruntime_hosting_lib_srcs})

# Hosting library
add_library(onnxruntime_hosting_lib ${onnxruntime_hosting_lib_srcs})
onnxruntime_add_include_to_target(onnxruntime_hosting_lib gsl onnx_proto hosting_proto)
target_include_directories(onnxruntime_hosting_lib PRIVATE
${ONNXRUNTIME_ROOT}
${CMAKE_CURRENT_BINARY_DIR}/onnx
${ONNXRUNTIME_ROOT}/hosting
${ONNXRUNTIME_ROOT}/hosting/http
PUBLIC
${Boost_INCLUDE_DIR}
Expand All @@ -70,20 +70,27 @@ target_link_libraries(onnxruntime_hosting_lib PRIVATE
${onnxruntime_EXTERNAL_LIBRARIES}
)

# For IDE only
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_hosting_srcs} ${onnxruntime_hosting_lib_srcs} ${onnxruntime_hosting_lib})

# Hosting Application
add_executable(${PROJECT_NAME} ${onnxruntime_hosting_srcs})
add_dependencies(${PROJECT_NAME} hosting_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})

onnxruntime_add_include_to_target(${PROJECT_NAME} onnxruntime_session gsl hosting_proto)
if(NOT WIN32)
if(HAS_UNUSED_PARAMETER)
set_source_files_properties("${ONNXRUNTIME_ROOT}/hosting/main.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter)
endif()
endif()

onnxruntime_add_include_to_target(${PROJECT_NAME} onnxruntime_session onnxruntime_hosting_lib gsl onnx_proto hosting_proto)

target_include_directories(${PROJECT_NAME} PRIVATE
${ONNXRUNTIME_ROOT}
${ONNXRUNTIME_ROOT}/hosting/http
${Boost_INCLUDE_DIR}
)

target_link_libraries(${PROJECT_NAME} PRIVATE
onnxruntime_hosting_lib
hosting_proto
)

2 changes: 1 addition & 1 deletion onnxruntime/hosting/http/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class App {
App& Run();

private:
const std::shared_ptr<Routes> routes{std::make_shared<Routes>()};
const std::shared_ptr<Routes> routes = std::make_shared<Routes>();
net::ip::address address_;
unsigned short port_;
int threads_;
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/hosting/http/json_handling.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <istream>
#include <string>
#include <boost/beast/core.hpp>
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/json_util.h>

#include "predict.pb.h"
#include "json_handling.h"
Expand Down Expand Up @@ -33,5 +30,6 @@ protobufutil::Status GenerateResponseInJson(onnxruntime::hosting::PredictRespons
protobufutil::Status result = MessageToJsonString(response, &json_string, options);
return result;
}

} // namespace hosting
} // namespace onnxruntime
56 changes: 56 additions & 0 deletions onnxruntime/hosting/http/predict_request_handler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "environment.h"
#include "http_server.h"
#include "json_handling.h"

namespace onnxruntime {
namespace hosting {

namespace beast = boost::beast;
namespace http = beast::http;

void BadRequest(HttpContext& context, const std::string& error_message) {
auto json_error = R"({"error_code": 400, "error_message": )" + error_message + " }";

http::response<http::string_body> res{http::status::bad_request, context.request.version()};
res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
res.set(http::field::content_type, "application/json");
res.keep_alive(context.request.keep_alive());
res.body() = std::string(json_error);
res.prepare_payload();
context.response = res;
}

// TODO: decide whether this should be a class
void Predict(const std::string& name,
const std::string& version,
const std::string& action,
HttpContext& context,
HostingEnvironment& env) {
PredictRequest predictRequest{};
auto logger = env.GetLogger();

LOGS(logger, VERBOSE) << "Name: " << name
<< "Version: " << version
<< "Action: " << action;

auto body = context.request.body();
auto status = GetRequestFromJson(body, predictRequest);

if (!status.ok()) {
return BadRequest(context, status.error_message());
}

http::response<http::string_body> res{std::piecewise_construct,
std::make_tuple(body),
std::make_tuple(http::status::ok, context.request.version())};
res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
res.set(http::field::content_type, "application/json");
res.keep_alive(context.request.keep_alive());
context.response = res;
};

} // namespace hosting
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/hosting/http/predict_request_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "http_server.h"
#include "json_handling.h"

namespace onnxruntime {
namespace hosting {

namespace beast = boost::beast;
namespace http = beast::http;

void BadRequest(HttpContext& context, const std::string& error_message);

// TODO: decide whether this should be a class
void Predict(const std::string& name,
const std::string& version,
const std::string& action,
HttpContext& context,
HostingEnvironment& env);

} // namespace hosting
} // namespace onnxruntime
47 changes: 47 additions & 0 deletions onnxruntime/hosting/http/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>

using handler_fn = std::function<void(std::string, std::string, std::string, HttpContext&)>;

HttpSession::HttpSession(std::shared_ptr<Routes> routes, tcp::socket socket)
: routes_(std::move(routes)), socket_(std::move(socket)), strand_(socket_.get_executor()) {
}

void HttpSession::DoRead() {
// Make the request empty before reading,
// otherwise the operation behavior is undefined.
Expand Down Expand Up @@ -72,5 +76,48 @@ void HttpSession::DoClose() {

// At this point the connection is closed gracefully
}

template <class Msg>
void HttpSession::Send(Msg&& msg) {
using item_type = std::remove_reference_t<decltype(msg)>;

auto ptr = std::make_shared<item_type>(std::move(msg));
auto self_ = shared_from_this();
self_->res_ = ptr;

http::async_write(self_->socket_, *ptr,
net::bind_executor(strand_,
[self_, close = ptr->need_eof()](beast::error_code ec, std::size_t bytes) {
self_->OnWrite(ec, bytes, close);
}));
}

template <typename Body, typename Allocator>
void HttpSession::HandleRequest(boost::beast::http::request<Body, boost::beast::http::basic_fields<Allocator> >&& req) {
HttpContext context{};
context.request = req;

std::string path = req.target().to_string();
std::string model_name;
std::string model_version;
std::string action;
handler_fn func;
http::status status = routes_->ParseUrl(req.method(), path, model_name, model_version, action, func);

if (http::status::ok == status && func != nullptr) {
func(model_name, model_version, action, context);
} else {
http::response<http::string_body> res{status, req.version()};
res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
res.set(http::field::content_type, "text/plain");
res.keep_alive(req.keep_alive());
res.body() = std::string("Something failed\n");
res.prepare_payload();
context.response = res;
}

return Send(std::move(context.response));
}

} // namespace hosting
} // namespace onnxruntime
51 changes: 8 additions & 43 deletions onnxruntime/hosting/http/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,15 @@ namespace hosting {
namespace net = boost::asio; // from <boost/asio.hpp>
namespace beast = boost::beast; // from <boost/beast.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
namespace http = beast::http;

using handler_fn = std::function<void(std::string, std::string, std::string, HttpContext&)>;

// An implementation of a single HTTP session
// Used by a listener to hand off the work and async write back to a socket
class HttpSession : public std::enable_shared_from_this<HttpSession> {
public:
explicit HttpSession(
std::shared_ptr<Routes> routes,
tcp::socket socket)
: routes_(std::move(routes)), socket_(std::move(socket)), strand_(socket_.get_executor()) {
}
HttpSession(std::shared_ptr<Routes> routes, tcp::socket socket);

// Start the asynchronous operation
// The entrypoint for the class
Expand All @@ -51,50 +48,17 @@ class HttpSession : public std::enable_shared_from_this<HttpSession> {
// Writes the message asynchronously back to the socket
// Stores the pointer to the message and the class itself so that
// They do not get destructed before the async process is finished
// If you pass shared_from_this() are guaranteed that the life time
// of your object will be extended to as long as the function needs it
// Most examples in boost::asio are based on this logic
template <class Msg>
void Send(Msg&& msg) {
using item_type = std::remove_reference_t<decltype(msg)>;

auto ptr = std::make_shared<item_type>(std::move(msg));
auto self_ = shared_from_this();
self_->res_ = ptr;

http::async_write(self_->socket_, *ptr,
net::bind_executor(strand_,
[ self_, close = ptr->need_eof() ](beast::error_code ec, std::size_t bytes) {
self_->OnWrite(ec, bytes, close);
}));
}
void Send(Msg&& msg);

// Handle the request and hand it off to the user's function
// Called after the session is finished reading the message
// Should set the response before calling Send
template <typename Body, typename Allocator>
void HandleRequest(http::request<Body, http::basic_fields<Allocator>>&& req) {
HttpContext context{};
context.request = req;

std::string path = req.target().to_string();
std::string model_name;
std::string model_version;
std::string action;
handler_fn func;
http::status status = routes_->ParseUrl(req.method(), path, model_name, model_version, action, func);

if (http::status::ok == status) {
func(model_name, model_version, action, context);
} else {
http::response<http::string_body> res{status, req.version()};
res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
res.set(http::field::content_type, "text/plain");
res.keep_alive(req.keep_alive());
res.body() = std::string("Something failed\n");
res.prepare_payload();
context.response = res;
}

return Send(std::move(context.response));
}
void HandleRequest(http::request<Body, http::basic_fields<Allocator>>&& req);

// Asynchronously reads the request from the socket
void DoRead();
Expand All @@ -105,6 +69,7 @@ class HttpSession : public std::enable_shared_from_this<HttpSession> {
// After writing, make the session read another request
void OnWrite(beast::error_code ec, std::size_t bytes_transferred, bool close);

// Close the connection
void DoClose();
};

Expand Down
47 changes: 17 additions & 30 deletions onnxruntime/hosting/main.cc
Original file line number Diff line number Diff line change
@@ -1,58 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "environment.h"
#include "http_server.h"
#include "predict_request_handler.h"
#include "server_configuration.h"
#include "environment.h"

namespace beast = boost::beast;
namespace http = beast::http;

void test_request(const std::string& name, const std::string& version,
const std::string& action, onnxruntime::hosting::HttpContext& context) {
std::stringstream ss;

ss << "\tModel Name: " << name << std::endl;
ss << "\tModel Version: " << version << std::endl;
ss << "\tAction: " << action << std::endl;
ss << "\tHTTP method: " << context.request.method() << std::endl;

http::response<http::string_body>
res{std::piecewise_construct, std::make_tuple(ss.str()), std::make_tuple(http::status::ok, context.request.version())};

res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
res.set(http::field::content_type, "plain/text");
res.keep_alive(context.request.keep_alive());
context.response = res;
}
namespace hosting = onnxruntime::hosting;

int main(int argc, char* argv[]) {
onnxruntime::hosting::ServerConfiguration config{};
hosting::ServerConfiguration config{};
auto res = config.ParseInput(argc, argv);

if (res == onnxruntime::hosting::Result::ExitSuccess) {
if (res == hosting::Result::ExitSuccess) {
exit(EXIT_SUCCESS);
} else if (res == onnxruntime::hosting::Result::ExitFailure) {
} else if (res == hosting::Result::ExitFailure) {
exit(EXIT_FAILURE);
}

onnxruntime::hosting::HostingEnvironment env;
hosting::HostingEnvironment env;
auto logger = env.GetLogger();

// TODO: below code snippet just trying to show case how to use the "env".
// Will be moved to proper place.
// TODO: below code snippet just trying to show case how to use the "env". Move later.
LOGS(logger, VERBOSE) << "Logging manager initialized.";
LOGS(logger, VERBOSE) << "Model path: " << config.model_path;
auto status = env.GetSession()->Load(config.model_path);
LOGS(logger, VERBOSE) << "Load Model Status: " << status.Code() << " ---- Error: [" << status.ErrorMessage() << "]";

auto const boost_address = boost::asio::ip::make_address(config.address);

onnxruntime::hosting::App app{};
app.Post(R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))", test_request)
.Bind(boost_address, config.http_port)
.NumThreads(config.num_http_threads)
.Run();
hosting::App app{};
app.Post(R"(/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))",
[&env](const std::string& name, const std::string& version, const std::string& action, hosting::HttpContext& context) {
hosting::Predict(name, version, action, context, env);
});

app.Bind(boost_address, config.http_port)
.NumThreads(config.num_http_threads)
.Run();

return EXIT_SUCCESS;
}

0 comments on commit 2db832c

Please sign in to comment.