diff --git a/src/app/server/RendezvousServer.cpp b/src/app/server/RendezvousServer.cpp index bedd9463b37fb7..a1adb9f2a1f9a4 100644 --- a/src/app/server/RendezvousServer.cpp +++ b/src/app/server/RendezvousServer.cpp @@ -111,7 +111,7 @@ CHIP_ERROR RendezvousServer::WaitForPairing(const RendezvousParameters & params, strlen(kSpake2pKeyExchangeSalt), mNextKeyId++, this)); } - ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mExchangeManager->GetReliableMessageMgr(), transportMgr)); + ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr)); mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress()); return CHIP_NO_ERROR; diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index 6398cd3c1f729d..78b92d93ba25ad 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -493,8 +493,7 @@ CHIP_ERROR Device::EstablishCASESession() Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(SecureSessionHandle(), &mCASESession); VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL); - ReturnErrorOnFailure( - mCASESession.MessageDispatch().Init(mExchangeMgr->GetReliableMessageMgr(), mSessionManager->GetTransportManager())); + ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager())); mCASESession.MessageDispatch().SetPeerAddress(mDeviceAddress); ReturnErrorOnFailure(mCASESession.EstablishSession(mDeviceAddress, mCredentials, mDeviceId, 0, exchange, this)); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 7666c16e08960c..64d092de375be2 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -851,7 +851,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle); - err = mPairingSession.MessageDispatch().Init(mExchangeMgr->GetReliableMessageMgr(), mTransportMgr); + err = mPairingSession.MessageDispatch().Init(mTransportMgr); SuccessOrExit(err); mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress()); diff --git a/src/messaging/ApplicationExchangeDispatch.h b/src/messaging/ApplicationExchangeDispatch.h index 4958cf11735ba8..ed330ac7e623ab 100644 --- a/src/messaging/ApplicationExchangeDispatch.h +++ b/src/messaging/ApplicationExchangeDispatch.h @@ -38,11 +38,11 @@ class ApplicationExchangeDispatch : public ExchangeMessageDispatch virtual ~ApplicationExchangeDispatch() {} - CHIP_ERROR Init(ReliableMessageMgr * reliableMessageMgr, SecureSessionMgr * sessionMgr) + CHIP_ERROR Init(SecureSessionMgr * sessionMgr) { ReturnErrorCodeIf(sessionMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT); mSessionMgr = sessionMgr; - return ExchangeMessageDispatch::Init(reliableMessageMgr); + return ExchangeMessageDispatch::Init(); } CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, diff --git a/src/messaging/ExchangeMessageDispatch.cpp b/src/messaging/ExchangeMessageDispatch.cpp index cced848bc31d13..22118e6ab12e58 100644 --- a/src/messaging/ExchangeMessageDispatch.cpp +++ b/src/messaging/ExchangeMessageDispatch.cpp @@ -65,27 +65,29 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uin #endif } - if (IsReliableTransmissionAllowed() && reliableMessageContext->AutoRequestAck() && mReliableMessageMgr != nullptr && - isReliableTransmission) + if (IsReliableTransmissionAllowed() && reliableMessageContext->AutoRequestAck() && + reliableMessageContext->GetReliableMessageMgr() != nullptr && isReliableTransmission) { + auto * reliableMessageMgr = reliableMessageContext->GetReliableMessageMgr(); + payloadHeader.SetNeedsAck(true); ReliableMessageMgr::RetransTableEntry * entry = nullptr; // Add to Table for subsequent sending - ReturnErrorOnFailure(mReliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry)); + ReturnErrorOnFailure(reliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry)); CHIP_ERROR err = SendMessageImpl(session, payloadHeader, std::move(message), &entry->retainedBuf); if (err != CHIP_NO_ERROR) { // Remove from table ChipLogError(ExchangeManager, "Failed to send message with err %s", ::chip::ErrorStr(err)); - mReliableMessageMgr->ClearRetransTable(*entry); + reliableMessageMgr->ClearRetransTable(*entry); ReturnErrorOnFailure(err); } else { - mReliableMessageMgr->StartRetransmision(entry); + reliableMessageMgr->StartRetransmision(entry); } } else diff --git a/src/messaging/ExchangeMessageDispatch.h b/src/messaging/ExchangeMessageDispatch.h index d5cc4111d1d338..0b9719ee85dc42 100644 --- a/src/messaging/ExchangeMessageDispatch.h +++ b/src/messaging/ExchangeMessageDispatch.h @@ -38,11 +38,7 @@ class ExchangeMessageDispatch : public ReferenceCounted ExchangeMessageDispatch() {} virtual ~ExchangeMessageDispatch() {} - CHIP_ERROR Init(ReliableMessageMgr * reliableMessageMgr) - { - mReliableMessageMgr = reliableMessageMgr; - return CHIP_NO_ERROR; - } + CHIP_ERROR Init() { return CHIP_NO_ERROR; } CHIP_ERROR SendMessage(SecureSessionHandle session, uint16_t exchangeId, bool isInitiator, ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol, @@ -70,9 +66,6 @@ class ExchangeMessageDispatch : public ReferenceCounted System::PacketBufferHandle && message, EncryptedPacketBufferHandle * retainedMessage) = 0; virtual bool IsReliableTransmissionAllowed() { return true; } - -private: - ReliableMessageMgr * mReliableMessageMgr = nullptr; }; } // namespace Messaging diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 6b960f49dc8e44..d83c28eb5c2db3 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -86,7 +86,7 @@ CHIP_ERROR ExchangeManager::Init(SecureSessionMgr * sessionMgr) sessionMgr->SetDelegate(this); mReliableMessageMgr.Init(sessionMgr->SystemLayer(), sessionMgr); - ReturnErrorOnFailure(mDefaultExchangeDispatch.Init(&mReliableMessageMgr, mSessionMgr)); + ReturnErrorOnFailure(mDefaultExchangeDispatch.Init(mSessionMgr)); mState = State::kState_Initialized; diff --git a/src/messaging/ReliableMessageContext.h b/src/messaging/ReliableMessageContext.h index f8fe36d454e8df..25e82c98cb95d1 100644 --- a/src/messaging/ReliableMessageContext.h +++ b/src/messaging/ReliableMessageContext.h @@ -191,6 +191,12 @@ class ReliableMessageContext */ void SetOccupied(bool inOccupied); + /** + * Get the reliable message manager that corresponds to this reliable + * message context. + */ + ReliableMessageMgr * GetReliableMessageMgr(); + protected: enum class Flags : uint16_t { @@ -229,7 +235,6 @@ class ReliableMessageContext CHIP_ERROR HandleRcvdAck(uint32_t AckMsgId); CHIP_ERROR HandleNeedsAck(uint32_t MessageId, BitFlags Flags); ExchangeContext * GetExchangeContext(); - ReliableMessageMgr * GetReliableMessageMgr(); private: friend class ReliableMessageMgr; diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 20382b05d71d28..3e9ac2ff27819c 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -394,7 +394,7 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext) 1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL }); - err = mockSender.mMessageDispatch.Init(rm); + err = mockSender.mMessageDispatch.Init(); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); mockSender.mMessageDispatch.mRetainMessageOnSend = false; @@ -615,7 +615,7 @@ void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuit 1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL }); - err = mockSender.mMessageDispatch.Init(rm); + err = mockSender.mMessageDispatch.Init(); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); // Let's drop the initial message diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 0176bb250727e6..25adca9cbf1794 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -41,7 +41,7 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager mAdmins = admins; mExchangeManager = exchangeManager; - ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mExchangeManager->GetReliableMessageMgr(), transportMgr)); + ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr)); ExchangeDelegate * delegate = this; ReturnErrorOnFailure( diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h index a6e9a669727fb1..8ac9c0a2df3705 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.h @@ -36,11 +36,11 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi virtual ~SessionEstablishmentExchangeDispatch() {} - CHIP_ERROR Init(Messaging::ReliableMessageMgr * reliableMessageMgr, TransportMgrBase * transportMgr) + CHIP_ERROR Init(TransportMgrBase * transportMgr) { ReturnErrorCodeIf(transportMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT); mTransportMgr = transportMgr; - return ExchangeMessageDispatch::Init(reliableMessageMgr); + return ExchangeMessageDispatch::Init(); } CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message, diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 642e8e81537dd9..57d5b6f41a4b48 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -120,8 +120,7 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) TestCASESecurePairingDelegate delegate; CASESession pairing; - NL_TEST_ASSERT( - inSuite, pairing.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); ExchangeContext * context = ctx.NewExchangeToLocal(&pairing); NL_TEST_ASSERT(inSuite, @@ -136,9 +135,7 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; CASESession pairing1; - NL_TEST_ASSERT(inSuite, - pairing1.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == - CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); gLoopback.mSentMessageCount = 0; gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; @@ -162,12 +159,8 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte CASESessionSerializable serializableAccessory; gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, - pairingCommissioner.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == - CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, - pairingAccessory.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == - CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType( diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index e4477de2d4aabb..48dfc4abfd5204 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -154,8 +154,7 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.Reset(); - NL_TEST_ASSERT( - inSuite, pairing.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); ExchangeContext * context = ctx.NewExchangeToLocal(&pairing); NL_TEST_ASSERT(inSuite, @@ -172,9 +171,7 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST; PASESession pairing1; - NL_TEST_ASSERT(inSuite, - pairing1.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == - CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); ExchangeContext * context1 = ctx.NewExchangeToLocal(&pairing1); NL_TEST_ASSERT(inSuite, pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context1, &delegate) == @@ -192,12 +189,8 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P gLoopback.mSentMessageCount = 0; - NL_TEST_ASSERT(inSuite, - pairingCommissioner.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == - CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, - pairingAccessory.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == - CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(&pairingCommissioner);