diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp index 381a1fc2387fcd..cfae386146172d 100644 --- a/src/app/CommandSender.cpp +++ b/src/app/CommandSender.cpp @@ -34,7 +34,7 @@ using GeneralStatusCode = chip::Protocols::SecureChannel::GeneralStatusCode; namespace chip { namespace app { -CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * secureSession, +CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional secureSession, uint32_t timeout) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -50,14 +50,7 @@ CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabric AbortExistingExchangeContext(); // Create a new exchange context. - if (secureSession == nullptr) - { - mpExchangeCtx = mpExchangeMgr->NewContext(SessionHandle(aNodeId, 0, 0, aFabricIndex), this); - } - else - { - mpExchangeCtx = mpExchangeMgr->NewContext(*secureSession, this); - } + mpExchangeCtx = mpExchangeMgr->NewContext(secureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(timeout); diff --git a/src/app/CommandSender.h b/src/app/CommandSender.h index 316ffe7670d15c..6d7ce63680ef6a 100644 --- a/src/app/CommandSender.h +++ b/src/app/CommandSender.h @@ -55,7 +55,7 @@ class CommandSender : public Command, public Messaging::ExchangeDelegate // // If SendCommandRequest is never called, or the call fails, the API // consumer is responsible for calling Shutdown on the CommandSender. - CHIP_ERROR SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * secureSession, + CHIP_ERROR SendCommandRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional secureSession, uint32_t timeout = kImMessageTimeoutMsec); private: diff --git a/src/app/ReadHandler.cpp b/src/app/ReadHandler.cpp index 102220b4de8283..3bc7bc758a348c 100644 --- a/src/app/ReadHandler.cpp +++ b/src/app/ReadHandler.cpp @@ -132,7 +132,6 @@ CHIP_ERROR ReadHandler::SendReportData(System::PacketBufferHandle && aPayload) if (IsInitialReport()) { VerifyOrReturnLogError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE); - mSecureHandle = mpExchangeCtx->GetSecureSession(); } VerifyOrReturnLogError(mpExchangeCtx != nullptr, CHIP_ERROR_INCORRECT_STATE); MoveToState(HandlerState::Reporting); diff --git a/src/app/ReadHandler.h b/src/app/ReadHandler.h index ca6d40abc7f5cc..9080bb29f949d2 100644 --- a/src/app/ReadHandler.h +++ b/src/app/ReadHandler.h @@ -166,7 +166,6 @@ class ReadHandler : public Messaging::ExchangeDelegate Messaging::ExchangeManager * mpExchangeMgr = nullptr; InteractionModelDelegate * mpDelegate = nullptr; bool mInitialReport = false; - SessionHandle mSecureHandle; }; } // namespace app } // namespace chip diff --git a/src/app/ReadPrepareParams.h b/src/app/ReadPrepareParams.h index 04f2e730c25f05..24031031892a11 100644 --- a/src/app/ReadPrepareParams.h +++ b/src/app/ReadPrepareParams.h @@ -39,10 +39,9 @@ struct ReadPrepareParams uint16_t mMinIntervalSeconds = 0; uint16_t mMaxIntervalSeconds = 0; - ReadPrepareParams() {} - ReadPrepareParams(ReadPrepareParams && other) + ReadPrepareParams(SessionHandle sessionHandle) : mSessionHandle(sessionHandle) {} + ReadPrepareParams(ReadPrepareParams && other) : mSessionHandle(other.mSessionHandle) { - mSessionHandle = other.mSessionHandle; mpEventPathParamsList = other.mpEventPathParamsList; mEventPathParamsListSize = other.mEventPathParamsListSize; mpAttributePathParamsList = other.mpAttributePathParamsList; diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp index bca0126cffba43..6c341145de203b 100644 --- a/src/app/WriteClient.cpp +++ b/src/app/WriteClient.cpp @@ -247,7 +247,7 @@ void WriteClient::ClearState() MoveToState(State::Uninitialized); } -CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, +CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, uint32_t timeout) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -263,14 +263,7 @@ CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricInde ClearExistingExchangeContext(); // Create a new exchange context. - if (apSecureSession == nullptr) - { - mpExchangeCtx = mpExchangeMgr->NewContext(SessionHandle(aNodeId, 0, 0, aFabricIndex), this); - } - else - { - mpExchangeCtx = mpExchangeMgr->NewContext(*apSecureSession, this); - } + mpExchangeCtx = mpExchangeMgr->NewContext(apSecureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(timeout); @@ -397,7 +390,7 @@ CHIP_ERROR WriteClient::ProcessAttributeStatusElement(AttributeStatusElement::Pa return err; } -CHIP_ERROR WriteClientHandle::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, +CHIP_ERROR WriteClientHandle::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, uint32_t timeout) { CHIP_ERROR err = mpWriteClient->SendWriteRequest(aNodeId, aFabricIndex, apSecureSession, timeout); diff --git a/src/app/WriteClient.h b/src/app/WriteClient.h index e757d2a755a9c2..e731cb19d27a02 100644 --- a/src/app/WriteClient.h +++ b/src/app/WriteClient.h @@ -94,7 +94,8 @@ class WriteClient : public Messaging::ExchangeDelegate * If SendWriteRequest is never called, or the call fails, the API * consumer is responsible for calling Shutdown on the WriteClient. */ - CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, uint32_t timeout); + CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, + uint32_t timeout); /** * Initialize the client object. Within the lifetime @@ -175,7 +176,7 @@ class WriteClientHandle * Finalize the message and send it to the desired node. The underlying write object will always be released, and the user * should not use this object after calling this function. */ - CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, SessionHandle * apSecureSession, + CHIP_ERROR SendWriteRequest(NodeId aNodeId, FabricIndex aFabricIndex, Optional apSecureSession, uint32_t timeout = kImMessageTimeoutMsec); /** diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp index 219fc88e0a7d4d..364a626ade5e88 100644 --- a/src/app/tests/TestCommandInteraction.cpp +++ b/src/app/tests/TestCommandInteraction.cpp @@ -231,7 +231,7 @@ void TestCommandInteraction::TestCommandSenderWithWrongState(nlTestSuite * apSui err = commandSender.Init(&gExchangeManager, nullptr); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, nullptr); + err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, Optional::Missing()); NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_INCORRECT_STATE); } @@ -270,7 +270,7 @@ void TestCommandInteraction::TestCommandSenderWithSendCommand(nlTestSuite * apSu NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); AddCommandDataElement(apSuite, apContext, &commandSender, false); - err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, nullptr); + err = commandSender.SendCommandRequest(kTestDeviceNodeId, gFabricIndex, Optional::Missing()); NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_NOT_CONNECTED); GenerateReceivedCommand(apSuite, apContext, buf, true /*aNeedCommandData*/); diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp index d57a50287f272e..8a1fe0035025e4 100644 --- a/src/app/tests/TestReadInteraction.cpp +++ b/src/app/tests/TestReadInteraction.cpp @@ -303,9 +303,8 @@ void TestReadInteraction::TestReadClient(nlTestSuite * apSuite, void * apContext System::PacketBufferHandle buf = System::PacketBufferHandle::New(System::PacketBuffer::kMaxSize); err = readClient.Init(&ctx.GetExchangeManager(), &delegate, 0 /* application identifier */); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - ReadPrepareParams readPrepareParams; - readPrepareParams.mSessionHandle = ctx.GetSessionLocalToPeer(); - err = readClient.SendReadRequest(readPrepareParams); + ReadPrepareParams readPrepareParams(ctx.GetSessionLocalToPeer()); + err = readClient.SendReadRequest(readPrepareParams); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); GenerateReportData(apSuite, apContext, buf); @@ -329,9 +328,8 @@ void TestReadInteraction::TestReadHandler(nlTestSuite * apSuite, void * apContex auto * engine = chip::app::InteractionModelEngine::GetInstance(); err = engine->Init(&ctx.GetExchangeManager(), &delegate); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - Messaging::ExchangeManager exchangeManager; - Messaging::ExchangeContext * exchangeCtx = exchangeManager.NewContext(SessionHandle(), nullptr); - readHandler.Init(&exchangeManager, nullptr, exchangeCtx); + Messaging::ExchangeContext * exchangeCtx = ctx.NewExchangeToPeer(nullptr); + readHandler.Init(&ctx.GetExchangeManager(), nullptr, exchangeCtx); GenerateReportData(apSuite, apContext, reportDatabuf); err = readHandler.SendReportData(std::move(reportDatabuf)); @@ -364,6 +362,7 @@ void TestReadInteraction::TestReadHandler(nlTestSuite * apSuite, void * apContex err = readHandler.OnReadInitialRequest(std::move(readRequestbuf)); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + exchangeCtx->Close(); engine->Shutdown(); } @@ -431,9 +430,8 @@ void TestReadInteraction::TestReadClientInvalidReport(nlTestSuite * apSuite, voi err = readClient.Init(&ctx.GetExchangeManager(), &delegate, 0 /* application identifier */); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - ReadPrepareParams readPrepareParams; - readPrepareParams.mSessionHandle = ctx.GetSessionLocalToPeer(); - err = readClient.SendReadRequest(readPrepareParams); + ReadPrepareParams readPrepareParams(ctx.GetSessionLocalToPeer()); + err = readClient.SendReadRequest(readPrepareParams); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); GenerateReportData(apSuite, apContext, buf, true /*aNeedInvalidReport*/); @@ -458,9 +456,8 @@ void TestReadInteraction::TestReadHandlerInvalidAttributePath(nlTestSuite * apSu auto * engine = chip::app::InteractionModelEngine::GetInstance(); err = engine->Init(&ctx.GetExchangeManager(), &delegate); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); - Messaging::ExchangeManager exchangeManager; - Messaging::ExchangeContext * exchangeCtx = exchangeManager.NewContext(SessionHandle(), nullptr); - readHandler.Init(&exchangeManager, nullptr, exchangeCtx); + Messaging::ExchangeContext * exchangeCtx = ctx.NewExchangeToPeer(nullptr); + readHandler.Init(&ctx.GetExchangeManager(), nullptr, exchangeCtx); GenerateReportData(apSuite, apContext, reportDatabuf); err = readHandler.SendReportData(std::move(reportDatabuf)); @@ -488,6 +485,8 @@ void TestReadInteraction::TestReadHandlerInvalidAttributePath(nlTestSuite * apSu err = readHandler.OnReadInitialRequest(std::move(readRequestbuf)); NL_TEST_ASSERT(apSuite, err == CHIP_ERROR_IM_MALFORMED_ATTRIBUTE_PATH); + + exchangeCtx->Close(); engine->Shutdown(); } @@ -638,8 +637,7 @@ void TestReadInteraction::TestReadRoundtrip(nlTestSuite * apSuite, void * apCont attributePathParams[1].mFlags.Set(chip::app::AttributePathParams::Flags::kFieldIdValid); attributePathParams[1].mFlags.Set(chip::app::AttributePathParams::Flags::kListIndexValid); - ReadPrepareParams readPrepareParams; - readPrepareParams.mSessionHandle = ctx.GetSessionLocalToPeer(); + ReadPrepareParams readPrepareParams(ctx.GetSessionLocalToPeer()); readPrepareParams.mpEventPathParamsList = eventPathParams; readPrepareParams.mEventPathParamsListSize = 2; readPrepareParams.mpAttributePathParamsList = attributePathParams; @@ -684,8 +682,7 @@ void TestReadInteraction::TestReadInvalidAttributePathRoundtrip(nlTestSuite * ap attributePathParams[0].mListIndex = 0; attributePathParams[0].mFlags.Set(chip::app::AttributePathParams::Flags::kFieldIdValid); - ReadPrepareParams readPrepareParams; - readPrepareParams.mSessionHandle = ctx.GetSessionLocalToPeer(); + ReadPrepareParams readPrepareParams(ctx.GetSessionLocalToPeer()); readPrepareParams.mpAttributePathParamsList = attributePathParams; readPrepareParams.mAttributePathParamsListSize = 1; err = chip::app::InteractionModelEngine::GetInstance()->SendReadRequest(readPrepareParams); diff --git a/src/app/tests/TestWriteInteraction.cpp b/src/app/tests/TestWriteInteraction.cpp index dec03ce4056240..466c9e9e2bb657 100644 --- a/src/app/tests/TestWriteInteraction.cpp +++ b/src/app/tests/TestWriteInteraction.cpp @@ -217,7 +217,8 @@ void TestWriteInteraction::TestWriteClient(nlTestSuite * apSuite, void * apConte AddAttributeDataElement(apSuite, apContext, writeClientHandle); SessionHandle session = ctx.GetSessionLocalToPeer(); - err = writeClientHandle.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), &session); + err = writeClientHandle.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), + Optional::Value(session)); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); // The internal WriteClient should be nullptr once we SendWriteRequest. NL_TEST_ASSERT(apSuite, nullptr == writeClientHandle.mpWriteClient); @@ -306,7 +307,7 @@ void TestWriteInteraction::TestWriteRoundtrip(nlTestSuite * apSuite, void * apCo SessionHandle session = ctx.GetSessionLocalToPeer(); - err = writeClient.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), &session); + err = writeClient.SendWriteRequest(ctx.GetDestinationNodeId(), ctx.GetFabricIndex(), Optional::Value(session)); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); NL_TEST_ASSERT(apSuite, delegate.mGotResponse); diff --git a/src/app/tests/integration/chip_im_initiator.cpp b/src/app/tests/integration/chip_im_initiator.cpp index 362d063fa2611a..ffa576f9396151 100644 --- a/src/app/tests/integration/chip_im_initiator.cpp +++ b/src/app/tests/integration/chip_im_initiator.cpp @@ -126,7 +126,8 @@ CHIP_ERROR SendCommandRequest(chip::app::CommandSender * commandSender) err = commandSender->FinishCommand(); SuccessOrExit(err); - err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec); + err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, chip::Optional::Missing(), + gMessageTimeoutMsec); SuccessOrExit(err); exit: @@ -163,7 +164,8 @@ CHIP_ERROR SendBadCommandRequest(chip::app::CommandSender * commandSender) err = commandSender->FinishCommand(); SuccessOrExit(err); - err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec); + err = commandSender->SendCommandRequest(chip::kTestDeviceNodeId, gFabricIndex, chip::Optional::Missing(), + gMessageTimeoutMsec); SuccessOrExit(err); exit: @@ -181,7 +183,6 @@ CHIP_ERROR SendBadCommandRequest(chip::app::CommandSender * commandSender) CHIP_ERROR SendReadRequest() { CHIP_ERROR err = CHIP_NO_ERROR; - chip::app::ReadPrepareParams readPrepareParams; chip::app::EventPathParams eventPathParams[2]; eventPathParams[0].mNodeId = kTestNodeId; eventPathParams[0].mEndpointId = kTestEndpointId; @@ -198,7 +199,7 @@ CHIP_ERROR SendReadRequest() printf("\nSend read request message to Node: %" PRIu64 "\n", chip::kTestDeviceNodeId); - readPrepareParams.mSessionHandle = chip::SessionHandle(chip::kTestDeviceNodeId, 0, 0, gFabricIndex); + chip::app::ReadPrepareParams readPrepareParams(chip::SessionHandle(chip::kTestDeviceNodeId, 0, 0, gFabricIndex)); readPrepareParams.mTimeout = gMessageTimeoutMsec; readPrepareParams.mpAttributePathParamsList = &attributePathParams; readPrepareParams.mAttributePathParamsListSize = 1; @@ -241,7 +242,8 @@ CHIP_ERROR SendWriteRequest(chip::app::WriteClientHandle & apWriteClient) SuccessOrExit(err = writer->PutBoolean(chip::TLV::ContextTag(chip::app::AttributeDataElement::kCsTag_Data), true)); SuccessOrExit(err = apWriteClient->FinishAttribute()); - SuccessOrExit(err = apWriteClient.SendWriteRequest(chip::kTestDeviceNodeId, gFabricIndex, nullptr, gMessageTimeoutMsec)); + SuccessOrExit(err = apWriteClient.SendWriteRequest(chip::kTestDeviceNodeId, gFabricIndex, + chip::Optional::Missing(), gMessageTimeoutMsec)); gWriteCount++; diff --git a/src/channel/ChannelContext.cpp b/src/channel/ChannelContext.cpp index 607ff8aaf88f7b..72523b7c019f75 100644 --- a/src/channel/ChannelContext.cpp +++ b/src/channel/ChannelContext.cpp @@ -258,7 +258,8 @@ void ChannelContext::EnterCasePairingState() auto & prepare = GetPrepareVars(); prepare.mCasePairingSession = Platform::New(); - ExchangeContext * ctxt = mExchangeManager->NewContext(SessionHandle(), prepare.mCasePairingSession); + ExchangeContext * ctxt = + mExchangeManager->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), prepare.mCasePairingSession); VerifyOrReturn(ctxt != nullptr); // TODO: currently only supports IP/UDP paring diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index fd1a2bd49b6d32..78b906912ee01e 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -73,7 +73,7 @@ CHIP_ERROR Device::SendMessage(Protocols::Id protocolId, uint8_t msgType, Messag ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(loadedSecureSession)); - Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(mSecureSession, nullptr); + Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(mSecureSession.Value(), nullptr); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_NO_MEMORY); if (!loadedSecureSession) @@ -129,10 +129,18 @@ CHIP_ERROR Device::LoadSecureSessionParametersIfNeeded(bool & didLoad) } else { - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); - - // Check if the connection state has the correct transport information - if (connectionState == nullptr || connectionState->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined) + if (mSecureSession.HasValue()) + { + Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value()); + // Check if the connection state has the correct transport information + if (connectionState->GetPeerAddress().GetTransportType() == Transport::Type::kUndefined) + { + mState = ConnectionState::NotConnected; + ReturnErrorOnFailure(LoadSecureSessionParameters(ResetTransport::kNo)); + didLoad = true; + } + } + else { mState = ConnectionState::NotConnected; ReturnErrorOnFailure(LoadSecureSessionParameters(ResetTransport::kNo)); @@ -148,7 +156,7 @@ CHIP_ERROR Device::SendCommands(app::CommandSender * commandObj) bool loadedSecureSession = false; ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(loadedSecureSession)); VerifyOrReturnError(commandObj != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - return commandObj->SendCommandRequest(mDeviceId, mFabricIndex, &mSecureSession); + return commandObj->SendCommandRequest(mDeviceId, mFabricIndex, mSecureSession); } CHIP_ERROR Device::Serialize(SerializedDevice & output) @@ -166,14 +174,13 @@ CHIP_ERROR Device::Serialize(SerializedDevice & output) serializable.mDevicePort = Encoding::LittleEndian::HostSwap16(mDeviceAddress.GetPort()); serializable.mFabricIndex = Encoding::LittleEndian::HostSwap16(mFabricIndex); - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); - // The connection state could be null if the device is moving from PASE connection to CASE connection. // The device parameters (e.g. mDeviceOperationalCertProvisioned) are updated during this transition. // The state during this transistion is being persisted so that the next access of the device will // trigger the CASE based secure session. - if (connectionState != nullptr) + if (mSecureSession.HasValue()) { + Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value()); const uint32_t localMessageCounter = connectionState->GetSessionMessageCounter().GetLocalMessageCounter().Value(); const uint32_t peerMessageCounter = connectionState->GetSessionMessageCounter().GetPeerMessageCounter().GetCounter(); @@ -303,13 +310,13 @@ CHIP_ERROR Device::Persist() void Device::OnNewConnection(SessionHandle session) { - mState = ConnectionState::SecureConnected; - mSecureSession = session; + mState = ConnectionState::SecureConnected; + mSecureSession.SetValue(session); // Reset the message counters here because this is the first time we get a handle to the secure session. // Since CHIPDevices can be serialized/deserialized in the middle of what is conceptually a single PASE session // we need to restore the session counters along with the session information. - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); + Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value()); VerifyOrReturn(connectionState != nullptr); MessageCounter & localCounter = connectionState->GetSessionMessageCounter().GetLocalMessageCounter(); if (localCounter.SetCounter(mLocalMessageCounter) != CHIP_NO_ERROR) @@ -322,10 +329,10 @@ void Device::OnNewConnection(SessionHandle session) void Device::OnConnectionExpired(SessionHandle session) { - VerifyOrReturn(session == mSecureSession, + VerifyOrReturn(mSecureSession.HasValue() && mSecureSession.Value() == session, ChipLogDetail(Controller, "Connection expired, but it doesn't match the current session")); - mState = ConnectionState::NotConnected; - mSecureSession = SessionHandle{}; + mState = ConnectionState::NotConnected; + mSecureSession.ClearValue(); } CHIP_ERROR Device::OnMessageReceived(Messaging::ExchangeContext * exchange, const PacketHeader & header, @@ -407,7 +414,10 @@ CHIP_ERROR Device::OpenPairingWindow(uint16_t timeout, PairingWindowOption optio CHIP_ERROR Device::CloseSession() { ReturnErrorCodeIf(mState != ConnectionState::SecureConnected, CHIP_ERROR_INCORRECT_STATE); - mSessionManager->ExpirePairing(mSecureSession); + if (mSecureSession.HasValue()) + { + mSessionManager->ExpirePairing(mSecureSession.Value()); + } mState = ConnectionState::NotConnected; return CHIP_NO_ERROR; } @@ -420,8 +430,7 @@ CHIP_ERROR Device::UpdateAddress(const Transport::PeerAddress & addr) ReturnErrorOnFailure(LoadSecureSessionParametersIfNeeded(didLoad)); - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); - if (connectionState == nullptr) + if (!mSecureSession.HasValue()) { // Nothing needs to be done here. It's not an error to not have a // connectionState. For one thing, we could have gotten an different @@ -430,6 +439,7 @@ CHIP_ERROR Device::UpdateAddress(const Transport::PeerAddress & addr) return CHIP_NO_ERROR; } + Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession.Value()); connectionState->SetPeerAddress(addr); return CHIP_NO_ERROR; @@ -440,8 +450,8 @@ void Device::Reset() if (IsActive() && mStorageDelegate != nullptr && mSessionManager != nullptr) { // If a session can be found, persist the device so that we track the newest message counter values - Transport::PeerConnectionState * connectionState = mSessionManager->GetPeerConnectionState(mSecureSession); - if (connectionState != nullptr) + + if (mSecureSession.HasValue()) { Persist(); } @@ -549,7 +559,8 @@ CHIP_ERROR Device::WarmupCASESession() VerifyOrReturnError(mDeviceOperationalCertProvisioned, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(mState == ConnectionState::NotConnected, CHIP_NO_ERROR); - Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(SessionHandle(), &mCASESession); + Messaging::ExchangeContext * exchange = + mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mCASESession); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager())); @@ -687,7 +698,6 @@ void Device::AddReportHandler(EndpointId endpoint, ClusterId cluster, AttributeI CHIP_ERROR Device::SendReadAttributeRequest(app::AttributePathParams aPath, Callback::Cancelable * onSuccessCallback, Callback::Cancelable * onFailureCallback, app::TLVDataFilter aTlvDataFilter) { - chip::app::ReadPrepareParams readPrepareParams; bool loadedSecureSession = false; uint8_t seqNum = GetNextSequenceNumber(); aPath.mNodeId = GetDeviceId(); @@ -700,7 +710,7 @@ CHIP_ERROR Device::SendReadAttributeRequest(app::AttributePathParams aPath, Call } // The application context is used to identify different requests from client applicaiton the type of it is intptr_t, here we // use the seqNum. - readPrepareParams.mSessionHandle = mSecureSession; + chip::app::ReadPrepareParams readPrepareParams(mSecureSession.Value()); readPrepareParams.mpAttributePathParamsList = &aPath; readPrepareParams.mAttributePathParamsListSize = 1; CHIP_ERROR err = @@ -726,7 +736,7 @@ CHIP_ERROR Device::SendWriteAttributeRequest(app::WriteClientHandle aHandle, Cal { AddResponseHandler(seqNum, onSuccessCallback, onFailureCallback); } - if ((err = aHandle.SendWriteRequest(GetDeviceId(), 0, &mSecureSession)) != CHIP_NO_ERROR) + if ((err = aHandle.SendWriteRequest(GetDeviceId(), 0, mSecureSession)) != CHIP_NO_ERROR) { CancelResponseHandler(seqNum); } diff --git a/src/controller/CHIPDevice.h b/src/controller/CHIPDevice.h index ee7a09b23109f5..3c15aa032e76b8 100644 --- a/src/controller/CHIPDevice.h +++ b/src/controller/CHIPDevice.h @@ -366,9 +366,9 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta NodeId GetDeviceId() const { return mDeviceId; } - bool MatchesSession(SessionHandle session) const { return mSecureSession == session; } + bool MatchesSession(SessionHandle session) const { return mSecureSession.HasValue() && mSecureSession.Value() == session; } - SessionHandle GetSecureSession() const { return mSecureSession; } + SessionHandle GetSecureSession() const { return mSecureSession.Value(); } void SetAddress(const Inet::IPAddress & deviceAddr) { mDeviceAddress.SetIPAddress(deviceAddr); } @@ -474,7 +474,7 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta Messaging::ExchangeManager * mExchangeMgr = nullptr; - SessionHandle mSecureSession = {}; + Optional mSecureSession = Optional::Missing(); uint8_t mSequenceNumber = 0; diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 25876140238708..eb7c466c6ac10c 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -883,7 +883,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam } } #endif - exchangeCtxt = mExchangeMgr->NewContext(SessionHandle(), &mPairingSession); + exchangeCtxt = mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mPairingSession); VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL); err = mIDAllocator.Allocate(keyID); diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 673fbc2a60c3a6..93def28ed465b2 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -93,6 +93,7 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp Transport::PeerConnectionState * state = nullptr; VerifyOrReturnError(mExchangeMgr != nullptr, CHIP_ERROR_INTERNAL); + VerifyOrReturnError(mSecureSession.HasValue(), CHIP_ERROR_CONNECTION_ABORTED); // Don't let method get called on a freed object. VerifyOrDie(mExchangeMgr != nullptr && GetReferenceCount() > 0); @@ -104,7 +105,7 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp bool reliableTransmissionRequested = true; - state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession); + state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession.Value()); // If sending via UDP and NoAutoRequestAck send flag is not specificed, request reliable transmission. if (state != nullptr && state->GetPeerAddress().GetTransportType() != Transport::Type::kUdp) { @@ -141,7 +142,7 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp { // Create a new scope for `err`, to avoid shadowing warning previous `err`. - CHIP_ERROR err = mDispatch->SendMessage(mSecureSession, mExchangeId, IsInitiator(), GetReliableMessageContext(), + CHIP_ERROR err = mDispatch->SendMessage(mSecureSession.Value(), mExchangeId, IsInitiator(), GetReliableMessageContext(), reliableTransmissionRequested, protocolId, msgType, std::move(msgBuf)); if (err != CHIP_NO_ERROR && IsResponseExpected()) { @@ -233,9 +234,9 @@ ExchangeContext::ExchangeContext(ExchangeManager * em, uint16_t ExchangeId, Sess { VerifyOrDie(mExchangeMgr == nullptr); - mExchangeMgr = em; - mExchangeId = ExchangeId; - mSecureSession = session; + mExchangeMgr = em; + mExchangeId = ExchangeId; + mSecureSession.SetValue(session); mFlags.Set(Flags::kFlagInitiator, Initiator); mDelegate = delegate; @@ -305,7 +306,7 @@ bool ExchangeContext::MatchExchange(SessionHandle session, const PacketHeader & (mExchangeId == payloadHeader.GetExchangeID()) // AND The Session ID associated with the incoming message matches the Session ID associated with the exchange. - && (mSecureSession.MatchIncomingSession(session)) + && (mSecureSession.HasValue() && mSecureSession.Value().MatchIncomingSession(session)) // AND The message was sent by an initiator and the exchange context is a responder (IsInitiator==false) // OR The message was sent by a responder and the exchange context is an initiator (IsInitiator==true) (for the broadcast @@ -320,7 +321,7 @@ void ExchangeContext::OnConnectionExpired() // connection state) value, because it's still referencing the now-expired // connection. This will mean that no more messages can be sent via this // exchange, which seems fine given the semantics of connection expiration. - mSecureSession = SessionHandle(); + mSecureSession.ClearValue(); if (!IsResponseExpected()) { diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index 4a99c182c81c03..854a8c944ddb82 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -156,7 +156,7 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen { if (mExchangeACL == nullptr) { - Transport::FabricInfo * fabric = table.FindFabricWithIndex(mSecureSession.GetFabricIndex()); + Transport::FabricInfo * fabric = table.FindFabricWithIndex(mSecureSession.Value().GetFabricIndex()); if (fabric != nullptr) { mExchangeACL = chip::Platform::New(fabric); @@ -166,7 +166,7 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen return mExchangeACL; } - SessionHandle GetSecureSession() { return mSecureSession; } + SessionHandle GetSecureSession() { return mSecureSession.Value(); } uint16_t GetExchangeId() const { return mExchangeId; } @@ -188,8 +188,8 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, public Referen ExchangeMessageDispatch * mDispatch = nullptr; - SessionHandle mSecureSession; // The connection state - uint16_t mExchangeId; // Assigned exchange ID. + Optional mSecureSession; // The connection state + uint16_t mExchangeId; // Assigned exchange ID. /** * Determine whether a response is currently expected for a message that was sent over diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index dd69e2cf2fb03a..58be3154c0f665 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -315,7 +315,7 @@ void ExchangeManager::OnConnectionExpired(SessionHandle session) } mContextPool.ForEachActiveObject([&](auto * ec) { - if (ec->mSecureSession == session) + if (ec->mSecureSession.HasValue() && ec->mSecureSession.Value() == session) { ec->OnConnectionExpired(); // Continue to iterate because there can be multiple exchanges diff --git a/src/protocols/echo/Echo.h b/src/protocols/echo/Echo.h index 11bf7e99aec2fb..3ee26e03b528a2 100644 --- a/src/protocols/echo/Echo.h +++ b/src/protocols/echo/Echo.h @@ -103,7 +103,7 @@ class DLL_EXPORT EchoClient : public Messaging::ExchangeDelegate Messaging::ExchangeManager * mExchangeMgr = nullptr; Messaging::ExchangeContext * mExchangeCtx = nullptr; EchoFunct OnEchoResponseReceived = nullptr; - SessionHandle mSecureSession; + Optional mSecureSession = Optional(); CHIP_ERROR OnMessageReceived(Messaging::ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, System::PacketBufferHandle && payload) override; diff --git a/src/protocols/echo/EchoClient.cpp b/src/protocols/echo/EchoClient.cpp index e1bb0972cb04b8..cb76a34ac599f9 100644 --- a/src/protocols/echo/EchoClient.cpp +++ b/src/protocols/echo/EchoClient.cpp @@ -38,8 +38,8 @@ CHIP_ERROR EchoClient::Init(Messaging::ExchangeManager * exchangeMgr, SessionHan if (mExchangeMgr != nullptr) return CHIP_ERROR_INCORRECT_STATE; - mExchangeMgr = exchangeMgr; - mSecureSession = session; + mExchangeMgr = exchangeMgr; + mSecureSession.SetValue(session); OnEchoResponseReceived = nullptr; mExchangeCtx = nullptr; @@ -71,7 +71,7 @@ CHIP_ERROR EchoClient::SendEchoRequest(System::PacketBufferHandle && payload, Me } // Create a new exchange context. - mExchangeCtx = mExchangeMgr->NewContext(mSecureSession, this); + mExchangeCtx = mExchangeMgr->NewContext(mSecureSession.Value(), this); if (mExchangeCtx == nullptr) { return CHIP_ERROR_NO_MEMORY; diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 5229624a19b7d3..d5ffcb2d57ff81 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -296,7 +296,7 @@ void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const { PayloadHeader payloadHeader; ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); - mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(), peerAddress, + mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle::TemporaryUnauthenticatedSession(), peerAddress, SecureSessionMgrDelegate::DuplicateMessage::No, std::move(msg)); } } diff --git a/src/transport/SessionHandle.h b/src/transport/SessionHandle.h index 238d82fbfb1394..fe9f0cadf26a5e 100644 --- a/src/transport/SessionHandle.h +++ b/src/transport/SessionHandle.h @@ -24,8 +24,6 @@ class SecureSessionMgr; class SessionHandle { public: - SessionHandle() : mPeerNodeId(kPlaceholderNodeId), mFabric(Transport::kUndefinedFabricIndex) {} - SessionHandle(NodeId peerNodeId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric) {} SessionHandle(NodeId peerNodeId, uint16_t localKeyId, uint16_t peerKeyId, FabricIndex fabric) : @@ -64,6 +62,12 @@ class SessionHandle const Optional & GetPeerKeyId() const { return mPeerKeyId; } const Optional & GetLocalKeyId() const { return mLocalKeyId; } + // TODO: currently SessionHandle is not able to identify a unauthenticated session, create an empty handle for it + static SessionHandle TemporaryUnauthenticatedSession() + { + return SessionHandle(kPlaceholderNodeId, Transport::kUndefinedFabricIndex); + } + private: friend class SecureSessionMgr; NodeId mPeerNodeId; diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index 2092c29e712468..23b43774826a7d 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -66,7 +66,7 @@ class TestSessMgrCallback : public SecureSessionMgrDelegate const Transport::PeerAddress & source, DuplicateMessage isDuplicate, System::PacketBufferHandle && msgBuf) override { - NL_TEST_ASSERT(mSuite, session == mRemoteToLocalSession); // Packet received by remote peer + NL_TEST_ASSERT(mSuite, session == mRemoteToLocalSession.Value()); // Packet received by remote peer size_t data_len = msgBuf->DataLength(); @@ -88,20 +88,20 @@ class TestSessMgrCallback : public SecureSessionMgrDelegate { // Preset the MessageCounter if (NewConnectionHandlerCallCount == 0) - mRemoteToLocalSession = session; + mRemoteToLocalSession.SetValue(session); if (NewConnectionHandlerCallCount == 1) - mLocalToRemoteSession = session; + mLocalToRemoteSession.SetValue(session); NewConnectionHandlerCallCount++; } void OnConnectionExpired(SessionHandle session) override { mOldConnectionDropped = true; } bool mOldConnectionDropped = false; - nlTestSuite * mSuite = nullptr; - SessionHandle mRemoteToLocalSession; - SessionHandle mLocalToRemoteSession; - int ReceiveHandlerCallCount = 0; - int NewConnectionHandlerCallCount = 0; + nlTestSuite * mSuite = nullptr; + Optional mRemoteToLocalSession = Optional::Missing(); + Optional mLocalToRemoteSession = Optional::Missing(); + int ReceiveHandlerCallCount = 0; + int NewConnectionHandlerCallCount = 0; bool LargeMessageSent = false; }; @@ -166,7 +166,7 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession; + SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value(); // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -256,7 +256,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession; + SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value(); // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -279,7 +279,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // Reset receive side message counter, or duplicated message will be denied. - Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession); + Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession.Value()); state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -330,7 +330,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairing2, SecureSession::SessionRole::kResponder, 0); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - SessionHandle localToRemoteSession = callback.mLocalToRemoteSession; + SessionHandle localToRemoteSession = callback.mLocalToRemoteSession.Value(); // Should be able to send a message to itself by just calling send. callback.ReceiveHandlerCallCount = 0; @@ -356,7 +356,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) /* -------------------------------------------------------------------------------------------*/ // Reset receive side message counter, or duplicated message will be denied. - Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession); + Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession.Value()); state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); PacketHeader packetHeader; diff --git a/src/transport/tests/TestSessionHandle.cpp b/src/transport/tests/TestSessionHandle.cpp index 8a31c8e78acf3e..a4793c6fa2dcd2 100644 --- a/src/transport/tests/TestSessionHandle.cpp +++ b/src/transport/tests/TestSessionHandle.cpp @@ -37,24 +37,8 @@ using namespace chip; -void TestInitialState(nlTestSuite * inSuite, void * inContext) -{ - SessionHandle session; - - NL_TEST_ASSERT(inSuite, session.GetPeerNodeId() == kPlaceholderNodeId); - NL_TEST_ASSERT(inSuite, session.GetFabricIndex() == Transport::kUndefinedFabricIndex); - NL_TEST_ASSERT(inSuite, !session.HasFabricIndex()); - NL_TEST_ASSERT(inSuite, !session.GetLocalKeyId().HasValue()); - NL_TEST_ASSERT(inSuite, !session.GetPeerKeyId().HasValue()); -} - void TestMatchSession(nlTestSuite * inSuite, void * inContext) { - SessionHandle session1; - SessionHandle session2; - NL_TEST_ASSERT(inSuite, session1 == session2); - NL_TEST_ASSERT(inSuite, session1.MatchIncomingSession(session2)); - SessionHandle session3(chip::kTestDeviceNodeId, 1, 1, 0); SessionHandle session4(chip::kTestDeviceNodeId, 1, 2, 0); NL_TEST_ASSERT(inSuite, !(session3 == session4)); @@ -69,7 +53,6 @@ void TestMatchSession(nlTestSuite * inSuite, void * inContext) // clang-format off static const nlTest sTests[] = { - NL_TEST_DEF("InitialState", TestInitialState), NL_TEST_DEF("MatchSession", TestMatchSession), NL_TEST_SENTINEL() };