Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

feat(security): client_negotiation handle mechanism selected response #612

Merged
merged 22 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/dsn/utility/error_code.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,5 @@ DEFINE_ERR_CODE(ERR_UNAUTHENTICATED)
DEFINE_ERR_CODE(ERR_KRB5_INTERNAL)

DEFINE_ERR_CODE(ERR_SASL_INTERNAL)
DEFINE_ERR_CODE(ERR_SASL_INCOMPLETE)
} // namespace dsn
43 changes: 37 additions & 6 deletions src/runtime/security/client_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void client_negotiation::handle_response(error_code err, const negotiation_respo
on_recv_mechanisms(response);
break;
case negotiation_status::type::SASL_SELECT_MECHANISMS:
// TBD(zlw)
on_mechanism_selected(response);
break;
case negotiation_status::type::SASL_INITIATE:
case negotiation_status::type::SASL_CHALLENGE_RESP:
Expand All @@ -79,11 +79,7 @@ void client_negotiation::handle_response(error_code err, const negotiation_respo

void client_negotiation::on_recv_mechanisms(const negotiation_response &resp)
{
if (resp.status != negotiation_status::type::SASL_LIST_MECHANISMS_RESP) {
dwarn_f("{}: get message({}) while expect({})",
_name,
enum_to_string(resp.status),
enum_to_string(negotiation_status::type::SASL_LIST_MECHANISMS_RESP));
if (!check_status(resp.status, negotiation_status::type::SASL_LIST_MECHANISMS_RESP)) {
fail_negotiation();
return;
}
Expand Down Expand Up @@ -111,6 +107,41 @@ void client_negotiation::on_recv_mechanisms(const negotiation_response &resp)
select_mechanism(match_mechanism);
}

void client_negotiation::on_mechanism_selected(const negotiation_response &resp)
{
if (!check_status(resp.status, negotiation_status::type::SASL_SELECT_MECHANISMS_RESP)) {
fail_negotiation();
return;
}

// init client sasl
auto err_s = _sasl->init();
if (!err_s.is_ok()) {
dwarn_f("{}: initialize sasl client failed, error = {}, reason = {}",
_name,
err_s.code().to_string(),
err_s.description());
fail_negotiation();
return;
}

// start client sasl, and send `SASL_INITIATE` to `server_negotiation` if everything is ok
std::string start_output;
err_s = _sasl->start(_selected_mechanism, "", start_output);
if (err_s.is_ok() || ERR_SASL_INCOMPLETE == err_s.code()) {
auto req = dsn::make_unique<negotiation_request>();
_status = req->status = negotiation_status::type::SASL_INITIATE;
req->msg = start_output;
send(std::move(req));
} else {
dwarn_f("{}: start sasl client failed, error = {}, reason = {}",
_name,
err_s.code().to_string(),
err_s.description());
fail_negotiation();
}
}

void client_negotiation::select_mechanism(const std::string &mechanism)
{
_selected_mechanism = mechanism;
Expand Down
1 change: 1 addition & 0 deletions src/runtime/security/client_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class client_negotiation : public negotiation
private:
void handle_response(error_code err, const negotiation_response &&response);
void on_recv_mechanisms(const negotiation_response &resp);
void on_mechanism_selected(const negotiation_response &resp);

void list_mechanisms();
void select_mechanism(const std::string &mechanism);
Expand Down
15 changes: 15 additions & 0 deletions src/runtime/security/negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
#include "negotiation.h"
#include "client_negotiation.h"
#include "server_negotiation.h"
#include "negotiation_utils.h"

#include <dsn/utility/flags.h>
#include <dsn/utility/smart_pointers.h>
#include <dsn/dist/fmt_logging.h>

namespace dsn {
namespace security {
Expand Down Expand Up @@ -48,5 +50,18 @@ void negotiation::fail_negotiation()
_session->on_failure(true);
}

bool negotiation::check_status(negotiation_status::type status,
negotiation_status::type expected_status)
{
if (status != expected_status) {
dwarn_f("{}: get message({}) while expect({})",
_name,
enum_to_string(status),
enum_to_string(expected_status));
return false;
}

return true;
}
} // namespace security
} // namespace dsn
5 changes: 5 additions & 0 deletions src/runtime/security/negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class negotiation
virtual void start() = 0;
bool negotiation_succeed() const { return _status == negotiation_status::type::SASL_SUCC; }
void fail_negotiation();
// check whether the status is equal to expected_status
// ret value:
// true: status == expected_status
// false: status != expected_status
bool check_status(negotiation_status::type status, negotiation_status::type expected_status);

protected:
// The ownership of the negotiation instance is held by rpc_session.
Expand Down
27 changes: 23 additions & 4 deletions src/runtime/security/sasl_client_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

#include "sasl_client_wrapper.h"

#include <sasl/sasl.h>
#include <dsn/utility/flags.h>
#include <dsn/utility/fail_point.h>

namespace dsn {
namespace security {
Expand All @@ -26,16 +28,33 @@ DSN_DECLARE_string(service_name);

error_s sasl_client_wrapper::init()
{
// TBD(zlw)
return error_s::make(ERR_OK);
FAIL_POINT_INJECT_F("sasl_client_wrapper_init", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
return error_s::make(err);
});

int sasl_err = sasl_client_new(
FLAGS_service_name, FLAGS_service_fqdn, nullptr, nullptr, nullptr, 0, &_conn);
return wrap_error(sasl_err);
}

error_s sasl_client_wrapper::start(const std::string &mechanism,
const std::string &input,
std::string &output)
{
// TBD(zlw)
return error_s::make(ERR_OK);
FAIL_POINT_INJECT_F("sasl_client_wrapper_start", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
return error_s::make(err);
});

const char *msg = nullptr;
unsigned msg_len = 0;
const char *client_mech = nullptr;
int sasl_err =
sasl_client_start(_conn, mechanism.c_str(), nullptr, &msg, &msg_len, &client_mech);

output.assign(msg, msg_len);
return wrap_error(sasl_err);
}

error_s sasl_client_wrapper::step(const std::string &input, std::string &output)
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/sasl_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ error_s sasl_wrapper::wrap_error(int sasl_err)
case SASL_OK:
return error_s::make(ERR_OK);
case SASL_CONTINUE:
return error_s::make(ERR_NOT_IMPLEMENTED);
return error_s::make(ERR_SASL_INCOMPLETE);
case SASL_FAIL: // Generic failure (encompasses missing krb5 credentials).
case SASL_BADAUTH: // Authentication failure.
case SASL_BADMAC: // Decode failure.
Expand Down
58 changes: 25 additions & 33 deletions src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,52 +60,44 @@ void server_negotiation::handle_request(negotiation_rpc rpc)

void server_negotiation::on_list_mechanisms(negotiation_rpc rpc)
{
if (rpc.request().status == negotiation_status::type::SASL_LIST_MECHANISMS) {
std::string mech_list = boost::join(supported_mechanisms, ",");
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_LIST_MECHANISMS_RESP;
response.msg = std::move(mech_list);
} else {
ddebug_f("{}: got message({}) while expect({})",
_name,
enum_to_string(rpc.request().status),
enum_to_string(negotiation_status::type::SASL_LIST_MECHANISMS));
if (!check_status(rpc.request().status, negotiation_status::type::SASL_LIST_MECHANISMS)) {
fail_negotiation();
return;
}
return;

std::string mech_list = boost::join(supported_mechanisms, ",");
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_LIST_MECHANISMS_RESP;
response.msg = std::move(mech_list);
}

void server_negotiation::on_select_mechanism(negotiation_rpc rpc)
{
const negotiation_request &request = rpc.request();
if (request.status == negotiation_status::type::SASL_SELECT_MECHANISMS) {
_selected_mechanism = request.msg;
if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) {
dwarn_f("the mechanism of {} is not supported", _selected_mechanism);
fail_negotiation();
return;
}
if (!check_status(rpc.request().status, negotiation_status::type::SASL_SELECT_MECHANISMS)) {
fail_negotiation();
return;
}

error_s err_s = _sasl->init();
if (!err_s.is_ok()) {
dwarn_f("{}: server initialize sasl failed, error = {}, msg = {}",
_name,
err_s.code().to_string(),
err_s.description());
fail_negotiation();
return;
}
_selected_mechanism = request.msg;
if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) {
dwarn_f("the mechanism of {} is not supported", _selected_mechanism);
fail_negotiation();
return;
}

negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP;
} else {
dwarn_f("{}: got message({}) while expect({})",
error_s err_s = _sasl->init();
if (!err_s.is_ok()) {
dwarn_f("{}: server initialize sasl failed, error = {}, msg = {}",
_name,
enum_to_string(request.status),
negotiation_status::type::SASL_SELECT_MECHANISMS);
err_s.code().to_string(),
err_s.description());
fail_negotiation();
return;
}

negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_SELECT_MECHANISMS_RESP;
Comment on lines +77 to +100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good refactor!

}
} // namespace security
} // namespace dsn
52 changes: 52 additions & 0 deletions src/runtime/test/client_negotiation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtest/gtest.h>
#include <dsn/utility/flags.h>
#include <dsn/utility/fail_point.h>

namespace dsn {
namespace security {
Expand All @@ -45,6 +46,11 @@ class client_negotiation_test : public testing::Test
_client_negotiation->handle_response(err, std::move(resp));
}

void on_mechanism_selected(const negotiation_response &resp)
{
_client_negotiation->on_mechanism_selected(resp);
}

const std::string &get_selected_mechanism() { return _client_negotiation->_selected_mechanism; }

negotiation_status::type get_negotiation_status() { return _client_negotiation->_status; }
Expand Down Expand Up @@ -112,5 +118,51 @@ TEST_F(client_negotiation_test, handle_response)
ASSERT_EQ(get_negotiation_status(), test.neg_status);
}
}

