Skip to content

Commit

Permalink
Stop passing around reliable message managers when we don't need to (#…
Browse files Browse the repository at this point in the history
…7338)

* Make GetReliableMessageMgr public on ReliableMessageContext

* Stop passing around reliable message managers when we don't need to
  • Loading branch information
bzbarsky-apple authored and pull[bot] committed Jun 7, 2021
1 parent 855f953 commit 1355152
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/app/server/RendezvousServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ CHIP_ERROR RendezvousServer::WaitForPairing(const RendezvousParameters & params,
strlen(kSpake2pKeyExchangeSalt), mNextKeyId++, this));
}

ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mExchangeManager->GetReliableMessageMgr(), transportMgr));
ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr));
mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress());

return CHIP_NO_ERROR;
Expand Down
3 changes: 1 addition & 2 deletions src/controller/CHIPDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,7 @@ CHIP_ERROR Device::EstablishCASESession()
Messaging::ExchangeContext * exchange = mExchangeMgr->NewContext(SecureSessionHandle(), &mCASESession);
VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INTERNAL);

ReturnErrorOnFailure(
mCASESession.MessageDispatch().Init(mExchangeMgr->GetReliableMessageMgr(), mSessionManager->GetTransportManager()));
ReturnErrorOnFailure(mCASESession.MessageDispatch().Init(mSessionManager->GetTransportManager()));
mCASESession.MessageDispatch().SetPeerAddress(mDeviceAddress);

ReturnErrorOnFailure(mCASESession.EstablishSession(mDeviceAddress, mCredentials, mDeviceId, 0, exchange, this));
Expand Down
2 changes: 1 addition & 1 deletion src/controller/CHIPDeviceController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ CHIP_ERROR DeviceCommissioner::PairDevice(NodeId remoteDeviceId, RendezvousParam

mIsIPRendezvous = (params.GetPeerAddress().GetTransportType() != Transport::Type::kBle);

err = mPairingSession.MessageDispatch().Init(mExchangeMgr->GetReliableMessageMgr(), mTransportMgr);
err = mPairingSession.MessageDispatch().Init(mTransportMgr);
SuccessOrExit(err);
mPairingSession.MessageDispatch().SetPeerAddress(params.GetPeerAddress());

Expand Down
4 changes: 2 additions & 2 deletions src/messaging/ApplicationExchangeDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class ApplicationExchangeDispatch : public ExchangeMessageDispatch

virtual ~ApplicationExchangeDispatch() {}

CHIP_ERROR Init(ReliableMessageMgr * reliableMessageMgr, SecureSessionMgr * sessionMgr)
CHIP_ERROR Init(SecureSessionMgr * sessionMgr)
{
ReturnErrorCodeIf(sessionMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT);
mSessionMgr = sessionMgr;
return ExchangeMessageDispatch::Init(reliableMessageMgr);
return ExchangeMessageDispatch::Init();
}

CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message,
Expand Down
12 changes: 7 additions & 5 deletions src/messaging/ExchangeMessageDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,29 @@ CHIP_ERROR ExchangeMessageDispatch::SendMessage(SecureSessionHandle session, uin
#endif
}

if (IsReliableTransmissionAllowed() && reliableMessageContext->AutoRequestAck() && mReliableMessageMgr != nullptr &&
isReliableTransmission)
if (IsReliableTransmissionAllowed() && reliableMessageContext->AutoRequestAck() &&
reliableMessageContext->GetReliableMessageMgr() != nullptr && isReliableTransmission)
{
auto * reliableMessageMgr = reliableMessageContext->GetReliableMessageMgr();

payloadHeader.SetNeedsAck(true);

ReliableMessageMgr::RetransTableEntry * entry = nullptr;

// Add to Table for subsequent sending
ReturnErrorOnFailure(mReliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry));
ReturnErrorOnFailure(reliableMessageMgr->AddToRetransTable(reliableMessageContext, &entry));

CHIP_ERROR err = SendMessageImpl(session, payloadHeader, std::move(message), &entry->retainedBuf);
if (err != CHIP_NO_ERROR)
{
// Remove from table
ChipLogError(ExchangeManager, "Failed to send message with err %s", ::chip::ErrorStr(err));
mReliableMessageMgr->ClearRetransTable(*entry);
reliableMessageMgr->ClearRetransTable(*entry);
ReturnErrorOnFailure(err);
}
else
{
mReliableMessageMgr->StartRetransmision(entry);
reliableMessageMgr->StartRetransmision(entry);
}
}
else
Expand Down
9 changes: 1 addition & 8 deletions src/messaging/ExchangeMessageDispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ class ExchangeMessageDispatch : public ReferenceCounted<ExchangeMessageDispatch>
ExchangeMessageDispatch() {}
virtual ~ExchangeMessageDispatch() {}

CHIP_ERROR Init(ReliableMessageMgr * reliableMessageMgr)
{
mReliableMessageMgr = reliableMessageMgr;
return CHIP_NO_ERROR;
}
CHIP_ERROR Init() { return CHIP_NO_ERROR; }

CHIP_ERROR SendMessage(SecureSessionHandle session, uint16_t exchangeId, bool isInitiator,
ReliableMessageContext * reliableMessageContext, bool isReliableTransmission, Protocols::Id protocol,
Expand Down Expand Up @@ -70,9 +66,6 @@ class ExchangeMessageDispatch : public ReferenceCounted<ExchangeMessageDispatch>
System::PacketBufferHandle && message, EncryptedPacketBufferHandle * retainedMessage) = 0;

virtual bool IsReliableTransmissionAllowed() { return true; }

