Skip to content
This repository has been archived by the owner on Jun 23, 2022. It is now read-only.

refactor(security): use blob instead of std::string as the type for msg member of negotiationi_request #622

Merged
merged 3 commits into from
Sep 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/runtime/security/client_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void client_negotiation::on_recv_mechanisms(const negotiation_response &resp)

std::string match_mechanism;
std::vector<std::string> server_support_mechanisms;
std::string resp_string = resp.msg;
std::string resp_string = resp.msg.to_string();
utils::split_args(resp_string.c_str(), server_support_mechanisms, ',');

for (const std::string &server_support_mechanism : server_support_mechanisms) {
Expand Down Expand Up @@ -125,8 +125,8 @@ void client_negotiation::on_mechanism_selected(const negotiation_response &resp)
}

// start client sasl, and send `SASL_INITIATE` to `server_negotiation` if everything is ok
std::string start_output;
err_s = _sasl->start(_selected_mechanism, "", start_output);
blob start_output;
err_s = _sasl->start(_selected_mechanism, blob(), start_output);
if (err_s.is_ok() || ERR_SASL_INCOMPLETE == err_s.code()) {
_status = negotiation_status::type::SASL_INITIATE;
send(_status, std::move(start_output));
Expand All @@ -142,7 +142,7 @@ void client_negotiation::on_mechanism_selected(const negotiation_response &resp)
void client_negotiation::on_challenge(const negotiation_response &challenge)
{
if (challenge.status == negotiation_status::type::SASL_CHALLENGE) {
std::string response_msg;
blob response_msg;
auto err = _sasl->step(challenge.msg, response_msg);
if (!err.is_ok() && err.code() != ERR_SASL_INCOMPLETE) {
dwarn_f("{}: negotiation failed, reason = {}", _name, err.description());
Expand All @@ -169,10 +169,10 @@ void client_negotiation::select_mechanism(const std::string &mechanism)
_selected_mechanism = mechanism;
_status = negotiation_status::type::SASL_SELECT_MECHANISMS;

send(_status, std::move(mechanism));
send(_status, blob::create_from_bytes(mechanism.data(), mechanism.length()));
}

void client_negotiation::send(negotiation_status::type status, const std::string &&msg)
void client_negotiation::send(negotiation_status::type status, const blob &msg)
{
auto req = dsn::make_unique<negotiation_request>();
req->status = status;
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/client_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class client_negotiation : public negotiation

void list_mechanisms();
void select_mechanism(const std::string &mechanism);
void send(negotiation_status::type status, const std::string &&msg = "");
void send(negotiation_status::type status, const blob &msg = blob());
void succ_negotiation();

friend class client_negotiation_test;
Expand Down
12 changes: 5 additions & 7 deletions src/runtime/security/sasl_client_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ error_s sasl_client_wrapper::init()
return wrap_error(sasl_err);
}

error_s sasl_client_wrapper::start(const std::string &mechanism,
const std::string &input,
std::string &output)
error_s sasl_client_wrapper::start(const std::string &mechanism, const blob &input, blob &output)
{
FAIL_POINT_INJECT_F("sasl_client_wrapper_start", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
Expand All @@ -53,11 +51,11 @@ error_s sasl_client_wrapper::start(const std::string &mechanism,
int sasl_err =
sasl_client_start(_conn, mechanism.c_str(), nullptr, &msg, &msg_len, &client_mech);

output.assign(msg, msg_len);
output = blob::create_from_bytes(msg, msg_len);
return wrap_error(sasl_err);
}

error_s sasl_client_wrapper::step(const std::string &input, std::string &output)
error_s sasl_client_wrapper::step(const blob &input, blob &output)
{
FAIL_POINT_INJECT_F("sasl_client_wrapper_step", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
Expand All @@ -66,9 +64,9 @@ error_s sasl_client_wrapper::step(const std::string &input, std::string &output)

const char *msg = nullptr;
unsigned msg_len = 0;
int sasl_err = sasl_client_step(_conn, input.c_str(), input.length(), nullptr, &msg, &msg_len);
int sasl_err = sasl_client_step(_conn, input.data(), input.length(), nullptr, &msg, &msg_len);

output.assign(msg, msg_len);
output = blob::create_from_bytes(msg, msg_len);
return wrap_error(sasl_err);
}
} // namespace security
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/sasl_client_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class sasl_client_wrapper : public sasl_wrapper
~sasl_client_wrapper() = default;

error_s init();
error_s start(const std::string &mechanism, const std::string &input, std::string &output);
error_s step(const std::string &input, std::string &output);
error_s start(const std::string &mechanism, const blob &input, blob &output);
error_s step(const blob &input, blob &output);
};
} // namespace security
} // namespace dsn
14 changes: 6 additions & 8 deletions src/runtime/security/sasl_server_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ error_s sasl_server_wrapper::init()
return wrap_error(sasl_err);
}

error_s sasl_server_wrapper::start(const std::string &mechanism,
const std::string &input,
std::string &output)
error_s sasl_server_wrapper::start(const std::string &mechanism, const blob &input, blob &output)
{
FAIL_POINT_INJECT_F("sasl_server_wrapper_start", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
Expand All @@ -50,13 +48,13 @@ error_s sasl_server_wrapper::start(const std::string &mechanism,
const char *msg = nullptr;
unsigned msg_len = 0;
int sasl_err =
sasl_server_start(_conn, mechanism.c_str(), input.c_str(), input.length(), &msg, &msg_len);
sasl_server_start(_conn, mechanism.c_str(), input.data(), input.length(), &msg, &msg_len);

output.assign(msg, msg_len);
output = blob::create_from_bytes(msg, msg_len);
return wrap_error(sasl_err);
}

error_s sasl_server_wrapper::step(const std::string &input, std::string &output)
error_s sasl_server_wrapper::step(const blob &input, blob &output)
{
FAIL_POINT_INJECT_F("sasl_server_wrapper_step", [](dsn::string_view str) {
error_code err = error_code::try_get(str.data(), ERR_UNKNOWN);
Expand All @@ -65,9 +63,9 @@ error_s sasl_server_wrapper::step(const std::string &input, std::string &output)

const char *msg = nullptr;
unsigned msg_len = 0;
int sasl_err = sasl_server_step(_conn, input.c_str(), input.length(), &msg, &msg_len);
int sasl_err = sasl_server_step(_conn, input.data(), input.length(), &msg, &msg_len);

output.assign(msg, msg_len);
output = blob::create_from_bytes(msg, msg_len);
return wrap_error(sasl_err);
}
} // namespace security
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/sasl_server_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class sasl_server_wrapper : public sasl_wrapper
~sasl_server_wrapper() = default;

error_s init();
error_s start(const std::string &mechanism, const std::string &input, std::string &output);
error_s step(const std::string &input, std::string &output);
error_s start(const std::string &mechanism, const blob &input, blob &output);
error_s step(const blob &input, blob &output);
};
} // namespace security
} // namespace dsn
5 changes: 2 additions & 3 deletions src/runtime/security/sasl_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ class sasl_wrapper
virtual ~sasl_wrapper();

virtual error_s init() = 0;
virtual error_s
start(const std::string &mechanism, const std::string &input, std::string &output) = 0;
virtual error_s step(const std::string &input, std::string &output) = 0;
virtual error_s start(const std::string &mechanism, const blob &input, blob &output) = 0;
virtual error_s step(const blob &input, blob &output) = 0;

protected:
sasl_wrapper() = default;
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/security/security.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ enum negotiation_status {

struct negotiation_request {
1: negotiation_status status;
2: string msg;
2: dsn.blob msg;
}

struct negotiation_response {
1: negotiation_status status;
2: string msg;
2: dsn.blob msg;
}
20 changes: 10 additions & 10 deletions src/runtime/security/security_types.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions src/runtime/security/security_types.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 5 additions & 7 deletions src/runtime/security/server_negotiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void server_negotiation::on_list_mechanisms(negotiation_rpc rpc)
std::string mech_list = boost::join(supported_mechanisms, ",");
negotiation_response &response = rpc.response();
_status = response.status = negotiation_status::type::SASL_LIST_MECHANISMS_RESP;
response.msg = std::move(mech_list);
response.msg = blob::create_from_bytes(mech_list.data(), mech_list.length());
}

void server_negotiation::on_select_mechanism(negotiation_rpc rpc)
Expand All @@ -81,7 +81,7 @@ void server_negotiation::on_select_mechanism(negotiation_rpc rpc)
return;
}