TEST_F(client_negotiation_test, on_mechanism_selected)
{
struct
{
std::string sasl_init_result;
std::string sasl_start_result;
negotiation_status::type resp_status;
negotiation_status::type neg_status;
} tests[] = {{"ERR_OK",
"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_INITIATE},
{"ERR_OK",
"ERR_SASL_INCOMPLETE",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_INITIATE},
{"ERR_OK",
"ERR_TIMEOUT",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_AUTH_FAIL},
{"ERR_TIMEOUT",
"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS_RESP,
negotiation_status::type::SASL_AUTH_FAIL},
{"ERR_OK",
"ERR_OK",
negotiation_status::type::SASL_SELECT_MECHANISMS,
negotiation_status::type::SASL_AUTH_FAIL}};

RPC_MOCKING(negotiation_rpc)
{
for (const auto &test : tests) {
fail::setup();
fail::cfg("sasl_client_wrapper_init", "return(" + test.sasl_init_result + ")");
fail::cfg("sasl_client_wrapper_start", "return(" + test.sasl_start_result + ")");

negotiation_response resp;
resp.status = test.resp_status;
on_mechanism_selected(resp);
ASSERT_EQ(get_negotiation_status(), test.neg_status);

fail::teardown();
}
}
}
} // namespace security
} // namespace dsn