Skip to content

Commit

Permalink
Fixing few issues with BindingManager and issues with how clients use…
Browse files Browse the repository at this point in the history
…d BindingManager (#22133)

* Couple with BindingManager

Fixed:
* Issue where BindingManager::EstablishConnection only called once
* Prevented used after free
* Prevented buffer overrun

* Restyle

* Allocate callbacks in EstablishConnection

This allows multiple connection async establishments to multiple
nodes.

* Address PR comments
  • Loading branch information
tehampson authored Aug 26, 2022
1 parent 3d7cc78 commit 7e00546
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 45 deletions.
10 changes: 8 additions & 2 deletions examples/all-clusters-app/ameba/main/BindingHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,20 @@ void LightSwitchChangedHandler(const EmberBindingTableEntry & binding, Operation
}
}

void LightSwitchContextReleaseHandler(void * context)
{
VerifyOrReturn(context != nullptr, ChipLogError(NotSpecified, "Invalid context for Light switch context release handler"));

Platform::Delete(static_cast<BindingCommandData *>(context));
}

void InitBindingHandlerInternal(intptr_t arg)
{
auto & server = chip::Server::GetInstance();
chip::BindingManager::GetInstance().Init(
{ &server.GetFabricTable(), server.GetCASESessionManager(), &server.GetPersistentStorage() });
chip::BindingManager::GetInstance().RegisterBoundDeviceChangedHandler(LightSwitchChangedHandler);
chip::BindingManager::GetInstance().RegisterBoundDeviceContextReleaseHandler(LightSwitchContextReleaseHandler);
}

#ifdef CONFIG_ENABLE_CHIP_SHELL
Expand Down Expand Up @@ -400,8 +408,6 @@ void SwitchWorkerFunction(intptr_t context)

BindingCommandData * data = reinterpret_cast<BindingCommandData *>(context);
BindingManager::GetInstance().NotifyBoundClusterChanged(data->localEndpointId, data->clusterId, static_cast<void *>(data));

Platform::Delete(data);
}

void BindingWorkerFunction(intptr_t context)
Expand Down
10 changes: 8 additions & 2 deletions examples/light-switch-app/ameba/main/BindingHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,20 @@ void LightSwitchChangedHandler(const EmberBindingTableEntry & binding, Operation
}
}

void LightSwitchContextReleaseHandler(void * context)
{
VerifyOrReturn(context != nullptr, ChipLogError(NotSpecified, "Invalid context for Light switch context release handler"));

Platform::Delete(static_cast<BindingCommandData *>(context));
}

void InitBindingHandlerInternal(intptr_t arg)
{
auto & server = chip::Server::GetInstance();
chip::BindingManager::GetInstance().Init(
{ &server.GetFabricTable(), server.GetCASESessionManager(), &server.GetPersistentStorage() });
chip::BindingManager::GetInstance().RegisterBoundDeviceChangedHandler(LightSwitchChangedHandler);
chip::BindingManager::GetInstance().RegisterBoundDeviceContextReleaseHandler(LightSwitchContextReleaseHandler);
}

#ifdef CONFIG_ENABLE_CHIP_SHELL
Expand Down Expand Up @@ -400,8 +408,6 @@ void SwitchWorkerFunction(intptr_t context)

BindingCommandData * data = reinterpret_cast<BindingCommandData *>(context);
BindingManager::GetInstance().NotifyBoundClusterChanged(data->localEndpointId, data->clusterId, static_cast<void *>(data));

Platform::Delete(data);
}

void BindingWorkerFunction(intptr_t context)
Expand Down
10 changes: 8 additions & 2 deletions examples/light-switch-app/esp32/main/BindingHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,20 @@ void LightSwitchChangedHandler(const EmberBindingTableEntry & binding, Operation
}
}

void LightSwitchContextReleaseHandler(void * context)
{
VerifyOrReturn(context != nullptr, ChipLogError(NotSpecified, "Invalid context for Light switch context release handler"));

Platform::Delete(static_cast<BindingCommandData *>(context));
}

void InitBindingHandlerInternal(intptr_t arg)
{
auto & server = chip::Server::GetInstance();
chip::BindingManager::GetInstance().Init(
{ &server.GetFabricTable(), server.GetCASESessionManager(), &server.GetPersistentStorage() });
chip::BindingManager::GetInstance().RegisterBoundDeviceChangedHandler(LightSwitchChangedHandler);
chip::BindingManager::GetInstance().RegisterBoundDeviceContextReleaseHandler(LightSwitchContextReleaseHandler);
}

#ifdef CONFIG_ENABLE_CHIP_SHELL
Expand Down Expand Up @@ -398,8 +406,6 @@ void SwitchWorkerFunction(intptr_t context)

