From 4aaf11a177ef99754155ec0b13926ebb9811ecbd Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Wed, 8 Dec 2021 13:54:52 -0800 Subject: [PATCH] Integrate CASE MRP parameters with controller and CASE server (#12738) * Integrate CASE MRP parameters with controller and CASE server * fix tests * fix test build --- src/app/CASEClient.cpp | 4 +-- src/app/CASEClient.h | 2 ++ src/app/OperationalDeviceProxy.cpp | 3 +- src/app/OperationalDeviceProxy.h | 2 ++ src/controller/CHIPDeviceController.cpp | 1 + src/controller/CHIPDeviceController.h | 4 +-- src/protocols/secure_channel/CASEServer.cpp | 3 +- src/protocols/secure_channel/CASESession.cpp | 30 +++++++++++++------- src/protocols/secure_channel/PASESession.cpp | 15 ++++++++-- src/transport/PairingSession.cpp | 10 +++---- src/transport/PairingSession.h | 4 +-- src/transport/tests/TestPairingSession.cpp | 15 +++++----- 12 files changed, 58 insertions(+), 35 deletions(-) diff --git a/src/app/CASEClient.cpp b/src/app/CASEClient.cpp index c0fe3c4b89090b..e4bba8e6c9e0b1 100644 --- a/src/app/CASEClient.cpp +++ b/src/app/CASEClient.cpp @@ -43,8 +43,8 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres uint16_t keyID = 0; ReturnErrorOnFailure(mInitParams.idAllocator->Allocate(keyID)); - ReturnErrorOnFailure( - mCASESession.EstablishSession(peerAddress, mInitParams.fabricInfo, peer.GetNodeId(), keyID, exchange, this)); + ReturnErrorOnFailure(mCASESession.EstablishSession(peerAddress, mInitParams.fabricInfo, peer.GetNodeId(), keyID, exchange, this, + mInitParams.mrpLocalConfig)); mConnectionSuccessCallback = onConnection; mConnectionFailureCallback = onFailure; mConectionContext = context; diff --git a/src/app/CASEClient.h b/src/app/CASEClient.h index 22b68ecfde34c2..24a502a5ff0621 100644 --- a/src/app/CASEClient.h +++ b/src/app/CASEClient.h @@ -35,6 +35,8 @@ struct CASEClientInitParams Messaging::ExchangeManager * exchangeMgr = nullptr; SessionIDAllocator * idAllocator = nullptr; FabricInfo * fabricInfo = nullptr; + + Optional mrpLocalConfig = Optional::Missing(); }; class DLL_EXPORT CASEClient : public SessionEstablishmentDelegate diff --git a/src/app/OperationalDeviceProxy.cpp b/src/app/OperationalDeviceProxy.cpp index 81e0b29ab7d20f..46ebccb0681be4 100644 --- a/src/app/OperationalDeviceProxy.cpp +++ b/src/app/OperationalDeviceProxy.cpp @@ -150,7 +150,8 @@ bool OperationalDeviceProxy::GetAddress(Inet::IPAddress & addr, uint16_t & port) CHIP_ERROR OperationalDeviceProxy::EstablishConnection() { mCASEClient = mInitParams.clientPool->Allocate(CASEClientInitParams{ mInitParams.sessionManager, mInitParams.exchangeMgr, - mInitParams.idAllocator, mInitParams.fabricInfo }); + mInitParams.idAllocator, mInitParams.fabricInfo, + mInitParams.mrpLocalConfig }); ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY); CHIP_ERROR err = mCASEClient->EstablishSession(mPeerId, mDeviceAddress, mMRPConfig, HandleCASEConnected, HandleCASEConnectionFailure, this); diff --git a/src/app/OperationalDeviceProxy.h b/src/app/OperationalDeviceProxy.h index 3ff7c4967e83eb..b67384b873e3f1 100644 --- a/src/app/OperationalDeviceProxy.h +++ b/src/app/OperationalDeviceProxy.h @@ -57,6 +57,8 @@ struct DeviceProxyInitParams Controller::DeviceControllerInteractionModelDelegate * imDelegate = nullptr; + Optional mrpLocalConfig = Optional::Missing(); + CHIP_ERROR Validate() { ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE); diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index 8944c8484fad1c..186c3cab4388a7 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -153,6 +153,7 @@ CHIP_ERROR DeviceController::Init(ControllerInitParams params) .fabricInfo = params.systemState->Fabrics()->FindFabricWithIndex(mFabricIndex), .clientPool = &mCASEClientPool, .imDelegate = params.systemState->IMDelegate(), + .mrpLocalConfig = Optional::Value(mMRPConfig), }; CASESessionManagerConfig sessionManagerConfig = { diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index 17597334dfabd0..c810c6a18df075 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -388,6 +388,8 @@ class DLL_EXPORT DeviceController : public SessionReleaseDelegate, uint16_t mVendorId; + ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig; + //////////// SessionReleaseDelegate Implementation /////////////// void OnSessionReleased(SessionHandle session) override; @@ -819,8 +821,6 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController, Callback::Callback mDeviceNOCChainCallback; SetUpCodePairer mSetUpCodePairer; - - ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig; }; } // namespace Controller diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 926210afb49964..a40c4f81f0dab1 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -74,7 +74,8 @@ CHIP_ERROR CASEServer::InitCASEHandshake(Messaging::ExchangeContext * ec) ReturnErrorOnFailure(mIDAllocator->Allocate(mSessionKeyId)); // Setup CASE state machine using the credentials for the current fabric. - ReturnErrorOnFailure(GetSession().ListenForSessionEstablishment(mSessionKeyId, mFabrics, this)); + ReturnErrorOnFailure(GetSession().ListenForSessionEstablishment( + mSessionKeyId, mFabrics, this, Optional::Value(gDefaultMRPConfig))); // Hand over the exchange context to the CASE session. ec->SetDelegate(&GetSession()); diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index cde2d4f333a0f3..581c14bd517153 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -291,12 +291,12 @@ CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session, CryptoConte CHIP_ERROR CASESession::SendSigma1() { - size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom + const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; + size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom sizeof(uint16_t), // initiatorSessionId, kSHA256_Hash_Length, // destinationId kP256_PublicKey_Length, // InitiatorEphPubKey, - /* TLV::EstimateStructOverhead(sizeof(uint16_t), - sizeof(uint16)_t), // initiatorMRPParams */ + mrpParamsSize, // initiatorMRPParams kCASEResumptionIDSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES); System::PacketBufferTLVWriter tlvWriter; @@ -461,8 +461,9 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg) CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom) { - size_t max_sigma2_resume_data_len = TLV::EstimateStructOverhead(kCASEResumptionIDSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, - sizeof(uint16_t) /*, kMRPOptionalParamsLength, */); + const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; + size_t max_sigma2_resume_data_len = + TLV::EstimateStructOverhead(kCASEResumptionIDSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, sizeof(uint16_t), mrpParamsSize); System::PacketBufferTLVWriter tlvWriter; System::PacketBufferHandle msg_R2_resume; @@ -603,8 +604,9 @@ CHIP_ERROR CASESession::SendSigma2() CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES)); // Construct Sigma2 Msg - size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, sizeof(uint16_t), kP256_PublicKey_Length, - msg_r2_signed_enc_len, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES); + const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0; + size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, sizeof(uint16_t), kP256_PublicKey_Length, + msg_r2_signed_enc_len, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, mrpParamsSize); System::PacketBufferHandle msg_R2 = System::PacketBufferHandle::New(data_len); VerifyOrReturnError(!msg_R2.IsNull(), CHIP_ERROR_NO_MEMORY); @@ -623,7 +625,7 @@ CHIP_ERROR CASESession::SendSigma2() if (mLocalMRPConfig.HasValue()) { ChipLogDetail(SecureChannel, "Including MRP parameters"); - ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter)); + ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriterMsg2)); } ReturnErrorOnFailure(tlvWriterMsg2.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriterMsg2.Finalize(&msg_R2)); @@ -680,7 +682,10 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg) VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); SuccessOrExit(err = tlvReader.Get(responderSessionId)); - SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + if (tlvReader.Next() != CHIP_END_OF_TLV) + { + SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(4), tlvReader)); + } ChipLogDetail(SecureChannel, "Peer assigned session session ID %d", responderSessionId); SetPeerSessionId(responderSessionId); @@ -852,7 +857,10 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) SetPeerCATs(peerCATs); // Retrieve responderMRPParams if present - SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + if (tlvReader.Next() != CHIP_END_OF_TLV) + { + SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader)); + } exit: if (err != CHIP_NO_ERROR) @@ -1378,7 +1386,7 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader, CHIP_ERROR err = tlvReader.Next(); if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kInitiatorMRPParamsTag)) { - ReturnErrorOnFailure(DecodeMRPParametersIfPresent(tlvReader)); + ReturnErrorOnFailure(DecodeMRPParametersIfPresent(TLV::ContextTag(kInitiatorMRPParamsTag), tlvReader)); err = tlvReader.Next(); } diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 5598f4cf915106..d43f01c5ae0eb8 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -442,7 +442,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); SuccessOrExit(err = tlvReader.Get(hasPBKDFParameters)); - SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + if (tlvReader.Next() != CHIP_END_OF_TLV) + { + SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader)); + } err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters); SuccessOrExit(err); @@ -562,7 +565,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m if (mHavePBKDFParameters) { - SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + if (tlvReader.Next() != CHIP_END_OF_TLV) + { + SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader)); + } // TODO - Add a unit test that exercises mHavePBKDFParameters path err = SetupSpake2p(mIterationCount, ByteSpan(mSalt, mSaltLength)); @@ -585,7 +591,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m SuccessOrExit(err = tlvReader.ExitContainer(containerType)); - SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader)); + if (tlvReader.Next() != CHIP_END_OF_TLV) + { + SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader)); + } err = SetupSpake2p(iterCount, ByteSpan(salt, saltLength)); SuccessOrExit(err); diff --git a/src/transport/PairingSession.cpp b/src/transport/PairingSession.cpp index affc27f51367a2..c19d4251a97943 100644 --- a/src/transport/PairingSession.cpp +++ b/src/transport/PairingSession.cpp @@ -36,15 +36,13 @@ CHIP_ERROR PairingSession::EncodeMRPParameters(TLV::Tag tag, const ReliableMessa return tlvWriter.EndContainer(mrpParamsContainer); } -CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::ContiguousBufferTLVReader & tlvReader) +CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TLV::ContiguousBufferTLVReader & tlvReader) { // The MRP parameters are optional. - CHIP_ERROR err = tlvReader.Next(); - if (err == CHIP_END_OF_TLV) + if (tlvReader.GetTag() != expectedTag) { return CHIP_NO_ERROR; } - ReturnErrorOnFailure(err); TLV::TLVType containerType = TLV::kTLVType_Structure; ReturnErrorOnFailure(tlvReader.EnterContainer(containerType)); @@ -63,10 +61,10 @@ CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::ContiguousBufferTLV mMRPConfig.mIdleRetransTimeout = System::Clock::Milliseconds32(tlvElementValue); // The next element is optional. If it's not present, return CHIP_NO_ERROR. - err = tlvReader.Next(); + CHIP_ERROR err = tlvReader.Next(); if (err == CHIP_END_OF_TLV) { - return CHIP_NO_ERROR; + return tlvReader.ExitContainer(containerType); } ReturnErrorOnFailure(err); } diff --git a/src/transport/PairingSession.h b/src/transport/PairingSession.h index b0f4769de6af4d..27e69a1c6a6542 100644 --- a/src/transport/PairingSession.h +++ b/src/transport/PairingSession.h @@ -163,7 +163,7 @@ class DLL_EXPORT PairingSession } /** - * Try to decode the next element (pointed by the TLV reader) as MRP parameters. + * Try to decode the current element (pointed by the TLV reader) as MRP parameters. * If the MRP parameters are found, mMRPConfig is updated with the devoded values. * * MRP parameters are optional. So, if the TLV reader is not pointing to the MRP parameters, @@ -172,7 +172,7 @@ class DLL_EXPORT PairingSession * If the parameters are present, but TLV reader fails to correctly parse it, the function will * return the corresponding error. */ - CHIP_ERROR DecodeMRPParametersIfPresent(TLV::ContiguousBufferTLVReader & tlvReader); + CHIP_ERROR DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TLV::ContiguousBufferTLVReader & tlvReader); // TODO: remove Clear, we should create a new instance instead reset the old instance. void Clear() diff --git a/src/transport/tests/TestPairingSession.cpp b/src/transport/tests/TestPairingSession.cpp index 7d7869805930f7..60cc82d41586fe 100644 --- a/src/transport/tests/TestPairingSession.cpp +++ b/src/transport/tests/TestPairingSession.cpp @@ -43,9 +43,9 @@ class TestPairingSession : public PairingSession public: CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override { return CHIP_NO_ERROR; } - CHIP_ERROR DecodeMRPParametersIfPresent(System::PacketBufferTLVReader & tlvReader) + CHIP_ERROR DecodeMRPParametersIfPresent(TLV::Tag expectedTag, System::PacketBufferTLVReader & tlvReader) { - return PairingSession::DecodeMRPParametersIfPresent(tlvReader); + return PairingSession::DecodeMRPParametersIfPresent(expectedTag, tlvReader); } }; @@ -62,8 +62,7 @@ void PairingSessionEncodeDecodeMRPParams(nlTestSuite * inSuite, void * inContext TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; NL_TEST_ASSERT(inSuite, writer.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType) == CHIP_NO_ERROR); - CHIP_ERROR err = PairingSession::EncodeMRPParameters(TLV::ContextTag(1), config, writer); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, PairingSession::EncodeMRPParameters(TLV::ContextTag(1), config, writer) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, writer.EndContainer(outerContainerType) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, writer.Finalize(&buf) == CHIP_NO_ERROR); @@ -75,8 +74,8 @@ void PairingSessionEncodeDecodeMRPParams(nlTestSuite * inSuite, void * inContext NL_TEST_ASSERT(inSuite, reader.Next(containerType, TLV::AnonymousTag) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, reader.EnterContainer(containerType) == CHIP_NO_ERROR); - err = session.DecodeMRPParametersIfPresent(reader); - NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, reader.Next() == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(TLV::ContextTag(1), reader) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mIdleRetransTimeout == config.mIdleRetransTimeout); NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mActiveRetransTimeout == config.mActiveRetransTimeout); @@ -92,6 +91,7 @@ void PairingSessionTryDecodeMissingMRPParams(nlTestSuite * inSuite, void * inCon TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; NL_TEST_ASSERT(inSuite, writer.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, writer.Put(TLV::ContextTag(1), static_cast(0x1234)) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, writer.EndContainer(outerContainerType) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, writer.Finalize(&buf) == CHIP_NO_ERROR); @@ -101,7 +101,8 @@ void PairingSessionTryDecodeMissingMRPParams(nlTestSuite * inSuite, void * inCon reader.Init(std::move(buf)); NL_TEST_ASSERT(inSuite, reader.Next(containerType, TLV::AnonymousTag) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, reader.EnterContainer(containerType) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(reader) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, reader.Next() == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(TLV::ContextTag(2), reader) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mIdleRetransTimeout == gDefaultMRPConfig.mIdleRetransTimeout); NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mActiveRetransTimeout == gDefaultMRPConfig.mActiveRetransTimeout);