From f7cbcedd0bdb5c3c5bdcacf04b22f80451c468c8 Mon Sep 17 00:00:00 2001 From: Zang MingJie Date: Wed, 25 Aug 2021 08:56:44 +0800 Subject: [PATCH] Remove default constructor of SessionHandle, reduce dangling session handle --- src/app/CommandSender.cpp | 11 +--- src/app/CommandSender.h | 2 +- src/app/InteractionModelEngine.cpp | 6 +- src/app/InteractionModelEngine.h | 2 +- src/app/ReadClient.cpp | 11 +--- src/app/ReadClient.h | 2 +- src/app/WriteClient.cpp | 13 +---- src/app/WriteClient.h | 5 +- src/app/tests/TestCommandInteraction.cpp | 4 +- src/app/tests/TestReadInteraction.cpp | 17 +++--- src/app/tests/TestWriteInteraction.cpp | 5 +- .../tests/integration/chip_im_initiator.cpp | 12 ++-- src/channel/ChannelContext.cpp | 3 +- src/controller/CHIPDevice.cpp | 57 +++++++++++-------- src/controller/CHIPDevice.h | 6 +- src/controller/CHIPDeviceController.cpp | 2 +- src/messaging/ExchangeContext.cpp | 15 ++--- src/messaging/ExchangeContext.h | 8 +-- src/messaging/ExchangeMgr.cpp | 2 +- src/protocols/echo/Echo.h | 2 +- src/protocols/echo/EchoClient.cpp | 6 +- src/transport/SecureSessionMgr.cpp | 2 +- src/transport/SessionHandle.h | 8 ++- src/transport/tests/TestSecureSessionMgr.cpp | 26 ++++----- src/transport/tests/TestSessionHandle.cpp | 17 ------ 25 files changed, 116 insertions(+), 128 deletions(-) diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp index 381a1fc2387fcd..cfae386146172d 100644 --- a/src/app/CommandSender.cpp +++ b/src/app/CommandSender.cpp @@ -34,7 +34,7 @@ using GeneralStatusCode = chip::Protocols::SecureChannel::GeneralStatusCode; namespace chip { namespace app { -CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * secureSession, +CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional secureSession, uint32_t timeout) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -50,14 +50,7 @@ CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabric AbortExistingExchangeContext(); // Create a new exchange context. - if (secureSession == nullptr) - { - mpExchangeCtx = mpExchangeMgr->NewContext(SessionHandle(aNodeId, 0, 0, aFabricIndex), this); - } - else - { - mpExchangeCtx = mpExchangeMgr->NewContext(*secureSession, this); - } + mpExchangeCtx = mpExchangeMgr->NewContext(secureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(timeout); diff --git a/src/app/CommandSender.h b/src/app/CommandSender.h index 316ffe7670d15c..6d7ce63680ef6a 100644 --- a/src/app/CommandSender.h +++ b/src/app/CommandSender.h @@ -55,7 +55,7 @@ class CommandSender : public Command, public Messaging::ExchangeDelegate // // If SendCommandRequest is never called, or the call fails, the API // consumer is responsible for calling Shutdown on the CommandSender. - CHIP_ERROR SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * secureSession, + CHIP_ERROR SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional secureSession, uint32_t timeout = kImMessageTimeoutMsec); private: diff --git a/src/app/InteractionModelEngine.cpp b/src/app/InteractionModelEngine.cpp index aac9dd57ea52eb..c48053d59f9fde 100644 --- a/src/app/InteractionModelEngine.cpp +++ b/src/app/InteractionModelEngine.cpp @@ -317,9 +317,9 @@ void InteractionModelEngine::OnResponseTimeout(Messaging::ExchangeContext * ec) ChipLogProgress(InteractionModel, "Time out! failed to receive echo response from Exchange: %d", ec->GetExchangeId()); } -CHIP_ERROR InteractionModelEngine::SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, - EventPathParams * apEventPathParamsList, size_t aEventPathParamsListSize, - AttributePathParams * apAttributePathParamsList, +CHIP_ERROR InteractionModelEngine::SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, + Optional apSecureSession, EventPathParams * apEventPathParamsList, + size_t aEventPathParamsListSize, AttributePathParams * apAttributePathParamsList, size_t aAttributePathParamsListSize, EventNumber aEventNumber, uint64_t aAppIdentifier) { diff --git a/src/app/InteractionModelEngine.h b/src/app/InteractionModelEngine.h index a0288f1d771283..efe54b7c51d9ed 100644 --- a/src/app/InteractionModelEngine.h +++ b/src/app/InteractionModelEngine.h @@ -110,7 +110,7 @@ class InteractionModelEngine : public Messaging::ExchangeDelegate * @retval #CHIP_ERROR_NO_MEMORY If there is no ReadClient available * @retval #CHIP_NO_ERROR On success. */ - CHIP_ERROR SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, + CHIP_ERROR SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, EventPathParams * apEventPathParamsList, size_t aEventPathParamsListSize, AttributePathParams * apAttributePathParamsList, size_t aAttributePathParamsListSize, EventNumber aEventNumber, uint64_t aAppIdentifier = 0); diff --git a/src/app/ReadClient.cpp b/src/app/ReadClient.cpp index 47e2ff5a338ed7..9c74331e15c75b 100644 --- a/src/app/ReadClient.cpp +++ b/src/app/ReadClient.cpp @@ -86,7 +86,7 @@ void ReadClient::MoveToState(const ClientState aTargetState) GetStateStr()); } -CHIP_ERROR ReadClient::SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, +CHIP_ERROR ReadClient::SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, EventPathParams * apEventPathParamsList, size_t aEventPathParamsListSize, AttributePathParams * apAttributePathParamsList, size_t aAttributePathParamsListSize, EventNumber aEventNumber, uint32_t timeout) @@ -134,14 +134,7 @@ CHIP_ERROR ReadClient::SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, SuccessOrExit(err); } - if (apSecureSession != nullptr) - { - mpExchangeCtx = mpExchangeMgr->NewContext(*apSecureSession, this); - } - else - { - mpExchangeCtx = mpExchangeMgr->NewContext(SessionHandle(aNodeId, 0, 0, aFabricIndex), this); - } + mpExchangeCtx = mpExchangeMgr->NewContext(apSecureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(timeout); diff --git a/src/app/ReadClient.h b/src/app/ReadClient.h index c77fa9d2598e43..1d0fd3250c6fa9 100644 --- a/src/app/ReadClient.h +++ b/src/app/ReadClient.h @@ -75,7 +75,7 @@ class ReadClient : public Messaging::ExchangeDelegate * @retval #others fail to send read request * @retval #CHIP_NO_ERROR On success. */ - CHIP_ERROR SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * aSecureSession, + CHIP_ERROR SendReadRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional aSecureSession, EventPathParams * apEventPathParamsList, size_t aEventPathParamsListSize, AttributePathParams * apAttributePathParamsList, size_t aAttributePathParamsListSize, EventNumber aEventNumber, uint32_t timeout = kImMessageTimeoutMsec); diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp index bca0126cffba43..6c341145de203b 100644 --- a/src/app/WriteClient.cpp +++ b/src/app/WriteClient.cpp @@ -247,7 +247,7 @@ void WriteClient::ClearState() MoveToState(State::Uninitialized); } -CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, +CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, uint32_t timeout) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -263,14 +263,7 @@ CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricInde ClearExistingExchangeContext(); // Create a new exchange context. - if (apSecureSession == nullptr) - { - mpExchangeCtx = mpExchangeMgr->NewContext(SessionHandle(aNodeId, 0, 0, aFabricIndex), this); - } - else - { - mpExchangeCtx = mpExchangeMgr->NewContext(*apSecureSession, this); - } + mpExchangeCtx = mpExchangeMgr->NewContext(apSecureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(timeout); @@ -397,7 +390,7 @@ CHIP_ERROR WriteClient::ProcessAttributeStatusElement(AttributeStatusElement::Pa return err; } -CHIP_ERROR WriteClientHandle::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, +CHIP_ERROR WriteClientHandle::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, uint32_t timeout) { CHIP_ERROR err = mpWriteClient->SendWriteRequest(aNodeId, aFabricIndex, apSecureSession, timeout); diff --git a/src/app/WriteClient.h b/src/app/WriteClient.h index e757d2a755a9c2..e731cb19d27a02 100644 --- a/src/app/WriteClient.h +++ b/src/app/WriteClient.h @@ -94,7 +94,8 @@ class WriteClient : public Messaging::ExchangeDelegate * If SendWriteRequest is never called, or the call fails, the API * consumer is responsible for calling Shutdown on the WriteClient. */ - CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, uint32_t timeout); + CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, + uint32_t timeout); /** * Initialize the client object. Within the lifetime @@ -175,7 +176,7 @@ class WriteClientHandle * Finalize the message and send it to the desired node. The underlying write object will always be released, and the user * should not use this object after calling this function. */ - CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, + CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, uint32_t timeout = kImMessageTimeoutMsec); /** diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp index 9ef62056e6bc46..665c4ba9c4c286 100644 --- a/src/app/tests/TestCommandInteraction.cpp +++ b/src/app/tests/TestCommandInteraction.cpp @@ -239,7 +239,7 @@ void TestCommandInteraction::TestCommandSenderWithWrongState(nlTestSuite * apSui err = commandSender.Init(&gExchangeManager, nullptr); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, nullptr); + err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, Optional::Missing()); NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_INCORRECT_STATE); } @@ -278,7 +278,7 @@ void TestCommandInteraction::TestCommandSenderWithSendCommand(nlTestSuite * apSu NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); AddCommandDataElement(apSuite, apContext, &commandSender, false); - err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, nullptr); + err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, Optional::Missing()); NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_NOT_CONNECTED); GenerateReceivedCommand(apSuite, apContext, buf, true /*aNeedCommandData*/); diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp index 9851c35882d2e6..89e4ea89acee2a 100644 --- a/src/app/tests/TestReadInteraction.cpp +++ b/src/app/tests/TestReadInteraction.cpp @@ -262,9 +262,10 @@ void TestReadInteraction::TestReadClient(nlTestSuite * apSuite, void * apContext err = readClient.Init(&ctx.GetExchangeManager(), &delegate, 0 /* application identifier */); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); SessionHandle session = ctx.GetSessionLocalToPeer(); - err = readClient.SendReadRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), &session, nullptr /*apEventPathParamsList*/, - 0 /*aEventPathParamsListSize*/, nullptr /*apAttributePathParamsList*/, - 0 /*aAttributePathParamsListSize*/, eventNumber /*aEventNumber*/); + err = readClient.SendReadRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), Optional::Value(session), + nullptr /*apEventPathParamsList*/, 0 /*aEventPathParamsListSize*/, + nullptr /*apAttributePathParamsList*/, 0 /*aAttributePathParamsListSize*/, + eventNumber /*aEventNumber*/); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); GenerateReportData(apSuite, apContext, buf); @@ -389,9 +390,10 @@ void TestReadInteraction::TestReadClientInvalidReport(nlTestSuite * apSuite, voi NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); SessionHandle session = ctx.GetSessionLocalToPeer(); - err = readClient.SendReadRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), &session, nullptr /*apEventPathParamsList*/, - 0 /*aEventPathParamsListSize*/, nullptr /*apAttributePathParamsList*/, - 0 /*aAttributePathParamsListSize*/, eventNumber /*aEventNumber*/); + err = readClient.SendReadRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), Optional::Value(session), + nullptr /*apEventPathParamsList*/, 0 /*aEventPathParamsListSize*/, + nullptr /*apAttributePathParamsList*/, 0 /*aAttributePathParamsListSize*/, + eventNumber /*aEventNumber*/); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); GenerateReportData(apSuite, apContext, buf, true /*aNeedInvalidReport*/); @@ -579,7 +581,8 @@ void TestReadInteraction::TestReadEventRoundtrip(nlTestSuite * apSuite, void * a SessionHandle session = ctx.GetSessionLocalToPeer(); err = chip::app::InteractionModelEngine::GetInstance()->SendReadRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), - &session, eventPathParams, 2, nullptr, 1, 0); + Optional::Value(session), + eventPathParams, 2, nullptr, 1, 0); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); InteractionModelEngine::GetInstance()->GetReportingEngine().Run(); diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp index dec03ce4056240..466c9e9e2bb657 100644 --- a/src/app/tests/TestWriteInteraction.cpp +++ b/src/app/tests/TestWriteInteraction.cpp @@ -217,7 +217,8 @@ void TestWriteInteraction::TestWriteClient(nlTestSuite * apSuite, void * apConte AddAttributeDataElement(apSuite, apContext, writeClientHandle); SessionHandle session = ctx.GetSessionLocalToPeer(); - err = writeClientHandle.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), &session); + err = writeClientHandle.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), + Optional::Value(session)); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); // The internal WriteClient should be nullptr once we SendWriteRequest. NL_TEST_ASSERT(apSuite, nullptr == writeClientHandle.mpWriteClient); @@ -306,7 +307,7 @@ void TestWriteInteraction::TestWriteRoundtrip(nlTestSuite * apSuite, void * apCo SessionHandle session = ctx.GetSessionLocalToPeer(); - err = writeClient.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), &session); + err = writeClient.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), Optional::Value(session)); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(apSuite, delegate.mGotResponse); diff --git a/src/app/tests/integration/chip_im_initiator.cpp b/src/app/tests/integration/chip_im_initiator.cpp index cb5a9dcecfad53..5ed53850cbf82c 100644 --- a/src/app/tests/integration/chip_im_initiator.cpp +++ b/src/app/tests/integration/chip_im_initiator.cpp @@ -126,7 +126,8 @@ CHIP_ERROR SendCommandRequest(chip::app::CommandSender * commandSender) err = commandSender->FinishCommand(); SuccessOrExit(err); - err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec); + err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, chip::Optional::Missing(), + gMessageTimeoutMsec); SuccessOrExit(err); exit: @@ -163,7 +164,8 @@ CHIP_ERROR SendBadCommandRequest(chip::app::CommandSender * commandSender) err = commandSender->FinishCommand(); SuccessOrExit(err); - err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec); + err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, chip::Optional::Missing(), + gMessageTimeoutMsec); SuccessOrExit(err); exit: @@ -199,7 +201,8 @@ CHIP_ERROR SendReadRequest() printf("\nSend read request message to Node: %" PRIu64 "\n", chip::kTestDeviceNodeId); err = chip::app::InteractionModelEngine::GetInstance()->SendReadRequest( - chip::kTestDeviceNodeId, gFabricIndex, nullptr, eventPathParams, 2, &attributePathParams, 1, number, gMessageTimeoutMsec); + chip::kTestDeviceNodeId, gFabricIndex, chip::Optional::Missing(), eventPathParams, 2, + &attributePathParams, 1, number, gMessageTimeoutMsec); SuccessOrExit(err); exit: @@ -236,7 +239,8 @@ CHIP_ERROR SendWriteRequest(chip::app::WriteClientHandle & apWriteClient) SuccessOrExit(err = writer->PutBoolean(chip::TLV::ContextTag(chip::app::AttributeDataElement::kCsTag_Data), true)); SuccessOrExit(err = apWriteClient->FinishAttribute()); - SuccessOrExit(err = apWriteClient.SendWriteRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec)); + SuccessOrExit(err = apWriteClient.SendWriteRequest(chip::kTestDeviceNodeId, gFabricIndex, + chip::Optional::Missing(), gMessageTimeoutMsec)); gWriteCount++; diff --git a/src/channel/ChannelContext.cpp b/src/channel/ChannelContext.cpp index 607ff8aaf88f7b..72523b7c019f75 100644 --- a/src/channel/ChannelContext.cpp +++ b/src/channel/ChannelContext.cpp @@ -258,7 +258,8 @@ void ChannelContext::EnterCasePairingState() auto & prepare = GetPrepareVars(); prepare.mCasePairingSession = Platform::New(); - ExchangeContext * ctxt = mExchangeManager->NewContext(SessionHandle(), prepare.mCasePairingSession); + ExchangeContext * ctxt = + mExchangeManager->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), prepare.mCasePairingSession); VerifyOrReturn(ctxt != nullptr); // TODO: currently only supports IP/UDP paring diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index c414729574f56d..c6128a03c2d55e 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -72,7 +72,7 @@ CHIP_ERROR Device::SendMessage(Protocols::Id protocolId, uint8_t msgType, Messag ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(loadedSecureSession)); - Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(mSecureSession, nullptr); + Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(mSecureSession.Value(), nullptr); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_NO_MEMORY); if (!loadedSecureSession) @@ -128,10 +128,18 @@ CHIP_ERROR Device::LoadSecureSessionParametersIfNeeded(bool & didLoad) } else { - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); - - // Check if the connection state has the correct transport information - if (connectionState == nullptr || connectionState->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined) + if (mSecureSession.HasValue()) + { + Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value()); + // Check if the connection state has the correct transport information + if (connectionState->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined) + { + mState = ConnectionState::NotConnected; + ReturnErrorOnFailure(LoadSecureSessionParameters(ResetTransport::kNo)); + didLoad = true; + } + } + else { mState = ConnectionState::NotConnected; ReturnErrorOnFailure(LoadSecureSessionParameters(ResetTransport::kNo)); @@ -147,7 +155,7 @@ CHIP_ERROR Device::SendCommands(app::CommandSender * commandObj) bool loadedSecureSession = false; ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(loadedSecureSession)); VerifyOrReturnError(commandObj != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - return commandObj->SendCommandRequest(mDeviceId, mFabricIndex, &mSecureSession); + return commandObj->SendCommandRequest(mDeviceId, mFabricIndex, mSecureSession); } CHIP_ERROR Device::Serialize(SerializedDevice & output) @@ -165,14 +173,13 @@ CHIP_ERROR Device::Serialize(SerializedDevice & output) serializable.mDevicePort = Encoding::LittleEndian::HostSwap16(mDeviceAddress.GetPort()); serializable.mFabricIndex = Encoding::LittleEndian::HostSwap16(mFabricIndex); - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); - // The connection state could be null if the device is moving from PASE connection to CASE connection. // The device parameters (e.g. mDeviceOperationalCertProvisioned) are updated during this transition. // The state during this transistion is being persisted so that the next access of the device will // trigger the CASE based secure session. - if (connectionState != nullptr) + if (mSecureSession.HasValue()) { + Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value()); const uint32_t localMessageCounter = connectionState->GetSessionMessageCounter().GetLocalMessageCounter().Value(); const uint32_t peerMessageCounter = connectionState->GetSessionMessageCounter().GetPeerMessageCounter().GetCounter(); @@ -302,13 +309,13 @@ CHIP_ERROR Device::Persist() void Device::OnNewConnection(SessionHandle session) { - mState = ConnectionState::SecureConnected; - mSecureSession = session; + mState = ConnectionState::SecureConnected; + mSecureSession.SetValue(session); // Reset the message counters here because this is the first time we get a handle to the secure session. // Since CHIPDevices can be serialized/deserialized in the middle of what is conceptually a single PASE session // we need to restore the session counters along with the session information. - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); + Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value()); VerifyOrReturn(connectionState != nullptr); MessageCounter & localCounter = connectionState->GetSessionMessageCounter().GetLocalMessageCounter(); if (localCounter.SetCounter(mLocalMessageCounter) != CHIP_NO_ERROR) @@ -321,10 +328,10 @@ void Device::OnNewConnection(SessionHandle session) void Device::OnConnectionExpired(SessionHandle session) { - VerifyOrReturn(session == mSecureSession, + VerifyOrReturn(mSecureSession.HasValue() && mSecureSession.Value() == session, ChipLogDetail(Controller, "Connection expired, but it doesn't match the current session")); - mState = ConnectionState::NotConnected; - mSecureSession = SessionHandle{}; + mState = ConnectionState::NotConnected; + mSecureSession.ClearValue(); } CHIP_ERROR Device::OnMessageReceived(Messaging::ExchangeContext * exchange, const PacketHeader & header, @@ -398,7 +405,10 @@ CHIP_ERROR Device::OpenPairingWindow(uint16_t timeout, PairingWindowOption optio CHIP_ERROR Device::CloseSession() { ReturnErrorCodeIf(mState != ConnectionState::SecureConnected, CHIP_ERROR_INCORRECT_STATE); - mSessionManager->ExpirePairing(mSecureSession); + if (mSecureSession.HasValue()) + { + mSessionManager->ExpirePairing(mSecureSession.Value()); + } mState = ConnectionState::NotConnected; return CHIP_NO_ERROR; } @@ -411,8 +421,7 @@ CHIP_ERROR Device::UpdateAddress(const Transport::PeerAddress & addr) ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(didLoad)); - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); - if (connectionState == nullptr) + if (!mSecureSession.HasValue()) { // Nothing needs to be done here. It's not an error to not have a // connectionState. For one thing, we could have gotten an different @@ -421,6 +430,7 @@ CHIP_ERROR Device::UpdateAddress(const Transport::PeerAddress & addr) return CHIP_NO_ERROR; } + Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value()); connectionState->SetPeerAddress(addr); return CHIP_NO_ERROR; @@ -431,8 +441,8 @@ void Device::Reset() if (IsActive() && mStorageDelegate != nullptr && mSessionManager != nullptr) { // If a session can be found, persist the device so that we track the newest message counter values - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); - if (connectionState != nullptr) + + if (mSecureSession.HasValue()) { Persist(); } @@ -540,7 +550,8 @@ CHIP_ERROR Device::WarmupCASESession() VerifyOrReturnError(mDeviceOperationalCertProvisioned, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(mState == ConnectionState::NotConnected, CHIP_NO_ERROR); - Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(SessionHandle(), &mCASESession); + Messaging::ExchangeContext * exchange = + mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mCASESession); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager())); @@ -691,7 +702,7 @@ CHIP_ERROR Device::SendReadAttributeRequest(app::AttributePathParams aPath, Call // The application context is used to identify different requests from client applicaiton the type of it is intptr_t, here we // use the seqNum. CHIP_ERROR err = chip::app::InteractionModelEngine::GetInstance()->SendReadRequest( - GetDeviceId(), 0, &mSecureSession, nullptr /*event path params list*/, 0, &aPath, 1, 0 /* event number */, + GetDeviceId(), 0, mSecureSession, nullptr /*event path params list*/, 0, &aPath, 1, 0 /* event number */, seqNum /* application context */); if (err != CHIP_NO_ERROR) { @@ -714,7 +725,7 @@ CHIP_ERROR Device::SendWriteAttributeRequest(app::WriteClientHandle aHandle, Cal { AddResponseHandler(seqNum, onSuccessCallback, onFailureCallback); } - if ((err = aHandle.SendWriteRequest(GetDeviceId(), 0, &mSecureSession)) != CHIP_NO_ERROR) + if ((err = aHandle.SendWriteRequest(GetDeviceId(), 0, mSecureSession)) != CHIP_NO_ERROR) { CancelResponseHandler(seqNum); } diff --git a/src/controller/CHIPDevice.h b/src/controller/CHIPDevice.h index 7fe69dccb4a626..4b863567068ee4 100644 --- a/src/controller/CHIPDevice.h +++ b/src/controller/CHIPDevice.h @@ -343,9 +343,9 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta NodeId GetDeviceId() const { return mDeviceId; } - bool MatchesSession(SessionHandle session) const { return mSecureSession == session; } + bool MatchesSession(SessionHandle session) const { return mSecureSession.HasValue() && mSecureSession.Value() == session; } - SessionHandle GetSecureSession() const { return mSecureSession; } + SessionHandle GetSecureSession() const { return mSecureSession.Value(); } void SetAddress(const Inet::IPAddress & deviceAddr) { mDeviceAddress.SetIPAddress(deviceAddr); } @@ -451,7 +451,7 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta Messaging::ExchangeManager * mExchangeMgr = nullptr; - SessionHandle mSecureSession = {}; + Optional mSecureSession = Optional::Missing(); uint8_t mSequenceNumber = 0; diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 108e610f644e25..57d494a6bac3b7 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -891,7 +891,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam } } #endif - exchangeCtxt = mExchangeMgr->NewContext(SessionHandle(), &mPairingSession); + exchangeCtxt = mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mPairingSession); VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL); err = mIDAllocator.Allocate(keyID); diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 673fbc2a60c3a6..93def28ed465b2 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -93,6 +93,7 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp Transport::PeerConnectionState * state = nullptr; VerifyOrReturnError(mExchangeMgr != nullptr, CHIP_ERROR_INTERNAL); + VerifyOrReturnError(mSecureSession.HasValue(), CHIP_ERROR_CONNECTION_ABORTED); // Don't let method get called on a freed object. VerifyOrDie(mExchangeMgr != nullptr && GetReferenceCount() > 0); @@ -104,7 +105,7 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp bool reliableTransmissionRequested = true; - state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession); + state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession.Value()); // If sending via UDP and NoAutoRequestAck send flag is not specificed, request reliable transmission. if (state != nullptr && state->GetPeerAddress().GetTransportType() != Transport::Type::kUdp) { @@ -141,7 +142,7 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp { // Create a new scope for `err`, to avoid shadowing warning previous `err`. - CHIP_ERROR err = mDispatch->SendMessage(mSecureSession, mExchangeId, IsInitiator(), GetReliableMessageContext(), + CHIP_ERROR err = mDispatch->SendMessage(mSecureSession.Value(), mExchangeId, IsInitiator(), GetReliableMessageContext(), reliableTransmissionRequested, protocolId, msgType, std::move(msgBuf)); if (err != CHIP_NO_ERROR && IsResponseExpected()) { @@ -233,9 +234,9 @@ ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, Sess { VerifyOrDie(mExchangeMgr == nullptr); - mExchangeMgr = em; - mExchangeId = ExchangeId; - mSecureSession = session; + mExchangeMgr = em; + mExchangeId = ExchangeId; + mSecureSession.SetValue(session); mFlags.Set(Flags::kFlagInitiator, Initiator); mDelegate = delegate; @@ -305,7 +306,7 @@ bool ExchangeContext::MatchExchange(SessionHandle session, const PacketHeader & (mExchangeId == payloadHeader.GetExchangeID()) // AND The Session ID associated with the incoming message matches the Session ID associated with the exchange. - && (mSecureSession.MatchIncomingSession(session)) + && (mSecureSession.HasValue() && mSecureSession.Value().MatchIncomingSession(session)) // AND The message was sent by an initiator and the exchange context is a responder (IsInitiator==false) // OR The message was sent by a responder and the exchange context is an initiator (IsInitiator==true) (for the broadcast @@ -320,7 +321,7 @@ void ExchangeContext::OnConnectionExpired() // connection state) value, because it's still referencing the now-expired // connection. This will mean that no more messages can be sent via this // exchange, which seems fine given the semantics of connection expiration. - mSecureSession = SessionHandle(); + mSecureSession.ClearValue(); if (!IsResponseExpected()) { diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index 4a99c182c81c03..854a8c944ddb82 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -156,7 +156,7 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen { if (mExchangeACL == nullptr) { - Transport::FabricInfo * fabric = table.FindFabricWithIndex(mSecureSession.GetFabricIndex()); + Transport::FabricInfo * fabric = table.FindFabricWithIndex(mSecureSession.Value().GetFabricIndex()); if (fabric != nullptr) { mExchangeACL = chip::Platform::New(fabric); @@ -166,7 +166,7 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen return mExchangeACL; } - SessionHandle GetSecureSession() { return mSecureSession; } + SessionHandle GetSecureSession() { return mSecureSession.Value(); } uint16_t GetExchangeId() const { return mExchangeId; } @@ -188,8 +188,8 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen ExchangeMessageDispatch * mDispatch = nullptr; - SessionHandle mSecureSession; // The connection state - uint16_t mExchangeId; // Assigned exchange ID. + Optional mSecureSession; // The connection state + uint16_t mExchangeId; // Assigned exchange ID. /** * Determine whether a response is currently expected for a message that was sent over diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index dd69e2cf2fb03a..58be3154c0f665 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -315,7 +315,7 @@ void ExchangeManager::OnConnectionExpired(SessionHandle session) } mContextPool.ForEachActiveObject([&](auto * ec) { - if (ec->mSecureSession == session) + if (ec->mSecureSession.HasValue() && ec->mSecureSession.Value() == session) { ec->OnConnectionExpired(); // Continue to iterate because there can be multiple exchanges diff --git a/src/protocols/echo/Echo.h b/src/protocols/echo/Echo.h index 11bf7e99aec2fb..3ee26e03b528a2 100644 --- a/src/protocols/echo/Echo.h +++ b/src/protocols/echo/Echo.h @@ -103,7 +103,7 @@ class DLL_EXPORT EchoClient : public Messaging::ExchangeDelegate Messaging::ExchangeManager * mExchangeMgr = nullptr; Messaging::ExchangeContext * mExchangeCtx = nullptr; EchoFunct OnEchoResponseReceived = nullptr; - SessionHandle mSecureSession; + Optional mSecureSession = Optional(); CHIP_ERROR OnMessageReceived(Messaging::ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, System::PacketBufferHandle && payload) override; diff --git a/src/protocols/echo/EchoClient.cpp b/src/protocols/echo/EchoClient.cpp index e1bb0972cb04b8..cb76a34ac599f9 100644 --- a/src/protocols/echo/EchoClient.cpp +++ b/src/protocols/echo/EchoClient.cpp @@ -38,8 +38,8 @@ CHIP_ERROR EchoClient::Init(Messaging::ExchangeManager * exchangeMgr, SessionHan if (mExchangeMgr != nullptr) return CHIP_ERROR_INCORRECT_STATE; - mExchangeMgr = exchangeMgr; - mSecureSession = session; + mExchangeMgr = exchangeMgr; + mSecureSession.SetValue(session); OnEchoResponseReceived = nullptr; mExchangeCtx = nullptr; @@ -71,7 +71,7 @@ CHIP_ERROR EchoClient::SendEchoRequest(System::PacketBufferHandle && payload, Me } // Create a new exchange context. - mExchangeCtx = mExchangeMgr->NewContext(mSecureSession, this); + mExchangeCtx = mExchangeMgr->NewContext(mSecureSession.Value(), this); if (mExchangeCtx == nullptr) { return CHIP_ERROR_NO_MEMORY; diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index a9911a47dc3f41..bbe5501118d602 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -299,7 +299,7 @@ void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const { PayloadHeader payloadHeader; ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); - mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(), peerAddress, + mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle::TemporaryUnauthenticatedSession(), peerAddress, SecureSessionMgrDelegate::DuplicateMessage::No, std::move(msg)); } } diff --git a/src/transport/SessionHandle.h b/src/transport/SessionHandle.h index 238d82fbfb1394..fe9f0cadf26a5e 100644 --- a/src/transport/SessionHandle.h +++ b/src/transport/SessionHandle.h @@ -24,8 +24,6 @@ class SecureSessionMgr; class SessionHandle { public: - SessionHandle() : mPeerNodeId(kPlaceholderNodeId), mFabric(Transport::kUndefinedFabricIndex) {} - SessionHandle(NodeId peerNodeId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric) {} SessionHandle(NodeId peerNodeId, uint16_t localKeyId, uint16_t peerKeyId, FabricIndex fabric) : @@ -64,6 +62,12 @@ class SessionHandle const Optional & GetPeerKeyId() const { return mPeerKeyId; } const Optional & GetLocalKeyId() const { return mLocalKeyId; } + // TODO: currently SessionHandle is not able to identify a unauthenticated session, create an empty handle for it + static SessionHandle TemporaryUnauthenticatedSession() + { + return SessionHandle(kPlaceholderNodeId, Transport::kUndefinedFabricIndex); + } + private: friend class SecureSessionMgr; NodeId mPeerNodeId; diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index 2092c29e712468..23b43774826a7d 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -66,7 +66,7 @@ class TestSessMgrCallback : public SecureSessionMgrDelegate const Transport::PeerAddress & source, DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override { - NL_TEST_ASSERT(mSuite, session == mRemoteToLocalSession); // Packet received by remote peer + NL_TEST_ASSERT(mSuite, session == mRemoteToLocalSession.Value()); // Packet received by remote peer size_t data_len = msgBuf->DataLength(); @@ -88,20 +88,20 @@ class TestSessMgrCallback : public SecureSessionMgrDelegate { // Preset the MessageCounter if (NewConnectionHandlerCallCount == 0) - mRemoteToLocalSession = session; + mRemoteToLocalSession.SetValue(session); if (NewConnectionHandlerCallCount == 1) - mLocalToRemoteSession = session; + mLocalToRemoteSession.SetValue(session); NewConnectionHandlerCallCount++; } void OnConnectionExpired(SessionHandle session) override { mOldConnectionDropped = true; } bool mOldConnectionDropped = false; - nlTestSuite * mSuite = nullptr; - SessionHandle mRemoteToLocalSession; - SessionHandle mLocalToRemoteSession; - int ReceiveHandlerCallCount = 0; - int NewConnectionHandlerCallCount = 0; + nlTestSuite * mSuite = nullptr; + Optional mRemoteToLocalSession = Optional::Missing(); + Optional mLocalToRemoteSession = Optional::Missing(); + int ReceiveHandlerCallCount = 0; + int NewConnectionHandlerCallCount = 0; bool LargeMessageSent = false; }; @@ -166,7 +166,7 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession; + SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value(); // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -256,7 +256,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession; + SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value(); // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -279,7 +279,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // Reset receive side message counter, or duplicated message will be denied. - Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession); + Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession.Value()); state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -330,7 +330,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession; + SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value(); // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -356,7 +356,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) /* -------------------------------------------------------------------------------------------*/ // Reset receive side message counter, or duplicated message will be denied. - Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession); + Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession.Value()); state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); PacketHeader packetHeader; diff --git a/src/transport/tests/TestSessionHandle.cpp b/src/transport/tests/TestSessionHandle.cpp index 8a31c8e78acf3e..a4793c6fa2dcd2 100644 --- a/src/transport/tests/TestSessionHandle.cpp +++ b/src/transport/tests/TestSessionHandle.cpp @@ -37,24 +37,8 @@ using namespace chip; -void TestInitialState(nlTestSuite * inSuite, void * inContext) -{ - SessionHandle session; - - NL_TEST_ASSERT(inSuite, session.GetPeerNodeId() == kPlaceholderNodeId); - NL_TEST_ASSERT(inSuite, session.GetFabricIndex() == Transport::kUndefinedFabricIndex); - NL_TEST_ASSERT(inSuite, !session.HasFabricIndex()); - NL_TEST_ASSERT(inSuite, !session.GetLocalKeyId().HasValue()); - NL_TEST_ASSERT(inSuite, !session.GetPeerKeyId().HasValue()); -} - void TestMatchSession(nlTestSuite * inSuite, void * inContext) { - SessionHandle session1; - SessionHandle session2; - NL_TEST_ASSERT(inSuite, session1 == session2); - NL_TEST_ASSERT(inSuite, session1.MatchIncomingSession(session2)); - SessionHandle session3(chip::kTestDeviceNodeId, 1, 1, 0); SessionHandle session4(chip::kTestDeviceNodeId, 1, 2, 0); NL_TEST_ASSERT(inSuite, !(session3 == session4)); @@ -69,7 +53,6 @@ void TestMatchSession(nlTestSuite * inSuite, void * inContext) // clang-format off static const nlTest sTests[] = { - NL_TEST_DEF("InitialState", TestInitialState), NL_TEST_DEF("MatchSession", TestMatchSession), NL_TEST_SENTINEL() };