diff --git a/examples/shell/shell_common/cmd_ping.cpp b/examples/shell/shell_common/cmd_ping.cpp index 357b02ed65e9e1..b2c4bc20eb33a3 100644 --- a/examples/shell/shell_common/cmd_ping.cpp +++ b/examples/shell/shell_common/cmd_ping.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -269,7 +270,8 @@ void StartPinging(streamer_t * stream, char * destination) { peerAddress = Transport::PeerAddress::TCP(gDestAddr, gPingArguments.GetEchoPort()); - err = gSessionManager.Init(kTestControllerNodeId, &DeviceLayer::SystemLayer, &gTCPManager, &admins); + err = + gSessionManager.Init(kTestControllerNodeId, &DeviceLayer::SystemLayer, &gTCPManager, &admins, &gMessageCounterManager); SuccessOrExit(err); err = gExchangeManager.Init(&gSessionManager); @@ -280,13 +282,17 @@ void StartPinging(streamer_t * stream, char * destination) { peerAddress = Transport::PeerAddress::UDP(gDestAddr, gPingArguments.GetEchoPort(), INET_NULL_INTERFACEID); - err = gSessionManager.Init(kTestControllerNodeId, &DeviceLayer::SystemLayer, &gUDPManager, &admins); + err = + gSessionManager.Init(kTestControllerNodeId, &DeviceLayer::SystemLayer, &gUDPManager, &admins, &gMessageCounterManager); SuccessOrExit(err); err = gExchangeManager.Init(&gSessionManager); SuccessOrExit(err); } + err = gMessageCounterManager.Init(&gExchangeManager); + SuccessOrExit(err); + // Start the CHIP connection to the CHIP echo responder. err = EstablishSecureSession(stream, peerAddress); SuccessOrExit(err); diff --git a/examples/shell/shell_common/cmd_send.cpp b/examples/shell/shell_common/cmd_send.cpp index ee30d2fab899ff..42ab92dab69998 100644 --- a/examples/shell/shell_common/cmd_send.cpp +++ b/examples/shell/shell_common/cmd_send.cpp @@ -256,10 +256,8 @@ void ProcessCommand(streamer_t * stream, char * destination) { peerAddress = Transport::PeerAddress::TCP(gDestAddr, gSendArguments.GetPort()); - err = gSessionManager.Init(kTestControllerNodeId, &DeviceLayer::SystemLayer, &gTCPManager, &admins); - SuccessOrExit(err); - - err = gExchangeManager.Init(&gSessionManager); + err = + gSessionManager.Init(kTestControllerNodeId, &DeviceLayer::SystemLayer, &gTCPManager, &admins, &gMessageCounterManager); SuccessOrExit(err); } else @@ -267,13 +265,17 @@ void ProcessCommand(streamer_t * stream, char * destination) { peerAddress = Transport::PeerAddress::UDP(gDestAddr, gSendArguments.GetPort(), INET_NULL_INTERFACEID); - err = gSessionManager.Init(kTestControllerNodeId, &DeviceLayer::SystemLayer, &gUDPManager, &admins); - SuccessOrExit(err); - - err = gExchangeManager.Init(&gSessionManager); + err = + gSessionManager.Init(kTestControllerNodeId, &DeviceLayer::SystemLayer, &gUDPManager, &admins, &gMessageCounterManager); SuccessOrExit(err); } + err = gExchangeManager.Init(&gSessionManager); + SuccessOrExit(err); + + err = gMessageCounterManager.Init(&gExchangeManager); + SuccessOrExit(err); + // Start the CHIP connection to the CHIP server. err = EstablishSecureSession(stream, peerAddress); SuccessOrExit(err); diff --git a/examples/shell/shell_common/globals.cpp b/examples/shell/shell_common/globals.cpp index 95f4c8135bbccd..2e2d7aaf9c122f 100644 --- a/examples/shell/shell_common/globals.cpp +++ b/examples/shell/shell_common/globals.cpp @@ -17,6 +17,7 @@ #include +chip::secure_channel::MessageCounterManager gMessageCounterManager; chip::Messaging::ExchangeManager gExchangeManager; chip::SecureSessionMgr gSessionManager; chip::Inet::IPAddress gDestAddr; diff --git a/examples/shell/shell_common/include/Globals.h b/examples/shell/shell_common/include/Globals.h index fd761fc9af70ac..377e1775fd4246 100644 --- a/examples/shell/shell_common/include/Globals.h +++ b/examples/shell/shell_common/include/Globals.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -30,6 +31,7 @@ constexpr size_t kMaxTcpPendingPackets = 4; constexpr size_t kMaxPayloadSize = 1280; constexpr size_t kResponseTimeOut = 1000; +extern chip::secure_channel::MessageCounterManager gMessageCounterManager; extern chip::Messaging::ExchangeManager gExchangeManager; extern chip::SecureSessionMgr gSessionManager; extern chip::Inet::IPAddress gDestAddr; diff --git a/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp b/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp index 28aa7e2f3c34ab..e74b0ad4d39bab 100644 --- a/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp +++ b/src/app/clusters/operational-credentials-server/operational-credentials-server.cpp @@ -263,8 +263,7 @@ bool emberAfOperationalCredentialsClusterSetFabricCallback(chip::app::Command * SuccessOrExit(err = commandObj->PrepareCommand(&cmdParams)); writer = commandObj->GetCommandDataElementTLVWriter(); - SuccessOrExit( - err = writer->Put(TLV::ContextTag(0), commandObj->GetExchangeContext()->GetSecureSessionHandle().GetPeerNodeId())); + SuccessOrExit(err = writer->Put(TLV::ContextTag(0), commandObj->GetExchangeContext()->GetSecureSession().GetPeerNodeId())); SuccessOrExit(err = commandObj->FinishCommand()); } diff --git a/src/app/server/Server.cpp b/src/app/server/Server.cpp index 20532df6467c61..a6a2c29886c134 100644 --- a/src/app/server/Server.cpp +++ b/src/app/server/Server.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -397,6 +398,7 @@ class ServerCallback : public ExchangeDelegate SecureSessionMgr * mSessionMgr = nullptr; }; +secure_channel::MessageCounterManager gMessageCounterManager; ServerCallback gCallbacks; SecurePairingUsingTestSecret gTestPairing; @@ -507,11 +509,14 @@ void InitServer(AppDelegate * delegate) SuccessOrExit(err); - err = gSessions.Init(chip::kTestDeviceNodeId, &DeviceLayer::SystemLayer, &gTransports, &gAdminPairings); + err = + gSessions.Init(chip::kTestDeviceNodeId, &DeviceLayer::SystemLayer, &gTransports, &gAdminPairings, &gMessageCounterManager); SuccessOrExit(err); err = gExchangeMgr.Init(&gSessions); SuccessOrExit(err); + err = gMessageCounterManager.Init(&gExchangeMgr); + SuccessOrExit(err); err = chip::app::InteractionModelEngine::GetInstance()->Init(&gExchangeMgr, nullptr); SuccessOrExit(err); diff --git a/src/app/tests/TestCommandInteraction.cpp b/src/app/tests/TestCommandInteraction.cpp index 400c42f35b6f08..fe5ac7dd5c90f9 100644 --- a/src/app/tests/TestCommandInteraction.cpp +++ b/src/app/tests/TestCommandInteraction.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -49,6 +50,7 @@ static System::Layer gSystemLayer; static SecureSessionMgr gSessionManager; static Messaging::ExchangeManager gExchangeManager; static TransportMgr gTransportManager; +static secure_channel::MessageCounterManager gMessageCounterManager; static Transport::AdminId gAdminId = 0; namespace app { @@ -311,12 +313,16 @@ void InitializeChip(nlTestSuite * apSuite) chip::gSystemLayer.Init(nullptr); - err = chip::gSessionManager.Init(chip::kTestDeviceNodeId, &chip::gSystemLayer, &chip::gTransportManager, &admins); + err = chip::gSessionManager.Init(chip::kTestDeviceNodeId, &chip::gSystemLayer, &chip::gTransportManager, &admins, + &chip::gMessageCounterManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); err = chip::gExchangeManager.Init(&chip::gSessionManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + err = chip::gMessageCounterManager.Init(&chip::gExchangeManager); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + err = chip::app::InteractionModelEngine::GetInstance()->Init(&chip::gExchangeManager, nullptr); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); } diff --git a/src/app/tests/TestEventLogging.cpp b/src/app/tests/TestEventLogging.cpp index e391ad7304bb69..c163c71a28384a 100644 --- a/src/app/tests/TestEventLogging.cpp +++ b/src/app/tests/TestEventLogging.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -63,6 +64,7 @@ static chip::app::CircularEventBuffer gCircularEventBuffer[3]; chip::SecureSessionMgr gSessionManager; chip::Messaging::ExchangeManager gExchangeManager; +chip::secure_channel::MessageCounterManager gMessageCounterManager; void InitializeChip(nlTestSuite * apSuite) { @@ -78,11 +80,14 @@ void InitializeChip(nlTestSuite * apSuite) gSystemLayer.Init(nullptr); - err = gSessionManager.Init(chip::kTestDeviceNodeId, &gSystemLayer, &gTransportManager, &admins); + err = gSessionManager.Init(chip::kTestDeviceNodeId, &gSystemLayer, &gTransportManager, &admins, &gMessageCounterManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); err = gExchangeManager.Init(&gSessionManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + err = gMessageCounterManager.Init(&gExchangeManager); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); } void InitializeEventLogging() diff --git a/src/app/tests/TestInteractionModelEngine.cpp b/src/app/tests/TestInteractionModelEngine.cpp index d6d043120e4fb6..bd87403f369ac7 100644 --- a/src/app/tests/TestInteractionModelEngine.cpp +++ b/src/app/tests/TestInteractionModelEngine.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -45,6 +46,7 @@ namespace { static chip::System::Layer gSystemLayer; static chip::SecureSessionMgr gSessionManager; static chip::Messaging::ExchangeManager gExchangeManager; +static chip::secure_channel::MessageCounterManager gMessageCounterManager; static chip::TransportMgr gTransportManager; static const chip::Transport::AdminId gAdminId = 0; } // namespace @@ -117,11 +119,14 @@ void InitializeChip(nlTestSuite * apSuite) gSystemLayer.Init(nullptr); - err = gSessionManager.Init(chip::kTestDeviceNodeId, &gSystemLayer, &gTransportManager, &admins); + err = gSessionManager.Init(chip::kTestDeviceNodeId, &gSystemLayer, &gTransportManager, &admins, &gMessageCounterManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); err = gExchangeManager.Init(&gSessionManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + err = gMessageCounterManager.Init(&gExchangeManager); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); } // clang-format off diff --git a/src/app/tests/TestReadInteraction.cpp b/src/app/tests/TestReadInteraction.cpp index 7af271004eef78..5831a2ffc68a55 100644 --- a/src/app/tests/TestReadInteraction.cpp +++ b/src/app/tests/TestReadInteraction.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -47,6 +48,7 @@ SecureSessionMgr gSessionManager; Messaging::ExchangeManager gExchangeManager; TransportMgr gTransportManager; const Transport::AdminId gAdminId = 0; +secure_channel::MessageCounterManager gMessageCounterManager; namespace app { class TestReadInteraction @@ -157,11 +159,15 @@ void InitializeChip(nlTestSuite * apSuite) chip::gSystemLayer.Init(nullptr); - err = chip::gSessionManager.Init(chip::kTestDeviceNodeId, &chip::gSystemLayer, &chip::gTransportManager, &admins); + err = chip::gSessionManager.Init(chip::kTestDeviceNodeId, &chip::gSystemLayer, &chip::gTransportManager, &admins, + &chip::gMessageCounterManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); err = chip::gExchangeManager.Init(&chip::gSessionManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + err = chip::gMessageCounterManager.Init(&chip::gExchangeManager); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); } /** diff --git a/src/app/tests/TestReportingEngine.cpp b/src/app/tests/TestReportingEngine.cpp index 75a3e356521cea..2a0542018078ad 100644 --- a/src/app/tests/TestReportingEngine.cpp +++ b/src/app/tests/TestReportingEngine.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -47,6 +48,7 @@ static System::Layer gSystemLayer; static SecureSessionMgr gSessionManager; static Messaging::ExchangeManager gExchangeManager; static TransportMgr gTransportManager; +static secure_channel::MessageCounterManager gMessageCounterManager; static const Transport::AdminId gAdminId = 0; constexpr ClusterId kTestClusterId = 6; constexpr EndpointId kTestEndpointId = 1; @@ -153,11 +155,15 @@ void InitializeChip(nlTestSuite * apSuite) chip::gSystemLayer.Init(nullptr); - err = chip::gSessionManager.Init(chip::kTestDeviceNodeId, &chip::gSystemLayer, &chip::gTransportManager, &admins); + err = chip::gSessionManager.Init(chip::kTestDeviceNodeId, &chip::gSystemLayer, &chip::gTransportManager, &admins, + &chip::gMessageCounterManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); err = chip::gExchangeManager.Init(&chip::gSessionManager); NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); + + err = chip::gMessageCounterManager.Init(&chip::gExchangeManager); + NL_TEST_ASSERT(apSuite, err == CHIP_NO_ERROR); } // clang-format off diff --git a/src/app/tests/integration/chip_im_initiator.cpp b/src/app/tests/integration/chip_im_initiator.cpp index d1ac99de60b580..df29a9985ab689 100644 --- a/src/app/tests/integration/chip_im_initiator.cpp +++ b/src/app/tests/integration/chip_im_initiator.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -56,8 +57,8 @@ chip::app::CommandSender * gpCommandSender = nullptr; chip::app::ReadClient * gpReadClient = nullptr; chip::TransportMgr gTransportManager; - chip::SecureSessionMgr gSessionManager; +chip::secure_channel::MessageCounterManager gMessageCounterManager; chip::Inet::IPAddress gDestAddr; @@ -317,12 +318,16 @@ int main(int argc, char * argv[]) .SetListenPort(IM_CLIENT_PORT)); SuccessOrExit(err); - err = gSessionManager.Init(chip::kTestControllerNodeId, &chip::DeviceLayer::SystemLayer, &gTransportManager, &admins); + err = gSessionManager.Init(chip::kTestControllerNodeId, &chip::DeviceLayer::SystemLayer, &gTransportManager, &admins, + &gMessageCounterManager); SuccessOrExit(err); err = gExchangeManager.Init(&gSessionManager); SuccessOrExit(err); + err = gMessageCounterManager.Init(&gExchangeManager); + SuccessOrExit(err); + err = chip::app::InteractionModelEngine::GetInstance()->Init(&gExchangeManager, &mockDelegate); SuccessOrExit(err); diff --git a/src/app/tests/integration/chip_im_responder.cpp b/src/app/tests/integration/chip_im_responder.cpp index 2273c8c75c3acc..a357d41228221c 100644 --- a/src/app/tests/integration/chip_im_responder.cpp +++ b/src/app/tests/integration/chip_im_responder.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -125,6 +126,7 @@ namespace { chip::TransportMgr gTransportManager; chip::SecureSessionMgr gSessionManager; chip::SecurePairingUsingTestSecret gTestPairing; +chip::secure_channel::MessageCounterManager gMessageCounterManager; LivenessEventGenerator gLivenessGenerator; uint8_t gDebugEventBuffer[2048]; @@ -162,12 +164,16 @@ int main(int argc, char * argv[]) chip::Transport::UdpListenParameters(&chip::DeviceLayer::InetLayer).SetAddressType(chip::Inet::kIPAddressType_IPv4)); SuccessOrExit(err); - err = gSessionManager.Init(chip::kTestDeviceNodeId, &chip::DeviceLayer::SystemLayer, &gTransportManager, &admins); + err = gSessionManager.Init(chip::kTestDeviceNodeId, &chip::DeviceLayer::SystemLayer, &gTransportManager, &admins, + &gMessageCounterManager); SuccessOrExit(err); err = gExchangeManager.Init(&gSessionManager); SuccessOrExit(err); + err = gMessageCounterManager.Init(&gExchangeManager); + SuccessOrExit(err); + err = chip::app::InteractionModelEngine::GetInstance()->Init(&gExchangeManager, &mockDelegate); SuccessOrExit(err); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 21533458c66567..d357d2b84d5dd8 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include @@ -162,9 +163,10 @@ CHIP_ERROR DeviceController::Init(NodeId localDeviceId, ControllerInitParams par VerifyOrExit(mBleLayer != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT); #endif - mTransportMgr = chip::Platform::New(); - mSessionMgr = chip::Platform::New(); - mExchangeMgr = chip::Platform::New(); + mTransportMgr = chip::Platform::New(); + mSessionMgr = chip::Platform::New(); + mExchangeMgr = chip::Platform::New(); + mMessageCounterManager = chip::Platform::New(); err = mTransportMgr->Init( Transport::UdpListenParameters(mInetLayer).SetAddressType(Inet::kIPAddressType_IPv6).SetListenPort(mListenPort) @@ -182,12 +184,15 @@ CHIP_ERROR DeviceController::Init(NodeId localDeviceId, ControllerInitParams par admin = mAdmins.AssignAdminId(mAdminId, localDeviceId); VerifyOrExit(admin != nullptr, err = CHIP_ERROR_NO_MEMORY); - err = mSessionMgr->Init(localDeviceId, mSystemLayer, mTransportMgr, &mAdmins); + err = mSessionMgr->Init(localDeviceId, mSystemLayer, mTransportMgr, &mAdmins, mMessageCounterManager); SuccessOrExit(err); err = mExchangeMgr->Init(mSessionMgr); SuccessOrExit(err); + err = mMessageCounterManager->Init(mExchangeMgr); + SuccessOrExit(err); + err = mExchangeMgr->RegisterUnsolicitedMessageHandlerForProtocol(Protocols::TempZCL::Id, this); SuccessOrExit(err); @@ -234,6 +239,16 @@ CHIP_ERROR DeviceController::Shutdown() mState = State::NotInitialized; + // TODO(#6668): Some exchange has leak, shutting down ExchangeManager will cause a assert fail. + // if (mExchangeMgr != nullptr) + // { + // mExchangeMgr->Shutdown(); + // } + if (mSessionMgr != nullptr) + { + mSessionMgr->Shutdown(); + } + #if CONFIG_DEVICE_LAYER ReturnErrorOnFailure(DeviceLayer::PlatformMgr().Shutdown()); #else @@ -247,6 +262,12 @@ CHIP_ERROR DeviceController::Shutdown() mInetLayer = nullptr; mStorageDelegate = nullptr; + if (mMessageCounterManager != nullptr) + { + chip::Platform::Delete(mMessageCounterManager); + mMessageCounterManager = nullptr; + } + if (mExchangeMgr != nullptr) { chip::Platform::Delete(mExchangeMgr); diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index 28613dad3af205..672c9fa47a7199 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -247,6 +248,7 @@ class DLL_EXPORT DeviceController : public Messaging::ExchangeDelegate, DeviceTransportMgr * mTransportMgr; SecureSessionMgr * mSessionMgr; Messaging::ExchangeManager * mExchangeMgr; + secure_channel::MessageCounterManager * mMessageCounterManager; PersistentStorageDelegate * mStorageDelegate; DeviceControllerInteractionModelDelegate * mDefaultIMDelegate; #if CHIP_DEVICE_CONFIG_ENABLE_MDNS diff --git a/src/lib/core/CHIPConfig.h b/src/lib/core/CHIPConfig.h index 149a066a4d3b05..387bc4b4fdc462 100644 --- a/src/lib/core/CHIPConfig.h +++ b/src/lib/core/CHIPConfig.h @@ -1292,6 +1292,28 @@ #define CHIP_CONFIG_NODE_ADDRESS_RESOLVE_TIMEOUT_MSECS (5000) #endif // CHIP_CONFIG_NODE_ADDRESS_RESOLVE_TIMEOUT_MSECS +/** + * @def CHIP_CONFIG_MCSP_RECEIVE_TABLE_SIZE + * + * @brief + * Size of the receive table for message counter synchronization protocol + * + */ +#ifndef CHIP_CONFIG_MCSP_RECEIVE_TABLE_SIZE +#define CHIP_CONFIG_MCSP_RECEIVE_TABLE_SIZE (CHIP_CONFIG_MAX_EXCHANGE_CONTEXTS - 2) +#endif // CHIP_CONFIG_MCSP_RECEIVE_TABLE_SIZE + +/** + * @def CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE + * + * @brief + * Max number of messages behind message window can be accepted. + * + */ +#ifndef CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE +#define CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE 32 +#endif // CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE + /** * @def CHIP_CONFIG_CONNECT_IP_ADDRS * @@ -1595,6 +1617,10 @@ #define CHIP_CONFIG_SUPPORT_CASE_CONFIG1 1 #endif // CHIP_CONFIG_SUPPORT_CASE_CONFIG1 +#ifndef CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER +#define CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER "GlobalMCTR" +#endif // CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER + /** * @def CHIP_CONFIG_DEFAULT_CASE_CURVE_ID * diff --git a/src/lib/support/Span.h b/src/lib/support/Span.h index 38c903716a98de..45282e3f9ceab5 100644 --- a/src/lib/support/Span.h +++ b/src/lib/support/Span.h @@ -44,6 +44,22 @@ class Span size_t mDataLen; }; +template +class FixedSpan +{ +public: + constexpr FixedSpan() : mDataBuf(nullptr) {} + constexpr explicit FixedSpan(const T * databuf) : mDataBuf(databuf) {} + + const T * data() const { return mDataBuf; } + size_t size() const { return N; } + +private: + const T * mDataBuf; +}; + using ByteSpan = Span; +template +using FixedByteSpan = FixedSpan; } // namespace chip diff --git a/src/lib/support/logging/CHIPLogging.cpp b/src/lib/support/logging/CHIPLogging.cpp index 3bf316419a5fa4..d70009e210b10f 100644 --- a/src/lib/support/logging/CHIPLogging.cpp +++ b/src/lib/support/logging/CHIPLogging.cpp @@ -67,6 +67,7 @@ static const char ModuleNames[] = "-\0\0" // None "CR\0" // Crypto "CTL" // Controller "AL\0" // Alarm + "SC\0" // SecureChannel "BDX" // BulkDataTransfer "DMG" // DataManagement "DC\0" // DeviceControl diff --git a/src/lib/support/logging/Constants.h b/src/lib/support/logging/Constants.h index 19da28416efd71..c605e2fd93214b 100644 --- a/src/lib/support/logging/Constants.h +++ b/src/lib/support/logging/Constants.h @@ -30,6 +30,7 @@ enum LogModule kLogModule_Crypto, kLogModule_Controller, kLogModule_Alarm, + kLogModule_SecureChannel, kLogModule_BDX, kLogModule_DataManagement, kLogModule_DeviceControl, diff --git a/src/messaging/BUILD.gn b/src/messaging/BUILD.gn index 9cb054af242841..76966ef817eb4b 100644 --- a/src/messaging/BUILD.gn +++ b/src/messaging/BUILD.gn @@ -32,8 +32,6 @@ static_library("messaging") { "ExchangeMgr.h", "ExchangeMgrDelegate.h", "Flags.h", - "MessageCounterSync.cpp", - "MessageCounterSync.h", "ReliableMessageContext.cpp", "ReliableMessageContext.h", "ReliableMessageMgr.cpp", diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 821569dbeb14f1..281834f5d42f82 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -84,37 +84,6 @@ CHIP_ERROR ExchangeContext::SendMessage(Protocols::Id protocolId, uint8_t msgTyp VerifyOrReturnError(mExchangeMgr != nullptr, CHIP_ERROR_INTERNAL); - state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession); - - // If a group message is to be transmitted to a destination node whose message counter is unknown. - if (state != nullptr && ChipKeyId::IsAppGroupKey(state->GetLocalKeyID()) && !state->IsPeerMsgCounterSynced()) - { - MessageCounterSyncMgr * messageCounterSyncMgr = mExchangeMgr->GetMessageCounterSyncMgr(); - VerifyOrReturnError(messageCounterSyncMgr != nullptr, CHIP_ERROR_INTERNAL); - - // Queue the message as needed for sync with destination node. - err = messageCounterSyncMgr->AddToRetransmissionTable(protocolId, msgType, sendFlags, std::move(msgBuf), this); - ReturnErrorOnFailure(err); - - // Initiate message counter synchronization if no message counter synchronization is in progress. - if (!state->IsMsgCounterSyncInProgress()) - { - err = mExchangeMgr->GetMessageCounterSyncMgr()->SendMsgCounterSyncReq(mSecureSession); - } - } - else - { - err = SendMessageImpl(protocolId, msgType, std::move(msgBuf), sendFlags, state); - } - - return err; -} - -CHIP_ERROR ExchangeContext::SendMessageImpl(Protocols::Id protocolId, uint8_t msgType, PacketBufferHandle msgBuf, - const SendFlags & sendFlags, Transport::PeerConnectionState * state) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - // Don't let method get called on a freed object. VerifyOrDie(mExchangeMgr != nullptr && GetReferenceCount() > 0); @@ -125,8 +94,9 @@ CHIP_ERROR ExchangeContext::SendMessageImpl(Protocols::Id protocolId, uint8_t ms bool reliableTransmissionRequested = true; + state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(mSecureSession); // If sending via UDP and NoAutoRequestAck send flag is not specificed, request reliable transmission. - if (state && state->GetPeerAddress().GetTransportType() != Transport::Type::kUdp) + if (state != nullptr && state->GetPeerAddress().GetTransportType() != Transport::Type::kUdp) { reliableTransmissionRequested = false; } @@ -297,7 +267,6 @@ void ExchangeContext::Free() DoClose(false); mExchangeMgr = nullptr; - mAppState = nullptr; em->DecrementContextsInUse(); diff --git a/src/messaging/ExchangeContext.h b/src/messaging/ExchangeContext.h index 1dd583fd5399f5..d58e140c6528fd 100644 --- a/src/messaging/ExchangeContext.h +++ b/src/messaging/ExchangeContext.h @@ -59,7 +59,6 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, { friend class ExchangeManager; friend class ExchangeContextDeletor; - friend class MessageCounterSyncMgr; public: typedef uint32_t Timeout; // Type used to express the timeout in this ExchangeContext, in milliseconds @@ -152,12 +151,6 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, uint16_t GetExchangeId() const { return mExchangeId; } - void SetAppState(void * state) { mAppState = state; } - - void * GetAppState() const { return mAppState; } - - SecureSessionHandle GetSecureSessionHandle() const { return mSecureSession; } - /* * In order to use reference counting (see refCount below) we use a hold/free paradigm where users of the exchange * can hold onto it while it's out of their direct control to make sure it isn't closed before everyone's ready. @@ -173,7 +166,6 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, ExchangeDelegateBase * mDelegate = nullptr; ExchangeManager * mExchangeMgr = nullptr; ExchangeACL * mExchangeACL = nullptr; - void * mAppState = nullptr; SecureSessionHandle mSecureSession; // The connection state uint16_t mExchangeId; // Assigned exchange ID. @@ -218,12 +210,6 @@ class DLL_EXPORT ExchangeContext : public ReliableMessageContext, CHIP_ERROR StartResponseTimer(); - /** - * A subset of SendMessage functionality that does not perform message - * counter sync for group keys. - */ - CHIP_ERROR SendMessageImpl(Protocols::Id protocolId, uint8_t msgType, System::PacketBufferHandle msgBuf, - const SendFlags & sendFlags, Transport::PeerConnectionState * state = nullptr); void CancelResponseTimer(); static void HandleResponseTimeout(System::Layer * aSystemLayer, void * aAppState, System::Error aError); diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index a34d4047724da4..0c5fe415f0edb6 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -89,9 +89,6 @@ CHIP_ERROR ExchangeManager::Init(SecureSessionMgr * sessionMgr) mReliableMessageMgr.Init(sessionMgr->SystemLayer(), sessionMgr); - err = mMessageCounterSyncMgr.Init(this); - ReturnErrorOnFailure(err); - mState = State::kState_Initialized; return err; @@ -99,7 +96,6 @@ CHIP_ERROR ExchangeManager::Init(SecureSessionMgr * sessionMgr) CHIP_ERROR ExchangeManager::Shutdown() { - mMessageCounterSyncMgr.Shutdown(); mReliableMessageMgr.Shutdown(); for (auto & ec : mContextPool) @@ -319,34 +315,6 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const } } -CHIP_ERROR ExchangeManager::QueueReceivedMessageAndSync(Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - - VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); - - // Queue the message as needed for sync with destination node. - err = mMessageCounterSyncMgr.AddToReceiveTable(std::move(msgBuf)); - SuccessOrExit(err); - - // Initiate message counter synchronization if no message counter synchronization is in progress. - if (!state->IsMsgCounterSyncInProgress()) - { - SecureSessionHandle session(state->GetPeerNodeId(), state->GetPeerKeyID(), state->GetAdminId()); - err = mMessageCounterSyncMgr.SendMsgCounterSyncReq(session); - } - -exit: - if (err != CHIP_NO_ERROR) - { - ChipLogError(ExchangeManager, - "Message counter synchronization for received message, failed to send synchronization request, err = %s", - ErrorStr(err)); - } - - return err; -} - void ExchangeManager::OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) { if (mDelegate != nullptr) diff --git a/src/messaging/ExchangeMgr.h b/src/messaging/ExchangeMgr.h index 2c3938bd9864e7..4c43901da360ea 100644 --- a/src/messaging/ExchangeMgr.h +++ b/src/messaging/ExchangeMgr.h @@ -28,7 +28,6 @@ #include #include -#include #include #include #include @@ -192,7 +191,6 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans ReliableMessageMgr * GetReliableMessageMgr() { return &mReliableMessageMgr; }; - MessageCounterSyncMgr * GetMessageCounterSyncMgr() { return &mMessageCounterSyncMgr; }; Transport::AdminId GetAdminId() { return mAdminId; } uint16_t GetNextKeyId() { return ++mNextKeyId; } @@ -233,7 +231,6 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans ExchangeMgrDelegate * mDelegate; SecureSessionMgr * mSessionMgr; ReliableMessageMgr mReliableMessageMgr; - MessageCounterSyncMgr mMessageCounterSyncMgr; Transport::AdminId mAdminId = 0; @@ -259,8 +256,6 @@ class DLL_EXPORT ExchangeManager : public SecureSessionMgrDelegate, public Trans // TransportMgrDelegate interface for rendezvous sessions void OnMessageReceived(const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf) override; - - CHIP_ERROR QueueReceivedMessageAndSync(Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf) override; }; } // namespace Messaging diff --git a/src/messaging/MessageCounterSync.cpp b/src/messaging/MessageCounterSync.cpp deleted file mode 100644 index 9045be21f8e34b..00000000000000 --- a/src/messaging/MessageCounterSync.cpp +++ /dev/null @@ -1,444 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file implements the CHIP Secure Channel protocol. - * - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace chip { -namespace Messaging { - -CHIP_ERROR MessageCounterSyncMgr::Init(Messaging::ExchangeManager * exchangeMgr) -{ - VerifyOrReturnError(exchangeMgr != nullptr, CHIP_ERROR_INCORRECT_STATE); - mExchangeMgr = exchangeMgr; - - // Register to receive unsolicited Message Counter Synchronization Request messages from the exchange manager. - return mExchangeMgr->RegisterUnsolicitedMessageHandlerForType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq, this); -} - -void MessageCounterSyncMgr::Shutdown() -{ - if (mExchangeMgr != nullptr) - { - mExchangeMgr->UnregisterUnsolicitedMessageHandlerForType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq); - mExchangeMgr = nullptr; - } -} - -void MessageCounterSyncMgr::OnMessageReceived(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, - const PayloadHeader & payloadHeader, System::PacketBufferHandle msgBuf) -{ - if (payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq)) - { - HandleMsgCounterSyncReq(exchangeContext, packetHeader, std::move(msgBuf)); - } - else if (payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp)) - { - HandleMsgCounterSyncResp(exchangeContext, packetHeader, std::move(msgBuf)); - } -} - -void MessageCounterSyncMgr::OnResponseTimeout(Messaging::ExchangeContext * exchangeContext) -{ - Transport::PeerConnectionState * state = - mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(exchangeContext->GetSecureSessionHandle()); - - if (state != nullptr) - { - state->SetMsgCounterSyncInProgress(false); - } - else - { - ChipLogError(ExchangeManager, "Timed out! Failed to clear message counter synchronization status."); - } - - // Close the exchange if MsgCounterSyncRsp is not received before kMsgCounterSyncTimeout. - if (exchangeContext != nullptr) - { - chip::Platform::MemoryFree(exchangeContext->GetAppState()); - exchangeContext->SetAppState(nullptr); - exchangeContext->Close(); - } -} - -CHIP_ERROR MessageCounterSyncMgr::AddToRetransmissionTable(Protocols::Id protocolId, uint8_t msgType, const SendFlags & sendFlags, - System::PacketBufferHandle msgBuf, - Messaging::ExchangeContext * exchangeContext) -{ - bool added = false; - CHIP_ERROR err = CHIP_NO_ERROR; - - VerifyOrReturnError(exchangeContext != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT); - - for (RetransTableEntry & entry : mRetransTable) - { - // Entries are in use if they have an exchangeContext. - if (entry.exchangeContext == nullptr) - { - entry.protocolId = protocolId; - entry.msgType = msgType; - entry.msgBuf = std::move(msgBuf); - entry.exchangeContext = exchangeContext; - entry.exchangeContext->Retain(); - added = true; - - break; - } - } - - if (!added) - { - ChipLogError(ExchangeManager, "MCSP RetransTable Already Full"); - err = CHIP_ERROR_NO_MEMORY; - } - - return err; -} - -/** - * Retransmit all pending messages that were encrypted with application - * group key and were addressed to the specified node. - * - * @param[in] peerNodeId Node ID of the destination node. - * - */ -void MessageCounterSyncMgr::RetransPendingGroupMsgs(NodeId peerNodeId) -{ - // Find all retransmit entries matching peerNodeId. Note that everything in - // this table was using an application group key; that's why it was added. - for (RetransTableEntry & entry : mRetransTable) - { - if (entry.exchangeContext != nullptr && entry.exchangeContext->GetSecureSession().GetPeerNodeId() == peerNodeId) - { - // Retramsmit message. - CHIP_ERROR err = - entry.exchangeContext->SendMessage(entry.protocolId, entry.msgType, std::move(entry.msgBuf), entry.sendFlags); - - if (err != CHIP_NO_ERROR) - { - ChipLogError(ExchangeManager, "Failed to resend cached group message to node: %d with error:%s", peerNodeId, - ErrorStr(err)); - } - - entry.exchangeContext->Release(); - entry.exchangeContext = nullptr; - } - } -} - -CHIP_ERROR MessageCounterSyncMgr::AddToReceiveTable(System::PacketBufferHandle msgBuf) -{ - bool added = false; - CHIP_ERROR err = CHIP_NO_ERROR; - - for (ReceiveTableEntry & entry : mReceiveTable) - { - // Entries are in use if they have a message buffer. - if (entry.msgBuf.IsNull()) - { - entry.msgBuf = std::move(msgBuf); - added = true; - break; - } - } - - if (!added) - { - ChipLogError(ExchangeManager, "MCSP ReceiveTable Already Full"); - err = CHIP_ERROR_NO_MEMORY; - } - - return err; -} - -/** - * Reprocess all pending messages that were encrypted with application - * group key and were addressed to the specified node id. - * - * @param[in] peerNodeId Node ID of the destination node. - * - */ -void MessageCounterSyncMgr::ProcessPendingGroupMsgs(NodeId peerNodeId) -{ - // Find all receive entries matching peerNodeId. Note that everything in - // this table was using an application group key; that's why it was added. - for (ReceiveTableEntry & entry : mReceiveTable) - { - if (!entry.msgBuf.IsNull()) - { - PacketHeader packetHeader; - uint16_t headerSize = 0; - - if (packetHeader.Decode((entry.msgBuf)->Start(), (entry.msgBuf)->DataLength(), &headerSize) != CHIP_NO_ERROR) - { - ChipLogError(ExchangeManager, "ProcessPendingGroupMsgs::Failed to decode PacketHeader"); - break; - } - - if (packetHeader.GetSourceNodeId().HasValue() && packetHeader.GetSourceNodeId().Value() == peerNodeId) - { - // Reprocess message. - mExchangeMgr->GetSessionMgr()->HandleGroupMessageReceived(packetHeader.GetEncryptionKeyID(), - std::move(entry.msgBuf)); - - // Explicitly free any buffer owned by this handle. The - // HandleGroupMessageReceived() call should really handle this, but - // just in case it messes up we don't want to get confused about - // wheter the entry is in use. - entry.msgBuf = nullptr; - } - } - } -} - -// Create and initialize new exchange for the message counter synchronization request/response messages. -CHIP_ERROR MessageCounterSyncMgr::NewMsgCounterSyncExchange(SecureSessionHandle session, - Messaging::ExchangeContext *& exchangeContext) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - - // Message counter synchronization protocol is only applicable for application group keys. - VerifyOrReturnError(ChipKeyId::IsAppGroupKey(session.GetPeerKeyId()), err = CHIP_ERROR_INVALID_ARGUMENT); - - // Create new exchange context. - exchangeContext = mExchangeMgr->NewContext(session, this); - VerifyOrReturnError(exchangeContext != nullptr, err = CHIP_ERROR_NO_MEMORY); - - return err; -} - -CHIP_ERROR MessageCounterSyncMgr::SendMsgCounterSyncReq(SecureSessionHandle session) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - - Messaging::ExchangeContext * exchangeContext = nullptr; - Transport::PeerConnectionState * state = nullptr; - System::PacketBufferHandle msgBuf; - Messaging::SendFlags sendFlags; - void * challenge = chip::Platform::MemoryAlloc(kMsgCounterChallengeSize); - VerifyOrExit(challenge != nullptr, err = CHIP_ERROR_NO_MEMORY); - - state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(session); - VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); - - // Create and initialize new exchange. - err = NewMsgCounterSyncExchange(session, exchangeContext); - SuccessOrExit(err); - - // Allocate a buffer for the null message. - msgBuf = MessagePacketBuffer::New(kMsgCounterChallengeSize); - VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_NO_MEMORY); - - // Generate a 64-bit random number to uniquely identify the request. - err = Crypto::DRBG_get_bytes(static_cast(challenge), kMsgCounterChallengeSize); - SuccessOrExit(err); - - // Store generated Challenge value to ExchangeContext to resolve synchronization response. - exchangeContext->SetAppState(challenge); - memcpy(msgBuf->Start(), challenge, kMsgCounterChallengeSize); - msgBuf->SetDataLength(kMsgCounterChallengeSize); - challenge = nullptr; - - sendFlags.Set(Messaging::SendMessageFlags::kNoAutoRequestAck, true).Set(Messaging::SendMessageFlags::kExpectResponse, true); - - // Arm a timer to enforce that a MsgCounterSyncRsp is received before kMsgCounterSyncTimeout. - exchangeContext->SetResponseTimeout(kMsgCounterSyncTimeout); - - // Send the message counter synchronization request in a Secure Channel Protocol::MsgCounterSyncReq message. - err = exchangeContext->SendMessageImpl(Protocols::SecureChannel::Id, - static_cast(Protocols::SecureChannel::MsgType::MsgCounterSyncReq), - std::move(msgBuf), sendFlags); - SuccessOrExit(err); - - state->SetMsgCounterSyncInProgress(true); - -exit: - if (err != CHIP_NO_ERROR) - { - chip::Platform::MemoryFree(challenge); - ChipLogError(ExchangeManager, "Failed to send message counter synchronization request with error:%s", ErrorStr(err)); - } - - return err; -} - -CHIP_ERROR MessageCounterSyncMgr::SendMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, SecureSessionHandle session) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - Transport::PeerConnectionState * state = nullptr; - System::PacketBufferHandle msgBuf; - uint8_t * msg = nullptr; - - state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(session); - VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); - - // Allocate new buffer. - msgBuf = System::PacketBufferHandle::New(kMsgCounterSyncRespMsgSize); - VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_NO_MEMORY); - - msg = msgBuf->Start(); - - // Let's construct the message using BufBound - { - Encoding::LittleEndian::BufferWriter bbuf(msg, kMsgCounterSyncRespMsgSize); - - // Write the message id (counter) field. - bbuf.Put32(state->GetSendMessageIndex()); - - // Fill in the random value - bbuf.Put(exchangeContext->GetAppState(), kMsgCounterChallengeSize); - - VerifyOrExit(bbuf.Fit(), err = CHIP_ERROR_NO_MEMORY); - } - - // Set message length. - msgBuf->SetDataLength(kMsgCounterSyncRespMsgSize); - - // Send message counter synchronization response message. - err = exchangeContext->SendMessageImpl(Protocols::SecureChannel::Id, - static_cast(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp), - std::move(msgBuf), Messaging::SendFlags(Messaging::SendMessageFlags::kNoAutoRequestAck)); - -exit: - if (err != CHIP_NO_ERROR) - { - ChipLogError(ExchangeManager, "Failed to send message counter synchronization response with error:%s", ErrorStr(err)); - } - - return err; -} - -void MessageCounterSyncMgr::HandleMsgCounterSyncReq(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, - System::PacketBufferHandle msgBuf) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - - void * challenge = nullptr; - uint8_t * req = msgBuf->Start(); - size_t reqlen = msgBuf->DataLength(); - - ChipLogDetail(ExchangeManager, "Received MsgCounterSyncReq request"); - - VerifyOrExit(packetHeader.GetSourceNodeId().HasValue(), err = CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrExit(ChipKeyId::IsAppGroupKey(packetHeader.GetEncryptionKeyID()), err = CHIP_ERROR_WRONG_KEY_TYPE); - VerifyOrExit(req != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); - VerifyOrExit(reqlen == kMsgCounterChallengeSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); - - challenge = chip::Platform::MemoryAlloc(kMsgCounterChallengeSize); - VerifyOrExit(challenge != nullptr, err = CHIP_ERROR_NO_MEMORY); - memcpy(challenge, req, kMsgCounterChallengeSize); - - // Store the 64-bit value sent in the Challenge filed of the MsgCounterSyncReq. - exchangeContext->SetAppState(challenge); - - // Respond with MsgCounterSyncResp - err = SendMsgCounterSyncResp(exchangeContext, { packetHeader.GetSourceNodeId().Value(), packetHeader.GetEncryptionKeyID(), 0 }); - -exit: - if (err != CHIP_NO_ERROR) - { - ChipLogError(ExchangeManager, "Failed to handle MsgCounterSyncReq message with error:%s", ErrorStr(err)); - } - - if (exchangeContext != nullptr) - { - chip::Platform::MemoryFree(exchangeContext->GetAppState()); - exchangeContext->SetAppState(nullptr); - exchangeContext->Close(); - } - - return; -} - -void MessageCounterSyncMgr::HandleMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, - const PacketHeader & packetHeader, System::PacketBufferHandle msgBuf) -{ - CHIP_ERROR err = CHIP_NO_ERROR; - - Transport::PeerConnectionState * state = nullptr; - NodeId peerNodeId = 0; - uint32_t syncCounter = 0; - uint8_t challenge[kMsgCounterChallengeSize]; - - const uint8_t * resp = msgBuf->Start(); - size_t resplen = msgBuf->DataLength(); - - ChipLogDetail(ExchangeManager, "Received MsgCounterSyncResp response"); - - // Find an active connection to the specified peer node - state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(exchangeContext->GetSecureSessionHandle()); - VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); - - state->SetMsgCounterSyncInProgress(false); - - VerifyOrExit(msgBuf->DataLength() == kMsgCounterSyncRespMsgSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); - VerifyOrExit(ChipKeyId::IsAppGroupKey(packetHeader.GetEncryptionKeyID()), err = CHIP_ERROR_WRONG_KEY_TYPE); - - // Store the 64-bit value sent in the Challenge filed of the MsgCounterSyncReq. - VerifyOrExit(resp != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); - VerifyOrExit(resplen == kMsgCounterSyncRespMsgSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); - - syncCounter = chip::Encoding::LittleEndian::Read32(resp); - VerifyOrExit(syncCounter != 0, err = CHIP_ERROR_READ_FAILED); - - memcpy(challenge, resp, kMsgCounterChallengeSize); - - // Verify that the response field matches the expected Challenge field for the exchange. - VerifyOrExit(memcmp(exchangeContext->GetAppState(), challenge, kMsgCounterChallengeSize) == 0, - err = CHIP_ERROR_INVALID_SIGNATURE); - - VerifyOrExit(packetHeader.GetSourceNodeId().HasValue(), err = CHIP_ERROR_INVALID_ARGUMENT); - peerNodeId = packetHeader.GetSourceNodeId().Value(); - - // Process all queued ougoing and incomming group messages after message counter synchronization is completed. - RetransPendingGroupMsgs(peerNodeId); - ProcessPendingGroupMsgs(peerNodeId); - -exit: - if (err != CHIP_NO_ERROR) - { - ChipLogError(ExchangeManager, "Failed to handle MsgCounterSyncResp message with error:%s", ErrorStr(err)); - } - - if (exchangeContext != nullptr) - { - chip::Platform::MemoryFree(exchangeContext->GetAppState()); - exchangeContext->SetAppState(nullptr); - exchangeContext->Close(); - } - - return; -} - -} // namespace Messaging -} // namespace chip diff --git a/src/messaging/MessageCounterSync.h b/src/messaging/MessageCounterSync.h deleted file mode 100644 index 732e4205be78b3..00000000000000 --- a/src/messaging/MessageCounterSync.h +++ /dev/null @@ -1,163 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * @file - * This file defines types and objects for CHIP Secure Channel protocol. - * - */ - -#pragma once - -#include -#include - -namespace chip { -namespace Messaging { - -constexpr uint16_t kMsgCounterChallengeSize = 8; // The size of the message counter synchronization request message. -constexpr uint16_t kMsgCounterSyncRespMsgSize = 12; // The size of the message counter synchronization response message. -constexpr uint32_t kMsgCounterSyncTimeout = 500; // The amount of time(in milliseconds) which a peer is given to respond - // to a message counter synchronization request. - -class ExchangeManager; - -class MessageCounterSyncMgr : public Messaging::ExchangeDelegate -{ -public: - MessageCounterSyncMgr() : mExchangeMgr(nullptr) {} - - CHIP_ERROR Init(Messaging::ExchangeManager * exchangeMgr); - void Shutdown(); - - /** - * Send peer message counter synchronization request. - * This function is called while processing a message encrypted with an application key from a peer whose message counter is not - * synchronized. This message is sent on a newly created exchange, which is closed immediately after. - * - * @param[in] session The secure session handle of the received message. - * - * @retval #CHIP_ERROR_NO_MEMORY If memory could not be allocated for the new - * exchange context or new message buffer. - * @retval #CHIP_NO_ERROR On success. - * - */ - CHIP_ERROR SendMsgCounterSyncReq(SecureSessionHandle session); - - /** - * Add a CHIP message into the cache table to queue the outgoing messages that trigger message counter synchronization protocol - * for retransmission. - * - * @param[in] protocolId The protocol identifier of the CHIP message to be sent. - * - * @param[in] msgType The message type of the corresponding protocol. - * - * @param[in] sendFlags Flags set by the application for the CHIP message being sent. - * - * @param[in] msgBuf A handle to the packet buffer holding the CHIP message. - * - * @param[in] exchangeContext A pointer to the exchange context object associated with the message being sent. - * - * @retval #CHIP_ERROR_NO_MEMORY If there is no empty slot left in the table for addition. - * @retval #CHIP_NO_ERROR On success. - */ - CHIP_ERROR AddToRetransmissionTable(Protocols::Id protocolId, uint8_t msgType, const SendFlags & sendFlags, - System::PacketBufferHandle msgBuf, Messaging::ExchangeContext * exchangeContext); - - /** - * Add a CHIP message into the cache table to queue the incoming messages that trigger message counter synchronization - * protocol for re-processing. - * - * @param[in] msgBuf A handle to the packet buffer holding the received message. - * - * @retval #CHIP_ERROR_NO_MEMORY If there is no empty slot left in the table for addition. - * @retval #CHIP_NO_ERROR On success. - */ - CHIP_ERROR AddToReceiveTable(System::PacketBufferHandle msgBuf); - -private: - /** - * @class RetransTableEntry - * - * @brief - * This class is part of the CHIP Message Counter Synchronization Protocol and is used - * to keep track of a CHIP messages to be transmitted to a destination node whose message - * counter is unknown. The message would be retransmitted from this table after message - * counter synchronization is completed. - * - */ - struct RetransTableEntry - { - RetransTableEntry() : protocolId(Protocols::NotSpecified) {} - ExchangeContext * exchangeContext; /**< The ExchangeContext for the stored CHIP message. - Non-null if and only if this entry is in use. */ - System::PacketBufferHandle msgBuf; /**< A handle to the PacketBuffer object holding the CHIP message. */ - SendFlags sendFlags; /**< Flags set by the application for the CHIP message being sent. */ - Protocols::Id protocolId; /**< The protocol identifier of the CHIP message to be sent. */ - uint8_t msgType; /**< The message type of the CHIP message to be sent. */ - }; - - /** - * @class RetransTableEntry - * - * @brief - * This class is part of the CHIP Message Counter Synchronization Protocol and is used - * to keep track of a CHIP messages to be reprocessed whose source's - * message counter is unknown. The message is reprocessed after message - * counter synchronization is completed. - * - */ - struct ReceiveTableEntry - { - System::PacketBufferHandle msgBuf; /**< A handle to the PacketBuffer object holding - the message data. This is non-null if and only - if this entry is in use. */ - }; - - Messaging::ExchangeManager * mExchangeMgr; // [READ ONLY] Associated Exchange Manager object. - - // MessageCounterSyncProtocol cache table to queue the outgoing messages that trigger message counter - // synchronization protocol. Reserve two extra exchanges, one for MCSP messages and another one for - // temporary exchange for ack. - RetransTableEntry mRetransTable[CHIP_CONFIG_MAX_EXCHANGE_CONTEXTS - 2]; - - // MessageCounterSyncProtocol cache table to queue the incoming messages that trigger message counter - // synchronization protocol. Reserve two extra exchanges, one for MCSP messages and another one for - // temporary exchange for ack. - ReceiveTableEntry mReceiveTable[CHIP_CONFIG_MAX_EXCHANGE_CONTEXTS - 2]; - - void RetransPendingGroupMsgs(NodeId peerNodeId); - - void ProcessPendingGroupMsgs(NodeId peerNodeId); - - CHIP_ERROR NewMsgCounterSyncExchange(SecureSessionHandle session, Messaging::ExchangeContext *& exchangeContext); - - CHIP_ERROR SendMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, SecureSessionHandle session); - - void HandleMsgCounterSyncReq(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, - System::PacketBufferHandle msgBuf); - - void HandleMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, - System::PacketBufferHandle msgBuf); - - void OnMessageReceived(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, - const PayloadHeader & payloadHeader, System::PacketBufferHandle payload) override; - - void OnResponseTimeout(Messaging::ExchangeContext * exchangeContext) override; -}; - -} // namespace Messaging -} // namespace chip diff --git a/src/messaging/tests/MessagingContext.cpp b/src/messaging/tests/MessagingContext.cpp index 421ba5dec0c0a8..3e67f870ce42ad 100644 --- a/src/messaging/tests/MessagingContext.cpp +++ b/src/messaging/tests/MessagingContext.cpp @@ -35,9 +35,11 @@ CHIP_ERROR MessagingContext::Init(nlTestSuite * suite, TransportMgrBase * transp chip::Transport::AdminPairingInfo * destNodeAdmin = mAdmins.AssignAdminId(mDestAdminId, GetDestinationNodeId()); VerifyOrReturnError(destNodeAdmin != nullptr, CHIP_ERROR_NO_MEMORY); - ReturnErrorOnFailure(mSecureSessionMgr.Init(GetSourceNodeId(), &GetSystemLayer(), transport, &mAdmins)); + ReturnErrorOnFailure( + mSecureSessionMgr.Init(GetSourceNodeId(), &GetSystemLayer(), transport, &mAdmins, &mMessageCounterManager)); ReturnErrorOnFailure(mExchangeManager.Init(&mSecureSessionMgr)); + ReturnErrorOnFailure(mMessageCounterManager.Init(&mExchangeManager)); ReturnErrorOnFailure(mSecureSessionMgr.NewPairing(mPeer, GetDestinationNodeId(), &mPairingLocalToPeer, SecureSession::SessionRole::kInitiator, mSrcAdminId)); @@ -53,16 +55,28 @@ CHIP_ERROR MessagingContext::Shutdown() return IOContext::Shutdown(); } +SecureSessionHandle MessagingContext::GetSessionLocalToPeer() +{ + // TODO: temporarily create a SecureSessionHandle from node id, will be fixed in PR 3602 + return { GetDestinationNodeId(), GetPeerKeyId(), GetAdminId() }; +} + +SecureSessionHandle MessagingContext::GetSessionPeerToLocal() +{ + // TODO: temporarily create a SecureSessionHandle from node id, will be fixed in PR 3602 + return { GetSourceNodeId(), GetLocalKeyId(), GetAdminId() }; +} + Messaging::ExchangeContext * MessagingContext::NewExchangeToPeer(Messaging::ExchangeDelegateBase * delegate) { // TODO: temprary create a SecureSessionHandle from node id, will be fix in PR 3602 - return mExchangeManager.NewContext({ GetDestinationNodeId(), GetPeerKeyId(), GetAdminId() }, delegate); + return mExchangeManager.NewContext(GetSessionLocalToPeer(), delegate); } Messaging::ExchangeContext * MessagingContext::NewExchangeToLocal(Messaging::ExchangeDelegateBase * delegate) { // TODO: temprary create a SecureSessionHandle from node id, will be fix in PR 3602 - return mExchangeManager.NewContext({ GetSourceNodeId(), GetLocalKeyId(), GetAdminId() }, delegate); + return mExchangeManager.NewContext(GetSessionPeerToLocal(), delegate); } } // namespace Test diff --git a/src/messaging/tests/MessagingContext.h b/src/messaging/tests/MessagingContext.h index 0b3698d068bb8d..13acc1185c36fc 100644 --- a/src/messaging/tests/MessagingContext.h +++ b/src/messaging/tests/MessagingContext.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -73,6 +74,10 @@ class MessagingContext : public IOContext SecureSessionMgr & GetSecureSessionManager() { return mSecureSessionMgr; } Messaging::ExchangeManager & GetExchangeManager() { return mExchangeManager; } + secure_channel::MessageCounterManager & GetMessageCounterManager() { return mMessageCounterManager; } + + SecureSessionHandle GetSessionLocalToPeer(); + SecureSessionHandle GetSessionPeerToLocal(); Messaging::ExchangeContext * NewExchangeToPeer(Messaging::ExchangeDelegateBase * delegate); Messaging::ExchangeContext * NewExchangeToLocal(Messaging::ExchangeDelegateBase * delegate); @@ -82,6 +87,7 @@ class MessagingContext : public IOContext private: SecureSessionMgr mSecureSessionMgr; Messaging::ExchangeManager mExchangeManager; + secure_channel::MessageCounterManager mMessageCounterManager; NodeId mSourceNodeId = 123654; NodeId mDestinationNodeId = 111222333; diff --git a/src/messaging/tests/TestExchangeMgr.cpp b/src/messaging/tests/TestExchangeMgr.cpp index 5c9ef2593d9dda..6e34e3902d7411 100644 --- a/src/messaging/tests/TestExchangeMgr.cpp +++ b/src/messaging/tests/TestExchangeMgr.cpp @@ -211,8 +211,14 @@ int Initialize(void * aContext) if (err != CHIP_NO_ERROR) return FAILURE; - err = reinterpret_cast(aContext)->Init(&sSuite, &gTransportMgr); - return (err == CHIP_NO_ERROR) ? SUCCESS : FAILURE; + auto * ctx = reinterpret_cast(aContext); + err = ctx->Init(&sSuite, &gTransportMgr); + if (err != CHIP_NO_ERROR) + { + return FAILURE; + } + + return SUCCESS; } /** diff --git a/src/messaging/tests/TestMessageCounterSyncMgr.cpp b/src/messaging/tests/TestMessageCounterSyncMgr.cpp deleted file mode 100644 index c0e74e888cbe8b..00000000000000 --- a/src/messaging/tests/TestMessageCounterSyncMgr.cpp +++ /dev/null @@ -1,432 +0,0 @@ -/* - * - * Copyright (c) 2021 Project CHIP Authors - * All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * @file - * This file implements unit tests for the MessageCounterSyncMgr implementation. - */ - -#include "TestMessagingLayer.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -#include -#include -#include -#include -#include - -namespace { - -using namespace chip; -using namespace chip::Inet; -using namespace chip::Transport; -using namespace chip::Messaging; -using namespace chip::Protocols; - -using TestContext = chip::Test::MessagingContext; - -TestContext sContext; - -constexpr NodeId kSourceNodeId = 123654; -constexpr NodeId kDestinationNodeId = 111222333; -constexpr NodeId kTestPeerGroupKeyId = 0x4000; -constexpr NodeId kTestLocalGroupKeyId = 0x5000; - -const char PAYLOAD[] = "Hello!"; - -class LoopbackTransport : public Transport::Base -{ -public: - /// Transports are required to have a constructor that takes exactly one argument - CHIP_ERROR Init(const char * unused) { return CHIP_NO_ERROR; } - - CHIP_ERROR SendMessage(const PacketHeader & header, const PeerAddress & address, System::PacketBufferHandle msgBuf) override - { - HandleMessageReceived(header, address, std::move(msgBuf)); - - return CHIP_NO_ERROR; - } - - bool CanSendToPeer(const PeerAddress & address) override { return true; } -}; - -class TestExchangeMgr : public SecureSessionMgrDelegate -{ -public: - void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, SecureSessionHandle session, - const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf, - SecureSessionMgr * mgr) override - { - NL_TEST_ASSERT(mSuite, header.GetSourceNodeId() == Optional::Value(kSourceNodeId)); - NL_TEST_ASSERT(mSuite, header.GetDestinationNodeId() == Optional::Value(kDestinationNodeId)); - NL_TEST_ASSERT(mSuite, msgBuf->DataLength() == kMsgCounterChallengeSize); - - ReceiveHandlerCallCount++; - } - - void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) override {} - - void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) override {} - - nlTestSuite * mSuite = nullptr; - int ReceiveHandlerCallCount = 0; -}; - -class TestSessMgrCallback : public SecureSessionMgrDelegate -{ -public: - void OnMessageReceived(const PacketHeader & header, const PayloadHeader & payloadHeader, SecureSessionHandle session, - const Transport::PeerAddress & source, System::PacketBufferHandle msgBuf, - SecureSessionMgr * mgr) override - {} - - void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) override - { - if (NewConnectionHandlerCallCount == 0) - { - mRemoteToLocalSession = session; - } - - if (NewConnectionHandlerCallCount == 1) - { - mLocalToRemoteSession = session; - } - NewConnectionHandlerCallCount++; - } - - void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) override {} - - CHIP_ERROR QueueReceivedMessageAndSync(Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf) override - { - PacketHeader packetHeader; - uint16_t headerSize = 0; - - CHIP_ERROR err = packetHeader.Decode(msgBuf->Start(), msgBuf->DataLength(), &headerSize); - NL_TEST_ASSERT(mSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(mSuite, ChipKeyId::IsAppGroupKey(packetHeader.GetEncryptionKeyID()) == true); - - ReceiveHandlerCallCount++; - - return CHIP_NO_ERROR; - } - - nlTestSuite * mSuite = nullptr; - SecureSessionHandle mRemoteToLocalSession; - SecureSessionHandle mLocalToRemoteSession; - int ReceiveHandlerCallCount = 0; - int NewConnectionHandlerCallCount = 0; -}; - -class MockAppDelegate : public ExchangeDelegate -{ -public: - void OnMessageReceived(ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, - System::PacketBufferHandle msgBuf) override - { - IsOnMessageReceivedCalled = true; - - NL_TEST_ASSERT(mSuite, payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq)); - NL_TEST_ASSERT(mSuite, packetHeader.GetSourceNodeId() == Optional::Value(kSourceNodeId)); - NL_TEST_ASSERT(mSuite, packetHeader.GetDestinationNodeId() == Optional::Value(kDestinationNodeId)); - NL_TEST_ASSERT(mSuite, msgBuf->DataLength() == kMsgCounterChallengeSize); - - ec->Close(); - } - - void OnResponseTimeout(ExchangeContext * ec) override {} - - nlTestSuite * mSuite = nullptr; - bool IsOnMessageReceivedCalled = false; -}; - -TransportMgr gTransportMgr; - -void CheckSendMsgCounterSyncReq(nlTestSuite * inSuite, void * inContext) -{ - TestContext & ctx = *reinterpret_cast(inContext); - - ctx.GetInetLayer().SystemLayer()->Init(nullptr); - - IPAddress addr; - IPAddress::FromString("127.0.0.1", addr); - - CHIP_ERROR err = CHIP_NO_ERROR; - TestExchangeMgr testExchangeMgr; - - testExchangeMgr.mSuite = inSuite; - ctx.GetSecureSessionManager().SetDelegate(&testExchangeMgr); - - MessageCounterSyncMgr * sm = ctx.GetExchangeManager().GetMessageCounterSyncMgr(); - NL_TEST_ASSERT(inSuite, sm != nullptr); - - Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); - - SecurePairingUsingTestSecret pairingLocalToPeer(kTestPeerGroupKeyId, kTestLocalGroupKeyId); - - err = ctx.GetSecureSessionManager().NewPairing(peer, kDestinationNodeId, &pairingLocalToPeer, - SecureSession::SessionRole::kInitiator, 0); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - SecurePairingUsingTestSecret pairingPeerToLocal(kTestLocalGroupKeyId, kTestPeerGroupKeyId); - - err = ctx.GetSecureSessionManager().NewPairing(peer, kSourceNodeId, &pairingPeerToLocal, SecureSession::SessionRole::kResponder, - 1); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - SecureSessionHandle session(kDestinationNodeId, 0x4000, 0); - - // Should be able to send a message to itself by just calling send. - testExchangeMgr.ReceiveHandlerCallCount = 0; - - err = sm->SendMsgCounterSyncReq(session); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, testExchangeMgr.ReceiveHandlerCallCount == 1); -} - -void CheckReceiveMsgCounterSyncReq(nlTestSuite * inSuite, void * inContext) -{ - TestContext & ctx = *reinterpret_cast(inContext); - - ctx.GetInetLayer().SystemLayer()->Init(nullptr); - - IPAddress addr; - IPAddress::FromString("127.0.0.1", addr); - - CHIP_ERROR err = CHIP_NO_ERROR; - MockAppDelegate mockAppDelegate; - - mockAppDelegate.mSuite = inSuite; - - MessageCounterSyncMgr * sm = ctx.GetExchangeManager().GetMessageCounterSyncMgr(); - NL_TEST_ASSERT(inSuite, sm != nullptr); - - // Register to receive unsolicited Secure Channel Request messages from the exchange manager. - err = ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq, - &mockAppDelegate); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); - - SecurePairingUsingTestSecret pairingLocalToPeer(kTestPeerGroupKeyId, kTestLocalGroupKeyId); - - err = ctx.GetSecureSessionManager().NewPairing(peer, kDestinationNodeId, &pairingLocalToPeer, - SecureSession::SessionRole::kInitiator, 0); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - SecurePairingUsingTestSecret pairingPeerToLocal(kTestLocalGroupKeyId, kTestPeerGroupKeyId); - - err = ctx.GetSecureSessionManager().NewPairing(peer, kSourceNodeId, &pairingPeerToLocal, SecureSession::SessionRole::kResponder, - 1); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - SecureSessionHandle session(kDestinationNodeId, 0x4000, 0); - - err = sm->SendMsgCounterSyncReq(session); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, mockAppDelegate.IsOnMessageReceivedCalled == true); -} - -void CheckAddRetransTable(nlTestSuite * inSuite, void * inContext) -{ - TestContext & ctx = *reinterpret_cast(inContext); - - ctx.GetInetLayer().SystemLayer()->Init(nullptr); - - MockAppDelegate mockAppDelegate; - ExchangeContext * exchange = ctx.NewExchangeToPeer(&mockAppDelegate); - NL_TEST_ASSERT(inSuite, exchange != nullptr); - - MessageCounterSyncMgr * sm = ctx.GetExchangeManager().GetMessageCounterSyncMgr(); - NL_TEST_ASSERT(inSuite, sm != nullptr); - - System::PacketBufferHandle buffer = MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); - NL_TEST_ASSERT(inSuite, !buffer.IsNull()); - - CHIP_ERROR err = - sm->AddToRetransmissionTable(Protocols::Echo::Id, static_cast(Protocols::Echo::MsgType::EchoRequest), - Messaging::SendFlags(Messaging::SendMessageFlags::kNone), std::move(buffer), exchange); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); -} - -void CheckAddToReceiveTable(nlTestSuite * inSuite, void * inContext) -{ - TestContext & ctx = *reinterpret_cast(inContext); - - ctx.GetInetLayer().SystemLayer()->Init(nullptr); - - MessageCounterSyncMgr * sm = ctx.GetExchangeManager().GetMessageCounterSyncMgr(); - NL_TEST_ASSERT(inSuite, sm != nullptr); - - System::PacketBufferHandle buffer = MessagePacketBuffer::NewWithData(PAYLOAD, sizeof(PAYLOAD)); - NL_TEST_ASSERT(inSuite, !buffer.IsNull()); - - CHIP_ERROR err = sm->AddToReceiveTable(std::move(buffer)); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); -} - -TestSessMgrCallback callback; - -void CheckReceiveMessage(nlTestSuite * inSuite, void * inContext) -{ - TestContext & ctx = *reinterpret_cast(inContext); - - uint16_t payload_len = sizeof(PAYLOAD); - - ctx.GetInetLayer().SystemLayer()->Init(nullptr); - - chip::System::PacketBufferHandle buffer = chip::MessagePacketBuffer::NewWithData(PAYLOAD, payload_len); - NL_TEST_ASSERT(inSuite, !buffer.IsNull()); - - IPAddress addr; - IPAddress::FromString("127.0.0.1", addr); - CHIP_ERROR err = CHIP_NO_ERROR; - - TransportMgr transportMgr; - SecureSessionMgr secureSessionMgr; - - err = transportMgr.Init("LOOPBACK"); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - Transport::AdminPairingTable admins; - err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - callback.mSuite = inSuite; - - secureSessionMgr.SetDelegate(&callback); - - Optional peer(Transport::PeerAddress::UDP(addr, CHIP_PORT)); - - Transport::AdminPairingInfo * admin = admins.AssignAdminId(0, kSourceNodeId); - NL_TEST_ASSERT(inSuite, admin != nullptr); - - admin = admins.AssignAdminId(1, kDestinationNodeId); - NL_TEST_ASSERT(inSuite, admin != nullptr); - - SecurePairingUsingTestSecret pairingPeerToLocal(kTestLocalGroupKeyId, kTestPeerGroupKeyId); - - err = secureSessionMgr.NewPairing(peer, kSourceNodeId, &pairingPeerToLocal, SecureSession::SessionRole::kInitiator, 1); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - SecurePairingUsingTestSecret pairingLocalToPeer(kTestPeerGroupKeyId, kTestLocalGroupKeyId); - err = secureSessionMgr.NewPairing(peer, kDestinationNodeId, &pairingLocalToPeer, SecureSession::SessionRole::kResponder, 0); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - SecureSessionHandle localToRemoteSession = callback.mLocalToRemoteSession; - - // Should be able to send a message to itself by just calling send. - callback.ReceiveHandlerCallCount = 0; - - PayloadHeader payloadHeader; - - // Set the exchange ID for this header. - payloadHeader.SetExchangeID(0); - - // Set the protocol ID and message type for this header. - payloadHeader.SetMessageType(chip::Protocols::Echo::MsgType::EchoRequest); - - err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer)); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 0; }); - - NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); -} - -// Test Suite - -/** - * Test Suite that lists all the test functions. - */ -// clang-format off -const nlTest sTests[] = -{ - NL_TEST_DEF("Test MessageCounterSyncMgr::ReceiveMsgCounterSyncReq", CheckReceiveMsgCounterSyncReq), - NL_TEST_DEF("Test MessageCounterSyncMgr::SendMsgCounterSyncReq", CheckSendMsgCounterSyncReq), - NL_TEST_DEF("Test MessageCounterSyncMgr::AddToRetransTable", CheckAddRetransTable), - NL_TEST_DEF("Test MessageCounterSyncMgr::AddToReceiveTable", CheckAddToReceiveTable), - NL_TEST_DEF("Test MessageCounterSyncMgr::ReceiveMessage", CheckReceiveMessage), - NL_TEST_SENTINEL() -}; -// clang-format on - -int Initialize(void * aContext); -int Finalize(void * aContext); - -// clang-format off -nlTestSuite sSuite = -{ - "Test-MessageCounterSyncMgr", - &sTests[0], - Initialize, - Finalize -}; -// clang-format on - -/** - * Initialize the test suite. - */ -int Initialize(void * aContext) -{ - CHIP_ERROR err = chip::Platform::MemoryInit(); - if (err != CHIP_NO_ERROR) - return FAILURE; - - err = gTransportMgr.Init("LOOPBACK"); - if (err != CHIP_NO_ERROR) - return FAILURE; - - err = reinterpret_cast(aContext)->Init(&sSuite, &gTransportMgr); - return (err == CHIP_NO_ERROR) ? SUCCESS : FAILURE; -} - -/** - * Finalize the test suite. - */ -int Finalize(void * aContext) -{ - CHIP_ERROR err = reinterpret_cast(aContext)->Shutdown(); - return (err == CHIP_NO_ERROR) ? SUCCESS : FAILURE; -} - -} // namespace - -/** - * Main - */ -int TestMessageCounterSyncMgr() -{ - // Run test suit against one context - nlTestRunner(&sSuite, &sContext); - - return (nlTestRunnerStats(&sSuite)); -} diff --git a/src/messaging/tests/TestMessagingLayer.h b/src/messaging/tests/TestMessagingLayer.h index fbab413701831b..7bdcef314c4c0b 100644 --- a/src/messaging/tests/TestMessagingLayer.h +++ b/src/messaging/tests/TestMessagingLayer.h @@ -29,7 +29,6 @@ extern "C" { #endif int TestExchangeMgr(void); -int TestMessageCounterSyncMgr(void); int TestReliableMessageProtocol(void); #ifdef __cplusplus diff --git a/src/messaging/tests/TestReliableMessageProtocol.cpp b/src/messaging/tests/TestReliableMessageProtocol.cpp index 3477a66f75fd8f..0ccca4c76e8ede 100644 --- a/src/messaging/tests/TestReliableMessageProtocol.cpp +++ b/src/messaging/tests/TestReliableMessageProtocol.cpp @@ -261,8 +261,14 @@ int Initialize(void * aContext) if (err != CHIP_NO_ERROR) return FAILURE; - err = reinterpret_cast(aContext)->Init(&sSuite, &gTransportMgr); - return (err == CHIP_NO_ERROR) ? SUCCESS : FAILURE; + auto * ctx = reinterpret_cast(aContext); + err = ctx->Init(&sSuite, &gTransportMgr); + if (err != CHIP_NO_ERROR) + { + return FAILURE; + } + + return SUCCESS; } /** diff --git a/src/messaging/tests/echo/common.cpp b/src/messaging/tests/echo/common.cpp index dc5dd524d321fc..280685f231749d 100644 --- a/src/messaging/tests/echo/common.cpp +++ b/src/messaging/tests/echo/common.cpp @@ -27,10 +27,12 @@ #include "common.h" #include #include +#include #include // The ExchangeManager global object. chip::Messaging::ExchangeManager gExchangeManager; +chip::secure_channel::MessageCounterManager gMessageCounterManager; void InitializeChip(void) { diff --git a/src/messaging/tests/echo/common.h b/src/messaging/tests/echo/common.h index 0c5c1ff1e4818e..cb38a93644d7a6 100644 --- a/src/messaging/tests/echo/common.h +++ b/src/messaging/tests/echo/common.h @@ -25,12 +25,14 @@ #pragma once #include +#include constexpr size_t kMaxTcpActiveConnectionCount = 4; constexpr size_t kMaxTcpPendingPackets = 4; constexpr size_t kNetworkSleepTimeMsecs = (100 * 1000); extern chip::Messaging::ExchangeManager gExchangeManager; +extern chip::secure_channel::MessageCounterManager gMessageCounterManager; void InitializeChip(void); void ShutdownChip(void); diff --git a/src/messaging/tests/echo/echo_requester.cpp b/src/messaging/tests/echo/echo_requester.cpp index 5e0267af8871fc..3dd1ce0a9f512f 100644 --- a/src/messaging/tests/echo/echo_requester.cpp +++ b/src/messaging/tests/echo/echo_requester.cpp @@ -241,7 +241,8 @@ int main(int argc, char * argv[]) .SetListenPort(ECHO_CLIENT_PORT)); SuccessOrExit(err); - err = gSessionManager.Init(chip::kTestControllerNodeId, &chip::DeviceLayer::SystemLayer, &gTCPManager, &admins); + err = gSessionManager.Init(chip::kTestControllerNodeId, &chip::DeviceLayer::SystemLayer, &gTCPManager, &admins, + &gMessageCounterManager); SuccessOrExit(err); } else @@ -251,13 +252,17 @@ int main(int argc, char * argv[]) .SetListenPort(ECHO_CLIENT_PORT)); SuccessOrExit(err); - err = gSessionManager.Init(chip::kTestControllerNodeId, &chip::DeviceLayer::SystemLayer, &gUDPManager, &admins); + err = gSessionManager.Init(chip::kTestControllerNodeId, &chip::DeviceLayer::SystemLayer, &gUDPManager, &admins, + &gMessageCounterManager); SuccessOrExit(err); } err = gExchangeManager.Init(&gSessionManager); SuccessOrExit(err); + err = gMessageCounterManager.Init(&gExchangeManager); + SuccessOrExit(err); + // Start the CHIP connection to the CHIP echo responder. err = EstablishSecureSession(); SuccessOrExit(err); diff --git a/src/messaging/tests/echo/echo_responder.cpp b/src/messaging/tests/echo/echo_responder.cpp index 71822f45160c35..32d03ad0e56c4e 100644 --- a/src/messaging/tests/echo/echo_responder.cpp +++ b/src/messaging/tests/echo/echo_responder.cpp @@ -95,7 +95,8 @@ int main(int argc, char * argv[]) chip::Transport::TcpListenParameters(&chip::DeviceLayer::InetLayer).SetAddressType(chip::Inet::kIPAddressType_IPv4)); SuccessOrExit(err); - err = gSessionManager.Init(chip::kTestDeviceNodeId, &chip::DeviceLayer::SystemLayer, &gTCPManager, &admins); + err = gSessionManager.Init(chip::kTestDeviceNodeId, &chip::DeviceLayer::SystemLayer, &gTCPManager, &admins, + &gMessageCounterManager); SuccessOrExit(err); } else @@ -104,13 +105,17 @@ int main(int argc, char * argv[]) chip::Transport::UdpListenParameters(&chip::DeviceLayer::InetLayer).SetAddressType(chip::Inet::kIPAddressType_IPv4)); SuccessOrExit(err); - err = gSessionManager.Init(chip::kTestDeviceNodeId, &chip::DeviceLayer::SystemLayer, &gUDPManager, &admins); + err = gSessionManager.Init(chip::kTestDeviceNodeId, &chip::DeviceLayer::SystemLayer, &gUDPManager, &admins, + &gMessageCounterManager); SuccessOrExit(err); } err = gExchangeManager.Init(&gSessionManager); SuccessOrExit(err); + err = gMessageCounterManager.Init(&gExchangeManager); + SuccessOrExit(err); + if (!disableEcho) { err = gEchoServer.Init(&gExchangeManager); diff --git a/src/platform/EFR32/CHIPPlatformConfig.h b/src/platform/EFR32/CHIPPlatformConfig.h index 23114fb2bc8b69..a3206d1fbf27c8 100644 --- a/src/platform/EFR32/CHIPPlatformConfig.h +++ b/src/platform/EFR32/CHIPPlatformConfig.h @@ -45,6 +45,7 @@ #define CHIP_CONFIG_PERSISTED_STORAGE_MAX_KEY_LENGTH 2 #define CHIP_CONFIG_LIFETIIME_PERSISTED_COUNTER_KEY 0x01 +#define CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER 0x2 #define CHIP_CONFIG_TIME_ENABLE_CLIENT 1 #define CHIP_CONFIG_TIME_ENABLE_SERVER 0 diff --git a/src/platform/K32W/CHIPPlatformConfig.h b/src/platform/K32W/CHIPPlatformConfig.h index 61d8985c30c8ee..9270eca22e71a9 100644 --- a/src/platform/K32W/CHIPPlatformConfig.h +++ b/src/platform/K32W/CHIPPlatformConfig.h @@ -46,6 +46,7 @@ #define CHIP_CONFIG_PERSISTED_STORAGE_MAX_KEY_LENGTH 2 #define CHIP_CONFIG_LIFETIIME_PERSISTED_COUNTER_KEY 0x01 +#define CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER 0x2 #define CHIP_CONFIG_TIME_ENABLE_CLIENT 1 #define CHIP_CONFIG_TIME_ENABLE_SERVER 0 diff --git a/src/platform/cc13x2_26x2/CHIPPlatformConfig.h b/src/platform/cc13x2_26x2/CHIPPlatformConfig.h index a94e1d5ec6446d..5444e51b9db235 100644 --- a/src/platform/cc13x2_26x2/CHIPPlatformConfig.h +++ b/src/platform/cc13x2_26x2/CHIPPlatformConfig.h @@ -43,6 +43,7 @@ #define CHIP_CONFIG_PERSISTED_STORAGE_MAX_KEY_LENGTH 2 #define CHIP_CONFIG_LIFETIIME_PERSISTED_COUNTER_KEY 0x01 +#define CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER 0x2 // ==================== Security Adaptations ==================== diff --git a/src/platform/nrfconnect/CHIPPlatformConfig.h b/src/platform/nrfconnect/CHIPPlatformConfig.h index d5eb3ceb7af004..9b9639c4b22f08 100644 --- a/src/platform/nrfconnect/CHIPPlatformConfig.h +++ b/src/platform/nrfconnect/CHIPPlatformConfig.h @@ -31,6 +31,7 @@ #define CHIP_CONFIG_PERSISTED_STORAGE_MAX_KEY_LENGTH 2 #define CHIP_CONFIG_LIFETIIME_PERSISTED_COUNTER_KEY "rc" +#define CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER "mc" #define CHIP_CONFIG_TIME_ENABLE_CLIENT 1 #define CHIP_CONFIG_TIME_ENABLE_SERVER 0 diff --git a/src/platform/qpg6100/CHIPPlatformConfig.h b/src/platform/qpg6100/CHIPPlatformConfig.h index 64edce10a052f6..15e894cbab8d98 100644 --- a/src/platform/qpg6100/CHIPPlatformConfig.h +++ b/src/platform/qpg6100/CHIPPlatformConfig.h @@ -46,6 +46,7 @@ #define CHIP_CONFIG_PERSISTED_STORAGE_MAX_KEY_LENGTH 2 #define CHIP_CONFIG_LIFETIIME_PERSISTED_COUNTER_KEY 0x01 +#define CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER 0x2 #define CHIP_CONFIG_TIME_ENABLE_CLIENT 1 #define CHIP_CONFIG_TIME_ENABLE_SERVER 0 diff --git a/src/protocols/BUILD.gn b/src/protocols/BUILD.gn index 832ca6ab1e9097..19f363c2e233ee 100644 --- a/src/protocols/BUILD.gn +++ b/src/protocols/BUILD.gn @@ -21,6 +21,8 @@ static_library("protocols") { "echo/Echo.h", "echo/EchoClient.cpp", "echo/EchoServer.cpp", + "secure_channel/MessageCounterManager.cpp", + "secure_channel/MessageCounterManager.h", ] cflags = [ "-Wconversion" ] diff --git a/src/protocols/secure_channel/MessageCounterManager.cpp b/src/protocols/secure_channel/MessageCounterManager.cpp new file mode 100644 index 00000000000000..ca8195153fa4da --- /dev/null +++ b/src/protocols/secure_channel/MessageCounterManager.cpp @@ -0,0 +1,327 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file implements the CHIP message counter messages in secure channel protocol. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace chip { +namespace secure_channel { + +CHIP_ERROR MessageCounterManager::Init(Messaging::ExchangeManager * exchangeMgr) +{ + VerifyOrReturnError(exchangeMgr != nullptr, CHIP_ERROR_INCORRECT_STATE); + mExchangeMgr = exchangeMgr; + + ReturnErrorOnFailure( + mExchangeMgr->RegisterUnsolicitedMessageHandlerForType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq, this)); + + return CHIP_NO_ERROR; +} + +void MessageCounterManager::Shutdown() +{ + if (mExchangeMgr != nullptr) + { + mExchangeMgr->UnregisterUnsolicitedMessageHandlerForType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq); + mExchangeMgr = nullptr; + } +} + +CHIP_ERROR MessageCounterManager::StartSync(SecureSessionHandle session, Transport::PeerConnectionState * state) +{ + // Initiate message counter synchronization if no message counter synchronization is in progress. + Transport::PeerMessageCounter & counter = state->GetSessionMessageCounter().GetPeerMessageCounter(); + if (!counter.IsSynchronizing() && !counter.IsSynchronized()) + { + ReturnErrorOnFailure(SendMsgCounterSyncReq(session, state)); + } + + return CHIP_NO_ERROR; +} + +CHIP_ERROR MessageCounterManager::QueueReceivedMessageAndStartSync(SecureSessionHandle session, + Transport::PeerConnectionState * state, + const Transport::PeerAddress & peerAddress, + System::PacketBufferHandle msgBuf) +{ + // Queue the message to be reprocessed when sync completes. + ReturnErrorOnFailure(AddToReceiveTable(state->GetPeerNodeId(), peerAddress, std::move(msgBuf))); + ReturnErrorOnFailure(StartSync(session, state)); + + // After the message that triggers message counter synchronization is stored, and a message counter + // synchronization exchange is initiated, we need to return immediately and re-process the original message + // when the synchronization is completed. + + return CHIP_NO_ERROR; +} + +void MessageCounterManager::OnMessageReceived(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, + const PayloadHeader & payloadHeader, System::PacketBufferHandle msgBuf) +{ + if (payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq)) + { + HandleMsgCounterSyncReq(exchangeContext, packetHeader, std::move(msgBuf)); + } + else if (payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp)) + { + HandleMsgCounterSyncResp(exchangeContext, packetHeader, std::move(msgBuf)); + } +} + +void MessageCounterManager::OnResponseTimeout(Messaging::ExchangeContext * exchangeContext) +{ + Transport::PeerConnectionState * state = + mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(exchangeContext->GetSecureSession()); + + if (state != nullptr) + { + state->GetSessionMessageCounter().GetPeerMessageCounter().SyncFailed(); + } + else + { + ChipLogError(SecureChannel, "Timed out! Failed to clear message counter synchronization status."); + } + + exchangeContext->Close(); +} + +CHIP_ERROR MessageCounterManager::AddToReceiveTable(NodeId peerNodeId, const Transport::PeerAddress & peerAddress, + System::PacketBufferHandle msgBuf) +{ + bool added = false; + CHIP_ERROR err = CHIP_NO_ERROR; + + for (ReceiveTableEntry & entry : mReceiveTable) + { + if (entry.peerNodeId == kUndefinedNodeId) + { + entry.peerNodeId = peerNodeId; + entry.peerAddress = peerAddress; + entry.msgBuf = std::move(msgBuf); + added = true; + + break; + } + } + + if (!added) + { + ChipLogError(SecureChannel, "MCSP ReceiveTable Already Full"); + err = CHIP_ERROR_NO_MEMORY; + } + + return err; +} + +/** + * Reprocess all pending messages that were encrypted with application + * group key and were addressed to the specified node id. + * + * @param[in] peerNodeId Node ID of the destination node. + * + */ +void MessageCounterManager::ProcessPendingMessages(NodeId peerNodeId) +{ + auto * secureSessionMgr = mExchangeMgr->GetSessionMgr(); + + // Find all receive entries matching peerNodeId. Note that everything in + // this table was using an application group key; that's why it was added. + for (ReceiveTableEntry & entry : mReceiveTable) + { + if (entry.peerNodeId == peerNodeId) + { + // Reprocess message. + secureSessionMgr->OnMessageReceived(entry.peerAddress, std::move(entry.msgBuf)); + + // Explicitly free any buffer owned by this handle. + entry.msgBuf = nullptr; + entry.peerNodeId = kUndefinedNodeId; + } + } +} + +CHIP_ERROR MessageCounterManager::SendMsgCounterSyncReq(SecureSessionHandle session, Transport::PeerConnectionState * state) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + + Messaging::ExchangeContext * exchangeContext = nullptr; + System::PacketBufferHandle msgBuf; + Messaging::SendFlags sendFlags; + + exchangeContext = mExchangeMgr->NewContext(session, this); + VerifyOrExit(exchangeContext != nullptr, err = CHIP_ERROR_NO_MEMORY); + + msgBuf = MessagePacketBuffer::New(kChallengeSize); + VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_NO_MEMORY); + + // Generate a 64-bit random number to uniquely identify the request. + SuccessOrExit(err = Crypto::DRBG_get_bytes(msgBuf->Start(), kChallengeSize)); + + msgBuf->SetDataLength(kChallengeSize); + + // Store generated Challenge value to message counter context to resolve synchronization response. + state->GetSessionMessageCounter().GetPeerMessageCounter().SyncStarting(FixedByteSpan(msgBuf->Start())); + + sendFlags.Set(Messaging::SendMessageFlags::kNoAutoRequestAck).Set(Messaging::SendMessageFlags::kExpectResponse); + + // Arm a timer to enforce that a MsgCounterSyncRsp is received before kSyncTimeoutMs. + exchangeContext->SetResponseTimeout(kSyncTimeoutMs); + + // Send the message counter synchronization request in a Secure Channel Protocol::MsgCounterSyncReq message. + SuccessOrExit( + err = exchangeContext->SendMessage(Protocols::SecureChannel::MsgType::MsgCounterSyncReq, std::move(msgBuf), sendFlags)); + +exit: + if (err != CHIP_NO_ERROR) + { + state->GetSessionMessageCounter().GetPeerMessageCounter().SyncFailed(); + ChipLogError(SecureChannel, "Failed to send message counter synchronization request with error:%s", ErrorStr(err)); + } + + return err; +} + +CHIP_ERROR MessageCounterManager::SendMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, + FixedByteSpan challenge) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + Transport::PeerConnectionState * state = nullptr; + System::PacketBufferHandle msgBuf; + uint8_t * msg = nullptr; + + state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(exchangeContext->GetSecureSession()); + VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); + + // Allocate new buffer. + msgBuf = MessagePacketBuffer::New(kSyncRespMsgSize); + VerifyOrExit(!msgBuf.IsNull(), err = CHIP_ERROR_NO_MEMORY); + + msg = msgBuf->Start(); + + { + Encoding::LittleEndian::BufferWriter bbuf(msg, kSyncRespMsgSize); + bbuf.Put32(state->GetSessionMessageCounter().GetLocalMessageCounter().Value()); + bbuf.Put(challenge.data(), kChallengeSize); + VerifyOrExit(bbuf.Fit(), err = CHIP_ERROR_NO_MEMORY); + } + + msgBuf->SetDataLength(kSyncRespMsgSize); + + err = exchangeContext->SendMessage(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp, std::move(msgBuf), + Messaging::SendFlags(Messaging::SendMessageFlags::kNoAutoRequestAck)); + +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Failed to send message counter synchronization response with error:%s", ErrorStr(err)); + } + + return err; +} + +void MessageCounterManager::HandleMsgCounterSyncReq(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, + System::PacketBufferHandle msgBuf) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + + uint8_t * req = msgBuf->Start(); + size_t reqlen = msgBuf->DataLength(); + + ChipLogDetail(SecureChannel, "Received MsgCounterSyncReq request"); + + VerifyOrExit(packetHeader.GetSourceNodeId().HasValue(), err = CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrExit(req != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); + VerifyOrExit(reqlen == kChallengeSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + + // Respond with MsgCounterSyncResp + err = SendMsgCounterSyncResp(exchangeContext, FixedByteSpan(req)); + +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Failed to handle MsgCounterSyncReq message with error:%s", ErrorStr(err)); + } + + exchangeContext->Close(); + return; +} + +void MessageCounterManager::HandleMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, + const PacketHeader & packetHeader, System::PacketBufferHandle msgBuf) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + + Transport::PeerConnectionState * state = nullptr; + NodeId peerNodeId = 0; + uint32_t syncCounter = 0; + + const uint8_t * resp = msgBuf->Start(); + size_t resplen = msgBuf->DataLength(); + + ChipLogDetail(SecureChannel, "Received MsgCounterSyncResp response"); + + // Find an active connection to the specified peer node + state = mExchangeMgr->GetSessionMgr()->GetPeerConnectionState(exchangeContext->GetSecureSession()); + VerifyOrExit(state != nullptr, err = CHIP_ERROR_NOT_CONNECTED); + + VerifyOrExit(msgBuf->DataLength() == kSyncRespMsgSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + + VerifyOrExit(resp != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); + VerifyOrExit(resplen == kSyncRespMsgSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + + syncCounter = chip::Encoding::LittleEndian::Read32(resp); + VerifyOrExit(syncCounter != 0, err = CHIP_ERROR_READ_FAILED); + + // Verify that the response field matches the expected Challenge field for the exchange. + err = + state->GetSessionMessageCounter().GetPeerMessageCounter().VerifyChallenge(syncCounter, FixedByteSpan(resp)); + SuccessOrExit(err); + + VerifyOrExit(packetHeader.GetSourceNodeId().HasValue(), err = CHIP_ERROR_INVALID_ARGUMENT); + peerNodeId = packetHeader.GetSourceNodeId().Value(); + + // Process all queued incoming messages after message counter synchronization is completed. + ProcessPendingMessages(peerNodeId); + +exit: + if (err != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Failed to handle MsgCounterSyncResp message with error:%s", ErrorStr(err)); + } + + exchangeContext->Close(); + return; +} + +} // namespace secure_channel +} // namespace chip diff --git a/src/protocols/secure_channel/MessageCounterManager.h b/src/protocols/secure_channel/MessageCounterManager.h new file mode 100644 index 00000000000000..bb36d926ad622a --- /dev/null +++ b/src/protocols/secure_channel/MessageCounterManager.h @@ -0,0 +1,122 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file + * This file defines types and objects for CHIP message counter messages in secure channel protocol. + * + */ + +#pragma once + +#include +#include +#include + +namespace chip { +namespace secure_channel { + +class ExchangeManager; + +class MessageCounterManager : public Messaging::ExchangeDelegate, public Transport::MessageCounterManagerInterface +{ +public: + static constexpr uint16_t kChallengeSize = Transport::PeerMessageCounter::kChallengeSize; + static constexpr uint16_t kCounterSize = 4; + static constexpr uint16_t kSyncRespMsgSize = kChallengeSize + kCounterSize; + static constexpr uint32_t kSyncTimeoutMs = 500; + + MessageCounterManager() : mExchangeMgr(nullptr) {} + ~MessageCounterManager() override {} + + CHIP_ERROR Init(Messaging::ExchangeManager * exchangeMgr); + void Shutdown(); + + // Implement MessageCounterManagerInterface + CHIP_ERROR StartSync(SecureSessionHandle session, Transport::PeerConnectionState * state) override; + CHIP_ERROR QueueReceivedMessageAndStartSync(SecureSessionHandle session, Transport::PeerConnectionState * state, + const Transport::PeerAddress & peerAddress, + System::PacketBufferHandle msgBuf) override; + + /** + * Send peer message counter synchronization request. + * This function is called while processing a message encrypted with an application key from a peer whose message counter is not + * synchronized. This message is sent on a newly created exchange, which is closed immediately after. + * + * @param[in] session The secure session handle of the received message. + * + * @retval #CHIP_ERROR_NO_MEMORY If memory could not be allocated for the new + * exchange context or new message buffer. + * @retval #CHIP_NO_ERROR On success. + * + */ + CHIP_ERROR SendMsgCounterSyncReq(SecureSessionHandle session, Transport::PeerConnectionState * state); + + /** + * Add a CHIP message into the cache table to queue the incoming messages that trigger message counter synchronization + * protocol for re-processing. + * + * @param[in] msgBuf A handle to the packet buffer holding the received message. + * + * @retval #CHIP_ERROR_NO_MEMORY If there is no empty slot left in the table for addition. + * @retval #CHIP_NO_ERROR On success. + */ + CHIP_ERROR AddToReceiveTable(NodeId peerNodeId, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle msgBuf); + +private: + /** + * @class ReceiveTableEntry + * + * @brief + * This class is part of the CHIP Message Counter Synchronization Protocol and is used + * to keep track of a CHIP messages to be reprocessed whose source's + * message counter is unknown. The message is reprocessed after message + * counter synchronization is completed. + * + */ + struct ReceiveTableEntry + { + ReceiveTableEntry() : peerNodeId(kUndefinedNodeId) {} + + // TODO(#6340): peerNodeId may not needed if we can extract it from msgBuf + NodeId peerNodeId; /**< The peerNodeId of the message. kUndefinedNodeId if is not in use. */ + Transport::PeerAddress peerAddress; /**< The peer address for the message*/ + System::PacketBufferHandle msgBuf; /**< A handle to the PacketBuffer object holding the message data. */ + }; + + Messaging::ExchangeManager * mExchangeMgr; // [READ ONLY] Associated Exchange Manager object. + + // MessageCounterManager cache table to queue the incoming messages that trigger message counter synchronization protocol. + ReceiveTableEntry mReceiveTable[CHIP_CONFIG_MCSP_RECEIVE_TABLE_SIZE]; + + void ProcessPendingMessages(NodeId peerNodeId); + + CHIP_ERROR SendMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, FixedByteSpan challenge); + + void HandleMsgCounterSyncReq(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, + System::PacketBufferHandle msgBuf); + + void HandleMsgCounterSyncResp(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, + System::PacketBufferHandle msgBuf); + + void OnMessageReceived(Messaging::ExchangeContext * exchangeContext, const PacketHeader & packetHeader, + const PayloadHeader & payloadHeader, System::PacketBufferHandle payload) override; + + void OnResponseTimeout(Messaging::ExchangeContext * exchangeContext) override; +}; + +} // namespace secure_channel +} // namespace chip diff --git a/src/protocols/secure_channel/tests/BUILD.gn b/src/protocols/secure_channel/tests/BUILD.gn index ab3447c4b1da09..6bbcfb328b9def 100644 --- a/src/protocols/secure_channel/tests/BUILD.gn +++ b/src/protocols/secure_channel/tests/BUILD.gn @@ -10,6 +10,7 @@ chip_test_suite("tests") { test_sources = [ "TestCASESession.cpp", + "TestMessageCounterManager.cpp", "TestPASESession.cpp", "TestStatusReport.cpp", ] @@ -19,6 +20,7 @@ chip_test_suite("tests") { "${chip_root}/src/lib/core", "${chip_root}/src/lib/support", "${chip_root}/src/messaging/tests:helpers", + "${chip_root}/src/protocols", "${chip_root}/src/protocols/secure_channel", "${nlio_root}:nlio", "${nlunit_test_root}:nlunit-test", diff --git a/src/protocols/secure_channel/tests/TestMessageCounterManager.cpp b/src/protocols/secure_channel/tests/TestMessageCounterManager.cpp new file mode 100644 index 00000000000000..b50a293e7d6c89 --- /dev/null +++ b/src/protocols/secure_channel/tests/TestMessageCounterManager.cpp @@ -0,0 +1,208 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file implements unit tests for the MessageCounterManager implementation. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace { + +using namespace chip; +using namespace chip::Inet; +using namespace chip::Transport; +using namespace chip::Messaging; +using namespace chip::Protocols; + +using TestContext = chip::Test::MessagingContext; + +TestContext sContext; + +class LoopbackTransport : public Transport::Base +{ +public: + /// Transports are required to have a constructor that takes exactly one argument + CHIP_ERROR Init(const char * unused) { return CHIP_NO_ERROR; } + + CHIP_ERROR SendMessage(const PeerAddress & address, System::PacketBufferHandle msgBuf) override + { + HandleMessageReceived(address, std::move(msgBuf)); + return CHIP_NO_ERROR; + } + + bool CanSendToPeer(const PeerAddress & address) override { return true; } +}; + +TransportMgr gTransportMgr; + +const char PAYLOAD[] = "Hello!"; + +class MockAppDelegate : public ExchangeDelegate +{ +public: + void OnMessageReceived(ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, + System::PacketBufferHandle msgBuf) override + { + ++ReceiveHandlerCallCount; + ec->Close(); + } + + void OnResponseTimeout(ExchangeContext * ec) override {} + + int ReceiveHandlerCallCount = 0; +}; + +void MessageCounterSyncProcess(nlTestSuite * inSuite, void * inContext) +{ + TestContext & ctx = *reinterpret_cast(inContext); + + CHIP_ERROR err = CHIP_NO_ERROR; + + SecureSessionHandle localSession = ctx.GetSessionLocalToPeer(); + SecureSessionHandle peerSession = ctx.GetSessionPeerToLocal(); + + Transport::PeerConnectionState * localState = ctx.GetSecureSessionManager().GetPeerConnectionState(localSession); + Transport::PeerConnectionState * peerState = ctx.GetSecureSessionManager().GetPeerConnectionState(peerSession); + + localState->GetSessionMessageCounter().GetPeerMessageCounter().Reset(); + err = ctx.GetMessageCounterManager().SendMsgCounterSyncReq(localSession, localState); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + MessageCounter & peerCounter = peerState->GetSessionMessageCounter().GetLocalMessageCounter(); + PeerMessageCounter & localCounter = localState->GetSessionMessageCounter().GetPeerMessageCounter(); + NL_TEST_ASSERT(inSuite, localCounter.IsSynchronized()); + NL_TEST_ASSERT(inSuite, localCounter.GetCounter() == peerCounter.Value()); +} + +void CheckReceiveMessage(nlTestSuite * inSuite, void * inContext) +{ + TestContext & ctx = *reinterpret_cast(inContext); + CHIP_ERROR err = CHIP_NO_ERROR; + + SecureSessionHandle peerSession = ctx.GetSessionPeerToLocal(); + Transport::PeerConnectionState * peerState = ctx.GetSecureSessionManager().GetPeerConnectionState(peerSession); + peerState->GetSessionMessageCounter().GetPeerMessageCounter().Reset(); + + MockAppDelegate callback; + ctx.GetExchangeManager().RegisterUnsolicitedMessageHandlerForType(chip::Protocols::Echo::MsgType::EchoRequest, &callback); + + uint16_t payload_len = sizeof(PAYLOAD); + System::PacketBufferHandle msgBuf = MessagePacketBuffer::NewWithData(PAYLOAD, payload_len); + NL_TEST_ASSERT(inSuite, !msgBuf.IsNull()); + + Messaging::ExchangeContext * ec = ctx.NewExchangeToPeer(nullptr); + NL_TEST_ASSERT(inSuite, ec != nullptr); + + err = ec->SendMessage(chip::Protocols::Echo::MsgType::EchoRequest, std::move(msgBuf), + Messaging::SendFlags{ Messaging::SendMessageFlags::kNoAutoRequestAck }); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, peerState->GetSessionMessageCounter().GetPeerMessageCounter().IsSynchronized()); + NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); + + ec->Close(); +} + +// Test Suite + +/** + * Test Suite that lists all the test functions. + */ +// clang-format off +const nlTest sTests[] = +{ + NL_TEST_DEF("Test MessageCounterManager::MessageCounterSyncProcess", MessageCounterSyncProcess), + NL_TEST_DEF("Test MessageCounterManager::ReceiveMessage", CheckReceiveMessage), + NL_TEST_SENTINEL() +}; +// clang-format on + +int Initialize(void * aContext); +int Finalize(void * aContext); + +// clang-format off +nlTestSuite sSuite = +{ + "Test-MessageCounterManager", + &sTests[0], + Initialize, + Finalize +}; +// clang-format on + +/** + * Initialize the test suite. + */ +int Initialize(void * aContext) +{ + CHIP_ERROR err = chip::Platform::MemoryInit(); + if (err != CHIP_NO_ERROR) + return FAILURE; + auto * ctx = reinterpret_cast(aContext); + + err = gTransportMgr.Init("LOOPBACK"); + if (err != CHIP_NO_ERROR) + return FAILURE; + + err = ctx->Init(&sSuite, &gTransportMgr); + if (err != CHIP_NO_ERROR) + return FAILURE; + + return SUCCESS; +} + +/** + * Finalize the test suite. + */ +int Finalize(void * aContext) +{ + CHIP_ERROR err = reinterpret_cast(aContext)->Shutdown(); + return (err == CHIP_NO_ERROR) ? SUCCESS : FAILURE; +} + +} // namespace + +/** + * Main + */ +int TestMessageCounterManager() +{ + // Run test suit against one context + nlTestRunner(&sSuite, &sContext); + + return (nlTestRunnerStats(&sSuite)); +} + +CHIP_REGISTER_TEST_SUITE(TestMessageCounterManager); diff --git a/src/transport/BUILD.gn b/src/transport/BUILD.gn index b60f1f5134965c..ff637ecddd34f4 100644 --- a/src/transport/BUILD.gn +++ b/src/transport/BUILD.gn @@ -22,8 +22,11 @@ static_library("transport") { sources = [ "AdminPairingTable.cpp", "AdminPairingTable.h", + "MessageCounter.cpp", + "MessageCounter.h", "PeerConnectionState.h", "PeerConnections.h", + "PeerMessageCounter.h", "SecureMessageCodec.cpp", "SecureMessageCodec.h", "SecureSession.cpp", diff --git a/src/messaging/tests/TestMessageCounterSyncMgrDriver.cpp b/src/transport/MessageCounter.cpp similarity index 60% rename from src/messaging/tests/TestMessageCounterSyncMgrDriver.cpp rename to src/transport/MessageCounter.cpp index 0bf0a9aa49c677..894e9a37622819 100644 --- a/src/messaging/tests/TestMessageCounterSyncMgrDriver.cpp +++ b/src/transport/MessageCounter.cpp @@ -1,5 +1,4 @@ /* - * * Copyright (c) 2021 Project CHIP Authors * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,19 +16,22 @@ /** * @file - * This file implements a standalone/native program executable - * test driver for the SecureChannelMgr tests. + * This file defines the CHIP message counters. * */ -#include "TestMessagingLayer.h" +#include -#include +#include +#include -int main() -{ - // Generate machine-readable, comma-separated value (CSV) output. - nlTestSetOutputStyle(OUTPUT_CSV); +namespace chip { + +GlobalUnencryptedMessageCounter::GlobalUnencryptedMessageCounter() : value(GetRandU32()) {} - return (TestMessageCounterSyncMgr()); +CHIP_ERROR GlobalEncryptedMessageCounter::Init() +{ + return persisted.Init(CHIP_CONFIG_PERSISTED_STORAGE_KEY_GLOBAL_MESSAGE_COUNTER, 1000); } + +} // namespace chip diff --git a/src/transport/MessageCounter.h b/src/transport/MessageCounter.h new file mode 100644 index 00000000000000..49d044dc7b4f0e --- /dev/null +++ b/src/transport/MessageCounter.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP message counters. + * + */ +#pragma once + +#include + +namespace chip { + +/** + * MessageCounter represents a local message counter. There are 3 types + * of message counter + * + * 1. Global unencrypted message counter + * 2. Global encrypted message counter + * 3. Session message counter + * + * There will be separate implementations for each type + */ +class MessageCounter +{ +public: + enum Type : uint8_t + { + GlobalUnencrypted, + GlobalEncrypted, + Session, + }; + + virtual ~MessageCounter() = 0; + + virtual Type GetType() = 0; + virtual void Reset() = 0; + virtual uint32_t Value() = 0; /** Get current value */ + virtual CHIP_ERROR Advance() = 0; /** Advance the counter */ +}; + +inline MessageCounter::~MessageCounter() {} + +class GlobalUnencryptedMessageCounter : public MessageCounter +{ +public: + GlobalUnencryptedMessageCounter(); + ~GlobalUnencryptedMessageCounter() override {} + + Type GetType() override { return GlobalUnencrypted; } + void Reset() override + { /* null op */ + } + uint32_t Value() override { return value; } + CHIP_ERROR Advance() override + { + ++value; + return CHIP_NO_ERROR; + } + +private: + uint32_t value; +}; + +class GlobalEncryptedMessageCounter : public MessageCounter +{ +public: + GlobalEncryptedMessageCounter() {} + ~GlobalEncryptedMessageCounter() override {} + + CHIP_ERROR Init(); + Type GetType() override { return GlobalEncrypted; } + void Reset() override + { /* null op */ + } + uint32_t Value() override { return persisted.GetValue(); } + CHIP_ERROR Advance() override { return persisted.Advance(); } + +private: +#if CONFIG_DEVICE_LAYER + PersistedCounter persisted; +#else + struct FakePersistedCounter + { + FakePersistedCounter() : value(0) {} + CHIP_ERROR Init(chip::Platform::PersistedStorage::Key aId, uint32_t aEpoch) { return CHIP_NO_ERROR; } + + uint32_t GetValue() { return value; } + CHIP_ERROR Advance() + { + ++value; + return CHIP_NO_ERROR; + } + + private: + uint32_t value; + } persisted; +#endif +}; + +class LocalSessionMessageCounter : public MessageCounter +{ +public: + static constexpr uint32_t kInitialValue = 1; + LocalSessionMessageCounter() : value(kInitialValue) {} + ~LocalSessionMessageCounter() override {} + + Type GetType() override { return Session; } + void Reset() override { value = kInitialValue; } + uint32_t Value() override { return value; } + CHIP_ERROR Advance() override + { + ++value; + return CHIP_NO_ERROR; + } + +private: + uint32_t value; +}; + +} // namespace chip diff --git a/src/transport/MessageCounterManagerInterface.h b/src/transport/MessageCounterManagerInterface.h new file mode 100644 index 00000000000000..d13dcb2770edca --- /dev/null +++ b/src/transport/MessageCounterManagerInterface.h @@ -0,0 +1,46 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace chip { +namespace Transport { + +class MessageCounterManagerInterface +{ +public: + virtual ~MessageCounterManagerInterface() {} + + /** + * Start sync if the sync procedure is not started yet. + */ + virtual CHIP_ERROR StartSync(SecureSessionHandle session, Transport::PeerConnectionState * state) = 0; + + /** + * Called when have received a message but session message counter is not synced. It will queue the message and start sync if + * the sync procedure is not started yet. + */ + virtual CHIP_ERROR QueueReceivedMessageAndStartSync(SecureSessionHandle session, Transport::PeerConnectionState * state, + const Transport::PeerAddress & peerAddress, + System::PacketBufferHandle msgBuf) = 0; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index 92a64f98713b50..e026467c82382a 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -64,6 +64,16 @@ class DLL_EXPORT PairingSession */ virtual uint16_t GetLocalKeyId() = 0; + /** + * @brief + * Get the value of peer session counter which is synced during session establishment + */ + virtual uint32_t GetPeerCounter() + { + // TODO(#6652): This is a stub implementation, should be replaced by the real one when CASE and PASE is completed + return LocalSessionMessageCounter::kInitialValue; + } + virtual const char * GetI2RSessionInfo() const = 0; virtual const char * GetR2ISessionInfo() const = 0; diff --git a/src/transport/PeerConnectionState.h b/src/transport/PeerConnectionState.h index a2ee465769fff1..e288efc0bca577 100644 --- a/src/transport/PeerConnectionState.h +++ b/src/transport/PeerConnectionState.h @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -48,7 +49,7 @@ static constexpr uint32_t kUndefinedMessageIndex = UINT32_MAX; class PeerConnectionState { public: - PeerConnectionState() : mMsgCounterSynStatus(MsgCounterSyncStatus::NotSync), mPeerAddress(PeerAddress::Uninitialized()) {} + PeerConnectionState() : mPeerAddress(PeerAddress::Uninitialized()) {} PeerConnectionState(const PeerAddress & addr) : mPeerAddress(addr) {} PeerConnectionState(PeerAddress && addr) : mPeerAddress(addr) {} @@ -64,15 +65,9 @@ class PeerConnectionState void SetTransport(Transport::Base * transport) { mTransport = transport; } Transport::Base * GetTransport() { return mTransport; } - bool IsPeerMsgCounterSynced() { return (mPeerMessageIndex != kUndefinedMessageIndex); } - void SetPeerMessageIndex(uint32_t id) { mPeerMessageIndex = id; } - NodeId GetPeerNodeId() const { return mPeerNodeId; } void SetPeerNodeId(NodeId peerNodeId) { mPeerNodeId = peerNodeId; } - uint32_t GetSendMessageIndex() const { return mSendMessageIndex; } - void IncrementSendMessageIndex() { mSendMessageIndex++; } - uint16_t GetPeerKeyID() const { return mPeerKeyID; } void SetPeerKeyID(uint16_t id) { mPeerKeyID = id; } @@ -87,19 +82,12 @@ class PeerConnectionState Transport::AdminId GetAdminId() const { return mAdmin; } void SetAdminId(Transport::AdminId admin) { mAdmin = admin; } - void SetMsgCounterSyncInProgress(bool value) - { - mMsgCounterSynStatus = value ? MsgCounterSyncStatus::SyncInProcess : MsgCounterSyncStatus::Synced; - } - bool IsInitialized() { return (mPeerAddress.IsInitialized() || mPeerNodeId != kUndefinedNodeId || mPeerKeyID != UINT16_MAX || mLocalKeyID != UINT16_MAX); } - bool IsMsgCounterSyncInProgress() { return mMsgCounterSynStatus == MsgCounterSyncStatus::SyncInProcess; } - /** * Reset the connection state to a completely uninitialized status. */ @@ -107,10 +95,9 @@ class PeerConnectionState { mPeerAddress = PeerAddress::Uninitialized(); mPeerNodeId = kUndefinedNodeId; - mSendMessageIndex = 0; mLastActivityTimeMs = 0; mSecureSession.Reset(); - mMsgCounterSynStatus = MsgCounterSyncStatus::NotSync; + mSessionMessageCounter.Reset(); } CHIP_ERROR EncryptBeforeSend(const uint8_t * input, size_t input_length, uint8_t * output, PacketHeader & header, @@ -125,23 +112,17 @@ class PeerConnectionState return mSecureSession.Decrypt(input, input_length, output, header, mac); } -private: - enum class MsgCounterSyncStatus - { - NotSync, - SyncInProcess, - Synced, - } mMsgCounterSynStatus; + SessionMessageCounter & GetSessionMessageCounter() { return mSessionMessageCounter; } +private: PeerAddress mPeerAddress; NodeId mPeerNodeId = kUndefinedNodeId; - uint32_t mSendMessageIndex = 0; - uint32_t mPeerMessageIndex = kUndefinedMessageIndex; uint16_t mPeerKeyID = UINT16_MAX; uint16_t mLocalKeyID = UINT16_MAX; uint64_t mLastActivityTimeMs = 0; Transport::Base * mTransport = nullptr; SecureSession mSecureSession; + SessionMessageCounter mSessionMessageCounter; Transport::AdminId mAdmin = kUndefinedAdminId; }; diff --git a/src/transport/PeerMessageCounter.h b/src/transport/PeerMessageCounter.h new file mode 100644 index 00000000000000..476150938f4171 --- /dev/null +++ b/src/transport/PeerMessageCounter.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file defines the CHIP message counters of remote nodes. + * + */ +#pragma once + +#include +#include + +#include + +namespace chip { +namespace Transport { + +class PeerMessageCounter +{ +public: + static constexpr size_t kChallengeSize = 8; + + PeerMessageCounter() : mStatus(Status::NotSynced) {} + ~PeerMessageCounter() { Reset(); } + + void Reset() + { + switch (mStatus) + { + case Status::NotSynced: + break; + case Status::SyncInProcess: + mSyncInProcess.~SyncInProcess(); + break; + case Status::Synced: + mSynced.~Synced(); + break; + } + mStatus = Status::NotSynced; + } + + bool IsSynchronizing() { return mStatus == Status::SyncInProcess; } + bool IsSynchronized() { return mStatus == Status::Synced; } + + void SyncStarting(FixedByteSpan challenge) + { + assert(mStatus == Status::NotSynced); + mStatus = Status::SyncInProcess; + new (&mSyncInProcess) SyncInProcess(); + ::memcpy(mSyncInProcess.mChallenge.data(), challenge.data(), kChallengeSize); + } + + void SyncFailed() { Reset(); } + + CHIP_ERROR VerifyChallenge(uint32_t counter, FixedByteSpan challenge) + { + if (mStatus != Status::SyncInProcess) + { + return CHIP_ERROR_INCORRECT_STATE; + } + if (::memcmp(mSyncInProcess.mChallenge.data(), challenge.data(), kChallengeSize) != 0) + { + return CHIP_ERROR_INVALID_ARGUMENT; + } + + mSyncInProcess.~SyncInProcess(); + mStatus = Status::Synced; + new (&mSynced) Synced(); + mSynced.mMaxCounter = counter; + mSynced.mWindow.reset(); // reset all bits, accept all packets in the window + return CHIP_NO_ERROR; + } + + CHIP_ERROR Verify(uint32_t counter) const + { + if (mStatus != Status::Synced) + { + return CHIP_ERROR_INCORRECT_STATE; + } + + if (counter <= mSynced.mMaxCounter) + { + uint32_t offset = mSynced.mMaxCounter - counter; + if (offset >= CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE) + { + return CHIP_ERROR_INVALID_ARGUMENT; // outside valid range + } + if (mSynced.mWindow.test(offset)) + { + return CHIP_ERROR_INVALID_ARGUMENT; // duplicated, in window + } + } + + return CHIP_NO_ERROR; + } + + /** + * @brief + * With the counter verified and the packet MIC also verified by the secure key, we can trust the packet and adjust + * counter states. + * + * @pre Verify(counter) == CHIP_NO_ERROR + */ + void Commit(uint32_t counter) + { + if (counter <= mSynced.mMaxCounter) + { + uint32_t offset = mSynced.mMaxCounter - counter; + mSynced.mWindow.set(offset); + } + else + { + uint32_t offset = counter - mSynced.mMaxCounter; + // advance max counter by `offset` + mSynced.mMaxCounter = counter; + if (offset < CHIP_CONFIG_MESSAGE_COUNTER_WINDOW_SIZE) + { + mSynced.mWindow <<= offset; + } + else + { + mSynced.mWindow.reset(); + } + mSynced.mWindow.set(0); + } + } + + void SetCounter(uint32_t value) + { + Reset(); + mStatus = Status::Synced; + new (&mSynced) Synced(); + mSynced.mMaxCounter = value; + mSynced.mWindow.reset(); + } + + /* Test-only */ + uint32_t GetCounter() { return mSynced.mMaxCounter; } + +private: + enum class Status + { + NotSynced, // No state associated + SyncInProcess, // mSyncInProcess will be active + Synced, // mSynced will be active + } mStatus; + + struct SyncInProcess + { + std::array mChallenge; + }; + + struct Synced + { + /* + * Past <-- --> Future + * MaxCounter + * | + * v + * | <-- mWindow -->| + * |[n]| ... |[0]| + */ + uint32_t mMaxCounter; // The most recent counter we have seen + std::bitset mWindow; + }; + + // We should use std::variant here when migrated to C++17 + union + { + SyncInProcess mSyncInProcess; + Synced mSynced; + }; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/SecureMessageCodec.cpp b/src/transport/SecureMessageCodec.cpp index d0673323402dfc..9b448faacdbb6d 100644 --- a/src/transport/SecureMessageCodec.cpp +++ b/src/transport/SecureMessageCodec.cpp @@ -37,13 +37,13 @@ using System::PacketBufferHandle; namespace SecureMessageCodec { CHIP_ERROR Encode(NodeId localNodeId, Transport::PeerConnectionState * state, PayloadHeader & payloadHeader, - PacketHeader & packetHeader, System::PacketBufferHandle & msgBuf) + PacketHeader & packetHeader, System::PacketBufferHandle & msgBuf, MessageCounter & counter) { VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); VerifyOrReturnError(!msgBuf->HasChainedBuffer(), CHIP_ERROR_INVALID_MESSAGE_LENGTH); VerifyOrReturnError(msgBuf->TotalLength() <= kMaxAppMessageLen, CHIP_ERROR_MESSAGE_TOO_LONG); - uint32_t msgId = state->GetSendMessageIndex(); + uint32_t msgId = counter.Value(); static_assert(std::is_sameTotalLength()), uint16_t>::value, "Addition to generate payloadLength might overflow"); @@ -76,7 +76,7 @@ CHIP_ERROR Encode(NodeId localNodeId, Transport::PeerConnectionState * state, Pa ChipLogDetail(Inet, "Secure message was encrypted: Msg ID %u", msgId); - state->IncrementSendMessageIndex(); + ReturnErrorOnFailure(counter.Advance()); return CHIP_NO_ERROR; } diff --git a/src/transport/SecureMessageCodec.h b/src/transport/SecureMessageCodec.h index e75a2e5abbd9b3..edf5b2b47d0c3e 100644 --- a/src/transport/SecureMessageCodec.h +++ b/src/transport/SecureMessageCodec.h @@ -47,10 +47,11 @@ namespace SecureMessageCodec { * @param msgBuf The message buffer that contains the unencrypted message. If * the operation is successuful, this buffer will contain the * encrypted message. + * @param counter The local counter object to be used * @ return CHIP_ERROR The result of the encode operation */ CHIP_ERROR Encode(NodeId localNodeId, Transport::PeerConnectionState * state, PayloadHeader & payloadHeader, - PacketHeader & packetHeader, System::PacketBufferHandle & msgBuf); + PacketHeader & packetHeader, System::PacketBufferHandle & msgBuf, MessageCounter & counter); /** * @brief diff --git a/src/transport/SecureSessionHandle.h b/src/transport/SecureSessionHandle.h new file mode 100644 index 00000000000000..167d982fa49cc8 --- /dev/null +++ b/src/transport/SecureSessionHandle.h @@ -0,0 +1,55 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace chip { + +class SecureSessionMgr; + +class SecureSessionHandle +{ +public: + SecureSessionHandle() : mPeerNodeId(kAnyNodeId), mPeerKeyId(0), mAdmin(Transport::kUndefinedAdminId) {} + SecureSessionHandle(NodeId peerNodeId, uint16_t peerKeyId, Transport::AdminId admin) : + mPeerNodeId(peerNodeId), mPeerKeyId(peerKeyId), mAdmin(admin) + {} + + bool HasAdminId() const { return (mAdmin != Transport::kUndefinedAdminId); } + Transport::AdminId GetAdminId() const { return mAdmin; } + void SetAdminId(Transport::AdminId adminId) { mAdmin = adminId; } + + bool operator==(const SecureSessionHandle & that) const + { + return mPeerNodeId == that.mPeerNodeId && mPeerKeyId == that.mPeerKeyId && mAdmin == that.mAdmin; + } + + NodeId GetPeerNodeId() const { return mPeerNodeId; } + uint16_t GetPeerKeyId() const { return mPeerKeyId; } + +private: + friend class SecureSessionMgr; + NodeId mPeerNodeId; + uint16_t mPeerKeyId; + // TODO: Re-evaluate the storing of Admin ID in SecureSessionHandle + // The Admin ID will not be available for PASE and group sessions. So need + // to identify an approach that'll allow looking up the corresponding information for + // such sessions. + Transport::AdminId mAdmin; +}; + +} // namespace chip diff --git a/src/transport/SecureSessionMgr.cpp b/src/transport/SecureSessionMgr.cpp index 6652486161c6d5..689489c4533014 100644 --- a/src/transport/SecureSessionMgr.cpp +++ b/src/transport/SecureSessionMgr.cpp @@ -72,16 +72,20 @@ SecureSessionMgr::~SecureSessionMgr() } CHIP_ERROR SecureSessionMgr::Init(NodeId localNodeId, System::Layer * systemLayer, TransportMgrBase * transportMgr, - Transport::AdminPairingTable * admins) + Transport::AdminPairingTable * admins, + Transport::MessageCounterManagerInterface * messageCounterManager) { VerifyOrReturnError(mState == State::kNotReady, CHIP_ERROR_INCORRECT_STATE); VerifyOrReturnError(transportMgr != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - mState = State::kInitialized; - mLocalNodeId = localNodeId; - mSystemLayer = systemLayer; - mTransportMgr = transportMgr; - mAdmins = admins; + mState = State::kInitialized; + mLocalNodeId = localNodeId; + mSystemLayer = systemLayer; + mTransportMgr = transportMgr; + mAdmins = admins; + mMessageCounterManager = messageCounterManager; + + mGlobalEncryptedMessageCounter.Init(); ChipLogProgress(Inet, "local node id is 0x%08" PRIx32 "%08" PRIx32, static_cast(mLocalNodeId >> 32), static_cast(mLocalNodeId)); @@ -97,6 +101,8 @@ void SecureSessionMgr::Shutdown() { CancelExpiryTimer(); + mMessageCounterManager = nullptr; + mState = State::kNotReady; mLocalNodeId = kUndefinedNodeId; mSystemLayer = nullptr; @@ -164,15 +170,15 @@ CHIP_ERROR SecureSessionMgr::SendMessage(SecureSessionHandle session, PayloadHea VerifyOrExit(admin != nullptr, err = CHIP_ERROR_INCORRECT_STATE); localNodeId = admin->GetNodeId(); - if (payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq) || - payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp)) + if (IsControlMessage(payloadHeader)) { packetHeader.SetSecureSessionControlMsg(true); } if (encryptionState == EncryptionState::kPayloadIsUnencrypted) { - err = SecureMessageCodec::Encode(localNodeId, state, payloadHeader, packetHeader, msgBuf); + MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *state); + err = SecureMessageCodec::Encode(localNodeId, state, payloadHeader, packetHeader, msgBuf, counter); SuccessOrExit(err); } @@ -261,6 +267,7 @@ CHIP_ERROR SecureSessionMgr::NewPairing(const Optional & if (mCB != nullptr) { + state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(pairing->GetPeerCounter()); mCB->OnNewConnection({ state->GetPeerNodeId(), state->GetPeerKeyID(), admin }, this); } @@ -283,14 +290,6 @@ void SecureSessionMgr::CancelExpiryTimer() } } -void SecureSessionMgr::HandleGroupMessageReceived(uint16_t keyId, System::PacketBufferHandle msgBuf) -{ - PeerConnectionState * state = mPeerConnections.FindPeerConnectionState(keyId, nullptr); - VerifyOrReturn(state != nullptr, ChipLogError(Inet, "Failed to find the peer connection state")); - - OnMessageReceived(state->GetPeerAddress(), std::move(msgBuf)); -} - void SecureSessionMgr::OnMessageReceived(const PeerAddress & peerAddress, System::PacketBufferHandle msg) { PacketHeader packetHeader; @@ -341,6 +340,45 @@ void SecureSessionMgr::SecureMessageDispatch(const PacketHeader & packetHeader, ExitNow(err = CHIP_ERROR_KEY_NOT_FOUND_FROM_PEER); } + // Verify message counter + if (packetHeader.GetFlags().Has(Header::FlagValues::kSecureSessionControlMessage)) + { + // TODO: control message counter is not implemented yet + } + else + { + if (!state->GetSessionMessageCounter().GetPeerMessageCounter().IsSynchronized()) + { + err = packetHeader.EncodeBeforeData(msg); + SuccessOrExit(err); + + // Queue and start message sync procedure + err = mMessageCounterManager->QueueReceivedMessageAndStartSync( + { state->GetPeerNodeId(), state->GetPeerKeyID(), state->GetAdminId() }, state, peerAddress, std::move(msg)); + + if (err != CHIP_NO_ERROR) + { + ChipLogError(Inet, + "Message counter synchronization for received message, failed to " + "QueueReceivedMessageAndStartSync, err = %d", + err); + } + else + { + ChipLogDetail(Inet, "Received message have been queued due to peer counter is not synced"); + } + + return; + } + + err = state->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageId()); + if (err != CHIP_NO_ERROR) + { + ChipLogError(Inet, "Message counter verify failed, err = %d", err); + } + SuccessOrExit(err); + } + admin = mAdmins->FindAdminWithId(state->GetAdminId()); VerifyOrExit(admin != nullptr, ChipLogError(Inet, "Secure transport received packet for unknown admin (%p, %d) pairing, discarding", state, @@ -357,30 +395,19 @@ void SecureSessionMgr::SecureMessageDispatch(const PacketHeader & packetHeader, ChipLogError(Inet, "Secure transport received message destined to node ID (%llu)", packetHeader.GetDestinationNodeId().Value()); mPeerConnections.MarkConnectionActive(state); - if (!packetHeader.IsSecureSessionControlMsg() && !state->IsPeerMsgCounterSynced() && - ChipKeyId::IsAppGroupKey(packetHeader.GetEncryptionKeyID())) - { - // Queue the message as needed for sync with destination node. - if (mCB != nullptr) - { - // We should encode the packetHeader into the buffer before storing the buffer into the queue since the - // stored buffer will be re-processed by OnMessageReceived after the peer message counter is synced. - ReturnOnFailure(packetHeader.EncodeBeforeData(msg)); - err = mCB->QueueReceivedMessageAndSync(state, std::move(msg)); - VerifyOrReturn(err == CHIP_NO_ERROR); - } - - // After the message that triggers message counter synchronization is stored, and a message counter - // synchronization exchange is initiated, we need to return immediately and re-process the original message - // when the synchronization is completed. - - return; - } - // Decode the message VerifyOrExit(CHIP_NO_ERROR == SecureMessageCodec::Decode(state, payloadHeader, packetHeader, msg), ChipLogError(Inet, "Secure transport received message, but failed to decode it, discarding")); + if (packetHeader.GetFlags().Has(Header::FlagValues::kSecureSessionControlMessage)) + { + // TODO: control message counter is not implemented yet + } + else + { + state->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageId()); + } + // See operational-credentials-server.cpp for explanation as to why fabricId is being set to commissioner node id // This is temporary code until AddOptCert is implemented through which an admin will be correctly added with the correct // fields. @@ -432,16 +459,6 @@ void SecureSessionMgr::SecureMessageDispatch(const PacketHeader & packetHeader, state->SetPeerAddress(peerAddress); } - if (!state->IsPeerMsgCounterSynced()) - { - // For all control messages, the first authenticated message counter from an unsynchronized peer is trusted - // and used to seed subsequent message counter based replay protection. - if (packetHeader.IsSecureSessionControlMsg()) - { - state->SetPeerMessageIndex(packetHeader.GetMessageId()); - } - } - if (mCB != nullptr) { SecureSessionHandle session(state->GetPeerNodeId(), state->GetPeerKeyID(), state->GetAdminId()); diff --git a/src/transport/SecureSessionMgr.h b/src/transport/SecureSessionMgr.h index 9e89ce6d76060a..e111d3e50695a0 100644 --- a/src/transport/SecureSessionMgr.h +++ b/src/transport/SecureSessionMgr.h @@ -30,12 +30,15 @@ #include #include #include +#include #include #include #include +#include #include #include #include +#include #include #include #include @@ -43,39 +46,6 @@ namespace chip { -class SecureSessionMgr; - -class SecureSessionHandle -{ -public: - SecureSessionHandle() : mPeerNodeId(kAnyNodeId), mPeerKeyId(0), mAdmin(Transport::kUndefinedAdminId) {} - SecureSessionHandle(NodeId peerNodeId, uint16_t peerKeyId, Transport::AdminId admin) : - mPeerNodeId(peerNodeId), mPeerKeyId(peerKeyId), mAdmin(admin) - {} - - bool HasAdminId() const { return (mAdmin != Transport::kUndefinedAdminId); } - Transport::AdminId GetAdminId() const { return mAdmin; } - void SetAdminId(Transport::AdminId adminId) { mAdmin = adminId; } - - bool operator==(const SecureSessionHandle & that) const - { - return mPeerNodeId == that.mPeerNodeId && mPeerKeyId == that.mPeerKeyId && mAdmin == that.mAdmin; - } - - NodeId GetPeerNodeId() const { return mPeerNodeId; } - uint16_t GetPeerKeyId() const { return mPeerKeyId; } - -private: - friend class SecureSessionMgr; - NodeId mPeerNodeId; - uint16_t mPeerKeyId; - // TODO: Re-evaluate the storing of Admin ID in SecureSessionHandle - // The Admin ID will not be available for PASE and group sessions. So need - // to identify an approach that'll allow looking up the corresponding information for - // such sessions. - Transport::AdminId mAdmin; -}; - /** * @brief * Tracks ownership of a encrypted packet buffer. @@ -186,22 +156,6 @@ class DLL_EXPORT SecureSessionMgrDelegate */ virtual void OnConnectionExpired(SecureSessionHandle session, SecureSessionMgr * mgr) {} - /** - * @brief - * Called when received message from a source node whose message counter is unknown. - * Queue the message and start sync if the sync procedure is not started yet. - * - * @param state A pointer to the state of peer connection - * @param msgBuf The received message - * - * @retval #CHIP_ERROR_NO_MEMORY If there is no empty slot left to queue the message. - * @retval #CHIP_NO_ERROR On success. - */ - virtual CHIP_ERROR QueueReceivedMessageAndSync(Transport::PeerConnectionState * state, System::PacketBufferHandle msgBuf) - { - return CHIP_NO_ERROR; - } - virtual ~SecureSessionMgrDelegate() {} }; @@ -258,13 +212,14 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate * @brief * Initialize a Secure Session Manager * - * @param localNodeId Node id for the current node - * @param systemLayer System, layer to use - * @param transportMgr Transport to use - * @param admins A table of device administrators + * @param localNodeId Node id for the current node + * @param systemLayer System, layer to use + * @param transportMgr Transport to use + * @param admins A table of device administrators + * @param messageCounterManager The message counter manager */ CHIP_ERROR Init(NodeId localNodeId, System::Layer * systemLayer, TransportMgrBase * transportMgr, - Transport::AdminPairingTable * admins); + Transport::AdminPairingTable * admins, Transport::MessageCounterManagerInterface * messageCounterManager); /** * @brief @@ -273,16 +228,6 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate */ void Shutdown(); - /** - * @brief - * Called when a cached group message that was waiting for message counter - * sync shold be reprocessed. - * - * @param keyId The encryption Key ID of the message buffer - * @param msgBuf The received message - */ - void HandleGroupMessageReceived(uint16_t keyId, System::PacketBufferHandle msgBuf); - /** * @brief * Set local node ID @@ -303,7 +248,6 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate TransportMgrBase * GetTransportManager() const { return mTransportMgr; } -protected: /** * @brief * Handle received secure message. Implements TransportMgrDelegate @@ -334,9 +278,13 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate Transport::PeerConnections mPeerConnections; // < Active connections to other peers State mState; // < Initialization state of the object - SecureSessionMgrDelegate * mCB = nullptr; - TransportMgrBase * mTransportMgr = nullptr; - Transport::AdminPairingTable * mAdmins = nullptr; + SecureSessionMgrDelegate * mCB = nullptr; + TransportMgrBase * mTransportMgr = nullptr; + Transport::AdminPairingTable * mAdmins = nullptr; + Transport::MessageCounterManagerInterface * mMessageCounterManager = nullptr; + + GlobalUnencryptedMessageCounter mGlobalUnencryptedMessageCounter; + GlobalEncryptedMessageCounter mGlobalEncryptedMessageCounter; CHIP_ERROR SendMessage(SecureSessionHandle session, PayloadHeader & payloadHeader, PacketHeader & packetHeader, System::PacketBufferHandle msgBuf, EncryptedPacketBufferHandle * bufferRetainSlot, @@ -362,6 +310,24 @@ class DLL_EXPORT SecureSessionMgr : public TransportMgrDelegate System::PacketBufferHandle msg); void MessageDispatch(const PacketHeader & packetHeader, const Transport::PeerAddress & peerAddress, System::PacketBufferHandle msg); + + static bool IsControlMessage(PayloadHeader & payloadHeader) + { + return payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncReq) || + payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::MsgCounterSyncRsp); + } + + MessageCounter & GetSendCounterForPacket(PayloadHeader & payloadHeader, Transport::PeerConnectionState & state) + { + if (IsControlMessage(payloadHeader)) + { + return mGlobalEncryptedMessageCounter; + } + else + { + return state.GetSessionMessageCounter().GetLocalMessageCounter(); + } + } }; namespace MessagePacketBuffer { diff --git a/src/transport/SessionMessageCounter.h b/src/transport/SessionMessageCounter.h new file mode 100644 index 00000000000000..8fd97ea66c6b7c --- /dev/null +++ b/src/transport/SessionMessageCounter.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @brief Defines state relevant for an active connection to a peer. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace chip { +namespace Transport { + +class SessionMessageCounter +{ +public: + void Reset() + { + mLocalMessageCounter.Reset(); + mPeerMessageCounter.Reset(); + } + + MessageCounter & GetLocalMessageCounter() { return mLocalMessageCounter; } + PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; } + +private: + LocalSessionMessageCounter mLocalMessageCounter; + PeerMessageCounter mPeerMessageCounter; +}; + +} // namespace Transport +} // namespace chip diff --git a/src/transport/tests/TestSecureSessionMgr.cpp b/src/transport/tests/TestSecureSessionMgr.cpp index ee3ada884e575a..821d32c8a82d0c 100644 --- a/src/transport/tests/TestSecureSessionMgr.cpp +++ b/src/transport/tests/TestSecureSessionMgr.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -118,6 +119,7 @@ class TestSessMgrCallback : public SecureSessionMgrDelegate void OnNewConnection(SecureSessionHandle session, SecureSessionMgr * mgr) override { + // Preset the MessageCounter if (NewConnectionHandlerCallCount == 0) mRemoteToLocalSession = session; if (NewConnectionHandlerCallCount == 1) @@ -143,6 +145,8 @@ void CheckSimpleInitTest(nlTestSuite * inSuite, void * inContext) TransportMgr transportMgr; SecureSessionMgr secureSessionMgr; + secure_channel::MessageCounterManager gMessageCounterManager; + CHIP_ERROR err; ctx.GetInetLayer().SystemLayer()->Init(nullptr); @@ -151,7 +155,7 @@ void CheckSimpleInitTest(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); Transport::AdminPairingTable admins; - err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins, &gMessageCounterManager); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); } @@ -174,12 +178,13 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) TransportMgr transportMgr; SecureSessionMgr secureSessionMgr; + secure_channel::MessageCounterManager gMessageCounterManager; err = transportMgr.Init("LOOPBACK"); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); Transport::AdminPairingTable admins; - err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins, &gMessageCounterManager); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); callback.mSuite = inSuite; @@ -218,8 +223,6 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer)); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 0; }); - NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); // Let's send the max sized message and make sure it is received @@ -231,8 +234,6 @@ void CheckMessageTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(large_buffer)); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 0; }); - NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); uint16_t large_payload_len = sizeof(LARGE_PAYLOAD); @@ -266,12 +267,13 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) TransportMgr transportMgr; SecureSessionMgr secureSessionMgr; + secure_channel::MessageCounterManager gMessageCounterManager; err = transportMgr.Init("LOOPBACK"); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); Transport::AdminPairingTable admins; - err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins, &gMessageCounterManager); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); callback.mSuite = inSuite; @@ -313,13 +315,15 @@ void SendEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer), &msgBuf); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 0; }); + // Reset receive side message counter, or duplicated message will be denied. + Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession); + state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(msgBuf), nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 1; }); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); } @@ -342,12 +346,13 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) TransportMgr transportMgr; SecureSessionMgr secureSessionMgr; + secure_channel::MessageCounterManager gMessageCounterManager; err = transportMgr.Init("LOOPBACK"); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); Transport::AdminPairingTable admins; - err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins); + err = secureSessionMgr.Init(kSourceNodeId, ctx.GetInetLayer().SystemLayer(), &transportMgr, &admins, &gMessageCounterManager); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); callback.mSuite = inSuite; @@ -389,9 +394,13 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendMessage(localToRemoteSession, payloadHeader, std::move(buffer), &msgBuf); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 0; }); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); + /* -------------------------------------------------------------------------------------------*/ + // Reset receive side message counter, or duplicated message will be denied. + Transport::PeerConnectionState * state = secureSessionMgr.GetPeerConnectionState(callback.mRemoteToLocalSession); + state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + PacketHeader packetHeader; // Change Destination Node ID @@ -405,10 +414,11 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badDestNodeIdMsg), nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 1; }); - NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); + /* -------------------------------------------------------------------------------------------*/ + state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + // Change Source Node ID EncryptedPacketBufferHandle badSrcNodeIdMsg = msgBuf.CloneData(); NL_TEST_ASSERT(inSuite, badSrcNodeIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); @@ -419,10 +429,11 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badSrcNodeIdMsg), nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 1; }); - NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); + /* -------------------------------------------------------------------------------------------*/ + state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + // Change Message ID EncryptedPacketBufferHandle badMessageIdMsg = msgBuf.CloneData(); NL_TEST_ASSERT(inSuite, badMessageIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); @@ -434,10 +445,11 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badMessageIdMsg), nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 1; }); - NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); + /* -------------------------------------------------------------------------------------------*/ + state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); + // Change Key ID EncryptedPacketBufferHandle badKeyIdMsg = msgBuf.CloneData(); NL_TEST_ASSERT(inSuite, badKeyIdMsg.ExtractPacketHeader(packetHeader) == CHIP_NO_ERROR); @@ -449,7 +461,8 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(badKeyIdMsg), nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 1; }); + /* -------------------------------------------------------------------------------------------*/ + state->GetSessionMessageCounter().GetPeerMessageCounter().SetCounter(1); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 1); @@ -457,7 +470,6 @@ void SendBadEncryptedPacketTest(nlTestSuite * inSuite, void * inContext) err = secureSessionMgr.SendEncryptedMessage(localToRemoteSession, std::move(msgBuf), nullptr); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); - ctx.DriveIOUntil(1000 /* ms */, []() { return callback.ReceiveHandlerCallCount != 1; }); NL_TEST_ASSERT(inSuite, callback.ReceiveHandlerCallCount == 2); }