diff --git a/examples/shell/shell_common/cmd_ping.cpp b/examples/shell/shell_common/cmd_ping.cpp index a75a3d18db30a9..b3121689f64e6f 100644 --- a/examples/shell/shell_common/cmd_ping.cpp +++ b/examples/shell/shell_common/cmd_ping.cpp @@ -328,7 +328,7 @@ void StartPinging(streamer_t * stream, char * destination) err = EstablishSecureSession(stream, GetEchoPeerAddress()); SuccessOrExit(err); - err = gEchoClient.Init(&gExchangeManager, SessionHandle(kTestDeviceNodeId, 0, 0, gFabricIndex)); + err = gEchoClient.Init(&gExchangeManager, SessionHandle(kTestDeviceNodeId, 1, 1, gFabricIndex)); SuccessOrExit(err); // Arrange to get a callback whenever an Echo Response is received. diff --git a/examples/shell/shell_common/cmd_send.cpp b/examples/shell/shell_common/cmd_send.cpp index 7d0015d59ac8d5..4e79ab38115482 100644 --- a/examples/shell/shell_common/cmd_send.cpp +++ b/examples/shell/shell_common/cmd_send.cpp @@ -127,7 +127,7 @@ CHIP_ERROR SendMessage(streamer_t * stream) uint32_t payloadSize = gSendArguments.GetPayloadSize(); // Create a new exchange context. - auto * ec = gExchangeManager.NewContext(SessionHandle(kTestDeviceNodeId, 0, 0, gFabricIndex), &gMockAppDelegate); + auto * ec = gExchangeManager.NewContext(SessionHandle(kTestDeviceNodeId, 1, 1, gFabricIndex), &gMockAppDelegate); VerifyOrExit(ec != nullptr, err = CHIP_ERROR_NO_MEMORY); payloadBuf = MessagePacketBuffer::New(payloadSize); diff --git a/src/app/CommandSender.cpp b/src/app/CommandSender.cpp index f22091d6de99f6..875ce0a08b868d 100644 --- a/src/app/CommandSender.cpp +++ b/src/app/CommandSender.cpp @@ -48,7 +48,7 @@ CHIP_ERROR CommandSender::SendCommandRequest(NodeId aNodeId, FabricIndex aFabric SuccessOrExit(err); // Create a new exchange context. - mpExchangeCtx = mpExchangeMgr->NewContext(secureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this); + mpExchangeCtx = mpExchangeMgr->NewContext(secureSession.ValueOr(SessionHandle(aNodeId, 1, 1, aFabricIndex)), this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(timeout); diff --git a/src/app/WriteClient.cpp b/src/app/WriteClient.cpp index 510a197dd6dacc..a77f58631ef91c 100644 --- a/src/app/WriteClient.cpp +++ b/src/app/WriteClient.cpp @@ -258,7 +258,7 @@ CHIP_ERROR WriteClient::SendWriteRequest(NodeId aNodeId, FabricIndex aFabricInde ClearExistingExchangeContext(); // Create a new exchange context. - mpExchangeCtx = mpExchangeMgr->NewContext(apSecureSession.ValueOr(SessionHandle(aNodeId, 0, 0, aFabricIndex)), this); + mpExchangeCtx = mpExchangeMgr->NewContext(apSecureSession.ValueOr(SessionHandle(aNodeId, 1, 1, aFabricIndex)), this); VerifyOrExit(mpExchangeCtx != nullptr, err = CHIP_ERROR_NO_MEMORY); mpExchangeCtx->SetResponseTimeout(timeout); diff --git a/src/app/tests/integration/chip_im_initiator.cpp b/src/app/tests/integration/chip_im_initiator.cpp index b5fb2cf0bb531d..e8e4e2740f0ed9 100644 --- a/src/app/tests/integration/chip_im_initiator.cpp +++ b/src/app/tests/integration/chip_im_initiator.cpp @@ -313,7 +313,7 @@ CHIP_ERROR SendReadRequest() printf("\nSend read request message to Node: %" PRIu64 "\n", chip::kTestDeviceNodeId); - chip::app::ReadPrepareParams readPrepareParams(chip::SessionHandle(chip::kTestDeviceNodeId, 0, 0, gFabricIndex)); + chip::app::ReadPrepareParams readPrepareParams(chip::SessionHandle(chip::kTestDeviceNodeId, 1, 1, gFabricIndex)); readPrepareParams.mTimeout = gMessageTimeoutMsec; readPrepareParams.mpAttributePathParamsList = &attributePathParams; readPrepareParams.mAttributePathParamsListSize = 1; @@ -374,7 +374,7 @@ CHIP_ERROR SendSubscribeRequest() CHIP_ERROR err = CHIP_NO_ERROR; gLastMessageTime = chip::System::SystemClock().GetMonotonicMilliseconds(); - chip::app::ReadPrepareParams readPrepareParams(chip::SessionHandle(chip::kTestDeviceNodeId, 0, 0, gFabricIndex)); + chip::app::ReadPrepareParams readPrepareParams(chip::SessionHandle(chip::kTestDeviceNodeId, 1, 1, gFabricIndex)); chip::app::EventPathParams eventPathParams[2]; chip::app::AttributePathParams attributePathParams[1]; readPrepareParams.mpEventPathParamsList = eventPathParams; diff --git a/src/crypto/CHIPCryptoPAL.h b/src/crypto/CHIPCryptoPAL.h index 750b18ee9f68ef..ad6313027f7533 100644 --- a/src/crypto/CHIPCryptoPAL.h +++ b/src/crypto/CHIPCryptoPAL.h @@ -50,6 +50,8 @@ constexpr size_t kSHA1_Hash_Length = 20; constexpr size_t CHIP_CRYPTO_GROUP_SIZE_BYTES = kP256_FE_Length; constexpr size_t CHIP_CRYPTO_PUBLIC_KEY_SIZE_BYTES = kP256_Point_Length; +constexpr size_t CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES = 16; + constexpr size_t kMax_ECDH_Secret_Length = kP256_FE_Length; constexpr size_t kMax_ECDSA_Signature_Length = kP256_ECDSA_Signature_Length_Raw; constexpr size_t kMAX_FE_Length = kP256_FE_Length; diff --git a/src/messaging/ExchangeContext.cpp b/src/messaging/ExchangeContext.cpp index 895f56d9e8b40e..3f358be87e288d 100644 --- a/src/messaging/ExchangeContext.cpp +++ b/src/messaging/ExchangeContext.cpp @@ -299,7 +299,7 @@ bool ExchangeContext::MatchExchange(SessionHandle session, const PacketHeader & // TODO: This check should be already implied by the equality of session check, // It should be removed after we have implemented the temporary node id for PASE and CASE sessions - && (IsEncryptionRequired() == packetHeader.GetFlags().Has(Header::FlagValues::kEncryptedMessage)) + && (IsEncryptionRequired() == packetHeader.IsEncrypted()) // AND The message was sent by an initiator and the exchange context is a responder (IsInitiator==false) // OR The message was sent by a responder and the exchange context is an initiator (IsInitiator==true) (for the broadcast diff --git a/src/messaging/ExchangeMgr.cpp b/src/messaging/ExchangeMgr.cpp index 3d1b17e42fcfa5..94cd87eb6ceebf 100644 --- a/src/messaging/ExchangeMgr.cpp +++ b/src/messaging/ExchangeMgr.cpp @@ -293,7 +293,7 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const ChipLogDetail(ExchangeManager, "Handling via exchange: " ChipLogFormatExchange ", Delegate: 0x%p", ChipLogValueExchange(ec), ec->GetDelegate()); - if (ec->IsEncryptionRequired() != packetHeader.GetFlags().Has(Header::FlagValues::kEncryptedMessage)) + if (ec->IsEncryptionRequired() != packetHeader.IsEncrypted()) { ChipLogError(ExchangeManager, "OnMessageReceived failed, err = %s", ErrorStr(CHIP_ERROR_INVALID_MESSAGE_TYPE)); ec->Close(); diff --git a/src/messaging/tests/echo/echo_requester.cpp b/src/messaging/tests/echo/echo_requester.cpp index a62bdf53fc0021..a7ebfb6dbfaa07 100644 --- a/src/messaging/tests/echo/echo_requester.cpp +++ b/src/messaging/tests/echo/echo_requester.cpp @@ -256,7 +256,7 @@ int main(int argc, char * argv[]) err = EstablishSecureSession(); SuccessOrExit(err); - err = gEchoClient.Init(&gExchangeManager, chip::SessionHandle(chip::kTestDeviceNodeId, 0, 0, gFabricIndex)); + err = gEchoClient.Init(&gExchangeManager, chip::SessionHandle(chip::kTestDeviceNodeId, 1, 1, gFabricIndex)); SuccessOrExit(err); // Arrange to get a callback whenever an Echo Response is received. diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index 8c6c99b7804118..6c3c0be6274011 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -339,8 +339,10 @@ class SecurePairingUsingTestSecret : public PairingSession public: SecurePairingUsingTestSecret() { - SetLocalSessionId(0); - SetPeerSessionId(0); + // Do not set to 0 to prevent unwanted unsecured session + // since the session type is unknown. + SetLocalSessionId(1); + SetPeerSessionId(1); } SecurePairingUsingTestSecret(uint16_t peerSessionId, uint16_t localSessionId) diff --git a/src/protocols/secure_channel/SessionIDAllocator.cpp b/src/protocols/secure_channel/SessionIDAllocator.cpp index d95d1c4cf1d69d..650dab3c2fbdbf 100644 --- a/src/protocols/secure_channel/SessionIDAllocator.cpp +++ b/src/protocols/secure_channel/SessionIDAllocator.cpp @@ -24,6 +24,7 @@ namespace chip { CHIP_ERROR SessionIDAllocator::Allocate(uint16_t & id) { VerifyOrReturnError(mNextAvailable < kMaxSessionID, CHIP_ERROR_NO_MEMORY); + VerifyOrReturnError(mNextAvailable > kUnsecuredSessionId, CHIP_ERROR_INTERNAL); id = mNextAvailable; // TODO - Update SessionID allocator to use freed session IDs @@ -34,7 +35,8 @@ CHIP_ERROR SessionIDAllocator::Allocate(uint16_t & id) void SessionIDAllocator::Free(uint16_t id) { - if (mNextAvailable > 0 && (mNextAvailable - 1) == id) + // As per spec 4.4.1.3 Session ID of 0 is reserved for Unsecure communication + if (mNextAvailable > (kUnsecuredSessionId + 1) && (mNextAvailable - 1) == id) { mNextAvailable--; } diff --git a/src/protocols/secure_channel/SessionIDAllocator.h b/src/protocols/secure_channel/SessionIDAllocator.h index f5649bf98949ed..25cb69334db89b 100644 --- a/src/protocols/secure_channel/SessionIDAllocator.h +++ b/src/protocols/secure_channel/SessionIDAllocator.h @@ -18,6 +18,15 @@ #pragma once #include +#include + +// Spec 4.4.1.3 +// ===== Session ID (16 bits) +// An unsigned integer value identifying the session associated with this message. +// The session identifies the particular key used to encrypt a message out of the set of +// available keys (either session or group), and the particular encryption/message +// integrity algorithm to use for the message.The Session ID field is always present. +// A Session ID of 0 SHALL indicate an unsecured session with no encryption or message integrity checking. namespace chip { @@ -34,9 +43,10 @@ class SessionIDAllocator uint16_t Peek(); private: - // Session ID is a 15 bit value (16th bit indicates unicast/group key) - static constexpr uint16_t kMaxSessionID = (1 << 15) - 1; - uint16_t mNextAvailable = 0; + static constexpr uint16_t kMaxSessionID = UINT16_MAX; + static constexpr uint16_t kUnsecuredSessionId = 0; + + uint16_t mNextAvailable = 1; }; } // namespace chip diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 276c59b343b539..d249fce013244f 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -405,6 +405,10 @@ void CASE_SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) PacketHeader header; MessageAuthenticationCode mac; + header.SetSessionId(1); + NL_TEST_ASSERT(inSuite, header.IsEncrypted() == true); + NL_TEST_ASSERT(inSuite, header.MICTagLength() == 16); + // Let's try encrypting using original session, and decrypting using deserialized { CryptoContext session1; diff --git a/src/protocols/secure_channel/tests/TestPASESession.cpp b/src/protocols/secure_channel/tests/TestPASESession.cpp index 1a1c08cf5ac28f..cb672c5297f642 100644 --- a/src/protocols/secure_channel/tests/TestPASESession.cpp +++ b/src/protocols/secure_channel/tests/TestPASESession.cpp @@ -302,6 +302,10 @@ void SecurePairingSerializeTest(nlTestSuite * inSuite, void * inContext) PacketHeader header; MessageAuthenticationCode mac; + header.SetSessionId(1); + NL_TEST_ASSERT(inSuite, header.IsEncrypted() == true); + NL_TEST_ASSERT(inSuite, header.MICTagLength() == 16); + // Let's try encrypting using original session, and decrypting using deserialized { CryptoContext session1; diff --git a/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp b/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp index ed1aaa10079ca4..68aa420e02c96c 100644 --- a/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp +++ b/src/protocols/secure_channel/tests/TestSessionIDAllocator.cpp @@ -28,11 +28,11 @@ void TestSessionIDAllocator_Allocate(nlTestSuite * inSuite, void * inContext) { SessionIDAllocator allocator; - NL_TEST_ASSERT(inSuite, allocator.Peek() == 0); + NL_TEST_ASSERT(inSuite, allocator.Peek() == 1); uint16_t id; - for (uint16_t i = 0; i < 16; i++) + for (uint16_t i = 1; i < 16; i++) { CHIP_ERROR err = allocator.Allocate(id); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); @@ -45,11 +45,11 @@ void TestSessionIDAllocator_Free(nlTestSuite * inSuite, void * inContext) { SessionIDAllocator allocator; - NL_TEST_ASSERT(inSuite, allocator.Peek() == 0); + NL_TEST_ASSERT(inSuite, allocator.Peek() == 1); uint16_t id; - for (uint16_t i = 0; i < 16; i++) + for (uint16_t i = 1; i < 17; i++) { CHIP_ERROR err = allocator.Allocate(id); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); @@ -59,15 +59,15 @@ void TestSessionIDAllocator_Free(nlTestSuite * inSuite, void * inContext) // Free an intermediate ID allocator.Free(10); - NL_TEST_ASSERT(inSuite, allocator.Peek() == 16); + NL_TEST_ASSERT(inSuite, allocator.Peek() == 17); // Free the last allocated ID - allocator.Free(15); - NL_TEST_ASSERT(inSuite, allocator.Peek() == 15); + allocator.Free(16); + NL_TEST_ASSERT(inSuite, allocator.Peek() == 16); // Free some random unallocated ID allocator.Free(100); - NL_TEST_ASSERT(inSuite, allocator.Peek() == 15); + NL_TEST_ASSERT(inSuite, allocator.Peek() == 16); } void TestSessionIDAllocator_Reserve(nlTestSuite * inSuite, void * inContext) @@ -76,7 +76,7 @@ void TestSessionIDAllocator_Reserve(nlTestSuite * inSuite, void * inContext) uint16_t id; - for (uint16_t i = 0; i < 16; i++) + for (uint16_t i = 1; i < 16; i++) { CHIP_ERROR err = allocator.Allocate(id); NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR); diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp index a49611545d8223..febb51da343147 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningClient.cpp @@ -67,8 +67,6 @@ CHIP_ERROR UserDirectedCommissioningClient::EncodeUDCMessage(System::PacketBuffe ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(payload)); - packetHeader.SetSessionType(Header::SessionType::kSessionTypeNone); - ReturnErrorOnFailure(packetHeader.EncodeBeforeData(payload)); return CHIP_NO_ERROR; diff --git a/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp b/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp index 4b2663414a0cbd..08c8e6d2a18783 100644 --- a/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp +++ b/src/protocols/user_directed_commissioning/UserDirectedCommissioningServer.cpp @@ -38,7 +38,7 @@ void UserDirectedCommissioningServer::OnMessageReceived(const Transport::PeerAdd ReturnOnFailure(packetHeader.DecodeAndConsume(msg)); - if (packetHeader.GetFlags().Has(Header::FlagValues::kEncryptedMessage)) + if (packetHeader.IsEncrypted()) { ChipLogError(AppServer, "UDC encryption flag set - ignoring"); return; diff --git a/src/protocols/user_directed_commissioning/tests/TestUdcMessages.cpp b/src/protocols/user_directed_commissioning/tests/TestUdcMessages.cpp index f80a1738328231..d0d05bb6fa74cb 100644 --- a/src/protocols/user_directed_commissioning/tests/TestUdcMessages.cpp +++ b/src/protocols/user_directed_commissioning/tests/TestUdcMessages.cpp @@ -180,7 +180,7 @@ void TestUserDirectedCommissioningClientMessage(nlTestSuite * inSuite, void * in // check the packet header fields PacketHeader packetHeader; packetHeader.DecodeAndConsume(payloadBuf); - NL_TEST_ASSERT(inSuite, !packetHeader.GetFlags().Has(Header::FlagValues::kEncryptedMessage)); + NL_TEST_ASSERT(inSuite, !packetHeader.IsEncrypted()); // check the payload header fields PayloadHeader payloadHeader; diff --git a/src/transport/CryptoContext.cpp b/src/transport/CryptoContext.cpp index fdd30bb7ce4641..7ac94a5351ad00 100644 --- a/src/transport/CryptoContext.cpp +++ b/src/transport/CryptoContext.cpp @@ -163,9 +163,8 @@ CHIP_ERROR CryptoContext::Encrypt(const uint8_t * input, size_t input_length, ui MessageAuthenticationCode & mac) const { - constexpr Header::SessionType sessionType = Header::SessionType::kAESCCMTagLen16; + const size_t taglen = header.MICTagLength(); - const size_t taglen = MessageAuthenticationCode::TagLenForSessionType(sessionType); VerifyOrDie(taglen <= kMaxTagLen); VerifyOrReturnError(mKeyAvailable, CHIP_ERROR_INVALID_USE_OF_SESSION_KEY); @@ -194,7 +193,7 @@ CHIP_ERROR CryptoContext::Encrypt(const uint8_t * input, size_t input_length, ui ReturnErrorOnFailure(AES_CCM_encrypt(input, input_length, AAD, aadLen, mKeys[usage], Crypto::kAES_CCM128_Key_Length, IV, sizeof(IV), output, tag, taglen)); - mac.SetTag(&header, sessionType, tag, taglen); + mac.SetTag(&header, tag, taglen); return CHIP_NO_ERROR; } @@ -202,7 +201,7 @@ CHIP_ERROR CryptoContext::Encrypt(const uint8_t * input, size_t input_length, ui CHIP_ERROR CryptoContext::Decrypt(const uint8_t * input, size_t input_length, uint8_t * output, const PacketHeader & header, const MessageAuthenticationCode & mac) const { - const size_t taglen = MessageAuthenticationCode::TagLenForSessionType(header.GetSessionType()); + const size_t taglen = header.MICTagLength(); const uint8_t * tag = mac.GetTag(); uint8_t IV[kAESCCMIVLen]; uint8_t AAD[kMaxAADLen]; diff --git a/src/transport/SecureMessageCodec.cpp b/src/transport/SecureMessageCodec.cpp index 666c5f3e69192b..aa54b6e1272c43 100644 --- a/src/transport/SecureMessageCodec.cpp +++ b/src/transport/SecureMessageCodec.cpp @@ -52,7 +52,8 @@ CHIP_ERROR Encrypt(Transport::SecureSession * state, PayloadHeader & payloadHead .SetMessageCounter(messageCounter) // .SetSessionId(state->GetPeerSessionId()); - packetHeader.GetFlags().Set(Header::FlagValues::kEncryptedMessage); + // TODO set Session Type (Unicast or Group) + // packetHeader.SetSessionType(Header::SessionType::kUnicastSession); ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(msgBuf)); @@ -90,7 +91,7 @@ CHIP_ERROR Decrypt(Transport::SecureSession * state, PayloadHeader & payloadHead msg->SetDataLength(len); #endif - uint16_t footerLen = MessageAuthenticationCode::TagLenForSessionType(packetHeader.GetSessionType()); + uint16_t footerLen = packetHeader.MICTagLength(); VerifyOrReturnError(footerLen <= len, CHIP_ERROR_INVALID_MESSAGE_LENGTH); uint16_t taglen = 0; diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index c1a86d2b2501ba..7dc32a45271c77 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -326,7 +326,7 @@ void SessionManager::OnMessageReceived(const PeerAddress & peerAddress, System:: ReturnOnFailure(packetHeader.DecodeAndConsume(msg)); - if (packetHeader.GetFlags().Has(Header::FlagValues::kEncryptedMessage)) + if (packetHeader.IsEncrypted()) { SecureMessageDispatch(packetHeader, peerAddress, std::move(msg)); } @@ -403,7 +403,7 @@ void SessionManager::SecureMessageDispatch(const PacketHeader & packetHeader, co ChipLogError(Inet, "Secure transport received message, but failed to decode/authenticate it, discarding")); // Verify message counter - if (packetHeader.GetFlags().Has(Header::FlagValues::kSecureSessionControlMessage)) + if (packetHeader.IsSecureSessionControlMsg()) { // TODO: control message counter is not implemented yet } @@ -462,7 +462,7 @@ void SessionManager::SecureMessageDispatch(const PacketHeader & packetHeader, co } } - if (packetHeader.GetFlags().Has(Header::FlagValues::kSecureSessionControlMessage)) + if (packetHeader.IsSecureSessionControlMsg()) { // TODO: control message counter is not implemented yet } diff --git a/src/transport/raw/MessageHeader.cpp b/src/transport/raw/MessageHeader.cpp index 43d503d5ee3f01..d9ac45a38a2e1a 100644 --- a/src/transport/raw/MessageHeader.cpp +++ b/src/transport/raw/MessageHeader.cpp @@ -80,15 +80,15 @@ constexpr size_t kVendorIdSizeBytes = 2; /// size of a serialized ack message counter inside a header constexpr size_t kAckMessageCounterSizeBytes = 4; -/// Mask to extract just the version part from a 16bit header prefix. -constexpr uint16_t kVersionMask = 0x00F0; -/// Shift to convert to/from a masked version 16bit value to a 4bit version. +/// Mask to extract just the version part from a 8bits header prefix. +constexpr uint8_t kVersionMask = 0xF0; + +constexpr uint8_t kMsgFlagsMask = 0x07; +/// Shift to convert to/from a masked version 8bit value to a 4bit version. constexpr int kVersionShift = 4; -/// Mask to extract just the encryption type part from a 16bit header prefix. -constexpr uint16_t kSessionTypeMask = 0x3000; -/// Shift to convert to/from a masked encryption type 16bit value to a 2bit encryption type. -constexpr int kSessionTypeShift = 12; +// Mask to extract sessionType +constexpr uint8_t kSessionTypeMask = 0x03; } // namespace @@ -134,18 +134,6 @@ uint16_t PayloadHeader::EncodeSizeBytes() const return static_cast(size); } -uint16_t MessageAuthenticationCode::TagLenForSessionType(Header::SessionType sessionType) -{ - switch (sessionType) - { - case Header::SessionType::kAESCCMTagLen16: - return 16; - - default: - return 0; - } -} - CHIP_ERROR PacketHeader::Decode(const uint8_t * const data, uint16_t size, uint16_t * decode_len) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -154,14 +142,20 @@ CHIP_ERROR PacketHeader::Decode(const uint8_t * const data, uint16_t size, uint1 // TODO: De-uint16-ify everything related to this library uint16_t octets_read; - uint16_t header; - err = reader.Read16(&header).StatusCode(); + uint8_t msgFlags; + err = reader.Read8(&msgFlags).StatusCode(); SuccessOrExit(err); - version = ((header & kVersionMask) >> kVersionShift); + version = ((msgFlags & kVersionMask) >> kVersionShift); VerifyOrExit(version == kMsgHeaderVersion, err = CHIP_ERROR_VERSION_MISMATCH); - mFlags.SetRaw(header); - mSessionType = static_cast((header & kSessionTypeMask) >> kSessionTypeShift); + mMsgFlags.SetRaw(msgFlags); + + uint8_t securityFlags; + err = reader.Read8(&securityFlags).StatusCode(); + SuccessOrExit(err); + mSecFlags.SetRaw(securityFlags); + + mSessionType = static_cast(securityFlags & kSessionTypeMask); err = reader.Read16(&mSessionId).StatusCode(); SuccessOrExit(err); @@ -169,7 +163,7 @@ CHIP_ERROR PacketHeader::Decode(const uint8_t * const data, uint16_t size, uint1 err = reader.Read32(&mMessageCounter).StatusCode(); SuccessOrExit(err); - if (mFlags.Has(Header::FlagValues::kSourceNodeIdPresent)) + if (mMsgFlags.Has(Header::MsgFlagValues::kSourceNodeIdPresent)) { uint64_t sourceNodeId; err = reader.Read64(&sourceNodeId).StatusCode(); @@ -180,22 +174,40 @@ CHIP_ERROR PacketHeader::Decode(const uint8_t * const data, uint16_t size, uint1 { mSourceNodeId.ClearValue(); } - if (mFlags.HasAll(Header::FlagValues::kDestinationNodeIdPresent, Header::FlagValues::kDestinationGroupIdPresent)) + + if (!IsSessionTypeValid()) + { + // Reserved. + err = CHIP_ERROR_INTERNAL; + SuccessOrExit(err); + } + + if (mMsgFlags.HasAll(Header::MsgFlagValues::kDestinationNodeIdPresent, Header::MsgFlagValues::kDestinationGroupIdPresent)) { // Reserved. err = CHIP_ERROR_INTERNAL; SuccessOrExit(err); } - else if (mFlags.Has(Header::FlagValues::kDestinationNodeIdPresent)) + else if (mMsgFlags.Has(Header::MsgFlagValues::kDestinationNodeIdPresent)) { + if (mSessionType != Header::SessionType::kUnicastSession) + { + err = CHIP_ERROR_INTERNAL; + SuccessOrExit(err); + } uint64_t destinationNodeId; err = reader.Read64(&destinationNodeId).StatusCode(); SuccessOrExit(err); mDestinationNodeId.SetValue(destinationNodeId); mDestinationGroupId.ClearValue(); } - else if (mFlags.Has(Header::FlagValues::kDestinationGroupIdPresent)) + else if (mMsgFlags.Has(Header::MsgFlagValues::kDestinationGroupIdPresent)) { + if (mSessionType != Header::SessionType::kGroupSession) + { + err = CHIP_ERROR_INTERNAL; + SuccessOrExit(err); + } uint16_t destinationGroupId; err = reader.Read16(&destinationGroupId).StatusCode(); SuccessOrExit(err); @@ -289,17 +301,22 @@ CHIP_ERROR PacketHeader::Encode(uint8_t * data, uint16_t size, uint16_t * encode { VerifyOrReturnError(size >= EncodeSizeBytes(), CHIP_ERROR_INVALID_ARGUMENT); VerifyOrReturnError(!(mDestinationNodeId.HasValue() && mDestinationGroupId.HasValue()), CHIP_ERROR_INTERNAL); + VerifyOrReturnError(encode_size != nullptr, CHIP_ERROR_INTERNAL); + VerifyOrReturnError(IsSessionTypeValid(), CHIP_ERROR_INTERNAL); + VerifyOrReturnError(!(IsGroupSession() && !mDestinationGroupId.HasValue()), CHIP_ERROR_INTERNAL); - Header::Flags encodeFlags = mFlags; - encodeFlags.Set(Header::FlagValues::kSourceNodeIdPresent, mSourceNodeId.HasValue()) - .Set(Header::FlagValues::kDestinationNodeIdPresent, mDestinationNodeId.HasValue()) - .Set(Header::FlagValues::kDestinationGroupIdPresent, mDestinationGroupId.HasValue()); + Header::MsgFlags messageFlags = mMsgFlags; + messageFlags.Set(Header::MsgFlagValues::kSourceNodeIdPresent, mSourceNodeId.HasValue()) + .Set(Header::MsgFlagValues::kDestinationNodeIdPresent, mDestinationNodeId.HasValue()) + .Set(Header::MsgFlagValues::kDestinationGroupIdPresent, mDestinationGroupId.HasValue()); - uint16_t header = (kMsgHeaderVersion << kVersionShift) | encodeFlags.Raw(); - header |= (static_cast(static_cast(mSessionType) << kSessionTypeShift) & kSessionTypeMask); + uint8_t msgFlags = (kMsgHeaderVersion << kVersionShift) | (messageFlags.Raw() & kMsgFlagsMask); + uint8_t secFlags = mSecFlags.Raw(); + secFlags |= static_cast(mSessionType); uint8_t * p = data; - LittleEndian::Write16(p, header); + Write8(p, msgFlags); + Write8(p, secFlags); LittleEndian::Write16(p, mSessionId); LittleEndian::Write32(p, mMessageCounter); if (mSourceNodeId.HasValue()) @@ -379,7 +396,7 @@ CHIP_ERROR PayloadHeader::EncodeBeforeData(const System::PacketBufferHandle & bu CHIP_ERROR MessageAuthenticationCode::Decode(const PacketHeader & packetHeader, const uint8_t * const data, uint16_t size, uint16_t * decode_len) { - const uint16_t taglen = TagLenForSessionType(packetHeader.GetSessionType()); + const uint16_t taglen = packetHeader.MICTagLength(); VerifyOrReturnError(taglen != 0, CHIP_ERROR_WRONG_ENCRYPTION_TYPE_FROM_PEER); VerifyOrReturnError(size >= taglen, CHIP_ERROR_INVALID_ARGUMENT); @@ -395,7 +412,7 @@ CHIP_ERROR MessageAuthenticationCode::Encode(const PacketHeader & packetHeader, uint16_t * encode_size) const { uint8_t * p = data; - const uint16_t taglen = TagLenForSessionType(packetHeader.GetSessionType()); + const uint16_t taglen = packetHeader.MICTagLength(); VerifyOrReturnError(taglen != 0, CHIP_ERROR_WRONG_ENCRYPTION_TYPE); VerifyOrReturnError(size >= taglen, CHIP_ERROR_INVALID_ARGUMENT); diff --git a/src/transport/raw/MessageHeader.h b/src/transport/raw/MessageHeader.h index 32f494467be9e6..d99f2c02c7357a 100644 --- a/src/transport/raw/MessageHeader.h +++ b/src/transport/raw/MessageHeader.h @@ -28,6 +28,7 @@ #include +#include #include #include #include @@ -43,7 +44,7 @@ static constexpr size_t kMaxTagLen = 16; static constexpr size_t kMaxAppMessageLen = 1200; -static constexpr size_t kMsgSessionIdUnsecured = 0x0000; +static constexpr uint16_t kMsgUnicastSessionIdUnsecured = 0x0000; typedef int PacketHeaderFlags; @@ -51,8 +52,8 @@ namespace Header { enum class SessionType { - kSessionTypeNone = 0, - kAESCCMTagLen16 = 1, + kUnicastSession = 0, + kGroupSession = 1, }; /** @@ -74,35 +75,45 @@ enum class ExFlagValues : uint8_t kExchangeFlag_VendorIdPresent = 0x10, }; -enum class FlagValues : uint16_t -{ - /// Header flag specifying that a destination node id is included in the header. - kDestinationNodeIdPresent = 0x0001, - - /// Header flag specifying that a destination group id is included in the header. - kDestinationGroupIdPresent = 0x0002, +// Message flags 8-bit value of the form +// | 4 bits | 1 | 1 | 2 bits | +// +---------+-------+--------| +// | version | - | S | DSIZ +// | | +// | +---------------- Destination Id field +// +-------------------- Source node Id present +enum class MsgFlagValues : uint8_t +{ /// Header flag specifying that a source node id is included in the header. - kSourceNodeIdPresent = 0x0004, - - /// Header flag specifying that it is a control message for secure session. - kSecureSessionControlMessage = 0x4000, + kSourceNodeIdPresent = 0b00000100, + kDestinationNodeIdPresent = 0b00000001, + kDestinationGroupIdPresent = 0b00000010, + kDSIZReserved = 0b00000011, - /// Header flag specifying that it is a encrypted message. - kEncryptedMessage = 0x0100, +}; +// Security flags 8-bit value of the form +// | 1 | 1 | 1 | 3 | 2 bits | +// +------------+---+--------| +// | P | C | MX | - | SessionType +// +// With : +// P = Privacy flag +// C = Control Msg flag +// MX = Message Extension + +enum class SecFlagValues : uint8_t +{ + kPrivacyFlag = 0b10000000, + kControlMsgFlag = 0b01000000, + kMsgExtensionFlag = 0b00100000, }; -using Flags = BitFlags; -using ExFlags = BitFlags; +using MsgFlags = BitFlags; +using SecFlags = BitFlags; -// Header is a 16-bit value of the form -// | 4 bit | 4 bit |8 bit Security Flags| -// +---------+-------+--------------------| -// | version | Flags | P | C |Reserved| E | -// | | +---Encrypted -// | +----------------Control message (TODO: Implement this) -// +--------------------Privacy enhancements (TODO: Implement this) +using ExFlags = BitFlags; } // namespace Header @@ -145,81 +156,103 @@ class PacketHeader const Optional & GetDestinationGroupId() const { return mDestinationGroupId; } uint16_t GetSessionId() const { return mSessionId; } + Header::SessionType GetSessionType() const { return mSessionType; } - Header::Flags & GetFlags() { return mFlags; } - const Header::Flags & GetFlags() const { return mFlags; } + bool IsGroupSession() const { return mSessionType == Header::SessionType::kGroupSession; } + bool IsUnicastSession() const { return mSessionType == Header::SessionType::kUnicastSession; } - /** Check if it's a secure session control message. */ - bool IsSecureSessionControlMsg() const { return mFlags.Has(Header::FlagValues::kSecureSessionControlMessage); } + bool IsSessionTypeValid() const + { + switch (mSessionType) + { + case Header::SessionType::kUnicastSession: + return true; + case Header::SessionType::kGroupSession: + return true; + default: + return false; + } + } - Header::SessionType GetSessionType() const { return mSessionType; } + bool IsEncrypted() const { return !((mSessionId == kMsgUnicastSessionIdUnsecured) && IsUnicastSession()); } + + uint16_t MICTagLength() const { return (IsEncrypted()) ? chip::Crypto::CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES : 0; } + + /** Check if it's a secure session control message. */ + bool IsSecureSessionControlMsg() const { return mSecFlags.Has(Header::SecFlagValues::kControlMsgFlag); } PacketHeader & SetSecureSessionControlMsg(bool value) { - mFlags.Set(Header::FlagValues::kSecureSessionControlMessage, value); + mSecFlags.Set(Header::SecFlagValues::kControlMsgFlag, value); return *this; } PacketHeader & SetSourceNodeId(NodeId id) { mSourceNodeId.SetValue(id); - mFlags.Set(Header::FlagValues::kSourceNodeIdPresent); + mMsgFlags.Set(Header::MsgFlagValues::kSourceNodeIdPresent); return *this; } PacketHeader & SetSourceNodeId(Optional id) { mSourceNodeId = id; - mFlags.Set(Header::FlagValues::kSourceNodeIdPresent, id.HasValue()); + mMsgFlags.Set(Header::MsgFlagValues::kSourceNodeIdPresent, id.HasValue()); return *this; } PacketHeader & ClearSourceNodeId() { mSourceNodeId.ClearValue(); - mFlags.Clear(Header::FlagValues::kSourceNodeIdPresent); + mMsgFlags.Clear(Header::MsgFlagValues::kSourceNodeIdPresent); return *this; } PacketHeader & SetDestinationNodeId(NodeId id) { mDestinationNodeId.SetValue(id); - mFlags.Set(Header::FlagValues::kDestinationNodeIdPresent); + mMsgFlags.Set(Header::MsgFlagValues::kDestinationNodeIdPresent); return *this; } PacketHeader & SetDestinationNodeId(Optional id) { mDestinationNodeId = id; - mFlags.Set(Header::FlagValues::kDestinationNodeIdPresent, id.HasValue()); + mMsgFlags.Set(Header::MsgFlagValues::kDestinationNodeIdPresent, id.HasValue()); return *this; } PacketHeader & ClearDestinationNodeId() { mDestinationNodeId.ClearValue(); - mFlags.Clear(Header::FlagValues::kDestinationNodeIdPresent); + mMsgFlags.Clear(Header::MsgFlagValues::kDestinationNodeIdPresent); return *this; } PacketHeader & SetDestinationGroupId(GroupId id) { mDestinationGroupId.SetValue(id); - mFlags.Set(Header::FlagValues::kDestinationGroupIdPresent); + mMsgFlags.Set(Header::MsgFlagValues::kDestinationGroupIdPresent); return *this; } PacketHeader & SetDestinationGroupId(Optional id) { mDestinationGroupId = id; - mFlags.Set(Header::FlagValues::kDestinationGroupIdPresent, id.HasValue()); + mMsgFlags.Set(Header::MsgFlagValues::kDestinationGroupIdPresent, id.HasValue()); return *this; } PacketHeader & ClearDestinationGroupId() { mDestinationGroupId.ClearValue(); - mFlags.Clear(Header::FlagValues::kDestinationGroupIdPresent); + mMsgFlags.Clear(Header::MsgFlagValues::kDestinationGroupIdPresent); + return *this; + } + + PacketHeader & SetSessionType(Header::SessionType type) + { + mSessionType = type; return *this; } @@ -235,9 +268,10 @@ class PacketHeader return *this; } - PacketHeader & SetSessionType(Header::SessionType type) + PacketHeader & SetUnsecured() { - mSessionType = type; + mSessionId = kMsgUnicastSessionIdUnsecured; + mSessionType = Header::SessionType::kUnicastSession; return *this; } @@ -322,8 +356,8 @@ class PacketHeader } private: - /// Represents the current encode/decode header version - static constexpr int kMsgHeaderVersion = 0; + /// Represents the current encode/decode header version (4 bits) + static constexpr uint8_t kMsgHeaderVersion = 0x00; /// Value expected to be incremented for each message sent. uint32_t mMessageCounter = 0; @@ -336,13 +370,13 @@ class PacketHeader Optional mDestinationGroupId; /// Session ID - uint16_t mSessionId = kMsgSessionIdUnsecured; + uint16_t mSessionId = kMsgUnicastSessionIdUnsecured; - /// Message flags read from the message. - Header::Flags mFlags; + Header::SessionType mSessionType = Header::SessionType::kUnicastSession; - /// Represents session type used for encrypting current packet - Header::SessionType mSessionType = Header::SessionType::kAESCCMTagLen16; + /// Flags read from the message. + Header::MsgFlags mMsgFlags; + Header::SecFlags mSecFlags; }; /** @@ -582,12 +616,11 @@ class MessageAuthenticationCode const uint8_t * GetTag() const { return &mTag[0]; } /** Set the message auth tag for this header. */ - MessageAuthenticationCode & SetTag(PacketHeader * header, Header::SessionType sessionType, uint8_t * tag, size_t len) + MessageAuthenticationCode & SetTag(PacketHeader * header, uint8_t * tag, size_t len) { - const size_t tagLen = TagLenForSessionType(sessionType); + const size_t tagLen = chip::Crypto::CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES; if (tagLen > 0 && tagLen <= kMaxTagLen && len == tagLen) { - header->SetSessionType(sessionType); memcpy(&mTag, tag, tagLen); } @@ -626,8 +659,6 @@ class MessageAuthenticationCode */ CHIP_ERROR Encode(const PacketHeader & packetHeader, uint8_t * data, uint16_t size, uint16_t * encode_size) const; - static uint16_t TagLenForSessionType(Header::SessionType sessionType); - private: /// Message authentication tag generated at encryption of the message. uint8_t mTag[kMaxTagLen]; diff --git a/src/transport/raw/tests/TestMessageHeader.cpp b/src/transport/raw/tests/TestMessageHeader.cpp index 71f916c477217b..da4aec5a15dc74 100644 --- a/src/transport/raw/tests/TestMessageHeader.cpp +++ b/src/transport/raw/tests/TestMessageHeader.cpp @@ -41,6 +41,9 @@ void TestPacketHeaderInitialState(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, !header.IsSecureSessionControlMsg()); NL_TEST_ASSERT(inSuite, header.GetMessageCounter() == 0); NL_TEST_ASSERT(inSuite, header.GetSessionId() == 0); + NL_TEST_ASSERT(inSuite, header.GetSessionType() == Header::SessionType::kUnicastSession); + NL_TEST_ASSERT(inSuite, header.IsSessionTypeValid()); + NL_TEST_ASSERT(inSuite, !header.IsEncrypted()); NL_TEST_ASSERT(inSuite, !header.GetDestinationNodeId().HasValue()); NL_TEST_ASSERT(inSuite, !header.GetDestinationGroupId().HasValue()); NL_TEST_ASSERT(inSuite, !header.GetSourceNodeId().HasValue()); @@ -127,6 +130,7 @@ void TestPacketHeaderEncodeDecode(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, header.GetMessageCounter() == 234); NL_TEST_ASSERT(inSuite, header.GetDestinationNodeId() == Optional::Value(88ull)); NL_TEST_ASSERT(inSuite, header.GetSourceNodeId() == Optional::Value(77ull)); + NL_TEST_ASSERT(inSuite, header.IsEncrypted()); NL_TEST_ASSERT(inSuite, header.GetSessionId() == 2); header.SetMessageCounter(234).SetSourceNodeId(77).SetDestinationNodeId(88); @@ -144,6 +148,7 @@ void TestPacketHeaderEncodeDecode(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, header.Encode(buffer, &encodeLen) == CHIP_ERROR_INTERNAL); header.ClearDestinationNodeId(); + header.SetSessionType(Header::SessionType::kGroupSession); NL_TEST_ASSERT(inSuite, header.Encode(buffer, &encodeLen) == CHIP_NO_ERROR); // change it to verify decoding @@ -254,6 +259,7 @@ void TestPacketHeaderEncodeDecodeBounds(nlTestSuite * inSuite, void * inContext) // Now test encoding/decoding with a source node id and destination group id present. header.ClearDestinationNodeId(); header.SetDestinationGroupId(25); + header.SetSessionType(Header::SessionType::kGroupSession); for (uint16_t shortLen = minLen; shortLen < minLen + 10; shortLen++) { NL_TEST_ASSERT(inSuite, header.Encode(buffer, shortLen, &unusedLen) != CHIP_NO_ERROR); diff --git a/src/transport/tests/TestSecureSession.cpp b/src/transport/tests/TestSecureSession.cpp index a839efef9be038..f3fd41e5336af7 100644 --- a/src/transport/tests/TestSecureSession.cpp +++ b/src/transport/tests/TestSecureSession.cpp @@ -79,6 +79,10 @@ void SecureChannelEncryptTest(nlTestSuite * inSuite, void * inContext) PacketHeader packetHeader; MessageAuthenticationCode mac; + packetHeader.SetSessionId(1); + NL_TEST_ASSERT(inSuite, packetHeader.IsEncrypted() == true); + NL_TEST_ASSERT(inSuite, packetHeader.MICTagLength() == 16); + P256Keypair keypair; NL_TEST_ASSERT(inSuite, keypair.Initialize() == CHIP_NO_ERROR); @@ -114,6 +118,10 @@ void SecureChannelDecryptTest(nlTestSuite * inSuite, void * inContext) PacketHeader packetHeader; MessageAuthenticationCode mac; + packetHeader.SetSessionId(1); + NL_TEST_ASSERT(inSuite, packetHeader.IsEncrypted() == true); + NL_TEST_ASSERT(inSuite, packetHeader.MICTagLength() == 16); + const char * salt = "Test Salt"; P256Keypair keypair;