Skip to content

Commit

Permalink
Move PairingSession from transport/ to protocols/secure_channel/ (#17889
Browse files Browse the repository at this point in the history
)

* Move PairingSession from transport/ to protocols/secure_channel/

* Resolve comments
  • Loading branch information
kghost authored and pull[bot] committed Aug 23, 2023
1 parent 89812f1 commit 43ead81
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 127 deletions.
2 changes: 2 additions & 0 deletions src/protocols/secure_channel/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ static_library("secure_channel") {
"DefaultSessionResumptionStorage.h",
"PASESession.cpp",
"PASESession.h",
"PairingSession.cpp",
"PairingSession.h",
"RendezvousParameters.h",
"SessionEstablishmentDelegate.h",
"SessionEstablishmentExchangeDispatch.cpp",
Expand Down
48 changes: 1 addition & 47 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
#include <lib/support/TypeTraits.h>
#include <protocols/Protocols.h>
#include <protocols/secure_channel/CASEDestinationId.h>
#include <protocols/secure_channel/PairingSession.h>
#include <protocols/secure_channel/SessionResumptionStorage.h>
#include <protocols/secure_channel/StatusReport.h>
#include <system/TLVPacketBufferBackingStore.h>
#include <trace/trace.h>
#include <transport/PairingSession.h>
#include <transport/SessionManager.h>
#if CHIP_CRYPTO_HSM
#include <crypto/hsm/CHIPCryptoPALHsm.h>
Expand Down Expand Up @@ -127,24 +127,6 @@ CASESession::~CASESession()
Clear();
}

void CASESession::Finish()
{
Transport::PeerAddress address = mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress();

// Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that.
DiscardExchange();

CHIP_ERROR err = ActivateSecureSession(address);
if (err == CHIP_NO_ERROR)
{
mDelegate->OnSessionEstablished(mSecureSessionHolder.Get());
}
else
{
mDelegate->OnSessionEstablishmentError(err);
}
}

void CASESession::Clear()
{
// This function zeroes out and resets the memory used by the object.
Expand All @@ -155,39 +137,11 @@ void CASESession::Clear()
mState = State::kInitialized;
Crypto::ClearSecretData(mIPK);

AbortExchange();

mLocalNodeId = kUndefinedNodeId;
mPeerNodeId = kUndefinedNodeId;
mFabricInfo = nullptr;
}

void CASESession::AbortExchange()
{
if (mExchangeCtxt != nullptr)
{
// The only time we reach this is if we are getting destroyed in the
// middle of our handshake. In that case, there is no point trying to
// do MRP resends of the last message we sent, so abort the exchange
// instead of just closing it.
mExchangeCtxt->Abort();
mExchangeCtxt = nullptr;
}
}

void CASESession::DiscardExchange()
{
if (mExchangeCtxt != nullptr)
{
// Make sure the exchange doesn't try to notify us when it closes,
// since we might be dead by then.
mExchangeCtxt->SetDelegate(nullptr);
// Null out mExchangeCtxt so that Clear() doesn't try closing it. The
// exchange will handle that.
mExchangeCtxt = nullptr;
}
}

CHIP_ERROR CASESession::Init(SessionManager & sessionManager, SessionEstablishmentDelegate * delegate)
{
VerifyOrReturnError(delegate != nullptr, CHIP_ERROR_INVALID_ARGUMENT);
Expand Down
17 changes: 1 addition & 16 deletions src/protocols/secure_channel/CASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@
#include <messaging/ExchangeDelegate.h>
#include <protocols/secure_channel/CASEDestinationId.h>
#include <protocols/secure_channel/Constants.h>
#include <protocols/secure_channel/SessionEstablishmentDelegate.h>
#include <protocols/secure_channel/PairingSession.h>
#include <protocols/secure_channel/SessionEstablishmentExchangeDispatch.h>
#include <protocols/secure_channel/SessionResumptionStorage.h>
#include <system/SystemPacketBuffer.h>
#include <transport/CryptoContext.h>
#include <transport/PairingSession.h>
#include <transport/raw/MessageHeader.h>
#include <transport/raw/PeerAddress.h>

Expand Down Expand Up @@ -216,26 +215,13 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler,
void OnSuccessStatusReport() override;
CHIP_ERROR OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) override;

// TODO: pull up Finish to PairingSession class
void Finish();

void AbortExchange();

/**
* Clear our reference to our exchange context pointer so that it can close
* itself at some later time.
*/
void DiscardExchange();

CHIP_ERROR GetHardcodedTime();

CHIP_ERROR SetEffectiveTime();

CHIP_ERROR ValidateReceivedMessage(Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader,
const System::PacketBufferHandle & msg);

SessionEstablishmentDelegate * mDelegate = nullptr;

