Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

Commit

Permalink
refactor(security): use blob instead of std::string as the type for m…
Browse files Browse the repository at this point in the history
…sg member of negotiation_request (#622)
  • Loading branch information
levy5307 authored Sep 10, 2020
1 parent bf610ce commit 8ec8ba6
Show file tree
Hide file tree
Showing 14 changed files with 51 additions and 58 deletions.
12 changes: 6 additions & 6 deletions src/runtime/security/client_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void client_negotiation::on_recv_mechanisms(const negotiation_response &resp)

std::string match_mechanism;
std::vector<std::string> server_support_mechanisms;
std::string resp_string = resp.msg;
std::string resp_string = resp.msg.to_string();
utils::split_args(resp_string.c_str(), server_support_mechanisms, ',');

for (const std::string &server_support_mechanism : server_support_mechanisms) {
Expand Down Expand Up @@ -125,8 +125,8 @@ void client_negotiation::on_mechanism_selected(const negotiation_response &resp)
}

// start client sasl, and send `SASL_INITIATE` to `server_negotiation` if everything is ok
std::string start_output;
err_s = _sasl->start(_selected_mechanism, "", start_output);
blob start_output;
err_s = _sasl->start(_selected_mechanism, blob(), start_output);
if (err_s.is_ok() || ERR_SASL_INCOMPLETE == err_s.code()) {
_status = negotiation_status::type::SASL_INITIATE;
send(_status, std::move(start_output));
Expand All @@ -142,7 +142,7 @@ void client_negotiation::on_mechanism_selected(const negotiation_response &resp)
void client_negotiation::on_challenge(const negotiation_response &challenge)
{
if (challenge.status == negotiation_status::type::SASL_CHALLENGE) {
std::string response_msg;
blob response_msg;
auto err = _sasl->step(challenge.msg, response_msg);
if (!err.is_ok() && err.code() != ERR_SASL_INCOMPLETE) {
dwarn_f("{}: negotiation failed, reason = {}", _name, err.description());
Expand All @@ -169,10 +169,10 @@ void client_negotiation::select_mechanism(const std::string &mechanism)
_selected_mechanism = mechanism;
_status = negotiation_status::type::SASL_SELECT_MECHANISMS;

send(_status, std::move(mechanism));
send(_status, blob::create_from_bytes(mechanism.data(), mechanism.length()));
}

void client_negotiation::send(negotiation_status::type status, const std::string &&msg)
void client_negotiation::send(negotiation_status::type status, const blob &msg)
{
auto req = dsn::make_unique<negotiation_request>();
req->status = status;
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/client_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class client_negotiation : public negotiation

void list_mechanisms();
void select_mechanism(const std::string &mechanism);
void send(negotiation_status::type status, const std::string &&msg = "");
void send(negotiation_status::type status, const blob &msg = blob());
void succ_negotiation();

friend class client_negotiation_test;
Expand Down
12 changes: 5 additions & 7 deletions src/runtime/security/sasl_client_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ error_s sasl_client_wrapper::init()
return wrap_error(sasl_err);
}

error_s sasl_client_wrapper::start(const std::string &mechanism,
const std::string &input,
std::string &output)
error_s sasl_client_wrapper::start(const std::string &mechanism, const blob &input, blob &output)
{
FAIL_POINT_INJECT_F("sasl_client_wrapper_start", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
Expand All @@ -53,11 +51,11 @@ error_s sasl_client_wrapper::start(const std::string &mechanism,
int sasl_err =
sasl_client_start(_conn, mechanism.c_str(), nullptr, &msg, &msg_len, &client_mech);

output.assign(msg, msg_len);
output = blob::create_from_bytes(msg, msg_len);
return wrap_error(sasl_err);
}

error_s sasl_client_wrapper::step(const std::string &input, std::string &output)
error_s sasl_client_wrapper::step(const blob &input, blob &output)
{
FAIL_POINT_INJECT_F("sasl_client_wrapper_step", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
Expand All @@ -66,9 +64,9 @@ error_s sasl_client_wrapper::step(const std::string &input, std::string &output)

const char *msg = nullptr;
unsigned msg_len = 0;
int sasl_err = sasl_client_step(_conn, input.c_str(), input.length(), nullptr, &msg, &msg_len);
int sasl_err = sasl_client_step(_conn, input.data(), input.length(), nullptr, &msg, &msg_len);

output.assign(msg, msg_len);
output = blob::create_from_bytes(msg, msg_len);
return wrap_error(sasl_err);
}
} // namespace security
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/sasl_client_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class sasl_client_wrapper : public sasl_wrapper
~sasl_client_wrapper() = default;

error_s init();
error_s start(const std::string &mechanism, const std::string &input, std::string &output);
error_s step(const std::string &input, std::string &output);
error_s start(const std::string &mechanism, const blob &input, blob &output);
error_s step(const blob &input, blob &output);
};
} // namespace security
} // namespace dsn
14 changes: 6 additions & 8 deletions src/runtime/security/sasl_server_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ error_s sasl_server_wrapper::init()
return wrap_error(sasl_err);
}

error_s sasl_server_wrapper::start(const std::string &mechanism,
const std::string &input,
std::string &output)
error_s sasl_server_wrapper::start(const std::string &mechanism, const blob &input, blob &output)
{
FAIL_POINT_INJECT_F("sasl_server_wrapper_start", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
Expand All @@ -50,13 +48,13 @@ error_s sasl_server_wrapper::start(const std::string &mechanism,
const char *msg = nullptr;
unsigned msg_len = 0;
int sasl_err =
sasl_server_start(_conn, mechanism.c_str(), input.c_str(), input.length(), &msg, &msg_len);
sasl_server_start(_conn, mechanism.c_str(), input.data(), input.length(), &msg, &msg_len);

output.assign(msg, msg_len);
output = blob::create_from_bytes(msg, msg_len);
return wrap_error(sasl_err);
}

error_s sasl_server_wrapper::step(const std::string &input, std::string &output)
error_s sasl_server_wrapper::step(const blob &input, blob &output)
{
FAIL_POINT_INJECT_F("sasl_server_wrapper_step", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
Expand All @@ -65,9 +63,9 @@ error_s sasl_server_wrapper::step(const std::string &input, std::string &output)

const char *msg = nullptr;
unsigned msg_len = 0;
int sasl_err = sasl_server_step(_conn, input.c_str(), input.length(), &msg, &msg_len);
int sasl_err = sasl_server_step(_conn, input.data(), input.length(), &msg, &msg_len);

output.assign(msg, msg_len);
output = blob::create_from_bytes(msg, msg_len);
return wrap_error(sasl_err);
}
} // namespace security
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/sasl_server_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class sasl_server_wrapper : public sasl_wrapper
~sasl_server_wrapper() = default;

error_s init();
error_s start(const std::string &mechanism, const std::string &input, std::string &output);
error_s step(const std::string &input, std::string &output);
error_s start(const std::string &mechanism, const blob &input, blob &output);
error_s step(const blob &input, blob &output);
};
} // namespace security
} // namespace dsn
5 changes: 2 additions & 3 deletions src/runtime/security/sasl_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ class sasl_wrapper
virtual ~sasl_wrapper();

virtual error_s init() = 0;
virtual error_s
start(const std::string &mechanism, const std::string &input, std::string &output) = 0;
virtual error_s step(const std::string &input, std::string &output) = 0;
virtual error_s start(const std::string &mechanism, const blob &input, blob &output) = 0;
virtual error_s step(const blob &input, blob &output) = 0;

protected:
sasl_wrapper() = default;
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/security.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ enum negotiation_status {

struct negotiation_request {
1: negotiation_status status;
2: string msg;
2: dsn.blob msg;
}

struct negotiation_response {
1: negotiation_status status;
2: string msg;
2: dsn.blob msg;
}
20 changes: 10 additions & 10 deletions src/runtime/security/security_types.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions src/runtime/security/security_types.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 5 additions & 7 deletions src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void server_negotiation::on_list_mechanisms(negotiation_rpc rpc)
std::string mech_list = boost::join(supported_mechanisms, ",");
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_LIST_MECHANISMS_RESP;
response.msg = std::move(mech_list);
response.msg = blob::create_from_bytes(mech_list.data(), mech_list.length());
}

void server_negotiation::on_select_mechanism(negotiation_rpc rpc)
Expand All @@ -81,7 +81,7 @@ void server_negotiation::on_select_mechanism(negotiation_rpc rpc)
return;
}