private:
ReliableMessageMgr * mReliableMessageMgr = nullptr;
};

} // namespace Messaging
Expand Down
2 changes: 1 addition & 1 deletion src/messaging/ExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ CHIP_ERROR ExchangeManager::Init(SecureSessionMgr * sessionMgr)
sessionMgr->SetDelegate(this);

mReliableMessageMgr.Init(sessionMgr->SystemLayer(), sessionMgr);
ReturnErrorOnFailure(mDefaultExchangeDispatch.Init(&mReliableMessageMgr, mSessionMgr));
ReturnErrorOnFailure(mDefaultExchangeDispatch.Init(mSessionMgr));

mState = State::kState_Initialized;

Expand Down
7 changes: 6 additions & 1 deletion src/messaging/ReliableMessageContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ class ReliableMessageContext
*/
void SetOccupied(bool inOccupied);

/**
* Get the reliable message manager that corresponds to this reliable
* message context.
*/
ReliableMessageMgr * GetReliableMessageMgr();

protected:
enum class Flags : uint16_t
{
Expand Down Expand Up @@ -229,7 +235,6 @@ class ReliableMessageContext
CHIP_ERROR HandleRcvdAck(uint32_t AckMsgId);
CHIP_ERROR HandleNeedsAck(uint32_t MessageId, BitFlags<MessageFlagValues> Flags);
ExchangeContext * GetExchangeContext();
ReliableMessageMgr * GetReliableMessageMgr();

private:
friend class ReliableMessageMgr;
Expand Down
4 changes: 2 additions & 2 deletions src/messaging/tests/TestReliableMessageProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ void CheckFailedMessageRetainOnSend(nlTestSuite * inSuite, void * inContext)
1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
});

err = mockSender.mMessageDispatch.Init(rm);
err = mockSender.mMessageDispatch.Init();
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

mockSender.mMessageDispatch.mRetainMessageOnSend = false;
Expand Down Expand Up @@ -615,7 +615,7 @@ void CheckResendSessionEstablishmentMessageWithPeerExchange(nlTestSuite * inSuit
1, // CHIP_CONFIG_MRP_DEFAULT_ACTIVE_RETRY_INTERVAL
});

err = mockSender.mMessageDispatch.Init(rm);
err = mockSender.mMessageDispatch.Init();
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);

// Let's drop the initial message
Expand Down
2 changes: 1 addition & 1 deletion src/protocols/secure_channel/CASEServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager
mAdmins = admins;
mExchangeManager = exchangeManager;

ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(mExchangeManager->GetReliableMessageMgr(), transportMgr));
ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr));

ExchangeDelegate * delegate = this;
ReturnErrorOnFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ class SessionEstablishmentExchangeDispatch : public Messaging::ExchangeMessageDi

virtual ~SessionEstablishmentExchangeDispatch() {}

CHIP_ERROR Init(Messaging::ReliableMessageMgr * reliableMessageMgr, TransportMgrBase * transportMgr)
CHIP_ERROR Init(TransportMgrBase * transportMgr)
{
ReturnErrorCodeIf(transportMgr == nullptr, CHIP_ERROR_INVALID_ARGUMENT);
mTransportMgr = transportMgr;
return ExchangeMessageDispatch::Init(reliableMessageMgr);
return ExchangeMessageDispatch::Init();
}

CHIP_ERROR ResendMessage(SecureSessionHandle session, EncryptedPacketBufferHandle && message,
Expand Down
15 changes: 4 additions & 11 deletions src/protocols/secure_channel/tests/TestCASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
TestCASESecurePairingDelegate delegate;
CASESession pairing;

NL_TEST_ASSERT(
inSuite, pairing.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
ExchangeContext * context = ctx.NewExchangeToLocal(&pairing);

NL_TEST_ASSERT(inSuite,
Expand All @@ -136,9 +135,7 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST;

CASESession pairing1;
NL_TEST_ASSERT(inSuite,
pairing1.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);

gLoopback.mSentMessageCount = 0;
gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST;
Expand All @@ -162,12 +159,8 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte
CASESessionSerializable serializableAccessory;

gLoopback.mSentMessageCount = 0;
NL_TEST_ASSERT(inSuite,
pairingCommissioner.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite,
pairingAccessory.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);

NL_TEST_ASSERT(inSuite,
ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(
Expand Down
15 changes: 4 additions & 11 deletions src/protocols/secure_channel/tests/TestPASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)

gLoopback.Reset();

NL_TEST_ASSERT(
inSuite, pairing.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairing.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
ExchangeContext * context = ctx.NewExchangeToLocal(&pairing);

NL_TEST_ASSERT(inSuite,
Expand All @@ -172,9 +171,7 @@ void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext)
gLoopback.mMessageSendError = CHIP_ERROR_BAD_REQUEST;

PASESession pairing1;
NL_TEST_ASSERT(inSuite,
pairing1.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairing1.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
ExchangeContext * context1 = ctx.NewExchangeToLocal(&pairing1);
NL_TEST_ASSERT(inSuite,
pairing1.Pair(Transport::PeerAddress(Transport::Type::kBle), 1234, 0, context1, &delegate) ==
Expand All @@ -192,12 +189,8 @@ void SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inContext, P

gLoopback.mSentMessageCount = 0;

NL_TEST_ASSERT(inSuite,
pairingCommissioner.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite,
pairingAccessory.MessageDispatch().Init(ctx.GetExchangeManager().GetReliableMessageMgr(), &gTransportMgr) ==
CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR);

ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(&pairingCommissioner);

Expand Down

0 comments on commit 1355152

Please sign in to comment.