Crypto::Hash_SHA256_stream mCommissioningHash;
Crypto::P256PublicKey mRemotePubKey;
#ifdef ENABLE_HSM_CASE_EPHEMERAL_KEY
Expand All @@ -250,7 +236,6 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler,
uint8_t mMessageDigest[Crypto::kSHA256_Hash_Length];
uint8_t mIPK[kIPKSize];

Messaging::ExchangeContext * mExchangeCtxt = nullptr;
SessionResumptionStorage * mSessionResumptionStorage = nullptr;

FabricTable * mFabricsTable = nullptr;
Expand Down
39 changes: 1 addition & 38 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,7 @@ PASESession::~PASESession()
void PASESession::Finish()
{
mPairingComplete = true;

Transport::PeerAddress address = mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress();

// Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that.
DiscardExchange();

CHIP_ERROR err = ActivateSecureSession(address);
if (err == CHIP_NO_ERROR)
{
mDelegate->OnSessionEstablished(mSecureSessionHolder.Get());
}
else
{
mDelegate->OnSessionEstablishmentError(err);
}
PairingSession::Finish();
}

void PASESession::Clear()
Expand All @@ -110,29 +96,6 @@ void PASESession::Clear()
mKeLen = sizeof(mKe);
mPairingComplete = false;
PairingSession::Clear();
CloseExchange();
}

void PASESession::CloseExchange()
{
if (mExchangeCtxt != nullptr)
{
mExchangeCtxt->Close();
mExchangeCtxt = nullptr;
}
}

void PASESession::DiscardExchange()
{
if (mExchangeCtxt != nullptr)
{
// Make sure the exchange doesn't try to notify us when it closes,
// since we might be dead by then.
mExchangeCtxt->SetDelegate(nullptr);
// Null out mExchangeCtxt so that Clear() doesn't try closing it. The
// exchange will handle that.
mExchangeCtxt = nullptr;
}
}

CHIP_ERROR PASESession::Init(SessionManager & sessionManager, uint32_t setupCode, SessionEstablishmentDelegate * delegate)
Expand Down
16 changes: 1 addition & 15 deletions src/protocols/secure_channel/PASESession.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@
#include <messaging/ExchangeDelegate.h>
#include <messaging/ExchangeMessageDispatch.h>
#include <protocols/secure_channel/Constants.h>
#include <protocols/secure_channel/SessionEstablishmentDelegate.h>
#include <protocols/secure_channel/PairingSession.h>
#include <protocols/secure_channel/SessionEstablishmentExchangeDispatch.h>
#include <system/SystemPacketBuffer.h>
#include <transport/CryptoContext.h>
#include <transport/PairingSession.h>
#include <transport/raw/MessageHeader.h>
#include <transport/raw/PeerAddress.h>

Expand Down Expand Up @@ -208,19 +207,8 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler,
void OnSuccessStatusReport() override;
CHIP_ERROR OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode) override;

// TODO: pull up Finish to PairingSession class
void Finish();

void CloseExchange();

/**
* Clear our reference to our exchange context pointer so that it can close
* itself at some later time.
*/
void DiscardExchange();

SessionEstablishmentDelegate * mDelegate = nullptr;

Protocols::SecureChannel::MsgType mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_PakeError;

#ifdef ENABLE_HSM_SPAKE
Expand All @@ -242,8 +230,6 @@ class DLL_EXPORT PASESession : public Messaging::UnsolicitedMessageHandler,
uint16_t mSaltLength = 0;
uint8_t * mSalt = nullptr;

Messaging::ExchangeContext * mExchangeCtxt = nullptr;

struct Spake2pErrorMsg
{
Spake2pErrorType error;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* limitations under the License.
*/

#include <transport/PairingSession.h>
#include <protocols/secure_channel/PairingSession.h>

#include <lib/core/CHIPTLVTypes.h>
#include <lib/support/SafeInt.h>
Expand Down Expand Up @@ -57,6 +57,37 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress &
return CHIP_NO_ERROR;
}

void PairingSession::Finish()
{
Transport::PeerAddress address = mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress();

// Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that.
DiscardExchange();

CHIP_ERROR err = ActivateSecureSession(address);
if (err == CHIP_NO_ERROR)
{
mDelegate->OnSessionEstablished(mSecureSessionHolder.Get());
}
else
{
mDelegate->OnSessionEstablishmentError(err);
}
}

void PairingSession::DiscardExchange()
{
if (mExchangeCtxt != nullptr)
{
// Make sure the exchange doesn't try to notify us when it closes,
// since we might be dead by then.
mExchangeCtxt->SetDelegate(nullptr);
// Null out mExchangeCtxt so that Clear() doesn't try closing it. The
// exchange will handle that.
mExchangeCtxt = nullptr;
}
}

