diff --git a/src/transport/SessionManager.cpp b/src/transport/SessionManager.cpp index 7f8838ebd68b71..82ae6587bbb816 100644 --- a/src/transport/SessionManager.cpp +++ b/src/transport/SessionManager.cpp @@ -401,49 +401,17 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea VerifyOrExit(CHIP_NO_ERROR == SecureMessageCodec::Decrypt(session, payloadHeader, packetHeader, msg), ChipLogError(Inet, "Secure transport received message, but failed to decode/authenticate it, discarding")); - // Verify message counter - if (packetHeader.IsSecureSessionControlMsg()) + err = session->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageCounter()); + if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED) { - // TODO: control message counter is not implemented yet + isDuplicate = SessionManagerDelegate::DuplicateMessage::Yes; + err = CHIP_NO_ERROR; } - else + if (err != CHIP_NO_ERROR) { - if (!session->GetSessionMessageCounter().GetPeerMessageCounter().IsSynchronized()) - { - // Queue and start message sync procedure - err = mMessageCounterManager->QueueReceivedMessageAndStartSync( - packetHeader, - SessionHandle(session->GetPeerNodeId(), session->GetLocalSessionId(), session->GetPeerSessionId(), - session->GetFabricIndex()), - session, peerAddress, std::move(msg)); - - if (err != CHIP_NO_ERROR) - { - ChipLogError(Inet, - "Message counter synchronization for received message, failed to " - "QueueReceivedMessageAndStartSync, err = %" CHIP_ERROR_FORMAT, - err.Format()); - } - else - { - ChipLogDetail(Inet, "Received message have been queued due to peer counter is not synced"); - } - - return; - } - - err = session->GetSessionMessageCounter().GetPeerMessageCounter().Verify(packetHeader.GetMessageCounter()); - if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED) - { - isDuplicate = SessionManagerDelegate::DuplicateMessage::Yes; - err = CHIP_NO_ERROR; - } - if (err != CHIP_NO_ERROR) - { - ChipLogError(Inet, "Message counter verify failed, err = %" CHIP_ERROR_FORMAT, err.Format()); - } - SuccessOrExit(err); + ChipLogError(Inet, "Message counter verify failed, err = %" CHIP_ERROR_FORMAT, err.Format()); } + SuccessOrExit(err); mSecureSessions.MarkSessionActive(session); @@ -461,14 +429,7 @@ void SessionManager::SecureUnicastMessageDispatch(const PacketHeader & packetHea } } - if (packetHeader.IsSecureSessionControlMsg()) - { - // TODO: control message counter is not implemented yet - } - else - { - session->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); - } + session->GetSessionMessageCounter().GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter()); // TODO: once mDNS address resolution is available reconsider if this is required // This updates the peer address once a packet is received from a new address @@ -502,6 +463,24 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade VerifyOrExit(!msg.IsNull(), ChipLogError(Inet, "Secure transport received NULL packet, discarding")); + // MCSP check + if (packetHeader.IsSecureSessionControlMsg()) + { + if (packetHeader.GetDestinationNodeId().HasValue() && packetHeader.HasPrivacyFlag()) + { + // TODO + // if (packetHeader.GetDestinationNodeId().Value() == ThisDeviceNodeID) + // { + // MCSP processing.. + // } + } + else + { + ChipLogError(Inet, "Invalid condition found in packet header"); + ExitNow(err = CHIP_ERROR_INCORRECT_STATE); + } + } + // TODO: Handle Group message counter here spec 4.7.3 // spec 4.5.1.2 for msg counter @@ -523,21 +502,14 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade } } - if (packetHeader.IsSecureSessionControlMsg()) - { - // TODO: control message counter is not implemented yet - } - else - { - // TODO: Commit Group Message Counter - } + // TODO: Commit Group Message Counter if (mCB != nullptr) { // TODO: Update Session Handle for Group messages. // SessionHandle session(session->GetPeerNodeId(), session->GetLocalSessionId(), session->GetPeerSessionId(), // session->GetFabricIndex()); - // mCB->OnMessageReceived(packetHeader, payloadHeader, nullptr, peerAddress, isDuplicate, std::move(msg)); + // mCB->OnMessageReceived(packetHeader, payloadHeader, session, peerAddress, isDuplicate, std::move(msg)); } exit: diff --git a/src/transport/raw/MessageHeader.cpp b/src/transport/raw/MessageHeader.cpp index a2b55235dc0012..b31b27da51d6d3 100644 --- a/src/transport/raw/MessageHeader.cpp +++ b/src/transport/raw/MessageHeader.cpp @@ -177,10 +177,9 @@ CHIP_ERROR PacketHeader::Decode(const uint8_t * const data, uint16_t size, uint1 } else if (mMsgFlags.Has(Header::MsgFlagValues::kDestinationNodeIdPresent)) { - if (mSessionType != Header::SessionType::kUnicastSession) - { - SuccessOrExit(err = CHIP_ERROR_INTERNAL); - } + // No need to check if session is Unicast because for MCSP + // a destination node ID is present with a group session ID. + // Spec 4.9.2.4 uint64_t destinationNodeId; SuccessOrExit(err = reader.Read64(&destinationNodeId).StatusCode()); mDestinationNodeId.SetValue(destinationNodeId); @@ -282,7 +281,6 @@ CHIP_ERROR PacketHeader::Encode(uint8_t * data, uint16_t size, uint16_t * encode VerifyOrReturnError(!(mDestinationNodeId.HasValue() && mDestinationGroupId.HasValue()), CHIP_ERROR_INTERNAL); VerifyOrReturnError(encode_size != nullptr, CHIP_ERROR_INTERNAL); VerifyOrReturnError(IsSessionTypeValid(), CHIP_ERROR_INTERNAL); - VerifyOrReturnError(!(IsGroupSession() && !mDestinationGroupId.HasValue()), CHIP_ERROR_INTERNAL); Header::MsgFlags messageFlags = mMsgFlags; messageFlags.Set(Header::MsgFlagValues::kSourceNodeIdPresent, mSourceNodeId.HasValue()) @@ -306,8 +304,7 @@ CHIP_ERROR PacketHeader::Encode(uint8_t * data, uint16_t size, uint16_t * encode { LittleEndian::Write64(p, mDestinationNodeId.Value()); } - - if (mDestinationGroupId.HasValue()) + else if (mDestinationGroupId.HasValue()) { LittleEndian::Write16(p, mDestinationGroupId.Value()); } diff --git a/src/transport/raw/MessageHeader.h b/src/transport/raw/MessageHeader.h index a3149204a62bea..230a8180e0c063 100644 --- a/src/transport/raw/MessageHeader.h +++ b/src/transport/raw/MessageHeader.h @@ -167,6 +167,11 @@ class PacketHeader uint8_t GetSecurityFlags() const { return mSecFlags.Raw(); } + bool HasPrivacyFlag() const { return mSecFlags.Has(Header::SecFlagValues::kPrivacyFlag); } + + void SetFlags(Header::SecFlagValues value) { mSecFlags.Set(value); } + void SetFlags(Header::MsgFlagValues value) { mMsgFlags.Set(value); } + void SetMessageFlags(uint8_t flags) { mMsgFlags.SetRaw(flags); } void SetSecurityFlags(uint8_t securityFlags) diff --git a/src/transport/raw/tests/TestMessageHeader.cpp b/src/transport/raw/tests/TestMessageHeader.cpp index da4aec5a15dc74..74ed1bd22068d5 100644 --- a/src/transport/raw/tests/TestMessageHeader.cpp +++ b/src/transport/raw/tests/TestMessageHeader.cpp @@ -157,6 +157,17 @@ void TestPacketHeaderEncodeDecode(nlTestSuite * inSuite, void * inContext) NL_TEST_ASSERT(inSuite, header.GetMessageCounter() == 234); NL_TEST_ASSERT(inSuite, header.GetDestinationGroupId() == Optional::Value((uint16_t) 45)); NL_TEST_ASSERT(inSuite, header.GetSourceNodeId() == Optional::Value(77ull)); + + // Verify MCSP state + header.ClearDestinationGroupId().SetDestinationNodeId(42).SetFlags(Header::SecFlagValues::kPrivacyFlag); + NL_TEST_ASSERT(inSuite, header.Encode(buffer, &encodeLen) == CHIP_NO_ERROR); + + // change it to verify decoding + header.SetMessageCounter(222).SetSourceNodeId(1).SetDestinationGroupId(2); + NL_TEST_ASSERT(inSuite, header.Decode(buffer, &decodeLen) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, header.GetDestinationNodeId() == Optional::Value(42ull)); + NL_TEST_ASSERT(inSuite, !header.GetDestinationGroupId().HasValue()); + NL_TEST_ASSERT(inSuite, header.HasPrivacyFlag()); } void TestPayloadHeaderEncodeDecode(nlTestSuite * inSuite, void * inContext)