diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 20b3567cb5d74b..2e2fcda957b391 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -177,9 +177,9 @@ class CASESession::WorkHelper { return CHIP_ERROR_INCORRECT_STATE; } - WorkHelper * helper = this; - bool cancel = false; - helper->mStatus = helper->mWorkCallback(helper->mData, cancel); + auto * helper = this; + bool cancel = false; + helper->mStatus = helper->mWorkCallback(helper->mData, cancel); if (!cancel) { helper->mStatus = (helper->mSession->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus); @@ -211,8 +211,8 @@ class CASESession::WorkHelper // Handler for the work callback. static void WorkHandler(intptr_t arg) { - WorkHelper * helper = reinterpret_cast(arg); - bool cancel = false; + auto * helper = reinterpret_cast(arg); + bool cancel = false; VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`? helper->mStatus = helper->mWorkCallback(helper->mData, cancel); VerifyOrExit(!cancel, ;); // canceled by `mWorkCallback`? @@ -227,7 +227,7 @@ class CASESession::WorkHelper static void AfterWorkHandler(intptr_t arg) { // Since this runs in the main Matter thread, the session shouldn't be otherwise used (messages, timers, etc.) - WorkHelper * helper = reinterpret_cast(arg); + auto * helper = reinterpret_cast(arg); if (auto * session = helper->mSession.load()) { (session->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus); @@ -282,18 +282,8 @@ struct CASESession::SendSigma3Data P256ECDSASignature tbsData3Signature; }; -struct CASESession::HandleSigma3Work +struct CASESession::HandleSigma3Data { - // Status of background processing. - CHIP_ERROR status; - - // Session to use after background processing. - CASESession * session; - - // Sequence number used to coordinate foreground/background work for a - // particular session establishment. - int sequence; - chip::Platform::ScopedMemoryBuffer msg_R3_Signed; size_t msg_r3_signed_len; @@ -333,6 +323,11 @@ void CASESession::Clear() mSendSigma3Helper->CancelWork(); mSendSigma3Helper.reset(); } + if (mHandleSigma3Helper) + { + mHandleSigma3Helper->CancelWork(); + mHandleSigma3Helper.reset(); + } // This function zeroes out and resets the memory used by the object. // It's done so that no security related information will be leaked. @@ -405,10 +400,6 @@ CASESession::PrepareForSessionEstablishment(SessionManager & sessionManager, Fab CHIP_ERROR err = CHIP_NO_ERROR; - // Sequence number used to coordinate foreground/background work for a - // particular session establishment. - mSequence++; - SuccessOrExit(err = fabricTable->AddFabricDelegate(this)); mFabricsTable = fabricTable; @@ -1531,23 +1522,16 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) ChipLogProgress(SecureChannel, "Received Sigma3 msg"); - auto * workPtr = Platform::New(); - VerifyOrExit(workPtr != nullptr, err = CHIP_ERROR_NO_MEMORY); + auto helper = WorkHelper::Create(*this, &HandleSigma3b, &CASESession::HandleSigma3c); + VerifyOrExit(helper, err = CHIP_ERROR_NO_MEMORY); { - auto & work = *workPtr; - - // Used to call back into the session after background event processing. - // It happens that there's only one pairing session (in CASEServer) - // so it will still be available for use. Use a sequence number to - // coordinate. - work.session = this; - work.sequence = mSequence; + auto & data = helper->mData; { VerifyOrExit(mFabricsTable != nullptr, err = CHIP_ERROR_INCORRECT_STATE); const auto * fabricInfo = mFabricsTable->FindFabricWithIndex(mFabricIndex); VerifyOrExit(fabricInfo != nullptr, err = CHIP_ERROR_INCORRECT_STATE); - work.fabricId = fabricInfo->GetFabricId(); + data.fabricId = fabricInfo->GetFabricId(); } VerifyOrExit(mEphemeralKey != nullptr, err = CHIP_ERROR_INTERNAL); @@ -1558,7 +1542,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) // Fetch encrypted data max_msg_r3_signed_enc_len = TLV::EstimateStructOverhead(Credentials::kMaxCHIPCertLength, Credentials::kMaxCHIPCertLength, - work.tbsData3Signature.Length(), kCaseOverheadForFutureTbeData); + data.tbsData3Signature.Length(), kCaseOverheadForFutureTbeData); SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_Sigma3_Encrypted3))); @@ -1592,74 +1576,70 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) SuccessOrExit(err = decryptedDataTlvReader.EnterContainer(containerType)); SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_SenderNOC))); - SuccessOrExit(err = decryptedDataTlvReader.Get(work.initiatorNOC)); + SuccessOrExit(err = decryptedDataTlvReader.Get(data.initiatorNOC)); SuccessOrExit(err = decryptedDataTlvReader.Next()); if (TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == kTag_TBEData_SenderICAC) { VerifyOrExit(decryptedDataTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - SuccessOrExit(err = decryptedDataTlvReader.Get(work.initiatorICAC)); + SuccessOrExit(err = decryptedDataTlvReader.Get(data.initiatorICAC)); SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_Signature))); } // Step 4 - Construct Sigma3 TBS Data - work.msg_r3_signed_len = TLV::EstimateStructOverhead(sizeof(uint16_t), work.initiatorNOC.size(), work.initiatorICAC.size(), + data.msg_r3_signed_len = TLV::EstimateStructOverhead(sizeof(uint16_t), data.initiatorNOC.size(), data.initiatorICAC.size(), kP256_PublicKey_Length, kP256_PublicKey_Length); - VerifyOrExit(work.msg_R3_Signed.Alloc(work.msg_r3_signed_len), err = CHIP_ERROR_NO_MEMORY); + VerifyOrExit(data.msg_R3_Signed.Alloc(data.msg_r3_signed_len), err = CHIP_ERROR_NO_MEMORY); - SuccessOrExit(err = ConstructTBSData(work.initiatorNOC, work.initiatorICAC, ByteSpan(mRemotePubKey, mRemotePubKey.Length()), + SuccessOrExit(err = ConstructTBSData(data.initiatorNOC, data.initiatorICAC, ByteSpan(mRemotePubKey, mRemotePubKey.Length()), ByteSpan(mEphemeralKey->Pubkey(), mEphemeralKey->Pubkey().Length()), - work.msg_R3_Signed.Get(), work.msg_r3_signed_len)); + data.msg_R3_Signed.Get(), data.msg_r3_signed_len)); VerifyOrExit(TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == kTag_TBEData_Signature, err = CHIP_ERROR_INVALID_TLV_TAG); - VerifyOrExit(work.tbsData3Signature.Capacity() >= decryptedDataTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); - work.tbsData3Signature.SetLength(decryptedDataTlvReader.GetLength()); - SuccessOrExit(err = decryptedDataTlvReader.GetBytes(work.tbsData3Signature.Bytes(), work.tbsData3Signature.Length())); + VerifyOrExit(data.tbsData3Signature.Capacity() >= decryptedDataTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); + data.tbsData3Signature.SetLength(decryptedDataTlvReader.GetLength()); + SuccessOrExit(err = decryptedDataTlvReader.GetBytes(data.tbsData3Signature.Bytes(), data.tbsData3Signature.Length())); // Prepare for Step 5/6 { - MutableByteSpan fabricRCAC{ work.rootCertBuf }; + MutableByteSpan fabricRCAC{ data.rootCertBuf }; SuccessOrExit(err = mFabricsTable->FetchRootCert(mFabricIndex, fabricRCAC)); - work.fabricRCAC = fabricRCAC; + data.fabricRCAC = fabricRCAC; // TODO probably should make SetEffectiveTime static and call closer to VerifyCredentials SuccessOrExit(err = SetEffectiveTime()); } // Copy remaining needed data into work structure { - work.validContext = mValidContext; + data.validContext = mValidContext; // initiatorNOC and initiatorICAC are spans into msg_R3_Encrypted // which is going away, so to save memory, redirect them to their // copies in msg_R3_signed, which is staying around TLV::TLVReader signedDataTlvReader; - signedDataTlvReader.Init(work.msg_R3_Signed.Get(), work.msg_r3_signed_len); + signedDataTlvReader.Init(data.msg_R3_Signed.Get(), data.msg_r3_signed_len); SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag())); SuccessOrExit(err = signedDataTlvReader.EnterContainer(containerType)); SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBSData_SenderNOC))); - SuccessOrExit(err = signedDataTlvReader.Get(work.initiatorNOC)); + SuccessOrExit(err = signedDataTlvReader.Get(data.initiatorNOC)); - if (!work.initiatorICAC.empty()) + if (!data.initiatorICAC.empty()) { SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBSData_SenderICAC))); - SuccessOrExit(err = signedDataTlvReader.Get(work.initiatorICAC)); + SuccessOrExit(err = signedDataTlvReader.Get(data.initiatorICAC)); } } - SuccessOrExit(err = DeviceLayer::PlatformMgr().ScheduleBackgroundWork( - [](intptr_t arg) { HandleSigma3b(*reinterpret_cast(arg)); }, - reinterpret_cast(&work))); - workPtr = nullptr; // scheduling succeeded, so don't delete + SuccessOrExit(err = helper->ScheduleWork()); + mHandleSigma3Helper = helper; mExchangeCtxt->WillSendMessage(); mState = State::kHandleSigma3Pending; } exit: - Platform::Delete(workPtr); - if (err != CHIP_NO_ERROR) { SendStatusReport(mExchangeCtxt, kProtocolCodeInvalidParam); @@ -1668,7 +1648,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg) return err; } -void CASESession::HandleSigma3b(HandleSigma3Work & work) +CHIP_ERROR CASESession::HandleSigma3b(HandleSigma3Data & data, bool & cancel) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -1678,9 +1658,9 @@ void CASESession::HandleSigma3b(HandleSigma3Work & work) CompressedFabricId unused; FabricId initiatorFabricId; P256PublicKey initiatorPublicKey; - SuccessOrExit(err = FabricTable::VerifyCredentials(work.initiatorNOC, work.initiatorICAC, work.fabricRCAC, work.validContext, - unused, initiatorFabricId, work.initiatorNodeId, initiatorPublicKey)); - VerifyOrExit(work.fabricId == initiatorFabricId, err = CHIP_ERROR_INVALID_CASE_PARAMETER); + SuccessOrExit(err = FabricTable::VerifyCredentials(data.initiatorNOC, data.initiatorICAC, data.fabricRCAC, data.validContext, + unused, initiatorFabricId, data.initiatorNodeId, initiatorPublicKey)); + VerifyOrExit(data.fabricId == initiatorFabricId, err = CHIP_ERROR_INVALID_CASE_PARAMETER); // TODO - Validate message signature prior to validating the received operational credentials. // The op cert check requires traversal of cert chain, that is a more expensive operation. @@ -1692,46 +1672,27 @@ void CASESession::HandleSigma3b(HandleSigma3Work & work) { P256PublicKeyHSM initiatorPublicKeyHSM; memcpy(Uint8::to_uchar(initiatorPublicKeyHSM), initiatorPublicKey.Bytes(), initiatorPublicKey.Length()); - SuccessOrExit(err = initiatorPublicKeyHSM.ECDSA_validate_msg_signature(work.msg_R3_Signed.Get(), work.msg_r3_signed_len, - work.tbsData3Signature)); + SuccessOrExit(err = initiatorPublicKeyHSM.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len, + data.tbsData3Signature)); } #else - SuccessOrExit(err = initiatorPublicKey.ECDSA_validate_msg_signature(work.msg_R3_Signed.Get(), work.msg_r3_signed_len, - work.tbsData3Signature)); + SuccessOrExit(err = initiatorPublicKey.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len, + data.tbsData3Signature)); #endif exit: - work.status = err; - - auto err2 = DeviceLayer::PlatformMgr().ScheduleWork( - [](intptr_t arg) { - auto & work2 = *reinterpret_cast(arg); - work2.session->HandleSigma3c(work2); - }, - reinterpret_cast(&work)); - - if (err2 != CHIP_NO_ERROR) - { - Platform::Delete(&work); // scheduling failed, so delete - } + return err; } -CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Work & work) +CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Data & data, CHIP_ERROR status) { - CHIP_ERROR err = CHIP_NO_ERROR; - bool ignoreFailure = true; + CHIP_ERROR err = CHIP_NO_ERROR; - // Special case: if for whatever reason not in expected state or sequence, - // don't do anything, including sending a status report or aborting the - // pending establish. VerifyOrExit(mState == State::kHandleSigma3Pending, err = CHIP_ERROR_INCORRECT_STATE); - VerifyOrExit(mSequence == work.sequence, err = CHIP_ERROR_INCORRECT_STATE); - - ignoreFailure = false; - SuccessOrExit(err = work.status); + SuccessOrExit(err = status); - mPeerNodeId = work.initiatorNodeId; + mPeerNodeId = data.initiatorNodeId; { MutableByteSpan messageDigestSpan(mMessageDigest); @@ -1740,7 +1701,7 @@ CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Work & work) // Retrieve peer CASE Authenticated Tags (CATs) from peer's NOC. { - SuccessOrExit(err = ExtractCATsFromOpCert(work.initiatorNOC, mPeerCATs)); + SuccessOrExit(err = ExtractCATsFromOpCert(data.initiatorNOC, mPeerCATs)); } if (mSessionResumptionStorage != nullptr) @@ -1758,9 +1719,9 @@ CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Work & work) Finish(); exit: - Platform::Delete(&work); + mHandleSigma3Helper.reset(); - if (err != CHIP_NO_ERROR && !ignoreFailure) + if (err != CHIP_NO_ERROR) { SendStatusReport(mExchangeCtxt, kProtocolCodeInvalidParam); // Abort the pending establish, which is normally done by CASESession::OnMessageReceived, diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 21578e592ba907..f0caa1675f7845 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -240,10 +240,10 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, static CHIP_ERROR SendSigma3b(SendSigma3Data & data, bool & cancel); CHIP_ERROR SendSigma3c(SendSigma3Data & data, CHIP_ERROR status); - struct HandleSigma3Work; + struct HandleSigma3Data; CHIP_ERROR HandleSigma3a(System::PacketBufferHandle && msg); - static void HandleSigma3b(HandleSigma3Work & work); - CHIP_ERROR HandleSigma3c(HandleSigma3Work & work); + static CHIP_ERROR HandleSigma3b(HandleSigma3Data & data, bool & cancel); + CHIP_ERROR HandleSigma3c(HandleSigma3Data & data, CHIP_ERROR status); CHIP_ERROR SendSigma2Resume(); @@ -303,13 +303,10 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, // Sigma1 initiator random, maintained to be reused post-Sigma1, such as when generating Sigma2 S2RK key uint8_t mInitiatorRandom[kSigmaParamRandomNumberSize]; - // Sequence number used to coordinate foreground/background work for a - // particular session establishment. - int mSequence = 0; - template class WorkHelper; Platform::SharedPtr> mSendSigma3Helper; + Platform::SharedPtr> mHandleSigma3Helper; State mState;