CHIP_ERROR PairingSession::EncodeMRPParameters(TLV::Tag tag, const ReliableMessageProtocolConfig & mrpConfig,
TLV::TLVWriter & tlvWriter)
{
Expand Down Expand Up @@ -109,6 +140,19 @@ CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TL

void PairingSession::Clear()
{
// Clear acts like the destructor if PairingSession, if it is call during
// middle of a pairing, means we should terminate the exchange. For normal
// path, the exchange should already be discarded before calling Clear.
if (mExchangeCtxt != nullptr)
{
// The only time we reach this is if we are getting destroyed in the
// middle of our handshake. In that case, there is no point trying to
// do MRP resends of the last message we sent, so abort the exchange
// instead of just closing it.
mExchangeCtxt->Abort();
mExchangeCtxt = nullptr;
}

if (mSessionManager != nullptr)
{
if (mSecureSessionHolder && !mSecureSessionHolder->AsSecureSession()->IsActiveSession())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <lib/core/CHIPTLV.h>
#include <messaging/ExchangeContext.h>
#include <protocols/secure_channel/Constants.h>
#include <protocols/secure_channel/SessionEstablishmentDelegate.h>
#include <protocols/secure_channel/StatusReport.h>
#include <transport/CryptoContext.h>
#include <transport/SecureSession.h>
Expand Down Expand Up @@ -95,6 +96,10 @@ class DLL_EXPORT PairingSession

CHIP_ERROR ActivateSecureSession(const Transport::PeerAddress & peerAddress);

void Finish();

void DiscardExchange(); // Clear our reference to our exchange context pointer so that it can close itself at some later time.

void SetPeerSessionId(uint16_t id) { mPeerSessionId.SetValue(id); }
virtual void OnSuccessStatusReport() {}
virtual CHIP_ERROR OnFailureStatusReport(Protocols::SecureChannel::GeneralStatusCode generalCode, uint16_t protocolCode)
Expand Down Expand Up @@ -168,7 +173,9 @@ class DLL_EXPORT PairingSession
SessionHolder mSecureSessionHolder;
// mSessionManager is set if we actually allocate a secure session, so we
// can clean it up later as needed.
SessionManager * mSessionManager = nullptr;
SessionManager * mSessionManager = nullptr;
Messaging::ExchangeContext * mExchangeCtxt = nullptr;
SessionEstablishmentDelegate * mDelegate = nullptr;

// mLocalMRPConfig is our config which is sent to the other end and used by the peer session.
// mRemoteMRPConfig is received from other end and set to our session.
Expand Down
1 change: 1 addition & 0 deletions src/protocols/secure_channel/tests/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ chip_test_suite("tests") {
# "TestMessageCounterManager.cpp",
"TestDefaultSessionResumptionStorage.cpp",
"TestPASESession.cpp",
"TestPairingSession.cpp",
"TestSimpleSessionResumptionStorage.cpp",
"TestStatusReport.cpp",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@

#include <lib/core/CHIPCore.h>
#include <lib/support/CodeUtils.h>

#include <messaging/ReliableMessageProtocolConfig.h>
#include <transport/PairingSession.h>

#include <lib/support/UnitTestRegistration.h>
#include <messaging/ReliableMessageProtocolConfig.h>
#include <protocols/secure_channel/PairingSession.h>
#include <stdarg.h>
#include <system/SystemClock.h>
#include <system/TLVPacketBufferBackingStore.h>
Expand Down
1 change: 0 additions & 1 deletion src/transport/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ static_library("transport") {
"MessageCounter.cpp",
"MessageCounter.h",
"MessageCounterManagerInterface.h",
"PairingSession.cpp",
"PeerMessageCounter.h",
"SecureMessageCodec.cpp",
"SecureMessageCodec.h",
Expand Down
1 change: 0 additions & 1 deletion src/transport/SessionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
#include <protocols/secure_channel/Constants.h>
#include <transport/GroupPeerMessageCounter.h>
#include <transport/GroupSession.h>
#include <transport/PairingSession.h>
#include <transport/SecureMessageCodec.h>
#include <transport/TransportMgr.h>

Expand Down
2 changes: 0 additions & 2 deletions src/transport/SessionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@

namespace chip {

class PairingSession;

/**
* @brief
* Tracks ownership of a encrypted packet buffer.
Expand Down
1 change: 0 additions & 1 deletion src/transport/tests/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ chip_test_suite("tests") {

test_sources = [
"TestGroupMessageCounter.cpp",
"TestPairingSession.cpp",
"TestPeerConnections.cpp",
"TestPeerMessageCounter.cpp",
"TestSecureSession.cpp",
Expand Down

0 comments on commit 43ead81

Please sign in to comment.