diff --git a/include/dsn/utility/error_code.h b/include/dsn/utility/error_code.h index 61be86f868..ee0703643e 100644 --- a/include/dsn/utility/error_code.h +++ b/include/dsn/utility/error_code.h @@ -124,4 +124,5 @@ DEFINE_ERR_CODE(ERR_UNAUTHENTICATED) DEFINE_ERR_CODE(ERR_KRB5_INTERNAL) DEFINE_ERR_CODE(ERR_SASL_INTERNAL) +DEFINE_ERR_CODE(ERR_SASL_INCOMPLETE) } // namespace dsn diff --git a/src/runtime/security/client_negotiation.cpp b/src/runtime/security/client_negotiation.cpp index 2b22bc0ef5..306273569c 100644 --- a/src/runtime/security/client_negotiation.cpp +++ b/src/runtime/security/client_negotiation.cpp @@ -66,7 +66,7 @@ void client_negotiation::handle_response(error_code err, const negotiation_respo on_recv_mechanisms(response); break; case negotiation_status::type::SASL_SELECT_MECHANISMS: - // TBD(zlw) + on_mechanism_selected(response); break; case negotiation_status::type::SASL_INITIATE: case negotiation_status::type::SASL_CHALLENGE_RESP: @@ -79,11 +79,7 @@ void client_negotiation::handle_response(error_code err, const negotiation_respo void client_negotiation::on_recv_mechanisms(const negotiation_response &resp) { - if (resp.status != negotiation_status::type::SASL_LIST_MECHANISMS_RESP) { - dwarn_f("{}: get message({}) while expect({})", - _name, - enum_to_string(resp.status), - enum_to_string(negotiation_status::type::SASL_LIST_MECHANISMS_RESP)); + if (!check_status(resp.status, negotiation_status::type::SASL_LIST_MECHANISMS_RESP)) { fail_negotiation(); return; } @@ -111,6 +107,41 @@ void client_negotiation::on_recv_mechanisms(const negotiation_response &resp) select_mechanism(match_mechanism); } +void client_negotiation::on_mechanism_selected(const negotiation_response &resp) +{ + if (!check_status(resp.status, negotiation_status::type::SASL_SELECT_MECHANISMS_RESP)) { + fail_negotiation(); + return; + } + + // init client sasl + auto err_s = _sasl->init(); + if (!err_s.is_ok()) { + dwarn_f("{}: initialize sasl client failed, error = {}, reason = {}", + _name, + err_s.code().to_string(), + err_s.description()); + fail_negotiation(); + return; + } + + // start client sasl, and send `SASL_INITIATE` to `server_negotiation` if everything is ok + std::string start_output; + err_s = _sasl->start(_selected_mechanism, "", start_output); + if (err_s.is_ok() || ERR_SASL_INCOMPLETE == err_s.code()) { + auto req = dsn::make_unique(); + _status = req->status = negotiation_status::type::SASL_INITIATE; + req->msg = start_output; + send(std::move(req)); + } else { + dwarn_f("{}: start sasl client failed, error = {}, reason = {}", + _name, + err_s.code().to_string(), + err_s.description()); + fail_negotiation(); + } +} + void client_negotiation::select_mechanism(const std::string &mechanism) { _selected_mechanism = mechanism; diff --git a/src/runtime/security/client_negotiation.h b/src/runtime/security/client_negotiation.h index 158c127311..8a4dd7f38e 100644 --- a/src/runtime/security/client_negotiation.h +++ b/src/runtime/security/client_negotiation.h @@ -32,6 +32,7 @@ class client_negotiation : public negotiation private: void handle_response(error_code err, const negotiation_response &&response); void on_recv_mechanisms(const negotiation_response &resp); + void on_mechanism_selected(const negotiation_response &resp); void list_mechanisms(); void select_mechanism(const std::string &mechanism); diff --git a/src/runtime/security/negotiation.cpp b/src/runtime/security/negotiation.cpp index e1c73e3fc5..57bc0ac3b6 100644 --- a/src/runtime/security/negotiation.cpp +++ b/src/runtime/security/negotiation.cpp @@ -18,9 +18,11 @@ #include "negotiation.h" #include "client_negotiation.h" #include "server_negotiation.h" +#include "negotiation_utils.h" #include #include +#include namespace dsn { namespace security { @@ -48,5 +50,18 @@ void negotiation::fail_negotiation() _session->on_failure(true); } +bool negotiation::check_status(negotiation_status::type status, + negotiation_status::type expected_status) +{ + if (status != expected_status) { + dwarn_f("{}: get message({}) while expect({})", + _name, + enum_to_string(status), + enum_to_string(expected_status)); + return false; + } + + return true; +} } // namespace security } // namespace dsn diff --git a/src/runtime/security/negotiation.h b/src/runtime/security/negotiation.h index 20bac462c2..9b31b7626b 100644 --- a/src/runtime/security/negotiation.h +++ b/src/runtime/security/negotiation.h @@ -42,6 +42,11 @@ class negotiation virtual void start() = 0; bool negotiation_succeed() const { return _status == negotiation_status::type::SASL_SUCC; } void fail_negotiation(); + // check whether the status is equal to expected_status + // ret value: + // true: status == expected_status + // false: status != expected_status + bool check_status(negotiation_status::type status, negotiation_status::type expected_status); protected: // The ownership of the negotiation instance is held by rpc_session. diff --git a/src/runtime/security/sasl_client_wrapper.cpp b/src/runtime/security/sasl_client_wrapper.cpp index 1d584af121..07871c6828 100644 --- a/src/runtime/security/sasl_client_wrapper.cpp +++ b/src/runtime/security/sasl_client_wrapper.cpp @@ -17,7 +17,9 @@ #include "sasl_client_wrapper.h" +#include #include +#include namespace dsn { namespace security { @@ -26,16 +28,33 @@ DSN_DECLARE_string(service_name); error_s sasl_client_wrapper::init() { - // TBD(zlw) - return error_s::make(ERR_OK); + FAIL_POINT_INJECT_F("sasl_client_wrapper_init", [](dsn::string_view str) { + error_code err = error_code::try_get(str.data(), ERR_UNKNOWN); + return error_s::make(err); + }); + + int sasl_err = sasl_client_new( + FLAGS_service_name, FLAGS_service_fqdn, nullptr, nullptr, nullptr, 0, &_conn); + return wrap_error(sasl_err); } error_s sasl_client_wrapper::start(const std::string &mechanism, const std::string &input, std::string &output) { - // TBD(zlw) - return error_s::make(ERR_OK); + FAIL_POINT_INJECT_F("sasl_client_wrapper_start", [](dsn::string_view str) { + error_code err = error_code::try_get(str.data(), ERR_UNKNOWN); + return error_s::make(err); + }); + + const char *msg = nullptr; + unsigned msg_len = 0; + const char *client_mech = nullptr; + int sasl_err = + sasl_client_start(_conn, mechanism.c_str(), nullptr, &msg, &msg_len, &client_mech); + + output.assign(msg, msg_len); + return wrap_error(sasl_err); } error_s sasl_client_wrapper::step(const std::string &input, std::string &output) diff --git a/src/runtime/security/sasl_wrapper.cpp b/src/runtime/security/sasl_wrapper.cpp index 3412a3df77..91cc05079d 100644 --- a/src/runtime/security/sasl_wrapper.cpp +++ b/src/runtime/security/sasl_wrapper.cpp @@ -46,7 +46,7 @@ error_s sasl_wrapper::wrap_error(int sasl_err) case SASL_OK: return error_s::make(ERR_OK); case SASL_CONTINUE: - return error_s::make(ERR_NOT_IMPLEMENTED); + return error_s::make(ERR_SASL_INCOMPLETE); case SASL_FAIL: // Generic failure (encompasses missing krb5 credentials). case SASL_BADAUTH: // Authentication failure. case SASL_BADMAC: // Decode failure. diff --git a/src/runtime/security/server_negotiation.cpp b/src/runtime/security/server_negotiation.cpp index 130a538d53..890cf79224 100644 --- a/src/runtime/security/server_negotiation.cpp +++ b/src/runtime/security/server_negotiation.cpp @@ -60,52 +60,44 @@ void server_negotiation::handle_request(negotiation_rpc rpc) void server_negotiation::on_list_mechanisms(negotiation_rpc rpc) { - if (rpc.request().status == negotiation_status::type::SASL_LIST_MECHANISMS) { - std::string mech_list = boost::join(supported_mechanisms, ","); - negotiation_response &response = rpc.response(); - _status = response.status = negotiation_status::type::SASL_LIST_MECHANISMS_RESP; - response.msg = std::move(mech_list); - } else { - ddebug_f("{}: got message({}) while expect({})", - _name, - enum_to_string(rpc.request().status), - enum_to_string(negotiation_status::type::SASL_LIST_MECHANISMS)); + if (!check_status(rpc.request().status, negotiation_status::type::SASL_LIST_MECHANISMS)) { fail_negotiation(); + return; } - return; + + std::string mech_list = boost::join(supported_mechanisms, ","); + negotiation_response &response = rpc.response(); + _status = response.status = negotiation_status::type::SASL_LIST_MECHANISMS_RESP; + response.msg = std::move(mech_list); } void server_negotiation::on_select_mechanism(negotiation_rpc rpc) { const negotiation_request &request = rpc.request(); - if (request.status == negotiation_status::type::SASL_SELECT_MECHANISMS) { - _selected_mechanism = request.msg; - if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) { - dwarn_f("the mechanism of {} is not supported", _selected_mechanism); - fail_negotiation(); - return; - } + if (!check_status(rpc.request().status, negotiation_status::type::SASL_SELECT_MECHANISMS)) { + fail_negotiation(); + return; + } - error_s err_s = _sasl->init(); - if (!err_s.is_ok()) { - dwarn_f("{}: server initialize sasl failed, error = {}, msg = {}", - _name, - err_s.code().to_string(), - err_s.description()); - fail_negotiation(); - return; - } + _selected_mechanism = request.msg; + if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) { + dwarn_f("the mechanism of {} is not supported", _selected_mechanism); + fail_negotiation(); + return; + } - negotiation_response &response = rpc.response(); - _status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP; - } else { - dwarn_f("{}: got message({}) while expect({})", + error_s err_s = _sasl->init(); + if (!err_s.is_ok()) { + dwarn_f("{}: server initialize sasl failed, error = {}, msg = {}", _name, - enum_to_string(request.status), - negotiation_status::type::SASL_SELECT_MECHANISMS); + err_s.code().to_string(), + err_s.description()); fail_negotiation(); return; } + + negotiation_response &response = rpc.response(); + _status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP; } } // namespace security } // namespace dsn diff --git a/src/runtime/test/client_negotiation_test.cpp b/src/runtime/test/client_negotiation_test.cpp index 5158c45a75..7df6334953 100644 --- a/src/runtime/test/client_negotiation_test.cpp +++ b/src/runtime/test/client_negotiation_test.cpp @@ -21,6 +21,7 @@ #include #include +#include namespace dsn { namespace security { @@ -45,6 +46,11 @@ class client_negotiation_test : public testing::Test _client_negotiation->handle_response(err, std::move(resp)); } + void on_mechanism_selected(const negotiation_response &resp) + { + _client_negotiation->on_mechanism_selected(resp); + } + const std::string &get_selected_mechanism() { return _client_negotiation->_selected_mechanism; } negotiation_status::type get_negotiation_status() { return _client_negotiation->_status; } @@ -112,5 +118,51 @@ TEST_F(client_negotiation_test, handle_response) ASSERT_EQ(get_negotiation_status(), test.neg_status); } } + +TEST_F(client_negotiation_test, on_mechanism_selected) +{ + struct + { + std::string sasl_init_result; + std::string sasl_start_result; + negotiation_status::type resp_status; + negotiation_status::type neg_status; + } tests[] = {{"ERR_OK", + "ERR_OK", + negotiation_status::type::SASL_SELECT_MECHANISMS_RESP, + negotiation_status::type::SASL_INITIATE}, + {"ERR_OK", + "ERR_SASL_INCOMPLETE", + negotiation_status::type::SASL_SELECT_MECHANISMS_RESP, + negotiation_status::type::SASL_INITIATE}, + {"ERR_OK", + "ERR_TIMEOUT", + negotiation_status::type::SASL_SELECT_MECHANISMS_RESP, + negotiation_status::type::SASL_AUTH_FAIL}, + {"ERR_TIMEOUT", + "ERR_OK", + negotiation_status::type::SASL_SELECT_MECHANISMS_RESP, + negotiation_status::type::SASL_AUTH_FAIL}, + {"ERR_OK", + "ERR_OK", + negotiation_status::type::SASL_SELECT_MECHANISMS, + negotiation_status::type::SASL_AUTH_FAIL}}; + + RPC_MOCKING(negotiation_rpc) + { + for (const auto &test : tests) { + fail::setup(); + fail::cfg("sasl_client_wrapper_init", "return(" + test.sasl_init_result + ")"); + fail::cfg("sasl_client_wrapper_start", "return(" + test.sasl_start_result + ")"); + + negotiation_response resp; + resp.status = test.resp_status; + on_mechanism_selected(resp); + ASSERT_EQ(get_negotiation_status(), test.neg_status); + + fail::teardown(); + } + } +} } // namespace security } // namespace dsn