From 63e5730a90a4d89e91d328fb631b1322822f6bdf Mon Sep 17 00:00:00 2001 From: Joseph Kelly <41064086+jpk233@users.noreply.github.com> Date: Wed, 28 Jul 2021 10:20:48 -0400 Subject: [PATCH] CASE spec refresh (#8137) * CASE spec refresh * Remove TrustedRootID from Sigma messages. Replce HKDF algorithm with HMAC-SHA256 Algorithm in GenerateDestinationID method. * Replace TrustedRootId parameter with an index value for the OperationalCredentialSet. Added CredentialsIndex parameter to AdminPairingTable's GetCredentials method. Removed FabricSecret. Removed kIPKInfo. Removed Deprecated ComputeIPK method - Replaced it with dummy RetrieveIPK method. * cast 1 to uint8_t to avoid compiler errors on other platforms * Fix uint8_t conversion * Replace CASETLVTags with TLV::ContextTag ID, switch unsafe statement to an actual OperationalCredentialSet method to retrieve a TrustedRootId, renamed kMAX_Hash_Length refs to kSHA256_Hash_Length to match the Spec. * Add GenerateDestinationID Test * Added ReleaseLastCert method to OperationalCredentialSet Class. This method will release the last certificate data in the set. Added call to ReleaseLastCert method during CASESession after a successful validation of the NOC certificate. Updated CASE Unit Tests to force 3 certificates maximum. This will guarantee that for the tests to work fine, CASESession must release the NOC certificate every time. CASESession: change LoadCerts to LoadCert - Only NOC is transferred during CASE Protocol. No need to handle ICA * remove fabricId parameters/methods from CASESession. Retrieve it from the NOC instead. * Updated ReleaseLastCert method from CHIPCertificateSet class: not using a const ChipCertificateData type anymore to avoid confusion. Removed redundant comment from ReleaseLastCert method. Wrote some TODO items to update OperationalCredentialSet class in order to work with size_t variable counters: useful to index more than 255 Credentials. Update DestinationIDGeneration Test to use Spec's test vectors. Added static assert to check TBEData2 and TBEData3 Nonce Lengths : they must match. Added new method to Estimate TLV Struct overhead. Updated GenerateDestinationID to be stateless: now directly accesses the inputs as raw memory buffers. Updated HandleSigma methods to handle TLV tags sequentially. Removed redundant GetLength and GetType calls during TLV Reads. Renamed encryptionKeyId to initiatorSessionId and responderSessionId. Fixed typo in ENABLE_HSM_CASE_EPHEMERAL_KEY macro. * Update casesession with latest comments * Trigger Build * added IPK to CASESession Serializable data: IPK is needed to DeriveSecureSession, so it needs to be stored. Updated FindDestinationIDCandidate loop variables to size_t. Also updated names to reflect what they are indexing. * Added CASESession protected API (Virtual) to get the IPK List Entries. They can be overridden by the Unit Tests in order to feed in the test IPK vectors. * Update for fabric class rename * Restyling * Replace mCredentialsIndex pointer to an actual uint8_t variable. Update GetCertFabricId method to use UINT64_MAX macro as an invalid reference, and added a final sanity check to see if there were actually any fabricIds present in the certificate. Added doxygen to GetCredentials method. * Restyle fix --- src/channel/Channel.h | 8 + src/channel/ChannelContext.cpp | 6 +- src/controller/CHIPDevice.cpp | 3 +- src/controller/CHIPDevice.h | 22 +- src/controller/CHIPDeviceController.cpp | 17 +- src/controller/CHIPDeviceController.h | 1 + src/credentials/CHIPCert.cpp | 36 + src/credentials/CHIPCert.h | 7 + .../CHIPOperationalCredentials.cpp | 19 + src/credentials/CHIPOperationalCredentials.h | 8 +- src/crypto/hsm/CHIPCryptoPALHsm_config.h | 2 +- src/protocols/secure_channel/CASEServer.cpp | 19 +- src/protocols/secure_channel/CASEServer.h | 4 +- src/protocols/secure_channel/CASESession.cpp | 775 ++++++++---------- src/protocols/secure_channel/CASESession.h | 58 +- .../secure_channel/tests/TestCASESession.cpp | 116 ++- src/transport/FabricTable.cpp | 3 +- src/transport/FabricTable.h | 13 +- src/transport/PeerConnectionState.h | 1 + 19 files changed, 611 insertions(+), 507 deletions(-) diff --git a/src/channel/Channel.h b/src/channel/Channel.h index 2d6821ec1ef711..d993caa4c982df 100644 --- a/src/channel/Channel.h +++ b/src/channel/Channel.h @@ -107,6 +107,13 @@ class ChannelBuilder return *this; } + uint8_t GetOperationalCredentialSetIndex() const { return mCaseParameters.mOperationalCredentialSetIndex; } + ChannelBuilder & SetOperationalCredentialSetIndex(uint8_t operationalCredentialSetIndex) + { + mCaseParameters.mOperationalCredentialSetIndex = operationalCredentialSetIndex; + return *this; + } + Optional GetForcePeerAddress() const { return mForcePeerAddr; } ChannelBuilder & SetForcePeerAddress(Inet::IPAddress peerAddr) { @@ -121,6 +128,7 @@ class ChannelBuilder { uint16_t mPeerKeyId; Credentials::OperationalCredentialSet * mOperationalCredentialSet; + uint8_t mOperationalCredentialSetIndex; } mCaseParameters; Optional mForcePeerAddr; diff --git a/src/channel/ChannelContext.cpp b/src/channel/ChannelContext.cpp index 9363eb2bc336bc..17367f42e319c1 100644 --- a/src/channel/ChannelContext.cpp +++ b/src/channel/ChannelContext.cpp @@ -264,9 +264,9 @@ void ChannelContext::EnterCasePairingState() // TODO: currently only supports IP/UDP paring Transport::PeerAddress addr; addr.SetTransportType(Transport::Type::kUdp).SetIPAddress(prepare.mAddress); - CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession(addr, &prepare.mBuilder.GetOperationalCredentialSet(), - prepare.mBuilder.GetPeerNodeId(), - mExchangeManager->GetNextKeyId(), ctxt, this); + CHIP_ERROR err = prepare.mCasePairingSession->EstablishSession( + addr, &prepare.mBuilder.GetOperationalCredentialSet(), prepare.mBuilder.GetOperationalCredentialSetIndex(), + prepare.mBuilder.GetPeerNodeId(), mExchangeManager->GetNextKeyId(), ctxt, this); if (err != CHIP_NO_ERROR) { ExitCasePairingState(); diff --git a/src/controller/CHIPDevice.cpp b/src/controller/CHIPDevice.cpp index ceef00577aae94..1325eb5aa08684 100644 --- a/src/controller/CHIPDevice.cpp +++ b/src/controller/CHIPDevice.cpp @@ -538,7 +538,8 @@ CHIP_ERROR Device::WarmupCASESession() mLocalMessageCounter = 0; mPeerMessageCounter = 0; - ReturnErrorOnFailure(mCASESession.EstablishSession(mDeviceAddress, mCredentials, mDeviceId, keyID, exchange, this)); + ReturnErrorOnFailure( + mCASESession.EstablishSession(mDeviceAddress, mCredentials, mCredentialsIndex, mDeviceId, keyID, exchange, this)); mState = ConnectionState::Connecting; diff --git a/src/controller/CHIPDevice.h b/src/controller/CHIPDevice.h index ca2fea5046056a..7f80f79d270e27 100644 --- a/src/controller/CHIPDevice.h +++ b/src/controller/CHIPDevice.h @@ -83,6 +83,7 @@ struct ControllerDeviceInitParams Inet::InetLayer * inetLayer = nullptr; PersistentStorageDelegate * storageDelegate = nullptr; Credentials::OperationalCredentialSet * credentials = nullptr; + uint8_t credentialsIndex = 0; SessionIDAllocator * idAllocator = nullptr; #if CONFIG_NETWORK_LAYER_BLE Ble::BleLayer * bleLayer = nullptr; @@ -180,15 +181,16 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta */ void Init(ControllerDeviceInitParams params, uint16_t listenPort, FabricIndex fabric) { - mTransportMgr = params.transportMgr; - mSessionManager = params.sessionMgr; - mExchangeMgr = params.exchangeMgr; - mInetLayer = params.inetLayer; - mListenPort = listenPort; - mFabricIndex = fabric; - mStorageDelegate = params.storageDelegate; - mCredentials = params.credentials; - mIDAllocator = params.idAllocator; + mTransportMgr = params.transportMgr; + mSessionManager = params.sessionMgr; + mExchangeMgr = params.exchangeMgr; + mInetLayer = params.inetLayer; + mListenPort = listenPort; + mFabricIndex = fabric; + mStorageDelegate = params.storageDelegate; + mCredentials = params.credentials; + mCredentialsIndex = params.credentialsIndex; + mIDAllocator = params.idAllocator; #if CONFIG_NETWORK_LAYER_BLE mBleLayer = params.bleLayer; #endif @@ -481,6 +483,8 @@ class DLL_EXPORT Device : public Messaging::ExchangeDelegate, public SessionEsta CASESession mCASESession; Credentials::OperationalCredentialSet * mCredentials = nullptr; + // TODO: Switch to size_t whenever OperationalCredentialSet Class is updated to support more then 255 credentials per controller + uint8_t mCredentialsIndex = 0; PersistentStorageDelegate * mStorageDelegate = nullptr; diff --git a/src/controller/CHIPDeviceController.cpp b/src/controller/CHIPDeviceController.cpp index db5da780b890d6..00e2daefdb7d25 100644 --- a/src/controller/CHIPDeviceController.cpp +++ b/src/controller/CHIPDeviceController.cpp @@ -319,7 +319,7 @@ CHIP_ERROR DeviceController::LoadLocalCredentials(Transport::FabricInfo * fabric } ChipLogProgress(Controller, "Generating credentials"); - ReturnErrorOnFailure(fabric->GetCredentials(mCredentials, mCertificates, mRootKeyId)); + ReturnErrorOnFailure(fabric->GetCredentials(mCredentials, mCertificates, mRootKeyId, mCredentialsIndex)); ChipLogProgress(Controller, "Loaded credentials successfully"); return CHIP_NO_ERROR; @@ -796,13 +796,14 @@ void DeviceController::OnNodeIdResolutionFailed(const chip::PeerId & peer, CHIP_ ControllerDeviceInitParams DeviceController::GetControllerDeviceInitParams() { return ControllerDeviceInitParams{ - .transportMgr = mTransportMgr, - .sessionMgr = mSessionMgr, - .exchangeMgr = mExchangeMgr, - .inetLayer = mInetLayer, - .storageDelegate = mStorageDelegate, - .credentials = &mCredentials, - .idAllocator = &mIDAllocator, + .transportMgr = mTransportMgr, + .sessionMgr = mSessionMgr, + .exchangeMgr = mExchangeMgr, + .inetLayer = mInetLayer, + .storageDelegate = mStorageDelegate, + .credentials = &mCredentials, + .credentialsIndex = mCredentialsIndex, + .idAllocator = &mIDAllocator, }; } diff --git a/src/controller/CHIPDeviceController.h b/src/controller/CHIPDeviceController.h index a39802af06f7d0..eef7d59da52b2a 100644 --- a/src/controller/CHIPDeviceController.h +++ b/src/controller/CHIPDeviceController.h @@ -344,6 +344,7 @@ class DLL_EXPORT DeviceController : public Messaging::ExchangeDelegate, Credentials::ChipCertificateSet mCertificates; Credentials::OperationalCredentialSet mCredentials; Credentials::CertificateKeyId mRootKeyId; + uint8_t mCredentialsIndex; SessionIDAllocator mIDAllocator; diff --git a/src/credentials/CHIPCert.cpp b/src/credentials/CHIPCert.cpp index 9cb33df80b943f..dec30212b3d5bf 100644 --- a/src/credentials/CHIPCert.cpp +++ b/src/credentials/CHIPCert.cpp @@ -310,6 +310,17 @@ CHIP_ERROR ChipCertificateSet::LoadCerts(TLVReader & reader, BitFlags 0) ? &mCerts[mCertCount - 1] : nullptr; + VerifyOrReturnError(lastCert != nullptr, CHIP_ERROR_INTERNAL); + + lastCert->~ChipCertificateData(); + --mCertCount; + + return CHIP_NO_ERROR; +} + const ChipCertificateData * ChipCertificateSet::FindCert(const CertificateKeyId & subjectKeyId) const { for (uint8_t i = 0; i < mCertCount; i++) @@ -744,6 +755,31 @@ CHIP_ERROR ChipDN::GetCertChipId(uint64_t & chipId) const return CHIP_NO_ERROR; } +CHIP_ERROR ChipDN::GetCertFabricId(uint64_t & fabricId) const +{ + uint8_t rdnCount = RDNCount(); + + fabricId = UINT64_MAX; + + for (uint8_t i = 0; i < rdnCount; i++) + { + switch (rdn[i].mAttrOID) + { + case kOID_AttributeType_ChipFabricId: + // Ensure only one FabricID RDN present, since start value is UINT64_MAX, which is reserved and never seen. + VerifyOrReturnError(fabricId == UINT64_MAX, CHIP_ERROR_WRONG_CERT_TYPE); + + fabricId = rdn[i].mChipVal; + break; + default: + break; + } + } + + VerifyOrReturnError(fabricId != UINT64_MAX, CHIP_ERROR_WRONG_CERT_TYPE); + return CHIP_NO_ERROR; +} + bool ChipDN::IsEqual(const ChipDN & other) const { bool res = true; diff --git a/src/credentials/CHIPCert.h b/src/credentials/CHIPCert.h index 18ab7c5214b3ba..f29d79086bb987 100644 --- a/src/credentials/CHIPCert.h +++ b/src/credentials/CHIPCert.h @@ -252,6 +252,11 @@ class ChipDN **/ CHIP_ERROR GetCertChipId(uint64_t & certId) const; + /** + * @brief Retrieve the Fabric ID of a CHIP certificate. + **/ + CHIP_ERROR GetCertFabricId(uint64_t & fabricId) const; + bool IsEqual(const ChipDN & other) const; /** @@ -463,6 +468,8 @@ class DLL_EXPORT ChipCertificateSet **/ CHIP_ERROR LoadCerts(chip::TLV::TLVReader & reader, BitFlags decodeFlags); + CHIP_ERROR ReleaseLastCert(); + /** * @brief Find certificate in the set. * diff --git a/src/credentials/CHIPOperationalCredentials.cpp b/src/credentials/CHIPOperationalCredentials.cpp index c9054faf06fcb7..19d9c7da5ad465 100644 --- a/src/credentials/CHIPOperationalCredentials.cpp +++ b/src/credentials/CHIPOperationalCredentials.cpp @@ -363,5 +363,24 @@ P256Keypair * OperationalCredentialSet::GetNodeKeypairAt(const CertificateKeyId return nullptr; } +const ChipCertificateData * OperationalCredentialSet::GetRootCertificate(const CertificateKeyId & trustedRootId) const +{ + for (size_t certChainIdx = 0; certChainIdx < mOpCredCount; certChainIdx++) + { + ChipCertificateSet * certSet = &mOpCreds[certChainIdx]; + + for (size_t ipkIdx = 0; ipkIdx < certSet->GetCertCount(); ipkIdx++) + { + const ChipCertificateData * cert = &certSet->GetCertSet()[ipkIdx]; + if (cert->mCertFlags.Has(CertFlags::kIsTrustAnchor) && cert->mAuthKeyId.data_equal(trustedRootId)) + { + return cert; + } + } + } + + return nullptr; +} + } // namespace Credentials } // namespace chip diff --git a/src/credentials/CHIPOperationalCredentials.h b/src/credentials/CHIPOperationalCredentials.h index f8cea77eda6a10..c9a2caa598fb5d 100644 --- a/src/credentials/CHIPOperationalCredentials.h +++ b/src/credentials/CHIPOperationalCredentials.h @@ -224,7 +224,7 @@ class DLL_EXPORT OperationalCredentialSet P256ECDSASignature & out_signature); /** - * @return A pointer to device credentials (in x509 format). + * @return A pointer to device credentials (in chip format). **/ const uint8_t * GetDevOpCred(const CertificateKeyId & trustedRootId) const { @@ -260,8 +260,11 @@ class DLL_EXPORT OperationalCredentialSet CHIP_ERROR SetDevOpCredKeypair(const CertificateKeyId & trustedRootId, P256Keypair * newKeypair); + const ChipCertificateData * GetRootCertificate(const CertificateKeyId & trustedRootId) const; + private: - ChipCertificateSet * mOpCreds; /**< Pointer to an array of certificate data. */ + ChipCertificateSet * mOpCreds; /**< Pointer to an array of certificate data. */ + // TODO: switch mOpCredCount var type to size_t in order to allow more than 255 credentials per controller. uint8_t mOpCredCount; /**< Number of certificates in mOpCreds array. We maintain the invariant that all the slots at indices less than @@ -276,6 +279,7 @@ class DLL_EXPORT OperationalCredentialSet NodeKeypairMap mDeviceOpCredKeypair[kOperationalCredentialsMax]; uint8_t mDeviceOpCredKeypairCount; + // TODO: Remove TrustedRootId indexing - Replace it with size_t index. const NodeCredential * GetNodeCredentialAt(const CertificateKeyId & trustedRootId) const; P256Keypair * GetNodeKeypairAt(const CertificateKeyId & trustedRootId); }; diff --git a/src/crypto/hsm/CHIPCryptoPALHsm_config.h b/src/crypto/hsm/CHIPCryptoPALHsm_config.h index 73d435335a6ac9..8908192b4f3df0 100644 --- a/src/crypto/hsm/CHIPCryptoPALHsm_config.h +++ b/src/crypto/hsm/CHIPCryptoPALHsm_config.h @@ -58,7 +58,7 @@ #if ((CHIP_CRYPTO_HSM) && (ENABLE_HSM_GENERATE_EC_KEY)) #define ENABLE_HSM_EC_KEY -#define ENABLE_HSM_CASE_EPHERMAL_KEY +#define ENABLE_HSM_CASE_EPHEMERAL_KEY #define ENABLE_HSM_CASE_OPS_KEY #endif diff --git a/src/protocols/secure_channel/CASEServer.cpp b/src/protocols/secure_channel/CASEServer.cpp index 340ef0cb89a514..21eec48890de4a 100644 --- a/src/protocols/secure_channel/CASEServer.cpp +++ b/src/protocols/secure_channel/CASEServer.cpp @@ -45,7 +45,7 @@ CHIP_ERROR CASEServer::ListenForSessionEstablishment(Messaging::ExchangeManager Cleanup(); - ReturnErrorOnFailure(mPairingSession.MessageDispatch().Init(transportMgr)); + ReturnErrorOnFailure(GetSession().MessageDispatch().Init(transportMgr)); return CHIP_NO_ERROR; } @@ -70,15 +70,16 @@ CHIP_ERROR CASEServer::InitCASEHandshake(Messaging::ExchangeContext * ec) } ReturnErrorCodeIf(fabric == nullptr, CHIP_ERROR_INVALID_ARGUMENT); - ReturnErrorOnFailure(fabric->GetCredentials(mCredentials, mCertificates, mRootKeyId)); + uint8_t credentialsIndex; + ReturnErrorOnFailure(fabric->GetCredentials(mCredentials, mCertificates, mRootKeyId, credentialsIndex)); ReturnErrorOnFailure(mIDAllocator->Allocate(mSessionKeyId)); // Setup CASE state machine using the credentials for the current fabric. - ReturnErrorOnFailure(mPairingSession.ListenForSessionEstablishment(&mCredentials, mSessionKeyId, this)); + ReturnErrorOnFailure(GetSession().ListenForSessionEstablishment(&mCredentials, mSessionKeyId, this)); // Hand over the exchange context to the CASE session. - ec->SetDelegate(&mPairingSession); + ec->SetDelegate(&GetSession()); return CHIP_NO_ERROR; } @@ -95,7 +96,7 @@ CHIP_ERROR CASEServer::OnMessageReceived(Messaging::ExchangeContext * ec, const ChipLogProgress(Inet, "CASE Server disabling CASE session setups"); mExchangeManager->UnregisterUnsolicitedMessageHandlerForType(Protocols::SecureChannel::MsgType::CASE_SigmaR1); - err = mPairingSession.OnMessageReceived(ec, packetHeader, payloadHeader, std::move(payload)); + err = GetSession().OnMessageReceived(ec, packetHeader, payloadHeader, std::move(payload)); SuccessOrExit(err); exit: @@ -116,7 +117,7 @@ void CASEServer::Cleanup() mFabricIndex = Transport::kUndefinedFabricIndex; mCredentials.Release(); mCertificates.Release(); - mPairingSession.Clear(); + GetSession().Clear(); } void CASEServer::OnSessionEstablishmentError(CHIP_ERROR err) @@ -129,11 +130,11 @@ void CASEServer::OnSessionEstablishmentError(CHIP_ERROR err) void CASEServer::OnSessionEstablished() { ChipLogProgress(Inet, "CASE Session established. Setting up the secure channel."); - mSessionMgr->ExpireAllPairings(mPairingSession.PeerConnection().GetPeerNodeId(), mFabricIndex); + mSessionMgr->ExpireAllPairings(GetSession().PeerConnection().GetPeerNodeId(), mFabricIndex); CHIP_ERROR err = mSessionMgr->NewPairing( - Optional::Value(mPairingSession.PeerConnection().GetPeerAddress()), - mPairingSession.PeerConnection().GetPeerNodeId(), &mPairingSession, SecureSession::SessionRole::kResponder, mFabricIndex); + Optional::Value(GetSession().PeerConnection().GetPeerAddress()), + GetSession().PeerConnection().GetPeerNodeId(), &GetSession(), SecureSession::SessionRole::kResponder, mFabricIndex); if (err != CHIP_NO_ERROR) { ChipLogError(Inet, "Failed in setting up secure channel: err %s", ErrorStr(err)); diff --git a/src/protocols/secure_channel/CASEServer.h b/src/protocols/secure_channel/CASEServer.h index fa2f4cc0809350..e66fc2a45f285d 100644 --- a/src/protocols/secure_channel/CASEServer.h +++ b/src/protocols/secure_channel/CASEServer.h @@ -53,10 +53,10 @@ class CASEServer : public SessionEstablishmentDelegate, public Messaging::Exchan Messaging::ExchangeMessageDispatch * GetMessageDispatch(Messaging::ReliableMessageMgr * reliableMessageManager, SecureSessionMgr * sessionMgr) override { - return mPairingSession.GetMessageDispatch(reliableMessageManager, sessionMgr); + return GetSession().GetMessageDispatch(reliableMessageManager, sessionMgr); } - CASESession & GetSession() { return mPairingSession; } + virtual CASESession & GetSession() { return mPairingSession; } private: Messaging::ExchangeManager * mExchangeManager = nullptr; diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index c1eb76536aab0b..207901996350ca 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -36,6 +36,7 @@ #include #include #include +#include #include namespace chip { @@ -44,21 +45,24 @@ using namespace Crypto; using namespace Credentials; using namespace Messaging; -constexpr uint8_t kIPKInfo[] = { 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x50, 0x72, 0x6f, - 0x74, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79 }; - -constexpr uint8_t kKDFSR2Info[] = { 0x53, 0x69, 0x67, 0x6d, 0x61, 0x52, 0x32 }; -constexpr uint8_t kKDFSR3Info[] = { 0x53, 0x69, 0x67, 0x6d, 0x61, 0x52, 0x33 }; +constexpr uint8_t kKDFSR2Info[] = { 0x53, 0x69, 0x67, 0x6d, 0x61, 0x32 }; +constexpr uint8_t kKDFSR3Info[] = { 0x53, 0x69, 0x67, 0x6d, 0x61, 0x33 }; constexpr size_t kKDFInfoLength = sizeof(kKDFSR2Info); constexpr uint8_t kKDFSEInfo[] = { 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x73 }; constexpr size_t kKDFSEInfoLength = sizeof(kKDFSEInfo); -constexpr uint8_t kIVSR2[] = { 0x4e, 0x43, 0x41, 0x53, 0x45, 0x5f, 0x53, 0x69, 0x67, 0x6d, 0x61, 0x52, 0x32 }; -constexpr uint8_t kIVSR3[] = { 0x4e, 0x43, 0x41, 0x53, 0x45, 0x5f, 0x53, 0x69, 0x67, 0x6d, 0x61, 0x52, 0x33 }; -constexpr size_t kIVLength = sizeof(kIVSR2); +constexpr uint8_t kTBEData2_Nonce[] = + /* "NCASE_Sigma2N" */ { 0x4e, 0x43, 0x41, 0x53, 0x45, 0x5f, 0x53, 0x69, 0x67, 0x6d, 0x61, 0x32, 0x4e }; +constexpr uint8_t kTBEData3_Nonce[] = + /* "NCASE_Sigma3N" */ { 0x4e, 0x43, 0x41, 0x53, 0x45, 0x5f, 0x53, 0x69, 0x67, 0x6d, 0x61, 0x33, 0x4e }; +constexpr size_t kTBEDataNonceLength = sizeof(kTBEData2_Nonce); +static_assert(sizeof(kTBEData2_Nonce) == sizeof(kTBEData3_Nonce), "TBEData2_Nonce and TBEData3_Nonce must be same size"); +// TODO: move this constant over to src/crypto/CHIPCryptoPAL.h - name it CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES constexpr size_t kTAGSize = 16; +constexpr size_t kDestinationMessageLen = kSigmaParamRandomNumberSize + kP256_PublicKey_Length + sizeof(FabricId) + sizeof(NodeId); + #ifdef ENABLE_HSM_HKDF using HKDF_sha_crypto = HKDF_shaHSM; #else @@ -70,48 +74,9 @@ using HKDF_sha_crypto = HKDF_sha; // The session establishment fails if the response is not received within timeout window. static constexpr ExchangeContext::Timeout kSigma_Response_Timeout = 10000; -/** - * \brief - * A list of registered packet types a.k.a. TLV context-specific tags to be used during CASE protocol. - */ -enum CASETLVTag : uint8_t -{ - /*! \brief Tag 0. Default stub and end of transmission signal. */ - kUnknown = 0, - /*! \brief Tag 1. The packet contains a random number. */ - kRandom = 1, - /*! \brief Tag 2. The packet contains the Session ID. */ - kSessionID = 2, - /*! \brief Tag 3. The packet contains the Destination ID. */ - kDestinationID = 3, - /*! \brief Tag 4. The packet contains the Initiator's Ephemeral Public Key. */ - kInitiatorEphPubKey = 4, - /*! \brief Tag 5. The packet contains the Initiator's Ephemeral Public Key. */ - kResponderEphPubKey = 5, - /*! \brief Tag 6. The packet contains a Node Operational Certificate. */ - kNOC = 6, - /*! \brief Tag 7. The packet contains a signature. */ - kSignature = 7, - /*! \brief Tag 8. The packet contains a Trusted Root ID. */ - kTrustedRootID = 8, - /*! \brief Tag 9. The packet contains an Encrypted data blob. */ - kEncryptedData = 9, - /*! \brief Tag 10. The packet contains an AEAD Tag. */ - kTag = 10, - // TODO: Remove tag 11 - /*! \brief Tag 11. The packet contains the total number of Trusted Root IDs. */ - kNumberofTrustedRootIDs = 11, -}; - CASESession::CASESession() { mTrustedRootId = CertificateKeyId(); - // dummy initialization REMOVE LATER - for (size_t i = 0; i < mFabricSecret.Capacity(); i++) - { - mFabricSecret[i] = static_cast(i); - } - mFabricSecret.SetLength(mFabricSecret.Capacity()); } CASESession::~CASESession() @@ -128,11 +93,6 @@ void CASESession::Clear() mCommissioningHash.Clear(); mPairingComplete = false; mConnectionState.Reset(); - if (!mTrustedRootId.empty()) - { - chip::Platform::MemoryFree(const_cast(mTrustedRootId.data())); - mTrustedRootId = CertificateKeyId(); - } CloseExchange(); } @@ -191,11 +151,13 @@ CHIP_ERROR CASESession::ToSerializable(CASESessionSerializable & serializable) const NodeId peerNodeId = mConnectionState.GetPeerNodeId(); VerifyOrReturnError(CanCastTo(mSharedSecret.Length()), CHIP_ERROR_INTERNAL); VerifyOrReturnError(CanCastTo(sizeof(mMessageDigest)), CHIP_ERROR_INTERNAL); + VerifyOrReturnError(CanCastTo(sizeof(mIPK)), CHIP_ERROR_INTERNAL); VerifyOrReturnError(CanCastTo(peerNodeId), CHIP_ERROR_INTERNAL); memset(&serializable, 0, sizeof(serializable)); serializable.mSharedSecretLen = static_cast(mSharedSecret.Length()); serializable.mMessageDigestLen = static_cast(sizeof(mMessageDigest)); + serializable.mIPKLen = static_cast(sizeof(mIPK)); serializable.mPairingComplete = (mPairingComplete) ? 1 : 0; serializable.mPeerNodeId = peerNodeId; serializable.mLocalKeyId = mConnectionState.GetLocalKeyID(); @@ -203,6 +165,7 @@ CHIP_ERROR CASESession::ToSerializable(CASESessionSerializable & serializable) memcpy(serializable.mSharedSecret, mSharedSecret, mSharedSecret.Length()); memcpy(serializable.mMessageDigest, mMessageDigest, sizeof(mMessageDigest)); + memcpy(serializable.mIPK, mIPK, sizeof(mIPK)); return CHIP_NO_ERROR; } @@ -213,10 +176,12 @@ CHIP_ERROR CASESession::FromSerializable(const CASESessionSerializable & seriali ReturnErrorOnFailure(mSharedSecret.SetLength(static_cast(serializable.mSharedSecretLen))); VerifyOrReturnError(serializable.mMessageDigestLen <= sizeof(mMessageDigest), CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(serializable.mIPKLen <= sizeof(mIPK), CHIP_ERROR_INVALID_ARGUMENT); memset(mSharedSecret, 0, sizeof(mSharedSecret.Capacity())); memcpy(mSharedSecret, serializable.mSharedSecret, mSharedSecret.Length()); memcpy(mMessageDigest, serializable.mMessageDigest, serializable.mMessageDigestLen); + memcpy(mIPK, serializable.mIPK, serializable.mIPKLen); mConnectionState.SetPeerNodeId(serializable.mPeerNodeId); mConnectionState.SetLocalKeyID(serializable.mLocalKeyId); @@ -262,8 +227,9 @@ CASESession::ListenForSessionEstablishment(OperationalCredentialSet * operationa } CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddress, - OperationalCredentialSet * operationalCredentialSet, NodeId peerNodeId, uint16_t myKeyId, - ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate) + OperationalCredentialSet * operationalCredentialSet, uint8_t opCredSetIndex, + NodeId peerNodeId, uint16_t myKeyId, ExchangeContext * exchangeCtxt, + SessionEstablishmentDelegate * delegate) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -283,6 +249,8 @@ CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddres mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout); mConnectionState.SetPeerAddress(peerAddress); mConnectionState.SetPeerNodeId(peerNodeId); + mTrustedRootId = operationalCredentialSet->GetTrustedRootId(opCredSetIndex); + VerifyOrExit(!mTrustedRootId.empty(), err = CHIP_ERROR_INTERNAL); err = SendSigmaR1(); SuccessOrExit(err); @@ -311,7 +279,7 @@ void CASESession::OnResponseTimeout(ExchangeContext * ec) CHIP_ERROR CASESession::DeriveSecureSession(SecureSession & session, SecureSession::SessionRole role) { - uint16_t saltlen; + size_t saltlen; (void) kKDFSEInfo; (void) kKDFSEInfoLength; @@ -319,16 +287,16 @@ CHIP_ERROR CASESession::DeriveSecureSession(SecureSession & session, SecureSessi VerifyOrReturnError(mPairingComplete, CHIP_ERROR_INCORRECT_STATE); // Generate Salt for Encryption keys - saltlen = kSHA256_Hash_Length; + saltlen = sizeof(mIPK) + kSHA256_Hash_Length; chip::Platform::ScopedMemoryBuffer msg_salt; ReturnErrorCodeIf(!msg_salt.Alloc(saltlen), CHIP_ERROR_NO_MEMORY); { Encoding::LittleEndian::BufferWriter bbuf(msg_salt.Get(), saltlen); - // TODO: Add IPK to Salt + bbuf.Put(mIPK, sizeof(mIPK)); bbuf.Put(mMessageDigest, sizeof(mMessageDigest)); - VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY); + VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_BUFFER_TOO_SMALL); } ReturnErrorOnFailure(session.InitFromSecret(ByteSpan(mSharedSecret, mSharedSecret.Length()), ByteSpan(msg_salt.Get(), saltlen), @@ -339,57 +307,61 @@ CHIP_ERROR CASESession::DeriveSecureSession(SecureSession & session, SecureSessi CHIP_ERROR CASESession::SendSigmaR1() { - uint16_t data_len = - static_cast(kSigmaParamRandomNumberSize + sizeof(uint16_t) + sizeof(uint16_t) + - mOpCredSet->GetCertCount() * kTrustedRootIdSize + kP256_PublicKey_Length + sizeof(uint64_t) * 4); + uint16_t data_len = EstimateTLVStructOverhead( + static_cast(kSigmaParamRandomNumberSize + sizeof(uint16_t) + kSHA256_Hash_Length + kP256_PublicKey_Length), 4); System::PacketBufferTLVWriter tlvWriter; System::PacketBufferHandle msg_R1; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; uint8_t initiatorRandom[kSigmaParamRandomNumberSize] = { 0 }; + uint8_t destinationIdentifier[kSHA256_Hash_Length] = { 0 }; - msg_R1 = System::PacketBufferHandle::New(data_len); - VerifyOrReturnError(!msg_R1.IsNull(), CHIP_ERROR_NO_MEMORY); + // Generate an ephemeral keypair +#ifdef ENABLE_HSM_CASE_EPHEMERAL_KEY + mEphemeralKey.SetKeyId(CASE_EPHEMERAL_KEY); +#endif + ReturnErrorOnFailure(mEphemeralKey.Initialize()); - // Step 1 // Fill in the random value ReturnErrorOnFailure(DRBG_get_bytes(initiatorRandom, kSigmaParamRandomNumberSize)); -// Step 4 -#ifdef ENABLE_HSM_CASE_EPHERMAL_KEY - mEphemeralKey.SetKeyId(CASE_EPHEMERAL_KEY); -#endif - ReturnErrorOnFailure(mEphemeralKey.Initialize()); + // Construct Sigma1 Msg + msg_R1 = System::PacketBufferHandle::New(data_len); + VerifyOrReturnError(!msg_R1.IsNull(), CHIP_ERROR_NO_MEMORY); - // Start writing TLV tlvWriter.Init(std::move(msg_R1)); ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriter.PutBytes(CASETLVTag::kRandom, initiatorRandom, sizeof(initiatorRandom))); - - // Step 5 - uint16_t n_trusted_roots = mOpCredSet->GetCertCount(); - // Initiator's session ID - ReturnErrorOnFailure(tlvWriter.Put(CASETLVTag::kSessionID, mConnectionState.GetLocalKeyID(), true)); - // Step 2/3 - ReturnErrorOnFailure(tlvWriter.Put(CASETLVTag::kNumberofTrustedRootIDs, n_trusted_roots, true)); - for (uint16_t i = 0; i < n_trusted_roots; ++i) + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(1), initiatorRandom, sizeof(initiatorRandom))); + // Retrieve Session Identifier + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), mConnectionState.GetLocalKeyID(), true)); + // Generate a Destination Identifier { - CertificateKeyId trustedRootId = mOpCredSet->GetTrustedRootId(i); - if (!trustedRootId.empty()) - { - ReturnErrorOnFailure(tlvWriter.PutBytes(CASETLVTag::kTrustedRootID, trustedRootId.data(), trustedRootId.size())); - } + const ChipCertificateData * rootCertificate = mOpCredSet->GetRootCertificate(mTrustedRootId); + VerifyOrReturnError(rootCertificate != nullptr, CHIP_ERROR_CERT_NOT_FOUND); + VerifyOrReturnError(!rootCertificate->mPublicKey.empty(), CHIP_ERROR_INTERNAL); + VerifyOrReturnError(rootCertificate->mPublicKey.size() == kP256_PublicKey_Length, CHIP_ERROR_INTERNAL); + ChipCertificateData nodeOperationalCertificate; + FabricId fabricId; + MutableByteSpan destinationIdSpan(destinationIdentifier); + + ReturnErrorOnFailure(DecodeChipCert(mOpCredSet->GetDevOpCred(mTrustedRootId), mOpCredSet->GetDevOpCredLen(mTrustedRootId), + nodeOperationalCertificate)); + ReturnErrorOnFailure(nodeOperationalCertificate.mSubjectDN.GetCertFabricId(fabricId)); + // retrieve Fabric IPK + MutableByteSpan ipkSpan(mIPK); + ReturnErrorOnFailure(RetrieveIPK(fabricId, ipkSpan)); + ReturnErrorOnFailure(GenerateDestinationID(ByteSpan(initiatorRandom), rootCertificate->mPublicKey, + mConnectionState.GetPeerNodeId(), fabricId, ByteSpan(mIPK), destinationIdSpan)); } - ReturnErrorOnFailure(tlvWriter.PutBytes(CASETLVTag::kInitiatorEphPubKey, mEphemeralKey.Pubkey(), - static_cast(mEphemeralKey.Pubkey().Length()))); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(3), destinationIdentifier, sizeof(destinationIdentifier))); + ReturnErrorOnFailure( + tlvWriter.PutBytes(TLV::ContextTag(4), mEphemeralKey.Pubkey(), static_cast(mEphemeralKey.Pubkey().Length()))); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize(&msg_R1)); ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R1->Start(), msg_R1->DataLength() })); - ReturnErrorOnFailure(ComputeIPK(mConnectionState.GetLocalKeyID(), mIPK, sizeof(mIPK))); - mNextExpectedMsg = Protocols::SecureChannel::MsgType::CASE_SigmaR2; // Call delegate to send the msg to peer @@ -413,55 +385,50 @@ CHIP_ERROR CASESession::HandleSigmaR1(System::PacketBufferHandle & msg) { CHIP_ERROR err = CHIP_NO_ERROR; System::PacketBufferTLVReader tlvReader; - System::PacketBufferTLVReader suppTlvReader; TLV::TLVType containerType = TLV::kTLVType_Structure; - uint16_t encryptionKeyId = 0; - uint32_t n_trusted_roots; + uint16_t initiatorSessionId = 0; + uint8_t destinationIdentifier[kSHA256_Hash_Length]; + uint8_t initiatorRandom[kSigmaParamRandomNumberSize]; + + uint32_t decodeTagIdSeq = 0; ChipLogDetail(SecureChannel, "Received SigmaR1 msg"); - err = mCommissioningHash.AddData(ByteSpan{ msg->Start(), msg->DataLength() }); - SuccessOrExit(err); + SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ msg->Start(), msg->DataLength() })); tlvReader.Init(std::move(msg)); - err = tlvReader.Next(containerType, TLV::AnonymousTag); - SuccessOrExit(err); - err = tlvReader.EnterContainer(containerType); - SuccessOrExit(err); + SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = tlvReader.EnterContainer(containerType)); - err = tlvReader.FindElementWithTag(CASETLVTag::kSessionID, suppTlvReader); - SuccessOrExit(err); - err = suppTlvReader.Get(encryptionKeyId); - SuccessOrExit(err); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.GetBytes(initiatorRandom, sizeof(initiatorRandom))); - err = tlvReader.FindElementWithTag(CASETLVTag::kNumberofTrustedRootIDs, suppTlvReader); - SuccessOrExit(err); - err = suppTlvReader.Get(n_trusted_roots); - SuccessOrExit(err); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.Get(initiatorSessionId)); - // Step 1/2 - err = FindValidTrustedRoot(tlvReader, n_trusted_roots); - SuccessOrExit(err); + ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", initiatorSessionId); + mConnectionState.SetPeerKeyID(initiatorSessionId); - // write public key from message - err = tlvReader.FindElementWithTag(CASETLVTag::kInitiatorEphPubKey, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(mRemotePubKey.Length() == suppTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - err = suppTlvReader.GetBytes(mRemotePubKey, static_cast(mRemotePubKey.Length())); - SuccessOrExit(err); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.GetBytes(destinationIdentifier, sizeof(destinationIdentifier))); - ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", encryptionKeyId); - mConnectionState.SetPeerKeyID(encryptionKeyId); + { + const ByteSpan * ipkListSpan = GetIPKList(); + SuccessOrExit(err = FindDestinationIdCandidate(ByteSpan(destinationIdentifier), ByteSpan(initiatorRandom), ipkListSpan, + GetIPKListEntries())); + } + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.GetBytes(mRemotePubKey, static_cast(mRemotePubKey.Length()))); exit: - if (err == CHIP_ERROR_CERT_NOT_TRUSTED) - { - SendErrorMsg(SigmaErrorType::kNoSharedTrustRoots); - } - else if (err != CHIP_NO_ERROR) + if (err != CHIP_NO_ERROR) { SendErrorMsg(SigmaErrorType::kUnexpected); } @@ -487,9 +454,7 @@ CHIP_ERROR CASESession::SendSigmaR2() uint16_t saltlen; uint8_t sr2k[kAEADKeySize]; - P256ECDSASignature sigmaR2Signature; - - uint8_t tag[kTAGSize]; + P256ECDSASignature tbsData2Signature; HKDF_sha_crypto mHKDF; @@ -498,31 +463,25 @@ CHIP_ERROR CASESession::SendSigmaR2() VerifyOrExit(msg_salt.Alloc(saltlen), err = CHIP_ERROR_NO_MEMORY); VerifyOrExit(msg_rand.Alloc(kSigmaParamRandomNumberSize), err = CHIP_ERROR_NO_MEMORY); - // Step 1 // Fill in the random value err = DRBG_get_bytes(msg_rand.Get(), kSigmaParamRandomNumberSize); SuccessOrExit(err); - // Step 3 - // hardcoded to use a p256keypair -#ifdef ENABLE_HSM_CASE_EPHERMAL_KEY + // Generate an ephemeral keypair +#ifdef ENABLE_HSM_CASE_EPHEMERAL_KEY mEphemeralKey.SetKeyId(CASE_EPHEMERAL_KEY); #endif err = mEphemeralKey.Initialize(); SuccessOrExit(err); - // Step 4 + // Generate a Shared Secret err = mEphemeralKey.ECDH_derive_secret(mRemotePubKey, mSharedSecret); SuccessOrExit(err); - err = ComputeIPK(mConnectionState.GetLocalKeyID(), mIPK, sizeof(mIPK)); - SuccessOrExit(err); - - // Step 5 { MutableByteSpan saltSpan(msg_salt.Get(), saltlen); - err = ConstructSaltSigmaR2(ByteSpan(msg_rand.Get(), kSigmaParamRandomNumberSize), mEphemeralKey.Pubkey(), mIPK, - sizeof(mIPK), saltSpan); + err = ConstructSaltSigmaR2(ByteSpan(msg_rand.Get(), kSigmaParamRandomNumberSize), mEphemeralKey.Pubkey(), ByteSpan(mIPK), + saltSpan); SuccessOrExit(err); } @@ -530,82 +489,78 @@ CHIP_ERROR CASESession::SendSigmaR2() kAEADKeySize); SuccessOrExit(err); - // Step 6 - msg_r2_signed_len = static_cast(sizeof(uint16_t) + mOpCredSet->GetDevOpCredLen(mTrustedRootId) + - kP256_PublicKey_Length * 2 + sizeof(uint64_t) * 3); + // Construct Sigma2 TBS Data + msg_r2_signed_len = EstimateTLVStructOverhead( + static_cast(mOpCredSet->GetDevOpCredLen(mTrustedRootId) + kP256_PublicKey_Length * 2), 3); VerifyOrExit(msg_R2_Signed.Alloc(msg_r2_signed_len), err = CHIP_ERROR_NO_MEMORY); - // Generate Sigma2 TBS Data { TLV::TLVWriter tlvWriter; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; tlvWriter.Init(msg_R2_Signed.Get(), msg_r2_signed_len); SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kResponderEphPubKey, mEphemeralKey.Pubkey(), - static_cast(mEphemeralKey.Pubkey().Length()))); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kNOC, mOpCredSet->GetDevOpCred(mTrustedRootId), + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(1), mOpCredSet->GetDevOpCred(mTrustedRootId), mOpCredSet->GetDevOpCredLen(mTrustedRootId))); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kInitiatorEphPubKey, mRemotePubKey, - static_cast(mRemotePubKey.Length()))); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(2), mEphemeralKey.Pubkey(), + static_cast(mEphemeralKey.Pubkey().Length()))); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(3), mRemotePubKey, static_cast(mRemotePubKey.Length()))); SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize()); msg_r2_signed_len = static_cast(tlvWriter.GetLengthWritten()); } - // Step 7 - err = mOpCredSet->SignMsg(mTrustedRootId, msg_R2_Signed.Get(), msg_r2_signed_len, sigmaR2Signature); + // Generate a Signature + err = mOpCredSet->SignMsg(mTrustedRootId, msg_R2_Signed.Get(), msg_r2_signed_len, tbsData2Signature); SuccessOrExit(err); - // Step 8 - msg_r2_signed_enc_len = static_cast(sizeof(uint16_t) + mOpCredSet->GetDevOpCredLen(mTrustedRootId) + - sigmaR2Signature.Length() + sizeof(uint64_t) * 2); + // Construct Sigma2 TBE Data + msg_r2_signed_enc_len = EstimateTLVStructOverhead( + static_cast(mOpCredSet->GetDevOpCredLen(mTrustedRootId) + tbsData2Signature.Length()), 2); - VerifyOrExit(msg_R2_Encrypted.Alloc(msg_r2_signed_enc_len), err = CHIP_ERROR_NO_MEMORY); + VerifyOrExit(msg_R2_Encrypted.Alloc(msg_r2_signed_enc_len + kTAGSize), err = CHIP_ERROR_NO_MEMORY); - // Generate Sigma2 TBE Data { TLV::TLVWriter tlvWriter; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; tlvWriter.Init(msg_R2_Encrypted.Get(), msg_r2_signed_enc_len); - SuccessOrExit(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kNOC, mOpCredSet->GetDevOpCred(mTrustedRootId), + SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(1), mOpCredSet->GetDevOpCred(mTrustedRootId), mOpCredSet->GetDevOpCredLen(mTrustedRootId))); SuccessOrExit( - err = tlvWriter.PutBytes(CASETLVTag::kSignature, sigmaR2Signature, static_cast(sigmaR2Signature.Length()))); + err = tlvWriter.PutBytes(TLV::ContextTag(2), tbsData2Signature, static_cast(tbsData2Signature.Length()))); SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize()); + msg_r2_signed_enc_len = static_cast(tlvWriter.GetLengthWritten()); } - // Step 9 - err = AES_CCM_encrypt(msg_R2_Encrypted.Get(), msg_r2_signed_enc_len, nullptr, 0, sr2k, kAEADKeySize, kIVSR2, kIVLength, - msg_R2_Encrypted.Get(), tag, sizeof(tag)); + // Generate the encrypted data blob + err = AES_CCM_encrypt(msg_R2_Encrypted.Get(), msg_r2_signed_enc_len, nullptr, 0, sr2k, kAEADKeySize, kTBEData2_Nonce, + kTBEDataNonceLength, msg_R2_Encrypted.Get(), msg_R2_Encrypted.Get() + msg_r2_signed_enc_len, kTAGSize); SuccessOrExit(err); - data_len = static_cast(kSigmaParamRandomNumberSize + sizeof(uint16_t) + kTrustedRootIdSize + kP256_PublicKey_Length + - msg_r2_signed_enc_len + sizeof(tag) + sizeof(uint64_t) * 6); + // Construct Sigma2 Msg + data_len = EstimateTLVStructOverhead(static_cast(kSigmaParamRandomNumberSize + sizeof(uint16_t) + + kP256_PublicKey_Length + msg_r2_signed_enc_len + kTAGSize), + 4); msg_R2 = System::PacketBufferHandle::New(data_len); VerifyOrExit(!msg_R2.IsNull(), err = CHIP_ERROR_NO_MEMORY); - // Step 10 - // now construct sigmaR2 { System::PacketBufferTLVWriter tlvWriter; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; tlvWriter.Init(std::move(msg_R2)); SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kRandom, msg_rand.Get(), kSigmaParamRandomNumberSize)); - SuccessOrExit(err = tlvWriter.Put(CASETLVTag::kSessionID, mConnectionState.GetLocalKeyID(), true)); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kTrustedRootID, mTrustedRootId.data(), - static_cast(mTrustedRootId.size()))); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kResponderEphPubKey, mEphemeralKey.Pubkey(), + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(1), msg_rand.Get(), kSigmaParamRandomNumberSize)); + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(2), mConnectionState.GetLocalKeyID(), true)); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(3), mEphemeralKey.Pubkey(), static_cast(mEphemeralKey.Pubkey().Length()))); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kEncryptedData, msg_R2_Encrypted.Get(), msg_r2_signed_enc_len)); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kTag, tag, sizeof(tag))); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(4), msg_R2_Encrypted.Get(), + static_cast(msg_r2_signed_enc_len + kTAGSize))); SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize(&msg_R2)); } @@ -643,7 +598,6 @@ CHIP_ERROR CASESession::HandleSigmaR2(System::PacketBufferHandle & msg) { CHIP_ERROR err = CHIP_NO_ERROR; System::PacketBufferTLVReader tlvReader; - System::PacketBufferTLVReader suppTlvReader; TLV::TLVReader decryptedDataTlvReader; TLV::TLVType containerType = TLV::kTLVType_Structure; @@ -654,14 +608,15 @@ CHIP_ERROR CASESession::HandleSigmaR2(System::PacketBufferHandle & msg) uint16_t saltlen; chip::Platform::ScopedMemoryBuffer msg_R2_Encrypted; - uint16_t msg_r2_encrypted_len; + size_t msg_r2_encrypted_len = 0; + size_t msg_r2_encrypted_len_with_tag = 0; chip::Platform::ScopedMemoryBuffer msg_R2_Signed; uint16_t msg_r2_signed_len; uint8_t sr2k[kAEADKeySize]; - P256ECDSASignature sigmaR2SignedData; + P256ECDSASignature tbsData2Signature; P256PublicKey remoteCredential; @@ -669,12 +624,12 @@ CHIP_ERROR CASESession::HandleSigmaR2(System::PacketBufferHandle & msg) uint8_t responderOpCert[1024]; uint16_t responderOpCertLen; - uint8_t tag[kTAGSize]; - - uint16_t encryptionKeyId = 0; + uint16_t responderSessionId = 0; HKDF_sha_crypto mHKDF; + uint32_t decodeTagIdSeq = 0; + VerifyOrExit(buf != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); ChipLogDetail(SecureChannel, "Received SigmaR2 msg"); @@ -683,124 +638,92 @@ CHIP_ERROR CASESession::HandleSigmaR2(System::PacketBufferHandle & msg) SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); SuccessOrExit(err = tlvReader.EnterContainer(containerType)); - // Assign Session Key ID - SuccessOrExit(err = tlvReader.FindElementWithTag(CASETLVTag::kSessionID, suppTlvReader)); - SuccessOrExit(err = suppTlvReader.Get(encryptionKeyId)); - - ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", encryptionKeyId); - mConnectionState.SetPeerKeyID(encryptionKeyId); - // Retrieve Responder's Random value - err = tlvReader.FindElementWithTag(CASETLVTag::kRandom, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(kSigmaParamRandomNumberSize == suppTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - err = suppTlvReader.GetBytes(responderRandom, sizeof(responderRandom)); - SuccessOrExit(err); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.GetBytes(responderRandom, sizeof(responderRandom))); - SuccessOrExit(err = FindValidTrustedRoot(tlvReader, 1)); + // Assign Session Key ID + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.Get(responderSessionId)); + + ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", responderSessionId); + mConnectionState.SetPeerKeyID(responderSessionId); // Retrieve Responder's Ephemeral Pubkey - SuccessOrExit(err = tlvReader.FindElementWithTag(CASETLVTag::kResponderEphPubKey, suppTlvReader)); - VerifyOrExit(mRemotePubKey.Length() == suppTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - SuccessOrExit(err = suppTlvReader.GetBytes(mRemotePubKey, static_cast(mRemotePubKey.Length()))); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.GetBytes(mRemotePubKey, static_cast(mRemotePubKey.Length()))); - // Step 2 - err = mEphemeralKey.ECDH_derive_secret(mRemotePubKey, mSharedSecret); - SuccessOrExit(err); + // Generate a Shared Secret + SuccessOrExit(err = mEphemeralKey.ECDH_derive_secret(mRemotePubKey, mSharedSecret)); - // Step 3 + // Generate the S2K key saltlen = kIPKSize + kSigmaParamRandomNumberSize + kP256_PublicKey_Length + kSHA256_Hash_Length; VerifyOrExit(msg_salt.Alloc(saltlen), err = CHIP_ERROR_NO_MEMORY); - err = ComputeIPK(mConnectionState.GetPeerKeyID(), mRemoteIPK, sizeof(mRemoteIPK)); - SuccessOrExit(err); - { MutableByteSpan saltSpan(msg_salt.Get(), saltlen); - err = ConstructSaltSigmaR2(ByteSpan(responderRandom, sizeof(responderRandom)), mRemotePubKey, mRemoteIPK, - sizeof(mRemoteIPK), saltSpan); - SuccessOrExit(err); + SuccessOrExit(err = ConstructSaltSigmaR2(ByteSpan(responderRandom), mRemotePubKey, ByteSpan(mIPK), saltSpan)); } - err = mHKDF.HKDF_SHA256(mSharedSecret, mSharedSecret.Length(), msg_salt.Get(), saltlen, kKDFSR2Info, kKDFInfoLength, sr2k, - kAEADKeySize); - SuccessOrExit(err); + SuccessOrExit(err = mHKDF.HKDF_SHA256(mSharedSecret, mSharedSecret.Length(), msg_salt.Get(), saltlen, kKDFSR2Info, + kKDFInfoLength, sr2k, kAEADKeySize)); - err = mCommissioningHash.AddData(ByteSpan{ buf, buflen }); - SuccessOrExit(err); + SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ buf, buflen })); - // Step 4 - err = tlvReader.FindElementWithTag(CASETLVTag::kEncryptedData, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - VerifyOrExit(msg_R2_Encrypted.Alloc(suppTlvReader.GetLength()), err = CHIP_ERROR_NO_MEMORY); - msg_r2_encrypted_len = static_cast(suppTlvReader.GetLength()); - err = suppTlvReader.GetBytes(msg_R2_Encrypted.Get(), msg_r2_encrypted_len); - SuccessOrExit(err); + // Generate decrypted data + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + VerifyOrExit(msg_R2_Encrypted.Alloc(tlvReader.GetLength()), err = CHIP_ERROR_NO_MEMORY); + msg_r2_encrypted_len_with_tag = tlvReader.GetLength(); + VerifyOrExit(msg_r2_encrypted_len_with_tag > kTAGSize, err = CHIP_ERROR_INVALID_TLV_ELEMENT); + SuccessOrExit(err = tlvReader.GetBytes(msg_R2_Encrypted.Get(), static_cast(msg_r2_encrypted_len_with_tag))); + msg_r2_encrypted_len = msg_r2_encrypted_len_with_tag - kTAGSize; - err = tlvReader.FindElementWithTag(CASETLVTag::kTag, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - VerifyOrExit(kTAGSize == suppTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); - err = suppTlvReader.GetBytes(tag, sizeof(tag)); - SuccessOrExit(err); - - err = AES_CCM_decrypt(msg_R2_Encrypted.Get(), msg_r2_encrypted_len, nullptr, 0, tag, kTAGSize, sr2k, kAEADKeySize, kIVSR2, - kIVLength, msg_R2_Encrypted.Get()); - SuccessOrExit(err); + SuccessOrExit(err = AES_CCM_decrypt(msg_R2_Encrypted.Get(), msg_r2_encrypted_len, nullptr, 0, + msg_R2_Encrypted.Get() + msg_r2_encrypted_len, kTAGSize, sr2k, kAEADKeySize, + kTBEData2_Nonce, kTBEDataNonceLength, msg_R2_Encrypted.Get())); - decryptedDataTlvReader.Init(msg_R2_Encrypted.Get(), msg_r2_encrypted_len); + decodeTagIdSeq = 0; + decryptedDataTlvReader.Init(msg_R2_Encrypted.Get(), static_cast(msg_r2_encrypted_len)); containerType = TLV::kTLVType_Structure; - err = decryptedDataTlvReader.Next(containerType, TLV::AnonymousTag); - SuccessOrExit(err); - err = decryptedDataTlvReader.EnterContainer(containerType); - SuccessOrExit(err); + SuccessOrExit(err = decryptedDataTlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = decryptedDataTlvReader.EnterContainer(containerType)); - err = decryptedDataTlvReader.FindElementWithTag(CASETLVTag::kNOC, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - responderOpCertLen = static_cast(suppTlvReader.GetLength()); - err = suppTlvReader.GetBytes(responderOpCert, responderOpCertLen); - SuccessOrExit(err); + SuccessOrExit(err = decryptedDataTlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + responderOpCertLen = static_cast(decryptedDataTlvReader.GetLength()); + SuccessOrExit(err = decryptedDataTlvReader.GetBytes(responderOpCert, responderOpCertLen)); - // Step 5 // Validate responder identity located in msg_r2_encrypted // Constructing responder identity - err = Validate_and_RetrieveResponderID(responderOpCert, responderOpCertLen, remoteCredential); - SuccessOrExit(err); + SuccessOrExit(err = Validate_and_RetrieveResponderID(ByteSpan(responderOpCert, responderOpCertLen), remoteCredential)); - // Step 6 - Construct msg_R2_Signed and validate the signature in msg_r2_encrypted + // Construct msg_R2_Signed and validate the signature in msg_r2_encrypted msg_r2_signed_len = - static_cast(sizeof(uint16_t) + responderOpCertLen + kP256_PublicKey_Length * 2 + sizeof(uint64_t) * 3); + EstimateTLVStructOverhead(static_cast(sizeof(uint16_t) + responderOpCertLen + kP256_PublicKey_Length * 2), 3); VerifyOrExit(msg_R2_Signed.Alloc(msg_r2_signed_len), err = CHIP_ERROR_NO_MEMORY); - err = ConstructTBS2Data(responderOpCert, responderOpCertLen, msg_R2_Signed.Get(), msg_r2_signed_len); - SuccessOrExit(err); + SuccessOrExit(err = ConstructTBS2Data(ByteSpan(responderOpCert, responderOpCertLen), msg_R2_Signed.Get(), msg_r2_signed_len)); - err = decryptedDataTlvReader.FindElementWithTag(CASETLVTag::kSignature, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - VerifyOrExit(sigmaR2SignedData.Capacity() >= suppTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); - sigmaR2SignedData.SetLength(suppTlvReader.GetLength()); - err = suppTlvReader.GetBytes(sigmaR2SignedData, static_cast(sigmaR2SignedData.Length())); - SuccessOrExit(err); + SuccessOrExit(err = decryptedDataTlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + VerifyOrExit(tbsData2Signature.Capacity() >= decryptedDataTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); + tbsData2Signature.SetLength(decryptedDataTlvReader.GetLength()); + SuccessOrExit(err = decryptedDataTlvReader.GetBytes(tbsData2Signature, static_cast(tbsData2Signature.Length()))); - err = remoteCredential.ECDSA_validate_msg_signature(msg_R2_Signed.Get(), msg_r2_signed_len, sigmaR2SignedData); - SuccessOrExit(err); + // Validate signature + SuccessOrExit(err = remoteCredential.ECDSA_validate_msg_signature(msg_R2_Signed.Get(), msg_r2_signed_len, tbsData2Signature)); exit: if (err == CHIP_ERROR_INVALID_SIGNATURE) { SendErrorMsg(SigmaErrorType::kInvalidSignature); } - else if (err == CHIP_ERROR_CERT_NOT_TRUSTED) - { - SendErrorMsg(SigmaErrorType::kNoSharedTrustRoots); - } else if (err != CHIP_NO_ERROR) { SendErrorMsg(SigmaErrorType::kUnexpected); @@ -827,30 +750,24 @@ CHIP_ERROR CASESession::SendSigmaR3() chip::Platform::ScopedMemoryBuffer msg_R3_Signed; uint16_t msg_r3_signed_len; - P256ECDSASignature sigmaR3Signature; - - uint8_t tag[kTAGSize]; + P256ECDSASignature tbsData3Signature; HKDF_sha_crypto mHKDF; - // Step 1 - saltlen = kIPKSize + kSHA256_Hash_Length; - ChipLogDetail(SecureChannel, "Sending SigmaR3"); + + saltlen = kIPKSize + kSHA256_Hash_Length; VerifyOrExit(msg_salt.Alloc(saltlen), err = CHIP_ERROR_NO_MEMORY); { MutableByteSpan saltSpan(msg_salt.Get(), saltlen); - err = ConstructSaltSigmaR3(mIPK, sizeof(mIPK), saltSpan); + err = ConstructSaltSigmaR3(ByteSpan(mIPK), saltSpan); SuccessOrExit(err); } - err = mHKDF.HKDF_SHA256(mSharedSecret, mSharedSecret.Length(), msg_salt.Get(), saltlen, kKDFSR3Info, kKDFInfoLength, sr3k, - kAEADKeySize); - SuccessOrExit(err); - // Step 2 - msg_r3_signed_len = static_cast(sizeof(uint16_t) + mOpCredSet->GetDevOpCredLen(mTrustedRootId) + - kP256_PublicKey_Length * 2 + sizeof(uint64_t) * 3); + // Prepare SigmaR3 TBS Data Blob + msg_r3_signed_len = EstimateTLVStructOverhead( + static_cast(mOpCredSet->GetDevOpCredLen(mTrustedRootId) + kP256_PublicKey_Length * 2), 3); VerifyOrExit(msg_R3_Signed.Alloc(msg_r3_signed_len), err = CHIP_ERROR_NO_MEMORY); @@ -860,52 +777,53 @@ CHIP_ERROR CASESession::SendSigmaR3() tlvWriter.Init(msg_R3_Signed.Get(), msg_r3_signed_len); SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kInitiatorEphPubKey, mEphemeralKey.Pubkey(), - static_cast(mEphemeralKey.Pubkey().Length()))); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kNOC, mOpCredSet->GetDevOpCred(mTrustedRootId), + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(1), mOpCredSet->GetDevOpCred(mTrustedRootId), mOpCredSet->GetDevOpCredLen(mTrustedRootId))); - SuccessOrExit(err = tlvWriter.PutBytes(CASETLVTag::kResponderEphPubKey, mRemotePubKey, - static_cast(mRemotePubKey.Length()))); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(2), mEphemeralKey.Pubkey(), + static_cast(mEphemeralKey.Pubkey().Length()))); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(3), mRemotePubKey, static_cast(mRemotePubKey.Length()))); SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize()); msg_r3_signed_len = static_cast(tlvWriter.GetLengthWritten()); } - // Step 3 - err = mOpCredSet->SignMsg(mTrustedRootId, msg_R3_Signed.Get(), msg_r3_signed_len, sigmaR3Signature); + // Generate a signature + err = mOpCredSet->SignMsg(mTrustedRootId, msg_R3_Signed.Get(), msg_r3_signed_len, tbsData3Signature); SuccessOrExit(err); - // Step 4 - msg_r3_encrypted_len = static_cast(sizeof(uint16_t) + mOpCredSet->GetDevOpCredLen(mTrustedRootId) + - static_cast(sigmaR3Signature.Length()) + sizeof(uint64_t) * 2); + // Prepare SigmaR3 TBE Data Blob + msg_r3_encrypted_len = EstimateTLVStructOverhead( + static_cast(mOpCredSet->GetDevOpCredLen(mTrustedRootId) + static_cast(tbsData3Signature.Length())), 2); - VerifyOrExit(msg_R3_Encrypted.Alloc(msg_r3_encrypted_len), err = CHIP_ERROR_NO_MEMORY); + VerifyOrExit(msg_R3_Encrypted.Alloc(msg_r3_encrypted_len + kTAGSize), err = CHIP_ERROR_NO_MEMORY); { TLV::TLVWriter tlvWriter; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; tlvWriter.Init(msg_R3_Encrypted.Get(), msg_r3_encrypted_len); - err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType); - SuccessOrExit(err); - err = tlvWriter.PutBytes(CASETLVTag::kNOC, mOpCredSet->GetDevOpCred(mTrustedRootId), - mOpCredSet->GetDevOpCredLen(mTrustedRootId)); - SuccessOrExit(err); - err = tlvWriter.PutBytes(CASETLVTag::kSignature, sigmaR3Signature, static_cast(sigmaR3Signature.Length())); - SuccessOrExit(err); - err = tlvWriter.EndContainer(outerContainerType); - SuccessOrExit(err); - err = tlvWriter.Finalize(); - SuccessOrExit(err); + SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); + SuccessOrExit(err = tlvWriter.PutBytes(TLV::ContextTag(1), mOpCredSet->GetDevOpCred(mTrustedRootId), + mOpCredSet->GetDevOpCredLen(mTrustedRootId))); + SuccessOrExit( + err = tlvWriter.PutBytes(TLV::ContextTag(2), tbsData3Signature, static_cast(tbsData3Signature.Length()))); + SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); + SuccessOrExit(err = tlvWriter.Finalize()); + msg_r3_encrypted_len = static_cast(tlvWriter.GetLengthWritten()); } - // Step 5 - err = AES_CCM_encrypt(msg_R3_Encrypted.Get(), msg_r3_encrypted_len, nullptr, 0, sr3k, kAEADKeySize, kIVSR3, kIVLength, - msg_R3_Encrypted.Get(), tag, sizeof(tag)); + // Generate S3K key + err = mHKDF.HKDF_SHA256(mSharedSecret, mSharedSecret.Length(), msg_salt.Get(), saltlen, kKDFSR3Info, kKDFInfoLength, sr3k, + kAEADKeySize); + SuccessOrExit(err); + + // Generated Encrypted data blob + err = AES_CCM_encrypt(msg_R3_Encrypted.Get(), msg_r3_encrypted_len, nullptr, 0, sr3k, kAEADKeySize, kTBEData3_Nonce, + kTBEDataNonceLength, msg_R3_Encrypted.Get(), msg_R3_Encrypted.Get() + msg_r3_encrypted_len, kTAGSize); SuccessOrExit(err); - // Step 6 - data_len = static_cast(sizeof(tag) + msg_r3_encrypted_len + sizeof(uint64_t) * 2); + // Generate Sigma3 Msg + data_len = EstimateTLVStructOverhead(static_cast(kTAGSize + msg_r3_encrypted_len), 1); msg_R3 = System::PacketBufferHandle::New(data_len); VerifyOrExit(!msg_R3.IsNull(), err = CHIP_ERROR_NO_MEMORY); @@ -917,9 +835,9 @@ CHIP_ERROR CASESession::SendSigmaR3() tlvWriter.Init(std::move(msg_R3)); err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType); SuccessOrExit(err); - err = tlvWriter.PutBytes(CASETLVTag::kEncryptedData, msg_R3_Encrypted.Get(), msg_r3_encrypted_len); + err = + tlvWriter.PutBytes(TLV::ContextTag(1), msg_R3_Encrypted.Get(), static_cast(msg_r3_encrypted_len + kTAGSize)); SuccessOrExit(err); - err = tlvWriter.PutBytes(CASETLVTag::kTag, tag, sizeof(tag)); err = tlvWriter.EndContainer(outerContainerType); SuccessOrExit(err); err = tlvWriter.Finalize(&msg_R3); @@ -960,7 +878,6 @@ CHIP_ERROR CASESession::HandleSigmaR3(System::PacketBufferHandle & msg) CHIP_ERROR err = CHIP_NO_ERROR; MutableByteSpan messageDigestSpan(mMessageDigest); System::PacketBufferTLVReader tlvReader; - System::PacketBufferTLVReader suppTlvReader; TLV::TLVReader decryptedDataTlvReader; TLV::TLVType containerType = TLV::kTLVType_Structure; @@ -968,13 +885,14 @@ CHIP_ERROR CASESession::HandleSigmaR3(System::PacketBufferHandle & msg) const uint16_t bufLen = msg->DataLength(); chip::Platform::ScopedMemoryBuffer msg_R3_Encrypted; - uint16_t msg_r3_encrypted_len; + size_t msg_r3_encrypted_len = 0; + size_t msg_r3_encrypted_len_with_tag = 0; chip::Platform::ScopedMemoryBuffer msg_R3_Signed; uint16_t msg_r3_signed_len; uint8_t sr3k[kAEADKeySize]; - P256ECDSASignature sigmaR3SignedData; + P256ECDSASignature tbsData3Signature; P256PublicKey remoteCredential; @@ -984,104 +902,81 @@ CHIP_ERROR CASESession::HandleSigmaR3(System::PacketBufferHandle & msg) chip::Platform::ScopedMemoryBuffer msg_salt; uint16_t saltlen; - uint8_t tag[kTAGSize]; - HKDF_sha_crypto mHKDF; + uint32_t decodeTagIdSeq = 0; + ChipLogDetail(SecureChannel, "Received SigmaR3 msg"); mNextExpectedMsg = Protocols::SecureChannel::MsgType::CASE_SigmaErr; tlvReader.Init(std::move(msg)); - err = tlvReader.Next(containerType, TLV::AnonymousTag); - SuccessOrExit(err); - err = tlvReader.EnterContainer(containerType); - SuccessOrExit(err); - - err = tlvReader.FindElementWithTag(CASETLVTag::kEncryptedData, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); + SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = tlvReader.EnterContainer(containerType)); - VerifyOrExit(msg_R3_Encrypted.Alloc(suppTlvReader.GetLength()), err = CHIP_ERROR_NO_MEMORY); - msg_r3_encrypted_len = static_cast(suppTlvReader.GetLength()); - err = suppTlvReader.GetBytes(msg_R3_Encrypted.Get(), msg_r3_encrypted_len); - SuccessOrExit(err); - - err = tlvReader.FindElementWithTag(CASETLVTag::kTag, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - VerifyOrExit(kTAGSize == suppTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); - err = suppTlvReader.GetBytes(tag, sizeof(tag)); - SuccessOrExit(err); + // Fetch encrypted data + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + VerifyOrExit(msg_R3_Encrypted.Alloc(tlvReader.GetLength()), err = CHIP_ERROR_NO_MEMORY); + msg_r3_encrypted_len_with_tag = tlvReader.GetLength(); + VerifyOrExit(msg_r3_encrypted_len_with_tag > kTAGSize, err = CHIP_ERROR_INVALID_TLV_ELEMENT); + SuccessOrExit(err = tlvReader.GetBytes(msg_R3_Encrypted.Get(), static_cast(msg_r3_encrypted_len_with_tag))); + msg_r3_encrypted_len = msg_r3_encrypted_len_with_tag - kTAGSize; // Step 1 saltlen = kIPKSize + kSHA256_Hash_Length; VerifyOrExit(msg_salt.Alloc(saltlen), err = CHIP_ERROR_NO_MEMORY); - err = ComputeIPK(mConnectionState.GetPeerKeyID(), mRemoteIPK, sizeof(mRemoteIPK)); - SuccessOrExit(err); - { MutableByteSpan saltSpan(msg_salt.Get(), saltlen); - err = ConstructSaltSigmaR3(mRemoteIPK, sizeof(mRemoteIPK), saltSpan); - SuccessOrExit(err); + SuccessOrExit(err = ConstructSaltSigmaR3(ByteSpan(mIPK), saltSpan)); } - err = mHKDF.HKDF_SHA256(mSharedSecret, mSharedSecret.Length(), msg_salt.Get(), saltlen, kKDFSR3Info, kKDFInfoLength, sr3k, - kAEADKeySize); - SuccessOrExit(err); + SuccessOrExit(err = mHKDF.HKDF_SHA256(mSharedSecret, mSharedSecret.Length(), msg_salt.Get(), saltlen, kKDFSR3Info, + kKDFInfoLength, sr3k, kAEADKeySize)); - err = mCommissioningHash.AddData(ByteSpan{ buf, bufLen }); - SuccessOrExit(err); + SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ buf, bufLen })); - // Step 2 - err = AES_CCM_decrypt(msg_R3_Encrypted.Get(), msg_r3_encrypted_len, nullptr, 0, tag, kTAGSize, sr3k, kAEADKeySize, kIVSR3, - kIVLength, msg_R3_Encrypted.Get()); - SuccessOrExit(err); + // Step 2 - Decrypt data blob + SuccessOrExit(err = AES_CCM_decrypt(msg_R3_Encrypted.Get(), msg_r3_encrypted_len, nullptr, 0, + msg_R3_Encrypted.Get() + msg_r3_encrypted_len, kTAGSize, sr3k, kAEADKeySize, + kTBEData3_Nonce, kTBEDataNonceLength, msg_R3_Encrypted.Get())); - decryptedDataTlvReader.Init(msg_R3_Encrypted.Get(), msg_r3_encrypted_len); + decodeTagIdSeq = 0; + decryptedDataTlvReader.Init(msg_R3_Encrypted.Get(), static_cast(msg_r3_encrypted_len)); containerType = TLV::kTLVType_Structure; - err = decryptedDataTlvReader.Next(containerType, TLV::AnonymousTag); - SuccessOrExit(err); - err = decryptedDataTlvReader.EnterContainer(containerType); - SuccessOrExit(err); + SuccessOrExit(err = decryptedDataTlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = decryptedDataTlvReader.EnterContainer(containerType)); - err = decryptedDataTlvReader.FindElementWithTag(CASETLVTag::kNOC, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - responderOpCertLen = static_cast(suppTlvReader.GetLength()); - err = suppTlvReader.GetBytes(responderOpCert, responderOpCertLen); - SuccessOrExit(err); + SuccessOrExit(err = decryptedDataTlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + responderOpCertLen = static_cast(decryptedDataTlvReader.GetLength()); + SuccessOrExit(err = decryptedDataTlvReader.GetBytes(responderOpCert, responderOpCertLen)); - // Step 3 + // Step 5/6 // Validate initiator identity located in msg->Start() // Constructing responder identity - err = Validate_and_RetrieveResponderID(responderOpCert, responderOpCertLen, remoteCredential); - SuccessOrExit(err); + SuccessOrExit(err = Validate_and_RetrieveResponderID(ByteSpan(responderOpCert, responderOpCertLen), remoteCredential)); - // Step 4 + // Step 4 - Construct SigmaR3 TBS Data msg_r3_signed_len = - static_cast(sizeof(uint16_t) + responderOpCertLen + kP256_PublicKey_Length * 2 + sizeof(uint64_t) * 3); + EstimateTLVStructOverhead(static_cast(sizeof(uint16_t) + responderOpCertLen + kP256_PublicKey_Length * 2), 3); VerifyOrExit(msg_R3_Signed.Alloc(msg_r3_signed_len), err = CHIP_ERROR_NO_MEMORY); - err = ConstructTBS3Data(responderOpCert, responderOpCertLen, msg_R3_Signed.Get(), msg_r3_signed_len); - SuccessOrExit(err); + SuccessOrExit(err = ConstructTBS3Data(ByteSpan(responderOpCert, responderOpCertLen), msg_R3_Signed.Get(), msg_r3_signed_len)); - err = decryptedDataTlvReader.FindElementWithTag(CASETLVTag::kSignature, suppTlvReader); - SuccessOrExit(err); - VerifyOrExit(suppTlvReader.GetType() == TLV::kTLVType_ByteString, err = CHIP_ERROR_WRONG_TLV_TYPE); - VerifyOrExit(sigmaR3SignedData.Capacity() >= suppTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); - sigmaR3SignedData.SetLength(suppTlvReader.GetLength()); - err = suppTlvReader.GetBytes(sigmaR3SignedData, static_cast(sigmaR3SignedData.Length())); - SuccessOrExit(err); + SuccessOrExit(err = decryptedDataTlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(decryptedDataTlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + VerifyOrExit(tbsData3Signature.Capacity() >= decryptedDataTlvReader.GetLength(), err = CHIP_ERROR_INVALID_TLV_ELEMENT); + tbsData3Signature.SetLength(decryptedDataTlvReader.GetLength()); + SuccessOrExit(err = decryptedDataTlvReader.GetBytes(tbsData3Signature, static_cast(tbsData3Signature.Length()))); - err = remoteCredential.ECDSA_validate_msg_signature(msg_R3_Signed.Get(), msg_r3_signed_len, sigmaR3SignedData); - SuccessOrExit(err); + // Step 7 - Validate Signature + SuccessOrExit(err = remoteCredential.ECDSA_validate_msg_signature(msg_R3_Signed.Get(), msg_r3_signed_len, tbsData3Signature)); - err = mCommissioningHash.Finish(messageDigestSpan); - SuccessOrExit(err); + SuccessOrExit(err = mCommissioningHash.Finish(messageDigestSpan)); mPairingComplete = true; @@ -1121,81 +1016,110 @@ void CASESession::SendErrorMsg(SigmaErrorType errorCode) ChipLogError(SecureChannel, "Failed to send error message")); } -CHIP_ERROR CASESession::FindValidTrustedRoot(const System::PacketBufferTLVReader & tlvReader, uint32_t nTrustedRoots) +CHIP_ERROR CASESession::GenerateDestinationID(const ByteSpan & random, const P256PublicKeySpan & rootPubkey, NodeId nodeId, + FabricId fabricId, const ByteSpan & ipk, MutableByteSpan & destinationId) { - CertificateKeyId trustedRoot; - System::PacketBufferTLVReader suppTlvReader; - uint8_t trustedRootId[kTrustedRootIdSize]; + HMAC_sha hmac; + uint8_t destinationMessage[kDestinationMessageLen]; + + Encoding::LittleEndian::BufferWriter bbuf(destinationMessage, sizeof(destinationMessage)); - trustedRoot = CertificateKeyId(trustedRootId); + bbuf.Put(random.data(), random.size()); + bbuf.Put(rootPubkey.data(), rootPubkey.size()); + bbuf.Put64(fabricId); + bbuf.Put64(nodeId); + + VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_BUFFER_TOO_SMALL); + + ReturnErrorOnFailure(hmac.HMAC_SHA256(ipk.data(), ipk.size(), destinationMessage, sizeof(destinationMessage), + destinationId.data(), destinationId.size())); + + return CHIP_NO_ERROR; +} - for (uint32_t i = 0; i < nTrustedRoots; ++i) +CHIP_ERROR CASESession::FindDestinationIdCandidate(const ByteSpan & destinationId, const ByteSpan & initiatorRandom, + const ByteSpan * ipkList, size_t ipkListEntries) +{ + uint8_t nCertificateSets = mOpCredSet->GetCertCount(); + + for (size_t certChainIdx = 0; certChainIdx < nCertificateSets; ++certChainIdx) { - ReturnErrorOnFailure(tlvReader.FindElementWithTag(CASETLVTag::kTrustedRootID, suppTlvReader)); + uint8_t candidate[kSHA256_Hash_Length] = { 0 }; + CertificateKeyId trustedRootId; + ChipCertificateData nodeOperationalCertificate; + NodeId nodeId; + FabricId fabricId; + + trustedRootId = mOpCredSet->GetTrustedRootId(static_cast(certChainIdx)); - VerifyOrReturnError(kTrustedRootIdSize == suppTlvReader.GetLength(), CHIP_ERROR_INVALID_TLV_ELEMENT); - VerifyOrReturnError(suppTlvReader.GetType() == TLV::kTLVType_ByteString, CHIP_ERROR_WRONG_TLV_TYPE); - ReturnErrorOnFailure(suppTlvReader.GetBytes(trustedRootId, kTrustedRootIdSize)); + ReturnErrorOnFailure(DecodeChipCert(mOpCredSet->GetDevOpCred(trustedRootId), mOpCredSet->GetDevOpCredLen(trustedRootId), + nodeOperationalCertificate)); - if (mOpCredSet->IsTrustedRootIn(trustedRoot)) + ReturnErrorOnFailure(nodeOperationalCertificate.mSubjectDN.GetCertChipId(nodeId)); + ReturnErrorOnFailure(nodeOperationalCertificate.mSubjectDN.GetCertFabricId(fabricId)); + + const ChipCertificateData * rootCertificate = mOpCredSet->GetRootCertificate(trustedRootId); + VerifyOrReturnError(rootCertificate != nullptr, CHIP_ERROR_CERT_NOT_FOUND); + VerifyOrReturnError(!rootCertificate->mPublicKey.empty(), CHIP_ERROR_INTERNAL); + VerifyOrReturnError(rootCertificate->mPublicKey.size() == kP256_PublicKey_Length, CHIP_ERROR_INTERNAL); + + for (size_t ipkIdx = 0; ipkIdx < ipkListEntries; ++ipkIdx) { - if (!mTrustedRootId.empty()) + MutableByteSpan candidateSpan(candidate); + ReturnErrorOnFailure(GenerateDestinationID(initiatorRandom, rootCertificate->mPublicKey, nodeId, fabricId, + ipkList[ipkIdx], candidateSpan)); + + if (destinationId.data_equal(candidateSpan)) { - chip::Platform::MemoryFree(const_cast(mTrustedRootId.data())); - mTrustedRootId = CertificateKeyId(); + VerifyOrReturnError(sizeof(mIPK) == ipkList[ipkIdx].size(), CHIP_ERROR_INTERNAL); + memcpy(mIPK, ipkList[ipkIdx].data(), ipkList[ipkIdx].size()); + mTrustedRootId = trustedRootId; + break; } - mTrustedRootId = CertificateKeyId(reinterpret_cast(chip::Platform::MemoryAlloc(kTrustedRootIdSize))); - VerifyOrReturnError(!mTrustedRootId.empty(), CHIP_ERROR_NO_MEMORY); - - memcpy(const_cast(mTrustedRootId.data()), trustedRoot.data(), trustedRoot.size()); - - break; } } + VerifyOrReturnError(!mTrustedRootId.empty(), CHIP_ERROR_CERT_NOT_TRUSTED); return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::ConstructSaltSigmaR2(const ByteSpan & rand, const P256PublicKey & pubkey, const uint8_t * ipk, - size_t ipkLen, MutableByteSpan & salt) +CHIP_ERROR CASESession::ConstructSaltSigmaR2(const ByteSpan & rand, const Crypto::P256PublicKey & pubkey, const ByteSpan & ipk, + MutableByteSpan & salt) { uint8_t md[kSHA256_Hash_Length]; memset(salt.data(), 0, salt.size()); Encoding::LittleEndian::BufferWriter bbuf(salt.data(), salt.size()); - bbuf.Put(ipk, ipkLen); + bbuf.Put(ipk.data(), ipk.size()); bbuf.Put(rand.data(), kSigmaParamRandomNumberSize); bbuf.Put(pubkey, pubkey.Length()); MutableByteSpan messageDigestSpan(md); - ReturnErrorOnFailure(mCommissioningHash.Finish(messageDigestSpan)); + ReturnErrorOnFailure(mCommissioningHash.GetDigest(messageDigestSpan)); bbuf.Put(messageDigestSpan.data(), messageDigestSpan.size()); - ReturnErrorOnFailure(mCommissioningHash.Begin()); - VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY); + VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_BUFFER_TOO_SMALL); return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::ConstructSaltSigmaR3(const uint8_t * ipk, size_t ipkLen, MutableByteSpan & salt) +CHIP_ERROR CASESession::ConstructSaltSigmaR3(const ByteSpan & ipk, MutableByteSpan & salt) { uint8_t md[kSHA256_Hash_Length]; memset(salt.data(), 0, salt.size()); Encoding::LittleEndian::BufferWriter bbuf(salt.data(), salt.size()); - bbuf.Put(ipk, ipkLen); + bbuf.Put(ipk.data(), ipk.size()); MutableByteSpan messageDigestSpan(md); - ReturnErrorOnFailure(mCommissioningHash.Finish(messageDigestSpan)); + ReturnErrorOnFailure(mCommissioningHash.GetDigest(messageDigestSpan)); bbuf.Put(messageDigestSpan.data(), messageDigestSpan.size()); - ReturnErrorOnFailure(mCommissioningHash.Begin()); - VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY); + VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_BUFFER_TOO_SMALL); return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::Validate_and_RetrieveResponderID(const uint8_t * responderOpCert, uint16_t responderOpCertLen, - P256PublicKey & responderID) +CHIP_ERROR CASESession::Validate_and_RetrieveResponderID(const ByteSpan & responderOpCert, Crypto::P256PublicKey & responderID) { const ChipCertificateData * resultCert = nullptr; @@ -1204,17 +1128,17 @@ CHIP_ERROR CASESession::Validate_and_RetrieveResponderID(const uint8_t * respond ReturnErrorOnFailure(certSet.Init(3)); Encoding::LittleEndian::BufferWriter bbuf(responderID, responderID.Length()); - ReturnErrorOnFailure( - certSet.LoadCerts(responderOpCert, responderOpCertLen, BitFlags(CertDecodeFlags::kGenerateTBSHash))); + ReturnErrorOnFailure(certSet.LoadCert(responderOpCert.data(), static_cast(responderOpCert.size()), + BitFlags(CertDecodeFlags::kGenerateTBSHash))); bbuf.Put(certSet.GetCertSet()[0].mPublicKey.data(), certSet.GetCertSet()[0].mPublicKey.size()); - VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY); + VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_BUFFER_TOO_SMALL); // Validate responder identity located in msg_r2_encrypted - ReturnErrorOnFailure( - mOpCredSet->FindCertSet(mTrustedRootId) - ->LoadCerts(responderOpCert, responderOpCertLen, BitFlags(CertDecodeFlags::kGenerateTBSHash))); + ReturnErrorOnFailure(mOpCredSet->FindCertSet(mTrustedRootId) + ->LoadCert(responderOpCert.data(), static_cast(responderOpCert.size()), + BitFlags(CertDecodeFlags::kGenerateTBSHash))); ReturnErrorOnFailure(SetEffectiveTime()); // Locate the subject DN and key id that will be used as input the FindValidCert() method. @@ -1231,11 +1155,13 @@ CHIP_ERROR CASESession::Validate_and_RetrieveResponderID(const uint8_t * respond mConnectionState.SetPeerNodeId(peerId.GetNodeId()); } + // Release the previously loaded NOC Certificate + ReturnErrorOnFailure(mOpCredSet->FindCertSet(mTrustedRootId)->ReleaseLastCert()); + return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::ConstructTBS2Data(const uint8_t * responderOpCert, uint32_t responderOpCertLen, uint8_t * tbsData, - uint16_t & tbsDataLen) +CHIP_ERROR CASESession::ConstructTBS2Data(const ByteSpan & responderOpCert, uint8_t * tbsData, uint16_t & tbsDataLen) { TLV::TLVWriter tlvWriter; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; @@ -1243,10 +1169,10 @@ CHIP_ERROR CASESession::ConstructTBS2Data(const uint8_t * responderOpCert, uint3 tlvWriter.Init(tbsData, tbsDataLen); ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); ReturnErrorOnFailure( - tlvWriter.PutBytes(CASETLVTag::kResponderEphPubKey, mRemotePubKey, static_cast(mRemotePubKey.Length()))); - ReturnErrorOnFailure(tlvWriter.PutBytes(CASETLVTag::kNOC, responderOpCert, responderOpCertLen)); - ReturnErrorOnFailure(tlvWriter.PutBytes(CASETLVTag::kInitiatorEphPubKey, mEphemeralKey.Pubkey(), - static_cast(mEphemeralKey.Pubkey().Length()))); + tlvWriter.PutBytes(TLV::ContextTag(1), responderOpCert.data(), static_cast(responderOpCert.size()))); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), mRemotePubKey, static_cast(mRemotePubKey.Length()))); + ReturnErrorOnFailure( + tlvWriter.PutBytes(TLV::ContextTag(3), mEphemeralKey.Pubkey(), static_cast(mEphemeralKey.Pubkey().Length()))); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize()); tbsDataLen = static_cast(tlvWriter.GetLengthWritten()); @@ -1254,8 +1180,7 @@ CHIP_ERROR CASESession::ConstructTBS2Data(const uint8_t * responderOpCert, uint3 return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::ConstructTBS3Data(const uint8_t * responderOpCert, uint32_t responderOpCertLen, uint8_t * tbsData, - uint16_t & tbsDataLen) +CHIP_ERROR CASESession::ConstructTBS3Data(const ByteSpan & responderOpCert, uint8_t * tbsData, uint16_t & tbsDataLen) { TLV::TLVWriter tlvWriter; TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; @@ -1263,10 +1188,10 @@ CHIP_ERROR CASESession::ConstructTBS3Data(const uint8_t * responderOpCert, uint3 tlvWriter.Init(tbsData, tbsDataLen); ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); ReturnErrorOnFailure( - tlvWriter.PutBytes(CASETLVTag::kInitiatorEphPubKey, mRemotePubKey, static_cast(mRemotePubKey.Length()))); - ReturnErrorOnFailure(tlvWriter.PutBytes(CASETLVTag::kNOC, responderOpCert, responderOpCertLen)); - ReturnErrorOnFailure(tlvWriter.PutBytes(CASETLVTag::kResponderEphPubKey, mEphemeralKey.Pubkey(), - static_cast(mEphemeralKey.Pubkey().Length()))); + tlvWriter.PutBytes(TLV::ContextTag(1), responderOpCert.data(), static_cast(responderOpCert.size()))); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), mRemotePubKey, static_cast(mRemotePubKey.Length()))); + ReturnErrorOnFailure( + tlvWriter.PutBytes(TLV::ContextTag(3), mEphemeralKey.Pubkey(), static_cast(mEphemeralKey.Pubkey().Length()))); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize()); tbsDataLen = static_cast(tlvWriter.GetLengthWritten()); @@ -1274,17 +1199,9 @@ CHIP_ERROR CASESession::ConstructTBS3Data(const uint8_t * responderOpCert, uint3 return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::ComputeIPK(const uint16_t sessionID, uint8_t * ipk, size_t ipkLen) +CHIP_ERROR CASESession::RetrieveIPK(FabricId fabricId, MutableByteSpan & ipk) { - uint8_t sid[2]; - Encoding::LittleEndian::BufferWriter bbuf(sid, sizeof(sid)); - bbuf.Put16(sessionID); - VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY); - - HKDF_sha_crypto mHKDF; - ReturnErrorOnFailure(mHKDF.HKDF_SHA256(mFabricSecret, mFabricSecret.Length(), bbuf.Buffer(), bbuf.Size(), kIPKInfo, - sizeof(kIPKInfo), ipk, ipkLen)); - + memset(ipk.data(), static_cast(fabricId), ipk.size()); return CHIP_NO_ERROR; } @@ -1317,10 +1234,6 @@ CHIP_ERROR CASESession::HandleErrorMsg(const System::PacketBufferHandle & msg) CHIP_ERROR err = CHIP_NO_ERROR; switch (pMsg->error) { - case SigmaErrorType::kNoSharedTrustRoots: - err = CHIP_ERROR_CERT_NOT_TRUSTED; - break; - case SigmaErrorType::kUnsupportedVersion: err = CHIP_ERROR_UNSUPPORTED_CASE_CONFIGURATION; break; diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 9080bea035b75e..6b04f3c181f52f 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -38,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -47,6 +46,7 @@ namespace chip { +// TODO: move this constant over to src/crypto/CHIPCryptoPAL.h - name it CHIP_CRYPTO_SYMMETRIC_KEY_LENGTH_BYTES constexpr uint16_t kAEADKeySize = 16; constexpr uint16_t kSigmaParamRandomNumberSize = 32; @@ -55,7 +55,7 @@ constexpr uint16_t kMaxTrustedRootIds = 5; constexpr uint16_t kIPKSize = 16; -#ifdef ENABLE_HSM_CASE_EPHERMAL_KEY +#ifdef ENABLE_HSM_CASE_EPHEMERAL_KEY #define CASE_EPHEMERAL_KEY 0xCA5EECD0 #endif @@ -67,6 +67,8 @@ struct CASESessionSerializable uint8_t mSharedSecret[Crypto::kMax_ECDH_Secret_Length]; uint16_t mMessageDigestLen; uint8_t mMessageDigest[Crypto::kSHA256_Hash_Length]; + uint16_t mIPKLen; + uint8_t mIPK[kIPKSize]; uint8_t mPairingComplete; NodeId mPeerNodeId; uint16_t mLocalKeyId; @@ -105,6 +107,9 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin * @param peerAddress Address of peer with which to establish a session. * @param operationalCredentialSet CHIP Certificate Set used to store the chain root of trust an validate peer node * certificates + * @param opCredSetIndex Index value used to choose the chain root of trust for establishing a session. Retrieve + * this index value from an operationalCredentialSet's entry that matches the device's + * operational credentials * @param peerNodeId Node id of the peer node * @param myKeyId Key ID to be assigned to the secure session on the peer node * @param exchangeCtxt The exchange context to send and receive messages with the peer @@ -113,8 +118,8 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin * @return CHIP_ERROR The result of initialization */ CHIP_ERROR EstablishSession(const Transport::PeerAddress peerAddress, - Credentials::OperationalCredentialSet * operationalCredentialSet, NodeId peerNodeId, - uint16_t myKeyId, Messaging::ExchangeContext * exchangeCtxt, + Credentials::OperationalCredentialSet * operationalCredentialSet, uint8_t opCredSetIndex, + NodeId peerNodeId, uint16_t myKeyId, Messaging::ExchangeContext * exchangeCtxt, SessionEstablishmentDelegate * delegate); /** @@ -198,7 +203,6 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin private: enum SigmaErrorType : uint8_t { - kNoSharedTrustRoots = 0x01, kInvalidSignature = 0x04, kInvalidResumptionTag = 0x05, kUnsupportedVersion = 0x06, @@ -220,17 +224,25 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin CHIP_ERROR SendSigmaR1Resume(); CHIP_ERROR HandleSigmaR1Resume_and_SendSigmaR2Resume(const PacketHeader & header, const System::PacketBufferHandle & msg); - CHIP_ERROR FindValidTrustedRoot(const System::PacketBufferTLVReader & tlvReader, uint32_t nTrustedRoots); - CHIP_ERROR ConstructSaltSigmaR2(const ByteSpan & rand, const Crypto::P256PublicKey & pubkey, const uint8_t * ipk, size_t ipkLen, +protected: + CHIP_ERROR GenerateDestinationID(const ByteSpan & random, const Credentials::P256PublicKeySpan & rootPubkey, NodeId nodeId, + FabricId fabricId, const ByteSpan & ipk, MutableByteSpan & destinationId); + +private: + CHIP_ERROR FindDestinationIdCandidate(const ByteSpan & destinationId, const ByteSpan & initiatorRandom, + const ByteSpan * ipkList, size_t ipkListEntries); + CHIP_ERROR ConstructSaltSigmaR2(const ByteSpan & rand, const Crypto::P256PublicKey & pubkey, const ByteSpan & ipk, MutableByteSpan & salt); - CHIP_ERROR Validate_and_RetrieveResponderID(const uint8_t * responderOpCert, uint16_t responderOpCertLen, - Crypto::P256PublicKey & responderID); - CHIP_ERROR ConstructSaltSigmaR3(const uint8_t * ipk, size_t ipkLen, MutableByteSpan & salt); - CHIP_ERROR ConstructTBS2Data(const uint8_t * responderOpCert, uint32_t responderOpCertLen, uint8_t * tbsData, - uint16_t & tbsDataLen); - CHIP_ERROR ConstructTBS3Data(const uint8_t * responderOpCert, uint32_t responderOpCertLen, uint8_t * tbsData, - uint16_t & tbsDataLen); - CHIP_ERROR ComputeIPK(const uint16_t sessionID, uint8_t * ipk, size_t ipkLen); + CHIP_ERROR Validate_and_RetrieveResponderID(const ByteSpan & responderOpCert, Crypto::P256PublicKey & responderID); + CHIP_ERROR ConstructSaltSigmaR3(const ByteSpan & ipk, MutableByteSpan & salt); + CHIP_ERROR ConstructTBS2Data(const ByteSpan & responderOpCert, uint8_t * tbsData, uint16_t & tbsDataLen); + CHIP_ERROR ConstructTBS3Data(const ByteSpan & responderOpCert, uint8_t * tbsData, uint16_t & tbsDataLen); + CHIP_ERROR RetrieveIPK(FabricId fabricId, MutableByteSpan & ipk); + + uint16_t EstimateTLVStructOverhead(uint16_t dataLen, uint16_t nFields) + { + return static_cast(dataLen + sizeof(uint64_t) * nFields); + } void SendErrorMsg(SigmaErrorType errorCode); @@ -253,13 +265,11 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin Crypto::Hash_SHA256_stream mCommissioningHash; Crypto::P256PublicKey mRemotePubKey; -#ifdef ENABLE_HSM_CASE_EPHERMAL_KEY +#ifdef ENABLE_HSM_CASE_EPHEMERAL_KEY Crypto::P256KeypairHSM mEphemeralKey; #else Crypto::P256Keypair mEphemeralKey; #endif - // TODO: Remove mFabricSecret later - Crypto::P256ECDHDerivedSecret mFabricSecret; Crypto::P256ECDHDerivedSecret mSharedSecret; Credentials::OperationalCredentialSet * mOpCredSet; Credentials::CertificateKeyId mTrustedRootId; @@ -267,7 +277,6 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin uint8_t mMessageDigest[Crypto::kSHA256_Hash_Length]; uint8_t mIPK[kIPKSize]; - uint8_t mRemoteIPK[kIPKSize]; Messaging::ExchangeContext * mExchangeCtxt = nullptr; SessionEstablishmentExchangeDispatch mMessageDispatch; @@ -281,6 +290,17 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin bool mPairingComplete = false; Transport::PeerConnectionState mConnectionState; + + virtual ByteSpan * GetIPKList() const + { + // TODO: Remove this list. Replace it with an actual method to retrieve an IPK list (e.g. from a Crypto Store API) + static uint8_t sIPKList[][kIPKSize] = { + { 0 }, /* Corresponds to the FabricID for the Commissioning Example. All zeros. */ + }; + static ByteSpan ipkListSpan[] = { ByteSpan(sIPKList[0]) }; + return ipkListSpan; + } + virtual size_t GetIPKListEntries() const { return 1; } }; typedef struct CASESessionSerialized diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index c3480e97bcda46..d2e1ab526961bb 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -67,11 +67,16 @@ P256SerializedKeypair accessoryOpKeysSerialized; P256Keypair commissionerOpKeys; P256Keypair accessoryOpKeys; + +CertificateKeyId trustedRootId = CertificateKeyId(sTestCert_Root01_SubjectKeyId); +uint8_t commissionerCredentialsIndex; + +NodeId Node01_01 = 0xDEDEDEDE00010001; } // namespace enum { - kStandardCertsCount = 4, + kStandardCertsCount = 3, }; class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate @@ -85,10 +90,43 @@ class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate uint32_t mNumPairingComplete = 0; }; -static CHIP_ERROR InitCredentialSets() +class TestCASESessionDestinationId : public CASESession { - CertificateKeyId trustedRootId = CertificateKeyId(sTestCert_Root01_SubjectKeyId); +public: + CHIP_ERROR GenerateDestinationID(const ByteSpan & random, const Credentials::P256PublicKeySpan & rootPubkey, NodeId nodeId, + FabricId fabricId, const ByteSpan & ipk, MutableByteSpan & destinationId) + { + return CASESession::GenerateDestinationID(random, rootPubkey, nodeId, fabricId, ipk, destinationId); + } +}; + +class TestCASESessionIPK : public CASESession +{ +protected: + ByteSpan * GetIPKList() const override + { + // TODO: Remove this list. Replace it with an actual method to retrieve an IPK list (e.g. from a Crypto Store API) + static uint8_t sIPKList[][kIPKSize] = { + { 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, 0x1D, + 0x1D }, /* Corresponds to the FabricID for the Node01_01 Test Vector */ + }; + static ByteSpan ipkListSpan[] = { ByteSpan(sIPKList[0]) }; + return ipkListSpan; + } + size_t GetIPKListEntries() const override { return 1; } +}; + +class TestCASEServerIPK : public CASEServer +{ +public: + TestCASESessionIPK & GetSession() override { return mPairingSession; } +private: + TestCASESessionIPK mPairingSession; +}; + +static CHIP_ERROR InitCredentialSets() +{ commissionerDevOpCred.Release(); accessoryDevOpCred.Release(); commissionerCertificateSet.Release(); @@ -130,6 +168,7 @@ static CHIP_ERROR InitCredentialSets() BitFlags(CertDecodeFlags::kIsTrustAnchor))); ReturnErrorOnFailure(commissionerDevOpCred.Init(&commissionerCertificateSet, 1)); + commissionerCredentialsIndex = static_cast(commissionerDevOpCred.GetCertCount() - 1U); ReturnErrorOnFailure(commissionerDevOpCred.SetDevOpCred(trustedRootId, sTestCert_Node01_01_Chip, static_cast(sTestCert_Node01_01_Chip_Len))); @@ -150,7 +189,7 @@ void CASE_SecurePairingWaitTest(nlTestSuite * inSuite, void * inContext) { // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegate; - CASESession pairing; + TestCASESessionIPK pairing; NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(&accessoryDevOpCred, 0, nullptr) == CHIP_ERROR_INVALID_ARGUMENT); NL_TEST_ASSERT(inSuite, pairing.ListenForSessionEstablishment(&accessoryDevOpCred, 0, &delegate) == CHIP_NO_ERROR); @@ -168,11 +207,11 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context = ctx.NewExchangeToLocal(&pairing); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, 2, 0, nullptr, - nullptr) != CHIP_NO_ERROR); + pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, + commissionerCredentialsIndex, Node01_01, 0, nullptr, nullptr) != CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, - pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, 2, 0, context, - &delegate) == CHIP_NO_ERROR); + pairing.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, + commissionerCredentialsIndex, Node01_01, 0, context, &delegate) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 1); @@ -186,7 +225,8 @@ void CASE_SecurePairingStartTest(nlTestSuite * inSuite, void * inContext) ExchangeContext * context1 = ctx.NewExchangeToLocal(&pairing1); NL_TEST_ASSERT(inSuite, - pairing1.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, 2, 0, context1, + pairing1.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, + commissionerCredentialsIndex, Node01_01, 0, context1, &delegate) == CHIP_ERROR_BAD_REQUEST); gLoopback.mMessageSendError = CHIP_NO_ERROR; } @@ -198,12 +238,10 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegateAccessory; - CASESession pairingAccessory; + TestCASESessionIPK pairingAccessory; CASESessionSerializable serializableCommissioner; CASESessionSerializable serializableAccessory; - NL_TEST_ASSERT(inSuite, InitCredentialSets() == CHIP_NO_ERROR); - gLoopback.mSentMessageCount = 0; NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, pairingAccessory.MessageDispatch().Init(&gTransportMgr) == CHIP_NO_ERROR); @@ -217,8 +255,9 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, pairingAccessory.ListenForSessionEstablishment(&accessoryDevOpCred, 0, &delegateAccessory) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, - pairingCommissioner.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, 1, 0, - contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner.EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &commissionerDevOpCred, + commissionerCredentialsIndex, Node01_01, 0, contextCommissioner, + &delegateCommissioner) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 3); NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 1); @@ -309,7 +348,7 @@ class TestPersistentStorageDelegate : public PersistentStorageDelegate uint16_t valuesize[16]; }; -CASEServer gPairingServer; +TestCASEServerIPK gPairingServer; void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext) { @@ -352,7 +391,8 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte ChipCertificateSet certificates; OperationalCredentialSet credentials; CertificateKeyId rootKeyId; - NL_TEST_ASSERT(inSuite, fabric->GetCredentials(credentials, certificates, rootKeyId) == CHIP_NO_ERROR); + uint8_t credentialsIndex; + NL_TEST_ASSERT(inSuite, fabric->GetCredentials(credentials, certificates, rootKeyId, credentialsIndex) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, gPairingServer.ListenForSessionEstablishment(&ctx.GetExchangeManager(), &gTransportMgr, @@ -362,8 +402,9 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte ExchangeContext * contextCommissioner = ctx.NewExchangeToLocal(pairingCommissioner); NL_TEST_ASSERT(inSuite, - pairingCommissioner->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &credentials, 1, 0, - contextCommissioner, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &credentials, + credentialsIndex, Node01_01, 0, contextCommissioner, + &delegateCommissioner) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, gLoopback.mSentMessageCount == 3); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); @@ -373,8 +414,9 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte ExchangeContext * contextCommissioner1 = ctx.NewExchangeToLocal(pairingCommissioner1); NL_TEST_ASSERT(inSuite, - pairingCommissioner1->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &credentials, 1, 0, - contextCommissioner1, &delegateCommissioner) == CHIP_NO_ERROR); + pairingCommissioner1->EstablishSession(Transport::PeerAddress(Transport::Type::kBle), &credentials, + credentialsIndex, Node01_01, 0, contextCommissioner1, + &delegateCommissioner) == CHIP_NO_ERROR); chip::Platform::Delete(pairingCommissioner); chip::Platform::Delete(pairingCommissioner1); @@ -435,6 +477,39 @@ void CASE_SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) chip::Platform::Delete(testPairingSession2); } +void CASE_DestinationIDGenerationTest(nlTestSuite * inSuite, void * inContext) +{ + TestCASESessionDestinationId pairingCommissioner; + + uint8_t random[kSigmaParamRandomNumberSize] = { 0x7e, 0x17, 0x12, 0x31, 0x56, 0x8d, 0xfa, 0x17, 0x20, 0x6b, 0x3a, + 0xcc, 0xf8, 0xfa, 0xec, 0x2f, 0x4d, 0x21, 0xb5, 0x80, 0x11, 0x31, + 0x96, 0xf4, 0x7c, 0x7c, 0x4d, 0xeb, 0x81, 0x0a, 0x73, 0xdc }; + uint8_t destinationIdentifier[kSHA256_Hash_Length] = { 0 }; + NodeId nodeId = 0xCD5544AA7B13EF14; + FabricId fabricId = 0x2906C908D115D362; + uint8_t rootPubkey[kP256_PublicKey_Length] = { 0x04, 0x4a, 0x9f, 0x42, 0xb1, 0xca, 0x48, 0x40, 0xd3, 0x72, 0x92, 0xbb, 0xc7, + 0xf6, 0xa7, 0xe1, 0x1e, 0x22, 0x20, 0x0c, 0x97, 0x6f, 0xc9, 0x00, 0xdb, 0xc9, + 0x8a, 0x7a, 0x38, 0x3a, 0x64, 0x1c, 0xb8, 0x25, 0x4a, 0x2e, 0x56, 0xd4, 0xe2, + 0x95, 0xa8, 0x47, 0x94, 0x3b, 0x4e, 0x38, 0x97, 0xc4, 0xa7, 0x73, 0xe9, 0x30, + 0x27, 0x7b, 0x4d, 0x9f, 0xbe, 0xde, 0x8a, 0x05, 0x26, 0x86, 0xbf, 0xac, 0xfa }; + P256PublicKeySpan rootPubkeySpan(rootPubkey); + uint8_t destinationIdentifierTestVector[kSHA256_Hash_Length] = { 0xc8, 0xe1, 0x70, 0x0d, 0x12, 0x5a, 0xff, 0xbc, + 0xea, 0xda, 0x34, 0x2a, 0x0d, 0x00, 0xdb, 0x7c, + 0xa0, 0x65, 0x05, 0xae, 0x5d, 0x0b, 0x29, 0x87, + 0xf3, 0xaf, 0x4b, 0x77, 0xe3, 0x94, 0x05, 0x1d }; + + uint8_t ipk[] = { 0x4a, 0x71, 0xcd, 0xd7, 0xb2, 0xa3, 0xca, 0x90, 0x24, 0xf9, 0x6f, 0x3c, 0x96, 0xa1, 0x9d, 0xee }; + + { + MutableByteSpan destinationIdSpan(destinationIdentifier, sizeof(destinationIdentifier)); + NL_TEST_ASSERT(inSuite, + pairingCommissioner.GenerateDestinationID(ByteSpan(random, sizeof(random)), rootPubkeySpan, nodeId, fabricId, + ByteSpan(ipk, sizeof(ipk)), destinationIdSpan) == CHIP_NO_ERROR); + } + + NL_TEST_ASSERT(inSuite, memcmp(destinationIdentifier, destinationIdentifierTestVector, sizeof(destinationIdentifier)) == 0); +} + // Test Suite /** @@ -448,6 +523,7 @@ static const nlTest sTests[] = NL_TEST_DEF("Handshake", CASE_SecurePairingHandshakeTest), NL_TEST_DEF("ServerHandshake", CASE_SecurePairingHandshakeServerTest), NL_TEST_DEF("Serialize", CASE_SecurePairingSerializeTest), + NL_TEST_DEF("DestinationID Generation", CASE_DestinationIDGenerationTest), NL_TEST_SENTINEL() }; diff --git a/src/transport/FabricTable.cpp b/src/transport/FabricTable.cpp index 09ac68f3205da2..708e48f2bc8541 100644 --- a/src/transport/FabricTable.cpp +++ b/src/transport/FabricTable.cpp @@ -357,7 +357,7 @@ CHIP_ERROR FabricInfo::SetOperationalCertsFromCertArray(const ByteSpan & certArr } CHIP_ERROR FabricInfo::GetCredentials(OperationalCredentialSet & credentials, ChipCertificateSet & certificates, - CertificateKeyId & rootKeyId) + CertificateKeyId & rootKeyId, uint8_t & credentialsIndex) { constexpr uint8_t kMaxNumCertsInOpCreds = 3; ReturnErrorOnFailure(certificates.Init(kMaxNumCertsInOpCreds)); @@ -374,6 +374,7 @@ CHIP_ERROR FabricInfo::GetCredentials(OperationalCredentialSet & credentials, Ch credentials.Release(); ReturnErrorOnFailure(credentials.Init(&certificates, 1)); + credentialsIndex = static_cast(credentials.GetCertCount() - 1U); rootKeyId = credentials.GetTrustedRootId(0); diff --git a/src/transport/FabricTable.h b/src/transport/FabricTable.h index 0fd9c4e35314c9..067e484a7d2292 100644 --- a/src/transport/FabricTable.h +++ b/src/transport/FabricTable.h @@ -122,8 +122,19 @@ class DLL_EXPORT FabricInfo return (mRootCert != nullptr && mNOCCert != nullptr && mRootCertLen != 0 && mNOCCertLen != 0); } + /** + * @brief + * Retrieve the credentials corresponding to the device being commissioned in form of OperationalCredentialSet. + * + * @param credentials Credential Set object containing the device's certificate set and keypair. + * @param certSet Set of Root [+ ICA] certificates corresponding to the device's credential set. + * @param rootKeyId Trusted Root Id corresponding to the device's credential set. + * @param credentialsIndex Index for the retrieved credentials corresponding to this device's credential set. + * + * @return CHIP_ERROR + */ CHIP_ERROR GetCredentials(Credentials::OperationalCredentialSet & credentials, Credentials::ChipCertificateSet & certSet, - Credentials::CertificateKeyId & rootKeyId); + Credentials::CertificateKeyId & rootKeyId, uint8_t & credentialsIndex); const uint8_t * GetTrustedRoot(uint16_t & size) { diff --git a/src/transport/PeerConnectionState.h b/src/transport/PeerConnectionState.h index 6f8e01a66855c4..a0f22b7b073c74 100644 --- a/src/transport/PeerConnectionState.h +++ b/src/transport/PeerConnectionState.h @@ -68,6 +68,7 @@ class PeerConnectionState uint16_t GetPeerKeyID() const { return mPeerKeyID; } void SetPeerKeyID(uint16_t id) { mPeerKeyID = id; } + // TODO: Rename KeyID to SessionID uint16_t GetLocalKeyID() const { return mLocalKeyID; } void SetLocalKeyID(uint16_t id) { mLocalKeyID = id; }