From 77db7c4c514e2abbc7b5d5bef546358289e6c6a5 Mon Sep 17 00:00:00 2001 From: VladSemenyuk Date: Tue, 2 Mar 2021 08:57:24 +0200 Subject: [PATCH 1/3] Fix NACK sending for Start RPC service with invalid data in bson payload. Fix memory leak with session after Start RPC service NACK. Fix OnServiceUpdate sending after Start RPC service with invalid data in bson payload for protected mode. Fix NACK sending for Start any service with non-existent session id. --- .../src/connection_handler_impl.cc | 20 +- .../test/connection_handler_impl_test.cc | 8 +- .../protocol_handler/protocol_handler.h | 2 +- .../protocol_handler/session_observer.h | 7 +- .../protocol_handler/mock_protocol_handler.h | 2 +- .../protocol_handler/protocol_handler_impl.h | 22 +- .../protocol_handler/src/handshake_handler.cc | 6 +- .../src/protocol_handler_impl.cc | 157 ++++++++----- .../test/protocol_handler_tm_test.cc | 217 ++++++++++-------- 9 files changed, 271 insertions(+), 170 deletions(-) diff --git a/src/components/connection_handler/src/connection_handler_impl.cc b/src/components/connection_handler/src/connection_handler_impl.cc index da269721eb4..2bc0aeb2e9e 100644 --- a/src/components/connection_handler/src/connection_handler_impl.cc +++ b/src/components/connection_handler/src/connection_handler_impl.cc @@ -547,11 +547,16 @@ void ConnectionHandlerImpl::OnSessionStartedCallback( session_key, service_type, params); - } else { + } +#ifdef BUILD_TESTS + else { + // FIXME (VSemenyuk): This code is only used in unit tests, so should be + // removed. ConnectionHandler unit tests should be fixed. if (protocol_handler_) { protocol_handler_->NotifySessionStarted(context, rejected_params); } } +#endif } void ConnectionHandlerImpl::NotifyServiceStartedResult( @@ -589,17 +594,20 @@ void ConnectionHandlerImpl::NotifyServiceStartedResult( if (!result) { SDL_LOG_WARN("Service starting forbidden by connection_handler_observer"); + context.is_start_session_failed_ = true; + } + + if (protocol_handler_) { + protocol_handler_->NotifySessionStarted(context, rejected_params, reason); + } + + if (context.is_start_session_failed_) { if (protocol_handler::kRpc == context.service_type_) { connection->RemoveSession(context.new_session_id_); } else { connection->RemoveService(context.initial_session_id_, context.service_type_); } - context.new_session_id_ = 0; - } - - if (protocol_handler_ != NULL) { - protocol_handler_->NotifySessionStarted(context, rejected_params, reason); } } diff --git a/src/components/connection_handler/test/connection_handler_impl_test.cc b/src/components/connection_handler/test/connection_handler_impl_test.cc index 1e66454ce78..f592915e0e6 100644 --- a/src/components/connection_handler/test/connection_handler_impl_test.cc +++ b/src/components/connection_handler/test/connection_handler_impl_test.cc @@ -1536,7 +1536,7 @@ TEST_F(ConnectionHandlerTest, ServiceStarted_Video_FAILURE) { PROTECTION_OFF, dummy_params); - EXPECT_EQ(0u, out_context_.new_session_id_); + EXPECT_TRUE(out_context_.is_start_session_failed_); } /* @@ -1637,8 +1637,10 @@ TEST_F(ConnectionHandlerTest, ServiceStarted_Video_Multiple) { PROTECTION_OFF, dummy_params); - EXPECT_NE(0u, new_context_first.new_session_id_); // result is positive - EXPECT_EQ(0u, new_context_second.new_session_id_); // result is negative + EXPECT_FALSE( + new_context_first.is_start_session_failed_); // result is positive + EXPECT_TRUE( + new_context_second.is_start_session_failed_); // result is negative } TEST_F(ConnectionHandlerTest, diff --git a/src/components/include/protocol_handler/protocol_handler.h b/src/components/include/protocol_handler/protocol_handler.h index b9cc2e97006..2cea958bae4 100644 --- a/src/components/include/protocol_handler/protocol_handler.h +++ b/src/components/include/protocol_handler/protocol_handler.h @@ -138,7 +138,7 @@ class ProtocolHandler { * generated_session_id is 0. */ virtual void NotifySessionStarted( - const SessionContext& context, + SessionContext& context, std::vector& rejected_params, const std::string err_reason = std::string()) = 0; diff --git a/src/components/include/protocol_handler/session_observer.h b/src/components/include/protocol_handler/session_observer.h index 593ce8408c8..ca12f4b6ad2 100644 --- a/src/components/include/protocol_handler/session_observer.h +++ b/src/components/include/protocol_handler/session_observer.h @@ -69,6 +69,7 @@ struct SessionContext { uint32_t hash_id_; bool is_protected_; bool is_new_service_; + bool is_start_session_failed_; /** * @brief Constructor @@ -81,7 +82,8 @@ struct SessionContext { , service_type_(protocol_handler::kInvalidServiceType) , hash_id_(0) , is_protected_(false) - , is_new_service_(false) {} + , is_new_service_(false) + , is_start_session_failed_(false) {} /** * @brief Constructor @@ -111,7 +113,8 @@ struct SessionContext { , service_type_(service_type) , hash_id_(hash_id) , is_protected_(is_protected) - , is_new_service_(false) {} + , is_new_service_(false) + , is_start_session_failed_(false) {} }; /** diff --git a/src/components/include/test/protocol_handler/mock_protocol_handler.h b/src/components/include/test/protocol_handler/mock_protocol_handler.h index da71d9bc90d..b9a54ef73f3 100644 --- a/src/components/include/test/protocol_handler/mock_protocol_handler.h +++ b/src/components/include/test/protocol_handler/mock_protocol_handler.h @@ -65,7 +65,7 @@ class MockProtocolHandler : public ::protocol_handler::ProtocolHandler { const ::protocol_handler::ProtocolHandlerSettings&()); MOCK_METHOD0(get_session_observer, protocol_handler::SessionObserver&()); MOCK_METHOD3(NotifySessionStarted, - void(const ::protocol_handler::SessionContext& context, + void(::protocol_handler::SessionContext& context, std::vector& rejected_params, const std::string err_reason)); MOCK_METHOD0(NotifyOnGetSystemTimeFailed, void()); diff --git a/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h b/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h index c75b1f272cb..b036735216d 100644 --- a/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h +++ b/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h @@ -375,12 +375,15 @@ class ProtocolHandlerImpl * \param protocol_version Version of protocol used for communication * \param service_type Type of session: RPC or BULK Data. RPC by default * \param reason String stating the reason for the rejecting the start service + * \param full_version full protocol version (major.minor.patch) used by the + * mobile proxy */ void SendStartSessionNAck(ConnectionID connection_id, uint8_t session_id, uint8_t protocol_version, uint8_t service_type, - const std::string& reason); + const std::string& reason, + utils::SemanticVersion& full_version); /** * \brief Sends fail of starting session to mobile application @@ -390,13 +393,16 @@ class ProtocolHandlerImpl * \param service_type Type of session: RPC or BULK Data. RPC by default * \param rejected_params List of rejected params to send in payload * \param reason String stating the reason for the rejecting the start service + * \param full_version full protocol version (major.minor.patch) used by the + * mobile proxy */ void SendStartSessionNAck(ConnectionID connection_id, uint8_t session_id, uint8_t protocol_version, uint8_t service_type, std::vector& rejectedParams, - const std::string& reason); + const std::string& reason, + utils::SemanticVersion& full_version); /** * \brief Sends acknowledgement of end session/service to mobile application @@ -456,7 +462,7 @@ class ProtocolHandlerImpl * Only valid when generated_session_id is 0. Note, even if * generated_session_id is 0, the list may be empty. */ - void NotifySessionStarted(const SessionContext& context, + void NotifySessionStarted(SessionContext& context, std::vector& rejected_params, const std::string err_reason) OVERRIDE; @@ -741,6 +747,16 @@ class ProtocolHandlerImpl void WriteProtocolVehicleData( BsonObject& params, const connection_handler::ProtocolVehicleData& data); + /** + * \brief Parces full protocol version from start service message headers bson + * \param full_version full protocol version (major.minor.patch) used by the + * mobile proxy + * \param packet Sart service message + * \return true if version successfully parsed, otherwise false + */ + bool ParseFullVersion(utils::SemanticVersion& full_version, + const ProtocolFramePtr& packet) const; + const ProtocolHandlerSettings& settings_; /** diff --git a/src/components/protocol_handler/src/handshake_handler.cc b/src/components/protocol_handler/src/handshake_handler.cc index 4d306fd3304..78f324e0ae1 100644 --- a/src/components/protocol_handler/src/handshake_handler.cc +++ b/src/components/protocol_handler/src/handshake_handler.cc @@ -232,7 +232,8 @@ void HandshakeHandler::ProcessSuccessfulHandshake(const uint32_t connection_key, context_.service_type_, (is_service_already_protected) ? "Service is already protected" - : "Service cannot be protected"); + : "Service cannot be protected", + full_version_); } } @@ -284,7 +285,8 @@ void HandshakeHandler::ProcessFailedHandshake(BsonObject& params, context_.new_session_id_, protocol_version_, context_.service_type_, - reason_msg + (err_reason.empty() ? "" : ": " + err_reason)); + reason_msg + (err_reason.empty() ? "" : ": " + err_reason), + full_version_); } } diff --git a/src/components/protocol_handler/src/protocol_handler_impl.cc b/src/components/protocol_handler/src/protocol_handler_impl.cc index ca15e9481a4..3148df7d209 100644 --- a/src/components/protocol_handler/src/protocol_handler_impl.cc +++ b/src/components/protocol_handler/src/protocol_handler_impl.cc @@ -481,18 +481,21 @@ void ProtocolHandlerImpl::SendStartSessionAck( } } -void ProtocolHandlerImpl::SendStartSessionNAck(ConnectionID connection_id, - uint8_t session_id, - uint8_t protocol_version, - uint8_t service_type, - const std::string& reason) { +void ProtocolHandlerImpl::SendStartSessionNAck( + ConnectionID connection_id, + uint8_t session_id, + uint8_t protocol_version, + uint8_t service_type, + const std::string& reason, + utils::SemanticVersion& full_version) { std::vector rejectedParams; SendStartSessionNAck(connection_id, session_id, protocol_version, service_type, rejectedParams, - reason); + reason, + full_version); } void ProtocolHandlerImpl::SendStartSessionNAck( @@ -501,9 +504,18 @@ void ProtocolHandlerImpl::SendStartSessionNAck( uint8_t protocol_version, uint8_t service_type, std::vector& rejectedParams, - const std::string& reason) { + const std::string& reason, + utils::SemanticVersion& full_version) { SDL_LOG_AUTO_TRACE(); + if (!full_version.isValid()) { + if (!session_observer_.ProtocolVersionUsed( + connection_id, session_id, full_version)) { + SDL_LOG_WARN("Connection: " << connection_id << " and/or session: " + << session_id << "no longer exist(s)."); + } + } + ProtocolFramePtr ptr( new protocol_handler::ProtocolPacket(connection_id, protocol_version, @@ -517,14 +529,6 @@ void ProtocolHandlerImpl::SendStartSessionNAck( uint8_t maxProtocolVersion = SupportedSDLProtocolVersion(); - utils::SemanticVersion full_version; - if (!session_observer_.ProtocolVersionUsed( - connection_id, session_id, full_version)) { - SDL_LOG_WARN("Connection: " << connection_id << " and/or session: " - << session_id << "no longer exist(s)."); - return; - } - if (protocol_version >= PROTOCOL_VERSION_5 && maxProtocolVersion >= PROTOCOL_VERSION_5) { BsonObject payloadObj; @@ -1778,8 +1782,17 @@ RESULT_CODE ProtocolHandlerImpl::HandleControlMessageStartSession( reason += " Allowed only in unprotected mode"; } - SendStartSessionNAck( - connection_id, session_id, protocol_version, service_type, reason); + utils::SemanticVersion version; + if (packet->service_type() == kRpc && packet->data() != NULL) { + ParseFullVersion(version, packet); + } + + SendStartSessionNAck(connection_id, + session_id, + protocol_version, + service_type, + reason, + version); return RESULT_OK; } @@ -1855,8 +1868,34 @@ RESULT_CODE ProtocolHandlerImpl::HandleControlMessageRegisterSecondaryTransport( return RESULT_OK; } +bool ProtocolHandlerImpl::ParseFullVersion( + utils::SemanticVersion& full_version, + const ProtocolFramePtr& packet) const { + SDL_LOG_AUTO_TRACE(); + + BsonObject request_params; + size_t request_params_size = bson_object_from_bytes_len( + &request_params, packet->data(), packet->total_data_bytes()); + if (request_params_size > 0) { + char* version_param = + bson_object_get_string(&request_params, strings::protocol_version); + std::string version_string(version_param == NULL ? "" : version_param); + full_version = version_string; + + // Constructed payloads added in Protocol v5 + if (full_version.major_version_ < PROTOCOL_VERSION_5) { + return false; + } + bson_object_deinitialize(&request_params); + } else { + SDL_LOG_WARN("Failed to parse start service packet for version string"); + } + + return true; +} + void ProtocolHandlerImpl::NotifySessionStarted( - const SessionContext& context, + SessionContext& context, std::vector& rejected_params, const std::string err_reason) { SDL_LOG_AUTO_TRACE(); @@ -1876,8 +1915,23 @@ void ProtocolHandlerImpl::NotifySessionStarted( const ServiceType service_type = ServiceTypeFromByte(packet->service_type()); const uint8_t protocol_version = packet->protocol_version(); + utils::SemanticVersion full_version; - if (0 == context.new_session_id_) { + // Can't check protocol_version because the first packet is v1, but there + // could still be a payload, in which case we can get the real protocol + // version + if (packet->service_type() == kRpc && packet->data() != NULL) { + if (ParseFullVersion(full_version, packet)) { + const auto connection_key = session_observer_.KeyFromPair( + packet->connection_id(), context.new_session_id_); + connection_handler_.BindProtocolVersionWithSession(connection_key, + full_version); + } else { + rejected_params.push_back(std::string(strings::protocol_version)); + } + } + + if (context.is_start_session_failed_ || !context.new_session_id_) { SDL_LOG_WARN("Refused by session_observer to create service " << static_cast(service_type) << " type."); const auto session_id = packet->session_id(); @@ -1892,7 +1946,8 @@ void ProtocolHandlerImpl::NotifySessionStarted( protocol_version, packet->service_type(), rejected_params, - err_reason); + err_reason, + full_version); return; } @@ -1942,38 +1997,6 @@ void ProtocolHandlerImpl::NotifySessionStarted( } } - std::shared_ptr fullVersion; - - // Can't check protocol_version because the first packet is v1, but there - // could still be a payload, in which case we can get the real protocol - // version - if (packet->service_type() == kRpc && packet->data() != NULL) { - BsonObject request_params; - size_t request_params_size = bson_object_from_bytes_len( - &request_params, packet->data(), packet->total_data_bytes()); - if (request_params_size > 0) { - char* version_param = - bson_object_get_string(&request_params, strings::protocol_version); - std::string version_string(version_param == NULL ? "" : version_param); - fullVersion = std::make_shared(version_string); - - const auto connection_key = session_observer_.KeyFromPair( - packet->connection_id(), context.new_session_id_); - connection_handler_.BindProtocolVersionWithSession(connection_key, - *fullVersion); - // Constructed payloads added in Protocol v5 - if (fullVersion->major_version_ < PROTOCOL_VERSION_5) { - rejected_params.push_back(std::string(strings::protocol_version)); - } - bson_object_deinitialize(&request_params); - } else { - SDL_LOG_WARN("Failed to parse start service packet for version string"); - fullVersion = std::make_shared(); - } - } else { - fullVersion = std::make_shared(); - } - #ifdef ENABLE_SECURITY // for packet is encrypted and security plugin is enable if (context.is_protected_ && security_manager_) { @@ -1984,7 +2007,7 @@ void ProtocolHandlerImpl::NotifySessionStarted( std::make_shared( *this, session_observer_, - *fullVersion, + full_version, context, packet->protocol_version(), start_session_ack_params, @@ -2010,12 +2033,20 @@ void ProtocolHandlerImpl::NotifySessionStarted( } if (!rejected_params.empty()) { + service_status_update_handler_->OnServiceUpdate( + connection_key, + context.service_type_, + ServiceStatus::SERVICE_START_FAILED); SendStartSessionNAck(context.connection_id_, packet->session_id(), protocol_version, packet->service_type(), rejected_params, - "SSL Handshake failed due to rejected parameters"); + "SSL Handshake failed due to rejected parameters", + full_version); + if (packet->service_type() != kRpc) { + context.is_start_session_failed_ = true; + } } else if (ssl_context->IsInitCompleted()) { // mark service as protected session_observer_.SetProtectionFlag(connection_key, service_type); @@ -2030,7 +2061,7 @@ void ProtocolHandlerImpl::NotifySessionStarted( context.hash_id_, packet->service_type(), PROTECTION_ON, - *fullVersion, + full_version, *start_session_ack_params); } else { SDL_LOG_DEBUG("Adding Handshake handler to listeners: " << handler.get()); @@ -2044,12 +2075,20 @@ void ProtocolHandlerImpl::NotifySessionStarted( if (!security_manager_->IsSystemTimeProviderReady()) { security_manager_->RemoveListener(listener); + service_status_update_handler_->OnServiceUpdate( + connection_key, + context.service_type_, + ServiceStatus::SERVICE_START_FAILED); SendStartSessionNAck(context.connection_id_, packet->session_id(), protocol_version, packet->service_type(), rejected_params, - "System time provider is not ready"); + "System time provider is not ready", + full_version); + if (packet->service_type() != kRpc) { + context.is_start_session_failed_ = true; + } } } } @@ -2070,7 +2109,7 @@ void ProtocolHandlerImpl::NotifySessionStarted( context.hash_id_, packet->service_type(), PROTECTION_OFF, - *fullVersion, + full_version, *start_session_ack_params); } else { service_status_update_handler_->OnServiceUpdate( @@ -2083,7 +2122,9 @@ void ProtocolHandlerImpl::NotifySessionStarted( protocol_version, packet->service_type(), rejected_params, - "Certain parameters in the StartService request were rejected"); + "Certain parameters in the StartService request were rejected", + full_version); + context.is_start_session_failed_ = true; } } diff --git a/src/components/protocol_handler/test/protocol_handler_tm_test.cc b/src/components/protocol_handler/test/protocol_handler_tm_test.cc index 035909ecbf6..e08500a93e3 100644 --- a/src/components/protocol_handler/test/protocol_handler_tm_test.cc +++ b/src/components/protocol_handler/test/protocol_handler_tm_test.cc @@ -287,7 +287,7 @@ class ProtocolHandlerImplTest : public ::testing::Test { const bool callback_protection_flag = PROTECTION_OFF; #endif // ENABLE_SECURITY - const protocol_handler::SessionContext context = + protocol_handler::SessionContext context = GetSessionContext(connection_id, NEW_SESSION_ID, session_id, @@ -323,7 +323,7 @@ class ProtocolHandlerImplTest : public ::testing::Test { NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - context, + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -570,6 +570,14 @@ TEST_F(ProtocolHandlerImplTest, .Times(call_times) .WillRepeatedly(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = + GetSessionContext(connection_id, + NEW_SESSION_ID, + SESSION_START_REJECT, + service_type, + HASH_ID_WRONG, + PROTECTION_OFF); + // Expect ConnectionHandler check EXPECT_CALL( session_observer_mock, @@ -586,12 +594,7 @@ TEST_F(ProtocolHandlerImplTest, SaveArg<2>(&service_type), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - SESSION_START_REJECT, - service_type, - HASH_ID_WRONG, - PROTECTION_OFF), + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times += call_times; @@ -662,6 +665,14 @@ TEST_F(ProtocolHandlerImplTest, StartSession_Protected_SessionObserverReject) { .Times(call_times) .WillRepeatedly(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = + GetSessionContext(connection_id, + NEW_SESSION_ID, + SESSION_START_REJECT, + service_type, + HASH_ID_WRONG, + callback_protection_flag); + // Expect ConnectionHandler check EXPECT_CALL( session_observer_mock, @@ -673,19 +684,14 @@ TEST_F(ProtocolHandlerImplTest, StartSession_Protected_SessionObserverReject) { .Times(call_times) . // Return sessions start rejection - WillRepeatedly(DoAll( - NotifyTestAsyncWaiter(waiter), - SaveArg<2>(&service_type), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - SESSION_START_REJECT, - service_type, - HASH_ID_WRONG, - callback_protection_flag), - ByRef(empty_rejected_param_), - std::string()))); + WillRepeatedly( + DoAll(NotifyTestAsyncWaiter(waiter), + SaveArg<2>(&service_type), + InvokeMemberFuncWithArg3(protocol_handler_impl.get(), + &ProtocolHandler::NotifySessionStarted, + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times += call_times; // Expect send NAck with encryption OFF @@ -741,6 +747,13 @@ TEST_F(ProtocolHandlerImplTest, EXPECT_CALL(protocol_handler_settings_mock, video_service_transports()) .WillOnce(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = GetSessionContext(connection_id, + NEW_SESSION_ID, + session_id, + start_service, + HASH_ID_WRONG, + PROTECTION_OFF); + // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, OnSessionStartedCallback(connection_id, @@ -754,12 +767,7 @@ TEST_F(ProtocolHandlerImplTest, DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - session_id, - start_service, - HASH_ID_WRONG, - PROTECTION_OFF), + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -916,34 +924,40 @@ TEST_F(ProtocolHandlerImplTest, EXPECT_CALL(session_observer_mock, ProtocolVersionUsed(_, _, An())) .WillOnce(Return(true)); + + protocol_handler::SessionContext rejected_context = + GetSessionContext(connection_id2, + session_id2, + SESSION_START_REJECT, + start_service, + HASH_ID_WRONG, + PROTECTION_OFF); + + protocol_handler::SessionContext context = + GetSessionContext(connection_id1, + session_id1, + generated_session_id1, + start_service, + HASH_ID_WRONG, + PROTECTION_OFF); EXPECT_CALL(session_observer_mock, OnSessionStartedCallback(connection_id2, session_id2, start_service, PROTECTION_OFF, An())) - .WillOnce(DoAll( - NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id2, - session_id2, - SESSION_START_REJECT, - start_service, - HASH_ID_WRONG, - PROTECTION_OFF), - ByRef(rejected_param_list), - std::string()), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id1, - session_id1, - generated_session_id1, - start_service, - HASH_ID_WRONG, - PROTECTION_OFF), - ByRef(empty_rejected_param_), - std::string()))); + .WillOnce( + DoAll(NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3(protocol_handler_impl.get(), + &ProtocolHandler::NotifySessionStarted, + ByRef(rejected_context), + ByRef(rejected_param_list), + std::string()), + InvokeMemberFuncWithArg3(protocol_handler_impl.get(), + &ProtocolHandler::NotifySessionStarted, + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; BsonObject bson_ack_params; @@ -1208,6 +1222,12 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionProtocoloV1) { EXPECT_CALL(protocol_handler_settings_mock, video_service_transports()) .WillOnce(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = GetSessionContext(connection_id, + NEW_SESSION_ID, + session_id, + start_service, + HASH_ID_WRONG, + PROTECTION_OFF); // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, OnSessionStartedCallback(connection_id, @@ -1221,12 +1241,7 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionProtocoloV1) { DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - session_id, - start_service, - HASH_ID_WRONG, - PROTECTION_OFF), + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -1280,6 +1295,13 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionUnprotected) { EXPECT_CALL(protocol_handler_settings_mock, video_service_transports()) .WillOnce(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = GetSessionContext(connection_id, + NEW_SESSION_ID, + session_id, + start_service, + HASH_ID_WRONG, + PROTECTION_OFF); + // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, OnSessionStartedCallback(connection_id, @@ -1293,12 +1315,7 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionUnprotected) { DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - session_id, - start_service, - HASH_ID_WRONG, - PROTECTION_OFF), + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -1334,6 +1351,7 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionProtected_Fail) { start_service, HASH_ID_WRONG, PROTECTION_ON); + context.is_new_service_ = true; // Expect verification of allowed transport @@ -1364,7 +1382,7 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionProtected_Fail) { DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - context, + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -1420,6 +1438,13 @@ TEST_F(ProtocolHandlerImplTest, EXPECT_CALL(protocol_handler_settings_mock, video_service_transports()) .WillOnce(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = GetSessionContext(connection_id, + NEW_SESSION_ID, + session_id, + start_service, + HASH_ID_WRONG, + PROTECTION_ON); + // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, OnSessionStartedCallback(connection_id, @@ -1433,12 +1458,7 @@ TEST_F(ProtocolHandlerImplTest, DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - session_id, - start_service, - HASH_ID_WRONG, - PROTECTION_ON), + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -1524,7 +1544,7 @@ TEST_F(ProtocolHandlerImplTest, DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - context, + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -1608,6 +1628,13 @@ TEST_F(ProtocolHandlerImplTest, EXPECT_CALL(protocol_handler_settings_mock, video_service_transports()) .WillOnce(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = GetSessionContext(connection_id, + NEW_SESSION_ID, + session_id, + start_service, + HASH_ID_WRONG, + PROTECTION_ON); + // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, OnSessionStartedCallback(connection_id, @@ -1621,12 +1648,7 @@ TEST_F(ProtocolHandlerImplTest, DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - session_id, - start_service, - HASH_ID_WRONG, - PROTECTION_ON), + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -1721,6 +1743,13 @@ TEST_F( EXPECT_CALL(protocol_handler_settings_mock, video_service_transports()) .WillOnce(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = GetSessionContext(connection_id, + NEW_SESSION_ID, + session_id, + start_service, + HASH_ID_WRONG, + PROTECTION_ON); + // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, OnSessionStartedCallback(connection_id, @@ -1734,12 +1763,7 @@ TEST_F( DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - session_id, - start_service, - HASH_ID_WRONG, - PROTECTION_ON), + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -1832,6 +1856,13 @@ TEST_F(ProtocolHandlerImplTest, EXPECT_CALL(protocol_handler_settings_mock, video_service_transports()) .WillOnce(ReturnRef(video_service_transports)); + protocol_handler::SessionContext context = GetSessionContext(connection_id, + NEW_SESSION_ID, + session_id, + start_service, + HASH_ID_WRONG, + PROTECTION_ON); + // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, OnSessionStartedCallback(connection_id, @@ -1845,12 +1876,7 @@ TEST_F(ProtocolHandlerImplTest, DoAll(NotifyTestAsyncWaiter(waiter), InvokeMemberFuncWithArg3(protocol_handler_impl.get(), &ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - session_id, - start_service, - HASH_ID_WRONG, - PROTECTION_ON), + ByRef(context), ByRef(empty_rejected_param_), std::string()))); times++; @@ -5170,7 +5196,6 @@ TEST_F(ProtocolHandlerImplTest, SendServiceDataAck_AfterVersion5) { TEST_F(ProtocolHandlerImplTest, StartSession_NACKReason_DisallowedBySettings) { const ServiceType service_type = kMobileNav; const utils::SemanticVersion min_reason_param_version(5, 3, 0); - #ifdef ENABLE_SECURITY AddSecurityManager(); @@ -5207,6 +5232,7 @@ TEST_F(ProtocolHandlerImplTest, StartSession_NACKReason_DisallowedBySettings) { bson_object_put_string(&bson_nack_params, protocol_handler::strings::reason, const_cast(reason.c_str())); + std::vector nack_params = CreateVectorFromBsonObject(&bson_nack_params); bson_object_deinitialize(&bson_nack_params); @@ -5267,6 +5293,14 @@ TEST_F(ProtocolHandlerImplTest, StartSession_NACKReason_SessionObserverReject) { .Times(call_times) .WillRepeatedly(ReturnRef(allowed_transports)); + protocol_handler::SessionContext context = + GetSessionContext(connection_id, + NEW_SESSION_ID, + SESSION_START_REJECT, + service_type, + protocol_handler::HASH_ID_WRONG, + PROTECTION_OFF); + // Expect ConnectionHandler check EXPECT_CALL( session_observer_mock, @@ -5284,12 +5318,7 @@ TEST_F(ProtocolHandlerImplTest, StartSession_NACKReason_SessionObserverReject) { InvokeMemberFuncWithArg3( protocol_handler_impl.get(), &protocol_handler::ProtocolHandler::NotifySessionStarted, - GetSessionContext(connection_id, - NEW_SESSION_ID, - SESSION_START_REJECT, - service_type, - protocol_handler::HASH_ID_WRONG, - PROTECTION_OFF), + ByRef(context), ByRef(empty_rejected_param_), err_reason))); times += call_times; From 77231fce35b028d5691fd42efaff4e155a47be3b Mon Sep 17 00:00:00 2001 From: Andrii Kalinich Date: Mon, 22 Mar 2021 20:43:40 -0400 Subject: [PATCH 2/3] fixup! Fix NACK sending for Start RPC service with invalid data in bson payload. Fix memory leak with session after Start RPC service NACK. Fix OnServiceUpdate sending after Start RPC service with invalid data in bson payload for protected mode. Fix NACK sending for Start any service with non-existent session id. --- .../test/connection_handler_impl_test.cc | 61 ++-- .../protocol_handler/protocol_handler.h | 16 ++ .../protocol_handler/mock_protocol_handler.h | 4 + .../protocol_handler/protocol_handler_impl.h | 5 + .../src/protocol_handler_impl.cc | 8 + .../test/protocol_handler_tm_test.cc | 270 +++++++++++------- 6 files changed, 239 insertions(+), 125 deletions(-) diff --git a/src/components/connection_handler/test/connection_handler_impl_test.cc b/src/components/connection_handler/test/connection_handler_impl_test.cc index f592915e0e6..79b237253cd 100644 --- a/src/components/connection_handler/test/connection_handler_impl_test.cc +++ b/src/components/connection_handler/test/connection_handler_impl_test.cc @@ -56,6 +56,7 @@ using namespace ::connection_handler; using ::protocol_handler::ServiceType; using namespace ::protocol_handler; using ::testing::_; +using ::testing::An; using ::testing::ByRef; using ::testing::DoAll; using ::testing::InSequence; @@ -127,7 +128,8 @@ class ConnectionHandlerTest : public ::testing::Test { void AddTestSession() { protocol_handler_test::MockProtocolHandler temp_protocol_handler; connection_handler_->set_protocol_handler(&temp_protocol_handler); - EXPECT_CALL(temp_protocol_handler, NotifySessionStarted(_, _, _)) + EXPECT_CALL(temp_protocol_handler, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&out_context_)); connection_handler_->OnSessionStartedCallback( @@ -164,7 +166,8 @@ class ConnectionHandlerTest : public ::testing::Test { SessionContext context; protocol_handler_test::MockProtocolHandler temp_protocol_handler; connection_handler_->set_protocol_handler(&temp_protocol_handler); - EXPECT_CALL(temp_protocol_handler, NotifySessionStarted(_, _, _)) + EXPECT_CALL(temp_protocol_handler, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&context)); connection_handler_->OnSessionStartedCallback(uid_, @@ -371,7 +374,8 @@ TEST_F(ConnectionHandlerTest, StartSession_NoConnection) { protocol_handler::SessionContext context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&context)); connection_handler_->OnSessionStartedCallback( @@ -1268,7 +1272,8 @@ TEST_F(ConnectionHandlerTest, StartService_withServices) { SessionContext audio_context, video_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&audio_context)) .WillOnce(SaveArg<0>(&video_context)); @@ -1309,7 +1314,8 @@ TEST_F(ConnectionHandlerTest, StartService_withServices_withParams) { std::vector empty; BsonObject* dummy_param = reinterpret_cast(&dummy); connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, empty, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), empty, _)) .WillOnce(SaveArg<0>(&video_context)); connection_handler_->OnSessionStartedCallback(uid_, @@ -1354,7 +1360,8 @@ TEST_F(ConnectionHandlerTest, ServiceStop) { SessionContext audio_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillRepeatedly(SaveArg<0>(&audio_context)); // Check ignoring hash_id on stop non-rpc service @@ -1445,7 +1452,8 @@ TEST_F(ConnectionHandlerTest, SessionStarted_WithRpc) { reason)); connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&out_context_)); // Start new session with RPC service @@ -1485,7 +1493,8 @@ TEST_F(ConnectionHandlerTest, ServiceStarted_Video_SUCCESS) { // confirm that NotifySessionStarted() is called connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, empty, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), empty, _)) .WillOnce(SaveArg<0>(&out_context_)); connection_handler_->OnSessionStartedCallback(uid_, @@ -1527,7 +1536,8 @@ TEST_F(ConnectionHandlerTest, ServiceStarted_Video_FAILURE) { // confirm that NotifySessionStarted() is called connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, empty, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), empty, _)) .WillOnce(SaveArg<0>(&out_context_)); connection_handler_->OnSessionStartedCallback(uid_, @@ -1551,7 +1561,8 @@ TEST_F(ConnectionHandlerTest, ServiceStarted_Video_Multiple) { protocol_handler_test::MockProtocolHandler temp_protocol_handler; connection_handler_->set_protocol_handler(&temp_protocol_handler); - EXPECT_CALL(temp_protocol_handler, NotifySessionStarted(_, _, _)) + EXPECT_CALL(temp_protocol_handler, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&context_first)) .WillOnce(SaveArg<0>(&context_second)); @@ -1622,7 +1633,8 @@ TEST_F(ConnectionHandlerTest, ServiceStarted_Video_Multiple) { // verify that connection handler will not mix up the two results SessionContext new_context_first, new_context_second; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, empty, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), empty, _)) .WillOnce(SaveArg<0>(&new_context_second)) .WillOnce(SaveArg<0>(&new_context_first)); @@ -1656,7 +1668,8 @@ TEST_F(ConnectionHandlerTest, SessionContext fail_context; SessionContext positive_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&fail_context)) .WillOnce(SaveArg<0>(&positive_context)); @@ -1699,7 +1712,8 @@ TEST_F(ConnectionHandlerTest, SessionContext fail_context; SessionContext positive_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&fail_context)) .WillOnce(SaveArg<0>(&positive_context)); @@ -1744,7 +1758,8 @@ TEST_F(ConnectionHandlerTest, SessionContext context_first, context_second; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&context_first)) .WillOnce(SaveArg<0>(&context_second)); @@ -1799,7 +1814,8 @@ TEST_F(ConnectionHandlerTest, SessionContext rejected_context, positive_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&rejected_context)) .WillOnce(SaveArg<0>(&positive_context)); @@ -1842,7 +1858,8 @@ TEST_F(ConnectionHandlerTest, SessionStarted_DelayProtect) { SessionContext context_new, context_second, context_third; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&context_new)) .WillOnce(SaveArg<0>(&context_second)) .WillOnce(SaveArg<0>(&context_third)); @@ -1897,7 +1914,8 @@ TEST_F(ConnectionHandlerTest, SessionStarted_DelayProtectBulk) { SessionContext new_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&new_context)); connection_handler_->OnSessionStartedCallback(uid_, out_context_.new_session_id_, @@ -2003,7 +2021,8 @@ TEST_F(ConnectionHandlerTest, GetSSLContext_ByProtectedService) { SessionContext new_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&new_context)); // Open kAudio service @@ -2040,7 +2059,8 @@ TEST_F(ConnectionHandlerTest, GetSSLContext_ByDealyProtectedRPC) { SessionContext new_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&new_context)); // Protect kRpc (Bulk will be protect also) @@ -2080,7 +2100,8 @@ TEST_F(ConnectionHandlerTest, GetSSLContext_ByDealyProtectedBulk) { SessionContext new_context; connection_handler_->set_protocol_handler(&mock_protocol_handler_); - EXPECT_CALL(mock_protocol_handler_, NotifySessionStarted(_, _, _)) + EXPECT_CALL(mock_protocol_handler_, + NotifySessionStarted(An(), _, _)) .WillOnce(SaveArg<0>(&new_context)); // Protect Bulk (kRpc will be protected also) diff --git a/src/components/include/protocol_handler/protocol_handler.h b/src/components/include/protocol_handler/protocol_handler.h index 2cea958bae4..011592da499 100644 --- a/src/components/include/protocol_handler/protocol_handler.h +++ b/src/components/include/protocol_handler/protocol_handler.h @@ -127,6 +127,22 @@ class ProtocolHandler { virtual const ProtocolHandlerSettings& get_settings() const = 0; virtual SessionObserver& get_session_observer() = 0; + /** + * @brief Called by connection handler to notify the context of + * OnSessionStartedCallback(). + * @param context reference to structure with started session data + * @param rejected_params list of parameters name that are rejected. + * Only valid when generated_session_id is 0. Note, even if + * generated_session_id is 0, the list may be empty. + * @param err_reason string with NACK reason. Only valid when + * generated_session_id is 0. + */ + DEPRECATED + virtual void NotifySessionStarted( + const SessionContext& context, + std::vector& rejected_params, + const std::string err_reason = std::string()) = 0; + /** * @brief Called by connection handler to notify the context of * OnSessionStartedCallback(). diff --git a/src/components/include/test/protocol_handler/mock_protocol_handler.h b/src/components/include/test/protocol_handler/mock_protocol_handler.h index b9a54ef73f3..aa7c1293c36 100644 --- a/src/components/include/test/protocol_handler/mock_protocol_handler.h +++ b/src/components/include/test/protocol_handler/mock_protocol_handler.h @@ -68,6 +68,10 @@ class MockProtocolHandler : public ::protocol_handler::ProtocolHandler { void(::protocol_handler::SessionContext& context, std::vector& rejected_params, const std::string err_reason)); + MOCK_METHOD3(NotifySessionStarted, + void(const ::protocol_handler::SessionContext& context, + std::vector& rejected_params, + const std::string err_reason)); MOCK_METHOD0(NotifyOnGetSystemTimeFailed, void()); MOCK_CONST_METHOD1(IsRPCServiceSecure, bool(const uint32_t connection_key)); MOCK_METHOD0(ProcessFailedPTU, void()); diff --git a/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h b/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h index b036735216d..02c96fc3c7b 100644 --- a/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h +++ b/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h @@ -462,6 +462,11 @@ class ProtocolHandlerImpl * Only valid when generated_session_id is 0. Note, even if * generated_session_id is 0, the list may be empty. */ + DEPRECATED + void NotifySessionStarted(const SessionContext& context, + std::vector& rejected_params, + const std::string err_reason) OVERRIDE; + void NotifySessionStarted(SessionContext& context, std::vector& rejected_params, const std::string err_reason) OVERRIDE; diff --git a/src/components/protocol_handler/src/protocol_handler_impl.cc b/src/components/protocol_handler/src/protocol_handler_impl.cc index 3148df7d209..7c8b4636020 100644 --- a/src/components/protocol_handler/src/protocol_handler_impl.cc +++ b/src/components/protocol_handler/src/protocol_handler_impl.cc @@ -2128,6 +2128,14 @@ void ProtocolHandlerImpl::NotifySessionStarted( } } +void ProtocolHandlerImpl::NotifySessionStarted( + const SessionContext& context, + std::vector& rejected_params, + const std::string err_reason) { + NotifySessionStarted( + const_cast(context), rejected_params, err_reason); +} + RESULT_CODE ProtocolHandlerImpl::HandleControlMessageHeartBeat( const ProtocolPacket& packet) { const ConnectionID connection_id = packet.connection_id(); diff --git a/src/components/protocol_handler/test/protocol_handler_tm_test.cc b/src/components/protocol_handler/test/protocol_handler_tm_test.cc index e08500a93e3..2813802b6b1 100644 --- a/src/components/protocol_handler/test/protocol_handler_tm_test.cc +++ b/src/components/protocol_handler/test/protocol_handler_tm_test.cc @@ -321,11 +321,15 @@ class ProtocolHandlerImplTest : public ::testing::Test { // Return sessions start success WillOnce(DoAll( NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; // Expect send Ack with PROTECTION_OFF (on no Security Manager) @@ -589,14 +593,18 @@ TEST_F(ProtocolHandlerImplTest, .Times(call_times) . // Return sessions start rejection - WillRepeatedly( - DoAll(NotifyTestAsyncWaiter(waiter), - SaveArg<2>(&service_type), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillRepeatedly(DoAll( + NotifyTestAsyncWaiter(waiter), + SaveArg<2>(&service_type), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times += call_times; // Expect send NAck @@ -684,14 +692,18 @@ TEST_F(ProtocolHandlerImplTest, StartSession_Protected_SessionObserverReject) { .Times(call_times) . // Return sessions start rejection - WillRepeatedly( - DoAll(NotifyTestAsyncWaiter(waiter), - SaveArg<2>(&service_type), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillRepeatedly(DoAll( + NotifyTestAsyncWaiter(waiter), + SaveArg<2>(&service_type), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times += call_times; // Expect send NAck with encryption OFF @@ -763,13 +775,17 @@ TEST_F(ProtocolHandlerImplTest, An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; SetProtocolVersion2(); @@ -946,18 +962,26 @@ TEST_F(ProtocolHandlerImplTest, start_service, PROTECTION_OFF, An())) - .WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(rejected_context), - ByRef(rejected_param_list), - std::string()), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + .WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(rejected_context), + ByRef(rejected_param_list), + std::string()), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; BsonObject bson_ack_params; @@ -1237,13 +1261,17 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionProtocoloV1) { An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; SetProtocolVersion2(); @@ -1311,13 +1339,17 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionUnprotected) { An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; SetProtocolVersion2(); @@ -1378,13 +1410,17 @@ TEST_F(ProtocolHandlerImplTest, SecurityEnable_StartSessionProtected_Fail) { An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; SetProtocolVersion2(); @@ -1454,13 +1490,17 @@ TEST_F(ProtocolHandlerImplTest, An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; SetProtocolVersion2(); @@ -1540,13 +1580,17 @@ TEST_F(ProtocolHandlerImplTest, An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; std::vector services; @@ -1644,13 +1688,17 @@ TEST_F(ProtocolHandlerImplTest, An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; // call new SSLContext creation @@ -1759,13 +1807,17 @@ TEST_F( An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; // call new SSLContext creation @@ -1872,13 +1924,17 @@ TEST_F(ProtocolHandlerImplTest, An())) . // Return sessions start success - WillOnce( - DoAll(NotifyTestAsyncWaiter(waiter), - InvokeMemberFuncWithArg3(protocol_handler_impl.get(), - &ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - std::string()))); + WillOnce(DoAll( + NotifyTestAsyncWaiter(waiter), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + std::string()))); times++; // call new SSLContext creation @@ -5312,15 +5368,19 @@ TEST_F(ProtocolHandlerImplTest, StartSession_NACKReason_SessionObserverReject) { .Times(call_times) . // Return sessions start rejection - WillRepeatedly( - DoAll(NotifyTestAsyncWaiter(waiter), - SaveArg<2>(&service_type), - InvokeMemberFuncWithArg3( - protocol_handler_impl.get(), - &protocol_handler::ProtocolHandler::NotifySessionStarted, - ByRef(context), - ByRef(empty_rejected_param_), - err_reason))); + WillRepeatedly(DoAll( + NotifyTestAsyncWaiter(waiter), + SaveArg<2>(&service_type), + InvokeMemberFuncWithArg3( + protocol_handler_impl.get(), + static_cast&, + const std::string)>( + &protocol_handler::ProtocolHandler::NotifySessionStarted), + ByRef(context), + ByRef(empty_rejected_param_), + err_reason))); times += call_times; // Expect send NAck From 58fb660c930d0fe7cd42cd8123e12c4522363dc2 Mon Sep 17 00:00:00 2001 From: Andrii Kalinich Date: Tue, 23 Mar 2021 22:21:28 -0400 Subject: [PATCH 3/3] fixup! fixup! Fix NACK sending for Start RPC service with invalid data in bson payload. Fix memory leak with session after Start RPC service NACK. Fix OnServiceUpdate sending after Start RPC service with invalid data in bson payload for protected mode. Fix NACK sending for Start any service with non-existent session id. --- src/components/protocol_handler/src/protocol_handler_impl.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/components/protocol_handler/src/protocol_handler_impl.cc b/src/components/protocol_handler/src/protocol_handler_impl.cc index 7c8b4636020..ff3a2abec15 100644 --- a/src/components/protocol_handler/src/protocol_handler_impl.cc +++ b/src/components/protocol_handler/src/protocol_handler_impl.cc @@ -2132,8 +2132,8 @@ void ProtocolHandlerImpl::NotifySessionStarted( const SessionContext& context, std::vector& rejected_params, const std::string err_reason) { - NotifySessionStarted( - const_cast(context), rejected_params, err_reason); + SessionContext context_copy = context; + NotifySessionStarted(context_copy, rejected_params, err_reason); } RESULT_CODE ProtocolHandlerImpl::HandleControlMessageHeartBeat(