diff --git a/src/app/server/RendezvousServer.cpp b/src/app/server/RendezvousServer.cpp index b3fb76357472fa..06eff52b32a898 100644 --- a/src/app/server/RendezvousServer.cpp +++ b/src/app/server/RendezvousServer.cpp @@ -118,8 +118,7 @@ CHIP_ERROR RendezvousServer::WaitForPairing(const RendezvousParameters & params, ReturnErrorOnFailure(mPairingSession.WaitForPairing(params.GetSetupPINCode(), pbkdf2IterCount, salt, keyID, this)); } - ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr)); - mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress()); + ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mSessionMgr)); return CHIP_NO_ERROR; } diff --git a/src/channel/ChannelContext.cpp b/src/channel/ChannelContext.cpp index 72523b7c019f75..800b24c380ccf2 100644 --- a/src/channel/ChannelContext.cpp +++ b/src/channel/ChannelContext.cpp @@ -258,13 +258,22 @@ void ChannelContext::EnterCasePairingState() auto & prepare = GetPrepareVars(); prepare.mCasePairingSession = Platform::New(); - ExchangeContext * ctxt = - mExchangeManager->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), prepare.mCasePairingSession); - VerifyOrReturn(ctxt != nullptr); - // TODO: currently only supports IP/UDP paring Transport::PeerAddress addr; addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(prepare.mAddress); + + auto session = mExchangeManager->GetSessionMgr()->CreateUnauthenticatedSession(addr); + if (!session.HasValue()) + { + ExitCasePairingState(); + ExitPreparingState(); + EnterFailedState(CHIP_ERROR_NO_MEMORY); + return; + } + + ExchangeContext * ctxt = mExchangeManager->NewContext(session.Value(), prepare.mCasePairingSession); + VerifyOrReturn(ctxt != nullptr); + Transport::FabricInfo * fabric = mFabricsTable->FindFabricWithIndex(mFabricIndex); VerifyOrReturn(fabric != nullptr); CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession(addr, fabric, prepare.mBuilder.GetPeerNodeId(), diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index 52a8e42f328881..94b6e799c03ce5 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -560,12 +560,17 @@ CHIP_ERROR Device::WarmupCASESession() VerifyOrReturnError(mDeviceOperationalCertProvisioned, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(mState == ConnectionState::NotConnected, CHIP_NO_ERROR); - Messaging::ExchangeContext * exchange = - mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mCASESession); + // Create a UnauthenticatedSession for CASE pairing. + // Don't use mSecureSession here, because mSecureSession is the secure session. + Optional session = mSessionManager->CreateUnauthenticatedSession(mDeviceAddress); + if (!session.HasValue()) + { + return CHIP_ERROR_NO_MEMORY; + } + Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(session.Value(), &mCASESession); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); - ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager())); - mCASESession.MessageDispatch().SetPeerAddress(mDeviceAddress); + ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager)); uint16_t keyID = 0; ReturnErrorOnFailure(mIDAllocator->Allocate(keyID)); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index ffe976bf2811cf..c535d10dcb0b12 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -805,6 +805,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam Transport::PeerAddress peerAddress = Transport::PeerAddress::UDP(Inet::IPAddress::Any); Messaging::ExchangeContext * exchangeCtxt = nullptr; + Optional session; uint16_t keyID = 0; @@ -857,9 +858,8 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle); - err = mPairingSession.MessageDispatch().Init(mTransportMgr); + err = mPairingSession.MessageDispatch().Init(mSessionMgr); SuccessOrExit(err); - mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress()); device->Init(GetControllerDeviceInitParams(), mListenPort, remoteDeviceId, peerAddress, fabric->GetFabricIndex()); @@ -885,7 +885,10 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam } } #endif - exchangeCtxt = mExchangeMgr->NewContext(SessionHandle::TemporaryUnauthenticatedSession(), &mPairingSession); + session = mSessionMgr->CreateUnauthenticatedSession(params.GetPeerAddress()); + VerifyOrExit(session.HasValue(), CHIP_ERROR_NO_MEMORY); + + exchangeCtxt = mExchangeMgr->NewContext(session.Value(), &mPairingSession); VerifyOrExit(exchangeCtxt != nullptr, err = CHIP_ERROR_INTERNAL); err = mIDAllocator.Allocate(keyID); diff --git a/src/lib/core/CHIPConfig.h b/src/lib/core/CHIPConfig.h index 6b2169096f0142..b3407b32e81328 100644 --- a/src/lib/core/CHIPConfig.h +++ b/src/lib/core/CHIPConfig.h @@ -2251,6 +2251,18 @@ #define CHIP_CONFIG_ENABLE_IFJ_SERVICE_FABRIC_JOIN 0 #endif // CHIP_CONFIG_ENABLE_IFJ_SERVICE_FABRIC_JOIN +/** + * @def CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE + * + * @brief Define the size of the pool used for tracking CHIP unauthenticated + * states. The entries in the pool are automatically rotated by LRU. The size + * of the pool limits how many PASE and CASE pairing sessions can be processed + * simultaneously. + */ +#ifndef CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE +#define CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE 4 +#endif // CHIP_CONFIG_UNAUTHENTICATED_CONNECTION_POOL_SIZE + /** * @def CHIP_CONFIG_PEER_CONNECTION_POOL_SIZE * diff --git a/src/lib/core/CHIPError.cpp b/src/lib/core/CHIPError.cpp index d8a6eb2ff7c61f..fc6623e608ab53 100644 --- a/src/lib/core/CHIPError.cpp +++ b/src/lib/core/CHIPError.cpp @@ -644,6 +644,9 @@ bool FormatCHIPError(char * buf, uint16_t bufSize, CHIP_ERROR err) case CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED.AsInteger(): desc = "Duplicate message received"; break; + case CHIP_ERROR_MESSAGE_ID_OUT_OF_WINDOW.AsInteger(): + desc = "Message id out of window"; + break; } #endif // !CHIP_CONFIG_SHORT_ERROR_STR diff --git a/src/lib/core/CHIPError.h b/src/lib/core/CHIPError.h index c4a7d63fd86157..a50be6da4d441a 100644 --- a/src/lib/core/CHIPError.h +++ b/src/lib/core/CHIPError.h @@ -2173,6 +2173,14 @@ using CHIP_ERROR = ::chip::ChipError; */ #define CHIP_ERROR_FABRIC_MISMATCH_ON_ICA CHIP_CORE_ERROR(0xc6) +/** + * @def CHIP_ERROR_MESSAGE_ID_OUT_OF_WINDOW + * + * @brief + * The message id of the received message is out of receiving window + */ +#define CHIP_ERROR_MESSAGE_ID_OUT_OF_WINDOW CHIP_CORE_ERROR(0xc7) + /** * @} */ diff --git a/src/lib/core/InPlace.h b/src/lib/core/InPlace.h index 01a5a1f2bd9ee4..647fc51c6e87d1 100644 --- a/src/lib/core/InPlace.h +++ b/src/lib/core/InPlace.h @@ -23,10 +23,6 @@ */ #pragma once -#include -#include -#include - namespace chip { /// InPlace is disambiguation tags that can be passed to the constructors to indicate that the contained object should be diff --git a/src/lib/support/Pool.h b/src/lib/support/Pool.h index 2158929ca216ee..f9fe67aaa8d735 100644 --- a/src/lib/support/Pool.h +++ b/src/lib/support/Pool.h @@ -87,7 +87,7 @@ template class BitMapObjectPool : public StaticAllocatorBitmap { public: - BitMapObjectPool() : StaticAllocatorBitmap(mMemory, mUsage, N, sizeof(T)) {} + BitMapObjectPool() : StaticAllocatorBitmap(mData.mMemory, mUsage, N, sizeof(T)) {} static size_t Size() { return N; } @@ -110,6 +110,13 @@ class BitMapObjectPool : public StaticAllocatorBitmap Deallocate(element); } + template + void ResetObject(T * element, Args &&... args) + { + element->~T(); + new (element) T(std::forward(args)...); + } + /** * @brief * Run a functor for each active object in the pool @@ -144,7 +151,13 @@ class BitMapObjectPool : public StaticAllocatorBitmap }; std::atomic mUsage[(N + kBitChunkSize - 1) / kBitChunkSize]; - alignas(alignof(T)) uint8_t mMemory[N * sizeof(T)]; + union Data + { + Data() {} + ~Data() {} + alignas(alignof(T)) uint8_t mMemory[N * sizeof(T)]; + T mMemoryViewForDebug[N]; // Just for debugger + } mData; }; } // namespace chip diff --git a/src/lib/support/ReferenceCountedHandle.h b/src/lib/support/ReferenceCountedHandle.h index 637fe514405c31..c433a639db8d7a 100644 --- a/src/lib/support/ReferenceCountedHandle.h +++ b/src/lib/support/ReferenceCountedHandle.h @@ -28,14 +28,19 @@ class ReferenceCountedHandle explicit ReferenceCountedHandle(Target & target) : mTarget(target) { mTarget.Retain(); } ~ReferenceCountedHandle() { mTarget.Release(); } - ReferenceCountedHandle(const ReferenceCountedHandle & that) = delete; + ReferenceCountedHandle(const ReferenceCountedHandle & that) : mTarget(that.mTarget) { mTarget.Retain(); } + + ReferenceCountedHandle(ReferenceCountedHandle && that) : mTarget(that.mTarget) { mTarget.Retain(); } + ReferenceCountedHandle & operator=(const ReferenceCountedHandle & that) = delete; - ReferenceCountedHandle(ReferenceCountedHandle && that) = delete; ReferenceCountedHandle & operator=(ReferenceCountedHandle && that) = delete; bool operator==(const ReferenceCountedHandle & that) const { return &mTarget == &that.mTarget; } bool operator!=(const ReferenceCountedHandle & that) const { return !(*this == that); } + Target * operator->() { return &mTarget; } + Target & Get() const { return mTarget; } + private: Target & mTarget; }; diff --git a/src/messaging/ApplicationExchangeDispatch.cpp b/src/messaging/ApplicationExchangeDispatch.cpp index 7e7caf59c1e8e0..76b1dbb420e17a 100644 --- a/src/messaging/ApplicationExchangeDispatch.cpp +++ b/src/messaging/ApplicationExchangeDispatch.cpp @@ -30,7 +30,7 @@ CHIP_ERROR ApplicationExchangeDispatch::PrepareMessage(SessionHandle session, Pa System::PacketBufferHandle && message, EncryptedPacketBufferHandle & preparedMessage) { - return mSessionMgr->BuildEncryptedMessagePayload(session, payloadHeader, std::move(message), preparedMessage); + return mSessionMgr->PrepareMessage(session, payloadHeader, std::move(message), preparedMessage); } CHIP_ERROR ApplicationExchangeDispatch::SendPreparedMessage(SessionHandle session, diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index e937376c712392..25e92b4029a271 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -66,6 +66,8 @@ class ExchangeMessageDispatch : public ReferenceCounted protected: virtual bool MessagePermitted(uint16_t protocol, uint8_t type) = 0; + + // TODO: remove IsReliableTransmissionAllowed, this function should be provided over session. virtual bool IsReliableTransmissionAllowed() const { return true; } }; diff --git a/src/messaging/ReliableMessageMgr.cpp b/src/messaging/ReliableMessageMgr.cpp index c180b9d5701224..41b8baa60975f0 100644 --- a/src/messaging/ReliableMessageMgr.cpp +++ b/src/messaging/ReliableMessageMgr.cpp @@ -406,8 +406,8 @@ void ReliableMessageMgr::ClearRetransTable(RetransTableEntry & rEntry) // Expire any virtual ticks that have expired so all wakeup sources reflect the current time ExpireTicks(); - rEntry.rc->ReleaseContext(); rEntry.rc->SetOccupied(false); + rEntry.rc->ReleaseContext(); rEntry.rc = nullptr; // Clear all other fields diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp index dc96e07fb9146e..6f870825ebc09b 100644 --- a/src/messaging/tests/MessagingContext.cpp +++ b/src/messaging/tests/MessagingContext.cpp @@ -37,11 +37,12 @@ CHIP_ERROR MessagingContext::Init(nlTestSuite * suite, TransportMgrBase * transp ReturnErrorOnFailure(mExchangeManager.Init(&mSecureSessionMgr)); ReturnErrorOnFailure(mMessageCounterManager.Init(&mExchangeManager)); - ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(mAddress, GetAliceNodeId(), &mPairingBobToAlice, - SecureSession::SessionRole::kInitiator, mSrcFabricIndex)); + ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(Optional::Value(mAliceAddress), GetAliceNodeId(), + &mPairingBobToAlice, SecureSession::SessionRole::kInitiator, + mSrcFabricIndex)); - return mSecureSessionMgr.NewPairing(mAddress, GetBobNodeId(), &mPairingAliceToBob, SecureSession::SessionRole::kResponder, - mDestFabricIndex); + return mSecureSessionMgr.NewPairing(Optional::Value(mBobAddress), GetBobNodeId(), &mPairingAliceToBob, + SecureSession::SessionRole::kResponder, mDestFabricIndex); } // Shutdown all layers, finalize operations @@ -67,6 +68,16 @@ SessionHandle MessagingContext::GetSessionAliceToBob() return SessionHandle(GetBobNodeId(), GetAliceKeyId(), GetBobKeyId(), mDestFabricIndex); } +Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToAlice(Messaging::ExchangeDelegate * delegate) +{ + return mExchangeManager.NewContext(mSecureSessionMgr.CreateUnauthenticatedSession(mAliceAddress).Value(), delegate); +} + +Messaging::ExchangeContext * MessagingContext::NewUnauthenticatedExchangeToBob(Messaging::ExchangeDelegate * delegate) +{ + return mExchangeManager.NewContext(mSecureSessionMgr.CreateUnauthenticatedSession(mBobAddress).Value(), delegate); +} + Messaging::ExchangeContext * MessagingContext::NewExchangeToAlice(Messaging::ExchangeDelegate * delegate) { // TODO: temprary create a SessionHandle from node id, will be fix in PR 3602 diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index 6bb4e70df69eb0..735278c344117e 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -37,8 +37,9 @@ class MessagingContext { public: MessagingContext() : - mInitialized(false), mAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)), - mPairingAliceToBob(GetBobKeyId(), GetAliceKeyId()), mPairingBobToAlice(GetAliceKeyId(), GetBobKeyId()) + mInitialized(false), mAliceAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT + 1)), + mBobAddress(Transport::PeerAddress::UDP(GetAddress(), CHIP_PORT)), mPairingAliceToBob(GetBobKeyId(), GetAliceKeyId()), + mPairingBobToAlice(GetAliceKeyId(), GetBobKeyId()) {} ~MessagingContext() { VerifyOrDie(mInitialized == false); } @@ -80,6 +81,9 @@ class MessagingContext SessionHandle GetSessionBobToAlice(); SessionHandle GetSessionAliceToBob(); + Messaging::ExchangeContext * NewUnauthenticatedExchangeToAlice(Messaging::ExchangeDelegate * delegate); + Messaging::ExchangeContext * NewUnauthenticatedExchangeToBob(Messaging::ExchangeDelegate * delegate); + Messaging::ExchangeContext * NewExchangeToAlice(Messaging::ExchangeDelegate * delegate); Messaging::ExchangeContext * NewExchangeToBob(Messaging::ExchangeDelegate * delegate); @@ -98,7 +102,8 @@ class MessagingContext NodeId mAliceNodeId = 111222333; uint16_t mBobKeyId = 1; uint16_t mAliceKeyId = 2; - Optional mAddress; + Transport::PeerAddress mAliceAddress; + Transport::PeerAddress mBobAddress; SecurePairingUsingTestSecret mPairingAliceToBob; SecurePairingUsingTestSecret mPairingBobToAlice; Transport::FabricTable mFabrics; diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 1f5a4b94ab217f..6f8874a9a1d702 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -124,26 +124,9 @@ class MockAppDelegate : public ExchangeDelegate nlTestSuite * mTestSuite = nullptr; }; -class MockSessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDispatch +class MockSessionEstablishmentExchangeDispatch : public Messaging::ApplicationExchangeDispatch { public: - CHIP_ERROR PrepareMessage(SessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && message, - EncryptedPacketBufferHandle & preparedMessage) override - { - PacketHeader packetHeader; - - ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); - ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); - - preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(message)); - return CHIP_NO_ERROR; - } - - CHIP_ERROR SendPreparedMessage(SessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const override - { - return gTransportMgr.SendMessage(Transport::PeerAddress(), preparedMessage.CastToWritable()); - } - bool IsReliableTransmissionAllowed() const override { return mRetainMessageOnSend; } bool MessagePermitted(uint16_t protocol, uint8_t type) override { return true; } @@ -367,6 +350,9 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) CHIP_ERROR err = CHIP_NO_ERROR; MockSessionEstablishmentDelegate mockSender; + err = mockSender.mMessageDispatch.Init(&ctx.GetSecureSessionManager()); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + ExchangeContext * exchange = ctx.NewExchangeToAlice(&mockSender); NL_TEST_ASSERT(inSuite, exchange != nullptr); @@ -380,9 +366,6 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) 1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL }); - err = mockSender.mMessageDispatch.Init(); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - mockSender.mMessageDispatch.mRetainMessageOnSend = false; // Let's drop the initial message @@ -414,14 +397,20 @@ void CheckUnencryptedMessageReceiveFailure(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, !buffer.IsNull()); MockSessionEstablishmentDelegate mockReceiver; - CHIP_ERROR err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver); + CHIP_ERROR err = mockReceiver.mMessageDispatch.Init(&ctx.GetSecureSessionManager()); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // Expect the received messages to be encrypted mockReceiver.mMessageDispatch.mRequireEncryption = true; MockSessionEstablishmentDelegate mockSender; - ExchangeContext * exchange = ctx.NewExchangeToAlice(&mockSender); + err = mockSender.mMessageDispatch.Init(&ctx.GetSecureSessionManager()); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + ExchangeContext * exchange = ctx.NewUnauthenticatedExchangeToAlice(&mockSender); NL_TEST_ASSERT(inSuite, exchange != nullptr); ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); @@ -429,9 +418,6 @@ void CheckUnencryptedMessageReceiveFailure(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, rm != nullptr); NL_TEST_ASSERT(inSuite, rc != nullptr); - err = mockSender.mMessageDispatch.Init(); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - gLoopback.mSentMessageCount = 0; gLoopback.mNumMessagesToDrop = 0; gLoopback.mDroppedMessageCount = 0; @@ -584,23 +570,23 @@ void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuit CHIP_ERROR err = ctx.Init(inSuite, &gTransportMgr, &gIOContext); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.SetBobNodeId(kPlaceholderNodeId); - ctx.SetAliceNodeId(kPlaceholderNodeId); - ctx.SetBobKeyId(0); - ctx.SetAliceKeyId(0); - ctx.SetFabricIndex(kUndefinedFabricIndex); - chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); NL_TEST_ASSERT(inSuite, !buffer.IsNull()); MockSessionEstablishmentDelegate mockReceiver; + err = mockReceiver.mMessageDispatch.Init(&ctx.GetSecureSessionManager()); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Echo::MsgType::EchoRequest, &mockReceiver); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); mockReceiver.mTestSuite = inSuite; MockSessionEstablishmentDelegate mockSender; - ExchangeContext * exchange = ctx.NewExchangeToAlice(&mockSender); + err = mockSender.mMessageDispatch.Init(&ctx.GetSecureSessionManager()); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + ExchangeContext * exchange = ctx.NewUnauthenticatedExchangeToAlice(&mockSender); NL_TEST_ASSERT(inSuite, exchange != nullptr); ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); @@ -613,9 +599,6 @@ void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuit 1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL }); - err = mockSender.mMessageDispatch.Init(); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - // Let's drop the initial message gLoopback.mSentMessageCount = 0; gLoopback.mNumMessagesToDrop = 1; diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 382bdb4ce25450..6a926939708897 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -46,7 +46,7 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager Cleanup(); - ReturnErrorOnFailure(GetSession().MessageDispatch().Init(transportMgr)); + ReturnErrorOnFailure(GetSession().MessageDispatch().Init(sessionMgr)); return CHIP_NO_ERROR; } diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 65c3026a402b7a..29f27d0ccf8709 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -1221,8 +1221,6 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea CHIP_ERROR err = ValidateReceivedMessage(ec, payloadHeader, msg); SuccessOrExit(err); - SetPeerAddress(mMessageDispatch.GetPeerAddress()); - switch (static_cast(payloadHeader.GetMessageType())) { case Protocols::SecureChannel::MsgType::CASE_SigmaR1: diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 6ff77062ffc82e..f47ba8ce5a8e52 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -882,8 +882,6 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl CHIP_ERROR err = ValidateReceivedMessage(exchange, payloadHeader, std::move(msg)); SuccessOrExit(err); - SetPeerAddress(mMessageDispatch.GetPeerAddress()); - switch (static_cast(payloadHeader.GetMessageType())) { case MsgType::PBKDFParamRequest: diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index ea6beca4c68135..983b75aac1f934 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -32,19 +32,13 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::PrepareMessage(SessionHandle se System::PacketBufferHandle && message, EncryptedPacketBufferHandle & preparedMessage) { - PacketHeader packetHeader; - ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); - ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); - - preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(message)); - return CHIP_NO_ERROR; + return mSessionMgr->PrepareMessage(session, payloadHeader, std::move(message), preparedMessage); } CHIP_ERROR SessionEstablishmentExchangeDispatch::SendPreparedMessage(SessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) const { - ReturnErrorCodeIf(mTransportMgr == nullptr, CHIP_ERROR_INCORRECT_STATE); - return mTransportMgr->SendMessage(mPeerAddress, preparedMessage.CastToWritable()); + return mSessionMgr->SendPreparedMessage(session, preparedMessage); } CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(uint32_t messageCounter, const PayloadHeader & payloadHeader, @@ -52,7 +46,6 @@ CHIP_ERROR SessionEstablishmentExchangeDispatch::OnMessageReceived(uint32_t mess Messaging::MessageFlags msgFlags, ReliableMessageContext * reliableMessageContext) { - mPeerAddress = peerAddress; return ExchangeMessageDispatch::OnMessageReceived(messageCounter, payloadHeader, peerAddress, msgFlags, reliableMessageContext); } diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index 7c3ab655218cbd..d58db5ac08aa5d 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -36,10 +36,10 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi virtual ~SessionEstablishmentExchangeDispatch() {} - CHIP_ERROR Init(TransportMgrBase * transportMgr) + CHIP_ERROR Init(SecureSessionMgr * sessionMgr) { - ReturnErrorCodeIf(transportMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT); - mTransportMgr = transportMgr; + ReturnErrorCodeIf(sessionMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT); + mSessionMgr = sessionMgr; return ExchangeMessageDispatch::Init(); } @@ -51,24 +51,13 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi const Transport::PeerAddress & peerAddress, Messaging::MessageFlags msgFlags, Messaging::ReliableMessageContext * reliableMessageContext) override; - const Transport::PeerAddress & GetPeerAddress() const { return mPeerAddress; } - - void SetPeerAddress(const Transport::PeerAddress & address) { mPeerAddress = address; } - protected: bool MessagePermitted(uint16_t protocol, uint8_t type) override; - bool IsReliableTransmissionAllowed() const override - { - // If the underlying transport is UDP. - return (mPeerAddress.GetTransportType() == Transport::Type::kUdp); - } - bool IsEncryptionRequired() const override { return false; } private: - TransportMgrBase * mTransportMgr = nullptr; - Transport::PeerAddress mPeerAddress; + SecureSessionMgr * mSessionMgr = nullptr; }; } // namespace chip diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 2868e7373ded63..67f39cd35e2b47 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -176,8 +176,8 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) FabricInfo * fabric = gCommissionerFabrics.FindFabricWithIndex(gCommissionerFabricIndex); NL_TEST_ASSERT(inSuite, fabric != nullptr); - NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - ExchangeContext * context = ctx.NewExchangeToBob(&pairing); + NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); + ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), nullptr, Node01_01, 0, nullptr, @@ -191,14 +191,19 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); + // Clear pending packet in CRMP + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + ReliableMessageContext * rc = context->GetReliableMessageContext(); + rm->ClearRetransTable(rc); + gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; CASESession pairing1; - NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); gLoopback.mSentMessageCount = 0; gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; - ExchangeContext * context1 = ctx.NewExchangeToBob(&pairing1); + ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, pairing1.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, context1, @@ -218,14 +223,14 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte CASESessionSerializable serializableAccessory; gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( Protocols::SecureChannel::MsgType::CASE_SigmaR1, &pairingAccessory) == CHIP_NO_ERROR); - ExchangeContext * contextCommissioner = ctx.NewExchangeToBob(&pairingCommissioner); + ExchangeContext * contextCommissioner = ctx.NewUnauthenticatedExchangeToBob(&pairingCommissioner); FabricInfo * fabric = gCommissionerFabrics.FindFabricWithIndex(gCommissionerFabricIndex); NL_TEST_ASSERT(inSuite, fabric != nullptr); @@ -236,7 +241,7 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte pairingCommissioner.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 3); + NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 4); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 1); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); @@ -343,8 +348,8 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte TestContext & ctx = *reinterpret_cast(inContext); gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, pairingCommissioner->MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, gPairingServer.GetSession().MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingCommissioner->MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, gPairingServer.GetSession().MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); SessionIDAllocator idAllocator; @@ -353,7 +358,7 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte &ctx.GetSecureSessionManager(), &gDeviceFabrics, &idAllocator) == CHIP_NO_ERROR); - ExchangeContext * contextCommissioner = ctx.NewExchangeToBob(pairingCommissioner); + ExchangeContext * contextCommissioner = ctx.NewUnauthenticatedExchangeToBob(pairingCommissioner); FabricInfo * fabric = gCommissionerFabrics.FindFabricWithIndex(gCommissionerFabricIndex); NL_TEST_ASSERT(inSuite, fabric != nullptr); @@ -362,12 +367,12 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte pairingCommissioner->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 3); + NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 4); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); auto * pairingCommissioner1 = chip::Platform::New(); - NL_TEST_ASSERT(inSuite, pairingCommissioner1->MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - ExchangeContext * contextCommissioner1 = ctx.NewExchangeToBob(pairingCommissioner1); + NL_TEST_ASSERT(inSuite, pairingCommissioner1->MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); + ExchangeContext * contextCommissioner1 = ctx.NewUnauthenticatedExchangeToBob(pairingCommissioner1); NL_TEST_ASSERT(inSuite, pairingCommissioner1->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), fabric, Node01_01, 0, diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 417b5bd3a212a3..9341875df472aa 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -122,8 +122,8 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.Reset(); - NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - ExchangeContext * context = ctx.NewExchangeToBob(&pairing); + NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); + ExchangeContext * context = ctx.NewUnauthenticatedExchangeToBob(&pairing); NL_TEST_ASSERT(inSuite, pairing.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, nullptr, nullptr) != CHIP_NO_ERROR); @@ -134,13 +134,18 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); + // Clear pending packet in CRMP + ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); + ReliableMessageContext * rc = context->GetReliableMessageContext(); + rm->ClearRetransTable(rc); + gLoopback.Reset(); gLoopback.mSentMessageCount = 0; gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; PASESession pairing1; - NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - ExchangeContext * context1 = ctx.NewExchangeToBob(&pairing1); + NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); + ExchangeContext * context1 = ctx.NewUnauthenticatedExchangeToBob(&pairing1); NL_TEST_ASSERT(inSuite, pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context1, &delegate) == CHIP_ERROR_BAD_REQUEST); @@ -157,16 +162,13 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); - ExchangeContext * contextCommissioner = ctx.NewExchangeToBob(&pairingCommissioner); + ExchangeContext * contextCommissioner = ctx.NewUnauthenticatedExchangeToBob(&pairingCommissioner); if (gLoopback.mNumMessagesToDrop != 0) { - pairingCommissioner.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp)); - pairingAccessory.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp)); - ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); ReliableMessageContext * rc = contextCommissioner->GetReliableMessageContext(); NL_TEST_ASSERT(inSuite, rm != nullptr); @@ -221,6 +223,8 @@ void SecurePairingHandshakeWithPacketLossTest(nlTestSuite * inSuite, void * inCo void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) { + TestContext & ctx = *reinterpret_cast(inContext); + TestSecurePairingDelegate delegateCommissioner; PASESession pairingCommissioner; @@ -230,14 +234,10 @@ void SecurePairingFailedHandshake(nlTestSuite * inSuite, void * inContext) gLoopback.Reset(); gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); - - TestContext & ctx = *reinterpret_cast(inContext); - ExchangeContext * contextCommissioner = ctx.NewExchangeToBob(&pairingCommissioner); + NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); - pairingCommissioner.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp)); - pairingAccessory.MessageDispatch().SetPeerAddress(PeerAddress(Type::kUdp)); + ExchangeContext * contextCommissioner = ctx.NewUnauthenticatedExchangeToBob(&pairingCommissioner); ReliableMessageMgr * rm = ctx.GetExchangeManager().GetReliableMessageMgr(); ReliableMessageContext * rc = contextCommissioner->GetReliableMessageContext(); diff --git a/src/transport/PeerMessageCounter.h b/src/transport/PeerMessageCounter.h index a69d1bf93088e5..0d0e9e740374db 100644 --- a/src/transport/PeerMessageCounter.h +++ b/src/transport/PeerMessageCounter.h @@ -97,7 +97,7 @@ class PeerMessageCounter uint32_t offset = mSynced.mMaxCounter - counter; if (offset >= CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE) { - return CHIP_ERROR_INVALID_ARGUMENT; // outside valid range + return CHIP_ERROR_MESSAGE_ID_OUT_OF_WINDOW; // outside valid range } if (mSynced.mWindow.test(offset)) { @@ -108,6 +108,33 @@ class PeerMessageCounter return CHIP_NO_ERROR; } + CHIP_ERROR VerifyOrTrustFirst(uint32_t counter) + { + switch (mStatus) + { + case Status::NotSynced: + // Trust and set the counter when not synced + SetCounter(counter); + return CHIP_NO_ERROR; + case Status::Synced: { + CHIP_ERROR err = Verify(counter); + if (err == CHIP_ERROR_MESSAGE_ID_OUT_OF_WINDOW) + { + // According to chip spec, when global unencrypted message + // counter is out of window, the peer may have reset and is + // using another randomize initial value. Trust the new + // counter here. + SetCounter(counter); + err = CHIP_NO_ERROR; + } + return err; + } + default: + VerifyOrDie(false); + return CHIP_ERROR_INTERNAL; + } + } + /** * @brief * With the counter verified and the packet MIC also verified by the secure key, we can trust the packet and adjust diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 093ce844e19ef4..54ef62b6d4ba9f 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -102,9 +102,8 @@ void SecureSessionMgr::Shutdown() mCB = nullptr; } -CHIP_ERROR SecureSessionMgr::BuildEncryptedMessagePayload(SessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && msgBuf, - EncryptedPacketBufferHandle & encryptedMessage) +CHIP_ERROR SecureSessionMgr::PrepareMessage(SessionHandle session, PayloadHeader & payloadHeader, + System::PacketBufferHandle && message, EncryptedPacketBufferHandle & preparedMessage) { PacketHeader packetHeader; if (IsControlMessage(payloadHeader)) @@ -112,71 +111,97 @@ CHIP_ERROR SecureSessionMgr::BuildEncryptedMessagePayload(SessionHandle session, packetHeader.SetSecureSessionControlMsg(true); } - PeerConnectionState * state = GetPeerConnectionState(session); - if (state == nullptr) + if (session.IsSecure()) { - return CHIP_ERROR_NOT_CONNECTED; + PeerConnectionState * state = GetPeerConnectionState(session); + if (state == nullptr) + { + return CHIP_ERROR_NOT_CONNECTED; + } + + MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *state); + ReturnErrorOnFailure(SecureMessageCodec::Encode(state, payloadHeader, packetHeader, message, counter)); + + ChipLogProgress(Inet, + "Build %s message %p to 0x" ChipLogFormatX64 " of type %d and protocolId %" PRIu32 + " on exchange %d with MessageId %" PRIu32 ".", + "encrypted", &preparedMessage, ChipLogValueX64(state->GetPeerNodeId()), payloadHeader.GetMessageType(), + payloadHeader.GetProtocolID().ToFullyQualifiedSpecForm(), payloadHeader.GetExchangeID(), + packetHeader.GetMessageId()); } + else + { + ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(message)); - MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *state); - ReturnErrorOnFailure(SecureMessageCodec::Encode(state, payloadHeader, packetHeader, msgBuf, counter)); + MessageCounter & counter = session.GetUnauthenticatedSession()->GetLocalMessageCounter(); + uint32_t messageId = counter.Value(); + ReturnErrorOnFailure(counter.Advance()); - ReturnErrorOnFailure(packetHeader.EncodeBeforeData(msgBuf)); + packetHeader.SetMessageId(messageId); - encryptedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(msgBuf)); - ChipLogProgress(Inet, "Encrypted message %p to 0x" ChipLogFormatX64 " of type %d and protocolId %" PRIu32 " on exchange %d.", - &encryptedMessage, ChipLogValueX64(state->GetPeerNodeId()), payloadHeader.GetMessageType(), - payloadHeader.GetProtocolID().ToFullyQualifiedSpecForm(), payloadHeader.GetExchangeID()); + ChipLogProgress(Inet, + "Build %s message %p to 0x" ChipLogFormatX64 " of type %d and protocolId %" PRIu32 + " on exchange %d with MessageId %" PRIu32 ".", + "plaintext", &preparedMessage, ChipLogValueX64(kUndefinedNodeId), payloadHeader.GetMessageType(), + payloadHeader.GetProtocolID().ToFullyQualifiedSpecForm(), payloadHeader.GetExchangeID(), + packetHeader.GetMessageId()); + } + + ReturnErrorOnFailure(packetHeader.EncodeBeforeData(message)); + preparedMessage = EncryptedPacketBufferHandle::MarkEncrypted(std::move(message)); return CHIP_NO_ERROR; } CHIP_ERROR SecureSessionMgr::SendPreparedMessage(SessionHandle session, const EncryptedPacketBufferHandle & preparedMessage) { - CHIP_ERROR err = CHIP_NO_ERROR; - PeerConnectionState * state = nullptr; - PacketBufferHandle msgBuf; + VerifyOrReturnError(mState == State::kInitialized, CHIP_ERROR_INCORRECT_STATE); + VerifyOrReturnError(!preparedMessage.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrExit(mState == State::kInitialized, err = CHIP_ERROR_INCORRECT_STATE); - VerifyOrExit(!preparedMessage.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT); - msgBuf = preparedMessage.CastToWritable(); - VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrExit(!msgBuf->HasChainedBuffer(), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + const Transport::PeerAddress * destination; - // Find an active connection to the specified peer node - state = GetPeerConnectionState(session); - VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); + if (session.IsSecure()) + { + // Find an active connection to the specified peer node + PeerConnectionState * state = GetPeerConnectionState(session); + if (state == nullptr) + { + ChipLogError(Inet, "Secure transport could not find a valid PeerConnection"); + return CHIP_ERROR_NOT_CONNECTED; + } - // This marks any connection where we send data to as 'active' - mPeerConnections.MarkConnectionActive(state); + // This marks any connection where we send data to as 'active' + mPeerConnections.MarkConnectionActive(state); - ChipLogProgress(Inet, "Sending msg %p to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", &preparedMessage, - ChipLogValueX64(state->GetPeerNodeId()), System::Clock::GetMonotonicMilliseconds()); + destination = &state->GetPeerAddress(); - if (mTransportMgr != nullptr) - { - ChipLogProgress(Inet, "Sending secure msg on generic transport"); - err = mTransportMgr->SendMessage(state->GetPeerAddress(), std::move(msgBuf)); + ChipLogProgress(Inet, "Sending %s msg %p to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", "encrypted", + &preparedMessage, ChipLogValueX64(state->GetPeerNodeId()), System::Clock::GetMonotonicMilliseconds()); } else { - ChipLogError(Inet, "The transport manager is not initialized. Unable to send the message"); - err = CHIP_ERROR_INCORRECT_STATE; + auto unauthenticated = session.GetUnauthenticatedSession(); + mUnauthenticatedSessions.MarkSessionActive(unauthenticated.Get()); + destination = &unauthenticated->GetPeerAddress(); + + ChipLogProgress(Inet, "Sending %s msg %p to 0x" ChipLogFormatX64 " at utc time: %" PRId64 " msec", "plaintext", + &preparedMessage, ChipLogValueX64(kUndefinedNodeId), System::Clock::GetMonotonicMilliseconds()); } - ChipLogProgress(Inet, "Secure msg send status %s", ErrorStr(err)); - SuccessOrExit(err); -exit: - if (!msgBuf.IsNull()) + PacketBufferHandle msgBuf = preparedMessage.CastToWritable(); + VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(!msgBuf->HasChainedBuffer(), CHIP_ERROR_INVALID_MESSAGE_LENGTH); + + if (mTransportMgr != nullptr) { - const char * errStr = ErrorStr(err); - if (state == nullptr) - { - ChipLogError(Inet, "Secure transport could not find a valid PeerConnection: %s", errStr); - } + ChipLogProgress(Inet, "Sending msg on generic transport"); + return mTransportMgr->SendMessage(*destination, std::move(msgBuf)); + } + else + { + ChipLogError(Inet, "The transport manager is not initialized. Unable to send the message"); + return CHIP_ERROR_INCORRECT_STATE; } - - return err; } void SecureSessionMgr::ExpirePairing(SessionHandle session) @@ -304,12 +329,36 @@ void SecureSessionMgr::OnMessageReceived(const PeerAddress & peerAddress, System void SecureSessionMgr::MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle && msg) { + Transport::UnauthenticatedSession * session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); + if (session == nullptr) + { + ChipLogError(Inet, "UnauthenticatedSession exhausted"); + return; + } + + SecureSessionMgrDelegate::DuplicateMessage isDuplicate = SecureSessionMgrDelegate::DuplicateMessage::No; + + // Verify message counter + CHIP_ERROR err = session->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageId()); + if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED) + { + ChipLogDetail(Inet, "Received a duplicate message with MessageId: %" PRIu32, packetHeader.GetMessageId()); + isDuplicate = SecureSessionMgrDelegate::DuplicateMessage::Yes; + err = CHIP_NO_ERROR; + } + VerifyOrDie(err == CHIP_NO_ERROR); + + mUnauthenticatedSessions.MarkSessionActive(*session); + + PayloadHeader payloadHeader; + ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); + + session->GetPeerMessageCounter().Commit(packetHeader.GetMessageId()); + if (mCB != nullptr) { - PayloadHeader payloadHeader; - ReturnOnFailure(payloadHeader.DecodeAndConsume(msg)); - mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle::TemporaryUnauthenticatedSession(), peerAddress, - SecureSessionMgrDelegate::DuplicateMessage::No, std::move(msg)); + mCB->OnMessageReceived(packetHeader, payloadHeader, SessionHandle(Transport::UnauthenticatedSessionHandle(*session)), + peerAddress, isDuplicate, std::move(msg)); } } @@ -365,7 +414,7 @@ void SecureSessionMgr::SecureMessageDispatch(const PacketHeader & packetHeader, err = state->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageId()); if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED) { - ChipLogDetail(Inet, "Received a duplicate message"); + ChipLogDetail(Inet, "Received a duplicate message with MessageId: %" PRIu32, packetHeader.GetMessageId()); isDuplicate = SecureSessionMgrDelegate::DuplicateMessage::Yes; err = CHIP_NO_ERROR; } diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index 740bae8fe1995a..f5d5ca82d7e7c9 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -40,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -191,8 +192,8 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate * 3. Encode the packet header and prepend it to message. * Returns a encrypted message in encryptedMessage. */ - CHIP_ERROR BuildEncryptedMessagePayload(SessionHandle session, PayloadHeader & payloadHeader, - System::PacketBufferHandle && msgBuf, EncryptedPacketBufferHandle & encryptedMessage); + CHIP_ERROR PrepareMessage(SessionHandle session, PayloadHeader & payloadHeader, System::PacketBufferHandle && msgBuf, + EncryptedPacketBufferHandle & encryptedMessage); /** * @brief @@ -263,6 +264,15 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate */ void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle && msgBuf) override; + Optional CreateUnauthenticatedSession(const Transport::PeerAddress & peerAddress) + { + Transport::UnauthenticatedSession * session = mUnauthenticatedSessions.FindOrAllocateEntry(peerAddress); + if (session == nullptr) + return Optional::Missing(); + + return Optional::Value(SessionHandle(Transport::UnauthenticatedSessionHandle(*session))); + } + private: /** * The State of a secure transport object. @@ -280,6 +290,7 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate }; System::Layer * mSystemLayer = nullptr; + Transport::UnauthenticatedSessionTable mUnauthenticatedSessions; Transport::PeerConnections mPeerConnections; // < Active connections to other peers State mState; // < Initialization state of the object diff --git a/src/transport/SessionHandle.h b/src/transport/SessionHandle.h index fe9f0cadf26a5e..f701702def5c4e 100644 --- a/src/transport/SessionHandle.h +++ b/src/transport/SessionHandle.h @@ -17,6 +17,8 @@ #pragma once +#include + namespace chip { class SecureSessionMgr; @@ -26,6 +28,10 @@ class SessionHandle public: SessionHandle(NodeId peerNodeId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric) {} + SessionHandle(Transport::UnauthenticatedSessionHandle session) : + mPeerNodeId(kPlaceholderNodeId), mFabric(Transport::kUndefinedFabricIndex), mUnauthenticatedSessionHandle(session) + {} + SessionHandle(NodeId peerNodeId, uint16_t localKeyId, uint16_t peerKeyId, FabricIndex fabric) : mPeerNodeId(peerNodeId), mFabric(fabric) { @@ -33,6 +39,8 @@ class SessionHandle mPeerKeyId.SetValue(peerKeyId); } + bool IsSecure() const { return !mUnauthenticatedSessionHandle.HasValue(); } + bool HasFabricIndex() const { return (mFabric != Transport::kUndefinedFabricIndex); } FabricIndex GetFabricIndex() const { return mFabric; } void SetFabricIndex(FabricIndex fabricId) { mFabric = fabricId; } @@ -45,16 +53,13 @@ class SessionHandle bool MatchIncomingSession(const SessionHandle & that) const { - - if (that.GetLocalKeyId().HasValue()) + if (IsSecure()) { - return mLocalKeyId == that.mLocalKeyId; + return that.IsSecure() && mLocalKeyId.Value() == that.mLocalKeyId.Value(); } else { - // TODO: For unencrypted session, temporarily still rely on the old match logic in MatchExchange, need to update to - // match peer’s HW address (BLE) or peer’s IP/Port (for IP). - return true; + return !that.IsSecure() && mUnauthenticatedSessionHandle.Value() == that.mUnauthenticatedSessionHandle.Value(); } } @@ -62,14 +67,12 @@ class SessionHandle const Optional & GetPeerKeyId() const { return mPeerKeyId; } const Optional & GetLocalKeyId() const { return mLocalKeyId; } - // TODO: currently SessionHandle is not able to identify a unauthenticated session, create an empty handle for it - static SessionHandle TemporaryUnauthenticatedSession() - { - return SessionHandle(kPlaceholderNodeId, Transport::kUndefinedFabricIndex); - } + Transport::UnauthenticatedSessionHandle GetUnauthenticatedSession() { return mUnauthenticatedSessionHandle.Value(); } private: friend class SecureSessionMgr; + + // Fields for secure session NodeId mPeerNodeId; Optional mLocalKeyId; Optional mPeerKeyId; @@ -78,6 +81,9 @@ class SessionHandle // to identify an approach that'll allow looking up the corresponding information for // such sessions. FabricIndex mFabric; + + // Fields for unauthenticated session + Optional mUnauthenticatedSessionHandle; }; } // namespace chip diff --git a/src/transport/UnauthenticatedSessionTable.h b/src/transport/UnauthenticatedSessionTable.h new file mode 100644 index 00000000000000..48d721fb9ff096 --- /dev/null +++ b/src/transport/UnauthenticatedSessionTable.h @@ -0,0 +1,200 @@ +/* + * + * Copyright (c) 2020 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace chip { +namespace Transport { + +class UnauthenticatedSession; +using UnauthenticatedSessionHandle = ReferenceCountedHandle; + +class UnauthenticatedSessionDeleter +{ +public: + // This is a no-op because life-cycle of UnauthenticatedSessionTable is rotated by LRU + static void Release(UnauthenticatedSession * entry) {} +}; + +/** + * @brief + * An UnauthenticatedSession stores the binding of TransportAddress, and message counters. + */ +class UnauthenticatedSession : public ReferenceCounted +{ +public: + UnauthenticatedSession(const PeerAddress & address) : mPeerAddress(address) {} + + UnauthenticatedSession(const UnauthenticatedSession &) = delete; + UnauthenticatedSession & operator=(const UnauthenticatedSession &) = delete; + UnauthenticatedSession(UnauthenticatedSession &&) = delete; + UnauthenticatedSession & operator=(UnauthenticatedSession &&) = delete; + + uint64_t GetLastActivityTimeMs() const { return mLastActivityTimeMs; } + void SetLastActivityTimeMs(uint64_t value) { mLastActivityTimeMs = value; } + + const PeerAddress & GetPeerAddress() const { return mPeerAddress; } + + MessageCounter & GetLocalMessageCounter() { return mLocalMessageCounter; } + PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; } + +private: + uint64_t mLastActivityTimeMs = 0; + + const PeerAddress mPeerAddress; + GlobalUnencryptedMessageCounter mLocalMessageCounter; + PeerMessageCounter mPeerMessageCounter; +}; + +/* + * @brief + * An table which manages UnauthenticatedSessions + * + * The UnauthenticatedSession entries are rotated using LRU, but entry can be + * hold by using UnauthenticatedSessionHandle, which increase the reference + * count by 1. If the reference count is not 0, the entry won't be pruned. + */ +template +class UnauthenticatedSessionTable +{ +public: + /** + * Allocates a new session out of the internal resource pool. + * + * @returns CHIP_NO_ERROR if new session created. May fail if maximum connection count has been reached (with + * CHIP_ERROR_NO_MEMORY). + */ + CHECK_RETURN_VALUE + CHIP_ERROR AllocEntry(const PeerAddress & address, UnauthenticatedSession *& entry) + { + entry = mEntries.CreateObject(address); + if (entry != nullptr) + return CHIP_NO_ERROR; + + entry = FindLeastRecentUsedEntry(); + if (entry == nullptr) + { + return CHIP_ERROR_NO_MEMORY; + } + + mEntries.ResetObject(entry, address); + return CHIP_NO_ERROR; + } + + /** + * Get a session using given address + * + * @return the peer found, nullptr if not found + */ + CHECK_RETURN_VALUE + UnauthenticatedSession * FindEntry(const PeerAddress & address) + { + UnauthenticatedSession * result = nullptr; + mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) { + if (MatchPeerAddress(entry->GetPeerAddress(), address)) + { + result = entry; + return false; + } + return true; + }); + return result; + } + + /** + * Get a peer given the peer id. If the peer doesn't exist in the cache, allocate a new entry for it. + * + * @return the peer found or allocated, nullptr if not found and allocate failed. + */ + CHECK_RETURN_VALUE + UnauthenticatedSession * FindOrAllocateEntry(const PeerAddress & address) + { + UnauthenticatedSession * result = FindEntry(address); + if (result != nullptr) + return result; + + CHIP_ERROR err = AllocEntry(address, result); + if (err == CHIP_NO_ERROR) + { + return result; + } + else + { + return nullptr; + } + } + + /// Mark a session as active + void MarkSessionActive(UnauthenticatedSession & entry) { entry.SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); } + + /// Allows access to the underlying time source used for keeping track of connection active time + Time::TimeSource & GetTimeSource() { return mTimeSource; } + +private: + UnauthenticatedSession * FindLeastRecentUsedEntry() + { + UnauthenticatedSession * result = nullptr; + uint64_t oldestTimeMs = std::numeric_limits::max(); + + mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) { + if (entry->GetReferenceCount() == 0 && entry->GetLastActivityTimeMs() < oldestTimeMs) + { + result = entry; + oldestTimeMs = entry->GetLastActivityTimeMs(); + } + return true; + }); + + return result; + } + + static bool MatchPeerAddress(const PeerAddress & a1, const PeerAddress & a2) + { + if (a1.GetTransportType() != a2.GetTransportType()) + return false; + + switch (a1.GetTransportType()) + { + case Transport::Type::kUndefined: + return false; + case Transport::Type::kUdp: + case Transport::Type::kTcp: + return a1.GetIPAddress() == a2.GetIPAddress() && a1.GetPort() == a2.GetPort() && + // Enforce interface equal-ness if the address is link-local, otherwise ignore interface + (a1.GetIPAddress().IsIPv6LinkLocal() ? a1.GetInterface() == a2.GetInterface() : true); + case Transport::Type::kBle: + // TODO: complete BLE address comparation + return true; + } + + return false; + } + + Time::TimeSource mTimeSource; + BitMapObjectPool mEntries; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index 1f40642882e971..69e745d562faf5 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -180,7 +180,7 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest); EncryptedPacketBufferHandle preparedMessage; - err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); @@ -194,8 +194,7 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) callback.LargeMessageSent = true; - err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(large_buffer), - preparedMessage); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(large_buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); @@ -211,8 +210,7 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) callback.LargeMessageSent = true; - err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(extra_large_buffer), - preparedMessage); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(extra_large_buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_MESSAGE_TOO_LONG); } @@ -272,7 +270,7 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetInitiator(true); - err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage); @@ -346,7 +344,7 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) payloadHeader.SetInitiator(true); - err = secureSessionMgr.BuildEncryptedMessagePayload(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); + err = secureSessionMgr.PrepareMessage(localToRemoteSession, payloadHeader, std::move(buffer), preparedMessage); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); err = secureSessionMgr.SendPreparedMessage(localToRemoteSession, preparedMessage);