_selected_mechanism = request.msg;
_selected_mechanism = request.msg.to_string();
if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) {
dwarn_f("the mechanism of {} is not supported", _selected_mechanism);
fail_negotiation();
Expand Down Expand Up @@ -110,7 +110,7 @@ void server_negotiation::on_initiate(negotiation_rpc rpc)
return;
}

std::string start_output;
blob start_output;
error_s err_s = _sasl->start(_selected_mechanism, request.msg, start_output);
return do_challenge(rpc, err_s, start_output);
}
Expand All @@ -123,14 +123,12 @@ void server_negotiation::on_challenge_resp(negotiation_rpc rpc)
return;
}

std::string resp_msg;
blob resp_msg;
error_s err_s = _sasl->step(request.msg, resp_msg);
return do_challenge(rpc, err_s, resp_msg);
}

void server_negotiation::do_challenge(negotiation_rpc rpc,
error_s err_s,
const std::string &resp_msg)
void server_negotiation::do_challenge(negotiation_rpc rpc, error_s err_s, const blob &resp_msg)
{
if (!err_s.is_ok() && err_s.code() != ERR_SASL_INCOMPLETE) {
dwarn_f("{}: negotiation failed, with err = {}, msg = {}",
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/server_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class server_negotiation : public negotiation
void on_initiate(negotiation_rpc rpc);
void on_challenge_resp(negotiation_rpc rpc);

void do_challenge(negotiation_rpc rpc, error_s err_s, const std::string &resp_msg);
void do_challenge(negotiation_rpc rpc, error_s err_s, const blob &resp_msg);

friend class server_negotiation_test;
};
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/test/client_negotiation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TEST_F(client_negotiation_test, on_recv_mechanisms)
for (const auto &test : tests) {
negotiation_response resp;
resp.status = test.resp_status;
resp.msg = test.resp_msg;
resp.msg = blob::create_from_bytes(test.resp_msg.data(), test.resp_msg.length());
on_recv_mechanism(resp);

ASSERT_EQ(get_selected_mechanism(), test.selected_mechanism);
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/test/server_negotiation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class server_negotiation_test : public testing::Test
{
auto request = make_unique<negotiation_request>();
request->status = status;
request->msg = msg;
request->msg = dsn::blob::create_from_bytes(msg.data(), msg.length());
return negotiation_rpc(std::move(request), RPC_NEGOTIATION);
}

Expand Down Expand Up @@ -84,7 +84,7 @@ TEST_F(server_negotiation_test, on_list_mechanisms)
on_list_mechanisms(rpc);

ASSERT_EQ(rpc.response().status, test.resp_status);
ASSERT_EQ(rpc.response().msg, test.resp_msg);
ASSERT_EQ(rpc.response().msg.to_string(), test.resp_msg);
ASSERT_EQ(get_negotiation_status(), test.nego_status);
}
}
Expand Down

0 comments on commit 8ec8ba6

Please sign in to comment.