BindingCommandData * data = reinterpret_cast<BindingCommandData *>(context);
BindingManager::GetInstance().NotifyBoundClusterChanged(data->localEndpointId, data->clusterId, static_cast<void *>(data));

Platform::Delete(data);
}

void BindingWorkerFunction(intptr_t context)
Expand Down
10 changes: 8 additions & 2 deletions examples/light-switch-app/telink/src/binding-handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ void LightSwitchChangedHandler(const EmberBindingTableEntry & binding, Operation
}
}

void LightSwitchContextReleaseHandler(void * context)
{
VerifyOrReturn(context != nullptr, ChipLogError(NotSpecified, "Invalid context for Light switch context release handler"));

Platform::Delete(static_cast<BindingCommandData *>(context));
}

#ifdef ENABLE_CHIP_SHELL

/********************************************************
Expand Down Expand Up @@ -385,6 +392,7 @@ void InitBindingHandlerInternal(intptr_t arg)
chip::BindingManager::GetInstance().Init(
{ &server.GetFabricTable(), server.GetCASESessionManager(), &server.GetPersistentStorage() });
chip::BindingManager::GetInstance().RegisterBoundDeviceChangedHandler(LightSwitchChangedHandler);
chip::BindingManager::GetInstance().RegisterBoundDeviceContextReleaseHandler(LightSwitchContextReleaseHandler);
}

} // namespace
Expand Down Expand Up @@ -413,8 +421,6 @@ void SwitchWorkerFunction(intptr_t context)

BindingCommandData * data = reinterpret_cast<BindingCommandData *>(context);
BindingManager::GetInstance().NotifyBoundClusterChanged(data->localEndpointId, data->clusterId, static_cast<void *>(data));

Platform::Delete(data);
}

void BindingWorkerFunction(intptr_t context)
Expand Down
32 changes: 14 additions & 18 deletions src/app/clusters/bindings/BindingManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,28 @@ CHIP_ERROR BindingManager::EstablishConnection(const ScopedNodeId & nodeId)
VerifyOrReturnError(mInitParams.mCASESessionManager != nullptr, CHIP_ERROR_INCORRECT_STATE);

mLastSessionEstablishmentError = CHIP_NO_ERROR;
mInitParams.mCASESessionManager->FindOrEstablishSession(nodeId, &mOnConnectedCallback, &mOnConnectionFailureCallback);
auto * connectionCallback = Platform::New<ConnectionCallback>(*this);
mInitParams.mCASESessionManager->FindOrEstablishSession(nodeId, connectionCallback->GetOnDeviceConnected(),
connectionCallback->GetOnDeviceConnectionFailure());
if (mLastSessionEstablishmentError == CHIP_ERROR_NO_MEMORY)
{
// Release the least recently used entry
// TODO: Some reference counting mechanism shall be added the CASESessionManager
// so that other session clients don't get accidentally closed.
ScopedNodeId peerToRemove;
if (mPendingNotificationMap.FindLRUConnectPeer(peerToRemove) == CHIP_NO_ERROR)
{
mPendingNotificationMap.RemoveAllEntriesForNode(peerToRemove);

// Now retry
mLastSessionEstablishmentError = CHIP_NO_ERROR;
mInitParams.mCASESessionManager->FindOrEstablishSession(nodeId, &mOnConnectedCallback, &mOnConnectionFailureCallback);
// At this point connectionCallback is null since it deletes itself when the callback is called.
connectionCallback = Platform::New<ConnectionCallback>(*this);
mInitParams.mCASESessionManager->FindOrEstablishSession(nodeId, connectionCallback->GetOnDeviceConnected(),
connectionCallback->GetOnDeviceConnectionFailure());
}
}
return mLastSessionEstablishmentError;
}

void BindingManager::HandleDeviceConnected(void * context, Messaging::ExchangeManager & exchangeMgr, SessionHandle & sessionHandle)
{
BindingManager * manager = static_cast<BindingManager *>(context);
manager->HandleDeviceConnected(exchangeMgr, sessionHandle);
}

void BindingManager::HandleDeviceConnected(Messaging::ExchangeManager & exchangeMgr, SessionHandle & sessionHandle)
{
FabricIndex fabricToRemove = kUndefinedFabricIndex;
Expand All @@ -149,17 +146,15 @@ void BindingManager::HandleDeviceConnected(Messaging::ExchangeManager & exchange
mPendingNotificationMap.RemoveAllEntriesForNode(ScopedNodeId(nodeToRemove, fabricToRemove));
}

void BindingManager::HandleDeviceConnectionFailure(void * context, const ScopedNodeId & peerId, CHIP_ERROR error)
{
BindingManager * manager = static_cast<BindingManager *>(context);
manager->HandleDeviceConnectionFailure(peerId, error);
}

void BindingManager::HandleDeviceConnectionFailure(const ScopedNodeId & peerId, CHIP_ERROR error)
{
// Simply release the entry, the connection will be re-established as needed.
ChipLogError(AppServer, "Failed to establish connection to node 0x" ChipLogFormatX64, ChipLogValueX64(peerId.GetNodeId()));
mLastSessionEstablishmentError = error;
// We don't release the entry when connection fails, because inside
// BindingManager::EstablishConnection we may try again the connection.
// TODO(#22173): The logic in there doesn't actually make any sense with how
// mPendingNotificationMap and CASESessionManager are implemented today.
}

void BindingManager::FabricRemoved(FabricIndex fabricIndex)
Expand Down Expand Up @@ -188,9 +183,10 @@ CHIP_ERROR BindingManager::NotifyBoundClusterChanged(EndpointId endpoint, Cluste
{
if (iter->type == EMBER_UNICAST_BINDING)
{
mPendingNotificationMap.AddPendingNotification(iter.GetIndex(), bindingContext);
error = mPendingNotificationMap.AddPendingNotification(iter.GetIndex(), bindingContext);
SuccessOrExit(error);
error = EstablishConnection(ScopedNodeId(iter->nodeId, iter->fabricIndex));
SuccessOrExit(error == CHIP_NO_ERROR);
SuccessOrExit(error);
}
else if (iter->type == EMBER_MULTICAST_BINDING)
{
Expand Down
54 changes: 43 additions & 11 deletions src/app/clusters/bindings/BindingManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ struct BindingManagerInitParams
class BindingManager
{
public:
BindingManager() :
mOnConnectedCallback(HandleDeviceConnected, this), mOnConnectionFailureCallback(HandleDeviceConnectionFailure, this)
{}
BindingManager() {}

void RegisterBoundDeviceChangedHandler(BoundDeviceChangedHandler handler) { mBoundDeviceChangedHandler = handler; }

Expand Down Expand Up @@ -123,22 +121,56 @@ class BindingManager
static BindingManager & GetInstance() { return sBindingManager; }

private:
static BindingManager sBindingManager;

static void HandleDeviceConnected(void * context, Messaging::ExchangeManager & exchangeMgr, SessionHandle & sessionHandle);
void HandleDeviceConnected(Messaging::ExchangeManager & exchangeMgr, SessionHandle & sessionHandle);
/*
* Used when providing OnConnection/Failure callbacks to CASESessionManager when establishing session.
*
* Since the BindingManager calls EstablishConnection inside of a loop, and it is possible that the
* callback is called some time after the loop is completed, we need a separate callbacks for each
* connection we are trying to establish. Failure to provide different instances of the callback
* to CASESessionManager may result in the callback only be called for that last EstablishConnection
* that was called when it establishes the connections asynchronously.
*
*/
class ConnectionCallback
{
public:
ConnectionCallback(BindingManager & bindingManager) :
mBindingManager(bindingManager), mOnConnectedCallback(HandleDeviceConnected, this),
mOnConnectionFailureCallback(HandleDeviceConnectionFailure, this)
{}

Callback::Callback<OnDeviceConnected> * GetOnDeviceConnected() { return &mOnConnectedCallback; }
Callback::Callback<OnDeviceConnectionFailure> * GetOnDeviceConnectionFailure() { return &mOnConnectionFailureCallback; }

private:
static void HandleDeviceConnected(void * context, Messaging::ExchangeManager & exchangeMgr, SessionHandle & sessionHandle)
{
ConnectionCallback * _this = static_cast<ConnectionCallback *>(context);
_this->mBindingManager.HandleDeviceConnected(exchangeMgr, sessionHandle);
Platform::Delete(_this);
}
static void HandleDeviceConnectionFailure(void * context, const ScopedNodeId & peerId, CHIP_ERROR error)
{
ConnectionCallback * _this = static_cast<ConnectionCallback *>(context);
_this->mBindingManager.HandleDeviceConnectionFailure(peerId, error);
Platform::Delete(_this);
}

BindingManager & mBindingManager;
Callback::Callback<OnDeviceConnected> mOnConnectedCallback;
Callback::Callback<OnDeviceConnectionFailure> mOnConnectionFailureCallback;
};

