Skip to content

Commit

Permalink
Use work helper with CASE handle sigma3 (#26300)
Browse files Browse the repository at this point in the history
* Use work helper with CASE handle sigma3

WorkHelper was introduced to help with CASESession::SendSigma3.
Now use it to help with CASESession::HandleSigma3.

Part of issue #26280.

* Unscope a block
  • Loading branch information
mlepage-google authored and pull[bot] committed Jan 17, 2024
1 parent dbbec76 commit 13d9661
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 97 deletions.
141 changes: 51 additions & 90 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ class CASESession::WorkHelper
{
return CHIP_ERROR_INCORRECT_STATE;
}
WorkHelper * helper = this;
bool cancel = false;
helper->mStatus = helper->mWorkCallback(helper->mData, cancel);
auto * helper = this;
bool cancel = false;
helper->mStatus = helper->mWorkCallback(helper->mData, cancel);
if (!cancel)
{
helper->mStatus = (helper->mSession->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus);
Expand Down Expand Up @@ -211,8 +211,8 @@ class CASESession::WorkHelper
// Handler for the work callback.
static void WorkHandler(intptr_t arg)
{
WorkHelper * helper = reinterpret_cast<WorkHelper *>(arg);
bool cancel = false;
auto * helper = reinterpret_cast<WorkHelper *>(arg);
bool cancel = false;
VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`?
helper->mStatus = helper->mWorkCallback(helper->mData, cancel);
VerifyOrExit(!cancel, ;); // canceled by `mWorkCallback`?
Expand All @@ -227,7 +227,7 @@ class CASESession::WorkHelper
static void AfterWorkHandler(intptr_t arg)
{
// Since this runs in the main Matter thread, the session shouldn't be otherwise used (messages, timers, etc.)
WorkHelper * helper = reinterpret_cast<WorkHelper *>(arg);
auto * helper = reinterpret_cast<WorkHelper *>(arg);
if (auto * session = helper->mSession.load())
{
(session->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus);
Expand Down Expand Up @@ -282,18 +282,8 @@ struct CASESession::SendSigma3Data
P256ECDSASignature tbsData3Signature;
};

struct CASESession::HandleSigma3Work
struct CASESession::HandleSigma3Data
{
// Status of background processing.
CHIP_ERROR status;

// Session to use after background processing.
CASESession * session;

// Sequence number used to coordinate foreground/background work for a
// particular session establishment.
int sequence;

chip::Platform::ScopedMemoryBuffer<uint8_t> msg_R3_Signed;
size_t msg_r3_signed_len;

Expand Down Expand Up @@ -333,6 +323,11 @@ void CASESession::Clear()
mSendSigma3Helper->CancelWork();
mSendSigma3Helper.reset();
}
if (mHandleSigma3Helper)
{
mHandleSigma3Helper->CancelWork();
mHandleSigma3Helper.reset();
}

// This function zeroes out and resets the memory used by the object.
// It's done so that no security related information will be leaked.
Expand Down Expand Up @@ -405,10 +400,6 @@ CASESession::PrepareForSessionEstablishment(SessionManager & sessionManager, Fab

CHIP_ERROR err = CHIP_NO_ERROR;

// Sequence number used to coordinate foreground/background work for a
// particular session establishment.
mSequence++;

SuccessOrExit(err = fabricTable->AddFabricDelegate(this));

mFabricsTable = fabricTable;
Expand Down Expand Up @@ -1531,23 +1522,16 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg)

ChipLogProgress(SecureChannel, "Received Sigma3 msg");

auto * workPtr = Platform::New<HandleSigma3Work>();
VerifyOrExit(workPtr != nullptr, err = CHIP_ERROR_NO_MEMORY);
auto helper = WorkHelper<HandleSigma3Data>::Create(*this, &HandleSigma3b, &CASESession::HandleSigma3c);
VerifyOrExit(helper, err = CHIP_ERROR_NO_MEMORY);
{
auto & work = *workPtr;

// Used to call back into the session after background event processing.
// It happens that there's only one pairing session (in CASEServer)
// so it will still be available for use. Use a sequence number to
// coordinate.
work.session = this;
work.sequence = mSequence;
auto & data = helper->mData;

{
VerifyOrExit(mFabricsTable != nullptr, err = CHIP_ERROR_INCORRECT_STATE);
const auto * fabricInfo = mFabricsTable->FindFabricWithIndex(mFabricIndex);
VerifyOrExit(fabricInfo != nullptr, err = CHIP_ERROR_INCORRECT_STATE);
work.fabricId = fabricInfo->GetFabricId();
data.fabricId = fabricInfo->GetFabricId();
}

VerifyOrExit(mEphemeralKey != nullptr, err = CHIP_ERROR_INTERNAL);
Expand All @@ -1558,7 +1542,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg)

// Fetch encrypted data
max_msg_r3_signed_enc_len = TLV::EstimateStructOverhead(Credentials::kMaxCHIPCertLength, Credentials::kMaxCHIPCertLength,
work.tbsData3Signature.Length(), kCaseOverheadForFutureTbeData);
data.tbsData3Signature.Length(), kCaseOverheadForFutureTbeData);

SuccessOrExit(err = tlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_Sigma3_Encrypted3)));

Expand Down Expand Up @@ -1592,74 +1576,70 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg)
SuccessOrExit(err = decryptedDataTlvReader.EnterContainer(containerType));

SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_SenderNOC)));
SuccessOrExit(err = decryptedDataTlvReader.Get(work.initiatorNOC));
SuccessOrExit(err = decryptedDataTlvReader.Get(data.initiatorNOC));

SuccessOrExit(err = decryptedDataTlvReader.Next());
if (TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == kTag_TBEData_SenderICAC)
{
VerifyOrExit(decryptedDataTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE);
SuccessOrExit(err = decryptedDataTlvReader.Get(work.initiatorICAC));
SuccessOrExit(err = decryptedDataTlvReader.Get(data.initiatorICAC));
SuccessOrExit(err = decryptedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBEData_Signature)));
}

// Step 4 - Construct Sigma3 TBS Data
work.msg_r3_signed_len = TLV::EstimateStructOverhead(sizeof(uint16_t), work.initiatorNOC.size(), work.initiatorICAC.size(),
data.msg_r3_signed_len = TLV::EstimateStructOverhead(sizeof(uint16_t), data.initiatorNOC.size(), data.initiatorICAC.size(),
kP256_PublicKey_Length, kP256_PublicKey_Length);

VerifyOrExit(work.msg_R3_Signed.Alloc(work.msg_r3_signed_len), err = CHIP_ERROR_NO_MEMORY);
VerifyOrExit(data.msg_R3_Signed.Alloc(data.msg_r3_signed_len), err = CHIP_ERROR_NO_MEMORY);

SuccessOrExit(err = ConstructTBSData(work.initiatorNOC, work.initiatorICAC, ByteSpan(mRemotePubKey, mRemotePubKey.Length()),
SuccessOrExit(err = ConstructTBSData(data.initiatorNOC, data.initiatorICAC, ByteSpan(mRemotePubKey, mRemotePubKey.Length()),
ByteSpan(mEphemeralKey->Pubkey(), mEphemeralKey->Pubkey().Length()),
work.msg_R3_Signed.Get(), work.msg_r3_signed_len));
data.msg_R3_Signed.Get(), data.msg_r3_signed_len));

VerifyOrExit(TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == kTag_TBEData_Signature,
err = CHIP_ERROR_INVALID_TLV_TAG);
VerifyOrExit(work.tbsData3Signature.Capacity() >= decryptedDataTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT);
work.tbsData3Signature.SetLength(decryptedDataTlvReader.GetLength());
SuccessOrExit(err = decryptedDataTlvReader.GetBytes(work.tbsData3Signature.Bytes(), work.tbsData3Signature.Length()));
VerifyOrExit(data.tbsData3Signature.Capacity() >= decryptedDataTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT);
data.tbsData3Signature.SetLength(decryptedDataTlvReader.GetLength());
SuccessOrExit(err = decryptedDataTlvReader.GetBytes(data.tbsData3Signature.Bytes(), data.tbsData3Signature.Length()));

// Prepare for Step 5/6
{
MutableByteSpan fabricRCAC{ work.rootCertBuf };
MutableByteSpan fabricRCAC{ data.rootCertBuf };
SuccessOrExit(err = mFabricsTable->FetchRootCert(mFabricIndex, fabricRCAC));
work.fabricRCAC = fabricRCAC;
data.fabricRCAC = fabricRCAC;
// TODO probably should make SetEffectiveTime static and call closer to VerifyCredentials
SuccessOrExit(err = SetEffectiveTime());
}

// Copy remaining needed data into work structure
{
work.validContext = mValidContext;
data.validContext = mValidContext;

// initiatorNOC and initiatorICAC are spans into msg_R3_Encrypted
// which is going away, so to save memory, redirect them to their
// copies in msg_R3_signed, which is staying around
TLV::TLVReader signedDataTlvReader;
signedDataTlvReader.Init(work.msg_R3_Signed.Get(), work.msg_r3_signed_len);
signedDataTlvReader.Init(data.msg_R3_Signed.Get(), data.msg_r3_signed_len);
SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_Structure, TLV::AnonymousTag()));
SuccessOrExit(err = signedDataTlvReader.EnterContainer(containerType));

SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBSData_SenderNOC)));
SuccessOrExit(err = signedDataTlvReader.Get(work.initiatorNOC));
SuccessOrExit(err = signedDataTlvReader.Get(data.initiatorNOC));

if (!work.initiatorICAC.empty())
if (!data.initiatorICAC.empty())
{
SuccessOrExit(err = signedDataTlvReader.Next(TLV::kTLVType_ByteString, TLV::ContextTag(kTag_TBSData_SenderICAC)));
SuccessOrExit(err = signedDataTlvReader.Get(work.initiatorICAC));
SuccessOrExit(err = signedDataTlvReader.Get(data.initiatorICAC));
}
}

SuccessOrExit(err = DeviceLayer::PlatformMgr().ScheduleBackgroundWork(
[](intptr_t arg) { HandleSigma3b(*reinterpret_cast<HandleSigma3Work *>(arg)); },
reinterpret_cast<intptr_t>(&work)));
workPtr = nullptr; // scheduling succeeded, so don't delete
SuccessOrExit(err = helper->ScheduleWork());
mHandleSigma3Helper = helper;
mExchangeCtxt->WillSendMessage();
mState = State::kHandleSigma3Pending;
}

exit:
Platform::Delete(workPtr);

if (err != CHIP_NO_ERROR)
{
SendStatusReport(mExchangeCtxt, kProtocolCodeInvalidParam);
Expand All @@ -1668,7 +1648,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg)
return err;
}

void CASESession::HandleSigma3b(HandleSigma3Work & work)
CHIP_ERROR CASESession::HandleSigma3b(HandleSigma3Data & data, bool & cancel)
{
CHIP_ERROR err = CHIP_NO_ERROR;

Expand All @@ -1678,9 +1658,9 @@ void CASESession::HandleSigma3b(HandleSigma3Work & work)
CompressedFabricId unused;
FabricId initiatorFabricId;
P256PublicKey initiatorPublicKey;
SuccessOrExit(err = FabricTable::VerifyCredentials(work.initiatorNOC, work.initiatorICAC, work.fabricRCAC, work.validContext,
unused, initiatorFabricId, work.initiatorNodeId, initiatorPublicKey));
VerifyOrExit(work.fabricId == initiatorFabricId, err = CHIP_ERROR_INVALID_CASE_PARAMETER);
SuccessOrExit(err = FabricTable::VerifyCredentials(data.initiatorNOC, data.initiatorICAC, data.fabricRCAC, data.validContext,
unused, initiatorFabricId, data.initiatorNodeId, initiatorPublicKey));
VerifyOrExit(data.fabricId == initiatorFabricId, err = CHIP_ERROR_INVALID_CASE_PARAMETER);

// TODO - Validate message signature prior to validating the received operational credentials.
// The op cert check requires traversal of cert chain, that is a more expensive operation.
Expand All @@ -1692,46 +1672,27 @@ void CASESession::HandleSigma3b(HandleSigma3Work & work)
{
P256PublicKeyHSM initiatorPublicKeyHSM;
memcpy(Uint8::to_uchar(initiatorPublicKeyHSM), initiatorPublicKey.Bytes(), initiatorPublicKey.Length());
SuccessOrExit(err = initiatorPublicKeyHSM.ECDSA_validate_msg_signature(work.msg_R3_Signed.Get(), work.msg_r3_signed_len,
work.tbsData3Signature));
SuccessOrExit(err = initiatorPublicKeyHSM.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len,
data.tbsData3Signature));
}
#else
SuccessOrExit(err = initiatorPublicKey.ECDSA_validate_msg_signature(work.msg_R3_Signed.Get(), work.msg_r3_signed_len,
work.tbsData3Signature));
SuccessOrExit(err = initiatorPublicKey.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len,
data.tbsData3Signature));
#endif

exit:
work.status = err;

auto err2 = DeviceLayer::PlatformMgr().ScheduleWork(
[](intptr_t arg) {
auto & work2 = *reinterpret_cast<HandleSigma3Work *>(arg);
work2.session->HandleSigma3c(work2);
},
reinterpret_cast<intptr_t>(&work));

if (err2 != CHIP_NO_ERROR)
{
Platform::Delete(&work); // scheduling failed, so delete
}
return err;
}

CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Work & work)
CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Data & data, CHIP_ERROR status)
{
CHIP_ERROR err = CHIP_NO_ERROR;
bool ignoreFailure = true;
CHIP_ERROR err = CHIP_NO_ERROR;

// Special case: if for whatever reason not in expected state or sequence,
// don't do anything, including sending a status report or aborting the
// pending establish.
VerifyOrExit(mState == State::kHandleSigma3Pending, err = CHIP_ERROR_INCORRECT_STATE);
VerifyOrExit(mSequence == work.sequence, err = CHIP_ERROR_INCORRECT_STATE);

ignoreFailure = false;

SuccessOrExit(err = work.status);
SuccessOrExit(err = status);

mPeerNodeId = work.initiatorNodeId;
mPeerNodeId = data.initiatorNodeId;

{
MutableByteSpan messageDigestSpan(mMessageDigest);
Expand All @@ -1740,7 +1701,7 @@ CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Work & work)

// Retrieve peer CASE Authenticated Tags (CATs) from peer's NOC.
{
SuccessOrExit(err = ExtractCATsFromOpCert(work.initiatorNOC, mPeerCATs));
SuccessOrExit(err = ExtractCATsFromOpCert(data.initiatorNOC, mPeerCATs));
}

if (mSessionResumptionStorage != nullptr)
Expand All @@ -1758,9 +1719,9 @@ CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Work & work)
Finish();

exit:
Platform::Delete(&work);
mHandleSigma3Helper.reset();

if (err != CHIP_NO_ERROR && !ignoreFailure)
if (err != CHIP_NO_ERROR)
{
SendStatusReport(mExchangeCtxt, kProtocolCodeInvalidParam);
// Abort the pending establish, which is normally done by CASESession::OnMessageReceived,
Expand Down
11 changes: 4 additions & 7 deletions src/protocols/secure_channel/CASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,10 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler,
static CHIP_ERROR SendSigma3b(SendSigma3Data & data, bool & cancel);
CHIP_ERROR SendSigma3c(SendSigma3Data & data, CHIP_ERROR status);

struct HandleSigma3Work;
struct HandleSigma3Data;
CHIP_ERROR HandleSigma3a(System::PacketBufferHandle && msg);
static void HandleSigma3b(HandleSigma3Work & work);
CHIP_ERROR HandleSigma3c(HandleSigma3Work & work);
static CHIP_ERROR HandleSigma3b(HandleSigma3Data & data, bool & cancel);
CHIP_ERROR HandleSigma3c(HandleSigma3Data & data, CHIP_ERROR status);

CHIP_ERROR SendSigma2Resume();

Expand Down Expand Up @@ -303,13 +303,10 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler,
// Sigma1 initiator random, maintained to be reused post-Sigma1, such as when generating Sigma2 S2RK key
uint8_t mInitiatorRandom[kSigmaParamRandomNumberSize];

// Sequence number used to coordinate foreground/background work for a
// particular session establishment.
int mSequence = 0;

template <class DATA>
class WorkHelper;
Platform::SharedPtr<WorkHelper<SendSigma3Data>> mSendSigma3Helper;
Platform::SharedPtr<WorkHelper<HandleSigma3Data>> mHandleSigma3Helper;

State mState;

Expand Down

0 comments on commit 13d9661

Please sign in to comment.