From 289346e7dde74104461567ba82805b3b51bfdfd9 Mon Sep 17 00:00:00 2001 From: zhao liwei Date: Fri, 20 Nov 2020 15:16:03 +0800 Subject: [PATCH] feat(security): implement meta server access controller (#655) --- include/dsn/tool-api/network.h | 7 ++ include/dsn/utility/error_code.h | 1 + include/dsn/utility/strings.h | 6 ++ src/meta/meta_service.cpp | 16 +++ src/meta/meta_service.h | 5 + src/runtime/rpc/network.cpp | 7 ++ src/runtime/security/access_controller.cpp | 47 +++++++++ src/runtime/security/access_controller.h | 54 ++++++++++ .../security/meta_access_controller.cpp | 69 +++++++++++++ src/runtime/security/meta_access_controller.h | 40 ++++++++ src/runtime/security/sasl_wrapper.cpp | 21 ++++ src/runtime/security/sasl_wrapper.h | 6 ++ src/runtime/security/server_negotiation.cpp | 15 ++- src/runtime/security/server_negotiation.h | 2 +- .../test/meta_access_controller_test.cpp | 99 +++++++++++++++++++ src/runtime/test/server_negotiation_test.cpp | 24 +++++ src/utils/strings.cpp | 10 ++ src/utils/test/utils.cpp | 4 + 18 files changed, 430 insertions(+), 3 deletions(-) create mode 100644 src/runtime/security/access_controller.cpp create mode 100644 src/runtime/security/access_controller.h create mode 100644 src/runtime/security/meta_access_controller.cpp create mode 100644 src/runtime/security/meta_access_controller.h create mode 100644 src/runtime/test/meta_access_controller_test.cpp diff --git a/include/dsn/tool-api/network.h b/include/dsn/tool-api/network.h index c261b517ff..506ba81e93 100644 --- a/include/dsn/tool-api/network.h +++ b/include/dsn/tool-api/network.h @@ -244,6 +244,9 @@ class rpc_session : public ref_counter void set_negotiation_succeed(); bool is_negotiation_succeed() const; + void set_client_username(const std::string &user_name); + const std::string &get_client_username() const; + public: /// /// for subclass to implement receiving message @@ -328,6 +331,10 @@ class rpc_session : public ref_counter rpc_client_matcher *_matcher; std::atomic_int _delay_server_receive_ms; + + // _client_username is only valid if it is a server rpc_session. + // it represents the name of the corresponding client + std::string _client_username; }; // --------- inline implementation -------------- diff --git a/include/dsn/utility/error_code.h b/include/dsn/utility/error_code.h index ee0703643e..b6738db5b5 100644 --- a/include/dsn/utility/error_code.h +++ b/include/dsn/utility/error_code.h @@ -125,4 +125,5 @@ DEFINE_ERR_CODE(ERR_KRB5_INTERNAL) DEFINE_ERR_CODE(ERR_SASL_INTERNAL) DEFINE_ERR_CODE(ERR_SASL_INCOMPLETE) +DEFINE_ERR_CODE(ERR_ACL_DENY) } // namespace dsn diff --git a/include/dsn/utility/strings.h b/include/dsn/utility/strings.h index 92eb64cbd4..a347448289 100644 --- a/include/dsn/utility/strings.h +++ b/include/dsn/utility/strings.h @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace dsn { @@ -14,6 +15,11 @@ void split_args(const char *args, char splitter = ' ', bool keep_place_holder = false); +void split_args(const char *args, + /*out*/ std::unordered_set &sargs, + char splitter = ' ', + bool keep_place_holder = false); + void split_args(const char *args, /*out*/ std::list &sargs, char splitter = ' '); // kv_map sample (when item_splitter = ',' and kv_splitter = ':'): diff --git a/src/meta/meta_service.cpp b/src/meta/meta_service.cpp index d56b64a715..1a211be2e4 100644 --- a/src/meta/meta_service.cpp +++ b/src/meta/meta_service.cpp @@ -47,6 +47,7 @@ #include "meta/duplication/meta_duplication_service.h" #include "meta_split_service.h" #include "meta_bulk_load_service.h" +#include "runtime/security/access_controller.h" namespace dsn { namespace replication { @@ -76,6 +77,8 @@ meta_service::meta_service() "replica server disconnect count in the recent period"); _unalive_nodes_count.init_app_counter( "eon.meta_service", "unalive_nodes", COUNTER_TYPE_NUMBER, "current count of unalive nodes"); + + _access_controller = security::create_meta_access_controller(); } meta_service::~meta_service() @@ -120,6 +123,12 @@ int meta_service::check_leader(TRpcHolder rpc, rpc_address *forward_address) template bool meta_service::check_status(TRpcHolder rpc, rpc_address *forward_address) { + if (!_access_controller->allowed(rpc.dsn_request())) { + rpc.response().err = ERR_ACL_DENY; + ddebug("reject request with ERR_ACL_DENY"); + return false; + } + int result = check_leader(rpc, forward_address); if (result == 0) return false; @@ -141,6 +150,13 @@ bool meta_service::check_status(TRpcHolder rpc, rpc_address *forward_address) template bool meta_service::check_status_with_msg(message_ex *req, TRespType &response_struct) { + if (!_access_controller->allowed(req)) { + ddebug("reject request with ERR_ACL_DENY"); + response_struct.err = ERR_ACL_DENY; + reply(req, response_struct); + return false; + } + int result = check_leader(req, nullptr); if (result == 0) { return false; diff --git a/src/meta/meta_service.h b/src/meta/meta_service.h index 900ab55fc0..cc98bae733 100644 --- a/src/meta/meta_service.h +++ b/src/meta/meta_service.h @@ -49,6 +49,9 @@ #include "block_service/block_service_manager.h" namespace dsn { +namespace security { +class access_controller; +} // namespace security namespace replication { class server_state; @@ -261,6 +264,8 @@ class meta_service : public serverlet perf_counter_wrapper _unalive_nodes_count; dsn::task_tracker _tracker; + + std::unique_ptr _access_controller; }; } // namespace replication diff --git a/src/runtime/rpc/network.cpp b/src/runtime/rpc/network.cpp index 53413d1e06..3479b7ce45 100644 --- a/src/runtime/rpc/network.cpp +++ b/src/runtime/rpc/network.cpp @@ -495,6 +495,13 @@ bool rpc_session::is_negotiation_succeed() const } } +void rpc_session::set_client_username(const std::string &user_name) +{ + _client_username = user_name; +} + +const std::string &rpc_session::get_client_username() const { return _client_username; } + //////////////////////////////////////////////////////////////////////////////////////////////// network::network(rpc_engine *srv, network *inner_provider) : _engine(srv), _client_hdr_format(NET_HDR_DSN), _unknown_msg_header_format(NET_HDR_INVALID) diff --git a/src/runtime/security/access_controller.cpp b/src/runtime/security/access_controller.cpp new file mode 100644 index 0000000000..1ea8053795 --- /dev/null +++ b/src/runtime/security/access_controller.cpp @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "access_controller.h" + +#include +#include +#include +#include "meta_access_controller.h" + +namespace dsn { +namespace security { +DSN_DEFINE_bool("security", enable_acl, false, "whether enable access controller or not"); +DSN_DEFINE_string("security", super_users, "", "super user for access controller"); + +access_controller::access_controller() { utils::split_args(FLAGS_super_users, _super_users, ','); } + +access_controller::~access_controller() {} + +bool access_controller::pre_check(const std::string &user_name) +{ + if (!FLAGS_enable_acl || _super_users.find(user_name) != _super_users.end()) { + return true; + } + return false; +} + +std::unique_ptr create_meta_access_controller() +{ + return make_unique(); +} +} // namespace security +} // namespace dsn diff --git a/src/runtime/security/access_controller.h b/src/runtime/security/access_controller.h new file mode 100644 index 0000000000..517161fd2d --- /dev/null +++ b/src/runtime/security/access_controller.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +namespace dsn { +class message_ex; +namespace security { + +class access_controller +{ +public: + access_controller(); + virtual ~access_controller() = 0; + + /** + * reset the access controller + * acls - the new acls to reset + **/ + virtual void reset(const std::string &acls){}; + + /** + * check if the message received is allowd to do something. + * msg - the message received + **/ + virtual bool allowed(message_ex *msg) = 0; + +protected: + bool pre_check(const std::string &user_name); + friend class meta_access_controller_test; + + std::unordered_set _super_users; +}; + +std::unique_ptr create_meta_access_controller(); +} // namespace security +} // namespace dsn diff --git a/src/runtime/security/meta_access_controller.cpp b/src/runtime/security/meta_access_controller.cpp new file mode 100644 index 0000000000..7c8c03b077 --- /dev/null +++ b/src/runtime/security/meta_access_controller.cpp @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "meta_access_controller.h" + +#include +#include +#include +#include + +namespace dsn { +namespace security { +DSN_DEFINE_string("security", + meta_acl_rpc_allow_list, + "", + "allowed list of rpc codes for meta_access_controller"); + +meta_access_controller::meta_access_controller() +{ + // MetaServer serves the allow-list RPC from all users. RPCs unincluded are accessible to only + // superusers. + if (strlen(FLAGS_meta_acl_rpc_allow_list) == 0) { + register_allowed_list("RPC_CM_LIST_APPS"); + register_allowed_list("RPC_CM_LIST_NODES"); + register_allowed_list("RPC_CM_CLUSTER_INFO"); + register_allowed_list("RPC_CM_QUERY_PARTITION_CONFIG_BY_INDEX"); + } else { + std::vector rpc_code_white_list; + utils::split_args(FLAGS_meta_acl_rpc_allow_list, rpc_code_white_list, ','); + for (const auto &rpc_code : rpc_code_white_list) { + register_allowed_list(rpc_code); + } + } +} + +bool meta_access_controller::allowed(message_ex *msg) +{ + if (pre_check(msg->io_session->get_client_username()) || + _allowed_rpc_code_list.find(msg->rpc_code().code()) != _allowed_rpc_code_list.end()) { + return true; + } + return false; +} + +void meta_access_controller::register_allowed_list(const std::string &rpc_code) +{ + auto code = task_code::try_get(rpc_code, TASK_CODE_INVALID); + dassert_f(code != TASK_CODE_INVALID, + "invalid task code({}) in rpc_code_white_list of security section", + rpc_code); + + _allowed_rpc_code_list.insert(code); +} +} // namespace security +} // namespace dsn diff --git a/src/runtime/security/meta_access_controller.h b/src/runtime/security/meta_access_controller.h new file mode 100644 index 0000000000..8c41332b1b --- /dev/null +++ b/src/runtime/security/meta_access_controller.h @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "access_controller.h" + +#include + +namespace dsn { +class message_ex; +namespace security { + +class meta_access_controller : public access_controller +{ +public: + meta_access_controller(); + bool allowed(message_ex *msg) override; + +private: + void register_allowed_list(const std::string &rpc_code); + + std::unordered_set _allowed_rpc_code_list; +}; +} // namespace security +} // namespace dsn diff --git a/src/runtime/security/sasl_wrapper.cpp b/src/runtime/security/sasl_wrapper.cpp index 91cc05079d..e8f855cf7e 100644 --- a/src/runtime/security/sasl_wrapper.cpp +++ b/src/runtime/security/sasl_wrapper.cpp @@ -20,6 +20,7 @@ #include "sasl_client_wrapper.h" #include +#include namespace dsn { namespace security { @@ -39,6 +40,26 @@ sasl_wrapper::~sasl_wrapper() } } +error_s sasl_wrapper::retrive_username(std::string &output) +{ + FAIL_POINT_INJECT_F("sasl_wrapper_retrive_username", [](dsn::string_view str) { + error_code err = error_code::try_get(str.data(), ERR_UNKNOWN); + return error_s::make(err); + }); + + // retrive username from _conn. + // If this is a sasl server, it gets the name of the corresponding sasl client. + // But if this is a sasl client, it gets the name of itself + char *username = nullptr; + error_s err_s = wrap_error(sasl_getprop(_conn, SASL_USERNAME, (const void **)&username)); + if (err_s.is_ok()) { + output = username; + output = output.substr(0, output.find_last_of('@')); + output = output.substr(0, output.find_first_of('/')); + } + return err_s; +} + error_s sasl_wrapper::wrap_error(int sasl_err) { error_s ret; diff --git a/src/runtime/security/sasl_wrapper.h b/src/runtime/security/sasl_wrapper.h index d40b5e7636..57bc7b231d 100644 --- a/src/runtime/security/sasl_wrapper.h +++ b/src/runtime/security/sasl_wrapper.h @@ -31,6 +31,12 @@ class sasl_wrapper virtual error_s init() = 0; virtual error_s start(const std::string &mechanism, const blob &input, blob &output) = 0; virtual error_s step(const blob &input, blob &output) = 0; + /** + * retrive username from sasl connection. + * If this is a sasl server, it gets the name of the corresponding sasl client. + * But if this is a sasl client, it gets the name of itself + **/ + error_s retrive_username(/*out*/ std::string &output); protected: sasl_wrapper() = default; diff --git a/src/runtime/security/server_negotiation.cpp b/src/runtime/security/server_negotiation.cpp index 3f02a3035a..a8fae1b4d9 100644 --- a/src/runtime/security/server_negotiation.cpp +++ b/src/runtime/security/server_negotiation.cpp @@ -140,7 +140,17 @@ void server_negotiation::do_challenge(negotiation_rpc rpc, error_s err_s, const } if (err_s.is_ok()) { - succ_negotiation(rpc); + std::string user_name; + auto retrive_err = _sasl->retrive_username(user_name); + if (retrive_err.is_ok()) { + succ_negotiation(rpc, user_name); + } else { + dwarn_f("{}: retrive user name failed: with err = {}, msg = {}", + _name, + retrive_err.code().to_string(), + retrive_err.description()); + fail_negotiation(); + } } else { negotiation_response &challenge = rpc.response(); _status = challenge.status = negotiation_status::type::SASL_CHALLENGE; @@ -148,10 +158,11 @@ void server_negotiation::do_challenge(negotiation_rpc rpc, error_s err_s, const } } -void server_negotiation::succ_negotiation(negotiation_rpc rpc) +void server_negotiation::succ_negotiation(negotiation_rpc rpc, const std::string &user_name) { negotiation_response &response = rpc.response(); _status = response.status = negotiation_status::type::SASL_SUCC; + _session->set_client_username(user_name); _session->set_negotiation_succeed(); ddebug_f("{}: negotiation succeed", _name); } diff --git a/src/runtime/security/server_negotiation.h b/src/runtime/security/server_negotiation.h index 56db2a6348..5f25f12d65 100644 --- a/src/runtime/security/server_negotiation.h +++ b/src/runtime/security/server_negotiation.h @@ -40,7 +40,7 @@ class server_negotiation : public negotiation void on_challenge_resp(negotiation_rpc rpc); void do_challenge(negotiation_rpc rpc, error_s err_s, const blob &resp_msg); - void succ_negotiation(negotiation_rpc rpc); + void succ_negotiation(negotiation_rpc rpc, const std::string &user_name); friend class server_negotiation_test; }; diff --git a/src/runtime/test/meta_access_controller_test.cpp b/src/runtime/test/meta_access_controller_test.cpp new file mode 100644 index 0000000000..19215d7a86 --- /dev/null +++ b/src/runtime/test/meta_access_controller_test.cpp @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include "runtime/security/access_controller.h" + +namespace dsn { +namespace security { +DSN_DECLARE_bool(enable_acl); + +class meta_access_controller_test : public testing::Test +{ +public: + meta_access_controller_test() { _meta_access_controller = create_meta_access_controller(); } + + void set_super_user(const std::string &super_user) + { + _meta_access_controller->_super_users.insert(super_user); + } + + bool pre_check(const std::string &user_name) + { + return _meta_access_controller->pre_check(user_name); + } + + bool allowed(dsn::message_ex *msg) { return _meta_access_controller->allowed(msg); } + + std::unique_ptr _meta_access_controller; +}; + +TEST_F(meta_access_controller_test, pre_check) +{ + const std::string SUPER_USER_NAME = "super_user"; + struct + { + bool enable_acl; + std::string user_name; + bool result; + } tests[] = {{true, "not_super_user", false}, + {false, "not_super_user", true}, + {true, SUPER_USER_NAME, true}}; + + bool origin_enable_acl = FLAGS_enable_acl; + set_super_user(SUPER_USER_NAME); + + for (const auto &test : tests) { + FLAGS_enable_acl = test.enable_acl; + ASSERT_EQ(pre_check(test.user_name), test.result); + } + + FLAGS_enable_acl = origin_enable_acl; +} + +TEST_F(meta_access_controller_test, allowed) +{ + struct + { + task_code rpc_code; + bool result; + } tests[] = {{RPC_CM_LIST_APPS, true}, + {RPC_CM_LIST_NODES, true}, + {RPC_CM_CLUSTER_INFO, true}, + {RPC_CM_QUERY_PARTITION_CONFIG_BY_INDEX, true}, + {RPC_CM_START_RECOVERY, false}}; + + bool origin_enable_acl = FLAGS_enable_acl; + FLAGS_enable_acl = true; + + std::unique_ptr sim_net( + new tools::sim_network_provider(nullptr, nullptr)); + auto sim_session = sim_net->create_client_session(rpc_address("localhost", 10086)); + for (const auto &test : tests) { + dsn::message_ptr msg = message_ex::create_request(test.rpc_code); + msg->io_session = sim_session; + + ASSERT_EQ(allowed(msg), test.result); + } + + FLAGS_enable_acl = origin_enable_acl; +} +} // namespace security +} // namespace dsn diff --git a/src/runtime/test/server_negotiation_test.cpp b/src/runtime/test/server_negotiation_test.cpp index c75dbbfbc3..394a605f07 100644 --- a/src/runtime/test/server_negotiation_test.cpp +++ b/src/runtime/test/server_negotiation_test.cpp @@ -143,23 +143,33 @@ TEST_F(server_negotiation_test, on_initiate) struct { std::string sasl_start_result; + std::string sasl_retrive_username_result; negotiation_status::type req_status; negotiation_status::type resp_status; negotiation_status::type nego_status; } tests[] = { {"ERR_TIMEOUT", + "ERR_OK", negotiation_status::type::SASL_INITIATE, negotiation_status::type::INVALID, negotiation_status::type::SASL_AUTH_FAIL}, {"ERR_OK", + "ERR_OK", negotiation_status::type::SASL_SELECT_MECHANISMS, negotiation_status::type::INVALID, negotiation_status::type::SASL_AUTH_FAIL}, + {"ERR_OK", + "ERR_TIMEOUT", + negotiation_status::type::SASL_INITIATE, + negotiation_status::type::INVALID, + negotiation_status::type::SASL_AUTH_FAIL}, {"ERR_SASL_INCOMPLETE", + "ERR_OK", negotiation_status::type::SASL_INITIATE, negotiation_status::type::SASL_CHALLENGE, negotiation_status::type::SASL_CHALLENGE}, {"ERR_OK", + "ERR_OK", negotiation_status::type::SASL_INITIATE, negotiation_status::type::SASL_SUCC, negotiation_status::type::SASL_SUCC}, @@ -170,6 +180,8 @@ TEST_F(server_negotiation_test, on_initiate) for (const auto &test : tests) { fail::setup(); fail::cfg("sasl_server_wrapper_start", "return(" + test.sasl_start_result + ")"); + fail::cfg("sasl_wrapper_retrive_username", + "return(" + test.sasl_retrive_username_result + ")"); auto rpc = create_negotiation_rpc(test.req_status, ""); on_initiate(rpc); @@ -186,22 +198,32 @@ TEST_F(server_negotiation_test, on_challenge_resp) struct { std::string sasl_step_result; + std::string sasl_retrive_username_result; negotiation_status::type req_status; negotiation_status::type resp_status; negotiation_status::type nego_status; } tests[] = {{"ERR_TIMEOUT", + "ERR_OK", negotiation_status::type::SASL_CHALLENGE_RESP, negotiation_status::type::INVALID, negotiation_status::type::SASL_AUTH_FAIL}, {"ERR_OK", + "ERR_OK", negotiation_status::type::SASL_SELECT_MECHANISMS, negotiation_status::type::INVALID, negotiation_status::type::SASL_AUTH_FAIL}, + {"ERR_OK", + "ERR_TIMEOUT", + negotiation_status::type::SASL_CHALLENGE_RESP, + negotiation_status::type::INVALID, + negotiation_status::type::SASL_AUTH_FAIL}, {"ERR_SASL_INCOMPLETE", + "ERR_OK", negotiation_status::type::SASL_CHALLENGE_RESP, negotiation_status::type::SASL_CHALLENGE, negotiation_status::type::SASL_CHALLENGE}, {"ERR_OK", + "ERR_OK", negotiation_status::type::SASL_CHALLENGE_RESP, negotiation_status::type::SASL_SUCC, negotiation_status::type::SASL_SUCC}}; @@ -211,6 +233,8 @@ TEST_F(server_negotiation_test, on_challenge_resp) for (const auto &test : tests) { fail::setup(); fail::cfg("sasl_server_wrapper_step", "return(" + test.sasl_step_result + ")"); + fail::cfg("sasl_wrapper_retrive_username", + "return(" + test.sasl_retrive_username_result + ")"); auto rpc = create_negotiation_rpc(test.req_status, ""); on_challenge_resp(rpc); diff --git a/src/utils/strings.cpp b/src/utils/strings.cpp index 83f6a4dbbd..00a9148a65 100644 --- a/src/utils/strings.cpp +++ b/src/utils/strings.cpp @@ -54,6 +54,16 @@ void split_args(const char *args, } } +void split_args(const char *args, + /*out*/ std::unordered_set &sargs, + char splitter, + bool keep_place_holder) +{ + std::vector sargs_vec; + split_args(args, sargs_vec, splitter, keep_place_holder); + sargs.insert(sargs_vec.begin(), sargs_vec.end()); +} + void split_args(const char *args, /*out*/ std::list &sargs, char splitter) { sargs.clear(); diff --git a/src/utils/test/utils.cpp b/src/utils/test/utils.cpp index 2e10d20c0b..6dc5c9d9ef 100644 --- a/src/utils/test/utils.cpp +++ b/src/utils/test/utils.cpp @@ -102,6 +102,10 @@ TEST(core, split_args) EXPECT_EQ(*it++, "a"); EXPECT_EQ(*it++, "b"); EXPECT_EQ(*it++, "c"); + + std::unordered_set sargs_set; + dsn::utils::split_args(value.c_str(), sargs_set, ','); + EXPECT_EQ(sargs_set.size(), 3); } TEST(core, split_args_keep_place_holder)