Skip to content

Commit

Permalink
Clean-up symettry in SessionManager (#13591)
Browse files Browse the repository at this point in the history
- SessionManager has symetrical processing steps for
  group and unsecured message, but some side-effects of
  unicast secure sessions are delegated to SecureMessageCodec
  that should actually just be encrypt/decrypt logic.
- This is needed to assist in having a single point of tracing
  for incoming/outgoing messages (upcoming PR). Current organization
  requires putting that logic in several modules due to the
  side-effects occuring at the wrong layer.

Done by this PR:
- Hoists those side effects (e.g. updating counter) up from
  SecureMessageCodec where they do not belong, into SessionManager.
- Fix documentation and argument names of SecureMessageCodec that
  had rotted over many refactors and terminological changes.
- Add a bit of processing symmetry in SessionManager

Testing done: unit tests pass, integration tests pass
  • Loading branch information
tcarmelveilleux authored and pull[bot] committed Oct 23, 2023
1 parent f7e6c83 commit 2872952
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 39 deletions.
21 changes: 5 additions & 16 deletions src/transport/SecureMessageCodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,46 +38,35 @@ using System::PacketBufferHandle;

namespace SecureMessageCodec {

CHIP_ERROR Encrypt(Transport::SecureSession * state, PayloadHeader & payloadHeader, PacketHeader & packetHeader,
System::PacketBufferHandle & msgBuf, MessageCounter & counter)
CHIP_ERROR Encrypt(Transport::SecureSession * session, PayloadHeader & payloadHeader, PacketHeader & packetHeader,
System::PacketBufferHandle & msgBuf)
{
VerifyOrReturnError(!msgBuf.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError(!msgBuf->HasChainedBuffer(), CHIP_ERROR_INVALID_MESSAGE_LENGTH);
VerifyOrReturnError(msgBuf->TotalLength() <= kMaxAppMessageLen, CHIP_ERROR_MESSAGE_TOO_LONG);

uint32_t messageCounter = counter.Value();

static_assert(std::is_same<decltype(msgBuf->TotalLength()), uint16_t>::value,
"Addition to generate payloadLength might overflow");

packetHeader
.SetMessageCounter(messageCounter) //
.SetSessionId(state->GetPeerSessionId());

// TODO set Session Type (Unicast or Group)
// packetHeader.SetSessionType(Header::SessionType::kUnicastSession);

ReturnErrorOnFailure(payloadHeader.EncodeBeforeData(msgBuf));

uint8_t * data = msgBuf->Start();
uint16_t totalLen = msgBuf->TotalLength();

CHIP_TRACE_MESSAGE(payloadHeader, packetHeader, data, totalLen);

MessageAuthenticationCode mac;
ReturnErrorOnFailure(state->EncryptBeforeSend(data, totalLen, data, packetHeader, mac));
ReturnErrorOnFailure(session->EncryptBeforeSend(data, totalLen, data, packetHeader, mac));

uint16_t taglen = 0;
ReturnErrorOnFailure(mac.Encode(packetHeader, &data[totalLen], msgBuf->AvailableDataLength(), &taglen));

VerifyOrReturnError(CanCastTo<uint16_t>(totalLen + taglen), CHIP_ERROR_INTERNAL);
msgBuf->SetDataLength(static_cast<uint16_t>(totalLen + taglen));

ReturnErrorOnFailure(counter.Advance());
return CHIP_NO_ERROR;
}

CHIP_ERROR Decrypt(Transport::SecureSession * state, PayloadHeader & payloadHeader, const PacketHeader & packetHeader,
CHIP_ERROR Decrypt(Transport::SecureSession * session, PayloadHeader & payloadHeader, const PacketHeader & packetHeader,
System::PacketBufferHandle & msg)
{
ReturnErrorCodeIf(msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
Expand Down Expand Up @@ -107,7 +96,7 @@ CHIP_ERROR Decrypt(Transport::SecureSession * state, PayloadHeader & payloadHead
msg->SetDataLength(len);

uint8_t * plainText = msg->Start();
ReturnErrorOnFailure(state->DecryptOnReceive(data, len, plainText, packetHeader, mac));
ReturnErrorOnFailure(session->DecryptOnReceive(data, len, plainText, packetHeader, mac));

ReturnErrorOnFailure(payloadHeader.DecodeAndConsume(msg));
return CHIP_NO_ERROR;
Expand Down
32 changes: 16 additions & 16 deletions src/transport/SecureMessageCodec.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,38 @@ namespace SecureMessageCodec {
/**
* @brief
* Attach payload header to the message and encrypt the message buffer using
* key from the connection state.
* key from the secure session.
*
* @param state The connection state with peer node
* @param session The secure session context with the peer node
* @param payloadHeader Reference to the payload header that should be inserted in
* the message
* @param packetHeader Reference to the packet header that contains unencrypted
* portion of the message header
* @param msgBuf The message buffer that contains the unencrypted message. If
* the operation is successuful, this buffer will contain the
* encrypted message.
* @param counter The local counter object to be used
* @ return CHIP_ERROR The result of the encode operation
* the operation is successful, this buffer will be mutated to contain
* the encrypted message.
* @return A CHIP_ERROR value consistent with the result of the encryption operation
*/
CHIP_ERROR Encrypt(Transport::SecureSession * state, PayloadHeader & payloadHeader, PacketHeader & packetHeader,
System::PacketBufferHandle & msgBuf, MessageCounter & counter);
CHIP_ERROR Encrypt(Transport::SecureSession * session, PayloadHeader & payloadHeader, PacketHeader & packetHeader,
System::PacketBufferHandle & msgBuf);

/**
* @brief
* Decrypt the message, perform message integrity check, and decode the payload header.
* Decrypt the message, perform message integrity check, and decode the payload header,
* consuming the header from the packet in doing so.
*
* @param state The connection state with peer node
* @param payloadHeader Reference to the payload header that should be inserted in
* the message
* @param session The secure session context with the peer node
* @param payloadHeader Reference to the payload header that will be recovered from the message
* @param packetHeader Reference to the packet header that contains unencrypted
* portion of the message header
* @param msgBuf The message buffer that contains the encrypted message. If
* the operation is successuful, this buffer will contain the
* unencrypted message.
* @ return CHIP_ERROR The result of the decode operation
* the operation is successful, this buffer will be mutated to contain
* the decrypted message.
* @return A CHIP_ERROR value consistent with the result of the decryption operation
*/
CHIP_ERROR Decrypt(Transport::SecureSession * state, PayloadHeader & payloadHeader, const PacketHeader & packetHeader,
CHIP_ERROR Decrypt(Transport::SecureSession * session, PayloadHeader & payloadHeader, const PacketHeader & packetHeader,
System::PacketBufferHandle & msgBuf);

} // namespace SecureMessageCodec

} // namespace chip
23 changes: 16 additions & 7 deletions src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,16 @@ CHIP_ERROR SessionManager::PrepareMessage(const SessionHandle & sessionHandle, P
{
return CHIP_ERROR_NOT_CONNECTED;
}

MessageCounter & counter = GetSendCounterForPacket(payloadHeader, *session);
ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message, counter));
uint32_t messageCounter = counter.Value();
packetHeader
.SetMessageCounter(messageCounter) //
.SetSessionId(session->GetPeerSessionId()) //
.SetSessionType(Header::SessionType::kUnicastSession);

ReturnErrorOnFailure(SecureMessageCodec::Encrypt(session, payloadHeader, packetHeader, message));
ReturnErrorOnFailure(counter.Advance());

#if CHIP_PROGRESS_LOGGING
destination = session->GetPeerNodeId();
Expand Down Expand Up @@ -420,19 +428,19 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr
}

const SessionHandle & session = optionalSession.Value();
Transport::UnauthenticatedSession * unsecuredSession = session->AsUnauthenticatedSession();
SessionMessageDelegate::DuplicateMessage isDuplicate = SessionMessageDelegate::DuplicateMessage::No;

// Verify message counter
CHIP_ERROR err =
session->AsUnauthenticatedSession()->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageCounter());
CHIP_ERROR err = unsecuredSession->GetPeerMessageCounter().VerifyOrTrustFirst(packetHeader.GetMessageCounter());
if (err == CHIP_ERROR_DUPLICATE_MESSAGE_RECEIVED)
{
isDuplicate = SessionMessageDelegate::DuplicateMessage::Yes;
err = CHIP_NO_ERROR;
}
VerifyOrDie(err == CHIP_NO_ERROR);

session->AsUnauthenticatedSession()->MarkActive();
unsecuredSession->MarkActive();

PayloadHeader payloadHeader;
ReturnOnFailure(payloadHeader.DecodeAndConsume(msg));
Expand All @@ -445,11 +453,11 @@ void SessionManager::MessageDispatch(const PacketHeader & packetHeader, const Tr
packetHeader.GetMessageCounter(), ChipLogValueExchangeIdFromReceivedHeader(payloadHeader));
}

session->AsUnauthenticatedSession()->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter());
unsecuredSession->GetPeerMessageCounter().Commit(packetHeader.GetMessageCounter());

if (mCB != nullptr)
{
mCB->OnMessageReceived(packetHeader, payloadHeader, optionalSession.Value(), peerAddress, isDuplicate, std::move(msg));
mCB->OnMessageReceived(packetHeader, payloadHeader, session, peerAddress, isDuplicate, std::move(msg));
}
}

Expand Down Expand Up @@ -599,10 +607,11 @@ void SessionManager::SecureGroupMessageDispatch(const PacketHeader & packetHeade
{
Optional<SessionHandle> session = CreateGroupSession(packetHeader.GetDestinationGroupId().Value());
VerifyOrReturn(session.HasValue(), ChipLogError(Inet, "Error when creating group session handle."));
Transport::GroupSession * groupSession = session.Value()->AsGroupSession();

mCB->OnMessageReceived(packetHeader, payloadHeader, session.Value(), peerAddress, isDuplicate, std::move(msg));

RemoveGroupSession(session.Value()->AsGroupSession());
RemoveGroupSession(groupSession);
}
}

Expand Down

0 comments on commit 2872952

Please sign in to comment.