From 836944463375a5638d2a95f9e1062704a17783d5 Mon Sep 17 00:00:00 2001 From: Nivi Sarkar <55898241+nivi-apple@users.noreply.github.com> Date: Thu, 2 Dec 2021 09:27:53 -0800 Subject: [PATCH] Add support for CASE session caching for session resume use cases (#11937) - Add tests for the CASE session cache - Update the CASESessionCachable struct to have only necessary members and rename the struct and APIs appropriately - Remove the tests that serilaize and deserilaize the CASE Session as its outdated --- src/lib/core/CHIPConfig.h | 10 + src/protocols/secure_channel/BUILD.gn | 2 + src/protocols/secure_channel/CASESession.cpp | 92 ++----- src/protocols/secure_channel/CASESession.h | 40 +-- .../secure_channel/CASESessionCache.cpp | 105 ++++++++ .../secure_channel/CASESessionCache.h | 44 +++ src/protocols/secure_channel/tests/BUILD.gn | 1 + .../secure_channel/tests/TestCASESession.cpp | 68 +---- .../tests/TestCASESessionCache.cpp | 251 ++++++++++++++++++ 9 files changed, 450 insertions(+), 163 deletions(-) create mode 100644 src/protocols/secure_channel/CASESessionCache.cpp create mode 100644 src/protocols/secure_channel/CASESessionCache.h create mode 100644 src/protocols/secure_channel/tests/TestCASESessionCache.cpp diff --git a/src/lib/core/CHIPConfig.h b/src/lib/core/CHIPConfig.h index 4489d2fe189e87..4157026b139cf7 100644 --- a/src/lib/core/CHIPConfig.h +++ b/src/lib/core/CHIPConfig.h @@ -2722,6 +2722,16 @@ extern const char CHIP_NON_PRODUCTION_MARKER[]; #define CHIP_CONFIG_MAX_SESSION_RECOVERY_DELEGATES 3 #endif +/** + * @def CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE + * + * @brief + * Maximum number of CASE sessions that a device caches, that can be resumed + */ +#ifndef CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE +#define CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE 4 +#endif + /** * @} */ diff --git a/src/protocols/secure_channel/BUILD.gn b/src/protocols/secure_channel/BUILD.gn index 525406a1c1acdc..5b4c9ef9580cd2 100644 --- a/src/protocols/secure_channel/BUILD.gn +++ b/src/protocols/secure_channel/BUILD.gn @@ -8,6 +8,8 @@ static_library("secure_channel") { "CASEServer.h", "CASESession.cpp", "CASESession.h", + "CASESessionCache.cpp", + "CASESessionCache.h", "PASESession.cpp", "PASESession.h", "RendezvousParameters.h", diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index e4722ea380ada6..6ff1dec6e48c9e 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -69,8 +69,6 @@ constexpr uint8_t kTBEData3_Nonce[] = constexpr size_t kTBEDataNonceLength = sizeof(kTBEData2_Nonce); static_assert(sizeof(kTBEData2_Nonce) == sizeof(kTBEData3_Nonce), "TBEData2_Nonce and TBEData3_Nonce must be same size"); -constexpr uint8_t kCASESessionVersion = 1; - enum { kTag_TBEData_SenderNOC = 1, @@ -124,96 +122,48 @@ void CASESession::CloseExchange() } } -CHIP_ERROR CASESession::Serialize(CASESessionSerialized & output) -{ - uint16_t serializedLen = 0; - CASESessionSerializable serializable; - - VerifyOrReturnError(BASE64_ENCODED_LEN(sizeof(serializable)) <= sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT); - - ReturnErrorOnFailure(ToSerializable(serializable)); - - serializedLen = chip::Base64Encode(Uint8::to_const_uchar(reinterpret_cast(&serializable)), - static_cast(sizeof(serializable)), Uint8::to_char(output.inner)); - VerifyOrReturnError(serializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(serializedLen < sizeof(output.inner), CHIP_ERROR_INVALID_ARGUMENT); - output.inner[serializedLen] = '\0'; - - return CHIP_NO_ERROR; -} - -CHIP_ERROR CASESession::Deserialize(CASESessionSerialized & input) -{ - CASESessionSerializable serializable; - size_t maxlen = BASE64_ENCODED_LEN(sizeof(serializable)); - size_t len = strnlen(Uint8::to_char(input.inner), maxlen); - uint16_t deserializedLen = 0; - - VerifyOrReturnError(len < sizeof(CASESessionSerialized), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(CanCastTo(len), CHIP_ERROR_INVALID_ARGUMENT); - - memset(&serializable, 0, sizeof(serializable)); - deserializedLen = - Base64Decode(Uint8::to_const_char(input.inner), static_cast(len), Uint8::to_uchar((uint8_t *) &serializable)); - - VerifyOrReturnError(deserializedLen > 0, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(deserializedLen <= sizeof(serializable), CHIP_ERROR_INVALID_ARGUMENT); - - ReturnErrorOnFailure(FromSerializable(serializable)); - - return CHIP_NO_ERROR; -} - -CHIP_ERROR CASESession::ToSerializable(CASESessionSerializable & serializable) +CHIP_ERROR CASESession::ToCachable(CASESessionCachable & cachableSession) { const NodeId peerNodeId = GetPeerNodeId(); VerifyOrReturnError(CanCastTo(mSharedSecret.Length()), CHIP_ERROR_INTERNAL); - VerifyOrReturnError(CanCastTo(sizeof(mMessageDigest)), CHIP_ERROR_INTERNAL); VerifyOrReturnError(CanCastTo(peerNodeId), CHIP_ERROR_INTERNAL); - memset(&serializable, 0, sizeof(serializable)); - serializable.mSharedSecretLen = LittleEndian::HostSwap16(static_cast(mSharedSecret.Length())); - serializable.mMessageDigestLen = LittleEndian::HostSwap16(static_cast(sizeof(mMessageDigest))); - serializable.mVersion = kCASESessionVersion; - serializable.mPeerNodeId = LittleEndian::HostSwap64(peerNodeId); - for (size_t i = 0; i < serializable.mPeerCATs.size(); i++) + memset(&cachableSession, 0, sizeof(cachableSession)); + cachableSession.mSharedSecretLen = LittleEndian::HostSwap16(static_cast(mSharedSecret.Length())); + cachableSession.mPeerNodeId = LittleEndian::HostSwap64(peerNodeId); + for (size_t i = 0; i < cachableSession.mPeerCATs.size(); i++) { - serializable.mPeerCATs.val[i] = LittleEndian::HostSwap32(GetPeerCATs().val[i]); + cachableSession.mPeerCATs.val[i] = LittleEndian::HostSwap32(GetPeerCATs().val[i]); } - serializable.mLocalSessionId = LittleEndian::HostSwap16(GetLocalSessionId()); - serializable.mPeerSessionId = LittleEndian::HostSwap16(GetPeerSessionId()); + // TODO: Get the fabric index + cachableSession.mLocalFabricIndex = 0; + cachableSession.mSessionSetupTimeStamp = LittleEndian::HostSwap64(mSessionSetupTimeStamp); - memcpy(serializable.mResumptionId, mResumptionId, sizeof(mResumptionId)); - memcpy(serializable.mSharedSecret, mSharedSecret, mSharedSecret.Length()); - memcpy(serializable.mMessageDigest, mMessageDigest, sizeof(mMessageDigest)); + memcpy(cachableSession.mResumptionId, mResumptionId, sizeof(mResumptionId)); + memcpy(cachableSession.mSharedSecret, mSharedSecret, mSharedSecret.Length()); return CHIP_NO_ERROR; } -CHIP_ERROR CASESession::FromSerializable(const CASESessionSerializable & serializable) +CHIP_ERROR CASESession::FromCachable(const CASESessionCachable & cachableSession) { - VerifyOrReturnError(serializable.mVersion == kCASESessionVersion, CHIP_ERROR_VERSION_MISMATCH); - - uint16_t length = LittleEndian::HostSwap16(serializable.mSharedSecretLen); + uint16_t length = LittleEndian::HostSwap16(cachableSession.mSharedSecretLen); ReturnErrorOnFailure(mSharedSecret.SetLength(static_cast(length))); memset(mSharedSecret, 0, sizeof(mSharedSecret.Capacity())); - memcpy(mSharedSecret, serializable.mSharedSecret, length); - - length = LittleEndian::HostSwap16(serializable.mMessageDigestLen); - VerifyOrReturnError(length <= sizeof(mMessageDigest), CHIP_ERROR_INVALID_ARGUMENT); - memcpy(mMessageDigest, serializable.mMessageDigest, length); + memcpy(mSharedSecret, cachableSession.mSharedSecret, length); - SetPeerNodeId(LittleEndian::HostSwap64(serializable.mPeerNodeId)); + SetPeerNodeId(LittleEndian::HostSwap64(cachableSession.mPeerNodeId)); Credentials::CATValues peerCATs; - for (size_t i = 0; i < serializable.mPeerCATs.size(); i++) + for (size_t i = 0; i < cachableSession.mPeerCATs.size(); i++) { - peerCATs.val[i] = LittleEndian::HostSwap32(serializable.mPeerCATs.val[i]); + peerCATs.val[i] = LittleEndian::HostSwap32(cachableSession.mPeerCATs.val[i]); } SetPeerCATs(peerCATs); - SetLocalSessionId(LittleEndian::HostSwap16(serializable.mLocalSessionId)); - SetPeerSessionId(LittleEndian::HostSwap16(serializable.mPeerSessionId)); + SetSessionTimeStamp(LittleEndian::HostSwap64(cachableSession.mSessionSetupTimeStamp)); + // TODO: Set the fabric index correctly + mLocalFabricIndex = cachableSession.mLocalFabricIndex; - memcpy(mResumptionId, serializable.mResumptionId, sizeof(mResumptionId)); + memcpy(mResumptionId, cachableSession.mResumptionId, sizeof(mResumptionId)); const ByteSpan * ipkListSpan = GetIPKList(); VerifyOrReturnError(ipkListSpan->size() == sizeof(mIPK), CHIP_ERROR_INVALID_ARGUMENT); diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 860f28b9d28a51..fab7aa2b40d4f7 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -58,20 +58,15 @@ constexpr size_t kCASEResumptionIDSize = 16; #define CASE_EPHEMERAL_KEY 0xCA5EECD0 #endif -struct CASESessionSerialized; - -struct CASESessionSerializable +struct CASESessionCachable { - uint8_t mVersion; uint16_t mSharedSecretLen; uint8_t mSharedSecret[Crypto::kMax_ECDH_Secret_Length]; - uint16_t mMessageDigestLen; - uint8_t mMessageDigest[Crypto::kSHA256_Hash_Length]; + FabricIndex mLocalFabricIndex; NodeId mPeerNodeId; Credentials::CATValues mPeerCATs; - uint16_t mLocalSessionId; - uint16_t mPeerSessionId; uint8_t mResumptionId[kCASEResumptionIDSize]; + uint64_t mSessionSetupTimeStamp; }; class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public PairingSession @@ -154,24 +149,14 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin const char * GetR2ISessionInfo() const override { return "Sigma R2I Key"; } /** - * @brief Serialize the Pairing Session to a string. - **/ - CHIP_ERROR Serialize(CASESessionSerialized & output); - - /** - * @brief Deserialize the Pairing Session from the string. - **/ - CHIP_ERROR Deserialize(CASESessionSerialized & input); - - /** - * @brief Serialize the CASESession to the given serializable data structure for secure pairing + * @brief Serialize the CASESession to the given cachableSession data structure for secure pairing **/ - CHIP_ERROR ToSerializable(CASESessionSerializable & output); + CHIP_ERROR ToCachable(CASESessionCachable & output); /** - * @brief Reconstruct secure pairing class from the serializable data structure. + * @brief Reconstruct secure pairing class from the cachableSession data structure. **/ - CHIP_ERROR FromSerializable(const CASESessionSerializable & output); + CHIP_ERROR FromCachable(const CASESessionCachable & output); SessionEstablishmentExchangeDispatch & MessageDispatch() { return mMessageDispatch; } @@ -277,6 +262,9 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin State mState; + uint8_t mLocalFabricIndex = 0; + uint64_t mSessionSetupTimeStamp = 0; + protected: bool mCASESessionEstablished = false; @@ -290,12 +278,8 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin return ipkListSpan; } virtual size_t GetIPKListEntries() const { return 1; } -}; -typedef struct CASESessionSerialized -{ - // Extra uint64_t to account for padding bytes (NULL termination, and some decoding overheads) - uint8_t inner[BASE64_ENCODED_LEN(sizeof(CASESessionSerializable) + sizeof(uint64_t))]; -} CASESessionSerialized; + void SetSessionTimeStamp(uint64_t timestamp) { mSessionSetupTimeStamp = timestamp; } +}; } // namespace chip diff --git a/src/protocols/secure_channel/CASESessionCache.cpp b/src/protocols/secure_channel/CASESessionCache.cpp new file mode 100644 index 00000000000000..a50ff1ad00ceb6 --- /dev/null +++ b/src/protocols/secure_channel/CASESessionCache.cpp @@ -0,0 +1,105 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace chip { + +CASESessionCache::CASESessionCache() {} + +CASESessionCache::~CASESessionCache() +{ + mCachePool.ForEachActiveObject([&](auto * ec) { + mCachePool.ReleaseObject(ec); + return true; + }); +} + +CASESessionCachable * CASESessionCache::GetLRUSession() +{ + uint64_t minTimeStamp = UINT64_MAX; + CASESessionCachable * lruSession = nullptr; + mCachePool.ForEachActiveObject([&](auto * ec) { + if (minTimeStamp > ec->mSessionSetupTimeStamp) + { + minTimeStamp = ec->mSessionSetupTimeStamp; + lruSession = ec; + } + return true; + }); + return lruSession; +} + +CHIP_ERROR CASESessionCache::Add(CASESessionCachable & cachableSession) +{ + // It's not an error if a device doesn't have cache for storing the sessions. + VerifyOrReturnError(mCachePool.Capacity() > 0, CHIP_NO_ERROR); + + // If the cache is full, get the least recently used session index and release that. + if (mCachePool.Exhausted()) + { + mCachePool.ReleaseObject(GetLRUSession()); + } + + mCachePool.CreateObject(cachableSession); + return CHIP_NO_ERROR; +} + +CHIP_ERROR CASESessionCache::Remove(ResumptionID resumptionID) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + CASESession session; + mCachePool.ForEachActiveObject([&](auto * ec) { + if (resumptionID.data_equal(ResumptionID(ec->mResumptionId))) + { + mCachePool.ReleaseObject(ec); + } + return true; + }); + + return err; +} + +CHIP_ERROR CASESessionCache::Get(ResumptionID resumptionID, CASESessionCachable & outSessionCachable) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + bool found = false; + mCachePool.ForEachActiveObject([&](auto * ec) { + if (resumptionID.data_equal(ResumptionID(ec->mResumptionId))) + { + found = true; + outSessionCachable = *ec; + return false; + } + return true; + }); + + if (!found) + { + err = CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND; + } + + return err; +} + +CHIP_ERROR CASESessionCache::Get(const PeerId & peer, CASESessionCachable & outSessionCachable) +{ + // TODO: Implement this based on peer id + return CHIP_NO_ERROR; +} + +} // namespace chip diff --git a/src/protocols/secure_channel/CASESessionCache.h b/src/protocols/secure_channel/CASESessionCache.h new file mode 100644 index 00000000000000..96eeddfeaf4441 --- /dev/null +++ b/src/protocols/secure_channel/CASESessionCache.h @@ -0,0 +1,44 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace chip { + +using ResumptionID = FixedByteSpan; + +class CASESessionCache +{ +public: + CASESessionCache(); + virtual ~CASESessionCache(); + + CHIP_ERROR Add(CASESessionCachable & cachableSession); + CHIP_ERROR Remove(ResumptionID resumptionID); + CHIP_ERROR Get(ResumptionID resumptionID, CASESessionCachable & outCachableSession); + CHIP_ERROR Get(const PeerId & peer, CASESessionCachable & outCachableSession); + +private: + BitMapObjectPool mCachePool; + CASESessionCachable * GetLRUSession(); +}; + +} // namespace chip diff --git a/src/protocols/secure_channel/tests/BUILD.gn b/src/protocols/secure_channel/tests/BUILD.gn index ee036a67284595..c2ce3df2fe54b7 100644 --- a/src/protocols/secure_channel/tests/BUILD.gn +++ b/src/protocols/secure_channel/tests/BUILD.gn @@ -10,6 +10,7 @@ chip_test_suite("tests") { test_sources = [ "TestCASESession.cpp", + "TestCASESessionCache.cpp", # TODO - Fix Message Counter Sync to use group key # "TestMessageCounterManager.cpp", diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index c925a3ebf64491..4861bad51b9c2d 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -216,8 +216,8 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte // Test all combinations of invalid parameters TestCASESecurePairingDelegate delegateAccessory; TestCASESessionIPK pairingAccessory; - CASESessionSerializable serializableCommissioner; - CASESessionSerializable serializableAccessory; + CASESessionCachable serializableCommissioner; + CASESessionCachable serializableAccessory; gLoopback.mSentMessageCount = 0; NL_TEST_ASSERT(inSuite, pairingCommissioner.MessageDispatch().Init(&ctx.GetSecureSessionManager()) == CHIP_NO_ERROR); @@ -242,8 +242,8 @@ void CASE_SecurePairingHandshakeTestCommon(nlTestSuite * inSuite, void * inConte NL_TEST_ASSERT(inSuite, delegateAccessory.mNumPairingComplete == 1); NL_TEST_ASSERT(inSuite, delegateCommissioner.mNumPairingComplete == 1); - NL_TEST_ASSERT(inSuite, pairingCommissioner.ToSerializable(serializableCommissioner) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, pairingAccessory.ToSerializable(serializableAccessory) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingCommissioner.ToCachable(serializableCommissioner) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, pairingAccessory.ToCachable(serializableAccessory) == CHIP_NO_ERROR); NL_TEST_ASSERT(inSuite, memcmp(serializableCommissioner.mSharedSecret, serializableAccessory.mSharedSecret, serializableCommissioner.mSharedSecretLen) == 0); @@ -391,65 +391,6 @@ void CASE_SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inConte chip::Platform::Delete(pairingCommissioner1); } -void CASE_SecurePairingDeserialize(nlTestSuite * inSuite, void * inContext, CASESession & pairingCommissioner, - CASESession & deserialized) -{ - CASESessionSerialized serialized; - NL_TEST_ASSERT(inSuite, pairingCommissioner.Serialize(serialized) == CHIP_NO_ERROR); - - NL_TEST_ASSERT(inSuite, deserialized.Deserialize(serialized) == CHIP_NO_ERROR); - - // Serialize from the deserialized session, and check we get the same string back - CASESessionSerialized serialized2; - NL_TEST_ASSERT(inSuite, deserialized.Serialize(serialized2) == CHIP_NO_ERROR); - - NL_TEST_ASSERT(inSuite, strncmp(Uint8::to_char(serialized.inner), Uint8::to_char(serialized2.inner), sizeof(serialized)) == 0); -} - -void CASE_SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) -{ - TestCASESecurePairingDelegate delegateCommissioner; - - // Allocate on the heap to avoid stack overflow in some restricted test scenarios (e.g. QEMU) - auto * testPairingSession1 = chip::Platform::New(); - auto * testPairingSession2 = chip::Platform::New(); - - CASE_SecurePairingHandshakeTestCommon(inSuite, inContext, *testPairingSession1, delegateCommissioner); - CASE_SecurePairingDeserialize(inSuite, inContext, *testPairingSession1, *testPairingSession2); - - const uint8_t plain_text[] = { 0x86, 0x74, 0x64, 0xe5, 0x0b, 0xd4, 0x0d, 0x90, 0xe1, 0x17, 0xa3, 0x2d, 0x4b, 0xd4, 0xe1, 0xe6 }; - uint8_t encrypted[64]; - PacketHeader header; - MessageAuthenticationCode mac; - - header.SetSessionId(1); - NL_TEST_ASSERT(inSuite, header.IsEncrypted() == true); - NL_TEST_ASSERT(inSuite, header.MICTagLength() == 16); - - // Let's try encrypting using original session, and decrypting using deserialized - { - CryptoContext session1; - - NL_TEST_ASSERT(inSuite, - testPairingSession1->DeriveSecureSession(session1, CryptoContext::SessionRole::kInitiator) == CHIP_NO_ERROR); - - NL_TEST_ASSERT(inSuite, session1.Encrypt(plain_text, sizeof(plain_text), encrypted, header, mac) == CHIP_NO_ERROR); - } - - { - CryptoContext session2; - NL_TEST_ASSERT(inSuite, - testPairingSession2->DeriveSecureSession(session2, CryptoContext::SessionRole::kResponder) == CHIP_NO_ERROR); - - uint8_t decrypted[64]; - NL_TEST_ASSERT(inSuite, session2.Decrypt(encrypted, sizeof(plain_text), decrypted, header, mac) == CHIP_NO_ERROR); - NL_TEST_ASSERT(inSuite, memcmp(plain_text, decrypted, sizeof(plain_text)) == 0); - } - - chip::Platform::Delete(testPairingSession1); - chip::Platform::Delete(testPairingSession2); -} - struct Sigma1Params { // Purposefully not using constants like kSigmaParamRandomNumberSize that @@ -676,7 +617,6 @@ static const nlTest sTests[] = NL_TEST_DEF("Start", CASE_SecurePairingStartTest), NL_TEST_DEF("Handshake", CASE_SecurePairingHandshakeTest), NL_TEST_DEF("ServerHandshake", CASE_SecurePairingHandshakeServerTest), - NL_TEST_DEF("Serialize", CASE_SecurePairingSerializeTest), NL_TEST_DEF("Sigma1Parsing", CASE_Sigma1ParsingTest), NL_TEST_SENTINEL() diff --git a/src/protocols/secure_channel/tests/TestCASESessionCache.cpp b/src/protocols/secure_channel/tests/TestCASESessionCache.cpp new file mode 100644 index 00000000000000..42f7c5782afdd1 --- /dev/null +++ b/src/protocols/secure_channel/tests/TestCASESessionCache.cpp @@ -0,0 +1,251 @@ +/* + * + * Copyright (c) 2021 Project CHIP Authors + * All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file + * This file implements unit tests for the CASESession implementation. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace chip; + +using TestContext = chip::Test::MessagingContext; + +namespace { +TransportMgrBase gTransportMgr; +Test::LoopbackTransport gLoopback; +chip::Test::IOContext gIOContext; + +NodeId sTest_PeerId = 0xEDEDEDED00010001; + +uint8_t sTest_SharedSecret[] = { + 0x7d, 0x73, 0x5b, 0xef, 0xe9, 0x16, 0xa1, 0xc0, 0xca, 0x02, 0xf8, 0xca, 0x98, 0x81, 0xe4, 0x26, + 0x63, 0xaa, 0xaf, 0x9a, 0xb9, 0xc4, 0x33, 0xb2, 0x89, 0xbe, 0x26, 0x70, 0x10, 0x75, 0x74, 0x10, +}; + +uint8_t sTest_ResumptionId[kCASEResumptionIDSize] = { 0 }; + +} // namespace + +class CASESessionTest : public CASESession +{ +public: + void createCASESessionTestCachable(uint8_t i) + { + uint16_t sharedSecretLen = sizeof(sTest_SharedSecret); + sTest_SharedSecret[sharedSecretLen - 1] = static_cast(sTest_SharedSecret[sharedSecretLen - 1] + i); + uint64_t timestamp = static_cast(4000 + i * 1000); + sTest_ResumptionId[kCASEResumptionIDSize - 1] = static_cast(sTest_ResumptionId[kCASEResumptionIDSize - 1] + i); + + mCASESessionCachableArray[i].mSharedSecretLen = sharedSecretLen; + memcpy(mCASESessionCachableArray[i].mSharedSecret, sTest_SharedSecret, sharedSecretLen); + mCASESessionCachableArray[i].mPeerNodeId = static_cast(sTest_PeerId + i); + mCASESessionCachableArray[i].mPeerCATs.val[0] = (uint32_t) i; + memcpy(mCASESessionCachableArray[i].mResumptionId, sTest_ResumptionId, kCASEResumptionIDSize); + mCASESessionCachableArray[i].mLocalFabricIndex = 0; + mCASESessionCachableArray[i].mSessionSetupTimeStamp = timestamp; + } + + bool isEqual(int index, CASESessionCachable cachableSession) + { + return (cachableSession.mSharedSecretLen == mCASESessionCachableArray[index].mSharedSecretLen) && + ((ByteSpan(cachableSession.mSharedSecret)).data_equal(ByteSpan(mCASESessionCachableArray[index].mSharedSecret))) && + (cachableSession.mPeerNodeId == mCASESessionCachableArray[index].mPeerNodeId) && + cachableSession.mPeerCATs.val[0] == mCASESessionCachableArray[index].mPeerCATs.val[0] && + ((ResumptionID(cachableSession.mResumptionId)) + .data_equal(ResumptionID(mCASESessionCachableArray[index].mResumptionId))) && + (cachableSession.mLocalFabricIndex == mCASESessionCachableArray[index].mLocalFabricIndex) && + (cachableSession.mSessionSetupTimeStamp == mCASESessionCachableArray[index].mSessionSetupTimeStamp); + } + + void InitializeCASESessionCachableArray() + { + for (size_t j = 0; j < kCASEResumptionIDSize; j++) + { + sTest_ResumptionId[j] = 0x01; + } + for (uint8_t i = 0; i < CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE; i++) + { + createCASESessionTestCachable(i); + } + } + + CASESessionCachable mCASESessionCachableArray[CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE + 1] = { { 0 } }; + CASESessionCache mCASESessionCache; +}; + +CASESessionTest mCASESessionTest; + +static void CASESessionCache_Create_Test(nlTestSuite * inSuite, void * inContext) +{ + mCASESessionTest.InitializeCASESessionCachableArray(); +} + +static void CASESessionCache_Add_Test(nlTestSuite * inSuite, void * inContext) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + for (uint8_t i = 0; i < CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE; i++) + { + CASESession session; + err = mCASESessionTest.mCASESessionCache.Add(mCASESessionTest.mCASESessionCachableArray[i]); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + } +} + +static void CASESessionCache_Get_Test(nlTestSuite * inSuite, void * inContext) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + for (uint8_t i = 0; i < CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE; i++) + { + CASESessionCachable outCachableSession; + err = mCASESessionTest.mCASESessionCache.Get(ResumptionID(mCASESessionTest.mCASESessionCachableArray[i].mResumptionId), + outCachableSession); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, true == mCASESessionTest.isEqual(i, outCachableSession)); + } +} + +static void CASESessionCache_Add_When_Full_Test(nlTestSuite * inSuite, void * inContext) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + mCASESessionTest.createCASESessionTestCachable(CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE); + err = mCASESessionTest.mCASESessionCache.Add( + mCASESessionTest.mCASESessionCachableArray[CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE]); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + // Check if the entry with lowest timestamp has been removed + CASESessionCachable outCachableSession; + err = mCASESessionTest.mCASESessionCache.Get(ResumptionID(mCASESessionTest.mCASESessionCachableArray[0].mResumptionId), + outCachableSession); + NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND); + + // Check if the new entry has been added. + err = mCASESessionTest.mCASESessionCache.Get( + ResumptionID(mCASESessionTest.mCASESessionCachableArray[CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE].mResumptionId), + outCachableSession); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, true == mCASESessionTest.isEqual(CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE, outCachableSession)); +} + +static void CASESessionCache_Remove_Test(nlTestSuite * inSuite, void * inContext) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + for (uint8_t i = 1; i < CHIP_CONFIG_CASE_SESSION_RESUME_CACHE_SIZE + 1; i++) + { + err = mCASESessionTest.mCASESessionCache.Remove(ResumptionID(mCASESessionTest.mCASESessionCachableArray[i].mResumptionId)); + NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); + + CASESessionCachable outCachableSession; + err = mCASESessionTest.mCASESessionCache.Get(ResumptionID(mCASESessionTest.mCASESessionCachableArray[i].mResumptionId), + outCachableSession); + NL_TEST_ASSERT(inSuite, err == CHIP_ERROR_PERSISTED_STORAGE_VALUE_NOT_FOUND); + } +} + +// Test Suite + +/** + * Test Suite that lists all the test functions. + */ +// clang-format off +static const nlTest sTests[] = +{ + NL_TEST_DEF("Create", CASESessionCache_Create_Test), + NL_TEST_DEF("Add", CASESessionCache_Add_Test), + NL_TEST_DEF("Get", CASESessionCache_Get_Test), + NL_TEST_DEF("AddWhenFull", CASESessionCache_Add_When_Full_Test), + NL_TEST_DEF("Remove", CASESessionCache_Remove_Test), + + NL_TEST_SENTINEL() +}; +// clang-format on + +int CASESessionCache_Test_Setup(void * inContext); +int CASESessionCache_Test_Teardown(void * inContext); + +// clang-format off +static nlTestSuite sSuite = +{ + "Test-CHIP-SecurePairing-CASECache", + &sTests[0], + CASESessionCache_Test_Setup, + CASESessionCache_Test_Teardown, +}; +// clang-format on + +static TestContext sContext; + +namespace { +/* + * Set up the test suite. + */ +CHIP_ERROR CASETestCacheSetup(void * inContext) +{ + TestContext & ctx = *reinterpret_cast(inContext); + + ReturnErrorOnFailure(chip::Platform::MemoryInit()); + + gTransportMgr.Init(&gLoopback); + ReturnErrorOnFailure(gIOContext.Init()); + + ReturnErrorOnFailure(ctx.Init(&gTransportMgr, &gIOContext)); + + return CHIP_NO_ERROR; +} +} // anonymous namespace + +/** + * Set up the test suite. + */ +int CASESessionCache_Test_Setup(void * inContext) +{ + return CASETestCacheSetup(inContext) == CHIP_NO_ERROR ? SUCCESS : FAILURE; +} + +/** + * Tear down the test suite. + */ +int CASESessionCache_Test_Teardown(void * inContext) +{ + reinterpret_cast(inContext)->Shutdown(); + gIOContext.Shutdown(); + chip::Platform::MemoryShutdown(); + return SUCCESS; +} + +/** + * Main + */ +int TestCASESessionCache() +{ + // Run test suit against one context + nlTestRunner(&sSuite, &sContext); + + return (nlTestRunnerStats(&sSuite)); +} + +CHIP_REGISTER_TEST_SUITE(TestCASESessionCache)