static void HandleDeviceConnectionFailure(void * context, const ScopedNodeId & peerId, CHIP_ERROR error);
void HandleDeviceConnectionFailure(const ScopedNodeId & peerId, CHIP_ERROR error);
static BindingManager sBindingManager;

CHIP_ERROR EstablishConnection(const ScopedNodeId & nodeId);

PendingNotificationMap mPendingNotificationMap;
BoundDeviceChangedHandler mBoundDeviceChangedHandler;
BindingManagerInitParams mInitParams;

Callback::Callback<OnDeviceConnected> mOnConnectedCallback;
Callback::Callback<OnDeviceConnectionFailure> mOnConnectionFailureCallback;
void HandleDeviceConnected(Messaging::ExchangeManager & exchangeMgr, SessionHandle & sessionHandle);
void HandleDeviceConnectionFailure(const ScopedNodeId & peerId, CHIP_ERROR error);

// Used to keep track of synchronous failures from FindOrEstablishSession.
CHIP_ERROR mLastSessionEstablishmentError;
Expand Down
8 changes: 7 additions & 1 deletion src/app/clusters/bindings/PendingNotificationMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,22 @@ CHIP_ERROR PendingNotificationMap::FindLRUConnectPeer(ScopedNodeId & nodeId)
return CHIP_ERROR_NOT_FOUND;
}

