From 2a8a9c07366c69a1e618a23b0a3a8c9669e61364 Mon Sep 17 00:00:00 2001 From: Marc Lepage <67919234+mlepage-google@users.noreply.github.com> Date: Thu, 27 Apr 2023 16:44:29 -0400 Subject: [PATCH] Async send sigma3 (#25695) * Break CASESession::SendSigma3 into fg/bg parts This unblocks the main event loop while performance intensive (e.g. crypto) parts of the process occur. * Fix host tests * Remove temp log statements * Restyle * Refactor CASESession::SendSigma3 Add more structure to manage multiple steps of work. * Remove temporary logging * Restyle * Minor cleanup * Minor cleanup * Restyle * Use Platform::SharedPtr Also add alias template for Platform::WeakPtr. * Add mutex to FabricTable This supports locking during SignWithOpKeypair, and other operations that may alter state in the fabric table, while CASESession is performing work in the background during session establishment. CASESession registers as a fabric table listener, and when a fabric is removed (e.g. timeout) it attempts to cancel any outstanding work, and also clears out the fabric index in the work helper data. Therefore, if outstanding work has made it into SignWithOpKeypair, it should be OK until complete. It still relies on other tasks not altering FabricInfo, or the configured OperationalKeystore, but that would have had to have been true before anyways. The mutex was not made recursive. It's omitted from a few functions, which should be OK for now, and there should be cleanup on a subsequent commit (and probably fix up const-ness of member functions, and factoring of API vs. impl functions). This commit is to flush out build/test errors on all CI platforms, and to discuss/review/comment on the general approach. * Remove mutex, only async sometimes It's tricky to async background the signing operation, because of the two ways operational signing occurs. Legacy way: opkeypair manually added for fabric info Recommended way: opkeystore handles everything Removed std::mutex because it wasn't supported by all platforms. Instead, made background signing occur only if using the operational keystore (recommended way), since implementors can perform any needed mutual exclusion in the operational keystore. If using manually added operational keypairs (legacy way), keep signing in the foreground, since it's not feasible to mutex the entire fabric table and typically the operations is simpler anyways. * Clean up error handling * Restyle * Only store data.fabricTable if fg case Store only one of data.fabricTable or data.keystore. * Declare wither signing in background is supported OperationalKeystore declares whether it supports this capability. If so, then CASE session establishment may take advantage of it. If not, then CASE session establishment must use foreground. * Make some variables const * Clean up a few comments --- src/credentials/FabricTable.h | 8 + src/crypto/OperationalKeystore.h | 14 +- src/lib/support/CHIPMem.h | 3 + src/protocols/secure_channel/CASESession.cpp | 420 ++++++++++++++---- src/protocols/secure_channel/CASESession.h | 40 +- .../secure_channel/tests/TestCASESession.cpp | 18 +- 6 files changed, 385 insertions(+), 118 deletions(-) diff --git a/src/credentials/FabricTable.h b/src/credentials/FabricTable.h index 42f88024b877c6..1b01e1f4fd4a4f 100644 --- a/src/credentials/FabricTable.h +++ b/src/credentials/FabricTable.h @@ -717,6 +717,14 @@ class DLL_EXPORT FabricTable */ bool HasOperationalKeyForFabric(FabricIndex fabricIndex) const; + /** + * @brief Returns the operational keystore. This is used for + * CASE and the only way the keystore should be used. + * + * @return The operational keystore, nullptr otherwise. + */ + const Crypto::OperationalKeystore * GetOperationalKeystore() { return mOperationalKeystore; } + /** * @brief Add a pending trusted root certificate for the next fabric created with `AddNewPendingFabric*` methods. * diff --git a/src/crypto/OperationalKeystore.h b/src/crypto/OperationalKeystore.h index bfa846b9d1a8be..6af92629174810 100644 --- a/src/crypto/OperationalKeystore.h +++ b/src/crypto/OperationalKeystore.h @@ -148,6 +148,19 @@ class OperationalKeystore virtual void RevertPendingKeypair() = 0; // ==== Primary operation required: signature + /** + * @brief Whether `SignWithOpKeypair` may be performed in the background. + * + * If true, `CASESession` may attempt to perform `SignWithOpKeypair` in the + * background. In this case, `OperationalKeystore` should protect itself, + * e.g. with a mutex, as the signing could occur at any time during session + * establishment. + * + * @retval true if `SignWithOpKeypair` may be performed in the background + * @retval false if `SignWithOpKeypair` may NOT be performed in the background + */ + virtual bool SupportsSignWithOpKeypairInBackground() const { return false; } + /** * @brief Sign a message with a fabric's currently-active operational keypair. * @@ -164,7 +177,6 @@ class OperationalKeystore * @retval CHIP_ERROR_INVALID_FABRIC_INDEX if no active key is found for the given `fabricIndex` or if * `fabricIndex` is invalid. * @retval other CHIP_ERROR value on internal crypto engine errors - * */ virtual CHIP_ERROR SignWithOpKeypair(FabricIndex fabricIndex, const ByteSpan & message, Crypto::P256ECDSASignature & outSignature) const = 0; diff --git a/src/lib/support/CHIPMem.h b/src/lib/support/CHIPMem.h index a1e27c20b73ce3..b4b78aca647e3a 100644 --- a/src/lib/support/CHIPMem.h +++ b/src/lib/support/CHIPMem.h @@ -193,6 +193,9 @@ inline SharedPtr MakeShared(Args &&... args) return SharedPtr(New(std::forward(args)...), Deleter()); } +template +using WeakPtr = std::weak_ptr; + // See MemoryDebugCheckPointer(). extern bool MemoryInternalCheckPointer(const void * p, size_t min_size); diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 1abd1f1b1e9e6b..20b3567cb5d74b 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -25,7 +25,9 @@ */ #include +#include #include +#include #include #include @@ -41,6 +43,7 @@ #include #include #include +#include #include #include #include @@ -129,6 +132,185 @@ static constexpr ExchangeContext::Timeout kExpectedLowProcessingTime = System static constexpr ExchangeContext::Timeout kExpectedSigma1ProcessingTime = kExpectedLowProcessingTime; static constexpr ExchangeContext::Timeout kExpectedHighProcessingTime = System::Clock::Seconds16(30); +// Helper for managing a session's outstanding work. +// Holds work data which is provided to a scheduled work callback (standalone), +// then (if not canceled) to a scheduled after work callback (on the session). +template +class CASESession::WorkHelper +{ +public: + // Work callback, processed in the background via `PlatformManager::ScheduleBackgroundWork`. + // This is a non-member function which does not use the associated session. + // The return value is passed to the after work callback (called afterward). + // Set `cancel` to true if calling the after work callback is not necessary. + typedef CHIP_ERROR (*WorkCallback)(DATA & data, bool & cancel); + + // After work callback, processed in the main Matter task via `PlatformManager::ScheduleWork`. + // This is a member function to be called on the associated session after the work callback. + // The `status` value is the result of the work callback (called beforehand). + typedef CHIP_ERROR (CASESession::*AfterWorkCallback)(DATA & data, CHIP_ERROR status); + + // Create a work helper using the specified session, work callback, after work callback, and data (template arg). + // Lifetime is not managed, see `Create` for that option. + WorkHelper(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) : + mSession(&session), mWorkCallback(workCallback), mAfterWorkCallback(afterWorkCallback) + {} + + // Create a work helper using the specified session, work callback, after work callback, and data (template arg). + // Lifetime is managed by sharing between the caller (typically the session) and the helper itself (while work is scheduled). + static Platform::SharedPtr Create(CASESession & session, WorkCallback workCallback, + AfterWorkCallback afterWorkCallback) + { + auto ptr = Platform::MakeShared(session, workCallback, afterWorkCallback); + if (ptr) + { + ptr->mWeakPtr = ptr; // used by `ScheduleWork` + } + return ptr; + } + + // Do the work immediately. + // No scheduling, no outstanding work, no shared lifetime management. + CHIP_ERROR DoWork() + { + if (!mSession || !mWorkCallback || !mAfterWorkCallback) + { + return CHIP_ERROR_INCORRECT_STATE; + } + WorkHelper * helper = this; + bool cancel = false; + helper->mStatus = helper->mWorkCallback(helper->mData, cancel); + if (!cancel) + { + helper->mStatus = (helper->mSession->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus); + } + return helper->mStatus; + } + + // Schedule the work after configuring the data. + // If lifetime is managed, the helper shares management while work is outstanding. + CHIP_ERROR ScheduleWork() + { + if (!mSession || !mWorkCallback || !mAfterWorkCallback) + { + return CHIP_ERROR_INCORRECT_STATE; + } + mStrongPtr = mWeakPtr.lock(); // set in `Create` + auto status = DeviceLayer::PlatformMgr().ScheduleBackgroundWork(WorkHandler, reinterpret_cast(this)); + if (status != CHIP_NO_ERROR) + { + mStrongPtr.reset(); + } + return status; + } + + // Cancel the work, by clearing the associated session. + void CancelWork() { mSession.store(nullptr); } + +private: + // Handler for the work callback. + static void WorkHandler(intptr_t arg) + { + WorkHelper * helper = reinterpret_cast(arg); + bool cancel = false; + VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`? + helper->mStatus = helper->mWorkCallback(helper->mData, cancel); + VerifyOrExit(!cancel, ;); // canceled by `mWorkCallback`? + VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`? + SuccessOrExit(DeviceLayer::PlatformMgr().ScheduleWork(AfterWorkHandler, reinterpret_cast(helper))); + return; + exit: + helper->mStrongPtr.reset(); + } + + // Handler for the after work callback. + static void AfterWorkHandler(intptr_t arg) + { + // Since this runs in the main Matter thread, the session shouldn't be otherwise used (messages, timers, etc.) + WorkHelper * helper = reinterpret_cast(arg); + if (auto * session = helper->mSession.load()) + { + (session->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus); + } + helper->mStrongPtr.reset(); + } + +private: + // Lifetime management: `ScheduleWork` sets `mStrongPtr` from `mWeakPtr`. + Platform::WeakPtr mWeakPtr; + + // Lifetime management: `ScheduleWork` sets `mStrongPtr` from `mWeakPtr`. + Platform::SharedPtr mStrongPtr; + + // Associated session, cleared by `CancelWork`. + std::atomic mSession; + + // Work callback, called by `WorkHandler`. + WorkCallback mWorkCallback; + + // After work callback, called by `AfterWorkHandler`. + AfterWorkCallback mAfterWorkCallback; + + // Return value of `mWorkCallback`, passed to `mAfterWorkCallback`. + CHIP_ERROR mStatus; + +public: + // Data passed to `mWorkCallback` and `mAfterWorkCallback`. + DATA mData; +}; + +struct CASESession::SendSigma3Data +{ + std::atomic fabricIndex; + + // Use one or the other + const FabricTable * fabricTable; + const Crypto::OperationalKeystore * keystore; + + chip::Platform::ScopedMemoryBuffer msg_R3_Signed; + size_t msg_r3_signed_len; + + chip::Platform::ScopedMemoryBuffer msg_R3_Encrypted; + size_t msg_r3_encrypted_len; + + chip::Platform::ScopedMemoryBuffer icacBuf; + MutableByteSpan icaCert; + + chip::Platform::ScopedMemoryBuffer nocBuf; + MutableByteSpan nocCert; + + P256ECDSASignature tbsData3Signature; +}; + +struct CASESession::HandleSigma3Work +{ + // Status of background processing. + CHIP_ERROR status; + + // Session to use after background processing. + CASESession * session; + + // Sequence number used to coordinate foreground/background work for a + // particular session establishment. + int sequence; + + chip::Platform::ScopedMemoryBuffer msg_R3_Signed; + size_t msg_r3_signed_len; + + ByteSpan initiatorNOC; + ByteSpan initiatorICAC; + + uint8_t rootCertBuf[kMaxCHIPCertLength]; + ByteSpan fabricRCAC; + + P256ECDSASignature tbsData3Signature; + + FabricId fabricId; + NodeId initiatorNodeId; + + ValidationContext validContext; +}; + CASESession::~CASESession() { // Let's clear out any security state stored in the object, before destroying it. @@ -144,6 +326,14 @@ void CASESession::OnSessionReleased() void CASESession::Clear() { + // Cancel any outstanding work. + if (mSendSigma3Helper) + { + mSendSigma3Helper->mData.fabricIndex = kUndefinedFabricIndex; + mSendSigma3Helper->CancelWork(); + mSendSigma3Helper.reset(); + } + // This function zeroes out and resets the memory used by the object. // It's done so that no security related information will be leaked. mCommissioningHash.Clear(); @@ -925,7 +1115,7 @@ CHIP_ERROR CASESession::HandleSigma2_and_SendSigma3(System::PacketBufferHandle & { MATTER_TRACE_EVENT_SCOPE("HandleSigma2_and_SendSigma3", "CASESession"); ReturnErrorOnFailure(HandleSigma2(std::move(msg))); - ReturnErrorOnFailure(SendSigma3()); + ReturnErrorOnFailure(SendSigma3a()); return CHIP_NO_ERROR; } @@ -1099,92 +1289,156 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg) return err; } -CHIP_ERROR CASESession::SendSigma3() +CHIP_ERROR CASESession::SendSigma3a() { MATTER_TRACE_EVENT_SCOPE("SendSigma3", "CASESession"); CHIP_ERROR err = CHIP_NO_ERROR; - MutableByteSpan messageDigestSpan(mMessageDigest); - System::PacketBufferHandle msg_R3; - size_t data_len; - - chip::Platform::ScopedMemoryBuffer msg_R3_Encrypted; - size_t msg_r3_encrypted_len; + ChipLogDetail(SecureChannel, "Sending Sigma3"); - uint8_t msg_salt[kIPKSize + kSHA256_Hash_Length]; + auto helper = WorkHelper::Create(*this, &SendSigma3b, &CASESession::SendSigma3c); + VerifyOrExit(helper, err = CHIP_ERROR_NO_MEMORY); + { + auto & data = helper->mData; - AutoReleaseSessionKey sr3k(*mSessionManager->GetSessionKeystore()); + VerifyOrExit(mFabricsTable != nullptr, err = CHIP_ERROR_INCORRECT_STATE); + data.fabricIndex = mFabricIndex; + data.fabricTable = nullptr; + data.keystore = nullptr; - chip::Platform::ScopedMemoryBuffer msg_R3_Signed; - size_t msg_r3_signed_len; + { + const FabricInfo * fabricInfo = mFabricsTable->FindFabricWithIndex(mFabricIndex); + VerifyOrExit(fabricInfo != nullptr, err = CHIP_ERROR_KEY_NOT_FOUND); + auto * keystore = mFabricsTable->GetOperationalKeystore(); + if (!fabricInfo->HasOperationalKey() && keystore != nullptr && keystore->SupportsSignWithOpKeypairInBackground()) + { + // NOTE: used to sign in background. + data.keystore = keystore; + } + else + { + // NOTE: used to sign in foreground. + data.fabricTable = mFabricsTable; + } + } - P256ECDSASignature tbsData3Signature; + VerifyOrExit(mEphemeralKey != nullptr, err = CHIP_ERROR_INTERNAL); - chip::Platform::ScopedMemoryBuffer icacBuf; - MutableByteSpan icaCert; + VerifyOrExit(data.icacBuf.Alloc(kMaxCHIPCertLength), err = CHIP_ERROR_NO_MEMORY); + data.icaCert = MutableByteSpan{ data.icacBuf.Get(), kMaxCHIPCertLength }; - chip::Platform::ScopedMemoryBuffer nocBuf; - MutableByteSpan nocCert; + VerifyOrExit(data.nocBuf.Alloc(kMaxCHIPCertLength), err = CHIP_ERROR_NO_MEMORY); + data.nocCert = MutableByteSpan{ data.nocBuf.Get(), kMaxCHIPCertLength }; - ChipLogDetail(SecureChannel, "Sending Sigma3"); + SuccessOrExit(err = mFabricsTable->FetchICACert(mFabricIndex, data.icaCert)); + SuccessOrExit(err = mFabricsTable->FetchNOCCert(mFabricIndex, data.nocCert)); - VerifyOrExit(mEphemeralKey != nullptr, err = CHIP_ERROR_INTERNAL); - VerifyOrExit(icacBuf.Alloc(kMaxCHIPCertLength), err = CHIP_ERROR_NO_MEMORY); - icaCert = MutableByteSpan{ icacBuf.Get(), kMaxCHIPCertLength }; + // Prepare Sigma3 TBS Data Blob + data.msg_r3_signed_len = + TLV::EstimateStructOverhead(data.icaCert.size(), data.nocCert.size(), kP256_PublicKey_Length, kP256_PublicKey_Length); - VerifyOrExit(nocBuf.Alloc(kMaxCHIPCertLength), err = CHIP_ERROR_NO_MEMORY); - nocCert = MutableByteSpan{ nocBuf.Get(), kMaxCHIPCertLength }; + VerifyOrExit(data.msg_R3_Signed.Alloc(data.msg_r3_signed_len), err = CHIP_ERROR_NO_MEMORY); - VerifyOrExit(mFabricsTable != nullptr, err = CHIP_ERROR_INCORRECT_STATE); + SuccessOrExit(err = ConstructTBSData( + data.nocCert, data.icaCert, ByteSpan(mEphemeralKey->Pubkey(), mEphemeralKey->Pubkey().Length()), + ByteSpan(mRemotePubKey, mRemotePubKey.Length()), data.msg_R3_Signed.Get(), data.msg_r3_signed_len)); - SuccessOrExit(err = mFabricsTable->FetchICACert(mFabricIndex, icaCert)); - SuccessOrExit(err = mFabricsTable->FetchNOCCert(mFabricIndex, nocCert)); + if (data.keystore != nullptr) + { + SuccessOrExit(err = helper->ScheduleWork()); + mSendSigma3Helper = helper; + mExchangeCtxt->WillSendMessage(); + mState = State::kSendSigma3Pending; + } + else + { + SuccessOrExit(err = helper->DoWork()); + } + } - // Prepare Sigma3 TBS Data Blob - msg_r3_signed_len = TLV::EstimateStructOverhead(icaCert.size(), nocCert.size(), kP256_PublicKey_Length, kP256_PublicKey_Length); +exit: + if (err != CHIP_NO_ERROR) + { + SendStatusReport(mExchangeCtxt, kProtocolCodeInvalidParam); + mState = State::kInitialized; + } - VerifyOrExit(msg_R3_Signed.Alloc(msg_r3_signed_len), err = CHIP_ERROR_NO_MEMORY); + return err; +} - SuccessOrExit(err = ConstructTBSData(nocCert, icaCert, ByteSpan(mEphemeralKey->Pubkey(), mEphemeralKey->Pubkey().Length()), - ByteSpan(mRemotePubKey, mRemotePubKey.Length()), msg_R3_Signed.Get(), msg_r3_signed_len)); +CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel) +{ + CHIP_ERROR err = CHIP_NO_ERROR; // Generate a signature - err = mFabricsTable->SignWithOpKeypair(mFabricIndex, ByteSpan{ msg_R3_Signed.Get(), msg_r3_signed_len }, tbsData3Signature); + if (data.keystore != nullptr) + { + // Recommended case: delegate to operational keystore + err = data.keystore->SignWithOpKeypair(data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, + data.tbsData3Signature); + } + else + { + // Legacy case: delegate to fabric table fabric info + err = data.fabricTable->SignWithOpKeypair(data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, + data.tbsData3Signature); + } SuccessOrExit(err); // Prepare Sigma3 TBE Data Blob - msg_r3_encrypted_len = TLV::EstimateStructOverhead(nocCert.size(), icaCert.size(), tbsData3Signature.Length()); + data.msg_r3_encrypted_len = + TLV::EstimateStructOverhead(data.nocCert.size(), data.icaCert.size(), data.tbsData3Signature.Length()); - VerifyOrExit(msg_R3_Encrypted.Alloc(msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES), err = CHIP_ERROR_NO_MEMORY); + VerifyOrExit(data.msg_R3_Encrypted.Alloc(data.msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES), + err = CHIP_ERROR_NO_MEMORY); { TLV::TLVWriter tlvWriter; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; - tlvWriter.Init(msg_R3_Encrypted.Get(), msg_r3_encrypted_len); + tlvWriter.Init(data.msg_R3_Encrypted.Get(), data.msg_r3_encrypted_len); SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType)); - SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), nocCert)); - if (!icaCert.empty()) + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), data.nocCert)); + if (!data.icaCert.empty()) { - SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), icaCert)); + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), data.icaCert)); } // We are now done with ICAC and NOC certs so we can release the memory. { - icacBuf.Free(); - icaCert = MutableByteSpan{}; + data.icacBuf.Free(); + data.icaCert = MutableByteSpan{}; - nocBuf.Free(); - nocCert = MutableByteSpan{}; + data.nocBuf.Free(); + data.nocCert = MutableByteSpan{}; } - SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), tbsData3Signature.ConstBytes(), - static_cast(tbsData3Signature.Length()))); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), data.tbsData3Signature.ConstBytes(), + static_cast(data.tbsData3Signature.Length()))); SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize()); - msg_r3_encrypted_len = static_cast(tlvWriter.GetLengthWritten()); + data.msg_r3_encrypted_len = static_cast(tlvWriter.GetLengthWritten()); } +exit: + return err; +} + +CHIP_ERROR CASESession::SendSigma3c(SendSigma3Data & data, CHIP_ERROR status) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + + System::PacketBufferHandle msg_R3; + size_t data_len; + + uint8_t msg_salt[kIPKSize + kSHA256_Hash_Length]; + + AutoReleaseSessionKey sr3k(*mSessionManager->GetSessionKeystore()); + + VerifyOrDieWithMsg(data.keystore == nullptr || mState == State::kSendSigma3Pending, SecureChannel, "Bad internal state."); + + SuccessOrExit(err = status); + // Generate S3K key { MutableByteSpan saltSpan(msg_salt); @@ -1193,12 +1447,13 @@ CHIP_ERROR CASESession::SendSigma3() } // Generated Encrypted data blob - SuccessOrExit(err = AES_CCM_encrypt(msg_R3_Encrypted.Get(), msg_r3_encrypted_len, nullptr, 0, sr3k.KeyHandle(), kTBEData3_Nonce, - kTBEDataNonceLength, msg_R3_Encrypted.Get(), msg_R3_Encrypted.Get() + msg_r3_encrypted_len, - CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES)); + SuccessOrExit(err = + AES_CCM_encrypt(data.msg_R3_Encrypted.Get(), data.msg_r3_encrypted_len, nullptr, 0, sr3k.KeyHandle(), + kTBEData3_Nonce, kTBEDataNonceLength, data.msg_R3_Encrypted.Get(), + data.msg_R3_Encrypted.Get() + data.msg_r3_encrypted_len, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES)); // Generate Sigma3 Msg - data_len = TLV::EstimateStructOverhead(CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, msg_r3_encrypted_len); + data_len = TLV::EstimateStructOverhead(CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, data.msg_r3_encrypted_len); msg_R3 = System::PacketBufferHandle::New(data_len); VerifyOrExit(!msg_R3.IsNull(), err = CHIP_ERROR_NO_MEMORY); @@ -1210,8 +1465,8 @@ CHIP_ERROR CASESession::SendSigma3() tlvWriter.Init(std::move(msg_R3)); err = tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType); SuccessOrExit(err); - err = tlvWriter.PutBytes(TLV::ContextTag(1), msg_R3_Encrypted.Get(), - static_cast(msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES)); + err = tlvWriter.PutBytes(TLV::ContextTag(1), data.msg_R3_Encrypted.Get(), + static_cast(data.msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES)); SuccessOrExit(err); err = tlvWriter.EndContainer(outerContainerType); SuccessOrExit(err); @@ -1229,50 +1484,29 @@ CHIP_ERROR CASESession::SendSigma3() ChipLogProgress(SecureChannel, "Sent Sigma3 msg"); - err = mCommissioningHash.Finish(messageDigestSpan); - SuccessOrExit(err); + { + MutableByteSpan messageDigestSpan(mMessageDigest); + SuccessOrExit(err = mCommissioningHash.Finish(messageDigestSpan)); + } mState = State::kSentSigma3; exit: + mSendSigma3Helper.reset(); - if (err != CHIP_NO_ERROR) + // If data.keystore is set, processing occurred in the background, so if an error occurred, + // need to send status report (normally occurs in SendSigma3a), and discard exchange and + // abort pending establish (normally occurs in OnMessageReceived). + if (data.keystore != nullptr && err != CHIP_NO_ERROR) { SendStatusReport(mExchangeCtxt, kProtocolCodeInvalidParam); - mState = State::kInitialized; + DiscardExchange(); + AbortPendingEstablish(err); } + return err; } -struct CASESession::Sigma3Work -{ - // Status of background processing. - CHIP_ERROR status; - - // Session to use after background processing. - CASESession * session; - - // Sequence number used to coordinate foreground/background work for a - // particular session establishment. - int sequence; - - chip::Platform::ScopedMemoryBuffer msg_R3_Signed; - size_t msg_r3_signed_len; - - ByteSpan initiatorNOC; - ByteSpan initiatorICAC; - - uint8_t rootCertBuf[kMaxCHIPCertLength]; - ByteSpan fabricRCAC; - - P256ECDSASignature tbsData3Signature; - - FabricId fabricId; - NodeId initiatorNodeId; - - ValidationContext validContext; -}; - CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) { MATTER_TRACE_EVENT_SCOPE("HandleSigma3", "CASESession"); @@ -1297,7 +1531,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) ChipLogProgress(SecureChannel, "Received Sigma3 msg"); - auto * workPtr = Platform::New(); + auto * workPtr = Platform::New(); VerifyOrExit(workPtr != nullptr, err = CHIP_ERROR_NO_MEMORY); { auto & work = *workPtr; @@ -1415,12 +1649,12 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) } } - SuccessOrExit( - err = DeviceLayer::PlatformMgr().ScheduleBackgroundWork( - [](intptr_t arg) { HandleSigma3b(*reinterpret_cast(arg)); }, reinterpret_cast(&work))); + SuccessOrExit(err = DeviceLayer::PlatformMgr().ScheduleBackgroundWork( + [](intptr_t arg) { HandleSigma3b(*reinterpret_cast(arg)); }, + reinterpret_cast(&work))); workPtr = nullptr; // scheduling succeeded, so don't delete mExchangeCtxt->WillSendMessage(); - mState = State::kBackgroundPending; + mState = State::kHandleSigma3Pending; } exit: @@ -1434,7 +1668,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) return err; } -void CASESession::HandleSigma3b(Sigma3Work & work) +void CASESession::HandleSigma3b(HandleSigma3Work & work) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -1471,7 +1705,7 @@ void CASESession::HandleSigma3b(Sigma3Work & work) auto err2 = DeviceLayer::PlatformMgr().ScheduleWork( [](intptr_t arg) { - auto & work2 = *reinterpret_cast(arg); + auto & work2 = *reinterpret_cast(arg); work2.session->HandleSigma3c(work2); }, reinterpret_cast(&work)); @@ -1482,7 +1716,7 @@ void CASESession::HandleSigma3b(Sigma3Work & work) } } -CHIP_ERROR CASESession::HandleSigma3c(Sigma3Work & work) +CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Work & work) { CHIP_ERROR err = CHIP_NO_ERROR; bool ignoreFailure = true; @@ -1490,7 +1724,7 @@ CHIP_ERROR CASESession::HandleSigma3c(Sigma3Work & work) // Special case: if for whatever reason not in expected state or sequence, // don't do anything, including sending a status report or aborting the // pending establish. - VerifyOrExit(mState == State::kBackgroundPending, err = CHIP_ERROR_INCORRECT_STATE); + VerifyOrExit(mState == State::kHandleSigma3Pending, err = CHIP_ERROR_INCORRECT_STATE); VerifyOrExit(mSequence == work.sequence, err = CHIP_ERROR_INCORRECT_STATE); ignoreFailure = false; diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index b6516aa5ca0612..21578e592ba907 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -66,7 +67,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, /** * @brief - * Initialize using configured fabrics and wait for session establishment requests. + * Initialize using configured fabrics and wait for session establishment requests (as a responder). * * @param sessionManager session manager from which to allocate a secure session object * @param fabricTable Table of fabrics that are currently configured on the device @@ -88,7 +89,7 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, /** * @brief - * Create and send session establishment request using device's operational credentials. + * Create and send session establishment request (as an initiator) using device's operational credentials. * * @param sessionManager session manager from which to allocate a secure session object * @param fabricTable The fabric table that contains a fabric in common with the peer @@ -194,15 +195,16 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, friend class TestCASESession; enum class State : uint8_t { - kInitialized = 0, - kSentSigma1 = 1, - kSentSigma2 = 2, - kSentSigma3 = 3, - kSentSigma1Resume = 4, - kSentSigma2Resume = 5, - kFinished = 6, - kFinishedViaResume = 7, - kBackgroundPending = 8, + kInitialized = 0, + kSentSigma1 = 1, + kSentSigma2 = 2, + kSentSigma3 = 3, + kSentSigma1Resume = 4, + kSentSigma2Resume = 5, + kFinished = 6, + kFinishedViaResume = 7, + kSendSigma3Pending = 8, + kHandleSigma3Pending = 9, }; /* @@ -233,11 +235,15 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, CHIP_ERROR HandleSigma2(System::PacketBufferHandle && msg); CHIP_ERROR HandleSigma2Resume(System::PacketBufferHandle && msg); - CHIP_ERROR SendSigma3(); - struct Sigma3Work; + struct SendSigma3Data; + CHIP_ERROR SendSigma3a(); + static CHIP_ERROR SendSigma3b(SendSigma3Data & data, bool & cancel); + CHIP_ERROR SendSigma3c(SendSigma3Data & data, CHIP_ERROR status); + + struct HandleSigma3Work; CHIP_ERROR HandleSigma3a(System::PacketBufferHandle && msg); - static void HandleSigma3b(Sigma3Work & work); - CHIP_ERROR HandleSigma3c(Sigma3Work & work); + static void HandleSigma3b(HandleSigma3Work & work); + CHIP_ERROR HandleSigma3c(HandleSigma3Work & work); CHIP_ERROR SendSigma2Resume(); @@ -301,6 +307,10 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, // particular session establishment. int mSequence = 0; + template + class WorkHelper; + Platform::SharedPtr> mSendSigma3Helper; + State mState; #if CONFIG_BUILD_FOR_HOST_UNIT_TEST diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index a863b3ca57dae8..11d5000f71df03 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -60,16 +60,16 @@ namespace { void ServiceEvents(TestContext & ctx) { - // Service any messages - ctx.DrainAndServiceIO(); - - // Messages may have scheduled work, so service them - chip::DeviceLayer::PlatformMgr().ScheduleWork([](intptr_t) -> void { chip::DeviceLayer::PlatformMgr().StopEventLoopTask(); }, - (intptr_t) nullptr); - chip::DeviceLayer::PlatformMgr().RunEventLoop(); + // Takes a few rounds of this because handling IO messages may schedule work, + // and scheduled work may queue messages for sending... + for (int i = 0; i < 3; ++i) + { + ctx.DrainAndServiceIO(); - // Work may have sent messages, so service them - ctx.DrainAndServiceIO(); + chip::DeviceLayer::PlatformMgr().ScheduleWork( + [](intptr_t) -> void { chip::DeviceLayer::PlatformMgr().StopEventLoopTask(); }, (intptr_t) nullptr); + chip::DeviceLayer::PlatformMgr().RunEventLoop(); + } } class TemporarySessionManager