Skip to content

Commit

Permalink
Use SecureSessionHandle to indicate if the peer's group key message c…
Browse files Browse the repository at this point in the history
…ounter is not synchronized. (#5368)
  • Loading branch information
yufengwangca authored Mar 16, 2021
1 parent 83a8341 commit a0dd7af
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/messaging/ExchangeMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ void ExchangeManager::OnMessageReceived(const PacketHeader & packetHeader, const
UnsolicitedMessageHandler * matchingUMH = nullptr;
bool sendAckAndCloseExchange = false;

if (!IsMsgCounterSyncMessage(payloadHeader) && packetHeader.IsPeerGroupMsgIdNotSynchronized())
if (!IsMsgCounterSyncMessage(payloadHeader) && session.IsPeerGroupMsgIdNotSynchronized())
{
Transport::PeerConnectionState * state = mSessionMgr->GetPeerConnectionState(session);
VerifyOrReturn(state != nullptr);
Expand Down
10 changes: 7 additions & 3 deletions src/transport/SecureSessionMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ void SecureSessionMgr::OnMessageReceived(const PacketHeader & packetHeader, cons
PacketBufferHandle origMsg;
PayloadHeader payloadHeader;

bool peerGroupMsgIdNotSynchronized = false;
Transport::AdminPairingInfo * admin = nullptr;

VerifyOrExit(!msg.IsNull(), ChipLogError(Inet, "Secure transport received NULL packet, discarding"));
Expand Down Expand Up @@ -375,14 +376,17 @@ void SecureSessionMgr::OnMessageReceived(const PacketHeader & packetHeader, cons
// For all group messages, Set flag if peer group key message counter is not synchronized.
if (ChipKeyId::IsAppGroupKey(packetHeader.GetEncryptionKeyID()))
{
const_cast<PacketHeader &>(packetHeader).SetPeerGroupMsgIdNotSynchronized(true);
peerGroupMsgIdNotSynchronized = true;
}
}

if (mCB != nullptr)
{
mCB->OnMessageReceived(packetHeader, payloadHeader, { state->GetPeerNodeId(), state->GetPeerKeyID(), state->GetAdminId() },
std::move(msg), this);
SecureSessionHandle session(state->GetPeerNodeId(), state->GetPeerKeyID(), state->GetAdminId());

session.SetPeerGroupMsgIdNotSynchronized(peerGroupMsgIdNotSynchronized);

mCB->OnMessageReceived(packetHeader, payloadHeader, session, std::move(msg), this);
}

exit:
Expand Down
5 changes: 5 additions & 0 deletions src/transport/SecureSessionMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class SecureSessionHandle
Transport::AdminId GetAdminId() const { return mAdmin; }
void SetAdminId(Transport::AdminId adminId) { mAdmin = adminId; }

bool IsPeerGroupMsgIdNotSynchronized() const { return mPeerGroupMsgIdNotSynchronized; }
void SetPeerGroupMsgIdNotSynchronized(bool value) { mPeerGroupMsgIdNotSynchronized = value; }

bool operator==(const SecureSessionHandle & that) const
{
return mPeerNodeId == that.mPeerNodeId && mPeerKeyId == that.mPeerKeyId && mAdmin == that.mAdmin;
Expand All @@ -74,6 +77,8 @@ class SecureSessionHandle
// to identify an approach that'll allow looking up the corresponding information for
// such sessions.
Transport::AdminId mAdmin;

bool mPeerGroupMsgIdNotSynchronized = false;
};

/**
Expand Down
26 changes: 2 additions & 24 deletions src/transport/raw/MessageHeader.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ enum class ExFlagValues : uint8_t
kExchangeFlag_VendorIdPresent = 0x10,
};

enum class InternalFlagValues : uint8_t
{
// Header flag indicates that the peer's group key message counter is not synchronized.
kPeerGroupMsgIdNotSynchronized = 0x01,
};

enum class FlagValues : uint16_t
{
/// Header flag specifying that a destination node id is included in the header.
Expand All @@ -95,9 +89,8 @@ enum class FlagValues : uint16_t

};

using Flags = BitFlags<FlagValues>;
using ExFlags = BitFlags<ExFlagValues>;
using InternalFlags = BitFlags<InternalFlagValues>;
using Flags = BitFlags<FlagValues>;
using ExFlags = BitFlags<ExFlagValues>;

// Header is a 16-bit value of the form
// | 4 bit | 4 bit |8 bit Security Flags|
Expand Down Expand Up @@ -149,12 +142,6 @@ class PacketHeader
/** Check if it's a secure session control message. */
bool IsSecureSessionControlMsg() const { return mFlags.Has(Header::FlagValues::kSecureSessionControlMessage); }

/** Check if the peer's group key message counter is not synchronized. */
bool IsPeerGroupMsgIdNotSynchronized() const
{
return mInternalFlags.Has(Header::InternalFlagValues::kPeerGroupMsgIdNotSynchronized);
}

Header::EncryptionType GetEncryptionType() const { return mEncryptionType; }

PacketHeader & SetSecureSessionControlMsg(bool value)
Expand All @@ -163,12 +150,6 @@ class PacketHeader
return *this;
}

PacketHeader & SetPeerGroupMsgIdNotSynchronized(bool value)
{
mInternalFlags.Set(Header::InternalFlagValues::kPeerGroupMsgIdNotSynchronized, value);
return *this;
}

PacketHeader & SetSourceNodeId(NodeId id)
{
mSourceNodeId.SetValue(id);
Expand Down Expand Up @@ -328,9 +309,6 @@ class PacketHeader
/// Message flags read from the message.
Header::Flags mFlags;

/// Message flags not encoded into the packet sent over wire.
Header::InternalFlags mInternalFlags;

/// Represents encryption type used for encrypting current packet
Header::EncryptionType mEncryptionType = Header::EncryptionType::kAESCCMTagLen16;
};
Expand Down

0 comments on commit a0dd7af

Please sign in to comment.