void PendingNotificationMap::AddPendingNotification(uint8_t bindingEntryId, PendingNotificationContext * context)
CHIP_ERROR PendingNotificationMap::AddPendingNotification(uint8_t bindingEntryId, PendingNotificationContext * context)
{
RemoveEntry(bindingEntryId);
if (mNumEntries == EMBER_BINDING_TABLE_SIZE)
{
// We know that the RemoveEntry above did not do anything so we don't need to try restoring it.
return CHIP_ERROR_NO_MEMORY;
}
mPendingBindingEntries[mNumEntries] = bindingEntryId;
mPendingContexts[mNumEntries] = context;
if (context)
{
context->IncrementConsumersNumber();
}
mNumEntries++;
return CHIP_NO_ERROR;
}

void PendingNotificationMap::RemoveEntry(uint8_t bindingEntryId)
Expand Down
2 changes: 1 addition & 1 deletion src/app/clusters/bindings/PendingNotificationMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class PendingNotificationMap

CHIP_ERROR FindLRUConnectPeer(ScopedNodeId & nodeId);

void AddPendingNotification(uint8_t bindingEntryId, PendingNotificationContext * context);
CHIP_ERROR AddPendingNotification(uint8_t bindingEntryId, PendingNotificationContext * context);

void RemoveEntry(uint8_t bindingEntryId);

Expand Down
15 changes: 9 additions & 6 deletions src/app/tests/TestPendingNotificationMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ void TestAddRemove(nlTestSuite * aSuite, void * aContext)
CreateDefaultFullBindingTable(BindingTable::GetInstance());
for (uint8_t i = 0; i < EMBER_BINDING_TABLE_SIZE; i++)
{
pendingMap.AddPendingNotification(i, nullptr);
NL_TEST_ASSERT(aSuite, pendingMap.AddPendingNotification(i, nullptr) == CHIP_NO_ERROR);
}
// Confirm adding in one more element fails
NL_TEST_ASSERT(aSuite, pendingMap.AddPendingNotification(EMBER_BINDING_TABLE_SIZE, nullptr) == CHIP_ERROR_NO_MEMORY);

auto iter = pendingMap.begin();
for (uint8_t i = 0; i < EMBER_BINDING_TABLE_SIZE; i++)
{
Expand Down Expand Up @@ -102,11 +105,11 @@ void TestLRUEntry(nlTestSuite * aSuite, void * aContext)
PendingNotificationMap pendingMap;
ClearBindingTable(BindingTable::GetInstance());
CreateDefaultFullBindingTable(BindingTable::GetInstance());
pendingMap.AddPendingNotification(0, nullptr);
pendingMap.AddPendingNotification(1, nullptr);
pendingMap.AddPendingNotification(5, nullptr);
pendingMap.AddPendingNotification(7, nullptr);
pendingMap.AddPendingNotification(11, nullptr);
NL_TEST_ASSERT(aSuite, pendingMap.AddPendingNotification(0, nullptr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(aSuite, pendingMap.AddPendingNotification(1, nullptr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(aSuite, pendingMap.AddPendingNotification(5, nullptr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(aSuite, pendingMap.AddPendingNotification(7, nullptr) == CHIP_NO_ERROR);
NL_TEST_ASSERT(aSuite, pendingMap.AddPendingNotification(11, nullptr) == CHIP_NO_ERROR);

chip::ScopedNodeId node;

Expand Down

0 comments on commit 7e00546

Please sign in to comment.