Skip to content

Commit

Permalink
Cleanup in aisle CASESession (#26339)
Browse files Browse the repository at this point in the history
* Cleanup in aisle CASESession

* Reduce nesting in function
  • Loading branch information
mlepage-google authored and pull[bot] committed Apr 25, 2024
1 parent 5cfb36e commit 1567327
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/credentials/FabricTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ FabricTable::AddOrUpdateInner(FabricIndex fabricIndex, bool isAddition, Crypto::
}
else
{
// Initialization for Upating fabric: setting up a shadow fabricInfo
// Initialization for Updating fabric: setting up a shadow fabricInfo
const FabricInfo * existingFabric = FindFabricWithIndex(fabricIndex);
VerifyOrReturnError(existingFabric != nullptr, CHIP_ERROR_INTERNAL);

Expand Down
4 changes: 2 additions & 2 deletions src/credentials/FabricTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class DLL_EXPORT FabricInfo

friend class FabricTable;

protected:
private:
struct InitParams
{
NodeId nodeId = kUndefinedNodeId;
Expand Down Expand Up @@ -1098,7 +1098,7 @@ class DLL_EXPORT FabricTable
*/
const FabricInfo * GetShadowPendingFabricEntry() const { return HasPendingFabricUpdate() ? &mPendingFabric : nullptr; }

// Returns true if we have a shadow entry pending for a fabruc update.
// Returns true if we have a shadow entry pending for a fabric update.
bool HasPendingFabricUpdate() const
{
return mPendingFabric.IsInitialized() &&
Expand Down
118 changes: 61 additions & 57 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,19 @@ class CASESession::WorkHelper
// The `status` value is the result of the work callback (called beforehand).
typedef CHIP_ERROR (CASESession::*AfterWorkCallback)(DATA & data, CHIP_ERROR status);

// Create a work helper using the specified session, work callback, after work callback, and data (template arg).
// Lifetime is not managed, see `Create` for that option.
WorkHelper(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) :
mSession(&session), mWorkCallback(workCallback), mAfterWorkCallback(afterWorkCallback)
{}

public:
// Create a work helper using the specified session, work callback, after work callback, and data (template arg).
// Lifetime is managed by sharing between the caller (typically the session) and the helper itself (while work is scheduled).
static Platform::SharedPtr<WorkHelper> Create(CASESession & session, WorkCallback workCallback,
AfterWorkCallback afterWorkCallback)
{
auto ptr = Platform::MakeShared<WorkHelper>(session, workCallback, afterWorkCallback);
struct EnableShared : public WorkHelper
{
EnableShared(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) :
WorkHelper(session, workCallback, afterWorkCallback)
{}
};
auto ptr = Platform::MakeShared<EnableShared>(session, workCallback, afterWorkCallback);
if (ptr)
{
ptr->mWeakPtr = ptr; // used by `ScheduleWork`
Expand All @@ -173,10 +174,7 @@ class CASESession::WorkHelper
// No scheduling, no outstanding work, no shared lifetime management.
CHIP_ERROR DoWork()
{
if (!mSession || !mWorkCallback || !mAfterWorkCallback)
{
return CHIP_ERROR_INCORRECT_STATE;
}
VerifyOrReturnError(mSession && mWorkCallback && mAfterWorkCallback, CHIP_ERROR_INCORRECT_STATE);
auto * helper = this;
bool cancel = false;
helper->mStatus = helper->mWorkCallback(helper->mData, cancel);
Expand All @@ -187,18 +185,17 @@ class CASESession::WorkHelper
return helper->mStatus;
}

// Schedule the work after configuring the data.
// Schedule the work for later execution.
// If lifetime is managed, the helper shares management while work is outstanding.
CHIP_ERROR ScheduleWork()
{
if (!mSession || !mWorkCallback || !mAfterWorkCallback)
{
return CHIP_ERROR_INCORRECT_STATE;
}
VerifyOrReturnError(mSession && mWorkCallback && mAfterWorkCallback, CHIP_ERROR_INCORRECT_STATE);
// Hold strong ptr while work is outstanding
mStrongPtr = mWeakPtr.lock(); // set in `Create`
auto status = DeviceLayer::PlatformMgr().ScheduleBackgroundWork(WorkHandler, reinterpret_cast<intptr_t>(this));
if (status != CHIP_NO_ERROR)
{
// Release strong ptr since scheduling failed
mStrongPtr.reset();
}
return status;
Expand All @@ -207,32 +204,47 @@ class CASESession::WorkHelper
// Cancel the work, by clearing the associated session.
void CancelWork() { mSession.store(nullptr); }

bool IsCancelled() const { return mSession.load() == nullptr; }

private:
// Create a work helper using the specified session, work callback, after work callback, and data (template arg).
// Lifetime is not managed, see `Create` for that option.
WorkHelper(CASESession & session, WorkCallback workCallback, AfterWorkCallback afterWorkCallback) :
mSession(&session), mWorkCallback(workCallback), mAfterWorkCallback(afterWorkCallback)
{}

// Handler for the work callback.
static void WorkHandler(intptr_t arg)
{
auto * helper = reinterpret_cast<WorkHelper *>(arg);
bool cancel = false;
VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`?
// Hold strong ptr while work is handled
auto strongPtr(std::move(helper->mStrongPtr));
VerifyOrReturn(!helper->IsCancelled());
bool cancel = false;
// Execute callback in background thread; data must be OK with this
helper->mStatus = helper->mWorkCallback(helper->mData, cancel);
VerifyOrExit(!cancel, ;); // canceled by `mWorkCallback`?
VerifyOrExit(helper->mSession.load(), ;); // cancelled by `CancelWork`?
SuccessOrExit(DeviceLayer::PlatformMgr().ScheduleWork(AfterWorkHandler, reinterpret_cast<intptr_t>(helper)));
return;
exit:
helper->mStrongPtr.reset();
VerifyOrReturn(!cancel && !helper->IsCancelled());
// Hold strong ptr while work is outstanding
helper->mStrongPtr.swap(strongPtr);
auto status = DeviceLayer::PlatformMgr().ScheduleWork(AfterWorkHandler, reinterpret_cast<intptr_t>(helper));
if (status != CHIP_NO_ERROR)
{
// Release strong ptr since scheduling failed
helper->mStrongPtr.reset();
}
}

// Handler for the after work callback.
static void AfterWorkHandler(intptr_t arg)
{
// Since this runs in the main Matter thread, the session shouldn't be otherwise used (messages, timers, etc.)
auto * helper = reinterpret_cast<WorkHelper *>(arg);
// Hold strong ptr while work is handled
auto strongPtr(std::move(helper->mStrongPtr));
if (auto * session = helper->mSession.load())
{
// Execute callback in Matter thread; session should be OK with this
(session->*(helper->mAfterWorkCallback))(helper->mData, helper->mStatus);
}
helper->mStrongPtr.reset();
}

private:
Expand Down Expand Up @@ -261,7 +273,7 @@ class CASESession::WorkHelper

struct CASESession::SendSigma3Data
{
std::atomic<FabricIndex> fabricIndex;
FabricIndex fabricIndex;

// Use one or the other
const FabricTable * fabricTable;
Expand Down Expand Up @@ -319,7 +331,6 @@ void CASESession::Clear()
// Cancel any outstanding work.
if (mSendSigma3Helper)
{
mSendSigma3Helper->mData.fabricIndex = kUndefinedFabricIndex;
mSendSigma3Helper->CancelWork();
mSendSigma3Helper.reset();
}
Expand Down Expand Up @@ -1359,40 +1370,37 @@ CHIP_ERROR CASESession::SendSigma3a()

CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel)
{
CHIP_ERROR err = CHIP_NO_ERROR;

// Generate a signature
if (data.keystore != nullptr)
{
// Recommended case: delegate to operational keystore
err = data.keystore->SignWithOpKeypair(data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len },
data.tbsData3Signature);
ReturnErrorOnFailure(data.keystore->SignWithOpKeypair(
data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, data.tbsData3Signature));
}
else
{
// Legacy case: delegate to fabric table fabric info
err = data.fabricTable->SignWithOpKeypair(data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len },
data.tbsData3Signature);
ReturnErrorOnFailure(data.fabricTable->SignWithOpKeypair(
data.fabricIndex, ByteSpan{ data.msg_R3_Signed.Get(), data.msg_r3_signed_len }, data.tbsData3Signature));
}
SuccessOrExit(err);

// Prepare Sigma3 TBE Data Blob
data.msg_r3_encrypted_len =
TLV::EstimateStructOverhead(data.nocCert.size(), data.icaCert.size(), data.tbsData3Signature.Length());

VerifyOrExit(data.msg_R3_Encrypted.Alloc(data.msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES),
err = CHIP_ERROR_NO_MEMORY);
VerifyOrReturnError(data.msg_R3_Encrypted.Alloc(data.msg_r3_encrypted_len + CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES),
CHIP_ERROR_NO_MEMORY);

{
TLV::TLVWriter tlvWriter;
TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified;

tlvWriter.Init(data.msg_R3_Encrypted.Get(), data.msg_r3_encrypted_len);
SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType));
SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), data.nocCert));
ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag(), TLV::kTLVType_Structure, outerContainerType));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderNOC), data.nocCert));
if (!data.icaCert.empty())
{
SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), data.icaCert));
ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kTag_TBEData_SenderICAC), data.icaCert));
}

// We are now done with ICAC and NOC certs so we can release the memory.
Expand All @@ -1404,15 +1412,14 @@ CHIP_ERROR CASESession::SendSigma3b(SendSigma3Data & data, bool & cancel)
data.nocCert = MutableByteSpan{};
}

SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), data.tbsData3Signature.ConstBytes(),
static_cast<uint32_t>(data.tbsData3Signature.Length())));
SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType));
SuccessOrExit(err = tlvWriter.Finalize());
ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(kTag_TBEData_Signature), data.tbsData3Signature.ConstBytes(),
static_cast<uint32_t>(data.tbsData3Signature.Length())));
ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriter.Finalize());
data.msg_r3_encrypted_len = static_cast<size_t>(tlvWriter.GetLengthWritten());
}

exit:
return err;
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::SendSigma3c(SendSigma3Data & data, CHIP_ERROR status)
Expand Down Expand Up @@ -1650,17 +1657,15 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg)

CHIP_ERROR CASESession::HandleSigma3b(HandleSigma3Data & data, bool & cancel)
{
CHIP_ERROR err = CHIP_NO_ERROR;

// Step 5/6
// Validate initiator identity located in msg->Start()
// Constructing responder identity
CompressedFabricId unused;
FabricId initiatorFabricId;
P256PublicKey initiatorPublicKey;
SuccessOrExit(err = FabricTable::VerifyCredentials(data.initiatorNOC, data.initiatorICAC, data.fabricRCAC, data.validContext,
unused, initiatorFabricId, data.initiatorNodeId, initiatorPublicKey));
VerifyOrExit(data.fabricId == initiatorFabricId, err = CHIP_ERROR_INVALID_CASE_PARAMETER);
ReturnErrorOnFailure(FabricTable::VerifyCredentials(data.initiatorNOC, data.initiatorICAC, data.fabricRCAC, data.validContext,
unused, initiatorFabricId, data.initiatorNodeId, initiatorPublicKey));
VerifyOrReturnError(data.fabricId == initiatorFabricId, CHIP_ERROR_INVALID_CASE_PARAMETER);

// TODO - Validate message signature prior to validating the received operational credentials.
// The op cert check requires traversal of cert chain, that is a more expensive operation.
Expand All @@ -1672,16 +1677,15 @@ CHIP_ERROR CASESession::HandleSigma3b(HandleSigma3Data & data, bool & cancel)
{
P256PublicKeyHSM initiatorPublicKeyHSM;
memcpy(Uint8::to_uchar(initiatorPublicKeyHSM), initiatorPublicKey.Bytes(), initiatorPublicKey.Length());
SuccessOrExit(err = initiatorPublicKeyHSM.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len,
data.tbsData3Signature));
ReturnErrorOnFailure(initiatorPublicKeyHSM.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len,
data.tbsData3Signature));
}
#else
SuccessOrExit(err = initiatorPublicKey.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len,
data.tbsData3Signature));
ReturnErrorOnFailure(
initiatorPublicKey.ECDSA_validate_msg_signature(data.msg_R3_Signed.Get(), data.msg_r3_signed_len, data.tbsData3Signature));
#endif

exit:
return err;
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::HandleSigma3c(HandleSigma3Data & data, CHIP_ERROR status)
Expand Down

0 comments on commit 1567327

Please sign in to comment.