_selected_mechanism = request.msg;
_selected_mechanism = request.msg.to_string();
if (supported_mechanisms.find(_selected_mechanism) == supported_mechanisms.end()) {
dwarn_f("the mechanism of {} is not supported", _selected_mechanism);
fail_negotiation();
Expand Down Expand Up @@ -110,7 +110,7 @@ void server_negotiation::on_initiate(negotiation_rpc rpc)
return;
}

std::string start_output;
blob start_output;
error_s err_s = _sasl->start(_selected_mechanism, request.msg, start_output);
return do_challenge(rpc, err_s, start_output);
}
Expand All @@ -123,14 +123,12 @@ void server_negotiation::on_challenge_resp(negotiation_rpc rpc)
return;
}

std::string resp_msg;
blob resp_msg;
error_s err_s = _sasl->step(request.msg, resp_msg);
return do_challenge(rpc, err_s, resp_msg);
}

void server_negotiation::do_challenge(negotiation_rpc rpc,
error_s err_s,
const std::string &resp_msg)
void server_negotiation::do_challenge(negotiation_rpc rpc, error_s err_s, const blob &resp_msg)
{
if (!err_s.is_ok() && err_s.code() != ERR_SASL_INCOMPLETE) {
dwarn_f("{}: negotiation failed, with err = {}, msg = {}",
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/security/server_negotiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class server_negotiation : public negotiation
void on_initiate(negotiation_rpc rpc);
void on_challenge_resp(negotiation_rpc rpc);

void do_challenge(negotiation_rpc rpc, error_s err_s, const std::string &resp_msg);
void do_challenge(negotiation_rpc rpc, error_s err_s, const blob &resp_msg);

friend class server_negotiation_test;
};
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/test/client_negotiation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TEST_F(client_negotiation_test, on_recv_mechanisms)
for (const auto &test : tests) {
negotiation_response resp;
resp.status = test.resp_status;
resp.msg = test.resp_msg;
resp.msg = blob::create_from_bytes(test.resp_msg.data(), test.resp_msg.length());
on_recv_mechanism(resp);

ASSERT_EQ(get_selected_mechanism(), test.selected_mechanism);
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/test/server_negotiation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class server_negotiation_test : public testing::Test
{
auto request = make_unique<negotiation_request>();
request->status = status;
request->msg = msg;
request->msg = dsn::blob::create_from_bytes(msg.data(), msg.length());
return negotiation_rpc(std::move(request), RPC_NEGOTIATION);
}

Expand Down Expand Up @@ -84,7 +84,7 @@ TEST_F(server_negotiation_test, on_list_mechanisms)
on_list_mechanisms(rpc);

ASSERT_EQ(rpc.response().status, test.resp_status);
ASSERT_EQ(rpc.response().msg, test.resp_msg);
ASSERT_EQ(rpc.response().msg.to_string(), test.resp_msg);
ASSERT_EQ(get_negotiation_status(), test.nego_status);
}
}
Expand Down